Fix VAE dtype mismatch (fp16)
Browse files
app.py
CHANGED
|
@@ -223,14 +223,37 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 223 |
cn_kwargs["variant"] = "fp16"
|
| 224 |
|
| 225 |
controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
controlnet=controlnet,
|
| 229 |
use_safetensors=True,
|
| 230 |
torch_dtype=dtype,
|
| 231 |
add_watermarker=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
).to(device)
|
| 233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 235 |
pipe.enable_attention_slicing()
|
| 236 |
try:
|
|
|
|
| 223 |
cn_kwargs["variant"] = "fp16"
|
| 224 |
|
| 225 |
controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
|
| 226 |
+
|
| 227 |
+
# ---------------------------------------------------------
|
| 228 |
+
# ✅ Fix 1) Force a consistent dtype for VAE to avoid:
|
| 229 |
+
# RuntimeError: Input type (c10::Half) and bias type (float) should be the same
|
| 230 |
+
# ---------------------------------------------------------
|
| 231 |
+
pipe_kwargs = dict(
|
| 232 |
controlnet=controlnet,
|
| 233 |
use_safetensors=True,
|
| 234 |
torch_dtype=dtype,
|
| 235 |
add_watermarker=False,
|
| 236 |
+
)
|
| 237 |
+
if dtype == torch.float16:
|
| 238 |
+
pipe_kwargs["variant"] = "fp16"
|
| 239 |
+
|
| 240 |
+
pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
|
| 241 |
+
BASE_MODEL_ID,
|
| 242 |
+
**pipe_kwargs,
|
| 243 |
).to(device)
|
| 244 |
|
| 245 |
+
# Force VAE params/bias dtype to match the pipeline dtype
|
| 246 |
+
if device == "cuda":
|
| 247 |
+
try:
|
| 248 |
+
pipe.vae.to(dtype=dtype)
|
| 249 |
+
# Some pipelines keep VAE in fp32 on purpose; your custom pipeline doesn't
|
| 250 |
+
# auto-cast inputs to vae.dtype before encode, so disable upcast if present.
|
| 251 |
+
if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
|
| 252 |
+
pipe.vae.config.force_upcast = False
|
| 253 |
+
print(f"[PIPE] VAE casted to {dtype}. force_upcast set to False (if supported).", flush=True)
|
| 254 |
+
except Exception as e:
|
| 255 |
+
print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
|
| 256 |
+
|
| 257 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 258 |
pipe.enable_attention_slicing()
|
| 259 |
try:
|