| | """ |
| | HSIGene modular components: path setup and component loading. |
| | |
| | AeroGen-style: ensure_ldm_path adds model dir to sys.path so hsigene can be imported. |
| | No manual sys.path.insert needed when using DiffusionPipeline.from_pretrained(path). |
| | """ |
| |
|
| | import importlib |
| | import json |
| | import sys |
| | from pathlib import Path |
| | from typing import Union |
| |
|
| | from diffusers import DDIMScheduler |
| |
|
| | |
| | _pipeline_dir = Path(__file__).resolve().parent |
| | if str(_pipeline_dir) not in sys.path: |
| | sys.path.insert(0, str(_pipeline_dir)) |
| |
|
| | _COMPONENT_NAMES = ( |
| | "unet", "vae", "text_encoder", "local_adapter", |
| | "global_content_adapter", "global_text_adapter", "metadata_encoder", |
| | ) |
| |
|
| | _TARGET_MAP = { |
| | "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet", |
| | "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet", |
| | "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", |
| | "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", |
| | "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder", |
| | "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder", |
| | "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter", |
| | "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter", |
| | "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", |
| | "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", |
| | "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", |
| | "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", |
| | "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", |
| | "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", |
| | } |
| |
|
| |
|
| | def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: |
| | """Add model repo to path so hsigene can be imported. Returns resolved path.""" |
| | path = Path(pretrained_model_name_or_path) |
| | if not path.exists(): |
| | from huggingface_hub import snapshot_download |
| | path = Path(snapshot_download(pretrained_model_name_or_path)) |
| | path = path.resolve() |
| | s = str(path) |
| | if s not in sys.path: |
| | sys.path.insert(0, s) |
| | return path |
| |
|
| |
|
| | def _get_class(target: str): |
| | module_path, cls_name = target.rsplit(".", 1) |
| | mod = importlib.import_module(module_path) |
| | return getattr(mod, cls_name) |
| |
|
| |
|
| | def load_component(model_path: Path, name: str): |
| | """Load a single component (unet, vae, text_encoder, etc.).""" |
| | import torch |
| | path = Path(model_path) |
| | root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path |
| | ensure_ldm_path(root) |
| | comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name |
| | with open(comp_path / "config.json") as f: |
| | cfg = json.load(f) |
| | target = cfg.pop("_target", None) |
| | if not target: |
| | raise ValueError(f"No _target in {comp_path / 'config.json'}") |
| | target = _TARGET_MAP.get(target, target) |
| | cls_ref = _get_class(target) |
| | params = {k: v for k, v in cfg.items() if not k.startswith("_")} |
| | comp = cls_ref(**params) |
| | for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"): |
| | wp = comp_path / wfile |
| | if wp.exists(): |
| | if wfile.endswith(".safetensors"): |
| | from safetensors.torch import load_file |
| | state = load_file(str(wp)) |
| | else: |
| | try: |
| | state = torch.load(wp, map_location="cpu", weights_only=True) |
| | except TypeError: |
| | state = torch.load(wp, map_location="cpu") |
| | comp.load_state_dict(state, strict=True) |
| | break |
| | comp.eval() |
| | return comp |
| |
|
| |
|
| | def load_components(model_path: Union[str, Path]) -> dict: |
| | """Load all pipeline components. Returns dict with components, scheduler, scale_factor.""" |
| | path = Path(ensure_ldm_path(model_path)) |
| | if path.name in _COMPONENT_NAMES and (path / "config.json").exists(): |
| | path = path.parent |
| | scheduler = DDIMScheduler.from_pretrained(path / "scheduler") |
| | components = {} |
| | for name in _COMPONENT_NAMES: |
| | components[name] = load_component(path, name) |
| | scale_factor = 0.18215 |
| | if (path / "model_index.json").exists(): |
| | with open(path / "model_index.json") as f: |
| | scale_factor = json.load(f).get("scale_factor", scale_factor) |
| | components["scheduler"] = scheduler |
| | components["scale_factor"] = scale_factor |
| | return components |
| |
|