matejpekar commited on
Commit
f6fbbfa
·
verified ·
1 Parent(s): 7e51321

Upload model

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. configuration.py +1 -1
  3. model.safetensors +2 -2
  4. modeling.py +72 -77
config.json CHANGED
@@ -42,5 +42,5 @@
42
  "transformers_version": "4.51.3",
43
  "use_pretrained_backbone": true,
44
  "use_timm_backbone": false,
45
- "window_size": 16
46
  }
 
42
  "transformers_version": "4.51.3",
43
  "use_pretrained_backbone": true,
44
  "use_timm_backbone": false,
45
+ "window_size": 8
46
  }
configuration.py CHANGED
@@ -19,7 +19,7 @@ class LSPDetrConfig(PretrainedConfig):
19
  depths: tuple[int, ...] = (6, 2, 2),
20
  query_block_size: int = 16,
21
  num_heads: int = 12,
22
- window_size: int = 16,
23
  tgt_window_sizes: tuple[int, ...] = (8, 8, 8),
24
  src_window_sizes: tuple[int, ...] = (8, 16, 32),
25
  num_radial_distances: int = 64,
 
19
  depths: tuple[int, ...] = (6, 2, 2),
20
  query_block_size: int = 16,
21
  num_heads: int = 12,
22
+ window_size: int = 8,
23
  tgt_window_sizes: tuple[int, ...] = (8, 8, 8),
24
  src_window_sizes: tuple[int, ...] = (8, 16, 32),
25
  num_radial_distances: int = 64,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:6411cad5a0ebad05cbeb8324502f020a4a2a145fa4605dd09757cedb1018ad45
3
- size 205648888
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7f85a91f41a67c7c3f6fb50f71f41e95e1e47c72e561b98308d65d899160de43
3
+ size 204465704
modeling.py CHANGED
@@ -21,7 +21,7 @@ class MLP(nn.Sequential):
21
  hidden_dim: int,
22
  output_dim: int,
23
  num_layers: int,
24
- act_layer: type[nn.Module] = nn.ReLU,
25
  dropout: float = 0.0,
26
  ) -> None:
27
  assert num_layers > 1
@@ -65,6 +65,7 @@ class FeedForward(nn.Module):
65
 
66
 
67
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
 
68
  freqs_x = []
69
  freqs_y = []
70
  freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim))
@@ -168,33 +169,15 @@ def maybe_pad(x: Tensor, window_size: int) -> Tensor:
168
 
169
 
170
  @torch.autocast("cuda", enabled=False)
171
- def relative_to_absolute_points(points: Tensor, height: int, width: int) -> Tensor:
172
- points = points.sigmoid()
173
- h, w = points.shape[1:3]
174
 
175
- step_x = width / w
176
- step_y = height / h
177
-
178
- anchor_x = torch.arange(0, width, step_x, device=points.device)[:w]
179
- anchor_y = torch.arange(0, height, step_y, device=points.device)[:h, None]
180
-
181
- absolute_x = points[..., 0] * step_x + anchor_x
182
- absolute_y = points[..., 1] * step_y + anchor_y
183
-
184
- return torch.stack((absolute_x, absolute_y), dim=-1)
185
-
186
-
187
- @torch.autocast("cuda", enabled=False)
188
- def relative_to_absolute_points_normalized(points: Tensor) -> Tensor:
189
- points = points.sigmoid()
190
- h, w = points.shape[1:3]
191
-
192
- anchor_x = torch.arange(0, 1, 1 / w, device=points.device)[:w]
193
- anchor_y = torch.arange(0, 1, 1 / h, device=points.device)[:h, None]
194
-
195
- absolute_x = points[..., 0] / w + anchor_x
196
- absolute_y = points[..., 1] / h + anchor_y
197
 
 
 
198
  return torch.stack((absolute_x, absolute_y), dim=-1)
199
 
200
 
@@ -297,8 +280,8 @@ class WindowCrossAttention(nn.Module):
297
 
298
  # partition windows
299
  tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
300
- tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
301
  src = window_partition(src, self.src_window_size).flatten(1, 2)
 
302
  src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
303
 
304
  attn_mask = self.get_attn_mask(
@@ -509,31 +492,13 @@ class Stage(nn.Module):
509
  return tgt
510
 
511
 
512
- class FeatureSampling(nn.Module):
513
- def __init__(self, in_dim: int, out_dim: int) -> None:
514
- super().__init__()
515
- self.reduction = nn.Linear(in_dim, out_dim, bias=False)
516
- self.norm = nn.LayerNorm(out_dim)
517
-
518
- def forward(self, points: Tensor, feature: Tensor) -> Tensor:
519
- x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
520
- return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
521
-
522
-
523
  class LSPTransformer(nn.Module):
524
- def __init__(
525
- self,
526
- config: LSPDetrConfig,
527
- bottleneck_channels: int,
528
- feature_channels: list[int],
529
- ) -> None:
530
  super().__init__()
531
 
532
  self.query_block_size = config.query_block_size
533
  self.num_radial_distances = config.num_radial_distances
534
 
535
- self.feature_sampling = FeatureSampling(bottleneck_channels, config.dim)
536
-
537
  self.stages = nn.ModuleList()
538
  for i, depth in enumerate(config.depths):
539
  stage = Stage(
@@ -552,9 +517,9 @@ class LSPTransformer(nn.Module):
552
 
553
  # output heads
554
  self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
555
- self.point_head = MLP(config.dim, config.dim, 2, 3)
556
  self.radial_distances_head = MLP(
557
- config.dim, config.dim, config.num_radial_distances, 3
558
  )
559
 
560
  self.init_weights()
@@ -565,29 +530,24 @@ class LSPTransformer(nn.Module):
565
  nn.init.constant_(self.point_head[-1].bias, 0.0)
566
 
567
  def forward(
568
- self, bottleneck: Tensor, features: list[Tensor], height: int, width: int
 
 
 
 
 
569
  ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
570
- b = bottleneck.size(0)
571
-
572
  src = []
573
  src_coords = []
574
- for i, feature in enumerate(reversed(features)):
575
- h, w = feature.shape[2:4]
576
  coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
577
  src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
578
- src_coords.append(relative_to_absolute_points(coords, height, width))
579
-
580
- ref_points = torch.zeros(
581
- b,
582
- height // self.query_block_size,
583
- width // self.query_block_size,
584
- 2,
585
- dtype=torch.float32,
586
- device=bottleneck.device,
587
- ) # center positions
588
- tgt = self.feature_sampling(
589
- relative_to_absolute_points_normalized(ref_points), bottleneck
590
- )
591
 
592
  logits_list: list[Tensor] = []
593
  ref_points_list: list[Tensor] = []
@@ -598,7 +558,9 @@ class LSPTransformer(nn.Module):
598
  tgt = stage(
599
  tgt=tgt,
600
  src=src[i],
601
- tgt_coords=relative_to_absolute_points(ref_points, height, width),
 
 
602
  src_coords=src_coords[i],
603
  )
604
 
@@ -608,8 +570,10 @@ class LSPTransformer(nn.Module):
608
  logits = self.class_head(tgt)
609
 
610
  ref_points_list.append(
611
- relative_to_absolute_points_normalized(
612
- new_ref_points + delta_point
 
 
613
  ).flatten(1, 2)
614
  )
615
  logits_list.append(logits.flatten(1, 2))
@@ -622,8 +586,8 @@ class LSPTransformer(nn.Module):
622
  "logits": logits_list[-1],
623
  "points": ref_points_list[-1],
624
  "radial_distances": radial_distances_list[-1],
625
- "absolute_points": relative_to_absolute_points(
626
- ref_points, height, width
627
  ).flatten(1, 2),
628
  "aux_outputs": [
629
  {
@@ -641,17 +605,48 @@ class LSPTransformer(nn.Module):
641
  }
642
 
643
 
 
 
 
 
 
 
 
 
 
 
 
644
  class LSPDetrModel(PreTrainedModel):
645
  config_class = LSPDetrConfig
646
 
647
  def __init__(self, config: LSPDetrConfig) -> None:
648
  super().__init__(config)
 
649
 
650
  self.backbone = load_backbone(config)
651
- _, *feature_channels, bottleneck = self.backbone.num_features
652
- self.decode_head = LSPTransformer(config, bottleneck, feature_channels[::-1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
653
 
654
- def forward(self, image: Tensor) -> dict[str, Tensor]:
655
- *features, bottleneck = self.backbone(image).feature_maps
656
- height, width = image.shape[2:]
657
- return self.decode_head(bottleneck, features, height, width)
 
21
  hidden_dim: int,
22
  output_dim: int,
23
  num_layers: int,
24
+ act_layer: type[nn.Module] = nn.GELU,
25
  dropout: float = 0.0,
26
  ) -> None:
27
  assert num_layers > 1
 
65
 
66
 
67
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
68
+ """Taken from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py."""
69
  freqs_x = []
70
  freqs_y = []
71
  freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim))
 
169
 
170
 
171
  @torch.autocast("cuda", enabled=False)
172
+ def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tensor:
173
+ pos = pos.sigmoid()
174
+ h, w = pos.shape[1:3]
175
 
176
+ anchor_x = torch.arange(w, dtype=torch.float32, device=pos.device) * step_x
177
+ anchor_y = torch.arange(h, dtype=torch.float32, device=pos.device) * step_y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ absolute_x = pos[..., 0] * step_x + anchor_x
180
+ absolute_y = pos[..., 1] * step_y + anchor_y.unsqueeze(1)
181
  return torch.stack((absolute_x, absolute_y), dim=-1)
182
 
183
 
 
280
 
281
  # partition windows
282
  tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
 
283
  src = window_partition(src, self.src_window_size).flatten(1, 2)
284
+ tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
285
  src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
286
 
287
  attn_mask = self.get_attn_mask(
 
492
  return tgt
493
 
494
 
 
 
 
 
 
 
 
 
 
 
 
495
  class LSPTransformer(nn.Module):
496
+ def __init__(self, config: LSPDetrConfig, feature_channels: list[int]) -> None:
 
 
 
 
 
497
  super().__init__()
498
 
499
  self.query_block_size = config.query_block_size
500
  self.num_radial_distances = config.num_radial_distances
501
 
 
 
502
  self.stages = nn.ModuleList()
503
  for i, depth in enumerate(config.depths):
504
  stage = Stage(
 
517
 
518
  # output heads
519
  self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
520
+ self.point_head = MLP(config.dim, config.dim, 2, 2)
521
  self.radial_distances_head = MLP(
522
+ config.dim, config.dim, config.num_radial_distances, 2
523
  )
524
 
525
  self.init_weights()
 
530
  nn.init.constant_(self.point_head[-1].bias, 0.0)
531
 
532
  def forward(
533
+ self,
534
+ tgt: Tensor,
535
+ ref_points: Tensor,
536
+ features: list[Tensor],
537
+ height: int,
538
+ width: int,
539
  ) -> dict[str, Tensor | list[dict[str, Tensor]]]:
 
 
540
  src = []
541
  src_coords = []
542
+ for i, feature in enumerate(features):
543
+ b, _, h, w = feature.shape
544
  coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
545
  src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
546
+ src_coords.append(
547
+ relative_to_absolute_pos(
548
+ coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h)
549
+ )
550
+ )
 
 
 
 
 
 
 
 
551
 
552
  logits_list: list[Tensor] = []
553
  ref_points_list: list[Tensor] = []
 
558
  tgt = stage(
559
  tgt=tgt,
560
  src=src[i],
561
+ tgt_coords=relative_to_absolute_pos(
562
+ ref_points, self.query_block_size, self.query_block_size
563
+ ),
564
  src_coords=src_coords[i],
565
  )
566
 
 
570
  logits = self.class_head(tgt)
571
 
572
  ref_points_list.append(
573
+ relative_to_absolute_pos(
574
+ new_ref_points + delta_point,
575
+ step_x=self.query_block_size / width,
576
+ step_y=self.query_block_size / height,
577
  ).flatten(1, 2)
578
  )
579
  logits_list.append(logits.flatten(1, 2))
 
586
  "logits": logits_list[-1],
587
  "points": ref_points_list[-1],
588
  "radial_distances": radial_distances_list[-1],
589
+ "absolute_points": relative_to_absolute_pos(
590
+ ref_points, self.query_block_size, self.query_block_size
591
  ).flatten(1, 2),
592
  "aux_outputs": [
593
  {
 
605
  }
606
 
607
 
608
+ class FeatureSampling(nn.Module):
609
+ def __init__(self, in_dim: int, out_dim: int) -> None:
610
+ super().__init__()
611
+ self.reduction = nn.Linear(in_dim, out_dim, bias=False)
612
+ self.norm = nn.LayerNorm(out_dim)
613
+
614
+ def forward(self, points: Tensor, feature: Tensor) -> Tensor:
615
+ x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
616
+ return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
617
+
618
+
619
  class LSPDetrModel(PreTrainedModel):
620
  config_class = LSPDetrConfig
621
 
622
  def __init__(self, config: LSPDetrConfig) -> None:
623
  super().__init__(config)
624
+ self.query_block_size = config.query_block_size
625
 
626
  self.backbone = load_backbone(config)
627
+ _, *feature_channels, neck = self.backbone.num_features
628
+
629
+ self.feature_sampling = FeatureSampling(neck, config.dim)
630
+ self.decode_head = LSPTransformer(config, feature_channels[::-1])
631
+
632
+ def forward(self, pixel_values: Tensor) -> dict[str, Tensor]:
633
+ b, _, h, w = pixel_values.shape
634
+
635
+ *features, neck = self.backbone(pixel_values).feature_maps
636
+
637
+ ref_points = torch.zeros(
638
+ b,
639
+ math.ceil(h / self.query_block_size),
640
+ math.ceil(w / self.query_block_size),
641
+ 2,
642
+ dtype=torch.float32,
643
+ device=neck.device,
644
+ ) # center positions
645
+ tgt = self.feature_sampling(
646
+ relative_to_absolute_pos(
647
+ ref_points, self.query_block_size, self.query_block_size
648
+ ),
649
+ neck,
650
+ )
651
 
652
+ return self.decode_head(tgt, ref_points, features[::-1], h, w)