LlamaCheckpoints / train_deep.py
Yaning1001's picture
Add files using upload-large-folder tool
54f7697 verified
import sys
import torch
sys.path.append("..")
import os
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
import argparse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# import wandb
# Setup for Weights & Biases
# wandb.init(project="kallini", group="babylm-perturbation-experiments", name=run_id)
if __name__ == "__main__":
# === CONFIGURATION SETTINGS ===
parser = argparse.ArgumentParser(description="Training configuration.")
parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training.')
parser.add_argument('--epoch', type=int, default=20, help='train epoch')
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
args = parser.parse_args()
# no_pos_encodings_underscore = "" # Ex: "_nopos" if needed
ckpt_path = "./checkpoints"
# effective_bsz = 512
model_name = "meta-llama/Llama-3.2-3B"
model_save_name = "Llama-3.2-3B"
# === FILE PATHS BASED ON CONFIGURATION ===
run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
os.makedirs(cache_dir, exist_ok=True)
os.makedirs(run_dir, exist_ok=True)
# === DATASET LOADING ===
dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
# dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
dataset = load_dataset('babylm_dataset_test.py', name=dataset_name, trust_remote_code=True)
train_dataset = dataset['train']
val_dataset = dataset['validation']
print(train_dataset)
# === TOKENIZER & MODEL LOADING ===
# model_name = f"gpt2{'' if no_pos_encodings_underscore == '' else '-no-pos'}-small-{perturbation}-{paren_model}"
# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
model = AutoModelForCausalLM.from_pretrained(model_name,
# device_map="auto", # deepspeed needs to delete this setting
cache_dir=cache_dir)
# print("model:", model)
# === TOKENIZATION ===
def tokenize_function(examples):
return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_valid = val_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
shuffled_valid = tokenized_valid.shuffle()
tokenized_valid = shuffled_valid.select(range(600))
# === DATA COLLATOR ===
data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# === TRAINING ARGUMENTS ===
training_args = TrainingArguments(
output_dir=run_dir,
evaluation_strategy="steps",
eval_steps=10,
per_device_train_batch_size=args.batch_size, # set "auto" in deepspeed config, adjust it in trainer
logging_dir='./logs',
logging_steps=10,
save_steps=150,
learning_rate=5e-5, # align with deepspeed
num_train_epochs=args.epoch,
seed=args.seed,
gradient_accumulation_steps=2, # # set "auto" in deepspeed config, adjust it in trainer
fp16 = True, # align with deepspeed
report_to="none",
deepspeed="deepspeed_config/train_dp_config.json"
)
# === TRAINER ===
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_valid,
tokenizer=tokenizer,
data_collator=data_collator
)
# === TRAIN MODEL ===
trainer.train()
# End logging
# wandb.finish()
# import sys
# import torch
# sys.path.append("..")
# import os
# from datasets import load_dataset
# from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
# from utils_llama import PERTURBATIONS, BABYLM_SPLITS, BABYLM_DATA_PATH, \
# GENRES, MARKER_TOKEN_IDS, marker_sg_token, marker_pl_token, marker_rev_token, write_file
# import argparse
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# class Trainer(Trainer):
# def save_model(self, output_dir=None, _internal_call=False):
# if output_dir is None:
# output_dir = self.args.output_dir
# # 确保输出目录存在
# os.makedirs(output_dir, exist_ok=True)
# # 保存检查点
# super().save_model(output_dir, _internal_call=_internal_call)
# # 检查 output_dir 中的每个子文件夹
# for folder_name in os.listdir(output_dir):
# folder_path = os.path.join(output_dir, folder_name)
# if os.path.isdir(folder_path):
# print(f"Checking contents of {folder_path}")
# # 检查当前子文件夹的一级目录
# for name in os.listdir(folder_path):
# path = os.path.join(folder_path, name)
# if os.path.isdir(path) and "global_step" in name:
# shutil.rmtree(path)
# print(f"Removed directory {path}")
# # Setup for Weights & Biases
# # wandb.init(project="kallini", group="babylm-perturbation-experiments", name=run_id)
# if __name__ == "__main__":
# # === CONFIGURATION SETTINGS ===
# parser = argparse.ArgumentParser(description="Training configuration.")
# parser.add_argument('--perturbation', type=str, default='hop_tokens4', help='Type of perturbation to use.')
# parser.add_argument('--train_set', type=str, default='10M', help='Dataset size for training.')
# parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training.')
# parser.add_argument('--epoch', type=int, default=20, help='train epoch')
# parser.add_argument('--seed', type=int, default=0, help='Random seed.')
# args = parser.parse_args()
# # no_pos_encodings_underscore = "" # Ex: "_nopos" if needed
# ckpt_path = "./checkpoints"
# # effective_bsz = 512
# model_name = "meta-llama/Llama-3.2-3B"
# model_save_name = "Llama-3.2-3B"
# # === FILE PATHS BASED ON CONFIGURATION ===
# run_id = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
# cache_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "artifacts")
# run_dir = os.path.join(ckpt_path, f"{model_save_name}", run_id, "runs")
# os.makedirs(cache_dir, exist_ok=True)
# os.makedirs(run_dir, exist_ok=True)
# # === DATASET LOADING ===
# dataset_name = f"babylm_{args.perturbation}_{args.train_set}_seed{args.seed}"
# # dataset = load_dataset('babylm_dataset_llama.py', name=dataset_name)
# dataset = load_dataset('babylm_dataset_llama.py', name=dataset_name, trust_remote_code=True)
# train_dataset = dataset['train']
# # === TOKENIZER & MODEL LOADING ===
# # model_name = f"gpt2{'' if no_pos_encodings_underscore == '' else '-no-pos'}-small-{perturbation}-{paren_model}"
# # tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
# tokenizer = PERTURBATIONS[args.perturbation]['llama_tokenizer']
# model = AutoModelForCausalLM.from_pretrained(model_name,
# # device_map="auto", # deepspeed needs to delete this setting
# cache_dir=cache_dir)
# # print("model:", model)
# # === TOKENIZATION ===
# def tokenize_function(examples):
# return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=1024)
# tokenized_train = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
# # === DATA COLLATOR ===
# data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False)
# # === TRAINING ARGUMENTS ===
# training_args = TrainingArguments(
# output_dir=run_dir,
# evaluation_strategy="no",
# per_device_train_batch_size=args.batch_size, # set "auto" in deepspeed config, adjust it in trainer
# logging_dir='./logs',
# logging_steps=150,
# save_steps=5,
# learning_rate=5e-5, # align with deepspeed
# num_train_epochs=args.epoch,
# seed=args.seed,
# gradient_accumulation_steps=4, # # set "auto" in deepspeed config, adjust it in trainer
# fp16 = True, # align with deepspeed
# report_to="none",
# deepspeed="deepspeed_config/train_dp_config.json"
# )
# # === TRAINER ===
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=tokenized_train,
# tokenizer=tokenizer,
# data_collator=data_collator
# )
# # === TRAIN MODEL ===
# trainer.train()
# # End logging
# # wandb.finish()