Spaces:
Running
Running
Add max_embedding_jobs flag.
Browse files- src/embedder.py +13 -5
- src/index.py +10 -3
src/embedder.py
CHANGED
|
@@ -34,9 +34,7 @@ class BatchEmbedder(ABC):
|
|
| 34 |
class OpenAIBatchEmbedder(BatchEmbedder):
|
| 35 |
"""Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
|
| 36 |
|
| 37 |
-
def __init__(
|
| 38 |
-
self, repo_manager: RepoManager, chunker: Chunker, local_dir: str
|
| 39 |
-
):
|
| 40 |
self.repo_manager = repo_manager
|
| 41 |
self.chunker = chunker
|
| 42 |
self.local_dir = local_dir
|
|
@@ -44,7 +42,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 44 |
self.openai_batch_ids = {}
|
| 45 |
self.client = OpenAI()
|
| 46 |
|
| 47 |
-
def embed_repo(self, chunks_per_batch: int):
|
| 48 |
"""Issues batch embedding jobs for the entire repository."""
|
| 49 |
if self.openai_batch_ids:
|
| 50 |
raise ValueError("Embeddings are in progress.")
|
|
@@ -67,6 +65,14 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 67 |
self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
|
| 68 |
sub_batch
|
| 69 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
batch = []
|
| 71 |
|
| 72 |
# Finally, commit the last batch.
|
|
@@ -133,7 +139,9 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 133 |
OpenAIBatchEmbedder._export_to_jsonl([request], input_file)
|
| 134 |
|
| 135 |
# Uplaod the file and issue the embedding job.
|
| 136 |
-
batch_input_file = self.client.files.create(
|
|
|
|
|
|
|
| 137 |
batch_status = self._create_batch_job(batch_input_file.id)
|
| 138 |
logging.info("Created job with ID %s", batch_status.id)
|
| 139 |
return batch_status.id
|
|
|
|
| 34 |
class OpenAIBatchEmbedder(BatchEmbedder):
|
| 35 |
"""Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
|
| 36 |
|
| 37 |
+
def __init__(self, repo_manager: RepoManager, chunker: Chunker, local_dir: str):
|
|
|
|
|
|
|
| 38 |
self.repo_manager = repo_manager
|
| 39 |
self.chunker = chunker
|
| 40 |
self.local_dir = local_dir
|
|
|
|
| 42 |
self.openai_batch_ids = {}
|
| 43 |
self.client = OpenAI()
|
| 44 |
|
| 45 |
+
def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 46 |
"""Issues batch embedding jobs for the entire repository."""
|
| 47 |
if self.openai_batch_ids:
|
| 48 |
raise ValueError("Embeddings are in progress.")
|
|
|
|
| 65 |
self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
|
| 66 |
sub_batch
|
| 67 |
)
|
| 68 |
+
if (
|
| 69 |
+
max_embedding_jobs
|
| 70 |
+
and len(self.openai_batch_ids) >= max_embedding_jobs
|
| 71 |
+
):
|
| 72 |
+
logging.info(
|
| 73 |
+
"Reached the maximum number of embedding jobs. Stopping."
|
| 74 |
+
)
|
| 75 |
+
return
|
| 76 |
batch = []
|
| 77 |
|
| 78 |
# Finally, commit the last batch.
|
|
|
|
| 139 |
OpenAIBatchEmbedder._export_to_jsonl([request], input_file)
|
| 140 |
|
| 141 |
# Uplaod the file and issue the embedding job.
|
| 142 |
+
batch_input_file = self.client.files.create(
|
| 143 |
+
file=open(input_file, "rb"), purpose="batch"
|
| 144 |
+
)
|
| 145 |
batch_status = self._create_batch_job(batch_input_file.id)
|
| 146 |
logging.info("Created job with ID %s", batch_status.id)
|
| 147 |
return batch_status.id
|
src/index.py
CHANGED
|
@@ -47,10 +47,17 @@ def main():
|
|
| 47 |
"--pinecone_index_name", required=True, help="Pinecone index name"
|
| 48 |
)
|
| 49 |
parser.add_argument(
|
| 50 |
-
"--include",
|
|
|
|
| 51 |
)
|
| 52 |
parser.add_argument(
|
| 53 |
-
"--exclude",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
args = parser.parse_args()
|
|
@@ -84,7 +91,7 @@ def main():
|
|
| 84 |
logging.info("Issuing embedding jobs...")
|
| 85 |
chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
|
| 86 |
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 87 |
-
embedder.embed_repo(args.chunks_per_batch)
|
| 88 |
|
| 89 |
logging.info("Waiting for embeddings to be ready...")
|
| 90 |
while not embedder.embeddings_are_ready():
|
|
|
|
| 47 |
"--pinecone_index_name", required=True, help="Pinecone index name"
|
| 48 |
)
|
| 49 |
parser.add_argument(
|
| 50 |
+
"--include",
|
| 51 |
+
help="Path to a file containing a list of extensions to include. One extension per line.",
|
| 52 |
)
|
| 53 |
parser.add_argument(
|
| 54 |
+
"--exclude",
|
| 55 |
+
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--max_embedding_jobs", type=int,
|
| 59 |
+
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 60 |
+
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 61 |
)
|
| 62 |
|
| 63 |
args = parser.parse_args()
|
|
|
|
| 91 |
logging.info("Issuing embedding jobs...")
|
| 92 |
chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
|
| 93 |
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 94 |
+
embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
|
| 95 |
|
| 96 |
logging.info("Waiting for embeddings to be ready...")
|
| 97 |
while not embedder.embeddings_are_ready():
|