| | import math |
| | import random |
| | from collections import OrderedDict |
| | from typing import Literal |
| |
|
| | import torch |
| | from timm.layers import DropPath, Mlp |
| | from timm.models.vision_transformer import LayerScale |
| | from torch import nn |
| | from torch.nn import functional as F |
| | from torch.nn.attention import SDPBackend, sdpa_kernel |
| |
|
| |
|
| | class Attention(nn.Module): |
| | """ |
| | Multi-head attention module with optional query/key normalization. |
| | |
| | :param dim: Total feature dimension. |
| | :param num_heads: Number of attention heads. |
| | :param qkv_bias: Whether to include bias terms in linear projections. |
| | :param qk_norm: Whether to apply LayerNorm to individual head queries and keys. |
| | :param attn_drop: Dropout probability for attention weights. |
| | :param proj_drop: Dropout probability after the output projection. |
| | :param norm_layer: Normalization layer to use if qk_norm is True. |
| | |
| | :return: Output tensor of shape (B, N1, dim) after attention and projection. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int = 8, |
| | qkv_bias: bool = False, |
| | qk_norm: bool = False, |
| | attn_drop: float = 0.0, |
| | proj_drop: float = 0.0, |
| | norm_layer: nn.Module = nn.LayerNorm, |
| | ) -> None: |
| | super().__init__() |
| | assert dim % num_heads == 0, "dim should be divisible by num_heads" |
| | self.num_heads = num_heads |
| | self.head_dim = dim // num_heads |
| | self.scale = self.head_dim**-0.5 |
| |
|
| | self.q = nn.Linear(dim, dim, bias=qkv_bias) |
| | self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias) |
| |
|
| | self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| | self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() |
| | self.attn_drop = attn_drop |
| | self.proj = nn.Linear(dim, dim) |
| | self.proj_drop = nn.Dropout(proj_drop) |
| |
|
| | def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Forward pass for multi-head attention. |
| | |
| | :param x: Query tensor of shape (B, N1, dim). |
| | :param z: Key/Value tensor of shape (B, N2, dim). |
| | :return: Attention output tensor of shape (B, N1, dim). |
| | """ |
| | B, N1, C = x.shape |
| | B, N2, C = z.shape |
| |
|
| | q = self.q(x).reshape([B, N1, self.num_heads, self.head_dim]).swapaxes(1, 2) |
| | kv = ( |
| | self.kv(z) |
| | .reshape(B, N2, 2, self.num_heads, self.head_dim) |
| | .permute(2, 0, 3, 1, 4) |
| | ) |
| | k, v = kv.unbind(0) |
| |
|
| | q, k = self.q_norm(q), self.k_norm(k) |
| | with sdpa_kernel( |
| | [ |
| | SDPBackend.MATH, |
| | ] |
| | ): |
| | x = F.scaled_dot_product_attention( |
| | query=q, key=k, value=v, dropout_p=self.attn_drop, scale=self.scale |
| | ) |
| |
|
| | x = x.transpose(1, 2).reshape(B, N1, C) |
| | x = self.proj(x) |
| | x = self.proj_drop(x) |
| | return x |
| |
|
| |
|
| | class Block(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | num_heads: int, |
| | mlp_ratio: float = 4.0, |
| | qkv_bias: bool = False, |
| | qk_norm: bool = False, |
| | proj_drop: float = 0.0, |
| | attn_drop: float = 0.0, |
| | init_values: float = None, |
| | drop_path: float = 0.0, |
| | act_layer: nn.Module = nn.GELU, |
| | norm_layer: nn.Module = nn.LayerNorm, |
| | mlp_layer: nn.Module = Mlp, |
| | ) -> None: |
| | """ |
| | Transformer block combining attention and MLP with residual connections and optional LayerScale and DropPath. |
| | |
| | :param dim: Feature dimension. |
| | :param num_heads: Number of attention heads. |
| | :param mlp_ratio: Ratio for hidden dimension in MLP. |
| | :param qkv_bias: Whether to include bias in QKV projections. |
| | :param qk_norm: Whether to normalize Q and K. |
| | :param proj_drop: Dropout probability after output projection. |
| | :param attn_drop: Dropout probability for attention. |
| | :param init_values: Initial value for LayerScale (if None, LayerScale is Identity). |
| | :param drop_path: Dropout probability for stochastic depth. |
| | :param act_layer: Activation layer for MLP. |
| | :param norm_layer: Normalization layer. |
| | :param mlp_layer: MLP module class. |
| | |
| | :return: Output tensor of shape (B, N, dim). |
| | """ |
| | super().__init__() |
| | self.x_norm = nn.LayerNorm(dim) |
| | self.z_norm = nn.LayerNorm(dim) |
| | self.attn = Attention( |
| | dim, |
| | num_heads=num_heads, |
| | qkv_bias=qkv_bias, |
| | qk_norm=qk_norm, |
| | attn_drop=attn_drop, |
| | proj_drop=proj_drop, |
| | norm_layer=norm_layer, |
| | ) |
| |
|
| | self.ls1 = ( |
| | LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | ) |
| | self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | self.norm2 = norm_layer(dim) |
| | self.mlp = mlp_layer( |
| | in_features=dim, |
| | hidden_features=int(dim * mlp_ratio), |
| | act_layer=act_layer, |
| | drop=proj_drop, |
| | ) |
| | self.ls2 = ( |
| | LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
| | ) |
| | self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
| |
|
| | def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor: |
| | """ |
| | Forward pass for a transformer block. |
| | |
| | :param x: Input tensor of shape (B, N, dim). |
| | :param z: Conditioning tensor for attention of same shape. |
| | :return: Output tensor of same shape after attention and MLP. |
| | """ |
| | x = x + self.drop_path1(self.ls1(self.attn(self.x_norm(x), self.z_norm(z)))) |
| | x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) |
| | return x |
| |
|
| |
|
| | class HistaugModel(nn.Module): |
| | """ |
| | Hierarchical augmentation transformer model for embedding input features and augmentations. |
| | |
| | :param input_dim: Dimensionality of raw input features. |
| | :param depth: Number of transformer blocks. |
| | :param num_heads: Number of attention heads. |
| | :param mlp_ratio: Ratio for hidden features in MLP layers. |
| | :param use_transform_pos_embeddings: Whether to include sequence positional embeddings for augmentations. |
| | :param positional_encoding_type: Type for transform positional embeddings ('learnable' or 'sinusoidal'). |
| | :param final_activation: Name of activation layer for final head. |
| | :param chunk_size: Number of chunks to split the input. |
| | :param transforms: Dictionary containing augmentation parameter configurations. |
| | :param device: Device for tensors and buffers. |
| | :param kwargs: Additional unused keyword arguments. |
| | |
| | :return: Output tensor of shape (B, input_dim) after augmentation and transformer processing. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_dim, |
| | depth, |
| | num_heads, |
| | mlp_ratio, |
| | use_transform_pos_embeddings=True, |
| | positional_encoding_type="learnable", |
| | final_activation="Identity", |
| | chunk_size=16, |
| | transforms=None, |
| | device=torch.device("cpu"), |
| | **kwargs, |
| | ): |
| | super().__init__() |
| | |
| | assert input_dim % chunk_size == 0, "input_dim must be divisble by chunk_size" |
| |
|
| | self.input_dim = input_dim |
| |
|
| | self.chunk_size = chunk_size |
| | self.transforms_parameters = transforms["parameters"] |
| | self.aug_param_names = sorted(self.transforms_parameters.keys()) |
| |
|
| | self.use_transform_pos_embeddings = use_transform_pos_embeddings |
| | self.positional_encoding_type = ( |
| | positional_encoding_type |
| | ) |
| | self.num_classes = 0 |
| | self.num_features = 0 |
| | self.num_classes = 0 |
| | self.embed_dim = self.input_dim // self.chunk_size |
| | self.chunk_pos_embeddings = self._get_sinusoidal_embeddings( |
| | self.chunk_size, self.embed_dim |
| | ) |
| | self.register_buffer("chunk_pos_embeddings_buffer", self.chunk_pos_embeddings) |
| | if use_transform_pos_embeddings: |
| | if positional_encoding_type == "learnable": |
| | self.sequence_pos_embedding = nn.Embedding( |
| | len(transforms["parameters"]), self.embed_dim |
| | ) |
| | elif positional_encoding_type == "sinusoidal": |
| | sinusoidal_embeddings = self._get_sinusoidal_embeddings( |
| | len(transforms["parameters"]), self.embed_dim |
| | ) |
| | self.register_buffer("sequence_pos_embedding", sinusoidal_embeddings) |
| | else: |
| | raise ValueError( |
| | f"Invalid positional_encoding_type: {positional_encoding_type}. Choose 'learnable' or 'sinusoidal'." |
| | ) |
| | else: |
| | print("Do not use transform positional embeddings") |
| |
|
| | self.transform_embeddings = self._get_transforms_embeddings( |
| | transforms["parameters"], self.embed_dim |
| | ) |
| |
|
| | self.features_embed = nn.Sequential( |
| | nn.Linear(input_dim, self.embed_dim), nn.LayerNorm(self.embed_dim) |
| | ) |
| |
|
| | self.blocks = nn.ModuleList( |
| | [ |
| | Block(dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio) |
| | for _ in range(depth) |
| | ] |
| | ) |
| | self.norm = nn.LayerNorm(self.embed_dim) |
| |
|
| | if hasattr(nn, final_activation): |
| | self.final_activation = getattr(nn, final_activation)() |
| | else: |
| | raise ValueError(f"Activation {final_activation} is not found in torch.nn") |
| |
|
| | self.head = nn.Sequential( |
| | nn.Linear(input_dim, input_dim), self.final_activation |
| | ) |
| |
|
| | def _get_sinusoidal_embeddings(self, num_positions, embed_dim): |
| | """ |
| | Create sinusoidal embeddings for positional encoding. |
| | |
| | :param num_positions: Number of positions to encode. |
| | :param embed_dim: Dimensionality of each embedding vector. |
| | :return: Tensor of shape (num_positions, embed_dim) containing positional encodings. |
| | """ |
| | assert embed_dim % 2 == 0, "embed_dim must be even" |
| | position = torch.arange( |
| | 0, num_positions, dtype=torch.float |
| | ).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, embed_dim, 2, dtype=torch.float) |
| | * (-math.log(10000.0) / embed_dim) |
| | ) |
| |
|
| | pe = torch.zeros(num_positions, embed_dim) |
| | pe[:, 0::2] = torch.sin(position * div_term) |
| | pe[:, 1::2] = torch.cos(position * div_term) |
| |
|
| | return pe |
| |
|
| | def _get_transforms_embeddings(self, transforms, embed_dim): |
| | """ |
| | Create embedding modules for each augmentation parameter. |
| | |
| | :param transforms: Mapping of augmentation names to configuration. |
| | :param embed_dim: Dimensionality of the embeddings. |
| | :return: ModuleDict of embeddings for each augmentation type. |
| | """ |
| | transform_embeddings = nn.ModuleDict() |
| | for aug_name in transforms: |
| | if aug_name in [ |
| | "rotation", |
| | "h_flip", |
| | "v_flip", |
| | "gaussian_blur", |
| | "erosion", |
| | "dilation", |
| | ]: |
| | |
| | transform_embeddings[aug_name] = nn.Embedding( |
| | num_embeddings=2 if aug_name != "rotation" else 4, |
| | embedding_dim=embed_dim, |
| | ) |
| | elif aug_name in ["crop"]: |
| | |
| | transform_embeddings[aug_name] = nn.Embedding( |
| | num_embeddings=6, embedding_dim=embed_dim |
| | ) |
| | elif aug_name in [ |
| | "brightness", |
| | "contrast", |
| | "saturation", |
| | "hed", |
| | "hue", |
| | "powerlaw", |
| | ]: |
| | |
| | transform_embeddings[aug_name] = nn.Sequential( |
| | nn.Linear(1, embed_dim * 2), |
| | nn.SiLU(), |
| | nn.Linear(embed_dim * 2, embed_dim), |
| | ) |
| | else: |
| | raise ValueError( |
| | f"{aug_name} is not a valid augmentation parameter name" |
| | ) |
| | return transform_embeddings |
| |
|
| | def forward_aug_params_embed(self, aug_params): |
| | """ |
| | Embed augmentation parameters and add positional embeddings if enabled. |
| | |
| | :param aug_params: OrderedDict mapping augmentation names to (value_tensor, position_tensor). |
| | :return: Tensor of shape (B, K, embed_dim) of embedded transform tokens. |
| | """ |
| | z_transforms = [] |
| | for aug_name, (aug_param, pos) in aug_params.items(): |
| | if aug_name in [ |
| | "rotation", |
| | "h_flip", |
| | "v_flip", |
| | "gaussian_blur", |
| | "erosion", |
| | "dilation", |
| | "crop", |
| | ]: |
| | z_transform = self.transform_embeddings[aug_name](aug_param) |
| | elif aug_name in [ |
| | "brightness", |
| | "contrast", |
| | "saturation", |
| | "hue", |
| | "powerlaw", |
| | "hed", |
| | ]: |
| | z_transform = self.transform_embeddings[aug_name]( |
| | aug_param[..., None].float() |
| | ) |
| | else: |
| | raise ValueError( |
| | f"{aug_name} is not a valid augmentation parameter name" |
| | ) |
| | |
| | if self.use_transform_pos_embeddings: |
| | if self.positional_encoding_type == "learnable": |
| | pos_index = torch.as_tensor(pos, device=aug_param.device) |
| | pos_embedding = self.sequence_pos_embedding(pos_index) |
| | elif self.positional_encoding_type == "sinusoidal": |
| | pos_embedding = self.sequence_pos_embedding[pos].to( |
| | aug_param.device |
| | ) |
| | else: |
| | raise ValueError( |
| | f"Invalid positional_encoding_type: {self.positional_encoding_type}" |
| | ) |
| | z_transform_with_pos = z_transform + pos_embedding |
| | z_transforms.append(z_transform_with_pos) |
| | else: |
| | z_transforms.append(z_transform) |
| |
|
| | |
| | z_transforms = torch.stack(z_transforms, dim=1) |
| | return z_transforms |
| |
|
| | def sample_aug_params( |
| | self, |
| | batch_size: int, |
| | device: torch.device = torch.device("cuda"), |
| | mode: Literal["instance_wise", "wsi_wise"] = "wsi_wise", |
| | ): |
| | """ |
| | Sample random augmentation parameters and their relative positions. |
| | |
| | If a transform from the supported list is missing in self.aug_param_names, |
| | include it with zero values and append it at unique tail positions. |
| | """ |
| | if mode not in ("instance_wise", "wsi_wise"): |
| | raise ValueError('mode must be "instance_wise" or "wsi_wise"') |
| |
|
| | supported_aug_names = [ |
| | "rotation", |
| | "crop", |
| | "h_flip", |
| | "v_flip", |
| | "gaussian_blur", |
| | "erosion", |
| | "dilation", |
| | "brightness", |
| | "contrast", |
| | "saturation", |
| | "hue", |
| | "powerlaw", |
| | "hed", |
| | ] |
| |
|
| | canonical_names = sorted(self.transforms_parameters.keys()) |
| | num_transforms = len(canonical_names) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | missing_names = [n for n in supported_aug_names if n not in canonical_names] |
| | required_positions = num_transforms + len(missing_names) |
| |
|
| | |
| | if mode == "instance_wise": |
| | permutation_matrix = ( |
| | torch.stack( |
| | [ |
| | torch.randperm(num_transforms, device=device) |
| | for _ in range(batch_size) |
| | ], |
| | dim=0, |
| | ) |
| | if num_transforms > 0 |
| | else torch.empty((batch_size, 0), dtype=torch.long, device=device) |
| | ) |
| | else: |
| | if num_transforms > 0: |
| | single_permutation = torch.randperm(num_transforms, device=device) |
| | permutation_matrix = single_permutation.unsqueeze(0).repeat( |
| | batch_size, 1 |
| | ) |
| | else: |
| | permutation_matrix = torch.empty( |
| | (batch_size, 0), dtype=torch.long, device=device |
| | ) |
| |
|
| | positions_matrix = ( |
| | torch.argsort(permutation_matrix, dim=1) |
| | if num_transforms > 0 |
| | else torch.empty((batch_size, 0), dtype=torch.long, device=device) |
| | ) |
| |
|
| | augmentation_parameters = OrderedDict() |
| | |
| | for transform_index, name in enumerate(canonical_names): |
| | config = self.transforms_parameters[name] |
| |
|
| | if name == "rotation": |
| | probability = float(config) |
| | if mode == "instance_wise": |
| | apply_mask = torch.rand(batch_size, device=device) < probability |
| | random_angles = torch.randint(0, 4, (batch_size,), device=device) |
| | random_angles[~apply_mask] = 0 |
| | value_tensor = random_angles |
| | else: |
| | apply = random.random() < probability |
| | angle = random.randint(1, 3) if apply else 0 |
| | value_tensor = torch.full( |
| | (batch_size,), angle, dtype=torch.int64, device=device |
| | ) |
| |
|
| | elif name == "crop": |
| | probability = float(config) |
| | if mode == "instance_wise": |
| | apply_mask = torch.rand(batch_size, device=device) < probability |
| | random_crops = torch.randint(0, 5, (batch_size,), device=device) |
| | random_crops[~apply_mask] = 0 |
| | value_tensor = random_crops |
| | else: |
| | apply = random.random() < probability |
| | crop_code = random.randint(1, 4) if apply else 0 |
| | value_tensor = torch.full( |
| | (batch_size,), crop_code, dtype=torch.int64, device=device |
| | ) |
| |
|
| | elif name in ("h_flip", "v_flip", "gaussian_blur", "erosion", "dilation"): |
| | probability = float(config) |
| | if mode == "instance_wise": |
| | value_tensor = ( |
| | torch.rand(batch_size, device=device) < probability |
| | ).int() |
| | else: |
| | bit = int(random.random() < probability) |
| | value_tensor = torch.full( |
| | (batch_size,), bit, dtype=torch.int32, device=device |
| | ) |
| |
|
| | elif name in ( |
| | "brightness", |
| | "contrast", |
| | "saturation", |
| | "hue", |
| | "powerlaw", |
| | "hed", |
| | ): |
| | lower_bound, upper_bound = map(float, config) |
| | if mode == "instance_wise": |
| | value_tensor = torch.empty(batch_size, device=device).uniform_( |
| | lower_bound, upper_bound |
| | ) |
| | else: |
| | scalar_value = random.uniform(lower_bound, upper_bound) |
| | value_tensor = torch.full( |
| | (batch_size,), scalar_value, dtype=torch.float32, device=device |
| | ) |
| |
|
| | else: |
| | raise ValueError(f"'{name}' is not a recognised augmentation name") |
| |
|
| | position_tensor = positions_matrix[:, transform_index] |
| | augmentation_parameters[name] = (value_tensor, position_tensor) |
| |
|
| | for i, name in enumerate(missing_names): |
| | if name in ("rotation", "crop"): |
| | zeros = torch.zeros(batch_size, dtype=torch.int64, device=device) |
| | elif name in ("h_flip", "v_flip", "gaussian_blur", "erosion", "dilation"): |
| | zeros = torch.zeros(batch_size, dtype=torch.int32, device=device) |
| | else: |
| | zeros = torch.zeros(batch_size, dtype=torch.float32, device=device) |
| |
|
| | tail_pos = num_transforms + i |
| | pos = torch.full((batch_size,), tail_pos, dtype=torch.long, device=device) |
| | augmentation_parameters[name] = (zeros, pos) |
| |
|
| | return augmentation_parameters |
| |
|
| | def forward(self, x, aug_params, **kwargs): |
| | """ |
| | Forward pass: embed features, apply transformer blocks, and produce output. |
| | |
| | :param x: Input tensor of shape (B, input_dim). |
| | :param aug_params: Augmentation parameters from sample_aug_params. |
| | :return: Output tensor of shape (B, input_dim). |
| | """ |
| |
|
| | x = x[:, None, :] |
| |
|
| | x = x.view(x.shape[0], self.chunk_size, self.embed_dim) |
| | pos_embeddings = self.chunk_pos_embeddings_buffer.unsqueeze(0) |
| | x = x + pos_embeddings |
| | z = self.forward_aug_params_embed(aug_params) |
| |
|
| | for block in self.blocks: |
| | x = block(x, z) |
| | x = self.norm(x) |
| |
|
| | x = x.view(x.shape[0], 1, -1) |
| | x = self.head(x) |
| | x = x[:, 0, :] |
| | return x |