File size: 1,545 Bytes
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from typing import Optional

from torch import Tensor, nn
import torch

__all__ = ["WSIFeatureEmbed"]


class WSIFeatureEmbed(nn.Module):
    """Linear projection for WSI patch features shaped as [B, N, C].

    - Projects features to model dimension
    - Optionally applies LayerNorm
    - Optionally concatenates CLS and storage tokens provided by caller
    """

    def __init__(self, input_dim: int, embed_dim: int, use_norm: bool = True) -> None:
        super().__init__()
        self.input_dim = int(input_dim)
        self.embed_dim = int(embed_dim)
        self.proj = nn.Linear(self.input_dim, self.embed_dim)
        self.norm = nn.LayerNorm(self.embed_dim, eps=1e-5) if use_norm else nn.Identity()

    def forward(
        self,
        x: Tensor,
        cls_token: Optional[Tensor] = None,
        storage_tokens: Optional[Tensor] = None,
    ) -> Tensor:
        # x: [B, N, C], cls_token: [1,1,D], storage_tokens: [1,S,D]
        if x.dim() != 3:
            raise ValueError(f"WSIFeatureEmbed expects [B, N, C] input but received {tuple(x.shape)}")
        B = x.size(0)
        x = self.proj(x)
        x = self.norm(x)
        if cls_token is not None:
            cls = cls_token.expand(B, -1, -1)
            if storage_tokens is not None and storage_tokens.numel() > 0:
                stor = storage_tokens.expand(B, -1, -1)
            else:
                stor = x.new_empty(B, 0, x.size(-1))
            x = torch.cat([cls, stor, x], dim=1)
        return x  # [B, 1+S+N, D] if cls/storage provided else [B, N, D]