Spaces:
Running
Running
Cache Monarch matrix + causal mask for faster inference
Browse files- checkpoint.jl +6 -1
- model.jl +41 -20
- server.jl +4 -1
checkpoint.jl
CHANGED
|
@@ -95,5 +95,10 @@ function load_inference_model(ckpt_path::String, config_path::String,
|
|
| 95 |
println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
|
| 96 |
end
|
| 97 |
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
end
|
|
|
|
| 95 |
println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
|
| 96 |
end
|
| 97 |
|
| 98 |
+
# Pre-compute inference caches (Monarch matrices + causal mask)
|
| 99 |
+
println("Pre-computing inference caches ...")
|
| 100 |
+
caches = precompute_inference_caches(config, ps)
|
| 101 |
+
println(" Cached $(config.n_layers) Monarch matrices ($(config.context_length)x$(config.context_length))")
|
| 102 |
+
|
| 103 |
+
return (; config, ps, tokenizer, step, val_loss, caches)
|
| 104 |
end
|
model.jl
CHANGED
|
@@ -379,28 +379,46 @@ function organelle_gate_forward(organelle_outputs, logits)
|
|
| 379 |
end
|
| 380 |
|
| 381 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 382 |
-
#
|
| 383 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 384 |
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 386 |
D, T, B = size(x)
|
| 387 |
-
p = isqrt(context_length)
|
| 388 |
|
| 389 |
# ββ Organelle 1: CausalConv (local n-gram patterns) ββ
|
| 390 |
conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
|
| 391 |
|
| 392 |
# ββ Organelle 2: MonarchMatrix (global structured mixing, single-head) ββ
|
| 393 |
-
#
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
# Apply causal mask (multiplicative 0/1)
|
| 397 |
-
M_causal = M .* mask
|
| 398 |
-
|
| 399 |
-
# Slice to actual sequence length (handles generation where T < context_length)
|
| 400 |
-
M_t = M_causal[1:T, 1:T]
|
| 401 |
|
| 402 |
# Single-head: apply Monarch to ALL channels at once
|
| 403 |
-
# x: (D, T, B) β permute to (T, D, B) β flatten β matmul β reshape back
|
| 404 |
x_seq = reshape(permutedims(x, (2, 1, 3)), T, D * B) # (T, D*B)
|
| 405 |
y_monarch = M_t * x_seq # (T, D*B)
|
| 406 |
monarch_out = permutedims(reshape(y_monarch, T, D, B), (2, 1, 3)) # (D, T, B)
|
|
@@ -415,18 +433,15 @@ function symbio_sequence_mixer_forward(x, ps, context_length::Int, mask)
|
|
| 415 |
end
|
| 416 |
|
| 417 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 418 |
-
# Full model forward pass
|
| 419 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 420 |
|
| 421 |
-
function model_forward(config::ModelConfig, ps, x)
|
| 422 |
T = size(x, 1) # x: (seq_len, batch) of integer token IDs
|
| 423 |
|
| 424 |
# Token embedding: (seq_len, batch) β (embed_dim, seq_len, batch)
|
| 425 |
h = ps.tok_emb.weight[:, x]
|
| 426 |
|
| 427 |
-
# Causal mask (multiplicative 0/1 for symbiogenesis)
|
| 428 |
-
mask = make_causal_mask(config.context_length)
|
| 429 |
-
|
| 430 |
# Symbiogenesis blocks
|
| 431 |
for i in 1:config.n_layers
|
| 432 |
name = Symbol("block_$i")
|
|
@@ -435,7 +450,7 @@ function model_forward(config::ModelConfig, ps, x)
|
|
| 435 |
# Pre-norm sequence mixing + residual
|
| 436 |
normed = rmsnorm_forward(h, bp.ln1.weight)
|
| 437 |
mixed = symbio_sequence_mixer_forward(normed, bp.seq_mixer,
|
| 438 |
-
|
| 439 |
h = h .+ mixed
|
| 440 |
|
| 441 |
# Pre-norm FFN + residual
|
|
@@ -502,12 +517,18 @@ function generate_streaming(config::ModelConfig, ps,
|
|
| 502 |
temperature::Float64=0.8,
|
| 503 |
top_k::Int=0,
|
| 504 |
top_p::Float64=1.0,
|
| 505 |
-
on_token=nothing
|
|
|
|
| 506 |
tokens = encode(tokenizer, prompt)
|
| 507 |
if isempty(tokens)
|
| 508 |
tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
|
| 509 |
end
|
| 510 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 511 |
generated = String[]
|
| 512 |
|
| 513 |
for _ in 1:max_tokens
|
|
@@ -518,7 +539,7 @@ function generate_streaming(config::ModelConfig, ps,
|
|
| 518 |
end
|
| 519 |
|
| 520 |
x = reshape(ctx, :, 1)
|
| 521 |
-
logits = model_forward(config, ps, x)
|
| 522 |
next_logits = Vector{Float32}(logits[:, end, 1])
|
| 523 |
|
| 524 |
if temperature != 1.0
|
|
|
|
| 379 |
end
|
| 380 |
|
| 381 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 382 |
+
# Pre-compute inference caches (Monarch matrices + causal mask)
|
| 383 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 384 |
|
| 385 |
+
"""
|
| 386 |
+
precompute_inference_caches(config, ps) -> NamedTuple
|
| 387 |
+
|
| 388 |
+
Pre-realize Monarch matrices and apply causal mask once at startup.
|
| 389 |
+
Avoids recomputing them on every forward pass during generation.
|
| 390 |
+
"""
|
| 391 |
+
function precompute_inference_caches(config::ModelConfig, ps)
|
| 392 |
+
p = isqrt(config.context_length)
|
| 393 |
+
mask = make_causal_mask(config.context_length)
|
| 394 |
+
|
| 395 |
+
# Pre-realize Monarch matrix per layer (single-head): monarch_ms[layer] = masked TΓT
|
| 396 |
+
monarch_ms = Vector{Matrix{Float32}}(undef, config.n_layers)
|
| 397 |
+
for i in 1:config.n_layers
|
| 398 |
+
name = Symbol("block_$i")
|
| 399 |
+
bp = getproperty(ps.blocks, name)
|
| 400 |
+
M = monarch_realize(bp.seq_mixer.monarch.L1, bp.seq_mixer.monarch.L2, p) .* mask
|
| 401 |
+
monarch_ms[i] = M
|
| 402 |
+
end
|
| 403 |
+
|
| 404 |
+
return (; mask, monarch_ms)
|
| 405 |
+
end
|
| 406 |
+
|
| 407 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 408 |
+
# Symbiogenesis Sequence Mixer β 3 organelles + gate (uses caches)
|
| 409 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 410 |
+
|
| 411 |
+
function symbio_sequence_mixer_forward(x, ps, monarch_M)
|
| 412 |
D, T, B = size(x)
|
|
|
|
| 413 |
|
| 414 |
# ββ Organelle 1: CausalConv (local n-gram patterns) ββ
|
| 415 |
conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
|
| 416 |
|
| 417 |
# ββ Organelle 2: MonarchMatrix (global structured mixing, single-head) ββ
|
| 418 |
+
# Use pre-realized + masked matrix, slice to actual T
|
| 419 |
+
M_t = monarch_M[1:T, 1:T]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
|
| 421 |
# Single-head: apply Monarch to ALL channels at once
|
|
|
|
| 422 |
x_seq = reshape(permutedims(x, (2, 1, 3)), T, D * B) # (T, D*B)
|
| 423 |
y_monarch = M_t * x_seq # (T, D*B)
|
| 424 |
monarch_out = permutedims(reshape(y_monarch, T, D, B), (2, 1, 3)) # (D, T, B)
|
|
|
|
| 433 |
end
|
| 434 |
|
| 435 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 436 |
+
# Full model forward pass (uses cached data)
|
| 437 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 438 |
|
| 439 |
+
function model_forward(config::ModelConfig, ps, x, caches)
|
| 440 |
T = size(x, 1) # x: (seq_len, batch) of integer token IDs
|
| 441 |
|
| 442 |
# Token embedding: (seq_len, batch) β (embed_dim, seq_len, batch)
|
| 443 |
h = ps.tok_emb.weight[:, x]
|
| 444 |
|
|
|
|
|
|
|
|
|
|
| 445 |
# Symbiogenesis blocks
|
| 446 |
for i in 1:config.n_layers
|
| 447 |
name = Symbol("block_$i")
|
|
|
|
| 450 |
# Pre-norm sequence mixing + residual
|
| 451 |
normed = rmsnorm_forward(h, bp.ln1.weight)
|
| 452 |
mixed = symbio_sequence_mixer_forward(normed, bp.seq_mixer,
|
| 453 |
+
caches.monarch_ms[i])
|
| 454 |
h = h .+ mixed
|
| 455 |
|
| 456 |
# Pre-norm FFN + residual
|
|
|
|
| 517 |
temperature::Float64=0.8,
|
| 518 |
top_k::Int=0,
|
| 519 |
top_p::Float64=1.0,
|
| 520 |
+
on_token=nothing,
|
| 521 |
+
caches=nothing)
|
| 522 |
tokens = encode(tokenizer, prompt)
|
| 523 |
if isempty(tokens)
|
| 524 |
tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
|
| 525 |
end
|
| 526 |
|
| 527 |
+
# Use provided caches or compute them once
|
| 528 |
+
if caches === nothing
|
| 529 |
+
caches = precompute_inference_caches(config, ps)
|
| 530 |
+
end
|
| 531 |
+
|
| 532 |
generated = String[]
|
| 533 |
|
| 534 |
for _ in 1:max_tokens
|
|
|
|
| 539 |
end
|
| 540 |
|
| 541 |
x = reshape(ctx, :, 1)
|
| 542 |
+
logits = model_forward(config, ps, x, caches)
|
| 543 |
next_logits = Vector{Float32}(logits[:, end, 1])
|
| 544 |
|
| 545 |
if temperature != 1.0
|
server.jl
CHANGED
|
@@ -73,6 +73,7 @@ const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGE
|
|
| 73 |
const CONFIG = INF_MODEL.config
|
| 74 |
const PS = INF_MODEL.ps
|
| 75 |
const TOKENIZER = INF_MODEL.tokenizer
|
|
|
|
| 76 |
const MODEL_CREATED_AT = Int(floor(time()))
|
| 77 |
|
| 78 |
println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
|
|
@@ -225,6 +226,7 @@ function handle_request(request::HTTP.Request)
|
|
| 225 |
|
| 226 |
generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 227 |
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
|
|
|
|
| 228 |
on_token = function(token_str)
|
| 229 |
token_count[] += 1
|
| 230 |
chunk = Dict(
|
|
@@ -277,7 +279,8 @@ function handle_request(request::HTTP.Request)
|
|
| 277 |
total_completion_tokens = 0
|
| 278 |
for i in 1:n_completions
|
| 279 |
text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 280 |
-
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val
|
|
|
|
| 281 |
finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
|
| 282 |
push!(choices, Dict(
|
| 283 |
"index" => i - 1,
|
|
|
|
| 73 |
const CONFIG = INF_MODEL.config
|
| 74 |
const PS = INF_MODEL.ps
|
| 75 |
const TOKENIZER = INF_MODEL.tokenizer
|
| 76 |
+
const CACHES = INF_MODEL.caches
|
| 77 |
const MODEL_CREATED_AT = Int(floor(time()))
|
| 78 |
|
| 79 |
println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
|
|
|
|
| 226 |
|
| 227 |
generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 228 |
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
|
| 229 |
+
caches=CACHES,
|
| 230 |
on_token = function(token_str)
|
| 231 |
token_count[] += 1
|
| 232 |
chunk = Dict(
|
|
|
|
| 279 |
total_completion_tokens = 0
|
| 280 |
for i in 1:n_completions
|
| 281 |
text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
|
| 282 |
+
max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
|
| 283 |
+
caches=CACHES)
|
| 284 |
finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
|
| 285 |
push!(choices, Dict(
|
| 286 |
"index" => i - 1,
|