Spaces:
Running
Running
Merge pull request #8 from Storia-AI/julia/fixes
Browse files- pyproject.toml +2 -0
- src/embedder.py +11 -18
- src/index.py +10 -3
- src/repo_manager.py +2 -0
- src/sample-exclude.txt +1 -0
pyproject.toml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[tool.black]
|
| 2 |
+
line-length = 120
|
src/embedder.py
CHANGED
|
@@ -34,9 +34,7 @@ class BatchEmbedder(ABC):
|
|
| 34 |
class OpenAIBatchEmbedder(BatchEmbedder):
|
| 35 |
"""Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
|
| 36 |
|
| 37 |
-
def __init__(
|
| 38 |
-
self, repo_manager: RepoManager, chunker: Chunker, local_dir: str
|
| 39 |
-
):
|
| 40 |
self.repo_manager = repo_manager
|
| 41 |
self.chunker = chunker
|
| 42 |
self.local_dir = local_dir
|
|
@@ -44,7 +42,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 44 |
self.openai_batch_ids = {}
|
| 45 |
self.client = OpenAI()
|
| 46 |
|
| 47 |
-
def embed_repo(self, chunks_per_batch: int):
|
| 48 |
"""Issues batch embedding jobs for the entire repository."""
|
| 49 |
if self.openai_batch_ids:
|
| 50 |
raise ValueError("Embeddings are in progress.")
|
|
@@ -60,24 +58,21 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 60 |
|
| 61 |
if len(batch) > chunks_per_batch:
|
| 62 |
for i in range(0, len(batch), chunks_per_batch):
|
| 63 |
-
|
| 64 |
openai_batch_id = self._issue_job_for_chunks(
|
| 65 |
-
|
| 66 |
-
)
|
| 67 |
-
self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
|
| 68 |
-
batch
|
| 69 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
batch = []
|
| 71 |
|
| 72 |
# Finally, commit the last batch.
|
| 73 |
if batch:
|
| 74 |
-
openai_batch_id = self._issue_job_for_chunks(
|
| 75 |
-
batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
|
| 76 |
-
)
|
| 77 |
self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(batch)
|
| 78 |
-
logging.info(
|
| 79 |
-
"Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count
|
| 80 |
-
)
|
| 81 |
|
| 82 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
| 83 |
metadata_file = os.path.join(self.local_dir, "openai_batch_ids.json")
|
|
@@ -149,9 +144,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
|
|
| 149 |
metadata={},
|
| 150 |
)
|
| 151 |
except Exception as e:
|
| 152 |
-
print(
|
| 153 |
-
f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}"
|
| 154 |
-
)
|
| 155 |
return None
|
| 156 |
|
| 157 |
@staticmethod
|
|
|
|
| 34 |
class OpenAIBatchEmbedder(BatchEmbedder):
|
| 35 |
"""Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
|
| 36 |
|
| 37 |
+
def __init__(self, repo_manager: RepoManager, chunker: Chunker, local_dir: str):
|
|
|
|
|
|
|
| 38 |
self.repo_manager = repo_manager
|
| 39 |
self.chunker = chunker
|
| 40 |
self.local_dir = local_dir
|
|
|
|
| 42 |
self.openai_batch_ids = {}
|
| 43 |
self.client = OpenAI()
|
| 44 |
|
| 45 |
+
def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
|
| 46 |
"""Issues batch embedding jobs for the entire repository."""
|
| 47 |
if self.openai_batch_ids:
|
| 48 |
raise ValueError("Embeddings are in progress.")
|
|
|
|
| 58 |
|
| 59 |
if len(batch) > chunks_per_batch:
|
| 60 |
for i in range(0, len(batch), chunks_per_batch):
|
| 61 |
+
sub_batch = batch[i : i + chunks_per_batch]
|
| 62 |
openai_batch_id = self._issue_job_for_chunks(
|
| 63 |
+
sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
|
|
|
|
|
|
|
|
|
|
| 64 |
)
|
| 65 |
+
self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(sub_batch)
|
| 66 |
+
if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
|
| 67 |
+
logging.info("Reached the maximum number of embedding jobs. Stopping.")
|
| 68 |
+
return
|
| 69 |
batch = []
|
| 70 |
|
| 71 |
# Finally, commit the last batch.
|
| 72 |
if batch:
|
| 73 |
+
openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
|
|
|
|
|
|
|
| 74 |
self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(batch)
|
| 75 |
+
logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
|
|
|
|
|
|
|
| 76 |
|
| 77 |
# Save the job IDs to a file, just in case this script is terminated by mistake.
|
| 78 |
metadata_file = os.path.join(self.local_dir, "openai_batch_ids.json")
|
|
|
|
| 144 |
metadata={},
|
| 145 |
)
|
| 146 |
except Exception as e:
|
| 147 |
+
print(f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}")
|
|
|
|
|
|
|
| 148 |
return None
|
| 149 |
|
| 150 |
@staticmethod
|
src/index.py
CHANGED
|
@@ -47,10 +47,17 @@ def main():
|
|
| 47 |
"--pinecone_index_name", required=True, help="Pinecone index name"
|
| 48 |
)
|
| 49 |
parser.add_argument(
|
| 50 |
-
"--include",
|
|
|
|
| 51 |
)
|
| 52 |
parser.add_argument(
|
| 53 |
-
"--exclude",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
)
|
| 55 |
|
| 56 |
args = parser.parse_args()
|
|
@@ -84,7 +91,7 @@ def main():
|
|
| 84 |
logging.info("Issuing embedding jobs...")
|
| 85 |
chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
|
| 86 |
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 87 |
-
embedder.embed_repo(args.chunks_per_batch)
|
| 88 |
|
| 89 |
logging.info("Waiting for embeddings to be ready...")
|
| 90 |
while not embedder.embeddings_are_ready():
|
|
|
|
| 47 |
"--pinecone_index_name", required=True, help="Pinecone index name"
|
| 48 |
)
|
| 49 |
parser.add_argument(
|
| 50 |
+
"--include",
|
| 51 |
+
help="Path to a file containing a list of extensions to include. One extension per line.",
|
| 52 |
)
|
| 53 |
parser.add_argument(
|
| 54 |
+
"--exclude",
|
| 55 |
+
help="Path to a file containing a list of extensions to exclude. One extension per line.",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--max_embedding_jobs", type=int,
|
| 59 |
+
help="Maximum number of embedding jobs to run. Specifying this might result in "
|
| 60 |
+
"indexing only part of the repository, but prevents you from burning through OpenAI credits.",
|
| 61 |
)
|
| 62 |
|
| 63 |
args = parser.parse_args()
|
|
|
|
| 91 |
logging.info("Issuing embedding jobs...")
|
| 92 |
chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
|
| 93 |
embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
|
| 94 |
+
embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
|
| 95 |
|
| 96 |
logging.info("Waiting for embeddings to be ready...")
|
| 97 |
while not embedder.embeddings_are_ready():
|
src/repo_manager.py
CHANGED
|
@@ -89,6 +89,8 @@ class RepoManager:
|
|
| 89 |
|
| 90 |
def _should_include(self, file_path: str) -> bool:
|
| 91 |
"""Checks whether the file should be indexed, based on the included and excluded extensions."""
|
|
|
|
|
|
|
| 92 |
_, extension = os.path.splitext(file_path)
|
| 93 |
extension = extension.lower()
|
| 94 |
if self.included_extensions and extension not in self.included_extensions:
|
|
|
|
| 89 |
|
| 90 |
def _should_include(self, file_path: str) -> bool:
|
| 91 |
"""Checks whether the file should be indexed, based on the included and excluded extensions."""
|
| 92 |
+
if os.path.islink(file_path):
|
| 93 |
+
return False
|
| 94 |
_, extension = os.path.splitext(file_path)
|
| 95 |
extension = extension.lower()
|
| 96 |
if self.included_extensions and extension not in self.included_extensions:
|
src/sample-exclude.txt
CHANGED
|
@@ -41,6 +41,7 @@
|
|
| 41 |
.pt
|
| 42 |
.ptl
|
| 43 |
.s
|
|
|
|
| 44 |
.sqlite
|
| 45 |
.stl
|
| 46 |
.sum
|
|
|
|
| 41 |
.pt
|
| 42 |
.ptl
|
| 43 |
.s
|
| 44 |
+
.so
|
| 45 |
.sqlite
|
| 46 |
.stl
|
| 47 |
.sum
|