Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import shutil | |
| import sqlite3 | |
| import textwrap | |
| import threading | |
| import time | |
| import uuid | |
| from copy import deepcopy | |
| from typing import Any, Optional | |
| from openenv.core.env_server import Environment | |
| from pydantic import ValidationError | |
| from data.init_db import WORKSPACE_ROOT, setup_workspace | |
| from models import ( | |
| PAYLOAD_MODELS, | |
| DataOpsAction, | |
| DataOpsObservation, | |
| DataOpsState, | |
| ExecuteSQLPayload, | |
| ReadFilePayload, | |
| RunScriptPayload, | |
| SendEmailPayload, | |
| WriteFilePayload, | |
| ) | |
| from .safe_exec import PythonRunResult, run_python_code, run_python_script | |
| from .task_specs import ( | |
| TASK_ALLOWED_READ_FILES, | |
| TASK_ALLOWED_RUN_FILES, | |
| TASK_ALLOWED_WRITE_FILES, | |
| TASK_EMAIL_ENABLED, | |
| TASK_IDS, | |
| TASK_SQL_POLICIES, | |
| TaskScenarioBundle, | |
| build_task_scenario, | |
| normalize_task_3_rows, | |
| report_matches_expected, | |
| task_3_data_matches_expected, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| _SQL_COMMENT_RE = re.compile(r"(--[^\n]*|/\*.*?\*/)", re.DOTALL) | |
| _SQL_STRING_RE = re.compile(r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"") | |
| _SQL_TABLE_REF_RE = re.compile( | |
| r"\b(?:from|join|update|into|delete\s+from)\s+([a-zA-Z_][a-zA-Z0-9_]*)", | |
| re.IGNORECASE, | |
| ) | |
| _SQL_CTE_NAME_RE = re.compile( | |
| r"(?:\bwith\b|,)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s+as\s*\(", | |
| re.IGNORECASE, | |
| ) | |
| MAX_STEPS = 15 | |
| MAX_SQL_ROWS = 500 | |
| MAX_FILE_SIZE = 1_000_000 | |
| DEFAULT_ACTION_TIMEOUT_S = 10.0 | |
| MAX_ACTION_TIMEOUT_S = 30.0 | |
| MAX_STDOUT_CHARS = 50_000 | |
| MAX_STDERR_CHARS = 10_000 | |
| PENALTY_FAILURE = -0.03 | |
| PENALTY_DESTRUCTIVE = -0.20 | |
| PENALTY_REPEAT = -0.08 | |
| PENALTY_DISALLOWED_TOOL_UNIT = -0.04 | |
| # Keep milestone bonuses small so the terminal grader remains the dominant signal. | |
| REWARD_EVENT_VALUES = { | |
| "t1_inspected_corruption": 0.05, | |
| "t1_exact_cleanup": 0.04, | |
| "t2_read_source": 0.04, | |
| "t2_candidate_compiles": 0.02, | |
| "t2_verified_fix": 0.03, | |
| "t3_nonempty_select": 0.03, | |
| "t3_matching_sql": 0.03, | |
| "t3_read_formatter_source": 0.02, | |
| "t3_report_data_verified": 0.03, | |
| "t3_formatter_compiles": 0.02, | |
| "t3_report_generated": 0.03, | |
| "t3_email_verified": 0.02, | |
| } | |
| PENALTY_EVENTS = { | |
| "destructive_sql": PENALTY_DESTRUCTIVE, | |
| "multiple_emails": -0.08, | |
| "t2_run_before_read": -0.05, | |
| "t2_write_before_read": -0.05, | |
| } | |
| class DataOpsEnvironment(Environment[DataOpsAction, DataOpsObservation, DataOpsState]): | |
| """Enterprise data pipeline remediation environment (OpenEnv-compliant).""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self) -> None: | |
| self._workspace_dir = os.path.join(WORKSPACE_ROOT, "sessions", uuid.uuid4().hex) | |
| self._db_path = os.path.join(self._workspace_dir, "mock_warehouse.db") | |
| self._state = DataOpsState() | |
| self._scenario: TaskScenarioBundle = build_task_scenario( | |
| "task_1_easy_anomaly", seed=0 | |
| ) | |
| self._evidence: dict[str, Any] = {} | |
| self._pending_events: list[str] = [] | |
| self.email_outbox: list[dict[str, str]] = [] | |
| self._last_action_key: Optional[str] = None | |
| self._milestones: set[str] = set() | |
| self._grader_score = 0.0 | |
| self._disallowed_tool_attempts = 0 | |
| self._lock = threading.Lock() | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> DataOpsObservation: | |
| task_id: str = kwargs.get("task_id", "task_1_easy_anomaly") | |
| if task_id not in TASK_IDS: | |
| raise ValueError(f"Unknown task_id: {task_id}") | |
| with self._lock: | |
| self._scenario = build_task_scenario(task_id, seed=seed) | |
| self._db_path = setup_workspace( | |
| self._workspace_dir, | |
| scenario=self._scenario, | |
| ) | |
| self.email_outbox.clear() | |
| self._last_action_key = None | |
| self._milestones.clear() | |
| self._pending_events = [] | |
| self._disallowed_tool_attempts = 0 | |
| self._evidence = self._initial_evidence() | |
| self._state = DataOpsState( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| step_count=0, | |
| task_id=task_id, | |
| task_description=self._scenario.description, | |
| max_steps=MAX_STEPS, | |
| seed=self._scenario.seed, | |
| ) | |
| self._grader_score = self._current_task_score() | |
| return DataOpsObservation( | |
| status="success", | |
| done=False, | |
| reward=0.0, | |
| message=f"Environment reset. Task: {self._scenario.description}", | |
| step_count=0, | |
| max_steps=MAX_STEPS, | |
| ) | |
| def step( | |
| self, action: DataOpsAction, timeout_s: Optional[float] = None, **kwargs: Any | |
| ) -> DataOpsObservation: | |
| del kwargs | |
| with self._lock: | |
| return self._step_locked(action, timeout_s) | |
| def state(self) -> DataOpsState: | |
| return self._state.model_copy() | |
| def scenario(self) -> TaskScenarioBundle: | |
| return self._scenario | |
| def evidence(self) -> dict[str, Any]: | |
| return deepcopy(self._evidence) | |
| def workspace_dir(self) -> str: | |
| return self._workspace_dir | |
| def db_path(self) -> str: | |
| return self._db_path | |
| def close(self) -> None: | |
| if os.path.isdir(self._workspace_dir): | |
| shutil.rmtree(self._workspace_dir, ignore_errors=True) | |
| def _step_locked( | |
| self, action: DataOpsAction, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| if self._state.done: | |
| return self._obs( | |
| "error", "Episode is over. Call /reset to start a new one.", done=True | |
| ) | |
| model_cls = PAYLOAD_MODELS.get(action.action_type) | |
| if not model_cls: | |
| return self._obs("error", f"Unknown action_type: {action.action_type}") | |
| try: | |
| payload = model_cls(**action.payload) | |
| except ValidationError as exc: | |
| return self._obs( | |
| "error", | |
| f"Invalid payload: {exc.error_count()} validation error(s).", | |
| ) | |
| self._pending_events = [] | |
| obs = self._dispatch(action.action_type, payload, timeout_s) | |
| reward = self._compute_reward(action, obs) | |
| self._state.step_count += 1 | |
| self._state.cumulative_reward += reward | |
| self._state.actions_taken.append(action.action_type) | |
| self._state.emails_sent = len(self.email_outbox) | |
| done = self._state.step_count >= MAX_STEPS or self._task_completed() | |
| self._state.done = done | |
| obs.reward = round(reward, 4) | |
| obs.done = done | |
| obs.step_count = self._state.step_count | |
| obs.max_steps = MAX_STEPS | |
| return obs | |
| def _dispatch( | |
| self, action_type: str, payload: Any, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| handlers = { | |
| "ExecuteSQL": self._handle_sql, | |
| "ReadFile": self._handle_read, | |
| "WriteFile": self._handle_write, | |
| "RunScript": self._handle_run, | |
| "SendEmail": self._handle_email, | |
| } | |
| return handlers[action_type](payload, timeout_s) | |
| def _handle_sql( | |
| self, payload: ExecuteSQLPayload, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| query = payload.query.strip() | |
| while True: | |
| q = query.rstrip() | |
| if not q.endswith(";"): | |
| break | |
| query = q[:-1].rstrip() | |
| statement_type = self._statement_type(query) | |
| validation_error = self._validate_sql_action(query, statement_type) | |
| if validation_error: | |
| return self._obs("error", validation_error) | |
| timeout = self._resolve_timeout(timeout_s) | |
| deadline = time.monotonic() + timeout | |
| try: | |
| with sqlite3.connect(self._db_path) as conn: | |
| conn.set_progress_handler( | |
| lambda: 1 if time.monotonic() >= deadline else 0, | |
| 1_000, | |
| ) | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(query) | |
| if statement_type in {"SELECT", "WITH"}: | |
| cols = [c[0] for c in cursor.description or []] | |
| rows_raw = cursor.fetchmany(MAX_SQL_ROWS + 1) | |
| if len(rows_raw) > MAX_SQL_ROWS: | |
| return self._obs( | |
| "error", | |
| f"Result exceeds {MAX_SQL_ROWS} rows. Add a LIMIT clause.", | |
| ) | |
| rows = [dict(zip(cols, row)) for row in rows_raw] | |
| self._record_sql_select(query, rows) | |
| return DataOpsObservation( | |
| status="success", | |
| sql_results=rows, | |
| message=f"Query returned {len(rows)} rows.", | |
| ) | |
| conn.commit() | |
| self._record_sql_mutation(query, cursor.rowcount) | |
| return self._obs("success", f"Rows affected: {cursor.rowcount}") | |
| except sqlite3.Error as exc: | |
| if "interrupted" in str(exc).lower(): | |
| return self._obs( | |
| "error", f"SQL execution timed out ({timeout:.1f}s limit)." | |
| ) | |
| logger.warning("SQL error: %s", exc) | |
| msg = "SQL execution error. Check your query syntax." | |
| if self._state.task_id == "task_3_hard_e2e" and re.search( | |
| r"\bdate\b", query, re.IGNORECASE | |
| ): | |
| if "report_date" not in query.lower(): | |
| msg += " Hint: table `daily_reports` uses column `report_date` for the calendar date." | |
| return self._obs("error", msg) | |
| def _handle_read( | |
| self, payload: ReadFilePayload, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| del timeout_s | |
| basename = os.path.basename(payload.filepath) | |
| if not self._is_allowed_file(TASK_ALLOWED_READ_FILES, basename): | |
| return self._obs( | |
| "error", f"Reading {basename} is not allowed for this task." | |
| ) | |
| safe_path = self._resolve_workspace_path(basename) | |
| if safe_path is None: | |
| return self._obs("error", "Resolved file path escapes the workspace.") | |
| if not os.path.isfile(safe_path): | |
| return self._obs("error", f"File not found: {basename}") | |
| if os.path.getsize(safe_path) > MAX_FILE_SIZE: | |
| return self._obs("error", "File too large to read.") | |
| try: | |
| with open(safe_path, encoding="utf-8") as f: | |
| content = f.read(MAX_FILE_SIZE) | |
| except OSError: | |
| return self._obs("error", "Failed to read file.") | |
| if ( | |
| self._state.task_id == "task_2_medium_syntax" | |
| and basename == "broken_pipeline.py" | |
| ): | |
| self._evidence["task_2"]["read_source"] = True | |
| self._record_event("t2_read_source") | |
| if self._state.task_id == "task_3_hard_e2e" and basename == "format_report.py": | |
| self._evidence["task_3"]["read_formatter_source"] = True | |
| self._record_event("t3_read_formatter_source") | |
| return DataOpsObservation( | |
| status="success", | |
| stdout=content, | |
| message=f"Read {len(content)} chars from {basename}", | |
| ) | |
| def _handle_write( | |
| self, payload: WriteFilePayload, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| del timeout_s | |
| basename = os.path.basename(payload.filepath) | |
| if not self._is_allowed_file(TASK_ALLOWED_WRITE_FILES, basename): | |
| return self._obs( | |
| "error", f"Writing {basename} is not allowed for this task." | |
| ) | |
| if ( | |
| self._state.task_id == "task_2_medium_syntax" | |
| and basename == "broken_pipeline.py" | |
| ): | |
| if not self._evidence["task_2"]["read_source"]: | |
| self._pending_events.append("t2_write_before_read") | |
| safe_path = self._resolve_workspace_path(basename) | |
| if safe_path is None: | |
| return self._obs("error", "Resolved file path escapes the workspace.") | |
| try: | |
| with open(safe_path, "w", encoding="utf-8") as f: | |
| f.write(payload.content) | |
| except OSError: | |
| return self._obs("error", "Failed to write file.") | |
| self._record_write_evidence(basename, payload.content) | |
| return self._obs("success", f"Wrote {len(payload.content)} chars to {basename}") | |
| def _handle_run( | |
| self, payload: RunScriptPayload, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| basename = os.path.basename(payload.filepath) | |
| if not self._is_allowed_file(TASK_ALLOWED_RUN_FILES, basename): | |
| return self._obs( | |
| "error", f"Executing {basename} is not allowed for this task." | |
| ) | |
| script_path = self._resolve_workspace_path(basename) | |
| if script_path is None: | |
| return self._obs("error", "Resolved script path escapes the workspace.") | |
| if not os.path.isfile(script_path): | |
| return self._obs("error", f"Script not found: {basename}") | |
| if ( | |
| self._state.task_id == "task_2_medium_syntax" | |
| and basename == "broken_pipeline.py" | |
| ): | |
| if not self._evidence["task_2"]["read_source"]: | |
| self._pending_events.append("t2_run_before_read") | |
| timeout = self._resolve_timeout(timeout_s) | |
| try: | |
| result = run_python_script( | |
| basename, | |
| cwd=self._workspace_dir, | |
| args=list(payload.args), | |
| timeout_s=timeout, | |
| stdout_limit=MAX_STDOUT_CHARS, | |
| stderr_limit=MAX_STDERR_CHARS, | |
| ) | |
| except OSError: | |
| return self._obs("error", "Failed to execute script.") | |
| if result.timed_out: | |
| return self._obs("error", f"Script timed out ({timeout:.1f}s limit).") | |
| self._record_run_evidence(basename, payload.args, result) | |
| status = "success" if result.returncode == 0 else "error" | |
| return DataOpsObservation( | |
| status=status, | |
| stdout=(result.stdout or "")[:MAX_STDOUT_CHARS], | |
| stderr=(result.stderr or "")[:MAX_STDERR_CHARS], | |
| message=f"Exit code: {result.returncode}", | |
| ) | |
| def _handle_email( | |
| self, payload: SendEmailPayload, timeout_s: Optional[float] | |
| ) -> DataOpsObservation: | |
| del timeout_s | |
| if self._state.task_id not in TASK_EMAIL_ENABLED: | |
| self._disallowed_tool_attempts += 1 | |
| self._pending_events.append("disallowed_tool") | |
| return self._obs( | |
| "error", | |
| "Email is not available for this task. Use read_file, write_file, and invoke_python only.", | |
| ) | |
| email = { | |
| "to_email": payload.to_email, | |
| "subject": payload.subject, | |
| "body": payload.body, | |
| } | |
| self.email_outbox.append(email) | |
| self._record_email_evidence(email) | |
| return DataOpsObservation( | |
| status="success", | |
| email_delivery_status=f"Queued for {payload.to_email}", | |
| message=f"Email queued for delivery to {payload.to_email}", | |
| ) | |
| def _compute_reward(self, action: DataOpsAction, obs: DataOpsObservation) -> float: | |
| current_score = self._current_task_score() | |
| reward = current_score - self._grader_score | |
| self._grader_score = current_score | |
| if obs.status != "success": | |
| reward += PENALTY_FAILURE | |
| action_key = ( | |
| f"{action.action_type}:" | |
| f"{json.dumps(action.payload, sort_keys=True, ensure_ascii=True)}" | |
| ) | |
| if action_key == self._last_action_key: | |
| reward += PENALTY_REPEAT | |
| self._last_action_key = action_key | |
| for event in self._pending_events: | |
| if event == "disallowed_tool": | |
| reward += PENALTY_DISALLOWED_TOOL_UNIT * min( | |
| self._disallowed_tool_attempts, 12 | |
| ) | |
| continue | |
| if event in PENALTY_EVENTS: | |
| reward += PENALTY_EVENTS[event] | |
| continue | |
| reward += self._award_event(event) | |
| return reward | |
| def _award_event(self, event: str) -> float: | |
| if event in self._milestones: | |
| return 0.0 | |
| self._milestones.add(event) | |
| return REWARD_EVENT_VALUES.get(event, 0.0) | |
| def _initial_evidence(self) -> dict[str, Any]: | |
| return { | |
| "task_1": { | |
| "inspected_corrupted_rows": False, | |
| "exact_cleanup": False, | |
| "destructive_sql_attempted": False, | |
| }, | |
| "task_2": { | |
| "read_source": False, | |
| "candidate_compiles": False, | |
| "verified_fix": False, | |
| }, | |
| "task_3": { | |
| "matching_sql_executed": False, | |
| "last_matching_sql_rows": [], | |
| "read_formatter_source": False, | |
| "report_data_matches_sql": False, | |
| "formatter_compiles": False, | |
| "format_output_matches_expected": False, | |
| "last_formatter_output": "", | |
| "email_matches_formatter_output": False, | |
| "single_email_sent": True, | |
| }, | |
| } | |
| def _record_sql_select(self, query: str, rows: list[dict[str, Any]]) -> None: | |
| if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly": | |
| row_ids = {int(row.get("id")) for row in rows if row.get("id") is not None} | |
| corrupted = set(self._scenario.task_1.corrupted_row_ids) | |
| if row_ids & corrupted: | |
| self._evidence["task_1"]["inspected_corrupted_rows"] = True | |
| self._record_event("t1_inspected_corruption") | |
| if self._scenario.task_3 and self._state.task_id == "task_3_hard_e2e": | |
| normalised_rows = normalize_task_3_rows(rows, require_headcount=True) | |
| expected_rows = list(self._scenario.task_3.expected_rows) | |
| if task_3_data_matches_expected( | |
| normalised_rows, | |
| expected_rows, | |
| require_headcount=True, | |
| ): | |
| self._evidence["task_3"]["matching_sql_executed"] = True | |
| self._evidence["task_3"]["last_matching_sql_rows"] = normalised_rows | |
| self._record_event("t3_matching_sql") | |
| elif rows: | |
| self._record_event("t3_nonempty_select") | |
| def _record_sql_mutation(self, query: str, rowcount: int) -> None: | |
| del rowcount | |
| if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly": | |
| exact_rows = self._current_transactions_rows() | |
| expected_rows = list(self._scenario.task_1.expected_rows) | |
| expected_by_id = {row["id"]: row for row in expected_rows} | |
| actual_by_id = {row["id"]: row for row in exact_rows} | |
| valid_rows_lost = any( | |
| row_id not in actual_by_id for row_id in expected_by_id | |
| ) | |
| valid_rows_changed = any( | |
| actual_by_id[row_id] != expected_row | |
| for row_id, expected_row in expected_by_id.items() | |
| if row_id in actual_by_id | |
| ) | |
| if exact_rows == expected_rows: | |
| self._evidence["task_1"]["exact_cleanup"] = True | |
| self._record_event("t1_exact_cleanup") | |
| elif valid_rows_lost or valid_rows_changed: | |
| self._evidence["task_1"]["destructive_sql_attempted"] = True | |
| self._pending_events.append("destructive_sql") | |
| def _record_write_evidence(self, basename: str, content: str) -> None: | |
| if ( | |
| self._state.task_id == "task_2_medium_syntax" | |
| and basename == "broken_pipeline.py" | |
| ): | |
| compiles = self._script_compiles(content, basename) | |
| self._evidence["task_2"]["candidate_compiles"] = compiles | |
| if compiles: | |
| self._record_event("t2_candidate_compiles") | |
| return | |
| if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e": | |
| return | |
| task_3 = self._evidence["task_3"] | |
| if basename == "report_data.json": | |
| try: | |
| payload = json.loads(content) | |
| except json.JSONDecodeError: | |
| task_3["report_data_matches_sql"] = False | |
| return | |
| if not isinstance(payload, list): | |
| task_3["report_data_matches_sql"] = False | |
| return | |
| normalised_rows = normalize_task_3_rows(payload, require_headcount=True) | |
| expected_rows = list(self._scenario.task_3.expected_rows) | |
| last_sql_rows = task_3.get("last_matching_sql_rows", []) | |
| matches_sql = bool(last_sql_rows) and normalised_rows == last_sql_rows | |
| matches_expected = task_3_data_matches_expected( | |
| normalised_rows, | |
| expected_rows, | |
| require_headcount=True, | |
| ) | |
| task_3["report_data_matches_sql"] = matches_sql and matches_expected | |
| if task_3["report_data_matches_sql"]: | |
| self._record_event("t3_report_data_verified") | |
| return | |
| if basename == "format_report.py": | |
| compiles = self._script_compiles(content, basename) | |
| task_3["formatter_compiles"] = compiles | |
| if compiles: | |
| self._record_event("t3_formatter_compiles") | |
| def _record_run_evidence( | |
| self, | |
| basename: str, | |
| args: list[str], | |
| result: PythonRunResult, | |
| ) -> None: | |
| if ( | |
| self._state.task_id == "task_2_medium_syntax" | |
| and basename == "broken_pipeline.py" | |
| ): | |
| if result.returncode == 0 and self._task_2_candidate_is_functional(): | |
| self._evidence["task_2"]["verified_fix"] = True | |
| self._record_event("t2_verified_fix") | |
| return | |
| if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e": | |
| return | |
| if basename != "format_report.py": | |
| return | |
| task_3 = self._evidence["task_3"] | |
| stdout = (result.stdout or "").strip() | |
| if ( | |
| result.returncode == 0 | |
| and self._task_3_args_reference_report_data(args) | |
| and task_3.get("report_data_matches_sql") | |
| and report_matches_expected( | |
| stdout, | |
| self._scenario.task_3.expected_rows, | |
| self._scenario.task_3.target_date, | |
| ) | |
| ): | |
| task_3["format_output_matches_expected"] = True | |
| task_3["last_formatter_output"] = stdout | |
| self._record_event("t3_report_generated") | |
| def _record_email_evidence(self, email: dict[str, str]) -> None: | |
| if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e": | |
| return | |
| task_3 = self._evidence["task_3"] | |
| if len(self.email_outbox) > 1: | |
| task_3["single_email_sent"] = False | |
| self._pending_events.append("multiple_emails") | |
| if ( | |
| task_3.get("format_output_matches_expected") | |
| and task_3.get("single_email_sent") | |
| and email.get("to_email") == self._scenario.task_3.recipient | |
| and email.get("subject") == self._scenario.task_3.subject | |
| and email.get("body", "").strip() | |
| == str(task_3.get("last_formatter_output", "")).strip() | |
| ): | |
| task_3["email_matches_formatter_output"] = True | |
| self._record_event("t3_email_verified") | |
| def _task_2_candidate_is_functional(self) -> bool: | |
| if not self._scenario.task_2: | |
| return False | |
| wrapper = textwrap.dedent( | |
| f""" | |
| import importlib.util | |
| import json | |
| spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py") | |
| module = importlib.util.module_from_spec(spec) | |
| assert spec.loader is not None | |
| spec.loader.exec_module(module) | |
| cases = {json.dumps(self._scenario.task_2.hidden_cases)} | |
| results = [module.process_data_stream(case) for case in cases] | |
| print("__RESULT__=" + json.dumps(results)) | |
| """ | |
| ) | |
| try: | |
| result = run_python_code( | |
| wrapper, | |
| cwd=self._workspace_dir, | |
| timeout_s=DEFAULT_ACTION_TIMEOUT_S, | |
| stdout_limit=MAX_STDOUT_CHARS, | |
| stderr_limit=MAX_STDERR_CHARS, | |
| ) | |
| except Exception: | |
| return False | |
| payload = next( | |
| ( | |
| line[len("__RESULT__=") :] | |
| for line in result.stdout.splitlines() | |
| if line.startswith("__RESULT__=") | |
| ), | |
| "", | |
| ) | |
| try: | |
| parsed = json.loads(payload) if payload else None | |
| except json.JSONDecodeError: | |
| parsed = None | |
| expected = [list(batch) for batch in self._scenario.task_2.hidden_expected] | |
| return result.returncode == 0 and parsed == expected | |
| def _task_3_args_reference_report_data(self, args: list[str]) -> bool: | |
| if len(args) != 1: | |
| return False | |
| expected_path = self._resolve_workspace_path("report_data.json") | |
| if expected_path is None: | |
| return False | |
| candidate = args[0] | |
| if os.path.isabs(candidate): | |
| resolved = os.path.realpath(candidate) | |
| else: | |
| resolved = os.path.realpath(os.path.join(self._workspace_dir, candidate)) | |
| return resolved == expected_path | |
| def _current_task_score(self) -> float: | |
| if not self._state.task_id: | |
| return 0.0 | |
| try: | |
| from .grading import evaluate_task | |
| return float(evaluate_task(self._state.task_id, self).get("score", 0.0)) | |
| except Exception: | |
| logger.exception( | |
| "Failed to compute current grader score for reward shaping." | |
| ) | |
| return self._grader_score | |
| def _script_compiles(self, content: str, filename: str) -> bool: | |
| try: | |
| compile(content, filename, "exec") | |
| except SyntaxError: | |
| return False | |
| return True | |
| def _task_completed(self) -> bool: | |
| if self._state.task_id == "task_1_easy_anomaly" and self._scenario.task_1: | |
| return self._current_transactions_rows() == list( | |
| self._scenario.task_1.expected_rows | |
| ) | |
| if self._state.task_id == "task_2_medium_syntax": | |
| # Terminal grader can be <1.0 even when verified_fix (visible/hidden/provenance split). | |
| return self._grader_score >= 1.0 | |
| if self._state.task_id == "task_3_hard_e2e": | |
| # Evidence flags can be partially true while component-weighted grader is still <1.0. | |
| return self._grader_score >= 1.0 | |
| return False | |
| def _current_transactions_rows(self) -> list[dict[str, Any]]: | |
| with sqlite3.connect(self._db_path) as conn: | |
| conn.row_factory = sqlite3.Row | |
| rows = conn.execute( | |
| "SELECT id, user_id, amount, status FROM transactions ORDER BY id" | |
| ).fetchall() | |
| return [ | |
| { | |
| "id": int(row["id"]), | |
| "user_id": int(row["user_id"]), | |
| "amount": None | |
| if row["amount"] is None | |
| else round(float(row["amount"]), 2), | |
| "status": str(row["status"]), | |
| } | |
| for row in rows | |
| ] | |
| def _record_event(self, event: str) -> None: | |
| self._pending_events.append(event) | |
| def _resolve_timeout(self, timeout_s: Optional[float]) -> float: | |
| if timeout_s is None: | |
| return DEFAULT_ACTION_TIMEOUT_S | |
| return max(0.1, min(float(timeout_s), MAX_ACTION_TIMEOUT_S)) | |
| def _is_allowed_file( | |
| self, allowed_registry: dict[str, frozenset[str]], basename: str | |
| ) -> bool: | |
| return basename in allowed_registry.get(self._state.task_id, frozenset()) | |
| def _resolve_workspace_path(self, basename: str) -> Optional[str]: | |
| workspace_root = os.path.realpath(self._workspace_dir) | |
| candidate = os.path.realpath(os.path.join(self._workspace_dir, basename)) | |
| if candidate == workspace_root: | |
| return None | |
| if not candidate.startswith(f"{workspace_root}{os.sep}"): | |
| return None | |
| return candidate | |
| def _statement_type(self, query: str) -> str: | |
| parts = query.split(None, 1) | |
| return parts[0].upper() if parts else "" | |
| def _validate_sql_action(self, query: str, statement_type: str) -> Optional[str]: | |
| if not query: | |
| return "SQL query cannot be empty." | |
| policy = TASK_SQL_POLICIES.get(self._state.task_id) | |
| if policy is None: | |
| return "SQL is not available for the active task." | |
| if statement_type not in policy.allowed_commands: | |
| allowed = ", ".join(sorted(policy.allowed_commands)) | |
| return f"Only {allowed} statements are allowed for this task." | |
| sanitized = self._strip_sql_literals_and_comments(query) | |
| normalized = " ".join(sanitized.split()) | |
| lowered = normalized.lower() | |
| if ";" in normalized: | |
| return "Only a single SQL statement is allowed." | |
| if any( | |
| token in lowered | |
| for token in ("pragma", "attach", "detach", "sqlite_", "alter ", "drop ") | |
| ): | |
| return "Query contains disallowed SQL constructs." | |
| if statement_type == "DELETE" and not re.match( | |
| rf"^delete\s+from\s+{re.escape(policy.required_table)}\s+where\b", | |
| lowered, | |
| ): | |
| return f"DELETE statements must target '{policy.required_table}' with an explicit WHERE clause." | |
| cte_names = self._extract_cte_names(normalized) | |
| table_refs = self._extract_sql_table_refs(normalized) | |
| if policy.required_table not in table_refs: | |
| return f"Query must target the '{policy.required_table}' table." | |
| allowed_refs = {policy.required_table, *cte_names} | |
| disallowed = sorted(ref for ref in table_refs if ref not in allowed_refs) | |
| if disallowed: | |
| return f"Query references disallowed table(s): {', '.join(disallowed)}." | |
| return None | |
| def _strip_sql_literals_and_comments(self, query: str) -> str: | |
| without_comments = _SQL_COMMENT_RE.sub(" ", query) | |
| return _SQL_STRING_RE.sub("''", without_comments) | |
| def _extract_cte_names(self, query: str) -> set[str]: | |
| lowered = query.lower().lstrip() | |
| if not lowered.startswith("with "): | |
| return set() | |
| return {match.group(1).lower() for match in _SQL_CTE_NAME_RE.finditer(query)} | |
| def _extract_sql_table_refs(self, query: str) -> set[str]: | |
| return {match.group(1).lower() for match in _SQL_TABLE_REF_RE.finditer(query)} | |
| def _obs( | |
| self, status: str, message: str, *, done: bool = False | |
| ) -> DataOpsObservation: | |
| return DataOpsObservation( | |
| status=status, | |
| message=message, | |
| step_count=self._state.step_count, | |
| max_steps=MAX_STEPS, | |
| done=done, | |
| ) | |