Upload model
Browse files- modeling.py +3 -14
modeling.py
CHANGED
|
@@ -18,7 +18,7 @@ from transformers.utils.backbone_utils import load_backbone
|
|
| 18 |
from .configuration import LSPDetrConfig, STAConfig
|
| 19 |
|
| 20 |
|
| 21 |
-
flex_attention = torch.compile(flex_attention, dynamic=
|
| 22 |
|
| 23 |
|
| 24 |
def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
|
|
@@ -106,6 +106,7 @@ class CayleySTRING(nn.Module):
|
|
| 106 |
positions ([b, n, pos_dim]): Positions tensor.
|
| 107 |
"""
|
| 108 |
# Compute (I + S)^-1 @ x
|
|
|
|
| 109 |
if self.training:
|
| 110 |
# Use linalg.solve during training for numerical stability.
|
| 111 |
y = torch.linalg.solve(
|
|
@@ -518,19 +519,7 @@ class LSPTransformer(nn.Module):
|
|
| 518 |
"absolute_points": relative_to_absolute_pos(
|
| 519 |
ref_points, self.query_block_size, self.query_block_size
|
| 520 |
).flatten(1, 2),
|
| 521 |
-
"
|
| 522 |
-
{
|
| 523 |
-
"logits": a,
|
| 524 |
-
"points": b,
|
| 525 |
-
"radial_distances": c,
|
| 526 |
-
}
|
| 527 |
-
for a, b, c in zip(
|
| 528 |
-
logits_list[:-1],
|
| 529 |
-
ref_points_list[:-1],
|
| 530 |
-
radial_distances_list[:-1],
|
| 531 |
-
strict=True,
|
| 532 |
-
)
|
| 533 |
-
],
|
| 534 |
}
|
| 535 |
|
| 536 |
|
|
|
|
| 18 |
from .configuration import LSPDetrConfig, STAConfig
|
| 19 |
|
| 20 |
|
| 21 |
+
flex_attention = torch.compile(flex_attention, dynamic=False)
|
| 22 |
|
| 23 |
|
| 24 |
def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
|
|
|
|
| 106 |
positions ([b, n, pos_dim]): Positions tensor.
|
| 107 |
"""
|
| 108 |
# Compute (I + S)^-1 @ x
|
| 109 |
+
print(self.training)
|
| 110 |
if self.training:
|
| 111 |
# Use linalg.solve during training for numerical stability.
|
| 112 |
y = torch.linalg.solve(
|
|
|
|
| 519 |
"absolute_points": relative_to_absolute_pos(
|
| 520 |
ref_points, self.query_block_size, self.query_block_size
|
| 521 |
).flatten(1, 2),
|
| 522 |
+
"embeddings": tgt.flatten(1, 2),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
}
|
| 524 |
|
| 525 |
|