| | import sys |
| | import torch |
| | sys.path.append("..") |
| |
|
| | from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling |
| | from utils_llama_3B import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \ |
| | GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file |
| | from datasets import load_dataset |
| | from FTP_3 import AdamP |
| |
|
| | import wandb |
| | import argparse |
| | import copy |
| | import math |
| | import os |
| |
|
| | |
| | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
| |
|
| | |
| | ftp_k = 1 |
| |
|
| | |
| | class TrainerAdamP(Trainer): |
| | def create_optimizer(self): |
| | optimizer_params = { |
| | "lr": 5e-6, |
| | "weight_decay": 0.0, |
| | "k": ftp_k, |
| | "exclude_set": set() |
| | } |
| |
|
| | params_to_opt = [x[1] for x in self.model.named_parameters() if x[1].requires_grad] |
| | params_to_opt_name = [x[0] for x in self.model.named_parameters() if x[1].requires_grad] |
| | params_anchor = copy.deepcopy(params_to_opt) |
| | param_group = [{'params': params_to_opt, 'pre': params_anchor, 'name': params_to_opt_name}] |
| |
|
| | self.optimizer = AdamP(param_group, **optimizer_params) |
| |
|
| | def compute_loss(self, model, inputs, return_outputs=False): |
| | |
| | outputs = model(**inputs) |
| | loss = outputs.loss |
| |
|
| | |
| | input_data = inputs['input_ids'].detach().clone().float() |
| | input_data.requires_grad_(True) |
| |
|
| | outputs_with_grad = model(input_data) |
| | logits = outputs_with_grad.logits |
| | batch_size, seq_len, vocab_size = logits.size() |
| |
|
| | lip_mat = [] |
| |
|
| | for i in range(vocab_size): |
| | v = torch.zeros_like(logits) |
| | v[:, :, i] = 1 |
| | gradients = torch.autograd.grad( |
| | outputs=logits, |
| | inputs=input_data, |
| | grad_outputs=v, |
| | create_graph=True, |
| | retain_graph=True, |
| | only_inputs=True |
| | )[0] |
| | grad_norm = torch.norm(gradients, dim=-1).unsqueeze(dim=-1) |
| | lip_mat.append(grad_norm) |
| |
|
| | |
| | lip_concat = torch.cat(lip_mat, dim=-1) |
| | lip_con_norm = torch.norm(lip_concat, dim=-1) |
| |
|
| | |
| | lip_loss = torch.max(lip_con_norm) |
| | total_loss = loss + 0.5 * lip_loss |
| |
|
| | return (total_loss, outputs) if return_outputs else total_loss |
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| | parser = argparse.ArgumentParser(description="Training configuration.") |
| |
|
| | parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.') |
| | parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.') |
| | parser.add_argument('--batch_size', type=int, default=3, help='Batch size for training.') |
| | parser.add_argument('--epoch', type=int, default=3, help='Train epoch') |
| | parser.add_argument('--seed', type=int, default=0, help='Random seed.') |
| | parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate.') |
| |
|
| | args = parser.parse_args() |
| |
|
| | ckpt_path = "./checkpoints" |
| | model_name = "meta-llama/Llama-3.2-3B" |
| | model_save_name = "Llama-3.2-3B-FTP-Update" |
| |
|
| | wandb_id = f"{model_save_name}_{args.perturbation}_train_set_{args.train_set}_epoch_{args.epoch}_batch_size_{args.batch_size}_seed_{args.seed}_lr_{args.lr}_wandb_ftp_{ftp_k}_UGD_ftp_2_lip_0.5" |
| | wandb.init(project="FTP-shuffle", group="shuffle", name=wandb_id) |
| | wandb.config.update(args) |
| |
|
| | run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" |
| | cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts") |
| | run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs") |
| | os.makedirs(cache_dir, exist_ok=True) |
| | os.makedirs(run_dir, exist_ok=True) |
| |
|
| | |
| | dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}" |
| | dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True) |
| | train_dataset = dataset['train'] |
| | valid_dataset = dataset['validation'] |
| |
|
| | |
| | tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer'] |
| | model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir) |
| |
|
| | |
| | def tokenize_function(examples): |
| | return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024) |
| |
|
| | tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) |
| | tokenized_valid = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"]) |
| |
|
| | shuffled_valid = tokenized_valid.shuffle() |
| | tokenized_valid = shuffled_valid.select(range(1000)) |
| |
|
| | |
| | data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) |
| |
|
| | |
| | training_args = TrainingArguments( |
| | output_dir=run_dir, |
| | evaluation_strategy="steps", |
| | eval_steps=10, |
| | per_device_train_batch_size=args.batch_size, |
| | logging_dir='./logs', |
| | logging_steps=1, |
| | save_steps=1000000, |
| | learning_rate=args.lr, |
| | num_train_epochs=args.epoch, |
| | seed=args.seed, |
| | gradient_accumulation_steps=2, |
| | fp16=True, |
| | report_to="wandb", |
| | warmup_ratio=0.1, |
| | deepspeed="deepspeed_config/train_dp_config.json" |
| | ) |
| |
|
| | |
| | trainer = TrainerAdamP( |
| | model=model, |
| | args=training_args, |
| | train_dataset=tokenized_train, |
| | eval_dataset=tokenized_valid, |
| | tokenizer=tokenizer, |
| | data_collator=data_collator |
| | ) |
| |
|
| | |
| | trainer.train() |
| | |
| | wandb.finish() |
| |
|