Spaces:
Sleeping
Sleeping
| """Trajectory collection helpers for RL-style training loops.""" | |
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass, asdict | |
| from pathlib import Path | |
| from typing import Callable, Dict, List, Optional | |
| from client import PythonEnv | |
| from models import PythonReviewAction | |
| class TrajectoryStep: | |
| observation: Dict[str, object] | |
| action: Dict[str, object] | |
| reward: float | |
| done: bool | |
| class TrajectoryEpisode: | |
| task_id: str | |
| snippet_id: str | |
| final_score: float | |
| cumulative_reward: float | |
| steps: List[TrajectoryStep] | |
| PolicyFn = Callable[[object], PythonReviewAction] | |
| def collect_episode(env, task_id: str, policy: PolicyFn, max_steps: Optional[int] = None) -> TrajectoryEpisode: | |
| """Collect one benchmark episode for an external trainer.""" | |
| result = env.reset(task_id=task_id) | |
| observation = result.observation | |
| steps: List[TrajectoryStep] = [] | |
| step_limit = max_steps or observation.max_steps | |
| for _ in range(step_limit): | |
| action = policy(observation) | |
| result = env.step(action) | |
| steps.append( | |
| TrajectoryStep( | |
| observation=observation.model_dump(), | |
| action=action.model_dump(exclude_none=True), | |
| reward=float(result.reward or 0.0), | |
| done=bool(result.done), | |
| ) | |
| ) | |
| observation = result.observation | |
| if result.done: | |
| break | |
| return TrajectoryEpisode( | |
| task_id=observation.task_id, | |
| snippet_id=observation.snippet_id, | |
| final_score=observation.metrics.current_score, | |
| cumulative_reward=observation.metrics.cumulative_reward, | |
| steps=steps, | |
| ) | |
| def write_jsonl(episodes: List[TrajectoryEpisode], output_path: str | Path) -> None: | |
| """Persist collected trajectories in a trainer-friendly JSONL format.""" | |
| path = Path(output_path) | |
| lines = [json.dumps(asdict(episode)) for episode in episodes] | |
| path.write_text("\n".join(lines), encoding="utf-8") | |