| import mlx.core as mx |
| import mlx.nn as nn |
| import mlx.utils as utils |
| import mlx.optimizers as optim |
|
|
| import os |
| import math |
| import time |
| import umap |
| import numpy as np |
| import torch |
| import pandas as pd |
| import plotly.express as px |
| import matplotlib.pyplot as plt |
|
|
| from functools import partial |
| from datasets import load_dataset |
| from IPython.display import clear_output |
| from collections import Counter, OrderedDict |
| from typing import Callable, List, Optional, Tuple, Union |
| mx.enable_compile() |
|
|
| @mx.custom_function |
| @mx.checkpoint |
| def _rwkv_v7_state_update_loop_metal(r, w, k, v, a, b, state): |
|
|
| B, T, H, N = r.shape |
| C = H * N |
| |
| assert (N % 4 == 0 and N < 256) |
| |
| source = ( |
| f""" |
| thread int T = w_shape[1]; |
| thread int H = w_shape[2]; |
| thread int N = w_shape[3]; |
| |
| thread int loc_mat = thread_position_in_grid.x*H*N*N + thread_position_in_grid.y*N*N + thread_position_in_grid.z*N; |
| thread int loc_vec = thread_position_in_grid.x*T*H*N + thread_position_in_grid.y*N; |
| thread int loc_tmp = thread_position_in_grid.x*T*H*N + thread_position_in_grid.y*N + thread_position_in_grid.z; |
| thread int loc_tmp_ = thread_position_in_grid.x*H*N + thread_position_in_grid.y*N + thread_position_in_grid.z; |
| |
| thread float4 w_vec; |
| thread float4 b_vec; |
| thread float4 k_vec; |
| thread float4 r_vec; |
| thread float4 a_vec; |
| thread float4 out_prev[{N // 4}]; |
| thread float tmp_; |
| thread float _tmp_; |
| thread float out_acc; |
| """ + """ |
| MLX_MTL_PRAGMA_UNROLL |
| for (int i = 0; i < N; i+= 4){ |
| a_vec = float4(a[loc_vec + i], a[loc_vec + i+1], a[loc_vec + i+2], a[loc_vec + i+3]); |
| out_prev[i/4] = float4(state[loc_mat + i], state[loc_mat + i+1], state[loc_mat + i+2], state[loc_mat + i+3]); |
| _tmp_ += metal::dot(out_prev[i/4], a_vec); |
| } |
| |
| MLX_MTL_PRAGMA_UNROLL |
| for (int t = 0; t < T; t++){ |
| tmp_ = _tmp_; |
| _tmp_ = 0; |
| out_acc = 0; |
| |
| MLX_MTL_PRAGMA_UNROLL |
| for (int i = 0; i < N; i+= 4){ |
| w_vec = float4(w[loc_vec + i], w[loc_vec + i+1], w[loc_vec + i+2], w[loc_vec + i+3]); |
| b_vec = float4(b[loc_vec + i], b[loc_vec + i+1], b[loc_vec + i+2], b[loc_vec + i+3]); |
| k_vec = float4(k[loc_vec + i], k[loc_vec + i+1], k[loc_vec + i+2], k[loc_vec + i+3]); |
| r_vec = float4(r[loc_vec + i], r[loc_vec + i+1], r[loc_vec + i+2], r[loc_vec + i+3]); |
| a_vec = float4(a[loc_vec + H*N + i], a[loc_vec + H*N + i+1], a[loc_vec + H*N + i+2], a[loc_vec + H*N + i+3]); |
| |
| out_prev[i/4] = out_prev[i/4] * metal::exp(w_vec) + tmp_ * b_vec + v[loc_tmp] * k_vec; |
| |
| out_acc += metal::dot(out_prev[i/4], r_vec); |
| _tmp_ += metal::dot(out_prev[i/4], a_vec); |
| } |
| |
| out[loc_tmp] = out_acc; |
| |
| loc_vec += H * N; |
| loc_tmp += H * N; |
| } |
| |
| MLX_MTL_PRAGMA_UNROLL |
| for (int i = 0; i < N; i+= 4){ |
| state_out[loc_mat + i] = out_prev[i/4].x; |
| state_out[loc_mat + i+1] = out_prev[i/4].y; |
| state_out[loc_mat + i+2] = out_prev[i/4].z; |
| state_out[loc_mat + i+3] = out_prev[i/4].w; |
| } |
| """ |
| ) |
| |
| header = """ |
| #define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)") |
| """ |
| |
| kernel = mx.fast.metal_kernel( |
| name="rwkv_v7_state_update_loop_metal", |
| input_names=["r", "w", "k", "v", "a", "b", "state"], |
| output_names=["out", "state_out"], |
| source=source, |
| header=header, |
| atomic_outputs=False, |
| ensure_row_contiguous=True, |
| ) |
| outputs = kernel( |
| inputs=[r, w, k, v, a, b, state], |
| output_shapes=[(B, T, C), (B, H, N, N)], |
| output_dtypes=[mx.float32, mx.float32], |
| grid=(B, H, N), |
| threadgroup=(1, 1, 1), |
| init_value=0, |
| ) |
| return (outputs[0], outputs[1]) |
|
|
| def rwkv_v7_state_update_loop(*args): |
| return _rwkv_v7_state_update_loop_metal(*args) |
|
|
| @mx.custom_function |
| @mx.checkpoint |
| def relu_2(x): |
|
|
| n_outputs = np.prod(x.shape) |
| source = """ |
| const uint stride = 4; // Process 4 elements at a time |
| uint gid = thread_position_in_grid.x; |
| uint baseIndex = gid * stride; |
| uint x_length = x_shape[0]; |
| |
| // Main SIMD loop for processing groups of 4 |
| if (baseIndex + stride <= x_length) { |
| float4 input = float4(x[baseIndex], x[baseIndex + 1], x[baseIndex + 2], x[baseIndex + 3]); |
| float4 relu = max(input, float4(0.0)); |
| float4 output = relu * relu; |
| |
| a[baseIndex] = T(output[0]); |
| a[baseIndex + 1] = T(output[1]); |
| a[baseIndex + 2] = T(output[2]); |
| a[baseIndex + 3] = T(output[3]); |
| } |
| |
| // Handle leftover elements |
| uint leftoverIndex = x_length - (x_length % stride); |
| if (gid == 0 && leftoverIndex < x_length) { // Only the first thread processes leftovers |
| for (uint i = leftoverIndex; i < x_length; ++i) { |
| T value = T(max(x[i], T(0))); |
| a[i] = value * value; |
| } |
| } |
| """ |
| |
| kernel = mx.fast.metal_kernel( |
| name="relu2_metal", |
| input_names=["x"], |
| output_names=["a"], |
| source=source, |
| atomic_outputs=False, |
| ensure_row_contiguous=True, |
| ) |
| outputs = kernel( |
| inputs=[x.flatten()], |
| template=[("T", x.dtype)], |
| output_shapes=[x.shape], |
| output_dtypes=[x.dtype], |
| grid=(n_outputs // 4, 1, 1), |
| threadgroup=(32, 1, 1), |
| ) |
| return outputs[0] |
|
|
| @relu_2.vjp |
| def relu_2_vjp(primals, cotangents, _outputs_): |
| x = primals |
| cotangent = cotangents |
| n_outputs = np.prod(x.shape) |
| |
| source = """ |
| const uint stride = 4; // Process 4 elements at a time |
| uint gid = thread_position_in_grid.x; |
| uint baseIndex = gid * stride; |
| uint x_length = x_shape[0]; |
| |
| // Main SIMD loop: process groups of 4 elements |
| if (baseIndex + stride <= x_length) { |
| // Load 4 elements into a vector |
| float4 input = float4(x[baseIndex], x[baseIndex + 1], x[baseIndex + 2], x[baseIndex + 3]); |
| float4 cotan = float4(c[baseIndex], c[baseIndex + 1], c[baseIndex + 2], c[baseIndex + 3]); |
| |
| // Compute 2 * x[loc] and apply ReLU using max intrinsic |
| float4 output = max(2 * input, float4(0.0)) * cotan; |
| |
| // Write results back to x_grad |
| x_grad[baseIndex] = T(output[0]); |
| x_grad[baseIndex + 1] = T(output[1]); |
| x_grad[baseIndex + 2] = T(output[2]); |
| x_grad[baseIndex + 3] = T(output[3]); |
| } |
| |
| // Handle leftover elements (if any) |
| uint leftoverIndex = x_length - (x_length % stride); |
| if (gid == 0 && leftoverIndex < x_length) { // Only the first thread handles leftovers |
| for (uint i = leftoverIndex; i < x_length; ++i) { |
| x_grad[i] = T(max(2 * x[i], T(0)) * c[i]); |
| } |
| } |
| """ |
| |
| kernel = mx.fast.metal_kernel( |
| name="relu2_metal_vjp", |
| input_names=["x", "c"], |
| output_names=["x_grad"], |
| source=source, |
| atomic_outputs=False, |
| ensure_row_contiguous=True, |
| ) |
| outputs = kernel( |
| inputs=[x.flatten(), cotangent.flatten()], |
| template=[("T", x.dtype)], |
| output_shapes=[x.shape], |
| output_dtypes=[x.dtype], |
| grid=(n_outputs // 4, 1, 1), |
| threadgroup=(32, 1, 1), |
| ) |
| return outputs[0] |
|
|
| def RWKV7_OP_V(q, k, v, d, kk, ka, state=None): |
| dtype = q.dtype |
| |
| B, T, C = q.shape |
| B, H, N, N = state.shape |
| s = state |
|
|
| out, state = RWKV7_OP(q, k, v, d, kk, ka, state=state) |
|
|
| q,k,v,d,kk,ka = q.reshape(B,T,H,N), k.reshape(B,T,H,N), v.reshape(B,T,H,N), d.reshape(B,T,H,N), kk.reshape(B,T,H,N), ka.reshape(B,T,H,N) |
| q,k,v,d,kk,ka = q.astype(mx.float32),k.astype(mx.float32),v.astype(mx.float32),d.astype(mx.float32),kk.astype(mx.float32),ka.astype(mx.float32) |
| |
| if state is None: |
| state = mx.zeros([B, H, N, N]) |
|
|
| d_exp = mx.exp(d) |
| attn = mx.zeros([B, T, T, H, N]) |
| |
| for t in range(T): |
| kk_ = kk[:, t, :, :, None] |
| ka_ = ka[:, t, :, None, :] |
| |
| d_ = d_exp[:, t, :, :] |
| |
| M = mx.vmap(mx.vmap(mx.diag))(d_) + kk_ @ ka_ |
|
|
| s = s @ M |
| k[:, :t, :, :] = mx.einsum('bthm,bhmn->bthn', k[:, :t, :, :], M) |
| |
| attn[:, t, :t+1, :, :] = k[:, :t+1, :, :] |
| |
| attn = mx.einsum('btThm,bthm->btTh', attn, q) |
| |
| |
| |
| return out.astype(dtype), state, attn |
|
|
| def RWKV7_OP(q, k, v, d, kk, ka, state=None): |
| dtype = q.dtype |
| |
| B, T, C = q.shape |
| B, H, N, N = state.shape |
|
|
| q,k,v,d,kk,ka = q.reshape(B,T,H,N), k.reshape(B,T,H,N), v.reshape(B,T,H,N), d.reshape(B,T,H,N), kk.reshape(B,T,H,N), ka.reshape(B,T,H,N) |
| q,k,v,d,kk,ka = q.astype(mx.float32),k.astype(mx.float32),v.astype(mx.float32),d.astype(mx.float32),kk.astype(mx.float32),ka.astype(mx.float32) |
| |
| if state is None: |
| state = mx.zeros([B, H, N, N]) |
|
|
| out, state = rwkv_v7_state_update_loop(q, d, k, v, kk, ka, state) |
| |
| return out.astype(dtype), state |
|
|
| def ATTN_OP_V(q, k, v): |
| dtype = q.dtype |
| |
| B, T, H, N = q.shape |
|
|
| q,k,v = q.reshape(B,T,H,N), k.reshape(B,T,H,N), v.reshape(B,T,H,N) |
| q,k,v = q.astype(mx.float32),k.astype(mx.float32),v.astype(mx.float32) |
| |
| attn = mx.swapaxes(q, -3, -2) @ mx.moveaxis(k, -3, -1) |
| |
| mask = mx.tril(mx.ones_like(attn)) |
|
|
| mask = mx.log(mask) |
|
|
| attn = mx.softmax((attn + mask) / (N ** 0.5), axis=-1) |
|
|
| out = attn @ mx.swapaxes(v, -3, -2) |
|
|
| out = mx.swapaxes(out, -3, -2).reshape(B, T, H*N) |
|
|
| attn = mx.moveaxis(attn, -3, -1) |
| |
| return out.astype(dtype), attn |
|
|
| def norm(x): |
| return mx.fast.rms_norm(x, weight=None, eps=1e-8) |
|
|
| class RWKV_Tmix_x070(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| layers: int, |
| num_heads: int, |
| input_dims: int, |
| hidden_dims: Union[int, None] = None, |
| ): |
| super().__init__() |
| |
| self.layer_id = layer_id |
| self.head_dim = input_dims // num_heads |
| self.num_heads = num_heads |
| assert input_dims % self.num_heads == 0 |
|
|
| H = self.num_heads |
| N = self.head_dim |
| C = input_dims |
| |
| self.x_k = mx.ones([C]) |
|
|
| self.q_proj = nn.Linear(C, C, bias=False) |
| self.k_proj = nn.Linear(C, C, bias=False) |
| self.v_proj = nn.Linear(C, C, bias=False) |
| self.o_proj = nn.Linear(C, C, bias=False) |
| |
| self.d_proj = nn.Linear(C, C, bias=True) |
| self.a_proj = nn.Linear(C, H, bias=False) |
|
|
| def __call__(self, x, x_prev=None, state_prev=None, attn_map=False): |
| B, T, C = x.shape |
| H = self.num_heads |
| N = C // H |
|
|
| def forward1(x, x_prev): |
| def normalize_last(x): |
| return mx.fast.rms_norm(x, weight=None, eps=1e-8) * x.shape[-1] ** -0.5 |
| |
| x = norm(x) |
| |
| xx = mx.concatenate([x_prev, x], axis=1)[:, :-1] - x |
|
|
| xk = x + xx * self.x_k |
|
|
| q = self.q_proj(x) |
| k = self.k_proj(xk) |
| v = self.v_proj(x) |
| d = self.d_proj(x) |
| a = self.a_proj(x) |
|
|
| d = -mx.sigmoid(d) |
| a = mx.sigmoid(a).reshape(B,T,H,-1) |
|
|
| q = normalize_last(q.reshape(B, T, H, N)).reshape(B, T, C) |
| k = normalize_last(k.reshape(B, T, H, N)) |
| |
| k_, kk, ka = k*a, -k, k*a |
|
|
| x_prev = x[:, -1] |
| |
| return q, k_, v, d, kk, ka, x_prev.reshape(B, 1, C) |
|
|
| q, k, v, d, kk, ka, x_prev = forward1(x, x_prev) |
|
|
| if not attn_map: |
| x, state_prev = RWKV7_OP(q, k, v, d, kk, ka, state=state_prev) |
| attn = None |
| else: |
| x, state_prev, attn = RWKV7_OP_V(q, k, v, d, kk, ka, state=state_prev) |
|
|
| def forward2(x): |
| x = self.o_proj(x) |
| return x |
| |
| x = forward2(x) |
|
|
| return x, x_prev, state_prev, attn |
|
|
| class SelfAttn(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| layers: int, |
| num_heads: int, |
| input_dims: int, |
| hidden_dims: Union[int, None] = None, |
| ): |
| super().__init__() |
| |
| self.layer_id = layer_id |
| self.head_dim = input_dims // num_heads |
| self.num_heads = num_heads |
| assert input_dims % self.num_heads == 0 |
|
|
| H = self.num_heads |
| N = self.head_dim |
| C = input_dims |
| |
| self.x_k = mx.ones([C]) |
|
|
| self.q_proj = nn.Linear(C, C, bias=False) |
| self.k_proj = nn.Linear(C, C, bias=False) |
| self.v_proj = nn.Linear(C, C, bias=False) |
| self.o_proj = nn.Linear(C, C, bias=False) |
|
|
| def __call__(self, x, x_prev=None, state_prev=None, attn_map=False): |
| B, T, C = x.shape |
| H = self.num_heads |
| N = C // H |
|
|
| def forward1(x, x_prev): |
| def normalize_last(x): |
| return mx.fast.rms_norm(x, weight=None, eps=1e-8) * x.shape[-1] ** -0.5 |
| |
| x = norm(x) |
| |
| xx = mx.concatenate([x_prev, x], axis=1)[:, :-1] - x |
|
|
| xk = x + xx * self.x_k |
|
|
| q = self.q_proj(x).reshape(B, T, H, N) |
| k = self.k_proj(xk) |
| v = self.v_proj(x) |
|
|
| x_prev = x[:, -1] |
| |
| return q, k, v, x_prev.reshape(B, 1, C) |
|
|
| q, k, v, x_prev = forward1(x, x_prev) |
|
|
| if not attn_map: |
| q = mx.swapaxes(q.reshape(B, T, H, N), -3, -2) |
| k = mx.swapaxes(k.reshape(B, T, H, N), -3, -2) |
| v = mx.swapaxes(v.reshape(B, T, H, N), -3, -2) |
| x = mx.fast.scaled_dot_product_attention(q, k, v, scale=(N ** -0.5), mask='causal') |
| x = mx.swapaxes(x, -3, -2) |
| attn = None |
| else: |
| x, attn = ATTN_OP_V(q, k, v) |
| |
| def forward2(x): |
| x = x.reshape(B, T, C) |
| x = self.o_proj(x) |
| return x |
|
|
| x = forward2(x) |
|
|
| return x, x_prev, state_prev, attn |
|
|
| class RWKV_CMix_x070(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| layers: int, |
| num_heads: int, |
| input_dims: int, |
| hidden_dims: Union[int, None] = None, |
| ): |
| super().__init__() |
| |
| self.layer_id = layer_id |
| |
| C = input_dims |
| hidden_dims = hidden_dims or 4 * C |
|
|
| self.x_k = mx.ones([C]) |
| |
| self.k_proj = nn.Linear(C, hidden_dims, bias=False) |
| self.v_proj = nn.Linear(hidden_dims, C, bias=False) |
|
|
| def __call__(self, x, x_prev=None, ffn_act=False): |
| B, T, C = x.shape |
| |
| def forward1(x, x_prev, ffn_act): |
| x = norm(x) |
| |
| xx = mx.concatenate([x_prev, x], axis=1)[:, :-1] - x |
| |
| k = x + xx * self.x_k |
| k = nn.relu(self.k_proj(k)) ** 2 |
|
|
| x_prev = x[:, -1] |
| |
| return self.v_proj(k), x_prev.reshape(B, 1, C), k if ffn_act else None |
| |
| output, x_prev, ffn = forward1(x, x_prev, ffn_act) |
| |
| return output, x_prev, ffn |
|
|
| class RWKV_v7_Block(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| layers: int, |
| num_heads: int, |
| input_dims: int, |
| hidden_dims: Union[int, None] = None, |
| ): |
| super().__init__() |
| self.layer_id = layer_id |
|
|
| self.att = RWKV_Tmix_x070(layer_id, layers, num_heads, input_dims, hidden_dims) |
| self.ffn = RWKV_CMix_x070(layer_id, layers, num_heads, input_dims, hidden_dims) |
| |
| def __call__(self, x, x_prev_0=None, state_prev=None, x_prev_1=None, attn_map=False, ffn_act=False): |
| xx, x_prev_0, state_prev, attn = self.att(x, x_prev_0, state_prev, attn_map) |
| x = x + xx |
| |
|
|
| xx, x_prev_1, ffn = self.ffn(x, x_prev_1, ffn_act) |
| x = x + xx |
| |
|
|
| return x, x_prev_0, state_prev, x_prev_1, attn, ffn |
|
|
| class Softmax_Block(nn.Module): |
| def __init__( |
| self, |
| layer_id: int, |
| layers: int, |
| num_heads: int, |
| input_dims: int, |
| hidden_dims: Union[int, None] = None, |
| ): |
| super().__init__() |
| self.layer_id = layer_id |
|
|
| self.att = SelfAttn(layer_id, layers, num_heads, input_dims, hidden_dims) |
| self.ffn = RWKV_CMix_x070(layer_id, layers, num_heads, input_dims, hidden_dims) |
| |
| def __call__(self, x, x_prev_0=None, state_prev=None, x_prev_1=None, attn_map=False, ffn_act=False): |
| xx, x_prev_0, state_prev, attn = self.att(x, x_prev_0, state_prev, attn_map) |
| x = x + xx |
| |
|
|
| xx, x_prev_1, ffn = self.ffn(x, x_prev_1, ffn_act) |
| x = x + xx |
| |
|
|
| return x, x_prev_0, state_prev, x_prev_1, attn, ffn |
|
|
| class RWKV_v7(nn.Module): |
| def __init__( |
| self, |
| layers: int, |
| num_heads: int, |
| vocab_size: int, |
| input_dims: int, |
| hidden_dims: Union[int, None] = None, |
| mtp_heads: int = 0, |
| dtype = None |
| ): |
| super().__init__() |
| |
| self.emb = Embedding(vocab_size, input_dims) |
|
|
| self.blocks = ([RWKV_v7_Block(i, layers, num_heads, input_dims, hidden_dims) if i%2 == 0 else Softmax_Block(i, layers, num_heads, input_dims, hidden_dims) for i in range(layers)]) |
|
|
| self.dtype = dtype |
|
|
| def __call__(self, idx, x_prev_0s=None, state_prevs=None, x_prev_1s=None, attn_map=False, ffn_act=False): |
| |
| x = norm(self.emb(idx, dtype=self.dtype)) |
|
|
| attns = [] |
| ffns = [] |
| x_prev_0s_ = mx.zeros_like(x_prev_0s) |
| state_prevs_ = mx.zeros_like(state_prevs) |
| x_prev_1s_ = mx.zeros_like(x_prev_1s) |
|
|
| layers = len(self.blocks) |
| |
| for i, block in enumerate(self.blocks): |
| x, x_prev_0, state_prev, x_prev_1, attn, ffn = block(x, x_prev_0s[i], state_prevs[i], x_prev_1s[i], attn_map, ffn_act) |
| x_prev_0s_[i] = x_prev_0 |
| state_prevs_[i] = state_prev |
| x_prev_1s_[i] = x_prev_1 |
| attns.append(attn) |
| ffns.append(ffn) |
| |
|
|
| logits = norm(x) @ self.emb.weight.T |
|
|
| if not attn_map: |
| return logits, (x_prev_0s_, state_prevs_, x_prev_1s_) |
| else: |
| return logits, (x_prev_0s_, state_prevs_, x_prev_1s_), mx.stack(attns, axis=0), mx.stack(ffns, axis=0) |