Spaces:
Sleeping
Sleeping
| """ | |
| env/environment.py β SQL Database Engineer Agent (SDEA) | |
| Round 2: Long-horizon DB optimization environment. | |
| Agent manages a simulated production database over 50 steps. | |
| """ | |
| import time | |
| import random | |
| from typing import Optional | |
| from pydantic import ValidationError | |
| from env.models import ( | |
| Action, Observation, Reward, EpisodeState, | |
| DifficultyLevel, ActionType, StepResponse | |
| ) | |
| from env.tasks import task_manager | |
| from env.reward import compute_reward, is_done, MAX_STEPS | |
| from env.db_simulator import DatabaseSimulator | |
| class SQLDebuggerEnvironment: | |
| """ | |
| OpenEnv-compliant SQL Database Engineer Agent Environment. | |
| Round 2 evolution: | |
| - 50-step long-horizon episodes (up from 20) | |
| - 10 action types including DB-specific actions | |
| - DatabaseSimulator tracks real performance score 0-100 | |
| - Milestone bonuses at 25%/50%/75% improvement | |
| - Backward compatible with Round 1 actions | |
| """ | |
| def __init__(self): | |
| self._state = EpisodeState() | |
| self._current_task = None | |
| self._started_at = None | |
| self._db_sim: Optional[DatabaseSimulator] = None | |
| self._milestones_earned: set = set() | |
| self._baseline_score: float = 0.0 | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # reset() β Observation | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def reset(self, difficulty: Optional[str] = None, task_id: Optional[str] = None) -> Observation: | |
| """ | |
| Starts a fresh episode. Clears ALL state. | |
| Loads scenario and initializes DatabaseSimulator. | |
| """ | |
| # ββ Resolve difficulty ββββββββββββββββββββββββββββββββββββ | |
| if difficulty is not None: | |
| try: | |
| diff_enum = DifficultyLevel(difficulty.lower()) | |
| except ValueError: | |
| diff_enum = random.choice(list(DifficultyLevel)) | |
| else: | |
| diff_enum = random.choice(list(DifficultyLevel)) | |
| # ββ Load task βββββββββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| task = task_manager.get_task(diff_enum, task_id=task_id) | |
| except Exception as e: | |
| raise ValueError(f"Failed to load task: {str(e)}") | |
| # ββ Initialize DatabaseSimulator ββββββββββββββββββββββββββ | |
| # Only initialize for Round 2 scenarios (have 'tables' key) | |
| if "tables" in task and "slow_queries" in task: | |
| self._db_sim = DatabaseSimulator(task) | |
| self._baseline_score = self._db_sim.get_performance_score() | |
| else: | |
| # Round 1 task β no DB simulator needed | |
| self._db_sim = None | |
| self._baseline_score = 0.0 | |
| self._milestones_earned = set() | |
| # ββ Reset episode state βββββββββββββββββββββββββββββββββββ | |
| self._current_task = task | |
| self._started_at = time.time() | |
| self._state = EpisodeState( | |
| task_id = task["id"], | |
| difficulty = diff_enum, | |
| step_count = 0, | |
| total_reward = 0.0, | |
| done = False, | |
| hints_used = 0, | |
| previous_actions = [], | |
| action_counts = { | |
| "_baseline_score": self._baseline_score, | |
| "_target_score": task.get("target_score", 85.0), | |
| "_milestones": [], | |
| "_perf_history": [self._baseline_score], | |
| "_best_score": self._baseline_score, | |
| }, | |
| started_at = self._started_at, | |
| last_reward = 0.0, | |
| initialized = True, | |
| ) | |
| return self._build_observation() | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # step() β StepResponse | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def step(self, action: Optional[Action]) -> StepResponse: | |
| """ | |
| Processes an action, updates DB simulator, computes reward. | |
| Handles all Round 2 DB engineering actions. | |
| """ | |
| # ββ Auto-reset if not initialized ββββββββββββββββββββββββ | |
| if not self._state.initialized or self._current_task is None: | |
| obs = self.reset() | |
| return StepResponse( | |
| observation = obs, | |
| reward = Reward(score=0.5, breakdown={"auto_reset": True}, feedback="Environment auto-reset."), | |
| done = False, | |
| info = {"auto_reset": True} | |
| ) | |
| # ββ Episode already done ββββββββββββββββββββββββββββββββββ | |
| if self._state.done: | |
| obs = self._build_observation() | |
| return StepResponse( | |
| observation = obs, | |
| reward = Reward(score=0.5, breakdown={"episode_done": True}, feedback="Episode finished. Call reset()."), | |
| done = True, | |
| info = {"episode_done": True, "total_reward": self._state.total_reward} | |
| ) | |
| # ββ Handle null action ββββββββββββββββββββββββββββββββββββ | |
| if action is None or action.payload is None: | |
| self._state.step_count += 1 | |
| obs = self._build_observation() | |
| reward = Reward(score=0.001, breakdown={"invalid_action": 0.001}, feedback="Null action.") | |
| done = self._state.step_count >= MAX_STEPS | |
| self._state.done = done | |
| return StepResponse(observation=obs, reward=reward, done=done, info={"error": "null_action"}) | |
| action_type_val = action.action_type.value if hasattr(action.action_type, "value") else str(action.action_type) | |
| action_type_enum = action.action_type | |
| # ββ Update step count βββββββββββββββββββββββββββββββββββββ | |
| self._state.step_count += 1 | |
| self._state.previous_actions.append(action_type_val) | |
| self._state.action_counts[action_type_val] = \ | |
| self._state.action_counts.get(action_type_val, 0) + 1 | |
| # ββ Handle hints ββββββββββββββββββββββββββββββββββββββββββ | |
| if action_type_enum == ActionType.REQUEST_HINT: | |
| self._state.hints_used += 1 | |
| hint_text = task_manager.get_hint(self._current_task, self._state.hints_used) | |
| self._current_task["_last_hint"] = hint_text | |
| # ββ Apply DB action and get delta βββββββββββββββββββββββββ | |
| db_delta = 0.0 | |
| current_score = self._baseline_score | |
| action_info = {} | |
| if self._db_sim is not None: | |
| payload = action.payload or {} | |
| if action_type_enum == ActionType.INSPECT_QUERY: | |
| qid = payload.get("query_id", "q1") | |
| action_info = self._db_sim.inspect_query(qid) | |
| self._current_task["_last_inspect"] = action_info | |
| # No score change β investigation action | |
| elif action_type_enum == ActionType.ANALYZE_INDEXES: | |
| table = payload.get("table", "") | |
| action_info = self._db_sim.analyze_indexes(table) | |
| self._current_task["_last_analysis"] = action_info | |
| elif action_type_enum == ActionType.CREATE_INDEX: | |
| result = self._db_sim.apply_action("create_index", payload) | |
| db_delta = result["delta"] | |
| action_info = result | |
| elif action_type_enum == ActionType.REWRITE_QUERY: | |
| result = self._db_sim.apply_action("rewrite_query", payload) | |
| db_delta = result["delta"] | |
| action_info = result | |
| elif action_type_enum == ActionType.ADD_COLUMN: | |
| result = self._db_sim.apply_action("add_column", payload) | |
| db_delta = result["delta"] | |
| action_info = result | |
| elif action_type_enum == ActionType.DROP_INDEX: | |
| result = self._db_sim.apply_action("drop_index", payload) | |
| db_delta = result["delta"] | |
| action_info = result | |
| elif action_type_enum == ActionType.PARTITION_TABLE: | |
| result = self._db_sim.apply_action("partition_table", payload) | |
| db_delta = result["delta"] | |
| action_info = result | |
| elif action_type_enum == ActionType.ANALYZE_STATS: | |
| result = self._db_sim.apply_action("analyze_statistics", payload) | |
| db_delta = result["delta"] | |
| action_info = result | |
| current_score = self._db_sim.get_performance_score() | |
| # Update tracking in action_counts dict (used by /progress) | |
| perf_history = self._state.action_counts.get("_perf_history", []) | |
| perf_history.append(current_score) | |
| self._state.action_counts["_perf_history"] = perf_history | |
| self._state.action_counts["_best_score"] = self._db_sim.best_score | |
| # ββ Compute reward ββββββββββββββββββββββββββββββββββββββββ | |
| reward = compute_reward( | |
| action = action, | |
| task_id = self._state.task_id, | |
| difficulty = self._state.difficulty, | |
| step_count = self._state.step_count, | |
| previous_actions = self._state.previous_actions[:-1], | |
| hints_used = self._state.hints_used, | |
| estimated_steps = self._current_task.get("estimated_fix_steps", MAX_STEPS), | |
| action_counts = self._state.action_counts, | |
| db_delta = db_delta, | |
| baseline_score = self._baseline_score, | |
| current_score = current_score, | |
| milestones_earned = self._milestones_earned, | |
| ) | |
| # Update milestone tracking | |
| self._state.action_counts["_milestones"] = list(self._milestones_earned) | |
| # ββ Update cumulative reward ββββββββββββββββββββββββββββββ | |
| self._state.last_reward = reward.score | |
| self._state.total_reward = round(self._state.total_reward + reward.score, 4) | |
| # ββ Check done ββββββββββββββββββββββββββββββββββββββββββββ | |
| target_reached = ( | |
| self._db_sim.is_target_reached() if self._db_sim else False | |
| ) | |
| done = is_done( | |
| action_type = action_type_enum, | |
| step_count = self._state.step_count, | |
| grader_score = reward.breakdown.get("grader_score", 0.0), | |
| target_reached = target_reached, | |
| ) | |
| self._state.done = done | |
| # ββ Build observation βββββββββββββββββββββββββββββββββββββ | |
| obs = self._build_observation() | |
| # ββ Info dict βββββββββββββββββββββββββββββββββββββββββββββ | |
| info = { | |
| "step_count": self._state.step_count, | |
| "total_reward": self._state.total_reward, | |
| "hints_used": self._state.hints_used, | |
| "task_id": self._state.task_id, | |
| "difficulty": self._state.difficulty.value if self._state.difficulty else None, | |
| "performance_score": current_score, | |
| "db_delta": db_delta, | |
| "milestones": list(self._milestones_earned), | |
| "action_result": action_info, | |
| } | |
| if done: | |
| info["episode_summary"] = { | |
| "total_steps": self._state.step_count, | |
| "total_reward": self._state.total_reward, | |
| "hints_used": self._state.hints_used, | |
| "duration_sec": round(time.time() - (self._started_at or time.time()), 2), | |
| "final_score": current_score, | |
| "baseline_score": self._baseline_score, | |
| "improvement": round(current_score - self._baseline_score, 2), | |
| "milestones_earned": list(self._milestones_earned), | |
| } | |
| # Normalize reward for validator compliance | |
| normalized_score = max(0.001, min(0.999, (reward.score + 1.0) / 2.0)) | |
| reward = Reward( | |
| score=normalized_score, | |
| breakdown=reward.breakdown, | |
| feedback=reward.feedback | |
| ) | |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # state() β EpisodeState | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def state(self) -> EpisodeState: | |
| return self._state | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # INTERNAL HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def _build_observation(self) -> Observation: | |
| """Builds Observation from current state + DB simulator state.""" | |
| if self._current_task is None: | |
| return Observation( | |
| task_id = "none", | |
| task_description = "No task loaded. Call reset() first.", | |
| current_context = {}, | |
| step_count = self._state.step_count, | |
| difficulty = DifficultyLevel.EASY, | |
| max_steps = MAX_STEPS, | |
| hints_used = self._state.hints_used, | |
| previous_actions = self._state.previous_actions, | |
| metadata = {} | |
| ) | |
| # Base context from task | |
| context = task_manager.build_observation_context(self._current_task) | |
| # Inject DB simulator state | |
| if self._db_sim is not None: | |
| db_state = self._db_sim.get_current_state() | |
| context.update({ | |
| "performance_score": db_state["performance_score"], | |
| "target_score": db_state["target_score"], | |
| "baseline_score": db_state["baseline_score"], | |
| "tables": db_state["tables"], | |
| "slow_queries": db_state["slow_queries"], | |
| "indexes": db_state["indexes"], | |
| "improvement_history": db_state["history"], | |
| "best_score": db_state["best_score"], | |
| "milestones_earned": list(self._milestones_earned), | |
| }) | |
| # Inject last action result if available | |
| if "_last_inspect" in self._current_task: | |
| context["last_inspect_result"] = self._current_task["_last_inspect"] | |
| if "_last_analysis" in self._current_task: | |
| context["last_analysis_result"] = self._current_task["_last_analysis"] | |
| if "_last_hint" in self._current_task: | |
| context["last_hint"] = self._current_task["_last_hint"] | |
| context["steps_remaining"] = MAX_STEPS - self._state.step_count | |
| context["total_reward_so_far"] = self._state.total_reward | |
| return Observation( | |
| task_id = self._state.task_id or "none", | |
| task_description = self._current_task.get("description", ""), | |
| current_context = context, | |
| step_count = self._state.step_count, | |
| difficulty = self._state.difficulty or DifficultyLevel.EASY, | |
| max_steps = MAX_STEPS, | |
| hints_used = self._state.hints_used, | |
| previous_actions = self._state.previous_actions.copy(), | |
| metadata = { | |
| "category": self._current_task.get("category", ""), | |
| "baseline_score": self._baseline_score, | |
| "target_score": self._current_task.get("target_score", 85.0), | |
| "total_reward": self._state.total_reward, | |
| "milestones": list(self._milestones_earned), | |
| } | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # SINGLETON INSTANCE (used by FastAPI) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| environment = SQLDebuggerEnvironment() |