LisaMegaWatts commited on
Commit
91c86b7
Β·
verified Β·
1 Parent(s): 3724bdb

Upload server.jl with huggingface_hub

Browse files
Files changed (1) hide show
  1. server.jl +313 -0
server.jl ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=
2
+ server.jl β€” OpenAI-compatible inference server for MonarchSLM
3
+
4
+ Serves a Lux.jl trained Monarch Mixer model (sub-quadratic sequence mixing,
5
+ RMSNorm, SwiGLU, weight-tied). Downloads artifacts from HuggingFace on first run.
6
+
7
+ Endpoints:
8
+ GET / -> health check / API info
9
+ GET /v1/models -> list available models
10
+ POST /v1/chat/completions -> generate text (OpenAI format, streaming supported)
11
+ =#
12
+
13
+ include("checkpoint.jl")
14
+ using HTTP
15
+ using UUIDs
16
+ using Downloads
17
+
18
+ # ═══════════════════════════════════════════════════════════════════
19
+ # Download artifacts from HuggingFace
20
+ # ═══════════════════════════════════════════════════════════════════
21
+
22
+ const CKPT_DIR = "checkpoints"
23
+ const CKPT_PATH = joinpath(CKPT_DIR, "final.jld2")
24
+ const CONFIG_PATH = joinpath(CKPT_DIR, "config.toml")
25
+ const VOCAB_PATH = joinpath(CKPT_DIR, "vocab.json")
26
+ const MERGES_PATH = joinpath(CKPT_DIR, "merges.txt")
27
+ const HF_REPO = get(ENV, "HF_REPO", "LisaMegaWatts/MonarchSLM")
28
+ const PORT = parse(Int, get(ENV, "PORT", "7860"))
29
+
30
+ function download_from_hf(repo::String, filename::String, local_path::String)
31
+ url = "https://huggingface.co/$repo/resolve/main/$filename"
32
+ println("Downloading $url ...")
33
+ mkpath(dirname(local_path))
34
+ Downloads.download(url, local_path)
35
+ sz = round(filesize(local_path) / 1024^2, digits=1)
36
+ println(" -> $local_path ($sz MB)")
37
+ end
38
+
39
+ function ensure_artifacts()
40
+ for (localpath, remote) in [(CKPT_PATH, "final.jld2"),
41
+ (CONFIG_PATH, "config.toml"),
42
+ (VOCAB_PATH, "vocab.json")]
43
+ if !isfile(localpath)
44
+ println("No local $remote found, downloading from $HF_REPO ...")
45
+ try
46
+ download_from_hf(HF_REPO, remote, localpath)
47
+ catch e
48
+ println("Download failed for $remote: $e")
49
+ println("Place $remote at $localpath manually.")
50
+ exit(1)
51
+ end
52
+ end
53
+ end
54
+ if !isfile(MERGES_PATH)
55
+ println("Attempting to download merges.txt (optional, for BPE) ...")
56
+ try
57
+ download_from_hf(HF_REPO, "merges.txt", MERGES_PATH)
58
+ catch e
59
+ println(" merges.txt not found (will use char tokenizer if vocab is array format)")
60
+ end
61
+ end
62
+ end
63
+
64
+ # ═══════════════════════════════════════════════════════════════════
65
+ # Download and load model
66
+ # ═══════════════════════════════════════════════════════════════════
67
+
68
+ ensure_artifacts()
69
+
70
+ println("\nLoading model...")
71
+ const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGES_PATH)
72
+ const CONFIG = INF_MODEL.config
73
+ const PS = INF_MODEL.ps
74
+ const TOKENIZER = INF_MODEL.tokenizer
75
+ const MODEL_CREATED_AT = Int(floor(time()))
76
+
77
+ println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " *
78
+ "layers=$(CONFIG.n_layers), monarch_heads=$(CONFIG.n_monarch_heads), ctx=$(CONFIG.context_length)")
79
+
80
+ # ═══════════════════════════════════════════════════════════════════
81
+ # HTTP helpers
82
+ # ═══════════════════════════════════════════════════════════════════
83
+
84
+ const CORS_HEADERS = [
85
+ "Access-Control-Allow-Origin" => "*",
86
+ "Access-Control-Allow-Methods" => "GET, POST, OPTIONS",
87
+ "Access-Control-Allow-Headers" => "Content-Type, Authorization",
88
+ ]
89
+
90
+ function json_response(status::Int, body; extra_headers=[])
91
+ json_bytes = JSON3.write(body)
92
+ headers = [
93
+ "Content-Type" => "application/json",
94
+ CORS_HEADERS...,
95
+ extra_headers...
96
+ ]
97
+ return HTTP.Response(status, headers, json_bytes)
98
+ end
99
+
100
+ function cors_preflight()
101
+ return HTTP.Response(204, CORS_HEADERS)
102
+ end
103
+
104
+ # ═══════════════════════════════════════════════════════════════════
105
+ # Extract prompt from OpenAI chat messages
106
+ # ═══════════════════════════════════════════════════════════════════
107
+
108
+ function extract_prompt(messages)
109
+ if isempty(messages)
110
+ return ""
111
+ end
112
+ for i in length(messages):-1:1
113
+ role = string(get(messages[i], :role, ""))
114
+ if role == "user"
115
+ return string(get(messages[i], :content, ""))
116
+ end
117
+ end
118
+ return string(get(messages[end], :content, ""))
119
+ end
120
+
121
+ # ═══════════════════════════════════════════════════════════════════
122
+ # SSE helpers
123
+ # ═══════════════════════════════════════════════════════════════════
124
+
125
+ function sse_line(data)
126
+ return "data: $(JSON3.write(data))\n\n"
127
+ end
128
+
129
+ # ═══════════════════════════════════════════════════════════════════
130
+ # Request handler
131
+ # ═══════════════════════════════════════════════════════════════════
132
+
133
+ function handle_request(request::HTTP.Request)
134
+ method = request.method
135
+ target = request.target
136
+
137
+ if method == "OPTIONS"
138
+ return cors_preflight()
139
+ end
140
+
141
+ # GET / β€” health check
142
+ if method == "GET" && target == "/"
143
+ return json_response(200, Dict(
144
+ "name" => "MonarchSLM",
145
+ "version" => "1.0.0",
146
+ "description" => "A Monarch Mixer model trained on classical philosophy texts",
147
+ "architecture" => "Decoder-only (Monarch Mixer, RMSNorm, SwiGLU, weight-tied)",
148
+ "model" => Dict(
149
+ "arch" => CONFIG.arch,
150
+ "vocab_size" => CONFIG.vocab_size,
151
+ "embed_dim" => CONFIG.embed_dim,
152
+ "n_layers" => CONFIG.n_layers,
153
+ "n_monarch_heads" => CONFIG.n_monarch_heads,
154
+ "conv_kernel_size" => CONFIG.conv_kernel_size,
155
+ "context_length" => CONFIG.context_length
156
+ ),
157
+ "endpoints" => ["/v1/models", "/v1/chat/completions"],
158
+ "features" => ["streaming", "OpenAI-compatible", "top-k", "top-p"],
159
+ "compatible_with" => ["OpenAI API", "OpenRouter"]
160
+ ))
161
+ end
162
+
163
+ # GET /v1/models
164
+ if method == "GET" && target == "/v1/models"
165
+ return json_response(200, Dict(
166
+ "object" => "list",
167
+ "data" => [Dict(
168
+ "id" => "monarchslm-philosophy",
169
+ "object" => "model",
170
+ "created" => MODEL_CREATED_AT,
171
+ "owned_by" => "monarchslm"
172
+ )]
173
+ ))
174
+ end
175
+
176
+ # POST /v1/chat/completions
177
+ if method == "POST" && target == "/v1/chat/completions"
178
+ local body
179
+ try
180
+ body = JSON3.read(String(request.body))
181
+ catch e
182
+ return json_response(400, Dict("error" => Dict(
183
+ "message" => "Invalid JSON in request body",
184
+ "type" => "invalid_request_error",
185
+ "code" => "invalid_json")))
186
+ end
187
+
188
+ temperature = Float64(clamp(get(body, :temperature, 0.8), 0.01, 2.0))
189
+ max_tokens = Int(clamp(get(body, :max_tokens, 200), 1, CONFIG.context_length))
190
+ top_k_val = Int(clamp(get(body, :top_k, 40), 0, CONFIG.vocab_size))
191
+ top_p_val = Float64(clamp(get(body, :top_p, 1.0), 0.0, 1.0))
192
+ stream = Bool(get(body, :stream, false))
193
+
194
+ messages = get(body, :messages, [])
195
+ prompt_text = extract_prompt(messages)
196
+
197
+ if stream
198
+ completion_id = "chatcmpl-" * string(uuid4())
199
+ created = Int(floor(time()))
200
+
201
+ buf = IOBuffer()
202
+
203
+ initial_chunk = Dict(
204
+ "id" => completion_id,
205
+ "object" => "chat.completion.chunk",
206
+ "created" => created,
207
+ "model" => "monarchslm-philosophy",
208
+ "choices" => [Dict(
209
+ "index" => 0,
210
+ "delta" => Dict("role" => "assistant", "content" => ""),
211
+ "finish_reason" => nothing
212
+ )]
213
+ )
214
+ write(buf, sse_line(initial_chunk))
215
+
216
+ token_count = Ref(0)
217
+
218
+ generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
219
+ max_tokens, temperature, top_k=top_k_val, top_p=top_p_val,
220
+ on_token = function(token_str)
221
+ token_count[] += 1
222
+ chunk = Dict(
223
+ "id" => completion_id,
224
+ "object" => "chat.completion.chunk",
225
+ "created" => created,
226
+ "model" => "monarchslm-philosophy",
227
+ "choices" => [Dict(
228
+ "index" => 0,
229
+ "delta" => Dict("content" => token_str),
230
+ "finish_reason" => nothing
231
+ )]
232
+ )
233
+ write(buf, sse_line(chunk))
234
+ end)
235
+
236
+ prompt_tokens = length(encode(TOKENIZER, prompt_text))
237
+ finish_chunk = Dict(
238
+ "id" => completion_id,
239
+ "object" => "chat.completion.chunk",
240
+ "created" => created,
241
+ "model" => "monarchslm-philosophy",
242
+ "choices" => [Dict(
243
+ "index" => 0,
244
+ "delta" => Dict(),
245
+ "finish_reason" => token_count[] >= max_tokens ? "length" : "stop"
246
+ )],
247
+ "usage" => Dict(
248
+ "prompt_tokens" => prompt_tokens,
249
+ "completion_tokens" => token_count[],
250
+ "total_tokens" => prompt_tokens + token_count[]
251
+ )
252
+ )
253
+ write(buf, sse_line(finish_chunk))
254
+ write(buf, "data: [DONE]\n\n")
255
+
256
+ sse_body = take!(buf)
257
+ headers = [
258
+ "Content-Type" => "text/event-stream",
259
+ "Cache-Control" => "no-cache",
260
+ "X-Accel-Buffering" => "no",
261
+ CORS_HEADERS...
262
+ ]
263
+ return HTTP.Response(200, headers, sse_body)
264
+
265
+ else
266
+ n_completions = Int(clamp(get(body, :n, 1), 1, 4))
267
+
268
+ choices = []
269
+ total_completion_tokens = 0
270
+ for i in 1:n_completions
271
+ text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text;
272
+ max_tokens, temperature, top_k=top_k_val, top_p=top_p_val)
273
+ finish_reason = length(text) >= max_tokens ? "length" : "stop"
274
+ push!(choices, Dict(
275
+ "index" => i - 1,
276
+ "message" => Dict("role" => "assistant", "content" => text),
277
+ "finish_reason" => finish_reason))
278
+ total_completion_tokens += length(text)
279
+ end
280
+
281
+ prompt_tokens = length(encode(TOKENIZER, prompt_text))
282
+ return json_response(200, Dict(
283
+ "id" => "chatcmpl-" * string(uuid4()),
284
+ "object" => "chat.completion",
285
+ "created" => Int(floor(time())),
286
+ "model" => "monarchslm-philosophy",
287
+ "choices" => choices,
288
+ "usage" => Dict(
289
+ "prompt_tokens" => prompt_tokens,
290
+ "completion_tokens" => total_completion_tokens,
291
+ "total_tokens" => prompt_tokens + total_completion_tokens),
292
+ "system_fingerprint" => "monarchslm-v1"))
293
+ end
294
+ end
295
+
296
+ return json_response(404, Dict("error" => Dict(
297
+ "message" => "Not found: $method $target",
298
+ "type" => "invalid_request_error",
299
+ "code" => "not_found")))
300
+ end
301
+
302
+ # ═══════════════════════════════════════════════════════════════════
303
+ # Start server
304
+ # ═══════════════════════════════════════════════════════════════════
305
+
306
+ println("\nMonarchSLM server starting on 0.0.0.0:$PORT ...")
307
+ println(" GET http://localhost:$PORT/")
308
+ println(" GET http://localhost:$PORT/v1/models")
309
+ println(" POST http://localhost:$PORT/v1/chat/completions")
310
+ println(" POST http://localhost:$PORT/v1/chat/completions (stream=true)")
311
+ println()
312
+
313
+ HTTP.serve(handle_request, "0.0.0.0", PORT)