File size: 6,069 Bytes
1e49495
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
"""
TrainingRunner — parallel episode executor for GRPO training.

Each episode runs in a ThreadPoolExecutor thread.
After every env.step(), observations are pushed to the broadcast server (fire-and-forget).
"""
from __future__ import annotations

import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Callable, Optional

from server.models import OrigamiAction
from server.origami_environment import OrigamiEnvironment


BroadcastFn = Callable[[str, dict], None]


def run_episode(
    strategy_fn: Callable[[dict], dict],
    task_name: str,
    ep_id: Optional[str] = None,
    broadcast_fn: Optional[BroadcastFn] = None,
    max_steps: Optional[int] = None,
) -> dict:
    """Run a single origami episode with a given strategy function.

    Args:
        strategy_fn: Callable that receives paper_state dict and returns a fold dict:
                     {"type": "valley"|"mountain"|"pleat"|"crimp"|"stop",
                      "line": {"start": [x, y], "end": [x, y]},
                      "angle": 180.0}
        task_name: Name of the task (from server/tasks.py)
        ep_id: Episode identifier for broadcast; auto-generated if None
        broadcast_fn: Optional callback(ep_id, data) for live streaming
        max_steps: Override task's max_folds if provided

    Returns:
        dict with keys: episode_id, score, final_metrics, fold_history, status
    """
    ep_id = ep_id or str(uuid.uuid4())[:8]
    env = OrigamiEnvironment()

    obs = env.reset(task_name=task_name)

    if broadcast_fn:
        broadcast_fn(ep_id, {
            "type": "episode_update",
            "episode_id": ep_id,
            "task_name": task_name,
            "step": 0,
            "observation": _obs_to_dict(obs),
        })

    step_limit = max_steps or env._task.get("max_folds", 20) if env._task else 20
    status = "done"

    for step_idx in range(step_limit):
        if obs.done:
            break

        # Strategy generates a fold dict
        try:
            fold_dict = strategy_fn(obs.paper_state)
        except Exception as exc:
            status = "error"
            if broadcast_fn:
                broadcast_fn(ep_id, {
                    "type": "episode_done",
                    "episode_id": ep_id,
                    "status": "error",
                    "score": obs.reward or 0.0,
                    "final_metrics": obs.metrics,
                    "error": str(exc),
                })
            break

        fold_type = fold_dict.get("type", "valley")
        fold_line = fold_dict.get("line", {"start": [0, 0.5], "end": [1, 0.5]})
        fold_angle = float(fold_dict.get("angle", 180.0))

        action = OrigamiAction(
            fold_type=fold_type,
            fold_line=fold_line,
            fold_angle=fold_angle,
        )
        obs = env.step(action)

        if broadcast_fn:
            broadcast_fn(ep_id, {
                "type": "episode_update",
                "episode_id": ep_id,
                "task_name": task_name,
                "step": step_idx + 1,
                "observation": _obs_to_dict(obs),
            })

        if obs.done:
            break
    else:
        status = "timeout"

    score = obs.reward if obs.reward is not None else (env._total_reward or 0.0)

    if broadcast_fn:
        broadcast_fn(ep_id, {
            "type": "episode_done",
            "episode_id": ep_id,
            "status": status,
            "score": float(score),
            "final_metrics": obs.metrics,
        })

    return {
        "episode_id": ep_id,
        "score": float(score),
        "final_metrics": obs.metrics,
        "fold_history": obs.fold_history,
        "status": status,
    }


def run_batch(
    strategy_fns: list[Callable[[dict], dict]],
    task_name: str,
    broadcast_fn: Optional[BroadcastFn] = None,
    batch_id: Optional[int] = None,
    max_workers: int = 8,
) -> list[dict]:
    """Run G episodes in parallel with a ThreadPoolExecutor.

    Args:
        strategy_fns: List of G strategy callables (one per completion)
        task_name: Task to use for all episodes
        broadcast_fn: Optional broadcast callback, called after each step
        batch_id: Batch identifier for broadcast
        max_workers: Max parallel threads (bounded by G)

    Returns:
        List of episode result dicts, in same order as strategy_fns
    """
    n = len(strategy_fns)
    ep_ids = [f"ep_{(batch_id or 0):04d}_{i:02d}" for i in range(n)]
    workers = min(max_workers, n)

    results: list[dict] = [{}] * n

    with ThreadPoolExecutor(max_workers=workers) as pool:
        futures = {
            pool.submit(
                run_episode,
                fn,
                task_name,
                ep_ids[i],
                broadcast_fn,
            ): i
            for i, fn in enumerate(strategy_fns)
        }

        for future in as_completed(futures):
            idx = futures[future]
            try:
                results[idx] = future.result()
            except Exception as exc:
                results[idx] = {
                    "episode_id": ep_ids[idx],
                    "score": 0.0,
                    "final_metrics": {},
                    "fold_history": [],
                    "status": "error",
                    "error": str(exc),
                }

    return results


def _obs_to_dict(obs) -> dict:
    """Convert OrigamiObservation to a JSON-serializable dict."""
    try:
        return obs.model_dump()
    except AttributeError:
        return {
            "task": obs.task if hasattr(obs, "task") else {},
            "paper_state": obs.paper_state if hasattr(obs, "paper_state") else {},
            "metrics": obs.metrics if hasattr(obs, "metrics") else {},
            "fold_history": obs.fold_history if hasattr(obs, "fold_history") else [],
            "done": obs.done if hasattr(obs, "done") else False,
            "reward": obs.reward if hasattr(obs, "reward") else None,
            "error": obs.error if hasattr(obs, "error") else None,
        }