sofieneb commited on
Commit
91eed61
·
verified ·
1 Parent(s): d85d2a1

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. README.md +10 -0
  2. _init_.py +1 -0
  3. config.json +56 -0
  4. configuration_histaug.py +33 -0
  5. histaug_model.py +569 -0
  6. model.safetensors +3 -0
  7. modeling_histaug.py +72 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - model_hub_mixin
4
+ - pytorch_model_hub_mixin
5
+ ---
6
+
7
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
8
+ - Code: [More Information Needed]
9
+ - Paper: [More Information Needed]
10
+ - Docs: [More Information Needed]
_init_.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .histaug_model import HistaugModel
config.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "model_type": "histaug",
4
+ "architectures": [
5
+ "HistaugPretrainedModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_histaug.HistaugConfig",
9
+ "AutoModel": "modeling_histaug.HistaugPretrainedModel"
10
+ },
11
+ "chunk_size": 4,
12
+ "depth": 16,
13
+ "input_dim": 512,
14
+ "mlp_ratio": 4,
15
+ "name": "histaug_model",
16
+ "num_heads": 8,
17
+ "positional_encoding_type": "learnable",
18
+ "transforms": {
19
+ "parameters": {
20
+ "brightness": [
21
+ -0.5,
22
+ 0.5
23
+ ],
24
+ "contrast": [
25
+ -0.5,
26
+ 0.5
27
+ ],
28
+ "crop": 0.75,
29
+ "dilation": 0.75,
30
+ "erosion": 0.75,
31
+ "gamma": [
32
+ -0.5,
33
+ 0.5
34
+ ],
35
+ "gaussian_blur": 0.75,
36
+ "h_flip": 0.75,
37
+ "hed": [
38
+ -0.5,
39
+ 0.5
40
+ ],
41
+ "hue": [
42
+ -0.5,
43
+ 0.5
44
+ ],
45
+ "rotation": 0.75,
46
+ "saturation": [
47
+ -0.5,
48
+ 0.5
49
+ ],
50
+ "v_flip": 0.75
51
+ }
52
+ },
53
+ "use_transform_pos_embeddings": true,
54
+ "torch_dtype": "float32",
55
+ "transformers_version": "4.46.0"
56
+ }
configuration_histaug.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class HistaugConfig(PretrainedConfig):
4
+ model_type = "histaug"
5
+
6
+ def __init__(
7
+ self,
8
+ input_dim: int = 512,
9
+ depth: int = 6,
10
+ num_heads: int = 8,
11
+ mlp_ratio: float = 4.0,
12
+ use_transform_pos_embeddings: bool = True,
13
+ positional_encoding_type: str = "learnable",
14
+ final_activation: str = "Identity",
15
+ embedding_type: str = "linear",
16
+ chunk_size: int = 16,
17
+ transforms: dict = None,
18
+ **kwargs,
19
+ ):
20
+ # your model hyperparameters
21
+ self.input_dim = input_dim
22
+ self.depth = depth
23
+ self.num_heads = num_heads
24
+ self.mlp_ratio = mlp_ratio
25
+ self.use_transform_pos_embeddings = use_transform_pos_embeddings
26
+ self.positional_encoding_type = positional_encoding_type
27
+ self.final_activation = final_activation
28
+ self.embedding_type = embedding_type
29
+ self.chunk_size = chunk_size
30
+
31
+ self.transforms = transforms or {"parameters": {}}
32
+
33
+ super().__init__(**kwargs)
histaug_model.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import Literal
5
+
6
+ import torch
7
+ from timm.layers import DropPath, Mlp
8
+ from timm.models.vision_transformer import LayerScale
9
+ from torch import nn
10
+ from torch.nn import functional as F
11
+ from torch.nn.attention import SDPBackend, sdpa_kernel
12
+
13
+
14
+ class Attention(nn.Module):
15
+ """
16
+ Multi-head attention module with optional query/key normalization.
17
+
18
+ :param dim: Total feature dimension.
19
+ :param num_heads: Number of attention heads.
20
+ :param qkv_bias: Whether to include bias terms in linear projections.
21
+ :param qk_norm: Whether to apply LayerNorm to individual head queries and keys.
22
+ :param attn_drop: Dropout probability for attention weights.
23
+ :param proj_drop: Dropout probability after the output projection.
24
+ :param norm_layer: Normalization layer to use if qk_norm is True.
25
+
26
+ :return: Output tensor of shape (B, N1, dim) after attention and projection.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ dim: int,
32
+ num_heads: int = 8,
33
+ qkv_bias: bool = False,
34
+ qk_norm: bool = False,
35
+ attn_drop: float = 0.0,
36
+ proj_drop: float = 0.0,
37
+ norm_layer: nn.Module = nn.LayerNorm,
38
+ ) -> None:
39
+ super().__init__()
40
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
41
+ self.num_heads = num_heads
42
+ self.head_dim = dim // num_heads
43
+ self.scale = self.head_dim**-0.5
44
+
45
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
46
+ self.kv = nn.Linear(dim, 2 * dim, bias=qkv_bias)
47
+
48
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
49
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
50
+ self.attn_drop = attn_drop
51
+ self.proj = nn.Linear(dim, dim)
52
+ self.proj_drop = nn.Dropout(proj_drop)
53
+
54
+ def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Forward pass for multi-head attention.
57
+
58
+ :param x: Query tensor of shape (B, N1, dim).
59
+ :param z: Key/Value tensor of shape (B, N2, dim).
60
+ :return: Attention output tensor of shape (B, N1, dim).
61
+ """
62
+ B, N1, C = x.shape
63
+ B, N2, C = z.shape
64
+
65
+ q = self.q(x).reshape([B, N1, self.num_heads, self.head_dim]).swapaxes(1, 2)
66
+ kv = (
67
+ self.kv(z)
68
+ .reshape(B, N2, 2, self.num_heads, self.head_dim)
69
+ .permute(2, 0, 3, 1, 4)
70
+ )
71
+ k, v = kv.unbind(0)
72
+
73
+ q, k = self.q_norm(q), self.k_norm(k)
74
+ with sdpa_kernel(
75
+ [
76
+ SDPBackend.MATH,
77
+ ]
78
+ ):
79
+ x = F.scaled_dot_product_attention(
80
+ query=q, key=k, value=v, dropout_p=self.attn_drop, scale=self.scale
81
+ )
82
+
83
+ x = x.transpose(1, 2).reshape(B, N1, C)
84
+ x = self.proj(x)
85
+ x = self.proj_drop(x)
86
+ return x
87
+
88
+
89
+ class Block(nn.Module):
90
+ def __init__(
91
+ self,
92
+ dim: int,
93
+ num_heads: int,
94
+ mlp_ratio: float = 4.0,
95
+ qkv_bias: bool = True,
96
+ qk_norm: bool = True,
97
+ proj_drop: float = 0.0,
98
+ attn_drop: float = 0.0,
99
+ init_values: float = None,
100
+ drop_path: float = 0.0,
101
+ act_layer: nn.Module = nn.GELU,
102
+ norm_layer: nn.Module = nn.LayerNorm,
103
+ mlp_layer: nn.Module = Mlp,
104
+ ) -> None:
105
+ """
106
+ Transformer block combining attention and MLP with residual connections and optional LayerScale and DropPath.
107
+
108
+ :param dim: Feature dimension.
109
+ :param num_heads: Number of attention heads.
110
+ :param mlp_ratio: Ratio for hidden dimension in MLP.
111
+ :param qkv_bias: Whether to include bias in QKV projections.
112
+ :param qk_norm: Whether to normalize Q and K.
113
+ :param proj_drop: Dropout probability after output projection.
114
+ :param attn_drop: Dropout probability for attention.
115
+ :param init_values: Initial value for LayerScale (if None, LayerScale is Identity).
116
+ :param drop_path: Dropout probability for stochastic depth.
117
+ :param act_layer: Activation layer for MLP.
118
+ :param norm_layer: Normalization layer.
119
+ :param mlp_layer: MLP module class.
120
+
121
+ :return: Output tensor of shape (B, N, dim).
122
+ """
123
+ super().__init__()
124
+ self.x_norm = nn.LayerNorm(dim)
125
+ self.z_norm = nn.LayerNorm(dim)
126
+ self.attn = Attention(
127
+ dim,
128
+ num_heads=num_heads,
129
+ qkv_bias=qkv_bias,
130
+ qk_norm=qk_norm,
131
+ attn_drop=attn_drop,
132
+ proj_drop=proj_drop,
133
+ norm_layer=norm_layer,
134
+ )
135
+
136
+ self.ls1 = (
137
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
138
+ )
139
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
140
+
141
+ self.norm2 = norm_layer(dim)
142
+ self.mlp = mlp_layer(
143
+ in_features=dim,
144
+ hidden_features=int(dim * mlp_ratio),
145
+ act_layer=act_layer,
146
+ drop=proj_drop,
147
+ )
148
+ self.ls2 = (
149
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
150
+ )
151
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
152
+
153
+ def forward(self, x: torch.Tensor, z: torch.Tensor) -> torch.Tensor:
154
+ """
155
+ Forward pass for a transformer block.
156
+
157
+ :param x: Input tensor of shape (B, N, dim).
158
+ :param z: Conditioning tensor for attention of same shape.
159
+ :return: Output tensor of same shape after attention and MLP.
160
+ """
161
+ x = x + self.drop_path1(self.ls1(self.attn(self.x_norm(x), self.z_norm(z))))
162
+ x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
163
+ return x
164
+
165
+
166
+ class HistaugModel(nn.Module):
167
+ """
168
+ Hierarchical augmentation transformer model for embedding input features and augmentations.
169
+
170
+ :param input_dim: Dimensionality of raw input features.
171
+ :param depth: Number of transformer blocks.
172
+ :param num_heads: Number of attention heads.
173
+ :param mlp_ratio: Ratio for hidden features in MLP layers.
174
+ :param use_transform_pos_embeddings: Whether to include sequence positional embeddings for augmentations.
175
+ :param positional_encoding_type: Type for transform positional embeddings ('learnable' or 'sinusoidal').
176
+ :param final_activation: Name of activation layer for final head.
177
+ :param chunk_size: Number of chunks to split the input.
178
+ :param transforms: Dictionary containing augmentation parameter configurations.
179
+ :param device: Device for tensors and buffers.
180
+ :param kwargs: Additional unused keyword arguments.
181
+
182
+ :return: Output tensor of shape (B, input_dim) after augmentation and transformer processing.
183
+ """
184
+
185
+ def __init__(
186
+ self,
187
+ input_dim,
188
+ depth,
189
+ num_heads,
190
+ mlp_ratio,
191
+ use_transform_pos_embeddings=True,
192
+ positional_encoding_type="learnable", # New parameter
193
+ final_activation="Identity",
194
+ chunk_size=16,
195
+ transforms=None,
196
+ device=torch.device("cpu"),
197
+ **kwargs,
198
+ ):
199
+ super().__init__()
200
+ # Features embedding
201
+ assert input_dim % chunk_size == 0, "input_dim must be divisble by chunk_size"
202
+
203
+ self.input_dim = input_dim
204
+
205
+ self.chunk_size = chunk_size
206
+ self.transforms_parameters = transforms["parameters"]
207
+ self.aug_param_names = sorted(self.transforms_parameters.keys())
208
+
209
+ self.use_transform_pos_embeddings = use_transform_pos_embeddings
210
+ self.positional_encoding_type = (
211
+ positional_encoding_type # Store the new parameter
212
+ )
213
+ self.num_classes = 0
214
+ self.num_features = 0
215
+ self.num_classes = 0
216
+ self.embed_dim = self.input_dim // self.chunk_size
217
+ self.chunk_pos_embeddings = self._get_sinusoidal_embeddings(
218
+ self.chunk_size, self.embed_dim
219
+ )
220
+ self.register_buffer("chunk_pos_embeddings_buffer", self.chunk_pos_embeddings)
221
+ if use_transform_pos_embeddings:
222
+ if positional_encoding_type == "learnable":
223
+ self.sequence_pos_embedding = nn.Embedding(
224
+ len(transforms["parameters"]), self.embed_dim
225
+ )
226
+ elif positional_encoding_type == "sinusoidal":
227
+ sinusoidal_embeddings = self._get_sinusoidal_embeddings(
228
+ len(transforms["parameters"]), self.embed_dim
229
+ )
230
+ self.register_buffer("sequence_pos_embedding", sinusoidal_embeddings)
231
+ else:
232
+ raise ValueError(
233
+ f"Invalid positional_encoding_type: {positional_encoding_type}. Choose 'learnable' or 'sinusoidal'."
234
+ )
235
+ else:
236
+ print("Do not use transform positional embeddings")
237
+
238
+ self.transform_embeddings = self._get_transforms_embeddings(
239
+ transforms["parameters"], self.embed_dim
240
+ )
241
+
242
+ self.features_embed = nn.Sequential(
243
+ nn.Linear(input_dim, self.embed_dim), nn.LayerNorm(self.embed_dim)
244
+ )
245
+
246
+ self.blocks = nn.ModuleList(
247
+ [
248
+ Block(dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio)
249
+ for _ in range(depth)
250
+ ]
251
+ )
252
+ self.norm = nn.LayerNorm(self.embed_dim)
253
+
254
+ if hasattr(nn, final_activation):
255
+ self.final_activation = getattr(nn, final_activation)()
256
+ else:
257
+ raise ValueError(f"Activation {final_activation} is not found in torch.nn")
258
+
259
+ self.head = nn.Sequential(
260
+ nn.Linear(input_dim, input_dim), self.final_activation
261
+ )
262
+
263
+ def _get_sinusoidal_embeddings(self, num_positions, embed_dim):
264
+ """
265
+ Create sinusoidal embeddings for positional encoding.
266
+
267
+ :param num_positions: Number of positions to encode.
268
+ :param embed_dim: Dimensionality of each embedding vector.
269
+ :return: Tensor of shape (num_positions, embed_dim) containing positional encodings.
270
+ """
271
+ assert embed_dim % 2 == 0, "embed_dim must be even"
272
+ position = torch.arange(
273
+ 0, num_positions, dtype=torch.float
274
+ ).unsqueeze(1)
275
+ div_term = torch.exp(
276
+ torch.arange(0, embed_dim, 2, dtype=torch.float)
277
+ * (-math.log(10000.0) / embed_dim)
278
+ ) # (embed_dim/2)
279
+
280
+ pe = torch.zeros(num_positions, embed_dim)
281
+ pe[:, 0::2] = torch.sin(position * div_term)
282
+ pe[:, 1::2] = torch.cos(position * div_term)
283
+
284
+ return pe
285
+
286
+ def _get_transforms_embeddings(self, transforms, embed_dim):
287
+ """
288
+ Create embedding modules for each augmentation parameter.
289
+
290
+ :param transforms: Mapping of augmentation names to configuration.
291
+ :param embed_dim: Dimensionality of the embeddings.
292
+ :return: ModuleDict of embeddings for each augmentation type.
293
+ """
294
+ transform_embeddings = nn.ModuleDict()
295
+ for aug_name in transforms:
296
+ if aug_name in [
297
+ "rotation",
298
+ "h_flip",
299
+ "v_flip",
300
+ "gaussian_blur",
301
+ "erosion",
302
+ "dilation",
303
+ ]:
304
+ # Discrete transformations
305
+ transform_embeddings[aug_name] = nn.Embedding(
306
+ num_embeddings=2 if aug_name != "rotation" else 4,
307
+ embedding_dim=embed_dim,
308
+ )
309
+ elif aug_name in ["crop"]:
310
+ # Discrete transformations
311
+ transform_embeddings[aug_name] = nn.Embedding(
312
+ num_embeddings=6, embedding_dim=embed_dim
313
+ )
314
+ elif aug_name in [
315
+ "brightness",
316
+ "contrast",
317
+ "saturation",
318
+ "hed",
319
+ "hue",
320
+ "gamma",
321
+ ]:
322
+ # Continuous transformations
323
+ transform_embeddings[aug_name] = nn.Sequential(
324
+ nn.Linear(1, embed_dim * 2),
325
+ nn.SiLU(),
326
+ nn.Linear(embed_dim * 2, embed_dim),
327
+ )
328
+ else:
329
+ raise ValueError(
330
+ f"{aug_name} is not a valid augmentation parameter name"
331
+ )
332
+ return transform_embeddings
333
+
334
+ def forward_aug_params_embed(self, aug_params):
335
+ """
336
+ Embed augmentation parameters and add positional embeddings if enabled.
337
+
338
+ :param aug_params: OrderedDict mapping augmentation names to (value_tensor, position_tensor).
339
+ :return: Tensor of shape (B, K, embed_dim) of embedded transform tokens.
340
+ """
341
+ z_transforms = []
342
+ for aug_name, (aug_param, pos) in aug_params.items():
343
+ if aug_name in [
344
+ "rotation",
345
+ "h_flip",
346
+ "v_flip",
347
+ "gaussian_blur",
348
+ "erosion",
349
+ "dilation",
350
+ "crop",
351
+ ]:
352
+ z_transform = self.transform_embeddings[aug_name](aug_param)
353
+ elif aug_name in [
354
+ "brightness",
355
+ "contrast",
356
+ "saturation",
357
+ "hue",
358
+ "gamma",
359
+ "hed",
360
+ ]:
361
+ z_transform = self.transform_embeddings[aug_name](
362
+ aug_param[..., None].float()
363
+ )
364
+ else:
365
+ raise ValueError(
366
+ f"{aug_name} is not a valid augmentation parameter name"
367
+ )
368
+ # Add positional embedding if specified
369
+ if self.use_transform_pos_embeddings:
370
+ if self.positional_encoding_type == "learnable":
371
+ pos_index = torch.as_tensor(pos, device=aug_param.device)
372
+ pos_embedding = self.sequence_pos_embedding(pos_index)
373
+ elif self.positional_encoding_type == "sinusoidal":
374
+ pos_embedding = self.sequence_pos_embedding[pos].to(
375
+ aug_param.device
376
+ )
377
+ else:
378
+ raise ValueError(
379
+ f"Invalid positional_encoding_type: {self.positional_encoding_type}"
380
+ )
381
+ z_transform_with_pos = z_transform + pos_embedding
382
+ z_transforms.append(z_transform_with_pos)
383
+ else:
384
+ z_transforms.append(z_transform)
385
+
386
+ # Stack the list of embeddings along a new dimension
387
+ z_transforms = torch.stack(z_transforms, dim=1)
388
+ return z_transforms
389
+
390
+ def sample_aug_params(
391
+ self,
392
+ batch_size: int,
393
+ device: torch.device = torch.device("cuda"),
394
+ mode: Literal["instance_wise", "wsi_wise"] = "wsi_wise",
395
+ ):
396
+ """
397
+ Sample random augmentation parameters and their relative positions.
398
+
399
+ If a transform from the supported list is missing in self.aug_param_names,
400
+ include it with zero values and append it at unique tail positions.
401
+ """
402
+ if mode not in ("instance_wise", "wsi_wise"):
403
+ raise ValueError('mode must be "instance_wise" or "wsi_wise"')
404
+
405
+ supported_aug_names = [
406
+ "rotation",
407
+ "crop",
408
+ "h_flip",
409
+ "v_flip",
410
+ "gaussian_blur",
411
+ "erosion",
412
+ "dilation",
413
+ "brightness",
414
+ "contrast",
415
+ "saturation",
416
+ "hue",
417
+ "gamma",
418
+ "hed",
419
+ ]
420
+
421
+ canonical_names = sorted(self.transforms_parameters.keys())
422
+ num_transforms = len(canonical_names)
423
+
424
+ # Determine which supported transforms are missing from the current configuration.
425
+ # For any missing transform, we will still include it in augmentation_parameters
426
+ # so that the downstream model sees a consistent set of transforms.
427
+ # These missing transforms are initialized with zero values (i.e., identity / no-op)
428
+ # and assigned unique tail positions after all configured transforms.
429
+ missing_names = [n for n in supported_aug_names if n not in canonical_names]
430
+ required_positions = num_transforms + len(missing_names)
431
+
432
+ # Build permutation/positions for configured transforms only
433
+ if mode == "instance_wise":
434
+ permutation_matrix = (
435
+ torch.stack(
436
+ [
437
+ torch.randperm(num_transforms, device=device)
438
+ for _ in range(batch_size)
439
+ ],
440
+ dim=0,
441
+ )
442
+ if num_transforms > 0
443
+ else torch.empty((batch_size, 0), dtype=torch.long, device=device)
444
+ )
445
+ else: # wsi_wise
446
+ if num_transforms > 0:
447
+ single_permutation = torch.randperm(num_transforms, device=device)
448
+ permutation_matrix = single_permutation.unsqueeze(0).repeat(
449
+ batch_size, 1
450
+ )
451
+ else:
452
+ permutation_matrix = torch.empty(
453
+ (batch_size, 0), dtype=torch.long, device=device
454
+ )
455
+
456
+ positions_matrix = (
457
+ torch.argsort(permutation_matrix, dim=1)
458
+ if num_transforms > 0
459
+ else torch.empty((batch_size, 0), dtype=torch.long, device=device)
460
+ )
461
+
462
+ augmentation_parameters = OrderedDict()
463
+ # --- sample configured transforms as before ---
464
+ for transform_index, name in enumerate(canonical_names):
465
+ config = self.transforms_parameters[name]
466
+
467
+ if name == "rotation":
468
+ probability = float(config)
469
+ if mode == "instance_wise":
470
+ apply_mask = torch.rand(batch_size, device=device) < probability
471
+ random_angles = torch.randint(0, 4, (batch_size,), device=device)
472
+ random_angles[~apply_mask] = 0
473
+ value_tensor = random_angles
474
+ else:
475
+ apply = random.random() < probability
476
+ angle = random.randint(1, 3) if apply else 0
477
+ value_tensor = torch.full(
478
+ (batch_size,), angle, dtype=torch.int64, device=device
479
+ )
480
+
481
+ elif name == "crop":
482
+ probability = float(config)
483
+ if mode == "instance_wise":
484
+ apply_mask = torch.rand(batch_size, device=device) < probability
485
+ random_crops = torch.randint(0, 5, (batch_size,), device=device)
486
+ random_crops[~apply_mask] = 0
487
+ value_tensor = random_crops
488
+ else:
489
+ apply = random.random() < probability
490
+ crop_code = random.randint(1, 4) if apply else 0
491
+ value_tensor = torch.full(
492
+ (batch_size,), crop_code, dtype=torch.int64, device=device
493
+ )
494
+
495
+ elif name in ("h_flip", "v_flip", "gaussian_blur", "erosion", "dilation"):
496
+ probability = float(config)
497
+ if mode == "instance_wise":
498
+ value_tensor = (
499
+ torch.rand(batch_size, device=device) < probability
500
+ ).int()
501
+ else:
502
+ bit = int(random.random() < probability)
503
+ value_tensor = torch.full(
504
+ (batch_size,), bit, dtype=torch.int32, device=device
505
+ )
506
+
507
+ elif name in (
508
+ "brightness",
509
+ "contrast",
510
+ "saturation",
511
+ "hue",
512
+ "gamma",
513
+ "hed",
514
+ ):
515
+ lower_bound, upper_bound = map(float, config)
516
+ if mode == "instance_wise":
517
+ value_tensor = torch.empty(batch_size, device=device).uniform_(
518
+ lower_bound, upper_bound
519
+ )
520
+ else:
521
+ scalar_value = random.uniform(lower_bound, upper_bound)
522
+ value_tensor = torch.full(
523
+ (batch_size,), scalar_value, dtype=torch.float32, device=device
524
+ )
525
+
526
+ else:
527
+ raise ValueError(f"'{name}' is not a recognised augmentation name")
528
+
529
+ position_tensor = positions_matrix[:, transform_index]
530
+ augmentation_parameters[name] = (value_tensor, position_tensor)
531
+
532
+ for i, name in enumerate(missing_names):
533
+ if name in ("rotation", "crop"):
534
+ zeros = torch.zeros(batch_size, dtype=torch.int64, device=device)
535
+ elif name in ("h_flip", "v_flip", "gaussian_blur", "erosion", "dilation"):
536
+ zeros = torch.zeros(batch_size, dtype=torch.int32, device=device)
537
+ else: # continuous
538
+ zeros = torch.zeros(batch_size, dtype=torch.float32, device=device)
539
+
540
+ tail_pos = num_transforms + i # unique: K, K+1, ..., K+M-1
541
+ pos = torch.full((batch_size,), tail_pos, dtype=torch.long, device=device)
542
+ augmentation_parameters[name] = (zeros, pos)
543
+
544
+ return augmentation_parameters
545
+
546
+ def forward(self, x, aug_params, **kwargs):
547
+ """
548
+ Forward pass: embed features, apply transformer blocks, and produce output.
549
+
550
+ :param x: Input tensor of shape (B, input_dim).
551
+ :param aug_params: Augmentation parameters from sample_aug_params.
552
+ :return: Output tensor of shape (B, input_dim).
553
+ """
554
+
555
+ x = x[:, None, :]
556
+
557
+ x = x.view(x.shape[0], self.chunk_size, self.embed_dim)
558
+ pos_embeddings = self.chunk_pos_embeddings_buffer.unsqueeze(0)
559
+ x = x + pos_embeddings
560
+ z = self.forward_aug_params_embed(aug_params)
561
+
562
+ for block in self.blocks:
563
+ x = block(x, z)
564
+ x = self.norm(x)
565
+
566
+ x = x.view(x.shape[0], 1, -1)
567
+ x = self.head(x)
568
+ x = x[:, 0, :]
569
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e387d4200fb0feed782ebc60cd8084a884a82a3fa928661b3d96f5b1a268a0
3
+ size 14881744
modeling_histaug.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from transformers import PreTrainedModel
4
+
5
+ from .configuration_histaug import HistaugConfig
6
+ from .histaug_model import HistaugModel
7
+
8
+ class HistaugPretrainedModel(PreTrainedModel):
9
+ config_class = HistaugConfig
10
+
11
+ def __init__(self, config: HistaugConfig, *model_args, **model_kwargs):
12
+ super().__init__(config)
13
+
14
+ # instantiate your core model using values from the config
15
+ self.histaug = HistaugModel(
16
+ input_dim=config.input_dim,
17
+ depth=config.depth,
18
+ num_heads=config.num_heads,
19
+ mlp_ratio=config.mlp_ratio,
20
+ use_transform_pos_embeddings=config.use_transform_pos_embeddings,
21
+ positional_encoding_type=config.positional_encoding_type,
22
+ final_activation=config.final_activation,
23
+ embedding_type=config.embedding_type,
24
+ chunk_size=config.chunk_size,
25
+ transforms=config.transforms,
26
+ **model_kwargs,
27
+ )
28
+
29
+ self.post_init()
30
+
31
+ self.histaug.eval()
32
+ for p in self.histaug.parameters():
33
+ p.requires_grad = False
34
+
35
+ def forward(self, x: torch.Tensor, aug_params, **kwargs) -> torch.Tensor:
36
+ """
37
+ Forward pass through the histaug model.
38
+ Args:
39
+ x: Input tensor of shape (batch_size, input_dim)
40
+ aug_params: Augmentation parameters dict as expected by HistaugModel
41
+ """
42
+ return self.histaug(x, aug_params, **kwargs)
43
+
44
+ def sample_aug_params(
45
+ self,
46
+ batch_size: int,
47
+ device: torch.device = None,
48
+ mode: str = "wsi_wise",
49
+ ):
50
+ """
51
+ Proxy to HistaugModel.sample_aug_params
52
+ """
53
+ device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+ return self.histaug.sample_aug_params(batch_size=batch_size, device=device, mode=mode)
55
+
56
+ def save_pretrained(self, save_directory: str, **kwargs):
57
+ """
58
+ Save the model and configuration to the directory.
59
+ """
60
+ super().save_pretrained(save_directory, **kwargs)
61
+
62
+ @classmethod
63
+ def from_pretrained(
64
+ cls,
65
+ pretrained_model_name_or_path: str,
66
+ *model_args,
67
+ **kwargs,
68
+ ):
69
+ """
70
+ Load a model from a pretrained checkpoint.
71
+ """
72
+ return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)