|
|
import math |
|
|
from typing import Any, Literal |
|
|
|
|
|
import chex |
|
|
from einops import einops |
|
|
from flax import linen as nn |
|
|
from flax.linen.module import Module |
|
|
from flax.linen.module import compact |
|
|
from flax.struct import dataclass |
|
|
from flax.typing import Array |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
|
|
|
|
|
|
class FsqCodebook(nn.Module): |
|
|
input_dim: int |
|
|
target_codebook_size: int |
|
|
codebook_type: Literal["fsq", "lfq"] |
|
|
|
|
|
_bins_per_dim: tuple[int] | None = None |
|
|
|
|
|
@property |
|
|
def bins_per_dim(self) -> tuple[int]: |
|
|
if self._bins_per_dim is not None: |
|
|
return self._bins_per_dim |
|
|
|
|
|
if self.codebook_type == "fsq": |
|
|
return self._get_bins_fsq(self.target_codebook_size) |
|
|
elif self.codebook_type == "lfq": |
|
|
return self._get_bins_lfq(self.target_codebook_size) |
|
|
elif self.codebook_type == "custom": |
|
|
return self._get_bins_custom(self.target_codebook_size) |
|
|
else: |
|
|
raise ValueError(f"Codebook type {self.codebook_type} not supported.") |
|
|
|
|
|
@property |
|
|
def place_values(self) -> jnp.ndarray: |
|
|
place_values = [1] |
|
|
for b in self.bins_per_dim[:-1]: |
|
|
place_values.append(place_values[-1] * b) |
|
|
return jnp.array(place_values) |
|
|
|
|
|
@staticmethod |
|
|
def _get_bins_fsq(target_codebook_size: int) -> tuple[int]: |
|
|
""" |
|
|
Get bins per dimension based on codebook size, from the original FSQ paper. |
|
|
""" |
|
|
if target_codebook_size == 2**8: |
|
|
return (8, 6, 5) |
|
|
elif target_codebook_size == 2**10: |
|
|
return (8, 5, 5, 5) |
|
|
elif target_codebook_size == 2**12: |
|
|
return (7, 5, 5, 5, 5) |
|
|
elif target_codebook_size == 2**14: |
|
|
return (8, 8, 8, 6, 5) |
|
|
elif target_codebook_size == 2**16: |
|
|
return (8, 8, 8, 5, 5, 5) |
|
|
else: |
|
|
raise ValueError(f"Codebook size {target_codebook_size} not supported.") |
|
|
|
|
|
@staticmethod |
|
|
def _get_bins_custom(target_codebook_size: int) -> tuple[int]: |
|
|
if target_codebook_size == 2**8: |
|
|
return (16, 16) |
|
|
elif target_codebook_size == 2**10: |
|
|
return (32, 32) |
|
|
elif target_codebook_size == 2**12: |
|
|
return (64, 64) |
|
|
elif target_codebook_size == 2**14: |
|
|
return (128, 128) |
|
|
elif target_codebook_size == 2**16: |
|
|
return (256, 256) |
|
|
return None |
|
|
|
|
|
@staticmethod |
|
|
def _get_bins_lfq(target_codebook_size: int) -> tuple[int]: |
|
|
""" |
|
|
Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension) |
|
|
""" |
|
|
assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ" |
|
|
|
|
|
return (2,) * int(math.log2(target_codebook_size)) |
|
|
|
|
|
def setup(self): |
|
|
self.proj_down = nn.Dense(len(self.bins_per_dim)) |
|
|
self.proj_up = nn.Dense(self.input_dim) |
|
|
|
|
|
def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: |
|
|
tokens, z = self.encode(inputs) |
|
|
output = self.decode(tokens, z_grad=z) |
|
|
return tokens, output |
|
|
|
|
|
def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: |
|
|
bases = jnp.array(self.bins_per_dim) |
|
|
|
|
|
x = self.proj_down(inputs) |
|
|
z = jnp.tanh(x) |
|
|
|
|
|
|
|
|
digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32) |
|
|
tokens = self.undigitize(digits) |
|
|
|
|
|
return tokens, z |
|
|
|
|
|
def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray: |
|
|
bases = jnp.array(self.bins_per_dim) |
|
|
digits = self.digitize(tokens) |
|
|
|
|
|
z_q = digits / (bases - 1) * 2 - 1 |
|
|
|
|
|
if z_grad is not None: |
|
|
chex.assert_equal_shape([z_q, z_grad]) |
|
|
z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad |
|
|
|
|
|
return self.proj_up(z_q) |
|
|
|
|
|
def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray: |
|
|
return jnp.sum(digits * jnp.array(self.place_values), axis=-1) |
|
|
|
|
|
def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray: |
|
|
return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim) |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return math.prod(self.bins_per_dim) |
|
|
|
|
|
|
|
|
class ResNetDownBlock(nn.Module): |
|
|
stride: int = 1 |
|
|
n_filters: int = 64 |
|
|
dropout_rate: float = 0.0 |
|
|
group_size: int = 32 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: |
|
|
skip = x |
|
|
|
|
|
if self.stride > 1 or x.shape[-1] != self.n_filters: |
|
|
skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) |
|
|
|
|
|
x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x) |
|
|
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) |
|
|
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) |
|
|
x = nn.relu(x) |
|
|
x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x) |
|
|
|
|
|
return skip + x |
|
|
|
|
|
|
|
|
class ResNetUpBlock(nn.Module): |
|
|
stride: int = 1 |
|
|
n_filters: int = 64 |
|
|
dropout_rate: float = 0.0 |
|
|
group_size: int = 32 |
|
|
|
|
|
@nn.compact |
|
|
def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: |
|
|
skip = x |
|
|
|
|
|
if self.stride > 1: |
|
|
skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) |
|
|
|
|
|
x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x) |
|
|
x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) |
|
|
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) |
|
|
x = nn.relu(x) |
|
|
x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x) |
|
|
|
|
|
return skip + x |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LfqCodebookOutput: |
|
|
tokens: jnp.ndarray |
|
|
z: jnp.ndarray |
|
|
z_q: jnp.ndarray |
|
|
token_log_probs: jnp.ndarray |
|
|
commit_loss: jnp.ndarray |
|
|
|
|
|
|
|
|
class LookupFreeQuantization(nn.Module): |
|
|
num_dims: int |
|
|
latent_dim: int |
|
|
|
|
|
def setup(self): |
|
|
self.codebook = jnp.array([-1, 1]) |
|
|
self.activation = nn.tanh |
|
|
|
|
|
self.project_down = nn.Dense(self.num_dims) |
|
|
self.project_up = nn.Dense(self.latent_dim) |
|
|
|
|
|
def encode(self, z: jnp.ndarray) -> jnp.ndarray: |
|
|
z = self.project_down(z) |
|
|
token_squared_distances = jnp.square(z[..., None] - self.codebook) |
|
|
token_bits = jnp.argmin(token_squared_distances, axis=-1) |
|
|
return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1) |
|
|
|
|
|
def decode(self, tokens: jnp.ndarray) -> jnp.ndarray: |
|
|
token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32) |
|
|
return self.project_up(self.codebook[token_bits]) |
|
|
|
|
|
def loss(self, x: jnp.ndarray) -> LfqCodebookOutput: |
|
|
z = self.project_down(x) |
|
|
z = self.activation(z) |
|
|
|
|
|
token_squared_distances = jnp.square(z[..., None] - self.codebook) |
|
|
tokens = jnp.argmin(token_squared_distances, axis=-1) |
|
|
|
|
|
token_bit_log_probs = -token_squared_distances |
|
|
|
|
|
token_bit_expansions = jnp.bitwise_and( |
|
|
jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None] |
|
|
).astype(jnp.int32) |
|
|
token_log_probs = ( |
|
|
token_bit_log_probs[..., 0] @ (1 - token_bit_expansions) |
|
|
+ token_bit_log_probs[..., 1] @ token_bit_expansions |
|
|
) |
|
|
token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1)) |
|
|
chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims)) |
|
|
|
|
|
z_q = self.codebook[tokens] |
|
|
commit_loss = jnp.square(z - z_q).mean() |
|
|
z_q = jax.lax.stop_gradient(z_q - z) + z |
|
|
|
|
|
z_q = self.project_up(z_q) |
|
|
z = self.project_up(z) |
|
|
|
|
|
tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1) |
|
|
return LfqCodebookOutput( |
|
|
tokens=tokens, |
|
|
z=z, |
|
|
z_q=z_q, |
|
|
token_log_probs=jnp.zeros(()), |
|
|
commit_loss=commit_loss, |
|
|
) |
|
|
|
|
|
|
|
|
def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray: |
|
|
return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)) |
|
|
|
|
|
|
|
|
class GeGLU(Module): |
|
|
"""Gated Linear Unit with GELU (GeGLU) activation function. |
|
|
GeGLU is a Flax layer that combines a linear transformation with a GELU |
|
|
activation function in a gating mechanism. It is often used in Transformer models |
|
|
to provide non-linear capabilities while preserving a strong linear component. |
|
|
|
|
|
Attributes: |
|
|
features: the number of output features (default: None). |
|
|
""" |
|
|
|
|
|
output_dim: int = -1 |
|
|
|
|
|
@compact |
|
|
def __call__(self, inputs: Array) -> Array: |
|
|
"""Applies the GeGLU activation to the inputs. |
|
|
Args: |
|
|
inputs: the nd-array to apply the GeGLU activation function to. |
|
|
Returns: |
|
|
The transformed input. |
|
|
""" |
|
|
output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim |
|
|
|
|
|
x = nn.Dense(output_dim * 2)(inputs) |
|
|
x, gate = x[..., :output_dim], x[..., output_dim:] |
|
|
return x * nn.gelu(gate) |
|
|
|
|
|
|
|
|
class CrossAttentionLayer(nn.Module): |
|
|
dropout_rate: float = 0.0 |
|
|
num_heads: int = None |
|
|
causal: bool = False |
|
|
mlp_ratio: float = 4.0 |
|
|
|
|
|
@nn.compact |
|
|
def __call__( |
|
|
self, |
|
|
x: jnp.ndarray, |
|
|
y: jnp.ndarray, |
|
|
*, |
|
|
mask_self: jnp.ndarray | None = None, |
|
|
mask_cross: jnp.ndarray | None = None, |
|
|
train: bool = True, |
|
|
) -> jnp.ndarray: |
|
|
d_embed = x.shape[-1] |
|
|
seq_len_q = x.shape[-2] |
|
|
seq_len_k = y.shape[-2] |
|
|
|
|
|
if self.causal: |
|
|
|
|
|
bs_q = max(seq_len_q // seq_len_k, 1) |
|
|
bs_k = max(seq_len_k // seq_len_q, 1) |
|
|
|
|
|
mask_self = nn.make_causal_mask(x[..., 0]) |
|
|
mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k) |
|
|
|
|
|
|
|
|
skip = x |
|
|
x = nn.LayerNorm()(x) |
|
|
x = nn.MultiHeadDotProductAttention( |
|
|
num_heads=self.num_heads or d_embed // 64, |
|
|
dropout_rate=self.dropout_rate, |
|
|
deterministic=not train, |
|
|
)(x, x, x, mask=mask_self) |
|
|
x = skip + x |
|
|
|
|
|
|
|
|
skip = x |
|
|
x = nn.LayerNorm()(x) |
|
|
x = nn.MultiHeadDotProductAttention( |
|
|
num_heads=self.num_heads or d_embed // 64, |
|
|
dropout_rate=self.dropout_rate, |
|
|
deterministic=not train, |
|
|
)(x, y, y, mask=mask_cross) |
|
|
x = skip + x |
|
|
|
|
|
|
|
|
skip = x |
|
|
x = nn.LayerNorm()(x) |
|
|
x = nn.Dense(int(d_embed * self.mlp_ratio))(x) |
|
|
x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) |
|
|
x = GeGLU()(x) |
|
|
x = nn.Dense(d_embed)(x) |
|
|
return skip + x |
|
|
|
|
|
|
|
|
def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray: |
|
|
seq_len, d_embed = shape |
|
|
|
|
|
position = jnp.arange(0, seq_len, 1) |
|
|
div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)) |
|
|
return jnp.concatenate( |
|
|
[ |
|
|
jnp.sin(position[:, jnp.newaxis] * div_term), |
|
|
jnp.cos(position[:, jnp.newaxis] * div_term), |
|
|
], |
|
|
axis=-1, |
|
|
) |
|
|
|
|
|
|
|
|
class TokenizerEncoderDecoder(nn.Module): |
|
|
num_tokens: int |
|
|
num_cross_tokens: int |
|
|
num_layers: int |
|
|
causal: bool |
|
|
|
|
|
mlp_ratio: float = 4.0 |
|
|
use_state_conditioning: bool = False |
|
|
|
|
|
@nn.compact |
|
|
def __call__( |
|
|
self, |
|
|
y: jnp.ndarray, |
|
|
*, |
|
|
train: bool = True, |
|
|
state_conditioning: jnp.ndarray | None = None, |
|
|
mask: jnp.ndarray | None = None, |
|
|
) -> jnp.ndarray: |
|
|
x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1])) |
|
|
x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:]) |
|
|
|
|
|
if mask is not None: |
|
|
|
|
|
chex.assert_equal_shape([y[..., 0], mask]) |
|
|
attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens) |
|
|
else: |
|
|
attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens)) |
|
|
|
|
|
if self.use_state_conditioning: |
|
|
assert state_conditioning is not None, "State conditioning is required for this model." |
|
|
state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :] |
|
|
y = jnp.concatenate([y, state_embed], axis=-2) |
|
|
attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1) |
|
|
|
|
|
y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:]) |
|
|
|
|
|
for _ in range(self.num_layers): |
|
|
x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)( |
|
|
x, y, train=train, mask_self=None, mask_cross=attn_mask |
|
|
) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class FsqAttentionTokenizer(nn.Module): |
|
|
embed_dim: int |
|
|
data_dim: int |
|
|
data_horizon: int |
|
|
num_tokens: int |
|
|
num_layers: int |
|
|
target_codebook_size: int |
|
|
causal: bool = False |
|
|
mlp_ratio: float = 2.0 |
|
|
|
|
|
bound: float | None = None |
|
|
|
|
|
use_state_conditioning: bool = False |
|
|
|
|
|
@property |
|
|
def vocab_size(self) -> int: |
|
|
return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) |
|
|
|
|
|
def setup(self): |
|
|
self.proj = nn.Dense(self.embed_dim) |
|
|
self.encoder = TokenizerEncoderDecoder( |
|
|
num_tokens=self.num_tokens, |
|
|
num_cross_tokens=self.data_horizon, |
|
|
num_layers=self.num_layers, |
|
|
causal=self.causal, |
|
|
use_state_conditioning=self.use_state_conditioning, |
|
|
mlp_ratio=self.mlp_ratio, |
|
|
) |
|
|
self.codebook = FsqCodebook( |
|
|
input_dim=self.embed_dim, |
|
|
target_codebook_size=self.target_codebook_size, |
|
|
codebook_type="custom", |
|
|
) |
|
|
self.decoder = TokenizerEncoderDecoder( |
|
|
num_tokens=self.data_horizon, |
|
|
num_cross_tokens=self.num_tokens, |
|
|
num_layers=self.num_layers, |
|
|
causal=self.causal, |
|
|
use_state_conditioning=self.use_state_conditioning, |
|
|
mlp_ratio=self.mlp_ratio, |
|
|
) |
|
|
|
|
|
self.proj_mean = nn.Dense(self.data_dim) |
|
|
self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0)) |
|
|
|
|
|
def tokenize( |
|
|
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False |
|
|
) -> tuple[jnp.ndarray, jnp.ndarray]: |
|
|
if self.bound is not None: |
|
|
action = jnp.clip(action, -self.bound, self.bound) |
|
|
|
|
|
x = self.proj(action) |
|
|
x = self.encoder(x, train=train, state_conditioning=obs) |
|
|
|
|
|
return self.codebook.encode(x) |
|
|
|
|
|
def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray: |
|
|
x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs) |
|
|
mean = self.proj_mean(x) |
|
|
return mean * self.out_scale |
|
|
|
|
|
def loss( |
|
|
self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True |
|
|
) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: |
|
|
|
|
|
x = self.proj(action) |
|
|
z = self.encoder(x, train=train, state_conditioning=obs) |
|
|
|
|
|
|
|
|
tokens, z = self.codebook(z) |
|
|
|
|
|
|
|
|
x = self.decoder(z, train=train, state_conditioning=obs) |
|
|
mean = self.proj_mean(x) * self.out_scale |
|
|
|
|
|
mse = jnp.mean(jnp.square(action - mean)) |
|
|
mae = jnp.mean(jnp.abs(action - mean)) |
|
|
|
|
|
return mse, { |
|
|
"mse": mse, |
|
|
"mae": mae, |
|
|
} |
|
|
|
|
|
def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: |
|
|
""" |
|
|
Dummy for .init |
|
|
""" |
|
|
return self.loss(*args, **kwargs) |
|
|
|