| | """CRS-Diff modular loading utilities for custom diffusers pipeline.""" |
| |
|
| | import importlib |
| | import json |
| | import sys |
| | from pathlib import Path |
| | from typing import Dict, Optional, Union |
| |
|
| | import torch |
| | from diffusers import DDIMScheduler |
| |
|
| | _PIPELINE_DIR = Path(__file__).resolve().parent |
| | if str(_PIPELINE_DIR) not in sys.path: |
| | sys.path.insert(0, str(_PIPELINE_DIR)) |
| |
|
| | _COMPONENT_NAMES = ( |
| | "unet", |
| | "vae", |
| | "text_encoder", |
| | "local_adapter", |
| | "global_content_adapter", |
| | "global_text_adapter", |
| | "metadata_encoder", |
| | ) |
| |
|
| | _TARGET_MAP = { |
| | "crs_core.local_adapter.LocalControlUNetModel": "crs_core.local_adapter.LocalControlUNetModel", |
| | "crs_core.autoencoder.AutoencoderKL": "crs_core.autoencoder.AutoencoderKL", |
| | "crs_core.text_encoder.FrozenCLIPEmbedder": "crs_core.text_encoder.FrozenCLIPEmbedder", |
| | "crs_core.local_adapter.LocalAdapter": "crs_core.local_adapter.LocalAdapter", |
| | "crs_core.global_adapter.GlobalContentAdapter": "crs_core.global_adapter.GlobalContentAdapter", |
| | "crs_core.global_adapter.GlobalTextAdapter": "crs_core.global_adapter.GlobalTextAdapter", |
| | "crs_core.metadata_embedding.metadata_embeddings": "crs_core.metadata_embedding.metadata_embeddings", |
| | } |
| |
|
| |
|
| | def ensure_model_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: |
| | """Resolve local path or download HF repo snapshot.""" |
| | path = Path(pretrained_model_name_or_path) |
| | if not path.exists(): |
| | from huggingface_hub import snapshot_download |
| |
|
| | path = Path(snapshot_download(str(pretrained_model_name_or_path))) |
| | path = path.resolve() |
| | if str(path) not in sys.path: |
| | sys.path.insert(0, str(path)) |
| | return path |
| |
|
| |
|
| | def resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]: |
| | """Resolve to folder containing model_index.json.""" |
| | if not candidate: |
| | return None |
| | path = ensure_model_path(candidate) |
| | if (path / "model_index.json").exists(): |
| | return path |
| | cur = path |
| | for _ in range(5): |
| | parent = cur.parent |
| | if parent == cur: |
| | break |
| | if (parent / "model_index.json").exists(): |
| | return parent |
| | cur = parent |
| | return None |
| |
|
| |
|
| | def _get_class(target: str): |
| | module_path, cls_name = target.rsplit(".", 1) |
| | mod = importlib.import_module(module_path) |
| | return getattr(mod, cls_name) |
| |
|
| |
|
| | def load_component(model_root: Path, name: str): |
| | """Load single split component from <repo>/<name>/.""" |
| | root = Path(model_root) |
| | comp_path = root / name |
| | with (comp_path / "config.json").open("r", encoding="utf-8") as f: |
| | cfg = json.load(f) |
| | target = cfg.pop("_target", None) |
| | if not target: |
| | raise ValueError(f"No _target in {comp_path / 'config.json'}") |
| | target = _TARGET_MAP.get(target, target) |
| | cls_ref = _get_class(target) |
| | params = {k: v for k, v in cfg.items() if not k.startswith("_")} |
| | module = cls_ref(**params) |
| |
|
| | weight_file = comp_path / "diffusion_pytorch_model.safetensors" |
| | if weight_file.exists(): |
| | from safetensors.torch import load_file |
| |
|
| | state = load_file(str(weight_file)) |
| | module.load_state_dict(state, strict=True) |
| | module.eval() |
| | return module |
| |
|
| |
|
| | class CRSModelWrapper(torch.nn.Module): |
| | """Wrap split components to mimic CRSControlNet APIs used by pipeline.""" |
| |
|
| | def __init__( |
| | self, |
| | unet, |
| | vae, |
| | text_encoder, |
| | local_adapter, |
| | global_content_adapter, |
| | global_text_adapter, |
| | metadata_encoder, |
| | channels: int = 4, |
| | ): |
| | super().__init__() |
| | self.model = torch.nn.Module() |
| | self.model.add_module("diffusion_model", unet) |
| | self.first_stage_model = vae |
| | self.cond_stage_model = text_encoder |
| | self.local_adapter = local_adapter |
| | self.global_content_adapter = global_content_adapter |
| | self.global_text_adapter = global_text_adapter |
| | self.metadata_emb = metadata_encoder |
| | self.local_control_scales = [1.0] * 13 |
| | self.channels = channels |
| |
|
| | @torch.no_grad() |
| | def get_learned_conditioning(self, prompts): |
| | if hasattr(self.cond_stage_model, "device"): |
| | self.cond_stage_model.device = str(next(self.parameters()).device) |
| | return self.cond_stage_model.encode(prompts) |
| |
|
| | def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, **kwargs): |
| | del kwargs |
| | if metadata is None: |
| | metadata = cond["metadata"] |
| | cond_txt = torch.cat(cond["c_crossattn"], 1) |
| |
|
| | if cond.get("global_control") is not None and cond["global_control"][0] is not None: |
| | metadata = self.metadata_emb(metadata) |
| | content_t, _ = cond["global_control"][0].chunk(2, dim=1) |
| | global_control = self.global_content_adapter(content_t) |
| | cond_txt = self.global_text_adapter(cond_txt) |
| | cond_txt = torch.cat([cond_txt, global_strength * global_control], dim=1) |
| |
|
| | local_control = None |
| | if cond.get("local_control") is not None and cond["local_control"][0] is not None: |
| | local_control = torch.cat(cond["local_control"], 1) |
| | local_control = self.local_adapter( |
| | x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control |
| | ) |
| | local_control = [c * s for c, s in zip(local_control, self.local_control_scales)] |
| |
|
| | return self.model.diffusion_model( |
| | x=x_noisy, |
| | timesteps=t, |
| | metadata=metadata, |
| | context=cond_txt, |
| | local_control=local_control, |
| | meta=True, |
| | ) |
| |
|
| | def decode_first_stage(self, z): |
| | return self.first_stage_model.decode(z) |
| |
|
| |
|
| | def load_components(model_root: Union[str, Path]) -> Dict[str, object]: |
| | """Load pipeline components from split directories.""" |
| | root = ensure_model_path(model_root) |
| | scheduler = DDIMScheduler.from_pretrained(root, subfolder="scheduler") |
| |
|
| | scale_factor = 0.18215 |
| | channels = 4 |
| | if (root / "model_index.json").exists(): |
| | with (root / "model_index.json").open("r", encoding="utf-8") as f: |
| | idx = json.load(f) |
| | scale_factor = float(idx.get("scale_factor", scale_factor)) |
| | channels = int(idx.get("channels", channels)) |
| |
|
| | has_split_components = all((root / name / "config.json").exists() for name in _COMPONENT_NAMES) |
| | if not has_split_components: |
| | missing = [name for name in _COMPONENT_NAMES if not (root / name / "config.json").exists()] |
| | raise FileNotFoundError( |
| | f"CRS-Diff split component export incomplete. Missing: {missing}. " |
| | "Expected split folders with config.json and weights." |
| | ) |
| |
|
| | loaded = {name: load_component(root, name) for name in _COMPONENT_NAMES} |
| | crs_model = CRSModelWrapper( |
| | unet=loaded["unet"], |
| | vae=loaded["vae"], |
| | text_encoder=loaded["text_encoder"], |
| | local_adapter=loaded["local_adapter"], |
| | global_content_adapter=loaded["global_content_adapter"], |
| | global_text_adapter=loaded["global_text_adapter"], |
| | metadata_encoder=loaded["metadata_encoder"], |
| | channels=channels, |
| | ) |
| |
|
| | return {"crs_model": crs_model, "scheduler": scheduler, "scale_factor": scale_factor} |
| |
|