|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Flags for benchmarking models."""
|
|
|
| from absl import flags
|
|
|
| from official.utils.flags._conventions import help_wrap
|
|
|
|
|
| def define_log_steps():
|
| flags.DEFINE_integer(
|
| name="log_steps",
|
| default=100,
|
| help="Frequency with which to log timing information with TimeHistory.")
|
|
|
| return []
|
|
|
|
|
| def define_benchmark(benchmark_log_dir=True, bigquery_uploader=True):
|
| """Register benchmarking flags.
|
|
|
| Args:
|
| benchmark_log_dir: Create a flag to specify location for benchmark logging.
|
| bigquery_uploader: Create flags for uploading results to BigQuery.
|
|
|
| Returns:
|
| A list of flags for core.py to marks as key flags.
|
| """
|
|
|
| key_flags = []
|
|
|
| flags.DEFINE_enum(
|
| name="benchmark_logger_type",
|
| default="BaseBenchmarkLogger",
|
| enum_values=["BaseBenchmarkLogger", "BenchmarkFileLogger"],
|
| help=help_wrap("The type of benchmark logger to use. Defaults to using "
|
| "BaseBenchmarkLogger which logs to STDOUT. Different "
|
| "loggers will require other flags to be able to work."))
|
| flags.DEFINE_string(
|
| name="benchmark_test_id",
|
| short_name="bti",
|
| default=None,
|
| help=help_wrap("The unique test ID of the benchmark run. It could be the "
|
| "combination of key parameters. It is hardware "
|
| "independent and could be used compare the performance "
|
| "between different test runs. This flag is designed for "
|
| "human consumption, and does not have any impact within "
|
| "the system."))
|
|
|
| define_log_steps()
|
|
|
| if benchmark_log_dir:
|
| flags.DEFINE_string(
|
| name="benchmark_log_dir",
|
| short_name="bld",
|
| default=None,
|
| help=help_wrap("The location of the benchmark logging."))
|
|
|
| if bigquery_uploader:
|
| flags.DEFINE_string(
|
| name="gcp_project",
|
| short_name="gp",
|
| default=None,
|
| help=help_wrap(
|
| "The GCP project name where the benchmark will be uploaded."))
|
|
|
| flags.DEFINE_string(
|
| name="bigquery_data_set",
|
| short_name="bds",
|
| default="test_benchmark",
|
| help=help_wrap(
|
| "The Bigquery dataset name where the benchmark will be uploaded."))
|
|
|
| flags.DEFINE_string(
|
| name="bigquery_run_table",
|
| short_name="brt",
|
| default="benchmark_run",
|
| help=help_wrap("The Bigquery table name where the benchmark run "
|
| "information will be uploaded."))
|
|
|
| flags.DEFINE_string(
|
| name="bigquery_run_status_table",
|
| short_name="brst",
|
| default="benchmark_run_status",
|
| help=help_wrap("The Bigquery table name where the benchmark run "
|
| "status information will be uploaded."))
|
|
|
| flags.DEFINE_string(
|
| name="bigquery_metric_table",
|
| short_name="bmt",
|
| default="benchmark_metric",
|
| help=help_wrap("The Bigquery table name where the benchmark metric "
|
| "information will be uploaded."))
|
|
|
| @flags.multi_flags_validator(
|
| ["benchmark_logger_type", "benchmark_log_dir"],
|
| message="--benchmark_logger_type=BenchmarkFileLogger will require "
|
| "--benchmark_log_dir being set")
|
| def _check_benchmark_log_dir(flags_dict):
|
| benchmark_logger_type = flags_dict["benchmark_logger_type"]
|
| if benchmark_logger_type == "BenchmarkFileLogger":
|
| return flags_dict["benchmark_log_dir"]
|
| return True
|
|
|
| return key_flags
|
|
|