CarolineM5 commited on
Commit
00edf85
·
verified ·
1 Parent(s): 4549fe0

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +79 -40
  2. inference.py +16 -71
app.py CHANGED
@@ -68,46 +68,85 @@ pipe = StableDiffusionInstructPix2PixPipeline(
68
 
69
  pipe = pipe.to(torch.float32).to(device)
70
 
71
-
72
-
73
- # --- 3) FONCTION GRADIO D’INTERFACE ---
74
- def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image, num_steps): # -> Image.Image
75
- """
76
- Cette fonction sera appelée à chaque upload par Gradio.
77
- Elle doit retourner une PIL.Image (ou un chemin vers l’image enregistrée).
78
- """
79
- # Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
80
- fibers_map = fibers_map.convert("RGB")
81
- rings_map = rings_map.convert("RGB")
82
 
83
- result_img = inference(pipe, device, rings_map, fibers_map, num_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- return result_img
86
-
87
- # --- 4) DÉFINITION DE L’INTERFACE GRADIO ---
88
- iface = gr.Interface(
89
- fn=gradio_generate,
90
- inputs=[
91
- gr.Image(type="pil", label="Fibre orientation map"),
92
- gr.Image(type="pil", label="Growth ring map"),
93
- gr.Number(value=20, label="Number of inference steps")
94
- ],
95
- outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
96
- title="Photorealistic wood generator",
97
- description="""
98
- Upload :
99
- 1) a fibre orientation map,
100
- 2) a growth ring map.
101
 
102
- Set the number of inference steps.
103
- Higher values can improve quality but increase processing time.
104
-
105
- The model will return a photo-realistic rendering of the wood that you can download.
106
- """
107
- )
108
-
109
- # --- 5) LANCER L’APPLICATION ---
110
- if __name__ == "__main__":
111
- # Vous pouvez préciser `server_name="0.0.0.0"` si vous souhaitez qu’il soit accessible sur le réseau
112
- # et `server_port=7860` (ou autre port) si vous voulez le personnaliser.
113
- iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  pipe = pipe.to(torch.float32).to(device)
70
 
71
+ with gr.Blocks() as demo:
72
+ gr.Markdown("## Photorealistic Wood Generator\nUpload your two maps, run inference, then use the slider to browse steps.")
 
 
 
 
 
 
 
 
 
73
 
74
+ with gr.Row():
75
+ fibers = gr.Image(type="pil", label="Fibre orientation map")
76
+ rings = gr.Image(type="pil", label="Growth ring map")
77
+ steps = gr.Number(value=10, label="Number of inference steps")
78
+ btn = gr.Button("Generate")
79
+
80
+ # State pour stocker la liste des images
81
+ state_images = gr.State([])
82
+
83
+ # Slider pour parcourir
84
+ slider = gr.Slider(minimum=0, maximum=0, step=1, value=0, interactive=True, label="Step index")
85
+ # Image affichée
86
+ display = gr.Image(label="Intermediate result")
87
+
88
+ # 1) Au clique, on génère et on met à jour state + slider + display
89
+ def run_and_store(fib, ring, num_steps):
90
+ imgs = inference(fib, ring, int(num_steps))
91
+ # On renvoie : la liste, la nouvelle valeur max du slider, et l’image 0
92
+ return imgs, gr.Slider.update(maximum=len(imgs)-1, value=0), imgs[0]
93
+
94
+ btn.click(
95
+ fn=run_and_store,
96
+ inputs=[fibers, rings, steps],
97
+ outputs=[state_images, slider, display]
98
+ )
99
+
100
+ # 2) Quand on bouge le slider, on affiche state_images[slider]
101
+ def select_step(imgs, idx):
102
+ return imgs[int(idx)]
103
+
104
+ slider.change(
105
+ fn=select_step,
106
+ inputs=[state_images, slider],
107
+ outputs=display
108
+ )
109
+
110
+ demo.launch()
111
+
112
+ # # --- 3) FONCTION GRADIO D’INTERFACE ---
113
+ # def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image, num_steps): # -> Image.Image
114
+ # """
115
+ # Cette fonction sera appelée à chaque upload par Gradio.
116
+ # Elle doit retourner une PIL.Image (ou un chemin vers l’image enregistrée).
117
+ # """
118
+ # # Vérifier que les deux images sont bien en mode RGB (ou adapter si besoin)
119
+ # fibers_map = fibers_map.convert("RGB")
120
+ # rings_map = rings_map.convert("RGB")
121
 
122
+ # result_img = inference(pipe, rings_map, fibers_map, num_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ # return result_img
125
+
126
+ # # --- 4) DÉFINITION DE L’INTERFACE GRADIO ---
127
+ # iface = gr.Interface(
128
+ # fn=gradio_generate,
129
+ # inputs=[
130
+ # gr.Image(type="pil", label="Fibre orientation map"),
131
+ # gr.Image(type="pil", label="Growth ring map"),
132
+ # gr.Number(value=10, label="Number of inference steps")
133
+ # ],
134
+ # outputs=gr.Image(type="pil", label="Photorealistic wood generated"),
135
+ # title="Photorealistic wood generator",
136
+ # description="""
137
+ # Upload :
138
+ # 1) a fibre orientation map,
139
+ # 2) a growth ring map.
140
+
141
+ # Set the number of inference steps.
142
+ # Higher values can improve quality but increase processing time.
143
+
144
+ # The model will return a photo-realistic rendering of the wood that you can download.
145
+ # """
146
+ # )
147
+
148
+ # # --- 5) LANCER L’APPLICATION ---
149
+ # if __name__ == "__main__":
150
+ # # Vous pouvez préciser `server_name="0.0.0.0"` si vous souhaitez qu’il soit accessible sur le réseau
151
+ # # et `server_port=7860` (ou autre port) si vous voulez le personnaliser.
152
+ # iface.launch(server_name="0.0.0.0", server_port=7860, share=False)
inference.py CHANGED
@@ -40,7 +40,7 @@ class UNetNoCondWrapper(nn.Module):
40
  # délègue à la vraie instance UNet2DModel
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
- def inference(pipe,device, img1, img2, num_steps):
44
 
45
  generator = torch.Generator("cpu").manual_seed(0)
46
 
@@ -57,14 +57,22 @@ def inference(pipe,device, img1, img2, num_steps):
57
 
58
  img1_np[img1_np > 200] = 255
59
  img1_np[img1_np <= 200] = 0
60
- img1_np = 255-img1_np
61
  img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
62
 
63
  image = PIL.Image.fromarray(img_np)
64
  image = PIL.ImageOps.exif_transpose(image)
 
 
 
 
 
 
 
 
 
65
 
66
  num_inference_steps = num_steps
67
- print(num_inference_steps)
68
  image_guidance_scale = 1.9
69
  guidance_scale = 10
70
 
@@ -76,77 +84,14 @@ def inference(pipe,device, img1, img2, num_steps):
76
  guidance_scale=guidance_scale,
77
  generator=generator,
78
  safety_checker=None,
 
 
79
  num_images_per_prompt=1
80
  ).images
81
 
82
- edited_image = edited_image[0]
83
-
84
- return edited_image
85
-
86
-
87
- # def inference(model_id,device, img1, img2):
88
-
89
-
90
- # vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
91
- # scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
92
- # tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
93
- # text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
94
- # feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
95
-
96
- # # 2) Chargez votre UNet non‑conditionné et wrappez‑le
97
- # base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
98
- # wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
99
-
100
- # # 3) Construisez la pipeline manuellement
101
- # pipe = StableDiffusionInstructPix2PixPipeline(
102
- # vae=vae,
103
- # text_encoder=text_encoder,
104
- # tokenizer=tokenizer,
105
- # unet=wrapped_unet,
106
- # scheduler=scheduler,
107
- # safety_checker=None,
108
- # feature_extractor=feature_extractor,
109
- # )
110
- # # pipe = pipe.to(torch.float16).to(device)
111
- # pipe = pipe.to(torch.float32).to(device)
112
 
113
- # generator = torch.Generator("cpu").manual_seed(0)
114
-
115
-
116
- # img1 = img1.resize((512, 512))
117
- # img2 = img2.resize((512, 512))
118
-
119
- # img1_np = np.array(img1)
120
- # if len(img1_np.shape) > 2:
121
- # img1_np = img1_np[:, :, 0]
122
-
123
- # img2_np = np.array(img2)
124
- # if len(img2_np.shape) > 2:
125
- # img2_np = img2_np[:, :, 0]
126
 
127
- # img1_np[img1_np > 200] = 255
128
- # img1_np[img1_np <= 200] = 0
129
- # img1_np = 255-img1_np
130
- # img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
131
 
132
- # image = PIL.Image.fromarray(img_np)
133
- # image = PIL.ImageOps.exif_transpose(image)
134
-
135
- # num_inference_steps = 20
136
- # image_guidance_scale = 1.9
137
- # guidance_scale = 10
138
-
139
- # edited_image = pipe(
140
- # prompt=[""] ,
141
- # image=image,
142
- # num_inference_steps=num_inference_steps,
143
- # image_guidance_scale=image_guidance_scale,
144
- # guidance_scale=guidance_scale,
145
- # generator=generator,
146
- # safety_checker=None,
147
- # num_images_per_prompt=1
148
- # ).images
149
-
150
- # edited_image = edited_image[0]
151
-
152
- # return edited_image
 
40
  # délègue à la vraie instance UNet2DModel
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
+ def inference(pipe, img1, img2, num_steps):
44
 
45
  generator = torch.Generator("cpu").manual_seed(0)
46
 
 
57
 
58
  img1_np[img1_np > 200] = 255
59
  img1_np[img1_np <= 200] = 0
60
+ # img1_np = 255-img1_np
61
  img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
62
 
63
  image = PIL.Image.fromarray(img_np)
64
  image = PIL.ImageOps.exif_transpose(image)
65
+
66
+ all_images = []
67
+ def cb_fn(step, timestep, latents):
68
+ # latents(torch.Tensor) -> image via VAE decode
69
+ with torch.no_grad():
70
+ decoded = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
71
+ # post‐traitement en PIL
72
+ img = pipe.numpy_to_pil(decoded.cpu().clamp(0,1))[0]
73
+ all_images.append(img)
74
 
75
  num_inference_steps = num_steps
 
76
  image_guidance_scale = 1.9
77
  guidance_scale = 10
78
 
 
84
  guidance_scale=guidance_scale,
85
  generator=generator,
86
  safety_checker=None,
87
+ callback=cb_fn,
88
+ callback_steps=1,
89
  num_images_per_prompt=1
90
  ).images
91
 
92
+ return all_images
93
+ # edited_image = edited_image[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ # return edited_image
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
 
 
 
97