Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| from typing import Any | |
| from dataclasses import dataclass, field | |
| from utils import MODELS, VIDEO_EXTENSIONS | |
| class TransformConfig: | |
| # RGB specific | |
| horizontal_flip_prob: float = 0.5 | |
| aug_type: str = "augmix" | |
| aug_paras: dict = field( | |
| default_factory=lambda: { | |
| "magnitude": 3, | |
| "alpha": 1.0, | |
| "width": 5, | |
| "depth": -1, | |
| } | |
| ) | |
| sample_rate: int = 4 | |
| # Pose specific | |
| normalization: bool = True | |
| # SL-GCN, DSTA-SLR specific | |
| random_choose: bool = False | |
| random_shift: bool = False | |
| random_move: bool = False | |
| random_mirror: bool = False | |
| random_mirror_p: float = 0.5 | |
| bone_stream: bool = False | |
| motion_stream: bool = False | |
| # SPOTER specific | |
| augmentation: bool = True | |
| aug_prob: float = 0.5 | |
| noise: bool = True | |
| def __post_init__(self): | |
| assert self.aug_type in ["augmix", "mixup"], \ | |
| "Only AugMix and MixUp are supported for now" | |
| class DataConfig: | |
| dataset: str = "vsl" | |
| modality: str = "rgb" | |
| subset: str = None | |
| data_dir: str = "data/processed/vsl" | |
| transform: Any = None | |
| fps: int = 30 | |
| debug: bool = False | |
| # transform: TransformConfig = TransformConfig() | |
| transform: TransformConfig = field(default_factory=TransformConfig) | |
| def __post_init__(self): | |
| assert self.dataset in ["vsl_98", "vsl_400"], \ | |
| "Only VSL dataset is supported for now" | |
| assert self.modality in ["rgb", "pose"], \ | |
| "Only RGB and Pose modalities are supported for now" | |
| class ModelConfig: | |
| arch: str = "sl_gcn" | |
| pretrained: str = "vsltranslation/sl_gcn_joint_v3_0" | |
| num_frozen_layers: int = 0 | |
| ignored_weights: list = field(default_factory=lambda: []) | |
| num_frames: int = 16 | |
| # SL-GCN specific | |
| num_points: int = 27 | |
| groups: int = 8 | |
| block_size: int = 41 | |
| in_channels: int = 3 | |
| labeling_mode: str = "spatial" | |
| is_vector: bool = False | |
| # DSTA-SLR specific | |
| graph: str = "wlasl" | |
| inner_dim: int = 64 | |
| drop_layers: int = 2 | |
| depth: int = 4 | |
| s_num_heads: int = 1 | |
| window_size: int = 120 | |
| # SPOTER specific | |
| hidden_dim: int = 108 | |
| def __post_init__(self): | |
| assert self.arch in MODELS, f"Model {self.arch} is not supported" | |
| class TrainingConfig: | |
| output_dir: str = "experiments" | |
| remove_unused_columns: bool = False | |
| do_train: bool = True | |
| use_cpu: bool = False | |
| eval_strategy: str = "epoch" | |
| logging_strategy: str = "epoch" | |
| save_strategy: str = "epoch" | |
| logging_steps: int = 1 | |
| save_steps: int = 1 | |
| eval_steps: int = 1 | |
| learning_rate: float = 5e-5 | |
| weight_decay: float = 0 | |
| adam_beta1: float = 0.9 | |
| adam_beta2: float = 0.999 | |
| adam_epsilon: float = 1e-8 | |
| warmup_ratio: float = 0.1 | |
| num_train_epochs: int = 10 | |
| per_device_train_batch_size: int = 8 | |
| per_device_eval_batch_size: int = 8 | |
| dataloader_num_workers: int = 0 | |
| load_best_model_at_end: bool = True | |
| metric_for_best_model: str = "accuracy" | |
| resume_from_checkpoint: str = None | |
| run_name: str = "swin3d" | |
| report_to: str = None | |
| push_to_hub: bool = False | |
| hub_model_id: str = None | |
| hub_strategy: str = "checkpoint" | |
| hub_private_repo: bool = True | |
| def __post_init__(self): | |
| self.output_dir = Path(self.output_dir) | |
| if str(self.output_dir) == "experiments": | |
| self.output_dir = self.output_dir / self.run_name | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| if self.hub_model_id is not None: | |
| self.push_to_hub = True | |
| if len(self.hub_model_id.split("/")) == 1: | |
| self.hub_model_id = f"{self.hub_model_id}/{self.run_name}" | |
| class InferenceConfig: | |
| source: str = "webcam" | |
| output_dir: str = "demo" | |
| use_onnx: bool = False | |
| device: str = "cpu" | |
| cache_dir: str = "models/huggingface" | |
| visualize: bool = False | |
| show_skeleton: bool = False | |
| visibility: float = 0.5 | |
| angle_threshold: int = 140 | |
| min_num_up_frames: int = 10 | |
| min_num_down_frames: int = 10 | |
| delay: int = 400 | |
| top_k: int = 3 | |
| # SL-GCN, DSTA-SLR specific | |
| bone_stream: bool = False | |
| motion_stream: bool = False | |
| def __post_init__(self): | |
| self.source = Path(self.source) | |
| assert any(( | |
| str(self.source) == "webcam", | |
| (self.source.exists() and str(self.source).endswith(VIDEO_EXTENSIONS)) | |
| )), \ | |
| f"Only Webcam and Video sources are supported for now (got {self.source})" | |
| self.output_dir = Path(self.output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |