python_env / rollout.py
darshanajudiya7's picture
Upload folder using huggingface_hub
d25ab77 verified
"""Trajectory collection helpers for RL-style training loops."""
from __future__ import annotations
import json
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Callable, Dict, List, Optional
from client import PythonEnv
from models import PythonReviewAction
@dataclass
class TrajectoryStep:
observation: Dict[str, object]
action: Dict[str, object]
reward: float
done: bool
@dataclass
class TrajectoryEpisode:
task_id: str
snippet_id: str
final_score: float
cumulative_reward: float
steps: List[TrajectoryStep]
PolicyFn = Callable[[object], PythonReviewAction]
def collect_episode(env, task_id: str, policy: PolicyFn, max_steps: Optional[int] = None) -> TrajectoryEpisode:
"""Collect one benchmark episode for an external trainer."""
result = env.reset(task_id=task_id)
observation = result.observation
steps: List[TrajectoryStep] = []
step_limit = max_steps or observation.max_steps
for _ in range(step_limit):
action = policy(observation)
result = env.step(action)
steps.append(
TrajectoryStep(
observation=observation.model_dump(),
action=action.model_dump(exclude_none=True),
reward=float(result.reward or 0.0),
done=bool(result.done),
)
)
observation = result.observation
if result.done:
break
return TrajectoryEpisode(
task_id=observation.task_id,
snippet_id=observation.snippet_id,
final_score=observation.metrics.current_score,
cumulative_reward=observation.metrics.cumulative_reward,
steps=steps,
)
def write_jsonl(episodes: List[TrajectoryEpisode], output_path: str | Path) -> None:
"""Persist collected trajectories in a trainer-friendly JSONL format."""
path = Path(output_path)
lines = [json.dumps(asdict(episode)) for episode in episodes]
path.write_text("\n".join(lines), encoding="utf-8")