Spaces:
Sleeping
Sleeping
| """ | |
| env/db_simulator.py β SQL Database Engineer Agent | |
| Simulates a production database responding to optimization actions. | |
| Core mechanism: index coverage reduces query execution time by up to 85-90%. | |
| """ | |
| import math | |
| import random | |
| from typing import Optional | |
| class DatabaseSimulator: | |
| """ | |
| Simulates a production database that degrades over time. | |
| The agent applies optimization actions and sees performance scores change. | |
| Performance score: 0-100 (100 = all queries running at target speed). | |
| The agent's goal: get performance_score >= target_score. | |
| """ | |
| def __init__(self, scenario: dict): | |
| self.scenario = scenario | |
| self.tables = {t["name"]: dict(t) for t in scenario["tables"]} | |
| self.queries = [dict(q) for q in scenario["slow_queries"]] | |
| self.indexes = { | |
| name: list(t.get("indexes", ["PRIMARY"])) | |
| for name, t in self.tables.items() | |
| } | |
| self.stats_fresh = {name: False for name in self.tables} | |
| self.partitioned = {name: False for name in self.tables} | |
| self.baseline = self._compute_score() | |
| self.history = [self.baseline] | |
| self.best_score = self.baseline | |
| self.target_score = scenario.get("target_score", 85.0) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # PUBLIC ACTIONS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def apply_action(self, action_type: str, payload: dict) -> dict: | |
| """ | |
| Apply an optimization action to the database. | |
| Returns delta showing performance change. | |
| """ | |
| old_score = self._compute_score() | |
| affected = [] | |
| if action_type == "create_index": | |
| table = payload.get("table", "") | |
| cols = payload.get("columns", []) | |
| if isinstance(cols, str): | |
| cols = [c.strip() for c in cols.split(",")] | |
| idx_name = "idx_" + "_".join(cols) | |
| if table in self.indexes and idx_name not in self.indexes[table]: | |
| self.indexes[table].append(idx_name) | |
| affected = self._queries_benefiting_from_index(table, cols) | |
| else: | |
| # Duplicate index β no benefit | |
| return { | |
| "old_score": old_score, "new_score": old_score, | |
| "delta": 0.0, "affected_queries": [], | |
| "improved": False, "message": "Index already exists or table not found." | |
| } | |
| elif action_type == "rewrite_query": | |
| qid = payload.get("query_id", "") | |
| new_sql = payload.get("new_sql", "") | |
| for q in self.queries: | |
| if q["id"] == qid: | |
| improvement = self._estimate_rewrite(new_sql, q) | |
| q["avg_ms"] = max(10, int(q["avg_ms"] * (1 - improvement))) | |
| affected = [qid] | |
| break | |
| elif action_type == "partition_table": | |
| table = payload.get("table", "") | |
| if table in self.tables and not self.partitioned.get(table): | |
| self.partitioned[table] = True | |
| affected = [q["id"] for q in self.queries if table in q.get("sql", "")] | |
| elif action_type == "analyze_statistics": | |
| table = payload.get("table", "") | |
| if table in self.tables: | |
| self.stats_fresh[table] = True | |
| affected = [q["id"] for q in self.queries if table in q.get("sql", "")] | |
| elif action_type == "drop_index": | |
| table = payload.get("table", "") | |
| idx_name = payload.get("index_name", "") | |
| if idx_name in self.indexes.get(table, []) and idx_name != "PRIMARY": | |
| self.indexes[table].remove(idx_name) | |
| elif action_type == "add_column": | |
| table = payload.get("table", "") | |
| col = payload.get("column", "") | |
| purpose = payload.get("purpose", "") | |
| if table in self.tables: | |
| if "extra_columns" not in self.tables[table]: | |
| self.tables[table]["extra_columns"] = [] | |
| self.tables[table]["extra_columns"].append(col) | |
| # Denormalization can help JOINy queries | |
| affected = [ | |
| q["id"] for q in self.queries | |
| if "join" in q.get("sql", "").lower() and table in q.get("sql", "") | |
| ] | |
| new_score = self._compute_score() | |
| self.history.append(new_score) | |
| if new_score > self.best_score: | |
| self.best_score = new_score | |
| return { | |
| "old_score": round(old_score, 2), | |
| "new_score": round(new_score, 2), | |
| "delta": round(new_score - old_score, 2), | |
| "affected_queries": affected, | |
| "improved": new_score > old_score, | |
| } | |
| def inspect_query(self, query_id: str) -> dict: | |
| """ | |
| EXPLAIN a slow query β reveals scan type, rows examined, cost. | |
| This is the agent's primary investigation tool. | |
| """ | |
| for q in self.queries: | |
| if q["id"] == query_id: | |
| has_index = self._check_query_index_coverage(q) > 0.1 | |
| is_partition = self.partitioned.get(q.get("main_table", ""), False) | |
| rows_examined = 50 if has_index else q.get("rows_examined", | |
| self.tables.get(q.get("main_table", ""), {}).get("rows", 50000)) | |
| return { | |
| "query_id": query_id, | |
| "sql": q["sql"], | |
| "avg_ms": q["avg_ms"], | |
| "scan_type": "INDEX RANGE SCAN" if has_index else "FULL TABLE SCAN", | |
| "rows_examined": rows_examined, | |
| "partitioned": is_partition, | |
| "optimization_hint": ( | |
| "Query is using index efficiently." | |
| if has_index | |
| else "No index covering WHERE columns. Consider adding composite index." | |
| ), | |
| "main_table": q.get("main_table", "unknown"), | |
| } | |
| return {"error": f"Query '{query_id}' not found"} | |
| def analyze_indexes(self, table: str) -> dict: | |
| """ | |
| Show all indexes on a table + usage stats + missing index hints. | |
| """ | |
| if table not in self.tables: | |
| return {"error": f"Table '{table}' not found"} | |
| existing = self.indexes.get(table, []) | |
| hints = [ | |
| h for h in self.scenario.get("missing_index_hints", []) | |
| if h.get("table") == table | |
| ] | |
| used_by = [] | |
| for q in self.queries: | |
| cov = self._check_query_index_coverage(q) | |
| if table in q.get("sql", "") and cov > 0.1: | |
| used_by.append(q["id"]) | |
| return { | |
| "table": table, | |
| "row_count": self.tables[table].get("rows", 0), | |
| "existing_indexes": existing, | |
| "indexes_used_by": used_by, | |
| "missing_hints": hints, | |
| "stats_fresh": self.stats_fresh.get(table, False), | |
| "partitioned": self.partitioned.get(table, False), | |
| "size_mb": self.tables[table].get("size_mb", 0), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # STATE | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_current_state(self) -> dict: | |
| """Returns the full current DB state for the Observation.""" | |
| return { | |
| "performance_score": round(self._compute_score(), 2), | |
| "baseline_score": round(self.baseline, 2), | |
| "target_score": self.target_score, | |
| "tables": list(self.tables.values()), | |
| "slow_queries": self.queries, | |
| "indexes": self.indexes, | |
| "history": self.history, | |
| "best_score": round(self.best_score, 2), | |
| } | |
| def get_performance_score(self) -> float: | |
| return round(self._compute_score(), 2) | |
| def is_target_reached(self) -> bool: | |
| return self._compute_score() >= self.target_score | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTERNAL SCORING ENGINE | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _compute_score(self) -> float: | |
| """ | |
| Core scoring: calculates performance score 0-100. | |
| Higher = better. Based on how fast queries run given current indexes. | |
| """ | |
| if not self.queries: | |
| return 0.0 | |
| scores = [] | |
| for q in self.queries: | |
| table = q.get("main_table", "") | |
| coverage = self._check_query_index_coverage(q) | |
| part_bonus = 0.30 if self.partitioned.get(table, False) else 0.0 | |
| stats_bonus = 0.05 if self.stats_fresh.get(table, False) else 0.0 | |
| total_reduction = min(coverage * 0.85 + part_bonus + stats_bonus, 0.97) | |
| effective_ms = q["avg_ms"] * (1 - total_reduction) | |
| # Score formula: 100ms = score 99, 1000ms = score 90, 8500ms = ~14 | |
| score = max(0.0, 100.0 - (effective_ms / 100.0)) | |
| scores.append(score) | |
| return round(sum(scores) / len(scores), 2) | |
| def _check_query_index_coverage(self, query: dict) -> float: | |
| """ | |
| Returns 0.0-1.0 representing how well indexes cover this query's WHERE clause. | |
| 0.0 = full table scan, 1.0 = perfect index coverage. | |
| """ | |
| sql = query.get("sql", "").lower() | |
| for table, indexes in self.indexes.items(): | |
| if table not in sql: | |
| continue | |
| for idx in indexes: | |
| if idx == "PRIMARY": | |
| # Primary key only helps if query filters by primary key | |
| if "where id=" in sql or "where id =" in sql: | |
| return 0.95 | |
| continue | |
| # Extract columns from index name (idx_col1_col2) | |
| cols = idx.replace("idx_", "").split("_") | |
| matches = sum(1 for c in cols if c in sql) | |
| if matches >= 2: | |
| return 0.90 # Composite index β excellent coverage | |
| if matches == 1: | |
| return 0.60 # Single column β partial coverage | |
| return 0.0 | |
| def _queries_benefiting_from_index(self, table: str, cols: list) -> list: | |
| """Returns query IDs that would benefit from an index on given table/columns.""" | |
| benefiting = [] | |
| for q in self.queries: | |
| sql = q.get("sql", "").lower() | |
| if table in sql and any(c.lower() in sql for c in cols): | |
| benefiting.append(q["id"]) | |
| return benefiting | |
| def _estimate_rewrite(self, new_sql: str, query: dict) -> float: | |
| """ | |
| Estimates improvement factor from a query rewrite (0.0 to 0.70). | |
| Checks for common optimization patterns. | |
| """ | |
| new_lower = new_sql.lower() | |
| old_lower = query.get("sql", "").lower() | |
| improvement = 0.0 | |
| # Remove SELECT * β specific columns | |
| if "select *" not in new_lower and "select *" in old_lower: | |
| improvement += 0.20 | |
| # Add LIMIT clause | |
| if "limit " in new_lower and "limit " not in old_lower: | |
| improvement += 0.15 | |
| # Use EXISTS instead of IN subquery | |
| if "exists" in new_lower and "in (select" in old_lower: | |
| improvement += 0.25 | |
| # Use INNER JOIN instead of implicit cross join | |
| if "inner join" in new_lower and "," in old_lower and "join" not in old_lower: | |
| improvement += 0.30 | |
| # Add WHERE clause that was missing | |
| if "where" in new_lower and "where" not in old_lower: | |
| improvement += 0.35 | |
| # Use COALESCE / ISNULL | |
| if "coalesce" in new_lower: | |
| improvement += 0.05 | |
| return min(improvement, 0.70) |