import sys import torch sys.path.append("..") from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling from utils_llama_3B import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \ GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file from datasets import load_dataset from FTP_3 import AdamP import wandb import argparse import copy import math import os # Prevent tokenizer parallelism issues os.environ["TOKENIZERS_PARALLELISM"] = "false" # Define AdamP-specific parameter ftp_k = 1 # Custom Trainer with Lipschitz Regularization class TrainerAdamP(Trainer): def create_optimizer(self): optimizer_params = { "lr": 5e-6, "weight_decay": 0.0, "k": ftp_k, "exclude_set": set() } params_to_opt = [x[1] for x in self.model.named_parameters() if x[1].requires_grad] params_to_opt_name = [x[0] for x in self.model.named_parameters() if x[1].requires_grad] params_anchor = copy.deepcopy(params_to_opt) param_group = [{'params': params_to_opt, 'pre': params_anchor, 'name': params_to_opt_name}] self.optimizer = AdamP(param_group, **optimizer_params) def compute_loss(self, model, inputs, return_outputs=False): # Default loss computation outputs = model(**inputs) loss = outputs.loss # Base loss (e.g., cross-entropy) # Compute Lipschitz regularization input_data = inputs['input_ids'].detach().clone().float() # Convert to float for gradient computation input_data.requires_grad_(True) outputs_with_grad = model(input_data) logits = outputs_with_grad.logits # Shape: [batch_size, seq_len, vocab_size] batch_size, seq_len, vocab_size = logits.size() lip_mat = [] # List to store gradient norms for each output dimension for i in range(vocab_size): # Iterate over each output dimension v = torch.zeros_like(logits) v[:, :, i] = 1 # Select gradient for dimension `i` gradients = torch.autograd.grad( outputs=logits, inputs=input_data, grad_outputs=v, create_graph=True, retain_graph=True, only_inputs=True )[0] # Shape: [batch_size, seq_len, input_dim] grad_norm = torch.norm(gradients, dim=-1).unsqueeze(dim=-1) # Gradient norm for each input lip_mat.append(grad_norm) # Append to lip_mat # Combine all dimensions' gradient norms lip_concat = torch.cat(lip_mat, dim=-1) # Shape: [batch_size, seq_len, vocab_size] lip_con_norm = torch.norm(lip_concat, dim=-1) # L2-norm across vocab dimensions for each sample # Compute Lipschitz loss as the maximum norm across samples lip_loss = torch.max(lip_con_norm) total_loss = loss + 0.5 * lip_loss # Combine with base loss (scale factor: 0.01) return (total_loss, outputs) if return_outputs else total_loss if __name__ == "__main__": # === CONFIGURATION SETTINGS === parser = argparse.ArgumentParser(description="Training configuration.") parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.') parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.') parser.add_argument('--batch_size', type=int, default=3, help='Batch size for training.') parser.add_argument('--epoch', type=int, default=3, help='Train epoch') parser.add_argument('--seed', type=int, default=0, help='Random seed.') parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate.') args = parser.parse_args() ckpt_path = "./checkpoints" model_name = "meta-llama/Llama-3.2-3B" model_save_name = "Llama-3.2-3B-FTP-Update" wandb_id = f"{model_save_name}_{args.perturbation}_train_set_{args.train_set}_epoch_{args.epoch}_batch_size_{args.batch_size}_seed_{args.seed}_lr_{args.lr}_wandb_ftp_{ftp_k}_UGD_ftp_2_lip_0.5" wandb.init(project="FTP-shuffle", group="shuffle", name=wandb_id) wandb.config.update(args) run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts") run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs") os.makedirs(cache_dir, exist_ok=True) os.makedirs(run_dir, exist_ok=True) # === DATASET LOADING === dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True) train_dataset = dataset['train'] valid_dataset = dataset['validation'] # === TOKENIZER & MODEL LOADING === tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer'] model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) # === TOKENIZATION === def tokenize_function(examples): return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024) tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) tokenized_valid = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) shuffled_valid = tokenized_valid.shuffle() tokenized_valid = shuffled_valid.select(range(1000)) # === DATA COLLATOR === data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) # === TRAINING ARGUMENTS === training_args = TrainingArguments( output_dir=run_dir, evaluation_strategy="steps", eval_steps=10, per_device_train_batch_size=args.batch_size, logging_dir='./logs', logging_steps=1, save_steps=1000000, learning_rate=args.lr, num_train_epochs=args.epoch, seed=args.seed, gradient_accumulation_steps=2, fp16=True, report_to="wandb", warmup_ratio=0.1, deepspeed="deepspeed_config/train_dp_config.json" ) # === TRAINER === trainer = TrainerAdamP( model=model, args=training_args, train_dataset=tokenized_train, eval_dataset=tokenized_valid, tokenizer=tokenizer, data_collator=data_collator ) # === TRAIN MODEL === trainer.train() # End logging wandb.finish()