Spaces:
Runtime error
Runtime error
File size: 4,710 Bytes
9f83ce9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | from pathlib import Path
from typing import Any
from dataclasses import dataclass, field
from utils import MODELS, VIDEO_EXTENSIONS
@dataclass
class TransformConfig:
# RGB specific
horizontal_flip_prob: float = 0.5
aug_type: str = "augmix"
aug_paras: dict = field(
default_factory=lambda: {
"magnitude": 3,
"alpha": 1.0,
"width": 5,
"depth": -1,
}
)
sample_rate: int = 4
# Pose specific
normalization: bool = True
# SL-GCN, DSTA-SLR specific
random_choose: bool = False
random_shift: bool = False
random_move: bool = False
random_mirror: bool = False
random_mirror_p: float = 0.5
bone_stream: bool = False
motion_stream: bool = False
# SPOTER specific
augmentation: bool = True
aug_prob: float = 0.5
noise: bool = True
def __post_init__(self):
assert self.aug_type in ["augmix", "mixup"], \
"Only AugMix and MixUp are supported for now"
@dataclass
class DataConfig:
dataset: str = "vsl"
modality: str = "rgb"
subset: str = None
data_dir: str = "data/processed/vsl"
transform: Any = None
fps: int = 30
debug: bool = False
# transform: TransformConfig = TransformConfig()
transform: TransformConfig = field(default_factory=TransformConfig)
def __post_init__(self):
assert self.dataset in ["vsl_98", "vsl_400"], \
"Only VSL dataset is supported for now"
assert self.modality in ["rgb", "pose"], \
"Only RGB and Pose modalities are supported for now"
@dataclass
class ModelConfig:
arch: str = "sl_gcn"
pretrained: str = "vsltranslation/sl_gcn_joint_v3_0"
num_frozen_layers: int = 0
ignored_weights: list = field(default_factory=lambda: [])
num_frames: int = 16
# SL-GCN specific
num_points: int = 27
groups: int = 8
block_size: int = 41
in_channels: int = 3
labeling_mode: str = "spatial"
is_vector: bool = False
# DSTA-SLR specific
graph: str = "wlasl"
inner_dim: int = 64
drop_layers: int = 2
depth: int = 4
s_num_heads: int = 1
window_size: int = 120
# SPOTER specific
hidden_dim: int = 108
def __post_init__(self):
assert self.arch in MODELS, f"Model {self.arch} is not supported"
@dataclass
class TrainingConfig:
output_dir: str = "experiments"
remove_unused_columns: bool = False
do_train: bool = True
use_cpu: bool = False
eval_strategy: str = "epoch"
logging_strategy: str = "epoch"
save_strategy: str = "epoch"
logging_steps: int = 1
save_steps: int = 1
eval_steps: int = 1
learning_rate: float = 5e-5
weight_decay: float = 0
adam_beta1: float = 0.9
adam_beta2: float = 0.999
adam_epsilon: float = 1e-8
warmup_ratio: float = 0.1
num_train_epochs: int = 10
per_device_train_batch_size: int = 8
per_device_eval_batch_size: int = 8
dataloader_num_workers: int = 0
load_best_model_at_end: bool = True
metric_for_best_model: str = "accuracy"
resume_from_checkpoint: str = None
run_name: str = "swin3d"
report_to: str = None
push_to_hub: bool = False
hub_model_id: str = None
hub_strategy: str = "checkpoint"
hub_private_repo: bool = True
def __post_init__(self):
self.output_dir = Path(self.output_dir)
if str(self.output_dir) == "experiments":
self.output_dir = self.output_dir / self.run_name
self.output_dir.mkdir(parents=True, exist_ok=True)
if self.hub_model_id is not None:
self.push_to_hub = True
if len(self.hub_model_id.split("/")) == 1:
self.hub_model_id = f"{self.hub_model_id}/{self.run_name}"
@dataclass
class InferenceConfig:
source: str = "webcam"
output_dir: str = "demo"
use_onnx: bool = False
device: str = "cpu"
cache_dir: str = "models/huggingface"
visualize: bool = False
show_skeleton: bool = False
visibility: float = 0.5
angle_threshold: int = 140
min_num_up_frames: int = 10
min_num_down_frames: int = 10
delay: int = 400
top_k: int = 3
# SL-GCN, DSTA-SLR specific
bone_stream: bool = False
motion_stream: bool = False
def __post_init__(self):
self.source = Path(self.source)
assert any((
str(self.source) == "webcam",
(self.source.exists() and str(self.source).endswith(VIDEO_EXTENSIONS))
)), \
f"Only Webcam and Video sources are supported for now (got {self.source})"
self.output_dir = Path(self.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
|