File size: 3,168 Bytes
e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
from __future__ import annotations
from typing import Dict, Optional
import torch
from torch import Tensor, nn
from .slide_transformer import VisionTransformer
__all__ = ["WSIEncoderHead"]
class WSIEncoderHead(nn.Module):
"""Adapter around VisionTransformer with aggregation over patch tokens.
Inputs:
- patch_features: [B, N, C]
- patch_mask: [B, N] with 1 for valid tokens (required for correct masking)
- patch_coords: optional [B, N, 2] integer coords for RoPE
Returns:
- dict with exactly two keys:
- patch_embedding: [B, N, C_in + C] concat(raw_patch_features, transformer_patch_tokens)
- slide_embedding: [B, C_in + C] concat(masked_mean(raw_patch_features), masked_mean(transformer_patch_tokens))
"""
def __init__(
self,
transformer: VisionTransformer,
input_dim: int,
embed_dim: int, # aggregator token channel dim
) -> None:
super().__init__()
self.transformer = transformer
self.embed_dim = int(embed_dim)
self.input_dim = int(input_dim)
def _masked_mean(self, tokens: Tensor, mask: Optional[Tensor]) -> Tensor:
"""Mask-aware mean over sequence dimension without fallback.
- tokens: [B, N, C]
- mask: [B, N] with 1 valid, 0 invalid; when all invalid, returns zero-vector mean (sum=0, count=1)
"""
if mask is None:
return tokens.mean(dim=1)
valid = mask.to(dtype=tokens.dtype).unsqueeze(-1) # [B, N, 1]
sums = (tokens * valid).sum(dim=1) # [B, C]
counts = valid.sum(dim=1).clamp_min(1.0) # [B, 1]
return sums / counts
def forward(
self,
patch_features: Tensor,
patch_mask: Tensor,
patch_coords: Optional[Tensor] = None,
patch_contour_index: Optional[Tensor] = None,
) -> Dict[str, Tensor]:
# patch_features: [B, N, C], patch_mask: [B, N] with 1 for valid tokens
if patch_mask is None:
raise ValueError("WSIFeatureEncoder requires patch_mask (shape [B, N]) to be provided.")
mask = patch_mask.to(device=patch_features.device)
# Pass optional per-patch contour indices to restrict attention within contours when provided.
encoded = self.transformer(
patch_features,
masks=mask,
coords=patch_coords,
contour_index=patch_contour_index,
)
patch_tokens = encoded["x_norm_patchtokens"] # [B, N, C]
# Patch-level embedding: concat(raw patch features, transformer patch tokens)
patch_embedding = torch.cat([patch_features, patch_tokens], dim=-1) # [B, N, C_in + C]
# Slide-level embedding: concat(masked mean of raw patch features, masked mean of transformer patch tokens)
raw_patch_mean = self._masked_mean(patch_features, mask) # [B, C_in]
token_mean = self._masked_mean(patch_tokens, mask) # [B, C]
slide_embedding = torch.cat([raw_patch_mean, token_mean], dim=-1) # [B, C_in + C]
return {
"patch_embedding": patch_embedding,
"slide_embedding": slide_embedding,
}
|