matejpekar commited on
Commit
c3f67c4
·
verified ·
1 Parent(s): 1a37868

Upload LSPDETR

Browse files
Files changed (4) hide show
  1. config.json +41 -0
  2. configuration.py +35 -0
  3. model.safetensors +3 -0
  4. modeling.py +671 -0
config.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LSPDETR"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration.LSPDETRConfig",
7
+ "AutoModelForObjectDetection": "modeling.LSPDETR"
8
+ },
9
+ "backbone": "microsoft/swinv2-tiny-patch4-window16-256",
10
+ "depths": [
11
+ 6,
12
+ 2,
13
+ 2
14
+ ],
15
+ "dim": 384,
16
+ "dropout": 0.1,
17
+ "in_channels": [
18
+ 768,
19
+ 384,
20
+ 192,
21
+ 96
22
+ ],
23
+ "model_type": "LSP-DETR",
24
+ "num_classes": 2,
25
+ "num_heads": 12,
26
+ "num_radial_distances": 64,
27
+ "query_block_size": 16,
28
+ "src_window_sizes": [
29
+ 8,
30
+ 16,
31
+ 32
32
+ ],
33
+ "tgt_window_sizes": [
34
+ 8,
35
+ 8,
36
+ 8
37
+ ],
38
+ "torch_dtype": "float32",
39
+ "transformers_version": "4.51.3",
40
+ "window_size": 16
41
+ }
configuration.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LSPDETRConfig(PretrainedConfig):
5
+ model_type = "LSP-DETR"
6
+
7
+ def __init__(
8
+ self,
9
+ backbone="microsoft/swinv2-tiny-patch4-window16-256",
10
+ dim: int = 384,
11
+ num_classes: int = 2,
12
+ depths: tuple[int, ...] = (6, 2, 2),
13
+ in_channels: tuple[int, ...] = (768, 384, 192, 96),
14
+ query_block_size: int = 16,
15
+ num_heads: int = 12,
16
+ window_size: int = 16,
17
+ tgt_window_sizes: tuple[int, ...] = (8, 8, 8),
18
+ src_window_sizes: tuple[int, ...] = (8, 16, 32),
19
+ num_radial_distances: int = 64,
20
+ dropout: float = 0.1,
21
+ **kwargs,
22
+ ) -> None:
23
+ self.backbone = backbone
24
+ self.dim = dim
25
+ self.num_classes = num_classes
26
+ self.depths = depths
27
+ self.in_channels = in_channels
28
+ self.query_block_size = query_block_size
29
+ self.num_heads = num_heads
30
+ self.window_size = window_size
31
+ self.tgt_window_sizes = tgt_window_sizes
32
+ self.src_window_sizes = src_window_sizes
33
+ self.num_radial_distances = num_radial_distances
34
+ self.dropout = dropout
35
+ super().__init__(**kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e725b30741b7f18033487d32688d1cc223f445318bd0a600f1dca082cd9e9352
3
+ size 205650424
modeling.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from torch import Tensor, nn
7
+ from torch.nn.utils import parametrize
8
+ from transformers import PreTrainedModel, Swinv2Backbone
9
+ from transformers.models.swinv2.modeling_swinv2 import window_partition, window_reverse
10
+
11
+ from .configuration import LSPDETRConfig
12
+
13
+
14
+ def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
15
+ freqs_x = []
16
+ freqs_y = []
17
+ freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim))
18
+ for _ in range(num_heads):
19
+ angles = torch.rand(1) * 2 * torch.pi
20
+ fx = torch.cat(
21
+ [freqs * torch.cos(angles), freqs * torch.cos(torch.pi / 2 + angles)],
22
+ dim=-1,
23
+ )
24
+ fy = torch.cat(
25
+ [freqs * torch.sin(angles), freqs * torch.sin(torch.pi / 2 + angles)],
26
+ dim=-1,
27
+ )
28
+ freqs_x.append(fx)
29
+ freqs_y.append(fy)
30
+ freqs_x = torch.stack(freqs_x, dim=0)
31
+ freqs_y = torch.stack(freqs_y, dim=0)
32
+ return torch.stack([freqs_x, freqs_y], dim=0)
33
+
34
+
35
+ class Skew(nn.Module):
36
+ """Skew-symmetric matrix parameterization."""
37
+
38
+ def forward(self, x: Tensor) -> Tensor:
39
+ a = x.triu(1)
40
+ return a - a.transpose(-1, -2)
41
+
42
+ def right_inverse(self, x: Tensor) -> Tensor:
43
+ return x.triu(1)
44
+
45
+
46
+ class CayleySTRING(nn.Module):
47
+ """Implements the Cayley-STRING positional encoding.
48
+
49
+ Based on "Learning the RoPEs: Better 2D and 3D Position Encodings with STRING"
50
+ (https://arxiv.org/abs/2502.02562).
51
+
52
+ Applies RoPE followed by multiplication with a learnable orthogonal matrix P
53
+ parameterized by the Cayley transform: P = (I - S)(I + S)^-1, where S is
54
+ a learnable skew-symmetric matrix.
55
+
56
+ Args:
57
+ dim (int): The feature dimension of the input tensor. Must be even.
58
+ max_seq_len (int): The maximum sequence length.
59
+ base (int): The base value for the RoPE frequency calculation. Defaults to 10000.
60
+ pos_dim (int): The dimensionality of the position vectors (e.g., 1 for 1D, 2 for 2D). Defaults to 1.
61
+ """
62
+
63
+ def __init__(
64
+ self, dim: int, num_heads: int, pos_dim: int = 2, theta: float = 100.0
65
+ ) -> None:
66
+ super().__init__()
67
+ assert dim % num_heads == 0, "Dimension must be divisible by num_heads."
68
+
69
+ head_dim = dim // num_heads
70
+
71
+ self.freqs = nn.Parameter(init_freqs(head_dim, num_heads, pos_dim, theta))
72
+
73
+ self.S = nn.Parameter(torch.zeros(head_dim, head_dim))
74
+ parametrize.register_parametrization(self, "S", Skew())
75
+
76
+ self.register_buffer("I", torch.eye(head_dim), persistent=False)
77
+
78
+ self.init_weights()
79
+
80
+ def init_weights(self) -> None:
81
+ self.S = nn.init.kaiming_uniform_(self.S, a=math.sqrt(5))
82
+
83
+ @parametrize.cached()
84
+ @torch.autocast("cuda", enabled=False)
85
+ def forward(self, x: Tensor, positions: Tensor) -> Tensor:
86
+ """Apply Cayley-STRING positional encoding.
87
+
88
+ Args:
89
+ x ([b, h, n, d]): Input tensor.
90
+ positions ([b, n, pos_dim]): Positions tensor.
91
+ """
92
+ # Compute (I + S)^-1 @ x
93
+ y = torch.linalg.solve(
94
+ self.I + self.S, rearrange(x.float(), "b h n d -> h d (b n)")
95
+ )
96
+
97
+ # change of basis
98
+ px = torch.matmul(self.I - self.S, y)
99
+ px = rearrange(px, "h d (b n) -> b h n d", b=x.size(0)).contiguous()
100
+
101
+ # apply RoPE-Mixed
102
+ angles = torch.einsum("bnk,khc->bhnc", positions, self.freqs)
103
+ freqs_cis = torch.polar(torch.ones_like(angles), angles)
104
+ px_ = torch.view_as_complex(rearrange(px, "... (d two) -> ... d two", two=2))
105
+ out = rearrange(torch.view_as_real(px_ * freqs_cis), "... d two -> ... (d two)")
106
+
107
+ return out.type_as(x)
108
+
109
+
110
+ class MLP(nn.Sequential):
111
+ """Very simple multi-layer perceptron."""
112
+
113
+ def __init__(
114
+ self,
115
+ input_dim: int,
116
+ hidden_dim: int,
117
+ output_dim: int,
118
+ num_layers: int,
119
+ act_layer: type[nn.Module] = nn.ReLU,
120
+ dropout: float = 0.0,
121
+ ) -> None:
122
+ assert num_layers > 1
123
+
124
+ layers = []
125
+ h = [hidden_dim] * (num_layers - 1)
126
+ for n, k in zip([input_dim, *h], h, strict=False):
127
+ layers.append(nn.Linear(n, k))
128
+ layers.append(act_layer())
129
+ if dropout > 0:
130
+ layers.append(nn.Dropout(dropout))
131
+
132
+ layers.append(nn.Linear(hidden_dim, output_dim))
133
+ super().__init__(*layers)
134
+
135
+
136
+ class FeedForward(nn.Module):
137
+ """FeedForward module.
138
+
139
+ Taken from https://github.com/meta-llama/llama-models/blob/main/models/llama4/ffn.py
140
+ """
141
+
142
+ def __init__(self, dim: int, hidden_dim: int, multiple_of: int = 256) -> None:
143
+ """Initialize the FeedForward module.
144
+
145
+ Args:
146
+ dim (int): Input dimension.
147
+ hidden_dim (int): Hidden dimension of the feedforward layer.
148
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
149
+ """
150
+ super().__init__()
151
+ hidden_dim = int(2 * hidden_dim / 3)
152
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
153
+
154
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
155
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
156
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
157
+
158
+ def forward(self, x: Tensor) -> Tensor:
159
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
160
+
161
+
162
+ @torch.autocast("cuda", enabled=False)
163
+ def relative_to_absolute_points(points: Tensor, height: int, width: int) -> Tensor:
164
+ points = points.sigmoid()
165
+ h, w = points.shape[1:3]
166
+
167
+ step_x = width / w
168
+ step_y = height / h
169
+
170
+ anchor_x = torch.arange(0, width, step_x, device=points.device)[:w]
171
+ anchor_y = torch.arange(0, height, step_y, device=points.device)[:h, None]
172
+
173
+ absolute_x = points[..., 0] * step_x + anchor_x
174
+ absolute_y = points[..., 1] * step_y + anchor_y
175
+
176
+ return torch.stack((absolute_x, absolute_y), dim=-1)
177
+
178
+
179
+ @torch.autocast("cuda", enabled=False)
180
+ def relative_to_absolute_points_normalized(points: Tensor) -> Tensor:
181
+ points = points.sigmoid()
182
+ h, w = points.shape[1:3]
183
+
184
+ anchor_x = torch.arange(0, 1, 1 / w, device=points.device)[:w]
185
+ anchor_y = torch.arange(0, 1, 1 / h, device=points.device)[:h, None]
186
+
187
+ absolute_x = points[..., 0] / w + anchor_x
188
+ absolute_y = points[..., 1] / h + anchor_y
189
+
190
+ return torch.stack((absolute_x, absolute_y), dim=-1)
191
+
192
+
193
+ def get_mask_windows(
194
+ height: int, width: int, window_size: int, shift_size: int, device: torch.device
195
+ ) -> Tensor:
196
+ # Create indices for height and width regions
197
+ h_idx = torch.zeros(height, dtype=torch.long, device=device)
198
+ h_idx[height - window_size : height - shift_size] = 1
199
+ h_idx[height - shift_size :] = 2
200
+
201
+ w_idx = torch.zeros(width, dtype=torch.long, device=device)
202
+ w_idx[width - window_size : width - shift_size] = 1
203
+ w_idx[width - shift_size :] = 2
204
+
205
+ # Calculate region index for each pixel using broadcasting
206
+ mask = h_idx.unsqueeze(1) * 3 + w_idx.unsqueeze(0)
207
+
208
+ mask_windows = window_partition(mask[None, ..., None], window_size)
209
+ return rearrange(mask_windows, "n w1 w2 1 -> n (w1 w2)")
210
+
211
+
212
+ class WindowCrossAttention(nn.Module):
213
+ def __init__(
214
+ self,
215
+ dim: int,
216
+ src_dim: int,
217
+ tgt_window_size: int,
218
+ src_window_size: int,
219
+ num_heads: int,
220
+ src_shift_size: int = 0,
221
+ tgt_shift_size: int = 0,
222
+ dropout: float = 0.0,
223
+ ) -> None:
224
+ super().__init__()
225
+
226
+ self.num_heads = num_heads
227
+ self.tgt_window_size = tgt_window_size
228
+ self.src_window_size = src_window_size
229
+ self.src_shift_size = src_shift_size
230
+ self.tgt_shift_size = tgt_shift_size
231
+ self.dropout = dropout
232
+
233
+ self.pe = CayleySTRING(dim, num_heads)
234
+ self.query = nn.Linear(dim, dim, bias=False)
235
+ self.kv = nn.Linear(src_dim, dim * 2, bias=False)
236
+ self.wo = nn.Linear(dim, dim, bias=False)
237
+
238
+ def get_attn_mask(
239
+ self,
240
+ height: int,
241
+ width: int,
242
+ key_height: int,
243
+ key_width: int,
244
+ device: torch.device,
245
+ dtype: torch.dtype,
246
+ ) -> Tensor | None:
247
+ if self.tgt_shift_size == 0:
248
+ return None
249
+
250
+ query_mask = get_mask_windows(
251
+ height, width, self.tgt_window_size, self.tgt_shift_size, device
252
+ )
253
+ key_mask = get_mask_windows(
254
+ key_height, key_width, self.src_window_size, self.src_shift_size, device
255
+ )
256
+
257
+ attn_mask = query_mask.unsqueeze(2) - key_mask.unsqueeze(1)
258
+ return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf)
259
+
260
+ def forward(
261
+ self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coord: Tensor
262
+ ) -> Tensor:
263
+ b, h, w, c = tgt.shape
264
+ src_h, src_w = src.shape[1:3]
265
+
266
+ # cyclic shift
267
+ if self.tgt_shift_size > 0:
268
+ tgt = tgt.roll(
269
+ shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2)
270
+ )
271
+ tgt_coords = tgt_coords.roll(
272
+ shifts=(-self.tgt_shift_size, -self.tgt_shift_size), dims=(1, 2)
273
+ )
274
+
275
+ if self.src_shift_size > 0:
276
+ src = src.roll(
277
+ shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2)
278
+ )
279
+ src_coord = src_coord.roll(
280
+ shifts=(-self.src_shift_size, -self.src_shift_size), dims=(1, 2)
281
+ )
282
+
283
+ # partition windows
284
+ tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
285
+ tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
286
+ src = window_partition(src, self.src_window_size).flatten(1, 2)
287
+ src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
288
+
289
+ attn_mask = self.get_attn_mask(h, w, src_h, src_w, tgt.device, tgt.dtype)
290
+
291
+ if attn_mask is not None:
292
+ attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
293
+
294
+ # W-MCA/SW-MCA
295
+ q = rearrange(self.query(tgt), "b n (h d) -> b h n d", h=self.num_heads)
296
+ k, v = rearrange(
297
+ self.kv(src), "b n (two h d) -> two b h n d", two=2, h=self.num_heads
298
+ )
299
+ x = F.scaled_dot_product_attention(
300
+ query=self.pe(q, tgt_coords),
301
+ key=self.pe(k, src_coord),
302
+ value=v,
303
+ attn_mask=attn_mask,
304
+ dropout_p=self.dropout if self.training else 0.0,
305
+ )
306
+ tgt = self.wo(rearrange(x, "b h n d -> b n (h d)"))
307
+
308
+ # merge windows
309
+ tgt = tgt.view(-1, self.tgt_window_size, self.tgt_window_size, c)
310
+ tgt = window_reverse(tgt, self.tgt_window_size, h, w)
311
+
312
+ # reverse cyclic shift
313
+ if self.tgt_shift_size > 0:
314
+ tgt = torch.roll(
315
+ tgt, shifts=(self.tgt_shift_size, self.tgt_shift_size), dims=(1, 2)
316
+ )
317
+
318
+ return tgt
319
+
320
+
321
+ class WindowSelfAttention(nn.Module):
322
+ def __init__(
323
+ self,
324
+ dim: int,
325
+ window_size: int,
326
+ num_heads: int,
327
+ shift_size: int = 0,
328
+ dropout: float = 0.0,
329
+ ) -> None:
330
+ super().__init__()
331
+
332
+ self.num_heads = num_heads
333
+ self.window_size = window_size
334
+ self.shift_size = shift_size
335
+ self.dropout = dropout
336
+
337
+ self.pe = CayleySTRING(dim, num_heads)
338
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
339
+ self.wo = nn.Linear(dim, dim, bias=False)
340
+
341
+ def get_attn_mask(
342
+ self, height: int, width: int, device: torch.device, dtype: torch.dtype
343
+ ) -> Tensor | None:
344
+ if self.shift_size == 0:
345
+ return None
346
+
347
+ mask_windows = get_mask_windows(
348
+ height, width, self.window_size, self.shift_size, device
349
+ )
350
+ # Calculate the attention mask based on window differences
351
+ attn_mask = mask_windows.unsqueeze(2) - mask_windows.unsqueeze(1)
352
+ return attn_mask.type(dtype).masked_fill(attn_mask != 0, -torch.inf)
353
+
354
+ def forward(self, x: Tensor, coords: Tensor) -> Tensor:
355
+ """Forward function for Window Self-Attention.
356
+
357
+ Args:
358
+ x ([b, h, w, c]): Hidden states.
359
+ coords ([b, h, w, 2]): Absolute positions.
360
+ """
361
+ b, h, w, c = x.shape
362
+
363
+ # cyclic shift
364
+ if self.shift_size > 0:
365
+ x = x.roll(shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
366
+ coords = coords.roll(
367
+ shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
368
+ )
369
+
370
+ # partition windows
371
+ x = window_partition(x, self.window_size).flatten(1, 2)
372
+ coords = window_partition(coords, self.window_size).flatten(1, 2)
373
+
374
+ attn_mask = self.get_attn_mask(h, w, x.device, x.dtype)
375
+ if attn_mask is not None:
376
+ attn_mask = repeat(attn_mask, "n l s -> (b n) h l s", b=b, h=self.num_heads)
377
+
378
+ # W-MSA/SW-MSA
379
+ q, k, v = rearrange(
380
+ self.qkv(x), "b n (three h d) -> three b h n d", three=3, h=self.num_heads
381
+ )
382
+ x = F.scaled_dot_product_attention(
383
+ query=self.pe(q, coords),
384
+ key=self.pe(k, coords),
385
+ value=v,
386
+ attn_mask=attn_mask,
387
+ dropout_p=self.dropout if self.training else 0.0,
388
+ )
389
+ x = self.wo(rearrange(x, "b h n d -> b n (h d)"))
390
+
391
+ # merge windows
392
+ x = x.view(-1, self.window_size, self.window_size, c)
393
+ x = window_reverse(x, self.window_size, h, w)
394
+
395
+ # reverse cyclic shift
396
+ if self.shift_size > 0:
397
+ x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
398
+
399
+ return x
400
+
401
+
402
+ class Block(nn.Module):
403
+ def __init__(
404
+ self,
405
+ dim: int,
406
+ src_dim: int,
407
+ num_heads: int,
408
+ window_size: int,
409
+ tgt_window_size: int,
410
+ src_window_size: int,
411
+ shift_size: int = 0,
412
+ tgt_shift_size: int = 0,
413
+ src_shift_size: int = 0,
414
+ dropout: float = 0.1,
415
+ ) -> None:
416
+ super().__init__()
417
+
418
+ self.cross_attention = WindowCrossAttention(
419
+ dim,
420
+ src_dim,
421
+ num_heads=num_heads,
422
+ tgt_window_size=tgt_window_size,
423
+ src_window_size=src_window_size,
424
+ tgt_shift_size=tgt_shift_size,
425
+ src_shift_size=src_shift_size,
426
+ dropout=dropout,
427
+ )
428
+ self.cross_attention_norm = nn.LayerNorm(dim)
429
+ self.cross_attention_dropout = nn.Dropout(dropout)
430
+
431
+ self.self_attention = WindowSelfAttention(
432
+ dim, window_size, num_heads, shift_size, dropout=dropout
433
+ )
434
+ self.self_attention_norm = nn.LayerNorm(dim)
435
+ self.self_attention_dropout = nn.Dropout(dropout)
436
+
437
+ self.ffn = FeedForward(dim, dim * 4)
438
+ self.ffn_norm = nn.LayerNorm(dim)
439
+ self.ffn_dropout = nn.Dropout(dropout)
440
+
441
+ def forward(
442
+ self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords
443
+ ) -> Tensor:
444
+ x = self.self_attention(tgt, tgt_coords)
445
+ tgt = self.self_attention_norm(tgt + self.self_attention_dropout(x))
446
+
447
+ x = self.cross_attention(tgt, src, tgt_coords, src_coords)
448
+ tgt = self.cross_attention_norm(tgt + self.cross_attention_dropout(x))
449
+
450
+ return self.ffn_norm(tgt + self.ffn_dropout(self.ffn(tgt)))
451
+
452
+
453
+ class Stage(nn.Module):
454
+ def __init__(
455
+ self,
456
+ dim: int,
457
+ src_dim: int,
458
+ depth: int,
459
+ num_heads: int,
460
+ window_size: int,
461
+ tgt_window_size: int,
462
+ src_window_size: int,
463
+ dropout: float = 0.0,
464
+ ) -> None:
465
+ super().__init__()
466
+ self.blocks = nn.ModuleList()
467
+ for i in range(depth):
468
+ block = Block(
469
+ dim=dim,
470
+ src_dim=src_dim,
471
+ num_heads=num_heads,
472
+ window_size=window_size,
473
+ tgt_window_size=tgt_window_size,
474
+ src_window_size=src_window_size,
475
+ shift_size=0 if i % 2 == 0 else window_size // 2,
476
+ tgt_shift_size=0 if i % 2 == 0 else tgt_window_size // 2,
477
+ src_shift_size=0 if i % 2 == 0 else src_window_size // 2,
478
+ dropout=dropout,
479
+ )
480
+ self.blocks.append(block)
481
+
482
+ def forward(
483
+ self, tgt: Tensor, src: Tensor, tgt_coords: Tensor, src_coords: Tensor
484
+ ) -> Tensor:
485
+ for block in self.blocks:
486
+ tgt = block(tgt, src, tgt_coords, src_coords)
487
+ return tgt
488
+
489
+
490
+ class FeatureSampling(nn.Module):
491
+ def __init__(self, in_dim: int, out_dim: int) -> None:
492
+ super().__init__()
493
+ self.reduction = nn.Linear(in_dim, out_dim, bias=False)
494
+ self.norm = nn.LayerNorm(out_dim)
495
+
496
+ def forward(self, points: Tensor, feature: Tensor) -> Tensor:
497
+ x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
498
+ return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
499
+
500
+
501
+ class LSPTransformer(nn.Module):
502
+ def __init__(
503
+ self,
504
+ dim: int,
505
+ num_classes: int,
506
+ query_block_size: int,
507
+ in_channels: list[int],
508
+ depths: list[int],
509
+ num_heads: int,
510
+ window_size: int,
511
+ tgt_window_sizes: list[int],
512
+ src_window_sizes: list[int],
513
+ num_radial_distances: int,
514
+ dropout: float = 0.0,
515
+ ) -> None:
516
+ super().__init__()
517
+
518
+ self.dim = dim
519
+ self.query_block_size = query_block_size
520
+ self.num_radial_distances = num_radial_distances
521
+
522
+ bottleneck, *in_channels = in_channels
523
+ self.feature_sampling = FeatureSampling(bottleneck, dim)
524
+
525
+ self.stages = nn.ModuleList()
526
+ for i, depth in enumerate(depths):
527
+ stage = Stage(
528
+ dim=dim,
529
+ src_dim=in_channels[i],
530
+ depth=depth,
531
+ num_heads=num_heads,
532
+ window_size=window_size,
533
+ tgt_window_size=tgt_window_sizes[i],
534
+ src_window_size=src_window_sizes[i],
535
+ dropout=dropout,
536
+ )
537
+ self.stages.append(stage)
538
+
539
+ self.input_norm = nn.ModuleList(nn.LayerNorm(d) for d in in_channels)
540
+
541
+ # output heads
542
+ self.class_head = nn.Linear(dim, num_classes + 1, bias=False)
543
+ self.point_head = MLP(dim, dim, 2, 3)
544
+ self.radial_distances_head = MLP(dim, dim, num_radial_distances, 3)
545
+
546
+ self.init_weights()
547
+
548
+ def init_weights(self) -> None:
549
+ # initialize regression layers
550
+ nn.init.constant_(self.point_head[-1].weight, 0.0)
551
+ nn.init.constant_(self.point_head[-1].bias, 0.0)
552
+
553
+ def forward(
554
+ self, multi_scale_features: list[Tensor], height: int, width: int
555
+ ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
556
+ *multi_scale_features, bottleneck = multi_scale_features
557
+
558
+ b = bottleneck.size(0)
559
+
560
+ src = []
561
+ src_coords = []
562
+ for i, feature in enumerate(reversed(multi_scale_features)):
563
+ h, w = feature.shape[2:4]
564
+ coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
565
+ src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
566
+ src_coords.append(relative_to_absolute_points(coords, height, width))
567
+
568
+ ref_points = torch.zeros(
569
+ b,
570
+ height // self.query_block_size,
571
+ width // self.query_block_size,
572
+ 2,
573
+ dtype=torch.float32,
574
+ device=bottleneck.device,
575
+ ) # center positions
576
+ tgt = self.feature_sampling(
577
+ relative_to_absolute_points_normalized(ref_points), bottleneck
578
+ )
579
+
580
+ logits_list: list[Tensor] = []
581
+ ref_points_list: list[Tensor] = []
582
+ radial_distances_list: list[Tensor] = []
583
+
584
+ new_ref_points = ref_points.clone() # for look forward twice
585
+ for i, stage in enumerate(self.stages):
586
+ tgt = stage(
587
+ tgt=tgt,
588
+ src=src[i],
589
+ tgt_coords=relative_to_absolute_points(ref_points, height, width),
590
+ src_coords=src_coords[i],
591
+ )
592
+
593
+ # output heads
594
+ delta_point = self.point_head(tgt)
595
+ radial_distances = self.radial_distances_head(tgt)
596
+ logits = self.class_head(tgt)
597
+
598
+ ref_points_list.append(
599
+ relative_to_absolute_points_normalized(
600
+ new_ref_points + delta_point
601
+ ).flatten(1, 2)
602
+ )
603
+ logits_list.append(logits.flatten(1, 2))
604
+ radial_distances_list.append(radial_distances.flatten(1, 2))
605
+
606
+ new_ref_points = ref_points + delta_point
607
+ ref_points = new_ref_points.detach()
608
+
609
+ return {
610
+ "logits": logits_list[-1],
611
+ "points": ref_points_list[-1],
612
+ "radial_distances": radial_distances_list[-1],
613
+ "polygons": self.get_polygons(
614
+ relative_to_absolute_points(ref_points, height, width).flatten(1, 2),
615
+ radial_distances_list[-1],
616
+ ),
617
+ "aux_outputs": [
618
+ {
619
+ "logits": a,
620
+ "points": b,
621
+ "radial_distances": c,
622
+ }
623
+ for a, b, c in zip(
624
+ logits_list[:-1],
625
+ ref_points_list[:-1],
626
+ radial_distances_list[:-1],
627
+ strict=True,
628
+ )
629
+ ],
630
+ }
631
+
632
+ @torch.no_grad()
633
+ @torch.autocast("cuda", enabled=False)
634
+ def get_polygons(self, ref_points: Tensor, radial_distances: Tensor) -> Tensor:
635
+ t = torch.linspace(
636
+ 0, 1, self.num_radial_distances + 1, device=ref_points.device
637
+ )[:-1]
638
+ cos = torch.cos(2 * torch.pi * t)
639
+ sin = torch.sin(2 * torch.pi * t)
640
+
641
+ radial_distances = radial_distances.expm1()
642
+ polar = radial_distances.unsqueeze(-1) * torch.stack([sin, cos], dim=-1)
643
+ return ref_points.unsqueeze(-2) + polar
644
+
645
+
646
+ class LSPDETR(PreTrainedModel):
647
+ def __init__(self, config: LSPDETRConfig) -> None:
648
+ super().__init__(config)
649
+
650
+ self.backbone = Swinv2Backbone.from_pretrained(
651
+ config.backbone, out_features=["stage1", "stage2", "stage3", "stage4"]
652
+ )
653
+
654
+ self.decode_head = LSPTransformer(
655
+ dim=config.dim,
656
+ num_classes=config.num_classes,
657
+ query_block_size=config.query_block_size,
658
+ in_channels=config.in_channels,
659
+ depths=config.depths,
660
+ num_heads=config.num_heads,
661
+ window_size=config.window_size,
662
+ tgt_window_sizes=config.tgt_window_sizes,
663
+ src_window_sizes=config.src_window_sizes,
664
+ num_radial_distances=config.num_radial_distances,
665
+ dropout=config.dropout,
666
+ )
667
+
668
+ def forward(self, image: Tensor) -> dict[str, Tensor]:
669
+ features = self.backbone(image).feature_maps
670
+ height, width = image.shape[2:]
671
+ return self.decode_head(features, height, width)