Kaiyeee's picture
Update app.py
ca5d70b verified
import gradio as gr
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification
import torch
# ✅ Load the model from Hugging Face Hub
MODEL_NAME = "Kaiyeee/fine_tuned_distilbert_imdb" # Update this with your actual repo name
tokenizer = DistilBertTokenizer.from_pretrained(MODEL_NAME)
model = DistilBertForSequenceClassification.from_pretrained(MODEL_NAME)
def predict_sentiment(text):
inputs = tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=-1).item()
sentiment = "positive" if predicted_class_id == 1 else "negative"
return sentiment
# ✅ Create a Gradio interface
demo = gr.Interface(
fn=predict_sentiment,
inputs=gr.Textbox(lines=5, placeholder="Enter text for sentiment analysis..."),
outputs="text",
title="Sentiment Analysis with DistilBERT",
description="Enter text to predict sentiment (positive or negative)."
)
if __name__ == "__main__":
demo.launch()