Saksham Chaudhary commited on
Commit
1dc0365
·
1 Parent(s): 28fa826

Added the progress bar (#104)

Browse files
Files changed (1) hide show
  1. sage/embedder.py +26 -9
sage/embedder.py CHANGED
@@ -7,7 +7,7 @@ import time
7
  from abc import ABC, abstractmethod
8
  from collections import Counter
9
  from typing import Dict, Generator, List, Optional, Tuple
10
-
11
  import google.generativeai as genai
12
  import marqo
13
  import requests
@@ -52,15 +52,20 @@ class OpenAIBatchEmbedder(BatchEmbedder):
52
 
53
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
54
  """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
 
 
55
  batch = []
56
  batch_ids = {} # job_id -> metadata
57
  chunk_count = 0
58
  dataset_name = self.data_manager.dataset_id.replace("/", "_")
59
 
 
 
60
  for content, metadata in self.data_manager.walk():
61
  chunks = self.chunker.chunk(content, metadata)
62
  chunk_count += len(chunks)
63
  batch.extend(chunks)
 
64
 
65
  if len(batch) > chunks_per_batch:
66
  for i in range(0, len(batch), chunks_per_batch):
@@ -76,6 +81,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
76
  if batch:
77
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
78
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch]
 
79
  logging.info("Issued %d jobs for %d chunks.", len(batch_ids), chunk_count)
80
 
81
  timestamp = int(time.time())
@@ -83,6 +89,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
83
  with open(metadata_file, "w") as f:
84
  json.dump(batch_ids, f)
85
  logging.info("Job metadata saved at %s", metadata_file)
 
86
  return metadata_file
87
 
88
  def embeddings_are_ready(self, metadata_file: str) -> bool:
@@ -219,13 +226,18 @@ class VoyageBatchEmbedder(BatchEmbedder):
219
 
220
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
221
  """Issues batch embedding jobs for the entire dataset."""
 
 
222
  batch = []
223
  chunk_count = 0
224
 
 
 
225
  for content, metadata in self.data_manager.walk():
226
  chunks = self.chunker.chunk(content, metadata)
227
  chunk_count += len(chunks)
228
- batch.extend(chunks)
 
229
 
230
  token_count = chunk_count * self.chunker.max_tokens
231
  if token_count % 900_000 == 0:
@@ -247,7 +259,7 @@ class VoyageBatchEmbedder(BatchEmbedder):
247
  result = self._make_batch_request(batch)
248
  for chunk, datum in zip(batch, result["data"]):
249
  self.embedding_data.append((chunk.metadata, datum["embedding"]))
250
-
251
  logging.info(f"Successfully embedded {chunk_count} chunks.")
252
 
253
  def embeddings_are_ready(self, *args, **kwargs) -> bool:
@@ -291,19 +303,21 @@ class MarqoEmbedder(BatchEmbedder):
291
  self.client.create_index(index_name, model=model)
292
 
293
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
294
- """Issues batch embedding jobs for the entire dataset."""
295
  if chunks_per_batch > 64:
296
  raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
297
 
 
298
  chunk_count = 0
299
  batch = []
300
  job_count = 0
 
301
 
302
  for content, metadata in self.data_manager.walk():
303
  chunks = self.chunker.chunk(content, metadata)
304
  chunk_count += len(chunks)
305
  batch.extend(chunks)
306
-
307
  if len(batch) > chunks_per_batch:
308
  for i in range(0, len(batch), chunks_per_batch):
309
  sub_batch = batch[i : i + chunks_per_batch]
@@ -316,12 +330,13 @@ class MarqoEmbedder(BatchEmbedder):
316
 
317
  if max_embedding_jobs and job_count >= max_embedding_jobs:
318
  logging.info("Reached the maximum number of embedding jobs. Stopping.")
 
319
  return
320
  batch = []
321
-
322
- # Finally, commit the last batch.
323
  if batch:
324
  self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD])
 
 
325
  logging.info(f"Successfully embedded {chunk_count} chunks.")
326
 
327
  def embeddings_are_ready(self) -> bool:
@@ -353,16 +368,18 @@ class GeminiBatchEmbedder(BatchEmbedder):
353
 
354
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
355
  """Issues batch embedding jobs for the entire dataset."""
 
356
  batch = []
357
  chunk_count = 0
358
 
359
  request_count = 0
360
  last_request_time = time.time()
361
-
362
  for content, metadata in self.data_manager.walk():
363
  chunks = self.chunker.chunk(content, metadata)
364
  chunk_count += len(chunks)
365
  batch.extend(chunks)
 
366
 
367
  if len(batch) > chunks_per_batch:
368
  for i in range(0, len(batch), chunks_per_batch):
@@ -395,7 +412,7 @@ class GeminiBatchEmbedder(BatchEmbedder):
395
  result = self._make_batch_request(batch)
396
  for chunk, embedding in zip(batch, result["embedding"]):
397
  self.embedding_data.append((chunk.metadata, embedding))
398
-
399
  logging.info(f"Successfully embedded {chunk_count} chunks.")
400
 
401
  def embeddings_are_ready(self, *args, **kwargs) -> bool:
 
7
  from abc import ABC, abstractmethod
8
  from collections import Counter
9
  from typing import Dict, Generator, List, Optional, Tuple
10
+ from tqdm import tqdm
11
  import google.generativeai as genai
12
  import marqo
13
  import requests
 
52
 
53
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
54
  """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
55
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
56
+
57
  batch = []
58
  batch_ids = {} # job_id -> metadata
59
  chunk_count = 0
60
  dataset_name = self.data_manager.dataset_id.replace("/", "_")
61
 
62
+ pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
63
+
64
  for content, metadata in self.data_manager.walk():
65
  chunks = self.chunker.chunk(content, metadata)
66
  chunk_count += len(chunks)
67
  batch.extend(chunks)
68
+ pbar.update(1)
69
 
70
  if len(batch) > chunks_per_batch:
71
  for i in range(0, len(batch), chunks_per_batch):
 
81
  if batch:
82
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
83
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch]
84
+
85
  logging.info("Issued %d jobs for %d chunks.", len(batch_ids), chunk_count)
86
 
87
  timestamp = int(time.time())
 
89
  with open(metadata_file, "w") as f:
90
  json.dump(batch_ids, f)
91
  logging.info("Job metadata saved at %s", metadata_file)
92
+ pbar.close()
93
  return metadata_file
94
 
95
  def embeddings_are_ready(self, metadata_file: str) -> bool:
 
226
 
227
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
228
  """Issues batch embedding jobs for the entire dataset."""
229
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
230
+
231
  batch = []
232
  chunk_count = 0
233
 
234
+ pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
235
+
236
  for content, metadata in self.data_manager.walk():
237
  chunks = self.chunker.chunk(content, metadata)
238
  chunk_count += len(chunks)
239
+ batch.extend(chunks)
240
+ pbar.update(1)
241
 
242
  token_count = chunk_count * self.chunker.max_tokens
243
  if token_count % 900_000 == 0:
 
259
  result = self._make_batch_request(batch)
260
  for chunk, datum in zip(batch, result["data"]):
261
  self.embedding_data.append((chunk.metadata, datum["embedding"]))
262
+ pbar.close()
263
  logging.info(f"Successfully embedded {chunk_count} chunks.")
264
 
265
  def embeddings_are_ready(self, *args, **kwargs) -> bool:
 
303
  self.client.create_index(index_name, model=model)
304
 
305
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
306
+ """Issues batch embedding jobs for the entire dataset with progress tracking."""
307
  if chunks_per_batch > 64:
308
  raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
309
 
310
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
311
  chunk_count = 0
312
  batch = []
313
  job_count = 0
314
+ pbar = tqdm(total=num_files, desc="Processing chunks", unit="file")
315
 
316
  for content, metadata in self.data_manager.walk():
317
  chunks = self.chunker.chunk(content, metadata)
318
  chunk_count += len(chunks)
319
  batch.extend(chunks)
320
+ pbar.update(1)
321
  if len(batch) > chunks_per_batch:
322
  for i in range(0, len(batch), chunks_per_batch):
323
  sub_batch = batch[i : i + chunks_per_batch]
 
330
 
331
  if max_embedding_jobs and job_count >= max_embedding_jobs:
332
  logging.info("Reached the maximum number of embedding jobs. Stopping.")
333
+ pbar.close()
334
  return
335
  batch = []
 
 
336
  if batch:
337
  self.index.add_documents(documents=[chunk.metadata for chunk in batch], tensor_fields=[TEXT_FIELD])
338
+
339
+ pbar.close()
340
  logging.info(f"Successfully embedded {chunk_count} chunks.")
341
 
342
  def embeddings_are_ready(self) -> bool:
 
368
 
369
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
370
  """Issues batch embedding jobs for the entire dataset."""
371
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
372
  batch = []
373
  chunk_count = 0
374
 
375
  request_count = 0
376
  last_request_time = time.time()
377
+ pbar = tqdm(total=num_files, desc="Processing chunks", unit="file")
378
  for content, metadata in self.data_manager.walk():
379
  chunks = self.chunker.chunk(content, metadata)
380
  chunk_count += len(chunks)
381
  batch.extend(chunks)
382
+ pbar.update(1)
383
 
384
  if len(batch) > chunks_per_batch:
385
  for i in range(0, len(batch), chunks_per_batch):
 
412
  result = self._make_batch_request(batch)
413
  for chunk, embedding in zip(batch, result["embedding"]):
414
  self.embedding_data.append((chunk.metadata, embedding))
415
+ pbar.close()
416
  logging.info(f"Successfully embedded {chunk_count} chunks.")
417
 
418
  def embeddings_are_ready(self, *args, **kwargs) -> bool: