""" Research -> Interactive Explainer Environment (multi-step, async). Episode flow: 1. reset() → agent gets a topic + tier 2. step(explore) × 0..MAX_EXPLORE → agent calls research tools 3. step(generate) × 1 → agent produces marimo/manim code 4. step(repair) × 0..MAX_REPAIR → agent fixes lint/build errors if needed Each step returns a per-step reward. The final generate step also includes a generation reward that accounts for how well the code uses the research. The environment supports async via reset_async() / step_async() overrides. OpenEnv's HTTP server detects these and calls them directly (no thread pool). """ import random from uuid import uuid4 from openenv.core.env_server.interfaces import Environment from openenv.core.env_server.types import State try: from ..constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS, clamp_action_reward from ..models import ExplainerAction, ExplainerObservation from ..research import AVAILABLE_TOOLS, run_research_tool from ..rewards.exploration import compute_explore_reward from ..rewards.generation import adjust_repair_reward, compute_generate_reward from ..rewards.sandbox import validate_code from ..task_bank import ALL_TASKS, EASY_TASKS, HARD_TASKS, MEDIUM_TASKS, Task except ImportError: from constants import MAX_EXPLORE_STEPS, MAX_REPAIR_STEPS, clamp_action_reward from models import ExplainerAction, ExplainerObservation from research import AVAILABLE_TOOLS, run_research_tool from rewards.exploration import compute_explore_reward from rewards.generation import adjust_repair_reward, compute_generate_reward from rewards.sandbox import validate_code from task_bank import ALL_TASKS, EASY_TASKS, HARD_TASKS, MEDIUM_TASKS, Task MB002_REPAIR_HINT = ( "MB002 repair checklist: Marimo treats every non-underscore assignment as a " "global notebook variable, including `for` loop variables. Audit the whole " "file and rename cell-local names to private names everywhere: `arr` -> " "`_arr`, `target` -> `_target`, `i` -> `_i`, `t` -> `_t`, `freqs` -> " "`_freqs`, `fig` -> `_fig`, `ax` -> `_ax`. Public names should only be used " "for values intentionally passed to later cells, and each public name may be " "defined once globally." ) def _render_errors_with_hints(errors: str, error_codes: list[str]) -> str: if "MB002" not in error_codes: return errors return f"{errors}\n\n{MB002_REPAIR_HINT}" class ExplainerEnvironment(Environment): """ Multi-step Research → Interactive Explainer environment. Phase 1 (explore): agent issues search queries, receives papers/wiki sections. Phase 2 (generate): agent produces marimo/manim code using the research. Supports async via reset_async() / step_async() — OpenEnv's server detects the overrides and awaits them directly instead of using a thread pool. """ SUPPORTS_CONCURRENT_SESSIONS: bool = True def __init__(self): super().__init__() self._state = State(episode_id=str(uuid4()), step_count=0) self._current_task: Task | None = None self._difficulty_pool: list[Task] = EASY_TASKS self._accumulated_context: list[str] = [] self._explore_actions: list[str] = [] self._used_tools: set[str] = set() self._explore_steps: int = 0 self._repair_steps: int = 0 self._phase: str = "explore" self._done: bool = False self._last_code: str = "" self._last_format: str = "marimo" self._last_narration: str = "" self._last_errors: str = "" self._last_error_codes: list[str] = [] # ------------------------------------------------------------------ # Sync interface (fallback — OpenEnv prefers async when overridden) # ------------------------------------------------------------------ def reset(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation: """Sample a task and return the initial observation (sync).""" return self._do_reset(seed=seed, episode_id=episode_id, **kwargs) def step(self, action: ExplainerAction, timeout_s=None, **kwargs) -> ExplainerObservation: """Route to explore or generate handler (sync — explore uses blocking fallback).""" import asyncio self._state.step_count += 1 task = self._current_task if task is None: return ExplainerObservation( feedback="Error: no task set. Call reset() first.", done=True, reward=-1.0, ) if self._done: return self._make_obs( task, phase="done", feedback="Episode is already done. Call reset() to start a new one.", reward=0.0, done=True, ) try: if action.action_type == "explore": # Run async explore in a new event loop for sync callers return asyncio.run(self._handle_explore(action, task)) elif action.action_type == "generate": return self._handle_generate(action, task) elif action.action_type == "repair": return self._handle_repair(action, task) else: return self._make_obs( task, phase="explore", feedback=f"Unknown action_type: {action.action_type}", reward=0.0, done=True, ) except Exception as e: return self._make_obs( task, phase="done", feedback=f"Environment error: {e}", reward=0.0, done=True, ) # ------------------------------------------------------------------ # Async interface (preferred — OpenEnv detects these overrides) # ------------------------------------------------------------------ async def reset_async(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation: """Sample a task and return the initial observation (async).""" return self._do_reset(seed=seed, episode_id=episode_id, **kwargs) async def step_async(self, action: ExplainerAction, timeout_s=None, **kwargs) -> ExplainerObservation: """Route to explore or generate handler (async).""" self._state.step_count += 1 task = self._current_task if task is None: return ExplainerObservation( feedback="Error: no task set. Call reset() first.", done=True, reward=-1.0, ) if self._done: return self._make_obs( task, phase="done", feedback="Episode is already done. Call reset() to start a new one.", reward=0.0, done=True, ) try: if action.action_type == "explore": return await self._handle_explore(action, task) elif action.action_type == "generate": return self._handle_generate(action, task) elif action.action_type == "repair": return self._handle_repair(action, task) else: return self._make_obs( task, phase="explore", feedback=f"Unknown action_type: {action.action_type}", reward=0.0, done=True, ) except Exception as e: return self._make_obs( task, phase="done", feedback=f"Environment error: {e}", reward=0.0, done=True, ) # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _do_reset(self, seed=None, episode_id=None, **kwargs) -> ExplainerObservation: """Shared reset logic (no I/O, so sync is fine).""" self._state = State( episode_id=episode_id or str(uuid4()), step_count=0 ) self._accumulated_context = [] self._explore_actions = [] self._used_tools = set() self._explore_steps = 0 self._repair_steps = 0 self._phase = "explore" self._done = False self._last_code = "" self._last_format = "marimo" self._last_narration = "" self._last_errors = "" self._last_error_codes = [] # Allow selecting a specific task by topic name topic = kwargs.get("topic", None) if topic: match = next((t for t in ALL_TASKS if t.topic == topic), None) if match: self._current_task = match else: # Fallback to random if topic not found rng = random.Random(seed) if seed is not None else random.Random() self._current_task = rng.choice(ALL_TASKS) else: difficulty = kwargs.get("difficulty", None) if difficulty == "medium": pool = MEDIUM_TASKS elif difficulty == "hard": pool = HARD_TASKS elif difficulty == "easy": pool = EASY_TASKS else: pool = self._difficulty_pool rng = random.Random(seed) if seed is not None else random.Random() self._current_task = rng.choice(pool) if pool else rng.choice(ALL_TASKS) t = self._current_task return ExplainerObservation( topic=t.topic, content=t.content, tier=t.tier, keywords=t.keywords, data_available=t.data_available, difficulty=t.difficulty, phase="explore", feedback=( "Research phase: choose a tool and query relevant to the topic. " f"Available tools: {', '.join(AVAILABLE_TOOLS)}." ), search_results="", explored_context="", explore_steps_left=MAX_EXPLORE_STEPS, repair_attempts_left=MAX_REPAIR_STEPS, available_tools=list(AVAILABLE_TOOLS), done=False, reward=0.0, ) async def _handle_explore(self, action: ExplainerAction, task: Task) -> ExplainerObservation: """Process an explore action: call a research tool and score the result.""" if self._phase not in {"explore", "generate"}: return self._make_obs( task, phase=self._phase, feedback=f"Cannot explore during phase '{self._phase}'.", reward=0.0, ) if self._explore_steps >= MAX_EXPLORE_STEPS: self._phase = "generate" return self._make_obs( task, phase="generate", feedback="Max explore steps reached. You must now generate.", reward=0.0, ) self._explore_steps += 1 query = action.query.strip() intent = action.intent.strip() tool = action.tool or "search_wikipedia" if not query: return self._make_obs( task, phase="explore", feedback="Empty query. Provide a search query.", reward=0.0, ) previous_context = list(self._accumulated_context) previous_actions = list(self._explore_actions) used_tools = set(self._used_tools) result = await run_research_tool(tool, query, intent) results_text = result.render() self._explore_actions.append(_explore_action_text(tool, query, intent)) if result.ok: self._accumulated_context.append(result.text) self._used_tools.add(tool) # Compute per-step exploration reward reward, components = compute_explore_reward( query=query, tool=tool, intent=intent, result=result, topic=task.topic, keywords_csv=task.keywords, task_content=task.content, difficulty=task.difficulty, previous_context=previous_context, accumulated_context=self._accumulated_context, used_tools=used_tools, previous_actions=previous_actions, ) steps_left = MAX_EXPLORE_STEPS - self._explore_steps if steps_left > 1: phase = "explore" hint = f"Research going well — {steps_left} more steps available. Keep searching or move to generation." elif steps_left == 1: phase = "explore" hint = "Last research step available. Search for any missing context, or proceed to generate." else: phase = "generate" hint = "Research phase complete. Time to generate your explanation." self._phase = phase top_chunks = _top_chunks_payload(result.chunks) return self._make_obs( task, phase=phase, feedback=f"{hint}\nTool: {tool}\nReward: {components}", search_results=results_text, top_chunks=top_chunks, reward=reward, metadata={ "step": self._state.step_count, "phase": "explore", "tool": tool, "source_count": len(result.chunks), "top_chunks": top_chunks, "error": result.error, **components, }, ) def _handle_generate(self, action: ExplainerAction, task: Task) -> ExplainerObservation: """Process a generate action: run sandbox, maybe open repair phase.""" if self._phase not in {"explore", "generate"}: return self._make_obs( task, phase=self._phase, feedback=f"Cannot generate during phase '{self._phase}'.", reward=0.0, ) fmt = action.format or "marimo" code = action.code narration = action.narration # Penalise generating without any exploration if self._explore_steps == 0: skip_penalty = -0.1 penalty_msg = "Warning: generating without any research. -0.1 penalty." else: skip_penalty = 0.0 penalty_msg = "" sandbox = validate_code(fmt, code) # Generation reward reward, components = compute_generate_reward( code=code, fmt=fmt, narration=narration, task=task, exec_success=sandbox.exec_success, accumulated_context=self._accumulated_context, static_check_passed=sandbox.check_passed, error_codes=sandbox.error_codes, ) reward = clamp_action_reward(reward + skip_penalty) components["generate_total"] = round(reward, 4) self._last_code = code self._last_format = fmt self._last_narration = narration rendered_errors = _render_errors_with_hints(sandbox.render_errors(), sandbox.error_codes) self._last_errors = rendered_errors self._last_error_codes = sandbox.error_codes # Feedback parts = [] if penalty_msg: parts.append(penalty_msg) if not sandbox.parses: parts.append("SYNTAX ERROR: code does not parse.") elif not sandbox.exec_success: parts.append(f"EXECUTION FAILED: {rendered_errors}") else: parts.append(f"EXECUTION OK: {sandbox.message}") parts.append( f"Reward: {', '.join(f'{k}={v}' for k, v in components.items())}" ) done = sandbox.exec_success or self._repair_steps >= MAX_REPAIR_STEPS phase = "done" if done else "repair" self._phase = phase self._done = done if not done: parts.append( f"Repair phase: {MAX_REPAIR_STEPS} attempts available. " "Submit a revised artifact using the error feedback." ) return self._make_obs( task, phase=phase, feedback="\n".join(parts), reward=reward, done=done, last_errors="" if sandbox.exec_success else rendered_errors, metadata={ "step": self._state.step_count, "phase": "generate", "explore_steps_used": self._explore_steps, "sandbox_message": sandbox.message, "error_codes": sandbox.error_codes, **components, }, ) def _handle_repair(self, action: ExplainerAction, task: Task) -> ExplainerObservation: """Process one repair attempt after a failed generate action.""" if self._phase != "repair": return self._make_obs( task, phase=self._phase, feedback="Repair is only available after a failed generate step.", reward=0.0, done=self._done, ) if self._repair_steps >= MAX_REPAIR_STEPS: self._phase = "done" self._done = True return self._make_obs( task, phase="done", feedback="No repair attempts left.", reward=0.0, done=True, ) self._repair_steps += 1 fmt = action.format or self._last_format or "marimo" code = action.code narration = action.narration or self._last_narration previous_code = self._last_code previous_errors = list(self._last_error_codes) sandbox = validate_code(fmt, code) base_reward, components = compute_generate_reward( code=code, fmt=fmt, narration=narration, task=task, exec_success=sandbox.exec_success, accumulated_context=self._accumulated_context, static_check_passed=sandbox.check_passed, error_codes=sandbox.error_codes, ) repair_reward, repair_components = adjust_repair_reward( base_reward, repair_success=sandbox.exec_success, previous_error_codes=previous_errors, new_error_codes=sandbox.error_codes, previous_code=previous_code, repaired_code=code, ) components.update(repair_components) self._last_code = code self._last_format = fmt self._last_narration = narration rendered_errors = _render_errors_with_hints(sandbox.render_errors(), sandbox.error_codes) self._last_errors = rendered_errors self._last_error_codes = sandbox.error_codes attempts_left = MAX_REPAIR_STEPS - self._repair_steps done = sandbox.exec_success or attempts_left <= 0 phase = "done" if done else "repair" self._phase = phase self._done = done status = "REPAIR OK" if sandbox.exec_success else "REPAIR FAILED" feedback_parts = [ f"{status}: {sandbox.message if sandbox.exec_success else rendered_errors}", f"Reward: {', '.join(f'{k}={v}' for k, v in components.items())}", ] if not done: feedback_parts.append( f"Repair phase continues: {attempts_left} repair attempts left. " "Submit another corrected artifact using the latest error feedback." ) feedback = "\n".join(feedback_parts) return self._make_obs( task, phase=phase, feedback=feedback, reward=repair_reward, done=done, last_errors="" if sandbox.exec_success else rendered_errors, metadata={ "step": self._state.step_count, "phase": "repair", "explore_steps_used": self._explore_steps, "repair_steps_used": self._repair_steps, "sandbox_message": sandbox.message, "error_codes": sandbox.error_codes, **components, }, ) def _make_obs( self, task: Task, *, phase: str, feedback: str, reward: float = 0.0, done: bool = False, search_results: str = "", top_chunks: list[dict] | None = None, last_errors: str | None = None, metadata: dict | None = None, ) -> ExplainerObservation: """Helper to build a consistent observation.""" return ExplainerObservation( topic=task.topic, content=task.content, tier=task.tier, keywords=task.keywords, data_available=task.data_available, difficulty=task.difficulty, phase=phase, feedback=feedback, search_results=search_results, top_chunks=top_chunks or [], explored_context="\n---\n".join(self._accumulated_context), explore_steps_left=MAX_EXPLORE_STEPS - self._explore_steps, repair_attempts_left=MAX_REPAIR_STEPS - self._repair_steps, last_errors=self._last_errors if last_errors is None else last_errors, available_tools=list(AVAILABLE_TOOLS), done=done, reward=reward, metadata=metadata or {}, ) @property def state(self) -> State: return self._state def _explore_action_text(tool: str, query: str, intent: str) -> str: return f"{tool} {query.strip()} {intent.strip()}".strip() def _top_chunks_payload(chunks) -> list[dict]: return [ { "rank": chunk.rank, "source": chunk.source, "title": chunk.title, "url": chunk.url, "score": round(chunk.score, 4), "snippet": chunk.text, } for chunk in chunks[:5] ]