from fastapi import FastAPI from pydantic import BaseModel import torch import torch.nn as nn import numpy as np import re # =============================== # App Init # =============================== app = FastAPI(title="GoEmotions Sentiment API", version="1.0") DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # =============================== # Emotion Mapping # =============================== emotion_map = [ "admiration","amusement","anger","annoyance","approval","caring","confusion", "curiosity","desire","disappointment","disapproval","disgust","embarrassment", "excitement","fear","gratitude","grief","joy","love","nervousness","optimism", "pride","realization","relief","remorse","sadness","surprise","neutral" ] POSITIVE_EMOTIONS = { "admiration","amusement","approval","caring","desire","excitement", "gratitude","joy","love","optimism","pride","relief" } NEGATIVE_EMOTIONS = { "anger","annoyance","disappointment","disapproval","disgust","embarrassment", "fear","grief","nervousness","remorse","sadness" } NEUTRAL_EMOTIONS = { "confusion","curiosity","realization","surprise","neutral" } # =============================== # Text Utils # =============================== def simple_tokenize(text): return text.split() def clean_text(text): text = text.lower() text = re.sub(r'[^a-z0-9\s]', ' ', text) text = re.sub(r'\s+', ' ', text).strip() return text # =============================== # Model Definition # =============================== class GoEmotionsLSTM(nn.Module): def __init__(self, vocab_size, embed_dim=200, hidden_dim=256, num_classes=28, num_layers=2): super().__init__() self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=0) self.lstm = nn.LSTM( input_size=embed_dim, hidden_size=hidden_dim, num_layers=num_layers, batch_first=True, dropout=0.2, bidirectional=True ) self.fc = nn.Linear(hidden_dim * 2, num_classes) def forward(self, x): x = self.embeddings(x) _, (h, _) = self.lstm(x) h_forward = h[-2] h_backward = h[-1] h_cat = torch.cat((h_forward, h_backward), dim=1) out = self.fc(h_cat) return out # =============================== # Globals (Loaded Once) # =============================== model = None vocab = None max_len = None # =============================== # Load Model at Startup # =============================== @app.on_event("startup") def load_model(): global model, vocab, max_len print("Loading GoEmotions BiLSTM model...") checkpoint = torch.load("goemotions_bilstm_checkpoint.pth", map_location=DEVICE) vocab = checkpoint["vocab"] max_len = checkpoint["max_len"] model = GoEmotionsLSTM(vocab_size=len(vocab)) model.load_state_dict(checkpoint["model_state"]) model.to(DEVICE) model.eval() print("Model loaded successfully.") # =============================== # Request Schema # =============================== class PredictRequest(BaseModel): text: str # =============================== # Status Endpoint # =============================== @app.get("/status") def status(): if model is None: return {"status": "loading"} return {"status": "ok", "model_loaded": True} # =============================== # Sentiment Aggregation Logic # =============================== def aggregate_sentiment(probs): pos_score = 0.0 neg_score = 0.0 neu_score = 0.0 for i, p in enumerate(probs): emotion = emotion_map[i] if emotion in POSITIVE_EMOTIONS: pos_score += p elif emotion in NEGATIVE_EMOTIONS: neg_score += p else: neu_score += p if pos_score > neg_score and pos_score > neu_score: return "Positive", pos_score elif neg_score > pos_score and neg_score > neu_score: return "Negative", neg_score else: return "Neutral", neu_score # =============================== # Prediction Endpoint # =============================== @app.post("/predict") def predict(req: PredictRequest): text = clean_text(req.text) tokens = simple_tokenize(text) # Convert tokens to indices seq = [vocab.get(tok, 1) for tok in tokens] # = 1 # Pad / truncate if len(seq) < max_len: seq += [vocab[""]] * (max_len - len(seq)) else: seq = seq[:max_len] x = torch.tensor([seq], dtype=torch.long).to(DEVICE) with torch.no_grad(): logits = model(x) probs = torch.sigmoid(logits).squeeze(0).cpu().numpy() sentiment, score = aggregate_sentiment(probs) return { "sentiment": sentiment, "confidence": round(float(score) * 100, 2) }