Spaces:
Sleeping
Sleeping
File size: 3,474 Bytes
28fd7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import re
from typing import TypedDict, Annotated, List
import operator
from langchain.tools import tool
from langchain_google_community.search import GoogleSearchAPIWrapper
from langchain_openai import ChatOpenAI
from langchain_core.messages import BaseMessage, ToolMessage
from langgraph.graph import StateGraph, END
from langgraph.prebuilt import ToolNode
# This class defines the structure of the agent's state
class AgentState(TypedDict):
messages: Annotated[List[BaseMessage], operator.add]
def create_agent_graph(vector_store, nvidia_api_key, google_api_key, google_cse_id):
"""Creates and compiles the LangGraph agent."""
# 1. Define Tools
@tool
def paper_qa_tool(query: str) -> str:
"""
Answers specific, detailed questions about scientific papers on graph theory,
sparsity, and the pebble game. Use this for questions that reference specific
paper details or concepts.
"""
print("--- Calling Paper Q&A Tool ---")
retriever = vector_store.as_retriever(search_kwargs={'k': 3})
context_docs = retriever.get_relevant_documents(query)
# Simple cleaning to remove potential gibberish from parsed PDFs
gibberish_pattern = re.compile(r'/DAN <[A-Fa-f0-9]+>')
cleaned_docs = [doc for doc in context_docs if not gibberish_pattern.search(doc.page_content)]
if not cleaned_docs:
return "No relevant information found in the documents after cleaning."
context_text = "\n\n".join([doc.page_content for doc in cleaned_docs])
return context_text
search_wrapper = GoogleSearchAPIWrapper(google_api_key=google_api_key, google_cse_id=google_cse_id)
@tool
def web_search_tool(query: str) -> str:
"""
Provides up-to-date answers from the web for general knowledge, definitions,
or topics not covered in the local scientific papers. Also provides source links.
"""
print("--- Calling Web Search Tool ---")
results = search_wrapper.results(query, num_results=3)
return "\n".join([f"Title: {res['title']}\nLink: {res['link']}\nSnippet: {res['snippet']}\n" for res in results])
tools = [paper_qa_tool, web_search_tool]
tool_node = ToolNode(tools)
# 2. Define the Model
# We use ChatOpenAI pointed at the NVIDIA endpoint
model = ChatOpenAI(
model="meta/llama3-70b-instruct",
openai_api_key=nvidia_api_key,
openai_api_base="https://integrate.api.nvidia.com/v1/ ",
temperature=0.2
).bind_tools(tools)
# 3. Define Graph Nodes
def call_model(state):
"""The primary node that calls the LLM."""
print("--- AGENT: Thinking... ---")
response = model.invoke(state["messages"])
return {"messages": [response]}
def should_continue(state):
"""Router: decides whether to call a tool or end the conversation."""
last_message = state["messages"][-1]
if last_message.tool_calls:
return "continue"
return "end"
# 4. Build and Compile the Graph
workflow = StateGraph(AgentState)
workflow.add_node("agent", call_model)
workflow.add_node("action", tool_node)
workflow.set_entry_point("agent")
workflow.add_conditional_edges(
"agent",
should_continue,
{"continue": "action", "end": END},
)
workflow.add_edge("action", "agent")
return workflow.compile() |