| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Training Script.""" |
|
|
| import copy |
| import functools |
| from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Type, Union |
|
|
| from absl import logging |
| from clu import metric_writers |
| from clu import periodic_actions |
| from flax import jax_utils |
| import flax.linen as nn |
| import jax |
| from jax.example_libraries.optimizers import clip_grads |
| import jax.numpy as jnp |
| import jax.profiler |
| import ml_collections |
| import numpy as np |
| import optax |
| from scenic.common_lib import debug_utils |
| from scenic.dataset_lib import dataset_utils |
| from scenic.model_lib.base_models import base_model |
| from scenic.projects.av_mae import evaluation_lib |
| from scenic.projects.av_mae import optimizer_utils |
| from scenic.projects.av_mae import train_utils as avmae_train_utils |
| from scenic.projects.av_mae import trainer_multimodal |
| from scenic.projects.vivit import train_utils as vivit_train_utils |
| from scenic.train_lib import lr_schedules |
| from scenic.train_lib import optimizers |
| from scenic.train_lib import pretrain_utils |
| from scenic.train_lib import train_utils |
| from scenic.train_lib.transfer import fewshot_utils |
|
|
|
|
| |
| Batch = Dict[str, jnp.ndarray] |
| MetricFn = Callable[[jnp.ndarray, Dict[str, jnp.ndarray]], |
| Dict[str, Tuple[float, int]]] |
| LossFn = Callable[[jnp.ndarray, Batch, Optional[jnp.ndarray]], float] |
|
|
|
|
| def train_step( |
| train_state: train_utils.TrainState, |
| batch: Batch, |
| *, |
| flax_model: nn.Module, |
| loss_fn: LossFn, |
| metrics_fn: MetricFn, |
| config: ml_collections.ConfigDict, |
| debug: Optional[bool] = False |
| ) -> Tuple[train_utils.TrainState, Dict[str, Tuple[float, int]]]: |
| """Runs a single step of training. |
| |
| Given the state of the training and a batch of data, computes |
| the loss and updates the parameters of the model. |
| |
| Note that in this code, the buffers of the first (train_state) and second |
| (batch) arguments are donated to the computation. |
| |
| Args: |
| train_state: The state of training including the current global_step, |
| model_state, rng, and optimizer. The buffer of this argument can be |
| donated to the computation. |
| batch: A single batch of data. The buffer of this argument can be donated to |
| the computation. |
| flax_model: A Flax model. |
| loss_fn: A loss function that given logits, a batch, and parameters of the |
| model calculates the loss. |
| metrics_fn: A metrics function that given logits and batch of data, |
| calculates the metrics as well as the loss. |
| config: Configurations of the experiment. |
| debug: Whether the debug mode is enabled during training. `debug=True` |
| enables model specific logging/storing some values using |
| jax.host_callback. |
| |
| Returns: |
| Updated state of training, computed metrics, and learning rate for logging. |
| """ |
| new_rng, rng = jax.random.split(train_state.rng) |
|
|
| if config.get('mixup') and config.mixup.alpha: |
| mixup_rng, rng = jax.random.split(rng, 2) |
| mixup_rng = train_utils.bind_rng_to_host_device( |
| mixup_rng, |
| axis_name='batch', |
| bind_to=config.mixup.get('bind_to', 'device')) |
| image_format = config.mixup.get('image_format', 'NTHWC') |
| batch_first = True |
| if image_format.index('N') > 0: |
| batch_first = False |
| batch = avmae_train_utils.mixup_modalities( |
| batch, |
| config.mixup.alpha, |
| batch_first, |
| mixmod=config.get('mixmod', False), |
| rng=mixup_rng) |
| else: |
| |
| if config.get('labels_as_dict', True): |
| labels = batch['label'] |
| batch['label'] = {} |
| for modality in batch['inputs']: |
| batch['label'][modality] = labels |
| batch['label']['all'] = labels |
|
|
| |
| dropout_rng = train_utils.bind_rng_to_host_device( |
| rng, axis_name='batch', bind_to='device') |
|
|
| def training_loss_fn(params): |
| variables = {'params': params, **train_state.model_state} |
| logits, new_model_state = flax_model.apply( |
| variables, |
| batch['inputs'], |
| mutable=['batch_stats'], |
| train=True, |
| rngs={'dropout': dropout_rng}, |
| debug=debug) |
| loss = loss_fn(logits, batch, variables['params']) |
| return loss, (new_model_state, logits) |
|
|
| compute_gradient_fn = jax.value_and_grad(training_loss_fn, has_aux=True) |
| (_, (new_model_state, |
| logits)), grad = compute_gradient_fn(train_state.params) |
|
|
| if isinstance(logits, dict): |
| |
| |
| modality = list(logits.keys())[0] |
| batch['label'] = batch['label'][modality] |
| metrics = metrics_fn(logits[modality], batch) |
| else: |
| metrics = metrics_fn(logits, batch) |
|
|
| if not config.get('grad_clip_after_pmean', True): |
| metrics['max_grad_norm_preclip_before_pmean'] = ( |
| avmae_train_utils.compute_max_norm(grad), 1) |
| if config.get('max_grad_norm', None) is not None: |
| grad = clip_grads(grad, config.max_grad_norm) |
| metrics['max_grad_norm_postclip_before_pmean'] = ( |
| avmae_train_utils.compute_max_norm(grad), 1) |
|
|
| |
| grad = jax.lax.pmean(grad, axis_name='batch') |
|
|
| if config.get('grad_clip_after_pmean', True): |
| metrics['max_grad_norm_preclip_after_pmean'] = ( |
| avmae_train_utils.compute_max_norm(grad), 1) |
| if config.get('max_grad_norm', None) is not None: |
| grad = clip_grads(grad, config.max_grad_norm) |
| metrics['max_grad_norm_postclip_after_pmean'] = ( |
| avmae_train_utils.compute_max_norm(grad), 1) |
|
|
| |
| |
| |
| updates, new_opt_state = train_state.tx.update(grad, train_state.opt_state, |
| train_state.params) |
| new_params = optax.apply_updates(train_state.params, updates) |
|
|
| |
| |
| metrics['l2_grads'] = (jnp.sqrt( |
| sum([jnp.vdot(g, g) for g in jax.tree_util.tree_leaves(grad)])), 1) |
| metrics['l2_params'] = (jnp.sqrt( |
| sum([jnp.vdot(p, p) for p in jax.tree_util.tree_leaves(new_params)])), 1) |
| metrics['l2_updates'] = (jnp.sqrt( |
| sum([jnp.vdot(u, u) for u in jax.tree_util.tree_leaves(updates)])), 1) |
|
|
| new_train_state = train_state.replace( |
| global_step=train_state.global_step + 1, |
| opt_state=new_opt_state, |
| params=new_params, |
| model_state=new_model_state, |
| rng=new_rng) |
| return new_train_state, metrics |
|
|
|
|
| def eval_step( |
| train_state: train_utils.TrainState, |
| batch: Batch, |
| *, |
| flax_model: nn.Module, |
| metrics_fn: MetricFn, |
| return_logits_and_labels: bool = False, |
| debug: Optional[bool] = False, |
| ) -> Union[ |
| Tuple[Dict[str, Tuple[float, int]], jnp.ndarray, jnp.ndarray], |
| Tuple[Dict[str, Tuple[float, int]], jnp.ndarray], |
| ]: |
| """Runs a single step of validation. |
| |
| Note that in this code, the buffer of the second argument (batch) is donated |
| to the computation. |
| |
| Assumed API of metrics_fn is: |
| ```metrics = metrics_fn(logits, batch) |
| where batch is yielded by the batch iterator, and metrics is a dictionary |
| mapping metric name to a vector of per example measurements. eval_step will |
| aggregate (by summing) all per example measurements and divide by the |
| aggregated normalizers. For each given metric we compute: |
| 1/N sum_{b in batch_iter} metric(b), where N is the sum of normalizer |
| over all batches. |
| |
| Args: |
| train_state: TrainState, the state of training including the current |
| global_step, model_state, rng, and optimizer. The buffer of this argument |
| can be donated to the computation. |
| batch: A single batch of data. a metrics function, that given logits and |
| batch of data, calculates the metrics as well as the loss. |
| flax_model: A Flax model. |
| metrics_fn: A metrics function, that given logits and batch of data, |
| calculates the metrics as well as the loss. |
| return_logits_and_labels: If true, returns logits and labels. Can be used |
| for calculating mean Average Precision for multi-label problems. |
| debug: Whether the debug mode is enabled during evaluation. `debug=True` |
| enables model specific logging/storing some values using |
| jax.host_callback. |
| |
| Returns: |
| Calculated metrics and logits. |
| """ |
| variables = { |
| 'params': train_state.params, |
| **train_state.model_state |
| } |
| logits = flax_model.apply( |
| variables, batch['inputs'], train=False, mutable=False, debug=debug) |
| metrics = metrics_fn(logits, batch) |
| |
| |
| |
| |
| |
| |
| if return_logits_and_labels: |
| logits = jax.lax.all_gather(logits, 'batch') |
| labels = jax.lax.all_gather(batch['label'], 'batch') |
| return metrics, logits, labels |
| return metrics, logits |
|
|
|
|
| def representation_fn( |
| train_state: train_utils.TrainState, |
| batch: Batch, |
| *, |
| flax_model: nn.Module, |
| representation_layer: str, |
| gather_to_host: bool = True |
| ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: |
| """Feeds the inputs to the model and returns their representations. |
| |
| Args: |
| train_state: TrainState, the state of training including the current |
| global_step, model_state, rng, and optimizer. The buffer of this argument |
| can be donated to the computation. |
| batch: A single batch of data from the dataset. |
| flax_model: A Flax model. |
| representation_layer: The name of the layer to use as the representation. |
| gather_to_host: Whether to gather results from all devices to the host, |
| rather than leaving them distributed. |
| |
| Returns: |
| Representation learned by the model for the given inputs and the labels and |
| masks. If `gather_to_host` is True, these are collected from all hosts. |
| """ |
| variables = { |
| 'params': train_state.params, |
| **train_state.model_state |
| } |
|
|
| representation_layer_parts = representation_layer.split('/') |
| filter_rep = lambda mdl, _: mdl.name == representation_layer_parts[-1] |
| _, model_state = flax_model.apply( |
| variables, |
| batch['inputs'], |
| train=False, |
| capture_intermediates=filter_rep, |
| mutable=['intermediates'], |
| debug=False) |
| if 'intermediates' not in model_state: |
| raise ValueError(f'Layer with name "{representation_layer}"' |
| ' does not exist in your model.') |
|
|
| representation = model_state['intermediates'] |
| for rep_layer in representation_layer_parts: |
| if rep_layer: |
| representation = representation[rep_layer] |
| representation = representation['__call__'][0] |
| if gather_to_host: |
| representation = jax.lax.all_gather(representation, 'batch') |
| batch = jax.lax.all_gather(batch, 'batch') |
| return representation, batch['label'], batch['batch_mask'] |
|
|
|
|
| def train( |
| *, |
| rng: jnp.ndarray, |
| config: ml_collections.ConfigDict, |
| model_cls: Type[base_model.BaseModel], |
| dataset: dataset_utils.Dataset, |
| workdir: str, |
| writer: metric_writers.MetricWriter): |
| """Main training loop lives in this function. |
| |
| Given the model class and dataset, it prepares the items needed to run the |
| training, including the TrainState. |
| |
| Args: |
| rng: Jax rng key. |
| config: Configurations of the experiment. |
| model_cls: Model class; A model has a flax_module, a loss_fn, and a |
| metrics_fn associated with it. |
| dataset: The dataset that has train_iter, eval_iter, meta_data, and |
| optionally, test_iter. |
| workdir: Directory for checkpointing. |
| writer: CLU metrics writer instance. |
| |
| Returns: |
| train_sate that has the state of training (including current global_step, |
| model_state, rng, and the optimizer), train_summary and eval_summary which |
| are dict of metrics (from the last evaluation and train metric logging |
| respectively). These outputs are used for regression testing. |
| """ |
| lead_host = jax.process_index() == 0 |
| |
| model = model_cls(config, dataset.meta_data) |
|
|
| is_multilabel_model = config.model_name in { |
| 'vit_multilabel_classification_mae', 'vit_multilabel_classification', |
| 'vivit_multimodal_multilabel_classification', |
| 'mbt_multilabel_classification' |
| } |
| logging.info('is_multilabel_model: %s', is_multilabel_model) |
|
|
| |
| rng, params_init_rng, dropout_init_rng = jax.random.split(rng, num=3) |
| init_rngs = {'params': params_init_rng, 'dropout': dropout_init_rng} |
| input_spec_dict = {} |
| for key in config.dataset_configs.modalities: |
| if isinstance(dataset.meta_data['input_dtype'], dict): |
| dtype = dataset.meta_data['input_dtype'][key] |
| else: |
| dtype = dataset.meta_data['input_dtype'] |
| input_spec = (dataset.meta_data['input_shape'][key], dtype) |
| input_spec_dict[key] = input_spec |
|
|
| (params, model_state, num_trainable_params, |
| gflops) = trainer_multimodal.initialize_model( |
| model_def=model.flax_model, |
| input_spec_dict=input_spec_dict, |
| config=config, |
| rngs=init_rngs, |
| is_train=True) |
|
|
| |
| lr_fn = lr_schedules.get_learning_rate_fn(config) |
| if 'layerwise_decay' in config.optimizer_configs: |
| tx = optimizer_utils.optimizer_with_layerwise_decay(config, params) |
| else: |
| optimizer_config = optimizers.get_optax_optimizer_config(config) |
| tx = optimizers.get_optimizer(optimizer_config, lr_fn, params) |
| |
| |
| opt_state = jax.jit(tx.init, backend='cpu')(params) |
|
|
| |
| chrono = train_utils.Chrono() |
|
|
| rng, train_rng = jax.random.split(rng) |
| train_state = train_utils.TrainState( |
| global_step=0, |
| opt_state=opt_state, |
| tx=tx, |
| params=params, |
| model_state=model_state, |
| rng=train_rng, |
| metadata={'chrono': chrono.save()}) |
| start_step = train_state.global_step |
| if config.checkpoint: |
| train_state, start_step = train_utils.restore_checkpoint( |
| workdir, train_state) |
| logging.info('Parameter summary after restoring checkpoint') |
| debug_utils.log_param_shapes(train_state.params) |
|
|
| chrono.load(train_state.metadata['chrono']) |
| train_state = train_state.replace(metadata={}) |
| if (start_step == 0 |
| and config.get('init_from') is not None): |
| restored_model_cfg = config.init_from.get('model_config') |
| init_checkpoint_path = config.init_from.get('checkpoint_path') |
|
|
| if init_checkpoint_path is not None: |
| checkpoint_format = config.init_from.get('checkpoint_format', 'scenic') |
| if checkpoint_format == 'scenic': |
| restored_train_state = pretrain_utils.restore_pretrained_checkpoint( |
| init_checkpoint_path, train_state, assert_exist=True) |
| elif checkpoint_format == 'big_vision': |
| restored_train_state = ( |
| pretrain_utils.convert_big_vision_to_scenic_checkpoint( |
| init_checkpoint_path, train_state |
| ) |
| ) |
| |
| |
| restored_model_cfg = copy.deepcopy(config) |
| |
| |
| |
| restored_model_cfg.model.classifier = config.init_from.get( |
| 'classifier_type', 'token') |
|
|
| train_state = model.init_from_train_state(train_state, |
| restored_train_state, |
| restored_model_cfg) |
| |
| del restored_train_state |
| logging.info('Parameters after initialising weights from checkpoint.') |
| debug_utils.log_param_shapes(train_state.params) |
|
|
| |
| train_state = jax_utils.replicate(train_state) |
| del params |
|
|
| |
| total_steps, steps_per_epoch = train_utils.get_num_training_steps( |
| config, dataset.meta_data) |
|
|
| train_step_pmapped = jax.pmap( |
| functools.partial( |
| train_step, |
| flax_model=model.flax_model, |
| loss_fn=model.loss_function, |
| metrics_fn=model.get_metrics_fn('train'), |
| config=config, |
| debug=config.debug_train), |
| axis_name='batch', |
| |
| donate_argnums=(0, 1), |
| ) |
| eval_step_pmapped = jax.pmap( |
| functools.partial( |
| eval_step, |
| flax_model=model.flax_model, |
| metrics_fn=model.get_metrics_fn('validation'), |
| return_logits_and_labels=is_multilabel_model, |
| debug=config.debug_eval), |
| axis_name='batch', |
| |
| donate_argnums=(1,), |
| ) |
|
|
| if 'fewshot' in config: |
| repr_fn = functools.partial( |
| representation_fn, |
| flax_model=model.flax_model, |
| representation_layer=config.fewshot.representation_layer) |
|
|
| if config.model_name.startswith('mvit'): |
| is_2d_model = len(config.model.patch_size) == 2 |
| else: |
| is_2d_model = len(config.model.patches.size) == 2 |
|
|
| if is_2d_model: |
| fewshotter = fewshot_utils.FewShotEvaluator(repr_fn, config.fewshot) |
| else: |
| fewshotter = fewshot_utils.FewShotEvaluatorVideo(repr_fn, config.fewshot) |
|
|
| if config.dataset_configs.get('do_multicrop_test'): |
| log_test_steps = int( |
| steps_per_epoch * config.dataset_configs.log_test_epochs) |
|
|
| test_step_pmapped = jax.pmap( |
| functools.partial( |
| avmae_train_utils.test_step_multimodal, |
| flax_model=model.flax_model, |
| metrics_fn=model.get_metrics_fn('test'), |
| n_clips=config.get('multicrop_clips_per_device', 2), |
| return_logits_and_labels=is_multilabel_model, |
| debug=config.debug_eval), |
| axis_name='batch', |
| |
| donate_argnums=(1,), |
| ) |
|
|
| if config.dataset_configs.test_batch_size != jax.local_device_count(): |
| raise ValueError( |
| 'The per-host batch size must be equal to the number of local devices' |
| 'This ensures that each TPU device is processing different views of' |
| 'the same original video. Got ' |
| f'{config.dataset_configs.test_batch_size} vs' |
| f'{jax.local_device_count()}.') |
|
|
| total_test_steps = int( |
| np.ceil(dataset.meta_data['num_test_examples'] / |
| (config.get('dataset_configs.test_batch_size') * |
| config.get('dataset_configs.num_test_clips') * |
| jax.process_count()))) |
| steps_per_test = config.get('steps_per_test') or total_test_steps |
|
|
| log_eval_steps = config.get('log_eval_steps') or steps_per_epoch |
| if not log_eval_steps: |
| raise ValueError("'log_eval_steps' should be specified in the config.") |
| checkpoint_steps = config.get('checkpoint_steps') or log_eval_steps |
| log_summary_steps = config.get('log_summary_steps') or log_eval_steps |
|
|
| def evaluate(train_state: train_utils.TrainState, step: int, |
| valid_iter: Iterator[Batch], |
| num_valid_ex: int, |
| compute_map: bool = False) -> Dict[str, Any]: |
| """Perform validation and log results, possibly including mAP. |
| """ |
|
|
| eval_summary = {} |
| if not isinstance(valid_iter, dict): |
| valid_iter, num_valid_ex = {'valid': valid_iter}, {'valid': num_valid_ex} |
|
|
| for val_name, val_iter in valid_iter.items(): |
| num_ex = num_valid_ex[val_name] |
| |
| eval_batch_size = config.get('eval_batch_size', config.batch_size) |
| total_eval_steps = int(np.ceil(num_ex / eval_batch_size)) |
| steps_per_eval = config.get('steps_per_eval') or total_eval_steps |
| eval_metrics = [] |
| additional_summary = None |
| if compute_map: |
| eval_logits = [] |
| eval_labels = [] |
| n_classes = dataset.meta_data['num_classes'] |
|
|
| for _ in range(steps_per_eval): |
| eval_batch = next(val_iter) |
| if dataset.meta_data['target_is_onehot']: |
| |
| eval_batch['batch_mask'] *= eval_batch['label'].max(axis=-1) |
|
|
| |
| if not compute_map: |
| |
| e_metrics, _ = eval_step_pmapped(train_state, eval_batch) |
| else: |
| |
| e_metrics = eval_step_pmapped(train_state, eval_batch) |
| e_metrics, logits_batch, labels_batch = e_metrics |
| |
| |
| |
|
|
| |
| |
| |
| logits_batch_in_cpu = vivit_train_utils.to_cpu(logits_batch) |
| labels_batch_in_cpu = vivit_train_utils.to_cpu(labels_batch) |
| eval_logits.append(logits_batch_in_cpu) |
| eval_labels.append(labels_batch_in_cpu) |
| eval_metrics.append(train_utils.unreplicate_and_get(e_metrics)) |
|
|
| if compute_map: |
| |
| additional_summary = evaluation_lib.compute_mean_avg_precision_dprime( |
| np.concatenate(eval_logits, axis=0), |
| np.concatenate(eval_labels, axis=0), |
| return_per_class_ap=n_classes < 10) |
|
|
| eval_summary.update( |
| train_utils.log_eval_summary( |
| step=step, |
| eval_metrics=eval_metrics, |
| extra_eval_summary=additional_summary, |
| writer=writer, |
| key_separator='/', |
| prefix=val_name)) |
| del eval_metrics |
| writer.flush() |
| return eval_summary |
|
|
| train_metrics, extra_training_logs = [], [] |
| train_summary, eval_summary = None, None |
| chrono.inform(start_step, total_steps, config.batch_size, steps_per_epoch) |
| logging.info('Starting training loop at step %d.', start_step) |
|
|
| report_progress = periodic_actions.ReportProgress( |
| num_train_steps=total_steps, writer=writer) |
| hooks = [] |
| if lead_host: |
| hooks.append(report_progress) |
| if config.get('xprof', True) and lead_host: |
| hooks.append(periodic_actions.Profile(num_profile_steps=5, logdir=workdir)) |
|
|
| if start_step == 0: |
| step0_log = {'num_trainable_params': num_trainable_params} |
| if gflops: |
| step0_log['gflops'] = gflops |
| writer.write_scalars(1, step0_log) |
|
|
| for step in range(start_step + 1, total_steps + 1): |
| with jax.profiler.StepTraceAnnotation('train', step_num=step): |
| train_batch = next(dataset.train_iter) |
|
|
|
|
| train_state, t_metrics = train_step_pmapped(train_state, train_batch) |
| train_metrics.append(t_metrics) |
| |
| extra_training_logs.append({'learning_rate': lr_fn(step)}) |
|
|
| |
| logging.log_first_n(logging.INFO, 'Finished training step %d.', 5, step) |
| for h in hooks: |
| h(step) |
|
|
| |
| if ((step % log_summary_steps == 1) or (step == total_steps) or |
| (chrono.warmup and lead_host)): |
| chrono.pause(wait_for=(train_metrics)) |
| if lead_host: |
| chrono.tick(step, writer, avmae_train_utils.log_note) |
| train_summary = train_utils.log_train_summary( |
| step=step, |
| train_metrics=jax.tree_util.tree_map(train_utils.unreplicate_and_get, |
| train_metrics), |
| extra_training_logs=extra_training_logs, |
| writer=writer) |
| |
| train_metrics, extra_training_logs = [], [] |
| chrono.resume() |
|
|
| |
| if (step % log_eval_steps == 1) or (step == total_steps): |
| chrono.pause(wait_for=(train_state.params)) |
| with report_progress.timed('eval'): |
| |
| train_state = train_utils.sync_model_state_across_replicas(train_state) |
| logging.info('Starting validation') |
| eval_summary = evaluate(train_state, step, dataset.valid_iter, |
| dataset.meta_data['num_eval_examples'], |
| is_multilabel_model) |
| chrono.resume() |
|
|
| |
| if ((step % checkpoint_steps == 1 and step > 1) or |
| (step == total_steps)) and config.checkpoint: |
| chrono.pause(wait_for=(train_state.params, train_state.opt_state)) |
| with report_progress.timed('checkpoint'): |
| train_utils.handle_checkpointing(train_state, chrono, workdir) |
| chrono.resume() |
|
|
| |
| if 'fewshot' in config: |
| |
| if (step % config.fewshot.log_eval_steps == 1) or (step == total_steps): |
| chrono.pause(wait_for=(train_state.params)) |
| with report_progress.timed('fewshot'): |
| results = fewshotter.run_all(train_state, config.fewshot.datasets) |
| fewshotter.log_fewshot_summary( |
| writer=writer, step=step, results=results) |
| del results |
| writer.write_scalars(step, {'zz/epoch': step / steps_per_epoch}) |
| writer.flush() |
| chrono.resume() |
|
|
| |
| if (config.dataset_configs.get('do_multicrop_test') and |
| ((step % log_test_steps == 1) or step == total_steps)): |
|
|
| chrono.pause(wait_for=(train_state.params)) |
| with report_progress.timed('test'): |
| test_metrics = [] |
| additional_summary = None |
| if is_multilabel_model: |
| all_test_logits, all_test_labels = [], [] |
| |
| train_state = train_utils.sync_model_state_across_replicas(train_state) |
|
|
| |
| if step == total_steps: |
| steps_per_test = total_test_steps |
|
|
| logging.info('Starting multicrop test') |
| for _ in range(steps_per_test): |
| test_batch = next(dataset.test_iter) |
|
|
| if not is_multilabel_model: |
| |
| t_metrics = test_step_pmapped(train_state, test_batch) |
| else: |
| |
| t_metrics, t_logits, t_labels = test_step_pmapped( |
| train_state, test_batch) |
| |
| |
| t_logits_in_cpu = vivit_train_utils.to_cpu(t_logits) |
| t_labels_in_cpu = vivit_train_utils.to_cpu(t_labels) |
| all_test_logits.append(t_logits_in_cpu) |
| all_test_labels.append(t_labels_in_cpu) |
|
|
| |
| test_metrics.append(train_utils.unreplicate_and_get(t_metrics)) |
|
|
| if is_multilabel_model: |
| |
| all_test_logits_concat = np.concatenate(all_test_logits, axis=0) |
| all_test_labels_concat = np.concatenate(all_test_labels, axis=0) |
| additional_summary = evaluation_lib.compute_mean_avg_precision_dprime( |
| all_test_logits_concat, |
| all_test_labels_concat, |
| return_per_class_ap=dataset.meta_data['num_classes'] < 10) |
|
|
| |
| train_utils.log_eval_summary( |
| step=step, |
| eval_metrics=test_metrics, |
| extra_eval_summary=additional_summary, |
| writer=writer, |
| key_separator='/', |
| prefix='test') |
| logging.info('Completed multicrop test') |
| del test_metrics |
| writer.flush() |
| chrono.resume() |
|
|
| |
| train_utils.barrier_across_hosts() |
| logging.info('Parameter summary after completing training.') |
| |
| unrep_train_state = jax_utils.unreplicate(train_state) |
| debug_utils.log_param_shapes(unrep_train_state.params) |
| |
| return train_state, train_summary, eval_summary |
|
|