Create modeling_genbio_pathfm.py

#2
Files changed (1) hide show
  1. modeling_genbio_pathfm.py +566 -0
modeling_genbio_pathfm.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """GenBio-PathFM modeling for HuggingFace AutoModel.
2
+
3
+ This file is intended to live in the HuggingFace repo at
4
+ ``genbio-ai/genbio-pathfm`` so that users can load the model with:
5
+
6
+ from transformers import AutoModel
7
+ model = AutoModel.from_pretrained("genbio-ai/genbio-pathfm", trust_remote_code=True)
8
+ """
9
+
10
+ import math
11
+ from functools import partial
12
+ from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from torch import Tensor
18
+ from transformers import PreTrainedModel
19
+
20
+ from .configuration_genbio_pathfm import GenBioPathFMConfig
21
+
22
+
23
+ # ──────────────────────────────────────────────────────────────
24
+ # Helpers
25
+ # ──────────────────────────────────────────────────────────────
26
+
27
+ def _cat_keep_shapes(x_list: List[Tensor]) -> Tuple[Tensor, List[Tuple[int, ...]], List[int]]:
28
+ shapes = [x.shape for x in x_list]
29
+ num_tokens = [x.select(dim=-1, index=0).numel() for x in x_list]
30
+ flattened = torch.cat([x.flatten(0, -2) for x in x_list])
31
+ return flattened, shapes, num_tokens
32
+
33
+
34
+ def _uncat_with_shapes(
35
+ flattened: Tensor,
36
+ shapes: List[Tuple[int, ...]],
37
+ num_tokens: List[int],
38
+ ) -> List[Tensor]:
39
+ outputs_splitted = torch.split_with_sizes(flattened, num_tokens, dim=0)
40
+ shapes_adjusted = [shape[:-1] + torch.Size([flattened.shape[-1]]) for shape in shapes]
41
+ return [o.reshape(s) for o, s in zip(outputs_splitted, shapes_adjusted)]
42
+
43
+
44
+ def _mlp_forward_list(mlp: nn.Module, x_list: List[Tensor]) -> List[Tensor]:
45
+ x_flat, shapes, num_tokens = _cat_keep_shapes(x_list)
46
+ x_flat = mlp(x_flat)
47
+ return _uncat_with_shapes(x_flat, shapes, num_tokens)
48
+
49
+
50
+ def _make_2tuple(x):
51
+ if isinstance(x, tuple):
52
+ assert len(x) == 2
53
+ return x
54
+ assert isinstance(x, int)
55
+ return (x, x)
56
+
57
+
58
+ # ──────────────────────────────────────────────────────────────
59
+ # LayerScale
60
+ # ──────────────────────────────────────────────────────────────
61
+
62
+ class LayerScale(nn.Module):
63
+ def __init__(self, dim: int, init_values: Union[float, Tensor] = 1e-5, inplace: bool = False, device=None):
64
+ super().__init__()
65
+ self.inplace = inplace
66
+ self.gamma = nn.Parameter(torch.empty(dim, device=device))
67
+ self.init_values = init_values
68
+
69
+ def reset_parameters(self):
70
+ nn.init.constant_(self.gamma, self.init_values)
71
+
72
+ def forward(self, x: Tensor) -> Tensor:
73
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
74
+
75
+
76
+ # ──────────────────────────────────────────────────────────────
77
+ # FFN layers
78
+ # ──────────────────────────────────────────────────────────────
79
+
80
+ class Mlp(nn.Module):
81
+ def __init__(self, in_features, hidden_features=None, out_features=None,
82
+ act_layer=nn.GELU, drop=0.0, bias=True, device=None):
83
+ super().__init__()
84
+ out_features = out_features or in_features
85
+ hidden_features = hidden_features or in_features
86
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias, device=device)
87
+ self.act = act_layer()
88
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias, device=device)
89
+ self.drop = nn.Dropout(drop)
90
+
91
+ def forward(self, x: Tensor) -> Tensor:
92
+ x = self.fc1(x)
93
+ x = self.act(x)
94
+ x = self.drop(x)
95
+ x = self.fc2(x)
96
+ x = self.drop(x)
97
+ return x
98
+
99
+
100
+ class SwiGLUFFN(nn.Module):
101
+ def __init__(self, in_features, hidden_features=None, out_features=None,
102
+ act_layer=None, drop=0.0, bias=True, align_to=8, device=None):
103
+ super().__init__()
104
+ out_features = out_features or in_features
105
+ hidden_features = hidden_features or in_features
106
+ d = int(hidden_features * 2 / 3)
107
+ h = d + (-d % align_to)
108
+ self.w1 = nn.Linear(in_features, h, bias=bias, device=device)
109
+ self.w2 = nn.Linear(in_features, h, bias=bias, device=device)
110
+ self.w3 = nn.Linear(h, out_features, bias=bias, device=device)
111
+
112
+ def forward(self, x: Tensor) -> Tensor:
113
+ return self.w3(F.silu(self.w1(x)) * self.w2(x))
114
+
115
+
116
+ # ──────────────────────────────────────────────────────────────
117
+ # PatchEmbed
118
+ # ──────────────────────────────────────────────────────────────
119
+
120
+ class PatchEmbed(nn.Module):
121
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768,
122
+ norm_layer=None, flatten_embedding=True):
123
+ super().__init__()
124
+ image_HW = _make_2tuple(img_size)
125
+ patch_HW = _make_2tuple(patch_size)
126
+ patch_grid_size = (image_HW[0] // patch_HW[0], image_HW[1] // patch_HW[1])
127
+
128
+ self.img_size = image_HW
129
+ self.patch_size = patch_HW
130
+ self.patches_resolution = patch_grid_size
131
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
132
+ self.in_chans = in_chans
133
+ self.embed_dim = embed_dim
134
+ self.flatten_embedding = flatten_embedding
135
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
136
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
137
+
138
+ def forward(self, x: Tensor) -> Tensor:
139
+ x = self.proj(x)
140
+ H, W = x.size(2), x.size(3)
141
+ x = x.flatten(2).transpose(1, 2)
142
+ x = self.norm(x)
143
+ if not self.flatten_embedding:
144
+ x = x.reshape(-1, H, W, self.embed_dim)
145
+ return x
146
+
147
+ def reset_parameters(self):
148
+ k = 1 / (self.in_chans * (self.patch_size[0] ** 2))
149
+ nn.init.uniform_(self.proj.weight, -math.sqrt(k), math.sqrt(k))
150
+ if self.proj.bias is not None:
151
+ nn.init.uniform_(self.proj.bias, -math.sqrt(k), math.sqrt(k))
152
+
153
+
154
+ # ──────────────────────────────────────────────────────────────
155
+ # RoPE
156
+ # ──────────────────────────────────────────────────────────────
157
+
158
+ class RopePositionEmbedding(nn.Module):
159
+ def __init__(self, embed_dim, *, num_heads, base=100.0, min_period=None,
160
+ max_period=None, normalize_coords="separate", shift_coords=None,
161
+ jitter_coords=None, rescale_coords=None, dtype=None, device=None):
162
+ super().__init__()
163
+ assert embed_dim % (4 * num_heads) == 0
164
+ both_periods = min_period is not None and max_period is not None
165
+ if (base is None and not both_periods) or (base is not None and both_periods):
166
+ raise ValueError("Provide either `base` or both `min_period`+`max_period`.")
167
+
168
+ D_head = embed_dim // num_heads
169
+ self.base = base
170
+ self.min_period = min_period
171
+ self.max_period = max_period
172
+ self.D_head = D_head
173
+ self.normalize_coords = normalize_coords
174
+ self.shift_coords = shift_coords
175
+ self.jitter_coords = jitter_coords
176
+ self.rescale_coords = rescale_coords
177
+ self.dtype = dtype
178
+ self.register_buffer("periods", torch.empty(D_head // 4, device=device, dtype=dtype), persistent=True)
179
+ self._init_weights()
180
+
181
+ def _init_weights(self):
182
+ device = self.periods.device
183
+ dtype = self.dtype
184
+ if self.base is not None:
185
+ periods = self.base ** (
186
+ 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
187
+ )
188
+ else:
189
+ base = self.max_period / self.min_period
190
+ exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype)
191
+ periods = (base ** exponents) / base * self.max_period
192
+ self.periods.data = periods
193
+
194
+ def forward(self, *, H: int, W: int) -> Tuple[Tensor, Tensor]:
195
+ device, dtype = self.periods.device, self.dtype
196
+ dd = {"device": device, "dtype": dtype}
197
+
198
+ if self.normalize_coords == "max":
199
+ m = max(H, W)
200
+ coords_h = torch.arange(0.5, H, **dd) / m
201
+ coords_w = torch.arange(0.5, W, **dd) / m
202
+ elif self.normalize_coords == "min":
203
+ m = min(H, W)
204
+ coords_h = torch.arange(0.5, H, **dd) / m
205
+ coords_w = torch.arange(0.5, W, **dd) / m
206
+ else:
207
+ coords_h = torch.arange(0.5, H, **dd) / H
208
+ coords_w = torch.arange(0.5, W, **dd) / W
209
+
210
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
211
+ coords = coords.flatten(0, 1)
212
+ coords = 2.0 * coords - 1.0
213
+
214
+ if self.training and self.shift_coords is not None:
215
+ coords += torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
216
+ if self.training and self.jitter_coords is not None:
217
+ jmax = math.log(self.jitter_coords)
218
+ coords *= torch.empty(2, **dd).uniform_(-jmax, jmax).exp()
219
+ if self.training and self.rescale_coords is not None:
220
+ rmax = math.log(self.rescale_coords)
221
+ coords *= torch.empty(1, **dd).uniform_(-rmax, rmax).exp()
222
+
223
+ angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :]
224
+ angles = angles.flatten(1, 2).tile(2)
225
+ return torch.sin(angles), torch.cos(angles)
226
+
227
+
228
+ # ──────────────────────────────────────────────────────────────
229
+ # Attention
230
+ # ──────────────────────────────────────────────────────────────
231
+
232
+ def _rope_rotate_half(x: Tensor) -> Tensor:
233
+ x1, x2 = x.chunk(2, dim=-1)
234
+ return torch.cat([-x2, x1], dim=-1)
235
+
236
+
237
+ def _rope_apply(x: Tensor, sin: Tensor, cos: Tensor) -> Tensor:
238
+ return (x * cos) + (_rope_rotate_half(x) * sin)
239
+
240
+
241
+ class SelfAttention(nn.Module):
242
+ def __init__(self, dim, num_heads=8, qkv_bias=False, proj_bias=True,
243
+ attn_drop=0.0, proj_drop=0.0, device=None):
244
+ super().__init__()
245
+ self.num_heads = num_heads
246
+ self.scale = (dim // num_heads) ** -0.5
247
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias, device=device)
248
+ self.attn_drop = nn.Dropout(attn_drop)
249
+ self.proj = nn.Linear(dim, dim, bias=proj_bias, device=device)
250
+ self.proj_drop = nn.Dropout(proj_drop)
251
+
252
+ def _apply_rope(self, q, k, rope):
253
+ q_dtype, k_dtype = q.dtype, k.dtype
254
+ sin, cos = rope
255
+ q = q.to(sin.dtype)
256
+ k = k.to(sin.dtype)
257
+ prefix = q.shape[-2] - sin.shape[-2]
258
+ assert prefix >= 0
259
+ q = torch.cat((q[:, :, :prefix], _rope_apply(q[:, :, prefix:], sin, cos)), dim=-2)
260
+ k = torch.cat((k[:, :, :prefix], _rope_apply(k[:, :, prefix:], sin, cos)), dim=-2)
261
+ return q.to(q_dtype), k.to(k_dtype)
262
+
263
+ def compute_attention(self, qkv, attn_bias=None, rope=None):
264
+ assert attn_bias is None
265
+ B, N, _ = qkv.shape
266
+ C = self.qkv.in_features
267
+ qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads)
268
+ q, k, v = [t.transpose(1, 2) for t in torch.unbind(qkv, 2)]
269
+ if rope is not None:
270
+ q, k = self._apply_rope(q, k, rope)
271
+ x = F.scaled_dot_product_attention(q, k, v)
272
+ return x.transpose(1, 2).reshape(B, N, C)
273
+
274
+ def forward(self, x, attn_bias=None, rope=None):
275
+ x = self.proj(self.compute_attention(self.qkv(x), attn_bias=attn_bias, rope=rope))
276
+ return self.proj_drop(x)
277
+
278
+ def forward_list(self, x_list, attn_bias=None, rope_list=None):
279
+ x_flat, shapes, num_tokens = _cat_keep_shapes(x_list)
280
+ qkv_flat = self.qkv(x_flat)
281
+ qkv_list = _uncat_with_shapes(qkv_flat, shapes, num_tokens)
282
+ att_out = [
283
+ self.compute_attention(qkv, attn_bias=attn_bias, rope=rope)
284
+ for qkv, rope in zip(qkv_list, rope_list)
285
+ ]
286
+ x_flat, shapes, num_tokens = _cat_keep_shapes(att_out)
287
+ return _uncat_with_shapes(self.proj(x_flat), shapes, num_tokens)
288
+
289
+
290
+ # ──────────────────────────────────────────────────────────────
291
+ # Transformer block
292
+ # ──────────────────────────────────────────────────────────────
293
+
294
+ class SelfAttentionBlock(nn.Module):
295
+ def __init__(self, dim, num_heads, ffn_ratio=4.0, qkv_bias=False,
296
+ proj_bias=True, ffn_bias=True, drop=0.0, attn_drop=0.0,
297
+ init_values=None, drop_path=0.0, act_layer=nn.GELU,
298
+ norm_layer=nn.LayerNorm, attn_class=SelfAttention,
299
+ ffn_layer=Mlp, device=None):
300
+ super().__init__()
301
+ self.norm1 = norm_layer(dim)
302
+ self.attn = attn_class(
303
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, proj_bias=proj_bias,
304
+ attn_drop=attn_drop, proj_drop=drop, device=device,
305
+ )
306
+ self.ls1 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
307
+ self.norm2 = norm_layer(dim)
308
+ self.mlp = ffn_layer(
309
+ in_features=dim, hidden_features=int(dim * ffn_ratio),
310
+ act_layer=act_layer, drop=drop, bias=ffn_bias, device=device,
311
+ )
312
+ self.ls2 = LayerScale(dim, init_values=init_values, device=device) if init_values else nn.Identity()
313
+ self.sample_drop_ratio = drop_path
314
+
315
+ @staticmethod
316
+ def _maybe_index_rope(rope, indices):
317
+ if rope is None:
318
+ return None
319
+ sin, cos = rope
320
+ if sin.ndim == 4:
321
+ return sin[indices], cos[indices]
322
+ return sin, cos
323
+
324
+ def _forward_list(self, x_list, rope_list=None):
325
+ if self.training and self.sample_drop_ratio > 0.0:
326
+ b_list = [x.shape[0] for x in x_list]
327
+ ss = [max(int(b * (1 - self.sample_drop_ratio)), 1) for b in b_list]
328
+ rsf = [b / s for b, s in zip(b_list, ss)]
329
+
330
+ idx1 = [(torch.randperm(b, device=x.device))[:s] for x, b, s in zip(x_list, b_list, ss)]
331
+ sub1 = [x[i] for x, i in zip(x_list, idx1)]
332
+ rope_sub = [self._maybe_index_rope(r, i) for r, i in zip(rope_list, idx1)] if rope_list else rope_list
333
+
334
+ flat, shapes, nt = _cat_keep_shapes(sub1)
335
+ norm1 = _uncat_with_shapes(self.norm1(flat), shapes, nt)
336
+ res1 = self.attn.forward_list(norm1, rope_list=rope_sub)
337
+
338
+ x_attn = [
339
+ torch.index_add(x, 0, i, self.ls1(r), alpha=f)
340
+ for x, r, i, f in zip(x_list, res1, idx1, rsf)
341
+ ]
342
+ idx2 = [(torch.randperm(b, device=x.device))[:s] for x, b, s in zip(x_attn, b_list, ss)]
343
+ sub2 = [x[i] for x, i in zip(x_attn, idx2)]
344
+ flat2, shapes2, nt2 = _cat_keep_shapes(sub2)
345
+ res2 = _mlp_forward_list(self.mlp, _uncat_with_shapes(self.norm2(flat2), shapes2, nt2))
346
+
347
+ return [
348
+ torch.index_add(xa, 0, i, self.ls2(r), alpha=f)
349
+ for xa, r, i, f in zip(x_attn, res2, idx2, rsf)
350
+ ]
351
+ else:
352
+ out = []
353
+ for x, rope in zip(x_list, rope_list):
354
+ x = x + self.ls1(self.attn(self.norm1(x), rope=rope))
355
+ x = x + self.ls2(self.mlp(self.norm2(x)))
356
+ out.append(x)
357
+ return out
358
+
359
+ def forward(self, x_or_list, rope_or_list=None):
360
+ if isinstance(x_or_list, Tensor):
361
+ return self._forward_list([x_or_list], rope_list=[rope_or_list])[0]
362
+ elif isinstance(x_or_list, list):
363
+ if rope_or_list is None:
364
+ rope_or_list = [None] * len(x_or_list)
365
+ return self._forward_list(x_or_list, rope_list=rope_or_list)
366
+ raise AssertionError
367
+
368
+
369
+ # ──────────────────────────────────────────────────────────────
370
+ # Backbone ViT
371
+ # ──────────────────────────────────────────────────────────────
372
+
373
+ _FFN_LAYERS = {
374
+ "mlp": Mlp,
375
+ "swiglu": SwiGLUFFN,
376
+ "swiglu32": partial(SwiGLUFFN, align_to=32),
377
+ "swiglu64": partial(SwiGLUFFN, align_to=64),
378
+ "swiglu128": partial(SwiGLUFFN, align_to=128),
379
+ }
380
+ _NORM_LAYERS = {
381
+ "layernorm": partial(nn.LayerNorm, eps=1e-6),
382
+ "layernormbf16": partial(nn.LayerNorm, eps=1e-5),
383
+ }
384
+ _DTYPES = {
385
+ "fp32": torch.float32,
386
+ "fp16": torch.float16,
387
+ "bf16": torch.bfloat16,
388
+ }
389
+
390
+
391
+ class VisionTransformer(nn.Module):
392
+ def __init__(self, *, img_size=224, patch_size=16, in_chans=1,
393
+ pos_embed_rope_base=100.0, pos_embed_rope_min_period=None,
394
+ pos_embed_rope_max_period=None, pos_embed_rope_normalize_coords="separate",
395
+ pos_embed_rope_shift_coords=None, pos_embed_rope_jitter_coords=None,
396
+ pos_embed_rope_rescale_coords=None, pos_embed_rope_dtype="bf16",
397
+ embed_dim=768, depth=12, num_heads=12, ffn_ratio=3.0,
398
+ qkv_bias=True, drop_path_rate=0.0, layerscale_init=None,
399
+ norm_layer="layernorm", ffn_layer="swiglu64", ffn_bias=True,
400
+ proj_bias=True, n_storage_tokens=4, device=None, **ignored_kwargs):
401
+ super().__init__()
402
+ norm_layer_cls = _NORM_LAYERS[norm_layer]
403
+ ffn_layer_cls = _FFN_LAYERS[ffn_layer]
404
+
405
+ self.num_features = self.embed_dim = embed_dim
406
+ self.n_blocks = depth
407
+ self.num_heads = num_heads
408
+ self.patch_size = patch_size
409
+ self.n_storage_tokens = n_storage_tokens
410
+
411
+ self.patch_embed = PatchEmbed(
412
+ img_size=img_size, patch_size=patch_size,
413
+ in_chans=in_chans, embed_dim=embed_dim, flatten_embedding=False,
414
+ )
415
+ self.cls_token = nn.Parameter(torch.empty(1, 1, embed_dim, device=device))
416
+ if n_storage_tokens > 0:
417
+ self.storage_tokens = nn.Parameter(torch.empty(1, n_storage_tokens, embed_dim, device=device))
418
+
419
+ self.rope_embed = RopePositionEmbedding(
420
+ embed_dim=embed_dim, num_heads=num_heads,
421
+ base=pos_embed_rope_base,
422
+ min_period=pos_embed_rope_min_period,
423
+ max_period=pos_embed_rope_max_period,
424
+ normalize_coords=pos_embed_rope_normalize_coords,
425
+ shift_coords=pos_embed_rope_shift_coords,
426
+ jitter_coords=pos_embed_rope_jitter_coords,
427
+ rescale_coords=pos_embed_rope_rescale_coords,
428
+ dtype=_DTYPES[pos_embed_rope_dtype],
429
+ device=device,
430
+ )
431
+
432
+ self.blocks = nn.ModuleList([
433
+ SelfAttentionBlock(
434
+ dim=embed_dim, num_heads=num_heads, ffn_ratio=ffn_ratio,
435
+ qkv_bias=qkv_bias, proj_bias=proj_bias, ffn_bias=ffn_bias,
436
+ drop_path=drop_path_rate, norm_layer=norm_layer_cls,
437
+ act_layer=nn.GELU, ffn_layer=ffn_layer_cls,
438
+ init_values=layerscale_init, device=device,
439
+ )
440
+ for _ in range(depth)
441
+ ])
442
+ self.norm = norm_layer_cls(embed_dim)
443
+
444
+ def prepare_tokens(self, x):
445
+ x = self.patch_embed(x)
446
+ B, H, W, _ = x.shape
447
+ x = x.flatten(1, 2)
448
+ ct = self.cls_token
449
+ st = self.storage_tokens if self.n_storage_tokens > 0 else torch.empty(
450
+ 1, 0, ct.shape[-1], dtype=ct.dtype, device=ct.device
451
+ )
452
+ x = torch.cat([ct.expand(B, -1, -1), st.expand(B, -1, -1), x], dim=1)
453
+ return x, (H, W)
454
+
455
+ def forward_features(self, x):
456
+ tokens, (H, W) = self.prepare_tokens(x)
457
+ rope = self.rope_embed(H=H, W=W)
458
+ for blk in self.blocks:
459
+ tokens = blk(tokens, rope)
460
+ tokens = self.norm(tokens)
461
+ n = self.n_storage_tokens
462
+ return {
463
+ "x_norm_clstoken": tokens[:, 0],
464
+ "x_storage_tokens": tokens[:, 1:n + 1],
465
+ "x_norm_patchtokens": tokens[:, n + 1:],
466
+ "x_prenorm": tokens,
467
+ }
468
+
469
+ def forward(self, x):
470
+ return self.forward_features(x)
471
+
472
+
473
+ # ──────────────────────────────────────────────────────────────
474
+ # HuggingFace PreTrainedModel wrapper
475
+ # ──────────────────────────────────────────────────────────────
476
+
477
+ class GenBioPathFMModel(PreTrainedModel):
478
+ """
479
+ GenBio-PathFM wrapped as a HuggingFace ``PreTrainedModel``.
480
+
481
+ Usage::
482
+
483
+ from transformers import AutoModel
484
+ model = AutoModel.from_pretrained("genbio-ai/genbio-pathfm", trust_remote_code=True)
485
+
486
+ # CLS-only: [B, embed_dim*3]
487
+ cls_features = model(rgb_tensor)
488
+
489
+ # CLS + patch tokens:
490
+ cls_features, patch_features = model.forward_with_patches(rgb_tensor)
491
+ """
492
+
493
+ config_class = GenBioPathFMConfig
494
+
495
+ def __init__(self, config: GenBioPathFMConfig):
496
+ super().__init__(config)
497
+ self.backbone = VisionTransformer(
498
+ img_size=config.img_size,
499
+ patch_size=config.patch_size,
500
+ embed_dim=config.embed_dim,
501
+ depth=config.depth,
502
+ num_heads=config.num_heads,
503
+ ffn_ratio=config.ffn_ratio,
504
+ in_chans=config.in_chans,
505
+ n_storage_tokens=config.n_storage_tokens,
506
+ ffn_layer=config.ffn_layer,
507
+ layerscale_init=config.layerscale_init,
508
+ qkv_bias=config.qkv_bias,
509
+ proj_bias=config.proj_bias,
510
+ ffn_bias=config.ffn_bias,
511
+ norm_layer=config.norm_layer,
512
+ drop_path_rate=config.drop_path_rate,
513
+ pos_embed_rope_base=config.pos_embed_rope_base,
514
+ pos_embed_rope_min_period=config.pos_embed_rope_min_period,
515
+ pos_embed_rope_max_period=config.pos_embed_rope_max_period,
516
+ pos_embed_rope_normalize_coords=config.pos_embed_rope_normalize_coords,
517
+ pos_embed_rope_shift_coords=config.pos_embed_rope_shift_coords,
518
+ pos_embed_rope_jitter_coords=config.pos_embed_rope_jitter_coords,
519
+ pos_embed_rope_rescale_coords=config.pos_embed_rope_rescale_coords,
520
+ pos_embed_rope_dtype=config.pos_embed_rope_dtype,
521
+ )
522
+
523
+ def _encode(self, x: Tensor) -> Dict[str, Tensor]:
524
+ """Encode single-channel [B, 1, H, W] images."""
525
+ tokens, (h, w) = self.backbone.prepare_tokens(x)
526
+ rope = self.backbone.rope_embed(H=h, W=w)
527
+ for blk in self.backbone.blocks:
528
+ tokens = blk(tokens, rope)
529
+ tokens = self.backbone.norm(tokens)
530
+ return {
531
+ "x_norm_clstoken": tokens[:, 0],
532
+ "x_norm_patchtokens": tokens[:, 1 + self.config.n_storage_tokens:],
533
+ }
534
+
535
+ def forward(self, pixel_values: Tensor, **kwargs) -> Tensor:
536
+ """
537
+ Args:
538
+ pixel_values: ``[B, 3, H, W]`` RGB images (already normalized).
539
+
540
+ Returns:
541
+ ``[B, embed_dim * 3]`` – CLS features from R, G, B channels
542
+ concatenated along the feature dimension.
543
+ """
544
+ b, _c, h, w = pixel_values.shape
545
+ features = self._encode(pixel_values.view(b * 3, 1, h, w))
546
+ cls = features["x_norm_clstoken"].view(b, 3, -1)
547
+ return torch.cat([cls[:, 0], cls[:, 1], cls[:, 2]], dim=-1)
548
+
549
+ def forward_with_patches(self, pixel_values: Tensor) -> Tuple[Tensor, Tensor]:
550
+ """
551
+ Returns:
552
+ cls_out: ``[B, embed_dim * 3]``
553
+ patch_out: ``[B, num_patches, embed_dim * 3]``
554
+ """
555
+ b, _c, h, w = pixel_values.shape
556
+ features = self._encode(pixel_values.view(b * 3, 1, h, w))
557
+
558
+ cls = features["x_norm_clstoken"].view(b, 3, -1)
559
+ cls_out = torch.cat([cls[:, 0], cls[:, 1], cls[:, 2]], dim=-1)
560
+
561
+ patches = features["x_norm_patchtokens"]
562
+ n, d = patches.shape[1], patches.shape[2]
563
+ patches = patches.view(b, 3, n, d)
564
+ patch_out = torch.cat([patches[:, 0], patches[:, 1], patches[:, 2]], dim=-1)
565
+
566
+ return cls_out, patch_out