inpaint / load_model.py
r455-007's picture
Update load_model.py
e6376a4 verified
# load_model.py
import torch
import traceback
from diffusers import StableDiffusionInpaintPipeline
# Optional ControlNet import (import only when needed)
try:
from diffusers import ControlNetModel
except Exception:
ControlNetModel = None
def load_inpaint_model(
model_id: str = "SG161222/Realistic_Vision_V5.0_noVAE",
controlnet_id: str | None = None,
device: str | None = None,
dtype=None,
):
"""
Load an SD 1.5 style inpainting pipeline.
- model_id: HF repo id or local path for the converted diffusers model.
- controlnet_id: optional HF repo id for ControlNet (compatible SDv1 ControlNet).
- device: "cuda" or "cpu" (auto-detected if None)
- dtype: torch.float16 recommended for CUDA
Returns: pipe
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if dtype is None:
dtype = torch.float16 if device == "cuda" else torch.float32
print(f"[load_model] Loading inpaint pipeline '{model_id}' -> device={device}, dtype={dtype}")
pipe = StableDiffusionInpaintPipeline.from_pretrained(
model_id,
torch_dtype=dtype,
safety_checker=None,
)
# If user supplied a ControlNet id and the environment has the ControlNetModel class:
if controlnet_id:
if ControlNetModel is None:
print("[load_model] ControlNetModel unavailable in this environment (skipping controlnet attach).")
else:
try:
print(f"[load_model] Loading ControlNet: {controlnet_id}")
cnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=dtype)
# attach as attribute for later use; exact usage depends on pipeline implementation
pipe.controlnet = cnet
print("[load_model] ControlNet attached to pipeline as `pipe.controlnet` (you may need custom logic to use it).")
except Exception as e:
print("[load_model] Failed to load/attach ControlNet:", e)
traceback.print_exc()
# Move to device and enable memory optimizations
pipe = pipe.to(device)
# Try to enable xformers for memory efficient attention (may fail if not installed)
try:
pipe.enable_xformers_memory_efficient_attention()
print("[load_model] xFormers enabled")
except Exception as e:
print("[load_model] xFormers not enabled (it may not be installed):", e)
try:
pipe.enable_attention_slicing()
except Exception:
pass
print("[load_model] Pipeline ready")
return pipe