File size: 4,635 Bytes
66a2b45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
HSIGene modular components: path setup and component loading.

AeroGen-style: ensure_ldm_path adds model dir to sys.path so hsigene can be imported.
No manual sys.path.insert needed when using DiffusionPipeline.from_pretrained(path).
"""

import importlib
import json
import sys
from pathlib import Path
from typing import Union

from diffusers import DDIMScheduler

# Ensure model dir is on path for hsigene imports
_pipeline_dir = Path(__file__).resolve().parent
if str(_pipeline_dir) not in sys.path:
    sys.path.insert(0, str(_pipeline_dir))

_COMPONENT_NAMES = (
    "unet", "vae", "text_encoder", "local_adapter",
    "global_content_adapter", "global_text_adapter", "metadata_encoder",
)

_TARGET_MAP = {
    "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet",
    "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet",
    "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
    "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
    "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder",
    "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder",
    "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter",
    "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter",
    "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
    "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
    "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
    "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
    "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
    "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
}


def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
    """Add model repo to path so hsigene can be imported. Returns resolved path."""
    path = Path(pretrained_model_name_or_path)
    if not path.exists():
        from huggingface_hub import snapshot_download
        path = Path(snapshot_download(pretrained_model_name_or_path))
    path = path.resolve()
    s = str(path)
    if s not in sys.path:
        sys.path.insert(0, s)
    return path


def _get_class(target: str):
    module_path, cls_name = target.rsplit(".", 1)
    mod = importlib.import_module(module_path)
    return getattr(mod, cls_name)


def load_component(model_path: Path, name: str):
    """Load a single component (unet, vae, text_encoder, etc.)."""
    import torch
    path = Path(model_path)
    root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path
    ensure_ldm_path(root)
    comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name
    with open(comp_path / "config.json") as f:
        cfg = json.load(f)
    target = cfg.pop("_target", None)
    if not target:
        raise ValueError(f"No _target in {comp_path / 'config.json'}")
    target = _TARGET_MAP.get(target, target)
    cls_ref = _get_class(target)
    params = {k: v for k, v in cfg.items() if not k.startswith("_")}
    comp = cls_ref(**params)
    for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
        wp = comp_path / wfile
        if wp.exists():
            if wfile.endswith(".safetensors"):
                from safetensors.torch import load_file
                state = load_file(str(wp))
            else:
                try:
                    state = torch.load(wp, map_location="cpu", weights_only=True)
                except TypeError:
                    state = torch.load(wp, map_location="cpu")
            comp.load_state_dict(state, strict=True)
            break
    comp.eval()
    return comp


def load_components(model_path: Union[str, Path]) -> dict:
    """Load all pipeline components. Returns dict with components, scheduler, scale_factor."""
    path = Path(ensure_ldm_path(model_path))
    if path.name in _COMPONENT_NAMES and (path / "config.json").exists():
        path = path.parent
    scheduler = DDIMScheduler.from_pretrained(path / "scheduler")
    components = {}
    for name in _COMPONENT_NAMES:
        components[name] = load_component(path, name)
    scale_factor = 0.18215
    if (path / "model_index.json").exists():
        with open(path / "model_index.json") as f:
            scale_factor = json.load(f).get("scale_factor", scale_factor)
    components["scheduler"] = scheduler
    components["scale_factor"] = scale_factor
    return components