File size: 3,474 Bytes
28fd7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import re
from typing import TypedDict, Annotated, List
import operator

from langchain.tools import tool
from langchain_google_community.search import GoogleSearchAPIWrapper
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, ToolMessage
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode

# This class defines the structure of the agent's state
class AgentState(TypedDict):
    messages: Annotated[List[BaseMessage], operator.add]

def create_agent_graph(vector_store, nvidia_api_key, google_api_key, google_cse_id):
    """Creates and compiles the LangGraph agent."""

    # 1. Define Tools
    @tool
    def paper_qa_tool(query: str) -> str:
        """
        Answers specific, detailed questions about scientific papers on graph theory,
        sparsity, and the pebble game. Use this for questions that reference specific
        paper details or concepts.
        """
        print("--- Calling Paper Q&A Tool ---")
        retriever = vector_store.as_retriever(search_kwargs={'k': 3})
        context_docs = retriever.get_relevant_documents(query)
        
        # Simple cleaning to remove potential gibberish from parsed PDFs
        gibberish_pattern = re.compile(r'/DAN <[A-Fa-f0-9]+>')
        cleaned_docs = [doc for doc in context_docs if not gibberish_pattern.search(doc.page_content)]
        
        if not cleaned_docs:
            return "No relevant information found in the documents after cleaning."
        
        context_text = "\n\n".join([doc.page_content for doc in cleaned_docs])
        return context_text

    search_wrapper = GoogleSearchAPIWrapper(google_api_key=google_api_key, google_cse_id=google_cse_id)

    @tool
    def web_search_tool(query: str) -> str:
        """
        Provides up-to-date answers from the web for general knowledge, definitions,
        or topics not covered in the local scientific papers. Also provides source links.
        """
        print("--- Calling Web Search Tool ---")
        results = search_wrapper.results(query, num_results=3)
        return "\n".join([f"Title: {res['title']}\nLink: {res['link']}\nSnippet: {res['snippet']}\n" for res in results])

    tools = [paper_qa_tool, web_search_tool]
    tool_node = ToolNode(tools)

    # 2. Define the Model
    # We use ChatOpenAI pointed at the NVIDIA endpoint
    model = ChatOpenAI(
        model="meta/llama3-70b-instruct",
        openai_api_key=nvidia_api_key,
        openai_api_base="https://integrate.api.nvidia.com/v1/ ",
        temperature=0.2
    ).bind_tools(tools)

    # 3. Define Graph Nodes
    def call_model(state):
        """The primary node that calls the LLM."""
        print("--- AGENT: Thinking... ---")
        response = model.invoke(state["messages"])
        return {"messages": [response]}

    def should_continue(state):
        """Router: decides whether to call a tool or end the conversation."""
        last_message = state["messages"][-1]
        if last_message.tool_calls:
            return "continue"
        return "end"

    # 4. Build and Compile the Graph
    workflow = StateGraph(AgentState)
    workflow.add_node("agent", call_model)
    workflow.add_node("action", tool_node)
    
    workflow.set_entry_point("agent")
    workflow.add_conditional_edges(
        "agent",
        should_continue,
        {"continue": "action", "end": END},
    )
    workflow.add_edge("action", "agent")

    return workflow.compile()