Spaces:
Sleeping
Sleeping
File size: 5,282 Bytes
38ecdbd 76bee53 990a91c e5dea97 38ecdbd 990a91c 38ecdbd 990a91c 38ecdbd 990a91c e40eff1 38ecdbd e40eff1 38ecdbd 990a91c b3652cf 7402c4e b3652cf 7402c4e b3652cf 7402c4e b3652cf 7402c4e b3652cf 7402c4e 85e2b85 7402c4e 00edf85 7402c4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# -*- coding: utf-8 -*-
"""
Created on Tue Jun 10 11:16:28 2025
@author: camaac
"""
import gradio as gr
from PIL import Image
import torch
from inference import inference
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import torch.nn as nn
class UNetNoCondWrapper(nn.Module):
def __init__(self, base_unet: UNet2DModel):
super().__init__()
self.unet = base_unet
def forward(
self,
sample,
timestep,
encoder_hidden_states=None,
added_cond_kwargs=None,
cross_attention_kwargs=None,
return_dict=False,
**kwargs
):
return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
def __getattr__(self, name):
if name in ("unet", "forward", "__getstate__", "__setstate__"):
return super().__getattr__(name)
return getattr(self.unet, name)
def save_pretrained(self, save_directory, **kwargs):
# délègue à la vraie instance UNet2DModel
return self.unet.save_pretrained(save_directory, **kwargs)
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=wrapped_unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=feature_extractor,
)
pipe = pipe.to(torch.float32).to(device)
def gradio_generate(fibers_map: Image.Image,
rings_map: Image.Image,
num_steps: int) -> Image.Image:
# 1) uniformiser le mode
fibers_map = fibers_map.convert("RGB")
rings_map = rings_map.convert("RGB")
# 3) appeler l'inference avec la seed
result_img = inference(pipe,
rings_map,
fibers_map,
num_steps)
return result_img
iface = gr.Interface(
fn=gradio_generate,
inputs=[
gr.Image(type="pil", label="Fibre orientation map"),
gr.Image(type="pil", label="Growth ring map"),
gr.Number(value=10, label="Number of inference steps")
],
outputs=gr.Image(
type="pil",
label="Photorealistic wood generated",
format="png" # ← force le .png au téléchargement
),
title="Photorealistic wood generator",
description="""
Upload :
1) a fibre orientation map,
2) a growth ring map.
Set the number of inference steps.
Higher values can improve quality but increase processing time.
The model will return a photo-realistic rendering of the wood that you can download.
"""
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
# with gr.Blocks() as demo:
# gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.")
# with gr.Row():
# fibers = gr.Image(type="pil", label="Fibre orientation map")
# rings = gr.Image(type="pil", label="Growth ring map")
# steps = gr.Number(value=10, label="Number of inference steps")
# btn = gr.Button("Generate")
# # State pour stocker la liste des images
# state_images = gr.State([])
# # Slider pour parcourir
# slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index")
# # Image affichée
# display = gr.Image(label="Intermediate result")
# # 1) Au clique, on génère et on met à jour state + slider + display
# def run_and_store(fib, ring, num_steps):
# imgs = inference(pipe, ring,fib, int(num_steps))
# # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0
# return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0]
# btn.click(
# fn=run_and_store,
# inputs=[fibers, rings, steps],
# outputs=[state_images, slider, display]
# )
# # 2) Quand on bouge le slider, on affiche state_images[slider]
# def select_step(imgs, idx):
# return imgs[int(idx)]
# slider.change(
# fn=select_step,
# inputs=[state_images, slider],
# outputs=display
# )
# demo.launch() |