File size: 6,239 Bytes
9b5b26a
 
 
 
c19d193
6aae614
b39d1c1
 
 
8fe992b
9b5b26a
 
5df72d6
9b5b26a
3d1237b
9b5b26a
 
 
 
 
 
 
b39d1c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b5b26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c01ffb
 
6aae614
ae7a494
 
 
 
e121372
bf6d34c
 
29ec968
fe328e0
13d500a
8c01ffb
 
9b5b26a
 
8c01ffb
861422e
 
9b5b26a
8c01ffb
8fe992b
83c8c40
8c01ffb
 
 
 
 
 
861422e
8fe992b
 
9b5b26a
8c01ffb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool
import datetime
import requests
import pytz
import yaml
from tools.final_answer import FinalAnswerTool
import sqlite3
import sqlparse  # pip install sqlparse
from faker import Faker  # pip install faker

from Gradio_UI import GradioUI

# Below is an example of a tool that does nothing. Amaze us with your creativity !
@tool
def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type
    #Keep this format for the description / args / args description but feel free to modify the tool
    """A tool that does nothing yet 
    Args:
        arg1: the first argument
        arg2: the second argument
    """
    return "What magic will you build ?"
    
def extract_table_info(create_table_sql: str):
    """
    简单解析 CREATE TABLE 语句,返回表名和字段列表
    字段列表格式为 [(column_name, data_type), ...]
    注意:此函数适用于结构较为简单的 SQL 语句
    """
    tokens = sqlparse.parse(create_table_sql)[0].tokens
    table_name = None
    columns = []
    for token in tokens:
        if token.ttype is None and token.is_group:
            if token.value.startswith("("):
                inner = token.value.strip("()")
                for col_def in inner.split(","):
                    parts = col_def.strip().split()
                    if len(parts) >= 2:
                        col_name = parts[0]
                        data_type = parts[1]
                        columns.append((col_name, data_type))
                break
        elif token.ttype is None and not table_name:
            table_name = token.value.strip()
    return table_name, columns

def generate_value_for_type(data_type: str):
    """
    根据字段数据类型生成测试数据的简单示例
    """
    data_type = data_type.upper()
    if "INT" in data_type:
        return fake.random_int(min=1, max=1000)
    elif "CHAR" in data_type or "TEXT" in data_type:
        return fake.word()
    elif "DATE" in data_type:
        return fake.date()
    elif "FLOAT" in data_type or "DOUBLE" in data_type or "DECIMAL" in data_type:
        return fake.pyfloat(left_digits=3, right_digits=2, positive=True)
    else:
        return "test"

@tool
def sql_unit_test(sql_code: str, num_rows: int = 10, expected_output: str = None) -> str:
    """
    针对 SQL 开发和单元测试的智能体工具,自动生成测试数据并执行 SQL 代码。
    
    Args:
        sql_code: 包含建表、插入数据和查询语句的 SQL 代码,多个语句以分号分隔。
        num_rows: 为每个表生成测试数据的行数,默认 10 行。
        expected_output: 可选的预期查询结果,用于自动验证查询输出。
        
    Returns:
        执行结果、调试日志及验证反馈。
    """
    try:
        conn = sqlite3.connect(':memory:')
        cursor = conn.cursor()
        statements = sqlparse.split(sql_code)
        results = []
        
        # 先处理建表和数据生成
        for stmt in statements:
            stmt = stmt.strip()
            if stmt.upper().startswith("CREATE TABLE"):
                cursor.execute(stmt)
                table_name, columns = extract_table_info(stmt)
                if table_name and columns:
                    for _ in range(num_rows):
                        values = [generate_value_for_type(col_type) for col_name, col_type in columns]
                        placeholders = ",".join(["?"] * len(values))
                        insert_stmt = f"INSERT INTO {table_name} VALUES ({placeholders})"
                        cursor.execute(insert_stmt, values)
        conn.commit()
        
        # 执行其他 SQL 语句(如 SELECT)
        for stmt in statements:
            stmt = stmt.strip()
            if stmt.upper().startswith("SELECT"):
                cursor.execute(stmt)
                fetched = cursor.fetchall()
                results.append(f"查询结果: {fetched}")
        
        output = "\n".join(results) if results else "SQL代码执行成功,但未检测到SELECT查询。"
        
        # 如果提供了预期输出,自动对比结果
        if expected_output:
            if expected_output.strip() == output.strip():
                output += "\n验证结果:查询结果与预期一致。"
            else:
                output += "\n验证结果:查询结果与预期不匹配。"
        
        return output
    except Exception as e:
        return f"SQL代码执行出错,错误信息: {str(e)}"

@tool
def get_current_time_in_timezone(timezone: str) -> str:
    """A tool that fetches the current local time in a specified timezone.
    Args:
        timezone: A string representing a valid timezone (e.g., 'America/New_York').
    """
    try:
        # Create timezone object
        tz = pytz.timezone(timezone)
        # Get current time in that timezone
        local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
        return f"The current local time in {timezone} is: {local_time}"
    except Exception as e:
        return f"Error fetching time for timezone '{timezone}': {str(e)}"


final_answer = FinalAnswerTool()

# If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder:
# model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud' 

model = HfApiModel(
max_tokens=2096,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded
custom_role_conversions=None,
)


# Import tool from Hub
image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True)

with open("prompts.yaml", 'r') as stream:
    prompt_templates = yaml.safe_load(stream)
    
agent = CodeAgent(
    model=model,
    tools=[final_answer, sql_unit_test, get_current_time_in_timezone, image_generation_tool], ## add your tools here (don't remove final answer)
    max_steps=6,
    verbosity_level=1,
    grammar=None,
    planning_interval=None,
    name=None,
    description=None,
    prompt_templates=prompt_templates
)


GradioUI(agent).launch()