juliaturc commited on
Commit
5f9eeb4
·
1 Parent(s): 57007fe

Clean up the structure of the code.

Browse files
Files changed (3) hide show
  1. src/chat.py +3 -32
  2. src/index.py +2 -7
  3. src/vector_store.py +56 -1
src/chat.py CHANGED
@@ -4,20 +4,17 @@ You must run main.py first in order to index the codebase into a vector store.
4
  """
5
 
6
  import argparse
7
- from typing import List
8
 
9
  import gradio as gr
10
- import marqo
11
  from dotenv import load_dotenv
12
  from langchain.chains import (create_history_aware_retriever,
13
  create_retrieval_chain)
14
  from langchain.chains.combine_documents import create_stuff_documents_chain
15
  from langchain.schema import AIMessage, HumanMessage
16
- from langchain_community.vectorstores import Marqo, Pinecone
17
- from langchain_core.documents import Document
18
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
19
- from langchain_openai import ChatOpenAI, OpenAIEmbeddings
20
 
 
21
  from repo_manager import RepoManager
22
 
23
  load_dotenv()
@@ -26,33 +23,7 @@ load_dotenv()
26
  def build_rag_chain(args):
27
  """Builds a RAG chain via LangChain."""
28
  llm = ChatOpenAI(model=args.openai_model)
29
-
30
- if args.vector_store_type == "pinecone":
31
- vectorstore = Pinecone.from_existing_index(
32
- index_name=args.pinecone_index_name,
33
- embedding=OpenAIEmbeddings(),
34
- namespace=args.repo_id,
35
- )
36
- elif args.vector_store_type == "marqo":
37
- marqo_client = marqo.Client(url=args.marqo_url)
38
- vectorstore = Marqo(
39
- client=marqo_client,
40
- index_name=args.index_name,
41
- )
42
-
43
- # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in the
44
- # result, and instead take the "filename" directly from the result.
45
- def patched_method(self, results):
46
- documents: List[Document] = []
47
- for res in results["hits"]:
48
- documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
49
- return documents
50
-
51
- vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
52
- vectorstore, vectorstore.__class__
53
- )
54
-
55
- retriever = vectorstore.as_retriever()
56
 
57
  # Prompt to contextualize the latest query based on the chat history.
58
  contextualize_q_system_prompt = (
 
4
  """
5
 
6
  import argparse
 
7
 
8
  import gradio as gr
 
9
  from dotenv import load_dotenv
10
  from langchain.chains import (create_history_aware_retriever,
11
  create_retrieval_chain)
12
  from langchain.chains.combine_documents import create_stuff_documents_chain
13
  from langchain.schema import AIMessage, HumanMessage
 
 
14
  from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
15
+ from langchain_openai import ChatOpenAI
16
 
17
+ import vector_store
18
  from repo_manager import RepoManager
19
 
20
  load_dotenv()
 
23
  def build_rag_chain(args):
24
  """Builds a RAG chain via LangChain."""
25
  llm = ChatOpenAI(model=args.openai_model)
26
+ retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  # Prompt to contextualize the latest query based on the chat history.
29
  contextualize_q_system_prompt = (
src/index.py CHANGED
@@ -7,11 +7,10 @@ import time
7
  from chunker import UniversalChunker
8
  from embedder import MarqoEmbedder, OpenAIBatchEmbedder
9
  from repo_manager import RepoManager
10
- from vector_store import PineconeVectorStore
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
14
- OPENAI_EMBEDDING_SIZE = 1536
15
  MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
16
  MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
17
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
@@ -129,11 +128,7 @@ def main():
129
 
130
  logging.info("Moving embeddings to the vector store...")
131
  # Note to developer: Replace this with your preferred vector store.
132
- vector_store = PineconeVectorStore(
133
- index_name=args.index_name,
134
- dimension=OPENAI_EMBEDDING_SIZE,
135
- namespace=repo_manager.repo_id,
136
- )
137
  vector_store.ensure_exists()
138
  vector_store.upsert(embedder.download_embeddings())
139
  logging.info("Done!")
 
7
  from chunker import UniversalChunker
8
  from embedder import MarqoEmbedder, OpenAIBatchEmbedder
9
  from repo_manager import RepoManager
10
+ from vector_store import build_from_args
11
 
12
  logging.basicConfig(level=logging.INFO)
13
 
 
14
  MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
15
  MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
16
  MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
 
128
 
129
  logging.info("Moving embeddings to the vector store...")
130
  # Note to developer: Replace this with your preferred vector store.
131
+ vector_store = build_from_args(args)
 
 
 
 
132
  vector_store.ensure_exists()
133
  vector_store.upsert(embedder.download_embeddings())
134
  logging.info("Done!")
src/vector_store.py CHANGED
@@ -3,8 +3,13 @@
3
  from abc import ABC, abstractmethod
4
  from typing import Dict, Generator, List, Tuple
5
 
 
 
 
 
6
  from pinecone import Pinecone
7
 
 
8
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
9
 
10
 
@@ -30,11 +35,15 @@ class VectorStore(ABC):
30
  if batch:
31
  self.upsert_batch(batch)
32
 
 
 
 
 
33
 
34
  class PineconeVectorStore(VectorStore):
35
  """Vector store implementation using Pinecone."""
36
 
37
- def __init__(self, index_name: str, dimension: int, namespace: str):
38
  self.index_name = index_name
39
  self.dimension = dimension
40
  self.client = Pinecone()
@@ -50,3 +59,49 @@ class PineconeVectorStore(VectorStore):
50
  (metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
51
  ]
52
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from abc import ABC, abstractmethod
4
  from typing import Dict, Generator, List, Tuple
5
 
6
+ import marqo
7
+ from langchain_community.vectorstores import Marqo
8
+ from langchain_core.documents import Document
9
+ from langchain_openai import OpenAIEmbeddings
10
  from pinecone import Pinecone
11
 
12
+ OPENAI_EMBEDDING_SIZE = 1536
13
  Vector = Tuple[Dict, List[float]] # (metadata, embedding)
14
 
15
 
 
35
  if batch:
36
  self.upsert_batch(batch)
37
 
38
+ @abstractmethod
39
+ def to_langchain(self):
40
+ """Converts the vector store to a LangChain vector store object."""
41
+
42
 
43
  class PineconeVectorStore(VectorStore):
44
  """Vector store implementation using Pinecone."""
45
 
46
+ def __init__(self, index_name: str, namespace: str, dimension: int = OPENAI_EMBEDDING_SIZE):
47
  self.index_name = index_name
48
  self.dimension = dimension
49
  self.client = Pinecone()
 
59
  (metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
60
  ]
61
  self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
62
+
63
+ def to_langchain(self):
64
+ return Pinecone.from_existing_index(
65
+ index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
66
+ )
67
+
68
+
69
+ class MarqoVectorStore(VectorStore):
70
+ """Vector store implementation using Marqo."""
71
+
72
+ def __init__(self, url: str, index_name: str):
73
+ self.client = marqo.Client(url=url)
74
+ self.index_name = index_name
75
+
76
+ def ensure_exists(self):
77
+ pass
78
+
79
+ def upsert_batch(self, vectors: List[Vector]):
80
+ # Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
81
+ pass
82
+
83
+ def to_langchain(self):
84
+ vectorstore = Marqo(client=self.client, index_name=self.index_name)
85
+
86
+ # Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
87
+ # the result, and instead take the "filename" directly from the result.
88
+ def patched_method(self, results):
89
+ documents: List[Document] = []
90
+ for res in results["hits"]:
91
+ documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
92
+ return documents
93
+
94
+ vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
95
+ vectorstore, vectorstore.__class__
96
+ )
97
+ return vectorstore
98
+
99
+
100
+ def build_from_args(args: dict) -> VectorStore:
101
+ """Builds a vector store from the given command-line arguments."""
102
+ if args.vector_store_type == "pinecone":
103
+ return PineconeVectorStore(index_name=args.index_name, namespace=args.repo_id)
104
+ elif args.vector_store_type == "marqo":
105
+ return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
106
+ else:
107
+ raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")