| import torch
|
| import torch.nn as nn
|
| from flexibrain.models.transformer_block import Block
|
|
|
| from flexibrain.models.mamba_jepa import VolumeMambaJEPA
|
|
|
|
|
| class MambaJEPAClassifier(nn.Module):
|
|
|
| def __init__(
|
| self,
|
| backbone: 'VolumeMambaJEPA',
|
| num_classes: int,
|
| head_depth: int = 2,
|
| head_num_heads: int = 8,
|
| head_mlp_ratio: float = 4.0,
|
| head_qkv_bias: bool = True,
|
| head_attn_drop: float = 0.0,
|
| head_proj_drop: float = 0.0,
|
| head_drop_path: float = 0.0,
|
| head_norm_epsilon: float = 1e-5,
|
| mlp_hidden: int = 1024,
|
| mlp_depth: int = 2,
|
| mlp_dropout: float = 0.1,
|
| freeze_backbone: bool = False,
|
| device=None,
|
| dtype=None,
|
| ):
|
| super().__init__()
|
| self.backbone = backbone
|
| self.embed_dim = backbone.embed_dim
|
|
|
| if freeze_backbone:
|
| for p in self.backbone.parameters():
|
| p.requires_grad = False
|
| else:
|
| for p in self.backbone.parameters():
|
| p.requires_grad = True
|
|
|
| if dtype is None:
|
| dtype = next(self.backbone.parameters()).dtype
|
| if device is None:
|
| device = next(self.backbone.parameters()).device
|
| factory = dict(device=device, dtype=dtype)
|
|
|
| self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim, **factory))
|
| nn.init.normal_(self.cls_token, std=0.02)
|
|
|
| dpr = [x.item() for x in torch.linspace(0, head_drop_path, head_depth)] if head_depth > 0 else []
|
| def _norm_layer_with_dtype(dim):
|
| ln = nn.LayerNorm(dim, eps=head_norm_epsilon)
|
| if device is not None:
|
| ln = ln.to(device=device)
|
| if dtype is not None:
|
| ln = ln.to(dtype=dtype)
|
| return ln
|
| self.head_blocks = nn.ModuleList([
|
| Block(
|
| dim=self.embed_dim,
|
| num_heads=head_num_heads,
|
| mlp_ratio=head_mlp_ratio,
|
| qkv_bias=head_qkv_bias,
|
| attn_drop=head_attn_drop,
|
| drop=head_proj_drop,
|
| drop_path=dpr[i] if head_depth > 0 else 0.0,
|
| norm_layer=_norm_layer_with_dtype,
|
| ) for i in range(head_depth)
|
| ])
|
| self.head_norm = nn.LayerNorm(self.embed_dim, eps=head_norm_epsilon, **factory)
|
|
|
| layers = []
|
| in_dim = self.embed_dim
|
| for _ in range(max(mlp_depth - 1, 0)):
|
| layers += [nn.Linear(in_dim, mlp_hidden, **factory), nn.GELU(), nn.Dropout(mlp_dropout)]
|
| in_dim = mlp_hidden
|
| layers += [nn.Linear(in_dim, num_classes, **factory)]
|
| self.classifier = nn.Sequential(*layers)
|
|
|
| @torch.no_grad()
|
| def _encode_backbone_nograd(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
|
| xf, attn_pad, _, _ = self.backbone.patch_embed(x, meta, orig_Ts, affines)
|
| feat = self.backbone._run_blocks(
|
| xf, attn_pad,
|
| blocks=self.backbone.blocks,
|
| norm_layer=self.backbone.norm_f,
|
| inference_params=inference_params
|
| )
|
| return feat, attn_pad
|
|
|
| def _encode_backbone(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
|
| xf, attn_pad, _, _ = self.backbone.patch_embed(x, meta, orig_Ts, affines)
|
|
|
| feat = self.backbone._run_blocks(
|
| xf, attn_pad,
|
| blocks=self.backbone.blocks,
|
| norm_layer=self.backbone.norm_f,
|
| inference_params=inference_params
|
| )
|
|
|
| return feat, attn_pad
|
|
|
| def forward_from_tokens(self, tokens, attn_pad, inference_params=None):
|
|
|
| B, L, D = tokens.shape
|
| device = tokens.device
|
|
|
| cls_tok = self.cls_token.to(dtype=tokens.dtype).expand(B, -1, -1)
|
| x_cat = torch.cat([cls_tok, tokens], dim=1)
|
|
|
| cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=device)
|
| attn_cat = torch.cat([cls_pad, attn_pad], dim=1)
|
|
|
| attn_cat_for_flash = ~attn_cat
|
|
|
| h = x_cat
|
| for blk in self.head_blocks:
|
| h = blk(h, attention_mask=attn_cat_for_flash)
|
| h = self.head_norm(h)
|
|
|
| cls_feat = h[:, 0, :]
|
| logits = self.classifier(cls_feat)
|
| return logits
|
|
|
| def forward(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
|
|
|
| feat, attn_pad = self._encode_backbone(x, meta=meta, orig_Ts=orig_Ts, affines=affines,
|
| inference_params=inference_params)
|
| B, L, D = feat.shape
|
| device = feat.device
|
|
|
| cls_tok = self.cls_token.to(dtype=feat.dtype).expand(B, -1, -1)
|
| x_cat = torch.cat([cls_tok, feat], dim=1)
|
|
|
| cls_pad = torch.zeros(B, 1, dtype=torch.bool, device=device)
|
| attn_cat = torch.cat([cls_pad, attn_pad], dim=1)
|
|
|
| attn_cat_for_flash = ~attn_cat
|
|
|
| h = x_cat
|
| for blk in self.head_blocks:
|
| h = blk(h, attention_mask=attn_cat_for_flash)
|
| h = self.head_norm(h)
|
|
|
| cls_feat = h[:, 0, :]
|
| logits = self.classifier(cls_feat)
|
| return logits
|
|
|
| class MambaJEPAClassifierAvgPool(nn.Module):
|
|
|
| def __init__(
|
| self,
|
| backbone: 'VolumeMambaJEPA',
|
| num_classes: int,
|
| mlp_hidden: int = 1024,
|
| mlp_depth: int = 3,
|
| mlp_dropout: float = 0.1,
|
| freeze_backbone: bool = False,
|
| device=None,
|
| dtype=None,
|
| ):
|
| super().__init__()
|
| self.backbone = backbone
|
| self.embed_dim = backbone.embed_dim
|
|
|
| if freeze_backbone:
|
| for p in self.backbone.parameters():
|
| p.requires_grad = False
|
| else:
|
| for p in self.backbone.parameters():
|
| p.requires_grad = True
|
|
|
| if dtype is None:
|
| dtype = next(self.backbone.parameters()).dtype
|
| if device is None:
|
| device = next(self.backbone.parameters()).device
|
| factory = dict(device=device, dtype=dtype)
|
|
|
| layers = []
|
| in_dim = self.embed_dim
|
| for _ in range(max(mlp_depth - 1, 0)):
|
| layers += [nn.Linear(in_dim, mlp_hidden, **factory), nn.GELU(), nn.Dropout(mlp_dropout)]
|
| in_dim = mlp_hidden
|
| layers += [nn.Linear(in_dim, num_classes, **factory)]
|
| self.classifier = nn.Sequential(*layers)
|
|
|
| def _encode_backbone(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
|
| xf, attn_pad, _, _ = self.backbone.patch_embed(x, meta, orig_Ts, affines)
|
|
|
| feat = self.backbone._run_blocks(
|
| xf, attn_pad,
|
| blocks=self.backbone.blocks,
|
| norm_layer=self.backbone.norm_f,
|
| inference_params=inference_params
|
| )
|
| return feat, attn_pad
|
|
|
| def forward_from_tokens(self, tokens, attn_pad, inference_params=None):
|
| """
|
| Args:
|
| tokens: [B, L, D] backbone token
|
| attn_pad: [B, L] attention mask (True=padding, False=valid)
|
| inference_params
|
|
|
| Returns:
|
| logits: [B, num_classes]
|
| """
|
| B, L, D = tokens.shape
|
|
|
| valid_mask = ~attn_pad
|
|
|
| valid_counts = valid_mask.sum(dim=1, keepdim=True).clamp(min=1)
|
|
|
| feat_masked = tokens * valid_mask.unsqueeze(-1).float()
|
| feat_sum = feat_masked.sum(dim=1)
|
|
|
| feat_avg = feat_sum / valid_counts.float()
|
|
|
| logits = self.classifier(feat_avg)
|
| return logits
|
|
|
| def forward(self, x, meta=None, orig_Ts=None, affines=None, inference_params=None):
|
|
|
| feat, attn_pad = self._encode_backbone(x, meta=meta, orig_Ts=orig_Ts, affines=affines,
|
| inference_params=inference_params)
|
| B, L, D = feat.shape
|
|
|
| valid_mask = ~attn_pad
|
|
|
| valid_counts = valid_mask.sum(dim=1, keepdim=True).clamp(min=1)
|
|
|
| feat_masked = feat * valid_mask.unsqueeze(-1).float()
|
| feat_sum = feat_masked.sum(dim=1)
|
|
|
| feat_avg = feat_sum / valid_counts.float()
|
|
|
| logits = self.classifier(feat_avg)
|
| return logits
|
|
|