Zaious's picture
Update app.py
3fc9d31 verified
import logging
logging.basicConfig(level=logging.DEBUG)
import gradio as gr
import pandas as pd
import numpy as np
import os
from openai import OpenAI
import pickle
import faiss
import json # 用於解析 JSON
# Initialize OpenAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Load Excel data
df = pd.read_excel("product_list.xlsx") # 修改為你的 xlsx 文件路徑
# Initialize FAISS Index
def initialize_embeddings_from_pkl(pkl_path: str, faiss_path: str):
"""從 PKL 檔案載入嵌入,並加載 FAISS 索引"""
if not os.path.exists(pkl_path):
raise FileNotFoundError(f"Embedding file not found at {pkl_path}")
if not os.path.exists(faiss_path):
raise FileNotFoundError(f"FAISS index file not found at {faiss_path}")
print(f"Loading embeddings from {pkl_path}...")
with open(pkl_path, "rb") as f:
embeddings = pickle.load(f)
print(f"Loading FAISS index from {faiss_path}...")
index = faiss.read_index(faiss_path)
print(f"FAISS index loaded with {index.ntotal} embeddings.")
return index, embeddings
# Paths to FAISS index and PKL file
faiss_path = "product_index.faiss" # 修改為你的 FAISS 索引文件路徑
pkl_path = "product_embeddings.pkl" # 修改為你的嵌入文件路徑
# Initialize embeddings
print("Initializing embeddings...")
faiss_index, product_embeddings_array = initialize_embeddings_from_pkl(pkl_path, faiss_path)
assert faiss_index.ntotal == len(df), "FAISS 索引與 xlsx 文件的行數不一致!"
print(f"Loaded embeddings with shape: {len(product_embeddings_array)} x {len(product_embeddings_array[0])}")
print("Embeddings initialized")
# Generate embeddings for query
def get_embedding(text: str, model="text-embedding-ada-002"):
"""Get embeddings for a text using OpenAI's API"""
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
input=[text],
model=model
)
return response.data[0].embedding
except Exception as e:
print(f"Error getting embedding: {e}")
return None
# Find similar products
def find_similar_products(query_embedding, top_k=8):
"""Find most similar products using FAISS index"""
if faiss_index is None:
raise ValueError("FAISS index is not loaded.")
# FAISS expects float32 type embeddings
query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
# Perform FAISS search
distances, indices = faiss_index.search(query_embedding, top_k)
# Retrieve metadata for matching products
matching_products = df.iloc[indices[0]].copy() # 從 df 中提取對應的行
matching_products["similarity"] = distances[0] # 加入相似度數據
return matching_products
json_schema = {
"name": "CookingIngredientsSchema",
"description": "Extract cooking ingredients and analysis from user query.",
"strict": True,
"schema": {
"type": "object",
"properties": {
"analysis": {
"type": "string",
"description": "完整的需求分析,解釋用戶的目標和需要的物品。"
},
"ingredients": {
"type": "array",
"description": "提取的食材或關鍵物品清單。",
"items": {
"type": "string",
"description": "單一食材或關鍵物品名稱。"
}
}
},
"required": ["analysis", "ingredients"],
"additionalProperties": False
}
}
# Analyze query and find products
def analyze_query_and_find_products(query: str) -> str:
if not query.strip():
return "請輸入您的問題或搜尋需求"
try:
# Analyze the query to understand intent
analysis_messages = [
{"role": "system", "content": f"""You are a knowledgeable shopping assistant.
When given a query:
1. Analyze what the user is looking for
2. Predict what user will need in a supermarket
Provide your analysis in Traditional Chinese, focusing on understanding user needs."""},
{"role": "user", "content": f"Analyze this query and explain what the user needs: {query}"}
]
analysis_response = client.chat.completions.create(
model="gpt-4o",
messages=analysis_messages,
temperature=0.7,
response_format={
"type": "json_schema",
"json_schema": json_schema
}
)
analysis = analysis_response.choices[0].message.content
analysis_json = json.loads(analysis)
description = analysis_json["analysis"] # 取得分析結果
ingredients = analysis_json["ingredients"] # 取得提取的食材清單
print("=======")
print("關鍵字陣列:")
print(ingredients)
print("=======")
# Generate embedding for the query
# 為每個食材生成嵌入並進行查詢
search_results = []
for item in ingredients:
print(f"正在搜尋:{item}")
query_embedding = get_embedding(item) # 生成食材嵌入
matching_products = find_similar_products(query_embedding) # 搜索結果
search_results.append((item, matching_products))
query_embedding = get_embedding(query + " " + description)
# Find similar products
matching_products = find_similar_products(query_embedding)
print(f"Found {len(matching_products)} matching products")
# Format the response
response_parts = [
"🔍 需求分析:",
analysis,
"\n📦 相關商品推薦:\n"
]
for item, matching_products in search_results:
response_parts.append(f"### {item} 的推薦商品:")
for _, row in matching_products.iterrows():
product_info = f"""
{row['item_name']}
ID: {row['item_id']}
描述: {row['description']}
分類: {row['tags']}
相似度: {row['similarity']:.2f}"""
response_parts.append(product_info)
response_parts.append("\n💡 購物建議:")
response_parts.append("根據您的需求,以上是推薦的商品!")
return "\n".join(response_parts)
except Exception as e:
print(f"Error in search: {str(e)}")
return f"搜尋發生錯誤: {str(e)}"
# Get system status
def get_system_status():
"""Get system initialization status"""
embeddings_loaded = faiss_index is not None
embedding_count = faiss_index.ntotal if embeddings_loaded else 0
product_count = len(df)
return {
"embeddings_loaded": embeddings_loaded,
"embedding_count": embedding_count,
"product_count": product_count
}
# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🛒 智慧商品推薦系統
輸入您的問題或需求,系統會:
1. 分析您的需求
2. 推薦相關商品
3. 提供實用建議
"""
)
# System status
with gr.Row():
status = get_system_status()
status_md = f"""
### 系統狀態:
- 資料庫商品數:{status['product_count']}
- 向量嵌入狀態:{'✅ 已載入' if status['embeddings_loaded'] else '❌ 未載入'}
"""
gr.Markdown(status_md)
# Main interface
with gr.Column():
input_text = gr.Textbox(
label="請輸入您的問題或需求",
placeholder="例如:需要適合便當的食材",
lines=3
)
output_text = gr.Textbox(
label="分析結果與建議",
lines=20
)
submit_btn = gr.Button("搜尋")
submit_btn.click(
fn=analyze_query_and_find_products,
inputs=input_text,
outputs=output_text
)
gr.Markdown("--- 系統使用 AI 分析需求並推薦商品。")
# Launch the app
demo.launch()