File size: 2,273 Bytes
389a8d6 d57f1a3 c3f67c4 d57f1a3 c3f67c4 389a8d6 b5a1bdb c3f67c4 d57f1a3 c3f67c4 389a8d6 45f73f8 389a8d6 c3f67c4 389a8d6 45f73f8 389a8d6 c3f67c4 389a8d6 45f73f8 389a8d6 d57f1a3 9c78786 c3f67c4 9c78786 c3f67c4 389a8d6 c3f67c4 389a8d6 c3f67c4 389a8d6 c3f67c4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from typing import Any, TypedDict
from transformers import PretrainedConfig
from transformers.utils.backbone_utils import verify_backbone_config_arguments
class STAConfig(TypedDict):
kernel: int
q_tile: int
kv_tile: int
class LSPDetrConfig(PretrainedConfig):
model_type = "lsp_detr"
def __init__(
self,
use_timm_backbone: bool = False,
use_pretrained_backbone: bool = True,
backbone: str = "microsoft/swinv2-tiny-patch4-window16-256",
backbone_kwargs: dict[str, Any] | None = None,
backbone_config: Any | None = None,
dim: int = 384,
num_heads: int = 12,
num_classes: int = 1,
query_block_size: float = 8, # 256 / 32
feature_levels: tuple[int, ...] = (2, 1, 0, 2, 1, 0),
num_radial_distances: int = 64,
self_sta_config: STAConfig | None = None,
cross_sta_config: tuple[STAConfig, ...] = (
{"kernel": 5, "q_tile": 4, "kv_tile": 8},
{"kernel": 5, "q_tile": 4, "kv_tile": 4},
{"kernel": 5, "q_tile": 4, "kv_tile": 2},
),
**kwargs,
) -> None:
if self_sta_config is None:
self_sta_config = {"kernel": 5, "q_tile": 4, "kv_tile": 4}
if backbone_kwargs is None:
backbone_kwargs = {"out_features": ["stage1", "stage2", "stage3", "stage4"]}
verify_backbone_config_arguments(
use_timm_backbone=use_timm_backbone,
use_pretrained_backbone=use_pretrained_backbone,
backbone=backbone,
backbone_config=backbone_config,
backbone_kwargs=backbone_kwargs,
)
self.use_timm_backbone = use_timm_backbone
self.use_pretrained_backbone = use_pretrained_backbone
self.backbone = backbone
self.backbone_config = backbone_config
self.backbone_kwargs = backbone_kwargs
self.dim = dim
self.num_heads = num_heads
self.num_classes = num_classes
self.query_block_size = query_block_size
self.feature_levels = feature_levels
self.num_radial_distances = num_radial_distances
self.self_sta_config = self_sta_config
self.cross_sta_config = cross_sta_config
super().__init__(**kwargs)
|