YuryS's picture
Requirements.txt added
406b217
import os
import pandas as pd
from PIL import Image
import io
from typing import TypedDict, Annotated
from dotenv import load_dotenv
from langgraph.graph import START, StateGraph
from langchain_core.messages import AnyMessage
from langgraph.graph.message import add_messages
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_openai import AzureChatOpenAI
from langgraph.graph.state import CompiledStateGraph
from langgraph.prebuilt import tools_condition
from langgraph.prebuilt import ToolNode
from typing import Optional
from tools import get_all_tools
load_dotenv(override=True)
class AgentState(TypedDict):
# The input document
input_file: Optional[str]
messages: Annotated[list[AnyMessage], add_messages]
assistant_system = (
'You are a general AI assistant. I will ask you a question. Think step-by-step, Report your thoughts, and finish '
'your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number '
"OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, "
"don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If "
"you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in "
"plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules "
"depending of whether the element to be put in the list is a number or a string."
)
class AssistantModel:
def __init__(self, api_key: str | None = None, deployment: str | None = None, endpoint: str | None = None):
llm = AzureChatOpenAI(
openai_api_version="2024-02-01",
azure_deployment=deployment if deployment is not None else os.getenv("AZURE_OPENAI_DEPLOYMENT"),
openai_api_key=api_key if api_key is not None else os.getenv("AZURE_OPENAI_API_KEY"),
azure_endpoint=endpoint if endpoint is not None else os.getenv("AZURE_OPENAI_ENDPOINT"),
temperature=0.0
)
self.llm_with_tools = llm.bind_tools(get_all_tools(), parallel_tool_calls=False)
self.graph = self._build_graph()
# self.show_graph()
def _assistant(self, state: AgentState):
sys_msg = SystemMessage(content=assistant_system)
return {"messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])]}
def show_graph(self):
import matplotlib.pyplot as plt
# python -m pip install --config-settings="--global-option=build_ext" --config-settings="--global-option=-IC:\Program Files\Graphviz\include" --config-settings="--global-option=-LC:\Program Files\Graphviz\lib" pygraphviz
png = self.graph.get_graph(xray=True).draw_png()
image = Image.open(io.BytesIO(png))
plt.imshow(image)
plt.axis('off') # Turn off axes for better visualization
plt.show(block=False)
def _build_graph(self) -> CompiledStateGraph:
# Graph
builder = StateGraph(AgentState)
# Define nodes: these do the work
builder.add_node("assistant", self._assistant)
builder.add_node("tools", ToolNode(get_all_tools()))
# Define edges: these determine how the control flow moves
builder.add_edge(START, "assistant")
builder.add_conditional_edges(
"assistant",
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
tools_condition,
)
builder.add_edge("tools", "assistant")
react_graph = builder.compile()
return react_graph
@staticmethod
def _get_final_answer(message: AnyMessage) -> str:
"""Extract the final answer from the message content."""
# Assuming the final answer is always at the end of the message
return message.content.split("FINAL ANSWER:")[-1].strip()
def _get_file_content(self, file_name: str) -> str:
"""Get the file content."""
if file_name is None or file_name == '':
return ''
header = '**Attached file content:**\n'
text_file = ['.py', '.txt', '.json']
full_file_name = os.path.join(r'.\dataset', file_name)
if any(file_name.endswith(ext) for ext in text_file):
with open(full_file_name, 'r', encoding='utf-8') as f:
return header + f.read()
elif file_name.endswith(".xlsx"):
df = pd.read_excel(full_file_name)
res = df.to_html(index=False)
return header + res if res else ''
else:
return ''
def _get_image_url(self, file_name: str) -> str:
exts = ['.png', '.jpg', '.jpeg', '.gif']
if any(file_name.endswith(ext) for ext in exts):
without_ext = file_name.split('.')[0]
return f'https://agents-course-unit4-scoring.hf.space/files/{without_ext}'
else:
return ''
def ask_question(self, question: str, file_name: str) -> str:
question_with_file = question + '\n' + self._get_file_content(file_name)
print('Question:', question_with_file)
image_url = self._get_image_url(file_name)
print('Image URL:', image_url)
if image_url != '':
content = [
{
"type": "image_url",
"image_url": {
"url": image_url
}
},
{
"type": "text",
"text": question_with_file
}
]
else:
content = question_with_file
messages = [HumanMessage(content=content)]
messages = self.graph.invoke({"messages": messages})
for m in messages['messages']:
m.pretty_print()
print('@' * 50)
final_answer = AssistantModel._get_final_answer(messages['messages'][-1])
print('The final answer is:', final_answer)
return final_answer
if __name__ == '__main__':
model = AssistantModel()
q = 'Divide 6790 by 5'
f = ''
q = 'How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.'
# q = '.rewsna eht sa "tfel" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI'
# q = 'Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?'
# q = 'Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name.'
# q = 'What is the final numeric output from the attached Python code?'
# f = 'f918266a-b3e0-4914-865d-4faa564f1aef.py'
# q = 'The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.'
# f = '7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx'
# q = "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation."
# f = 'cca530fc-4052-43b2-b130-b30968d8aa44.png'
answer = model.ask_question(q, f)