|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from datetime import datetime |
|
|
from typing import Dict, Any, Optional, List, Union |
|
|
import json |
|
|
|
|
|
|
|
|
from dotenv import load_dotenv |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
from datasets import load_dataset, DatasetDict, Dataset |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
Trainer, |
|
|
TrainingArguments, |
|
|
DataCollatorForLanguageModeling, |
|
|
set_seed, |
|
|
EarlyStoppingCallback, |
|
|
PreTrainedTokenizerBase, |
|
|
PreTrainedModel, |
|
|
TrainerCallback, |
|
|
) |
|
|
from peft import LoraConfig, get_peft_model, TaskType, PeftModel |
|
|
|
|
|
|
|
|
SPECIAL_TOKENS = ["<startofex>", "<endofex>"] |
|
|
DEFAULT_MODEL_NAME = "gpt2" |
|
|
DEFAULT_BLOCK_SIZE = 128 |
|
|
DEFAULT_EPOCHS = 3 |
|
|
DEFAULT_BATCH_SIZE = 8 |
|
|
DEFAULT_LR = 5e-5 |
|
|
DEFAULT_WEIGHT_DECAY = 0.01 |
|
|
DEFAULT_GRAD_ACCUM_STEPS = 1 |
|
|
DEFAULT_LOGGING_STEPS = 100 |
|
|
DEFAULT_SAVE_EVAL_STEPS = 500 |
|
|
DEFAULT_SAVE_TOTAL_LIMIT = 2 |
|
|
DEFAULT_SEED = 42 |
|
|
DEFAULT_EVAL_STRATEGY = "epoch" |
|
|
DEFAULT_SAVE_STRATEGY = "epoch" |
|
|
DEFAULT_DATA_COLUMN = "text" |
|
|
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', |
|
|
datefmt='%Y-%m-%d %H:%M:%S' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def load_hf_token() -> str: |
|
|
"""Loads Hugging Face token from .env file.""" |
|
|
load_dotenv() |
|
|
token = os.getenv("HF_TOKEN") |
|
|
if not token: |
|
|
logger.error("Hugging Face token (HF_TOKEN) not found in .env file.") |
|
|
raise ValueError("Hugging Face token not found in .env.") |
|
|
logger.info("Hugging Face token loaded successfully.") |
|
|
return token |
|
|
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
|
"""Parses command-line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Fine-tune GPT-2 model using PEFT (LoRA) on an equation dataset." |
|
|
) |
|
|
parser.add_argument("--bf16", action='store_true', help="Use bfloat16 precision training.") |
|
|
parser.add_argument("--dataloader_num_workers", type=int, default=8, help="Number of workers for data loading.") |
|
|
parser.add_argument("--warmup_ratio", type=float, default=0.03, help="Ratio of total steps for learning rate warmup.") |
|
|
parser.add_argument("--max_grad_norm", type=float, default=1.0, help="Maximum gradient norm for gradient clipping.") |
|
|
parser.add_argument("--optim", type=str, default="adamw_torch_fused", choices=["adamw_torch_fused", "adamw_hf", "adamw_torch", "sgd"], |
|
|
help="Optimizer to use during training.") |
|
|
|
|
|
|
|
|
parser.add_argument("--model_name_or_path", type=str, default=DEFAULT_MODEL_NAME, |
|
|
help="Pretrained model name or path (e.g., 'gpt2', 'gpt2-medium').") |
|
|
parser.add_argument("--dataset_repo_id", type=str, required=True, |
|
|
help="Hugging Face Hub repository ID for the dataset (e.g., 'username/my-equation-dataset').") |
|
|
parser.add_argument("--data_dir", type=str, default="10k", |
|
|
help="Directory containing the dataset files within the repo (optional).") |
|
|
parser.add_argument("--source_data_column", type=str, default="i_simple", |
|
|
help="Column name in the *source* dataset to use for training (will be renamed to 'text').") |
|
|
parser.add_argument("--block_size", type=int, default=DEFAULT_BLOCK_SIZE, |
|
|
help="Block size for tokenizing and chunking.") |
|
|
|
|
|
|
|
|
parser.add_argument("--num_train_epochs", type=int, default=DEFAULT_EPOCHS, help="Number of training epochs.") |
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=DEFAULT_BATCH_SIZE, |
|
|
help="Batch size per device during training.") |
|
|
parser.add_argument("--per_device_eval_batch_size", type=int, default=DEFAULT_BATCH_SIZE, |
|
|
help="Batch size per device during evaluation.") |
|
|
parser.add_argument("--learning_rate", type=float, default=DEFAULT_LR, help="Learning rate.") |
|
|
parser.add_argument("--lr_scheduler_type", type=str, default="linear", choices=["linear", "cosine", "cosine_with_restarts", "polynomial", "constant"], |
|
|
help="Learning rate scheduler type.") |
|
|
parser.add_argument("--weight_decay", type=float, default=DEFAULT_WEIGHT_DECAY, help="Weight decay.") |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=DEFAULT_GRAD_ACCUM_STEPS, |
|
|
help="Steps for gradient accumulation.") |
|
|
parser.add_argument("--warmup_steps", type=int, default=0, help="Learning rate scheduler warmup steps.") |
|
|
|
|
|
|
|
|
parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank (dimension).") |
|
|
parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha (scaling factor).") |
|
|
parser.add_argument("--lora_dropout", type=float, default=0.05, help="LoRA dropout.") |
|
|
parser.add_argument("--lora_target_modules", nargs='+', default=["c_attn"], |
|
|
help="Module names to apply LoRA to (e.g., 'c_attn' for GPT-2 query/key/value).") |
|
|
parser.add_argument("--lora_bias", type=str, default="none", choices=["none", "all", "lora_only"], |
|
|
help="Bias type for LoRA.") |
|
|
|
|
|
|
|
|
parser.add_argument("--output_dir", type=str, required=True, |
|
|
help="Directory to save the fine-tuned model, checkpoints, and logs.") |
|
|
parser.add_argument("--overwrite_output_dir", action='store_true', |
|
|
help="Overwrite the content of the output directory if it exists.") |
|
|
parser.add_argument("--logging_steps", type=int, default=DEFAULT_LOGGING_STEPS, help="Log training metrics every N steps.") |
|
|
parser.add_argument("--eval_steps", type=int, default=DEFAULT_SAVE_EVAL_STEPS, |
|
|
help="Evaluate every N steps (if eval_strategy='steps').") |
|
|
parser.add_argument("--save_steps", type=int, default=DEFAULT_SAVE_EVAL_STEPS, |
|
|
help="Save checkpoint every N steps (if save_strategy='steps').") |
|
|
parser.add_argument("--eval_strategy", type=str, default=DEFAULT_EVAL_STRATEGY, choices=["steps", "epoch", "no"], help="Evaluation strategy.") |
|
|
parser.add_argument("--save_strategy", type=str, default=DEFAULT_SAVE_STRATEGY, choices=["steps", "epoch", "no"], |
|
|
help="Checkpoint saving strategy.") |
|
|
parser.add_argument("--save_total_limit", type=int, default=DEFAULT_SAVE_TOTAL_LIMIT, |
|
|
help="Limit the total number of checkpoints saved.") |
|
|
parser.add_argument("--load_best_model_at_end", action='store_true', |
|
|
help="Load the best model (based on evaluation loss) at the end.") |
|
|
parser.add_argument("--early_stopping_patience", type=int, default=None, |
|
|
help="Number of evaluations with no improvement to trigger early stopping. Requires load_best_model_at_end.") |
|
|
parser.add_argument("--special_tokens", nargs='+', default=SPECIAL_TOKENS, |
|
|
help="List of special tokens to add to the tokenizer (e.g., '<startofex>', '<endofex>').") |
|
|
|
|
|
|
|
|
parser.add_argument("--fp16", action='store_true', help="Use mixed precision training (FP16).") |
|
|
parser.add_argument("--seed", type=int, default=DEFAULT_SEED, help="Random seed for reproducibility.") |
|
|
parser.add_argument("--report_to", type=str, default="tensorboard", choices=["tensorboard", "wandb", "none"], |
|
|
help="Where to report metrics.") |
|
|
parser.add_argument("--run_name", type=str, default="train_gpt2_equations", |
|
|
help="Name of the run for logging purposes.") |
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--push_to_hub", action='store_true', help="Push the final model to the Hugging Face Hub.") |
|
|
parser.add_argument("--hub_model_id", type=str, default=None, |
|
|
help="Repository ID for pushing (e.g., 'username/gpt2-finetuned-equations'). Required if --push_to_hub.") |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.push_to_hub and not args.hub_model_id: |
|
|
raise ValueError("--hub_model_id is required when --push_to_hub is set.") |
|
|
if args.early_stopping_patience is not None and not args.load_best_model_at_end: |
|
|
logger.warning("--early_stopping_patience is set, but --load_best_model_at_end is False. Early stopping requires loading the best model.") |
|
|
|
|
|
if args.eval_strategy == "no" and (args.load_best_model_at_end or args.early_stopping_patience is not None): |
|
|
raise ValueError("Cannot use --load_best_model_at_end or --early_stopping_patience without evaluation (set --eval_strategy to 'steps' or 'epoch').") |
|
|
|
|
|
return args |
|
|
|
|
|
|
|
|
|
|
|
def load_and_prepare_dataset( |
|
|
dataset_repo_id: str, |
|
|
data_dir: Optional[str], |
|
|
source_column: str, |
|
|
target_column: str, |
|
|
tokenizer: PreTrainedTokenizerBase, |
|
|
block_size: int, |
|
|
eval_strategy: str |
|
|
) -> DatasetDict: |
|
|
|
|
|
"""Loads dataset, renames column, tokenizes, and groups texts.""" |
|
|
logger.info(f"Loading dataset from Hub: {dataset_repo_id} (data_dir: {data_dir})") |
|
|
try: |
|
|
raw_datasets = load_dataset(dataset_repo_id, data_dir=data_dir) |
|
|
logger.info(f"Dataset loaded: {raw_datasets}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load dataset: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
eos_text_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Renaming column '{source_column}' to '{target_column}' and removing others.") |
|
|
try: |
|
|
|
|
|
def rename_and_keep_column(example: Dict[str, Any]) -> Dict[str, Any]: |
|
|
if source_column not in example: |
|
|
raise KeyError(f"Source column '{source_column}' not found in example: {list(example.keys())}") |
|
|
|
|
|
text = example[source_column] |
|
|
|
|
|
return {target_column: text + eos_text_token} |
|
|
|
|
|
|
|
|
column_names_to_remove = {} |
|
|
for split in raw_datasets.keys(): |
|
|
column_names_to_remove[split] = raw_datasets[split].column_names |
|
|
|
|
|
processed_datasets = DatasetDict() |
|
|
for split, names in column_names_to_remove.items(): |
|
|
processed_datasets[split] = raw_datasets[split].map( |
|
|
rename_and_keep_column, |
|
|
batched=False, |
|
|
remove_columns=names |
|
|
) |
|
|
logger.info(f"Dataset after column renaming: {processed_datasets}") |
|
|
|
|
|
except KeyError as e: |
|
|
logger.error(f"Error during column renaming: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
except Exception as e: |
|
|
logger.error(f"An unexpected error occurred during column renaming/cleanup: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
logger.info("Tokenizing dataset...") |
|
|
def tokenize_function(examples: Dict[str, List[str]]) -> Dict[str, List[Any]]: |
|
|
return tokenizer(examples[target_column], truncation=True) |
|
|
|
|
|
tokenized_datasets = processed_datasets.map( |
|
|
tokenize_function, |
|
|
batched=True, |
|
|
remove_columns=processed_datasets["train"].column_names, |
|
|
|
|
|
desc="Running tokenizer on dataset", |
|
|
) |
|
|
logger.info("Tokenization complete.") |
|
|
|
|
|
return tokenized_datasets |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_tokenizer(model_name_or_path: str) -> PreTrainedTokenizerBase: |
|
|
"""Loads the tokenizer and adds special tokens.""" |
|
|
logger.info(f"Loading tokenizer for model: {model_name_or_path}") |
|
|
try: |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) |
|
|
|
|
|
|
|
|
SPECIAL_TOKENS = { |
|
|
"pad_token": "<pad>", |
|
|
"additional_special_tokens": ["<startofex>", "<endofex>"] |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
num_added = tokenizer.add_special_tokens(SPECIAL_TOKENS) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info(f"Added {num_added} special tokens: {SPECIAL_TOKENS}") |
|
|
|
|
|
return tokenizer |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load tokenizer: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
def load_model(model_name_or_path: str, tokenizer: PreTrainedTokenizerBase, args: argparse.Namespace) -> PeftModel: |
|
|
"""Loads the base model, resizes embeddings, and applies PEFT (LoRA).""" |
|
|
logger.info(f"Loading pretrained model: {model_name_or_path}") |
|
|
try: |
|
|
base_model = AutoModelForCausalLM.from_pretrained(model_name_or_path) |
|
|
|
|
|
base_model.resize_token_embeddings(len(tokenizer)) |
|
|
logger.info(f"Resized model token embeddings to: {len(tokenizer)}") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load base model: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
logger.info("Configuring PEFT (LoRA)...") |
|
|
lora_config = LoraConfig( |
|
|
task_type=TaskType.CAUSAL_LM, |
|
|
r=args.lora_r, |
|
|
lora_alpha=args.lora_alpha, |
|
|
target_modules=args.lora_target_modules, |
|
|
lora_dropout=args.lora_dropout, |
|
|
bias=args.lora_bias, |
|
|
|
|
|
) |
|
|
logger.info(f"LoRA Config: {lora_config}") |
|
|
|
|
|
|
|
|
try: |
|
|
model = get_peft_model(base_model, lora_config) |
|
|
logger.info("Applied PEFT (LoRA) configuration to the model.") |
|
|
model.print_trainable_parameters() |
|
|
|
|
|
|
|
|
if not any(p.requires_grad for p in model.parameters()): |
|
|
logger.error("No parameters marked as trainable after applying LoRA. Check LoRA config and target modules.") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Failed to apply PEFT (LoRA) to the model: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
|
|
|
def initialize_trainer( |
|
|
model: PeftModel, |
|
|
args: TrainingArguments, |
|
|
train_dataset: Dataset, |
|
|
eval_dataset: Optional[Dataset], |
|
|
tokenizer: PreTrainedTokenizerBase, |
|
|
early_stopping_patience: Optional[int] |
|
|
) -> Trainer: |
|
|
"""Initializes and returns the Hugging Face Trainer.""" |
|
|
logger.info("Initializing Trainer...") |
|
|
|
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False, |
|
|
pad_to_multiple_of=8, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
callbacks: List[TrainerCallback] = [] |
|
|
if args.load_best_model_at_end and early_stopping_patience is not None and early_stopping_patience > 0: |
|
|
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=early_stopping_patience) |
|
|
callbacks.append(early_stopping_callback) |
|
|
logger.info(f"Early stopping enabled with patience: {early_stopping_patience}") |
|
|
|
|
|
trainer = Trainer( |
|
|
model=model, |
|
|
args=args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
tokenizer=tokenizer, |
|
|
data_collator=data_collator, |
|
|
callbacks=callbacks if callbacks else None, |
|
|
|
|
|
) |
|
|
logger.info("Trainer initialized.") |
|
|
return trainer |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to orchestrate the fine-tuning process.""" |
|
|
start_time = datetime.now() |
|
|
logger.info(f"--- Starting Fine-Tuning Script at {start_time.strftime('%Y-%m-%d %H:%M:%S')} ---") |
|
|
|
|
|
|
|
|
args = parse_arguments() |
|
|
logger.info(f"Running with arguments: {vars(args)}") |
|
|
|
|
|
|
|
|
hf_token = None |
|
|
if args.push_to_hub: |
|
|
hf_token = load_hf_token() |
|
|
|
|
|
|
|
|
set_seed(args.seed) |
|
|
logger.info(f"Random seed set to: {args.seed}") |
|
|
|
|
|
|
|
|
tokenizer = load_tokenizer(args.model_name_or_path) |
|
|
|
|
|
|
|
|
lm_datasets = load_and_prepare_dataset( |
|
|
dataset_repo_id=args.dataset_repo_id, |
|
|
data_dir=args.data_dir, |
|
|
source_column=args.source_data_column, |
|
|
target_column=DEFAULT_DATA_COLUMN, |
|
|
tokenizer=tokenizer, |
|
|
block_size=args.block_size, |
|
|
eval_strategy=args.eval_strategy |
|
|
) |
|
|
train_dataset = lm_datasets["train"] |
|
|
eval_dataset = lm_datasets.get("validation") |
|
|
has_validation = eval_dataset is not None and len(eval_dataset) > 0 |
|
|
if not has_validation: |
|
|
logger.warning("No validation dataset found. Skipping evaluation during training.") |
|
|
eval_dataset = None |
|
|
|
|
|
|
|
|
|
|
|
model = load_model(args.model_name_or_path, tokenizer, args) |
|
|
|
|
|
|
|
|
training_args = TrainingArguments( |
|
|
output_dir=args.output_dir, |
|
|
num_train_epochs=args.num_train_epochs, |
|
|
per_device_train_batch_size=args.per_device_train_batch_size, |
|
|
per_device_eval_batch_size=args.per_device_eval_batch_size, |
|
|
learning_rate=args.learning_rate, |
|
|
lr_scheduler_type=args.lr_scheduler_type, |
|
|
weight_decay=args.weight_decay, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
warmup_steps=args.warmup_steps, |
|
|
fp16=args.fp16, |
|
|
bf16=args.bf16, |
|
|
seed=args.seed, |
|
|
eval_strategy=args.eval_strategy, |
|
|
metric_for_best_model="eval_loss", |
|
|
greater_is_better=False, |
|
|
load_best_model_at_end=args.load_best_model_at_end, |
|
|
save_strategy=args.save_strategy, |
|
|
save_total_limit=args.save_total_limit, |
|
|
logging_dir=os.path.join(args.output_dir, "logs"), |
|
|
logging_steps=args.logging_steps, |
|
|
report_to=args.report_to, |
|
|
run_name=args.run_name, |
|
|
push_to_hub=args.push_to_hub, |
|
|
hub_model_id=args.hub_model_id, |
|
|
hub_token=hf_token if args.push_to_hub else None, |
|
|
overwrite_output_dir=args.overwrite_output_dir, |
|
|
optim=args.optim, |
|
|
dataloader_num_workers=args.dataloader_num_workers, |
|
|
warmup_ratio=args.warmup_ratio, |
|
|
max_grad_norm=args.max_grad_norm, |
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
trainer = initialize_trainer( |
|
|
model=model, |
|
|
args=training_args, |
|
|
train_dataset=train_dataset, |
|
|
eval_dataset=eval_dataset, |
|
|
tokenizer=tokenizer, |
|
|
early_stopping_patience=args.early_stopping_patience |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("*** Starting Training ***") |
|
|
try: |
|
|
train_result = trainer.train() |
|
|
logger.info("Training finished.") |
|
|
|
|
|
|
|
|
metrics = train_result.metrics |
|
|
trainer.log_metrics("train", metrics) |
|
|
trainer.save_metrics("train", metrics) |
|
|
|
|
|
|
|
|
logger.info(f"Saving final model and tokenizer to {training_args.output_dir}") |
|
|
trainer.save_model() |
|
|
|
|
|
tokenizer.save_pretrained(training_args.output_dir) |
|
|
logger.info("Model and tokenizer saved successfully.") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"An error occurred during training: {e}", exc_info=True) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if training_args.do_eval: |
|
|
if eval_dataset: |
|
|
logger.info("*** Evaluating Final Model ***") |
|
|
try: |
|
|
eval_metrics = trainer.evaluate() |
|
|
|
|
|
try: |
|
|
perplexity = np.exp(eval_metrics["eval_loss"]) |
|
|
eval_metrics["perplexity"] = perplexity |
|
|
logger.info(f"Perplexity: {perplexity:.4f}") |
|
|
except OverflowError: |
|
|
eval_metrics["perplexity"] = float("inf") |
|
|
logger.warning("Could not compute perplexity due to overflow in exp(eval_loss).") |
|
|
|
|
|
logger.info(f"Evaluation metrics: {eval_metrics}") |
|
|
trainer.log_metrics("eval", eval_metrics) |
|
|
trainer.save_metrics("eval", eval_metrics) |
|
|
except Exception as e: |
|
|
logger.error(f"An error occurred during evaluation: {e}", exc_info=True) |
|
|
else: |
|
|
logger.warning("Evaluation was configured but no valid evaluation dataset was found/processed. Skipping final evaluation.") |
|
|
|
|
|
|
|
|
if training_args.push_to_hub: |
|
|
logger.info(f"Pushing final model artifacts to Hub repository: {training_args.hub_model_id}") |
|
|
try: |
|
|
|
|
|
trainer.push_to_hub(commit_message="End of fine-tuning training") |
|
|
logger.info("Model pushed successfully to the Hub.") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to push model to Hub: {e}", exc_info=True) |
|
|
|
|
|
|
|
|
end_time = datetime.now() |
|
|
logger.info(f"--- Script Finished at {end_time.strftime('%Y-%m-%d %H:%M:%S')} ---") |
|
|
logger.info(f"Total execution time: {end_time - start_time}") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|