vit / app.py
vijjj1's picture
Update app.py
4938926 verified
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn.functional as F
import uvicorn
# ===== 1. CẤU HÌNH MODEL EMOTION (5 NHÃN) =====
MODEL_NAME = "vijjj1/emotion3" # Repo chứa model 5 nhãn của bạn
# Định nghĩa nhãn khớp với thứ tự huấn luyện (0->4)
id2label = {
0: "neutral",
1: "positive",
2: "negative",
3: "angry",
4: "sarcasm"
}
# ===== 2. LOAD MODEL =====
print("Loading model...")
try:
# use_fast=False để tránh lỗi với các model gốc PhoBERT/ViT5
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
model.eval() # Chuyển sang chế độ dự đoán
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
# ===== 3. KHỞI TẠO API =====
app = FastAPI(title="Emotion Analysis API", version="5-Labels")
# Cấu hình CORS để Extension gọi được
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def home():
return {"message": "Emotion Analysis API (5 Labels) is running!"}
# ===== 4. XỬ LÝ DỰ ĐOÁN =====
@app.post("/predict")
async def predict(req: Request):
try:
data = await req.json()
# Extension gửi key là "comment"
text = data.get("comment", "").strip()
if not text:
return {"error": "Vui lòng nhập nội dung bình luận"}
# Tokenize
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
# Inference
with torch.no_grad():
outputs = model(**inputs)
# Tính xác suất bằng Softmax
probs = F.softmax(outputs.logits, dim=1)[0]
# Tìm nhãn có điểm cao nhất
pred_idx = torch.argmax(probs).item()
pred_label = id2label.get(pred_idx, "neutral")
max_score = probs[pred_idx].item()
# Trả về kết quả
return {
"label": pred_label, # Ví dụ: "angry"
"score": max_score, # Ví dụ: 0.945
"probabilities": probs.tolist() # Danh sách điểm [0.01, 0.02, ...] để vẽ biểu đồ nếu cần
}
except Exception as e:
return {"error": str(e)}
# Run Local (Dùng khi chạy thử trên máy)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)