| --- |
| title: "RLHF (Beta)" |
| description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback." |
| back-to-top-navigation: true |
| toc: true |
| toc-expand: 2 |
| toc-depth: 4 |
| --- |
|
|
| ## Overview |
|
|
| Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human |
| feedback. Various methods include, but not limited to: |
|
|
| - [Direct Preference Optimization (DPO)](#dpo) |
| - [Identity Preference Optimization (IPO)](#ipo) |
| - [Kahneman-Tversky Optimization (KTO)](#kto) |
| - [Odds Ratio Preference Optimization (ORPO)](#orpo) |
| - Proximal Policy Optimization (PPO) (not yet supported in axolotl) |
|
|
|
|
| ## RLHF using Axolotl |
|
|
| ::: {.callout-important} |
| This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality. |
| ::: |
|
|
| We rely on the [TRL](https://github.com/huggingface/trl) library for implementations of various RL training methods, which we wrap around to expose in axolotl. Each method has their own supported ways of loading datasets and prompt formats. |
|
|
| ::: {.callout-tip} |
| You can find what each method supports by going into `src/axolotl/prompt_strategies/{method}` where `{method}` is one of our supported methods. The `type: ` can be retrieved from `{method}.{function_name}`. |
| ::: |
|
|
| ### DPO |
|
|
| Example config: |
|
|
| ```yaml |
| rl: dpo |
| datasets: |
| - path: Intel/orca_dpo_pairs |
| split: train |
| type: chatml.intel |
| - path: argilla/ultrafeedback-binarized-preferences |
| split: train |
| type: chatml |
| ``` |
|
|
| DPO supports the following types with the following dataset format: |
|
|
| #### chatml.argilla |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "instruction": "...", |
| "chosen_response": "...", |
| "rejected_response": "..." |
| } |
| ``` |
|
|
| #### chatml.argilla_chat |
|
|
| ```json |
| { |
| "chosen": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ], |
| "rejected": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
|
|
| #### chatml.icr |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "input": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| #### chatml.intel |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "question": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| #### chatml.prompt_pairs |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| #### chatml.ultra |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "chosen": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ], |
| "rejected": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
|
|
| #### llama3.argilla |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "instruction": "...", |
| "chosen_response": "...", |
| "rejected_response": "..." |
| } |
| ``` |
|
|
| #### llama3.argilla_chat |
|
|
| ```json |
| { |
| "chosen": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ], |
| "rejected": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
|
|
| #### llama3.icr |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "input": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| #### llama3.intel |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "question": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| #### llama3.prompt_pairs |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| #### llama3.ultra |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "chosen": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ], |
| "rejected": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
|
|
| #### zephyr.nectar |
|
|
| ```json |
| { |
| "prompt": "...", |
| "answers": [ |
| { |
| "answer": "...", |
| "rank": 1 |
| }, |
| { |
| "answer": "...", |
| "rank": 2 |
| } |
| // ... more answers with ranks |
| ] |
| } |
| ``` |
|
|
| #### chat_template.default |
|
|
| ```yaml |
| rl: dpo |
| datasets: |
| - path: ... |
| split: train |
| type: chat_template.default |
| field_messages: "messages" |
| field_chosen: "chosen" |
| field_rejected: "rejected" |
| message_property_mappings: |
| role: role |
| content: content |
| roles: |
| user: ["user"] |
| assistant: ["assistant"] |
| system: ["system"] |
| ``` |
|
|
| Sample input format: |
|
|
| ```json |
| { |
| "messages": [ |
| { |
| "role": "system", |
| "content": "..." |
| }, |
| { |
| "role": "user", |
| "content": "..." |
| }, |
| // ... more messages |
| ], |
| "chosen": { |
| "role": "assistant", |
| "content": "..." |
| }, |
| "rejected": { |
| "role": "assistant", |
| "content": "..." |
| } |
| } |
| ``` |
|
|
| #### user_defined.default |
|
|
| For custom behaviors, |
|
|
| ```yaml |
| rl: dpo |
| datasets: |
| - path: ... |
| split: train |
| type: user_defined.default |
| |
| field_prompt: "prompt" |
| field_system: "system" |
| field_chosen: "chosen" |
| field_rejected: "rejected" |
| prompt_format: "{prompt}" |
| chosen_format: "{chosen}" |
| rejected_format: "{rejected}" |
| ``` |
| |
| The input format is a simple JSON input with customizable fields based on the above config. |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "chosen": "...", |
| "rejected": "..." |
| } |
| ``` |
|
|
| ### IPO |
|
|
| As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO. |
|
|
| ```yaml |
| rl: ipo |
| ``` |
|
|
| ### ORPO |
|
|
| Paper: https://arxiv.org/abs/2403.07691 |
|
|
| ```yaml |
| rl: orpo |
| orpo_alpha: 0.1 |
| remove_unused_columns: false |
| |
| chat_template: chatml |
| datasets: |
| - path: argilla/ultrafeedback-binarized-preferences-cleaned |
| type: chat_template.argilla |
| ``` |
| |
| ORPO supports the following types with the following dataset format: |
|
|
| #### chat_template.argilla |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", // if available, will be taken as user message for single-turn instead of from list below |
| |
| // chosen/rejected should be same till last content and only even-number of alternating user/assistant turns |
| "chosen": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ], |
| "rejected": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
| |
| ### KTO |
|
|
| ```yaml |
| rl: kto |
| rl_beta: 0.1 # default |
| kto_desirable_weight: 1.0 # default |
| kto_undesirable_weight: 1.0 # default |
| |
| remove_unused_columns: false |
|
|
| datasets: |
| - path: argilla/ultrafeedback-binarized-preferences-cleaned-kto |
| type: llama3.ultra |
| split: train |
|
|
| gradient_checkpointing: true |
| gradient_checkpointing_kwargs: |
| use_reentrant: true |
| ``` |
| |
| KTO supports the following types with the following dataset format: |
|
|
| #### chatml.argilla |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "instruction": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### chatml.argilla_chat |
|
|
| ```json |
| { |
| "chosen": [ |
| {"role": "user", "content": "..."} |
| ], |
| "completion": [ |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
|
|
| #### chatml.intel |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "question": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### chatml.prompt_pairs |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### chatml.ultra |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### llama3.argilla |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "instruction": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### llama3.argilla_chat |
|
|
| ```json |
| { |
| "completion": [ |
| {"role": "user", "content": "..."}, |
| {"role": "assistant", "content": "..."} |
| ] |
| } |
| ``` |
|
|
| #### llama3.intel |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "question": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### llama3.prompt_pairs |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### llama3.ultra |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "completion": "..." |
| } |
| ``` |
|
|
| #### user_defined.default |
|
|
| For custom behaviors, |
|
|
| ```yaml |
| rl: kto |
| datasets: |
| - path: ... |
| split: train |
| type: user_defined.default |
| |
| field_prompt: "prompt" |
| field_system: "system" |
| field_completion: "completion" |
| field_label: "label" |
| prompt_format: "{prompt}" |
| completion_format: "{completion}" |
| ``` |
| |
| The input format is a simple JSON input with customizable fields based on the above config. |
|
|
| ```json |
| { |
| "system": "...", // optional |
| "prompt": "...", |
| "completion": "...", |
| "label": "..." |
| } |
| ``` |
|
|
| ### GRPO |
|
|
| ::: {.callout-tip} |
| Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo). |
| ::: |
|
|
| GRPO uses custom reward functions and transformations. Please have them ready locally. |
|
|
| For ex, to load OpenAI's GSM8K and use a random reward for completions: |
|
|
| ```python |
| # rewards.py |
| import random |
| |
| def rand_reward_func(completions, **kwargs) -> list[float]: |
| return [random.uniform(0, 1) for _ in completions] |
|
|
| def oai_gsm8k_transform(cfg, *args, **kwargs): |
| def transform_fn(example, tokenizer=None): |
| label = example["answer"].split("####")[-1].strip().replace(",", "") |
| return { |
| "prompt": [{"role": "user", "content": example["question"]},], |
| "answer": label, |
| } |
| return transform_fn, {"remove_columns": ["question"]} |
| ``` |
| |
| ```yaml |
| rl: grpo |
| |
| trl: |
| beta: 0.001 |
| max_completion_length: 256 |
| use_vllm: True |
| vllm_device: auto |
| vllm_gpu_memory_utilization: 0.15 |
| num_generations: 4 |
| reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}' |
| reward_weights: [1.0] |
| datasets: |
| - path: openai/gsm8k |
| name: main |
| type: rewards.oai_gsm8k_transform # format: '{file_name}.{fn_name}' |
| ``` |
| |
| To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function). |
|
|
| To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py). |
|
|
| ### SimPO |
|
|
| SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function. |
|
|
| ```yaml |
| rl: simpo |
| rl_beta: 0.1 # default in CPOTrainer |
| cpo_alpha: 1.0 # default in CPOTrainer |
| simpo_gamma: 0.5 # default in CPOTrainer |
| ``` |
|
|
| This method uses the same dataset format as [DPO](#dpo). |
|
|
| ### Using local dataset files |
|
|
| ```yaml |
| datasets: |
| - ds_type: json |
| data_files: |
| - orca_rlhf.jsonl |
| split: train |
| type: chatml.intel |
| ``` |
|
|
| ### TRL auto-unwrapping for PEFT |
|
|
| TRL supports auto-unwrapping PEFT models for RL training paradigms which rely on a reference model. This significantly reduces memory pressure as an additional refreference model does not need to be loaded, and reference model log-probabilities can be obtained by disabling PEFT adapters. This is enabled by default. To turn it off, pass the following config: |
|
|
| ```yaml |
| # load ref model when adapter training. |
| rl_adapter_ref_model: true |
| ``` |
|
|