ssoxye commited on
Commit
ed13881
·
1 Parent(s): ce9b543

Fix VAE dtype mismatch (fp16)

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -215,7 +215,7 @@ from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionMod
215
  @lru_cache(maxsize=1)
216
  def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
217
  device = "cuda" if torch.cuda.is_available() else "cpu"
218
- dtype = torch.float16 if device == "cuda" else torch.float32
219
 
220
  print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
221
 
@@ -351,10 +351,10 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
351
 
352
 
353
  if device == "cuda":
354
- pipe.to(dtype=torch.float16)
355
  try:
356
  for _, proc in pipe.unet.attn_processors.items():
357
- proc.to(dtype=torch.float16)
358
  except Exception:
359
  pass
360
 
 
215
  @lru_cache(maxsize=1)
216
  def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
217
  device = "cuda" if torch.cuda.is_available() else "cpu"
218
+ dtype = torch.float32 if device == "cuda" else torch.float32
219
 
220
  print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
221
 
 
351
 
352
 
353
  if device == "cuda":
354
+ pipe.to(dtype=torch.float32)
355
  try:
356
  for _, proc in pipe.unet.attn_processors.items():
357
+ proc.to(dtype=torch.float32)
358
  except Exception:
359
  pass
360