Spaces:
Running
Running
| """ | |
| FIX 9: Trajectory logging for GRPO training data collection. | |
| Per rulebook Section 5 & 6: Save episode trajectories to enable GRPO training. | |
| Each episode is saved as JSON with metadata, summary, and full trajectory. | |
| """ | |
| import json | |
| import os | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, List, Any, Optional | |
| class TrajectoryLogger: | |
| """Save episode trajectories for GRPO training.""" | |
| def __init__(self, output_dir: str = "./episodes"): | |
| self.output_dir = Path(output_dir) | |
| self.output_dir.mkdir(parents=True, exist_ok=True) | |
| def save_episode( | |
| self, | |
| task: str, | |
| difficulty: str, | |
| success: bool, | |
| steps: int, | |
| rewards: List[float], | |
| trajectory: List[Dict[str, Any]], | |
| model: str = "unknown", | |
| elapsed_s: float = 0.0, | |
| ) -> str: | |
| """ | |
| Save one episode to JSON for GRPO training. | |
| Args: | |
| task: Task identifier (e.g., "easy", "problem_1") | |
| difficulty: Difficulty level (easy/medium/hard) | |
| success: Whether episode succeeded | |
| steps: Number of steps taken | |
| rewards: List of rewards per step | |
| trajectory: List of {observation, action, reward, done, test_score} | |
| model: Model name used | |
| elapsed_s: Total episode time | |
| Returns: | |
| Path to saved episode file | |
| """ | |
| episode = { | |
| "metadata": { | |
| "task": task, | |
| "difficulty": difficulty, | |
| "success": success, | |
| "model": model, | |
| "timestamp": datetime.now().isoformat(), | |
| "elapsed_s": round(elapsed_s, 3), | |
| }, | |
| "summary": { | |
| "steps": steps, | |
| "rewards": [round(r, 4) for r in rewards], | |
| "final_reward": round(rewards[-1], 4) if rewards else 0.0, | |
| "mean_reward": round(sum(rewards) / len(rewards), 4) if rewards else 0.0, | |
| "max_reward": round(max(rewards), 4) if rewards else 0.0, | |
| }, | |
| "trajectory": trajectory, | |
| } | |
| # Filename: difficulty_timestamp.json | |
| timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S') | |
| filename = f"{difficulty}_{task}_{timestamp_str}.json" | |
| filepath = self.output_dir / filename | |
| with open(filepath, 'w') as f: | |
| json.dump(episode, f, indent=2) | |
| return str(filepath) | |
| def load_episodes(output_dir: str = "./episodes") -> List[Dict[str, Any]]: | |
| """Load all saved episodes from directory.""" | |
| episodes = [] | |
| episode_dir = Path(output_dir) | |
| if not episode_dir.exists(): | |
| return episodes | |
| for json_file in sorted(episode_dir.glob("*.json")): | |
| try: | |
| with open(json_file, 'r') as f: | |
| episode = json.load(f) | |
| episodes.append(episode) | |
| except Exception as e: | |
| print(f"Failed to load {json_file}: {e}") | |
| return episodes | |