thienphuc12339's picture
Update app.py
4d27720 verified
# api/main.py
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
from pydantic import BaseModel
from pathlib import Path
import shutil
import logging
import uvicorn
from typing import Optional, List
import pandas as pd
import json
from configs import ModelConfig, InferenceConfig
from tools.models import load_pipeline
from inference import inference as run_inference
# Initialize FastAPI app
app = FastAPI(title="Sign Language Recognition API")
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define a Pydantic model for individual prediction
class Prediction(BaseModel):
gloss: str
score: float
start_time: float
end_time: float
inference_time: float
# Define a Pydantic model for the response
class InferenceResponse(BaseModel):
status: str
predictions: Optional[List[Prediction]] = None
message: Optional[str] = None
# Define id2gloss mapping
# Đây là một ví dụ. Bạn cần thay thế bằng bản đồ thực tế của bạn.
id2gloss = {
"0": "hello",
"1": "thanks",
"2": "yes",
# Thêm các ánh xạ cần thiết
}
@app.post("/inference", response_model=InferenceResponse)
async def inference_endpoint(
file: UploadFile = File(...),
model_name: str = Form(...),
output_dir: Optional[str] = Form("output")
):
"""
Endpoint để xử lý yêu cầu nhận diện ngôn ngữ ký hiệu từ video.
Args:
file (UploadFile): Video file được tải lên.
model_name (str): Tên mô hình sẽ sử dụng (ví dụ: 'spoter', 'sl_gcn', 'dsta_slr').
output_dir (str, optional): Thư mục để lưu kết quả. Mặc định là 'output'.
Returns:
InferenceResponse: Kết quả nhận diện.
"""
try:
# Kiểm tra file có hợp lệ không
if not file.filename.endswith((".mp4", ".avi", ".mov", ".mkv")):
raise HTTPException(status_code=400, detail="Unsupported file type.")
# Tạo thư mục output nếu không tồn tại
output_path = Path(output_dir)
output_path.mkdir(parents=True, exist_ok=True)
# Lưu video tạm thời
video_path = output_path / file.filename
with open(video_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"Video saved to {video_path}")
# Tải cấu hình mô hình dựa trên model_name
try:
if model_name == "spoter":
model_config = ModelConfig(arch="spoter", pretrained="vsltranslation/spoter_v3.0")
elif model_name == "sl_gcn":
model_config = ModelConfig(arch="sl_gcn", pretrained="vsltranslation/sl_gcn_joint_v3_0")
elif model_name == "dsta_slr":
model_config = ModelConfig(arch="dsta_slr", pretrained="models/dsta_slr_joint_motion_v3_0.onnx")
else:
raise ValueError("Unsupported model name.")
inference_config = InferenceConfig(
source=str(video_path),
output_dir=str(output_path),
use_onnx=True if model_config.pretrained.endswith(".onnx") else False,
device="cpu", # Bạn có thể thay đổi thành "cuda" nếu sử dụng GPU
cache_dir="models/huggingface",
visualize=False,
show_skeleton=False,
visibility=0.5,
angle_threshold=140,
min_num_up_frames=10,
min_num_down_frames=10,
delay=400,
top_k=3,
bone_stream=False,
motion_stream=True # Theo cấu hình YAML bạn cung cấp
)
# Tải pipeline hoặc session
pipeline_or_session = load_pipeline(model_config, inference_config)
logger.info("Pipeline loaded successfully.")
# Chạy inference
run_inference(model_config, inference_config, pipeline_or_session)
logger.info("Inference completed successfully.")
# Đọc kết quả từ results.csv
results_csv = output_path / "results.csv"
if not results_csv.exists():
raise HTTPException(status_code=500, detail="Inference did not produce results.")
results_df = pd.read_csv(results_csv)
# Chuyển đổi DataFrame thành list of Prediction
predictions = []
for _, row in results_df.iterrows():
# Giả sử results.csv có các cột: start_time, end_time, inference_time, prediction
# Và 'prediction' là một danh sách các từ điển với 'gloss' và 'score'
start_time = row.get("start_time", 0.0)
end_time = row.get("end_time", 0.0)
inference_time = row.get("inference_time", 0.0)
prediction_list = row.get("prediction", [])
if isinstance(prediction_list, str):
# Nếu prediction được lưu dưới dạng chuỗi JSON
try:
prediction_list = json.loads(prediction_list.replace("'", '"'))
except json.JSONDecodeError:
prediction_list = []
for pred in prediction_list:
gloss = pred.get('gloss', 'Unknown')
score = pred.get('score', 0.0)
predictions.append(Prediction(
gloss=gloss,
score=score,
start_time=start_time,
end_time=end_time,
inference_time=inference_time
))
return InferenceResponse(status="success", predictions=predictions)
except ValueError as ve:
logger.exception("ValueError during inference.")
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.exception("Error during inference.")
raise HTTPException(status_code=500, detail=str(e))
except Exception as e:
logger.exception("Unexpected error.")
raise HTTPException(status_code=500, detail="An unexpected error occurred.")
if __name__ == "__main__":
uvicorn.run("api.main:app", host="0.0.0.0", port=8000, reload=True)