itsyuimorii's picture
Upload folder using huggingface_hub
efa0ffc verified
import pprint
import tensorflow as tf
import tensorflow_models as tfm
from official.core import exp_factory
from official.vision.serving import export_saved_model_lib
from shared import show_batch, export_dir, HEIGHT, WIDTH
pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation
print(tf.__version__) # Check the version of tensorflow used
train_data_input_path = './tfrecords/train-00000-of-00001.tfrecord'
valid_data_input_path = './tfrecords/valid-00000-of-00001.tfrecord'
model_dir = './trained_model/'
exp_config = exp_factory.get_exp_config('retinanet_mobile_coco')
batch_size = 1
num_classes = 7
IMG_SIZE = [HEIGHT, WIDTH, 3]
# Backbone config.
exp_config.task.freeze_backbone = False
exp_config.task.annotation_file = ''
# Model config.
exp_config.task.model.input_size = IMG_SIZE
exp_config.task.model.num_classes = num_classes + 1
exp_config.task.model.detection_generator.tflite_post_processing.max_classes_per_detection = exp_config.task.model.num_classes
# Training data config.
exp_config.task.train_data.input_path = train_data_input_path
exp_config.task.train_data.dtype = 'float32'
exp_config.task.train_data.global_batch_size = batch_size
exp_config.task.train_data.parser.aug_scale_max = 1.0
exp_config.task.train_data.parser.aug_scale_min = 1.0
# Validation data config.
exp_config.task.validation_data.input_path = valid_data_input_path
exp_config.task.validation_data.dtype = 'float32'
exp_config.task.validation_data.global_batch_size = batch_size
logical_device_names = [
logical_device.name for logical_device in tf.config.list_logical_devices()]
if 'GPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'GPU'
elif 'TPU' in ''.join(logical_device_names):
print('This may be broken in Colab.')
device = 'TPU'
else:
print('Running on CPU is slow, so only train for a few steps.')
device = 'CPU'
train_steps = 10000
# steps_per_loop = num_of_training_examples // train_batch_size
exp_config.trainer.steps_per_loop = 100
exp_config.trainer.summary_interval = 100
exp_config.trainer.checkpoint_interval = 100
exp_config.trainer.validation_interval = 100
# validation_steps = num_of_validation_examples // eval_batch_size
exp_config.trainer.validation_steps = 100
exp_config.trainer.train_steps = train_steps
exp_config.trainer.optimizer_config.warmup.linear.warmup_steps = 100
exp_config.trainer.optimizer_config.learning_rate.type = 'cosine'
exp_config.trainer.optimizer_config.learning_rate.cosine.decay_steps = train_steps
exp_config.trainer.optimizer_config.learning_rate.cosine.initial_learning_rate = 0.1
exp_config.trainer.optimizer_config.warmup.linear.warmup_learning_rate = 0.05
if exp_config.runtime.mixed_precision_dtype == tf.float16:
tf.keras.mixed_precision.set_global_policy('mixed_float16')
if 'GPU' in ''.join(logical_device_names):
distribution_strategy = tf.distribute.MirroredStrategy()
elif 'TPU' in ''.join(logical_device_names):
tf.tpu.experimental.initialize_tpu_system()
tpu = tf.distribute.cluster_resolver.TPUClusterResolver(
tpu='/device:TPU_SYSTEM:0')
distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu)
else:
print('Warning: this will be really slow.')
distribution_strategy = tf.distribute.OneDeviceStrategy(
logical_device_names[0])
print('Done')
with distribution_strategy.scope():
task = tfm.core.task_factory.get_task(
exp_config.task, logging_dir=model_dir)
for images, labels in task.build_inputs(exp_config.task.train_data).take(1):
print()
print(
f'images.shape: {str(images.shape):16} images.dtype: {images.dtype!r}')
print(f'labels.keys: {labels.keys()}')
model, eval_logs = tfm.core.train_lib.run_experiment(
distribution_strategy=distribution_strategy,
task=task,
mode='train_and_eval',
params=exp_config,
model_dir=model_dir,
run_post_eval=True)
export_saved_model_lib.export_inference_graph(
input_type='image_tensor',
batch_size=1,
input_image_size=[HEIGHT, WIDTH],
params=exp_config,
checkpoint_path=tf.train.latest_checkpoint(model_dir),
export_dir=export_dir)
buffer_size = 20
num_of_examples = 3
raw_records = tf.data.TFRecordDataset(
exp_config.task.train_data.input_path).shuffle(
buffer_size=buffer_size).take(num_of_examples)
show_batch(raw_records)