Spaces:
Runtime error
Runtime error
| from smolagents import CodeAgent,DuckDuckGoSearchTool, HfApiModel,load_tool,tool | |
| import datetime | |
| import requests | |
| import pytz | |
| import yaml | |
| from tools.final_answer import FinalAnswerTool | |
| import sqlite3 | |
| import sqlparse # pip install sqlparse | |
| from faker import Faker # pip install faker | |
| from Gradio_UI import GradioUI | |
| # Below is an example of a tool that does nothing. Amaze us with your creativity ! | |
| def my_custom_tool(arg1:str, arg2:int)-> str: #it's import to specify the return type | |
| #Keep this format for the description / args / args description but feel free to modify the tool | |
| """A tool that does nothing yet | |
| Args: | |
| arg1: the first argument | |
| arg2: the second argument | |
| """ | |
| return "What magic will you build ?" | |
| def extract_table_info(create_table_sql: str): | |
| """ | |
| 简单解析 CREATE TABLE 语句,返回表名和字段列表 | |
| 字段列表格式为 [(column_name, data_type), ...] | |
| 注意:此函数适用于结构较为简单的 SQL 语句 | |
| """ | |
| tokens = sqlparse.parse(create_table_sql)[0].tokens | |
| table_name = None | |
| columns = [] | |
| for token in tokens: | |
| if token.ttype is None and token.is_group: | |
| if token.value.startswith("("): | |
| inner = token.value.strip("()") | |
| for col_def in inner.split(","): | |
| parts = col_def.strip().split() | |
| if len(parts) >= 2: | |
| col_name = parts[0] | |
| data_type = parts[1] | |
| columns.append((col_name, data_type)) | |
| break | |
| elif token.ttype is None and not table_name: | |
| table_name = token.value.strip() | |
| return table_name, columns | |
| def generate_value_for_type(data_type: str): | |
| """ | |
| 根据字段数据类型生成测试数据的简单示例 | |
| """ | |
| data_type = data_type.upper() | |
| if "INT" in data_type: | |
| return fake.random_int(min=1, max=1000) | |
| elif "CHAR" in data_type or "TEXT" in data_type: | |
| return fake.word() | |
| elif "DATE" in data_type: | |
| return fake.date() | |
| elif "FLOAT" in data_type or "DOUBLE" in data_type or "DECIMAL" in data_type: | |
| return fake.pyfloat(left_digits=3, right_digits=2, positive=True) | |
| else: | |
| return "test" | |
| def sql_unit_test(sql_code: str, num_rows: int = 10, expected_output: str = None) -> str: | |
| """ | |
| 针对 SQL 开发和单元测试的智能体工具,自动生成测试数据并执行 SQL 代码。 | |
| Args: | |
| sql_code: 包含建表、插入数据和查询语句的 SQL 代码,多个语句以分号分隔。 | |
| num_rows: 为每个表生成测试数据的行数,默认 10 行。 | |
| expected_output: 可选的预期查询结果,用于自动验证查询输出。 | |
| Returns: | |
| 执行结果、调试日志及验证反馈。 | |
| """ | |
| try: | |
| conn = sqlite3.connect(':memory:') | |
| cursor = conn.cursor() | |
| statements = sqlparse.split(sql_code) | |
| results = [] | |
| # 先处理建表和数据生成 | |
| for stmt in statements: | |
| stmt = stmt.strip() | |
| if stmt.upper().startswith("CREATE TABLE"): | |
| cursor.execute(stmt) | |
| table_name, columns = extract_table_info(stmt) | |
| if table_name and columns: | |
| for _ in range(num_rows): | |
| values = [generate_value_for_type(col_type) for col_name, col_type in columns] | |
| placeholders = ",".join(["?"] * len(values)) | |
| insert_stmt = f"INSERT INTO {table_name} VALUES ({placeholders})" | |
| cursor.execute(insert_stmt, values) | |
| conn.commit() | |
| # 执行其他 SQL 语句(如 SELECT) | |
| for stmt in statements: | |
| stmt = stmt.strip() | |
| if stmt.upper().startswith("SELECT"): | |
| cursor.execute(stmt) | |
| fetched = cursor.fetchall() | |
| results.append(f"查询结果: {fetched}") | |
| output = "\n".join(results) if results else "SQL代码执行成功,但未检测到SELECT查询。" | |
| # 如果提供了预期输出,自动对比结果 | |
| if expected_output: | |
| if expected_output.strip() == output.strip(): | |
| output += "\n验证结果:查询结果与预期一致。" | |
| else: | |
| output += "\n验证结果:查询结果与预期不匹配。" | |
| return output | |
| except Exception as e: | |
| return f"SQL代码执行出错,错误信息: {str(e)}" | |
| def get_current_time_in_timezone(timezone: str) -> str: | |
| """A tool that fetches the current local time in a specified timezone. | |
| Args: | |
| timezone: A string representing a valid timezone (e.g., 'America/New_York'). | |
| """ | |
| try: | |
| # Create timezone object | |
| tz = pytz.timezone(timezone) | |
| # Get current time in that timezone | |
| local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") | |
| return f"The current local time in {timezone} is: {local_time}" | |
| except Exception as e: | |
| return f"Error fetching time for timezone '{timezone}': {str(e)}" | |
| final_answer = FinalAnswerTool() | |
| # If the agent does not answer, the model is overloaded, please use another model or the following Hugging Face Endpoint that also contains qwen2.5 coder: | |
| # model_id='https://pflgm2locj2t89co.us-east-1.aws.endpoints.huggingface.cloud' | |
| model = HfApiModel( | |
| max_tokens=2096, | |
| temperature=0.5, | |
| model_id='Qwen/Qwen2.5-Coder-32B-Instruct',# it is possible that this model may be overloaded | |
| custom_role_conversions=None, | |
| ) | |
| # Import tool from Hub | |
| image_generation_tool = load_tool("agents-course/text-to-image", trust_remote_code=True) | |
| with open("prompts.yaml", 'r') as stream: | |
| prompt_templates = yaml.safe_load(stream) | |
| agent = CodeAgent( | |
| model=model, | |
| tools=[final_answer, sql_unit_test, get_current_time_in_timezone, image_generation_tool], ## add your tools here (don't remove final answer) | |
| max_steps=6, | |
| verbosity_level=1, | |
| grammar=None, | |
| planning_interval=None, | |
| name=None, | |
| description=None, | |
| prompt_templates=prompt_templates | |
| ) | |
| GradioUI(agent).launch() |