File size: 2,227 Bytes
24c2665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
"""
Rollout data saver utility for VeRL training
"""
import os
import json
from typing import List, Any, Dict
from datetime import datetime
class RolloutSaver:
"""Saves rollout data immediately after generation"""
def __init__(self, save_dir: str):
self.save_dir = save_dir
os.makedirs(save_dir, exist_ok=True)
self.step_counter = 0
def save_rollout(self,
prompts: List[str],
responses: List[str],
scores: List[float] = None,
step: int = None,
extra_info: Dict[str, Any] = None):
"""
Save rollout data to JSONL file
Args:
prompts: List of input prompts
responses: List of generated responses
scores: List of scores (optional, defaults to 0)
step: Training step number
extra_info: Additional information to save
"""
if step is None:
step = self.step_counter
self.step_counter += 1
if scores is None:
scores = [0.0] * len(prompts)
filename = os.path.join(self.save_dir, f"{step}_rollout.jsonl")
try:
with open(filename, "w") as f:
for i in range(len(prompts)):
entry = {
"step": step,
"input": prompts[i],
"output": responses[i],
"score": scores[i] if i < len(scores) else 0.0,
"saved_at": datetime.now().isoformat(),
"saved_after": "rollout"
}
# Add extra info if provided
if extra_info:
entry.update(extra_info)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"✅ Saved {len(prompts)} rollout samples to {filename}")
return filename
except Exception as e:
print(f"⚠️ Failed to save rollout data: {e}")
return None |