juliaturc commited on
Commit
a8c35cd
·
2 Parent(s): 5cf92d3 77a0875

Merge pull request #8 from Storia-AI/julia/fixes

Browse files
pyproject.toml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [tool.black]
2
+ line-length = 120
src/embedder.py CHANGED
@@ -34,9 +34,7 @@ class BatchEmbedder(ABC):
34
  class OpenAIBatchEmbedder(BatchEmbedder):
35
  """Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
36
 
37
- def __init__(
38
- self, repo_manager: RepoManager, chunker: Chunker, local_dir: str
39
- ):
40
  self.repo_manager = repo_manager
41
  self.chunker = chunker
42
  self.local_dir = local_dir
@@ -44,7 +42,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
44
  self.openai_batch_ids = {}
45
  self.client = OpenAI()
46
 
47
- def embed_repo(self, chunks_per_batch: int):
48
  """Issues batch embedding jobs for the entire repository."""
49
  if self.openai_batch_ids:
50
  raise ValueError("Embeddings are in progress.")
@@ -60,24 +58,21 @@ class OpenAIBatchEmbedder(BatchEmbedder):
60
 
61
  if len(batch) > chunks_per_batch:
62
  for i in range(0, len(batch), chunks_per_batch):
63
- batch = batch[i : i + chunks_per_batch]
64
  openai_batch_id = self._issue_job_for_chunks(
65
- batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
66
- )
67
- self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(
68
- batch
69
  )
 
 
 
 
70
  batch = []
71
 
72
  # Finally, commit the last batch.
73
  if batch:
74
- openai_batch_id = self._issue_job_for_chunks(
75
- batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
76
- )
77
  self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(batch)
78
- logging.info(
79
- "Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count
80
- )
81
 
82
  # Save the job IDs to a file, just in case this script is terminated by mistake.
83
  metadata_file = os.path.join(self.local_dir, "openai_batch_ids.json")
@@ -149,9 +144,7 @@ class OpenAIBatchEmbedder(BatchEmbedder):
149
  metadata={},
150
  )
151
  except Exception as e:
152
- print(
153
- f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}"
154
- )
155
  return None
156
 
157
  @staticmethod
 
34
  class OpenAIBatchEmbedder(BatchEmbedder):
35
  """Batch embedder that calls OpenAI. See https://platform.openai.com/docs/guides/batch/overview."""
36
 
37
+ def __init__(self, repo_manager: RepoManager, chunker: Chunker, local_dir: str):
 
 
38
  self.repo_manager = repo_manager
39
  self.chunker = chunker
40
  self.local_dir = local_dir
 
42
  self.openai_batch_ids = {}
43
  self.client = OpenAI()
44
 
45
+ def embed_repo(self, chunks_per_batch: int, max_embedding_jobs: int = None):
46
  """Issues batch embedding jobs for the entire repository."""
47
  if self.openai_batch_ids:
48
  raise ValueError("Embeddings are in progress.")
 
58
 
59
  if len(batch) > chunks_per_batch:
60
  for i in range(0, len(batch), chunks_per_batch):
61
+ sub_batch = batch[i : i + chunks_per_batch]
62
  openai_batch_id = self._issue_job_for_chunks(
63
+ sub_batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}"
 
 
 
64
  )
65
+ self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(sub_batch)
66
+ if max_embedding_jobs and len(self.openai_batch_ids) >= max_embedding_jobs:
67
+ logging.info("Reached the maximum number of embedding jobs. Stopping.")
68
+ return
69
  batch = []
70
 
71
  # Finally, commit the last batch.
72
  if batch:
73
+ openai_batch_id = self._issue_job_for_chunks(batch, batch_id=f"{repo_name}/{len(self.openai_batch_ids)}")
 
 
74
  self.openai_batch_ids[openai_batch_id] = self._metadata_for_chunks(batch)
75
+ logging.info("Issued %d jobs for %d chunks.", len(self.openai_batch_ids), chunk_count)
 
 
76
 
77
  # Save the job IDs to a file, just in case this script is terminated by mistake.
78
  metadata_file = os.path.join(self.local_dir, "openai_batch_ids.json")
 
144
  metadata={},
145
  )
146
  except Exception as e:
147
+ print(f"Failed to create batch job with input_file_id={input_file_id}. Error: {e}")
 
 
148
  return None
149
 
150
  @staticmethod
src/index.py CHANGED
@@ -47,10 +47,17 @@ def main():
47
  "--pinecone_index_name", required=True, help="Pinecone index name"
48
  )
49
  parser.add_argument(
50
- "--include", help="Path to a file containing a list of extensions to include. One extension per line."
 
51
  )
52
  parser.add_argument(
53
- "--exclude", help="Path to a file containing a list of extensions to exclude. One extension per line."
 
 
 
 
 
 
54
  )
55
 
56
  args = parser.parse_args()
@@ -84,7 +91,7 @@ def main():
84
  logging.info("Issuing embedding jobs...")
85
  chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
86
  embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
87
- embedder.embed_repo(args.chunks_per_batch)
88
 
89
  logging.info("Waiting for embeddings to be ready...")
90
  while not embedder.embeddings_are_ready():
 
47
  "--pinecone_index_name", required=True, help="Pinecone index name"
48
  )
49
  parser.add_argument(
50
+ "--include",
51
+ help="Path to a file containing a list of extensions to include. One extension per line.",
52
  )
53
  parser.add_argument(
54
+ "--exclude",
55
+ help="Path to a file containing a list of extensions to exclude. One extension per line.",
56
+ )
57
+ parser.add_argument(
58
+ "--max_embedding_jobs", type=int,
59
+ help="Maximum number of embedding jobs to run. Specifying this might result in "
60
+ "indexing only part of the repository, but prevents you from burning through OpenAI credits.",
61
  )
62
 
63
  args = parser.parse_args()
 
91
  logging.info("Issuing embedding jobs...")
92
  chunker = UniversalChunker(max_tokens=args.tokens_per_chunk)
93
  embedder = OpenAIBatchEmbedder(repo_manager, chunker, args.local_dir)
94
+ embedder.embed_repo(args.chunks_per_batch, args.max_embedding_jobs)
95
 
96
  logging.info("Waiting for embeddings to be ready...")
97
  while not embedder.embeddings_are_ready():
src/repo_manager.py CHANGED
@@ -89,6 +89,8 @@ class RepoManager:
89
 
90
  def _should_include(self, file_path: str) -> bool:
91
  """Checks whether the file should be indexed, based on the included and excluded extensions."""
 
 
92
  _, extension = os.path.splitext(file_path)
93
  extension = extension.lower()
94
  if self.included_extensions and extension not in self.included_extensions:
 
89
 
90
  def _should_include(self, file_path: str) -> bool:
91
  """Checks whether the file should be indexed, based on the included and excluded extensions."""
92
+ if os.path.islink(file_path):
93
+ return False
94
  _, extension = os.path.splitext(file_path)
95
  extension = extension.lower()
96
  if self.included_extensions and extension not in self.included_extensions:
src/sample-exclude.txt CHANGED
@@ -41,6 +41,7 @@
41
  .pt
42
  .ptl
43
  .s
 
44
  .sqlite
45
  .stl
46
  .sum
 
41
  .pt
42
  .ptl
43
  .s
44
+ .so
45
  .sqlite
46
  .stl
47
  .sum