| # Using flags in official models | |
| 1. **All common flags must be incorporated in the models.** | |
| Common flags (i.e. batch_size, model_dir, etc.) are provided by various flag definition functions, | |
| and channeled through `official.utils.flags.core`. For instance to define common supervised | |
| learning parameters one could use the following code: | |
| ```$xslt | |
| from absl import app as absl_app | |
| from absl import flags | |
| from official.utils.flags import core as flags_core | |
| def define_flags(): | |
| flags_core.define_base() | |
| flags.adopt_key_flags(flags_core) | |
| def main(_): | |
| flags_obj = flags.FLAGS | |
| print(flags_obj) | |
| if __name__ == "__main__" | |
| absl_app.run(main) | |
| ``` | |
| 2. **Validate flag values.** | |
| See the [Validators](#validators) section for implementation details. | |
| Validators in the official model repo should not access the file system, such as verifying | |
| that files exist, due to the strict ordering requirements. | |
| 3. **Flag values should not be mutated.** | |
| Instead of mutating flag values, use getter functions to return the desired values. An example | |
| getter function is `get_tf_dtype` function below: | |
| ``` | |
| # Map string to TensorFlow dtype | |
| DTYPE_MAP = { | |
| "fp16": tf.float16, | |
| "fp32": tf.float32, | |
| } | |
| def get_tf_dtype(flags_obj): | |
| if getattr(flags_obj, "fp16_implementation", None) == "graph_rewrite": | |
| # If the graph_rewrite is used, we build the graph with fp32, and let the | |
| # graph rewrite change ops to fp16. | |
| return tf.float32 | |
| return DTYPE_MAP[flags_obj.dtype] | |
| def main(_): | |
| flags_obj = flags.FLAGS() | |
| # Do not mutate flags_obj | |
| # if flags_obj.fp16_implementation == "graph_rewrite": | |
| # flags_obj.dtype = "float32" # Don't do this | |
| print(get_tf_dtype(flags_obj)) | |
| ... | |
| ``` |