clarify-rl / server /clarify_environment.py
Anurag Agarwal
env: enable concurrent rollout sessions (max_concurrent_envs=8)
895f00d
"""
ClarifyEnvironment — OpenEnv MCPEnvironment for the ClarifyRL task.
Three MCP tools:
- `get_task_info()` — free, returns the original ambiguous request and meta
- `ask_question(question)` — costs 1 from the 6-question budget
- `propose_plan(plan)` — terminal; runs the composable rubric
"""
from __future__ import annotations
import random
from typing import Any, Optional
from fastmcp import FastMCP
from openenv.core.env_server.interfaces import EnvironmentMetadata
from openenv.core.env_server.mcp_environment import MCPEnvironment
from openenv.core.env_server.mcp_types import CallToolAction, CallToolObservation
from server.grader import (
PENALTY_OVER_CAP,
ask_question_reward,
parse_plan,
)
from server.rubrics import RubricContext, build_rubric, score_breakdown
from server.scenarios import Scenario, generate
from server.user_simulator import answer
from models import ClarifyState
_INSTRUCTIONS = (
"Ask clarifying questions via ask_question(question) — you have a 6-question budget. "
"Then submit your final plan via propose_plan(plan) where plan is a JSON string "
"object containing the required keys for the task family. "
"Avoid hallucinating values for fields you never asked about."
)
class ClarifyEnvironment(MCPEnvironment):
# All state is per-instance (`_scenario`, `_asked_field_keys`, `_public_state`,
# `_last_step_reward`). The grader/rubric/scenarios modules are pure functions
# of their inputs, so a fresh instance per WebSocket session is independent
# and safe. Required so multiple parallel HF Jobs runs (and TRL's
# num_generations > 1) do not contend on a single shared session slot.
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self, max_questions: int = 6) -> None:
mcp_server = FastMCP("clarify_rl")
def get_task_info() -> dict[str, Any]:
return self._tool_get_task_info()
def ask_question(question: str) -> dict[str, Any]:
return self._tool_ask_question(question)
def propose_plan(plan: str) -> dict[str, Any]:
return self._tool_propose_plan(plan)
mcp_server.tool()(get_task_info)
mcp_server.tool()(ask_question)
mcp_server.tool()(propose_plan)
super().__init__(mcp_server=mcp_server)
self.rubric = build_rubric()
self._default_max_questions: int = max_questions
self._scenario: Optional[Scenario] = None
self._asked_field_keys: set[str] = set()
self._public_state: ClarifyState = ClarifyState()
self._last_step_reward: float = 0.0
self._last_step_done: bool = False
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
**kwargs: Any,
) -> CallToolObservation:
task_id = kwargs.get("task_id", "medium")
if seed is None:
seed = random.randint(0, 10**9)
sc = generate(seed=seed, task_id=task_id)
self._scenario = sc
self._asked_field_keys = set()
self._last_step_reward = 0.0
self._last_step_done = False
self._public_state = ClarifyState(
episode_id=episode_id,
step_count=0,
task_id=sc["task_id"],
task_title=sc["task_title"],
questions_asked=[],
questions_remaining=sc["max_questions"],
answers_received=[],
fields_revealed=[],
plan_submitted=False,
episode_done=False,
final_score=None,
score_breakdown=None,
)
result = {
"type": "task",
"request": sc["request"],
"task_id": sc["task_id"],
"task_title": sc["task_title"],
"family": sc["family"],
"max_steps": sc["max_steps"],
"questions_remaining": sc["max_questions"],
"instructions": _INSTRUCTIONS,
}
return CallToolObservation(
tool_name="reset",
result=result,
done=False,
reward=0.0,
)
def _patch_obs(self, obs: CallToolObservation, action: Any) -> CallToolObservation:
if not isinstance(action, CallToolAction):
return obs
obs.reward = self._last_step_reward
obs.done = self._last_step_done
self._public_state.step_count = self._public_state.step_count + 1
if self._last_step_done:
self._public_state.episode_done = True
sc = self._scenario
if sc and self._public_state.step_count >= sc["max_steps"] and not obs.done:
obs.done = True
self._public_state.episode_done = True
return obs
def step(
self,
action: Any,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CallToolObservation:
obs = super().step(action, timeout_s=timeout_s, **kwargs)
return self._patch_obs(obs, action)
async def step_async(
self,
action: Any,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CallToolObservation:
obs = await super().step_async(action, timeout_s=timeout_s, **kwargs)
return self._patch_obs(obs, action)
def _step_impl(
self,
action: Any,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> CallToolObservation:
del timeout_s, kwargs
return CallToolObservation(
tool_name=getattr(action, "tool_name", "unknown"),
result=None,
done=False,
reward=0.0,
)
@property
def state(self) -> ClarifyState:
return self._public_state
def get_metadata(self) -> EnvironmentMetadata:
return EnvironmentMetadata(
name="ClarifyRL — AskBeforeYouAct",
description=(
"Train LLMs to ask clarifying questions instead of hallucinating. "
"Five task families (coding / medical-intake / support-triage / meeting / event), "
"rule-based simulator, composable rubric."
),
version="0.1.0",
author="Team Bhole Chature",
)
def _require_scenario(self) -> Scenario:
if self._scenario is None:
raise RuntimeError("Environment must be reset() before tool calls.")
return self._scenario
def _guard_episode_done(self) -> Optional[dict[str, Any]]:
if self._public_state.episode_done:
self._last_step_reward = 0.0
self._last_step_done = True
return {"error": "episode already ended", "episode_done": True}
return None
def _tool_get_task_info(self) -> dict[str, Any]:
sc = self._require_scenario()
blocked = self._guard_episode_done()
if blocked:
return blocked
self._last_step_reward = 0.0
self._last_step_done = False
return {
"request": sc["request"],
"task_id": sc["task_id"],
"task_title": sc["task_title"],
"family": sc["family"],
"questions_remaining": self._public_state.questions_remaining,
"instructions": _INSTRUCTIONS,
}
def _tool_ask_question(self, question: str) -> dict[str, Any]:
sc = self._require_scenario()
st = self._public_state
blocked = self._guard_episode_done()
if blocked:
return blocked
question = question[:200]
if st.questions_remaining <= 0:
self._last_step_reward = PENALTY_OVER_CAP
self._last_step_done = True
return {
"answer": "(no more questions allowed)",
"questions_remaining": 0,
"field_revealed": None,
"duplicate": False,
"over_cap": True,
}
text, matched = answer(question, sc["hidden_profile"], sc["family"])
is_duplicate = matched is not None and matched in self._asked_field_keys
revealed_new = matched is not None and not is_duplicate
if revealed_new:
self._asked_field_keys.add(matched)
st.fields_revealed = sorted(self._asked_field_keys)
st.questions_asked = st.questions_asked + [question]
st.answers_received = st.answers_received + [text]
st.questions_remaining = st.questions_remaining - 1
self._last_step_reward = ask_question_reward(
over_cap=False,
is_duplicate_field=is_duplicate,
revealed_new_field=revealed_new,
)
self._last_step_done = False
return {
"answer": text,
"questions_remaining": st.questions_remaining,
"field_revealed": matched if revealed_new else None,
"duplicate": is_duplicate,
"over_cap": False,
}
def _tool_propose_plan(self, plan: str) -> dict[str, Any]:
sc = self._require_scenario()
st = self._public_state
blocked = self._guard_episode_done()
if blocked:
return blocked
parsed, parse_err = parse_plan(plan)
ctx = RubricContext(
family=sc["family"],
hidden_profile=sc["hidden_profile"],
critical_fields=frozenset(sc["critical_fields"]),
required_keys=tuple(sc["required_keys"]),
asked_field_keys=frozenset(self._asked_field_keys),
questions_asked_count=len(st.questions_asked),
max_questions=sc["max_questions"],
parsed_plan=parsed,
parse_error=parse_err,
)
score = float(self.rubric(action=None, observation=ctx))
breakdown = score_breakdown(self.rubric)
self._last_step_reward = score
self._last_step_done = True
st.plan_submitted = True
st.episode_done = True
st.final_score = score
st.score_breakdown = breakdown
return {
"type": "resolution",
"score": score,
"breakdown": breakdown,
"expected_profile": sc["hidden_profile"],
"critical_fields": list(sc["critical_fields"]),
"required_keys": list(sc["required_keys"]),
"submitted_plan": parsed,
"parse_error": parse_err,
"questions_asked": len(st.questions_asked),
"fields_revealed": sorted(self._asked_field_keys),
}
__all__ = ["ClarifyEnvironment"]