neural-mesh-v2 / test /utils /rollout_saver.py
hjkim00's picture
Restore all essential files - code, configs, and MBPP/HumanEval data
24c2665 verified
"""
Rollout data saver utility for VeRL training
"""
import os
import json
from typing import List, Any, Dict
from datetime import datetime
class RolloutSaver:
"""Saves rollout data immediately after generation"""
def __init__(self, save_dir: str):
self.save_dir = save_dir
os.makedirs(save_dir, exist_ok=True)
self.step_counter = 0
def save_rollout(self,
prompts: List[str],
responses: List[str],
scores: List[float] = None,
step: int = None,
extra_info: Dict[str, Any] = None):
"""
Save rollout data to JSONL file
Args:
prompts: List of input prompts
responses: List of generated responses
scores: List of scores (optional, defaults to 0)
step: Training step number
extra_info: Additional information to save
"""
if step is None:
step = self.step_counter
self.step_counter += 1
if scores is None:
scores = [0.0] * len(prompts)
filename = os.path.join(self.save_dir, f"{step}_rollout.jsonl")
try:
with open(filename, "w") as f:
for i in range(len(prompts)):
entry = {
"step": step,
"input": prompts[i],
"output": responses[i],
"score": scores[i] if i < len(scores) else 0.0,
"saved_at": datetime.now().isoformat(),
"saved_after": "rollout"
}
# Add extra info if provided
if extra_info:
entry.update(extra_info)
f.write(json.dumps(entry, ensure_ascii=False) + "\n")
print(f"✅ Saved {len(prompts)} rollout samples to {filename}")
return filename
except Exception as e:
print(f"⚠️ Failed to save rollout data: {e}")
return None