|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import DistilBertTokenizer |
|
|
from model import TransformerSentimentModel |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
|
|
|
REPO_ID = "Nefflymicn/amazon-sentiment-transformer" |
|
|
device = torch.device("cpu") |
|
|
tokenizer = DistilBertTokenizer.from_pretrained(REPO_ID, revision="v2.0") |
|
|
|
|
|
|
|
|
model = TransformerSentimentModel( |
|
|
vocab_size=tokenizer.vocab_size, |
|
|
embed_dim=128, |
|
|
num_heads=8, |
|
|
ff_dim=512, |
|
|
num_layers=4, |
|
|
output_dim=2 |
|
|
) |
|
|
weights_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin", revision="v2.0") |
|
|
model.load_state_dict(torch.load(weights_path, map_location=device)) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
|
|
def predict(text): |
|
|
inputs = tokenizer( |
|
|
text, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=300, |
|
|
return_tensors="pt" |
|
|
) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(inputs['input_ids']) |
|
|
probs = torch.softmax(outputs, dim=1) |
|
|
conf, pred = torch.max(probs, 1) |
|
|
|
|
|
label = "Positive" if pred.item() == 1 else "Negative" |
|
|
return {label: float(conf.item()), ("Negative" if label == "Positive" else "Positive"): 1 - float(conf.item())} |
|
|
|
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=predict, |
|
|
inputs=gr.Textbox(lines=5, placeholder="Type your product review here..."), |
|
|
outputs=gr.Label(num_top_classes=2), |
|
|
title="Amazon Sentiment Transformer", |
|
|
description="A custom 4-layer Transformer trained on 500k Amazon reviews.", |
|
|
examples=[ |
|
|
["The build quality is incredible and it arrived much faster than expected. Highly recommend!"], |
|
|
["Total waste of money. The item stopped working after two days and customer service was useless."], |
|
|
["While the setup was a bit difficult, the performance once it's running is unmatched. Best in its class."], |
|
|
["I was skeptical at first because of the price, but after using it for a month, I can honestly say it is perfect."], |
|
|
["The camera is great, the battery is decent, but the software is buggy and keeps crashing."] |
|
|
] |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |