import json from collections.abc import Iterator from pathlib import Path from typing import Any from fastembed import SparseTextEmbedding from loguru import logger from tqdm import tqdm from scientific_rag.application.embeddings.encoder import encoder from scientific_rag.domain.documents import PaperChunk from scientific_rag.infrastructure.qdrant import qdrant_service from scientific_rag.settings import settings def load_chunks_generator( chunks_file: Path, batch_size: int = 10000 ) -> Iterator[list[PaperChunk]]: logger.info(f"Loading chunks from {chunks_file} in batches of {batch_size}") with open(chunks_file, encoding="utf-8") as f: chunks_data = json.load(f) total_chunks = len(chunks_data) logger.info(f"Found {total_chunks} chunks in file") for i in range(0, total_chunks, batch_size): batch_data = chunks_data[i : i + batch_size] batch_chunks = [PaperChunk(**chunk_data) for chunk_data in batch_data] yield batch_chunks del chunks_data def embed_chunks( chunks: list[PaperChunk], batch_size: int = 32, show_progress: bool = True, ) -> list[PaperChunk]: """Embed chunks using the dense encoder.""" logger.info(f"Embedding {len(chunks)} chunks (Dense) with batch size {batch_size}") texts = [chunk.text for chunk in chunks] embeddings = encoder.encode( texts=texts, mode="passage", batch_size=batch_size, show_progress=show_progress, ) for chunk, embedding in zip(chunks, embeddings, strict=False): chunk.embedding = embedding logger.success(f"Generated dense embeddings for {len(chunks)} chunks") return chunks def embed_sparse_chunks( chunks: list[PaperChunk], sparse_encoder: SparseTextEmbedding, batch_size: int = 32, show_progress: bool = True, ) -> list[Any]: """Generate sparse BM25 embeddings for chunks.""" logger.info( f"Embedding {len(chunks)} chunks (Sparse BM25) with batch size {batch_size}" ) texts = [chunk.text for chunk in chunks] sparse_embeddings = list( sparse_encoder.embed(documents=texts, batch_size=batch_size, parallel=None) ) logger.success(f"Generated sparse embeddings for {len(chunks)} chunks") return sparse_embeddings def index_chunks_to_qdrant( chunks: list[PaperChunk], sparse_embeddings: list[Any], batch_size: int = 100, show_progress: bool = True, ) -> int: """Upload chunks to Qdrant in batches.""" total_uploaded = 0 iterator = tqdm( range(0, len(chunks), batch_size), desc="Uploading to Qdrant", disable=not show_progress, ) for i in iterator: batch_chunks = chunks[i : i + batch_size] batch_sparse = None if sparse_embeddings: batch_sparse = sparse_embeddings[i : i + batch_size] uploaded = qdrant_service.upsert_chunks( batch_chunks, sparse_embeddings=batch_sparse ) total_uploaded += uploaded return total_uploaded def index_qdrant( chunks_file: Path | str | None = None, embedding_batch_size: int = 32, upload_batch_size: int = 100, create_collection: bool = True, process_batch_size: int = 10000, ) -> dict[str, int]: """Complete pipeline to index chunks to Qdrant. Args: chunks_file: Path to chunks JSON file embedding_batch_size: Batch size for embedding generation upload_batch_size: Batch size for Qdrant upload create_collection: Whether to create the collection process_batch_size: Process chunks in batches of this size to manage memory """ if chunks_file is None: chunks_file = ( Path(settings.root_dir) / "data" / "processed" / f"chunks_{settings.dataset_split}.json" ) else: chunks_file = Path(chunks_file) if not chunks_file.exists(): raise FileNotFoundError(f"Chunks file not found: {chunks_file}") # Use the module-level singleton so in-memory Qdrant is shared # between the indexer and the application runtime. if create_collection: vector_size = getattr(encoder, "embedding_dim", 384) qdrant_service.create_collection(vector_size=vector_size) logger.info(f"Initializing Sparse Encoder: {settings.sparse_embedding_model_name}") sparse_encoder = SparseTextEmbedding( model_name=settings.sparse_embedding_model_name ) logger.info("Processing chunks in streaming batches to manage memory...") total_uploaded = 0 batch_num = 0 for batch_chunks in load_chunks_generator( chunks_file, batch_size=process_batch_size ): batch_num += 1 batch_start = (batch_num - 1) * process_batch_size batch_end = batch_start + len(batch_chunks) logger.info( f"--- Processing Batch {batch_num} (Chunks {batch_start}-{batch_end}) ---" ) batch_chunks = embed_chunks( chunks=batch_chunks, batch_size=embedding_batch_size, show_progress=True, ) batch_sparse = embed_sparse_chunks( chunks=batch_chunks, sparse_encoder=sparse_encoder, batch_size=embedding_batch_size, show_progress=True, ) logger.info(f"Batch {batch_num}: Uploading chunks to Qdrant...") batch_uploaded = index_chunks_to_qdrant( chunks=batch_chunks, sparse_embeddings=batch_sparse, batch_size=upload_batch_size, show_progress=True, ) total_uploaded += batch_uploaded logger.success( f"Batch {batch_num} complete: {batch_uploaded} chunks uploaded (Total: {total_uploaded})" ) logger.info("Getting final statistics...") collection_info = qdrant_service.get_collection_info() stats = { "chunks_uploaded": total_uploaded, "collection_points": collection_info.get("points_count", 0), "collection_vectors": collection_info.get("index_vectors_count", 0), } logger.success(f"Indexing complete: {stats}") return stats if __name__ == "__main__": index_qdrant()