|
|
"""LangGraph Agent""" |
|
|
|
|
|
import os |
|
|
import json |
|
|
import getpass |
|
|
from dotenv import load_dotenv |
|
|
|
|
|
from langgraph.graph import START, StateGraph, MessagesState |
|
|
from langgraph.prebuilt import tools_condition, ToolNode |
|
|
|
|
|
from langchain_core.messages import SystemMessage, HumanMessage |
|
|
from langchain_core.vectorstores import InMemoryVectorStore |
|
|
from langchain_core.documents import Document |
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
|
from langchain_ollama import ChatOllama |
|
|
|
|
|
from tools.math.multiply import multiply |
|
|
from tools.math.add import add |
|
|
from tools.math.subtract import subtract |
|
|
from tools.math.divide import divide |
|
|
from tools.math.modulus import modulus |
|
|
from tools.math.power import power |
|
|
from tools.math.square_root import square_root |
|
|
|
|
|
from tools.search.arxiv_search import arxiv_search |
|
|
from tools.search.web_search import web_search |
|
|
from tools.search.wiki_search import wiki_search |
|
|
|
|
|
from tools.file.analyze_csv_file import analyze_csv_file |
|
|
from tools.file.analyze_excel_file import analyze_excel_file |
|
|
from tools.file.analyze_image import analyze_image |
|
|
from tools.file.download_file_from_url import download_file_from_url |
|
|
from tools.file.save_content_to_file import save_content_to_file |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
DATASET_PATH = "dataset/metadata.jsonl" |
|
|
SYSTEM_PROMPT_PATH = "prompts/system_prompt.txt" |
|
|
TOOLS = [ |
|
|
add, |
|
|
subtract, |
|
|
multiply, |
|
|
divide, |
|
|
modulus, |
|
|
power, |
|
|
square_root, |
|
|
web_search, |
|
|
wiki_search, |
|
|
arxiv_search, |
|
|
analyze_csv_file, |
|
|
analyze_excel_file, |
|
|
analyze_image, |
|
|
download_file_from_url, |
|
|
save_content_to_file, |
|
|
] |
|
|
|
|
|
|
|
|
def load_vector_store() -> InMemoryVectorStore: |
|
|
"""Load vector store with dataset examples.""" |
|
|
if not os.path.exists(DATASET_PATH): |
|
|
raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}.") |
|
|
embeddings = OpenAIEmbeddings() |
|
|
vector_store = InMemoryVectorStore(embeddings) |
|
|
documents = [] |
|
|
with open(DATASET_PATH, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
entry = json.loads(line) |
|
|
content = ( |
|
|
f"Question: {entry['Question']}\nFinal answer: {entry['Final answer']}" |
|
|
) |
|
|
doc = Document(page_content=content, metadata={"source": entry["task_id"]}) |
|
|
documents.append(doc) |
|
|
vector_store.add_documents(documents) |
|
|
return vector_store |
|
|
|
|
|
|
|
|
def get_llm(provider: str): |
|
|
"""Get LLM instance based on provider.""" |
|
|
if provider == "openai": |
|
|
if not os.environ.get("OPENAI_API_KEY"): |
|
|
os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API key: ") |
|
|
return ChatOpenAI(model="gpt-4.1", temperature=0) |
|
|
elif provider == "ollama": |
|
|
return ChatOllama(model="llama3", temperature=0) |
|
|
else: |
|
|
raise ValueError("Unsupported provider: choose 'openai' or 'ollama'") |
|
|
|
|
|
|
|
|
def load_system_prompt() -> SystemMessage: |
|
|
"""Load system prompt from file.""" |
|
|
if not os.path.exists(SYSTEM_PROMPT_PATH): |
|
|
raise FileNotFoundError(f"System prompt not found at {SYSTEM_PROMPT_PATH}.") |
|
|
with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f: |
|
|
return SystemMessage(content=f.read()) |
|
|
|
|
|
|
|
|
def build_graph(provider: str = "openai"): |
|
|
"""Build and compile the LangGraph agent.""" |
|
|
llm = get_llm(provider).bind_tools(TOOLS) |
|
|
vector_store = load_vector_store() |
|
|
system_msg = load_system_prompt() |
|
|
|
|
|
def retriever(state: MessagesState): |
|
|
"""Retrieve similar examples based on user query.""" |
|
|
query = state["messages"][0].content |
|
|
similar = vector_store.similarity_search(query, k=3) |
|
|
if similar: |
|
|
refs = "\n\n".join(doc.page_content for doc in similar) |
|
|
example_msg = HumanMessage(content=f"Here are similar examples:\n\n{refs}") |
|
|
return {"messages": [system_msg] + state["messages"] + [example_msg]} |
|
|
return {"messages": [system_msg] + state["messages"]} |
|
|
|
|
|
def assistant(state: MessagesState): |
|
|
"""Call LLM to generate next message.""" |
|
|
response = llm.invoke(state["messages"]) |
|
|
return {"messages": [response]} |
|
|
|
|
|
|
|
|
graph = StateGraph(MessagesState) |
|
|
graph.add_node("retriever", retriever) |
|
|
graph.add_node("assistant", assistant) |
|
|
graph.add_node("tools", ToolNode(TOOLS)) |
|
|
|
|
|
graph.add_edge(START, "retriever") |
|
|
graph.add_edge("retriever", "assistant") |
|
|
graph.add_conditional_edges("assistant", tools_condition) |
|
|
graph.add_edge("tools", "assistant") |
|
|
|
|
|
return graph.compile() |
|
|
|
|
|
|
|
|
def run_agent(query: str, provider: str = "openai"): |
|
|
"""Run the agent on a given query.""" |
|
|
graph = build_graph(provider) |
|
|
messages = [HumanMessage(content=query)] |
|
|
result = graph.invoke({"messages": messages}) |
|
|
for msg in result["messages"]: |
|
|
msg.pretty_print() |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
user_query = input("Enter your question: ") |
|
|
run_agent(user_query) |
|
|
|