juliaturc commited on
Commit
eab5126
·
1 Parent(s): d0e366f

Add max_embedding_jobs flag.

Browse files
Files changed (2) hide show
  1. src/embedder.py +13 -5
  2. src/index.py +10 -3
src/embedder.py CHANGED
@@ -34,9 +34,7 @@ class BatchEmbedder(ABC):
34
  class OpenAIBatchEmbedder(BatchEmbedder):
35
  """Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
36
 
37
- def __init__(
38
- self, repo_manager: RepoManager, chunker: Chunker, local_dir: str
39
- ):
40
  self.repo_manager = repo_manager
41
  self.chunker = chunker
42
  self.local_dir = local_dir
@@ -44,7 +42,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
44
  self.openai_batch_ids = {}
45
  self.client = OpenAI()
46
 
47
- def embed_repo(self, chunks_per_batch: int):
48
  """Issues batch embedding jobs for the entire repository."""
49
  if self.openai_batch_ids:
50
  raise ValueError("Embeddings are in progress.")
@@ -67,6 +65,14 @@ class OpenAIBatchEmbedder(BatchEmbedder):
67
  self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
68
  sub_batch
69
  )
 
 
 
 
 
 
 
 
70
  batch = []
71
 
72
  # Finally, commit the last batch.
@@ -133,7 +139,9 @@ class OpenAIBatchEmbedder(BatchEmbedder):
133
  OpenAIBatchEmbedder._export_to_jsonl([request], input_file)
134
 
135
  # Uplaod the file and issue the embedding job.
136
- batch_input_file = self.client.files.create(file=open(input_file, "rb"), purpose="batch")
 
 
137
  batch_status = self._create_batch_job(batch_input_file.id)
138
  logging.info("Created job with ID %s", batch_status.id)
139
  return batch_status.id
 
34
  class OpenAIBatchEmbedder(BatchEmbedder):
35
  """Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
36
 
37
+ def __init__(self, repo_manager: RepoManager, chunker: Chunker, local_dir: str):
 
 
38
  self.repo_manager = repo_manager
39
  self.chunker = chunker
40
  self.local_dir = local_dir
 
42
  self.openai_batch_ids = {}
43
  self.client = OpenAI()
44
 
45
+ def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
46
  """Issues batch embedding jobs for the entire repository."""
47
  if self.openai_batch_ids:
48
  raise ValueError("Embeddings are in progress.")
 
65
  self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
66
  sub_batch
67
  )
68
+ if (
69
+ max_embedding_jobs
70
+ and len(self.openai_batch_ids) >= max_embedding_jobs
71
+ ):
72
+ logging.info(
73
+ "Reached the maximum number of embedding jobs. Stopping."
74
+ )
75
+ return
76
  batch = []
77
 
78
  # Finally, commit the last batch.
 
139
  OpenAIBatchEmbedder._export_to_jsonl([request], input_file)
140
 
141
  # Uplaod the file and issue the embedding job.
142
+ batch_input_file = self.client.files.create(
143
+ file=open(input_file, "rb"), purpose="batch"
144
+ )
145
  batch_status = self._create_batch_job(batch_input_file.id)
146
  logging.info("Created job with ID %s", batch_status.id)
147
  return batch_status.id
src/index.py CHANGED
@@ -47,10 +47,17 @@ def main():
47
  "--pinecone_index_name", required=True, help="Pinecone index name"
48
  )
49
  parser.add_argument(
50
- "--include", help="Path to a file containing a list of extensions to include. One extension per line."
 
51
  )
52
  parser.add_argument(
53
- "--exclude", help="Path to a file containing a list of extensions to exclude. One extension per line."
 
 
 
 
 
 
54
  )
55
 
56
  args = parser.parse_args()
@@ -84,7 +91,7 @@ def main():
84
  logging.info("Issuing embedding jobs...")
85
  chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
86
  embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
87
- embedder.embed_repo(args.chunks_per_batch)
88
 
89
  logging.info("Waiting for embeddings to be ready...")
90
  while not embedder.embeddings_are_ready():
 
47
  "--pinecone_index_name", required=True, help="Pinecone index name"
48
  )
49
  parser.add_argument(
50
+ "--include",
51
+ help="Path to a file containing a list of extensions to include. One extension per line.",
52
  )
53
  parser.add_argument(
54
+ "--exclude",
55
+ help="Path to a file containing a list of extensions to exclude. One extension per line.",
56
+ )
57
+ parser.add_argument(
58
+ "--max_embedding_jobs", type=int,
59
+ help="Maximum number of embedding jobs to run. Specifying this might result in "
60
+ "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
61
  )
62
 
63
  args = parser.parse_args()
 
91
  logging.info("Issuing embedding jobs...")
92
  chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
93
  embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
94
+ embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
95
 
96
  logging.info("Waiting for embeddings to be ready...")
97
  while not embedder.embeddings_are_ready():