ssoxye commited on
Commit
f4ac6fe
·
1 Parent(s): fb56917

Fix VAE dtype mismatch (fp16)

Browse files
Files changed (1) hide show
  1. app.py +25 -2
app.py CHANGED
@@ -223,14 +223,37 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
223
  cn_kwargs["variant"] = "fp16"
224
 
225
  controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
226
- pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
227
- BASE_MODEL_ID,
 
 
 
 
228
  controlnet=controlnet,
229
  use_safetensors=True,
230
  torch_dtype=dtype,
231
  add_watermarker=False,
 
 
 
 
 
 
 
232
  ).to(device)
233
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
235
  pipe.enable_attention_slicing()
236
  try:
 
223
  cn_kwargs["variant"] = "fp16"
224
 
225
  controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
226
+
227
+ # ---------------------------------------------------------
228
+ # ✅ Fix 1) Force a consistent dtype for VAE to avoid:
229
+ # RuntimeError: Input type (c10::Half) and bias type (float) should be the same
230
+ # ---------------------------------------------------------
231
+ pipe_kwargs = dict(
232
  controlnet=controlnet,
233
  use_safetensors=True,
234
  torch_dtype=dtype,
235
  add_watermarker=False,
236
+ )
237
+ if dtype == torch.float16:
238
+ pipe_kwargs["variant"] = "fp16"
239
+
240
+ pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
241
+ BASE_MODEL_ID,
242
+ **pipe_kwargs,
243
  ).to(device)
244
 
245
+ # Force VAE params/bias dtype to match the pipeline dtype
246
+ if device == "cuda":
247
+ try:
248
+ pipe.vae.to(dtype=dtype)
249
+ # Some pipelines keep VAE in fp32 on purpose; your custom pipeline doesn't
250
+ # auto-cast inputs to vae.dtype before encode, so disable upcast if present.
251
+ if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
252
+ pipe.vae.config.force_upcast = False
253
+ print(f"[PIPE] VAE casted to {dtype}. force_upcast set to False (if supported).", flush=True)
254
+ except Exception as e:
255
+ print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
256
+
257
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
258
  pipe.enable_attention_slicing()
259
  try: