Zaious's picture
Update app.py
c02e86f verified
import logging
logging.basicConfig(level=logging.DEBUG)
import gradio as gr
import pandas as pd
import numpy as np
import os
from openai import OpenAI
from typing import List, Dict
import pickle
import time
from sklearn.metrics.pairwise import cosine_similarity
from huggingface_hub import HfApi, hf_hub_download, upload_file
from pathlib import Path
# Initialize OpenAI client
client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
# Hugging Face configuration
HF_TOKEN = os.environ.get("HF_TOKEN")
REPO_ID = os.environ.get("REPO_ID") # format: "username/space-name"
EMBEDDING_FILE = "product_embeddings.pkl"
# Initialize Hugging Face API
hf_api = HfApi(token=HF_TOKEN)
# Load CSV data
df = pd.read_csv("item_new.csv", encoding='utf-8')
def create_product_text(row):
"""Create a comprehensive text representation of a product"""
#return f"{row['item_desc']} {row['item_class1_desc']} {row['item_class2_desc']} {row['item_class3_desc']} {str(row['brand'])} {str(row['spec'])}"
return f"{row['item_name']} {row['description']} {row['tags']}"
def get_embedding(text: str, model="text-embedding-3-small"):
"""Get embeddings for a text using OpenAI's API"""
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
input=[text],
model=model
)
return response.data[0].embedding
except Exception as e:
print(f"Error getting embedding: {e}")
return None
def download_embeddings():
"""Try to download embeddings from Hugging Face"""
try:
local_path = hf_hub_download(
repo_id=REPO_ID,
filename=EMBEDDING_FILE,
token=HF_TOKEN
)
with open(local_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
print(f"Error downloading embeddings: {e}")
return None
def upload_embeddings(embeddings):
"""Upload embeddings to Hugging Face"""
try:
# Save embeddings locally first
temp_path = "temp_embeddings.pkl"
with open(temp_path, 'wb') as f:
pickle.dump(embeddings, f)
# Upload to Hugging Face
hf_api.upload_file(
path_or_fileobj=temp_path,
path_in_repo=EMBEDDING_FILE,
repo_id=REPO_ID,
token=HF_TOKEN
)
# Clean up temp file
os.remove(temp_path)
print("Successfully uploaded embeddings")
except Exception as e:
print(f"Error uploading embeddings: {e}")
def initialize_embeddings():
"""Initialize or load product embeddings"""
print("Checking for existing embeddings...")
embeddings = download_embeddings()
if embeddings is not None:
print("Loaded existing embeddings")
return embeddings
print("Creating new embeddings...")
embeddings = []
for idx, row in df.iterrows():
product_text = create_product_text(row)
embedding = get_embedding(product_text)
if embedding:
embeddings.append(embedding)
else:
embeddings.append([0] * 1536) # Default embedding dimension
time.sleep(0.1) # Rate limiting for API calls
# Upload new embeddings
upload_embeddings(embeddings)
return embeddings
# Load or create embeddings
print("Initializing embeddings...")
product_embeddings = initialize_embeddings()
product_embeddings_array = np.array(product_embeddings)
print("Embeddings initialized")
def find_similar_products(query_embedding, top_k=8):
"""Find most similar products using cosine similarity"""
similarities = cosine_similarity(
[query_embedding],
product_embeddings_array
)[0]
top_indices = similarities.argsort()[-top_k:][::-1]
return df.iloc[top_indices], similarities[top_indices]
# Rest of the code remains the same...
def analyze_query_and_find_products(query: str) -> str:
if not query.strip():
return "請輸入您的問題或搜尋需求"
try:
# First, analyze the query to understand intent
analysis_messages = [
{"role": "system", "content": f"""You are a knowledgeable shopping assistant.
When given a query:
1. Analyze what the user is looking for
2. Predict what user will need in a supermarket
Provide your analysis in Traditional Chinese, focusing on understanding user needs."""},
{"role": "user", "content": f"Analyze this query and explain what the user needs: {query}"}
]
analysis_response = client.chat.completions.create(
model="gpt-4o",
messages=analysis_messages,
temperature=0.7,
max_tokens=500
)
analysis = analysis_response.choices[0].message.content
print("Embedding Querry的是:" + analysis)
# Get embedding for the query
query_embedding = get_embedding(query + " " + analysis)
# Find similar products
matching_products, similarities = find_similar_products(query_embedding)
print(f"Found {len(matching_products)} matching products")
# Get recommendations based on found products
product_descriptions = "\n".join([
f"- {row['item_id']} ({row['item_name']})"
for _, row in matching_products.iterrows()
])
recommendation_messages = [
{"role": "system", "content": f"""Based on the query and available products,
provide helpful recommendations and tips. Consider:
1. How the products can be used
2. What to look for when choosing
3. Alternative options if available
Respond in Traditional Chinese."""},
{"role": "user", "content": f"""Query: {query}
Analysis: {analysis}
Available products: {product_descriptions}"""}
]
recommendation_response = client.chat.completions.create(
model="gpt-4o",
messages=recommendation_messages,
temperature=0.7,
max_tokens=250
)
# Format the response
response_parts = [
"🔍 需求分析:",
analysis,
"\n📦 相關商品推薦:\n"
]
for (_, product), similarity in zip(matching_products.iterrows(), similarities):
confidence = similarity * 100
product_info = f"""
{product['item_name']}
分類: {product['tags']}
規格: {product['description']}
相關度: {confidence:.1f}%"""
response_parts.append(product_info)
response_parts.extend([
"\n💡 購物建議:",
recommendation_response.choices[0].message.content
])
return "\n".join(response_parts)
except Exception as e:
print(f"Error in search: {str(e)}")
return f"搜尋發生錯誤: {str(e)}"
# Add system status message
def get_system_status():
"""Get system initialization status"""
return {
"embeddings_loaded": product_embeddings is not None,
"embedding_count": len(product_embeddings) if product_embeddings else 0,
"product_count": len(df)
}
# Modified interface with status
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# 🛒 智慧商品推薦系統
輸入您的問題或需求,系統會:
1. 分析您的需求
2. 推薦相關商品
3. 提供實用建議
"""
)
# System status
with gr.Row():
status = get_system_status()
status_md = f"""
### 系統狀態:
- 資料庫商品數:{status['product_count']}
- 向量嵌入狀態:{'✅ 已載入' if status['embeddings_loaded'] else '❌ 未載入'}
"""
gr.Markdown(status_md)
# Main interface
with gr.Column():
# Input area
input_text = gr.Textbox(
label="請輸入您的問題或需求",
placeholder="您可以詢問任何商品相關的問題,例如:\n- 想找一些適合做便當的食材\n- 需要營養均衡的食材\n- 想買一些新鮮的海鮮\n- 有什麼適合老人家的食物",
lines=3
)
# Buttons
with gr.Row():
submit_btn = gr.Button("搜尋", variant="primary")
clear_btn = gr.Button("清除")
# Output area
output_text = gr.Textbox(
label="分析結果與建議",
lines=25
)
# Clear function
def clear_inputs():
return {"input_text": "", "output_text": ""}
# Setup button actions
submit_btn.click(
fn=analyze_query_and_find_products,
inputs=input_text,
outputs=output_text,
api_name="search" # This enables API access
)
clear_btn.click(
fn=clear_inputs,
inputs=[],
outputs=[input_text, output_text],
api_name="clear"
)
# Examples section
gr.Markdown("### 搜尋範例")
with gr.Row():
examples = gr.Examples(
examples=[
["想找一些適合做便當的食材"],
["需要一些營養均衡的食物"],
["有沒有適合老人家吃的食物?"],
["想買一些新鮮的海鮮,有什麼推薦?"],
["最近感冒了,有什麼食材可以幫助恢復?"],
],
inputs=input_text,
outputs=output_text,
fn=analyze_query_and_find_products,
cache_examples=True
)
# Footer
gr.Markdown(
"""
---
系統使用AI語意分析技術,能更好地理解您的需求並提供相關建議。
如有任何問題或建議,歡迎反饋。
"""
)
# Launch the app
demo.launch()