HSIGene / pipeline_hsigene.py
BiliSakura's picture
Add files using upload-large-folder tool
7f25226 verified
"""HSIGenePipeline - diffusers DiffusionPipeline for HSIGene hyperspectral generation.
AeroGen-style loading: use DiffusionPipeline.from_pretrained(path) - no sys.path.insert needed.
Self-contained: loading logic inlined (no separate modular_pipeline import).
"""
import importlib
import json
import sys
from pathlib import Path
from typing import List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from dataclasses import dataclass
from diffusers import DDIMScheduler, DiffusionPipeline
from diffusers.utils import BaseOutput
# Re-export for diffusers component loading (load_method lookup)
DiffusionPipeline = DiffusionPipeline
# Inline path/loading (AeroGen-style) - self-contained for diffusers cache loading
_pipeline_dir = Path(__file__).resolve().parent
if str(_pipeline_dir) not in sys.path:
sys.path.insert(0, str(_pipeline_dir))
# Register as "pipeline_hsigene" so diffusers' get_class_obj_and_candidates finds us when it does
# importlib.import_module("pipeline_hsigene") during component loading. (We may be loaded as
# "diffusers_modules.local.xxx.pipeline_hsigene" from cache, so this alias is required.)
sys.modules["pipeline_hsigene"] = sys.modules[__name__]
_COMPONENT_NAMES = (
"unet", "vae", "text_encoder", "local_adapter",
"global_content_adapter", "global_text_adapter", "metadata_encoder",
)
_TARGET_MAP = {
"hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet",
"hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet",
"hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
"hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL",
"ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder",
"hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder",
"models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter",
"hsigene.LocalAdapter": "local_adapter.model.LocalAdapter",
"models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
"hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter",
"models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
"hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter",
"models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
"hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings",
}
def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path:
"""Add model repo to path so hsigene can be imported. Returns resolved path."""
path = Path(pretrained_model_name_or_path)
if not path.exists():
from huggingface_hub import snapshot_download
path = Path(snapshot_download(pretrained_model_name_or_path))
path = path.resolve()
s = str(path)
if s not in sys.path:
sys.path.insert(0, s)
return path
def _get_class(target: str):
module_path, cls_name = target.rsplit(".", 1)
mod = importlib.import_module(module_path)
return getattr(mod, cls_name)
def load_component(model_path: Path, name: str):
"""Load a single component (unet, vae, text_encoder, etc.)."""
path = Path(model_path)
root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path
ensure_ldm_path(root)
comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name
with open(comp_path / "config.json") as f:
cfg = json.load(f)
target = cfg.pop("_target", None)
if not target:
raise ValueError(f"No _target in {comp_path / 'config.json'}")
target = _TARGET_MAP.get(target, target)
cls_ref = _get_class(target)
params = {k: v for k, v in cfg.items() if not k.startswith("_")}
comp = cls_ref(**params)
for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"):
wp = comp_path / wfile
if wp.exists():
if wfile.endswith(".safetensors"):
from safetensors.torch import load_file
state = load_file(str(wp))
else:
try:
state = torch.load(wp, map_location="cpu", weights_only=True)
except TypeError:
state = torch.load(wp, map_location="cpu")
comp.load_state_dict(state, strict=True)
break
comp.eval()
return comp
def load_components(model_path: Union[str, Path]) -> dict:
"""Load all pipeline components."""
path = Path(ensure_ldm_path(model_path))
if path.name in _COMPONENT_NAMES and (path / "config.json").exists():
path = path.parent
scheduler = DDIMScheduler.from_pretrained(path / "scheduler")
components = {}
for name in _COMPONENT_NAMES:
components[name] = load_component(path, name)
scale_factor = 0.18215
if (path / "model_index.json").exists():
with open(path / "model_index.json") as f:
scale_factor = json.load(f).get("scale_factor", scale_factor)
components["scheduler"] = scheduler
components["scale_factor"] = scale_factor
return components
class _CRSModelWrapper(torch.nn.Module):
"""Wrapper that mimics CRSControlNet interface."""
def __init__(
self,
unet,
vae,
text_encoder,
local_adapter,
global_content_adapter,
global_text_adapter,
metadata_emb,
scale_factor=0.18215,
local_control_scales=None,
):
super().__init__()
# Keep diffusion_model as a properly registered submodule so
# wrapper/device transfers (e.g., `.to("cuda")`) move UNet weights.
self.model = torch.nn.Module()
self.model.add_module("diffusion_model", unet)
self.first_stage_model = vae
self.cond_stage_model = text_encoder
self.local_adapter = local_adapter
self.global_content_adapter = global_content_adapter
self.global_text_adapter = global_text_adapter
self.metadata_emb = metadata_emb
self.scale_factor = scale_factor
self.local_control_scales = local_control_scales or [1.0] * 13
@torch.no_grad()
def get_learned_conditioning(self, prompts):
return self.cond_stage_model(prompts)
def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, text_strength=1.0, **kwargs):
if metadata is None:
metadata = cond["metadata"]
metadata_emb = self.metadata_emb(metadata)
content_t = cond["global_control"][0]
global_control = self.global_content_adapter(content_t)
cond_txt = torch.cat(cond["c_crossattn"], 1)
cond_txt = self.global_text_adapter(cond_txt)
cond_txt = F.normalize(cond_txt, p=2, dim=-1) * text_strength
global_control = F.normalize(global_control, p=2, dim=-1) * global_strength
cond_txt = torch.cat([cond_txt, global_control], dim=1)
local_control = torch.cat(cond["local_control"], 1)
local_control = self.local_adapter(
x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control
)
local_control = [c * s for c, s in zip(local_control, self.local_control_scales)]
return self.model.diffusion_model(
x=x_noisy,
timesteps=t,
metadata=metadata_emb,
context=cond_txt,
local_control=local_control,
meta=True,
)
def decode_first_stage(self, z):
z = (1.0 / self.scale_factor) * z
return self.first_stage_model.decode(z)
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model.diffusion_model = self.model.diffusion_model.cuda()
self.local_adapter = self.local_adapter.cuda()
self.global_text_adapter = self.global_text_adapter.cuda()
self.global_content_adapter = self.global_content_adapter.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model.diffusion_model = self.model.diffusion_model.cpu()
self.local_adapter = self.local_adapter.cpu()
self.global_text_adapter = self.global_text_adapter.cpu()
self.global_content_adapter = self.global_content_adapter.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()
@dataclass
class HSIGeneOutput(BaseOutput):
"""Output class for HSIGene pipeline."""
images: Optional[np.ndarray] = None
latents: Optional[torch.Tensor] = None
def _is_component_list(v):
"""Check if value is raw config format [library, class_name]."""
return isinstance(v, (list, tuple)) and len(v) == 2 and isinstance(v[0], str) and isinstance(v[1], str)
def _resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]:
"""Resolve candidate path/repo to model root containing model_index.json."""
if not candidate:
return None
try:
path = Path(candidate)
if not path.exists():
from huggingface_hub import snapshot_download
path = Path(snapshot_download(str(candidate)))
path = path.resolve()
if (path / "model_index.json").exists():
return path
cur = path
for _ in range(5):
parent = cur.parent
if parent == cur:
break
if (parent / "model_index.json").exists():
return parent
cur = parent
except Exception:
return None
return None
class HSIGenePipeline(DiffusionPipeline):
"""Pipeline for HSIGene hyperspectral image generation.
AeroGen-style: load with DiffusionPipeline.from_pretrained(path) - no sys.path.insert.
"""
def register_modules(self, **kwargs):
"""Override to handle list-format component specs from diffusers config."""
for name, module in kwargs.items():
if module is None or (isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None):
self.register_to_config(**{name: (None, None)})
setattr(self, name, module)
elif _is_component_list(module):
self.register_to_config(**{name: (module[0], module[1])})
setattr(self, name, module)
else:
from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple
library, class_name = _fetch_class_library_tuple(module)
self.register_to_config(**{name: (library, class_name)})
setattr(self, name, module)
def __init__(
self,
unet=None,
vae=None,
text_encoder=None,
local_adapter=None,
global_content_adapter=None,
global_text_adapter=None,
metadata_encoder=None,
scheduler=None,
crs_model=None,
scale_factor=0.18215,
model_path: Optional[Union[str, Path]] = None,
_name_or_path: Optional[Union[str, Path]] = None,
):
super().__init__()
if crs_model is not None:
self.register_modules(crs_model=crs_model, scheduler=scheduler)
else:
components_are_lists = any(
_is_component_list(x)
for x in (
unet,
vae,
text_encoder,
local_adapter,
global_content_adapter,
global_text_adapter,
metadata_encoder,
)
if x is not None
)
if components_are_lists:
# Diffusers custom_pipeline may pass raw [library, class] placeholders to __init__.
# Resolve model root and materialize real components here.
model_root = (
_resolve_model_root(model_path)
or _resolve_model_root(_name_or_path)
or _resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None))
)
if model_root is None:
raise ValueError(
"HSIGene received raw config placeholders but could not resolve model path. "
"Pass `model_path` to HSIGenePipeline or load via "
"`DiffusionPipeline.from_pretrained(<path>, custom_pipeline=<pipeline_file>)` "
"with a valid local model directory."
)
loaded = load_components(model_root)
unet = loaded["unet"]
vae = loaded["vae"]
text_encoder = loaded["text_encoder"]
local_adapter = loaded["local_adapter"]
global_content_adapter = loaded["global_content_adapter"]
global_text_adapter = loaded["global_text_adapter"]
metadata_encoder = loaded["metadata_encoder"]
scheduler = loaded["scheduler"] if scheduler is None else scheduler
scale_factor = loaded["scale_factor"]
crs_model = _CRSModelWrapper(
unet=unet,
vae=vae,
text_encoder=text_encoder,
local_adapter=local_adapter,
global_content_adapter=global_content_adapter,
global_text_adapter=global_text_adapter,
metadata_emb=metadata_encoder,
scale_factor=scale_factor,
)
self.register_modules(
unet=unet,
vae=vae,
text_encoder=text_encoder,
local_adapter=local_adapter,
global_content_adapter=global_content_adapter,
global_text_adapter=global_text_adapter,
metadata_encoder=metadata_encoder,
scheduler=scheduler,
crs_model=crs_model,
)
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: Union[str, Path],
device: Optional[Union[str, torch.device]] = None,
subfolder: Optional[str] = None,
**kwargs,
):
"""Load from diffusers-format directory. Supports subfolder for single-component loading."""
path = Path(ensure_ldm_path(pretrained_model_name_or_path))
subfolder = kwargs.pop("subfolder", subfolder)
if subfolder in ("unet", "vae", "text_encoder", "local_adapter",
"global_content_adapter", "global_text_adapter", "metadata_encoder"):
return load_component(path, subfolder)
if path.name in ("unet", "vae", "text_encoder", "local_adapter",
"global_content_adapter", "global_text_adapter", "metadata_encoder"):
if (path / "config.json").exists():
ensure_ldm_path(path.parent)
return load_component(path.parent, path.name)
if not (path / "model_index.json").exists():
for _ in range(5):
parent = path.parent
if (parent / "model_index.json").exists():
path = parent
break
if parent == path:
break
path = parent
components = load_components(path)
pipe = cls(
unet=components["unet"],
vae=components["vae"],
text_encoder=components["text_encoder"],
local_adapter=components["local_adapter"],
global_content_adapter=components["global_content_adapter"],
global_text_adapter=components["global_text_adapter"],
metadata_encoder=components["metadata_encoder"],
scheduler=components["scheduler"],
scale_factor=components["scale_factor"],
)
if device is not None:
pipe = pipe.to(device)
return pipe
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]] = "",
num_samples: int = 1,
height: int = 256,
width: int = 256,
num_inference_steps: int = 50,
eta: float = 0.0,
global_strength: float = 1.0,
text_strength: Optional[float] = None,
local_conditions: Optional[torch.Tensor] = None,
global_conditions: Optional[torch.Tensor] = None,
metadata: Optional[torch.Tensor] = None,
condition_resolution: int = 512,
guidance_scale: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
output_type: str = "numpy",
return_dict: bool = True,
save_memory: bool = False,
):
target_device = next(self.crs_model.parameters()).device
if hasattr(self, "unet") and isinstance(self.unet, torch.nn.Module):
target_device = next(self.unet.parameters()).device
if latents is not None:
target_device = latents.device
elif generator is not None and hasattr(generator, "device"):
target_device = torch.device(generator.device)
# Keep wrapper submodules on the same device used for sampling.
if next(self.crs_model.parameters()).device != target_device:
self.crs_model = self.crs_model.to(target_device)
device = target_device
if text_strength is None:
text_strength = global_strength
if isinstance(prompt, str):
prompts = [prompt] * num_samples
else:
prompts = list(prompt)
num_samples = len(prompts)
if save_memory:
self.crs_model.low_vram_shift(is_diffusing=False)
text_embedding = self.crs_model.get_learned_conditioning(prompts)
if local_conditions is None:
local_conditions = torch.zeros(
num_samples, 18, condition_resolution, condition_resolution,
device=device, dtype=torch.float32,
)
else:
local_conditions = local_conditions.to(device=device, dtype=torch.float32)
if global_conditions is None:
global_conditions = torch.zeros(
num_samples, 768, device=device, dtype=torch.float32,
)
else:
global_conditions = global_conditions.to(device=device, dtype=torch.float32)
if metadata is None:
metadata = torch.zeros(7, device=device, dtype=torch.float32)
else:
metadata = metadata.to(device=device, dtype=torch.float32)
cond = {
"local_control": [local_conditions],
"c_crossattn": [text_embedding],
"global_control": [global_conditions],
"metadata": [metadata],
}
do_cfg = guidance_scale > 1.0
if do_cfg:
if negative_prompt is None:
neg_prompts = [""] * num_samples
elif isinstance(negative_prompt, str):
neg_prompts = [negative_prompt] * num_samples
else:
neg_prompts = list(negative_prompt)
uc_text = self.crs_model.get_learned_conditioning(neg_prompts)
uncond = {
"local_control": [local_conditions],
"c_crossattn": [uc_text],
"global_control": [torch.zeros_like(global_conditions)],
"metadata": [metadata],
}
latent_shape = (num_samples, 4, height // 4, width // 4)
if latents is None:
if generator is not None and hasattr(generator, "device"):
gen_device = torch.device(generator.device)
if gen_device.type != device.type:
# Recreate generator on target device while preserving seed
# so CPU/CUDA mismatch does not crash torch.randn.
if hasattr(generator, "initial_seed"):
generator = torch.Generator(device=device).manual_seed(generator.initial_seed())
else:
generator = torch.Generator(device=device)
latents = torch.randn(
latent_shape, device=device, generator=generator, dtype=torch.float32,
)
else:
latents = latents.to(device)
self.scheduler.set_timesteps(num_inference_steps, device=device)
if save_memory:
self.crs_model.low_vram_shift(is_diffusing=True)
for t in self.progress_bar(self.scheduler.timesteps):
t_batch = t.expand(num_samples)
if do_cfg:
noise_pred_cond = self.crs_model.apply_model(
latents, t_batch, cond,
metadata=metadata,
global_strength=global_strength,
text_strength=text_strength,
)
noise_pred_uncond = self.crs_model.apply_model(
latents, t_batch, uncond,
metadata=metadata,
global_strength=global_strength,
text_strength=text_strength,
)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
else:
noise_pred = self.crs_model.apply_model(
latents, t_batch, cond,
metadata=metadata,
global_strength=global_strength,
text_strength=text_strength,
)
latents = self.scheduler.step(
noise_pred, t, latents, eta=eta, generator=generator,
).prev_sample
if output_type == "latent":
if not return_dict:
return (latents,)
return HSIGeneOutput(latents=latents)
if save_memory:
self.crs_model.low_vram_shift(is_diffusing=False)
images = self.crs_model.decode_first_stage(latents)
images = images.permute(0, 2, 3, 1).cpu().numpy()
images = images * 0.5 + 0.5
images = np.clip(images, 0, 1)
if not return_dict:
return (images,)
return HSIGeneOutput(images=images)