thienphuc12339 commited on
Commit
a7eca0b
·
1 Parent(s): c1a2a6d

Add all source code

Browse files
.dockerignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Ignore build artifacts
2
+ *.log
3
+ *.tmp
4
+
5
+ # Ignore compiled Python files
6
+ __pycache__/
7
+ *.pyc
8
+ *.pyo
9
+ *.pyd
10
+
11
+ # Ignore files/directories
12
+ # engines/data/
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Tắt buffering để log ra terminal ngay lập tức
4
+ ENV PYTHONUNBUFFERED=1
5
+
6
+ # Cài đặt các thư viện hệ thống cần thiết
7
+ RUN apt-get update && apt-get install -y \
8
+ libgl1-mesa-glx \
9
+ libglib2.0-0 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+
14
+ # Sao chép requirements.txt vào container và cài đặt
15
+ COPY requirements.txt .
16
+ RUN pip install --no-cache-dir -r requirements.txt
17
+
18
+ # Sao chép toàn bộ code vào container
19
+ COPY . .
20
+
21
+ # Thiết lập biến môi trường PORT (Hugging Face sẽ trỏ traffic vào port này)
22
+ ENV PORT 7860
23
+ EXPOSE 7860
24
+
25
+ # Chạy ứng dụng FastAPI bằng uvicorn
26
+ # Ở đây giả sử file main app của bạn là app.py và app là tên biến FastAPI instance
27
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # WRITER: PhucNTT2 # EMAIL: thienphuc12339@gmail.com # DATE: 11/2023
2
+ # FROM: akaOCR Team
3
+ # ALL USE CASES MUST BE APPROVED BY AKAOCR TEAM
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
4
+ from fastapi.responses import JSONResponse
5
+ from pydantic import BaseModel
6
+ from pathlib import Path
7
+ import shutil
8
+ import logging
9
+ import uvicorn
10
+ import asyncio
11
+ from typing import Optional
12
+
13
+ from configs import ModelConfig, InferenceConfig
14
+ from tools.models import load_pipeline
15
+ from inference import inference as run_inference
16
+
17
+ # Initialize FastAPI app
18
+ app = FastAPI(title="Sign Language Recognition API")
19
+
20
+ # Configure logging
21
+ logging.basicConfig(level=logging.INFO)
22
+ logger = logging.getLogger(__name__)
23
+
24
+ # Define a Pydantic model for the response
25
+ class InferenceResponse(BaseModel):
26
+ status: str
27
+ predictions: Optional[list] = None
28
+ message: Optional[str] = None
29
+
30
+ @app.post("/inference", response_model=InferenceResponse)
31
+ async def inference_endpoint(
32
+ file: UploadFile = File(...),
33
+ model_name: str = Form(...),
34
+ output_dir: Optional[str] = Form("output")
35
+ ):
36
+ """
37
+ Endpoint để xử lý yêu cầu nhận diện ngôn ngữ ký hiệu từ video.
38
+
39
+ Args:
40
+ file (UploadFile): Video file được tải lên.
41
+ model_name (str): Tên mô hình sẽ sử dụng (ví dụ: 'spoter', 'sl_gcn', 'dsta_slr').
42
+ output_dir (str, optional): Thư mục để lưu kết quả. Mặc định là 'output'.
43
+
44
+ Returns:
45
+ InferenceResponse: Kết quả nhận diện.
46
+ """
47
+ # Kiểm tra file có hợp lệ không
48
+ if not file.filename.endswith((".mp4", ".avi", ".mov", ".mkv")):
49
+ raise HTTPException(status_code=400, detail="Unsupported file type.")
50
+
51
+ # Tạo thư mục output nếu không tồn tại
52
+ output_path = Path(output_dir)
53
+ output_path.mkdir(parents=True, exist_ok=True)
54
+
55
+ # Lưu video tạm thời
56
+ video_path = output_path / file.filename
57
+ with open(video_path, "wb") as buffer:
58
+ shutil.copyfileobj(file.file, buffer)
59
+
60
+ logger.info(f"Video saved to {video_path}")
61
+
62
+ # Tải cấu hình mô hình dựa trên model_name
63
+ try:
64
+ if model_name == "spoter":
65
+ model_config = ModelConfig(arch="spoter", pretrained="vsltranslation/spoter_v3.0")
66
+ elif model_name == "sl_gcn":
67
+ model_config = ModelConfig(arch="sl_gcn", pretrained="vsltranslation/sl_gcn_joint_v3_0")
68
+ elif model_name == "dsta_slr":
69
+ model_config = ModelConfig(arch="dsta_slr", pretrained="vsltranslation/dsta_slr_joint_motion_v3_0")
70
+ else:
71
+ raise ValueError("Unsupported model name.")
72
+
73
+ inference_config = InferenceConfig(
74
+ source=str(video_path),
75
+ output_dir=str(output_path),
76
+ use_onnx=False,
77
+ device="cpu", # Bạn có thể thay đổi thành "cuda" nếu sử dụng GPU
78
+ cache_dir="models/huggingface",
79
+ visualize=False,
80
+ show_skeleton=False,
81
+ visibility=0.5,
82
+ angle_threshold=140,
83
+ min_num_up_frames=10,
84
+ min_num_down_frames=10,
85
+ delay=400,
86
+ top_k=3,
87
+ bone_stream=False,
88
+ motion_stream=False
89
+ )
90
+
91
+ # Tải pipeline
92
+ pipeline = load_pipeline(model_config, inference_config)
93
+ logger.info("Pipeline loaded successfully.")
94
+
95
+ # Chạy inference
96
+ run_inference(model_config, inference_config, pipeline)
97
+ logger.info("Inference completed successfully.")
98
+
99
+ # Đọc kết quả từ file CSV
100
+ results_csv = output_path / "results.csv"
101
+ if results_csv.exists():
102
+ import pandas as pd
103
+ df = pd.read_csv(results_csv)
104
+ predictions = df.to_dict(orient="records")
105
+ else:
106
+ predictions = []
107
+
108
+ return InferenceResponse(status="success", predictions=predictions)
109
+
110
+ except Exception as e:
111
+ logger.exception("Error during inference.")
112
+ raise HTTPException(status_code=500, detail=str(e))
configs/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .arguments import *
configs/arguments.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #configs/arguments.py
2
+
3
+ from pathlib import Path
4
+ from typing import Any
5
+ from dataclasses import dataclass, field
6
+ from utils import MODELS, VIDEO_EXTENSIONS
7
+
8
+
9
+ @dataclass
10
+ class TransformConfig:
11
+ # RGB specific
12
+ horizontal_flip_prob: float = 0.5
13
+ aug_type: str = "augmix"
14
+ aug_paras: dict = field(
15
+ default_factory=lambda: {
16
+ "magnitude": 3,
17
+ "alpha": 1.0,
18
+ "width": 5,
19
+ "depth": -1,
20
+ }
21
+ )
22
+ sample_rate: int = 4
23
+
24
+ # Pose specific
25
+ normalization: bool = True
26
+
27
+ # SL-GCN, DSTA-SLR specific
28
+ random_choose: bool = False
29
+ random_shift: bool = False
30
+ random_move: bool = False
31
+ random_mirror: bool = False
32
+ random_mirror_p: float = 0.5
33
+ bone_stream: bool = False
34
+ motion_stream: bool = False
35
+
36
+ # SPOTER specific
37
+ augmentation: bool = True
38
+ aug_prob: float = 0.5
39
+ noise: bool = True
40
+
41
+ def __post_init__(self):
42
+ assert self.aug_type in ["augmix", "mixup"], \
43
+ "Only AugMix and MixUp are supported for now"
44
+
45
+
46
+ @dataclass
47
+ class DataConfig:
48
+ dataset: str = "vsl"
49
+ modality: str = "rgb"
50
+ subset: str = None
51
+ data_dir: str = "data/processed/vsl"
52
+ transform: Any = None
53
+ fps: int = 30
54
+ debug: bool = False
55
+ # transform: TransformConfig = TransformConfig()
56
+ transform: TransformConfig = field(default_factory=TransformConfig)
57
+
58
+
59
+ def __post_init__(self):
60
+ assert self.dataset in ["vsl_98", "vsl_400"], \
61
+ "Only VSL dataset is supported for now"
62
+ assert self.modality in ["rgb", "pose"], \
63
+ "Only RGB and Pose modalities are supported for now"
64
+
65
+
66
+ @dataclass
67
+ class ModelConfig:
68
+ arch: str = "sl_gcn"
69
+ pretrained: str = "vsltranslation/sl_gcn_joint_v3_0"
70
+ num_frozen_layers: int = 0
71
+ ignored_weights: list = field(default_factory=lambda: [])
72
+ num_frames: int = 16
73
+
74
+ # SL-GCN specific
75
+ num_points: int = 27
76
+ groups: int = 8
77
+ block_size: int = 41
78
+ in_channels: int = 3
79
+ labeling_mode: str = "spatial"
80
+ is_vector: bool = False
81
+
82
+ # DSTA-SLR specific
83
+ graph: str = "wlasl"
84
+ inner_dim: int = 64
85
+ drop_layers: int = 2
86
+ depth: int = 4
87
+ s_num_heads: int = 1
88
+ window_size: int = 120
89
+
90
+ # SPOTER specific
91
+ hidden_dim: int = 108
92
+
93
+ def __post_init__(self):
94
+ assert self.arch in MODELS, f"Model {self.arch} is not supported"
95
+
96
+
97
+ @dataclass
98
+ class TrainingConfig:
99
+ output_dir: str = "experiments"
100
+ remove_unused_columns: bool = False
101
+ do_train: bool = True
102
+ use_cpu: bool = False
103
+
104
+ eval_strategy: str = "epoch"
105
+ logging_strategy: str = "epoch"
106
+ save_strategy: str = "epoch"
107
+ logging_steps: int = 1
108
+ save_steps: int = 1
109
+ eval_steps: int = 1
110
+
111
+ learning_rate: float = 5e-5
112
+ weight_decay: float = 0
113
+ adam_beta1: float = 0.9
114
+ adam_beta2: float = 0.999
115
+ adam_epsilon: float = 1e-8
116
+ warmup_ratio: float = 0.1
117
+
118
+ num_train_epochs: int = 10
119
+ per_device_train_batch_size: int = 8
120
+ per_device_eval_batch_size: int = 8
121
+ dataloader_num_workers: int = 0
122
+
123
+ load_best_model_at_end: bool = True
124
+ metric_for_best_model: str = "accuracy"
125
+ resume_from_checkpoint: str = None
126
+
127
+ run_name: str = "swin3d"
128
+ report_to: str = None
129
+ push_to_hub: bool = False
130
+ hub_model_id: str = None
131
+ hub_strategy: str = "checkpoint"
132
+ hub_private_repo: bool = True
133
+
134
+ def __post_init__(self):
135
+ self.output_dir = Path(self.output_dir)
136
+ if str(self.output_dir) == "experiments":
137
+ self.output_dir = self.output_dir / self.run_name
138
+ self.output_dir.mkdir(parents=True, exist_ok=True)
139
+
140
+ if self.hub_model_id is not None:
141
+ self.push_to_hub = True
142
+ if len(self.hub_model_id.split("/")) == 1:
143
+ self.hub_model_id = f"{self.hub_model_id}/{self.run_name}"
144
+
145
+
146
+ @dataclass
147
+ class InferenceConfig:
148
+ source: str = "webcam"
149
+ output_dir: str = "demo"
150
+ use_onnx: bool = False
151
+ device: str = "cpu"
152
+ cache_dir: str = "models/huggingface"
153
+
154
+ visualize: bool = False
155
+ show_skeleton: bool = False
156
+
157
+ visibility: float = 0.5
158
+ angle_threshold: int = 140
159
+ min_num_up_frames: int = 10
160
+ min_num_down_frames: int = 10
161
+ delay: int = 400
162
+
163
+ top_k: int = 3
164
+ # SL-GCN, DSTA-SLR specific
165
+ bone_stream: bool = False
166
+ motion_stream: bool = False
167
+
168
+ def __post_init__(self):
169
+ self.source = Path(self.source)
170
+ assert any((
171
+ str(self.source) == "webcam",
172
+ (self.source.exists() and str(self.source).endswith(VIDEO_EXTENSIONS))
173
+ )), \
174
+ f"Only Webcam and Video sources are supported for now (got {self.source})"
175
+ self.output_dir = Path(self.output_dir)
176
+ self.output_dir.mkdir(parents=True, exist_ok=True)
configs/dsta_slr.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #configs/dsta_slr.yaml
2
+
3
+ model:
4
+ arch: dsta_slr
5
+ pretrained: vsltranslation/dsta_slr_joint_motion_v3_0
6
+ inference:
7
+ source: webcam
8
+ output_dir: demo/run_1
9
+ use_onnx: True
10
+ show_skeleton: True
11
+ visualize: True
12
+ bone_stream: False
13
+ motion_stream: True
configs/sl_gcn.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #configs/sl_gcn.yaml
2
+
3
+ model:
4
+ arch: sl_gcn
5
+ pretrained: vsltranslation/sl_gcn_joint_v3_0
6
+ inference:
7
+ source: webcam
8
+ output_dir: demo/run_1
9
+ use_onnx: True
10
+ show_skeleton: True
11
+ visualize: True
12
+ bone_stream: True
13
+ motion_stream: False
configs/spoter.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #configs/spoter.yaml
2
+
3
+ model:
4
+ arch: spoter
5
+ pretrained: vsltranslation/spoter_v3.0
6
+ inference:
7
+ source: webcam
8
+ output_dir: demo/run_1
9
+ use_onnx: True
10
+ show_skeleton: True
11
+ visualize: True
data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
data/utils.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #data/utils.py
2
+
3
+ import numpy as np
4
+ from mediapipe.python.solutions import pose
5
+ from visualization import draw_text_on_image
6
+
7
+
8
+ class Arm:
9
+ def __init__(
10
+ self,
11
+ side: str,
12
+ visibility: float = 0.5,
13
+ ) -> None:
14
+ if side == "left":
15
+ self.shoulde_idx = pose.PoseLandmark.LEFT_SHOULDER.value
16
+ self.elbow_idx = pose.PoseLandmark.LEFT_ELBOW.value
17
+ self.wrist_idx = pose.PoseLandmark.LEFT_WRIST.value
18
+ elif side == "right":
19
+ self.shoulde_idx = pose.PoseLandmark.RIGHT_SHOULDER.value
20
+ self.elbow_idx = pose.PoseLandmark.RIGHT_ELBOW.value
21
+ self.wrist_idx = pose.PoseLandmark.RIGHT_WRIST.value
22
+ else:
23
+ raise ValueError("Side must be either 'left' or 'right'")
24
+ self.visibility = visibility
25
+
26
+ self.is_up = False
27
+ self.num_up_frames = 0
28
+ self.num_down_frames = 0
29
+ self.start_time = 0
30
+ self.end_time = 0
31
+ self.shoulder = None
32
+ self.elbow = None
33
+ self.wrist = None
34
+ self.angle = 0
35
+
36
+ def reset_state(self) -> None:
37
+ self.is_up = False
38
+ self.num_up_frames = 0
39
+ self.num_down_frames = 0
40
+ self.start_time = 0
41
+ self.end_time = 0
42
+ self.shoulder = None
43
+ self.elbow = None
44
+ self.wrist = None
45
+ self.angle = 0
46
+
47
+ def set_pose(self, landmarks) -> bool:
48
+ if landmarks[self.shoulde_idx].visibility < self.visibility:
49
+ return False
50
+ self.shoulder = (
51
+ landmarks[self.shoulde_idx].x,
52
+ landmarks[self.shoulde_idx].y,
53
+ )
54
+
55
+ if landmarks[self.elbow_idx].visibility < self.visibility:
56
+ return False
57
+ self.elbow = (
58
+ landmarks[self.elbow_idx].x,
59
+ landmarks[self.elbow_idx].y,
60
+ )
61
+
62
+ if landmarks[self.wrist_idx].visibility < self.visibility:
63
+ return False
64
+ self.wrist = (
65
+ landmarks[self.wrist_idx].x,
66
+ landmarks[self.wrist_idx].y,
67
+ )
68
+
69
+ self.angle = calculate_angle(self.shoulder, self.elbow, self.wrist)
70
+ return True
71
+
72
+ def visualize(
73
+ self,
74
+ frame: np.ndarray,
75
+ position: tuple = (20, 50),
76
+ prefix: str = "Angle",
77
+ color: tuple = (0, 0, 255),
78
+ ) -> np.ndarray:
79
+ text = prefix + ": " + str(round(self.angle, 2))
80
+ return draw_text_on_image(
81
+ image=frame,
82
+ text=text,
83
+ position=position,
84
+ color=color,
85
+ font_size=20,
86
+ )
87
+
88
+
89
+ def get_sample_timestamp(left_arm: Arm, right_arm: Arm) -> tuple:
90
+ start_time, end_time = 0, 0
91
+ left_arm_available = left_arm.start_time > 0 and left_arm.end_time > 0
92
+ right_arm_available = right_arm.start_time > 0 and right_arm.end_time > 0
93
+
94
+ if left_arm_available and right_arm.start_time == 0:
95
+ start_time = left_arm.start_time
96
+ end_time = left_arm.end_time
97
+ if right_arm_available and left_arm.start_time == 0:
98
+ start_time = right_arm.start_time
99
+ end_time = right_arm.end_time
100
+ if all((
101
+ left_arm_available, not left_arm.is_up,
102
+ right_arm_available, not right_arm.is_up,
103
+ )):
104
+ start_time = min(left_arm.start_time, right_arm.start_time)
105
+ end_time = max(left_arm.end_time, right_arm.end_time)
106
+
107
+ # Convert seconds to milliseconds
108
+ start_time /= 1000
109
+ end_time /= 1000
110
+ return start_time, end_time
111
+
112
+
113
+ def calculate_angle(a: tuple, b: tuple, c: tuple) -> float:
114
+ a = np.array(a) # First
115
+ b = np.array(b) # Mid
116
+ c = np.array(c) # End
117
+
118
+ radians = np.arctan2(c[1] - b[1], c[0] - b[0]) - np.arctan2(a[1] - b[1], a[0] - b[0])
119
+ angle = np.abs(radians * 180.0 / np.pi)
120
+
121
+ return 360 - angle if angle > 180 else angle
122
+
123
+
124
+ def ok_to_get_frame(
125
+ arm: Arm,
126
+ angle_threshold: int,
127
+ min_num_up_frames: int,
128
+ min_num_down_frames: int,
129
+ current_time: int,
130
+ delay: int,
131
+ ) -> bool:
132
+ if 0 < arm.angle < angle_threshold:
133
+ if arm.is_up:
134
+ arm.num_down_frames = 0
135
+ arm.end_time = 0
136
+ else:
137
+ if arm.num_up_frames == min_num_up_frames:
138
+ arm.is_up = True
139
+ arm.num_up_frames = 0
140
+ else:
141
+ if arm.num_up_frames == 0:
142
+ arm.start_time = current_time - delay
143
+ arm.num_up_frames += 1
144
+ return False
145
+ else:
146
+ if arm.is_up:
147
+ if arm.num_down_frames == min_num_down_frames:
148
+ arm.is_up = False
149
+ arm.num_down_frames = 0
150
+ else:
151
+ if arm.num_down_frames == 0:
152
+ arm.end_time = current_time + delay
153
+ arm.num_down_frames += 1
154
+ return True
155
+ else:
156
+ arm.num_up_frames = 0
157
+ arm.start_time = 0
158
+
159
+ return arm.is_up
inference.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference.py
2
+
3
+ import logging
4
+ import pandas as pd
5
+ import cv2
6
+ import numpy as np
7
+ from pathlib import Path
8
+ import time
9
+
10
+ from configs import ModelConfig, InferenceConfig
11
+ from tools.models import load_pipeline
12
+ from utils import POSE_BASED_MODELS
13
+ from data import Arm, get_sample_timestamp, ok_to_get_frame
14
+ from visualization.utils import draw_text_on_image
15
+ from tools.models import Predictions
16
+
17
+ def inference(model_config: ModelConfig, inference_config: InferenceConfig, pipeline) -> dict:
18
+ """
19
+ Thực hiện quá trình suy luận trên video.
20
+
21
+ Args:
22
+ model_config (ModelConfig): Cấu hình mô hình.
23
+ inference_config (InferenceConfig): Cấu hình suy luận.
24
+ pipeline: Pipeline đã được tải.
25
+
26
+ Returns:
27
+ dict: Kết quả nhận diện.
28
+ """
29
+ # Load video
30
+ source = str(inference_config.source) if Path(inference_config.source).is_file() else 0
31
+ cap = cv2.VideoCapture(source)
32
+ if inference_config.output_dir is not None:
33
+ output_dir = Path(inference_config.output_dir)
34
+ output_dir.mkdir(parents=True, exist_ok=True)
35
+ writer = cv2.VideoWriter(
36
+ str(output_dir / "output.mp4"),
37
+ cv2.VideoWriter_fourcc(*"mp4v"),
38
+ cap.get(cv2.CAP_PROP_FPS),
39
+ (int(cap.get(3)), int(cap.get(4))),
40
+ )
41
+ else:
42
+ writer = None
43
+
44
+ # Init Mediapipe
45
+ import mediapipe as mp
46
+ from mediapipe.python.solutions.pose import PoseLandmark
47
+ from mediapipe.python.solutions.hands import HandLandmark
48
+ from mediapipe.python.solutions.drawing_utils import DrawingSpec
49
+
50
+ mp_holistic = mp.solutions.holistic
51
+ mp_drawing = mp.solutions.drawing_utils
52
+ mp_drawing_styles = mp.solutions.drawing_styles
53
+
54
+ custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
55
+ custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
56
+ custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
57
+ custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS)
58
+ custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)
59
+
60
+ if inference_config.show_skeleton:
61
+ pose_landmarks = [
62
+ PoseLandmark.NOSE,
63
+ PoseLandmark.LEFT_EYE,
64
+ PoseLandmark.RIGHT_EYE,
65
+ PoseLandmark.LEFT_SHOULDER,
66
+ PoseLandmark.RIGHT_SHOULDER,
67
+ PoseLandmark.LEFT_ELBOW,
68
+ PoseLandmark.RIGHT_ELBOW,
69
+ PoseLandmark.LEFT_WRIST,
70
+ PoseLandmark.RIGHT_WRIST
71
+ ]
72
+ hand_landmarks = [
73
+ HandLandmark.WRIST,
74
+ HandLandmark.INDEX_FINGER_TIP, HandLandmark.INDEX_FINGER_DIP, HandLandmark.INDEX_FINGER_PIP, HandLandmark.INDEX_FINGER_MCP,
75
+ HandLandmark.MIDDLE_FINGER_TIP, HandLandmark.MIDDLE_FINGER_DIP, HandLandmark.MIDDLE_FINGER_PIP, HandLandmark.MIDDLE_FINGER_MCP,
76
+ HandLandmark.RING_FINGER_TIP, HandLandmark.RING_FINGER_DIP, HandLandmark.RING_FINGER_PIP, HandLandmark.RING_FINGER_MCP,
77
+ HandLandmark.PINKY_TIP, HandLandmark.PINKY_DIP, HandLandmark.PINKY_PIP, HandLandmark.PINKY_MCP,
78
+ HandLandmark.THUMB_TIP, HandLandmark.THUMB_IP, HandLandmark.THUMB_MCP, HandLandmark.THUMB_CMC,
79
+ ]
80
+
81
+ for landmark in PoseLandmark:
82
+ if landmark in pose_landmarks:
83
+ custom_pose_style[landmark] = DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
84
+ else:
85
+ custom_pose_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
86
+ custom_pose_connections = [conn for conn in custom_pose_connections if landmark.value not in conn]
87
+
88
+ for landmark in HandLandmark:
89
+ if landmark in hand_landmarks:
90
+ custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
91
+ custom_left_hand_style[landmark] = DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
92
+ else:
93
+ custom_right_hand_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
94
+ custom_left_hand_style[landmark] = DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
95
+ custom_hand_connections = [conn for conn in custom_hand_connections if landmark.value not in conn]
96
+
97
+ # Init variables
98
+ right_arm = Arm("right", inference_config.visibility)
99
+ left_arm = Arm("left", inference_config.visibility)
100
+ data = []
101
+ results = None
102
+ predictions = Predictions()
103
+
104
+ with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic:
105
+ while cap.isOpened():
106
+ success, frame = cap.read()
107
+ if not success:
108
+ break
109
+
110
+ # Recolor image to RGB, because mp processes on RGB image
111
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
112
+ frame.flags.writeable = False
113
+
114
+ # Make detections
115
+ detection_results = holistic.process(frame)
116
+
117
+ # Recolor image back to BGR, because cv2 processes on BGR image
118
+ frame.flags.writeable = True
119
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
120
+
121
+ # Extract landmarks
122
+ try:
123
+ landmarks = detection_results.pose_landmarks.landmark
124
+ except Exception:
125
+ continue
126
+
127
+ left_arm.set_pose(landmarks)
128
+ right_arm.set_pose(landmarks)
129
+
130
+ # Check if arms are up or down
131
+ left_arm_ok_to_get_frame = ok_to_get_frame(
132
+ arm=left_arm,
133
+ angle_threshold=inference_config.angle_threshold,
134
+ min_num_up_frames=inference_config.min_num_up_frames,
135
+ min_num_down_frames=inference_config.min_num_down_frames,
136
+ current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
137
+ delay=inference_config.delay,
138
+ )
139
+ right_arm_ok_to_get_frame = ok_to_get_frame(
140
+ arm=right_arm,
141
+ angle_threshold=inference_config.angle_threshold,
142
+ min_num_up_frames=inference_config.min_num_up_frames,
143
+ min_num_down_frames=inference_config.min_num_down_frames,
144
+ current_time=cap.get(cv2.CAP_PROP_POS_MSEC),
145
+ delay=inference_config.delay,
146
+ )
147
+ if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
148
+ predictions = Predictions()
149
+ data.append(detection_results.pose_landmarks if inference_config.use_pose_model else frame)
150
+
151
+ # Calculate the start and end time of sign
152
+ start_time, end_time = get_sample_timestamp(left_arm, right_arm)
153
+
154
+ # Convert from miliseconds to seconds
155
+ start_time /= 1_000
156
+ end_time /= 1_000
157
+
158
+ if start_time != 0 and end_time != 0:
159
+ # Run inference
160
+ start_inference_time = time.time()
161
+ predictions = Predictions(predictions=pipeline(np.array(data)))
162
+ predictions.inference_time = time.time() - start_inference_time
163
+
164
+ predictions.start_time = start_time
165
+ predictions.end_time = end_time
166
+ logging.info(str(predictions))
167
+ results = predictions.merge_results(results)
168
+
169
+ # Reset variables
170
+ start_time = 0
171
+ end_time = 0
172
+ left_arm.reset_state()
173
+ right_arm.reset_state()
174
+ data = []
175
+
176
+ # Render detections
177
+ frame = left_arm.visualize(frame, (20, 10), "Left arm angle")
178
+ frame = right_arm.visualize(frame, (20, 40), "Right arm angle")
179
+ frame = predictions.visualize(frame, (20, 70))
180
+ if inference_config.show_skeleton:
181
+ mp_drawing.draw_landmarks(
182
+ frame,
183
+ detection_results.pose_landmarks,
184
+ connections = custom_pose_connections,
185
+ landmark_drawing_spec=custom_pose_style
186
+ )
187
+
188
+ mp_drawing.draw_landmarks(
189
+ frame,
190
+ detection_results.right_hand_landmarks,
191
+ connections = custom_hand_connections,
192
+ landmark_drawing_spec=custom_right_hand_style
193
+ )
194
+
195
+ mp_drawing.draw_landmarks(
196
+ frame,
197
+ detection_results.left_hand_landmarks,
198
+ connections = custom_hand_connections,
199
+ landmark_drawing_spec=custom_left_hand_style
200
+ )
201
+
202
+ if writer:
203
+ writer.write(frame)
204
+
205
+ cap.release()
206
+ if writer:
207
+ writer.release()
208
+ logging.info(f"Video is recorded and saved to {inference_config.output_dir / 'output.mp4'}")
209
+ pd.DataFrame(results).to_csv(inference_config.output_dir / "results.csv", index=False)
210
+ logging.info(f"Results saved to {inference_config.output_dir / 'results.csv'}")
211
+
212
+ return {
213
+ "video_path": str(output_path / "output.mp4"),
214
+ "results": results
215
+ }
models/dsta_slr_joint_motion_v3_0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecfcb2b459fd68bfe838569d41bdb502f7cd21ddd675790146034cf0e6f71632
3
+ size 29678372
models/sl_gcn_joint_v3_0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ab4e3b86ec2a828c9e8f72f1f80ca131c0b7439539412fe15244dbcb64fb2a1
3
+ size 17046336
models/spoter_v3.0.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:38c21cd96446475cdc110f7748b11ad58b84cd055133379684f9f463dea8fcbd
3
+ size 24208453
request.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+
3
+ url = 'https://<your-hf-space-url>.hf.space/inference' # URL thực tế sau khi deploy lên HF
4
+ video_path = '/path/to/your_video.mp4'
5
+ params = {
6
+ 'model_name': 'spoter',
7
+ 'output_option': 'all',
8
+ 'output_dir': 'custom_output_folder' # người dùng có thể chọn folder output
9
+ }
10
+ files = {
11
+ 'file': open(video_path, 'rb')
12
+ }
13
+
14
+ response = requests.post(url=url, files=files, params=params)
15
+ print(response.json())
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers
2
+ pandas
3
+ evaluate
4
+ simple-parsing
5
+ torch
6
+ torchvision
7
+ hf-transfer
8
+ decord
9
+ accelerate
10
+ scikit-learn
11
+ wandb
12
+ pose-format
13
+ torchsummary
14
+ mediapipe
15
+ opencv-python
16
+ onnxruntime
17
+ onnx
18
+ imageio
19
+ tk
20
+ timm
21
+ einops
22
+ fastapi
23
+ uvicorn
24
+ pydantic
25
+ numpy
26
+ opencv-python
27
+ simple_parsing
tools/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .models import *
2
+ from .features import *
3
+ # from .utils import exists_on_hf
tools/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (203 Bytes). View file
 
tools/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (234 Bytes). View file
 
tools/__pycache__/features.cpython-39.pyc ADDED
Binary file (1.51 kB). View file
 
tools/__pycache__/models.cpython-312.pyc ADDED
Binary file (15.4 kB). View file
 
tools/__pycache__/models.cpython-39.pyc ADDED
Binary file (9.63 kB). View file
 
tools/features.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #tools/features.py
2
+
3
+ import torch
4
+ from configs import DataConfig
5
+ from features import BaseDataset, VSL98Dataset, VSL400Dataset
6
+
7
+
8
+ def load_dataset(data_config: DataConfig) -> BaseDataset:
9
+ '''
10
+ '''
11
+ datasets = {
12
+ 'vsl_98': VSL98Dataset,
13
+ "vsl_400": VSL400Dataset,
14
+ }
15
+ return datasets[data_config.dataset](data_config)
16
+
17
+
18
+ def rgb_collate_fn(examples) -> dict:
19
+ # permute to (num_frames, num_channels, height, width)
20
+ pixel_values = torch.stack(
21
+ [example["video"].permute(1, 0, 2, 3) for example in examples]
22
+ )
23
+ labels = torch.tensor([example["label"] for example in examples])
24
+ return {"pixel_values": pixel_values, "labels": labels}
25
+
26
+
27
+ def pose_collate_fn(examples) -> dict:
28
+ # permute to (num_frames, num_channels, height, width)
29
+ poses = torch.stack([example["pose"] for example in examples])
30
+ labels = torch.tensor([example["label"] for example in examples])
31
+ return {"poses": poses, "labels": labels}
tools/models.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #tools/models.py
2
+
3
+ import torch
4
+ import logging
5
+ import onnxruntime as ort
6
+ from time import time
7
+ from typing import Union
8
+ from configs import ModelConfig, InferenceConfig
9
+ from utils import (
10
+ POSE_BASED_MODELS,
11
+ RGB_BASED_MODELS,
12
+ HUGGINGFACE_RGB_BASED_MODELS,
13
+ TORCHHUB_RGB_BASED_MODELS,
14
+ )
15
+ from transformers import (
16
+ ImageProcessingMixin,
17
+ FeatureExtractionMixin,
18
+ AutoModelForVideoClassification,
19
+ AutoModel,
20
+ Pipeline,
21
+ pipeline,
22
+ )
23
+ from transformers.pipelines import PIPELINE_REGISTRY
24
+ from visualization import draw_text_on_image
25
+ from utils import exists_on_hf
26
+ from models import (
27
+ Swin3DConfig, Swin3DImageProcessor, Swin3DForVideoClassification,
28
+ S3DConfig, S3DImageProcessor, S3DForVideoClassification,
29
+ VideoResNetConfig, VideoResNetImageProcessor, VideoResNetForVideoClassification,
30
+ MViTConfig, MViTImageProcessor, MViTForVideoClassification,
31
+ SLGCNConfig, SLGCNFeatureExtractor, SLGCNForGraphClassification,
32
+ SPOTERConfig, SPOTERFeatureExtractor, SPOTERForGraphClassification,
33
+ DSTASLRConfig, DSTASLRFeatureExtractor, DSTASLRForGraphClassification,
34
+ VideoMAEConfig, VideoMAEImageProcessor, VideoMAEForVideoClassification
35
+ )
36
+ from pipelines import (
37
+ VideoClassificationPipeline,
38
+ SLGCNGraphClassificationPipeline,
39
+ SPOTERGraphClassificationPipeline,
40
+ )
41
+
42
+
43
+ def load_model(
44
+ model_config: ModelConfig,
45
+ label2id: dict = None,
46
+ id2label: dict = None,
47
+ do_train: bool = False,
48
+ ) -> tuple:
49
+ '''
50
+ '''
51
+ if do_train:
52
+ if model_config.arch in POSE_BASED_MODELS:
53
+ return load_pose_model_for_training(model_config, label2id, id2label)
54
+ return load_rgb_model_for_training(model_config, label2id, id2label)
55
+
56
+ if model_config.arch in POSE_BASED_MODELS:
57
+ processor = FeatureExtractionMixin.from_pretrained(
58
+ model_config.pretrained,
59
+ trust_remote_code=True,
60
+ cache_dir="models/huggingface",
61
+ )
62
+ model = AutoModel.from_pretrained(
63
+ model_config.pretrained,
64
+ trust_remote_code=True,
65
+ cache_dir="models/huggingface",
66
+ )
67
+ else:
68
+ processor = ImageProcessingMixin.from_pretrained(
69
+ model_config.pretrained,
70
+ trust_remote_code=True,
71
+ cache_dir="models/huggingface",
72
+ )
73
+ model = AutoModelForVideoClassification.from_pretrained(
74
+ model_config.pretrained,
75
+ trust_remote_code=True,
76
+ cache_dir="models/huggingface",
77
+ )
78
+ model.eval()
79
+ return model.config, processor, model
80
+
81
+
82
+ def load_rgb_model_for_training(
83
+ model_config: ModelConfig,
84
+ label2id: dict = None,
85
+ id2label: dict = None,
86
+ ) -> tuple:
87
+ '''
88
+ '''
89
+ if model_config.arch in HUGGINGFACE_RGB_BASED_MODELS:
90
+ if model_config.arch == "videomae":
91
+ config_class = VideoMAEConfig
92
+ processor_class = VideoMAEImageProcessor
93
+ model_class = VideoMAEForVideoClassification
94
+ elif exists_on_hf(model_config.pretrained):
95
+ processor = ImageProcessingMixin.from_pretrained(
96
+ model_config.pretrained,
97
+ trust_remote_code=True,
98
+ cache_dir="models/huggingface",
99
+ )
100
+ model = AutoModelForVideoClassification.from_pretrained(
101
+ model_config.pretrained,
102
+ label2id,
103
+ id2label,
104
+ ignore_mismatched_sizes=True,
105
+ trust_remote_code=True,
106
+ cache_dir="models/huggingface",
107
+ )
108
+ return model.config, processor, model
109
+ elif model_config.arch in TORCHHUB_RGB_BASED_MODELS:
110
+ if model_config.arch in ['swin3d_t', 'swin3d_s', 'swin3d_b']:
111
+ config_class = Swin3DConfig
112
+ processor_class = Swin3DImageProcessor
113
+ model_class = Swin3DForVideoClassification
114
+ elif model_config.arch in ['r3d_18', 'mc3_18', 'r2plus1d_18']:
115
+ config_class = VideoResNetConfig
116
+ processor_class = VideoResNetImageProcessor
117
+ model_class = VideoResNetForVideoClassification
118
+ elif model_config.arch in ['s3d']:
119
+ config_class = S3DConfig
120
+ processor_class = S3DImageProcessor
121
+ model_class = S3DForVideoClassification
122
+ elif model_config.arch in ['mvit_v1_b', 'mvit_v2_s']:
123
+ config_class = MViTConfig
124
+ processor_class = MViTImageProcessor
125
+ model_class = MViTForVideoClassification
126
+ else:
127
+ logging.error(f"Model {model_config.arch} is not supported")
128
+ exit(1)
129
+
130
+ config_class.register_for_auto_class()
131
+ processor_class.register_for_auto_class("AutoImageProcessor")
132
+ model_class.register_for_auto_class("AutoModel")
133
+ model_class.register_for_auto_class("AutoModelForVideoClassification")
134
+ logging.info(f"{model_config.arch} classes registered")
135
+
136
+ config = config_class(**vars(model_config))
137
+ processor = processor_class(config=config)
138
+ model = model_class(config=config, label2id=label2id, id2label=id2label)
139
+
140
+ return config, processor, model
141
+
142
+
143
+ def load_pose_model_for_training(
144
+ model_config: ModelConfig,
145
+ label2id: dict = None,
146
+ id2label: dict = None,
147
+ ) -> tuple:
148
+ '''
149
+ '''
150
+ if exists_on_hf(model_config.pretrained):
151
+ processor = FeatureExtractionMixin.from_pretrained(
152
+ model_config.pretrained,
153
+ trust_remote_code=True,
154
+ cache_dir="models/huggingface",
155
+ )
156
+ model = AutoModel.from_pretrained(
157
+ model_config.pretrained,
158
+ label2id=label2id,
159
+ id2label=id2label,
160
+ ignore_mismatched_sizes=True,
161
+ trust_remote_code=True,
162
+ cache_dir="models/huggingface",
163
+ )
164
+ return model.config, processor, model
165
+ elif model_config.arch in POSE_BASED_MODELS:
166
+ if model_config.arch == "spoter":
167
+ config_class = SPOTERConfig
168
+ processor_class = SPOTERFeatureExtractor
169
+ model_class = SPOTERForGraphClassification
170
+ elif model_config.arch == "sl_gcn":
171
+ config_class = SLGCNConfig
172
+ processor_class = SLGCNFeatureExtractor
173
+ model_class = SLGCNForGraphClassification
174
+ elif model_config.arch == "dsta_slr":
175
+ config_class = DSTASLRConfig
176
+ processor_class = DSTASLRFeatureExtractor
177
+ model_class = DSTASLRForGraphClassification
178
+ else:
179
+ logging.error(f"Model {model_config.arch} is not supported")
180
+ exit(1)
181
+
182
+ config_class.register_for_auto_class()
183
+ processor_class.register_for_auto_class("AutoFeatureExtractor")
184
+ model_class.register_for_auto_class("AutoModel")
185
+ logging.info(F"Registering {model_config.arch} classes")
186
+
187
+ config = config_class(**vars(model_config))
188
+ processor = processor_class(config=config)
189
+ model = model_class(config=config, label2id=label2id, id2label=id2label)
190
+
191
+ return config, processor, model
192
+
193
+
194
+ class Predictions:
195
+ def __init__(
196
+ self,
197
+ predictions: list[dict] = None,
198
+ inference_time: float = 0,
199
+ start_time: float = 0,
200
+ end_time: float = 0,
201
+ ) -> None:
202
+ self.predictions = predictions
203
+ self.inference_time = inference_time
204
+ self.start_time = start_time
205
+ self.end_time = end_time
206
+
207
+ def visualize(
208
+ self,
209
+ frame: torch.Tensor,
210
+ position: tuple = (20, 100),
211
+ prefix: str = "Predictions",
212
+ color: tuple = (0, 0, 255),
213
+ ) -> None:
214
+ text = prefix + ": " + self.get_pred_message()
215
+ return draw_text_on_image(
216
+ image=frame,
217
+ text=text,
218
+ position=position,
219
+ color=color,
220
+ font_size=20,
221
+ )
222
+
223
+ def get_pred_message(self) -> str:
224
+ if not any((
225
+ self.start_time,
226
+ self.end_time,
227
+ self.inference_time,
228
+ self.predictions
229
+ )):
230
+ return ""
231
+
232
+ return ', '.join(
233
+ [
234
+ f"{pred['gloss']} ({pred['score']*100:.2f}%)"
235
+ for pred in self.predictions
236
+ ]
237
+ )
238
+
239
+ def __str__(self) -> str:
240
+ if not any((
241
+ self.start_time,
242
+ self.end_time,
243
+ self.inference_time,
244
+ self.predictions
245
+ )):
246
+ return ""
247
+
248
+ predictions = self.get_pred_message()
249
+ message = "Sample start: {:.2f}s - end: {:.2f}s | Runtime: {:.2f}s | Predictions: {}"
250
+ return message.format(self.start_time, self.end_time, self.inference_time, predictions)
251
+
252
+ def merge_results(self, results: dict = None) -> dict:
253
+ if results is None:
254
+ results = {
255
+ "start_time": [],
256
+ "end_time": [],
257
+ "inference_time": [],
258
+ "prediction": [],
259
+ }
260
+ results["start_time"].append(self.start_time)
261
+ results["end_time"].append(self.end_time)
262
+ results["inference_time"].append(self.inference_time)
263
+ results["prediction"].append(self.predictions)
264
+ return results
265
+
266
+
267
+ def get_predictions(
268
+ inputs: torch.Tensor,
269
+ model: Union[ort.InferenceSession, AutoModel],
270
+ id2gloss: dict,
271
+ k: int = 3,
272
+ ) -> Predictions:
273
+ '''
274
+ Get the top-k predictions.
275
+ Parameters
276
+ ----------
277
+ inputs : torch.Tensor
278
+ Model inputs (Time, Height, Width, Channels).
279
+ model : Union[ort.InferenceSession, AutoModel]
280
+ Model to get predictions from.
281
+ id2gloss : dict
282
+ Mapping of class indices to glosses.
283
+ k : int, optional
284
+ Number of predictions to return, by default 3.
285
+ Returns
286
+ -------
287
+ tuple
288
+ List of top-k predictions and inference time.
289
+ '''
290
+ if inputs is None:
291
+ return Predictions()
292
+
293
+ # Get logits
294
+ start_time = time()
295
+ if isinstance(model, ort.InferenceSession):
296
+ inputs = inputs.cpu().numpy()
297
+ logits = torch.from_numpy(model.run(None, {"pixel_values": inputs})[0])
298
+ else:
299
+ logits = model(inputs.to(model.device)).logits
300
+ inference_time = time() - start_time
301
+
302
+ # Get top-3 predictions
303
+ topk_scores, topk_indices = torch.topk(logits, k, dim=1)
304
+ topk_scores = torch.nn.functional.softmax(topk_scores, dim=1).squeeze().detach().numpy()
305
+ topk_indices = topk_indices.squeeze().detach().numpy()
306
+ predictions = [
307
+ {
308
+ 'gloss': id2gloss[str(topk_indices[i])],
309
+ 'score': topk_scores[i],
310
+ }
311
+ for i in range(k)
312
+ ]
313
+
314
+ return Predictions(predictions=predictions, inference_time=inference_time)
315
+
316
+
317
+ def register_pipeline(model_config: ModelConfig) -> Pipeline:
318
+ '''
319
+ '''
320
+ _, processor, model = load_model(model_config)
321
+
322
+ if model_config.arch == "spoter":
323
+ PIPELINE_REGISTRY.register_pipeline(
324
+ "video-classification",
325
+ pipeline_class=SPOTERGraphClassificationPipeline,
326
+ pt_model=AutoModel,
327
+ default={"pt": ("vsltranslation/spoter_v3.0", "main")},
328
+ type="multimodal",
329
+ )
330
+ return SPOTERGraphClassificationPipeline(
331
+ model=model,
332
+ feature_extractor=processor,
333
+ )
334
+ elif model_config.arch in ["sl_gcn", "dsta_slr"]:
335
+ PIPELINE_REGISTRY.register_pipeline(
336
+ "video-classification",
337
+ pipeline_class=SLGCNGraphClassificationPipeline,
338
+ pt_model=AutoModel,
339
+ default={"pt": ("vsltranslation/sl_gcn_joint_v1.0", "main")},
340
+ type="multimodal",
341
+ )
342
+ return SLGCNGraphClassificationPipeline(
343
+ model=model,
344
+ feature_extractor=processor,
345
+ )
346
+
347
+ PIPELINE_REGISTRY.register_pipeline(
348
+ "video-classification",
349
+ pipeline_class=VideoClassificationPipeline,
350
+ pt_model=AutoModelForVideoClassification,
351
+ default={"pt": ("vsltranslation/swin3d_t_v1.0", "main")},
352
+ type="multimodal",
353
+ )
354
+ return VideoClassificationPipeline(
355
+ model=model,
356
+ image_processor=processor,
357
+ )
358
+
359
+
360
+ def load_pipeline(
361
+ model_config: ModelConfig,
362
+ inference_config: InferenceConfig,
363
+ ) -> Pipeline:
364
+ '''
365
+ '''
366
+ if model_config.arch in POSE_BASED_MODELS:
367
+ return pipeline(
368
+ "video-classification",
369
+ model=model_config.pretrained,
370
+ feature_extractor=model_config.pretrained,
371
+ device=inference_config.device,
372
+ model_kwargs={
373
+ "cache_dir": inference_config.cache_dir,
374
+ },
375
+ trust_remote_code=True,
376
+ use_onnx=inference_config.use_onnx,
377
+ top_k=inference_config.top_k,
378
+ bone_stream=inference_config.bone_stream,
379
+ motion_stream=inference_config.motion_stream,
380
+ )
381
+
382
+ return pipeline(
383
+ "video-classification",
384
+ model=model_config.pretrained,
385
+ image_processor=model_config.pretrained,
386
+ device=inference_config.device,
387
+ model_kwargs={
388
+ "cache_dir": inference_config.cache_dir,
389
+ },
390
+ trust_remote_code=True,
391
+ use_onnx=inference_config.use_onnx,
392
+ top_k=inference_config.top_k,
393
+ )
394
+
395
+
396
+ def get_input_shape(
397
+ arch: str,
398
+ processor: Union[ImageProcessingMixin, FeatureExtractionMixin],
399
+ batch_size: int = 1,
400
+ ) -> tuple:
401
+ '''
402
+ Get the input shape for the model.
403
+ Parameters
404
+ ----------
405
+ processor : Union[ImageProcessingMixin, FeatureExtractionMixin]
406
+ Model processor.
407
+ batch_size : int, optional
408
+ Batch size, by default 1.
409
+ Returns
410
+ -------
411
+ tuple
412
+ Input shape.
413
+ '''
414
+ if arch in RGB_BASED_MODELS:
415
+ return (
416
+ batch_size,
417
+ processor.num_frames,
418
+ 3,
419
+ processor.size["height"],
420
+ processor.size["width"]
421
+ )
422
+ elif arch in POSE_BASED_MODELS:
423
+ if arch == "spoter":
424
+ return (
425
+ batch_size,
426
+ processor.num_frames,
427
+ processor.num_points,
428
+ processor.in_channels,
429
+ )
430
+ elif arch in ["sl_gcn", "dsta_slr"]:
431
+ return (
432
+ batch_size,
433
+ processor.in_channels,
434
+ processor.window_size,
435
+ processor.num_points,
436
+ processor.num_people,
437
+ )
438
+ else:
439
+ logging.error(f"Model {arch} is not supported")
440
+ exit(1)
441
+ else:
442
+ logging.error(f"Model {arch} is not supported")
443
+ exit(1)
utils/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .loggers import *
2
+ from .constants import *
utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (205 Bytes). View file
 
utils/__pycache__/constants.cpython-312.pyc ADDED
Binary file (4.35 kB). View file
 
utils/__pycache__/loggers.cpython-312.pyc ADDED
Binary file (1.59 kB). View file
 
utils/constants.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #utils/constants.py
2
+
3
+ import numpy as np
4
+
5
+
6
+ VIDEO_EXTENSIONS = (".mp4", ".avi", ".mov", ".mkv")
7
+
8
+ TORCHHUB_RGB_BASED_MODELS = (
9
+ 'swin3d_t',
10
+ 'swin3d_s',
11
+ 'swin3d_b',
12
+ "r3d_18",
13
+ "mc3_18",
14
+ "r2plus1d_18",
15
+ "s3d",
16
+ "mvit_v1_b",
17
+ "mvit_v2_s",
18
+ )
19
+ HUGGINGFACE_RGB_BASED_MODELS = (
20
+ "videomae",
21
+ )
22
+ RGB_BASED_MODELS = HUGGINGFACE_RGB_BASED_MODELS + TORCHHUB_RGB_BASED_MODELS
23
+
24
+ POSE_BASED_MODELS = (
25
+ "spoter",
26
+ "sl_gcn",
27
+ "dsta_slr"
28
+ )
29
+
30
+ MODELS = RGB_BASED_MODELS + POSE_BASED_MODELS
31
+
32
+ HAND_LANDMARKS = [
33
+ "wrist",
34
+ "indexTip",
35
+ "indexDIP",
36
+ "indexPIP",
37
+ "indexMCP",
38
+ "middleTip",
39
+ "middleDIP",
40
+ "middlePIP",
41
+ "middleMCP",
42
+ "ringTip",
43
+ "ringDIP",
44
+ "ringPIP",
45
+ "ringMCP",
46
+ "littleTip",
47
+ "littleDIP",
48
+ "littlePIP",
49
+ "littleMCP",
50
+ "thumbTip",
51
+ "thumbIP",
52
+ "thumbMP",
53
+ "thumbCMC",
54
+ ]
55
+ BODY_LANDMARKS = [
56
+ "nose",
57
+ "neck",
58
+ "rightEye",
59
+ "leftEye",
60
+ "rightEar",
61
+ "leftEar",
62
+ "rightShoulder",
63
+ "leftShoulder",
64
+ "rightElbow",
65
+ "leftElbow",
66
+ "rightWrist",
67
+ "leftWrist",
68
+ ]
69
+ ARM_LANDMARKS_ORDER = ["neck", "$side$Shoulder", "$side$Elbow", "$side$Wrist"]
70
+
71
+ FLIP_IDXS = np.concatenate(
72
+ (
73
+ [0, 2, 1, 4, 3, 6, 5],
74
+ [17, 18, 19, 20, 21, 22, 23, 24, 25, 26],
75
+ [7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
76
+ ),
77
+ axis=0,
78
+ )
79
+
80
+ SLGCN_JOINTS = {
81
+ 59: np.concatenate((np.arange(0, 17), np.arange(91, 133)), axis=0), # 59
82
+ 31: np.concatenate(
83
+ (
84
+ np.arange(0, 11),
85
+ [91, 95, 96, 99, 100, 103, 104, 107, 108, 111],
86
+ [112, 116, 117, 120, 121, 124, 125, 128, 129, 132],
87
+ ),
88
+ axis=0,
89
+ ), # 31
90
+ 27: np.concatenate(
91
+ (
92
+ [0, 5, 6, 7, 8, 9, 10],
93
+ [91, 95, 96, 99, 100, 103, 104, 107, 108, 111],
94
+ [112, 116, 117, 120, 121, 124, 125, 128, 129, 132],
95
+ ),
96
+ axis=0,
97
+ ), # 27
98
+ }
99
+
100
+ COCO_TO_POSE_FORMAT = {
101
+ 0: ("POSE_LANDMARKS", "NOSE"),
102
+ 1: ("POSE_LANDMARKS", "LEFT_EYE"),
103
+ 2: ("POSE_LANDMARKS", "RIGHT_EYE"),
104
+ 3: ("POSE_LANDMARKS", "LEFT_EAR"),
105
+ 4: ("POSE_LANDMARKS", "RIGHT_EAR"),
106
+ 5: ("POSE_LANDMARKS", "LEFT_SHOULDER"),
107
+ 6: ("POSE_LANDMARKS", "RIGHT_SHOULDER"),
108
+ 7: ("POSE_LANDMARKS", "LEFT_ELBOW"),
109
+ 8: ("POSE_LANDMARKS", "RIGHT_ELBOW"),
110
+ 9: ("POSE_LANDMARKS", "LEFT_WRIST"),
111
+ 10: ("POSE_LANDMARKS", "RIGHT_WRIST"),
112
+ 11: ("POSE_LANDMARKS", "LEFT_HIP"),
113
+ 12: ("POSE_LANDMARKS", "RIGHT_HIP"),
114
+ 13: ("POSE_LANDMARKS", "LEFT_KNEE"),
115
+ 14: ("POSE_LANDMARKS", "RIGHT_KNEE"),
116
+ 15: ("POSE_LANDMARKS", "LEFT_ANKLE"),
117
+ 16: ("POSE_LANDMARKS", "RIGHT_ANKLE"),
118
+ 91: ("LEFT_HAND_LANDMARKS", "WRIST"),
119
+ 92: ("LEFT_HAND_LANDMARKS", "THUMB_CMC"),
120
+ 93: ("LEFT_HAND_LANDMARKS", "THUMB_MCP"),
121
+ 94: ("LEFT_HAND_LANDMARKS", "THUMB_IP"),
122
+ 95: ("LEFT_HAND_LANDMARKS", "THUMB_TIP"),
123
+ 96: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_MCP"),
124
+ 97: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_PIP"),
125
+ 98: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_DIP"),
126
+ 99: ("LEFT_HAND_LANDMARKS", "INDEX_FINGER_TIP"),
127
+ 100: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
128
+ 101: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_PIP"),
129
+ 102: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_DIP"),
130
+ 103: ("LEFT_HAND_LANDMARKS", "MIDDLE_FINGER_TIP"),
131
+ 104: ("LEFT_HAND_LANDMARKS", "RING_FINGER_MCP"),
132
+ 105: ("LEFT_HAND_LANDMARKS", "RING_FINGER_PIP"),
133
+ 106: ("LEFT_HAND_LANDMARKS", "RING_FINGER_DIP"),
134
+ 107: ("LEFT_HAND_LANDMARKS", "RING_FINGER_TIP"),
135
+ 108: ("LEFT_HAND_LANDMARKS", "PINKY_MCP"),
136
+ 109: ("LEFT_HAND_LANDMARKS", "PINKY_PIP"),
137
+ 110: ("LEFT_HAND_LANDMARKS", "PINKY_DIP"),
138
+ 111: ("LEFT_HAND_LANDMARKS", "PINKY_TIP"),
139
+ 112: ("RIGHT_HAND_LANDMARKS", "WRIST"),
140
+ 113: ("RIGHT_HAND_LANDMARKS", "THUMB_CMC"),
141
+ 114: ("RIGHT_HAND_LANDMARKS", "THUMB_MCP"),
142
+ 115: ("RIGHT_HAND_LANDMARKS", "THUMB_IP"),
143
+ 116: ("RIGHT_HAND_LANDMARKS", "THUMB_TIP"),
144
+ 117: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_MCP"),
145
+ 118: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_PIP"),
146
+ 119: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_DIP"),
147
+ 120: ("RIGHT_HAND_LANDMARKS", "INDEX_FINGER_TIP"),
148
+ 121: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_MCP"),
149
+ 122: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_PIP"),
150
+ 123: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_DIP"),
151
+ 124: ("RIGHT_HAND_LANDMARKS", "MIDDLE_FINGER_TIP"),
152
+ 125: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_MCP"),
153
+ 126: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_PIP"),
154
+ 127: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_DIP"),
155
+ 128: ("RIGHT_HAND_LANDMARKS", "RING_FINGER_TIP"),
156
+ 129: ("RIGHT_HAND_LANDMARKS", "PINKY_MCP"),
157
+ 130: ("RIGHT_HAND_LANDMARKS", "PINKY_PIP"),
158
+ 131: ("RIGHT_HAND_LANDMARKS", "PINKY_DIP"),
159
+ 132: ("RIGHT_HAND_LANDMARKS", "PINKY_TIP"),
160
+ }
utils/loggers.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #utils/loggers.py
2
+
3
+ import sys
4
+ import logging
5
+ from pathlib import Path
6
+ from transformers import TrainerCallback
7
+
8
+
9
+ class TrainingCallback(TrainerCallback):
10
+ def on_log(self, args, state, control, logs=None, **kwargs):
11
+ logging.info(logs)
12
+
13
+
14
+ def config_logger(log_file: str = None) -> None:
15
+ handlers = [logging.StreamHandler(sys.stdout)]
16
+ if log_file is not None:
17
+ log_dir = Path(log_file).parent
18
+ if not log_dir.exists():
19
+ log_dir.mkdir(parents=True, exist_ok=True)
20
+ handlers.append(logging.FileHandler(filename=log_file))
21
+ logging.basicConfig(
22
+ datefmt="%m/%d/%Y %H:%M:%S",
23
+ level=logging.INFO,
24
+ format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
25
+ handlers=handlers
26
+ )
visualization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .utils import *
visualization/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (187 Bytes). View file
 
visualization/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (221 Bytes). View file
 
visualization/__pycache__/utils.cpython-312.pyc ADDED
Binary file (2.44 kB). View file
 
visualization/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.7 kB). View file
 
visualization/utils.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #visualization/utils.py
2
+
3
+ import torch
4
+ import numpy as np
5
+ from imageio import mimsave
6
+ from PIL import Image, ImageDraw, ImageFont
7
+
8
+
9
+ def unnormalize_img(image: np.ndarray, std: tuple, mean: tuple) -> np.ndarray:
10
+ image = (image * std) + mean
11
+ image = (image * 255).astype('uint8')
12
+ return image.clip(0, 255)
13
+
14
+
15
+ def save_as_gif(
16
+ video_tensor: torch.Tensor,
17
+ save_path: str = 'sample.gif',
18
+ std: tuple = None,
19
+ mean: tuple = None,
20
+ ):
21
+ frames = []
22
+ for video_frame in video_tensor:
23
+ frame_unnormalized = unnormalize_img(
24
+ image=video_frame.permute(1, 2, 0).numpy(),
25
+ std=std,
26
+ mean=mean,
27
+ )
28
+ frames.append(frame_unnormalized)
29
+ kargs = {'duration': 0.25}
30
+ mimsave(save_path, frames, 'GIF', **kargs)
31
+ return save_path
32
+
33
+
34
+ def display_gif(gif_path: str) -> Image:
35
+ return Image(filename=gif_path)
36
+
37
+
38
+ def draw_text_on_image(
39
+ image: np.ndarray,
40
+ text: str,
41
+ position: tuple = (20, 20),
42
+ color: tuple = (0, 0, 255),
43
+ font_size: int = 20,
44
+ ) -> np.ndarray:
45
+ font = ImageFont.truetype(
46
+ font="fonts/OpenSans-Regular.ttf",
47
+ size=font_size,
48
+ )
49
+ pil_image = Image.fromarray(image)
50
+ draw = ImageDraw.Draw(pil_image)
51
+ draw.text(
52
+ xy=position,
53
+ text=text,
54
+ fill=color,
55
+ font=font,
56
+ )
57
+ return np.array(pil_image)