import base64 import json import os import torch import torch.nn.functional as F from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, ) app = FastAPI(title="Text Emotion Recognition") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) MODEL_LABELS = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"] _model = None _tokenizer = None _device = None def load_model(): global _model, _tokenizer, _device _device = "cuda" if torch.cuda.is_available() else "cpu" model_path = os.path.dirname(os.path.abspath(__file__)) _tokenizer = AutoTokenizer.from_pretrained(model_path) _model = AutoModelForSequenceClassification.from_pretrained(model_path).to(_device).eval() print(f"[INFO] Model loaded on {_device}") @app.on_event("startup") async def startup(): load_model() @app.get("/") @app.get("/health") async def health(): return { "status": "ok", "model_loaded": _model is not None, "device": _device, } @app.post("/predict_b64") async def predict_b64(request: Request): try: body = await request.body() content_type = request.headers.get("content-type", "") if "application/json" in content_type or body.startswith(b"{"): payload = json.loads(body) text_b64 = payload.get("text", "") else: import urllib.parse parsed = urllib.parse.parse_qs(body.decode()) raw = parsed.get("data", [None])[0] if raw is None: raise HTTPException(status_code=400, detail="Missing 'data' field") payload = json.loads(raw) text_b64 = payload.get("text", "") if not text_b64: raise HTTPException(status_code=400, detail="No text data found") try: text = base64.b64decode(text_b64).decode("utf-8") except Exception: text = text_b64 inputs = _tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(_device) with torch.no_grad(): outputs = _model(**inputs) probs = F.softmax(outputs.logits, dim=-1).squeeze(0) probs_np = probs.cpu().numpy() pred_idx = int(probs_np.argmax()) emotion = MODEL_LABELS[pred_idx] prob_map = {c: round(float(probs_np[i]), 4) for i, c in enumerate(MODEL_LABELS)} return { "emotion": emotion, "confidence": round(float(probs_np[pred_idx]), 4), "probabilities": prob_map, } except HTTPException: raise except Exception as e: raise HTTPException(status_code=500, detail=str(e))