Dinesh310 commited on
Commit
7b9f456
·
verified ·
1 Parent(s): e012c2e

Create langraph_rag_backend

Browse files
Files changed (1) hide show
  1. src/langraph_rag_backend +183 -0
src/langraph_rag_backend ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import sqlite3
5
+ import tempfile
6
+ from typing import Annotated, Any, Dict, List, Optional, TypedDict
7
+
8
+ from dotenv import load_dotenv
9
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
10
+ from langchain_community.document_loaders import PyPDFLoader
11
+ from langchain_community.embeddings import HuggingFaceEmbeddings
12
+ from langchain_community.tools import DuckDuckGoSearchRun
13
+ from langchain_community.vectorstores import FAISS
14
+ from langchain_core.messages import BaseMessage, SystemMessage
15
+ from langchain_core.tools import tool
16
+ from langchain_openai import ChatOpenAI
17
+ from langgraph.checkpoint.sqlite import SqliteSaver
18
+ from langgraph.graph import START, StateGraph
19
+ from langgraph.graph.message import add_messages
20
+ from langgraph.prebuilt import ToolNode, tools_condition
21
+ import requests
22
+
23
+ load_dotenv()
24
+
25
+ # -------------------
26
+ # 1. LLM + embeddings
27
+ # -------------------
28
+ llm = ChatOpenAI(
29
+ model="openai/gpt-oss-120b:free",
30
+ base_url="https://openrouter.ai/api/v1",
31
+ api_key=os.getenv("OPENROUTER_API_KEY"),
32
+ extra_body={"reasoning": {"enabled": True}}
33
+ )
34
+
35
+ embeddings = HuggingFaceEmbeddings(
36
+ model_name="sentence-transformers/all-MiniLM-L6-v2",
37
+ model_kwargs={"device": "cpu"},
38
+ encode_kwargs={"normalize_embeddings": True}
39
+ )
40
+
41
+ # -------------------
42
+ # 2. Multi-PDF Store (per thread)
43
+ # -------------------
44
+ # Changed from _THREAD_RETRIEVERS to _THREAD_STORES to keep access to .add_documents()
45
+ _THREAD_STORES: Dict[str, FAISS] = {}
46
+ _THREAD_METADATA: Dict[str, List[dict]] = {}
47
+
48
+
49
+ def ingest_pdf(file_bytes: bytes, thread_id: str, filename: Optional[str] = None) -> dict:
50
+ """
51
+ Adds a PDF to the existing FAISS index for a thread, or creates a new one.
52
+ """
53
+ if not file_bytes:
54
+ raise ValueError("No bytes received for ingestion.")
55
+
56
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
57
+ temp_file.write(file_bytes)
58
+ temp_path = temp_file.name
59
+
60
+ try:
61
+ loader = PyPDFLoader(temp_path)
62
+ docs = loader.load()
63
+
64
+ splitter = RecursiveCharacterTextSplitter(
65
+ chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", " ", ""]
66
+ )
67
+ chunks = splitter.split_documents(docs)
68
+
69
+ thread_key = str(thread_id)
70
+
71
+ # --- Multi-PDF Logic ---
72
+ if thread_key in _THREAD_STORES:
73
+ # Add to existing vector store
74
+ _THREAD_STORES[thread_key].add_documents(chunks)
75
+ else:
76
+ # Create new vector store
77
+ _THREAD_STORES[thread_key] = FAISS.from_documents(chunks, embeddings)
78
+
79
+ # Track metadata as a list of files
80
+ file_info = {
81
+ "filename": filename or os.path.basename(temp_path),
82
+ "documents": len(docs),
83
+ "chunks": len(chunks),
84
+ }
85
+
86
+ if thread_key not in _THREAD_METADATA:
87
+ _THREAD_METADATA[thread_key] = []
88
+ _THREAD_METADATA[thread_key].append(file_info)
89
+
90
+ return file_info
91
+ finally:
92
+ try:
93
+ os.remove(temp_path)
94
+ except OSError:
95
+ pass
96
+
97
+
98
+ # -------------------
99
+ # 3. Tools
100
+ # -------------------
101
+ search_tool = DuckDuckGoSearchRun(region="us-en")
102
+
103
+ @tool
104
+ def calculator(first_num: float, second_num: float, operation: str) -> dict:
105
+ """Perform basic arithmetic: add, sub, mul, div."""
106
+ # ... (same as your previous logic)
107
+ ops = {"add": first_num + second_num, "sub": first_num - second_num,
108
+ "mul": first_num * second_num, "div": first_num / second_num if second_num != 0 else "Error"}
109
+ return {"result": ops.get(operation, "Unsupported")}
110
+
111
+ @tool
112
+ def get_stock_price(symbol: str) -> dict:
113
+ """Fetch latest stock price for a symbol."""
114
+ url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey=C9PE94QUEW9VWGFM"
115
+ return requests.get(url).json()
116
+
117
+ @tool
118
+ def rag_tool(query: str, thread_id: Optional[str] = None) -> dict:
119
+ """
120
+ Retrieve information from ALL uploaded PDFs for this chat thread.
121
+ """
122
+ thread_key = str(thread_id)
123
+ vector_store = _THREAD_STORES.get(thread_key)
124
+
125
+ if vector_store is None:
126
+ return {
127
+ "error": "No documents indexed for this chat. Please upload one or more PDFs.",
128
+ "query": query,
129
+ }
130
+
131
+ # Search across all documents in the store
132
+ docs = vector_store.similarity_search(query, k=4)
133
+
134
+ return {
135
+ "query": query,
136
+ "context": [doc.page_content for doc in docs],
137
+ "sources": [doc.metadata for doc in docs],
138
+ "uploaded_files": [f["filename"] for f in _THREAD_METADATA.get(thread_key, [])]
139
+ }
140
+
141
+ tools = [search_tool, get_stock_price, calculator, rag_tool]
142
+ llm_with_tools = llm.bind_tools(tools)
143
+
144
+ # -------------------
145
+ # 4. State & Nodes (Same as previous)
146
+ # -------------------
147
+ class ChatState(TypedDict):
148
+ messages: Annotated[list[BaseMessage], add_messages]
149
+
150
+ def chat_node(state: ChatState, config=None):
151
+ thread_id = config.get("configurable", {}).get("thread_id") if config else None
152
+
153
+ system_message = SystemMessage(
154
+ content=(
155
+ "You are a helpful assistant. You have access to multiple PDFs uploaded by the user. "
156
+ f"To search them, use `rag_tool` with thread_id `{thread_id}`. "
157
+ "You can synthesize info from multiple documents if needed."
158
+ )
159
+ )
160
+ return {"messages": [llm_with_tools.invoke([system_message, *state["messages"]], config=config)]}
161
+
162
+ # -------------------
163
+ # 5. Graph Setup
164
+ # -------------------
165
+ tool_node = ToolNode(tools)
166
+ conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
167
+ checkpointer = SqliteSaver(conn=conn)
168
+
169
+ builder = StateGraph(ChatState)
170
+ builder.add_node("chat_node", chat_node)
171
+ builder.add_node("tools", tool_node)
172
+ builder.add_edge(START, "chat_node")
173
+ builder.add_conditional_edges("chat_node", tools_condition)
174
+ builder.add_edge("tools", "chat_node")
175
+
176
+ chatbot = builder.compile(checkpointer=checkpointer)
177
+
178
+ # -------------------
179
+ # 6. Helpers
180
+ # -------------------
181
+ def get_all_uploaded_files(thread_id: str) -> List[dict]:
182
+ """Returns a list of all files uploaded to this thread."""
183
+ return _THREAD_METADATA.get(str(thread_id), [])