File size: 6,065 Bytes
fbd7ca2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
# Medical-ChatBot-DPO LoRA 模型
## 模型概述
基于 SFT 模型的直接偏好优化(Direct Preference Optimization, DPO)LoRA 适配器,通过人类偏好数据进行对齐训练。
- **基础模型**: SFT 阶段训练后的 LLaMA-3.1-8B (sft-full-multi)
- **训练阶段**: Direct Preference Optimization (DPO)
- **训练状态**: 🚧 训练中
## 1. 数据集
**数据集**: [bootscoder/Medical-ChatBot-DPO](https://huggingface.co/datasets/bootscoder/Medical-ChatBot-DPO)
详细数据集信息请查看上述链接。
## 2. 训练流程
### 技术栈
- **DeepSpeed**: ZeRO Stage 1 分布式训练
- **PEFT**: LoRA 参数高效微调
- **BitsAndBytes**: 4-bit NF4 量化 + 双重量化
- **Flash Attention 2**: 加速注意力计算
- **TRL**: DPOTrainer 偏好优化
### 训练阶段
1. **加载 SFT 模型**: 使用 SFT 阶段的 full model 作为起点
2. **偏好数据处理**: 处理 chosen/rejected 回答对
3. **DPO 训练**: 使用 sigmoid loss 进行偏好对齐
4. **分布式训练**: DeepSpeed 8卡并行训练,2 epochs
5. **保存模型**: 保存 LoRA 适配器权重
## 3. 参数配置
### 硬件配置
```
GPU: 8 × NVIDIA A5000 (24GB VRAM)
实际使用卡: 0,1,2,4,5,7,8,9
分布式: DeepSpeed ZeRO Stage 1
```
### 训练超参数
```yaml
seq_length: 512 # 序列长度
max_prompt_length: 128 # 最大提示长度
max_completion_length: 128 # 最大回答长度
batch_size: 4 # 每卡批次大小
gradient_accumulation_steps: 8 # 梯度累积
effective_batch_size: 256 # 4 × 8 × 8
num_train_epochs: 2 # 训练轮数
learning_rate: 5e-6 # 学习率 (低于 SFT)
lr_scheduler_type: cosine # 余弦调度
warmup_ratio: 0.05 # 预热比例
bf16: true # BF16 混合精度
gradient_checkpointing: true # 梯度检查点
beta: 0.1 # DPO 温度参数
loss_type: sigmoid # Sigmoid loss
```
### QLoRA 配置
**量化配置**:
```python
load_in_4bit: True # 4-bit 量化
bnb_4bit_quant_type: nf4 # NF4 量化
bnb_4bit_compute_dtype: bfloat16 # BF16 计算
bnb_4bit_use_double_quant: True # 双重量化
```
**LoRA 配置**:
```python
r: 64 # LoRA 秩
lora_alpha: 8 # 缩放因子 (alpha/r = 0.125)
target_modules: [q_proj, k_proj] # Q, K 投影层
bias: none # 不训练 bias
trainable_params: ~54MB # 可训练参数
```
**DPO 特性**:
- **Beta 参数 (0.1)**: 控制偏好强度,较小的 beta 使模型更积极地学习偏好
- **Sigmoid Loss**: 稳定的损失函数,适合偏好学习
- **无需 Reference Model**: 隐式参考模型,节约显存
- **显存节约**: ~90% (相比全参数训练)
## 4. 峰值显存占用
**单卡峰值**: ____________ GB
**8卡总计**: ____________ GB
## 5. 模型预期表现
### 相比 SFT 模型的改进
**改进**:
- 生成内容更符合人类偏好
- 回答质量和安全性提升
- 减少不必要的冗长或不当内容
- 更好的拒答能力(对不确定问题)
- 输出风格更加友好和专业
**相比 Base LLaMA-3.1-8B**:
- 医疗领域知识(CPT)+ 指令遵循(SFT)+ 偏好对齐(DPO)
- 完整的 RLHF 替代方案(DPO 作为 PPO 的替代)
- 更安全、更可控、更符合用户期望
**局限**:
- 依赖偏好数据质量,可能存在偏好数据偏差
- 对于偏好数据未覆盖的场景,改进有限
- 仍建议在实际应用中进行人工审核
## 使用方法
### 加载模型
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载 SFT full model
base_model = AutoModelForCausalLM.from_pretrained(
"/path/to/sft-full-multi",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 加载 DPO LoRA 适配器
model = PeftModel.from_pretrained(base_model, "/path/to/dpo-lora")
tokenizer = AutoTokenizer.from_pretrained("/path/to/dpo-lora")
# 合并适配器(可选)
model = model.merge_and_unload()
```
### 对话示例
```python
# 构建对话格式
SYSTEM_PROMPT = "You are a Medical Chatbot, you should friendly answer the question."
def format_prompt(question):
return f"###System: {SYSTEM_PROMPT}\n###Question: {question}\n###Answer: "
# 生成回答
question = "感冒了应该怎么办?"
prompt = format_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.7,
top_p=0.9,
do_sample=True,
repetition_penalty=1.1
)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(response)
```
### 推理优化建议
```python
# 使用更保守的采样参数以提高输出质量
outputs = model.generate(
**inputs,
max_new_tokens=256,
temperature=0.6, # 降低随机性
top_p=0.85, # 更集中的采样
do_sample=True,
repetition_penalty=1.15, # 避免重复
no_repeat_ngram_size=3 # 避免 n-gram 重复
)
```
## 模型文件
```
dpo-lora/
├── adapter_config.json # LoRA 配置
├── adapter_model.safetensors # LoRA 权重 (~54MB)
├── special_tokens_map.json # 特殊 token 映射
├── tokenizer.json # 分词器
└── tokenizer_config.json # 分词器配置
```
## 训练进度
**当前状态**: 🚧 训练中
训练完成后,模型将包含完整的 CPT → SFT → DPO 训练流程,形成一个完整的医疗对话模型。
## 许可证
遵循 [Llama 3.1 Community License](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE)
## 注意事项
⚠️ **医疗免责声明**:
- 本模型仅供研究和教育用途
- 不应作为专业医疗建议的替代
- 任何医疗决策都应咨询专业医疗人员
- 模型输出可能包含错误或不完整信息
|