File size: 3,943 Bytes
b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 cb2f352 b9c28c1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | # Medical-ChatBot-CPT LoRA 模型
## 模型概述
基于 LLaMA-3.1-8B 的医疗聊天机器人持续预训练(Continual Pre-Training, CPT)LoRA 适配器。
- **基础模型**: meta-llama/Llama-3.1-8B
- **训练阶段**: Continual Pre-Training (CPT)
- **适配器大小**: ~26.5MB
## 1. 数据集
**数据集**: [bootscoder/Medical-ChatBot-CPT](https://huggingface.co/datasets/bootscoder/Medical-ChatBot-CPT)
详细数据集信息请查看上述链接。
## 2. 训练流程
### 技术栈
- **DeepSpeed**: ZeRO Stage 1 分布式训练
- **PEFT**: LoRA 参数高效微调
- **BitsAndBytes**: 4-bit NF4 量化
- **Flash Attention 2**: 加速注意力计算
- **TRL**: SFTTrainer 训练接口
### 训练阶段
1. **模型初始化**: 加载 LLaMA-3.1-8B 并应用 4-bit 量化
2. **LoRA 配置**: 初始化低秩适配器(r=32, alpha=8)
3. **分布式训练**: DeepSpeed 8卡并行训练,1 epoch
4. **保存模型**: 保存 LoRA 适配器权重
## 3. 参数配置
### 硬件配置
```
GPU: 8 × NVIDIA A5000 (24GB VRAM)
分布式: DeepSpeed ZeRO Stage 1
```
### 训练超参数
```yaml
seq_length: 2048 # 序列长度
batch_size: 2 # 每卡批次大小
gradient_accumulation_steps: 16 # 梯度累积
effective_batch_size: 256 # 2 × 8 × 16
num_train_epochs: 1 # 训练轮数
learning_rate: 1e-5 # 学习率
lr_scheduler_type: cosine # 余弦调度
warmup_ratio: 0.1 # 预热比例
bf16: true # BF16 混合精度
gradient_checkpointing: true # 梯度检查点
packing: true # 序列打包
```
### QLoRA 配置
**量化配置**:
```python
load_in_4bit: True # 4-bit 量化
bnb_4bit_quant_type: nf4 # NF4 量化
bnb_4bit_compute_dtype: bfloat16 # BF16 计算
```
**LoRA 配置**:
```python
r: 32 # LoRA 秩
lora_alpha: 8 # 缩放因子 (alpha/r = 0.25)
target_modules: [q_proj, k_proj] # Q, K 投影层
bias: none # 不训练 bias
trainable_params: ~26.5MB # 可训练参数 (~0.2%)
```
**显存优化效果**:
- 原始全参数训练 (FP16): ~72GB per GPU
- 使用 QLoRA: ~7-8GB per GPU
- **显存节约: ~90%**
## 4. 峰值显存占用
**单卡峰值**: ____________ GB
**8卡总计**: ____________ GB
## 5. 模型预期表现
### 相比 Base LLaMA-3.1-8B 的改进
**改进**:
- 更好理解医疗术语和概念
- 输出更符合医疗领域语言风格
- 为后续 SFT 训练提供更好初始化
**局限**:
- 未经指令微调,不理解指令格式
- 输出结构化程度不足
- 不建议直接部署使用
## 使用方法
### 加载模型
```python
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
# 加载基础模型
base_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B",
torch_dtype=torch.bfloat16,
device_map="auto"
)
# 加载 LoRA 适配器
model = PeftModel.from_pretrained(base_model, "/path/to/pretrained-lora")
tokenizer = AutoTokenizer.from_pretrained("/path/to/pretrained-lora")
# 合并适配器(可选)
model = model.merge_and_unload()
```
### 生成文本
```python
inputs = tokenizer("高血压是一种", return_tensors="pt").to(model.device)
outputs = model.generate(**inputs, max_new_tokens=128)
print(tokenizer.decode(outputs[0]))
```
## 模型文件
```
pretrained-lora/
├── adapter_config.json # LoRA 配置
├── adapter_model.safetensors # LoRA 权重 (~26.5MB)
├── special_tokens_map.json # 特殊 token 映射
├── tokenizer.json # 分词器
└── tokenizer_config.json # 分词器配置
```
## 许可证
遵循 [Llama 3.1 Community License](https://github.com/meta-llama/llama-models/blob/main/models/llama3_1/LICENSE)
|