Fix VAE dtype mismatch (fp16)
Browse files
app.py
CHANGED
|
@@ -211,6 +211,9 @@ def save_cropped(imgs, out_path: str):
|
|
| 211 |
imageio.imsave(out_path, out)
|
| 212 |
|
| 213 |
|
|
|
|
|
|
|
|
|
|
| 214 |
@lru_cache(maxsize=1)
|
| 215 |
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
|
| 216 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -218,18 +221,28 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 218 |
|
| 219 |
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
|
| 220 |
|
|
|
|
| 221 |
cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
|
| 222 |
if dtype == torch.float16:
|
| 223 |
cn_kwargs["variant"] = "fp16"
|
| 224 |
-
|
| 225 |
controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
|
| 226 |
|
| 227 |
-
#
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 231 |
pipe_kwargs = dict(
|
| 232 |
controlnet=controlnet,
|
|
|
|
| 233 |
use_safetensors=True,
|
| 234 |
torch_dtype=dtype,
|
| 235 |
add_watermarker=False,
|
|
@@ -242,15 +255,13 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 242 |
**pipe_kwargs,
|
| 243 |
).to(device)
|
| 244 |
|
| 245 |
-
#
|
| 246 |
if device == "cuda":
|
| 247 |
try:
|
| 248 |
pipe.vae.to(dtype=dtype)
|
| 249 |
-
# Some pipelines keep VAE in fp32 on purpose; your custom pipeline doesn't
|
| 250 |
-
# auto-cast inputs to vae.dtype before encode, so disable upcast if present.
|
| 251 |
if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
|
| 252 |
pipe.vae.config.force_upcast = False
|
| 253 |
-
print(f"[PIPE] VAE casted to {dtype}. force_upcast
|
| 254 |
except Exception as e:
|
| 255 |
print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
|
| 256 |
|
|
@@ -264,6 +275,7 @@ def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, s
|
|
| 264 |
return pipe, device, dtype
|
| 265 |
|
| 266 |
|
|
|
|
| 267 |
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
|
| 268 |
global H, W
|
| 269 |
pipe, device, _dtype = get_pipe_and_device()
|
|
|
|
| 211 |
imageio.imsave(out_path, out)
|
| 212 |
|
| 213 |
|
| 214 |
+
from diffusers import UniPCMultistepScheduler, AutoencoderKL
|
| 215 |
+
# (위 import 라인에 AutoencoderKL 추가)
|
| 216 |
+
|
| 217 |
@lru_cache(maxsize=1)
|
| 218 |
def get_pipe_and_device() -> Tuple[StableDiffusionXLControlNetImg2ImgPipeline, str, torch.dtype]:
|
| 219 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 221 |
|
| 222 |
print(f"[PIPE] device={device}, dtype={dtype}", flush=True)
|
| 223 |
|
| 224 |
+
# ControlNet
|
| 225 |
cn_kwargs = dict(torch_dtype=dtype, use_safetensors=True)
|
| 226 |
if dtype == torch.float16:
|
| 227 |
cn_kwargs["variant"] = "fp16"
|
|
|
|
| 228 |
controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, **cn_kwargs).to(device)
|
| 229 |
|
| 230 |
+
# ✅ VAE를 먼저 safetensors로 강제 로드해서 주입 (bin 찾는 경로 우회)
|
| 231 |
+
vae_kwargs = dict(
|
| 232 |
+
torch_dtype=dtype,
|
| 233 |
+
use_safetensors=True,
|
| 234 |
+
subfolder="vae",
|
| 235 |
+
)
|
| 236 |
+
# SDXL base가 fp16 variant를 제공하는 경우에만 도움이 됨 (없어도 동작)
|
| 237 |
+
if dtype == torch.float16:
|
| 238 |
+
vae_kwargs["variant"] = "fp16"
|
| 239 |
+
|
| 240 |
+
vae = AutoencoderKL.from_pretrained(BASE_MODEL_ID, **vae_kwargs).to(device)
|
| 241 |
+
|
| 242 |
+
# Pipeline
|
| 243 |
pipe_kwargs = dict(
|
| 244 |
controlnet=controlnet,
|
| 245 |
+
vae=vae, # ✅ 주입
|
| 246 |
use_safetensors=True,
|
| 247 |
torch_dtype=dtype,
|
| 248 |
add_watermarker=False,
|
|
|
|
| 255 |
**pipe_kwargs,
|
| 256 |
).to(device)
|
| 257 |
|
| 258 |
+
# ✅ (이전 dtype mismatch 방지) VAE dtype 강제 일치
|
| 259 |
if device == "cuda":
|
| 260 |
try:
|
| 261 |
pipe.vae.to(dtype=dtype)
|
|
|
|
|
|
|
| 262 |
if hasattr(pipe.vae, "config") and hasattr(pipe.vae.config, "force_upcast"):
|
| 263 |
pipe.vae.config.force_upcast = False
|
| 264 |
+
print(f"[PIPE] VAE casted to {dtype}. force_upcast False (if supported).", flush=True)
|
| 265 |
except Exception as e:
|
| 266 |
print("[PIPE] VAE dtype cast failed:", repr(e), flush=True)
|
| 267 |
|
|
|
|
| 275 |
return pipe, device, dtype
|
| 276 |
|
| 277 |
|
| 278 |
+
|
| 279 |
def run_one(paths: Paths, prompt: str, steps: int = DEFAULT_STEPS):
|
| 280 |
global H, W
|
| 281 |
pipe, device, _dtype = get_pipe_and_device()
|