# -*- coding: utf-8 -*- """ Created on Tue Jun 10 11:16:28 2025 @author: camaac """ import gradio as gr from PIL import Image import torch from inference import inference from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor import torch.nn as nn class UNetNoCondWrapper(nn.Module): def __init__(self, base_unet: UNet2DModel): super().__init__() self.unet = base_unet def forward( self, sample, timestep, encoder_hidden_states=None, added_cond_kwargs=None, cross_attention_kwargs=None, return_dict=False, **kwargs ): return self.unet(sample, timestep, return_dict=return_dict, **kwargs) def __getattr__(self, name): if name in ("unet", "forward", "__getstate__", "__setstate__"): return super().__getattr__(name) return getattr(self.unet, name) def save_pretrained(self, save_directory, **kwargs): # délègue à la vraie instance UNet2DModel return self.unet.save_pretrained(save_directory, **kwargs) # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cpu') model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt" vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device) feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor") # 2) Chargez votre UNet non‑conditionné et wrappez‑le base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device) wrapped_unet = UNetNoCondWrapper(base_unet).to(device) # 3) Construisez la pipeline manuellement pipe = StableDiffusionInstructPix2PixPipeline( vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=wrapped_unet, scheduler=scheduler, safety_checker=None, feature_extractor=feature_extractor, ) pipe = pipe.to(torch.float32).to(device) def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image, num_steps: int) -> Image.Image: # 1) uniformiser le mode fibers_map = fibers_map.convert("RGB") rings_map = rings_map.convert("RGB") # 3) appeler l'inference avec la seed result_img = inference(pipe, rings_map, fibers_map, num_steps) return result_img iface = gr.Interface( fn=gradio_generate, inputs=[ gr.Image(type="pil", label="Fibre orientation map"), gr.Image(type="pil", label="Growth ring map"), gr.Number(value=10, label="Number of inference steps") ], outputs=gr.Image( type="pil", label="Photorealistic wood generated", format="png" # ← force le .png au téléchargement ), title="Photorealistic wood generator", description=""" Upload : 1) a fibre orientation map, 2) a growth ring map. Set the number of inference steps. Higher values can improve quality but increase processing time. The model will return a photo-realistic rendering of the wood that you can download. """ ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860, share=True) # with gr.Blocks() as demo: # gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.") # with gr.Row(): # fibers = gr.Image(type="pil", label="Fibre orientation map") # rings = gr.Image(type="pil", label="Growth ring map") # steps = gr.Number(value=10, label="Number of inference steps") # btn = gr.Button("Generate") # # State pour stocker la liste des images # state_images = gr.State([]) # # Slider pour parcourir # slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index") # # Image affichée # display = gr.Image(label="Intermediate result") # # 1) Au clique, on génère et on met à jour state + slider + display # def run_and_store(fib, ring, num_steps): # imgs = inference(pipe, ring,fib, int(num_steps)) # # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0 # return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0] # btn.click( # fn=run_and_store, # inputs=[fibers, rings, steps], # outputs=[state_images, slider, display] # ) # # 2) Quand on bouge le slider, on affiche state_images[slider] # def select_step(imgs, idx): # return imgs[int(idx)] # slider.change( # fn=select_step, # inputs=[state_images, slider], # outputs=display # ) # demo.launch()