Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from typing import List | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| # Load your fine-tuned Longformer model | |
| sentiment_pipeline = pipeline( | |
| "text-classification", | |
| model="spacesedan/reddit-sentiment-analysis-longformer" | |
| ) | |
| # Request models | |
| class SentimentRequest(BaseModel): | |
| content_id: str | |
| text: str | |
| class BatchSentimentRequest(BaseModel): | |
| posts: List[SentimentRequest] | |
| # Response model | |
| class SentimentResponse(BaseModel): | |
| content_id: str | |
| sentiment_score: float | |
| sentiment_label: str | |
| confidence: float | |
| # Updated label-to-score mapping | |
| LABEL_MAP = { | |
| "very negative": -1.0, | |
| "negative": -0.7, | |
| "neutral": 0.0, | |
| "positive": 0.7, | |
| "very positive": 1.0 | |
| } | |
| def normalize_prediction(label: str, confidence: float) -> (float, str): | |
| label = label.lower() | |
| score = LABEL_MAP.get(label, 0.0) | |
| # Confidence-based fallback to neutral | |
| if confidence < 0.6 and -0.7 < score < 0.7: | |
| return 0.0, "neutral" | |
| return score, label | |
| def analyze_sentiment(request: SentimentRequest): | |
| try: | |
| result = sentiment_pipeline(request.text)[0] | |
| confidence = round(result["score"], 3) | |
| sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence) | |
| return SentimentResponse( | |
| content_id=request.content_id, | |
| sentiment_score=sentiment_score, | |
| sentiment_label=sentiment_label, | |
| confidence=confidence | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def analyze_sentiment_batch(request: BatchSentimentRequest): | |
| try: | |
| responses = [] | |
| for post in request.posts: | |
| result = sentiment_pipeline(post.text)[0] | |
| confidence = round(result["score"], 3) | |
| sentiment_score, sentiment_label = normalize_prediction(result["label"], confidence) | |
| responses.append(SentimentResponse( | |
| content_id=post.content_id, | |
| sentiment_score=sentiment_score, | |
| sentiment_label=sentiment_label, | |
| confidence=confidence | |
| )) | |
| return responses | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def root(): | |
| return {"message": "Reddit Sentiment Analysis API (Longformer 5-point) is running!"} | |