File size: 2,273 Bytes
389a8d6
d57f1a3
c3f67c4
d57f1a3
c3f67c4
 
389a8d6
 
 
 
 
 
b5a1bdb
 
c3f67c4
 
 
d57f1a3
 
 
 
 
c3f67c4
 
389a8d6
45f73f8
389a8d6
c3f67c4
389a8d6
 
45f73f8
 
 
389a8d6
c3f67c4
 
389a8d6
45f73f8
389a8d6
d57f1a3
 
 
 
 
 
 
 
 
 
 
9c78786
 
c3f67c4
9c78786
 
c3f67c4
389a8d6
c3f67c4
 
389a8d6
c3f67c4
389a8d6
 
 
c3f67c4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from typing import Any, TypedDict

from transformers import PretrainedConfig
from transformers.utils.backbone_utils import verify_backbone_config_arguments


class STAConfig(TypedDict):
    kernel: int
    q_tile: int
    kv_tile: int


class LSPDetrConfig(PretrainedConfig):
    model_type = "lsp_detr"

    def __init__(
        self,
        use_timm_backbone: bool = False,
        use_pretrained_backbone: bool = True,
        backbone: str = "microsoft/swinv2-tiny-patch4-window16-256",
        backbone_kwargs: dict[str, Any] | None = None,
        backbone_config: Any | None = None,
        dim: int = 384,
        num_heads: int = 12,
        num_classes: int = 1,
        query_block_size: float = 8,  # 256 / 32
        feature_levels: tuple[int, ...] = (2, 1, 0, 2, 1, 0),
        num_radial_distances: int = 64,
        self_sta_config: STAConfig | None = None,
        cross_sta_config: tuple[STAConfig, ...] = (
            {"kernel": 5, "q_tile": 4, "kv_tile": 8},
            {"kernel": 5, "q_tile": 4, "kv_tile": 4},
            {"kernel": 5, "q_tile": 4, "kv_tile": 2},
        ),
        **kwargs,
    ) -> None:
        if self_sta_config is None:
            self_sta_config = {"kernel": 5, "q_tile": 4, "kv_tile": 4}

        if backbone_kwargs is None:
            backbone_kwargs = {"out_features": ["stage1", "stage2", "stage3", "stage4"]}

        verify_backbone_config_arguments(
            use_timm_backbone=use_timm_backbone,
            use_pretrained_backbone=use_pretrained_backbone,
            backbone=backbone,
            backbone_config=backbone_config,
            backbone_kwargs=backbone_kwargs,
        )

        self.use_timm_backbone = use_timm_backbone
        self.use_pretrained_backbone = use_pretrained_backbone
        self.backbone = backbone
        self.backbone_config = backbone_config
        self.backbone_kwargs = backbone_kwargs
        self.dim = dim
        self.num_heads = num_heads
        self.num_classes = num_classes
        self.query_block_size = query_block_size
        self.feature_levels = feature_levels
        self.num_radial_distances = num_radial_distances
        self.self_sta_config = self_sta_config
        self.cross_sta_config = cross_sta_config

        super().__init__(**kwargs)