spacesedan commited on
Commit
bb9b235
·
1 Parent(s): 1609cc2

feat: changing model with a more finetuned respose

Browse files
Files changed (1) hide show
  1. app.py +70 -28
app.py CHANGED
@@ -1,43 +1,85 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from typing import List
4
  from transformers import pipeline
 
 
5
 
 
6
  app = FastAPI()
7
 
8
- sentiment_analyzer = pipeline("sentiment-analysis", model="nlptown/bert-base-multilingual-uncased-sentiment")
 
 
 
 
9
 
 
10
  class SentimentRequest(BaseModel):
 
11
  text: str
12
- post_id: str
13
-
14
- class SentimentResponse(BaseModel):
15
- post_id: str
16
- label: str
17
- score: float
18
 
19
- @app.post("/sentiment", response_model=SentimentResponse)
20
- async def analyze_sentiment(request: SentimentRequest):
21
- result = sentiment_analyzer(request.text)
22
- return {"post_id": request.post_id, "label": result[0]["label"], "score": result[0]["score"]}
23
 
24
- class SentimentPost(BaseModel):
25
- post_id: str
26
- text: str
 
 
 
27
 
28
- class SentimentBatchRequest(BaseModel):
29
- posts: List[SentimentPost]
 
 
 
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
- @app.post("/sentiment/batch", response_model=List[SentimentResponse])
33
- async def analyze_sentiment_batch(request: SentimentBatchRequest):
34
- texts = [post.text for post in request.posts]
35
- results = sentiment_analyzer(texts)
36
- return [
37
- {"post_id": post.post_id, "label": res["label"], "score": res["score"]}
38
- for post, res in zip(request.posts, results)
39
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
 
41
  @app.get("/")
42
- def greet_json():
43
- return {"message": "BERT Sentiment Analysis API is running!"}
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
 
3
  from transformers import pipeline
4
+ import torch
5
+ from typing import List
6
 
7
+ # Initialize FastAPI
8
  app = FastAPI()
9
 
10
+ # Load RoBERTa sentiment analysis model
11
+ sentiment_pipeline = pipeline(
12
+ "sentiment-analysis",
13
+ model="cardiffnlp/twitter-roberta-base-sentiment"
14
+ )
15
 
16
+ # Request models
17
  class SentimentRequest(BaseModel):
18
+ content_id: str
19
  text: str
 
 
 
 
 
 
20
 
21
+ class BatchSentimentRequest(BaseModel):
22
+ posts: List[SentimentRequest]
 
 
23
 
24
+ # Response model
25
+ class SentimentResponse(BaseModel):
26
+ content_id: str
27
+ sentiment_score: float
28
+ sentiment_label: str
29
+ confidence: float
30
 
31
+ # Mapping RoBERTa labels to a floating-point scale
32
+ LABEL_MAP = {
33
+ "LABEL_0": -1.0, # Negative
34
+ "LABEL_1": 0.0, # Neutral
35
+ "LABEL_2": 1.0 # Positive
36
+ }
37
 
38
+ @app.post("/analyze", response_model=SentimentResponse)
39
+ def analyze_sentiment(request: SentimentRequest):
40
+ try:
41
+ # Get model prediction
42
+ result = sentiment_pipeline(request.text)[0]
43
+ label = result["label"]
44
+ score = result["score"]
45
+
46
+ # Convert RoBERTa labels to floating-point scores
47
+ sentiment_score = LABEL_MAP[label]
48
+ confidence = round(score, 3)
49
+
50
+ return SentimentResponse(
51
+ content_id=request.content_id,
52
+ sentiment_score=sentiment_score,
53
+ sentiment_label="positive" if sentiment_score > 0 else "neutral" if sentiment_score == 0 else "negative",
54
+ confidence=confidence
55
+ )
56
+ except Exception as e:
57
+ raise HTTPException(status_code=500, detail=str(e))
58
 
59
+ @app.post("/analyze_batch", response_model=List[SentimentResponse])
60
+ def analyze_sentiment_batch(request: BatchSentimentRequest):
61
+ try:
62
+ responses = []
63
+ for post in request.posts:
64
+ result = sentiment_pipeline(post.text)[0]
65
+ label = result["label"]
66
+ score = result["score"]
67
+
68
+ # Convert RoBERTa labels to floating-point scores
69
+ sentiment_score = LABEL_MAP[label]
70
+ confidence = round(score, 3)
71
+
72
+ responses.append(SentimentResponse(
73
+ content_id=post.content_id,
74
+ sentiment_score=sentiment_score,
75
+ sentiment_label="positive" if sentiment_score > 0 else "neutral" if sentiment_score == 0 else "negative",
76
+ confidence=confidence
77
+ ))
78
+ return responses
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=str(e))
81
 
82
+ # Root endpoint
83
  @app.get("/")
84
+ def root():
85
+ return {"message": "RoBERTa Sentiment Analysis API is running!"}