Scalable_monarch_adapter / src /math_train.py
nvan13's picture
Upload folder using huggingface_hub
ecadbd9 verified
import yaml
import draccus
from typing import List, Tuple
from dataclasses import field, dataclass, asdict
from .config import MainConfig, convert_to_trainer_args
import random
import numpy as np
import torch
import transformers
import wandb
from datasets import load_dataset
import os
import json
from datetime import datetime
import torch
import torch.optim as optim
from typing import Sequence, Literal, Dict
from torch.nn.utils.rnn import pad_sequence
from transformers import (
AutoModelForCausalLM, AutoTokenizer,
Trainer,
set_seed,
get_linear_schedule_with_warmup,
get_cosine_schedule_with_warmup,
)
import copy
from smpeft.sama import SamaConfig #RotationTuner
from smpeft import get_peft_model, PeftModel
from .utils import trainable_parameters_to_file, set_seed_all
import warnings
# Ignore FutureWarning: prims_common.check, Online Softmax
warnings.filterwarnings("ignore", category=FutureWarning, module='torch._inductor.lowering')
warnings.filterwarnings("ignore", message=".*Online softmax is disabled on the fly.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*Our suggested max number of worker in current system is 1.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*will be initialized from a multivariate normal distribution.*")
warnings.filterwarnings("ignore", message=".*that differ from the model config and generation config.*", category=UserWarning)
warnings.filterwarnings("ignore", message=".*torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch..*", category=UserWarning)
IGNORE_INDEX=-100
PROMPT_TEMPLATE = (
"Below is an passage followed by a coresponding question that describes a task "
"Write a response that appropriately completes the request with your answer.\n\n"
"### Instruction:\n{instruction}\n\n### Response:"
)
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
"""
Tokenize a list of strings.
Modified to return lists (not tensors) for better compatibility with HF dataset.map().
"""
tokenized_list = [
tokenizer(
text,
return_tensors=None, # Return python lists, let DataCollator handle tensors later
padding=False, # Do not pad here, pad in DataCollator
max_length=tokenizer.model_max_length,
truncation=True,
)
for text in strings
]
input_ids = [tokenized['input_ids'] for tokenized in tokenized_list]
# Calculate length of valid tokens (since we are not padding yet, it is just len())
input_ids_lens = [len(x) for x in input_ids]
return dict(
input_ids=input_ids,
labels=input_ids, # Initially labels are same as input_ids
input_ids_lens=input_ids_lens,
)
def preprocess(
sources: Sequence[str],
targets: Sequence[str],
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
"""
Preprocess the data by tokenizing and masking the source (instruction).
"""
# 1. Concatenate source (instruction) and target (answer)
examples = [s + t for s, t in zip(sources, targets)]
# 2. Tokenize the full examples and just the sources (to find masking boundary)
examples_tokenized = _tokenize_fn(examples, tokenizer)
sources_tokenized = _tokenize_fn(sources, tokenizer)
input_ids = examples_tokenized["input_ids"]
labels = copy.deepcopy(input_ids)
# # 3. Masking: Set the labels corresponding to the source text to IGNORE_INDEX
# for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
# # Safety check: ensure we don't mask everything if truncation happened badly
# if source_len < len(label):
# label[:source_len] = [IGNORE_INDEX] * source_len
# else:
# # If source is longer than max_seq_length (truncated), we might mask everything.
# # In practice, you might want to log a warning here.
# label[:] = [IGNORE_INDEX] * len(label)
# raise UserWarning(f"Truncated prompt: source_len = {source_len}, label_len = {len(label)}")
for i, source_len in enumerate(sources_tokenized["input_ids_lens"]):
# Determine how many tokens to mask
# If source_len >= len(labels[i]), it means the prompt took up the entire sequence
# (truncation happened), so we mask everything (ignore this sample).
mask_len = min(source_len, len(labels[i]))
# Apply mask
labels[i][:mask_len] = [IGNORE_INDEX] * mask_len
# REMOVED: raise UserWarning(...)
# We silently ignore truncated samples by masking them entirely.
return dict(input_ids=input_ids, labels=labels)
def train_tokenize_function(examples, tokenizer):
"""
Adaptation for MetaMathQA (395k) dataset structure.
MetaMath usually has columns: 'query', 'response', 'type', etc.
"""
sources = []
targets = []
# Iterate over the batch
# Check your specific dataset column names. Usually 'query' and 'response' for MetaMath.
for query, response in zip(examples['query'], examples['response']):
# 1. Format the Input (Instruction)
# Apply the prompt template to the math problem
source_text = PROMPT_TEMPLATE.format_map(dict(instruction=query))
sources.append(source_text)
# 2. Format the Output (Response)
# Add EOS token at the end of the response
target_text = f"{response}{tokenizer.eos_token}"
targets.append(target_text)
# 3. Run the preprocessing logic
data_dict = preprocess(sources, targets, tokenizer)
return data_dict
@dataclass
class DataCollatorForSupervisedDataset():
tokenizer: transformers.PreTrainedTokenizer
max_length: int = field(default=512)
mode: str = field(default="fixed") # "dynamic" or "fixed"
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
# Extract inputs and labels
# Assuming instances is a list of dicts like {'input_ids': [...], 'labels': [...]}
input_ids_list = [torch.tensor(x["input_ids"], dtype=torch.long) for x in instances]
labels_list = [torch.tensor(x["labels"], dtype=torch.long) for x in instances]
# 1. Determine padding logic
if self.mode == "dynamic":
# Dynamic padding: pad to the longest sequence in the batch
# But cap it at self.max_length to prevent OOM
batch_max_len = max([len(x) for x in input_ids_list])
target_len = min(batch_max_len, self.max_length)
else:
# Fixed padding: always pad to max_length
target_len = self.max_length
# 2. Helper to pad and truncate
def pad_and_truncate(tensors, padding_value):
# First, pad everything using PyTorch's optimized utility (batch_first=True)
padded = pad_sequence(tensors, batch_first=True, padding_value=padding_value)
# Handle truncation/extending to exact target_len
curr_len = padded.shape[1]
if curr_len > target_len:
# Truncate if too long (rare if filtered beforehand)
return padded[:, :target_len]
elif curr_len < target_len:
# Pad more if shorter than target_len (happens in fixed mode)
diff = target_len - curr_len
padding = torch.full((padded.shape[0], diff), padding_value, dtype=padded.dtype)
return torch.cat([padded, padding], dim=1)
else:
return padded
# 3. Apply padding
# Critical: tokenizer.pad_token_id must NOT be None here
if self.tokenizer.pad_token_id is None:
raise ValueError("Tokenizer.pad_token_id is None. Please set it to eos_token_id or unk_token_id.")
input_ids = pad_and_truncate(input_ids_list, self.tokenizer.pad_token_id)
labels = pad_and_truncate(labels_list, IGNORE_INDEX)
# 4. Create Attention Mask explicitly
# .ne() creates Bools, .long() casts to 0s and 1s for compatibility
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask
}
@draccus.wrap()
def main(mainCfg: MainConfig):
print('='*120)
set_seed_all(mainCfg.seed)
# print(mainCfg)
# print(draccus.dump(mainCfg, default_flow_style=False))
# set_seed(mainCfg.seed)
training_args = convert_to_trainer_args(mainCfg)
# training_args.project = f'Rotation-Llama2-{mainCfg.data.dataset_name}'
# print(training_args.to_json_string())
task_name = mainCfg.data.dataset_name # 'MATH'
# wandb
ENTITY = "nvan-13-korea-university"
PROJECT = os.environ.get("WANDB_PROJECT")
api = wandb.Api()
try:
runs_list = api.runs(f"{ENTITY}/{PROJECT}")
next_run_num = len(runs_list) + 1
except Exception as e:
next_run_num = 1
model = AutoModelForCausalLM.from_pretrained(mainCfg.model.model_name,
device_map="auto", low_cpu_mem_usage=True,
dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
# attn_implementation="sdpa",
)
total_params_now = sum(p.numel() for p in model.parameters())
print(f'#params of the pretrained model, {total_params_now:,}')
# print(model)
if mainCfg.model.adapter_path is not None:
print('___ Loading from: ', mainCfg.model.adapter_path)
model = PeftModel.from_pretrained(model, mainCfg.model.adapter_path, is_trainable = True)
elif mainCfg.sama_adapter.col_L is not None:
sama_adapter_config = asdict(mainCfg.sama_adapter)
# rotation_adapter_config[peft_type]
for adapter_name in mainCfg.data.adapter_names:
print("Init from Sama Config:", json.dumps(sama_adapter_config, indent=4, sort_keys=True))
sama_config = SamaConfig(**sama_adapter_config)
model = get_peft_model(model, sama_config, adapter_name=adapter_name)
# model.set_adapter(adapter_name)
else:
print("Full Parameter Fine-Tuning")
model.print_trainable_parameters()
# for name, param in model.named_parameters():
# if param.requires_grad == True and 'layers.6' in name:
# print(f'name {name}, shape {param.shape}')
sama_trainable_layers = filter(
lambda p: p.requires_grad, model.parameters()
)
tokenizer = AutoTokenizer.from_pretrained(
mainCfg.model.model_name,
model_max_length=mainCfg.model.model_max_seq_length,
padding_side="right",
use_fast=True,
)
if tokenizer.pad_token is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token_id = tokenizer.unk_token_id
tokenizer.pad_token = tokenizer.unk_token
print("Set PAD token to UNK token.")
elif tokenizer.eos_token_id is not None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenizer.pad_token = tokenizer.eos_token
print("Set PAD token to EOS token.")
if model is not None:
model.config.pad_token_id = tokenizer.pad_token_id
if model.config.pad_token_id != tokenizer.pad_token_id:
raise ValueError("Failed to sync pad_token_id between tokenizer and model config")
metamathqa_train = load_dataset(path=mainCfg.data.path, split=mainCfg.data.dataset_split)
metamathqa_valid = load_dataset(path=mainCfg.data.path, split='train[20000:20256]')
train_dataset = metamathqa_train.map(
train_tokenize_function,
batched=True,
batch_size=20000,
num_proc=32, # Adjust based on your CPU
remove_columns=metamathqa_train.column_names,
load_from_cache_file=True, # Set False for debugging new logic
desc="Running tokenizer on train dataset",
fn_kwargs={"tokenizer": tokenizer}
)
dev_dataset = metamathqa_valid.map(
train_tokenize_function,
batched=True,
batch_size=20000,
num_proc=32,
load_from_cache_file=True,
remove_columns=metamathqa_valid.column_names,
fn_kwargs={"tokenizer": tokenizer}
)
print('- Train dataset size: ', len(train_dataset))
print('- Dev dataset size: ', len(dev_dataset))
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer, max_length=mainCfg.model.model_max_seq_length,
#mode=mainCfg.model.data_collator_mode,
)
data_module = dict(train_dataset=train_dataset, data_collator=data_collator, eval_dataset=dev_dataset)
optimizer = optim.AdamW(
sama_trainable_layers,
lr=mainCfg.trainer_args.learning_rate, #
eps=1e-8
)
# optimizer = optim.Adafactor(
# sama_trainable_layers,
# lr=mainCfg.trainer_args.learning_rate, #
# eps=(None,1e-8)
# )
num_devices = training_args.n_gpu if training_args.n_gpu > 0 else 1
per_device_train_batch_size = training_args.per_device_train_batch_size
gradient_accumulation_steps = training_args.gradient_accumulation_steps
# Effective batch size used for updates
total_train_batch_size = per_device_train_batch_size * num_devices * gradient_accumulation_steps
# Calculate steps
num_update_steps_per_epoch = len(train_dataset) // total_train_batch_size
max_steps = int(training_args.num_train_epochs * num_update_steps_per_epoch)
print(f"___ Estimated Total Training Steps: {max_steps}")
if training_args.lr_scheduler_type == "cosine": # Change to "linear" if preferred
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=max_steps,
)
else:
# Default to Linear Decay
lr_scheduler = get_linear_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=max_steps,
)
start_time = datetime.now()
date_str = start_time.strftime("%y%m%dd%Hh%Mm%S")[1:]
output_dir = f'{training_args.output_dir}/{task_name}/'\
f't{date_str},' \
f'mlr{training_args.learning_rate:.1e},'\
f'b{mainCfg.trainer_args.per_device_train_batch_size},{mainCfg.trainer_args.gradient_accumulation_steps},'\
f'nb{mainCfg.sama_adapter.num_unique_blocks_L},{mainCfg.sama_adapter.num_unique_blocks_R},'\
f'cL{mainCfg.sama_adapter.col_L},'\
f'rR{mainCfg.sama_adapter.row_R},s{mainCfg.sama_adapter.scaling},'\
f'init{mainCfg.run_text},dr{mainCfg.sama_adapter.drop_out},'\
f'ep{training_args.num_train_epochs}' \
print('out', type(output_dir), output_dir)
# Save infor to a file
trainable_parameters_to_file(model, output_dir)
training_args.output_dir=output_dir
print(f'Current output_dir: {output_dir}')
training_args.run_name = f'[{next_run_num}-{task_name}]mlr{training_args.learning_rate:.1e},'\
f'b{mainCfg.trainer_args.per_device_train_batch_size},{mainCfg.trainer_args.gradient_accumulation_steps},'\
f'nb{mainCfg.sama_adapter.num_unique_blocks_L},{mainCfg.sama_adapter.num_unique_blocks_R},'\
f'cL{mainCfg.sama_adapter.col_L},'\
f'rR{mainCfg.sama_adapter.row_R},s{mainCfg.sama_adapter.scaling},'\
f'init{mainCfg.run_text},dr{mainCfg.sama_adapter.drop_out}' \
f'ep{training_args.num_train_epochs}' \
f't{date_str}'
print('out', type(training_args.run_name), training_args.run_name)
print(f'data: {task_name}, train: {len(train_dataset)}, valid: {len(dev_dataset)}')
from .utils import ExperimentMonitorCallback
monitor = ExperimentMonitorCallback(
log_file_path="./training_metrics_bs8.json",
run_name="Experiment_BatchSize_8",
log_interval=20 # Will calculate average over every 100 steps
)
trainer = Trainer(
model=model,
args=training_args,
# compute_metrics=compute_metrics,
processing_class=tokenizer,
optimizers=(optimizer, lr_scheduler),
**data_module,
# callbacks=[monitor],
)
model.config.use_cache = False
trainer.train()
end_time = datetime.now()
print('end time: ', end_time.strftime("%Y-%m-%d %H:%M:%S"), '| duration: ', end_time - start_time)
tokenizer.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
trainer.save_state()
model.peft_config.save_pretrained(os.path.join(training_args.output_dir, 'ft'))
model.save_pretrained(os.path.join(training_args.output_dir, 'ft2'))
if __name__ == "__main__":
main()