import random import asyncio import re from datetime import datetime, timezone from typing import Optional, Tuple import sqlite3 from app.models import Action, Observation, Reward, StateSnapshot from app.state_manager import EpisodeState, generate_episode, get_schema_info, take_snapshot from app.reward import RewardEngine from app.tasks import get_task from app.graders import grade_task1, grade_task2, grade_task3 def validate_sql(sql: str, action_type: str) -> Tuple[bool, str]: if not sql: return False, "Empty SQL statement" sql_upper = sql.strip().upper() tokens = sql_upper.split() if not tokens: return False, "Empty SQL statement" first_token = tokens[0] if action_type == "query": allowed_starts = {"SELECT", "WITH", "EXPLAIN", "PRAGMA"} if first_token not in allowed_starts: return False, "Only SELECT statements allowed in query actions" blocked_patterns = [ r"DROP\s+TABLE", r"DELETE\s+FROM\s+SQLITE_MASTER", r"DROP\s+INDEX", r"ATTACH\s+DATABASE" ] for pat in blocked_patterns: if re.search(pat, sql_upper): clean_pat = pat.replace('\\s+', ' ') return False, f"Blocked pattern detected in query: {clean_pat}" return True, "" elif action_type == "ddl": if re.search(r"\bDROP\s+TABLE\b", sql_upper): return False, "DROP TABLE is blocked" if re.search(r"\bATTACH\b|\bDETACH\b", sql_upper): return False, "ATTACH and DETACH are blocked" if re.search(r"(UPDATE|INSERT\s+INTO|DELETE\s+FROM)\s+(SQLITE_MASTER|SQLITE_SEQUENCE)\b", sql_upper): return False, "Writing to sqlite_master or sqlite_sequence is blocked" if first_token == "ALTER": if not re.search(r"^ALTER\s+TABLE\s+.*?\s+RENAME\s+COLUMN", sql_upper): return False, "Only ALTER TABLE ... RENAME COLUMN is allowed" if first_token == "CREATE": if re.search(r"^CREATE\s+TABLE", sql_upper) and not re.search(r"^CREATE\s+(TEMP|TEMPORARY)\s+TABLE", sql_upper): return False, "Only temporary tables are allowed (CREATE TEMP TABLE)" if not (re.search(r"^CREATE\s+(TEMP|TEMPORARY)\s+TABLE", sql_upper) or re.search(r"^CREATE\s+VIEW", sql_upper)): return False, "Only CREATE VIEW or CREATE TEMP TABLE allowed for CREATE" if first_token == "DROP": if not re.search(r"^DROP\s+VIEW", sql_upper): return False, "Only DROP VIEW is allowed for DROP statements" allowed_starts = {"UPDATE", "INSERT", "DELETE", "ALTER", "CREATE", "DROP"} if first_token not in allowed_starts: return False, f"DDL action does not allow '{first_token}' statements" return True, "" return True, "" class DataOpsEnv: def __init__(self): self.state: Optional[EpisodeState] = None self.reward_engine: Optional[RewardEngine] = None self.task_config: Optional[dict] = None self._lock = asyncio.Lock() self.last_activity = datetime.now(timezone.utc) self._last_grader_score = None async def reset(self, task_id: int, seed: int = None, difficulty_multiplier: float = 1.0) -> Observation: async with self._lock: self.last_activity = datetime.now(timezone.utc) if task_id not in [1, 2, 3]: raise ValueError("task_id must be 1, 2, or 3") self.state = generate_episode(task_id, seed, difficulty_multiplier) task_info = get_task(task_id) self.task_config = { "task_id": task_id, } main_table = self.state.table_registry.get("main") if task_id == 1: id_col = self.state.column_registry.get("id") rows = self.state.initial_snapshot.get(main_table, []) self.task_config["initial_null_count"] = sum(1 for r in rows if r.get(id_col) is None) elif task_id == 2: rows = self.state.initial_snapshot.get(main_table, []) self.task_config["total_rows"] = len(rows) self.task_config["pii_columns"] = [self.state.column_registry.get("email"), self.state.column_registry.get("phone")] self.task_config["ssn_col"] = self.state.column_registry.get("ssn_col") elif task_id == 3: self.task_config["expected_view_output"] = True self.reward_engine = RewardEngine(self.task_config) system_logs = [] if task_id == 3: err_table = self.state.table_registry.get("error_log") if err_table: try: cursor = self.state.db.cursor() cursor.execute(f"SELECT msg FROM {err_table}") system_logs = [r["msg"] for r in cursor.fetchall()] except Exception: pass self._last_grader_score = self.grader_score() return Observation( current_step=0, max_steps=self.state.max_steps, task_id=task_id, task_description=task_info["description"], last_action_status="NONE", last_error_message=None, query_results=[], results_truncated=False, total_rows_returned=0, schema_info=get_schema_info(self.state), system_logs=system_logs[:20], logs_truncated=len(system_logs) > 20, progress_hint=None ) def grader_score(self) -> float: if not self.state: return 0.0 if self.state.task_id == 1: return grade_task1(self.state.db, self.state) elif self.state.task_id == 2: return grade_task2(self.state.db, self.state) elif self.state.task_id == 3: return grade_task3(self.state.db, self.state) return 0.0 def get_state(self) -> StateSnapshot: if not self.state: raise ValueError("Environment not initialized") tables = take_snapshot(self.state) return StateSnapshot( episode_id=self.state.episode_id, task_id=self.state.task_id, current_step=self.state.current_step, tables=tables, trajectory=self.state.trajectory, grader_score=self.grader_score(), seed=self.state.seed, difficulty_multiplier=self.state.difficulty_multiplier ) async def step(self, action: Action, session_id: str = "") -> Tuple[Observation, Reward]: async with self._lock: try: self.last_activity = datetime.now(timezone.utc) if not self.state or self.state.done: raise RuntimeError("Episode is not active. Call reset().") score_before = getattr(self, "_last_grader_score", None) if score_before is None: score_before = self.grader_score() try: action_dict = action.model_dump() except AttributeError: action_dict = action if isinstance(action, dict) else dict(action) action_type = getattr(action, "action_type", action_dict.get("action_type")) state_before = self.get_state().model_dump() action_result = { "status": "SUCCESS", "error_message": None, "rows": [], "results_truncated": False, "total_rows_returned": 0 } sql = getattr(action, "sql", action_dict.get("sql", "")) is_valid = True val_msg = "" if action_type in ["query", "ddl"]: is_valid, val_msg = validate_sql(sql, action_type) if not is_valid: action_result["status"] = "ERROR" action_result["error_message"] = val_msg else: self.state.current_step += 1 try: cursor = self.state.db.cursor() if action_type == "query": cursor.execute(sql) all_rows = cursor.fetchall() total = len(all_rows) display_rows = all_rows[:10] # hard cap at 10 def truncate_value(v, max_len=100): if v is None: return None s = str(v) return s[:max_len] + "..." if len(s) > max_len else s col_names = [d[0] for d in cursor.description] if cursor.description else [] result_dicts = [ {col: truncate_value(val) for col, val in zip(col_names, row)} for row in display_rows ] action_result["rows"] = result_dicts action_result["results_truncated"] = total > 10 action_result["total_rows_returned"] = total elif action_type == "ddl": cursor.execute(sql) self.state.db.commit() elif action_type == "test": target_table = getattr(action, "target_table", action_dict.get("target_table")) cursor.execute(f"SELECT COUNT(*) as cnt FROM {target_table}") action_result["rows"] = [dict(r) for r in cursor.fetchall()] elif action_type == "submit": self.state.done = True except Exception as e: action_result["status"] = "ERROR" action_result["error_message"] = str(e) score_after = self.grader_score() self._last_grader_score = score_after state_after = self.get_state().model_dump() state_after["grader_score"] = score_after step_reward_val, breakdown = self.reward_engine.compute( action=action_dict, action_result=action_result, state_before=state_before, state_after=state_after, grader_score_before=score_before, grader_score_after=score_after ) truncated = False if self.state.current_step >= self.state.max_steps: truncated = True self.state.done = True progress_hint = None if self.state.current_step > 8 and score_after < 0.1: task_info = get_task(self.state.task_id) hints = task_info.get("hints", []) progress_hint = random.choice(hints) if hints else "Review the schema and target carefully." system_logs = [] if self.state.task_id == 3: err_table = self.state.table_registry.get("error_log") if err_table: try: c = self.state.db.cursor() c.execute(f"SELECT msg FROM {err_table}") system_logs = [r["msg"] for r in c.fetchall()] except Exception: pass obs = Observation( current_step=self.state.current_step, max_steps=self.state.max_steps, task_id=self.state.task_id, task_description=get_task(self.state.task_id)["description"], last_action_status=action_result["status"], last_error_message=action_result["error_message"], query_results=action_result["rows"], results_truncated=action_result.get("results_truncated", False), total_rows_returned=action_result.get("total_rows_returned", 0), schema_info=get_schema_info(self.state), system_logs=system_logs[:20], logs_truncated=len(system_logs) > 20, progress_hint=progress_hint ) reward = Reward( step_reward=step_reward_val, cumulative_reward=self.reward_engine.cumulative, reward_breakdown=breakdown, done=self.state.done, truncated=truncated, grader_score_before=score_before, grader_score_after=score_after ) self.state.trajectory.append({ "action": action_dict, "observation": obs.model_dump(), "reward": reward.model_dump() }) return obs, reward except sqlite3.OperationalError as e: # SQL syntax errors, missing tables, broken views return self._error_observation( error_msg=f"SQL error: {str(e)}", reward_penalty=-0.05 ), self._error_reward(breakdown={"sql_error": -0.05}) except sqlite3.DatabaseError as e: # Corrupted state, PRAGMA failures, trigger issues return self._error_observation( error_msg=f"Database error: {str(e)}", reward_penalty=-0.10 ), self._error_reward(breakdown={"db_error": -0.10}) except Exception as e: # Catch-all: unknown agent-triggered edge cases # Log the full traceback internally but NEVER expose it import traceback internal_log = traceback.format_exc() # Store in state for debugging but do not return to agent if self.state: self.state.trajectory.append({ "step": self.state.current_step, "internal_error": internal_log[:500] }) return self._error_observation( error_msg="Internal error — action could not be processed", reward_penalty=-0.05 ), self._error_reward(breakdown={"internal_error": -0.05}) def _error_observation(self, error_msg: str, reward_penalty: float) -> Observation: return Observation( current_step=self.state.current_step if self.state else 0, max_steps=self.state.max_steps if self.state else 20, task_id=self.state.task_id if self.state else 0, task_description="", last_action_status="ERROR", last_error_message=error_msg, query_results=[], schema_info={}, system_logs=[f"ERROR: {error_msg}"], results_truncated=False, total_rows_returned=0, progress_hint=None ) def _error_reward(self, breakdown: dict) -> Reward: step_reward = sum(breakdown.values()) if self.state: self.state.cumulative_reward += step_reward return Reward( step_reward=step_reward, cumulative_reward=self.state.cumulative_reward if self.state else step_reward, reward_breakdown=breakdown, done=False, truncated=False, grader_score_before=0.0, grader_score_after=0.0 )