# envs/your_env/server/environment.py from __future__ import annotations import random from datetime import datetime, timezone from pathlib import Path from typing import Dict, List, Optional, Union from openenv.core.env_server import Environment from ..models import ( Action, Difficulty, HistoryEntry, Observation, ResetResult, State, StatePayload, StepInfo, StepResult, ) from .actions import ParsedActionResult, parse_and_validate_action from .loader import load_episode_bundle, load_episode_bundle_from_paths from .observation import build_observation from .reward import compute_reward from .termination import is_episode_done from .transitions import apply_action_to_state class GitHubIssueTriageEnvironment(Environment): SUPPORTS_CONCURRENT_SESSIONS = True def __init__( self, *, episodes: Optional[list[State]] = None, repo_rules_source: Optional[Union[str, Path]] = None, tasks_source: Optional[Union[str, Path]] = None, issues_source: Optional[Union[str, Path]] = None, data_dir: Optional[Union[str, Path]] = None, strict_mode: bool = True, live_github: bool = False, ) -> None: self.strict_mode = strict_mode self.live_github = live_github self._episodes_source: list[State] = episodes or [] self._episode_index: int = -1 self._state: Optional[State] = None self._seed: Optional[int] = None self._global_sequence: List[int] = [] self._global_position: int = -1 self._difficulty_sequences: Dict[Difficulty, List[int]] = {} self._difficulty_positions: Dict[Difficulty, int] = {} if not self._episodes_source: if data_dir is not None: self._episodes_source = load_episode_bundle_from_paths( data_dir, live_github=live_github, ) elif repo_rules_source and tasks_source and issues_source: self._episodes_source = load_episode_bundle( repo_rules_path=repo_rules_source, tasks_path=tasks_source, issues_path=issues_source, live_github=live_github, ) self._initialize_sequences() def _initialize_sequences(self, seed: Optional[int] = None) -> None: if not self._episodes_source: self._global_sequence = [] self._global_position = -1 self._difficulty_sequences = {} self._difficulty_positions = {} return rng = random.Random(seed) if seed is not None else None indices = list(range(len(self._episodes_source))) if rng is not None: rng.shuffle(indices) self._global_sequence = indices self._global_position = -1 self._difficulty_sequences = {} self._difficulty_positions = {} for difficulty in Difficulty: seq = [ idx for idx in indices if self._episodes_source[idx].task.difficulty == difficulty ] if rng is not None: rng.shuffle(seq) self._difficulty_sequences[difficulty] = seq self._difficulty_positions[difficulty] = -1 def _set_seed(self, seed: Optional[int]) -> None: self._seed = seed self._initialize_sequences(seed) @staticmethod def _timestamp() -> str: return datetime.now(timezone.utc).isoformat() def _record_history( self, state: State, *, action: Action, outcome: str, success: bool, ) -> None: state.current_action_history.append( HistoryEntry( step_index=state.step_count, action_type=action.type, action_payload=action.model_dump(), outcome=outcome, success=success, timestamp=self._timestamp(), ) ) def _normalize_difficulty( self, difficulty: Optional[Union[str, Difficulty]] ) -> Optional[Difficulty]: if difficulty is None: return None if isinstance(difficulty, Difficulty): return difficulty try: return Difficulty(difficulty.strip().lower()) except Exception as exc: raise KeyError(f"Unknown difficulty: {difficulty}") from exc def _next_index(self, *, difficulty: Optional[Difficulty] = None) -> int: if difficulty is None: if not self._global_sequence: raise RuntimeError("No episodes available.") self._global_position = (self._global_position + 1) % len(self._global_sequence) return self._global_sequence[self._global_position] seq = self._difficulty_sequences.get(difficulty, []) if not seq: raise KeyError(f"No episodes available for difficulty '{difficulty.value}'.") position = (self._difficulty_positions[difficulty] + 1) % len(seq) self._difficulty_positions[difficulty] = position return seq[position] def reset( self, task_id: Optional[str] = None, difficulty: Optional[Union[str, Difficulty]] = None, seed: Optional[int] = None, ) -> Observation: if not self._episodes_source: raise RuntimeError( "No episodes loaded. Pass episodes=..., data_dir=..., " "or repo_rules_source/tasks_source/issues_source." ) if seed is not None: self._set_seed(seed) difficulty_enum = self._normalize_difficulty(difficulty) if task_id is None: if difficulty_enum is None: index = self._next_index() else: index = self._next_index(difficulty=difficulty_enum) self._episode_index = index base_state = self._episodes_source[index] else: match_idx = None for idx, ep in enumerate(self._episodes_source): if ep.task.task_id == task_id or ep.episode_id == task_id: match_idx = idx break if match_idx is None: raise KeyError(f"Unknown task_id or episode_id: {task_id}") self._episode_index = match_idx base_state = self._episodes_source[match_idx] self._state = base_state.model_copy(deep=True) self._state.step_count = 0 self._state.done = False self._state.current_action_history = [] self._state.pending_missing_fields = ( list(self._state.hidden_target.required_missing_fields) if self._state.hidden_target else [] ) self._state.requested_fields = [] self._state.public_notes = [] self._state.last_action_valid = True self._state.last_action_message = "" self._state.internal_score_cache = None return build_observation(self._state) def step(self, action: Action | dict) -> StepResult: state = self._require_state() if state.done: obs = build_observation(state) reward = compute_reward(state) reward_dump = reward.model_dump() reward_components = ( reward_dump.pop("components", {}) if isinstance(reward_dump.get("components"), dict) else {} ) reward_breakdown = { key: float(value) for key, value in reward_dump.items() if isinstance(value, (int, float)) } return StepResult( observation=obs, reward=reward, done=True, info=StepInfo( action_valid=False, action_effect="episode_already_done", changed_fields=[], reward_breakdown=reward_breakdown, reward_components=reward_components, grader_notes=["Episode already completed."], ), ) validation: ParsedActionResult = parse_and_validate_action( action, state.task.allowed_actions ) parsed_action = validation.action if not validation.valid: action_effect = validation.effect or "action_validation_failed" notes = validation.notes or ["Action failed validation."] self._record_history( state, action=parsed_action, outcome=action_effect, success=False, ) state.step_count += 1 if is_episode_done(state): state.done = True state.last_action_valid = False state.last_action_message = notes[0] reward = compute_reward(state) obs = build_observation(state) state.internal_score_cache = reward.total reward_dump = reward.model_dump() reward_components = ( reward_dump.pop("components", {}) if isinstance(reward_dump.get("components"), dict) else {} ) reward_breakdown = { key: float(value) for key, value in reward_dump.items() if isinstance(value, (int, float)) } info = StepInfo( action_valid=False, action_effect=action_effect, changed_fields=[], reward_breakdown=reward_breakdown, reward_components=reward_components, grader_notes=notes, ) return StepResult( observation=obs, reward=reward, done=state.done, info=info, ) transition = apply_action_to_state(state, parsed_action) state.step_count += 1 if is_episode_done(state): state.done = True reward = compute_reward(state) obs = build_observation(state) transition_notes = list(getattr(transition, "notes", [])) transition_effect = str(getattr(transition, "action_effect", "")) state.internal_score_cache = reward.total state.last_action_valid = bool(getattr(transition, "action_valid", True)) state.last_action_message = ( transition_notes[0] if transition_notes else transition_effect ) reward_dump = reward.model_dump() reward_components = ( reward_dump.pop("components", {}) if isinstance(reward_dump.get("components"), dict) else {} ) reward_breakdown = { key: float(value) for key, value in reward_dump.items() if isinstance(value, (int, float)) } info = StepInfo( action_valid=bool(getattr(transition, "action_valid", True)), action_effect=transition_effect, changed_fields=list(getattr(transition, "changed_fields", [])), reward_breakdown=reward_breakdown, reward_components=reward_components, grader_notes=transition_notes, ) return StepResult( observation=obs, reward=reward, done=state.done, info=info, ) @property def state(self) -> State: return self._require_state().model_copy(deep=True) def snapshot(self) -> StatePayload: return StatePayload(state=self.state) def reset_result( self, task_id: Optional[str] = None, difficulty: Optional[Union[str, Difficulty]] = None, seed: Optional[int] = None, ) -> ResetResult: obs = self.reset(task_id=task_id, difficulty=difficulty, seed=seed) return ResetResult(observation=obs, state=self.state) def _require_state(self) -> State: if self._state is None: raise RuntimeError("Environment has not been reset yet.") return self._state