|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Video classification task definition."""
|
| from typing import Any, Optional, List, Tuple
|
|
|
| from absl import logging
|
| import tensorflow as tf, tf_keras
|
| from official.core import base_task
|
| from official.core import task_factory
|
| from official.modeling import tf_utils
|
| from official.vision.configs import video_classification as exp_cfg
|
| from official.vision.dataloaders import input_reader_factory
|
| from official.vision.dataloaders import video_input
|
| from official.vision.modeling import factory_3d
|
| from official.vision.ops import augment
|
|
|
|
|
| @task_factory.register_task_cls(exp_cfg.VideoClassificationTask)
|
| class VideoClassificationTask(base_task.Task):
|
| """A task for video classification."""
|
|
|
| def _get_num_classes(self):
|
| """Gets the number of classes."""
|
| return self.task_config.train_data.num_classes
|
|
|
| def _get_feature_shape(self):
|
| """Get the common feature shape for train and eval."""
|
| return [
|
| d1 if d1 == d2 else None
|
| for d1, d2 in zip(self.task_config.train_data.feature_shape,
|
| self.task_config.validation_data.feature_shape)
|
| ]
|
|
|
| def _get_num_test_views(self):
|
| """Gets number of views for test."""
|
| num_test_clips = self.task_config.validation_data.num_test_clips
|
| num_test_crops = self.task_config.validation_data.num_test_crops
|
| num_test_views = num_test_clips * num_test_crops
|
| return num_test_views
|
|
|
| def _is_multilabel(self):
|
| """If the label is multi-labels."""
|
| return self.task_config.train_data.is_multilabel
|
|
|
| def build_model(self):
|
| """Builds video classification model."""
|
| common_input_shape = self._get_feature_shape()
|
| input_specs = tf_keras.layers.InputSpec(shape=[None] + common_input_shape)
|
| logging.info('Build model input %r', common_input_shape)
|
|
|
| l2_weight_decay = float(self.task_config.losses.l2_weight_decay)
|
|
|
|
|
|
|
| l2_regularizer = (tf_keras.regularizers.l2(
|
| l2_weight_decay / 2.0) if l2_weight_decay else None)
|
|
|
| model = factory_3d.build_model(
|
| self.task_config.model.model_type,
|
| input_specs=input_specs,
|
| model_config=self.task_config.model,
|
| num_classes=self._get_num_classes(),
|
| l2_regularizer=l2_regularizer)
|
|
|
| if self.task_config.freeze_backbone:
|
| logging.info('Freezing model backbone.')
|
| model.backbone.trainable = False
|
| return model
|
|
|
| def initialize(self, model: tf_keras.Model):
|
| """Loads pretrained checkpoint."""
|
| if not self.task_config.init_checkpoint:
|
| return
|
|
|
| ckpt_dir_or_file = self.task_config.init_checkpoint
|
| if tf.io.gfile.isdir(ckpt_dir_or_file):
|
| ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
|
|
|
|
|
| if self.task_config.init_checkpoint_modules == 'all':
|
| ckpt = tf.train.Checkpoint(model=model)
|
| status = ckpt.read(ckpt_dir_or_file)
|
| status.expect_partial().assert_existing_objects_matched()
|
| elif self.task_config.init_checkpoint_modules == 'backbone':
|
| ckpt = tf.train.Checkpoint(backbone=model.backbone)
|
| status = ckpt.read(ckpt_dir_or_file)
|
| status.expect_partial().assert_existing_objects_matched()
|
| else:
|
| raise ValueError(
|
| "Only 'all' or 'backbone' can be used to initialize the model.")
|
|
|
| logging.info('Finished loading pretrained checkpoint from %s',
|
| ckpt_dir_or_file)
|
|
|
| def _get_dataset_fn(self, params):
|
| if params.file_type == 'tfrecord':
|
| return tf.data.TFRecordDataset
|
| else:
|
| raise ValueError('Unknown input file type {!r}'.format(params.file_type))
|
|
|
| def _get_decoder_fn(self, params):
|
| if params.tfds_name:
|
| decoder = video_input.VideoTfdsDecoder(
|
| image_key=params.image_field_key, label_key=params.label_field_key)
|
| else:
|
| decoder = video_input.Decoder(
|
| image_key=params.image_field_key, label_key=params.label_field_key)
|
| if self.task_config.train_data.output_audio:
|
| assert self.task_config.train_data.audio_feature, 'audio feature is empty'
|
| decoder.add_feature(self.task_config.train_data.audio_feature,
|
| tf.io.VarLenFeature(dtype=tf.float32))
|
| return decoder.decode
|
|
|
| def build_inputs(self,
|
| params: exp_cfg.DataConfig,
|
| input_context: Optional[tf.distribute.InputContext] = None):
|
| """Builds classification input."""
|
|
|
| parser = video_input.Parser(
|
| input_params=params,
|
| image_key=params.image_field_key,
|
| label_key=params.label_field_key)
|
| postprocess_fn = video_input.PostBatchProcessor(params)
|
| if params.mixup_and_cutmix is not None:
|
| def mixup_and_cutmix(features, labels):
|
| augmenter = augment.MixupAndCutmix(
|
| mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
|
| cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
|
| prob=params.mixup_and_cutmix.prob,
|
| label_smoothing=params.mixup_and_cutmix.label_smoothing,
|
| num_classes=self._get_num_classes())
|
| features['image'], labels = augmenter(features['image'], labels)
|
| return features, labels
|
| postprocess_fn = mixup_and_cutmix
|
|
|
| reader = input_reader_factory.input_reader_generator(
|
| params,
|
| dataset_fn=self._get_dataset_fn(params),
|
| decoder_fn=self._get_decoder_fn(params),
|
| parser_fn=parser.parse_fn(params.is_training),
|
| postprocess_fn=postprocess_fn)
|
|
|
| dataset = reader.read(input_context=input_context)
|
|
|
| return dataset
|
|
|
| def build_losses(self,
|
| labels: Any,
|
| model_outputs: Any,
|
| aux_losses: Optional[Any] = None):
|
| """Sparse categorical cross entropy loss.
|
|
|
| Args:
|
| labels: labels.
|
| model_outputs: Output logits of the classifier.
|
| aux_losses: auxiliarly loss tensors, i.e. `losses` in keras.Model.
|
|
|
| Returns:
|
| The total loss tensor.
|
| """
|
| all_losses = {}
|
| losses_config = self.task_config.losses
|
| total_loss = None
|
| if self._is_multilabel():
|
| entropy = -tf.reduce_mean(
|
| tf.reduce_sum(model_outputs * tf.math.log(model_outputs + 1e-8), -1))
|
| total_loss = tf_keras.losses.binary_crossentropy(
|
| labels, model_outputs, from_logits=False)
|
| all_losses.update({
|
| 'class_loss': total_loss,
|
| 'entropy': entropy,
|
| })
|
| else:
|
| if losses_config.one_hot:
|
| total_loss = tf_keras.losses.categorical_crossentropy(
|
| labels,
|
| model_outputs,
|
| from_logits=False,
|
| label_smoothing=losses_config.label_smoothing)
|
| else:
|
| total_loss = tf_keras.losses.sparse_categorical_crossentropy(
|
| labels, model_outputs, from_logits=False)
|
|
|
| total_loss = tf_utils.safe_mean(total_loss)
|
| all_losses.update({
|
| 'class_loss': total_loss,
|
| })
|
| if aux_losses:
|
| all_losses.update({
|
| 'reg_loss': aux_losses,
|
| })
|
| total_loss += tf.add_n(aux_losses)
|
| all_losses[self.loss] = total_loss
|
|
|
| return all_losses
|
|
|
| def build_metrics(self, training: bool = True):
|
| """Gets streaming metrics for training/validation."""
|
| if self.task_config.losses.one_hot:
|
| metrics = [
|
| tf_keras.metrics.CategoricalAccuracy(name='accuracy'),
|
| tf_keras.metrics.TopKCategoricalAccuracy(k=1, name='top_1_accuracy'),
|
| tf_keras.metrics.TopKCategoricalAccuracy(k=5, name='top_5_accuracy')
|
| ]
|
| if self._is_multilabel():
|
| metrics.append(
|
| tf_keras.metrics.AUC(
|
| curve='ROC', multi_label=self._is_multilabel(), name='ROC-AUC'))
|
| metrics.append(
|
| tf_keras.metrics.RecallAtPrecision(
|
| 0.95, name='RecallAtPrecision95'))
|
| metrics.append(
|
| tf_keras.metrics.AUC(
|
| curve='PR', multi_label=self._is_multilabel(), name='PR-AUC'))
|
| if self.task_config.metrics.use_per_class_recall:
|
| for i in range(self._get_num_classes()):
|
| metrics.append(
|
| tf_keras.metrics.Recall(class_id=i, name=f'recall-{i}'))
|
| else:
|
| metrics = [
|
| tf_keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
|
| tf_keras.metrics.SparseTopKCategoricalAccuracy(
|
| k=1, name='top_1_accuracy'),
|
| tf_keras.metrics.SparseTopKCategoricalAccuracy(
|
| k=5, name='top_5_accuracy')
|
| ]
|
| return metrics
|
|
|
| def process_metrics(self, metrics: List[Any], labels: Any,
|
| model_outputs: Any):
|
| """Process and update metrics.
|
|
|
| Called when using custom training loop API.
|
|
|
| Args:
|
| metrics: a nested structure of metrics objects. The return of function
|
| self.build_metrics.
|
| labels: a tensor or a nested structure of tensors.
|
| model_outputs: a tensor or a nested structure of tensors. For example,
|
| output of the keras model built by self.build_model.
|
| """
|
| for metric in metrics:
|
| metric.update_state(labels, model_outputs)
|
|
|
| def train_step(self,
|
| inputs: Tuple[Any, Any],
|
| model: tf_keras.Model,
|
| optimizer: tf_keras.optimizers.Optimizer,
|
| metrics: Optional[List[Any]] = None):
|
| """Does forward and backward.
|
|
|
| Args:
|
| inputs: a dictionary of input tensors.
|
| model: the model, forward pass definition.
|
| optimizer: the optimizer for this training step.
|
| metrics: a nested structure of metrics objects.
|
|
|
| Returns:
|
| A dictionary of logs.
|
| """
|
| features, labels = inputs
|
| input_partition_dims = self.task_config.train_input_partition_dims
|
| if input_partition_dims:
|
| strategy = tf.distribute.get_strategy()
|
| features['image'] = strategy.experimental_split_to_logical_devices(
|
| features['image'], input_partition_dims)
|
|
|
| num_replicas = tf.distribute.get_strategy().num_replicas_in_sync
|
| with tf.GradientTape() as tape:
|
| outputs = model(features, training=True)
|
|
|
|
|
| outputs = tf.nest.map_structure(
|
| lambda x: tf.cast(x, tf.float32), outputs)
|
|
|
|
|
| if self._is_multilabel():
|
| outputs = tf.nest.map_structure(tf.math.sigmoid, outputs)
|
| else:
|
| outputs = tf.nest.map_structure(tf.math.softmax, outputs)
|
| all_losses = self.build_losses(
|
| model_outputs=outputs, labels=labels, aux_losses=model.losses)
|
| loss = all_losses[self.loss]
|
|
|
|
|
| scaled_loss = loss / num_replicas
|
|
|
|
|
|
|
| if isinstance(
|
| optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
| scaled_loss = optimizer.get_scaled_loss(scaled_loss)
|
|
|
| tvars = model.trainable_variables
|
| grads = tape.gradient(scaled_loss, tvars)
|
|
|
|
|
| if isinstance(optimizer, tf_keras.mixed_precision.LossScaleOptimizer):
|
| grads = optimizer.get_unscaled_gradients(grads)
|
| optimizer.apply_gradients(list(zip(grads, tvars)))
|
|
|
| logs = all_losses
|
| if metrics:
|
| self.process_metrics(metrics, labels, outputs)
|
| logs.update({m.name: m.result() for m in metrics})
|
| elif model.compiled_metrics:
|
| self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
|
| logs.update({m.name: m.result() for m in model.metrics})
|
| return logs
|
|
|
| def validation_step(self,
|
| inputs: Tuple[Any, Any],
|
| model: tf_keras.Model,
|
| metrics: Optional[List[Any]] = None):
|
| """Validatation step.
|
|
|
| Args:
|
| inputs: a dictionary of input tensors.
|
| model: the keras.Model.
|
| metrics: a nested structure of metrics objects.
|
|
|
| Returns:
|
| A dictionary of logs.
|
| """
|
| features, labels = inputs
|
| input_partition_dims = self.task_config.eval_input_partition_dims
|
| if input_partition_dims:
|
| strategy = tf.distribute.get_strategy()
|
| features['image'] = strategy.experimental_split_to_logical_devices(
|
| features['image'], input_partition_dims)
|
|
|
| outputs = self.inference_step(features, model)
|
| outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
|
| logs = self.build_losses(model_outputs=outputs, labels=labels,
|
| aux_losses=model.losses)
|
|
|
| if metrics:
|
| self.process_metrics(metrics, labels, outputs)
|
| logs.update({m.name: m.result() for m in metrics})
|
| elif model.compiled_metrics:
|
| self.process_compiled_metrics(model.compiled_metrics, labels, outputs)
|
| logs.update({m.name: m.result() for m in model.metrics})
|
| return logs
|
|
|
| def inference_step(self, features: tf.Tensor, model: tf_keras.Model):
|
| """Performs the forward step."""
|
| outputs = model(features, training=False)
|
| if self._is_multilabel():
|
| outputs = tf.nest.map_structure(tf.math.sigmoid, outputs)
|
| else:
|
| outputs = tf.nest.map_structure(tf.math.softmax, outputs)
|
| num_test_views = self._get_num_test_views()
|
| if num_test_views > 1:
|
|
|
| outputs = tf.reshape(outputs, [-1, num_test_views, outputs.shape[-1]])
|
| outputs = tf.reduce_mean(outputs, axis=1)
|
| return outputs
|
|
|