File size: 8,747 Bytes
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
from functools import partial
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
import torch.nn.init
from torch import Tensor, nn

from .layers import SelfAttentionBlock, SwiGLUFFN, WSIFeatureEmbed, RopePositionEmbedding
from .model_utils import named_apply

ffn_layer_dict = {
    "swiglu": SwiGLUFFN,
    "swiglu32": partial(SwiGLUFFN, align_to=32),
    "swiglu64": partial(SwiGLUFFN, align_to=64),
    "swiglu128": partial(SwiGLUFFN, align_to=128),
}

norm_layer_dict = {
    "layernorm": partial(nn.LayerNorm, eps=1e-5),
    "layernormbf16": partial(nn.LayerNorm, eps=1e-5),
    "rmsnorm": partial(nn.RMSNorm, eps=1e-5),
}

dtype_dict = {
    "fp32": torch.float32,
    "fp16": torch.float16,
    "bf16": torch.bfloat16,
}


def init_weights_vit(module: nn.Module, name: str = ""):
    if isinstance(module, nn.Linear):
        torch.nn.init.trunc_normal_(module.weight, std=0.02)
        if module.bias is not None:
            nn.init.zeros_(module.bias)
    if isinstance(module, nn.LayerNorm):
        module.reset_parameters()
    if isinstance(module, nn.RMSNorm):
        module.reset_parameters()


class VisionTransformer(nn.Module):
    def __init__(
        self,
        *,
        # Always WSI feature input mode
        input_dim: int = 768,
        patch_size: int = 256,
        embed_use_norm: bool = True,
        pos_embed_rope_base: float = 100.0,
        pos_embed_rope_min_period: float | None = None,
        pos_embed_rope_max_period: float | None = None,
        pos_embed_rope_dtype: str = "fp32",
        embed_dim: int = 768,
        depth: int = 12,
        num_heads: int = 12,
        ffn_ratio: float = 4.0,
        qkv_bias: bool = True,
        norm_layer: str = "layernorm",
        ffn_layer: str = "swiglu128",
        ffn_bias: bool = True,
        proj_bias: bool = True,
        ffn_drop: float = 0.0,
        attn_drop: float = 0.0,
        n_storage_tokens: int = 0,
        nope_interval: int = 2,
        device: Any | None = None,
        **ignored_kwargs,
    ):
        super().__init__()
        del ignored_kwargs

        norm_layer_cls = norm_layer_dict[norm_layer]

        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
        self.n_blocks = depth
        self.num_heads = num_heads
        self.patch_size = patch_size
        self.nope_interval = max(1, int(nope_interval)) # e.g., nope_interval=2 applies NOPE attention at block indices 0,2,4,...

        if input_dim is None:
            raise ValueError("VisionTransformer requires input_dim for WSI feature inputs.")
        self.patch_embed = WSIFeatureEmbed(
            input_dim=int(input_dim), embed_dim=embed_dim, use_norm=embed_use_norm
        )

        self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
        self.n_storage_tokens = n_storage_tokens
        if self.n_storage_tokens > 0:
            self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
        self.rope_embed = RopePositionEmbedding(
            embed_dim=embed_dim,
            num_heads=num_heads,
            patch_size=patch_size,
            base=pos_embed_rope_base,
            min_period=pos_embed_rope_min_period,
            max_period=pos_embed_rope_max_period,
            dtype=dtype_dict[pos_embed_rope_dtype],
            device=device,
        )
        ffn_layer_cls = ffn_layer_dict[ffn_layer]
        ffn_ratio_sequence = [ffn_ratio] * depth
        blocks_list = [
            SelfAttentionBlock(
                dim=embed_dim,
                num_heads=num_heads,
                ffn_ratio=ffn_ratio_sequence[i],
                qkv_bias=qkv_bias,
                proj_bias=proj_bias,
                ffn_bias=ffn_bias,
                ffn_drop=ffn_drop,
                attn_drop=attn_drop,
                norm_layer=norm_layer_cls,
                ffn_layer=ffn_layer_cls,
                device=device,
            )
            for i in range(depth)
        ]

        self.blocks = nn.ModuleList(blocks_list)

        # This norm is applied to everything: CLS, registers, patch, and mask tokens.
        self.norm = norm_layer_cls(embed_dim)

    def init_weights(self):
        self.rope_embed._init_weights()
        nn.init.normal_(self.cls_token, std=0.02)
        if self.n_storage_tokens > 0:
            nn.init.normal_(self.storage_tokens, std=0.02)
        named_apply(init_weights_vit, self)
    
    def forward(
        self,
        x: Tensor,
        masks: Tensor,
        coords: Optional[Tensor] = None,
        contour_index: Optional[Tensor] = None,
    ) -> Dict[str, Tensor]:
        # x: [B, N, C] (WSI feature tokens)
        B = x.size(0)
        # Concatenate CLS and storage tokens inside embedder
        storage_tokens = self.storage_tokens if self.n_storage_tokens > 0 else None
        x = self.patch_embed(x, cls_token=self.cls_token, storage_tokens=storage_tokens)

        # Attention key padding mask: True = valid
        attn_key_mask = masks
        if attn_key_mask.dim() > 2:
            raise ValueError("masks must be of shape [B, N] with True for valid tokens")
        if attn_key_mask.dtype != torch.bool:
            attn_key_mask = attn_key_mask.to(dtype=torch.bool)
        # prepend valid tokens for CLS/storage
        prefix = torch.ones(B, 1 + self.n_storage_tokens, dtype=torch.bool, device=x.device)
        attn_key_mask = torch.cat([prefix, attn_key_mask], dim=1)

        # Build pairwise contour-based attention mask when contour_index is provided.
        # NOTE (mask semantics): This project follows PyTorch SDPA docs where a boolean attn_mask
        # True indicates that the element should take part in attention (allowed).
        # Rules:
        # - CLS and storage tokens do NOT interact with patch tokens in contour-constrained ("rope attention") blocks.
        # - Patch tokens can attend only to patch tokens with the same contour_index.
        # - Key padding mask (invalid/padded tokens) are always excluded from attention.
        T = x.size(1)
        pre_len = 1 + self.n_storage_tokens  # CLS + storage tokens length
        attn_mask_pairwise: Optional[Tensor] = None
        if contour_index is not None:
            # Ensure contour_index shape is [B, N]
            if contour_index.dim() != 2:
                raise ValueError(f"contour_index must be of shape [B, N], received shape={tuple(contour_index.shape)}")
            contour_index = contour_index.to(device=x.device, dtype=torch.long)
            # Combine prefix (-1) with provided contour indices for patch tokens
            prefix_ci = torch.full((B, pre_len), fill_value=-1, device=x.device, dtype=torch.long)
            full_ci = torch.cat([prefix_ci, contour_index], dim=1)  # [B, T]

            # For patch-to-patch, allow only same contour_index (equal values)
            attn_mask_pairwise = full_ci.unsqueeze(2).eq(full_ci.unsqueeze(1))  # [B, T, T] # True means allowed

        # Key padding mask as broadcastable shape for SDPA: [B, 1, 1, S]
        # True = allowed (participate), False = excluded
        nope_mask_broadcast = attn_key_mask.view(B, 1, 1, -1)
        # Default to key-padding-only mask; if pairwise is available, intersect (AND) to keep only allowed entries
        attn_mask_broadcast = nope_mask_broadcast
        if attn_mask_pairwise is not None:
            pairwise_broadcast = attn_mask_pairwise.view(B, 1, T, T)
            # Combine pairwise (allowed) with key padding (allowed) using AND so only allowed stays True
            attn_mask_broadcast = pairwise_broadcast & nope_mask_broadcast  # broadcast over K dimension

        # RoPE from coords (patch tokens only). coords shape: [B, N, 2] or [N, 2]
        rope_sincos = self.rope_embed(coords=coords) if (self.rope_embed is not None and coords is not None) else None

        for idx, blk in enumerate(self.blocks):
            if (idx % self.nope_interval) == 0:
                # NOPE attention block: no RoPE, only key padding constraints
                x = blk(x, rope=None, attn_mask=nope_mask_broadcast)
            else:
                # ROPE attention block: RoPE + contour-aware pairwise mask (when provided) + key padding
                x = blk(x, rope=rope_sincos, attn_mask=attn_mask_broadcast)

        # Output packing
        x_norm = self.norm(x)
        x_norm_cls_reg = x_norm[:, : self.n_storage_tokens + 1]
        x_norm_patch = x_norm[:, self.n_storage_tokens + 1 :]

        return {
            "x_norm_clstoken": x_norm_cls_reg[:, 0],
            "x_storage_tokens": x_norm_cls_reg[:, 1:],
            "x_norm_patchtokens": x_norm_patch,
            "x_prenorm": x,
            "masks": masks,
        }