| | |
| | import torch |
| | import traceback |
| | from diffusers import StableDiffusionInpaintPipeline |
| |
|
| | |
| | try: |
| | from diffusers import ControlNetModel |
| | except Exception: |
| | ControlNetModel = None |
| |
|
| | def load_inpaint_model( |
| | model_id: str = "SG161222/Realistic_Vision_V5.0_noVAE", |
| | controlnet_id: str | None = None, |
| | device: str | None = None, |
| | dtype=None, |
| | ): |
| | """ |
| | Load an SD 1.5 style inpainting pipeline. |
| | - model_id: HF repo id or local path for the converted diffusers model. |
| | - controlnet_id: optional HF repo id for ControlNet (compatible SDv1 ControlNet). |
| | - device: "cuda" or "cpu" (auto-detected if None) |
| | - dtype: torch.float16 recommended for CUDA |
| | Returns: pipe |
| | """ |
| |
|
| | if device is None: |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | if dtype is None: |
| | dtype = torch.float16 if device == "cuda" else torch.float32 |
| |
|
| | print(f"[load_model] Loading inpaint pipeline '{model_id}' -> device={device}, dtype={dtype}") |
| | pipe = StableDiffusionInpaintPipeline.from_pretrained( |
| | model_id, |
| | torch_dtype=dtype, |
| | safety_checker=None, |
| | ) |
| |
|
| | |
| | if controlnet_id: |
| | if ControlNetModel is None: |
| | print("[load_model] ControlNetModel unavailable in this environment (skipping controlnet attach).") |
| | else: |
| | try: |
| | print(f"[load_model] Loading ControlNet: {controlnet_id}") |
| | cnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=dtype) |
| | |
| | pipe.controlnet = cnet |
| | print("[load_model] ControlNet attached to pipeline as `pipe.controlnet` (you may need custom logic to use it).") |
| | except Exception as e: |
| | print("[load_model] Failed to load/attach ControlNet:", e) |
| | traceback.print_exc() |
| |
|
| | |
| | pipe = pipe.to(device) |
| |
|
| | |
| | try: |
| | pipe.enable_xformers_memory_efficient_attention() |
| | print("[load_model] xFormers enabled") |
| | except Exception as e: |
| | print("[load_model] xFormers not enabled (it may not be installed):", e) |
| |
|
| | try: |
| | pipe.enable_attention_slicing() |
| | except Exception: |
| | pass |
| |
|
| | print("[load_model] Pipeline ready") |
| | return pipe |