import math from typing import Any, Literal import chex from einops import einops from flax import linen as nn from flax.linen.module import Module from flax.linen.module import compact from flax.struct import dataclass from flax.typing import Array import jax import jax.numpy as jnp class FsqCodebook(nn.Module): input_dim: int target_codebook_size: int codebook_type: Literal["fsq", "lfq"] _bins_per_dim: tuple[int] | None = None @property def bins_per_dim(self) -> tuple[int]: if self._bins_per_dim is not None: return self._bins_per_dim if self.codebook_type == "fsq": return self._get_bins_fsq(self.target_codebook_size) elif self.codebook_type == "lfq": # noqa: RET505 return self._get_bins_lfq(self.target_codebook_size) elif self.codebook_type == "custom": return self._get_bins_custom(self.target_codebook_size) else: raise ValueError(f"Codebook type {self.codebook_type} not supported.") @property def place_values(self) -> jnp.ndarray: place_values = [1] for b in self.bins_per_dim[:-1]: place_values.append(place_values[-1] * b) return jnp.array(place_values) @staticmethod def _get_bins_fsq(target_codebook_size: int) -> tuple[int]: """ Get bins per dimension based on codebook size, from the original FSQ paper. """ if target_codebook_size == 2**8: return (8, 6, 5) elif target_codebook_size == 2**10: # noqa: RET505 return (8, 5, 5, 5) elif target_codebook_size == 2**12: return (7, 5, 5, 5, 5) elif target_codebook_size == 2**14: return (8, 8, 8, 6, 5) elif target_codebook_size == 2**16: return (8, 8, 8, 5, 5, 5) else: raise ValueError(f"Codebook size {target_codebook_size} not supported.") @staticmethod def _get_bins_custom(target_codebook_size: int) -> tuple[int]: if target_codebook_size == 2**8: return (16, 16) elif target_codebook_size == 2**10: # noqa: RET505 return (32, 32) elif target_codebook_size == 2**12: return (64, 64) elif target_codebook_size == 2**14: return (128, 128) elif target_codebook_size == 2**16: return (256, 256) return None @staticmethod def _get_bins_lfq(target_codebook_size: int) -> tuple[int]: """ Get bins per dimension according to the Lookup-Free Quantization paper (2 bins per dimension) """ assert target_codebook_size & (target_codebook_size - 1) == 0, "Codebook size should be a power of two for LFQ" return (2,) * int(math.log2(target_codebook_size)) def setup(self): self.proj_down = nn.Dense(len(self.bins_per_dim)) self.proj_up = nn.Dense(self.input_dim) def __call__(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: tokens, z = self.encode(inputs) output = self.decode(tokens, z_grad=z) return tokens, output def encode(self, inputs: jnp.ndarray) -> tuple[jnp.ndarray, jnp.ndarray]: bases = jnp.array(self.bins_per_dim) x = self.proj_down(inputs) z = jnp.tanh(x) # Quantize digits = jnp.round((z + 1) * (bases - 1) / 2).astype(jnp.int32) tokens = self.undigitize(digits) return tokens, z def decode(self, tokens: jnp.ndarray, z_grad: jax.Array | None = None) -> jnp.ndarray: bases = jnp.array(self.bins_per_dim) digits = self.digitize(tokens) z_q = digits / (bases - 1) * 2 - 1 if z_grad is not None: chex.assert_equal_shape([z_q, z_grad]) z_q = jax.lax.stop_gradient(z_q - z_grad) + z_grad return self.proj_up(z_q) def undigitize(self, digits: jnp.ndarray) -> jnp.ndarray: return jnp.sum(digits * jnp.array(self.place_values), axis=-1) def digitize(self, tokens: jnp.ndarray) -> jnp.ndarray: return (tokens[..., None] // jnp.array(self.place_values)) % jnp.array(self.bins_per_dim) @property def vocab_size(self) -> int: return math.prod(self.bins_per_dim) class ResNetDownBlock(nn.Module): stride: int = 1 n_filters: int = 64 dropout_rate: float = 0.0 group_size: int = 32 @nn.compact def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: skip = x if self.stride > 1 or x.shape[-1] != self.n_filters: skip = nn.Conv(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) x = nn.Conv(self.n_filters, (3,), (self.stride,), "SAME")(x) x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) x = nn.relu(x) x = nn.Conv(self.n_filters, (3,), (1,), "SAME")(x) return skip + x class ResNetUpBlock(nn.Module): stride: int = 1 n_filters: int = 64 dropout_rate: float = 0.0 group_size: int = 32 @nn.compact def __call__(self, x: jnp.ndarray, *, train: bool = True) -> jnp.ndarray: skip = x if self.stride > 1: skip = nn.ConvTranspose(self.n_filters, (self.stride,), (self.stride,), "SAME")(skip) x = nn.ConvTranspose(self.n_filters, (3,), (self.stride,), "SAME")(x) x = nn.GroupNorm(num_groups=self.n_filters // self.group_size)(x) x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) x = nn.relu(x) x = nn.ConvTranspose(self.n_filters, (3,), (1,), "SAME")(x) return skip + x @dataclass class LfqCodebookOutput: tokens: jnp.ndarray z: jnp.ndarray z_q: jnp.ndarray token_log_probs: jnp.ndarray commit_loss: jnp.ndarray class LookupFreeQuantization(nn.Module): num_dims: int latent_dim: int def setup(self): self.codebook = jnp.array([-1, 1]) self.activation = nn.tanh self.project_down = nn.Dense(self.num_dims) self.project_up = nn.Dense(self.latent_dim) def encode(self, z: jnp.ndarray) -> jnp.ndarray: z = self.project_down(z) token_squared_distances = jnp.square(z[..., None] - self.codebook) token_bits = jnp.argmin(token_squared_distances, axis=-1) return jnp.sum(token_bits * (2 ** jnp.arange(self.num_dims)), axis=-1) def decode(self, tokens: jnp.ndarray) -> jnp.ndarray: token_bits = (tokens[..., None] & (2 ** jnp.arange(self.num_dims))).astype(jnp.int32) return self.project_up(self.codebook[token_bits]) def loss(self, x: jnp.ndarray) -> LfqCodebookOutput: z = self.project_down(x) z = self.activation(z) token_squared_distances = jnp.square(z[..., None] - self.codebook) tokens = jnp.argmin(token_squared_distances, axis=-1) token_bit_log_probs = -token_squared_distances # Compute token log probs for tokens 0..2^num_dims-1 by summing corresponding log-probs token_bit_expansions = jnp.bitwise_and( jnp.arange(2**self.num_dims)[None, :], 2 ** jnp.arange(self.num_dims)[:, None] ).astype(jnp.int32) token_log_probs = ( token_bit_log_probs[..., 0] @ (1 - token_bit_expansions) + token_bit_log_probs[..., 1] @ token_bit_expansions ) # (batch_size, num_tokens, 2 ** num_dims) token_log_probs = jax.lax.stop_gradient(jax.nn.log_softmax(token_log_probs, axis=-1)) chex.assert_shape(token_log_probs, (*x.shape[:-1], 2**self.num_dims)) z_q = self.codebook[tokens] commit_loss = jnp.square(z - z_q).mean() z_q = jax.lax.stop_gradient(z_q - z) + z z_q = self.project_up(z_q) z = self.project_up(z) tokens = jnp.sum(tokens * (len(self.codebook) ** jnp.arange(self.num_dims)), axis=-1) return LfqCodebookOutput( tokens=tokens, z=z, z_q=z_q, token_log_probs=jnp.zeros(()), commit_loss=commit_loss, ) def make_block_causal_attention_matrix(q: jnp.ndarray, k: jnp.ndarray, bs_q: int, bs_k: int) -> jnp.ndarray: return nn.make_attention_mask(q, k, pairwise_fn=lambda x, y: jnp.greater_equal(x // bs_k, y // bs_q)) class GeGLU(Module): """Gated Linear Unit with GELU (GeGLU) activation function. GeGLU is a Flax layer that combines a linear transformation with a GELU activation function in a gating mechanism. It is often used in Transformer models to provide non-linear capabilities while preserving a strong linear component. Attributes: features: the number of output features (default: None). """ output_dim: int = -1 @compact def __call__(self, inputs: Array) -> Array: """Applies the GeGLU activation to the inputs. Args: inputs: the nd-array to apply the GeGLU activation function to. Returns: The transformed input. """ output_dim = inputs.shape[-1] if self.output_dim == -1 else self.output_dim x = nn.Dense(output_dim * 2)(inputs) x, gate = x[..., :output_dim], x[..., output_dim:] return x * nn.gelu(gate) class CrossAttentionLayer(nn.Module): dropout_rate: float = 0.0 num_heads: int = None causal: bool = False mlp_ratio: float = 4.0 @nn.compact def __call__( self, x: jnp.ndarray, y: jnp.ndarray, *, mask_self: jnp.ndarray | None = None, mask_cross: jnp.ndarray | None = None, train: bool = True, ) -> jnp.ndarray: d_embed = x.shape[-1] seq_len_q = x.shape[-2] seq_len_k = y.shape[-2] if self.causal: # One block size will be 1 bs_q = max(seq_len_q // seq_len_k, 1) bs_k = max(seq_len_k // seq_len_q, 1) mask_self = nn.make_causal_mask(x[..., 0]) mask_cross = make_block_causal_attention_matrix(x[..., 0], y[..., 0], bs_q, bs_k) # Self-attention block skip = x x = nn.LayerNorm()(x) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads or d_embed // 64, dropout_rate=self.dropout_rate, deterministic=not train, )(x, x, x, mask=mask_self) x = skip + x # Cross-attention block skip = x x = nn.LayerNorm()(x) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads or d_embed // 64, dropout_rate=self.dropout_rate, deterministic=not train, )(x, y, y, mask=mask_cross) x = skip + x # MLP block skip = x x = nn.LayerNorm()(x) x = nn.Dense(int(d_embed * self.mlp_ratio))(x) x = nn.Dropout(self.dropout_rate)(x, deterministic=not train) x = GeGLU()(x) x = nn.Dense(d_embed)(x) return skip + x def sinusoidal_pe_init(_, shape: tuple[int, int]) -> jnp.ndarray: seq_len, d_embed = shape position = jnp.arange(0, seq_len, 1) div_term = jnp.exp(jnp.arange(0, d_embed, 2) * -(jnp.log(10000.0) / d_embed)) return jnp.concatenate( [ jnp.sin(position[:, jnp.newaxis] * div_term), jnp.cos(position[:, jnp.newaxis] * div_term), ], axis=-1, ) class TokenizerEncoderDecoder(nn.Module): num_tokens: int num_cross_tokens: int num_layers: int causal: bool mlp_ratio: float = 4.0 use_state_conditioning: bool = False @nn.compact def __call__( self, y: jnp.ndarray, *, train: bool = True, state_conditioning: jnp.ndarray | None = None, mask: jnp.ndarray | None = None, ) -> jnp.ndarray: x = self.param("q_embed", sinusoidal_pe_init, (self.num_tokens, y.shape[-1])) x = jax.numpy.broadcast_to(x, y.shape[:-2] + x.shape[-2:]) if mask is not None: # mask is (batch_dims..., num_cross_tokens) chex.assert_equal_shape([y[..., 0], mask]) attn_mask = einops.repeat(mask, "... kv -> ... 1 q kv", q=self.num_tokens) else: attn_mask = jnp.ones((*y.shape[:-2], 1, self.num_tokens, self.num_cross_tokens)) if self.use_state_conditioning: assert state_conditioning is not None, "State conditioning is required for this model." state_embed = nn.Dense(y.shape[-1], name="state_proj")(state_conditioning)[..., None, :] y = jnp.concatenate([y, state_embed], axis=-2) attn_mask = jnp.concatenate([attn_mask, jnp.ones_like(attn_mask[..., 0:1])], axis=-1) y = y + self.param("y_pos_enc", sinusoidal_pe_init, y.shape[-2:]) for _ in range(self.num_layers): x = CrossAttentionLayer(causal=self.causal, mlp_ratio=self.mlp_ratio)( x, y, train=train, mask_self=None, mask_cross=attn_mask ) return x class FsqAttentionTokenizer(nn.Module): embed_dim: int data_dim: int data_horizon: int num_tokens: int num_layers: int target_codebook_size: int causal: bool = False mlp_ratio: float = 2.0 bound: float | None = None use_state_conditioning: bool = False @property def vocab_size(self) -> int: return math.prod(FsqCodebook._get_bins_fsq(self.target_codebook_size)) # noqa: SLF001 def setup(self): self.proj = nn.Dense(self.embed_dim) self.encoder = TokenizerEncoderDecoder( num_tokens=self.num_tokens, num_cross_tokens=self.data_horizon, num_layers=self.num_layers, causal=self.causal, use_state_conditioning=self.use_state_conditioning, mlp_ratio=self.mlp_ratio, ) self.codebook = FsqCodebook( input_dim=self.embed_dim, target_codebook_size=self.target_codebook_size, codebook_type="custom", ) self.decoder = TokenizerEncoderDecoder( num_tokens=self.data_horizon, num_cross_tokens=self.num_tokens, num_layers=self.num_layers, causal=self.causal, use_state_conditioning=self.use_state_conditioning, mlp_ratio=self.mlp_ratio, ) self.proj_mean = nn.Dense(self.data_dim) self.out_scale = self.param("out_scale", lambda _: jnp.full((), 1.0)) def tokenize( self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = False ) -> tuple[jnp.ndarray, jnp.ndarray]: if self.bound is not None: action = jnp.clip(action, -self.bound, self.bound) x = self.proj(action) x = self.encoder(x, train=train, state_conditioning=obs) return self.codebook.encode(x) def detokenize(self, tokens: jnp.ndarray, *, obs: jnp.ndarray | None = None) -> jnp.ndarray: x = self.decoder(self.codebook.decode(tokens), state_conditioning=obs) mean = self.proj_mean(x) return mean * self.out_scale def loss( self, action: jnp.ndarray, *, obs: jnp.ndarray | None = None, train: bool = True ) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: # Encode x = self.proj(action) z = self.encoder(x, train=train, state_conditioning=obs) # Quantize tokens, z = self.codebook(z) # Decode x = self.decoder(z, train=train, state_conditioning=obs) mean = self.proj_mean(x) * self.out_scale mse = jnp.mean(jnp.square(action - mean)) mae = jnp.mean(jnp.abs(action - mean)) return mse, { "mse": mse, "mae": mae, } def __call__(self, *args: Any, **kwargs: Any) -> tuple[jnp.ndarray, dict[str, jnp.ndarray]]: """ Dummy for .init """ return self.loss(*args, **kwargs)