repl / server /repl_environment.py
sergiopaniego's picture
Upload folder using huggingface_hub
3d90b75 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
REPL Environment Implementation.
A Python REPL environment for training language models on code execution tasks,
based on the Recursive Language Models (RLM) paradigm.
References:
- RLM Paper: https://arxiv.org/abs/2512.24601
- Prime Intellect Blog: https://www.primeintellect.ai/blog/rlm
- Alex Zhang Blog: https://alexzhang13.github.io/blog/2025/rlm/
"""
import os
import re
from collections.abc import Callable
from typing import Any, Dict, List, Optional
from uuid import uuid4
# Support both in-repo and standalone imports
try:
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State, EnvironmentMetadata
except ImportError:
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State, EnvironmentMetadata
try:
from ..models import REPLAction, REPLObservation, REPLState, CodeBlockResult
except ImportError:
from models import REPLAction, REPLObservation, REPLState, CodeBlockResult
try:
from .python_executor import PythonExecutor
except ImportError:
from python_executor import PythonExecutor
class REPLEnvironment(Environment):
"""
A REPL environment for training language models to use code execution.
Based on the Recursive Language Models (RLM) paradigm, this environment allows
language models to:
- Execute Python code in a sandboxed REPL
- Work with large contexts loaded as variables
- Finalize answers via FINAL(), FINAL_VAR(), or answer dict pattern
- Optionally make recursive LLM calls via llm_query() / llm_batch()
Supports two finalization patterns:
1. RLM-style: print('FINAL(answer)') or print('FINAL_VAR(var_name)')
2. Prime Intellect style: answer = {"content": "...", "ready": True}
Example:
>>> env = REPLEnvironment(context="Hello World", task_prompt="Count chars")
>>> obs = env.reset()
>>> print(obs.context_preview) # "Hello World"
>>>
>>> obs = env.step(REPLAction(code="result = len(context)"))
>>> print(obs.result.success) # True
>>> print(obs.available_variables) # ["context", "result", "answer"]
>>>
>>> obs = env.step(REPLAction(code="print(f'FINAL({result})')"))
>>> print(obs.done) # True
>>> print(obs.metadata["final_answer"]) # "11"
"""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(
self,
context: Optional[str] = None,
task_prompt: Optional[str] = None,
max_iterations: int = 30,
max_output_length: int = 8192,
context_preview_length: int = 500,
reward_on_success: float = 1.0,
reward_on_iteration: float = 0.0,
reward_on_failure: float = -0.1,
reward_on_error: float = -0.05,
llm_query_fn: Optional[Callable[[str], str]] = None,
llm_batch_fn: Optional[Callable[[List[str]], List[str]]] = None,
):
"""Initialize the REPL environment.
Args:
context: Initial context to load (can also be set via REPL_CONTEXT env var)
task_prompt: Task description (can also be set via REPL_TASK_PROMPT env var)
max_iterations: Maximum steps per episode (default 30, env var REPL_MAX_ITERATIONS)
max_output_length: Max chars for stdout/stderr per turn (default 8192)
context_preview_length: Chars to show in context preview (default 500)
reward_on_success: Reward when final answer is submitted (default 1.0)
reward_on_iteration: Reward per iteration step (default 0.0)
reward_on_failure: Reward when max iterations reached (default -0.1)
reward_on_error: Reward when code execution fails (default -0.05)
llm_query_fn: Optional function for llm_query() support
llm_batch_fn: Optional function for llm_batch() support
"""
self.initial_context = context or os.environ.get("REPL_CONTEXT", "")
self.initial_task_prompt = task_prompt or os.environ.get("REPL_TASK_PROMPT", "")
self.max_iterations = int(os.environ.get("REPL_MAX_ITERATIONS", max_iterations))
self.max_output_length = max_output_length
self.context_preview_length = context_preview_length
# Reward configuration
self.reward_on_success = reward_on_success
self.reward_on_iteration = reward_on_iteration
self.reward_on_failure = reward_on_failure
self.reward_on_error = reward_on_error
# Optional LLM functions for recursive calls
self.llm_query_fn = llm_query_fn
self.llm_batch_fn = llm_batch_fn
# State (initialized on reset)
self._state: Optional[REPLState] = None
self._executor: Optional[PythonExecutor] = None
def reset(
self,
seed: Optional[int] = None,
episode_id: Optional[str] = None,
context: Optional[str] = None,
task_prompt: Optional[str] = None,
**kwargs: Any,
) -> REPLObservation:
"""Reset the environment with optional new context.
Args:
seed: Optional random seed (for reproducibility)
episode_id: Optional episode identifier (if not provided, one is generated)
context: Context to load (overrides initial_context)
task_prompt: Task description (overrides initial_task_prompt)
**kwargs: Additional reset parameters
Returns:
Initial REPLObservation with environment ready message
"""
effective_context = context or self.initial_context
effective_task_prompt = task_prompt or self.initial_task_prompt
# Initialize state
self._state = REPLState(
episode_id=episode_id or str(uuid4()),
step_count=0,
context=effective_context,
task_prompt=effective_task_prompt,
iteration=0,
max_iterations=self.max_iterations,
namespace_keys=[],
final_answer=None,
total_execution_time=0.0,
)
# Initialize executor
self._executor = PythonExecutor(max_output_length=self.max_output_length)
# Initialize answer dict (Prime Intellect style)
self._executor.set_variable("answer", {"content": "", "ready": False})
# Load context into namespace if provided
if effective_context:
self._executor.set_context(effective_context)
# Inject LLM functions if provided
if self.llm_query_fn:
self._executor.inject_function("llm_query", self.llm_query_fn)
if self.llm_batch_fn:
self._executor.inject_function("llm_batch", self.llm_batch_fn)
# Inject FINAL helper function so both FINAL(x) and print(f'FINAL({x})') work
# Returns the FINAL pattern as a string so it appears in stdout for detection
def final_helper(value):
"""Helper that returns FINAL(value) string for detection."""
return f"FINAL({value})"
self._executor.inject_function("FINAL", final_helper)
# Update namespace keys
self._state.namespace_keys = self._executor.list_variables()
# Build initial message
message_parts = ["REPL environment initialized."]
if effective_context:
message_parts.append(
f"Context loaded ({len(effective_context)} chars). "
"Use 'context' variable to access it."
)
if effective_task_prompt:
message_parts.append(f"Task: {effective_task_prompt}")
message_parts.append(
"Use answer['content'] to store your answer, "
"and set answer['ready'] = True when done."
)
return REPLObservation(
result=CodeBlockResult(
stdout="\n".join(message_parts),
stderr="",
locals_snapshot={},
execution_time=0.0,
success=True,
exception=None,
),
context_preview=(
effective_context[: self.context_preview_length]
if effective_context
else None
),
context_length=len(effective_context) if effective_context else 0,
available_variables=self._state.namespace_keys,
iteration=0,
max_iterations=self.max_iterations,
done=False,
reward=0.0,
metadata={
"task_prompt": effective_task_prompt,
"message": "Environment ready.",
},
)
def step(
self,
action: REPLAction,
timeout_s: Optional[float] = None,
**kwargs: Any,
) -> REPLObservation:
"""Execute code and return observation.
Args:
action: REPLAction containing code to execute
timeout_s: Optional timeout in seconds (not currently used)
**kwargs: Additional step parameters
Returns:
REPLObservation with execution results
"""
if self._state is None or self._executor is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
self._state.step_count += 1
self._state.iteration += 1
# Check if agent explicitly signals final answer
if action.is_final:
self._state.final_answer = action.final_answer or ""
return self._create_final_observation(
success=True,
message="Final answer submitted.",
reward=self.reward_on_success,
)
# Check iteration limit
if self._state.iteration >= self.max_iterations:
# Check if there's a partial answer in the answer dict
answer_var = self._executor.get_variable("answer")
if isinstance(answer_var, dict) and answer_var.get("content"):
self._state.final_answer = str(answer_var.get("content", ""))
return self._create_final_observation(
success=False,
message=f"Maximum iterations ({self.max_iterations}) reached.",
reward=self.reward_on_failure,
)
# Execute code
result = self._executor.execute(action.code)
self._state.total_execution_time += result["execution_time"]
self._state.namespace_keys = self._executor.list_variables()
# Calculate reward
reward = self.reward_on_iteration
if not result["success"]:
reward += self.reward_on_error
# Check for final answer patterns
final_answer = self._extract_final_answer(result["stdout"])
done = final_answer is not None
if done:
self._state.final_answer = final_answer
reward = self.reward_on_success
return REPLObservation(
result=CodeBlockResult(
stdout=result["stdout"],
stderr=result["stderr"],
locals_snapshot=result["locals_snapshot"],
execution_time=result["execution_time"],
success=result["success"],
exception=result["exception"],
),
context_preview=(
self._state.context[: self.context_preview_length]
if self._state.context
else None
),
context_length=len(self._state.context) if self._state.context else 0,
available_variables=self._state.namespace_keys,
iteration=self._state.iteration,
max_iterations=self.max_iterations,
done=done,
reward=reward,
metadata={
"task_prompt": self._state.task_prompt,
"final_answer": final_answer,
"execution_time": result["execution_time"],
},
)
def _extract_final_answer(self, stdout: str) -> Optional[str]:
"""Extract final answer from output.
Supports multiple patterns:
1. RLM-style: FINAL(answer) in stdout
2. RLM-style: FINAL_VAR(variable_name) in stdout
3. Prime Intellect style: answer = {"content": "...", "ready": True} in namespace
Args:
stdout: Standard output from code execution
Returns:
Final answer string or None if not found
"""
# Pattern 1: RLM-style FINAL(answer)
final_match = re.search(r"FINAL\((.*?)\)", stdout, re.DOTALL)
if final_match:
return final_match.group(1).strip()
# Pattern 2: RLM-style FINAL_VAR(variable_name)
final_var_match = re.search(r"FINAL_VAR\((\w+)\)", stdout)
if final_var_match and self._executor:
var_name = final_var_match.group(1)
value = self._executor.get_variable(var_name)
if value is not None:
return str(value)
# Pattern 3: Prime Intellect style answer dict
if self._executor:
answer_var = self._executor.get_variable("answer")
if isinstance(answer_var, dict):
if answer_var.get("ready", False):
return str(answer_var.get("content", ""))
return None
def _create_final_observation(
self, success: bool, message: str, reward: float
) -> REPLObservation:
"""Create observation for episode termination.
Args:
success: Whether the episode ended successfully
message: Termination message
reward: Final reward value
Returns:
Final REPLObservation with done=True
"""
return REPLObservation(
result=CodeBlockResult(
stdout=message,
stderr="",
locals_snapshot={},
execution_time=0.0,
success=success,
exception=None,
),
context_preview=None,
context_length=0,
available_variables=[],
iteration=self._state.iteration if self._state else 0,
max_iterations=self.max_iterations,
done=True,
reward=reward,
metadata={
"final_answer": self._state.final_answer if self._state else None,
"total_execution_time": (
self._state.total_execution_time if self._state else 0
),
"total_iterations": self._state.iteration if self._state else 0,
},
)
@property
def state(self) -> REPLState:
"""Get the current environment state.
Returns:
Current REPLState
Raises:
RuntimeError: If environment not initialized
"""
if self._state is None:
raise RuntimeError("Environment not initialized. Call reset() first.")
return self._state
def close(self) -> None:
"""Cleanup resources."""
self._executor = None
self._state = None
def get_metadata(self) -> EnvironmentMetadata:
"""Get environment metadata.
Returns:
EnvironmentMetadata with environment info
"""
return EnvironmentMetadata(
name="repl_env",
description="Python REPL environment for RLM-style code execution",
version="0.1.0",
)