Spaces:
Building
Building
| RateLimiterStore::RateLimiterStore(const RateLimitConfig &config) | |
| : requests_per_minute_(std::max(0, config.requests_per_minute)), | |
| estimated_tokens_per_minute_(std::max(0, config.estimated_tokens_per_minute)) {} | |
| RateLimitDecision RateLimiterStore::allow(const std::string &api_key_id, int estimated_tokens) { | |
| if (requests_per_minute_ <= 0 && estimated_tokens_per_minute_ <= 0) return {}; | |
| std::lock_guard<std::mutex> lock(mu_); | |
| auto &bucket = buckets_[api_key_id]; | |
| const auto now = std::chrono::steady_clock::now(); | |
| refill(bucket.request_tokens, bucket.last_request_refill, requests_per_minute_, now); | |
| refill(bucket.estimated_tokens, bucket.last_estimated_refill, estimated_tokens_per_minute_, now); | |
| if (requests_per_minute_ > 0 && bucket.request_tokens < 1.0) { | |
| return {false, 1, "Rate limit exceeded: requests"}; | |
| } | |
| if (estimated_tokens_per_minute_ > 0 && bucket.estimated_tokens < estimated_tokens) { | |
| return {false, 1, "Rate limit exceeded: estimated tokens"}; | |
| } | |
| if (requests_per_minute_ > 0) bucket.request_tokens -= 1.0; | |
| if (estimated_tokens_per_minute_ > 0) bucket.estimated_tokens -= estimated_tokens; | |
| return {}; | |
| } | |
| void RateLimiterStore::refill( | |
| double &tokens, | |
| std::chrono::steady_clock::time_point &last_refill, | |
| int limit_per_minute, | |
| std::chrono::steady_clock::time_point now) { | |
| if (limit_per_minute <= 0) return; | |
| if (last_refill.time_since_epoch().count() == 0) { | |
| tokens = limit_per_minute; | |
| last_refill = now; | |
| return; | |
| } | |
| const auto elapsed = std::chrono::duration_cast<std::chrono::milliseconds>(now - last_refill).count(); | |
| if (elapsed <= 0) return; | |
| const double refill_amount = (static_cast<double>(limit_per_minute) * elapsed) / 60000.0; | |
| tokens = std::min(static_cast<double>(limit_per_minute), tokens + refill_amount); | |
| last_refill = now; | |
| } | |
| std::shared_ptr<RequestContext> RequestRegistry::create( | |
| const std::string &request_id, | |
| const ApiKeyRecord &principal, | |
| const TokenEstimate &estimate, | |
| const std::string &request_body) { | |
| auto ctx = std::make_shared<RequestContext>(); | |
| ctx->request_id = request_id; | |
| ctx->api_key_id = principal.key_id; | |
| ctx->role = principal.role; | |
| ctx->priority = role_to_priority(principal.role); | |
| ctx->estimate = estimate; | |
| ctx->request_body = request_body; | |
| ctx->created_at = std::chrono::steady_clock::now(); | |
| ctx->enqueue_time = ctx->created_at; | |
| std::lock_guard<std::mutex> lock(mu_); | |
| requests_[request_id] = ctx; | |
| return ctx; | |
| } | |
| std::shared_ptr<RequestContext> RequestRegistry::find(const std::string &request_id) const { | |
| std::lock_guard<std::mutex> lock(mu_); | |
| const auto it = requests_.find(request_id); | |
| if (it == requests_.end()) return nullptr; | |
| return it->second; | |
| } | |
| void RequestRegistry::complete(const std::shared_ptr<RequestContext> &ctx, RequestState state, RequestResult result) { | |
| { | |
| std::lock_guard<std::mutex> lock(ctx->mu); | |
| ctx->state.store(state); | |
| ctx->result = std::move(result); | |
| ctx->completed = true; | |
| } | |
| ctx->cv.notify_all(); | |
| } | |
| std::shared_ptr<RequestContext> RequestRegistry::cancel_request(const std::string &request_id) { | |
| auto ctx = find(request_id); | |
| if (!ctx) return nullptr; | |
| ctx->cancelled.store(true); | |
| const auto state = ctx->state.load(); | |
| if (state == RequestState::QUEUED) { | |
| complete(ctx, RequestState::CANCELLED, {499, R"({"error":"Request cancelled"})"}); | |
| } else if (state == RequestState::RUNNING) { | |
| ctx->state.store(RequestState::CANCELLED); | |
| } | |
| return ctx; | |
| } | |
| std::vector<std::shared_ptr<RequestContext>> RequestRegistry::cancel_all() { | |
| std::vector<std::shared_ptr<RequestContext>> out; | |
| std::lock_guard<std::mutex> lock(mu_); | |
| out.reserve(requests_.size()); | |
| for (auto &[_, ctx] : requests_) { | |
| ctx->cancelled.store(true); | |
| const auto state = ctx->state.load(); | |
| if (state == RequestState::QUEUED) { | |
| { | |
| std::lock_guard<std::mutex> ctx_lock(ctx->mu); | |
| ctx->state.store(RequestState::CANCELLED); | |
| ctx->result = {499, R"({"error":"Request cancelled"})"}; | |
| ctx->completed = true; | |
| } | |
| ctx->cv.notify_all(); | |
| } else if (state == RequestState::RUNNING) { | |
| ctx->state.store(RequestState::CANCELLED); | |
| } | |
| out.push_back(ctx); | |
| } | |
| return out; | |
| } | |
| void MetricsRegistry::inc_requests_total() { requests_total_.fetch_add(1); } | |
| void MetricsRegistry::inc_requests_inflight() { requests_inflight_.fetch_add(1); } | |
| void MetricsRegistry::dec_requests_inflight() { requests_inflight_.fetch_sub(1); } | |
| void MetricsRegistry::inc_queue_rejected_total() { queue_rejected_total_.fetch_add(1); } | |
| void MetricsRegistry::inc_rate_limited_total() { rate_limited_total_.fetch_add(1); } | |
| void MetricsRegistry::add_cancellations_total(uint64_t delta) { cancellations_total_.fetch_add(delta); } | |
| void MetricsRegistry::inc_switch_total() { switch_total_.fetch_add(1); } | |
| void MetricsRegistry::inc_worker_restarts_total() { worker_restarts_total_.fetch_add(1); } | |
| void MetricsRegistry::observe_request_latency_ms(int64_t value) { | |
| request_latency_ms_total_.fetch_add(value); | |
| request_latency_samples_.fetch_add(1); | |
| } | |
| void MetricsRegistry::observe_queue_wait_ms(int64_t value) { | |
| queue_wait_ms_total_.fetch_add(value); | |
| queue_wait_samples_.fetch_add(1); | |
| } | |
| std::string MetricsRegistry::render_prometheus(const QueueSnapshot &queue, ModelManager &manager) const { | |
| std::ostringstream oss; | |
| oss << "llm_manager_requests_total " << requests_total_.load() << '\n'; | |
| oss << "llm_manager_requests_inflight " << requests_inflight_.load() << '\n'; | |
| oss << "llm_manager_request_latency_ms_total " << request_latency_ms_total_.load() << '\n'; | |
| oss << "llm_manager_request_latency_ms_samples " << request_latency_samples_.load() << '\n'; | |
| oss << "llm_manager_queue_size " << queue.total_size << '\n'; | |
| oss << "llm_manager_queue_admin_size " << queue.admin_size << '\n'; | |
| oss << "llm_manager_queue_user_size " << queue.user_size << '\n'; | |
| oss << "llm_manager_queue_tokens " << queue.total_tokens << '\n'; | |
| oss << "llm_manager_queue_rejected_total " << queue_rejected_total_.load() << '\n'; | |
| oss << "llm_manager_rate_limited_total " << rate_limited_total_.load() << '\n'; | |
| oss << "llm_manager_queue_wait_time_ms_total " << queue_wait_ms_total_.load() << '\n'; | |
| oss << "llm_manager_queue_wait_time_ms_samples " << queue_wait_samples_.load() << '\n'; | |
| oss << "llm_manager_cancellations_total " << cancellations_total_.load() << '\n'; | |
| oss << "llm_manager_switch_total " << switch_total_.load() << '\n'; | |
| oss << "llm_manager_worker_restarts_total " << worker_restarts_total_.load() << '\n'; | |
| const auto active = manager.active_worker(); | |
| oss << "llm_manager_active_worker " << (active ? 1 : 0) << '\n'; | |
| return oss.str(); | |
| } | |
| PrioritySchedulerQueue::PrioritySchedulerQueue(const QueueConfig &config) | |
| : max_size_(config.max_size), | |
| max_tokens_(config.max_tokens), | |
| admin_quota_(std::max(1, config.admin_quota)), | |
| retry_after_sec_(std::max(1, config.retry_after_sec)) {} | |
| bool PrioritySchedulerQueue::try_push(const std::shared_ptr<RequestContext> &ctx) { | |
| std::lock_guard<std::mutex> lock(mu_); | |
| if (current_size_ >= max_size_) return false; | |
| if (current_tokens_ + ctx->estimate.estimated_total_tokens > max_tokens_) return false; | |
| if (ctx->priority == Priority::ADMIN) admin_queue_.push_back(ctx); | |
| else user_queue_.push_back(ctx); | |
| ++current_size_; | |
| current_tokens_ += ctx->estimate.estimated_total_tokens; | |
| cv_.notify_one(); | |
| return true; | |
| } | |
| std::shared_ptr<RequestContext> PrioritySchedulerQueue::pop_next() { | |
| std::unique_lock<std::mutex> lock(mu_); | |
| cv_.wait(lock, [&]() { return stopped_ || current_size_ > 0; }); | |
| if (stopped_) return nullptr; | |
| std::deque<std::shared_ptr<RequestContext>> *selected_queue = nullptr; | |
| if (!admin_queue_.empty() && (admin_streak_ < admin_quota_ || user_queue_.empty())) { | |
| selected_queue = &admin_queue_; | |
| ++admin_streak_; | |
| } else if (!user_queue_.empty()) { | |
| selected_queue = &user_queue_; | |
| admin_streak_ = 0; | |
| } else if (!admin_queue_.empty()) { | |
| selected_queue = &admin_queue_; | |
| admin_streak_ = 1; | |
| } | |
| if (!selected_queue || selected_queue->empty()) return nullptr; | |
| auto best_it = std::min_element( | |
| selected_queue->begin(), | |
| selected_queue->end(), | |
| [](const auto &a, const auto &b) { | |
| return a->estimate.estimated_total_tokens < b->estimate.estimated_total_tokens; | |
| }); | |
| auto ctx = *best_it; | |
| selected_queue->erase(best_it); | |
| --current_size_; | |
| current_tokens_ -= ctx->estimate.estimated_total_tokens; | |
| return ctx; | |
| } | |
| void PrioritySchedulerQueue::stop() { | |
| std::lock_guard<std::mutex> lock(mu_); | |
| stopped_ = true; | |
| cv_.notify_all(); | |
| } | |
| int PrioritySchedulerQueue::retry_after_sec() const { | |
| return retry_after_sec_; | |
| } | |
| QueueSnapshot PrioritySchedulerQueue::snapshot() const { | |
| std::lock_guard<std::mutex> lock(mu_); | |
| return QueueSnapshot{current_size_, admin_queue_.size(), user_queue_.size(), current_tokens_}; | |
| } | |
| Scheduler::Scheduler( | |
| ModelManager &manager, | |
| RequestRegistry ®istry, | |
| MetricsRegistry &metrics, | |
| const QueueConfig &queue_config) | |
| : manager_(manager), registry_(registry), metrics_(metrics), queue_(queue_config) { | |
| worker_ = std::thread([this]() { worker_loop(); }); | |
| } | |
| Scheduler::~Scheduler() { | |
| queue_.stop(); | |
| if (worker_.joinable()) worker_.join(); | |
| } | |
| bool Scheduler::try_enqueue(const std::shared_ptr<RequestContext> &ctx) { | |
| return queue_.try_push(ctx); | |
| } | |
| int Scheduler::retry_after_sec() const { | |
| return queue_.retry_after_sec(); | |
| } | |
| QueueSnapshot Scheduler::snapshot() const { | |
| return queue_.snapshot(); | |
| } | |
| void Scheduler::worker_loop() { | |
| for (;;) { | |
| auto ctx = queue_.pop_next(); | |
| if (!ctx) return; | |
| if (ctx->cancelled.load()) { | |
| registry_.complete(ctx, RequestState::CANCELLED, {499, R"({"error":"Request cancelled"})"}); | |
| continue; | |
| } | |
| ctx->state.store(RequestState::RUNNING); | |
| ctx->start_time = std::chrono::steady_clock::now(); | |
| metrics_.observe_queue_wait_ms( | |
| std::chrono::duration_cast<std::chrono::milliseconds>(ctx->start_time - ctx->enqueue_time).count()); | |
| auto worker = manager_.active_worker(); | |
| if (!worker) { | |
| registry_.complete(ctx, RequestState::FAILED, {503, R"({"error":"No active model"})"}); | |
| continue; | |
| } | |
| try { | |
| auto [status, body] = forward_chat(*worker, ctx->request_body); | |
| if (ctx->cancelled.load()) { | |
| registry_.complete(ctx, RequestState::CANCELLED, {499, R"({"error":"Request cancelled"})"}); | |
| continue; | |
| } | |
| registry_.complete(ctx, RequestState::DONE, {status, body}); | |
| } catch (const std::exception &e) { | |
| log_line("request_id=" + ctx->request_id + " scheduler_exception=" + std::string(e.what())); | |
| registry_.complete(ctx, RequestState::FAILED, {500, json({{"error", e.what()}}).dump()}); | |
| } catch (...) { | |
| log_line("request_id=" + ctx->request_id + " scheduler_exception=unknown"); | |
| registry_.complete(ctx, RequestState::FAILED, {500, json({{"error", "Unknown exception"}}).dump()}); | |
| } | |
| } | |
| } | |
| ApiKeyAuth::ApiKeyAuth(const ManagerConfig &config) | |
| : header_name_(config.auth.header), scheme_(config.auth.scheme) { | |
| for (const auto &record : config.api_keys) { | |
| records_by_secret_.emplace(record.secret, record); | |
| } | |
| } | |
| bool ApiKeyAuth::enabled() const { | |
| return !records_by_secret_.empty(); | |
| } | |
| std::optional<ApiKeyRecord> ApiKeyAuth::authenticate( | |
| const http::request<http::string_body> &req, | |
| std::string &error) const { | |
| if (!enabled()) { | |
| error.clear(); | |
| return ApiKeyRecord{"anonymous", "", Role::ADMIN, true}; | |
| } | |
| const auto token = extract_bearer_token(req, error); | |
| if (!token) return std::nullopt; | |
| const auto it = records_by_secret_.find(*token); | |
| if (it == records_by_secret_.end()) { | |
| error = "Invalid API key"; | |
| return std::nullopt; | |
| } | |
| if (!it->second.enabled) { | |
| error = "API key disabled"; | |
| return std::nullopt; | |
| } | |
| error.clear(); | |
| return it->second; | |
| } | |
| std::optional<std::string> ApiKeyAuth::extract_bearer_token( | |
| const http::request<http::string_body> &req, | |
| std::string &error) const { | |
| const auto header_it = req.find(header_name_); | |
| if (header_it == req.end()) { | |
| error = "Missing authorization header"; | |
| return std::nullopt; | |
| } | |
| const std::string value = trim_copy(header_it->value().to_string()); | |
| const std::string prefix = scheme_ + " "; | |
| if (value.size() <= prefix.size() || value.rfind(prefix, 0) != 0) { | |
| error = "Invalid authorization scheme"; | |
| return std::nullopt; | |
| } | |
| std::string token = trim_copy(value.substr(prefix.size())); | |
| if (token.empty()) { | |
| error = "Missing API key"; | |
| return std::nullopt; | |
| } | |
| return token; | |
| } | |