# train_gpt2_equations.py # Script to fine-tune a GPT-2 model on a dataset of equations from the Hugging Face Hub. # Author: Your Name # Date: April 17, 2025 import argparse import os import logging from dotenv import load_dotenv import sys from transformers import EarlyStoppingCallback import numpy as np import wandb import random from datasets import load_dataset from transformers import ( AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling, set_seed ) from peft import LoraConfig, get_peft_model, TaskType # Configure logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Preprocessing Functions --- def tokenize_function(examples, tokenizer): """Applies the tokenizer to the 'text' field of the dataset examples.""" return tokenizer(examples["text"]) def group_texts(examples, block_size): """Groups texts into chunks of block_size.""" # Concatenate all texts. concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can # customize this part to your needs. if total_length >= block_size: total_length = (total_length // block_size) * block_size else: # Handle case where total length is less than block size (might happen with very small datasets/splits) # You might want to pad here, or simply return empty if Trainer handles it logger.warning(f"Total length ({total_length}) is smaller than block_size ({block_size}). Chunking might result in empty data for small splits.") # Returning empty might cause issues later, consider padding or adjusting block_size # For now, let's proceed but be aware. pass # Let the slicing below handle it, might result in empty lists # Split by chunks of block_size. result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } # For Causal LM, labels are usually the input_ids shifted, Trainer handles this if labels aren't provided # or we can create them explicitly like this: result["labels"] = result["input_ids"].copy() return result # --- Main Training Function --- def main(): parser = argparse.ArgumentParser(description="Fine-tune GPT-2 model on an equation dataset from Hugging Face Hub.") # --- Arguments --- parser.add_argument("--model_name_or_path", type=str, default="gpt2", help="Pretrained model name or path (e.g., 'gpt2', 'gpt2-medium').") parser.add_argument("--dataset_repo_id", type=str, required=True, help="Hugging Face Hub repository ID for the dataset (e.g., 'username/my-equation-dataset').") parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the fine-tuned model and checkpoints.") parser.add_argument("--data_dir", type=str, required=True, help="Directory containing the dataset files.") parser.add_argument("--data_column", type=str, default="i_prompt_n", help="Column name in the dataset to be used for training (e.g., 'i_prompt_n', 'p_prompt_n').") parser.add_argument("--approach", default="infix", type=str, help="Approach to be used for training (e.g., 'infix', 'prefix').") # Wandb arguments parser.add_argument("--wandb_project", type=str, default="seriguela", help="Wandb project name.") parser.add_argument("--wandb_run_name", type=str, default=None, help="Wandb run name. If not set, will be auto-generated.") parser.add_argument("--wandb_entity", type=str, default=None, help="Wandb entity (team or username).") parser.add_argument("--block_size", type=int, default=128, help="Block size for tokenizing and chunking the dataset.") parser.add_argument("--num_train_epochs", type=int, default=3, help="Number of training epochs.") parser.add_argument("--per_device_train_batch_size", type=int, default=8, help="Batch size per device during training.") parser.add_argument("--per_device_eval_batch_size", type=int, default=8, help="Batch size per device during evaluation.") parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate for the optimizer.") parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay for regularization.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of steps to accumulate gradients before updating weights.") parser.add_argument("--warmup_steps", type=int, default=0, help="Number of warmup steps for the learning rate scheduler.") parser.add_argument("--logging_steps", type=int, default=100, help="Log training metrics every N steps.") parser.add_argument("--eval_steps", type=int, default=500, help="Evaluate on the validation set every N steps. Ignored if eval_strategy='epoch'.") parser.add_argument("--save_steps", type=int, default=500, help="Save a checkpoint every N steps. Ignored if save_strategy='epoch'.") parser.add_argument("--eval_strategy", type=str, default="epoch", choices=["steps", "epoch", "no"], help="Evaluation strategy ('steps', 'epoch', 'no').") parser.add_argument("--save_strategy", type=str, default="epoch", choices=["steps", "epoch", "no"], help="Checkpoint saving strategy ('steps', 'epoch', 'no').") parser.add_argument("--save_total_limit", type=int, default=2, help="Limit the total number of checkpoints saved.") parser.add_argument("--load_best_model_at_end", action='store_true', help="Load the best model (based on evaluation loss) at the end of training.") parser.add_argument("--fp16", action='store_true', help="Use mixed precision training (FP16). Requires CUDA.") parser.add_argument("--push_to_hub", action='store_true', help="Push the final model to the Hugging Face Hub.") parser.add_argument("--hub_model_id", type=str, default=None, help="Repository ID for pushing the model (e.g., 'username/gpt2-finetuned-equations'). Required if --push_to_hub is set.") parser.add_argument("--run_name", type=str, default=None, help="Optional run name for this training (used in output_dir and hub_model_id).") parser.add_argument("--lora_r", type=int, default=8, help="LoRA rank (dimension of adapter matrices).") parser.add_argument("--lora_alpha", type=int, default=32, help="LoRA alpha (scaling factor).") parser.add_argument("--lora_dropout", type=float, default=0.05, help="Dropout probability for LoRA layers.") parser.add_argument("--seed", type=int, default=42, help="Random seed for reproducibility.") args = parser.parse_args() # Carrega as variáveis do .env load_dotenv() # Acessa o token token = os.getenv("HF_TOKEN") if not token: raise ValueError("Token da Hugging Face não encontrado no .env.") # Configure Wandb API key wandb_api_key = os.getenv("WANDB_API_KEY") if wandb_api_key: os.environ["WANDB_API_KEY"] = wandb_api_key wandb.login(key=wandb_api_key) # Set seed for reproducibility set_seed(args.seed) # Initialize wandb wandb_run_name = args.wandb_run_name or f"{args.model_name_or_path}-{args.data_dir}-{args.approach}" wandb.init( project=args.wandb_project, name=wandb_run_name, entity=args.wandb_entity, config={ "model": args.model_name_or_path, "dataset": args.dataset_repo_id, "data_dir": args.data_dir, "data_column": args.data_column, "approach": args.approach, "block_size": args.block_size, "epochs": args.num_train_epochs, "batch_size": args.per_device_train_batch_size, "learning_rate": args.learning_rate, "seed": args.seed, } ) logger.info(f"Wandb initialized: project={args.wandb_project}, run={wandb_run_name}") logger.info(f"Starting fine-tuning with parameters: {args}") # 1. Load Dataset from Hub or local files # Check if local prepared data exists local_data_dir = "./data/processed/700K_fixed" local_train = os.path.join(local_data_dir, f"train_{args.data_dir}.csv") if os.path.exists(local_train): logger.info(f"Loading dataset from LOCAL files: {local_data_dir}") try: raw_datasets = load_dataset( 'csv', data_files={ "train": os.path.join(local_data_dir, f"train_{args.data_dir}.csv"), "validation": os.path.join(local_data_dir, f"validation_{args.data_dir}.csv"), "test": os.path.join(local_data_dir, f"test_{args.data_dir}.csv") } ) logger.info(f"Dataset loaded from local CSV files: {raw_datasets}") except Exception as e: logger.error(f"Failed to load local dataset: {e}") logger.info(f"Falling back to Hub: {args.dataset_repo_id}") raw_datasets = load_dataset( args.dataset_repo_id, data_files={ "train": f"{args.data_dir}/train_{args.data_dir}.csv", "validation": f"{args.data_dir}/val_{args.data_dir}.csv", "test": f"{args.data_dir}/test_{args.data_dir}.csv" } ) logger.info(f"Dataset loaded from Hub: {raw_datasets}") else: logger.info(f"Loading dataset from Hub: {args.dataset_repo_id}") try: # Carrega dataset com arquivos específicos para cada split raw_datasets = load_dataset( args.dataset_repo_id, data_files={ "train": f"{args.data_dir}/train_{args.data_dir}.csv", "validation": f"{args.data_dir}/val_{args.data_dir}.csv", "test": f"{args.data_dir}/test_{args.data_dir}.csv" } ) logger.info(f"Dataset loaded: {raw_datasets}") except Exception as e: logger.error(f"Failed to load dataset: {e}") sys.exit(1) # Renomeia a coluna de dados para 'text' logger.info(f"Renaming column '{args.data_column}' to 'text'") raw_datasets = raw_datasets.map( lambda x: {"text": x[args.data_column]}, remove_columns=raw_datasets["train"].column_names ) logger.info(f"Dataset after column rename: {raw_datasets}") # Basic validation: Check for train/validation splits if "train" not in raw_datasets: raise ValueError("Dataset missing 'train' split.") if args.eval_strategy != "no" and "validation" not in raw_datasets: raise ValueError("Dataset missing 'validation' split, required for evaluation.") # 2. Load Tokenizer logger.info(f"Loading tokenizer for model: {args.model_name_or_path}") try: tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) #, use_fast=True) # Consider use_fast=True # Handle GPT-2 specific padding token if necessary if tokenizer.pad_token is None and "gpt2" in args.model_name_or_path.lower(): logger.warning("GPT-2 tokenizer does not have a default pad token. Setting pad_token = eos_token.") tokenizer.pad_token = tokenizer.eos_token # Adding special tokens tokenizer.add_special_tokens({"additional_special_tokens": ["<|startofex|>", "<|endofex|>"]}) # Verify special tokens were added correctly start_token_id = tokenizer.convert_tokens_to_ids("<|startofex|>") end_token_id = tokenizer.convert_tokens_to_ids("<|endofex|>") if start_token_id == tokenizer.unk_token_id or end_token_id == tokenizer.unk_token_id: logger.error("Special tokens not properly added to tokenizer!") sys.exit(1) logger.info(f"Special token IDs: <|startofex|>={start_token_id}, <|endofex|>={end_token_id}") except Exception as e: logger.error(f"Failed to load tokenizer: {e}") sys.exit(1) # 3. Preprocess Dataset (Tokenize & Chunk) logger.info("Tokenizing dataset...") # Need functools.partial or lambda if tokenize_function needs tokenizer arg with map tokenized_datasets = raw_datasets.map( lambda examples: tokenize_function(examples, tokenizer), batched=True, # num_proc=4, # Optional: Use multiple processes for faster tokenization remove_columns=raw_datasets["train"].column_names # Remove all original columns ) logger.info("Tokenization complete.") logger.info(f"Grouping texts into blocks of size: {args.block_size}") # Need functools.partial or lambda if group_texts needs block_size arg with map lm_datasets = tokenized_datasets.map( lambda examples: group_texts(examples, args.block_size), batched=True, # num_proc=4 # Optional: Use multiple processes ) logger.info("Grouping complete.") logger.info(f"Processed dataset structure: {lm_datasets}") # Ensure datasets aren't empty after processing if not lm_datasets["train"]: logger.error("Training dataset is empty after processing. Check block_size and original data.") sys.exit(1) if args.eval_strategy != "no" and not lm_datasets["validation"]: logger.warning("Validation dataset is empty after processing. Evaluation might fail or be skipped.") # Validate that training data contains special tokens logger.info("Validating special tokens in training data...") sample_indices = random.sample(range(len(lm_datasets["train"])), min(10, len(lm_datasets["train"]))) valid_samples = 0 for idx in sample_indices: sample = lm_datasets["train"][idx] decoded = tokenizer.decode(sample["input_ids"]) if "<|endofex|>" in decoded: valid_samples += 1 if valid_samples == 0: logger.error("No training samples contain <|endofex|> marker!") logger.error("Training data was not properly prepared. Use prepare_training_data_fixed.py") sys.exit(1) logger.info(f"Validation: {valid_samples}/{len(sample_indices)} samples contain end markers") # 4. Load Model logger.info(f"Loading pretrained model: {args.model_name_or_path}") try: base_model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path) # Update with tokenizer special tokens base_model.resize_token_embeddings(len(tokenizer)) # Configure model to use <|endofex|> as EOS for generation end_token_id = tokenizer.convert_tokens_to_ids("<|endofex|>") base_model.config.eos_token_id = end_token_id logger.info(f"Configured model EOS token: {end_token_id} (<|endofex|>)") except Exception as e: logger.error(f"Failed to load model: {e}") sys.exit(1) # Define LoRA configuration lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, # Specify task type r=args.lora_r, # LoRA rank (dimension of adapter matrices, e.g., 8, 16, 32) lora_alpha=args.lora_alpha, # LoRA alpha (scaling factor, often 2*r) target_modules=["c_attn"], # Modules to apply LoRA to in GPT-2. 'c_attn' often covers query/key/value projections. May need adjustment based on exact model variant. lora_dropout=args.lora_dropout, # Dropout probability for LoRA layers bias="none" # Usually set to 'none' or 'all' # ... other LoraConfig parameters ) # Apply PEFT logger.info("Applying PEFT (LoRA) configuration to the model...") model = get_peft_model(base_model, lora_config) for name, param in model.named_parameters(): if param.requires_grad: logger.info(f"Param will be trained: {name} | requires_grad={param.requires_grad}") model.train() requires_grad_params = [p for p in model.parameters() if p.requires_grad] if not requires_grad_params: logger.error("Nenhum parâmetro com requires_grad=True. O modelo está congelado e não pode ser treinado.") sys.exit(1) model.print_trainable_parameters() # This will show how few parameters are actually trainable! #model.gradient_checkpointing_enable() # 5. Configure Training Arguments logger.info("Configuring training arguments...") # Determine effective values based on validation set availability has_validation = "validation" in lm_datasets and lm_datasets["validation"] effective_load_best = args.load_best_model_at_end and has_validation effective_eval_strategy = args.eval_strategy if has_validation else "no" training_args = TrainingArguments( output_dir=args.output_dir, overwrite_output_dir=True, # Be careful with this in production num_train_epochs=args.num_train_epochs, per_device_train_batch_size=args.per_device_train_batch_size, per_device_eval_batch_size=args.per_device_eval_batch_size, learning_rate=args.learning_rate, weight_decay=args.weight_decay, gradient_accumulation_steps=args.gradient_accumulation_steps, warmup_steps=args.warmup_steps, logging_dir=os.path.join(args.output_dir, 'logs'), # Log Tensorboard data within output_dir logging_steps=args.logging_steps, eval_strategy=effective_eval_strategy, save_strategy=args.save_strategy, save_steps=args.save_steps if args.save_strategy == "steps" else 500, # Default save_steps if strategy is steps but value not provided save_total_limit=args.save_total_limit, load_best_model_at_end=effective_load_best, metric_for_best_model="eval_loss" if effective_load_best else None, greater_is_better=False if effective_load_best else None, fp16=args.fp16, report_to="wandb", run_name=wandb_run_name, push_to_hub=args.push_to_hub, hub_model_id=args.hub_model_id if args.push_to_hub else None, hub_token=token if args.push_to_hub else None, # Use the obtained token seed=args.seed, # Add deepspeed config path if using deepspeed # deepspeed=args.deepspeed_config_path ) # Data collator - for CLM, pads inputs dynamically. # With chunking, sequences should already be block_size, but this handles potential variations/labels. # `mlm=False` specifies Causal LM. data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) # 6. Initialize Trainer logger.info("Initializing Trainer...") trainer = Trainer( model=model, args=training_args, train_dataset=lm_datasets["train"], eval_dataset=lm_datasets.get("validation"), # Use .get() to handle missing validation split gracefully if eval_strategy is 'no' tokenizer=tokenizer, data_collator=data_collator, #compute_metrics=compute_metrics, # Optional: Define a function for custom eval metrics besides loss/perplexity callbacks=[EarlyStoppingCallback(early_stopping_patience=2)] if effective_load_best else None, ) # 7. Start Training logger.info("*** Starting Training ***") try: train_result = trainer.train() logger.info("Training finished.") # Log metrics metrics = train_result.metrics trainer.log_metrics("train", metrics) trainer.save_metrics("train", metrics) # Save final model and tokenizer logger.info(f"Saving final model to {args.output_dir}") trainer.save_model() # Saves model, tokenizer, config, training args # No need to call trainer.save_state() explicitly here unless needed outside Trainer's saves tokenizer.save_pretrained(args.output_dir) except Exception as e: logger.error(f"An error occurred during training: {e}") sys.exit(1) # 8. Evaluate (Optional, but good practice if validation set exists) if training_args.do_eval and lm_datasets.get("validation"): # Check if evaluation was configured AND validation set exists logger.info("*** Evaluating Final Model ***") eval_metrics = trainer.evaluate() logger.info(f"Evaluation metrics: {eval_metrics}") trainer.log_metrics("eval", eval_metrics) trainer.save_metrics("eval", eval_metrics) # 9. Push to Hub (if requested) if args.push_to_hub: if not args.hub_model_id: logger.error("Cannot push to hub: --hub_model_id is required when --push_to_hub is set.") else: logger.info(f"Pushing final model to Hub repository: {args.hub_model_id}") try: # This pushes the content saved by save_model() trainer.push_to_hub(commit_message="End of training") logger.info("Model pushed successfully.") except Exception as e: logger.error(f"Failed to push model to Hub: {e}") # Finish wandb run wandb.finish() logger.info("--- Script Finished ---") if __name__ == "__main__": main()