sebastianfrench commited on
Commit
a4def59
·
1 Parent(s): f6559d0

add search agent

Browse files
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  .vscode
2
  .venv
 
 
 
1
  .vscode
2
  .venv
3
+ *.env
4
+ __pycache__/
agents/agent.py CHANGED
@@ -1,14 +1,10 @@
1
- import os
2
  from typing import TypedDict, List, Dict, Any, Optional
3
  from langgraph.graph import StateGraph, START, END
4
  from langchain_anthropic import ChatAnthropic
5
  from langchain_groq import ChatGroq
6
  from langchain_core.messages import HumanMessage
7
  import getpass
8
-
9
- if "ANTHROPIC_API_KEY" not in os.environ:
10
- os.environ["ANTHROPIC_API_KEY"] = getpass.getpass("Enter your Anthropic API key: ")
11
-
12
  class EmailState(TypedDict):
13
  # The email being processed
14
  email: Dict[str, Any] # Contains subject, sender, body, etc.
 
 
1
  from typing import TypedDict, List, Dict, Any, Optional
2
  from langgraph.graph import StateGraph, START, END
3
  from langchain_anthropic import ChatAnthropic
4
  from langchain_groq import ChatGroq
5
  from langchain_core.messages import HumanMessage
6
  import getpass
7
+
 
 
 
8
  class EmailState(TypedDict):
9
  # The email being processed
10
  email: Dict[str, Any] # Contains subject, sender, body, etc.
{models → agents}/basic_agent.py RENAMED
File without changes
agents/search_agent.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from graphs.search import build_workflow
2
+ from langchain_core.messages import HumanMessage
3
+ class SearchAgent:
4
+ def __init__(self):
5
+ print("SearchAgent initialized.")
6
+ def __call__(self, question: str) -> str:
7
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
8
+ workflow = build_workflow()
9
+ messages = [HumanMessage(question)]
10
+ messages = workflow.invoke({
11
+ "messages":messages
12
+ })
13
+
14
+ return messages["messages"][-1].content
agents/tools/search.py DELETED
File without changes
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
6
- from models.basic_agent import BasicAgent
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
@@ -31,7 +31,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
31
 
32
  # 1. Instantiate Agent ( modify this part to create your agent)
33
  try:
34
- agent = BasicAgent()
35
  except Exception as e:
36
  print(f"Error instantiating agent: {e}")
37
  return f"Error initializing agent: {e}", None
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from agents.search_agent import SearchAgent
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
 
31
 
32
  # 1. Instantiate Agent ( modify this part to create your agent)
33
  try:
34
+ agent = SearchAgent()
35
  except Exception as e:
36
  print(f"Error instantiating agent: {e}")
37
  return f"Error initializing agent: {e}", None
{agents/graphs → graphs}/__init__.py RENAMED
File without changes
graphs/search.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.models import groq_model
2
+ from tools.search import arxiv_search, web_search, wikipedia_search
3
+ from langgraph.graph import StateGraph, START, END, MessagesState
4
+ from langgraph.prebuilt import ToolNode
5
+ from langchain_core.messages import HumanMessage
6
+
7
+ tools = [
8
+ arxiv_search,
9
+ web_search,
10
+ wikipedia_search
11
+ ]
12
+
13
+ tool_node = ToolNode(tools)
14
+ bound_model = groq_model.bind_tools(tools)
15
+
16
+ # Define the function that calls the model
17
+ def call_model(state: MessagesState):
18
+ response = bound_model.invoke(state["messages"])
19
+ # We return a list, because this will get added to the existing list
20
+ return {"messages": response}
21
+
22
+ def should_continue(state:MessagesState):
23
+ last_message = state["messages"][-1]
24
+
25
+ if not last_message.tool_calls:
26
+ return END
27
+
28
+ return "action"
29
+
30
+ def build_workflow():
31
+ """
32
+ Build search workflow
33
+ """
34
+ workflow = StateGraph(MessagesState)
35
+ workflow.add_node("agent", call_model)
36
+ workflow.add_node("action",tool_node)
37
+
38
+ workflow.add_edge(START,"agent")
39
+ workflow.add_conditional_edges("agent", should_continue)
40
+ workflow.add_edge("action", "agent")
41
+ return workflow.compile()
42
+
43
+ if __name__ == "__main__":
44
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
45
+ # Build the graph
46
+ graph = build_workflow()
47
+ # Run the graph
48
+ messages = [HumanMessage(content=question)]
49
+ messages = graph.invoke({"messages": messages})
50
+ for m in messages["messages"]:
51
+ m.pretty_print()
models/models.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_anthropic import ChatAnthropic
2
+ from langchain_groq import ChatGroq
3
+ from dotenv import load_dotenv
4
+
5
+ load_dotenv()
6
+
7
+ anthropic_model = ChatAnthropic(
8
+ model="claude-3-5-haiku-latest",
9
+ temperature=0
10
+ )
11
+
12
+ groq_model = ChatGroq(
13
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
14
+ temperature=0
15
+ )
{agents/tools → tools}/__init__.py RENAMED
File without changes
tools/search.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.tools import tool
2
+ from langchain_community.tools.tavily_search import TavilySearchResults
3
+ from langchain_community.document_loaders import WikipediaLoader
4
+ from langchain_community.document_loaders import ArxivLoader
5
+ from dotenv import load_dotenv
6
+
7
+ load_dotenv()
8
+
9
+ @tool
10
+ def wikipedia_search(query: str) -> str:
11
+ """Search Wikipedia for a query and return maximum 1 result.
12
+ Args:
13
+ query: The search query."""
14
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
15
+
16
+ formatted_search_docs = "\n\n---\n\n".join(
17
+ [
18
+ f'<Document source="{doc.metadata["source"]}"/>\n{doc.page_content}\n</Document>'
19
+ for doc in search_docs
20
+ if "Closed-ended question" not in doc.metadata.get("title", "")
21
+ ]
22
+ )
23
+
24
+ return {"wiki_results": formatted_search_docs}
25
+
26
+ @tool
27
+ def web_search(query: str) -> str:
28
+ """Search Tavily for a query and return maximum 1 results.
29
+ Args:
30
+ query: The search query."""
31
+ search_docs = TavilySearchResults(max_results=1).invoke(input=query)
32
+
33
+ formatted_search_docs = "\n\n---\n\n".join(
34
+ [
35
+ f'<Document source="{doc["url"]}""/>\n{doc["content"]}\n</Document>'
36
+ for doc in search_docs
37
+ ]
38
+ )
39
+
40
+ return {"web_results": formatted_search_docs}
41
+
42
+
43
+ @tool
44
+ def arxiv_search(query: str) -> str:
45
+ """Search Arxiv for a query and return maximum 3 result.
46
+ Args:
47
+ query: The search query."""
48
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
49
+ print(search_docs)
50
+ formatted_search_docs = "\n\n---\n\n".join(
51
+ [
52
+ f'<Document source="{doc.metadata["Title"]}""/>\n{doc.page_content[:1000]}\n</Document>'
53
+ for doc in search_docs
54
+ ]
55
+ )
56
+ return {"arxiv_results": formatted_search_docs}