Zaious commited on
Commit
33ef8ae
·
verified ·
1 Parent(s): f4cbb4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -97
app.py CHANGED
@@ -11,115 +11,50 @@ import time
11
  from sklearn.metrics.pairwise import cosine_similarity
12
  from huggingface_hub import HfApi, hf_hub_download, upload_file
13
  from pathlib import Path
 
14
 
15
  # Initialize OpenAI client
16
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
17
 
18
- # Hugging Face configuration
19
- HF_TOKEN = os.environ.get("HF_TOKEN")
20
- REPO_ID = os.environ.get("REPO_ID") # format: "username/space-name"
21
- EMBEDDING_FILE = "product_embeddings.pkl"
22
-
23
- # Initialize Hugging Face API
24
- hf_api = HfApi(token=HF_TOKEN)
 
 
 
 
 
 
25
 
26
  # Load CSV data
27
  df = pd.read_csv("item_new.csv", encoding='utf-8')
28
 
29
- def create_product_text(row):
30
- """Create a comprehensive text representation of a product"""
31
- #return f"{row['item_desc']} {row['item_class1_desc']} {row['item_class2_desc']} {row['item_class3_desc']} {str(row['brand'])} {str(row['spec'])}"
32
- return f"{row['item_name']} {row['description']} {row['tags']}"
33
-
34
- def get_embedding(text: str, model="text-embedding-3-small"):
35
- """Get embeddings for a text using OpenAI's API"""
36
- try:
37
- text = text.replace("\n", " ")
38
- response = client.embeddings.create(
39
- input=[text],
40
- model=model
41
- )
42
- return response.data[0].embedding
43
- except Exception as e:
44
- print(f"Error getting embedding: {e}")
45
- return None
46
-
47
- def download_embeddings():
48
- """Try to download embeddings from Hugging Face"""
49
- try:
50
- local_path = hf_hub_download(
51
- repo_id=REPO_ID,
52
- filename=EMBEDDING_FILE,
53
- token=HF_TOKEN
54
- )
55
- with open(local_path, 'rb') as f:
56
- return pickle.load(f)
57
- except Exception as e:
58
- print(f"Error downloading embeddings: {e}")
59
- return None
60
-
61
- def upload_embeddings(embeddings):
62
- """Upload embeddings to Hugging Face"""
63
- try:
64
- # Save embeddings locally first
65
- temp_path = "temp_embeddings.pkl"
66
- with open(temp_path, 'wb') as f:
67
- pickle.dump(embeddings, f)
68
-
69
- # Upload to Hugging Face
70
- hf_api.upload_file(
71
- path_or_fileobj=temp_path,
72
- path_in_repo=EMBEDDING_FILE,
73
- repo_id=REPO_ID,
74
- token=HF_TOKEN
75
- )
76
-
77
- # Clean up temp file
78
- os.remove(temp_path)
79
- print("Successfully uploaded embeddings")
80
- except Exception as e:
81
- print(f"Error uploading embeddings: {e}")
82
 
83
- def initialize_embeddings():
84
- """Initialize or load product embeddings"""
85
- print("Checking for existing embeddings...")
86
- embeddings = download_embeddings()
87
-
88
- if embeddings is not None:
89
- print("Loaded existing embeddings")
90
- return embeddings
91
-
92
- print("Creating new embeddings...")
93
- embeddings = []
94
- for idx, row in df.iterrows():
95
- product_text = create_product_text(row)
96
- embedding = get_embedding(product_text)
97
- if embedding:
98
- embeddings.append(embedding)
99
- else:
100
- embeddings.append([0] * 1536) # Default embedding dimension
101
- time.sleep(0.1) # Rate limiting for API calls
102
-
103
- # Upload new embeddings
104
- upload_embeddings(embeddings)
105
-
106
- return embeddings
107
-
108
- # Load or create embeddings
109
  print("Initializing embeddings...")
110
- product_embeddings = initialize_embeddings()
111
- product_embeddings_array = np.array(product_embeddings)
112
  print("Embeddings initialized")
113
 
 
114
  def find_similar_products(query_embedding, top_k=8):
115
- """Find most similar products using cosine similarity"""
116
- similarities = cosine_similarity(
117
- [query_embedding],
118
- product_embeddings_array
119
- )[0]
 
 
 
 
120
 
121
- top_indices = similarities.argsort()[-top_k:][::-1]
122
- return df.iloc[top_indices], similarities[top_indices]
 
123
 
124
  # Rest of the code remains the same...
125
  def analyze_query_and_find_products(query: str) -> str:
@@ -209,9 +144,11 @@ def analyze_query_and_find_products(query: str) -> str:
209
  # Add system status message
210
  def get_system_status():
211
  """Get system initialization status"""
 
 
212
  return {
213
- "embeddings_loaded": product_embeddings is not None,
214
- "embedding_count": len(product_embeddings) if product_embeddings else 0,
215
  "product_count": len(df)
216
  }
217
 
 
11
  from sklearn.metrics.pairwise import cosine_similarity
12
  from huggingface_hub import HfApi, hf_hub_download, upload_file
13
  from pathlib import Path
14
+ import faiss
15
 
16
  # Initialize OpenAI client
17
  client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
18
 
19
+ def initialize_embeddings_from_faiss(faiss_path: str):
20
+ """Load product embeddings directly from FAISS index"""
21
+ if not os.path.exists(faiss_path):
22
+ raise FileNotFoundError(f"FAISS index file not found at {faiss_path}")
23
+
24
+ print(f"Loading FAISS index from {faiss_path}...")
25
+ index = faiss.read_index(faiss_path)
26
+
27
+ # Extract embeddings from FAISS index
28
+ product_embeddings_array = faiss.vector_to_array(index.xb).reshape(index.ntotal, index.d)
29
+ print(f"FAISS index loaded with {index.ntotal} embeddings.")
30
+
31
+ return index, product_embeddings_array
32
 
33
  # Load CSV data
34
  df = pd.read_csv("item_new.csv", encoding='utf-8')
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ # Load embeddings from FAISS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  print("Initializing embeddings...")
39
+ faiss_path = "product_index.faiss" # Path to FAISS index file
40
+ faiss_index, product_embeddings_array = initialize_embeddings_from_faiss(faiss_path)
41
  print("Embeddings initialized")
42
 
43
+
44
  def find_similar_products(query_embedding, top_k=8):
45
+ """Find most similar products using FAISS index"""
46
+ if faiss_index is None:
47
+ raise ValueError("FAISS index is not loaded.")
48
+
49
+ # FAISS expects float32 type embeddings
50
+ query_embedding = np.array(query_embedding).astype('float32').reshape(1, -1)
51
+
52
+ # Perform FAISS search
53
+ distances, indices = faiss_index.search(query_embedding, top_k)
54
 
55
+ # Retrieve matching products
56
+ matching_products = df.iloc[indices[0]]
57
+ return matching_products, distances[0]
58
 
59
  # Rest of the code remains the same...
60
  def analyze_query_and_find_products(query: str) -> str:
 
144
  # Add system status message
145
  def get_system_status():
146
  """Get system initialization status"""
147
+ embeddings_loaded = faiss_index is not None
148
+ embedding_count = faiss_index.ntotal if embeddings_loaded else 0
149
  return {
150
+ "embeddings_loaded": embeddings_loaded,
151
+ "embedding_count": embedding_count,
152
  "product_count": len(df)
153
  }
154