File size: 2,696 Bytes
e0b3d16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import verifiers as vf

"""
# install
vf-install complex-json-output (-p /path/to/environments)

# quick eval
vf-eval complex-json-output (-m model_name in endpoints.py)

inference:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5 vf-vllm --model Qwen/Qwen2.5-1.5B-Instruct \
    --data-parallel-size 6 --enforce-eager --disable-log-requests

training:
CUDA_VISIBLE_DEVICES=6,7 accelerate launch --num-processes 2 \
    --config-file configs/zero3.yaml examples/grpo/train_complex_json_output.py
"""

# Hyperparameters
HPARAMS = [
    "per_device_train_batch_size",
    "num_generations",
    "gradient_accumulation_steps",
    "max_tokens",
    "max_seq_len",
    "max_prompt_length",
    "max_completion_length",
    "temperature",
    "learning_rate",
    "max_steps",
    "warmup_steps",
    "eval_steps",
    "save_steps",
    "beta",
    "loss_type",
]

# Load environment
vf_env = vf.load_environment(
    env_id="complex-json-output",
    num_train_examples=8000,  # Use subset for faster training
    num_eval_examples=50
)

# Model configuration
model_name = "/raid/workspace/Mango/verifiers/MS3.2-0.35-Beta"
run_name = "complex-json-grpo_" + model_name.split("/")[-1].lower()

# Load model and tokenizer
model, tokenizer = vf.get_model_and_tokenizer(model_name)

# Training arguments
training_args = vf.grpo_defaults(run_name=run_name)

# Batch configuration
training_args.per_device_train_batch_size = 2
training_args.num_generations = 16
training_args.gradient_accumulation_steps = 2

# Generation parameters
training_args.max_tokens = 2048  # JSON can be long
training_args.max_seq_len = 16000
training_args.max_prompt_length = 8192  # Allow long prompts (questions can be lengthy)
training_args.max_completion_length = 4096  # Allow long completions
training_args.temperature = 1.0  # Full diversity for exploration

# Training schedule
training_args.learning_rate = 5e-6
training_args.max_steps = 1000
training_args.warmup_steps = 15

# Evaluation
training_args.eval_strategy = "none"
training_args.eval_steps = 50
training_args.per_device_eval_batch_size = 8

# Checkpointing
training_args.save_strategy = "steps"
training_args.save_steps = 100

# GRPO parameters
training_args.beta = 0.001  # Conservative KL penalty
training_args.loss_type = "dr_grpo"  # Recommended: no length bias

# Logging
training_args.logging_steps = 1
training_args.log_completions = True
training_args.num_completions_to_print = 3
training_args.report_to = "wandb"  # Disable wandb

# Create trainer
trainer = vf.GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    env=vf_env,
    args=training_args,
    peft_config=vf.lora_defaults(r=8, alpha=16),  # Use LoRA for efficiency
)

# Train
trainer.train()