LisaMegaWatts's picture
Initial space setup: Distilled LLaMA-style OpenAI-compatible server
0624a08 verified
#=
model.jl β€” LLaMA-style GPT model in Flux.jl for JuliaFluxGPT
Contains: RMSNorm, SwiGLU, CausalSelfAttention (GQA + RoPE),
TransformerBlock, GPT, and generation utilities.
Same architecture as juliaflux_v2.ipynb β€” extracted for inference serving.
NOTE: Weight tying is done by computing the output projection directly using
m.wte.weight in the forward pass. This matches the training notebooks and
ensures Flux.loadmodel! works without needing to skip lm_head.
=#
using Flux
using NNlib
using NNlib: batched_mul
using Statistics
using Random
using LinearAlgebra
# ═══════════════════════════════════════════════════════════════════════════════
# RoPE β€” Rotary Positional Embeddings
# ═══════════════════════════════════════════════════════════════════════════════
function precompute_rope_freqs(head_dim::Int, max_seq_len::Int; base::Float32 = 10000.0f0)
half_dim = head_dim Γ· 2
freqs = Float32[1.0f0 / (base ^ (Float32(2 * (i - 1)) / Float32(head_dim))) for i in 1:half_dim]
positions = Float32.(collect(0:max_seq_len-1))
angles = freqs * positions'
return cos.(angles), sin.(angles)
end
function apply_rope(x, cos_f, sin_f, T::Int)
d = size(x, 1) Γ· 2
x1 = x[1:d, :, :]
x2 = x[d+1:2d, :, :]
c = cos_f[:, 1:T]
s = sin_f[:, 1:T]
return vcat(x1 .* c .- x2 .* s, x1 .* s .+ x2 .* c)
end
# ═══════════════════════════════════════════════════════════════════════════════
# Model components
# ═══════════════════════════════════════════════════════════════════════════════
struct RMSNorm{W <: AbstractVector}
weight::W
eps::Float32
end
Flux.@layer RMSNorm
RMSNorm(dim::Int; eps::Float32 = 1.0f-6) = RMSNorm(ones(Float32, dim), eps)
function (rn::RMSNorm)(x)
rms = sqrt.(mean(x .^ 2, dims=1) .+ rn.eps)
return (x ./ rms) .* rn.weight
end
struct SwiGLUFFN
w_gate::Dense
w_up::Dense
w_down::Dense
drop::Dropout
end
Flux.@layer SwiGLUFFN
function SwiGLUFFN(n_embd::Int; bias=false, dropout=0.0)
raw_inner = Int(floor(4 * n_embd * 2 / 3))
inner_dim = max(64, 64 * div(raw_inner + 32, 64))
SwiGLUFFN(
Dense(n_embd => inner_dim; bias),
Dense(n_embd => inner_dim; bias),
Dense(inner_dim => n_embd; bias),
Dropout(dropout)
)
end
function (ff::SwiGLUFFN)(x)
ff.drop(ff.w_down(NNlib.swish(ff.w_gate(x)) .* ff.w_up(x)))
end
struct CausalSelfAttention
wq::Dense
wkv::Dense
proj::Dense
n_head::Int
n_kv_head::Int
end
Flux.@layer CausalSelfAttention trainable=(wq, wkv, proj)
function CausalSelfAttention(n_embd::Int, n_head::Int, n_kv_head::Int; bias=false)
head_dim = n_embd Γ· n_head
kv_dim = head_dim * n_kv_head
CausalSelfAttention(
Dense(n_embd => n_embd; bias),
Dense(n_embd => 2 * kv_dim; bias),
Dense(n_embd => n_embd; bias),
n_head,
n_kv_head
)
end
function (attn::CausalSelfAttention)(x, causal_mask, rope_cos, rope_sin)
C, T, B = size(x)
nh = attn.n_head
nkv = attn.n_kv_head
hs = C Γ· nh
kv_dim = hs * nkv
groups = nh Γ· nkv
q = attn.wq(x)
kv = attn.wkv(x)
k = kv[1:kv_dim, :, :]
v = kv[kv_dim+1:2*kv_dim, :, :]
q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B)
k = reshape(permutedims(reshape(k, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
v = reshape(permutedims(reshape(v, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B)
q = apply_rope(q, rope_cos, rope_sin, T)
k = apply_rope(k, rope_cos, rope_sin, T)
if groups > 1
k_4d = reshape(k, hs, T, nkv, B)
v_4d = reshape(v, hs, T, nkv, B)
k_rep = repeat(reshape(k_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
v_rep = repeat(reshape(v_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1)
k = reshape(permutedims(k_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
v = reshape(permutedims(v_rep, (1, 2, 4, 3, 5)), hs, T, nh * B)
end
scale = Float32(1 / sqrt(hs))
wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale
wei = wei .+ causal_mask[1:T, 1:T]
wei = softmax(wei; dims=2)
out = batched_mul(v, permutedims(wei, (2, 1, 3)))
out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B)
attn.proj(out)
end
struct TransformerBlock
ln1::RMSNorm
attn::CausalSelfAttention
ln2::RMSNorm
ffwd::SwiGLUFFN
end
Flux.@layer TransformerBlock
function TransformerBlock(n_embd::Int, n_head::Int, n_kv_head::Int; dropout=0.0)
TransformerBlock(
RMSNorm(n_embd),
CausalSelfAttention(n_embd, n_head, n_kv_head),
RMSNorm(n_embd),
SwiGLUFFN(n_embd; dropout)
)
end
# ═══════════════════════════════════════════════════════════════════════════════
# GPT β€” weight-tied output projection (matches training notebooks)
# ═══════════════════════════════════════════════════════════════════════════════
struct GPT
wte::Embedding
drop::Dropout
blocks::Chain
ln_f::RMSNorm
# Precomputed constants (not trainable)
causal_mask::Matrix{Float32}
rope_cos::Matrix{Float32}
rope_sin::Matrix{Float32}
n_head::Int
n_kv_head::Int
block_size::Int
end
Flux.@layer GPT trainable=(wte, drop, blocks, ln_f)
function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head, dropout=0.0)
head_dim = n_embd Γ· n_head
wte = Embedding(vocab_size => n_embd)
causal_mask = triu(fill(typemin(Float32), block_size, block_size), 1)
rope_cos, rope_sin = precompute_rope_freqs(head_dim, block_size)
GPT(
wte,
Dropout(dropout),
Chain([TransformerBlock(n_embd, n_head, n_kv_head; dropout) for _ in 1:n_layer]...),
RMSNorm(n_embd),
causal_mask,
rope_cos,
rope_sin,
n_head,
n_kv_head,
block_size
)
end
function (m::GPT)(idx)
B, T = size(idx)
tok = permutedims(m.wte(idx), (1, 3, 2)) # (C, T, B)
x = m.drop(tok)
for block in m.blocks
x = x .+ block.attn(block.ln1(x), m.causal_mask, m.rope_cos, m.rope_sin)
x = x .+ block.ffwd(block.ln2(x))
end
x = m.ln_f(x)
# Weight-tied output projection β€” same weight as embedding
W = m.wte.weight
C = size(x, 1)
x_flat = reshape(x, C, T * B)
out = W' * x_flat
reshape(out, size(W, 2), T, B)
end
# ═══════════════════════════════════════════════════════════════════════════════
# Text generation with streaming support
# ═══════════════════════════════════════════════════════════════════════════════
function generate_streaming(model, encode_fn, decode_fn, vocab_size::Int, block_size::Int;
prompt::String="", max_tokens::Int=200, temperature::Float64=0.8,
top_k::Int=40, top_p::Float64=1.0, on_token=nothing)
if !isempty(prompt)
prompt_ids = encode_fn(prompt)
idx = reshape(prompt_ids, 1, :)
else
idx = reshape([rand(1:vocab_size)], 1, 1)
end
generated_ids = Int[]
for _ in 1:max_tokens
idx_cond = idx[:, max(1, end-block_size+1):end]
logits = model(idx_cond)
logits_last = Vector{Float32}(logits[:, end, 1])
# Temperature scaling
logits_last ./= Float32(max(temperature, 0.01))
# Top-k filtering
if top_k > 0 && top_k < length(logits_last)
threshold = partialsort(logits_last, top_k; rev=true)
for i in eachindex(logits_last)
if logits_last[i] < threshold
logits_last[i] = -Inf32
end
end
end
# Top-p (nucleus) filtering
if top_p < 1.0
sorted_indices = sortperm(logits_last; rev=true)
sorted_logits = logits_last[sorted_indices]
probs_sorted = NNlib.softmax(sorted_logits)
cumprobs = cumsum(Array(probs_sorted))
cutoff = something(findfirst(>=(Float32(top_p)), cumprobs), length(probs_sorted))
for i in (cutoff+1):length(sorted_indices)
logits_last[sorted_indices[i]] = -Inf32
end
end
probs = NNlib.softmax(logits_last)
probs_cpu = Float64.(probs)
r = rand()
cum = 0.0
next_id = length(probs_cpu)
for (i, p) in enumerate(probs_cpu)
cum += p
if r <= cum
next_id = i
break
end
end
push!(generated_ids, next_id)
idx = hcat(idx, reshape([next_id], 1, 1))
if on_token !== nothing
token_str = decode_fn([next_id])
on_token(token_str)
end
end
return decode_fn(generated_ids)
end