File size: 766 Bytes
b8877ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | # utils/compose.py
import torch
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
def load_and_patch_sd_pipeline(repo_id, unet_weights_path, dtype=torch.float16, device="cuda"):
"""
Load a base SD pipeline and patch its UNet with ESD/UCE weights.
"""
pipe = StableDiffusionPipeline.from_pretrained(
repo_id, torch_dtype=dtype, safety_checker=None
).to(device)
# Load patch state dict
if unet_weights_path.endswith(".safetensors"):
patch = load_file(unet_weights_path)
else:
patch = torch.load(unet_weights_path, map_location="cpu")
sd = pipe.unet.state_dict()
sd.update(patch)
pipe.unet.load_state_dict(sd, strict=True)
return pipe
|