junaid17's picture
Upload 2 files
da6e212 verified
from fastapi import FastAPI
from pydantic import BaseModel
import torch
import torch.nn as nn
import numpy as np
import re
# ===============================
# App Init
# ===============================
app = FastAPI(title="GoEmotions Sentiment API", version="1.0")
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===============================
# Emotion Mapping
# ===============================
emotion_map = [
"admiration","amusement","anger","annoyance","approval","caring","confusion",
"curiosity","desire","disappointment","disapproval","disgust","embarrassment",
"excitement","fear","gratitude","grief","joy","love","nervousness","optimism",
"pride","realization","relief","remorse","sadness","surprise","neutral"
]
POSITIVE_EMOTIONS = {
"admiration","amusement","approval","caring","desire","excitement",
"gratitude","joy","love","optimism","pride","relief"
}
NEGATIVE_EMOTIONS = {
"anger","annoyance","disappointment","disapproval","disgust","embarrassment",
"fear","grief","nervousness","remorse","sadness"
}
NEUTRAL_EMOTIONS = {
"confusion","curiosity","realization","surprise","neutral"
}
# ===============================
# Text Utils
# ===============================
def simple_tokenize(text):
return text.split()
def clean_text(text):
text = text.lower()
text = re.sub(r'[^a-z0-9\s]', ' ', text)
text = re.sub(r'\s+', ' ', text).strip()
return text
# ===============================
# Model Definition
# ===============================
class GoEmotionsLSTM(nn.Module):
def __init__(self, vocab_size, embed_dim=200, hidden_dim=256, num_classes=28, num_layers=2):
super().__init__()
self.embeddings = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(
input_size=embed_dim,
hidden_size=hidden_dim,
num_layers=num_layers,
batch_first=True,
dropout=0.2,
bidirectional=True
)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
x = self.embeddings(x)
_, (h, _) = self.lstm(x)
h_forward = h[-2]
h_backward = h[-1]
h_cat = torch.cat((h_forward, h_backward), dim=1)
out = self.fc(h_cat)
return out
# ===============================
# Globals (Loaded Once)
# ===============================
model = None
vocab = None
max_len = None
# ===============================
# Load Model at Startup
# ===============================
@app.on_event("startup")
def load_model():
global model, vocab, max_len
print("Loading GoEmotions BiLSTM model...")
checkpoint = torch.load("goemotions_bilstm_checkpoint.pth", map_location=DEVICE)
vocab = checkpoint["vocab"]
max_len = checkpoint["max_len"]
model = GoEmotionsLSTM(vocab_size=len(vocab))
model.load_state_dict(checkpoint["model_state"])
model.to(DEVICE)
model.eval()
print("Model loaded successfully.")
# ===============================
# Request Schema
# ===============================
class PredictRequest(BaseModel):
text: str
# ===============================
# Status Endpoint
# ===============================
@app.get("/status")
def status():
if model is None:
return {"status": "loading"}
return {"status": "ok", "model_loaded": True}
# ===============================
# Sentiment Aggregation Logic
# ===============================
def aggregate_sentiment(probs):
pos_score = 0.0
neg_score = 0.0
neu_score = 0.0
for i, p in enumerate(probs):
emotion = emotion_map[i]
if emotion in POSITIVE_EMOTIONS:
pos_score += p
elif emotion in NEGATIVE_EMOTIONS:
neg_score += p
else:
neu_score += p
if pos_score > neg_score and pos_score > neu_score:
return "Positive", pos_score
elif neg_score > pos_score and neg_score > neu_score:
return "Negative", neg_score
else:
return "Neutral", neu_score
# ===============================
# Prediction Endpoint
# ===============================
@app.post("/predict")
def predict(req: PredictRequest):
text = clean_text(req.text)
tokens = simple_tokenize(text)
# Convert tokens to indices
seq = [vocab.get(tok, 1) for tok in tokens] # <UNK> = 1
# Pad / truncate
if len(seq) < max_len:
seq += [vocab["<PAD>"]] * (max_len - len(seq))
else:
seq = seq[:max_len]
x = torch.tensor([seq], dtype=torch.long).to(DEVICE)
with torch.no_grad():
logits = model(x)
probs = torch.sigmoid(logits).squeeze(0).cpu().numpy()
sentiment, score = aggregate_sentiment(probs)
return {
"sentiment": sentiment,
"confidence": round(float(score) * 100, 2)
}