Spaces:
Sleeping
Sleeping
Update search_utils.py
Browse files- search_utils.py +24 -8
search_utils.py
CHANGED
|
@@ -36,7 +36,6 @@ class MetadataManager:
|
|
| 36 |
self._ensure_directories()
|
| 37 |
self._unzip_if_needed()
|
| 38 |
self._build_shard_map()
|
| 39 |
-
self._init_url_resolver()
|
| 40 |
logger.info(f"Total documents indexed: {self.total_docs}")
|
| 41 |
logger.info(f"Total shards found: {len(self.shard_map)}")
|
| 42 |
|
|
@@ -145,12 +144,15 @@ class MetadataManager:
|
|
| 145 |
shard_path = self.shard_dir / shard
|
| 146 |
if not shard_path.exists():
|
| 147 |
logger.error(f"Shard file not found: {shard_path}")
|
| 148 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
|
|
|
| 149 |
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
|
| 150 |
logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
|
|
|
|
| 151 |
try:
|
| 152 |
-
self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary"])
|
| 153 |
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
|
|
|
| 154 |
except Exception as e:
|
| 155 |
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
| 156 |
try:
|
|
@@ -158,7 +160,7 @@ class MetadataManager:
|
|
| 158 |
logger.info(f"Parquet schema: {schema}")
|
| 159 |
except Exception:
|
| 160 |
pass
|
| 161 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
| 162 |
df = self.loaded_shards[shard]
|
| 163 |
df_len = len(df)
|
| 164 |
valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
|
|
@@ -168,9 +170,10 @@ class MetadataManager:
|
|
| 168 |
chunk = df.iloc[valid_local_indices]
|
| 169 |
logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
|
| 170 |
return chunk
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
|
| 173 |
-
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
| 174 |
|
| 175 |
def get_metadata(self, global_indices):
|
| 176 |
"""Retrieve metadata for a batch of global indices using parallel shard processing."""
|
|
@@ -328,14 +331,17 @@ class SemanticSearch:
|
|
| 328 |
if index.ntotal == 0:
|
| 329 |
self.logger.warning(f"Skipping empty shard {shard_idx}")
|
| 330 |
return None
|
|
|
|
| 331 |
try:
|
| 332 |
shard_start = time.time()
|
| 333 |
distances, indices = index.search(query_embedding, top_k)
|
| 334 |
valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
|
| 335 |
valid_indices = indices[0][valid_mask].tolist()
|
| 336 |
valid_distances = distances[0][valid_mask].tolist()
|
|
|
|
| 337 |
if len(valid_indices) != top_k:
|
| 338 |
self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
|
|
|
|
| 339 |
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
|
| 340 |
self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
|
| 341 |
return valid_distances, global_indices
|
|
@@ -348,21 +354,23 @@ class SemanticSearch:
|
|
| 348 |
process_start = time.time()
|
| 349 |
if global_indices.size == 0 or distances.size == 0:
|
| 350 |
self.logger.warning("No search results to process")
|
| 351 |
-
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
|
| 352 |
try:
|
| 353 |
self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
|
| 354 |
metadata_start = time.time()
|
| 355 |
results = self.metadata_mgr.get_metadata(global_indices)
|
| 356 |
self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
|
|
|
|
| 357 |
if len(results) == 0:
|
| 358 |
self.logger.warning("No metadata found for indices")
|
| 359 |
-
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
|
| 360 |
if len(results) != len(distances):
|
| 361 |
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
|
| 362 |
if len(results) < len(distances):
|
| 363 |
distances = distances[:len(results)]
|
| 364 |
else:
|
| 365 |
distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
|
|
|
|
| 366 |
self.logger.debug("Calculating similarity scores")
|
| 367 |
results['similarity'] = 1 - (distances / 2)
|
| 368 |
if not results.empty:
|
|
@@ -370,13 +378,21 @@ class SemanticSearch:
|
|
| 370 |
f"max={results['similarity'].max():.3f}, " +
|
| 371 |
f"mean={results['similarity'].mean():.3f}")
|
| 372 |
results['source'] = results["source"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
|
| 374 |
pre_dedup = len(results)
|
| 375 |
-
results = results.drop_duplicates(subset=["title", "source"]).sort_values("similarity", ascending=False).head(top_k)
|
|
|
|
| 376 |
post_dedup = len(results)
|
| 377 |
if pre_dedup > post_dedup:
|
| 378 |
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
|
| 379 |
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
|
|
|
|
| 380 |
return results.reset_index(drop=True)
|
| 381 |
except Exception as e:
|
| 382 |
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
|
|
|
|
| 36 |
self._ensure_directories()
|
| 37 |
self._unzip_if_needed()
|
| 38 |
self._build_shard_map()
|
|
|
|
| 39 |
logger.info(f"Total documents indexed: {self.total_docs}")
|
| 40 |
logger.info(f"Total shards found: {len(self.shard_map)}")
|
| 41 |
|
|
|
|
| 144 |
shard_path = self.shard_dir / shard
|
| 145 |
if not shard_path.exists():
|
| 146 |
logger.error(f"Shard file not found: {shard_path}")
|
| 147 |
+
return pd.DataFrame(columns=["title", "summary", "similarity","authors", "source"])
|
| 148 |
+
|
| 149 |
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
|
| 150 |
logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
|
| 151 |
+
|
| 152 |
try:
|
| 153 |
+
self.loaded_shards[shard] = pd.read_parquet(shard_path, columns=["title", "summary", "source", "authors"])
|
| 154 |
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
| 155 |
+
|
| 156 |
except Exception as e:
|
| 157 |
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
| 158 |
try:
|
|
|
|
| 160 |
logger.info(f"Parquet schema: {schema}")
|
| 161 |
except Exception:
|
| 162 |
pass
|
| 163 |
+
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
| 164 |
df = self.loaded_shards[shard]
|
| 165 |
df_len = len(df)
|
| 166 |
valid_local_indices = [idx for idx in local_indices if 0 <= idx < df_len]
|
|
|
|
| 170 |
chunk = df.iloc[valid_local_indices]
|
| 171 |
logger.info(f"Retrieved {len(chunk)} records from shard {shard}")
|
| 172 |
return chunk
|
| 173 |
+
|
| 174 |
except Exception as e:
|
| 175 |
logger.error(f"Error processing shard {shard}: {str(e)}", exc_info=True)
|
| 176 |
+
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
| 177 |
|
| 178 |
def get_metadata(self, global_indices):
|
| 179 |
"""Retrieve metadata for a batch of global indices using parallel shard processing."""
|
|
|
|
| 331 |
if index.ntotal == 0:
|
| 332 |
self.logger.warning(f"Skipping empty shard {shard_idx}")
|
| 333 |
return None
|
| 334 |
+
|
| 335 |
try:
|
| 336 |
shard_start = time.time()
|
| 337 |
distances, indices = index.search(query_embedding, top_k)
|
| 338 |
valid_mask = (indices[0] >= 0) & (indices[0] < index.ntotal)
|
| 339 |
valid_indices = indices[0][valid_mask].tolist()
|
| 340 |
valid_distances = distances[0][valid_mask].tolist()
|
| 341 |
+
|
| 342 |
if len(valid_indices) != top_k:
|
| 343 |
self.logger.debug(f"Shard {shard_idx}: Found {len(valid_indices)} valid results out of {top_k}")
|
| 344 |
+
|
| 345 |
global_indices = [self._global_index(shard_idx, idx) for idx in valid_indices]
|
| 346 |
self.logger.debug(f"Shard {shard_idx} search completed in {time.time() - shard_start:.3f}s")
|
| 347 |
return valid_distances, global_indices
|
|
|
|
| 354 |
process_start = time.time()
|
| 355 |
if global_indices.size == 0 or distances.size == 0:
|
| 356 |
self.logger.warning("No search results to process")
|
| 357 |
+
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
|
| 358 |
try:
|
| 359 |
self.logger.info(f"Retrieving metadata for {len(global_indices)} indices")
|
| 360 |
metadata_start = time.time()
|
| 361 |
results = self.metadata_mgr.get_metadata(global_indices)
|
| 362 |
self.logger.info(f"Metadata retrieved in {time.time() - metadata_start:.2f}s, got {len(results)} records")
|
| 363 |
+
|
| 364 |
if len(results) == 0:
|
| 365 |
self.logger.warning("No metadata found for indices")
|
| 366 |
+
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
|
| 367 |
if len(results) != len(distances):
|
| 368 |
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
|
| 369 |
if len(results) < len(distances):
|
| 370 |
distances = distances[:len(results)]
|
| 371 |
else:
|
| 372 |
distances = np.pad(distances, (0, len(results) - len(distances)), 'constant', constant_values=1.0)
|
| 373 |
+
|
| 374 |
self.logger.debug("Calculating similarity scores")
|
| 375 |
results['similarity'] = 1 - (distances / 2)
|
| 376 |
if not results.empty:
|
|
|
|
| 378 |
f"max={results['similarity'].max():.3f}, " +
|
| 379 |
f"mean={results['similarity'].mean():.3f}")
|
| 380 |
results['source'] = results["source"]
|
| 381 |
+
|
| 382 |
+
# Ensure we have all required columns
|
| 383 |
+
required_columns = ["title", "summary", "authors", "source", "similarity"]
|
| 384 |
+
for col in required_columns:
|
| 385 |
+
if col not in results.columns:
|
| 386 |
+
results[col] = None # Fill missing columns with None
|
| 387 |
|
| 388 |
pre_dedup = len(results)
|
| 389 |
+
results = results.drop_duplicates(subset=["title","authors", "source"]).sort_values("similarity", ascending=False).head(top_k)
|
| 390 |
+
|
| 391 |
post_dedup = len(results)
|
| 392 |
if pre_dedup > post_dedup:
|
| 393 |
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
|
| 394 |
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s, returning {len(results)} items")
|
| 395 |
+
|
| 396 |
return results.reset_index(drop=True)
|
| 397 |
except Exception as e:
|
| 398 |
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
|