Spaces:
Runtime error
Runtime error
| # api/main.py | |
| from fastapi import FastAPI, HTTPException, UploadFile, File, Form | |
| from pydantic import BaseModel | |
| from pathlib import Path | |
| import shutil | |
| import logging | |
| import uvicorn | |
| from typing import Optional, List | |
| import pandas as pd | |
| import json | |
| from configs import ModelConfig, InferenceConfig | |
| from tools.models import load_pipeline | |
| from inference import inference as run_inference | |
| # Initialize FastAPI app | |
| app = FastAPI(title="Sign Language Recognition API") | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Define a Pydantic model for individual prediction | |
| class Prediction(BaseModel): | |
| gloss: str | |
| score: float | |
| start_time: float | |
| end_time: float | |
| inference_time: float | |
| # Define a Pydantic model for the response | |
| class InferenceResponse(BaseModel): | |
| status: str | |
| predictions: Optional[List[Prediction]] = None | |
| message: Optional[str] = None | |
| # Define id2gloss mapping | |
| # Đây là một ví dụ. Bạn cần thay thế bằng bản đồ thực tế của bạn. | |
| id2gloss = { | |
| "0": "hello", | |
| "1": "thanks", | |
| "2": "yes", | |
| # Thêm các ánh xạ cần thiết | |
| } | |
| async def inference_endpoint( | |
| file: UploadFile = File(...), | |
| model_name: str = Form(...), | |
| output_dir: Optional[str] = Form("output") | |
| ): | |
| """ | |
| Endpoint để xử lý yêu cầu nhận diện ngôn ngữ ký hiệu từ video. | |
| Args: | |
| file (UploadFile): Video file được tải lên. | |
| model_name (str): Tên mô hình sẽ sử dụng (ví dụ: 'spoter', 'sl_gcn', 'dsta_slr'). | |
| output_dir (str, optional): Thư mục để lưu kết quả. Mặc định là 'output'. | |
| Returns: | |
| InferenceResponse: Kết quả nhận diện. | |
| """ | |
| try: | |
| # Kiểm tra file có hợp lệ không | |
| if not file.filename.endswith((".mp4", ".avi", ".mov", ".mkv")): | |
| raise HTTPException(status_code=400, detail="Unsupported file type.") | |
| # Tạo thư mục output nếu không tồn tại | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| # Lưu video tạm thời | |
| video_path = output_path / file.filename | |
| with open(video_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| logger.info(f"Video saved to {video_path}") | |
| # Tải cấu hình mô hình dựa trên model_name | |
| try: | |
| if model_name == "spoter": | |
| model_config = ModelConfig(arch="spoter", pretrained="vsltranslation/spoter_v3.0") | |
| elif model_name == "sl_gcn": | |
| model_config = ModelConfig(arch="sl_gcn", pretrained="vsltranslation/sl_gcn_joint_v3_0") | |
| elif model_name == "dsta_slr": | |
| model_config = ModelConfig(arch="dsta_slr", pretrained="models/dsta_slr_joint_motion_v3_0.onnx") | |
| else: | |
| raise ValueError("Unsupported model name.") | |
| inference_config = InferenceConfig( | |
| source=str(video_path), | |
| output_dir=str(output_path), | |
| use_onnx=True if model_config.pretrained.endswith(".onnx") else False, | |
| device="cpu", # Bạn có thể thay đổi thành "cuda" nếu sử dụng GPU | |
| cache_dir="models/huggingface", | |
| visualize=False, | |
| show_skeleton=False, | |
| visibility=0.5, | |
| angle_threshold=140, | |
| min_num_up_frames=10, | |
| min_num_down_frames=10, | |
| delay=400, | |
| top_k=3, | |
| bone_stream=False, | |
| motion_stream=True # Theo cấu hình YAML bạn cung cấp | |
| ) | |
| # Tải pipeline hoặc session | |
| pipeline_or_session = load_pipeline(model_config, inference_config) | |
| logger.info("Pipeline loaded successfully.") | |
| # Chạy inference | |
| run_inference(model_config, inference_config, pipeline_or_session) | |
| logger.info("Inference completed successfully.") | |
| # Đọc kết quả từ results.csv | |
| results_csv = output_path / "results.csv" | |
| if not results_csv.exists(): | |
| raise HTTPException(status_code=500, detail="Inference did not produce results.") | |
| results_df = pd.read_csv(results_csv) | |
| # Chuyển đổi DataFrame thành list of Prediction | |
| predictions = [] | |
| for _, row in results_df.iterrows(): | |
| # Giả sử results.csv có các cột: start_time, end_time, inference_time, prediction | |
| # Và 'prediction' là một danh sách các từ điển với 'gloss' và 'score' | |
| start_time = row.get("start_time", 0.0) | |
| end_time = row.get("end_time", 0.0) | |
| inference_time = row.get("inference_time", 0.0) | |
| prediction_list = row.get("prediction", []) | |
| if isinstance(prediction_list, str): | |
| # Nếu prediction được lưu dưới dạng chuỗi JSON | |
| try: | |
| prediction_list = json.loads(prediction_list.replace("'", '"')) | |
| except json.JSONDecodeError: | |
| prediction_list = [] | |
| for pred in prediction_list: | |
| gloss = pred.get('gloss', 'Unknown') | |
| score = pred.get('score', 0.0) | |
| predictions.append(Prediction( | |
| gloss=gloss, | |
| score=score, | |
| start_time=start_time, | |
| end_time=end_time, | |
| inference_time=inference_time | |
| )) | |
| return InferenceResponse(status="success", predictions=predictions) | |
| except ValueError as ve: | |
| logger.exception("ValueError during inference.") | |
| raise HTTPException(status_code=400, detail=str(ve)) | |
| except Exception as e: | |
| logger.exception("Error during inference.") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| except Exception as e: | |
| logger.exception("Unexpected error.") | |
| raise HTTPException(status_code=500, detail="An unexpected error occurred.") | |
| if __name__ == "__main__": | |
| uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True) | |