File size: 5,282 Bytes
38ecdbd
 
 
 
 
 
 
 
 
 
76bee53
990a91c
 
e5dea97
38ecdbd
990a91c
 
 
 
38ecdbd
990a91c
 
 
 
 
 
 
 
 
 
 
 
38ecdbd
990a91c
 
 
 
 
 
 
 
 
e40eff1
38ecdbd
e40eff1
38ecdbd
990a91c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3652cf
 
 
 
7402c4e
 
b3652cf
 
 
 
 
 
 
7402c4e
 
b3652cf
7402c4e
 
 
 
 
 
 
b3652cf
 
 
 
 
7402c4e
 
 
 
 
b3652cf
7402c4e
 
 
 
 
 
 
 
85e2b85
7402c4e
 
 
 
00edf85
7402c4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# -*- coding: utf-8 -*-
"""

Created on Tue Jun 10 11:16:28 2025



@author: camaac

"""

import gradio as gr
from PIL import Image
import torch
from inference import inference
from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler 
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
import torch.nn as nn

class UNetNoCondWrapper(nn.Module):
    def __init__(self, base_unet: UNet2DModel):
        super().__init__()
        self.unet = base_unet

    def forward(

        self,

        sample,

        timestep,

        encoder_hidden_states=None,    

        added_cond_kwargs=None,        

        cross_attention_kwargs=None,   

        return_dict=False,

        **kwargs                       

    ):
        
        return self.unet(sample, timestep, return_dict=return_dict, **kwargs)

    def __getattr__(self, name):
        if name in ("unet", "forward", "__getstate__", "__setstate__"):
            return super().__getattr__(name)
        return getattr(self.unet, name)
    
    def save_pretrained(self, save_directory, **kwargs):
        # délègue à la vraie instance UNet2DModel
        return self.unet.save_pretrained(save_directory, **kwargs)
    
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

device = torch.device('cpu')

model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")

# 2) Chargez votre UNet non‑conditionné et wrappez‑le
base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
wrapped_unet = UNetNoCondWrapper(base_unet).to(device)

# 3) Construisez la pipeline manuellement
pipe = StableDiffusionInstructPix2PixPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=wrapped_unet,
    scheduler=scheduler,
    safety_checker=None,
    feature_extractor=feature_extractor,
)

pipe = pipe.to(torch.float32).to(device)

def gradio_generate(fibers_map: Image.Image,

                    rings_map: Image.Image,

                    num_steps: int) -> Image.Image:
    # 1) uniformiser le mode
    fibers_map = fibers_map.convert("RGB")
    rings_map = rings_map.convert("RGB")


    # 3) appeler l'inference avec la seed
    result_img = inference(pipe,
                           rings_map,
                           fibers_map,
                           num_steps)
    return result_img


iface = gr.Interface(
    fn=gradio_generate,
    inputs=[
        gr.Image(type="pil", label="Fibre orientation map"),
        gr.Image(type="pil", label="Growth ring map"),
        gr.Number(value=10, label="Number of inference steps")
    ],
    outputs=gr.Image(
        type="pil",
        label="Photorealistic wood generated",
        format="png"           # ← force le .png au téléchargement
    ),
    title="Photorealistic wood generator",
    description="""

    Upload :

    1) a fibre orientation map,

    2) a growth ring map.



    Set the number of inference steps.

    Higher values can improve quality but increase processing time.



    The model will return a photo-realistic rendering of the wood that you can download.

    """
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860, share=True)


# with gr.Blocks() as demo:
#     gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.")
    
#     with gr.Row():
#         fibers = gr.Image(type="pil", label="Fibre orientation map")
#         rings  = gr.Image(type="pil", label="Growth ring map")
#     steps = gr.Number(value=10, label="Number of inference steps")
#     btn   = gr.Button("Generate")

#     # State pour stocker la liste des images
#     state_images = gr.State([])

#     # Slider pour parcourir
#     slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index")
#     # Image affichée
#     display = gr.Image(label="Intermediate result")

#     # 1) Au clique, on génère et on met à jour state + slider + display
#     def run_and_store(fib, ring, num_steps):
#         imgs = inference(pipe, ring,fib,  int(num_steps))
#         # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0
#         return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0]

#     btn.click(
#         fn=run_and_store,
#         inputs=[fibers, rings, steps],
#         outputs=[state_images, slider, display]
#     )

#     # 2) Quand on bouge le slider, on affiche state_images[slider]
#     def select_step(imgs, idx):
#         return imgs[int(idx)]

#     slider.change(
#         fn=select_step,
#         inputs=[state_images, slider],
#         outputs=display
#     )

#     demo.launch()