""" FIX 9: Trajectory logging for GRPO training data collection. Per rulebook Section 5 & 6: Save episode trajectories to enable GRPO training. Each episode is saved as JSON with metadata, summary, and full trajectory. """ import json import os from pathlib import Path from datetime import datetime from typing import Dict, List, Any, Optional class TrajectoryLogger: """Save episode trajectories for GRPO training.""" def __init__(self, output_dir: str = "./episodes"): self.output_dir = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) def save_episode( self, task: str, difficulty: str, success: bool, steps: int, rewards: List[float], trajectory: List[Dict[str, Any]], model: str = "unknown", elapsed_s: float = 0.0, ) -> str: """ Save one episode to JSON for GRPO training. Args: task: Task identifier (e.g., "easy", "problem_1") difficulty: Difficulty level (easy/medium/hard) success: Whether episode succeeded steps: Number of steps taken rewards: List of rewards per step trajectory: List of {observation, action, reward, done, test_score} model: Model name used elapsed_s: Total episode time Returns: Path to saved episode file """ episode = { "metadata": { "task": task, "difficulty": difficulty, "success": success, "model": model, "timestamp": datetime.now().isoformat(), "elapsed_s": round(elapsed_s, 3), }, "summary": { "steps": steps, "rewards": [round(r, 4) for r in rewards], "final_reward": round(rewards[-1], 4) if rewards else 0.0, "mean_reward": round(sum(rewards) / len(rewards), 4) if rewards else 0.0, "max_reward": round(max(rewards), 4) if rewards else 0.0, }, "trajectory": trajectory, } # Filename: difficulty_timestamp.json timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S') filename = f"{difficulty}_{task}_{timestamp_str}.json" filepath = self.output_dir / filename with open(filepath, 'w') as f: json.dump(episode, f, indent=2) return str(filepath) @staticmethod def load_episodes(output_dir: str = "./episodes") -> List[Dict[str, Any]]: """Load all saved episodes from directory.""" episodes = [] episode_dir = Path(output_dir) if not episode_dir.exists(): return episodes for json_file in sorted(episode_dir.glob("*.json")): try: with open(json_file, 'r') as f: episode = json.load(f) episodes.append(episode) except Exception as e: print(f"Failed to load {json_file}: {e}") return episodes