matejpekar commited on
Commit
116afe9
·
verified ·
1 Parent(s): 58b7416

Upload model

Browse files
Files changed (1) hide show
  1. modeling.py +3 -11
modeling.py CHANGED
@@ -19,19 +19,13 @@ from transformers.utils.backbone_utils import load_backbone
19
  from .configuration import LSPDetrConfig, STAConfig
20
 
21
 
22
- def _meta_safe_is_orthogonal(Q, eps=None):
23
- return True if Q.device == torch.device("meta") else _is_orthogonal(Q, eps=eps)
24
-
25
-
26
  patch(
27
  "torch.nn.utils.parametrizations._is_orthogonal",
28
- _meta_safe_is_orthogonal,
29
  ).start()
30
 
31
 
32
- flex_attention = torch.compile(flex_attention, dynamic=True)
33
-
34
-
35
  class CayleySTRING(nn.Module):
36
  """Implements the Cayley-STRING positional encoding.
37
 
@@ -464,6 +458,7 @@ class LSPTransformer(nn.Module):
464
  "absolute_points": relative_to_absolute_pos(
465
  ref_points, self.query_block_size, self.query_block_size
466
  ).flatten(1, 2),
 
467
  "aux_outputs": [
468
  {
469
  "logits": a,
@@ -525,6 +520,3 @@ class LSPDetrModel(PreTrainedModel):
525
  )
526
 
527
  return self.decode_head(tgt, ref_points, features, h, w)
528
-
529
-
530
- LSPDetrModel.from_pretrained("RationAI/LSP-DETR", trust_remote_code=True)
 
19
  from .configuration import LSPDetrConfig, STAConfig
20
 
21
 
22
+ flex_attention = torch.compile(flex_attention, dynamic=True)
 
 
 
23
  patch(
24
  "torch.nn.utils.parametrizations._is_orthogonal",
25
+ lambda Q, eps=None: Q.device == torch.device("meta") or _is_orthogonal(Q, eps=eps),
26
  ).start()
27
 
28
 
 
 
 
29
  class CayleySTRING(nn.Module):
30
  """Implements the Cayley-STRING positional encoding.
31
 
 
458
  "absolute_points": relative_to_absolute_pos(
459
  ref_points, self.query_block_size, self.query_block_size
460
  ).flatten(1, 2),
461
+ "embeddings": tgt.flatten(1, 2),
462
  "aux_outputs": [
463
  {
464
  "logits": a,
 
520
  )
521
 
522
  return self.decode_head(tgt, ref_points, features, h, w)