matejpekar commited on
Commit
389a8d6
·
verified ·
1 Parent(s): 835482b

Upload model

Browse files
Files changed (4) hide show
  1. config.json +32 -19
  2. configuration.py +24 -14
  3. model.safetensors +2 -2
  4. modeling.py +292 -363
config.json CHANGED
@@ -16,31 +16,44 @@
16
  "stage4"
17
  ]
18
  },
19
- "depths": [
20
- 6,
21
- 2,
22
- 2
 
 
 
 
 
 
 
 
 
 
 
 
23
  ],
24
  "dim": 384,
25
- "dropout": 0.1,
 
 
 
 
 
 
 
26
  "model_type": "lsp_detr",
27
- "num_classes": 1,
28
  "num_heads": 12,
29
  "num_radial_distances": 64,
30
  "query_block_size": 14.222222222222223,
31
- "src_window_sizes": [
32
- 8,
33
- 16,
34
- 32
35
- ],
36
- "tgt_window_sizes": [
37
- 9,
38
- 9,
39
- 9
40
- ],
41
  "torch_dtype": "float32",
42
- "transformers_version": "4.51.3",
43
  "use_pretrained_backbone": true,
44
- "use_timm_backbone": false,
45
- "window_size": 9
46
  }
 
16
  "stage4"
17
  ]
18
  },
19
+ "cross_sta_config": [
20
+ {
21
+ "kernel": 5,
22
+ "kv_tile": 8,
23
+ "q_tile": 3
24
+ },
25
+ {
26
+ "kernel": 5,
27
+ "kv_tile": 4,
28
+ "q_tile": 3
29
+ },
30
+ {
31
+ "kernel": 5,
32
+ "kv_tile": 2,
33
+ "q_tile": 3
34
+ }
35
  ],
36
  "dim": 384,
37
+ "feature_levels": [
38
+ 2,
39
+ 1,
40
+ 0,
41
+ 2,
42
+ 1,
43
+ 0
44
+ ],
45
  "model_type": "lsp_detr",
46
+ "num_classes": 5,
47
  "num_heads": 12,
48
  "num_radial_distances": 64,
49
  "query_block_size": 14.222222222222223,
50
+ "self_sta_config": {
51
+ "kernel": 3,
52
+ "kv_tile": 3,
53
+ "q_tile": 3
54
+ },
 
 
 
 
 
55
  "torch_dtype": "float32",
56
+ "transformers_version": "4.52.3",
57
  "use_pretrained_backbone": true,
58
+ "use_timm_backbone": false
 
59
  }
configuration.py CHANGED
@@ -1,9 +1,15 @@
1
- from typing import Any
2
 
3
  from transformers import PretrainedConfig
4
  from transformers.utils.backbone_utils import verify_backbone_config_arguments
5
 
6
 
 
 
 
 
 
 
7
  class LSPDetrConfig(PretrainedConfig):
8
  model_type = "lsp_detr"
9
 
@@ -15,17 +21,22 @@ class LSPDetrConfig(PretrainedConfig):
15
  backbone_kwargs: dict[str, Any] | None = None,
16
  backbone_config: Any | None = None,
17
  dim: int = 384,
18
- num_classes: int = 1,
19
- depths: tuple[int, ...] = (6, 2, 2),
20
- query_block_size: int = 16,
21
  num_heads: int = 12,
22
- window_size: int = 8,
23
- tgt_window_sizes: tuple[int, ...] = (8, 8, 8),
24
- src_window_sizes: tuple[int, ...] = (8, 16, 32),
25
  num_radial_distances: int = 64,
26
- dropout: float = 0.1,
 
 
 
 
 
27
  **kwargs,
28
  ) -> None:
 
 
 
29
  if backbone_kwargs is None:
30
  backbone_kwargs = {"out_features": ["stage1", "stage2", "stage3", "stage4"]}
31
 
@@ -43,13 +54,12 @@ class LSPDetrConfig(PretrainedConfig):
43
  self.backbone_config = backbone_config
44
  self.backbone_kwargs = backbone_kwargs
45
  self.dim = dim
 
46
  self.num_classes = num_classes
47
- self.depths = depths
48
  self.query_block_size = query_block_size
49
- self.num_heads = num_heads
50
- self.window_size = window_size
51
- self.tgt_window_sizes = tgt_window_sizes
52
- self.src_window_sizes = src_window_sizes
53
  self.num_radial_distances = num_radial_distances
54
- self.dropout = dropout
 
 
55
  super().__init__(**kwargs)
 
1
+ from typing import Any, TypedDict
2
 
3
  from transformers import PretrainedConfig
4
  from transformers.utils.backbone_utils import verify_backbone_config_arguments
5
 
6
 
7
+ class STAConfig(TypedDict):
8
+ kernel: int
9
+ q_tile: int
10
+ kv_tile: int
11
+
12
+
13
  class LSPDetrConfig(PretrainedConfig):
14
  model_type = "lsp_detr"
15
 
 
21
  backbone_kwargs: dict[str, Any] | None = None,
22
  backbone_config: Any | None = None,
23
  dim: int = 384,
 
 
 
24
  num_heads: int = 12,
25
+ num_classes: int = 1,
26
+ query_block_size: float = 14.222222222222223, # 256 / 18
27
+ feature_levels: tuple[int, ...] = (2, 1, 0, 2, 1, 0),
28
  num_radial_distances: int = 64,
29
+ self_sta_config: STAConfig | None = None,
30
+ cross_sta_config: tuple[STAConfig, ...] = (
31
+ {"kernel": 5, "q_tile": 3, "kv_tile": 8},
32
+ {"kernel": 5, "q_tile": 3, "kv_tile": 4},
33
+ {"kernel": 5, "q_tile": 3, "kv_tile": 2},
34
+ ),
35
  **kwargs,
36
  ) -> None:
37
+ if self_sta_config is None:
38
+ self_sta_config = {"kernel": 3, "q_tile": 3, "kv_tile": 3}
39
+
40
  if backbone_kwargs is None:
41
  backbone_kwargs = {"out_features": ["stage1", "stage2", "stage3", "stage4"]}
42
 
 
54
  self.backbone_config = backbone_config
55
  self.backbone_kwargs = backbone_kwargs
56
  self.dim = dim
57
+ self.num_heads = num_heads
58
  self.num_classes = num_classes
 
59
  self.query_block_size = query_block_size
60
+ self.feature_levels = feature_levels
 
 
 
61
  self.num_radial_distances = num_radial_distances
62
+ self.self_sta_config = self_sta_config
63
+ self.cross_sta_config = cross_sta_config
64
+
65
  super().__init__(**kwargs)
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f1366763c506eadd8933ed22481c4f11ab6d2984ac78f91157c461d3ad59c526
3
- size 204465704
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f5437eb889a864ff88ae121ed7581217778b30430657ce39751a7f5e4b96082
3
+ size 180151024
modeling.py CHANGED
@@ -1,67 +1,24 @@
1
  import math
 
2
 
3
  import torch
4
  import torch.nn.functional as F
5
- from einops import rearrange, repeat
6
  from torch import Tensor, nn
 
 
 
 
 
 
7
  from torch.nn.utils import parametrize
8
- from transformers import PreTrainedModel
9
- from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse
10
  from transformers.utils.backbone_utils import load_backbone
11
 
12
- from .configuration import LSPDetrConfig
13
 
14
 
15
- class MLP(nn.Sequential):
16
- """Very simple multi-layer perceptron."""
17
-
18
- def __init__(
19
- self,
20
- input_dim: int,
21
- hidden_dim: int,
22
- output_dim: int,
23
- num_layers: int,
24
- act_layer: type[nn.Module] = nn.GELU,
25
- dropout: float = 0.0,
26
- ) -> None:
27
- assert num_layers > 1
28
-
29
- layers = []
30
- h = [hidden_dim] * (num_layers - 1)
31
- for n, k in zip([input_dim, *h], h, strict=False):
32
- layers.append(nn.Linear(n, k))
33
- layers.append(act_layer())
34
- if dropout > 0:
35
- layers.append(nn.Dropout(dropout))
36
-
37
- layers.append(nn.Linear(hidden_dim, output_dim))
38
- super().__init__(*layers)
39
-
40
-
41
- class FeedForward(nn.Module):
42
- """FeedForward module.
43
-
44
- Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py
45
- """
46
-
47
- def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None:
48
- """Initialize the FeedForward module.
49
-
50
- Args:
51
- dim (int): Input dimension.
52
- hidden_dim (int): Hidden dimension of the feedforward layer.
53
- multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
54
- """
55
- super().__init__()
56
- hidden_dim = int(2 * hidden_dim / 3)
57
- hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
58
-
59
- self.w1 = nn.Linear(dim, hidden_dim, bias=False)
60
- self.w2 = nn.Linear(hidden_dim, dim, bias=False)
61
- self.w3 = nn.Linear(dim, hidden_dim, bias=False)
62
-
63
- def forward(self, x: Tensor) -> Tensor:
64
- return self.w2(F.silu(self.w1(x)) * self.w3(x))
65
 
66
 
67
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
@@ -134,6 +91,11 @@ class CayleySTRING(nn.Module):
134
  def init_weights(self) -> None:
135
  self.S = nn.init.kaiming_uniform_(self.S, a=math.sqrt(5))
136
 
 
 
 
 
 
137
  @parametrize.cached()
138
  @torch.autocast("cuda", enabled=False)
139
  def forward(self, x: Tensor, positions: Tensor) -> Tensor:
@@ -144,13 +106,18 @@ class CayleySTRING(nn.Module):
144
  positions ([b, n, pos_dim]): Positions tensor.
145
  """
146
  # Compute (I + S)^-1 @ x
147
- y = torch.linalg.solve(
148
- self.I + self.S, rearrange(x.float(), "b h n d -> h d (b n)")
149
- )
 
 
 
 
 
 
 
150
 
151
- # change of basis
152
- px = torch.matmul(self.I - self.S, y)
153
- px = rearrange(px, "h d (b n) -> b h n d", b=x.size(0)).contiguous()
154
 
155
  # apply RoPE-Mixed
156
  angles = torch.einsum("bnk,khc->bhnc", positions, self.freqs)
@@ -161,11 +128,123 @@ class CayleySTRING(nn.Module):
161
  return out.type_as(x)
162
 
163
 
164
- def maybe_pad(x: Tensor, window_size: int) -> Tensor:
165
- h, w = x.shape[1:3]
166
- pad_right = (window_size - w % window_size) % window_size
167
- pad_bottom = (window_size - h % window_size) % window_size
168
- return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
 
171
  @torch.autocast("cuda", enabled=False)
@@ -181,315 +260,142 @@ def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tenso
181
  return torch.stack((absolute_x, absolute_y), dim=-1)
182
 
183
 
184
- def get_mask_windows(
185
- height: int, width: int, window_size: int, shift_size: int, device: torch.device
186
- ) -> Tensor:
187
- # Create indices for height and width regions
188
- h_idx = torch.zeros(height, dtype=torch.long, device=device)
189
- h_idx[height - window_size : height - shift_size] = 1
190
- h_idx[height - shift_size :] = 2
191
-
192
- w_idx = torch.zeros(width, dtype=torch.long, device=device)
193
- w_idx[width - window_size : width - shift_size] = 1
194
- w_idx[width - shift_size :] = 2
195
-
196
- # Calculate region index for each pixel using broadcasting
197
- mask = h_idx.unsqueeze(1) * 3 + w_idx.unsqueeze(0)
198
-
199
- mask_windows = window_partition(mask[None, ..., None], window_size)
200
- return rearrange(mask_windows, "n w1 w2 1 -> n (w1 w2)")
201
-
202
-
203
- class WindowCrossAttention(nn.Module):
204
  def __init__(
205
  self,
206
  dim: int,
207
  src_dim: int,
208
- tgt_window_size: int,
209
- src_window_size: int,
210
  num_heads: int,
211
- src_shift_size: int = 0,
212
- tgt_shift_size: int = 0,
213
- dropout: float = 0.0,
214
  ) -> None:
215
  super().__init__()
216
-
217
  self.num_heads = num_heads
218
- self.tgt_window_size = tgt_window_size
219
- self.src_window_size = src_window_size
220
- self.src_shift_size = src_shift_size
221
- self.tgt_shift_size = tgt_shift_size
222
- self.dropout = dropout
223
 
224
  self.pe = CayleySTRING(dim, num_heads)
225
- self.query = nn.Linear(dim, dim, bias=False)
226
  self.kv = nn.Linear(src_dim, dim * 2, bias=False)
227
  self.wo = nn.Linear(dim, dim, bias=False)
228
 
229
- def get_attn_mask(
230
- self,
231
- height: int,
232
- width: int,
233
- key_height: int,
234
- key_width: int,
235
- device: torch.device,
236
- dtype: torch.dtype,
237
- ) -> Tensor | None:
238
- if self.tgt_shift_size == 0:
239
- return None
240
-
241
- query_mask = get_mask_windows(
242
- height, width, self.tgt_window_size, self.tgt_shift_size, device
243
- )
244
- key_mask = get_mask_windows(
245
- key_height, key_width, self.src_window_size, self.src_shift_size, device
246
  )
247
-
248
- attn_mask = query_mask.unsqueeze(2) - key_mask.unsqueeze(1)
249
- return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf)
250
 
251
  def forward(
252
- self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor
253
  ) -> Tensor:
254
- b, h, w, c = tgt.shape
255
-
256
- # pad to multiples of window size
257
- tgt = maybe_pad(tgt, self.tgt_window_size)
258
- src = maybe_pad(src, self.src_window_size)
259
- tgt_coords = maybe_pad(tgt_coords, self.tgt_window_size)
260
- src_coord = maybe_pad(src_coord, self.src_window_size)
261
- h_pad, w_pad = tgt.shape[1:3]
262
- src_h, src_w = src.shape[1:3]
263
-
264
- # cyclic shift
265
- if self.tgt_shift_size > 0:
266
- tgt = tgt.roll(
267
- shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2)
268
- )
269
- tgt_coords = tgt_coords.roll(
270
- shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2)
271
- )
272
 
273
- if self.src_shift_size > 0:
274
- src = src.roll(
275
- shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2)
276
- )
277
- src_coord = src_coord.roll(
278
- shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2)
279
- )
280
-
281
- # partition windows
282
- tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
283
- src = window_partition(src, self.src_window_size).flatten(1, 2)
284
- tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
285
- src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
286
-
287
- attn_mask = self.get_attn_mask(
288
- h_pad, w_pad, src_h, src_w, tgt.device, tgt.dtype
289
  )
290
-
291
- if attn_mask is not None:
292
- attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
293
-
294
- # W-MCA/SW-MCA
295
- q = rearrange(self.query(tgt), "b n (h d) -> b h n d", h=self.num_heads)
296
  k, v = rearrange(
297
- self.kv(src), "b n (two h d) -> two b h n d", two=2, h=self.num_heads
 
 
 
298
  )
299
- x = F.scaled_dot_product_attention(
300
- query=self.pe(q, tgt_coords),
301
- key=self.pe(k, src_coord),
302
- value=v,
303
- attn_mask=attn_mask,
304
- dropout_p=self.dropout if self.training else 0.0,
305
- )
306
- tgt = self.wo(rearrange(x, "b h n d -> b n (h d)"))
307
-
308
- # merge windows
309
- tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c)
310
- tgt = window_reverse(tgt, self.tgt_window_size, h_pad, w_pad)
311
-
312
- # reverse cyclic shift
313
- if self.tgt_shift_size > 0:
314
- tgt = torch.roll(
315
- tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2)
316
- )
317
-
318
- return tgt[:, :h, :w, :].contiguous() # remove padding
319
-
320
 
321
- class WindowSelfAttention(nn.Module):
322
- def __init__(
323
- self,
324
- dim: int,
325
- window_size: int,
326
- num_heads: int,
327
- shift_size: int = 0,
328
- dropout: float = 0.0,
329
- ) -> None:
330
- super().__init__()
331
-
332
- self.num_heads = num_heads
333
- self.window_size = window_size
334
- self.shift_size = shift_size
335
- self.dropout = dropout
336
-
337
- self.pe = CayleySTRING(dim, num_heads)
338
- self.qkv = nn.Linear(dim, dim * 3, bias=False)
339
- self.wo = nn.Linear(dim, dim, bias=False)
340
-
341
- def get_attn_mask(
342
- self, height: int, width: int, device: torch.device, dtype: torch.dtype
343
- ) -> Tensor | None:
344
- if self.shift_size == 0:
345
- return None
346
-
347
- mask_windows = get_mask_windows(
348
- height, width, self.window_size, self.shift_size, device
349
- )
350
- # Calculate the attention mask based on window differences
351
- attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(1)
352
- return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf)
353
-
354
- def forward(self, x: Tensor, coords: Tensor) -> Tensor:
355
- """Forward function for Window Self-Attention.
356
-
357
- Args:
358
- x ([b, h, w, c]): Hidden states.
359
- coords ([b, h, w, 2]): Absolute positions.
360
- """
361
- b, h, w, c = x.shape
362
-
363
- # pad to multiples of window size
364
- x = maybe_pad(x, self.window_size)
365
- coords = maybe_pad(coords, self.window_size)
366
- h_pad, w_pad = x.shape[1:3]
367
-
368
- # cyclic shift
369
- if self.shift_size > 0:
370
- x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
371
- coords = coords.roll(
372
- shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
373
- )
374
-
375
- # partition windows
376
- x = window_partition(x, self.window_size).flatten(1, 2)
377
- coords = window_partition(coords, self.window_size).flatten(1, 2)
378
-
379
- attn_mask = self.get_attn_mask(h_pad, w_pad, x.device, x.dtype)
380
- if attn_mask is not None:
381
- attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
382
-
383
- # W-MSA/SW-MSA
384
- q, k, v = rearrange(
385
- self.qkv(x), "b n (three h d) -> three b h n d", three=3, h=self.num_heads
386
  )
387
- x = F.scaled_dot_product_attention(
388
- query=self.pe(q, coords),
389
- key=self.pe(k, coords),
390
- value=v,
391
- attn_mask=attn_mask,
392
- dropout_p=self.dropout if self.training else 0.0,
 
 
 
 
393
  )
394
- x = self.wo(rearrange(x, "b h n d -> b n (h d)"))
395
-
396
- # merge windows
397
- x = x.view(-1, self.window_size, self.window_size, c)
398
- x = window_reverse(x, self.window_size, h_pad, w_pad)
399
 
400
- # reverse cyclic shift
401
- if self.shift_size > 0:
402
- x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
403
 
404
- return x[:, :h, :w, :].contiguous() # remove padding
405
 
406
 
407
- class Block(nn.Module):
408
  def __init__(
409
  self,
410
  dim: int,
411
  src_dim: int,
412
  num_heads: int,
413
- window_size: int,
414
- tgt_window_size: int,
415
- src_window_size: int,
416
- shift_size: int = 0,
417
- tgt_shift_size: int = 0,
418
- src_shift_size: int = 0,
419
- dropout: float = 0.1,
420
  ) -> None:
421
  super().__init__()
422
 
423
- self.cross_attention = WindowCrossAttention(
424
  dim,
425
- src_dim,
426
- num_heads=num_heads,
427
- tgt_window_size=tgt_window_size,
428
- src_window_size=src_window_size,
429
- tgt_shift_size=tgt_shift_size,
430
- src_shift_size=src_shift_size,
431
- dropout=dropout,
432
  )
433
- self.cross_attention_norm = nn.LayerNorm(dim)
434
- self.cross_attention_dropout = nn.Dropout(dropout)
435
 
436
- self.self_attention = WindowSelfAttention(
437
- dim, window_size, num_heads, shift_size, dropout=dropout
 
 
 
 
 
438
  )
439
- self.self_attention_norm = nn.LayerNorm(dim)
440
- self.self_attention_dropout = nn.Dropout(dropout)
441
 
442
  self.ffn = FeedForward(dim, dim * 4)
443
  self.ffn_norm = nn.LayerNorm(dim)
444
- self.ffn_dropout = nn.Dropout(dropout)
445
 
446
  def forward(
447
- self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords
448
  ) -> Tensor:
449
- x = self.self_attention(tgt, tgt_coords)
450
- tgt = self.self_attention_norm(tgt + self.self_attention_dropout(x))
451
 
452
  x = self.cross_attention(tgt, src, tgt_coords, src_coords)
453
- tgt = self.cross_attention_norm(tgt + self.cross_attention_dropout(x))
454
 
455
- return self.ffn_norm(tgt + self.ffn_dropout(self.ffn(tgt)))
456
-
457
-
458
- class Stage(nn.Module):
459
- def __init__(
460
- self,
461
- dim: int,
462
- src_dim: int,
463
- depth: int,
464
- num_heads: int,
465
- window_size: int,
466
- tgt_window_size: int,
467
- src_window_size: int,
468
- dropout: float = 0.0,
469
- ) -> None:
470
- super().__init__()
471
- self.blocks = nn.ModuleList()
472
- for i in range(depth):
473
- block = Block(
474
- dim=dim,
475
- src_dim=src_dim,
476
- num_heads=num_heads,
477
- window_size=window_size,
478
- tgt_window_size=tgt_window_size,
479
- src_window_size=src_window_size,
480
- shift_size=0 if i % 2 == 0 else window_size // 2,
481
- tgt_shift_size=0 if i % 2 == 0 else tgt_window_size // 2,
482
- src_shift_size=0 if i % 2 == 0 else src_window_size // 2,
483
- dropout=dropout,
484
- )
485
- self.blocks.append(block)
486
-
487
- def forward(
488
- self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor
489
- ) -> Tensor:
490
- for block in self.blocks:
491
- tgt = block(tgt, src, tgt_coords, src_coords)
492
- return tgt
493
 
494
 
495
  class LSPTransformer(nn.Module):
@@ -498,36 +404,45 @@ class LSPTransformer(nn.Module):
498
 
499
  self.query_block_size = config.query_block_size
500
  self.num_radial_distances = config.num_radial_distances
 
 
501
 
502
- self.stages = nn.ModuleList()
503
- for i, depth in enumerate(config.depths):
504
- stage = Stage(
505
  dim=config.dim,
506
- src_dim=feature_channels[i],
507
- depth=depth,
508
  num_heads=config.num_heads,
509
- window_size=config.window_size,
510
- tgt_window_size=config.tgt_window_sizes[i],
511
- src_window_size=config.src_window_sizes[i],
512
- dropout=config.dropout,
513
  )
514
- self.stages.append(stage)
515
-
516
- self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in feature_channels)
517
 
518
  # output heads
519
- self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
520
- self.point_head = MLP(config.dim, config.dim, 2, 2)
521
- self.radial_distances_head = MLP(
522
- config.dim, config.dim, config.num_radial_distances, 2
 
 
 
523
  )
524
 
525
  self.init_weights()
526
 
527
  def init_weights(self) -> None:
 
 
 
 
528
  # initialize regression layers
529
- nn.init.constant_(self.point_head[-1].weight, 0.0)
530
- nn.init.constant_(self.point_head[-1].bias, 0.0)
 
 
 
 
 
531
 
532
  def forward(
533
  self,
@@ -539,34 +454,44 @@ class LSPTransformer(nn.Module):
539
  ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
540
  src = []
541
  src_coords = []
542
- for i, feature in enumerate(features):
543
  b, _, h, w = feature.shape
544
  coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
545
- src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
546
- src_coords.append(
547
- relative_to_absolute_pos(
548
- coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h)
549
- )
550
  )
 
 
 
 
 
 
 
 
 
 
551
 
552
  logits_list: list[Tensor] = []
553
  ref_points_list: list[Tensor] = []
554
  radial_distances_list: list[Tensor] = []
555
 
556
- new_ref_points = ref_points.clone() # for look forward twice
557
- for i, stage in enumerate(self.stages):
558
- tgt = stage(
 
 
 
559
  tgt=tgt,
560
- src=src[i],
561
  tgt_coords=relative_to_absolute_pos(
562
  ref_points, self.query_block_size, self.query_block_size
563
- ),
564
- src_coords=src_coords[i],
565
  )
566
 
567
  # output heads
568
- delta_point = self.point_head(tgt)
569
- radial_distances = self.radial_distances_head(tgt)
570
  logits = self.class_head(tgt)
571
 
572
  ref_points_list.append(
@@ -577,10 +502,14 @@ class LSPTransformer(nn.Module):
577
  ).flatten(1, 2)
578
  )
579
  logits_list.append(logits.flatten(1, 2))
580
- radial_distances_list.append(radial_distances.flatten(1, 2))
 
 
581
 
582
  new_ref_points = ref_points + delta_point
 
583
  ref_points = new_ref_points.detach()
 
584
 
585
  return {
586
  "logits": logits_list[-1],
@@ -608,12 +537,12 @@ class LSPTransformer(nn.Module):
608
  class FeatureSampling(nn.Module):
609
  def __init__(self, in_dim: int, out_dim: int) -> None:
610
  super().__init__()
611
- self.reduction = nn.Linear(in_dim, out_dim, bias=False)
612
  self.norm = nn.LayerNorm(out_dim)
613
 
614
  def forward(self, points: Tensor, feature: Tensor) -> Tensor:
615
- x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
616
- return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
617
 
618
 
619
  class LSPDetrModel(PreTrainedModel):
@@ -627,7 +556,7 @@ class LSPDetrModel(PreTrainedModel):
627
  _, *feature_channels, neck = self.backbone.num_features
628
 
629
  self.feature_sampling = FeatureSampling(neck, config.dim)
630
- self.decode_head = LSPTransformer(config, feature_channels[::-1])
631
 
632
  def forward(self, pixel_values: Tensor) -> dict[str, Tensor]:
633
  b, _, h, w = pixel_values.shape
@@ -649,4 +578,4 @@ class LSPDetrModel(PreTrainedModel):
649
  neck,
650
  )
651
 
652
- return self.decode_head(tgt, ref_points, features[::-1], h, w)
 
1
  import math
2
+ from functools import cached_property, lru_cache
3
 
4
  import torch
5
  import torch.nn.functional as F
6
+ from einops import rearrange
7
  from torch import Tensor, nn
8
+ from torch.nn.attention.flex_attention import (
9
+ BlockMask,
10
+ _mask_mod_signature,
11
+ create_block_mask,
12
+ flex_attention,
13
+ )
14
  from torch.nn.utils import parametrize
15
+ from transformers.modeling_utils import PreTrainedModel
 
16
  from transformers.utils.backbone_utils import load_backbone
17
 
18
+ from .configuration import LSPDetrConfig, STAConfig
19
 
20
 
21
+ flex_attention = torch.compile(flex_attention, dynamic=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
 
91
  def init_weights(self) -> None:
92
  self.S = nn.init.kaiming_uniform_(self.S, a=math.sqrt(5))
93
 
94
+ @cached_property
95
+ def P(self) -> Tensor:
96
+ i_plus_s_inv = torch.linalg.inv(self.I + self.S)
97
+ return torch.matmul(self.I - self.S, i_plus_s_inv)
98
+
99
  @parametrize.cached()
100
  @torch.autocast("cuda", enabled=False)
101
  def forward(self, x: Tensor, positions: Tensor) -> Tensor:
 
106
  positions ([b, n, pos_dim]): Positions tensor.
107
  """
108
  # Compute (I + S)^-1 @ x
109
+ if self.training:
110
+ # Use linalg.solve during training for numerical stability.
111
+ y = torch.linalg.solve(
112
+ self.I + self.S, rearrange(x.float(), "b h n d -> (b h) d n")
113
+ )
114
+ px = torch.matmul(self.I - self.S, y)
115
+ px = rearrange(px, "(b h) d n -> b h n d", b=x.size(0))
116
+ else:
117
+ # During inference, use the pre-calculated matrix P for performance.
118
+ px = x.float() @ self.P.T
119
 
120
+ px = px.contiguous()
 
 
121
 
122
  # apply RoPE-Mixed
123
  angles = torch.einsum("bnk,khc->bhnc", positions, self.freqs)
 
128
  return out.type_as(x)
129
 
130
 
131
+ class MLP(nn.Sequential):
132
+ """Very simple multi-layer perceptron."""
133
+
134
+ def __init__(
135
+ self,
136
+ input_dim: int,
137
+ hidden_dim: int,
138
+ output_dim: int,
139
+ num_layers: int,
140
+ act_layer: type[nn.Module] = nn.GELU,
141
+ dropout: float = 0.0,
142
+ ) -> None:
143
+ assert num_layers > 1
144
+
145
+ layers = []
146
+ h = [hidden_dim] * (num_layers - 1)
147
+ for n, k in zip([input_dim, *h], h, strict=False):
148
+ layers.append(nn.Linear(n, k))
149
+ layers.append(act_layer())
150
+ if dropout > 0:
151
+ layers.append(nn.Dropout(dropout))
152
+
153
+ layers.append(nn.Linear(hidden_dim, output_dim))
154
+ super().__init__(*layers)
155
+
156
+
157
+ class FeedForward(nn.Module):
158
+ """FeedForward module.
159
+
160
+ Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py
161
+ """
162
+
163
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None:
164
+ """Initialize the FeedForward module.
165
+
166
+ Args:
167
+ dim (int): Input dimension.
168
+ hidden_dim (int): Hidden dimension of the feedforward layer.
169
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
170
+ """
171
+ super().__init__()
172
+ hidden_dim = int(2 * hidden_dim / 3)
173
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
174
+
175
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
176
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
177
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
178
+
179
+ def forward(self, x: Tensor) -> Tensor:
180
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
181
+
182
+
183
+ def generate_sta_mask(
184
+ q_canvas_w: int,
185
+ kv_canvas_hw: tuple[int, int],
186
+ kernel: int,
187
+ q_tile: int,
188
+ kv_tile: int,
189
+ ) -> _mask_mod_signature:
190
+ q_canvas_tile_w = q_canvas_w // q_tile
191
+ kv_canvas_tile_h = kv_canvas_hw[0] // kv_tile
192
+ kv_canvas_tile_w = kv_canvas_hw[1] // kv_tile
193
+
194
+ def q_tile_rescale(x: Tensor):
195
+ # Computes round(x * (kv_canvas_tile_w - 1) / (q_canvas_tile_w - 1))
196
+ scale_numerator = kv_canvas_tile_w - 1
197
+ scale_denominator = q_canvas_tile_w - 1
198
+ return (x * scale_numerator + scale_denominator // 2) // scale_denominator
199
+
200
+ def get_tile_xy(
201
+ idx: Tensor, tile_size: int, canvas_tile_w: int
202
+ ) -> tuple[Tensor, Tensor]:
203
+ tile_id = idx // (tile_size * tile_size)
204
+ tile_x = tile_id % canvas_tile_w
205
+ tile_y = tile_id // canvas_tile_w
206
+ return tile_x, tile_y
207
+
208
+ def sta_mask_2d(b: Tensor, h: Tensor, q_idx: Tensor, kv_idx: Tensor) -> Tensor:
209
+ q_x_tile, q_y_tile = get_tile_xy(q_idx, q_tile, q_canvas_tile_w)
210
+ kv_x_tile, kv_y_tile = get_tile_xy(kv_idx, kv_tile, kv_canvas_tile_w)
211
+
212
+ q_x_tile = q_tile_rescale(q_x_tile)
213
+ q_y_tile = q_tile_rescale(q_y_tile)
214
+
215
+ center_x = q_x_tile.clamp(kernel // 2, (kv_canvas_tile_w - 1) - kernel // 2)
216
+ center_y = q_y_tile.clamp(kernel // 2, (kv_canvas_tile_h - 1) - kernel // 2)
217
+
218
+ # Apply kernel mask in canvas coordinates (not tile coordinates)
219
+ x_mask = torch.abs(center_x - kv_x_tile) <= kernel // 2
220
+ y_mask = torch.abs(center_y - kv_y_tile) <= kernel // 2
221
+
222
+ return x_mask & y_mask
223
+
224
+ return sta_mask_2d
225
+
226
+
227
+ @lru_cache
228
+ def create_sta_block_mask(
229
+ q_len: int,
230
+ kv_len: int,
231
+ q_width: int,
232
+ kv_width: int,
233
+ kernel: int,
234
+ q_tile: int,
235
+ kv_tile: int,
236
+ ) -> BlockMask:
237
+ return create_block_mask(
238
+ generate_sta_mask(
239
+ q_width, (kv_len // kv_width, kv_width), kernel, q_tile, kv_tile
240
+ ),
241
+ B=None,
242
+ H=None,
243
+ device="cuda" if torch.cuda.is_available() else "cpu",
244
+ Q_LEN=q_len,
245
+ KV_LEN=kv_len,
246
+ _compile=True,
247
+ )
248
 
249
 
250
  @torch.autocast("cuda", enabled=False)
 
260
  return torch.stack((absolute_x, absolute_y), dim=-1)
261
 
262
 
263
+ class STAttention(nn.Module):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  def __init__(
265
  self,
266
  dim: int,
267
  src_dim: int,
 
 
268
  num_heads: int,
269
+ kernel: int,
270
+ q_tile: int,
271
+ kv_tile: int,
272
  ) -> None:
273
  super().__init__()
 
274
  self.num_heads = num_heads
275
+ self.kernel = kernel
276
+ self.q_tile = q_tile
277
+ self.kv_tile = kv_tile
 
 
278
 
279
  self.pe = CayleySTRING(dim, num_heads)
280
+ self.q = nn.Linear(dim, dim, bias=False)
281
  self.kv = nn.Linear(src_dim, dim * 2, bias=False)
282
  self.wo = nn.Linear(dim, dim, bias=False)
283
 
284
+ def maybe_pad(self, x: Tensor, tile: int) -> Tensor:
285
+ h, w = x.shape[1:3]
286
+ pad_right = (tile - w % tile) % tile
287
+ pad_bottom = (tile - h % tile) % tile
288
+ return F.pad(x, (0, 0, 0, pad_right, 0, pad_bottom))
289
+
290
+ def tile(self, x: Tensor, height: int, tile: int) -> tuple[Tensor, int, int]:
291
+ x = rearrange(x, "b head (h w) dim -> b h w (head dim)", h=height)
292
+ x = self.maybe_pad(x, tile)
293
+ h, w = x.shape[1:3]
294
+ x = rearrange(
295
+ x,
296
+ "b (n_h ts_h) (n_w ts_w) (h d) -> b h (n_h n_w ts_h ts_w) d",
297
+ ts_h=tile,
298
+ ts_w=tile,
299
+ h=self.num_heads,
 
300
  )
301
+ return x, h, w
 
 
302
 
303
  def forward(
304
+ self, tgt: Tensor, src: Tensor, q_coords: Tensor, k_coords: Tensor
305
  ) -> Tensor:
306
+ h, w = tgt.shape[1:3]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
+ q = rearrange(
309
+ self.q(tgt), "b h w (head d) -> b head (h w) d", head=self.num_heads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
  )
 
 
 
 
 
 
311
  k, v = rearrange(
312
+ self.kv(src),
313
+ "b h w (two head d) -> two b head (h w) d",
314
+ two=2,
315
+ head=self.num_heads,
316
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
+ # RoPE
319
+ q = self.pe(q, q_coords)
320
+ k = self.pe(k, k_coords)
321
+
322
+ # tile
323
+ q, q_h, q_w = self.tile(q, h, self.q_tile)
324
+ k, _, kv_w = self.tile(k, src.shape[1], self.kv_tile)
325
+ v, _, _ = self.tile(v, src.shape[1], self.kv_tile)
326
+
327
+ # flex attention
328
+ block_mask = create_sta_block_mask(
329
+ q_len=q.shape[2],
330
+ kv_len=k.shape[2],
331
+ q_width=q_w,
332
+ kv_width=kv_w,
333
+ kernel=self.kernel,
334
+ q_tile=self.q_tile,
335
+ kv_tile=self.kv_tile,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
  )
337
+ x = flex_attention(q, k, v, block_mask=block_mask)
338
+
339
+ # un-tile
340
+ x = rearrange(
341
+ x,
342
+ "b h (n_h n_w ts_h ts_w) d -> b (n_h ts_h) (n_w ts_w) (h d)",
343
+ n_h=q_h // self.q_tile,
344
+ n_w=q_w // self.q_tile,
345
+ ts_h=self.q_tile,
346
+ ts_w=self.q_tile,
347
  )
 
 
 
 
 
348
 
349
+ # remove padding
350
+ x = x[:, :h, :w, :].contiguous()
 
351
 
352
+ return self.wo(x)
353
 
354
 
355
+ class Layer(nn.Module):
356
  def __init__(
357
  self,
358
  dim: int,
359
  src_dim: int,
360
  num_heads: int,
361
+ self_sta_config: STAConfig,
362
+ cross_sta_config: STAConfig,
 
 
 
 
 
363
  ) -> None:
364
  super().__init__()
365
 
366
+ self.self_attention = STAttention(
367
  dim,
368
+ dim,
369
+ num_heads,
370
+ kernel=self_sta_config["kernel"],
371
+ q_tile=self_sta_config["q_tile"],
372
+ kv_tile=self_sta_config["kv_tile"],
 
 
373
  )
374
+ self.self_attention_norm = nn.LayerNorm(dim)
 
375
 
376
+ self.cross_attention = STAttention(
377
+ dim,
378
+ src_dim,
379
+ num_heads,
380
+ kernel=cross_sta_config["kernel"],
381
+ q_tile=cross_sta_config["q_tile"],
382
+ kv_tile=cross_sta_config["kv_tile"],
383
  )
384
+ self.cross_attention_norm = nn.LayerNorm(dim)
 
385
 
386
  self.ffn = FeedForward(dim, dim * 4)
387
  self.ffn_norm = nn.LayerNorm(dim)
 
388
 
389
  def forward(
390
+ self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor
391
  ) -> Tensor:
392
+ x = self.self_attention(tgt, tgt, tgt_coords, tgt_coords)
393
+ tgt = self.self_attention_norm(tgt + x)
394
 
395
  x = self.cross_attention(tgt, src, tgt_coords, src_coords)
396
+ tgt = self.cross_attention_norm(tgt + x)
397
 
398
+ return self.ffn_norm(tgt + self.ffn(tgt))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
 
400
 
401
  class LSPTransformer(nn.Module):
 
404
 
405
  self.query_block_size = config.query_block_size
406
  self.num_radial_distances = config.num_radial_distances
407
+ self.feature_levels = config.feature_levels
408
+ self.num_classes = config.num_classes + 1
409
 
410
+ self.layers = nn.ModuleList()
411
+ for level in config.feature_levels:
412
+ layer = Layer(
413
  dim=config.dim,
414
+ src_dim=feature_channels[level],
 
415
  num_heads=config.num_heads,
416
+ self_sta_config=config.self_sta_config,
417
+ cross_sta_config=config.cross_sta_config[level],
 
 
418
  )
419
+ self.layers.append(layer)
 
 
420
 
421
  # output heads
422
+ self.class_head = nn.Linear(config.dim, self.num_classes)
423
+ self.point_head = nn.ModuleList(
424
+ MLP(config.dim, config.dim, 2, 3) for _ in config.feature_levels
425
+ )
426
+ self.radial_distances_head = nn.ModuleList(
427
+ MLP(config.dim, config.dim, config.num_radial_distances, 3)
428
+ for _ in config.feature_levels
429
  )
430
 
431
  self.init_weights()
432
 
433
  def init_weights(self) -> None:
434
+ prior_prob = 0.01
435
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
436
+ self.class_head.bias.data = torch.ones(self.num_classes) * bias_value
437
+
438
  # initialize regression layers
439
+ for head in self.point_head:
440
+ nn.init.constant_(head[-1].weight, 0)
441
+ nn.init.constant_(head[-1].bias, 0)
442
+
443
+ for head in self.radial_distances_head:
444
+ nn.init.constant_(head[-1].weight, 0)
445
+ nn.init.constant_(head[-1].bias, 0)
446
 
447
  def forward(
448
  self,
 
454
  ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
455
  src = []
456
  src_coords = []
457
+ for feature in features:
458
  b, _, h, w = feature.shape
459
  coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
460
+ coords = relative_to_absolute_pos(
461
+ coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h)
 
 
 
462
  )
463
+ # the outputs from SwinV2 are already normalized
464
+ src.append(rearrange(feature, "b c h w -> b h w c"))
465
+ src_coords.append(rearrange(coords, "b h w pos -> b (h w) pos"))
466
+
467
+ radial_distances = torch.full(
468
+ (*tgt.shape[:3], self.num_radial_distances),
469
+ math.log1p(self.query_block_size / 2),
470
+ dtype=torch.float32,
471
+ device=tgt.device,
472
+ )
473
 
474
  logits_list: list[Tensor] = []
475
  ref_points_list: list[Tensor] = []
476
  radial_distances_list: list[Tensor] = []
477
 
478
+ # for look forward twice
479
+ new_ref_points = ref_points.clone()
480
+ new_radial_distances = radial_distances.clone()
481
+
482
+ for i, layer in enumerate(self.layers):
483
+ tgt = layer(
484
  tgt=tgt,
485
+ src=src[self.feature_levels[i]],
486
  tgt_coords=relative_to_absolute_pos(
487
  ref_points, self.query_block_size, self.query_block_size
488
+ ).flatten(1, 2),
489
+ src_coords=src_coords[self.feature_levels[i]],
490
  )
491
 
492
  # output heads
493
+ delta_point = self.point_head[i](tgt)
494
+ delta_distances = self.radial_distances_head[i](tgt)
495
  logits = self.class_head(tgt)
496
 
497
  ref_points_list.append(
 
502
  ).flatten(1, 2)
503
  )
504
  logits_list.append(logits.flatten(1, 2))
505
+ radial_distances_list.append(
506
+ torch.flatten(new_radial_distances + delta_distances, 1, 2)
507
+ )
508
 
509
  new_ref_points = ref_points + delta_point
510
+ new_radial_distances = radial_distances + delta_distances
511
  ref_points = new_ref_points.detach()
512
+ radial_distances = new_radial_distances.detach()
513
 
514
  return {
515
  "logits": logits_list[-1],
 
537
  class FeatureSampling(nn.Module):
538
  def __init__(self, in_dim: int, out_dim: int) -> None:
539
  super().__init__()
540
+ self.reduction = nn.Conv2d(in_dim, out_dim, kernel_size=1, bias=False)
541
  self.norm = nn.LayerNorm(out_dim)
542
 
543
  def forward(self, points: Tensor, feature: Tensor) -> Tensor:
544
+ x = F.grid_sample(self.reduction(feature), points * 2 - 1, align_corners=False)
545
+ return self.norm(rearrange(x, "b c h w -> b h w c"))
546
 
547
 
548
  class LSPDetrModel(PreTrainedModel):
 
556
  _, *feature_channels, neck = self.backbone.num_features
557
 
558
  self.feature_sampling = FeatureSampling(neck, config.dim)
559
+ self.decode_head = LSPTransformer(config, feature_channels)
560
 
561
  def forward(self, pixel_values: Tensor) -> dict[str, Tensor]:
562
  b, _, h, w = pixel_values.shape
 
578
  neck,
579
  )
580
 
581
+ return self.decode_head(tgt, ref_points, features, h, w)