File size: 3,199 Bytes
1b3b5e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os
from typing import List, AsyncGenerator

from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank
from langchain.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from operator import itemgetter
from dotenv import load_dotenv

load_dotenv()

class RAGProcessor:
    def __init__(
        self,
        model_name: str = "bsmith3715/legal-ft-demo_final",
        collection_name: str = "reformer_docs",
        embedding_dim: int = 768,
        rerank_model: str = "rerank-v3.5"
    ):
        # 1. Embedding model
        self.embeddings = HuggingFaceEmbeddings(model_name=model_name)

        # 2. In-memory Qdrant store
        self.client = QdrantClient(":memory:")
        self.collection_name = collection_name
        self.client.create_collection(
            collection_name=collection_name,
            vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
        )

        self.vectorstore = QdrantVectorStore(
            client=self.client,
            collection_name=collection_name,
            embedding=self.embeddings,
        )

        # 3. Retriever
        self.retriever = self.vectorstore.as_retriever(search_kwargs={"k": 5})

        # 4. Contextual compression with reranking
        self.compression_retriever = ContextualCompressionRetriever(
            base_compressor=CohereRerank(model=rerank_model, top_n=10),
            base_retriever=self.retriever
        )

        # 5. Prompt
        template = """You are a helpful assistant who answers questions based on provided context. You must only use the provided context, and cannot use your own knowledge.

### Question
{question}

### Context
{context}
"""
        self.prompt = ChatPromptTemplate.from_template(template)

        # 6. LLM
        self.llm = ChatOpenAI(model="gpt-4o-mini", streaming=True)

        # 7. Final chain
        self.chain = (
            {"context": itemgetter("question") | self.compression_retriever, "question": itemgetter("question")}
            | RunnablePassthrough.assign(context=itemgetter("context"))
            | {"response": self.prompt | self.llm | StrOutputParser(), "context": itemgetter("context")}
        )

    def add_documents(self, texts: List[str]):
        """Splits and indexes a list of text documents."""
        text_splitter = RecursiveCharacterTextSplitter(chunk_size=750, chunk_overlap=100)
        docs = text_splitter.create_documents(texts)
        self.vectorstore.add_documents(docs)

    async def generate_response(self, query: str) -> AsyncGenerator[str, None]:
        """Streams response using the LangChain RAG pipeline."""
        async for chunk in self.chain.stream({"question": query}):
            yield chunk["response"]