# utils/compose.py import torch from diffusers import StableDiffusionPipeline from safetensors.torch import load_file def load_and_patch_sd_pipeline(repo_id, unet_weights_path, dtype=torch.float16, device="cuda"): """ Load a base SD pipeline and patch its UNet with ESD/UCE weights. """ pipe = StableDiffusionPipeline.from_pretrained( repo_id, torch_dtype=dtype, safety_checker=None ).to(device) # Load patch state dict if unet_weights_path.endswith(".safetensors"): patch = load_file(unet_weights_path) else: patch = torch.load(unet_weights_path, map_location="cpu") sd = pipe.unet.state_dict() sd.update(patch) pipe.unet.load_state_dict(sd, strict=True) return pipe