Aryan6192's picture
deep
79cf6ef verified
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Model utilities."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import json
import os
import sys
from absl import logging
import numpy as np
import tensorflow.compat.v1 as tf
from tensorflow.python.tpu import tpu_function # pylint:disable=g-direct-tensorflow-import
def build_learning_rate(initial_lr,
global_step,
steps_per_epoch=None,
lr_decay_type='exponential',
decay_factor=0.97,
decay_epochs=2.4,
total_steps=None,
warmup_epochs=5):
"""Build learning rate."""
if lr_decay_type == 'exponential':
assert steps_per_epoch is not None
decay_steps = steps_per_epoch * decay_epochs
lr = tf.train.exponential_decay(
initial_lr, global_step, decay_steps, decay_factor, staircase=True)
elif lr_decay_type == 'cosine':
assert total_steps is not None
lr = 0.5 * initial_lr * (
1 + tf.cos(np.pi * tf.cast(global_step, tf.float32) / total_steps))
elif lr_decay_type == 'constant':
lr = initial_lr
else:
assert False, 'Unknown lr_decay_type : %s' % lr_decay_type
if warmup_epochs:
logging.info('Learning rate warmup_epochs: %d', warmup_epochs)
warmup_steps = int(warmup_epochs * steps_per_epoch)
warmup_lr = (
initial_lr * tf.cast(global_step, tf.float32) / tf.cast(
warmup_steps, tf.float32))
lr = tf.cond(global_step < warmup_steps, lambda: warmup_lr, lambda: lr)
return lr
def build_optimizer(learning_rate,
optimizer_name='rmsprop',
decay=0.9,
epsilon=0.001,
momentum=0.9):
"""Build optimizer."""
if optimizer_name == 'sgd':
logging.info('Using SGD optimizer')
optimizer = tf.train.GradientDescentOptimizer(learning_rate=learning_rate)
elif optimizer_name == 'momentum':
logging.info('Using Momentum optimizer')
optimizer = tf.train.MomentumOptimizer(
learning_rate=learning_rate, momentum=momentum)
elif optimizer_name == 'rmsprop':
logging.info('Using RMSProp optimizer')
optimizer = tf.train.RMSPropOptimizer(learning_rate, decay, momentum,
epsilon)
else:
logging.fatal('Unknown optimizer: %s', optimizer_name)
return optimizer
class TpuBatchNormalization(tf.layers.BatchNormalization):
# class TpuBatchNormalization(tf.layers.BatchNormalization):
"""Cross replica batch normalization."""
def __init__(self, fused=False, **kwargs):
if fused in (True, None):
raise ValueError('TpuBatchNormalization does not support fused=True.')
super(TpuBatchNormalization, self).__init__(fused=fused, **kwargs)
def _cross_replica_average(self, t, num_shards_per_group):
"""Calculates the average value of input tensor across TPU replicas."""
num_shards = tpu_function.get_tpu_context().number_of_shards
group_assignment = None
if num_shards_per_group > 1:
if num_shards % num_shards_per_group != 0:
raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
% (num_shards, num_shards_per_group))
num_groups = num_shards // num_shards_per_group
group_assignment = [[
x for x in range(num_shards) if x // num_shards_per_group == y
] for y in range(num_groups)]
return tf.tpu.cross_replica_sum(t, group_assignment) / tf.cast(
num_shards_per_group, t.dtype)
def _moments(self, inputs, reduction_axes, keep_dims):
"""Compute the mean and variance: it overrides the original _moments."""
shard_mean, shard_variance = super(TpuBatchNormalization, self)._moments(
inputs, reduction_axes, keep_dims=keep_dims)
num_shards = tpu_function.get_tpu_context().number_of_shards or 1
if num_shards <= 8: # Skip cross_replica for 2x2 or smaller slices.
num_shards_per_group = 1
else:
num_shards_per_group = max(8, num_shards // 8)
logging.info('TpuBatchNormalization with num_shards_per_group %s',
num_shards_per_group)
if num_shards_per_group > 1:
# Compute variance using: Var[X]= E[X^2] - E[X]^2.
shard_square_of_mean = tf.math.square(shard_mean)
shard_mean_of_square = shard_variance + shard_square_of_mean
group_mean = self._cross_replica_average(
shard_mean, num_shards_per_group)
group_mean_of_square = self._cross_replica_average(
shard_mean_of_square, num_shards_per_group)
group_variance = group_mean_of_square - tf.math.square(group_mean)
return (group_mean, group_variance)
else:
return (shard_mean, shard_variance)
class BatchNormalization(tf.layers.BatchNormalization):
"""Fixed default name of BatchNormalization to match TpuBatchNormalization."""
def __init__(self, name='tpu_batch_normalization', **kwargs):
super(BatchNormalization, self).__init__(name=name, **kwargs)
def drop_connect(inputs, is_training, survival_prob):
"""Drop the entire conv with given survival probability."""
# "Deep Networks with Stochastic Depth", https://arxiv.org/pdf/1603.09382.pdf
if not is_training:
return inputs
# Compute tensor.
batch_size = tf.shape(inputs)[0]
random_tensor = survival_prob
random_tensor += tf.random_uniform([batch_size, 1, 1, 1], dtype=inputs.dtype)
binary_tensor = tf.floor(random_tensor)
# Unlike conventional way that multiply survival_prob at test time, here we
# divide survival_prob at training time, such that no addition compute is
# needed at test time.
output = tf.div(inputs, survival_prob) * binary_tensor
return output
def archive_ckpt(ckpt_eval, ckpt_objective, ckpt_path):
"""Archive a checkpoint if the metric is better."""
ckpt_dir, ckpt_name = os.path.split(ckpt_path)
saved_objective_path = os.path.join(ckpt_dir, 'best_objective.txt')
saved_objective = float('-inf')
if tf.gfile.Exists(saved_objective_path):
with tf.gfile.GFile(saved_objective_path, 'r') as f:
saved_objective = float(f.read())
if saved_objective > ckpt_objective:
logging.info('Ckpt %s is worse than %s', ckpt_objective, saved_objective)
return False
filenames = tf.gfile.Glob(ckpt_path + '.*')
if filenames is None:
logging.info('No files to copy for checkpoint %s', ckpt_path)
return False
# Clear the old folder.
dst_dir = os.path.join(ckpt_dir, 'archive')
if tf.gfile.Exists(dst_dir):
tf.gfile.DeleteRecursively(dst_dir)
tf.gfile.MakeDirs(dst_dir)
# Write checkpoints.
for f in filenames:
dest = os.path.join(dst_dir, os.path.basename(f))
tf.gfile.Copy(f, dest, overwrite=True)
ckpt_state = tf.train.generate_checkpoint_state_proto(
dst_dir,
model_checkpoint_path=ckpt_name,
all_model_checkpoint_paths=[ckpt_name])
with tf.gfile.GFile(os.path.join(dst_dir, 'checkpoint'), 'w') as f:
f.write(str(ckpt_state))
with tf.gfile.GFile(os.path.join(dst_dir, 'best_eval.txt'), 'w') as f:
f.write('%s' % ckpt_eval)
# Update the best objective.
with tf.gfile.GFile(saved_objective_path, 'w') as f:
f.write('%f' % ckpt_objective)
logging.info('Copying checkpoint %s to %s', ckpt_path, dst_dir)
return True
def get_ema_vars():
"""Get all exponential moving average (ema) variables."""
ema_vars = tf.trainable_variables() + tf.get_collection('moving_vars')
for v in tf.global_variables():
# We maintain mva for batch norm moving mean and variance as well.
if 'moving_mean' in v.name or 'moving_variance' in v.name:
ema_vars.append(v)
return list(set(ema_vars))
class DepthwiseConv2D(tf.keras.layers.DepthwiseConv2D, tf.layers.Layer):
"""Wrap keras DepthwiseConv2D to tf.layers."""
pass
class EvalCkptDriver(object):
"""A driver for running eval inference.
Attributes:
model_name: str. Model name to eval.
batch_size: int. Eval batch size.
image_size: int. Input image size, determined by model name.
num_classes: int. Number of classes, default to 1000 for ImageNet.
include_background_label: whether to include extra background label.
"""
def __init__(self,
model_name,
batch_size=1,
image_size=224,
num_classes=1000,
include_background_label=False):
"""Initialize internal variables."""
self.model_name = model_name
self.batch_size = batch_size
self.num_classes = num_classes
self.include_background_label = include_background_label
self.image_size = image_size
def restore_model(self, sess, ckpt_dir, enable_ema=True, export_ckpt=None):
"""Restore variables from checkpoint dir."""
sess.run(tf.global_variables_initializer())
checkpoint = tf.train.latest_checkpoint(ckpt_dir)
if enable_ema:
ema = tf.train.ExponentialMovingAverage(decay=0.0)
ema_vars = get_ema_vars()
var_dict = ema.variables_to_restore(ema_vars)
ema_assign_op = ema.apply(ema_vars)
else:
var_dict = get_ema_vars()
ema_assign_op = None
tf.train.get_or_create_global_step()
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver(var_dict, max_to_keep=1)
saver.restore(sess, checkpoint)
if export_ckpt:
if ema_assign_op is not None:
sess.run(ema_assign_op)
saver = tf.train.Saver(max_to_keep=1, save_relative_paths=True)
saver.save(sess, export_ckpt)
def build_model(self, features, is_training):
"""Build model with input features."""
del features, is_training
raise ValueError('Must be implemented by subclasses.')
def get_preprocess_fn(self):
raise ValueError('Must be implemented by subclsses.')
def build_dataset(self, filenames, labels, is_training):
"""Build input dataset."""
batch_drop_remainder = False
if 'condconv' in self.model_name and not is_training:
# CondConv layers can only be called with known batch dimension. Thus, we
# must drop all remaining examples that do not make up one full batch.
# To ensure all examples are evaluated, use a batch size that evenly
# divides the number of files.
batch_drop_remainder = True
num_files = len(filenames)
if num_files % self.batch_size != 0:
tf.logging.warn('Remaining examples in last batch are not being '
'evaluated.')
filenames = tf.constant(filenames)
labels = tf.constant(labels)
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
def _parse_function(filename, label):
image_string = tf.read_file(filename)
preprocess_fn = self.get_preprocess_fn()
image_decoded = preprocess_fn(
image_string, is_training, image_size=self.image_size)
image = tf.cast(image_decoded, tf.float32)
return image, label
dataset = dataset.map(_parse_function)
dataset = dataset.batch(self.batch_size,
drop_remainder=batch_drop_remainder)
iterator = dataset.make_one_shot_iterator()
images, labels = iterator.get_next()
return images, labels
def run_inference(self,
ckpt_dir,
image_files,
labels,
enable_ema=True,
export_ckpt=None):
"""Build and run inference on the target images and labels."""
label_offset = 1 if self.include_background_label else 0
with tf.Graph().as_default(), tf.Session() as sess:
images, labels = self.build_dataset(image_files, labels, False)
probs = self.build_model(images, is_training=False)
if isinstance(probs, tuple):
probs = probs[0]
self.restore_model(sess, ckpt_dir, enable_ema, export_ckpt)
prediction_idx = []
prediction_prob = []
for _ in range(len(image_files) // self.batch_size):
out_probs = sess.run(probs)
idx = np.argsort(out_probs)[::-1]
prediction_idx.append(idx[:5] - label_offset)
prediction_prob.append([out_probs[pid] for pid in idx[:5]])
# Return the top 5 predictions (idx and prob) for each image.
return prediction_idx, prediction_prob
def eval_example_images(self,
ckpt_dir,
image_files,
labels_map_file,
enable_ema=True,
export_ckpt=None):
"""Eval a list of example images.
Args:
ckpt_dir: str. Checkpoint directory path.
image_files: List[str]. A list of image file paths.
labels_map_file: str. The labels map file path.
enable_ema: enable expotential moving average.
export_ckpt: export ckpt folder.
Returns:
A tuple (pred_idx, and pred_prob), where pred_idx is the top 5 prediction
index and pred_prob is the top 5 prediction probability.
"""
classes = json.loads(tf.gfile.Open(labels_map_file).read())
pred_idx, pred_prob = self.run_inference(
ckpt_dir, image_files, [0] * len(image_files), enable_ema, export_ckpt)
for i in range(len(image_files)):
print('predicted class for image {}: '.format(image_files[i]))
for j, idx in enumerate(pred_idx[i]):
print(' -> top_{} ({:4.2f}%): {} '.format(j, pred_prob[i][j] * 100,
classes[str(idx)]))
return pred_idx, pred_prob
def eval_imagenet(self, ckpt_dir, imagenet_eval_glob,
imagenet_eval_label, num_images, enable_ema, export_ckpt):
"""Eval ImageNet images and report top1/top5 accuracy.
Args:
ckpt_dir: str. Checkpoint directory path.
imagenet_eval_glob: str. File path glob for all eval images.
imagenet_eval_label: str. File path for eval label.
num_images: int. Number of images to eval: -1 means eval the whole
dataset.
enable_ema: enable expotential moving average.
export_ckpt: export checkpoint folder.
Returns:
A tuple (top1, top5) for top1 and top5 accuracy.
"""
imagenet_val_labels = [int(i) for i in tf.gfile.GFile(imagenet_eval_label)]
imagenet_filenames = sorted(tf.gfile.Glob(imagenet_eval_glob))
if num_images < 0:
num_images = len(imagenet_filenames)
image_files = imagenet_filenames[:num_images]
labels = imagenet_val_labels[:num_images]
pred_idx, _ = self.run_inference(
ckpt_dir, image_files, labels, enable_ema, export_ckpt)
top1_cnt, top5_cnt = 0.0, 0.0
for i, label in enumerate(labels):
top1_cnt += label in pred_idx[i][:1]
top5_cnt += label in pred_idx[i][:5]
if i % 100 == 0:
print('Step {}: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(
i, 100 * top1_cnt / (i + 1), 100 * top5_cnt / (i + 1)))
sys.stdout.flush()
top1, top5 = 100 * top1_cnt / num_images, 100 * top5_cnt / num_images
print('Final: top1_acc = {:4.2f}% top5_acc = {:4.2f}%'.format(top1, top5))
return top1, top5