hissterical's picture
Upload 9 files
a5c89a3 verified
"""
DB Schema Migration Environment
Three tasks:
- easy: rename messy columns/tables to clean names
- medium: rename + add columns + fix data types
- hard: full normalization + foreign keys + type fixes
"""
import copy
from typing import Dict, Any, List, Optional, Tuple
from server.schemas import (
Action, Observation, TableInfo, ColumnInfo,
OperationType, StepResult, ResetResult
)
# ---------------------------------------------------------------------------
# Task definitions
# ---------------------------------------------------------------------------
TASKS = {
"easy": {
"description": (
"A legacy user table was created with terrible naming conventions. "
"Rename all tables and columns to follow snake_case conventions and "
"meaningful names as described in the requirements."
),
"initial_schema": [
TableInfo(name="tbl_usr", columns=[
ColumnInfo(name="usr_id", data_type="INT", primary_key=True),
ColumnInfo(name="usr_nm", data_type="VARCHAR(50)"),
ColumnInfo(name="usr_eml", data_type="VARCHAR(100)"),
ColumnInfo(name="dt_crt", data_type="VARCHAR(20)"),
ColumnInfo(name="stat_cd", data_type="INT"),
])
],
"target_requirements": [
"Rename table 'tbl_usr' to 'users'",
"Rename column 'usr_id' to 'id'",
"Rename column 'usr_nm' to 'username'",
"Rename column 'usr_eml' to 'email'",
"Rename column 'dt_crt' to 'created_at'",
"Rename column 'stat_cd' to 'status'",
],
"hints": [
"All operations are RENAME_TABLE or RENAME_COLUMN",
"Start with the table rename, then columns",
],
"max_steps": 10,
"expected_schema": [
TableInfo(name="users", columns=[
ColumnInfo(name="id", data_type="INT", primary_key=True),
ColumnInfo(name="username", data_type="VARCHAR(50)"),
ColumnInfo(name="email", data_type="VARCHAR(100)"),
ColumnInfo(name="created_at", data_type="VARCHAR(20)"),
ColumnInfo(name="status", data_type="INT"),
])
],
},
"medium": {
"description": (
"An orders database has wrong column names, wrong data types, and is "
"missing important columns. Fix naming, types, and add the missing fields."
),
"initial_schema": [
TableInfo(name="order_tbl", columns=[
ColumnInfo(name="oid", data_type="VARCHAR(10)", primary_key=True),
ColumnInfo(name="cust", data_type="VARCHAR(50)"),
ColumnInfo(name="amt", data_type="VARCHAR(20)"),
ColumnInfo(name="ord_dte", data_type="VARCHAR(30)"),
]),
TableInfo(name="prod_tbl", columns=[
ColumnInfo(name="pid", data_type="VARCHAR(10)", primary_key=True),
ColumnInfo(name="pname", data_type="TEXT"),
ColumnInfo(name="prc", data_type="VARCHAR(20)"),
]),
],
"target_requirements": [
"Rename table 'order_tbl' to 'orders'",
"Rename table 'prod_tbl' to 'products'",
"Rename column 'oid' to 'order_id' in orders",
"Rename column 'cust' to 'customer_name' in orders",
"Rename column 'amt' to 'total_amount' in orders",
"Rename column 'ord_dte' to 'order_date' in orders",
"Change type of 'total_amount' in orders to DECIMAL(10,2)",
"Change type of 'order_date' in orders to TIMESTAMP",
"Rename column 'pid' to 'product_id' in products",
"Rename column 'pname' to 'product_name' in products",
"Rename column 'prc' to 'price' in products",
"Change type of 'price' in products to DECIMAL(10,2)",
"Add column 'stock_quantity' (INT) to products",
"Add column 'status' (VARCHAR(20)) to orders",
],
"hints": [
"Fix table names first, then column names, then types, then add missing columns",
"DECIMAL(10,2) is the correct type for money fields",
],
"max_steps": 20,
"expected_schema": [
TableInfo(name="orders", columns=[
ColumnInfo(name="order_id", data_type="VARCHAR(10)", primary_key=True),
ColumnInfo(name="customer_name", data_type="VARCHAR(50)"),
ColumnInfo(name="total_amount", data_type="DECIMAL(10,2)"),
ColumnInfo(name="order_date", data_type="TIMESTAMP"),
ColumnInfo(name="status", data_type="VARCHAR(20)"),
]),
TableInfo(name="products", columns=[
ColumnInfo(name="product_id", data_type="VARCHAR(10)", primary_key=True),
ColumnInfo(name="product_name", data_type="TEXT"),
ColumnInfo(name="price", data_type="DECIMAL(10,2)"),
ColumnInfo(name="stock_quantity", data_type="INT"),
]),
],
},
"hard": {
"description": (
"A fully denormalized legacy table stores everything in one blob. "
"You must normalize it into 3NF: split into proper tables, fix types, "
"add primary keys, and establish foreign key relationships."
),
"initial_schema": [
TableInfo(name="everything", columns=[
ColumnInfo(name="row_id", data_type="INT", primary_key=True),
ColumnInfo(name="cust_name", data_type="TEXT"),
ColumnInfo(name="cust_email", data_type="TEXT"),
ColumnInfo(name="cust_phone", data_type="TEXT"),
ColumnInfo(name="item_name", data_type="TEXT"),
ColumnInfo(name="item_price", data_type="TEXT"),
ColumnInfo(name="item_qty", data_type="TEXT"),
ColumnInfo(name="order_date", data_type="TEXT"),
ColumnInfo(name="order_total", data_type="TEXT"),
])
],
"target_requirements": [
"Create table 'customers' with columns: customer_id (INT, PK), name (VARCHAR(100)), email (VARCHAR(150)), phone (VARCHAR(20))",
"Create table 'products' with columns: product_id (INT, PK), product_name (VARCHAR(200)), price (DECIMAL(10,2))",
"Create table 'orders' with columns: order_id (INT, PK), customer_id (INT, FK->customers), order_date (TIMESTAMP), total_amount (DECIMAL(10,2))",
"Create table 'order_items' with columns: item_id (INT, PK), order_id (INT, FK->orders), product_id (INT, FK->products), quantity (INT)",
"Add foreign key: orders.customer_id -> customers.customer_id",
"Add foreign key: order_items.order_id -> orders.order_id",
"Add foreign key: order_items.product_id -> products.product_id",
"Remove the 'everything' table after normalization",
],
"hints": [
"Normalize step by step: customers → products → orders → order_items",
"Foreign keys require the referenced table to exist first",
"Use NORMALIZE_TABLE action to create the new tables from 'everything'",
],
"max_steps": 30,
"expected_tables": ["customers", "products", "orders", "order_items"],
"expected_schema": [
TableInfo(name="customers", columns=[
ColumnInfo(name="customer_id", data_type="INT", primary_key=True),
ColumnInfo(name="name", data_type="VARCHAR(100)"),
ColumnInfo(name="email", data_type="VARCHAR(150)"),
ColumnInfo(name="phone", data_type="VARCHAR(20)"),
]),
TableInfo(name="products", columns=[
ColumnInfo(name="product_id", data_type="INT", primary_key=True),
ColumnInfo(name="product_name", data_type="VARCHAR(200)"),
ColumnInfo(name="price", data_type="DECIMAL(10,2)"),
]),
TableInfo(name="orders", columns=[
ColumnInfo(name="order_id", data_type="INT", primary_key=True),
ColumnInfo(name="customer_id", data_type="INT", foreign_key="customers.customer_id"),
ColumnInfo(name="order_date", data_type="TIMESTAMP"),
ColumnInfo(name="total_amount", data_type="DECIMAL(10,2)"),
]),
TableInfo(name="order_items", columns=[
ColumnInfo(name="item_id", data_type="INT", primary_key=True),
ColumnInfo(name="order_id", data_type="INT", foreign_key="orders.order_id"),
ColumnInfo(name="product_id", data_type="INT", foreign_key="products.product_id"),
ColumnInfo(name="quantity", data_type="INT"),
]),
],
},
}
# ---------------------------------------------------------------------------
# Environment class
# ---------------------------------------------------------------------------
class DBMigrationEnv:
def __init__(self):
self.task_name: str = "easy"
self.schema: List[TableInfo] = []
self.steps_taken: List[Dict] = []
self.done: bool = False
self.step_count: int = 0
self.score: float = 0.0
def reset(self, task_name: str = "easy") -> ResetResult:
assert task_name in TASKS, f"Unknown task: {task_name}. Choose from {list(TASKS.keys())}"
self.task_name = task_name
task = TASKS[task_name]
self.schema = copy.deepcopy(task["initial_schema"])
self.steps_taken = []
self.done = False
self.step_count = 0
self.score = 0.0
obs = self._build_observation()
return ResetResult(
observation=obs,
task_name=task_name,
task_description=task["description"],
)
def step(self, action: Action) -> StepResult:
if self.done:
return StepResult(
observation=self._build_observation(),
reward=-0.1,
done=True,
info={"error": "episode already done"},
error="Episode already done",
)
task = TASKS[self.task_name]
self.step_count += 1
# Check step limit
if self.step_count > task["max_steps"]:
self.done = True
return StepResult(
observation=self._build_observation(),
reward=-0.2,
done=True,
info={"error": "max steps exceeded"},
error="Max steps exceeded",
)
# Handle DONE
if action.operation == OperationType.DONE:
self.done = True
final_score = self._grade()
self.score = final_score
reward = final_score
return StepResult(
observation=self._build_observation(),
reward=reward,
done=True,
info={"final_score": final_score, "message": "Episode ended by agent"},
)
# Apply action
reward, error = self._apply_action(action)
self.steps_taken.append({
"operation": action.operation,
"table": action.table,
"column": action.column,
"new_name": action.new_name,
"data_type": action.data_type,
"reward": reward,
"error": error,
})
# Check if task is complete
partial_score = self._grade()
self.score = partial_score
if partial_score >= 1.0:
self.done = True
return StepResult(
observation=self._build_observation(),
reward=reward,
done=self.done,
info={"partial_score": partial_score, "step": self.step_count},
error=error,
)
def state(self):
from server.schemas import StateResult
return StateResult(
observation=self._build_observation(),
task_name=self.task_name,
step_count=self.step_count,
done=self.done,
score=self.score,
)
# -----------------------------------------------------------------------
# Action application
# -----------------------------------------------------------------------
def _apply_action(self, action: Action) -> Tuple[float, Optional[str]]:
op = action.operation
if op == OperationType.RENAME_TABLE:
return self._rename_table(action.table, action.new_name)
elif op == OperationType.RENAME_COLUMN:
return self._rename_column(action.table, action.column, action.new_name)
elif op == OperationType.ADD_COLUMN:
return self._add_column(action.table, action.column, action.data_type)
elif op == OperationType.DROP_COLUMN:
return self._drop_column(action.table, action.column)
elif op == OperationType.CHANGE_TYPE:
return self._change_type(action.table, action.column, action.data_type)
elif op == OperationType.ADD_FOREIGN_KEY:
return self._add_foreign_key(action.table, action.column,
action.reference_table, action.reference_column)
elif op == OperationType.NORMALIZE_TABLE:
return self._normalize_table(action)
return 0.0, f"Unknown operation: {op}"
def _find_table(self, name: str) -> Optional[TableInfo]:
for t in self.schema:
if t.name == name:
return t
return None
def _rename_table(self, old_name: str, new_name: str) -> Tuple[float, Optional[str]]:
if not new_name:
return -0.1, "new_name is required for rename_table"
t = self._find_table(old_name)
if t is None:
return -0.1, f"Table '{old_name}' not found"
if self._find_table(new_name):
return -0.1, f"Table '{new_name}' already exists"
# Check if this rename is expected
expected = self._is_expected_table_rename(old_name, new_name)
t.name = new_name
return (0.15 if expected else -0.05), None
def _rename_column(self, table_name: str, col_name: str, new_name: str) -> Tuple[float, Optional[str]]:
if not new_name or not col_name:
return -0.1, "column and new_name are required for rename_column"
t = self._find_table(table_name)
if t is None:
return -0.1, f"Table '{table_name}' not found"
col = next((c for c in t.columns if c.name == col_name), None)
if col is None:
return -0.1, f"Column '{col_name}' not found in '{table_name}'"
expected = self._is_expected_column_rename(table_name, col_name, new_name)
col.name = new_name
return (0.1 if expected else -0.05), None
def _add_column(self, table_name: str, col_name: str, data_type: str) -> Tuple[float, Optional[str]]:
if not col_name or not data_type:
return -0.1, "column and data_type required for add_column"
t = self._find_table(table_name)
if t is None:
return -0.1, f"Table '{table_name}' not found"
if any(c.name == col_name for c in t.columns):
return -0.1, f"Column '{col_name}' already exists in '{table_name}'"
expected = self._is_expected_add_column(table_name, col_name, data_type)
t.columns.append(ColumnInfo(name=col_name, data_type=data_type))
return (0.1 if expected else -0.05), None
def _drop_column(self, table_name: str, col_name: str) -> Tuple[float, Optional[str]]:
t = self._find_table(table_name)
if t is None:
return -0.1, f"Table '{table_name}' not found"
col = next((c for c in t.columns if c.name == col_name), None)
if col is None:
return -0.1, f"Column '{col_name}' not found in '{table_name}'"
if col.primary_key:
return -0.2, f"Cannot drop primary key column '{col_name}'"
t.columns = [c for c in t.columns if c.name != col_name]
return 0.05, None
def _change_type(self, table_name: str, col_name: str, data_type: str) -> Tuple[float, Optional[str]]:
if not data_type:
return -0.1, "data_type required for change_type"
t = self._find_table(table_name)
if t is None:
return -0.1, f"Table '{table_name}' not found"
col = next((c for c in t.columns if c.name == col_name), None)
if col is None:
return -0.1, f"Column '{col_name}' not found in '{table_name}'"
expected = self._is_expected_type_change(table_name, col_name, data_type)
col.data_type = data_type
return (0.1 if expected else -0.05), None
def _add_foreign_key(self, table_name, col_name, ref_table, ref_col) -> Tuple[float, Optional[str]]:
if not ref_table or not ref_col:
return -0.1, "reference_table and reference_column required"
t = self._find_table(table_name)
if t is None:
return -0.1, f"Table '{table_name}' not found"
col = next((c for c in t.columns if c.name == col_name), None)
if col is None:
return -0.1, f"Column '{col_name}' not found in '{table_name}'"
if self._find_table(ref_table) is None:
return -0.15, f"Referenced table '{ref_table}' does not exist yet"
col.foreign_key = f"{ref_table}.{ref_col}"
return 0.15, None
def _normalize_table(self, action: Action) -> Tuple[float, Optional[str]]:
"""
For the hard task: agent uses NORMALIZE_TABLE to declare a new table
they're extracting from 'everything'. The new table name goes in action.new_name,
and data_type field carries a JSON-like column spec string.
We parse new_name as the new table to create with appropriate columns
based on whether it matches expected tables.
"""
new_table_name = action.new_name
if not new_table_name:
return -0.1, "new_name required for normalize_table (name of new table to create)"
if self._find_table(new_table_name):
return -0.05, f"Table '{new_table_name}' already exists"
expected_tables = TASKS["hard"]["expected_schema"]
expected = next((t for t in expected_tables if t.name == new_table_name), None)
if expected is None:
# Creating a table not in requirements — penalize
self.schema.append(TableInfo(name=new_table_name, columns=[
ColumnInfo(name="id", data_type="INT", primary_key=True)
]))
return -0.1, f"Table '{new_table_name}' is not in requirements"
# Create the expected table
self.schema.append(copy.deepcopy(expected))
# Remove foreign keys temporarily (agent must add them via ADD_FOREIGN_KEY)
t = self._find_table(new_table_name)
for col in t.columns:
col.foreign_key = None
return 0.2, None
# -----------------------------------------------------------------------
# Graders (deterministic, 0.0 - 1.0)
# -----------------------------------------------------------------------
def _grade(self) -> float:
task_name = self.task_name
if task_name == "easy":
return self._grade_easy()
elif task_name == "medium":
return self._grade_medium()
elif task_name == "hard":
return self._grade_hard()
return 0.0
def _grade_easy(self) -> float:
"""Check exact match against expected schema."""
expected = TASKS["easy"]["expected_schema"]
score = self._schema_match_score(expected)
# Penalty for extra steps (efficiency)
step_penalty = max(0, (self.step_count - 6) * 0.02)
return max(0.0, min(1.0, score - step_penalty))
def _grade_medium(self) -> float:
expected = TASKS["medium"]["expected_schema"]
score = self._schema_match_score(expected)
step_penalty = max(0, (self.step_count - 14) * 0.01)
return max(0.0, min(1.0, score - step_penalty))
def _grade_hard(self) -> float:
expected_tables = ["customers", "products", "orders", "order_items"]
expected_schema = TASKS["hard"]["expected_schema"]
# Component 1: tables exist (25%)
tables_present = sum(1 for t in expected_tables if self._find_table(t) is not None)
table_score = tables_present / len(expected_tables) * 0.25
# Component 2: correct columns/types (50%)
col_score = self._schema_match_score(expected_schema) * 0.50
# Component 3: foreign keys (25%)
fk_checks = [
("orders", "customer_id", "customers.customer_id"),
("order_items", "order_id", "orders.order_id"),
("order_items", "product_id", "products.product_id"),
]
fks_correct = 0
for tname, cname, fk in fk_checks:
t = self._find_table(tname)
if t:
col = next((c for c in t.columns if c.name == cname), None)
if col and col.foreign_key == fk:
fks_correct += 1
fk_score = (fks_correct / len(fk_checks)) * 0.25
# Penalty if 'everything' table still exists
if self._find_table("everything"):
total = table_score + col_score + fk_score - 0.1
else:
total = table_score + col_score + fk_score
return max(0.0, min(1.0, total))
def _schema_match_score(self, expected_tables: List[TableInfo]) -> float:
"""Partial credit: score each table and average."""
if not expected_tables:
return 0.0
table_scores = []
for exp_table in expected_tables:
actual = self._find_table(exp_table.name)
if actual is None:
table_scores.append(0.0)
continue
# Score columns
exp_cols = {c.name: c for c in exp_table.columns}
act_cols = {c.name: c for c in actual.columns}
if not exp_cols:
table_scores.append(1.0)
continue
col_score = 0.0
for cname, ecol in exp_cols.items():
if cname in act_cols:
acol = act_cols[cname]
match = 1.0
if acol.data_type.upper() != ecol.data_type.upper():
match -= 0.3
if acol.primary_key != ecol.primary_key:
match -= 0.2
col_score += max(0, match)
# Missing column = 0 for that column
table_scores.append(col_score / len(exp_cols))
return sum(table_scores) / len(table_scores)
# -----------------------------------------------------------------------
# Helpers for reward shaping
# -----------------------------------------------------------------------
def _is_expected_table_rename(self, old_name: str, new_name: str) -> bool:
task = TASKS[self.task_name]
req_str = f"Rename table '{old_name}' to '{new_name}'"
return any(req_str in r for r in task.get("target_requirements", []))
def _is_expected_column_rename(self, table: str, old_col: str, new_col: str) -> bool:
task = TASKS[self.task_name]
req_str = f"Rename column '{old_col}' to '{new_col}'"
return any(req_str in r for r in task.get("target_requirements", []))
def _is_expected_add_column(self, table: str, col: str, dtype: str) -> bool:
task = TASKS[self.task_name]
req_str = f"Add column '{col}'"
return any(req_str in r for r in task.get("target_requirements", []))
def _is_expected_type_change(self, table: str, col: str, dtype: str) -> bool:
task = TASKS[self.task_name]
return any(f"'{col}'" in r and dtype.upper() in r.upper()
for r in task.get("target_requirements", []))
def _build_observation(self) -> Observation:
task = TASKS[self.task_name]
violations = self._check_violations()
return Observation(
current_schema=copy.deepcopy(self.schema),
target_requirements=task["target_requirements"],
steps_taken=self.steps_taken[-10:], # last 10 steps
violations=violations,
hints=task.get("hints", []),
step_count=self.step_count,
max_steps=task["max_steps"],
)
def _check_violations(self) -> List[str]:
violations = []
# Check for duplicate table names
names = [t.name for t in self.schema]
if len(names) != len(set(names)):
violations.append("Duplicate table names detected")
# Check FK references exist
for t in self.schema:
for c in t.columns:
if c.foreign_key:
parts = c.foreign_key.split(".")
if len(parts) == 2:
ref_table = self._find_table(parts[0])
if ref_table is None:
violations.append(
f"FK violation: {t.name}.{c.name} references non-existent table '{parts[0]}'"
)
return violations