from __future__ import annotations import logging from pathlib import Path from typing import Any import torch.nn as nn from .diffusion_auto import AutoDiffusionModel, DiffusionConfig from .vision_tower import AutoEvaClipVisionTower, default_device logger = logging.getLogger(__name__) _PKG_ROOT = Path(__file__).resolve().parent _DEFAULT_DIFF_ROOT = (_PKG_ROOT.parent / "BLIP3o-4B-Diffusion-Decoder").resolve() def _get_attr(obj: Any, name: str, default: Any = None) -> Any: return getattr(obj, name, default) class IdentityMap(nn.Module): def forward(self, x, *args, **kwargs): return x @property def config(self): return {"mm_projector_type": "identity"} def build_gen_vision_projector(config, delay_load: bool = False, **kwargs): projector_type = getattr(config, "gen_projector_type", "linear") hidden = getattr(config, "hidden_size") gen_hidden = getattr(config, "gen_hidden_size", hidden) if projector_type == "linear": return nn.Linear(gen_hidden, hidden) if projector_type.startswith("mlp") and "gelu" in projector_type: depth_str = projector_type.replace("mlp", "").split("x")[0] depth = int(depth_str) if depth_str.isdigit() else 2 modules = [nn.Linear(gen_hidden, hidden)] for _ in range(1, depth): modules.append(nn.GELU()) modules.append(nn.Linear(hidden, hidden)) return nn.Sequential(*modules) if projector_type == "identity": return IdentityMap() raise ValueError(f"Unknown projector type: {projector_type}") def build_down_projector(config, delay_load: bool = False, **kwargs): projector_type = getattr(config, "mm_projector_type", "identity") mm_hidden = getattr( config, "mm_hidden_size", getattr(config, "hidden_size", getattr(config, "gen_hidden_size", None)), ) if mm_hidden is None: raise AttributeError( "Config must define one of mm_hidden_size/gen_hidden_size/hidden_size for down_projector." ) if projector_type == "identity": return IdentityMap() if projector_type == "linear": return nn.Linear(mm_hidden, config.hidden_size) if projector_type.startswith("mlp") and "gelu" in projector_type: depth_str = projector_type.replace("mlp", "").split("x")[0] depth = int(depth_str) if depth_str.isdigit() else 2 modules = [nn.Linear(mm_hidden, config.hidden_size)] for _ in range(1, depth): modules.append(nn.GELU()) modules.append(nn.Linear(config.hidden_size, config.hidden_size)) return nn.Sequential(*modules) raise ValueError(f"Unknown mm projector type: {projector_type}") def build_gen_vision_tower(config, delay_load: bool = False, **kwargs): """Instantiate the EVA-CLIP tower purely from HF Hub assets.""" tower = AutoEvaClipVisionTower( config=config, torch_dtype=kwargs.get("torch_dtype"), device=default_device(kwargs.get("device")), delay_load=delay_load, ) if hasattr(tower, "load_model") and not delay_load: tower.load_model(torch_dtype=kwargs.get("torch_dtype"), device=kwargs.get("device")) return tower def build_dit(config, **kwargs): """Instantiate the diffusion transformer + scheduler bundle.""" default_diff = DiffusionConfig() latent_dim = _get_attr(config, "hidden_size", default_diff.latent_embedding_size) weights_path = _get_attr(config, "diffusion_weights_path", None) scheduler_path = _get_attr(config, "diffusion_scheduler_path", None) if weights_path is None and _DEFAULT_DIFF_ROOT.exists(): candidate = _DEFAULT_DIFF_ROOT / "unet" / "diffusion_pytorch_model.bf16.safetensors" if candidate.exists(): weights_path = str(candidate) if scheduler_path is None and _DEFAULT_DIFF_ROOT.exists(): sched_dir = _DEFAULT_DIFF_ROOT / "scheduler" if (sched_dir / "scheduler_config.json").exists(): scheduler_path = str(sched_dir) diff_cfg = DiffusionConfig( weights_path=weights_path, scheduler_path=scheduler_path, latent_embedding_size=latent_dim, dim=_get_attr(config, "dit_hidden_size", default_diff.dim), n_layers=_get_attr(config, "dit_layers", default_diff.n_layers), n_heads=_get_attr(config, "dit_heads", default_diff.n_heads), n_kv_heads=_get_attr(config, "dit_kv_heads", default_diff.n_kv_heads), input_size=_get_attr(config, "dit_input_size", default_diff.input_size), patch_size=_get_attr(config, "dit_patch_size", default_diff.patch_size), in_channels=_get_attr(config, "dit_in_channels", default_diff.in_channels), ) bundle = AutoDiffusionModel.from_config( diff_cfg, device=kwargs.get("device"), torch_dtype=kwargs.get("torch_dtype"), ) logger.debug( "Loaded diffusion transformer: missing=%s unexpected=%s", bundle.load_info.get("missing_keys"), bundle.load_info.get("unexpected_keys"), ) return bundle.dit, bundle.scheduler __all__ = ["build_gen_vision_tower", "build_gen_vision_projector", "build_down_projector", "build_dit"]