Spaces:
Sleeping
Sleeping
| from dotenv import load_dotenv | |
| from typing_extensions import List, TypedDict | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from langchain_core.documents import Document | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from langchain_qdrant import QdrantVectorStore | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langgraph.graph import START, StateGraph | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_community.document_loaders import DirectoryLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from qdrant_client.http.models import Distance, VectorParams | |
| # Necessary for dependencies for DirectoryLoader | |
| import nltk | |
| nltk.download('punkt_tab') | |
| nltk.download('averaged_perceptron_tagger_eng') | |
| # Chunk configuration | |
| CHUNK_SIZE = 1000 | |
| CHUNK_OVERLAP = CHUNK_SIZE // 2 | |
| # RAG prompt template | |
| RAG_PROMPT = """\ | |
| You are a helpful assistant who helps Shopify merchants automate their businesses. | |
| Your goal is to provide a helpful response to the merchant's question in straight forward, non technical language. | |
| Try to be brief and to the point, but explain technical jargon. | |
| You must only use the provided context, and cannot use your own knowledge. | |
| ### Question | |
| {question} | |
| ### Context | |
| {context} | |
| """ | |
| class RagGraph: | |
| def __init__(self, qdrant_client, use_finetuned_embeddings=False): | |
| self.llm = ChatOpenAI(model="gpt-4-turbo-preview", streaming=True) | |
| self.collection_name = "rag_collection" if not use_finetuned_embeddings else "rag_collection_finetuned" | |
| self.embeddings_model = OpenAIEmbeddings(model="text-embedding-3-small") \ | |
| if not use_finetuned_embeddings else HuggingFaceEmbeddings(model_name="thomfoolery/AIE5-MidTerm-finetuned-embeddings") | |
| self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
| self.rag_prompt = ChatPromptTemplate.from_template(RAG_PROMPT) | |
| self.qdrant_client = qdrant_client | |
| does_collection_exist = self.qdrant_client.collection_exists(collection_name=self.collection_name) | |
| dimension_size = 1536 if not use_finetuned_embeddings else 1024 | |
| print(f"Collection {self.collection_name} exists: {does_collection_exist}") | |
| if not does_collection_exist: | |
| qdrant_client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=dimension_size, distance=Distance.COSINE), | |
| ) | |
| self.vector_store = QdrantVectorStore( | |
| client=qdrant_client, | |
| collection_name=self.collection_name, | |
| embedding=self.embeddings_model, | |
| ) | |
| if not does_collection_exist: | |
| loader = DirectoryLoader("data/scraped/clean", glob="*.txt") | |
| documents = self.text_splitter.split_documents(loader.load()) | |
| self.vector_store.add_documents(documents=documents) | |
| self.vector_db_retriever = self.vector_store.as_retriever(search_kwargs={"k": 5}) | |
| self.graph = None | |
| self.create() | |
| def create(self): | |
| """Create the RAG graph.""" | |
| class State(TypedDict): | |
| """State for the conversation.""" | |
| question: str | |
| context: List[Document] | |
| def retrieve(state): | |
| question = state["question"] | |
| context = self.vector_db_retriever.invoke(question) | |
| return {"question": state["question"], "context": context} | |
| async def stream(state): | |
| """LangGraph node that streams responses""" | |
| question = state["question"] | |
| context = "\n\n".join(doc.page_content for doc in state["context"]) | |
| messages = self.rag_prompt.format_messages(question=question, context=context) | |
| async for chunk in self.llm.astream(messages): | |
| yield {"content": chunk.content} | |
| graph_builder = StateGraph(State).add_sequence([retrieve, stream]) | |
| graph_builder.add_edge(START, "retrieve") | |
| self.graph = graph_builder.compile() | |
| def run(self, question): | |
| """Invoke RAG response without streaming.""" | |
| chunks = self.vector_db_retriever.invoke(question) | |
| context = "\n\n".join(doc.page_content for doc in chunks) | |
| messages = self.rag_prompt.format_messages(question=question, context=context) | |
| response = self.llm.invoke(messages) | |
| return { | |
| "response": response.content, | |
| "context": chunks | |
| } | |
| async def stream(self, question, msg): | |
| """Stream RAG response.""" | |
| async for event in self.graph.astream({"question": question, "context": []}, stream_mode=["messages"]): | |
| _event_name, (message_chunk, _metadata) = event | |
| if message_chunk.content: | |
| await msg.stream_token(message_chunk.content) | |
| await msg.send() | |
| # Run RAG with CLI (no streaming) | |
| def main(): | |
| """Test the RAG graph.""" | |
| load_dotenv() | |
| rag_graph = RagGraph() | |
| # rag_graph.update_vector_store("data/scraped/clean", replace_documents=False) | |
| rag_graph.create_rag_graph() | |
| response = rag_graph.run("What is Shopify Flow?") | |
| print(response["response"]) | |
| if __name__ == "__main__": | |
| main() | |