File size: 5,286 Bytes
70dc500
0919908
35fc704
b6a5916
d1b27f1
 
996f53d
b6a5916
d1b27f1
b6a5916
 
d1b27f1
 
0919908
 
b6a5916
 
 
70dc500
b6a5916
 
996f53d
 
 
 
 
243b5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0d0488
243b5b7
 
 
 
 
 
7b77919
35fc704
243b5b7
8f0d578
7b77919
243b5b7
 
 
 
 
 
 
 
 
 
 
369866b
01582bf
243b5b7
 
7b77919
243b5b7
 
 
 
 
 
 
 
 
 
d1b27f1
243b5b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6a5916
996f53d
243b5b7
7b77919
 
996f53d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
from typing import Annotated, TypedDict
import os

from langchain_google_genai.chat_models import ChatGoogleGenerativeAI
from langchain_openai.chat_models import ChatOpenAI
from langchain_anthropic.chat_models import ChatAnthropic
from langchain_groq.chat_models import ChatGroq
from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
from langgraph.graph import StateGraph, START
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from langfuse import Langfuse
from langfuse.langchain import CallbackHandler

from tools import *



class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


# for local testing
# from dotenv import load_dotenv
# load_dotenv()

class Gaia_Agent:
    def __init__(self):
        model_family = os.getenv("MODEL_FAMILY", None)
        if model_family is None:
            raise ValueError("MODEL_FAMILY is not set in environment variables")
        
        model_name = os.getenv("MODEL_NAME", None)
        if model_name is None:
            raise ValueError("MODEL_NAME is not set in environment variables")

        llm_api_key = os.getenv("LLM_API_KEY", None)
        if llm_api_key is None:
            raise ValueError("LLM_API_KEY is not set in environment variables")

        langfuse_host = os.getenv("LANGFUSE_HOST", None)
        if langfuse_host is None:
            raise ValueError("LANGFUSE_HOST is not set in environment variables")

        langfuse_secret_key = os.getenv("LANGFUSE_SECRET_KEY", None)
        if langfuse_secret_key is None:
            raise ValueError("LANGFUSE_SECRET_KEY is not set in environment variables")

        langfuse_public_key = os.getenv("LANGFUSE_PUBLIC_KEY", None)
        if langfuse_public_key is None:
            raise ValueError("LANGFUSE_PUBLIC_KEY is not set in environment variables")

        # Setting the llm
        if model_family.lower() == "gemini" or model_family.lower() == "google":
            self.llm = ChatGoogleGenerativeAI(model=model_name, google_api_key=llm_api_key, verbose=True)
        elif model_family.lower() == "openai" or model_family.lower() == "gpt":
            self.llm = ChatOpenAI(model=model_name, api_key=llm_api_key, verbose=True)
        elif model_family.lower() == "anthropic" or model_family.lower() == "claude":
            self.llm = ChatAnthropic(model_name=model_name, api_key=llm_api_key, verbose=True)
        elif model_family.lower() == "groq":
            self.llm = ChatGroq(model=model_name, temperature=0, api_key=llm_api_key)
        else:
            raise ValueError("model family not an acceptable value!!")
        
        langfuse = Langfuse(
            public_key=langfuse_public_key,
            secret_key=langfuse_secret_key,
            host=langfuse_host
        )
        self.langfuse_handler = CallbackHandler()

    def add_tools(self):
        self.tools_list = [
            calculator,
            Web_Search,
            Arxiv_Search,
            Wikipedia_Search,
            get_yt_video_info_metadata,
            get_yt_video_transcript,
            analyze_excel_file,
            read_file,
            save_file_temp,
            analyze_image,
            transcribe_audio_file,
            execute_code_file
        ]

        self.llm_with_tools = self.llm.bind_tools(tools=self.tools_list)

    def build_graph(self):
        def assistant(state:AgentState) -> AgentState:
            """Assistant node with message history"""
            try:
                response = self.llm_with_tools.invoke(state["messages"])
                return {"messages": state["messages"] + [response]}
            except Exception as e:
                print(f"encountered an error: {e}")
                raise

        graph_builder = StateGraph(AgentState)
        self.add_tools()

        #nodes
        graph_builder.add_node("assistant", assistant)
        graph_builder.add_node("tools", ToolNode(tools=self.tools_list))

        #edges
        graph_builder.add_edge(START, "assistant")
        graph_builder.add_edge("tools", "assistant")
        graph_builder.add_conditional_edges(
            "assistant", tools_condition
        )

        self.graph = graph_builder.compile()
        return self.graph
    


    def Run_Agent(self, prompt:str, sys_prompt_file:str):
        with open(sys_prompt_file, "r") as f:
            system_prompt = f.read()
        sys_message = SystemMessage(content=system_prompt)
        message = HumanMessage(content=prompt)
        response = self.graph.invoke(
            input={"messages":[sys_message]+[message]},
            config={"callbacks":[self.langfuse_handler]}
            )
        return response

if __name__ == "__main__":
    agent = Gaia_Agent()
    prompt = "If Eliud Kipchoge could maintain his record-making marathon pace indefinitely, how many thousand hours would it take him to run the distance between the Earth and the Moon its closest approach? Please use the minimum perigee value on the Wikipedia page for the Moon when carrying out your calculation. Round your result to the nearest 1000 hours and do not use any comma separators if necessary."
    answer = agent.Run_Agent(prompt,"sys_prompt_v2.txt")

    print(answer["messages"][-1].content)