Spaces:
Configuration error
Configuration error
| """ | |
| DB Schema Migration Environment | |
| Three tasks: | |
| - easy: rename messy columns/tables to clean names | |
| - medium: rename + add columns + fix data types | |
| - hard: full normalization + foreign keys + type fixes | |
| """ | |
| import copy | |
| from typing import Dict, Any, List, Optional, Tuple | |
| from server.schemas import ( | |
| Action, Observation, TableInfo, ColumnInfo, | |
| OperationType, StepResult, ResetResult | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Task definitions | |
| # --------------------------------------------------------------------------- | |
| TASKS = { | |
| "easy": { | |
| "description": ( | |
| "A legacy user table was created with terrible naming conventions. " | |
| "Rename all tables and columns to follow snake_case conventions and " | |
| "meaningful names as described in the requirements." | |
| ), | |
| "initial_schema": [ | |
| TableInfo(name="tbl_usr", columns=[ | |
| ColumnInfo(name="usr_id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="usr_nm", data_type="VARCHAR(50)"), | |
| ColumnInfo(name="usr_eml", data_type="VARCHAR(100)"), | |
| ColumnInfo(name="dt_crt", data_type="VARCHAR(20)"), | |
| ColumnInfo(name="stat_cd", data_type="INT"), | |
| ]) | |
| ], | |
| "target_requirements": [ | |
| "Rename table 'tbl_usr' to 'users'", | |
| "Rename column 'usr_id' to 'id'", | |
| "Rename column 'usr_nm' to 'username'", | |
| "Rename column 'usr_eml' to 'email'", | |
| "Rename column 'dt_crt' to 'created_at'", | |
| "Rename column 'stat_cd' to 'status'", | |
| ], | |
| "hints": [ | |
| "All operations are RENAME_TABLE or RENAME_COLUMN", | |
| "Start with the table rename, then columns", | |
| ], | |
| "max_steps": 10, | |
| "expected_schema": [ | |
| TableInfo(name="users", columns=[ | |
| ColumnInfo(name="id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="username", data_type="VARCHAR(50)"), | |
| ColumnInfo(name="email", data_type="VARCHAR(100)"), | |
| ColumnInfo(name="created_at", data_type="VARCHAR(20)"), | |
| ColumnInfo(name="status", data_type="INT"), | |
| ]) | |
| ], | |
| }, | |
| "medium": { | |
| "description": ( | |
| "An orders database has wrong column names, wrong data types, and is " | |
| "missing important columns. Fix naming, types, and add the missing fields." | |
| ), | |
| "initial_schema": [ | |
| TableInfo(name="order_tbl", columns=[ | |
| ColumnInfo(name="oid", data_type="VARCHAR(10)", primary_key=True), | |
| ColumnInfo(name="cust", data_type="VARCHAR(50)"), | |
| ColumnInfo(name="amt", data_type="VARCHAR(20)"), | |
| ColumnInfo(name="ord_dte", data_type="VARCHAR(30)"), | |
| ]), | |
| TableInfo(name="prod_tbl", columns=[ | |
| ColumnInfo(name="pid", data_type="VARCHAR(10)", primary_key=True), | |
| ColumnInfo(name="pname", data_type="TEXT"), | |
| ColumnInfo(name="prc", data_type="VARCHAR(20)"), | |
| ]), | |
| ], | |
| "target_requirements": [ | |
| "Rename table 'order_tbl' to 'orders'", | |
| "Rename table 'prod_tbl' to 'products'", | |
| "Rename column 'oid' to 'order_id' in orders", | |
| "Rename column 'cust' to 'customer_name' in orders", | |
| "Rename column 'amt' to 'total_amount' in orders", | |
| "Rename column 'ord_dte' to 'order_date' in orders", | |
| "Change type of 'total_amount' in orders to DECIMAL(10,2)", | |
| "Change type of 'order_date' in orders to TIMESTAMP", | |
| "Rename column 'pid' to 'product_id' in products", | |
| "Rename column 'pname' to 'product_name' in products", | |
| "Rename column 'prc' to 'price' in products", | |
| "Change type of 'price' in products to DECIMAL(10,2)", | |
| "Add column 'stock_quantity' (INT) to products", | |
| "Add column 'status' (VARCHAR(20)) to orders", | |
| ], | |
| "hints": [ | |
| "Fix table names first, then column names, then types, then add missing columns", | |
| "DECIMAL(10,2) is the correct type for money fields", | |
| ], | |
| "max_steps": 20, | |
| "expected_schema": [ | |
| TableInfo(name="orders", columns=[ | |
| ColumnInfo(name="order_id", data_type="VARCHAR(10)", primary_key=True), | |
| ColumnInfo(name="customer_name", data_type="VARCHAR(50)"), | |
| ColumnInfo(name="total_amount", data_type="DECIMAL(10,2)"), | |
| ColumnInfo(name="order_date", data_type="TIMESTAMP"), | |
| ColumnInfo(name="status", data_type="VARCHAR(20)"), | |
| ]), | |
| TableInfo(name="products", columns=[ | |
| ColumnInfo(name="product_id", data_type="VARCHAR(10)", primary_key=True), | |
| ColumnInfo(name="product_name", data_type="TEXT"), | |
| ColumnInfo(name="price", data_type="DECIMAL(10,2)"), | |
| ColumnInfo(name="stock_quantity", data_type="INT"), | |
| ]), | |
| ], | |
| }, | |
| "hard": { | |
| "description": ( | |
| "A fully denormalized legacy table stores everything in one blob. " | |
| "You must normalize it into 3NF: split into proper tables, fix types, " | |
| "add primary keys, and establish foreign key relationships." | |
| ), | |
| "initial_schema": [ | |
| TableInfo(name="everything", columns=[ | |
| ColumnInfo(name="row_id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="cust_name", data_type="TEXT"), | |
| ColumnInfo(name="cust_email", data_type="TEXT"), | |
| ColumnInfo(name="cust_phone", data_type="TEXT"), | |
| ColumnInfo(name="item_name", data_type="TEXT"), | |
| ColumnInfo(name="item_price", data_type="TEXT"), | |
| ColumnInfo(name="item_qty", data_type="TEXT"), | |
| ColumnInfo(name="order_date", data_type="TEXT"), | |
| ColumnInfo(name="order_total", data_type="TEXT"), | |
| ]) | |
| ], | |
| "target_requirements": [ | |
| "Create table 'customers' with columns: customer_id (INT, PK), name (VARCHAR(100)), email (VARCHAR(150)), phone (VARCHAR(20))", | |
| "Create table 'products' with columns: product_id (INT, PK), product_name (VARCHAR(200)), price (DECIMAL(10,2))", | |
| "Create table 'orders' with columns: order_id (INT, PK), customer_id (INT, FK->customers), order_date (TIMESTAMP), total_amount (DECIMAL(10,2))", | |
| "Create table 'order_items' with columns: item_id (INT, PK), order_id (INT, FK->orders), product_id (INT, FK->products), quantity (INT)", | |
| "Add foreign key: orders.customer_id -> customers.customer_id", | |
| "Add foreign key: order_items.order_id -> orders.order_id", | |
| "Add foreign key: order_items.product_id -> products.product_id", | |
| "Remove the 'everything' table after normalization", | |
| ], | |
| "hints": [ | |
| "Normalize step by step: customers → products → orders → order_items", | |
| "Foreign keys require the referenced table to exist first", | |
| "Use NORMALIZE_TABLE action to create the new tables from 'everything'", | |
| ], | |
| "max_steps": 30, | |
| "expected_tables": ["customers", "products", "orders", "order_items"], | |
| "expected_schema": [ | |
| TableInfo(name="customers", columns=[ | |
| ColumnInfo(name="customer_id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="name", data_type="VARCHAR(100)"), | |
| ColumnInfo(name="email", data_type="VARCHAR(150)"), | |
| ColumnInfo(name="phone", data_type="VARCHAR(20)"), | |
| ]), | |
| TableInfo(name="products", columns=[ | |
| ColumnInfo(name="product_id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="product_name", data_type="VARCHAR(200)"), | |
| ColumnInfo(name="price", data_type="DECIMAL(10,2)"), | |
| ]), | |
| TableInfo(name="orders", columns=[ | |
| ColumnInfo(name="order_id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="customer_id", data_type="INT", foreign_key="customers.customer_id"), | |
| ColumnInfo(name="order_date", data_type="TIMESTAMP"), | |
| ColumnInfo(name="total_amount", data_type="DECIMAL(10,2)"), | |
| ]), | |
| TableInfo(name="order_items", columns=[ | |
| ColumnInfo(name="item_id", data_type="INT", primary_key=True), | |
| ColumnInfo(name="order_id", data_type="INT", foreign_key="orders.order_id"), | |
| ColumnInfo(name="product_id", data_type="INT", foreign_key="products.product_id"), | |
| ColumnInfo(name="quantity", data_type="INT"), | |
| ]), | |
| ], | |
| }, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Environment class | |
| # --------------------------------------------------------------------------- | |
| class DBMigrationEnv: | |
| def __init__(self): | |
| self.task_name: str = "easy" | |
| self.schema: List[TableInfo] = [] | |
| self.steps_taken: List[Dict] = [] | |
| self.done: bool = False | |
| self.step_count: int = 0 | |
| self.score: float = 0.0 | |
| def reset(self, task_name: str = "easy") -> ResetResult: | |
| assert task_name in TASKS, f"Unknown task: {task_name}. Choose from {list(TASKS.keys())}" | |
| self.task_name = task_name | |
| task = TASKS[task_name] | |
| self.schema = copy.deepcopy(task["initial_schema"]) | |
| self.steps_taken = [] | |
| self.done = False | |
| self.step_count = 0 | |
| self.score = 0.0 | |
| obs = self._build_observation() | |
| return ResetResult( | |
| observation=obs, | |
| task_name=task_name, | |
| task_description=task["description"], | |
| ) | |
| def step(self, action: Action) -> StepResult: | |
| if self.done: | |
| return StepResult( | |
| observation=self._build_observation(), | |
| reward=-0.1, | |
| done=True, | |
| info={"error": "episode already done"}, | |
| error="Episode already done", | |
| ) | |
| task = TASKS[self.task_name] | |
| self.step_count += 1 | |
| # Check step limit | |
| if self.step_count > task["max_steps"]: | |
| self.done = True | |
| return StepResult( | |
| observation=self._build_observation(), | |
| reward=-0.2, | |
| done=True, | |
| info={"error": "max steps exceeded"}, | |
| error="Max steps exceeded", | |
| ) | |
| # Handle DONE | |
| if action.operation == OperationType.DONE: | |
| self.done = True | |
| final_score = self._grade() | |
| self.score = final_score | |
| reward = final_score | |
| return StepResult( | |
| observation=self._build_observation(), | |
| reward=reward, | |
| done=True, | |
| info={"final_score": final_score, "message": "Episode ended by agent"}, | |
| ) | |
| # Apply action | |
| reward, error = self._apply_action(action) | |
| self.steps_taken.append({ | |
| "operation": action.operation, | |
| "table": action.table, | |
| "column": action.column, | |
| "new_name": action.new_name, | |
| "data_type": action.data_type, | |
| "reward": reward, | |
| "error": error, | |
| }) | |
| # Check if task is complete | |
| partial_score = self._grade() | |
| self.score = partial_score | |
| if partial_score >= 1.0: | |
| self.done = True | |
| return StepResult( | |
| observation=self._build_observation(), | |
| reward=reward, | |
| done=self.done, | |
| info={"partial_score": partial_score, "step": self.step_count}, | |
| error=error, | |
| ) | |
| def state(self): | |
| from server.schemas import StateResult | |
| return StateResult( | |
| observation=self._build_observation(), | |
| task_name=self.task_name, | |
| step_count=self.step_count, | |
| done=self.done, | |
| score=self.score, | |
| ) | |
| # ----------------------------------------------------------------------- | |
| # Action application | |
| # ----------------------------------------------------------------------- | |
| def _apply_action(self, action: Action) -> Tuple[float, Optional[str]]: | |
| op = action.operation | |
| if op == OperationType.RENAME_TABLE: | |
| return self._rename_table(action.table, action.new_name) | |
| elif op == OperationType.RENAME_COLUMN: | |
| return self._rename_column(action.table, action.column, action.new_name) | |
| elif op == OperationType.ADD_COLUMN: | |
| return self._add_column(action.table, action.column, action.data_type) | |
| elif op == OperationType.DROP_COLUMN: | |
| return self._drop_column(action.table, action.column) | |
| elif op == OperationType.CHANGE_TYPE: | |
| return self._change_type(action.table, action.column, action.data_type) | |
| elif op == OperationType.ADD_FOREIGN_KEY: | |
| return self._add_foreign_key(action.table, action.column, | |
| action.reference_table, action.reference_column) | |
| elif op == OperationType.NORMALIZE_TABLE: | |
| return self._normalize_table(action) | |
| return 0.0, f"Unknown operation: {op}" | |
| def _find_table(self, name: str) -> Optional[TableInfo]: | |
| for t in self.schema: | |
| if t.name == name: | |
| return t | |
| return None | |
| def _rename_table(self, old_name: str, new_name: str) -> Tuple[float, Optional[str]]: | |
| if not new_name: | |
| return -0.1, "new_name is required for rename_table" | |
| t = self._find_table(old_name) | |
| if t is None: | |
| return -0.1, f"Table '{old_name}' not found" | |
| if self._find_table(new_name): | |
| return -0.1, f"Table '{new_name}' already exists" | |
| # Check if this rename is expected | |
| expected = self._is_expected_table_rename(old_name, new_name) | |
| t.name = new_name | |
| return (0.15 if expected else -0.05), None | |
| def _rename_column(self, table_name: str, col_name: str, new_name: str) -> Tuple[float, Optional[str]]: | |
| if not new_name or not col_name: | |
| return -0.1, "column and new_name are required for rename_column" | |
| t = self._find_table(table_name) | |
| if t is None: | |
| return -0.1, f"Table '{table_name}' not found" | |
| col = next((c for c in t.columns if c.name == col_name), None) | |
| if col is None: | |
| return -0.1, f"Column '{col_name}' not found in '{table_name}'" | |
| expected = self._is_expected_column_rename(table_name, col_name, new_name) | |
| col.name = new_name | |
| return (0.1 if expected else -0.05), None | |
| def _add_column(self, table_name: str, col_name: str, data_type: str) -> Tuple[float, Optional[str]]: | |
| if not col_name or not data_type: | |
| return -0.1, "column and data_type required for add_column" | |
| t = self._find_table(table_name) | |
| if t is None: | |
| return -0.1, f"Table '{table_name}' not found" | |
| if any(c.name == col_name for c in t.columns): | |
| return -0.1, f"Column '{col_name}' already exists in '{table_name}'" | |
| expected = self._is_expected_add_column(table_name, col_name, data_type) | |
| t.columns.append(ColumnInfo(name=col_name, data_type=data_type)) | |
| return (0.1 if expected else -0.05), None | |
| def _drop_column(self, table_name: str, col_name: str) -> Tuple[float, Optional[str]]: | |
| t = self._find_table(table_name) | |
| if t is None: | |
| return -0.1, f"Table '{table_name}' not found" | |
| col = next((c for c in t.columns if c.name == col_name), None) | |
| if col is None: | |
| return -0.1, f"Column '{col_name}' not found in '{table_name}'" | |
| if col.primary_key: | |
| return -0.2, f"Cannot drop primary key column '{col_name}'" | |
| t.columns = [c for c in t.columns if c.name != col_name] | |
| return 0.05, None | |
| def _change_type(self, table_name: str, col_name: str, data_type: str) -> Tuple[float, Optional[str]]: | |
| if not data_type: | |
| return -0.1, "data_type required for change_type" | |
| t = self._find_table(table_name) | |
| if t is None: | |
| return -0.1, f"Table '{table_name}' not found" | |
| col = next((c for c in t.columns if c.name == col_name), None) | |
| if col is None: | |
| return -0.1, f"Column '{col_name}' not found in '{table_name}'" | |
| expected = self._is_expected_type_change(table_name, col_name, data_type) | |
| col.data_type = data_type | |
| return (0.1 if expected else -0.05), None | |
| def _add_foreign_key(self, table_name, col_name, ref_table, ref_col) -> Tuple[float, Optional[str]]: | |
| if not ref_table or not ref_col: | |
| return -0.1, "reference_table and reference_column required" | |
| t = self._find_table(table_name) | |
| if t is None: | |
| return -0.1, f"Table '{table_name}' not found" | |
| col = next((c for c in t.columns if c.name == col_name), None) | |
| if col is None: | |
| return -0.1, f"Column '{col_name}' not found in '{table_name}'" | |
| if self._find_table(ref_table) is None: | |
| return -0.15, f"Referenced table '{ref_table}' does not exist yet" | |
| col.foreign_key = f"{ref_table}.{ref_col}" | |
| return 0.15, None | |
| def _normalize_table(self, action: Action) -> Tuple[float, Optional[str]]: | |
| """ | |
| For the hard task: agent uses NORMALIZE_TABLE to declare a new table | |
| they're extracting from 'everything'. The new table name goes in action.new_name, | |
| and data_type field carries a JSON-like column spec string. | |
| We parse new_name as the new table to create with appropriate columns | |
| based on whether it matches expected tables. | |
| """ | |
| new_table_name = action.new_name | |
| if not new_table_name: | |
| return -0.1, "new_name required for normalize_table (name of new table to create)" | |
| if self._find_table(new_table_name): | |
| return -0.05, f"Table '{new_table_name}' already exists" | |
| expected_tables = TASKS["hard"]["expected_schema"] | |
| expected = next((t for t in expected_tables if t.name == new_table_name), None) | |
| if expected is None: | |
| # Creating a table not in requirements — penalize | |
| self.schema.append(TableInfo(name=new_table_name, columns=[ | |
| ColumnInfo(name="id", data_type="INT", primary_key=True) | |
| ])) | |
| return -0.1, f"Table '{new_table_name}' is not in requirements" | |
| # Create the expected table | |
| self.schema.append(copy.deepcopy(expected)) | |
| # Remove foreign keys temporarily (agent must add them via ADD_FOREIGN_KEY) | |
| t = self._find_table(new_table_name) | |
| for col in t.columns: | |
| col.foreign_key = None | |
| return 0.2, None | |
| # ----------------------------------------------------------------------- | |
| # Graders (deterministic, 0.0 - 1.0) | |
| # ----------------------------------------------------------------------- | |
| def _grade(self) -> float: | |
| task_name = self.task_name | |
| if task_name == "easy": | |
| return self._grade_easy() | |
| elif task_name == "medium": | |
| return self._grade_medium() | |
| elif task_name == "hard": | |
| return self._grade_hard() | |
| return 0.0 | |
| def _grade_easy(self) -> float: | |
| """Check exact match against expected schema.""" | |
| expected = TASKS["easy"]["expected_schema"] | |
| score = self._schema_match_score(expected) | |
| # Penalty for extra steps (efficiency) | |
| step_penalty = max(0, (self.step_count - 6) * 0.02) | |
| return max(0.0, min(1.0, score - step_penalty)) | |
| def _grade_medium(self) -> float: | |
| expected = TASKS["medium"]["expected_schema"] | |
| score = self._schema_match_score(expected) | |
| step_penalty = max(0, (self.step_count - 14) * 0.01) | |
| return max(0.0, min(1.0, score - step_penalty)) | |
| def _grade_hard(self) -> float: | |
| expected_tables = ["customers", "products", "orders", "order_items"] | |
| expected_schema = TASKS["hard"]["expected_schema"] | |
| # Component 1: tables exist (25%) | |
| tables_present = sum(1 for t in expected_tables if self._find_table(t) is not None) | |
| table_score = tables_present / len(expected_tables) * 0.25 | |
| # Component 2: correct columns/types (50%) | |
| col_score = self._schema_match_score(expected_schema) * 0.50 | |
| # Component 3: foreign keys (25%) | |
| fk_checks = [ | |
| ("orders", "customer_id", "customers.customer_id"), | |
| ("order_items", "order_id", "orders.order_id"), | |
| ("order_items", "product_id", "products.product_id"), | |
| ] | |
| fks_correct = 0 | |
| for tname, cname, fk in fk_checks: | |
| t = self._find_table(tname) | |
| if t: | |
| col = next((c for c in t.columns if c.name == cname), None) | |
| if col and col.foreign_key == fk: | |
| fks_correct += 1 | |
| fk_score = (fks_correct / len(fk_checks)) * 0.25 | |
| # Penalty if 'everything' table still exists | |
| if self._find_table("everything"): | |
| total = table_score + col_score + fk_score - 0.1 | |
| else: | |
| total = table_score + col_score + fk_score | |
| return max(0.0, min(1.0, total)) | |
| def _schema_match_score(self, expected_tables: List[TableInfo]) -> float: | |
| """Partial credit: score each table and average.""" | |
| if not expected_tables: | |
| return 0.0 | |
| table_scores = [] | |
| for exp_table in expected_tables: | |
| actual = self._find_table(exp_table.name) | |
| if actual is None: | |
| table_scores.append(0.0) | |
| continue | |
| # Score columns | |
| exp_cols = {c.name: c for c in exp_table.columns} | |
| act_cols = {c.name: c for c in actual.columns} | |
| if not exp_cols: | |
| table_scores.append(1.0) | |
| continue | |
| col_score = 0.0 | |
| for cname, ecol in exp_cols.items(): | |
| if cname in act_cols: | |
| acol = act_cols[cname] | |
| match = 1.0 | |
| if acol.data_type.upper() != ecol.data_type.upper(): | |
| match -= 0.3 | |
| if acol.primary_key != ecol.primary_key: | |
| match -= 0.2 | |
| col_score += max(0, match) | |
| # Missing column = 0 for that column | |
| table_scores.append(col_score / len(exp_cols)) | |
| return sum(table_scores) / len(table_scores) | |
| # ----------------------------------------------------------------------- | |
| # Helpers for reward shaping | |
| # ----------------------------------------------------------------------- | |
| def _is_expected_table_rename(self, old_name: str, new_name: str) -> bool: | |
| task = TASKS[self.task_name] | |
| req_str = f"Rename table '{old_name}' to '{new_name}'" | |
| return any(req_str in r for r in task.get("target_requirements", [])) | |
| def _is_expected_column_rename(self, table: str, old_col: str, new_col: str) -> bool: | |
| task = TASKS[self.task_name] | |
| req_str = f"Rename column '{old_col}' to '{new_col}'" | |
| return any(req_str in r for r in task.get("target_requirements", [])) | |
| def _is_expected_add_column(self, table: str, col: str, dtype: str) -> bool: | |
| task = TASKS[self.task_name] | |
| req_str = f"Add column '{col}'" | |
| return any(req_str in r for r in task.get("target_requirements", [])) | |
| def _is_expected_type_change(self, table: str, col: str, dtype: str) -> bool: | |
| task = TASKS[self.task_name] | |
| return any(f"'{col}'" in r and dtype.upper() in r.upper() | |
| for r in task.get("target_requirements", [])) | |
| def _build_observation(self) -> Observation: | |
| task = TASKS[self.task_name] | |
| violations = self._check_violations() | |
| return Observation( | |
| current_schema=copy.deepcopy(self.schema), | |
| target_requirements=task["target_requirements"], | |
| steps_taken=self.steps_taken[-10:], # last 10 steps | |
| violations=violations, | |
| hints=task.get("hints", []), | |
| step_count=self.step_count, | |
| max_steps=task["max_steps"], | |
| ) | |
| def _check_violations(self) -> List[str]: | |
| violations = [] | |
| # Check for duplicate table names | |
| names = [t.name for t in self.schema] | |
| if len(names) != len(set(names)): | |
| violations.append("Duplicate table names detected") | |
| # Check FK references exist | |
| for t in self.schema: | |
| for c in t.columns: | |
| if c.foreign_key: | |
| parts = c.foreign_key.split(".") | |
| if len(parts) == 2: | |
| ref_table = self._find_table(parts[0]) | |
| if ref_table is None: | |
| violations.append( | |
| f"FK violation: {t.name}.{c.name} references non-existent table '{parts[0]}'" | |
| ) | |
| return violations | |