BLIP3o-4B / builders.py
orrzohar's picture
working
24580dc
from __future__ import annotations
import logging
from pathlib import Path
from typing import Any
import torch.nn as nn
from .diffusion_auto import AutoDiffusionModel, DiffusionConfig
from .vision_tower import AutoEvaClipVisionTower, default_device
logger = logging.getLogger(__name__)
_PKG_ROOT = Path(__file__).resolve().parent
_DEFAULT_DIFF_ROOT = (_PKG_ROOT.parent / "BLIP3o-4B-Diffusion-Decoder").resolve()
def _get_attr(obj: Any, name: str, default: Any = None) -> Any:
return getattr(obj, name, default)
class IdentityMap(nn.Module):
def forward(self, x, *args, **kwargs):
return x
@property
def config(self):
return {"mm_projector_type": "identity"}
def build_gen_vision_projector(config, delay_load: bool = False, **kwargs):
projector_type = getattr(config, "gen_projector_type", "linear")
hidden = getattr(config, "hidden_size")
gen_hidden = getattr(config, "gen_hidden_size", hidden)
if projector_type == "linear":
return nn.Linear(gen_hidden, hidden)
if projector_type.startswith("mlp") and "gelu" in projector_type:
depth_str = projector_type.replace("mlp", "").split("x")[0]
depth = int(depth_str) if depth_str.isdigit() else 2
modules = [nn.Linear(gen_hidden, hidden)]
for _ in range(1, depth):
modules.append(nn.GELU())
modules.append(nn.Linear(hidden, hidden))
return nn.Sequential(*modules)
if projector_type == "identity":
return IdentityMap()
raise ValueError(f"Unknown projector type: {projector_type}")
def build_down_projector(config, delay_load: bool = False, **kwargs):
projector_type = getattr(config, "mm_projector_type", "identity")
mm_hidden = getattr(
config,
"mm_hidden_size",
getattr(config, "hidden_size", getattr(config, "gen_hidden_size", None)),
)
if mm_hidden is None:
raise AttributeError(
"Config must define one of mm_hidden_size/gen_hidden_size/hidden_size for down_projector."
)
if projector_type == "identity":
return IdentityMap()
if projector_type == "linear":
return nn.Linear(mm_hidden, config.hidden_size)
if projector_type.startswith("mlp") and "gelu" in projector_type:
depth_str = projector_type.replace("mlp", "").split("x")[0]
depth = int(depth_str) if depth_str.isdigit() else 2
modules = [nn.Linear(mm_hidden, config.hidden_size)]
for _ in range(1, depth):
modules.append(nn.GELU())
modules.append(nn.Linear(config.hidden_size, config.hidden_size))
return nn.Sequential(*modules)
raise ValueError(f"Unknown mm projector type: {projector_type}")
def build_gen_vision_tower(config, delay_load: bool = False, **kwargs):
"""Instantiate the EVA-CLIP tower purely from HF Hub assets."""
tower = AutoEvaClipVisionTower(
config=config,
torch_dtype=kwargs.get("torch_dtype"),
device=default_device(kwargs.get("device")),
delay_load=delay_load,
)
if hasattr(tower, "load_model") and not delay_load:
tower.load_model(torch_dtype=kwargs.get("torch_dtype"), device=kwargs.get("device"))
return tower
def build_dit(config, **kwargs):
"""Instantiate the diffusion transformer + scheduler bundle."""
default_diff = DiffusionConfig()
latent_dim = _get_attr(config, "hidden_size", default_diff.latent_embedding_size)
weights_path = _get_attr(config, "diffusion_weights_path", None)
scheduler_path = _get_attr(config, "diffusion_scheduler_path", None)
if weights_path is None and _DEFAULT_DIFF_ROOT.exists():
candidate = _DEFAULT_DIFF_ROOT / "unet" / "diffusion_pytorch_model.bf16.safetensors"
if candidate.exists():
weights_path = str(candidate)
if scheduler_path is None and _DEFAULT_DIFF_ROOT.exists():
sched_dir = _DEFAULT_DIFF_ROOT / "scheduler"
if (sched_dir / "scheduler_config.json").exists():
scheduler_path = str(sched_dir)
diff_cfg = DiffusionConfig(
weights_path=weights_path,
scheduler_path=scheduler_path,
latent_embedding_size=latent_dim,
dim=_get_attr(config, "dit_hidden_size", default_diff.dim),
n_layers=_get_attr(config, "dit_layers", default_diff.n_layers),
n_heads=_get_attr(config, "dit_heads", default_diff.n_heads),
n_kv_heads=_get_attr(config, "dit_kv_heads", default_diff.n_kv_heads),
input_size=_get_attr(config, "dit_input_size", default_diff.input_size),
patch_size=_get_attr(config, "dit_patch_size", default_diff.patch_size),
in_channels=_get_attr(config, "dit_in_channels", default_diff.in_channels),
)
bundle = AutoDiffusionModel.from_config(
diff_cfg,
device=kwargs.get("device"),
torch_dtype=kwargs.get("torch_dtype"),
)
logger.debug(
"Loaded diffusion transformer: missing=%s unexpected=%s",
bundle.load_info.get("missing_keys"),
bundle.load_info.get("unexpected_keys"),
)
return bundle.dit, bundle.scheduler
__all__ = ["build_gen_vision_tower", "build_gen_vision_projector", "build_down_projector", "build_dit"]