Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| from PIL import Image | |
| import io | |
| from typing import TypedDict, Annotated | |
| from dotenv import load_dotenv | |
| from langgraph.graph import START, StateGraph | |
| from langchain_core.messages import AnyMessage | |
| from langgraph.graph.message import add_messages | |
| from langchain_core.messages import HumanMessage, SystemMessage | |
| from langchain_openai import AzureChatOpenAI | |
| from langgraph.graph.state import CompiledStateGraph | |
| from langgraph.prebuilt import tools_condition | |
| from langgraph.prebuilt import ToolNode | |
| from typing import Optional | |
| from tools import get_all_tools | |
| load_dotenv(override=True) | |
| class AgentState(TypedDict): | |
| # The input document | |
| input_file: Optional[str] | |
| messages: Annotated[list[AnyMessage], add_messages] | |
| assistant_system = ( | |
| 'You are a general AI assistant. I will ask you a question. Think step-by-step, Report your thoughts, and finish ' | |
| 'your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number ' | |
| "OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, " | |
| "don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If " | |
| "you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in " | |
| "plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules " | |
| "depending of whether the element to be put in the list is a number or a string." | |
| ) | |
| class AssistantModel: | |
| def __init__(self, api_key: str | None = None, deployment: str | None = None, endpoint: str | None = None): | |
| llm = AzureChatOpenAI( | |
| openai_api_version="2024-02-01", | |
| azure_deployment=deployment if deployment is not None else os.getenv("AZURE_OPENAI_DEPLOYMENT"), | |
| openai_api_key=api_key if api_key is not None else os.getenv("AZURE_OPENAI_API_KEY"), | |
| azure_endpoint=endpoint if endpoint is not None else os.getenv("AZURE_OPENAI_ENDPOINT"), | |
| temperature=0.0 | |
| ) | |
| self.llm_with_tools = llm.bind_tools(get_all_tools(), parallel_tool_calls=False) | |
| self.graph = self._build_graph() | |
| # self.show_graph() | |
| def _assistant(self, state: AgentState): | |
| sys_msg = SystemMessage(content=assistant_system) | |
| return {"messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])]} | |
| def show_graph(self): | |
| import matplotlib.pyplot as plt | |
| # python -m pip install --config-settings="--global-option=build_ext" --config-settings="--global-option=-IC:\Program Files\Graphviz\include" --config-settings="--global-option=-LC:\Program Files\Graphviz\lib" pygraphviz | |
| png = self.graph.get_graph(xray=True).draw_png() | |
| image = Image.open(io.BytesIO(png)) | |
| plt.imshow(image) | |
| plt.axis('off') # Turn off axes for better visualization | |
| plt.show(block=False) | |
| def _build_graph(self) -> CompiledStateGraph: | |
| # Graph | |
| builder = StateGraph(AgentState) | |
| # Define nodes: these do the work | |
| builder.add_node("assistant", self._assistant) | |
| builder.add_node("tools", ToolNode(get_all_tools())) | |
| # Define edges: these determine how the control flow moves | |
| builder.add_edge(START, "assistant") | |
| builder.add_conditional_edges( | |
| "assistant", | |
| # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools | |
| # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END | |
| tools_condition, | |
| ) | |
| builder.add_edge("tools", "assistant") | |
| react_graph = builder.compile() | |
| return react_graph | |
| def _get_final_answer(message: AnyMessage) -> str: | |
| """Extract the final answer from the message content.""" | |
| # Assuming the final answer is always at the end of the message | |
| return message.content.split("FINAL ANSWER:")[-1].strip() | |
| def _get_file_content(self, file_name: str) -> str: | |
| """Get the file content.""" | |
| if file_name is None or file_name == '': | |
| return '' | |
| header = '**Attached file content:**\n' | |
| text_file = ['.py', '.txt', '.json'] | |
| full_file_name = os.path.join(r'.\dataset', file_name) | |
| if any(file_name.endswith(ext) for ext in text_file): | |
| with open(full_file_name, 'r', encoding='utf-8') as f: | |
| return header + f.read() | |
| elif file_name.endswith(".xlsx"): | |
| df = pd.read_excel(full_file_name) | |
| res = df.to_html(index=False) | |
| return header + res if res else '' | |
| else: | |
| return '' | |
| def _get_image_url(self, file_name: str) -> str: | |
| exts = ['.png', '.jpg', '.jpeg', '.gif'] | |
| if any(file_name.endswith(ext) for ext in exts): | |
| without_ext = file_name.split('.')[0] | |
| return f'https://agents-course-unit4-scoring.hf.space/files/{without_ext}' | |
| else: | |
| return '' | |
| def ask_question(self, question: str, file_name: str) -> str: | |
| question_with_file = question + '\n' + self._get_file_content(file_name) | |
| print('Question:', question_with_file) | |
| image_url = self._get_image_url(file_name) | |
| print('Image URL:', image_url) | |
| if image_url != '': | |
| content = [ | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": image_url | |
| } | |
| }, | |
| { | |
| "type": "text", | |
| "text": question_with_file | |
| } | |
| ] | |
| else: | |
| content = question_with_file | |
| messages = [HumanMessage(content=content)] | |
| messages = self.graph.invoke({"messages": messages}) | |
| for m in messages['messages']: | |
| m.pretty_print() | |
| print('@' * 50) | |
| final_answer = AssistantModel._get_final_answer(messages['messages'][-1]) | |
| print('The final answer is:', final_answer) | |
| return final_answer | |
| if __name__ == '__main__': | |
| model = AssistantModel() | |
| q = 'Divide 6790 by 5' | |
| f = '' | |
| q = 'How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.' | |
| # q = '.rewsna eht sa "tfel" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI' | |
| # q = 'Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?' | |
| # q = 'Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name.' | |
| # q = 'What is the final numeric output from the attached Python code?' | |
| # f = 'f918266a-b3e0-4914-865d-4faa564f1aef.py' | |
| # q = 'The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.' | |
| # f = '7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx' | |
| # q = "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation." | |
| # f = 'cca530fc-4052-43b2-b130-b30968d8aa44.png' | |
| answer = model.ask_question(q, f) | |