File size: 7,557 Bytes
cbe419f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a5794
cbe419f
 
21a5794
 
 
cbe419f
 
 
 
 
 
 
 
 
 
 
 
 
 
406b217
cbe419f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383d2d3
 
 
cbe419f
 
383d2d3
 
cbe419f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21a5794
 
cbe419f
 
 
21a5794
 
 
 
cbe419f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
import os
import pandas as pd

from PIL import Image
import io
from typing import TypedDict, Annotated

from dotenv import load_dotenv
from langgraph.graph import START, StateGraph
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode

from typing import Optional

from tools import get_all_tools


load_dotenv(override=True)


class AgentState(TypedDict):
    # The input document
    input_file:  Optional[str]
    messages: Annotated[list[AnyMessage], add_messages]

assistant_system = (
    'You are a general AI assistant. I will ask you a question. Think step-by-step, Report your thoughts, and finish '
    'your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number '
    "OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, "
    "don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If "
    "you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in "
    "plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules "
    "depending of whether the element to be put in the list is a number or a string."
)

class AssistantModel:
    def __init__(self, api_key: str | None = None, deployment: str | None = None, endpoint: str | None = None):
        llm = AzureChatOpenAI(
            openai_api_version="2024-02-01",
            azure_deployment=deployment if deployment is not None else os.getenv("AZURE_OPENAI_DEPLOYMENT"),
            openai_api_key=api_key if api_key is not None else os.getenv("AZURE_OPENAI_API_KEY"),
            azure_endpoint=endpoint if endpoint is not None else os.getenv("AZURE_OPENAI_ENDPOINT"),
            temperature=0.0
        )

        self.llm_with_tools = llm.bind_tools(get_all_tools(), parallel_tool_calls=False)
        self.graph = self._build_graph()
        # self.show_graph()


    def _assistant(self, state: AgentState):
        sys_msg = SystemMessage(content=assistant_system)

        return {"messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])]}

    def show_graph(self):
        import matplotlib.pyplot as plt
        # python -m pip install --config-settings="--global-option=build_ext" --config-settings="--global-option=-IC:\Program Files\Graphviz\include" --config-settings="--global-option=-LC:\Program Files\Graphviz\lib" pygraphviz
        png = self.graph.get_graph(xray=True).draw_png()
        image = Image.open(io.BytesIO(png))

        plt.imshow(image)
        plt.axis('off')  # Turn off axes for better visualization
        plt.show(block=False)


    def _build_graph(self) -> CompiledStateGraph:
        # Graph
        builder = StateGraph(AgentState)

        # Define nodes: these do the work
        builder.add_node("assistant", self._assistant)
        builder.add_node("tools", ToolNode(get_all_tools()))

        # Define edges: these determine how the control flow moves
        builder.add_edge(START, "assistant")
        builder.add_conditional_edges(
            "assistant",
            # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
            # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
            tools_condition,
        )
        builder.add_edge("tools", "assistant")
        react_graph = builder.compile()

        return react_graph

    @staticmethod
    def _get_final_answer(message: AnyMessage) -> str:
        """Extract the final answer from the message content."""
        # Assuming the final answer is always at the end of the message
        return message.content.split("FINAL ANSWER:")[-1].strip()

    def _get_file_content(self, file_name: str) -> str:
        """Get the file content."""
        if file_name is None or file_name == '':
            return ''

        header = '**Attached file content:**\n'

        text_file = ['.py', '.txt', '.json']

        full_file_name = os.path.join(r'.\dataset', file_name)

        if any(file_name.endswith(ext) for ext in text_file):
            with open(full_file_name, 'r', encoding='utf-8') as f:
                return header + f.read()

        elif file_name.endswith(".xlsx"):
            df = pd.read_excel(full_file_name)
            res = df.to_html(index=False)
            return header + res if res else ''

        else:
            return ''

    def _get_image_url(self, file_name: str) -> str:
        exts = ['.png', '.jpg', '.jpeg', '.gif']

        if any(file_name.endswith(ext) for ext in exts):
            without_ext = file_name.split('.')[0]
            return f'https://agents-course-unit4-scoring.hf.space/files/{without_ext}'
        else:
            return ''


    def ask_question(self, question: str, file_name: str) -> str:
        question_with_file = question + '\n' + self._get_file_content(file_name)

        print('Question:', question_with_file)

        image_url = self._get_image_url(file_name)

        print('Image URL:', image_url)

        if image_url != '':
            content = [
                {
                    "type": "image_url",
                    "image_url": {
                        "url": image_url
                    }
                },
                {
                    "type": "text",
                    "text": question_with_file
                }
            ]
        else:
            content = question_with_file

        messages = [HumanMessage(content=content)]

        messages = self.graph.invoke({"messages": messages})

        for m in messages['messages']:
            m.pretty_print()

        print('@' * 50)
        final_answer = AssistantModel._get_final_answer(messages['messages'][-1])
        print('The final answer is:', final_answer)

        return final_answer

if __name__ == '__main__':
    model = AssistantModel()

    q = 'Divide 6790 by 5'
    f = ''
    q = 'How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.'
    # q = '.rewsna eht sa "tfel" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI'
    # q = 'Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?'
    # q = 'Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name.'
    # q = 'What is the final numeric output from the attached Python code?'
    # f = 'f918266a-b3e0-4914-865d-4faa564f1aef.py'
    # q = 'The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.'
    # f = '7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx'
    # q = "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation."
    # f = 'cca530fc-4052-43b2-b130-b30968d8aa44.png'

    answer = model.ask_question(q, f)