Spaces:
Sleeping
Sleeping
| """ | |
| ClarifyEnvironment — OpenEnv MCPEnvironment for the ClarifyRL task. | |
| Three MCP tools: | |
| - `get_task_info()` — free, returns the original ambiguous request and meta | |
| - `ask_question(question)` — costs 1 from the 6-question budget | |
| - `propose_plan(plan)` — terminal; runs the composable rubric | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from typing import Any, Optional | |
| from fastmcp import FastMCP | |
| from openenv.core.env_server.interfaces import EnvironmentMetadata | |
| from openenv.core.env_server.mcp_environment import MCPEnvironment | |
| from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation | |
| from server.grader import ( | |
| PENALTY_OVER_CAP, | |
| ask_question_reward, | |
| parse_plan, | |
| ) | |
| from server.rubrics import RubricContext, build_rubric, score_breakdown | |
| from server.scenarios import Scenario, generate | |
| from server.user_simulator import answer | |
| from models import ClarifyState | |
| _INSTRUCTIONS = ( | |
| "Ask clarifying questions via ask_question(question) — you have a 6-question budget. " | |
| "Then submit your final plan via propose_plan(plan) where plan is a JSON string " | |
| "object containing the required keys for the task family. " | |
| "Avoid hallucinating values for fields you never asked about." | |
| ) | |
| class ClarifyEnvironment(MCPEnvironment): | |
| # All state is per-instance (`_scenario`, `_asked_field_keys`, `_public_state`, | |
| # `_last_step_reward`). The grader/rubric/scenarios modules are pure functions | |
| # of their inputs, so a fresh instance per WebSocket session is independent | |
| # and safe. Required so multiple parallel HF Jobs runs (and TRL's | |
| # num_generations > 1) do not contend on a single shared session slot. | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self, max_questions: int = 6) -> None: | |
| mcp_server = FastMCP("clarify_rl") | |
| def get_task_info() -> dict[str, Any]: | |
| return self._tool_get_task_info() | |
| def ask_question(question: str) -> dict[str, Any]: | |
| return self._tool_ask_question(question) | |
| def propose_plan(plan: str) -> dict[str, Any]: | |
| return self._tool_propose_plan(plan) | |
| mcp_server.tool()(get_task_info) | |
| mcp_server.tool()(ask_question) | |
| mcp_server.tool()(propose_plan) | |
| super().__init__(mcp_server=mcp_server) | |
| self.rubric = build_rubric() | |
| self._default_max_questions: int = max_questions | |
| self._scenario: Optional[Scenario] = None | |
| self._asked_field_keys: set[str] = set() | |
| self._public_state: ClarifyState = ClarifyState() | |
| self._last_step_reward: float = 0.0 | |
| self._last_step_done: bool = False | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> CallToolObservation: | |
| task_id = kwargs.get("task_id", "medium") | |
| if seed is None: | |
| seed = random.randint(0, 10**9) | |
| sc = generate(seed=seed, task_id=task_id) | |
| self._scenario = sc | |
| self._asked_field_keys = set() | |
| self._last_step_reward = 0.0 | |
| self._last_step_done = False | |
| self._public_state = ClarifyState( | |
| episode_id=episode_id, | |
| step_count=0, | |
| task_id=sc["task_id"], | |
| task_title=sc["task_title"], | |
| questions_asked=[], | |
| questions_remaining=sc["max_questions"], | |
| answers_received=[], | |
| fields_revealed=[], | |
| plan_submitted=False, | |
| episode_done=False, | |
| final_score=None, | |
| score_breakdown=None, | |
| ) | |
| result = { | |
| "type": "task", | |
| "request": sc["request"], | |
| "task_id": sc["task_id"], | |
| "task_title": sc["task_title"], | |
| "family": sc["family"], | |
| "max_steps": sc["max_steps"], | |
| "questions_remaining": sc["max_questions"], | |
| "instructions": _INSTRUCTIONS, | |
| } | |
| return CallToolObservation( | |
| tool_name="reset", | |
| result=result, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def _patch_obs(self, obs: CallToolObservation, action: Any) -> CallToolObservation: | |
| if not isinstance(action, CallToolAction): | |
| return obs | |
| obs.reward = self._last_step_reward | |
| obs.done = self._last_step_done | |
| self._public_state.step_count = self._public_state.step_count + 1 | |
| if self._last_step_done: | |
| self._public_state.episode_done = True | |
| sc = self._scenario | |
| if sc and self._public_state.step_count >= sc["max_steps"] and not obs.done: | |
| obs.done = True | |
| self._public_state.episode_done = True | |
| return obs | |
| def step( | |
| self, | |
| action: Any, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CallToolObservation: | |
| obs = super().step(action, timeout_s=timeout_s, **kwargs) | |
| return self._patch_obs(obs, action) | |
| async def step_async( | |
| self, | |
| action: Any, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CallToolObservation: | |
| obs = await super().step_async(action, timeout_s=timeout_s, **kwargs) | |
| return self._patch_obs(obs, action) | |
| def _step_impl( | |
| self, | |
| action: Any, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> CallToolObservation: | |
| del timeout_s, kwargs | |
| return CallToolObservation( | |
| tool_name=getattr(action, "tool_name", "unknown"), | |
| result=None, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def state(self) -> ClarifyState: | |
| return self._public_state | |
| def get_metadata(self) -> EnvironmentMetadata: | |
| return EnvironmentMetadata( | |
| name="ClarifyRL — AskBeforeYouAct", | |
| description=( | |
| "Train LLMs to ask clarifying questions instead of hallucinating. " | |
| "Five task families (coding / medical-intake / support-triage / meeting / event), " | |
| "rule-based simulator, composable rubric." | |
| ), | |
| version="0.1.0", | |
| author="Team Bhole Chature", | |
| ) | |
| def _require_scenario(self) -> Scenario: | |
| if self._scenario is None: | |
| raise RuntimeError("Environment must be reset() before tool calls.") | |
| return self._scenario | |
| def _guard_episode_done(self) -> Optional[dict[str, Any]]: | |
| if self._public_state.episode_done: | |
| self._last_step_reward = 0.0 | |
| self._last_step_done = True | |
| return {"error": "episode already ended", "episode_done": True} | |
| return None | |
| def _tool_get_task_info(self) -> dict[str, Any]: | |
| sc = self._require_scenario() | |
| blocked = self._guard_episode_done() | |
| if blocked: | |
| return blocked | |
| self._last_step_reward = 0.0 | |
| self._last_step_done = False | |
| return { | |
| "request": sc["request"], | |
| "task_id": sc["task_id"], | |
| "task_title": sc["task_title"], | |
| "family": sc["family"], | |
| "questions_remaining": self._public_state.questions_remaining, | |
| "instructions": _INSTRUCTIONS, | |
| } | |
| def _tool_ask_question(self, question: str) -> dict[str, Any]: | |
| sc = self._require_scenario() | |
| st = self._public_state | |
| blocked = self._guard_episode_done() | |
| if blocked: | |
| return blocked | |
| question = question[:200] | |
| if st.questions_remaining <= 0: | |
| self._last_step_reward = PENALTY_OVER_CAP | |
| self._last_step_done = True | |
| return { | |
| "answer": "(no more questions allowed)", | |
| "questions_remaining": 0, | |
| "field_revealed": None, | |
| "duplicate": False, | |
| "over_cap": True, | |
| } | |
| text, matched = answer(question, sc["hidden_profile"], sc["family"]) | |
| is_duplicate = matched is not None and matched in self._asked_field_keys | |
| revealed_new = matched is not None and not is_duplicate | |
| if revealed_new: | |
| self._asked_field_keys.add(matched) | |
| st.fields_revealed = sorted(self._asked_field_keys) | |
| st.questions_asked = st.questions_asked + [question] | |
| st.answers_received = st.answers_received + [text] | |
| st.questions_remaining = st.questions_remaining - 1 | |
| self._last_step_reward = ask_question_reward( | |
| over_cap=False, | |
| is_duplicate_field=is_duplicate, | |
| revealed_new_field=revealed_new, | |
| ) | |
| self._last_step_done = False | |
| return { | |
| "answer": text, | |
| "questions_remaining": st.questions_remaining, | |
| "field_revealed": matched if revealed_new else None, | |
| "duplicate": is_duplicate, | |
| "over_cap": False, | |
| } | |
| def _tool_propose_plan(self, plan: str) -> dict[str, Any]: | |
| sc = self._require_scenario() | |
| st = self._public_state | |
| blocked = self._guard_episode_done() | |
| if blocked: | |
| return blocked | |
| parsed, parse_err = parse_plan(plan) | |
| ctx = RubricContext( | |
| family=sc["family"], | |
| hidden_profile=sc["hidden_profile"], | |
| critical_fields=frozenset(sc["critical_fields"]), | |
| required_keys=tuple(sc["required_keys"]), | |
| asked_field_keys=frozenset(self._asked_field_keys), | |
| questions_asked_count=len(st.questions_asked), | |
| max_questions=sc["max_questions"], | |
| parsed_plan=parsed, | |
| parse_error=parse_err, | |
| ) | |
| score = float(self.rubric(action=None, observation=ctx)) | |
| breakdown = score_breakdown(self.rubric) | |
| self._last_step_reward = score | |
| self._last_step_done = True | |
| st.plan_submitted = True | |
| st.episode_done = True | |
| st.final_score = score | |
| st.score_breakdown = breakdown | |
| return { | |
| "type": "resolution", | |
| "score": score, | |
| "breakdown": breakdown, | |
| "expected_profile": sc["hidden_profile"], | |
| "critical_fields": list(sc["critical_fields"]), | |
| "required_keys": list(sc["required_keys"]), | |
| "submitted_plan": parsed, | |
| "parse_error": parse_err, | |
| "questions_asked": len(st.questions_asked), | |
| "fields_revealed": sorted(self._asked_field_keys), | |
| } | |
| __all__ = ["ClarifyEnvironment"] | |