|
|
|
|
|
|
|
|
""" |
|
|
Training script for OlmoE model with adapters on the mlfoundations/dclm-baseline-1.0 dataset. |
|
|
This script demonstrates parameter-efficient fine-tuning using adapters. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import math |
|
|
import logging |
|
|
import argparse |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, List, Optional, Tuple, Any, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import DataLoader, IterableDataset |
|
|
from torch.optim import AdamW |
|
|
from torch.optim.lr_scheduler import LambdaLR |
|
|
|
|
|
from datasets import load_dataset |
|
|
from transformers import ( |
|
|
OlmoConfig, |
|
|
OlmoForCausalLM, |
|
|
AutoTokenizer, |
|
|
DataCollatorForLanguageModeling, |
|
|
HfArgumentParser, |
|
|
TrainingArguments, |
|
|
set_seed, |
|
|
get_scheduler, |
|
|
) |
|
|
from tqdm import tqdm |
|
|
from accelerate import Accelerator, DistributedType |
|
|
from accelerate.utils import find_batch_size |
|
|
|
|
|
from modeling_olmoe import ( |
|
|
OlmoEWithAdaptersForCausalLM, |
|
|
OlmoEForCausalLM, |
|
|
) |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
|
level=logging.INFO, |
|
|
) |
|
|
|
|
|
@dataclass |
|
|
class ModelArguments: |
|
|
"""Arguments for model configuration.""" |
|
|
model_name_or_path: str = field( |
|
|
default="allenai/OLMo-7B-Instruct", |
|
|
metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} |
|
|
) |
|
|
adapter_size: int = field( |
|
|
default=64, |
|
|
metadata={"help": "Size of the adapter layers"} |
|
|
) |
|
|
freeze_base_model: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Whether to freeze all parameters except the adapters"} |
|
|
) |
|
|
checkpoint_dir: Optional[str] = field( |
|
|
default=None, |
|
|
metadata={"help": "Path to save model checkpoints"} |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
"""Arguments for dataset configuration.""" |
|
|
dataset_name: str = field( |
|
|
default="mlfoundations/dclm-baseline-1.0", |
|
|
metadata={"help": "Dataset name or path for training"} |
|
|
) |
|
|
streaming: bool = field( |
|
|
default=True, |
|
|
metadata={"help": "Whether to stream the dataset"} |
|
|
) |
|
|
streaming_buffer_size: int = field( |
|
|
default=8192, |
|
|
metadata={"help": "Buffer size for streaming dataset"} |
|
|
) |
|
|
max_seq_length: int = field( |
|
|
default=1024, |
|
|
metadata={"help": "Maximum sequence length for training"} |
|
|
) |
|
|
preprocessing_num_workers: Optional[int] = field( |
|
|
default=None, |
|
|
metadata={"help": "Number of workers for preprocessing"} |
|
|
) |
|
|
text_column_name: str = field( |
|
|
default="text", |
|
|
metadata={"help": "Column name for text data"} |
|
|
) |
|
|
|
|
|
|
|
|
class StreamingTextDataset(IterableDataset): |
|
|
"""Dataset for streaming text data.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dataset_name: str, |
|
|
tokenizer, |
|
|
max_seq_length: int, |
|
|
streaming: bool = True, |
|
|
text_column_name: str = "text", |
|
|
buffer_size: int = 8192, |
|
|
split: str = "train", |
|
|
): |
|
|
self.tokenizer = tokenizer |
|
|
self.max_seq_length = max_seq_length |
|
|
self.text_column_name = text_column_name |
|
|
|
|
|
|
|
|
self.dataset = load_dataset( |
|
|
dataset_name, |
|
|
split=split, |
|
|
streaming=streaming, |
|
|
) |
|
|
if streaming: |
|
|
|
|
|
self.dataset = self.dataset.shuffle(buffer_size=buffer_size) |
|
|
|
|
|
def __iter__(self): |
|
|
buffer = [] |
|
|
current_length = 0 |
|
|
|
|
|
for example in self.dataset: |
|
|
text = example[self.text_column_name] |
|
|
if not text or len(text.strip()) == 0: |
|
|
continue |
|
|
|
|
|
tokenized = self.tokenizer( |
|
|
text, |
|
|
truncation=False, |
|
|
return_attention_mask=False, |
|
|
return_token_type_ids=False, |
|
|
add_special_tokens=False, |
|
|
) |
|
|
|
|
|
ids = tokenized["input_ids"] |
|
|
buffer.extend(ids) |
|
|
|
|
|
|
|
|
while len(buffer) >= self.max_seq_length: |
|
|
yield { |
|
|
"input_ids": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long), |
|
|
"labels": torch.tensor(buffer[:self.max_seq_length], dtype=torch.long), |
|
|
} |
|
|
buffer = buffer[self.max_seq_length:] |
|
|
|
|
|
|
|
|
def create_optimizer_and_scheduler( |
|
|
model: nn.Module, |
|
|
args: TrainingArguments, |
|
|
num_training_steps: int |
|
|
) -> Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler]: |
|
|
"""Create optimizer and learning rate scheduler.""" |
|
|
|
|
|
|
|
|
if hasattr(model, "get_trainable_parameters"): |
|
|
optimizer_params = model.get_trainable_parameters() |
|
|
logger.info(f"Training with {len(optimizer_params)} trainable parameters") |
|
|
else: |
|
|
|
|
|
optimizer_params = [p for p in model.parameters() if p.requires_grad] |
|
|
logger.info(f"Training with {len(optimizer_params)} parameters") |
|
|
|
|
|
|
|
|
optimizer = AdamW( |
|
|
optimizer_params, |
|
|
lr=args.learning_rate, |
|
|
betas=(args.adam_beta1, args.adam_beta2), |
|
|
eps=args.adam_epsilon, |
|
|
weight_decay=args.weight_decay, |
|
|
) |
|
|
|
|
|
|
|
|
scheduler = get_scheduler( |
|
|
name=args.lr_scheduler_type, |
|
|
optimizer=optimizer, |
|
|
num_warmup_steps=args.warmup_steps, |
|
|
num_training_steps=num_training_steps, |
|
|
) |
|
|
|
|
|
return optimizer, scheduler |
|
|
|
|
|
|
|
|
def train( |
|
|
model_args: ModelArguments, |
|
|
data_args: DataArguments, |
|
|
training_args: TrainingArguments, |
|
|
): |
|
|
"""Main training function.""" |
|
|
|
|
|
|
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=training_args.gradient_accumulation_steps, |
|
|
mixed_precision=training_args.fp16 and "fp16" or training_args.bf16 and "bf16" or "no", |
|
|
) |
|
|
|
|
|
|
|
|
logger.info(accelerator.state) |
|
|
if accelerator.is_local_main_process: |
|
|
logger.info(f"Model arguments: {model_args}") |
|
|
logger.info(f"Data arguments: {data_args}") |
|
|
logger.info(f"Training arguments: {training_args}") |
|
|
|
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
config = OlmoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
|
|
config.adapter_size = model_args.adapter_size |
|
|
|
|
|
|
|
|
logger.info(f"Loading OlmoE model with adapters from {model_args.model_name_or_path}") |
|
|
base_model = OlmoForCausalLM.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) |
|
|
|
|
|
|
|
|
model = OlmoEWithAdaptersForCausalLM(config) |
|
|
|
|
|
|
|
|
|
|
|
model.load_state_dict(base_model.state_dict(), strict=False) |
|
|
|
|
|
|
|
|
if model_args.freeze_base_model: |
|
|
logger.info("Freezing base model parameters") |
|
|
model.freeze_base_model() |
|
|
|
|
|
|
|
|
logger.info(f"Loading dataset: {data_args.dataset_name}") |
|
|
train_dataset = StreamingTextDataset( |
|
|
dataset_name=data_args.dataset_name, |
|
|
tokenizer=tokenizer, |
|
|
max_seq_length=data_args.max_seq_length, |
|
|
streaming=data_args.streaming, |
|
|
buffer_size=data_args.streaming_buffer_size, |
|
|
text_column_name=data_args.text_column_name, |
|
|
) |
|
|
|
|
|
|
|
|
data_collator = DataCollatorForLanguageModeling( |
|
|
tokenizer=tokenizer, |
|
|
mlm=False, |
|
|
) |
|
|
|
|
|
|
|
|
train_dataloader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=training_args.per_device_train_batch_size, |
|
|
collate_fn=data_collator, |
|
|
num_workers=data_args.preprocessing_num_workers or 0, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
num_update_steps_per_epoch = training_args.max_steps |
|
|
num_training_steps = training_args.max_steps |
|
|
|
|
|
|
|
|
optimizer, lr_scheduler = create_optimizer_and_scheduler( |
|
|
model=model, |
|
|
args=training_args, |
|
|
num_training_steps=num_training_steps, |
|
|
) |
|
|
|
|
|
|
|
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( |
|
|
model, optimizer, train_dataloader, lr_scheduler |
|
|
) |
|
|
|
|
|
|
|
|
total_batch_size = ( |
|
|
training_args.per_device_train_batch_size |
|
|
* accelerator.num_processes |
|
|
* training_args.gradient_accumulation_steps |
|
|
) |
|
|
logger.info(f"Total batch size (with parallel & accumulation): {total_batch_size}") |
|
|
|
|
|
|
|
|
logger.info(f"Number of training steps: {num_training_steps}") |
|
|
logger.info(f"Number of warmup steps: {training_args.warmup_steps}") |
|
|
|
|
|
|
|
|
progress_bar = tqdm( |
|
|
range(num_training_steps), |
|
|
disable=not accelerator.is_local_main_process, |
|
|
desc="Training", |
|
|
) |
|
|
completed_steps = 0 |
|
|
starting_epoch = 0 |
|
|
global_step = 0 |
|
|
|
|
|
|
|
|
logger.info("Starting training...") |
|
|
model.train() |
|
|
|
|
|
for step, batch in enumerate(train_dataloader): |
|
|
|
|
|
if starting_epoch > 0 and step < starting_epoch * num_update_steps_per_epoch: |
|
|
progress_bar.update(1) |
|
|
continue |
|
|
|
|
|
with accelerator.accumulate(model): |
|
|
|
|
|
outputs = model(**batch) |
|
|
loss = outputs.loss |
|
|
|
|
|
|
|
|
accelerator.backward(loss) |
|
|
|
|
|
|
|
|
optimizer.step() |
|
|
lr_scheduler.step() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
progress_bar.update(1) |
|
|
completed_steps += 1 |
|
|
global_step += 1 |
|
|
|
|
|
|
|
|
if global_step % training_args.logging_steps == 0: |
|
|
|
|
|
loss_value = accelerator.gather(loss).mean().item() |
|
|
logger.info(f"Step {global_step}: loss = {loss_value:.4f}, lr = {lr_scheduler.get_last_lr()[0]:.8f}") |
|
|
|
|
|
|
|
|
if hasattr(accelerator.trackers[0], "store"): |
|
|
accelerator.trackers[0].store({ |
|
|
"loss": loss_value, |
|
|
"learning_rate": lr_scheduler.get_last_lr()[0], |
|
|
"step": global_step, |
|
|
}) |
|
|
|
|
|
|
|
|
if training_args.save_steps > 0 and global_step % training_args.save_steps == 0: |
|
|
if model_args.checkpoint_dir is not None: |
|
|
output_dir = os.path.join(model_args.checkpoint_dir, f"checkpoint-{global_step}") |
|
|
accelerator.save_state(output_dir) |
|
|
logger.info(f"Saved checkpoint to {output_dir}") |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
unwrapped_model.save_pretrained( |
|
|
output_dir, |
|
|
is_main_process=accelerator.is_main_process, |
|
|
save_function=accelerator.save, |
|
|
) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
|
|
if completed_steps >= num_training_steps: |
|
|
break |
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is not None: |
|
|
output_dir = os.path.join(model_args.checkpoint_dir, "final-model") |
|
|
accelerator.save_state(output_dir) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
unwrapped_model = accelerator.unwrap_model(model) |
|
|
unwrapped_model.save_pretrained( |
|
|
output_dir, |
|
|
is_main_process=accelerator.is_main_process, |
|
|
save_function=accelerator.save, |
|
|
) |
|
|
tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
logger.info(f"Saved final model to {output_dir}") |
|
|
|
|
|
logger.info("Training complete!") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point.""" |
|
|
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
|
|
model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
|
|
|
|
|
|
|
|
if model_args.checkpoint_dir is None: |
|
|
model_args.checkpoint_dir = training_args.output_dir |
|
|
os.makedirs(model_args.checkpoint_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
train(model_args, data_args, training_args) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |