HSIGene / modular_pipeline.py
BiliSakura's picture
Add files using upload-large-folder tool
66a2b45 verified
"""
HSIGene modular components: path setup and component loading.
AeroGen-style: ensure_ldm_path adds model dir to sys.path so hsigene can be imported.
No manual sys.path.insert needed when using DiffusionPipeline.from_pretrained(path).
"""
import importlib
import json
import sys
from pathlib import Path
from typing import Union
from diffusers import DDIMScheduler
# Ensure model dir is on path for hsigene imports
_pipeline_dir = Path(__file__).resolve().parent
if str(_pipeline_dir) not in sys.path:
sys.path.insert(0, str(_pipeline_dir))
_COMPONENT_NAMES = (
"unet", "vae", "text_encoder", "local_adapter",
"global_content_adapter", "global_text_adapter", "metadata_encoder",
)
_TARGET_MAP = {
"hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet",
"hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet",
"hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
"hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
"ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder",
"hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder",
"models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter",
"hsigene.LocalAdapter": "local_adapter.model.LocalAdapter",
"models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
"hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
"models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
"hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
"models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
"hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
}
def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
"""Add model repo to path so hsigene can be imported. Returns resolved path."""
path = Path(pretrained_model_name_or_path)
if not path.exists():
from huggingface_hub import snapshot_download
path = Path(snapshot_download(pretrained_model_name_or_path))
path = path.resolve()
s = str(path)
if s not in sys.path:
sys.path.insert(0, s)
return path
def _get_class(target: str):
module_path, cls_name = target.rsplit(".", 1)
mod = importlib.import_module(module_path)
return getattr(mod, cls_name)
def load_component(model_path: Path, name: str):
"""Load a single component (unet, vae, text_encoder, etc.)."""
import torch
path = Path(model_path)
root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path
ensure_ldm_path(root)
comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name
with open(comp_path / "config.json") as f:
cfg = json.load(f)
target = cfg.pop("_target", None)
if not target:
raise ValueError(f"No _target in {comp_path / 'config.json'}")
target = _TARGET_MAP.get(target, target)
cls_ref = _get_class(target)
params = {k: v for k, v in cfg.items() if not k.startswith("_")}
comp = cls_ref(**params)
for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
wp = comp_path / wfile
if wp.exists():
if wfile.endswith(".safetensors"):
from safetensors.torch import load_file
state = load_file(str(wp))
else:
try:
state = torch.load(wp, map_location="cpu", weights_only=True)
except TypeError:
state = torch.load(wp, map_location="cpu")
comp.load_state_dict(state, strict=True)
break
comp.eval()
return comp
def load_components(model_path: Union[str, Path]) -> dict:
"""Load all pipeline components. Returns dict with components, scheduler, scale_factor."""
path = Path(ensure_ldm_path(model_path))
if path.name in _COMPONENT_NAMES and (path / "config.json").exists():
path = path.parent
scheduler = DDIMScheduler.from_pretrained(path / "scheduler")
components = {}
for name in _COMPONENT_NAMES:
components[name] = load_component(path, name)
scale_factor = 0.18215
if (path / "model_index.json").exists():
with open(path / "model_index.json") as f:
scale_factor = json.load(f).get("scale_factor", scale_factor)
components["scheduler"] = scheduler
components["scale_factor"] = scale_factor
return components