from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import torch.nn.functional as F import uvicorn # ===== 1. CẤU HÌNH MODEL EMOTION (5 NHÃN) ===== MODEL_NAME = "vijjj1/emotion3" # Repo chứa model 5 nhãn của bạn # Định nghĩa nhãn khớp với thứ tự huấn luyện (0->4) id2label = { 0: "neutral", 1: "positive", 2: "negative", 3: "angry", 4: "sarcasm" } # ===== 2. LOAD MODEL ===== print("Loading model...") try: # use_fast=False để tránh lỗi với các model gốc PhoBERT/ViT5 tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() # Chuyển sang chế độ dự đoán print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # ===== 3. KHỞI TẠO API ===== app = FastAPI(title="Emotion Analysis API", version="5-Labels") # Cấu hình CORS để Extension gọi được app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def home(): return {"message": "Emotion Analysis API (5 Labels) is running!"} # ===== 4. XỬ LÝ DỰ ĐOÁN ===== @app.post("/predict") async def predict(req: Request): try: data = await req.json() # Extension gửi key là "comment" text = data.get("comment", "").strip() if not text: return {"error": "Vui lòng nhập nội dung bình luận"} # Tokenize inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) # Inference with torch.no_grad(): outputs = model(**inputs) # Tính xác suất bằng Softmax probs = F.softmax(outputs.logits, dim=1)[0] # Tìm nhãn có điểm cao nhất pred_idx = torch.argmax(probs).item() pred_label = id2label.get(pred_idx, "neutral") max_score = probs[pred_idx].item() # Trả về kết quả return { "label": pred_label, # Ví dụ: "angry" "score": max_score, # Ví dụ: 0.945 "probabilities": probs.tolist() # Danh sách điểm [0.01, 0.02, ...] để vẽ biểu đồ nếu cần } except Exception as e: return {"error": str(e)} # Run Local (Dùng khi chạy thử trên máy) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)