Spaces:
Sleeping
Sleeping
Update search_utils.py
Browse files- search_utils.py +44 -32
search_utils.py
CHANGED
|
@@ -12,7 +12,6 @@ from urllib.parse import quote
|
|
| 12 |
import requests
|
| 13 |
import shutil
|
| 14 |
import concurrent.futures
|
| 15 |
-
# Optional: Uncomment if you want to use lru_cache for instance methods
|
| 16 |
from functools import lru_cache
|
| 17 |
|
| 18 |
# Configure logging
|
|
@@ -144,22 +143,28 @@ class MetadataManager:
|
|
| 144 |
shard_path = self.shard_dir / shard
|
| 145 |
if not shard_path.exists():
|
| 146 |
logger.error(f"Shard file not found: {shard_path}")
|
| 147 |
-
return pd.DataFrame(columns=["title", "summary", "similarity","authors", "source"])
|
| 148 |
|
| 149 |
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
|
| 150 |
logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
|
| 151 |
|
| 152 |
try:
|
| 153 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
| 155 |
|
| 156 |
except Exception as e:
|
| 157 |
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
| 158 |
-
try:
|
| 159 |
-
schema = pd.read_parquet(shard_path, engine='pyarrow').dtypes
|
| 160 |
-
logger.info(f"Parquet schema: {schema}")
|
| 161 |
-
except Exception:
|
| 162 |
-
pass
|
| 163 |
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
| 164 |
df = self.loaded_shards[shard]
|
| 165 |
df_len = len(df)
|
|
@@ -220,8 +225,8 @@ class MetadataManager:
|
|
| 220 |
else:
|
| 221 |
logger.warning("No metadata records retrieved")
|
| 222 |
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
| 223 |
-
|
| 224 |
-
|
| 225 |
class SemanticSearch:
|
| 226 |
def __init__(self):
|
| 227 |
self.shard_dir = Path("compressed_shards")
|
|
@@ -310,7 +315,6 @@ class SemanticSearch:
|
|
| 310 |
|
| 311 |
all_distances = []
|
| 312 |
all_global_indices = []
|
| 313 |
-
# Run shard searches in parallel
|
| 314 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 315 |
futures = {
|
| 316 |
executor.submit(self._search_shard, shard_idx, index, query_embedding, top_k): shard_idx
|
|
@@ -351,7 +355,7 @@ class SemanticSearch:
|
|
| 351 |
return None
|
| 352 |
|
| 353 |
def _process_results(self, distances, global_indices, top_k):
|
| 354 |
-
"""Process raw search results
|
| 355 |
process_start = time.time()
|
| 356 |
if global_indices.size == 0 or distances.size == 0:
|
| 357 |
self.logger.warning("No search results to process")
|
|
@@ -367,33 +371,41 @@ class SemanticSearch:
|
|
| 367 |
self.logger.warning("No metadata found for indices")
|
| 368 |
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
|
| 369 |
|
|
|
|
| 370 |
if len(results) != len(distances):
|
| 371 |
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
required_columns = ["title", "summary", "authors", "source", "similarity"]
|
| 384 |
-
for col in required_columns:
|
| 385 |
-
if col not in results.columns:
|
| 386 |
-
results[col] = None # Fill missing columns with None
|
| 387 |
-
|
| 388 |
pre_dedup = len(results)
|
| 389 |
-
results =
|
| 390 |
-
|
|
|
|
|
|
|
|
|
|
| 391 |
post_dedup = len(results)
|
|
|
|
| 392 |
if pre_dedup > post_dedup:
|
| 393 |
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
|
| 394 |
-
|
| 395 |
-
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s
|
| 396 |
return results[required_columns].reset_index(drop=True)
|
|
|
|
| 397 |
except Exception as e:
|
| 398 |
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
|
| 399 |
-
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
|
|
|
|
| 12 |
import requests
|
| 13 |
import shutil
|
| 14 |
import concurrent.futures
|
|
|
|
| 15 |
from functools import lru_cache
|
| 16 |
|
| 17 |
# Configure logging
|
|
|
|
| 143 |
shard_path = self.shard_dir / shard
|
| 144 |
if not shard_path.exists():
|
| 145 |
logger.error(f"Shard file not found: {shard_path}")
|
| 146 |
+
return pd.DataFrame(columns=["title", "summary", "similarity", "authors", "source"])
|
| 147 |
|
| 148 |
file_size_mb = os.path.getsize(shard_path) / (1024 * 1024)
|
| 149 |
logger.info(f"Loading shard file: {shard} (size: {file_size_mb:.2f} MB)")
|
| 150 |
|
| 151 |
try:
|
| 152 |
+
# Load with explicit dtype for source column
|
| 153 |
+
self.loaded_shards[shard] = pd.read_parquet(
|
| 154 |
+
shard_path,
|
| 155 |
+
columns=["title", "summary", "source", "authors"],
|
| 156 |
+
dtype={'source': 'str'}
|
| 157 |
+
)
|
| 158 |
+
# Convert source strings to lists
|
| 159 |
+
self.loaded_shards[shard]['source'] = self.loaded_shards[shard]['source'].apply(
|
| 160 |
+
lambda x: x.split("; ") if isinstance(x, str) else []
|
| 161 |
+
)
|
| 162 |
+
# Handle missing summaries
|
| 163 |
+
self.loaded_shards[shard]['summary'] = self.loaded_shards[shard]['summary'].fillna("")
|
| 164 |
logger.info(f"Loaded shard {shard} with {len(self.loaded_shards[shard])} rows")
|
| 165 |
|
| 166 |
except Exception as e:
|
| 167 |
logger.error(f"Failed to read parquet file {shard}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
return pd.DataFrame(columns=["title", "summary", "similarity", "source", "authors"])
|
| 169 |
df = self.loaded_shards[shard]
|
| 170 |
df_len = len(df)
|
|
|
|
| 225 |
else:
|
| 226 |
logger.warning("No metadata records retrieved")
|
| 227 |
return pd.DataFrame(columns=["title", "summary", "similarity", "source"])
|
| 228 |
+
|
| 229 |
+
|
| 230 |
class SemanticSearch:
|
| 231 |
def __init__(self):
|
| 232 |
self.shard_dir = Path("compressed_shards")
|
|
|
|
| 315 |
|
| 316 |
all_distances = []
|
| 317 |
all_global_indices = []
|
|
|
|
| 318 |
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 319 |
futures = {
|
| 320 |
executor.submit(self._search_shard, shard_idx, index, query_embedding, top_k): shard_idx
|
|
|
|
| 355 |
return None
|
| 356 |
|
| 357 |
def _process_results(self, distances, global_indices, top_k):
|
| 358 |
+
"""Process raw search results with correct similarity calculation."""
|
| 359 |
process_start = time.time()
|
| 360 |
if global_indices.size == 0 or distances.size == 0:
|
| 361 |
self.logger.warning("No search results to process")
|
|
|
|
| 371 |
self.logger.warning("No metadata found for indices")
|
| 372 |
return pd.DataFrame(columns=["title", "summary", "source", "authors", "similarity"])
|
| 373 |
|
| 374 |
+
# Handle distance-results alignment
|
| 375 |
if len(results) != len(distances):
|
| 376 |
self.logger.warning(f"Mismatch between distances ({len(distances)}) and results ({len(results)})")
|
| 377 |
+
min_len = min(len(results), len(distances))
|
| 378 |
+
results = results.iloc[:min_len]
|
| 379 |
+
distances = distances[:min_len]
|
| 380 |
+
|
| 381 |
+
# Calculate similarity (cosine similarity = inner product for normalized embeddings)
|
| 382 |
+
results['similarity'] = distances
|
| 383 |
+
|
| 384 |
+
# Ensure URL lists are properly formatted
|
| 385 |
+
results['source'] = results['source'].apply(
|
| 386 |
+
lambda x: [
|
| 387 |
+
url.strip().rstrip(')') # Clean trailing parentheses and whitespace
|
| 388 |
+
for url in str(x).split(';') # Split on semicolons
|
| 389 |
+
if url.strip() # Remove empty strings
|
| 390 |
+
] if isinstance(x, (str, list)) else []
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Deduplicate and sort
|
| 394 |
required_columns = ["title", "summary", "authors", "source", "similarity"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
pre_dedup = len(results)
|
| 396 |
+
results = (
|
| 397 |
+
results.drop_duplicates(subset=["title", "authors"])
|
| 398 |
+
.sort_values("similarity", ascending=False)
|
| 399 |
+
.head(top_k)
|
| 400 |
+
)
|
| 401 |
post_dedup = len(results)
|
| 402 |
+
|
| 403 |
if pre_dedup > post_dedup:
|
| 404 |
self.logger.info(f"Removed {pre_dedup - post_dedup} duplicate results")
|
| 405 |
+
|
| 406 |
+
self.logger.info(f"Results processed in {time.time() - process_start:.2f}s")
|
| 407 |
return results[required_columns].reset_index(drop=True)
|
| 408 |
+
|
| 409 |
except Exception as e:
|
| 410 |
self.logger.error(f"Result processing failed: {str(e)}", exc_info=True)
|
| 411 |
+
return pd.DataFrame(columns=["title", "summary", "source", "similarity"])
|