File size: 9,619 Bytes
3f8673d
 
 
e3678d1
 
 
aa9de2b
e3678d1
aa9de2b
e3678d1
 
 
 
 
aa9de2b
e3678d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9de2b
e3678d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9de2b
e3678d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa9de2b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3678d1
 
aa9de2b
 
e3678d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
---
license: mit
---

## [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)

本次微调的模型我已经上传到了 HuggingFace Hub 上,大家可以进行尝试。

### 安装 mlx-lm

```bash
pip install mlx-lm
```

### 生成 SQL
```
python -m mlx_lm.generate --model mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
```
```
SELECT School FROM Students WHERE Name = 'Wang Junjian'
```


## [在 MLX 上使用 LoRA 基于 Mistral-7B 微调 Text2SQL(一)](https://wangjunjian.com/mlx/lora/2024/01/23/Fine-tuning-Text2SQL-based-on-Mistral-7B-using-LoRA-on-MLX-1.html)

📌 没有使用模型的标注格式生成数据集,导致不能结束,直到生成最大的 Tokens 数量。

这次我们来解决这个问题。

## 数据集 WikiSQL

- [WikiSQL](https://github.com/salesforce/WikiSQL)
- [sqllama/sqllama-V0](https://huggingface.co/sqllama/sqllama-V0/blob/main/wikisql.ipynb)

### 修改脚本 mlx-examples/lora/data/wikisql.py
```py
if __name__ == "__main__":
    # ......
    for dataset, name, size in datasets:
        with open(f"data/{name}.jsonl", "w") as fid:
            for e, t in zip(range(size), dataset):
                """
                t 变量的文本是这样的:
                ------------------------
                <s>table: 1-1058787-1
                columns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples
                Q: How many significant relationships list Will as a virtue?
                A: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'</s>
                """
                t = t[3:] # 去掉开头的 <s>,因为 tokenizer 会自动添加 <s>
                json.dump({"text": t}, fid)
                fid.write("\n")
```

执行脚本 `data/wikisql.py` 生成数据集。

### 样本示例

```
table: 1-10753917-1
columns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat
Q: Which podiums did the alfa romeo team have?
A: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'</s>
```


## 微调

- 预训练模型 [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1)

### LoRA 微调

```bash
python lora.py --model mistralai/Mistral-7B-v0.1 \
               --train \
               --iters 600
```
```
Total parameters 7243.436M
Trainable parameters 1.704M
python lora.py --model mistralai/Mistral-7B-v0.1 --train --iters 600  50.58s user 214.71s system 21% cpu 20:26.04 total
```

微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。

LoRA 微调 600 次迭代,耗时 20 分 26 秒,占用内存 46G。

## 评估

计算测试集困惑度(PPL)和交叉熵损失(Loss)。

```bash
python lora.py --model mistralai/Mistral-7B-v0.1 \
               --adapter-file adapters.npz \
               --test
```
```
Iter 100: Test loss 1.351, Test ppl 3.862.
Iter 200: Test loss 1.327, Test ppl 3.770.
Iter 300: Test loss 1.353, Test ppl 3.869.
Iter 400: Test loss 1.355, Test ppl 3.875.
Iter 500: Test loss 1.294, Test ppl 3.646.
Iter 600: Test loss 1.351, Test ppl 3.863.
```

| Iter | Test loss | Test ppl |
| :--: | --------: | -------: |
| 100  | 1.351     | 3.862    |
| 200  | 1.327     | 3.770    |
| 300  | 1.353     | 3.869    |
| 400  | 1.355     | 3.875    |
| 500  | 1.294     | 3.646    |
| 600  | 1.351     | 3.863    |

评估占用内存 26G。


## 融合(Fuse)

```bash
python fuse.py --model mistralai/Mistral-7B-v0.1 \
               --adapter-file adapters.npz \
               --save-path lora_fused_model
```


## 生成 SQL

### 王军建的姓名是什么?

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: "
```
```
SELECT Name FROM students WHERE Name = 'Wang Junjian'
```

### 王军建的年龄是多少?

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: "
```
```
SELECT Age FROM Students WHERE Name = 'Wang Junjian'
```

### 王军建来自哪所学校?

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
```
```
SELECT School FROM Students WHERE Name = 'Wang Junjian'
```

### 查询王军建的姓名、年龄、学校信息。

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: "
```
```
SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
```

### 查询王军建的所有信息。

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: "
```
```
SELECT Name FROM students WHERE Name = 'Wang Junjian'
```

可能训练数据不足。

### 统计一下九年级有多少学生。

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: "
```
```
SELECT COUNT Name FROM Students WHERE Grade = '9th'
```

### 统计一下九年级有多少学生(九年级的值是9)。

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9.
Q: Count how many students there are in ninth grade.
A: "
```

```bash
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9.)
A: "
```

```
SELECT COUNT Name FROM students WHERE Grade = 9
```

附加的提示信息可以轻松添加,不用太在意放置的位置。


## 上传模型到 HuggingFace Hub

1. 加入 [MLX Community](https://huggingface.co/mlx-community) 组织

2. 在 MLX Community 组织中创建一个新的模型 [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)

3. 克隆仓库 [mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL](https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL)

```bash
git clone https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
```

4. 将生成的模型文件(`lora_fused_model` 目录下的所有文件)复制到仓库目录下

5. 上传模型到 HuggingFace Hub

```bash
git add .
git commit -m "Fine tuning Text2SQL based on Mistral-7B using LoRA on MLX"
git push
```

### git push 错误

1. 不能 push

错误信息:

```
Uploading LFS objects:   0% (0/2), 0 B | 0 B/s, done.
batch response: Authorization error.
error: failed to push some refs to 'https://huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL'
```

解决方法:

```bash
vim .git/config
```
```conf
[remote "origin"]
    url = https://wangjunjian:write_token@huggingface.co/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
    fetch = +refs/heads/*:refs/remotes/origin/*
```

2. 不能上传大于 5GB 的文件

错误信息:

```
warning: current Git remote contains credentials
batch response:
You need to configure your repository to enable upload of files > 5GB.
Run "huggingface-cli lfs-enable-largefiles ./path/to/your/repo" and try again.
```


解决方法:

```bash
huggingface-cli longin
huggingface-cli lfs-enable-largefiles /Users/junjian/HuggingFace/mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL
```


## 参考资料
- [MLX Community](https://huggingface.co/mlx-community)
- [Fine-Tuning with LoRA or QLoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora)
- [Generate Text with LLMs and MLX](https://github.com/ml-explore/mlx-examples/tree/main/llms)
- [Awesome Text2SQL](https://github.com/eosphoros-ai/Awesome-Text2SQL)
- [Awesome Text2SQL(中文)](https://github.com/eosphoros-ai/Awesome-Text2SQL/blob/main/README.zh.md)
- [Mistral AI](https://huggingface.co/mistralai)
- [A Beginner’s Guide to Fine-Tuning Mistral 7B Instruct Model](https://adithyask.medium.com/a-beginners-guide-to-fine-tuning-mistral-7b-instruct-model-0f39647b20fe)
- [Mistral Instruct 7B Finetuning on MedMCQA Dataset](https://saankhya.medium.com/mistral-instruct-7b-finetuning-on-medmcqa-dataset-6ec2532b1ff1)
- [Fine-tuning Mistral on your own data](https://github.com/brevdev/notebooks/blob/main/mistral-finetune-own-data.ipynb)
- [mlx-examples llms Mistral](https://github.com/ml-explore/mlx-examples/blob/main/llms/mistral/README.md)