ktluege commited on
Commit
3be3277
Β·
verified Β·
1 Parent(s): 81917a3

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +71 -0
agent.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent.py
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import StateGraph, MessagesState
5
+ from langchain_community.tools.tavily_search import TavilySearchResults
6
+ from langchain_community.document_loaders import WikipediaLoader
7
+ from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
8
+ from langchain_core.tools import tool
9
+
10
+ # Load environment variables (API keys, etc.)
11
+ load_dotenv()
12
+
13
+ # === TOOLS ===
14
+
15
+ @tool
16
+ def add(a: int, b: int) -> int:
17
+ """Add two numbers."""
18
+ return a + b
19
+
20
+ @tool
21
+ def wiki_search(query: str) -> str:
22
+ """Search Wikipedia and return up to 2 results."""
23
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
24
+ formatted = "\n\n---\n\n".join(
25
+ [f"{doc.metadata}\n{doc.page_content[:500]}" for doc in search_docs]
26
+ )
27
+ return {"wiki_results": formatted}
28
+
29
+ @tool
30
+ def web_search(query: str) -> str:
31
+ """Search Tavily for a query and return max 3 results."""
32
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
33
+ formatted = "\n\n---\n\n".join(
34
+ [f"{doc.metadata}\n{doc.page_content[:500]}" for doc in search_docs]
35
+ )
36
+ return {"web_results": formatted}
37
+
38
+ tools = [add, wiki_search, web_search]
39
+
40
+ # === SYSTEM PROMPT ===
41
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
42
+ system_prompt = f.read()
43
+ sys_msg = SystemMessage(content=system_prompt)
44
+
45
+ # === GRAPH BUILD ===
46
+ def build_graph():
47
+ # For demo, use OpenAI API if key provided; else default to Gemini
48
+ from langchain_openai import ChatOpenAI
49
+ openai_api_key = os.environ.get("OPENAI_API_KEY")
50
+ if openai_api_key:
51
+ llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0, openai_api_key=openai_api_key)
52
+ else:
53
+ # Gemini fallback (or any other, e.g., groq, etc.)
54
+ from langchain_google_genai import ChatGoogleGenerativeAI
55
+ llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
56
+
57
+ # Bind tools to LLM
58
+ llm_with_tools = llm.bind_tools(tools)
59
+
60
+ # Simple assistant node
61
+ def assistant(state: MessagesState):
62
+ # Always include system prompt!
63
+ messages = [sys_msg] + state["messages"]
64
+ return {"messages": [llm_with_tools.invoke(messages)]}
65
+
66
+ builder = StateGraph(MessagesState)
67
+ builder.add_node("assistant", assistant)
68
+ builder.set_entry_point("assistant")
69
+ builder.set_finish_point("assistant")
70
+
71
+ return builder.compile()