File size: 1,258 Bytes
3b6ded8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from trl import GRPOConfig, GRPOTrainer


def train(tokenizer, model, reward_funcs, train_ds):
    training_args = GRPOConfig(
        learning_rate = 5e-6,
        adam_beta1 = 0.9,
        adam_beta2 = 0.99,
        weight_decay = 0.1,
        warmup_ratio = 0.1,
        lr_scheduler_type = "cosine",
        optim = "adamw_8bit",
        logging_steps = 1,
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 2, # Increase to 4 for smoother training
        num_generations = 4, # Decrease if out of memory
        max_prompt_length = 1024,
        max_completion_length = 1024,
        #num_train_epochs = 2, # Set to 1 for a full training run
        importance_sampling_level = "sequence",
        mask_truncated_completions=False,
        loss_type='dr_grpo',
        max_steps = 60,
        save_steps = 60,
        max_grad_norm = 0.1,
        report_to = "none", # Can use wandb as per unsloth docs
        output_dir = "outputs",
    )

    trainer = GRPOTrainer(
        model=model,
        args=training_args,
        # Pass the processor to handle multimodal inputs
        processing_class=tokenizer,
        reward_funcs=reward_funcs,
        train_dataset=train_ds,
    )

    trainer.train()
    return trainer