File size: 7,223 Bytes
b6acc0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""CRS-Diff modular loading utilities for custom diffusers pipeline."""

import importlib
import json
import sys
from pathlib import Path
from typing import Dict, Optional, Union

import torch
from diffusers import DDIMScheduler

_PIPELINE_DIR = Path(__file__).resolve().parent
if str(_PIPELINE_DIR) not in sys.path:
    sys.path.insert(0, str(_PIPELINE_DIR))

_COMPONENT_NAMES = (
    "unet",
    "vae",
    "text_encoder",
    "local_adapter",
    "global_content_adapter",
    "global_text_adapter",
    "metadata_encoder",
)

_TARGET_MAP = {
    "crs_core.local_adapter.LocalControlUNetModel": "crs_core.local_adapter.LocalControlUNetModel",
    "crs_core.autoencoder.AutoencoderKL": "crs_core.autoencoder.AutoencoderKL",
    "crs_core.text_encoder.FrozenCLIPEmbedder": "crs_core.text_encoder.FrozenCLIPEmbedder",
    "crs_core.local_adapter.LocalAdapter": "crs_core.local_adapter.LocalAdapter",
    "crs_core.global_adapter.GlobalContentAdapter": "crs_core.global_adapter.GlobalContentAdapter",
    "crs_core.global_adapter.GlobalTextAdapter": "crs_core.global_adapter.GlobalTextAdapter",
    "crs_core.metadata_embedding.metadata_embeddings": "crs_core.metadata_embedding.metadata_embeddings",
}


def ensure_model_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
    """Resolve local path or download HF repo snapshot."""
    path = Path(pretrained_model_name_or_path)
    if not path.exists():
        from huggingface_hub import snapshot_download

        path = Path(snapshot_download(str(pretrained_model_name_or_path)))
    path = path.resolve()
    if str(path) not in sys.path:
        sys.path.insert(0, str(path))
    return path


def resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]:
    """Resolve to folder containing model_index.json."""
    if not candidate:
        return None
    path = ensure_model_path(candidate)
    if (path / "model_index.json").exists():
        return path
    cur = path
    for _ in range(5):
        parent = cur.parent
        if parent == cur:
            break
        if (parent / "model_index.json").exists():
            return parent
        cur = parent
    return None


def _get_class(target: str):
    module_path, cls_name = target.rsplit(".", 1)
    mod = importlib.import_module(module_path)
    return getattr(mod, cls_name)


def load_component(model_root: Path, name: str):
    """Load single split component from <repo>/<name>/."""
    root = Path(model_root)
    comp_path = root / name
    with (comp_path / "config.json").open("r", encoding="utf-8") as f:
        cfg = json.load(f)
    target = cfg.pop("_target", None)
    if not target:
        raise ValueError(f"No _target in {comp_path / 'config.json'}")
    target = _TARGET_MAP.get(target, target)
    cls_ref = _get_class(target)
    params = {k: v for k, v in cfg.items() if not k.startswith("_")}
    module = cls_ref(**params)

    weight_file = comp_path / "diffusion_pytorch_model.safetensors"
    if weight_file.exists():
        from safetensors.torch import load_file

        state = load_file(str(weight_file))
        module.load_state_dict(state, strict=True)
    module.eval()
    return module


class CRSModelWrapper(torch.nn.Module):
    """Wrap split components to mimic CRSControlNet APIs used by pipeline."""

    def __init__(
        self,
        unet,
        vae,
        text_encoder,
        local_adapter,
        global_content_adapter,
        global_text_adapter,
        metadata_encoder,
        channels: int = 4,
    ):
        super().__init__()
        self.model = torch.nn.Module()
        self.model.add_module("diffusion_model", unet)
        self.first_stage_model = vae
        self.cond_stage_model = text_encoder
        self.local_adapter = local_adapter
        self.global_content_adapter = global_content_adapter
        self.global_text_adapter = global_text_adapter
        self.metadata_emb = metadata_encoder
        self.local_control_scales = [1.0] * 13
        self.channels = channels

    @torch.no_grad()
    def get_learned_conditioning(self, prompts):
        if hasattr(self.cond_stage_model, "device"):
            self.cond_stage_model.device = str(next(self.parameters()).device)
        return self.cond_stage_model.encode(prompts)

    def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, **kwargs):
        del kwargs
        if metadata is None:
            metadata = cond["metadata"]
        cond_txt = torch.cat(cond["c_crossattn"], 1)

        if cond.get("global_control") is not None and cond["global_control"][0] is not None:
            metadata = self.metadata_emb(metadata)
            content_t, _ = cond["global_control"][0].chunk(2, dim=1)
            global_control = self.global_content_adapter(content_t)
            cond_txt = self.global_text_adapter(cond_txt)
            cond_txt = torch.cat([cond_txt, global_strength * global_control], dim=1)

        local_control = None
        if cond.get("local_control") is not None and cond["local_control"][0] is not None:
            local_control = torch.cat(cond["local_control"], 1)
            local_control = self.local_adapter(
                x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control
            )
            local_control = [c * s for c, s in zip(local_control, self.local_control_scales)]

        return self.model.diffusion_model(
            x=x_noisy,
            timesteps=t,
            metadata=metadata,
            context=cond_txt,
            local_control=local_control,
            meta=True,
        )

    def decode_first_stage(self, z):
        return self.first_stage_model.decode(z)


def load_components(model_root: Union[str, Path]) -> Dict[str, object]:
    """Load pipeline components from split directories."""
    root = ensure_model_path(model_root)
    scheduler = DDIMScheduler.from_pretrained(root, subfolder="scheduler")

    scale_factor = 0.18215
    channels = 4
    if (root / "model_index.json").exists():
        with (root / "model_index.json").open("r", encoding="utf-8") as f:
            idx = json.load(f)
        scale_factor = float(idx.get("scale_factor", scale_factor))
        channels = int(idx.get("channels", channels))

    has_split_components = all((root / name / "config.json").exists() for name in _COMPONENT_NAMES)
    if not has_split_components:
        missing = [name for name in _COMPONENT_NAMES if not (root / name / "config.json").exists()]
        raise FileNotFoundError(
            f"CRS-Diff split component export incomplete. Missing: {missing}. "
            "Expected split folders with config.json and weights."
        )

    loaded = {name: load_component(root, name) for name in _COMPONENT_NAMES}
    crs_model = CRSModelWrapper(
        unet=loaded["unet"],
        vae=loaded["vae"],
        text_encoder=loaded["text_encoder"],
        local_adapter=loaded["local_adapter"],
        global_content_adapter=loaded["global_content_adapter"],
        global_text_adapter=loaded["global_text_adapter"],
        metadata_encoder=loaded["metadata_encoder"],
        channels=channels,
    )

    return {"crs_model": crs_model, "scheduler": scheduler, "scale_factor": scale_factor}