File size: 6,186 Bytes
7b9f456 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from __future__ import annotations
import os
import sqlite3
import tempfile
from typing import Annotated, Any, Dict, List, Optional, TypedDict
from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.vectorstores import FAISS
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
import requests
load_dotenv()
# -------------------
# 1. LLM + embeddings
# -------------------
llm = ChatOpenAI(
model="openai/gpt-oss-120b:free",
base_url="https://openrouter.ai/api/v1",
api_key=os.getenv("OPENROUTER_API_KEY"),
extra_body={"reasoning": {"enabled": True}}
)
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-MiniLM-L6-v2",
model_kwargs={"device": "cpu"},
encode_kwargs={"normalize_embeddings": True}
)
# -------------------
# 2. Multi-PDF Store (per thread)
# -------------------
# Changed from _THREAD_RETRIEVERS to _THREAD_STORES to keep access to .add_documents()
_THREAD_STORES: Dict[str, FAISS] = {}
_THREAD_METADATA: Dict[str, List[dict]] = {}
def ingest_pdf(file_bytes: bytes, thread_id: str, filename: Optional[str] = None) -> dict:
"""
Adds a PDF to the existing FAISS index for a thread, or creates a new one.
"""
if not file_bytes:
raise ValueError("No bytes received for ingestion.")
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
temp_file.write(file_bytes)
temp_path = temp_file.name
try:
loader = PyPDFLoader(temp_path)
docs = loader.load()
splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", " ", ""]
)
chunks = splitter.split_documents(docs)
thread_key = str(thread_id)
# --- Multi-PDF Logic ---
if thread_key in _THREAD_STORES:
# Add to existing vector store
_THREAD_STORES[thread_key].add_documents(chunks)
else:
# Create new vector store
_THREAD_STORES[thread_key] = FAISS.from_documents(chunks, embeddings)
# Track metadata as a list of files
file_info = {
"filename": filename or os.path.basename(temp_path),
"documents": len(docs),
"chunks": len(chunks),
}
if thread_key not in _THREAD_METADATA:
_THREAD_METADATA[thread_key] = []
_THREAD_METADATA[thread_key].append(file_info)
return file_info
finally:
try:
os.remove(temp_path)
except OSError:
pass
# -------------------
# 3. Tools
# -------------------
search_tool = DuckDuckGoSearchRun(region="us-en")
@tool
def calculator(first_num: float, second_num: float, operation: str) -> dict:
"""Perform basic arithmetic: add, sub, mul, div."""
# ... (same as your previous logic)
ops = {"add": first_num + second_num, "sub": first_num - second_num,
"mul": first_num * second_num, "div": first_num / second_num if second_num != 0 else "Error"}
return {"result": ops.get(operation, "Unsupported")}
@tool
def get_stock_price(symbol: str) -> dict:
"""Fetch latest stock price for a symbol."""
url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey=C9PE94QUEW9VWGFM"
return requests.get(url).json()
@tool
def rag_tool(query: str, thread_id: Optional[str] = None) -> dict:
"""
Retrieve information from ALL uploaded PDFs for this chat thread.
"""
thread_key = str(thread_id)
vector_store = _THREAD_STORES.get(thread_key)
if vector_store is None:
return {
"error": "No documents indexed for this chat. Please upload one or more PDFs.",
"query": query,
}
# Search across all documents in the store
docs = vector_store.similarity_search(query, k=4)
return {
"query": query,
"context": [doc.page_content for doc in docs],
"sources": [doc.metadata for doc in docs],
"uploaded_files": [f["filename"] for f in _THREAD_METADATA.get(thread_key, [])]
}
tools = [search_tool, get_stock_price, calculator, rag_tool]
llm_with_tools = llm.bind_tools(tools)
# -------------------
# 4. State & Nodes (Same as previous)
# -------------------
class ChatState(TypedDict):
messages: Annotated[list[BaseMessage], add_messages]
def chat_node(state: ChatState, config=None):
thread_id = config.get("configurable", {}).get("thread_id") if config else None
system_message = SystemMessage(
content=(
"You are a helpful assistant. You have access to multiple PDFs uploaded by the user. "
f"To search them, use `rag_tool` with thread_id `{thread_id}`. "
"You can synthesize info from multiple documents if needed."
)
)
return {"messages": [llm_with_tools.invoke([system_message, *state["messages"]], config=config)]}
# -------------------
# 5. Graph Setup
# -------------------
tool_node = ToolNode(tools)
conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)
builder = StateGraph(ChatState)
builder.add_node("chat_node", chat_node)
builder.add_node("tools", tool_node)
builder.add_edge(START, "chat_node")
builder.add_conditional_edges("chat_node", tools_condition)
builder.add_edge("tools", "chat_node")
chatbot = builder.compile(checkpointer=checkpointer)
# -------------------
# 6. Helpers
# -------------------
def get_all_uploaded_files(thread_id: str) -> List[dict]:
"""Returns a list of all files uploaded to this thread."""
return _THREAD_METADATA.get(str(thread_id), []) |