DenysKovalML commited on
Commit
8ddd8e2
·
1 Parent(s): 6c023b4

fix: add batching to qdrant index

Browse files
src/scientific_rag/cli.py CHANGED
@@ -24,6 +24,7 @@ def index(
24
  embedding_batch_size: int = typer.Option(32, "--embedding-batch-size", "-eb"),
25
  upload_batch_size: int = typer.Option(100, "--upload-batch-size", "-ub"),
26
  create_collection: bool = typer.Option(True, "--create-collection/--no-create-collection"),
 
27
  ) -> None:
28
  """Embed chunks and upload to Qdrant."""
29
  chunks_path = Path(chunks_file) if chunks_file else None
@@ -32,6 +33,7 @@ def index(
32
  embedding_batch_size=embedding_batch_size,
33
  upload_batch_size=upload_batch_size,
34
  create_collection=create_collection,
 
35
  )
36
 
37
 
@@ -41,6 +43,7 @@ def pipeline(
41
  embedding_batch_size: int = typer.Option(32, "--embedding-batch-size", "-eb"),
42
  upload_batch_size: int = typer.Option(100, "--upload-batch-size", "-ub"),
43
  create_collection: bool = typer.Option(True, "--create-collection/--no-create-collection"),
 
44
  ) -> None:
45
  """Run complete pipeline: chunk → embed → index."""
46
  logger.info("Step 1/2: Chunking data")
@@ -52,6 +55,7 @@ def pipeline(
52
  embedding_batch_size=embedding_batch_size,
53
  upload_batch_size=upload_batch_size,
54
  create_collection=create_collection,
 
55
  )
56
 
57
 
 
24
  embedding_batch_size: int = typer.Option(32, "--embedding-batch-size", "-eb"),
25
  upload_batch_size: int = typer.Option(100, "--upload-batch-size", "-ub"),
26
  create_collection: bool = typer.Option(True, "--create-collection/--no-create-collection"),
27
+ process_batch_size: int = typer.Option(10000, "--process-batch-size", "-pb", help="Process chunks in batches"),
28
  ) -> None:
29
  """Embed chunks and upload to Qdrant."""
30
  chunks_path = Path(chunks_file) if chunks_file else None
 
33
  embedding_batch_size=embedding_batch_size,
34
  upload_batch_size=upload_batch_size,
35
  create_collection=create_collection,
36
+ process_batch_size=process_batch_size,
37
  )
38
 
39
 
 
43
  embedding_batch_size: int = typer.Option(32, "--embedding-batch-size", "-eb"),
44
  upload_batch_size: int = typer.Option(100, "--upload-batch-size", "-ub"),
45
  create_collection: bool = typer.Option(True, "--create-collection/--no-create-collection"),
46
+ process_batch_size: int = typer.Option(10000, "--process-batch-size", "-pb", help="Process chunks in batches"),
47
  ) -> None:
48
  """Run complete pipeline: chunk → embed → index."""
49
  logger.info("Step 1/2: Chunking data")
 
55
  embedding_batch_size=embedding_batch_size,
56
  upload_batch_size=upload_batch_size,
57
  create_collection=create_collection,
58
+ process_batch_size=process_batch_size,
59
  )
60
 
61
 
src/scientific_rag/scripts/index_qdrant.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import json
2
  from pathlib import Path
3
 
@@ -10,17 +11,21 @@ from scientific_rag.infrastructure.qdrant import QdrantService
10
  from scientific_rag.settings import settings
11
 
12
 
13
- def load_chunks(chunks_file: Path) -> list[PaperChunk]:
14
- """Load chunks from JSON file."""
15
- logger.info(f"Loading chunks from {chunks_file}")
16
 
17
  with open(chunks_file, encoding="utf-8") as f:
18
  chunks_data = json.load(f)
19
 
20
- chunks = [PaperChunk(**chunk_data) for chunk_data in chunks_data]
21
- logger.info(f"Loaded {len(chunks)} chunks")
22
 
23
- return chunks
 
 
 
 
 
24
 
25
 
26
  def embed_chunks(
@@ -51,18 +56,17 @@ def index_chunks_to_qdrant(
51
  chunks: list[PaperChunk],
52
  qdrant_service: QdrantService,
53
  batch_size: int = 100,
 
54
  ) -> int:
55
  """Upload chunks to Qdrant in batches."""
56
- logger.info(f"Indexing {len(chunks)} chunks to Qdrant")
57
-
58
  total_uploaded = 0
59
 
60
- for i in tqdm(range(0, len(chunks), batch_size), desc="Uploading to Qdrant"):
 
61
  batch = chunks[i : i + batch_size]
62
  uploaded = qdrant_service.upsert_chunks(batch)
63
  total_uploaded += uploaded
64
 
65
- logger.success(f"Indexed {total_uploaded} chunks to Qdrant")
66
  return total_uploaded
67
 
68
 
@@ -71,8 +75,17 @@ def index_qdrant(
71
  embedding_batch_size: int = 32,
72
  upload_batch_size: int = 100,
73
  create_collection: bool = True,
 
74
  ) -> dict[str, int]:
75
- """Complete pipeline to index chunks to Qdrant."""
 
 
 
 
 
 
 
 
76
  if chunks_file is None:
77
  chunks_file = Path(settings.root_dir) / "data" / "processed" / f"chunks_{settings.dataset_split}.json"
78
  else:
@@ -85,27 +98,42 @@ def index_qdrant(
85
  if create_collection:
86
  qdrant_service.create_collection(vector_size=encoder.embedding_dim)
87
 
88
- chunks = load_chunks(chunks_file)
89
- chunks = embed_chunks(
90
- chunks=chunks,
91
- batch_size=embedding_batch_size,
92
- show_progress=True,
93
- )
94
- total_uploaded = index_chunks_to_qdrant(
95
- chunks=chunks,
96
- qdrant_service=qdrant_service,
97
- batch_size=upload_batch_size,
98
- )
99
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  collection_info = qdrant_service.get_collection_info()
101
  stats = {
102
- "chunks_loaded": len(chunks),
103
  "chunks_uploaded": total_uploaded,
104
  "collection_points": collection_info.get("points_count", 0),
105
  "collection_vectors": collection_info.get("index_vectors_count", 0),
106
  }
107
 
108
- logger.info(f"Indexing complete: {stats}")
109
  return stats
110
 
111
 
 
1
+ from collections.abc import Iterator
2
  import json
3
  from pathlib import Path
4
 
 
11
  from scientific_rag.settings import settings
12
 
13
 
14
+ def load_chunks_generator(chunks_file: Path, batch_size: int = 10000) -> Iterator[list[PaperChunk]]:
15
+ logger.info(f"Loading chunks from {chunks_file} in batches of {batch_size}")
 
16
 
17
  with open(chunks_file, encoding="utf-8") as f:
18
  chunks_data = json.load(f)
19
 
20
+ total_chunks = len(chunks_data)
21
+ logger.info(f"Found {total_chunks} chunks in file")
22
 
23
+ for i in range(0, total_chunks, batch_size):
24
+ batch_data = chunks_data[i : i + batch_size]
25
+ batch_chunks = [PaperChunk(**chunk_data) for chunk_data in batch_data]
26
+ yield batch_chunks
27
+
28
+ del chunks_data
29
 
30
 
31
  def embed_chunks(
 
56
  chunks: list[PaperChunk],
57
  qdrant_service: QdrantService,
58
  batch_size: int = 100,
59
+ show_progress: bool = True,
60
  ) -> int:
61
  """Upload chunks to Qdrant in batches."""
 
 
62
  total_uploaded = 0
63
 
64
+ iterator = tqdm(range(0, len(chunks), batch_size), desc="Uploading to Qdrant", disable=not show_progress)
65
+ for i in iterator:
66
  batch = chunks[i : i + batch_size]
67
  uploaded = qdrant_service.upsert_chunks(batch)
68
  total_uploaded += uploaded
69
 
 
70
  return total_uploaded
71
 
72
 
 
75
  embedding_batch_size: int = 32,
76
  upload_batch_size: int = 100,
77
  create_collection: bool = True,
78
+ process_batch_size: int = 10000,
79
  ) -> dict[str, int]:
80
+ """Complete pipeline to index chunks to Qdrant.
81
+
82
+ Args:
83
+ chunks_file: Path to chunks JSON file
84
+ embedding_batch_size: Batch size for embedding generation
85
+ upload_batch_size: Batch size for Qdrant upload
86
+ create_collection: Whether to create the collection
87
+ process_batch_size: Process chunks in batches of this size to manage memory
88
+ """
89
  if chunks_file is None:
90
  chunks_file = Path(settings.root_dir) / "data" / "processed" / f"chunks_{settings.dataset_split}.json"
91
  else:
 
98
  if create_collection:
99
  qdrant_service.create_collection(vector_size=encoder.embedding_dim)
100
 
101
+ logger.info("Processing chunks in streaming batches to manage memory...")
102
+ total_uploaded = 0
103
+ batch_num = 0
104
+
105
+ for batch_chunks in load_chunks_generator(chunks_file, batch_size=process_batch_size):
106
+ batch_num += 1
107
+ batch_start = (batch_num - 1) * process_batch_size
108
+ batch_end = batch_start + len(batch_chunks)
109
+
110
+ logger.info(f"Batch {batch_num}: Embedding chunks {batch_start}-{batch_end} ({len(batch_chunks)} chunks)...")
111
+ batch_chunks = embed_chunks(
112
+ chunks=batch_chunks,
113
+ batch_size=embedding_batch_size,
114
+ show_progress=True,
115
+ )
116
+
117
+ logger.info(f"Batch {batch_num}: Uploading chunks {batch_start}-{batch_end} to Qdrant...")
118
+ batch_uploaded = index_chunks_to_qdrant(
119
+ chunks=batch_chunks,
120
+ qdrant_service=qdrant_service,
121
+ batch_size=upload_batch_size,
122
+ show_progress=True,
123
+ )
124
+ total_uploaded += batch_uploaded
125
+
126
+ logger.success(f"Batch {batch_num} complete: {batch_uploaded} chunks uploaded (Total: {total_uploaded})")
127
+
128
+ logger.info("Getting final statistics...")
129
  collection_info = qdrant_service.get_collection_info()
130
  stats = {
 
131
  "chunks_uploaded": total_uploaded,
132
  "collection_points": collection_info.get("points_count", 0),
133
  "collection_vectors": collection_info.get("index_vectors_count", 0),
134
  }
135
 
136
+ logger.success(f"Indexing complete: {stats}")
137
  return stats
138
 
139