from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import pipeline from typing import List, Optional import os app = FastAPI( title="WorkPulse API", description="Company culture sentiment analysis using fine-tuned DistilBERT on 838K Glassdoor reviews", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=False, allow_methods=["*"], allow_headers=["*"], ) MODEL_NAME = os.getenv("HF_MODEL", "Madhuri1003/workpulse-distilbert") classifier = None @app.on_event("startup") async def load_model(): global classifier print(f"Loading model: {MODEL_NAME}") classifier = pipeline( "text-classification", model=MODEL_NAME, tokenizer=MODEL_NAME, device=-1 ) print("Model loaded successfully.") # --- Schemas --- class ReviewRequest(BaseModel): text: str class BatchReviewRequest(BaseModel): texts: List[str] class PartResult(BaseModel): text: str sentiment: str confidence_pct: str class SentimentResult(BaseModel): text: str sentiment: str confidence: float confidence_pct: str class BatchSentimentResult(BaseModel): results: List[SentimentResult] total: int summary: dict # --- Helpers --- CONTRAST_KEYWORDS = [" but ", " however ", " although ", " yet ", " though ", " despite "] def split_on_contrast(text): for kw in CONTRAST_KEYWORDS: if kw in text.lower(): idx = text.lower().find(kw) parts = [text[:idx].strip(), text[idx + len(kw):].strip()] parts = [p for p in parts if len(p.split()) >= 3] if len(parts) == 2: return parts return [text] def score_text(text): result = classifier(text, top_k=3) return {r["label"]: round(r["score"], 4) for r in result} # --- Endpoints --- @app.get("/") def root(): return { "name": "WorkPulse API", "description": "Company culture sentiment analysis", "model": MODEL_NAME, "endpoints": ["/predict", "/batch", "/health", "/docs"] } @app.get("/health") def health(): return { "status": "ok", "model_loaded": classifier is not None, "model": MODEL_NAME } @app.post("/predict") def predict(request: ReviewRequest): if classifier is None: raise HTTPException(status_code=503, detail="Model not loaded yet") text = request.text.strip() if not text: raise HTTPException(status_code=400, detail="Text cannot be empty") if len(text) > 1000: raise HTTPException(status_code=400, detail="Text too long, max 1000 characters") parts = split_on_contrast(text) is_split = len(parts) > 1 # Score each part individually all_scores = [score_text(p) for p in parts] # Average scores across parts final_scores = { "Negative": round(sum(s.get("Negative", 0) for s in all_scores) / len(all_scores), 4), "Neutral": round(sum(s.get("Neutral", 0) for s in all_scores) / len(all_scores), 4), "Positive": round(sum(s.get("Positive", 0) for s in all_scores) / len(all_scores), 4), } top_label = max(final_scores, key=final_scores.get) top_score = final_scores[top_label] sorted_vals = sorted(final_scores.values(), reverse=True) is_close = (sorted_vals[0] - sorted_vals[1]) < 0.20 is_mixed = is_split or is_close # Build parts_analyzed with per-part sentiment parts_analyzed = None if is_mixed and is_split: parts_analyzed = [ { "text": part, "sentiment": max(all_scores[i], key=all_scores[i].get), "confidence_pct": f"{max(all_scores[i].values()) * 100:.1f}%" } for i, part in enumerate(parts) ] return { "text": text, "sentiment": "Mixed" if is_mixed else top_label, "confidence": top_score, "confidence_pct": f"{top_score * 100:.1f}%", "is_mixed": is_mixed, "all_scores": final_scores, "parts_analyzed": parts_analyzed } @app.post("/batch", response_model=BatchSentimentResult) def batch_predict(request: BatchReviewRequest): if classifier is None: raise HTTPException(status_code=503, detail="Model not loaded yet") if not request.texts: raise HTTPException(status_code=400, detail="Texts list cannot be empty") if len(request.texts) > 20: raise HTTPException(status_code=400, detail="Max 20 texts per batch") texts = [t.strip() for t in request.texts if t.strip()] summary = {"Positive": 0, "Neutral": 0, "Negative": 0, "Mixed": 0} results = [] for text in texts: parts = split_on_contrast(text) is_split = len(parts) > 1 part_scores = [score_text(p) for p in parts] final_scores = { "Negative": round(sum(s.get("Negative", 0) for s in part_scores) / len(part_scores), 4), "Neutral": round(sum(s.get("Neutral", 0) for s in part_scores) / len(part_scores), 4), "Positive": round(sum(s.get("Positive", 0) for s in part_scores) / len(part_scores), 4), } top_label = max(final_scores, key=final_scores.get) top_score = final_scores[top_label] sorted_vals = sorted(final_scores.values(), reverse=True) is_mixed = is_split or (sorted_vals[0] - sorted_vals[1] < 0.20) sentiment = "Mixed" if is_mixed else top_label summary[sentiment] = summary.get(sentiment, 0) + 1 results.append(SentimentResult( text=text, sentiment=sentiment, confidence=top_score, confidence_pct=f"{top_score * 100:.1f}%" )) return BatchSentimentResult( results=results, total=len(results), summary=summary )