BoardGenerator / app.py
CarolineM5's picture
Upload app.py
85e2b85 verified
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 10 11:16:28 2025
@author: camaac
"""
import gradio as gr
from PIL import Image
import torch
from inference import inference
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import torch.nn as nn
class UNetNoCondWrapper(nn.Module):
def __init__(self, base_unet: UNet2DModel):
super().__init__()
self.unet = base_unet
def forward(
self,
sample,
timestep,
encoder_hidden_states=None,
added_cond_kwargs=None,
cross_attention_kwargs=None,
return_dict=False,
**kwargs
):
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
def __getattr__(self, name):
if name in ("unet", "forward", "__getstate__", "__setstate__"):
return super().__getattr__(name)
return getattr(self.unet, name)
def save_pretrained(self, save_directory, **kwargs):
# délègue à la vraie instance UNet2DModel
return self.unet.save_pretrained(save_directory, **kwargs)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=wrapped_unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
pipe = pipe.to(torch.float32).to(device)
def gradio_generate(fibers_map: Image.Image,
rings_map: Image.Image,
num_steps: int) -> Image.Image:
# 1) uniformiser le mode
fibers_map = fibers_map.convert("RGB")
rings_map = rings_map.convert("RGB")
# 3) appeler l'inference avec la seed
result_img = inference(pipe,
rings_map,
fibers_map,
num_steps)
return result_img
iface = gr.Interface(
fn=gradio_generate,
inputs=[
gr.Image(type="pil", label="Fibre orientation map"),
gr.Image(type="pil", label="Growth ring map"),
gr.Number(value=10, label="Number of inference steps")
],
outputs=gr.Image(
type="pil",
label="Photorealistic wood generated",
format="png" # ← force le .png au téléchargement
),
title="Photorealistic wood generator",
description="""
Upload :
1) a fibre orientation map,
2) a growth ring map.
Set the number of inference steps.
Higher values can improve quality but increase processing time.
The model will return a photo-realistic rendering of the wood that you can download.
"""
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
# with gr.Blocks() as demo:
# gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.")
# with gr.Row():
# fibers = gr.Image(type="pil", label="Fibre orientation map")
# rings = gr.Image(type="pil", label="Growth ring map")
# steps = gr.Number(value=10, label="Number of inference steps")
# btn = gr.Button("Generate")
# # State pour stocker la liste des images
# state_images = gr.State([])
# # Slider pour parcourir
# slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index")
# # Image affichée
# display = gr.Image(label="Intermediate result")
# # 1) Au clique, on génère et on met à jour state + slider + display
# def run_and_store(fib, ring, num_steps):
# imgs = inference(pipe, ring,fib, int(num_steps))
# # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0
# return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0]
# btn.click(
# fn=run_and_store,
# inputs=[fibers, rings, steps],
# outputs=[state_images, slider, display]
# )
# # 2) Quand on bouge le slider, on affiche state_images[slider]
# def select_step(imgs, idx):
# return imgs[int(idx)]
# slider.change(
# fn=select_step,
# inputs=[state_images, slider],
# outputs=display
# )
# demo.launch()