Spaces:
Runtime error
Runtime error
File size: 6,239 Bytes
9b5b26a c19d193 6aae614 b39d1c1 8fe992b 9b5b26a 5df72d6 9b5b26a 3d1237b 9b5b26a b39d1c1 9b5b26a 8c01ffb 6aae614 ae7a494 e121372 bf6d34c 29ec968 fe328e0 13d500a 8c01ffb 9b5b26a 8c01ffb 861422e 9b5b26a 8c01ffb 8fe992b 83c8c40 8c01ffb 861422e 8fe992b 9b5b26a 8c01ffb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
import datetime
import requests
import pytz
import yaml
from tools.final_answer import FinalAnswerTool
import sqlite3
import sqlparse # pip install sqlparse
from faker import Faker # pip install faker
from Gradio_UI import GradioUI
# Below is an example of a tool that does nothing. Amaze us with your creativity !
@tool
def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type
#Keep this format for the description / args / args description but feel free to modify the tool
"""A tool that does nothing yet
Args:
arg1: the first argument
arg2: the second argument
"""
return "What magic will you build ?"
def extract_table_info(create_table_sql: str):
"""
简单解析 CREATE TABLE 语句,返回表名和字段列表
字段列表格式为 [(column_name, data_type), ...]
注意:此函数适用于结构较为简单的 SQL 语句
"""
tokens = sqlparse.parse(create_table_sql)[0].tokens
table_name = None
columns = []
for token in tokens:
if token.ttype is None and token.is_group:
if token.value.startswith("("):
inner = token.value.strip("()")
for col_def in inner.split(","):
parts = col_def.strip().split()
if len(parts) >= 2:
col_name = parts[0]
data_type = parts[1]
columns.append((col_name, data_type))
break
elif token.ttype is None and not table_name:
table_name = token.value.strip()
return table_name, columns
def generate_value_for_type(data_type: str):
"""
根据字段数据类型生成测试数据的简单示例
"""
data_type = data_type.upper()
if "INT" in data_type:
return fake.random_int(min=1, max=1000)
elif "CHAR" in data_type or "TEXT" in data_type:
return fake.word()
elif "DATE" in data_type:
return fake.date()
elif "FLOAT" in data_type or "DOUBLE" in data_type or "DECIMAL" in data_type:
return fake.pyfloat(left_digits=3, right_digits=2, positive=True)
else:
return "test"
@tool
def sql_unit_test(sql_code: str, num_rows: int = 10, expected_output: str = None) -> str:
"""
针对 SQL 开发和单元测试的智能体工具,自动生成测试数据并执行 SQL 代码。
Args:
sql_code: 包含建表、插入数据和查询语句的 SQL 代码,多个语句以分号分隔。
num_rows: 为每个表生成测试数据的行数,默认 10 行。
expected_output: 可选的预期查询结果,用于自动验证查询输出。
Returns:
执行结果、调试日志及验证反馈。
"""
try:
conn = sqlite3.connect(':memory:')
cursor = conn.cursor()
statements = sqlparse.split(sql_code)
results = []
# 先处理建表和数据生成
for stmt in statements:
stmt = stmt.strip()
if stmt.upper().startswith("CREATE TABLE"):
cursor.execute(stmt)
table_name, columns = extract_table_info(stmt)
if table_name and columns:
for _ in range(num_rows):
values = [generate_value_for_type(col_type) for col_name, col_type in columns]
placeholders = ",".join(["?"] * len(values))
insert_stmt = f"INSERT INTO {table_name} VALUES ({placeholders})"
cursor.execute(insert_stmt, values)
conn.commit()
# 执行其他 SQL 语句(如 SELECT)
for stmt in statements:
stmt = stmt.strip()
if stmt.upper().startswith("SELECT"):
cursor.execute(stmt)
fetched = cursor.fetchall()
results.append(f"查询结果: {fetched}")
output = "\n".join(results) if results else "SQL代码执行成功,但未检测到SELECT查询。"
# 如果提供了预期输出,自动对比结果
if expected_output:
if expected_output.strip() == output.strip():
output += "\n验证结果:查询结果与预期一致。"
else:
output += "\n验证结果:查询结果与预期不匹配。"
return output
except Exception as e:
return f"SQL代码执行出错,错误信息: {str(e)}"
@tool
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
# Create timezone object
tz = pytz.timezone(timezone)
# Get current time in that timezone
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
final_answer = FinalAnswerTool()
# If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud'
model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
custom_role_conversions=None,
)
# Import tool from Hub
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)
with open("prompts.yaml", 'r') as stream:
prompt_templates = yaml.safe_load(stream)
agent = CodeAgent(
model=model,
tools=[final_answer, sql_unit_test, get_current_time_in_timezone, image_generation_tool], ## add your tools here (don't remove final answer)
max_steps=6,
verbosity_level=1,
grammar=None,
planning_interval=None,
name=None,
description=None,
prompt_templates=prompt_templates
)
GradioUI(agent).launch() |