""" AAC Micro Brain — 16M parameter conversational flow model. Tiny transformer that only knows how humans talk in everyday situations. No world knowledge. No encyclopedia. Just conversation patterns. Architecture: ~16M params - vocab_size: 8192 - d_model: 512 - n_heads: 8 - n_layers: 6 - d_ff: 1024 - max_seq_len: 128 """ import mlx.core as mx import mlx.nn as nn import math class MultiHeadAttention(nn.Module): def __init__(self, d_model: int, n_heads: int): super().__init__() self.n_heads = n_heads self.d_head = d_model // n_heads self.qkv = nn.Linear(d_model, 3 * d_model, bias=False) self.out = nn.Linear(d_model, d_model, bias=False) def __call__(self, x, mask=None): B, T, C = x.shape qkv = self.qkv(x) q, k, v = mx.split(qkv, 3, axis=-1) q = q.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3) k = k.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3) v = v.reshape(B, T, self.n_heads, self.d_head).transpose(0, 2, 1, 3) scale = math.sqrt(self.d_head) attn = (q @ k.transpose(0, 1, 3, 2)) / scale if mask is not None: attn = attn + mask attn = mx.softmax(attn, axis=-1) out = (attn @ v).transpose(0, 2, 1, 3).reshape(B, T, C) return self.out(out) class TransformerBlock(nn.Module): def __init__(self, d_model: int, n_heads: int, d_ff: int): super().__init__() self.attn = MultiHeadAttention(d_model, n_heads) self.ff = nn.Sequential( nn.Linear(d_model, d_ff, bias=False), nn.GELU(), nn.Linear(d_ff, d_model, bias=False), ) self.ln1 = nn.RMSNorm(d_model) self.ln2 = nn.RMSNorm(d_model) def __call__(self, x, mask=None): x = x + self.attn(self.ln1(x), mask=mask) x = x + self.ff(self.ln2(x)) return x class MicroBrain(nn.Module): """16M param conversational flow predictor.""" def __init__( self, vocab_size: int = 8192, d_model: int = 512, n_heads: int = 8, n_layers: int = 6, d_ff: int = 1024, max_seq_len: int = 128, ): super().__init__() self.d_model = d_model self.max_seq_len = max_seq_len self.token_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(max_seq_len, d_model) self.layers = [TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)] self.ln_final = nn.RMSNorm(d_model) self.output = nn.Linear(d_model, vocab_size, bias=False) def __call__(self, tokens): B, T = tokens.shape positions = mx.arange(T) x = self.token_emb(tokens) + self.pos_emb(positions) # Causal mask mask = nn.MultiHeadAttention.create_additive_causal_mask(T) for layer in self.layers: x = layer(x, mask=mask) x = self.ln_final(x) logits = self.output(x) return logits def count_params(self): """Count total parameters.""" from mlx.utils import tree_flatten return sum(v.size for _, v in tree_flatten(self.parameters())) def create_model(**kwargs): model = MicroBrain(**kwargs) mx.eval(model.parameters()) n_params = model.count_params() print(f"MicroBrain: {n_params:,} parameters ({n_params / 1e6:.1f}M)") return model if __name__ == "__main__": model = create_model()