File size: 4,940 Bytes
c8d9c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""LangGraph Agent"""

import os
import json
import getpass
from dotenv import load_dotenv

from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode

from langchain_core.messages import SystemMessage, HumanMessage
from langchain_core.vectorstores import InMemoryVectorStore
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_ollama import ChatOllama

from tools.math.multiply import multiply
from tools.math.add import add
from tools.math.subtract import subtract
from tools.math.divide import divide
from tools.math.modulus import modulus
from tools.math.power import power
from tools.math.square_root import square_root

from tools.search.arxiv_search import arxiv_search
from tools.search.web_search import web_search
from tools.search.wiki_search import wiki_search

from tools.file.analyze_csv_file import analyze_csv_file
from tools.file.analyze_excel_file import analyze_excel_file
from tools.file.analyze_image import analyze_image
from tools.file.download_file_from_url import download_file_from_url
from tools.file.save_content_to_file import save_content_to_file

# --- Load environment variables ---
load_dotenv()

# --- Constants ---
DATASET_PATH = "dataset/metadata.jsonl"
SYSTEM_PROMPT_PATH = "prompts/system_prompt.txt"
TOOLS = [
    add,
    subtract,
    multiply,
    divide,
    modulus,
    power,
    square_root,
    web_search,
    wiki_search,
    arxiv_search,
    analyze_csv_file,
    analyze_excel_file,
    analyze_image,
    download_file_from_url,
    save_content_to_file,
]


def load_vector_store() -> InMemoryVectorStore:
    """Load vector store with dataset examples."""
    if not os.path.exists(DATASET_PATH):
        raise FileNotFoundError(f"Dataset not found at {DATASET_PATH}.")
    embeddings = OpenAIEmbeddings()
    vector_store = InMemoryVectorStore(embeddings)
    documents = []
    with open(DATASET_PATH, "r", encoding="utf-8") as f:
        for line in f:
            entry = json.loads(line)
            content = (
                f"Question: {entry['Question']}\nFinal answer: {entry['Final answer']}"
            )
            doc = Document(page_content=content, metadata={"source": entry["task_id"]})
            documents.append(doc)
    vector_store.add_documents(documents)
    return vector_store


def get_llm(provider: str):
    """Get LLM instance based on provider."""
    if provider == "openai":
        if not os.environ.get("OPENAI_API_KEY"):
            os.environ["OPENAI_API_KEY"] = getpass.getpass("Enter OpenAI API key: ")
        return ChatOpenAI(model="gpt-4.1", temperature=0)
    elif provider == "ollama":
        return ChatOllama(model="llama3", temperature=0)
    else:
        raise ValueError("Unsupported provider: choose 'openai' or 'ollama'")


def load_system_prompt() -> SystemMessage:
    """Load system prompt from file."""
    if not os.path.exists(SYSTEM_PROMPT_PATH):
        raise FileNotFoundError(f"System prompt not found at {SYSTEM_PROMPT_PATH}.")
    with open(SYSTEM_PROMPT_PATH, "r", encoding="utf-8") as f:
        return SystemMessage(content=f.read())


def build_graph(provider: str = "openai"):
    """Build and compile the LangGraph agent."""
    llm = get_llm(provider).bind_tools(TOOLS)
    vector_store = load_vector_store()
    system_msg = load_system_prompt()

    def retriever(state: MessagesState):
        """Retrieve similar examples based on user query."""
        query = state["messages"][0].content
        similar = vector_store.similarity_search(query, k=3)
        if similar:
            refs = "\n\n".join(doc.page_content for doc in similar)
            example_msg = HumanMessage(content=f"Here are similar examples:\n\n{refs}")
            return {"messages": [system_msg] + state["messages"] + [example_msg]}
        return {"messages": [system_msg] + state["messages"]}

    def assistant(state: MessagesState):
        """Call LLM to generate next message."""
        response = llm.invoke(state["messages"])
        return {"messages": [response]}

    # --- Build graph ---
    graph = StateGraph(MessagesState)
    graph.add_node("retriever", retriever)
    graph.add_node("assistant", assistant)
    graph.add_node("tools", ToolNode(TOOLS))

    graph.add_edge(START, "retriever")
    graph.add_edge("retriever", "assistant")
    graph.add_conditional_edges("assistant", tools_condition)
    graph.add_edge("tools", "assistant")

    return graph.compile()


def run_agent(query: str, provider: str = "openai"):
    """Run the agent on a given query."""
    graph = build_graph(provider)
    messages = [HumanMessage(content=query)]
    result = graph.invoke({"messages": messages})
    for msg in result["messages"]:
        msg.pretty_print()


# --- Run locally ---
if __name__ == "__main__":
    user_query = input("Enter your question: ")
    run_agent(user_query)