Spaces:
Running
Running
File size: 5,909 Bytes
29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 29c9d32 b5d83b3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import pipeline
from typing import List, Optional
import os
app = FastAPI(
title="WorkPulse API",
description="Company culture sentiment analysis using fine-tuned DistilBERT on 838K Glassdoor reviews",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=False,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_NAME = os.getenv("HF_MODEL", "Madhuri1003/workpulse-distilbert")
classifier = None
@app.on_event("startup")
async def load_model():
global classifier
print(f"Loading model: {MODEL_NAME}")
classifier = pipeline(
"text-classification",
model=MODEL_NAME,
tokenizer=MODEL_NAME,
device=-1
)
print("Model loaded successfully.")
# --- Schemas ---
class ReviewRequest(BaseModel):
text: str
class BatchReviewRequest(BaseModel):
texts: List[str]
class PartResult(BaseModel):
text: str
sentiment: str
confidence_pct: str
class SentimentResult(BaseModel):
text: str
sentiment: str
confidence: float
confidence_pct: str
class BatchSentimentResult(BaseModel):
results: List[SentimentResult]
total: int
summary: dict
# --- Helpers ---
CONTRAST_KEYWORDS = [" but ", " however ", " although ", " yet ", " though ", " despite "]
def split_on_contrast(text):
for kw in CONTRAST_KEYWORDS:
if kw in text.lower():
idx = text.lower().find(kw)
parts = [text[:idx].strip(), text[idx + len(kw):].strip()]
parts = [p for p in parts if len(p.split()) >= 3]
if len(parts) == 2:
return parts
return [text]
def score_text(text):
result = classifier(text, top_k=3)
return {r["label"]: round(r["score"], 4) for r in result}
# --- Endpoints ---
@app.get("/")
def root():
return {
"name": "WorkPulse API",
"description": "Company culture sentiment analysis",
"model": MODEL_NAME,
"endpoints": ["/predict", "/batch", "/health", "/docs"]
}
@app.get("/health")
def health():
return {
"status": "ok",
"model_loaded": classifier is not None,
"model": MODEL_NAME
}
@app.post("/predict")
def predict(request: ReviewRequest):
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
text = request.text.strip()
if not text:
raise HTTPException(status_code=400, detail="Text cannot be empty")
if len(text) > 1000:
raise HTTPException(status_code=400, detail="Text too long, max 1000 characters")
parts = split_on_contrast(text)
is_split = len(parts) > 1
# Score each part individually
all_scores = [score_text(p) for p in parts]
# Average scores across parts
final_scores = {
"Negative": round(sum(s.get("Negative", 0) for s in all_scores) / len(all_scores), 4),
"Neutral": round(sum(s.get("Neutral", 0) for s in all_scores) / len(all_scores), 4),
"Positive": round(sum(s.get("Positive", 0) for s in all_scores) / len(all_scores), 4),
}
top_label = max(final_scores, key=final_scores.get)
top_score = final_scores[top_label]
sorted_vals = sorted(final_scores.values(), reverse=True)
is_close = (sorted_vals[0] - sorted_vals[1]) < 0.20
is_mixed = is_split or is_close
# Build parts_analyzed with per-part sentiment
parts_analyzed = None
if is_mixed and is_split:
parts_analyzed = [
{
"text": part,
"sentiment": max(all_scores[i], key=all_scores[i].get),
"confidence_pct": f"{max(all_scores[i].values()) * 100:.1f}%"
}
for i, part in enumerate(parts)
]
return {
"text": text,
"sentiment": "Mixed" if is_mixed else top_label,
"confidence": top_score,
"confidence_pct": f"{top_score * 100:.1f}%",
"is_mixed": is_mixed,
"all_scores": final_scores,
"parts_analyzed": parts_analyzed
}
@app.post("/batch", response_model=BatchSentimentResult)
def batch_predict(request: BatchReviewRequest):
if classifier is None:
raise HTTPException(status_code=503, detail="Model not loaded yet")
if not request.texts:
raise HTTPException(status_code=400, detail="Texts list cannot be empty")
if len(request.texts) > 20:
raise HTTPException(status_code=400, detail="Max 20 texts per batch")
texts = [t.strip() for t in request.texts if t.strip()]
summary = {"Positive": 0, "Neutral": 0, "Negative": 0, "Mixed": 0}
results = []
for text in texts:
parts = split_on_contrast(text)
is_split = len(parts) > 1
part_scores = [score_text(p) for p in parts]
final_scores = {
"Negative": round(sum(s.get("Negative", 0) for s in part_scores) / len(part_scores), 4),
"Neutral": round(sum(s.get("Neutral", 0) for s in part_scores) / len(part_scores), 4),
"Positive": round(sum(s.get("Positive", 0) for s in part_scores) / len(part_scores), 4),
}
top_label = max(final_scores, key=final_scores.get)
top_score = final_scores[top_label]
sorted_vals = sorted(final_scores.values(), reverse=True)
is_mixed = is_split or (sorted_vals[0] - sorted_vals[1] < 0.20)
sentiment = "Mixed" if is_mixed else top_label
summary[sentiment] = summary.get(sentiment, 0) + 1
results.append(SentimentResult(
text=text,
sentiment=sentiment,
confidence=top_score,
confidence_pct=f"{top_score * 100:.1f}%"
))
return BatchSentimentResult(
results=results,
total=len(results),
summary=summary
) |