vichudo commited on
Commit
254ca68
·
1 Parent(s): 0560e17
download_from_hub.py CHANGED
@@ -4,6 +4,8 @@ import pickle
4
  import sys
5
  import numpy as np
6
  from huggingface_hub import hf_hub_download, list_repo_files
 
 
7
 
8
  def ensure_dirs():
9
  """Create necessary directories if they don't exist."""
@@ -11,6 +13,64 @@ def ensure_dirs():
11
  os.makedirs("embeddings", exist_ok=True)
12
  os.makedirs("pdfs", exist_ok=True)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def create_fallback_data():
15
  """Create minimal fallback data if downloads fail."""
16
  print("Creating fallback data files...")
@@ -18,20 +78,22 @@ def create_fallback_data():
18
  # Create minimal embeddings
19
  try:
20
  print("Creating fallback embeddings...")
21
- # Create a small random matrix as embeddings (10 documents, 384 dimensions)
22
- embeddings = np.random.random((10, 384)).astype(np.float32)
 
 
23
  with open("embeddings/embeddings.pkl", "wb") as f:
24
  pickle.dump(embeddings, f)
25
 
26
- # Create a minimal FAISS index
27
  import faiss
28
- dimension = 384
29
  index = faiss.IndexFlatL2(dimension)
30
  index.add(embeddings)
31
  faiss.write_index(index, "embeddings/faiss_index.index")
32
  print("Fallback embeddings and FAISS index created successfully!")
33
  except Exception as e:
34
  print(f"Error creating fallback embeddings: {e}")
 
35
  return False
36
 
37
  # Create minimal document chunks
@@ -51,47 +113,27 @@ def create_fallback_data():
51
  print("Fallback document chunks created successfully!")
52
  except Exception as e:
53
  print(f"Error creating fallback document chunks: {e}")
 
54
  return False
55
 
56
- return True
 
57
 
58
  def download_datasets():
59
  """Download datasets from Hugging Face Hub."""
60
  print("Downloading data files from Hugging Face Hub...")
61
  download_success = True
62
 
63
- # Download embeddings
64
- try:
65
- from datasets import load_dataset
66
- print("Downloading embeddings...")
67
- # First check what files are available in the dataset repository
68
- try:
69
- files = list_repo_files("vichudo/agentic-defensor-embeddings", repo_type="dataset")
70
- print(f"Files in embeddings repository: {files}")
71
- except Exception as e:
72
- print(f"Error listing files in embeddings repository: {e}")
73
-
74
- embeddings_ds = load_dataset("vichudo/agentic-defensor-embeddings", split="train")
75
- print(f"Embeddings dataset info: {embeddings_ds}")
76
- print(f"Embeddings dataset features: {embeddings_ds.features}")
77
- print(f"First row of embeddings dataset: {embeddings_ds[0]}")
78
-
79
- if "data" not in embeddings_ds[0]:
80
- print("Error: No 'data' field found in embeddings dataset")
81
- print(f"Available fields: {embeddings_ds[0].keys()}")
82
- download_success = False
83
- else:
84
- embeddings_data = pickle.loads(embeddings_ds[0]["data"])
85
- with open("embeddings/embeddings.pkl", "wb") as f:
86
- pickle.dump(embeddings_data, f)
87
- print("Embeddings downloaded and saved successfully!")
88
- except Exception as e:
89
- print(f"Error downloading embeddings: {e}")
90
- download_success = False
91
 
92
- # Download FAISS index
93
  try:
94
- print("Downloading FAISS index...")
95
  # Try direct file download
96
  try:
97
  faiss_path = hf_hub_download(
@@ -99,9 +141,15 @@ def download_datasets():
99
  filename="faiss_index.index",
100
  repo_type="dataset"
101
  )
102
- # Copy to correct location
103
- os.system(f"cp {faiss_path} embeddings/faiss_index.index")
104
- print("FAISS index downloaded and saved successfully!")
 
 
 
 
 
 
105
  except Exception as e:
106
  print(f"Direct download of FAISS index failed: {e}")
107
  # Try alternate approach using dataset API
@@ -113,44 +161,278 @@ def download_datasets():
113
  import faiss
114
  faiss.write_index(embeddings_ds.faiss_index, "embeddings/faiss_index.index")
115
  print("FAISS index from dataset attributes saved successfully!")
 
116
  else:
117
- raise ValueError("No FAISS index found in dataset attributes")
118
  except Exception as inner_e:
119
  print(f"Alternative FAISS index download failed: {inner_e}")
120
- raise
121
  except Exception as e:
122
  print(f"Error downloading FAISS index: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  download_success = False
124
 
125
  # Download document chunks
126
  try:
127
- from datasets import load_dataset
128
- print("Downloading document chunks...")
129
  # First check what files are available
130
  try:
131
  files = list_repo_files("vichudo/agentic-defensor-chunks", repo_type="dataset")
132
  print(f"Files in chunks repository: {files}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  except Exception as e:
134
  print(f"Error listing files in chunks repository: {e}")
135
-
136
- chunks_ds = load_dataset("vichudo/agentic-defensor-chunks", split="train")
137
- print(f"Chunks dataset info: {chunks_ds}")
138
- print(f"Chunks dataset features: {chunks_ds.features}")
139
- print(f"First row of chunks dataset: {chunks_ds[0]}")
140
 
141
- if "data" not in chunks_ds[0]:
142
- print("Error: No 'data' field found in chunks dataset")
143
- print(f"Available fields: {chunks_ds[0].keys()}")
144
- download_success = False
145
- else:
146
- chunks_data = pickle.loads(chunks_ds[0]["data"])
147
- with open("data/doc_chunks.pkl", "wb") as f:
148
- pickle.dump(chunks_data, f)
149
- print("Document chunks downloaded and saved successfully!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  except Exception as e:
151
  print(f"Error downloading document chunks: {e}")
 
152
  download_success = False
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  return download_success
155
 
156
  if __name__ == "__main__":
@@ -159,12 +441,33 @@ if __name__ == "__main__":
159
 
160
  # If download fails, create fallback data
161
  if not success:
162
- print("Downloads failed. Creating fallback data...")
163
  success = create_fallback_data()
164
 
165
  if success:
166
- print("Data files setup completed successfully!")
167
- sys.exit(0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  else:
169
- print("Failed to set up data files.")
170
  sys.exit(1)
 
4
  import sys
5
  import numpy as np
6
  from huggingface_hub import hf_hub_download, list_repo_files
7
+ import traceback
8
+ import shutil
9
 
10
  def ensure_dirs():
11
  """Create necessary directories if they don't exist."""
 
13
  os.makedirs("embeddings", exist_ok=True)
14
  os.makedirs("pdfs", exist_ok=True)
15
 
16
+ def verify_embeddings_faiss_compatibility():
17
+ """Verify that the downloaded embeddings and FAISS index are compatible."""
18
+ print("Verifying compatibility between embeddings and FAISS index...")
19
+
20
+ try:
21
+ # Check if files exist
22
+ if not os.path.exists("embeddings/embeddings.pkl"):
23
+ print("Error: embeddings.pkl does not exist")
24
+ return False
25
+
26
+ if not os.path.exists("embeddings/faiss_index.index"):
27
+ print("Error: faiss_index.index does not exist")
28
+ return False
29
+
30
+ # Load embeddings
31
+ with open("embeddings/embeddings.pkl", "rb") as f:
32
+ embeddings = pickle.load(f)
33
+
34
+ print(f"Loaded embeddings with shape: {embeddings.shape if hasattr(embeddings, 'shape') else 'Unknown'}")
35
+
36
+ # Load FAISS index and check compatibility
37
+ import faiss
38
+ index = faiss.read_index("embeddings/faiss_index.index")
39
+
40
+ # Print FAISS index stats
41
+ print(f"FAISS index contains {index.ntotal} vectors of dimension {index.d}")
42
+
43
+ # Check if the dimensionality matches
44
+ if hasattr(embeddings, 'shape'):
45
+ if len(embeddings.shape) != 2:
46
+ print(f"Warning: embeddings should be a 2D array, got shape {embeddings.shape}")
47
+ return False
48
+
49
+ if embeddings.shape[1] != index.d:
50
+ print(f"Error: Dimension mismatch - embeddings: {embeddings.shape[1]}, FAISS index: {index.d}")
51
+ return False
52
+
53
+ # Check if number of vectors matches
54
+ if embeddings.shape[0] != index.ntotal:
55
+ print(f"Warning: Count mismatch - embeddings: {embeddings.shape[0]}, FAISS index: {index.ntotal}")
56
+ print("This might be acceptable if the index was created from a subset of embeddings")
57
+
58
+ # Test a simple query to ensure the index works
59
+ try:
60
+ test_query = np.zeros((1, index.d), dtype=np.float32)
61
+ D, I = index.search(test_query, 1)
62
+ print("FAISS index test query successful")
63
+ return True
64
+ except Exception as e:
65
+ print(f"FAISS index test query failed: {e}")
66
+ traceback.print_exc()
67
+ return False
68
+
69
+ except Exception as e:
70
+ print(f"Compatibility verification failed: {e}")
71
+ traceback.print_exc()
72
+ return False
73
+
74
  def create_fallback_data():
75
  """Create minimal fallback data if downloads fail."""
76
  print("Creating fallback data files...")
 
78
  # Create minimal embeddings
79
  try:
80
  print("Creating fallback embeddings...")
81
+
82
+ # Create a small random matrix as embeddings (10 documents, 1536 dimensions - OpenAI dimension)
83
+ dimension = 1536 # text-embedding-3-small dimension
84
+ embeddings = np.random.random((10, dimension)).astype(np.float32)
85
  with open("embeddings/embeddings.pkl", "wb") as f:
86
  pickle.dump(embeddings, f)
87
 
88
+ # Create a minimal FAISS index with same dimension
89
  import faiss
 
90
  index = faiss.IndexFlatL2(dimension)
91
  index.add(embeddings)
92
  faiss.write_index(index, "embeddings/faiss_index.index")
93
  print("Fallback embeddings and FAISS index created successfully!")
94
  except Exception as e:
95
  print(f"Error creating fallback embeddings: {e}")
96
+ traceback.print_exc()
97
  return False
98
 
99
  # Create minimal document chunks
 
113
  print("Fallback document chunks created successfully!")
114
  except Exception as e:
115
  print(f"Error creating fallback document chunks: {e}")
116
+ traceback.print_exc()
117
  return False
118
 
119
+ # Verify compatibility
120
+ return verify_embeddings_faiss_compatibility()
121
 
122
  def download_datasets():
123
  """Download datasets from Hugging Face Hub."""
124
  print("Downloading data files from Hugging Face Hub...")
125
  download_success = True
126
 
127
+ # Track what we've downloaded
128
+ faiss_downloaded = False
129
+ embeddings_downloaded = False
130
+ chunks_downloaded = False
131
+
132
+ # Try multiple download methods
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
+ # Download FAISS index first
135
  try:
136
+ print("\nDownloading FAISS index...")
137
  # Try direct file download
138
  try:
139
  faiss_path = hf_hub_download(
 
141
  filename="faiss_index.index",
142
  repo_type="dataset"
143
  )
144
+ # Copy to correct location with error handling
145
+ if os.path.exists(faiss_path):
146
+ shutil.copy(faiss_path, "embeddings/faiss_index.index")
147
+ print(f"FAISS index downloaded and saved successfully to embeddings/faiss_index.index!")
148
+ print(f"FAISS index size: {os.path.getsize('embeddings/faiss_index.index') / (1024*1024):.2f} MB")
149
+ faiss_downloaded = True
150
+ else:
151
+ print(f"Downloaded FAISS path {faiss_path} does not exist")
152
+
153
  except Exception as e:
154
  print(f"Direct download of FAISS index failed: {e}")
155
  # Try alternate approach using dataset API
 
161
  import faiss
162
  faiss.write_index(embeddings_ds.faiss_index, "embeddings/faiss_index.index")
163
  print("FAISS index from dataset attributes saved successfully!")
164
+ faiss_downloaded = True
165
  else:
166
+ print("No FAISS index found in dataset attributes")
167
  except Exception as inner_e:
168
  print(f"Alternative FAISS index download failed: {inner_e}")
 
169
  except Exception as e:
170
  print(f"Error downloading FAISS index: {e}")
171
+ traceback.print_exc()
172
+ download_success = False
173
+
174
+ # Download embeddings
175
+ try:
176
+ print("\nDownloading embeddings...")
177
+ # First check what files are available in the dataset repository
178
+ try:
179
+ files = list_repo_files("vichudo/agentic-defensor-embeddings", repo_type="dataset")
180
+ print(f"Files in embeddings repository: {files}")
181
+
182
+ # Try downloading directly if .pkl file is found
183
+ for file in files:
184
+ if file.endswith("embeddings.pkl") or file.endswith("embeddings.pt") or file.endswith("embeddings.npy"):
185
+ print(f"Found embeddings file: {file}")
186
+ try:
187
+ emb_path = hf_hub_download(
188
+ repo_id="vichudo/agentic-defensor-embeddings",
189
+ filename=file,
190
+ repo_type="dataset"
191
+ )
192
+ # Copy to correct location
193
+ shutil.copy(emb_path, "embeddings/embeddings.pkl")
194
+ print(f"Embeddings downloaded directly from file {file} and saved successfully!")
195
+ embeddings_downloaded = True
196
+ break
197
+ except Exception as file_e:
198
+ print(f"Direct embeddings file download failed: {file_e}")
199
+ except Exception as e:
200
+ print(f"Error listing files in embeddings repository: {e}")
201
+
202
+ # If direct file download failed, try using the dataset API
203
+ if not embeddings_downloaded:
204
+ try:
205
+ from datasets import load_dataset
206
+ import pandas as pd
207
+
208
+ # Try to download the dataset
209
+ embeddings_ds = load_dataset("vichudo/agentic-defensor-embeddings", split="train")
210
+ print(f"Embeddings dataset info: {embeddings_ds}")
211
+ print(f"Embeddings dataset features: {embeddings_ds.features}")
212
+
213
+ # Check first row to understand structure
214
+ if len(embeddings_ds) > 0:
215
+ print(f"First row keys: {embeddings_ds[0].keys()}")
216
+
217
+ # Approach 1: Try to find data blob
218
+ if "data" in embeddings_ds[0]:
219
+ print("Found 'data' blob in dataset")
220
+ embeddings_data = pickle.loads(embeddings_ds[0]["data"])
221
+ with open("embeddings/embeddings.pkl", "wb") as f:
222
+ pickle.dump(embeddings_data, f)
223
+ print("Embeddings from data blob saved successfully!")
224
+ embeddings_downloaded = True
225
+
226
+ # Approach 2: Try to find embedding column
227
+ elif "embedding" in embeddings_ds[0]:
228
+ print("Found 'embedding' column in dataset")
229
+ # Convert dataset to pandas to handle embedding extraction
230
+ df = pd.DataFrame(embeddings_ds)
231
+ embeddings_array = np.stack(df.embedding.values)
232
+ with open("embeddings/embeddings.pkl", "wb") as f:
233
+ pickle.dump(embeddings_array, f)
234
+ print("Embeddings from column data saved successfully!")
235
+ embeddings_downloaded = True
236
+
237
+ # Approach 3: Try to work with parquet files directly
238
+ else:
239
+ try:
240
+ print("Trying to work with parquet files directly")
241
+ import pyarrow.parquet as pq
242
+
243
+ # Find all parquet files in the repository
244
+ parquet_files = [f for f in files if f.endswith('.parquet')]
245
+ if parquet_files:
246
+ print(f"Found parquet files: {parquet_files}")
247
+ for parquet_file in parquet_files:
248
+ try:
249
+ parquet_path = hf_hub_download(
250
+ repo_id="vichudo/agentic-defensor-embeddings",
251
+ filename=parquet_file,
252
+ repo_type="dataset"
253
+ )
254
+
255
+ # Try to read parquet and extract embeddings
256
+ table = pq.read_table(parquet_path)
257
+ df = table.to_pandas()
258
+ print(f"Parquet columns: {df.columns}")
259
+
260
+ if "embedding" in df.columns:
261
+ print("Found 'embedding' column in parquet file")
262
+ embeddings_array = np.stack(df.embedding.values)
263
+ with open("embeddings/embeddings.pkl", "wb") as f:
264
+ pickle.dump(embeddings_array, f)
265
+ print("Embeddings from parquet file saved successfully!")
266
+ embeddings_downloaded = True
267
+ break
268
+ elif "data" in df.columns:
269
+ print("Found 'data' column in parquet file")
270
+ embeddings_data = pickle.loads(df.data.iloc[0])
271
+ with open("embeddings/embeddings.pkl", "wb") as f:
272
+ pickle.dump(embeddings_data, f)
273
+ print("Embeddings data from parquet file saved successfully!")
274
+ embeddings_downloaded = True
275
+ break
276
+ except Exception as parquet_e:
277
+ print(f"Error processing parquet file {parquet_file}: {parquet_e}")
278
+ except Exception as parquet_approach_e:
279
+ print(f"Error in parquet approach: {parquet_approach_e}")
280
+ except Exception as ds_e:
281
+ print(f"Error processing embeddings dataset: {ds_e}")
282
+ traceback.print_exc()
283
+ except Exception as e:
284
+ print(f"Error downloading embeddings: {e}")
285
+ traceback.print_exc()
286
  download_success = False
287
 
288
  # Download document chunks
289
  try:
290
+ print("\nDownloading document chunks...")
 
291
  # First check what files are available
292
  try:
293
  files = list_repo_files("vichudo/agentic-defensor-chunks", repo_type="dataset")
294
  print(f"Files in chunks repository: {files}")
295
+
296
+ # Try direct file download if .pkl file exists
297
+ for file in files:
298
+ if file.endswith("doc_chunks.pkl") or file.endswith("chunks.pkl"):
299
+ print(f"Found chunks file: {file}")
300
+ try:
301
+ chunks_path = hf_hub_download(
302
+ repo_id="vichudo/agentic-defensor-chunks",
303
+ filename=file,
304
+ repo_type="dataset"
305
+ )
306
+ # Copy to correct location
307
+ shutil.copy(chunks_path, "data/doc_chunks.pkl")
308
+ print(f"Document chunks downloaded directly from file {file} and saved successfully!")
309
+ chunks_downloaded = True
310
+ break
311
+ except Exception as file_e:
312
+ print(f"Direct chunks file download failed: {file_e}")
313
  except Exception as e:
314
  print(f"Error listing files in chunks repository: {e}")
 
 
 
 
 
315
 
316
+ # If direct file approach failed, try dataset API
317
+ if not chunks_downloaded:
318
+ try:
319
+ from datasets import load_dataset
320
+ import pandas as pd
321
+
322
+ chunks_ds = load_dataset("vichudo/agentic-defensor-chunks", split="train")
323
+ print(f"Chunks dataset info: {chunks_ds}")
324
+ print(f"Chunks dataset features: {chunks_ds.features}")
325
+
326
+ if len(chunks_ds) > 0:
327
+ print(f"First row keys: {chunks_ds[0].keys()}")
328
+
329
+ # Approach 1: Try to find data blob
330
+ if "data" in chunks_ds[0]:
331
+ print("Found 'data' blob in chunks dataset")
332
+ chunks_data = pickle.loads(chunks_ds[0]["data"])
333
+ with open("data/doc_chunks.pkl", "wb") as f:
334
+ pickle.dump(chunks_data, f)
335
+ print("Document chunks from data blob saved successfully!")
336
+ chunks_downloaded = True
337
+
338
+ # Approach 2: Try to reconstruct from text columns
339
+ elif all(field in chunks_ds[0] for field in ["text", "source"]):
340
+ print("Found text and source columns, reconstructing chunks")
341
+ df = pd.DataFrame(chunks_ds)
342
+ chunks_list = []
343
+ for _, row in df.iterrows():
344
+ chunk = {
345
+ "text": row["text"],
346
+ "source": row["source"]
347
+ }
348
+ # Add other fields if available
349
+ for field in ["page", "chunk_id", "metadata"]:
350
+ if field in row:
351
+ chunk[field] = row[field]
352
+ chunks_list.append(chunk)
353
+
354
+ with open("data/doc_chunks.pkl", "wb") as f:
355
+ pickle.dump(chunks_list, f)
356
+ print(f"Reconstructed {len(chunks_list)} document chunks successfully!")
357
+ chunks_downloaded = True
358
+
359
+ # Approach 3: Try to work with parquet files directly
360
+ else:
361
+ try:
362
+ print("Trying to work with parquet files directly for chunks")
363
+ import pyarrow.parquet as pq
364
+
365
+ # Find all parquet files in the repository
366
+ parquet_files = [f for f in files if f.endswith('.parquet')]
367
+ if parquet_files:
368
+ print(f"Found parquet files: {parquet_files}")
369
+ for parquet_file in parquet_files:
370
+ try:
371
+ parquet_path = hf_hub_download(
372
+ repo_id="vichudo/agentic-defensor-chunks",
373
+ filename=parquet_file,
374
+ repo_type="dataset"
375
+ )
376
+
377
+ # Try to read parquet and extract chunks
378
+ table = pq.read_table(parquet_path)
379
+ df = table.to_pandas()
380
+ print(f"Parquet columns: {df.columns}")
381
+
382
+ if "data" in df.columns:
383
+ print("Found 'data' column in chunks parquet file")
384
+ chunks_data = pickle.loads(df.data.iloc[0])
385
+ with open("data/doc_chunks.pkl", "wb") as f:
386
+ pickle.dump(chunks_data, f)
387
+ print("Chunks data from parquet file saved successfully!")
388
+ chunks_downloaded = True
389
+ break
390
+ elif all(field in df.columns for field in ["text", "source"]):
391
+ print("Found text and source columns in parquet, reconstructing")
392
+ chunks_list = []
393
+ for _, row in df.iterrows():
394
+ chunk = {
395
+ "text": row["text"],
396
+ "source": row["source"]
397
+ }
398
+ # Add other fields if available
399
+ for field in ["page", "chunk_id", "metadata"]:
400
+ if field in row:
401
+ chunk[field] = row[field]
402
+ chunks_list.append(chunk)
403
+
404
+ with open("data/doc_chunks.pkl", "wb") as f:
405
+ pickle.dump(chunks_list, f)
406
+ print(f"Reconstructed {len(chunks_list)} document chunks from parquet successfully!")
407
+ chunks_downloaded = True
408
+ break
409
+ except Exception as parquet_e:
410
+ print(f"Error processing chunks parquet file {parquet_file}: {parquet_e}")
411
+ except Exception as parquet_approach_e:
412
+ print(f"Error in chunks parquet approach: {parquet_approach_e}")
413
+ except Exception as ds_e:
414
+ print(f"Error processing chunks dataset: {ds_e}")
415
+ traceback.print_exc()
416
  except Exception as e:
417
  print(f"Error downloading document chunks: {e}")
418
+ traceback.print_exc()
419
  download_success = False
420
 
421
+ # Check what was successfully downloaded
422
+ print("\nDownload summary:")
423
+ print(f"- FAISS index: {'✓' if faiss_downloaded else '✗'}")
424
+ print(f"- Embeddings: {'✓' if embeddings_downloaded else '✗'}")
425
+ print(f"- Document chunks: {'✓' if chunks_downloaded else '✗'}")
426
+
427
+ download_success = faiss_downloaded and embeddings_downloaded and chunks_downloaded
428
+
429
+ # If downloads were successful, verify compatibility
430
+ if download_success:
431
+ compatible = verify_embeddings_faiss_compatibility()
432
+ if not compatible:
433
+ print("Warning: Downloaded files are not compatible, will use fallback data")
434
+ download_success = False
435
+
436
  return download_success
437
 
438
  if __name__ == "__main__":
 
441
 
442
  # If download fails, create fallback data
443
  if not success:
444
+ print("\n\nDownloads failed or data is incompatible. Creating fallback data...")
445
  success = create_fallback_data()
446
 
447
  if success:
448
+ # Just to be extra sure, load everything to verify
449
+ try:
450
+ import faiss
451
+ index = faiss.read_index("embeddings/faiss_index.index")
452
+ with open("embeddings/embeddings.pkl", "rb") as f:
453
+ embeddings = pickle.load(f)
454
+ with open("data/doc_chunks.pkl", "rb") as f:
455
+ chunks = pickle.load(f)
456
+
457
+ print("\nFinal verification:")
458
+ print(f"FAISS index: {index.ntotal} vectors of dimension {index.d}")
459
+ if hasattr(embeddings, 'shape'):
460
+ print(f"Embeddings: shape {embeddings.shape}")
461
+ else:
462
+ print(f"Embeddings: type {type(embeddings)}")
463
+ print(f"Document chunks: {len(chunks)} chunks")
464
+
465
+ print("\nData files setup completed successfully!")
466
+ sys.exit(0)
467
+ except Exception as e:
468
+ print(f"\nFinal verification failed: {e}")
469
+ traceback.print_exc()
470
+ sys.exit(1)
471
  else:
472
+ print("\nFailed to set up data files.")
473
  sys.exit(1)
requirements.txt CHANGED
@@ -9,4 +9,8 @@ numpy>=1.24.0
9
  scikit-learn>=1.3.0
10
  pandas>=2.0.0
11
  torch>=2.0.0
12
- langchain>=0.0.335
 
 
 
 
 
9
  scikit-learn>=1.3.0
10
  pandas>=2.0.0
11
  torch>=2.0.0
12
+ langchain>=0.0.335
13
+ pyarrow>=14.0.1
14
+ datasets>=2.15.0
15
+ huggingface_hub>=0.19.0
16
+ requests>=2.31.0
src/embeddings/embedder.py CHANGED
@@ -20,7 +20,31 @@ class TextEmbedder:
20
  self.model = model
21
  self.batch_size = batch_size
22
  self.client = OpenAI(api_key=OPENAI_API_KEY)
23
- self.embedding_dim = 1536 # Default dimension for text-embedding-3-small
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def get_embedding_for_text(self, text: str) -> List[float]:
26
  """Generate embedding for a single text."""
@@ -80,7 +104,28 @@ class TextEmbedder:
80
  input=[query],
81
  model=self.model
82
  )
83
- return np.array(q_response.data[0].embedding, dtype='float32').reshape(1, -1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  except Exception as e:
85
  print(f"Error creating embedding for query: {e}")
 
 
86
  return np.zeros((1, self.embedding_dim), dtype='float32')
 
20
  self.model = model
21
  self.batch_size = batch_size
22
  self.client = OpenAI(api_key=OPENAI_API_KEY)
23
+
24
+ # Default dimension for different models
25
+ self.embedding_dim = self._get_model_dimension(model)
26
+ print(f"Initialized TextEmbedder with model {model}, dimension {self.embedding_dim}")
27
+
28
+ def _get_model_dimension(self, model_name: str) -> int:
29
+ """Get the embedding dimension for a given model."""
30
+ # Mapping of model names to dimensions
31
+ dimensions = {
32
+ "text-embedding-3-small": 1536,
33
+ "text-embedding-3-large": 3072,
34
+ "text-embedding-ada-002": 1536,
35
+ # Add other models if needed
36
+ }
37
+
38
+ # Return the dimension for the model or default to 1536 (most common)
39
+ return dimensions.get(model_name, 1536)
40
+
41
+ def set_dimension(self, dimension: int) -> None:
42
+ """
43
+ Set the embedding dimension explicitly.
44
+ Use this to ensure compatibility with existing FAISS indices.
45
+ """
46
+ self.embedding_dim = dimension
47
+ print(f"Explicitly set embedding dimension to {dimension}")
48
 
49
  def get_embedding_for_text(self, text: str) -> List[float]:
50
  """Generate embedding for a single text."""
 
104
  input=[query],
105
  model=self.model
106
  )
107
+ embedding = np.array(q_response.data[0].embedding, dtype='float32')
108
+
109
+ # Check and log the actual dimension
110
+ actual_dim = embedding.shape[0]
111
+ if actual_dim != self.embedding_dim:
112
+ print(f"Warning: OpenAI returned embedding of dimension {actual_dim}, expected {self.embedding_dim}")
113
+
114
+ # Handle dimension mismatch
115
+ if actual_dim > self.embedding_dim:
116
+ # Truncate the embedding to match expected dimension
117
+ print(f"Truncating embedding from {actual_dim} to {self.embedding_dim}")
118
+ embedding = embedding[:self.embedding_dim]
119
+ elif actual_dim < self.embedding_dim:
120
+ # Pad the embedding to match expected dimension
121
+ print(f"Padding embedding from {actual_dim} to {self.embedding_dim}")
122
+ padding = np.zeros(self.embedding_dim - actual_dim, dtype='float32')
123
+ embedding = np.concatenate([embedding, padding])
124
+
125
+ # Return the embedding as a 2D array
126
+ return embedding.reshape(1, -1)
127
  except Exception as e:
128
  print(f"Error creating embedding for query: {e}")
129
+ import traceback
130
+ traceback.print_exc()
131
  return np.zeros((1, self.embedding_dim), dtype='float32')
src/models/retriever.py CHANGED
@@ -89,10 +89,27 @@ class Retriever:
89
  resource_manager.faiss_index = self.index
90
  resource_manager.doc_chunks = self.doc_chunks
91
  resource_manager.initialized = True
 
 
 
 
92
  except Exception as e:
93
  print(f"Error loading resources: {e}")
 
 
94
  raise
95
 
 
 
 
 
 
 
 
 
 
 
 
96
  def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
97
  """
98
  Retrieve the most relevant document chunks for a query.
@@ -117,9 +134,35 @@ class Retriever:
117
 
118
  # Search the FAISS index
119
  try:
 
 
120
  distances, indices = self.index.search(query_embedding, top_k)
 
 
 
 
121
  except Exception as e:
122
  print(f"Error during FAISS search: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  # Return all available chunks as fallback
124
  return self._get_all_chunks_with_placeholder_scores()
125
 
 
89
  resource_manager.faiss_index = self.index
90
  resource_manager.doc_chunks = self.doc_chunks
91
  resource_manager.initialized = True
92
+
93
+ # Ensure embedder dimension matches FAISS index
94
+ self._ensure_embedder_compatibility()
95
+
96
  except Exception as e:
97
  print(f"Error loading resources: {e}")
98
+ import traceback
99
+ traceback.print_exc()
100
  raise
101
 
102
+ def _ensure_embedder_compatibility(self) -> None:
103
+ """Ensure the embedder's dimension matches the FAISS index dimension."""
104
+ if self.index is not None and hasattr(self.embedder, 'set_dimension'):
105
+ faiss_dim = self.index.d
106
+ embedder_dim = self.embedder.embedding_dim
107
+
108
+ if faiss_dim != embedder_dim:
109
+ print(f"Warning: Dimension mismatch between FAISS index ({faiss_dim}) and embedder ({embedder_dim})")
110
+ print(f"Adjusting embedder dimension to match FAISS index")
111
+ self.embedder.set_dimension(faiss_dim)
112
+
113
  def retrieve(self, query: str, top_k: Optional[int] = None) -> List[Dict[str, Any]]:
114
  """
115
  Retrieve the most relevant document chunks for a query.
 
134
 
135
  # Search the FAISS index
136
  try:
137
+ print(f"FAISS index info - ntotal: {self.index.ntotal}, dimension: {self.index.d}")
138
+ print(f"Query embedding shape: {query_embedding.shape}")
139
  distances, indices = self.index.search(query_embedding, top_k)
140
+ # Log first few results for debugging
141
+ top_indices = indices[0][:min(3, len(indices[0]))]
142
+ top_distances = distances[0][:min(3, len(distances[0]))]
143
+ print(f"Top 3 results - indices: {top_indices}, distances: {top_distances}")
144
  except Exception as e:
145
  print(f"Error during FAISS search: {e}")
146
+ import traceback
147
+ traceback.print_exc()
148
+
149
+ # Provide diagnostic information
150
+ try:
151
+ # Check if embeddings and index are compatible
152
+ if self.index is None:
153
+ print("FAISS index is None - index was not loaded properly")
154
+ else:
155
+ print(f"FAISS index dimension: {self.index.d}, total vectors: {self.index.ntotal}")
156
+
157
+ if query_embedding is None:
158
+ print("Query embedding is None")
159
+ else:
160
+ print(f"Query embedding shape: {query_embedding.shape}, dtype: {query_embedding.dtype}")
161
+ if query_embedding.shape[1] != self.index.d:
162
+ print(f"Dimension mismatch: query embedding ({query_embedding.shape[1]}) vs. FAISS index ({self.index.d})")
163
+ except Exception as diagnostic_e:
164
+ print(f"Error during diagnostics: {diagnostic_e}")
165
+
166
  # Return all available chunks as fallback
167
  return self._get_all_chunks_with_placeholder_scores()
168