# load_model.py import torch import traceback from diffusers import StableDiffusionInpaintPipeline # Optional ControlNet import (import only when needed) try: from diffusers import ControlNetModel except Exception: ControlNetModel = None def load_inpaint_model( model_id: str = "SG161222/Realistic_Vision_V5.0_noVAE", controlnet_id: str | None = None, device: str | None = None, dtype=None, ): """ Load an SD 1.5 style inpainting pipeline. - model_id: HF repo id or local path for the converted diffusers model. - controlnet_id: optional HF repo id for ControlNet (compatible SDv1 ControlNet). - device: "cuda" or "cpu" (auto-detected if None) - dtype: torch.float16 recommended for CUDA Returns: pipe """ if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if dtype is None: dtype = torch.float16 if device == "cuda" else torch.float32 print(f"[load_model] Loading inpaint pipeline '{model_id}' -> device={device}, dtype={dtype}") pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, torch_dtype=dtype, safety_checker=None, ) # If user supplied a ControlNet id and the environment has the ControlNetModel class: if controlnet_id: if ControlNetModel is None: print("[load_model] ControlNetModel unavailable in this environment (skipping controlnet attach).") else: try: print(f"[load_model] Loading ControlNet: {controlnet_id}") cnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=dtype) # attach as attribute for later use; exact usage depends on pipeline implementation pipe.controlnet = cnet print("[load_model] ControlNet attached to pipeline as `pipe.controlnet` (you may need custom logic to use it).") except Exception as e: print("[load_model] Failed to load/attach ControlNet:", e) traceback.print_exc() # Move to device and enable memory optimizations pipe = pipe.to(device) # Try to enable xformers for memory efficient attention (may fail if not installed) try: pipe.enable_xformers_memory_efficient_attention() print("[load_model] xFormers enabled") except Exception as e: print("[load_model] xFormers not enabled (it may not be installed):", e) try: pipe.enable_attention_slicing() except Exception: pass print("[load_model] Pipeline ready") return pipe