CRS-Diff / modular_pipeline.py
BiliSakura's picture
Add files using upload-large-folder tool
b6acc0a verified
"""CRS-Diff modular loading utilities for custom diffusers pipeline."""
import importlib
import json
import sys
from pathlib import Path
from typing import Dict, Optional, Union
import torch
from diffusers import DDIMScheduler
_PIPELINE_DIR = Path(__file__).resolve().parent
if str(_PIPELINE_DIR) not in sys.path:
sys.path.insert(0, str(_PIPELINE_DIR))
_COMPONENT_NAMES = (
"unet",
"vae",
"text_encoder",
"local_adapter",
"global_content_adapter",
"global_text_adapter",
"metadata_encoder",
)
_TARGET_MAP = {
"crs_core.local_adapter.LocalControlUNetModel": "crs_core.local_adapter.LocalControlUNetModel",
"crs_core.autoencoder.AutoencoderKL": "crs_core.autoencoder.AutoencoderKL",
"crs_core.text_encoder.FrozenCLIPEmbedder": "crs_core.text_encoder.FrozenCLIPEmbedder",
"crs_core.local_adapter.LocalAdapter": "crs_core.local_adapter.LocalAdapter",
"crs_core.global_adapter.GlobalContentAdapter": "crs_core.global_adapter.GlobalContentAdapter",
"crs_core.global_adapter.GlobalTextAdapter": "crs_core.global_adapter.GlobalTextAdapter",
"crs_core.metadata_embedding.metadata_embeddings": "crs_core.metadata_embedding.metadata_embeddings",
}
def ensure_model_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
"""Resolve local path or download HF repo snapshot."""
path = Path(pretrained_model_name_or_path)
if not path.exists():
from huggingface_hub import snapshot_download
path = Path(snapshot_download(str(pretrained_model_name_or_path)))
path = path.resolve()
if str(path) not in sys.path:
sys.path.insert(0, str(path))
return path
def resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]:
"""Resolve to folder containing model_index.json."""
if not candidate:
return None
path = ensure_model_path(candidate)
if (path / "model_index.json").exists():
return path
cur = path
for _ in range(5):
parent = cur.parent
if parent == cur:
break
if (parent / "model_index.json").exists():
return parent
cur = parent
return None
def _get_class(target: str):
module_path, cls_name = target.rsplit(".", 1)
mod = importlib.import_module(module_path)
return getattr(mod, cls_name)
def load_component(model_root: Path, name: str):
"""Load single split component from <repo>/<name>/."""
root = Path(model_root)
comp_path = root / name
with (comp_path / "config.json").open("r", encoding="utf-8") as f:
cfg = json.load(f)
target = cfg.pop("_target", None)
if not target:
raise ValueError(f"No _target in {comp_path / 'config.json'}")
target = _TARGET_MAP.get(target, target)
cls_ref = _get_class(target)
params = {k: v for k, v in cfg.items() if not k.startswith("_")}
module = cls_ref(**params)
weight_file = comp_path / "diffusion_pytorch_model.safetensors"
if weight_file.exists():
from safetensors.torch import load_file
state = load_file(str(weight_file))
module.load_state_dict(state, strict=True)
module.eval()
return module
class CRSModelWrapper(torch.nn.Module):
"""Wrap split components to mimic CRSControlNet APIs used by pipeline."""
def __init__(
self,
unet,
vae,
text_encoder,
local_adapter,
global_content_adapter,
global_text_adapter,
metadata_encoder,
channels: int = 4,
):
super().__init__()
self.model = torch.nn.Module()
self.model.add_module("diffusion_model", unet)
self.first_stage_model = vae
self.cond_stage_model = text_encoder
self.local_adapter = local_adapter
self.global_content_adapter = global_content_adapter
self.global_text_adapter = global_text_adapter
self.metadata_emb = metadata_encoder
self.local_control_scales = [1.0] * 13
self.channels = channels
@torch.no_grad()
def get_learned_conditioning(self, prompts):
if hasattr(self.cond_stage_model, "device"):
self.cond_stage_model.device = str(next(self.parameters()).device)
return self.cond_stage_model.encode(prompts)
def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, **kwargs):
del kwargs
if metadata is None:
metadata = cond["metadata"]
cond_txt = torch.cat(cond["c_crossattn"], 1)
if cond.get("global_control") is not None and cond["global_control"][0] is not None:
metadata = self.metadata_emb(metadata)
content_t, _ = cond["global_control"][0].chunk(2, dim=1)
global_control = self.global_content_adapter(content_t)
cond_txt = self.global_text_adapter(cond_txt)
cond_txt = torch.cat([cond_txt, global_strength * global_control], dim=1)
local_control = None
if cond.get("local_control") is not None and cond["local_control"][0] is not None:
local_control = torch.cat(cond["local_control"], 1)
local_control = self.local_adapter(
x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control
)
local_control = [c * s for c, s in zip(local_control, self.local_control_scales)]
return self.model.diffusion_model(
x=x_noisy,
timesteps=t,
metadata=metadata,
context=cond_txt,
local_control=local_control,
meta=True,
)
def decode_first_stage(self, z):
return self.first_stage_model.decode(z)
def load_components(model_root: Union[str, Path]) -> Dict[str, object]:
"""Load pipeline components from split directories."""
root = ensure_model_path(model_root)
scheduler = DDIMScheduler.from_pretrained(root, subfolder="scheduler")
scale_factor = 0.18215
channels = 4
if (root / "model_index.json").exists():
with (root / "model_index.json").open("r", encoding="utf-8") as f:
idx = json.load(f)
scale_factor = float(idx.get("scale_factor", scale_factor))
channels = int(idx.get("channels", channels))
has_split_components = all((root / name / "config.json").exists() for name in _COMPONENT_NAMES)
if not has_split_components:
missing = [name for name in _COMPONENT_NAMES if not (root / name / "config.json").exists()]
raise FileNotFoundError(
f"CRS-Diff split component export incomplete. Missing: {missing}. "
"Expected split folders with config.json and weights."
)
loaded = {name: load_component(root, name) for name in _COMPONENT_NAMES}
crs_model = CRSModelWrapper(
unet=loaded["unet"],
vae=loaded["vae"],
text_encoder=loaded["text_encoder"],
local_adapter=loaded["local_adapter"],
global_content_adapter=loaded["global_content_adapter"],
global_text_adapter=loaded["global_text_adapter"],
metadata_encoder=loaded["metadata_encoder"],
channels=channels,
)
return {"crs_model": crs_model, "scheduler": scheduler, "scale_factor": scale_factor}