Spaces:
Building
Building
| namespace beast = boost::beast; | |
| namespace http = beast::http; | |
| static std::atomic<uint64_t> g_req_id{1}; | |
| static std::string build_sse_event(const json &payload) { | |
| return "data: " + payload.dump() + "\n\n"; | |
| } | |
| static std::string extract_chat_text(const json &completion) { | |
| if (!completion.is_object()) return ""; | |
| if (!completion.contains("choices") || !completion["choices"].is_array() || completion["choices"].empty()) { | |
| return ""; | |
| } | |
| const auto &choice = completion["choices"][0]; | |
| if (!choice.is_object()) return ""; | |
| if (choice.contains("message") && choice["message"].is_object()) { | |
| const auto &message = choice["message"]; | |
| if (message.contains("content") && message["content"].is_string()) { | |
| return message["content"].get<std::string>(); | |
| } | |
| } | |
| if (choice.contains("text") && choice["text"].is_string()) { | |
| return choice["text"].get<std::string>(); | |
| } | |
| return ""; | |
| } | |
| static json completion_payload_to_chat_payload(const json &payload, const LimitsConfig &limits) { | |
| json chat_payload = json::object(); | |
| chat_payload["messages"] = json::array(); | |
| if (payload.contains("prompt")) { | |
| if (payload["prompt"].is_string()) { | |
| chat_payload["messages"].push_back({ | |
| {"role", "user"}, | |
| {"content", payload["prompt"].get<std::string>()} | |
| }); | |
| } else if (payload["prompt"].is_array()) { | |
| std::string joined_prompt; | |
| bool first = true; | |
| for (const auto &item : payload["prompt"]) { | |
| if (!item.is_string()) continue; | |
| if (!first) joined_prompt += "\n"; | |
| joined_prompt += item.get<std::string>(); | |
| first = false; | |
| } | |
| chat_payload["messages"].push_back({ | |
| {"role", "user"}, | |
| {"content", joined_prompt} | |
| }); | |
| } | |
| } | |
| int normalized_max_tokens = limits.default_max_tokens; | |
| if (payload.contains("n_predict") && payload["n_predict"].is_number_integer()) { | |
| normalized_max_tokens = payload["n_predict"].get<int>(); | |
| } else if (payload.contains("max_tokens") && payload["max_tokens"].is_number_integer()) { | |
| normalized_max_tokens = payload["max_tokens"].get<int>(); | |
| } | |
| if (normalized_max_tokens <= 0) { | |
| normalized_max_tokens = limits.default_max_tokens; | |
| } | |
| chat_payload["max_tokens"] = normalized_max_tokens; | |
| if (chat_payload["messages"].empty()) { | |
| chat_payload["messages"].push_back({ | |
| {"role", "user"}, | |
| {"content", ""} | |
| }); | |
| } | |
| if (payload.contains("temperature")) chat_payload["temperature"] = payload["temperature"]; | |
| if (payload.contains("top_p")) chat_payload["top_p"] = payload["top_p"]; | |
| if (payload.contains("top_k")) chat_payload["top_k"] = payload["top_k"]; | |
| if (payload.contains("stop")) chat_payload["stop"] = payload["stop"]; | |
| if (payload.contains("stream")) chat_payload["stream"] = payload["stream"]; | |
| return chat_payload; | |
| } | |
| static std::string build_completion_compat_response(const std::string &completion_body) { | |
| json completion = json::parse(completion_body, nullptr, false); | |
| if (completion.is_discarded() || !completion.is_object()) { | |
| return completion_body; | |
| } | |
| json out = { | |
| {"content", extract_chat_text(completion)} | |
| }; | |
| if (completion.contains("stop")) out["stop"] = completion["stop"]; | |
| if (completion.contains("stopped_eos")) out["stopped_eos"] = completion["stopped_eos"]; | |
| if (completion.contains("stopped_limit")) out["stopped_limit"] = completion["stopped_limit"]; | |
| if (completion.contains("tokens_predicted")) out["tokens_predicted"] = completion["tokens_predicted"]; | |
| if (completion.contains("tokens_evaluated")) out["tokens_evaluated"] = completion["tokens_evaluated"]; | |
| return out.dump(); | |
| } | |
| static std::string build_buffered_stream_response(const std::string &completion_body) { | |
| json completion = json::parse(completion_body, nullptr, false); | |
| if (completion.is_discarded() || !completion.is_object()) { | |
| return "data: [DONE]\n\n"; | |
| } | |
| const std::string id = completion.value("id", "chatcmpl-buffered"); | |
| const std::string model = completion.value("model", ""); | |
| const auto created = completion.value("created", 0); | |
| const std::string assistant_content = extract_chat_text(completion); | |
| std::ostringstream oss; | |
| oss << build_sse_event({ | |
| {"id", id}, | |
| {"object", "chat.completion.chunk"}, | |
| {"created", created}, | |
| {"model", model}, | |
| {"choices", json::array({ | |
| { | |
| {"index", 0}, | |
| {"delta", {{"role", "assistant"}}}, | |
| {"finish_reason", nullptr} | |
| } | |
| })} | |
| }); | |
| if (!assistant_content.empty()) { | |
| oss << build_sse_event({ | |
| {"id", id}, | |
| {"object", "chat.completion.chunk"}, | |
| {"created", created}, | |
| {"model", model}, | |
| {"choices", json::array({ | |
| { | |
| {"index", 0}, | |
| {"delta", {{"content", assistant_content}}}, | |
| {"finish_reason", nullptr} | |
| } | |
| })} | |
| }); | |
| } | |
| oss << build_sse_event({ | |
| {"id", id}, | |
| {"object", "chat.completion.chunk"}, | |
| {"created", created}, | |
| {"model", model}, | |
| {"choices", json::array({ | |
| { | |
| {"index", 0}, | |
| {"delta", json::object()}, | |
| {"finish_reason", "stop"} | |
| } | |
| })} | |
| }); | |
| oss << "data: [DONE]\n\n"; | |
| return oss.str(); | |
| } | |
| http::response<http::string_body> handle_request( | |
| ModelManager &manager, | |
| const ManagerConfig &config, | |
| const ApiKeyAuth &auth, | |
| RateLimiterStore &rate_limiter, | |
| RequestRegistry ®istry, | |
| MetricsRegistry &metrics, | |
| Scheduler &scheduler, | |
| http::request<http::string_body> &&req) { | |
| const auto start = std::chrono::steady_clock::now(); | |
| const auto req_id_num = g_req_id.fetch_add(1); | |
| const std::string request_id = std::to_string(req_id_num); | |
| const std::string target = req.target().to_string(); | |
| const std::string method = req.method_string().to_string(); | |
| const std::string path = target.substr(0, target.find('?')); | |
| auto authenticated = std::optional<ApiKeyRecord>{}; | |
| metrics.inc_requests_total(); | |
| metrics.inc_requests_inflight(); | |
| struct InflightGuard { | |
| MetricsRegistry &metrics; | |
| ~InflightGuard() { metrics.dec_requests_inflight(); } | |
| } inflight_guard{metrics}; | |
| log_line("request_id=" + request_id + " method=" + method + " path=" + target); | |
| if (!req.body().empty()) { | |
| log_line("request_id=" + request_id + " body=" + truncate_body(req.body())); | |
| } | |
| auto json_response = [&](http::status status, const json &obj) { | |
| json payload = obj; | |
| payload["request_id"] = request_id; | |
| http::response<http::string_body> res{status, req.version()}; | |
| res.set(http::field::content_type, "application/json"); | |
| res.set(http::field::server, "llm-manager"); | |
| res.set("X-Request-Id", request_id); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = payload.dump(); | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start).count(); | |
| metrics.observe_request_latency_ms(elapsed_ms); | |
| std::string log_message = "request_id=" + request_id + | |
| " status=" + std::to_string(res.result_int()) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms); | |
| if (res.result_int() >= 400) { | |
| log_message += " error_body=" + truncate_body(res.body()); | |
| } | |
| log_line(log_message); | |
| return res; | |
| }; | |
| auto json_response_with_retry_after = [&](http::status status, const json &obj, int retry_after_sec) { | |
| auto res = json_response(status, obj); | |
| res.set(http::field::retry_after, std::to_string(std::max(1, retry_after_sec))); | |
| return res; | |
| }; | |
| auto ensure_authenticated = [&](Role minimum_role) -> std::optional<http::response<http::string_body>> { | |
| std::string auth_error; | |
| authenticated = auth.authenticate(req, auth_error); | |
| if (!authenticated) { | |
| return json_response(http::status::unauthorized, {{"error", auth_error}}); | |
| } | |
| if (minimum_role == Role::ADMIN && authenticated->role != Role::ADMIN) { | |
| return json_response(http::status::forbidden, {{"error", "Admin role required"}}); | |
| } | |
| log_line("request_id=" + request_id + | |
| " api_key_id=" + authenticated->key_id + | |
| " role=" + role_to_string(authenticated->role)); | |
| return std::nullopt; | |
| }; | |
| try { | |
| if (path == "/health" && req.method() == http::verb::get) { | |
| return json_response(http::status::ok, manager.models_view()); | |
| } | |
| if (path == "/models" && req.method() == http::verb::get) { | |
| return json_response(http::status::ok, manager.models_view()); | |
| } | |
| if (path == "/queue/metrics" && req.method() == http::verb::get) { | |
| http::response<http::string_body> res{http::status::ok, req.version()}; | |
| res.set(http::field::content_type, "text/plain; version=0.0.4; charset=utf-8"); | |
| res.set(http::field::server, "llm-manager"); | |
| res.set("X-Request-Id", request_id); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = metrics.render_prometheus(scheduler.snapshot(), manager); | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start).count(); | |
| metrics.observe_request_latency_ms(elapsed_ms); | |
| log_line("request_id=" + request_id + " status=" + std::to_string(res.result_int()) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms)); | |
| return res; | |
| } | |
| if (path == "/switch-model" && req.method() == http::verb::post) { | |
| if (auto auth_res = ensure_authenticated(Role::ADMIN)) return *auth_res; | |
| json j = json::parse(req.body(), nullptr, false); | |
| if (j.is_discarded()) { | |
| return json_response(http::status::bad_request, {{"error", "Invalid JSON"}}); | |
| } | |
| std::string model; | |
| if (j.contains("model_name")) model = j["model_name"].get<std::string>(); | |
| if (j.contains("model")) model = j["model"].get<std::string>(); | |
| if (model.empty()) { | |
| return json_response(http::status::bad_request, {{"error", "Expected 'model' or 'model_name'"}}); | |
| } | |
| std::string err; | |
| bool ok = manager.switch_model(model, err); | |
| if (!ok) { | |
| auto status = (err == "Switch already in progress") | |
| ? http::status::conflict | |
| : http::status::internal_server_error; | |
| return json_response(status, {{"status", "error"}, {"error", err}}); | |
| } | |
| auto state = manager.models_view(); | |
| state["message"] = "Switched model successfully"; | |
| metrics.inc_switch_total(); | |
| return json_response(http::status::ok, state); | |
| } | |
| if (path == "/stop" && req.method() == http::verb::post) { | |
| if (auto auth_res = ensure_authenticated(Role::ADMIN)) return *auth_res; | |
| const auto cancelled = registry.cancel_all(); | |
| metrics.add_cancellations_total(cancelled.size()); | |
| std::string err; | |
| bool ok = manager.restart_active(err); | |
| if (!ok) { | |
| http::status status = http::status::internal_server_error; | |
| if (err == "Switch already in progress") status = http::status::conflict; | |
| else if (err == "No active model") status = http::status::service_unavailable; | |
| return json_response(status, {{"status", "error"}, {"error", err}}); | |
| } | |
| auto state = manager.models_view(); | |
| state["message"] = "Stopped in-flight prompts and restarted model"; | |
| metrics.inc_worker_restarts_total(); | |
| return json_response(http::status::ok, state); | |
| } | |
| if (req.method() == http::verb::post) { | |
| if (auto cancel_id = extract_cancel_request_id(path)) { | |
| if (auto auth_res = ensure_authenticated(Role::USER)) return *auth_res; | |
| auto ctx = registry.find(*cancel_id); | |
| if (!ctx) { | |
| return json_response(http::status::not_found, {{"error", "Unknown request id"}}); | |
| } | |
| if (authenticated->role != Role::ADMIN && authenticated->key_id != ctx->api_key_id) { | |
| return json_response(http::status::forbidden, {{"error", "Cannot cancel another API key request"}}); | |
| } | |
| const auto previous_state = ctx->state.load(); | |
| registry.cancel_request(*cancel_id); | |
| metrics.add_cancellations_total(); | |
| std::string restart_error; | |
| bool restarted = true; | |
| if (previous_state == RequestState::RUNNING) { | |
| restarted = manager.restart_active(restart_error); | |
| if (restarted) metrics.inc_worker_restarts_total(); | |
| } | |
| json payload = { | |
| {"cancelled_request_id", *cancel_id}, | |
| {"state", state_to_string(ctx->state.load())} | |
| }; | |
| if (!restarted) payload["restart_error"] = restart_error; | |
| return json_response(http::status::ok, payload); | |
| } | |
| } | |
| if ((path == "/v1/chat/completions" || path == "/completion") && req.method() == http::verb::post) { | |
| if (auto auth_res = ensure_authenticated(Role::USER)) return *auth_res; | |
| json payload = json::parse(req.body(), nullptr, false); | |
| if (payload.is_discarded()) { | |
| return json_response(http::status::bad_request, {{"error", "Invalid JSON"}}); | |
| } | |
| const bool completion_compat_mode = path == "/completion"; | |
| if (completion_compat_mode) { | |
| payload = completion_payload_to_chat_payload(payload, config.limits); | |
| } | |
| const bool stream_requested = request_stream_enabled(payload); | |
| if (stream_requested) { | |
| payload["stream"] = false; | |
| log_line("request_id=" + request_id + | |
| " stream_requested=true mode=buffered_sse_fallback"); | |
| } | |
| if (completion_compat_mode) { | |
| log_line("request_id=" + request_id + " completion_compat_mode=true"); | |
| } | |
| std::string token_error; | |
| auto estimate = estimate_chat_tokens(payload, config.limits, token_error); | |
| if (!estimate) { | |
| return json_response(http::status::bad_request, {{"error", token_error}}); | |
| } | |
| log_line("request_id=" + request_id + | |
| " prompt_tokens=" + std::to_string(estimate->prompt_tokens) + | |
| " max_tokens=" + std::to_string(estimate->requested_max_tokens) + | |
| " estimated_total_tokens=" + std::to_string(estimate->estimated_total_tokens)); | |
| auto rate_limit_decision = rate_limiter.allow(authenticated->key_id, estimate->estimated_total_tokens); | |
| if (!rate_limit_decision.allowed) { | |
| metrics.inc_rate_limited_total(); | |
| return json_response_with_retry_after( | |
| http::status::too_many_requests, | |
| {{"error", rate_limit_decision.error}}, | |
| rate_limit_decision.retry_after_sec); | |
| } | |
| const std::string upstream_request_body = payload.dump(); | |
| auto ctx = registry.create(request_id, *authenticated, *estimate, upstream_request_body); | |
| if (!scheduler.try_enqueue(ctx)) { | |
| ctx->cancelled.store(true); | |
| registry.complete(ctx, RequestState::CANCELLED, {503, R"({"error":"Queue full"})"}); | |
| metrics.inc_queue_rejected_total(); | |
| return json_response_with_retry_after( | |
| http::status::service_unavailable, | |
| {{"error", "Queue full"}}, | |
| scheduler.retry_after_sec()); | |
| } | |
| std::unique_lock<std::mutex> lock(ctx->mu); | |
| const bool finished = ctx->cv.wait_for( | |
| lock, | |
| std::chrono::seconds(std::max(1, config.limits.request_timeout_sec)), | |
| [&]() { return ctx->completed; }); | |
| if (!finished) { | |
| lock.unlock(); | |
| registry.cancel_request(request_id); | |
| metrics.add_cancellations_total(); | |
| std::string restart_error; | |
| bool restarted = true; | |
| if (ctx->state.load() == RequestState::RUNNING) { | |
| restarted = manager.restart_active(restart_error); | |
| if (restarted) metrics.inc_worker_restarts_total(); | |
| } | |
| json timeout_payload = { | |
| {"error", "Request timed out"}, | |
| {"state", state_to_string(ctx->state.load())} | |
| }; | |
| if (!restarted) timeout_payload["restart_error"] = restart_error; | |
| return json_response(http::status::gateway_timeout, timeout_payload); | |
| } | |
| const auto final_state = ctx->state.load(); | |
| RequestResult result = ctx->result; | |
| lock.unlock(); | |
| if (final_state == RequestState::CANCELLED) { | |
| return json_response(http::status::ok, {{"status", "cancelled"}}); | |
| } | |
| http::response<http::string_body> res{ | |
| static_cast<http::status>(result.status), req.version()}; | |
| if (stream_requested && result.status >= 200 && result.status < 300) { | |
| result.body = build_buffered_stream_response(result.body); | |
| result.content_type = "text/event-stream; charset=utf-8"; | |
| } else if (completion_compat_mode && result.status >= 200 && result.status < 300) { | |
| result.body = build_completion_compat_response(result.body); | |
| result.content_type = "application/json"; | |
| } | |
| res.set(http::field::content_type, result.content_type); | |
| res.set(http::field::server, "llm-manager"); | |
| res.set("X-Request-Id", request_id); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = result.body; | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start).count(); | |
| metrics.observe_request_latency_ms(elapsed_ms); | |
| log_line("request_id=" + request_id + | |
| " final_state=" + state_to_string(final_state) + | |
| " upstream_status=" + std::to_string(result.status) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms)); | |
| if (result.status >= 400) { | |
| log_line("request_id=" + request_id + | |
| " upstream_error_body=" + truncate_body(result.body)); | |
| } | |
| return res; | |
| } | |
| if (req.method() == http::verb::get) { | |
| auto worker = manager.active_worker(); | |
| if (!worker) { | |
| return json_response(http::status::service_unavailable, {{"error", "No active model"}}); | |
| } | |
| auto upstream = forward_get_to_worker(*worker, target); | |
| http::response<http::string_body> res{ | |
| static_cast<http::status>(upstream.status), req.version()}; | |
| res.set(http::field::content_type, upstream.content_type); | |
| if (!upstream.content_encoding.empty()) { | |
| res.set(http::field::content_encoding, upstream.content_encoding); | |
| } | |
| res.set(http::field::server, "llm-manager"); | |
| res.set("X-Request-Id", request_id); | |
| res.keep_alive(req.keep_alive()); | |
| res.body() = upstream.body; | |
| res.prepare_payload(); | |
| auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>( | |
| std::chrono::steady_clock::now() - start).count(); | |
| log_line("request_id=" + request_id + | |
| " proxied_get model=" + worker->model + | |
| " upstream_status=" + std::to_string(upstream.status) + | |
| " elapsed_ms=" + std::to_string(elapsed_ms)); | |
| if (upstream.status >= 400) { | |
| log_line("request_id=" + request_id + | |
| " proxied_get_error_body=" + truncate_body(upstream.body)); | |
| } | |
| return res; | |
| } | |
| return json_response(http::status::not_found, {{"error", "Not found"}}); | |
| } catch (const std::exception &e) { | |
| log_line("request_id=" + request_id + " exception=" + std::string(e.what())); | |
| return json_response(http::status::internal_server_error, {{"error", e.what()}}); | |
| } catch (...) { | |
| log_line("request_id=" + request_id + " exception=unknown"); | |
| return json_response(http::status::internal_server_error, {{"error", "Unknown exception"}}); | |
| } | |
| } | |
| void do_session( | |
| boost::asio::ip::tcp::socket socket, | |
| ModelManager &manager, | |
| const ManagerConfig &config, | |
| const ApiKeyAuth &auth, | |
| RateLimiterStore &rate_limiter, | |
| RequestRegistry ®istry, | |
| MetricsRegistry &metrics, | |
| Scheduler &scheduler) { | |
| try { | |
| beast::flat_buffer buffer; | |
| http::request<http::string_body> req; | |
| http::read(socket, buffer, req); | |
| auto res = handle_request(manager, config, auth, rate_limiter, registry, metrics, scheduler, std::move(req)); | |
| http::write(socket, res); | |
| beast::error_code ec; | |
| socket.shutdown(boost::asio::ip::tcp::socket::shutdown_send, ec); | |
| } catch (const std::exception &e) { | |
| log_line("session_exception=" + std::string(e.what())); | |
| } catch (...) { | |
| log_line("session_exception=unknown"); | |
| } | |
| } | |