File size: 3,266 Bytes
03a907a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""

FIX 9: Trajectory logging for GRPO training data collection.



Per rulebook Section 5 & 6: Save episode trajectories to enable GRPO training.

Each episode is saved as JSON with metadata, summary, and full trajectory.

"""

import json
import os
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Any, Optional


class TrajectoryLogger:
    """Save episode trajectories for GRPO training."""
    
    def __init__(self, output_dir: str = "./episodes"):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
    
    def save_episode(

        self,

        task: str,

        difficulty: str,

        success: bool,

        steps: int,

        rewards: List[float],

        trajectory: List[Dict[str, Any]],

        model: str = "unknown",

        elapsed_s: float = 0.0,

    ) -> str:
        """

        Save one episode to JSON for GRPO training.

        

        Args:

            task: Task identifier (e.g., "easy", "problem_1")

            difficulty: Difficulty level (easy/medium/hard)

            success: Whether episode succeeded

            steps: Number of steps taken

            rewards: List of rewards per step

            trajectory: List of {observation, action, reward, done, test_score}

            model: Model name used

            elapsed_s: Total episode time

        

        Returns:

            Path to saved episode file

        """
        episode = {
            "metadata": {
                "task": task,
                "difficulty": difficulty,
                "success": success,
                "model": model,
                "timestamp": datetime.now().isoformat(),
                "elapsed_s": round(elapsed_s, 3),
            },
            "summary": {
                "steps": steps,
                "rewards": [round(r, 4) for r in rewards],
                "final_reward": round(rewards[-1], 4) if rewards else 0.0,
                "mean_reward": round(sum(rewards) / len(rewards), 4) if rewards else 0.0,
                "max_reward": round(max(rewards), 4) if rewards else 0.0,
            },
            "trajectory": trajectory,
        }
        
        # Filename: difficulty_timestamp.json
        timestamp_str = datetime.now().strftime('%Y%m%d_%H%M%S')
        filename = f"{difficulty}_{task}_{timestamp_str}.json"
        filepath = self.output_dir / filename
        
        with open(filepath, 'w') as f:
            json.dump(episode, f, indent=2)
        
        return str(filepath)
    
    @staticmethod
    def load_episodes(output_dir: str = "./episodes") -> List[Dict[str, Any]]:
        """Load all saved episodes from directory."""
        episodes = []
        episode_dir = Path(output_dir)
        
        if not episode_dir.exists():
            return episodes
        
        for json_file in sorted(episode_dir.glob("*.json")):
            try:
                with open(json_file, 'r') as f:
                    episode = json.load(f)
                    episodes.append(episode)
            except Exception as e:
                print(f"Failed to load {json_file}: {e}")
        
        return episodes