juliaturc commited on
Commit
2730f0a
·
1 Parent(s): 7a1bb92

Option to index repo at specific commit.

Browse files
Files changed (3) hide show
  1. sage/data_manager.py +8 -1
  2. sage/index.py +4 -3
  3. sage/vector_store.py +14 -2
sage/data_manager.py CHANGED
@@ -29,6 +29,7 @@ class GitHubRepoManager(DataManager):
29
  def __init__(
30
  self,
31
  repo_id: str,
 
32
  local_dir: str = None,
33
  inclusion_file: str = None,
34
  exclusion_file: str = None,
@@ -36,6 +37,7 @@ class GitHubRepoManager(DataManager):
36
  """
37
  Args:
38
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
 
39
  local_dir: The local directory where the repository will be cloned.
40
  inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
41
  the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
@@ -44,6 +46,7 @@ class GitHubRepoManager(DataManager):
44
  """
45
  super().__init__(dataset_id=repo_id)
46
  self.repo_id = repo_id
 
47
 
48
  self.local_dir = local_dir or "/tmp/"
49
  if not os.path.exists(self.local_dir):
@@ -103,7 +106,11 @@ class GitHubRepoManager(DataManager):
103
  clone_url = f"https://github.com/{self.repo_id}.git"
104
 
105
  try:
106
- Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
 
 
 
 
107
  except GitCommandError as e:
108
  logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
109
  return False
 
29
  def __init__(
30
  self,
31
  repo_id: str,
32
+ commit_hash: str = None,
33
  local_dir: str = None,
34
  inclusion_file: str = None,
35
  exclusion_file: str = None,
 
37
  """
38
  Args:
39
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
40
+ commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
41
  local_dir: The local directory where the repository will be cloned.
42
  inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
43
  the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
 
46
  """
47
  super().__init__(dataset_id=repo_id)
48
  self.repo_id = repo_id
49
+ self.commit_hash = commit_hash
50
 
51
  self.local_dir = local_dir or "/tmp/"
52
  if not os.path.exists(self.local_dir):
 
106
  clone_url = f"https://github.com/{self.repo_id}.git"
107
 
108
  try:
109
+ if self.commit_hash:
110
+ repo = Repo.clone_from(clone_url, self.local_path)
111
+ repo.git.checkout(self.commit_hash)
112
+ else:
113
+ Repo.clone_from(clone_url, self.local_path, depth=1, single_branch=True)
114
  except GitCommandError as e:
115
  logging.error("Unable to clone %s from %s. Error: %s", self.repo_id, clone_url, e)
116
  return False
sage/index.py CHANGED
@@ -38,6 +38,7 @@ OPENAI_DEFAULT_EMBEDDING_SIZE = {
38
  def main():
39
  parser = argparse.ArgumentParser(description="Batch-embeds a GitHub repository and its issues.")
40
  parser.add_argument("repo_id", help="The ID of the repository to index")
 
41
  parser.add_argument("--embedder-type", default="marqo", choices=["openai", "marqo"])
42
  parser.add_argument(
43
  "--embedding-model",
@@ -72,9 +73,8 @@ def main():
72
  parser.add_argument(
73
  "--index-name",
74
  default=None,
75
- help="Vector store index name. For Marqo, we default it to the repository name. Required for Pinecone, since "
76
- "it needs to be created manually on their website. In Pinecone terminology, this is *not* the namespace (which "
77
- "we default to the repo ID).",
78
  )
79
  parser.add_argument(
80
  "--include",
@@ -202,6 +202,7 @@ def main():
202
  logging.info("Cloning the repository...")
203
  repo_manager = GitHubRepoManager(
204
  args.repo_id,
 
205
  local_dir=args.local_dir,
206
  inclusion_file=args.include,
207
  exclusion_file=args.exclude,
 
38
  def main():
39
  parser = argparse.ArgumentParser(description="Batch-embeds a GitHub repository and its issues.")
40
  parser.add_argument("repo_id", help="The ID of the repository to index")
41
+ parser.add_argument("--commit-hash", help="Optional commit hash to checkout. When not provided, defaults to HEAD.")
42
  parser.add_argument("--embedder-type", default="marqo", choices=["openai", "marqo"])
43
  parser.add_argument(
44
  "--embedding-model",
 
73
  parser.add_argument(
74
  "--index-name",
75
  default=None,
76
+ help="Vector store index name. For Marqo, we default it to the repository name. Required for Pinecone. "
77
+ "In Pinecone terminology, this is *not* the namespace (which we default to the repo ID).",
 
78
  )
79
  parser.add_argument(
80
  "--include",
 
202
  logging.info("Cloning the repository...")
203
  repo_manager = GitHubRepoManager(
204
  args.repo_id,
205
+ commit_hash=args.commit_hash,
206
  local_dir=args.local_dir,
207
  inclusion_file=args.include,
208
  exclusion_file=args.exclude,
sage/vector_store.py CHANGED
@@ -152,12 +152,24 @@ def build_from_args(args: dict) -> VectorStore:
152
  if not args.index_name:
153
  raise ValueError("Please specify --index-name for Pinecone.")
154
  dimension = args.embedding_size if "embedding_size" in args else None
 
 
 
 
 
155
  return PineconeVectorStore(
156
- index_name=args.index_name, namespace=args.repo_id, dimension=dimension, hybrid=args.hybrid_retrieval
157
  )
158
  elif args.vector_store_type == "marqo":
159
  marqo_url = args.marqo_url or "http://localhost:8882"
160
- index_name = args.index_name or args.repo_id.split("/")[1]
 
 
 
 
 
 
 
161
  return MarqoVectorStore(url=marqo_url, index_name=index_name)
162
  else:
163
  raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
 
152
  if not args.index_name:
153
  raise ValueError("Please specify --index-name for Pinecone.")
154
  dimension = args.embedding_size if "embedding_size" in args else None
155
+
156
+ namespace = args.repo_id
157
+ if args.commit_hash:
158
+ namespace += "/" + args.commit_hash
159
+
160
  return PineconeVectorStore(
161
+ index_name=args.index_name, namespace=namespace, dimension=dimension, hybrid=args.hybrid_retrieval
162
  )
163
  elif args.vector_store_type == "marqo":
164
  marqo_url = args.marqo_url or "http://localhost:8882"
165
+
166
+ index_name = args.index_name
167
+ if not index_name:
168
+ # Marqo doesn't allow slashes in the index name.
169
+ index_name = args.repo_id.split("/")[1]
170
+ if args.commit_hash:
171
+ index_name += "_" + args.commit_hash
172
+
173
  return MarqoVectorStore(url=marqo_url, index_name=index_name)
174
  else:
175
  raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")