selvaonline commited on
Commit
0cf3ada
·
verified ·
1 Parent(s): 2af87ab

Upload widget.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. widget.py +49 -0
widget.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+ import json
4
+ import os
5
+
6
+ # Define the function that will be called when the widget is used
7
+ def infer(text):
8
+ # Load the model and tokenizer
9
+ model_path = os.path.dirname(os.path.abspath(__file__))
10
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
11
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
12
+
13
+ # Load the categories
14
+ try:
15
+ with open(os.path.join(model_path, "categories.json"), "r") as f:
16
+ categories = json.load(f)
17
+ except Exception as e:
18
+ print(f"Error loading categories: {str(e)}")
19
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
20
+
21
+ # Prepare the input
22
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
23
+
24
+ # Get the model prediction
25
+ with torch.no_grad():
26
+ outputs = model(**inputs)
27
+ predictions = torch.sigmoid(outputs.logits)
28
+
29
+ # Get the top categories
30
+ top_categories = []
31
+ for i, score in enumerate(predictions[0]):
32
+ if score > 0.5: # Threshold for multi-label classification
33
+ top_categories.append((categories[i], score.item()))
34
+
35
+ # Sort by score
36
+ top_categories.sort(key=lambda x: x[1], reverse=True)
37
+
38
+ # Format the results
39
+ if top_categories:
40
+ result = f"Top categories for '{text}':\n\n"
41
+ for category, score in top_categories:
42
+ result += f"- {category}: {score:.4f}\n"
43
+
44
+ result += "\nBased on your query, I would recommend looking for deals in the "
45
+ result += f"**{top_categories[0][0]}** category."
46
+ else:
47
+ result = f"No categories found for '{text}'. Please try a different query."
48
+
49
+ return result