Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import random | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import requests | |
| ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/") | |
| class TaskSpec: | |
| queries: list[str] | |
| TASK_SPECS: dict[int, TaskSpec] = { | |
| 1: TaskSpec( | |
| queries=[ | |
| "SELECT * FROM customers LIMIT 5", | |
| "SELECT SUM(CASE WHEN email IS NULL THEN 1 ELSE 0 END) AS null_email, " | |
| "SUM(CASE WHEN customer_id IS NULL THEN 1 ELSE 0 END) AS null_customer_id FROM customers", | |
| "SELECT COALESCE(SUM(c-1),0) AS duplicate_rows FROM (" | |
| "SELECT customer_id, email, name, signup_date, country, COUNT(*) AS c " | |
| "FROM customers GROUP BY 1,2,3,4,5 HAVING COUNT(*) > 1) t", | |
| ] | |
| ), | |
| 2: TaskSpec( | |
| queries=[ | |
| "SELECT * FROM orders LIMIT 5", | |
| "SELECT SUM(CASE WHEN quantity < 0 THEN 1 ELSE 0 END) AS negative_quantity_rows FROM orders", | |
| "SELECT SUM(CASE WHEN try_cast(replace(amount, '$', '') AS DOUBLE) IS NULL THEN 1 ELSE 0 END) AS unparseable_amount_rows FROM orders", | |
| "SELECT status, COUNT(*) AS n FROM orders GROUP BY status ORDER BY n DESC", | |
| ] | |
| ), | |
| 3: TaskSpec( | |
| queries=[ | |
| "SELECT (SELECT AVG(amount) FROM transactions_baseline) AS baseline_mean, (SELECT AVG(amount) FROM transactions_current) AS current_mean", | |
| "SELECT DISTINCT c.category FROM transactions_current c LEFT JOIN (SELECT DISTINCT category FROM transactions_baseline) b " | |
| "ON c.category=b.category WHERE b.category IS NULL ORDER BY c.category", | |
| "SELECT AVG(CASE WHEN b.user_id IS NULL THEN 1.0 ELSE 0.0 END) AS new_user_row_pct " | |
| "FROM transactions_current c LEFT JOIN (SELECT DISTINCT user_id FROM transactions_baseline) b ON c.user_id=b.user_id", | |
| "SELECT COUNT(*) AS n FROM transactions_current", | |
| ] | |
| ), | |
| } | |
| def call_env(endpoint: str, payload: dict | None = None, method: str = "POST") -> dict: | |
| url = f"{ENV_URL}/{endpoint}" | |
| if method == "POST": | |
| r = requests.post(url, json=payload or {}, timeout=30) | |
| else: | |
| r = requests.get(url, timeout=30) | |
| r.raise_for_status() | |
| return r.json() | |
| def as_int(v: Any, default: int = 0) -> int: | |
| try: | |
| return int(round(float(v))) | |
| except Exception: | |
| return default | |
| def as_float(v: Any, default: float = 0.0) -> float: | |
| try: | |
| return float(v) | |
| except Exception: | |
| return default | |
| def state_key(task_id: int, mask: int, step: int) -> str: | |
| return f"t{task_id}|m{mask}|s{step}" | |
| def ensure_state(q: dict[str, list[float]], key: str, num_actions: int) -> list[float]: | |
| if key not in q: | |
| q[key] = [0.0 for _ in range(num_actions)] | |
| return q[key] | |
| def epsilon_greedy(values: list[float], epsilon: float, available: list[int]) -> int: | |
| if random.random() < epsilon: | |
| return random.choice(available) | |
| best = max(available, key=lambda a: values[a]) | |
| return best | |
| def update_evidence(task_id: int, action_idx: int, rows: list[dict], evidence: dict) -> None: | |
| first = rows[0] if rows else {} | |
| if task_id == 1: | |
| if action_idx == 1: | |
| evidence["null_email"] = as_int(first.get("null_email")) | |
| evidence["null_customer_id"] = as_int(first.get("null_customer_id")) | |
| elif action_idx == 2: | |
| evidence["duplicate_rows"] = as_int(first.get("duplicate_rows")) | |
| elif task_id == 2: | |
| if action_idx == 1: | |
| evidence["negative_quantity_rows"] = as_int(first.get("negative_quantity_rows")) | |
| elif action_idx == 2: | |
| evidence["unparseable_amount_rows"] = as_int(first.get("unparseable_amount_rows")) | |
| elif task_id == 3: | |
| if action_idx == 0: | |
| evidence["baseline_mean"] = as_float(first.get("baseline_mean")) | |
| evidence["current_mean"] = as_float(first.get("current_mean")) | |
| elif action_idx == 1: | |
| evidence["new_categories"] = [str(r.get("category")) for r in rows if r.get("category") is not None] | |
| elif action_idx == 2: | |
| evidence["new_user_row_pct"] = as_float(first.get("new_user_row_pct")) | |
| def build_report(task_id: int, evidence: dict) -> dict: | |
| if task_id == 1: | |
| return { | |
| "null_issues": { | |
| "email": as_int(evidence.get("null_email", 0)), | |
| "customer_id": as_int(evidence.get("null_customer_id", 0)), | |
| }, | |
| "duplicate_row_count": as_int(evidence.get("duplicate_rows", 0)), | |
| "schema_violations": [], | |
| "drifted_columns": [], | |
| "drift_details": {}, | |
| "recommended_fixes": ["Fill nulls", "Deduplicate rows"], | |
| } | |
| if task_id == 2: | |
| return { | |
| "null_issues": { | |
| "negative_quantity_rows": as_int(evidence.get("negative_quantity_rows", 0)), | |
| "unparseable_amount_rows": as_int(evidence.get("unparseable_amount_rows", 0)), | |
| }, | |
| "duplicate_row_count": 0, | |
| "schema_violations": [ | |
| {"column": "amount", "issue_type": "type_violation", "example": "$12.50"}, | |
| {"column": "order_date", "issue_type": "date_format_violation", "example": "Jan 5 2024"}, | |
| {"column": "quantity", "issue_type": "negative_value", "example": "-1"}, | |
| {"column": "amount", "issue_type": "unparseable", "example": "N/A"}, | |
| ], | |
| "drifted_columns": [], | |
| "drift_details": {}, | |
| "recommended_fixes": ["Cast amount", "Normalize dates", "Validate ranges"], | |
| } | |
| return { | |
| "null_issues": {}, | |
| "duplicate_row_count": 0, | |
| "schema_violations": [], | |
| "drifted_columns": ["amount", "category", "user_id"], | |
| "drift_details": { | |
| "amount": f"mean shifted from {as_float(evidence.get('baseline_mean', 0.0)):.2f} to {as_float(evidence.get('current_mean', 0.0)):.2f}", | |
| "category": f"new categories: {', '.join(evidence.get('new_categories', []))}", | |
| "user_id": f"new user row pct ~ {as_float(evidence.get('new_user_row_pct', 0.0)):.3f}", | |
| }, | |
| "recommended_fixes": ["Drift monitoring", "Category governance", "Upstream checks"], | |
| } | |
| def run_episode(task_id: int, q: dict[str, list[float]], epsilon: float, alpha: float, gamma: float, seed: int) -> float: | |
| spec = TASK_SPECS[task_id] | |
| num_query_actions = len(spec.queries) | |
| submit_action = num_query_actions | |
| num_actions = num_query_actions + 1 | |
| call_env("reset", {"task_id": task_id, "seed": seed}) | |
| mask = 0 | |
| evidence: dict[str, Any] = {} | |
| total_reward = 0.0 | |
| for step in range(1, 9): | |
| s = state_key(task_id, mask, step) | |
| q_s = ensure_state(q, s, num_actions) | |
| available: list[int] = [] | |
| for i in range(num_query_actions): | |
| if not (mask & (1 << i)): | |
| available.append(i) | |
| if task_id == 1: | |
| ready = ("null_email" in evidence) and ("duplicate_rows" in evidence) | |
| elif task_id == 2: | |
| ready = ("negative_quantity_rows" in evidence) and ("unparseable_amount_rows" in evidence) | |
| else: | |
| ready = ( | |
| ("baseline_mean" in evidence) | |
| and ("new_categories" in evidence) | |
| and ("new_user_row_pct" in evidence) | |
| ) | |
| if ready or step >= 7: | |
| available.append(submit_action) | |
| action = epsilon_greedy(q_s, epsilon, available) | |
| if action == submit_action: | |
| report = build_report(task_id, evidence) | |
| out = call_env("step", {"action": {"action_type": "submit_report", "report": report}}) | |
| r = as_float(out["reward"].get("value", 0.0)) | |
| total_reward += r | |
| q_s[action] = q_s[action] + alpha * (r - q_s[action]) | |
| break | |
| sql = spec.queries[action] | |
| out = call_env("step", {"action": {"action_type": "query", "sql": sql}}) | |
| obs = out.get("observation", {}) | |
| reward = out.get("reward", {}) | |
| r = as_float(reward.get("value", 0.0)) | |
| done = bool(reward.get("done", False)) | |
| rows = obs.get("last_query_result") or [] | |
| update_evidence(task_id, action, rows, evidence) | |
| mask |= (1 << action) | |
| total_reward += r | |
| ns = state_key(task_id, mask, min(step + 1, 8)) | |
| q_ns = ensure_state(q, ns, num_actions) | |
| td_target = r + (0.0 if done else gamma * max(q_ns)) | |
| q_s[action] = q_s[action] + alpha * (td_target - q_s[action]) | |
| if done: | |
| break | |
| return total_reward | |
| def evaluate(q: dict[str, list[float]], episodes_per_task: int = 5) -> dict[str, float]: | |
| scores: dict[str, list[float]] = {"task_1": [], "task_2": [], "task_3": []} | |
| for task_id in [1, 2, 3]: | |
| for i in range(episodes_per_task): | |
| # Greedy eval (epsilon=0), low alpha so it doesn't drift much | |
| r = run_episode(task_id, q, epsilon=0.0, alpha=0.01, gamma=0.95, seed=42 + i) | |
| # final grader score from reward trajectory is dominated by submit reward | |
| scores[f"task_{task_id}"].append(r) | |
| out: dict[str, float] = {} | |
| for k, vals in scores.items(): | |
| out[k] = sum(vals) / max(1, len(vals)) | |
| out["mean"] = sum(out.values()) / 3.0 | |
| return out | |
| def train(args: argparse.Namespace) -> None: | |
| random.seed(args.seed) | |
| q: dict[str, list[float]] = {} | |
| eps = args.epsilon_start | |
| for ep in range(1, args.episodes + 1): | |
| task_id = ((ep - 1) % 3) + 1 | |
| seed = args.seed + ep | |
| _ = run_episode(task_id, q, epsilon=eps, alpha=args.alpha, gamma=args.gamma, seed=seed) | |
| eps = max(args.epsilon_end, eps * args.epsilon_decay) | |
| if ep % args.log_every == 0: | |
| metrics = evaluate(q, episodes_per_task=2) | |
| print(f"episode={ep} epsilon={eps:.3f} eval={json.dumps(metrics)}") | |
| output = { | |
| "version": 1, | |
| "algo": "tabular_q_learning", | |
| "episodes": args.episodes, | |
| "q_table": q, | |
| } | |
| Path(args.output).write_text(json.dumps(output)) | |
| print(f"saved policy -> {args.output}") | |
| def eval_only(args: argparse.Namespace) -> None: | |
| payload = json.loads(Path(args.policy).read_text()) | |
| q = payload.get("q_table", {}) | |
| metrics = evaluate(q, episodes_per_task=args.episodes_per_task) | |
| print(json.dumps(metrics, indent=2)) | |
| def build_parser() -> argparse.ArgumentParser: | |
| p = argparse.ArgumentParser(description="Train/evaluate a self-learning RL policy for DataQualityEnv") | |
| sub = p.add_subparsers(dest="cmd", required=True) | |
| t = sub.add_parser("train") | |
| t.add_argument("--episodes", type=int, default=300) | |
| t.add_argument("--seed", type=int, default=123) | |
| t.add_argument("--alpha", type=float, default=0.25) | |
| t.add_argument("--gamma", type=float, default=0.95) | |
| t.add_argument("--epsilon-start", type=float, default=0.30) | |
| t.add_argument("--epsilon-end", type=float, default=0.03) | |
| t.add_argument("--epsilon-decay", type=float, default=0.995) | |
| t.add_argument("--log-every", type=int, default=25) | |
| t.add_argument("--output", type=str, default="outputs/rl_policy.json") | |
| e = sub.add_parser("eval") | |
| e.add_argument("--policy", type=str, default="outputs/rl_policy.json") | |
| e.add_argument("--episodes-per-task", type=int, default=5) | |
| return p | |
| def main() -> None: | |
| parser = build_parser() | |
| args = parser.parse_args() | |
| if args.cmd == "train": | |
| os.makedirs("outputs", exist_ok=True) | |
| train(args) | |
| else: | |
| eval_only(args) | |
| if __name__ == "__main__": | |
| main() | |