ssoxye commited on
Commit
ce9b543
·
1 Parent(s): c132490

Fix VAE dtype mismatch (fp16)

Browse files
Files changed (1) hide show
  1. app.py +13 -1
app.py CHANGED
@@ -209,7 +209,7 @@ def save_cropped(imgs, out_path: str):
209
  out = np.concatenate(cropped, axis=1)
210
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
211
  imageio.imsave(out_path, out)
212
-
213
  from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
214
 
215
  @lru_cache(maxsize=1)
@@ -348,6 +348,18 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
348
  garment_images=garment_pil,
349
  garment_mask=garment_mask_pil,
350
  )
 
 
 
 
 
 
 
 
 
 
 
 
351
 
352
  style_img = Image.open(paths.style_path).convert("RGB")
353
 
 
209
  out = np.concatenate(cropped, axis=1)
210
  os.makedirs(os.path.dirname(out_path), exist_ok=True)
211
  imageio.imsave(out_path, out)
212
+
213
  from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionModel
214
 
215
  @lru_cache(maxsize=1)
 
348
  garment_images=garment_pil,
349
  garment_mask=garment_mask_pil,
350
  )
351
+
352
+
353
+ if device == "cuda":
354
+ pipe.to(dtype=torch.float16)
355
+ try:
356
+ for _, proc in pipe.unet.attn_processors.items():
357
+ proc.to(dtype=torch.float16)
358
+ except Exception:
359
+ pass
360
+
361
+
362
+
363
 
364
  style_img = Image.open(paths.style_path).convert("RGB")
365