Nefflymicn's picture
Update app.py
0cdbb60 verified
import gradio as gr
import torch
from transformers import DistilBertTokenizer
from model import TransformerSentimentModel # Import your custom class
from huggingface_hub import hf_hub_download
# 1. Load Tokenizer and Model
REPO_ID = "Nefflymicn/amazon-sentiment-transformer"
device = torch.device("cpu") # Spaces use CPU by default for free tier
tokenizer = DistilBertTokenizer.from_pretrained(REPO_ID, revision="v2.0")
# Architecture must match your trained version exactly
model = TransformerSentimentModel(
vocab_size=tokenizer.vocab_size,
embed_dim=128,
num_heads=8,
ff_dim=512,
num_layers=4,
output_dim=2
)
weights_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin", revision="v2.0")
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
# 2. Define the Prediction Function
def predict(text):
inputs = tokenizer(
text,
padding='max_length',
truncation=True,
max_length=300,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(inputs['input_ids'])
probs = torch.softmax(outputs, dim=1)
conf, pred = torch.max(probs, 1)
label = "Positive" if pred.item() == 1 else "Negative"
return {label: float(conf.item()), ("Negative" if label == "Positive" else "Positive"): 1 - float(conf.item())}
# 3. Build the Gradio Interface
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=5, placeholder="Type your product review here..."),
outputs=gr.Label(num_top_classes=2),
title="Amazon Sentiment Transformer",
description="A custom 4-layer Transformer trained on 500k Amazon reviews.",
examples=[
["The build quality is incredible and it arrived much faster than expected. Highly recommend!"],
["Total waste of money. The item stopped working after two days and customer service was useless."],
["While the setup was a bit difficult, the performance once it's running is unmatched. Best in its class."],
["I was skeptical at first because of the price, but after using it for a month, I can honestly say it is perfect."],
["The camera is great, the battery is decent, but the software is buggy and keeps crashing."]
]
)
if __name__ == "__main__":
demo.launch()