File size: 2,227 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Rollout data saver utility for VeRL training
"""
import os
import json
from typing import List, Any, Dict
from datetime import datetime


class RolloutSaver:
    """Saves rollout data immediately after generation"""
    
    def __init__(self, save_dir: str):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        self.step_counter = 0
        
    def save_rollout(self, 
                     prompts: List[str], 
                     responses: List[str], 
                     scores: List[float] = None,
                     step: int = None,
                     extra_info: Dict[str, Any] = None):
        """
        Save rollout data to JSONL file
        
        Args:
            prompts: List of input prompts
            responses: List of generated responses
            scores: List of scores (optional, defaults to 0)
            step: Training step number
            extra_info: Additional information to save
        """
        if step is None:
            step = self.step_counter
            self.step_counter += 1
            
        if scores is None:
            scores = [0.0] * len(prompts)
            
        filename = os.path.join(self.save_dir, f"{step}_rollout.jsonl")
        
        try:
            with open(filename, "w") as f:
                for i in range(len(prompts)):
                    entry = {
                        "step": step,
                        "input": prompts[i],
                        "output": responses[i],
                        "score": scores[i] if i < len(scores) else 0.0,
                        "saved_at": datetime.now().isoformat(),
                        "saved_after": "rollout"
                    }
                    
                    # Add extra info if provided
                    if extra_info:
                        entry.update(extra_info)
                        
                    f.write(json.dumps(entry, ensure_ascii=False) + "\n")
                    
            print(f"✅ Saved {len(prompts)} rollout samples to {filename}")
            return filename
            
        except Exception as e:
            print(f"⚠️ Failed to save rollout data: {e}")
            return None