# utils.py import re import json from typing import List, Dict, Any, Optional, Callable, Iterator from torch.utils.data import IterableDataset try: from disasim import DisasimEnv from disasim.server.actors.VertexICSActor import VertexICSActor from disasim.models import DisasimAction except ModuleNotFoundError: from . import DisasimEnv from models import DisasimAction # ---------------------------------------------------------------------- # 1. Dataset: rolls out environment with dummy actions, yields prompt + combined snapshot # ---------------------------------------------------------------------- class EnvPromptDataset(IterableDataset): def __init__( self, env: DisasimEnv, actor: VertexICSActor, total_steps: int = 1000, dummy_action: Optional[DisasimAction] = None, ): self.env = env self.actor = actor self.total_steps = total_steps if dummy_action is None: dummy_action = DisasimAction(SAR_1=[], MED_1=[], ENG_1=[]) self.dummy_action = dummy_action def __len__(self) -> int: # ← add this return self.total_steps def __iter__(self) -> Iterator[dict]: # Reset environment – returns observation only obs = self.env.reset() # The environment's internal state is stored in self.env.state state = self.env.state for step in range(self.total_steps): # Build prompt from observation and state prompt = self.actor._build_prompt(obs, state) combined_snapshot = { "env_snapshot": self.env.get_snapshot(), "actor_snapshot": self.actor.get_snapshot(), # Also store obs and state because restore_snapshot doesn't restore them "obs": obs, "state": state, } yield {"prompt": prompt, "snapshot": combined_snapshot} # Advance environment with dummy action obs = self.env.step(self.dummy_action) # returns new observation state = self.env.state # updated internal state # Check termination via observation if obs.done: obs = self.env.reset() state = self.env.state # ---------------------------------------------------------------------- # 2. Parse model completion to directives (JSON inside markdown) # ---------------------------------------------------------------------- def parse_completion_to_directives(completion: str) -> dict: try: clean = re.sub(r"```json|```", "", completion) match = re.search(r"(\{.*\})", clean, re.DOTALL) if match: return json.loads(match.group(1)) return {} except (json.JSONDecodeError, AttributeError): return {} # ---------------------------------------------------------------------- # 3. Evaluate a single completion given a combined snapshot # ---------------------------------------------------------------------- def evaluate_completion( env: DisasimEnv, actor: VertexICSActor, completion: str, snapshot: dict, parse_fn: Callable[[str], dict] = parse_completion_to_directives, penalty: float = -1.0, ) -> float: # Restore environment and actor env.restore_snapshot(snapshot["env_snapshot"]) actor.restore_snapshot(snapshot["actor_snapshot"]) # Restore the observation and state that were current at prompt time # (These are not inside the snapshots, so we must store them separately) obs = snapshot["obs"] state = snapshot["state"] directives = parse_fn(completion) if not directives: return penalty # Execute the action new_obs = env.step(directives) return float(new_obs.reward) # ---------------------------------------------------------------------- # 4. Reward function factory for GRPOTrainer # ---------------------------------------------------------------------- def make_reward_fn( scoring_env: DisasimEnv, scoring_actor: VertexICSActor, parse_fn: Callable[[str], dict] = parse_completion_to_directives, penalty: float = -1.0, ) -> Callable: def reward_fn( prompts: List[str], completions: List[str], snapshot: List[dict] = None, **kwargs ) -> List[float]: rewards = [] for i, completion in enumerate(completions): if snapshot is None or i >= len(snapshot): rewards.append(penalty) continue snap = snapshot[i] scoring_env.restore_snapshot(snap["env_snapshot"]) scoring_actor.restore_snapshot(snap["actor_snapshot"]) # We don't need to restore obs/state for step, because step only uses env's internal state. directives = parse_fn(completion) if not directives: rewards.append(penalty) continue new_obs = scoring_env.step(directives) rewards.append(float(new_obs.reward)) return rewards return reward_fn # ---------------------------------------------------------------------- # 5. Convenience function: create collection + scoring envs, dataset, reward fn # ---------------------------------------------------------------------- def build_grpo_components( ENV_URL: str, actor_class, actor_config: dict, total_dataset_steps: int = 2000, dummy_action: Optional[DisasimAction] = None, ) -> tuple: collection_env = DisasimEnv(base_url=ENV_URL) collection_actor = actor_class(**actor_config) scoring_env = DisasimEnv(base_url=ENV_URL) scoring_actor = actor_class(**actor_config) dataset = EnvPromptDataset( env=collection_env, actor=collection_actor, total_steps=total_dataset_steps, dummy_action=dummy_action, ) reward_fn = make_reward_fn( scoring_env=scoring_env, scoring_actor=scoring_actor, ) return dataset, reward_fn