File size: 6,034 Bytes
c5f8654
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
"""
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
"""

import os
import json
import math
import logging
from dataclasses import dataclass, field
from typing import Optional

import contextlib

import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoConfig,
    Trainer,
    TrainingArguments,
    HfArgumentParser,
    DataCollatorForLanguageModeling,
    set_seed,
)

import deepspeed
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync

@contextlib.contextmanager
def _patched_no_sync(self):
    try:
        with _orig_no_sync(self):
            yield
    except AssertionError:
        yield

deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync

logger = logging.getLogger(__name__)


@dataclass
class ModelArguments:
    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier"}
    )
    torch_dtype: Optional[str] = field(
        default="bfloat16",
        metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
    )


@dataclass
class DataArguments:
    data_path: str = field(
        metadata={"help": "Path to training data (parquet file or directory)"}
    )
    max_seq_length: int = field(
        default=4096,
        metadata={"help": "Maximum sequence length for training"},
    )
    text_column: str = field(
        default="text",
        metadata={"help": "Name of the text column in the dataset"},
    )
    preprocessing_num_workers: int = field(
        default=8,
        metadata={"help": "Number of workers for data preprocessing"},
    )


def tokenize_and_group(dataset, tokenizer, data_args):
    """Tokenize texts and group into chunks of max_seq_length."""

    column_names = dataset.column_names
    text_column = data_args.text_column
    if text_column not in column_names:
        candidates = [c for c in column_names if "text" in c.lower()]
        if candidates:
            text_column = candidates[0]
        else:
            text_column = column_names[0]
        logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")

    def tokenize_function(examples):
        return tokenizer(examples[text_column], add_special_tokens=False)

    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        remove_columns=column_names,
        desc="Tokenizing",
    )

    block_size = data_args.max_seq_length

    def group_texts(examples):
        concatenated = {k: sum(examples[k], []) for k in examples.keys()}
        total_length = len(concatenated["input_ids"])
        total_length = (total_length // block_size) * block_size

        result = {
            k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
            for k, t in concatenated.items()
        }
        result["labels"] = result["input_ids"].copy()
        return result

    grouped_dataset = tokenized_dataset.map(
        group_texts,
        batched=True,
        num_proc=data_args.preprocessing_num_workers,
        desc="Grouping texts",
    )

    return grouped_dataset


def main():
    parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.info(f"Training args: {training_args}")

    set_seed(training_args.seed)

    dtype_map = {
        "float16": torch.float16,
        "bfloat16": torch.bfloat16,
        "float32": torch.float32,
    }
    torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)

    logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    logger.info(f"Loading model from {model_args.model_name_or_path}")
    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=torch_dtype,
        trust_remote_code=True,
        attn_implementation="sdpa",
    )
    model.config.use_cache = False

    logger.info(f"Loading dataset from {data_args.data_path}")
    if os.path.isfile(data_args.data_path):
        raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
    elif os.path.isdir(data_args.data_path):
        parquet_files = [
            os.path.join(data_args.data_path, f)
            for f in os.listdir(data_args.data_path)
            if f.endswith(".parquet")
        ]
        raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
    else:
        raise ValueError(f"Data path not found: {data_args.data_path}")

    logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")

    train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
    logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        data_collator=data_collator,
    )

    logger.info("Starting training...")
    train_result = trainer.train(
        resume_from_checkpoint=training_args.resume_from_checkpoint
    )

    trainer.save_model()
    trainer.save_state()

    metrics = train_result.metrics
    metrics["train_samples"] = len(train_dataset)
    trainer.log_metrics("train", metrics)
    trainer.save_metrics("train", metrics)


if __name__ == "__main__":
    main()