dataops-env / server /dataops_env_environment.py
visheshrathi's picture
Upload folder using huggingface_hub
f89b1ac verified
from __future__ import annotations
import json
import logging
import os
import re
import shutil
import sqlite3
import textwrap
import threading
import time
import uuid
from copy import deepcopy
from typing import Any, Optional
from openenv.core.env_server import Environment
from pydantic import ValidationError
from data.init_db import WORKSPACE_ROOT, setup_workspace
from models import (
PAYLOAD_MODELS,
DataOpsAction,
DataOpsObservation,
DataOpsState,
ExecuteSQLPayload,
ReadFilePayload,
RunScriptPayload,
SendEmailPayload,
WriteFilePayload,
)
from .safe_exec import PythonRunResult, run_python_code, run_python_script
from .task_specs import (
TASK_ALLOWED_READ_FILES,
TASK_ALLOWED_RUN_FILES,
TASK_ALLOWED_WRITE_FILES,
TASK_EMAIL_ENABLED,
TASK_IDS,
TASK_SQL_POLICIES,
TaskScenarioBundle,
build_task_scenario,
normalize_task_3_rows,
report_matches_expected,
task_3_data_matches_expected,
)
logger = logging.getLogger(__name__)
_SQL_COMMENT_RE = re.compile(r"(--[^\n]*|/\*.*?\*/)", re.DOTALL)
_SQL_STRING_RE = re.compile(r"'(?:''|[^'])*'|\"(?:\"\"|[^\"])*\"")
_SQL_TABLE_REF_RE = re.compile(
r"\b(?:from|join|update|into|delete\s+from)\s+([a-zA-Z_][a-zA-Z0-9_]*)",
re.IGNORECASE,
)
_SQL_CTE_NAME_RE = re.compile(
r"(?:\bwith\b|,)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s+as\s*\(",
re.IGNORECASE,
)
MAX_STEPS = 15
MAX_SQL_ROWS = 500
MAX_FILE_SIZE = 1_000_000
DEFAULT_ACTION_TIMEOUT_S = 10.0
MAX_ACTION_TIMEOUT_S = 30.0
MAX_STDOUT_CHARS = 50_000
MAX_STDERR_CHARS = 10_000
PENALTY_FAILURE = -0.03
PENALTY_DESTRUCTIVE = -0.20
PENALTY_REPEAT = -0.08
PENALTY_DISALLOWED_TOOL_UNIT = -0.04
# Keep milestone bonuses small so the terminal grader remains the dominant signal.
REWARD_EVENT_VALUES = {
"t1_inspected_corruption": 0.05,
"t1_exact_cleanup": 0.04,
"t2_read_source": 0.04,
"t2_candidate_compiles": 0.02,
"t2_verified_fix": 0.03,
"t3_nonempty_select": 0.03,
"t3_matching_sql": 0.03,
"t3_read_formatter_source": 0.02,
"t3_report_data_verified": 0.03,
"t3_formatter_compiles": 0.02,
"t3_report_generated": 0.03,
"t3_email_verified": 0.02,
}
PENALTY_EVENTS = {
"destructive_sql": PENALTY_DESTRUCTIVE,
"multiple_emails": -0.08,
"t2_run_before_read": -0.05,
"t2_write_before_read": -0.05,
}
class DataOpsEnvironment(Environment[DataOpsAction, DataOpsObservation, DataOpsState]):
"""Enterprise data pipeline remediation environment (OpenEnv-compliant)."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self) -> None:
self._workspace_dir = os.path.join(WORKSPACE_ROOT, "sessions", uuid.uuid4().hex)
self._db_path = os.path.join(self._workspace_dir, "mock_warehouse.db")
self._state = DataOpsState()
self._scenario: TaskScenarioBundle = build_task_scenario(
"task_1_easy_anomaly", seed=0
)
self._evidence: dict[str, Any] = {}
self._pending_events: list[str] = []
self.email_outbox: list[dict[str, str]] = []
self._last_action_key: Optional[str] = None
self._milestones: set[str] = set()
self._grader_score = 0.0
self._disallowed_tool_attempts = 0
self._lock = threading.Lock()
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> DataOpsObservation:
task_id: str = kwargs.get("task_id", "task_1_easy_anomaly")
if task_id not in TASK_IDS:
raise ValueError(f"Unknown task_id: {task_id}")
with self._lock:
self._scenario = build_task_scenario(task_id, seed=seed)
self._db_path = setup_workspace(
self._workspace_dir,
scenario=self._scenario,
)
self.email_outbox.clear()
self._last_action_key = None
self._milestones.clear()
self._pending_events = []
self._disallowed_tool_attempts = 0
self._evidence = self._initial_evidence()
self._state = DataOpsState(
episode_id=episode_id or str(uuid.uuid4()),
step_count=0,
task_id=task_id,
task_description=self._scenario.description,
max_steps=MAX_STEPS,
seed=self._scenario.seed,
)
self._grader_score = self._current_task_score()
return DataOpsObservation(
status="success",
done=False,
reward=0.0,
message=f"Environment reset. Task: {self._scenario.description}",
step_count=0,
max_steps=MAX_STEPS,
)
def step(
self, action: DataOpsAction, timeout_s: Optional[float] = None, **kwargs: Any
) -> DataOpsObservation:
del kwargs
with self._lock:
return self._step_locked(action, timeout_s)
@property
def state(self) -> DataOpsState:
return self._state.model_copy()
@property
def scenario(self) -> TaskScenarioBundle:
return self._scenario
@property
def evidence(self) -> dict[str, Any]:
return deepcopy(self._evidence)
@property
def workspace_dir(self) -> str:
return self._workspace_dir
@property
def db_path(self) -> str:
return self._db_path
def close(self) -> None:
if os.path.isdir(self._workspace_dir):
shutil.rmtree(self._workspace_dir, ignore_errors=True)
def _step_locked(
self, action: DataOpsAction, timeout_s: Optional[float]
) -> DataOpsObservation:
if self._state.done:
return self._obs(
"error", "Episode is over. Call /reset to start a new one.", done=True
)
model_cls = PAYLOAD_MODELS.get(action.action_type)
if not model_cls:
return self._obs("error", f"Unknown action_type: {action.action_type}")
try:
payload = model_cls(**action.payload)
except ValidationError as exc:
return self._obs(
"error",
f"Invalid payload: {exc.error_count()} validation error(s).",
)
self._pending_events = []
obs = self._dispatch(action.action_type, payload, timeout_s)
reward = self._compute_reward(action, obs)
self._state.step_count += 1
self._state.cumulative_reward += reward
self._state.actions_taken.append(action.action_type)
self._state.emails_sent = len(self.email_outbox)
done = self._state.step_count >= MAX_STEPS or self._task_completed()
self._state.done = done
obs.reward = round(reward, 4)
obs.done = done
obs.step_count = self._state.step_count
obs.max_steps = MAX_STEPS
return obs
def _dispatch(
self, action_type: str, payload: Any, timeout_s: Optional[float]
) -> DataOpsObservation:
handlers = {
"ExecuteSQL": self._handle_sql,
"ReadFile": self._handle_read,
"WriteFile": self._handle_write,
"RunScript": self._handle_run,
"SendEmail": self._handle_email,
}
return handlers[action_type](payload, timeout_s)
def _handle_sql(
self, payload: ExecuteSQLPayload, timeout_s: Optional[float]
) -> DataOpsObservation:
query = payload.query.strip()
while True:
q = query.rstrip()
if not q.endswith(";"):
break
query = q[:-1].rstrip()
statement_type = self._statement_type(query)
validation_error = self._validate_sql_action(query, statement_type)
if validation_error:
return self._obs("error", validation_error)
timeout = self._resolve_timeout(timeout_s)
deadline = time.monotonic() + timeout
try:
with sqlite3.connect(self._db_path) as conn:
conn.set_progress_handler(
lambda: 1 if time.monotonic() >= deadline else 0,
1_000,
)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute(query)
if statement_type in {"SELECT", "WITH"}:
cols = [c[0] for c in cursor.description or []]
rows_raw = cursor.fetchmany(MAX_SQL_ROWS + 1)
if len(rows_raw) > MAX_SQL_ROWS:
return self._obs(
"error",
f"Result exceeds {MAX_SQL_ROWS} rows. Add a LIMIT clause.",
)
rows = [dict(zip(cols, row)) for row in rows_raw]
self._record_sql_select(query, rows)
return DataOpsObservation(
status="success",
sql_results=rows,
message=f"Query returned {len(rows)} rows.",
)
conn.commit()
self._record_sql_mutation(query, cursor.rowcount)
return self._obs("success", f"Rows affected: {cursor.rowcount}")
except sqlite3.Error as exc:
if "interrupted" in str(exc).lower():
return self._obs(
"error", f"SQL execution timed out ({timeout:.1f}s limit)."
)
logger.warning("SQL error: %s", exc)
msg = "SQL execution error. Check your query syntax."
if self._state.task_id == "task_3_hard_e2e" and re.search(
r"\bdate\b", query, re.IGNORECASE
):
if "report_date" not in query.lower():
msg += " Hint: table `daily_reports` uses column `report_date` for the calendar date."
return self._obs("error", msg)
def _handle_read(
self, payload: ReadFilePayload, timeout_s: Optional[float]
) -> DataOpsObservation:
del timeout_s
basename = os.path.basename(payload.filepath)
if not self._is_allowed_file(TASK_ALLOWED_READ_FILES, basename):
return self._obs(
"error", f"Reading {basename} is not allowed for this task."
)
safe_path = self._resolve_workspace_path(basename)
if safe_path is None:
return self._obs("error", "Resolved file path escapes the workspace.")
if not os.path.isfile(safe_path):
return self._obs("error", f"File not found: {basename}")
if os.path.getsize(safe_path) > MAX_FILE_SIZE:
return self._obs("error", "File too large to read.")
try:
with open(safe_path, encoding="utf-8") as f:
content = f.read(MAX_FILE_SIZE)
except OSError:
return self._obs("error", "Failed to read file.")
if (
self._state.task_id == "task_2_medium_syntax"
and basename == "broken_pipeline.py"
):
self._evidence["task_2"]["read_source"] = True
self._record_event("t2_read_source")
if self._state.task_id == "task_3_hard_e2e" and basename == "format_report.py":
self._evidence["task_3"]["read_formatter_source"] = True
self._record_event("t3_read_formatter_source")
return DataOpsObservation(
status="success",
stdout=content,
message=f"Read {len(content)} chars from {basename}",
)
def _handle_write(
self, payload: WriteFilePayload, timeout_s: Optional[float]
) -> DataOpsObservation:
del timeout_s
basename = os.path.basename(payload.filepath)
if not self._is_allowed_file(TASK_ALLOWED_WRITE_FILES, basename):
return self._obs(
"error", f"Writing {basename} is not allowed for this task."
)
if (
self._state.task_id == "task_2_medium_syntax"
and basename == "broken_pipeline.py"
):
if not self._evidence["task_2"]["read_source"]:
self._pending_events.append("t2_write_before_read")
safe_path = self._resolve_workspace_path(basename)
if safe_path is None:
return self._obs("error", "Resolved file path escapes the workspace.")
try:
with open(safe_path, "w", encoding="utf-8") as f:
f.write(payload.content)
except OSError:
return self._obs("error", "Failed to write file.")
self._record_write_evidence(basename, payload.content)
return self._obs("success", f"Wrote {len(payload.content)} chars to {basename}")
def _handle_run(
self, payload: RunScriptPayload, timeout_s: Optional[float]
) -> DataOpsObservation:
basename = os.path.basename(payload.filepath)
if not self._is_allowed_file(TASK_ALLOWED_RUN_FILES, basename):
return self._obs(
"error", f"Executing {basename} is not allowed for this task."
)
script_path = self._resolve_workspace_path(basename)
if script_path is None:
return self._obs("error", "Resolved script path escapes the workspace.")
if not os.path.isfile(script_path):
return self._obs("error", f"Script not found: {basename}")
if (
self._state.task_id == "task_2_medium_syntax"
and basename == "broken_pipeline.py"
):
if not self._evidence["task_2"]["read_source"]:
self._pending_events.append("t2_run_before_read")
timeout = self._resolve_timeout(timeout_s)
try:
result = run_python_script(
basename,
cwd=self._workspace_dir,
args=list(payload.args),
timeout_s=timeout,
stdout_limit=MAX_STDOUT_CHARS,
stderr_limit=MAX_STDERR_CHARS,
)
except OSError:
return self._obs("error", "Failed to execute script.")
if result.timed_out:
return self._obs("error", f"Script timed out ({timeout:.1f}s limit).")
self._record_run_evidence(basename, payload.args, result)
status = "success" if result.returncode == 0 else "error"
return DataOpsObservation(
status=status,
stdout=(result.stdout or "")[:MAX_STDOUT_CHARS],
stderr=(result.stderr or "")[:MAX_STDERR_CHARS],
message=f"Exit code: {result.returncode}",
)
def _handle_email(
self, payload: SendEmailPayload, timeout_s: Optional[float]
) -> DataOpsObservation:
del timeout_s
if self._state.task_id not in TASK_EMAIL_ENABLED:
self._disallowed_tool_attempts += 1
self._pending_events.append("disallowed_tool")
return self._obs(
"error",
"Email is not available for this task. Use read_file, write_file, and invoke_python only.",
)
email = {
"to_email": payload.to_email,
"subject": payload.subject,
"body": payload.body,
}
self.email_outbox.append(email)
self._record_email_evidence(email)
return DataOpsObservation(
status="success",
email_delivery_status=f"Queued for {payload.to_email}",
message=f"Email queued for delivery to {payload.to_email}",
)
def _compute_reward(self, action: DataOpsAction, obs: DataOpsObservation) -> float:
current_score = self._current_task_score()
reward = current_score - self._grader_score
self._grader_score = current_score
if obs.status != "success":
reward += PENALTY_FAILURE
action_key = (
f"{action.action_type}:"
f"{json.dumps(action.payload, sort_keys=True, ensure_ascii=True)}"
)
if action_key == self._last_action_key:
reward += PENALTY_REPEAT
self._last_action_key = action_key
for event in self._pending_events:
if event == "disallowed_tool":
reward += PENALTY_DISALLOWED_TOOL_UNIT * min(
self._disallowed_tool_attempts, 12
)
continue
if event in PENALTY_EVENTS:
reward += PENALTY_EVENTS[event]
continue
reward += self._award_event(event)
return reward
def _award_event(self, event: str) -> float:
if event in self._milestones:
return 0.0
self._milestones.add(event)
return REWARD_EVENT_VALUES.get(event, 0.0)
def _initial_evidence(self) -> dict[str, Any]:
return {
"task_1": {
"inspected_corrupted_rows": False,
"exact_cleanup": False,
"destructive_sql_attempted": False,
},
"task_2": {
"read_source": False,
"candidate_compiles": False,
"verified_fix": False,
},
"task_3": {
"matching_sql_executed": False,
"last_matching_sql_rows": [],
"read_formatter_source": False,
"report_data_matches_sql": False,
"formatter_compiles": False,
"format_output_matches_expected": False,
"last_formatter_output": "",
"email_matches_formatter_output": False,
"single_email_sent": True,
},
}
def _record_sql_select(self, query: str, rows: list[dict[str, Any]]) -> None:
if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly":
row_ids = {int(row.get("id")) for row in rows if row.get("id") is not None}
corrupted = set(self._scenario.task_1.corrupted_row_ids)
if row_ids & corrupted:
self._evidence["task_1"]["inspected_corrupted_rows"] = True
self._record_event("t1_inspected_corruption")
if self._scenario.task_3 and self._state.task_id == "task_3_hard_e2e":
normalised_rows = normalize_task_3_rows(rows, require_headcount=True)
expected_rows = list(self._scenario.task_3.expected_rows)
if task_3_data_matches_expected(
normalised_rows,
expected_rows,
require_headcount=True,
):
self._evidence["task_3"]["matching_sql_executed"] = True
self._evidence["task_3"]["last_matching_sql_rows"] = normalised_rows
self._record_event("t3_matching_sql")
elif rows:
self._record_event("t3_nonempty_select")
def _record_sql_mutation(self, query: str, rowcount: int) -> None:
del rowcount
if self._scenario.task_1 and self._state.task_id == "task_1_easy_anomaly":
exact_rows = self._current_transactions_rows()
expected_rows = list(self._scenario.task_1.expected_rows)
expected_by_id = {row["id"]: row for row in expected_rows}
actual_by_id = {row["id"]: row for row in exact_rows}
valid_rows_lost = any(
row_id not in actual_by_id for row_id in expected_by_id
)
valid_rows_changed = any(
actual_by_id[row_id] != expected_row
for row_id, expected_row in expected_by_id.items()
if row_id in actual_by_id
)
if exact_rows == expected_rows:
self._evidence["task_1"]["exact_cleanup"] = True
self._record_event("t1_exact_cleanup")
elif valid_rows_lost or valid_rows_changed:
self._evidence["task_1"]["destructive_sql_attempted"] = True
self._pending_events.append("destructive_sql")
def _record_write_evidence(self, basename: str, content: str) -> None:
if (
self._state.task_id == "task_2_medium_syntax"
and basename == "broken_pipeline.py"
):
compiles = self._script_compiles(content, basename)
self._evidence["task_2"]["candidate_compiles"] = compiles
if compiles:
self._record_event("t2_candidate_compiles")
return
if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
return
task_3 = self._evidence["task_3"]
if basename == "report_data.json":
try:
payload = json.loads(content)
except json.JSONDecodeError:
task_3["report_data_matches_sql"] = False
return
if not isinstance(payload, list):
task_3["report_data_matches_sql"] = False
return
normalised_rows = normalize_task_3_rows(payload, require_headcount=True)
expected_rows = list(self._scenario.task_3.expected_rows)
last_sql_rows = task_3.get("last_matching_sql_rows", [])
matches_sql = bool(last_sql_rows) and normalised_rows == last_sql_rows
matches_expected = task_3_data_matches_expected(
normalised_rows,
expected_rows,
require_headcount=True,
)
task_3["report_data_matches_sql"] = matches_sql and matches_expected
if task_3["report_data_matches_sql"]:
self._record_event("t3_report_data_verified")
return
if basename == "format_report.py":
compiles = self._script_compiles(content, basename)
task_3["formatter_compiles"] = compiles
if compiles:
self._record_event("t3_formatter_compiles")
def _record_run_evidence(
self,
basename: str,
args: list[str],
result: PythonRunResult,
) -> None:
if (
self._state.task_id == "task_2_medium_syntax"
and basename == "broken_pipeline.py"
):
if result.returncode == 0 and self._task_2_candidate_is_functional():
self._evidence["task_2"]["verified_fix"] = True
self._record_event("t2_verified_fix")
return
if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
return
if basename != "format_report.py":
return
task_3 = self._evidence["task_3"]
stdout = (result.stdout or "").strip()
if (
result.returncode == 0
and self._task_3_args_reference_report_data(args)
and task_3.get("report_data_matches_sql")
and report_matches_expected(
stdout,
self._scenario.task_3.expected_rows,
self._scenario.task_3.target_date,
)
):
task_3["format_output_matches_expected"] = True
task_3["last_formatter_output"] = stdout
self._record_event("t3_report_generated")
def _record_email_evidence(self, email: dict[str, str]) -> None:
if not self._scenario.task_3 or self._state.task_id != "task_3_hard_e2e":
return
task_3 = self._evidence["task_3"]
if len(self.email_outbox) > 1:
task_3["single_email_sent"] = False
self._pending_events.append("multiple_emails")
if (
task_3.get("format_output_matches_expected")
and task_3.get("single_email_sent")
and email.get("to_email") == self._scenario.task_3.recipient
and email.get("subject") == self._scenario.task_3.subject
and email.get("body", "").strip()
== str(task_3.get("last_formatter_output", "")).strip()
):
task_3["email_matches_formatter_output"] = True
self._record_event("t3_email_verified")
def _task_2_candidate_is_functional(self) -> bool:
if not self._scenario.task_2:
return False
wrapper = textwrap.dedent(
f"""
import importlib.util
import json
spec = importlib.util.spec_from_file_location("candidate_pipeline", "broken_pipeline.py")
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
cases = {json.dumps(self._scenario.task_2.hidden_cases)}
results = [module.process_data_stream(case) for case in cases]
print("__RESULT__=" + json.dumps(results))
"""
)
try:
result = run_python_code(
wrapper,
cwd=self._workspace_dir,
timeout_s=DEFAULT_ACTION_TIMEOUT_S,
stdout_limit=MAX_STDOUT_CHARS,
stderr_limit=MAX_STDERR_CHARS,
)
except Exception:
return False
payload = next(
(
line[len("__RESULT__=") :]
for line in result.stdout.splitlines()
if line.startswith("__RESULT__=")
),
"",
)
try:
parsed = json.loads(payload) if payload else None
except json.JSONDecodeError:
parsed = None
expected = [list(batch) for batch in self._scenario.task_2.hidden_expected]
return result.returncode == 0 and parsed == expected
def _task_3_args_reference_report_data(self, args: list[str]) -> bool:
if len(args) != 1:
return False
expected_path = self._resolve_workspace_path("report_data.json")
if expected_path is None:
return False
candidate = args[0]
if os.path.isabs(candidate):
resolved = os.path.realpath(candidate)
else:
resolved = os.path.realpath(os.path.join(self._workspace_dir, candidate))
return resolved == expected_path
def _current_task_score(self) -> float:
if not self._state.task_id:
return 0.0
try:
from .grading import evaluate_task
return float(evaluate_task(self._state.task_id, self).get("score", 0.0))
except Exception:
logger.exception(
"Failed to compute current grader score for reward shaping."
)
return self._grader_score
def _script_compiles(self, content: str, filename: str) -> bool:
try:
compile(content, filename, "exec")
except SyntaxError:
return False
return True
def _task_completed(self) -> bool:
if self._state.task_id == "task_1_easy_anomaly" and self._scenario.task_1:
return self._current_transactions_rows() == list(
self._scenario.task_1.expected_rows
)
if self._state.task_id == "task_2_medium_syntax":
# Terminal grader can be <1.0 even when verified_fix (visible/hidden/provenance split).
return self._grader_score >= 1.0
if self._state.task_id == "task_3_hard_e2e":
# Evidence flags can be partially true while component-weighted grader is still <1.0.
return self._grader_score >= 1.0
return False
def _current_transactions_rows(self) -> list[dict[str, Any]]:
with sqlite3.connect(self._db_path) as conn:
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, user_id, amount, status FROM transactions ORDER BY id"
).fetchall()
return [
{
"id": int(row["id"]),
"user_id": int(row["user_id"]),
"amount": None
if row["amount"] is None
else round(float(row["amount"]), 2),
"status": str(row["status"]),
}
for row in rows
]
def _record_event(self, event: str) -> None:
self._pending_events.append(event)
def _resolve_timeout(self, timeout_s: Optional[float]) -> float:
if timeout_s is None:
return DEFAULT_ACTION_TIMEOUT_S
return max(0.1, min(float(timeout_s), MAX_ACTION_TIMEOUT_S))
def _is_allowed_file(
self, allowed_registry: dict[str, frozenset[str]], basename: str
) -> bool:
return basename in allowed_registry.get(self._state.task_id, frozenset())
def _resolve_workspace_path(self, basename: str) -> Optional[str]:
workspace_root = os.path.realpath(self._workspace_dir)
candidate = os.path.realpath(os.path.join(self._workspace_dir, basename))
if candidate == workspace_root:
return None
if not candidate.startswith(f"{workspace_root}{os.sep}"):
return None
return candidate
def _statement_type(self, query: str) -> str:
parts = query.split(None, 1)
return parts[0].upper() if parts else ""
def _validate_sql_action(self, query: str, statement_type: str) -> Optional[str]:
if not query:
return "SQL query cannot be empty."
policy = TASK_SQL_POLICIES.get(self._state.task_id)
if policy is None:
return "SQL is not available for the active task."
if statement_type not in policy.allowed_commands:
allowed = ", ".join(sorted(policy.allowed_commands))
return f"Only {allowed} statements are allowed for this task."
sanitized = self._strip_sql_literals_and_comments(query)
normalized = " ".join(sanitized.split())
lowered = normalized.lower()
if ";" in normalized:
return "Only a single SQL statement is allowed."
if any(
token in lowered
for token in ("pragma", "attach", "detach", "sqlite_", "alter ", "drop ")
):
return "Query contains disallowed SQL constructs."
if statement_type == "DELETE" and not re.match(
rf"^delete\s+from\s+{re.escape(policy.required_table)}\s+where\b",
lowered,
):
return f"DELETE statements must target '{policy.required_table}' with an explicit WHERE clause."
cte_names = self._extract_cte_names(normalized)
table_refs = self._extract_sql_table_refs(normalized)
if policy.required_table not in table_refs:
return f"Query must target the '{policy.required_table}' table."
allowed_refs = {policy.required_table, *cte_names}
disallowed = sorted(ref for ref in table_refs if ref not in allowed_refs)
if disallowed:
return f"Query references disallowed table(s): {', '.join(disallowed)}."
return None
def _strip_sql_literals_and_comments(self, query: str) -> str:
without_comments = _SQL_COMMENT_RE.sub(" ", query)
return _SQL_STRING_RE.sub("''", without_comments)
def _extract_cte_names(self, query: str) -> set[str]:
lowered = query.lower().lstrip()
if not lowered.startswith("with "):
return set()
return {match.group(1).lower() for match in _SQL_CTE_NAME_RE.finditer(query)}
def _extract_sql_table_refs(self, query: str) -> set[str]:
return {match.group(1).lower() for match in _SQL_TABLE_REF_RE.finditer(query)}
def _obs(
self, status: str, message: str, *, done: bool = False
) -> DataOpsObservation:
return DataOpsObservation(
status=status,
message=message,
step_count=self._state.step_count,
max_steps=MAX_STEPS,
done=done,
)