disasim / utils.py
JonathanShiju12's picture
Upload folder using huggingface_hub
be150a7 verified
# utils.py
import re
import json
from typing import List, Dict, Any, Optional, Callable, Iterator
from torch.utils.data import IterableDataset
try:
from disasim import DisasimEnv
from disasim.server.actors.VertexICSActor import VertexICSActor
from disasim.models import DisasimAction
except ModuleNotFoundError:
from . import DisasimEnv
from models import DisasimAction
# ----------------------------------------------------------------------
# 1. Dataset: rolls out environment with dummy actions, yields prompt + combined snapshot
# ----------------------------------------------------------------------
class EnvPromptDataset(IterableDataset):
def __init__(
self,
env: DisasimEnv,
actor: VertexICSActor,
total_steps: int = 1000,
dummy_action: Optional[DisasimAction] = None,
):
self.env = env
self.actor = actor
self.total_steps = total_steps
if dummy_action is None:
dummy_action = DisasimAction(SAR_1=[], MED_1=[], ENG_1=[])
self.dummy_action = dummy_action
def __len__(self) -> int: # ← add this
return self.total_steps
def __iter__(self) -> Iterator[dict]:
# Reset environment – returns observation only
obs = self.env.reset()
# The environment's internal state is stored in self.env.state
state = self.env.state
for step in range(self.total_steps):
# Build prompt from observation and state
prompt = self.actor._build_prompt(obs, state)
combined_snapshot = {
"env_snapshot": self.env.get_snapshot(),
"actor_snapshot": self.actor.get_snapshot(),
# Also store obs and state because restore_snapshot doesn't restore them
"obs": obs,
"state": state,
}
yield {"prompt": prompt, "snapshot": combined_snapshot}
# Advance environment with dummy action
obs = self.env.step(self.dummy_action) # returns new observation
state = self.env.state # updated internal state
# Check termination via observation
if obs.done:
obs = self.env.reset()
state = self.env.state
# ----------------------------------------------------------------------
# 2. Parse model completion to directives (JSON inside markdown)
# ----------------------------------------------------------------------
def parse_completion_to_directives(completion: str) -> dict:
try:
clean = re.sub(r"```json|```", "", completion)
match = re.search(r"(\{.*\})", clean, re.DOTALL)
if match:
return json.loads(match.group(1))
return {}
except (json.JSONDecodeError, AttributeError):
return {}
# ----------------------------------------------------------------------
# 3. Evaluate a single completion given a combined snapshot
# ----------------------------------------------------------------------
def evaluate_completion(
env: DisasimEnv,
actor: VertexICSActor,
completion: str,
snapshot: dict,
parse_fn: Callable[[str], dict] = parse_completion_to_directives,
penalty: float = -1.0,
) -> float:
# Restore environment and actor
env.restore_snapshot(snapshot["env_snapshot"])
actor.restore_snapshot(snapshot["actor_snapshot"])
# Restore the observation and state that were current at prompt time
# (These are not inside the snapshots, so we must store them separately)
obs = snapshot["obs"]
state = snapshot["state"]
directives = parse_fn(completion)
if not directives:
return penalty
# Execute the action
new_obs = env.step(directives)
return float(new_obs.reward)
# ----------------------------------------------------------------------
# 4. Reward function factory for GRPOTrainer
# ----------------------------------------------------------------------
def make_reward_fn(
scoring_env: DisasimEnv,
scoring_actor: VertexICSActor,
parse_fn: Callable[[str], dict] = parse_completion_to_directives,
penalty: float = -1.0,
) -> Callable:
def reward_fn(
prompts: List[str],
completions: List[str],
snapshot: List[dict] = None,
**kwargs
) -> List[float]:
rewards = []
for i, completion in enumerate(completions):
if snapshot is None or i >= len(snapshot):
rewards.append(penalty)
continue
snap = snapshot[i]
scoring_env.restore_snapshot(snap["env_snapshot"])
scoring_actor.restore_snapshot(snap["actor_snapshot"])
# We don't need to restore obs/state for step, because step only uses env's internal state.
directives = parse_fn(completion)
if not directives:
rewards.append(penalty)
continue
new_obs = scoring_env.step(directives)
rewards.append(float(new_obs.reward))
return rewards
return reward_fn
# ----------------------------------------------------------------------
# 5. Convenience function: create collection + scoring envs, dataset, reward fn
# ----------------------------------------------------------------------
def build_grpo_components(
ENV_URL: str,
actor_class,
actor_config: dict,
total_dataset_steps: int = 2000,
dummy_action: Optional[DisasimAction] = None,
) -> tuple:
collection_env = DisasimEnv(base_url=ENV_URL)
collection_actor = actor_class(**actor_config)
scoring_env = DisasimEnv(base_url=ENV_URL)
scoring_actor = actor_class(**actor_config)
dataset = EnvPromptDataset(
env=collection_env,
actor=collection_actor,
total_steps=total_dataset_steps,
dummy_action=dummy_action,
)
reward_fn = make_reward_fn(
scoring_env=scoring_env,
scoring_actor=scoring_actor,
)
return dataset, reward_fn