FlexiBrain / flexibrain /models /factory.py
OneMore1's picture
Sync from GitHub FlexiBrain main
6a51385 verified
Raw
History Blame Contribute Delete
4.66 kB
from __future__ import annotations
import logging
from typing import Optional
import torch
import torch.nn as nn
from flexibrain.config import ModelConfig, apply_checkpoint_config
from flexibrain.models.mamba_jepa import VolumeMambaJEPA
from flexibrain.models.classifier import MambaJEPAClassifier, MambaJEPAClassifierAvgPool
def build_mamba_backbone(cfg: ModelConfig, device: torch.device, dtype=torch.float32) -> VolumeMambaJEPA:
return VolumeMambaJEPA(
embed_dim=cfg.embed_dim,
depth=cfg.depth,
predictor_depth=cfg.predictor_depth,
ssm_cfg=None,
encoder_attn_layer_idx=None,
attn_cfg=None,
drop_path_rate=cfg.drop_path_rate,
norm_epsilon=1e-5,
rms_norm=cfg.rms_norm,
initializer_cfg=None,
fused_add_norm=cfg.fused_add_norm,
residual_in_fp32=cfg.residual_in_fp32,
device=device,
dtype=dtype,
bimamba_type=cfg.bimamba_type,
if_bimamba=cfg.if_bimamba,
mixer_type=cfg.mixer_type,
if_devide_out=cfg.if_devide_out,
momentum=cfg.momentum,
norm_target=cfg.norm_target,
)
def build_pretrain_model(cfg: ModelConfig, device: torch.device) -> nn.Module:
if cfg.model_type != "mamba":
raise ValueError("This cleaned Flexibrain build currently keeps only the Mamba pretrain/downstream path")
return build_mamba_backbone(cfg, device=device, dtype=torch.float32).to(device)
def load_checkpoint(path: str, device: torch.device):
return torch.load(path, map_location=device)
def state_dict_from_checkpoint(checkpoint: dict):
if "model_state_dict" in checkpoint:
return checkpoint["model_state_dict"]
if "model" in checkpoint:
return checkpoint["model"]
raise KeyError("Checkpoint has neither model_state_dict nor model")
def build_downstream_model(cfg: ModelConfig, device: torch.device, logger: Optional[logging.Logger] = None, checkpoint_path: Optional[str] = None, from_scratch: bool = False, use_checkpoint_config: bool = True) -> nn.Module:
checkpoint = None
if checkpoint_path and not from_scratch:
checkpoint = load_checkpoint(checkpoint_path, device)
if use_checkpoint_config:
apply_checkpoint_config(cfg, checkpoint.get("config", {}))
if logger:
logger.info("Backbone config restored from checkpoint: %s", checkpoint.get("config", {}))
if cfg.model_type != "mamba":
raise ValueError("This cleaned Flexibrain build currently keeps only the Mamba downstream path")
backbone = build_mamba_backbone(cfg, device=device, dtype=torch.float32)
if checkpoint is not None:
state = state_dict_from_checkpoint(checkpoint)
try:
backbone.load_state_dict(state, strict=True)
if logger:
logger.info("Loaded pretrained backbone strictly from %s", checkpoint_path)
except RuntimeError as exc:
incompatible = backbone.load_state_dict(state, strict=False)
backward_markers = ["_b", "conv1d_b", "x_proj_b", "dt_proj_b", "A_b_log", "D_b"]
missing = list(incompatible.missing_keys)
only_backward = missing and all(any(marker in key for marker in backward_markers) for key in missing)
if not only_backward or incompatible.unexpected_keys:
raise exc
if logger:
logger.warning("Strict load missed %d backward-scan BiMamba keys; loaded checkpoint with strict=False compatibility", len(missing))
elif logger:
logger.info("Backbone initialized from scratch")
if cfg.head_type == "transformer":
model = MambaJEPAClassifier(
backbone=backbone,
num_classes=cfg.num_classes,
head_depth=cfg.head_depth,
head_num_heads=cfg.head_num_heads,
head_mlp_ratio=cfg.head_mlp_ratio,
head_proj_drop=cfg.head_proj_drop,
head_drop_path=cfg.head_drop_path,
mlp_hidden=cfg.mlp_hidden,
mlp_depth=cfg.mlp_depth,
mlp_dropout=cfg.mlp_dropout,
freeze_backbone=cfg.freeze_backbone,
device=device,
)
elif cfg.head_type == "avgpool":
model = MambaJEPAClassifierAvgPool(backbone=backbone, num_classes=cfg.num_classes, mlp_hidden=cfg.mlp_hidden, mlp_depth=cfg.mlp_depth, mlp_dropout=cfg.mlp_dropout, freeze_backbone=cfg.freeze_backbone, device=device)
else:
raise ValueError(f"Unknown head_type: {cfg.head_type}")
return model.to(device)