Julia Turc commited on
Commit
fcecd4b
·
1 Parent(s): 210c3c5

Apply formatting

Browse files
Files changed (1) hide show
  1. sage/embedder.py +10 -9
sage/embedder.py CHANGED
@@ -7,12 +7,13 @@ import time
7
  from abc import ABC, abstractmethod
8
  from collections import Counter
9
  from typing import Dict, Generator, List, Optional, Tuple
10
- from tqdm import tqdm
11
  import google.generativeai as genai
12
  import marqo
13
  import requests
14
  from openai import OpenAI
15
  from tenacity import retry, stop_after_attempt, wait_random_exponential
 
16
 
17
  from sage.chunker import Chunk, Chunker
18
  from sage.constants import TEXT_FIELD
@@ -53,14 +54,14 @@ class OpenAIBatchEmbedder(BatchEmbedder):
53
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
54
  """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
55
  num_files = len([x for x in self.data_manager.walk(get_content=False)])
56
-
57
  batch = []
58
  batch_ids = {} # job_id -> metadata
59
  chunk_count = 0
60
  dataset_name = self.data_manager.dataset_id.replace("/", "_")
61
 
62
  pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
63
-
64
  for content, metadata in self.data_manager.walk():
65
  chunks = self.chunker.chunk(content, metadata)
66
  chunk_count += len(chunks)
@@ -81,7 +82,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
81
  if batch:
82
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
83
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch]
84
-
85
  logging.info("Issued %d jobs for %d chunks.", len(batch_ids), chunk_count)
86
 
87
  timestamp = int(time.time())
@@ -227,16 +228,16 @@ class VoyageBatchEmbedder(BatchEmbedder):
227
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
228
  """Issues batch embedding jobs for the entire dataset."""
229
  num_files = len([x for x in self.data_manager.walk(get_content=False)])
230
-
231
  batch = []
232
  chunk_count = 0
233
 
234
  pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
235
-
236
  for content, metadata in self.data_manager.walk():
237
  chunks = self.chunker.chunk(content, metadata)
238
  chunk_count += len(chunks)
239
- batch.extend(chunks)
240
  pbar.update(1)
241
 
242
  token_count = chunk_count * self.chunker.max_tokens
@@ -317,7 +318,7 @@ class MarqoEmbedder(BatchEmbedder):
317
  chunks = self.chunker.chunk(content, metadata)
318
  chunk_count += len(chunks)
319
  batch.extend(chunks)
320
- pbar.update(1)
321
  if len(batch) > chunks_per_batch:
322
  for i in range(0, len(batch), chunks_per_batch):
323
  sub_batch = batch[i : i + chunks_per_batch]
@@ -379,7 +380,7 @@ class GeminiBatchEmbedder(BatchEmbedder):
379
  chunks = self.chunker.chunk(content, metadata)
380
  chunk_count += len(chunks)
381
  batch.extend(chunks)
382
- pbar.update(1)
383
 
384
  if len(batch) > chunks_per_batch:
385
  for i in range(0, len(batch), chunks_per_batch):
 
7
  from abc import ABC, abstractmethod
8
  from collections import Counter
9
  from typing import Dict, Generator, List, Optional, Tuple
10
+
11
  import google.generativeai as genai
12
  import marqo
13
  import requests
14
  from openai import OpenAI
15
  from tenacity import retry, stop_after_attempt, wait_random_exponential
16
+ from tqdm import tqdm
17
 
18
  from sage.chunker import Chunk, Chunker
19
  from sage.constants import TEXT_FIELD
 
54
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None) -> str:
55
  """Issues batch embedding jobs for the entire dataset. Returns the filename containing the job IDs."""
56
  num_files = len([x for x in self.data_manager.walk(get_content=False)])
57
+
58
  batch = []
59
  batch_ids = {} # job_id -> metadata
60
  chunk_count = 0
61
  dataset_name = self.data_manager.dataset_id.replace("/", "_")
62
 
63
  pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
64
+
65
  for content, metadata in self.data_manager.walk():
66
  chunks = self.chunker.chunk(content, metadata)
67
  chunk_count += len(chunks)
 
82
  if batch:
83
  openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{dataset_name}/{len(batch_ids)}")
84
  batch_ids[openai_batch_id] = [chunk.metadata for chunk in batch]
85
+
86
  logging.info("Issued %d jobs for %d chunks.", len(batch_ids), chunk_count)
87
 
88
  timestamp = int(time.time())
 
228
  def embed_dataset(self, chunks_per_batch: int, max_embedding_jobs: int = None):
229
  """Issues batch embedding jobs for the entire dataset."""
230
  num_files = len([x for x in self.data_manager.walk(get_content=False)])
231
+
232
  batch = []
233
  chunk_count = 0
234
 
235
  pbar = tqdm(total=num_files, desc="Processing chunks", unit="chunk")
236
+
237
  for content, metadata in self.data_manager.walk():
238
  chunks = self.chunker.chunk(content, metadata)
239
  chunk_count += len(chunks)
240
+ batch.extend(chunks)
241
  pbar.update(1)
242
 
243
  token_count = chunk_count * self.chunker.max_tokens
 
318
  chunks = self.chunker.chunk(content, metadata)
319
  chunk_count += len(chunks)
320
  batch.extend(chunks)
321
+ pbar.update(1)
322
  if len(batch) > chunks_per_batch:
323
  for i in range(0, len(batch), chunks_per_batch):
324
  sub_batch = batch[i : i + chunks_per_batch]
 
380
  chunks = self.chunker.chunk(content, metadata)
381
  chunk_count += len(chunks)
382
  batch.extend(chunks)
383
+ pbar.update(1)
384
 
385
  if len(batch) > chunks_per_batch:
386
  for i in range(0, len(batch), chunks_per_batch):