File size: 5,449 Bytes
54f7697 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
import sys
import torch
sys.path.append("..")
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
from datasets import load_dataset
from FTP import AdamP
import wandb
import argparse
import copy
import math
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ftp_k = 1
class TrainerAdamP(Trainer):
def create_optimizer(self):
optimizer_params = {
"lr": 5e-6,
"weight_decay": 0.0,
"k": ftp_k, # Example parameter for AdamP
"exclude_set": set() # Use empty set if you don't want exclusion
}
# Cache pre-trained model weights
params_to_opt = [x[1] for x in self.model.named_parameters() if x[1].requires_grad]
params_to_opt_name = [x[0] for x in self.model.named_parameters() if x[1].requires_grad]
params_anchor = copy.deepcopy(params_to_opt)
param_group = [{'params': params_to_opt, 'pre': params_anchor, 'name': params_to_opt_name}]
# Initialize the AdamP optimizer
self.optimizer = AdamP(param_group, **optimizer_params)
if __name__ == "__main__":
# === CONFIGURATION SETTINGS ===
parser = argparse.ArgumentParser(description="Training configuration.")
parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
parser.add_argument('--batch_size', type=int, default=3, help='Batch size for training.')
parser.add_argument('--epoch', type=int, default=3, help='train epoch')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate.')
args = parser.parse_args()
# no_pos_encodings_underscore = "" # Ex: "_nopos" if needed
ckpt_path = "./checkpoints"
# effective_bsz = 512
model_name = "meta-llama/Llama-3.2-3B"
model_save_name = "Llama-3.2-3B-FTP"
# === FILE PATHS BASED ON CONFIGURATION ===
wandb_id = f"{model_save_name}_{args.perturbation}_train_set_{args.train_set}_epoch_{args.epoch}_batch_size_{args.batch_size}_seed_{args.seed}_lr_{args.lr}_wandb_ftp_{ftp_k}"
wandb.init(project="exp-impo-shuffle", group="ftp-1", name=wandb_id)
wandb.config.update(args)
run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(run_dir, exist_ok=True)
# === DATASET LOADING ===
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
train_dataset = dataset['train']
valid_dataset = dataset['validation']
# === TOKENIZER & MODEL LOADING ===
# model_name = f"gpt2{'' if no_pos_encodings_underscore == '' else '-no-pos'}-small-{perturbation}-{paren_model}"
# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
model = AutoModelForCausalLM.from_pretrained(model_name,
# device_map="auto", # deepspeed needs to delete this setting
cache_dir=cache_dir)
# print("model:", model)
# === TOKENIZATION ===
def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_valid = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
shuffled_valid = tokenized_valid.shuffle()
tokenized_valid = shuffled_valid.select(range(1000))
print("tokenized_valid:", tokenized_valid)
# print(train_dataset)
# === DATA COLLATOR ===2
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# === TRAINING ARGUMENTS ===
training_args = TrainingArguments(
output_dir=run_dir,
evaluation_strategy="steps",
eval_steps=10,
per_device_train_batch_size=args.batch_size, # set "auto" in deepspeed config, adjust it in trainer
logging_dir='./logs',
logging_steps=1,
save_steps=100,
learning_rate=args.lr, # align with deepspeed
num_train_epochs=args.epoch,
seed=args.seed,
gradient_accumulation_steps=2, # # set "auto" in deepspeed config, adjust it in trainer
fp16=True, # align with deepspeed
report_to="wandb",
warmup_ratio=0.1,
deepspeed="deepspeed_config/train_dp_config.json"
)
# === TRAINER ===
trainer = TrainerAdamP(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
tokenizer=tokenizer,
data_collator=data_collator
)
# === TRAIN MODEL ===
trainer.train()
# End logging
wandb.finish()
|