Upload ms-swift/docs/source/Instruction/采样.md with huggingface_hub
Browse files
ms-swift/docs/source/Instruction/采样.md
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 采样
|
| 2 |
+
|
| 3 |
+
采样是SWIFT新支持的重要能力之一,这部分可以理解为`test-time compute`的落地实现。同时,该能力对RFT(强化微调)的实现也至关重要。
|
| 4 |
+
|
| 5 |
+
## 能力介绍
|
| 6 |
+
|
| 7 |
+
SWIFT的sample能力可以使用下面的例子进行:
|
| 8 |
+
```shell
|
| 9 |
+
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 10 |
+
```
|
| 11 |
+
在当前文件夹的`sample_output`目录下,会生成以时间戳为文件名的jsonl文件,该文件应该包含25行,每一行都是一个完整`messages`格式的数据。
|
| 12 |
+
|
| 13 |
+
采样的参数列表请参考[这里](命令行参数.md)。
|
| 14 |
+
|
| 15 |
+
## 环境准备
|
| 16 |
+
|
| 17 |
+
```shell
|
| 18 |
+
pip install ms-swift[llm] -U
|
| 19 |
+
```
|
| 20 |
+
|
| 21 |
+
或从源代码安装:
|
| 22 |
+
|
| 23 |
+
```shell
|
| 24 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 25 |
+
cd ms-swift
|
| 26 |
+
pip install -e '.[llm]'
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## 使用PRM和ORM进行结果过滤
|
| 30 |
+
|
| 31 |
+
采样重要的能力就是对过程和结果进行监督,这可以通过设置额外参数来支持。
|
| 32 |
+
|
| 33 |
+
```shell
|
| 34 |
+
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
在当前文件夹的`sample_output`目录下,会生成以时间戳为文件名的jsonl文件,该文件**至多包含**10行,每一行都是一个完整`messages`格式的数据。
|
| 38 |
+
> 之所以至多包含10行,是因为虽然设置了共处理5个数据,每个数据保留2个(n_best_to_keep),但是orm可能会校验失败,失败数据不会保留到文件中。
|
| 39 |
+
> 另外,增加了--prm_model或--orm_model后文件格式有所不同,包含了rejected_response key,内容来自于prm评分最低的行。
|
| 40 |
+
|
| 41 |
+
## 自定义PRM或ORM
|
| 42 |
+
|
| 43 |
+
PRM和ORM的自定义可以在plugin中按照现有代码增加一个新的实现。例如:
|
| 44 |
+
```python
|
| 45 |
+
class CustomPRM:
|
| 46 |
+
|
| 47 |
+
# 构造需要是无参的
|
| 48 |
+
def __init__(self):
|
| 49 |
+
# init here
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
def __call__(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[Union[float, List[float]]]:
|
| 53 |
+
...
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
prms = {'custom': CustomPRM}
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
之后在命令行中使用`--prm_model custom`即可。
|
| 60 |
+
|
| 61 |
+
## 显存控制
|
| 62 |
+
|
| 63 |
+
如果被采样模型和PRM共同加载进显存,则可能出现OOM的问题。因此采样可以分为两段进行:
|
| 64 |
+
|
| 65 |
+
- 第一段指定`--model`和``--sampler_engine`,同时不指定`--orm_model`和`--prm_model`,仅进行采样,并存储为文件
|
| 66 |
+
- 第二段指定`--sampler_engine no`,指定`--orm_model`和`--prm_model`,并同时指定`--cache_files`,仅进行RM数据过滤,不重新采样
|
| 67 |
+
|
| 68 |
+
通过两段方式可以每次仅加载一个模型,防止OOM。
|
| 69 |
+
|
| 70 |
+
## 实际例子
|
| 71 |
+
|
| 72 |
+
请参考[强化微调脚本](https://github.com/modelscope/ms-swift/tree/main/examples/train/rft/rft.py)。该脚本给出了使用采样进行强化微调的实际例子。
|
| 73 |
+
|
| 74 |
+
> 注意:该脚本的实际效果和模型、数据、RM的质量强相关,因此仅作为样例出现,用户请自行修改该脚本并训练自己的RM和generator模型。
|
| 75 |
+
|
| 76 |
+
## 大模型蒸馏采样
|
| 77 |
+
|
| 78 |
+
SWIFT的sample支持使用OpenAI API的方式,用大模型蒸馏数据,如下示例:
|
| 79 |
+
```shell
|
| 80 |
+
OPENAI_API_KEY="your_api_key" \
|
| 81 |
+
swift sample \
|
| 82 |
+
--sampler_type distill \
|
| 83 |
+
--sampler_engine client \
|
| 84 |
+
--model deepseek-r1 \
|
| 85 |
+
--stream true \
|
| 86 |
+
--dataset tastelikefeet/competition_math#5 \
|
| 87 |
+
--num_return_sequences 1 \
|
| 88 |
+
--temperature 0.6 \
|
| 89 |
+
--top_p 0.95 \
|
| 90 |
+
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
|
| 91 |
+
```
|
| 92 |
+
在以上示例中,base_url和model分别是api地址和模型名称,stream表示发起请求的stream参数。
|
| 93 |
+
|
| 94 |
+
注意,对于Deepseek-R1系列模型,输出会被格式化为:`<think>{reasoning_content}</think>\n\n<answer>{content}</answer>`。
|