import os from typing import List, AsyncGenerator from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_huggingface import HuggingFaceEmbeddings from langchain_qdrant import QdrantVectorStore from qdrant_client import QdrantClient from qdrant_client.http.models import Distance, VectorParams from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain_cohere import CohereRerank from langchain.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from operator import itemgetter from dotenv import load_dotenv load_dotenv() class RAGProcessor: def __init__( self, model_name: str = "bsmith3715/legal-ft-demo_final", collection_name: str = "reformer_docs", embedding_dim: int = 768, rerank_model: str = "rerank-v3.5" ): # 1. Embedding model self.embeddings = HuggingFaceEmbeddings(model_name=model_name) # 2. In-memory Qdrant store self.client = QdrantClient(":memory:") self.collection_name = collection_name self.client.create_collection( collection_name=collection_name, vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), ) self.vectorstore = QdrantVectorStore( client=self.client, collection_name=collection_name, embedding=self.embeddings, ) # 3. Retriever self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5}) # 4. Contextual compression with reranking self.compression_retriever = ContextualCompressionRetriever( base_compressor=CohereRerank(model=rerank_model, top_n=10), base_retriever=self.retriever ) # 5. Prompt template = """You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge. ### Question {question} ### Context {context} """ self.prompt = ChatPromptTemplate.from_template(template) # 6. LLM self.llm = ChatOpenAI(model="gpt-4o-mini", streaming=True) # 7. Final chain self.chain = ( {"context": itemgetter("question") | self.compression_retriever, "question": itemgetter("question")} | RunnablePassthrough.assign(context=itemgetter("context")) | {"response": self.prompt | self.llm | StrOutputParser(), "context": itemgetter("context")} ) def add_documents(self, texts: List[str]): """Splits and indexes a list of text documents.""" text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100) docs = text_splitter.create_documents(texts) self.vectorstore.add_documents(docs) async def generate_response(self, query: str) -> AsyncGenerator[str, None]: """Streams response using the LangChain RAG pipeline.""" async for chunk in self.chain.stream({"question": query}): yield chunk["response"]