junaid0600 commited on
Commit
d366e7a
Β·
verified Β·
1 Parent(s): bb2cfec

Update env/db_simulator.py

Browse files
Files changed (1) hide show
  1. env/db_simulator.py +107 -105
env/db_simulator.py CHANGED
@@ -1,23 +1,15 @@
1
  """
2
  env/db_simulator.py β€” SQL Database Engineer Agent
3
  Simulates a production database responding to optimization actions.
4
- Core mechanism: index coverage reduces query execution time by up to 85-90%.
 
5
  """
6
 
7
  import math
8
- import random
9
  from typing import Optional
10
 
11
 
12
  class DatabaseSimulator:
13
- """
14
- Simulates a production database that degrades over time.
15
- The agent applies optimization actions and sees performance scores change.
16
-
17
- Performance score: 0-100 (100 = all queries running at target speed).
18
- The agent's goal: get performance_score >= target_score.
19
- """
20
-
21
  def __init__(self, scenario: dict):
22
  self.scenario = scenario
23
  self.tables = {t["name"]: dict(t) for t in scenario["tables"]}
@@ -28,6 +20,12 @@ class DatabaseSimulator:
28
  }
29
  self.stats_fresh = {name: False for name in self.tables}
30
  self.partitioned = {name: False for name in self.tables}
 
 
 
 
 
 
31
  self.baseline = self._compute_score()
32
  self.history = [self.baseline]
33
  self.best_score = self.baseline
@@ -38,10 +36,6 @@ class DatabaseSimulator:
38
  # ─────────────────────────────────────────────
39
 
40
  def apply_action(self, action_type: str, payload: dict) -> dict:
41
- """
42
- Apply an optimization action to the database.
43
- Returns delta showing performance change.
44
- """
45
  old_score = self._compute_score()
46
  affected = []
47
 
@@ -55,11 +49,11 @@ class DatabaseSimulator:
55
  self.indexes[table].append(idx_name)
56
  affected = self._queries_benefiting_from_index(table, cols)
57
  else:
58
- # Duplicate index β€” no benefit
59
  return {
60
  "old_score": old_score, "new_score": old_score,
61
  "delta": 0.0, "affected_queries": [],
62
- "improved": False, "message": "Index already exists or table not found."
 
63
  }
64
 
65
  elif action_type == "rewrite_query":
@@ -76,32 +70,38 @@ class DatabaseSimulator:
76
  table = payload.get("table", "")
77
  if table in self.tables and not self.partitioned.get(table):
78
  self.partitioned[table] = True
79
- affected = [q["id"] for q in self.queries if table in q.get("sql", "")]
 
 
 
80
 
81
  elif action_type == "analyze_statistics":
82
  table = payload.get("table", "")
83
  if table in self.tables:
84
  self.stats_fresh[table] = True
85
- affected = [q["id"] for q in self.queries if table in q.get("sql", "")]
 
 
 
86
 
87
  elif action_type == "drop_index":
88
  table = payload.get("table", "")
89
  idx_name = payload.get("index_name", "")
90
- if idx_name in self.indexes.get(table, []) and idx_name != "PRIMARY":
 
91
  self.indexes[table].remove(idx_name)
92
 
93
  elif action_type == "add_column":
94
- table = payload.get("table", "")
95
- col = payload.get("column", "")
96
- purpose = payload.get("purpose", "")
97
  if table in self.tables:
98
  if "extra_columns" not in self.tables[table]:
99
  self.tables[table]["extra_columns"] = []
100
  self.tables[table]["extra_columns"].append(col)
101
- # Denormalization can help JOINy queries
102
  affected = [
103
  q["id"] for q in self.queries
104
- if "join" in q.get("sql", "").lower() and table in q.get("sql", "")
 
105
  ]
106
 
107
  new_score = self._compute_score()
@@ -118,60 +118,58 @@ class DatabaseSimulator:
118
  }
119
 
120
  def inspect_query(self, query_id: str) -> dict:
121
- """
122
- EXPLAIN a slow query β€” reveals scan type, rows examined, cost.
123
- This is the agent's primary investigation tool.
124
- """
125
  for q in self.queries:
126
  if q["id"] == query_id:
127
- has_index = self._check_query_index_coverage(q) > 0.1
128
- is_partition = self.partitioned.get(q.get("main_table", ""), False)
129
- rows_examined = 50 if has_index else q.get("rows_examined",
130
- self.tables.get(q.get("main_table", ""), {}).get("rows", 50000))
131
-
 
 
 
 
 
132
  return {
133
- "query_id": query_id,
134
- "sql": q["sql"],
135
- "avg_ms": q["avg_ms"],
136
- "scan_type": "INDEX RANGE SCAN" if has_index else "FULL TABLE SCAN",
137
- "rows_examined": rows_examined,
138
- "partitioned": is_partition,
 
139
  "optimization_hint": (
140
  "Query is using index efficiently."
141
- if has_index
142
- else "No index covering WHERE columns. Consider adding composite index."
 
143
  ),
144
- "main_table": q.get("main_table", "unknown"),
145
  }
146
  return {"error": f"Query '{query_id}' not found"}
147
 
148
  def analyze_indexes(self, table: str) -> dict:
149
- """
150
- Show all indexes on a table + usage stats + missing index hints.
151
- """
152
  if table not in self.tables:
153
  return {"error": f"Table '{table}' not found"}
154
-
155
- existing = self.indexes.get(table, [])
156
- hints = [
157
  h for h in self.scenario.get("missing_index_hints", [])
158
  if h.get("table") == table
159
  ]
160
- used_by = []
161
  for q in self.queries:
162
  cov = self._check_query_index_coverage(q)
163
  if table in q.get("sql", "") and cov > 0.1:
164
  used_by.append(q["id"])
165
-
166
  return {
167
- "table": table,
168
- "row_count": self.tables[table].get("rows", 0),
169
  "existing_indexes": existing,
170
- "indexes_used_by": used_by,
171
- "missing_hints": hints,
172
- "stats_fresh": self.stats_fresh.get(table, False),
173
- "partitioned": self.partitioned.get(table, False),
174
- "size_mb": self.tables[table].get("size_mb", 0),
175
  }
176
 
177
  # ─────────────────────────────────────────────
@@ -179,7 +177,6 @@ class DatabaseSimulator:
179
  # ─────────────────────────────────────────────
180
 
181
  def get_current_state(self) -> dict:
182
- """Returns the full current DB state for the Observation."""
183
  return {
184
  "performance_score": round(self._compute_score(), 2),
185
  "baseline_score": round(self.baseline, 2),
@@ -198,35 +195,57 @@ class DatabaseSimulator:
198
  return self._compute_score() >= self.target_score
199
 
200
  # ─────────────────────────────────────────────
201
- # INTERNAL SCORING ENGINE
202
  # ─────────────────────────────────────────────
203
 
204
  def _compute_score(self) -> float:
205
  """
206
- Core scoring: calculates performance score 0-100.
207
- Higher = better. Based on how fast queries run given current indexes.
 
 
 
208
  """
209
  if not self.queries:
210
  return 0.0
211
 
 
 
 
212
  scores = []
213
  for q in self.queries:
214
- table = q.get("main_table", "")
215
- coverage = self._check_query_index_coverage(q)
216
- part_bonus = 0.30 if self.partitioned.get(table, False) else 0.0
217
- stats_bonus = 0.05 if self.stats_fresh.get(table, False) else 0.0
218
- total_reduction = min(coverage * 0.85 + part_bonus + stats_bonus, 0.97)
219
- effective_ms = q["avg_ms"] * (1 - total_reduction)
220
- # Score formula: 100ms = score 99, 1000ms = score 90, 8500ms = ~14
221
- score = max(0.0, 100.0 - (effective_ms / 100.0))
222
- scores.append(score)
223
 
224
- return round(sum(scores) / len(scores), 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
  def _check_query_index_coverage(self, query: dict) -> float:
227
  """
228
- Returns 0.0-1.0 representing how well indexes cover this query's WHERE clause.
229
- 0.0 = full table scan, 1.0 = perfect index coverage.
230
  """
231
  sql = query.get("sql", "").lower()
232
  for table, indexes in self.indexes.items():
@@ -234,59 +253,42 @@ class DatabaseSimulator:
234
  continue
235
  for idx in indexes:
236
  if idx == "PRIMARY":
237
- # Primary key only helps if query filters by primary key
238
  if "where id=" in sql or "where id =" in sql:
239
  return 0.95
240
  continue
241
- # Extract columns from index name (idx_col1_col2)
242
- cols = idx.replace("idx_", "").split("_")
243
  matches = sum(1 for c in cols if c in sql)
244
  if matches >= 2:
245
- return 0.90 # Composite index β€” excellent coverage
246
  if matches == 1:
247
- return 0.60 # Single column β€” partial coverage
248
  return 0.0
249
 
250
- def _queries_benefiting_from_index(self, table: str, cols: list) -> list:
251
- """Returns query IDs that would benefit from an index on given table/columns."""
252
- benefiting = []
253
- for q in self.queries:
254
- sql = q.get("sql", "").lower()
255
- if table in sql and any(c.lower() in sql for c in cols):
256
- benefiting.append(q["id"])
257
- return benefiting
258
 
259
  def _estimate_rewrite(self, new_sql: str, query: dict) -> float:
260
- """
261
- Estimates improvement factor from a query rewrite (0.0 to 0.70).
262
- Checks for common optimization patterns.
263
- """
264
  new_lower = new_sql.lower()
265
  old_lower = query.get("sql", "").lower()
266
  improvement = 0.0
267
-
268
- # Remove SELECT * β†’ specific columns
269
  if "select *" not in new_lower and "select *" in old_lower:
270
  improvement += 0.20
271
-
272
- # Add LIMIT clause
273
  if "limit " in new_lower and "limit " not in old_lower:
274
  improvement += 0.15
275
-
276
- # Use EXISTS instead of IN subquery
277
  if "exists" in new_lower and "in (select" in old_lower:
278
  improvement += 0.25
279
-
280
- # Use INNER JOIN instead of implicit cross join
281
- if "inner join" in new_lower and "," in old_lower and "join" not in old_lower:
282
  improvement += 0.30
283
-
284
- # Add WHERE clause that was missing
285
  if "where" in new_lower and "where" not in old_lower:
286
  improvement += 0.35
287
-
288
- # Use COALESCE / ISNULL
289
  if "coalesce" in new_lower:
290
  improvement += 0.05
291
-
292
- return min(improvement, 0.70)
 
1
  """
2
  env/db_simulator.py β€” SQL Database Engineer Agent
3
  Simulates a production database responding to optimization actions.
4
+ Core fix: _compute_score() now interpolates from JSON baseline β†’ target
5
+ so baseline matches scenario JSON (e.g. 8.0 not 80.0).
6
  """
7
 
8
  import math
 
9
  from typing import Optional
10
 
11
 
12
  class DatabaseSimulator:
 
 
 
 
 
 
 
 
13
  def __init__(self, scenario: dict):
14
  self.scenario = scenario
15
  self.tables = {t["name"]: dict(t) for t in scenario["tables"]}
 
20
  }
21
  self.stats_fresh = {name: False for name in self.tables}
22
  self.partitioned = {name: False for name in self.tables}
23
+
24
+ # Store original ms for detecting rewrite improvements
25
+ self._original_query_ms = {
26
+ q["id"]: q["avg_ms"] for q in scenario["slow_queries"]
27
+ }
28
+
29
  self.baseline = self._compute_score()
30
  self.history = [self.baseline]
31
  self.best_score = self.baseline
 
36
  # ─────────────────────────────────────────────
37
 
38
  def apply_action(self, action_type: str, payload: dict) -> dict:
 
 
 
 
39
  old_score = self._compute_score()
40
  affected = []
41
 
 
49
  self.indexes[table].append(idx_name)
50
  affected = self._queries_benefiting_from_index(table, cols)
51
  else:
 
52
  return {
53
  "old_score": old_score, "new_score": old_score,
54
  "delta": 0.0, "affected_queries": [],
55
+ "improved": False,
56
+ "message": "Index already exists or table not found."
57
  }
58
 
59
  elif action_type == "rewrite_query":
 
70
  table = payload.get("table", "")
71
  if table in self.tables and not self.partitioned.get(table):
72
  self.partitioned[table] = True
73
+ affected = [
74
+ q["id"] for q in self.queries
75
+ if table in q.get("sql", "")
76
+ ]
77
 
78
  elif action_type == "analyze_statistics":
79
  table = payload.get("table", "")
80
  if table in self.tables:
81
  self.stats_fresh[table] = True
82
+ affected = [
83
+ q["id"] for q in self.queries
84
+ if table in q.get("sql", "")
85
+ ]
86
 
87
  elif action_type == "drop_index":
88
  table = payload.get("table", "")
89
  idx_name = payload.get("index_name", "")
90
+ if (idx_name in self.indexes.get(table, [])
91
+ and idx_name != "PRIMARY"):
92
  self.indexes[table].remove(idx_name)
93
 
94
  elif action_type == "add_column":
95
+ table = payload.get("table", "")
96
+ col = payload.get("column", "")
 
97
  if table in self.tables:
98
  if "extra_columns" not in self.tables[table]:
99
  self.tables[table]["extra_columns"] = []
100
  self.tables[table]["extra_columns"].append(col)
 
101
  affected = [
102
  q["id"] for q in self.queries
103
+ if "join" in q.get("sql", "").lower()
104
+ and table in q.get("sql", "")
105
  ]
106
 
107
  new_score = self._compute_score()
 
118
  }
119
 
120
  def inspect_query(self, query_id: str) -> dict:
 
 
 
 
121
  for q in self.queries:
122
  if q["id"] == query_id:
123
+ has_index = self._check_query_index_coverage(q) > 0.1
124
+ is_partitioned = self.partitioned.get(
125
+ q.get("main_table", ""), False
126
+ )
127
+ rows_examined = 50 if has_index else q.get(
128
+ "rows_examined",
129
+ self.tables.get(
130
+ q.get("main_table", ""), {}
131
+ ).get("rows", 50000)
132
+ )
133
  return {
134
+ "query_id": query_id,
135
+ "sql": q["sql"],
136
+ "avg_ms": q["avg_ms"],
137
+ "scan_type": "INDEX RANGE SCAN" if has_index
138
+ else "FULL TABLE SCAN",
139
+ "rows_examined": rows_examined,
140
+ "partitioned": is_partitioned,
141
  "optimization_hint": (
142
  "Query is using index efficiently."
143
+ if has_index else
144
+ "No index covering WHERE columns. "
145
+ "Consider adding composite index."
146
  ),
147
+ "main_table": q.get("main_table", "unknown"),
148
  }
149
  return {"error": f"Query '{query_id}' not found"}
150
 
151
  def analyze_indexes(self, table: str) -> dict:
 
 
 
152
  if table not in self.tables:
153
  return {"error": f"Table '{table}' not found"}
154
+ existing = self.indexes.get(table, [])
155
+ hints = [
 
156
  h for h in self.scenario.get("missing_index_hints", [])
157
  if h.get("table") == table
158
  ]
159
+ used_by = []
160
  for q in self.queries:
161
  cov = self._check_query_index_coverage(q)
162
  if table in q.get("sql", "") and cov > 0.1:
163
  used_by.append(q["id"])
 
164
  return {
165
+ "table": table,
166
+ "row_count": self.tables[table].get("rows", 0),
167
  "existing_indexes": existing,
168
+ "indexes_used_by": used_by,
169
+ "missing_hints": hints,
170
+ "stats_fresh": self.stats_fresh.get(table, False),
171
+ "partitioned": self.partitioned.get(table, False),
172
+ "size_mb": self.tables[table].get("size_mb", 0),
173
  }
174
 
175
  # ─────────────────────────────────────────────
 
177
  # ─────────────────────────────────────────────
178
 
179
  def get_current_state(self) -> dict:
 
180
  return {
181
  "performance_score": round(self._compute_score(), 2),
182
  "baseline_score": round(self.baseline, 2),
 
195
  return self._compute_score() >= self.target_score
196
 
197
  # ─────────────────────────────────────────────
198
+ # INTERNAL SCORING ENGINE β€” FIXED
199
  # ─────────────────────────────────────────────
200
 
201
  def _compute_score(self) -> float:
202
  """
203
+ FIXED: Interpolates from json_baseline β†’ target_score
204
+ based on index coverage + rewrite improvements.
205
+
206
+ Before fix: used raw ms formula β†’ gave 80 when JSON said 8
207
+ After fix: no index = json_baseline, full index = target_score
208
  """
209
  if not self.queries:
210
  return 0.0
211
 
212
+ json_baseline = self.scenario.get("performance_score_baseline", 50.0)
213
+ target = self.scenario.get("target_score", 85.0)
214
+
215
  scores = []
216
  for q in self.queries:
217
+ table = q.get("main_table", "")
 
 
 
 
 
 
 
 
218
 
219
+ # ── Index coverage improvement ────────────────────────
220
+ coverage = self._check_query_index_coverage(q)
221
+ part_bonus = 0.25 if self.partitioned.get(table, False) else 0.0
222
+ stats_bonus = 0.04 if self.stats_fresh.get(table, False) else 0.0
223
+ index_improvement = min(coverage + part_bonus + stats_bonus, 0.95)
224
+
225
+ # ── Query rewrite improvement ─────────────────────────
226
+ original_ms = self._original_query_ms.get(q["id"], q["avg_ms"])
227
+ rewrite_factor = max(
228
+ 0.0,
229
+ 1.0 - q["avg_ms"] / max(1, original_ms)
230
+ )
231
+ rewrite_improvement = rewrite_factor * 0.40
232
+
233
+ # ── Combined improvement fraction (0 β†’ 1) ─────────────
234
+ combined = min(index_improvement + rewrite_improvement, 1.0)
235
+
236
+ # ── Interpolate: baseline β†’ target ────────────────────
237
+ q_score = json_baseline + (target - json_baseline) * combined
238
+ scores.append(q_score)
239
+
240
+ return round(
241
+ min(100.0, max(0.0, sum(scores) / len(scores))),
242
+ 2
243
+ )
244
 
245
  def _check_query_index_coverage(self, query: dict) -> float:
246
  """
247
+ Returns 0.0-1.0: how well indexes cover this query's WHERE clause.
248
+ 0.0 = full table scan, 0.9 = composite index match.
249
  """
250
  sql = query.get("sql", "").lower()
251
  for table, indexes in self.indexes.items():
 
253
  continue
254
  for idx in indexes:
255
  if idx == "PRIMARY":
 
256
  if "where id=" in sql or "where id =" in sql:
257
  return 0.95
258
  continue
259
+ cols = idx.replace("idx_", "").split("_")
 
260
  matches = sum(1 for c in cols if c in sql)
261
  if matches >= 2:
262
+ return 0.90 # Composite β€” excellent
263
  if matches == 1:
264
+ return 0.60 # Single column β€” partial
265
  return 0.0
266
 
267
+ def _queries_benefiting_from_index(
268
+ self, table: str, cols: list
269
+ ) -> list:
270
+ return [
271
+ q["id"] for q in self.queries
272
+ if table in q.get("sql", "").lower()
273
+ and any(c.lower() in q.get("sql", "").lower() for c in cols)
274
+ ]
275
 
276
  def _estimate_rewrite(self, new_sql: str, query: dict) -> float:
 
 
 
 
277
  new_lower = new_sql.lower()
278
  old_lower = query.get("sql", "").lower()
279
  improvement = 0.0
 
 
280
  if "select *" not in new_lower and "select *" in old_lower:
281
  improvement += 0.20
 
 
282
  if "limit " in new_lower and "limit " not in old_lower:
283
  improvement += 0.15
 
 
284
  if "exists" in new_lower and "in (select" in old_lower:
285
  improvement += 0.25
286
+ if ("inner join" in new_lower
287
+ and "," in old_lower
288
+ and "join" not in old_lower):
289
  improvement += 0.30
 
 
290
  if "where" in new_lower and "where" not in old_lower:
291
  improvement += 0.35
 
 
292
  if "coalesce" in new_lower:
293
  improvement += 0.05
294
+ return min(improvement, 0.70)