| | #include "arg.h"
|
| |
|
| | #include "common.h"
|
| | #include "gguf.h"
|
| | #include "log.h"
|
| | #include "download.h"
|
| |
|
| | #define JSON_ASSERT GGML_ASSERT
|
| | #include <nlohmann/json.hpp>
|
| |
|
| | #include <algorithm>
|
| | #include <filesystem>
|
| | #include <fstream>
|
| | #include <future>
|
| | #include <map>
|
| | #include <mutex>
|
| | #include <regex>
|
| | #include <string>
|
| | #include <thread>
|
| | #include <vector>
|
| |
|
| | #include "http.h"
|
| |
|
| | #ifndef __EMSCRIPTEN__
|
| | #ifdef __linux__
|
| | #include <linux/limits.h>
|
| | #elif defined(_WIN32)
|
| | # if !defined(PATH_MAX)
|
| | # define PATH_MAX MAX_PATH
|
| | # endif
|
| | #elif defined(_AIX)
|
| | #include <sys/limits.h>
|
| | #else
|
| | #include <sys/syslimits.h>
|
| | #endif
|
| | #endif
|
| |
|
| | #define LLAMA_MAX_URL_LENGTH 2084
|
| |
|
| |
|
| | #if defined(_WIN32)
|
| | #include <io.h>
|
| | #else
|
| | #include <unistd.h>
|
| | #endif
|
| |
|
| | using json = nlohmann::ordered_json;
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | static bool validate_repo_name(const std::string & repo) {
|
| | static const std::regex repo_regex(R"(^[A-Za-z0-9_.\-]+\/[A-Za-z0-9_.\-]+$)");
|
| | return std::regex_match(repo, repo_regex);
|
| | }
|
| |
|
| | static std::string get_manifest_path(const std::string & repo, const std::string & tag) {
|
| |
|
| | std::string fname = "manifest=" + repo + "=" + tag + ".json";
|
| | if (!validate_repo_name(repo)) {
|
| | throw std::runtime_error("error: repo name must be in the format 'owner/repo'");
|
| | }
|
| | string_replace_all(fname, "/", "=");
|
| | return fs_get_cache_file(fname);
|
| | }
|
| |
|
| | static std::string read_file(const std::string & fname) {
|
| | std::ifstream file(fname);
|
| | if (!file) {
|
| | throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
| | }
|
| | std::string content((std::istreambuf_iterator<char>(file)), std::istreambuf_iterator<char>());
|
| | file.close();
|
| | return content;
|
| | }
|
| |
|
| | static void write_file(const std::string & fname, const std::string & content) {
|
| | const std::string fname_tmp = fname + ".tmp";
|
| | std::ofstream file(fname_tmp);
|
| | if (!file) {
|
| | throw std::runtime_error(string_format("error: failed to open file '%s'\n", fname.c_str()));
|
| | }
|
| |
|
| | try {
|
| | file << content;
|
| | file.close();
|
| |
|
| |
|
| | if (rename(fname_tmp.c_str(), fname.c_str()) != 0) {
|
| | LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, fname_tmp.c_str(), fname.c_str());
|
| |
|
| | if (remove(fname_tmp.c_str()) != 0) {
|
| | LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
|
| | }
|
| | }
|
| | } catch (...) {
|
| |
|
| | if (remove(fname_tmp.c_str()) != 0) {
|
| | LOG_ERR("%s: unable to delete temporary file: %s\n", __func__, fname_tmp.c_str());
|
| | }
|
| |
|
| | throw std::runtime_error(string_format("error: failed to write file '%s'\n", fname.c_str()));
|
| | }
|
| | }
|
| |
|
| | static void write_etag(const std::string & path, const std::string & etag) {
|
| | const std::string etag_path = path + ".etag";
|
| | write_file(etag_path, etag);
|
| | LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str());
|
| | }
|
| |
|
| | static std::string read_etag(const std::string & path) {
|
| | const std::string etag_path = path + ".etag";
|
| | if (!std::filesystem::exists(etag_path)) {
|
| | return {};
|
| | }
|
| | std::ifstream etag_in(etag_path);
|
| | if (!etag_in) {
|
| | LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str());
|
| | return {};
|
| | }
|
| | std::string etag;
|
| | std::getline(etag_in, etag);
|
| | return etag;
|
| | }
|
| |
|
| | static bool is_http_status_ok(int status) {
|
| | return status >= 200 && status < 400;
|
| | }
|
| |
|
| | std::pair<std::string, std::string> common_download_split_repo_tag(const std::string & hf_repo_with_tag) {
|
| | auto parts = string_split<std::string>(hf_repo_with_tag, ':');
|
| | std::string tag = parts.size() > 1 ? parts.back() : "latest";
|
| | std::string hf_repo = parts[0];
|
| | if (string_split<std::string>(hf_repo, '/').size() != 2) {
|
| | throw std::invalid_argument("error: invalid HF repo format, expected <user>/<model>[:quant]\n");
|
| | }
|
| | return {hf_repo, tag};
|
| | }
|
| |
|
| | class ProgressBar {
|
| | static inline std::mutex mutex;
|
| | static inline std::map<const ProgressBar *, int> lines;
|
| | static inline int max_line = 0;
|
| |
|
| | static void cleanup(const ProgressBar * line) {
|
| | lines.erase(line);
|
| | if (lines.empty()) {
|
| | max_line = 0;
|
| | }
|
| | }
|
| |
|
| | static bool is_output_a_tty() {
|
| | #if defined(_WIN32)
|
| | return _isatty(_fileno(stdout));
|
| | #else
|
| | return isatty(1);
|
| | #endif
|
| | }
|
| |
|
| | public:
|
| | ProgressBar() = default;
|
| |
|
| | ~ProgressBar() {
|
| | std::lock_guard<std::mutex> lock(mutex);
|
| | cleanup(this);
|
| | }
|
| |
|
| | void update(size_t current, size_t total) {
|
| | if (!is_output_a_tty()) {
|
| | return;
|
| | }
|
| |
|
| | if (!total) {
|
| | return;
|
| | }
|
| |
|
| | std::lock_guard<std::mutex> lock(mutex);
|
| |
|
| | if (lines.find(this) == lines.end()) {
|
| | lines[this] = max_line++;
|
| | std::cout << "\n";
|
| | }
|
| | int lines_up = max_line - lines[this];
|
| |
|
| | size_t width = 50;
|
| | size_t pct = (100 * current) / total;
|
| | size_t pos = (width * current) / total;
|
| |
|
| | std::cout << "\033[s";
|
| |
|
| | if (lines_up > 0) {
|
| | std::cout << "\033[" << lines_up << "A";
|
| | }
|
| | std::cout << "\033[2K\r["
|
| | << std::string(pos, '=')
|
| | << (pos < width ? ">" : "")
|
| | << std::string(width - pos, ' ')
|
| | << "] " << std::setw(3) << pct << "% ("
|
| | << current / (1024 * 1024) << " MB / "
|
| | << total / (1024 * 1024) << " MB) "
|
| | << "\033[u";
|
| |
|
| | std::cout.flush();
|
| |
|
| | if (current == total) {
|
| | cleanup(this);
|
| | }
|
| | }
|
| |
|
| | ProgressBar(const ProgressBar &) = delete;
|
| | ProgressBar & operator=(const ProgressBar &) = delete;
|
| | };
|
| |
|
| | static bool common_pull_file(httplib::Client & cli,
|
| | const std::string & resolve_path,
|
| | const std::string & path_tmp,
|
| | bool supports_ranges,
|
| | size_t existing_size,
|
| | size_t & total_size) {
|
| | std::ofstream ofs(path_tmp, std::ios::binary | std::ios::app);
|
| | if (!ofs.is_open()) {
|
| | LOG_ERR("%s: error opening local file for writing: %s\n", __func__, path_tmp.c_str());
|
| | return false;
|
| | }
|
| |
|
| | httplib::Headers headers;
|
| | if (supports_ranges && existing_size > 0) {
|
| | headers.emplace("Range", "bytes=" + std::to_string(existing_size) + "-");
|
| | }
|
| |
|
| | const char * func = __func__;
|
| | size_t downloaded = existing_size;
|
| | size_t progress_step = 0;
|
| | ProgressBar bar;
|
| |
|
| | auto res = cli.Get(resolve_path, headers,
|
| | [&](const httplib::Response &response) {
|
| | if (existing_size > 0 && response.status != 206) {
|
| | LOG_WRN("%s: server did not respond with 206 Partial Content for a resume request. Status: %d\n", func, response.status);
|
| | return false;
|
| | }
|
| | if (existing_size == 0 && response.status != 200) {
|
| | LOG_WRN("%s: download received non-successful status code: %d\n", func, response.status);
|
| | return false;
|
| | }
|
| | if (total_size == 0 && response.has_header("Content-Length")) {
|
| | try {
|
| | size_t content_length = std::stoull(response.get_header_value("Content-Length"));
|
| | total_size = existing_size + content_length;
|
| | } catch (const std::exception &e) {
|
| | LOG_WRN("%s: invalid Content-Length header: %s\n", func, e.what());
|
| | }
|
| | }
|
| | return true;
|
| | },
|
| | [&](const char *data, size_t len) {
|
| | ofs.write(data, len);
|
| | if (!ofs) {
|
| | LOG_ERR("%s: error writing to file: %s\n", func, path_tmp.c_str());
|
| | return false;
|
| | }
|
| | downloaded += len;
|
| | progress_step += len;
|
| |
|
| | if (progress_step >= total_size / 1000 || downloaded == total_size) {
|
| | bar.update(downloaded, total_size);
|
| | progress_step = 0;
|
| | }
|
| | return true;
|
| | },
|
| | nullptr
|
| | );
|
| |
|
| | if (!res) {
|
| | LOG_ERR("%s: download failed: %s (status: %d)\n",
|
| | __func__,
|
| | httplib::to_string(res.error()).c_str(),
|
| | res ? res->status : -1);
|
| | return false;
|
| | }
|
| |
|
| | return true;
|
| | }
|
| |
|
| |
|
| |
|
| | static int common_download_file_single_online(const std::string & url,
|
| | const std::string & path,
|
| | const std::string & bearer_token,
|
| | const common_header_list & custom_headers) {
|
| | static const int max_attempts = 3;
|
| | static const int retry_delay_seconds = 2;
|
| |
|
| | auto [cli, parts] = common_http_client(url);
|
| |
|
| | httplib::Headers headers;
|
| | for (const auto & h : custom_headers) {
|
| | headers.emplace(h.first, h.second);
|
| | }
|
| | if (headers.find("User-Agent") == headers.end()) {
|
| | headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
| | }
|
| | if (!bearer_token.empty()) {
|
| | headers.emplace("Authorization", "Bearer " + bearer_token);
|
| | }
|
| | cli.set_default_headers(headers);
|
| |
|
| | const bool file_exists = std::filesystem::exists(path);
|
| |
|
| | std::string last_etag;
|
| | if (file_exists) {
|
| | last_etag = read_etag(path);
|
| | } else {
|
| | LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str());
|
| | }
|
| |
|
| | auto head = cli.Head(parts.path);
|
| | if (!head || head->status < 200 || head->status >= 300) {
|
| | LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1);
|
| | if (file_exists) {
|
| | LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str());
|
| | return 304;
|
| | }
|
| | return head ? head->status : -1;
|
| | }
|
| |
|
| | std::string etag;
|
| | if (head->has_header("ETag")) {
|
| | etag = head->get_header_value("ETag");
|
| | }
|
| |
|
| | size_t total_size = 0;
|
| | if (head->has_header("Content-Length")) {
|
| | try {
|
| | total_size = std::stoull(head->get_header_value("Content-Length"));
|
| | } catch (const std::exception& e) {
|
| | LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what());
|
| | }
|
| | }
|
| |
|
| | bool supports_ranges = false;
|
| | if (head->has_header("Accept-Ranges")) {
|
| | supports_ranges = head->get_header_value("Accept-Ranges") != "none";
|
| | }
|
| |
|
| | if (file_exists) {
|
| | if (etag.empty()) {
|
| | LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str());
|
| | return 304;
|
| | }
|
| | if (!last_etag.empty() && last_etag == etag) {
|
| | LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str());
|
| | return 304;
|
| | }
|
| | if (remove(path.c_str()) != 0) {
|
| | LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str());
|
| | return -1;
|
| | }
|
| | }
|
| |
|
| | const std::string path_temporary = path + ".downloadInProgress";
|
| | int delay = retry_delay_seconds;
|
| |
|
| | for (int i = 0; i < max_attempts; ++i) {
|
| | if (i) {
|
| | LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay);
|
| | std::this_thread::sleep_for(std::chrono::seconds(delay));
|
| | delay *= retry_delay_seconds;
|
| | }
|
| |
|
| | size_t existing_size = 0;
|
| |
|
| | if (std::filesystem::exists(path_temporary)) {
|
| | if (supports_ranges) {
|
| | existing_size = std::filesystem::file_size(path_temporary);
|
| | } else if (remove(path_temporary.c_str()) != 0) {
|
| | LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str());
|
| | return -1;
|
| | }
|
| | }
|
| |
|
| | LOG_INF("%s: downloading from %s to %s (etag:%s)...\n",
|
| | __func__, common_http_show_masked_url(parts).c_str(),
|
| | path_temporary.c_str(), etag.c_str());
|
| |
|
| | if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) {
|
| | if (std::rename(path_temporary.c_str(), path.c_str()) != 0) {
|
| | LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str());
|
| | return -1;
|
| | }
|
| | if (!etag.empty()) {
|
| | write_etag(path, etag);
|
| | }
|
| | return head->status;
|
| | }
|
| | }
|
| |
|
| | LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts);
|
| | return -1;
|
| | }
|
| |
|
| | std::pair<long, std::vector<char>> common_remote_get_content(const std::string & url,
|
| | const common_remote_params & params) {
|
| | auto [cli, parts] = common_http_client(url);
|
| |
|
| | httplib::Headers headers;
|
| | for (const auto & h : params.headers) {
|
| | headers.emplace(h.first, h.second);
|
| | }
|
| | if (headers.find("User-Agent") == headers.end()) {
|
| | headers.emplace("User-Agent", "llama-cpp/" + build_info);
|
| | }
|
| |
|
| | if (params.timeout > 0) {
|
| | cli.set_read_timeout(params.timeout, 0);
|
| | cli.set_write_timeout(params.timeout, 0);
|
| | }
|
| |
|
| | std::vector<char> buf;
|
| | auto res = cli.Get(parts.path, headers,
|
| | [&](const char *data, size_t len) {
|
| | buf.insert(buf.end(), data, data + len);
|
| | return params.max_size == 0 ||
|
| | buf.size() <= static_cast<size_t>(params.max_size);
|
| | },
|
| | nullptr
|
| | );
|
| |
|
| | if (!res) {
|
| | throw std::runtime_error("error: cannot make GET request");
|
| | }
|
| |
|
| | return { res->status, std::move(buf) };
|
| | }
|
| |
|
| | int common_download_file_single(const std::string & url,
|
| | const std::string & path,
|
| | const std::string & bearer_token,
|
| | bool offline,
|
| | const common_header_list & headers) {
|
| | if (!offline) {
|
| | return common_download_file_single_online(url, path, bearer_token, headers);
|
| | }
|
| |
|
| | if (!std::filesystem::exists(path)) {
|
| | LOG_ERR("%s: required file is not available in cache (offline mode): %s\n", __func__, path.c_str());
|
| | return -1;
|
| | }
|
| |
|
| | LOG_INF("%s: using cached file (offline mode): %s\n", __func__, path.c_str());
|
| | return 304;
|
| | }
|
| |
|
| |
|
| |
|
| | static bool common_download_file_multiple(const std::vector<std::pair<std::string, std::string>> & urls,
|
| | const std::string & bearer_token,
|
| | bool offline,
|
| | const common_header_list & headers) {
|
| |
|
| | std::vector<std::future<bool>> futures_download;
|
| | futures_download.reserve(urls.size());
|
| |
|
| | for (auto const & item : urls) {
|
| | futures_download.push_back(
|
| | std::async(
|
| | std::launch::async,
|
| | [&bearer_token, offline, &headers](const std::pair<std::string, std::string> & it) -> bool {
|
| | const int http_status = common_download_file_single(it.first, it.second, bearer_token, offline, headers);
|
| | return is_http_status_ok(http_status);
|
| | },
|
| | item
|
| | )
|
| | );
|
| | }
|
| |
|
| |
|
| | for (auto & f : futures_download) {
|
| | if (!f.get()) {
|
| | return false;
|
| | }
|
| | }
|
| |
|
| | return true;
|
| | }
|
| |
|
| | bool common_download_model(const common_params_model & model,
|
| | const std::string & bearer_token,
|
| | bool offline,
|
| | const common_header_list & headers) {
|
| |
|
| | if (model.url.empty()) {
|
| | LOG_ERR("%s: invalid model url\n", __func__);
|
| | return false;
|
| | }
|
| |
|
| | const int http_status = common_download_file_single(model.url, model.path, bearer_token, offline, headers);
|
| | if (!is_http_status_ok(http_status)) {
|
| | return false;
|
| | }
|
| |
|
| |
|
| | int n_split = 0;
|
| | {
|
| | struct gguf_init_params gguf_params = {
|
| | true,
|
| | NULL,
|
| | };
|
| | auto * ctx_gguf = gguf_init_from_file(model.path.c_str(), gguf_params);
|
| | if (!ctx_gguf) {
|
| | LOG_ERR("\n%s: failed to load input GGUF from %s\n", __func__, model.path.c_str());
|
| | return false;
|
| | }
|
| |
|
| | auto key_n_split = gguf_find_key(ctx_gguf, LLM_KV_SPLIT_COUNT);
|
| | if (key_n_split >= 0) {
|
| | n_split = gguf_get_val_u16(ctx_gguf, key_n_split);
|
| | }
|
| |
|
| | gguf_free(ctx_gguf);
|
| | }
|
| |
|
| | if (n_split > 1) {
|
| | char split_prefix[PATH_MAX] = {0};
|
| | char split_url_prefix[LLAMA_MAX_URL_LENGTH] = {0};
|
| |
|
| |
|
| |
|
| | {
|
| | if (!llama_split_prefix(split_prefix, sizeof(split_prefix), model.path.c_str(), 0, n_split)) {
|
| | LOG_ERR("\n%s: unexpected model file name: %s n_split=%d\n", __func__, model.path.c_str(), n_split);
|
| | return false;
|
| | }
|
| |
|
| | if (!llama_split_prefix(split_url_prefix, sizeof(split_url_prefix), model.url.c_str(), 0, n_split)) {
|
| | LOG_ERR("\n%s: unexpected model url: %s n_split=%d\n", __func__, model.url.c_str(), n_split);
|
| | return false;
|
| | }
|
| | }
|
| |
|
| | std::vector<std::pair<std::string, std::string>> urls;
|
| | for (int idx = 1; idx < n_split; idx++) {
|
| | char split_path[PATH_MAX] = {0};
|
| | llama_split_path(split_path, sizeof(split_path), split_prefix, idx, n_split);
|
| |
|
| | char split_url[LLAMA_MAX_URL_LENGTH] = {0};
|
| | llama_split_path(split_url, sizeof(split_url), split_url_prefix, idx, n_split);
|
| |
|
| | if (std::string(split_path) == model.path) {
|
| | continue;
|
| | }
|
| |
|
| | urls.push_back({split_url, split_path});
|
| | }
|
| |
|
| |
|
| | common_download_file_multiple(urls, bearer_token, offline, headers);
|
| | }
|
| |
|
| | return true;
|
| | }
|
| |
|
| | common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag,
|
| | const std::string & bearer_token,
|
| | bool offline,
|
| | const common_header_list & custom_headers) {
|
| |
|
| | auto [hf_repo, tag] = common_download_split_repo_tag(hf_repo_with_tag);
|
| |
|
| | std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag;
|
| |
|
| |
|
| | common_header_list headers = custom_headers;
|
| | headers.push_back({"Accept", "application/json"});
|
| | if (!bearer_token.empty()) {
|
| | headers.push_back({"Authorization", "Bearer " + bearer_token});
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| | common_remote_params params;
|
| | params.headers = headers;
|
| | long res_code = 0;
|
| | std::string res_str;
|
| | bool use_cache = false;
|
| | std::string cached_response_path = get_manifest_path(hf_repo, tag);
|
| | if (!offline) {
|
| | try {
|
| | auto res = common_remote_get_content(url, params);
|
| | res_code = res.first;
|
| | res_str = std::string(res.second.data(), res.second.size());
|
| | } catch (const std::exception & e) {
|
| | LOG_WRN("error: failed to get manifest at %s: %s\n", url.c_str(), e.what());
|
| | }
|
| | }
|
| | if (res_code == 0) {
|
| | if (std::filesystem::exists(cached_response_path)) {
|
| | LOG_WRN("trying to read manifest from cache: %s\n", cached_response_path.c_str());
|
| | res_str = read_file(cached_response_path);
|
| | res_code = 200;
|
| | use_cache = true;
|
| | } else {
|
| | throw std::runtime_error(
|
| | offline ? "error: failed to get manifest (offline mode)"
|
| | : "error: failed to get manifest (check your internet connection)");
|
| | }
|
| | }
|
| | std::string ggufFile;
|
| | std::string mmprojFile;
|
| |
|
| | if (res_code == 200 || res_code == 304) {
|
| | try {
|
| | auto j = json::parse(res_str);
|
| |
|
| | if (j.contains("ggufFile") && j["ggufFile"].contains("rfilename")) {
|
| | ggufFile = j["ggufFile"]["rfilename"].get<std::string>();
|
| | }
|
| | if (j.contains("mmprojFile") && j["mmprojFile"].contains("rfilename")) {
|
| | mmprojFile = j["mmprojFile"]["rfilename"].get<std::string>();
|
| | }
|
| | } catch (const std::exception & e) {
|
| | throw std::runtime_error(std::string("error parsing manifest JSON: ") + e.what());
|
| | }
|
| | if (!use_cache) {
|
| |
|
| | write_file(cached_response_path, res_str);
|
| | }
|
| | } else if (res_code == 401) {
|
| | throw std::runtime_error("error: model is private or does not exist; if you are accessing a gated model, please provide a valid HF token");
|
| | } else {
|
| | throw std::runtime_error(string_format("error from HF API (%s), response code: %ld, data: %s", url.c_str(), res_code, res_str.c_str()));
|
| | }
|
| |
|
| |
|
| | if (ggufFile.empty()) {
|
| | throw std::runtime_error("error: model does not have ggufFile");
|
| | }
|
| |
|
| | return { hf_repo, ggufFile, mmprojFile };
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | static std::string common_docker_get_token(const std::string & repo) {
|
| | std::string url = "https://auth.docker.io/token?service=registry.docker.io&scope=repository:" + repo + ":pull";
|
| |
|
| | common_remote_params params;
|
| | auto res = common_remote_get_content(url, params);
|
| |
|
| | if (res.first != 200) {
|
| | throw std::runtime_error("Failed to get Docker registry token, HTTP code: " + std::to_string(res.first));
|
| | }
|
| |
|
| | std::string response_str(res.second.begin(), res.second.end());
|
| | nlohmann::ordered_json response = nlohmann::ordered_json::parse(response_str);
|
| |
|
| | if (!response.contains("token")) {
|
| | throw std::runtime_error("Docker registry token response missing 'token' field");
|
| | }
|
| |
|
| | return response["token"].get<std::string>();
|
| | }
|
| |
|
| | std::string common_docker_resolve_model(const std::string & docker) {
|
| |
|
| | size_t colon_pos = docker.find(':');
|
| | std::string repo, tag;
|
| | if (colon_pos != std::string::npos) {
|
| | repo = docker.substr(0, colon_pos);
|
| | tag = docker.substr(colon_pos + 1);
|
| | } else {
|
| | repo = docker;
|
| | tag = "latest";
|
| | }
|
| |
|
| |
|
| | size_t slash_pos = docker.find('/');
|
| | if (slash_pos == std::string::npos) {
|
| | repo.insert(0, "ai/");
|
| | }
|
| |
|
| | LOG_INF("%s: Downloading Docker Model: %s:%s\n", __func__, repo.c_str(), tag.c_str());
|
| | try {
|
| |
|
| | auto validate_oci_digest = [](const std::string & digest) -> std::string {
|
| |
|
| |
|
| | static const std::regex re("^sha256:([a-fA-F0-9]{64})$");
|
| | std::smatch m;
|
| | if (!std::regex_match(digest, m, re)) {
|
| | throw std::runtime_error("Invalid OCI digest format received in manifest: " + digest);
|
| | }
|
| |
|
| | std::string normalized = digest;
|
| | std::transform(normalized.begin()+7, normalized.end(), normalized.begin()+7, [](unsigned char c){
|
| | return std::tolower(c);
|
| | });
|
| | return normalized;
|
| | };
|
| |
|
| | std::string token = common_docker_get_token(repo);
|
| |
|
| |
|
| |
|
| | const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo;
|
| | std::string manifest_url = url_prefix + "/manifests/" + tag;
|
| | common_remote_params manifest_params;
|
| | manifest_params.headers.push_back({"Authorization", "Bearer " + token});
|
| | manifest_params.headers.push_back({"Accept",
|
| | "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"
|
| | });
|
| | auto manifest_res = common_remote_get_content(manifest_url, manifest_params);
|
| | if (manifest_res.first != 200) {
|
| | throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first));
|
| | }
|
| |
|
| | std::string manifest_str(manifest_res.second.begin(), manifest_res.second.end());
|
| | nlohmann::ordered_json manifest = nlohmann::ordered_json::parse(manifest_str);
|
| | std::string gguf_digest;
|
| | if (manifest.contains("layers")) {
|
| | for (const auto & layer : manifest["layers"]) {
|
| | if (layer.contains("mediaType")) {
|
| | std::string media_type = layer["mediaType"].get<std::string>();
|
| | if (media_type == "application/vnd.docker.ai.gguf.v3" ||
|
| | media_type.find("gguf") != std::string::npos) {
|
| | gguf_digest = layer["digest"].get<std::string>();
|
| | break;
|
| | }
|
| | }
|
| | }
|
| | }
|
| |
|
| | if (gguf_digest.empty()) {
|
| | throw std::runtime_error("No GGUF layer found in Docker manifest");
|
| | }
|
| |
|
| |
|
| | gguf_digest = validate_oci_digest(gguf_digest);
|
| | LOG_DBG("%s: Using validated digest: %s\n", __func__, gguf_digest.c_str());
|
| |
|
| |
|
| | std::string model_filename = repo;
|
| | std::replace(model_filename.begin(), model_filename.end(), '/', '_');
|
| | model_filename += "_" + tag + ".gguf";
|
| | std::string local_path = fs_get_cache_file(model_filename);
|
| |
|
| | const std::string blob_url = url_prefix + "/blobs/" + gguf_digest;
|
| | const int http_status = common_download_file_single(blob_url, local_path, token, false, {});
|
| | if (!is_http_status_ok(http_status)) {
|
| | throw std::runtime_error("Failed to download Docker Model");
|
| | }
|
| |
|
| | LOG_INF("%s: Downloaded Docker Model to: %s\n", __func__, local_path.c_str());
|
| | return local_path;
|
| | } catch (const std::exception & e) {
|
| | LOG_ERR("%s: Docker Model download failed: %s\n", __func__, e.what());
|
| | throw;
|
| | }
|
| | }
|
| |
|
| | std::vector<common_cached_model_info> common_list_cached_models() {
|
| | std::vector<common_cached_model_info> models;
|
| | const std::string cache_dir = fs_get_cache_directory();
|
| | const std::vector<common_file_info> files = fs_list(cache_dir, false);
|
| | for (const auto & file : files) {
|
| | if (string_starts_with(file.name, "manifest=") && string_ends_with(file.name, ".json")) {
|
| | common_cached_model_info model_info;
|
| | model_info.manifest_path = file.path;
|
| | std::string fname = file.name;
|
| | string_replace_all(fname, ".json", "");
|
| | auto parts = string_split<std::string>(fname, '=');
|
| | if (parts.size() == 4) {
|
| |
|
| | model_info.user = parts[1];
|
| | model_info.model = parts[2];
|
| | model_info.tag = parts[3];
|
| | } else {
|
| |
|
| | continue;
|
| | }
|
| | model_info.size = 0;
|
| | models.push_back(model_info);
|
| | }
|
| | }
|
| | return models;
|
| | }
|
| |
|