Spaces:
Runtime error
Runtime error
| import logging | |
| from time import time | |
| import pandas as pd | |
| import numpy as np | |
| import cv2 | |
| from typing import Optional | |
| from pathlib import Path | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Query | |
| from fastapi.responses import JSONResponse | |
| import mediapipe as mp | |
| from configs import ModelConfig, InferenceConfig | |
| from utils import config_logger, POSE_BASED_MODELS | |
| from data import Arm, get_sample_timestamp, ok_to_get_frame | |
| from tools import load_pipeline, Predictions | |
| from visualization import draw_text_on_image | |
| app = FastAPI() | |
| # Định nghĩa ba preset model | |
| MODEL_PRESETS = { | |
| "dsta_slr": { | |
| "model": ModelConfig( | |
| arch="dsta_slr", | |
| pretrained="models/dsta_slr_joint_motion_v3_0.onnx", | |
| ), | |
| "inference": InferenceConfig( | |
| source="upload", # Sử dụng upload, không webcam | |
| output_dir="demo/run_1", | |
| use_onnx=True, | |
| show_skeleton=True, | |
| visualize=True, | |
| bone_stream=False, | |
| motion_stream=True, | |
| ), | |
| }, | |
| "sl_gcn": { | |
| "model": ModelConfig( | |
| arch="sl_gcn", | |
| pretrained="models/dsta_slr_joint_motion_v3_0.onnx", | |
| ), | |
| "inference": InferenceConfig( | |
| source="upload", | |
| output_dir="demo/run_1", | |
| use_onnx=True, | |
| show_skeleton=True, | |
| visualize=True, | |
| bone_stream=True, | |
| motion_stream=False, | |
| ), | |
| }, | |
| "spoter": { | |
| "model": ModelConfig( | |
| arch="spoter", | |
| pretrained="models/spoter_v3.0.onnx", | |
| ), | |
| "inference": InferenceConfig( | |
| source="upload", | |
| output_dir="demo/run_1", | |
| use_onnx=True, | |
| show_skeleton=True, | |
| visualize=True, | |
| ), | |
| }, | |
| } | |
| config_logger("inference.log") | |
| logging.info("API started") | |
| SPOTER_POSE_LANDMARKS = [ | |
| mp.solutions.pose.PoseLandmark.NOSE, | |
| mp.solutions.pose.PoseLandmark.LEFT_EYE, | |
| mp.solutions.pose.PoseLandmark.RIGHT_EYE, | |
| mp.solutions.pose.PoseLandmark.RIGHT_SHOULDER, | |
| mp.solutions.pose.PoseLandmark.LEFT_SHOULDER, | |
| mp.solutions.pose.PoseLandmark.RIGHT_ELBOW, | |
| mp.solutions.pose.PoseLandmark.LEFT_ELBOW, | |
| mp.solutions.pose.PoseLandmark.RIGHT_WRIST, | |
| mp.solutions.pose.PoseLandmark.LEFT_WRIST | |
| ] | |
| SPOTER_HAND_LANDMARKS = [ | |
| mp.solutions.hands.HandLandmark.WRIST, | |
| mp.solutions.hands.HandLandmark.INDEX_FINGER_TIP, mp.solutions.hands.HandLandmark.INDEX_FINGER_DIP, | |
| mp.solutions.hands.HandLandmark.INDEX_FINGER_PIP, mp.solutions.hands.HandLandmark.INDEX_FINGER_MCP, | |
| mp.solutions.hands.HandLandmark.MIDDLE_FINGER_TIP, mp.solutions.hands.HandLandmark.MIDDLE_FINGER_DIP, | |
| mp.solutions.hands.HandLandmark.MIDDLE_FINGER_PIP, mp.solutions.hands.HandLandmark.MIDDLE_FINGER_MCP, | |
| mp.solutions.hands.HandLandmark.RING_FINGER_TIP, mp.solutions.hands.HandLandmark.RING_FINGER_DIP, | |
| mp.solutions.hands.HandLandmark.RING_FINGER_PIP, mp.solutions.hands.HandLandmark.RING_FINGER_MCP, | |
| mp.solutions.hands.HandLandmark.PINKY_TIP, mp.solutions.hands.HandLandmark.PINKY_DIP, | |
| mp.solutions.hands.HandLandmark.PINKY_PIP, mp.solutions.hands.HandLandmark.PINKY_MCP, | |
| mp.solutions.hands.HandLandmark.THUMB_TIP, mp.solutions.hands.HandLandmark.THUMB_IP, | |
| mp.solutions.hands.HandLandmark.THUMB_MCP, mp.solutions.hands.HandLandmark.THUMB_CMC, | |
| ] | |
| async def healthcheck(): | |
| return JSONResponse(status_code=200, content={"status": "UP"}) | |
| def run_inference(model_config, inference_config, input_frames): | |
| pipeline = load_pipeline(model_config, inference_config) | |
| logging.info("Pipeline loaded") | |
| right_arm = Arm("right", inference_config.visibility) | |
| left_arm = Arm("left", inference_config.visibility) | |
| data = [] | |
| results = None | |
| predictions = Predictions() | |
| mp_holistic = mp.solutions.holistic | |
| mp_drawing = mp.solutions.drawing_utils | |
| mp_drawing_styles = mp.solutions.drawing_styles | |
| custom_pose_style = mp_drawing_styles.get_default_pose_landmarks_style() | |
| custom_right_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() | |
| custom_left_hand_style = mp_drawing_styles.get_default_hand_landmarks_style() | |
| custom_pose_connections = list(mp_holistic.POSE_CONNECTIONS) | |
| custom_hand_connections = list(mp_holistic.HAND_CONNECTIONS) | |
| if inference_config.show_skeleton: | |
| pose_landmarks = SPOTER_POSE_LANDMARKS | |
| hand_landmarks = SPOTER_HAND_LANDMARKS | |
| for landmark in mp.solutions.pose.PoseLandmark: | |
| if landmark in pose_landmarks: | |
| custom_pose_style[landmark] = mp.drawing.DrawingSpec(color=(0,255,0), thickness=2, circle_radius=2) | |
| else: | |
| custom_pose_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) | |
| for connection_tuple in custom_pose_connections: | |
| if landmark.value in connection_tuple: | |
| custom_pose_connections.remove(connection_tuple) | |
| for landmark in mp.solutions.hands.HandLandmark: | |
| if landmark in hand_landmarks: | |
| custom_right_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,255), thickness=2, circle_radius=2) | |
| custom_left_hand_style[landmark] = mp.drawing.DrawingSpec(color=(255,0,0), thickness=2, circle_radius=2) | |
| else: | |
| custom_right_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) | |
| custom_left_hand_style[landmark] = mp.drawing.DrawingSpec(color=(0,0,0), thickness=0, circle_radius=0) | |
| for connection_tuple in custom_hand_connections: | |
| if landmark.value in connection_tuple: | |
| custom_hand_connections.remove(connection_tuple) | |
| writer = None | |
| if inference_config.output_dir is not None: | |
| out_path = Path(inference_config.output_dir) | |
| out_path.mkdir(parents=True, exist_ok=True) | |
| if len(input_frames) > 0 and isinstance(input_frames[0], np.ndarray): | |
| h, w, _ = input_frames[0].shape | |
| writer = cv2.VideoWriter(str(out_path / "output.mp4"), cv2.VideoWriter_fourcc(*"mp4v"), 30, (w, h)) | |
| with mp_holistic.Holistic(min_detection_confidence=0.9, min_tracking_confidence=0.5) as holistic: | |
| # giả định mỗi frame ~33ms, ở đây chỉ là demo logic | |
| current_time_ms = 0 | |
| for frame in input_frames: | |
| rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| rgb_frame.flags.writeable = False | |
| detection_results = holistic.process(rgb_frame) | |
| try: | |
| landmarks = detection_results.pose_landmarks.landmark | |
| except: | |
| current_time_ms += 33 | |
| continue | |
| left_arm.set_pose(landmarks) | |
| right_arm.set_pose(landmarks) | |
| left_arm_ok_to_get_frame = ok_to_get_frame( | |
| arm=left_arm, | |
| angle_threshold=inference_config.angle_threshold, | |
| min_num_up_frames=inference_config.min_num_up_frames, | |
| min_num_down_frames=inference_config.min_num_down_frames, | |
| current_time=current_time_ms, | |
| delay=inference_config.delay, | |
| ) | |
| right_arm_ok_to_get_frame = ok_to_get_frame( | |
| arm=right_arm, | |
| angle_threshold=inference_config.angle_threshold, | |
| min_num_up_frames=inference_config.min_num_up_frames, | |
| min_num_down_frames=inference_config.min_num_down_frames, | |
| current_time=current_time_ms, | |
| delay=inference_config.delay, | |
| ) | |
| if left_arm_ok_to_get_frame or right_arm_ok_to_get_frame: | |
| predictions = Predictions() | |
| data.append(detection_results if inference_config.use_pose_model else frame) | |
| start_time, end_time = get_sample_timestamp(left_arm, right_arm) | |
| start_time /= 1000 | |
| end_time /= 1000 | |
| if start_time != 0 and end_time != 0: | |
| start_inference_time = time() | |
| predictions = Predictions(predictions=pipeline(np.array(data))) | |
| predictions.inference_time = time() - start_inference_time | |
| predictions.start_time = start_time | |
| predictions.end_time = end_time | |
| logging.info(str(predictions)) | |
| results = predictions.merge_results(results) | |
| # Reset | |
| start_time = 0 | |
| end_time = 0 | |
| left_arm.reset_state() | |
| right_arm.reset_state() | |
| data = [] | |
| # Vẽ kết quả | |
| frame = left_arm.visualize(frame, (20, 10), "Left arm angle") | |
| frame = right_arm.visualize(frame, (20, 40), "Right arm angle") | |
| frame = predictions.visualize(frame, (20, 70)) | |
| if inference_config.show_skeleton: | |
| mp.drawing.draw_landmarks( | |
| frame, | |
| detection_results.pose_landmarks, | |
| connections=custom_pose_connections, | |
| landmark_drawing_spec=custom_pose_style | |
| ) | |
| mp.drawing.draw_landmarks( | |
| frame, | |
| detection_results.right_hand_landmarks, | |
| connections=custom_hand_connections, | |
| landmark_drawing_spec=custom_right_hand_style | |
| ) | |
| mp.drawing.draw_landmarks( | |
| frame, | |
| detection_results.left_hand_landmarks, | |
| connections=custom_hand_connections, | |
| landmark_drawing_spec=custom_left_hand_style | |
| ) | |
| if writer is not None: | |
| writer.write(frame) | |
| current_time_ms += 33 | |
| if writer is not None: | |
| writer.release() | |
| if results is not None: | |
| pd.DataFrame(results).to_csv(Path(inference_config.output_dir) / "results.csv", index=False) | |
| return predictions.predictions, results | |
| async def inference_endpoint( | |
| model_name: str = Query(..., description="Choose model: dsta_slr, sl_gcn, spoter"), | |
| output_option: str = Query("all", description="Output option: 'predictions', 'csv', 'video', 'all'"), | |
| output_dir: str = Query("demo/run_1", description="Output directory for results"), | |
| file: UploadFile = File(...) | |
| ): | |
| """ | |
| Inference endpoint: | |
| - model_name: chọn mô hình: dsta_slr, sl_gcn, spoter | |
| - output_option: 'predictions', 'csv', 'video', hoặc 'all' | |
| - output_dir: thư mục output, vd: 'my_results' | |
| - file: upload 1 file video | |
| """ | |
| if model_name not in MODEL_PRESETS: | |
| raise HTTPException(status_code=400, detail="Invalid model_name") | |
| # Đọc video từ file upload | |
| video_bytes = np.asarray(bytearray(await file.read()), dtype=np.uint8) | |
| temp_video_path = Path("temp_input.mp4") | |
| with open(temp_video_path, "wb") as f: | |
| f.write(video_bytes) | |
| cap = cv2.VideoCapture(str(temp_video_path)) | |
| input_frames = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| input_frames.append(frame) | |
| cap.release() | |
| # Load config từ preset | |
| model_config = MODEL_PRESETS[model_name]["model"] | |
| inference_config = MODEL_PRESETS[model_name]["inference"] | |
| # Ghi đè output_dir theo yêu cầu người dùng | |
| inference_config.output_dir = output_dir | |
| if model_config.arch in POSE_BASED_MODELS: | |
| inference_config.use_pose_model = True | |
| else: | |
| inference_config.use_pose_model = False | |
| predictions, results = run_inference(model_config, inference_config, input_frames) | |
| resp = {} | |
| out_dir = Path(inference_config.output_dir) | |
| if predictions is None: | |
| predictions = [] | |
| if output_option in ["predictions", "all"]: | |
| resp["predictions"] = predictions | |
| if output_option in ["csv", "all"]: | |
| csv_path = str(out_dir / "results.csv") | |
| resp["csv_path"] = csv_path if Path(csv_path).exists() else None | |
| if output_option in ["video", "all"]: | |
| video_path = str(out_dir / "output.mp4") | |
| resp["video_path"] = video_path if Path(video_path).exists() else None | |
| return resp | |