| """HSIGenePipeline - diffusers DiffusionPipeline for HSIGene hyperspectral generation. |
| |
| AeroGen-style loading: use DiffusionPipeline.from_pretrained(path) - no sys.path.insert needed. |
| Self-contained: loading logic inlined (no separate modular_pipeline import). |
| """ |
|
|
| import importlib |
| import json |
| import sys |
| from pathlib import Path |
| from typing import List, Optional, Union |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
|
|
| from diffusers import DDIMScheduler, DiffusionPipeline |
| from diffusers.utils import BaseOutput |
|
|
| |
| DiffusionPipeline = DiffusionPipeline |
|
|
| |
| _pipeline_dir = Path(__file__).resolve().parent |
| if str(_pipeline_dir) not in sys.path: |
| sys.path.insert(0, str(_pipeline_dir)) |
|
|
| |
| |
| |
| sys.modules["pipeline_hsigene"] = sys.modules[__name__] |
|
|
| _COMPONENT_NAMES = ( |
| "unet", "vae", "text_encoder", "local_adapter", |
| "global_content_adapter", "global_text_adapter", "metadata_encoder", |
| ) |
|
|
| _TARGET_MAP = { |
| "hsigene_models.HSIGeneUNet": "unet.model.HSIGeneUNet", |
| "hsigene.HSIGeneUNet": "unet.model.HSIGeneUNet", |
| "hsigene_models.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", |
| "hsigene.HSIGeneAutoencoderKL": "vae.model.HSIGeneAutoencoderKL", |
| "ldm.modules.encoders.modules.FrozenCLIPEmbedder": "text_encoder.model.CLIPTextEncoder", |
| "hsigene.CLIPTextEncoder": "text_encoder.model.CLIPTextEncoder", |
| "models.local_adapter.LocalAdapter": "local_adapter.model.LocalAdapter", |
| "hsigene.LocalAdapter": "local_adapter.model.LocalAdapter", |
| "models.global_adapter.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", |
| "hsigene.GlobalContentAdapter": "global_content_adapter.model.GlobalContentAdapter", |
| "models.global_adapter.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", |
| "hsigene.GlobalTextAdapter": "global_text_adapter.model.GlobalTextAdapter", |
| "models.metadata_embedding.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", |
| "hsigene.metadata_embeddings": "metadata_encoder.model.metadata_embeddings", |
| } |
|
|
|
|
| def ensure_ldm_path(pretrained_model_name_or_path: Union[str, Path]) -> Path: |
| """Add model repo to path so hsigene can be imported. Returns resolved path.""" |
| path = Path(pretrained_model_name_or_path) |
| if not path.exists(): |
| from huggingface_hub import snapshot_download |
| path = Path(snapshot_download(pretrained_model_name_or_path)) |
| path = path.resolve() |
| s = str(path) |
| if s not in sys.path: |
| sys.path.insert(0, s) |
| return path |
|
|
|
|
| def _get_class(target: str): |
| module_path, cls_name = target.rsplit(".", 1) |
| mod = importlib.import_module(module_path) |
| return getattr(mod, cls_name) |
|
|
|
|
| def load_component(model_path: Path, name: str): |
| """Load a single component (unet, vae, text_encoder, etc.).""" |
| path = Path(model_path) |
| root = path.parent if path.name in _COMPONENT_NAMES and (path / "config.json").exists() else path |
| ensure_ldm_path(root) |
| comp_path = path if (path / "config.json").exists() and path.name in _COMPONENT_NAMES else path / name |
| with open(comp_path / "config.json") as f: |
| cfg = json.load(f) |
| target = cfg.pop("_target", None) |
| if not target: |
| raise ValueError(f"No _target in {comp_path / 'config.json'}") |
| target = _TARGET_MAP.get(target, target) |
| cls_ref = _get_class(target) |
| params = {k: v for k, v in cfg.items() if not k.startswith("_")} |
| comp = cls_ref(**params) |
| for wfile in ("diffusion_pytorch_model.safetensors", "diffusion_pytorch_model.bin"): |
| wp = comp_path / wfile |
| if wp.exists(): |
| if wfile.endswith(".safetensors"): |
| from safetensors.torch import load_file |
| state = load_file(str(wp)) |
| else: |
| try: |
| state = torch.load(wp, map_location="cpu", weights_only=True) |
| except TypeError: |
| state = torch.load(wp, map_location="cpu") |
| comp.load_state_dict(state, strict=True) |
| break |
| comp.eval() |
| return comp |
|
|
|
|
| def load_components(model_path: Union[str, Path]) -> dict: |
| """Load all pipeline components.""" |
| path = Path(ensure_ldm_path(model_path)) |
| if path.name in _COMPONENT_NAMES and (path / "config.json").exists(): |
| path = path.parent |
| scheduler = DDIMScheduler.from_pretrained(path / "scheduler") |
| components = {} |
| for name in _COMPONENT_NAMES: |
| components[name] = load_component(path, name) |
| scale_factor = 0.18215 |
| if (path / "model_index.json").exists(): |
| with open(path / "model_index.json") as f: |
| scale_factor = json.load(f).get("scale_factor", scale_factor) |
| components["scheduler"] = scheduler |
| components["scale_factor"] = scale_factor |
| return components |
|
|
|
|
| class _CRSModelWrapper(torch.nn.Module): |
| """Wrapper that mimics CRSControlNet interface.""" |
|
|
| def __init__( |
| self, |
| unet, |
| vae, |
| text_encoder, |
| local_adapter, |
| global_content_adapter, |
| global_text_adapter, |
| metadata_emb, |
| scale_factor=0.18215, |
| local_control_scales=None, |
| ): |
| super().__init__() |
| |
| |
| self.model = torch.nn.Module() |
| self.model.add_module("diffusion_model", unet) |
| self.first_stage_model = vae |
| self.cond_stage_model = text_encoder |
| self.local_adapter = local_adapter |
| self.global_content_adapter = global_content_adapter |
| self.global_text_adapter = global_text_adapter |
| self.metadata_emb = metadata_emb |
| self.scale_factor = scale_factor |
| self.local_control_scales = local_control_scales or [1.0] * 13 |
|
|
| @torch.no_grad() |
| def get_learned_conditioning(self, prompts): |
| return self.cond_stage_model(prompts) |
|
|
| def apply_model(self, x_noisy, t, cond, metadata=None, global_strength=1.0, text_strength=1.0, **kwargs): |
| if metadata is None: |
| metadata = cond["metadata"] |
| metadata_emb = self.metadata_emb(metadata) |
| content_t = cond["global_control"][0] |
| global_control = self.global_content_adapter(content_t) |
| cond_txt = torch.cat(cond["c_crossattn"], 1) |
| cond_txt = self.global_text_adapter(cond_txt) |
| cond_txt = F.normalize(cond_txt, p=2, dim=-1) * text_strength |
| global_control = F.normalize(global_control, p=2, dim=-1) * global_strength |
| cond_txt = torch.cat([cond_txt, global_control], dim=1) |
| local_control = torch.cat(cond["local_control"], 1) |
| local_control = self.local_adapter( |
| x=x_noisy, timesteps=t, context=cond_txt, local_conditions=local_control |
| ) |
| local_control = [c * s for c, s in zip(local_control, self.local_control_scales)] |
| return self.model.diffusion_model( |
| x=x_noisy, |
| timesteps=t, |
| metadata=metadata_emb, |
| context=cond_txt, |
| local_control=local_control, |
| meta=True, |
| ) |
|
|
| def decode_first_stage(self, z): |
| z = (1.0 / self.scale_factor) * z |
| return self.first_stage_model.decode(z) |
|
|
| def low_vram_shift(self, is_diffusing): |
| if is_diffusing: |
| self.model.diffusion_model = self.model.diffusion_model.cuda() |
| self.local_adapter = self.local_adapter.cuda() |
| self.global_text_adapter = self.global_text_adapter.cuda() |
| self.global_content_adapter = self.global_content_adapter.cuda() |
| self.first_stage_model = self.first_stage_model.cpu() |
| self.cond_stage_model = self.cond_stage_model.cpu() |
| else: |
| self.model.diffusion_model = self.model.diffusion_model.cpu() |
| self.local_adapter = self.local_adapter.cpu() |
| self.global_text_adapter = self.global_text_adapter.cpu() |
| self.global_content_adapter = self.global_content_adapter.cpu() |
| self.first_stage_model = self.first_stage_model.cuda() |
| self.cond_stage_model = self.cond_stage_model.cuda() |
|
|
|
|
| @dataclass |
| class HSIGeneOutput(BaseOutput): |
| """Output class for HSIGene pipeline.""" |
|
|
| images: Optional[np.ndarray] = None |
| latents: Optional[torch.Tensor] = None |
|
|
|
|
| def _is_component_list(v): |
| """Check if value is raw config format [library, class_name].""" |
| return isinstance(v, (list, tuple)) and len(v) == 2 and isinstance(v[0], str) and isinstance(v[1], str) |
|
|
|
|
| def _resolve_model_root(candidate: Optional[Union[str, Path]]) -> Optional[Path]: |
| """Resolve candidate path/repo to model root containing model_index.json.""" |
| if not candidate: |
| return None |
| try: |
| path = Path(candidate) |
| if not path.exists(): |
| from huggingface_hub import snapshot_download |
| path = Path(snapshot_download(str(candidate))) |
| path = path.resolve() |
| if (path / "model_index.json").exists(): |
| return path |
| cur = path |
| for _ in range(5): |
| parent = cur.parent |
| if parent == cur: |
| break |
| if (parent / "model_index.json").exists(): |
| return parent |
| cur = parent |
| except Exception: |
| return None |
| return None |
|
|
|
|
| class HSIGenePipeline(DiffusionPipeline): |
| """Pipeline for HSIGene hyperspectral image generation. |
| |
| AeroGen-style: load with DiffusionPipeline.from_pretrained(path) - no sys.path.insert. |
| """ |
|
|
| def register_modules(self, **kwargs): |
| """Override to handle list-format component specs from diffusers config.""" |
| for name, module in kwargs.items(): |
| if module is None or (isinstance(module, (tuple, list)) and len(module) > 0 and module[0] is None): |
| self.register_to_config(**{name: (None, None)}) |
| setattr(self, name, module) |
| elif _is_component_list(module): |
| self.register_to_config(**{name: (module[0], module[1])}) |
| setattr(self, name, module) |
| else: |
| from diffusers.pipelines.pipeline_loading_utils import _fetch_class_library_tuple |
| library, class_name = _fetch_class_library_tuple(module) |
| self.register_to_config(**{name: (library, class_name)}) |
| setattr(self, name, module) |
|
|
| def __init__( |
| self, |
| unet=None, |
| vae=None, |
| text_encoder=None, |
| local_adapter=None, |
| global_content_adapter=None, |
| global_text_adapter=None, |
| metadata_encoder=None, |
| scheduler=None, |
| crs_model=None, |
| scale_factor=0.18215, |
| model_path: Optional[Union[str, Path]] = None, |
| _name_or_path: Optional[Union[str, Path]] = None, |
| ): |
| super().__init__() |
| if crs_model is not None: |
| self.register_modules(crs_model=crs_model, scheduler=scheduler) |
| else: |
| components_are_lists = any( |
| _is_component_list(x) |
| for x in ( |
| unet, |
| vae, |
| text_encoder, |
| local_adapter, |
| global_content_adapter, |
| global_text_adapter, |
| metadata_encoder, |
| ) |
| if x is not None |
| ) |
| if components_are_lists: |
| |
| |
| model_root = ( |
| _resolve_model_root(model_path) |
| or _resolve_model_root(_name_or_path) |
| or _resolve_model_root(getattr(getattr(self, "config", None), "_name_or_path", None)) |
| ) |
| if model_root is None: |
| raise ValueError( |
| "HSIGene received raw config placeholders but could not resolve model path. " |
| "Pass `model_path` to HSIGenePipeline or load via " |
| "`DiffusionPipeline.from_pretrained(<path>, custom_pipeline=<pipeline_file>)` " |
| "with a valid local model directory." |
| ) |
| loaded = load_components(model_root) |
| unet = loaded["unet"] |
| vae = loaded["vae"] |
| text_encoder = loaded["text_encoder"] |
| local_adapter = loaded["local_adapter"] |
| global_content_adapter = loaded["global_content_adapter"] |
| global_text_adapter = loaded["global_text_adapter"] |
| metadata_encoder = loaded["metadata_encoder"] |
| scheduler = loaded["scheduler"] if scheduler is None else scheduler |
| scale_factor = loaded["scale_factor"] |
| crs_model = _CRSModelWrapper( |
| unet=unet, |
| vae=vae, |
| text_encoder=text_encoder, |
| local_adapter=local_adapter, |
| global_content_adapter=global_content_adapter, |
| global_text_adapter=global_text_adapter, |
| metadata_emb=metadata_encoder, |
| scale_factor=scale_factor, |
| ) |
| self.register_modules( |
| unet=unet, |
| vae=vae, |
| text_encoder=text_encoder, |
| local_adapter=local_adapter, |
| global_content_adapter=global_content_adapter, |
| global_text_adapter=global_text_adapter, |
| metadata_encoder=metadata_encoder, |
| scheduler=scheduler, |
| crs_model=crs_model, |
| ) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Union[str, Path], |
| device: Optional[Union[str, torch.device]] = None, |
| subfolder: Optional[str] = None, |
| **kwargs, |
| ): |
| """Load from diffusers-format directory. Supports subfolder for single-component loading.""" |
| path = Path(ensure_ldm_path(pretrained_model_name_or_path)) |
| subfolder = kwargs.pop("subfolder", subfolder) |
|
|
| if subfolder in ("unet", "vae", "text_encoder", "local_adapter", |
| "global_content_adapter", "global_text_adapter", "metadata_encoder"): |
| return load_component(path, subfolder) |
|
|
| if path.name in ("unet", "vae", "text_encoder", "local_adapter", |
| "global_content_adapter", "global_text_adapter", "metadata_encoder"): |
| if (path / "config.json").exists(): |
| ensure_ldm_path(path.parent) |
| return load_component(path.parent, path.name) |
|
|
| if not (path / "model_index.json").exists(): |
| for _ in range(5): |
| parent = path.parent |
| if (parent / "model_index.json").exists(): |
| path = parent |
| break |
| if parent == path: |
| break |
| path = parent |
|
|
| components = load_components(path) |
| pipe = cls( |
| unet=components["unet"], |
| vae=components["vae"], |
| text_encoder=components["text_encoder"], |
| local_adapter=components["local_adapter"], |
| global_content_adapter=components["global_content_adapter"], |
| global_text_adapter=components["global_text_adapter"], |
| metadata_encoder=components["metadata_encoder"], |
| scheduler=components["scheduler"], |
| scale_factor=components["scale_factor"], |
| ) |
| if device is not None: |
| pipe = pipe.to(device) |
| return pipe |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| prompt: Union[str, List[str]] = "", |
| num_samples: int = 1, |
| height: int = 256, |
| width: int = 256, |
| num_inference_steps: int = 50, |
| eta: float = 0.0, |
| global_strength: float = 1.0, |
| text_strength: Optional[float] = None, |
| local_conditions: Optional[torch.Tensor] = None, |
| global_conditions: Optional[torch.Tensor] = None, |
| metadata: Optional[torch.Tensor] = None, |
| condition_resolution: int = 512, |
| guidance_scale: float = 1.0, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| generator: Optional[torch.Generator] = None, |
| latents: Optional[torch.Tensor] = None, |
| output_type: str = "numpy", |
| return_dict: bool = True, |
| save_memory: bool = False, |
| ): |
| target_device = next(self.crs_model.parameters()).device |
| if hasattr(self, "unet") and isinstance(self.unet, torch.nn.Module): |
| target_device = next(self.unet.parameters()).device |
| if latents is not None: |
| target_device = latents.device |
| elif generator is not None and hasattr(generator, "device"): |
| target_device = torch.device(generator.device) |
|
|
| |
| if next(self.crs_model.parameters()).device != target_device: |
| self.crs_model = self.crs_model.to(target_device) |
| device = target_device |
| if text_strength is None: |
| text_strength = global_strength |
|
|
| if isinstance(prompt, str): |
| prompts = [prompt] * num_samples |
| else: |
| prompts = list(prompt) |
| num_samples = len(prompts) |
|
|
| if save_memory: |
| self.crs_model.low_vram_shift(is_diffusing=False) |
|
|
| text_embedding = self.crs_model.get_learned_conditioning(prompts) |
|
|
| if local_conditions is None: |
| local_conditions = torch.zeros( |
| num_samples, 18, condition_resolution, condition_resolution, |
| device=device, dtype=torch.float32, |
| ) |
| else: |
| local_conditions = local_conditions.to(device=device, dtype=torch.float32) |
|
|
| if global_conditions is None: |
| global_conditions = torch.zeros( |
| num_samples, 768, device=device, dtype=torch.float32, |
| ) |
| else: |
| global_conditions = global_conditions.to(device=device, dtype=torch.float32) |
|
|
| if metadata is None: |
| metadata = torch.zeros(7, device=device, dtype=torch.float32) |
| else: |
| metadata = metadata.to(device=device, dtype=torch.float32) |
|
|
| cond = { |
| "local_control": [local_conditions], |
| "c_crossattn": [text_embedding], |
| "global_control": [global_conditions], |
| "metadata": [metadata], |
| } |
|
|
| do_cfg = guidance_scale > 1.0 |
| if do_cfg: |
| if negative_prompt is None: |
| neg_prompts = [""] * num_samples |
| elif isinstance(negative_prompt, str): |
| neg_prompts = [negative_prompt] * num_samples |
| else: |
| neg_prompts = list(negative_prompt) |
| uc_text = self.crs_model.get_learned_conditioning(neg_prompts) |
| uncond = { |
| "local_control": [local_conditions], |
| "c_crossattn": [uc_text], |
| "global_control": [torch.zeros_like(global_conditions)], |
| "metadata": [metadata], |
| } |
|
|
| latent_shape = (num_samples, 4, height // 4, width // 4) |
| if latents is None: |
| if generator is not None and hasattr(generator, "device"): |
| gen_device = torch.device(generator.device) |
| if gen_device.type != device.type: |
| |
| |
| if hasattr(generator, "initial_seed"): |
| generator = torch.Generator(device=device).manual_seed(generator.initial_seed()) |
| else: |
| generator = torch.Generator(device=device) |
| latents = torch.randn( |
| latent_shape, device=device, generator=generator, dtype=torch.float32, |
| ) |
| else: |
| latents = latents.to(device) |
|
|
| self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
|
| if save_memory: |
| self.crs_model.low_vram_shift(is_diffusing=True) |
|
|
| for t in self.progress_bar(self.scheduler.timesteps): |
| t_batch = t.expand(num_samples) |
| if do_cfg: |
| noise_pred_cond = self.crs_model.apply_model( |
| latents, t_batch, cond, |
| metadata=metadata, |
| global_strength=global_strength, |
| text_strength=text_strength, |
| ) |
| noise_pred_uncond = self.crs_model.apply_model( |
| latents, t_batch, uncond, |
| metadata=metadata, |
| global_strength=global_strength, |
| text_strength=text_strength, |
| ) |
| noise_pred = noise_pred_uncond + guidance_scale * ( |
| noise_pred_cond - noise_pred_uncond |
| ) |
| else: |
| noise_pred = self.crs_model.apply_model( |
| latents, t_batch, cond, |
| metadata=metadata, |
| global_strength=global_strength, |
| text_strength=text_strength, |
| ) |
| latents = self.scheduler.step( |
| noise_pred, t, latents, eta=eta, generator=generator, |
| ).prev_sample |
|
|
| if output_type == "latent": |
| if not return_dict: |
| return (latents,) |
| return HSIGeneOutput(latents=latents) |
|
|
| if save_memory: |
| self.crs_model.low_vram_shift(is_diffusing=False) |
|
|
| images = self.crs_model.decode_first_stage(latents) |
| images = images.permute(0, 2, 3, 1).cpu().numpy() |
| images = images * 0.5 + 0.5 |
| images = np.clip(images, 0, 1) |
|
|
| if not return_dict: |
| return (images,) |
| return HSIGeneOutput(images=images) |
|
|