Spaces:
Running
Running
| #= | |
| server.jl β OpenAI-compatible inference server for MonarchSLM | |
| Serves a Lux.jl trained Monarch Mixer model (sub-quadratic sequence mixing, | |
| RMSNorm, SwiGLU, weight-tied). Downloads artifacts from HuggingFace on first run. | |
| Endpoints: | |
| GET / -> health check / API info | |
| GET /v1/models -> list available models | |
| POST /v1/chat/completions -> generate text (OpenAI format, streaming supported) | |
| =# | |
| include("checkpoint.jl") | |
| using HTTP | |
| using UUIDs | |
| using Downloads | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Download artifacts from HuggingFace | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| const CKPT_DIR = "checkpoints" | |
| const CKPT_PATH = joinpath(CKPT_DIR, "final.jld2") | |
| const CONFIG_PATH = joinpath(CKPT_DIR, "config.toml") | |
| const VOCAB_PATH = joinpath(CKPT_DIR, "vocab.json") | |
| const MERGES_PATH = joinpath(CKPT_DIR, "merges.txt") | |
| const HF_REPO = get(ENV, "HF_REPO", "LisaMegaWatts/MonarchSLM") | |
| const PORT = parse(Int, get(ENV, "PORT", "7860")) | |
| function download_from_hf(repo::String, filename::String, local_path::String) | |
| url = "https://huggingface.co/$repo/resolve/main/$filename" | |
| println("Downloading $url ...") | |
| mkpath(dirname(local_path)) | |
| Downloads.download(url, local_path) | |
| sz = round(filesize(local_path) / 1024^2, digits=1) | |
| println(" -> $local_path ($sz MB)") | |
| end | |
| function ensure_artifacts() | |
| for (localpath, remote) in [(CKPT_PATH, "final.jld2"), | |
| (CONFIG_PATH, "config.toml"), | |
| (VOCAB_PATH, "vocab.json")] | |
| if !isfile(localpath) | |
| println("No local $remote found, downloading from $HF_REPO ...") | |
| try | |
| download_from_hf(HF_REPO, remote, localpath) | |
| catch e | |
| println("Download failed for $remote: $e") | |
| println("Place $remote at $localpath manually.") | |
| exit(1) | |
| end | |
| end | |
| end | |
| if !isfile(MERGES_PATH) | |
| println("Attempting to download merges.txt (optional, for BPE) ...") | |
| try | |
| download_from_hf(HF_REPO, "merges.txt", MERGES_PATH) | |
| catch e | |
| println(" merges.txt not found (will use char tokenizer if vocab is array format)") | |
| end | |
| end | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Download and load model | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ensure_artifacts() | |
| println("\nLoading model...") | |
| const INF_MODEL = load_inference_model(CKPT_PATH, CONFIG_PATH, VOCAB_PATH, MERGES_PATH) | |
| const CONFIG = INF_MODEL.config | |
| const PS = INF_MODEL.ps | |
| const TOKENIZER = INF_MODEL.tokenizer | |
| const CACHES = INF_MODEL.caches | |
| const MODEL_CREATED_AT = Int(floor(time())) | |
| println("\nModel ready: arch=$(CONFIG.arch), vocab=$(CONFIG.vocab_size), embd=$(CONFIG.embed_dim), " * | |
| "layers=$(CONFIG.n_layers), monarch_heads=$(CONFIG.n_monarch_heads), ctx=$(CONFIG.context_length)") | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HTTP helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| const CORS_HEADERS = [ | |
| "Access-Control-Allow-Origin" => "*", | |
| "Access-Control-Allow-Methods" => "GET, POST, OPTIONS", | |
| "Access-Control-Allow-Headers" => "Content-Type, Authorization", | |
| ] | |
| function json_response(status::Int, body; extra_headers=[]) | |
| json_bytes = JSON3.write(body) | |
| headers = [ | |
| "Content-Type" => "application/json", | |
| CORS_HEADERS..., | |
| extra_headers... | |
| ] | |
| return HTTP.Response(status, headers, json_bytes) | |
| end | |
| function cors_preflight() | |
| return HTTP.Response(204, CORS_HEADERS) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Extract prompt from OpenAI chat messages | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function extract_prompt(messages) | |
| if isempty(messages) | |
| return "" | |
| end | |
| for i in length(messages):-1:1 | |
| role = string(get(messages[i], :role, "")) | |
| if role == "user" | |
| return string(get(messages[i], :content, "")) | |
| end | |
| end | |
| return string(get(messages[end], :content, "")) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SSE helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function sse_line(data) | |
| return "data: $(JSON3.write(data))\n\n" | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Request handler | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| function handle_request(request::HTTP.Request) | |
| method = request.method | |
| target = request.target | |
| if method == "OPTIONS" | |
| return cors_preflight() | |
| end | |
| # GET / β health check | |
| if method == "GET" && target == "/" | |
| return json_response(200, Dict( | |
| "name" => "MonarchSLM", | |
| "version" => "1.0.0", | |
| "description" => "A Monarch Mixer model trained on classical philosophy texts", | |
| "architecture" => "Decoder-only (Monarch Mixer, RMSNorm, SwiGLU, weight-tied)", | |
| "model" => Dict( | |
| "arch" => CONFIG.arch, | |
| "vocab_size" => CONFIG.vocab_size, | |
| "embed_dim" => CONFIG.embed_dim, | |
| "n_layers" => CONFIG.n_layers, | |
| "n_monarch_heads" => CONFIG.n_monarch_heads, | |
| "conv_kernel_size" => CONFIG.conv_kernel_size, | |
| "context_length" => CONFIG.context_length | |
| ), | |
| "endpoints" => ["/v1/models", "/v1/chat/completions"], | |
| "features" => ["streaming", "OpenAI-compatible", "top-k", "top-p"], | |
| "compatible_with" => ["OpenAI API", "OpenRouter"] | |
| )) | |
| end | |
| # GET /v1/models | |
| if method == "GET" && target == "/v1/models" | |
| return json_response(200, Dict( | |
| "object" => "list", | |
| "data" => [Dict( | |
| "id" => "monarchslm-philosophy", | |
| "object" => "model", | |
| "created" => MODEL_CREATED_AT, | |
| "owned_by" => "monarchslm" | |
| )] | |
| )) | |
| end | |
| # POST /v1/chat/completions | |
| if method == "POST" && target == "/v1/chat/completions" | |
| local body | |
| try | |
| body = JSON3.read(String(request.body)) | |
| catch e | |
| return json_response(400, Dict("error" => Dict( | |
| "message" => "Invalid JSON in request body", | |
| "type" => "invalid_request_error", | |
| "code" => "invalid_json"))) | |
| end | |
| temperature = Float64(clamp(get(body, :temperature, 0.8), 0.01, 2.0)) | |
| max_tokens = Int(clamp(get(body, :max_tokens, 200), 1, CONFIG.context_length)) | |
| top_k_val = Int(clamp(get(body, :top_k, 40), 0, CONFIG.vocab_size)) | |
| top_p_val = Float64(clamp(get(body, :top_p, 1.0), 0.0, 1.0)) | |
| stream = Bool(get(body, :stream, false)) | |
| messages = get(body, :messages, []) | |
| prompt_text = extract_prompt(messages) | |
| if stream | |
| completion_id = "chatcmpl-" * string(uuid4()) | |
| created = Int(floor(time())) | |
| buf = IOBuffer() | |
| initial_chunk = Dict( | |
| "id" => completion_id, | |
| "object" => "chat.completion.chunk", | |
| "created" => created, | |
| "model" => "monarchslm-philosophy", | |
| "choices" => [Dict( | |
| "index" => 0, | |
| "delta" => Dict("role" => "assistant", "content" => ""), | |
| "finish_reason" => nothing | |
| )] | |
| ) | |
| write(buf, sse_line(initial_chunk)) | |
| token_count = Ref(0) | |
| generate_streaming(CONFIG, PS, TOKENIZER, prompt_text; | |
| max_tokens, temperature, top_k=top_k_val, top_p=top_p_val, | |
| caches=CACHES, | |
| on_token = function(token_str) | |
| token_count[] += 1 | |
| chunk = Dict( | |
| "id" => completion_id, | |
| "object" => "chat.completion.chunk", | |
| "created" => created, | |
| "model" => "monarchslm-philosophy", | |
| "choices" => [Dict( | |
| "index" => 0, | |
| "delta" => Dict("content" => token_str), | |
| "finish_reason" => nothing | |
| )] | |
| ) | |
| write(buf, sse_line(chunk)) | |
| end) | |
| prompt_tokens = length(encode(TOKENIZER, prompt_text)) | |
| finish_chunk = Dict( | |
| "id" => completion_id, | |
| "object" => "chat.completion.chunk", | |
| "created" => created, | |
| "model" => "monarchslm-philosophy", | |
| "choices" => [Dict( | |
| "index" => 0, | |
| "delta" => Dict(), | |
| "finish_reason" => token_count[] >= max_tokens ? "length" : "stop" | |
| )], | |
| "usage" => Dict( | |
| "prompt_tokens" => prompt_tokens, | |
| "completion_tokens" => token_count[], | |
| "total_tokens" => prompt_tokens + token_count[] | |
| ) | |
| ) | |
| write(buf, sse_line(finish_chunk)) | |
| write(buf, "data: [DONE]\n\n") | |
| sse_body = take!(buf) | |
| headers = [ | |
| "Content-Type" => "text/event-stream", | |
| "Cache-Control" => "no-cache", | |
| "X-Accel-Buffering" => "no", | |
| CORS_HEADERS... | |
| ] | |
| return HTTP.Response(200, headers, sse_body) | |
| else | |
| n_completions = Int(clamp(get(body, :n, 1), 1, 4)) | |
| choices = [] | |
| total_completion_tokens = 0 | |
| for i in 1:n_completions | |
| text = generate_streaming(CONFIG, PS, TOKENIZER, prompt_text; | |
| max_tokens, temperature, top_k=top_k_val, top_p=top_p_val, | |
| caches=CACHES) | |
| finish_reason = "length" # generate_streaming always produces exactly max_tokens tokens | |
| push!(choices, Dict( | |
| "index" => i - 1, | |
| "message" => Dict("role" => "assistant", "content" => text), | |
| "finish_reason" => finish_reason)) | |
| total_completion_tokens += max_tokens # count tokens, not decoded chars | |
| end | |
| prompt_tokens = length(encode(TOKENIZER, prompt_text)) | |
| return json_response(200, Dict( | |
| "id" => "chatcmpl-" * string(uuid4()), | |
| "object" => "chat.completion", | |
| "created" => Int(floor(time())), | |
| "model" => "monarchslm-philosophy", | |
| "choices" => choices, | |
| "usage" => Dict( | |
| "prompt_tokens" => prompt_tokens, | |
| "completion_tokens" => total_completion_tokens, | |
| "total_tokens" => prompt_tokens + total_completion_tokens), | |
| "system_fingerprint" => "monarchslm-v1")) | |
| end | |
| end | |
| return json_response(404, Dict("error" => Dict( | |
| "message" => "Not found: $method $target", | |
| "type" => "invalid_request_error", | |
| "code" => "not_found"))) | |
| end | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Start server | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| println("\nMonarchSLM server starting on 0.0.0.0:$PORT ...") | |
| println(" GET http://localhost:$PORT/") | |
| println(" GET http://localhost:$PORT/v1/models") | |
| println(" POST http://localhost:$PORT/v1/chat/completions") | |
| println(" POST http://localhost:$PORT/v1/chat/completions (stream=true)") | |
| println() | |
| HTTP.serve(handle_request, "0.0.0.0", PORT) | |