Spaces:
Running
Running
Clean up the structure of the code.
Browse files- src/chat.py +3 -32
- src/index.py +2 -7
- src/vector_store.py +56 -1
src/chat.py
CHANGED
|
@@ -4,20 +4,17 @@ You must run main.py first in order to index the codebase into a vector store.
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
| 7 |
-
from typing import List
|
| 8 |
|
| 9 |
import gradio as gr
|
| 10 |
-
import marqo
|
| 11 |
from dotenv import load_dotenv
|
| 12 |
from langchain.chains import (create_history_aware_retriever,
|
| 13 |
create_retrieval_chain)
|
| 14 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 15 |
from langchain.schema import AIMessage, HumanMessage
|
| 16 |
-
from langchain_community.vectorstores import Marqo, Pinecone
|
| 17 |
-
from langchain_core.documents import Document
|
| 18 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 19 |
-
from langchain_openai import ChatOpenAI
|
| 20 |
|
|
|
|
| 21 |
from repo_manager import RepoManager
|
| 22 |
|
| 23 |
load_dotenv()
|
|
@@ -26,33 +23,7 @@ load_dotenv()
|
|
| 26 |
def build_rag_chain(args):
|
| 27 |
"""Builds a RAG chain via LangChain."""
|
| 28 |
llm = ChatOpenAI(model=args.openai_model)
|
| 29 |
-
|
| 30 |
-
if args.vector_store_type == "pinecone":
|
| 31 |
-
vectorstore = Pinecone.from_existing_index(
|
| 32 |
-
index_name=args.pinecone_index_name,
|
| 33 |
-
embedding=OpenAIEmbeddings(),
|
| 34 |
-
namespace=args.repo_id,
|
| 35 |
-
)
|
| 36 |
-
elif args.vector_store_type == "marqo":
|
| 37 |
-
marqo_client = marqo.Client(url=args.marqo_url)
|
| 38 |
-
vectorstore = Marqo(
|
| 39 |
-
client=marqo_client,
|
| 40 |
-
index_name=args.index_name,
|
| 41 |
-
)
|
| 42 |
-
|
| 43 |
-
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in the
|
| 44 |
-
# result, and instead take the "filename" directly from the result.
|
| 45 |
-
def patched_method(self, results):
|
| 46 |
-
documents: List[Document] = []
|
| 47 |
-
for res in results["hits"]:
|
| 48 |
-
documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
|
| 49 |
-
return documents
|
| 50 |
-
|
| 51 |
-
vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
|
| 52 |
-
vectorstore, vectorstore.__class__
|
| 53 |
-
)
|
| 54 |
-
|
| 55 |
-
retriever = vectorstore.as_retriever()
|
| 56 |
|
| 57 |
# Prompt to contextualize the latest query based on the chat history.
|
| 58 |
contextualize_q_system_prompt = (
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
import argparse
|
|
|
|
| 7 |
|
| 8 |
import gradio as gr
|
|
|
|
| 9 |
from dotenv import load_dotenv
|
| 10 |
from langchain.chains import (create_history_aware_retriever,
|
| 11 |
create_retrieval_chain)
|
| 12 |
from langchain.chains.combine_documents import create_stuff_documents_chain
|
| 13 |
from langchain.schema import AIMessage, HumanMessage
|
|
|
|
|
|
|
| 14 |
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 15 |
+
from langchain_openai import ChatOpenAI
|
| 16 |
|
| 17 |
+
import vector_store
|
| 18 |
from repo_manager import RepoManager
|
| 19 |
|
| 20 |
load_dotenv()
|
|
|
|
| 23 |
def build_rag_chain(args):
|
| 24 |
"""Builds a RAG chain via LangChain."""
|
| 25 |
llm = ChatOpenAI(model=args.openai_model)
|
| 26 |
+
retriever = vector_store.build_from_args(args).to_langchain().as_retriever()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Prompt to contextualize the latest query based on the chat history.
|
| 29 |
contextualize_q_system_prompt = (
|
src/index.py
CHANGED
|
@@ -7,11 +7,10 @@ import time
|
|
| 7 |
from chunker import UniversalChunker
|
| 8 |
from embedder import MarqoEmbedder, OpenAIBatchEmbedder
|
| 9 |
from repo_manager import RepoManager
|
| 10 |
-
from vector_store import
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
| 14 |
-
OPENAI_EMBEDDING_SIZE = 1536
|
| 15 |
MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 16 |
MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
| 17 |
MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
|
@@ -129,11 +128,7 @@ def main():
|
|
| 129 |
|
| 130 |
logging.info("Moving embeddings to the vector store...")
|
| 131 |
# Note to developer: Replace this with your preferred vector store.
|
| 132 |
-
vector_store =
|
| 133 |
-
index_name=args.index_name,
|
| 134 |
-
dimension=OPENAI_EMBEDDING_SIZE,
|
| 135 |
-
namespace=repo_manager.repo_id,
|
| 136 |
-
)
|
| 137 |
vector_store.ensure_exists()
|
| 138 |
vector_store.upsert(embedder.download_embeddings())
|
| 139 |
logging.info("Done!")
|
|
|
|
| 7 |
from chunker import UniversalChunker
|
| 8 |
from embedder import MarqoEmbedder, OpenAIBatchEmbedder
|
| 9 |
from repo_manager import RepoManager
|
| 10 |
+
from vector_store import build_from_args
|
| 11 |
|
| 12 |
logging.basicConfig(level=logging.INFO)
|
| 13 |
|
|
|
|
| 14 |
MAX_TOKENS_PER_CHUNK = 8192 # The ADA embedder from OpenAI has a maximum of 8192 tokens.
|
| 15 |
MAX_CHUNKS_PER_BATCH = 2048 # The OpenAI batch embedding API enforces a maximum of 2048 chunks per batch.
|
| 16 |
MAX_TOKENS_PER_JOB = 3_000_000 # The OpenAI batch embedding API enforces a maximum of 3M tokens processed at once.
|
|
|
|
| 128 |
|
| 129 |
logging.info("Moving embeddings to the vector store...")
|
| 130 |
# Note to developer: Replace this with your preferred vector store.
|
| 131 |
+
vector_store = build_from_args(args)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
vector_store.ensure_exists()
|
| 133 |
vector_store.upsert(embedder.download_embeddings())
|
| 134 |
logging.info("Done!")
|
src/vector_store.py
CHANGED
|
@@ -3,8 +3,13 @@
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Dict, Generator, List, Tuple
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
from pinecone import Pinecone
|
| 7 |
|
|
|
|
| 8 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 9 |
|
| 10 |
|
|
@@ -30,11 +35,15 @@ class VectorStore(ABC):
|
|
| 30 |
if batch:
|
| 31 |
self.upsert_batch(batch)
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
class PineconeVectorStore(VectorStore):
|
| 35 |
"""Vector store implementation using Pinecone."""
|
| 36 |
|
| 37 |
-
def __init__(self, index_name: str,
|
| 38 |
self.index_name = index_name
|
| 39 |
self.dimension = dimension
|
| 40 |
self.client = Pinecone()
|
|
@@ -50,3 +59,49 @@ class PineconeVectorStore(VectorStore):
|
|
| 50 |
(metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
|
| 51 |
]
|
| 52 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from abc import ABC, abstractmethod
|
| 4 |
from typing import Dict, Generator, List, Tuple
|
| 5 |
|
| 6 |
+
import marqo
|
| 7 |
+
from langchain_community.vectorstores import Marqo
|
| 8 |
+
from langchain_core.documents import Document
|
| 9 |
+
from langchain_openai import OpenAIEmbeddings
|
| 10 |
from pinecone import Pinecone
|
| 11 |
|
| 12 |
+
OPENAI_EMBEDDING_SIZE = 1536
|
| 13 |
Vector = Tuple[Dict, List[float]] # (metadata, embedding)
|
| 14 |
|
| 15 |
|
|
|
|
| 35 |
if batch:
|
| 36 |
self.upsert_batch(batch)
|
| 37 |
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def to_langchain(self):
|
| 40 |
+
"""Converts the vector store to a LangChain vector store object."""
|
| 41 |
+
|
| 42 |
|
| 43 |
class PineconeVectorStore(VectorStore):
|
| 44 |
"""Vector store implementation using Pinecone."""
|
| 45 |
|
| 46 |
+
def __init__(self, index_name: str, namespace: str, dimension: int = OPENAI_EMBEDDING_SIZE):
|
| 47 |
self.index_name = index_name
|
| 48 |
self.dimension = dimension
|
| 49 |
self.client = Pinecone()
|
|
|
|
| 59 |
(metadata.get("id", str(i)), embedding, metadata) for i, (metadata, embedding) in enumerate(vectors)
|
| 60 |
]
|
| 61 |
self.index.upsert(vectors=pinecone_vectors, namespace=self.namespace)
|
| 62 |
+
|
| 63 |
+
def to_langchain(self):
|
| 64 |
+
return Pinecone.from_existing_index(
|
| 65 |
+
index_name=self.index_name, embedding=OpenAIEmbeddings(), namespace=self.namespace
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MarqoVectorStore(VectorStore):
|
| 70 |
+
"""Vector store implementation using Marqo."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, url: str, index_name: str):
|
| 73 |
+
self.client = marqo.Client(url=url)
|
| 74 |
+
self.index_name = index_name
|
| 75 |
+
|
| 76 |
+
def ensure_exists(self):
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
def upsert_batch(self, vectors: List[Vector]):
|
| 80 |
+
# Since Marqo is both an embedder and a vector store, the embedder is already doing the upsert.
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
def to_langchain(self):
|
| 84 |
+
vectorstore = Marqo(client=self.client, index_name=self.index_name)
|
| 85 |
+
|
| 86 |
+
# Monkey-patch the _construct_documents_from_results_without_score method to not expect a "metadata" field in
|
| 87 |
+
# the result, and instead take the "filename" directly from the result.
|
| 88 |
+
def patched_method(self, results):
|
| 89 |
+
documents: List[Document] = []
|
| 90 |
+
for res in results["hits"]:
|
| 91 |
+
documents.append(Document(page_content=res["text"], metadata={"filename": res["filename"]}))
|
| 92 |
+
return documents
|
| 93 |
+
|
| 94 |
+
vectorstore._construct_documents_from_results_without_score = patched_method.__get__(
|
| 95 |
+
vectorstore, vectorstore.__class__
|
| 96 |
+
)
|
| 97 |
+
return vectorstore
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def build_from_args(args: dict) -> VectorStore:
|
| 101 |
+
"""Builds a vector store from the given command-line arguments."""
|
| 102 |
+
if args.vector_store_type == "pinecone":
|
| 103 |
+
return PineconeVectorStore(index_name=args.index_name, namespace=args.repo_id)
|
| 104 |
+
elif args.vector_store_type == "marqo":
|
| 105 |
+
return MarqoVectorStore(url=args.marqo_url, index_name=args.index_name)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unrecognized vector store type {args.vector_store_type}")
|