File size: 14,010 Bytes
91c86b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76b7110
91c86b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76b7110
91c86b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76b7110
 
f0aedd4
91c86b7
 
 
 
f0aedd4
91c86b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
#=
server.jl β€” OpenAI-compatible inference server for MonarchSLM

Serves a Lux.jl trained Monarch Mixer model (sub-quadratic sequence mixing,
RMSNorm, SwiGLU, weight-tied). Downloads artifacts from HuggingFace on first run.

Endpoints:
    GET  /                       -> health check / API info
    GET  /v1/models              -> list available models
    POST /v1/chat/completions    -> generate text (OpenAI format, streaming supported)
=#

include("checkpoint.jl")
using HTTP
using UUIDs
using Downloads

# ═══════════════════════════════════════════════════════════════════
# Download artifacts from HuggingFace
# ═══════════════════════════════════════════════════════════════════

const CKPT_DIR = "checkpoints"
const CKPT_PATH = joinpath(CKPT_DIR, "final.jld2")
const CONFIG_PATH = joinpath(CKPT_DIR, "config.toml")
const VOCAB_PATH = joinpath(CKPT_DIR, "vocab.json")
const MERGES_PATH = joinpath(CKPT_DIR, "merges.txt")
const HF_REPO = get(ENV, "HF_REPO", "LisaMegaWatts/MonarchSLM")
const PORT = parse(Int, get(ENV, "PORT", "7860"))

function download_from_hf(repo::String, filename::String, local_path::String)
    url = "https://huggingface.co/$repo/resolve/main/$filename"
    println("Downloading $url ...")
    mkpath(dirname(local_path))
    Downloads.download(url, local_path)
    sz = round(filesize(local_path) / 1024^2, digits=1)
    println("  -> $local_path ($sz MB)")
end

function ensure_artifacts()
    for (localpath, remote) in [(CKPT_PATH, "final.jld2"),
                                (CONFIG_PATH, "config.toml"),
                                (VOCAB_PATH, "vocab.json")]
        if !isfile(localpath)
            println("No local $remote found, downloading from $HF_REPO ...")
            try
                download_from_hf(HF_REPO, remote, localpath)
            catch e
                println("Download failed for $remote: $e")
                println("Place $remote at $localpath manually.")
                exit(1)
            end
        end
    end
    if !isfile(MERGES_PATH)
        println("Attempting to download merges.txt (optional, for BPE) ...")
        try
            download_from_hf(HF_REPO, "merges.txt", MERGES_PATH)
        catch e
            println("  merges.txt not found (will use char tokenizer if vocab is array format)")
        end
    end
end

# ═══════════════════════════════════════════════════════════════════
# Download and load model
# ═══════════════════════════════════════════════════════════════════

ensure_artifacts()

println("\nLoading model...")
const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGES_PATH)
const CONFIG = INF_MODEL.config
const PS = INF_MODEL.ps
const TOKENIZER = INF_MODEL.tokenizer
const CACHES = INF_MODEL.caches
const MODEL_CREATED_AT = Int(floor(time()))

println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
        "layers=$(CONFIG.n_layers), monarch_heads=$(CONFIG.n_monarch_heads), ctx=$(CONFIG.context_length)")

# ═══════════════════════════════════════════════════════════════════
# HTTP helpers
# ═══════════════════════════════════════════════════════════════════

const CORS_HEADERS = [
    "Access-Control-Allow-Origin" => "*",
    "Access-Control-Allow-Methods" => "GET, POST, OPTIONS",
    "Access-Control-Allow-Headers" => "Content-Type, Authorization",
]

function json_response(status::Int, body; extra_headers=[])
    json_bytes = JSON3.write(body)
    headers = [
        "Content-Type" => "application/json",
        CORS_HEADERS...,
        extra_headers...
    ]
    return HTTP.Response(status, headers, json_bytes)
end

function cors_preflight()
    return HTTP.Response(204, CORS_HEADERS)
end

# ═══════════════════════════════════════════════════════════════════
# Extract prompt from OpenAI chat messages
# ═══════════════════════════════════════════════════════════════════

function extract_prompt(messages)
    if isempty(messages)
        return ""
    end
    for i in length(messages):-1:1
        role = string(get(messages[i], :role, ""))
        if role == "user"
            return string(get(messages[i], :content, ""))
        end
    end
    return string(get(messages[end], :content, ""))
end

# ═══════════════════════════════════════════════════════════════════
# SSE helpers
# ═══════════════════════════════════════════════════════════════════

function sse_line(data)
    return "data: $(JSON3.write(data))\n\n"
end

# ═══════════════════════════════════════════════════════════════════
# Request handler
# ═══════════════════════════════════════════════════════════════════

function handle_request(request::HTTP.Request)
    method = request.method
    target = request.target

    if method == "OPTIONS"
        return cors_preflight()
    end

    # GET / β€” health check
    if method == "GET" && target == "/"
        return json_response(200, Dict(
            "name" => "MonarchSLM",
            "version" => "1.0.0",
            "description" => "A Monarch Mixer model trained on classical philosophy texts",
            "architecture" => "Decoder-only (Monarch Mixer, RMSNorm, SwiGLU, weight-tied)",
            "model" => Dict(
                "arch" => CONFIG.arch,
                "vocab_size" => CONFIG.vocab_size,
                "embed_dim" => CONFIG.embed_dim,
                "n_layers" => CONFIG.n_layers,
                "n_monarch_heads" => CONFIG.n_monarch_heads,
                "conv_kernel_size" => CONFIG.conv_kernel_size,
                "context_length" => CONFIG.context_length
            ),
            "endpoints" => ["/v1/models", "/v1/chat/completions"],
            "features" => ["streaming", "OpenAI-compatible", "top-k", "top-p"],
            "compatible_with" => ["OpenAI API", "OpenRouter"]
        ))
    end

    # GET /v1/models
    if method == "GET" && target == "/v1/models"
        return json_response(200, Dict(
            "object" => "list",
            "data" => [Dict(
                "id" => "monarchslm-philosophy",
                "object" => "model",
                "created" => MODEL_CREATED_AT,
                "owned_by" => "monarchslm"
            )]
        ))
    end

    # POST /v1/chat/completions
    if method == "POST" && target == "/v1/chat/completions"
        local body
        try
            body = JSON3.read(String(request.body))
        catch e
            return json_response(400, Dict("error" => Dict(
                "message" => "Invalid JSON in request body",
                "type" => "invalid_request_error",
                "code" => "invalid_json")))
        end

        temperature = Float64(clamp(get(body, :temperature, 0.8), 0.01, 2.0))
        max_tokens = Int(clamp(get(body, :max_tokens, 200), 1, CONFIG.context_length))
        top_k_val = Int(clamp(get(body, :top_k, 40), 0, CONFIG.vocab_size))
        top_p_val = Float64(clamp(get(body, :top_p, 1.0), 0.0, 1.0))
        stream = Bool(get(body, :stream, false))

        messages = get(body, :messages, [])
        prompt_text = extract_prompt(messages)

        if stream
            completion_id = "chatcmpl-" * string(uuid4())
            created = Int(floor(time()))

            buf = IOBuffer()

            initial_chunk = Dict(
                "id" => completion_id,
                "object" => "chat.completion.chunk",
                "created" => created,
                "model" => "monarchslm-philosophy",
                "choices" => [Dict(
                    "index" => 0,
                    "delta" => Dict("role" => "assistant", "content" => ""),
                    "finish_reason" => nothing
                )]
            )
            write(buf, sse_line(initial_chunk))

            token_count = Ref(0)

            generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
                               max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
                               caches=CACHES,
                               on_token = function(token_str)
                                   token_count[] += 1
                                   chunk = Dict(
                                       "id" => completion_id,
                                       "object" => "chat.completion.chunk",
                                       "created" => created,
                                       "model" => "monarchslm-philosophy",
                                       "choices" => [Dict(
                                           "index" => 0,
                                           "delta" => Dict("content" => token_str),
                                           "finish_reason" => nothing
                                       )]
                                   )
                                   write(buf, sse_line(chunk))
                               end)

            prompt_tokens = length(encode(TOKENIZER, prompt_text))
            finish_chunk = Dict(
                "id" => completion_id,
                "object" => "chat.completion.chunk",
                "created" => created,
                "model" => "monarchslm-philosophy",
                "choices" => [Dict(
                    "index" => 0,
                    "delta" => Dict(),
                    "finish_reason" => token_count[] >= max_tokens ? "length" : "stop"
                )],
                "usage" => Dict(
                    "prompt_tokens" => prompt_tokens,
                    "completion_tokens" => token_count[],
                    "total_tokens" => prompt_tokens + token_count[]
                )
            )
            write(buf, sse_line(finish_chunk))
            write(buf, "data: [DONE]\n\n")

            sse_body = take!(buf)
            headers = [
                "Content-Type" => "text/event-stream",
                "Cache-Control" => "no-cache",
                "X-Accel-Buffering" => "no",
                CORS_HEADERS...
            ]
            return HTTP.Response(200, headers, sse_body)

        else
            n_completions = Int(clamp(get(body, :n, 1), 1, 4))

            choices = []
            total_completion_tokens = 0
            for i in 1:n_completions
                text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
                                          max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
                                          caches=CACHES)
                finish_reason = "length"  # generate_streaming always produces exactly max_tokens tokens
                push!(choices, Dict(
                    "index" => i - 1,
                    "message" => Dict("role" => "assistant", "content" => text),
                    "finish_reason" => finish_reason))
                total_completion_tokens += max_tokens  # count tokens, not decoded chars
            end

            prompt_tokens = length(encode(TOKENIZER, prompt_text))
            return json_response(200, Dict(
                "id" => "chatcmpl-" * string(uuid4()),
                "object" => "chat.completion",
                "created" => Int(floor(time())),
                "model" => "monarchslm-philosophy",
                "choices" => choices,
                "usage" => Dict(
                    "prompt_tokens" => prompt_tokens,
                    "completion_tokens" => total_completion_tokens,
                    "total_tokens" => prompt_tokens + total_completion_tokens),
                "system_fingerprint" => "monarchslm-v1"))
        end
    end

    return json_response(404, Dict("error" => Dict(
        "message" => "Not found: $method $target",
        "type" => "invalid_request_error",
        "code" => "not_found")))
end

# ═══════════════════════════════════════════════════════════════════
# Start server
# ═══════════════════════════════════════════════════════════════════

println("\nMonarchSLM server starting on 0.0.0.0:$PORT ...")
println("  GET  http://localhost:$PORT/")
println("  GET  http://localhost:$PORT/v1/models")
println("  POST http://localhost:$PORT/v1/chat/completions")
println("  POST http://localhost:$PORT/v1/chat/completions  (stream=true)")
println()

HTTP.serve(handle_request, "0.0.0.0", PORT)