LlamaCheckpoints / train_ftp.py
Yaning1001's picture
Add files using upload-large-folder tool
54f7697 verified
import sys
import torch
sys.path.append("..")
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
from datasets import load_dataset
from FTP import AdamP
import wandb
import argparse
import copy
import math
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
ftp_k = 1
class TrainerAdamP(Trainer):
def create_optimizer(self):
optimizer_params = {
"lr": 5e-6,
"weight_decay": 0.0,
"k": ftp_k, # Example parameter for AdamP
"exclude_set": set() # Use empty set if you don't want exclusion
}
# Cache pre-trained model weights
params_to_opt = [x[1] for x in self.model.named_parameters() if x[1].requires_grad]
params_to_opt_name = [x[0] for x in self.model.named_parameters() if x[1].requires_grad]
params_anchor = copy.deepcopy(params_to_opt)
param_group = [{'params': params_to_opt, 'pre': params_anchor, 'name': params_to_opt_name}]
# Initialize the AdamP optimizer
self.optimizer = AdamP(param_group, **optimizer_params)
if __name__ == "__main__":
# === CONFIGURATION SETTINGS ===
parser = argparse.ArgumentParser(description="Training configuration.")
parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
parser.add_argument('--batch_size', type=int, default=3, help='Batch size for training.')
parser.add_argument('--epoch', type=int, default=3, help='train epoch')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate.')
args = parser.parse_args()
# no_pos_encodings_underscore = "" # Ex: "_nopos" if needed
ckpt_path = "./checkpoints"
# effective_bsz = 512
model_name = "meta-llama/Llama-3.2-3B"
model_save_name = "Llama-3.2-3B-FTP"
# === FILE PATHS BASED ON CONFIGURATION ===
wandb_id = f"{model_save_name}_{args.perturbation}_train_set_{args.train_set}_epoch_{args.epoch}_batch_size_{args.batch_size}_seed_{args.seed}_lr_{args.lr}_wandb_ftp_{ftp_k}"
wandb.init(project="exp-impo-shuffle", group="ftp-1", name=wandb_id)
wandb.config.update(args)
run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(run_dir, exist_ok=True)
# === DATASET LOADING ===
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
train_dataset = dataset['train']
valid_dataset = dataset['validation']
# === TOKENIZER & MODEL LOADING ===
# model_name = f"gpt2{'' if no_pos_encodings_underscore == '' else '-no-pos'}-small-{perturbation}-{paren_model}"
# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
model = AutoModelForCausalLM.from_pretrained(model_name,
# device_map="auto", # deepspeed needs to delete this setting
cache_dir=cache_dir)
# print("model:", model)
# === TOKENIZATION ===
def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_valid = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
shuffled_valid = tokenized_valid.shuffle()
tokenized_valid = shuffled_valid.select(range(1000))
print("tokenized_valid:", tokenized_valid)
# print(train_dataset)
# === DATA COLLATOR ===2
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# === TRAINING ARGUMENTS ===
training_args = TrainingArguments(
output_dir=run_dir,
evaluation_strategy="steps",
eval_steps=10,
per_device_train_batch_size=args.batch_size, # set "auto" in deepspeed config, adjust it in trainer
logging_dir='./logs',
logging_steps=1,
save_steps=100,
learning_rate=args.lr, # align with deepspeed
num_train_epochs=args.epoch,
seed=args.seed,
gradient_accumulation_steps=2, # # set "auto" in deepspeed config, adjust it in trainer
fp16=True, # align with deepspeed
report_to="wandb",
warmup_ratio=0.1,
deepspeed="deepspeed_config/train_dp_config.json"
)
# === TRAINER ===
trainer = TrainerAdamP(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
tokenizer=tokenizer,
data_collator=data_collator
)
# === TRAIN MODEL ===
trainer.train()
# End logging
wandb.finish()