runthebandsup commited on
Commit
2a0af58
·
verified ·
1 Parent(s): e7130af

Create Gradio interface for EcommerceClassifier

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
3
+ import torch
4
+
5
+ # Load the EcommerceClassifier model
6
+ model_name = "Maverick98/EcommerceClassifier"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
8
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
+
10
+ # Define classification function
11
+ def classify_product(product_text):
12
+ inputs = tokenizer(product_text, return_tensors="pt", truncation=True, max_length=512)
13
+ outputs = model(**inputs)
14
+ predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
15
+
16
+ # Get the predicted class and confidence
17
+ predicted_class = torch.argmax(predictions, dim=-1).item()
18
+ confidence = predictions[0][predicted_class].item()
19
+
20
+ # Map class index to label (adjust based on model's classes)
21
+ class_labels = model.config.id2label
22
+ predicted_label = class_labels.get(predicted_class, f"Class {predicted_class}")
23
+
24
+ # Return all probabilities for each class
25
+ results = {class_labels.get(i, f"Class {i}"): predictions[0][i].item()
26
+ for i in range(len(predictions[0]))}
27
+
28
+ return results
29
+
30
+ # Create Gradio interface
31
+ demo = gr.Interface(
32
+ fn=classify_product,
33
+ inputs=gr.Textbox(
34
+ label="Product Description",
35
+ placeholder="Enter product title or description...",
36
+ lines=5
37
+ ),
38
+ outputs=gr.Label(label="Classification Results", num_top_classes=10),
39
+ title="🛍️ E-Commerce Product Classifier",
40
+ description="Fast and accurate e-commerce product classification powered by EcommerceClassifier. Enter a product title or description to classify it into the appropriate category.",
41
+ examples=[
42
+ ["Women's Cotton T-Shirt - Casual Summer Wear"],
43
+ ["Wireless Bluetooth Headphones with Noise Cancellation"],
44
+ ["Organic Green Tea - 100 Tea Bags"],
45
+ ["Leather Office Chair with Lumbar Support"],
46
+ ],
47
+ theme="soft"
48
+ )
49
+
50
+ if __name__ == "__main__":
51
+ demo.launch()