FlexiBrain / flexibrain /models /classifier.py
OneMore1's picture
Sync from GitHub FlexiBrain main
6a51385 verified
Raw
History Blame Contribute Delete
8.82 kB
import torch
import torch.nn as nn
from flexibrain.models.transformer_block import Block
from flexibrain.models.mamba_jepa import VolumeMambaJEPA
class MambaJEPAClassifier(nn.Module):
def __init__(
self,
backbone: 'VolumeMambaJEPA',
num_classes: int,
head_depth: int = 2,
head_num_heads: int = 8,
head_mlp_ratio: float = 4.0,
head_qkv_bias: bool = True,
head_attn_drop: float = 0.0,
head_proj_drop: float = 0.0,
head_drop_path: float = 0.0,
head_norm_epsilon: float = 1e-5,
mlp_hidden: int = 1024,
mlp_depth: int = 2,
mlp_dropout: float = 0.1,
freeze_backbone: bool = False,
device=None,
dtype=None,
):
super().__init__()
self.backbone = backbone
self.embed_dim = backbone.embed_dim
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
else:
for p in self.backbone.parameters():
p.requires_grad = True
if dtype is None:
dtype = next(self.backbone.parameters()).dtype
if device is None:
device = next(self.backbone.parameters()).device
factory = dict(device=device, dtype=dtype)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, **factory))
nn.init.normal_(self.cls_token, std=0.02)
dpr = [x.item() for x in torch.linspace(0, head_drop_path, head_depth)] if head_depth > 0 else []
def _norm_layer_with_dtype(dim):
ln = nn.LayerNorm(dim, eps=head_norm_epsilon)
if device is not None:
ln = ln.to(device=device)
if dtype is not None:
ln = ln.to(dtype=dtype)
return ln
self.head_blocks = nn.ModuleList([
Block(
dim=self.embed_dim,
num_heads=head_num_heads,
mlp_ratio=head_mlp_ratio,
qkv_bias=head_qkv_bias,
attn_drop=head_attn_drop,
drop=head_proj_drop,
drop_path=dpr[i] if head_depth > 0 else 0.0,
norm_layer=_norm_layer_with_dtype,
) for i in range(head_depth)
])
self.head_norm = nn.LayerNorm(self.embed_dim, eps=head_norm_epsilon, **factory)
layers = []
in_dim = self.embed_dim
for _ in range(max(mlp_depth - 1, 0)):
layers += [nn.Linear(in_dim, mlp_hidden, **factory), nn.GELU(), nn.Dropout(mlp_dropout)]
in_dim = mlp_hidden
layers += [nn.Linear(in_dim, num_classes, **factory)]
self.classifier = nn.Sequential(*layers)
@torch.no_grad()
def _encode_backbone_nograd(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
xf, attn_pad, _, _ = self.backbone.patch_embed(x, meta, orig_Ts, affines)
feat = self.backbone._run_blocks(
xf, attn_pad,
blocks=self.backbone.blocks,
norm_layer=self.backbone.norm_f,
inference_params=inference_params
)
return feat, attn_pad
def _encode_backbone(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
xf, attn_pad, _, _ = self.backbone.patch_embed(x, meta, orig_Ts, affines)
feat = self.backbone._run_blocks(
xf, attn_pad,
blocks=self.backbone.blocks,
norm_layer=self.backbone.norm_f,
inference_params=inference_params
)
return feat, attn_pad
def forward_from_tokens(self, tokens, attn_pad, inference_params=None):
B, L, D = tokens.shape
device = tokens.device
cls_tok = self.cls_token.to(dtype=tokens.dtype).expand(B, -1, -1) # [B,1,D]
x_cat = torch.cat([cls_tok, tokens], dim=1) # [B,1+L,D]
cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=device) # False=valid
attn_cat = torch.cat([cls_pad, attn_pad], dim=1) # [B,1+L]
attn_cat_for_flash = ~attn_cat
h = x_cat
for blk in self.head_blocks:
h = blk(h, attention_mask=attn_cat_for_flash)
h = self.head_norm(h)
cls_feat = h[:, 0, :] # [B,D]
logits = self.classifier(cls_feat) # [B,C]
return logits
def forward(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
feat, attn_pad = self._encode_backbone(x, meta=meta, orig_Ts=orig_Ts, affines=affines,
inference_params=inference_params)
B, L, D = feat.shape
device = feat.device
cls_tok = self.cls_token.to(dtype=feat.dtype).expand(B, -1, -1) # [B,1,D]
x_cat = torch.cat([cls_tok, feat], dim=1) # [B,1+L,D]
cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=device) # False=valid
attn_cat = torch.cat([cls_pad, attn_pad], dim=1) # [B,1+L]
attn_cat_for_flash = ~attn_cat
h = x_cat
for blk in self.head_blocks:
h = blk(h, attention_mask=attn_cat_for_flash)
h = self.head_norm(h)
cls_feat = h[:, 0, :] # [B,D]
logits = self.classifier(cls_feat) # [B,C]
return logits
class MambaJEPAClassifierAvgPool(nn.Module):
def __init__(
self,
backbone: 'VolumeMambaJEPA',
num_classes: int,
mlp_hidden: int = 1024,
mlp_depth: int = 3,
mlp_dropout: float = 0.1,
freeze_backbone: bool = False,
device=None,
dtype=None,
):
super().__init__()
self.backbone = backbone
self.embed_dim = backbone.embed_dim
if freeze_backbone:
for p in self.backbone.parameters():
p.requires_grad = False
else:
for p in self.backbone.parameters():
p.requires_grad = True
if dtype is None:
dtype = next(self.backbone.parameters()).dtype
if device is None:
device = next(self.backbone.parameters()).device
factory = dict(device=device, dtype=dtype)
layers = []
in_dim = self.embed_dim
for _ in range(max(mlp_depth - 1, 0)):
layers += [nn.Linear(in_dim, mlp_hidden, **factory), nn.GELU(), nn.Dropout(mlp_dropout)]
in_dim = mlp_hidden
layers += [nn.Linear(in_dim, num_classes, **factory)]
self.classifier = nn.Sequential(*layers)
def _encode_backbone(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
xf, attn_pad, _, _ = self.backbone.patch_embed(x, meta, orig_Ts, affines)
feat = self.backbone._run_blocks(
xf, attn_pad,
blocks=self.backbone.blocks,
norm_layer=self.backbone.norm_f,
inference_params=inference_params
)
return feat, attn_pad
def forward_from_tokens(self, tokens, attn_pad, inference_params=None):
"""
Args:
tokens: [B, L, D] backbone token
attn_pad: [B, L] attention mask (True=padding, False=valid)
inference_params
Returns:
logits: [B, num_classes]
"""
B, L, D = tokens.shape
valid_mask = ~attn_pad # [B, L] -> True=valid
valid_counts = valid_mask.sum(dim=1, keepdim=True).clamp(min=1) # [B, 1]
feat_masked = tokens * valid_mask.unsqueeze(-1).float() # [B, L, D]
feat_sum = feat_masked.sum(dim=1) # [B, D]
feat_avg = feat_sum / valid_counts.float() # [B, D]
logits = self.classifier(feat_avg) # [B, num_classes]
return logits
def forward(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
feat, attn_pad = self._encode_backbone(x, meta=meta, orig_Ts=orig_Ts, affines=affines,
inference_params=inference_params)
B, L, D = feat.shape
valid_mask = ~attn_pad # [B, L] -> True=valid
valid_counts = valid_mask.sum(dim=1, keepdim=True).clamp(min=1) # [B, 1]
feat_masked = feat * valid_mask.unsqueeze(-1).float() # [B, L, D]
feat_sum = feat_masked.sum(dim=1) # [B, D]
feat_avg = feat_sum / valid_counts.float() # [B, D]
logits = self.classifier(feat_avg) # [B, num_classes]
return logits