guanwenyu1995's picture
Upload folder using huggingface_hub
c5f8654 verified
"""
Supervised fine-tuning script using DeepSpeed + HuggingFace Trainer.
"""
import json
import logging
import os
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import contextlib
import torch
from datasets import load_dataset
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
HfArgumentParser,
Trainer,
TrainingArguments,
set_seed,
)
import deepspeed
_orig_no_sync = deepspeed.DeepSpeedEngine.no_sync
@contextlib.contextmanager
def _patched_no_sync(self):
try:
with _orig_no_sync(self):
yield
except AssertionError:
yield
deepspeed.DeepSpeedEngine.no_sync = _patched_no_sync
logger = logging.getLogger(__name__)
IGNORE_INDEX = -100
@dataclass
class ModelArguments:
model_name_or_path: str = field(
metadata={"help": "Path to pretrained model or model identifier"}
)
torch_dtype: Optional[str] = field(
default="bfloat16",
metadata={"help": "torch dtype for model weights (float16, bfloat16, float32)"},
)
@dataclass
class DataArguments:
data_path: str = field(metadata={"help": "Path to SFT data file or directory"})
max_seq_length: int = field(
default=4096,
metadata={"help": "Maximum sequence length for training"},
)
prompt_column: Optional[str] = field(
default=None,
metadata={"help": "Prompt/instruction column name. Auto-detected if omitted."},
)
input_column: Optional[str] = field(
default=None,
metadata={"help": "Optional extra input/context column name"},
)
response_column: Optional[str] = field(
default=None,
metadata={"help": "Response/output column name. Auto-detected if omitted."},
)
messages_column: Optional[str] = field(
default=None,
metadata={"help": "Chat messages column name. Auto-detected if omitted."},
)
system_column: Optional[str] = field(
default=None,
metadata={"help": "Optional system prompt column name"},
)
train_on_prompt: bool = field(
default=False,
metadata={"help": "Whether to compute loss on prompt/user tokens"},
)
add_eos_token: bool = field(
default=True,
metadata={"help": "Append eos_token to plain prompt/response examples"},
)
preprocessing_num_workers: int = field(
default=8,
metadata={"help": "Number of workers for data preprocessing"},
)
class SFTDataCollator:
def __init__(self, tokenizer, pad_to_multiple_of: Optional[int] = 8):
self.tokenizer = tokenizer
self.pad_to_multiple_of = pad_to_multiple_of
def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
max_length = max(len(feature["input_ids"]) for feature in features)
if self.pad_to_multiple_of:
multiple = self.pad_to_multiple_of
max_length = ((max_length + multiple - 1) // multiple) * multiple
input_ids = []
attention_mask = []
labels = []
pad_token_id = self.tokenizer.pad_token_id
for feature in features:
length = len(feature["input_ids"])
pad_length = max_length - length
input_ids.append(feature["input_ids"] + [pad_token_id] * pad_length)
attention_mask.append([1] * length + [0] * pad_length)
labels.append(feature["labels"] + [IGNORE_INDEX] * pad_length)
return {
"input_ids": torch.tensor(input_ids, dtype=torch.long),
"attention_mask": torch.tensor(attention_mask, dtype=torch.long),
"labels": torch.tensor(labels, dtype=torch.long),
}
def load_sft_dataset(data_path: str):
if os.path.isfile(data_path):
extension = os.path.splitext(data_path)[1].lstrip(".").lower()
if extension == "jsonl":
extension = "json"
if extension not in {"parquet", "json", "csv", "txt"}:
raise ValueError(f"Unsupported data file extension: {extension}")
return load_dataset(extension, data_files=data_path, split="train")
if os.path.isdir(data_path):
data_files = []
extension = None
for name in os.listdir(data_path):
current_extension = os.path.splitext(name)[1].lstrip(".").lower()
if current_extension == "jsonl":
current_extension = "json"
if current_extension in {"parquet", "json", "csv", "txt"}:
extension = extension or current_extension
if current_extension == extension:
data_files.append(os.path.join(data_path, name))
if not data_files or extension is None:
raise ValueError(f"No supported data files found in: {data_path}")
return load_dataset(extension, data_files=sorted(data_files), split="train")
raise ValueError(f"Data path not found: {data_path}")
def choose_column(
column_names: List[str], explicit: Optional[str], candidates: List[str]
) -> Optional[str]:
if explicit:
if explicit not in column_names:
raise ValueError(f"Column '{explicit}' not found. Available columns: {column_names}")
return explicit
for name in candidates:
if name in column_names:
return name
return None
def parse_messages(value: Any) -> List[Dict[str, str]]:
if isinstance(value, str):
value = json.loads(value)
if not isinstance(value, list):
raise ValueError("messages/conversations column must be a list or JSON string")
messages = []
for item in value:
if not isinstance(item, dict):
raise ValueError("Each message must be a dict")
role = item.get("role", item.get("from"))
content = item.get("content", item.get("value"))
if role == "human":
role = "user"
elif role == "gpt":
role = "assistant"
if role is None or content is None:
raise ValueError("Each message must contain role/from and content/value")
messages.append({"role": str(role), "content": str(content)})
return messages
def tokenize_text(tokenizer, text: str) -> List[int]:
return tokenizer(text, add_special_tokens=False)["input_ids"]
def apply_chat_template(tokenizer, messages: List[Dict[str, str]], add_generation_prompt: bool) -> str:
if tokenizer.chat_template is None:
raise ValueError(
"The tokenizer has no chat_template. Use prompt/response columns or set a chat_template."
)
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=add_generation_prompt,
)
def encode_prompt_response(
example: Dict[str, Any],
tokenizer,
data_args: DataArguments,
prompt_column: str,
input_column: Optional[str],
response_column: str,
) -> Tuple[List[int], List[int]]:
prompt = str(example[prompt_column])
if input_column and example.get(input_column):
prompt = prompt + "\n" + str(example[input_column])
response = str(example[response_column])
messages = []
if data_args.system_column and example.get(data_args.system_column):
messages.append({"role": "system", "content": str(example[data_args.system_column])})
messages.append({"role": "user", "content": prompt})
messages.append({"role": "assistant", "content": response})
if tokenizer.chat_template is not None:
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
prompt_text = apply_chat_template(tokenizer, messages[:-1], add_generation_prompt=True)
input_ids = tokenize_text(tokenizer, full_text)
prompt_length = len(tokenize_text(tokenizer, prompt_text))
else:
response_text = response
if data_args.add_eos_token and tokenizer.eos_token:
response_text += tokenizer.eos_token
full_text = prompt + "\n" + response_text
input_ids = tokenize_text(tokenizer, full_text)
prompt_length = len(tokenize_text(tokenizer, prompt + "\n"))
labels = input_ids.copy()
if not data_args.train_on_prompt:
labels[:prompt_length] = [IGNORE_INDEX] * min(prompt_length, len(labels))
return input_ids, labels
def encode_messages(
example: Dict[str, Any],
tokenizer,
data_args: DataArguments,
messages_column: str,
) -> Tuple[List[int], List[int]]:
messages = parse_messages(example[messages_column])
if tokenizer.chat_template is not None:
full_text = apply_chat_template(tokenizer, messages, add_generation_prompt=False)
input_ids = tokenize_text(tokenizer, full_text)
labels = [IGNORE_INDEX] * len(input_ids)
if data_args.train_on_prompt:
labels = input_ids.copy()
else:
for index, message in enumerate(messages):
if message["role"] != "assistant":
continue
before_text = apply_chat_template(
tokenizer, messages[:index], add_generation_prompt=True
)
after_text = apply_chat_template(
tokenizer, messages[: index + 1], add_generation_prompt=False
)
start = len(tokenize_text(tokenizer, before_text))
end = len(tokenize_text(tokenizer, after_text))
labels[start:end] = input_ids[start:end]
else:
labels = []
input_ids = []
for message in messages:
part = f"{message['role']}: {message['content']}\n"
if data_args.add_eos_token and message["role"] == "assistant" and tokenizer.eos_token:
part += tokenizer.eos_token
part_ids = tokenize_text(tokenizer, part)
input_ids.extend(part_ids)
if data_args.train_on_prompt or message["role"] == "assistant":
labels.extend(part_ids)
else:
labels.extend([IGNORE_INDEX] * len(part_ids))
return input_ids, labels
def preprocess_sft_dataset(raw_dataset, tokenizer, data_args: DataArguments):
column_names = raw_dataset.column_names
messages_column = choose_column(
column_names, data_args.messages_column, ["messages", "conversations"]
)
prompt_column = choose_column(
column_names,
data_args.prompt_column,
["prompt", "instruction", "question"],
)
input_column = choose_column(
column_names,
data_args.input_column,
["input", "context"],
)
response_column = choose_column(
column_names,
data_args.response_column,
["response", "output", "answer", "chosen"],
)
if messages_column:
logger.info(f"Using chat messages column: {messages_column}")
elif prompt_column and response_column:
logger.info(f"Using prompt column '{prompt_column}' and response column '{response_column}'")
else:
raise ValueError(
"Cannot infer SFT data format. Provide either messages/conversations or "
"prompt/instruction plus response/output columns."
)
def encode_batch(examples):
batch_input_ids = []
batch_labels = []
batch_attention_mask = []
batch_size = len(next(iter(examples.values())))
for i in range(batch_size):
example = {name: values[i] for name, values in examples.items()}
if messages_column:
input_ids, labels = encode_messages(example, tokenizer, data_args, messages_column)
else:
input_ids, labels = encode_prompt_response(
example, tokenizer, data_args, prompt_column, input_column, response_column
)
input_ids = input_ids[: data_args.max_seq_length]
labels = labels[: data_args.max_seq_length]
if not input_ids or all(label == IGNORE_INDEX for label in labels):
continue
batch_input_ids.append(input_ids)
batch_labels.append(labels)
batch_attention_mask.append([1] * len(input_ids))
return {
"input_ids": batch_input_ids,
"attention_mask": batch_attention_mask,
"labels": batch_labels,
}
return raw_dataset.map(
encode_batch,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
desc="Tokenizing SFT data",
)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
)
logger.info(f"Training args: {training_args}")
set_seed(training_args.seed)
dtype_map = {
"float16": torch.float16,
"bfloat16": torch.bfloat16,
"float32": torch.float32,
}
torch_dtype = dtype_map.get(model_args.torch_dtype, torch.bfloat16)
logger.info(f"Loading tokenizer from {model_args.model_name_or_path}")
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
logger.info(f"Loading model from {model_args.model_name_or_path}")
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
torch_dtype=torch_dtype,
trust_remote_code=True,
attn_implementation="sdpa",
)
model.config.use_cache = False
logger.info(f"Loading SFT dataset from {data_args.data_path}")
raw_dataset = load_sft_dataset(data_args.data_path)
logger.info(f"Dataset loaded: {len(raw_dataset)} samples, columns: {raw_dataset.column_names}")
train_dataset = preprocess_sft_dataset(raw_dataset, tokenizer, data_args)
logger.info(f"Processed dataset: {len(train_dataset)} samples")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
data_collator=SFTDataCollator(tokenizer),
)
logger.info("Starting SFT training...")
train_result = trainer.train(
resume_from_checkpoint=training_args.resume_from_checkpoint
)
trainer.save_model()
trainer.save_state()
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
if __name__ == "__main__":
main()