File size: 4,710 Bytes
9f83ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
from pathlib import Path
from typing import Any
from dataclasses import dataclass, field
from utils import MODELS, VIDEO_EXTENSIONS


@dataclass
class TransformConfig:
    # RGB specific
    horizontal_flip_prob: float = 0.5
    aug_type: str = "augmix"
    aug_paras: dict = field(
        default_factory=lambda: {
            "magnitude": 3,
            "alpha": 1.0,
            "width": 5,
            "depth": -1,
        }
    )
    sample_rate: int = 4

    # Pose specific
    normalization: bool = True

    # SL-GCN, DSTA-SLR specific
    random_choose: bool = False
    random_shift: bool = False
    random_move: bool = False
    random_mirror: bool = False
    random_mirror_p: float = 0.5
    bone_stream: bool = False
    motion_stream: bool = False

    # SPOTER specific
    augmentation: bool = True
    aug_prob: float = 0.5
    noise: bool = True

    def __post_init__(self):
        assert self.aug_type in ["augmix", "mixup"], \
            "Only AugMix and MixUp are supported for now"


@dataclass
class DataConfig:
    dataset: str = "vsl"
    modality: str = "rgb"
    subset: str = None
    data_dir: str = "data/processed/vsl"
    transform: Any = None
    fps: int = 30
    debug: bool = False
    # transform: TransformConfig = TransformConfig()
    transform: TransformConfig = field(default_factory=TransformConfig)


    def __post_init__(self):
        assert self.dataset in ["vsl_98", "vsl_400"], \
            "Only VSL dataset is supported for now"
        assert self.modality in ["rgb", "pose"], \
            "Only RGB and Pose modalities are supported for now"


@dataclass
class ModelConfig:
    arch: str = "sl_gcn"
    pretrained: str = "vsltranslation/sl_gcn_joint_v3_0"
    num_frozen_layers: int = 0
    ignored_weights: list = field(default_factory=lambda: [])
    num_frames: int = 16

    # SL-GCN specific
    num_points: int = 27
    groups: int = 8
    block_size: int = 41
    in_channels: int = 3
    labeling_mode: str = "spatial"
    is_vector: bool = False

    # DSTA-SLR specific
    graph: str = "wlasl"
    inner_dim: int = 64
    drop_layers: int = 2
    depth: int = 4
    s_num_heads: int = 1
    window_size: int = 120

    # SPOTER specific
    hidden_dim: int = 108

    def __post_init__(self):
        assert self.arch in MODELS, f"Model {self.arch} is not supported"


@dataclass
class TrainingConfig:
    output_dir: str = "experiments"
    remove_unused_columns: bool = False
    do_train: bool = True
    use_cpu: bool = False

    eval_strategy: str = "epoch"
    logging_strategy: str = "epoch"
    save_strategy: str = "epoch"
    logging_steps: int = 1
    save_steps: int = 1
    eval_steps: int = 1

    learning_rate: float = 5e-5
    weight_decay: float = 0
    adam_beta1: float = 0.9
    adam_beta2: float = 0.999
    adam_epsilon: float = 1e-8
    warmup_ratio: float = 0.1

    num_train_epochs: int = 10
    per_device_train_batch_size: int = 8
    per_device_eval_batch_size: int = 8
    dataloader_num_workers: int = 0

    load_best_model_at_end: bool = True
    metric_for_best_model: str = "accuracy"
    resume_from_checkpoint: str = None

    run_name: str = "swin3d"
    report_to: str = None
    push_to_hub: bool = False
    hub_model_id: str = None
    hub_strategy: str = "checkpoint"
    hub_private_repo: bool = True

    def __post_init__(self):
        self.output_dir = Path(self.output_dir)
        if str(self.output_dir) == "experiments":
            self.output_dir = self.output_dir / self.run_name
        self.output_dir.mkdir(parents=True, exist_ok=True)

        if self.hub_model_id is not None:
            self.push_to_hub = True
            if len(self.hub_model_id.split("/")) == 1:
                self.hub_model_id = f"{self.hub_model_id}/{self.run_name}"


@dataclass
class InferenceConfig:
    source: str = "webcam"
    output_dir: str = "demo"
    use_onnx: bool = False
    device: str = "cpu"
    cache_dir: str = "models/huggingface"

    visualize: bool = False
    show_skeleton: bool = False

    visibility: float = 0.5
    angle_threshold: int = 140
    min_num_up_frames: int = 10
    min_num_down_frames: int = 10
    delay: int = 400

    top_k: int = 3
    # SL-GCN, DSTA-SLR specific
    bone_stream: bool = False
    motion_stream: bool = False

    def __post_init__(self):
        self.source = Path(self.source)
        assert any((
            str(self.source) == "webcam",
            (self.source.exists() and str(self.source).endswith(VIDEO_EXTENSIONS))
        )), \
            f"Only Webcam and Video sources are supported for now (got {self.source})"
        self.output_dir = Path(self.output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)