autonomic-dbre / dbre /environment.py
ZeroiJ's picture
Initial commit: Autonomic DBRE - Self-Improving Database Reliability Engineer
b59a07e
from __future__ import annotations
import time
import uuid
from typing import Dict, Any, Optional, Tuple
from pydantic import BaseModel, Field
from dbre.database import DBREPostgres
from dbre.workload_generator import WorkloadGenerator
from dbre.schema_drift import SchemaDrifter
from dbre.playbook import PlaybookManager
from dbre.meta_agent import MetaAgent
from dbre.elo_system import PlaybookELOTracker
from dbre.rewards import compute_total_reward
class DBREObservation(BaseModel):
episode_id: str
broken_query: str
schema_description: str = ""
schema_diff: list[str] = Field(default_factory=list)
execution_trace: dict = Field(default_factory=dict)
agent_playbook: str = ""
baseline_latency_ms: float = 0.0
current_score: float = 0.0
attempts: int = 0
max_attempts: int = 20
class DBREAction(BaseModel):
action_type: str = Field(..., description="One of: rewrite_query, add_index, commit_playbook_diff")
new_sql: Optional[str] = None
table_name: Optional[str] = None
column_name: Optional[str] = None
diff: Optional[str] = None
class DBREEnvironment:
"""OpenEnv-compatible environment for Autonomic DBRE."""
def __init__(self, config: Optional[Dict[str, Any]] = None):
config = config or {}
self.max_steps = config.get("max_steps", 20)
self.latency_threshold_pct = config.get("latency_threshold_pct", 0.6)
self.db = DBREPostgres()
self.db.connect()
self.db.create_tables()
self.db.seed_data()
self.workload_gen = WorkloadGenerator(self.db.conn)
self.schema_drifter = SchemaDrifter(self.db.conn)
self.playbook_manager = PlaybookManager()
self.elo_tracker = PlaybookELOTracker()
self.meta_agent = MetaAgent(self.playbook_manager, self.elo_tracker, episode_history_limit=5)
self.episode_id: str = ""
self.broken_query: str = ""
self.reference_rows: list = []
self.baseline_latency_ms: float = 0.0
self.current_optimized_query: str = ""
self.attempts: int = 0
self.episode_done: bool = False
self.episode_success: bool = False
# v1 registered once at init
if not self.elo_tracker.history:
self.elo_tracker.register_playbook("v1", 1000)
def reset(self) -> DBREObservation:
"""Reset environment for a new episode."""
self.schema_drifter.apply_random_drift()
self.broken_query, self.baseline_latency_ms = self.workload_gen.generate_broken_query()
self.reference_rows = self.workload_gen.get_expected_rows(self.broken_query)
self.episode_id = str(uuid.uuid4())[:8]
self.attempts = 0
self.episode_done = False
self.episode_success = False
self.current_optimized_query = ""
return DBREObservation(
episode_id=self.episode_id,
broken_query=self.broken_query,
schema_description=self._get_schema_description(),
schema_diff=self.schema_drifter.get_schema_diff(),
execution_trace={},
agent_playbook=self.playbook_manager.get_current(),
baseline_latency_ms=self.baseline_latency_ms,
current_score=0.0,
attempts=0,
max_attempts=self.max_steps
)
def step(self, action: DBREAction) -> Tuple[DBREObservation, float, bool, Dict[str, Any]]:
"""Execute an action and return (observation, reward, terminated, info)."""
self.attempts += 1
try:
if action.action_type == "rewrite_query":
reward_info = self._handle_rewrite_query(action.new_sql)
elif action.action_type == "add_index":
reward_info = self._handle_add_index(action.table_name, action.column_name)
elif action.action_type == "commit_playbook_diff":
reward_info = self._handle_playbook_diff(action.diff)
else:
reward_info = {"total": 0.0, "error": f"Unknown action_type: {action.action_type}"}
except Exception as e:
reward_info = {"total": 0.0, "error": str(e)}
total_reward = reward_info.get("total", 0.0)
if total_reward >= 0.6 or self.attempts >= self.max_steps:
self.episode_done = True
self.episode_success = total_reward >= 0.6
self.meta_agent.observe_episode({
"episode_id": self.episode_id,
"success": self.episode_success,
"total_reward": total_reward,
"reward_breakdown": reward_info,
"attempts": self.attempts
})
# Auto-trigger meta agent when it's ready
if self.meta_agent.should_trigger():
print("[META] Triggering playbook evaluation...")
meta_result = self.meta_agent.evaluate_and_commit(self.db.conn)
print(f"[META] Result: {meta_result}")
observation = self._build_observation()
info = {"reward_breakdown": reward_info, "episode_success": self.episode_success}
return observation, total_reward, self.episode_done, info
def state(self) -> DBREObservation:
"""Return current state without stepping."""
return self._build_observation()
def _handle_rewrite_query(self, new_sql: Optional[str]) -> Dict[str, Any]:
"""Handle a query rewrite action."""
if not new_sql:
return {"total": 0.0, "error": "No SQL provided"}
try:
cur = self.db.conn.cursor()
cur.execute(new_sql)
new_rows = cur.fetchall()
cur.close()
new_latency = self.workload_gen.measure_latency(self.db.conn, new_sql)
except Exception as e:
return {"total": 0.0, "error": f"SQL execution error: {str(e)}"}
self.current_optimized_query = new_sql
return compute_total_reward(
original_query=self.broken_query,
new_query=new_sql,
reference_rows=self.reference_rows,
baseline_latency_ms=self.baseline_latency_ms,
new_latency_ms=new_latency,
new_rows=new_rows
)
def _handle_add_index(self, table_name: Optional[str], column_name: Optional[str]) -> Dict[str, Any]:
"""Handle an add_index action."""
if not table_name or not column_name:
return {"total": 0.0, "error": "table_name and column_name required"}
try:
cursor = self.db.conn.cursor()
index_name = f"idx_{table_name}_{column_name}"
cursor.execute(f"CREATE INDEX IF NOT EXISTS {index_name} ON {table_name}({column_name})")
self.db.conn.commit()
cursor.close()
except Exception as e:
return {"total": 0.0, "error": f"Index creation error: {str(e)}"}
if self.current_optimized_query:
try:
new_latency = self.workload_gen.measure_latency(self.db.conn, self.current_optimized_query)
return compute_total_reward(
original_query=self.broken_query,
new_query=self.current_optimized_query,
reference_rows=self.reference_rows,
baseline_latency_ms=self.baseline_latency_ms,
new_latency_ms=new_latency
)
except Exception:
pass
return {"total": 0.1, "note": "Index created but no query to evaluate yet"}
def _handle_playbook_diff(self, diff: Optional[str]) -> Dict[str, Any]:
"""Handle a commit_playbook_diff action."""
if not diff:
return {"total": 0.0, "error": "No diff provided"}
try:
result = self.meta_agent.evaluate_and_commit(self.db.conn)
if result["accepted"]:
return {"total": 0.3, "note": f"Playbook updated. New ELO: {result['new_elo']}"}
else:
return {"total": 0.0, "note": "Playbook not accepted"}
except Exception as e:
return {"total": 0.0, "error": f"Playbook update error: {str(e)}"}
def _build_observation(self) -> DBREObservation:
"""Build current observation."""
return DBREObservation(
episode_id=self.episode_id,
broken_query=self.broken_query,
schema_description=self._get_schema_description(),
schema_diff=self.schema_drifter.get_schema_diff(),
execution_trace={},
agent_playbook=self.playbook_manager.get_current(),
baseline_latency_ms=self.baseline_latency_ms,
current_score=0.0,
attempts=self.attempts,
max_attempts=self.max_steps
)
def _get_schema_description(self) -> str:
"""Get human-readable schema description."""
try:
cursor = self.db.conn.cursor()
cursor.execute("""
SELECT table_name
FROM information_schema.tables
WHERE table_schema = 'public'
ORDER BY table_name
""")
tables = [row[0] for row in cursor.fetchall()]
cursor.close()
return f"Tables: {', '.join(tables)}"
except Exception:
return "Schema unavailable"