Demo_1 / src /rag_engine.py
Dinesh310's picture
Update src/rag_engine.py
e3f3e21 verified
import os
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_openai import ChatOpenAI
from langchain_community.chat_models import ChatLiteLLM
from langchain_core.messages import HumanMessage, AIMessage
class ProjectRAGEngine:
def __init__(self):
# βœ… Hugging Face Embeddings (LOCAL / FREE)
self.embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"}, # change to "cuda" if GPU available
encode_kwargs={"normalize_embeddings": True}
)
# βœ… OpenRouter LLM (Chat only)
self.llm = ChatOpenAI(
model="openai/gpt-oss-120b:free",
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
extra_body={"reasoning": {"enabled": True}})
self.vector_store = None
def process_documents(self, pdf_paths):
all_docs = []
for path in pdf_paths:
loader = PyPDFLoader(path)
all_docs.extend(loader.load())
splitter = RecursiveCharacterTextSplitter(
chunk_size=500,
chunk_overlap=50
)
splits = splitter.split_documents(all_docs)
# βœ… FAISS with HuggingFace embeddings
self.vector_store = FAISS.from_documents(
splits, self.embeddings
)
def _format_docs(self, docs):
return "\n\n".join(d.page_content for d in docs)
def get_answer(self, query):
if not self.vector_store:
return "Please upload documents first.", []
template = """
You are a professional Project Analyst.
Answer strictly using the context.
If unknown, say you don't know.
Cite document names and page numbers.
Context:
{context}
Question:
{question}
"""
prompt = ChatPromptTemplate.from_template(template)
retriever = self.vector_store.as_retriever(search_type="mmr", search_kwargs={"k": 5, "lambda_mult":0.25})
rag_chain = (
RunnablePassthrough.assign(
context=lambda x: self._format_docs(x["context"])
)
| prompt
| self.llm
| StrOutputParser()
)
chain = RunnableParallel(
{"context": retriever, "question": RunnablePassthrough()}
).assign(answer=rag_chain)
result = chain.invoke(query)
sources = [
{"content": d.page_content, "metadata": d.metadata}
for d in result["context"]
]
return result["answer"], sources