matejpekar commited on
Commit
ae0d54d
·
verified ·
1 Parent(s): 389a8d6

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +3 -14
modeling.py CHANGED
@@ -18,7 +18,7 @@ from transformers.utils.backbone_utils import load_backbone
18
  from .configuration import LSPDetrConfig, STAConfig
19
 
20
 
21
- flex_attention = torch.compile(flex_attention, dynamic=True)
22
 
23
 
24
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
@@ -106,6 +106,7 @@ class CayleySTRING(nn.Module):
106
  positions ([b, n, pos_dim]): Positions tensor.
107
  """
108
  # Compute (I + S)^-1 @ x
 
109
  if self.training:
110
  # Use linalg.solve during training for numerical stability.
111
  y = torch.linalg.solve(
@@ -518,19 +519,7 @@ class LSPTransformer(nn.Module):
518
  "absolute_points": relative_to_absolute_pos(
519
  ref_points, self.query_block_size, self.query_block_size
520
  ).flatten(1, 2),
521
- "aux_outputs": [
522
- {
523
- "logits": a,
524
- "points": b,
525
- "radial_distances": c,
526
- }
527
- for a, b, c in zip(
528
- logits_list[:-1],
529
- ref_points_list[:-1],
530
- radial_distances_list[:-1],
531
- strict=True,
532
- )
533
- ],
534
  }
535
 
536
 
 
18
  from .configuration import LSPDetrConfig, STAConfig
19
 
20
 
21
+ flex_attention = torch.compile(flex_attention, dynamic=False)
22
 
23
 
24
  def init_freqs(head_dim: int, num_heads: int, pos_dim: int, theta: float) -> Tensor:
 
106
  positions ([b, n, pos_dim]): Positions tensor.
107
  """
108
  # Compute (I + S)^-1 @ x
109
+ print(self.training)
110
  if self.training:
111
  # Use linalg.solve during training for numerical stability.
112
  y = torch.linalg.solve(
 
519
  "absolute_points": relative_to_absolute_pos(
520
  ref_points, self.query_block_size, self.query_block_size
521
  ).flatten(1, 2),
522
+ "embeddings": tgt.flatten(1, 2),
 
 
 
 
 
 
 
 
 
 
 
 
523
  }
524
 
525