1Paras1 commited on
Commit
d4bbaf2
·
verified ·
1 Parent(s): 7a2ee91

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +92 -0
agent.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import TypedDict, Annotated, List
3
+ import operator
4
+
5
+ from langchain.tools import tool
6
+ from langchain_google_community.search import GoogleSearchAPIWrapper
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.messages import BaseMessage, ToolMessage
9
+ from langgraph.graph import StateGraph, END
10
+ from langgraph.prebuilt import ToolNode
11
+
12
+ # This class defines the structure of the agent's state
13
+ class AgentState(TypedDict):
14
+ messages: Annotated[List[BaseMessage], operator.add]
15
+
16
+ def create_agent_graph(vector_store, nvidia_api_key, google_api_key, google_cse_id):
17
+ """Creates and compiles the LangGraph agent."""
18
+
19
+ # 1. Define Tools
20
+ @tool
21
+ def paper_qa_tool(query: str) -> str:
22
+ """
23
+ Answers specific, detailed questions about scientific papers on graph theory,
24
+ sparsity, and the pebble game. Use this for questions that reference specific
25
+ paper details or concepts.
26
+ """
27
+ print("--- Calling Paper Q&A Tool ---")
28
+ retriever = vector_store.as_retriever(search_kwargs={'k': 3})
29
+ context_docs = retriever.get_relevant_documents(query)
30
+
31
+ # Simple cleaning to remove potential gibberish from parsed PDFs
32
+ gibberish_pattern = re.compile(r'/DAN <[A-Fa-f0-9]+>')
33
+ cleaned_docs = [doc for doc in context_docs if not gibberish_pattern.search(doc.page_content)]
34
+
35
+ if not cleaned_docs:
36
+ return "No relevant information found in the documents after cleaning."
37
+
38
+ context_text = "\n\n".join([doc.page_content for doc in cleaned_docs])
39
+ return context_text
40
+
41
+ search_wrapper = GoogleSearchAPIWrapper(google_api_key=google_api_key, google_cse_id=google_cse_id)
42
+
43
+ @tool
44
+ def web_search_tool(query: str) -> str:
45
+ """
46
+ Provides up-to-date answers from the web for general knowledge, definitions,
47
+ or topics not covered in the local scientific papers. Also provides source links.
48
+ """
49
+ print("--- Calling Web Search Tool ---")
50
+ results = search_wrapper.results(query, num_results=3)
51
+ return "\n".join([f"Title: {res['title']}\nLink: {res['link']}\nSnippet: {res['snippet']}\n" for res in results])
52
+
53
+ tools = [paper_qa_tool, web_search_tool]
54
+ tool_node = ToolNode(tools)
55
+
56
+ # 2. Define the Model
57
+ # We use ChatOpenAI pointed at the NVIDIA endpoint
58
+ model = ChatOpenAI(
59
+ model="meta/llama3-70b-instruct",
60
+ openai_api_key=nvidia_api_key,
61
+ openai_api_base="https://integrate.api.nvidia.com/v1/ ",
62
+ temperature=0.2
63
+ ).bind_tools(tools)
64
+
65
+ # 3. Define Graph Nodes
66
+ def call_model(state):
67
+ """The primary node that calls the LLM."""
68
+ print("--- AGENT: Thinking... ---")
69
+ response = model.invoke(state["messages"])
70
+ return {"messages": [response]}
71
+
72
+ def should_continue(state):
73
+ """Router: decides whether to call a tool or end the conversation."""
74
+ last_message = state["messages"][-1]
75
+ if last_message.tool_calls:
76
+ return "continue"
77
+ return "end"
78
+
79
+ # 4. Build and Compile the Graph
80
+ workflow = StateGraph(AgentState)
81
+ workflow.add_node("agent", call_model)
82
+ workflow.add_node("action", tool_node)
83
+
84
+ workflow.set_entry_point("agent")
85
+ workflow.add_conditional_edges(
86
+ "agent",
87
+ should_continue,
88
+ {"continue": "action", "end": END},
89
+ )
90
+ workflow.add_edge("action", "agent")
91
+
92
+ return workflow.compile()