|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Tests for object_detection.utils.config_util."""
|
|
|
| from __future__ import absolute_import
|
| from __future__ import division
|
| from __future__ import print_function
|
|
|
| import os
|
| import unittest
|
| from six.moves import range
|
| import tensorflow.compat.v1 as tf
|
|
|
| from google.protobuf import text_format
|
|
|
| from object_detection.protos import eval_pb2
|
| from object_detection.protos import image_resizer_pb2
|
| from object_detection.protos import input_reader_pb2
|
| from object_detection.protos import model_pb2
|
| from object_detection.protos import pipeline_pb2
|
| from object_detection.protos import train_pb2
|
| from object_detection.utils import config_util
|
| from object_detection.utils import tf_version
|
|
|
|
|
| try:
|
| from tensorflow.contrib import training as contrib_training
|
| except ImportError:
|
|
|
| pass
|
|
|
|
|
|
|
| def _write_config(config, config_path):
|
| """Writes a config object to disk."""
|
| config_text = text_format.MessageToString(config)
|
| with tf.gfile.Open(config_path, "wb") as f:
|
| f.write(config_text)
|
|
|
|
|
| def _update_optimizer_with_constant_learning_rate(optimizer, learning_rate):
|
| """Adds a new constant learning rate."""
|
| constant_lr = optimizer.learning_rate.constant_learning_rate
|
| constant_lr.learning_rate = learning_rate
|
|
|
|
|
| def _update_optimizer_with_exponential_decay_learning_rate(
|
| optimizer, learning_rate):
|
| """Adds a new exponential decay learning rate."""
|
| exponential_lr = optimizer.learning_rate.exponential_decay_learning_rate
|
| exponential_lr.initial_learning_rate = learning_rate
|
|
|
|
|
| def _update_optimizer_with_manual_step_learning_rate(
|
| optimizer, initial_learning_rate, learning_rate_scaling):
|
| """Adds a learning rate schedule."""
|
| manual_lr = optimizer.learning_rate.manual_step_learning_rate
|
| manual_lr.initial_learning_rate = initial_learning_rate
|
| for i in range(3):
|
| schedule = manual_lr.schedule.add()
|
| schedule.learning_rate = initial_learning_rate * learning_rate_scaling**i
|
|
|
|
|
| def _update_optimizer_with_cosine_decay_learning_rate(
|
| optimizer, learning_rate, warmup_learning_rate):
|
| """Adds a new cosine decay learning rate."""
|
| cosine_lr = optimizer.learning_rate.cosine_decay_learning_rate
|
| cosine_lr.learning_rate_base = learning_rate
|
| cosine_lr.warmup_learning_rate = warmup_learning_rate
|
|
|
|
|
| class ConfigUtilTest(tf.test.TestCase):
|
|
|
| def _create_and_load_test_configs(self, pipeline_config):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| _write_config(pipeline_config, pipeline_config_path)
|
| return config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
|
|
| def test_get_configs_from_pipeline_file(self):
|
| """Test that proto configs can be read from pipeline config file."""
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.faster_rcnn.num_classes = 10
|
| pipeline_config.train_config.batch_size = 32
|
| pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
|
| pipeline_config.eval_config.num_examples = 20
|
| pipeline_config.eval_input_reader.add().queue_capacity = 100
|
|
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| self.assertProtoEquals(pipeline_config.model, configs["model"])
|
| self.assertProtoEquals(pipeline_config.train_config,
|
| configs["train_config"])
|
| self.assertProtoEquals(pipeline_config.train_input_reader,
|
| configs["train_input_config"])
|
| self.assertProtoEquals(pipeline_config.eval_config,
|
| configs["eval_config"])
|
| self.assertProtoEquals(pipeline_config.eval_input_reader,
|
| configs["eval_input_configs"])
|
|
|
| def test_create_configs_from_pipeline_proto(self):
|
| """Tests creating configs dictionary from pipeline proto."""
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.faster_rcnn.num_classes = 10
|
| pipeline_config.train_config.batch_size = 32
|
| pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
|
| pipeline_config.eval_config.num_examples = 20
|
| pipeline_config.eval_input_reader.add().queue_capacity = 100
|
|
|
| configs = config_util.create_configs_from_pipeline_proto(pipeline_config)
|
| self.assertProtoEquals(pipeline_config.model, configs["model"])
|
| self.assertProtoEquals(pipeline_config.train_config,
|
| configs["train_config"])
|
| self.assertProtoEquals(pipeline_config.train_input_reader,
|
| configs["train_input_config"])
|
| self.assertProtoEquals(pipeline_config.eval_config, configs["eval_config"])
|
| self.assertProtoEquals(pipeline_config.eval_input_reader,
|
| configs["eval_input_configs"])
|
|
|
| def test_create_pipeline_proto_from_configs(self):
|
| """Tests that proto can be reconstructed from configs dictionary."""
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.faster_rcnn.num_classes = 10
|
| pipeline_config.train_config.batch_size = 32
|
| pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
|
| pipeline_config.eval_config.num_examples = 20
|
| pipeline_config.eval_input_reader.add().queue_capacity = 100
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| pipeline_config_reconstructed = (
|
| config_util.create_pipeline_proto_from_configs(configs))
|
| self.assertEqual(pipeline_config, pipeline_config_reconstructed)
|
|
|
| def test_save_pipeline_config(self):
|
| """Tests that the pipeline config is properly saved to disk."""
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.faster_rcnn.num_classes = 10
|
| pipeline_config.train_config.batch_size = 32
|
| pipeline_config.train_input_reader.label_map_path = "path/to/label_map"
|
| pipeline_config.eval_config.num_examples = 20
|
| pipeline_config.eval_input_reader.add().queue_capacity = 100
|
|
|
| config_util.save_pipeline_config(pipeline_config, self.get_temp_dir())
|
| configs = config_util.get_configs_from_pipeline_file(
|
| os.path.join(self.get_temp_dir(), "pipeline.config"))
|
| pipeline_config_reconstructed = (
|
| config_util.create_pipeline_proto_from_configs(configs))
|
|
|
| self.assertEqual(pipeline_config, pipeline_config_reconstructed)
|
|
|
| def test_get_configs_from_multiple_files(self):
|
| """Tests that proto configs can be read from multiple files."""
|
| temp_dir = self.get_temp_dir()
|
|
|
|
|
| model_config_path = os.path.join(temp_dir, "model.config")
|
| model = model_pb2.DetectionModel()
|
| model.faster_rcnn.num_classes = 10
|
| _write_config(model, model_config_path)
|
|
|
|
|
| train_config_path = os.path.join(temp_dir, "train.config")
|
| train_config = train_config = train_pb2.TrainConfig()
|
| train_config.batch_size = 32
|
| _write_config(train_config, train_config_path)
|
|
|
|
|
| train_input_config_path = os.path.join(temp_dir, "train_input.config")
|
| train_input_config = input_reader_pb2.InputReader()
|
| train_input_config.label_map_path = "path/to/label_map"
|
| _write_config(train_input_config, train_input_config_path)
|
|
|
|
|
| eval_config_path = os.path.join(temp_dir, "eval.config")
|
| eval_config = eval_pb2.EvalConfig()
|
| eval_config.num_examples = 20
|
| _write_config(eval_config, eval_config_path)
|
|
|
|
|
| eval_input_config_path = os.path.join(temp_dir, "eval_input.config")
|
| eval_input_config = input_reader_pb2.InputReader()
|
| eval_input_config.label_map_path = "path/to/another/label_map"
|
| _write_config(eval_input_config, eval_input_config_path)
|
|
|
| configs = config_util.get_configs_from_multiple_files(
|
| model_config_path=model_config_path,
|
| train_config_path=train_config_path,
|
| train_input_config_path=train_input_config_path,
|
| eval_config_path=eval_config_path,
|
| eval_input_config_path=eval_input_config_path)
|
| self.assertProtoEquals(model, configs["model"])
|
| self.assertProtoEquals(train_config, configs["train_config"])
|
| self.assertProtoEquals(train_input_config,
|
| configs["train_input_config"])
|
| self.assertProtoEquals(eval_config, configs["eval_config"])
|
| self.assertProtoEquals(eval_input_config, configs["eval_input_configs"][0])
|
|
|
| def _assertOptimizerWithNewLearningRate(self, optimizer_name):
|
| """Asserts successful updating of all learning rate schemes."""
|
| original_learning_rate = 0.7
|
| learning_rate_scaling = 0.1
|
| warmup_learning_rate = 0.07
|
| hparams = contrib_training.HParams(learning_rate=0.15)
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
|
| _update_optimizer_with_constant_learning_rate(optimizer,
|
| original_learning_rate)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
|
| constant_lr = optimizer.learning_rate.constant_learning_rate
|
| self.assertAlmostEqual(hparams.learning_rate, constant_lr.learning_rate)
|
|
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
|
| _update_optimizer_with_exponential_decay_learning_rate(
|
| optimizer, original_learning_rate)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
|
| exponential_lr = optimizer.learning_rate.exponential_decay_learning_rate
|
| self.assertAlmostEqual(hparams.learning_rate,
|
| exponential_lr.initial_learning_rate)
|
|
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
|
| _update_optimizer_with_manual_step_learning_rate(
|
| optimizer, original_learning_rate, learning_rate_scaling)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
|
| manual_lr = optimizer.learning_rate.manual_step_learning_rate
|
| self.assertAlmostEqual(hparams.learning_rate,
|
| manual_lr.initial_learning_rate)
|
| for i, schedule in enumerate(manual_lr.schedule):
|
| self.assertAlmostEqual(hparams.learning_rate * learning_rate_scaling**i,
|
| schedule.learning_rate)
|
|
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| optimizer = getattr(pipeline_config.train_config.optimizer, optimizer_name)
|
| _update_optimizer_with_cosine_decay_learning_rate(optimizer,
|
| original_learning_rate,
|
| warmup_learning_rate)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| optimizer = getattr(configs["train_config"].optimizer, optimizer_name)
|
| cosine_lr = optimizer.learning_rate.cosine_decay_learning_rate
|
|
|
| self.assertAlmostEqual(hparams.learning_rate, cosine_lr.learning_rate_base)
|
| warmup_scale_factor = warmup_learning_rate / original_learning_rate
|
| self.assertAlmostEqual(hparams.learning_rate * warmup_scale_factor,
|
| cosine_lr.warmup_learning_rate)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testRMSPropWithNewLearingRate(self):
|
| """Tests new learning rates for RMSProp Optimizer."""
|
| self._assertOptimizerWithNewLearningRate("rms_prop_optimizer")
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testMomentumOptimizerWithNewLearningRate(self):
|
| """Tests new learning rates for Momentum Optimizer."""
|
| self._assertOptimizerWithNewLearningRate("momentum_optimizer")
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testAdamOptimizerWithNewLearningRate(self):
|
| """Tests new learning rates for Adam Optimizer."""
|
| self._assertOptimizerWithNewLearningRate("adam_optimizer")
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testGenericConfigOverride(self):
|
| """Tests generic config overrides for all top-level configs."""
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.ssd.num_classes = 1
|
| pipeline_config.train_config.batch_size = 1
|
| pipeline_config.eval_config.num_visualizations = 1
|
| pipeline_config.train_input_reader.label_map_path = "/some/path"
|
| pipeline_config.eval_input_reader.add().label_map_path = "/some/path"
|
| pipeline_config.graph_rewriter.quantization.weight_bits = 1
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| hparams = contrib_training.HParams(
|
| **{
|
| "model.ssd.num_classes": 2,
|
| "train_config.batch_size": 2,
|
| "train_input_config.label_map_path": "/some/other/path",
|
| "eval_config.num_visualizations": 2,
|
| "graph_rewriter_config.quantization.weight_bits": 2
|
| })
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
|
|
|
|
| self.assertEqual(2, configs["model"].ssd.num_classes)
|
| self.assertEqual(2, configs["train_config"].batch_size)
|
| self.assertEqual("/some/other/path",
|
| configs["train_input_config"].label_map_path)
|
| self.assertEqual(2, configs["eval_config"].num_visualizations)
|
| self.assertEqual(2,
|
| configs["graph_rewriter_config"].quantization.weight_bits)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testNewBatchSize(self):
|
| """Tests that batch size is updated appropriately."""
|
| original_batch_size = 2
|
| hparams = contrib_training.HParams(batch_size=16)
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_config.batch_size = original_batch_size
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| new_batch_size = configs["train_config"].batch_size
|
| self.assertEqual(16, new_batch_size)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testNewBatchSizeWithClipping(self):
|
| """Tests that batch size is clipped to 1 from below."""
|
| original_batch_size = 2
|
| hparams = contrib_training.HParams(batch_size=0.5)
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_config.batch_size = original_batch_size
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| new_batch_size = configs["train_config"].batch_size
|
| self.assertEqual(1, new_batch_size)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testOverwriteBatchSizeWithKeyValue(self):
|
| """Tests that batch size is overwritten based on key/value."""
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_config.batch_size = 2
|
| configs = self._create_and_load_test_configs(pipeline_config)
|
| hparams = contrib_training.HParams(**{"train_config.batch_size": 10})
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| new_batch_size = configs["train_config"].batch_size
|
| self.assertEqual(10, new_batch_size)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testOverwriteSampleFromDatasetWeights(self):
|
| """Tests config override for sample_from_datasets_weights."""
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_input_reader.sample_from_datasets_weights.extend(
|
| [1, 2])
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| hparams = contrib_training.HParams(sample_from_datasets_weights=[0.5, 0.5])
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
|
|
|
|
| self.assertListEqual(
|
| [0.5, 0.5],
|
| list(configs["train_input_config"].sample_from_datasets_weights))
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testOverwriteSampleFromDatasetWeightsWrongLength(self):
|
| """Tests config override for sample_from_datasets_weights."""
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_input_reader.sample_from_datasets_weights.extend(
|
| [1, 2])
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| hparams = contrib_training.HParams(
|
| sample_from_datasets_weights=[0.5, 0.5, 0.5])
|
| with self.assertRaises(
|
| ValueError,
|
| msg="sample_from_datasets_weights override has a different number of"
|
| " values (3) than the configured dataset weights (2)."
|
| ):
|
| config_util.merge_external_params_with_configs(configs, hparams)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testKeyValueOverrideBadKey(self):
|
| """Tests that overwriting with a bad key causes an exception."""
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| configs = self._create_and_load_test_configs(pipeline_config)
|
| hparams = contrib_training.HParams(**{"train_config.no_such_field": 10})
|
| with self.assertRaises(ValueError):
|
| config_util.merge_external_params_with_configs(configs, hparams)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testOverwriteBatchSizeWithBadValueType(self):
|
| """Tests that overwriting with a bad valuye type causes an exception."""
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_config.batch_size = 2
|
| configs = self._create_and_load_test_configs(pipeline_config)
|
|
|
| hparams = contrib_training.HParams(**{"train_config.batch_size": "10"})
|
| with self.assertRaises(TypeError):
|
| config_util.merge_external_params_with_configs(configs, hparams)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testNewMomentumOptimizerValue(self):
|
| """Tests that new momentum value is updated appropriately."""
|
| original_momentum_value = 0.4
|
| hparams = contrib_training.HParams(momentum_optimizer_value=1.1)
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| optimizer_config = pipeline_config.train_config.optimizer.rms_prop_optimizer
|
| optimizer_config.momentum_optimizer_value = original_momentum_value
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| optimizer_config = configs["train_config"].optimizer.rms_prop_optimizer
|
| new_momentum_value = optimizer_config.momentum_optimizer_value
|
| self.assertAlmostEqual(1.0, new_momentum_value)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testNewClassificationLocalizationWeightRatio(self):
|
| """Tests that the loss weight ratio is updated appropriately."""
|
| original_localization_weight = 0.1
|
| original_classification_weight = 0.2
|
| new_weight_ratio = 5.0
|
| hparams = contrib_training.HParams(
|
| classification_localization_weight_ratio=new_weight_ratio)
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.ssd.loss.localization_weight = (
|
| original_localization_weight)
|
| pipeline_config.model.ssd.loss.classification_weight = (
|
| original_classification_weight)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| loss = configs["model"].ssd.loss
|
| self.assertAlmostEqual(1.0, loss.localization_weight)
|
| self.assertAlmostEqual(new_weight_ratio, loss.classification_weight)
|
|
|
| @unittest.skipIf(tf_version.is_tf2(), "Skipping TF1.X only test.")
|
| def testNewFocalLossParameters(self):
|
| """Tests that the loss weight ratio is updated appropriately."""
|
| original_alpha = 1.0
|
| original_gamma = 1.0
|
| new_alpha = 0.3
|
| new_gamma = 2.0
|
| hparams = contrib_training.HParams(
|
| focal_loss_alpha=new_alpha, focal_loss_gamma=new_gamma)
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| classification_loss = pipeline_config.model.ssd.loss.classification_loss
|
| classification_loss.weighted_sigmoid_focal.alpha = original_alpha
|
| classification_loss.weighted_sigmoid_focal.gamma = original_gamma
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| configs = config_util.merge_external_params_with_configs(configs, hparams)
|
| classification_loss = configs["model"].ssd.loss.classification_loss
|
| self.assertAlmostEqual(new_alpha,
|
| classification_loss.weighted_sigmoid_focal.alpha)
|
| self.assertAlmostEqual(new_gamma,
|
| classification_loss.weighted_sigmoid_focal.gamma)
|
|
|
| def testMergingKeywordArguments(self):
|
| """Tests that keyword arguments get merged as do hyperparameters."""
|
| original_num_train_steps = 100
|
| desired_num_train_steps = 10
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_config.num_steps = original_num_train_steps
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"train_steps": desired_num_train_steps}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| train_steps = configs["train_config"].num_steps
|
| self.assertEqual(desired_num_train_steps, train_steps)
|
|
|
| def testGetNumberOfClasses(self):
|
| """Tests that number of classes can be retrieved."""
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.faster_rcnn.num_classes = 20
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| number_of_classes = config_util.get_number_of_classes(configs["model"])
|
| self.assertEqual(20, number_of_classes)
|
|
|
| def testNewTrainInputPath(self):
|
| """Tests that train input path can be overwritten with single file."""
|
| original_train_path = ["path/to/data"]
|
| new_train_path = "another/path/to/data"
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| reader_config = pipeline_config.train_input_reader.tf_record_input_reader
|
| reader_config.input_path.extend(original_train_path)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"train_input_path": new_train_path}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| reader_config = configs["train_input_config"].tf_record_input_reader
|
| final_path = reader_config.input_path
|
| self.assertEqual([new_train_path], final_path)
|
|
|
| def testNewTrainInputPathList(self):
|
| """Tests that train input path can be overwritten with multiple files."""
|
| original_train_path = ["path/to/data"]
|
| new_train_path = ["another/path/to/data", "yet/another/path/to/data"]
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| reader_config = pipeline_config.train_input_reader.tf_record_input_reader
|
| reader_config.input_path.extend(original_train_path)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"train_input_path": new_train_path}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| reader_config = configs["train_input_config"].tf_record_input_reader
|
| final_path = reader_config.input_path
|
| self.assertEqual(new_train_path, final_path)
|
|
|
| def testNewLabelMapPath(self):
|
| """Tests that label map path can be overwritten in input readers."""
|
| original_label_map_path = "path/to/original/label_map"
|
| new_label_map_path = "path//to/new/label_map"
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| train_input_reader = pipeline_config.train_input_reader
|
| train_input_reader.label_map_path = original_label_map_path
|
| eval_input_reader = pipeline_config.eval_input_reader.add()
|
| eval_input_reader.label_map_path = original_label_map_path
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"label_map_path": new_label_map_path}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| self.assertEqual(new_label_map_path,
|
| configs["train_input_config"].label_map_path)
|
| for eval_input_config in configs["eval_input_configs"]:
|
| self.assertEqual(new_label_map_path, eval_input_config.label_map_path)
|
|
|
| def testDontOverwriteEmptyLabelMapPath(self):
|
| """Tests that label map path will not by overwritten with empty string."""
|
| original_label_map_path = "path/to/original/label_map"
|
| new_label_map_path = ""
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| train_input_reader = pipeline_config.train_input_reader
|
| train_input_reader.label_map_path = original_label_map_path
|
| eval_input_reader = pipeline_config.eval_input_reader.add()
|
| eval_input_reader.label_map_path = original_label_map_path
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"label_map_path": new_label_map_path}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| self.assertEqual(original_label_map_path,
|
| configs["train_input_config"].label_map_path)
|
| self.assertEqual(original_label_map_path,
|
| configs["eval_input_configs"][0].label_map_path)
|
|
|
| def testNewMaskType(self):
|
| """Tests that mask type can be overwritten in input readers."""
|
| original_mask_type = input_reader_pb2.NUMERICAL_MASKS
|
| new_mask_type = input_reader_pb2.PNG_MASKS
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| train_input_reader = pipeline_config.train_input_reader
|
| train_input_reader.mask_type = original_mask_type
|
| eval_input_reader = pipeline_config.eval_input_reader.add()
|
| eval_input_reader.mask_type = original_mask_type
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"mask_type": new_mask_type}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| self.assertEqual(new_mask_type, configs["train_input_config"].mask_type)
|
| self.assertEqual(new_mask_type, configs["eval_input_configs"][0].mask_type)
|
|
|
| def testUseMovingAverageForEval(self):
|
| use_moving_averages_orig = False
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
|
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_config.use_moving_averages = use_moving_averages_orig
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"eval_with_moving_averages": True}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| self.assertEqual(True, configs["eval_config"].use_moving_averages)
|
|
|
| def testGetImageResizerConfig(self):
|
| """Tests that number of classes can be retrieved."""
|
| model_config = model_pb2.DetectionModel()
|
| model_config.faster_rcnn.image_resizer.fixed_shape_resizer.height = 100
|
| model_config.faster_rcnn.image_resizer.fixed_shape_resizer.width = 300
|
| image_resizer_config = config_util.get_image_resizer_config(model_config)
|
| self.assertEqual(image_resizer_config.fixed_shape_resizer.height, 100)
|
| self.assertEqual(image_resizer_config.fixed_shape_resizer.width, 300)
|
|
|
| def testGetSpatialImageSizeFromFixedShapeResizerConfig(self):
|
| image_resizer_config = image_resizer_pb2.ImageResizer()
|
| image_resizer_config.fixed_shape_resizer.height = 100
|
| image_resizer_config.fixed_shape_resizer.width = 200
|
| image_shape = config_util.get_spatial_image_size(image_resizer_config)
|
| self.assertAllEqual(image_shape, [100, 200])
|
|
|
| def testGetSpatialImageSizeFromAspectPreservingResizerConfig(self):
|
| image_resizer_config = image_resizer_pb2.ImageResizer()
|
| image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
|
| image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
|
| image_resizer_config.keep_aspect_ratio_resizer.pad_to_max_dimension = True
|
| image_shape = config_util.get_spatial_image_size(image_resizer_config)
|
| self.assertAllEqual(image_shape, [600, 600])
|
|
|
| def testGetSpatialImageSizeFromAspectPreservingResizerDynamic(self):
|
| image_resizer_config = image_resizer_pb2.ImageResizer()
|
| image_resizer_config.keep_aspect_ratio_resizer.min_dimension = 100
|
| image_resizer_config.keep_aspect_ratio_resizer.max_dimension = 600
|
| image_shape = config_util.get_spatial_image_size(image_resizer_config)
|
| self.assertAllEqual(image_shape, [-1, -1])
|
|
|
| def testGetSpatialImageSizeFromConditionalShapeResizer(self):
|
| image_resizer_config = image_resizer_pb2.ImageResizer()
|
| image_resizer_config.conditional_shape_resizer.size_threshold = 100
|
| image_shape = config_util.get_spatial_image_size(image_resizer_config)
|
| self.assertAllEqual(image_shape, [-1, -1])
|
|
|
| def testGetMaxNumContextFeaturesFromModelConfig(self):
|
| model_config = model_pb2.DetectionModel()
|
| model_config.faster_rcnn.context_config.max_num_context_features = 10
|
| max_num_context_features = config_util.get_max_num_context_features(
|
| model_config)
|
| self.assertAllEqual(max_num_context_features, 10)
|
|
|
| def testGetContextFeatureLengthFromModelConfig(self):
|
| model_config = model_pb2.DetectionModel()
|
| model_config.faster_rcnn.context_config.context_feature_length = 100
|
| context_feature_length = config_util.get_context_feature_length(
|
| model_config)
|
| self.assertAllEqual(context_feature_length, 100)
|
|
|
| def testEvalShuffle(self):
|
| """Tests that `eval_shuffle` keyword arguments are applied correctly."""
|
| original_shuffle = True
|
| desired_shuffle = False
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_input_reader.add().shuffle = original_shuffle
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"eval_shuffle": desired_shuffle}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| self.assertEqual(desired_shuffle, configs["eval_input_configs"][0].shuffle)
|
|
|
| def testTrainShuffle(self):
|
| """Tests that `train_shuffle` keyword arguments are applied correctly."""
|
| original_shuffle = True
|
| desired_shuffle = False
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_input_reader.shuffle = original_shuffle
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"train_shuffle": desired_shuffle}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| train_shuffle = configs["train_input_config"].shuffle
|
| self.assertEqual(desired_shuffle, train_shuffle)
|
|
|
| def testOverWriteRetainOriginalImages(self):
|
| """Tests that `train_shuffle` keyword arguments are applied correctly."""
|
| original_retain_original_images = True
|
| desired_retain_original_images = False
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_config.retain_original_images = (
|
| original_retain_original_images)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {
|
| "retain_original_images_in_eval": desired_retain_original_images
|
| }
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| retain_original_images = configs["eval_config"].retain_original_images
|
| self.assertEqual(desired_retain_original_images, retain_original_images)
|
|
|
| def testOverwriteAllEvalSampling(self):
|
| original_num_eval_examples = 1
|
| new_num_eval_examples = 10
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
|
| original_num_eval_examples)
|
| pipeline_config.eval_input_reader.add().sample_1_of_n_examples = (
|
| original_num_eval_examples)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"sample_1_of_n_eval_examples": new_num_eval_examples}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| for eval_input_config in configs["eval_input_configs"]:
|
| self.assertEqual(new_num_eval_examples,
|
| eval_input_config.sample_1_of_n_examples)
|
|
|
| def testOverwriteAllEvalNumEpochs(self):
|
| original_num_epochs = 10
|
| new_num_epochs = 1
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
|
| pipeline_config.eval_input_reader.add().num_epochs = original_num_epochs
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"eval_num_epochs": new_num_epochs}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| for eval_input_config in configs["eval_input_configs"]:
|
| self.assertEqual(new_num_epochs, eval_input_config.num_epochs)
|
|
|
| def testUpdateMaskTypeForAllInputConfigs(self):
|
| original_mask_type = input_reader_pb2.NUMERICAL_MASKS
|
| new_mask_type = input_reader_pb2.PNG_MASKS
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| train_config = pipeline_config.train_input_reader
|
| train_config.mask_type = original_mask_type
|
| eval_1 = pipeline_config.eval_input_reader.add()
|
| eval_1.mask_type = original_mask_type
|
| eval_1.name = "eval_1"
|
| eval_2 = pipeline_config.eval_input_reader.add()
|
| eval_2.mask_type = original_mask_type
|
| eval_2.name = "eval_2"
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"mask_type": new_mask_type}
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
|
|
| self.assertEqual(configs["train_input_config"].mask_type, new_mask_type)
|
| for eval_input_config in configs["eval_input_configs"]:
|
| self.assertEqual(eval_input_config.mask_type, new_mask_type)
|
|
|
| def testErrorOverwritingMultipleInputConfig(self):
|
| original_shuffle = False
|
| new_shuffle = True
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| eval_1 = pipeline_config.eval_input_reader.add()
|
| eval_1.shuffle = original_shuffle
|
| eval_1.name = "eval_1"
|
| eval_2 = pipeline_config.eval_input_reader.add()
|
| eval_2.shuffle = original_shuffle
|
| eval_2.name = "eval_2"
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {"eval_shuffle": new_shuffle}
|
| with self.assertRaises(ValueError):
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
|
|
| def testCheckAndParseInputConfigKey(self):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_input_reader.add().name = "eval_1"
|
| pipeline_config.eval_input_reader.add().name = "eval_2"
|
| _write_config(pipeline_config, pipeline_config_path)
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
|
|
| specific_shuffle_update_key = "eval_input_configs:eval_2:shuffle"
|
| is_valid_input_config_key, key_name, input_name, field_name = (
|
| config_util.check_and_parse_input_config_key(
|
| configs, specific_shuffle_update_key))
|
| self.assertTrue(is_valid_input_config_key)
|
| self.assertEqual(key_name, "eval_input_configs")
|
| self.assertEqual(input_name, "eval_2")
|
| self.assertEqual(field_name, "shuffle")
|
|
|
| legacy_shuffle_update_key = "eval_shuffle"
|
| is_valid_input_config_key, key_name, input_name, field_name = (
|
| config_util.check_and_parse_input_config_key(configs,
|
| legacy_shuffle_update_key))
|
| self.assertTrue(is_valid_input_config_key)
|
| self.assertEqual(key_name, "eval_input_configs")
|
| self.assertEqual(input_name, None)
|
| self.assertEqual(field_name, "shuffle")
|
|
|
| non_input_config_update_key = "label_map_path"
|
| is_valid_input_config_key, key_name, input_name, field_name = (
|
| config_util.check_and_parse_input_config_key(
|
| configs, non_input_config_update_key))
|
| self.assertFalse(is_valid_input_config_key)
|
| self.assertEqual(key_name, None)
|
| self.assertEqual(input_name, None)
|
| self.assertEqual(field_name, "label_map_path")
|
|
|
| with self.assertRaisesRegexp(ValueError,
|
| "Invalid key format when overriding configs."):
|
| config_util.check_and_parse_input_config_key(
|
| configs, "train_input_config:shuffle")
|
|
|
| with self.assertRaisesRegexp(
|
| ValueError, "Invalid key_name when overriding input config."):
|
| config_util.check_and_parse_input_config_key(
|
| configs, "invalid_key_name:train_name:shuffle")
|
|
|
| with self.assertRaisesRegexp(
|
| ValueError, "Invalid input_name when overriding input config."):
|
| config_util.check_and_parse_input_config_key(
|
| configs, "eval_input_configs:unknown_eval_name:shuffle")
|
|
|
| with self.assertRaisesRegexp(
|
| ValueError, "Invalid field_name when overriding input config."):
|
| config_util.check_and_parse_input_config_key(
|
| configs, "eval_input_configs:eval_2:unknown_field_name")
|
|
|
| def testUpdateInputReaderConfigSuccess(self):
|
| original_shuffle = False
|
| new_shuffle = True
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.train_input_reader.shuffle = original_shuffle
|
| _write_config(pipeline_config, pipeline_config_path)
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
|
|
| config_util.update_input_reader_config(
|
| configs,
|
| key_name="train_input_config",
|
| input_name=None,
|
| field_name="shuffle",
|
| value=new_shuffle)
|
| self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
|
|
|
| config_util.update_input_reader_config(
|
| configs,
|
| key_name="train_input_config",
|
| input_name=None,
|
| field_name="shuffle",
|
| value=new_shuffle)
|
| self.assertEqual(configs["train_input_config"].shuffle, new_shuffle)
|
|
|
| def testUpdateInputReaderConfigErrors(self):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_input_reader.add().name = "same_eval_name"
|
| pipeline_config.eval_input_reader.add().name = "same_eval_name"
|
| _write_config(pipeline_config, pipeline_config_path)
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
|
|
| with self.assertRaisesRegexp(ValueError,
|
| "Duplicate input name found when overriding."):
|
| config_util.update_input_reader_config(
|
| configs,
|
| key_name="eval_input_configs",
|
| input_name="same_eval_name",
|
| field_name="shuffle",
|
| value=False)
|
|
|
| with self.assertRaisesRegexp(
|
| ValueError, "Input name name_not_exist not found when overriding."):
|
| config_util.update_input_reader_config(
|
| configs,
|
| key_name="eval_input_configs",
|
| input_name="name_not_exist",
|
| field_name="shuffle",
|
| value=False)
|
|
|
| with self.assertRaisesRegexp(ValueError,
|
| "Unknown input config overriding."):
|
| config_util.update_input_reader_config(
|
| configs,
|
| key_name="eval_input_configs",
|
| input_name=None,
|
| field_name="shuffle",
|
| value=False)
|
|
|
| def testOverWriteRetainOriginalImageAdditionalChannels(self):
|
| """Tests that keyword arguments are applied correctly."""
|
| original_retain_original_image_additional_channels = True
|
| desired_retain_original_image_additional_channels = False
|
|
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.eval_config.retain_original_image_additional_channels = (
|
| original_retain_original_image_additional_channels)
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| override_dict = {
|
| "retain_original_image_additional_channels_in_eval":
|
| desired_retain_original_image_additional_channels
|
| }
|
| configs = config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict=override_dict)
|
| retain_original_image_additional_channels = configs[
|
| "eval_config"].retain_original_image_additional_channels
|
| self.assertEqual(desired_retain_original_image_additional_channels,
|
| retain_original_image_additional_channels)
|
|
|
| def testUpdateNumClasses(self):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| pipeline_config.model.faster_rcnn.num_classes = 10
|
|
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
|
|
| self.assertEqual(config_util.get_number_of_classes(configs["model"]), 10)
|
|
|
| config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict={"num_classes": 2})
|
|
|
| self.assertEqual(config_util.get_number_of_classes(configs["model"]), 2)
|
|
|
| def testRemoveUnnecessaryEma(self):
|
| input_dict = {
|
| "expanded_conv_10/project/act_quant/min":
|
| 1,
|
| "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min":
|
| 2,
|
| "expanded_conv_10/expand/BatchNorm/gamma/min/ExponentialMovingAverage":
|
| 3,
|
| "expanded_conv_3/depthwise/BatchNorm/beta/max/ExponentialMovingAverage":
|
| 4,
|
| "BoxPredictor_1/ClassPredictor_depthwise/act_quant":
|
| 5
|
| }
|
|
|
| no_ema_collection = ["/min", "/max"]
|
|
|
| output_dict = {
|
| "expanded_conv_10/project/act_quant/min":
|
| 1,
|
| "FeatureExtractor/MobilenetV2_2/expanded_conv_5/expand/act_quant/min":
|
| 2,
|
| "expanded_conv_10/expand/BatchNorm/gamma/min":
|
| 3,
|
| "expanded_conv_3/depthwise/BatchNorm/beta/max":
|
| 4,
|
| "BoxPredictor_1/ClassPredictor_depthwise/act_quant":
|
| 5
|
| }
|
|
|
| self.assertEqual(
|
| output_dict,
|
| config_util.remove_unnecessary_ema(input_dict, no_ema_collection))
|
|
|
| def testUpdateRescoreInstances(self):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
|
| kpt_task.rescore_instances = True
|
|
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| cn_config = configs["model"].center_net
|
| self.assertEqual(
|
| True, cn_config.keypoint_estimation_task[0].rescore_instances)
|
|
|
| config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict={"rescore_instances": False})
|
| cn_config = configs["model"].center_net
|
| self.assertEqual(
|
| False, cn_config.keypoint_estimation_task[0].rescore_instances)
|
|
|
| def testUpdateRescoreInstancesWithBooleanString(self):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
|
| kpt_task.rescore_instances = True
|
|
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| cn_config = configs["model"].center_net
|
| self.assertEqual(
|
| True, cn_config.keypoint_estimation_task[0].rescore_instances)
|
|
|
| config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict={"rescore_instances": "False"})
|
| cn_config = configs["model"].center_net
|
| self.assertEqual(
|
| False, cn_config.keypoint_estimation_task[0].rescore_instances)
|
|
|
| def testUpdateRescoreInstancesWithMultipleTasks(self):
|
| pipeline_config_path = os.path.join(self.get_temp_dir(), "pipeline.config")
|
| pipeline_config = pipeline_pb2.TrainEvalPipelineConfig()
|
| kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
|
| kpt_task.rescore_instances = True
|
| kpt_task = pipeline_config.model.center_net.keypoint_estimation_task.add()
|
| kpt_task.rescore_instances = True
|
|
|
| _write_config(pipeline_config, pipeline_config_path)
|
|
|
| configs = config_util.get_configs_from_pipeline_file(pipeline_config_path)
|
| cn_config = configs["model"].center_net
|
| self.assertEqual(
|
| True, cn_config.keypoint_estimation_task[0].rescore_instances)
|
|
|
| config_util.merge_external_params_with_configs(
|
| configs, kwargs_dict={"rescore_instances": False})
|
| cn_config = configs["model"].center_net
|
| self.assertEqual(
|
| True, cn_config.keypoint_estimation_task[0].rescore_instances)
|
| self.assertEqual(
|
| True, cn_config.keypoint_estimation_task[1].rescore_instances)
|
|
|
|
|
| if __name__ == "__main__":
|
| tf.test.main()
|
|
|