|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Model utilities.""" |
|
|
|
|
|
from __future__ import absolute_import |
|
|
from __future__ import division |
|
|
from __future__ import print_function |
|
|
|
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
|
|
|
from absl import logging |
|
|
import numpy as np |
|
|
import tensorflow.compat.v1 as tf |
|
|
|
|
|
from tensorflow.python.tpu import tpu_function |
|
|
|
|
|
|
|
|
def build_learning_rate(initial_lr, |
|
|
global_step, |
|
|
steps_per_epoch=None, |
|
|
lr_decay_type='exponential', |
|
|
decay_factor=0.97, |
|
|
decay_epochs=2.4, |
|
|
total_steps=None, |
|
|
warmup_epochs=5): |
|
|
"""Build learning rate.""" |
|
|
if lr_decay_type == 'exponential': |
|
|
assert steps_per_epoch is not None |
|
|
decay_steps = steps_per_epoch * decay_epochs |
|
|
lr = tf.train.exponential_decay( |
|
|
initial_lr, global_step, decay_steps, decay_factor, staircase=True) |
|
|
elif lr_decay_type == 'cosine': |
|
|
assert total_steps is not None |
|
|
lr = 0.5 * initial_lr * ( |
|
|
1 + tf.cos(np.pi * tf.cast(global_step, tf.float32) / total_steps)) |
|
|
elif lr_decay_type == 'constant': |
|
|
lr = initial_lr |
|
|
else: |
|
|
assert False, 'Unknown lr_decay_type : %s' % lr_decay_type |
|
|
|
|
|
if warmup_epochs: |
|
|
logging.info('Learning rate warmup_epochs: %d', warmup_epochs) |
|
|
warmup_steps = int(warmup_epochs * steps_per_epoch) |
|
|
warmup_lr = ( |
|
|
initial_lr * tf.cast(global_step, tf.float32) / tf.cast( |
|
|
warmup_steps, tf.float32)) |
|
|
lr = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr) |
|
|
|
|
|
return lr |
|
|
|
|
|
|
|
|
def build_optimizer(learning_rate, |
|
|
optimizer_name='rmsprop', |
|
|
decay=0.9, |
|
|
epsilon=0.001, |
|
|
momentum=0.9): |
|
|
"""Build optimizer.""" |
|
|
if optimizer_name == 'sgd': |
|
|
logging.info('Using SGD optimizer') |
|
|
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate) |
|
|
elif optimizer_name == 'momentum': |
|
|
logging.info('Using Momentum optimizer') |
|
|
optimizer = tf.train.MomentumOptimizer( |
|
|
learning_rate=learning_rate, momentum=momentum) |
|
|
elif optimizer_name == 'rmsprop': |
|
|
logging.info('Using RMSProp optimizer') |
|
|
optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum, |
|
|
epsilon) |
|
|
else: |
|
|
logging.fatal('Unknown optimizer: %s', optimizer_name) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
|
|
|
class TpuBatchNormalization(tf.layers.BatchNormalization): |
|
|
|
|
|
"""Cross replica batch normalization.""" |
|
|
|
|
|
def __init__(self, fused=False, **kwargs): |
|
|
if fused in (True, None): |
|
|
raise ValueError('TpuBatchNormalization does not support fused=True.') |
|
|
super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs) |
|
|
|
|
|
def _cross_replica_average(self, t, num_shards_per_group): |
|
|
"""Calculates the average value of input tensor across TPU replicas.""" |
|
|
num_shards = tpu_function.get_tpu_context().number_of_shards |
|
|
group_assignment = None |
|
|
if num_shards_per_group > 1: |
|
|
if num_shards % num_shards_per_group != 0: |
|
|
raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0' |
|
|
% (num_shards, num_shards_per_group)) |
|
|
num_groups = num_shards // num_shards_per_group |
|
|
group_assignment = [[ |
|
|
x for x in range(num_shards) if x // num_shards_per_group == y |
|
|
] for y in range(num_groups)] |
|
|
return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast( |
|
|
num_shards_per_group, t.dtype) |
|
|
|
|
|
def _moments(self, inputs, reduction_axes, keep_dims): |
|
|
"""Compute the mean and variance: it overrides the original _moments.""" |
|
|
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments( |
|
|
inputs, reduction_axes, keep_dims=keep_dims) |
|
|
|
|
|
num_shards = tpu_function.get_tpu_context().number_of_shards or 1 |
|
|
if num_shards <= 8: |
|
|
num_shards_per_group = 1 |
|
|
else: |
|
|
num_shards_per_group = max(8, num_shards // 8) |
|
|
logging.info('TpuBatchNormalization with num_shards_per_group %s', |
|
|
num_shards_per_group) |
|
|
if num_shards_per_group > 1: |
|
|
|
|
|
shard_square_of_mean = tf.math.square(shard_mean) |
|
|
shard_mean_of_square = shard_variance + shard_square_of_mean |
|
|
group_mean = self._cross_replica_average( |
|
|
shard_mean, num_shards_per_group) |
|
|
group_mean_of_square = self._cross_replica_average( |
|
|
shard_mean_of_square, num_shards_per_group) |
|
|
group_variance = group_mean_of_square - tf.math.square(group_mean) |
|
|
return (group_mean, group_variance) |
|
|
else: |
|
|
return (shard_mean, shard_variance) |
|
|
|
|
|
|
|
|
class BatchNormalization(tf.layers.BatchNormalization): |
|
|
"""Fixed default name of BatchNormalization to match TpuBatchNormalization.""" |
|
|
|
|
|
def __init__(self, name='tpu_batch_normalization', **kwargs): |
|
|
super(BatchNormalization, self).__init__(name=name, **kwargs) |
|
|
|
|
|
|
|
|
def drop_connect(inputs, is_training, survival_prob): |
|
|
"""Drop the entire conv with given survival probability.""" |
|
|
|
|
|
if not is_training: |
|
|
return inputs |
|
|
|
|
|
|
|
|
batch_size = tf.shape(inputs)[0] |
|
|
random_tensor = survival_prob |
|
|
random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype) |
|
|
binary_tensor = tf.floor(random_tensor) |
|
|
|
|
|
|
|
|
|
|
|
output = tf.div(inputs, survival_prob) * binary_tensor |
|
|
return output |
|
|
|
|
|
|
|
|
def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path): |
|
|
"""Archive a checkpoint if the metric is better.""" |
|
|
ckpt_dir, ckpt_name = os.path.split(ckpt_path) |
|
|
|
|
|
saved_objective_path = os.path.join(ckpt_dir, 'best_objective.txt') |
|
|
saved_objective = float('-inf') |
|
|
if tf.gfile.Exists(saved_objective_path): |
|
|
with tf.gfile.GFile(saved_objective_path, 'r') as f: |
|
|
saved_objective = float(f.read()) |
|
|
if saved_objective > ckpt_objective: |
|
|
logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective) |
|
|
return False |
|
|
|
|
|
filenames = tf.gfile.Glob(ckpt_path + '.*') |
|
|
if filenames is None: |
|
|
logging.info('No files to copy for checkpoint %s', ckpt_path) |
|
|
return False |
|
|
|
|
|
|
|
|
dst_dir = os.path.join(ckpt_dir, 'archive') |
|
|
if tf.gfile.Exists(dst_dir): |
|
|
tf.gfile.DeleteRecursively(dst_dir) |
|
|
tf.gfile.MakeDirs(dst_dir) |
|
|
|
|
|
|
|
|
for f in filenames: |
|
|
dest = os.path.join(dst_dir, os.path.basename(f)) |
|
|
tf.gfile.Copy(f, dest, overwrite=True) |
|
|
ckpt_state = tf.train.generate_checkpoint_state_proto( |
|
|
dst_dir, |
|
|
model_checkpoint_path=ckpt_name, |
|
|
all_model_checkpoint_paths=[ckpt_name]) |
|
|
with tf.gfile.GFile(os.path.join(dst_dir, 'checkpoint'), 'w') as f: |
|
|
f.write(str(ckpt_state)) |
|
|
with tf.gfile.GFile(os.path.join(dst_dir, 'best_eval.txt'), 'w') as f: |
|
|
f.write('%s' % ckpt_eval) |
|
|
|
|
|
|
|
|
with tf.gfile.GFile(saved_objective_path, 'w') as f: |
|
|
f.write('%f' % ckpt_objective) |
|
|
|
|
|
logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir) |
|
|
return True |
|
|
|
|
|
|
|
|
def get_ema_vars(): |
|
|
"""Get all exponential moving average (ema) variables.""" |
|
|
ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars') |
|
|
for v in tf.global_variables(): |
|
|
|
|
|
if 'moving_mean' in v.name or 'moving_variance' in v.name: |
|
|
ema_vars.append(v) |
|
|
return list(set(ema_vars)) |
|
|
|
|
|
|
|
|
class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer): |
|
|
"""Wrap keras DepthwiseConv2D to tf.layers.""" |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
class EvalCkptDriver(object): |
|
|
"""A driver for running eval inference. |
|
|
|
|
|
Attributes: |
|
|
model_name: str. Model name to eval. |
|
|
batch_size: int. Eval batch size. |
|
|
image_size: int. Input image size, determined by model name. |
|
|
num_classes: int. Number of classes, default to 1000 for ImageNet. |
|
|
include_background_label: whether to include extra background label. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model_name, |
|
|
batch_size=1, |
|
|
image_size=224, |
|
|
num_classes=1000, |
|
|
include_background_label=False): |
|
|
"""Initialize internal variables.""" |
|
|
self.model_name = model_name |
|
|
self.batch_size = batch_size |
|
|
self.num_classes = num_classes |
|
|
self.include_background_label = include_background_label |
|
|
self.image_size = image_size |
|
|
|
|
|
def restore_model(self, sess, ckpt_dir, enable_ema=True, export_ckpt=None): |
|
|
"""Restore variables from checkpoint dir.""" |
|
|
sess.run(tf.global_variables_initializer()) |
|
|
checkpoint = tf.train.latest_checkpoint(ckpt_dir) |
|
|
if enable_ema: |
|
|
ema = tf.train.ExponentialMovingAverage(decay=0.0) |
|
|
ema_vars = get_ema_vars() |
|
|
var_dict = ema.variables_to_restore(ema_vars) |
|
|
ema_assign_op = ema.apply(ema_vars) |
|
|
else: |
|
|
var_dict = get_ema_vars() |
|
|
ema_assign_op = None |
|
|
|
|
|
tf.train.get_or_create_global_step() |
|
|
sess.run(tf.global_variables_initializer()) |
|
|
saver = tf.train.Saver(var_dict, max_to_keep=1) |
|
|
saver.restore(sess, checkpoint) |
|
|
|
|
|
if export_ckpt: |
|
|
if ema_assign_op is not None: |
|
|
sess.run(ema_assign_op) |
|
|
saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True) |
|
|
saver.save(sess, export_ckpt) |
|
|
|
|
|
def build_model(self, features, is_training): |
|
|
"""Build model with input features.""" |
|
|
del features, is_training |
|
|
raise ValueError('Must be implemented by subclasses.') |
|
|
|
|
|
def get_preprocess_fn(self): |
|
|
raise ValueError('Must be implemented by subclsses.') |
|
|
|
|
|
def build_dataset(self, filenames, labels, is_training): |
|
|
"""Build input dataset.""" |
|
|
batch_drop_remainder = False |
|
|
if 'condconv' in self.model_name and not is_training: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
batch_drop_remainder = True |
|
|
num_files = len(filenames) |
|
|
if num_files % self.batch_size != 0: |
|
|
tf.logging.warn('Remaining examples in last batch are not being ' |
|
|
'evaluated.') |
|
|
filenames = tf.constant(filenames) |
|
|
labels = tf.constant(labels) |
|
|
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels)) |
|
|
|
|
|
def _parse_function(filename, label): |
|
|
image_string = tf.read_file(filename) |
|
|
preprocess_fn = self.get_preprocess_fn() |
|
|
image_decoded = preprocess_fn( |
|
|
image_string, is_training, image_size=self.image_size) |
|
|
image = tf.cast(image_decoded, tf.float32) |
|
|
return image, label |
|
|
|
|
|
dataset = dataset.map(_parse_function) |
|
|
dataset = dataset.batch(self.batch_size, |
|
|
drop_remainder=batch_drop_remainder) |
|
|
|
|
|
iterator = dataset.make_one_shot_iterator() |
|
|
images, labels = iterator.get_next() |
|
|
return images, labels |
|
|
|
|
|
def run_inference(self, |
|
|
ckpt_dir, |
|
|
image_files, |
|
|
labels, |
|
|
enable_ema=True, |
|
|
export_ckpt=None): |
|
|
"""Build and run inference on the target images and labels.""" |
|
|
label_offset = 1 if self.include_background_label else 0 |
|
|
with tf.Graph().as_default(), tf.Session() as sess: |
|
|
images, labels = self.build_dataset(image_files, labels, False) |
|
|
probs = self.build_model(images, is_training=False) |
|
|
if isinstance(probs, tuple): |
|
|
probs = probs[0] |
|
|
|
|
|
self.restore_model(sess, ckpt_dir, enable_ema, export_ckpt) |
|
|
|
|
|
prediction_idx = [] |
|
|
prediction_prob = [] |
|
|
for _ in range(len(image_files) // self.batch_size): |
|
|
out_probs = sess.run(probs) |
|
|
idx = np.argsort(out_probs)[::-1] |
|
|
prediction_idx.append(idx[:5] - label_offset) |
|
|
prediction_prob.append([out_probs[pid] for pid in idx[:5]]) |
|
|
|
|
|
|
|
|
return prediction_idx, prediction_prob |
|
|
|
|
|
def eval_example_images(self, |
|
|
ckpt_dir, |
|
|
image_files, |
|
|
labels_map_file, |
|
|
enable_ema=True, |
|
|
export_ckpt=None): |
|
|
"""Eval a list of example images. |
|
|
|
|
|
Args: |
|
|
ckpt_dir: str. Checkpoint directory path. |
|
|
image_files: List[str]. A list of image file paths. |
|
|
labels_map_file: str. The labels map file path. |
|
|
enable_ema: enable expotential moving average. |
|
|
export_ckpt: export ckpt folder. |
|
|
|
|
|
Returns: |
|
|
A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction |
|
|
index and pred_prob is the top 5 prediction probability. |
|
|
""" |
|
|
classes = json.loads(tf.gfile.Open(labels_map_file).read()) |
|
|
pred_idx, pred_prob = self.run_inference( |
|
|
ckpt_dir, image_files, [0] * len(image_files), enable_ema, export_ckpt) |
|
|
for i in range(len(image_files)): |
|
|
print('predicted class for image {}: '.format(image_files[i])) |
|
|
for j, idx in enumerate(pred_idx[i]): |
|
|
print(' -> top_{} ({:4.2f}%): {} '.format(j, pred_prob[i][j] * 100, |
|
|
classes[str(idx)])) |
|
|
return pred_idx, pred_prob |
|
|
|
|
|
def eval_imagenet(self, ckpt_dir, imagenet_eval_glob, |
|
|
imagenet_eval_label, num_images, enable_ema, export_ckpt): |
|
|
"""Eval ImageNet images and report top1/top5 accuracy. |
|
|
|
|
|
Args: |
|
|
ckpt_dir: str. Checkpoint directory path. |
|
|
imagenet_eval_glob: str. File path glob for all eval images. |
|
|
imagenet_eval_label: str. File path for eval label. |
|
|
num_images: int. Number of images to eval: -1 means eval the whole |
|
|
dataset. |
|
|
enable_ema: enable expotential moving average. |
|
|
export_ckpt: export checkpoint folder. |
|
|
|
|
|
Returns: |
|
|
A tuple (top1, top5) for top1 and top5 accuracy. |
|
|
""" |
|
|
imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)] |
|
|
imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob)) |
|
|
if num_images < 0: |
|
|
num_images = len(imagenet_filenames) |
|
|
image_files = imagenet_filenames[:num_images] |
|
|
labels = imagenet_val_labels[:num_images] |
|
|
|
|
|
pred_idx, _ = self.run_inference( |
|
|
ckpt_dir, image_files, labels, enable_ema, export_ckpt) |
|
|
top1_cnt, top5_cnt = 0.0, 0.0 |
|
|
for i, label in enumerate(labels): |
|
|
top1_cnt += label in pred_idx[i][:1] |
|
|
top5_cnt += label in pred_idx[i][:5] |
|
|
if i % 100 == 0: |
|
|
print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format( |
|
|
i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1))) |
|
|
sys.stdout.flush() |
|
|
top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images |
|
|
print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5)) |
|
|
return top1, top5 |
|
|
|