selvaonline commited on
Commit
4ec9a8e
·
verified ·
1 Parent(s): 9df7f3a

Upload gradio_demo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. gradio_demo.py +126 -0
gradio_demo.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio demo for the Shopping Assistant model
4
+ """
5
+ import gradio as gr
6
+ import requests
7
+ import numpy as np
8
+ import argparse
9
+
10
+ def query_model(text, api_token=None, model_id="selvaonline/shopping-assistant"):
11
+ """
12
+ Query the model using the Hugging Face Inference API
13
+ """
14
+ api_url = f"https://api-inference.huggingface.co/models/{model_id}"
15
+
16
+ headers = {}
17
+ if api_token:
18
+ headers["Authorization"] = f"Bearer {api_token}"
19
+
20
+ payload = {
21
+ "inputs": text,
22
+ "options": {
23
+ "wait_for_model": True
24
+ }
25
+ }
26
+
27
+ response = requests.post(api_url, headers=headers, json=payload)
28
+
29
+ if response.status_code == 200:
30
+ return response.json()
31
+ else:
32
+ print(f"Error: {response.status_code}")
33
+ print(response.text)
34
+ return None
35
+
36
+ def process_results(results, text):
37
+ """
38
+ Process the results from the Inference API
39
+ """
40
+ if not results or not isinstance(results, list) or len(results) == 0:
41
+ return f"No results found for '{text}'"
42
+
43
+ # The API returns logits, we need to convert them to probabilities
44
+
45
+ # Apply sigmoid to convert logits to probabilities
46
+ probabilities = 1 / (1 + np.exp(-np.array(results[0])))
47
+
48
+ # Define the categories (should match the model's categories)
49
+ categories = ["electronics", "clothing", "home", "kitchen", "toys", "other"]
50
+
51
+ # Get the top categories
52
+ top_categories = []
53
+ for i, score in enumerate(probabilities):
54
+ if score > 0.5: # Threshold for multi-label classification
55
+ top_categories.append((categories[i], float(score)))
56
+
57
+ # Sort by score
58
+ top_categories.sort(key=lambda x: x[1], reverse=True)
59
+
60
+ # Format the results
61
+ if top_categories:
62
+ result = f"Top categories for '{text}':\n\n"
63
+ for category, score in top_categories:
64
+ result += f"- {category}: {score:.4f}\n"
65
+
66
+ result += f"\nBased on your query, I would recommend looking for deals in the **{top_categories[0][0]}** category."
67
+ else:
68
+ result = f"No categories found for '{text}'. Please try a different query."
69
+
70
+ return result
71
+
72
+ def classify_query(query, api_token=None, model_id="selvaonline/shopping-assistant"):
73
+ """
74
+ Classify a shopping query using the model
75
+ """
76
+ results = query_model(query, api_token, model_id)
77
+ return process_results(results, query)
78
+
79
+ def create_gradio_interface(api_token=None, model_id="selvaonline/shopping-assistant"):
80
+ """
81
+ Create a Gradio interface for the Shopping Assistant model
82
+ """
83
+ # Define the interface
84
+ demo = gr.Interface(
85
+ fn=lambda query: classify_query(query, api_token, model_id),
86
+ inputs=gr.Textbox(
87
+ lines=2,
88
+ placeholder="Enter your shopping query here...",
89
+ label="Shopping Query"
90
+ ),
91
+ outputs=gr.Markdown(label="Results"),
92
+ title="Shopping Assistant",
93
+ description="""
94
+ This demo shows how to use the Shopping Assistant model to classify shopping queries into categories.
95
+ Enter a shopping query below to see which categories it belongs to.
96
+
97
+ Examples:
98
+ - "I'm looking for headphones"
99
+ - "Do you have any kitchen appliance deals?"
100
+ - "Show me the best laptop deals"
101
+ - "I need a new smart TV"
102
+ """,
103
+ examples=[
104
+ ["I'm looking for headphones"],
105
+ ["Do you have any kitchen appliance deals?"],
106
+ ["Show me the best laptop deals"],
107
+ ["I need a new smart TV"]
108
+ ],
109
+ theme=gr.themes.Soft()
110
+ )
111
+
112
+ return demo
113
+
114
+ def main():
115
+ parser = argparse.ArgumentParser(description="Gradio demo for the Shopping Assistant model")
116
+ parser.add_argument("--token", type=str, help="Hugging Face API token")
117
+ parser.add_argument("--model-id", type=str, default="selvaonline/shopping-assistant", help="Hugging Face model ID")
118
+ parser.add_argument("--share", action="store_true", help="Create a public link")
119
+ args = parser.parse_args()
120
+
121
+ print(f"Starting Gradio demo for model {args.model_id}")
122
+ demo = create_gradio_interface(args.token, args.model_id)
123
+ demo.launch(share=args.share)
124
+
125
+ if __name__ == "__main__":
126
+ main()