Upload model
Browse files- modeling.py +3 -11
modeling.py
CHANGED
|
@@ -19,19 +19,13 @@ from transformers.utils.backbone_utils import load_backbone
|
|
| 19 |
from .configuration import LSPDetrConfig, STAConfig
|
| 20 |
|
| 21 |
|
| 22 |
-
|
| 23 |
-
return True if Q.device == torch.device("meta") else _is_orthogonal(Q, eps=eps)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
patch(
|
| 27 |
"torch.nn.utils.parametrizations._is_orthogonal",
|
| 28 |
-
|
| 29 |
).start()
|
| 30 |
|
| 31 |
|
| 32 |
-
flex_attention = torch.compile(flex_attention, dynamic=True)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
class CayleySTRING(nn.Module):
|
| 36 |
"""Implements the Cayley-STRING positional encoding.
|
| 37 |
|
|
@@ -464,6 +458,7 @@ class LSPTransformer(nn.Module):
|
|
| 464 |
"absolute_points": relative_to_absolute_pos(
|
| 465 |
ref_points, self.query_block_size, self.query_block_size
|
| 466 |
).flatten(1, 2),
|
|
|
|
| 467 |
"aux_outputs": [
|
| 468 |
{
|
| 469 |
"logits": a,
|
|
@@ -525,6 +520,3 @@ class LSPDetrModel(PreTrainedModel):
|
|
| 525 |
)
|
| 526 |
|
| 527 |
return self.decode_head(tgt, ref_points, features, h, w)
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
LSPDetrModel.from_pretrained("RationAI/LSP-DETR", trust_remote_code=True)
|
|
|
|
| 19 |
from .configuration import LSPDetrConfig, STAConfig
|
| 20 |
|
| 21 |
|
| 22 |
+
flex_attention = torch.compile(flex_attention, dynamic=True)
|
|
|
|
|
|
|
|
|
|
| 23 |
patch(
|
| 24 |
"torch.nn.utils.parametrizations._is_orthogonal",
|
| 25 |
+
lambda Q, eps=None: Q.device == torch.device("meta") or _is_orthogonal(Q, eps=eps),
|
| 26 |
).start()
|
| 27 |
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
class CayleySTRING(nn.Module):
|
| 30 |
"""Implements the Cayley-STRING positional encoding.
|
| 31 |
|
|
|
|
| 458 |
"absolute_points": relative_to_absolute_pos(
|
| 459 |
ref_points, self.query_block_size, self.query_block_size
|
| 460 |
).flatten(1, 2),
|
| 461 |
+
"embeddings": tgt.flatten(1, 2),
|
| 462 |
"aux_outputs": [
|
| 463 |
{
|
| 464 |
"logits": a,
|
|
|
|
| 520 |
)
|
| 521 |
|
| 522 |
return self.decode_head(tgt, ref_points, features, h, w)
|
|
|
|
|
|
|
|
|