TaoZewen commited on
Commit
b39d1c1
·
verified ·
1 Parent(s): ae7a494

Update app.py

Browse files

增加了 针对 SQL 开发和单元测试中数据准备和结果验证的问题而开发的tool

Files changed (1) hide show
  1. app.py +98 -1
app.py CHANGED
@@ -4,6 +4,9 @@ import requests
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
 
 
 
7
 
8
  from Gradio_UI import GradioUI
9
 
@@ -17,6 +20,100 @@ def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return
17
  arg2: the second argument
18
  """
19
  return "What magic will you build ?"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  @tool
22
  def get_current_time_in_timezone(timezone: str) -> str:
@@ -55,7 +152,7 @@ with open("prompts.yaml", 'r') as stream:
55
 
56
  agent = CodeAgent(
57
  model=model,
58
- tools=[final_answer], ## add your tools here (don't remove final answer)
59
  max_steps=6,
60
  verbosity_level=1,
61
  grammar=None,
 
4
  import pytz
5
  import yaml
6
  from tools.final_answer import FinalAnswerTool
7
+ import sqlite3
8
+ import sqlparse # pip install sqlparse
9
+ from faker import Faker # pip install faker
10
 
11
  from Gradio_UI import GradioUI
12
 
 
20
  arg2: the second argument
21
  """
22
  return "What magic will you build ?"
23
+
24
+ def extract_table_info(create_table_sql: str):
25
+ """
26
+ 简单解析 CREATE TABLE 语句,返回表名和字段列表
27
+ 字段列表格式为 [(column_name, data_type), ...]
28
+ 注意:此函数适用于结构较为简单的 SQL 语句
29
+ """
30
+ tokens = sqlparse.parse(create_table_sql)[0].tokens
31
+ table_name = None
32
+ columns = []
33
+ for token in tokens:
34
+ if token.ttype is None and token.is_group:
35
+ if token.value.startswith("("):
36
+ inner = token.value.strip("()")
37
+ for col_def in inner.split(","):
38
+ parts = col_def.strip().split()
39
+ if len(parts) >= 2:
40
+ col_name = parts[0]
41
+ data_type = parts[1]
42
+ columns.append((col_name, data_type))
43
+ break
44
+ elif token.ttype is None and not table_name:
45
+ table_name = token.value.strip()
46
+ return table_name, columns
47
+
48
+ def generate_value_for_type(data_type: str):
49
+ """
50
+ 根据字段数据类型生成测试数据的简单示例
51
+ """
52
+ data_type = data_type.upper()
53
+ if "INT" in data_type:
54
+ return fake.random_int(min=1, max=1000)
55
+ elif "CHAR" in data_type or "TEXT" in data_type:
56
+ return fake.word()
57
+ elif "DATE" in data_type:
58
+ return fake.date()
59
+ elif "FLOAT" in data_type or "DOUBLE" in data_type or "DECIMAL" in data_type:
60
+ return fake.pyfloat(left_digits=3, right_digits=2, positive=True)
61
+ else:
62
+ return "test"
63
+
64
+ @tool
65
+ def sql_unit_test(sql_code: str, num_rows: int = 10, expected_output: str = None) -> str:
66
+ """
67
+ 针对 SQL 开发和单元测试的智能体工具,自动生成测试数据并执行 SQL 代码。
68
+
69
+ Args:
70
+ sql_code: 包含建表、插入数据和查询语句的 SQL 代码,多个语句以分号分隔。
71
+ num_rows: 为每个表生成测试数据的行数,默认 10 行。
72
+ expected_output: 可选的预期查询结果,用于自动验证查询输出。
73
+
74
+ Returns:
75
+ 执行结果、调试日志及验证反馈。
76
+ """
77
+ try:
78
+ conn = sqlite3.connect(':memory:')
79
+ cursor = conn.cursor()
80
+ statements = sqlparse.split(sql_code)
81
+ results = []
82
+
83
+ # 先处理建表和数据生成
84
+ for stmt in statements:
85
+ stmt = stmt.strip()
86
+ if stmt.upper().startswith("CREATE TABLE"):
87
+ cursor.execute(stmt)
88
+ table_name, columns = extract_table_info(stmt)
89
+ if table_name and columns:
90
+ for _ in range(num_rows):
91
+ values = [generate_value_for_type(col_type) for col_name, col_type in columns]
92
+ placeholders = ",".join(["?"] * len(values))
93
+ insert_stmt = f"INSERT INTO {table_name} VALUES ({placeholders})"
94
+ cursor.execute(insert_stmt, values)
95
+ conn.commit()
96
+
97
+ # 执行其他 SQL 语句(如 SELECT)
98
+ for stmt in statements:
99
+ stmt = stmt.strip()
100
+ if stmt.upper().startswith("SELECT"):
101
+ cursor.execute(stmt)
102
+ fetched = cursor.fetchall()
103
+ results.append(f"查询结果: {fetched}")
104
+
105
+ output = "\n".join(results) if results else "SQL代码执行成功,但未检测到SELECT查询。"
106
+
107
+ # 如果提供了预期输出,自动对比结果
108
+ if expected_output:
109
+ if expected_output.strip() == output.strip():
110
+ output += "\n验证结果:查询结果与预期一致。"
111
+ else:
112
+ output += "\n验证结果:查询结果与预期不匹配。"
113
+
114
+ return output
115
+ except Exception as e:
116
+ return f"SQL代码执行出错,错误信息: {str(e)}"
117
 
118
  @tool
119
  def get_current_time_in_timezone(timezone: str) -> str:
 
152
 
153
  agent = CodeAgent(
154
  model=model,
155
+ tools=[final_answer,sql_unit_test,get_current_time_in_timezone,image_generation_tool], ## add your tools here (don't remove final answer)
156
  max_steps=6,
157
  verbosity_level=1,
158
  grammar=None,