|
|
""" |
|
|
Utilities for checkpointing learning dynamics-related states (i.e. activations, weights, grads, etc.) |
|
|
|
|
|
We save the learning dynamics states in a subdirectory of the checkpointing directory. |
|
|
""" |
|
|
|
|
|
import os |
|
|
import re |
|
|
from typing import Dict, Optional |
|
|
|
|
|
import deepspeed |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from datasets import Dataset |
|
|
from huggingface_hub import upload_folder |
|
|
from lightning.fabric import Fabric |
|
|
from lightning.fabric.strategies import DeepSpeedStrategy |
|
|
from lightning.fabric.utilities.rank_zero import rank_zero_only |
|
|
from torch.nn import functional as F |
|
|
from torch.utils.data import DataLoader |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
from src.config import CheckpointingConfig |
|
|
from src.config.checkpointing_config import LearningDynamicsCheckpointingConfig |
|
|
from src.training.utils.initialization import initialize_model |
|
|
from src.training.utils.io import use_backoff |
|
|
|
|
|
|
|
|
|
|
|
class DummyOptimizer(optim.Optimizer): |
|
|
def __init__(self, params): |
|
|
super().__init__(params, defaults={}) |
|
|
|
|
|
|
|
|
class CheckpointStateExtractor: |
|
|
""" |
|
|
Class to extract and save the states of a model at a given checkpoint step for learning |
|
|
dynamics research. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
learning_dynamics_config: LearningDynamicsCheckpointingConfig, |
|
|
fabric: Fabric, |
|
|
model: nn.Module, |
|
|
): |
|
|
self.learning_dynamics_config = learning_dynamics_config |
|
|
self.fabric = fabric |
|
|
self.model = model |
|
|
|
|
|
def extract_states(self, dataloader, compute_gradients: bool = False): |
|
|
"""Extracts model states (activations, weights, and optionally gradients). |
|
|
|
|
|
Given a dataloader, this function will perform a forward pass of the model on each batch, |
|
|
and save the activations and weights at each layer. If compute_gradients is True, it will |
|
|
also compute the gradients of the model parameters. |
|
|
|
|
|
Args: |
|
|
dataloader: The dataloader containing the dataset to extract states from. |
|
|
compute_gradients: Whether to compute the gradients of the model parameters. |
|
|
|
|
|
Returns: |
|
|
A dictionary containing the activations, weights, and optionally gradients of the model. |
|
|
""" |
|
|
checkpoint_activations = {} |
|
|
checkpoint_weights = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
forward_hooks = self._setup_forward_hooks( |
|
|
checkpoint_activations, |
|
|
checkpoint_weights, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for sub_batch in dataloader: |
|
|
_input_ids = torch.tensor(sub_batch["input_ids"], device=self.fabric.device) |
|
|
|
|
|
if compute_gradients: |
|
|
if "labels" in sub_batch: |
|
|
input_ids = _input_ids |
|
|
labels = torch.tensor( |
|
|
sub_batch["labels"], device=self.fabric.device |
|
|
) |
|
|
else: |
|
|
input_ids = _input_ids[:, :-1] |
|
|
labels = _input_ids[:, 1:] |
|
|
else: |
|
|
input_ids = _input_ids |
|
|
labels = None |
|
|
|
|
|
if labels is None: |
|
|
|
|
|
with torch.no_grad(): |
|
|
_ = self.model(input_ids) |
|
|
else: |
|
|
|
|
|
|
|
|
outputs, _ = self.model(input_ids) |
|
|
outputs = outputs.transpose(1, 2) |
|
|
loss = F.cross_entropy(outputs, labels) |
|
|
self.fabric.backward(loss, model=self.model) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for hook in forward_hooks: |
|
|
hook.remove() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
layer_suffixes = self.learning_dynamics_config.layer_suffixes |
|
|
checkpoint_gradients = {} |
|
|
if compute_gradients: |
|
|
for name, param in self.model.named_parameters(): |
|
|
|
|
|
if ( |
|
|
any(layer_suffix in name for layer_suffix in layer_suffixes) |
|
|
and "weight" in name |
|
|
): |
|
|
if isinstance(self.fabric.strategy, DeepSpeedStrategy): |
|
|
_grad = deepspeed.utils.safe_get_full_grad(param) |
|
|
else: |
|
|
_grad = param.grad |
|
|
|
|
|
assert _grad is not None, f"Gradient is None for layer: {name}" |
|
|
name = re.sub(r"\.weight", "", name) |
|
|
checkpoint_gradients[name] = _grad.detach().cpu() |
|
|
|
|
|
|
|
|
self.model.zero_grad() |
|
|
|
|
|
return checkpoint_activations, checkpoint_weights, checkpoint_gradients |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _setup_forward_hooks(self, checkpoint_activations, checkpoint_weights): |
|
|
"""Setup forward hooks for the model to save activations and weights at each layer. |
|
|
|
|
|
This function will setup forward hooks on the layers of the model that we are interested in. |
|
|
The forward hooks will save the activations and weights at each layer whenever the forward pass |
|
|
is performed. |
|
|
|
|
|
Args: |
|
|
checkpoint_activations: A dictionary to store the activations at each layer. |
|
|
checkpoint_weights: A dictionary to store the weights at each layer. |
|
|
|
|
|
Returns: |
|
|
A list of forward hooks. We do this so that we can remove the hooks after the forward pass |
|
|
is complete. |
|
|
""" |
|
|
|
|
|
forward_hooks = [] |
|
|
layer_suffixes = self.learning_dynamics_config.layer_suffixes |
|
|
|
|
|
for name, module in self.model.named_modules(): |
|
|
if any(layer_suffix in name for layer_suffix in layer_suffixes): |
|
|
_forward_hook = module.register_forward_hook( |
|
|
self._get_forward_hook( |
|
|
name, checkpoint_activations, checkpoint_weights |
|
|
) |
|
|
) |
|
|
forward_hooks.append(_forward_hook) |
|
|
return forward_hooks |
|
|
|
|
|
def _get_forward_hook( |
|
|
self, module_name, checkpoint_activations, checkpoint_weights |
|
|
): |
|
|
"""Get a forward hook for a given module. |
|
|
|
|
|
This function is called by the _setup_forward_hooks function to setup a forward hook for a given |
|
|
module. This functions is a closure that captures the module_name, checkpoint_activations, and |
|
|
checkpoint_weights. |
|
|
|
|
|
Args: |
|
|
module_name: The name of the module to setup a forward hook for. |
|
|
checkpoint_activations: A dictionary to store the activations at each layer. |
|
|
checkpoint_weights: A dictionary to store the weights at each layer. |
|
|
|
|
|
Returns: |
|
|
A forward hook for the given module. |
|
|
""" |
|
|
|
|
|
def _forward_hook(module, _, module_out): |
|
|
sequence_idx = self.learning_dynamics_config.sequence_idx |
|
|
|
|
|
local_activations = module_out[:, sequence_idx, :].detach() |
|
|
|
|
|
|
|
|
gathered_activations = self.fabric.all_gather(local_activations) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gathered_activations = gathered_activations.transpose(0, 1).reshape( |
|
|
-1, gathered_activations.shape[-1] |
|
|
) |
|
|
|
|
|
|
|
|
if module_name not in checkpoint_activations: |
|
|
|
|
|
checkpoint_activations[module_name] = ( |
|
|
gathered_activations.detach().cpu() |
|
|
) |
|
|
|
|
|
|
|
|
weight_matrix = module.weight.detach().cpu() |
|
|
checkpoint_weights[module_name] = weight_matrix |
|
|
else: |
|
|
|
|
|
checkpoint_activations[module_name] = torch.cat( |
|
|
( |
|
|
checkpoint_activations[module_name], |
|
|
gathered_activations.detach().cpu(), |
|
|
) |
|
|
) |
|
|
|
|
|
return _forward_hook |
|
|
|
|
|
|
|
|
def compute_learning_dynamics_states( |
|
|
checkpointing_config: CheckpointingConfig, |
|
|
fabric: Fabric, |
|
|
model: nn.Module, |
|
|
dataset: Dataset, |
|
|
compute_gradients: bool = False, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
"""Computes the learning dynamics metrics for a given checkpoint step. |
|
|
|
|
|
Uses the CheckpointStateExtractor to extract the activations, weights, and optionally gradients |
|
|
of the model at a given checkpoint step. |
|
|
|
|
|
Args: |
|
|
checkpointing_config: The configuration object for checkpointing. |
|
|
fabric: The Fabric instance for distributed training. |
|
|
model: The model to extract states from. |
|
|
dataset: The dataset to extract states from. |
|
|
compute_gradients: Whether to compute the gradients of the model parameters. |
|
|
|
|
|
Returns: |
|
|
A dictionary containing the activations, weights, and optionally gradients of the model. |
|
|
""" |
|
|
|
|
|
|
|
|
fabric.barrier() |
|
|
model.to("cpu") |
|
|
|
|
|
|
|
|
def _collate_fn(batch): |
|
|
return {"input_ids": [entry["input_ids"] for entry in batch]} |
|
|
|
|
|
batch_size = checkpointing_config.learning_dynamics.batch_size |
|
|
sub_batch_size = batch_size // fabric.world_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
extractor_dataloader = DataLoader( |
|
|
dataset, |
|
|
batch_size=sub_batch_size, |
|
|
shuffle=False, |
|
|
collate_fn=_collate_fn, |
|
|
drop_last=False, |
|
|
) |
|
|
extractor_dataloader = fabric.setup_dataloaders( |
|
|
extractor_dataloader, use_distributed_sampler=True |
|
|
) |
|
|
|
|
|
|
|
|
_model = initialize_model(model.config) |
|
|
_model.load_state_dict(model.state_dict()) |
|
|
|
|
|
if isinstance(fabric.strategy, DeepSpeedStrategy): |
|
|
_model, _ = fabric.setup(_model, DummyOptimizer(_model.parameters())) |
|
|
else: |
|
|
_model = fabric.setup(_model) |
|
|
|
|
|
_model.zero_grad() |
|
|
|
|
|
|
|
|
state_extractor = CheckpointStateExtractor( |
|
|
checkpointing_config.learning_dynamics, fabric, _model |
|
|
) |
|
|
|
|
|
checkpoint_activations, checkpoint_weights, checkpoint_gradients = ( |
|
|
state_extractor.extract_states( |
|
|
extractor_dataloader, compute_gradients=compute_gradients |
|
|
) |
|
|
) |
|
|
|
|
|
del _model |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
fabric.barrier() |
|
|
|
|
|
model.to(fabric.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for layer_name, layer_activations in checkpoint_activations.items(): |
|
|
if len(layer_activations) > len(dataset): |
|
|
checkpoint_activations[layer_name] = layer_activations[: len(dataset)] |
|
|
elif len(layer_activations) < len(dataset): |
|
|
raise ValueError( |
|
|
f"Number of activations ({len(layer_activations)}) in layer {layer_name} does not match number of samples in dataset ({len(dataset)})" |
|
|
) |
|
|
|
|
|
return { |
|
|
"activations": checkpoint_activations, |
|
|
"weights": checkpoint_weights, |
|
|
"gradients": checkpoint_gradients, |
|
|
} |
|
|
|
|
|
|
|
|
@rank_zero_only |
|
|
@use_backoff() |
|
|
def save_learning_dynamics_states( |
|
|
checkpointing_config: CheckpointingConfig, |
|
|
checkpoint_step: int, |
|
|
prefix: str, |
|
|
fabric: Fabric, |
|
|
learning_dynamics_states: Dict[str, torch.Tensor], |
|
|
learning_dynamics_dataset: Optional[Dataset] = None, |
|
|
tokenizer: Optional[PreTrainedTokenizerBase] = None, |
|
|
) -> None: |
|
|
"""Save the learning dynamics metrics to the checkpointing directory. |
|
|
|
|
|
By default only the learning dynamics states are saved. If the learning dynamics dataset |
|
|
is provided, it is also saved; if a tokenizer is provided, the dataset is also detokenized |
|
|
(i.e. a new column with the text is added to the dataset). |
|
|
|
|
|
The learning dynamics dataset is saved in the checkpointing directory as a HuggingFace |
|
|
dataset. |
|
|
|
|
|
Creates a versioned checkpoint directory with the following structure: |
|
|
|
|
|
{checkpointing_config.runs_dir}/ |
|
|
βββ {checkpointing_config.run_name}/ |
|
|
βββ {checkpointing_config.checkpoints_dir}/ |
|
|
βββ step_{checkpoint_step}/ |
|
|
β βββ {checkpointing_config.learning_dynamics_dir}/ # Learning Dynamics files |
|
|
β βββ {prefix}_activations.pt |
|
|
β βββ {prefix}_weights.pt |
|
|
β βββ {prefix}_gradients.pt |
|
|
β βββ {prefix}_data/ # if learning_dynamics_dataset is provided |
|
|
βββ latest -> step_{checkpoint_step}/ |
|
|
|
|
|
NOTE: this function is only called on rank 0 |
|
|
|
|
|
Args: |
|
|
checkpointing_config: The configuration object for checkpointing. |
|
|
checkpoint_step: The checkpoint step at which the learning dynamics states were computed. |
|
|
prefix: The prefix for the learning dynamics states. |
|
|
fabric: The Fabric instance for distributed training. |
|
|
learning_dynamics_states: The learning dynamics states to save. |
|
|
learning_dynamics_dataset: The dataset containing learning dynamics data, |
|
|
including input IDs that need to be decoded. (optional) |
|
|
tokenizer: The tokenizer used to decode input IDs into text. (optional) |
|
|
""" |
|
|
|
|
|
runs_dir = checkpointing_config.runs_dir |
|
|
run_name = checkpointing_config.run_name |
|
|
checkpoints_dir = checkpointing_config.checkpoints_dir |
|
|
learning_dynamics_dir = checkpointing_config.learning_dynamics_dir |
|
|
|
|
|
run_path = os.path.join(runs_dir, run_name) |
|
|
root_checkpoint_path = os.path.join(run_path, checkpoints_dir) |
|
|
checkpoint_path = os.path.join(root_checkpoint_path, f"step_{checkpoint_step}") |
|
|
learning_dynamics_path = os.path.join(checkpoint_path, learning_dynamics_dir) |
|
|
os.makedirs(learning_dynamics_path, exist_ok=True) |
|
|
|
|
|
|
|
|
for key, value in learning_dynamics_states.items(): |
|
|
if value is not None and len(value) > 0: |
|
|
torch.save( |
|
|
value, os.path.join(learning_dynamics_path, f"{prefix}_{key}.pt") |
|
|
) |
|
|
|
|
|
if learning_dynamics_dataset is not None: |
|
|
if tokenizer is not None: |
|
|
|
|
|
detokenized_dataset = {"input_ids": [], "text": []} |
|
|
|
|
|
for entry in learning_dynamics_dataset: |
|
|
input_ids = entry["input_ids"] |
|
|
decoded_text = tokenizer.decode(input_ids, skip_special_tokens=True) |
|
|
detokenized_dataset["input_ids"].append(input_ids) |
|
|
detokenized_dataset["text"].append(decoded_text) |
|
|
|
|
|
learning_dynamics_dataset = Dataset.from_dict(detokenized_dataset) |
|
|
|
|
|
learning_dynamics_dataset_path = os.path.join( |
|
|
learning_dynamics_path, f"{prefix}_data" |
|
|
) |
|
|
learning_dynamics_dataset.save_to_disk(learning_dynamics_dataset_path) |
|
|
|
|
|
if checkpointing_config.save_to_hf: |
|
|
|
|
|
upload_folder( |
|
|
folder_path=learning_dynamics_path, |
|
|
path_in_repo=learning_dynamics_dir, |
|
|
repo_id=checkpointing_config.hf_checkpoint.repo_id, |
|
|
commit_message=f"Saving Learning Dynamics Data ({prefix}) -- Step {checkpoint_step}", |
|
|
revision=checkpointing_config.run_name, |
|
|
token=os.getenv("HF_TOKEN"), |
|
|
) |
|
|
|