| # utils/compose.py | |
| import torch | |
| from diffusers import StableDiffusionPipeline | |
| from safetensors.torch import load_file | |
| def load_and_patch_sd_pipeline(repo_id, unet_weights_path, dtype=torch.float16, device="cuda"): | |
| """ | |
| Load a base SD pipeline and patch its UNet with ESD/UCE weights. | |
| """ | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| repo_id, torch_dtype=dtype, safety_checker=None | |
| ).to(device) | |
| # Load patch state dict | |
| if unet_weights_path.endswith(".safetensors"): | |
| patch = load_file(unet_weights_path) | |
| else: | |
| patch = torch.load(unet_weights_path, map_location="cpu") | |
| sd = pipe.unet.state_dict() | |
| sd.update(patch) | |
| pipe.unet.load_state_dict(sd, strict=True) | |
| return pipe | |