Basshole commited on
Commit
88d4bd7
·
verified ·
1 Parent(s): 2f66663

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +308 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ logging.basicConfig(level=logging.DEBUG)
3
+ import gradio as gr
4
+ import pandas as pd
5
+ import numpy as np
6
+ import os
7
+ from openai import OpenAI
8
+ from typing import List, Dict
9
+ import pickle
10
+ import time
11
+ from sklearn.metrics.pairwise import cosine_similarity
12
+ from huggingface_hub import HfApi, hf_hub_download, upload_file
13
+ from pathlib import Path
14
+
15
+ # Initialize OpenAI client
16
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
17
+
18
+ # Hugging Face configuration
19
+ HF_TOKEN = os.environ.get("HF_TOKEN")
20
+ REPO_ID = os.environ.get("REPO_ID") # format: "username/space-name"
21
+ EMBEDDING_FILE = "product_embeddings.pkl"
22
+
23
+ # Initialize Hugging Face API
24
+ hf_api = HfApi(token=HF_TOKEN)
25
+
26
+ # Load CSV data
27
+ df = pd.read_csv("item_new.csv", encoding='utf-8')
28
+
29
+ def create_product_text(row):
30
+ """Create a comprehensive text representation of a product"""
31
+ #return f"{row['item_desc']} {row['item_class1_desc']} {row['item_class2_desc']} {row['item_class3_desc']} {str(row['brand'])} {str(row['spec'])}"
32
+ return f"{row['item_desc']}"
33
+
34
+ def get_embedding(text: str, model="text-embedding-3-small"):
35
+ """Get embeddings for a text using OpenAI's API"""
36
+ try:
37
+ text = text.replace("\n", " ")
38
+ response = client.embeddings.create(
39
+ input=[text],
40
+ model=model
41
+ )
42
+ return response.data[0].embedding
43
+ except Exception as e:
44
+ print(f"Error getting embedding: {e}")
45
+ return None
46
+
47
+ def download_embeddings():
48
+ """Try to download embeddings from Hugging Face"""
49
+ try:
50
+ local_path = hf_hub_download(
51
+ repo_id=REPO_ID,
52
+ filename=EMBEDDING_FILE,
53
+ token=HF_TOKEN
54
+ )
55
+ with open(local_path, 'rb') as f:
56
+ return pickle.load(f)
57
+ except Exception as e:
58
+ print(f"Error downloading embeddings: {e}")
59
+ return None
60
+
61
+ def upload_embeddings(embeddings):
62
+ """Upload embeddings to Hugging Face"""
63
+ try:
64
+ # Save embeddings locally first
65
+ temp_path = "temp_embeddings.pkl"
66
+ with open(temp_path, 'wb') as f:
67
+ pickle.dump(embeddings, f)
68
+
69
+ # Upload to Hugging Face
70
+ hf_api.upload_file(
71
+ path_or_fileobj=temp_path,
72
+ path_in_repo=EMBEDDING_FILE,
73
+ repo_id=REPO_ID,
74
+ token=HF_TOKEN
75
+ )
76
+
77
+ # Clean up temp file
78
+ os.remove(temp_path)
79
+ print("Successfully uploaded embeddings")
80
+ except Exception as e:
81
+ print(f"Error uploading embeddings: {e}")
82
+
83
+ def initialize_embeddings():
84
+ """Initialize or load product embeddings"""
85
+ print("Checking for existing embeddings...")
86
+ embeddings = download_embeddings()
87
+
88
+ if embeddings is not None:
89
+ print("Loaded existing embeddings")
90
+ return embeddings
91
+
92
+ print("Creating new embeddings...")
93
+ embeddings = []
94
+ for idx, row in df.iterrows():
95
+ product_text = create_product_text(row)
96
+ embedding = get_embedding(product_text)
97
+ if embedding:
98
+ embeddings.append(embedding)
99
+ else:
100
+ embeddings.append([0] * 1536) # Default embedding dimension
101
+ time.sleep(0.1) # Rate limiting for API calls
102
+
103
+ # Upload new embeddings
104
+ upload_embeddings(embeddings)
105
+
106
+ return embeddings
107
+
108
+ # Load or create embeddings
109
+ print("Initializing embeddings...")
110
+ product_embeddings = initialize_embeddings()
111
+ product_embeddings_array = np.array(product_embeddings)
112
+ print("Embeddings initialized")
113
+
114
+ def find_similar_products(query_embedding, top_k=8):
115
+ """Find most similar products using cosine similarity"""
116
+ similarities = cosine_similarity(
117
+ [query_embedding],
118
+ product_embeddings_array
119
+ )[0]
120
+
121
+ top_indices = similarities.argsort()[-top_k:][::-1]
122
+ return df.iloc[top_indices], similarities[top_indices]
123
+
124
+ # Rest of the code remains the same...
125
+ def analyze_query_and_find_products(query: str) -> str:
126
+ if not query.strip():
127
+ return "請輸入您的問題或搜尋需求"
128
+
129
+ try:
130
+ # First, analyze the query to understand intent
131
+ analysis_messages = [
132
+ {"role": "system", "content": f"""You are a knowledgeable shopping assistant.
133
+ When given a query:
134
+ 1. Analyze what the user is looking for
135
+ 2. Predict what user will need in a supermarket
136
+
137
+ Provide your analysis in Traditional Chinese, focusing on understanding user needs."""},
138
+ {"role": "user", "content": f"Analyze this query and explain what the user needs: {query}"}
139
+ ]
140
+
141
+ analysis_response = client.chat.completions.create(
142
+ model="gpt-4o",
143
+ messages=analysis_messages,
144
+ temperature=0.7,
145
+ max_tokens=500
146
+ )
147
+
148
+ analysis = analysis_response.choices[0].message.content
149
+
150
+ # Get embedding for the query
151
+ query_embedding = get_embedding(query + " " + analysis)
152
+
153
+ # Find similar products
154
+ matching_products, similarities = find_similar_products(query_embedding)
155
+ print(f"Found {len(matching_products)} matching products")
156
+
157
+ # Get recommendations based on found products
158
+ product_descriptions = "\n".join([
159
+ f"- {row['item_desc']} ({row['item_class1_desc']})"
160
+ for _, row in matching_products.iterrows()
161
+ ])
162
+
163
+ recommendation_messages = [
164
+ {"role": "system", "content": f"""Based on the query and available products,
165
+ provide helpful recommendations and tips. Consider:
166
+ 1. How the products can be used
167
+ 2. What to look for when choosing
168
+ 3. Alternative options if available
169
+ Respond in Traditional Chinese."""},
170
+ {"role": "user", "content": f"""Query: {query}
171
+ Analysis: {analysis}
172
+ Available products: {product_descriptions}"""}
173
+ ]
174
+
175
+ recommendation_response = client.chat.completions.create(
176
+ model="gpt-4o",
177
+ messages=recommendation_messages,
178
+ temperature=0.7,
179
+ max_tokens=250
180
+ )
181
+
182
+ # Format the response
183
+ response_parts = [
184
+ "🔍 需求分析:",
185
+ analysis,
186
+ "\n📦 相關商品推薦:\n"
187
+ ]
188
+
189
+ for (_, product), similarity in zip(matching_products.iterrows(), similarities):
190
+ confidence = similarity * 100
191
+ product_info = f"""
192
+ • {product['item_desc']}
193
+ 分類: {product['item_class1_desc']} > {product['item_class2_desc']}
194
+ 規格: {product['spec']}
195
+ 價格: NT$ {float(product['sales_amt']):,.0f} / {product['unit']}
196
+ 相關度: {confidence:.1f}%"""
197
+ response_parts.append(product_info)
198
+
199
+ response_parts.extend([
200
+ "\n💡 購物建議:",
201
+ recommendation_response.choices[0].message.content
202
+ ])
203
+
204
+ return "\n".join(response_parts)
205
+
206
+ except Exception as e:
207
+ print(f"Error in search: {str(e)}")
208
+ return f"搜尋發生錯誤: {str(e)}"
209
+
210
+ # Add system status message
211
+ def get_system_status():
212
+ """Get system initialization status"""
213
+ return {
214
+ "embeddings_loaded": product_embeddings is not None,
215
+ "embedding_count": len(product_embeddings) if product_embeddings else 0,
216
+ "product_count": len(df)
217
+ }
218
+
219
+ # Modified interface with status
220
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
221
+ gr.Markdown(
222
+ """
223
+ # 🛒 智慧商品推薦系統
224
+
225
+ 輸入您的問題或需求,系統會:
226
+ 1. 分析您的需求
227
+ 2. 推薦相關商品
228
+ 3. 提供實用建議
229
+ """
230
+ )
231
+
232
+ # System status
233
+ with gr.Row():
234
+ status = get_system_status()
235
+ status_md = f"""
236
+ ### 系統狀態:
237
+ - 資料庫商品數:{status['product_count']}
238
+ - 向量嵌入狀態:{'✅ 已載入' if status['embeddings_loaded'] else '❌ 未載入'}
239
+ """
240
+ gr.Markdown(status_md)
241
+
242
+ # Main interface
243
+ with gr.Column():
244
+ # Input area
245
+ input_text = gr.Textbox(
246
+ label="請輸入您的問題或需求",
247
+ placeholder="您可以詢問任何商品相關的問題,例如:\n- 想找一些適合做便當的食材\n- 需要營養均衡的食材\n- 想買一些新鮮的海鮮\n- 有什麼適合老人家的食物",
248
+ lines=3
249
+ )
250
+
251
+ # Buttons
252
+ with gr.Row():
253
+ submit_btn = gr.Button("搜尋", variant="primary")
254
+ clear_btn = gr.Button("清除")
255
+
256
+ # Output area
257
+ output_text = gr.Textbox(
258
+ label="分析結果與建議",
259
+ lines=25
260
+ )
261
+
262
+ # Clear function
263
+ def clear_inputs():
264
+ return {"input_text": "", "output_text": ""}
265
+
266
+ # Setup button actions
267
+ submit_btn.click(
268
+ fn=analyze_query_and_find_products,
269
+ inputs=input_text,
270
+ outputs=output_text,
271
+ api_name="search" # This enables API access
272
+ )
273
+
274
+ clear_btn.click(
275
+ fn=clear_inputs,
276
+ inputs=[],
277
+ outputs=[input_text, output_text],
278
+ api_name="clear"
279
+ )
280
+
281
+ # Examples section
282
+ gr.Markdown("### 搜尋範例")
283
+ with gr.Row():
284
+ examples = gr.Examples(
285
+ examples=[
286
+ ["想找一些適合做便當的食材"],
287
+ ["需要一些營養均衡的食物"],
288
+ ["有沒有適合老人家吃的食物?"],
289
+ ["想買一些新鮮的海鮮,有什麼推薦?"],
290
+ ["最近感冒了,有什麼食材可以幫助恢復?"],
291
+ ],
292
+ inputs=input_text,
293
+ outputs=output_text,
294
+ fn=analyze_query_and_find_products,
295
+ cache_examples=True
296
+ )
297
+
298
+ # Footer
299
+ gr.Markdown(
300
+ """
301
+ ---
302
+ 系統使用AI語意分析技術,能更好地理解您的需求並提供相關建議。
303
+ 如有任何問題或建議,歡迎反饋。
304
+ """
305
+ )
306
+
307
+ # Launch the app
308
+ demo.launch()