li11111's picture
Upload README.md
ef73a4f verified
---
language:
- en
pipeline_tag: text-generation
tags:
- pytorch
- Mistral
---
## Model Details
We employ **Mistral-Base(7B)** as one of the base models to evaluate our proposed **Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO)** method. The model is trained for **one epoch** on the **UltraFeedback Binarized dataset** using **(RSPO)** method.
## How to use
#### Transformers AutoModelForCausalLM
```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
model_id = "li11111/Mistral-7B-Base-RSPO"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
)
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"},
]
input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt"
).to(model.device)
terminators = [
tokenizer.eos_token_id
]
outputs = model.generate(
input_ids,
max_new_tokens=256,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
)
response = outputs[0][input_ids.shape[-1]:]
print(tokenizer.decode(response, skip_special_tokens=True))
```
## Experiment Parameters
| **Parameter** | **Mistral-Base(7B)** |
| ------------------- | -------------------- |
| `GPU` | 8×Ascend910B |
| `beta` | 0.01 |
| `batch` | 128 |
| `learning_rate` | 5e-7 |
| `max_prompt_length` | 512 |
| `max_length` | 1024 |
| `num_train_epochs` | 1 |
| `torch_dtype` | `bfloat16` |
| `warmup_ratio` | 0.1 |
| `β_w` | 0.01 |
| `β_l` | 0.1 |
| `λ` | 0.1 |
## Training Data
We use the [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) dataset to train the Mistral Base model.
## Benchmarks
<table>
<tr>
<th>Method</th>
<th colspan="3" style="text-align: center;">AlpacaEval 2.0</th>
</tr>
<tr>
<th></th>
<th>LC</th>
<th>WR</th>
<th>Avg. Len</th>
</tr>
<tr>
<td><b>RSPO</b></td>
<td><b>25.4</b></td>
<td><b>23.7</b></td>
<td>1873</td>
</tr>
</table>
| **Method** | **GSM8K** | **ARC** | **TQA** | **MMLU** | **IFEval** | **Avg.** |
| ---------- | --------- | --------- | --------- | --------- | ---------- | --------- |
| **SFT** | **42.61** | 55.97 | 28.15 | 57.17 | 36.59 | 44.10 |
| **DPO** | 33.13 | 59.64 | 46.14 | 57.46 | 50.48 | 49.37 |
| **R-DPO** | 30.10 | 56.06 | 40.64 | 58.48 | 53.24 | 47.70 |
| **SimPO** | 33.59 | **60.15** | 43.45 | 58.25 | 52.98 | 49.68 |
| **WPO** | 30.63 | 57.00 | 40.51 | 58.54 | **55.64** | 48.46 |
| **RSPO** | 37.45 | 57.94 | **47.25** | **58.58** | 55.04 | **51.25** |