CarolineM5 commited on
Commit
990a91c
·
verified ·
1 Parent(s): c468816

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +56 -6
  2. inference.py +72 -29
app.py CHANGED
@@ -9,13 +9,66 @@ import gradio as gr
9
  from PIL import Image
10
  import torch
11
  from inference import inference
 
 
12
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
15
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
 
17
  device = torch.device('cpu')
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # --- 3) FONCTION GRADIO D’INTERFACE ---
20
  def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Image:
21
  """
@@ -26,13 +79,10 @@ def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Im
26
  fibers_map = fibers_map.convert("RGB")
27
  rings_map = rings_map.convert("RGB")
28
 
29
- model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
30
-
31
- result_img = inference(model_id, device, rings_map, fibers_map)
32
 
33
  return result_img
34
 
35
-
36
  # --- 4) DÉFINITION DE L’INTERFACE GRADIO ---
37
  iface = gr.Interface(
38
  fn=gradio_generate,
@@ -44,8 +94,8 @@ iface = gr.Interface(
44
  title="Photorealistic wood generator",
45
  description="""
46
  Upload :
47
- 1) a fibre orientation mapping image,
48
- 2) a tree-ring boundary mapping image.
49
 
50
  The model will return a photo-realistic rendering of the wood that you can download.
51
  """
 
9
  from PIL import Image
10
  import torch
11
  from inference import inference
12
+ from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, AutoencoderKL, DDPMScheduler
13
+ from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
 
15
+ class UNetNoCondWrapper(nn.Module):
16
+ def __init__(self, base_unet: UNet2DModel):
17
+ super().__init__()
18
+ self.unet = base_unet
19
 
20
+ def forward(
21
+ self,
22
+ sample,
23
+ timestep,
24
+ encoder_hidden_states=None,
25
+ added_cond_kwargs=None,
26
+ cross_attention_kwargs=None,
27
+ return_dict=False,
28
+ **kwargs
29
+ ):
30
+
31
+ return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
32
 
33
+ def __getattr__(self, name):
34
+ if name in ("unet", "forward", "__getstate__", "__setstate__"):
35
+ return super().__getattr__(name)
36
+ return getattr(self.unet, name)
37
+
38
+ def save_pretrained(self, save_directory, **kwargs):
39
+ # délègue à la vraie instance UNet2DModel
40
+ return self.unet.save_pretrained(save_directory, **kwargs)
41
+
42
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
43
 
44
  device = torch.device('cpu')
45
 
46
+ model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt"
47
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
48
+ scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
49
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
50
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
51
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
52
+
53
+ # 2) Chargez votre UNet non‑conditionné et wrappez‑le
54
+ base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
55
+ wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
56
+
57
+ # 3) Construisez la pipeline manuellement
58
+ pipe = StableDiffusionInstructPix2PixPipeline(
59
+ vae=vae,
60
+ text_encoder=text_encoder,
61
+ tokenizer=tokenizer,
62
+ unet=wrapped_unet,
63
+ scheduler=scheduler,
64
+ safety_checker=None,
65
+ feature_extractor=feature_extractor,
66
+ )
67
+
68
+ pipe = pipe.to(torch.float32).to(device)
69
+
70
+
71
+
72
  # --- 3) FONCTION GRADIO D’INTERFACE ---
73
  def gradio_generate(fibers_map: Image.Image, rings_map: Image.Image) -> Image.Image:
74
  """
 
79
  fibers_map = fibers_map.convert("RGB")
80
  rings_map = rings_map.convert("RGB")
81
 
82
+ result_img = inference(pipe, device, rings_map, fibers_map)
 
 
83
 
84
  return result_img
85
 
 
86
  # --- 4) DÉFINITION DE L’INTERFACE GRADIO ---
87
  iface = gr.Interface(
88
  fn=gradio_generate,
 
94
  title="Photorealistic wood generator",
95
  description="""
96
  Upload :
97
+ 1) a fibre orientation map,
98
+ 2) a growth ring map.
99
 
100
  The model will return a photo-realistic rendering of the wood that you can download.
101
  """
inference.py CHANGED
@@ -40,34 +40,9 @@ class UNetNoCondWrapper(nn.Module):
40
  # délègue à la vraie instance UNet2DModel
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
- def inference(model_id,device, img1, img2):
44
-
45
-
46
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
47
- scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
48
- tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
49
- text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
50
- feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
51
-
52
- # 2) Chargez votre UNet non‑conditionné et wrappez‑le
53
- base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
54
- wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
55
-
56
- # 3) Construisez la pipeline manuellement
57
- pipe = StableDiffusionInstructPix2PixPipeline(
58
- vae=vae,
59
- text_encoder=text_encoder,
60
- tokenizer=tokenizer,
61
- unet=wrapped_unet,
62
- scheduler=scheduler,
63
- safety_checker=None,
64
- feature_extractor=feature_extractor,
65
- )
66
- # pipe = pipe.to(torch.float16).to(device)
67
- pipe = pipe.to(torch.float32).to(device)
68
-
69
- generator = torch.Generator("cpu").manual_seed(0)
70
-
71
 
72
  img1 = img1.resize((512, 512))
73
  img2 = img2.resize((512, 512))
@@ -88,7 +63,7 @@ def inference(model_id,device, img1, img2):
88
  image = PIL.Image.fromarray(img_np)
89
  image = PIL.ImageOps.exif_transpose(image)
90
 
91
- num_inference_steps = 20
92
  image_guidance_scale = 1.9
93
  guidance_scale = 10
94
 
@@ -106,3 +81,71 @@ def inference(model_id,device, img1, img2):
106
  edited_image = edited_image[0]
107
 
108
  return edited_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  # délègue à la vraie instance UNet2DModel
41
  return self.unet.save_pretrained(save_directory, **kwargs)
42
 
43
+ def inference(pipe,device, img1, img2):
44
+
45
+ generator = torch.Generator("cpu").manual_seed(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  img1 = img1.resize((512, 512))
48
  img2 = img2.resize((512, 512))
 
63
  image = PIL.Image.fromarray(img_np)
64
  image = PIL.ImageOps.exif_transpose(image)
65
 
66
+ num_inference_steps = 5
67
  image_guidance_scale = 1.9
68
  guidance_scale = 10
69
 
 
81
  edited_image = edited_image[0]
82
 
83
  return edited_image
84
+
85
+
86
+ # def inference(model_id,device, img1, img2):
87
+
88
+
89
+ # vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
90
+ # scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
91
+ # tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
92
+ # text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
93
+ # feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
94
+
95
+ # # 2) Chargez votre UNet non‑conditionné et wrappez‑le
96
+ # base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
97
+ # wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
98
+
99
+ # # 3) Construisez la pipeline manuellement
100
+ # pipe = StableDiffusionInstructPix2PixPipeline(
101
+ # vae=vae,
102
+ # text_encoder=text_encoder,
103
+ # tokenizer=tokenizer,
104
+ # unet=wrapped_unet,
105
+ # scheduler=scheduler,
106
+ # safety_checker=None,
107
+ # feature_extractor=feature_extractor,
108
+ # )
109
+ # # pipe = pipe.to(torch.float16).to(device)
110
+ # pipe = pipe.to(torch.float32).to(device)
111
+
112
+ # generator = torch.Generator("cpu").manual_seed(0)
113
+
114
+
115
+ # img1 = img1.resize((512, 512))
116
+ # img2 = img2.resize((512, 512))
117
+
118
+ # img1_np = np.array(img1)
119
+ # if len(img1_np.shape) > 2:
120
+ # img1_np = img1_np[:, :, 0]
121
+
122
+ # img2_np = np.array(img2)
123
+ # if len(img2_np.shape) > 2:
124
+ # img2_np = img2_np[:, :, 0]
125
+
126
+ # img1_np[img1_np > 200] = 255
127
+ # img1_np[img1_np <= 200] = 0
128
+ # img1_np = 255-img1_np
129
+ # img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
130
+
131
+ # image = PIL.Image.fromarray(img_np)
132
+ # image = PIL.ImageOps.exif_transpose(image)
133
+
134
+ # num_inference_steps = 20
135
+ # image_guidance_scale = 1.9
136
+ # guidance_scale = 10
137
+
138
+ # edited_image = pipe(
139
+ # prompt=[""] ,
140
+ # image=image,
141
+ # num_inference_steps=num_inference_steps,
142
+ # image_guidance_scale=image_guidance_scale,
143
+ # guidance_scale=guidance_scale,
144
+ # generator=generator,
145
+ # safety_checker=None,
146
+ # num_images_per_prompt=1
147
+ # ).images
148
+
149
+ # edited_image = edited_image[0]
150
+
151
+ # return edited_image