TaoZewen's picture
bug fix
83c8c40 verified
from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
import datetime
import requests
import pytz
import yaml
from tools.final_answer import FinalAnswerTool
import sqlite3
import sqlparse # pip install sqlparse
from faker import Faker # pip install faker
from Gradio_UI import GradioUI
# Below is an example of a tool that does nothing. Amaze us with your creativity !
@tool
def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type
#Keep this format for the description / args / args description but feel free to modify the tool
"""A tool that does nothing yet
Args:
arg1: the first argument
arg2: the second argument
"""
return "What magic will you build ?"
def extract_table_info(create_table_sql: str):
"""
简单解析 CREATE TABLE 语句,返回表名和字段列表
字段列表格式为 [(column_name, data_type), ...]
注意:此函数适用于结构较为简单的 SQL 语句
"""
tokens = sqlparse.parse(create_table_sql)[0].tokens
table_name = None
columns = []
for token in tokens:
if token.ttype is None and token.is_group:
if token.value.startswith("("):
inner = token.value.strip("()")
for col_def in inner.split(","):
parts = col_def.strip().split()
if len(parts) >= 2:
col_name = parts[0]
data_type = parts[1]
columns.append((col_name, data_type))
break
elif token.ttype is None and not table_name:
table_name = token.value.strip()
return table_name, columns
def generate_value_for_type(data_type: str):
"""
根据字段数据类型生成测试数据的简单示例
"""
data_type = data_type.upper()
if "INT" in data_type:
return fake.random_int(min=1, max=1000)
elif "CHAR" in data_type or "TEXT" in data_type:
return fake.word()
elif "DATE" in data_type:
return fake.date()
elif "FLOAT" in data_type or "DOUBLE" in data_type or "DECIMAL" in data_type:
return fake.pyfloat(left_digits=3, right_digits=2, positive=True)
else:
return "test"
@tool
def sql_unit_test(sql_code: str, num_rows: int = 10, expected_output: str = None) -> str:
"""
针对 SQL 开发和单元测试的智能体工具,自动生成测试数据并执行 SQL 代码。
Args:
sql_code: 包含建表、插入数据和查询语句的 SQL 代码,多个语句以分号分隔。
num_rows: 为每个表生成测试数据的行数,默认 10 行。
expected_output: 可选的预期查询结果,用于自动验证查询输出。
Returns:
执行结果、调试日志及验证反馈。
"""
try:
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
statements = sqlparse.split(sql_code)
results = []
# 先处理建表和数据生成
for stmt in statements:
stmt = stmt.strip()
if stmt.upper().startswith("CREATE TABLE"):
cursor.execute(stmt)
table_name, columns = extract_table_info(stmt)
if table_name and columns:
for _ in range(num_rows):
values = [generate_value_for_type(col_type) for col_name, col_type in columns]
placeholders = ",".join(["?"] * len(values))
insert_stmt = f"INSERT INTO {table_name} VALUES ({placeholders})"
cursor.execute(insert_stmt, values)
conn.commit()
# 执行其他 SQL 语句(如 SELECT)
for stmt in statements:
stmt = stmt.strip()
if stmt.upper().startswith("SELECT"):
cursor.execute(stmt)
fetched = cursor.fetchall()
results.append(f"查询结果: {fetched}")
output = "\n".join(results) if results else "SQL代码执行成功,但未检测到SELECT查询。"
# 如果提供了预期输出,自动对比结果
if expected_output:
if expected_output.strip() == output.strip():
output += "\n验证结果:查询结果与预期一致。"
else:
output += "\n验证结果:查询结果与预期不匹配。"
return output
except Exception as e:
return f"SQL代码执行出错,错误信息: {str(e)}"
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
# Create timezone object
tz = pytz.timezone(timezone)
# Get current time in that timezone
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
final_answer = FinalAnswerTool()
# If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
custom_role_conversions=None,
)
# Import tool from Hub
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, sql_unit_test, get_current_time_in_timezone, image_generation_tool], ## add your tools here (don't remove final answer)
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
GradioUI(agent).launch()