LisaMegaWatts commited on
Commit
676f15f
Β·
verified Β·
1 Parent(s): d167fb8

Cache Monarch matrix + causal mask for faster inference

Browse files
Files changed (3) hide show
  1. checkpoint.jl +6 -1
  2. model.jl +41 -20
  3. server.jl +4 -1
checkpoint.jl CHANGED
@@ -95,5 +95,10 @@ function load_inference_model(ckpt_path::String, config_path::String,
95
  println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
96
  end
97
 
98
- return (; config, ps, tokenizer, step, val_loss)
 
 
 
 
 
99
  end
 
95
  println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
96
  end
97
 
98
+ # Pre-compute inference caches (Monarch matrices + causal mask)
99
+ println("Pre-computing inference caches ...")
100
+ caches = precompute_inference_caches(config, ps)
101
+ println(" Cached $(config.n_layers) Monarch matrices ($(config.context_length)x$(config.context_length))")
102
+
103
+ return (; config, ps, tokenizer, step, val_loss, caches)
104
  end
model.jl CHANGED
@@ -379,28 +379,46 @@ function organelle_gate_forward(organelle_outputs, logits)
379
  end
380
 
381
  # ═══════════════════════════════════════════════════════════════════
382
- # Symbiogenesis Sequence Mixer β€” 3 organelles + gate
383
  # ═══════════════════════════════════════════════════════════════════
384
 
385
- function symbio_sequence_mixer_forward(x, ps, context_length::Int, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
  D, T, B = size(x)
387
- p = isqrt(context_length)
388
 
389
  # ── Organelle 1: CausalConv (local n-gram patterns) ──
390
  conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
391
 
392
  # ── Organelle 2: MonarchMatrix (global structured mixing, single-head) ──
393
- # Realize full T_max Γ— T_max Monarch matrix
394
- M = monarch_realize(ps.monarch.L1, ps.monarch.L2, p)
395
-
396
- # Apply causal mask (multiplicative 0/1)
397
- M_causal = M .* mask
398
-
399
- # Slice to actual sequence length (handles generation where T < context_length)
400
- M_t = M_causal[1:T, 1:T]
401
 
402
  # Single-head: apply Monarch to ALL channels at once
403
- # x: (D, T, B) β†’ permute to (T, D, B) β†’ flatten β†’ matmul β†’ reshape back
404
  x_seq = reshape(permutedims(x, (2, 1, 3)), T, D * B) # (T, D*B)
405
  y_monarch = M_t * x_seq # (T, D*B)
406
  monarch_out = permutedims(reshape(y_monarch, T, D, B), (2, 1, 3)) # (D, T, B)
@@ -415,18 +433,15 @@ function symbio_sequence_mixer_forward(x, ps, context_length::Int, mask)
415
  end
416
 
417
  # ═══════════════════════════════════════════════════════════════════
418
- # Full model forward pass
419
  # ═══════════════════════════════════════════════════════════════════
420
 
421
- function model_forward(config::ModelConfig, ps, x)
422
  T = size(x, 1) # x: (seq_len, batch) of integer token IDs
423
 
424
  # Token embedding: (seq_len, batch) β†’ (embed_dim, seq_len, batch)
425
  h = ps.tok_emb.weight[:, x]
426
 
427
- # Causal mask (multiplicative 0/1 for symbiogenesis)
428
- mask = make_causal_mask(config.context_length)
429
-
430
  # Symbiogenesis blocks
431
  for i in 1:config.n_layers
432
  name = Symbol("block_$i")
@@ -435,7 +450,7 @@ function model_forward(config::ModelConfig, ps, x)
435
  # Pre-norm sequence mixing + residual
436
  normed = rmsnorm_forward(h, bp.ln1.weight)
437
  mixed = symbio_sequence_mixer_forward(normed, bp.seq_mixer,
438
- config.context_length, mask)
439
  h = h .+ mixed
440
 
441
  # Pre-norm FFN + residual
@@ -502,12 +517,18 @@ function generate_streaming(config::ModelConfig, ps,
502
  temperature::Float64=0.8,
503
  top_k::Int=0,
504
  top_p::Float64=1.0,
505
- on_token=nothing)
 
506
  tokens = encode(tokenizer, prompt)
507
  if isempty(tokens)
508
  tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
509
  end
510
 
 
 
 
 
 
511
  generated = String[]
512
 
513
  for _ in 1:max_tokens
@@ -518,7 +539,7 @@ function generate_streaming(config::ModelConfig, ps,
518
  end
519
 
520
  x = reshape(ctx, :, 1)
521
- logits = model_forward(config, ps, x)
522
  next_logits = Vector{Float32}(logits[:, end, 1])
523
 
524
  if temperature != 1.0
 
379
  end
380
 
381
  # ═══════════════════════════════════════════════════════════════════
382
+ # Pre-compute inference caches (Monarch matrices + causal mask)
383
  # ═══════════════════════════════════════════════════════════════════
384
 
385
+ """
386
+ precompute_inference_caches(config, ps) -> NamedTuple
387
+
388
+ Pre-realize Monarch matrices and apply causal mask once at startup.
389
+ Avoids recomputing them on every forward pass during generation.
390
+ """
391
+ function precompute_inference_caches(config::ModelConfig, ps)
392
+ p = isqrt(config.context_length)
393
+ mask = make_causal_mask(config.context_length)
394
+
395
+ # Pre-realize Monarch matrix per layer (single-head): monarch_ms[layer] = masked TΓ—T
396
+ monarch_ms = Vector{Matrix{Float32}}(undef, config.n_layers)
397
+ for i in 1:config.n_layers
398
+ name = Symbol("block_$i")
399
+ bp = getproperty(ps.blocks, name)
400
+ M = monarch_realize(bp.seq_mixer.monarch.L1, bp.seq_mixer.monarch.L2, p) .* mask
401
+ monarch_ms[i] = M
402
+ end
403
+
404
+ return (; mask, monarch_ms)
405
+ end
406
+
407
+ # ═══════════════════════════════════════════════════════════════════
408
+ # Symbiogenesis Sequence Mixer β€” 3 organelles + gate (uses caches)
409
+ # ═══════════════════════════════════════════════════════════════════
410
+
411
+ function symbio_sequence_mixer_forward(x, ps, monarch_M)
412
  D, T, B = size(x)
 
413
 
414
  # ── Organelle 1: CausalConv (local n-gram patterns) ──
415
  conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
416
 
417
  # ── Organelle 2: MonarchMatrix (global structured mixing, single-head) ──
418
+ # Use pre-realized + masked matrix, slice to actual T
419
+ M_t = monarch_M[1:T, 1:T]
 
 
 
 
 
 
420
 
421
  # Single-head: apply Monarch to ALL channels at once
 
422
  x_seq = reshape(permutedims(x, (2, 1, 3)), T, D * B) # (T, D*B)
423
  y_monarch = M_t * x_seq # (T, D*B)
424
  monarch_out = permutedims(reshape(y_monarch, T, D, B), (2, 1, 3)) # (D, T, B)
 
433
  end
434
 
435
  # ═══════════════════════════════════════════════════════════════════
436
+ # Full model forward pass (uses cached data)
437
  # ═══════════════════════════════════════════════════════════════════
438
 
439
+ function model_forward(config::ModelConfig, ps, x, caches)
440
  T = size(x, 1) # x: (seq_len, batch) of integer token IDs
441
 
442
  # Token embedding: (seq_len, batch) β†’ (embed_dim, seq_len, batch)
443
  h = ps.tok_emb.weight[:, x]
444
 
 
 
 
445
  # Symbiogenesis blocks
446
  for i in 1:config.n_layers
447
  name = Symbol("block_$i")
 
450
  # Pre-norm sequence mixing + residual
451
  normed = rmsnorm_forward(h, bp.ln1.weight)
452
  mixed = symbio_sequence_mixer_forward(normed, bp.seq_mixer,
453
+ caches.monarch_ms[i])
454
  h = h .+ mixed
455
 
456
  # Pre-norm FFN + residual
 
517
  temperature::Float64=0.8,
518
  top_k::Int=0,
519
  top_p::Float64=1.0,
520
+ on_token=nothing,
521
+ caches=nothing)
522
  tokens = encode(tokenizer, prompt)
523
  if isempty(tokens)
524
  tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
525
  end
526
 
527
+ # Use provided caches or compute them once
528
+ if caches === nothing
529
+ caches = precompute_inference_caches(config, ps)
530
+ end
531
+
532
  generated = String[]
533
 
534
  for _ in 1:max_tokens
 
539
  end
540
 
541
  x = reshape(ctx, :, 1)
542
+ logits = model_forward(config, ps, x, caches)
543
  next_logits = Vector{Float32}(logits[:, end, 1])
544
 
545
  if temperature != 1.0
server.jl CHANGED
@@ -73,6 +73,7 @@ const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGE
73
  const CONFIG = INF_MODEL.config
74
  const PS = INF_MODEL.ps
75
  const TOKENIZER = INF_MODEL.tokenizer
 
76
  const MODEL_CREATED_AT = Int(floor(time()))
77
 
78
  println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
@@ -225,6 +226,7 @@ function handle_request(request::HTTP.Request)
225
 
226
  generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
227
  max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
 
228
  on_token = function(token_str)
229
  token_count[] += 1
230
  chunk = Dict(
@@ -277,7 +279,8 @@ function handle_request(request::HTTP.Request)
277
  total_completion_tokens = 0
278
  for i in 1:n_completions
279
  text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
280
- max_tokens, temperature, top_k=top_k_val, top_p=top_p_val)
 
281
  finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
282
  push!(choices, Dict(
283
  "index" => i - 1,
 
73
  const CONFIG = INF_MODEL.config
74
  const PS = INF_MODEL.ps
75
  const TOKENIZER = INF_MODEL.tokenizer
76
+ const CACHES = INF_MODEL.caches
77
  const MODEL_CREATED_AT = Int(floor(time()))
78
 
79
  println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
 
226
 
227
  generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
228
  max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
229
+ caches=CACHES,
230
  on_token = function(token_str)
231
  token_count[] += 1
232
  chunk = Dict(
 
279
  total_completion_tokens = 0
280
  for i in 1:n_completions
281
  text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
282
+ max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
283
+ caches=CACHES)
284
  finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
285
  push!(choices, Dict(
286
  "index" => i - 1,