MissingBreath's picture
Update api.py
7fb86a3 verified
raw
history blame
2.49 kB
from fastapi import FastAPI, File, UploadFile
import numpy as np
from PIL import Image
import io
import tensorflow as tf
import os
# from transformers import AutoTokenizer, AutoModelForSequenceClassification
# # os.environ['HF_TOKEN']=''
# from huggingface_hub import login
# hf_token = os.getenv("HF_TOKEN")
# login(token=hf_token)
# Read token from environment
# hf_token = os.getenv("HF_TOKEN")
# print("HF_TOKEN:", hf_token)
# Load tokenizer directly with the token (no login)
# tokenizer = AutoTokenizer.from_pretrained(
# "chillies/distilbert-course-review-classification",
# token=hf_token # Pass it directly
# )
# tokenizer = AutoTokenizer.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizer
# tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
# model = AutoModelForSequenceClassification.from_pretrained("chillies/distilbert-course-review-classification")
# from transformers import DistilBertTokenizerFast
# tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
# from transformers import pipeline
# model = pipeline("text-classification", model="distilbert-base-uncased-finetuned-sst-2-english")
from transformers import AutoModelForSequenceClassification, AutoTokenizer
MODEL_DIR = "/my_model"
TOKENIZER_DIR = "/my_tokenizer"
# Load the model and tokenizer
try:
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR)
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
def inference(review):
inputs = tokenizer(review, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs)
# Assuming the model outputs logits
predicted_class = outputs.logits.argmax(dim=-1).item()
class_labels = [
'Improvement Suggestions', 'Questions', 'Confusion', 'Support Request',
'Discussion', 'Course Comparison', 'Related Course Suggestions',
'Negative', 'Positive'
]
return class_labels[predicted_class]
app = FastAPI()
@app.post("/classify")
async def classify(request: ReviewRequest):
reviews = request.reviews
predictions = []
# Process each review and get the predictions
for review in reviews:
predicted_class = inference(review)
predictions.append({predicted_class})
return {"predictions": predictions}