| | from fastapi import FastAPI, Request |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| | import torch |
| | import torch.nn.functional as F |
| | import uvicorn |
| |
|
| | |
| | MODEL_NAME = "vijjj1/emotion3" |
| |
|
| | |
| | id2label = { |
| | 0: "neutral", |
| | 1: "positive", |
| | 2: "negative", |
| | 3: "angry", |
| | 4: "sarcasm" |
| | } |
| |
|
| | |
| | print("Loading model...") |
| | try: |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
| | model.eval() |
| | print("Model loaded successfully!") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| |
|
| | |
| | app = FastAPI(title="Emotion Analysis API", version="5-Labels") |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | @app.get("/") |
| | def home(): |
| | return {"message": "Emotion Analysis API (5 Labels) is running!"} |
| |
|
| | |
| | @app.post("/predict") |
| | async def predict(req: Request): |
| | try: |
| | data = await req.json() |
| | |
| | text = data.get("comment", "").strip() |
| |
|
| | if not text: |
| | return {"error": "Vui lòng nhập nội dung bình luận"} |
| |
|
| | |
| | inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) |
| |
|
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | |
| | probs = F.softmax(outputs.logits, dim=1)[0] |
| |
|
| | |
| | pred_idx = torch.argmax(probs).item() |
| | pred_label = id2label.get(pred_idx, "neutral") |
| | max_score = probs[pred_idx].item() |
| |
|
| | |
| | return { |
| | "label": pred_label, |
| | "score": max_score, |
| | "probabilities": probs.tolist() |
| | } |
| |
|
| | except Exception as e: |
| | return {"error": str(e)} |
| |
|
| | |
| | if __name__ == "__main__": |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |