Spaces:
Sleeping
Sleeping
My model added
Browse files- .gitignore +4 -0
- model.py +180 -0
- tools.py +104 -0
- validation.py +63 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea/
|
| 2 |
+
__pycache__/
|
| 3 |
+
dataset/
|
| 4 |
+
.env
|
model.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import io
|
| 6 |
+
from typing import TypedDict, Annotated
|
| 7 |
+
|
| 8 |
+
from dotenv import load_dotenv
|
| 9 |
+
from langgraph.graph import START, StateGraph
|
| 10 |
+
from langchain_core.messages import AnyMessage
|
| 11 |
+
from langgraph.graph.message import add_messages
|
| 12 |
+
from langchain_core.messages import HumanMessage, SystemMessage
|
| 13 |
+
from langchain_openai import AzureChatOpenAI
|
| 14 |
+
from langgraph.graph.state import CompiledStateGraph
|
| 15 |
+
from langgraph.prebuilt import tools_condition
|
| 16 |
+
from langgraph.prebuilt import ToolNode
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
from tools import get_all_tools
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
load_dotenv(override=True)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class AgentState(TypedDict):
|
| 28 |
+
# The input document
|
| 29 |
+
input_file: Optional[str]
|
| 30 |
+
messages: Annotated[list[AnyMessage], add_messages]
|
| 31 |
+
|
| 32 |
+
assistant_system = (
|
| 33 |
+
'You are a general AI assistant. I will ask you a question. Think step-by-step, Report your thoughts, and finish '
|
| 34 |
+
'your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number '
|
| 35 |
+
"OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, "
|
| 36 |
+
"don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If "
|
| 37 |
+
"you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in "
|
| 38 |
+
"plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules "
|
| 39 |
+
"depending of whether the element to be put in the list is a number or a string."
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
class AssistantModel:
|
| 43 |
+
def __init__(self):
|
| 44 |
+
llm = AzureChatOpenAI(
|
| 45 |
+
openai_api_version="2024-02-01",
|
| 46 |
+
azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT"),
|
| 47 |
+
openai_api_key=os.getenv("AZURE_OPENAI_API_KEY"),
|
| 48 |
+
azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
|
| 49 |
+
temperature=0.0
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.llm_with_tools = llm.bind_tools(get_all_tools(), parallel_tool_calls=False)
|
| 53 |
+
self.graph = self._build_graph()
|
| 54 |
+
# self.show_graph()
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _assistant(self, state: AgentState):
|
| 58 |
+
sys_msg = SystemMessage(content=assistant_system)
|
| 59 |
+
|
| 60 |
+
return {"messages": [self.llm_with_tools.invoke([sys_msg] + state["messages"])]}
|
| 61 |
+
|
| 62 |
+
def show_graph(self):
|
| 63 |
+
# python -m pip install --config-settings="--global-option=build_ext" --config-settings="--global-option=-IC:\Program Files\Graphviz\include" --config-settings="--global-option=-LC:\Program Files\Graphviz\lib" pygraphviz
|
| 64 |
+
png = self.graph.get_graph(xray=True).draw_png()
|
| 65 |
+
image = Image.open(io.BytesIO(png))
|
| 66 |
+
|
| 67 |
+
plt.imshow(image)
|
| 68 |
+
plt.axis('off') # Turn off axes for better visualization
|
| 69 |
+
plt.show(block=False)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def _build_graph(self) -> CompiledStateGraph:
|
| 73 |
+
# Graph
|
| 74 |
+
builder = StateGraph(AgentState)
|
| 75 |
+
|
| 76 |
+
# Define nodes: these do the work
|
| 77 |
+
builder.add_node("assistant", self._assistant)
|
| 78 |
+
builder.add_node("tools", ToolNode(get_all_tools()))
|
| 79 |
+
|
| 80 |
+
# Define edges: these determine how the control flow moves
|
| 81 |
+
builder.add_edge(START, "assistant")
|
| 82 |
+
builder.add_conditional_edges(
|
| 83 |
+
"assistant",
|
| 84 |
+
# If the latest message (result) from assistant is a tool call -> tools_condition routes to tools
|
| 85 |
+
# If the latest message (result) from assistant is a not a tool call -> tools_condition routes to END
|
| 86 |
+
tools_condition,
|
| 87 |
+
)
|
| 88 |
+
builder.add_edge("tools", "assistant")
|
| 89 |
+
react_graph = builder.compile()
|
| 90 |
+
|
| 91 |
+
return react_graph
|
| 92 |
+
|
| 93 |
+
@staticmethod
|
| 94 |
+
def _get_final_answer(message: AnyMessage) -> str:
|
| 95 |
+
"""Extract the final answer from the message content."""
|
| 96 |
+
# Assuming the final answer is always at the end of the message
|
| 97 |
+
return message.content.split("FINAL ANSWER:")[-1].strip()
|
| 98 |
+
|
| 99 |
+
def _get_file_content(self, file_name: str) -> str:
|
| 100 |
+
"""Get the file content."""
|
| 101 |
+
if file_name is None or file_name == '':
|
| 102 |
+
return ''
|
| 103 |
+
|
| 104 |
+
header = '**Attached file content:**\n'
|
| 105 |
+
|
| 106 |
+
text_file = ['.py', '.txt', '.json']
|
| 107 |
+
|
| 108 |
+
full_file_name = os.path.join(r'.\dataset', file_name)
|
| 109 |
+
|
| 110 |
+
if any(file_name.endswith(ext) for ext in text_file):
|
| 111 |
+
with open(full_file_name, 'r', encoding='utf-8') as f:
|
| 112 |
+
return header + f.read()
|
| 113 |
+
|
| 114 |
+
elif file_name.endswith(".xlsx"):
|
| 115 |
+
df = pd.read_excel(full_file_name)
|
| 116 |
+
res = df.to_html(index=False)
|
| 117 |
+
return header + res if res else ''
|
| 118 |
+
|
| 119 |
+
else:
|
| 120 |
+
return ''
|
| 121 |
+
|
| 122 |
+
def _get_image_url(self, file_name: str) -> str:
|
| 123 |
+
exts = ['.png', '.jpg', '.jpeg', '.gif']
|
| 124 |
+
|
| 125 |
+
if any(file_name.endswith(ext) for ext in exts):
|
| 126 |
+
without_ext = file_name.split('.')[0]
|
| 127 |
+
return f'https://agents-course-unit4-scoring.hf.space/files/{without_ext}'
|
| 128 |
+
else:
|
| 129 |
+
return ''
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def ask_question(self, question: str, file_name: str) -> str:
|
| 133 |
+
question_with_file = question + '\n' + self._get_file_content(file_name)
|
| 134 |
+
image_url = self._get_image_url(file_name)
|
| 135 |
+
|
| 136 |
+
if image_url != '':
|
| 137 |
+
content = [
|
| 138 |
+
{
|
| 139 |
+
"type": "image_url",
|
| 140 |
+
"image_url": {
|
| 141 |
+
"url": image_url
|
| 142 |
+
}
|
| 143 |
+
},
|
| 144 |
+
{
|
| 145 |
+
"type": "text",
|
| 146 |
+
"text": question_with_file
|
| 147 |
+
}
|
| 148 |
+
]
|
| 149 |
+
else:
|
| 150 |
+
content = question_with_file
|
| 151 |
+
|
| 152 |
+
messages = [HumanMessage(content=content)]
|
| 153 |
+
|
| 154 |
+
messages = self.graph.invoke({"messages": messages})
|
| 155 |
+
|
| 156 |
+
for m in messages['messages']:
|
| 157 |
+
m.pretty_print()
|
| 158 |
+
|
| 159 |
+
print('@' * 50)
|
| 160 |
+
final_answer = AssistantModel._get_final_answer(messages['messages'][-1])
|
| 161 |
+
print('The final answer is:', final_answer)
|
| 162 |
+
|
| 163 |
+
return final_answer
|
| 164 |
+
|
| 165 |
+
if __name__ == '__main__':
|
| 166 |
+
model = AssistantModel()
|
| 167 |
+
|
| 168 |
+
q = 'Divide 6790 by 5'
|
| 169 |
+
# q = 'How many studio albums were published by Mercedes Sosa between 2000 and 2009 (included)? You can use the latest 2022 version of english wikipedia.'
|
| 170 |
+
# q = '.rewsna eht sa "tfel" drow eht fo etisoppo eht etirw ,ecnetnes siht dnatsrednu uoy fI'
|
| 171 |
+
# q = 'Who nominated the only Featured Article on English Wikipedia about a dinosaur that was promoted in November 2016?'
|
| 172 |
+
# q = 'Who did the actor who played Ray in the Polish-language version of Everybody Loves Raymond play in Magda M.? Give only the first name.'
|
| 173 |
+
q = 'What is the final numeric output from the attached Python code?'
|
| 174 |
+
f = 'f918266a-b3e0-4914-865d-4faa564f1aef.py'
|
| 175 |
+
q = 'The attached Excel file contains the sales of menu items for a local fast-food chain. What were the total sales that the chain made from food (not including drinks)? Express your answer in USD with two decimal places.'
|
| 176 |
+
f = '7bd855d8-463d-4ed5-93ca-5fe35145f733.xlsx'
|
| 177 |
+
# q = "Review the chess position provided in the image. It is black's turn. Provide the correct next move for black which guarantees a win. Please provide your response in algebraic notation."
|
| 178 |
+
# f = 'cca530fc-4052-43b2-b130-b30968d8aa44.png'
|
| 179 |
+
|
| 180 |
+
answer = model.ask_question(q, f)
|
tools.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Type, Optional
|
| 2 |
+
|
| 3 |
+
from langchain_community.document_loaders import AsyncChromiumLoader
|
| 4 |
+
from langchain_community.document_transformers import BeautifulSoupTransformer
|
| 5 |
+
from langchain_community.tools.wikipedia.tool import WikipediaQueryInput
|
| 6 |
+
from langchain_community.tools import WikipediaQueryRun
|
| 7 |
+
from langchain_community.utilities import WikipediaAPIWrapper
|
| 8 |
+
from langchain_core.callbacks import CallbackManagerForToolRun
|
| 9 |
+
from langchain_core.tools import BaseTool
|
| 10 |
+
from langchain_tavily import TavilySearch
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _get_web_page(url: str) -> str:
|
| 15 |
+
"""Fetches the content of a web page and transforms (beautify) it into a string."""
|
| 16 |
+
|
| 17 |
+
loader = AsyncChromiumLoader([url])
|
| 18 |
+
docs = loader.load()
|
| 19 |
+
bs_transformer = BeautifulSoupTransformer()
|
| 20 |
+
docs = bs_transformer.transform_documents(docs)
|
| 21 |
+
return '\n'.join(['=' * 30 + '\n' + doc.page_content for doc in docs])
|
| 22 |
+
|
| 23 |
+
class WikipediaQueryLoad(BaseTool):
|
| 24 |
+
"""Tool that searches the Wikipedia API."""
|
| 25 |
+
|
| 26 |
+
name: str = "wikipedia"
|
| 27 |
+
description: str = (
|
| 28 |
+
"A wrapper around Wikipedia. "
|
| 29 |
+
"Useful for when you need to answer general questions about "
|
| 30 |
+
"people, places, companies, facts, historical events, or other subjects. "
|
| 31 |
+
"Input should be a search query."
|
| 32 |
+
)
|
| 33 |
+
api_wrapper: WikipediaAPIWrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=20000)
|
| 34 |
+
|
| 35 |
+
args_schema: Type[BaseModel] = WikipediaQueryInput
|
| 36 |
+
|
| 37 |
+
def _run(
|
| 38 |
+
self,
|
| 39 |
+
query: str,
|
| 40 |
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 41 |
+
) -> str:
|
| 42 |
+
"""Use the Wikipedia tool."""
|
| 43 |
+
page_titles = self.api_wrapper.wiki_client.search(
|
| 44 |
+
query[:300], results=1
|
| 45 |
+
)
|
| 46 |
+
summary = self.api_wrapper.run(query)
|
| 47 |
+
|
| 48 |
+
# Wikipedia python package doesn't properly support some wiki syntax (i.e. tables), so
|
| 49 |
+
# the full wiki page is read separately
|
| 50 |
+
full_page = _get_web_page(f"https://en.wikipedia.org/wiki/{page_titles[0]}")
|
| 51 |
+
|
| 52 |
+
res = [
|
| 53 |
+
'**Wiki page url**:',
|
| 54 |
+
page_titles[0],
|
| 55 |
+
'**Wiki page summary:**',
|
| 56 |
+
summary,
|
| 57 |
+
'**Full page content:**',
|
| 58 |
+
full_page
|
| 59 |
+
]
|
| 60 |
+
return '\n'.join(res)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class WebScrapTool(BaseTool):
|
| 64 |
+
name: str = "webscraper"
|
| 65 |
+
|
| 66 |
+
description: str = "Loads full content of the web page."
|
| 67 |
+
|
| 68 |
+
# Load HTML
|
| 69 |
+
def _run(
|
| 70 |
+
self,
|
| 71 |
+
url: str,
|
| 72 |
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 73 |
+
) -> str:
|
| 74 |
+
return _get_web_page(url)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class CalculatorTool(BaseTool):
|
| 78 |
+
"""Tool that performs basic calculations."""
|
| 79 |
+
|
| 80 |
+
name: str = "calculator"
|
| 81 |
+
description: str = (
|
| 82 |
+
"A calculator. "
|
| 83 |
+
"Useful for when you need to perform basic calculations."
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def _run(
|
| 87 |
+
self,
|
| 88 |
+
expression: str,
|
| 89 |
+
run_manager: Optional[CallbackManagerForToolRun] = None,
|
| 90 |
+
) -> float:
|
| 91 |
+
"""Use the calculator tool."""
|
| 92 |
+
return eval(expression)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
wiki = WikipediaQueryLoad(api_wrapper=WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=20000))
|
| 96 |
+
|
| 97 |
+
def get_all_tools() -> list[BaseTool]:
|
| 98 |
+
"""Get all tools."""
|
| 99 |
+
return [
|
| 100 |
+
wiki,
|
| 101 |
+
WebScrapTool(),
|
| 102 |
+
TavilySearch(max_results=5, topic="general"),
|
| 103 |
+
CalculatorTool()
|
| 104 |
+
]
|
validation.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
import requests
|
| 4 |
+
|
| 5 |
+
DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
| 6 |
+
|
| 7 |
+
def _get_response(url: str):
|
| 8 |
+
try:
|
| 9 |
+
response = requests.get(url, timeout=15)
|
| 10 |
+
response.raise_for_status()
|
| 11 |
+
except requests.exceptions.RequestException as e:
|
| 12 |
+
print(f"Error fetching questions: {e}")
|
| 13 |
+
return None
|
| 14 |
+
except Exception as e:
|
| 15 |
+
print(f"An unexpected error occurred fetching questions: {e}")
|
| 16 |
+
return None
|
| 17 |
+
return response
|
| 18 |
+
|
| 19 |
+
def _get_response_json(url: str):
|
| 20 |
+
try:
|
| 21 |
+
response = _get_response(url)
|
| 22 |
+
questions_data = response.json()
|
| 23 |
+
if not questions_data:
|
| 24 |
+
print("Fetched questions list is empty.")
|
| 25 |
+
return {}, None
|
| 26 |
+
print(f"Fetched {len(questions_data)} questions.")
|
| 27 |
+
except requests.exceptions.JSONDecodeError as e:
|
| 28 |
+
print(f"Error decoding JSON response from questions endpoint: {e}")
|
| 29 |
+
print(f"Response text: {response.text[:500]}")
|
| 30 |
+
return {}, None
|
| 31 |
+
|
| 32 |
+
return questions_data
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def load_questions() -> None:
|
| 36 |
+
questions_url = f"{DEFAULT_API_URL}/questions"
|
| 37 |
+
|
| 38 |
+
questions_data = _get_response_json(questions_url)
|
| 39 |
+
|
| 40 |
+
with open(r'./dataset/questions.json', 'w') as f:
|
| 41 |
+
json.dump(questions_data, f, indent=2)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_files() -> None:
|
| 45 |
+
with open(r'./dataset/questions.json', 'r') as f:
|
| 46 |
+
questions_data = json.load(f)
|
| 47 |
+
|
| 48 |
+
for q in questions_data:
|
| 49 |
+
if q["file_name"] != '':
|
| 50 |
+
files_url = f'{DEFAULT_API_URL}/files/{q["task_id"]}'
|
| 51 |
+
|
| 52 |
+
print(f"Fetching file from: {files_url}")
|
| 53 |
+
|
| 54 |
+
file_data = _get_response(files_url)
|
| 55 |
+
|
| 56 |
+
with open(f'./dataset/{q["file_name"]}', 'wb') as f:
|
| 57 |
+
f.write(file_data.content)
|
| 58 |
+
print(f"File {q['file_name']} downloaded successfully.")
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == '__main__':
|
| 62 |
+
# load_questions()
|
| 63 |
+
load_files()
|