| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Training loop with flexible/schedulable settings.""" |
| |
| import functools |
| import importlib |
| import multiprocessing.pool |
| import os |
|
|
| from absl import app |
| from absl import flags |
| from absl import logging |
| import big_vision.evaluators.common as eval_common |
| import big_vision.input_pipeline as input_pipeline |
| import big_vision.optax as bv_optax |
| import big_vision.trainers.proj.flexi.common as flexi |
| import big_vision.utils as u |
| from clu import parameter_overview |
| import flax |
| import jax |
| import jax.numpy as jnp |
| from ml_collections import config_flags |
| import numpy as np |
| import optax |
| import tensorflow as tf |
|
|
| from tensorflow.io import gfile |
|
|
| |
|
|
|
|
| config_flags.DEFINE_config_file( |
| "config", None, "Training configuration.", lock_config=True) |
|
|
| flags.DEFINE_string("workdir", default=None, help="Work unit directory.") |
| flags.DEFINE_boolean("cleanup", default=False, |
| help="Delete workdir (only) after successful completion.") |
|
|
| |
| jax.config.parse_flags_with_absl() |
|
|
|
|
| def main(argv): |
| del argv |
| tf.config.experimental.set_visible_devices([], "GPU") |
|
|
| config = flags.FLAGS.config |
| workdir = flags.FLAGS.workdir |
| logging.info( |
| f"\u001b[33mHello from process {jax.process_index()} holding " |
| f"{jax.local_device_count()}/{jax.device_count()} devices and " |
| f"writing to workdir {workdir}.\u001b[0m") |
|
|
| save_ckpt_path = None |
| if workdir: |
| gfile.makedirs(workdir) |
| save_ckpt_path = os.path.join(workdir, "checkpoint.npz") |
|
|
| |
| pool = multiprocessing.pool.ThreadPool() |
|
|
| |
| for m in config.get("pp_modules", ["ops_general", "ops_image", "ops_text"]): |
| importlib.import_module(f"big_vision.pp.{m}") |
|
|
| |
| |
| |
| |
| rng = jax.random.PRNGKey(config.get("seed", 0)) |
|
|
| |
| |
| xid, wid = -1, -1 |
| fillin = lambda s: s |
| def info(s, *a): |
| logging.info("\u001b[33mNOTE\u001b[0m: " + s, *a) |
| def write_note(note): |
| if jax.process_index() == 0: |
| info("%s", note) |
|
|
| write_note("Initializing...") |
|
|
| batch_size = config.input.batch_size |
| if batch_size % jax.device_count() != 0: |
| raise ValueError(f"Batch size ({batch_size}) must " |
| f"be divisible by device number ({jax.device_count()})") |
| info("Global batch size %d on %d hosts results in %d local batch size. With " |
| "%d dev per host (%d dev total), that's a %d per-device batch size.", |
| batch_size, jax.process_count(), batch_size // jax.process_count(), |
| jax.local_device_count(), jax.device_count(), |
| batch_size // jax.device_count()) |
|
|
| |
| mw = u.BigVisionMetricWriter(xid, wid, workdir, config) |
|
|
| write_note("Initializing train dataset...") |
| train_ds, ntrain_img = input_pipeline.training(config.input) |
|
|
| |
| n_prefetch = config.get("prefetch_to_device", 1) |
| train_iter = input_pipeline.start_input_pipeline(train_ds, n_prefetch) |
|
|
| total_steps = u.steps("total", config, ntrain_img, batch_size) |
| def get_steps(name, default=ValueError, cfg=config): |
| return u.steps(name, cfg, ntrain_img, batch_size, total_steps, default) |
|
|
| u.chrono.inform(total_steps=total_steps, global_bs=batch_size, |
| steps_per_epoch=ntrain_img / batch_size, |
| measure=mw.measure, write_note=write_note) |
|
|
| info("Running for %d steps, that means %f epochs", |
| total_steps, total_steps * batch_size / ntrain_img) |
|
|
| write_note(f"Initializing {config.model_name} model...") |
| model_mod = importlib.import_module(f"big_vision.models.{config.model_name}") |
| model = model_mod.Model( |
| num_classes=config.num_classes, **config.get("model", {})) |
|
|
| |
| |
| |
| @functools.partial(jax.jit, backend="cpu") |
| def init(rng): |
| shape = tuple(train_ds.element_spec["image"].shape[1:]) |
| bs = batch_size // jax.device_count() |
| dummy_input = jnp.zeros((bs,) + shape, jnp.float32) |
| params = flax.core.unfreeze(model.init(rng, dummy_input))["params"] |
|
|
| |
| if "init_head_bias" in config: |
| params["head"]["bias"] = jnp.full_like(params["head"]["bias"], |
| config["init_head_bias"]) |
|
|
| return params |
|
|
| rng, rng_init = jax.random.split(rng) |
| with u.chrono.log_timing("z/secs/init"): |
| params_cpu = init(rng_init) |
|
|
| if jax.process_index() == 0: |
| num_params = sum(p.size for p in jax.tree_leaves(params_cpu)) |
| parameter_overview.log_parameter_overview(params_cpu, msg="init params") |
| mw.measure("num_params", num_params) |
|
|
| write_note(f"Initializing {config.optax_name} optimizer...") |
| tx, sched_fns = bv_optax.make(config, params_cpu, sched_kw=dict( |
| total_steps=total_steps, batch_size=batch_size, data_size=ntrain_img)) |
|
|
| |
| opt_cpu = jax.jit(tx.init, backend="cpu")(params_cpu) |
| sched_fns_cpu = [jax.jit(sched_fn, backend="cpu") for sched_fn in sched_fns] |
|
|
| flexi_argnames = sorted(config.flexi) |
|
|
| @functools.partial(jax.pmap, axis_name="batch", donate_argnums=(0, 1), |
| static_broadcasted_argnums=tuple(range(5, 5 + len(flexi_argnames)))) |
| def update_fn(params, opt, rng, images, labels, *args): |
| """Update step.""" |
|
|
| measurements = {} |
|
|
| if config.get("mixup") and config.mixup.p: |
| rng, (images, labels), _ = u.mixup(rng, images, labels, **config.mixup) |
|
|
| |
| rng, rng_model = jax.random.split(rng, 2) |
| rng_model_local = jax.random.fold_in(rng_model, jax.lax.axis_index("batch")) |
|
|
| def loss_fn(params, images, labels): |
| logits, _ = model.apply( |
| {"params": params}, images, |
| train=True, rngs={"dropout": rng_model_local}, |
| **dict(zip(flexi_argnames, args))) |
| return getattr(u, config.get("loss", "sigmoid_xent"))( |
| logits=logits, labels=labels) |
|
|
| l, grads = jax.value_and_grad(loss_fn)(params, images, labels) |
| l, grads = jax.lax.pmean((l, grads), axis_name="batch") |
| updates, opt = tx.update(grads, opt, params) |
| params = optax.apply_updates(params, updates) |
|
|
| gs = jax.tree_leaves(bv_optax.replace_frozen(config.schedule, grads, 0.)) |
| measurements["l2_grads"] = jnp.sqrt(sum([jnp.vdot(g, g) for g in gs])) |
| ps = jax.tree_leaves(params) |
| measurements["l2_params"] = jnp.sqrt(sum([jnp.vdot(p, p) for p in ps])) |
| us = jax.tree_leaves(updates) |
| measurements["l2_updates"] = jnp.sqrt(sum([jnp.vdot(u, u) for u in us])) |
|
|
| return params, opt, rng, l, measurements |
|
|
| |
| |
| |
| def predict_fn(params, image, **flexi_kw): |
| logits, out = model.apply({"params": params}, image, **flexi_kw) |
| return logits, out |
|
|
| |
| |
| |
| |
| |
| resume_ckpt_path = None |
| if save_ckpt_path and gfile.exists(save_ckpt_path): |
| resume_ckpt_path = save_ckpt_path |
| elif config.get("resume"): |
| resume_ckpt_path = fillin(config.resume) |
| if resume_ckpt_path: |
| write_note("Resume training from checkpoint...") |
| checkpoint = { |
| "params": params_cpu, |
| "opt": opt_cpu, |
| "chrono": u.chrono.save(), |
| } |
| checkpoint_tree = jax.tree_structure(checkpoint) |
| loaded = u.load_checkpoint_np(resume_ckpt_path, checkpoint_tree) |
| |
| checkpoint = jax.tree_map(u.recover_dtype, loaded) |
| params_cpu, opt_cpu = checkpoint["params"], checkpoint["opt"] |
| u.chrono.load(checkpoint["chrono"]) |
| elif config.get("model_init"): |
| write_note(f"Initialize model from {config.model_init}...") |
| params_cpu = model_mod.load( |
| params_cpu, config.model_init, config.get("model"), |
| **config.get("model_load", {})) |
| if jax.process_index() == 0: |
| parameter_overview.log_parameter_overview( |
| params_cpu, msg="restored params") |
|
|
| write_note("Kicking off misc stuff...") |
| first_step = bv_optax.get_count(opt_cpu) |
| u.chrono.inform(first_step=first_step) |
| prof = None |
|
|
| write_note(f"Replicating...\n{u.chrono.note}") |
| params_repl = flax.jax_utils.replicate(params_cpu) |
| opt_repl = flax.jax_utils.replicate(opt_cpu) |
|
|
| @functools.cache |
| def evaluators(): |
| return eval_common.from_config( |
| config, flexi.mkpredictfns(predict_fn, config.flexi, "predict_{x}"), |
| lambda s: write_note(f"Init evaluator: {s}…\n{u.chrono.note}"), |
| lambda key, cfg: get_steps(key, default=None, cfg=cfg), |
| ) |
|
|
| rng, rng_loop = jax.random.split(rng, 2) |
| rngs_loop = flax.jax_utils.replicate(rng_loop) |
| ckpt_writer = None |
|
|
| write_note(f"First step compilations...\n{u.chrono.note}") |
|
|
| |
| |
| |
| if first_step in (total_steps, 0): |
| mw.step_start(first_step) |
| for (name, evaluator, _, prefix) in evaluators(): |
| if config.evals[name].get("skip_first") and first_step != total_steps: |
| continue |
| write_note(f"{name} evaluation...\n{u.chrono.note}") |
| with u.chrono.log_timing(f"z/secs/eval/{name}"): |
| for key, value in evaluator.run(params_repl): |
| mw.measure(f"{prefix}{key}", value) |
|
|
| |
| |
| for step, batch in zip(range(first_step + 1, total_steps + 1), train_iter): |
| mw.step_start(step) |
|
|
| np_rng = flexi.mkrng(xm_xp.id, xm_wu.id, step) |
| flexi_args = [ |
| flexi.choice(config.flexi[n].v, config.flexi[n].p, np_rng) |
| for n in flexi_argnames |
| ] |
|
|
| with jax.profiler.StepTraceAnnotation("train_step", step_num=step): |
| with u.chrono.log_timing("z/secs/update0", noop=step > first_step + 1): |
| params_repl, opt_repl, rngs_loop, loss_value, measurements = update_fn( |
| params_repl, opt_repl, rngs_loop, batch["image"], batch["labels"], |
| *flexi_args) |
|
|
| |
| if jax.process_index() == 0: |
| prof = u.startstop_prof(prof, step, first_step, get_steps("log_training")) |
|
|
| |
| if (u.itstime(step, get_steps("log_training"), total_steps, host=0) |
| or u.chrono.warmup and jax.process_index() == 0): |
| for i, sched_fn_cpu in enumerate(sched_fns_cpu): |
| mw.measure(f"global_schedule{i if i else ''}", sched_fn_cpu(step - 1)) |
| l = mw.measure("training_loss", loss_value[0]) |
| for name, value in measurements.items(): |
| mw.measure(name, value[0]) |
| u.chrono.tick(step) |
| if not np.isfinite(l): |
| raise RuntimeError(f"The loss became nan or inf somewhere within steps " |
| f"[{step - get_steps('log_training')}, {step}]") |
|
|
| |
| if (save_ckpt_path and |
| (u.itstime(step, get_steps("ckpt", None), total_steps, host=0) or |
| u.itstime(step, get_steps("keep_ckpt", None), total_steps, host=0))): |
| u.chrono.pause(wait_for=(params_repl, opt_repl)) |
| u.checkpointing_timeout(ckpt_writer, config.get("ckpt_timeout", 1)) |
| |
| |
| |
| params_cpu = jax.tree_map(lambda x: np.array(x[0]), params_repl) |
| opt_cpu = jax.tree_map(lambda x: np.array(x[0]), opt_repl) |
|
|
| |
| copy_step = None |
| if u.itstime(step, get_steps("keep_ckpt", None), total_steps): |
| copy_step = step |
|
|
| ckpt = {"params": params_cpu, "opt": opt_cpu, "chrono": u.chrono.save()} |
| ckpt_writer = pool.apply_async( |
| u.save_checkpoint, (ckpt, save_ckpt_path, copy_step)) |
| u.chrono.resume() |
|
|
| for (name, evaluator, log_steps, prefix) in evaluators(): |
| if u.itstime(step, log_steps, total_steps, first=False, last=True): |
| u.chrono.pause(wait_for=params_repl) |
| u.chrono.tick(step) |
| write_note(f"{name} evaluation...\n{u.chrono.note}") |
| with u.chrono.log_timing(f"z/secs/eval/{name}"): |
| for key, value in evaluator.run(params_repl): |
| mw.measure(f"{prefix}{key}", value) |
| u.chrono.resume() |
| mw.step_end() |
|
|
| |
| |
| if jax.process_index() == 0 and prof is not None: |
| u.startstop_prof(prof) |
|
|
| |
| write_note(f"Done!\n{u.chrono.note}") |
|
|
| pool.close() |
| pool.join() |
| mw.close() |
|
|
| |
| u.sync() |
|
|
| u.maybe_cleanup_workdir(workdir, flags.FLAGS.cleanup, info) |
|
|
|
|
| if __name__ == "__main__": |
| app.run(main) |
|
|