Spaces:
Running
Running
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from typing import List, Optional | |
| import os | |
| app = FastAPI( | |
| title="WorkPulse API", | |
| description="Company culture sentiment analysis using fine-tuned DistilBERT on 838K Glassdoor reviews", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=False, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| MODEL_NAME = os.getenv("HF_MODEL", "Madhuri1003/workpulse-distilbert") | |
| classifier = None | |
| async def load_model(): | |
| global classifier | |
| print(f"Loading model: {MODEL_NAME}") | |
| classifier = pipeline( | |
| "text-classification", | |
| model=MODEL_NAME, | |
| tokenizer=MODEL_NAME, | |
| device=-1 | |
| ) | |
| print("Model loaded successfully.") | |
| # --- Schemas --- | |
| class ReviewRequest(BaseModel): | |
| text: str | |
| class BatchReviewRequest(BaseModel): | |
| texts: List[str] | |
| class PartResult(BaseModel): | |
| text: str | |
| sentiment: str | |
| confidence_pct: str | |
| class SentimentResult(BaseModel): | |
| text: str | |
| sentiment: str | |
| confidence: float | |
| confidence_pct: str | |
| class BatchSentimentResult(BaseModel): | |
| results: List[SentimentResult] | |
| total: int | |
| summary: dict | |
| # --- Helpers --- | |
| CONTRAST_KEYWORDS = [" but ", " however ", " although ", " yet ", " though ", " despite "] | |
| def split_on_contrast(text): | |
| for kw in CONTRAST_KEYWORDS: | |
| if kw in text.lower(): | |
| idx = text.lower().find(kw) | |
| parts = [text[:idx].strip(), text[idx + len(kw):].strip()] | |
| parts = [p for p in parts if len(p.split()) >= 3] | |
| if len(parts) == 2: | |
| return parts | |
| return [text] | |
| def score_text(text): | |
| result = classifier(text, top_k=3) | |
| return {r["label"]: round(r["score"], 4) for r in result} | |
| # --- Endpoints --- | |
| def root(): | |
| return { | |
| "name": "WorkPulse API", | |
| "description": "Company culture sentiment analysis", | |
| "model": MODEL_NAME, | |
| "endpoints": ["/predict", "/batch", "/health", "/docs"] | |
| } | |
| def health(): | |
| return { | |
| "status": "ok", | |
| "model_loaded": classifier is not None, | |
| "model": MODEL_NAME | |
| } | |
| def predict(request: ReviewRequest): | |
| if classifier is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| text = request.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Text cannot be empty") | |
| if len(text) > 1000: | |
| raise HTTPException(status_code=400, detail="Text too long, max 1000 characters") | |
| parts = split_on_contrast(text) | |
| is_split = len(parts) > 1 | |
| # Score each part individually | |
| all_scores = [score_text(p) for p in parts] | |
| # Average scores across parts | |
| final_scores = { | |
| "Negative": round(sum(s.get("Negative", 0) for s in all_scores) / len(all_scores), 4), | |
| "Neutral": round(sum(s.get("Neutral", 0) for s in all_scores) / len(all_scores), 4), | |
| "Positive": round(sum(s.get("Positive", 0) for s in all_scores) / len(all_scores), 4), | |
| } | |
| top_label = max(final_scores, key=final_scores.get) | |
| top_score = final_scores[top_label] | |
| sorted_vals = sorted(final_scores.values(), reverse=True) | |
| is_close = (sorted_vals[0] - sorted_vals[1]) < 0.20 | |
| is_mixed = is_split or is_close | |
| # Build parts_analyzed with per-part sentiment | |
| parts_analyzed = None | |
| if is_mixed and is_split: | |
| parts_analyzed = [ | |
| { | |
| "text": part, | |
| "sentiment": max(all_scores[i], key=all_scores[i].get), | |
| "confidence_pct": f"{max(all_scores[i].values()) * 100:.1f}%" | |
| } | |
| for i, part in enumerate(parts) | |
| ] | |
| return { | |
| "text": text, | |
| "sentiment": "Mixed" if is_mixed else top_label, | |
| "confidence": top_score, | |
| "confidence_pct": f"{top_score * 100:.1f}%", | |
| "is_mixed": is_mixed, | |
| "all_scores": final_scores, | |
| "parts_analyzed": parts_analyzed | |
| } | |
| def batch_predict(request: BatchReviewRequest): | |
| if classifier is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded yet") | |
| if not request.texts: | |
| raise HTTPException(status_code=400, detail="Texts list cannot be empty") | |
| if len(request.texts) > 20: | |
| raise HTTPException(status_code=400, detail="Max 20 texts per batch") | |
| texts = [t.strip() for t in request.texts if t.strip()] | |
| summary = {"Positive": 0, "Neutral": 0, "Negative": 0, "Mixed": 0} | |
| results = [] | |
| for text in texts: | |
| parts = split_on_contrast(text) | |
| is_split = len(parts) > 1 | |
| part_scores = [score_text(p) for p in parts] | |
| final_scores = { | |
| "Negative": round(sum(s.get("Negative", 0) for s in part_scores) / len(part_scores), 4), | |
| "Neutral": round(sum(s.get("Neutral", 0) for s in part_scores) / len(part_scores), 4), | |
| "Positive": round(sum(s.get("Positive", 0) for s in part_scores) / len(part_scores), 4), | |
| } | |
| top_label = max(final_scores, key=final_scores.get) | |
| top_score = final_scores[top_label] | |
| sorted_vals = sorted(final_scores.values(), reverse=True) | |
| is_mixed = is_split or (sorted_vals[0] - sorted_vals[1] < 0.20) | |
| sentiment = "Mixed" if is_mixed else top_label | |
| summary[sentiment] = summary.get(sentiment, 0) + 1 | |
| results.append(SentimentResult( | |
| text=text, | |
| sentiment=sentiment, | |
| confidence=top_score, | |
| confidence_pct=f"{top_score * 100:.1f}%" | |
| )) | |
| return BatchSentimentResult( | |
| results=results, | |
| total=len(results), | |
| summary=summary | |
| ) |