Spaces:
Running
Running
| """ | |
| Skill Invocation Environment Implementation. | |
| Trains LLMs to decide WHEN to invoke procedural knowledge (skills) during | |
| task-solving. Context cost model: each loaded skill costs context budget. | |
| Reward has two distinct cost signals: | |
| - Context hygiene (bloat_penalty): penalizes irrelevant skills still loaded at | |
| submit time (-0.15 per skill). | |
| - Token efficiency (token_waste_penalty): penalizes skills that were ever loaded | |
| but turned out to be irrelevant, even if unloaded before submission (-0.05 per | |
| skill). This captures cumulative token waste across the episode. | |
| Actions: list, load, unload, submit (plus "invoke" as backward-compat alias for load). | |
| """ | |
| import random | |
| from typing import Optional | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import SkillInvocationAction, SkillInvocationObservation, SkillInvocationState | |
| from task_bank import TASK_BANK, SKILL_BANK | |
| from task_generator import TaskGenerator | |
| DEFAULT_CONTEXT_BUDGET = 5 | |
| class SkillInvocationEnvironment(Environment): | |
| """ | |
| RL environment for training skill invocation decisions. | |
| Episodes: | |
| 1. reset() samples a task, assembles skill catalog (relevant + distractors) | |
| 2. Agent can list, load, and unload skills (within context budget) | |
| 3. Agent submits a solution | |
| 4. Reward = correctness + precision + recall - bloat - token_waste | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__( | |
| self, | |
| use_procedural: bool = False, | |
| procedural_seed: int = 0, | |
| context_budget: int = DEFAULT_CONTEXT_BUDGET, | |
| ): | |
| super().__init__() | |
| self._state = SkillInvocationState(episode_id=str(uuid4()), step_count=0) | |
| self._current_task = None | |
| self._catalog_skill_ids: list[str] = [] | |
| self._messages: list[str] = [] | |
| self._use_procedural = use_procedural | |
| self._task_generator = TaskGenerator(seed=procedural_seed) if use_procedural else None | |
| self._episode_skills: dict = {} | |
| self._context_budget = context_budget | |
| # Per-instance RNG to avoid mutating global random state (concurrency-safe) | |
| self._rng = random.Random() | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs, | |
| ) -> SkillInvocationObservation: | |
| """Sample a random task and assemble the skill catalog.""" | |
| # Use a local RNG instance to avoid mutating global random state. | |
| # This is concurrency-safe: parallel rollouts won't clobber each other's seeds. | |
| if seed is not None: | |
| self._rng = random.Random(seed) | |
| else: | |
| self._rng = random.Random() | |
| if self._use_procedural and self._task_generator: | |
| gen_seed = seed if seed is not None else self._rng.randint(0, 2**31) | |
| result = self._task_generator.generate_with_seed(gen_seed) | |
| task = result["task"] | |
| self._episode_skills = result["skills"] | |
| else: | |
| task = self._rng.choice(TASK_BANK) | |
| self._episode_skills = SKILL_BANK | |
| self._current_task = task | |
| # Build catalog: relevant + distractor skills, shuffled | |
| catalog_ids = list(task["relevant_skills"]) + list(task["distractor_skills"]) | |
| self._rng.shuffle(catalog_ids) | |
| self._catalog_skill_ids = catalog_ids | |
| # Build catalog descriptions (short only, no full content) | |
| skill_catalog = [] | |
| for sid in catalog_ids: | |
| skill = self._episode_skills[sid] | |
| skill_catalog.append({ | |
| "id": sid, | |
| "name": skill["name"], | |
| "description": skill["short_description"], | |
| }) | |
| # Initialize state | |
| eid = episode_id or str(uuid4()) | |
| self._state = SkillInvocationState( | |
| episode_id=eid, | |
| step_count=0, | |
| task_id=task["id"], | |
| loaded_skills=[], | |
| skills_ever_loaded=[], | |
| skills_invoked=[], | |
| difficulty=task["difficulty"], | |
| done=False, | |
| context_budget_total=self._context_budget, | |
| remaining_invocations=self._context_budget, | |
| ) | |
| self._messages = [f"Episode started. Task: {task['id']} ({task['difficulty']})"] | |
| return self._make_observation( | |
| skill_content=None, | |
| reward=0.0, | |
| done=False, | |
| ) | |
| def step( | |
| self, | |
| action: SkillInvocationAction, | |
| timeout_s: Optional[float] = None, | |
| **kwargs, | |
| ) -> SkillInvocationObservation: | |
| """Process a list, load, unload, or submit action.""" | |
| self._state.step_count += 1 | |
| if self._state.done: | |
| self._messages.append("Episode already done. Call reset().") | |
| return self._make_observation( | |
| skill_content=None, | |
| verification_result="Episode already finished.", | |
| reward=0.0, | |
| done=True, | |
| ) | |
| action_type = action.action_type | |
| # Backward compat: "invoke" is an alias for "load" | |
| if action_type == "invoke": | |
| action_type = "load" | |
| if action_type == "load": | |
| return self._handle_load(action) | |
| elif action_type == "unload": | |
| return self._handle_unload(action) | |
| elif action_type == "submit": | |
| return self._handle_submit(action) | |
| else: | |
| self._messages.append(f"Unknown action_type: {action.action_type}") | |
| return self._make_observation( | |
| skill_content=None, | |
| reward=0.0, | |
| done=False, | |
| ) | |
| def _handle_load(self, action: SkillInvocationAction) -> SkillInvocationObservation: | |
| """Load a skill into context.""" | |
| skill_id = action.skill_id | |
| if not skill_id: | |
| self._messages.append("load action requires skill_id") | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| if skill_id not in self._episode_skills: | |
| self._messages.append(f"Unknown skill_id: {skill_id}") | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| if skill_id not in self._catalog_skill_ids: | |
| self._messages.append(f"Skill {skill_id} not in current catalog.") | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| # Already loaded — no-op, but still return content | |
| if skill_id in self._state.loaded_skills: | |
| full_content = self._episode_skills[skill_id]["full_content"] | |
| self._messages.append(f"Skill {skill_id} already loaded.") | |
| return self._make_observation(skill_content=full_content, reward=0.0, done=False) | |
| # Check context budget | |
| if len(self._state.loaded_skills) >= self._state.context_budget_total: | |
| self._messages.append( | |
| f"Context budget full ({self._state.context_budget_total} skills loaded). " | |
| "Unload a skill first." | |
| ) | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| # Load skill | |
| self._state.loaded_skills.append(skill_id) | |
| if skill_id not in self._state.skills_ever_loaded: | |
| self._state.skills_ever_loaded.append(skill_id) | |
| # Backward compat | |
| self._state.skills_invoked = list(self._state.skills_ever_loaded) | |
| self._state.remaining_invocations = ( | |
| self._state.context_budget_total - len(self._state.loaded_skills) | |
| ) | |
| full_content = self._episode_skills[skill_id]["full_content"] | |
| skill_name = self._episode_skills[skill_id]["name"] | |
| self._messages.append( | |
| f"Loaded skill '{skill_name}' ({skill_id}). " | |
| f"Context: {len(self._state.loaded_skills)}/{self._state.context_budget_total}" | |
| ) | |
| return self._make_observation( | |
| skill_content=full_content, | |
| reward=0.0, | |
| done=False, | |
| ) | |
| def _handle_unload(self, action: SkillInvocationAction) -> SkillInvocationObservation: | |
| """Unload a skill from context.""" | |
| skill_id = action.skill_id | |
| if not skill_id: | |
| self._messages.append("unload action requires skill_id") | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| if skill_id not in self._state.loaded_skills: | |
| self._messages.append(f"Skill {skill_id} is not currently loaded.") | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| self._state.loaded_skills.remove(skill_id) | |
| self._state.remaining_invocations = ( | |
| self._state.context_budget_total - len(self._state.loaded_skills) | |
| ) | |
| skill_name = self._episode_skills[skill_id]["name"] | |
| self._messages.append( | |
| f"Unloaded skill '{skill_name}' ({skill_id}). " | |
| f"Context: {len(self._state.loaded_skills)}/{self._state.context_budget_total}" | |
| ) | |
| return self._make_observation(skill_content=None, reward=0.0, done=False) | |
| def _handle_submit(self, action: SkillInvocationAction) -> SkillInvocationObservation: | |
| """Handle a solution submission. | |
| Reward = correctness + precision + recall - bloat - token_waste. | |
| Two distinct cost signals: | |
| - bloat_penalty (-0.15 per skill): penalizes irrelevant skills still | |
| loaded at submit time (context hygiene). | |
| - token_waste_penalty (-0.05 per skill): penalizes skills that were ever | |
| loaded but turned out irrelevant, capturing cumulative token waste | |
| across the episode (token efficiency). | |
| """ | |
| answer = action.answer or "" | |
| task = self._current_task | |
| # Run deterministic verifier | |
| try: | |
| task_correct = task["verifier"](answer) | |
| except Exception: | |
| task_correct = False | |
| # Compute reward | |
| loaded = set(self._state.loaded_skills) | |
| ever_loaded = set(self._state.skills_ever_loaded) | |
| relevant = set(task["relevant_skills"]) | |
| # 1. Correctness: +0.6 | |
| correctness = 0.6 if task_correct else 0.0 | |
| # 2. Precision: what fraction of loaded skills are relevant? | |
| if len(loaded) > 0: | |
| precision = len(loaded & relevant) / len(loaded) | |
| else: | |
| precision = 0.0 | |
| precision_bonus = 0.3 * precision | |
| # 3. Recall: did you load all relevant skills? | |
| if len(relevant) > 0: | |
| recall = len(loaded & relevant) / len(relevant) | |
| else: | |
| recall = 1.0 | |
| recall_bonus = 0.1 * recall | |
| # 4. Bloat: penalty for unnecessary skills loaded at submit time | |
| unnecessary = loaded - relevant | |
| bloat_penalty = -0.15 * len(unnecessary) | |
| # 5. Token waste: penalty for skills ever loaded that were irrelevant | |
| wasted = ever_loaded - relevant | |
| token_waste_penalty = -0.05 * len(wasted) | |
| total_reward = correctness + precision_bonus + recall_bonus + bloat_penalty + token_waste_penalty | |
| total_reward = max(total_reward, -1.0) | |
| self._state.done = True | |
| verification_msg = ( | |
| f"{'CORRECT' if task_correct else 'INCORRECT'}. " | |
| f"Reward: correctness={correctness:.2f}, " | |
| f"precision={precision_bonus:.2f}, recall={recall_bonus:.2f}, " | |
| f"bloat={bloat_penalty:.2f}, token_waste={token_waste_penalty:.2f}, " | |
| f"total={total_reward:.2f}" | |
| ) | |
| self._messages.append(f"Submitted answer. {verification_msg}") | |
| return self._make_observation( | |
| skill_content=None, | |
| verification_result=verification_msg, | |
| reward=total_reward, | |
| done=True, | |
| ) | |
| def _make_observation( | |
| self, | |
| skill_content: Optional[str], | |
| reward: float, | |
| done: bool, | |
| verification_result: Optional[str] = None, | |
| ) -> SkillInvocationObservation: | |
| """Build an observation from current state.""" | |
| task = self._current_task | |
| catalog = [] | |
| if task: | |
| for sid in self._catalog_skill_ids: | |
| if sid in self._episode_skills: | |
| skill = self._episode_skills[sid] | |
| catalog.append({ | |
| "id": sid, | |
| "name": skill["name"], | |
| "description": skill["short_description"], | |
| }) | |
| # Build loaded skill contents | |
| loaded_contents = {} | |
| for sid in self._state.loaded_skills: | |
| if sid in self._episode_skills: | |
| loaded_contents[sid] = self._episode_skills[sid]["full_content"] | |
| return SkillInvocationObservation( | |
| task_description=task["description"] if task else "", | |
| skill_catalog=catalog, | |
| difficulty=self._state.difficulty, | |
| loaded_skills=list(self._state.loaded_skills), | |
| loaded_skill_contents=loaded_contents, | |
| context_budget_used=len(self._state.loaded_skills), | |
| context_budget_total=self._state.context_budget_total, | |
| skill_content=skill_content, | |
| remaining_invocations=( | |
| self._state.context_budget_total - len(self._state.loaded_skills) | |
| ), | |
| verification_result=verification_result, | |
| skills_invoked=list(self._state.skills_ever_loaded), | |
| messages=list(self._messages), | |
| done=done, | |
| reward=reward, | |
| ) | |
| def state(self) -> SkillInvocationState: | |
| """Get current episode state.""" | |
| return self._state | |