talkie-box / talkie_mlx.py
N8Programs's picture
Upload folder using huggingface_hub
b8bde03 verified
# Copyright 2026
#
# MLX implementation of talkie-lm/talkie-1930-13b-base.
# This file is intentionally self-contained so an MLX model directory can load it
# through config.json: {"model_file": "talkie_mlx.py"}.
import math
from dataclasses import dataclass
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
from mlx_lm.models.base import BaseModelArgs, create_attention_mask
from mlx_lm.models.base import scaled_dot_product_attention
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str = "talkie"
vocab_size: int = 65536
hidden_size: int = 5120
num_hidden_layers: int = 40
num_attention_heads: int = 40
intermediate_size: int = 13696
head_dim: int = 128
max_position_embeddings: int = 2048
rope_theta: float = 1_000_000.0
tie_word_embeddings: bool = False
rms_norm_eps: Optional[float] = 1.1920928955078125e-7
def rms_norm(x: mx.array, eps: Optional[float] = None) -> mx.array:
if eps is None:
eps = mx.finfo(x.dtype).eps
return mx.fast.rms_norm(x, None, eps)
def apply_talkie_rope(x: mx.array, offset: int, base: float) -> mx.array:
"""Apply Talkie's split-half RoPE to tensors shaped [B, H, T, D]."""
head_dim = x.shape[-1]
half_dim = head_dim // 2
freqs = -mx.exp(
mx.arange(0.0, half_dim, dtype=mx.float32) * (math.log(base) / half_dim)
)
return mx.fast.rope(
x,
dims=head_dim,
traditional=False,
base=None,
freqs=freqs,
scale=1.0,
offset=offset,
)
class HeadGain(nn.Module):
def __init__(self, num_heads: int):
super().__init__()
self.head_g = mx.ones((num_heads,), dtype=mx.float32)
def __call__(self, x: mx.array) -> mx.array:
return x * self.head_g.astype(x.dtype).reshape(1, -1, 1, 1)
class WeightGain(nn.Module):
def __init__(self):
super().__init__()
self.w_g = mx.ones((1,), dtype=mx.float32)
def __call__(self, w: mx.array) -> mx.array:
return w * self.w_g.astype(w.dtype)
class ActGain(nn.Module):
def __init__(self, init_value: float):
super().__init__()
self.a_g = mx.array([init_value], dtype=mx.float32)
def __call__(self, x: mx.array) -> mx.array:
return x * self.a_g.astype(x.dtype)
class CausalSelfAttention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.n_head = args.num_attention_heads
self.head_dim = args.head_dim
self.rope_theta = args.rope_theta
self.rms_norm_eps = args.rms_norm_eps
self.scale = self.head_dim**-0.5
n_state = args.hidden_size
self.attn_query = nn.Linear(n_state, n_state, bias=False)
self.attn_key = nn.Linear(n_state, n_state, bias=False)
self.attn_value = nn.Linear(n_state, n_state, bias=False)
self.attn_resid = nn.Linear(n_state, n_state, bias=False)
self.head_gain = HeadGain(self.n_head)
def __call__(
self,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
bsz, seq_len, _ = x.shape
q = self.attn_query(x).reshape(bsz, seq_len, self.n_head, self.head_dim)
k = self.attn_key(x).reshape(bsz, seq_len, self.n_head, self.head_dim)
v = self.attn_value(x).reshape(bsz, seq_len, self.n_head, self.head_dim)
q = q.transpose(0, 2, 1, 3)
k = k.transpose(0, 2, 1, 3)
v = v.transpose(0, 2, 1, 3)
offset = cache.offset if cache is not None else 0
q = apply_talkie_rope(q, offset=offset, base=self.rope_theta)
k = apply_talkie_rope(k, offset=offset, base=self.rope_theta)
q = rms_norm(q, self.rms_norm_eps)
k = rms_norm(k, self.rms_norm_eps)
q = self.head_gain(q)
if cache is not None:
k, v = cache.update_and_fetch(k, v)
y = scaled_dot_product_attention(
q, k, v, cache=cache, scale=self.scale, mask=mask
)
y = y.transpose(0, 2, 1, 3).reshape(bsz, seq_len, -1)
return self.attn_resid(y)
class MLP(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
n_state = args.hidden_size
n_mlp = args.intermediate_size
self.mlp_gate = nn.Linear(n_state, n_mlp, bias=False)
self.mlp_linear = nn.Linear(n_state, n_mlp, bias=False)
self.mlp_resid = nn.Linear(n_mlp, n_state, bias=False)
def __call__(self, x: mx.array) -> mx.array:
gate = self.mlp_gate(x)
x = gate * mx.sigmoid(gate) * self.mlp_linear(x)
return self.mlp_resid(x)
class Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
init_gain = (2 * args.num_hidden_layers) ** -0.5
self.attn = CausalSelfAttention(args)
self.attn_gain = ActGain(init_gain)
self.mlp = MLP(args)
self.mlp_gain = ActGain(init_gain)
self.embed_skip = ActGain(0.0)
self.rms_norm_eps = args.rms_norm_eps
def __call__(
self,
e_x: mx.array,
x: mx.array,
mask: Optional[mx.array] = None,
cache: Optional[Any] = None,
) -> mx.array:
x = x + self.attn_gain(self.attn(rms_norm(x, self.rms_norm_eps), mask, cache))
x = x + self.mlp_gain(self.mlp(rms_norm(x, self.rms_norm_eps)))
x = x + self.embed_skip(e_x)
return x
class Model(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.model_type = args.model_type
self.embed = nn.Embedding(args.vocab_size, args.hidden_size)
self.blocks = [Block(args) for _ in range(args.num_hidden_layers)]
self.lm_head = mx.zeros((args.vocab_size, args.hidden_size), dtype=mx.float32)
self.lm_head_gain = WeightGain()
def __call__(
self,
input_ids: mx.array,
cache: Optional[Any] = None,
input_embeddings: Optional[mx.array] = None,
) -> mx.array:
if input_embeddings is not None:
x = input_embeddings
else:
x = self.embed(input_ids)
x = rms_norm(x, self.args.rms_norm_eps)
e_x = x
if cache is None:
cache = [None] * len(self.blocks)
mask = create_attention_mask(x, cache[0])
for block, c in zip(self.blocks, cache):
x = block(e_x, x, mask=mask, cache=c)
x = rms_norm(x, self.args.rms_norm_eps)
return x @ self.lm_head_gain(self.lm_head).T
@property
def layers(self):
return self.blocks