|
|
--- |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: text-generation |
|
|
tags: |
|
|
- pytorch |
|
|
- Mistral |
|
|
--- |
|
|
|
|
|
## Model Details |
|
|
|
|
|
We employ **Mistral-Base(7B)** as one of the base models to evaluate our proposed **Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO)** method. The model is trained for **one epoch** on the **UltraFeedback Binarized dataset** using **(RSPO)** method. |
|
|
|
|
|
## How to use |
|
|
|
|
|
#### Transformers AutoModelForCausalLM |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
import torch |
|
|
|
|
|
model_id = "li11111/Mistral-7B-Base-RSPO" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
messages = [ |
|
|
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, |
|
|
{"role": "user", "content": "Who are you?"}, |
|
|
] |
|
|
|
|
|
input_ids = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
add_generation_prompt=True, |
|
|
return_tensors="pt" |
|
|
).to(model.device) |
|
|
|
|
|
terminators = [ |
|
|
tokenizer.eos_token_id |
|
|
] |
|
|
|
|
|
outputs = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=256, |
|
|
eos_token_id=terminators, |
|
|
do_sample=True, |
|
|
temperature=0.6, |
|
|
top_p=0.9, |
|
|
) |
|
|
response = outputs[0][input_ids.shape[-1]:] |
|
|
print(tokenizer.decode(response, skip_special_tokens=True)) |
|
|
``` |
|
|
|
|
|
## Experiment Parameters |
|
|
|
|
|
| **Parameter** | **Mistral-Base(7B)** | |
|
|
| ------------------- | -------------------- | |
|
|
| `GPU` | 8×Ascend910B | |
|
|
| `beta` | 0.01 | |
|
|
| `batch` | 128 | |
|
|
| `learning_rate` | 5e-7 | |
|
|
| `max_prompt_length` | 512 | |
|
|
| `max_length` | 1024 | |
|
|
| `num_train_epochs` | 1 | |
|
|
| `torch_dtype` | `bfloat16` | |
|
|
| `warmup_ratio` | 0.1 | |
|
|
| `β_w` | 0.01 | |
|
|
| `β_l` | 0.1 | |
|
|
| `λ` | 0.1 | |
|
|
|
|
|
|
|
|
## Training Data |
|
|
|
|
|
We use the [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) dataset to train the Mistral Base model. |
|
|
|
|
|
|
|
|
## Benchmarks |
|
|
|
|
|
<table> |
|
|
<tr> |
|
|
<th>Method</th> |
|
|
<th colspan="3" style="text-align: center;">AlpacaEval 2.0</th> |
|
|
</tr> |
|
|
<tr> |
|
|
<th></th> |
|
|
<th>LC</th> |
|
|
<th>WR</th> |
|
|
<th>Avg. Len</th> |
|
|
</tr> |
|
|
<tr> |
|
|
<td><b>RSPO</b></td> |
|
|
<td><b>25.4</b></td> |
|
|
<td><b>23.7</b></td> |
|
|
<td>1873</td> |
|
|
</tr> |
|
|
</table> |
|
|
|
|
|
|
|
|
|
|
|
| **Method** | **GSM8K** | **ARC** | **TQA** | **MMLU** | **IFEval** | **Avg.** | |
|
|
| ---------- | --------- | --------- | --------- | --------- | ---------- | --------- | |
|
|
| **SFT** | **42.61** | 55.97 | 28.15 | 57.17 | 36.59 | 44.10 | |
|
|
| **DPO** | 33.13 | 59.64 | 46.14 | 57.46 | 50.48 | 49.37 | |
|
|
| **R-DPO** | 30.10 | 56.06 | 40.64 | 58.48 | 53.24 | 47.70 | |
|
|
| **SimPO** | 33.59 | **60.15** | 43.45 | 58.25 | 52.98 | 49.68 | |
|
|
| **WPO** | 30.63 | 57.00 | 40.51 | 58.54 | **55.64** | 48.46 | |
|
|
| **RSPO** | 37.45 | 57.94 | **47.25** | **58.58** | 55.04 | **51.25** | |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|