rl_code_fix_env / src /reward /trajectory_logger.py
Viraj0112's picture
Upload folder using huggingface_hub
03a907a verified
"""
FIX 9: Trajectory logging for GRPO training data collection.
Per rulebook Section 5 & 6: Save episode trajectories to enable GRPO training.
Each episode is saved as JSON with metadata, summary, and full trajectory.
"""
import json
import os
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any, Optional
class TrajectoryLogger:
"""Save episode trajectories for GRPO training."""
def __init__(self, output_dir: str = "./episodes"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def save_episode(
self,
task: str,
difficulty: str,
success: bool,
steps: int,
rewards: List[float],
trajectory: List[Dict[str, Any]],
model: str = "unknown",
elapsed_s: float = 0.0,
) -> str:
"""
Save one episode to JSON for GRPO training.
Args:
task: Task identifier (e.g., "easy", "problem_1")
difficulty: Difficulty level (easy/medium/hard)
success: Whether episode succeeded
steps: Number of steps taken
rewards: List of rewards per step
trajectory: List of {observation, action, reward, done, test_score}
model: Model name used
elapsed_s: Total episode time
Returns:
Path to saved episode file
"""
episode = {
"metadata": {
"task": task,
"difficulty": difficulty,
"success": success,
"model": model,
"timestamp": datetime.now().isoformat(),
"elapsed_s": round(elapsed_s, 3),
},
"summary": {
"steps": steps,
"rewards": [round(r, 4) for r in rewards],
"final_reward": round(rewards[-1], 4) if rewards else 0.0,
"mean_reward": round(sum(rewards) / len(rewards), 4) if rewards else 0.0,
"max_reward": round(max(rewards), 4) if rewards else 0.0,
},
"trajectory": trajectory,
}
# Filename: difficulty_timestamp.json
timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
filename = f"{difficulty}_{task}_{timestamp_str}.json"
filepath = self.output_dir / filename
with open(filepath, 'w') as f:
json.dump(episode, f, indent=2)
return str(filepath)
@staticmethod
def load_episodes(output_dir: str = "./episodes") -> List[Dict[str, Any]]:
"""Load all saved episodes from directory."""
episodes = []
episode_dir = Path(output_dir)
if not episode_dir.exists():
return episodes
for json_file in sorted(episode_dir.glob("*.json")):
try:
with open(json_file, 'r') as f:
episode = json.load(f)
episodes.append(episode)
except Exception as e:
print(f"Failed to load {json_file}: {e}")
return episodes