File size: 4,940 Bytes
c8d9c8c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
"""LangGraph Agent"""
import os
import json
import getpass
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_ollama import ChatOllama
from tools.math.multiply import multiply
from tools.math.add import add
from tools.math.subtract import subtract
from tools.math.divide import divide
from tools.math.modulus import modulus
from tools.math.power import power
from tools.math.square_root import square_root
from tools.search.arxiv_search import arxiv_search
from tools.search.web_search import web_search
from tools.search.wiki_search import wiki_search
from tools.file.analyze_csv_file import analyze_csv_file
from tools.file.analyze_excel_file import analyze_excel_file
from tools.file.analyze_image import analyze_image
from tools.file.download_file_from_url import download_file_from_url
from tools.file.save_content_to_file import save_content_to_file
# --- Load environment variables ---
load_dotenv()
# --- Constants ---
DATASET_PATH = "dataset/metadata.jsonl"
SYSTEM_PROMPT_PATH = "prompts/system_prompt.txt"
TOOLS = [
add,
subtract,
multiply,
divide,
modulus,
power,
square_root,
web_search,
wiki_search,
arxiv_search,
analyze_csv_file,
analyze_excel_file,
analyze_image,
download_file_from_url,
save_content_to_file,
]
def load_vector_store() -> InMemoryVectorStore:
"""Load vector store with dataset examples."""
if not os.path.exists(DATASET_PATH):
raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}.")
embeddings = OpenAIEmbeddings()
vector_store = InMemoryVectorStore(embeddings)
documents = []
with open(DATASET_PATH, "r", encoding="utf-8") as f:
for line in f:
entry = json.loads(line)
content = (
f"Question: {entry['Question']}\nFinal answer: {entry['Final answer']}"
)
doc = Document(page_content=content, metadata={"source": entry["task_id"]})
documents.append(doc)
vector_store.add_documents(documents)
return vector_store
def get_llm(provider: str):
"""Get LLM instance based on provider."""
if provider == "openai":
if not os.environ.get("OPENAI_API_KEY"):
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API key: ")
return ChatOpenAI(model="gpt-4.1", temperature=0)
elif provider == "ollama":
return ChatOllama(model="llama3", temperature=0)
else:
raise ValueError("Unsupported provider: choose 'openai' or 'ollama'")
def load_system_prompt() -> SystemMessage:
"""Load system prompt from file."""
if not os.path.exists(SYSTEM_PROMPT_PATH):
raise FileNotFoundError(f"System prompt not found at {SYSTEM_PROMPT_PATH}.")
with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
return SystemMessage(content=f.read())
def build_graph(provider: str = "openai"):
"""Build and compile the LangGraph agent."""
llm = get_llm(provider).bind_tools(TOOLS)
vector_store = load_vector_store()
system_msg = load_system_prompt()
def retriever(state: MessagesState):
"""Retrieve similar examples based on user query."""
query = state["messages"][0].content
similar = vector_store.similarity_search(query, k=3)
if similar:
refs = "\n\n".join(doc.page_content for doc in similar)
example_msg = HumanMessage(content=f"Here are similar examples:\n\n{refs}")
return {"messages": [system_msg] + state["messages"] + [example_msg]}
return {"messages": [system_msg] + state["messages"]}
def assistant(state: MessagesState):
"""Call LLM to generate next message."""
response = llm.invoke(state["messages"])
return {"messages": [response]}
# --- Build graph ---
graph = StateGraph(MessagesState)
graph.add_node("retriever", retriever)
graph.add_node("assistant", assistant)
graph.add_node("tools", ToolNode(TOOLS))
graph.add_edge(START, "retriever")
graph.add_edge("retriever", "assistant")
graph.add_conditional_edges("assistant", tools_condition)
graph.add_edge("tools", "assistant")
return graph.compile()
def run_agent(query: str, provider: str = "openai"):
"""Run the agent on a given query."""
graph = build_graph(provider)
messages = [HumanMessage(content=query)]
result = graph.invoke({"messages": messages})
for msg in result["messages"]:
msg.pretty_print()
# --- Run locally ---
if __name__ == "__main__":
user_query = input("Enter your question: ")
run_agent(user_query)
|