owlv2 / scenic /projects /adatape /adatape_vit /adatape_trainer.py
fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training script with transfer learning utilities."""
import functools
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type
from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform
from flax import jax_utils
import flax.linen as nn
import jax
from jax.example_libraries.optimizers import clip_grads
import jax.numpy as jnp
import jax.profiler
import ml_collections
import numpy as np
import optax
from scenic.common_lib import video_utils
from scenic.dataset_lib import dataset_utils
from scenic.model_lib.base_models import base_model
from scenic.train_lib import lr_schedules
from scenic.train_lib import optimizers
from scenic.train_lib import pretrain_utils
from scenic.train_lib import train_utils
from scenic.train_lib.transfer import fewshot_utils
from scenic.train_lib.transfer import linear_probe_utils
# Aliases for custom types:
Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]],
Dict[str, Tuple[float, int]]]
LossFn = Callable[
[jnp.ndarray, Batch, Optional[jnp.ndarray], Any], float]
LrFn = Callable[[jnp.ndarray], jnp.ndarray]
def train_step(
train_state: train_utils.TrainState,
batch: Batch,
*,
flax_model: nn.Module,
loss_fn: LossFn,
metrics_fn: MetricFn,
lr_fn: LrFn,
config: ml_collections.ConfigDict,
debug: Optional[bool] = False
) -> Tuple[train_utils.TrainState, Dict[str, Tuple[float, int]], Dict[str,
Any]]:
"""Runs a single step of training.
Given the state of the training and a batch of data, computes
the loss and updates the parameters of the model.
Note that in this code, the buffers of the first (train_state) and second
(batch) arguments are donated to the computation.
Args:
train_state: The state of training including the current global_step,
model_state, rng, params, and optimizer. The buffer of this argument can
be donated to the computation.
batch: A single batch of data. The buffer of this argument can be donated to
the computation.
flax_model: A Flax model.
loss_fn: A loss function that given logits, a batch, and parameters of the
model calculates the loss.
metrics_fn: A metrics function that given logits and batch of data,
calculates the metrics as well as the loss.
lr_fn: The learning rate fn used for the logging the learning rate.
config: Configurations of the experiment.
debug: Whether the debug mode is enabled during training. `debug=True`
enables model specific logging/storing some values using
jax.host_callback.
Returns:
Updated state of training and computed metrics for logging.
"""
training_logs = {}
new_rng, rng = jax.random.split(train_state.rng)
if config.get('mixup') and config.mixup.alpha:
mixup_rng, rng = jax.random.split(rng, 2)
mixup_rng = train_utils.bind_rng_to_host_device(
mixup_rng,
axis_name='batch',
bind_to=config.mixup.get('bind_to', 'device'))
batch = dataset_utils.mixup(
batch,
config.mixup.alpha,
config.mixup.get('image_format', 'NHWC'),
rng=mixup_rng)
# Bind the rng to the host/device we are on.
dropout_rng = train_utils.bind_rng_to_host_device(
rng, axis_name='batch', bind_to='device')
def training_loss_fn(params):
variables = {'params': params, **train_state.model_state}
(logits, auxiliary_outputs), new_model_state = flax_model.apply(
variables,
batch['inputs'],
mutable=['batch_stats'],
train=True,
rngs={'dropout': dropout_rng},
debug=debug)
loss = loss_fn(logits, batch, variables['params'], auxiliary_outputs)
return loss, (new_model_state, logits, auxiliary_outputs)
compute_gradient_fn = jax.value_and_grad(training_loss_fn, has_aux=True)
(train_cost, compute_outputs), grad = compute_gradient_fn(train_state.params)
(new_model_state, logits, auxiliary_outputs) = compute_outputs
del train_cost
# Re-use same axis_name as in the call to `pmap(...train_step...)` below.
grad = jax.lax.pmean(grad, axis_name='batch')
if config.get('max_grad_norm') is not None:
grad = clip_grads(grad, config.max_grad_norm)
updates, new_opt_state = train_state.tx.update(grad, train_state.opt_state,
train_state.params)
new_params = optax.apply_updates(train_state.params, updates)
training_logs['l2_grads'] = jnp.sqrt(
sum([jnp.vdot(g, g) for g in jax.tree_util.tree_leaves(grad)]))
ps = jax.tree_util.tree_leaves(new_params)
training_logs['l2_params'] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps]))
us = jax.tree_util.tree_leaves(updates)
training_logs['l2_updates'] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us]))
training_logs['learning_rate'] = lr_fn(train_state.global_step)
# Logging the average sequence length we are using
if config.get('model.ac_config.dynamic_tape_length') is not None:
# Logging the sequence length
input_masks = auxiliary_outputs[0]
input_len = jnp.sum(input_masks, axis=-1)
avg_input_len = jnp.mean(input_len)
training_logs['sequence length'] = avg_input_len
# To check whether all examples are using the same sequence length,
# we write the variance of input_len into training_logs.
training_logs['sequence length var'] = jnp.var(input_len)
loss_atr = auxiliary_outputs[1]
training_logs['loss_atr'] = jnp.mean(loss_atr)
metrics = metrics_fn(logits, batch)
new_train_state = train_state.replace( # pytype: disable=attribute-error
global_step=train_state.global_step + 1,
opt_state=new_opt_state,
params=new_params,
model_state=new_model_state,
rng=new_rng)
return new_train_state, metrics, training_logs
def eval_step(
train_state: train_utils.TrainState,
batch: Batch,
*,
flax_model: nn.Module,
metrics_fn: MetricFn,
debug: Optional[bool] = False
) -> Tuple[Dict[str, Tuple[float, int]], jnp.ndarray]:
"""Runs a single step of training.
Note that in this code, the buffer of the second argument (batch) is donated
to the computation.
Assumed API of metrics_fn is:
```metrics = metrics_fn(logits, batch)
where batch is yielded by the batch iterator, and metrics is a dictionary
mapping metric name to a vector of per example measurements. eval_step will
aggregate (by summing) all per example measurements and divide by the
aggregated normalizers. For each given metric we compute:
1/N sum_{b in batch_iter} metric(b), where N is the sum of normalizer
over all batches.
Args:
train_state: TrainState, the state of training including the current
global_step, model_state, rng, params and optimizer state. The buffer of
this argument can be donated to the computation.
batch: A single batch of data. a metrics function, that given logits and
batch of data, calculates the metrics as well as the loss.
flax_model: A Flax model.
metrics_fn: A metrics function, that given logits and batch of data,
calculates the metrics as well as the loss.
debug: Whether the debug mode is enabled during evaluation. `debug=True`
enables model specific logging/storing some values using
jax.host_callback.
Returns:
Calculated metrics and logits.
"""
variables = {'params': train_state.params, **train_state.model_state}
logits, _ = flax_model.apply(
variables, batch['inputs'], train=False, mutable=False, debug=debug)
metrics = metrics_fn(logits, batch)
return metrics, logits
def representation_fn(
train_state: train_utils.TrainState,
batch: Batch,
*,
flax_model: nn.Module,
representation_layer: str,
gather_to_host: bool = True
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Feeds the inputs to the model and returns their representations.
Args:
train_state: TrainState, the state of training including the current
global_step, model_state, rng, and optimizer. The buffer of this argument
can be donated to the computation.
batch: A single batch of data from the dataset.
flax_model: A Flax model.
representation_layer: The name of the layer to use as the representation.
gather_to_host: Whether to gather results from all devices to the host,
rather than leaving them distributed.
Returns:
Representation learned by the model for the given inputs and the labels and
masks. If `gather_to_host` is True, these are collected from all hosts.
"""
variables = {'params': train_state.params, **train_state.model_state}
representation_layer_parts = representation_layer.split('/')
filter_rep = lambda mdl, _: mdl.name == representation_layer_parts[-1]
_, model_state = flax_model.apply(
variables,
batch['inputs'],
train=False,
capture_intermediates=filter_rep,
mutable=['intermediates'],
debug=False)
if 'intermediates' not in model_state:
raise ValueError(f'Layer with name "{representation_layer}"'
' does not exist in your model.')
representation = model_state['intermediates']
for rep_layer in representation_layer_parts:
if rep_layer:
representation = representation[rep_layer]
representation = representation['__call__'][0]
if gather_to_host:
representation = jax.lax.all_gather(representation, 'batch')
batch = jax.lax.all_gather(batch, 'batch')
return representation, batch['label'], batch['batch_mask']
def representation_fn_video(
train_state: train_utils.TrainState,
batch: Batch,
*,
flax_model: nn.Module,
config: ml_collections.ConfigDict,
gather_to_host: bool = True,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Feeds the video inputs to the model and returns their representations.
Video representations are obtained by temporally average-pooling per-frame
representations from the input video clip.
Args:
train_state: TrainState, the state of training including the current
global_step, model_state, rng, params, and optimizer. The buffer of this
argument can be donated to the computation.
batch: A single batch of data from the video dataset.
flax_model: A Flax model.
config: Configurations of the experiment.
gather_to_host: Whether to gather results from all devices to the host,
rather than leaving them distributed.
Returns:
Representation learned by the model for the given inputs and the labels and
masks. If `gather_to_host` is True, these are collected from all hosts.
The shape of the returned tensors when `gather_to_host` is False are:
representation: `[num_devices, global_batch, features]`.
labels: `[num_devices, global_batch]`.
mask: `[num_devices, global_batch]`.
If `gather_to_host` is True then each shape is prepended with
`[num_hosts,]`
"""
variables = {'params': train_state.params, **train_state.model_state}
representation_layer = config.video_fewshot.representation_layer.split('/')
filter_rep = lambda mdl, _: mdl.name == representation_layer[-1]
def get_representation(inputs, variables, training, capture_intermediates,
mutable, debug):
_, model_state = flax_model.apply(
variables,
inputs,
train=training,
capture_intermediates=capture_intermediates,
mutable=mutable,
debug=debug)
if 'intermediates' not in model_state:
raise ValueError(
f'Layer with name "{config.video_fewshot.representation_layer}"'
' does not exist in your model.')
representation = model_state['intermediates']
for rep_layer in representation_layer:
if rep_layer:
representation = representation[rep_layer]
representation = representation['__call__'][0]
return representation
# Get representations for each frame in the video sample.
if config.video_fewshot.get('n_sampled_frames'):
inputs = video_utils.sample_frames_uniformly(
batch['inputs'], config.video_fewshot.n_sampled_frames)
else:
inputs = batch['inputs']
representation = jax.vmap(
functools.partial(
get_representation,
variables=variables,
training=False,
capture_intermediates=filter_rep,
mutable=['intermediates'],
debug=False),
in_axes=1,
out_axes=1,
axis_name='time')(
inputs)
# Average pooling of representations over time axis.
representation = jnp.mean(representation, axis=1)
if gather_to_host:
representation = jax.lax.all_gather(representation, 'batch')
batch = jax.lax.all_gather(batch, 'batch')
return representation, batch['label'], batch['batch_mask']
def train(
*,
rng: jnp.ndarray,
config: ml_collections.ConfigDict,
model_cls: Type[base_model.BaseModel],
dataset: dataset_utils.Dataset,
workdir: str,
writer: metric_writers.MetricWriter,
) -> Tuple[train_utils.TrainState, Dict[str, Any], Dict[str, Any]]:
"""Main training loop lives in this function.
Given the model class and dataset, it prepares the items needed to run the
training, including the TrainState.
Args:
rng: Jax rng key.
config: Configurations of the experiment.
model_cls: Model class; A model has a flax_module, a loss_fn, and a
metrics_fn associated with it.
dataset: The dataset that has train_iter, eval_iter, meta_data, and
optionally, test_iter.
workdir: Directory for checkpointing.
writer: CLU metrics writer instance.
Returns:
train_sate that has the state of training (including current global_step,
model_state, rng, and the optimizer), train_summary and eval_summary which
are dict of metrics (from the last evaluation and train metric logging
respectively). These outputs are used for regression testing.
"""
lead_host = jax.process_index() == 0
# Build the loss_fn, metrics, and flax_model.
model = model_cls(config, dataset.meta_data)
# Initialize model.
rng, init_rng = jax.random.split(rng)
(params, model_state, num_trainable_params,
gflops) = train_utils.initialize_model(
model_def=model.flax_model,
input_spec=[(dataset.meta_data['input_shape'],
dataset.meta_data.get('input_dtype', jnp.float32))],
config=config,
rngs=init_rng)
# Create optimizer.
lr_fn = lr_schedules.get_learning_rate_fn(config)
optimizer_config = optimizers.get_optax_optimizer_config(config)
# If the config is already an optax-compatible config, better call directly:
# optimizers.get_optimizer(config.optimizer_configs, lr_fn)
tx = optimizers.get_optimizer(optimizer_config, lr_fn, params=params)
# We jit this, such that the arrays that are created are created on the same
# device as the input is, in this case the CPU. Else they'd be on device[0].
opt_state = jax.jit(tx.init, backend='cpu')(params)
rng, train_rng = jax.random.split(rng)
# Create chrono class to track and store training statistics and metadata:
chrono = train_utils.Chrono()
train_state = train_utils.TrainState(
global_step=0,
opt_state=opt_state,
tx=tx,
params=params,
model_state=model_state,
rng=train_rng,
metadata={'chrono': chrono.save()})
start_step = train_state.global_step
if config.checkpoint:
train_state, start_step = train_utils.restore_checkpoint(
workdir, train_state)
chrono.load(train_state.metadata['chrono'])
if (start_step == 0 # Which means "no" checkpoint is restored!
and config.get('init_from') is not None):
restored_model_cfg = config.init_from.get('model_config')
init_checkpoint_path = config.init_from.get('checkpoint_path')
if init_checkpoint_path is not None:
restored_train_state = pretrain_utils.restore_pretrained_checkpoint(
init_checkpoint_path, train_state, assert_exist=True)
# Load params from the init_model.
train_state = model.init_from_train_state( # pytype: disable=attribute-error
train_state, restored_train_state, restored_model_cfg)
del restored_train_state
# Replicate the optimzier, state, and rng.
train_state = jax_utils.replicate(train_state)
del params # Do not keep a copy of the initial params.
# Calculate the total number of training steps.
total_steps, steps_per_epoch = train_utils.get_num_training_steps(
config, dataset.meta_data)
train_step_pmapped = jax.pmap(
functools.partial(
train_step,
flax_model=model.flax_model,
loss_fn=model.loss_function,
metrics_fn=model.get_metrics_fn('train'),
lr_fn=lr_fn,
config=config,
debug=config.debug_train),
axis_name='batch',
# We can donate both buffers of train_state and train_batch.
donate_argnums=(0, 1),
)
eval_step_pmapped = jax.pmap(
functools.partial(
eval_step,
flax_model=model.flax_model,
metrics_fn=model.get_metrics_fn('validation'),
debug=config.debug_eval),
axis_name='batch',
# We can donate the eval_batch's buffer.
donate_argnums=(1,),
)
if 'fewshot' in config:
representation_fn_fewshot = functools.partial(
representation_fn,
flax_model=model.flax_model,
representation_layer=config.fewshot.representation_layer)
fewshotter = fewshot_utils.FewShotEvaluator(representation_fn_fewshot,
config.fewshot)
if 'video_fewshot' in config:
representation_fn_video_fewshot = functools.partial(
representation_fn_video, flax_model=model.flax_model, config=config)
video_fewshotter = fewshot_utils.FewShotEvaluatorVideo(
representation_fn_video_fewshot, config.video_fewshot)
if 'linear_probe' in config:
representation_fn_linear_probe = functools.partial(
representation_fn,
flax_model=model.flax_model,
representation_layer=config.linear_probe.representation_layer,
gather_to_host=False)
rng, linear_probe_rng = jax.random.split(rng)
linear_probe = linear_probe_utils.LinearEvaluator(
representation_fn=representation_fn_linear_probe,
rng=linear_probe_rng,
linear_eval_config=config.linear_probe)
log_eval_steps = config.get('log_eval_steps') or steps_per_epoch
if not log_eval_steps:
raise ValueError("'log_eval_steps' should be specified in the config.")
checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps
log_summary_steps = config.get('log_summary_steps') or log_eval_steps
def evaluate(train_state: train_utils.TrainState, step: int,
valid_iter: Iterator[Batch],
num_valid_ex: int) -> Dict[str, Any]:
eval_summary = {}
if not isinstance(valid_iter, dict): # Only on validation set.
valid_iter, num_valid_ex = {'valid': valid_iter}, {'valid': num_valid_ex}
for val_name, val_iter in valid_iter.items():
num_ex = num_valid_ex[val_name]
# Ceil rounding such that we include the last incomplete batch.
eval_batch_size = config.get('eval_batch_size', config.batch_size)
total_eval_steps = int(np.ceil(num_ex / eval_batch_size))
steps_per_eval = config.get('steps_per_eval') or total_eval_steps
eval_metrics = []
for _ in range(steps_per_eval):
eval_batch = next(val_iter)
if dataset.meta_data['target_is_onehot']: # Which includes multi-hot.
# Ignore the entries with all zero label for evaluation.
eval_batch['batch_mask'] *= eval_batch['label'].max(axis=-1)
e_metrics, _ = eval_step_pmapped(train_state, eval_batch)
eval_metrics.append(train_utils.unreplicate_and_get(e_metrics))
eval_summary.update(
train_utils.log_eval_summary(
step=step,
eval_metrics=eval_metrics,
writer=writer,
prefix=val_name))
del eval_metrics
writer.flush()
return eval_summary
train_metrics, extra_training_logs = [], []
train_summary, eval_summary = None, None
chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch)
logging.info('Starting training loop at step %d.', start_step + 1)
report_progress = periodic_actions.ReportProgress(
num_train_steps=total_steps, writer=writer)
def write_note(note):
if lead_host:
platform.work_unit().set_notes(note)
hooks = []
if lead_host:
hooks.append(report_progress)
if config.get('xprof', True) and lead_host:
hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir))
if start_step == 0:
step0_log = {'num_trainable_params': num_trainable_params}
if gflops:
step0_log['gflops'] = gflops
writer.write_scalars(1, step0_log)
write_note(f'First step compilations...\n{chrono.note}')
for step in range(start_step + 1, total_steps + 1):
with jax.profiler.StepTraceAnnotation('train', step_num=step):
train_batch = next(dataset.train_iter)
train_state, t_metrics, t_logs = train_step_pmapped(
train_state, train_batch)
# This will accumulate metrics in TPU memory up to the point that we log
# them. This is no problem for small metrics but may be a problem for
# large (e.g. segmentation) metrics. An alternative is to set
# `log_summary_steps` to a small number, or to use
# `train_utils.unreplicate_and_get` here instead of right before writing
# summaries, but that means in each step, we have data transfer between
# tpu and host, which might slow down the training.
train_metrics.append(t_metrics)
# Additional training logs: learning rate:
t_logs = jax.tree_util.tree_map(jax_utils.unreplicate, t_logs)
# t_logs.update({'learning_rate': lr_fn(step)})
extra_training_logs.append(t_logs)
# Quick indication that training is happening.
logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step)
for h in hooks:
h(step)
############### LOG TRAIN SUMMARY ###############
if (step % log_summary_steps == 1) or (step
== total_steps) or chrono.warmup:
chrono.pause(wait_for=(train_metrics))
if lead_host:
chrono.tick(step, writer, write_note)
train_summary = train_utils.log_train_summary(
step=step,
train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get,
train_metrics),
extra_training_logs=jax.tree_util.tree_map(jax.device_get,
extra_training_logs),
writer=writer)
# Reset metric accumulation for next evaluation cycle.
train_metrics, extra_training_logs = [], []
chrono.resume()
################### EVALUATION #######################
if (step % log_eval_steps == 1) or (step == total_steps):
chrono.pause(wait_for=(train_state.params))
with report_progress.timed('eval'):
# Sync model state across replicas.
train_state = train_utils.sync_model_state_across_replicas(train_state)
eval_summary = evaluate(train_state, step, dataset.valid_iter,
dataset.meta_data['num_eval_examples'])
chrono.resume()
##################### CHECKPOINTING ############################
if ((step % checkpoint_steps == 1 and step > 1) or
(step == total_steps)) and config.checkpoint:
chrono.pause(wait_for=(train_state.params, train_state.opt_state))
with report_progress.timed('checkpoint'):
# Sync model state across replicas.
train_state = train_utils.sync_model_state_across_replicas(train_state)
if lead_host:
# Take the first replica.
unrep_train_state = jax_utils.unreplicate(train_state)
metadata = unrep_train_state.metadata
metadata['chrono'] = chrono.save()
unrep_train_state.replace(metadata=metadata) # pytype: disable=attribute-error
train_utils.save_checkpoint(workdir, unrep_train_state)
del unrep_train_state
chrono.resume() # Un-pause now.
##################### FEWSHOT EVALUATION ############################
if 'fewshot' in config:
# Compute few-shot on-the-fly evaluation.
if (step % config.fewshot.log_eval_steps == 1) or (step == total_steps):
chrono.pause(wait_for=(train_state.params))
with report_progress.timed('fewshot'):
results = fewshotter.run_all(train_state, config.fewshot.datasets)
fewshotter.log_fewshot_summary(
writer=writer, step=step, results=results)
del results
writer.write_scalars(step, {'zz/epoch': step / steps_per_epoch})
writer.flush()
chrono.resume() # Un-pause now.
########### FEWSHOT EVALUATION USING VIDEO DATASETS ###############
if 'video_fewshot' in config:
# Compute few-shot on-the-fly evaluation using video dataset.
if ((step % config.video_fewshot.log_eval_steps == 1) or
step == total_steps):
chrono.pause(wait_for=(train_state.params))
with report_progress.timed('video_fewshot'):
results = video_fewshotter.run_all(train_state,
config.video_fewshot.datasets)
video_fewshotter.log_fewshot_summary(
writer=writer, step=step, results=results)
del results
writer.write_scalars(step, {'zz/epoch': step / steps_per_epoch})
writer.flush()
chrono.resume() # Un-pause now.
##################### LINEAR-PROBE EVALUATION ##########################
if 'linear_probe' in config:
if (config.linear_probe.log_eval_steps > 0 and
step % config.linear_probe.log_eval_steps == 1) or (step
== total_steps):
chrono.pause(wait_for=(train_state.params))
with report_progress.timed('linear_probe'):
linear_probe.run_all(
train_state,
config.linear_probe.datasets,
writer=writer,
repr_step=step)
writer.flush()
chrono.resume() # Un-pause now.
# Wait until computations are done before exiting.
train_utils.barrier_across_hosts()
# Return the train and eval summary after last step for regression testing.
return train_state, train_summary, eval_summary