YuryS commited on
Commit
cbe419f
·
1 Parent(s): ae7a494

My model added

Browse files
Files changed (4) hide show
  1. .gitignore +4 -0
  2. model.py +180 -0
  3. tools.py +104 -0
  4. validation.py +63 -0
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ .idea/
2
+ __pycache__/
3
+ dataset/
4
+ .env
model.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+
4
+ from PIL import Image
5
+ import io
6
+ from typing import TypedDict, Annotated
7
+
8
+ from dotenv import load_dotenv
9
+ from langgraph.graph import START, StateGraph
10
+ from langchain_core.messages import AnyMessage
11
+ from langgraph.graph.message import add_messages
12
+ from langchain_core.messages import HumanMessage, SystemMessage
13
+ from langchain_openai import AzureChatOpenAI
14
+ from langgraph.graph.state import CompiledStateGraph
15
+ from langgraph.prebuilt import tools_condition
16
+ from langgraph.prebuilt import ToolNode
17
+ import matplotlib.pyplot as plt
18
+
19
+ from typing import Optional
20
+
21
+ from tools import get_all_tools
22
+
23
+
24
+ load_dotenv(override=True)
25
+
26
+
27
+ class AgentState(TypedDict):
28
+ # The input document
29
+ input_file: Optional[str]
30
+ messages: Annotated[list[AnyMessage], add_messages]
31
+
32
+ assistant_system = (
33
+ 'You are a general AI assistant. I will ask you a question. Think step-by-step, Report your thoughts, and finish '
34
+ 'your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number '
35
+ "OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, "
36
+ "don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If "
37
+ "you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in "
38
+ "plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules "
39
+ "depending of whether the element to be put in the list is a number or a string."
40
+ )
41
+
42
+ class AssistantModel:
43
+ def __init__(self):
44
+ llm = AzureChatOpenAI(
45
+ openai_api_version="2024-02-01",
46
+ azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"),
47
+ openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
48
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
49
+ temperature=0.0
50
+ )
51
+
52
+ self.llm_with_tools = llm.bind_tools(get_all_tools(), parallel_tool_calls=False)
53
+ self.graph = self._build_graph()
54
+ # self.show_graph()
55
+
56
+
57
+ def _assistant(self, state: AgentState):
58
+ sys_msg = SystemMessage(content=assistant_system)
59
+
60
+ return {"messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])]}
61
+
62
+ def show_graph(self):
63
+ # python -m pip install --config-settings="--global-option=build_ext" --config-settings="--global-option=-IC:\Program Files\Graphviz\include" --config-settings="--global-option=-LC:\Program Files\Graphviz\lib" pygraphviz
64
+ png = self.graph.get_graph(xray=True).draw_png()
65
+ image = Image.open(io.BytesIO(png))
66
+
67
+ plt.imshow(image)
68
+ plt.axis('off') # Turn off axes for better visualization
69
+ plt.show(block=False)
70
+
71
+
72
+ def _build_graph(self) -> CompiledStateGraph:
73
+ # Graph
74
+ builder = StateGraph(AgentState)
75
+
76
+ # Define nodes: these do the work
77
+ builder.add_node("assistant", self._assistant)
78
+ builder.add_node("tools", ToolNode(get_all_tools()))
79
+
80
+ # Define edges: these determine how the control flow moves
81
+ builder.add_edge(START, "assistant")
82
+ builder.add_conditional_edges(
83
+ "assistant",
84
+ # If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
85
+ # If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
86
+ tools_condition,
87
+ )
88
+ builder.add_edge("tools", "assistant")
89
+ react_graph = builder.compile()
90
+
91
+ return react_graph
92
+
93
+ @staticmethod
94
+ def _get_final_answer(message: AnyMessage) -> str:
95
+ """Extract the final answer from the message content."""
96
+ # Assuming the final answer is always at the end of the message
97
+ return message.content.split("FINAL ANSWER:")[-1].strip()
98
+
99
+ def _get_file_content(self, file_name: str) -> str:
100
+ """Get the file content."""
101
+ if file_name is None or file_name == '':
102
+ return ''
103
+
104
+ header = '**Attached file content:**\n'
105
+
106
+ text_file = ['.py', '.txt', '.json']
107
+
108
+ full_file_name = os.path.join(r'.\dataset', file_name)
109
+
110
+ if any(file_name.endswith(ext) for ext in text_file):
111
+ with open(full_file_name, 'r', encoding='utf-8') as f:
112
+ return header + f.read()
113
+
114
+ elif file_name.endswith(".xlsx"):
115
+ df = pd.read_excel(full_file_name)
116
+ res = df.to_html(index=False)
117
+ return header + res if res else ''
118
+
119
+ else:
120
+ return ''
121
+
122
+ def _get_image_url(self, file_name: str) -> str:
123
+ exts = ['.png', '.jpg', '.jpeg', '.gif']
124
+
125
+ if any(file_name.endswith(ext) for ext in exts):
126
+ without_ext = file_name.split('.')[0]
127
+ return f'https://agents-course-unit4-scoring.hf.space/files/{without_ext}'
128
+ else:
129
+ return ''
130
+
131
+
132
+ def ask_question(self, question: str, file_name: str) -> str:
133
+ question_with_file = question + '\n' + self._get_file_content(file_name)
134
+ image_url = self._get_image_url(file_name)
135
+
136
+ if image_url != '':
137
+ content = [
138
+ {
139
+ "type": "image_url",
140
+ "image_url": {
141
+ "url": image_url
142
+ }
143
+ },
144
+ {
145
+ "type": "text",
146
+ "text": question_with_file
147
+ }
148
+ ]
149
+ else:
150
+ content = question_with_file
151
+
152
+ messages = [HumanMessage(content=content)]
153
+
154
+ messages = self.graph.invoke({"messages": messages})
155
+
156
+ for m in messages['messages']:
157
+ m.pretty_print()
158
+
159
+ print('@' * 50)
160
+ final_answer = AssistantModel._get_final_answer(messages['messages'][-1])
161
+ print('The final answer is:', final_answer)
162
+
163
+ return final_answer
164
+
165
+ if __name__ == '__main__':
166
+ model = AssistantModel()
167
+
168
+ q = 'Divide 6790 by 5'
169
+ # q = 'How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.'
170
+ # q = '.rewsna eht sa "tfel" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI'
171
+ # q = 'Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?'
172
+ # q = 'Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name.'
173
+ q = 'What is the final numeric output from the attached Python code?'
174
+ f = 'f918266a-b3e0-4914-865d-4faa564f1aef.py'
175
+ q = 'The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.'
176
+ f = '7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx'
177
+ # q = "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation."
178
+ # f = 'cca530fc-4052-43b2-b130-b30968d8aa44.png'
179
+
180
+ answer = model.ask_question(q, f)
tools.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Type, Optional
2
+
3
+ from langchain_community.document_loaders import AsyncChromiumLoader
4
+ from langchain_community.document_transformers import BeautifulSoupTransformer
5
+ from langchain_community.tools.wikipedia.tool import WikipediaQueryInput
6
+ from langchain_community.tools import WikipediaQueryRun
7
+ from langchain_community.utilities import WikipediaAPIWrapper
8
+ from langchain_core.callbacks import CallbackManagerForToolRun
9
+ from langchain_core.tools import BaseTool
10
+ from langchain_tavily import TavilySearch
11
+ from pydantic import BaseModel, Field
12
+
13
+
14
+ def _get_web_page(url: str) -> str:
15
+ """Fetches the content of a web page and transforms (beautify) it into a string."""
16
+
17
+ loader = AsyncChromiumLoader([url])
18
+ docs = loader.load()
19
+ bs_transformer = BeautifulSoupTransformer()
20
+ docs = bs_transformer.transform_documents(docs)
21
+ return '\n'.join(['=' * 30 + '\n' + doc.page_content for doc in docs])
22
+
23
+ class WikipediaQueryLoad(BaseTool):
24
+ """Tool that searches the Wikipedia API."""
25
+
26
+ name: str = "wikipedia"
27
+ description: str = (
28
+ "A wrapper around Wikipedia. "
29
+ "Useful for when you need to answer general questions about "
30
+ "people, places, companies, facts, historical events, or other subjects. "
31
+ "Input should be a search query."
32
+ )
33
+ api_wrapper: WikipediaAPIWrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=20000)
34
+
35
+ args_schema: Type[BaseModel] = WikipediaQueryInput
36
+
37
+ def _run(
38
+ self,
39
+ query: str,
40
+ run_manager: Optional[CallbackManagerForToolRun] = None,
41
+ ) -> str:
42
+ """Use the Wikipedia tool."""
43
+ page_titles = self.api_wrapper.wiki_client.search(
44
+ query[:300], results=1
45
+ )
46
+ summary = self.api_wrapper.run(query)
47
+
48
+ # Wikipedia python package doesn't properly support some wiki syntax (i.e. tables), so
49
+ # the full wiki page is read separately
50
+ full_page = _get_web_page(f"https://en.wikipedia.org/wiki/{page_titles[0]}")
51
+
52
+ res = [
53
+ '**Wiki page url**:',
54
+ page_titles[0],
55
+ '**Wiki page summary:**',
56
+ summary,
57
+ '**Full page content:**',
58
+ full_page
59
+ ]
60
+ return '\n'.join(res)
61
+
62
+
63
+ class WebScrapTool(BaseTool):
64
+ name: str = "webscraper"
65
+
66
+ description: str = "Loads full content of the web page."
67
+
68
+ # Load HTML
69
+ def _run(
70
+ self,
71
+ url: str,
72
+ run_manager: Optional[CallbackManagerForToolRun] = None,
73
+ ) -> str:
74
+ return _get_web_page(url)
75
+
76
+
77
+ class CalculatorTool(BaseTool):
78
+ """Tool that performs basic calculations."""
79
+
80
+ name: str = "calculator"
81
+ description: str = (
82
+ "A calculator. "
83
+ "Useful for when you need to perform basic calculations."
84
+ )
85
+
86
+ def _run(
87
+ self,
88
+ expression: str,
89
+ run_manager: Optional[CallbackManagerForToolRun] = None,
90
+ ) -> float:
91
+ """Use the calculator tool."""
92
+ return eval(expression)
93
+
94
+
95
+ wiki = WikipediaQueryLoad(api_wrapper=WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=20000))
96
+
97
+ def get_all_tools() -> list[BaseTool]:
98
+ """Get all tools."""
99
+ return [
100
+ wiki,
101
+ WebScrapTool(),
102
+ TavilySearch(max_results=5, topic="general"),
103
+ CalculatorTool()
104
+ ]
validation.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ import requests
4
+
5
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
6
+
7
+ def _get_response(url: str):
8
+ try:
9
+ response = requests.get(url, timeout=15)
10
+ response.raise_for_status()
11
+ except requests.exceptions.RequestException as e:
12
+ print(f"Error fetching questions: {e}")
13
+ return None
14
+ except Exception as e:
15
+ print(f"An unexpected error occurred fetching questions: {e}")
16
+ return None
17
+ return response
18
+
19
+ def _get_response_json(url: str):
20
+ try:
21
+ response = _get_response(url)
22
+ questions_data = response.json()
23
+ if not questions_data:
24
+ print("Fetched questions list is empty.")
25
+ return {}, None
26
+ print(f"Fetched {len(questions_data)} questions.")
27
+ except requests.exceptions.JSONDecodeError as e:
28
+ print(f"Error decoding JSON response from questions endpoint: {e}")
29
+ print(f"Response text: {response.text[:500]}")
30
+ return {}, None
31
+
32
+ return questions_data
33
+
34
+
35
+ def load_questions() -> None:
36
+ questions_url = f"{DEFAULT_API_URL}/questions"
37
+
38
+ questions_data = _get_response_json(questions_url)
39
+
40
+ with open(r'./dataset/questions.json', 'w') as f:
41
+ json.dump(questions_data, f, indent=2)
42
+
43
+
44
+ def load_files() -> None:
45
+ with open(r'./dataset/questions.json', 'r') as f:
46
+ questions_data = json.load(f)
47
+
48
+ for q in questions_data:
49
+ if q["file_name"] != '':
50
+ files_url = f'{DEFAULT_API_URL}/files/{q["task_id"]}'
51
+
52
+ print(f"Fetching file from: {files_url}")
53
+
54
+ file_data = _get_response(files_url)
55
+
56
+ with open(f'./dataset/{q["file_name"]}', 'wb') as f:
57
+ f.write(file_data.content)
58
+ print(f"File {q['file_name']} downloaded successfully.")
59
+
60
+
61
+ if __name__ == '__main__':
62
+ # load_questions()
63
+ load_files()