# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for using gin configurations."""
import logging
from typing import Any, Sequence, Union, cast
import gin
from absl import app
def parse_gin_flags(
gin_search_paths: Sequence[str],
gin_files: Sequence[str],
gin_bindings: Sequence[str],
skip_unknown: Union[bool, Sequence[str]] = False,
finalize_config: bool = True,
) -> None:
"""Parses provided gin files override params.
Args:
gin_search_paths: paths that will be searched for gin files.
gin_files: paths to gin config files to be parsed. Files will be parsed in
order with conflicting settings being overriden by later files. Paths may
be relative to paths in `gin_search_paths`.
gin_bindings: individual gin bindings to be applied after the gin files are
parsed. Will be applied in order with conflicting settings being overriden
by later oens.
skip_unknown: whether to ignore unknown bindings or raise an error (default
behavior). Alternatively, a list of configurable names to skip if unknown.
finalize_config: whether to finalize the config so that it cannot be
modified (default behavior).
"""
# Register .gin file search paths with gin
for gin_file_path in gin_search_paths:
gin.add_config_file_search_path(gin_file_path)
# Parse config files and bindings passed via flag.
gin.parse_config_files_and_bindings(
gin_files,
gin_bindings,
skip_unknown=skip_unknown,
finalize_config=finalize_config,
)
logging.info("Gin Configuration:")
for line in gin.config_str().splitlines():
logging.info("%s", line)
def rewrite_gin_args(args: Sequence[str]) -> Sequence[str]:
"""Rewrite `--gin.NAME=VALUE` flags to `--gin_bindings=NAME=VALUE`."""
def _rewrite_gin_arg(arg: str) -> str:
if not arg.startswith("--gin."):
return arg
if "=" not in arg:
raise ValueError(
"Gin bindings must be of the form '--gin.=', got: " + arg
)
# Strip '--gin.'
arg = arg[6:]
name, value = arg.split("=", maxsplit=1)
r_arg = f"--gin_bindings={name} = {value}"
print(f"Rewritten gin arg: {r_arg}")
return r_arg
return [_rewrite_gin_arg(arg) for arg in args]
def run(main: Any) -> None:
"""Wrapper for app.run that rewrites gin args before parsing."""
app.run(
main,
flags_parser=lambda args: app.parse_flags_with_usage(rewrite_gin_args(args)),
)
# ====================== Configurable Utility Functions ======================
@gin.configurable # type: ignore
def bool_fn(var1: object = gin.REQUIRED) -> bool:
"""bool function to use inside gin files."""
return bool(var1)
@gin.configurable # type: ignore
def string_split_fn(
text: str = cast(str, gin.REQUIRED),
separator: str = cast(str, gin.REQUIRED),
maxsplit: int = -1,
index: int | None = None,
) -> list[str] | Any:
"""String split function to use inside gin files."""
values = text.split(separator, maxsplit)
if index is None:
return values
else:
return values[index]