Spaces:
Sleeping
Sleeping
| import json | |
| import random | |
| from pathlib import Path | |
| from env.models import DifficultyLevel, TaskInfo | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOAD DATASETS β Round 1 + Round 2 | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| BASE_DIR = Path(__file__).parent.parent / "dataset" | |
| def _load(filename: str) -> list[dict]: | |
| path = BASE_DIR / filename | |
| with open(path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| # Round 1 cases (keep for backward compatibility) | |
| EASY_CASES = _load("easy_cases.json") | |
| MEDIUM_CASES = _load("medium_cases.json") | |
| HARD_CASES = _load("hard_cases.json") | |
| # Round 2 scenarios (new long-horizon DB engineering tasks) | |
| EASY_SCENARIOS = _load("easy_scenarios.json") | |
| MEDIUM_SCENARIOS = _load("medium_scenarios.json") | |
| HARD_SCENARIOS = _load("hard_scenarios.json") | |
| # Combined pools β Round 2 scenarios take priority (listed first) | |
| ALL_CASES: dict[str, list[dict]] = { | |
| DifficultyLevel.EASY: EASY_SCENARIOS + EASY_CASES, | |
| DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS + MEDIUM_CASES, | |
| DifficultyLevel.HARD: HARD_SCENARIOS + HARD_CASES, | |
| } | |
| # Round 2 only (for training pipeline) | |
| SCENARIO_ONLY: dict[str, list[dict]] = { | |
| DifficultyLevel.EASY: EASY_SCENARIOS, | |
| DifficultyLevel.MEDIUM: MEDIUM_SCENARIOS, | |
| DifficultyLevel.HARD: HARD_SCENARIOS, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # ACTION SCHEMA (required by /tasks validator) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| ACTION_SCHEMA = { | |
| # ββ Round 1 actions ββββββββββββββββββββββββββββββββββββββββββ | |
| "identify_error": { | |
| "description": "Identify where and what the error is without fixing it yet", | |
| "payload_fields": { | |
| "error_location": {"type": "string", "required": True, "description": "Where in the query the error occurs"}, | |
| "error_type": {"type": "string", "required": True, "description": "Type: syntax | logic | performance"}, | |
| "explanation": {"type": "string", "required": False, "description": "Brief explanation of the error"} | |
| } | |
| }, | |
| "propose_fix": { | |
| "description": "Propose a fix without submitting as final answer", | |
| "payload_fields": { | |
| "fixed_query": {"type": "string", "required": True, "description": "The proposed corrected SQL query"}, | |
| "change_made": {"type": "string", "required": True, "description": "What specifically was changed"}, | |
| "confidence": {"type": "float", "required": False, "description": "Confidence score 0.0-1.0"} | |
| } | |
| }, | |
| "submit_answer": { | |
| "description": "Submit the final fixed query as the definitive answer", | |
| "payload_fields": { | |
| "fixed_query": {"type": "string", "required": True, "description": "Final corrected SQL query"}, | |
| "explanation": {"type": "string", "required": True, "description": "Full explanation of fix"}, | |
| "error_type": {"type": "string", "required": False, "description": "syntax | logic | performance"}, | |
| "confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"} | |
| } | |
| }, | |
| "request_hint": { | |
| "description": "Request a hint β costs 0.10 reward penalty per hint", | |
| "payload_fields": { | |
| "hint_type": {"type": "string", "required": False, "description": "location | error_type | fix_direction"} | |
| } | |
| }, | |
| "explain_issue": { | |
| "description": "Explain the issue in detail", | |
| "payload_fields": { | |
| "explanation": {"type": "string", "required": True, "description": "Detailed explanation"}, | |
| "impact": {"type": "string", "required": False, "description": "Impact on query performance"}, | |
| "root_cause": {"type": "string", "required": False, "description": "Root cause analysis"} | |
| } | |
| }, | |
| "optimize_query": { | |
| "description": "Submit an optimized version of the query", | |
| "payload_fields": { | |
| "optimized_query": {"type": "string", "required": True, "description": "Optimized SQL"}, | |
| "optimization_type": {"type": "string", "required": True, "description": "What optimization was applied"}, | |
| "expected_improvement":{"type": "string", "required": False, "description": "Expected performance gain"}, | |
| "explanation": {"type": "string", "required": False, "description": "Why this optimization works"}, | |
| "confidence": {"type": "float", "required": False, "description": "Confidence 0.0-1.0"} | |
| } | |
| }, | |
| # ββ Round 2 actions ββββββββββββββββββββββββββββββββββββββββββ | |
| "inspect_query": { | |
| "description": "EXPLAIN a slow query β reveals scan type, rows examined, index usage", | |
| "payload_fields": { | |
| "query_id": {"type": "string", "required": True, "description": "ID of slow query to inspect (e.g. 'q1')"} | |
| } | |
| }, | |
| "analyze_indexes": { | |
| "description": "Show all indexes on a table + usage frequency + missing index hints", | |
| "payload_fields": { | |
| "table": {"type": "string", "required": True, "description": "Table name to analyze"} | |
| } | |
| }, | |
| "create_index": { | |
| "description": "Add a composite index on specified columns β core optimization action", | |
| "payload_fields": { | |
| "table": {"type": "string", "required": True, "description": "Table to index"}, | |
| "columns": {"type": "list|string", "required": True, "description": "Columns to index (list or comma-separated string)"} | |
| } | |
| }, | |
| "rewrite_query": { | |
| "description": "Submit a rewritten SQL query β system evaluates execution time improvement", | |
| "payload_fields": { | |
| "query_id": {"type": "string", "required": True, "description": "ID of query to rewrite"}, | |
| "new_sql": {"type": "string", "required": True, "description": "Rewritten SQL query"} | |
| } | |
| }, | |
| "add_column": { | |
| "description": "Add a denormalization column to reduce expensive JOINs", | |
| "payload_fields": { | |
| "table": {"type": "string", "required": True, "description": "Table to modify"}, | |
| "column": {"type": "string", "required": True, "description": "New column name"}, | |
| "purpose": {"type": "string", "required": False, "description": "Why this column helps"} | |
| } | |
| }, | |
| "drop_index": { | |
| "description": "Remove an unused index to reduce write overhead", | |
| "payload_fields": { | |
| "table": {"type": "string", "required": True, "description": "Table name"}, | |
| "index_name": {"type": "string", "required": True, "description": "Index name to drop (cannot drop PRIMARY)"} | |
| } | |
| }, | |
| "partition_table": { | |
| "description": "Partition a large table by date or ID range for range query efficiency", | |
| "payload_fields": { | |
| "table": {"type": "string", "required": True, "description": "Table to partition"}, | |
| "partition_by": {"type": "string", "required": False, "description": "Column to partition on (e.g. 'created_at')"}, | |
| "partition_type": {"type": "string", "required": False, "description": "RANGE | LIST | HASH"} | |
| } | |
| }, | |
| "analyze_statistics": { | |
| "description": "Update table statistics for query planner accuracy", | |
| "payload_fields": { | |
| "table": {"type": "string", "required": True, "description": "Table to analyze"} | |
| } | |
| }, | |
| "submit_report": { | |
| "description": "TERMINAL: Submit final optimization report β ends episode, computes full score", | |
| "payload_fields": { | |
| "summary": {"type": "string", "required": True, "description": "Summary of optimizations applied"}, | |
| "actions_taken": {"type": "list", "required": False, "description": "List of key actions taken"}, | |
| "expected_gain": {"type": "string", "required": False, "description": "Expected performance improvement"} | |
| } | |
| }, | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # TASK MANAGER | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| class TaskManager: | |
| """ | |
| Manages task selection for both Round 1 and Round 2 scenarios. | |
| Round 2 scenarios have tables/slow_queries structure. | |
| Round 1 cases have buggy_query structure. | |
| """ | |
| def __init__(self): | |
| self._used_ids: set[str] = set() | |
| def get_task(self, difficulty: DifficultyLevel, task_id: str | None = None) -> dict: | |
| """ | |
| Returns a task for the given difficulty. | |
| Prefers Round 2 scenarios, falls back to Round 1 cases. | |
| """ | |
| pool = ALL_CASES[difficulty] | |
| if task_id: | |
| for case in pool: | |
| if case["id"] == task_id: | |
| return case | |
| raise ValueError(f"Task '{task_id}' not found in {difficulty} pool") | |
| # Avoid recently used tasks | |
| available = [c for c in pool if c["id"] not in self._used_ids] | |
| if not available: | |
| self._used_ids.clear() | |
| available = pool | |
| task = random.choice(available) | |
| self._used_ids.add(task["id"]) | |
| return task | |
| def get_random_task(self) -> dict: | |
| difficulty = random.choice(list(DifficultyLevel)) | |
| return self.get_task(difficulty) | |
| def get_scenario(self, difficulty: DifficultyLevel, scenario_id: str | None = None) -> dict: | |
| """Get Round 2 scenario specifically.""" | |
| pool = SCENARIO_ONLY[difficulty] | |
| if scenario_id: | |
| for s in pool: | |
| if s["id"] == scenario_id: | |
| return s | |
| raise ValueError(f"Scenario '{scenario_id}' not found") | |
| return random.choice(pool) | |
| def build_observation_context(self, task: dict) -> dict: | |
| """ | |
| Builds current_context for the Observation. | |
| Handles both Round 2 scenario format and Round 1 case format. | |
| CRITICAL: Never leaks ground truth (fixed_query / optimal_actions). | |
| """ | |
| # ββ Round 2 scenario format βββββββββββββββββββββββββββββββ | |
| if "slow_queries" in task: | |
| return { | |
| "scenario_id": task["id"], | |
| "description": task.get("description", ""), | |
| "tables": task.get("tables", []), | |
| "slow_queries": task.get("slow_queries", []), | |
| "performance_score_baseline": task.get("performance_score_baseline", 0.0), | |
| "target_score": task.get("target_score", 85.0), | |
| "max_steps": task.get("max_steps", 50), | |
| "category": task.get("category", ""), | |
| # Do NOT include missing_index_hints (that's the answer) | |
| # Do NOT include optimal_actions (that's the answer) | |
| } | |
| # ββ Round 1 case format (backward compatible) ββββββββββββ | |
| context = { | |
| "buggy_query": task.get("buggy_query", ""), | |
| "error_message": task.get("error_message", ""), | |
| "database_schema": task.get("database_schema", ""), | |
| "error_type_hint": task.get("error_type", ""), | |
| "category": task.get("category", ""), | |
| "estimated_steps": task.get("estimated_fix_steps", 5), | |
| } | |
| if task.get("performance_issue"): | |
| context["performance_issue"] = { | |
| "type": task["performance_issue"]["type"], | |
| "impact": task["performance_issue"]["impact"], | |
| } | |
| if task.get("expected_output") and isinstance(task["expected_output"], list): | |
| context["expected_output_sample"] = task["expected_output"][:1] | |
| return context | |
| def get_hint(self, task: dict, hint_number: int) -> str: | |
| """Progressive hints. Each hint reveals more info. Costs -0.10 each.""" | |
| # Round 2 scenario hints | |
| if "slow_queries" in task: | |
| hints = [ | |
| f"Hint 1: Start by inspecting your slow queries with inspect_query action.", | |
| f"Hint 2: Use analyze_indexes on tables appearing in slow queries.", | |
| f"Hint 3: Category is '{task.get('category', 'indexing')}'. Target score: {task.get('target_score', 85.0)}.", | |
| ] | |
| else: | |
| # Round 1 hints | |
| hints = [ | |
| f"Hint 1: The error is in the {task.get('error_location', 'query')}.", | |
| f"Hint 2: This is a {task.get('error_type', 'unknown')} error. Category: {task.get('category')}.", | |
| f"Hint 3: Fix: {task.get('fix_description', 'Review the query carefully.')}", | |
| ] | |
| idx = min(hint_number - 1, len(hints) - 1) | |
| return hints[max(0, idx)] | |
| def list_all_tasks(self) -> list[TaskInfo]: | |
| """Returns TaskInfo list for the /tasks endpoint β all 30 tasks.""" | |
| result = [] | |
| for difficulty, cases in ALL_CASES.items(): | |
| for case in cases: | |
| result.append(TaskInfo( | |
| id = case["id"], | |
| difficulty = difficulty, | |
| description = case.get("description", ""), | |
| action_schema = ACTION_SCHEMA | |
| )) | |
| return result | |
| def get_ground_truth(self, task_id: str) -> dict | None: | |
| """Returns full task including ground truth (used by grader only).""" | |
| for cases in ALL_CASES.values(): | |
| for case in cases: | |
| if case["id"] == task_id: | |
| return case | |
| return None | |
| # Singleton instance | |
| task_manager = TaskManager() | |