rohan9977's picture
Upload app/env.py with huggingface_hub
c21e94b verified
import random
import asyncio
import re
from datetime import datetime, timezone
from typing import Optional, Tuple
import sqlite3
from app.models import Action, Observation, Reward, StateSnapshot
from app.state_manager import EpisodeState, generate_episode, get_schema_info, take_snapshot
from app.reward import RewardEngine
from app.tasks import get_task
from app.graders import grade_task1, grade_task2, grade_task3
def validate_sql(sql: str, action_type: str) -> Tuple[bool, str]:
if not sql:
return False, "Empty SQL statement"
sql_upper = sql.strip().upper()
tokens = sql_upper.split()
if not tokens:
return False, "Empty SQL statement"
first_token = tokens[0]
if action_type == "query":
allowed_starts = {"SELECT", "WITH", "EXPLAIN", "PRAGMA"}
if first_token not in allowed_starts:
return False, "Only SELECT statements allowed in query actions"
blocked_patterns = [
r"DROP\s+TABLE",
r"DELETE\s+FROM\s+SQLITE_MASTER",
r"DROP\s+INDEX",
r"ATTACH\s+DATABASE"
]
for pat in blocked_patterns:
if re.search(pat, sql_upper):
clean_pat = pat.replace('\\s+', ' ')
return False, f"Blocked pattern detected in query: {clean_pat}"
return True, ""
elif action_type == "ddl":
if re.search(r"\bDROP\s+TABLE\b", sql_upper):
return False, "DROP TABLE is blocked"
if re.search(r"\bATTACH\b|\bDETACH\b", sql_upper):
return False, "ATTACH and DETACH are blocked"
if re.search(r"(UPDATE|INSERT\s+INTO|DELETE\s+FROM)\s+(SQLITE_MASTER|SQLITE_SEQUENCE)\b", sql_upper):
return False, "Writing to sqlite_master or sqlite_sequence is blocked"
if first_token == "ALTER":
if not re.search(r"^ALTER\s+TABLE\s+.*?\s+RENAME\s+COLUMN", sql_upper):
return False, "Only ALTER TABLE ... RENAME COLUMN is allowed"
if first_token == "CREATE":
if re.search(r"^CREATE\s+TABLE", sql_upper) and not re.search(r"^CREATE\s+(TEMP|TEMPORARY)\s+TABLE", sql_upper):
return False, "Only temporary tables are allowed (CREATE TEMP TABLE)"
if not (re.search(r"^CREATE\s+(TEMP|TEMPORARY)\s+TABLE", sql_upper) or re.search(r"^CREATE\s+VIEW", sql_upper)):
return False, "Only CREATE VIEW or CREATE TEMP TABLE allowed for CREATE"
if first_token == "DROP":
if not re.search(r"^DROP\s+VIEW", sql_upper):
return False, "Only DROP VIEW is allowed for DROP statements"
allowed_starts = {"UPDATE", "INSERT", "DELETE", "ALTER", "CREATE", "DROP"}
if first_token not in allowed_starts:
return False, f"DDL action does not allow '{first_token}' statements"
return True, ""
return True, ""
class DataOpsEnv:
def __init__(self):
self.state: Optional[EpisodeState] = None
self.reward_engine: Optional[RewardEngine] = None
self.task_config: Optional[dict] = None
self._lock = asyncio.Lock()
self.last_activity = datetime.now(timezone.utc)
self._last_grader_score = None
async def reset(self, task_id: int, seed: int = None, difficulty_multiplier: float = 1.0) -> Observation:
async with self._lock:
self.last_activity = datetime.now(timezone.utc)
if task_id not in [1, 2, 3]:
raise ValueError("task_id must be 1, 2, or 3")
self.state = generate_episode(task_id, seed, difficulty_multiplier)
task_info = get_task(task_id)
self.task_config = {
"task_id": task_id,
}
main_table = self.state.table_registry.get("main")
if task_id == 1:
id_col = self.state.column_registry.get("id")
rows = self.state.initial_snapshot.get(main_table, [])
self.task_config["initial_null_count"] = sum(1 for r in rows if r.get(id_col) is None)
elif task_id == 2:
rows = self.state.initial_snapshot.get(main_table, [])
self.task_config["total_rows"] = len(rows)
self.task_config["pii_columns"] = [self.state.column_registry.get("email"), self.state.column_registry.get("phone")]
self.task_config["ssn_col"] = self.state.column_registry.get("ssn_col")
elif task_id == 3:
self.task_config["expected_view_output"] = True
self.reward_engine = RewardEngine(self.task_config)
system_logs = []
if task_id == 3:
err_table = self.state.table_registry.get("error_log")
if err_table:
try:
cursor = self.state.db.cursor()
cursor.execute(f"SELECT msg FROM {err_table}")
system_logs = [r["msg"] for r in cursor.fetchall()]
except Exception:
pass
self._last_grader_score = self.grader_score()
return Observation(
current_step=0,
max_steps=self.state.max_steps,
task_id=task_id,
task_description=task_info["description"],
last_action_status="NONE",
last_error_message=None,
query_results=[],
results_truncated=False,
total_rows_returned=0,
schema_info=get_schema_info(self.state),
system_logs=system_logs[:20],
logs_truncated=len(system_logs) > 20,
progress_hint=None
)
def grader_score(self) -> float:
if not self.state:
return 0.0
if self.state.task_id == 1:
return grade_task1(self.state.db, self.state)
elif self.state.task_id == 2:
return grade_task2(self.state.db, self.state)
elif self.state.task_id == 3:
return grade_task3(self.state.db, self.state)
return 0.0
def get_state(self) -> StateSnapshot:
if not self.state:
raise ValueError("Environment not initialized")
tables = take_snapshot(self.state)
return StateSnapshot(
episode_id=self.state.episode_id,
task_id=self.state.task_id,
current_step=self.state.current_step,
tables=tables,
trajectory=self.state.trajectory,
grader_score=self.grader_score(),
seed=self.state.seed,
difficulty_multiplier=self.state.difficulty_multiplier
)
async def step(self, action: Action, session_id: str = "") -> Tuple[Observation, Reward]:
async with self._lock:
try:
self.last_activity = datetime.now(timezone.utc)
if not self.state or self.state.done:
raise RuntimeError("Episode is not active. Call reset().")
score_before = getattr(self, "_last_grader_score", None)
if score_before is None:
score_before = self.grader_score()
try:
action_dict = action.model_dump()
except AttributeError:
action_dict = action if isinstance(action, dict) else dict(action)
action_type = getattr(action, "action_type", action_dict.get("action_type"))
state_before = self.get_state().model_dump()
action_result = {
"status": "SUCCESS",
"error_message": None,
"rows": [],
"results_truncated": False,
"total_rows_returned": 0
}
sql = getattr(action, "sql", action_dict.get("sql", ""))
is_valid = True
val_msg = ""
if action_type in ["query", "ddl"]:
is_valid, val_msg = validate_sql(sql, action_type)
if not is_valid:
action_result["status"] = "ERROR"
action_result["error_message"] = val_msg
else:
self.state.current_step += 1
try:
cursor = self.state.db.cursor()
if action_type == "query":
cursor.execute(sql)
all_rows = cursor.fetchall()
total = len(all_rows)
display_rows = all_rows[:10] # hard cap at 10
def truncate_value(v, max_len=100):
if v is None: return None
s = str(v)
return s[:max_len] + "..." if len(s) > max_len else s
col_names = [d[0] for d in cursor.description] if cursor.description else []
result_dicts = [
{col: truncate_value(val) for col, val in zip(col_names, row)}
for row in display_rows
]
action_result["rows"] = result_dicts
action_result["results_truncated"] = total > 10
action_result["total_rows_returned"] = total
elif action_type == "ddl":
cursor.execute(sql)
self.state.db.commit()
elif action_type == "test":
target_table = getattr(action, "target_table", action_dict.get("target_table"))
cursor.execute(f"SELECT COUNT(*) as cnt FROM {target_table}")
action_result["rows"] = [dict(r) for r in cursor.fetchall()]
elif action_type == "submit":
self.state.done = True
except Exception as e:
action_result["status"] = "ERROR"
action_result["error_message"] = str(e)
score_after = self.grader_score()
self._last_grader_score = score_after
state_after = self.get_state().model_dump()
state_after["grader_score"] = score_after
step_reward_val, breakdown = self.reward_engine.compute(
action=action_dict,
action_result=action_result,
state_before=state_before,
state_after=state_after,
grader_score_before=score_before,
grader_score_after=score_after
)
truncated = False
if self.state.current_step >= self.state.max_steps:
truncated = True
self.state.done = True
progress_hint = None
if self.state.current_step > 8 and score_after < 0.1:
task_info = get_task(self.state.task_id)
hints = task_info.get("hints", [])
progress_hint = random.choice(hints) if hints else "Review the schema and target carefully."
system_logs = []
if self.state.task_id == 3:
err_table = self.state.table_registry.get("error_log")
if err_table:
try:
c = self.state.db.cursor()
c.execute(f"SELECT msg FROM {err_table}")
system_logs = [r["msg"] for r in c.fetchall()]
except Exception:
pass
obs = Observation(
current_step=self.state.current_step,
max_steps=self.state.max_steps,
task_id=self.state.task_id,
task_description=get_task(self.state.task_id)["description"],
last_action_status=action_result["status"],
last_error_message=action_result["error_message"],
query_results=action_result["rows"],
results_truncated=action_result.get("results_truncated", False),
total_rows_returned=action_result.get("total_rows_returned", 0),
schema_info=get_schema_info(self.state),
system_logs=system_logs[:20],
logs_truncated=len(system_logs) > 20,
progress_hint=progress_hint
)
reward = Reward(
step_reward=step_reward_val,
cumulative_reward=self.reward_engine.cumulative,
reward_breakdown=breakdown,
done=self.state.done,
truncated=truncated,
grader_score_before=score_before,
grader_score_after=score_after
)
self.state.trajectory.append({
"action": action_dict,
"observation": obs.model_dump(),
"reward": reward.model_dump()
})
return obs, reward
except sqlite3.OperationalError as e:
# SQL syntax errors, missing tables, broken views
return self._error_observation(
error_msg=f"SQL error: {str(e)}",
reward_penalty=-0.05
), self._error_reward(breakdown={"sql_error": -0.05})
except sqlite3.DatabaseError as e:
# Corrupted state, PRAGMA failures, trigger issues
return self._error_observation(
error_msg=f"Database error: {str(e)}",
reward_penalty=-0.10
), self._error_reward(breakdown={"db_error": -0.10})
except Exception as e:
# Catch-all: unknown agent-triggered edge cases
# Log the full traceback internally but NEVER expose it
import traceback
internal_log = traceback.format_exc()
# Store in state for debugging but do not return to agent
if self.state:
self.state.trajectory.append({
"step": self.state.current_step,
"internal_error": internal_log[:500]
})
return self._error_observation(
error_msg="Internal error — action could not be processed",
reward_penalty=-0.05
), self._error_reward(breakdown={"internal_error": -0.05})
def _error_observation(self, error_msg: str, reward_penalty: float) -> Observation:
return Observation(
current_step=self.state.current_step if self.state else 0,
max_steps=self.state.max_steps if self.state else 20,
task_id=self.state.task_id if self.state else 0,
task_description="",
last_action_status="ERROR",
last_error_message=error_msg,
query_results=[],
schema_info={},
system_logs=[f"ERROR: {error_msg}"],
results_truncated=False,
total_rows_returned=0,
progress_hint=None
)
def _error_reward(self, breakdown: dict) -> Reward:
step_reward = sum(breakdown.values())
if self.state:
self.state.cumulative_reward += step_reward
return Reward(
step_reward=step_reward,
cumulative_reward=self.state.cumulative_reward if self.state else step_reward,
reward_breakdown=breakdown,
done=False,
truncated=False,
grader_score_before=0.0,
grader_score_after=0.0
)