MissingBreath's picture
Update api.py
04bd264 verified
raw
history blame
1.63 kB
from fastapi import FastAPI, File, UploadFile
import numpy as np
from PIL import Image
import io
import tensorflow as tf
from transformers import AutoTokenizer, AutoModelForSequenceClassification
tokenizer = AutoTokenizer.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizer
# tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
model = AutoModelForSequenceClassification.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizerFast
# tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
# from transformers import pipeline
# model = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
def inference(review):
inputs = tokenizer(review, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
# Assuming the model outputs logits
predicted_class = outputs.logits.argmax(dim=-1).item()
class_labels = [
'Improvement Suggestions', 'Questions', 'Confusion', 'Support Request',
'Discussion', 'Course Comparison', 'Related Course Suggestions',
'Negative', 'Positive'
]
return class_labels[predicted_class]
app = FastAPI()
@app.post("/classify")
async def classify(request: ReviewRequest):
reviews = request.reviews
predictions = []
# Process each review and get the predictions
for review in reviews:
predicted_class = inference(review)
predictions.append({predicted_class})
return {"predictions": predictions}