Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """ | |
| Created on Tue Jun 10 11:16:28 2025 | |
| @author: camaac | |
| """ | |
| import gradio as gr | |
| from PIL import Image | |
| import torch | |
| from inference import inference | |
| from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler | |
| from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor | |
| import torch.nn as nn | |
| class UNetNoCondWrapper(nn.Module): | |
| def __init__(self, base_unet: UNet2DModel): | |
| super().__init__() | |
| self.unet = base_unet | |
| def forward( | |
| self, | |
| sample, | |
| timestep, | |
| encoder_hidden_states=None, | |
| added_cond_kwargs=None, | |
| cross_attention_kwargs=None, | |
| return_dict=False, | |
| **kwargs | |
| ): | |
| return self.unet(sample, timestep, return_dict=return_dict, **kwargs) | |
| def __getattr__(self, name): | |
| if name in ("unet", "forward", "__getstate__", "__setstate__"): | |
| return super().__getattr__(name) | |
| return getattr(self.unet, name) | |
| def save_pretrained(self, save_directory, **kwargs): | |
| # délègue à la vraie instance UNet2DModel | |
| return self.unet.save_pretrained(save_directory, **kwargs) | |
| # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| device = torch.device('cpu') | |
| model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt" | |
| vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device) | |
| scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler") | |
| tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") | |
| text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device) | |
| feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor") | |
| # 2) Chargez votre UNet non‑conditionné et wrappez‑le | |
| base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device) | |
| wrapped_unet = UNetNoCondWrapper(base_unet).to(device) | |
| # 3) Construisez la pipeline manuellement | |
| pipe = StableDiffusionInstructPix2PixPipeline( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=wrapped_unet, | |
| scheduler=scheduler, | |
| safety_checker=None, | |
| feature_extractor=feature_extractor, | |
| ) | |
| pipe = pipe.to(torch.float32).to(device) | |
| def gradio_generate(fibers_map: Image.Image, | |
| rings_map: Image.Image, | |
| num_steps: int) -> Image.Image: | |
| # 1) uniformiser le mode | |
| fibers_map = fibers_map.convert("RGB") | |
| rings_map = rings_map.convert("RGB") | |
| # 3) appeler l'inference avec la seed | |
| result_img = inference(pipe, | |
| rings_map, | |
| fibers_map, | |
| num_steps) | |
| return result_img | |
| iface = gr.Interface( | |
| fn=gradio_generate, | |
| inputs=[ | |
| gr.Image(type="pil", label="Fibre orientation map"), | |
| gr.Image(type="pil", label="Growth ring map"), | |
| gr.Number(value=10, label="Number of inference steps") | |
| ], | |
| outputs=gr.Image( | |
| type="pil", | |
| label="Photorealistic wood generated", | |
| format="png" # ← force le .png au téléchargement | |
| ), | |
| title="Photorealistic wood generator", | |
| description=""" | |
| Upload : | |
| 1) a fibre orientation map, | |
| 2) a growth ring map. | |
| Set the number of inference steps. | |
| Higher values can improve quality but increase processing time. | |
| The model will return a photo-realistic rendering of the wood that you can download. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860, share=True) | |
| # with gr.Blocks() as demo: | |
| # gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.") | |
| # with gr.Row(): | |
| # fibers = gr.Image(type="pil", label="Fibre orientation map") | |
| # rings = gr.Image(type="pil", label="Growth ring map") | |
| # steps = gr.Number(value=10, label="Number of inference steps") | |
| # btn = gr.Button("Generate") | |
| # # State pour stocker la liste des images | |
| # state_images = gr.State([]) | |
| # # Slider pour parcourir | |
| # slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index") | |
| # # Image affichée | |
| # display = gr.Image(label="Intermediate result") | |
| # # 1) Au clique, on génère et on met à jour state + slider + display | |
| # def run_and_store(fib, ring, num_steps): | |
| # imgs = inference(pipe, ring,fib, int(num_steps)) | |
| # # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0 | |
| # return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0] | |
| # btn.click( | |
| # fn=run_and_store, | |
| # inputs=[fibers, rings, steps], | |
| # outputs=[state_images, slider, display] | |
| # ) | |
| # # 2) Quand on bouge le slider, on affiche state_images[slider] | |
| # def select_step(imgs, idx): | |
| # return imgs[int(idx)] | |
| # slider.change( | |
| # fn=select_step, | |
| # inputs=[state_images, slider], | |
| # outputs=display | |
| # ) | |
| # demo.launch() |