Upload model
Browse files- config.json +1 -1
- configuration.py +1 -1
- model.safetensors +2 -2
- modeling.py +72 -77
config.json
CHANGED
|
@@ -42,5 +42,5 @@
|
|
| 42 |
"transformers_version": "4.51.3",
|
| 43 |
"use_pretrained_backbone": true,
|
| 44 |
"use_timm_backbone": false,
|
| 45 |
-
"window_size":
|
| 46 |
}
|
|
|
|
| 42 |
"transformers_version": "4.51.3",
|
| 43 |
"use_pretrained_backbone": true,
|
| 44 |
"use_timm_backbone": false,
|
| 45 |
+
"window_size": 8
|
| 46 |
}
|
configuration.py
CHANGED
|
@@ -19,7 +19,7 @@ class LSPDetrConfig(PretrainedConfig):
|
|
| 19 |
depths: tuple[int, ...] = (6, 2, 2),
|
| 20 |
query_block_size: int = 16,
|
| 21 |
num_heads: int = 12,
|
| 22 |
-
window_size: int =
|
| 23 |
tgt_window_sizes: tuple[int, ...] = (8, 8, 8),
|
| 24 |
src_window_sizes: tuple[int, ...] = (8, 16, 32),
|
| 25 |
num_radial_distances: int = 64,
|
|
|
|
| 19 |
depths: tuple[int, ...] = (6, 2, 2),
|
| 20 |
query_block_size: int = 16,
|
| 21 |
num_heads: int = 12,
|
| 22 |
+
window_size: int = 8,
|
| 23 |
tgt_window_sizes: tuple[int, ...] = (8, 8, 8),
|
| 24 |
src_window_sizes: tuple[int, ...] = (8, 16, 32),
|
| 25 |
num_radial_distances: int = 64,
|
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7f85a91f41a67c7c3f6fb50f71f41e95e1e47c72e561b98308d65d899160de43
|
| 3 |
+
size 204465704
|
modeling.py
CHANGED
|
@@ -21,7 +21,7 @@ class MLP(nn.Sequential):
|
|
| 21 |
hidden_dim: int,
|
| 22 |
output_dim: int,
|
| 23 |
num_layers: int,
|
| 24 |
-
act_layer: type[nn.Module] = nn.
|
| 25 |
dropout: float = 0.0,
|
| 26 |
) -> None:
|
| 27 |
assert num_layers > 1
|
|
@@ -65,6 +65,7 @@ class FeedForward(nn.Module):
|
|
| 65 |
|
| 66 |
|
| 67 |
def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
|
|
|
|
| 68 |
freqs_x = []
|
| 69 |
freqs_y = []
|
| 70 |
freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim))
|
|
@@ -168,33 +169,15 @@ def maybe_pad(x: Tensor, window_size: int) -> Tensor:
|
|
| 168 |
|
| 169 |
|
| 170 |
@torch.autocast("cuda", enabled=False)
|
| 171 |
-
def
|
| 172 |
-
|
| 173 |
-
h, w =
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
anchor_x = torch.arange(0, width, step_x, device=points.device)[:w]
|
| 179 |
-
anchor_y = torch.arange(0, height, step_y, device=points.device)[:h, None]
|
| 180 |
-
|
| 181 |
-
absolute_x = points[..., 0] * step_x + anchor_x
|
| 182 |
-
absolute_y = points[..., 1] * step_y + anchor_y
|
| 183 |
-
|
| 184 |
-
return torch.stack((absolute_x, absolute_y), dim=-1)
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
@torch.autocast("cuda", enabled=False)
|
| 188 |
-
def relative_to_absolute_points_normalized(points: Tensor) -> Tensor:
|
| 189 |
-
points = points.sigmoid()
|
| 190 |
-
h, w = points.shape[1:3]
|
| 191 |
-
|
| 192 |
-
anchor_x = torch.arange(0, 1, 1 / w, device=points.device)[:w]
|
| 193 |
-
anchor_y = torch.arange(0, 1, 1 / h, device=points.device)[:h, None]
|
| 194 |
-
|
| 195 |
-
absolute_x = points[..., 0] / w + anchor_x
|
| 196 |
-
absolute_y = points[..., 1] / h + anchor_y
|
| 197 |
|
|
|
|
|
|
|
| 198 |
return torch.stack((absolute_x, absolute_y), dim=-1)
|
| 199 |
|
| 200 |
|
|
@@ -297,8 +280,8 @@ class WindowCrossAttention(nn.Module):
|
|
| 297 |
|
| 298 |
# partition windows
|
| 299 |
tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
|
| 300 |
-
tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
|
| 301 |
src = window_partition(src, self.src_window_size).flatten(1, 2)
|
|
|
|
| 302 |
src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
|
| 303 |
|
| 304 |
attn_mask = self.get_attn_mask(
|
|
@@ -509,31 +492,13 @@ class Stage(nn.Module):
|
|
| 509 |
return tgt
|
| 510 |
|
| 511 |
|
| 512 |
-
class FeatureSampling(nn.Module):
|
| 513 |
-
def __init__(self, in_dim: int, out_dim: int) -> None:
|
| 514 |
-
super().__init__()
|
| 515 |
-
self.reduction = nn.Linear(in_dim, out_dim, bias=False)
|
| 516 |
-
self.norm = nn.LayerNorm(out_dim)
|
| 517 |
-
|
| 518 |
-
def forward(self, points: Tensor, feature: Tensor) -> Tensor:
|
| 519 |
-
x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
|
| 520 |
-
return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
|
| 521 |
-
|
| 522 |
-
|
| 523 |
class LSPTransformer(nn.Module):
|
| 524 |
-
def __init__(
|
| 525 |
-
self,
|
| 526 |
-
config: LSPDetrConfig,
|
| 527 |
-
bottleneck_channels: int,
|
| 528 |
-
feature_channels: list[int],
|
| 529 |
-
) -> None:
|
| 530 |
super().__init__()
|
| 531 |
|
| 532 |
self.query_block_size = config.query_block_size
|
| 533 |
self.num_radial_distances = config.num_radial_distances
|
| 534 |
|
| 535 |
-
self.feature_sampling = FeatureSampling(bottleneck_channels, config.dim)
|
| 536 |
-
|
| 537 |
self.stages = nn.ModuleList()
|
| 538 |
for i, depth in enumerate(config.depths):
|
| 539 |
stage = Stage(
|
|
@@ -552,9 +517,9 @@ class LSPTransformer(nn.Module):
|
|
| 552 |
|
| 553 |
# output heads
|
| 554 |
self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
|
| 555 |
-
self.point_head = MLP(config.dim, config.dim, 2,
|
| 556 |
self.radial_distances_head = MLP(
|
| 557 |
-
config.dim, config.dim, config.num_radial_distances,
|
| 558 |
)
|
| 559 |
|
| 560 |
self.init_weights()
|
|
@@ -565,29 +530,24 @@ class LSPTransformer(nn.Module):
|
|
| 565 |
nn.init.constant_(self.point_head[-1].bias, 0.0)
|
| 566 |
|
| 567 |
def forward(
|
| 568 |
-
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 569 |
) -> dict[str, Tensor | list[dict[str, Tensor]]]:
|
| 570 |
-
b = bottleneck.size(0)
|
| 571 |
-
|
| 572 |
src = []
|
| 573 |
src_coords = []
|
| 574 |
-
for i, feature in enumerate(
|
| 575 |
-
h, w = feature.shape
|
| 576 |
coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
|
| 577 |
src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
|
| 578 |
-
src_coords.append(
|
| 579 |
-
|
| 580 |
-
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
width // self.query_block_size,
|
| 584 |
-
2,
|
| 585 |
-
dtype=torch.float32,
|
| 586 |
-
device=bottleneck.device,
|
| 587 |
-
) # center positions
|
| 588 |
-
tgt = self.feature_sampling(
|
| 589 |
-
relative_to_absolute_points_normalized(ref_points), bottleneck
|
| 590 |
-
)
|
| 591 |
|
| 592 |
logits_list: list[Tensor] = []
|
| 593 |
ref_points_list: list[Tensor] = []
|
|
@@ -598,7 +558,9 @@ class LSPTransformer(nn.Module):
|
|
| 598 |
tgt = stage(
|
| 599 |
tgt=tgt,
|
| 600 |
src=src[i],
|
| 601 |
-
tgt_coords=
|
|
|
|
|
|
|
| 602 |
src_coords=src_coords[i],
|
| 603 |
)
|
| 604 |
|
|
@@ -608,8 +570,10 @@ class LSPTransformer(nn.Module):
|
|
| 608 |
logits = self.class_head(tgt)
|
| 609 |
|
| 610 |
ref_points_list.append(
|
| 611 |
-
|
| 612 |
-
new_ref_points + delta_point
|
|
|
|
|
|
|
| 613 |
).flatten(1, 2)
|
| 614 |
)
|
| 615 |
logits_list.append(logits.flatten(1, 2))
|
|
@@ -622,8 +586,8 @@ class LSPTransformer(nn.Module):
|
|
| 622 |
"logits": logits_list[-1],
|
| 623 |
"points": ref_points_list[-1],
|
| 624 |
"radial_distances": radial_distances_list[-1],
|
| 625 |
-
"absolute_points":
|
| 626 |
-
ref_points,
|
| 627 |
).flatten(1, 2),
|
| 628 |
"aux_outputs": [
|
| 629 |
{
|
|
@@ -641,17 +605,48 @@ class LSPTransformer(nn.Module):
|
|
| 641 |
}
|
| 642 |
|
| 643 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 644 |
class LSPDetrModel(PreTrainedModel):
|
| 645 |
config_class = LSPDetrConfig
|
| 646 |
|
| 647 |
def __init__(self, config: LSPDetrConfig) -> None:
|
| 648 |
super().__init__(config)
|
|
|
|
| 649 |
|
| 650 |
self.backbone = load_backbone(config)
|
| 651 |
-
_, *feature_channels,
|
| 652 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
-
|
| 655 |
-
*features, bottleneck = self.backbone(image).feature_maps
|
| 656 |
-
height, width = image.shape[2:]
|
| 657 |
-
return self.decode_head(bottleneck, features, height, width)
|
|
|
|
| 21 |
hidden_dim: int,
|
| 22 |
output_dim: int,
|
| 23 |
num_layers: int,
|
| 24 |
+
act_layer: type[nn.Module] = nn.GELU,
|
| 25 |
dropout: float = 0.0,
|
| 26 |
) -> None:
|
| 27 |
assert num_layers > 1
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
|
| 68 |
+
"""Taken from https://github.com/naver-ai/rope-vit/blob/main/self-attn/rope_self_attn.py."""
|
| 69 |
freqs_x = []
|
| 70 |
freqs_y = []
|
| 71 |
freqs = 1 / (theta ** (torch.arange(0, head_dim, 2 * pos_dim).float() / head_dim))
|
|
|
|
| 169 |
|
| 170 |
|
| 171 |
@torch.autocast("cuda", enabled=False)
|
| 172 |
+
def relative_to_absolute_pos(pos: Tensor, step_x: float, step_y: float) -> Tensor:
|
| 173 |
+
pos = pos.sigmoid()
|
| 174 |
+
h, w = pos.shape[1:3]
|
| 175 |
|
| 176 |
+
anchor_x = torch.arange(w, dtype=torch.float32, device=pos.device) * step_x
|
| 177 |
+
anchor_y = torch.arange(h, dtype=torch.float32, device=pos.device) * step_y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
+
absolute_x = pos[..., 0] * step_x + anchor_x
|
| 180 |
+
absolute_y = pos[..., 1] * step_y + anchor_y.unsqueeze(1)
|
| 181 |
return torch.stack((absolute_x, absolute_y), dim=-1)
|
| 182 |
|
| 183 |
|
|
|
|
| 280 |
|
| 281 |
# partition windows
|
| 282 |
tgt = window_partition(tgt, self.tgt_window_size).flatten(1, 2)
|
|
|
|
| 283 |
src = window_partition(src, self.src_window_size).flatten(1, 2)
|
| 284 |
+
tgt_coords = window_partition(tgt_coords, self.tgt_window_size).flatten(1, 2)
|
| 285 |
src_coord = window_partition(src_coord, self.src_window_size).flatten(1, 2)
|
| 286 |
|
| 287 |
attn_mask = self.get_attn_mask(
|
|
|
|
| 492 |
return tgt
|
| 493 |
|
| 494 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
class LSPTransformer(nn.Module):
|
| 496 |
+
def __init__(self, config: LSPDetrConfig, feature_channels: list[int]) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 497 |
super().__init__()
|
| 498 |
|
| 499 |
self.query_block_size = config.query_block_size
|
| 500 |
self.num_radial_distances = config.num_radial_distances
|
| 501 |
|
|
|
|
|
|
|
| 502 |
self.stages = nn.ModuleList()
|
| 503 |
for i, depth in enumerate(config.depths):
|
| 504 |
stage = Stage(
|
|
|
|
| 517 |
|
| 518 |
# output heads
|
| 519 |
self.class_head = nn.Linear(config.dim, config.num_classes + 1, bias=False)
|
| 520 |
+
self.point_head = MLP(config.dim, config.dim, 2, 2)
|
| 521 |
self.radial_distances_head = MLP(
|
| 522 |
+
config.dim, config.dim, config.num_radial_distances, 2
|
| 523 |
)
|
| 524 |
|
| 525 |
self.init_weights()
|
|
|
|
| 530 |
nn.init.constant_(self.point_head[-1].bias, 0.0)
|
| 531 |
|
| 532 |
def forward(
|
| 533 |
+
self,
|
| 534 |
+
tgt: Tensor,
|
| 535 |
+
ref_points: Tensor,
|
| 536 |
+
features: list[Tensor],
|
| 537 |
+
height: int,
|
| 538 |
+
width: int,
|
| 539 |
) -> dict[str, Tensor | list[dict[str, Tensor]]]:
|
|
|
|
|
|
|
| 540 |
src = []
|
| 541 |
src_coords = []
|
| 542 |
+
for i, feature in enumerate(features):
|
| 543 |
+
b, _, h, w = feature.shape
|
| 544 |
coords = torch.zeros(b, h, w, 2, dtype=torch.float32, device=feature.device)
|
| 545 |
src.append(self.input_norm[i](rearrange(feature, "b c h w -> b h w c")))
|
| 546 |
+
src_coords.append(
|
| 547 |
+
relative_to_absolute_pos(
|
| 548 |
+
coords, step_x=math.ceil(width / w), step_y=math.ceil(height / h)
|
| 549 |
+
)
|
| 550 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 551 |
|
| 552 |
logits_list: list[Tensor] = []
|
| 553 |
ref_points_list: list[Tensor] = []
|
|
|
|
| 558 |
tgt = stage(
|
| 559 |
tgt=tgt,
|
| 560 |
src=src[i],
|
| 561 |
+
tgt_coords=relative_to_absolute_pos(
|
| 562 |
+
ref_points, self.query_block_size, self.query_block_size
|
| 563 |
+
),
|
| 564 |
src_coords=src_coords[i],
|
| 565 |
)
|
| 566 |
|
|
|
|
| 570 |
logits = self.class_head(tgt)
|
| 571 |
|
| 572 |
ref_points_list.append(
|
| 573 |
+
relative_to_absolute_pos(
|
| 574 |
+
new_ref_points + delta_point,
|
| 575 |
+
step_x=self.query_block_size / width,
|
| 576 |
+
step_y=self.query_block_size / height,
|
| 577 |
).flatten(1, 2)
|
| 578 |
)
|
| 579 |
logits_list.append(logits.flatten(1, 2))
|
|
|
|
| 586 |
"logits": logits_list[-1],
|
| 587 |
"points": ref_points_list[-1],
|
| 588 |
"radial_distances": radial_distances_list[-1],
|
| 589 |
+
"absolute_points": relative_to_absolute_pos(
|
| 590 |
+
ref_points, self.query_block_size, self.query_block_size
|
| 591 |
).flatten(1, 2),
|
| 592 |
"aux_outputs": [
|
| 593 |
{
|
|
|
|
| 605 |
}
|
| 606 |
|
| 607 |
|
| 608 |
+
class FeatureSampling(nn.Module):
|
| 609 |
+
def __init__(self, in_dim: int, out_dim: int) -> None:
|
| 610 |
+
super().__init__()
|
| 611 |
+
self.reduction = nn.Linear(in_dim, out_dim, bias=False)
|
| 612 |
+
self.norm = nn.LayerNorm(out_dim)
|
| 613 |
+
|
| 614 |
+
def forward(self, points: Tensor, feature: Tensor) -> Tensor:
|
| 615 |
+
x = F.grid_sample(feature, points * 2 - 1, align_corners=False)
|
| 616 |
+
return self.norm(self.reduction(rearrange(x, "b c h w -> b h w c")))
|
| 617 |
+
|
| 618 |
+
|
| 619 |
class LSPDetrModel(PreTrainedModel):
|
| 620 |
config_class = LSPDetrConfig
|
| 621 |
|
| 622 |
def __init__(self, config: LSPDetrConfig) -> None:
|
| 623 |
super().__init__(config)
|
| 624 |
+
self.query_block_size = config.query_block_size
|
| 625 |
|
| 626 |
self.backbone = load_backbone(config)
|
| 627 |
+
_, *feature_channels, neck = self.backbone.num_features
|
| 628 |
+
|
| 629 |
+
self.feature_sampling = FeatureSampling(neck, config.dim)
|
| 630 |
+
self.decode_head = LSPTransformer(config, feature_channels[::-1])
|
| 631 |
+
|
| 632 |
+
def forward(self, pixel_values: Tensor) -> dict[str, Tensor]:
|
| 633 |
+
b, _, h, w = pixel_values.shape
|
| 634 |
+
|
| 635 |
+
*features, neck = self.backbone(pixel_values).feature_maps
|
| 636 |
+
|
| 637 |
+
ref_points = torch.zeros(
|
| 638 |
+
b,
|
| 639 |
+
math.ceil(h / self.query_block_size),
|
| 640 |
+
math.ceil(w / self.query_block_size),
|
| 641 |
+
2,
|
| 642 |
+
dtype=torch.float32,
|
| 643 |
+
device=neck.device,
|
| 644 |
+
) # center positions
|
| 645 |
+
tgt = self.feature_sampling(
|
| 646 |
+
relative_to_absolute_pos(
|
| 647 |
+
ref_points, self.query_block_size, self.query_block_size
|
| 648 |
+
),
|
| 649 |
+
neck,
|
| 650 |
+
)
|
| 651 |
|
| 652 |
+
return self.decode_head(tgt, ref_points, features[::-1], h, w)
|
|
|
|
|
|
|
|
|