Spaces:
Sleeping
Sleeping
| import base64 | |
| import json | |
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| ) | |
| app = FastAPI(title="Text Emotion Recognition") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_LABELS = ["Anger", "Disgust", "Fear", "Happy", "Neutral", "Sad", "Surprise"] | |
| _model = None | |
| _tokenizer = None | |
| _device = None | |
| def load_model(): | |
| global _model, _tokenizer, _device | |
| _device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_path = os.path.dirname(os.path.abspath(__file__)) | |
| _tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| _model = AutoModelForSequenceClassification.from_pretrained(model_path).to(_device).eval() | |
| print(f"[INFO] Model loaded on {_device}") | |
| async def startup(): | |
| load_model() | |
| async def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": _model is not None, | |
| "device": _device, | |
| } | |
| async def predict_b64(request: Request): | |
| try: | |
| body = await request.body() | |
| content_type = request.headers.get("content-type", "") | |
| if "application/json" in content_type or body.startswith(b"{"): | |
| payload = json.loads(body) | |
| text_b64 = payload.get("text", "") | |
| else: | |
| import urllib.parse | |
| parsed = urllib.parse.parse_qs(body.decode()) | |
| raw = parsed.get("data", [None])[0] | |
| if raw is None: | |
| raise HTTPException(status_code=400, detail="Missing 'data' field") | |
| payload = json.loads(raw) | |
| text_b64 = payload.get("text", "") | |
| if not text_b64: | |
| raise HTTPException(status_code=400, detail="No text data found") | |
| try: | |
| text = base64.b64decode(text_b64).decode("utf-8") | |
| except Exception: | |
| text = text_b64 | |
| inputs = _tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(_device) | |
| with torch.no_grad(): | |
| outputs = _model(**inputs) | |
| probs = F.softmax(outputs.logits, dim=-1).squeeze(0) | |
| probs_np = probs.cpu().numpy() | |
| pred_idx = int(probs_np.argmax()) | |
| emotion = MODEL_LABELS[pred_idx] | |
| prob_map = {c: round(float(probs_np[i]), 4) for i, c in enumerate(MODEL_LABELS)} | |
| return { | |
| "emotion": emotion, | |
| "confidence": round(float(probs_np[pred_idx]), 4), | |
| "probabilities": prob_map, | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |