Text Generation
Transformers
PyTorch
Chinese
English
llama
conversational
custom_code
text-generation-inference
Instructions to use openbmb/BitCPM-CANN-1B-unquantized with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use openbmb/BitCPM-CANN-1B-unquantized with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("text-generation", model="openbmb/BitCPM-CANN-1B-unquantized", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] pipe(messages)# Load model directly from transformers import AutoTokenizer, AutoModelForCausalLM tokenizer = AutoTokenizer.from_pretrained("openbmb/BitCPM-CANN-1B-unquantized", trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained("openbmb/BitCPM-CANN-1B-unquantized", trust_remote_code=True) messages = [ {"role": "user", "content": "Who are you?"}, ] inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device) outputs = model.generate(**inputs, max_new_tokens=40) print(tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:])) - Notebooks
- Google Colab
- Kaggle
- Local Apps
- vLLM
How to use openbmb/BitCPM-CANN-1B-unquantized with vLLM:
Install from pip and serve model
# Install vLLM from pip: pip install vllm # Start the vLLM server: vllm serve "openbmb/BitCPM-CANN-1B-unquantized" # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:8000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "openbmb/BitCPM-CANN-1B-unquantized", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker
docker model run hf.co/openbmb/BitCPM-CANN-1B-unquantized
- SGLang
How to use openbmb/BitCPM-CANN-1B-unquantized with SGLang:
Install from pip and serve model
# Install SGLang from pip: pip install sglang # Start the SGLang server: python3 -m sglang.launch_server \ --model-path "openbmb/BitCPM-CANN-1B-unquantized" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "openbmb/BitCPM-CANN-1B-unquantized", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }'Use Docker images
docker run --gpus all \ --shm-size 32g \ -p 30000:30000 \ -v ~/.cache/huggingface:/root/.cache/huggingface \ --env "HF_TOKEN=<secret>" \ --ipc=host \ lmsysorg/sglang:latest \ python3 -m sglang.launch_server \ --model-path "openbmb/BitCPM-CANN-1B-unquantized" \ --host 0.0.0.0 \ --port 30000 # Call the server using curl (OpenAI-compatible API): curl -X POST "http://localhost:30000/v1/chat/completions" \ -H "Content-Type: application/json" \ --data '{ "model": "openbmb/BitCPM-CANN-1B-unquantized", "messages": [ { "role": "user", "content": "What is the capital of France?" } ] }' - Docker Model Runner
How to use openbmb/BitCPM-CANN-1B-unquantized with Docker Model Runner:
docker model run hf.co/openbmb/BitCPM-CANN-1B-unquantized
Add example/ folder with training scripts
Browse files- example/README.md +131 -0
- example/ds_config.json +29 -0
- example/ds_config_z2.json +22 -0
- example/requirements.txt +8 -0
- example/run.sh +37 -0
- example/run_sft.sh +38 -0
- example/train.py +203 -0
- example/train_sft.py +424 -0
example/README.md
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# BitCPM4 Continue Pretrain Example
|
| 2 |
+
|
| 3 |
+
This project provides scripts for continue pretraining **BitCPM4-CANN-1B-unquantized**.
|
| 4 |
+
|
| 5 |
+
## Environment Setup
|
| 6 |
+
|
| 7 |
+
### Docker Image
|
| 8 |
+
|
| 9 |
+
Use the following Huawei NPU image:
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
swr.cn-south-1.myhuaweicloud.com/ascendhub/mindspeed-llm:openeuler22.03-mindspeed-llm-2.3.0-a3-arm
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
Other Huawei NPU images may also work but have not been fully tested.
|
| 16 |
+
|
| 17 |
+
### Install Dependencies
|
| 18 |
+
|
| 19 |
+
After entering the container, install the Python dependencies:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install -r requirements.txt
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
Dependency list:
|
| 26 |
+
|
| 27 |
+
| Package | Version |
|
| 28 |
+
| --- | --- |
|
| 29 |
+
| transformers | 4.46.3 |
|
| 30 |
+
| tokenizers | 0.20.3 |
|
| 31 |
+
| accelerate | 1.1.1 |
|
| 32 |
+
| deepspeed | 0.16.2 |
|
| 33 |
+
| datasets | 3.1.0 |
|
| 34 |
+
| safetensors | 0.4.5 |
|
| 35 |
+
| pyarrow | 17.0.0 |
|
| 36 |
+
| tensorboard | 2.18.0 |
|
| 37 |
+
|
| 38 |
+
## Dataset
|
| 39 |
+
|
| 40 |
+
The test dataset used is [C4-Pro](https://huggingface.co/datasets/gair-prox/c4-pro), stored in parquet format after downloading.
|
| 41 |
+
|
| 42 |
+
## Usage
|
| 43 |
+
|
| 44 |
+
Modify the path configuration in `run.sh`:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
MODEL_PATH="/path/to/BitCPM4-CANN-1B-unquantized/"
|
| 48 |
+
DATA_PATH="/path/to/c4-pro/data/your_file.parquet"
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
Then start training:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
bash run.sh
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
By default, the script trains for 500 steps using 8 devices, DeepSpeed ZeRO-2, and bf16 precision.
|
| 58 |
+
|
| 59 |
+
## Training Results Reference
|
| 60 |
+
|
| 61 |
+
Below is the loss curve for the first 100 steps (learning rate warmup covers the first 50 steps):
|
| 62 |
+
|
| 63 |
+
| Step | Loss | Learning Rate | Epoch |
|
| 64 |
+
| --- | --- | --- | --- |
|
| 65 |
+
| 2 | 2.7920 | 1.60e-06 | 0.01 |
|
| 66 |
+
| 4 | 2.8012 | 3.20e-06 | 0.02 |
|
| 67 |
+
| 6 | 2.7984 | 4.80e-06 | 0.03 |
|
| 68 |
+
| 8 | 2.7839 | 6.40e-06 | 0.04 |
|
| 69 |
+
| 10 | 2.8084 | 8.00e-06 | 0.05 |
|
| 70 |
+
| 12 | 2.8064 | 9.60e-06 | 0.06 |
|
| 71 |
+
| 14 | 2.7994 | 1.12e-05 | 0.07 |
|
| 72 |
+
| 16 | 2.7463 | 1.28e-05 | 0.08 |
|
| 73 |
+
| 18 | 2.7580 | 1.44e-05 | 0.09 |
|
| 74 |
+
| 20 | 2.8007 | 1.60e-05 | 0.10 |
|
| 75 |
+
| 22 | 2.8916 | 1.76e-05 | 0.12 |
|
| 76 |
+
| 24 | 2.8144 | 1.92e-05 | 0.13 |
|
| 77 |
+
| 26 | 2.7723 | 2.08e-05 | 0.14 |
|
| 78 |
+
| 28 | 2.7556 | 2.24e-05 | 0.15 |
|
| 79 |
+
| 30 | 2.7414 | 2.40e-05 | 0.16 |
|
| 80 |
+
| 32 | 2.7469 | 2.56e-05 | 0.17 |
|
| 81 |
+
| 34 | 2.7428 | 2.72e-05 | 0.18 |
|
| 82 |
+
| 36 | 2.7392 | 2.88e-05 | 0.19 |
|
| 83 |
+
| 38 | 2.7132 | 3.04e-05 | 0.20 |
|
| 84 |
+
| 40 | 2.7008 | 3.20e-05 | 0.21 |
|
| 85 |
+
| 42 | 2.7547 | 3.36e-05 | 0.22 |
|
| 86 |
+
| 44 | 2.7151 | 3.52e-05 | 0.23 |
|
| 87 |
+
| 46 | 2.7119 | 3.68e-05 | 0.24 |
|
| 88 |
+
| 48 | 2.7029 | 3.84e-05 | 0.25 |
|
| 89 |
+
| 50 | 2.6803 | 4.00e-05 | 0.26 |
|
| 90 |
+
| 52 | 2.6980 | 4.00e-05 | 0.27 |
|
| 91 |
+
| 54 | 2.6923 | 4.00e-05 | 0.28 |
|
| 92 |
+
| 56 | 2.7068 | 4.00e-05 | 0.29 |
|
| 93 |
+
| 58 | 2.6965 | 4.00e-05 | 0.30 |
|
| 94 |
+
| 60 | 2.7179 | 3.99e-05 | 0.31 |
|
| 95 |
+
| 62 | 2.7119 | 3.99e-05 | 0.32 |
|
| 96 |
+
| 64 | 2.7178 | 3.99e-05 | 0.33 |
|
| 97 |
+
| 66 | 2.7069 | 3.99e-05 | 0.35 |
|
| 98 |
+
| 68 | 2.6870 | 3.98e-05 | 0.36 |
|
| 99 |
+
| 70 | 2.6775 | 3.98e-05 | 0.37 |
|
| 100 |
+
| 72 | 2.7038 | 3.98e-05 | 0.38 |
|
| 101 |
+
| 74 | 2.6924 | 3.97e-05 | 0.39 |
|
| 102 |
+
| 76 | 2.7061 | 3.97e-05 | 0.40 |
|
| 103 |
+
| 78 | 2.6929 | 3.96e-05 | 0.41 |
|
| 104 |
+
| 80 | 2.6787 | 3.96e-05 | 0.42 |
|
| 105 |
+
| 82 | 2.6749 | 3.95e-05 | 0.43 |
|
| 106 |
+
| 84 | 2.6909 | 3.94e-05 | 0.44 |
|
| 107 |
+
| 86 | 2.6893 | 3.94e-05 | 0.45 |
|
| 108 |
+
| 88 | 2.6788 | 3.93e-05 | 0.46 |
|
| 109 |
+
| 90 | 2.6831 | 3.92e-05 | 0.47 |
|
| 110 |
+
| 92 | 2.7039 | 3.91e-05 | 0.48 |
|
| 111 |
+
| 94 | 2.6619 | 3.91e-05 | 0.49 |
|
| 112 |
+
| 96 | 2.6903 | 3.90e-05 | 0.50 |
|
| 113 |
+
| 98 | 2.6993 | 3.89e-05 | 0.51 |
|
| 114 |
+
| 100 | 2.6891 | 3.88e-05 | 0.52 |
|
| 115 |
+
| 102 | 2.6739 | 3.87e-05 | 0.53 |
|
| 116 |
+
|
| 117 |
+
> **Note:** BitCPM has its own training dataset and data mixture. It is expected that the loss continues to decrease when continue pretraining on open-source datasets.
|
| 118 |
+
|
| 119 |
+
As shown in the table, the loss gradually decreases from ~2.79 to ~2.67, indicating a stable training process and that the model is learning normally.
|
| 120 |
+
|
| 121 |
+
## File Description
|
| 122 |
+
|
| 123 |
+
| File | Description |
|
| 124 |
+
| --- | --- |
|
| 125 |
+
| `train.py` | Training script based on HuggingFace Trainer + DeepSpeed |
|
| 126 |
+
| `run.sh` | Launch script with training hyperparameter configuration |
|
| 127 |
+
| `train_sft.py` | Supervised fine-tuning script based on HuggingFace Trainer + DeepSpeed |
|
| 128 |
+
| `run_sft.sh` | Launch script for SFT with hyperparameter configuration |
|
| 129 |
+
| `ds_config.json` | DeepSpeed ZeRO-3 configuration (with CPU offload) |
|
| 130 |
+
| `ds_config_z2.json` | DeepSpeed ZeRO-2 configuration (used by default) |
|
| 131 |
+
| `requirements.txt` | Python dependency list |
|
example/ds_config.json
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"zero_optimization": {
|
| 6 |
+
"stage": 3,
|
| 7 |
+
"offload_optimizer": {
|
| 8 |
+
"device": "cpu",
|
| 9 |
+
"pin_memory": true
|
| 10 |
+
},
|
| 11 |
+
"offload_param": {
|
| 12 |
+
"device": "none"
|
| 13 |
+
},
|
| 14 |
+
"overlap_comm": true,
|
| 15 |
+
"contiguous_gradients": true,
|
| 16 |
+
"sub_group_size": 1e9,
|
| 17 |
+
"reduce_bucket_size": 2e8,
|
| 18 |
+
"stage3_prefetch_bucket_size": 2e8,
|
| 19 |
+
"stage3_param_persistence_threshold": 1e5,
|
| 20 |
+
"stage3_max_live_parameters": 2e9,
|
| 21 |
+
"stage3_max_reuse_distance": 2e9,
|
| 22 |
+
"stage3_gather_16bit_weights_on_model_save": true
|
| 23 |
+
},
|
| 24 |
+
"gradient_accumulation_steps": "auto",
|
| 25 |
+
"gradient_clipping": "auto",
|
| 26 |
+
"train_batch_size": "auto",
|
| 27 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 28 |
+
"wall_clock_breakdown": false
|
| 29 |
+
}
|
example/ds_config_z2.json
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"bf16": {
|
| 3 |
+
"enabled": true
|
| 4 |
+
},
|
| 5 |
+
"zero_optimization": {
|
| 6 |
+
"stage": 2,
|
| 7 |
+
"offload_optimizer": {
|
| 8 |
+
"device": "none"
|
| 9 |
+
},
|
| 10 |
+
"allgather_partitions": true,
|
| 11 |
+
"allgather_bucket_size": 2e8,
|
| 12 |
+
"overlap_comm": true,
|
| 13 |
+
"reduce_scatter": true,
|
| 14 |
+
"reduce_bucket_size": 2e8,
|
| 15 |
+
"contiguous_gradients": true
|
| 16 |
+
},
|
| 17 |
+
"gradient_accumulation_steps": "auto",
|
| 18 |
+
"gradient_clipping": "auto",
|
| 19 |
+
"train_batch_size": "auto",
|
| 20 |
+
"train_micro_batch_size_per_gpu": "auto",
|
| 21 |
+
"wall_clock_breakdown": false
|
| 22 |
+
}
|
example/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
transformers==4.46.3
|
| 2 |
+
tokenizers==0.20.3
|
| 3 |
+
accelerate==1.1.1
|
| 4 |
+
deepspeed==0.16.2
|
| 5 |
+
datasets==3.1.0
|
| 6 |
+
safetensors==0.4.5
|
| 7 |
+
pyarrow==17.0.0
|
| 8 |
+
tensorboard==2.18.0
|
example/run.sh
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH="/model/BitCPM/BitCPM4-CANN-1B-unquantized/"
|
| 4 |
+
DATA_PATH="/dataset/c4-pro/data/000_1_7.parquet"
|
| 5 |
+
OUTPUT_DIR="./output"
|
| 6 |
+
DS_CONFIG="./ds_config_z2.json"
|
| 7 |
+
|
| 8 |
+
NUM_GPUS=8
|
| 9 |
+
BATCH_SIZE_PER_GPU=8
|
| 10 |
+
GRAD_ACCUM_STEPS=8
|
| 11 |
+
MAX_SEQ_LENGTH=1024
|
| 12 |
+
|
| 13 |
+
export ASCEND_RT_VISIBLE_DEVICES=8,9,10,11,12,13,14,15
|
| 14 |
+
|
| 15 |
+
torchrun --nproc_per_node=$NUM_GPUS train.py \
|
| 16 |
+
--model_name_or_path $MODEL_PATH \
|
| 17 |
+
--data_path $DATA_PATH \
|
| 18 |
+
--max_seq_length $MAX_SEQ_LENGTH \
|
| 19 |
+
--output_dir $OUTPUT_DIR \
|
| 20 |
+
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
|
| 21 |
+
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
|
| 22 |
+
--max_steps 500 \
|
| 23 |
+
--learning_rate 4e-5 \
|
| 24 |
+
--lr_scheduler_type cosine \
|
| 25 |
+
--warmup_ratio 0.1 \
|
| 26 |
+
--weight_decay 1e-2 \
|
| 27 |
+
--logging_steps 2 \
|
| 28 |
+
--save_steps 500 \
|
| 29 |
+
--save_total_limit 3 \
|
| 30 |
+
--bf16 \
|
| 31 |
+
--deepspeed $DS_CONFIG \
|
| 32 |
+
--gradient_checkpointing \
|
| 33 |
+
--seed 42 \
|
| 34 |
+
--dataloader_num_workers 4 \
|
| 35 |
+
--report_to tensorboard \
|
| 36 |
+
--logging_dir /data/tensorboard/ \
|
| 37 |
+
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
|
example/run_sft.sh
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
MODEL_PATH="/model/BitCPM/BitCPM4-CANN-3B-unquantized/"
|
| 4 |
+
DATA_PATH=""
|
| 5 |
+
OUTPUT_DIR="./output_sft"
|
| 6 |
+
DS_CONFIG="./ds_config.json"
|
| 7 |
+
|
| 8 |
+
NUM_GPUS=8
|
| 9 |
+
BATCH_SIZE_PER_GPU=2
|
| 10 |
+
GRAD_ACCUM_STEPS=1
|
| 11 |
+
MAX_SEQ_LENGTH=4096
|
| 12 |
+
|
| 13 |
+
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
|
| 14 |
+
|
| 15 |
+
torchrun --nproc_per_node=$NUM_GPUS train_sft.py \
|
| 16 |
+
--model_name_or_path $MODEL_PATH \
|
| 17 |
+
--data_path $DATA_PATH \
|
| 18 |
+
--max_seq_length $MAX_SEQ_LENGTH \
|
| 19 |
+
--output_dir $OUTPUT_DIR \
|
| 20 |
+
--per_device_train_batch_size $BATCH_SIZE_PER_GPU \
|
| 21 |
+
--gradient_accumulation_steps $GRAD_ACCUM_STEPS \
|
| 22 |
+
--num_train_epochs 3 \
|
| 23 |
+
--learning_rate 2e-5 \
|
| 24 |
+
--lr_scheduler_type cosine \
|
| 25 |
+
--warmup_ratio 0.03 \
|
| 26 |
+
--weight_decay 0.0 \
|
| 27 |
+
--logging_steps 2 \
|
| 28 |
+
--save_steps 500 \
|
| 29 |
+
--save_total_limit 3 \
|
| 30 |
+
--bf16 \
|
| 31 |
+
--deepspeed $DS_CONFIG \
|
| 32 |
+
--gradient_checkpointing \
|
| 33 |
+
--seed 42 \
|
| 34 |
+
--dataloader_num_workers 4 \
|
| 35 |
+
--report_to tensorboard \
|
| 36 |
+
--logging_dir /data/tensorboard/sft \
|
| 37 |
+
--train_on_prompt false \
|
| 38 |
+
--gradient_checkpointing_kwargs '{"use_reentrant": false}'
|
example/train.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import logging
|
| 9 |
+
from dataclasses import dataclass, field
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
import contextlib
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
from datasets import load_dataset
|
| 16 |
+
from transformers import (
|
| 17 |
+
AutoModelForCausalLM,
|
| 18 |
+
AutoTokenizer,
|
| 19 |
+
AutoConfig,
|
| 20 |
+
Trainer,
|
| 21 |
+
TrainingArguments,
|
| 22 |
+
HfArgumentParser,
|
| 23 |
+
DataCollatorForLanguageModeling,
|
| 24 |
+
set_seed,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
import deepspeed
|
| 28 |
+
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
| 29 |
+
|
| 30 |
+
@contextlib.contextmanager
|
| 31 |
+
def _patched_no_sync(self):
|
| 32 |
+
try:
|
| 33 |
+
with _orig_no_sync(self):
|
| 34 |
+
yield
|
| 35 |
+
except AssertionError:
|
| 36 |
+
yield
|
| 37 |
+
|
| 38 |
+
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class ModelArguments:
|
| 45 |
+
model_name_or_path: str = field(
|
| 46 |
+
metadata={"help": "Path to pretrained model or model identifier"}
|
| 47 |
+
)
|
| 48 |
+
torch_dtype: Optional[str] = field(
|
| 49 |
+
default="bfloat16",
|
| 50 |
+
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class DataArguments:
|
| 56 |
+
data_path: str = field(
|
| 57 |
+
metadata={"help": "Path to training data (parquet file or directory)"}
|
| 58 |
+
)
|
| 59 |
+
max_seq_length: int = field(
|
| 60 |
+
default=4096,
|
| 61 |
+
metadata={"help": "Maximum sequence length for training"},
|
| 62 |
+
)
|
| 63 |
+
text_column: str = field(
|
| 64 |
+
default="text",
|
| 65 |
+
metadata={"help": "Name of the text column in the dataset"},
|
| 66 |
+
)
|
| 67 |
+
preprocessing_num_workers: int = field(
|
| 68 |
+
default=8,
|
| 69 |
+
metadata={"help": "Number of workers for data preprocessing"},
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def tokenize_and_group(dataset, tokenizer, data_args):
|
| 74 |
+
"""Tokenize texts and group into chunks of max_seq_length."""
|
| 75 |
+
|
| 76 |
+
column_names = dataset.column_names
|
| 77 |
+
text_column = data_args.text_column
|
| 78 |
+
if text_column not in column_names:
|
| 79 |
+
candidates = [c for c in column_names if "text" in c.lower()]
|
| 80 |
+
if candidates:
|
| 81 |
+
text_column = candidates[0]
|
| 82 |
+
else:
|
| 83 |
+
text_column = column_names[0]
|
| 84 |
+
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
|
| 85 |
+
|
| 86 |
+
def tokenize_function(examples):
|
| 87 |
+
return tokenizer(examples[text_column], add_special_tokens=False)
|
| 88 |
+
|
| 89 |
+
tokenized_dataset = dataset.map(
|
| 90 |
+
tokenize_function,
|
| 91 |
+
batched=True,
|
| 92 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 93 |
+
remove_columns=column_names,
|
| 94 |
+
desc="Tokenizing",
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
block_size = data_args.max_seq_length
|
| 98 |
+
|
| 99 |
+
def group_texts(examples):
|
| 100 |
+
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
|
| 101 |
+
total_length = len(concatenated["input_ids"])
|
| 102 |
+
total_length = (total_length // block_size) * block_size
|
| 103 |
+
|
| 104 |
+
result = {
|
| 105 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
| 106 |
+
for k, t in concatenated.items()
|
| 107 |
+
}
|
| 108 |
+
result["labels"] = result["input_ids"].copy()
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
grouped_dataset = tokenized_dataset.map(
|
| 112 |
+
group_texts,
|
| 113 |
+
batched=True,
|
| 114 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 115 |
+
desc="Grouping texts",
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return grouped_dataset
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 123 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 124 |
+
|
| 125 |
+
logging.basicConfig(
|
| 126 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 127 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 128 |
+
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
| 129 |
+
)
|
| 130 |
+
logger.info(f"Training args: {training_args}")
|
| 131 |
+
|
| 132 |
+
set_seed(training_args.seed)
|
| 133 |
+
|
| 134 |
+
dtype_map = {
|
| 135 |
+
"float16": torch.float16,
|
| 136 |
+
"bfloat16": torch.bfloat16,
|
| 137 |
+
"float32": torch.float32,
|
| 138 |
+
}
|
| 139 |
+
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
| 140 |
+
|
| 141 |
+
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
| 142 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 143 |
+
model_args.model_name_or_path,
|
| 144 |
+
trust_remote_code=True,
|
| 145 |
+
)
|
| 146 |
+
if tokenizer.pad_token is None:
|
| 147 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 148 |
+
|
| 149 |
+
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
| 150 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 151 |
+
model_args.model_name_or_path,
|
| 152 |
+
torch_dtype=torch_dtype,
|
| 153 |
+
trust_remote_code=True,
|
| 154 |
+
attn_implementation="sdpa",
|
| 155 |
+
)
|
| 156 |
+
model.config.use_cache = False
|
| 157 |
+
|
| 158 |
+
logger.info(f"Loading dataset from {data_args.data_path}")
|
| 159 |
+
if os.path.isfile(data_args.data_path):
|
| 160 |
+
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
|
| 161 |
+
elif os.path.isdir(data_args.data_path):
|
| 162 |
+
parquet_files = [
|
| 163 |
+
os.path.join(data_args.data_path, f)
|
| 164 |
+
for f in os.listdir(data_args.data_path)
|
| 165 |
+
if f.endswith(".parquet")
|
| 166 |
+
]
|
| 167 |
+
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
|
| 168 |
+
else:
|
| 169 |
+
raise ValueError(f"Data path not found: {data_args.data_path}")
|
| 170 |
+
|
| 171 |
+
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
| 172 |
+
|
| 173 |
+
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
|
| 174 |
+
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
|
| 175 |
+
|
| 176 |
+
data_collator = DataCollatorForLanguageModeling(
|
| 177 |
+
tokenizer=tokenizer,
|
| 178 |
+
mlm=False,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
trainer = Trainer(
|
| 182 |
+
model=model,
|
| 183 |
+
args=training_args,
|
| 184 |
+
train_dataset=train_dataset,
|
| 185 |
+
data_collator=data_collator,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
logger.info("Starting training...")
|
| 189 |
+
train_result = trainer.train(
|
| 190 |
+
resume_from_checkpoint=training_args.resume_from_checkpoint
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
trainer.save_model()
|
| 194 |
+
trainer.save_state()
|
| 195 |
+
|
| 196 |
+
metrics = train_result.metrics
|
| 197 |
+
metrics["train_samples"] = len(train_dataset)
|
| 198 |
+
trainer.log_metrics("train", metrics)
|
| 199 |
+
trainer.save_metrics("train", metrics)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
example/train_sft.py
ADDED
|
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import contextlib
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from datasets import load_dataset
|
| 15 |
+
from transformers import (
|
| 16 |
+
AutoModelForCausalLM,
|
| 17 |
+
AutoTokenizer,
|
| 18 |
+
HfArgumentParser,
|
| 19 |
+
Trainer,
|
| 20 |
+
TrainingArguments,
|
| 21 |
+
set_seed,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
import deepspeed
|
| 25 |
+
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
|
| 26 |
+
|
| 27 |
+
@contextlib.contextmanager
|
| 28 |
+
def _patched_no_sync(self):
|
| 29 |
+
try:
|
| 30 |
+
with _orig_no_sync(self):
|
| 31 |
+
yield
|
| 32 |
+
except AssertionError:
|
| 33 |
+
yield
|
| 34 |
+
|
| 35 |
+
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
|
| 36 |
+
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
IGNORE_INDEX = -100
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class ModelArguments:
|
| 44 |
+
model_name_or_path: str = field(
|
| 45 |
+
metadata={"help": "Path to pretrained model or model identifier"}
|
| 46 |
+
)
|
| 47 |
+
torch_dtype: Optional[str] = field(
|
| 48 |
+
default="bfloat16",
|
| 49 |
+
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class DataArguments:
|
| 55 |
+
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
|
| 56 |
+
max_seq_length: int = field(
|
| 57 |
+
default=4096,
|
| 58 |
+
metadata={"help": "Maximum sequence length for training"},
|
| 59 |
+
)
|
| 60 |
+
prompt_column: Optional[str] = field(
|
| 61 |
+
default=None,
|
| 62 |
+
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
|
| 63 |
+
)
|
| 64 |
+
input_column: Optional[str] = field(
|
| 65 |
+
default=None,
|
| 66 |
+
metadata={"help": "Optional extra input/context column name"},
|
| 67 |
+
)
|
| 68 |
+
response_column: Optional[str] = field(
|
| 69 |
+
default=None,
|
| 70 |
+
metadata={"help": "Response/output column name. Auto-detected if omitted."},
|
| 71 |
+
)
|
| 72 |
+
messages_column: Optional[str] = field(
|
| 73 |
+
default=None,
|
| 74 |
+
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
|
| 75 |
+
)
|
| 76 |
+
system_column: Optional[str] = field(
|
| 77 |
+
default=None,
|
| 78 |
+
metadata={"help": "Optional system prompt column name"},
|
| 79 |
+
)
|
| 80 |
+
train_on_prompt: bool = field(
|
| 81 |
+
default=False,
|
| 82 |
+
metadata={"help": "Whether to compute loss on prompt/user tokens"},
|
| 83 |
+
)
|
| 84 |
+
add_eos_token: bool = field(
|
| 85 |
+
default=True,
|
| 86 |
+
metadata={"help": "Append eos_token to plain prompt/response examples"},
|
| 87 |
+
)
|
| 88 |
+
preprocessing_num_workers: int = field(
|
| 89 |
+
default=8,
|
| 90 |
+
metadata={"help": "Number of workers for data preprocessing"},
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class SFTDataCollator:
|
| 95 |
+
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
|
| 96 |
+
self.tokenizer = tokenizer
|
| 97 |
+
self.pad_to_multiple_of = pad_to_multiple_of
|
| 98 |
+
|
| 99 |
+
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
|
| 100 |
+
max_length = max(len(feature["input_ids"]) for feature in features)
|
| 101 |
+
if self.pad_to_multiple_of:
|
| 102 |
+
multiple = self.pad_to_multiple_of
|
| 103 |
+
max_length = ((max_length + multiple - 1) // multiple) * multiple
|
| 104 |
+
|
| 105 |
+
input_ids = []
|
| 106 |
+
attention_mask = []
|
| 107 |
+
labels = []
|
| 108 |
+
pad_token_id = self.tokenizer.pad_token_id
|
| 109 |
+
|
| 110 |
+
for feature in features:
|
| 111 |
+
length = len(feature["input_ids"])
|
| 112 |
+
pad_length = max_length - length
|
| 113 |
+
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
|
| 114 |
+
attention_mask.append([1] * length + [0] * pad_length)
|
| 115 |
+
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
|
| 116 |
+
|
| 117 |
+
return {
|
| 118 |
+
"input_ids": torch.tensor(input_ids, dtype=torch.long),
|
| 119 |
+
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
|
| 120 |
+
"labels": torch.tensor(labels, dtype=torch.long),
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def load_sft_dataset(data_path: str):
|
| 125 |
+
if os.path.isfile(data_path):
|
| 126 |
+
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
|
| 127 |
+
if extension == "jsonl":
|
| 128 |
+
extension = "json"
|
| 129 |
+
if extension not in {"parquet", "json", "csv", "txt"}:
|
| 130 |
+
raise ValueError(f"Unsupported data file extension: {extension}")
|
| 131 |
+
return load_dataset(extension, data_files=data_path, split="train")
|
| 132 |
+
|
| 133 |
+
if os.path.isdir(data_path):
|
| 134 |
+
data_files = []
|
| 135 |
+
extension = None
|
| 136 |
+
for name in os.listdir(data_path):
|
| 137 |
+
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
|
| 138 |
+
if current_extension == "jsonl":
|
| 139 |
+
current_extension = "json"
|
| 140 |
+
if current_extension in {"parquet", "json", "csv", "txt"}:
|
| 141 |
+
extension = extension or current_extension
|
| 142 |
+
if current_extension == extension:
|
| 143 |
+
data_files.append(os.path.join(data_path, name))
|
| 144 |
+
if not data_files or extension is None:
|
| 145 |
+
raise ValueError(f"No supported data files found in: {data_path}")
|
| 146 |
+
return load_dataset(extension, data_files=sorted(data_files), split="train")
|
| 147 |
+
|
| 148 |
+
raise ValueError(f"Data path not found: {data_path}")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def choose_column(
|
| 152 |
+
column_names: List[str], explicit: Optional[str], candidates: List[str]
|
| 153 |
+
) -> Optional[str]:
|
| 154 |
+
if explicit:
|
| 155 |
+
if explicit not in column_names:
|
| 156 |
+
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
|
| 157 |
+
return explicit
|
| 158 |
+
for name in candidates:
|
| 159 |
+
if name in column_names:
|
| 160 |
+
return name
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def parse_messages(value: Any) -> List[Dict[str, str]]:
|
| 165 |
+
if isinstance(value, str):
|
| 166 |
+
value = json.loads(value)
|
| 167 |
+
if not isinstance(value, list):
|
| 168 |
+
raise ValueError("messages/conversations column must be a list or JSON string")
|
| 169 |
+
|
| 170 |
+
messages = []
|
| 171 |
+
for item in value:
|
| 172 |
+
if not isinstance(item, dict):
|
| 173 |
+
raise ValueError("Each message must be a dict")
|
| 174 |
+
|
| 175 |
+
role = item.get("role", item.get("from"))
|
| 176 |
+
content = item.get("content", item.get("value"))
|
| 177 |
+
if role == "human":
|
| 178 |
+
role = "user"
|
| 179 |
+
elif role == "gpt":
|
| 180 |
+
role = "assistant"
|
| 181 |
+
|
| 182 |
+
if role is None or content is None:
|
| 183 |
+
raise ValueError("Each message must contain role/from and content/value")
|
| 184 |
+
messages.append({"role": str(role), "content": str(content)})
|
| 185 |
+
|
| 186 |
+
return messages
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def tokenize_text(tokenizer, text: str) -> List[int]:
|
| 190 |
+
return tokenizer(text, add_special_tokens=False)["input_ids"]
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
|
| 194 |
+
if tokenizer.chat_template is None:
|
| 195 |
+
raise ValueError(
|
| 196 |
+
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
|
| 197 |
+
)
|
| 198 |
+
return tokenizer.apply_chat_template(
|
| 199 |
+
messages,
|
| 200 |
+
tokenize=False,
|
| 201 |
+
add_generation_prompt=add_generation_prompt,
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def encode_prompt_response(
|
| 206 |
+
example: Dict[str, Any],
|
| 207 |
+
tokenizer,
|
| 208 |
+
data_args: DataArguments,
|
| 209 |
+
prompt_column: str,
|
| 210 |
+
input_column: Optional[str],
|
| 211 |
+
response_column: str,
|
| 212 |
+
) -> Tuple[List[int], List[int]]:
|
| 213 |
+
prompt = str(example[prompt_column])
|
| 214 |
+
if input_column and example.get(input_column):
|
| 215 |
+
prompt = prompt + "\n" + str(example[input_column])
|
| 216 |
+
response = str(example[response_column])
|
| 217 |
+
|
| 218 |
+
messages = []
|
| 219 |
+
if data_args.system_column and example.get(data_args.system_column):
|
| 220 |
+
messages.append({"role": "system", "content": str(example[data_args.system_column])})
|
| 221 |
+
messages.append({"role": "user", "content": prompt})
|
| 222 |
+
messages.append({"role": "assistant", "content": response})
|
| 223 |
+
|
| 224 |
+
if tokenizer.chat_template is not None:
|
| 225 |
+
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
| 226 |
+
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
|
| 227 |
+
input_ids = tokenize_text(tokenizer, full_text)
|
| 228 |
+
prompt_length = len(tokenize_text(tokenizer, prompt_text))
|
| 229 |
+
else:
|
| 230 |
+
response_text = response
|
| 231 |
+
if data_args.add_eos_token and tokenizer.eos_token:
|
| 232 |
+
response_text += tokenizer.eos_token
|
| 233 |
+
full_text = prompt + "\n" + response_text
|
| 234 |
+
input_ids = tokenize_text(tokenizer, full_text)
|
| 235 |
+
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
|
| 236 |
+
|
| 237 |
+
labels = input_ids.copy()
|
| 238 |
+
if not data_args.train_on_prompt:
|
| 239 |
+
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
|
| 240 |
+
return input_ids, labels
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def encode_messages(
|
| 244 |
+
example: Dict[str, Any],
|
| 245 |
+
tokenizer,
|
| 246 |
+
data_args: DataArguments,
|
| 247 |
+
messages_column: str,
|
| 248 |
+
) -> Tuple[List[int], List[int]]:
|
| 249 |
+
messages = parse_messages(example[messages_column])
|
| 250 |
+
|
| 251 |
+
if tokenizer.chat_template is not None:
|
| 252 |
+
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
|
| 253 |
+
input_ids = tokenize_text(tokenizer, full_text)
|
| 254 |
+
labels = [IGNORE_INDEX] * len(input_ids)
|
| 255 |
+
|
| 256 |
+
if data_args.train_on_prompt:
|
| 257 |
+
labels = input_ids.copy()
|
| 258 |
+
else:
|
| 259 |
+
for index, message in enumerate(messages):
|
| 260 |
+
if message["role"] != "assistant":
|
| 261 |
+
continue
|
| 262 |
+
before_text = apply_chat_template(
|
| 263 |
+
tokenizer, messages[:index], add_generation_prompt=True
|
| 264 |
+
)
|
| 265 |
+
after_text = apply_chat_template(
|
| 266 |
+
tokenizer, messages[: index + 1], add_generation_prompt=False
|
| 267 |
+
)
|
| 268 |
+
start = len(tokenize_text(tokenizer, before_text))
|
| 269 |
+
end = len(tokenize_text(tokenizer, after_text))
|
| 270 |
+
labels[start:end] = input_ids[start:end]
|
| 271 |
+
else:
|
| 272 |
+
labels = []
|
| 273 |
+
input_ids = []
|
| 274 |
+
for message in messages:
|
| 275 |
+
part = f"{message['role']}: {message['content']}\n"
|
| 276 |
+
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
|
| 277 |
+
part += tokenizer.eos_token
|
| 278 |
+
part_ids = tokenize_text(tokenizer, part)
|
| 279 |
+
input_ids.extend(part_ids)
|
| 280 |
+
if data_args.train_on_prompt or message["role"] == "assistant":
|
| 281 |
+
labels.extend(part_ids)
|
| 282 |
+
else:
|
| 283 |
+
labels.extend([IGNORE_INDEX] * len(part_ids))
|
| 284 |
+
|
| 285 |
+
return input_ids, labels
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
|
| 289 |
+
column_names = raw_dataset.column_names
|
| 290 |
+
messages_column = choose_column(
|
| 291 |
+
column_names, data_args.messages_column, ["messages", "conversations"]
|
| 292 |
+
)
|
| 293 |
+
prompt_column = choose_column(
|
| 294 |
+
column_names,
|
| 295 |
+
data_args.prompt_column,
|
| 296 |
+
["prompt", "instruction", "question"],
|
| 297 |
+
)
|
| 298 |
+
input_column = choose_column(
|
| 299 |
+
column_names,
|
| 300 |
+
data_args.input_column,
|
| 301 |
+
["input", "context"],
|
| 302 |
+
)
|
| 303 |
+
response_column = choose_column(
|
| 304 |
+
column_names,
|
| 305 |
+
data_args.response_column,
|
| 306 |
+
["response", "output", "answer", "chosen"],
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
if messages_column:
|
| 310 |
+
logger.info(f"Using chat messages column: {messages_column}")
|
| 311 |
+
elif prompt_column and response_column:
|
| 312 |
+
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
|
| 313 |
+
else:
|
| 314 |
+
raise ValueError(
|
| 315 |
+
"Cannot infer SFT data format. Provide either messages/conversations or "
|
| 316 |
+
"prompt/instruction plus response/output columns."
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def encode_batch(examples):
|
| 320 |
+
batch_input_ids = []
|
| 321 |
+
batch_labels = []
|
| 322 |
+
batch_attention_mask = []
|
| 323 |
+
|
| 324 |
+
batch_size = len(next(iter(examples.values())))
|
| 325 |
+
for i in range(batch_size):
|
| 326 |
+
example = {name: values[i] for name, values in examples.items()}
|
| 327 |
+
if messages_column:
|
| 328 |
+
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
|
| 329 |
+
else:
|
| 330 |
+
input_ids, labels = encode_prompt_response(
|
| 331 |
+
example, tokenizer, data_args, prompt_column, input_column, response_column
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
input_ids = input_ids[: data_args.max_seq_length]
|
| 335 |
+
labels = labels[: data_args.max_seq_length]
|
| 336 |
+
if not input_ids or all(label == IGNORE_INDEX for label in labels):
|
| 337 |
+
continue
|
| 338 |
+
|
| 339 |
+
batch_input_ids.append(input_ids)
|
| 340 |
+
batch_labels.append(labels)
|
| 341 |
+
batch_attention_mask.append([1] * len(input_ids))
|
| 342 |
+
|
| 343 |
+
return {
|
| 344 |
+
"input_ids": batch_input_ids,
|
| 345 |
+
"attention_mask": batch_attention_mask,
|
| 346 |
+
"labels": batch_labels,
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
return raw_dataset.map(
|
| 350 |
+
encode_batch,
|
| 351 |
+
batched=True,
|
| 352 |
+
num_proc=data_args.preprocessing_num_workers,
|
| 353 |
+
remove_columns=column_names,
|
| 354 |
+
desc="Tokenizing SFT data",
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def main():
|
| 359 |
+
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
|
| 360 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
| 361 |
+
|
| 362 |
+
logging.basicConfig(
|
| 363 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
| 364 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 365 |
+
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
|
| 366 |
+
)
|
| 367 |
+
logger.info(f"Training args: {training_args}")
|
| 368 |
+
|
| 369 |
+
set_seed(training_args.seed)
|
| 370 |
+
|
| 371 |
+
dtype_map = {
|
| 372 |
+
"float16": torch.float16,
|
| 373 |
+
"bfloat16": torch.bfloat16,
|
| 374 |
+
"float32": torch.float32,
|
| 375 |
+
}
|
| 376 |
+
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
|
| 377 |
+
|
| 378 |
+
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
|
| 379 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 380 |
+
model_args.model_name_or_path,
|
| 381 |
+
trust_remote_code=True,
|
| 382 |
+
)
|
| 383 |
+
if tokenizer.pad_token is None:
|
| 384 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 385 |
+
|
| 386 |
+
logger.info(f"Loading model from {model_args.model_name_or_path}")
|
| 387 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 388 |
+
model_args.model_name_or_path,
|
| 389 |
+
torch_dtype=torch_dtype,
|
| 390 |
+
trust_remote_code=True,
|
| 391 |
+
attn_implementation="sdpa",
|
| 392 |
+
)
|
| 393 |
+
model.config.use_cache = False
|
| 394 |
+
|
| 395 |
+
logger.info(f"Loading SFT dataset from {data_args.data_path}")
|
| 396 |
+
raw_dataset = load_sft_dataset(data_args.data_path)
|
| 397 |
+
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
|
| 398 |
+
|
| 399 |
+
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
|
| 400 |
+
logger.info(f"Processed dataset: {len(train_dataset)} samples")
|
| 401 |
+
|
| 402 |
+
trainer = Trainer(
|
| 403 |
+
model=model,
|
| 404 |
+
args=training_args,
|
| 405 |
+
train_dataset=train_dataset,
|
| 406 |
+
data_collator=SFTDataCollator(tokenizer),
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
logger.info("Starting SFT training...")
|
| 410 |
+
train_result = trainer.train(
|
| 411 |
+
resume_from_checkpoint=training_args.resume_from_checkpoint
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
trainer.save_model()
|
| 415 |
+
trainer.save_state()
|
| 416 |
+
|
| 417 |
+
metrics = train_result.metrics
|
| 418 |
+
metrics["train_samples"] = len(train_dataset)
|
| 419 |
+
trainer.log_metrics("train", metrics)
|
| 420 |
+
trainer.save_metrics("train", metrics)
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
if __name__ == "__main__":
|
| 424 |
+
main()
|