ssoxye commited on
Commit
b5130ce
·
1 Parent(s): f4ac6fe

Fix VAE dtype mismatch (fp16)

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -211,6 +211,9 @@ def save_cropped(imgs, out_path: str):
211
  imageio.imsave(out_path, out)
212
 
213
 
 
 
 
214
  @lru_cache(maxsize=1)
215
  def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
216
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -218,18 +221,28 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
218
 
219
  print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
220
 
 
221
  cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
222
  if dtype == torch.float16:
223
  cn_kwargs["variant"] = "fp16"
224
-
225
  controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
226
 
227
- # ---------------------------------------------------------
228
- # Fix 1) Force a consistent dtype for VAE to avoid:
229
- # RuntimeError: Input type (c10::Half) and bias type (float) should be the same
230
- # ---------------------------------------------------------
 
 
 
 
 
 
 
 
 
231
  pipe_kwargs = dict(
232
  controlnet=controlnet,
 
233
  use_safetensors=True,
234
  torch_dtype=dtype,
235
  add_watermarker=False,
@@ -242,15 +255,13 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
242
  **pipe_kwargs,
243
  ).to(device)
244
 
245
- # Force VAE params/bias dtype to match the pipeline dtype
246
  if device == "cuda":
247
  try:
248
  pipe.vae.to(dtype=dtype)
249
- # Some pipelines keep VAE in fp32 on purpose; your custom pipeline doesn't
250
- # auto-cast inputs to vae.dtype before encode, so disable upcast if present.
251
  if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
252
  pipe.vae.config.force_upcast = False
253
- print(f"[PIPE] VAE casted to {dtype}. force_upcast set to False (if supported).", flush=True)
254
  except Exception as e:
255
  print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
256
 
@@ -264,6 +275,7 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
264
  return pipe, device, dtype
265
 
266
 
 
267
  def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
268
  global H, W
269
  pipe, device, _dtype = get_pipe_and_device()
 
211
  imageio.imsave(out_path, out)
212
 
213
 
214
+ from diffusers import UniPCMultistepScheduler, AutoencoderKL
215
+ # (위 import 라인에 AutoencoderKL 추가)
216
+
217
  @lru_cache(maxsize=1)
218
  def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
219
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
221
 
222
  print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
223
 
224
+ # ControlNet
225
  cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
226
  if dtype == torch.float16:
227
  cn_kwargs["variant"] = "fp16"
 
228
  controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
229
 
230
+ # ✅ VAE를 먼저 safetensors로 강제 로드해서 주입 (bin 찾는 경로 우회)
231
+ vae_kwargs = dict(
232
+ torch_dtype=dtype,
233
+ use_safetensors=True,
234
+ subfolder="vae",
235
+ )
236
+ # SDXL base가 fp16 variant를 제공하는 경우에만 도움이 됨 (없어도 동작)
237
+ if dtype == torch.float16:
238
+ vae_kwargs["variant"] = "fp16"
239
+
240
+ vae = AutoencoderKL.from_pretrained(BASE_MODEL_ID, **vae_kwargs).to(device)
241
+
242
+ # Pipeline
243
  pipe_kwargs = dict(
244
  controlnet=controlnet,
245
+ vae=vae, # ✅ 주입
246
  use_safetensors=True,
247
  torch_dtype=dtype,
248
  add_watermarker=False,
 
255
  **pipe_kwargs,
256
  ).to(device)
257
 
258
+ # (이전 dtype mismatch 방지) VAE dtype 강제 일치
259
  if device == "cuda":
260
  try:
261
  pipe.vae.to(dtype=dtype)
 
 
262
  if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
263
  pipe.vae.config.force_upcast = False
264
+ print(f"[PIPE] VAE casted to {dtype}. force_upcast False (if supported).", flush=True)
265
  except Exception as e:
266
  print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
267
 
 
275
  return pipe, device, dtype
276
 
277
 
278
+
279
  def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
280
  global H, W
281
  pipe, device, _dtype = get_pipe_and_device()