DEMO_COMBINED / rag_processor.py
bsmith3715's picture
Update rag_processor.py
1b3b5e3 verified
import os
from typing import List, AsyncGenerator
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from dotenv import load_dotenv
load_dotenv()
class RAGProcessor:
def __init__(
self,
model_name: str = "bsmith3715/legal-ft-demo_final",
collection_name: str = "reformer_docs",
embedding_dim: int = 768,
rerank_model: str = "rerank-v3.5"
):
# 1. Embedding model
self.embeddings = HuggingFaceEmbeddings(model_name=model_name)
# 2. In-memory Qdrant store
self.client = QdrantClient(":memory:")
self.collection_name = collection_name
self.client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
self.vectorstore = QdrantVectorStore(
client=self.client,
collection_name=collection_name,
embedding=self.embeddings,
)
# 3. Retriever
self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})
# 4. Contextual compression with reranking
self.compression_retriever = ContextualCompressionRetriever(
base_compressor=CohereRerank(model=rerank_model, top_n=10),
base_retriever=self.retriever
)
# 5. Prompt
template = """You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.
### Question
{question}
### Context
{context}
"""
self.prompt = ChatPromptTemplate.from_template(template)
# 6. LLM
self.llm = ChatOpenAI(model="gpt-4o-mini", streaming=True)
# 7. Final chain
self.chain = (
{"context": itemgetter("question") | self.compression_retriever, "question": itemgetter("question")}
| RunnablePassthrough.assign(context=itemgetter("context"))
| {"response": self.prompt | self.llm | StrOutputParser(), "context": itemgetter("context")}
)
def add_documents(self, texts: List[str]):
"""Splits and indexes a list of text documents."""
text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100)
docs = text_splitter.create_documents(texts)
self.vectorstore.add_documents(docs)
async def generate_response(self, query: str) -> AsyncGenerator[str, None]:
"""Streams response using the LangChain RAG pipeline."""
async for chunk in self.chain.stream({"question": query}):
yield chunk["response"]