File size: 9,619 Bytes
3f8673d e3678d1 aa9de2b e3678d1 aa9de2b e3678d1 aa9de2b e3678d1 aa9de2b e3678d1 aa9de2b e3678d1 aa9de2b e3678d1 aa9de2b e3678d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 |
---
license: mit
---
## [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)
本次微调的模型我已经上传到了 HuggingFace Hub 上,大家可以进行尝试。
### 安装 mlx-lm
```bash
pip install mlx-lm
```
### 生成 SQL
```
python -m mlx_lm.generate --model mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
```
```
SELECT School FROM Students WHERE Name = 'Wang Junjian'
```
## [在 MLX 上使用 LoRA 基于 Mistral-7B 微调 Text2SQL(一)](https://wangjunjian.com/mlx/lora/2024/01/23/Fine-tuning-Text2SQL-based-on-Mistral-7B-using-LoRA-on-MLX-1.html)
📌 没有使用模型的标注格式生成数据集,导致不能结束,直到生成最大的 Tokens 数量。
这次我们来解决这个问题。
## 数据集 WikiSQL
- [WikiSQL](https://github.com/salesforce/WikiSQL)
- [sqllama/sqllama-V0](https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb)
### 修改脚本 mlx-examples/lora/data/wikisql.py
```py
if __name__ == "__main__":
# ......
for dataset, name, size in datasets:
with open(f"data/{name}.jsonl", "w") as fid:
for e, t in zip(range(size), dataset):
"""
t 变量的文本是这样的:
------------------------
<s>table: 1-1058787-1
columns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples
Q: How many significant relationships list Will as a virtue?
A: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'</s>
"""
t = t[3:] # 去掉开头的 <s>,因为 tokenizer 会自动添加 <s>
json.dump({"text": t}, fid)
fid.write("\n")
```
执行脚本 `data/wikisql.py` 生成数据集。
### 样本示例
```
table: 1-10753917-1
columns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat
Q: Which podiums did the alfa romeo team have?
A: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'</s>
```
## 微调
- 预训练模型 [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)
### LoRA 微调
```bash
python lora.py --model mistralai/Mistral-7B-v0.1 \
--train \
--iters 600
```
```
Total parameters 7243.436M
Trainable parameters 1.704M
python lora.py --model mistralai/Mistral-7B-v0.1 --train --iters 600 50.58s user 214.71s system 21% cpu 20:26.04 total
```
微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。
LoRA 微调 600 次迭代,耗时 20 分 26 秒,占用内存 46G。
## 评估
计算测试集困惑度(PPL)和交叉熵损失(Loss)。
```bash
python lora.py --model mistralai/Mistral-7B-v0.1 \
--adapter-file adapters.npz \
--test
```
```
Iter 100: Test loss 1.351, Test ppl 3.862.
Iter 200: Test loss 1.327, Test ppl 3.770.
Iter 300: Test loss 1.353, Test ppl 3.869.
Iter 400: Test loss 1.355, Test ppl 3.875.
Iter 500: Test loss 1.294, Test ppl 3.646.
Iter 600: Test loss 1.351, Test ppl 3.863.
```
| Iter | Test loss | Test ppl |
| :--: | --------: | -------: |
| 100 | 1.351 | 3.862 |
| 200 | 1.327 | 3.770 |
| 300 | 1.353 | 3.869 |
| 400 | 1.355 | 3.875 |
| 500 | 1.294 | 3.646 |
| 600 | 1.351 | 3.863 |
评估占用内存 26G。
## 融合(Fuse)
```bash
python fuse.py --model mistralai/Mistral-7B-v0.1 \
--adapter-file adapters.npz \
--save-path lora_fused_model
```
## 生成 SQL
### 王军建的姓名是什么?
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: "
```
```
SELECT Name FROM students WHERE Name = 'Wang Junjian'
```
### 王军建的年龄是多少?
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: "
```
```
SELECT Age FROM Students WHERE Name = 'Wang Junjian'
```
### 王军建来自哪所学校?
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
```
```
SELECT School FROM Students WHERE Name = 'Wang Junjian'
```
### 查询王军建的姓名、年龄、学校信息。
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: "
```
```
SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
```
### 查询王军建的所有信息。
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: "
```
```
SELECT Name FROM students WHERE Name = 'Wang Junjian'
```
可能训练数据不足。
### 统计一下九年级有多少学生。
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: "
```
```
SELECT COUNT Name FROM Students WHERE Grade = '9th'
```
### 统计一下九年级有多少学生(九年级的值是9)。
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9.
Q: Count how many students there are in ninth grade.
A: "
```
```bash
python -m mlx_lm.generate --model lora_fused_model \
--max-tokens 50 \
--prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9.)
A: "
```
```
SELECT COUNT Name FROM students WHERE Grade = 9
```
附加的提示信息可以轻松添加,不用太在意放置的位置。
## 上传模型到 HuggingFace Hub
1. 加入 [MLX Community](https://huggingface.co/mlx-community) 组织
2. 在 MLX Community 组织中创建一个新的模型 [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)
3. 克隆仓库 [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)
```bash
git clone https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
```
4. 将生成的模型文件(`lora_fused_model` 目录下的所有文件)复制到仓库目录下
5. 上传模型到 HuggingFace Hub
```bash
git add .
git commit -m "Fine tuning Text2SQL based on Mistral-7B using LoRA on MLX"
git push
```
### git push 错误
1. 不能 push
错误信息:
```
Uploading LFS objects: 0% (0/2), 0 B | 0 B/s, done.
batch response: Authorization error.
error: failed to push some refs to 'https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL'
```
解决方法:
```bash
vim .git/config
```
```conf
[remote "origin"]
url = https://wangjunjian:write_token@huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
fetch = +refs/heads/*:refs/remotes/origin/*
```
2. 不能上传大于 5GB 的文件
错误信息:
```
warning: current Git remote contains credentials
batch response:
You need to configure your repository to enable upload of files > 5GB.
Run "huggingface-cli lfs-enable-largefiles ./path/to/your/repo" and try again.
```
解决方法:
```bash
huggingface-cli longin
huggingface-cli lfs-enable-largefiles /Users/junjian/HuggingFace/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
```
## 参考资料
- [MLX Community](https://huggingface.co/mlx-community)
- [Fine-Tuning with LoRA or QLoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora)
- [Generate Text with LLMs and MLX](https://github.com/ml-explore/mlx-examples/tree/main/llms)
- [Awesome Text2SQL](https://github.com/eosphoros-ai/Awesome-Text2SQL)
- [Awesome Text2SQL(中文)](https://github.com/eosphoros-ai/Awesome-Text2SQL/blob/main/README.zh.md)
- [Mistral AI](https://huggingface.co/mistralai)
- [A Beginner’s Guide to Fine-Tuning Mistral 7B Instruct Model](https://adithyask.medium.com/a-beginners-guide-to-fine-tuning-mistral-7b-instruct-model-0f39647b20fe)
- [Mistral Instruct 7B Finetuning on MedMCQA Dataset](https://saankhya.medium.com/mistral-instruct-7b-finetuning-on-medmcqa-dataset-6ec2532b1ff1)
- [Fine-tuning Mistral on your own data](https://github.com/brevdev/notebooks/blob/main/mistral-finetune-own-data.ipynb)
- [mlx-examples llms Mistral](https://github.com/ml-explore/mlx-examples/blob/main/llms/mistral/README.md)
|