st24hour's picture
Upload folder using huggingface_hub
e101805 verified
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.nn.init
from torch import Tensor, nn
from .layers import SelfAttentionBlock, SwiGLUFFN, WSIFeatureEmbed, RopePositionEmbedding
from .model_utils import named_apply
ffn_layer_dict = {
"swiglu": SwiGLUFFN,
"swiglu32": partial(SwiGLUFFN, align_to=32),
"swiglu64": partial(SwiGLUFFN, align_to=64),
"swiglu128": partial(SwiGLUFFN, align_to=128),
}
norm_layer_dict = {
"layernorm": partial(nn.LayerNorm, eps=1e-5),
"layernormbf16": partial(nn.LayerNorm, eps=1e-5),
"rmsnorm": partial(nn.RMSNorm, eps=1e-5),
}
dtype_dict = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
def init_weights_vit(module: nn.Module, name: str = ""):
if isinstance(module, nn.Linear):
torch.nn.init.trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
if isinstance(module, nn.LayerNorm):
module.reset_parameters()
if isinstance(module, nn.RMSNorm):
module.reset_parameters()
class VisionTransformer(nn.Module):
def __init__(
self,
*,
# Always WSI feature input mode
input_dim: int = 768,
patch_size: int = 256,
embed_use_norm: bool = True,
pos_embed_rope_base: float = 100.0,
pos_embed_rope_min_period: float | None = None,
pos_embed_rope_max_period: float | None = None,
pos_embed_rope_dtype: str = "fp32",
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
ffn_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: str = "layernorm",
ffn_layer: str = "swiglu128",
ffn_bias: bool = True,
proj_bias: bool = True,
ffn_drop: float = 0.0,
attn_drop: float = 0.0,
n_storage_tokens: int = 0,
nope_interval: int = 2,
device: Any | None = None,
**ignored_kwargs,
):
super().__init__()
del ignored_kwargs
norm_layer_cls = norm_layer_dict[norm_layer]
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.nope_interval = max(1, int(nope_interval)) # e.g., nope_interval=2 applies NOPE attention at block indices 0,2,4,...
if input_dim is None:
raise ValueError("VisionTransformer requires input_dim for WSI feature inputs.")
self.patch_embed = WSIFeatureEmbed(
input_dim=int(input_dim), embed_dim=embed_dim, use_norm=embed_use_norm
)
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
self.n_storage_tokens = n_storage_tokens
if self.n_storage_tokens > 0:
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
self.rope_embed = RopePositionEmbedding(
embed_dim=embed_dim,
num_heads=num_heads,
patch_size=patch_size,
base=pos_embed_rope_base,
min_period=pos_embed_rope_min_period,
max_period=pos_embed_rope_max_period,
dtype=dtype_dict[pos_embed_rope_dtype],
device=device,
)
ffn_layer_cls = ffn_layer_dict[ffn_layer]
ffn_ratio_sequence = [ffn_ratio] * depth
blocks_list = [
SelfAttentionBlock(
dim=embed_dim,
num_heads=num_heads,
ffn_ratio=ffn_ratio_sequence[i],
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
ffn_drop=ffn_drop,
attn_drop=attn_drop,
norm_layer=norm_layer_cls,
ffn_layer=ffn_layer_cls,
device=device,
)
for i in range(depth)
]
self.blocks = nn.ModuleList(blocks_list)
# This norm is applied to everything: CLS, registers, patch, and mask tokens.
self.norm = norm_layer_cls(embed_dim)
def init_weights(self):
self.rope_embed._init_weights()
nn.init.normal_(self.cls_token, std=0.02)
if self.n_storage_tokens > 0:
nn.init.normal_(self.storage_tokens, std=0.02)
named_apply(init_weights_vit, self)
def forward(
self,
x: Tensor,
masks: Tensor,
coords: Optional[Tensor] = None,
contour_index: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
# x: [B, N, C] (WSI feature tokens)
B = x.size(0)
# Concatenate CLS and storage tokens inside embedder
storage_tokens = self.storage_tokens if self.n_storage_tokens > 0 else None
x = self.patch_embed(x, cls_token=self.cls_token, storage_tokens=storage_tokens)
# Attention key padding mask: True = valid
attn_key_mask = masks
if attn_key_mask.dim() > 2:
raise ValueError("masks must be of shape [B, N] with True for valid tokens")
if attn_key_mask.dtype != torch.bool:
attn_key_mask = attn_key_mask.to(dtype=torch.bool)
# prepend valid tokens for CLS/storage
prefix = torch.ones(B, 1 + self.n_storage_tokens, dtype=torch.bool, device=x.device)
attn_key_mask = torch.cat([prefix, attn_key_mask], dim=1)
# Build pairwise contour-based attention mask when contour_index is provided.
# NOTE (mask semantics): This project follows PyTorch SDPA docs where a boolean attn_mask
# True indicates that the element should take part in attention (allowed).
# Rules:
# - CLS and storage tokens do NOT interact with patch tokens in contour-constrained ("rope attention") blocks.
# - Patch tokens can attend only to patch tokens with the same contour_index.
# - Key padding mask (invalid/padded tokens) are always excluded from attention.
T = x.size(1)
pre_len = 1 + self.n_storage_tokens # CLS + storage tokens length
attn_mask_pairwise: Optional[Tensor] = None
if contour_index is not None:
# Ensure contour_index shape is [B, N]
if contour_index.dim() != 2:
raise ValueError(f"contour_index must be of shape [B, N], received shape={tuple(contour_index.shape)}")
contour_index = contour_index.to(device=x.device, dtype=torch.long)
# Combine prefix (-1) with provided contour indices for patch tokens
prefix_ci = torch.full((B, pre_len), fill_value=-1, device=x.device, dtype=torch.long)
full_ci = torch.cat([prefix_ci, contour_index], dim=1) # [B, T]
# For patch-to-patch, allow only same contour_index (equal values)
attn_mask_pairwise = full_ci.unsqueeze(2).eq(full_ci.unsqueeze(1)) # [B, T, T] # True means allowed
# Key padding mask as broadcastable shape for SDPA: [B, 1, 1, S]
# True = allowed (participate), False = excluded
nope_mask_broadcast = attn_key_mask.view(B, 1, 1, -1)
# Default to key-padding-only mask; if pairwise is available, intersect (AND) to keep only allowed entries
attn_mask_broadcast = nope_mask_broadcast
if attn_mask_pairwise is not None:
pairwise_broadcast = attn_mask_pairwise.view(B, 1, T, T)
# Combine pairwise (allowed) with key padding (allowed) using AND so only allowed stays True
attn_mask_broadcast = pairwise_broadcast & nope_mask_broadcast # broadcast over K dimension
# RoPE from coords (patch tokens only). coords shape: [B, N, 2] or [N, 2]
rope_sincos = self.rope_embed(coords=coords) if (self.rope_embed is not None and coords is not None) else None
for idx, blk in enumerate(self.blocks):
if (idx % self.nope_interval) == 0:
# NOPE attention block: no RoPE, only key padding constraints
x = blk(x, rope=None, attn_mask=nope_mask_broadcast)
else:
# ROPE attention block: RoPE + contour-aware pairwise mask (when provided) + key padding
x = blk(x, rope=rope_sincos, attn_mask=attn_mask_broadcast)
# Output packing
x_norm = self.norm(x)
x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]
return {
"x_norm_clstoken": x_norm_cls_reg[:, 0],
"x_storage_tokens": x_norm_cls_reg[:, 1:],
"x_norm_patchtokens": x_norm_patch,
"x_prenorm": x,
"masks": masks,
}