| import os |
| import math |
| from typing import Any, Mapping, Text, Tuple, Union, NamedTuple |
| from functools import partial |
| import re |
| import dataclasses |
| import random |
| from ml_collections import ConfigDict |
| from ml_collections.config_dict.config_dict import placeholder |
|
|
| import flax |
| import jax |
| import jax.numpy as jnp |
| from jax.sharding import PartitionSpec as PS |
| from jax.sharding import Mesh |
| from jax.experimental import mesh_utils |
| from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint |
| from jax.experimental.pjit import pjit |
| from jax.interpreters import pxla |
| import numpy as np |
| from transformers import FlaxLogitsWarper |
|
|
|
|
| class JaxRNG(object): |
| """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside |
| pure function. |
| """ |
|
|
| @classmethod |
| def from_seed(cls, seed): |
| return cls(jax.random.PRNGKey(seed)) |
|
|
| def __init__(self, rng): |
| self.rng = rng |
|
|
| def __call__(self, keys=None): |
| if keys is None: |
| self.rng, split_rng = jax.random.split(self.rng) |
| return split_rng |
| elif isinstance(keys, int): |
| split_rngs = jax.random.split(self.rng, num=keys + 1) |
| self.rng = split_rngs[0] |
| return tuple(split_rngs[1:]) |
| else: |
| split_rngs = jax.random.split(self.rng, num=len(keys) + 1) |
| self.rng = split_rngs[0] |
| return {key: val for key, val in zip(keys, split_rngs[1:])} |
|
|
|
|
| class JaxDistributedConfig(object): |
| """ Utility class for initializing JAX distributed. """ |
|
|
| @staticmethod |
| def get_default_config(updates=None): |
| config = ConfigDict() |
| config.initialize_jax_distributed = False |
| config.coordinator_address = placeholder(str) |
| config.num_processes = placeholder(int) |
| config.process_id = placeholder(int) |
| config.local_device_ids = placeholder(str) |
|
|
| if updates is not None: |
| config.update(ConfigDict(updates).copy_and_resolve_references()) |
| return config |
|
|
| @classmethod |
| def initialize(cls, config): |
| config = cls.get_default_config(config) |
| if config.initialize_jax_distributed: |
| if config.local_device_ids is not None: |
| local_device_ids = [int(x) for x in config.local_device_ids.split(',')] |
| else: |
| local_device_ids = None |
|
|
| jax.distributed.initialize( |
| coordinator_address=config.coordinator_address, |
| num_processes=config.num_processes, |
| process_id=config.process_id, |
| local_device_ids=local_device_ids, |
| ) |
|
|
|
|
| class FlaxTemperatureLogitsWarper(FlaxLogitsWarper): |
| """ JIT traceable version of FlaxLogitsWarper that performs temperature scaling.""" |
| def __init__(self, temperature): |
| self.temperature = temperature |
|
|
| def __call__(self, input_ids, scores, cur_len): |
| return scores / jnp.clip(self.temperature, a_min=1e-8) |
|
|
|
|
| def make_shard_and_gather_fns(partition_specs, dtype_specs=None): |
| """ Create pytree of sharding and gathering functions from pytree of |
| partition specs. |
| """ |
| float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) |
|
|
| def make_to_dtype_fn(dtype_spec): |
| def to_dtype(tensor): |
| if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes: |
| |
| return tensor.astype(dtype_specs) |
| elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'): |
| return tensor.astype(dtype_spec.dtype) |
| return tensor |
| return to_dtype |
|
|
| def make_shard_fn(partition_spec, dtype_spec=None): |
| jax_shard_function = pjit( |
| make_to_dtype_fn(dtype_spec), |
| in_shardings=None, |
| out_shardings=partition_spec |
| ) |
| def shard_fn(tensor): |
| return jax_shard_function(tensor).block_until_ready() |
| return shard_fn |
|
|
| def make_gather_fn(partition_spec, dtype_spec=None): |
| jax_gather_fn = pjit( |
| make_to_dtype_fn(dtype_spec), |
| in_shardings=partition_spec, |
| out_shardings=None |
| ) |
| def gather_fn(tensor): |
| return jax.device_get(jax_gather_fn(tensor)) |
| return gather_fn |
|
|
| if dtype_specs is None or dtype_specs in float_dtypes: |
| shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs) |
| gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs) |
| else: |
| shard_fns = jax.tree_util.tree_map( |
| make_shard_fn, partition_specs, dtype_specs |
| ) |
| gather_fns = jax.tree_util.tree_map( |
| make_gather_fn, partition_specs, dtype_specs |
| ) |
| return shard_fns, gather_fns |
|
|
|
|
| def set_random_seed(seed): |
| np.random.seed(seed) |
| random.seed(seed) |
| init_rng(seed) |
|
|
|
|
| def get_jax_mesh(axis_dims, names): |
| if axis_dims.startswith('!'): |
| |
| mesh_axis_splitting = True |
| axis_dims = axis_dims[1:] |
| else: |
| mesh_axis_splitting = False |
|
|
| if ':' in axis_dims: |
| dims = [] |
| dim_names = [] |
| for axis in axis_dims.split(','): |
| name, dim = axis.split(':') |
| assert name in names |
| dims.append(int(dim)) |
| dim_names.append(name) |
| assert(set(dim_names) == set(names)) |
| else: |
| dims = [int(x) for x in axis_dims.split(',')] |
| dim_names = names |
| assert len(dims) == len(names) |
| mesh_shape = np.arange(jax.device_count()).reshape(dims).shape |
| if mesh_axis_splitting: |
| physical_mesh = np.array(jax.devices()).reshape(mesh_shape) |
| else: |
| physical_mesh = mesh_utils.create_device_mesh(mesh_shape) |
| return Mesh(physical_mesh, dim_names) |
|
|
|
|
| def names_in_current_mesh(*names): |
| """ Check if current mesh axes contain these names. """ |
| mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names |
| return set(names) <= set(mesh_axis_names) |
|
|
|
|
| def get_names_from_parition_spec(partition_specs): |
| """ Return axis names from partition specs. """ |
| names = set() |
| if isinstance(partition_specs, dict): |
| partition_specs = partition_specs.values() |
| for item in partition_specs: |
| if item is None: |
| continue |
| elif isinstance(item, str): |
| names.add(item) |
| else: |
| names.update(get_names_from_parition_spec(item)) |
|
|
| return list(names) |
|
|
|
|
| def with_sharding_constraint(x, partition_specs): |
| """ A smarter version of with_sharding_constraint that only applies the |
| constraint if the current mesh contains the axes in the partition specs. |
| """ |
| axis_names = get_names_from_parition_spec(partition_specs) |
| if names_in_current_mesh(*axis_names): |
| x = _with_sharding_constraint(x, partition_specs) |
| return x |
|
|
|
|
| def wrap_function_with_rng(rng): |
| """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """ |
| def wrap_function(function): |
| def wrapped(*args, **kwargs): |
| nonlocal rng |
| rng, split_rng = jax.random.split(rng) |
| return function(split_rng, *args, **kwargs) |
| return wrapped |
| return wrap_function |
|
|
|
|
| def init_rng(seed): |
| global jax_utils_rng |
| jax_utils_rng = JaxRNG.from_seed(seed) |
|
|
|
|
| def next_rng(*args, **kwargs): |
| global jax_utils_rng |
| return jax_utils_rng(*args, **kwargs) |
|
|
|
|
| def get_metrics(metrics, unreplicate=False, stack=False): |
| if unreplicate: |
| metrics = flax.jax_utils.unreplicate(metrics) |
| metrics = jax.device_get(metrics) |
| if stack: |
| return jax.tree_map(lambda *args: np.stack(args), *metrics) |
| else: |
| return {key: float(val) for key, val in metrics.items()} |
|
|
|
|
| def mse_loss(val, target, valid=None): |
| if valid is None: |
| valid = jnp.ones((*target.shape[:2], 1)) |
| valid = valid.astype(jnp.float32) |
| loss = jnp.mean( |
| jnp.where( |
| valid > 0.0, |
| jnp.square(val - target), |
| 0.0 |
| ) |
| ) |
| return loss |
|
|
|
|
| def cross_entropy_loss_and_accuracy(logits, tokens, valid=None): |
| if valid is None: |
| valid = jnp.ones(tokens.shape[:2]) |
| valid = valid.astype(jnp.float32) |
| valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10) |
| logits = logits.astype(jnp.float32) |
| token_log_prob = jnp.squeeze( |
| jnp.take_along_axis( |
| jax.nn.log_softmax(logits, axis=-1), |
| jnp.expand_dims(tokens, -1), |
| axis=-1, |
| ), |
| -1, |
| ) |
| token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0)) |
| loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length) |
| correct = jnp.where( |
| valid > 0.0, |
| jnp.argmax(logits, axis=-1) == tokens, |
| jnp.array(False) |
| ) |
| accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length) |
| return loss, accuracy |
|
|
|
|
| def global_norm(tree): |
| """ Return the global L2 norm of a pytree. """ |
| squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree) |
| flattened, _ = jax.flatten_util.ravel_pytree(squared) |
| return jnp.sqrt(jnp.sum(flattened)) |
|
|
|
|
| def average_metrics(metrics): |
| with jax.spmd_mode("allow_all"): |
| return jax.tree_map( |
| lambda *args: jnp.mean(jnp.stack(args)), |
| *metrics |
| ) |
|
|
|
|
| def get_float_dtype_by_name(dtype): |
| return { |
| 'bf16': jnp.bfloat16, |
| 'bfloat16': jnp.bfloat16, |
| 'fp16': jnp.float16, |
| 'float16': jnp.float16, |
| 'fp32': jnp.float32, |
| 'float32': jnp.float32, |
| 'fp64': jnp.float64, |
| 'float64': jnp.float64, |
| }[dtype] |
|
|
|
|
| def float_tensor_to_dtype(tensor, dtype): |
| if dtype is None or dtype == '': |
| return tensor |
| if isinstance(dtype, str): |
| dtype = get_float_dtype_by_name(dtype) |
| float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64) |
| if getattr(tensor, 'dtype', None) in float_dtypes: |
| tensor = tensor.astype(dtype) |
| return tensor |
|
|
|
|
| def float_to_dtype(tree, dtype): |
| return jax.tree_util.tree_map( |
| partial(float_tensor_to_dtype, dtype=dtype), tree |
| ) |
|
|
|
|
| def get_gradient_checkpoint_policy(name): |
| return { |
| 'everything_saveable': jax.checkpoint_policies.everything_saveable, |
| 'nothing_saveable': jax.checkpoint_policies.nothing_saveable, |
| 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots, |
| 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, |
| }[name] |
|
|
|
|
| def tree_path_to_string(path, sep=None): |
| keys = [] |
| for key in path: |
| if isinstance(key, jax.tree_util.SequenceKey): |
| keys.append(str(key.idx)) |
| elif isinstance(key, jax.tree_util.DictKey): |
| keys.append(str(key.key)) |
| elif isinstance(key, jax.tree_util.GetAttrKey): |
| keys.append(str(key.name)) |
| elif isinstance(key, jax.tree_util.FlattenedIndexKey): |
| keys.append(str(key.key)) |
| else: |
| keys.append(str(key)) |
| if sep is None: |
| return tuple(keys) |
| return sep.join(keys) |
|
|
|
|
| def flatten_tree(xs, is_leaf=None, sep=None): |
| flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf) |
| output = {} |
| for key, val in flattened: |
| output[tree_path_to_string(key, sep=sep)] = val |
| return output |
|
|
|
|
| def named_tree_map(f, tree, *rest, is_leaf=None, sep=None): |
| """ An extended version of jax.tree_util.tree_map, where the mapped function |
| f takes both the name (path) and the tree leaf as input. |
| """ |
| return jax.tree_util.tree_map_with_path( |
| lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r), |
| tree, *rest, |
| is_leaf=is_leaf |
| ) |
|
|
|
|
| def match_partition_rules(rules, params): |
| """ Returns a pytree of PartitionSpec according to rules. Supports handling |
| Flax TrainState and Optax optimizer state. |
| """ |
| def get_partition_spec(name, leaf): |
| if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1: |
| """ Don't partition scalar values. """ |
| return PS() |
| for rule, ps in rules: |
| if re.search(rule, name) is not None: |
| return ps |
| raise ValueError(f'Partition rule not found for param: {name}') |
| return named_tree_map(get_partition_spec, params, sep='/') |
|
|
|
|
| def get_weight_decay_mask(exclusions): |
| """ Return a weight decay mask function that computes the pytree masks |
| according to the given exclusion rules. |
| """ |
| def decay(name, _): |
| for rule in exclusions: |
| if re.search(rule, name) is not None: |
| return False |
| return True |
|
|
| def weight_decay_mask(params): |
| return named_tree_map(decay, params, sep='/') |
|
|
| return weight_decay_mask |
|
|
|
|
| def tree_apply(fns, tree): |
| """ Apply a pytree of functions to the pytree. """ |
| return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree) |
|
|
|
|