bert-yelp / app.py
ogflash's picture
Update app.py
c3bea6e verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
# Load the tokenizer and model from local path (or HF if internet is available)
model = AutoModelForSequenceClassification.from_pretrained("ogflash/yelp_review_classifier")
tokenizer = AutoTokenizer.from_pretrained("ogflash/yelp_review_classifier")
# Prediction function
def classify(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
# Remove token_type_ids if using DistilBERT
if "token_type_ids" in inputs:
inputs.pop("token_type_ids")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=1).item()
score = torch.softmax(logits, dim=1)[0][predicted_class_id].item()
# Map labels using if-elif-else
label = f"LABEL_{predicted_class_id}"
if label == "LABEL_0":
label_name = "Negative"
elif label == "LABEL_1":
label_name = "Neutral"
elif label == "LABEL_2":
label_name = "Positive"
else:
label_name = label # fallback
return f"{label_name} ({score * 100:.2f}%)"
# Gradio UI
iface = gr.Interface(fn=classify,
inputs=gr.Textbox(lines=2, placeholder="Enter your review here..."),
outputs="text",
title="Sentiment Classifier",
description="Classifies text into Positive, Neutral, or Negative.")
iface.launch()