Spaces:
Running
Running
Cache Monarch matrices + causal mask for faster inference
Browse files- checkpoint.jl +7 -1
- model.jl +49 -24
- server.jl +4 -1
checkpoint.jl
CHANGED
|
@@ -94,5 +94,11 @@ function load_inference_model(ckpt_path::String, config_path::String,
|
|
| 94 |
println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
|
| 95 |
end
|
| 96 |
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
end
|
|
|
|
| 94 |
println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
|
| 95 |
end
|
| 96 |
|
| 97 |
+
# Pre-compute inference caches (Monarch matrices + causal mask)
|
| 98 |
+
println("Pre-computing inference caches ...")
|
| 99 |
+
caches = precompute_inference_caches(config, ps)
|
| 100 |
+
n_cached = config.n_layers * config.n_monarch_heads
|
| 101 |
+
println(" Cached $n_cached Monarch matrices ($(config.context_length)x$(config.context_length))")
|
| 102 |
+
|
| 103 |
+
return (; config, ps, tokenizer, step, val_loss, caches)
|
| 104 |
end
|
model.jl
CHANGED
|
@@ -305,31 +305,53 @@ function causal_depthwise_conv1d(x, kernel)
|
|
| 305 |
end
|
| 306 |
|
| 307 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 308 |
-
# Monarch
|
| 309 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
|
| 311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
D, T, B = size(x)
|
| 313 |
H = n_heads
|
| 314 |
-
HD = D Γ· H
|
| 315 |
-
p = isqrt(context_length)
|
| 316 |
|
| 317 |
# 1. Causal depthwise conv for local context
|
| 318 |
conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
|
| 319 |
|
| 320 |
-
# 2. Multi-head Monarch mixing
|
| 321 |
monarch_slices = map(1:H) do i
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
# Realize full T_max Γ T_max Monarch matrix
|
| 326 |
-
M = monarch_realize(ps_m.L1, ps_m.L2, p)
|
| 327 |
-
|
| 328 |
-
# Apply causal mask
|
| 329 |
-
M_causal = M .* mask
|
| 330 |
-
|
| 331 |
-
# Slice to actual sequence length T (for generation where T < context_length)
|
| 332 |
-
M_t = M_causal[1:T, 1:T]
|
| 333 |
|
| 334 |
# Extract this head's channel slice: (HD, T, B)
|
| 335 |
ch_start = (i - 1) * HD + 1
|
|
@@ -356,18 +378,15 @@ function monarch_sequence_mixer_forward(x, ps, n_heads::Int, context_length::Int
|
|
| 356 |
end
|
| 357 |
|
| 358 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 359 |
-
# Full model forward pass
|
| 360 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 361 |
|
| 362 |
-
function model_forward(config::ModelConfig, ps, x)
|
| 363 |
T = size(x, 1) # x: (seq_len, batch) of integer token IDs
|
| 364 |
|
| 365 |
# Token embedding: (seq_len, batch) β (embed_dim, seq_len, batch)
|
| 366 |
h = ps.tok_emb.weight[:, x]
|
| 367 |
|
| 368 |
-
# Causal mask (multiplicative 0/1 for Monarch)
|
| 369 |
-
mask = make_causal_mask(config.context_length)
|
| 370 |
-
|
| 371 |
# Monarch blocks
|
| 372 |
for i in 1:config.n_layers
|
| 373 |
name = Symbol("block_$i")
|
|
@@ -377,7 +396,7 @@ function model_forward(config::ModelConfig, ps, x)
|
|
| 377 |
normed = rmsnorm_forward(h, bp.ln1.weight)
|
| 378 |
mixed = monarch_sequence_mixer_forward(normed, bp.seq_mixer,
|
| 379 |
config.n_monarch_heads,
|
| 380 |
-
|
| 381 |
h = h .+ mixed
|
| 382 |
|
| 383 |
# Pre-norm FFN + residual
|
|
@@ -444,12 +463,18 @@ function generate_streaming(config::ModelConfig, ps,
|
|
| 444 |
temperature::Float64=0.8,
|
| 445 |
top_k::Int=0,
|
| 446 |
top_p::Float64=1.0,
|
| 447 |
-
on_token=nothing
|
|
|
|
| 448 |
tokens = encode(tokenizer, prompt)
|
| 449 |
if isempty(tokens)
|
| 450 |
tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
|
| 451 |
end
|
| 452 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 453 |
generated = String[]
|
| 454 |
|
| 455 |
for _ in 1:max_tokens
|
|
@@ -460,7 +485,7 @@ function generate_streaming(config::ModelConfig, ps,
|
|
| 460 |
end
|
| 461 |
|
| 462 |
x = reshape(ctx, :, 1)
|
| 463 |
-
logits = model_forward(config, ps, x)
|
| 464 |
next_logits = Vector{Float32}(logits[:, end, 1])
|
| 465 |
|
| 466 |
if temperature != 1.0
|
|
|
|
| 305 |
end
|
| 306 |
|
| 307 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 308 |
+
# Pre-compute inference caches (Monarch matrices + causal mask)
|
| 309 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 310 |
|
| 311 |
+
"""
|
| 312 |
+
precompute_inference_caches(config, ps) -> NamedTuple
|
| 313 |
+
|
| 314 |
+
Pre-realize all Monarch matrices and apply causal mask once at startup.
|
| 315 |
+
Avoids recomputing them on every forward pass during generation.
|
| 316 |
+
"""
|
| 317 |
+
function precompute_inference_caches(config::ModelConfig, ps)
|
| 318 |
+
p = isqrt(config.context_length)
|
| 319 |
+
mask = make_causal_mask(config.context_length)
|
| 320 |
+
|
| 321 |
+
# Pre-realize all Monarch matrices: monarch_ms[layer][head] = masked TΓT matrix
|
| 322 |
+
monarch_ms = Vector{Vector{Matrix{Float32}}}(undef, config.n_layers)
|
| 323 |
+
for i in 1:config.n_layers
|
| 324 |
+
name = Symbol("block_$i")
|
| 325 |
+
bp = getproperty(ps.blocks, name)
|
| 326 |
+
layer_ms = Vector{Matrix{Float32}}(undef, config.n_monarch_heads)
|
| 327 |
+
for j in 1:config.n_monarch_heads
|
| 328 |
+
head_name = Symbol("head_$j")
|
| 329 |
+
ps_m = getproperty(bp.seq_mixer.monarchs, head_name)
|
| 330 |
+
M = monarch_realize(ps_m.L1, ps_m.L2, p) .* mask
|
| 331 |
+
layer_ms[j] = M
|
| 332 |
+
end
|
| 333 |
+
monarch_ms[i] = layer_ms
|
| 334 |
+
end
|
| 335 |
+
|
| 336 |
+
return (; mask, monarch_ms)
|
| 337 |
+
end
|
| 338 |
+
|
| 339 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 340 |
+
# Monarch Sequence Mixer forward pass (uses cached matrices)
|
| 341 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 342 |
+
|
| 343 |
+
function monarch_sequence_mixer_forward(x, ps, n_heads::Int, monarch_ms_layer)
|
| 344 |
D, T, B = size(x)
|
| 345 |
H = n_heads
|
| 346 |
+
HD = D Γ· H
|
|
|
|
| 347 |
|
| 348 |
# 1. Causal depthwise conv for local context
|
| 349 |
conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
|
| 350 |
|
| 351 |
+
# 2. Multi-head Monarch mixing (pre-realized matrices)
|
| 352 |
monarch_slices = map(1:H) do i
|
| 353 |
+
# Slice cached matrix to actual sequence length
|
| 354 |
+
M_t = monarch_ms_layer[i][1:T, 1:T]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
# Extract this head's channel slice: (HD, T, B)
|
| 357 |
ch_start = (i - 1) * HD + 1
|
|
|
|
| 378 |
end
|
| 379 |
|
| 380 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 381 |
+
# Full model forward pass (uses cached data)
|
| 382 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 383 |
|
| 384 |
+
function model_forward(config::ModelConfig, ps, x, caches)
|
| 385 |
T = size(x, 1) # x: (seq_len, batch) of integer token IDs
|
| 386 |
|
| 387 |
# Token embedding: (seq_len, batch) β (embed_dim, seq_len, batch)
|
| 388 |
h = ps.tok_emb.weight[:, x]
|
| 389 |
|
|
|
|
|
|
|
|
|
|
| 390 |
# Monarch blocks
|
| 391 |
for i in 1:config.n_layers
|
| 392 |
name = Symbol("block_$i")
|
|
|
|
| 396 |
normed = rmsnorm_forward(h, bp.ln1.weight)
|
| 397 |
mixed = monarch_sequence_mixer_forward(normed, bp.seq_mixer,
|
| 398 |
config.n_monarch_heads,
|
| 399 |
+
caches.monarch_ms[i])
|
| 400 |
h = h .+ mixed
|
| 401 |
|
| 402 |
# Pre-norm FFN + residual
|
|
|
|
| 463 |
temperature::Float64=0.8,
|
| 464 |
top_k::Int=0,
|
| 465 |
top_p::Float64=1.0,
|
| 466 |
+
on_token=nothing,
|
| 467 |
+
caches=nothing)
|
| 468 |
tokens = encode(tokenizer, prompt)
|
| 469 |
if isempty(tokens)
|
| 470 |
tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
|
| 471 |
end
|
| 472 |
|
| 473 |
+
# Use provided caches or compute them once
|
| 474 |
+
if caches === nothing
|
| 475 |
+
caches = precompute_inference_caches(config, ps)
|
| 476 |
+
end
|
| 477 |
+
|
| 478 |
generated = String[]
|
| 479 |
|
| 480 |
for _ in 1:max_tokens
|
|
|
|
| 485 |
end
|
| 486 |
|
| 487 |
x = reshape(ctx, :, 1)
|
| 488 |
+
logits = model_forward(config, ps, x, caches)
|
| 489 |
next_logits = Vector{Float32}(logits[:, end, 1])
|
| 490 |
|
| 491 |
if temperature != 1.0
|
server.jl
CHANGED
|
@@ -72,6 +72,7 @@ const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGE
|
|
| 72 |
const CONFIG = INF_MODEL.config
|
| 73 |
const PS = INF_MODEL.ps
|
| 74 |
const TOKENIZER = INF_MODEL.tokenizer
|
|
|
|
| 75 |
const MODEL_CREATED_AT = Int(floor(time()))
|
| 76 |
|
| 77 |
println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
|
|
@@ -217,6 +218,7 @@ function handle_request(request::HTTP.Request)
|
|
| 217 |
|
| 218 |
generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 219 |
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
|
|
|
|
| 220 |
on_token = function(token_str)
|
| 221 |
token_count[] += 1
|
| 222 |
chunk = Dict(
|
|
@@ -269,7 +271,8 @@ function handle_request(request::HTTP.Request)
|
|
| 269 |
total_completion_tokens = 0
|
| 270 |
for i in 1:n_completions
|
| 271 |
text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 272 |
-
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val
|
|
|
|
| 273 |
finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
|
| 274 |
push!(choices, Dict(
|
| 275 |
"index" => i - 1,
|
|
|
|
| 72 |
const CONFIG = INF_MODEL.config
|
| 73 |
const PS = INF_MODEL.ps
|
| 74 |
const TOKENIZER = INF_MODEL.tokenizer
|
| 75 |
+
const CACHES = INF_MODEL.caches
|
| 76 |
const MODEL_CREATED_AT = Int(floor(time()))
|
| 77 |
|
| 78 |
println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
|
|
|
|
| 218 |
|
| 219 |
generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 220 |
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
|
| 221 |
+
caches=CACHES,
|
| 222 |
on_token = function(token_str)
|
| 223 |
token_count[] += 1
|
| 224 |
chunk = Dict(
|
|
|
|
| 271 |
total_completion_tokens = 0
|
| 272 |
for i in 1:n_completions
|
| 273 |
text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 274 |
+
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
|
| 275 |
+
caches=CACHES)
|
| 276 |
finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
|
| 277 |
push!(choices, Dict(
|
| 278 |
"index" => i - 1,
|