File size: 5,449 Bytes
54f7697
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import sys
import torch
sys.path.append("..")

from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
    GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
from datasets import load_dataset
from FTP import AdamP

import wandb
import argparse
import copy
import math
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

ftp_k = 1
class TrainerAdamP(Trainer):

    def create_optimizer(self):
        optimizer_params = {
            "lr": 5e-6,
            "weight_decay": 0.0,
            "k": ftp_k,  # Example parameter for AdamP
            "exclude_set": set()  # Use empty set if you don't want exclusion
        }

        # Cache pre-trained model weights 
        params_to_opt = [x[1] for x in self.model.named_parameters() if x[1].requires_grad]
        params_to_opt_name = [x[0] for x in self.model.named_parameters() if x[1].requires_grad]
        params_anchor = copy.deepcopy(params_to_opt)
        param_group = [{'params': params_to_opt, 'pre': params_anchor, 'name': params_to_opt_name}]

        # Initialize the AdamP optimizer
        self.optimizer = AdamP(param_group, **optimizer_params)



if __name__ == "__main__":

    # === CONFIGURATION SETTINGS ===
    parser = argparse.ArgumentParser(description="Training configuration.")

    parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
    parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
    parser.add_argument('--batch_size', type=int, default=3, help='Batch size for training.')
    parser.add_argument('--epoch', type=int, default=3, help='train epoch')
    parser.add_argument('--seed', type=int, default=0, help='Random seed.')
    parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate.')

    args = parser.parse_args()

    # no_pos_encodings_underscore = ""  # Ex: "_nopos" if needed
    ckpt_path = "./checkpoints"
    # effective_bsz = 512

    model_name = "meta-llama/Llama-3.2-3B"
    model_save_name = "Llama-3.2-3B-FTP"

    # === FILE PATHS BASED ON CONFIGURATION ===

    
    wandb_id = f"{model_save_name}_{args.perturbation}_train_set_{args.train_set}_epoch_{args.epoch}_batch_size_{args.batch_size}_seed_{args.seed}_lr_{args.lr}_wandb_ftp_{ftp_k}"
    wandb.init(project="exp-impo-shuffle", group="ftp-1", name=wandb_id)
    wandb.config.update(args) 

    run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
    cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
    run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
    os.makedirs(cache_dir, exist_ok=True)
    os.makedirs(run_dir, exist_ok=True)

    # === DATASET LOADING === 
    dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
    dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
    train_dataset = dataset['train']
    valid_dataset = dataset['validation']

    # === TOKENIZER & MODEL LOADING ===
    # model_name = f"gpt2{'' if no_pos_encodings_underscore == '' else '-no-pos'}-small-{perturbation}-{paren_model}"
 
    # tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
    tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
    model = AutoModelForCausalLM.from_pretrained(model_name,  
                                                # device_map="auto",  # deepspeed needs to delete this setting
                                                cache_dir=cache_dir)

    # print("model:", model)
    # === TOKENIZATION ===
    def tokenize_function(examples): 
        return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
    tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
    tokenized_valid = valid_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

    shuffled_valid = tokenized_valid.shuffle()
    tokenized_valid = shuffled_valid.select(range(1000))
    print("tokenized_valid:", tokenized_valid)
    # print(train_dataset)
    # === DATA COLLATOR ===2
    data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)

    # === TRAINING ARGUMENTS ===
    training_args = TrainingArguments(
        output_dir=run_dir,
        evaluation_strategy="steps",
        eval_steps=10,
        per_device_train_batch_size=args.batch_size,  # set "auto" in deepspeed config, adjust it in trainer
        logging_dir='./logs',
        logging_steps=1,
        save_steps=100,
        learning_rate=args.lr, # align with deepspeed
        num_train_epochs=args.epoch,
        seed=args.seed,
        gradient_accumulation_steps=2, # # set "auto" in deepspeed config, adjust it in trainer
        fp16=True, # align with deepspeed
        report_to="wandb",
        warmup_ratio=0.1, 
        deepspeed="deepspeed_config/train_dp_config.json"
    )

    # === TRAINER ===
    trainer = TrainerAdamP(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_valid,
        tokenizer=tokenizer,
        data_collator=data_collator
    )
    
    # === TRAIN MODEL ===
    trainer.train()
    # End logging
    wandb.finish()