import gradio as gr import torch from transformers import DistilBertTokenizer from model import TransformerSentimentModel # Import your custom class from huggingface_hub import hf_hub_download # 1. Load Tokenizer and Model REPO_ID = "Nefflymicn/amazon-sentiment-transformer" device = torch.device("cpu") # Spaces use CPU by default for free tier tokenizer = DistilBertTokenizer.from_pretrained(REPO_ID, revision="v2.0") # Architecture must match your trained version exactly model = TransformerSentimentModel( vocab_size=tokenizer.vocab_size, embed_dim=128, num_heads=8, ff_dim=512, num_layers=4, output_dim=2 ) weights_path = hf_hub_download(repo_id=REPO_ID, filename="pytorch_model.bin", revision="v2.0") model.load_state_dict(torch.load(weights_path, map_location=device)) model.eval() # 2. Define the Prediction Function def predict(text): inputs = tokenizer( text, padding='max_length', truncation=True, max_length=300, return_tensors="pt" ) with torch.no_grad(): outputs = model(inputs['input_ids']) probs = torch.softmax(outputs, dim=1) conf, pred = torch.max(probs, 1) label = "Positive" if pred.item() == 1 else "Negative" return {label: float(conf.item()), ("Negative" if label == "Positive" else "Positive"): 1 - float(conf.item())} # 3. Build the Gradio Interface demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=5, placeholder="Type your product review here..."), outputs=gr.Label(num_top_classes=2), title="Amazon Sentiment Transformer", description="A custom 4-layer Transformer trained on 500k Amazon reviews.", examples=[ ["The build quality is incredible and it arrived much faster than expected. Highly recommend!"], ["Total waste of money. The item stopped working after two days and customer service was useless."], ["While the setup was a bit difficult, the performance once it's running is unmatched. Best in its class."], ["I was skeptical at first because of the price, but after using it for a month, I can honestly say it is perfect."], ["The camera is great, the battery is decent, but the software is buggy and keeps crashing."] ] ) if __name__ == "__main__": demo.launch()