guanwenyu1995's picture
Upload folder using huggingface_hub
1a8f0d8 verified
"""
Continual pretraining script for CPM-2B model using DeepSpeed + HuggingFace Trainer.
"""
import os
import json
import math
import logging
from dataclasses import dataclass, field
from typing import Optional
import contextlib
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
Trainer,
TrainingArguments,
HfArgumentParser,
DataCollatorForLanguageModeling,
set_seed,
)
import deepspeed
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
@contextlib.contextmanager
def _patched_no_sync(self):
try:
with _orig_no_sync(self):
yield
except AssertionError:
yield
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
logger = logging.getLogger(__name__)
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier"}
)
torch_dtype: Optional[str] = field(
default="bfloat16",
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
)
@dataclass
class DataArguments:
data_path: str = field(
metadata={"help": "Path to training data (parquet file or directory)"}
)
max_seq_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length for training"},
)
text_column: str = field(
default="text",
metadata={"help": "Name of the text column in the dataset"},
)
preprocessing_num_workers: int = field(
default=8,
metadata={"help": "Number of workers for data preprocessing"},
)
def tokenize_and_group(dataset, tokenizer, data_args):
"""Tokenize texts and group into chunks of max_seq_length."""
column_names = dataset.column_names
text_column = data_args.text_column
if text_column not in column_names:
candidates = [c for c in column_names if "text" in c.lower()]
if candidates:
text_column = candidates[0]
else:
text_column = column_names[0]
logger.warning(f"Column '{data_args.text_column}' not found, using '{text_column}'")
def tokenize_function(examples):
return tokenizer(examples[text_column], add_special_tokens=False)
tokenized_dataset = dataset.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Tokenizing",
)
block_size = data_args.max_seq_length
def group_texts(examples):
concatenated = {k: sum(examples[k], []) for k in examples.keys()}
total_length = len(concatenated["input_ids"])
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated.items()
}
result["labels"] = result["input_ids"].copy()
return result
grouped_dataset = tokenized_dataset.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
desc="Grouping texts",
)
return grouped_dataset
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.info(f"Training args: {training_args}")
set_seed(training_args.seed)
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Loading model from {model_args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
attn_implementation="sdpa",
)
model.config.use_cache = False
logger.info(f"Loading dataset from {data_args.data_path}")
if os.path.isfile(data_args.data_path):
raw_dataset = load_dataset("parquet", data_files=data_args.data_path, split="train")
elif os.path.isdir(data_args.data_path):
parquet_files = [
os.path.join(data_args.data_path, f)
for f in os.listdir(data_args.data_path)
if f.endswith(".parquet")
]
raw_dataset = load_dataset("parquet", data_files=parquet_files, split="train")
else:
raise ValueError(f"Data path not found: {data_args.data_path}")
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
train_dataset = tokenize_and_group(raw_dataset, tokenizer, data_args)
logger.info(f"Processed dataset: {len(train_dataset)} samples of length {data_args.max_seq_length}")
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=data_collator,
)
logger.info("Starting training...")
train_result = trainer.train(
resume_from_checkpoint=training_args.resume_from_checkpoint
)
trainer.save_model()
trainer.save_state()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
if __name__ == "__main__":
main()