Sampling
Sampling is one of the newly supported key capabilities of SWIFT. This feature can be understood as the practical implementation of test-time compute. Additionally, this capability is crucial for the implementation of RFT (Reinforcement Fine-Tuning).
Capability Introduction
The sampling capability of SWIFT can be demonstrated with the following example:
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine pt --num_return_sequences 5 --dataset AI-ModelScope/alpaca-gpt4-data-zh#5
A jsonl file with a timestamp as the filename will be generated in the sample_output directory of the current folder. This file should contain 25 lines, each representing a complete messages format data.
For a list of sampling parameters, please refer to here.
Environment Setup
pip install ms-swift[llm] -U
Or install swift from source:
git clone https://github.com/modelscope/ms-swift.git
cd ms-swift
pip install -e '.[llm]'
Using PRM and ORM for Result Filtering
An important capability of sampling is supervising the process and results, which can be supported by setting additional parameters.
swift sample --model LLM-Research/Meta-Llama-3.1-8B-Instruct --sampler_engine lmdeploy --num_return_sequences 5 --n_best_to_keep 2 --dataset tastelikefeet/competition_math#5 --prm_model AI-ModelScope/GRM-llama3.2-3B-rewardmodel-ft --orm_model math
A jsonl file with a timestamp as the filename will be generated in the sample_output directory of the current folder. This file will contain at most 10 lines, each representing a complete messages format data.
The reason it contains at most 10 lines is that although 5 data points are processed in total, and 2 are kept for each data point (
n_best_to_keep), ORM may fail some validations, and failed data will not be retained in the file. Additionally, after adding--prm_modelor--orm_model, the file format is slightly different and includes arejected_responsekey, which contains the responses with the lowest PRM scores.
Customizing PRM or ORM
PRM and ORM can be customized by adding a new implementation in the plugin according to the existing code. For example:
class CustomPRM:
# The constructor should be parameterless
def __init__(self):
# Initialize here
pass
def __call__(self, infer_requests: List[InferRequest], ground_truths: List[str], **kwargs) -> List[Union[float, List[float]]]:
...
prms = {'custom': CustomPRM}
Afterward, use --prm_model custom in the command line.
Memory Control
If the sampled model and PRM are loaded into memory simultaneously, it may lead to an OOM (Out of Memory) issue. To address this, sampling can be divided into two stages:
- Stage 1: Specify
--modeland--sampler_enginewithout specifying--orm_modeland--prm_model. Perform sampling only and save the results to a file. - Stage 2: Specify
--sampler_engine no, along with--orm_modeland--prm_model, and also specify--cache_files. Perform only RM data filtering without re-sampling.
By dividing the process into two stages, only one model is loaded at a time, avoiding OOM issues.
Practical Example
Please refer to the Reinforcement Fine-Tuning Script. This script provides a practical example of using sampling for reinforcement fine-tuning.
Note: The actual effectiveness of this script is strongly related to the quality of the model, data, and RM. Therefore, it is presented only as an example. Users should modify this script and train their own RM and generator models accordingly.
Sampling From Large Model
SWIFT's sample supports using the OpenAI API to distill data with large models. Example:
OPENAI_API_KEY="your_api_key" \
swift sample \
--sampler_type distill \
--sampler_engine client \
--model deepseek-r1 \
--stream true \
--dataset tastelikefeet/competition_math#5 \
--num_return_sequences 1 \
--temperature 0.6 \
--top_p 0.95 \
--engine_kwargs '{"base_url":"https://dashscope.aliyuncs.com/compatible-mode/v1"}'
In this example:
base_url and model represent the API endpoint and model name, respectively. stream indicates the stream parameter for the request.
Note: For Deepseek-R1 series models, the output will be formatted as:<thinking>{reasoning_content}</thinking>\n\n<answer>{content}</answer>.