--- language: - en pipeline_tag: text-generation tags: - pytorch - Mistral --- ## Model Details We employ **Mistral-Base(7B)** as one of the base models to evaluate our proposed **Reward-Driven Selective Penalization for Preference Alignment Optimization (RSPO)** method. The model is trained for **one epoch** on the **UltraFeedback Binarized dataset** using **(RSPO)** method. ## How to use #### Transformers AutoModelForCausalLM ```python from transformers import AutoTokenizer, AutoModelForCausalLM import torch model_id = "li11111/Mistral-7B-Base-RSPO" tokenizer = AutoTokenizer.from_pretrained(model_id) model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.bfloat16, device_map="auto", ) messages = [ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, {"role": "user", "content": "Who are you?"}, ] input_ids = tokenizer.apply_chat_template( messages, add_generation_prompt=True, return_tensors="pt" ).to(model.device) terminators = [ tokenizer.eos_token_id ] outputs = model.generate( input_ids, max_new_tokens=256, eos_token_id=terminators, do_sample=True, temperature=0.6, top_p=0.9, ) response = outputs[0][input_ids.shape[-1]:] print(tokenizer.decode(response, skip_special_tokens=True)) ``` ## Experiment Parameters | **Parameter** | **Mistral-Base(7B)** | | ------------------- | -------------------- | | `GPU` | 8×Ascend910B | | `beta` | 0.01 | | `batch` | 128 | | `learning_rate` | 5e-7 | | `max_prompt_length` | 512 | | `max_length` | 1024 | | `num_train_epochs` | 1 | | `torch_dtype` | `bfloat16` | | `warmup_ratio` | 0.1 | | `β_w` | 0.01 | | `β_l` | 0.1 | | `λ` | 0.1 | ## Training Data We use the [HuggingFaceH4/ultrafeedback_binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) dataset to train the Mistral Base model. ## Benchmarks
Method AlpacaEval 2.0
LC WR Avg. Len
RSPO 25.4 23.7 1873
| **Method** | **GSM8K** | **ARC** | **TQA** | **MMLU** | **IFEval** | **Avg.** | | ---------- | --------- | --------- | --------- | --------- | ---------- | --------- | | **SFT** | **42.61** | 55.97 | 28.15 | 57.17 | 36.59 | 44.10 | | **DPO** | 33.13 | 59.64 | 46.14 | 57.46 | 50.48 | 49.37 | | **R-DPO** | 30.10 | 56.06 | 40.64 | 58.48 | 53.24 | 47.70 | | **SimPO** | 33.59 | **60.15** | 43.45 | 58.25 | 52.98 | 49.68 | | **WPO** | 30.63 | 57.00 | 40.51 | 58.54 | **55.64** | 48.46 | | **RSPO** | 37.45 | 57.94 | **47.25** | **58.58** | 55.04 | **51.25** |