File size: 3,342 Bytes
c2f85cf
29bec31
c2f85cf
 
 
 
41f6453
7eb753f
4fd40c0
e5a6189
c2f85cf
 
 
 
 
 
a380329
7eb753f
c2f85cf
7eb753f
 
 
dab557d
35bce75
c2f85cf
a380329
c2f85cf
a380329
c2f85cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2618ebd
a380329
eca3d44
2618ebd
5440230
 
 
 
 
2618ebd
5440230
58c57f1
a18d9f7
 
 
58c57f1
09e510d
f46a489
 
37cf57f
f46a489
 
 
 
 
 
 
 
 
 
 
c2f85cf
37cf57f
c2f85cf
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from langgraph.graph.message import add_messages
from langchain_core.messages import AnyMessage, HumanMessage, AIMessage, SystemMessage
from langgraph.prebuilt import ToolNode
from langgraph.graph import START, StateGraph
from langgraph.prebuilt import tools_condition
from langchain_huggingface import HuggingFaceEndpoint, ChatHuggingFace
from tools import extract_text, describe_image
from langchain_community.tools import DuckDuckGoSearchRun
from langchain_openai import ChatOpenAI
from typing import TypedDict, Annotated, Optional


class AgentState(TypedDict):
    messages: Annotated[list[AnyMessage], add_messages]


class BasicAgent():
    def __init__(self, llm):
        chat = ChatHuggingFace(llm=llm, verbose=True)
        
        search_tool = DuckDuckGoSearchRun()
        vision_llm = ChatOpenAI(model="gpt-4o")
        self.tools = [extract_text, describe_image, search_tool]
        self.chat_with_tools = chat.bind_tools(self.tools)
        self._initialize_graph()
        print("BasicAgent initialized.")

    
    def _initialize_graph(self):
        builder = StateGraph(AgentState)

        # Define nodes
        builder.add_node("assistant", self.assistant)
        builder.add_node("tools", ToolNode(self.tools))

        # Define edges
        builder.add_edge(START, "assistant")
        builder.add_conditional_edges("assistant",tools_condition)
        builder.add_edge("tools", "assistant")

        # Compile the graph
        self.agent = builder.compile()

    
    def __call__(self, question: str) -> str:
        print(f"Agent received question: {question}.")
        messages=[HumanMessage(content=question)]
        response = self.agent.invoke({"messages":messages})
        answer = response['messages'][-1].content
        print(f"Agent returning answer: {answer}")
        return answer

    
    def assistant(self, state: AgentState):
 
        sys_msg = SystemMessage(content=f"""
        You are a general AI assistant. I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. 
        YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
        """)
        print("Calling assistant with state: ",  state["messages"])
        response = self.chat_with_tools.invoke(state["messages"])

#        updated_messages = [self.chat_with_tools.invoke([sys_msg] + state["messages"])]
 
        # Ensure the response is not None
        if response is None:
            raise RuntimeError("chat_with_tools.invoke returned None")
    
        # Ensure response is a list if expected
        if isinstance(response, list):
            updated_messages = response
        else:
            updated_messages = [response]
            
        return {
            "messages":  updated_messages,            
        }