File size: 12,368 Bytes
9f83ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a1f4a
9f83ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a1f4a
9f83ce9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import logging
from time import time
import pandas as pd
import numpy as np
import cv2
from typing import Optional
from pathlib import Path
from fastapi import FastAPI, HTTPException, UploadFile, File, Query
from fastapi.responses import JSONResponse
import mediapipe as mp

from configs import ModelConfig, InferenceConfig
from utils import config_logger, POSE_BASED_MODELS
from data import Arm, get_sample_timestamp, ok_to_get_frame
from tools import load_pipeline, Predictions
from visualization import draw_text_on_image

app = FastAPI()

# Định nghĩa ba preset model
MODEL_PRESETS = {
    "dsta_slr": {
        "model": ModelConfig(
            arch="dsta_slr",
            pretrained="models/dsta_slr_joint_motion_v3_0.onnx",
        ),
        "inference": InferenceConfig(
            source="upload",  # Sử dụng upload, không webcam
            output_dir="demo/run_1",
            use_onnx=True,
            show_skeleton=True,
            visualize=True,
            bone_stream=False,
            motion_stream=True,
        ),
    },
    "sl_gcn": {
        "model": ModelConfig(
            arch="sl_gcn",
            pretrained="models/dsta_slr_joint_motion_v3_0.onnx",
        ),
        "inference": InferenceConfig(
            source="upload",
            output_dir="demo/run_1",
            use_onnx=True,
            show_skeleton=True,
            visualize=True,
            bone_stream=True,
            motion_stream=False,
        ),
    },
    "spoter": {
        "model": ModelConfig(
            arch="spoter",
            pretrained="models/spoter_v3.0.onnx",
        ),
        "inference": InferenceConfig(
            source="upload",
            output_dir="demo/run_1",
            use_onnx=True,
            show_skeleton=True,
            visualize=True,
        ),
    },
}

config_logger("inference.log")
logging.info("API started")

SPOTER_POSE_LANDMARKS = [
    mp.solutions.pose.PoseLandmark.NOSE,
    mp.solutions.pose.PoseLandmark.LEFT_EYE, 
    mp.solutions.pose.PoseLandmark.RIGHT_EYE, 
    mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER,
    mp.solutions.pose.PoseLandmark.LEFT_SHOULDER,
    mp.solutions.pose.PoseLandmark.RIGHT_ELBOW,
    mp.solutions.pose.PoseLandmark.LEFT_ELBOW,
    mp.solutions.pose.PoseLandmark.RIGHT_WRIST,
    mp.solutions.pose.PoseLandmark.LEFT_WRIST 
]

SPOTER_HAND_LANDMARKS = [
    mp.solutions.hands.HandLandmark.WRIST,
    mp.solutions.hands.HandLandmark.INDEX_FINGER_TIP, mp.solutions.hands.HandLandmark.INDEX_FINGER_DIP, 
    mp.solutions.hands.HandLandmark.INDEX_FINGER_PIP, mp.solutions.hands.HandLandmark.INDEX_FINGER_MCP,
    mp.solutions.hands.HandLandmark.MIDDLE_FINGER_TIP, mp.solutions.hands.HandLandmark.MIDDLE_FINGER_DIP, 
    mp.solutions.hands.HandLandmark.MIDDLE_FINGER_PIP, mp.solutions.hands.HandLandmark.MIDDLE_FINGER_MCP,
    mp.solutions.hands.HandLandmark.RING_FINGER_TIP, mp.solutions.hands.HandLandmark.RING_FINGER_DIP, 
    mp.solutions.hands.HandLandmark.RING_FINGER_PIP, mp.solutions.hands.HandLandmark.RING_FINGER_MCP,
    mp.solutions.hands.HandLandmark.PINKY_TIP, mp.solutions.hands.HandLandmark.PINKY_DIP, 
    mp.solutions.hands.HandLandmark.PINKY_PIP, mp.solutions.hands.HandLandmark.PINKY_MCP,
    mp.solutions.hands.HandLandmark.THUMB_TIP, mp.solutions.hands.HandLandmark.THUMB_IP, 
    mp.solutions.hands.HandLandmark.THUMB_MCP, mp.solutions.hands.HandLandmark.THUMB_CMC,
]


@app.get("/healthcheck")
async def healthcheck():
    return JSONResponse(status_code=200, content={"status": "UP"})


def run_inference(model_config, inference_config, input_frames):
    pipeline = load_pipeline(model_config, inference_config)
    logging.info("Pipeline loaded")

    right_arm = Arm("right", inference_config.visibility)
    left_arm = Arm("left", inference_config.visibility)
    data = []
    results = None
    predictions = Predictions()

    mp_holistic = mp.solutions.holistic
    mp_drawing = mp.solutions.drawing_utils
    mp_drawing_styles = mp.solutions.drawing_styles

    custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style()
    custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
    custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style()
    custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS)
    custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS)

    if inference_config.show_skeleton:
        pose_landmarks = SPOTER_POSE_LANDMARKS
        hand_landmarks = SPOTER_HAND_LANDMARKS
        for landmark in mp.solutions.pose.PoseLandmark:
            if landmark in pose_landmarks:
                custom_pose_style[landmark] = mp.drawing.DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2)
            else:
                custom_pose_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) 
                for connection_tuple in custom_pose_connections:
                    if landmark.value in connection_tuple:
                        custom_pose_connections.remove(connection_tuple)
        for landmark in mp.solutions.hands.HandLandmark:
            if landmark in hand_landmarks:
                custom_right_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2)
                custom_left_hand_style[landmark] = mp.drawing.DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2)
            else:
                custom_right_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
                custom_left_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0)
                for connection_tuple in custom_hand_connections:
                    if landmark.value in connection_tuple:
                        custom_hand_connections.remove(connection_tuple)

    writer = None
    if inference_config.output_dir is not None:
        out_path = Path(inference_config.output_dir)
        out_path.mkdir(parents=True, exist_ok=True)
        if len(input_frames) > 0 and isinstance(input_frames[0], np.ndarray):
            h, w, _ = input_frames[0].shape
            writer = cv2.VideoWriter(str(out_path / "output.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), 30, (w, h))

    with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic:
        # giả định mỗi frame ~33ms, ở đây chỉ là demo logic
        current_time_ms = 0
        for frame in input_frames:
            rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            rgb_frame.flags.writeable = False
            detection_results = holistic.process(rgb_frame)

            try:
                landmarks = detection_results.pose_landmarks.landmark
            except:
                current_time_ms += 33
                continue

            left_arm.set_pose(landmarks)
            right_arm.set_pose(landmarks)

            left_arm_ok_to_get_frame = ok_to_get_frame(
                arm=left_arm,
                angle_threshold=inference_config.angle_threshold,
                min_num_up_frames=inference_config.min_num_up_frames,
                min_num_down_frames=inference_config.min_num_down_frames,
                current_time=current_time_ms,
                delay=inference_config.delay,
            )
            right_arm_ok_to_get_frame = ok_to_get_frame(
                arm=right_arm,
                angle_threshold=inference_config.angle_threshold,
                min_num_up_frames=inference_config.min_num_up_frames,
                min_num_down_frames=inference_config.min_num_down_frames,
                current_time=current_time_ms,
                delay=inference_config.delay,
            )

            if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame:
                predictions = Predictions()
                data.append(detection_results if inference_config.use_pose_model else frame)

            start_time, end_time = get_sample_timestamp(left_arm, right_arm)
            start_time /= 1000
            end_time /= 1000

            if start_time != 0 and end_time != 0:
                start_inference_time = time()
                predictions = Predictions(predictions=pipeline(np.array(data)))
                predictions.inference_time = time() - start_inference_time
                predictions.start_time = start_time
                predictions.end_time = end_time
                logging.info(str(predictions))
                results = predictions.merge_results(results)

                # Reset
                start_time = 0
                end_time = 0
                left_arm.reset_state()
                right_arm.reset_state()
                data = []

            # Vẽ kết quả
            frame = left_arm.visualize(frame, (20, 10), "Left arm angle")
            frame = right_arm.visualize(frame, (20, 40), "Right arm angle")
            frame = predictions.visualize(frame, (20, 70))

            if inference_config.show_skeleton:
                mp.drawing.draw_landmarks(
                    frame,
                    detection_results.pose_landmarks,
                    connections=custom_pose_connections,
                    landmark_drawing_spec=custom_pose_style
                )
                mp.drawing.draw_landmarks(
                    frame,
                    detection_results.right_hand_landmarks,
                    connections=custom_hand_connections,
                    landmark_drawing_spec=custom_right_hand_style
                )
                mp.drawing.draw_landmarks(
                    frame,
                    detection_results.left_hand_landmarks,
                    connections=custom_hand_connections,
                    landmark_drawing_spec=custom_left_hand_style
                )

            if writer is not None:
                writer.write(frame)

            current_time_ms += 33

    if writer is not None:
        writer.release()
    if results is not None:
        pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False)

    return predictions.predictions, results


@app.post("/inference")
async def inference_endpoint(
    model_name: str = Query(..., description="Choose model: dsta_slr, sl_gcn, spoter"),
    output_option: str = Query("all", description="Output option: 'predictions', 'csv', 'video', 'all'"),
    output_dir: str = Query("demo/run_1", description="Output directory for results"),
    file: UploadFile = File(...)
):
    """
    Inference endpoint:
    - model_name: chọn mô hình: dsta_slr, sl_gcn, spoter
    - output_option: 'predictions', 'csv', 'video', hoặc 'all'
    - output_dir: thư mục output, vd: 'my_results'
    - file: upload 1 file video
    """

    if model_name not in MODEL_PRESETS:
        raise HTTPException(status_code=400, detail="Invalid model_name")

    # Đọc video từ file upload
    video_bytes = np.asarray(bytearray(await file.read()), dtype=np.uint8)
    temp_video_path = Path("temp_input.mp4")
    with open(temp_video_path, "wb") as f:
        f.write(video_bytes)
    cap = cv2.VideoCapture(str(temp_video_path))

    input_frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        input_frames.append(frame)
    cap.release()

    # Load config từ preset
    model_config = MODEL_PRESETS[model_name]["model"]
    inference_config = MODEL_PRESETS[model_name]["inference"]

    # Ghi đè output_dir theo yêu cầu người dùng
    inference_config.output_dir = output_dir

    if model_config.arch in POSE_BASED_MODELS:
        inference_config.use_pose_model = True
    else:
        inference_config.use_pose_model = False

    predictions, results = run_inference(model_config, inference_config, input_frames)

    resp = {}
    out_dir = Path(inference_config.output_dir)
    if predictions is None:
        predictions = []

    if output_option in ["predictions", "all"]:
        resp["predictions"] = predictions

    if output_option in ["csv", "all"]:
        csv_path = str(out_dir / "results.csv")
        resp["csv_path"] = csv_path if Path(csv_path).exists() else None

    if output_option in ["video", "all"]:
        video_path = str(out_dir / "output.mp4")
        resp["video_path"] = video_path if Path(video_path).exists() else None

    return resp