Text-model / app.py
A7md47's picture
Add text emotion classifier (config, tokenizer, app, Dockerfile)
0d0980f verified
Raw
History Blame Contribute Delete
2.87 kB
import base64
import json
import os
import torch
import torch.nn.functional as F
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
)
app = FastAPI(title="Text Emotion Recognition")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_LABELS = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"]
_model = None
_tokenizer = None
_device = None
def load_model():
global _model, _tokenizer, _device
_device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = os.path.dirname(os.path.abspath(__file__))
_tokenizer = AutoTokenizer.from_pretrained(model_path)
_model = AutoModelForSequenceClassification.from_pretrained(model_path).to(_device).eval()
print(f"[INFO] Model loaded on {_device}")
@app.on_event("startup")
async def startup():
load_model()
@app.get("/")
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": _model is not None,
"device": _device,
}
@app.post("/predict_b64")
async def predict_b64(request: Request):
try:
body = await request.body()
content_type = request.headers.get("content-type", "")
if "application/json" in content_type or body.startswith(b"{"):
payload = json.loads(body)
text_b64 = payload.get("text", "")
else:
import urllib.parse
parsed = urllib.parse.parse_qs(body.decode())
raw = parsed.get("data", [None])[0]
if raw is None:
raise HTTPException(status_code=400, detail="Missing 'data' field")
payload = json.loads(raw)
text_b64 = payload.get("text", "")
if not text_b64:
raise HTTPException(status_code=400, detail="No text data found")
try:
text = base64.b64decode(text_b64).decode("utf-8")
except Exception:
text = text_b64
inputs = _tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(_device)
with torch.no_grad():
outputs = _model(**inputs)
probs = F.softmax(outputs.logits, dim=-1).squeeze(0)
probs_np = probs.cpu().numpy()
pred_idx = int(probs_np.argmax())
emotion = MODEL_LABELS[pred_idx]
prob_map = {c: round(float(probs_np[i]), 4) for i, c in enumerate(MODEL_LABELS)}
return {
"emotion": emotion,
"confidence": round(float(probs_np[pred_idx]), 4),
"probabilities": prob_map,
}
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))