CarolineM5 commited on
Commit
0600e9e
·
verified ·
1 Parent(s): a28b5ab

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +44 -62
  2. inference.py +40 -56
app.py CHANGED
@@ -45,83 +45,67 @@ class UNetNoCondWrapper(nn.Module):
45
  return getattr(self.unet, name)
46
 
47
  def save_pretrained(self, save_directory, **kwargs):
48
- # délègue à la vraie instance UNet2DModel
49
  return self.unet.save_pretrained(save_directory, **kwargs)
50
 
51
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
 
53
- # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
-
55
- # model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
56
- # vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
57
- # scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
58
- # tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
59
- # text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
60
- # feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
61
-
62
- # # 2) Chargez votre UNet non‑conditionné et wrappez‑le
63
- # base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
64
- # wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
65
-
66
- # # 3) Construisez la pipeline manuellement
67
- # pipe = StableDiffusionInstructPix2PixPipeline(
68
- # vae=vae,
69
- # text_encoder=text_encoder,
70
- # tokenizer=tokenizer,
71
- # unet=wrapped_unet,
72
- # scheduler=scheduler,
73
- # safety_checker=None,
74
- # feature_extractor=feature_extractor,
75
- # )
76
-
77
- # pipe = pipe.to(torch.float32).to(device)
78
  # @spaces.GPU
79
 
80
- # def pil_to_data_uri(img: Image.Image) -> str:
81
- # buf = io.BytesIO()
82
- # img.save(buf, format="PNG")
83
- # b = base64.b64encode(buf.getvalue()).decode("utf-8")
84
- # return f"data:image/png;base64,{b}"
85
-
86
 
87
 
88
  def build_textured_cube(pil_imgs, face_rotations=None):
89
  """
90
- Crée un parallélépipède texturé (OBJ + MTL + textures).
91
- - pil_imgs: liste/tuple de 4 PIL.Image dans l'ordre [front, right, back, left]
92
- - Retour: (chemin_absolu_obj, tmpdir)
93
- Defaults:
94
- default_rots = {"front": 0, "right": 270, "back": 180, "left": 90, "top": 0, "bottom": 0}
95
- face_order = ["top","right","bottom","left","front","back"]
96
- Notes:
97
- - Ecrit les fichiers dans /tmp/gradio si possible (HF Spaces).
98
- - front/back utilisent la taille de pil_imgs[0] ; left/right utilisent leur propre largeur (rectangles).
99
  """
100
  import os
101
  import tempfile
102
  from PIL import Image
103
 
104
- # validation
105
  if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
106
- raise ValueError("build_textured_cube attend une liste/tuple de 4 images PIL (front, right, back, left).")
107
 
108
- # defaults rotation & ordre
109
  default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
110
- # default_rots = {"front": 0, "right": 0, "back": 0, "left": 0, "top": 0, "bottom": 0}
111
  if face_rotations is None:
112
  face_rotations = default_rots
113
  else:
114
  for k, v in default_rots.items():
115
  face_rotations.setdefault(k, v)
116
 
117
- # dossier temporaire (préférer /tmp/gradio sur HF)
118
  base_dir = "/tmp/gradio"
119
  if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
120
  tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
121
  else:
122
  tmpdir = tempfile.mkdtemp(prefix="parallelep_")
123
 
124
- # noms relatifs pour textures (mtl utilisera ces noms)
125
  tex_names = {
126
  "front": "tex_front.png",
127
  "right": "tex_right.png",
@@ -131,34 +115,32 @@ def build_textured_cube(pil_imgs, face_rotations=None):
131
  "bottom": "tex_bottom.png",
132
  }
133
 
134
- # récupérer tailles (on suppose que inference a déjà redimensionné left/right si besoin)
135
  front_w, front_h = pil_imgs[0].size
136
  right_w, right_h = pil_imgs[1].size
137
 
138
  ratio = 45/145
139
  right_w = int(front_w * ratio)
140
 
141
- # définir dimensions physiques du parallélépipède (en "px" puis on normalise)
142
- width_px = float(front_w) # largeur X (front width)
143
- height_px = float(right_w) # hauteur Z
144
- # profondeur Y : on prend la largeur des faces latérales (moyenne left/right)
145
  depth_px = float(front_h)
146
 
147
- # normalisation pour garder des coordonnées de l'ordre de ±0.5
148
  max_dim = max(width_px, depth_px, height_px, 1.0)
149
  scale = 1.0 / max_dim
150
- half_x = (width_px * 0.5) * scale # demi-largeur en X
151
- half_y = (depth_px * 0.5) * scale # demi-profondeur en Y
152
- half_z = (height_px * 0.5) * scale # demi-hauteur en Z
153
 
154
- # mapping attendu pour pil_imgs
155
  mapping_order = ["front", "right", "back", "left"]
156
- # sauvegarder textures (avec rotation demandée) dans tmpdir
157
  for img, face_name in zip(pil_imgs[:4], mapping_order):
158
  im = img.convert("RGB")
159
  angle = face_rotations.get(face_name, 0)
160
  if angle % 360 != 0:
161
- # PIL rotate: angle en degrés, positif = CCW
162
  im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
163
  path = os.path.join(tmpdir, tex_names[face_name])
164
  im.save(path, format="PNG")
@@ -167,7 +149,7 @@ def build_textured_cube(pil_imgs, face_rotations=None):
167
  except Exception:
168
  pass
169
 
170
- # top/bottom textures noires (même taille que front pour cohérence)
171
  black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
172
  for face_name in ("top", "bottom"):
173
  im = black
@@ -181,7 +163,7 @@ def build_textured_cube(pil_imgs, face_rotations=None):
181
  except Exception:
182
  pass
183
 
184
- # --- wri .mtl (références relatives) ---
185
  mtl_path = os.path.join(tmpdir, "parallelep.mtl")
186
  with open(mtl_path, "w", encoding="utf-8") as f:
187
  f.write("# Material file for parallelepiped\n")
@@ -294,7 +276,7 @@ def build_textured_cube(pil_imgs, face_rotations=None):
294
  # -------------------------
295
  def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
296
  try:
297
- outputs = inference(fibers, rings, num_steps)
298
  if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
299
  raise ValueError("user_inference must return a list/tuple of 4 images.")
300
 
 
45
  return getattr(self.unet, name)
46
 
47
  def save_pretrained(self, save_directory, **kwargs):
48
+
49
  return self.unet.save_pretrained(save_directory, **kwargs)
50
 
51
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
52
 
53
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
54
+
55
+ model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
56
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
57
+ scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
58
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
59
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
60
+ feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
61
+
62
+ # 2) Chargez votre UNet non‑conditionné et wrappez‑le
63
+ base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
64
+ wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
65
+
66
+ # 3) Construisez la pipeline manuellement
67
+ pipe = StableDiffusionInstructPix2PixPipeline(
68
+ vae=vae,
69
+ text_encoder=text_encoder,
70
+ tokenizer=tokenizer,
71
+ unet=wrapped_unet,
72
+ scheduler=scheduler,
73
+ safety_checker=None,
74
+ feature_extractor=feature_extractor,
75
+ )
76
+
77
+ pipe = pipe.to(torch.float32).to(device)
78
  # @spaces.GPU
79
 
 
 
 
 
 
 
80
 
81
 
82
  def build_textured_cube(pil_imgs, face_rotations=None):
83
  """
84
+ Creates a textured parallelepiped (OBJ + MTL + textures).
 
 
 
 
 
 
 
 
85
  """
86
  import os
87
  import tempfile
88
  from PIL import Image
89
 
90
+
91
  if not (isinstance(pil_imgs, (list, tuple)) and len(pil_imgs) >= 4):
92
+ raise ValueError("build_textured_cube expects a list/tuple of 4 PIL images (front, right, back, left).")
93
 
 
94
  default_rots = {"front": 0, "right": 270, "back": 180, "left": 270, "top": 0, "bottom": 0}
 
95
  if face_rotations is None:
96
  face_rotations = default_rots
97
  else:
98
  for k, v in default_rots.items():
99
  face_rotations.setdefault(k, v)
100
 
101
+
102
  base_dir = "/tmp/gradio"
103
  if os.path.isdir(base_dir) and os.access(base_dir, os.W_OK):
104
  tmpdir = tempfile.mkdtemp(prefix="parallelep_", dir=base_dir)
105
  else:
106
  tmpdir = tempfile.mkdtemp(prefix="parallelep_")
107
 
108
+ # relative names for textures (mtl will use these names)
109
  tex_names = {
110
  "front": "tex_front.png",
111
  "right": "tex_right.png",
 
115
  "bottom": "tex_bottom.png",
116
  }
117
 
 
118
  front_w, front_h = pil_imgs[0].size
119
  right_w, right_h = pil_imgs[1].size
120
 
121
  ratio = 45/145
122
  right_w = int(front_w * ratio)
123
 
124
+ # define the physical dimensions of the parallelepiped (in px”, then normalize)
125
+ width_px = float(front_w)
126
+ height_px = float(right_w)
 
127
  depth_px = float(front_h)
128
 
129
+ # normalization to keep coordinates within ±0.5
130
  max_dim = max(width_px, depth_px, height_px, 1.0)
131
  scale = 1.0 / max_dim
132
+ half_x = (width_px * 0.5) * scale
133
+ half_y = (depth_px * 0.5) * scale
134
+ half_z = (height_px * 0.5) * scale
135
 
136
+
137
  mapping_order = ["front", "right", "back", "left"]
138
+ # save textures in tmpdir
139
  for img, face_name in zip(pil_imgs[:4], mapping_order):
140
  im = img.convert("RGB")
141
  angle = face_rotations.get(face_name, 0)
142
  if angle % 360 != 0:
143
+ # PIL rotate: angle in degrees, positive = CCW
144
  im = im.rotate(angle, resample=Image.BICUBIC, expand=False)
145
  path = os.path.join(tmpdir, tex_names[face_name])
146
  im.save(path, format="PNG")
 
149
  except Exception:
150
  pass
151
 
152
+ # black top/bottom
153
  black = Image.new("RGB", (front_w, front_h), (0, 0, 0))
154
  for face_name in ("top", "bottom"):
155
  im = black
 
163
  except Exception:
164
  pass
165
 
166
+ # --- write .mtl ---
167
  mtl_path = os.path.join(tmpdir, "parallelep.mtl")
168
  with open(mtl_path, "w", encoding="utf-8") as f:
169
  f.write("# Material file for parallelepiped\n")
 
276
  # -------------------------
277
  def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
278
  try:
279
+ outputs = inference(pipe, fibers, rings, num_steps)
280
  if not (isinstance(outputs, (list, tuple)) and len(outputs) >= 4):
281
  raise ValueError("user_inference must return a list/tuple of 4 images.")
282
 
inference.py CHANGED
@@ -21,8 +21,8 @@ def pil_from(x):
21
  if isinstance(x, str):
22
  return PIL.Image.open(x)
23
  return x
24
- def inference(fiber_imgs, ring_imgs, num_steps):
25
- # def inference(pipe, fiber_imgs, ring_imgs, num_steps):
26
  """
27
  fiber_imgs: PIL.Image or paths
28
  ring_imgs: PIL.Image or paths
@@ -31,68 +31,52 @@ def inference(fiber_imgs, ring_imgs, num_steps):
31
  returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
32
  """
33
  # seed + generator
34
- # seed = random.randrange(0, 2**32)
35
- # torch.manual_seed(seed)
36
- # generator = torch.Generator("cpu").manual_seed(seed)
37
 
38
  # sizes
39
  tile = 512
40
  canvas_size = tile * 2
41
- ratio = 45/145
42
- original_width, original_height = (512,512)
43
- new_width = int(original_width * ratio)
44
-
45
-
46
- # # stack channels: [fiber, ring, ring] -> H,W,3
47
- # arr_f = np.array(fiber_imgs).astype(np.uint8)
48
- # arr_r = np.array(ring_imgs).astype(np.uint8)
49
- # arr_in = np.stack([arr_f, arr_r, arr_r], axis=2) # H,W,3
50
- # input_image = PIL.Image.fromarray(arr_in) # PIL RGB
51
-
52
- # # run pipeline (use autocast consistent with device)
53
- # edited_images = []
54
- # if torch.backends.mps.is_available():
55
- # autocast_ctx = nullcontext()
56
- # else:
57
- # autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
58
-
59
- # with autocast_ctx:
60
- # out = pipe(
61
- # prompt="", # empty prompt (your model ignores prompt)
62
- # image=input_image,
63
- # num_inference_steps=num_steps,
64
- # image_guidance_scale=1.9,
65
- # guidance_scale=10.0,
66
- # generator=generator,
67
- # safety_checker=None,
68
- # num_images_per_prompt=1,
69
- # )
70
- # # out.images may be a list; take first
71
- # pred = out.images[0]
72
-
73
- # # ensure pred is canvas_size x canvas_size
74
- # if pred.size != (canvas_size, canvas_size):
75
- # pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
76
 
77
  # split into 4 tiles in same order TL, TR, BL, BR
78
- # tl = pred.crop((0, 0, tile, tile))
79
- # tr = pred.crop((tile, 0, canvas_size, tile))
80
- # bl = pred.crop((0, tile, tile, canvas_size))
81
- # br = pred.crop((tile, tile, canvas_size, canvas_size))
82
 
83
- ring_imgs = PIL.Image.fromarray(ring_imgs)
84
- fiber_imgs = PIL.Image.fromarray(fiber_imgs)
85
 
86
- tl = ring_imgs.crop((0, 0, tile, tile))
87
- tr = ring_imgs.crop((tile, 0, canvas_size, tile))
88
- bl = ring_imgs.crop((0, tile, tile, canvas_size))
89
- br = ring_imgs.crop((tile, tile, canvas_size, canvas_size))
90
-
91
- # tr = cv2.resize(np.asarray(tr), (new_width, original_height), interpolation=cv2.INTER_LANCZOS4)
92
- # br = cv2.resize(np.asarray(br), (new_width, original_height), interpolation=cv2.INTER_LANCZOS4)
93
-
94
- # tr = PIL.Image.fromarray(tr)
95
- # br = PIL.Image.fromarray(br)
96
 
97
  # close opened images to free handles
98
  fiber_imgs.close()
 
21
  if isinstance(x, str):
22
  return PIL.Image.open(x)
23
  return x
24
+
25
+ def inference(pipe, fiber_imgs, ring_imgs, num_steps):
26
  """
27
  fiber_imgs: PIL.Image or paths
28
  ring_imgs: PIL.Image or paths
 
31
  returns: list of 4 PIL.Image (L mode), order [1, 4, 3, 2]
32
  """
33
  # seed + generator
34
+ seed = random.randrange(0, 2**32)
35
+ torch.manual_seed(seed)
36
+ generator = torch.Generator("cpu").manual_seed(seed)
37
 
38
  # sizes
39
  tile = 512
40
  canvas_size = tile * 2
41
+
42
+
43
+ # stack channels: [fiber, ring, ring] -> H,W,3
44
+ arr_f = np.array(fiber_imgs).astype(np.uint8)
45
+ arr_r = np.array(ring_imgs).astype(np.uint8)
46
+ arr_in = np.stack([arr_f, arr_r, arr_r], axis=2) # H,W,3
47
+ input_image = PIL.Image.fromarray(arr_in) # PIL RGB
48
+
49
+ # run pipeline (use autocast consistent with device)
50
+ if torch.backends.mps.is_available():
51
+ autocast_ctx = nullcontext()
52
+ else:
53
+ autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
54
+
55
+ with autocast_ctx:
56
+ out = pipe(
57
+ prompt="", # empty prompt (your model ignores prompt)
58
+ image=input_image,
59
+ num_inference_steps=num_steps,
60
+ image_guidance_scale=1.9,
61
+ guidance_scale=10.0,
62
+ generator=generator,
63
+ safety_checker=None,
64
+ num_images_per_prompt=1,
65
+ )
66
+ # out.images may be a list; take first
67
+ pred = out.images[0]
68
+
69
+ # ensure pred is canvas_size x canvas_size
70
+ if pred.size != (canvas_size, canvas_size):
71
+ pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
 
 
 
 
72
 
73
  # split into 4 tiles in same order TL, TR, BL, BR
74
+ tl = pred.crop((0, 0, tile, tile))
75
+ tr = pred.crop((tile, 0, canvas_size, tile))
76
+ bl = pred.crop((0, tile, tile, canvas_size))
77
+ br = pred.crop((tile, tile, canvas_size, canvas_size))
78
 
 
 
79
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # close opened images to free handles
82
  fiber_imgs.close()