File size: 8,747 Bytes
e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.init
from torch import Tensor, nn
from .layers import SelfAttentionBlock, SwiGLUFFN, WSIFeatureEmbed, RopePositionEmbedding
from .model_utils import named_apply
ffn_layer_dict = {
"swiglu": SwiGLUFFN,
"swiglu32": partial(SwiGLUFFN, align_to=32),
"swiglu64": partial(SwiGLUFFN, align_to=64),
"swiglu128": partial(SwiGLUFFN, align_to=128),
}
norm_layer_dict = {
"layernorm": partial(nn.LayerNorm, eps=1e-5),
"layernormbf16": partial(nn.LayerNorm, eps=1e-5),
"rmsnorm": partial(nn.RMSNorm, eps=1e-5),
}
dtype_dict = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
def init_weights_vit(module: nn.Module, name: str = ""):
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, nn.LayerNorm):
module.reset_parameters()
if isinstance(module, nn.RMSNorm):
module.reset_parameters()
class VisionTransformer(nn.Module):
def __init__(
self,
*,
# Always WSI feature input mode
input_dim: int = 768,
patch_size: int = 256,
embed_use_norm: bool = True,
pos_embed_rope_base: float = 100.0,
pos_embed_rope_min_period: float | None = None,
pos_embed_rope_max_period: float | None = None,
pos_embed_rope_dtype: str = "fp32",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
ffn_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: str = "layernorm",
ffn_layer: str = "swiglu128",
ffn_bias: bool = True,
proj_bias: bool = True,
ffn_drop: float = 0.0,
attn_drop: float = 0.0,
n_storage_tokens: int = 0,
nope_interval: int = 2,
device: Any | None = None,
**ignored_kwargs,
):
super().__init__()
del ignored_kwargs
norm_layer_cls = norm_layer_dict[norm_layer]
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.nope_interval = max(1, int(nope_interval)) # e.g., nope_interval=2 applies NOPE attention at block indices 0,2,4,...
if input_dim is None:
raise ValueError("VisionTransformer requires input_dim for WSI feature inputs.")
self.patch_embed = WSIFeatureEmbed(
input_dim=int(input_dim), embed_dim=embed_dim, use_norm=embed_use_norm
)
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
self.n_storage_tokens = n_storage_tokens
if self.n_storage_tokens > 0:
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
self.rope_embed = RopePositionEmbedding(
embed_dim=embed_dim,
num_heads=num_heads,
patch_size=patch_size,
base=pos_embed_rope_base,
min_period=pos_embed_rope_min_period,
max_period=pos_embed_rope_max_period,
dtype=dtype_dict[pos_embed_rope_dtype],
device=device,
)
ffn_layer_cls = ffn_layer_dict[ffn_layer]
ffn_ratio_sequence = [ffn_ratio] * depth
blocks_list = [
SelfAttentionBlock(
dim=embed_dim,
num_heads=num_heads,
ffn_ratio=ffn_ratio_sequence[i],
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
ffn_drop=ffn_drop,
attn_drop=attn_drop,
norm_layer=norm_layer_cls,
ffn_layer=ffn_layer_cls,
device=device,
)
for i in range(depth)
]
self.blocks = nn.ModuleList(blocks_list)
# This norm is applied to everything: CLS, registers, patch, and mask tokens.
self.norm = norm_layer_cls(embed_dim)
def init_weights(self):
self.rope_embed._init_weights()
nn.init.normal_(self.cls_token, std=0.02)
if self.n_storage_tokens > 0:
nn.init.normal_(self.storage_tokens, std=0.02)
named_apply(init_weights_vit, self)
def forward(
self,
x: Tensor,
masks: Tensor,
coords: Optional[Tensor] = None,
contour_index: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
# x: [B, N, C] (WSI feature tokens)
B = x.size(0)
# Concatenate CLS and storage tokens inside embedder
storage_tokens = self.storage_tokens if self.n_storage_tokens > 0 else None
x = self.patch_embed(x, cls_token=self.cls_token, storage_tokens=storage_tokens)
# Attention key padding mask: True = valid
attn_key_mask = masks
if attn_key_mask.dim() > 2:
raise ValueError("masks must be of shape [B, N] with True for valid tokens")
if attn_key_mask.dtype != torch.bool:
attn_key_mask = attn_key_mask.to(dtype=torch.bool)
# prepend valid tokens for CLS/storage
prefix = torch.ones(B, 1 + self.n_storage_tokens, dtype=torch.bool, device=x.device)
attn_key_mask = torch.cat([prefix, attn_key_mask], dim=1)
# Build pairwise contour-based attention mask when contour_index is provided.
# NOTE (mask semantics): This project follows PyTorch SDPA docs where a boolean attn_mask
# True indicates that the element should take part in attention (allowed).
# Rules:
# - CLS and storage tokens do NOT interact with patch tokens in contour-constrained ("rope attention") blocks.
# - Patch tokens can attend only to patch tokens with the same contour_index.
# - Key padding mask (invalid/padded tokens) are always excluded from attention.
T = x.size(1)
pre_len = 1 + self.n_storage_tokens # CLS + storage tokens length
attn_mask_pairwise: Optional[Tensor] = None
if contour_index is not None:
# Ensure contour_index shape is [B, N]
if contour_index.dim() != 2:
raise ValueError(f"contour_index must be of shape [B, N], received shape={tuple(contour_index.shape)}")
contour_index = contour_index.to(device=x.device, dtype=torch.long)
# Combine prefix (-1) with provided contour indices for patch tokens
prefix_ci = torch.full((B, pre_len), fill_value=-1, device=x.device, dtype=torch.long)
full_ci = torch.cat([prefix_ci, contour_index], dim=1) # [B, T]
# For patch-to-patch, allow only same contour_index (equal values)
attn_mask_pairwise = full_ci.unsqueeze(2).eq(full_ci.unsqueeze(1)) # [B, T, T] # True means allowed
# Key padding mask as broadcastable shape for SDPA: [B, 1, 1, S]
# True = allowed (participate), False = excluded
nope_mask_broadcast = attn_key_mask.view(B, 1, 1, -1)
# Default to key-padding-only mask; if pairwise is available, intersect (AND) to keep only allowed entries
attn_mask_broadcast = nope_mask_broadcast
if attn_mask_pairwise is not None:
pairwise_broadcast = attn_mask_pairwise.view(B, 1, T, T)
# Combine pairwise (allowed) with key padding (allowed) using AND so only allowed stays True
attn_mask_broadcast = pairwise_broadcast & nope_mask_broadcast # broadcast over K dimension
# RoPE from coords (patch tokens only). coords shape: [B, N, 2] or [N, 2]
rope_sincos = self.rope_embed(coords=coords) if (self.rope_embed is not None and coords is not None) else None
for idx, blk in enumerate(self.blocks):
if (idx % self.nope_interval) == 0:
# NOPE attention block: no RoPE, only key padding constraints
x = blk(x, rope=None, attn_mask=nope_mask_broadcast)
else:
# ROPE attention block: RoPE + contour-aware pairwise mask (when provided) + key padding
x = blk(x, rope=rope_sincos, attn_mask=attn_mask_broadcast)
# Output packing
x_norm = self.norm(x)
x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]
return {
"x_norm_clstoken": x_norm_cls_reg[:, 0],
"x_storage_tokens": x_norm_cls_reg[:, 1:],
"x_norm_patchtokens": x_norm_patch,
"x_prenorm": x,
"masks": masks,
}
|