# Copyright 2026 # # MLX implementation of talkie-lm/talkie-1930-13b-base. # This file is intentionally self-contained so an MLX model directory can load it # through config.json: {"model_file": "talkie_mlx.py"}. import math from dataclasses import dataclass from typing import Any, Optional import mlx.core as mx import mlx.nn as nn from mlx_lm.models.base import BaseModelArgs, create_attention_mask from mlx_lm.models.base import scaled_dot_product_attention @dataclass class ModelArgs(BaseModelArgs): model_type: str = "talkie" vocab_size: int = 65536 hidden_size: int = 5120 num_hidden_layers: int = 40 num_attention_heads: int = 40 intermediate_size: int = 13696 head_dim: int = 128 max_position_embeddings: int = 2048 rope_theta: float = 1_000_000.0 tie_word_embeddings: bool = False rms_norm_eps: Optional[float] = 1.1920928955078125e-7 def rms_norm(x: mx.array, eps: Optional[float] = None) -> mx.array: if eps is None: eps = mx.finfo(x.dtype).eps return mx.fast.rms_norm(x, None, eps) def apply_talkie_rope(x: mx.array, offset: int, base: float) -> mx.array: """Apply Talkie's split-half RoPE to tensors shaped [B, H, T, D].""" head_dim = x.shape[-1] half_dim = head_dim // 2 freqs = -mx.exp( mx.arange(0.0, half_dim, dtype=mx.float32) * (math.log(base) / half_dim) ) return mx.fast.rope( x, dims=head_dim, traditional=False, base=None, freqs=freqs, scale=1.0, offset=offset, ) class HeadGain(nn.Module): def __init__(self, num_heads: int): super().__init__() self.head_g = mx.ones((num_heads,), dtype=mx.float32) def __call__(self, x: mx.array) -> mx.array: return x * self.head_g.astype(x.dtype).reshape(1, -1, 1, 1) class WeightGain(nn.Module): def __init__(self): super().__init__() self.w_g = mx.ones((1,), dtype=mx.float32) def __call__(self, w: mx.array) -> mx.array: return w * self.w_g.astype(w.dtype) class ActGain(nn.Module): def __init__(self, init_value: float): super().__init__() self.a_g = mx.array([init_value], dtype=mx.float32) def __call__(self, x: mx.array) -> mx.array: return x * self.a_g.astype(x.dtype) class CausalSelfAttention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_head = args.num_attention_heads self.head_dim = args.head_dim self.rope_theta = args.rope_theta self.rms_norm_eps = args.rms_norm_eps self.scale = self.head_dim**-0.5 n_state = args.hidden_size self.attn_query = nn.Linear(n_state, n_state, bias=False) self.attn_key = nn.Linear(n_state, n_state, bias=False) self.attn_value = nn.Linear(n_state, n_state, bias=False) self.attn_resid = nn.Linear(n_state, n_state, bias=False) self.head_gain = HeadGain(self.n_head) def __call__( self, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: bsz, seq_len, _ = x.shape q = self.attn_query(x).reshape(bsz, seq_len, self.n_head, self.head_dim) k = self.attn_key(x).reshape(bsz, seq_len, self.n_head, self.head_dim) v = self.attn_value(x).reshape(bsz, seq_len, self.n_head, self.head_dim) q = q.transpose(0, 2, 1, 3) k = k.transpose(0, 2, 1, 3) v = v.transpose(0, 2, 1, 3) offset = cache.offset if cache is not None else 0 q = apply_talkie_rope(q, offset=offset, base=self.rope_theta) k = apply_talkie_rope(k, offset=offset, base=self.rope_theta) q = rms_norm(q, self.rms_norm_eps) k = rms_norm(k, self.rms_norm_eps) q = self.head_gain(q) if cache is not None: k, v = cache.update_and_fetch(k, v) y = scaled_dot_product_attention( q, k, v, cache=cache, scale=self.scale, mask=mask ) y = y.transpose(0, 2, 1, 3).reshape(bsz, seq_len, -1) return self.attn_resid(y) class MLP(nn.Module): def __init__(self, args: ModelArgs): super().__init__() n_state = args.hidden_size n_mlp = args.intermediate_size self.mlp_gate = nn.Linear(n_state, n_mlp, bias=False) self.mlp_linear = nn.Linear(n_state, n_mlp, bias=False) self.mlp_resid = nn.Linear(n_mlp, n_state, bias=False) def __call__(self, x: mx.array) -> mx.array: gate = self.mlp_gate(x) x = gate * mx.sigmoid(gate) * self.mlp_linear(x) return self.mlp_resid(x) class Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() init_gain = (2 * args.num_hidden_layers) ** -0.5 self.attn = CausalSelfAttention(args) self.attn_gain = ActGain(init_gain) self.mlp = MLP(args) self.mlp_gain = ActGain(init_gain) self.embed_skip = ActGain(0.0) self.rms_norm_eps = args.rms_norm_eps def __call__( self, e_x: mx.array, x: mx.array, mask: Optional[mx.array] = None, cache: Optional[Any] = None, ) -> mx.array: x = x + self.attn_gain(self.attn(rms_norm(x, self.rms_norm_eps), mask, cache)) x = x + self.mlp_gain(self.mlp(rms_norm(x, self.rms_norm_eps))) x = x + self.embed_skip(e_x) return x class Model(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args self.model_type = args.model_type self.embed = nn.Embedding(args.vocab_size, args.hidden_size) self.blocks = [Block(args) for _ in range(args.num_hidden_layers)] self.lm_head = mx.zeros((args.vocab_size, args.hidden_size), dtype=mx.float32) self.lm_head_gain = WeightGain() def __call__( self, input_ids: mx.array, cache: Optional[Any] = None, input_embeddings: Optional[mx.array] = None, ) -> mx.array: if input_embeddings is not None: x = input_embeddings else: x = self.embed(input_ids) x = rms_norm(x, self.args.rms_norm_eps) e_x = x if cache is None: cache = [None] * len(self.blocks) mask = create_attention_mask(x, cache[0]) for block, c in zip(self.blocks, cache): x = block(e_x, x, mask=mask, cache=c) x = rms_norm(x, self.args.rms_norm_eps) return x @ self.lm_head_gain(self.lm_head).T @property def layers(self): return self.blocks