|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Gemma adaptation for Pi, taken from big_vision. |
|
|
|
|
|
We follow this einsum axis naming convention: |
|
|
B: batch |
|
|
T: query length |
|
|
S: k/v length |
|
|
N: num query heads |
|
|
K: num k/v heads |
|
|
G: num query heads per k/v head |
|
|
H: head dim |
|
|
D: d_model ("features") |
|
|
""" |
|
|
|
|
|
from collections.abc import Sequence |
|
|
import dataclasses |
|
|
from typing import Literal, TypeAlias |
|
|
|
|
|
import einops |
|
|
import flax.linen as nn |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
|
|
|
import openpi.models.lora as lora |
|
|
import openpi.shared.array_typing as at |
|
|
import openpi.training.sharding as sharding |
|
|
|
|
|
PALIGEMMA_VOCAB_SIZE = 257_152 |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class Config: |
|
|
width: int |
|
|
depth: int |
|
|
mlp_dim: int |
|
|
num_heads: int |
|
|
num_kv_heads: int |
|
|
head_dim: int |
|
|
lora_configs: dict[str, lora.LoRAConfig] = dataclasses.field(default_factory=dict) |
|
|
|
|
|
|
|
|
Variant = Literal["dummy", "gemma_300m", "gemma_300m_lora", "gemma_2b", "gemma_2b_lora"] |
|
|
|
|
|
|
|
|
def get_config(variant: Variant) -> Config: |
|
|
"""Returns config for specified gemma variant.""" |
|
|
if variant == "dummy": |
|
|
return Config( |
|
|
width=64, |
|
|
depth=4, |
|
|
mlp_dim=128, |
|
|
num_heads=8, |
|
|
num_kv_heads=1, |
|
|
head_dim=16, |
|
|
) |
|
|
if variant == "gemma_300m": |
|
|
|
|
|
return Config( |
|
|
width=1024, |
|
|
depth=18, |
|
|
mlp_dim=4096, |
|
|
num_heads=8, |
|
|
num_kv_heads=1, |
|
|
head_dim=256, |
|
|
) |
|
|
if variant == "gemma_2b": |
|
|
return Config( |
|
|
width=2048, |
|
|
depth=18, |
|
|
mlp_dim=16_384, |
|
|
num_heads=8, |
|
|
num_kv_heads=1, |
|
|
head_dim=256, |
|
|
) |
|
|
if variant == "gemma_2b_lora": |
|
|
return Config( |
|
|
width=2048, |
|
|
depth=18, |
|
|
mlp_dim=16_384, |
|
|
num_heads=8, |
|
|
num_kv_heads=1, |
|
|
head_dim=256, |
|
|
lora_configs={"attn": lora.LoRAConfig(rank=16, alpha=16.0), "ffn": lora.LoRAConfig(rank=16, alpha=16.0)}, |
|
|
) |
|
|
if variant == "gemma_300m_lora": |
|
|
|
|
|
return Config( |
|
|
width=1024, |
|
|
depth=18, |
|
|
mlp_dim=4096, |
|
|
num_heads=8, |
|
|
num_kv_heads=1, |
|
|
head_dim=256, |
|
|
lora_configs={"attn": lora.LoRAConfig(rank=32, alpha=32.0), "ffn": lora.LoRAConfig(rank=32, alpha=32.0)}, |
|
|
) |
|
|
raise ValueError(f"Unknown variant: {variant}") |
|
|
|
|
|
|
|
|
@at.typecheck |
|
|
class RMSNorm(nn.Module): |
|
|
@nn.compact |
|
|
def __call__(self, x, cond): |
|
|
dtype = x.dtype |
|
|
var = jnp.mean(jnp.square(x.astype(jnp.float32)), axis=-1, keepdims=True) |
|
|
normed_inputs = jnp.asarray(x * jnp.reciprocal(jnp.sqrt(var + 1e-06))) |
|
|
if cond is None: |
|
|
|
|
|
scale = self.param("scale", nn.initializers.zeros_init(), (x.shape[-1])) |
|
|
normed_inputs = normed_inputs * ( |
|
|
1 + scale |
|
|
) |
|
|
return normed_inputs.astype(dtype), None |
|
|
|
|
|
|
|
|
modulation = nn.Dense(x.shape[-1] * 3, kernel_init=nn.initializers.zeros, dtype=dtype)(cond) |
|
|
scale, shift, gate = jnp.split(modulation[:, None, :], 3, axis=-1) |
|
|
normed_inputs = normed_inputs * (1 + scale) + shift |
|
|
return normed_inputs.astype(dtype), gate |
|
|
|
|
|
|
|
|
@at.typecheck |
|
|
class Embedder(nn.Module): |
|
|
"""Embedder module.""" |
|
|
|
|
|
vocab_size: int |
|
|
embed_dim: int |
|
|
|
|
|
def setup(self): |
|
|
self.input_embedding_table = self.param( |
|
|
"input_embedding", |
|
|
nn.initializers.normal(), |
|
|
(self.vocab_size, self.embed_dim), |
|
|
) |
|
|
|
|
|
def encode(self, x): |
|
|
x = self.input_embedding_table[(x,)] |
|
|
x *= jnp.sqrt(self.embed_dim).astype(x.dtype) |
|
|
return x |
|
|
|
|
|
def decode(self, x): |
|
|
return jnp.dot(x, self.input_embedding_table.T) |
|
|
|
|
|
|
|
|
@at.typecheck |
|
|
class Attention(nn.Module): |
|
|
"""Attention module.""" |
|
|
|
|
|
configs: Sequence[Config] |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, xs, positions, attn_mask, kv_cache): |
|
|
|
|
|
assert all(config.head_dim == self.configs[0].head_dim for config in self.configs) |
|
|
assert all(config.num_heads == self.configs[0].num_heads for config in self.configs) |
|
|
assert all(config.num_kv_heads == self.configs[0].num_kv_heads for config in self.configs) |
|
|
|
|
|
dtype = next(x.dtype for x in xs if x is not None) |
|
|
|
|
|
qkvs = [] |
|
|
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): |
|
|
if x is None: |
|
|
continue |
|
|
if config.num_kv_heads == config.num_heads: |
|
|
qkv_einsum = lora.Einsum( |
|
|
shape=(3, config.num_heads, config.width, config.head_dim), |
|
|
name=_name("qkv_einsum", i), |
|
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), |
|
|
lora_config=config.lora_configs.get("attn"), |
|
|
) |
|
|
qkvs.append(qkv_einsum("BSD,3KDH->3BSKH", x)) |
|
|
else: |
|
|
q_einsum = lora.Einsum( |
|
|
shape=(config.num_heads, config.width, config.head_dim), |
|
|
name=_name("q_einsum", i), |
|
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), |
|
|
lora_config=config.lora_configs.get("attn"), |
|
|
) |
|
|
q = q_einsum("BTD,NDH->BTNH", x) |
|
|
kv_einsum = lora.Einsum( |
|
|
shape=(2, config.num_kv_heads, config.width, config.head_dim), |
|
|
name=_name("kv_einsum", i), |
|
|
init_fn=nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0, 1)), |
|
|
lora_config=config.lora_configs.get("attn"), |
|
|
) |
|
|
k, v = kv_einsum("BSD,2KDH->2BSKH", x) |
|
|
qkvs.append((q, k, v)) |
|
|
|
|
|
q, k, v = (jnp.concatenate(y, axis=1) for y in zip(*qkvs, strict=True)) |
|
|
|
|
|
q = _apply_rope(q, positions=positions) |
|
|
q *= self.configs[0].head_dim ** -0.5 |
|
|
|
|
|
k = _apply_rope(k, positions=positions) |
|
|
|
|
|
|
|
|
assert q.dtype == k.dtype == v.dtype == dtype |
|
|
|
|
|
if kv_cache is not None: |
|
|
cache_k, cache_v = kv_cache |
|
|
k = jnp.concatenate([cache_k, k], axis=1) |
|
|
v = jnp.concatenate([cache_v, v], axis=1) |
|
|
|
|
|
q = einops.rearrange(q, "B T (K G) H -> B T K G H", K=self.configs[0].num_kv_heads) |
|
|
logits = jnp.einsum("BTKGH,BSKH->BKGTS", q, k, preferred_element_type=jnp.float32) |
|
|
|
|
|
if attn_mask.shape != (q.shape[0], 1, q.shape[1], k.shape[1]): |
|
|
raise ValueError( |
|
|
f"Attention mask with shape {attn_mask.shape} but shapes for q and k are: {q.shape} and {k.shape}" |
|
|
) |
|
|
|
|
|
|
|
|
big_neg = -2.3819763e38 |
|
|
masked_logits = jnp.where(attn_mask[:, :, None, :, :], logits, big_neg) |
|
|
|
|
|
probs = jax.nn.softmax(masked_logits, axis=-1).astype(dtype) |
|
|
|
|
|
encoded = jnp.einsum("BKGTS,BSKH->BTKGH", probs, v) |
|
|
encoded = einops.rearrange(encoded, "B T K G H -> B T (K G) H") |
|
|
|
|
|
out = [] |
|
|
start = 0 |
|
|
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): |
|
|
if x is not None: |
|
|
end = start + x.shape[1] |
|
|
out_einsum = lora.Einsum( |
|
|
shape=(config.num_heads, config.head_dim, config.width), |
|
|
name=_name("attn_vec_einsum", i), |
|
|
init_fn=nn.initializers.lecun_normal(in_axis=(-3, -2), out_axis=-1), |
|
|
lora_config=config.lora_configs.get("attn"), |
|
|
) |
|
|
out.append(out_einsum("BTNH,NHD->BTD", encoded[:, start:end])) |
|
|
start = end |
|
|
else: |
|
|
out.append(None) |
|
|
|
|
|
return out, (k, v) |
|
|
|
|
|
|
|
|
@at.typecheck |
|
|
class FeedForward(nn.Module): |
|
|
"""Feed forward module.""" |
|
|
|
|
|
features: int |
|
|
hidden_dim: int |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x): |
|
|
dtype = x.dtype |
|
|
w_gating = self.param( |
|
|
"gating_einsum", |
|
|
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1, batch_axis=(0,)), |
|
|
(2, self.features, self.hidden_dim), |
|
|
).astype(dtype) |
|
|
ff_gate = jnp.dot(x, w_gating[0]) |
|
|
gate_value = nn.gelu(ff_gate) |
|
|
|
|
|
ff1 = jnp.dot(x, w_gating[1]) |
|
|
activations = gate_value * ff1 |
|
|
|
|
|
w_linear = self.param( |
|
|
"linear", |
|
|
nn.initializers.lecun_normal(in_axis=-2, out_axis=-1), |
|
|
(self.hidden_dim, self.features), |
|
|
).astype(dtype) |
|
|
outputs = jnp.dot(activations, w_linear) |
|
|
assert outputs.dtype == dtype |
|
|
return outputs |
|
|
|
|
|
|
|
|
@at.typecheck |
|
|
class Block(nn.Module): |
|
|
"""Transformer block.""" |
|
|
|
|
|
configs: tuple[Config, ...] |
|
|
|
|
|
dropout: float = 0.0 |
|
|
dropout_bdims: tuple[int, ...] = () |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, xs, kv_cache, positions, attn_mask, adarms_cond, deterministic=True): |
|
|
xs = sharding.activation_sharding_constraint(xs) |
|
|
drop = nn.Dropout(self.dropout, self.dropout_bdims) if self.dropout else lambda x, _: x |
|
|
|
|
|
attn = Attention(configs=self.configs, name="attn") |
|
|
|
|
|
pre_attn = [] |
|
|
gates = [] |
|
|
for i, x in enumerate(xs): |
|
|
if x is not None: |
|
|
x, gate = RMSNorm(name=_name("pre_attention_norm", i))(x, adarms_cond[i]) |
|
|
pre_attn.append(x) |
|
|
gates.append(gate if x is not None else None) |
|
|
|
|
|
pre_attn = sharding.activation_sharding_constraint(pre_attn) |
|
|
post_attn, kv_cache = attn(pre_attn, positions, attn_mask, kv_cache) |
|
|
post_attn = jax.tree.map(lambda x: drop(x, deterministic), post_attn) |
|
|
post_attn = sharding.activation_sharding_constraint(post_attn) |
|
|
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, post_attn, gates, strict=True)] |
|
|
xs = sharding.activation_sharding_constraint(xs) |
|
|
|
|
|
out = [] |
|
|
gates = [] |
|
|
for i, (x, config) in enumerate(zip(xs, self.configs, strict=True)): |
|
|
if x is not None: |
|
|
x, gate = RMSNorm(name=_name("pre_ffw_norm", i))(x, adarms_cond[i]) |
|
|
x = lora.FeedForward( |
|
|
features=config.width, |
|
|
hidden_dim=config.mlp_dim, |
|
|
name=_name("mlp", i), |
|
|
lora_config=config.lora_configs.get("ffn"), |
|
|
)(x) |
|
|
out.append(x) |
|
|
gates.append(gate if x is not None else None) |
|
|
|
|
|
out = sharding.activation_sharding_constraint(out) |
|
|
out = jax.tree.map(lambda x: drop(x, deterministic), out) |
|
|
xs = [_gated_residual(x, y, gate) for x, y, gate in zip(xs, out, gates, strict=True)] |
|
|
xs = sharding.activation_sharding_constraint(xs) |
|
|
|
|
|
return xs, kv_cache |
|
|
|
|
|
|
|
|
KVCache: TypeAlias = tuple[at.Float[at.Array, "l b _t _k _h"], at.Float[at.Array, "l b _t _v _h"]] |
|
|
|
|
|
|
|
|
@at.typecheck |
|
|
class Module(nn.Module): |
|
|
"""Transformer model, supporting a mixture of different weights for different tokens.""" |
|
|
|
|
|
configs: Sequence[Config] |
|
|
embed_dtype: str |
|
|
|
|
|
dropout: float = 0.0 |
|
|
dropout_bdims: tuple[int, ...] = () |
|
|
adarms: bool = False |
|
|
|
|
|
def setup(self): |
|
|
|
|
|
assert all(config.depth == self.configs[0].depth for config in self.configs) |
|
|
|
|
|
self.embedder = Embedder( |
|
|
vocab_size=PALIGEMMA_VOCAB_SIZE, |
|
|
embed_dim=self.configs[0].width, |
|
|
name="embedder", |
|
|
) |
|
|
block_cls = nn.remat( |
|
|
Block, |
|
|
prevent_cse=False, |
|
|
static_argnums=(5,), |
|
|
policy=jax.checkpoint_policies.nothing_saveable, |
|
|
) |
|
|
self.layers = nn.scan( |
|
|
block_cls, |
|
|
variable_axes={"params": 0}, |
|
|
split_rngs={"params": True, "dropout": True}, |
|
|
in_axes=( |
|
|
0, |
|
|
nn.broadcast, |
|
|
nn.broadcast, |
|
|
nn.broadcast, |
|
|
nn.broadcast, |
|
|
), |
|
|
length=self.configs[0].depth, |
|
|
)( |
|
|
configs=self.configs, |
|
|
dropout=self.dropout, |
|
|
dropout_bdims=self.dropout_bdims, |
|
|
) |
|
|
self.final_norms = [RMSNorm(name=_name("final_norm", i)) for i in range(len(self.configs))] |
|
|
|
|
|
@at.typecheck |
|
|
def embed(self, tokens: at.Int[at.Array, "b t"]) -> at.Float[at.Array, "b t d"]: |
|
|
return self.embedder.encode(tokens).astype(self.embed_dtype) |
|
|
|
|
|
@at.typecheck |
|
|
def __call__( |
|
|
self, |
|
|
|
|
|
embedded: Sequence[at.Float[at.Array, "b _t _d"] | None], |
|
|
positions: at.Int[at.Array, "b t"], |
|
|
mask: at.Bool[at.Array, "b t s"], |
|
|
adarms_cond: Sequence[at.Float[at.Array, "b _d"] | None] | None = None, |
|
|
*, |
|
|
kv_cache: KVCache | None = None, |
|
|
deterministic: bool = True, |
|
|
) -> tuple[Sequence[at.Float[at.Array, "b _t _d"] | None], KVCache]: |
|
|
embedded = jax.tree.map(lambda e: e.astype(self.embed_dtype), embedded) |
|
|
mask = jnp.asarray(mask)[:, None, :, :] |
|
|
if adarms_cond is None: |
|
|
adarms_cond = [None] * len(self.configs) |
|
|
|
|
|
embedded, kv_cache = self.layers(embedded, kv_cache, positions, mask, adarms_cond, deterministic) |
|
|
|
|
|
assert all(e.dtype == jnp.dtype(self.embed_dtype) for e in embedded if e is not None) |
|
|
|
|
|
return [ |
|
|
f(e, a)[0] if e is not None else e for f, e, a in zip(self.final_norms, embedded, adarms_cond, strict=True) |
|
|
], kv_cache |
|
|
|
|
|
def init(self, use_adarms: Sequence[bool]): |
|
|
"""Convenience method for initializing all parameters, necessary due to the quirks of linen.""" |
|
|
self.embed(jnp.zeros((1, 1), dtype=jnp.int32)) |
|
|
self( |
|
|
[jnp.zeros((1, 1, c.width)) for c in self.configs], |
|
|
jnp.zeros((1, len(self.configs)), dtype=jnp.int32), |
|
|
jnp.zeros((1, len(self.configs), len(self.configs)), dtype=bool), |
|
|
adarms_cond=[jnp.zeros((1, c.width)) if u else None for u, c in zip(use_adarms, self.configs, strict=True)], |
|
|
) |
|
|
|
|
|
|
|
|
def _apply_rope(x, *, positions, max_wavelength=10_000): |
|
|
"""Applies RoPE positions [B, L] to x [B, L, H, D].""" |
|
|
freq_exponents = (2.0 / x.shape[-1]) * jnp.arange(x.shape[-1] // 2, dtype=jnp.float32) |
|
|
timescale = max_wavelength**freq_exponents |
|
|
radians = positions[..., None] / timescale[None, None, :] |
|
|
radians = radians[..., None, :] |
|
|
assert radians.dtype == jnp.float32 |
|
|
|
|
|
sin, cos = jnp.sin(radians), jnp.cos(radians) |
|
|
x1, x2 = jnp.split(x, 2, axis=-1) |
|
|
res = jnp.concatenate([x1 * cos - x2 * sin, x2 * cos + x1 * sin], axis=-1) |
|
|
assert res.dtype == jnp.float32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return res.astype(x.dtype) |
|
|
|
|
|
|
|
|
def _name(name, i): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if i == 0: |
|
|
return name |
|
|
return f"{name}_{i}" |
|
|
|
|
|
|
|
|
def _gated_residual(x, y, gate): |
|
|
assert (x is None) == (y is None) |
|
|
if x is None: |
|
|
return None |
|
|
if gate is None: |
|
|
return x + y |
|
|
return x + y * gate |
|
|
|