| | from __future__ import annotations |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Any |
| |
|
| | import torch.nn as nn |
| |
|
| | from .diffusion_auto import AutoDiffusionModel, DiffusionConfig |
| | from .vision_tower import AutoEvaClipVisionTower, default_device |
| |
|
| | logger = logging.getLogger(__name__) |
| | _PKG_ROOT = Path(__file__).resolve().parent |
| | _DEFAULT_DIFF_ROOT = (_PKG_ROOT.parent / "BLIP3o-4B-Diffusion-Decoder").resolve() |
| |
|
| |
|
| | def _get_attr(obj: Any, name: str, default: Any = None) -> Any: |
| | return getattr(obj, name, default) |
| |
|
| |
|
| | class IdentityMap(nn.Module): |
| | def forward(self, x, *args, **kwargs): |
| | return x |
| |
|
| | @property |
| | def config(self): |
| | return {"mm_projector_type": "identity"} |
| |
|
| |
|
| | def build_gen_vision_projector(config, delay_load: bool = False, **kwargs): |
| | projector_type = getattr(config, "gen_projector_type", "linear") |
| | hidden = getattr(config, "hidden_size") |
| | gen_hidden = getattr(config, "gen_hidden_size", hidden) |
| |
|
| | if projector_type == "linear": |
| | return nn.Linear(gen_hidden, hidden) |
| |
|
| | if projector_type.startswith("mlp") and "gelu" in projector_type: |
| | depth_str = projector_type.replace("mlp", "").split("x")[0] |
| | depth = int(depth_str) if depth_str.isdigit() else 2 |
| | modules = [nn.Linear(gen_hidden, hidden)] |
| | for _ in range(1, depth): |
| | modules.append(nn.GELU()) |
| | modules.append(nn.Linear(hidden, hidden)) |
| | return nn.Sequential(*modules) |
| |
|
| | if projector_type == "identity": |
| | return IdentityMap() |
| |
|
| | raise ValueError(f"Unknown projector type: {projector_type}") |
| |
|
| |
|
| | def build_down_projector(config, delay_load: bool = False, **kwargs): |
| | projector_type = getattr(config, "mm_projector_type", "identity") |
| | mm_hidden = getattr( |
| | config, |
| | "mm_hidden_size", |
| | getattr(config, "hidden_size", getattr(config, "gen_hidden_size", None)), |
| | ) |
| | if mm_hidden is None: |
| | raise AttributeError( |
| | "Config must define one of mm_hidden_size/gen_hidden_size/hidden_size for down_projector." |
| | ) |
| | if projector_type == "identity": |
| | return IdentityMap() |
| | if projector_type == "linear": |
| | return nn.Linear(mm_hidden, config.hidden_size) |
| | if projector_type.startswith("mlp") and "gelu" in projector_type: |
| | depth_str = projector_type.replace("mlp", "").split("x")[0] |
| | depth = int(depth_str) if depth_str.isdigit() else 2 |
| | modules = [nn.Linear(mm_hidden, config.hidden_size)] |
| | for _ in range(1, depth): |
| | modules.append(nn.GELU()) |
| | modules.append(nn.Linear(config.hidden_size, config.hidden_size)) |
| | return nn.Sequential(*modules) |
| | raise ValueError(f"Unknown mm projector type: {projector_type}") |
| |
|
| |
|
| | def build_gen_vision_tower(config, delay_load: bool = False, **kwargs): |
| | """Instantiate the EVA-CLIP tower purely from HF Hub assets.""" |
| |
|
| | tower = AutoEvaClipVisionTower( |
| | config=config, |
| | torch_dtype=kwargs.get("torch_dtype"), |
| | device=default_device(kwargs.get("device")), |
| | delay_load=delay_load, |
| | ) |
| | if hasattr(tower, "load_model") and not delay_load: |
| | tower.load_model(torch_dtype=kwargs.get("torch_dtype"), device=kwargs.get("device")) |
| | return tower |
| |
|
| |
|
| | def build_dit(config, **kwargs): |
| | """Instantiate the diffusion transformer + scheduler bundle.""" |
| |
|
| | default_diff = DiffusionConfig() |
| | latent_dim = _get_attr(config, "hidden_size", default_diff.latent_embedding_size) |
| |
|
| | weights_path = _get_attr(config, "diffusion_weights_path", None) |
| | scheduler_path = _get_attr(config, "diffusion_scheduler_path", None) |
| | if weights_path is None and _DEFAULT_DIFF_ROOT.exists(): |
| | candidate = _DEFAULT_DIFF_ROOT / "unet" / "diffusion_pytorch_model.bf16.safetensors" |
| | if candidate.exists(): |
| | weights_path = str(candidate) |
| | if scheduler_path is None and _DEFAULT_DIFF_ROOT.exists(): |
| | sched_dir = _DEFAULT_DIFF_ROOT / "scheduler" |
| | if (sched_dir / "scheduler_config.json").exists(): |
| | scheduler_path = str(sched_dir) |
| |
|
| | diff_cfg = DiffusionConfig( |
| | weights_path=weights_path, |
| | scheduler_path=scheduler_path, |
| | latent_embedding_size=latent_dim, |
| | dim=_get_attr(config, "dit_hidden_size", default_diff.dim), |
| | n_layers=_get_attr(config, "dit_layers", default_diff.n_layers), |
| | n_heads=_get_attr(config, "dit_heads", default_diff.n_heads), |
| | n_kv_heads=_get_attr(config, "dit_kv_heads", default_diff.n_kv_heads), |
| | input_size=_get_attr(config, "dit_input_size", default_diff.input_size), |
| | patch_size=_get_attr(config, "dit_patch_size", default_diff.patch_size), |
| | in_channels=_get_attr(config, "dit_in_channels", default_diff.in_channels), |
| | ) |
| |
|
| | bundle = AutoDiffusionModel.from_config( |
| | diff_cfg, |
| | device=kwargs.get("device"), |
| | torch_dtype=kwargs.get("torch_dtype"), |
| | ) |
| | logger.debug( |
| | "Loaded diffusion transformer: missing=%s unexpected=%s", |
| | bundle.load_info.get("missing_keys"), |
| | bundle.load_info.get("unexpected_keys"), |
| | ) |
| | return bundle.dit, bundle.scheduler |
| |
|
| |
|
| | __all__ = ["build_gen_vision_tower", "build_gen_vision_projector", "build_down_projector", "build_dit"] |
| |
|
| |
|
| |
|