| | """ |
| | Chunked embedding loader for scalable model embeddings. |
| | Loads embeddings in chunks to reduce memory usage and startup time. |
| | """ |
| | import os |
| | import logging |
| | from pathlib import Path |
| | from typing import Optional, List, Dict, Tuple |
| | import pandas as pd |
| | import numpy as np |
| | import pyarrow.parquet as pq |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class ChunkedEmbeddingLoader: |
| | """ |
| | Load embeddings from chunked parquet files. |
| | Only loads chunks containing requested model IDs. |
| | """ |
| | |
| | def __init__(self, data_dir: str = "precomputed_data", version: str = "v1", chunk_size: int = 50000): |
| | """ |
| | Initialize chunked loader. |
| | |
| | Args: |
| | data_dir: Directory containing pre-computed files |
| | version: Version tag |
| | chunk_size: Number of models per chunk |
| | """ |
| | self.data_dir = Path(data_dir) |
| | self.version = version |
| | self.chunk_size = chunk_size |
| | self.chunk_index: Optional[pd.DataFrame] = None |
| | self._chunk_cache: Dict[int, pd.DataFrame] = {} |
| | self._max_cache_size = 10 |
| | |
| | def load_chunk_index(self) -> pd.DataFrame: |
| | """Load the chunk index mapping model_id to chunk_id.""" |
| | index_file = self.data_dir / f"chunk_index_{self.version}.parquet" |
| | |
| | if not index_file.exists(): |
| | raise FileNotFoundError( |
| | f"Chunk index not found: {index_file}\n" |
| | f"Run precompute_data.py with --chunked flag to generate chunked data." |
| | ) |
| | |
| | logger.info(f"Loading chunk index from {index_file}...") |
| | self.chunk_index = pd.read_parquet(index_file) |
| | logger.info(f"Loaded chunk index: {len(self.chunk_index):,} models in {self.chunk_index['chunk_id'].nunique()} chunks") |
| | |
| | return self.chunk_index |
| | |
| | def _load_chunk(self, chunk_id: int) -> pd.DataFrame: |
| | """Load a single chunk file.""" |
| | |
| | if chunk_id in self._chunk_cache: |
| | return self._chunk_cache[chunk_id] |
| | |
| | chunk_file = self.data_dir / f"embeddings_chunk_{chunk_id:03d}_{self.version}.parquet" |
| | |
| | if not chunk_file.exists(): |
| | raise FileNotFoundError(f"Chunk file not found: {chunk_file}") |
| | |
| | logger.debug(f"Loading chunk {chunk_id} from {chunk_file}...") |
| | chunk_df = pd.read_parquet(chunk_file) |
| | |
| | |
| | if len(self._chunk_cache) >= self._max_cache_size: |
| | oldest_chunk = min(self._chunk_cache.keys()) |
| | del self._chunk_cache[oldest_chunk] |
| | |
| | self._chunk_cache[chunk_id] = chunk_df |
| | return chunk_df |
| | |
| | def load_embeddings_for_models( |
| | self, |
| | model_ids: List[str], |
| | return_as_dict: bool = False |
| | ) -> Tuple[np.ndarray, List[str]]: |
| | """ |
| | Load embeddings only for specified model IDs. |
| | |
| | Args: |
| | model_ids: List of model IDs to load |
| | return_as_dict: If True, return dict mapping model_id to embedding |
| | |
| | Returns: |
| | Tuple of (embeddings_array, model_ids_found) |
| | If return_as_dict=True, returns (embeddings_dict, model_ids_found) |
| | """ |
| | if self.chunk_index is None: |
| | self.load_chunk_index() |
| | |
| | |
| | requested_ids = set(model_ids) |
| | |
| | |
| | model_chunks = self.chunk_index[ |
| | self.chunk_index['model_id'].isin(requested_ids) |
| | ] |
| | |
| | if len(model_chunks) == 0: |
| | logger.warning(f"No embeddings found for {len(model_ids)} requested models") |
| | return (np.array([]), []) if not return_as_dict else ({}, []) |
| | |
| | |
| | embeddings_dict = {} |
| | found_ids = [] |
| | |
| | for chunk_id in model_chunks['chunk_id'].unique(): |
| | chunk_df = self._load_chunk(chunk_id) |
| | |
| | |
| | chunk_model_ids = model_chunks[model_chunks['chunk_id'] == chunk_id]['model_id'].tolist() |
| | chunk_embeddings = chunk_df[chunk_df['model_id'].isin(chunk_model_ids)] |
| | |
| | for _, row in chunk_embeddings.iterrows(): |
| | model_id = row['model_id'] |
| | embedding = np.array(row['embedding']) |
| | embeddings_dict[model_id] = embedding |
| | found_ids.append(model_id) |
| | |
| | if return_as_dict: |
| | return embeddings_dict, found_ids |
| | |
| | |
| | embeddings_list = [embeddings_dict[mid] for mid in model_ids if mid in embeddings_dict] |
| | found_ids_ordered = [mid for mid in model_ids if mid in embeddings_dict] |
| | |
| | if len(embeddings_list) == 0: |
| | return np.array([]), [] |
| | |
| | embeddings_array = np.array(embeddings_list) |
| | return embeddings_array, found_ids_ordered |
| | |
| | def load_all_embeddings(self) -> Tuple[np.ndarray, pd.Series]: |
| | """ |
| | Load all embeddings (for backward compatibility). |
| | Warning: This loads all chunks into memory! |
| | """ |
| | if self.chunk_index is None: |
| | self.load_chunk_index() |
| | |
| | all_chunk_ids = sorted(self.chunk_index['chunk_id'].unique()) |
| | logger.warning(f"Loading all {len(all_chunk_ids)} chunks - this may use significant memory!") |
| | |
| | all_embeddings = [] |
| | all_model_ids = [] |
| | |
| | for chunk_id in all_chunk_ids: |
| | chunk_df = self._load_chunk(chunk_id) |
| | all_embeddings.extend(chunk_df['embedding'].tolist()) |
| | all_model_ids.extend(chunk_df['model_id'].tolist()) |
| | |
| | embeddings_array = np.array(all_embeddings) |
| | model_ids_series = pd.Series(all_model_ids) |
| | |
| | return embeddings_array, model_ids_series |
| | |
| | def get_chunk_info(self) -> Dict: |
| | """Get information about chunks.""" |
| | if self.chunk_index is None: |
| | self.load_chunk_index() |
| | |
| | chunk_counts = self.chunk_index['chunk_id'].value_counts().sort_index() |
| | |
| | return { |
| | 'total_models': len(self.chunk_index), |
| | 'total_chunks': self.chunk_index['chunk_id'].nunique(), |
| | 'chunk_size': self.chunk_size, |
| | 'chunk_counts': chunk_counts.to_dict(), |
| | 'cached_chunks': list(self._chunk_cache.keys()) |
| | } |
| | |
| | def clear_cache(self): |
| | """Clear the chunk cache.""" |
| | self._chunk_cache.clear() |
| | logger.info("Chunk cache cleared") |
| |
|
| |
|
| | def create_chunk_index( |
| | df: pd.DataFrame, |
| | chunk_size: int = 50000, |
| | output_dir: Path = None, |
| | version: str = "v1" |
| | ) -> pd.DataFrame: |
| | """ |
| | Create chunk index from dataframe. |
| | |
| | Args: |
| | df: DataFrame with model_id column |
| | chunk_size: Number of models per chunk |
| | output_dir: Directory to save index |
| | version: Version tag |
| | |
| | Returns: |
| | DataFrame with columns: model_id, chunk_id, chunk_offset |
| | """ |
| | model_ids = df['model_id'].astype(str).values |
| | |
| | |
| | chunk_ids = (np.arange(len(model_ids)) // chunk_size).astype(int) |
| | chunk_offsets = np.arange(len(model_ids)) % chunk_size |
| | |
| | chunk_index = pd.DataFrame({ |
| | 'model_id': model_ids, |
| | 'chunk_id': chunk_ids, |
| | 'chunk_offset': chunk_offsets |
| | }) |
| | |
| | if output_dir: |
| | index_file = output_dir / f"chunk_index_{version}.parquet" |
| | chunk_index.to_parquet(index_file, compression='snappy', index=False) |
| | logger.info(f"Saved chunk index: {index_file}") |
| | |
| | return chunk_index |
| |
|
| |
|
| |
|