""" HSIGene modular components: path setup and component loading. AeroGen-style: ensure_ldm_path adds model dir to sys.path so hsigene can be imported. No manual sys.path.insert needed when using DiffusionPipeline.from_pretrained(path). """ import importlib import json import sys from pathlib import Path from typing import Union from diffusers import DDIMScheduler # Ensure model dir is on path for hsigene imports _pipeline_dir = Path(__file__).resolve().parent if str(_pipeline_dir) not in sys.path: sys.path.insert(0, str(_pipeline_dir)) _COMPONENT_NAMES = ( "unet", "vae", "text_encoder", "local_adapter", "global_content_adapter", "global_text_adapter", "metadata_encoder", ) _TARGET_MAP = { "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet", "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet", "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder", "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder", "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter", "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter", "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", } def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: """Add model repo to path so hsigene can be imported. Returns resolved path.""" path = Path(pretrained_model_name_or_path) if not path.exists(): from huggingface_hub import snapshot_download path = Path(snapshot_download(pretrained_model_name_or_path)) path = path.resolve() s = str(path) if s not in sys.path: sys.path.insert(0, s) return path def _get_class(target: str): module_path, cls_name = target.rsplit(".", 1) mod = importlib.import_module(module_path) return getattr(mod, cls_name) def load_component(model_path: Path, name: str): """Load a single component (unet, vae, text_encoder, etc.).""" import torch path = Path(model_path) root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path ensure_ldm_path(root) comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name with open(comp_path / "config.json") as f: cfg = json.load(f) target = cfg.pop("_target", None) if not target: raise ValueError(f"No _target in {comp_path / 'config.json'}") target = _TARGET_MAP.get(target, target) cls_ref = _get_class(target) params = {k: v for k, v in cfg.items() if not k.startswith("_")} comp = cls_ref(**params) for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"): wp = comp_path / wfile if wp.exists(): if wfile.endswith(".safetensors"): from safetensors.torch import load_file state = load_file(str(wp)) else: try: state = torch.load(wp, map_location="cpu", weights_only=True) except TypeError: state = torch.load(wp, map_location="cpu") comp.load_state_dict(state, strict=True) break comp.eval() return comp def load_components(model_path: Union[str, Path]) -> dict: """Load all pipeline components. Returns dict with components, scheduler, scale_factor.""" path = Path(ensure_ldm_path(model_path)) if path.name in _COMPONENT_NAMES and (path / "config.json").exists(): path = path.parent scheduler = DDIMScheduler.from_pretrained(path / "scheduler") components = {} for name in _COMPONENT_NAMES: components[name] = load_component(path, name) scale_factor = 0.18215 if (path / "model_index.json").exists(): with open(path / "model_index.json") as f: scale_factor = json.load(f).get("scale_factor", scale_factor) components["scheduler"] = scheduler components["scale_factor"] = scale_factor return components