guanwenyu1995 commited on
Commit
d7bc7ce
·
verified ·
1 Parent(s): 8a5049d

Add example/ folder with training scripts

Browse files
example/README.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BitCPM4 Continue Pretrain Example
2
+
3
+ This project provides scripts for continue pretraining **BitCPM4-CANN-1B-unquantized**.
4
+
5
+ ## Environment Setup
6
+
7
+ ### Docker Image
8
+
9
+ Use the following Huawei NPU image:
10
+
11
+ ```
12
+ swr.cn-south-1.myhuaweicloud.com/ascendhub/mindspeed-llm:openeuler22.03-mindspeed-llm-2.3.0-a3-arm
13
+ ```
14
+
15
+ Other Huawei NPU images may also work but have not been fully tested.
16
+
17
+ ### Install Dependencies
18
+
19
+ After entering the container, install the Python dependencies:
20
+
21
+ ```bash
22
+ pip install -r requirements.txt
23
+ ```
24
+
25
+ Dependency list:
26
+
27
+ | Package | Version |
28
+ | --- | --- |
29
+ | transformers | 4.46.3 |
30
+ | tokenizers | 0.20.3 |
31
+ | accelerate | 1.1.1 |
32
+ | deepspeed | 0.16.2 |
33
+ | datasets | 3.1.0 |
34
+ | safetensors | 0.4.5 |
35
+ | pyarrow | 17.0.0 |
36
+ | tensorboard | 2.18.0 |
37
+
38
+ ## Dataset
39
+
40
+ The test dataset used is [C4-Pro](https://huggingface.co/datasets/gair-prox/c4-pro), stored in parquet format after downloading.
41
+
42
+ ## Usage
43
+
44
+ Modify the path configuration in `run.sh`:
45
+
46
+ ```bash
47
+ MODEL_PATH="/path/to/BitCPM4-CANN-1B-unquantized/"
48
+ DATA_PATH="/path/to/c4-pro/data/your_file.parquet"
49
+ ```
50
+
51
+ Then start training:
52
+
53
+ ```bash
54
+ bash run.sh
55
+ ```
56
+
57
+ By default, the script trains for 500 steps using 8 devices, DeepSpeed ZeRO-2, and bf16 precision.
58
+
59
+ ## Training Results Reference
60
+
61
+ Below is the loss curve for the first 100 steps (learning rate warmup covers the first 50 steps):
62
+
63
+ | Step | Loss | Learning Rate | Epoch |
64
+ | --- | --- | --- | --- |
65
+ | 2 | 2.7920 | 1.60e-06 | 0.01 |
66
+ | 4 | 2.8012 | 3.20e-06 | 0.02 |
67
+ | 6 | 2.7984 | 4.80e-06 | 0.03 |
68
+ | 8 | 2.7839 | 6.40e-06 | 0.04 |
69
+ | 10 | 2.8084 | 8.00e-06 | 0.05 |
70
+ | 12 | 2.8064 | 9.60e-06 | 0.06 |
71
+ | 14 | 2.7994 | 1.12e-05 | 0.07 |
72
+ | 16 | 2.7463 | 1.28e-05 | 0.08 |
73
+ | 18 | 2.7580 | 1.44e-05 | 0.09 |
74
+ | 20 | 2.8007 | 1.60e-05 | 0.10 |
75
+ | 22 | 2.8916 | 1.76e-05 | 0.12 |
76
+ | 24 | 2.8144 | 1.92e-05 | 0.13 |
77
+ | 26 | 2.7723 | 2.08e-05 | 0.14 |
78
+ | 28 | 2.7556 | 2.24e-05 | 0.15 |
79
+ | 30 | 2.7414 | 2.40e-05 | 0.16 |
80
+ | 32 | 2.7469 | 2.56e-05 | 0.17 |
81
+ | 34 | 2.7428 | 2.72e-05 | 0.18 |
82
+ | 36 | 2.7392 | 2.88e-05 | 0.19 |
83
+ | 38 | 2.7132 | 3.04e-05 | 0.20 |
84
+ | 40 | 2.7008 | 3.20e-05 | 0.21 |
85
+ | 42 | 2.7547 | 3.36e-05 | 0.22 |
86
+ | 44 | 2.7151 | 3.52e-05 | 0.23 |
87
+ | 46 | 2.7119 | 3.68e-05 | 0.24 |
88
+ | 48 | 2.7029 | 3.84e-05 | 0.25 |
89
+ | 50 | 2.6803 | 4.00e-05 | 0.26 |
90
+ | 52 | 2.6980 | 4.00e-05 | 0.27 |
91
+ | 54 | 2.6923 | 4.00e-05 | 0.28 |
92
+ | 56 | 2.7068 | 4.00e-05 | 0.29 |
93
+ | 58 | 2.6965 | 4.00e-05 | 0.30 |
94
+ | 60 | 2.7179 | 3.99e-05 | 0.31 |
95
+ | 62 | 2.7119 | 3.99e-05 | 0.32 |
96
+ | 64 | 2.7178 | 3.99e-05 | 0.33 |
97
+ | 66 | 2.7069 | 3.99e-05 | 0.35 |
98
+ | 68 | 2.6870 | 3.98e-05 | 0.36 |
99
+ | 70 | 2.6775 | 3.98e-05 | 0.37 |
100
+ | 72 | 2.7038 | 3.98e-05 | 0.38 |
101
+ | 74 | 2.6924 | 3.97e-05 | 0.39 |
102
+ | 76 | 2.7061 | 3.97e-05 | 0.40 |
103
+ | 78 | 2.6929 | 3.96e-05 | 0.41 |
104
+ | 80 | 2.6787 | 3.96e-05 | 0.42 |
105
+ | 82 | 2.6749 | 3.95e-05 | 0.43 |
106
+ | 84 | 2.6909 | 3.94e-05 | 0.44 |
107
+ | 86 | 2.6893 | 3.94e-05 | 0.45 |
108
+ | 88 | 2.6788 | 3.93e-05 | 0.46 |
109
+ | 90 | 2.6831 | 3.92e-05 | 0.47 |
110
+ | 92 | 2.7039 | 3.91e-05 | 0.48 |
111
+ | 94 | 2.6619 | 3.91e-05 | 0.49 |
112
+ | 96 | 2.6903 | 3.90e-05 | 0.50 |
113
+ | 98 | 2.6993 | 3.89e-05 | 0.51 |
114
+ | 100 | 2.6891 | 3.88e-05 | 0.52 |
115
+ | 102 | 2.6739 | 3.87e-05 | 0.53 |
116
+
117
+ > **Note:** BitCPM has its own training dataset and data mixture. It is expected that the loss continues to decrease when continue pretraining on open-source datasets.
118
+
119
+ As shown in the table, the loss gradually decreases from ~2.79 to ~2.67, indicating a stable training process and that the model is learning normally.
120
+
121
+ ## File Description
122
+
123
+ | File | Description |
124
+ | --- | --- |
125
+ | `train.py` | Training script based on HuggingFace Trainer + DeepSpeed |
126
+ | `run.sh` | Launch script with training hyperparameter configuration |
127
+ | `train_sft.py` | Supervised fine-tuning script based on HuggingFace Trainer + DeepSpeed |
128
+ | `run_sft.sh` | Launch script for SFT with hyperparameter configuration |
129
+ | `ds_config.json` | DeepSpeed ZeRO-3 configuration (with CPU offload) |
130
+ | `ds_config_z2.json` | DeepSpeed ZeRO-2 configuration (used by default) |
131
+ | `requirements.txt` | Python dependency list |
example/ds_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 3,
7
+ "offload_optimizer": {
8
+ "device": "cpu",
9
+ "pin_memory": true
10
+ },
11
+ "offload_param": {
12
+ "device": "none"
13
+ },
14
+ "overlap_comm": true,
15
+ "contiguous_gradients": true,
16
+ "sub_group_size": 1e9,
17
+ "reduce_bucket_size": 2e8,
18
+ "stage3_prefetch_bucket_size": 2e8,
19
+ "stage3_param_persistence_threshold": 1e5,
20
+ "stage3_max_live_parameters": 2e9,
21
+ "stage3_max_reuse_distance": 2e9,
22
+ "stage3_gather_16bit_weights_on_model_save": true
23
+ },
24
+ "gradient_accumulation_steps": "auto",
25
+ "gradient_clipping": "auto",
26
+ "train_batch_size": "auto",
27
+ "train_micro_batch_size_per_gpu": "auto",
28
+ "wall_clock_breakdown": false
29
+ }
example/ds_config_z2.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bf16": {
3
+ "enabled": true
4
+ },
5
+ "zero_optimization": {
6
+ "stage": 2,
7
+ "offload_optimizer": {
8
+ "device": "none"
9
+ },
10
+ "allgather_partitions": true,
11
+ "allgather_bucket_size": 2e8,
12
+ "overlap_comm": true,
13
+ "reduce_scatter": true,
14
+ "reduce_bucket_size": 2e8,
15
+ "contiguous_gradients": true
16
+ },
17
+ "gradient_accumulation_steps": "auto",
18
+ "gradient_clipping": "auto",
19
+ "train_batch_size": "auto",
20
+ "train_micro_batch_size_per_gpu": "auto",
21
+ "wall_clock_breakdown": false
22
+ }
example/requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers==4.46.3
2
+ tokenizers==0.20.3
3
+ accelerate==1.1.1
4
+ deepspeed==0.16.2
5
+ datasets==3.1.0
6
+ safetensors==0.4.5
7
+ pyarrow==17.0.0
8
+ tensorboard==2.18.0
example/run.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH="/model/BitCPM/BitCPM4-CANN-1B-unquantized/"
4
+ DATA_PATH="/dataset/c4-pro/data/000_1_7.parquet"
5
+ OUTPUT_DIR="./output"
6
+ DS_CONFIG="./ds_config_z2.json"
7
+
8
+ NUM_GPUS=8
9
+ BATCH_SIZE_PER_GPU=8
10
+ GRAD_ACCUM_STEPS=8
11
+ MAX_SEQ_LENGTH=1024
12
+
13
+ export ASCEND_RT_VISIBLE_DEVICES=8,9,10,11,12,13,14,15
14
+
15
+ torchrun --nproc_per_node=$NUM_GPUS train.py \
16
+ --model_name_or_path $MODEL_PATH \
17
+ --data_path $DATA_PATH \
18
+ --max_seq_length $MAX_SEQ_LENGTH \
19
+ --output_dir $OUTPUT_DIR \
20
+ --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
21
+ --gradient_accumulation_steps $GRAD_ACCUM_STEPS \
22
+ --max_steps 500 \
23
+ --learning_rate 4e-5 \
24
+ --lr_scheduler_type cosine \
25
+ --warmup_ratio 0.1 \
26
+ --weight_decay 1e-2 \
27
+ --logging_steps 2 \
28
+ --save_steps 500 \
29
+ --save_total_limit 3 \
30
+ --bf16 \
31
+ --deepspeed $DS_CONFIG \
32
+ --gradient_checkpointing \
33
+ --seed 42 \
34
+ --dataloader_num_workers 4 \
35
+ --report_to tensorboard \
36
+ --logging_dir /data/tensorboard/ \
37
+ --gradient_checkpointing_kwargs '{"use_reentrant": false}'
example/run_sft.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ MODEL_PATH="/model/BitCPM/BitCPM4-CANN-3B-unquantized/"
4
+ DATA_PATH=""
5
+ OUTPUT_DIR="./output_sft"
6
+ DS_CONFIG="./ds_config.json"
7
+
8
+ NUM_GPUS=8
9
+ BATCH_SIZE_PER_GPU=2
10
+ GRAD_ACCUM_STEPS=1
11
+ MAX_SEQ_LENGTH=4096
12
+
13
+ export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
14
+
15
+ torchrun --nproc_per_node=$NUM_GPUS train_sft.py \
16
+ --model_name_or_path $MODEL_PATH \
17
+ --data_path $DATA_PATH \
18
+ --max_seq_length $MAX_SEQ_LENGTH \
19
+ --output_dir $OUTPUT_DIR \
20
+ --per_device_train_batch_size $BATCH_SIZE_PER_GPU \
21
+ --gradient_accumulation_steps $GRAD_ACCUM_STEPS \
22
+ --num_train_epochs 3 \
23
+ --learning_rate 2e-5 \
24
+ --lr_scheduler_type cosine \
25
+ --warmup_ratio 0.03 \
26
+ --weight_decay 0.0 \
27
+ --logging_steps 2 \
28
+ --save_steps 500 \
29
+ --save_total_limit 3 \
30
+ --bf16 \
31
+ --deepspeed $DS_CONFIG \
32
+ --gradient_checkpointing \
33
+ --seed 42 \
34
+ --dataloader_num_workers 4 \
35
+ --report_to tensorboard \
36
+ --logging_dir /data/tensorboard/sft \
37
+ --train_on_prompt false \
38
+ --gradient_checkpointing_kwargs '{"use_reentrant": false}'
example/train.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
3
+ """
4
+
5
+ import os
6
+ import json
7
+ import math
8
+ import logging
9
+ from dataclasses import dataclass, field
10
+ from typing import Optional
11
+
12
+ import contextlib
13
+
14
+ import torch
15
+ from datasets import load_dataset
16
+ from transformers import (
17
+ AutoModelForCausalLM,
18
+ AutoTokenizer,
19
+ AutoConfig,
20
+ Trainer,
21
+ TrainingArguments,
22
+ HfArgumentParser,
23
+ DataCollatorForLanguageModeling,
24
+ set_seed,
25
+ )
26
+
27
+ import deepspeed
28
+ _orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
29
+
30
+ @contextlib.contextmanager
31
+ def _patched_no_sync(self):
32
+ try:
33
+ with _orig_no_sync(self):
34
+ yield
35
+ except AssertionError:
36
+ yield
37
+
38
+ deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ @dataclass
44
+ class ModelArguments:
45
+ model_name_or_path: str = field(
46
+ metadata={"help": "Path to pretrained model or model identifier"}
47
+ )
48
+ torch_dtype: Optional[str] = field(
49
+ default="bfloat16",
50
+ metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
51
+ )
52
+
53
+
54
+ @dataclass
55
+ class DataArguments:
56
+ data_path: str = field(
57
+ metadata={"help": "Path to training data (parquet file or directory)"}
58
+ )
59
+ max_seq_length: int = field(
60
+ default=4096,
61
+ metadata={"help": "Maximum sequence length for training"},
62
+ )
63
+ text_column: str = field(
64
+ default="text",
65
+ metadata={"help": "Name of the text column in the dataset"},
66
+ )
67
+ preprocessing_num_workers: int = field(
68
+ default=8,
69
+ metadata={"help": "Number of workers for data preprocessing"},
70
+ )
71
+
72
+
73
+ def tokenize_and_group(dataset, tokenizer, data_args):
74
+ """Tokenize texts and group into chunks of max_seq_length."""
75
+
76
+ column_names = dataset.column_names
77
+ text_column = data_args.text_column
78
+ if text_column not in column_names:
79
+ candidates = [c for c in column_names if "text" in c.lower()]
80
+ if candidates:
81
+ text_column = candidates[0]
82
+ else:
83
+ text_column = column_names[0]
84
+ logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
85
+
86
+ def tokenize_function(examples):
87
+ return tokenizer(examples[text_column], add_special_tokens=False)
88
+
89
+ tokenized_dataset = dataset.map(
90
+ tokenize_function,
91
+ batched=True,
92
+ num_proc=data_args.preprocessing_num_workers,
93
+ remove_columns=column_names,
94
+ desc="Tokenizing",
95
+ )
96
+
97
+ block_size = data_args.max_seq_length
98
+
99
+ def group_texts(examples):
100
+ concatenated = {k: sum(examples[k], []) for k in examples.keys()}
101
+ total_length = len(concatenated["input_ids"])
102
+ total_length = (total_length // block_size) * block_size
103
+
104
+ result = {
105
+ k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
106
+ for k, t in concatenated.items()
107
+ }
108
+ result["labels"] = result["input_ids"].copy()
109
+ return result
110
+
111
+ grouped_dataset = tokenized_dataset.map(
112
+ group_texts,
113
+ batched=True,
114
+ num_proc=data_args.preprocessing_num_workers,
115
+ desc="Grouping texts",
116
+ )
117
+
118
+ return grouped_dataset
119
+
120
+
121
+ def main():
122
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
123
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
124
+
125
+ logging.basicConfig(
126
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
127
+ datefmt="%Y-%m-%d %H:%M:%S",
128
+ level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
129
+ )
130
+ logger.info(f"Training args: {training_args}")
131
+
132
+ set_seed(training_args.seed)
133
+
134
+ dtype_map = {
135
+ "float16": torch.float16,
136
+ "bfloat16": torch.bfloat16,
137
+ "float32": torch.float32,
138
+ }
139
+ torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
140
+
141
+ logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
142
+ tokenizer = AutoTokenizer.from_pretrained(
143
+ model_args.model_name_or_path,
144
+ trust_remote_code=True,
145
+ )
146
+ if tokenizer.pad_token is None:
147
+ tokenizer.pad_token = tokenizer.eos_token
148
+
149
+ logger.info(f"Loading model from {model_args.model_name_or_path}")
150
+ model = AutoModelForCausalLM.from_pretrained(
151
+ model_args.model_name_or_path,
152
+ torch_dtype=torch_dtype,
153
+ trust_remote_code=True,
154
+ attn_implementation="sdpa",
155
+ )
156
+ model.config.use_cache = False
157
+
158
+ logger.info(f"Loading dataset from {data_args.data_path}")
159
+ if os.path.isfile(data_args.data_path):
160
+ raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
161
+ elif os.path.isdir(data_args.data_path):
162
+ parquet_files = [
163
+ os.path.join(data_args.data_path, f)
164
+ for f in os.listdir(data_args.data_path)
165
+ if f.endswith(".parquet")
166
+ ]
167
+ raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
168
+ else:
169
+ raise ValueError(f"Data path not found: {data_args.data_path}")
170
+
171
+ logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
172
+
173
+ train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
174
+ logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
175
+
176
+ data_collator = DataCollatorForLanguageModeling(
177
+ tokenizer=tokenizer,
178
+ mlm=False,
179
+ )
180
+
181
+ trainer = Trainer(
182
+ model=model,
183
+ args=training_args,
184
+ train_dataset=train_dataset,
185
+ data_collator=data_collator,
186
+ )
187
+
188
+ logger.info("Starting training...")
189
+ train_result = trainer.train(
190
+ resume_from_checkpoint=training_args.resume_from_checkpoint
191
+ )
192
+
193
+ trainer.save_model()
194
+ trainer.save_state()
195
+
196
+ metrics = train_result.metrics
197
+ metrics["train_samples"] = len(train_dataset)
198
+ trainer.log_metrics("train", metrics)
199
+ trainer.save_metrics("train", metrics)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
example/train_sft.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
3
+ """
4
+
5
+ import json
6
+ import logging
7
+ import os
8
+ from dataclasses import dataclass, field
9
+ from typing import Any, Dict, List, Optional, Tuple
10
+
11
+ import contextlib
12
+
13
+ import torch
14
+ from datasets import load_dataset
15
+ from transformers import (
16
+ AutoModelForCausalLM,
17
+ AutoTokenizer,
18
+ HfArgumentParser,
19
+ Trainer,
20
+ TrainingArguments,
21
+ set_seed,
22
+ )
23
+
24
+ import deepspeed
25
+ _orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
26
+
27
+ @contextlib.contextmanager
28
+ def _patched_no_sync(self):
29
+ try:
30
+ with _orig_no_sync(self):
31
+ yield
32
+ except AssertionError:
33
+ yield
34
+
35
+ deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+ IGNORE_INDEX = -100
40
+
41
+
42
+ @dataclass
43
+ class ModelArguments:
44
+ model_name_or_path: str = field(
45
+ metadata={"help": "Path to pretrained model or model identifier"}
46
+ )
47
+ torch_dtype: Optional[str] = field(
48
+ default="bfloat16",
49
+ metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
50
+ )
51
+
52
+
53
+ @dataclass
54
+ class DataArguments:
55
+ data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
56
+ max_seq_length: int = field(
57
+ default=4096,
58
+ metadata={"help": "Maximum sequence length for training"},
59
+ )
60
+ prompt_column: Optional[str] = field(
61
+ default=None,
62
+ metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
63
+ )
64
+ input_column: Optional[str] = field(
65
+ default=None,
66
+ metadata={"help": "Optional extra input/context column name"},
67
+ )
68
+ response_column: Optional[str] = field(
69
+ default=None,
70
+ metadata={"help": "Response/output column name. Auto-detected if omitted."},
71
+ )
72
+ messages_column: Optional[str] = field(
73
+ default=None,
74
+ metadata={"help": "Chat messages column name. Auto-detected if omitted."},
75
+ )
76
+ system_column: Optional[str] = field(
77
+ default=None,
78
+ metadata={"help": "Optional system prompt column name"},
79
+ )
80
+ train_on_prompt: bool = field(
81
+ default=False,
82
+ metadata={"help": "Whether to compute loss on prompt/user tokens"},
83
+ )
84
+ add_eos_token: bool = field(
85
+ default=True,
86
+ metadata={"help": "Append eos_token to plain prompt/response examples"},
87
+ )
88
+ preprocessing_num_workers: int = field(
89
+ default=8,
90
+ metadata={"help": "Number of workers for data preprocessing"},
91
+ )
92
+
93
+
94
+ class SFTDataCollator:
95
+ def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
96
+ self.tokenizer = tokenizer
97
+ self.pad_to_multiple_of = pad_to_multiple_of
98
+
99
+ def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
100
+ max_length = max(len(feature["input_ids"]) for feature in features)
101
+ if self.pad_to_multiple_of:
102
+ multiple = self.pad_to_multiple_of
103
+ max_length = ((max_length + multiple - 1) // multiple) * multiple
104
+
105
+ input_ids = []
106
+ attention_mask = []
107
+ labels = []
108
+ pad_token_id = self.tokenizer.pad_token_id
109
+
110
+ for feature in features:
111
+ length = len(feature["input_ids"])
112
+ pad_length = max_length - length
113
+ input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
114
+ attention_mask.append([1] * length + [0] * pad_length)
115
+ labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
116
+
117
+ return {
118
+ "input_ids": torch.tensor(input_ids, dtype=torch.long),
119
+ "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
120
+ "labels": torch.tensor(labels, dtype=torch.long),
121
+ }
122
+
123
+
124
+ def load_sft_dataset(data_path: str):
125
+ if os.path.isfile(data_path):
126
+ extension = os.path.splitext(data_path)[1].lstrip(".").lower()
127
+ if extension == "jsonl":
128
+ extension = "json"
129
+ if extension not in {"parquet", "json", "csv", "txt"}:
130
+ raise ValueError(f"Unsupported data file extension: {extension}")
131
+ return load_dataset(extension, data_files=data_path, split="train")
132
+
133
+ if os.path.isdir(data_path):
134
+ data_files = []
135
+ extension = None
136
+ for name in os.listdir(data_path):
137
+ current_extension = os.path.splitext(name)[1].lstrip(".").lower()
138
+ if current_extension == "jsonl":
139
+ current_extension = "json"
140
+ if current_extension in {"parquet", "json", "csv", "txt"}:
141
+ extension = extension or current_extension
142
+ if current_extension == extension:
143
+ data_files.append(os.path.join(data_path, name))
144
+ if not data_files or extension is None:
145
+ raise ValueError(f"No supported data files found in: {data_path}")
146
+ return load_dataset(extension, data_files=sorted(data_files), split="train")
147
+
148
+ raise ValueError(f"Data path not found: {data_path}")
149
+
150
+
151
+ def choose_column(
152
+ column_names: List[str], explicit: Optional[str], candidates: List[str]
153
+ ) -> Optional[str]:
154
+ if explicit:
155
+ if explicit not in column_names:
156
+ raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
157
+ return explicit
158
+ for name in candidates:
159
+ if name in column_names:
160
+ return name
161
+ return None
162
+
163
+
164
+ def parse_messages(value: Any) -> List[Dict[str, str]]:
165
+ if isinstance(value, str):
166
+ value = json.loads(value)
167
+ if not isinstance(value, list):
168
+ raise ValueError("messages/conversations column must be a list or JSON string")
169
+
170
+ messages = []
171
+ for item in value:
172
+ if not isinstance(item, dict):
173
+ raise ValueError("Each message must be a dict")
174
+
175
+ role = item.get("role", item.get("from"))
176
+ content = item.get("content", item.get("value"))
177
+ if role == "human":
178
+ role = "user"
179
+ elif role == "gpt":
180
+ role = "assistant"
181
+
182
+ if role is None or content is None:
183
+ raise ValueError("Each message must contain role/from and content/value")
184
+ messages.append({"role": str(role), "content": str(content)})
185
+
186
+ return messages
187
+
188
+
189
+ def tokenize_text(tokenizer, text: str) -> List[int]:
190
+ return tokenizer(text, add_special_tokens=False)["input_ids"]
191
+
192
+
193
+ def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
194
+ if tokenizer.chat_template is None:
195
+ raise ValueError(
196
+ "The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
197
+ )
198
+ return tokenizer.apply_chat_template(
199
+ messages,
200
+ tokenize=False,
201
+ add_generation_prompt=add_generation_prompt,
202
+ )
203
+
204
+
205
+ def encode_prompt_response(
206
+ example: Dict[str, Any],
207
+ tokenizer,
208
+ data_args: DataArguments,
209
+ prompt_column: str,
210
+ input_column: Optional[str],
211
+ response_column: str,
212
+ ) -> Tuple[List[int], List[int]]:
213
+ prompt = str(example[prompt_column])
214
+ if input_column and example.get(input_column):
215
+ prompt = prompt + "\n" + str(example[input_column])
216
+ response = str(example[response_column])
217
+
218
+ messages = []
219
+ if data_args.system_column and example.get(data_args.system_column):
220
+ messages.append({"role": "system", "content": str(example[data_args.system_column])})
221
+ messages.append({"role": "user", "content": prompt})
222
+ messages.append({"role": "assistant", "content": response})
223
+
224
+ if tokenizer.chat_template is not None:
225
+ full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
226
+ prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
227
+ input_ids = tokenize_text(tokenizer, full_text)
228
+ prompt_length = len(tokenize_text(tokenizer, prompt_text))
229
+ else:
230
+ response_text = response
231
+ if data_args.add_eos_token and tokenizer.eos_token:
232
+ response_text += tokenizer.eos_token
233
+ full_text = prompt + "\n" + response_text
234
+ input_ids = tokenize_text(tokenizer, full_text)
235
+ prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
236
+
237
+ labels = input_ids.copy()
238
+ if not data_args.train_on_prompt:
239
+ labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
240
+ return input_ids, labels
241
+
242
+
243
+ def encode_messages(
244
+ example: Dict[str, Any],
245
+ tokenizer,
246
+ data_args: DataArguments,
247
+ messages_column: str,
248
+ ) -> Tuple[List[int], List[int]]:
249
+ messages = parse_messages(example[messages_column])
250
+
251
+ if tokenizer.chat_template is not None:
252
+ full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
253
+ input_ids = tokenize_text(tokenizer, full_text)
254
+ labels = [IGNORE_INDEX] * len(input_ids)
255
+
256
+ if data_args.train_on_prompt:
257
+ labels = input_ids.copy()
258
+ else:
259
+ for index, message in enumerate(messages):
260
+ if message["role"] != "assistant":
261
+ continue
262
+ before_text = apply_chat_template(
263
+ tokenizer, messages[:index], add_generation_prompt=True
264
+ )
265
+ after_text = apply_chat_template(
266
+ tokenizer, messages[: index + 1], add_generation_prompt=False
267
+ )
268
+ start = len(tokenize_text(tokenizer, before_text))
269
+ end = len(tokenize_text(tokenizer, after_text))
270
+ labels[start:end] = input_ids[start:end]
271
+ else:
272
+ labels = []
273
+ input_ids = []
274
+ for message in messages:
275
+ part = f"{message['role']}: {message['content']}\n"
276
+ if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
277
+ part += tokenizer.eos_token
278
+ part_ids = tokenize_text(tokenizer, part)
279
+ input_ids.extend(part_ids)
280
+ if data_args.train_on_prompt or message["role"] == "assistant":
281
+ labels.extend(part_ids)
282
+ else:
283
+ labels.extend([IGNORE_INDEX] * len(part_ids))
284
+
285
+ return input_ids, labels
286
+
287
+
288
+ def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
289
+ column_names = raw_dataset.column_names
290
+ messages_column = choose_column(
291
+ column_names, data_args.messages_column, ["messages", "conversations"]
292
+ )
293
+ prompt_column = choose_column(
294
+ column_names,
295
+ data_args.prompt_column,
296
+ ["prompt", "instruction", "question"],
297
+ )
298
+ input_column = choose_column(
299
+ column_names,
300
+ data_args.input_column,
301
+ ["input", "context"],
302
+ )
303
+ response_column = choose_column(
304
+ column_names,
305
+ data_args.response_column,
306
+ ["response", "output", "answer", "chosen"],
307
+ )
308
+
309
+ if messages_column:
310
+ logger.info(f"Using chat messages column: {messages_column}")
311
+ elif prompt_column and response_column:
312
+ logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
313
+ else:
314
+ raise ValueError(
315
+ "Cannot infer SFT data format. Provide either messages/conversations or "
316
+ "prompt/instruction plus response/output columns."
317
+ )
318
+
319
+ def encode_batch(examples):
320
+ batch_input_ids = []
321
+ batch_labels = []
322
+ batch_attention_mask = []
323
+
324
+ batch_size = len(next(iter(examples.values())))
325
+ for i in range(batch_size):
326
+ example = {name: values[i] for name, values in examples.items()}
327
+ if messages_column:
328
+ input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
329
+ else:
330
+ input_ids, labels = encode_prompt_response(
331
+ example, tokenizer, data_args, prompt_column, input_column, response_column
332
+ )
333
+
334
+ input_ids = input_ids[: data_args.max_seq_length]
335
+ labels = labels[: data_args.max_seq_length]
336
+ if not input_ids or all(label == IGNORE_INDEX for label in labels):
337
+ continue
338
+
339
+ batch_input_ids.append(input_ids)
340
+ batch_labels.append(labels)
341
+ batch_attention_mask.append([1] * len(input_ids))
342
+
343
+ return {
344
+ "input_ids": batch_input_ids,
345
+ "attention_mask": batch_attention_mask,
346
+ "labels": batch_labels,
347
+ }
348
+
349
+ return raw_dataset.map(
350
+ encode_batch,
351
+ batched=True,
352
+ num_proc=data_args.preprocessing_num_workers,
353
+ remove_columns=column_names,
354
+ desc="Tokenizing SFT data",
355
+ )
356
+
357
+
358
+ def main():
359
+ parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
360
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
361
+
362
+ logging.basicConfig(
363
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
364
+ datefmt="%Y-%m-%d %H:%M:%S",
365
+ level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
366
+ )
367
+ logger.info(f"Training args: {training_args}")
368
+
369
+ set_seed(training_args.seed)
370
+
371
+ dtype_map = {
372
+ "float16": torch.float16,
373
+ "bfloat16": torch.bfloat16,
374
+ "float32": torch.float32,
375
+ }
376
+ torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
377
+
378
+ logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
379
+ tokenizer = AutoTokenizer.from_pretrained(
380
+ model_args.model_name_or_path,
381
+ trust_remote_code=True,
382
+ )
383
+ if tokenizer.pad_token is None:
384
+ tokenizer.pad_token = tokenizer.eos_token
385
+
386
+ logger.info(f"Loading model from {model_args.model_name_or_path}")
387
+ model = AutoModelForCausalLM.from_pretrained(
388
+ model_args.model_name_or_path,
389
+ torch_dtype=torch_dtype,
390
+ trust_remote_code=True,
391
+ attn_implementation="sdpa",
392
+ )
393
+ model.config.use_cache = False
394
+
395
+ logger.info(f"Loading SFT dataset from {data_args.data_path}")
396
+ raw_dataset = load_sft_dataset(data_args.data_path)
397
+ logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
398
+
399
+ train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
400
+ logger.info(f"Processed dataset: {len(train_dataset)} samples")
401
+
402
+ trainer = Trainer(
403
+ model=model,
404
+ args=training_args,
405
+ train_dataset=train_dataset,
406
+ data_collator=SFTDataCollator(tokenizer),
407
+ )
408
+
409
+ logger.info("Starting SFT training...")
410
+ train_result = trainer.train(
411
+ resume_from_checkpoint=training_args.resume_from_checkpoint
412
+ )
413
+
414
+ trainer.save_model()
415
+ trainer.save_state()
416
+
417
+ metrics = train_result.metrics
418
+ metrics["train_samples"] = len(train_dataset)
419
+ trainer.log_metrics("train", metrics)
420
+ trainer.save_metrics("train", metrics)
421
+
422
+
423
+ if __name__ == "__main__":
424
+ main()