| """ |
| HSIGene modular components: path setup and component loading. |
| |
| AeroGen-style: ensure_ldm_path adds model dir to sys.path so hsigene can be imported. |
| No manual sys.path.insert needed when using DiffusionPipeline.from_pretrained(path). |
| """ |
|
|
| import importlib |
| import json |
| import sys |
| from pathlib import Path |
| from typing import Union |
|
|
| from diffusers import DDIMScheduler |
|
|
| |
| _pipeline_dir = Path(__file__).resolve().parent |
| if str(_pipeline_dir) not in sys.path: |
| sys.path.insert(0, str(_pipeline_dir)) |
|
|
| _COMPONENT_NAMES = ( |
| "unet", "vae", "text_encoder", "local_adapter", |
| "global_content_adapter", "global_text_adapter", "metadata_encoder", |
| ) |
|
|
| _TARGET_MAP = { |
| "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet", |
| "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet", |
| "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", |
| "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", |
| "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder", |
| "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder", |
| "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter", |
| "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter", |
| "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", |
| "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", |
| "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", |
| "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", |
| "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", |
| "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", |
| } |
|
|
|
|
| def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: |
| """Add model repo to path so hsigene can be imported. Returns resolved path.""" |
| path = Path(pretrained_model_name_or_path) |
| if not path.exists(): |
| from huggingface_hub import snapshot_download |
| path = Path(snapshot_download(pretrained_model_name_or_path)) |
| path = path.resolve() |
| s = str(path) |
| if s not in sys.path: |
| sys.path.insert(0, s) |
| return path |
|
|
|
|
| def _get_class(target: str): |
| module_path, cls_name = target.rsplit(".", 1) |
| mod = importlib.import_module(module_path) |
| return getattr(mod, cls_name) |
|
|
|
|
| def load_component(model_path: Path, name: str): |
| """Load a single component (unet, vae, text_encoder, etc.).""" |
| import torch |
| path = Path(model_path) |
| root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path |
| ensure_ldm_path(root) |
| comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name |
| with open(comp_path / "config.json") as f: |
| cfg = json.load(f) |
| target = cfg.pop("_target", None) |
| if not target: |
| raise ValueError(f"No _target in {comp_path / 'config.json'}") |
| target = _TARGET_MAP.get(target, target) |
| cls_ref = _get_class(target) |
| params = {k: v for k, v in cfg.items() if not k.startswith("_")} |
| comp = cls_ref(**params) |
| for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"): |
| wp = comp_path / wfile |
| if wp.exists(): |
| if wfile.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| state = load_file(str(wp)) |
| else: |
| try: |
| state = torch.load(wp, map_location="cpu", weights_only=True) |
| except TypeError: |
| state = torch.load(wp, map_location="cpu") |
| comp.load_state_dict(state, strict=True) |
| break |
| comp.eval() |
| return comp |
|
|
|
|
| def load_components(model_path: Union[str, Path]) -> dict: |
| """Load all pipeline components. Returns dict with components, scheduler, scale_factor.""" |
| path = Path(ensure_ldm_path(model_path)) |
| if path.name in _COMPONENT_NAMES and (path / "config.json").exists(): |
| path = path.parent |
| scheduler = DDIMScheduler.from_pretrained(path / "scheduler") |
| components = {} |
| for name in _COMPONENT_NAMES: |
| components[name] = load_component(path, name) |
| scale_factor = 0.18215 |
| if (path / "model_index.json").exists(): |
| with open(path / "model_index.json") as f: |
| scale_factor = json.load(f).get("scale_factor", scale_factor) |
| components["scheduler"] = scheduler |
| components["scale_factor"] = scale_factor |
| return components |
|
|