Mihail Eric commited on
Commit
9581b48
·
2 Parent(s): bee461c 8b42d65

fix merge conflict

Browse files
Files changed (7) hide show
  1. README.md +10 -0
  2. sage/data_manager.py +3 -2
  3. sage/github.py +2 -3
  4. sage/index.py +9 -5
  5. sage/retriever.py +3 -2
  6. sage/vector_store.py +18 -18
  7. setup.py +1 -1
README.md CHANGED
@@ -141,6 +141,15 @@ If you are planning on indexing GitHub issues in addition to the codebase, you w
141
 
142
  ## Additional features
143
 
 
 
 
 
 
 
 
 
 
144
  <details>
145
  <summary><strong>:hammer_and_wrench: Control which files get indexed</strong></summary>
146
 
@@ -177,6 +186,7 @@ By default, we use the exclusion file [sample-exclude.txt](sage/sample-exclude.t
177
  <details>
178
  <summary><strong>:bug: Index open GitHub issues</strong></summary>
179
  You will need a GitHub token first:
 
180
  ```
181
  export GITHUB_TOKEN=...
182
  ```
 
141
 
142
  ## Additional features
143
 
144
+ <details>
145
+ <summary><strong>:lock: Working with private repositories</strong></summary>
146
+ To index and chat with a private repository, simply set the GITHUB_TOKEN environment variable. To obtain this token: go to github.com > click on your profile icon > Settings > Developer settings > Personal access tokens. You can either make a fine-grained token for the desired repository, or a classic token.
147
+
148
+ ```
149
+ export GITHUB_TOKEN=...
150
+ ```
151
+ </details>
152
+
153
  <details>
154
  <summary><strong>:hammer_and_wrench: Control which files get indexed</strong></summary>
155
 
 
186
  <details>
187
  <summary><strong>:bug: Index open GitHub issues</strong></summary>
188
  You will need a GitHub token first:
189
+
190
  ```
191
  export GITHUB_TOKEN=...
192
  ```
sage/data_manager.py CHANGED
@@ -30,6 +30,7 @@ class GitHubRepoManager(DataManager):
30
  self,
31
  repo_id: str,
32
  commit_hash: str = None,
 
33
  local_dir: str = None,
34
  inclusion_file: str = None,
35
  exclusion_file: str = None,
@@ -38,6 +39,7 @@ class GitHubRepoManager(DataManager):
38
  Args:
39
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
40
  commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
 
41
  local_dir: The local directory where the repository will be cloned.
42
  inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
43
  the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
@@ -47,6 +49,7 @@ class GitHubRepoManager(DataManager):
47
  super().__init__(dataset_id=repo_id)
48
  self.repo_id = repo_id
49
  self.commit_hash = commit_hash
 
50
 
51
  self.local_dir = local_dir or "/tmp/"
52
  if not os.path.exists(self.local_dir):
@@ -57,8 +60,6 @@ class GitHubRepoManager(DataManager):
57
  if not os.path.exists(self.log_dir):
58
  os.makedirs(self.log_dir)
59
 
60
- self.access_token = os.getenv("GITHUB_TOKEN")
61
-
62
  if inclusion_file and exclusion_file:
63
  raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")
64
 
 
30
  self,
31
  repo_id: str,
32
  commit_hash: str = None,
33
+ access_token: str = None,
34
  local_dir: str = None,
35
  inclusion_file: str = None,
36
  exclusion_file: str = None,
 
39
  Args:
40
  repo_id: The identifier of the repository in owner/repo format, e.g. "Storia-AI/sage".
41
  commit_hash: Optional commit hash to checkout. If not specified, we pull the latest version of the repo.
42
+ access_token: A GitHub access token to use for cloning private repositories. Not needed for public repos.
43
  local_dir: The local directory where the repository will be cloned.
44
  inclusion_file: A file with a lists of files/directories/extensions to include. Each line must be in one of
45
  the following formats: "ext:.my-extension", "file:my-file.py", or "dir:my-directory".
 
49
  super().__init__(dataset_id=repo_id)
50
  self.repo_id = repo_id
51
  self.commit_hash = commit_hash
52
+ self.access_token = access_token
53
 
54
  self.local_dir = local_dir or "/tmp/"
55
  if not os.path.exists(self.local_dir):
 
60
  if not os.path.exists(self.log_dir):
61
  os.makedirs(self.log_dir)
62
 
 
 
63
  if inclusion_file and exclusion_file:
64
  raise ValueError("Only one of inclusion_file or exclusion_file should be provided.")
65
 
sage/github.py CHANGED
@@ -1,7 +1,6 @@
1
  """GitHub-specific implementations for DataManager and Chunker."""
2
 
3
  import logging
4
- import os
5
  from dataclasses import dataclass
6
  from typing import Any, Dict, Generator, List, Tuple
7
 
@@ -47,12 +46,12 @@ class GitHubIssue:
47
  class GitHubIssuesManager(DataManager):
48
  """Class to manage the GitHub issues of a particular repository."""
49
 
50
- def __init__(self, repo_id: str, index_comments: bool = False, max_issues: int = None):
51
  super().__init__(dataset_id=repo_id + "/issues")
52
  self.repo_id = repo_id
53
  self.index_comments = index_comments
54
  self.max_issues = max_issues
55
- self.access_token = os.getenv("GITHUB_TOKEN")
56
  if not self.access_token:
57
  raise ValueError("Please set the GITHUB_TOKEN environment variable when indexing GitHub issues.")
58
  self.issues = []
 
1
  """GitHub-specific implementations for DataManager and Chunker."""
2
 
3
  import logging
 
4
  from dataclasses import dataclass
5
  from typing import Any, Dict, Generator, List, Tuple
6
 
 
46
  class GitHubIssuesManager(DataManager):
47
  """Class to manage the GitHub issues of a particular repository."""
48
 
49
+ def __init__(self, repo_id: str, access_token: str, index_comments: bool = False, max_issues: int = None):
50
  super().__init__(dataset_id=repo_id + "/issues")
51
  self.repo_id = repo_id
52
  self.index_comments = index_comments
53
  self.max_issues = max_issues
54
+ self.access_token = access_token
55
  if not self.access_token:
56
  raise ValueError("Please set the GITHUB_TOKEN environment variable when indexing GitHub issues.")
57
  self.issues = []
sage/index.py CHANGED
@@ -1,11 +1,10 @@
1
  """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
 
3
  import logging
4
- import nltk
5
  import time
6
 
7
  import configargparse
8
- import pkg_resources
9
 
10
  import sage.config as sage_config
11
  from sage.chunker import UniversalFileChunker
@@ -55,6 +54,7 @@ def main():
55
  repo_manager = GitHubRepoManager(
56
  args.repo_id,
57
  commit_hash=args.commit_hash,
 
58
  local_dir=args.local_dir,
59
  inclusion_file=args.include,
60
  exclusion_file=args.exclude,
@@ -69,7 +69,9 @@ def main():
69
  issues_embedder = None
70
  if args.index_issues:
71
  logging.info("Issuing embedding jobs for GitHub issues...")
72
- issues_manager = GitHubIssuesManager(args.repo_id, index_comments=args.index_issue_comments)
 
 
73
  issues_manager.download()
74
  logging.info("Embedding GitHub issues...")
75
  chunker = GitHubIssuesChunker(max_tokens=args.tokens_per_chunk)
@@ -94,7 +96,7 @@ def main():
94
  logging.info("Moving embeddings to the repo vector store...")
95
  repo_vector_store = build_vector_store_from_args(args, repo_manager)
96
  repo_vector_store.ensure_exists()
97
- repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file))
98
 
99
  if issues_embedder is not None:
100
  logging.info("Waiting for issue embeddings to be ready...")
@@ -105,7 +107,9 @@ def main():
105
  logging.info("Moving embeddings to the issues vector store...")
106
  issues_vector_store = build_vector_store_from_args(args, issues_manager)
107
  issues_vector_store.ensure_exists()
108
- issues_vector_store.upsert(issues_embedder.download_embeddings(issues_jobs_file))
 
 
109
 
110
  logging.info("Done!")
111
 
 
1
  """Runs a batch job to compute embeddings for an entire repo and stores them into a vector store."""
2
 
3
  import logging
4
+ import os
5
  import time
6
 
7
  import configargparse
 
8
 
9
  import sage.config as sage_config
10
  from sage.chunker import UniversalFileChunker
 
54
  repo_manager = GitHubRepoManager(
55
  args.repo_id,
56
  commit_hash=args.commit_hash,
57
+ access_token=os.getenv("GITHUB_TOKEN"),
58
  local_dir=args.local_dir,
59
  inclusion_file=args.include,
60
  exclusion_file=args.exclude,
 
69
  issues_embedder = None
70
  if args.index_issues:
71
  logging.info("Issuing embedding jobs for GitHub issues...")
72
+ issues_manager = GitHubIssuesManager(
73
+ args.repo_id, access_token=os.getenv("GITHUB_TOKEN"), index_comments=args.index_issue_comments
74
+ )
75
  issues_manager.download()
76
  logging.info("Embedding GitHub issues...")
77
  chunker = GitHubIssuesChunker(max_tokens=args.tokens_per_chunk)
 
96
  logging.info("Moving embeddings to the repo vector store...")
97
  repo_vector_store = build_vector_store_from_args(args, repo_manager)
98
  repo_vector_store.ensure_exists()
99
+ repo_vector_store.upsert(repo_embedder.download_embeddings(repo_jobs_file), namespace=args.index_namespace)
100
 
101
  if issues_embedder is not None:
102
  logging.info("Waiting for issue embeddings to be ready...")
 
107
  logging.info("Moving embeddings to the issues vector store...")
108
  issues_vector_store = build_vector_store_from_args(args, issues_manager)
109
  issues_vector_store.ensure_exists()
110
+ issues_vector_store.upsert(
111
+ issues_embedder.download_embeddings(issues_jobs_file), namespace=args.index_namespace
112
+ )
113
 
114
  logging.info("Done!")
115
 
sage/retriever.py CHANGED
@@ -2,7 +2,6 @@ from langchain.retrievers import ContextualCompressionRetriever
2
  from langchain_openai import OpenAIEmbeddings
3
  from langchain_voyageai import VoyageAIEmbeddings
4
 
5
-
6
  from sage.reranker import build_reranker
7
  from sage.vector_store import build_vector_store_from_args
8
 
@@ -17,7 +16,9 @@ def build_retriever_from_args(args):
17
  else:
18
  embeddings = None
19
 
20
- retriever = build_vector_store_from_args(args).as_retriever(top_k=args.retriever_top_k, embeddings=embeddings)
 
 
21
 
22
  reranker = build_reranker(args.reranker_provider, args.reranker_model, args.reranker_top_k)
23
  if reranker:
 
2
  from langchain_openai import OpenAIEmbeddings
3
  from langchain_voyageai import VoyageAIEmbeddings
4
 
 
5
  from sage.reranker import build_reranker
6
  from sage.vector_store import build_vector_store_from_args
7
 
 
16
  else:
17
  embeddings = None
18
 
19
+ retriever = build_vector_store_from_args(args).as_retriever(
20
+ top_k=args.retriever_top_k, embeddings=embeddings, namespace=args.index_namespace
21
+ )
22
 
23
  reranker = build_reranker(args.reranker_provider, args.reranker_model, args.reranker_top_k)
24
  if reranker:
sage/vector_store.py CHANGED
@@ -1,7 +1,7 @@
1
  """Vector store abstraction and implementations."""
2
 
3
- import os
4
  import logging
 
5
  from abc import ABC, abstractmethod
6
  from functools import cached_property
7
  from typing import Dict, Generator, List, Optional, Tuple
@@ -36,33 +36,32 @@ class VectorStore(ABC):
36
  """Ensures that the vector store exists. Creates it if it doesn't."""
37
 
38
  @abstractmethod
39
- def upsert_batch(self, vectors: List[Vector]):
40
  """Upserts a batch of vectors."""
41
 
42
- def upsert(self, vectors: Generator[Vector, None, None]):
43
  """Upserts in batches of 100, since vector stores have a limit on upsert size."""
44
  batch = []
45
  for metadata, embedding in vectors:
46
  batch.append((metadata, embedding))
47
  if len(batch) == 100:
48
- self.upsert_batch(batch)
49
  batch = []
50
  if batch:
51
- self.upsert_batch(batch)
52
 
53
  @abstractmethod
54
- def as_retriever(self, top_k: int, embeddings: Embeddings):
55
  """Converts the vector store to a LangChain retriever object."""
56
 
57
 
58
  class PineconeVectorStore(VectorStore):
59
  """Vector store implementation using Pinecone."""
60
 
61
- def __init__(self, index_name: str, namespace: str, dimension: int, alpha: float, bm25_cache: Optional[str] = None):
62
  """
63
  Args:
64
  index_name: The name of the Pinecone index to use. If it doesn't exist already, we'll create it.
65
- namespace: The namespace within the index to use.
66
  dimension: The dimension of the vectors.
67
  alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
68
  BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
@@ -72,7 +71,6 @@ class PineconeVectorStore(VectorStore):
72
  self.index_name = index_name
73
  self.dimension = dimension
74
  self.client = Pinecone()
75
- self.namespace = namespace
76
  self.alpha = alpha
77
 
78
  if alpha < 1.0:
@@ -105,7 +103,8 @@ class PineconeVectorStore(VectorStore):
105
  def patched_query(*args, **kwargs):
106
  result = original_query(*args, **kwargs)
107
  for res in result["matches"]:
108
- res["metadata"]["context"] = res["metadata"][TEXT_FIELD]
 
109
  return result
110
 
111
  index.query = patched_query
@@ -121,7 +120,7 @@ class PineconeVectorStore(VectorStore):
121
  spec=ServerlessSpec(cloud="aws", region="us-east-1"),
122
  )
123
 
124
- def upsert_batch(self, vectors: List[Vector]):
125
  pinecone_vectors = []
126
  for i, (metadata, embedding) in enumerate(vectors):
127
  vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata}
@@ -129,21 +128,21 @@ class PineconeVectorStore(VectorStore):
129
  vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD])
130
  pinecone_vectors.append(vector)
131
 
132
- self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
133
 
134
- def as_retriever(self, top_k: int, embeddings: Embeddings):
135
  if self.bm25_encoder:
136
  return PineconeHybridSearchRetriever(
137
  embeddings=embeddings,
138
  sparse_encoder=self.bm25_encoder,
139
  index=self.index,
140
- namespace=self.namespace,
141
  top_k=top_k,
142
  alpha=self.alpha,
143
  )
144
 
145
  return LangChainPinecone.from_existing_index(
146
- index_name=self.index_name, embedding=embeddings, namespace=self.namespace
147
  ).as_retriever(search_kwargs={"k": top_k})
148
 
149
 
@@ -157,12 +156,14 @@ class MarqoVectorStore(VectorStore):
157
  def ensure_exists(self):
158
  pass
159
 
160
- def upsert_batch(self, vectors: List[Vector]):
161
  # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
162
  pass
163
 
164
- def as_retriever(self, top_k: int, embeddings: Embeddings = None):
165
  del embeddings # Unused; The Marqo vector store is also an embedder.
 
 
166
  vectorstore = Marqo(client=self.client, index_name=self.index_name)
167
 
168
  # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
@@ -202,7 +203,6 @@ def build_vector_store_from_args(args: dict, data_manager: Optional[DataManager]
202
 
203
  return PineconeVectorStore(
204
  index_name=args.pinecone_index_name,
205
- namespace=args.index_namespace,
206
  dimension=args.embedding_size if "embedding_size" in args else None,
207
  alpha=args.retrieval_alpha,
208
  bm25_cache=bm25_cache,
 
1
  """Vector store abstraction and implementations."""
2
 
 
3
  import logging
4
+ import os
5
  from abc import ABC, abstractmethod
6
  from functools import cached_property
7
  from typing import Dict, Generator, List, Optional, Tuple
 
36
  """Ensures that the vector store exists. Creates it if it doesn't."""
37
 
38
  @abstractmethod
39
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
40
  """Upserts a batch of vectors."""
41
 
42
+ def upsert(self, vectors: Generator[Vector, None, None], namespace: str):
43
  """Upserts in batches of 100, since vector stores have a limit on upsert size."""
44
  batch = []
45
  for metadata, embedding in vectors:
46
  batch.append((metadata, embedding))
47
  if len(batch) == 100:
48
+ self.upsert_batch(batch, namespace)
49
  batch = []
50
  if batch:
51
+ self.upsert_batch(batch, namespace)
52
 
53
  @abstractmethod
54
+ def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
55
  """Converts the vector store to a LangChain retriever object."""
56
 
57
 
58
  class PineconeVectorStore(VectorStore):
59
  """Vector store implementation using Pinecone."""
60
 
61
+ def __init__(self, index_name: str, dimension: int, alpha: float, bm25_cache: Optional[str] = None):
62
  """
63
  Args:
64
  index_name: The name of the Pinecone index to use. If it doesn't exist already, we'll create it.
 
65
  dimension: The dimension of the vectors.
66
  alpha: The alpha parameter for hybrid search: alpha == 1.0 means pure dense search, alpha == 0.0 means pure
67
  BM25, and 0.0 < alpha < 1.0 means a hybrid of the two.
 
71
  self.index_name = index_name
72
  self.dimension = dimension
73
  self.client = Pinecone()
 
74
  self.alpha = alpha
75
 
76
  if alpha < 1.0:
 
103
  def patched_query(*args, **kwargs):
104
  result = original_query(*args, **kwargs)
105
  for res in result["matches"]:
106
+ if TEXT_FIELD in res["metadata"]:
107
+ res["metadata"]["context"] = res["metadata"][TEXT_FIELD]
108
  return result
109
 
110
  index.query = patched_query
 
120
  spec=ServerlessSpec(cloud="aws", region="us-east-1"),
121
  )
122
 
123
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
124
  pinecone_vectors = []
125
  for i, (metadata, embedding) in enumerate(vectors):
126
  vector = {"id": metadata.get("id", str(i)), "values": embedding, "metadata": metadata}
 
128
  vector["sparse_values"] = self.bm25_encoder.encode_documents(metadata[TEXT_FIELD])
129
  pinecone_vectors.append(vector)
130
 
131
+ self.index.upsert(vectors=pinecone_vectors, namespace=namespace)
132
 
133
+ def as_retriever(self, top_k: int, embeddings: Embeddings, namespace: str):
134
  if self.bm25_encoder:
135
  return PineconeHybridSearchRetriever(
136
  embeddings=embeddings,
137
  sparse_encoder=self.bm25_encoder,
138
  index=self.index,
139
+ namespace=namespace,
140
  top_k=top_k,
141
  alpha=self.alpha,
142
  )
143
 
144
  return LangChainPinecone.from_existing_index(
145
+ index_name=self.index_name, embedding=embeddings, namespace=namespace
146
  ).as_retriever(search_kwargs={"k": top_k})
147
 
148
 
 
156
  def ensure_exists(self):
157
  pass
158
 
159
+ def upsert_batch(self, vectors: List[Vector], namespace: str):
160
  # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
161
  pass
162
 
163
+ def as_retriever(self, top_k: int, embeddings: Embeddings = None, namespace: str = None):
164
  del embeddings # Unused; The Marqo vector store is also an embedder.
165
+ del namespace # Unused; Unlike Pinecone, Marqo doesn't differentiate between index name and namespace.
166
+
167
  vectorstore = Marqo(client=self.client, index_name=self.index_name)
168
 
169
  # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
 
203
 
204
  return PineconeVectorStore(
205
  index_name=args.pinecone_index_name,
 
206
  dimension=args.embedding_size if "embedding_size" in args else None,
207
  alpha=args.retrieval_alpha,
208
  bm25_cache=bm25_cache,
setup.py CHANGED
@@ -8,7 +8,7 @@ def readfile(filename):
8
 
9
  setup(
10
  name="sage",
11
- version="0.1.0",
12
  packages=find_packages(),
13
  include_package_data=True,
14
  package_data={
 
8
 
9
  setup(
10
  name="sage",
11
+ version="0.1.2",
12
  packages=find_packages(),
13
  include_package_data=True,
14
  package_data={