File size: 3,324 Bytes
af5a423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bf44ed
af5a423
 
 
 
 
fbbbedb
7bf44ed
 
af5a423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02af49f
af5a423
 
 
 
 
 
 
 
 
 
 
 
 
 
02af49f
af5a423
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph, MessagesState
from langgraph.prebuilt import tools_condition, ToolNode
from langchain_openai import ChatOpenAI
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
from omegaconf import OmegaConf
from .tools import *


def load_config(config_path: str):
    config = OmegaConf.load(config_path)
    return config

# --- Constants ---
CONFIG = load_config("config.yaml")
SYSTEM_PROMPT = CONFIG["system_prompt"]["custom"]

# Load environment variables from .env file
load_dotenv()


class LangGraphAgent4GAIA:
    def __init__(self, model_provider: str, model_name: str):
        self.sys_prompt = SystemMessage(content=SYSTEM_PROMPT)
        self.graph = self.get_agent(model_provider, model_name)

    def assistant(self, state: MessagesState):
        """Assistant node"""
        return {"messages": [self.llm_with_tools.invoke([self.sys_prompt] + state["messages"])]}

    def get_agent(self, provider: str, model_name: str):
        tools = [
            multiply,
            add,
            add_list,
            subtract,
            divide,
            modulo,
            web_search,
            arxiv_search,
            wiki_search,
            read_xlsx_file,
            get_python_file
        ]

        # 1. Build graph
        if provider == "openai":
            llm = ChatOpenAI(
                model=model_name,
                temperature=0,
                max_retries=2,
                api_key=os.getenv("OPENAI_API_KEY")
            )
        elif provider == "huggingface":
            llm = ChatHuggingFace(
                llm=HuggingFaceEndpoint(
                    repo_id=model_name,
                    task="text-generation",
                    max_new_tokens=1024,
                    do_sample=False,
                    repetition_penalty=1.03,
                    temperature=0
                ),
                verbose=True
            )
        else:
            raise ValueError("Invalid provider. Choose 'openai' or 'huggingface'.")


        # 2. Bind tools to LLM
        self.llm_with_tools = llm.bind_tools(tools, parallel_tool_calls=False)

        builder = StateGraph(MessagesState)
        builder.add_node("assistant", self.assistant)
        builder.add_node("tools", ToolNode(tools))
        builder.add_edge(START, "assistant")
        builder.add_conditional_edges(
            "assistant",
            tools_condition,
        )
        builder.add_edge("tools", "assistant")

        # Compile graph
        return builder.compile()


if __name__ == "__main__":
    from langchain_core.runnables.graph import MermaidDrawMethod

    question = "What is the capital of Spain?"
    # Build the graph
    agent_manager = LangGraphAgent4GAIA(CONFIG["model"]["provider"], CONFIG["model"]["name"])
    img_data = agent_manager.graph.get_graph().draw_mermaid_png(draw_method=MermaidDrawMethod.API)
    with open('agentic/graph.png', "wb") as f:
        f.write(img_data)

    # Run the graph
    messages = [HumanMessage(content=question)]
    messages = agent_manager.graph.invoke({"messages": messages}, {"recursion_limit": 50})
    for m in messages["messages"]:
        m.pretty_print()