Spaces:
Sleeping
Sleeping
| #= | |
| model.jl β LLaMA-style GPT model in Flux.jl for JuliaFluxGPT | |
| Contains: RMSNorm, SwiGLU, CausalSelfAttention (GQA + RoPE), | |
| TransformerBlock, GPT, and generation utilities. | |
| Same architecture as juliaflux_v2.ipynb β extracted for inference serving. | |
| NOTE: Weight tying is done by computing the output projection directly using | |
| m.wte.weight in the forward pass. This matches the training notebooks and | |
| ensures Flux.loadmodel! works without needing to skip lm_head. | |
| =# | |
| using Flux | |
| using NNlib | |
| using NNlib: batched_mul | |
| using Statistics | |
| using Random | |
| using LinearAlgebra | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RoPE β Rotary Positional Embeddings | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function precompute_rope_freqs(head_dim::Int, max_seq_len::Int; base::Float32 = 10000.0f0) | |
| half_dim = head_dim Γ· 2 | |
| freqs = Float32[1.0f0 / (base ^ (Float32(2 * (i - 1)) / Float32(head_dim))) for i in 1:half_dim] | |
| positions = Float32.(collect(0:max_seq_len-1)) | |
| angles = freqs * positions' | |
| return cos.(angles), sin.(angles) | |
| end | |
| function apply_rope(x, cos_f, sin_f, T::Int) | |
| d = size(x, 1) Γ· 2 | |
| x1 = x[1:d, :, :] | |
| x2 = x[d+1:2d, :, :] | |
| c = cos_f[:, 1:T] | |
| s = sin_f[:, 1:T] | |
| return vcat(x1 .* c .- x2 .* s, x1 .* s .+ x2 .* c) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Model components | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| struct RMSNorm{W <: AbstractVector} | |
| weight::W | |
| eps::Float32 | |
| end | |
| Flux. RMSNorm | |
| RMSNorm(dim::Int; eps::Float32 = 1.0f-6) = RMSNorm(ones(Float32, dim), eps) | |
| function (rn::RMSNorm)(x) | |
| rms = sqrt.(mean(x .^ 2, dims=1) .+ rn.eps) | |
| return (x ./ rms) .* rn.weight | |
| end | |
| struct SwiGLUFFN | |
| w_gate::Dense | |
| w_up::Dense | |
| w_down::Dense | |
| drop::Dropout | |
| end | |
| Flux. SwiGLUFFN | |
| function SwiGLUFFN(n_embd::Int; bias=false, dropout=0.0) | |
| raw_inner = Int(floor(4 * n_embd * 2 / 3)) | |
| inner_dim = max(64, 64 * div(raw_inner + 32, 64)) | |
| SwiGLUFFN( | |
| Dense(n_embd => inner_dim; bias), | |
| Dense(n_embd => inner_dim; bias), | |
| Dense(inner_dim => n_embd; bias), | |
| Dropout(dropout) | |
| ) | |
| end | |
| function (ff::SwiGLUFFN)(x) | |
| ff.drop(ff.w_down(NNlib.swish(ff.w_gate(x)) .* ff.w_up(x))) | |
| end | |
| struct CausalSelfAttention | |
| wq::Dense | |
| wkv::Dense | |
| proj::Dense | |
| n_head::Int | |
| n_kv_head::Int | |
| end | |
| Flux. CausalSelfAttention trainable=(wq, wkv, proj) | |
| function CausalSelfAttention(n_embd::Int, n_head::Int, n_kv_head::Int; bias=false) | |
| head_dim = n_embd Γ· n_head | |
| kv_dim = head_dim * n_kv_head | |
| CausalSelfAttention( | |
| Dense(n_embd => n_embd; bias), | |
| Dense(n_embd => 2 * kv_dim; bias), | |
| Dense(n_embd => n_embd; bias), | |
| n_head, | |
| n_kv_head | |
| ) | |
| end | |
| function (attn::CausalSelfAttention)(x, causal_mask, rope_cos, rope_sin) | |
| C, T, B = size(x) | |
| nh = attn.n_head | |
| nkv = attn.n_kv_head | |
| hs = C Γ· nh | |
| kv_dim = hs * nkv | |
| groups = nh Γ· nkv | |
| q = attn.wq(x) | |
| kv = attn.wkv(x) | |
| k = kv[1:kv_dim, :, :] | |
| v = kv[kv_dim+1:2*kv_dim, :, :] | |
| q = reshape(permutedims(reshape(q, hs, nh, T, B), (1, 3, 2, 4)), hs, T, nh * B) | |
| k = reshape(permutedims(reshape(k, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B) | |
| v = reshape(permutedims(reshape(v, hs, nkv, T, B), (1, 3, 2, 4)), hs, T, nkv * B) | |
| q = apply_rope(q, rope_cos, rope_sin, T) | |
| k = apply_rope(k, rope_cos, rope_sin, T) | |
| if groups > 1 | |
| k_4d = reshape(k, hs, T, nkv, B) | |
| v_4d = reshape(v, hs, T, nkv, B) | |
| k_rep = repeat(reshape(k_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1) | |
| v_rep = repeat(reshape(v_4d, hs, T, nkv, 1, B), 1, 1, 1, groups, 1) | |
| k = reshape(permutedims(k_rep, (1, 2, 4, 3, 5)), hs, T, nh * B) | |
| v = reshape(permutedims(v_rep, (1, 2, 4, 3, 5)), hs, T, nh * B) | |
| end | |
| scale = Float32(1 / sqrt(hs)) | |
| wei = batched_mul(permutedims(q, (2, 1, 3)), k) .* scale | |
| wei = wei .+ causal_mask[1:T, 1:T] | |
| wei = softmax(wei; dims=2) | |
| out = batched_mul(v, permutedims(wei, (2, 1, 3))) | |
| out = reshape(permutedims(reshape(out, hs, T, nh, B), (1, 3, 2, 4)), C, T, B) | |
| attn.proj(out) | |
| end | |
| struct TransformerBlock | |
| ln1::RMSNorm | |
| attn::CausalSelfAttention | |
| ln2::RMSNorm | |
| ffwd::SwiGLUFFN | |
| end | |
| Flux. TransformerBlock | |
| function TransformerBlock(n_embd::Int, n_head::Int, n_kv_head::Int; dropout=0.0) | |
| TransformerBlock( | |
| RMSNorm(n_embd), | |
| CausalSelfAttention(n_embd, n_head, n_kv_head), | |
| RMSNorm(n_embd), | |
| SwiGLUFFN(n_embd; dropout) | |
| ) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GPT β weight-tied output projection (matches training notebooks) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| struct GPT | |
| wte::Embedding | |
| drop::Dropout | |
| blocks::Chain | |
| ln_f::RMSNorm | |
| # Precomputed constants (not trainable) | |
| causal_mask::Matrix{Float32} | |
| rope_cos::Matrix{Float32} | |
| rope_sin::Matrix{Float32} | |
| n_head::Int | |
| n_kv_head::Int | |
| block_size::Int | |
| end | |
| Flux. GPT trainable=(wte, drop, blocks, ln_f) | |
| function GPT(; vocab_size, n_embd, block_size, n_layer, n_head, n_kv_head, dropout=0.0) | |
| head_dim = n_embd Γ· n_head | |
| wte = Embedding(vocab_size => n_embd) | |
| causal_mask = triu(fill(typemin(Float32), block_size, block_size), 1) | |
| rope_cos, rope_sin = precompute_rope_freqs(head_dim, block_size) | |
| GPT( | |
| wte, | |
| Dropout(dropout), | |
| Chain([TransformerBlock(n_embd, n_head, n_kv_head; dropout) for _ in 1:n_layer]...), | |
| RMSNorm(n_embd), | |
| causal_mask, | |
| rope_cos, | |
| rope_sin, | |
| n_head, | |
| n_kv_head, | |
| block_size | |
| ) | |
| end | |
| function (m::GPT)(idx) | |
| B, T = size(idx) | |
| tok = permutedims(m.wte(idx), (1, 3, 2)) # (C, T, B) | |
| x = m.drop(tok) | |
| for block in m.blocks | |
| x = x .+ block.attn(block.ln1(x), m.causal_mask, m.rope_cos, m.rope_sin) | |
| x = x .+ block.ffwd(block.ln2(x)) | |
| end | |
| x = m.ln_f(x) | |
| # Weight-tied output projection β same weight as embedding | |
| W = m.wte.weight | |
| C = size(x, 1) | |
| x_flat = reshape(x, C, T * B) | |
| out = W' * x_flat | |
| reshape(out, size(W, 2), T, B) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Text generation with streaming support | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function generate_streaming(model, encode_fn, decode_fn, vocab_size::Int, block_size::Int; | |
| prompt::String="", max_tokens::Int=200, temperature::Float64=0.8, | |
| top_k::Int=40, top_p::Float64=1.0, on_token=nothing) | |
| if !isempty(prompt) | |
| prompt_ids = encode_fn(prompt) | |
| idx = reshape(prompt_ids, 1, :) | |
| else | |
| idx = reshape([rand(1:vocab_size)], 1, 1) | |
| end | |
| generated_ids = Int[] | |
| for _ in 1:max_tokens | |
| idx_cond = idx[:, max(1, end-block_size+1):end] | |
| logits = model(idx_cond) | |
| logits_last = Vector{Float32}(logits[:, end, 1]) | |
| # Temperature scaling | |
| logits_last ./= Float32(max(temperature, 0.01)) | |
| # Top-k filtering | |
| if top_k > 0 && top_k < length(logits_last) | |
| threshold = partialsort(logits_last, top_k; rev=true) | |
| for i in eachindex(logits_last) | |
| if logits_last[i] < threshold | |
| logits_last[i] = -Inf32 | |
| end | |
| end | |
| end | |
| # Top-p (nucleus) filtering | |
| if top_p < 1.0 | |
| sorted_indices = sortperm(logits_last; rev=true) | |
| sorted_logits = logits_last[sorted_indices] | |
| probs_sorted = NNlib.softmax(sorted_logits) | |
| cumprobs = cumsum(Array(probs_sorted)) | |
| cutoff = something(findfirst(>=(Float32(top_p)), cumprobs), length(probs_sorted)) | |
| for i in (cutoff+1):length(sorted_indices) | |
| logits_last[sorted_indices[i]] = -Inf32 | |
| end | |
| end | |
| probs = NNlib.softmax(logits_last) | |
| probs_cpu = Float64.(probs) | |
| r = rand() | |
| cum = 0.0 | |
| next_id = length(probs_cpu) | |
| for (i, p) in enumerate(probs_cpu) | |
| cum += p | |
| if r <= cum | |
| next_id = i | |
| break | |
| end | |
| end | |
| push!(generated_ids, next_id) | |
| idx = hcat(idx, reshape([next_id], 1, 1)) | |
| if on_token !== nothing | |
| token_str = decode_fn([next_id]) | |
| on_token(token_str) | |
| end | |
| end | |
| return decode_fn(generated_ids) | |
| end | |