from fastapi import FastAPI, HTTPException from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForSequenceClassification import torch import os from dotenv import load_dotenv app = FastAPI() load_dotenv() checkpoint = "GaaS-Team/DistilBERT-finetuned-GaaS" hf_token = os.getenv("HUGGINGFACE_TOKEN") tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_auth_token=hf_token) model = AutoModelForSequenceClassification.from_pretrained(checkpoint, use_auth_token=hf_token) class TextInput(BaseModel): text: str @app.post("/predict") async def predict_sentiment(input: TextInput): try: text = input.text inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt") with torch.no_grad(): logits = model(**inputs).logits predicted_class_id = logits.argmax(dim=-1).item() predicted_class = model.config.id2label[predicted_class_id] return {"sentiment": predicted_class} except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")