juinkinn's picture
Initial commit
1536543
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import os
from dotenv import load_dotenv
app = FastAPI()
load_dotenv()
checkpoint = "GaaS-Team/DistilBERT-finetuned-GaaS"
hf_token = os.getenv("HUGGINGFACE_TOKEN")
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=hf_token)
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, use_auth_token=hf_token)
class TextInput(BaseModel):
text: str
@app.post("/predict")
async def predict_sentiment(input: TextInput):
try:
text = input.text
inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
predicted_class_id = logits.argmax(dim=-1).item()
predicted_class = model.config.id2label[predicted_class_id]
return {"sentiment": predicted_class}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")