Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import threading | |
| from typing import Any | |
| from fastapi import FastAPI, HTTPException | |
| from env.dataset_gen import generate_dataset | |
| from env.engine import SQLEngine | |
| from env.models import Action, EpisodeState, Observation, Reward, RewardBreakdown | |
| from tasks.base import BaseTask | |
| from tasks.task1_nulls import Task1 | |
| from tasks.task2_schema import Task2 | |
| from tasks.task3_drift import Task3 | |
| from tasks.task4_relational import Task4 | |
| app = FastAPI(title="DataQualityEnv") | |
| _lock = threading.Lock() | |
| TASKS = {1: Task1(), 2: Task2(), 3: Task3(), 4: Task4()} | |
| MAX_STEPS = 12 | |
| FIX_STEPS = 3 | |
| state: EpisodeState | None = None | |
| engine: SQLEngine | None = None | |
| gold: dict[str, Any] = {} | |
| table_names: list[str] = [] | |
| def health() -> dict[str, str]: | |
| return {"status": "ok", "env": "DataQualityEnv", "version": "2.0.0"} | |
| def reset(payload: dict): | |
| global state, engine, gold, table_names | |
| task_id = int(payload.get("task_id", 1)) | |
| seed = int(payload.get("seed", 42)) | |
| if task_id not in TASKS: | |
| raise HTTPException(400, f"task_id must be 1-4, got {task_id}") | |
| with _lock: | |
| if engine: | |
| engine.close() | |
| engine = SQLEngine() | |
| tables, gold = generate_dataset(task_id, seed) | |
| engine.load_tables(tables) | |
| table_names = list(tables.keys()) | |
| state = EpisodeState(task_id=task_id, seed=seed, gold_faults=gold, max_steps=MAX_STEPS, fix_steps_remaining=FIX_STEPS) | |
| task = TASKS[task_id] | |
| obs = _make_observation(task, state, engine, table_names, None, None, None) | |
| return obs.model_dump() | |
| def step(payload: dict): | |
| global state | |
| if state is None or state.done: | |
| raise HTTPException(400, "Call /reset first.") | |
| try: | |
| action = Action(**payload.get("action", payload)) | |
| except Exception as e: | |
| raise HTTPException(400, f"Invalid action: {e}") | |
| task = TASKS[state.task_id] | |
| assert engine is not None | |
| with _lock: | |
| state.step += 1 | |
| if state.step > MAX_STEPS: | |
| state.done = True | |
| total = BaseTask.strict_score(round(min(1.25, state.audit_score + state.fix_bonus), 4)) | |
| rb = RewardBreakdown( | |
| base_audit_score=state.audit_score, | |
| confidence_brier_adjustment=0.0, | |
| budget_efficiency_bonus=0.0, | |
| fix_verification_bonus=round(state.fix_bonus, 4), | |
| total=total, | |
| ) | |
| obs = _make_observation(task, state, engine, table_names, None, "max_steps", None) | |
| return _step_response(obs, Reward(value=total, breakdown=rb, done=True, info={"reason": "max_steps"})) | |
| if action.action_type == "query": | |
| if state.phase == "fix": | |
| obs = _make_observation(task, state, engine, table_names, None, "Use fix_sql action in fix phase, not query.", None) | |
| reward = Reward(value=0.0, breakdown=_zero_breakdown(), done=False, info={}) | |
| return _step_response(obs, reward) | |
| if state.query_credits <= 0: | |
| obs = _make_observation(task, state, engine, table_names, None, "No query credits remaining.", None) | |
| reward = Reward(value=0.0, breakdown=_zero_breakdown(), done=False, info={}) | |
| return _step_response(obs, reward) | |
| if not action.sql: | |
| raise HTTPException(400, "sql is required for query action") | |
| result = engine.execute(action.sql) | |
| if isinstance(result, str) and result.startswith("ERROR"): | |
| obs = _make_observation(task, state, engine, table_names, None, result, None) | |
| reward = Reward(value=0.0, breakdown=_zero_breakdown(), done=False, info={"error": result}) | |
| else: | |
| state.query_credits -= 1 | |
| obs = _make_observation(task, state, engine, table_names, result if isinstance(result, list) else None, None, None) | |
| query_reward = _query_reward(action.sql, result if isinstance(result, list) else None) | |
| reward = Reward(value=query_reward, breakdown=_zero_breakdown(), done=False, info={}) | |
| return _step_response(obs, reward) | |
| if action.action_type == "submit_report": | |
| if action.report is None: | |
| raise HTTPException(400, "report is required for submit_report") | |
| if state.report_submitted: | |
| raise HTTPException(400, "Report already submitted. Use fix_sql or reset.") | |
| base_score, score_breakdown = task.grade(action.report, gold) | |
| score_breakdown = { | |
| k: BaseTask.strict_score(v) if isinstance(v, (int, float)) else v | |
| for k, v in score_breakdown.items() | |
| } | |
| budget_bonus = round(min(0.10, state.query_credits * 0.01), 4) | |
| total = BaseTask.strict_score(round(min(1.0, base_score + budget_bonus), 4)) | |
| state.audit_score = total | |
| state.report_submitted = True | |
| state.phase = "fix" | |
| rb = RewardBreakdown( | |
| base_audit_score=float(base_score), | |
| confidence_brier_adjustment=0.0, | |
| budget_efficiency_bonus=budget_bonus, | |
| fix_verification_bonus=0.0, | |
| total=total, | |
| ) | |
| done = state.fix_steps_remaining == 0 | |
| if done: | |
| state.done = True | |
| obs = _make_observation(task, state, engine, table_names, None, None, None) | |
| return _step_response(obs, Reward(value=total, breakdown=rb, done=done, info={"score_breakdown": score_breakdown, "fix_steps_available": FIX_STEPS})) | |
| if action.action_type == "fix_sql": | |
| if not state.report_submitted: | |
| raise HTTPException(400, "Submit report before using fix_sql.") | |
| if not action.sql: | |
| raise HTTPException(400, "sql is required for fix_sql") | |
| if state.fix_steps_remaining <= 0: | |
| state.done = True | |
| total = BaseTask.strict_score(round(min(1.25, state.audit_score + state.fix_bonus), 4)) | |
| rb = RewardBreakdown( | |
| base_audit_score=state.audit_score, | |
| confidence_brier_adjustment=0.0, | |
| budget_efficiency_bonus=0.0, | |
| fix_verification_bonus=round(state.fix_bonus, 4), | |
| total=total, | |
| ) | |
| obs = _make_observation(task, state, engine, table_names, None, None, 0.0) | |
| return _step_response(obs, Reward(value=total, breakdown=rb, done=True, info={})) | |
| fix_score = engine.run_fix_sql(action.sql, gold) | |
| state.fix_bonus = min(0.25, state.fix_bonus + fix_score * 0.08) | |
| state.fix_steps_remaining -= 1 | |
| done = state.fix_steps_remaining == 0 | |
| if done: | |
| state.done = True | |
| total = BaseTask.strict_score(round(min(1.25, state.audit_score + state.fix_bonus), 4)) | |
| rb = RewardBreakdown( | |
| base_audit_score=state.audit_score, | |
| confidence_brier_adjustment=0.0, | |
| budget_efficiency_bonus=0.0, | |
| fix_verification_bonus=round(state.fix_bonus, 4), | |
| total=total, | |
| ) | |
| obs = _make_observation(task, state, engine, table_names, None, None, fix_score) | |
| return _step_response(obs, Reward(value=total, breakdown=rb, done=done, info={})) | |
| raise HTTPException(400, f"Unsupported action_type: {action.action_type}") | |
| def get_state(): | |
| if state is None: | |
| raise HTTPException(400, "No active episode.") | |
| return state.model_dump() | |
| def _make_observation(task, st: EpisodeState, eng: SQLEngine, tables: list[str], query_result, error, last_fix_score) -> Observation: | |
| schemas = eng.get_table_schemas(tables) | |
| row_counts = eng.get_row_counts(tables) | |
| trimmed = query_result[:50] if isinstance(query_result, list) else None | |
| return Observation( | |
| task_id=st.task_id, | |
| task_description=task.get_description(), | |
| tables=schemas, | |
| row_counts=row_counts, | |
| step=st.step, | |
| max_steps=MAX_STEPS, | |
| query_credits_remaining=st.query_credits, | |
| phase=st.phase, | |
| last_query_result=trimmed, | |
| last_action_error=error, | |
| last_fix_score=last_fix_score, | |
| ) | |
| def _step_response(obs: Observation, reward: Reward) -> dict[str, Any]: | |
| return {"observation": obs.model_dump(), "reward": reward.model_dump()} | |
| def _zero_breakdown(destructive: float = 0.0) -> RewardBreakdown: | |
| return RewardBreakdown( | |
| base_audit_score=0.0, | |
| confidence_brier_adjustment=0.0, | |
| budget_efficiency_bonus=0.0, | |
| fix_verification_bonus=destructive, | |
| total=destructive, | |
| ) | |
| def _query_reward(sql: str, result: list[dict[str, Any]] | None) -> float: | |
| """Provide small positive rewards for valid exploration and stronger credit for finding obvious issues.""" | |
| rows = result or [] | |
| if not rows: | |
| return 0.01 | |
| sql_lower = (sql or "").lower() | |
| hot_query = any(keyword in sql_lower for keyword in ("null", "dup", "duplicate", "group by", "having", "count(", "is null")) | |
| def row_has_signal(row: dict[str, Any]) -> bool: | |
| for value in row.values(): | |
| if value is None: | |
| return True | |
| if isinstance(value, (int, float)) and value > 0: | |
| return True | |
| if isinstance(value, str) and value.strip().lower() in {"null", "n/a", "na", "unknown", "none", "-", ""}: | |
| return True | |
| return False | |
| if hot_query and any(row_has_signal(row) for row in rows): | |
| return 0.1 | |
| return 0.01 | |