File size: 6,527 Bytes
9b1822c
a7eca0b
 
 
 
 
 
 
4d27720
 
 
a7eca0b
 
 
 
 
 
 
 
 
 
 
 
4d27720
 
 
 
 
 
 
 
a7eca0b
 
 
4d27720
a7eca0b
 
4d27720
 
 
 
 
 
 
 
 
a7eca0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4d27720
 
 
a7eca0b
4d27720
 
 
a7eca0b
4d27720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7eca0b
4d27720
 
 
 
a7eca0b
4d27720
a7eca0b
4d27720
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7eca0b
4d27720
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# api/main.py

from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from pathlib import Path
import shutil
import logging
import uvicorn
from typing import Optional, List
import pandas as pd
import json

from configs import ModelConfig, InferenceConfig
from tools.models import load_pipeline
from inference import inference as run_inference

# Initialize FastAPI app
app = FastAPI(title="Sign Language Recognition API")

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Define a Pydantic model for individual prediction
class Prediction(BaseModel):
    gloss: str
    score: float
    start_time: float
    end_time: float
    inference_time: float

# Define a Pydantic model for the response
class InferenceResponse(BaseModel):
    status: str
    predictions: Optional[List[Prediction]] = None
    message: Optional[str] = None

# Define id2gloss mapping
# Đây là một ví dụ. Bạn cần thay thế bằng bản đồ thực tế của bạn.
id2gloss = {
    "0": "hello",
    "1": "thanks",
    "2": "yes",
    # Thêm các ánh xạ cần thiết
}

@app.post("/inference", response_model=InferenceResponse)
async def inference_endpoint(
    file: UploadFile = File(...),
    model_name: str = Form(...),
    output_dir: Optional[str] = Form("output")
):
    """
    Endpoint để xử lý yêu cầu nhận diện ngôn ngữ ký hiệu từ video.

    Args:
        file (UploadFile): Video file được tải lên.
        model_name (str): Tên mô hình sẽ sử dụng (ví dụ: 'spoter', 'sl_gcn', 'dsta_slr').
        output_dir (str, optional): Thư mục để lưu kết quả. Mặc định là 'output'.

    Returns:
        InferenceResponse: Kết quả nhận diện.
    """
    try:
        # Kiểm tra file có hợp lệ không
        if not file.filename.endswith((".mp4", ".avi", ".mov", ".mkv")):
            raise HTTPException(status_code=400, detail="Unsupported file type.")
        
        # Tạo thư mục output nếu không tồn tại
        output_path = Path(output_dir)
        output_path.mkdir(parents=True, exist_ok=True)
        
        # Lưu video tạm thời
        video_path = output_path / file.filename
        with open(video_path, "wb") as buffer:
            shutil.copyfileobj(file.file, buffer)
        
        logger.info(f"Video saved to {video_path}")

        # Tải cấu hình mô hình dựa trên model_name
        try:
            if model_name == "spoter":
                model_config = ModelConfig(arch="spoter", pretrained="vsltranslation/spoter_v3.0")
            elif model_name == "sl_gcn":
                model_config = ModelConfig(arch="sl_gcn", pretrained="vsltranslation/sl_gcn_joint_v3_0")
            elif model_name == "dsta_slr":
                model_config = ModelConfig(arch="dsta_slr", pretrained="models/dsta_slr_joint_motion_v3_0.onnx")
            else:
                raise ValueError("Unsupported model name.")
            
            inference_config = InferenceConfig(
                source=str(video_path),
                output_dir=str(output_path),
                use_onnx=True if model_config.pretrained.endswith(".onnx") else False,
                device="cpu",  # Bạn có thể thay đổi thành "cuda" nếu sử dụng GPU
                cache_dir="models/huggingface",
                visualize=False,
                show_skeleton=False,
                visibility=0.5,
                angle_threshold=140,
                min_num_up_frames=10,
                min_num_down_frames=10,
                delay=400,
                top_k=3,
                bone_stream=False,
                motion_stream=True  # Theo cấu hình YAML bạn cung cấp
            )
            
            # Tải pipeline hoặc session
            pipeline_or_session = load_pipeline(model_config, inference_config)
            logger.info("Pipeline loaded successfully.")

            # Chạy inference
            run_inference(model_config, inference_config, pipeline_or_session)
            logger.info("Inference completed successfully.")

            # Đọc kết quả từ results.csv
            results_csv = output_path / "results.csv"
            if not results_csv.exists():
                raise HTTPException(status_code=500, detail="Inference did not produce results.")

            results_df = pd.read_csv(results_csv)

            # Chuyển đổi DataFrame thành list of Prediction
            predictions = []
            for _, row in results_df.iterrows():
                # Giả sử results.csv có các cột: start_time, end_time, inference_time, prediction
                # Và 'prediction' là một danh sách các từ điển với 'gloss' và 'score'
                start_time = row.get("start_time", 0.0)
                end_time = row.get("end_time", 0.0)
                inference_time = row.get("inference_time", 0.0)
                prediction_list = row.get("prediction", [])
                
                if isinstance(prediction_list, str):
                    # Nếu prediction được lưu dưới dạng chuỗi JSON
                    try:
                        prediction_list = json.loads(prediction_list.replace("'", '"'))
                    except json.JSONDecodeError:
                        prediction_list = []
                
                for pred in prediction_list:
                    gloss = pred.get('gloss', 'Unknown')
                    score = pred.get('score', 0.0)
                    predictions.append(Prediction(
                        gloss=gloss,
                        score=score,
                        start_time=start_time,
                        end_time=end_time,
                        inference_time=inference_time
                    ))
            
            return InferenceResponse(status="success", predictions=predictions)

        except ValueError as ve:
            logger.exception("ValueError during inference.")
            raise HTTPException(status_code=400, detail=str(ve))
        except Exception as e:
            logger.exception("Error during inference.")
            raise HTTPException(status_code=500, detail=str(e))
    
    except Exception as e:
        logger.exception("Unexpected error.")
        raise HTTPException(status_code=500, detail="An unexpected error occurred.")

if __name__ == "__main__":
    uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True)