/*! * Copyright (c) 2023 by Contributors * \file cli_main.cc * \brief Implementation of a CLI version of chat */ // NOTE we only interact with the module through tvm runtime // so there is no need to depend on a header interface // the same set of operations can be implemented in other languages #include #include #include #include #include #include #include #include #include #include "llm_chat.h" std::string DetectDeviceName(std::string device_name) { using tvm::runtime::DeviceAPI; if (device_name == "auto") { bool allow_missing = true; if (DeviceAPI::Get(DLDevice{kDLCUDA, 0}, allow_missing)) { return "cuda"; } if (DeviceAPI::Get(DLDevice{kDLMetal, 0}, allow_missing)) { return "metal"; } if (DeviceAPI::Get(DLDevice{kDLVulkan, 0}, allow_missing)) { return "vulkan"; } if (DeviceAPI::Get(DLDevice{kDLOpenCL, 0}, allow_missing)) { return "opencl"; } LOG(FATAL) << "Cannot auto detect device-name"; } return device_name; } DLDevice GetDevice(const std::string& device_name, int device_id) { if (device_name == "cuda") return DLDevice{kDLCUDA, device_id}; if (device_name == "metal") return DLDevice{kDLMetal, device_id}; if (device_name == "vulkan") return DLDevice{kDLVulkan, device_id}; if (device_name == "opencl") return DLDevice{kDLOpenCL, device_id}; LOG(FATAL) << "Do not recognize device name " << device_name; return DLDevice{kDLCPU, 0}; } /*! * \brief Search for file path return the first result. * * \param search_paths The paths to search for the file. * \param names The names of to look for. * \param suffixes The suffix to look for. */ std::optional FindFile(const std::vector& search_paths, const std::vector& names, const std::vector& suffixes) { std::vector ret; for (auto prefix : search_paths) { for (auto name : names) { for (auto suffix : suffixes) { std::filesystem::path path(prefix + "/" + name + suffix); if (std::filesystem::exists(path)) { path = std::filesystem::canonical(path); if (std::filesystem::is_regular_file(path)) { return path; } } } } } return std::nullopt; } /** * get default lib suffixes */ std::vector GetLibSuffixes() { #if defined(WIN32) return {".dll"}; #elif defined(__APPLE__) return {".dylib", ".so"}; #else return {".so"}; #endif } std::string GetArchSuffix() { #if defined(__x86_64__) return "_x86_64"; #elif defined(__aarch64__) return "_arm64"; #endif return ""; } std::vector CountUTF8(const std::string& s) { // assume that the string is always valid utf8 std::vector ret; for (size_t pos = 0; pos < s.size();) { if ((s[pos] & 0x80) == 0x00) { ret.push_back(s.substr(pos, 1)); pos += 1; } else if (pos + 1 < s.size() && (s[pos] & 0xE0) == 0xC0 && (s[pos + 1] & 0xC0) == 0x80) { ret.push_back(s.substr(pos, 2)); pos += 2; } else if (pos + 1 < s.size() && (s[pos] & 0xF0) == 0xE0 && (s[pos + 1] & 0xC0) == 0x80 && (s[pos + 2] & 0xC0) == 0x80) { ret.push_back(s.substr(pos, 3)); pos += 3; } else if (pos + 2 < s.size() && (s[pos] & 0xF8) == 0xF0 && (s[pos + 1] & 0xC0) == 0x80 && (s[pos + 2] & 0xC0) == 0x80 && (s[pos + 3] & 0xC0) == 0x80) { ret.push_back(s.substr(pos, 4)); pos += 4; } else { LOG(FATAL) << "Invalid UTF8 string"; } } return std::move(ret); } void PrintSpecialCommands() { std::cout << "You can use the following special commands:\n" << " /help print the special commands\n" << " /exit quit the cli\n" << " /stats print out the latest stats (token/sec)\n" << " /reset restart a fresh chat\n" << std::endl << std::flush; } /*! * \brief Start a chat conversation. * * \param chat_mod The chat module. * \param model The model to use. * \param max_gen_len The maximum length of the generated sequence. * \param temperature The temperature to use for sampling. * \param top_p The top_p to use for sampling. */ void Chat(tvm::runtime::Module chat_mod, const std::string& model, int64_t max_gen_len = 2048, double temperature = 0.7, double top_p = 0.95, int64_t stream_interval = 2, int max_window_size = 768, int mean_gen_len = 128, double shift_fill_factor = 0.3) { // conv template detect std::string conv_template; if (model.find("vicuna") == 0) { conv_template = "vicuna_v1.1"; } else if (model.find("dolly-") == 0) { conv_template = "dolly"; } else if (model.find("stablelm") == 0) { conv_template = "stablelm"; } else if (model.find("moss") == 0) { conv_template = "moss"; } else { LOG(FATAL) << "Do not recognize model name " << model; } // initialize chat context chat_mod.GetFunction("init_chat")(model, conv_template, max_gen_len, temperature, top_p, stream_interval, max_window_size, mean_gen_len, shift_fill_factor); auto f_stop = chat_mod.GetFunction("stopped"); auto f_encode = chat_mod.GetFunction("encode"); auto f_decode = chat_mod.GetFunction("decode"); auto f_stats = chat_mod.GetFunction("runtime_stats_text"); std::string role0 = chat_mod.GetFunction("get_role0")(); std::string role1 = chat_mod.GetFunction("get_role1")(); while (true) { std::string inp; std::cout << role0 << ": " << std::flush; std::getline(std::cin, inp); if (!std::cin.good()) break; if (inp.substr(0, 6) == "/reset") { // initialize chat context chat_mod.GetFunction("reset_chat")(); std::cout << "RESET CHAT SUCCESS" << std::endl << std::flush; continue; } else if (inp.substr(0, 5) == "/exit") { break; } else if (inp.substr(0, 6) == "/stats") { std::string stats_text = f_stats(); std::cout << stats_text << std::endl << std::flush; continue; } else if (inp.substr(0, 5) == "/help") { PrintSpecialCommands(); continue; } std::string prev_printed = ""; auto printed_UTF8_chars = CountUTF8(prev_printed); std::cout << role1 << ": " << std::flush; f_encode(inp); for (size_t i = 0; !f_stop(); ++i) { f_decode(); if (i % stream_interval == 0 || f_stop()) { std::string cur_msg = chat_mod.GetFunction("get_message")(); // delete the previous message, and print the current message std::string print = ""; auto cur_UTF8_chars = CountUTF8(cur_msg); size_t i = 0; for (; i < std::min(printed_UTF8_chars.size(), cur_UTF8_chars.size()); ++i) { if (printed_UTF8_chars[i] != cur_UTF8_chars[i]) { break; } } for (size_t j = i; j < printed_UTF8_chars.size(); ++j) { print += "\b \b"; } for (size_t j = i; j < cur_UTF8_chars.size(); ++j) { print += cur_UTF8_chars[j]; } std::cout << print << std::flush; prev_printed = cur_msg; printed_UTF8_chars = std::move(cur_UTF8_chars); } } std::cout << std::endl << std::flush; } } int main(int argc, char* argv[]) { using namespace tvm::runtime; argparse::ArgumentParser args("mlc_chat"); args.add_argument("--device-name").default_value("auto"); args.add_argument("--device_id").default_value(0); args.add_argument("--artifact-path").default_value("dist"); args.add_argument("--model").default_value("vicuna-v1-7b"); args.add_argument("--dtype").default_value("auto"); args.add_argument("--params").default_value("auto"); args.add_argument("--evaluate").default_value(false).implicit_value(true); try { args.parse_args(argc, argv); } catch (const std::runtime_error& err) { std::cerr << err.what() << std::endl; std::cerr << args; return 1; } std::string device_name = DetectDeviceName(args.get("--device-name")); int device_id = args.get("--device_id"); DLDevice device = GetDevice(device_name, device_id); std::string artifact_path = args.get("--artifact-path"); std::string model = args.get("--model"); std::string dtype = args.get("--dtype"); std::string params = args.get("--params"); std::string lib_name = model + "_" + device_name; std::string arch_suffix = GetArchSuffix(); std::optional lib_path_opt; std::vector dtype_candidates; if (dtype == "auto") { dtype_candidates = {"float16", "float32"}; } else { dtype_candidates = {dtype}; } std::optional lib_path; for (auto candidate : dtype_candidates) { std::vector search_paths = {artifact_path + "/" + model + "/" + candidate, artifact_path + "/models/" + model, artifact_path + "/" + model, artifact_path + "/lib"}; // search for lib_x86_64 and lib lib_path_opt = FindFile(search_paths, {lib_name + arch_suffix + "_" + candidate, lib_name + "_" + candidate}, GetLibSuffixes()); if (lib_path_opt) { dtype = candidate; break; } } if (!lib_path_opt) { std::cerr << "Cannot find " << model << " lib in preferred path \"" << artifact_path << "/" << model << "/" << dtype_candidates[0] << "/" << lib_name << "_" << dtype_candidates[0] << GetLibSuffixes()[0] << "\" or other candidate paths"; return 1; } std::cout << "Use lib " << lib_path_opt.value().string() << std::endl; std::string model_path = lib_path_opt.value().parent_path().string(); // get artifact path lib name std::optional tokenizer_path_opt = FindFile( { model_path, artifact_path + "/models/" + model, artifact_path + "/" + model, }, {"tokenizer"}, {".model", ".json"}); if (!tokenizer_path_opt) { // Try ByteLevelBPETokenizer tokenizer_path_opt = FindFile( { model_path, artifact_path + "/models/" + model, artifact_path + "/" + model, }, {"vocab"}, {".json"}); if (!tokenizer_path_opt) { std::cerr << "Cannot find tokenizer file in " << model_path; return 1; } else { // GPT2 styles tokenizer needs multiple files, we need to // get the directory that stores vocab.json. tokenizer_path_opt = tokenizer_path_opt.value().parent_path(); } } if (params == "auto") { auto params_json_opt = FindFile( { model_path + "/params", artifact_path + "/" + model + "/" + dtype, artifact_path + "/" + model, }, {"ndarray-cache"}, {".json"}); if (!params_json_opt) { std::cerr << "Cannot find ndarray-cache.json for params in preferred path \"" << model_path << "/params\", \"" << artifact_path << "/" + model << "/" + dtype << "\", and \"" << artifact_path << "/" << model << "\"."; return 1; } std::string params_json = params_json_opt.value().string(); params = params_json.substr(0, params_json.length() - 18); } else if (!FindFile({params}, {"ndarray-cache"}, {".json"})) { std::cerr << "Cannot find params/ndarray-cache.json in " << model_path; return 1; } try { auto lib = Module::LoadFromFile(lib_path_opt.value().string()); std::cout << "Initializing the chat module..." << std::endl; Module chat_mod = mlc::llm::CreateChatModule(lib, tokenizer_path_opt.value().string(), params, device); std::cout << "Finish loading" << std::endl; PrintSpecialCommands(); if (args.get("--evaluate")) { chat_mod.GetFunction("evaluate")(); } else { Chat(chat_mod, model); } } catch (const std::runtime_error& err) { // catch exception so error message // get reported here without silently quit. std::cerr << err.what(); return 1; } return 0; }