Spaces:
Sleeping
Sleeping
| # utils.py | |
| import re | |
| import json | |
| from typing import List, Dict, Any, Optional, Callable, Iterator | |
| from torch.utils.data import IterableDataset | |
| try: | |
| from disasim import DisasimEnv | |
| from disasim.server.actors.VertexICSActor import VertexICSActor | |
| from disasim.models import DisasimAction | |
| except ModuleNotFoundError: | |
| from . import DisasimEnv | |
| from models import DisasimAction | |
| # ---------------------------------------------------------------------- | |
| # 1. Dataset: rolls out environment with dummy actions, yields prompt + combined snapshot | |
| # ---------------------------------------------------------------------- | |
| class EnvPromptDataset(IterableDataset): | |
| def __init__( | |
| self, | |
| env: DisasimEnv, | |
| actor: VertexICSActor, | |
| total_steps: int = 1000, | |
| dummy_action: Optional[DisasimAction] = None, | |
| ): | |
| self.env = env | |
| self.actor = actor | |
| self.total_steps = total_steps | |
| if dummy_action is None: | |
| dummy_action = DisasimAction(SAR_1=[], MED_1=[], ENG_1=[]) | |
| self.dummy_action = dummy_action | |
| def __len__(self) -> int: # ← add this | |
| return self.total_steps | |
| def __iter__(self) -> Iterator[dict]: | |
| # Reset environment – returns observation only | |
| obs = self.env.reset() | |
| # The environment's internal state is stored in self.env.state | |
| state = self.env.state | |
| for step in range(self.total_steps): | |
| # Build prompt from observation and state | |
| prompt = self.actor._build_prompt(obs, state) | |
| combined_snapshot = { | |
| "env_snapshot": self.env.get_snapshot(), | |
| "actor_snapshot": self.actor.get_snapshot(), | |
| # Also store obs and state because restore_snapshot doesn't restore them | |
| "obs": obs, | |
| "state": state, | |
| } | |
| yield {"prompt": prompt, "snapshot": combined_snapshot} | |
| # Advance environment with dummy action | |
| obs = self.env.step(self.dummy_action) # returns new observation | |
| state = self.env.state # updated internal state | |
| # Check termination via observation | |
| if obs.done: | |
| obs = self.env.reset() | |
| state = self.env.state | |
| # ---------------------------------------------------------------------- | |
| # 2. Parse model completion to directives (JSON inside markdown) | |
| # ---------------------------------------------------------------------- | |
| def parse_completion_to_directives(completion: str) -> dict: | |
| try: | |
| clean = re.sub(r"```json|```", "", completion) | |
| match = re.search(r"(\{.*\})", clean, re.DOTALL) | |
| if match: | |
| return json.loads(match.group(1)) | |
| return {} | |
| except (json.JSONDecodeError, AttributeError): | |
| return {} | |
| # ---------------------------------------------------------------------- | |
| # 3. Evaluate a single completion given a combined snapshot | |
| # ---------------------------------------------------------------------- | |
| def evaluate_completion( | |
| env: DisasimEnv, | |
| actor: VertexICSActor, | |
| completion: str, | |
| snapshot: dict, | |
| parse_fn: Callable[[str], dict] = parse_completion_to_directives, | |
| penalty: float = -1.0, | |
| ) -> float: | |
| # Restore environment and actor | |
| env.restore_snapshot(snapshot["env_snapshot"]) | |
| actor.restore_snapshot(snapshot["actor_snapshot"]) | |
| # Restore the observation and state that were current at prompt time | |
| # (These are not inside the snapshots, so we must store them separately) | |
| obs = snapshot["obs"] | |
| state = snapshot["state"] | |
| directives = parse_fn(completion) | |
| if not directives: | |
| return penalty | |
| # Execute the action | |
| new_obs = env.step(directives) | |
| return float(new_obs.reward) | |
| # ---------------------------------------------------------------------- | |
| # 4. Reward function factory for GRPOTrainer | |
| # ---------------------------------------------------------------------- | |
| def make_reward_fn( | |
| scoring_env: DisasimEnv, | |
| scoring_actor: VertexICSActor, | |
| parse_fn: Callable[[str], dict] = parse_completion_to_directives, | |
| penalty: float = -1.0, | |
| ) -> Callable: | |
| def reward_fn( | |
| prompts: List[str], | |
| completions: List[str], | |
| snapshot: List[dict] = None, | |
| **kwargs | |
| ) -> List[float]: | |
| rewards = [] | |
| for i, completion in enumerate(completions): | |
| if snapshot is None or i >= len(snapshot): | |
| rewards.append(penalty) | |
| continue | |
| snap = snapshot[i] | |
| scoring_env.restore_snapshot(snap["env_snapshot"]) | |
| scoring_actor.restore_snapshot(snap["actor_snapshot"]) | |
| # We don't need to restore obs/state for step, because step only uses env's internal state. | |
| directives = parse_fn(completion) | |
| if not directives: | |
| rewards.append(penalty) | |
| continue | |
| new_obs = scoring_env.step(directives) | |
| rewards.append(float(new_obs.reward)) | |
| return rewards | |
| return reward_fn | |
| # ---------------------------------------------------------------------- | |
| # 5. Convenience function: create collection + scoring envs, dataset, reward fn | |
| # ---------------------------------------------------------------------- | |
| def build_grpo_components( | |
| ENV_URL: str, | |
| actor_class, | |
| actor_config: dict, | |
| total_dataset_steps: int = 2000, | |
| dummy_action: Optional[DisasimAction] = None, | |
| ) -> tuple: | |
| collection_env = DisasimEnv(base_url=ENV_URL) | |
| collection_actor = actor_class(**actor_config) | |
| scoring_env = DisasimEnv(base_url=ENV_URL) | |
| scoring_actor = actor_class(**actor_config) | |
| dataset = EnvPromptDataset( | |
| env=collection_env, | |
| actor=collection_actor, | |
| total_steps=total_dataset_steps, | |
| dummy_action=dummy_action, | |
| ) | |
| reward_fn = make_reward_fn( | |
| scoring_env=scoring_env, | |
| scoring_actor=scoring_actor, | |
| ) | |
| return dataset, reward_fn | |