Spaces:
Sleeping
Sleeping
| import json | |
| import os | |
| from typing import Annotated, Dict | |
| from spider_env import SpiderEnv | |
| class SQLExec: | |
| def __init__(self, sql_writer, user_proxy): | |
| self.gym = SpiderEnv(cache_dir='./.cache') | |
| self.sql_writer = sql_writer | |
| self.user_proxy = user_proxy | |
| def sql_spider(self): | |
| # %pip install spider-env | |
| #gym = SpiderEnv() | |
| # Randomly select a question from Spider | |
| observation, info = self.gym.reset() | |
| # The natural language question | |
| question = observation["instruction"] | |
| print(question) | |
| # The schema of the corresponding database | |
| schema = info["schema"] | |
| print(schema) | |
| def sql_exec(self): | |
| def execute_sql( | |
| reflection: Annotated[str, "Think about what to do"], sql: Annotated[str, "SQL query"] | |
| ) -> Annotated[Dict[str, str], "Dictionary with keys 'result' and 'error'"]: | |
| observation, reward, _, _, info = self.gym.step(sql) | |
| error = observation["feedback"]["error"] | |
| if not error and reward == 0: | |
| error = "The SQL query returned an incorrect result" | |
| if error: | |
| return { | |
| "error": error, | |
| "wrong_result": observation["feedback"]["result"], | |
| "correct_result": info["gold_result"], | |
| } | |
| else: | |
| return { | |
| "result": observation["feedback"]["result"], | |
| } | |