|
|
""" |
|
|
Rollout data saver utility for VeRL training |
|
|
""" |
|
|
import os |
|
|
import json |
|
|
from typing import List, Any, Dict |
|
|
from datetime import datetime |
|
|
|
|
|
|
|
|
class RolloutSaver: |
|
|
"""Saves rollout data immediately after generation""" |
|
|
|
|
|
def __init__(self, save_dir: str): |
|
|
self.save_dir = save_dir |
|
|
os.makedirs(save_dir, exist_ok=True) |
|
|
self.step_counter = 0 |
|
|
|
|
|
def save_rollout(self, |
|
|
prompts: List[str], |
|
|
responses: List[str], |
|
|
scores: List[float] = None, |
|
|
step: int = None, |
|
|
extra_info: Dict[str, Any] = None): |
|
|
""" |
|
|
Save rollout data to JSONL file |
|
|
|
|
|
Args: |
|
|
prompts: List of input prompts |
|
|
responses: List of generated responses |
|
|
scores: List of scores (optional, defaults to 0) |
|
|
step: Training step number |
|
|
extra_info: Additional information to save |
|
|
""" |
|
|
if step is None: |
|
|
step = self.step_counter |
|
|
self.step_counter += 1 |
|
|
|
|
|
if scores is None: |
|
|
scores = [0.0] * len(prompts) |
|
|
|
|
|
filename = os.path.join(self.save_dir, f"{step}_rollout.jsonl") |
|
|
|
|
|
try: |
|
|
with open(filename, "w") as f: |
|
|
for i in range(len(prompts)): |
|
|
entry = { |
|
|
"step": step, |
|
|
"input": prompts[i], |
|
|
"output": responses[i], |
|
|
"score": scores[i] if i < len(scores) else 0.0, |
|
|
"saved_at": datetime.now().isoformat(), |
|
|
"saved_after": "rollout" |
|
|
} |
|
|
|
|
|
|
|
|
if extra_info: |
|
|
entry.update(extra_info) |
|
|
|
|
|
f.write(json.dumps(entry, ensure_ascii=False) + "\n") |
|
|
|
|
|
print(f"✅ Saved {len(prompts)} rollout samples to {filename}") |
|
|
return filename |
|
|
|
|
|
except Exception as e: |
|
|
print(f"⚠️ Failed to save rollout data: {e}") |
|
|
return None |