Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, AsyncGenerator | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from langchain_qdrant import QdrantVectorStore | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, VectorParams | |
| from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
| from langchain_cohere import CohereRerank | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from operator import itemgetter | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| class RAGProcessor: | |
| def __init__( | |
| self, | |
| model_name: str = "bsmith3715/legal-ft-demo_final", | |
| collection_name: str = "reformer_docs", | |
| embedding_dim: int = 768, | |
| rerank_model: str = "rerank-v3.5" | |
| ): | |
| # 1. Embedding model | |
| self.embeddings = HuggingFaceEmbeddings(model_name=model_name) | |
| # 2. In-memory Qdrant store | |
| self.client = QdrantClient(":memory:") | |
| self.collection_name = collection_name | |
| self.client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), | |
| ) | |
| self.vectorstore = QdrantVectorStore( | |
| client=self.client, | |
| collection_name=collection_name, | |
| embedding=self.embeddings, | |
| ) | |
| # 3. Retriever | |
| self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5}) | |
| # 4. Contextual compression with reranking | |
| self.compression_retriever = ContextualCompressionRetriever( | |
| base_compressor=CohereRerank(model=rerank_model, top_n=10), | |
| base_retriever=self.retriever | |
| ) | |
| # 5. Prompt | |
| template = """You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge. | |
| ### Question | |
| {question} | |
| ### Context | |
| {context} | |
| """ | |
| self.prompt = ChatPromptTemplate.from_template(template) | |
| # 6. LLM | |
| self.llm = ChatOpenAI(model="gpt-4o-mini", streaming=True) | |
| # 7. Final chain | |
| self.chain = ( | |
| {"context": itemgetter("question") | self.compression_retriever, "question": itemgetter("question")} | |
| | RunnablePassthrough.assign(context=itemgetter("context")) | |
| | {"response": self.prompt | self.llm | StrOutputParser(), "context": itemgetter("context")} | |
| ) | |
| def add_documents(self, texts: List[str]): | |
| """Splits and indexes a list of text documents.""" | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100) | |
| docs = text_splitter.create_documents(texts) | |
| self.vectorstore.add_documents(docs) | |
| async def generate_response(self, query: str) -> AsyncGenerator[str, None]: | |
| """Streams response using the LangChain RAG pipeline.""" | |
| async for chunk in self.chain.stream({"question": query}): | |
| yield chunk["response"] | |