Fix VAE dtype mismatch (fp16)
Browse files
app.py
CHANGED
|
@@ -215,7 +215,7 @@ from diffusers import UniPCMultistepScheduler, AutoencoderKL, UNet2DConditionMod
|
|
| 215 |
@lru_cache(maxsize=1)
|
| 216 |
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
|
| 217 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 218 |
-
dtype = torch.
|
| 219 |
|
| 220 |
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
|
| 221 |
|
|
@@ -351,10 +351,10 @@ def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
|
|
| 351 |
|
| 352 |
|
| 353 |
if device == "cuda":
|
| 354 |
-
pipe.to(dtype=torch.
|
| 355 |
try:
|
| 356 |
for _, proc in pipe.unet.attn_processors.items():
|
| 357 |
-
proc.to(dtype=torch.
|
| 358 |
except Exception:
|
| 359 |
pass
|
| 360 |
|
|
|
|
| 215 |
@lru_cache(maxsize=1)
|
| 216 |
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
|
| 217 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 218 |
+
dtype = torch.float32 if device == "cuda" else torch.float32
|
| 219 |
|
| 220 |
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
|
| 221 |
|
|
|
|
| 351 |
|
| 352 |
|
| 353 |
if device == "cuda":
|
| 354 |
+
pipe.to(dtype=torch.float32)
|
| 355 |
try:
|
| 356 |
for _, proc in pipe.unet.attn_processors.items():
|
| 357 |
+
proc.to(dtype=torch.float32)
|
| 358 |
except Exception:
|
| 359 |
pass
|
| 360 |
|