from __future__ import annotations import logging from typing import Optional import torch import torch.nn as nn from flexibrain.config import ModelConfig, apply_checkpoint_config from flexibrain.models.mamba_jepa import VolumeMambaJEPA from flexibrain.models.classifier import MambaJEPAClassifier, MambaJEPAClassifierAvgPool def build_mamba_backbone(cfg: ModelConfig, device: torch.device, dtype=torch.float32) -> VolumeMambaJEPA: return VolumeMambaJEPA( embed_dim=cfg.embed_dim, depth=cfg.depth, predictor_depth=cfg.predictor_depth, ssm_cfg=None, encoder_attn_layer_idx=None, attn_cfg=None, drop_path_rate=cfg.drop_path_rate, norm_epsilon=1e-5, rms_norm=cfg.rms_norm, initializer_cfg=None, fused_add_norm=cfg.fused_add_norm, residual_in_fp32=cfg.residual_in_fp32, device=device, dtype=dtype, bimamba_type=cfg.bimamba_type, if_bimamba=cfg.if_bimamba, mixer_type=cfg.mixer_type, if_devide_out=cfg.if_devide_out, momentum=cfg.momentum, norm_target=cfg.norm_target, ) def build_pretrain_model(cfg: ModelConfig, device: torch.device) -> nn.Module: if cfg.model_type != "mamba": raise ValueError("This cleaned Flexibrain build currently keeps only the Mamba pretrain/downstream path") return build_mamba_backbone(cfg, device=device, dtype=torch.float32).to(device) def load_checkpoint(path: str, device: torch.device): return torch.load(path, map_location=device) def state_dict_from_checkpoint(checkpoint: dict): if "model_state_dict" in checkpoint: return checkpoint["model_state_dict"] if "model" in checkpoint: return checkpoint["model"] raise KeyError("Checkpoint has neither model_state_dict nor model") def build_downstream_model(cfg: ModelConfig, device: torch.device, logger: Optional[logging.Logger] = None, checkpoint_path: Optional[str] = None, from_scratch: bool = False, use_checkpoint_config: bool = True) -> nn.Module: checkpoint = None if checkpoint_path and not from_scratch: checkpoint = load_checkpoint(checkpoint_path, device) if use_checkpoint_config: apply_checkpoint_config(cfg, checkpoint.get("config", {})) if logger: logger.info("Backbone config restored from checkpoint: %s", checkpoint.get("config", {})) if cfg.model_type != "mamba": raise ValueError("This cleaned Flexibrain build currently keeps only the Mamba downstream path") backbone = build_mamba_backbone(cfg, device=device, dtype=torch.float32) if checkpoint is not None: state = state_dict_from_checkpoint(checkpoint) try: backbone.load_state_dict(state, strict=True) if logger: logger.info("Loaded pretrained backbone strictly from %s", checkpoint_path) except RuntimeError as exc: incompatible = backbone.load_state_dict(state, strict=False) backward_markers = ["_b", "conv1d_b", "x_proj_b", "dt_proj_b", "A_b_log", "D_b"] missing = list(incompatible.missing_keys) only_backward = missing and all(any(marker in key for marker in backward_markers) for key in missing) if not only_backward or incompatible.unexpected_keys: raise exc if logger: logger.warning("Strict load missed %d backward-scan BiMamba keys; loaded checkpoint with strict=False compatibility", len(missing)) elif logger: logger.info("Backbone initialized from scratch") if cfg.head_type == "transformer": model = MambaJEPAClassifier( backbone=backbone, num_classes=cfg.num_classes, head_depth=cfg.head_depth, head_num_heads=cfg.head_num_heads, head_mlp_ratio=cfg.head_mlp_ratio, head_proj_drop=cfg.head_proj_drop, head_drop_path=cfg.head_drop_path, mlp_hidden=cfg.mlp_hidden, mlp_depth=cfg.mlp_depth, mlp_dropout=cfg.mlp_dropout, freeze_backbone=cfg.freeze_backbone, device=device, ) elif cfg.head_type == "avgpool": model = MambaJEPAClassifierAvgPool(backbone=backbone, num_classes=cfg.num_classes, mlp_hidden=cfg.mlp_hidden, mlp_depth=cfg.mlp_depth, mlp_dropout=cfg.mlp_dropout, freeze_backbone=cfg.freeze_backbone, device=device) else: raise ValueError(f"Unknown head_type: {cfg.head_type}") return model.to(device)