optigami / training /runner.py
sissississi's picture
new-environment (#5)
1e49495
"""
TrainingRunner — parallel episode executor for GRPO training.
Each episode runs in a ThreadPoolExecutor thread.
After every env.step(), observations are pushed to the broadcast server (fire-and-forget).
"""
from __future__ import annotations
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Optional
from server.models import OrigamiAction
from server.origami_environment import OrigamiEnvironment
BroadcastFn = Callable[[str, dict], None]
def run_episode(
strategy_fn: Callable[[dict], dict],
task_name: str,
ep_id: Optional[str] = None,
broadcast_fn: Optional[BroadcastFn] = None,
max_steps: Optional[int] = None,
) -> dict:
"""Run a single origami episode with a given strategy function.
Args:
strategy_fn: Callable that receives paper_state dict and returns a fold dict:
{"type": "valley"|"mountain"|"pleat"|"crimp"|"stop",
"line": {"start": [x, y], "end": [x, y]},
"angle": 180.0}
task_name: Name of the task (from server/tasks.py)
ep_id: Episode identifier for broadcast; auto-generated if None
broadcast_fn: Optional callback(ep_id, data) for live streaming
max_steps: Override task's max_folds if provided
Returns:
dict with keys: episode_id, score, final_metrics, fold_history, status
"""
ep_id = ep_id or str(uuid.uuid4())[:8]
env = OrigamiEnvironment()
obs = env.reset(task_name=task_name)
if broadcast_fn:
broadcast_fn(ep_id, {
"type": "episode_update",
"episode_id": ep_id,
"task_name": task_name,
"step": 0,
"observation": _obs_to_dict(obs),
})
step_limit = max_steps or env._task.get("max_folds", 20) if env._task else 20
status = "done"
for step_idx in range(step_limit):
if obs.done:
break
# Strategy generates a fold dict
try:
fold_dict = strategy_fn(obs.paper_state)
except Exception as exc:
status = "error"
if broadcast_fn:
broadcast_fn(ep_id, {
"type": "episode_done",
"episode_id": ep_id,
"status": "error",
"score": obs.reward or 0.0,
"final_metrics": obs.metrics,
"error": str(exc),
})
break
fold_type = fold_dict.get("type", "valley")
fold_line = fold_dict.get("line", {"start": [0, 0.5], "end": [1, 0.5]})
fold_angle = float(fold_dict.get("angle", 180.0))
action = OrigamiAction(
fold_type=fold_type,
fold_line=fold_line,
fold_angle=fold_angle,
)
obs = env.step(action)
if broadcast_fn:
broadcast_fn(ep_id, {
"type": "episode_update",
"episode_id": ep_id,
"task_name": task_name,
"step": step_idx + 1,
"observation": _obs_to_dict(obs),
})
if obs.done:
break
else:
status = "timeout"
score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)
if broadcast_fn:
broadcast_fn(ep_id, {
"type": "episode_done",
"episode_id": ep_id,
"status": status,
"score": float(score),
"final_metrics": obs.metrics,
})
return {
"episode_id": ep_id,
"score": float(score),
"final_metrics": obs.metrics,
"fold_history": obs.fold_history,
"status": status,
}
def run_batch(
strategy_fns: list[Callable[[dict], dict]],
task_name: str,
broadcast_fn: Optional[BroadcastFn] = None,
batch_id: Optional[int] = None,
max_workers: int = 8,
) -> list[dict]:
"""Run G episodes in parallel with a ThreadPoolExecutor.
Args:
strategy_fns: List of G strategy callables (one per completion)
task_name: Task to use for all episodes
broadcast_fn: Optional broadcast callback, called after each step
batch_id: Batch identifier for broadcast
max_workers: Max parallel threads (bounded by G)
Returns:
List of episode result dicts, in same order as strategy_fns
"""
n = len(strategy_fns)
ep_ids = [f"ep_{(batch_id or 0):04d}_{i:02d}" for i in range(n)]
workers = min(max_workers, n)
results: list[dict] = [{}] * n
with ThreadPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(
run_episode,
fn,
task_name,
ep_ids[i],
broadcast_fn,
): i
for i, fn in enumerate(strategy_fns)
}
for future in as_completed(futures):
idx = futures[future]
try:
results[idx] = future.result()
except Exception as exc:
results[idx] = {
"episode_id": ep_ids[idx],
"score": 0.0,
"final_metrics": {},
"fold_history": [],
"status": "error",
"error": str(exc),
}
return results
def _obs_to_dict(obs) -> dict:
"""Convert OrigamiObservation to a JSON-serializable dict."""
try:
return obs.model_dump()
except AttributeError:
return {
"task": obs.task if hasattr(obs, "task") else {},
"paper_state": obs.paper_state if hasattr(obs, "paper_state") else {},
"metrics": obs.metrics if hasattr(obs, "metrics") else {},
"fold_history": obs.fold_history if hasattr(obs, "fold_history") else [],
"done": obs.done if hasattr(obs, "done") else False,
"reward": obs.reward if hasattr(obs, "reward") else None,
"error": obs.error if hasattr(obs, "error") else None,
}