File size: 1,101 Bytes
2f0250d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
from dotenv import load_dotenv

app = FastAPI()

load_dotenv()
checkpoint = "GaaS-Team/DistilBERT-finetuned-GaaS"
hf_token = os.getenv("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=hf_token)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, use_auth_token=hf_token)

class TextInput(BaseModel):
    text: str

@app.post("/predict")
async def predict_sentiment(input: TextInput):
    try:
        text = input.text
        inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")

        with torch.no_grad():
            logits = model(**inputs).logits

        predicted_class_id = logits.argmax(dim=-1).item()
        predicted_class = model.config.id2label[predicted_class_id]
        return {"sentiment": predicted_class}
    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")