xTimeCrystal's picture
Upload model.py
0100f82 verified
import mlx.core as mx
import mlx.nn as nn
import mlx.utils as utils
import mlx.optimizers as optim
import os
import math
import time
import umap
import numpy as np
import torch
import pandas as pd
import plotly.express as px
import matplotlib.pyplot as plt
from functools import partial
from datasets import load_dataset
from IPython.display import clear_output
from collections import Counter, OrderedDict
from typing import Callable, List, Optional, Tuple, Union
mx.enable_compile()
@mx.custom_function
@mx.checkpoint
def _rwkv_v7_state_update_loop_metal(r, w, k, v, a, b, state):
B, T, H, N = r.shape
C = H * N
assert (N % 4 == 0 and N < 256)
source = (
f"""
thread int T = w_shape[1];
thread int H = w_shape[2];
thread int N = w_shape[3];
thread int loc_mat = thread_position_in_grid.x*H*N*N + thread_position_in_grid.y*N*N + thread_position_in_grid.z*N;
thread int loc_vec = thread_position_in_grid.x*T*H*N + thread_position_in_grid.y*N;
thread int loc_tmp = thread_position_in_grid.x*T*H*N + thread_position_in_grid.y*N + thread_position_in_grid.z;
thread int loc_tmp_ = thread_position_in_grid.x*H*N + thread_position_in_grid.y*N + thread_position_in_grid.z;
thread float4 w_vec;
thread float4 b_vec;
thread float4 k_vec;
thread float4 r_vec;
thread float4 a_vec;
thread float4 out_prev[{N // 4}];
thread float tmp_;
thread float _tmp_;
thread float out_acc;
""" + """
MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < N; i+= 4){
a_vec = float4(a[loc_vec + i], a[loc_vec + i+1], a[loc_vec + i+2], a[loc_vec + i+3]);
out_prev[i/4] = float4(state[loc_mat + i], state[loc_mat + i+1], state[loc_mat + i+2], state[loc_mat + i+3]);
_tmp_ += metal::dot(out_prev[i/4], a_vec);
}
MLX_MTL_PRAGMA_UNROLL
for (int t = 0; t < T; t++){
tmp_ = _tmp_;
_tmp_ = 0;
out_acc = 0;
MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < N; i+= 4){
w_vec = float4(w[loc_vec + i], w[loc_vec + i+1], w[loc_vec + i+2], w[loc_vec + i+3]);
b_vec = float4(b[loc_vec + i], b[loc_vec + i+1], b[loc_vec + i+2], b[loc_vec + i+3]);
k_vec = float4(k[loc_vec + i], k[loc_vec + i+1], k[loc_vec + i+2], k[loc_vec + i+3]);
r_vec = float4(r[loc_vec + i], r[loc_vec + i+1], r[loc_vec + i+2], r[loc_vec + i+3]);
a_vec = float4(a[loc_vec + H*N + i], a[loc_vec + H*N + i+1], a[loc_vec + H*N + i+2], a[loc_vec + H*N + i+3]);
out_prev[i/4] = out_prev[i/4] * metal::exp(w_vec) + tmp_ * b_vec + v[loc_tmp] * k_vec;
out_acc += metal::dot(out_prev[i/4], r_vec);
_tmp_ += metal::dot(out_prev[i/4], a_vec);
}
out[loc_tmp] = out_acc;
loc_vec += H * N;
loc_tmp += H * N;
}
MLX_MTL_PRAGMA_UNROLL
for (int i = 0; i < N; i+= 4){
state_out[loc_mat + i] = out_prev[i/4].x;
state_out[loc_mat + i+1] = out_prev[i/4].y;
state_out[loc_mat + i+2] = out_prev[i/4].z;
state_out[loc_mat + i+3] = out_prev[i/4].w;
}
"""
)
header = """
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
"""
kernel = mx.fast.metal_kernel(
name="rwkv_v7_state_update_loop_metal",
input_names=["r", "w", "k", "v", "a", "b", "state"],
output_names=["out", "state_out"],
source=source,
header=header,
atomic_outputs=False,
ensure_row_contiguous=True,
)
outputs = kernel(
inputs=[r, w, k, v, a, b, state],
output_shapes=[(B, T, C), (B, H, N, N)],
output_dtypes=[mx.float32, mx.float32],
grid=(B, H, N),
threadgroup=(1, 1, 1),
init_value=0,
)
return (outputs[0], outputs[1])
def rwkv_v7_state_update_loop(*args):
return _rwkv_v7_state_update_loop_metal(*args)
@mx.custom_function
@mx.checkpoint
def relu_2(x):
n_outputs = np.prod(x.shape)
source = """
const uint stride = 4; // Process 4 elements at a time
uint gid = thread_position_in_grid.x;
uint baseIndex = gid * stride;
uint x_length = x_shape[0];
// Main SIMD loop for processing groups of 4
if (baseIndex + stride <= x_length) {
float4 input = float4(x[baseIndex], x[baseIndex + 1], x[baseIndex + 2], x[baseIndex + 3]);
float4 relu = max(input, float4(0.0));
float4 output = relu * relu;
a[baseIndex] = T(output[0]);
a[baseIndex + 1] = T(output[1]);
a[baseIndex + 2] = T(output[2]);
a[baseIndex + 3] = T(output[3]);
}
// Handle leftover elements
uint leftoverIndex = x_length - (x_length % stride);
if (gid == 0 && leftoverIndex < x_length) { // Only the first thread processes leftovers
for (uint i = leftoverIndex; i < x_length; ++i) {
T value = T(max(x[i], T(0)));
a[i] = value * value;
}
}
"""
kernel = mx.fast.metal_kernel(
name="relu2_metal",
input_names=["x"],
output_names=["a"],
source=source,
atomic_outputs=False,
ensure_row_contiguous=True,
)
outputs = kernel(
inputs=[x.flatten()],
template=[("T", x.dtype)],
output_shapes=[x.shape],
output_dtypes=[x.dtype],
grid=(n_outputs // 4, 1, 1),
threadgroup=(32, 1, 1),
)
return outputs[0]
@relu_2.vjp
def relu_2_vjp(primals, cotangents, _outputs_):
x = primals
cotangent = cotangents
n_outputs = np.prod(x.shape)
source = """
const uint stride = 4; // Process 4 elements at a time
uint gid = thread_position_in_grid.x;
uint baseIndex = gid * stride;
uint x_length = x_shape[0];
// Main SIMD loop: process groups of 4 elements
if (baseIndex + stride <= x_length) {
// Load 4 elements into a vector
float4 input = float4(x[baseIndex], x[baseIndex + 1], x[baseIndex + 2], x[baseIndex + 3]);
float4 cotan = float4(c[baseIndex], c[baseIndex + 1], c[baseIndex + 2], c[baseIndex + 3]);
// Compute 2 * x[loc] and apply ReLU using max intrinsic
float4 output = max(2 * input, float4(0.0)) * cotan;
// Write results back to x_grad
x_grad[baseIndex] = T(output[0]);
x_grad[baseIndex + 1] = T(output[1]);
x_grad[baseIndex + 2] = T(output[2]);
x_grad[baseIndex + 3] = T(output[3]);
}
// Handle leftover elements (if any)
uint leftoverIndex = x_length - (x_length % stride);
if (gid == 0 && leftoverIndex < x_length) { // Only the first thread handles leftovers
for (uint i = leftoverIndex; i < x_length; ++i) {
x_grad[i] = T(max(2 * x[i], T(0)) * c[i]);
}
}
"""
kernel = mx.fast.metal_kernel(
name="relu2_metal_vjp",
input_names=["x", "c"],
output_names=["x_grad"],
source=source,
atomic_outputs=False,
ensure_row_contiguous=True,
)
outputs = kernel(
inputs=[x.flatten(), cotangent.flatten()],
template=[("T", x.dtype)],
output_shapes=[x.shape],
output_dtypes=[x.dtype],
grid=(n_outputs // 4, 1, 1),
threadgroup=(32, 1, 1),
)
return outputs[0]
def RWKV7_OP_V(q, k, v, d, kk, ka, state=None):
dtype = q.dtype
B, T, C = q.shape
B, H, N, N = state.shape
s = state
out, state = RWKV7_OP(q, k, v, d, kk, ka, state=state)
q,k,v,d,kk,ka = q.reshape(B,T,H,N), k.reshape(B,T,H,N), v.reshape(B,T,H,N), d.reshape(B,T,H,N), kk.reshape(B,T,H,N), ka.reshape(B,T,H,N)
q,k,v,d,kk,ka = q.astype(mx.float32),k.astype(mx.float32),v.astype(mx.float32),d.astype(mx.float32),kk.astype(mx.float32),ka.astype(mx.float32)
if state is None:
state = mx.zeros([B, H, N, N])
d_exp = mx.exp(d)
attn = mx.zeros([B, T, T, H, N])
for t in range(T):
kk_ = kk[:, t, :, :, None]
ka_ = ka[:, t, :, None, :]
d_ = d_exp[:, t, :, :]
M = mx.vmap(mx.vmap(mx.diag))(d_) + kk_ @ ka_
s = s @ M
k[:, :t, :, :] = mx.einsum('bthm,bhmn->bthn', k[:, :t, :, :], M)
attn[:, t, :t+1, :, :] = k[:, :t+1, :, :]
attn = mx.einsum('btThm,bthm->btTh', attn, q)
# attn /= out.reshape(B, T, 1, H, N).square().mean(-1).sqrt()
return out.astype(dtype), state, attn
def RWKV7_OP(q, k, v, d, kk, ka, state=None):
dtype = q.dtype
B, T, C = q.shape
B, H, N, N = state.shape
q,k,v,d,kk,ka = q.reshape(B,T,H,N), k.reshape(B,T,H,N), v.reshape(B,T,H,N), d.reshape(B,T,H,N), kk.reshape(B,T,H,N), ka.reshape(B,T,H,N)
q,k,v,d,kk,ka = q.astype(mx.float32),k.astype(mx.float32),v.astype(mx.float32),d.astype(mx.float32),kk.astype(mx.float32),ka.astype(mx.float32)
if state is None:
state = mx.zeros([B, H, N, N])
out, state = rwkv_v7_state_update_loop(q, d, k, v, kk, ka, state)
return out.astype(dtype), state
def ATTN_OP_V(q, k, v):
dtype = q.dtype
B, T, H, N = q.shape
q,k,v = q.reshape(B,T,H,N), k.reshape(B,T,H,N), v.reshape(B,T,H,N)
q,k,v = q.astype(mx.float32),k.astype(mx.float32),v.astype(mx.float32)
attn = mx.swapaxes(q, -3, -2) @ mx.moveaxis(k, -3, -1)
mask = mx.tril(mx.ones_like(attn))
mask = mx.log(mask)
attn = mx.softmax((attn + mask) / (N ** 0.5), axis=-1)
out = attn @ mx.swapaxes(v, -3, -2)
out = mx.swapaxes(out, -3, -2).reshape(B, T, H*N)
attn = mx.moveaxis(attn, -3, -1)
return out.astype(dtype), attn
def norm(x):
return mx.fast.rms_norm(x, weight=None, eps=1e-8)
class RWKV_Tmix_x070(nn.Module):
def __init__(
self,
layer_id: int,
layers: int,
num_heads: int,
input_dims: int,
hidden_dims: Union[int, None] = None,
):
super().__init__()
self.layer_id = layer_id
self.head_dim = input_dims // num_heads
self.num_heads = num_heads
assert input_dims % self.num_heads == 0
H = self.num_heads
N = self.head_dim
C = input_dims
self.x_k = mx.ones([C])
self.q_proj = nn.Linear(C, C, bias=False)
self.k_proj = nn.Linear(C, C, bias=False)
self.v_proj = nn.Linear(C, C, bias=False)
self.o_proj = nn.Linear(C, C, bias=False)
self.d_proj = nn.Linear(C, C, bias=True)
self.a_proj = nn.Linear(C, H, bias=False)
def __call__(self, x, x_prev=None, state_prev=None, attn_map=False):
B, T, C = x.shape
H = self.num_heads
N = C // H
def forward1(x, x_prev):
def normalize_last(x):
return mx.fast.rms_norm(x, weight=None, eps=1e-8) * x.shape[-1] ** -0.5
x = norm(x)
xx = mx.concatenate([x_prev, x], axis=1)[:, :-1] - x
xk = x + xx * self.x_k
q = self.q_proj(x)
k = self.k_proj(xk)
v = self.v_proj(x)
d = self.d_proj(x)
a = self.a_proj(x)
d = -mx.sigmoid(d)
a = mx.sigmoid(a).reshape(B,T,H,-1)
q = normalize_last(q.reshape(B, T, H, N)).reshape(B, T, C)
k = normalize_last(k.reshape(B, T, H, N))
k_, kk, ka = k*a, -k, k*a
x_prev = x[:, -1]
return q, k_, v, d, kk, ka, x_prev.reshape(B, 1, C)
q, k, v, d, kk, ka, x_prev = forward1(x, x_prev)
if not attn_map:
x, state_prev = RWKV7_OP(q, k, v, d, kk, ka, state=state_prev)
attn = None
else:
x, state_prev, attn = RWKV7_OP_V(q, k, v, d, kk, ka, state=state_prev)
def forward2(x):
x = self.o_proj(x)
return x
x = forward2(x)
return x, x_prev, state_prev, attn
class SelfAttn(nn.Module):
def __init__(
self,
layer_id: int,
layers: int,
num_heads: int,
input_dims: int,
hidden_dims: Union[int, None] = None,
):
super().__init__()
self.layer_id = layer_id
self.head_dim = input_dims // num_heads
self.num_heads = num_heads
assert input_dims % self.num_heads == 0
H = self.num_heads
N = self.head_dim
C = input_dims
self.x_k = mx.ones([C])
self.q_proj = nn.Linear(C, C, bias=False)
self.k_proj = nn.Linear(C, C, bias=False)
self.v_proj = nn.Linear(C, C, bias=False)
self.o_proj = nn.Linear(C, C, bias=False)
def __call__(self, x, x_prev=None, state_prev=None, attn_map=False):
B, T, C = x.shape
H = self.num_heads
N = C // H
def forward1(x, x_prev):
def normalize_last(x):
return mx.fast.rms_norm(x, weight=None, eps=1e-8) * x.shape[-1] ** -0.5
x = norm(x)
xx = mx.concatenate([x_prev, x], axis=1)[:, :-1] - x
xk = x + xx * self.x_k
q = self.q_proj(x).reshape(B, T, H, N) # Shape detection in ATTN_OP_V
k = self.k_proj(xk)
v = self.v_proj(x)
x_prev = x[:, -1]
return q, k, v, x_prev.reshape(B, 1, C)
q, k, v, x_prev = forward1(x, x_prev)
if not attn_map:
q = mx.swapaxes(q.reshape(B, T, H, N), -3, -2)
k = mx.swapaxes(k.reshape(B, T, H, N), -3, -2)
v = mx.swapaxes(v.reshape(B, T, H, N), -3, -2)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=(N ** -0.5), mask='causal')
x = mx.swapaxes(x, -3, -2)
attn = None
else:
x, attn = ATTN_OP_V(q, k, v)
def forward2(x):
x = x.reshape(B, T, C)
x = self.o_proj(x)
return x
x = forward2(x)
return x, x_prev, state_prev, attn
class RWKV_CMix_x070(nn.Module):
def __init__(
self,
layer_id: int,
layers: int,
num_heads: int,
input_dims: int,
hidden_dims: Union[int, None] = None,
):
super().__init__()
self.layer_id = layer_id
C = input_dims
hidden_dims = hidden_dims or 4 * C
self.x_k = mx.ones([C])
self.k_proj = nn.Linear(C, hidden_dims, bias=False)
self.v_proj = nn.Linear(hidden_dims, C, bias=False)
def __call__(self, x, x_prev=None, ffn_act=False):
B, T, C = x.shape
def forward1(x, x_prev, ffn_act):
x = norm(x)
xx = mx.concatenate([x_prev, x], axis=1)[:, :-1] - x
k = x + xx * self.x_k
k = nn.relu(self.k_proj(k)) ** 2
x_prev = x[:, -1]
return self.v_proj(k), x_prev.reshape(B, 1, C), k if ffn_act else None
output, x_prev, ffn = forward1(x, x_prev, ffn_act)
return output, x_prev, ffn
class RWKV_v7_Block(nn.Module):
def __init__(
self,
layer_id: int,
layers: int,
num_heads: int,
input_dims: int,
hidden_dims: Union[int, None] = None,
):
super().__init__()
self.layer_id = layer_id
self.att = RWKV_Tmix_x070(layer_id, layers, num_heads, input_dims, hidden_dims)
self.ffn = RWKV_CMix_x070(layer_id, layers, num_heads, input_dims, hidden_dims)
def __call__(self, x, x_prev_0=None, state_prev=None, x_prev_1=None, attn_map=False, ffn_act=False):
xx, x_prev_0, state_prev, attn = self.att(x, x_prev_0, state_prev, attn_map)
x = x + xx
# print(x.square().mean())
xx, x_prev_1, ffn = self.ffn(x, x_prev_1, ffn_act)
x = x + xx
# print(x.square().mean())
return x, x_prev_0, state_prev, x_prev_1, attn, ffn
class Softmax_Block(nn.Module):
def __init__(
self,
layer_id: int,
layers: int,
num_heads: int,
input_dims: int,
hidden_dims: Union[int, None] = None,
):
super().__init__()
self.layer_id = layer_id
self.att = SelfAttn(layer_id, layers, num_heads, input_dims, hidden_dims)
self.ffn = RWKV_CMix_x070(layer_id, layers, num_heads, input_dims, hidden_dims)
def __call__(self, x, x_prev_0=None, state_prev=None, x_prev_1=None, attn_map=False, ffn_act=False):
xx, x_prev_0, state_prev, attn = self.att(x, x_prev_0, state_prev, attn_map)
x = x + xx
# print(x.square().mean())
xx, x_prev_1, ffn = self.ffn(x, x_prev_1, ffn_act)
x = x + xx
# print(x.square().mean())
return x, x_prev_0, state_prev, x_prev_1, attn, ffn
class RWKV_v7(nn.Module):
def __init__(
self,
layers: int,
num_heads: int,
vocab_size: int,
input_dims: int,
hidden_dims: Union[int, None] = None,
mtp_heads: int = 0,
dtype = None
):
super().__init__()
self.emb = Embedding(vocab_size, input_dims)
self.blocks = ([RWKV_v7_Block(i, layers, num_heads, input_dims, hidden_dims) if i%2 == 0 else Softmax_Block(i, layers, num_heads, input_dims, hidden_dims) for i in range(layers)])
self.dtype = dtype
def __call__(self, idx, x_prev_0s=None, state_prevs=None, x_prev_1s=None, attn_map=False, ffn_act=False):
x = norm(self.emb(idx, dtype=self.dtype))
attns = []
ffns = []
x_prev_0s_ = mx.zeros_like(x_prev_0s)
state_prevs_ = mx.zeros_like(state_prevs)
x_prev_1s_ = mx.zeros_like(x_prev_1s)
layers = len(self.blocks)
for i, block in enumerate(self.blocks):
x, x_prev_0, state_prev, x_prev_1, attn, ffn = block(x, x_prev_0s[i], state_prevs[i], x_prev_1s[i], attn_map, ffn_act)
x_prev_0s_[i] = x_prev_0
state_prevs_[i] = state_prev
x_prev_1s_[i] = x_prev_1
attns.append(attn)
ffns.append(ffn)
# print(x.square().mean())
logits = norm(x) @ self.emb.weight.T
if not attn_map:
return logits, (x_prev_0s_, state_prevs_, x_prev_1s_)
else:
return logits, (x_prev_0s_, state_prevs_, x_prev_1s_), mx.stack(attns, axis=0), mx.stack(ffns, axis=0)