Spaces:
Sleeping
Sleeping
| # coding=utf-8 | |
| # Copyright 2024 HuggingFace Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| import os | |
| import pathlib | |
| import tempfile | |
| import textwrap | |
| import unittest | |
| import pytest | |
| from IPython.core.interactiveshell import InteractiveShell | |
| from smolagents import Tool | |
| from smolagents.tools import tool | |
| from smolagents.utils import get_source, parse_code_blobs | |
| class AgentTextTests(unittest.TestCase): | |
| def test_parse_code_blobs(self): | |
| with pytest.raises(ValueError): | |
| parse_code_blobs("Wrong blob!") | |
| # Parsing mardkwon with code blobs should work | |
| output = parse_code_blobs(""" | |
| Here is how to solve the problem: | |
| Code: | |
| ```py | |
| import numpy as np | |
| ```<end_code> | |
| """) | |
| assert output == "import numpy as np" | |
| # Parsing code blobs should work | |
| code_blob = "import numpy as np" | |
| output = parse_code_blobs(code_blob) | |
| assert output == code_blob | |
| def test_multiple_code_blobs(self): | |
| test_input = """Here's a function that adds numbers: | |
| ```python | |
| def add(a, b): | |
| return a + b | |
| ``` | |
| And here's a function that multiplies them: | |
| ```py | |
| def multiply(a, b): | |
| return a * b | |
| ```""" | |
| expected_output = """def add(a, b): | |
| return a + b | |
| def multiply(a, b): | |
| return a * b""" | |
| result = parse_code_blobs(test_input) | |
| assert result == expected_output | |
| def ipython_shell(): | |
| """Reset IPython shell before and after each test.""" | |
| shell = InteractiveShell.instance() | |
| shell.reset() # Clean before test | |
| yield shell | |
| shell.reset() # Clean after test | |
| def test_get_source_ipython(ipython_shell, obj_name, code_blob): | |
| ipython_shell.run_cell(code_blob, store_history=True) | |
| obj = ipython_shell.user_ns[obj_name] | |
| assert get_source(obj) == code_blob | |
| def test_get_source_standard_class(): | |
| class TestClass: ... | |
| source = get_source(TestClass) | |
| assert source == "class TestClass: ..." | |
| assert source == textwrap.dedent(inspect.getsource(TestClass)).strip() | |
| def test_get_source_standard_function(): | |
| def test_func(): ... | |
| source = get_source(test_func) | |
| assert source == "def test_func(): ..." | |
| assert source == textwrap.dedent(inspect.getsource(test_func)).strip() | |
| def test_get_source_ipython_errors_empty_cells(ipython_shell): | |
| test_code = textwrap.dedent("""class TestClass:\n ...""").strip() | |
| ipython_shell.user_ns["In"] = [""] | |
| ipython_shell.run_cell(test_code, store_history=True) | |
| with pytest.raises(ValueError, match="No code cells found in IPython session"): | |
| get_source(ipython_shell.user_ns["TestClass"]) | |
| def test_get_source_ipython_errors_definition_not_found(ipython_shell): | |
| test_code = textwrap.dedent("""class TestClass:\n ...""").strip() | |
| ipython_shell.user_ns["In"] = ["", "print('No class definition here')"] | |
| ipython_shell.run_cell(test_code, store_history=True) | |
| with pytest.raises(ValueError, match="Could not find source code for TestClass in IPython history"): | |
| get_source(ipython_shell.user_ns["TestClass"]) | |
| def test_get_source_ipython_errors_type_error(): | |
| with pytest.raises(TypeError, match="Expected class or callable"): | |
| get_source(None) | |
| def test_e2e_class_tool_save(): | |
| class TestTool(Tool): | |
| name = "test_tool" | |
| description = "Test tool description" | |
| inputs = { | |
| "task": { | |
| "type": "string", | |
| "description": "tool input", | |
| } | |
| } | |
| output_type = "string" | |
| def forward(self, task: str): | |
| import IPython # noqa: F401 | |
| return task | |
| test_tool = TestTool() | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| test_tool.save(tmp_dir, make_gradio_app=True) | |
| assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | |
| assert ( | |
| pathlib.Path(tmp_dir, "tool.py").read_text() | |
| == """from typing import Any, Optional | |
| from smolagents.tools import Tool | |
| import IPython | |
| class TestTool(Tool): | |
| name = "test_tool" | |
| description = "Test tool description" | |
| inputs = {'task': {'type': 'string', 'description': 'tool input'}} | |
| output_type = "string" | |
| def forward(self, task: str): | |
| import IPython # noqa: F401 | |
| return task | |
| def __init__(self, *args, **kwargs): | |
| self.is_initialized = False | |
| """ | |
| ) | |
| requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | |
| assert requirements == {"IPython", "smolagents"} | |
| assert ( | |
| pathlib.Path(tmp_dir, "app.py").read_text() | |
| == """from smolagents import launch_gradio_demo | |
| from tool import TestTool | |
| tool = TestTool() | |
| launch_gradio_demo(tool) | |
| """ | |
| ) | |
| def test_e2e_ipython_class_tool_save(): | |
| shell = InteractiveShell.instance() | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| code_blob = textwrap.dedent(f""" | |
| from smolagents.tools import Tool | |
| class TestTool(Tool): | |
| name = "test_tool" | |
| description = "Test tool description" | |
| inputs = {{"task": {{"type": "string", | |
| "description": "tool input", | |
| }} | |
| }} | |
| output_type = "string" | |
| def forward(self, task: str): | |
| import IPython # noqa: F401 | |
| return task | |
| TestTool().save("{tmp_dir}", make_gradio_app=True) | |
| """) | |
| assert shell.run_cell(code_blob, store_history=True).success | |
| assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | |
| assert ( | |
| pathlib.Path(tmp_dir, "tool.py").read_text() | |
| == """from typing import Any, Optional | |
| from smolagents.tools import Tool | |
| import IPython | |
| class TestTool(Tool): | |
| name = "test_tool" | |
| description = "Test tool description" | |
| inputs = {'task': {'type': 'string', 'description': 'tool input'}} | |
| output_type = "string" | |
| def forward(self, task: str): | |
| import IPython # noqa: F401 | |
| return task | |
| def __init__(self, *args, **kwargs): | |
| self.is_initialized = False | |
| """ | |
| ) | |
| requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | |
| assert requirements == {"IPython", "smolagents"} | |
| assert ( | |
| pathlib.Path(tmp_dir, "app.py").read_text() | |
| == """from smolagents import launch_gradio_demo | |
| from tool import TestTool | |
| tool = TestTool() | |
| launch_gradio_demo(tool) | |
| """ | |
| ) | |
| def test_e2e_function_tool_save(): | |
| def test_tool(task: str) -> str: | |
| """ | |
| Test tool description | |
| Args: | |
| task: tool input | |
| """ | |
| import IPython # noqa: F401 | |
| return task | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| test_tool.save(tmp_dir, make_gradio_app=True) | |
| assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | |
| assert ( | |
| pathlib.Path(tmp_dir, "tool.py").read_text() | |
| == """from smolagents import Tool | |
| from typing import Any, Optional | |
| class SimpleTool(Tool): | |
| name = "test_tool" | |
| description = "Test tool description" | |
| inputs = {"task":{"type":"string","description":"tool input"}} | |
| output_type = "string" | |
| def forward(self, task: str) -> str: | |
| \""" | |
| Test tool description | |
| Args: | |
| task: tool input | |
| \""" | |
| import IPython # noqa: F401 | |
| return task""" | |
| ) | |
| requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | |
| assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements | |
| assert ( | |
| pathlib.Path(tmp_dir, "app.py").read_text() | |
| == """from smolagents import launch_gradio_demo | |
| from tool import SimpleTool | |
| tool = SimpleTool() | |
| launch_gradio_demo(tool) | |
| """ | |
| ) | |
| def test_e2e_ipython_function_tool_save(): | |
| shell = InteractiveShell.instance() | |
| with tempfile.TemporaryDirectory() as tmp_dir: | |
| code_blob = textwrap.dedent(f""" | |
| from smolagents import tool | |
| @tool | |
| def test_tool(task: str) -> str: | |
| \""" | |
| Test tool description | |
| Args: | |
| task: tool input | |
| \""" | |
| import IPython # noqa: F401 | |
| return task | |
| test_tool.save("{tmp_dir}", make_gradio_app=True) | |
| """) | |
| assert shell.run_cell(code_blob, store_history=True).success | |
| assert set(os.listdir(tmp_dir)) == {"requirements.txt", "app.py", "tool.py"} | |
| assert ( | |
| pathlib.Path(tmp_dir, "tool.py").read_text() | |
| == """from smolagents import Tool | |
| from typing import Any, Optional | |
| class SimpleTool(Tool): | |
| name = "test_tool" | |
| description = "Test tool description" | |
| inputs = {"task":{"type":"string","description":"tool input"}} | |
| output_type = "string" | |
| def forward(self, task: str) -> str: | |
| \""" | |
| Test tool description | |
| Args: | |
| task: tool input | |
| \""" | |
| import IPython # noqa: F401 | |
| return task""" | |
| ) | |
| requirements = set(pathlib.Path(tmp_dir, "requirements.txt").read_text().split()) | |
| assert requirements == {"smolagents"} # FIXME: IPython should be in the requirements | |
| assert ( | |
| pathlib.Path(tmp_dir, "app.py").read_text() | |
| == """from smolagents import launch_gradio_demo | |
| from tool import SimpleTool | |
| tool = SimpleTool() | |
| launch_gradio_demo(tool) | |
| """ | |
| ) | |