File size: 1,733 Bytes
1fa3c6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# GRPO With Replay Buffer

This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.

## Usage

```python

import torch

from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer

from datasets import load_dataset



dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")



# Guarantee that some rewards have 0 std

def custom_reward_func(completions, **kwargs):

    if torch.rand(1).item() < 0.25:

        return [0] * len(completions)  # simulate some None rewards

    else:

        return torch.rand(len(completions)).tolist()



training_args = GRPOWithReplayBufferConfig(

    output_dir="./tmp",

    learning_rate=1e-4,

    per_device_train_batch_size=4,

    num_generations=4,

    max_completion_length=8,

    replay_buffer_size=8,

    report_to="none",

)



trainer = GRPOWithReplayBufferTrainer(

    model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",

    reward_funcs=[custom_reward_func],

    args=training_args,

    train_dataset=dataset,

)



previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}



trainer.train()

```

## GRPOWithReplayBufferTrainer

[[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer

    - train

    - save_model
    - push_to_hub

## GRPOWithReplayBufferConfig

[[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig



## ReplayBuffer



[[autodoc]] experimental.grpo_with_replay_buffer.ReplayBuffer