ScalingMakeItPossible / train /train_ftp_update3.py
Yaning1001's picture
Add files using upload-large-folder tool
69168b6 verified
import sys
import torch
sys.path.append("..")
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from utils_llama_3B import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
from datasets import load_dataset
from FTP_3 import AdamP
import wandb
import argparse
import copy
import math
import os
# Prevent tokenizer parallelism issues
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# Define AdamP-specific parameter
ftp_k = 1
# Custom Trainer with Lipschitz Regularization
class TrainerAdamP(Trainer):
def create_optimizer(self):
optimizer_params = {
"lr": 5e-6,
"weight_decay": 0.0,
"k": ftp_k,
"exclude_set": set()
}
params_to_opt = [x[1] for x in self.model.named_parameters() if x[1].requires_grad]
params_to_opt_name = [x[0] for x in self.model.named_parameters() if x[1].requires_grad]
params_anchor = copy.deepcopy(params_to_opt)
param_group = [{'params': params_to_opt, 'pre': params_anchor, 'name': params_to_opt_name}]
self.optimizer = AdamP(param_group, **optimizer_params)
def compute_loss(self, model, inputs, return_outputs=False):
# Default loss computation
outputs = model(**inputs)
loss = outputs.loss # Base loss (e.g., cross-entropy)
# Compute Lipschitz regularization
input_data = inputs['input_ids'].detach().clone().float() # Convert to float for gradient computation
input_data.requires_grad_(True)
outputs_with_grad = model(input_data)
logits = outputs_with_grad.logits # Shape: [batch_size, seq_len, vocab_size]
batch_size, seq_len, vocab_size = logits.size()
lip_mat = [] # List to store gradient norms for each output dimension
for i in range(vocab_size): # Iterate over each output dimension
v = torch.zeros_like(logits)
v[:, :, i] = 1 # Select gradient for dimension `i`
gradients = torch.autograd.grad(
outputs=logits,
inputs=input_data,
grad_outputs=v,
create_graph=True,
retain_graph=True,
only_inputs=True
)[0] # Shape: [batch_size, seq_len, input_dim]
grad_norm = torch.norm(gradients, dim=-1).unsqueeze(dim=-1) # Gradient norm for each input
lip_mat.append(grad_norm) # Append to lip_mat
# Combine all dimensions' gradient norms
lip_concat = torch.cat(lip_mat, dim=-1) # Shape: [batch_size, seq_len, vocab_size]
lip_con_norm = torch.norm(lip_concat, dim=-1) # L2-norm across vocab dimensions for each sample
# Compute Lipschitz loss as the maximum norm across samples
lip_loss = torch.max(lip_con_norm)
total_loss = loss + 0.5 * lip_loss # Combine with base loss (scale factor: 0.01)
return (total_loss, outputs) if return_outputs else total_loss
if __name__ == "__main__":
# === CONFIGURATION SETTINGS ===
parser = argparse.ArgumentParser(description="Training configuration.")
parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
parser.add_argument('--batch_size', type=int, default=3, help='Batch size for training.')
parser.add_argument('--epoch', type=int, default=3, help='Train epoch')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate.')
args = parser.parse_args()
ckpt_path = "./checkpoints"
model_name = "meta-llama/Llama-3.2-3B"
model_save_name = "Llama-3.2-3B-FTP-Update"
wandb_id = f"{model_save_name}_{args.perturbation}_train_set_{args.train_set}_epoch_{args.epoch}_batch_size_{args.batch_size}_seed_{args.seed}_lr_{args.lr}_wandb_ftp_{ftp_k}_UGD_ftp_2_lip_0.5"
wandb.init(project="FTP-shuffle", group="shuffle", name=wandb_id)
wandb.config.update(args)
run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(run_dir, exist_ok=True)
# === DATASET LOADING ===
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
train_dataset = dataset['train']
valid_dataset = dataset['validation']
# === TOKENIZER & MODEL LOADING ===
tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
# === TOKENIZATION ===
def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_valid = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
shuffled_valid = tokenized_valid.shuffle()
tokenized_valid = shuffled_valid.select(range(1000))
# === DATA COLLATOR ===
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# === TRAINING ARGUMENTS ===
training_args = TrainingArguments(
output_dir=run_dir,
evaluation_strategy="steps",
eval_steps=10,
per_device_train_batch_size=args.batch_size,
logging_dir='./logs',
logging_steps=1,
save_steps=1000000,
learning_rate=args.lr,
num_train_epochs=args.epoch,
seed=args.seed,
gradient_accumulation_steps=2,
fp16=True,
report_to="wandb",
warmup_ratio=0.1,
deepspeed="deepspeed_config/train_dp_config.json"
)
# === TRAINER ===
trainer = TrainerAdamP(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
tokenizer=tokenizer,
data_collator=data_collator
)
# === TRAIN MODEL ===
trainer.train()
# End logging
wandb.finish()