workpulse-api / app.py
Madhuri1003's picture
Update app.py
b5d83b3 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
from typing import List, Optional
import os
app = FastAPI(
title="WorkPulse API",
description="Company culture sentiment analysis using fine-tuned DistilBERT on 838K Glassdoor reviews",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_NAME = os.getenv("HF_MODEL", "Madhuri1003/workpulse-distilbert")
classifier = None
@app.on_event("startup")
async def load_model():
global classifier
print(f"Loading model: {MODEL_NAME}")
classifier = pipeline(
"text-classification",
model=MODEL_NAME,
tokenizer=MODEL_NAME,
device=-1
)
print("Model loaded successfully.")
# --- Schemas ---
class ReviewRequest(BaseModel):
text: str
class BatchReviewRequest(BaseModel):
texts: List[str]
class PartResult(BaseModel):
text: str
sentiment: str
confidence_pct: str
class SentimentResult(BaseModel):
text: str
sentiment: str
confidence: float
confidence_pct: str
class BatchSentimentResult(BaseModel):
results: List[SentimentResult]
total: int
summary: dict
# --- Helpers ---
CONTRAST_KEYWORDS = [" but ", " however ", " although ", " yet ", " though ", " despite "]
def split_on_contrast(text):
for kw in CONTRAST_KEYWORDS:
if kw in text.lower():
idx = text.lower().find(kw)
parts = [text[:idx].strip(), text[idx + len(kw):].strip()]
parts = [p for p in parts if len(p.split()) >= 3]
if len(parts) == 2:
return parts
return [text]
def score_text(text):
result = classifier(text, top_k=3)
return {r["label"]: round(r["score"], 4) for r in result}
# --- Endpoints ---
@app.get("/")
def root():
return {
"name": "WorkPulse API",
"description": "Company culture sentiment analysis",
"model": MODEL_NAME,
"endpoints": ["/predict", "/batch", "/health", "/docs"]
}
@app.get("/health")
def health():
return {
"status": "ok",
"model_loaded": classifier is not None,
"model": MODEL_NAME
}
@app.post("/predict")
def predict(request: ReviewRequest):
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
text = request.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(text) > 1000:
raise HTTPException(status_code=400, detail="Text too long, max 1000 characters")
parts = split_on_contrast(text)
is_split = len(parts) > 1
# Score each part individually
all_scores = [score_text(p) for p in parts]
# Average scores across parts
final_scores = {
"Negative": round(sum(s.get("Negative", 0) for s in all_scores) / len(all_scores), 4),
"Neutral": round(sum(s.get("Neutral", 0) for s in all_scores) / len(all_scores), 4),
"Positive": round(sum(s.get("Positive", 0) for s in all_scores) / len(all_scores), 4),
}
top_label = max(final_scores, key=final_scores.get)
top_score = final_scores[top_label]
sorted_vals = sorted(final_scores.values(), reverse=True)
is_close = (sorted_vals[0] - sorted_vals[1]) < 0.20
is_mixed = is_split or is_close
# Build parts_analyzed with per-part sentiment
parts_analyzed = None
if is_mixed and is_split:
parts_analyzed = [
{
"text": part,
"sentiment": max(all_scores[i], key=all_scores[i].get),
"confidence_pct": f"{max(all_scores[i].values()) * 100:.1f}%"
}
for i, part in enumerate(parts)
]
return {
"text": text,
"sentiment": "Mixed" if is_mixed else top_label,
"confidence": top_score,
"confidence_pct": f"{top_score * 100:.1f}%",
"is_mixed": is_mixed,
"all_scores": final_scores,
"parts_analyzed": parts_analyzed
}
@app.post("/batch", response_model=BatchSentimentResult)
def batch_predict(request: BatchReviewRequest):
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
if not request.texts:
raise HTTPException(status_code=400, detail="Texts list cannot be empty")
if len(request.texts) > 20:
raise HTTPException(status_code=400, detail="Max 20 texts per batch")
texts = [t.strip() for t in request.texts if t.strip()]
summary = {"Positive": 0, "Neutral": 0, "Negative": 0, "Mixed": 0}
results = []
for text in texts:
parts = split_on_contrast(text)
is_split = len(parts) > 1
part_scores = [score_text(p) for p in parts]
final_scores = {
"Negative": round(sum(s.get("Negative", 0) for s in part_scores) / len(part_scores), 4),
"Neutral": round(sum(s.get("Neutral", 0) for s in part_scores) / len(part_scores), 4),
"Positive": round(sum(s.get("Positive", 0) for s in part_scores) / len(part_scores), 4),
}
top_label = max(final_scores, key=final_scores.get)
top_score = final_scores[top_label]
sorted_vals = sorted(final_scores.values(), reverse=True)
is_mixed = is_split or (sorted_vals[0] - sorted_vals[1] < 0.20)
sentiment = "Mixed" if is_mixed else top_label
summary[sentiment] = summary.get(sentiment, 0) + 1
results.append(SentimentResult(
text=text,
sentiment=sentiment,
confidence=top_score,
confidence_pct=f"{top_score * 100:.1f}%"
))
return BatchSentimentResult(
results=results,
total=len(results),
summary=summary
)