LisaMegaWatts commited on
Commit
3724bdb
Β·
verified Β·
1 Parent(s): 0575c49

Upload checkpoint.jl with huggingface_hub

Browse files
Files changed (1) hide show
  1. checkpoint.jl +98 -0
checkpoint.jl ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ checkpoint.jl β€” Load Lux-trained MonarchSLM checkpoint for inference
3
+
4
+ Loads model parameters from JLD2, config from TOML, and tokenizer from JSON + merges.
5
+ Converts Float16 parameters to Float32 for efficient CPU inference.
6
+ No RoPE caches needed β€” Monarch uses learned position mixing.
7
+ =#
8
+
9
+ include("model.jl")
10
+ using JLD2
11
+
12
+ # ═══════════════════════════════════════════════════════════════════
13
+ # Float32 conversion for CPU inference
14
+ # ═══════════════════════════════════════════════════════════════════
15
+
16
+ ensure_f32(x::AbstractArray{Float16}) = Float32.(x)
17
+ ensure_f32(x::AbstractArray) = x
18
+ ensure_f32(x::NamedTuple) = NamedTuple{keys(x)}(map(ensure_f32, values(x)))
19
+ ensure_f32(x::Tuple) = map(ensure_f32, x)
20
+ ensure_f32(x) = x
21
+
22
+ # ═══════════════════════════════════════════════════════════════════
23
+ # Tokenizer loading β€” auto-detect BPE vs char based on file format
24
+ # ═══════════════════════════════════════════════════════════════════
25
+
26
+ function load_tokenizer(vocab_path::String, merges_path::String)
27
+ if isfile(merges_path)
28
+ println("Loading BPE tokenizer from $vocab_path + $merges_path ...")
29
+ tok = load_bpe_tokenizer(vocab_path, merges_path)
30
+ println(" BPE vocab_size = $(tok.vocab_size), merges = $(length(tok.merges))")
31
+ return tok
32
+ end
33
+
34
+ raw_text = read(vocab_path, String)
35
+ parsed = JSON3.read(raw_text)
36
+ if parsed isa AbstractDict
37
+ println("Loading BPE tokenizer from $vocab_path (no merges file) ...")
38
+ tok = load_bpe_tokenizer_no_merges(vocab_path)
39
+ println(" BPE vocab_size = $(tok.vocab_size) (no merges)")
40
+ return tok
41
+ end
42
+
43
+ println("Loading character tokenizer from $vocab_path ...")
44
+ tok = load_char_vocab_json(vocab_path)
45
+ println(" char vocab_size = $(tok.vocab_size)")
46
+ return tok
47
+ end
48
+
49
+ function load_bpe_tokenizer_no_merges(vocab_path::String)
50
+ encoder = JSON3.read(read(vocab_path, String), Dict{String, Int})
51
+ decoder = Dict{Int, String}(v => k for (k, v) in encoder)
52
+ b2u = _build_byte_to_unicode()
53
+ u2b = Dict{Char, UInt8}(v => k for (k, v) in b2u)
54
+ pat = r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"
55
+ return BPETokenizer(encoder, decoder, Tuple{String,String}[],
56
+ Dict{Tuple{String,String},Int}(), b2u, u2b,
57
+ length(encoder), pat)
58
+ end
59
+
60
+ # ═══════════════════════════════════════════════════════════════════
61
+ # Load everything needed for inference
62
+ # ═══════════════════════════════════════════════════════════════════
63
+
64
+ function load_inference_model(ckpt_path::String, config_path::String,
65
+ vocab_path::String, merges_path::String)
66
+ # Tokenizer (determines vocab_size)
67
+ tokenizer = load_tokenizer(vocab_path, merges_path)
68
+ vs = tokenizer_vocab_size(tokenizer)
69
+
70
+ # Config (with dynamically-set vocab_size from tokenizer)
71
+ println("Loading config from $config_path ...")
72
+ config = load_config_toml(config_path; vocab_size=vs)
73
+ println(" arch=$(config.arch), embed_dim=$(config.embed_dim), layers=$(config.n_layers)")
74
+ println(" monarch_heads=$(config.n_monarch_heads), conv_kernel=$(config.conv_kernel_size)")
75
+ println(" context_length=$(config.context_length), weight_tying=$(config.weight_tying)")
76
+
77
+ # Parameters
78
+ println("Loading parameters from $ckpt_path ...")
79
+ ps = ensure_f32(JLD2.load(ckpt_path, "parameters"))
80
+
81
+ step = try JLD2.load(ckpt_path, "step") catch; 0 end
82
+ val_loss = try JLD2.load(ckpt_path, "best_val_loss") catch; Inf end
83
+ println(" step=$step, best_val_loss=$(round(val_loss; digits=4))")
84
+
85
+ # Verify embedding dimensions match
86
+ emb_shape = size(ps.tok_emb.weight)
87
+ println(" embedding weight: $(emb_shape) (expect $(config.embed_dim) x $(config.vocab_size))")
88
+ if emb_shape[2] != config.vocab_size
89
+ @warn "Vocab size mismatch!" config_vocab=config.vocab_size embedding_vocab=emb_shape[2]
90
+ config = ModelConfig(config.arch, config.embed_dim, config.n_layers,
91
+ config.n_monarch_heads, config.conv_kernel_size,
92
+ config.context_length, emb_shape[2],
93
+ config.weight_tying, config.bias)
94
+ println(" Adjusted vocab_size to $(config.vocab_size) from embedding weight")
95
+ end
96
+
97
+ return (; config, ps, tokenizer, step, val_loss)
98
+ end