juliaturc commited on
Commit
d4d6659
·
1 Parent(s): 8f148e0

Update "Processing chunks" into "Processing files"

Browse files
Files changed (1) hide show
  1. sage/embedder.py +11 -10
sage/embedder.py CHANGED
@@ -53,14 +53,13 @@ class OpenAIBatchEmbedder(BatchEmbedder):
53
 
54
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
55
  """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
56
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
57
-
58
  batch = []
59
  batch_ids = {} # job_id -> metadata
60
  chunk_count = 0
61
  dataset_name = self.data_manager.dataset_id.replace("/", "_")
62
 
63
- pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
 
64
 
65
  for content, metadata in self.data_manager.walk():
66
  chunks = self.chunker.chunk(content, metadata)
@@ -227,12 +226,11 @@ class VoyageBatchEmbedder(BatchEmbedder):
227
 
228
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
229
  """Issues batch embedding jobs for the entire dataset."""
230
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
231
-
232
  batch = []
233
  chunk_count = 0
234
 
235
- pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
 
236
 
237
  for content, metadata in self.data_manager.walk():
238
  chunks = self.chunker.chunk(content, metadata)
@@ -308,11 +306,12 @@ class MarqoEmbedder(BatchEmbedder):
308
  if chunks_per_batch > 64:
309
  raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
310
 
311
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
312
  chunk_count = 0
313
  batch = []
314
  job_count = 0
315
- pbar = tqdm(total=num_files, desc="Processing chunks", unit="file")
 
 
316
 
317
  for content, metadata in self.data_manager.walk():
318
  chunks = self.chunker.chunk(content, metadata)
@@ -369,13 +368,15 @@ class GeminiBatchEmbedder(BatchEmbedder):
369
 
370
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
371
  """Issues batch embedding jobs for the entire dataset."""
372
- num_files = len([x for x in self.data_manager.walk(get_content=False)])
373
  batch = []
374
  chunk_count = 0
375
 
376
  request_count = 0
377
  last_request_time = time.time()
378
- pbar = tqdm(total=num_files, desc="Processing chunks", unit="file")
 
 
 
379
  for content, metadata in self.data_manager.walk():
380
  chunks = self.chunker.chunk(content, metadata)
381
  chunk_count += len(chunks)
 
53
 
54
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
55
  """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
 
 
56
  batch = []
57
  batch_ids = {} # job_id -> metadata
58
  chunk_count = 0
59
  dataset_name = self.data_manager.dataset_id.replace("/", "_")
60
 
61
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
62
+ pbar = tqdm(total=num_files, desc="Processing files", unit="file")
63
 
64
  for content, metadata in self.data_manager.walk():
65
  chunks = self.chunker.chunk(content, metadata)
 
226
 
227
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
228
  """Issues batch embedding jobs for the entire dataset."""
 
 
229
  batch = []
230
  chunk_count = 0
231
 
232
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
233
+ pbar = tqdm(total=num_files, desc="Processing files", unit="file")
234
 
235
  for content, metadata in self.data_manager.walk():
236
  chunks = self.chunker.chunk(content, metadata)
 
306
  if chunks_per_batch > 64:
307
  raise ValueError("Marqo enforces a limit of 64 chunks per batch.")
308
 
 
309
  chunk_count = 0
310
  batch = []
311
  job_count = 0
312
+
313
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
314
+ pbar = tqdm(total=num_files, desc="Processing files", unit="file")
315
 
316
  for content, metadata in self.data_manager.walk():
317
  chunks = self.chunker.chunk(content, metadata)
 
368
 
369
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
370
  """Issues batch embedding jobs for the entire dataset."""
 
371
  batch = []
372
  chunk_count = 0
373
 
374
  request_count = 0
375
  last_request_time = time.time()
376
+
377
+ num_files = len([x for x in self.data_manager.walk(get_content=False)])
378
+ pbar = tqdm(total=num_files, desc="Processing files", unit="file")
379
+
380
  for content, metadata in self.data_manager.walk():
381
  chunks = self.chunker.chunk(content, metadata)
382
  chunk_count += len(chunks)