File size: 2,031 Bytes
d25ab77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
"""Trajectory collection helpers for RL-style training loops."""

from __future__ import annotations

import json
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Callable, Dict, List, Optional

from client import PythonEnv
from models import PythonReviewAction


@dataclass
class TrajectoryStep:
    observation: Dict[str, object]
    action: Dict[str, object]
    reward: float
    done: bool


@dataclass
class TrajectoryEpisode:
    task_id: str
    snippet_id: str
    final_score: float
    cumulative_reward: float
    steps: List[TrajectoryStep]


PolicyFn = Callable[[object], PythonReviewAction]


def collect_episode(env, task_id: str, policy: PolicyFn, max_steps: Optional[int] = None) -> TrajectoryEpisode:
    """Collect one benchmark episode for an external trainer."""

    result = env.reset(task_id=task_id)
    observation = result.observation
    steps: List[TrajectoryStep] = []
    step_limit = max_steps or observation.max_steps

    for _ in range(step_limit):
        action = policy(observation)
        result = env.step(action)
        steps.append(
            TrajectoryStep(
                observation=observation.model_dump(),
                action=action.model_dump(exclude_none=True),
                reward=float(result.reward or 0.0),
                done=bool(result.done),
            )
        )
        observation = result.observation
        if result.done:
            break

    return TrajectoryEpisode(
        task_id=observation.task_id,
        snippet_id=observation.snippet_id,
        final_score=observation.metrics.current_score,
        cumulative_reward=observation.metrics.cumulative_reward,
        steps=steps,
    )


def write_jsonl(episodes: List[TrajectoryEpisode], output_path: str | Path) -> None:
    """Persist collected trajectories in a trainer-friendly JSONL format."""

    path = Path(output_path)
    lines = [json.dumps(asdict(episode)) for episode in episodes]
    path.write_text("\n".join(lines), encoding="utf-8")