Spaces:
Sleeping
Sleeping
| import logging | |
| logging.basicConfig(level=logging.DEBUG) | |
| import gradio as gr | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| from openai import OpenAI | |
| import pickle | |
| import faiss | |
| import json # 用於解析 JSON | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY")) | |
| # Load Excel data | |
| df = pd.read_excel("product_list.xlsx") # 修改為你的 xlsx 文件路徑 | |
| # Initialize FAISS Index | |
| def initialize_embeddings_from_pkl(pkl_path: str, faiss_path: str): | |
| """從 PKL 檔案載入嵌入,並加載 FAISS 索引""" | |
| if not os.path.exists(pkl_path): | |
| raise FileNotFoundError(f"Embedding file not found at {pkl_path}") | |
| if not os.path.exists(faiss_path): | |
| raise FileNotFoundError(f"FAISS index file not found at {faiss_path}") | |
| print(f"Loading embeddings from {pkl_path}...") | |
| with open(pkl_path, "rb") as f: | |
| embeddings = pickle.load(f) | |
| print(f"Loading FAISS index from {faiss_path}...") | |
| index = faiss.read_index(faiss_path) | |
| print(f"FAISS index loaded with {index.ntotal} embeddings.") | |
| return index, embeddings | |
| # Paths to FAISS index and PKL file | |
| faiss_path = "product_index.faiss" # 修改為你的 FAISS 索引文件路徑 | |
| pkl_path = "product_embeddings.pkl" # 修改為你的嵌入文件路徑 | |
| # Initialize embeddings | |
| print("Initializing embeddings...") | |
| faiss_index, product_embeddings_array = initialize_embeddings_from_pkl(pkl_path, faiss_path) | |
| assert faiss_index.ntotal == len(df), "FAISS 索引與 xlsx 文件的行數不一致!" | |
| print(f"Loaded embeddings with shape: {len(product_embeddings_array)} x {len(product_embeddings_array[0])}") | |
| print("Embeddings initialized") | |
| # Generate embeddings for query | |
| def get_embedding(text: str, model="text-embedding-ada-002"): | |
| """Get embeddings for a text using OpenAI's API""" | |
| try: | |
| text = text.replace("\n", " ") | |
| response = client.embeddings.create( | |
| input=[text], | |
| model=model | |
| ) | |
| return response.data[0].embedding | |
| except Exception as e: | |
| print(f"Error getting embedding: {e}") | |
| return None | |
| # Find similar products | |
| def find_similar_products(query_embedding, top_k=8): | |
| """Find most similar products using FAISS index""" | |
| if faiss_index is None: | |
| raise ValueError("FAISS index is not loaded.") | |
| # FAISS expects float32 type embeddings | |
| query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1) | |
| # Perform FAISS search | |
| distances, indices = faiss_index.search(query_embedding, top_k) | |
| # Retrieve metadata for matching products | |
| matching_products = df.iloc[indices[0]].copy() # 從 df 中提取對應的行 | |
| matching_products["similarity"] = distances[0] # 加入相似度數據 | |
| return matching_products | |
| json_schema = { | |
| "name": "CookingIngredientsSchema", | |
| "description": "Extract cooking ingredients and analysis from user query.", | |
| "strict": True, | |
| "schema": { | |
| "type": "object", | |
| "properties": { | |
| "analysis": { | |
| "type": "string", | |
| "description": "完整的需求分析,解釋用戶的目標和需要的物品。" | |
| }, | |
| "ingredients": { | |
| "type": "array", | |
| "description": "提取的食材或關鍵物品清單。", | |
| "items": { | |
| "type": "string", | |
| "description": "單一食材或關鍵物品名稱。" | |
| } | |
| } | |
| }, | |
| "required": ["analysis", "ingredients"], | |
| "additionalProperties": False | |
| } | |
| } | |
| # Analyze query and find products | |
| def analyze_query_and_find_products(query: str) -> str: | |
| if not query.strip(): | |
| return "請輸入您的問題或搜尋需求" | |
| try: | |
| # Analyze the query to understand intent | |
| analysis_messages = [ | |
| {"role": "system", "content": f"""You are a knowledgeable shopping assistant. | |
| When given a query: | |
| 1. Analyze what the user is looking for | |
| 2. Predict what user will need in a supermarket | |
| Provide your analysis in Traditional Chinese, focusing on understanding user needs."""}, | |
| {"role": "user", "content": f"Analyze this query and explain what the user needs: {query}"} | |
| ] | |
| analysis_response = client.chat.completions.create( | |
| model="gpt-4o", | |
| messages=analysis_messages, | |
| temperature=0.7, | |
| response_format={ | |
| "type": "json_schema", | |
| "json_schema": json_schema | |
| } | |
| ) | |
| analysis = analysis_response.choices[0].message.content | |
| analysis_json = json.loads(analysis) | |
| description = analysis_json["analysis"] # 取得分析結果 | |
| ingredients = analysis_json["ingredients"] # 取得提取的食材清單 | |
| print("=======") | |
| print("關鍵字陣列:") | |
| print(ingredients) | |
| print("=======") | |
| # Generate embedding for the query | |
| # 為每個食材生成嵌入並進行查詢 | |
| search_results = [] | |
| for item in ingredients: | |
| print(f"正在搜尋:{item}") | |
| query_embedding = get_embedding(item) # 生成食材嵌入 | |
| matching_products = find_similar_products(query_embedding) # 搜索結果 | |
| search_results.append((item, matching_products)) | |
| query_embedding = get_embedding(query + " " + description) | |
| # Find similar products | |
| matching_products = find_similar_products(query_embedding) | |
| print(f"Found {len(matching_products)} matching products") | |
| # Format the response | |
| response_parts = [ | |
| "🔍 需求分析:", | |
| analysis, | |
| "\n📦 相關商品推薦:\n" | |
| ] | |
| for item, matching_products in search_results: | |
| response_parts.append(f"### {item} 的推薦商品:") | |
| for _, row in matching_products.iterrows(): | |
| product_info = f""" | |
| • {row['item_name']} | |
| ID: {row['item_id']} | |
| 描述: {row['description']} | |
| 分類: {row['tags']} | |
| 相似度: {row['similarity']:.2f}""" | |
| response_parts.append(product_info) | |
| response_parts.append("\n💡 購物建議:") | |
| response_parts.append("根據您的需求,以上是推薦的商品!") | |
| return "\n".join(response_parts) | |
| except Exception as e: | |
| print(f"Error in search: {str(e)}") | |
| return f"搜尋發生錯誤: {str(e)}" | |
| # Get system status | |
| def get_system_status(): | |
| """Get system initialization status""" | |
| embeddings_loaded = faiss_index is not None | |
| embedding_count = faiss_index.ntotal if embeddings_loaded else 0 | |
| product_count = len(df) | |
| return { | |
| "embeddings_loaded": embeddings_loaded, | |
| "embedding_count": embedding_count, | |
| "product_count": product_count | |
| } | |
| # Gradio Interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🛒 智慧商品推薦系統 | |
| 輸入您的問題或需求,系統會: | |
| 1. 分析您的需求 | |
| 2. 推薦相關商品 | |
| 3. 提供實用建議 | |
| """ | |
| ) | |
| # System status | |
| with gr.Row(): | |
| status = get_system_status() | |
| status_md = f""" | |
| ### 系統狀態: | |
| - 資料庫商品數:{status['product_count']} | |
| - 向量嵌入狀態:{'✅ 已載入' if status['embeddings_loaded'] else '❌ 未載入'} | |
| """ | |
| gr.Markdown(status_md) | |
| # Main interface | |
| with gr.Column(): | |
| input_text = gr.Textbox( | |
| label="請輸入您的問題或需求", | |
| placeholder="例如:需要適合便當的食材", | |
| lines=3 | |
| ) | |
| output_text = gr.Textbox( | |
| label="分析結果與建議", | |
| lines=20 | |
| ) | |
| submit_btn = gr.Button("搜尋") | |
| submit_btn.click( | |
| fn=analyze_query_and_find_products, | |
| inputs=input_text, | |
| outputs=output_text | |
| ) | |
| gr.Markdown("--- 系統使用 AI 分析需求並推薦商品。") | |
| # Launch the app | |
| demo.launch() |