""" DB Schema Migration Environment Three tasks: - easy: rename messy columns/tables to clean names - medium: rename + add columns + fix data types - hard: full normalization + foreign keys + type fixes """ import copy from typing import Dict, Any, List, Optional, Tuple from server.schemas import ( Action, Observation, TableInfo, ColumnInfo, OperationType, StepResult, ResetResult ) # --------------------------------------------------------------------------- # Task definitions # --------------------------------------------------------------------------- TASKS = { "easy": { "description": ( "A legacy user table was created with terrible naming conventions. " "Rename all tables and columns to follow snake_case conventions and " "meaningful names as described in the requirements." ), "initial_schema": [ TableInfo(name="tbl_usr", columns=[ ColumnInfo(name="usr_id", data_type="INT", primary_key=True), ColumnInfo(name="usr_nm", data_type="VARCHAR(50)"), ColumnInfo(name="usr_eml", data_type="VARCHAR(100)"), ColumnInfo(name="dt_crt", data_type="VARCHAR(20)"), ColumnInfo(name="stat_cd", data_type="INT"), ]) ], "target_requirements": [ "Rename table 'tbl_usr' to 'users'", "Rename column 'usr_id' to 'id'", "Rename column 'usr_nm' to 'username'", "Rename column 'usr_eml' to 'email'", "Rename column 'dt_crt' to 'created_at'", "Rename column 'stat_cd' to 'status'", ], "hints": [ "All operations are RENAME_TABLE or RENAME_COLUMN", "Start with the table rename, then columns", ], "max_steps": 10, "expected_schema": [ TableInfo(name="users", columns=[ ColumnInfo(name="id", data_type="INT", primary_key=True), ColumnInfo(name="username", data_type="VARCHAR(50)"), ColumnInfo(name="email", data_type="VARCHAR(100)"), ColumnInfo(name="created_at", data_type="VARCHAR(20)"), ColumnInfo(name="status", data_type="INT"), ]) ], }, "medium": { "description": ( "An orders database has wrong column names, wrong data types, and is " "missing important columns. Fix naming, types, and add the missing fields." ), "initial_schema": [ TableInfo(name="order_tbl", columns=[ ColumnInfo(name="oid", data_type="VARCHAR(10)", primary_key=True), ColumnInfo(name="cust", data_type="VARCHAR(50)"), ColumnInfo(name="amt", data_type="VARCHAR(20)"), ColumnInfo(name="ord_dte", data_type="VARCHAR(30)"), ]), TableInfo(name="prod_tbl", columns=[ ColumnInfo(name="pid", data_type="VARCHAR(10)", primary_key=True), ColumnInfo(name="pname", data_type="TEXT"), ColumnInfo(name="prc", data_type="VARCHAR(20)"), ]), ], "target_requirements": [ "Rename table 'order_tbl' to 'orders'", "Rename table 'prod_tbl' to 'products'", "Rename column 'oid' to 'order_id' in orders", "Rename column 'cust' to 'customer_name' in orders", "Rename column 'amt' to 'total_amount' in orders", "Rename column 'ord_dte' to 'order_date' in orders", "Change type of 'total_amount' in orders to DECIMAL(10,2)", "Change type of 'order_date' in orders to TIMESTAMP", "Rename column 'pid' to 'product_id' in products", "Rename column 'pname' to 'product_name' in products", "Rename column 'prc' to 'price' in products", "Change type of 'price' in products to DECIMAL(10,2)", "Add column 'stock_quantity' (INT) to products", "Add column 'status' (VARCHAR(20)) to orders", ], "hints": [ "Fix table names first, then column names, then types, then add missing columns", "DECIMAL(10,2) is the correct type for money fields", ], "max_steps": 20, "expected_schema": [ TableInfo(name="orders", columns=[ ColumnInfo(name="order_id", data_type="VARCHAR(10)", primary_key=True), ColumnInfo(name="customer_name", data_type="VARCHAR(50)"), ColumnInfo(name="total_amount", data_type="DECIMAL(10,2)"), ColumnInfo(name="order_date", data_type="TIMESTAMP"), ColumnInfo(name="status", data_type="VARCHAR(20)"), ]), TableInfo(name="products", columns=[ ColumnInfo(name="product_id", data_type="VARCHAR(10)", primary_key=True), ColumnInfo(name="product_name", data_type="TEXT"), ColumnInfo(name="price", data_type="DECIMAL(10,2)"), ColumnInfo(name="stock_quantity", data_type="INT"), ]), ], }, "hard": { "description": ( "A fully denormalized legacy table stores everything in one blob. " "You must normalize it into 3NF: split into proper tables, fix types, " "add primary keys, and establish foreign key relationships." ), "initial_schema": [ TableInfo(name="everything", columns=[ ColumnInfo(name="row_id", data_type="INT", primary_key=True), ColumnInfo(name="cust_name", data_type="TEXT"), ColumnInfo(name="cust_email", data_type="TEXT"), ColumnInfo(name="cust_phone", data_type="TEXT"), ColumnInfo(name="item_name", data_type="TEXT"), ColumnInfo(name="item_price", data_type="TEXT"), ColumnInfo(name="item_qty", data_type="TEXT"), ColumnInfo(name="order_date", data_type="TEXT"), ColumnInfo(name="order_total", data_type="TEXT"), ]) ], "target_requirements": [ "Create table 'customers' with columns: customer_id (INT, PK), name (VARCHAR(100)), email (VARCHAR(150)), phone (VARCHAR(20))", "Create table 'products' with columns: product_id (INT, PK), product_name (VARCHAR(200)), price (DECIMAL(10,2))", "Create table 'orders' with columns: order_id (INT, PK), customer_id (INT, FK->customers), order_date (TIMESTAMP), total_amount (DECIMAL(10,2))", "Create table 'order_items' with columns: item_id (INT, PK), order_id (INT, FK->orders), product_id (INT, FK->products), quantity (INT)", "Add foreign key: orders.customer_id -> customers.customer_id", "Add foreign key: order_items.order_id -> orders.order_id", "Add foreign key: order_items.product_id -> products.product_id", "Remove the 'everything' table after normalization", ], "hints": [ "Normalize step by step: customers → products → orders → order_items", "Foreign keys require the referenced table to exist first", "Use NORMALIZE_TABLE action to create the new tables from 'everything'", ], "max_steps": 30, "expected_tables": ["customers", "products", "orders", "order_items"], "expected_schema": [ TableInfo(name="customers", columns=[ ColumnInfo(name="customer_id", data_type="INT", primary_key=True), ColumnInfo(name="name", data_type="VARCHAR(100)"), ColumnInfo(name="email", data_type="VARCHAR(150)"), ColumnInfo(name="phone", data_type="VARCHAR(20)"), ]), TableInfo(name="products", columns=[ ColumnInfo(name="product_id", data_type="INT", primary_key=True), ColumnInfo(name="product_name", data_type="VARCHAR(200)"), ColumnInfo(name="price", data_type="DECIMAL(10,2)"), ]), TableInfo(name="orders", columns=[ ColumnInfo(name="order_id", data_type="INT", primary_key=True), ColumnInfo(name="customer_id", data_type="INT", foreign_key="customers.customer_id"), ColumnInfo(name="order_date", data_type="TIMESTAMP"), ColumnInfo(name="total_amount", data_type="DECIMAL(10,2)"), ]), TableInfo(name="order_items", columns=[ ColumnInfo(name="item_id", data_type="INT", primary_key=True), ColumnInfo(name="order_id", data_type="INT", foreign_key="orders.order_id"), ColumnInfo(name="product_id", data_type="INT", foreign_key="products.product_id"), ColumnInfo(name="quantity", data_type="INT"), ]), ], }, } # --------------------------------------------------------------------------- # Environment class # --------------------------------------------------------------------------- class DBMigrationEnv: def __init__(self): self.task_name: str = "easy" self.schema: List[TableInfo] = [] self.steps_taken: List[Dict] = [] self.done: bool = False self.step_count: int = 0 self.score: float = 0.0 def reset(self, task_name: str = "easy") -> ResetResult: assert task_name in TASKS, f"Unknown task: {task_name}. Choose from {list(TASKS.keys())}" self.task_name = task_name task = TASKS[task_name] self.schema = copy.deepcopy(task["initial_schema"]) self.steps_taken = [] self.done = False self.step_count = 0 self.score = 0.0 obs = self._build_observation() return ResetResult( observation=obs, task_name=task_name, task_description=task["description"], ) def step(self, action: Action) -> StepResult: if self.done: return StepResult( observation=self._build_observation(), reward=-0.1, done=True, info={"error": "episode already done"}, error="Episode already done", ) task = TASKS[self.task_name] self.step_count += 1 # Check step limit if self.step_count > task["max_steps"]: self.done = True return StepResult( observation=self._build_observation(), reward=-0.2, done=True, info={"error": "max steps exceeded"}, error="Max steps exceeded", ) # Handle DONE if action.operation == OperationType.DONE: self.done = True final_score = self._grade() self.score = final_score reward = final_score return StepResult( observation=self._build_observation(), reward=reward, done=True, info={"final_score": final_score, "message": "Episode ended by agent"}, ) # Apply action reward, error = self._apply_action(action) self.steps_taken.append({ "operation": action.operation, "table": action.table, "column": action.column, "new_name": action.new_name, "data_type": action.data_type, "reward": reward, "error": error, }) # Check if task is complete partial_score = self._grade() self.score = partial_score if partial_score >= 1.0: self.done = True return StepResult( observation=self._build_observation(), reward=reward, done=self.done, info={"partial_score": partial_score, "step": self.step_count}, error=error, ) def state(self): from server.schemas import StateResult return StateResult( observation=self._build_observation(), task_name=self.task_name, step_count=self.step_count, done=self.done, score=self.score, ) # ----------------------------------------------------------------------- # Action application # ----------------------------------------------------------------------- def _apply_action(self, action: Action) -> Tuple[float, Optional[str]]: op = action.operation if op == OperationType.RENAME_TABLE: return self._rename_table(action.table, action.new_name) elif op == OperationType.RENAME_COLUMN: return self._rename_column(action.table, action.column, action.new_name) elif op == OperationType.ADD_COLUMN: return self._add_column(action.table, action.column, action.data_type) elif op == OperationType.DROP_COLUMN: return self._drop_column(action.table, action.column) elif op == OperationType.CHANGE_TYPE: return self._change_type(action.table, action.column, action.data_type) elif op == OperationType.ADD_FOREIGN_KEY: return self._add_foreign_key(action.table, action.column, action.reference_table, action.reference_column) elif op == OperationType.NORMALIZE_TABLE: return self._normalize_table(action) return 0.0, f"Unknown operation: {op}" def _find_table(self, name: str) -> Optional[TableInfo]: for t in self.schema: if t.name == name: return t return None def _rename_table(self, old_name: str, new_name: str) -> Tuple[float, Optional[str]]: if not new_name: return -0.1, "new_name is required for rename_table" t = self._find_table(old_name) if t is None: return -0.1, f"Table '{old_name}' not found" if self._find_table(new_name): return -0.1, f"Table '{new_name}' already exists" # Check if this rename is expected expected = self._is_expected_table_rename(old_name, new_name) t.name = new_name return (0.15 if expected else -0.05), None def _rename_column(self, table_name: str, col_name: str, new_name: str) -> Tuple[float, Optional[str]]: if not new_name or not col_name: return -0.1, "column and new_name are required for rename_column" t = self._find_table(table_name) if t is None: return -0.1, f"Table '{table_name}' not found" col = next((c for c in t.columns if c.name == col_name), None) if col is None: return -0.1, f"Column '{col_name}' not found in '{table_name}'" expected = self._is_expected_column_rename(table_name, col_name, new_name) col.name = new_name return (0.1 if expected else -0.05), None def _add_column(self, table_name: str, col_name: str, data_type: str) -> Tuple[float, Optional[str]]: if not col_name or not data_type: return -0.1, "column and data_type required for add_column" t = self._find_table(table_name) if t is None: return -0.1, f"Table '{table_name}' not found" if any(c.name == col_name for c in t.columns): return -0.1, f"Column '{col_name}' already exists in '{table_name}'" expected = self._is_expected_add_column(table_name, col_name, data_type) t.columns.append(ColumnInfo(name=col_name, data_type=data_type)) return (0.1 if expected else -0.05), None def _drop_column(self, table_name: str, col_name: str) -> Tuple[float, Optional[str]]: t = self._find_table(table_name) if t is None: return -0.1, f"Table '{table_name}' not found" col = next((c for c in t.columns if c.name == col_name), None) if col is None: return -0.1, f"Column '{col_name}' not found in '{table_name}'" if col.primary_key: return -0.2, f"Cannot drop primary key column '{col_name}'" t.columns = [c for c in t.columns if c.name != col_name] return 0.05, None def _change_type(self, table_name: str, col_name: str, data_type: str) -> Tuple[float, Optional[str]]: if not data_type: return -0.1, "data_type required for change_type" t = self._find_table(table_name) if t is None: return -0.1, f"Table '{table_name}' not found" col = next((c for c in t.columns if c.name == col_name), None) if col is None: return -0.1, f"Column '{col_name}' not found in '{table_name}'" expected = self._is_expected_type_change(table_name, col_name, data_type) col.data_type = data_type return (0.1 if expected else -0.05), None def _add_foreign_key(self, table_name, col_name, ref_table, ref_col) -> Tuple[float, Optional[str]]: if not ref_table or not ref_col: return -0.1, "reference_table and reference_column required" t = self._find_table(table_name) if t is None: return -0.1, f"Table '{table_name}' not found" col = next((c for c in t.columns if c.name == col_name), None) if col is None: return -0.1, f"Column '{col_name}' not found in '{table_name}'" if self._find_table(ref_table) is None: return -0.15, f"Referenced table '{ref_table}' does not exist yet" col.foreign_key = f"{ref_table}.{ref_col}" return 0.15, None def _normalize_table(self, action: Action) -> Tuple[float, Optional[str]]: """ For the hard task: agent uses NORMALIZE_TABLE to declare a new table they're extracting from 'everything'. The new table name goes in action.new_name, and data_type field carries a JSON-like column spec string. We parse new_name as the new table to create with appropriate columns based on whether it matches expected tables. """ new_table_name = action.new_name if not new_table_name: return -0.1, "new_name required for normalize_table (name of new table to create)" if self._find_table(new_table_name): return -0.05, f"Table '{new_table_name}' already exists" expected_tables = TASKS["hard"]["expected_schema"] expected = next((t for t in expected_tables if t.name == new_table_name), None) if expected is None: # Creating a table not in requirements — penalize self.schema.append(TableInfo(name=new_table_name, columns=[ ColumnInfo(name="id", data_type="INT", primary_key=True) ])) return -0.1, f"Table '{new_table_name}' is not in requirements" # Create the expected table self.schema.append(copy.deepcopy(expected)) # Remove foreign keys temporarily (agent must add them via ADD_FOREIGN_KEY) t = self._find_table(new_table_name) for col in t.columns: col.foreign_key = None return 0.2, None # ----------------------------------------------------------------------- # Graders (deterministic, 0.0 - 1.0) # ----------------------------------------------------------------------- def _grade(self) -> float: task_name = self.task_name if task_name == "easy": return self._grade_easy() elif task_name == "medium": return self._grade_medium() elif task_name == "hard": return self._grade_hard() return 0.0 def _grade_easy(self) -> float: """Check exact match against expected schema.""" expected = TASKS["easy"]["expected_schema"] score = self._schema_match_score(expected) # Penalty for extra steps (efficiency) step_penalty = max(0, (self.step_count - 6) * 0.02) return max(0.0, min(1.0, score - step_penalty)) def _grade_medium(self) -> float: expected = TASKS["medium"]["expected_schema"] score = self._schema_match_score(expected) step_penalty = max(0, (self.step_count - 14) * 0.01) return max(0.0, min(1.0, score - step_penalty)) def _grade_hard(self) -> float: expected_tables = ["customers", "products", "orders", "order_items"] expected_schema = TASKS["hard"]["expected_schema"] # Component 1: tables exist (25%) tables_present = sum(1 for t in expected_tables if self._find_table(t) is not None) table_score = tables_present / len(expected_tables) * 0.25 # Component 2: correct columns/types (50%) col_score = self._schema_match_score(expected_schema) * 0.50 # Component 3: foreign keys (25%) fk_checks = [ ("orders", "customer_id", "customers.customer_id"), ("order_items", "order_id", "orders.order_id"), ("order_items", "product_id", "products.product_id"), ] fks_correct = 0 for tname, cname, fk in fk_checks: t = self._find_table(tname) if t: col = next((c for c in t.columns if c.name == cname), None) if col and col.foreign_key == fk: fks_correct += 1 fk_score = (fks_correct / len(fk_checks)) * 0.25 # Penalty if 'everything' table still exists if self._find_table("everything"): total = table_score + col_score + fk_score - 0.1 else: total = table_score + col_score + fk_score return max(0.0, min(1.0, total)) def _schema_match_score(self, expected_tables: List[TableInfo]) -> float: """Partial credit: score each table and average.""" if not expected_tables: return 0.0 table_scores = [] for exp_table in expected_tables: actual = self._find_table(exp_table.name) if actual is None: table_scores.append(0.0) continue # Score columns exp_cols = {c.name: c for c in exp_table.columns} act_cols = {c.name: c for c in actual.columns} if not exp_cols: table_scores.append(1.0) continue col_score = 0.0 for cname, ecol in exp_cols.items(): if cname in act_cols: acol = act_cols[cname] match = 1.0 if acol.data_type.upper() != ecol.data_type.upper(): match -= 0.3 if acol.primary_key != ecol.primary_key: match -= 0.2 col_score += max(0, match) # Missing column = 0 for that column table_scores.append(col_score / len(exp_cols)) return sum(table_scores) / len(table_scores) # ----------------------------------------------------------------------- # Helpers for reward shaping # ----------------------------------------------------------------------- def _is_expected_table_rename(self, old_name: str, new_name: str) -> bool: task = TASKS[self.task_name] req_str = f"Rename table '{old_name}' to '{new_name}'" return any(req_str in r for r in task.get("target_requirements", [])) def _is_expected_column_rename(self, table: str, old_col: str, new_col: str) -> bool: task = TASKS[self.task_name] req_str = f"Rename column '{old_col}' to '{new_col}'" return any(req_str in r for r in task.get("target_requirements", [])) def _is_expected_add_column(self, table: str, col: str, dtype: str) -> bool: task = TASKS[self.task_name] req_str = f"Add column '{col}'" return any(req_str in r for r in task.get("target_requirements", [])) def _is_expected_type_change(self, table: str, col: str, dtype: str) -> bool: task = TASKS[self.task_name] return any(f"'{col}'" in r and dtype.upper() in r.upper() for r in task.get("target_requirements", [])) def _build_observation(self) -> Observation: task = TASKS[self.task_name] violations = self._check_violations() return Observation( current_schema=copy.deepcopy(self.schema), target_requirements=task["target_requirements"], steps_taken=self.steps_taken[-10:], # last 10 steps violations=violations, hints=task.get("hints", []), step_count=self.step_count, max_steps=task["max_steps"], ) def _check_violations(self) -> List[str]: violations = [] # Check for duplicate table names names = [t.name for t in self.schema] if len(names) != len(set(names)): violations.append("Duplicate table names detected") # Check FK references exist for t in self.schema: for c in t.columns: if c.foreign_key: parts = c.foreign_key.split(".") if len(parts) == 2: ref_table = self._find_table(parts[0]) if ref_table is None: violations.append( f"FK violation: {t.name}.{c.name} references non-existent table '{parts[0]}'" ) return violations