""" Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer. """ import json import logging import os from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Tuple import contextlib import torch from datasets import load_dataset from transformers import ( AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, Trainer, TrainingArguments, set_seed, ) import deepspeed _orig_no_sync = deepspeed.DeepSpeedEngine.no_sync @contextlib.contextmanager def _patched_no_sync(self): try: with _orig_no_sync(self): yield except AssertionError: yield deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync logger = logging.getLogger(__name__) IGNORE_INDEX = -100 @dataclass class ModelArguments: model_name_or_path: str = field( metadata={"help": "Path to pretrained model or model identifier"} ) torch_dtype: Optional[str] = field( default="bfloat16", metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"}, ) @dataclass class DataArguments: data_path: str = field(metadata={"help": "Path to SFT data file or directory"}) max_seq_length: int = field( default=4096, metadata={"help": "Maximum sequence length for training"}, ) prompt_column: Optional[str] = field( default=None, metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."}, ) input_column: Optional[str] = field( default=None, metadata={"help": "Optional extra input/context column name"}, ) response_column: Optional[str] = field( default=None, metadata={"help": "Response/output column name. Auto-detected if omitted."}, ) messages_column: Optional[str] = field( default=None, metadata={"help": "Chat messages column name. Auto-detected if omitted."}, ) system_column: Optional[str] = field( default=None, metadata={"help": "Optional system prompt column name"}, ) train_on_prompt: bool = field( default=False, metadata={"help": "Whether to compute loss on prompt/user tokens"}, ) add_eos_token: bool = field( default=True, metadata={"help": "Append eos_token to plain prompt/response examples"}, ) preprocessing_num_workers: int = field( default=8, metadata={"help": "Number of workers for data preprocessing"}, ) class SFTDataCollator: def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8): self.tokenizer = tokenizer self.pad_to_multiple_of = pad_to_multiple_of def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]: max_length = max(len(feature["input_ids"]) for feature in features) if self.pad_to_multiple_of: multiple = self.pad_to_multiple_of max_length = ((max_length + multiple - 1) // multiple) * multiple input_ids = [] attention_mask = [] labels = [] pad_token_id = self.tokenizer.pad_token_id for feature in features: length = len(feature["input_ids"]) pad_length = max_length - length input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length) attention_mask.append([1] * length + [0] * pad_length) labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length) return { "input_ids": torch.tensor(input_ids, dtype=torch.long), "attention_mask": torch.tensor(attention_mask, dtype=torch.long), "labels": torch.tensor(labels, dtype=torch.long), } def load_sft_dataset(data_path: str): if os.path.isfile(data_path): extension = os.path.splitext(data_path)[1].lstrip(".").lower() if extension == "jsonl": extension = "json" if extension not in {"parquet", "json", "csv", "txt"}: raise ValueError(f"Unsupported data file extension: {extension}") return load_dataset(extension, data_files=data_path, split="train") if os.path.isdir(data_path): data_files = [] extension = None for name in os.listdir(data_path): current_extension = os.path.splitext(name)[1].lstrip(".").lower() if current_extension == "jsonl": current_extension = "json" if current_extension in {"parquet", "json", "csv", "txt"}: extension = extension or current_extension if current_extension == extension: data_files.append(os.path.join(data_path, name)) if not data_files or extension is None: raise ValueError(f"No supported data files found in: {data_path}") return load_dataset(extension, data_files=sorted(data_files), split="train") raise ValueError(f"Data path not found: {data_path}") def choose_column( column_names: List[str], explicit: Optional[str], candidates: List[str] ) -> Optional[str]: if explicit: if explicit not in column_names: raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}") return explicit for name in candidates: if name in column_names: return name return None def parse_messages(value: Any) -> List[Dict[str, str]]: if isinstance(value, str): value = json.loads(value) if not isinstance(value, list): raise ValueError("messages/conversations column must be a list or JSON string") messages = [] for item in value: if not isinstance(item, dict): raise ValueError("Each message must be a dict") role = item.get("role", item.get("from")) content = item.get("content", item.get("value")) if role == "human": role = "user" elif role == "gpt": role = "assistant" if role is None or content is None: raise ValueError("Each message must contain role/from and content/value") messages.append({"role": str(role), "content": str(content)}) return messages def tokenize_text(tokenizer, text: str) -> List[int]: return tokenizer(text, add_special_tokens=False)["input_ids"] def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str: if tokenizer.chat_template is None: raise ValueError( "The tokenizer has no chat_template. Use prompt/response columns or set a chat_template." ) return tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=add_generation_prompt, ) def encode_prompt_response( example: Dict[str, Any], tokenizer, data_args: DataArguments, prompt_column: str, input_column: Optional[str], response_column: str, ) -> Tuple[List[int], List[int]]: prompt = str(example[prompt_column]) if input_column and example.get(input_column): prompt = prompt + "\n" + str(example[input_column]) response = str(example[response_column]) messages = [] if data_args.system_column and example.get(data_args.system_column): messages.append({"role": "system", "content": str(example[data_args.system_column])}) messages.append({"role": "user", "content": prompt}) messages.append({"role": "assistant", "content": response}) if tokenizer.chat_template is not None: full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False) prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True) input_ids = tokenize_text(tokenizer, full_text) prompt_length = len(tokenize_text(tokenizer, prompt_text)) else: response_text = response if data_args.add_eos_token and tokenizer.eos_token: response_text += tokenizer.eos_token full_text = prompt + "\n" + response_text input_ids = tokenize_text(tokenizer, full_text) prompt_length = len(tokenize_text(tokenizer, prompt + "\n")) labels = input_ids.copy() if not data_args.train_on_prompt: labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels)) return input_ids, labels def encode_messages( example: Dict[str, Any], tokenizer, data_args: DataArguments, messages_column: str, ) -> Tuple[List[int], List[int]]: messages = parse_messages(example[messages_column]) if tokenizer.chat_template is not None: full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False) input_ids = tokenize_text(tokenizer, full_text) labels = [IGNORE_INDEX] * len(input_ids) if data_args.train_on_prompt: labels = input_ids.copy() else: for index, message in enumerate(messages): if message["role"] != "assistant": continue before_text = apply_chat_template( tokenizer, messages[:index], add_generation_prompt=True ) after_text = apply_chat_template( tokenizer, messages[: index + 1], add_generation_prompt=False ) start = len(tokenize_text(tokenizer, before_text)) end = len(tokenize_text(tokenizer, after_text)) labels[start:end] = input_ids[start:end] else: labels = [] input_ids = [] for message in messages: part = f"{message['role']}: {message['content']}\n" if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token: part += tokenizer.eos_token part_ids = tokenize_text(tokenizer, part) input_ids.extend(part_ids) if data_args.train_on_prompt or message["role"] == "assistant": labels.extend(part_ids) else: labels.extend([IGNORE_INDEX] * len(part_ids)) return input_ids, labels def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments): column_names = raw_dataset.column_names messages_column = choose_column( column_names, data_args.messages_column, ["messages", "conversations"] ) prompt_column = choose_column( column_names, data_args.prompt_column, ["prompt", "instruction", "question"], ) input_column = choose_column( column_names, data_args.input_column, ["input", "context"], ) response_column = choose_column( column_names, data_args.response_column, ["response", "output", "answer", "chosen"], ) if messages_column: logger.info(f"Using chat messages column: {messages_column}") elif prompt_column and response_column: logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'") else: raise ValueError( "Cannot infer SFT data format. Provide either messages/conversations or " "prompt/instruction plus response/output columns." ) def encode_batch(examples): batch_input_ids = [] batch_labels = [] batch_attention_mask = [] batch_size = len(next(iter(examples.values()))) for i in range(batch_size): example = {name: values[i] for name, values in examples.items()} if messages_column: input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column) else: input_ids, labels = encode_prompt_response( example, tokenizer, data_args, prompt_column, input_column, response_column ) input_ids = input_ids[: data_args.max_seq_length] labels = labels[: data_args.max_seq_length] if not input_ids or all(label == IGNORE_INDEX for label in labels): continue batch_input_ids.append(input_ids) batch_labels.append(labels) batch_attention_mask.append([1] * len(input_ids)) return { "input_ids": batch_input_ids, "attention_mask": batch_attention_mask, "labels": batch_labels, } return raw_dataset.map( encode_batch, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, desc="Tokenizing SFT data", ) def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() logging.basicConfig( format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN, ) logger.info(f"Training args: {training_args}") set_seed(training_args.seed) dtype_map = { "float16": torch.float16, "bfloat16": torch.bfloat16, "float32": torch.float32, } torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16) logger.info(f"Loading tokenizer from {model_args.model_name_or_path}") tokenizer = AutoTokenizer.from_pretrained( model_args.model_name_or_path, trust_remote_code=True, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token logger.info(f"Loading model from {model_args.model_name_or_path}") model = AutoModelForCausalLM.from_pretrained( model_args.model_name_or_path, torch_dtype=torch_dtype, trust_remote_code=True, attn_implementation="sdpa", ) model.config.use_cache = False logger.info(f"Loading SFT dataset from {data_args.data_path}") raw_dataset = load_sft_dataset(data_args.data_path) logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}") train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args) logger.info(f"Processed dataset: {len(train_dataset)} samples") trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, data_collator=SFTDataCollator(tokenizer), ) logger.info("Starting SFT training...") train_result = trainer.train( resume_from_checkpoint=training_args.resume_from_checkpoint ) trainer.save_model() trainer.save_state() metrics = train_result.metrics metrics["train_samples"] = len(train_dataset) trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) if __name__ == "__main__": main()