OsamaAbdeljaber commited on
Commit
39cd58c
·
verified ·
1 Parent(s): 85e2b85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -154
app.py CHANGED
@@ -1,155 +1,154 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- Created on Tue Jun 10 11:16:28 2025
4
-
5
- @author: camaac
6
- """
7
-
8
- import gradio as gr
9
- from PIL import Image
10
- import torch
11
- from inference import inference
12
- from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
13
- from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
- import torch.nn as nn
15
-
16
- class UNetNoCondWrapper(nn.Module):
17
- def __init__(self, base_unet: UNet2DModel):
18
- super().__init__()
19
- self.unet = base_unet
20
-
21
- def forward(
22
- self,
23
- sample,
24
- timestep,
25
- encoder_hidden_states=None,
26
- added_cond_kwargs=None,
27
- cross_attention_kwargs=None,
28
- return_dict=False,
29
- **kwargs
30
- ):
31
-
32
- return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
33
-
34
- def __getattr__(self, name):
35
- if name in ("unet", "forward", "__getstate__", "__setstate__"):
36
- return super().__getattr__(name)
37
- return getattr(self.unet, name)
38
-
39
- def save_pretrained(self, save_directory, **kwargs):
40
- # délègue à la vraie instance UNet2DModel
41
- return self.unet.save_pretrained(save_directory, **kwargs)
42
-
43
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
-
45
- device = torch.device('cpu')
46
-
47
- model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
48
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
49
- scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
50
- tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
51
- text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
52
- feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
53
-
54
- # 2) Chargez votre UNet non‑conditionné et wrappez‑le
55
- base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
56
- wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
57
-
58
- # 3) Construisez la pipeline manuellement
59
- pipe = StableDiffusionInstructPix2PixPipeline(
60
- vae=vae,
61
- text_encoder=text_encoder,
62
- tokenizer=tokenizer,
63
- unet=wrapped_unet,
64
- scheduler=scheduler,
65
- safety_checker=None,
66
- feature_extractor=feature_extractor,
67
- )
68
-
69
- pipe = pipe.to(torch.float32).to(device)
70
-
71
- def gradio_generate(fibers_map: Image.Image,
72
- rings_map: Image.Image,
73
- num_steps: int) -> Image.Image:
74
- # 1) uniformiser le mode
75
- fibers_map = fibers_map.convert("RGB")
76
- rings_map = rings_map.convert("RGB")
77
-
78
-
79
- # 3) appeler l'inference avec la seed
80
- result_img = inference(pipe,
81
- rings_map,
82
- fibers_map,
83
- num_steps)
84
- return result_img
85
-
86
-
87
- iface = gr.Interface(
88
- fn=gradio_generate,
89
- inputs=[
90
- gr.Image(type="pil", label="Fibre orientation map"),
91
- gr.Image(type="pil", label="Growth ring map"),
92
- gr.Number(value=10, label="Number of inference steps")
93
- ],
94
- outputs=gr.Image(
95
- type="pil",
96
- label="Photorealistic wood generated",
97
- format="png" # ← force le .png au téléchargement
98
- ),
99
- title="Photorealistic wood generator",
100
- description="""
101
- Upload :
102
- 1) a fibre orientation map,
103
- 2) a growth ring map.
104
-
105
- Set the number of inference steps.
106
- Higher values can improve quality but increase processing time.
107
-
108
- The model will return a photo-realistic rendering of the wood that you can download.
109
- """
110
- )
111
-
112
- if __name__ == "__main__":
113
- iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
114
-
115
-
116
- # with gr.Blocks() as demo:
117
- # gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.")
118
-
119
- # with gr.Row():
120
- # fibers = gr.Image(type="pil", label="Fibre orientation map")
121
- # rings = gr.Image(type="pil", label="Growth ring map")
122
- # steps = gr.Number(value=10, label="Number of inference steps")
123
- # btn = gr.Button("Generate")
124
-
125
- # # State pour stocker la liste des images
126
- # state_images = gr.State([])
127
-
128
- # # Slider pour parcourir
129
- # slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index")
130
- # # Image affichée
131
- # display = gr.Image(label="Intermediate result")
132
-
133
- # # 1) Au clique, on génère et on met à jour state + slider + display
134
- # def run_and_store(fib, ring, num_steps):
135
- # imgs = inference(pipe, ring,fib, int(num_steps))
136
- # # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0
137
- # return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0]
138
-
139
- # btn.click(
140
- # fn=run_and_store,
141
- # inputs=[fibers, rings, steps],
142
- # outputs=[state_images, slider, display]
143
- # )
144
-
145
- # # 2) Quand on bouge le slider, on affiche state_images[slider]
146
- # def select_step(imgs, idx):
147
- # return imgs[int(idx)]
148
-
149
- # slider.change(
150
- # fn=select_step,
151
- # inputs=[state_images, slider],
152
- # outputs=display
153
- # )
154
-
155
  # demo.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Tue Jun 10 11:16:28 2025
4
+
5
+ @author: camaac
6
+ """
7
+ import spaces
8
+ import gradio as gr
9
+ from PIL import Image
10
+ import torch
11
+ from inference import inference
12
+ from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
13
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
+ import torch.nn as nn
15
+
16
+ class UNetNoCondWrapper(nn.Module):
17
+ def __init__(self, base_unet: UNet2DModel):
18
+ super().__init__()
19
+ self.unet = base_unet
20
+
21
+ def forward(
22
+ self,
23
+ sample,
24
+ timestep,
25
+ encoder_hidden_states=None,
26
+ added_cond_kwargs=None,
27
+ cross_attention_kwargs=None,
28
+ return_dict=False,
29
+ **kwargs
30
+ ):
31
+
32
+ return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
33
+
34
+ def __getattr__(self, name):
35
+ if name in ("unet", "forward", "__getstate__", "__setstate__"):
36
+ return super().__getattr__(name)
37
+ return getattr(self.unet, name)
38
+
39
+ def save_pretrained(self, save_directory, **kwargs):
40
+ # délègue à la vraie instance UNet2DModel
41
+ return self.unet.save_pretrained(save_directory, **kwargs)
42
+
43
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
44
+
45
+ model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
46
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
47
+ scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
48
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
49
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
50
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
51
+
52
+ # 2) Chargez votre UNet non‑conditionné et wrappez‑le
53
+ base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
54
+ wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
55
+
56
+ # 3) Construisez la pipeline manuellement
57
+ pipe = StableDiffusionInstructPix2PixPipeline(
58
+ vae=vae,
59
+ text_encoder=text_encoder,
60
+ tokenizer=tokenizer,
61
+ unet=wrapped_unet,
62
+ scheduler=scheduler,
63
+ safety_checker=None,
64
+ feature_extractor=feature_extractor,
65
+ )
66
+
67
+ pipe = pipe.to(torch.float32).to(device)
68
+
69
+ @spaces.GPU
70
+ def gradio_generate(fibers_map: Image.Image,
71
+ rings_map: Image.Image,
72
+ num_steps: int) -> Image.Image:
73
+ # 1) uniformiser le mode
74
+ fibers_map = fibers_map.convert("RGB")
75
+ rings_map = rings_map.convert("RGB")
76
+
77
+
78
+ # 3) appeler l'inference avec la seed
79
+ result_img = inference(pipe,
80
+ rings_map,
81
+ fibers_map,
82
+ num_steps)
83
+ return result_img
84
+
85
+
86
+ iface = gr.Interface(
87
+ fn=gradio_generate,
88
+ inputs=[
89
+ gr.Image(type="pil", label="Fibre orientation map"),
90
+ gr.Image(type="pil", label="Growth ring map"),
91
+ gr.Number(value=10, label="Number of inference steps")
92
+ ],
93
+ outputs=gr.Image(
94
+ type="pil",
95
+ label="Photorealistic wood generated",
96
+ format="png" # force le .png au téléchargement
97
+ ),
98
+ title="Photorealistic wood generator",
99
+ description="""
100
+ Upload :
101
+ 1) a fibre orientation map,
102
+ 2) a growth ring map.
103
+
104
+ Set the number of inference steps.
105
+ Higher values can improve quality but increase processing time.
106
+
107
+ The model will return a photo-realistic rendering of the wood that you can download.
108
+ """
109
+ )
110
+
111
+ if __name__ == "__main__":
112
+ iface.launch(server_name="0.0.0.0", server_port=7860, share=True)
113
+
114
+
115
+ # with gr.Blocks() as demo:
116
+ # gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.")
117
+
118
+ # with gr.Row():
119
+ # fibers = gr.Image(type="pil", label="Fibre orientation map")
120
+ # rings = gr.Image(type="pil", label="Growth ring map")
121
+ # steps = gr.Number(value=10, label="Number of inference steps")
122
+ # btn = gr.Button("Generate")
123
+
124
+ # # State pour stocker la liste des images
125
+ # state_images = gr.State([])
126
+
127
+ # # Slider pour parcourir
128
+ # slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index")
129
+ # # Image affichée
130
+ # display = gr.Image(label="Intermediate result")
131
+
132
+ # # 1) Au clique, on génère et on met à jour state + slider + display
133
+ # def run_and_store(fib, ring, num_steps):
134
+ # imgs = inference(pipe, ring,fib, int(num_steps))
135
+ # # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0
136
+ # return imgs, gr.update(maximum=len(imgs)-1, value=0), imgs[0]
137
+
138
+ # btn.click(
139
+ # fn=run_and_store,
140
+ # inputs=[fibers, rings, steps],
141
+ # outputs=[state_images, slider, display]
142
+ # )
143
+
144
+ # # 2) Quand on bouge le slider, on affiche state_images[slider]
145
+ # def select_step(imgs, idx):
146
+ # return imgs[int(idx)]
147
+
148
+ # slider.change(
149
+ # fn=select_step,
150
+ # inputs=[state_images, slider],
151
+ # outputs=display
152
+ # )
153
+
 
154
  # demo.launch()