Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Query | |
| from pydantic import BaseModel | |
| from typing import Dict, Any | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| TextClassificationPipeline, | |
| ) | |
| # -------- Device selection -------- | |
| DEVICE = 0 if torch.cuda.is_available() else -1 # -1 = CPU | |
| print(f"API using {'GPU' if DEVICE == 0 else 'CPU'}") | |
| # -------- Models config -------- | |
| MODELS = { | |
| "FABSA": "Anudeep-Narala/fabsa-roberta-sentiment", | |
| "MoodMeter": "Priyanshuchaudhary2425/MoodMeter-sentimental-analysis", | |
| "Twitter": "cardiffnlp/twitter-roberta-base-sentiment-latest", | |
| } | |
| # -------- Load pipelines once -------- | |
| pipes: Dict[str, TextClassificationPipeline | None] = {} | |
| for name, mid in MODELS.items(): | |
| try: | |
| tok = AutoTokenizer.from_pretrained(mid, use_fast=False) | |
| mdl = AutoModelForSequenceClassification.from_pretrained(mid) | |
| pipes[name] = TextClassificationPipeline( | |
| model=mdl, tokenizer=tok, device=DEVICE, top_k=None | |
| ) | |
| print(f"Loaded: {name} -> {mid}") | |
| except Exception as e: | |
| print(f"Error loading model {name} ({mid}): {e}") | |
| pipes[name] = None | |
| def _normalize(scores: Any): | |
| """Normalize HF outputs into negative/neutral/positive probs and pick label.""" | |
| out = {"negative": 0.0, "neutral": 0.0, "positive": 0.0} | |
| for e in scores: | |
| lbl = e["label"].lower() | |
| s = float(e["score"]) | |
| if "neg" in lbl or lbl == "label_0": | |
| out["negative"] = s | |
| elif "neu" in lbl or lbl == "label_1": | |
| out["neutral"] = s | |
| elif "pos" in lbl or lbl == "label_2": | |
| out["positive"] = s | |
| pred = max(out, key=out.get) | |
| return pred, out | |
| def run_models(text: str) -> Dict[str, Any]: | |
| text = (text or "").strip() | |
| if not text: | |
| return { | |
| "FABSA": {"label": "N/A", "scores": {}}, | |
| "MoodMeter": {"label": "N/A", "scores": {}}, | |
| "Twitter": {"label": "N/A", "scores": {}}, | |
| "Ensemble": {"label": "N/A"}, | |
| "text": "", | |
| } | |
| res: Dict[str, Any] = {} | |
| for name, pipe in pipes.items(): | |
| if pipe is None: | |
| res[name] = {"label": "Error: Model failed to load", "scores": {}} | |
| continue | |
| try: | |
| raw = pipe(text)[0] | |
| pred, probs = _normalize(raw) | |
| res[name] = {"label": pred, "scores": probs} | |
| except Exception as e: | |
| res[name] = {"label": f"Error during inference: {e}", "scores": {}} | |
| fabsa_label = res.get("FABSA", {}).get("label", "N/A") | |
| twitter_label = res.get("Twitter", {}).get("label", "N/A") | |
| if ( | |
| fabsa_label != "N/A" | |
| and twitter_label != "N/A" | |
| and "Error" not in fabsa_label | |
| and "Error" not in twitter_label | |
| ): | |
| ensemble = ( | |
| "negative" | |
| if fabsa_label == "negative" | |
| else ("neutral" if twitter_label == "neutral" else twitter_label) | |
| ) | |
| else: | |
| ensemble = "N/A" | |
| res["Ensemble"] = {"label": ensemble} | |
| res["text"] = text | |
| return res | |
| # -------- FastAPI app -------- | |
| app = FastAPI(title="Mental Health Sentiment API", version="1.0.0") | |
| class PredictIn(BaseModel): | |
| text: str | |
| def health(): | |
| return {"status": "ok"} | |
| def predict(body: PredictIn): | |
| try: | |
| return run_models(body.text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Optional: GET /predict?text=... | |
| def predict_get(text: str = Query("", description="Input text to analyze")): | |
| try: | |
| return run_models(text) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="127.0.0.1", port=8000, reload=False) | |