File size: 766 Bytes
b8877ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
# utils/compose.py
import torch
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file

def load_and_patch_sd_pipeline(repo_id, unet_weights_path, dtype=torch.float16, device="cuda"):
    """

    Load a base SD pipeline and patch its UNet with ESD/UCE weights.

    """
    pipe = StableDiffusionPipeline.from_pretrained(
        repo_id, torch_dtype=dtype, safety_checker=None
    ).to(device)

    # Load patch state dict
    if unet_weights_path.endswith(".safetensors"):
        patch = load_file(unet_weights_path)
    else:
        patch = torch.load(unet_weights_path, map_location="cpu")

    sd = pipe.unet.state_dict()
    sd.update(patch)
    pipe.unet.load_state_dict(sd, strict=True)
    return pipe