File size: 6,186 Bytes
7b9f456
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from __future__ import annotations

import os
import sqlite3
import tempfile
from typing import Annotated, Any, Dict, List, Optional, TypedDict

from dotenv import load_dotenv
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_community.vectorstores import FAISS
from langchain_core.messages import BaseMessage, SystemMessage
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
import requests

load_dotenv()

# -------------------
# 1. LLM + embeddings
# -------------------
llm = ChatOpenAI(
    model="openai/gpt-oss-120b:free",
    base_url="https://openrouter.ai/api/v1",
    api_key=os.getenv("OPENROUTER_API_KEY"),
    extra_body={"reasoning": {"enabled": True}}
)

embeddings = HuggingFaceEmbeddings(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    model_kwargs={"device": "cpu"},
    encode_kwargs={"normalize_embeddings": True}
)

# -------------------
# 2. Multi-PDF Store (per thread)
# -------------------
# Changed from _THREAD_RETRIEVERS to _THREAD_STORES to keep access to .add_documents()
_THREAD_STORES: Dict[str, FAISS] = {}
_THREAD_METADATA: Dict[str, List[dict]] = {}


def ingest_pdf(file_bytes: bytes, thread_id: str, filename: Optional[str] = None) -> dict:
    """
    Adds a PDF to the existing FAISS index for a thread, or creates a new one.
    """
    if not file_bytes:
        raise ValueError("No bytes received for ingestion.")

    with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_file:
        temp_file.write(file_bytes)
        temp_path = temp_file.name

    try:
        loader = PyPDFLoader(temp_path)
        docs = loader.load()

        splitter = RecursiveCharacterTextSplitter(
            chunk_size=500, chunk_overlap=100, separators=["\n\n", "\n", " ", ""]
        )
        chunks = splitter.split_documents(docs)

        thread_key = str(thread_id)
        
        # --- Multi-PDF Logic ---
        if thread_key in _THREAD_STORES:
            # Add to existing vector store
            _THREAD_STORES[thread_key].add_documents(chunks)
        else:
            # Create new vector store
            _THREAD_STORES[thread_key] = FAISS.from_documents(chunks, embeddings)

        # Track metadata as a list of files
        file_info = {
            "filename": filename or os.path.basename(temp_path),
            "documents": len(docs),
            "chunks": len(chunks),
        }
        
        if thread_key not in _THREAD_METADATA:
            _THREAD_METADATA[thread_key] = []
        _THREAD_METADATA[thread_key].append(file_info)

        return file_info
    finally:
        try:
            os.remove(temp_path)
        except OSError:
            pass


# -------------------
# 3. Tools
# -------------------
search_tool = DuckDuckGoSearchRun(region="us-en")

@tool
def calculator(first_num: float, second_num: float, operation: str) -> dict:
    """Perform basic arithmetic: add, sub, mul, div."""
    # ... (same as your previous logic)
    ops = {"add": first_num + second_num, "sub": first_num - second_num, 
           "mul": first_num * second_num, "div": first_num / second_num if second_num != 0 else "Error"}
    return {"result": ops.get(operation, "Unsupported")}

@tool
def get_stock_price(symbol: str) -> dict:
    """Fetch latest stock price for a symbol."""
    url = f"https://www.alphavantage.co/query?function=GLOBAL_QUOTE&symbol={symbol}&apikey=C9PE94QUEW9VWGFM"
    return requests.get(url).json()

@tool
def rag_tool(query: str, thread_id: Optional[str] = None) -> dict:
    """
    Retrieve information from ALL uploaded PDFs for this chat thread.
    """
    thread_key = str(thread_id)
    vector_store = _THREAD_STORES.get(thread_key)
    
    if vector_store is None:
        return {
            "error": "No documents indexed for this chat. Please upload one or more PDFs.",
            "query": query,
        }

    # Search across all documents in the store
    docs = vector_store.similarity_search(query, k=4)
    
    return {
        "query": query,
        "context": [doc.page_content for doc in docs],
        "sources": [doc.metadata for doc in docs],
        "uploaded_files": [f["filename"] for f in _THREAD_METADATA.get(thread_key, [])]
    }

tools = [search_tool, get_stock_price, calculator, rag_tool]
llm_with_tools = llm.bind_tools(tools)

# -------------------
# 4. State & Nodes (Same as previous)
# -------------------
class ChatState(TypedDict):
    messages: Annotated[list[BaseMessage], add_messages]

def chat_node(state: ChatState, config=None):
    thread_id = config.get("configurable", {}).get("thread_id") if config else None
    
    system_message = SystemMessage(
        content=(
            "You are a helpful assistant. You have access to multiple PDFs uploaded by the user. "
            f"To search them, use `rag_tool` with thread_id `{thread_id}`. "
            "You can synthesize info from multiple documents if needed."
        )
    )
    return {"messages": [llm_with_tools.invoke([system_message, *state["messages"]], config=config)]}

# -------------------
# 5. Graph Setup
# -------------------
tool_node = ToolNode(tools)
conn = sqlite3.connect(database="chatbot.db", check_same_thread=False)
checkpointer = SqliteSaver(conn=conn)

builder = StateGraph(ChatState)
builder.add_node("chat_node", chat_node)
builder.add_node("tools", tool_node)
builder.add_edge(START, "chat_node")
builder.add_conditional_edges("chat_node", tools_condition)
builder.add_edge("tools", "chat_node")

chatbot = builder.compile(checkpointer=checkpointer)

# -------------------
# 6. Helpers
# -------------------
def get_all_uploaded_files(thread_id: str) -> List[dict]:
    """Returns a list of all files uploaded to this thread."""
    return _THREAD_METADATA.get(str(thread_id), [])