|
|
from functools import partial |
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn.init |
|
|
from torch import Tensor, nn |
|
|
|
|
|
from .layers import SelfAttentionBlock, SwiGLUFFN, WSIFeatureEmbed, RopePositionEmbedding |
|
|
from .model_utils import named_apply |
|
|
|
|
|
ffn_layer_dict = { |
|
|
"swiglu": SwiGLUFFN, |
|
|
"swiglu32": partial(SwiGLUFFN, align_to=32), |
|
|
"swiglu64": partial(SwiGLUFFN, align_to=64), |
|
|
"swiglu128": partial(SwiGLUFFN, align_to=128), |
|
|
} |
|
|
|
|
|
norm_layer_dict = { |
|
|
"layernorm": partial(nn.LayerNorm, eps=1e-5), |
|
|
"layernormbf16": partial(nn.LayerNorm, eps=1e-5), |
|
|
"rmsnorm": partial(nn.RMSNorm, eps=1e-5), |
|
|
} |
|
|
|
|
|
dtype_dict = { |
|
|
"fp32": torch.float32, |
|
|
"fp16": torch.float16, |
|
|
"bf16": torch.bfloat16, |
|
|
} |
|
|
|
|
|
|
|
|
def init_weights_vit(module: nn.Module, name: str = ""): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.trunc_normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
if isinstance(module, nn.LayerNorm): |
|
|
module.reset_parameters() |
|
|
if isinstance(module, nn.RMSNorm): |
|
|
module.reset_parameters() |
|
|
|
|
|
|
|
|
class VisionTransformer(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
|
|
|
input_dim: int = 768, |
|
|
patch_size: int = 256, |
|
|
embed_use_norm: bool = True, |
|
|
pos_embed_rope_base: float = 100.0, |
|
|
pos_embed_rope_min_period: float | None = None, |
|
|
pos_embed_rope_max_period: float | None = None, |
|
|
pos_embed_rope_dtype: str = "fp32", |
|
|
embed_dim: int = 768, |
|
|
depth: int = 12, |
|
|
num_heads: int = 12, |
|
|
ffn_ratio: float = 4.0, |
|
|
qkv_bias: bool = True, |
|
|
norm_layer: str = "layernorm", |
|
|
ffn_layer: str = "swiglu128", |
|
|
ffn_bias: bool = True, |
|
|
proj_bias: bool = True, |
|
|
ffn_drop: float = 0.0, |
|
|
attn_drop: float = 0.0, |
|
|
n_storage_tokens: int = 0, |
|
|
nope_interval: int = 2, |
|
|
device: Any | None = None, |
|
|
**ignored_kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
del ignored_kwargs |
|
|
|
|
|
norm_layer_cls = norm_layer_dict[norm_layer] |
|
|
|
|
|
self.num_features = self.embed_dim = embed_dim |
|
|
self.n_blocks = depth |
|
|
self.num_heads = num_heads |
|
|
self.patch_size = patch_size |
|
|
self.nope_interval = max(1, int(nope_interval)) |
|
|
|
|
|
if input_dim is None: |
|
|
raise ValueError("VisionTransformer requires input_dim for WSI feature inputs.") |
|
|
self.patch_embed = WSIFeatureEmbed( |
|
|
input_dim=int(input_dim), embed_dim=embed_dim, use_norm=embed_use_norm |
|
|
) |
|
|
|
|
|
self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device)) |
|
|
self.n_storage_tokens = n_storage_tokens |
|
|
if self.n_storage_tokens > 0: |
|
|
self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device)) |
|
|
self.rope_embed = RopePositionEmbedding( |
|
|
embed_dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
patch_size=patch_size, |
|
|
base=pos_embed_rope_base, |
|
|
min_period=pos_embed_rope_min_period, |
|
|
max_period=pos_embed_rope_max_period, |
|
|
dtype=dtype_dict[pos_embed_rope_dtype], |
|
|
device=device, |
|
|
) |
|
|
ffn_layer_cls = ffn_layer_dict[ffn_layer] |
|
|
ffn_ratio_sequence = [ffn_ratio] * depth |
|
|
blocks_list = [ |
|
|
SelfAttentionBlock( |
|
|
dim=embed_dim, |
|
|
num_heads=num_heads, |
|
|
ffn_ratio=ffn_ratio_sequence[i], |
|
|
qkv_bias=qkv_bias, |
|
|
proj_bias=proj_bias, |
|
|
ffn_bias=ffn_bias, |
|
|
ffn_drop=ffn_drop, |
|
|
attn_drop=attn_drop, |
|
|
norm_layer=norm_layer_cls, |
|
|
ffn_layer=ffn_layer_cls, |
|
|
device=device, |
|
|
) |
|
|
for i in range(depth) |
|
|
] |
|
|
|
|
|
self.blocks = nn.ModuleList(blocks_list) |
|
|
|
|
|
|
|
|
self.norm = norm_layer_cls(embed_dim) |
|
|
|
|
|
def init_weights(self): |
|
|
self.rope_embed._init_weights() |
|
|
nn.init.normal_(self.cls_token, std=0.02) |
|
|
if self.n_storage_tokens > 0: |
|
|
nn.init.normal_(self.storage_tokens, std=0.02) |
|
|
named_apply(init_weights_vit, self) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: Tensor, |
|
|
masks: Tensor, |
|
|
coords: Optional[Tensor] = None, |
|
|
contour_index: Optional[Tensor] = None, |
|
|
) -> Dict[str, Tensor]: |
|
|
|
|
|
B = x.size(0) |
|
|
|
|
|
storage_tokens = self.storage_tokens if self.n_storage_tokens > 0 else None |
|
|
x = self.patch_embed(x, cls_token=self.cls_token, storage_tokens=storage_tokens) |
|
|
|
|
|
|
|
|
attn_key_mask = masks |
|
|
if attn_key_mask.dim() > 2: |
|
|
raise ValueError("masks must be of shape [B, N] with True for valid tokens") |
|
|
if attn_key_mask.dtype != torch.bool: |
|
|
attn_key_mask = attn_key_mask.to(dtype=torch.bool) |
|
|
|
|
|
prefix = torch.ones(B, 1 + self.n_storage_tokens, dtype=torch.bool, device=x.device) |
|
|
attn_key_mask = torch.cat([prefix, attn_key_mask], dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
T = x.size(1) |
|
|
pre_len = 1 + self.n_storage_tokens |
|
|
attn_mask_pairwise: Optional[Tensor] = None |
|
|
if contour_index is not None: |
|
|
|
|
|
if contour_index.dim() != 2: |
|
|
raise ValueError(f"contour_index must be of shape [B, N], received shape={tuple(contour_index.shape)}") |
|
|
contour_index = contour_index.to(device=x.device, dtype=torch.long) |
|
|
|
|
|
prefix_ci = torch.full((B, pre_len), fill_value=-1, device=x.device, dtype=torch.long) |
|
|
full_ci = torch.cat([prefix_ci, contour_index], dim=1) |
|
|
|
|
|
|
|
|
attn_mask_pairwise = full_ci.unsqueeze(2).eq(full_ci.unsqueeze(1)) |
|
|
|
|
|
|
|
|
|
|
|
nope_mask_broadcast = attn_key_mask.view(B, 1, 1, -1) |
|
|
|
|
|
attn_mask_broadcast = nope_mask_broadcast |
|
|
if attn_mask_pairwise is not None: |
|
|
pairwise_broadcast = attn_mask_pairwise.view(B, 1, T, T) |
|
|
|
|
|
attn_mask_broadcast = pairwise_broadcast & nope_mask_broadcast |
|
|
|
|
|
|
|
|
rope_sincos = self.rope_embed(coords=coords) if (self.rope_embed is not None and coords is not None) else None |
|
|
|
|
|
for idx, blk in enumerate(self.blocks): |
|
|
if (idx % self.nope_interval) == 0: |
|
|
|
|
|
x = blk(x, rope=None, attn_mask=nope_mask_broadcast) |
|
|
else: |
|
|
|
|
|
x = blk(x, rope=rope_sincos, attn_mask=attn_mask_broadcast) |
|
|
|
|
|
|
|
|
x_norm = self.norm(x) |
|
|
x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1] |
|
|
x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :] |
|
|
|
|
|
return { |
|
|
"x_norm_clstoken": x_norm_cls_reg[:, 0], |
|
|
"x_storage_tokens": x_norm_cls_reg[:, 1:], |
|
|
"x_norm_patchtokens": x_norm_patch, |
|
|
"x_prenorm": x, |
|
|
"masks": masks, |
|
|
} |
|
|
|