| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Utilities for constructing PyTrees of PartitionSpecs.""" |
|
|
| |
|
|
| import re |
|
|
| from flax.core.frozen_dict import freeze |
| from flax.traverse_util import flatten_dict, unflatten_dict |
| from jax.experimental import PartitionSpec as P |
|
|
|
|
| |
| _unmatched = object() |
|
|
| |
| empty_dict = object() |
|
|
|
|
| def _match(qs, ks): |
| """Return True if regexes in qs match any window of strings in tuple ks.""" |
| |
| qts = tuple((re.compile(x + "$") for x in qs)) |
| for i in range(len(ks) - len(qs) + 1): |
| matches = [x.match(y) for x, y in zip(qts, ks[i:])] |
| if matches and all(matches): |
| return True |
| return False |
|
|
|
|
| def _replacement_rules(rules): |
| def replace(key, val): |
| for rule, replacement in rules: |
| if _match(rule, key): |
| return replacement |
| return val |
|
|
| return replace |
|
|
|
|
| |
| |
| def _get_partition_rules(): |
| return [ |
| |
| (("transformer", "wpe", "embedding"), P("mp", None)), |
| (("transformer", "wte", "embedding"), P("mp", None)), |
| |
| (("attention", "(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")), |
| (("attention", "out_proj", "kernel"), P("mp", None)), |
| (("attention", "out_proj", "bias"), None), |
| |
| (("mlp", "c_fc", "kernel"), P(None, "mp")), |
| (("mlp", "c_fc", "bias"), P("mp")), |
| (("mlp", "c_proj", "kernel"), P("mp", None)), |
| (("mlp", "c_proj", "bias"), None), |
| |
| ((r"ln_\d+", "bias"), None), |
| ((r"\d+", r"ln_\d+", "scale"), None), |
| (("ln_f", "bias"), None), |
| (("ln_f", "scale"), None), |
| ] |
|
|
|
|
| def set_partitions(in_dict): |
| rules = _get_partition_rules() |
| replace = _replacement_rules(rules) |
| initd = {k: _unmatched for k in flatten_dict(in_dict)} |
| result = {k: replace(k, v) for k, v in initd.items()} |
| assert _unmatched not in result.values(), "Incomplete partition spec." |
| return freeze(unflatten_dict(result)) |
|
|