|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "server-task.cpp" |
|
|
#include "server-queue.cpp" |
|
|
#include "server-common.cpp" |
|
|
#include "server-context.cpp" |
|
|
|
|
|
|
|
|
|
|
|
#include "backend.pb.h" |
|
|
#include "backend.grpc.pb.h" |
|
|
#include "common.h" |
|
|
#include <getopt.h> |
|
|
#include <grpcpp/ext/proto_server_reflection_plugin.h> |
|
|
#include <grpcpp/grpcpp.h> |
|
|
#include <grpcpp/health_check_service_interface.h> |
|
|
#include <regex> |
|
|
#include <atomic> |
|
|
#include <mutex> |
|
|
#include <signal.h> |
|
|
#include <thread> |
|
|
|
|
|
#if defined(_WIN32) |
|
|
#include <windows.h> |
|
|
#endif |
|
|
|
|
|
|
|
|
using grpc::Server; |
|
|
using grpc::ServerBuilder; |
|
|
using grpc::ServerContext; |
|
|
using grpc::Status; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool loaded_model; |
|
|
|
|
|
static std::function<void(int)> shutdown_handler; |
|
|
static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; |
|
|
|
|
|
static inline void signal_handler(int signal) { |
|
|
if (is_terminating.test_and_set()) { |
|
|
|
|
|
|
|
|
fprintf(stderr, "Received second interrupt, terminating immediately.\n"); |
|
|
exit(1); |
|
|
} |
|
|
|
|
|
shutdown_handler(signal); |
|
|
} |
|
|
|
|
|
|
|
|
static void start_llama_server(server_context& ctx_server); |
|
|
static json parse_options(bool streaming, const backend::PredictOptions* predict, const common_params& params_base, llama_context* ctx); |
|
|
static ggml_type kv_cache_type_from_str(const std::string & s); |
|
|
static std::string get_all_kv_cache_types(); |
|
|
static void add_rpc_devices(std::string servers); |
|
|
static void params_parse(server_context& ctx_server, const backend::ModelOptions* request, common_params & params); |
|
|
|
|
|
static void start_llama_server(server_context& ctx_server) { |
|
|
|
|
|
LOG_INF("%s: starting llama server\n", __func__); |
|
|
|
|
|
LOG_INF("%s: waiting for model to be loaded\n", __func__); |
|
|
|
|
|
while (!loaded_model) { |
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(100)); |
|
|
} |
|
|
|
|
|
LOG_INF("%s: model loaded\n", __func__); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
shutdown_handler = [&](int) { |
|
|
|
|
|
ctx_server.terminate(); |
|
|
}; |
|
|
|
|
|
|
|
|
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) |
|
|
struct sigaction sigint_action; |
|
|
sigint_action.sa_handler = signal_handler; |
|
|
sigemptyset (&sigint_action.sa_mask); |
|
|
sigint_action.sa_flags = 0; |
|
|
sigaction(SIGINT, &sigint_action, NULL); |
|
|
sigaction(SIGTERM, &sigint_action, NULL); |
|
|
#elif defined (_WIN32) |
|
|
auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { |
|
|
return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; |
|
|
}; |
|
|
SetConsoleCtrlHandler(reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true); |
|
|
#endif |
|
|
|
|
|
|
|
|
ctx_server.start_loop(); |
|
|
} |
|
|
|
|
|
json parse_options(bool streaming, const backend::PredictOptions* predict, const common_params& params_base, llama_context* ctx) |
|
|
{ |
|
|
|
|
|
|
|
|
|
|
|
json data; |
|
|
data["stream"] = streaming; |
|
|
data["cache_prompt"] = predict->promptcacheall(); |
|
|
data["n_predict"] = predict->tokens() == 0 ? -1 : predict->tokens(); |
|
|
data["top_k"] = predict->topk(); |
|
|
data["top_p"] = predict->topp(); |
|
|
data["typical_p"] = predict->typicalp(); |
|
|
data["temperature"] = predict->temperature(); |
|
|
data["repeat_last_n"] = predict->repeat(); |
|
|
data["repeat_penalty"] = predict->penalty(); |
|
|
data["frequency_penalty"] = predict->frequencypenalty(); |
|
|
data["presence_penalty"] = predict->presencepenalty(); |
|
|
data["mirostat"] = predict->mirostat(); |
|
|
data["mirostat_tau"] = predict->mirostattau(); |
|
|
data["mirostat_eta"] = predict->mirostateta(); |
|
|
data["n_keep"] = predict->nkeep(); |
|
|
data["seed"] = predict->seed(); |
|
|
|
|
|
|
|
|
std::string grammar_str = predict->grammar(); |
|
|
|
|
|
|
|
|
|
|
|
if (!grammar_str.empty()) { |
|
|
data["grammar"] = grammar_str; |
|
|
SRV_INF("Using grammar: %s\n", grammar_str.c_str()); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (!predict->usetokenizertemplate() || predict->messages_size() == 0) { |
|
|
data["prompt"] = predict->prompt(); |
|
|
} |
|
|
|
|
|
|
|
|
SRV_INF("[TOOLS DEBUG] parse_options: Checking for tools in proto, tools().empty()=%d, tools().size()=%zu\n", |
|
|
predict->tools().empty() ? 1 : 0, predict->tools().size()); |
|
|
if (!predict->tools().empty()) { |
|
|
SRV_INF("[TOOLS DEBUG] parse_options: Tools string from proto (first 500 chars): %s\n", |
|
|
predict->tools().substr(0, std::min<size_t>(500, predict->tools().size())).c_str()); |
|
|
try { |
|
|
|
|
|
json tools_json = json::parse(predict->tools()); |
|
|
data["tools"] = tools_json; |
|
|
SRV_INF("Extracted tools from proto: %s\n", predict->tools().c_str()); |
|
|
|
|
|
if (tools_json.is_array()) { |
|
|
SRV_INF("[TOOLS DEBUG] parse_options: Successfully parsed %zu tools from Go layer\n", tools_json.size()); |
|
|
for (size_t i = 0; i < tools_json.size(); i++) { |
|
|
if (tools_json[i].contains("function") && tools_json[i]["function"].contains("name")) { |
|
|
SRV_INF("[TOOLS DEBUG] parse_options: Tool %zu: %s\n", i, tools_json[i]["function"]["name"].get<std::string>().c_str()); |
|
|
} else if (tools_json[i].contains("name")) { |
|
|
SRV_INF("[TOOLS DEBUG] parse_options: Tool %zu: %s\n", i, tools_json[i]["name"].get<std::string>().c_str()); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
SRV_WRN("[TOOLS DEBUG] parse_options: Parsed tools JSON is not an array: %s\n", tools_json.dump().c_str()); |
|
|
} |
|
|
} catch (const json::parse_error& e) { |
|
|
SRV_WRN("Failed to parse tools JSON from proto: %s\n", e.what()); |
|
|
SRV_WRN("[TOOLS DEBUG] parse_options: Tools string that failed to parse: %s\n", predict->tools().c_str()); |
|
|
} |
|
|
} else { |
|
|
SRV_INF("%s", "[TOOLS DEBUG] parse_options: No tools received from Go layer (predict->tools() is empty)\n"); |
|
|
} |
|
|
|
|
|
|
|
|
if (data.contains("tools")) { |
|
|
SRV_INF("[TOOLS DEBUG] parse_options: Tools successfully added to data, count: %zu\n", |
|
|
data["tools"].is_array() ? data["tools"].size() : 0); |
|
|
} else { |
|
|
SRV_INF("%s", "[TOOLS DEBUG] parse_options: WARNING - Tools NOT in data after extraction!\n"); |
|
|
} |
|
|
if (!predict->toolchoice().empty()) { |
|
|
try { |
|
|
|
|
|
json tool_choice_json = json::parse(predict->toolchoice()); |
|
|
|
|
|
|
|
|
if (tool_choice_json.is_string()) { |
|
|
data["tool_choice"] = tool_choice_json.get<std::string>(); |
|
|
SRV_DBG("[TOOLS DEBUG] Received tool_choice from Go layer: %s\n", tool_choice_json.get<std::string>().c_str()); |
|
|
} else { |
|
|
|
|
|
data["tool_choice"] = tool_choice_json; |
|
|
SRV_DBG("[TOOLS DEBUG] Received tool_choice object from Go layer: %s\n", tool_choice_json.dump().c_str()); |
|
|
} |
|
|
SRV_INF("Extracted tool_choice from proto: %s\n", predict->toolchoice().c_str()); |
|
|
} catch (const json::parse_error& e) { |
|
|
|
|
|
data["tool_choice"] = predict->toolchoice(); |
|
|
SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (predict->logprobs() > 0) { |
|
|
data["logprobs"] = predict->logprobs(); |
|
|
|
|
|
|
|
|
data["n_probs"] = predict->logprobs(); |
|
|
SRV_INF("Using logprobs: %d\n", predict->logprobs()); |
|
|
} |
|
|
if (predict->toplogprobs() > 0) { |
|
|
data["top_logprobs"] = predict->toplogprobs(); |
|
|
SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs()); |
|
|
} |
|
|
|
|
|
|
|
|
if (!predict->logitbias().empty()) { |
|
|
try { |
|
|
|
|
|
json logit_bias_json = json::parse(predict->logitbias()); |
|
|
|
|
|
data["logit_bias"] = logit_bias_json; |
|
|
SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str()); |
|
|
} catch (const json::parse_error& e) { |
|
|
SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what()); |
|
|
} |
|
|
} |
|
|
|
|
|
data["ignore_eos"] = predict->ignoreeos(); |
|
|
data["embeddings"] = predict->embeddings(); |
|
|
|
|
|
|
|
|
data["correlation_id"] = predict->correlationid(); |
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < predict->images_size(); i++) { |
|
|
data["image_data"].push_back(json |
|
|
{ |
|
|
{"id", i}, |
|
|
{"data", predict->images(i)}, |
|
|
}); |
|
|
} |
|
|
|
|
|
|
|
|
for (int i = 0; i < predict->audios_size(); i++) { |
|
|
data["audio_data"].push_back(json |
|
|
{ |
|
|
{"id", i}, |
|
|
{"data", predict->audios(i)}, |
|
|
}); |
|
|
} |
|
|
|
|
|
data["stop"] = predict->stopprompts(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!params_base.sampling.grammar_triggers.empty()) { |
|
|
json grammar_triggers = json::array(); |
|
|
for (const auto& trigger : params_base.sampling.grammar_triggers) { |
|
|
json trigger_json; |
|
|
trigger_json["value"] = trigger.value; |
|
|
|
|
|
trigger_json["type"] = static_cast<int>(COMMON_GRAMMAR_TRIGGER_TYPE_WORD); |
|
|
grammar_triggers.push_back(trigger_json); |
|
|
} |
|
|
data["grammar_triggers"] = grammar_triggers; |
|
|
} |
|
|
|
|
|
|
|
|
if (!params_base.sampling.preserved_tokens.empty()) { |
|
|
json preserved_tokens = json::array(); |
|
|
for (const auto& token : params_base.sampling.preserved_tokens) { |
|
|
preserved_tokens.push_back(common_token_to_piece(ctx, token)); |
|
|
} |
|
|
data["preserved_tokens"] = preserved_tokens; |
|
|
} |
|
|
|
|
|
return data; |
|
|
} |
|
|
|
|
|
|
|
|
const std::vector<ggml_type> kv_cache_types = { |
|
|
GGML_TYPE_F32, |
|
|
GGML_TYPE_F16, |
|
|
GGML_TYPE_BF16, |
|
|
GGML_TYPE_Q8_0, |
|
|
GGML_TYPE_Q4_0, |
|
|
GGML_TYPE_Q4_1, |
|
|
GGML_TYPE_IQ4_NL, |
|
|
GGML_TYPE_Q5_0, |
|
|
GGML_TYPE_Q5_1, |
|
|
}; |
|
|
|
|
|
static ggml_type kv_cache_type_from_str(const std::string & s) { |
|
|
for (const auto & type : kv_cache_types) { |
|
|
if (ggml_type_name(type) == s) { |
|
|
return type; |
|
|
} |
|
|
} |
|
|
throw std::runtime_error("Unsupported cache type: " + s); |
|
|
} |
|
|
|
|
|
static std::string get_all_kv_cache_types() { |
|
|
std::ostringstream msg; |
|
|
for (const auto & type : kv_cache_types) { |
|
|
msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", "); |
|
|
} |
|
|
return msg.str(); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static void add_rpc_devices(std::string servers) { |
|
|
auto rpc_servers = string_split<std::string>(servers, ','); |
|
|
|
|
|
for (std::string & server : rpc_servers) |
|
|
{ |
|
|
server.erase(0, server.find_first_not_of(" \t\n\r")); |
|
|
server.erase(server.find_last_not_of(" \t\n\r") + 1); |
|
|
} |
|
|
if (rpc_servers.empty()) { |
|
|
throw std::invalid_argument("no RPC servers specified"); |
|
|
} |
|
|
ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); |
|
|
if (!rpc_reg) { |
|
|
throw std::invalid_argument("failed to find RPC backend"); |
|
|
} |
|
|
typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint); |
|
|
ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); |
|
|
if (!ggml_backend_rpc_add_server_fn) { |
|
|
throw std::invalid_argument("failed to find RPC add server function"); |
|
|
} |
|
|
for (const auto & server : rpc_servers) { |
|
|
ggml_backend_reg_t reg = ggml_backend_rpc_add_server_fn(server.c_str()); |
|
|
ggml_backend_register(reg); |
|
|
} |
|
|
} |
|
|
|
|
|
static void params_parse(server_context& , const backend::ModelOptions* request, |
|
|
common_params & params) { |
|
|
|
|
|
|
|
|
|
|
|
params.model.path = request->modelfile(); |
|
|
if (!request->mmproj().empty()) { |
|
|
params.mmproj.path = request->mmproj(); |
|
|
} |
|
|
|
|
|
params.model_alias = request->modelfile(); |
|
|
if (!request->cachetypekey().empty()) { |
|
|
params.cache_type_k = kv_cache_type_from_str(request->cachetypekey()); |
|
|
} |
|
|
if (!request->cachetypevalue().empty()) { |
|
|
params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue()); |
|
|
} |
|
|
params.n_ctx = request->contextsize(); |
|
|
|
|
|
params.cpuparams.n_threads = request->threads(); |
|
|
params.n_gpu_layers = request->ngpulayers(); |
|
|
params.n_batch = request->nbatch(); |
|
|
|
|
|
|
|
|
|
|
|
params.n_ubatch = request->nbatch(); |
|
|
|
|
|
|
|
|
params.ctx_shift = false; |
|
|
|
|
|
params.cache_ram_mib = -1; |
|
|
|
|
|
params.n_parallel = 1; |
|
|
|
|
|
std::string grpc_servers_option = ""; |
|
|
|
|
|
|
|
|
|
|
|
params.fit_params = true; |
|
|
|
|
|
|
|
|
params.fit_params_target = std::vector<size_t>(llama_max_devices(), 1024 * 1024 * 1024); |
|
|
|
|
|
params.fit_params_min_ctx = 4096; |
|
|
|
|
|
|
|
|
|
|
|
params.n_cache_reuse = 0; |
|
|
|
|
|
params.slot_prompt_similarity = 0.1f; |
|
|
|
|
|
params.swa_full = false; |
|
|
|
|
|
params.cont_batching = true; |
|
|
|
|
|
params.check_tensors = false; |
|
|
|
|
|
params.warmup = true; |
|
|
|
|
|
params.no_op_offload = false; |
|
|
|
|
|
params.kv_unified = false; |
|
|
|
|
|
params.n_ctx_checkpoints = 8; |
|
|
|
|
|
|
|
|
for (int i = 0; i < request->options_size(); i++) { |
|
|
std::string opt = request->options(i); |
|
|
std::vector<char> opt_buf(opt.begin(), opt.end()); |
|
|
opt_buf.push_back('\0'); |
|
|
char *optname = strtok(opt_buf.data(), ":"); |
|
|
char *optval = strtok(NULL, ":"); |
|
|
std::string optval_str = (optval == NULL) ? "true" : optval; |
|
|
|
|
|
if (!strcmp(optname, "context_shift")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.ctx_shift = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.ctx_shift = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "use_jinja") || !strcmp(optname, "jinja")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.use_jinja = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.use_jinja = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "cache_ram")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
params.cache_ram_mib = std::stoi(optval_str); |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (!strcmp(optname, "parallel") || !strcmp(optname, "n_parallel")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
params.n_parallel = std::stoi(optval_str); |
|
|
if (params.n_parallel > 1) { |
|
|
params.cont_batching = true; |
|
|
} |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (!strcmp(optname, "grpc_servers") || !strcmp(optname, "rpc_servers")) { |
|
|
if (optval != NULL) { |
|
|
grpc_servers_option = optval_str; |
|
|
} |
|
|
} else if (!strcmp(optname, "fit_params") || !strcmp(optname, "fit")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.fit_params = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.fit_params = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "fit_params_target") || !strcmp(optname, "fit_target")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
|
|
|
|
|
|
std::string arg_next = optval_str; |
|
|
const std::regex regex{ R"([,/]+)" }; |
|
|
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; |
|
|
std::vector<std::string> split_arg{ it, {} }; |
|
|
if (split_arg.size() >= llama_max_devices()) { |
|
|
|
|
|
continue; |
|
|
} |
|
|
if (split_arg.size() == 1) { |
|
|
|
|
|
size_t value_mib = std::stoul(split_arg[0]); |
|
|
std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), value_mib * 1024 * 1024); |
|
|
} else { |
|
|
|
|
|
for (size_t i = 0; i < split_arg.size() && i < params.fit_params_target.size(); i++) { |
|
|
params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024 * 1024; |
|
|
} |
|
|
} |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (!strcmp(optname, "fit_params_min_ctx") || !strcmp(optname, "fit_ctx")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
params.fit_params_min_ctx = std::stoi(optval_str); |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (!strcmp(optname, "n_cache_reuse") || !strcmp(optname, "cache_reuse")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
params.n_cache_reuse = std::stoi(optval_str); |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (!strcmp(optname, "slot_prompt_similarity") || !strcmp(optname, "sps")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
params.slot_prompt_similarity = std::stof(optval_str); |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} else if (!strcmp(optname, "swa_full")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.swa_full = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.swa_full = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "cont_batching") || !strcmp(optname, "continuous_batching")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.cont_batching = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.cont_batching = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "check_tensors")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.check_tensors = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.check_tensors = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "warmup")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.warmup = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.warmup = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "no_op_offload")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.no_op_offload = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.no_op_offload = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "kv_unified") || !strcmp(optname, "unified_kv")) { |
|
|
if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { |
|
|
params.kv_unified = true; |
|
|
} else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { |
|
|
params.kv_unified = false; |
|
|
} |
|
|
} else if (!strcmp(optname, "n_ctx_checkpoints") || !strcmp(optname, "ctx_checkpoints")) { |
|
|
if (optval != NULL) { |
|
|
try { |
|
|
params.n_ctx_checkpoints = std::stoi(optval_str); |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (params.n_parallel == 1) { |
|
|
const char *env_parallel = std::getenv("LLAMACPP_PARALLEL"); |
|
|
if (env_parallel != NULL) { |
|
|
try { |
|
|
params.n_parallel = std::stoi(env_parallel); |
|
|
if (params.n_parallel > 1) { |
|
|
params.cont_batching = true; |
|
|
} |
|
|
} catch (const std::exception& e) { |
|
|
|
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!grpc_servers_option.empty()) { |
|
|
add_rpc_devices(grpc_servers_option); |
|
|
} else { |
|
|
const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS"); |
|
|
if (llama_grpc_servers != NULL) { |
|
|
add_rpc_devices(std::string(llama_grpc_servers)); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (request->overrides_size() > 0) { |
|
|
for (int i = 0; i < request->overrides_size(); i++) { |
|
|
string_parse_kv_override(request->overrides(i).c_str(), params.kv_overrides); |
|
|
} |
|
|
} |
|
|
|
|
|
if (!params.kv_overrides.empty()) { |
|
|
params.kv_overrides.emplace_back(); |
|
|
params.kv_overrides.back().key[0] = 0; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (!request->tensorsplit().empty()) { |
|
|
std::string arg_next = request->tensorsplit(); |
|
|
|
|
|
|
|
|
const std::regex regex{ R"([,/]+)" }; |
|
|
std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; |
|
|
std::vector<std::string> split_arg{ it, {} }; |
|
|
|
|
|
GGML_ASSERT(split_arg.size() <= llama_max_devices()); |
|
|
|
|
|
for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { |
|
|
if (i_device < split_arg.size()) { |
|
|
params.tensor_split[i_device] = std::stof(split_arg[i_device]); |
|
|
} |
|
|
else { |
|
|
params.tensor_split[i_device] = 0.0f; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
if (!request->maingpu().empty()) { |
|
|
params.main_gpu = std::stoi(request->maingpu()); |
|
|
} |
|
|
if (!request->loraadapter().empty() && !request->lorabase().empty()) { |
|
|
float scale_factor = 1.0f; |
|
|
if (request->lorascale() != 0.0f) { |
|
|
scale_factor = request->lorascale(); |
|
|
} |
|
|
|
|
|
std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\")); |
|
|
common_adapter_lora_info lora_info; |
|
|
lora_info.path = model_dir + "/" + request->loraadapter(); |
|
|
lora_info.scale = scale_factor; |
|
|
lora_info.task_name = ""; |
|
|
lora_info.prompt_prefix = ""; |
|
|
lora_info.ptr = nullptr; |
|
|
params.lora_adapters.push_back(std::move(lora_info)); |
|
|
} |
|
|
params.use_mlock = request->mlock(); |
|
|
params.use_mmap = request->mmap(); |
|
|
|
|
|
if (request->flashattention() == "on" || request->flashattention() == "enabled") { |
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; |
|
|
} else if (request->flashattention() == "off" || request->flashattention() == "disabled") { |
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; |
|
|
} else if (request->flashattention() == "auto") { |
|
|
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; |
|
|
} |
|
|
|
|
|
params.no_kv_offload = request->nokvoffload(); |
|
|
params.embedding = request->embeddings() || request->reranking(); |
|
|
if (request->reranking()) { |
|
|
params.pooling_type = LLAMA_POOLING_TYPE_RANK; |
|
|
} |
|
|
|
|
|
|
|
|
if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } |
|
|
else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } |
|
|
else if (request->ropescaling() == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } |
|
|
|
|
|
if ( request->yarnextfactor() != 0.0f ) { |
|
|
params.yarn_ext_factor = request->yarnextfactor(); |
|
|
} |
|
|
if ( request->yarnattnfactor() != 0.0f ) { |
|
|
params.yarn_attn_factor = request->yarnattnfactor(); |
|
|
} |
|
|
if ( request->yarnbetafast() != 0.0f ) { |
|
|
params.yarn_beta_fast = request->yarnbetafast(); |
|
|
} |
|
|
if ( request->yarnbetaslow() != 0.0f ) { |
|
|
params.yarn_beta_slow = request->yarnbetaslow(); |
|
|
} |
|
|
if ( request->ropefreqbase() != 0.0f ) { |
|
|
params.rope_freq_base = request->ropefreqbase(); |
|
|
} |
|
|
if ( request->ropefreqscale() != 0.0f ) { |
|
|
params.rope_freq_scale = request->ropefreqscale(); |
|
|
} |
|
|
|
|
|
if (request->grammartriggers_size() > 0) { |
|
|
|
|
|
|
|
|
for (int i = 0; i < request->grammartriggers_size(); i++) { |
|
|
const auto & word = request->grammartriggers(i).word(); |
|
|
common_grammar_trigger trigger; |
|
|
trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; |
|
|
trigger.value = word; |
|
|
params.sampling.grammar_triggers.push_back(std::move(trigger)); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class BackendServiceImpl final : public backend::Backend::Service { |
|
|
private: |
|
|
server_context& ctx_server; |
|
|
common_params params_base; |
|
|
|
|
|
public: |
|
|
BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {} |
|
|
|
|
|
grpc::Status Health(ServerContext* , const backend::HealthMessage* , backend::Reply* reply) override { |
|
|
|
|
|
reply->set_message("OK"); |
|
|
return Status::OK; |
|
|
} |
|
|
|
|
|
grpc::Status LoadModel(ServerContext* , const backend::ModelOptions* request, backend::Result* result) override { |
|
|
|
|
|
common_params params; |
|
|
params_parse(ctx_server, request, params); |
|
|
|
|
|
common_init(); |
|
|
|
|
|
common_log_set_verbosity_thold(params.verbosity); |
|
|
|
|
|
llama_backend_init(); |
|
|
llama_numa_init(params.numa); |
|
|
|
|
|
|
|
|
LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); |
|
|
LOG_INF("\n"); |
|
|
LOG_INF("%s\n", common_params_get_system_info(params).c_str()); |
|
|
LOG_INF("\n"); |
|
|
|
|
|
|
|
|
struct error_capture { |
|
|
std::string captured_error; |
|
|
std::mutex error_mutex; |
|
|
ggml_log_callback original_callback; |
|
|
void* original_user_data; |
|
|
} error_capture_data; |
|
|
|
|
|
|
|
|
llama_log_get(&error_capture_data.original_callback, &error_capture_data.original_user_data); |
|
|
|
|
|
|
|
|
llama_log_set([](ggml_log_level level, const char * text, void * user_data) { |
|
|
auto* capture = static_cast<error_capture*>(user_data); |
|
|
|
|
|
|
|
|
if (level == GGML_LOG_LEVEL_ERROR) { |
|
|
std::lock_guard<std::mutex> lock(capture->error_mutex); |
|
|
|
|
|
std::string msg(text); |
|
|
while (!msg.empty() && (msg.back() == '\n' || msg.back() == '\r')) { |
|
|
msg.pop_back(); |
|
|
} |
|
|
if (!msg.empty()) { |
|
|
if (!capture->captured_error.empty()) { |
|
|
capture->captured_error.append("; "); |
|
|
} |
|
|
capture->captured_error.append(msg); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (capture->original_callback) { |
|
|
capture->original_callback(level, text, capture->original_user_data); |
|
|
} |
|
|
}, &error_capture_data); |
|
|
|
|
|
|
|
|
bool load_success = ctx_server.load_model(params); |
|
|
|
|
|
|
|
|
llama_log_set(error_capture_data.original_callback, error_capture_data.original_user_data); |
|
|
|
|
|
if (!load_success) { |
|
|
std::string error_msg = "Failed to load model: " + params.model.path; |
|
|
if (!params.mmproj.path.empty()) { |
|
|
error_msg += " (with mmproj: " + params.mmproj.path + ")"; |
|
|
} |
|
|
if (params.has_speculative() && !params.speculative.model.path.empty()) { |
|
|
error_msg += " (with draft model: " + params.speculative.model.path + ")"; |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
std::lock_guard<std::mutex> lock(error_capture_data.error_mutex); |
|
|
if (!error_capture_data.captured_error.empty()) { |
|
|
error_msg += ". Error: " + error_capture_data.captured_error; |
|
|
} else { |
|
|
error_msg += ". Model file may not exist or be invalid."; |
|
|
} |
|
|
} |
|
|
|
|
|
result->set_message(error_msg); |
|
|
result->set_success(false); |
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, error_msg); |
|
|
} |
|
|
|
|
|
|
|
|
if (!params.sampling.grammar_triggers.empty()) { |
|
|
std::vector<common_grammar_trigger> processed_triggers; |
|
|
for (const auto& trigger : params.sampling.grammar_triggers) { |
|
|
if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { |
|
|
auto ids = common_tokenize(ctx_server.impl->vocab, trigger.value, false, true); |
|
|
if (ids.size() == 1) { |
|
|
auto token = ids[0]; |
|
|
|
|
|
if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) { |
|
|
params.sampling.preserved_tokens.insert(token); |
|
|
LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str()); |
|
|
} |
|
|
LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str()); |
|
|
common_grammar_trigger processed_trigger; |
|
|
processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; |
|
|
processed_trigger.value = trigger.value; |
|
|
processed_trigger.token = token; |
|
|
processed_triggers.push_back(std::move(processed_trigger)); |
|
|
} else { |
|
|
LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str()); |
|
|
processed_triggers.push_back(trigger); |
|
|
} |
|
|
} else { |
|
|
processed_triggers.push_back(trigger); |
|
|
} |
|
|
} |
|
|
|
|
|
params.sampling.grammar_triggers = std::move(processed_triggers); |
|
|
} |
|
|
|
|
|
|
|
|
result->set_message("Loading succeeded"); |
|
|
result->set_success(true); |
|
|
loaded_model = true; |
|
|
|
|
|
params_base = params; |
|
|
|
|
|
return Status::OK; |
|
|
} |
|
|
|
|
|
|
|
|
static json extract_logprobs_from_json(const json& res_json) { |
|
|
json logprobs_json = json::object(); |
|
|
|
|
|
|
|
|
if (res_json.contains("choices") && res_json["choices"].is_array() && |
|
|
res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) { |
|
|
logprobs_json = res_json["choices"][0]["logprobs"]; |
|
|
} |
|
|
|
|
|
else if (res_json.contains("completion_probabilities")) { |
|
|
|
|
|
logprobs_json["content"] = res_json["completion_probabilities"]; |
|
|
} |
|
|
|
|
|
else if (res_json.contains("logprobs")) { |
|
|
logprobs_json = res_json["logprobs"]; |
|
|
} |
|
|
|
|
|
return logprobs_json; |
|
|
} |
|
|
|
|
|
grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter<backend::Reply>* writer) override { |
|
|
if (params_base.model.path.empty()) { |
|
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); |
|
|
} |
|
|
json data = parse_options(true, request, params_base, ctx_server.get_llama_context()); |
|
|
|
|
|
|
|
|
|
|
|
if (params_base.embedding) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode"); |
|
|
} |
|
|
|
|
|
|
|
|
auto completion_id = gen_chatcmplid(); |
|
|
|
|
|
auto rd = ctx_server.get_response_reader(); |
|
|
try { |
|
|
std::vector<server_task> tasks; |
|
|
|
|
|
std::string prompt_str; |
|
|
std::vector<raw_buffer> files; |
|
|
|
|
|
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_templates != nullptr) { |
|
|
|
|
|
json body_json; |
|
|
json messages_json = json::array(); |
|
|
|
|
|
|
|
|
int last_user_msg_idx = -1; |
|
|
for (int i = request->messages_size() - 1; i >= 0; i--) { |
|
|
if (request->messages(i).role() == "user") { |
|
|
last_user_msg_idx = i; |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
for (int i = 0; i < request->messages_size(); i++) { |
|
|
const auto& msg = request->messages(i); |
|
|
json msg_json; |
|
|
msg_json["role"] = msg.role(); |
|
|
|
|
|
bool is_last_user_msg = (i == last_user_msg_idx); |
|
|
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0); |
|
|
|
|
|
|
|
|
|
|
|
if (!msg.content().empty()) { |
|
|
|
|
|
json content_val; |
|
|
try { |
|
|
content_val = json::parse(msg.content()); |
|
|
|
|
|
if (content_val.is_null()) { |
|
|
content_val = ""; |
|
|
} |
|
|
} catch (const json::parse_error&) { |
|
|
|
|
|
content_val = msg.content(); |
|
|
} |
|
|
|
|
|
|
|
|
if (content_val.is_object()) { |
|
|
content_val = content_val.dump(); |
|
|
} |
|
|
|
|
|
|
|
|
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) { |
|
|
json content_array = json::array(); |
|
|
|
|
|
content_array.push_back({{"type", "text"}, {"text", content_val.get<std::string>()}}); |
|
|
|
|
|
if (request->images_size() > 0) { |
|
|
for (int j = 0; j < request->images_size(); j++) { |
|
|
json image_chunk; |
|
|
image_chunk["type"] = "image_url"; |
|
|
json image_url; |
|
|
image_url["url"] = "data:image/jpeg;base64," + request->images(j); |
|
|
image_chunk["image_url"] = image_url; |
|
|
content_array.push_back(image_chunk); |
|
|
} |
|
|
} |
|
|
|
|
|
if (request->audios_size() > 0) { |
|
|
for (int j = 0; j < request->audios_size(); j++) { |
|
|
json audio_chunk; |
|
|
audio_chunk["type"] = "input_audio"; |
|
|
json input_audio; |
|
|
input_audio["data"] = request->audios(j); |
|
|
input_audio["format"] = "wav"; |
|
|
audio_chunk["input_audio"] = input_audio; |
|
|
content_array.push_back(audio_chunk); |
|
|
} |
|
|
} |
|
|
msg_json["content"] = content_array; |
|
|
} else { |
|
|
|
|
|
|
|
|
if (content_val.is_null()) { |
|
|
msg_json["content"] = ""; |
|
|
} else { |
|
|
msg_json["content"] = content_val; |
|
|
} |
|
|
} |
|
|
} else if (is_last_user_msg && has_images_or_audio) { |
|
|
|
|
|
json content_array = json::array(); |
|
|
if (request->images_size() > 0) { |
|
|
for (int j = 0; j < request->images_size(); j++) { |
|
|
json image_chunk; |
|
|
image_chunk["type"] = "image_url"; |
|
|
json image_url; |
|
|
image_url["url"] = "data:image/jpeg;base64," + request->images(j); |
|
|
image_chunk["image_url"] = image_url; |
|
|
content_array.push_back(image_chunk); |
|
|
} |
|
|
} |
|
|
if (request->audios_size() > 0) { |
|
|
for (int j = 0; j < request->audios_size(); j++) { |
|
|
json audio_chunk; |
|
|
audio_chunk["type"] = "input_audio"; |
|
|
json input_audio; |
|
|
input_audio["data"] = request->audios(j); |
|
|
input_audio["format"] = "wav"; |
|
|
audio_chunk["input_audio"] = input_audio; |
|
|
content_array.push_back(audio_chunk); |
|
|
} |
|
|
} |
|
|
msg_json["content"] = content_array; |
|
|
} else if (msg.role() == "tool") { |
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0); |
|
|
if (msg.content().empty()) { |
|
|
msg_json["content"] = ""; |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): empty content, set to empty string\n", i); |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): content exists: %s\n", |
|
|
i, msg.content().substr(0, std::min<size_t>(200, msg.content().size())).c_str()); |
|
|
|
|
|
json content_val; |
|
|
try { |
|
|
content_val = json::parse(msg.content()); |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): parsed JSON, type=%s\n", |
|
|
i, content_val.is_null() ? "null" : |
|
|
content_val.is_object() ? "object" : |
|
|
content_val.is_string() ? "string" : |
|
|
content_val.is_array() ? "array" : "other"); |
|
|
|
|
|
if (content_val.is_null()) { |
|
|
msg_json["content"] = ""; |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): null content, converted to empty string\n", i); |
|
|
} else if (content_val.is_object()) { |
|
|
|
|
|
msg_json["content"] = content_val.dump(); |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): object content, converted to string: %s\n", |
|
|
i, content_val.dump().substr(0, std::min<size_t>(200, content_val.dump().size())).c_str()); |
|
|
} else if (content_val.is_string()) { |
|
|
msg_json["content"] = content_val.get<std::string>(); |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): string content, using as-is\n", i); |
|
|
} else { |
|
|
|
|
|
msg_json["content"] = content_val.dump(); |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): %s content, converted to string\n", |
|
|
i, content_val.is_array() ? "array" : "other type"); |
|
|
} |
|
|
} catch (const json::parse_error&) { |
|
|
|
|
|
msg_json["content"] = msg.content(); |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): not JSON, using as string\n", i); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
|
|
|
if (!msg_json.contains("content")) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (role=%s): no content field, adding empty string\n", |
|
|
i, msg.role().c_str()); |
|
|
msg_json["content"] = ""; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!msg.name().empty()) { |
|
|
msg_json["name"] = msg.name(); |
|
|
} |
|
|
if (!msg.tool_call_id().empty()) { |
|
|
msg_json["tool_call_id"] = msg.tool_call_id(); |
|
|
} |
|
|
if (!msg.reasoning_content().empty()) { |
|
|
msg_json["reasoning_content"] = msg.reasoning_content(); |
|
|
} |
|
|
if (!msg.tool_calls().empty()) { |
|
|
|
|
|
try { |
|
|
json tool_calls = json::parse(msg.tool_calls()); |
|
|
msg_json["tool_calls"] = tool_calls; |
|
|
SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get<std::string>().empty())) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d has tool_calls but empty content, setting to space\n", i); |
|
|
msg_json["content"] = " "; |
|
|
} |
|
|
|
|
|
if (tool_calls.is_array()) { |
|
|
for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) { |
|
|
const auto& tc = tool_calls[tc_idx]; |
|
|
std::string tool_name = "unknown"; |
|
|
std::string tool_args = "{}"; |
|
|
if (tc.contains("function")) { |
|
|
const auto& func = tc["function"]; |
|
|
if (func.contains("name")) { |
|
|
tool_name = func["name"].get<std::string>(); |
|
|
} |
|
|
if (func.contains("arguments")) { |
|
|
tool_args = func["arguments"].is_string() ? |
|
|
func["arguments"].get<std::string>() : |
|
|
func["arguments"].dump(); |
|
|
} |
|
|
} else if (tc.contains("name")) { |
|
|
tool_name = tc["name"].get<std::string>(); |
|
|
if (tc.contains("arguments")) { |
|
|
tool_args = tc["arguments"].is_string() ? |
|
|
tc["arguments"].get<std::string>() : |
|
|
tc["arguments"].dump(); |
|
|
} |
|
|
} |
|
|
SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d, tool_call %zu: name=%s, arguments=%s\n", |
|
|
i, tc_idx, tool_name.c_str(), tool_args.c_str()); |
|
|
} |
|
|
} |
|
|
} catch (const json::parse_error& e) { |
|
|
SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (msg_json.contains("content")) { |
|
|
if (msg_json["content"].is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i); |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content type=%s, has_value=%d\n", |
|
|
i, msg_json["content"].is_string() ? "string" : |
|
|
msg_json["content"].is_array() ? "array" : |
|
|
msg_json["content"].is_object() ? "object" : "other", |
|
|
msg_json["content"].is_null() ? 0 : 1); |
|
|
} |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i); |
|
|
} |
|
|
|
|
|
messages_json.push_back(msg_json); |
|
|
} |
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Running final safety check on %zu messages\n", messages_json.size()); |
|
|
for (size_t idx = 0; idx < messages_json.size(); idx++) { |
|
|
auto& msg = messages_json[idx]; |
|
|
if (msg.contains("content") && msg["content"].is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Safety check found message %zu with NULL content, converting to empty string\n", idx); |
|
|
msg["content"] = ""; |
|
|
} else if (!msg.contains("content")) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Safety check found message %zu without content field, adding empty string\n", idx); |
|
|
msg["content"] = ""; |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Safety check message %zu: content OK, type=%s\n", |
|
|
idx, msg["content"].is_string() ? "string" : |
|
|
msg["content"].is_array() ? "array" : |
|
|
msg["content"].is_object() ? "object" : "other"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int tool_msg_count = 0; |
|
|
for (const auto& msg : messages_json) { |
|
|
if (msg.contains("role") && msg["role"] == "tool") { |
|
|
tool_msg_count++; |
|
|
} |
|
|
} |
|
|
SRV_DBG("[TOOLS DEBUG] PredictStream: Built %d tool messages out of %zu total messages\n", tool_msg_count, messages_json.size()); |
|
|
|
|
|
|
|
|
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full messages array:\n%s\n", messages_json.dump(2).c_str()); |
|
|
|
|
|
body_json["messages"] = messages_json; |
|
|
body_json["stream"] = true; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool has_grammar_from_go = data.contains("grammar") && |
|
|
data["grammar"].is_string() && |
|
|
!data["grammar"].get<std::string>().empty(); |
|
|
|
|
|
SRV_INF("[TOOLS DEBUG] PredictStream: has_grammar_from_go=%d, data.contains(\"tools\")=%d, data.contains(\"grammar\")=%d\n", |
|
|
has_grammar_from_go ? 1 : 0, |
|
|
data.contains("tools") ? 1 : 0, |
|
|
data.contains("grammar") ? 1 : 0); |
|
|
if (data.contains("grammar")) { |
|
|
SRV_INF("[TOOLS DEBUG] PredictStream: grammar type=%s, empty=%d\n", |
|
|
data["grammar"].is_string() ? "string" : "other", |
|
|
data["grammar"].is_string() && data["grammar"].get<std::string>().empty() ? 1 : 0); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!has_grammar_from_go) { |
|
|
|
|
|
if (data.contains("tools")) { |
|
|
body_json["tools"] = data["tools"]; |
|
|
std::string tools_str = data["tools"].dump(); |
|
|
SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str()); |
|
|
|
|
|
if (data["tools"].is_array()) { |
|
|
SRV_INF("[TOOLS DEBUG] PredictStream: Passing %zu tools to oaicompat_chat_params_parse\n", data["tools"].size()); |
|
|
for (size_t t_idx = 0; t_idx < data["tools"].size(); t_idx++) { |
|
|
const auto& tool = data["tools"][t_idx]; |
|
|
std::string tool_name = "unknown"; |
|
|
std::string tool_desc = ""; |
|
|
if (tool.contains("function")) { |
|
|
const auto& func = tool["function"]; |
|
|
if (func.contains("name")) { |
|
|
tool_name = func["name"].get<std::string>(); |
|
|
} |
|
|
if (func.contains("description")) { |
|
|
tool_desc = func["description"].is_string() ? |
|
|
func["description"].get<std::string>() : ""; |
|
|
} |
|
|
} else if (tool.contains("name")) { |
|
|
tool_name = tool["name"].get<std::string>(); |
|
|
if (tool.contains("description")) { |
|
|
tool_desc = tool["description"].is_string() ? |
|
|
tool["description"].get<std::string>() : ""; |
|
|
} |
|
|
} |
|
|
SRV_INF("[TOOLS DEBUG] PredictStream: Tool %zu: name=%s, description=%s\n", |
|
|
t_idx, tool_name.c_str(), tool_desc.substr(0, 100).c_str()); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n"); |
|
|
SRV_DBG("[TOOLS DEBUG] PredictStream: No tools in data, tool_choice=%s\n", data.contains("tool_choice") ? data["tool_choice"].dump().c_str() : "not set"); |
|
|
} |
|
|
if (data.contains("tool_choice")) { |
|
|
|
|
|
|
|
|
if (data["tool_choice"].is_string()) { |
|
|
body_json["tool_choice"] = data["tool_choice"].get<std::string>(); |
|
|
} else if (data["tool_choice"].is_object()) { |
|
|
|
|
|
body_json["tool_choice"] = "required"; |
|
|
std::string tool_choice_obj_str = data["tool_choice"].dump(); |
|
|
SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str()); |
|
|
} else { |
|
|
|
|
|
body_json["tool_choice"] = data["tool_choice"].dump(); |
|
|
} |
|
|
std::string tool_choice_str = body_json["tool_choice"].get<std::string>(); |
|
|
SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str()); |
|
|
} else { |
|
|
|
|
|
body_json["tool_choice"] = "auto"; |
|
|
} |
|
|
} else { |
|
|
|
|
|
SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n"); |
|
|
|
|
|
} |
|
|
|
|
|
if (data.contains("json_schema")) { |
|
|
body_json["json_schema"] = data["json_schema"]; |
|
|
} |
|
|
|
|
|
|
|
|
if (has_grammar_from_go) { |
|
|
body_json["grammar"] = data["grammar"]; |
|
|
} |
|
|
if (data.contains("response_format")) { |
|
|
body_json["response_format"] = data["response_format"]; |
|
|
} |
|
|
if (data.contains("chat_template_kwargs")) { |
|
|
body_json["chat_template_kwargs"] = data["chat_template_kwargs"]; |
|
|
} |
|
|
|
|
|
if (data.contains("parallel_tool_calls")) { |
|
|
body_json["parallel_tool_calls"] = data["parallel_tool_calls"]; |
|
|
} |
|
|
|
|
|
if (data.contains("add_generation_prompt")) { |
|
|
body_json["add_generation_prompt"] = data["add_generation_prompt"]; |
|
|
} |
|
|
|
|
|
|
|
|
SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
oaicompat_parser_options parser_opt = ctx_server.impl->oai_parser_opt; |
|
|
parser_opt.tmpls = ctx_server.impl->chat_templates.get(); |
|
|
|
|
|
parser_opt.allow_image = ctx_server.impl->mctx ? mtmd_support_vision(ctx_server.impl->mctx) : false; |
|
|
parser_opt.allow_audio = ctx_server.impl->mctx ? mtmd_support_audio(ctx_server.impl->mctx) : false; |
|
|
|
|
|
|
|
|
if (body_json.contains("tools")) { |
|
|
SRV_DBG("[TOOLS DEBUG] PredictStream: Before oaicompat_chat_params_parse - tools count: %zu\n", |
|
|
body_json["tools"].is_array() ? body_json["tools"].size() : 0); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (body_json.contains("messages") && body_json["messages"].is_array()) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size()); |
|
|
for (size_t idx = 0; idx < body_json["messages"].size(); idx++) { |
|
|
auto& msg = body_json["messages"][idx]; |
|
|
std::string role_str = msg.contains("role") ? msg["role"].get<std::string>() : "unknown"; |
|
|
if (msg.contains("content")) { |
|
|
if (msg["content"].is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str()); |
|
|
msg["content"] = ""; |
|
|
} else if (role_str == "tool" && msg["content"].is_array()) { |
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx); |
|
|
msg["content"] = msg["content"].dump(); |
|
|
} else if (!msg["content"].is_string() && !msg["content"].is_array()) { |
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str()); |
|
|
if (msg["content"].is_object()) { |
|
|
msg["content"] = msg["content"].dump(); |
|
|
} else { |
|
|
msg["content"] = ""; |
|
|
} |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n", |
|
|
idx, role_str.c_str(), |
|
|
msg["content"].is_string() ? "string" : |
|
|
msg["content"].is_array() ? "array" : |
|
|
msg["content"].is_object() ? "object" : "other"); |
|
|
} |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str()); |
|
|
msg["content"] = ""; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files); |
|
|
|
|
|
|
|
|
if (parsed_data.contains("tools")) { |
|
|
SRV_DBG("[TOOLS DEBUG] PredictStream: After oaicompat_chat_params_parse - tools count: %zu\n", |
|
|
parsed_data["tools"].is_array() ? parsed_data["tools"].size() : 0); |
|
|
} else { |
|
|
SRV_DBG("%s", "[TOOLS DEBUG] PredictStream: After oaicompat_chat_params_parse - no tools in parsed_data\n"); |
|
|
} |
|
|
|
|
|
|
|
|
prompt_str = parsed_data.at("prompt").get<std::string>(); |
|
|
|
|
|
|
|
|
|
|
|
json preserved_grammar; |
|
|
if (has_grammar_from_go && data.contains("grammar")) { |
|
|
preserved_grammar = data["grammar"]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto& item : parsed_data.items()) { |
|
|
if (item.key() != "prompt") { |
|
|
|
|
|
if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) { |
|
|
data["grammar"] = preserved_grammar; |
|
|
} else { |
|
|
data[item.key()] = item.value(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (data.contains("parse_tool_calls")) { |
|
|
SRV_DBG("[TOOLS DEBUG] PredictStream: parse_tool_calls=%s\n", data["parse_tool_calls"].get<bool>() ? "true" : "false"); |
|
|
} |
|
|
} else { |
|
|
|
|
|
if (data.contains("prompt") && data["prompt"].is_string()) { |
|
|
prompt_str = data["prompt"].get<std::string>(); |
|
|
} else { |
|
|
prompt_str = request->prompt(); |
|
|
} |
|
|
} |
|
|
|
|
|
const auto type = SERVER_TASK_TYPE_COMPLETION; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_templates == nullptr) { |
|
|
const auto &images_data = data.find("image_data"); |
|
|
if (images_data != data.end() && images_data->is_array()) |
|
|
{ |
|
|
for (const auto &img : *images_data) |
|
|
{ |
|
|
auto decoded_data = base64_decode(img["data"].get<std::string>()); |
|
|
files.push_back(decoded_data); |
|
|
} |
|
|
} |
|
|
|
|
|
const auto &audio_data = data.find("audio_data"); |
|
|
if (audio_data != data.end() && audio_data->is_array()) |
|
|
{ |
|
|
for (const auto &audio : *audio_data) |
|
|
{ |
|
|
auto decoded_data = base64_decode(audio["data"].get<std::string>()); |
|
|
files.push_back(decoded_data); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
const bool has_mtmd = ctx_server.impl->mctx != nullptr; |
|
|
|
|
|
|
|
|
std::vector<server_tokens> inputs; |
|
|
if (has_mtmd) { |
|
|
|
|
|
inputs.push_back(process_mtmd_prompt(ctx_server.impl->mctx, prompt_str, files)); |
|
|
} else { |
|
|
|
|
|
inputs = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt_str, true, true); |
|
|
} |
|
|
|
|
|
tasks.reserve(inputs.size()); |
|
|
for (size_t i = 0; i < inputs.size(); i++) { |
|
|
server_task task = server_task(type); |
|
|
|
|
|
task.id = rd.queue_tasks.get_new_id(); |
|
|
task.index = i; |
|
|
|
|
|
task.tokens = std::move(inputs[i]); |
|
|
task.params = server_task::params_from_json_cmpl( |
|
|
ctx_server.impl->vocab, |
|
|
params_base, |
|
|
ctx_server.get_meta().slot_n_ctx, |
|
|
data); |
|
|
task.id_slot = json_value(data, "id_slot", -1); |
|
|
|
|
|
|
|
|
task.params.res_type = TASK_RESPONSE_TYPE_NONE; |
|
|
task.params.oaicompat_cmpl_id = completion_id; |
|
|
|
|
|
|
|
|
tasks.push_back(std::move(task)); |
|
|
} |
|
|
|
|
|
rd.post_tasks(std::move(tasks)); |
|
|
} catch (const std::exception & e) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); |
|
|
} |
|
|
|
|
|
|
|
|
server_task_result_ptr first_result = rd.next([&context]() { return context->IsCancelled(); }); |
|
|
if (first_result == nullptr) { |
|
|
|
|
|
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); |
|
|
} else if (first_result->is_error()) { |
|
|
json error_json = first_result->to_json(); |
|
|
backend::Reply reply; |
|
|
reply.set_message(error_json.value("message", "")); |
|
|
writer->Write(reply); |
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred")); |
|
|
} |
|
|
|
|
|
|
|
|
json first_res_json = first_result->to_json(); |
|
|
if (first_res_json.is_array()) { |
|
|
for (const auto & res : first_res_json) { |
|
|
std::string completion_text = res.value("content", ""); |
|
|
|
|
|
backend::Reply reply; |
|
|
reply.set_message(completion_text); |
|
|
int32_t tokens_predicted = res.value("tokens_predicted", 0); |
|
|
reply.set_tokens(tokens_predicted); |
|
|
int32_t tokens_evaluated = res.value("tokens_evaluated", 0); |
|
|
reply.set_prompt_tokens(tokens_evaluated); |
|
|
|
|
|
if (res.contains("timings")) { |
|
|
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); |
|
|
reply.set_timing_prompt_processing(timing_prompt_processing); |
|
|
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0); |
|
|
reply.set_timing_token_generation(timing_token_generation); |
|
|
} |
|
|
|
|
|
|
|
|
json logprobs_json = extract_logprobs_from_json(res); |
|
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) { |
|
|
std::string logprobs_str = logprobs_json.dump(); |
|
|
reply.set_logprobs(logprobs_str); |
|
|
} |
|
|
|
|
|
writer->Write(reply); |
|
|
} |
|
|
} else { |
|
|
std::string completion_text = first_res_json.value("content", ""); |
|
|
|
|
|
backend::Reply reply; |
|
|
reply.set_message(completion_text); |
|
|
int32_t tokens_predicted = first_res_json.value("tokens_predicted", 0); |
|
|
reply.set_tokens(tokens_predicted); |
|
|
int32_t tokens_evaluated = first_res_json.value("tokens_evaluated", 0); |
|
|
reply.set_prompt_tokens(tokens_evaluated); |
|
|
|
|
|
if (first_res_json.contains("timings")) { |
|
|
double timing_prompt_processing = first_res_json.at("timings").value("prompt_ms", 0.0); |
|
|
reply.set_timing_prompt_processing(timing_prompt_processing); |
|
|
double timing_token_generation = first_res_json.at("timings").value("predicted_ms", 0.0); |
|
|
reply.set_timing_token_generation(timing_token_generation); |
|
|
} |
|
|
|
|
|
|
|
|
json logprobs_json = extract_logprobs_from_json(first_res_json); |
|
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) { |
|
|
std::string logprobs_str = logprobs_json.dump(); |
|
|
reply.set_logprobs(logprobs_str); |
|
|
} |
|
|
|
|
|
writer->Write(reply); |
|
|
} |
|
|
|
|
|
|
|
|
while (rd.has_next()) { |
|
|
|
|
|
if (context->IsCancelled()) { |
|
|
break; |
|
|
} |
|
|
|
|
|
auto result = rd.next([&context]() { return context->IsCancelled(); }); |
|
|
if (result == nullptr) { |
|
|
|
|
|
break; |
|
|
} |
|
|
|
|
|
json res_json = result->to_json(); |
|
|
if (res_json.is_array()) { |
|
|
for (const auto & res : res_json) { |
|
|
std::string completion_text = res.value("content", ""); |
|
|
|
|
|
backend::Reply reply; |
|
|
reply.set_message(completion_text); |
|
|
int32_t tokens_predicted = res.value("tokens_predicted", 0); |
|
|
reply.set_tokens(tokens_predicted); |
|
|
int32_t tokens_evaluated = res.value("tokens_evaluated", 0); |
|
|
reply.set_prompt_tokens(tokens_evaluated); |
|
|
|
|
|
if (res.contains("timings")) { |
|
|
double timing_prompt_processing = res.at("timings").value("prompt_ms", 0.0); |
|
|
reply.set_timing_prompt_processing(timing_prompt_processing); |
|
|
double timing_token_generation = res.at("timings").value("predicted_ms", 0.0); |
|
|
reply.set_timing_token_generation(timing_token_generation); |
|
|
} |
|
|
|
|
|
|
|
|
json logprobs_json = extract_logprobs_from_json(res); |
|
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) { |
|
|
std::string logprobs_str = logprobs_json.dump(); |
|
|
reply.set_logprobs(logprobs_str); |
|
|
} |
|
|
|
|
|
writer->Write(reply); |
|
|
} |
|
|
} else { |
|
|
std::string completion_text = res_json.value("content", ""); |
|
|
|
|
|
backend::Reply reply; |
|
|
reply.set_message(completion_text); |
|
|
int32_t tokens_predicted = res_json.value("tokens_predicted", 0); |
|
|
reply.set_tokens(tokens_predicted); |
|
|
int32_t tokens_evaluated = res_json.value("tokens_evaluated", 0); |
|
|
reply.set_prompt_tokens(tokens_evaluated); |
|
|
|
|
|
if (res_json.contains("timings")) { |
|
|
double timing_prompt_processing = res_json.at("timings").value("prompt_ms", 0.0); |
|
|
reply.set_timing_prompt_processing(timing_prompt_processing); |
|
|
double timing_token_generation = res_json.at("timings").value("predicted_ms", 0.0); |
|
|
reply.set_timing_token_generation(timing_token_generation); |
|
|
} |
|
|
|
|
|
|
|
|
json logprobs_json = extract_logprobs_from_json(res_json); |
|
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) { |
|
|
std::string logprobs_str = logprobs_json.dump(); |
|
|
reply.set_logprobs(logprobs_str); |
|
|
} |
|
|
|
|
|
writer->Write(reply); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (context->IsCancelled()) { |
|
|
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); |
|
|
} |
|
|
|
|
|
return grpc::Status::OK; |
|
|
} |
|
|
|
|
|
grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override { |
|
|
if (params_base.model.path.empty()) { |
|
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); |
|
|
} |
|
|
json data = parse_options(true, request, params_base, ctx_server.get_llama_context()); |
|
|
|
|
|
data["stream"] = false; |
|
|
|
|
|
if (params_base.embedding) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in Predict mode"); |
|
|
} |
|
|
std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl; |
|
|
auto completion_id = gen_chatcmplid(); |
|
|
auto rd = ctx_server.get_response_reader(); |
|
|
try { |
|
|
std::vector<server_task> tasks; |
|
|
|
|
|
std::string prompt_str; |
|
|
std::vector<raw_buffer> files; |
|
|
|
|
|
if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_templates != nullptr) { |
|
|
|
|
|
json body_json; |
|
|
json messages_json = json::array(); |
|
|
|
|
|
|
|
|
int last_user_msg_idx = -1; |
|
|
for (int i = request->messages_size() - 1; i >= 0; i--) { |
|
|
if (request->messages(i).role() == "user") { |
|
|
last_user_msg_idx = i; |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: Processing %d messages\n", request->messages_size()); |
|
|
for (int i = 0; i < request->messages_size(); i++) { |
|
|
const auto& msg = request->messages(i); |
|
|
json msg_json; |
|
|
msg_json["role"] = msg.role(); |
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d: role=%s, content_empty=%d, content_length=%zu\n", |
|
|
i, msg.role().c_str(), msg.content().empty() ? 1 : 0, msg.content().size()); |
|
|
if (!msg.content().empty()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d content (first 200 chars): %s\n", |
|
|
i, msg.content().substr(0, std::min<size_t>(200, msg.content().size())).c_str()); |
|
|
} |
|
|
|
|
|
bool is_last_user_msg = (i == last_user_msg_idx); |
|
|
bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0); |
|
|
|
|
|
|
|
|
|
|
|
if (!msg.content().empty()) { |
|
|
|
|
|
json content_val; |
|
|
try { |
|
|
content_val = json::parse(msg.content()); |
|
|
|
|
|
if (content_val.is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d parsed JSON is null, converting to empty string\n", i); |
|
|
content_val = ""; |
|
|
} |
|
|
} catch (const json::parse_error&) { |
|
|
|
|
|
content_val = msg.content(); |
|
|
} |
|
|
|
|
|
|
|
|
if (content_val.is_object()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d content is object, converting to string\n", i); |
|
|
content_val = content_val.dump(); |
|
|
} |
|
|
|
|
|
|
|
|
if (content_val.is_string() && is_last_user_msg && has_images_or_audio) { |
|
|
json content_array = json::array(); |
|
|
|
|
|
content_array.push_back({{"type", "text"}, {"text", content_val.get<std::string>()}}); |
|
|
|
|
|
if (request->images_size() > 0) { |
|
|
for (int j = 0; j < request->images_size(); j++) { |
|
|
json image_chunk; |
|
|
image_chunk["type"] = "image_url"; |
|
|
json image_url; |
|
|
image_url["url"] = "data:image/jpeg;base64," + request->images(j); |
|
|
image_chunk["image_url"] = image_url; |
|
|
content_array.push_back(image_chunk); |
|
|
} |
|
|
} |
|
|
|
|
|
if (request->audios_size() > 0) { |
|
|
for (int j = 0; j < request->audios_size(); j++) { |
|
|
json audio_chunk; |
|
|
audio_chunk["type"] = "input_audio"; |
|
|
json input_audio; |
|
|
input_audio["data"] = request->audios(j); |
|
|
input_audio["format"] = "wav"; |
|
|
audio_chunk["input_audio"] = input_audio; |
|
|
content_array.push_back(audio_chunk); |
|
|
} |
|
|
} |
|
|
msg_json["content"] = content_array; |
|
|
} else { |
|
|
|
|
|
|
|
|
if (content_val.is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d content_val was null, setting to empty string\n", i); |
|
|
msg_json["content"] = ""; |
|
|
} else { |
|
|
msg_json["content"] = content_val; |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d content set, type=%s\n", |
|
|
i, content_val.is_string() ? "string" : |
|
|
content_val.is_array() ? "array" : |
|
|
content_val.is_object() ? "object" : "other"); |
|
|
} |
|
|
} |
|
|
} else if (is_last_user_msg && has_images_or_audio) { |
|
|
|
|
|
json content_array = json::array(); |
|
|
if (request->images_size() > 0) { |
|
|
for (int j = 0; j < request->images_size(); j++) { |
|
|
json image_chunk; |
|
|
image_chunk["type"] = "image_url"; |
|
|
json image_url; |
|
|
image_url["url"] = "data:image/jpeg;base64," + request->images(j); |
|
|
image_chunk["image_url"] = image_url; |
|
|
content_array.push_back(image_chunk); |
|
|
} |
|
|
} |
|
|
if (request->audios_size() > 0) { |
|
|
for (int j = 0; j < request->audios_size(); j++) { |
|
|
json audio_chunk; |
|
|
audio_chunk["type"] = "input_audio"; |
|
|
json input_audio; |
|
|
input_audio["data"] = request->audios(j); |
|
|
input_audio["format"] = "wav"; |
|
|
audio_chunk["input_audio"] = input_audio; |
|
|
content_array.push_back(audio_chunk); |
|
|
} |
|
|
} |
|
|
msg_json["content"] = content_array; |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i); |
|
|
} else if (!msg.tool_calls().empty()) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls, setting content to space (not empty string)\n", i); |
|
|
msg_json["content"] = " "; |
|
|
} else if (msg.role() == "tool") { |
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0); |
|
|
if (msg.content().empty()) { |
|
|
msg_json["content"] = ""; |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): empty content, set to empty string\n", i); |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): content exists: %s\n", |
|
|
i, msg.content().substr(0, std::min<size_t>(200, msg.content().size())).c_str()); |
|
|
|
|
|
json content_val; |
|
|
try { |
|
|
content_val = json::parse(msg.content()); |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): parsed JSON, type=%s\n", |
|
|
i, content_val.is_null() ? "null" : |
|
|
content_val.is_object() ? "object" : |
|
|
content_val.is_string() ? "string" : |
|
|
content_val.is_array() ? "array" : "other"); |
|
|
|
|
|
if (content_val.is_null()) { |
|
|
msg_json["content"] = ""; |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): null content, converted to empty string\n", i); |
|
|
} else if (content_val.is_object()) { |
|
|
|
|
|
msg_json["content"] = content_val.dump(); |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): object content, converted to string: %s\n", |
|
|
i, content_val.dump().substr(0, std::min<size_t>(200, content_val.dump().size())).c_str()); |
|
|
} else if (content_val.is_string()) { |
|
|
msg_json["content"] = content_val.get<std::string>(); |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): string content, using as-is\n", i); |
|
|
} else { |
|
|
|
|
|
msg_json["content"] = content_val.dump(); |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): %s content, converted to string\n", |
|
|
i, content_val.is_array() ? "array" : "other type"); |
|
|
} |
|
|
} catch (const json::parse_error&) { |
|
|
|
|
|
msg_json["content"] = msg.content(); |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): not JSON, using as string\n", i); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
|
|
|
if (!msg_json.contains("content")) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d (role=%s): no content field, adding empty string\n", |
|
|
i, msg.role().c_str()); |
|
|
msg_json["content"] = ""; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (!msg.name().empty()) { |
|
|
msg_json["name"] = msg.name(); |
|
|
} |
|
|
if (!msg.tool_call_id().empty()) { |
|
|
msg_json["tool_call_id"] = msg.tool_call_id(); |
|
|
} |
|
|
if (!msg.reasoning_content().empty()) { |
|
|
msg_json["reasoning_content"] = msg.reasoning_content(); |
|
|
} |
|
|
if (!msg.tool_calls().empty()) { |
|
|
|
|
|
try { |
|
|
json tool_calls = json::parse(msg.tool_calls()); |
|
|
msg_json["tool_calls"] = tool_calls; |
|
|
SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get<std::string>().empty())) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls but empty content, setting to space\n", i); |
|
|
msg_json["content"] = " "; |
|
|
} |
|
|
|
|
|
if (tool_calls.is_array()) { |
|
|
for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) { |
|
|
const auto& tc = tool_calls[tc_idx]; |
|
|
std::string tool_name = "unknown"; |
|
|
std::string tool_args = "{}"; |
|
|
if (tc.contains("function")) { |
|
|
const auto& func = tc["function"]; |
|
|
if (func.contains("name")) { |
|
|
tool_name = func["name"].get<std::string>(); |
|
|
} |
|
|
if (func.contains("arguments")) { |
|
|
tool_args = func["arguments"].is_string() ? |
|
|
func["arguments"].get<std::string>() : |
|
|
func["arguments"].dump(); |
|
|
} |
|
|
} else if (tc.contains("name")) { |
|
|
tool_name = tc["name"].get<std::string>(); |
|
|
if (tc.contains("arguments")) { |
|
|
tool_args = tc["arguments"].is_string() ? |
|
|
tc["arguments"].get<std::string>() : |
|
|
tc["arguments"].dump(); |
|
|
} |
|
|
} |
|
|
SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d, tool_call %zu: name=%s, arguments=%s\n", |
|
|
i, tc_idx, tool_name.c_str(), tool_args.c_str()); |
|
|
} |
|
|
} |
|
|
} catch (const json::parse_error& e) { |
|
|
SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (msg_json.contains("content")) { |
|
|
if (msg_json["content"].is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i); |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content type=%s, has_value=%d\n", |
|
|
i, msg_json["content"].is_string() ? "string" : |
|
|
msg_json["content"].is_array() ? "array" : |
|
|
msg_json["content"].is_object() ? "object" : "other", |
|
|
msg_json["content"].is_null() ? 0 : 1); |
|
|
} |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i); |
|
|
} |
|
|
|
|
|
messages_json.push_back(msg_json); |
|
|
} |
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: Running final safety check on %zu messages\n", messages_json.size()); |
|
|
for (size_t idx = 0; idx < messages_json.size(); idx++) { |
|
|
auto& msg = messages_json[idx]; |
|
|
std::string role_str = msg.contains("role") ? msg["role"].get<std::string>() : "unknown"; |
|
|
if (msg.contains("content") && msg["content"].is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Safety check found message %zu (role=%s) with NULL content, converting to empty string\n", idx, role_str.c_str()); |
|
|
msg["content"] = ""; |
|
|
} else if (!msg.contains("content")) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Safety check found message %zu (role=%s) without content field, adding empty string\n", idx, role_str.c_str()); |
|
|
msg["content"] = ""; |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Safety check message %zu (role=%s): content OK, type=%s\n", |
|
|
idx, role_str.c_str(), |
|
|
msg["content"].is_string() ? "string" : |
|
|
msg["content"].is_array() ? "array" : |
|
|
msg["content"].is_object() ? "object" : "other"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
int tool_msg_count = 0; |
|
|
for (const auto& msg : messages_json) { |
|
|
if (msg.contains("role") && msg["role"] == "tool") { |
|
|
tool_msg_count++; |
|
|
} |
|
|
} |
|
|
SRV_DBG("[TOOLS DEBUG] Predict: Built %d tool messages out of %zu total messages\n", tool_msg_count, messages_json.size()); |
|
|
|
|
|
|
|
|
SRV_DBG("[CONVERSATION DEBUG] Predict: Full messages array:\n%s\n", messages_json.dump(2).c_str()); |
|
|
|
|
|
body_json["messages"] = messages_json; |
|
|
body_json["stream"] = false; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool has_grammar_from_go = data.contains("grammar") && |
|
|
data["grammar"].is_string() && |
|
|
!data["grammar"].get<std::string>().empty(); |
|
|
|
|
|
SRV_INF("[TOOLS DEBUG] Predict: has_grammar_from_go=%d, data.contains(\"tools\")=%d, data.contains(\"grammar\")=%d\n", |
|
|
has_grammar_from_go ? 1 : 0, |
|
|
data.contains("tools") ? 1 : 0, |
|
|
data.contains("grammar") ? 1 : 0); |
|
|
if (data.contains("grammar")) { |
|
|
SRV_INF("[TOOLS DEBUG] Predict: grammar type=%s, empty=%d\n", |
|
|
data["grammar"].is_string() ? "string" : "other", |
|
|
data["grammar"].is_string() && data["grammar"].get<std::string>().empty() ? 1 : 0); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!has_grammar_from_go) { |
|
|
|
|
|
if (data.contains("tools")) { |
|
|
body_json["tools"] = data["tools"]; |
|
|
std::string tools_str = data["tools"].dump(); |
|
|
SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str()); |
|
|
|
|
|
if (data["tools"].is_array()) { |
|
|
SRV_INF("[TOOLS DEBUG] Predict: Passing %zu tools to oaicompat_chat_params_parse\n", data["tools"].size()); |
|
|
for (size_t t_idx = 0; t_idx < data["tools"].size(); t_idx++) { |
|
|
const auto& tool = data["tools"][t_idx]; |
|
|
std::string tool_name = "unknown"; |
|
|
std::string tool_desc = ""; |
|
|
if (tool.contains("function")) { |
|
|
const auto& func = tool["function"]; |
|
|
if (func.contains("name")) { |
|
|
tool_name = func["name"].get<std::string>(); |
|
|
} |
|
|
if (func.contains("description")) { |
|
|
tool_desc = func["description"].is_string() ? |
|
|
func["description"].get<std::string>() : ""; |
|
|
} |
|
|
} else if (tool.contains("name")) { |
|
|
tool_name = tool["name"].get<std::string>(); |
|
|
if (tool.contains("description")) { |
|
|
tool_desc = tool["description"].is_string() ? |
|
|
tool["description"].get<std::string>() : ""; |
|
|
} |
|
|
} |
|
|
SRV_INF("[TOOLS DEBUG] Predict: Tool %zu: name=%s, description=%s\n", |
|
|
t_idx, tool_name.c_str(), tool_desc.substr(0, 100).c_str()); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n"); |
|
|
SRV_DBG("[TOOLS DEBUG] Predict: No tools in data, tool_choice=%s\n", data.contains("tool_choice") ? data["tool_choice"].dump().c_str() : "not set"); |
|
|
} |
|
|
if (data.contains("tool_choice")) { |
|
|
|
|
|
|
|
|
if (data["tool_choice"].is_string()) { |
|
|
body_json["tool_choice"] = data["tool_choice"].get<std::string>(); |
|
|
} else if (data["tool_choice"].is_object()) { |
|
|
|
|
|
body_json["tool_choice"] = "required"; |
|
|
std::string tool_choice_obj_str = data["tool_choice"].dump(); |
|
|
SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str()); |
|
|
} else { |
|
|
|
|
|
body_json["tool_choice"] = data["tool_choice"].dump(); |
|
|
} |
|
|
std::string tool_choice_str = body_json["tool_choice"].get<std::string>(); |
|
|
SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str()); |
|
|
} else { |
|
|
|
|
|
body_json["tool_choice"] = "auto"; |
|
|
} |
|
|
} else { |
|
|
|
|
|
SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n"); |
|
|
|
|
|
} |
|
|
|
|
|
if (data.contains("json_schema")) { |
|
|
body_json["json_schema"] = data["json_schema"]; |
|
|
} |
|
|
|
|
|
|
|
|
if (has_grammar_from_go) { |
|
|
body_json["grammar"] = data["grammar"]; |
|
|
} |
|
|
if (data.contains("response_format")) { |
|
|
body_json["response_format"] = data["response_format"]; |
|
|
} |
|
|
if (data.contains("chat_template_kwargs")) { |
|
|
body_json["chat_template_kwargs"] = data["chat_template_kwargs"]; |
|
|
} |
|
|
|
|
|
if (data.contains("parallel_tool_calls")) { |
|
|
body_json["parallel_tool_calls"] = data["parallel_tool_calls"]; |
|
|
} |
|
|
|
|
|
if (data.contains("add_generation_prompt")) { |
|
|
body_json["add_generation_prompt"] = data["add_generation_prompt"]; |
|
|
} |
|
|
|
|
|
|
|
|
SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
oaicompat_parser_options parser_opt = ctx_server.impl->oai_parser_opt; |
|
|
parser_opt.tmpls = ctx_server.impl->chat_templates.get(); |
|
|
|
|
|
parser_opt.allow_image = ctx_server.impl->mctx ? mtmd_support_vision(ctx_server.impl->mctx) : false; |
|
|
parser_opt.allow_audio = ctx_server.impl->mctx ? mtmd_support_audio(ctx_server.impl->mctx) : false; |
|
|
|
|
|
|
|
|
if (body_json.contains("tools")) { |
|
|
SRV_DBG("[TOOLS DEBUG] Predict: Before oaicompat_chat_params_parse - tools count: %zu\n", |
|
|
body_json["tools"].is_array() ? body_json["tools"].size() : 0); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (body_json.contains("messages") && body_json["messages"].is_array()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size()); |
|
|
for (size_t idx = 0; idx < body_json["messages"].size(); idx++) { |
|
|
auto& msg = body_json["messages"][idx]; |
|
|
std::string role_str = msg.contains("role") ? msg["role"].get<std::string>() : "unknown"; |
|
|
if (msg.contains("content")) { |
|
|
if (msg["content"].is_null()) { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str()); |
|
|
msg["content"] = ""; |
|
|
} else if (role_str == "tool" && msg["content"].is_array()) { |
|
|
|
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx); |
|
|
msg["content"] = msg["content"].dump(); |
|
|
} else if (!msg["content"].is_string() && !msg["content"].is_array()) { |
|
|
|
|
|
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str()); |
|
|
if (msg["content"].is_object()) { |
|
|
msg["content"] = msg["content"].dump(); |
|
|
} else { |
|
|
msg["content"] = ""; |
|
|
} |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n", |
|
|
idx, role_str.c_str(), |
|
|
msg["content"].is_string() ? "string" : |
|
|
msg["content"].is_array() ? "array" : |
|
|
msg["content"].is_object() ? "object" : "other"); |
|
|
} |
|
|
} else { |
|
|
SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str()); |
|
|
msg["content"] = ""; |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
json parsed_data = oaicompat_chat_params_parse(body_json, parser_opt, files); |
|
|
|
|
|
|
|
|
if (parsed_data.contains("tools")) { |
|
|
SRV_DBG("[TOOLS DEBUG] Predict: After oaicompat_chat_params_parse - tools count: %zu\n", |
|
|
parsed_data["tools"].is_array() ? parsed_data["tools"].size() : 0); |
|
|
} else { |
|
|
SRV_DBG("%s", "[TOOLS DEBUG] Predict: After oaicompat_chat_params_parse - no tools in parsed_data\n"); |
|
|
} |
|
|
|
|
|
|
|
|
prompt_str = parsed_data.at("prompt").get<std::string>(); |
|
|
|
|
|
|
|
|
|
|
|
json preserved_grammar; |
|
|
if (has_grammar_from_go && data.contains("grammar")) { |
|
|
preserved_grammar = data["grammar"]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (const auto& item : parsed_data.items()) { |
|
|
if (item.key() != "prompt") { |
|
|
|
|
|
if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) { |
|
|
data["grammar"] = preserved_grammar; |
|
|
} else { |
|
|
data[item.key()] = item.value(); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (data.contains("parse_tool_calls")) { |
|
|
SRV_DBG("[TOOLS DEBUG] Predict: parse_tool_calls=%s\n", data["parse_tool_calls"].get<bool>() ? "true" : "false"); |
|
|
} |
|
|
} else { |
|
|
|
|
|
if (data.contains("prompt") && data["prompt"].is_string()) { |
|
|
prompt_str = data["prompt"].get<std::string>(); |
|
|
} else { |
|
|
prompt_str = request->prompt(); |
|
|
} |
|
|
} |
|
|
|
|
|
const auto type = SERVER_TASK_TYPE_COMPLETION; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_templates == nullptr) { |
|
|
const auto &images_data = data.find("image_data"); |
|
|
if (images_data != data.end() && images_data->is_array()) |
|
|
{ |
|
|
std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl; |
|
|
for (const auto &img : *images_data) |
|
|
{ |
|
|
std::cout << "[PREDICT] Processing image" << std::endl; |
|
|
auto decoded_data = base64_decode(img["data"].get<std::string>()); |
|
|
files.push_back(decoded_data); |
|
|
} |
|
|
} |
|
|
|
|
|
const auto &audio_data = data.find("audio_data"); |
|
|
if (audio_data != data.end() && audio_data->is_array()) |
|
|
{ |
|
|
for (const auto &audio : *audio_data) |
|
|
{ |
|
|
auto decoded_data = base64_decode(audio["data"].get<std::string>()); |
|
|
files.push_back(decoded_data); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
const bool has_mtmd = ctx_server.impl->mctx != nullptr; |
|
|
|
|
|
|
|
|
std::vector<server_tokens> inputs; |
|
|
if (has_mtmd) { |
|
|
|
|
|
inputs.push_back(process_mtmd_prompt(ctx_server.impl->mctx, prompt_str, files)); |
|
|
} else { |
|
|
|
|
|
inputs = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt_str, true, true); |
|
|
} |
|
|
|
|
|
tasks.reserve(inputs.size()); |
|
|
for (size_t i = 0; i < inputs.size(); i++) { |
|
|
server_task task = server_task(type); |
|
|
|
|
|
task.id = rd.queue_tasks.get_new_id(); |
|
|
task.index = i; |
|
|
|
|
|
task.tokens = std::move(inputs[i]); |
|
|
task.params = server_task::params_from_json_cmpl( |
|
|
ctx_server.impl->vocab, |
|
|
params_base, |
|
|
ctx_server.get_meta().slot_n_ctx, |
|
|
data); |
|
|
task.id_slot = json_value(data, "id_slot", -1); |
|
|
|
|
|
|
|
|
task.params.res_type = TASK_RESPONSE_TYPE_NONE; |
|
|
task.params.oaicompat_cmpl_id = completion_id; |
|
|
|
|
|
|
|
|
tasks.push_back(std::move(task)); |
|
|
} |
|
|
|
|
|
rd.post_tasks(std::move(tasks)); |
|
|
} catch (const std::exception & e) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); |
|
|
} |
|
|
|
|
|
|
|
|
std::cout << "[DEBUG] Waiting for results..." << std::endl; |
|
|
|
|
|
|
|
|
auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); |
|
|
|
|
|
if (all_results.is_terminated) { |
|
|
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); |
|
|
} else if (all_results.error) { |
|
|
std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl; |
|
|
reply->set_message(all_results.error->to_json().value("message", "")); |
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred")); |
|
|
} else { |
|
|
std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl; |
|
|
if (all_results.results.size() == 1) { |
|
|
|
|
|
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(all_results.results[0].get()) != nullptr); |
|
|
json result_json = all_results.results[0]->to_json(); |
|
|
reply->set_message(result_json.value("content", "")); |
|
|
|
|
|
int32_t tokens_predicted = result_json.value("tokens_predicted", 0); |
|
|
reply->set_tokens(tokens_predicted); |
|
|
int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0); |
|
|
reply->set_prompt_tokens(tokens_evaluated); |
|
|
|
|
|
if (result_json.contains("timings")) { |
|
|
double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0); |
|
|
reply->set_timing_prompt_processing(timing_prompt_processing); |
|
|
double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0); |
|
|
reply->set_timing_token_generation(timing_token_generation); |
|
|
} |
|
|
|
|
|
|
|
|
json logprobs_json = extract_logprobs_from_json(result_json); |
|
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) { |
|
|
std::string logprobs_str = logprobs_json.dump(); |
|
|
reply->set_logprobs(logprobs_str); |
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
json arr = json::array(); |
|
|
json logprobs_arr = json::array(); |
|
|
bool has_logprobs = false; |
|
|
for (auto & res : all_results.results) { |
|
|
GGML_ASSERT(dynamic_cast<server_task_result_cmpl_final*>(res.get()) != nullptr); |
|
|
json res_json = res->to_json(); |
|
|
arr.push_back(res_json.value("content", "")); |
|
|
|
|
|
|
|
|
json logprobs_json = extract_logprobs_from_json(res_json); |
|
|
if (!logprobs_json.empty() && !logprobs_json.is_null()) { |
|
|
has_logprobs = true; |
|
|
logprobs_arr.push_back(logprobs_json); |
|
|
} else { |
|
|
logprobs_arr.push_back(json::object()); |
|
|
} |
|
|
} |
|
|
reply->set_message(arr); |
|
|
|
|
|
|
|
|
if (has_logprobs) { |
|
|
std::string logprobs_str = logprobs_arr.dump(); |
|
|
reply->set_logprobs(logprobs_str); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
std::cout << "[DEBUG] Predict request completed successfully" << std::endl; |
|
|
|
|
|
|
|
|
if (context->IsCancelled()) { |
|
|
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); |
|
|
} |
|
|
|
|
|
return grpc::Status::OK; |
|
|
} |
|
|
|
|
|
grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override { |
|
|
if (params_base.model.path.empty()) { |
|
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); |
|
|
} |
|
|
json body = parse_options(false, request, params_base, ctx_server.get_llama_context()); |
|
|
|
|
|
body["stream"] = false; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
json prompt = body.at("embeddings"); |
|
|
|
|
|
|
|
|
auto tokenized_prompts = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt, true, true); |
|
|
for (const auto & tokens : tokenized_prompts) { |
|
|
|
|
|
if (tokens.empty()) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Input content cannot be empty"); |
|
|
} |
|
|
} |
|
|
|
|
|
int embd_normalize = 2; |
|
|
|
|
|
auto rd = ctx_server.get_response_reader(); |
|
|
{ |
|
|
std::vector<server_task> tasks; |
|
|
for (size_t i = 0; i < tokenized_prompts.size(); i++) { |
|
|
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); |
|
|
|
|
|
task.id = rd.queue_tasks.get_new_id(); |
|
|
task.index = i; |
|
|
task.tokens = std::move(tokenized_prompts[i]); |
|
|
|
|
|
task.params.res_type = TASK_RESPONSE_TYPE_NONE; |
|
|
task.params.embd_normalize = embd_normalize; |
|
|
tasks.push_back(std::move(task)); |
|
|
} |
|
|
|
|
|
rd.post_tasks(std::move(tasks)); |
|
|
} |
|
|
|
|
|
|
|
|
auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); |
|
|
|
|
|
if (all_results.is_terminated) { |
|
|
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); |
|
|
} else if (all_results.error) { |
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results")); |
|
|
} |
|
|
|
|
|
|
|
|
json responses = json::array(); |
|
|
for (auto & res : all_results.results) { |
|
|
GGML_ASSERT(dynamic_cast<server_task_result_embd*>(res.get()) != nullptr); |
|
|
responses.push_back(res->to_json()); |
|
|
} |
|
|
|
|
|
std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl; |
|
|
|
|
|
|
|
|
for (const auto & response_elem : responses) { |
|
|
|
|
|
if (response_elem.contains("embedding")) { |
|
|
json embedding_data = json_value(response_elem, "embedding", json::array()); |
|
|
|
|
|
if (embedding_data.is_array() && !embedding_data.empty()) { |
|
|
for (const auto & embedding_vector : embedding_data) { |
|
|
if (embedding_vector.is_array()) { |
|
|
for (const auto & embedding_value : embedding_vector) { |
|
|
embeddingResult->add_embeddings(embedding_value.get<float>()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
if (response_elem.is_array()) { |
|
|
for (const auto & embedding_value : response_elem) { |
|
|
embeddingResult->add_embeddings(embedding_value.get<float>()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return grpc::Status::OK; |
|
|
} |
|
|
|
|
|
grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) override { |
|
|
if (!params_base.embedding || params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { |
|
|
return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); |
|
|
} |
|
|
|
|
|
|
|
|
if (request->query().empty()) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided"); |
|
|
} |
|
|
|
|
|
if (request->documents_size() == 0) { |
|
|
return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array"); |
|
|
} |
|
|
|
|
|
|
|
|
auto rd = ctx_server.get_response_reader(); |
|
|
{ |
|
|
std::vector<server_task> tasks; |
|
|
std::vector<std::string> documents; |
|
|
for (int i = 0; i < request->documents_size(); i++) { |
|
|
documents.push_back(request->documents(i)); |
|
|
} |
|
|
|
|
|
tasks.reserve(documents.size()); |
|
|
for (size_t i = 0; i < documents.size(); i++) { |
|
|
auto tmp = format_prompt_rerank(ctx_server.impl->model, ctx_server.impl->vocab, ctx_server.impl->mctx, request->query(), documents[i]); |
|
|
server_task task = server_task(SERVER_TASK_TYPE_RERANK); |
|
|
task.id = rd.queue_tasks.get_new_id(); |
|
|
task.index = i; |
|
|
task.tokens = std::move(tmp); |
|
|
tasks.push_back(std::move(task)); |
|
|
} |
|
|
|
|
|
rd.post_tasks(std::move(tasks)); |
|
|
} |
|
|
|
|
|
|
|
|
auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); |
|
|
|
|
|
if (all_results.is_terminated) { |
|
|
return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); |
|
|
} else if (all_results.error) { |
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results")); |
|
|
} |
|
|
|
|
|
|
|
|
json responses = json::array(); |
|
|
for (auto & res : all_results.results) { |
|
|
GGML_ASSERT(dynamic_cast<server_task_result_rerank*>(res.get()) != nullptr); |
|
|
responses.push_back(res->to_json()); |
|
|
} |
|
|
|
|
|
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) { |
|
|
return a.value("score", 0.0f) > b.value("score", 0.0f); |
|
|
}); |
|
|
|
|
|
|
|
|
int top_n = request->top_n(); |
|
|
if (top_n > 0 && top_n < static_cast<int>(responses.size())) { |
|
|
responses = json(responses.begin(), responses.begin() + top_n); |
|
|
} |
|
|
|
|
|
backend::Usage* usage = rerankResult->mutable_usage(); |
|
|
int total_tokens = 0; |
|
|
int prompt_tokens = 0; |
|
|
|
|
|
|
|
|
for (const auto& response : responses) { |
|
|
backend::DocumentResult* doc_result = rerankResult->add_results(); |
|
|
doc_result->set_index(response.value("index", 0)); |
|
|
doc_result->set_text(request->documents(response.value("index", 0))); |
|
|
doc_result->set_relevance_score(response.value("score", 0.0f)); |
|
|
|
|
|
|
|
|
int tokens_evaluated = response.value("tokens_evaluated", 0); |
|
|
total_tokens += tokens_evaluated; |
|
|
prompt_tokens += tokens_evaluated; |
|
|
} |
|
|
|
|
|
|
|
|
usage->set_total_tokens(total_tokens); |
|
|
usage->set_prompt_tokens(prompt_tokens); |
|
|
|
|
|
return grpc::Status::OK; |
|
|
} |
|
|
|
|
|
grpc::Status TokenizeString(ServerContext* , const backend::PredictOptions* request, backend::TokenizationResponse* response) override { |
|
|
if (params_base.model.path.empty()) { |
|
|
return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); |
|
|
} |
|
|
json body = parse_options(false, request, params_base, ctx_server.get_llama_context()); |
|
|
body["stream"] = false; |
|
|
|
|
|
json tokens_response = json::array(); |
|
|
if (body.count("prompt") != 0) { |
|
|
const bool add_special = json_value(body, "add_special", false); |
|
|
|
|
|
llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("content"), add_special, true); |
|
|
|
|
|
|
|
|
for (const auto& token : tokens) { |
|
|
std::string piece = common_token_to_piece(ctx_server.get_llama_context(), token); |
|
|
response->add_tokens(token); |
|
|
} |
|
|
} |
|
|
|
|
|
return grpc::Status::OK; |
|
|
} |
|
|
|
|
|
grpc::Status GetMetrics(ServerContext* , const backend::MetricsRequest* , backend::MetricsResponse* response) override { |
|
|
|
|
|
|
|
|
auto rd = ctx_server.get_response_reader(); |
|
|
int task_id = rd.queue_tasks.get_new_id(); |
|
|
{ |
|
|
server_task task(SERVER_TASK_TYPE_METRICS); |
|
|
task.id = task_id; |
|
|
rd.queue_results.add_waiting_task_id(task_id); |
|
|
rd.queue_tasks.post(std::move(task), true); |
|
|
} |
|
|
|
|
|
|
|
|
server_task_result_ptr result = rd.queue_results.recv(task_id); |
|
|
rd.queue_results.remove_waiting_task_id(task_id); |
|
|
|
|
|
if (result->is_error()) { |
|
|
|
|
|
response->set_slot_id(0); |
|
|
response->set_prompt_json_for_slot(""); |
|
|
response->set_tokens_per_second(0); |
|
|
response->set_tokens_generated(0); |
|
|
response->set_prompt_tokens_processed(0); |
|
|
return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); |
|
|
} |
|
|
|
|
|
|
|
|
auto res_metrics = dynamic_cast<server_task_result_metrics*>(result.get()); |
|
|
GGML_ASSERT(res_metrics != nullptr); |
|
|
|
|
|
|
|
|
response->set_slot_id(0); |
|
|
response->set_prompt_json_for_slot(""); |
|
|
response->set_tokens_per_second(res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.); |
|
|
response->set_tokens_generated(res_metrics->n_tokens_predicted_total); |
|
|
response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total); |
|
|
|
|
|
|
|
|
return grpc::Status::OK; |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
int main(int argc, char** argv) { |
|
|
std::string server_address("localhost:50051"); |
|
|
|
|
|
|
|
|
struct option long_options[] = { |
|
|
{"addr", required_argument, nullptr, 'a'}, |
|
|
{nullptr, 0, nullptr, 0} |
|
|
}; |
|
|
|
|
|
|
|
|
int option; |
|
|
int option_index = 0; |
|
|
while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) { |
|
|
switch (option) { |
|
|
case 'a': |
|
|
server_address = optarg; |
|
|
break; |
|
|
default: |
|
|
std::cerr << "Usage: " << argv[0] << " [--addr=<address>] or [-a <address>]" << std::endl; |
|
|
return 1; |
|
|
} |
|
|
} |
|
|
|
|
|
server_context ctx_server; |
|
|
BackendServiceImpl service(ctx_server); |
|
|
|
|
|
ServerBuilder builder; |
|
|
builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); |
|
|
builder.RegisterService(&service); |
|
|
builder.SetMaxMessageSize(50 * 1024 * 1024); |
|
|
builder.SetMaxSendMessageSize(50 * 1024 * 1024); |
|
|
builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); |
|
|
std::unique_ptr<Server> server(builder.BuildAndStart()); |
|
|
|
|
|
std::thread t([&]() |
|
|
{ |
|
|
std::cout << "Server listening on " << server_address << std::endl; |
|
|
server->Wait(); |
|
|
return 0; |
|
|
}); |
|
|
|
|
|
|
|
|
auto clean_up = [&server, &ctx_server]() { |
|
|
SRV_INF("%s: cleaning up before exit...\n", __func__); |
|
|
server->Shutdown(); |
|
|
ctx_server.terminate(); |
|
|
llama_backend_free(); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
start_llama_server(ctx_server); |
|
|
std::cout << "stopping" << std::endl; |
|
|
|
|
|
|
|
|
clean_up(); |
|
|
t.join(); |
|
|
|
|
|
return 0; |
|
|
} |
|
|
|