Dmitry Beresnev commited on
Commit
d9ce859
·
1 Parent(s): 8ef326a

add auth, token policy, queue scheduler, and cancel flow, etc

Browse files
Files changed (3) hide show
  1. .gitignore +3 -0
  2. config.toml.example +52 -0
  3. cpp/llm_manager.cpp +1031 -35
.gitignore CHANGED
@@ -131,3 +131,6 @@ temp/
131
  tests/
132
  *.md
133
  docs/
 
 
 
 
131
  tests/
132
  *.md
133
  docs/
134
+
135
+ #
136
+ .clang-format
config.toml.example ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ host = "0.0.0.0"
3
+ port = 7860
4
+
5
+ [worker]
6
+ default_model = "QuantFactory/Qwen2.5-7B-Instruct-GGUF:q4_k_m"
7
+ llama_server_bin = "/usr/local/bin/llama-server"
8
+ host = "127.0.0.1"
9
+ bind_host = "0.0.0.0"
10
+ base_port = 8080
11
+ switch_timeout_sec = 300
12
+
13
+ [llama]
14
+ n_ctx = 8192
15
+ threads = 4
16
+ ngl = 0
17
+ batch = 128
18
+ ubatch = 64
19
+
20
+ [auth]
21
+ header = "Authorization"
22
+ scheme = "Bearer"
23
+
24
+ [limits]
25
+ default_max_tokens = 256
26
+ max_tokens_per_request = 2048
27
+ request_timeout_sec = 30
28
+
29
+ [queue]
30
+ max_size = 100
31
+ max_tokens = 20000
32
+ admin_quota = 3
33
+ retry_after_sec = 5
34
+
35
+ [scheduler]
36
+ max_concurrent = 1
37
+
38
+ [rate_limit]
39
+ requests_per_minute = 60
40
+ estimated_tokens_per_minute = 6000
41
+
42
+ [[api_keys]]
43
+ key_id = "admin-main"
44
+ secret = "change-me-admin"
45
+ role = "admin"
46
+ enabled = true
47
+
48
+ [[api_keys]]
49
+ key_id = "user-main"
50
+ secret = "change-me-user"
51
+ role = "user"
52
+ enabled = true
cpp/llm_manager.cpp CHANGED
@@ -4,18 +4,25 @@
4
  #include <boost/beast/version.hpp>
5
  #include <nlohmann/json.hpp>
6
 
 
7
  #include <atomic>
8
  #include <chrono>
 
9
  #include <csignal>
10
  #include <cstdlib>
11
  #include <ctime>
 
 
 
12
  #include <iomanip>
13
  #include <iostream>
 
14
  #include <mutex>
15
  #include <optional>
16
  #include <sstream>
17
  #include <string>
18
  #include <thread>
 
19
  #include <vector>
20
 
21
  #include <sys/types.h>
@@ -27,6 +34,81 @@ namespace beast = boost::beast;
27
  namespace http = beast::http;
28
  using json = nlohmann::json;
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  struct WorkerInfo {
31
  std::string model;
32
  int port = 0;
@@ -34,6 +116,8 @@ struct WorkerInfo {
34
  std::string last_loaded;
35
  };
36
 
 
 
37
  static std::string now_utc_iso() {
38
  std::time_t t = std::time(nullptr);
39
  std::tm tm{};
@@ -58,6 +142,399 @@ static int get_env_int_or(const char *name, int fallback) {
58
  }
59
  }
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  static bool is_alive(pid_t pid) {
62
  if (pid <= 0) return false;
63
  return kill(pid, 0) == 0;
@@ -80,18 +557,18 @@ static void shutdown_worker(pid_t pid, int wait_seconds = 15) {
80
 
81
  class ModelManager {
82
  public:
83
- ModelManager()
84
- : _default_model(get_env_or("DEFAULT_MODEL", "QuantFactory/Qwen2.5-7B-Instruct-GGUF:q4_k_m")),
85
- _llama_server_bin(get_env_or("LLAMA_SERVER_BIN", "/usr/local/bin/llama-server")),
86
- _worker_host(get_env_or("WORKER_HOST", "127.0.0.1")),
87
- _worker_bind_host(get_env_or("WORKER_BIND_HOST", "0.0.0.0")),
88
- _base_port(get_env_int_or("WORKER_BASE_PORT", 8080)),
89
- _switch_timeout_sec(get_env_int_or("SWITCH_TIMEOUT_SEC", 300)),
90
- _n_ctx(get_env_int_or("MODEL_N_CTX", 8192)),
91
- _n_threads(get_env_int_or("MODEL_THREADS", 4)),
92
- _n_gpu_layers(get_env_int_or("MODEL_NGL", 0)),
93
- _n_batch(get_env_int_or("MODEL_BATCH", 128)),
94
- _n_ubatch(get_env_int_or("MODEL_UBATCH", 64)),
95
  _next_port(_base_port) {}
96
 
97
  bool initialize_default(std::string &error) {
@@ -318,6 +795,364 @@ private:
318
 
319
  static std::atomic<uint64_t> g_req_id{1};
320
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
321
  static void log_line(const std::string &line) {
322
  std::cout << "[" << now_utc_iso() << "] " << line << std::endl;
323
  }
@@ -327,6 +1162,17 @@ static std::string truncate_body(const std::string &body, size_t max_len = 2000)
327
  return body.substr(0, max_len) + "...[truncated]";
328
  }
329
 
 
 
 
 
 
 
 
 
 
 
 
330
  static std::pair<int, std::string> forward_chat(const WorkerInfo &worker, const std::string &body) {
331
  asio::io_context ioc;
332
  asio::ip::tcp::resolver resolver(ioc);
@@ -391,34 +1237,65 @@ static ProxiedGetResult forward_get_to_worker(const WorkerInfo &worker,
391
  template <typename Body, typename Allocator>
392
  http::response<http::string_body> handle_request(
393
  ModelManager &manager,
 
 
 
 
 
394
  http::request<Body, http::basic_fields<Allocator>> &&req) {
395
  const auto start = std::chrono::steady_clock::now();
396
- const auto req_id = g_req_id.fetch_add(1);
 
397
  const std::string target = req.target().to_string();
398
  const std::string method = req.method_string().to_string();
399
  const std::string path = target.substr(0, target.find('?'));
 
400
 
401
- log_line("request_id=" + std::to_string(req_id) + " method=" + method + " path=" + target);
402
  if constexpr (std::is_same_v<Body, http::string_body>) {
403
  if (!req.body().empty()) {
404
- log_line("request_id=" + std::to_string(req_id) + " body=" + truncate_body(req.body()));
405
  }
406
  }
407
 
408
  auto json_response = [&](http::status status, const json &obj) {
 
 
409
  http::response<http::string_body> res{status, req.version()};
410
  res.set(http::field::content_type, "application/json");
411
  res.set(http::field::server, "llm-manager");
 
412
  res.keep_alive(req.keep_alive());
413
- res.body() = obj.dump();
414
  res.prepare_payload();
415
  auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
416
  std::chrono::steady_clock::now() - start).count();
417
- log_line("request_id=" + std::to_string(req_id) + " status=" + std::to_string(res.result_int()) +
418
  " elapsed_ms=" + std::to_string(elapsed_ms));
419
  return res;
420
  };
421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  try {
423
  if (path == "/health" && req.method() == http::verb::get) {
424
  return json_response(http::status::ok, manager.models_view());
@@ -429,6 +1306,7 @@ http::response<http::string_body> handle_request(
429
  }
430
 
431
  if (path == "/switch-model" && req.method() == http::verb::post) {
 
432
  std::string body(req.body().data(), req.body().size());
433
  json j = json::parse(body, nullptr, false);
434
  if (j.is_discarded()) {
@@ -453,6 +1331,8 @@ http::response<http::string_body> handle_request(
453
  }
454
 
455
  if (path == "/stop" && req.method() == http::verb::post) {
 
 
456
  std::string err;
457
  bool ok = manager.restart_active(err);
458
  if (!ok) {
@@ -469,24 +1349,114 @@ http::response<http::string_body> handle_request(
469
  return json_response(http::status::ok, state);
470
  }
471
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  if (path == "/v1/chat/completions" && req.method() == http::verb::post) {
473
- auto worker = manager.active_worker();
474
- if (!worker) {
475
- return json_response(http::status::service_unavailable, {{"error", "No active model"}});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
  }
477
- auto [upstream_status, upstream_body] = forward_chat(*worker, req.body());
478
- http::response<http::string_body> res{static_cast<http::status>(upstream_status), req.version()};
479
- res.set(http::field::content_type, "application/json");
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
  res.set(http::field::server, "llm-manager");
 
481
  res.keep_alive(req.keep_alive());
482
- res.body() = upstream_body;
483
  res.prepare_payload();
484
  auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
485
  std::chrono::steady_clock::now() - start).count();
486
- log_line("request_id=" + std::to_string(req_id) + " model=" + worker->model +
487
- " active_pid=" + std::to_string(worker->pid) +
488
- " active_port=" + std::to_string(worker->port) +
489
- " upstream_status=" + std::to_string(upstream_status) +
490
  " elapsed_ms=" + std::to_string(elapsed_ms));
491
  return res;
492
  }
@@ -506,13 +1476,14 @@ http::response<http::string_body> handle_request(
506
  res.set(http::field::content_encoding, upstream.content_encoding);
507
  }
508
  res.set(http::field::server, "llm-manager");
 
509
  res.keep_alive(req.keep_alive());
510
  res.body() = upstream.body;
511
  res.prepare_payload();
512
  auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
513
  std::chrono::steady_clock::now() - start)
514
  .count();
515
- log_line("request_id=" + std::to_string(req_id) +
516
  " proxied_get model=" + worker->model +
517
  " upstream_status=" + std::to_string(upstream.status) +
518
  " elapsed_ms=" + std::to_string(elapsed_ms));
@@ -525,12 +1496,19 @@ http::response<http::string_body> handle_request(
525
  }
526
  }
527
 
528
- void do_session(asio::ip::tcp::socket socket, ModelManager &manager) {
 
 
 
 
 
 
 
529
  try {
530
  beast::flat_buffer buffer;
531
  http::request<http::string_body> req;
532
  http::read(socket, buffer, req);
533
- auto res = handle_request(manager, std::move(req));
534
  http::write(socket, res);
535
  beast::error_code ec;
536
  socket.shutdown(asio::ip::tcp::socket::shutdown_send, ec);
@@ -539,11 +1517,21 @@ void do_session(asio::ip::tcp::socket socket, ModelManager &manager) {
539
  }
540
 
541
  int main() {
542
- const auto bind_host = get_env_or("MANAGER_HOST", "0.0.0.0");
543
- const int bind_port = get_env_int_or("MANAGER_PORT", 7860);
544
- ModelManager manager;
 
 
 
 
 
545
 
546
  std::string init_error;
 
 
 
 
 
547
  log_line("startup: loading default model");
548
  if (!manager.initialize_default(init_error)) {
549
  log_line("startup: default model failed: " + init_error);
@@ -558,6 +1546,14 @@ int main() {
558
  for (;;) {
559
  asio::ip::tcp::socket socket{ioc};
560
  acceptor.accept(socket);
561
- std::thread(&do_session, std::move(socket), std::ref(manager)).detach();
 
 
 
 
 
 
 
 
562
  }
563
  }
 
4
  #include <boost/beast/version.hpp>
5
  #include <nlohmann/json.hpp>
6
 
7
+ #include <algorithm>
8
  #include <atomic>
9
  #include <chrono>
10
+ #include <condition_variable>
11
  #include <csignal>
12
  #include <cstdlib>
13
  #include <ctime>
14
+ #include <deque>
15
+ #include <filesystem>
16
+ #include <fstream>
17
  #include <iomanip>
18
  #include <iostream>
19
+ #include <memory>
20
  #include <mutex>
21
  #include <optional>
22
  #include <sstream>
23
  #include <string>
24
  #include <thread>
25
+ #include <unordered_map>
26
  #include <vector>
27
 
28
  #include <sys/types.h>
 
34
  namespace http = beast::http;
35
  using json = nlohmann::json;
36
 
37
+ static void log_line(const std::string &line);
38
+
39
+ struct ServerConfig {
40
+ std::string host = "0.0.0.0";
41
+ int port = 7860;
42
+ };
43
+
44
+ struct WorkerConfig {
45
+ std::string default_model = "QuantFactory/Qwen2.5-7B-Instruct-GGUF:q4_k_m";
46
+ std::string llama_server_bin = "/usr/local/bin/llama-server";
47
+ std::string host = "127.0.0.1";
48
+ std::string bind_host = "0.0.0.0";
49
+ int base_port = 8080;
50
+ int switch_timeout_sec = 300;
51
+ };
52
+
53
+ struct LlamaConfig {
54
+ int n_ctx = 8192;
55
+ int threads = 4;
56
+ int ngl = 0;
57
+ int batch = 128;
58
+ int ubatch = 64;
59
+ };
60
+
61
+ enum class Role {
62
+ ADMIN,
63
+ USER
64
+ };
65
+
66
+ struct AuthConfig {
67
+ std::string header = "Authorization";
68
+ std::string scheme = "Bearer";
69
+ };
70
+
71
+ struct LimitsConfig {
72
+ int default_max_tokens = 256;
73
+ int max_tokens_per_request = 2048;
74
+ int request_timeout_sec = 30;
75
+ };
76
+
77
+ struct QueueConfig {
78
+ size_t max_size = 100;
79
+ int max_tokens = 20000;
80
+ int admin_quota = 3;
81
+ int retry_after_sec = 5;
82
+ };
83
+
84
+ struct RateLimitConfig {
85
+ int requests_per_minute = 60;
86
+ int estimated_tokens_per_minute = 6000;
87
+ };
88
+
89
+ struct SchedulerConfig {
90
+ int max_concurrent = 1;
91
+ };
92
+
93
+ struct ApiKeyRecord {
94
+ std::string key_id;
95
+ std::string secret;
96
+ Role role = Role::USER;
97
+ bool enabled = true;
98
+ };
99
+
100
+ struct ManagerConfig {
101
+ ServerConfig server;
102
+ WorkerConfig worker;
103
+ LlamaConfig llama;
104
+ AuthConfig auth;
105
+ LimitsConfig limits;
106
+ QueueConfig queue;
107
+ RateLimitConfig rate_limit;
108
+ SchedulerConfig scheduler;
109
+ std::vector<ApiKeyRecord> api_keys;
110
+ };
111
+
112
  struct WorkerInfo {
113
  std::string model;
114
  int port = 0;
 
116
  std::string last_loaded;
117
  };
118
 
119
+ static std::pair<int, std::string> forward_chat(const WorkerInfo &worker, const std::string &body);
120
+
121
  static std::string now_utc_iso() {
122
  std::time_t t = std::time(nullptr);
123
  std::tm tm{};
 
142
  }
143
  }
144
 
145
+ static std::string trim_copy(const std::string &value) {
146
+ const auto first = value.find_first_not_of(" \t\r\n");
147
+ if (first == std::string::npos) return "";
148
+ const auto last = value.find_last_not_of(" \t\r\n");
149
+ return value.substr(first, last - first + 1);
150
+ }
151
+
152
+ static std::string strip_quotes(const std::string &value) {
153
+ if (value.size() >= 2) {
154
+ const char first = value.front();
155
+ const char last = value.back();
156
+ if ((first == '"' && last == '"') || (first == '\'' && last == '\'')) {
157
+ return value.substr(1, value.size() - 2);
158
+ }
159
+ }
160
+ return value;
161
+ }
162
+
163
+ static bool parse_bool_or(const std::string &value, bool fallback) {
164
+ const std::string normalized = trim_copy(value);
165
+ if (normalized == "true") return true;
166
+ if (normalized == "false") return false;
167
+ return fallback;
168
+ }
169
+
170
+ static Role parse_role_or(const std::string &value, Role fallback) {
171
+ const std::string normalized = trim_copy(value);
172
+ if (normalized == "admin" || normalized == "ADMIN") return Role::ADMIN;
173
+ if (normalized == "user" || normalized == "USER") return Role::USER;
174
+ return fallback;
175
+ }
176
+
177
+ static std::string role_to_string(Role role) {
178
+ return role == Role::ADMIN ? "admin" : "user";
179
+ }
180
+
181
+ enum class Priority {
182
+ ADMIN = 0,
183
+ USER = 1
184
+ };
185
+
186
+ static Priority role_to_priority(Role role) {
187
+ return role == Role::ADMIN ? Priority::ADMIN : Priority::USER;
188
+ }
189
+
190
+ enum class RequestState {
191
+ QUEUED,
192
+ RUNNING,
193
+ CANCELLED,
194
+ FAILED,
195
+ DONE
196
+ };
197
+
198
+ static std::string state_to_string(RequestState state) {
199
+ switch (state) {
200
+ case RequestState::QUEUED: return "queued";
201
+ case RequestState::RUNNING: return "running";
202
+ case RequestState::CANCELLED: return "cancelled";
203
+ case RequestState::FAILED: return "failed";
204
+ case RequestState::DONE: return "done";
205
+ }
206
+ return "unknown";
207
+ }
208
+
209
+ struct TokenEstimate {
210
+ int prompt_tokens = 0;
211
+ int requested_max_tokens = 0;
212
+ int estimated_total_tokens = 0;
213
+ };
214
+
215
+ struct RateLimitDecision {
216
+ bool allowed = true;
217
+ int retry_after_sec = 0;
218
+ std::string error;
219
+ };
220
+
221
+ struct RequestResult {
222
+ int status = 500;
223
+ std::string body;
224
+ std::string content_type = "application/json";
225
+ };
226
+
227
+ struct RequestContext {
228
+ std::string request_id;
229
+ std::string api_key_id;
230
+ Role role = Role::USER;
231
+ Priority priority = Priority::USER;
232
+ TokenEstimate estimate;
233
+ std::string request_body;
234
+ std::atomic<RequestState> state{RequestState::QUEUED};
235
+ std::atomic<bool> cancelled{false};
236
+ std::chrono::steady_clock::time_point created_at{std::chrono::steady_clock::now()};
237
+ std::chrono::steady_clock::time_point enqueue_time{created_at};
238
+ std::chrono::steady_clock::time_point start_time{};
239
+ std::mutex mu;
240
+ std::condition_variable cv;
241
+ bool completed = false;
242
+ RequestResult result;
243
+ };
244
+
245
+ static int estimate_text_tokens_rough(const std::string &text) {
246
+ if (text.empty()) return 0;
247
+ return std::max(1, static_cast<int>((text.size() + 3) / 4));
248
+ }
249
+
250
+ static std::string flatten_json_content(const json &content) {
251
+ if (content.is_string()) {
252
+ return content.get<std::string>();
253
+ }
254
+ if (content.is_array()) {
255
+ std::ostringstream oss;
256
+ bool first = true;
257
+ for (const auto &item : content) {
258
+ std::string part;
259
+ if (item.is_string()) {
260
+ part = item.get<std::string>();
261
+ } else if (item.is_object() && item.contains("text") && item["text"].is_string()) {
262
+ part = item["text"].get<std::string>();
263
+ }
264
+ if (part.empty()) continue;
265
+ if (!first) oss << '\n';
266
+ oss << part;
267
+ first = false;
268
+ }
269
+ return oss.str();
270
+ }
271
+ return "";
272
+ }
273
+
274
+ static std::optional<TokenEstimate> estimate_chat_tokens(
275
+ const json &payload,
276
+ const LimitsConfig &limits,
277
+ std::string &error) {
278
+ if (!payload.is_object()) {
279
+ error = "Expected JSON object";
280
+ return std::nullopt;
281
+ }
282
+ if (!payload.contains("messages") || !payload["messages"].is_array()) {
283
+ error = "Expected 'messages' array";
284
+ return std::nullopt;
285
+ }
286
+
287
+ TokenEstimate estimate;
288
+ estimate.requested_max_tokens = limits.default_max_tokens;
289
+ if (payload.contains("max_tokens")) {
290
+ if (!payload["max_tokens"].is_number_integer()) {
291
+ error = "Expected integer 'max_tokens'";
292
+ return std::nullopt;
293
+ }
294
+ estimate.requested_max_tokens = payload["max_tokens"].get<int>();
295
+ }
296
+
297
+ if (estimate.requested_max_tokens <= 0) {
298
+ error = "'max_tokens' must be > 0";
299
+ return std::nullopt;
300
+ }
301
+
302
+ for (const auto &message : payload["messages"]) {
303
+ if (!message.is_object()) continue;
304
+ if (message.contains("role") && message["role"].is_string()) {
305
+ estimate.prompt_tokens += estimate_text_tokens_rough(message["role"].get<std::string>());
306
+ }
307
+ if (message.contains("content")) {
308
+ estimate.prompt_tokens += estimate_text_tokens_rough(
309
+ flatten_json_content(message["content"]));
310
+ }
311
+ estimate.prompt_tokens += 4;
312
+ }
313
+
314
+ estimate.estimated_total_tokens = estimate.prompt_tokens + estimate.requested_max_tokens;
315
+ if (estimate.estimated_total_tokens > limits.max_tokens_per_request) {
316
+ error = "Estimated request tokens exceed configured limit";
317
+ return std::nullopt;
318
+ }
319
+
320
+ error.clear();
321
+ return estimate;
322
+ }
323
+
324
+ static std::unordered_map<std::string, std::unordered_map<std::string, std::string>>
325
+ parse_simple_toml(const std::string &path) {
326
+ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> out;
327
+ std::ifstream input(path);
328
+ if (!input.is_open()) return out;
329
+
330
+ std::string current_section;
331
+ std::string line;
332
+ while (std::getline(input, line)) {
333
+ auto hash = line.find('#');
334
+ if (hash != std::string::npos) line = line.substr(0, hash);
335
+ line = trim_copy(line);
336
+ if (line.empty()) continue;
337
+
338
+ if (line.front() == '[' && line.back() == ']') {
339
+ current_section = trim_copy(line.substr(1, line.size() - 2));
340
+ continue;
341
+ }
342
+
343
+ const auto eq = line.find('=');
344
+ if (eq == std::string::npos) continue;
345
+
346
+ std::string key = trim_copy(line.substr(0, eq));
347
+ std::string value = strip_quotes(trim_copy(line.substr(eq + 1)));
348
+ out[current_section][key] = value;
349
+ }
350
+ return out;
351
+ }
352
+
353
+ static std::vector<ApiKeyRecord> parse_api_keys_toml(const std::string &path) {
354
+ std::vector<ApiKeyRecord> keys;
355
+ std::ifstream input(path);
356
+ if (!input.is_open()) return keys;
357
+
358
+ std::string line;
359
+ bool in_api_key = false;
360
+ ApiKeyRecord current;
361
+ bool has_any_field = false;
362
+
363
+ auto flush_current = [&]() {
364
+ if (has_any_field && !current.key_id.empty() && !current.secret.empty()) {
365
+ keys.push_back(current);
366
+ }
367
+ current = ApiKeyRecord{};
368
+ has_any_field = false;
369
+ };
370
+
371
+ while (std::getline(input, line)) {
372
+ auto hash = line.find('#');
373
+ if (hash != std::string::npos) line = line.substr(0, hash);
374
+ line = trim_copy(line);
375
+ if (line.empty()) continue;
376
+
377
+ if (line == "[[api_keys]]") {
378
+ flush_current();
379
+ in_api_key = true;
380
+ continue;
381
+ }
382
+
383
+ if (!in_api_key) continue;
384
+
385
+ if (line.front() == '[' && line.back() == ']') {
386
+ flush_current();
387
+ in_api_key = false;
388
+ continue;
389
+ }
390
+
391
+ const auto eq = line.find('=');
392
+ if (eq == std::string::npos) continue;
393
+
394
+ std::string key = trim_copy(line.substr(0, eq));
395
+ std::string value = strip_quotes(trim_copy(line.substr(eq + 1)));
396
+ has_any_field = true;
397
+
398
+ if (key == "key_id") current.key_id = value;
399
+ else if (key == "secret") current.secret = value;
400
+ else if (key == "role") current.role = parse_role_or(value, current.role);
401
+ else if (key == "enabled") current.enabled = parse_bool_or(value, current.enabled);
402
+ }
403
+
404
+ flush_current();
405
+ return keys;
406
+ }
407
+
408
+ static std::string get_toml_string_or(
409
+ const std::unordered_map<std::string, std::unordered_map<std::string, std::string>> &data,
410
+ const std::string &section,
411
+ const std::string &key,
412
+ const std::string &fallback) {
413
+ const auto it = data.find(section);
414
+ if (it == data.end()) return fallback;
415
+ const auto kv = it->second.find(key);
416
+ if (kv == it->second.end() || kv->second.empty()) return fallback;
417
+ return kv->second;
418
+ }
419
+
420
+ static int get_toml_int_or(
421
+ const std::unordered_map<std::string, std::unordered_map<std::string, std::string>> &data,
422
+ const std::string &section,
423
+ const std::string &key,
424
+ int fallback) {
425
+ const auto it = data.find(section);
426
+ if (it == data.end()) return fallback;
427
+ const auto kv = it->second.find(key);
428
+ if (kv == it->second.end() || kv->second.empty()) return fallback;
429
+ try {
430
+ return std::stoi(kv->second);
431
+ } catch (...) {
432
+ return fallback;
433
+ }
434
+ }
435
+
436
+ static ManagerConfig load_manager_config() {
437
+ ManagerConfig cfg;
438
+
439
+ const std::string config_path = get_env_or("MANAGER_CONFIG", "config.toml");
440
+ std::unordered_map<std::string, std::unordered_map<std::string, std::string>> toml;
441
+ if (std::filesystem::exists(config_path)) {
442
+ toml = parse_simple_toml(config_path);
443
+ log_line("config: loaded " + config_path);
444
+ } else {
445
+ log_line("config: using environment/defaults (file not found: " + config_path + ")");
446
+ }
447
+
448
+ cfg.server.host = get_env_or(
449
+ "MANAGER_HOST",
450
+ get_toml_string_or(toml, "server", "host", cfg.server.host));
451
+ cfg.server.port = get_env_int_or(
452
+ "MANAGER_PORT",
453
+ get_toml_int_or(toml, "server", "port", cfg.server.port));
454
+
455
+ cfg.worker.default_model = get_env_or(
456
+ "DEFAULT_MODEL",
457
+ get_toml_string_or(toml, "worker", "default_model", cfg.worker.default_model));
458
+ cfg.worker.llama_server_bin = get_env_or(
459
+ "LLAMA_SERVER_BIN",
460
+ get_toml_string_or(toml, "worker", "llama_server_bin", cfg.worker.llama_server_bin));
461
+ cfg.worker.host = get_env_or(
462
+ "WORKER_HOST",
463
+ get_toml_string_or(toml, "worker", "host", cfg.worker.host));
464
+ cfg.worker.bind_host = get_env_or(
465
+ "WORKER_BIND_HOST",
466
+ get_toml_string_or(toml, "worker", "bind_host", cfg.worker.bind_host));
467
+ cfg.worker.base_port = get_env_int_or(
468
+ "WORKER_BASE_PORT",
469
+ get_toml_int_or(toml, "worker", "base_port", cfg.worker.base_port));
470
+ cfg.worker.switch_timeout_sec = get_env_int_or(
471
+ "SWITCH_TIMEOUT_SEC",
472
+ get_toml_int_or(toml, "worker", "switch_timeout_sec", cfg.worker.switch_timeout_sec));
473
+
474
+ cfg.llama.n_ctx = get_env_int_or(
475
+ "MODEL_N_CTX",
476
+ get_toml_int_or(toml, "llama", "n_ctx", cfg.llama.n_ctx));
477
+ cfg.llama.threads = get_env_int_or(
478
+ "MODEL_THREADS",
479
+ get_toml_int_or(toml, "llama", "threads", cfg.llama.threads));
480
+ cfg.llama.ngl = get_env_int_or(
481
+ "MODEL_NGL",
482
+ get_toml_int_or(toml, "llama", "ngl", cfg.llama.ngl));
483
+ cfg.llama.batch = get_env_int_or(
484
+ "MODEL_BATCH",
485
+ get_toml_int_or(toml, "llama", "batch", cfg.llama.batch));
486
+ cfg.llama.ubatch = get_env_int_or(
487
+ "MODEL_UBATCH",
488
+ get_toml_int_or(toml, "llama", "ubatch", cfg.llama.ubatch));
489
+
490
+ cfg.auth.header = get_env_or(
491
+ "AUTH_HEADER",
492
+ get_toml_string_or(toml, "auth", "header", cfg.auth.header));
493
+ cfg.auth.scheme = get_env_or(
494
+ "AUTH_SCHEME",
495
+ get_toml_string_or(toml, "auth", "scheme", cfg.auth.scheme));
496
+
497
+ cfg.limits.default_max_tokens = get_env_int_or(
498
+ "DEFAULT_MAX_TOKENS",
499
+ get_toml_int_or(toml, "limits", "default_max_tokens", cfg.limits.default_max_tokens));
500
+ cfg.limits.max_tokens_per_request = get_env_int_or(
501
+ "MAX_TOKENS_PER_REQUEST",
502
+ get_toml_int_or(toml, "limits", "max_tokens_per_request", cfg.limits.max_tokens_per_request));
503
+ cfg.limits.request_timeout_sec = get_env_int_or(
504
+ "REQUEST_TIMEOUT_SEC",
505
+ get_toml_int_or(toml, "limits", "request_timeout_sec", cfg.limits.request_timeout_sec));
506
+
507
+ cfg.queue.max_size = static_cast<size_t>(std::max(
508
+ 1,
509
+ get_env_int_or("QUEUE_MAX_SIZE", get_toml_int_or(toml, "queue", "max_size", static_cast<int>(cfg.queue.max_size)))));
510
+ cfg.queue.max_tokens = get_env_int_or(
511
+ "QUEUE_MAX_TOKENS",
512
+ get_toml_int_or(toml, "queue", "max_tokens", cfg.queue.max_tokens));
513
+ cfg.queue.admin_quota = get_env_int_or(
514
+ "QUEUE_ADMIN_QUOTA",
515
+ get_toml_int_or(toml, "queue", "admin_quota", cfg.queue.admin_quota));
516
+ cfg.queue.retry_after_sec = get_env_int_or(
517
+ "QUEUE_RETRY_AFTER_SEC",
518
+ get_toml_int_or(toml, "queue", "retry_after_sec", cfg.queue.retry_after_sec));
519
+
520
+ cfg.rate_limit.requests_per_minute = get_env_int_or(
521
+ "REQUESTS_PER_MINUTE",
522
+ get_toml_int_or(toml, "rate_limit", "requests_per_minute", cfg.rate_limit.requests_per_minute));
523
+ cfg.rate_limit.estimated_tokens_per_minute = get_env_int_or(
524
+ "ESTIMATED_TOKENS_PER_MINUTE",
525
+ get_toml_int_or(toml, "rate_limit", "estimated_tokens_per_minute", cfg.rate_limit.estimated_tokens_per_minute));
526
+
527
+ cfg.scheduler.max_concurrent = get_env_int_or(
528
+ "SCHEDULER_MAX_CONCURRENT",
529
+ get_toml_int_or(toml, "scheduler", "max_concurrent", cfg.scheduler.max_concurrent));
530
+
531
+ if (!config_path.empty() && std::filesystem::exists(config_path)) {
532
+ cfg.api_keys = parse_api_keys_toml(config_path);
533
+ }
534
+
535
+ return cfg;
536
+ }
537
+
538
  static bool is_alive(pid_t pid) {
539
  if (pid <= 0) return false;
540
  return kill(pid, 0) == 0;
 
557
 
558
  class ModelManager {
559
  public:
560
+ explicit ModelManager(const ManagerConfig &config)
561
+ : _default_model(config.worker.default_model),
562
+ _llama_server_bin(config.worker.llama_server_bin),
563
+ _worker_host(config.worker.host),
564
+ _worker_bind_host(config.worker.bind_host),
565
+ _base_port(config.worker.base_port),
566
+ _switch_timeout_sec(config.worker.switch_timeout_sec),
567
+ _n_ctx(config.llama.n_ctx),
568
+ _n_threads(config.llama.threads),
569
+ _n_gpu_layers(config.llama.ngl),
570
+ _n_batch(config.llama.batch),
571
+ _n_ubatch(config.llama.ubatch),
572
  _next_port(_base_port) {}
573
 
574
  bool initialize_default(std::string &error) {
 
795
 
796
  static std::atomic<uint64_t> g_req_id{1};
797
 
798
+ class RateLimiterStore {
799
+ public:
800
+ explicit RateLimiterStore(const RateLimitConfig &config)
801
+ : _requests_per_minute(std::max(0, config.requests_per_minute)),
802
+ _estimated_tokens_per_minute(std::max(0, config.estimated_tokens_per_minute)) {}
803
+
804
+ RateLimitDecision allow(const std::string &api_key_id, int estimated_tokens) {
805
+ if (_requests_per_minute <= 0 && _estimated_tokens_per_minute <= 0) {
806
+ return {};
807
+ }
808
+
809
+ std::lock_guard<std::mutex> lock(_mu);
810
+ auto &bucket = _buckets[api_key_id];
811
+ const auto now = std::chrono::steady_clock::now();
812
+ refill(bucket.request_tokens, bucket.last_request_refill, _requests_per_minute, now);
813
+ refill(bucket.estimated_tokens, bucket.last_estimated_refill, _estimated_tokens_per_minute, now);
814
+
815
+ if (_requests_per_minute > 0 && bucket.request_tokens < 1.0) {
816
+ return {false, 1, "Rate limit exceeded: requests"};
817
+ }
818
+ if (_estimated_tokens_per_minute > 0 && bucket.estimated_tokens < estimated_tokens) {
819
+ return {false, 1, "Rate limit exceeded: estimated tokens"};
820
+ }
821
+
822
+ if (_requests_per_minute > 0) bucket.request_tokens -= 1.0;
823
+ if (_estimated_tokens_per_minute > 0) bucket.estimated_tokens -= estimated_tokens;
824
+ return {};
825
+ }
826
+
827
+ private:
828
+ struct Bucket {
829
+ double request_tokens = 0.0;
830
+ double estimated_tokens = 0.0;
831
+ std::chrono::steady_clock::time_point last_request_refill{};
832
+ std::chrono::steady_clock::time_point last_estimated_refill{};
833
+ };
834
+
835
+ std::mutex _mu;
836
+ std::unordered_map<std::string, Bucket> _buckets;
837
+ int _requests_per_minute;
838
+ int _estimated_tokens_per_minute;
839
+
840
+ static void refill(
841
+ double &tokens,
842
+ std::chrono::steady_clock::time_point &last_refill,
843
+ int limit_per_minute,
844
+ std::chrono::steady_clock::time_point now) {
845
+ if (limit_per_minute <= 0) return;
846
+ if (last_refill.time_since_epoch().count() == 0) {
847
+ tokens = limit_per_minute;
848
+ last_refill = now;
849
+ return;
850
+ }
851
+ const auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - last_refill).count();
852
+ if (elapsed <= 0) return;
853
+ const double refill_amount = (static_cast<double>(limit_per_minute) * elapsed) / 60000.0;
854
+ tokens = std::min(static_cast<double>(limit_per_minute), tokens + refill_amount);
855
+ last_refill = now;
856
+ }
857
+ };
858
+
859
+ class RequestRegistry {
860
+ public:
861
+ std::shared_ptr<RequestContext> create(
862
+ const std::string &request_id,
863
+ const ApiKeyRecord &principal,
864
+ const TokenEstimate &estimate,
865
+ const std::string &request_body) {
866
+ auto ctx = std::make_shared<RequestContext>();
867
+ ctx->request_id = request_id;
868
+ ctx->api_key_id = principal.key_id;
869
+ ctx->role = principal.role;
870
+ ctx->priority = role_to_priority(principal.role);
871
+ ctx->estimate = estimate;
872
+ ctx->request_body = request_body;
873
+ ctx->created_at = std::chrono::steady_clock::now();
874
+ ctx->enqueue_time = ctx->created_at;
875
+
876
+ std::lock_guard<std::mutex> lock(_mu);
877
+ _requests[request_id] = ctx;
878
+ return ctx;
879
+ }
880
+
881
+ std::shared_ptr<RequestContext> find(const std::string &request_id) const {
882
+ std::lock_guard<std::mutex> lock(_mu);
883
+ const auto it = _requests.find(request_id);
884
+ if (it == _requests.end()) return nullptr;
885
+ return it->second;
886
+ }
887
+
888
+ void mark_state(const std::string &request_id, RequestState state) {
889
+ auto ctx = find(request_id);
890
+ if (ctx) ctx->state.store(state);
891
+ }
892
+
893
+ void complete(const std::shared_ptr<RequestContext> &ctx, RequestState state, RequestResult result) {
894
+ {
895
+ std::lock_guard<std::mutex> lock(ctx->mu);
896
+ ctx->state.store(state);
897
+ ctx->result = std::move(result);
898
+ ctx->completed = true;
899
+ }
900
+ ctx->cv.notify_all();
901
+ }
902
+
903
+ std::shared_ptr<RequestContext> cancel_request(const std::string &request_id) {
904
+ auto ctx = find(request_id);
905
+ if (!ctx) return nullptr;
906
+
907
+ ctx->cancelled.store(true);
908
+ const auto state = ctx->state.load();
909
+ if (state == RequestState::QUEUED) {
910
+ complete(ctx, RequestState::CANCELLED, {499, R"({"error":"Request cancelled"})"});
911
+ } else if (state == RequestState::RUNNING) {
912
+ ctx->state.store(RequestState::CANCELLED);
913
+ }
914
+ return ctx;
915
+ }
916
+
917
+ std::vector<std::shared_ptr<RequestContext>> cancel_all() {
918
+ std::vector<std::shared_ptr<RequestContext>> out;
919
+ std::lock_guard<std::mutex> lock(_mu);
920
+ out.reserve(_requests.size());
921
+ for (auto &[_, ctx] : _requests) {
922
+ ctx->cancelled.store(true);
923
+ const auto state = ctx->state.load();
924
+ if (state == RequestState::QUEUED) {
925
+ {
926
+ std::lock_guard<std::mutex> ctx_lock(ctx->mu);
927
+ ctx->state.store(RequestState::CANCELLED);
928
+ ctx->result = {499, R"({"error":"Request cancelled"})"};
929
+ ctx->completed = true;
930
+ }
931
+ ctx->cv.notify_all();
932
+ } else if (state == RequestState::RUNNING) {
933
+ ctx->state.store(RequestState::CANCELLED);
934
+ }
935
+ out.push_back(ctx);
936
+ }
937
+ return out;
938
+ }
939
+
940
+ private:
941
+ mutable std::mutex _mu;
942
+ std::unordered_map<std::string, std::shared_ptr<RequestContext>> _requests;
943
+ };
944
+
945
+ class PrioritySchedulerQueue {
946
+ public:
947
+ explicit PrioritySchedulerQueue(const QueueConfig &config)
948
+ : _max_size(config.max_size),
949
+ _max_tokens(config.max_tokens),
950
+ _admin_quota(std::max(1, config.admin_quota)),
951
+ _retry_after_sec(std::max(1, config.retry_after_sec)) {}
952
+
953
+ bool try_push(const std::shared_ptr<RequestContext> &ctx) {
954
+ std::lock_guard<std::mutex> lock(_mu);
955
+ if (_current_size >= _max_size) return false;
956
+ if (_current_tokens + ctx->estimate.estimated_total_tokens > _max_tokens) return false;
957
+
958
+ if (ctx->priority == Priority::ADMIN) {
959
+ _admin_queue.push_back(ctx);
960
+ } else {
961
+ _user_queue.push_back(ctx);
962
+ }
963
+ ++_current_size;
964
+ _current_tokens += ctx->estimate.estimated_total_tokens;
965
+ _cv.notify_one();
966
+ return true;
967
+ }
968
+
969
+ std::shared_ptr<RequestContext> pop_next() {
970
+ std::unique_lock<std::mutex> lock(_mu);
971
+ _cv.wait(lock, [&]() { return _stopped || _current_size > 0; });
972
+ if (_stopped) return nullptr;
973
+
974
+ std::deque<std::shared_ptr<RequestContext>> *selected_queue = nullptr;
975
+ if (!_admin_queue.empty() && (_admin_streak < _admin_quota || _user_queue.empty())) {
976
+ selected_queue = &_admin_queue;
977
+ ++_admin_streak;
978
+ } else if (!_user_queue.empty()) {
979
+ selected_queue = &_user_queue;
980
+ _admin_streak = 0;
981
+ } else if (!_admin_queue.empty()) {
982
+ selected_queue = &_admin_queue;
983
+ _admin_streak = 1;
984
+ }
985
+
986
+ if (!selected_queue || selected_queue->empty()) return nullptr;
987
+
988
+ auto best_it = std::min_element(
989
+ selected_queue->begin(),
990
+ selected_queue->end(),
991
+ [](const auto &a, const auto &b) {
992
+ return a->estimate.estimated_total_tokens < b->estimate.estimated_total_tokens;
993
+ });
994
+ auto ctx = *best_it;
995
+ selected_queue->erase(best_it);
996
+ --_current_size;
997
+ _current_tokens -= ctx->estimate.estimated_total_tokens;
998
+ return ctx;
999
+ }
1000
+
1001
+ void stop() {
1002
+ std::lock_guard<std::mutex> lock(_mu);
1003
+ _stopped = true;
1004
+ _cv.notify_all();
1005
+ }
1006
+
1007
+ int retry_after_sec() const {
1008
+ return _retry_after_sec;
1009
+ }
1010
+
1011
+ private:
1012
+ mutable std::mutex _mu;
1013
+ std::condition_variable _cv;
1014
+ std::deque<std::shared_ptr<RequestContext>> _admin_queue;
1015
+ std::deque<std::shared_ptr<RequestContext>> _user_queue;
1016
+ size_t _max_size;
1017
+ size_t _current_size = 0;
1018
+ int _max_tokens;
1019
+ int _current_tokens = 0;
1020
+ int _admin_quota;
1021
+ int _admin_streak = 0;
1022
+ int _retry_after_sec;
1023
+ bool _stopped = false;
1024
+ };
1025
+
1026
+ class Scheduler {
1027
+ public:
1028
+ Scheduler(ModelManager &manager, RequestRegistry &registry, const QueueConfig &queue_config)
1029
+ : _manager(manager), _registry(registry), _queue(queue_config) {
1030
+ _worker = std::thread([this]() { worker_loop(); });
1031
+ }
1032
+
1033
+ ~Scheduler() {
1034
+ _queue.stop();
1035
+ if (_worker.joinable()) _worker.join();
1036
+ }
1037
+
1038
+ bool try_enqueue(const std::shared_ptr<RequestContext> &ctx) {
1039
+ return _queue.try_push(ctx);
1040
+ }
1041
+
1042
+ int retry_after_sec() const {
1043
+ return _queue.retry_after_sec();
1044
+ }
1045
+
1046
+ private:
1047
+ ModelManager &_manager;
1048
+ RequestRegistry &_registry;
1049
+ PrioritySchedulerQueue _queue;
1050
+ std::thread _worker;
1051
+
1052
+ void worker_loop() {
1053
+ for (;;) {
1054
+ auto ctx = _queue.pop_next();
1055
+ if (!ctx) return;
1056
+
1057
+ if (ctx->cancelled.load()) {
1058
+ _registry.complete(ctx, RequestState::CANCELLED, {499, R"({"error":"Request cancelled"})"});
1059
+ continue;
1060
+ }
1061
+
1062
+ ctx->state.store(RequestState::RUNNING);
1063
+ ctx->start_time = std::chrono::steady_clock::now();
1064
+ auto worker = _manager.active_worker();
1065
+ if (!worker) {
1066
+ _registry.complete(ctx, RequestState::FAILED, {503, R"({"error":"No active model"})"});
1067
+ continue;
1068
+ }
1069
+
1070
+ try {
1071
+ auto [status, body] = forward_chat(*worker, ctx->request_body);
1072
+ if (ctx->cancelled.load()) {
1073
+ _registry.complete(ctx, RequestState::CANCELLED, {499, R"({"error":"Request cancelled"})"});
1074
+ continue;
1075
+ }
1076
+ _registry.complete(ctx, RequestState::DONE, {status, body});
1077
+ } catch (const std::exception &e) {
1078
+ _registry.complete(
1079
+ ctx,
1080
+ RequestState::FAILED,
1081
+ {500, json({{"error", e.what()}}).dump()});
1082
+ }
1083
+ }
1084
+ }
1085
+ };
1086
+
1087
+ class ApiKeyAuth {
1088
+ public:
1089
+ explicit ApiKeyAuth(const ManagerConfig &config)
1090
+ : _header_name(config.auth.header), _scheme(config.auth.scheme) {
1091
+ for (const auto &record : config.api_keys) {
1092
+ _records_by_secret.emplace(record.secret, record);
1093
+ }
1094
+ }
1095
+
1096
+ bool enabled() const {
1097
+ return !_records_by_secret.empty();
1098
+ }
1099
+
1100
+ template <typename Body, typename Allocator>
1101
+ std::optional<ApiKeyRecord> authenticate(
1102
+ const http::request<Body, http::basic_fields<Allocator>> &req,
1103
+ std::string &error) const {
1104
+ if (!enabled()) {
1105
+ error.clear();
1106
+ return ApiKeyRecord{"anonymous", "", Role::ADMIN, true};
1107
+ }
1108
+
1109
+ const auto token = extract_bearer_token(req, error);
1110
+ if (!token) return std::nullopt;
1111
+
1112
+ const auto it = _records_by_secret.find(*token);
1113
+ if (it == _records_by_secret.end()) {
1114
+ error = "Invalid API key";
1115
+ return std::nullopt;
1116
+ }
1117
+ if (!it->second.enabled) {
1118
+ error = "API key disabled";
1119
+ return std::nullopt;
1120
+ }
1121
+ error.clear();
1122
+ return it->second;
1123
+ }
1124
+
1125
+ private:
1126
+ std::string _header_name;
1127
+ std::string _scheme;
1128
+ std::unordered_map<std::string, ApiKeyRecord> _records_by_secret;
1129
+
1130
+ template <typename Body, typename Allocator>
1131
+ std::optional<std::string> extract_bearer_token(
1132
+ const http::request<Body, http::basic_fields<Allocator>> &req,
1133
+ std::string &error) const {
1134
+ const auto header_it = req.find(_header_name);
1135
+ if (header_it == req.end()) {
1136
+ error = "Missing authorization header";
1137
+ return std::nullopt;
1138
+ }
1139
+
1140
+ const std::string value = trim_copy(header_it->value().to_string());
1141
+ const std::string prefix = _scheme + " ";
1142
+ if (value.size() <= prefix.size() || value.rfind(prefix, 0) != 0) {
1143
+ error = "Invalid authorization scheme";
1144
+ return std::nullopt;
1145
+ }
1146
+
1147
+ std::string token = trim_copy(value.substr(prefix.size()));
1148
+ if (token.empty()) {
1149
+ error = "Missing API key";
1150
+ return std::nullopt;
1151
+ }
1152
+ return token;
1153
+ }
1154
+ };
1155
+
1156
  static void log_line(const std::string &line) {
1157
  std::cout << "[" << now_utc_iso() << "] " << line << std::endl;
1158
  }
 
1162
  return body.substr(0, max_len) + "...[truncated]";
1163
  }
1164
 
1165
+ static std::optional<std::string> extract_cancel_request_id(const std::string &path) {
1166
+ const std::string prefix = "/requests/";
1167
+ const std::string suffix = "/cancel";
1168
+ if (path.size() <= prefix.size() + suffix.size()) return std::nullopt;
1169
+ if (path.rfind(prefix, 0) != 0) return std::nullopt;
1170
+ if (path.substr(path.size() - suffix.size()) != suffix) return std::nullopt;
1171
+ const std::string request_id = path.substr(prefix.size(), path.size() - prefix.size() - suffix.size());
1172
+ if (request_id.empty()) return std::nullopt;
1173
+ return request_id;
1174
+ }
1175
+
1176
  static std::pair<int, std::string> forward_chat(const WorkerInfo &worker, const std::string &body) {
1177
  asio::io_context ioc;
1178
  asio::ip::tcp::resolver resolver(ioc);
 
1237
  template <typename Body, typename Allocator>
1238
  http::response<http::string_body> handle_request(
1239
  ModelManager &manager,
1240
+ const ManagerConfig &config,
1241
+ const ApiKeyAuth &auth,
1242
+ RateLimiterStore &rate_limiter,
1243
+ RequestRegistry &registry,
1244
+ Scheduler &scheduler,
1245
  http::request<Body, http::basic_fields<Allocator>> &&req) {
1246
  const auto start = std::chrono::steady_clock::now();
1247
+ const auto req_id_num = g_req_id.fetch_add(1);
1248
+ const std::string request_id = std::to_string(req_id_num);
1249
  const std::string target = req.target().to_string();
1250
  const std::string method = req.method_string().to_string();
1251
  const std::string path = target.substr(0, target.find('?'));
1252
+ auto authenticated = std::optional<ApiKeyRecord>{};
1253
 
1254
+ log_line("request_id=" + request_id + " method=" + method + " path=" + target);
1255
  if constexpr (std::is_same_v<Body, http::string_body>) {
1256
  if (!req.body().empty()) {
1257
+ log_line("request_id=" + request_id + " body=" + truncate_body(req.body()));
1258
  }
1259
  }
1260
 
1261
  auto json_response = [&](http::status status, const json &obj) {
1262
+ json payload = obj;
1263
+ payload["request_id"] = request_id;
1264
  http::response<http::string_body> res{status, req.version()};
1265
  res.set(http::field::content_type, "application/json");
1266
  res.set(http::field::server, "llm-manager");
1267
+ res.set("X-Request-Id", request_id);
1268
  res.keep_alive(req.keep_alive());
1269
+ res.body() = payload.dump();
1270
  res.prepare_payload();
1271
  auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
1272
  std::chrono::steady_clock::now() - start).count();
1273
+ log_line("request_id=" + request_id + " status=" + std::to_string(res.result_int()) +
1274
  " elapsed_ms=" + std::to_string(elapsed_ms));
1275
  return res;
1276
  };
1277
 
1278
+ auto json_response_with_retry_after = [&](http::status status, const json &obj, int retry_after_sec) {
1279
+ auto res = json_response(status, obj);
1280
+ res.set(http::field::retry_after, std::to_string(std::max(1, retry_after_sec)));
1281
+ return res;
1282
+ };
1283
+
1284
+ auto ensure_authenticated = [&](Role minimum_role) -> std::optional<http::response<http::string_body>> {
1285
+ std::string auth_error;
1286
+ authenticated = auth.authenticate(req, auth_error);
1287
+ if (!authenticated) {
1288
+ return json_response(http::status::unauthorized, {{"error", auth_error}});
1289
+ }
1290
+ if (minimum_role == Role::ADMIN && authenticated->role != Role::ADMIN) {
1291
+ return json_response(http::status::forbidden, {{"error", "Admin role required"}});
1292
+ }
1293
+ log_line("request_id=" + request_id +
1294
+ " api_key_id=" + authenticated->key_id +
1295
+ " role=" + role_to_string(authenticated->role));
1296
+ return std::nullopt;
1297
+ };
1298
+
1299
  try {
1300
  if (path == "/health" && req.method() == http::verb::get) {
1301
  return json_response(http::status::ok, manager.models_view());
 
1306
  }
1307
 
1308
  if (path == "/switch-model" && req.method() == http::verb::post) {
1309
+ if (auto auth_res = ensure_authenticated(Role::ADMIN)) return *auth_res;
1310
  std::string body(req.body().data(), req.body().size());
1311
  json j = json::parse(body, nullptr, false);
1312
  if (j.is_discarded()) {
 
1331
  }
1332
 
1333
  if (path == "/stop" && req.method() == http::verb::post) {
1334
+ if (auto auth_res = ensure_authenticated(Role::ADMIN)) return *auth_res;
1335
+ registry.cancel_all();
1336
  std::string err;
1337
  bool ok = manager.restart_active(err);
1338
  if (!ok) {
 
1349
  return json_response(http::status::ok, state);
1350
  }
1351
 
1352
+ if (req.method() == http::verb::post) {
1353
+ if (auto cancel_id = extract_cancel_request_id(path)) {
1354
+ if (auto auth_res = ensure_authenticated(Role::USER)) return *auth_res;
1355
+ auto ctx = registry.find(*cancel_id);
1356
+ if (!ctx) {
1357
+ return json_response(http::status::not_found, {{"error", "Unknown request id"}});
1358
+ }
1359
+ if (authenticated->role != Role::ADMIN && authenticated->key_id != ctx->api_key_id) {
1360
+ return json_response(http::status::forbidden, {{"error", "Cannot cancel another API key request"}});
1361
+ }
1362
+
1363
+ const auto previous_state = ctx->state.load();
1364
+ registry.cancel_request(*cancel_id);
1365
+ std::string restart_error;
1366
+ bool restarted = true;
1367
+ if (previous_state == RequestState::RUNNING) {
1368
+ restarted = manager.restart_active(restart_error);
1369
+ }
1370
+
1371
+ json payload = {
1372
+ {"cancelled_request_id", *cancel_id},
1373
+ {"state", state_to_string(ctx->state.load())}
1374
+ };
1375
+ if (!restarted) {
1376
+ payload["restart_error"] = restart_error;
1377
+ }
1378
+ return json_response(http::status::ok, payload);
1379
+ }
1380
+ }
1381
+
1382
  if (path == "/v1/chat/completions" && req.method() == http::verb::post) {
1383
+ if (auto auth_res = ensure_authenticated(Role::USER)) return *auth_res;
1384
+ json payload = json::parse(req.body(), nullptr, false);
1385
+ if (payload.is_discarded()) {
1386
+ return json_response(http::status::bad_request, {{"error", "Invalid JSON"}});
1387
+ }
1388
+ std::string token_error;
1389
+ auto estimate = estimate_chat_tokens(payload, config.limits, token_error);
1390
+ if (!estimate) {
1391
+ return json_response(http::status::bad_request, {{"error", token_error}});
1392
+ }
1393
+ log_line("request_id=" + request_id +
1394
+ " prompt_tokens=" + std::to_string(estimate->prompt_tokens) +
1395
+ " max_tokens=" + std::to_string(estimate->requested_max_tokens) +
1396
+ " estimated_total_tokens=" + std::to_string(estimate->estimated_total_tokens));
1397
+
1398
+ auto rate_limit_decision = rate_limiter.allow(
1399
+ authenticated->key_id,
1400
+ estimate->estimated_total_tokens);
1401
+ if (!rate_limit_decision.allowed) {
1402
+ return json_response_with_retry_after(
1403
+ http::status::too_many_requests,
1404
+ {{"error", rate_limit_decision.error}},
1405
+ rate_limit_decision.retry_after_sec);
1406
  }
1407
+
1408
+ auto ctx = registry.create(request_id, *authenticated, *estimate, req.body());
1409
+ if (!scheduler.try_enqueue(ctx)) {
1410
+ ctx->cancelled.store(true);
1411
+ registry.complete(ctx, RequestState::CANCELLED, {503, R"({"error":"Queue full"})"});
1412
+ return json_response_with_retry_after(
1413
+ http::status::service_unavailable,
1414
+ {{"error", "Queue full"}},
1415
+ scheduler.retry_after_sec());
1416
+ }
1417
+
1418
+ std::unique_lock<std::mutex> lock(ctx->mu);
1419
+ const bool finished = ctx->cv.wait_for(
1420
+ lock,
1421
+ std::chrono::seconds(std::max(1, config.limits.request_timeout_sec)),
1422
+ [&]() { return ctx->completed; });
1423
+ if (!finished) {
1424
+ lock.unlock();
1425
+ registry.cancel_request(request_id);
1426
+ std::string restart_error;
1427
+ bool restarted = true;
1428
+ if (ctx->state.load() == RequestState::RUNNING) {
1429
+ restarted = manager.restart_active(restart_error);
1430
+ }
1431
+ json timeout_payload = {
1432
+ {"error", "Request timed out"},
1433
+ {"state", state_to_string(ctx->state.load())}
1434
+ };
1435
+ if (!restarted) timeout_payload["restart_error"] = restart_error;
1436
+ return json_response(http::status::gateway_timeout, timeout_payload);
1437
+ }
1438
+
1439
+ const auto final_state = ctx->state.load();
1440
+ RequestResult result = ctx->result;
1441
+ lock.unlock();
1442
+
1443
+ if (final_state == RequestState::CANCELLED) {
1444
+ return json_response(http::status::ok, {{"status", "cancelled"}});
1445
+ }
1446
+
1447
+ http::response<http::string_body> res{
1448
+ static_cast<http::status>(result.status), req.version()};
1449
+ res.set(http::field::content_type, result.content_type);
1450
  res.set(http::field::server, "llm-manager");
1451
+ res.set("X-Request-Id", request_id);
1452
  res.keep_alive(req.keep_alive());
1453
+ res.body() = result.body;
1454
  res.prepare_payload();
1455
  auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
1456
  std::chrono::steady_clock::now() - start).count();
1457
+ log_line("request_id=" + request_id +
1458
+ " final_state=" + state_to_string(final_state) +
1459
+ " upstream_status=" + std::to_string(result.status) +
 
1460
  " elapsed_ms=" + std::to_string(elapsed_ms));
1461
  return res;
1462
  }
 
1476
  res.set(http::field::content_encoding, upstream.content_encoding);
1477
  }
1478
  res.set(http::field::server, "llm-manager");
1479
+ res.set("X-Request-Id", request_id);
1480
  res.keep_alive(req.keep_alive());
1481
  res.body() = upstream.body;
1482
  res.prepare_payload();
1483
  auto elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
1484
  std::chrono::steady_clock::now() - start)
1485
  .count();
1486
+ log_line("request_id=" + request_id +
1487
  " proxied_get model=" + worker->model +
1488
  " upstream_status=" + std::to_string(upstream.status) +
1489
  " elapsed_ms=" + std::to_string(elapsed_ms));
 
1496
  }
1497
  }
1498
 
1499
+ void do_session(
1500
+ asio::ip::tcp::socket socket,
1501
+ ModelManager &manager,
1502
+ const ManagerConfig &config,
1503
+ const ApiKeyAuth &auth,
1504
+ RateLimiterStore &rate_limiter,
1505
+ RequestRegistry &registry,
1506
+ Scheduler &scheduler) {
1507
  try {
1508
  beast::flat_buffer buffer;
1509
  http::request<http::string_body> req;
1510
  http::read(socket, buffer, req);
1511
+ auto res = handle_request(manager, config, auth, rate_limiter, registry, scheduler, std::move(req));
1512
  http::write(socket, res);
1513
  beast::error_code ec;
1514
  socket.shutdown(asio::ip::tcp::socket::shutdown_send, ec);
 
1517
  }
1518
 
1519
  int main() {
1520
+ const ManagerConfig config = load_manager_config();
1521
+ const auto &bind_host = config.server.host;
1522
+ const int bind_port = config.server.port;
1523
+ ModelManager manager(config);
1524
+ ApiKeyAuth auth(config);
1525
+ RateLimiterStore rate_limiter(config.rate_limit);
1526
+ RequestRegistry registry;
1527
+ Scheduler scheduler(manager, registry, config.queue);
1528
 
1529
  std::string init_error;
1530
+ if (auth.enabled()) {
1531
+ log_line("auth: enabled api_keys=" + std::to_string(config.api_keys.size()));
1532
+ } else {
1533
+ log_line("auth: disabled (no configured api keys)");
1534
+ }
1535
  log_line("startup: loading default model");
1536
  if (!manager.initialize_default(init_error)) {
1537
  log_line("startup: default model failed: " + init_error);
 
1546
  for (;;) {
1547
  asio::ip::tcp::socket socket{ioc};
1548
  acceptor.accept(socket);
1549
+ std::thread(
1550
+ &do_session,
1551
+ std::move(socket),
1552
+ std::ref(manager),
1553
+ std::cref(config),
1554
+ std::cref(auth),
1555
+ std::ref(rate_limiter),
1556
+ std::ref(registry),
1557
+ std::ref(scheduler)).detach();
1558
  }
1559
  }