AegisOpenEnv / verify_dataset_rollout.py
armaan020's picture
Upload folder using huggingface_hub
ae6e5e2 verified
"""
Verification script for dataset-augmented rollout function.
"""
from transformers import AutoTokenizer
from train import rollout_func, dataset_iterator
import trl.experimental.openenv
def verify():
print("=== Verifying Dataset-Augmented Rollout ===")
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Mock generation
def mock_gen(trainer, prompts):
print(f"\n[Generated Prompt Sample]:\n{prompts[0][:500]}...\n")
fake_completion = '{"action_type": "APPROVE", "target_id": "ACC-001", "regulation_citation": "EU-Regulation"}'
return [{"prompt_ids": [1], "completion_ids": [1], "logprobs": [0.0], "text": fake_completion}]
trl.experimental.openenv.generate_rollout_completions = mock_gen
print("Running rollout_func...")
result = rollout_func(None, ["prompt"], tokenizer)
print(f"\nRollout Finished!")
print(f"Reward: {result['env_reward']}")
if __name__ == "__main__":
verify()