""" Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer. """ import os import json import math import logging from dataclasses import dataclass, field from typing import Optional import contextlib import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, AutoConfig, Trainer, TrainingArguments, HfArgumentParser, DataCollatorForLanguageModeling, set_seed, ) import deepspeed _orig_no_sync = deepspeed.DeepSpeedEngine.no_sync @contextlib.contextmanager def _patched_no_sync(self): try: with _orig_no_sync(self): yield except AssertionError: yield deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync logger = logging.getLogger(__name__) @dataclass class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier"} ) torch_dtype: Optional[str] = field( default="bfloat16", metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"}, ) @dataclass class DataArguments: data_path: str = field( metadata={"help": "Path to training data (parquet file or directory)"} ) max_seq_length: int = field( default=4096, metadata={"help": "Maximum sequence length for training"}, ) text_column: str = field( default="text", metadata={"help": "Name of the text column in the dataset"}, ) preprocessing_num_workers: int = field( default=8, metadata={"help": "Number of workers for data preprocessing"}, ) def tokenize_and_group(dataset, tokenizer, data_args): """Tokenize texts and group into chunks of max_seq_length.""" column_names = dataset.column_names text_column = data_args.text_column if text_column not in column_names: candidates = [c for c in column_names if "text" in c.lower()] if candidates: text_column = candidates[0] else: text_column = column_names[0] logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'") def tokenize_function(examples): return tokenizer(examples[text_column], add_special_tokens=False) tokenized_dataset = dataset.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, desc="Tokenizing", ) block_size = data_args.max_seq_length def group_texts(examples): concatenated = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated["input_ids"]) total_length = (total_length // block_size) * block_size result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated.items() } result["labels"] = result["input_ids"].copy() return result grouped_dataset = tokenized_dataset.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, desc="Grouping texts", ) return grouped_dataset def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.info(f"Training args: {training_args}") set_seed(training_args.seed) dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16) logger.info(f"Loading tokenizer from {model_args.model_name_or_path}") tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Loading model from {model_args.model_name_or_path}") model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation="sdpa", ) model.config.use_cache = False logger.info(f"Loading dataset from {data_args.data_path}") if os.path.isfile(data_args.data_path): raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train") elif os.path.isdir(data_args.data_path): parquet_files = [ os.path.join(data_args.data_path, f) for f in os.listdir(data_args.data_path) if f.endswith(".parquet") ] raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train") else: raise ValueError(f"Data path not found: {data_args.data_path}") logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}") train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args) logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}") data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=data_collator, ) logger.info("Starting training...") train_result = trainer.train( resume_from_checkpoint=training_args.resume_from_checkpoint ) trainer.save_model() trainer.save_state() metrics = train_result.metrics metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) if __name__ == "__main__": main()