Dinesh310 commited on
Commit
d31a43f
·
verified ·
1 Parent(s): 6602c82

Update src/rag_engine.py

Browse files
Files changed (1) hide show
  1. src/rag_engine.py +58 -30
src/rag_engine.py CHANGED
@@ -7,59 +7,87 @@ from langchain_core.prompts import ChatPromptTemplate
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough, RunnableParallel
9
 
 
10
  class ProjectRAGEngine:
11
- def __init__(self, api_key):
12
- self.embeddings = OpenAIEmbeddings(openai_api_key=api_key)
13
- self.llm = ChatOpenAI(model="openai/gpt-oss-120b:free", openai_api_key=api_key, temperature=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  self.vector_store = None
15
 
16
  def process_documents(self, pdf_paths):
17
  all_docs = []
 
18
  for path in pdf_paths:
19
- try:
20
- loader = PyPDFLoader(path)
21
- docs = loader.load()
22
- all_docs.extend(docs)
23
- except Exception as e:
24
- print(f"Error loading {path}: {e}")
25
-
26
- # Splitting logic to handle large reports [cite: 10]
27
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
28
- splits = text_splitter.split_documents(all_docs)
29
- self.vector_store = FAISS.from_documents(splits, self.embeddings)
 
 
30
 
31
  def _format_docs(self, docs):
32
- return "\n\n".join(doc.page_content for doc in docs)
33
 
34
  def get_answer(self, query):
35
  if not self.vector_store:
36
  return "Please upload documents first.", []
37
 
38
- # System prompt ensuring grounded responses [cite: 18, 25]
39
  template = """
40
- You are a professional Project Analyst. Answer strictly based on the provided context.
41
- If the answer is not in the context, say you don't know.
42
- Cite document names and page numbers for every answer. Include direct quotes.
 
 
 
 
43
 
44
- Context: {context}
45
- Question: {question}
46
  """
 
47
  prompt = ChatPromptTemplate.from_template(template)
48
  retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
49
 
50
- # Pure LCEL Chain composition
51
- rag_chain_from_docs = (
52
- RunnablePassthrough.assign(context=(lambda x: self._format_docs(x["context"])))
 
53
  | prompt
54
  | self.llm
55
  | StrOutputParser()
56
  )
57
 
58
- rag_chain_with_source = RunnableParallel(
59
  {"context": retriever, "question": RunnablePassthrough()}
60
- ).assign(answer=rag_chain_from_docs)
61
 
62
- result = rag_chain_with_source.invoke(query)
63
-
64
- sources = [{"content": doc.page_content, "metadata": doc.metadata} for doc in result["context"]]
65
- return result["answer"], sources
 
 
 
 
 
7
  from langchain_core.output_parsers import StrOutputParser
8
  from langchain_core.runnables import RunnablePassthrough, RunnableParallel
9
 
10
+
11
  class ProjectRAGEngine:
12
+ def __init__(self):
13
+ # OpenAI embeddings (OFFICIAL)
14
+ self.embeddings = OpenAIEmbeddings(
15
+ model="text-embedding-3-small"
16
+ )
17
+
18
+ # ✅ OpenRouter LLM
19
+ self.llm = ChatOpenAI(
20
+ model="openai/gpt-oss-120b:free",
21
+ temperature=0,
22
+ openai_api_key=os.getenv("OPENROUTER_API_KEY"),
23
+ openai_api_base="https://openrouter.ai/api/v1",
24
+ default_headers={
25
+ "HTTP-Referer": "http://localhost:8501",
26
+ "X-Title": "Project-RAG-App"
27
+ }
28
+ )
29
+
30
  self.vector_store = None
31
 
32
  def process_documents(self, pdf_paths):
33
  all_docs = []
34
+
35
  for path in pdf_paths:
36
+ loader = PyPDFLoader(path)
37
+ all_docs.extend(loader.load())
38
+
39
+ splitter = RecursiveCharacterTextSplitter(
40
+ chunk_size=1000,
41
+ chunk_overlap=200
42
+ )
43
+
44
+ splits = splitter.split_documents(all_docs)
45
+
46
+ self.vector_store = FAISS.from_documents(
47
+ splits, self.embeddings
48
+ )
49
 
50
  def _format_docs(self, docs):
51
+ return "\n\n".join(d.page_content for d in docs)
52
 
53
  def get_answer(self, query):
54
  if not self.vector_store:
55
  return "Please upload documents first.", []
56
 
 
57
  template = """
58
+ You are a professional Project Analyst.
59
+ Answer strictly using the context.
60
+ If unknown, say you don't know.
61
+ Cite document names and page numbers.
62
+
63
+ Context:
64
+ {context}
65
 
66
+ Question:
67
+ {question}
68
  """
69
+
70
  prompt = ChatPromptTemplate.from_template(template)
71
  retriever = self.vector_store.as_retriever(search_kwargs={"k": 5})
72
 
73
+ rag_chain = (
74
+ RunnablePassthrough.assign(
75
+ context=lambda x: self._format_docs(x["context"])
76
+ )
77
  | prompt
78
  | self.llm
79
  | StrOutputParser()
80
  )
81
 
82
+ chain = RunnableParallel(
83
  {"context": retriever, "question": RunnablePassthrough()}
84
+ ).assign(answer=rag_chain)
85
 
86
+ result = chain.invoke(query)
87
+
88
+ sources = [
89
+ {"content": d.page_content, "metadata": d.metadata}
90
+ for d in result["context"]
91
+ ]
92
+
93
+ return result["answer"], sources