gajula21's picture
Update app.py
118bc34 verified
raw
history blame contribute delete
984 Bytes
from fastapi import FastAPI
import uvicorn
from pydantic import BaseModel
from typing import List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
class CommentsInput(BaseModel):
comments: List[str]
model_name = "gajula21/youtube-sentiment-model-telugu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
model.eval()
label_mapping = {0: "Negative", 1: "Neutral", 2: "Positive"}
app = FastAPI()
@app.get("/")
def read_root():
return {"message": "Hello, World!"}
@app.post("/sentiment")
def predict_sentiments(data: CommentsInput):
inputs = tokenizer(data.comments, return_tensors="pt", padding=True, truncation=True, max_length=256)
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=1).tolist()
sentiments = [label_mapping[p] for p in predictions]
return {"sentiments": sentiments}