LisaMegaWatts commited on
Commit
76b7110
Β·
verified Β·
1 Parent(s): f0aedd4

Cache Monarch matrices + causal mask for faster inference

Browse files
Files changed (3) hide show
  1. checkpoint.jl +7 -1
  2. model.jl +49 -24
  3. server.jl +4 -1
checkpoint.jl CHANGED
@@ -94,5 +94,11 @@ function load_inference_model(ckpt_path::String, config_path::String,
94
  println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
95
  end
96
 
97
- return (; config, ps, tokenizer, step, val_loss)
 
 
 
 
 
 
98
  end
 
94
  println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
95
  end
96
 
97
+ # Pre-compute inference caches (Monarch matrices + causal mask)
98
+ println("Pre-computing inference caches ...")
99
+ caches = precompute_inference_caches(config, ps)
100
+ n_cached = config.n_layers * config.n_monarch_heads
101
+ println(" Cached $n_cached Monarch matrices ($(config.context_length)x$(config.context_length))")
102
+
103
+ return (; config, ps, tokenizer, step, val_loss, caches)
104
  end
model.jl CHANGED
@@ -305,31 +305,53 @@ function causal_depthwise_conv1d(x, kernel)
305
  end
306
 
307
  # ═══════════════════════════════════════════════════════════════════
308
- # Monarch Sequence Mixer forward pass
309
  # ═══════════════════════════════════════════════════════════════════
310
 
311
- function monarch_sequence_mixer_forward(x, ps, n_heads::Int, context_length::Int, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  D, T, B = size(x)
313
  H = n_heads
314
- HD = D Γ· H # channels per head
315
- p = isqrt(context_length)
316
 
317
  # 1. Causal depthwise conv for local context
318
  conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
319
 
320
- # 2. Multi-head Monarch mixing for global context
321
  monarch_slices = map(1:H) do i
322
- name = Symbol("head_$i")
323
- ps_m = getproperty(ps.monarchs, name)
324
-
325
- # Realize full T_max Γ— T_max Monarch matrix
326
- M = monarch_realize(ps_m.L1, ps_m.L2, p)
327
-
328
- # Apply causal mask
329
- M_causal = M .* mask
330
-
331
- # Slice to actual sequence length T (for generation where T < context_length)
332
- M_t = M_causal[1:T, 1:T]
333
 
334
  # Extract this head's channel slice: (HD, T, B)
335
  ch_start = (i - 1) * HD + 1
@@ -356,18 +378,15 @@ function monarch_sequence_mixer_forward(x, ps, n_heads::Int, context_length::Int
356
  end
357
 
358
  # ═══════════════════════════════════════════════════════════════════
359
- # Full model forward pass
360
  # ═══════════════════════════════════════════════════════════════════
361
 
362
- function model_forward(config::ModelConfig, ps, x)
363
  T = size(x, 1) # x: (seq_len, batch) of integer token IDs
364
 
365
  # Token embedding: (seq_len, batch) β†’ (embed_dim, seq_len, batch)
366
  h = ps.tok_emb.weight[:, x]
367
 
368
- # Causal mask (multiplicative 0/1 for Monarch)
369
- mask = make_causal_mask(config.context_length)
370
-
371
  # Monarch blocks
372
  for i in 1:config.n_layers
373
  name = Symbol("block_$i")
@@ -377,7 +396,7 @@ function model_forward(config::ModelConfig, ps, x)
377
  normed = rmsnorm_forward(h, bp.ln1.weight)
378
  mixed = monarch_sequence_mixer_forward(normed, bp.seq_mixer,
379
  config.n_monarch_heads,
380
- config.context_length, mask)
381
  h = h .+ mixed
382
 
383
  # Pre-norm FFN + residual
@@ -444,12 +463,18 @@ function generate_streaming(config::ModelConfig, ps,
444
  temperature::Float64=0.8,
445
  top_k::Int=0,
446
  top_p::Float64=1.0,
447
- on_token=nothing)
 
448
  tokens = encode(tokenizer, prompt)
449
  if isempty(tokens)
450
  tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
451
  end
452
 
 
 
 
 
 
453
  generated = String[]
454
 
455
  for _ in 1:max_tokens
@@ -460,7 +485,7 @@ function generate_streaming(config::ModelConfig, ps,
460
  end
461
 
462
  x = reshape(ctx, :, 1)
463
- logits = model_forward(config, ps, x)
464
  next_logits = Vector{Float32}(logits[:, end, 1])
465
 
466
  if temperature != 1.0
 
305
  end
306
 
307
  # ═══════════════════════════════════════════════════════════════════
308
+ # Pre-compute inference caches (Monarch matrices + causal mask)
309
  # ═══════════════════════════════════════════════════════════════════
310
 
311
+ """
312
+ precompute_inference_caches(config, ps) -> NamedTuple
313
+
314
+ Pre-realize all Monarch matrices and apply causal mask once at startup.
315
+ Avoids recomputing them on every forward pass during generation.
316
+ """
317
+ function precompute_inference_caches(config::ModelConfig, ps)
318
+ p = isqrt(config.context_length)
319
+ mask = make_causal_mask(config.context_length)
320
+
321
+ # Pre-realize all Monarch matrices: monarch_ms[layer][head] = masked TΓ—T matrix
322
+ monarch_ms = Vector{Vector{Matrix{Float32}}}(undef, config.n_layers)
323
+ for i in 1:config.n_layers
324
+ name = Symbol("block_$i")
325
+ bp = getproperty(ps.blocks, name)
326
+ layer_ms = Vector{Matrix{Float32}}(undef, config.n_monarch_heads)
327
+ for j in 1:config.n_monarch_heads
328
+ head_name = Symbol("head_$j")
329
+ ps_m = getproperty(bp.seq_mixer.monarchs, head_name)
330
+ M = monarch_realize(ps_m.L1, ps_m.L2, p) .* mask
331
+ layer_ms[j] = M
332
+ end
333
+ monarch_ms[i] = layer_ms
334
+ end
335
+
336
+ return (; mask, monarch_ms)
337
+ end
338
+
339
+ # ═══════════════════════════════════════════════════════════════════
340
+ # Monarch Sequence Mixer forward pass (uses cached matrices)
341
+ # ═══════════════════════════════════════════════════════════════════
342
+
343
+ function monarch_sequence_mixer_forward(x, ps, n_heads::Int, monarch_ms_layer)
344
  D, T, B = size(x)
345
  H = n_heads
346
+ HD = D Γ· H
 
347
 
348
  # 1. Causal depthwise conv for local context
349
  conv_out = causal_depthwise_conv1d(x, ps.conv.kernel)
350
 
351
+ # 2. Multi-head Monarch mixing (pre-realized matrices)
352
  monarch_slices = map(1:H) do i
353
+ # Slice cached matrix to actual sequence length
354
+ M_t = monarch_ms_layer[i][1:T, 1:T]
 
 
 
 
 
 
 
 
 
355
 
356
  # Extract this head's channel slice: (HD, T, B)
357
  ch_start = (i - 1) * HD + 1
 
378
  end
379
 
380
  # ═══════════════════════════════════════════════════════════════════
381
+ # Full model forward pass (uses cached data)
382
  # ═══════════════════════════════════════════════════════════════════
383
 
384
+ function model_forward(config::ModelConfig, ps, x, caches)
385
  T = size(x, 1) # x: (seq_len, batch) of integer token IDs
386
 
387
  # Token embedding: (seq_len, batch) β†’ (embed_dim, seq_len, batch)
388
  h = ps.tok_emb.weight[:, x]
389
 
 
 
 
390
  # Monarch blocks
391
  for i in 1:config.n_layers
392
  name = Symbol("block_$i")
 
396
  normed = rmsnorm_forward(h, bp.ln1.weight)
397
  mixed = monarch_sequence_mixer_forward(normed, bp.seq_mixer,
398
  config.n_monarch_heads,
399
+ caches.monarch_ms[i])
400
  h = h .+ mixed
401
 
402
  # Pre-norm FFN + residual
 
463
  temperature::Float64=0.8,
464
  top_k::Int=0,
465
  top_p::Float64=1.0,
466
+ on_token=nothing,
467
+ caches=nothing)
468
  tokens = encode(tokenizer, prompt)
469
  if isempty(tokens)
470
  tokens = [rand(1:tokenizer_vocab_size(tokenizer))]
471
  end
472
 
473
+ # Use provided caches or compute them once
474
+ if caches === nothing
475
+ caches = precompute_inference_caches(config, ps)
476
+ end
477
+
478
  generated = String[]
479
 
480
  for _ in 1:max_tokens
 
485
  end
486
 
487
  x = reshape(ctx, :, 1)
488
+ logits = model_forward(config, ps, x, caches)
489
  next_logits = Vector{Float32}(logits[:, end, 1])
490
 
491
  if temperature != 1.0
server.jl CHANGED
@@ -72,6 +72,7 @@ const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGE
72
  const CONFIG = INF_MODEL.config
73
  const PS = INF_MODEL.ps
74
  const TOKENIZER = INF_MODEL.tokenizer
 
75
  const MODEL_CREATED_AT = Int(floor(time()))
76
 
77
  println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
@@ -217,6 +218,7 @@ function handle_request(request::HTTP.Request)
217
 
218
  generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
219
  max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
 
220
  on_token = function(token_str)
221
  token_count[] += 1
222
  chunk = Dict(
@@ -269,7 +271,8 @@ function handle_request(request::HTTP.Request)
269
  total_completion_tokens = 0
270
  for i in 1:n_completions
271
  text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
272
- max_tokens, temperature, top_k=top_k_val, top_p=top_p_val)
 
273
  finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
274
  push!(choices, Dict(
275
  "index" => i - 1,
 
72
  const CONFIG = INF_MODEL.config
73
  const PS = INF_MODEL.ps
74
  const TOKENIZER = INF_MODEL.tokenizer
75
+ const CACHES = INF_MODEL.caches
76
  const MODEL_CREATED_AT = Int(floor(time()))
77
 
78
  println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
 
218
 
219
  generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
220
  max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
221
+ caches=CACHES,
222
  on_token = function(token_str)
223
  token_count[] += 1
224
  chunk = Dict(
 
271
  total_completion_tokens = 0
272
  for i in 1:n_completions
273
  text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
274
+ max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
275
+ caches=CACHES)
276
  finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens
277
  push!(choices, Dict(
278
  "index" => i - 1,