diff --git a/Genie/Genie/GenieSymbols.default b/Genie/Genie/GenieSymbols.default new file mode 100644 index 0000000000000000000000000000000000000000..4084db46f37f5b3d47b9ed1f7e65938d185786f2 --- /dev/null +++ b/Genie/Genie/GenieSymbols.default @@ -0,0 +1,31 @@ +#============================================================================= +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# All Rights Reserved. +# Confidential and Proprietary - Qualcomm Technologies, Inc. +# +#============================================================================= +{ + global: + Genie_getApiMajorVersion*; + Genie_getApiMinorVersion*; + Genie_getApiPatchVersion*; + GenieDialogConfig_createFromJson*; + GenieDialogConfig_free*; + GenieDialog_create*; + GenieDialog_query*; + GenieDialog_tokenQuery*; + GenieDialog_embeddingQuery*; + GenieDialog_save*; + GenieDialog_restore*; + GenieDialog_reset*; + GenieDialog_setLoraStrength*; + GenieDialog_applyLora*; + GenieDialog_free*; + GenieEmbeddingConfig_createFromJson*; + GenieEmbeddingConfig_free*; + GenieEmbedding_create*; + GenieEmbedding_generate*; + GenieEmbedding_free*; + local: *; +}; \ No newline at end of file diff --git a/Genie/Genie/Makefile b/Genie/Genie/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..dab323ec87f1151f78d1f88dcd443636f4d53454 --- /dev/null +++ b/Genie/Genie/Makefile @@ -0,0 +1,57 @@ +#============================================================================= +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# All Rights Reserved. +# Confidential and Proprietary - Qualcomm Technologies, Inc. +# +#============================================================================= + +RUST_TARGET := aarch64-linux-android +RUST_SOURCE_DIR := ./src/qualla/tokenizers/rust +# specify compiler +export CXX := clang++-14 +export PATH := $(ANDROID_NDK_ROOT)/toolchains/llvm/prebuilt/linux-x86_64/bin:$(PATH) +.PHONY: all x86 android clean clean_x86 clean_android +.DEFAULT: x86 + +all: x86 android + +x86: build_x86_tokenizer + @echo "-------------------- Building genie for x86 -------------------- " + @$(MAKE) -f make/Makefile.linux-x86_64 CPATH="/usr/include/x86_64-linux-gnu" || (echo "-------------------- genie x86 build failed --------------------"; exit 1; ) + @echo "-------------------- genie x86 build succeeded -------------------- " + +android: check_ndk build_android_tokenizer + @echo "-------------------- Building genie for android -------------------- " + @$(ANDROID_NDK_ROOT)/ndk-build APP_ALLOW_MISSING_DEPS=true APP_ABI="arm64-v8a" NDK_PROJECT_PATH=./ NDK_APPLICATION_MK=make/Application.mk APP_BUILD_SCRIPT=make/Android.mk || (echo "-------------------- genie android build failed --------------------"; exit 1; ) + @$(rename_target_dirs) + @echo "-------------------- genie android build succeeded -------------------- " + +clean: clean_x86 clean_android + +clean_x86: + @$(MAKE) -f make/Makefile.linux-x86_64 clean + +clean_android: + if [ -d "lib/aarch64-android" ]; then rm -rf lib/aarch64-android; fi + if [ -d "obj/local" ]; then rm -rf obj/local; fi + +# utilities +rename_target_dirs = \ + @if [ -d ./lib/aarch64-android ]; then rm -rf ./lib/aarch64-android; fi; \ + find ./obj/local -type d -execdir rename 's/arm64-v8a/aarch64-android/' '{}' \+ \ + && mkdir -p lib \ + && mv ./obj/local/aarch64-android lib/ \ + && mv ./libs/arm64-v8a/libc++_shared.so lib/aarch64-android/ \ + && rm -rf ./libs \ + +check_ndk: +ifeq ($(ANDROID_NDK_ROOT),) + $(error ERROR: ANDROID_NDK_ROOT not set, skipping compilation for Android platform(s).) +endif + +build_x86_tokenizer: $(RUST_SOURCE_DIR)/Cargo.toml + cargo build --release --manifest-path=$< + +build_android_tokenizer: $(RUST_SOURCE_DIR)/Cargo.toml + cargo build --release --manifest-path=$< --target=$(RUST_TARGET) diff --git a/Genie/Genie/README b/Genie/Genie/README new file mode 100644 index 0000000000000000000000000000000000000000..2ed544b5dd74dabd328ef8ca2cf04aecd46a20d4 --- /dev/null +++ b/Genie/Genie/README @@ -0,0 +1,16 @@ +#============================================================================= +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# All Rights Reserved. +# Confidential and Proprietary - Qualcomm Technologies, Inc. +# +#============================================================================= + +Genie library source code example +--------------------------------- + +The Genie library (libGenie.so / Genie.dll) source code example provides users with an ability to recreate the Genie +library from source. Note that the Genie library source may be refactored, rewritten, or otherwise modified without +notice. The Genie C API is the commercially controlled and versioned interface that users should expect to be stable. +Please refer to the Genie SDK documentation tutorials at ${SDK_ROOT}/doc/Genie/ for more information on how to build the +sample code. \ No newline at end of file diff --git a/Genie/Genie/make/Android.mk b/Genie/Genie/make/Android.mk new file mode 100644 index 0000000000000000000000000000000000000000..319f417eb9d4c2be12da0ed39c91719960903ef2 --- /dev/null +++ b/Genie/Genie/make/Android.mk @@ -0,0 +1,56 @@ +#============================================================================= +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# All Rights Reserved. +# Confidential and Proprietary - Qualcomm Technologies, Inc. +# +#============================================================================= + +LOCAL_PATH := $(call my-dir) +SUPPORTED_TARGET_ABI := arm64-v8a x86 x86_64 + +#============================ Verify Target Info and Application Variables ========================================= +ifneq ($(filter $(TARGET_ARCH_ABI),$(SUPPORTED_TARGET_ABI)),) + ifneq ($(APP_STL), c++_shared) + $(error Unsupported APP_STL: "$(APP_STL)") + endif +else + $(error Unsupported TARGET_ARCH_ABI: '$(TARGET_ARCH_ABI)') +endif + +#============================ Define Common Variables =============================================================== +# PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../../include/QNN +# Include paths +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../include +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/Genie +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/include +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN/HTP +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu +PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp + +#========================== Define T2T Lib variables ============================================= +include $(CLEAR_VARS) +LOCAL_MODULE := tokenizers_capi +LOCAL_SRC_FILES := ../src/qualla/tokenizers/rust/target/aarch64-linux-android/release/libtokenizers_capi.a +include $(PREBUILT_STATIC_LIBRARY) + +include $(CLEAR_VARS) +LOCAL_C_INCLUDES := $(PACKAGE_C_INCLUDES) +MY_SRC_FILES := $(wildcard $(LOCAL_PATH)/../src/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/dialogs/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp) +MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/samplers/*.cpp) + +LOCAL_MODULE := libGenie +LOCAL_SRC_FILES := $(subst make/,,$(MY_SRC_FILES)) +LOCAL_STATIC_LIBRARIES := tokenizers_capi +include $(BUILD_SHARED_LIBRARY) diff --git a/Genie/Genie/make/Application.mk b/Genie/Genie/make/Application.mk new file mode 100644 index 0000000000000000000000000000000000000000..4e0596f856b93970ac937ab0b6302b74306ae05b --- /dev/null +++ b/Genie/Genie/make/Application.mk @@ -0,0 +1,14 @@ +#============================================================================= +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# All Rights Reserved. +# Confidential and Proprietary - Qualcomm Technologies, Inc. +# +#============================================================================= + +APP_ABI := arm64-v8a +APP_STL := c++_shared +APP_PLATFORM := android-21 +APP_MODULES := Genie +APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE +APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all diff --git a/Genie/Genie/make/Makefile.linux-x86_64 b/Genie/Genie/make/Makefile.linux-x86_64 new file mode 100644 index 0000000000000000000000000000000000000000..98d4d4a9657c72a5ea359a475eed0b9103339d68 --- /dev/null +++ b/Genie/Genie/make/Makefile.linux-x86_64 @@ -0,0 +1,192 @@ +#============================================================================= +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# All Rights Reserved. +# Confidential and Proprietary - Qualcomm Technologies, Inc. +# +#============================================================================= + +# define relevant directories +SRC_DIR := src/qualla +# +SRC_DIR_GENIE_TOKENIZERS := src/qualla/tokenizers +# +SRC_DIR_SAMPLE_DIALOGS := src/qualla/dialogs + +# All engines +SRC_DIR_GENIE_ENGINES := src/qualla/engines +SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api +SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu +SRC_DIR_GENIE_UTILS := src/qualla/utils +# +SRC_DIR_GENIE_LOGGERS := src/qualla/loggers + +# +SRC_DIR_GENIE_SAMPLERS := src/qualla/samplers + +# +SRC_DIR_GENIE := src + +# Includes +GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu +GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api +GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp +GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers + +GENIE_INCLUDE := include +GENIE_C_API_HEADERS_INCLUDE := ../../../include/Genie +QUALLA_INCLUDE := src/qualla/include +QNN_API_INCLUDE := ../../../include/QNN/ +QNN_API_HTP_INCLUDE := $(QNN_API_INCLUDE)/HTP + +AR := /usr/bin/ar +ARFLAGS := rcs +# Checking if clang++ is present. If not switch to clang++ +ifeq ($(shell $(CXX) -v 2>&1 | grep -c "clang version"), 0) +CXX := clang++ +endif + +QNN_TARGET ?= x86_64-linux-clang +export TARGET_DIR := ./lib/$(QNN_TARGET) + +libGenie := $(TARGET_DIR)/libGenie.so +libtokenizers := src/qualla/tokenizers/rust/target/release/libtokenizers_capi.a + +# define target architecture if not previously defined, default is x86 +ifndef TARGET_AARCH_VARS +TARGET_AARCH_VARS:= -march=x86-64 +endif + +.PHONY: linux_x86_64 +.DEFAULT: linux_x86_64 +GENIE_all: $(libGenie) + +# Include paths +INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE) + +# set compiler flags +COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES) +COMMON_LDFLAGS = -shared -s -fPIC -pthread -L/usr/lib/x86_64-linux-gnu -L./src/qualla/tokenizers/rust/target/release + +COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include + +ifdef QNN_DEBUG_ENABLE +CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE +CFLAGS += $(COMMON_CFLAGS) +LDFLAGS += $(COMMON_LDFLAGS) +else +CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE +CFLAGS += $(COMMON_CFLAGS) +LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto +endif + +# define library sources +SOURCES_GENIE_CPP := $(wildcard $(SRC_DIR_GENIE)/*.cpp) +SOURCES := $(wildcard $(SRC_DIR)/*.cpp) +SOURCES_GENIE_TOKENIZERS := $(wildcard $(SRC_DIR_GENIE_TOKENIZERS)/*.cpp) +SOURCES_GENIE_QNN_API_CPP := $(wildcard $(SRC_DIR_GENIE_QNN_API)/*.cpp) + +SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp)) +SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp) +SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp) +SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp) + + +SOURCES_GENIE_LOGGERS_CPP := $(wildcard $(SRC_DIR_GENIE_LOGGERS)/*.cpp) +SOURCES_GENIE_SAMPLERS_CPP := $(wildcard $(SRC_DIR_GENIE_SAMPLERS)/*.cpp) + + +# define object directory +OBJ_ROOT := obj +OBJ_DIR_QUALLA := obj/$(QNN_TARGET)/qualla +OBJ_DIR_GENIE := obj/$(QNN_TARGET)/genie +OBJ_DIR_GENIE_TOKENIZERS := $(OBJ_DIR_QUALLA)/tokenizers +OBJ_DIR_GENIE_QNN_API := $(OBJ_DIR_QUALLA)/qnn-api + +OBJ_DIR_GENIE_DIALOGS := $(OBJ_DIR_QUALLA)/dialogs +OBJ_DIR_GENIE_ENGINES := $(OBJ_DIR_QUALLA)/engines +OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils +OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu +$(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU)) + +OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers +OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers + +$(shell mkdir -p $(OBJ_DIR_GENIE)) +$(shell mkdir -p $(OBJ_DIR_GENIE_LOGGERS)) +$(shell mkdir -p $(OBJ_DIR_GENIE_SAMPLERS)) + +# setup object files in object directory +OBJECTS_GENIE := $(patsubst %.cpp,$(OBJ_DIR_GENIE)/%.o,$(foreach x,$(SOURCES_GENIE_CPP),$(notdir $(x)))) +OBJECTS_QUALLA := $(patsubst %.cpp,$(OBJ_DIR_QUALLA)/%.o,$(foreach x,$(SOURCES),$(notdir $(x)))) +OBJECTS_GENIE_TOKENIZERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_TOKENIZERS)/%.o,$(foreach x,$(SOURCES_GENIE_TOKENIZERS),$(notdir $(x)))) +OBJECTS_GENIE_QNN_API := $(patsubst %.cpp,$(OBJ_DIR_GENIE_QNN_API)/%.o,$(foreach x,$(SOURCES_GENIE_QNN_API_CPP),$(notdir $(x)))) +OBJECTS_GENIE_ENGINES := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPP),$(notdir $(x)))) +OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x)))) +OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x)))) +OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x)))) + +OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x)))) +OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x)))) + +LIBS=-ldl + + +# Rule to make shared lib +.PHONY: libGenie +libGenie: $(libGenie) + +# Implicit rule to compile and link object files +$(OBJ_DIR_GENIE)/%.o: $(SRC_DIR_GENIE)/%.cpp + $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_QUALLA)/%.o: $(SRC_DIR)/%.cpp + $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_TOKENIZERS)/%.o: $(SRC_DIR_GENIE_TOKENIZERS)/%.cpp + $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_QNN_API)/%.o: $(SRC_DIR_GENIE_QNN_API)/%.cpp + $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_ENGINES)/%.o: $(SRC_DIR_GENIE_ENGINES)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_DIALOGS)/%.o: $(SRC_DIR_SAMPLE_DIALOGS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_UTILS)/%.o: $(SRC_DIR_GENIE_UTILS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@ + +$(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@ + + +# set up resources +directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS) + +# Compile +$(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories) + $(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers) + + +# rule for object directory resource +$(OBJECTS_GENIE): | $(OBJ_DIR_GENIE) +$(OBJECTS_QUALLA): | $(OBJ_DIR_QUALLA) +$(OBJECTS_GENIE_TOKENIZERS): | $(OBJ_DIR_GENIE_TOKENIZERS) +$(OBJECTS_GENIE_QNN_API): | $(OBJ_DIR_GENIE_QNN_API) +$(OBJECTS_GENIE_ENGINES): | $(OBJ_DIR_GENIE_ENGINES) +$(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS) +$(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS) +$(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU) +$(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS) +$(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS) + + +# rule to create directories +$(directories): + mkdir -p $@ + +.PHONY: clean +clean: + rm -rf $(OBJ_ROOT) $(TARGET_DIR) diff --git a/Genie/Genie/src/Dialog.cpp b/Genie/Genie/src/Dialog.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e3812e81362e3ccde2a9e96473339b3b760b883c --- /dev/null +++ b/Genie/Genie/src/Dialog.cpp @@ -0,0 +1,1804 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include + +#include "Dialog.hpp" +#include "Exception.hpp" +#include "Macro.hpp" +#include "qualla/detail/json.hpp" +#include "qualla/env.hpp" + +using namespace genie; + +#ifdef _WIN32 +inline std::string libPrefix = ""; +inline std::string libSuffix = ".dll"; +#else +inline std::string libPrefix = "lib"; +inline std::string libSuffix = ".so"; +#endif + +inline std::string getLibName(std::string baseName) { return libPrefix + baseName + libSuffix; } + +//============================================================================= +// Context::Config functions +//============================================================================= + +static void validateContextConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "context config is not an object"); + } + + std::set mandatoryFields{"version", "bos-token", "eos-token", "size", "n-vocab"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing context field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "context"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid context config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "bos-token") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "eos-token") { + JSON_ENFORCE_ARRAY_OR_NUMERIC(); + } else if (item.key() == "eot-token") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "size") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "n-vocab") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "pad-token") { + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key()); + } + } +} + +static void translateContextConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) { + if (genieConfig["dialog"].contains("context")) { + if (genieConfig["dialog"]["context"].contains("bos-token")) { + quallaConfig["context"]["bos-token"] = genieConfig["dialog"]["context"]["bos-token"]; + } + if (genieConfig["dialog"]["context"].contains("eos-token")) { + quallaConfig["context"]["eos-token"] = genieConfig["dialog"]["context"]["eos-token"]; + } + if (genieConfig["dialog"]["context"].contains("eot-token")) { + quallaConfig["context"]["eot-token"] = genieConfig["dialog"]["context"]["eot-token"]; + } + if (genieConfig["dialog"]["context"].contains("size")) { + quallaConfig["context"]["size"] = genieConfig["dialog"]["context"]["size"]; + } + if (genieConfig["dialog"]["context"].contains("n-vocab")) { + quallaConfig["context"]["n-vocab"] = genieConfig["dialog"]["context"]["n-vocab"]; + } + if (genieConfig["dialog"]["context"].contains("pad-token")) { + quallaConfig["context"]["pad-token"] = genieConfig["dialog"]["context"]["pad-token"]; + } + } +} + +//============================================================================= +// Sampler::Config functions +//============================================================================= + +static void validateSamplerConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object"); + } + + std::set mandatoryFields{"version"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "sampler"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid sampler config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "seed") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "temp") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "top-k") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "top-p") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "greedy") { + JSON_ENFORCE_BOOLEAN(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key()); + } + } +} + +static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) { + if (genieConfig["dialog"].contains("sampler")) { + quallaConfig["sampler"]["type"] = "basic"; + + if (genieConfig["dialog"]["sampler"].contains("seed")) { + quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"]; + } + if (genieConfig["dialog"]["sampler"].contains("temp")) { + quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"]; + } + + quallaConfig["sampler"]["role"] = "primary"; +#if defined(GENIE_SPD_FEATURE) + if (genieConfig["dialog"]["type"] == "spd") { + quallaConfig["sampler"]["role"] = "target"; + } +#endif + + if (genieConfig["dialog"]["sampler"].contains("top-k")) { + quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"]; + } + if (genieConfig["dialog"]["sampler"].contains("top-p")) { + quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"]; + } + if (genieConfig["dialog"]["sampler"].contains("greedy")) { + quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"]; + } + if (genieConfig["dialog"]["sampler"].contains("seed")) { + quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"]; + } + } +} + +//============================================================================= +// Tokenizer::Config functions +//============================================================================= + +static void validateTokenizerConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "tokenizer config is not an object"); + } + + std::set mandatoryFields{"version", "path"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing tokenizer field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "tokenizer"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid tokenizer config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "path") { + JSON_ENFORCE_STRING(); + // Note: the existence of this file is checked by qualla + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown tokenizer config key: " + item.key()); + } + } +} + +static void translateTokenizerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) { + quallaConfig["tokenizer"] = genieConfig["dialog"]["tokenizer"]["path"]; +} + +//============================================================================= +// Embedding::Config functions +//============================================================================= + +static void validateEmbeddingConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "embedding config is not an object"); + } + + std::set mandatoryFields{"version", "size"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "embedding"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid embedding config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "size") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "datatype") { + JSON_ENFORCE_STRING(); + const std::set supportedTypes = {"float32", "native"}; + if (std::find(supportedTypes.begin(), supportedTypes.end(), item.value()) == + supportedTypes.end()) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Unknown embedding datatype: " + std::string(item.value())); + } + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown embedding config key: " + item.key()); + } + } +} + +static void translateEmbeddingConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) { + if (genieConfig["dialog"].contains("embedding")) { + quallaConfig["context"]["n-embd"] = genieConfig["dialog"]["embedding"]["size"]; + + if (genieConfig["dialog"]["embedding"].contains("datatype")) { + quallaConfig["context"]["embedding-datatype"] = + genieConfig["dialog"]["embedding"]["datatype"]; + } + } +} + +bool position_dim_set = false; +bool rope_theta_set = false; + +//============================================================================= +// Backend::Config functions +//============================================================================= + +static void validateBackendHtpConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnHtp config is not an object"); + } + + std::set mandatoryFields{ + "version", "spill-fill-bufsize", "mmap-budget", "use-mmap", "cpu-mask", "poll"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "QnnHtp"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid QnnHtp config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "spill-fill-bufsize") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "mmap-budget") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "use-mmap") { + JSON_ENFORCE_BOOLEAN(); +#ifdef _WIN32 + if (item.value() == true) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid QnnHtp config. use-mmap not supported on target"); + } +#endif + } else if (item.key() == "pos-id-dim") { + position_dim_set = true; + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "cpu-mask") { + JSON_ENFORCE_STRING(); + } else if (item.key() == "poll") { + JSON_ENFORCE_BOOLEAN(); + } else if (item.key() == "kv-dim") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "kv-update-method") { + JSON_ENFORCE_STRING(); + } else if (item.key() == "allow-async-init") { + JSON_ENFORCE_BOOLEAN(); + } else if (item.key() == "rope-theta") { + rope_theta_set = true; + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key()); + } + } +} + +static void validateBackendGenaiConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnGenAiTransformer config is not an object"); + } + + std::set mandatoryFields{"version"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Missing QnnGenAiTransformer field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "QnnGenAiTransformer"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception( + GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid QnnGenAiTransformer config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "use-mmap") { + JSON_ENFORCE_BOOLEAN(); +#ifdef _WIN32 + if (item.value() == true) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid QnnGenAiTransformer config. use-mmap not supported on target"); + } +#endif + } else if (item.key() == "n-logits") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "n-layer") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "n-embd") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "n-heads") { + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown QnnGenAiTransformer config key: " + item.key()); + } + } +} + +static void validateBackendConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "backend config is not an object"); + } + + std::set mandatoryFields{"version", "type"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing backend field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "backend"; + + std::string type; + bool htp = false; + qualla::json htpConfig; + bool genai = false; + qualla::json genaiConfig; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid backend config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "type") { + JSON_ENFORCE_STRING(); + type = item.value().get(); + if (type == "QnnHtp") { + htp = true; + } else if (type == "QnnGenAiTransformer") { + genai = true; + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid backend config: unsupported type: " + item.value().dump()); + } + } else if (item.key() == "extensions") { + JSON_ENFORCE_STRING(); + } else if (item.key() == "QnnHtp") { + JSON_ENFORCE_OBJECT(); + htpConfig = item.value(); + } else if (item.key() == "QnnGenAiTransformer") { + JSON_ENFORCE_OBJECT(); + genaiConfig = item.value(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown backend config key: " + item.key()); + } + } + + if (htp) { + if (!htpConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp dialog config"); + } + validateBackendHtpConfig(htpConfig); + } else { + if (htpConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "QnnHtp backend config for incorrect backend type: " + type); + } + } + + if (genai) { + if (!genaiConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnGenAiTransformer dialog config"); + } + validateBackendGenaiConfig(genaiConfig); + } else { + if (genaiConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "QnnGenAiTransformer backend config for incorrect backend type: " + type); + } + } +} + +//============================================================================= +// Model::Config functions +//============================================================================= + +static void validateLoraAdapterConfig(const qualla::json& config, + LORA_VERSION& specifiedLoraVersion) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora adapter config is not an object"); + } + const std::set mandatoryFields{"version", "name"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora adapter field: " + field); + } + } + + // component is used in the "ENFORCE" macros + const std::string component = "lora adapter"; + LORA_VERSION configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED; + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid lora config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "name") { + JSON_ENFORCE_STRING(); + } else if (item.key() == "bin-sections") { + JSON_ENFORCE_ARRAY(); + configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; // Adapter occurs with V2 + for (auto& elem : item.value()) { + if (!elem.is_string()) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "bin-sections must be an array of strings"); + } + } + } else if (item.key() == "path") { + configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1; // Weights are V1 + JSON_ENFORCE_STRING(); + // Note:all directory validations will done by NSP engine + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown lora adapter config key: " + item.key()); + } + } + + if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1 && + configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "LoRA Adapters must be used with lora version: 2"); + } else if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2 && + configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "LoRA Weights must be used with lora version: 1"); + } else if (configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid lora config."); + } +} + +static void validateLoraConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora config is not an object"); + } + + const std::set mandatoryFields{"version", "adapters"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora field: " + field); + } + } + + // component is used in the "ENFORCE" macros + const std::string component = "lora"; + LORA_VERSION specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; // Default is loraV2 + if (config.find("lora-version") != config.end()) { + switch (static_cast(config["lora-version"])) { + case 1: + specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1; + break; + case 2: + specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; + break; + default: + specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED; + break; + } + } + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid lora config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "alpha-tensor-name") { + JSON_ENFORCE_STRING(); + } else if (item.key() == "adapters") { + JSON_ENFORCE_ARRAY(); + for (auto& elem : item.value()) { + validateLoraAdapterConfig(elem, specifiedLoraVersion); + } + } else if (item.key() == "lora-version") { // Optional + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lora config key: " + item.key()); + } + } + if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unsupported lora version: " + to_string(config["lora-version"])); + } +} + +static void validateModelBinaryConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "binary config is not an object"); + } + + std::set mandatoryFields{"version", "ctx-bins"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "binary"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid binary config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "ctx-bins") { + JSON_ENFORCE_ARRAY(); + for (auto& elem : item.value()) { + if (!elem.is_string()) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "ctx-bins must be an array of strings"); + } + } + } else if (item.key() == "lora") { + JSON_ENFORCE_OBJECT(); + validateLoraConfig(item.value()); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown binary config key: " + item.key()); + } + } +} + +static void validateModelLibraryConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "library config is not an object"); + } + + std::set mandatoryFields{"version", "model-bin"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "library"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid library config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "model-bin") { + JSON_ENFORCE_STRING(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key()); + } + } +} + +static void validateRopeScalingConfig(const qualla::json& config) { + // component is used in the "ENFORCE" macros + std::string component = "rope-scaling"; + if (config.is_object()) { + std::string ropeType; + for (auto& item : config.items()) { + if (item.key() == "rope-type") { + JSON_ENFORCE_STRING(); + ropeType = item.value().get(); + if (ropeType != "llama3" && ropeType != "default" && ropeType != "longrope") { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Rope type not supported" + ropeType); + } + } else if (item.key() == "factor" || item.key() == "low-freq-factor" || + item.key() == "high-freq-factor" || + item.key() == "original-max-position-embeddings") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "short-factor" || item.key() == "long-factor") { + JSON_ENFORCE_ARRAY(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Rope scaling parameter not supported " + item.key()); + } + } + } +} + +static void validatePositionalEncodingConfig(const qualla::json& config) { + // component is used in the "ENFORCE" macros + std::string component = "positional-encoding"; + qualla::json ropeScalingConfig; + if (config.is_object()) { + for (auto& item : config.items()) { + if (item.key() == "type") { + std::string positionEncodingType = item.value().get(); + if (positionEncodingType != "rope" && positionEncodingType != "absolute" && + positionEncodingType != "alibi") { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "positional-encoding type not supported"); + } + } else if (item.key() == "rope-dim") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "rope-theta") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "rope-scaling") { + JSON_ENFORCE_OBJECT(); + ropeScalingConfig = item.value(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown positional encoding config key: " + item.key()); + } + } + } + if (position_dim_set) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Specify one config from pos-id-dim and positional-encoding"); + } + if (rope_theta_set) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Specify one config from rope-theta and positional-encoding"); + } + if (ropeScalingConfig.is_object()) { + validateRopeScalingConfig(ropeScalingConfig); + } +} + +static void validateModelConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "model config is not an object"); + } + + std::set mandatoryFields{"version", "type"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing model field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "model"; + + std::string type; + bool binary = false; + qualla::json binaryConfig; + bool library = false; + qualla::json libraryConfig; + qualla::json positionalEncodingConfig; + bool positionalEncoding = false; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid model config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "type") { + JSON_ENFORCE_STRING(); + type = item.value().get(); + if (type == "binary") { + binary = true; + } else if (type == "library") { + library = true; + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid model config: unsupported type: " + item.value().dump()); + } + } else if (item.key() == "binary") { + JSON_ENFORCE_OBJECT(); + binaryConfig = item.value(); + } else if (item.key() == "library") { + JSON_ENFORCE_OBJECT(); + libraryConfig = item.value(); + } else if (item.key() == "positional-encoding") { + JSON_ENFORCE_OBJECT(); + positionalEncodingConfig = item.value(); + positionalEncoding = true; + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown model config key: " + item.key()); + } + } + + if (binary) { + if (!binaryConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary model config"); + } + validateModelBinaryConfig(binaryConfig); + } else { + if (binaryConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "binary model config for incorrect model type: " + type); + } + } + + if (library) { + if (!libraryConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library model config"); + } + validateModelLibraryConfig(libraryConfig); + } else { + if (libraryConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "library model config for incorrect model type: " + type); + } + } + + if (positionalEncoding) { + if (!positionalEncodingConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing Positional encoding config"); + } + validatePositionalEncodingConfig(positionalEncodingConfig); + } else { + if (positionalEncodingConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Positional encoding config for incorrect model type: " + type); + } + } +} + +//============================================================================= +// Engine::Config functions +//============================================================================= + +static void validateEngineConfig(const qualla::json& config, std::string dialogType) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object"); + } + + std::set mandatoryFields{"version", "backend", "model", "n-threads"}; +#if defined(GENIE_SPD_FEATURE) + if (dialogType == "spd") { + mandatoryFields.insert("role"); + } +#endif + if (dialogType == "kv-share") { + mandatoryFields.insert("role"); + } + + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing engine field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "engine"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid engine config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "backend") { + JSON_ENFORCE_OBJECT(); + validateBackendConfig(item.value()); + } else if (item.key() == "model") { + JSON_ENFORCE_OBJECT(); + validateModelConfig(item.value()); + } else if (item.key() == "n-threads") { + JSON_ENFORCE_NUMERIC(); +#if defined(GENIE_SPD_FEATURE) + } else if (item.key() == "role" && dialogType == "spd") { + JSON_ENFORCE_STRING(); + if (item.value() != "draft" && item.value() != "target") { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown value: for engine config key: " + item.key()); + } +#endif + } else if (item.key() == "role" && dialogType == "kv-share") { + JSON_ENFORCE_STRING(); + if (item.value() != "primary" && item.value() != "secondary") { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown value: for engine config key: " + item.key()); + } + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown engine config key: " + item.key()); + } + } +} + +static void validateMultiEngineConfig(const qualla::json& configs, std::string dialogType) { + if (configs.is_object()) { + validateEngineConfig(configs, dialogType); +#if defined(GENIE_SPD_FEATURE) + if (dialogType == "spd") { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for spd is not an array"); + } +#endif + if (dialogType == "kv-share") { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for kv-share is not an array"); + } +#if defined(GENIE_SPD_FEATURE) + } else if (configs.is_array() && dialogType == "spd") { + if (configs.size() != 2) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "engine config for spd contain invalid number of engines"); + } + bool engineRoleMask[2] = {false, false}; + for (auto& item : configs) { + validateEngineConfig(item, dialogType); + if (item["role"] == "draft") { + engineRoleMask[0] = true; + } else if (item["role"] == "target") { + engineRoleMask[1] = true; + } + } + if (!engineRoleMask[0]) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "engine config for spd does not contain draft engine"); + } + if (!engineRoleMask[1]) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "engine config for spd does not contain target engine"); + } +#endif + } else if (configs.is_array() && dialogType == "kv-share") { + if (configs.size() != 2) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "engine config for kv-share contain invalid number of engines"); + } + bool engineRoleMask[2] = {false, false}; + for (auto& item : configs) { + validateEngineConfig(item, dialogType); + if (item["role"] == "primary") { + engineRoleMask[0] = true; + } else if (item["role"] == "secondary") { + engineRoleMask[1] = true; + } + } + if (!engineRoleMask[0]) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "engine config for kv-share does not contain primary"); + } + if (!engineRoleMask[1]) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "engine config for kv-share does not contain secondary"); + } + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object or an array"); + } +} + +static void translateEngineConfig(const qualla::json& genieEngineConfig, + qualla::json& quallaEngineConfig) { + if (genieEngineConfig["version"] == 1) { + if (genieEngineConfig.contains("role")) { + quallaEngineConfig["role"] = genieEngineConfig["role"]; + } else { + quallaEngineConfig["role"] = "primary"; + } + + quallaEngineConfig["n-threads"] = genieEngineConfig["n-threads"]; + + if (genieEngineConfig["backend"]["type"] == "QnnHtp") { + quallaEngineConfig["type"] = "qnn-htp"; + quallaEngineConfig["backend-lib"] = getLibName("QnnHtp"); + quallaEngineConfig["mmap-budget"] = genieEngineConfig["backend"]["QnnHtp"]["mmap-budget"]; + quallaEngineConfig["use-mmap"] = genieEngineConfig["backend"]["QnnHtp"]["use-mmap"]; + quallaEngineConfig["spill-fill-bufsize"] = + genieEngineConfig["backend"]["QnnHtp"]["spill-fill-bufsize"]; + if (genieEngineConfig["backend"]["QnnHtp"].contains("pos-id-dim")) { + quallaEngineConfig["pos-id-dim"] = genieEngineConfig["backend"]["QnnHtp"]["pos-id-dim"]; + } + quallaEngineConfig["cpumask"] = genieEngineConfig["backend"]["QnnHtp"]["cpu-mask"]; + quallaEngineConfig["poll"] = genieEngineConfig["backend"]["QnnHtp"]["poll"]; + quallaEngineConfig["kv-dim"] = genieEngineConfig["backend"]["QnnHtp"]["kv-dim"]; + if (genieEngineConfig["backend"]["QnnHtp"].contains("rope-theta")) { + quallaEngineConfig["rope-theta"] = genieEngineConfig["backend"]["QnnHtp"]["rope-theta"]; + } + if (genieEngineConfig["backend"]["QnnHtp"].contains("kv-update-method")) { + quallaEngineConfig["kv-update-method"] = + genieEngineConfig["backend"]["QnnHtp"]["kv-update-method"]; + } + // By default, Qualla will default to the async init path. + // For now, we are forcing async init off unless explicitly + // specified in the Genie config. It is HTP specific feature only. + quallaEngineConfig["use-async-Init"] = false; + if (genieEngineConfig["backend"]["QnnHtp"].contains("allow-async-init")) { + quallaEngineConfig["use-async-Init"] = + genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"]; + } + } else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") { + quallaEngineConfig["type"] = "qnn-cpu"; + quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer"); + if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-logits")) { + quallaEngineConfig["n_logits"] = + genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-logits"]; + } + if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("use-mmap")) { + quallaEngineConfig["use-mmap"] = + genieEngineConfig["backend"]["QnnGenAiTransformer"]["use-mmap"]; + } + if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-layer")) { + quallaEngineConfig["n_layer"] = + genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-layer"]; + } + if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-embd")) { + quallaEngineConfig["n_embd"] = + genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-embd"]; + } + if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-heads")) { + quallaEngineConfig["n_heads"] = + genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"]; + } + } + + if (genieEngineConfig["backend"].contains("extensions")) { + quallaEngineConfig["backend-ext-conf"] = genieEngineConfig["backend"]["extensions"]; + } + + if (genieEngineConfig["model"]["type"] == "binary") { + quallaEngineConfig["model-list"] = genieEngineConfig["model"]["binary"]["ctx-bins"]; + if (genieEngineConfig["model"]["binary"].contains("lora")) { + quallaEngineConfig["lora-version"] = + static_cast(LORA_VERSION::GENIE_LORA_VERSION_V2); + if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") && + genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) { + quallaEngineConfig["lora-version"] = + genieEngineConfig["model"]["binary"]["lora"]["lora-version"]; + } + for (int i = 0; i < genieEngineConfig["model"]["binary"]["lora"]["adapters"].size(); i++) { + quallaEngineConfig["lora"][i]["adapter-name"] = + genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["name"]; + quallaEngineConfig["lora"][i]["alpha-tensor-name"] = ""; + if (genieEngineConfig["model"]["binary"]["lora"].contains("alpha-tensor-name")) { + quallaEngineConfig["lora"][i]["alpha-tensor-name"] = + genieEngineConfig["model"]["binary"]["lora"]["alpha-tensor-name"]; + } + quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f; + quallaEngineConfig["lora"][i]["binsection-basedir"] = ""; + if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") && + genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) { + quallaEngineConfig["lora"][i]["path"] = + genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["path"]; + } else { + quallaEngineConfig["lora"][i]["bin-sections"] = + genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["bin-sections"]; + } + } + } + } else if (genieEngineConfig["model"]["type"] == "library") { + quallaEngineConfig["model"] = getLibName("QnnGenAiTransformerModel"); + quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"]; + quallaEngineConfig["op-package"] = + getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider"; + } + if (genieEngineConfig["model"].contains("positional-encoding")) { + quallaEngineConfig["positional-encoding"]["type"] = + genieEngineConfig["model"]["positional-encoding"]["type"]; + if (genieEngineConfig["model"]["positional-encoding"]["type"] == "rope") { + quallaEngineConfig["positional-encoding"]["rope-dim"] = + genieEngineConfig["model"]["positional-encoding"]["rope-dim"]; + if (genieEngineConfig["model"]["positional-encoding"].contains("rope-theta")) { + quallaEngineConfig["positional-encoding"]["rope-theta"] = + genieEngineConfig["model"]["positional-encoding"]["rope-theta"]; + } + if (genieEngineConfig["model"]["positional-encoding"].contains("rope-scaling")) { + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "rope-type")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["rope-type"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"]; + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] == + "llama3") { + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "factor")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"]; + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "low-freq-factor")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["low-freq-factor"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"] + ["low-freq-factor"]; + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "high-freq-factor")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["high-freq-factor"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"] + ["high-freq-factor"]; + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "original-max-position-embeddings")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"] + ["original-max-position-embeddings"] = + genieEngineConfig["model"]["positional-encoding"] + ["rope-scaling"] + ["original-max-position-embeddings"]; + } + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] == + "longrope") { + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "factor")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"]; + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "short-factor")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["short-factor"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"] + ["short-factor"]; + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "long-factor")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"]["long-factor"] = + genieEngineConfig["model"]["positional-encoding"]["rope-scaling"] + ["long-factor"]; + } + if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains( + "original-max-position-embeddings")) { + quallaEngineConfig["positional-encoding"]["rope-scaling"] + ["original-max-position-embeddings"] = + genieEngineConfig["model"]["positional-encoding"] + ["rope-scaling"] + ["original-max-position-embeddings"]; + } + } + } + } + } + } + } +} + +static void translateMultiEngineConfig(const qualla::json& genieConfig, + qualla::json& quallaConfig) { + if (genieConfig["dialog"]["engine"].is_array()) { + quallaConfig["engine"] = qualla::json::array(); + for (auto& item : genieConfig["dialog"]["engine"]) { + qualla::json quallaEngineConfig; + translateEngineConfig(item, quallaEngineConfig); + quallaConfig["engine"].push_back(quallaEngineConfig); + } + } else { + translateEngineConfig(genieConfig["dialog"]["engine"], quallaConfig["engine"]); + } +} + +//============================================================================= +// Dialog::Config functions +//============================================================================= + +qnn::util::HandleManager Dialog::Config::s_manager; + +GenieDialogConfig_Handle_t Dialog::Config::add(std::shared_ptr config) { + return (GenieDialogConfig_Handle_t)s_manager.add(config); +} + +std::shared_ptr Dialog::Config::get(GenieDialogConfig_Handle_t handle) { + return s_manager.get((qnn::util::Handle_t)handle); +} + +void Dialog::Config::remove(GenieDialogConfig_Handle_t handle) { + s_manager.remove((qnn::util::Handle_t)handle); +} + +#if defined(GENIE_SSD_FEATURE) +static void validateDialogSsdConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "ssd-q1 config is not an object"); + } + + std::set mandatoryFields{"version", + "ssd-version", + "forecast-token-count", + "branches", + "forecast-prefix", + "forecast-prefix-name"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "ssd-q1"; + + int branchesSize = 0; + int forecastTokenCount = 0; + + int nStreams = 1; + float pThreshold = 0.0; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid ssd-q1 config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "ssd-version") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "forecast-token-count") { + JSON_ENFORCE_NUMERIC(); + forecastTokenCount = item.value(); + } else if (item.key() == "branches") { + JSON_ENFORCE_ARRAY(); + for (auto& elem : item.value()) { + if (!elem.is_number_integer()) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "branches must be an array of integers"); + } + } + branchesSize = item.value().size(); + } else if (item.key() == "forecast-prefix") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "forecast-prefix-name") { + JSON_ENFORCE_STRING(); + } else if (item.key() == "n-streams") { + JSON_ENFORCE_NUMERIC(); + nStreams = item.value(); + } else if (item.key() == "p-threshold") { + JSON_ENFORCE_NUMERIC(); + pThreshold = item.value(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown ssd-q1 config key: " + item.key()); + } + } + + if ((pThreshold > 0.0) && (nStreams <= 1)) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "p-threshold can only be used with multistream (n-streams > 1)"); + } + + if (branchesSize > forecastTokenCount) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Size of branches array must be less than forecast-token-count"); + } +} +#endif + +#if defined(GENIE_LADE_FEATURE) +static void validateDialogLadeConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lade config is not an object"); + } + + std::set mandatoryFields{"version", "update-mode", "window", "ngram", "gcap"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "lade"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid lade config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "update-mode") { + JSON_ENFORCE_STRING(); + std::string mode = item.value().get(); + if ((mode != "FWD_MAX_HIT") && (mode != "FWD_LEVEL") && (mode != "ALWAYS_FWD_ONE")) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid lade config: unsupported update-mode: " + item.value().dump()); + } + } else if (item.key() == "window") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "ngram") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "gcap") { + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lade config key: " + item.key()); + } + } +} +#endif + +#if defined(GENIE_SPD_FEATURE) +static void validateDialogSpdConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "spd config is not an object"); + } + + std::set mandatoryFields{"version", "draft-len"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "spd"; + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid spd config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "draft-len") { + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown spd config key: " + item.key()); + } + } +} +#endif + +#if defined(GENIE_MULTISTREAM_FEATURE) +static void validateDialogMultistreamConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "multistream config is not an object"); + } + + std::set mandatoryFields{"version", "n-streams"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "multistream"; + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid multistream config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "n-streams") { + JSON_ENFORCE_NUMERIC(); + } else if (item.key() == "p-threshold") { + JSON_ENFORCE_NUMERIC(); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Unknown multistream config key: " + item.key()); + } + } +} +#endif + +static void validateDialogConfig(const qualla::json& config) { + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object"); + } + + std::set mandatoryFields{"version", "type", "context", "tokenizer", "engine"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "dialog"; + + std::string dialogType = "basic"; +#if defined(GENIE_SSD_FEATURE) + bool ssdq1 = false; + qualla::json ssdq1Config; +#endif +#if defined(GENIE_LADE_FEATURE) + bool lade = false; + qualla::json ladeConfig; +#endif +#if defined(GENIE_SPD_FEATURE) + bool spd = false; + qualla::json spdConfig; +#endif +#if defined(GENIE_MULTISTREAM_FEATURE) + bool multistream = false; + qualla::json multistreamConfig; +#endif + + for (auto& item : config.items()) { + if (item.key() == "version") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() != 1) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "Invalid dialog config: unsupported version: " + item.value().dump()); + } + } else if (item.key() == "type") { + JSON_ENFORCE_STRING(); + dialogType = item.value(); + if (dialogType == "basic" || dialogType == "kv-share") { + // Do nothing +#if defined(GENIE_SSD_FEATURE) + } else if (dialogType == "ssd-q1") { + ssdq1 = true; +#endif +#if defined(GENIE_LADE_FEATURE) + } else if (dialogType == "lade") { + lade = true; +#endif +#if defined(GENIE_SPD_FEATURE) + } else if (dialogType == "spd") { + spd = true; +#endif +#if defined(GENIE_MULTISTREAM_FEATURE) + } else if (dialogType == "multistream") { + multistream = true; +#endif + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "Invalid dialog type: " + dialogType); + } +#if defined(GENIE_SSD_FEATURE) + } else if (item.key() == "ssd-q1") { + JSON_ENFORCE_OBJECT(); + ssdq1Config = item.value(); + // ssd-q1 validation is done below +#endif +#if defined(GENIE_LADE_FEATURE) + } else if (item.key() == "lade") { + JSON_ENFORCE_OBJECT(); + ladeConfig = item.value(); + // ssd-q1 validation is done below +#endif +#if defined(GENIE_SPD_FEATURE) + } else if (item.key() == "spd") { + JSON_ENFORCE_OBJECT(); + spdConfig = item.value(); + // spd validation is done below +#endif +#if defined(GENIE_MULTISTREAM_FEATURE) + } else if (item.key() == "multistream") { + JSON_ENFORCE_OBJECT(); + multistreamConfig = item.value(); + // multistream validation is done below +#endif + } else if (item.key() == "stop-sequence") { + JSON_ENFORCE_ARRAY(); + for (auto& elem : item.value()) { + if (!elem.is_string()) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "stop-sequence must be an array of strings"); + } + } + } else if (item.key() == "max-num-tokens") { + JSON_ENFORCE_NUMERIC(); + if (item.value().get() < 0) { + throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, + "number of tokens must be > 0. provided: " + item.value().dump()); + } + } else if (item.key() == "context") { + JSON_ENFORCE_OBJECT(); + validateContextConfig(item.value()); + } else if (item.key() == "tokenizer") { + JSON_ENFORCE_OBJECT(); + validateTokenizerConfig(item.value()); + } else if (item.key() == "sampler") { + JSON_ENFORCE_OBJECT(); + validateSamplerConfig(item.value()); + } else if (item.key() == "engine") { + JSON_ENFORCE_ARRAY_OR_OBJECT(); + } else if (item.key() == "embedding") { + JSON_ENFORCE_OBJECT(); + validateEmbeddingConfig(item.value()); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key()); + } + } + + // Engine Verification requires dialogType for engine roles. Since "type" is encounterd + // later than "engine" in loop. Therefore, moving engine validation out of the loop. + validateMultiEngineConfig(config["engine"], dialogType); + +#if defined(GENIE_SSD_FEATURE) + if (ssdq1) { + if (!ssdq1Config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 dialog config"); + } + validateDialogSsdConfig(ssdq1Config); + } else { + if (ssdq1Config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "ssd-q1 dialog config for incorrect dialog type: " + dialogType); + } + } +#endif +#if defined(GENIE_LADE_FEATURE) + if (lade) { + if (!ladeConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade dialog config"); + } + validateDialogLadeConfig(ladeConfig); + } else { + if (ladeConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "lade dialog config for incorrect dialog type: " + dialogType); + } + } +#endif +#if defined(GENIE_SPD_FEATURE) + if (spd) { + if (!spdConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd dialog config"); + } + validateDialogSpdConfig(spdConfig); + } else { + if (spdConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "spd dialog config for incorrect dialog type: " + dialogType); + } + } +#endif +#if defined(GENIE_MULTISTREAM_FEATURE) + if (multistream) { + if (!multistreamConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream dialog config"); + } + validateDialogMultistreamConfig(multistreamConfig); + } else { + if (multistreamConfig.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "multistream dialog config for incorrect dialog type: " + dialogType); + } + } +#endif +} + +static void translateDialogConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) { + if (genieConfig["dialog"]["version"] == 1) { + if (genieConfig["dialog"]["type"] == "lade") { + quallaConfig["type"] = "lhd-dec"; + } else if (genieConfig["dialog"]["type"] == "spd") { + quallaConfig["type"] = "spec-dec"; + } else if (genieConfig["dialog"]["type"] == "multistream") { + quallaConfig["type"] = "multistream"; + } else { + quallaConfig["type"] = genieConfig["dialog"]["type"]; + } +#if defined(GENIE_SSD_FEATURE) + if (genieConfig["dialog"]["type"] == "ssd-q1") { + quallaConfig["ssd-version"] = genieConfig["dialog"]["ssd-q1"]["ssd-version"]; + quallaConfig["forecast-token-count"] = + genieConfig["dialog"]["ssd-q1"]["forecast-token-count"]; + quallaConfig["branches"] = genieConfig["dialog"]["ssd-q1"]["branches"]; + quallaConfig["forecast-prefix"] = genieConfig["dialog"]["ssd-q1"]["forecast-prefix"]; + quallaConfig["forecast-prefix-name"] = + genieConfig["dialog"]["ssd-q1"]["forecast-prefix-name"]; + + if (genieConfig["dialog"]["ssd-q1"].contains("n-streams")) { + quallaConfig["n-streams"] = genieConfig["dialog"]["ssd-q1"]["n-streams"]; + } + if (genieConfig["dialog"]["ssd-q1"].contains("p-threshold")) { + quallaConfig["p-threshold"] = genieConfig["dialog"]["ssd-q1"]["p-threshold"]; + } + } +#endif +#if defined(GENIE_LADE_FEATURE) + if (genieConfig["dialog"]["type"] == "lade") { + quallaConfig["lhd-update-mode"] = genieConfig["dialog"]["lade"]["update-mode"]; + quallaConfig["window"] = genieConfig["dialog"]["lade"]["window"]; + quallaConfig["ngram"] = genieConfig["dialog"]["lade"]["ngram"]; + quallaConfig["gcap"] = genieConfig["dialog"]["lade"]["gcap"]; + } +#endif +#if defined(GENIE_SPD_FEATURE) + if (genieConfig["dialog"]["type"] == "spd") { + quallaConfig["draft-len"] = genieConfig["dialog"]["spd"]["draft-len"]; + } +#endif +#if defined(GENIE_MULTISTREAM_FEATURE) + if (genieConfig["dialog"]["type"] == "multistream") { + quallaConfig["n-streams"] = genieConfig["dialog"]["multistream"]["n-streams"]; + if (genieConfig["dialog"]["multistream"].contains("p-threshold")) { + quallaConfig["p-threshold"] = genieConfig["dialog"]["multistream"]["p-threshold"]; + } + } +#endif + } + if (genieConfig["dialog"].contains("stop-sequence")) { + quallaConfig["prompt"]["stop-sequence"] = genieConfig["dialog"]["stop-sequence"]; + } + + translateContextConfig(genieConfig, quallaConfig); + translateTokenizerConfig(genieConfig, quallaConfig); + translateSamplerConfig(genieConfig, quallaConfig); + translateMultiEngineConfig(genieConfig, quallaConfig); + translateEmbeddingConfig(genieConfig, quallaConfig); +} + +uint32_t getMaxNumTokens(const qualla::json& genieConfig) { + uint32_t tokenLimit{UINT32_MAX}; + if (genieConfig["dialog"]["version"] == 1) { + if (genieConfig["dialog"].contains("max-num-tokens")) { + tokenLimit = genieConfig["dialog"]["max-num-tokens"]; + } + } + return tokenLimit; +} + +Dialog::Config::Config(const char* configStr) { + qualla::json config; + rope_theta_set = false; + position_dim_set = false; + { + std::set keys; + + auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) { + if ((depth == 1) && (event == qualla::json::parse_event_t::key)) { + if (keys.count(parsed) > 0) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, + "Multiple dialog config key: " + parsed.dump()); + } + keys.insert(parsed); + } + return true; + }; + + config = qualla::json::parse(configStr, callback); + } + + if (!config.is_object()) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object"); + } + + std::set mandatoryFields{"dialog"}; + for (const auto& field : mandatoryFields) { + if (!config.contains(field)) { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field); + } + } + + // component is used in the "ENFORCE" macros + std::string component = "dialog"; + + for (auto& item : config.items()) { + if (item.key() == "dialog") { + JSON_ENFORCE_OBJECT(); + validateDialogConfig(item.value()); + } else { + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key()); + } + } + m_config = config; +} + +qualla::json Dialog::Config::getJson() const { return m_config; } + +//============================================================================= +// Dialog functions +//============================================================================= + +qnn::util::HandleManager Dialog::s_manager; +std::atomic Dialog::s_nameCounter{0u}; + +GenieDialog_Handle_t Dialog::add(std::shared_ptr dialog) { + return (GenieDialog_Handle_t)s_manager.add(dialog); +} + +std::shared_ptr Dialog::get(GenieDialog_Handle_t handle) { + return s_manager.get((qnn::util::Handle_t)handle); +} + +void Dialog::remove(GenieDialog_Handle_t handle) { s_manager.remove((qnn::util::Handle_t)handle); } + +Dialog::Dialog(std::shared_ptr config) { + auto env = qualla::Env::create(qualla::json{}); + qualla::json quallaConfig; + translateDialogConfig(config->getJson(), quallaConfig); + m_tokenLimit = getMaxNumTokens(config->getJson()); + m_quallaDialog = qualla::Dialog::create( + env, "dialog" + std::to_string(s_nameCounter.fetch_add(1u)), quallaConfig); + if (!m_quallaDialog) { + throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object"); + } +} + +static_assert(qualla::Sentence::Code::COMPLETE == + static_cast(GENIE_DIALOG_SENTENCE_COMPLETE)); +static_assert(qualla::Sentence::Code::BEGIN == + static_cast(GENIE_DIALOG_SENTENCE_BEGIN)); +static_assert(qualla::Sentence::Code::CONTINUE == + static_cast(GENIE_DIALOG_SENTENCE_CONTINUE)); +static_assert(qualla::Sentence::Code::END == + static_cast(GENIE_DIALOG_SENTENCE_END)); +static_assert(qualla::Sentence::Code::ABORT == + static_cast(GENIE_DIALOG_SENTENCE_ABORT)); + +int32_t Dialog::query(const char* queryStr, + GenieDialog_SentenceCode_t sentenceCode, + GenieDialog_QueryCallback_t callback, + const void* userData) { + std::string query(queryStr); + uint32_t genTokenCount = 0u; + bool status = m_quallaDialog->query( + query, + static_cast(sentenceCode), + [&](const std::string& response, qualla::Sentence::Code code) { + callback(response.c_str(), static_cast(code), userData); + bool keepGoing = ++genTokenCount < m_tokenLimit; + if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) || + (code == qualla::Sentence::Code::CONTINUE))) { + callback("", GENIE_DIALOG_SENTENCE_END, userData); + } + return keepGoing; + }); + qualla::Dialog::KPIs kpis = m_quallaDialog->kpis(); + printf( + "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : " + "%f toks/sec\n" + "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n", + kpis.init.total_usec, + kpis.prompt.last_usec, + kpis.tps.prompt, + kpis.generate.last_usec, + kpis.tps.generate); + return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED); +} + +int32_t Dialog::save(const std::string& name) { + return m_quallaDialog->save(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED); +} + +int32_t Dialog::restore(const std::string& name) { + return m_quallaDialog->restore(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED); +} + +#if defined(GENIE_E2T_FEATURE) +int32_t Dialog::embeddingQuery(const void* embeddings, + const uint32_t embeddingsSize, + GenieDialog_SentenceCode_t sentenceCode, + GenieDialog_TokenToEmbeddingCallback_t t2eCallback, + GenieDialog_QueryCallback_t callback, + const void* userData) { + uint32_t genTokenCount = 0u; + + if (embeddingsSize % m_quallaDialog->getEmbeddingBufferSize() != 0) { + throw std::runtime_error( + "The embeddings buffer size must be an integer multiple of the embedding vector size in " + "bytes."); + } + + const uint8_t* embeddingsSrc = static_cast(embeddings); + std::vector embeddingVector(embeddingsSrc, embeddingsSrc + embeddingsSize); + + qualla::Dialog::T2ECallback t2eQuallaCallback{nullptr}; + if (t2eCallback) { + t2eQuallaCallback = [&](const int32_t token, void* embedding, const uint32_t embd_size) { + t2eCallback(token, embedding, embd_size, userData); + }; + } + + bool status = m_quallaDialog->query( + embeddingVector, + static_cast(sentenceCode), + t2eQuallaCallback, + [&](const std::string& response, qualla::Sentence::Code code) { + callback(response.c_str(), static_cast(code), userData); + bool keepGoing = ++genTokenCount < m_tokenLimit; + if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) || + (code == qualla::Sentence::Code::CONTINUE))) { + callback("", GENIE_DIALOG_SENTENCE_END, userData); + } + return keepGoing; + }); + qualla::Dialog::KPIs kpis = m_quallaDialog->kpis(); + printf( + "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : " + "%f toks/sec\n" + "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n", + kpis.init.total_usec, + kpis.prompt.last_usec, + kpis.tps.prompt, + kpis.generate.last_usec, + kpis.tps.generate); + return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED); +} +#endif + +void Dialog::reset() { m_quallaDialog->reset(); } + +#if defined(GENIE_LORA_FEATURE) + +int32_t Dialog::applyLora(std::string loraAdapterName, std::string engine) { + bool status = m_quallaDialog->applyLoraAdapter(loraAdapterName, engine); + return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL); +} + +int32_t Dialog::applyLoraStrength(std::string tensorName, std::string engine, float alpha) { + bool status = m_quallaDialog->applyLoraStrength(tensorName, alpha, engine); + return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL); +} + +#endif + +int32_t Dialog::tokenQuery(const uint32_t* tokens, + const uint32_t sizeInputTokens, + GenieDialog_SentenceCode_t sentenceCode, + GenieDialog_TokenQueryCallback_t callback, + const void* userData) { + std::vector inputTokens; + for (size_t i = 0; i < sizeInputTokens; i++) { + inputTokens.push_back(tokens[i]); + } + uint32_t genTokenCount = 0u; + dialogCallback.setCallBackType(qualla::QUALLA_CALLBACK_TYPE_TOKEN); + dialogCallback.getTokenCbFunc() = std::make_shared< + std::function>(); + *(dialogCallback.getTokenCbFunc()) = [&](const int32_t* responseTokens, + const uint32_t sizeResponseTokens, + qualla::Sentence::Code code) { + callback((const uint32_t*)responseTokens, + sizeResponseTokens, + static_cast(code), + userData); + bool keepGoing = ++genTokenCount < m_tokenLimit; + if (!keepGoing && + ((code == qualla::Sentence::Code::BEGIN) || (code == qualla::Sentence::Code::CONTINUE))) { + callback(nullptr, 0, GENIE_DIALOG_SENTENCE_END, userData); + } + return keepGoing; + }; + bool status = m_quallaDialog->query((const std::vector)inputTokens, + static_cast(sentenceCode), + dialogCallback); + qualla::Dialog::KPIs kpis = m_quallaDialog->kpis(); + printf( + "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : " + "%f toks/sec\n" + "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n", + kpis.init.total_usec, + kpis.prompt.last_usec, + kpis.tps.prompt, + kpis.generate.last_usec, + kpis.tps.generate); + return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED); +} \ No newline at end of file diff --git a/Genie/Genie/src/Dialog.hpp b/Genie/Genie/src/Dialog.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c62690c358732c8159dbc365b0c5ccda2e083de7 --- /dev/null +++ b/Genie/Genie/src/Dialog.hpp @@ -0,0 +1,95 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include + +#include "GenieDialog.h" +#include "Util/HandleManager.hpp" +#include "qualla/dialog.hpp" +#include "qualla/DialogCallback.hpp" + +namespace genie { + +enum LORA_VERSION : uint8_t { + GENIE_LORA_VERSION_V1 = 0x1, + GENIE_LORA_VERSION_V2 = 0x2, + GENIE_LORA_VERSION_UNDEFINED = 0xFF +}; + +class Dialog { + public: + class Config { + public: + static GenieDialogConfig_Handle_t add(std::shared_ptr config); + static std::shared_ptr get(GenieDialogConfig_Handle_t handle); + static void remove(GenieDialogConfig_Handle_t handle); + + Config(const char* configStr); + qualla::json getJson() const; + + private: + static qnn::util::HandleManager s_manager; + qualla::json m_config; + }; + + static GenieDialog_Handle_t add(std::shared_ptr dialog); + static std::shared_ptr get(GenieDialog_Handle_t handle); + static void remove(GenieDialog_Handle_t handle); + + qualla::DialogCallback dialogCallback; + + Dialog(std::shared_ptr config); + + Dialog(const Dialog&) = delete; + Dialog& operator=(const Dialog&) = delete; + Dialog(Dialog&&) = delete; + Dialog& operator=(Dialog&&) = delete; + + int32_t query(const char* queryStr, + GenieDialog_SentenceCode_t sentenceCode, + GenieDialog_QueryCallback_t callback, + const void* userData); + + int32_t save(const std::string&); + + int32_t restore(const std::string&); + +#if defined(GENIE_E2T_FEATURE) + int32_t embeddingQuery(const void* embeddings, + const uint32_t embeddingsSize, + GenieDialog_SentenceCode_t sentenceCode, + GenieDialog_TokenToEmbeddingCallback_t t2eCallback, + GenieDialog_QueryCallback_t callback, + const void* userData); +#endif + + + + int32_t tokenQuery(const uint32_t* tokens, + const uint32_t sizeInputTokens, + GenieDialog_SentenceCode_t sentenceCode, + GenieDialog_TokenQueryCallback_t callback, + const void* userData); + + void reset(); + +#if defined(GENIE_LORA_FEATURE) + int32_t applyLora(std::string loraAdapterName, std::string engine); + int32_t applyLoraStrength(std::string tensorName, std::string engine, float alpha); +#endif + + private: + std::unique_ptr m_quallaDialog; + uint32_t m_tokenLimit{UINT32_MAX}; + static qnn::util::HandleManager s_manager; + static std::atomic s_nameCounter; +}; +} // namespace genie diff --git a/Genie/Genie/src/Exception.hpp b/Genie/Genie/src/Exception.hpp new file mode 100644 index 0000000000000000000000000000000000000000..956c935caecb25696b823d093dee0ee9b8e85405 --- /dev/null +++ b/Genie/Genie/src/Exception.hpp @@ -0,0 +1,27 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include + +#include "GenieCommon.h" + +namespace genie { + +class Exception : public std::runtime_error { + public: + Exception(Genie_Status_t status, std::string what) : std::runtime_error(what), m_status(status) {} + Genie_Status_t status() const { return m_status; } + + private: + Genie_Status_t m_status; +}; + +} // namespace genie diff --git a/Genie/Genie/src/GenieCommon.cpp b/Genie/Genie/src/GenieCommon.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ed4f084cf83b154d7add8a143301ed12090e2e0f --- /dev/null +++ b/Genie/Genie/src/GenieCommon.cpp @@ -0,0 +1,15 @@ +//============================================================================= +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +#include "GenieCommon.h" + +uint32_t Genie_getApiMajorVersion(void) { return GENIE_API_VERSION_MAJOR; } + +uint32_t Genie_getApiMinorVersion(void) { return GENIE_API_VERSION_MINOR; } + +uint32_t Genie_getApiPatchVersion(void) { return GENIE_API_VERSION_PATCH; } diff --git a/Genie/Genie/src/GenieDialog.cpp b/Genie/Genie/src/GenieDialog.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6f6f3116de4a0261b15aed2194cfc17b8b3bcda8 --- /dev/null +++ b/Genie/Genie/src/GenieDialog.cpp @@ -0,0 +1,249 @@ +//============================================================================= +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +#include "Dialog.hpp" +#include "Exception.hpp" +#include "GenieDialog.h" +#include "Macro.hpp" +#include "Util/HandleManager.hpp" +#include "qualla/detail/json.hpp" + +using namespace genie; + +GENIE_API +Genie_Status_t GenieDialogConfig_createFromJson(const char* str, + GenieDialogConfig_Handle_t* configHandle) { + try { + GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + auto config = std::make_shared(str); + GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC); + *configHandle = genie::Dialog::Config::add(config); + } catch (const qualla::json::parse_error& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_JSON_FORMAT; + } catch (const Exception& e) { + std::cerr << e.what() << std::endl; + return e.status(); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + return GENIE_STATUS_SUCCESS; +} + +GENIE_API +Genie_Status_t GenieDialogConfig_free(const GenieDialogConfig_Handle_t configHandle) { + try { + GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + { + // Check if the dialog actually exists + auto configObj = genie::Dialog::Config::get(configHandle); + GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE); + } + genie::Dialog::Config::remove(configHandle); + } catch (const std::exception& e) { + return GENIE_STATUS_ERROR_GENERAL; + } + return GENIE_STATUS_SUCCESS; +} + +GENIE_API +Genie_Status_t GenieDialog_create(const GenieDialogConfig_Handle_t configHandle, + GenieDialog_Handle_t* dialogHandle) { + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + + // Get config object + auto configObj = genie::Dialog::Config::get(configHandle); + GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE); + + // Create dialog + auto dialog = std::make_shared(configObj); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_MEM_ALLOC); + + // Create Handle + *dialogHandle = genie::Dialog::add(dialog); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + + // Return SUCCESS + return GENIE_STATUS_SUCCESS; +} + +GENIE_API +Genie_Status_t GenieDialog_query(const GenieDialog_Handle_t dialogHandle, + const char* queryStr, + const GenieDialog_SentenceCode_t sentenceCode, + const GenieDialog_QueryCallback_t callback, + const void* userData) { + int32_t status; + + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(queryStr, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + + switch (sentenceCode) { + case GENIE_DIALOG_SENTENCE_COMPLETE: + case GENIE_DIALOG_SENTENCE_BEGIN: + case GENIE_DIALOG_SENTENCE_CONTINUE: + case GENIE_DIALOG_SENTENCE_END: + case GENIE_DIALOG_SENTENCE_ABORT: + // Do nothing + break; + default: + return GENIE_STATUS_ERROR_INVALID_ARGUMENT; + } + + status = dialog->query(queryStr, sentenceCode, callback, userData); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + + return status; +} + +GENIE_API +Genie_Status_t GenieDialog_save(const GenieDialog_Handle_t dialogHandle, const char* path) { + int32_t status; + + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(path, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + status = dialog->save(path); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + + return status; +} + +GENIE_API +Genie_Status_t GenieDialog_restore(const GenieDialog_Handle_t dialogHandle, const char* path) { + int32_t status; + + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(path, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + status = dialog->restore(path); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + + return status; +} + +GENIE_API +Genie_Status_t GenieDialog_reset(const GenieDialog_Handle_t dialogHandle) { + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + dialog->reset(); + } catch (const std::exception& e) { + return GENIE_STATUS_ERROR_GENERAL; + } + return GENIE_STATUS_SUCCESS; +} + +#if defined(GENIE_LORA_FEATURE) + +GENIE_API +Genie_Status_t GenieDialog_applyLora(const GenieDialog_Handle_t dialogHandle, + const char* engine, + const char* loraAdapterName) { + int32_t status; + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(engine, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + std::string eng(engine); + GENIE_ENSURE(loraAdapterName, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + std::string loraName(loraAdapterName); + status = dialog->applyLora(loraName, eng); + } catch (const std::exception& e) { + return GENIE_STATUS_ERROR_GENERAL; + } + return status; +} + +GENIE_API +Genie_Status_t GenieDialog_setLoraStrength(const GenieDialog_Handle_t dialogHandle, + const char* engine, + const char* tensorName, + const float alpha) { + int32_t status; + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(engine, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + std::string eng(engine); + GENIE_ENSURE(tensorName, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + std::string alphaTensorName(tensorName); + GENIE_ENSURE_NOT_EMPTY(alphaTensorName, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + status = dialog->applyLoraStrength(tensorName, eng, alpha); + } catch (const std::exception& e) { + return GENIE_STATUS_ERROR_GENERAL; + } + return status; +} + +#endif + +GENIE_API +Genie_Status_t GenieDialog_tokenQuery(const GenieDialog_Handle_t dialogHandle, + const uint32_t* inputTokens, + const uint32_t numTokens, + const GenieDialog_SentenceCode_t sentenceCode, + const GenieDialog_TokenQueryCallback_t callback, + const void* userData) { + bool status; + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(inputTokens, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + status = dialog->tokenQuery(inputTokens, numTokens, sentenceCode, callback, userData); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + + return status; +} + +GENIE_API +Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) { + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + { + // Check if the dialog actually exists + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + } + genie::Dialog::remove(dialogHandle); + } catch (const std::exception& e) { + return GENIE_STATUS_ERROR_GENERAL; + } + return GENIE_STATUS_SUCCESS; +} diff --git a/Genie/Genie/src/GenieDialogEmbedding.cpp b/Genie/Genie/src/GenieDialogEmbedding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5162cb8e4b57b9523b2e57e3568d46ea261f8e2 --- /dev/null +++ b/Genie/Genie/src/GenieDialogEmbedding.cpp @@ -0,0 +1,41 @@ +//============================================================================= +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +#include "Dialog.hpp" +#include "Exception.hpp" +#include "GenieDialog.h" +#include "Macro.hpp" +#include "Util/HandleManager.hpp" +#include "qualla/detail/json.hpp" + +using namespace genie; + +GENIE_API +Genie_Status_t GenieDialog_embeddingQuery(const GenieDialog_Handle_t dialogHandle, + const void* embeddings, + const uint32_t embeddingsSize, + const GenieDialog_SentenceCode_t sentenceCode, + const GenieDialog_TokenToEmbeddingCallback_t t2eCallback, + const GenieDialog_QueryCallback_t callback, + const void* userData) { + Genie_Status_t status; + try { + GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE); + auto dialog = genie::Dialog::get(dialogHandle); + GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE); + GENIE_ENSURE(embeddings, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT); + status = dialog->embeddingQuery( + embeddings, embeddingsSize, sentenceCode, t2eCallback, callback, userData); + } catch (const std::exception& e) { + std::cerr << e.what() << std::endl; + return GENIE_STATUS_ERROR_GENERAL; + } + + return status; +} diff --git a/Genie/Genie/src/Macro.hpp b/Genie/Genie/src/Macro.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c50b1585e4cee424b3b744415b4d74233d7a8c31 --- /dev/null +++ b/Genie/Genie/src/Macro.hpp @@ -0,0 +1,101 @@ +//============================================================================ +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================= + +#pragma once + +//====================================================================================================================== +// Error generation macros +//====================================================================================================================== + +#define GENIE_LOG_ERROR(fmt, ...) + +#define GENIE_ENSURE_MSG(value, return_error, msg) \ + do { \ + if (!(value)) { \ + GENIE_LOG_ERROR(" " msg); \ + return return_error; \ + } \ + } while (0) + +#define GENIE_ENSURE(value, return_error) \ + do { \ + if (!(value)) { \ + GENIE_LOG_ERROR("%s was not true.", #value); \ + return return_error; \ + } \ + } while (0) + +#define GENIE_ENSURE_STATUS(status, return_error) \ + do { \ + if ((status) != GENIE_SUCCESS) { \ + return return_error; \ + } \ + } while (0) + +#define GENIE_ENSURE_EQ(a, b, return_error) \ + do { \ + if ((a) != (b)) { \ + GENIE_LOG_ERROR("%s != %s (%d != %d)", #a, #b, (a), (b)); \ + return return_error; \ + } \ + } while (0) + +#define GENIE_ENSURE_NOT_EMPTY(value, return_error) \ + do { \ + if (value.empty()) { \ + GENIE_LOG_ERROR("%s was not true.", #value); \ + return return_error; \ + } \ + } while (0) +//====================================================================================================================== +// JSON config macros +//====================================================================================================================== + +#define JSON_ENFORCE_OBJECT() \ + if (!item.value().is_object()) { \ + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not an object"); \ + } + +#define JSON_ENFORCE_ARRAY() \ + if (!item.value().is_array()) { \ + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not an array"); \ + } + +#define JSON_ENFORCE_ARRAY_OR_OBJECT() \ + if (!item.value().is_array() && !item.value().is_object()) { \ + throw Exception( \ + GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not an array or object"); \ + } + +#define JSON_ENFORCE_NUMERIC() \ + if (!item.value().is_number()) { \ + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not numeric"); \ + } + +#define JSON_ENFORCE_ARRAY_OR_NUMERIC() \ + if (!item.value().is_number() && !item.value().is_array()) { \ + throw Exception( \ + GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not an array or numeric"); \ + } + +#define JSON_ENFORCE_BOOLEAN() \ + if (!item.value().is_boolean()) { \ + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not boolean"); \ + } + +#define JSON_ENFORCE_STRING() \ + if (!item.value().is_string()) { \ + throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \ + "Invalid " + component + " config: " + item.key() + " is not a string"); \ + } diff --git a/Genie/Genie/src/Util/HandleGenerator.hpp b/Genie/Genie/src/Util/HandleGenerator.hpp new file mode 100644 index 0000000000000000000000000000000000000000..09ed32c97015727885b8418441ffb36efe5893c1 --- /dev/null +++ b/Genie/Genie/src/Util/HandleGenerator.hpp @@ -0,0 +1,62 @@ +//============================================================================== +// +// Copyright (c) 2019-2020,2023 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include + +namespace qnn { +namespace util { + +typedef std::size_t Handle_t; + +class HandleGenerator final { + static_assert(std::is_integral::value, "Handle must be an integral type"); + static_assert((sizeof(Handle_t) == 8) || (sizeof(Handle_t) == 4), + "Implementation of HandleGenerator::bswap() for sizeof(std::size_t) is required"); + + public: + HandleGenerator(const HandleGenerator&) = delete; + HandleGenerator& operator=(const HandleGenerator&) = delete; + HandleGenerator(HandleGenerator&&) = delete; + HandleGenerator& operator=(HandleGenerator&&) = delete; + + static Handle_t generate(const void* const addr) { + return (bswap((Handle_t)addr) ^ (Handle_t)s_operand); + } + static const void* reverse(const Handle_t handle) { + return (void*)bswap(handle ^ (Handle_t)s_operand); + } + static constexpr Handle_t invalid() { return s_operand; } + + private: + HandleGenerator() {} + + static uint32_t bswap32(const uint32_t val) { + return (val >> 24U) | ((val >> 8U) & 0xff00U) | ((val << 8U) & 0xff0000U) | (val << 24U); + } + + static uint64_t bswap64(const uint64_t val) { + return ((bswap32(val) + 0ULL) << 32U) | bswap32(val >> 32U); + } + + template + static size_t bswap(T val) { + if (sizeof(T) == 4) { + return bswap32(val); + } else { + return bswap64(val); + } + } + + // Magic number generated via "openssl rand -hex 8" + static constexpr Handle_t s_operand = (Handle_t)0xd4c2416534bcdc9b; +}; + +} // namespace util +} // namespace qnn diff --git a/Genie/Genie/src/Util/HandleManager.hpp b/Genie/Genie/src/Util/HandleManager.hpp new file mode 100644 index 0000000000000000000000000000000000000000..375d4dcc01c7e49ba9969b0d2244aef7f5221f1c --- /dev/null +++ b/Genie/Genie/src/Util/HandleManager.hpp @@ -0,0 +1,84 @@ +//============================================================================== +// +// Copyright (c) 2019-2020 Qualcomm Technologies, Inc. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include +#include +#include +#include + +#include "HandleGenerator.hpp" + +namespace qnn { +namespace util { + +template +class HandleManager { + public: + HandleManager() = default; + HandleManager(const HandleManager&) = delete; + HandleManager& operator=(const HandleManager&) = delete; + HandleManager(HandleManager&&) = delete; + HandleManager& operator=(HandleManager&&) = delete; + + Handle_t add(std::shared_ptr item) { + std::lock_guard locker(m_itemsMtx); + + if (!item) { + return HandleGenerator::invalid(); + } + + auto handle = HandleGenerator::generate(item.get()); + m_items[handle] = item; + return handle; + } + + Handle_t add(T* item) { return add(std::shared_ptr(item)); } + + Handle_t add(std::weak_ptr item) { return add(item.lock()); } + + std::shared_ptr get(Handle_t handle) { + std::lock_guard locker(m_itemsMtx); + + auto it = m_items.find(handle); + if (it == m_items.end()) { + return std::shared_ptr(nullptr); + } + + return it->second; + } + + typedef std::function>&)> UnaryPredicate_t; + + Handle_t findIf(UnaryPredicate_t pred) const { + auto it = std::find_if(m_items.begin(), m_items.end(), pred); + if (it == m_items.end()) { + return HandleGenerator::invalid(); + } + + return it->first; + } + + size_t remove(Handle_t handle) { + std::lock_guard locker(m_itemsMtx); + return m_items.erase(handle); + } + + void clear() { m_items.clear(); } + + const std::unordered_map>& getItems() const { return m_items; } + + private: + std::unordered_map> m_items; + std::mutex m_itemsMtx; +}; + +} // namespace util +} // namespace qnn diff --git a/Genie/Genie/src/qualla/context.cpp b/Genie/Genie/src/qualla/context.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9a71ce9c1754ea9bdee044d01fab6a99ef79c543 --- /dev/null +++ b/Genie/Genie/src/qualla/context.cpp @@ -0,0 +1,118 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include + +#include +#include + +namespace qualla { + +Context::Context(Env& env, const std::string& name, const qualla::json& json) + : _name(name), _env(env), _conf(json) { + _env.logger().debug(fmt::format("ctx-new: {} config {}", _name, _conf.dump())); + + qualla::Config conf(json, "context:"); + _size = conf.optional("size", 1024); + _size = conf.optional("n-ctx", _size); // alternative name + _n_vocab = conf.optional("n-vocab", 32000); + _n_embd = conf.optional("n-embd", 1024); + _embedding_length = conf.optional("embedding-length", -1); + _embedding_datatype = conf.optional("embedding-datatype", "float32"); + // For backward compatibility. When eot-token is removed, this logic can be simplified + // Currently, EOT is marked as default truncating token if available + int32_t eot_tok = conf.optional("eot-token", -1); + if (eot_tok >= 0) _eos_tok_list.insert(eot_tok); + + const qualla::json eos_conf = conf.optional("eos-token", _eos_tok); + if (eos_conf.is_array() && eos_conf.size() > 0) { + const std::vector& eos_tokens = eos_conf.get>(); + _eos_tok = eos_tokens[0]; + for (const int32_t& eos_tok : eos_tokens) + _eos_tok_list.insert(eos_tok); + } else if (eos_conf.is_number_integer()) { + int32_t eos_tok = eos_conf.get(); + _eos_tok = (eot_tok >= 0) ? eot_tok : eos_tok; + _eos_tok_list.insert(eos_tok); + } + + _pad_tok = conf.optional("pad-token", _eos_tok); +} + +std::unique_ptr Context::create( + Env& env, + const std::string& name, + const qualla::json& conf +) { + return std::make_unique(env, name, conf); +} + +std::unique_ptr Context::create( + Env& env, + const std::string& name, + std::istream& json_stream +) { + return create(env, name, json::parse(json_stream)); +} + +std::unique_ptr Context::create( + Env& env, + const std::string& name, + const std::string& json_str +) { + return create(env, name, json::parse(json_str)); +} + +#ifdef QUALLA_STATIC + +// This is a hack to make sure all core bits are linked in for the static build + +extern void needFileLogger(); +extern void needStdoutLogger(); +extern void needBasicSampler(); +extern void needBasicDialog(); +extern void needKvShareDialog(); +extern void needSpdDialog(); +extern void needSsdDialog(); +extern void needLadeDialog(); +extern void needMultistreamDialog(); + + #ifdef QUALLA_ENGINE_QNN_HTP +extern void needQnnHtpEngine(); + #endif + + #ifdef QUALLA_ENGINE_QNN_CPU +extern void needQnnCpuEngine(); + #endif + +static OnLoad needs([]() { + needStdoutLogger(); + needFileLogger(); + needBasicDialog(); + needBasicSampler(); + needKvShareDialog(); + needSpdDialog(); + needSsdDialog(); + needLadeDialog(); + needMultistreamDialog(); + + #ifdef QUALLA_ENGINE_QNN_HTP + needQnnHtpEngine(); + #endif + + #ifdef QUALLA_ENGINE_QNN_CPU + needQnnCpuEngine(); + #endif +}); + +#endif + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialog.cpp b/Genie/Genie/src/qualla/dialog.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef51c2094fa07125430ad7eb87382ba11a27f3ca --- /dev/null +++ b/Genie/Genie/src/qualla/dialog.cpp @@ -0,0 +1,590 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace fs = std::filesystem; + +namespace qualla { + +Dialog::Dialog(std::shared_ptr env, const std::string& name, const qualla::json& json) + : _env(env) { + Timer start; + + + + __DEBUG("dialog-new: {} config {}", name, json.dump()); + + using qc = qualla::Config; + + // Create Gpiomarker and reset the gpio status to low + const qualla::json& gpio_conf = qc::optional(json, "gpio", {}); + _gpio_marker = GpioMarker::create(gpio_conf); + + _gpio_marker->set(); + + // Create the context first + _ctx = Context::create(*_env, name, qc::mandatory(json, "context")); + + // Parse prompt config + const qualla::json& pmt_conf = qc::optional(json, "prompt", {}); + _prompt_type = qc::optional(pmt_conf, "type", "llama2"); + _sys_tags = qc::optional>(pmt_conf, "sys-tags", {"", ""}); + _inst_tags = qc::optional>(pmt_conf, "inst-tags", {"", ""}); + _role_tags = qc::optional>(pmt_conf, "role-tags", {"", ""}); + _sys_prompt = qc::optional(pmt_conf, "sys-prompt", ""); + + const std::vector& stop_sequence = + qc::optional>(pmt_conf, "stop-sequence", {}); + _stop_sequence = SequenceMatchTrie(stop_sequence); + + // Create Tokenizer + // TODO: auto-detect / validate n_vocab with tokenizer vocab + fs::path tok_path = _env->path().models / qc::mandatory(json, "tokenizer"); + _tokenizer = Tokenizer::create(*_ctx, tok_path); + + // Create Sampler(s) + auto add_sampler = [&](const qualla::json& j) { + std::string role = qc::optional(j, "role", "primary"); + _sampler[role] = Sampler::create(*_ctx, j); + }; + + const qualla::json& sam_conf = qc::mandatory(json, "sampler"); + if (sam_conf.is_array()) { + for (auto sc : sam_conf) { + add_sampler(sc); + } + } else + add_sampler(sam_conf); + + + + + // Create Engine(s) + auto add_engine = [&](const qualla::json& j) { + std::string role = qc::optional(j, "role", "primary"); + + _engine[role] = Engine::create(*_ctx, j); + + using FF = Engine::Feature::Flags; + + + if (!_engine[role]->supports(FF::OUTPUT_LOGITS)) + throw std::runtime_error("the engine must output Logits"); + }; + + + + const qualla::json& eng_conf = qc::mandatory(json, "engine"); + + + if (eng_conf.is_array()) { + + for (auto ec : eng_conf) { + add_engine(ec); + } + } else{ + add_engine(eng_conf); + + } + + // Store input type (token, embedding, etc) from the engine. + // This assumes multi-engine usecases use matching input types. + m_inputType = _engine.begin()->second->getInputType(); + + _kpis.init.update(start.elapsed_usec()); +} + +Dialog::~Dialog() {} + +static bool __no_response_query(const std::string&, Sentence::Code) { + return false; +} + +static bool __no_response_token(const int32_t*, const uint32_t, Sentence::Code) { + return false; +} + +static bool __no_response(const std::string&, Sentence::Code) { + return false; +} + +void Dialog::getTopK(std::vector& logits, std::vector>& tokens, size_t topK, float pThreshold, Dialog::Callback callback) { + + auto& sampler = *_sampler["primary"]; + + // Sample top-k logits but with a minimum probability threshold +#if defined(__GNUC__) && !defined(__clang__) + std::span indexed_logits_span(logits); + IndexedLogits indexed_logits(indexed_logits_span, sampler.rng()); +#else + IndexedLogits indexed_logits(std::span{logits.data(),logits.size()}, sampler.rng()); +#endif + indexed_logits.softmax(); + indexed_logits.topK(topK); + + for (int i = 0; i < topK; i++) { + + _last_tok = indexed_logits.indices[i]; + + // Only sample tokens above some probability threshold + // TODO: Modify sampling algorithm as necessary + if (indexed_logits.probs[i] < pThreshold) { + break; + } else if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::CONTINUE); + } else { + tokens.push_back({_last_tok}); + } + } +} + +bool Dialog::query(const std::string& str, Sentence::Code scode, Dialog::Callback callback) { + std::vector p_vec; // prompt tokens + std::string p_str; // prompt string + + p_vec.reserve(1024); + + if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) { + // Reset prompt/gen counts for new query + _n_prompt = 0; + _n_generated = 0; + _n_previous_prompt = 0; + _n_previous_generated = 0; + + + if (_last_tok >= 0 && !_ctx->is_eos(_last_tok)) p_vec.push_back(_last_tok); + + p_str = _inst_tags[0]; + + if (!_n_queries) { + // First query. Prepend sys-prompt. + p_str += _sys_tags[0] + _sys_prompt + _sys_tags[1]; + } else { + // Add EOS explicitly if the last query was aborted prematurely. + if (_ctx->eos_tok() >= 0) p_vec.push_back(_ctx->eos_tok()); + } + + // Add BOS + if (_ctx->bos_tok() >= 0) { + p_vec.push_back(_ctx->bos_tok()); + } + } + + // FIXME: make this more generic + if (_prompt_type == "llama3") { + p_str += _sys_tags[0] + _role_tags[1] + _sys_tags[1] + str + _inst_tags[2]; + } else { + p_str += str; + } + + if (scode == Sentence::COMPLETE || scode == Sentence::END) { + if (_prompt_type == "llama3") { + p_str += _sys_tags[0] + _role_tags[2] + _sys_tags[1]; + } else { + p_str += _inst_tags[1]; + } + } + + _env->logger().post(Logger::DEBUG, [&]() { + qualla::json j{{"string", str}, {"prompt", p_str}}; + return fmt::format("dialog-query: {} {}", _ctx->name(), j.dump()); + }); + + _n_queries++; + + _tokenizer->encode(p_str, p_vec); + + __DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec); + __DEBUG("dialog-text: \"{}\"", p_str); + + if (scode == Sentence::COMPLETE || scode == Sentence::END) { + // Detect stop sequences here + if (!_stop_sequence.empty()) { + _stop_sequence.reset(); + return process(p_vec, [&](const std::string& str, Sentence::Code c) { + // Check for stop sequence and end inference when stop sequence is found + if (_stop_sequence.process_next_string(str)) { + callback(str, c); // Emit sequences until match is complete + return false; + } + + // Else, return normal callback function + return callback(str, c); + }); + } + + return process(p_vec, callback); + } + + return process(p_vec, __no_response); +} + +bool Dialog::query(const std::vector& input, Sentence::Code scode, qualla::DialogCallback& callback) { + std::vector p_vec; // prompt tokens + p_vec.reserve(1024); + + if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) { + // Reset prompt/gen counts for new query + _n_prompt = 0; + _n_generated = 0; + _n_previous_prompt = 0; + _n_previous_generated = 0; + + if (_last_tok >= 0) + p_vec.push_back(_last_tok); + + // Add EOS explicitly if the last query was aborted prematurely. + if (_n_queries && _last_tok != _ctx->eos_tok()) { + p_vec.push_back(_ctx->eos_tok()); + } + // Add BOS + if (_ctx->bos_tok() >= 0) { + p_vec.push_back(_ctx->bos_tok()); + } + } + + p_vec.insert(p_vec.end(), input.begin(), input.end()); + __DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec); + + _n_queries++; + + if (scode == Sentence::COMPLETE || scode == Sentence::END) { + return process(p_vec, callback); + } + + DialogCallback callback_return_token(QUALLA_CALLBACK_TYPE_TOKEN); + *(callback_return_token.getTokenCbFunc()) = __no_response_token; + return process(p_vec, callback_return_token); +} + +bool Dialog::query( + std::vector& embedding_vectors, + Sentence::Code scode, + T2ECallback t2eCallback, + Dialog::Callback callback +) { + _n_queries++; + if (scode == Sentence::COMPLETE || scode == Sentence::END) { + return process(embedding_vectors, t2eCallback, callback); + } + // Only process, no output + return process(embedding_vectors, t2eCallback, [&](const std::string&, Sentence::Code) { + return false; + }); +} + +bool Dialog::prime(const std::string& str) { + bool r = query(str, Sentence::COMPLETE, __no_response); + + // End with EOS as we want the primer to be self-contained + _last_tok = _ctx->eos_tok(); + + return r; +} + +bool Dialog::save(const std::string& o_name) { + Timer start; + + // Save using session name unless override is provided + std::string name = o_name.empty() ? _ctx->name() : o_name; + fs::path save_path = name; + + if (!_n_past) { + __ERROR("dialog-save: {} : nothing to save yet", name); + return false; + } + + __INFO("dialog-save: saving as {} {}", name, save_path.string()); + + if (!fs::exists(save_path) && !fs::create_directories(save_path)) { + __ERROR("dialog-save: {} : failed to create cache directory", name); + return false; + } + + // Save Dialog state + qualla::json j{ + {"n-past", _n_past}, + {"n-prompt", _n_prompt}, + {"n-generated", _n_generated}, + {"n-queries", _n_queries}, + {"last-tok", _last_tok} + }; + { + fs::path p = save_path / "dialog.json"; + std::ofstream f(p); + f << j; + } + + // Save Engines (mandatory) + for (auto& e : _engine) { + if (!e.second->save(name)) { + __ERROR("dialog-save: {} : unable to save {} engine", name, e.first); + return false; + } + } + + // Save Samplers (optional) + for (auto& s : _sampler) { + if (!s.second->save(name)) { + __WARN("dialog-save: {} : unable to save {} sampler", name, s.first); + } + } + + _kpis.save.update(start.elapsed_usec()); + + return true; +} + +bool Dialog::restore(const std::string& o_name) { + Timer start; + + // Restore using session name unless override is provided + std::string name = o_name.empty() ? _ctx->name() : o_name; + fs::path restore_path = name; + + __INFO("dialog-restore: restoring from {} {}", name, restore_path.string()); + + // Try to restore the Dialog state (optional) + // If this fails we reset everything and try to restore the engine. + qualla::json j{}; + { + fs::path p = restore_path / "dialog.json"; + if (fs::exists(p)) { + std::ifstream f(p); + j = qualla::json::parse(f); + } else { + __DEBUG("dialog-restore: {} : internal state not restored", name); + } + } + + using qc = qualla::Config; + _n_past = qc::optional(j, "n-past", 0); + _n_prompt = qc::optional(j, "n-prompt", 0); + _n_generated = qc::optional(j, "n-generated", 0); + _n_queries = qc::optional(j, "n-queries", 1); + _last_tok = qc::optional(j, "last-tok", _ctx->eos_tok()); + + // Restore Engines (mandatory) + for (auto& e : _engine) { + uint32_t n = e.second->restore(name); + if (!n) { + __ERROR("dialog-restore: {} : unable to restore {} engine", name, e.first); + return false; + } + + // Restore n_past from the engine state + if (_n_past && n != _n_past) { + __WARN("dialog-restore: {} : n-past mismatch : {} engine {} intern {}", + name, + e.first, + _n_past, + n); + // Keep the smaller number + _n_past = std::min(n, _n_past); + } else + _n_past = n; + } + + // Restore Samplers (optional) + for (auto& s : _sampler) { + if (!s.second->restore(name)) { + __WARN("dialog-restore: {} : unable to restore {} sampler", name, s.first); + } + } + + _kpis.reset(); + _kpis.restore.update(start.elapsed_usec()); + + return true; +} + +void Dialog::reset() { + __INFO("dialog-reset: {}", _ctx->name()); + + _n_past = 0; + _n_prompt = 0; + _n_generated = 0; + _n_queries = 0; + _last_tok = -1; + _n_previous_prompt = 0; + _n_previous_generated = 0; + + _kpis.reset(); + + // Reset Engines and Samplers + for (auto& e : _engine) + e.second->reset(); + for (auto& s : _sampler) + s.second->reset(); + + State::clear(); +} + +// Dialog KPIs helpers + +// Get latest KPIs +Dialog::KPIs& Dialog::kpis() { + // Update TPS + if (_n_prompt) { + float t = _kpis.prompt.last_usec / _n_prompt; + _kpis.tps.n_prompt = _n_prompt; + _kpis.tps.prompt = 1000000.0 / (t ? t : 1000000.0); + } + + if (_n_generated) { + float t = _kpis.generate.last_usec / _n_generated; + _kpis.tps.n_generate = _n_generated; + _kpis.tps.generate = 1000000.0 / (t ? t : 1000000.0); + } + + // We could synthesize more KPIs from from other layers (engine, sampler, etc) + return _kpis; +} + +std::string Dialog::KPIs::dump(std::string_view sep) const { + return fmt::format( + "init:[{}]{}prompt:[{}]{}generate:[{}]{}save:[{}]{}restore:[{}]{} tps-prompt:{:.2f} tps-generate:{:.2f}", + init.dump(), + sep, + prompt.dump(), + sep, + generate.dump(), + sep, + save.dump(), + sep, + restore.dump(), + sep, + tps.prompt, + tps.generate + ); +} + +void Dialog::KPIs::reset() { + init.reset(); + prompt.reset(); + generate.reset(); + save.reset(); + restore.reset(); + tps.prompt = 0.0; + tps.generate = 0.0; +} + +// Create API + +// Dialog registry : type string + creator function +using Registry = std::unordered_map; +static std::unique_ptr registry; + +void Dialog::__register(const std::string& type, Creator func) { + if (!registry) registry = std::make_unique(); + + Registry& r = *registry; + + + r[type] = func; +} + +std::unique_ptr Dialog::create( + std::shared_ptr env, + const std::string& name, + const qualla::json& conf +) { + + using qc = qualla::Config; + std::string type = qc::optional(conf, "type", "basic"); + + if (!registry) throw std::runtime_error(type + ": dialog not found"); + + Registry& r = *registry; + + if (!r.contains(type)) throw std::runtime_error(type + ": dialog not found"); + + if (!r.contains(type)) { + throw std::runtime_error(type + ": dialog not found"); + } + + return std::unique_ptr(r[type](env, name, conf)); +} + +std::unique_ptr Dialog::create( + std::shared_ptr env, + const std::string& name, + std::istream& json_stream +) { + + return create(env, name, json::parse(json_stream)); +} + +std::unique_ptr Dialog::create( + std::shared_ptr env, + const std::string& name, + const fs::path& json_path +) { + + if (!fs::exists(json_path)) + throw std::runtime_error(json_path.string() + ": file does not exist"); + std::ifstream ifs(json_path); + return create(env, name, ifs); +} + +std::vector Dialog::list() { + std::vector v; + if (!registry) return v; + + Registry& r = *registry; + + for (auto k : r) + v.push_back(k.first); + v.push_back("basic"); // default type, always registered + return v; +} + +bool Dialog::applyLoraAdapter(std::string lora_adapter_name, std::string engine_role) { + auto& engine = *_engine[engine_role]; + if (!engine.applyLoraAdapter(lora_adapter_name)) { + __WARN("dialog-applyLoraAdapter: failed for {}", lora_adapter_name); + return false; + } + return true; +} +bool Dialog::applyLoraStrength(std::string tensor_name, float tensor_val, std::string engine_role) { + auto& engine = *_engine[engine_role]; + if (!engine.applyLoraStrength(tensor_name, tensor_val)) { + __WARN("dialog-applyLoraStrength: failed for {}", tensor_name); + return false; + } + return true; +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialogs/basic.cpp b/Genie/Genie/src/qualla/dialogs/basic.cpp new file mode 100644 index 0000000000000000000000000000000000000000..43e02cfd22d4b9e482a2c31801c710ab2365e8da --- /dev/null +++ b/Genie/Genie/src/qualla/dialogs/basic.cpp @@ -0,0 +1,421 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +BasicDialog::BasicDialog(std::shared_ptr env, const std::string& name, const json& conf) : Dialog(env, name, conf) { + if (!_engine.contains("primary")) { + State::fatal("\"primary\" engine not present in config!"); + return; + } +} + +bool BasicDialog::processFollowOnGeneration(std::vector& tokens, std::vector& logits, Dialog::Callback callback){ + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + while (true) { + if (State::canceled()) { + callback("", Sentence::END); + break; + } + // This condition is valid for both tokens and embedding + if (_n_past + 1 > _ctx->size()) { + __WARN("Context limit exceeded ({} + 1 > {})", _n_past, _ctx->size()); + callback("", Sentence::END); + break; + } + if (m_inputType == InputType::TOKENS) { + if (!engine.process(tokens, logits)) + return Dialog::abort("engine processing failed", callback); + } else if(m_inputType == InputType::EMBEDDINGS) { + // Convert tokens to embedding for the processing in the engine. + auto embedBufSize = engine.getEmbeddingBufferSize(); + std::vector embedding; + for(auto &token: tokens){ + std::vector curTokenEmbedding(embedBufSize,0); + m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize); + embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end()); + } + if (!engine.process(embedding, {}, logits)) + return Dialog::abort("engine processing failed", callback); + } + else{ + return Dialog::abort("No valid Input Type is used", callback); + } + tokens[0] = _last_tok = sampler.process(logits); + + _n_past++; + _n_generated++; + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + break; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break; + } + + return true; +} + +bool BasicDialog::process(std::vector& tokens, Dialog::Callback callback) { + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + if(m_inputType != InputType::TOKENS) { + __ERROR("Input type for model is not tokens."); + return false; + } + + _gpio_marker->set(); + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + + State::clear(); + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(tokens, logits, false)) + return Dialog::abort("engine prompt processing failed", callback); + + _n_prompt += tokens.size(); + _n_past += tokens.size(); + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + tokens[0] = _last_tok = sampler.process(logits); + tokens.resize(1); + + _n_generated++; + + _gpio_marker->set(); + + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + start.reset(); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + return true; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true; + + State::busy(true); + + processFollowOnGeneration(tokens, logits, callback); + + State::busy(false); + + _gpio_marker->set(); + _gpio_marker->reset(); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return !State::failed(); +} + +bool BasicDialog::processFollowOnGeneration(std::vector& tokens, std::vector& logits, qualla::DialogCallback callback){ + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + while (true) { + if (State::canceled()) { + callback.callBack(nullptr, 0, Sentence::END, tokenizer()); + break; + } + // This condition is valid for both tokens and embedding + if (_n_past + 1 > _ctx->size()) { + __WARN("Context limit exceeded ({} + 1 > {})", _n_past, _ctx->size()); + callback.callBack(nullptr, 0, Sentence::END, tokenizer()); + break; + } + if (m_inputType == InputType::TOKENS) { + if (!engine.process(tokens, logits)) + return Dialog::abort("engine processing failed", callback); + } else if(m_inputType == InputType::EMBEDDINGS) { + // Convert tokens to embedding for the processing in the engine. + auto embedBufSize = engine.getEmbeddingBufferSize(); + std::vector embedding; + for(auto &token: tokens){ + std::vector curTokenEmbedding(embedBufSize,0); + m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize); + embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end()); + } + if (!engine.process(embedding, {}, logits)) + return Dialog::abort("engine processing failed", callback); + } + else{ + return Dialog::abort("No valid Input Type is used", callback); + } + tokens[0] = _last_tok = sampler.process(logits); + + _n_past++; + _n_generated++; + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + if (_ctx->is_eos(_last_tok)) { + callback.callBack(nullptr, 0, Sentence::END, tokenizer()); + break; + } + + if (!callback.callBack(tokens.data(), tokens.size(), Sentence::CONTINUE, tokenizer())) break; + } + + return true; +} + +bool BasicDialog::process(std::vector& tokens, qualla::DialogCallback callback) { + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + if(m_inputType != InputType::TOKENS) { + __ERROR("Input type for model is not tokens."); + return false; + } + + _gpio_marker->set(); + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + + State::clear(); + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback.callBack(nullptr, 0, Sentence::END, tokenizer()); + return true; + } + + if (!engine.process(tokens, logits, false)) { + return Dialog::abort("engine prompt processing failed", callback); + } + + _n_prompt += tokens.size(); + _n_past += tokens.size(); + + if (!engine.updateKV(_n_past)) { + return Dialog::abort("context size exceeded", callback); + } + + tokens[0] = _last_tok = sampler.process(logits); + tokens.resize(1); + + _n_generated++; + + _gpio_marker->set(); + + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + start.reset(); + + if (_ctx->is_eos(_last_tok)) { + callback.callBack(nullptr, 0, Sentence::END, tokenizer()); + return true; + } + + if (!callback.callBack(tokens.data(), tokens.size(), Sentence::BEGIN, tokenizer())) + return true; + + State::busy(true); + processFollowOnGeneration(tokens, logits, callback); + State::busy(false); + + _gpio_marker->set(); + _gpio_marker->reset(); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return !State::failed(); +} + +bool BasicDialog::process( + std::vector& embedding_vectors, + T2ECallback t2eCallback, + Dialog::Callback callback +) { + Timer start; + + if(m_inputType != InputType::EMBEDDINGS) { + __ERROR("Input type for model is not embeddings."); + return false; + } + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + + State::clear(); + + _gpio_marker->set(); + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + // Store the t2e callback for reference during follow-on generation. + m_t2eCallback = t2eCallback; + + size_t embedBufSize = engine.getEmbeddingBufferSize(); + + { + std::vector eosEmbedding(embedBufSize, 0.0); + if (m_t2eCallback) { + m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize); + } + // For non-autogenerative usecases (where t2eCallback is not supplied), + // the EOS vector is all zero. This is fine for models with proper + // attention masking support, but may degrade accuracy otherwise. + if (!engine.cacheEosEmbedding(eosEmbedding)) { + __DEBUG("Failed to set the eos token embedding."); + return false; + } + } + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + size_t curTokenCount = embedding_vectors.size() / embedBufSize; + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + start.reset(); // Don't include preprocessing time + + if (_n_past + curTokenCount > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokenCount, _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(embedding_vectors, {}, logits)) + return Dialog::abort("engine prompt processing failed", callback); + _n_prompt += curTokenCount; + _n_past += curTokenCount; + + std::vector tokens(1, 0); + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + tokens[0] = _last_tok = sampler.process(logits); + + _n_generated++; + + _gpio_marker->set(); + + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + start.reset(); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + return true; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) { + return true; + } + + if (!m_t2eCallback) { + callback("", Sentence::END); + return true; + } + + State::busy(true); + processFollowOnGeneration(tokens, logits, callback); + State::busy(false); + + _gpio_marker->set(); + _gpio_marker->reset(); + + _kpis.generate.update(start.elapsed_usec()); + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return !State::failed(); +} + +// Registrator instance +static OnLoad regy([]() { + Dialog::__register( + "basic", + [](std::shared_ptr env, const std::string& name, const json& conf) { + return (Dialog*)new BasicDialog(env, name, conf); + } + ); +}); + +void needBasicDialog() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialogs/kv-share.cpp b/Genie/Genie/src/qualla/dialogs/kv-share.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0afef88e66638fe74864c54c171d464f3b39f637 --- /dev/null +++ b/Genie/Genie/src/qualla/dialogs/kv-share.cpp @@ -0,0 +1,359 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + + using qc = qualla::Config; + + class KvShareDialog : public Dialog { + public: + KvShareDialog(std::shared_ptr env, const std::string& name, const json& conf) + : Dialog(env, name, conf) {} + + virtual bool process(std::vector& tokens, Dialog::Callback callback) override; + + virtual bool process(std::vector& tokens, DialogCallback callback) override { + return false; + } + + virtual void reset() override; + + bool convertKV(const fs::path& cache_dir); + + }; + + void KvShareDialog::reset() { + __INFO("dialog-reset: {}", _ctx->name()); + + _n_past = 0; + _n_prompt = 0; + _n_generated = 0; + _n_queries = 0; + _last_tok = -1; + + _kpis.reset(); + + // Reset Samplers + for (auto& s : _sampler) + s.second->reset(); + + // Reset Engines + for (auto& e : _engine) { + e.second->reset(); + e.second->unload(); + } + + State::clear(); + } + + bool KvShareDialog::process(std::vector& tokens, Dialog::Callback callback) { + + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + + State::clear(); + + auto& sampler = *_sampler["primary"]; + + auto& p_engine = *_engine["primary"]; // prompt + auto& s_engine = *_engine["secondary"]; // generation + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!p_engine.process(tokens, logits)) + return Dialog::abort("engine prompt processing failed", callback); + + _n_prompt += tokens.size(); + _n_past += tokens.size(); + + if (!p_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + tokens[0] = _last_tok = sampler.process(logits); + tokens.resize(1); + + _n_generated++; + + _kpis.prompt.update(start.elapsed_usec()); + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + return true; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true; + + __DEBUG("dialog: {} : switching engines", _ctx->name()); + { + // Setup cache dir for saving the engine state + std::string cache_name = _ctx->name() + "-kv-share"; + fs::path cache_dir = _env->path().cache / cache_name; + + if (!fs::exists(cache_dir) && !fs::create_directories(cache_dir)) { + __ERROR("dialog: {} : failed to create cache directory {}", + _ctx->name(), + cache_dir.string()); + return Dialog::abort("engine switch failed", callback); + } + + // Save and unload the primary engine + p_engine.save(cache_name); + p_engine.unload(); + + // The purpose is to save the hyperparams + s_engine.save(cache_name); + + convertKV(cache_dir); + + size_t n = s_engine.restore(cache_name); + + if(!fs::remove_all(cache_dir)) { + __WARN("dialog: {} : cache files not closed/dir not found", _ctx->name()); + } + + if (n != _n_past) { + __WARN("dialog: {} : kv size mismatch {} expected {}", _ctx->name(), n, _n_past); + _n_past = n; + } + + s_engine.updateKV(_n_past); + } + + start.reset(); + + State::busy(true); + + while (true) { + if (State::canceled()) { + callback("", Sentence::END); + break; + } + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + break; + } + if (!s_engine.process(tokens, logits)) + return Dialog::abort("secondary engine processing failed", callback); + + tokens[0] = _last_tok = sampler.process(logits); + + _n_past++; + _n_generated++; + + if (!s_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + break; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break; + } + + State::busy(false); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return true; + } + + bool KvShareDialog::convertKV(const fs::path& cache_dir) { + Timer start; + + fs::path nsp_cache_path = cache_dir / "kv-cache.primary.qnn-htp"; + fs::path cpu_cache_path = cache_dir / "kv-cache.secondary.qnn-cpu"; + + __DEBUG("kv-convert: begin converting {} to ", nsp_cache_path.string(), cpu_cache_path.string()); + + std::ifstream nsp_fs(nsp_cache_path, std::ios::in | std::ios::binary); + + if (nsp_fs.fail()) { + __ERROR("kv-convert: error reading file {}", nsp_cache_path.string()); + State::error("failed to read primary kv-cache"); + return false; + } + + // Read spec from nsp file + CacheFileSpec nsp_spec; + nsp_fs.read((char*)&nsp_spec, sizeof(nsp_spec)); + if (nsp_spec.magic != 0xC0DE) { + __ERROR("kv-convert: expected 0xC0DE found {:#x}", nsp_spec.magic); + State::error("invalid format of primary kv-cache"); + return false; + } + + // clang-format off + __DEBUG("kv-convert: load {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}", + nsp_spec.num_tensors, nsp_spec.magic, int(nsp_spec.dtype), nsp_spec.n_heads, nsp_spec.embed_dim, nsp_spec.update_size); + // clang-format on + + std::fstream cpu_fs(cpu_cache_path, std::ios::in | std::ios::out | std::ios::binary); + + if (cpu_fs.fail()) { + // TODO: replace with proper error handling + __ERROR("kv-convert: failed to write {}", cpu_cache_path.string()); + State::error("failed to save secondary kv-cache"); + return false; + } + + CacheFileSpec cpu_spec; + cpu_fs.read((char*)&cpu_spec, sizeof(cpu_spec)); + if (cpu_spec.magic != 0xC0DE) { + __ERROR("kv-convert: expected 0xC0DE found {:#x}", cpu_spec.magic); + State::error("invalid format of secondary kv-cache"); + return false; + } + + // Set the n_tokens processed during prompt processing and the spec write to file + cpu_spec.update_size = nsp_spec.update_size; + cpu_fs.seekp(std::ios::beg); + cpu_fs.write((char*)&cpu_spec, sizeof(cpu_spec)); + + const uint32_t n_layer = nsp_spec.num_tensors / 2; + const uint32_t n_head = nsp_spec.n_heads; + const uint32_t kv_dim = nsp_spec.embed_dim; + const uint32_t n_tok = nsp_spec.update_size; + + const size_t cache_size = n_layer * n_head * kv_dim * n_tok; + + // Read Key/Value Cache + std::vector key_cache(cache_size); + std::vector value_cache(cache_size); + nsp_fs.read((char*)key_cache.data(), cache_size); + nsp_fs.read((char*)value_cache.data(), cache_size); + + // Read Quantization parameters + std::vector key_scales(n_layer); + std::vector value_scales(n_layer); + nsp_fs.read((char*)key_scales.data(), n_layer * sizeof(double)); + nsp_fs.read((char*)value_scales.data(), n_layer * sizeof(double)); + + nsp_fs.close(); + + // Convert and write on cpu_file + // Dequant and transpose caches + const uint32_t layer_size = n_head * kv_dim * n_tok; + const uint32_t head_size = kv_dim * n_tok; + + // Transpose kvdim * n_tok (QNN-HTP K$) -> n_tok * kvdim (QNN-CPU K$) + // For ScopGPT KV$ Format + __DEBUG("kv-convert: dequantizing keys"); + std::vector dequant_keys(cache_size); + for (uint32_t i = 0; i < n_layer; i++) { + for (uint32_t j = 0; j < n_head; j++) { + for (uint32_t k = 0; k < kv_dim; k++) { + for (uint32_t l = 0; l < n_tok; l++) { + // Interleave K$ + // QNN HTP: [0 2 4 ... 126 1 3 5 ... 127] + // QNN CPU: [0 1 2 ... 63 64 65 ... 127] + const uint32_t interleaved_k = + (2 * k < kv_dim) ? 2 * k : 2 * (k - kv_dim / 2) + 1; + + const uint32_t read_loc = i * layer_size + j * head_size + k * n_tok + l; + const uint32_t write_loc = i * layer_size + j * head_size + l * kv_dim + interleaved_k; + + dequant_keys[write_loc] = + (static_cast(key_cache[read_loc]) - 128) * key_scales[i]; + } + } + } + } + + __DEBUG("kv-convert: dequantizing values"); + std::vector dequant_values(cache_size); + for (uint32_t i = 0; i < n_layer; i++) { + for (uint32_t j = 0; j < n_head; j++) { + for (uint32_t l = 0; l < n_tok; l++) { + for (uint32_t k = 0; k < kv_dim; k++) { + const uint32_t read_loc = i * layer_size + j * head_size + l * kv_dim + k; + const uint32_t write_loc = read_loc; + + dequant_values[write_loc] = + (static_cast(value_cache[read_loc]) - 128) * value_scales[i]; + } + } + } + } + + __DEBUG("kv-convert: storing converted KV to file"); + cpu_fs.write((char *)dequant_keys.data(), dequant_keys.size() * sizeof(float)); + cpu_fs.write((char *)dequant_values.data(), dequant_values.size() * sizeof(float)); + + cpu_fs.flush(); + cpu_fs.close(); + + __DEBUG("kv-convert: done converting {} to {} in {} usec", + nsp_cache_path.string(), + cpu_cache_path.string(), + start.elapsed_usec()); + + return true; + + } + +// Registrator instance + static OnLoad regy([]() { + Dialog::__register( + "kv-share", + [](std::shared_ptr env, const std::string& name, const json& conf) { + return (Dialog*)new KvShareDialog(env, name, conf); + } + ); + }); + + void needKvShareDialog() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialogs/lhd-dec.cpp b/Genie/Genie/src/qualla/dialogs/lhd-dec.cpp new file mode 100644 index 0000000000000000000000000000000000000000..95c67c2702e92a26c9514987019153a8530cff82 --- /dev/null +++ b/Genie/Genie/src/qualla/dialogs/lhd-dec.cpp @@ -0,0 +1,481 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +using qc = qualla::Config; + +LhdDecDialog::LhdDecDialog(std::shared_ptr env, const std::string& name, const json& conf) + : Dialog(env, name, conf) { + + _window = qc::optional(conf, "window", 8); + _ngram = qc::optional(conf, "ngram", 3); + _gcap = qc::optional(conf, "gcap", 8); + + _lhd_mode_str = qc::optional(conf, "lhd-update-mode", "ALWAYS_FWD_ONE"); +} + +bool LhdDecDialog::process(std::vector& tokens, Dialog::Callback callback) { + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + std::vector resultTokens; + + State::clear(); + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(tokens, logits, false)) + return Dialog::abort("engine prompt processing failed", callback); + + _n_prompt += tokens.size(); + _n_past += tokens.size(); + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + std::vector tokens_tmp(1); + tokens_tmp[0] = _last_tok = sampler.process(logits); + resultTokens.push_back(_last_tok); + + _n_generated++; + + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + return true; + } + + // Exit condition : Prediction limit reached OR ctx size limit reached + if (!callback(_tokenizer->decode(tokens_tmp), Sentence::BEGIN)) return true; + + State::busy(true); + + // verification branch init + v_branch.resize(_gcap); + + // n-gram pools + const size_t n_vocab = _ctx->n_vocab(); + ngram_container ngrams_pool(n_vocab, _ngram, _gcap); + + // lookahead branch first level init + lhd_branch.resize(_ngram - 1); + lhd_branch_prev.resize(_window); + + for (int j = 0; j < _ngram - 1; j++) { + lhd_branch[j].resize(_window); + + for (int i = 0; i < _window; i++) { + if (j == 0) { + // initialize with the random token from prompt + lhd_branch[j][i] = tokens[1 + rand() % (tokens.size() - 1)]; + } else { + // initialize with a sequence of increasing numbers + lhd_branch[j][i] = 1000 + i; + } + } + } + + // lookahead branch other level init + while (_level_idx < _ngram - 1) { + + batch.clear(); + attention_map.clear(); + + // fill the first token of the first level + batch.push_back(_last_tok); + attention_map.push_back(-1); + lhd_branch[0][0] = _last_tok; + + // fill the remaining WINDOW - 1 tokens for the first level + for (int i = 1; i < _window; i++) { + batch.push_back(lhd_branch[0][i]); + attention_map.push_back(i - 1); + } + + // fill the rest of the levels + for (int j = 1; j < _ngram - 1; j++) { + for (int i = 0; i < _window; i++) { + batch.push_back(lhd_branch[j][i]); + attention_map.push_back((j - 1) * _window + i); + } + } + + // re-init tokens batch + tokens.resize(_window * (_ngram - 1)); + tokens = batch; + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + break; + } + + size_t n_tok = engine.process(tokens, attention_map, logits, true); + if (n_tok != tokens.size()) + return Dialog::abort("engine lookahead branch processing failed", callback); + + for (int i = 0; i < _window; i++) { + size_t sample_tmp_idx = (_level_idx - 1) * _window + i; + // sampler from logits all + std::span span_logits{logits.data(),logits.size()}; + std::span span_tmp = span_logits.subspan(sample_tmp_idx * n_vocab, n_vocab); + int32_t sampled_tmp_token = sampler.process(span_tmp); + lhd_branch[_level_idx][i] = sampled_tmp_token; + } + + _level_idx++; + } + + if (_lhd_mode_str == "FWD_MAX_HIT") + _lhd_update_mode = FWD_MAX_HIT; + else if (_lhd_mode_str == "FWD_LEVEL") + _lhd_update_mode = FWD_LEVEL; + else + _lhd_update_mode = ALWAYS_FWD_ONE; + + start.reset(); + + while (true) { + if (State::canceled()) { + callback("", Sentence::END); + break; + } + // input batch init + { + batch.clear(); + attention_map.clear(); + + // fill the first token of the first level + batch.push_back(_last_tok); + attention_map.push_back(-1); + // lhd_branch[0][0] = _last_tok; + + // fill the remaining WINDOW - 1 tokens for the first level + for (int i = 1; i < _window; i++) { + batch.push_back(lhd_branch[0][i]); + attention_map.push_back(i - 1); + } + + // fill the rest of the levels + for (int j = 1; j < _ngram - 1; j++) { + for (int i = 0; i < _window; i++) { + batch.push_back(lhd_branch[j][i]); + attention_map.push_back((j - 1) * _window + i); + } + } + + // build verification n-grams(branch) + { + const int g_cur = ngrams_pool.cnt[_last_tok]; + + v_branch.resize(g_cur); + // input_token_batch.size = (_window + g_cur) * (_ngram - 1); + tokens.resize((_window + g_cur) * (_ngram - 1)); + for (int g = 0; g < g_cur; g++) { + v_branch[g].active = true; + v_branch[g].tokens.resize(_ngram); + v_branch[g].i_batch.resize(_ngram); + v_branch[g].seq_id = _window + 1 + g; + v_branch[g].i_batch[0] = 0; + v_branch[g].tokens[0] = _last_tok; + } + + for (int j = 0; j < _ngram - 1; j++) { + for (int g = 0; g < g_cur; g++) { + const int idx = _last_tok * (_ngram - 1) * _gcap + g * (_ngram - 1); + const int32_t t = ngrams_pool.tokens[idx + j]; + v_branch[g].tokens[j + 1] = t; + v_branch[g].i_batch[j + 1] = j + 1; + } + } + + for (int g = 0; g < g_cur; g++) { + for (int j = 0; j < _ngram - 1; j++) { + batch.push_back(v_branch[g].tokens[j + 1]); + if (j == 0) + attention_map.push_back(0); + else + attention_map.push_back(batch.size() - 2); + } + } + } + } + + // re-init tokens batch + std::vector selected(attention_map.size(), false); + tokens = batch; + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + break; + } + + size_t n_tok = engine.process(tokens, attention_map, logits, true); + if (n_tok != tokens.size()) return Dialog::abort("engine gen processing failed", callback); + + // verification branch seq-id + size_t seq_id_best = 0; + // max hit pos + size_t i_batch_best = 0; + + // Lookahead decoding and verification + for (int v = 0; v < _ngram; ++v) { + int i_batch = 0; + + if (v > 0) { + for (int g = 0; g < (int)v_branch.size(); g++) { + // record the best matched seq and pos + if (v_branch[g].active) { + i_batch = v_branch[g].i_batch[v]; + i_batch_best = i_batch; + seq_id_best = v_branch[g].seq_id; + ++_n_accept; + break; + } + } + + if (i_batch == 0) { + break; + } + } + + size_t sample_idx; + if (seq_id_best == 0) + sample_idx = 0; + else + sample_idx = _window * (_ngram - 1) + (seq_id_best - (_window + 1)) * (_ngram - 1) + + i_batch - 1; + + //vector selected set + selected[sample_idx] = true; + + // sampler from logits all + std::span span_logits{logits.data(),logits.size()}; + std::span sample_logit = span_logits.subspan(sample_idx * n_vocab, n_vocab); + _last_tok = sampler.process(sample_logit); + + std::vector tokens_tmp(1); + tokens_tmp[0] = _last_tok; + + resultTokens.push_back(_last_tok); + _n_generated++; + _n_past++; + + if (_ctx->is_eos(_last_tok)) break; + + if (!callback(_tokenizer->decode(tokens_tmp), Sentence::CONTINUE)) return true; + + // if verify pass, check the next sample token until verifing failed + for (int g = 0; g < (int)v_branch.size(); g++) { + // update the n-gram active status + if (v_branch[g].active) { + if (v == _ngram - 1) { + v_branch[g].active = false; + } else { + if (_last_tok != v_branch[g].tokens[v + 1]) { + v_branch[g].active = false; + } + } + } + } + + // update lookahead tokens when v=0 OR verify match + { + for (int i = 0; i < _window; i++) { + lhd_branch_prev[i] = lhd_branch[0][i]; + } + + if (v == 0) { + for (int j = 0; j < _ngram - 2; j++) { + lhd_branch[j] = lhd_branch[j + 1]; + } + + // sample from the last level + for (int i = 0; i < _window; i++) { + size_t sample_idx = (_ngram - 2) * _window + i; + std::span sample_logit = + span_logits.subspan(sample_idx * n_vocab, n_vocab); + lhd_branch[_ngram - 2][i] = sampler.process(sample_logit); + } + } else { + if (_lhd_update_mode == FWD_MAX_HIT) { + // update lookahead branch by foward + for (int j = 0; j < _ngram - 1; j++) { + for (int i = 0; i < _window - v; i++) { + lhd_branch[j][i] = lhd_branch[j][i + 1]; + } + } + } else if (_lhd_update_mode == FWD_LEVEL) { + // update lookahead branch by shifting level + for (int j = 0; j < _ngram - 2; j++) { + lhd_branch[j] = lhd_branch[j + 1]; + } + + for (int i = 0; i < _window; i++) { + // init from the previous level + lhd_branch[_ngram - 2][i] = lhd_branch[0][i]; + } + } + } + } + + // update n-grams pool + // only update n-grams pools when v=0 + if (v == 0) { + std::vector ngram(_ngram - 1); + // n-gram pool generation + for (int f = 0; f < _window; ++f) { + const int ft = lhd_branch_prev[f]; // first token of the n-gram + + for (int j = 0; j < _ngram - 1; ++j) { + ngram[j] = lhd_branch[j][f]; + } + + // filter-out repeating n-grams + { + bool is_unique = true; + + for (int k = 0; k < ngrams_pool.cnt[ft]; ++k) { + // caculate the related idx by the first n-gram token + const int idx = ft * (_ngram - 1) * _gcap + k * (_ngram - 1); + + bool is_match = true; + for (int j = 0; j < _ngram - 1; ++j) { + if (ngrams_pool.tokens[idx + j] != ngram[j]) { + is_match = false; + break; + } + } + + // if n-gram match all, discard one of them + if (is_match) { + is_unique = false; + break; + } + } + + if (!is_unique) { + continue; + } + } + + const int head = ngrams_pool.head[ft]; + const int idx = ft * (_ngram - 1) * _gcap + head * (_ngram - 1); + + for (int i = 0; i < _ngram - 1; i++) { + // update the n-gram pool with new n-gram + ngrams_pool.tokens[idx + i] = ngram[i]; + } + + ngrams_pool.cnt[ft] = std::min(_gcap, ngrams_pool.cnt[ft] + 1); + ngrams_pool.head[ft] = (head + 1) % _gcap; + + ngrams_pool.n_total++; + } + } + } + + if (_lhd_update_mode == FWD_MAX_HIT) { + // std::random_device rd; + // std::mt19937 gen(rd()); + // std::uniform_int_distribution<> dis(0, resultTokens.size() - 1); + + // fill lookahead branch + for (int i = 0; i < _ngram - 1; i++) { + for (int j = _window - i_batch_best; j < _window; j++) { + lhd_branch[i][j] = resultTokens[1 + rand() % (resultTokens.size() - 1)]; + // lhd_branch[i][j] = resultTokens[dis(gen)]; + // std::cout << "Fill token = " << lhd_branch[i][j] << std::endl; + } + } + } + + // KV cache management + if (!engine.updateKV(_n_past, selected)) + return Dialog::abort("context size exceeded", callback); + + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + break; + } + } + + State::busy(false); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + std::cout << std::endl << std::endl << std::flush; + __DEBUG("lhd-dec: n_generated = {} ---------- n_accept = {}", _n_generated, _n_accept); + + return !State::failed(); +} + +// Registrator instance +static OnLoad regy([]() { + Dialog::__register( + "lhd-dec", + [](std::shared_ptr env, const std::string& name, const json& conf) { + return (Dialog*)new LhdDecDialog(env, name, conf); + } + ); +}); + +void needLadeDialog() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialogs/multistream.cpp b/Genie/Genie/src/qualla/dialogs/multistream.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b9cf614b1e6ae6b63382ad6f60e1d0e1b299c5ec --- /dev/null +++ b/Genie/Genie/src/qualla/dialogs/multistream.cpp @@ -0,0 +1,300 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +bool MultiStreamDialog::processFollowOnGeneration(std::vector>& streams, std::vector& logits, Dialog::Callback callback) { + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + std::vector> attention_mask(_n_streams); + std::vector streamIndices; + + if (streams.size() == 0) { + callback("\n", Sentence::END); + return true; + } + + for (int i = 0; i < streams.size(); i++) { + // Initialize all attention_masks to attend to all previous tokens + attention_mask[i].resize(_n_past, 1); + streamIndices.push_back(i); + } + + State::busy(true); + + while (true) { + if (State::canceled()) break; + + // If this exceeds context length, truncate all streams and return + if (_n_past + streamIndices.size() > _ctx->size()) { + for (auto stream : streamIndices) + callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE); + break; + } + + // Accumulate input tokens from all streams + std::vector multi_tokens(streamIndices.size()); + + for (int i = 0; i < streamIndices.size(); i++) { + multi_tokens[i] = streams[streamIndices[i]].back(); + + // Also add current iteration to the attention_mask + for (auto _mask_row : streamIndices) + // Set to true iff on diagonal, i.e. attend to itself + attention_mask[streamIndices[i]].push_back((streamIndices[i] == _mask_row) ? 1 : 0); + } + + // Concatenate attention_mask for all active streams + std::vector multi_attn_mask; + multi_attn_mask.reserve((_n_past + streamIndices.size()) * streamIndices.size()); + for (auto i : streamIndices) + multi_attn_mask.insert( + multi_attn_mask.end(), + attention_mask[i].begin(), + attention_mask[i].end() + ); + + // __DEBUG("Multi attention mask = {}", multi_attn_mask); + + if (m_inputType == InputType::TOKENS) { + // Process input tokens for all streams in one batch + if (!engine.process(multi_tokens, multi_attn_mask, logits, true)) + return Dialog::abort("engine gen processing failed", callback); + } else if (m_inputType == InputType::EMBEDDINGS) { + // Accumulate input embeddings from all streams + auto embedBufSize = engine.getEmbeddingBufferSize(); + std::vector multi_embeddings; + + for (auto token : multi_tokens) { + // Convert tokens to embedding for the processing in the engine. + std::vector curTokenEmbedding(embedBufSize, 0); + m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize); + multi_embeddings.insert(multi_embeddings.end(), curTokenEmbedding.begin(), curTokenEmbedding.end()); + } + + // Process input tokens for all streams in one batch + if (!engine.process(multi_embeddings, multi_attn_mask, logits, true)) + return Dialog::abort("engine gen processing failed", callback); + } + + // Process all logits independently + std::span logit_span = std::span{logits.data(),logits.size()}; + for (int i = 0; i < streamIndices.size(); i++) { + _last_tok = sampler.process(logit_span.subspan(i * _vocab, _vocab)); + streams[streamIndices[i]].push_back(_last_tok); + } + + _n_past += streamIndices.size(); + _n_generated += streamIndices.size(); + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + for (auto it = streamIndices.begin(); it != streamIndices.end();) { + int32_t stream = *it; + if (_ctx->is_eos(streams[stream].back())) { + callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE); + it = streamIndices.erase(it); + } else { + ++it; + } + } + + if (streamIndices.size() == 0) break; + } + callback("\n", Sentence::END); + + State::busy(false); + + return true; +} + +bool MultiStreamDialog::process(std::vector& tokens, Dialog::Callback callback) { + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + if(m_inputType != InputType::TOKENS) { + __ERROR("Input type for model is not tokens."); + return false; + } + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + + State::clear(); + + auto& engine = *_engine["primary"]; + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(tokens, logits, false)) + return Dialog::abort("engine prompt processing failed", callback); + + _n_prompt += tokens.size(); + _n_past += tokens.size(); + + _prompt_len = _n_past; + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + std::vector> streams; + getTopK(logits, streams, _n_streams, _p_threshold, callback); + + _n_generated += streams.size(); + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + start.reset(); + + bool status = processFollowOnGeneration(streams, logits, callback); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return status; +} + +bool MultiStreamDialog::process( + std::vector& embedding_vectors, + T2ECallback t2eCallback, + Dialog::Callback callback +) { + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + if(m_inputType != InputType::EMBEDDINGS) { + __ERROR("Input type for model is not embeddings."); + return false; + } + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector logits; + + State::clear(); + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + // Store the t2e callback for reference during follow-on generation. + m_t2eCallback = t2eCallback; + + size_t embedBufSize = engine.getEmbeddingBufferSize(); + + { + std::vector eosEmbedding(embedBufSize, 0.0); + if (m_t2eCallback) { + m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize); + } + // For non-autogenerative usecases (where t2eCallback is not supplied), + // the EOS vector is all zero. This is fine for models with proper + // attention masking support, but may degrade accuracy otherwise. + if (!engine.cacheEosEmbedding(eosEmbedding)) { + __DEBUG("Failed to set the eos token embedding."); + return false; + } + } + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + size_t curTokenCount = embedding_vectors.size() / embedBufSize; + if (_n_past + curTokenCount > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokenCount, _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(embedding_vectors, {}, logits)) + return Dialog::abort("engine prompt processing failed", callback); + + _n_prompt += curTokenCount; + _n_past += curTokenCount; + + _prompt_len = _n_past; + + if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback); + + std::vector> streams; + getTopK(logits, streams, _n_streams, _p_threshold, callback); + + _n_generated += streams.size(); + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + start.reset(); + + bool status = processFollowOnGeneration(streams, logits, callback); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return status; +} + +// Registrator instance +static OnLoad regy([]() { + Dialog::__register( + "multistream", + [](std::shared_ptr env, const std::string& name, const json& conf) { + return (Dialog*)new MultiStreamDialog(env, name, conf); + } + ); +}); + +void needMultistreamDialog() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialogs/spec-dec.cpp b/Genie/Genie/src/qualla/dialogs/spec-dec.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b5e4e5023f4ff657e413c8af04858fd1e3701b02 --- /dev/null +++ b/Genie/Genie/src/qualla/dialogs/spec-dec.cpp @@ -0,0 +1,458 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +using qc = qualla::Config; + +class SpecDecDialog : public Dialog { + public: + SpecDecDialog(std::shared_ptr env, const std::string& name, const json& conf); + + virtual bool process(std::vector& tokens, Dialog::Callback callback) override; + + virtual bool process(std::vector& tokens, DialogCallback callback) override { + return false; + } + + private: + int32_t _draft_len; // Number of draft tokens + bool _parallel; // Enable parallel processing (where possible) + + Sampler& _d_sampler; // Draft sampler + Sampler& _t_sampler; // Target sampler + + // Token acceptor, called for each accepted token. + // Returns true to continue, false to stop + using Acceptor = std::function; + + // Rejection sampling. + // Returns number of accepted tokens + size_t rejectionSampling( + std::span tokens, + std::span target_logits, + std::span draft_probs, + Acceptor accept + ); + + int32_t sampleFromModifiedDist(std::span src0_dst, std::span src1); +}; + +SpecDecDialog::SpecDecDialog(std::shared_ptr env, const std::string& name, const json& conf) + : Dialog(env, name, conf), + _d_sampler(_sampler.contains("draft") ? *_sampler["draft"] : *_sampler["target"]), + _t_sampler(*_sampler["target"]) { + + _draft_len = qc::optional(conf, "draft-len", 3); + _parallel = qc::optional(conf, "parallel", false); + + // Check all underlying components for correct types an config + // If something is not right we set our error state that can be checked later + + if (!_sampler.contains("target")) { + State::fatal("\"target\" sampler not present in config!"); + return; + } + + if (!_engine.contains("target")) { + State::fatal("\"target\" engine not present in config!"); + return; + } + if (!_engine.contains("draft")) { + State::fatal("\"draft\" engine not present in config!"); + return; + } +} + +int32_t SpecDecDialog::sampleFromModifiedDist(std::span src0_dst, std::span src1) { + // [max(prob_target[x] - prob_draft[x], 0.f) for all x in vocab] + size_t size = src0_dst.size(); + + if (_t_sampler.gumbel()) { + // Avoid going in the denormal zone. + float tiny = 1.1754943508222875e-38; + +#pragma clang loop vectorize(enable) unroll_count(4) + for (size_t i = 0U; i < size; i++) { + float p_src0 = std::exp(src0_dst[i]); + float p_src1 = std::exp(src1[i]); + src0_dst[i] = std::log(std::max(tiny, p_src0 - p_src1)); + } + + // NOTE: The output logps_target is unnormalized since we use Gumbel trick. + // If we use standard multinomial sampling, normalization should be added. + + } else { + float sum = 0.0; // Unlikely to overflow (?) +#pragma clang loop vectorize(enable) unroll_count(4) + for (size_t i = 0U; i < size; i++) { + float num = std::max(0.f, src0_dst[i] - src1[i]); + sum += num; + src0_dst[i] = num; + } + // Normalize +#pragma clang loop vectorize(enable) unroll_count(4) + for (size_t i = 0U; i < size; i++) { + src0_dst[i] /= sum; + } + } + + if (_t_sampler.greedy()) return argmax(src0_dst); + + if (_t_sampler.gumbel()) return sampleUsingGumbelMax(src0_dst, _t_sampler.rng()); + + // Skipping softmax since the probs are already normalized + return sampleFromProbs(src0_dst, _t_sampler.rng()); +} + +size_t SpecDecDialog::rejectionSampling( + std::span tokens, + std::span target_logits, + std::span draft_probs, + Acceptor accept +) { + const size_t n_vocab = _ctx->n_vocab(); + const size_t n_tok = tokens.size(); + + assert(tokens.size() == draft_probs.size() / n_vocab); + assert(target_logits.size() == draft_probs.size() + n_vocab); + + // Rejection sampling: + // For each token in the n_tok tokens sampled from the draft model: + // 1. Determine the probability of that token being accepted by the target model + // 2. Accept the token with probability = prob_target[tok] / prob_draft[tok] (clamped to [0, 1]) + // 3. If the token is rejected, resample a new token from the following distribution: + // [max(prob_target[x] - prob_draft[x], 0.f) for all x in vocab] + int32_t t_tok; + size_t n_accepted = 0; + + std::vector target_probs; + + for (int32_t i = 0; i < n_tok; i++) { + int32_t d_tok = tokens[i]; + + std::span t_span = target_logits.subspan(i * n_vocab, n_vocab); + + if (_t_sampler.greedy()) { + t_tok = _t_sampler.process(t_span); + if (t_tok != d_tok) { + // Reject + break; + } + } else { + target_probs.clear(); + t_tok = _t_sampler.process(t_span, target_probs, false); // only probs, no token + + // Acceptance threshold + double threshold; + float prob_draft = draft_probs[i * n_vocab + d_tok]; + float prob_target = target_probs[d_tok]; + + if (_t_sampler.gumbel()) { + threshold = std::exp(double(prob_target) - double(prob_draft)); + } else { + threshold = double(prob_target) / double(prob_draft); + } + + double r = sampleFromUniform(_t_sampler.rng()); + if (r > threshold) { + // Reject + break; + } + } + // Accepted! + ++n_accepted; + if (!accept(d_tok)) return n_accepted; + } + + // Sample an extra token either from the target distribution or the modified distribution + if (n_accepted == n_tok) { + t_tok = _t_sampler.process(target_logits.subspan(n_tok * n_vocab)); + } else if (!_t_sampler.greedy()) { + // Resample from modified distribution. + t_tok = sampleFromModifiedDist( + std::span{target_probs.data(),target_probs.size()}, draft_probs.subspan(n_accepted * n_vocab, n_vocab) + ); + } // for greedy, t_tok should be already valid from the loop above + + ++n_accepted; + accept(t_tok); + + return n_accepted; +} + +bool SpecDecDialog::process(std::vector& tokens, Dialog::Callback callback) { + + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + const size_t n_vocab = _ctx->n_vocab(); + + // Vector for storing logits. + // Allocated & filled by the engine. + std::vector t_logits; + std::vector d_logits; + + bool keep_generating = true; + + // A buffer for tokens to be decoded (one at a time, per the Middleware's request) + std::vector decode_buf(1, 0); + + // Decode new token. + // Return true to continue generation, and false otherwise + auto decode_token = [&](int32_t t) { + decode_buf[0] = _last_tok = t; + + if (_ctx->is_eos(t)) { + keep_generating = false; + callback("", Sentence::END); + } else { + keep_generating = callback(_tokenizer->decode(decode_buf), Sentence::CONTINUE); + } + + return keep_generating; + }; + + State::clear(); + + auto& t_engine = *_engine["target"]; + auto& d_engine = *_engine["draft"]; + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + // Step 0: Process the prompt both on the target and draft models. + bool d_pmpt, t_pmpt; + if (_parallel) { + std::thread dt([&]() { d_pmpt = d_engine.process(tokens, d_logits, false); }); + std::thread tt([&]() { t_pmpt = t_engine.process(tokens, t_logits, false); }); + dt.join(); + tt.join(); + } else { + d_pmpt = d_engine.process(tokens, d_logits, false); + t_pmpt = t_engine.process(tokens, t_logits, false); + } + + if (!d_pmpt) return Dialog::abort("draft engine prompt processing failed", callback); + if (!t_pmpt) return Dialog::abort("target engine prompt processing failed", callback); + + // KV state Update + _n_prompt += tokens.size(); + _n_past += tokens.size(); + + if (!t_engine.updateKV(_n_past)) return Dialog::abort("target context size exceeded", callback); + if (!d_engine.updateKV(_n_past)) return Dialog::abort("draft context size exceeded", callback); + + // Sample one token from the target. + _last_tok = _t_sampler.process(t_logits); + + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + if (!decode_token(_last_tok)) return true; + + // Done with the prompt, start generating + start.reset(); + State::busy(true); + + // Buffers for all the tokens that need to be considered for each iteration + std::vector toks_to_target(_draft_len + 1); + std::vector toks_to_draft(2); + + // Buffer for all the probability distributions from the draft sampler + std::vector d_probs(n_vocab * _draft_len); + + toks_to_target.assign(1, _last_tok); + toks_to_draft.assign(1, _last_tok); + + // For keeping track of the number of tokens that were accepted in each iteration. + std::vector n_accepted_counts(_draft_len + 1, 0); + + // Draft n_past, either in sync with n_past or one token behind (accepted-all) + size_t d_n_past = _n_past; + + while (!State::canceled() && keep_generating) { + // Step 1: Use draft model to decode draft_len (aka gamma) tokens, and accumulate probabilities + d_probs.clear(); + + for (int32_t i = 0; i < _draft_len; i++) { + if (d_n_past + toks_to_draft.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", + d_n_past, + toks_to_target.size(), + _ctx->size()); + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + callback("", Sentence::END); + return true; + } + + if (!d_engine.process(toks_to_draft, d_logits)) + return Dialog::abort("draft engine gen processing failed", callback); + + d_n_past += toks_to_draft.size(); + + if (!d_engine.updateKV(d_n_past)) + return Dialog::abort("draft context size exceeded", callback); + + int32_t token = _d_sampler.process(d_logits, d_probs); + toks_to_draft.assign(1, token); + toks_to_target.push_back(token); + + if (_ctx->is_eos(token)) break; + } + + // Step 2: run the target model on the draft tokens + if (_n_past + toks_to_target.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", + _n_past, + toks_to_target.size(), + _ctx->size()); + callback("", Sentence::END); + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + return true; + } + + std::vector attention_map(toks_to_target.size()); + std::iota(attention_map.begin(), attention_map.end(), -1); + size_t n_tok_t = + t_engine.process(toks_to_target, attention_map, t_logits, true /* all logits */); + if (n_tok_t != toks_to_target.size()) + return Dialog::abort("target engine gen processing failed", callback); + + // Step 3: accept or reject draft tokens + size_t n_accepted = rejectionSampling( + std::span{toks_to_target.data(),toks_to_target.size()}.subspan(1), + std::span{t_logits.data(),t_logits.size()}, std::span{d_probs.data(),d_probs.size()}, decode_token + ); + + _n_generated += n_accepted; + _n_past += n_accepted; + + // Update stats + n_accepted_counts[n_accepted - 1]++; + + // Accepted all? + if (n_accepted == _draft_len + 1) { + // Grab the last 2 tokens + toks_to_draft.assign({toks_to_target[_draft_len], _last_tok}); + d_n_past = _n_past - 1; + } else { + // Grab only the last token + toks_to_draft.assign(1, _last_tok); + d_n_past = _n_past; + } + + toks_to_target.assign(1, _last_tok); + + __DEBUG("spec-dec: draft_len {} n_generated {} n_accepted {} n_past {}", + _draft_len, + _n_generated, + n_accepted, + _n_past); + + std::vector selected(attention_map.size(), false); + selected[0] = true; // first token is selected always + auto last_sel = 0; + for (int i = n_accepted - 1; i != 0; i = attention_map[i]) { + selected[i] = true; + last_sel = i > last_sel ? i : last_sel; + } + selected.resize(last_sel + 1); // trim away rejected tokens + + // Step 4: commit accepted tokens to kv-caches + if (!t_engine.updateKV(_n_past, selected)) + return Dialog::abort("target context size exceeded", callback); + if (!d_engine.updateKV(d_n_past)) + return Dialog::abort("draft context size exceeded", callback); + } + + if (d_n_past != _n_past) { + // The draft engine needs to process one last token to catch up + toks_to_draft.resize(1); + if (!d_engine.process(toks_to_draft)) + return Dialog::abort("draft engine gen processing failed", callback); + if (!d_engine.updateKV(_n_past)) + return Dialog::abort("draft context size exceeded", callback); + } + + State::busy(false); + + _kpis.generate.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + __KPIS("spec-dec: accepted counts: {}", n_accepted_counts); + + return true; +} + +// Registrator instance +static OnLoad regy([]() { + Dialog::__register( + "spec-dec", + [](std::shared_ptr env, const std::string& name, const json& conf) { + return (Dialog*)new SpecDecDialog(env, name, conf); + } + ); +}); + +// Register spec-dec sampler for compatibility +static OnLoad sampler_regy([]() { + Sampler::__register("spec-dec", [](Context& ctx, const json& conf) { + return (Sampler*)new BasicSampler(ctx, conf); + }); +}); + +void needSpdDialog() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/dialogs/ssd-q1.cpp b/Genie/Genie/src/qualla/dialogs/ssd-q1.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1e0048f8125451ad27f546e9f0f4210f79e09468 --- /dev/null +++ b/Genie/Genie/src/qualla/dialogs/ssd-q1.cpp @@ -0,0 +1,1046 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All rights reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +using qc = qualla::Config; +using Logits = std::span; + +class SelfSpecDecDialog : public Dialog { + enum { VERSION = 1 }; + + public: + SelfSpecDecDialog(std::shared_ptr env, const std::string& name, const json& conf); + + virtual bool process(std::vector& tokens, Dialog::Callback callback) override; + virtual bool process(std::vector& embedding_vectors, Dialog::T2ECallback t2eCallback, Dialog::Callback callback) override; + virtual void reset() override; + + virtual bool process(std::vector& tokens, DialogCallback callback) override { + return false; + } + + virtual bool save(const std::string& name) override; + virtual bool restore(const std::string& name) override; + + private: + Sampler& _t_sampler; + + int32_t _vocab; + + std::string _kv_prefix_name{"forecast-prefix"}; + + // AR8 + size_t _draft{1}; + std::vector _branches{3}; + + size_t _forecast_prefix{16}; + size_t _forecast_token_offset{32000}; + size_t _forecast_token_count{4}; + + // Multistream parameters + int32_t _n_streams; + float _p_threshold; + + InputType m_inputType{InputType::UNKNOWN}; + + bool processFollowOnGeneration(std::vector& tokens, std::vector& logits, Dialog::Callback callback); + // Multistream + bool processFollowOnGeneration(std::vector>& streams, std::vector& logits, Dialog::Callback callback); + + /* + Helper function for combining masks for SSD mulstistream. + + @param masks The attention mask to be tiled + @param streamIndices Indices of streams. The tiling count is equal to the size of this vector. + @param pastMap A vector of stream indices for masking all past tokens after the prompt. + @param prefixOffset Offset where KV prefix masking begins in each tile. + @param finalMask A mask that combines all of the independent masks such that + they can be executed in the same inference. + */ + void tileAttentionMask(const std::vector& mask, const std::vector streamIndices, const std::vector& pastMap, const size_t prefixOffset, std::vector& finalMask); + + std::vector gen_attention_map() const; + auto get_len_flat_sample_tree() const; + auto gen_forecast_tokens(int repeat) const; + + // Sampling and verification + std::vector build_sample_tree( + int32_t last_token, + Logits logits, + const std::vector& indices + ); + std::tuple, std::vector> verify_and_select_longest( + std::span sample_tree, + Logits logits + ); + std::vector sample_to_draft(Logits logits, size_t index, size_t count) { + const auto thislogit = logits.subspan(index * _vocab, _vocab); + IndexedLogits logit(thislogit, _t_sampler.rng()); + logit.topK(count); + return logit.indices; + } + int32_t sample_to_verify(Logits logits, size_t index) { + const auto thislogit = logits.subspan(index * _vocab, _vocab); + if (_t_sampler.greedy()) { + return argmax(thislogit); + } + auto token = _t_sampler.process(thislogit); + return token; + } +}; + +SelfSpecDecDialog::SelfSpecDecDialog( + std::shared_ptr env, + const std::string& name, + const json& conf +) + : Dialog(env, name, conf), _t_sampler(*_sampler["primary"]) { + + auto ssd_version = qc::optional(conf, "ssd-version", 0); + if (ssd_version > SelfSpecDecDialog::VERSION) __WARN("newer ssd-version in config!"); + + _vocab = _ctx->n_vocab(); + + _branches = qc::optional(conf, "branches", _branches); + _draft = _branches.size(); + + _forecast_prefix = qc::optional(conf, "forecast-prefix", _forecast_prefix); + _forecast_token_count = qc::optional(conf, "forecast-token-count", _forecast_token_count); + _forecast_token_offset = _vocab; + + _kv_prefix_name = qc::optional(conf, "forecast-prefix-name", _kv_prefix_name); + + _n_streams = qc::optional(conf, "n-streams", 1); + _p_threshold = qc::optional(conf, "p-threshold", 0.0); + + if (!_engine.contains("primary")) { + State::fatal("\"primary\" engine not present in config!"); + return; + } + + //Get Input Type from the engine + m_inputType = _engine["primary"]->getInputType(); + // Load KV prefix + Timer timer; + size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name); + if (n_restored_prefix != _forecast_prefix) { + // clang-format off + throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$", + n_restored_prefix, _kv_prefix_name, _forecast_prefix ) ); + // clang-format on + } + _n_past = _forecast_prefix; + _kpis.restore.update(timer.elapsed_usec()); +} + +auto SelfSpecDecDialog::get_len_flat_sample_tree() const { + size_t len_flat_sample_tree = 1; + size_t last_tokens = 1; + for (int i = 0; i < _draft; ++i) { + len_flat_sample_tree += last_tokens * _branches[i]; + last_tokens = last_tokens * _branches[i]; + } + return len_flat_sample_tree; +} + +auto SelfSpecDecDialog::gen_forecast_tokens(int repeat) const { + std::vector forecast_tokens(_draft, 0); + std::iota(forecast_tokens.begin(), forecast_tokens.end(), _forecast_token_offset); + + std::vector ret; + for (auto i = 0; i < repeat; ++i) + ret.insert(ret.end(), forecast_tokens.begin(), forecast_tokens.end()); + return ret; +} + +std::vector SelfSpecDecDialog::gen_attention_map() const { + auto len_flat_sample_tree = get_len_flat_sample_tree(); + std::vector attention_map(len_flat_sample_tree + len_flat_sample_tree * _draft, -1); + + auto build_verify_tree = [&attention_map, + this](auto self, int parent_begin, int parent_end, int level) { + if (level == _draft) return; + auto current = parent_end; + for (auto parent = parent_begin; parent < parent_end; parent += 1) { + for (auto child = current; child < current + _branches[level]; child += 1) + attention_map[child] = parent; + current += _branches[level]; + } + self(self, parent_end, current, level + 1); + }; + + auto build_forecast_tree = [&attention_map, this](int parent_begin, int parent_end) { + auto current = parent_end; + for (auto parent = parent_begin; parent < parent_end; parent += 1) { + for (auto child = current, current_parent = parent; child < current + _draft; + child += 1) { + attention_map[child] = current_parent; + current_parent = child; + } + current += _draft; + } + }; + + build_verify_tree(build_verify_tree, 0, 1, 0); + build_forecast_tree(0, len_flat_sample_tree); + return attention_map; +} + +std::vector SelfSpecDecDialog::build_sample_tree( + int32_t last_token, + Logits logits, + const std::vector& indices +) { + std::vector tree = {last_token}; + for (auto draft = 0, repeat = 1; draft < _draft; ++draft) { + auto samples = sample_to_draft(logits, indices[draft], _branches[draft]); + for (auto i = 0; i < repeat; ++i) { + tree.insert(tree.end(), samples.begin(), samples.end()); + } + repeat *= _branches[draft]; + } + return tree; +} + +std::tuple, std::vector> SelfSpecDecDialog::verify_and_select_longest( + std::span sample_tree, + Logits logits +) { + std::vector> accepted_all = {{sample_to_verify(logits, 0)}}; + std::vector> node_ids_all = {{0}}; + + std::vector draft_offset(_draft, 0); + draft_offset[0] = 1; + for (int32_t i = 1, draft_count = _branches[0]; i < _draft; ++i) { + draft_offset[i] = draft_offset[i - 1] + draft_count; + draft_count = draft_count * _branches[i]; + } + + size_t longest = 0, longest_size = 1; + auto verify_recursive = [&](auto self, + std::vector accepted, + std::vector node_ids, + int draft, + int offset_in_draft) -> void { + auto target = accepted.back(); + auto branch_base = draft_offset[draft] + offset_in_draft; + for (auto branch = 0; branch < _branches[draft]; ++branch) { + auto ndx_node = branch_base + branch; + if (!_ctx->is_eos(target) && target == sample_tree[ndx_node]) { + auto sample_accepted = sample_to_verify(logits, ndx_node); + accepted_all.push_back(accepted); + accepted_all.back().push_back(sample_accepted); + node_ids_all.push_back(node_ids); + node_ids_all.back().push_back(ndx_node); + if (node_ids_all.back().size() > longest_size) { + longest = node_ids_all.size() - 1; + longest_size = node_ids_all.back().size(); + } + if (draft + 1 < _draft) + self(self, + accepted_all.back(), + node_ids_all.back(), + draft + 1, + (offset_in_draft + branch) * _branches[draft + 1]); + } + } + }; + verify_recursive(verify_recursive, accepted_all.back(), node_ids_all.back(), 0, 0); + return {accepted_all[longest], node_ids_all[longest]}; +} + +void SelfSpecDecDialog::tileAttentionMask(const std::vector& mask, const std::vector streamIndices, const std::vector& pastMap, const size_t prefixOffset, std::vector& tiledMask) { + + const size_t sampleTreeLen = get_len_flat_sample_tree(); + const size_t pastMapLen = pastMap.size(); + const int posVal = 1, negVal = 0; + + const size_t maskSize = mask.size(); + const size_t numTokens = maskSize * streamIndices.size(); + + const size_t rowLength = _n_past + numTokens; + tiledMask.resize(numTokens * rowLength); + + for (int maskIdx = 0; maskIdx < streamIndices.size(); maskIdx++) { + // Number of rows to skip to reach the current tile. + const size_t tileOffset = maskIdx * maskSize; + int32_t* const tileStart = &tiledMask[tileOffset*rowLength + tileOffset + _n_past]; + for (int i = 0; i < maskSize; i++) { + // Pointer to the start of row i of the current mask + int32_t* rowPtr = &tiledMask[(tileOffset + i)*rowLength]; + // Skip kv-prefix attention for rows without speculative tokens. + const int prefixFillVal = (i < prefixOffset) ? negVal : posVal; + std::fill_n(rowPtr, _forecast_prefix, prefixFillVal); + rowPtr += _forecast_prefix; + // Always attend to prompt. + std::fill_n(rowPtr, _n_prompt, posVal); + rowPtr += _n_prompt; + + // Fill in the past valid tokens for this stream. + for (const size_t& pastIdx : pastMap) { + *rowPtr = (pastIdx == streamIndices[maskIdx]) ? posVal : negVal; + rowPtr++; + } + + // Clear the rest of the row. It will mostly consist of 0's. + std::fill_n(rowPtr, rowLength - _n_prompt - _forecast_prefix - pastMapLen, negVal); + // Move to the correct tile. + rowPtr += tileOffset; + // Translate the mask. + const auto tokenId = mask[i]; + if (tokenId > -1) { + std::copy_n(tileStart + (tokenId * rowLength), tokenId + 1, rowPtr); + } + // Always attend to self. + rowPtr[i] = posVal; + } + } +} + +// Takes a vector of tokens and produces a vector of embeddings via the provided T2E callback. +static inline void convertTokensToEmbeddings(std::vector& tokens, + std::vector& embeddings, + size_t embeddingBufferSize, + Dialog::T2ECallback t2eCallback) { + for(auto &token : tokens){ + std::vector embedding(embeddingBufferSize,0); + t2eCallback(token, embedding.data(), embeddingBufferSize); + embeddings.insert(embeddings.end(), embedding.begin(), embedding.end()); + } +} + +bool SelfSpecDecDialog::processFollowOnGeneration(std::vector& tokens, std::vector& logits, Dialog::Callback callback){ + + // Handles the printing of the subsequent generated tokens + bool keep_generating = true; + const size_t context = _ctx->n_ctx(); + + std::vector decode_buf( + 1, 0 + ); // A buffer for tokens to be decoded (one at a time, per the Middleware's request) + auto decode_token = [&](int32_t t) { + if (!keep_generating) return; + // Decode new token. + // Return true to continue generation, and false otherwise + decode_buf[0] = _last_tok = t; + ++_n_generated; + if (_ctx->is_eos(t)) { + keep_generating = false; + callback("", Sentence::END); + } else { + keep_generating = callback(_tokenizer->decode(decode_buf), Sentence::CONTINUE); + } + return; + }; + // set decode_buf from prompt processing + decode_buf[0] = _last_tok; + + auto& engine = *_engine["primary"]; + + auto update_kv = [&engine, &callback, this](size_t past, const std::vector& selected) { + if (!engine.updateKV(past, selected)) + return Dialog::abort("context size exceeded", callback); + return true; + }; + + + // prepare the next inference + std::vector indices(_draft, 0); + std::iota(indices.begin(), indices.end(), 1); + tokens = build_sample_tree(sample_to_verify(std::span{logits.data(),logits.size()}, 0), std::span{logits.data(),logits.size()}, indices); + decode_token(tokens[0]); + + // Prepare constant options for next inferences + const auto len_flat_sample_tree = get_len_flat_sample_tree(); + const auto forecast_tokens = gen_forecast_tokens(len_flat_sample_tree); + const auto attention_map = gen_attention_map(); + + engine.set({{"kv-prefix-offset", len_flat_sample_tree}}); + + std::vector accepted_counts(_draft + 1, 0); + std::vector selected(attention_map.size(), false); + + while (!State::canceled() && keep_generating) { + + // Append forecast tokens + tokens.insert(tokens.end(), forecast_tokens.begin(), forecast_tokens.end()); + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + break; + } + + size_t n_tok_t = 0; + + // Bifurcate based on embedding as input or token as input + if (m_inputType == InputType::TOKENS) + n_tok_t = engine.process(tokens, attention_map, logits, true /* all logits */); + else if (m_inputType == InputType::EMBEDDINGS) { + // Convert tokens to embedding for the processing in the engine. + auto embedBufSize = engine.getEmbeddingBufferSize(); + std::vector embedding; + for(auto &token: tokens){ + std::vector curTokenEmbedding(embedBufSize,0); + m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize); + embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end()); + } + n_tok_t = engine.process(embedding, attention_map, logits, true /* all logits */); + } else { + return Dialog::abort("No valid Input Type is used", callback); + } + if (n_tok_t != tokens.size()) return Dialog::abort("engine processing failed", callback); + + // Accept tokens + auto [accepted_tokens, accepted_ids] = verify_and_select_longest(std::span{tokens.data(),tokens.size()}, + std::span{logits.data(),logits.size()}); + + // Commit accepted tokens to kv-caches + selected.resize(accepted_ids.back() + 1); // trim away rejected tokens + std::fill(selected.begin(), selected.end(), false); + for (auto id : accepted_ids) + selected[id] = true; + accepted_counts[accepted_tokens.size() - 1] += 1; + _n_past += accepted_tokens.size(); + update_kv(_n_past, selected); + + // Decode tokens + std::for_each(accepted_tokens.begin(), accepted_tokens.end(), decode_token); + + // Prepare new tokens + auto next_draft_offset = len_flat_sample_tree + accepted_ids.back() * _draft; + std::iota(indices.begin(), indices.end(), next_draft_offset); + tokens = build_sample_tree(accepted_tokens.back(), std::span{logits.data(),logits.size()}, indices); + } + + State::busy(false); + + auto total_iteration = std::accumulate(accepted_counts.begin(), accepted_counts.end(), 0); + auto accept_rate = + float(_n_generated - 1) / total_iteration; // -1: exclude first generated token + __KPIS("SSD{{draft:{}, branch:{}, greedy:{}}}: accepted counts: {}, accept rate = {} tokens/iteration", + _draft, + _branches, + _t_sampler.greedy(), + accepted_counts, + accept_rate); + + return true; +} + +// Multistream AR generation +bool SelfSpecDecDialog::processFollowOnGeneration(std::vector>& streams, std::vector& logits, Dialog::Callback callback) { + + auto& sampler = *_sampler["primary"]; + auto& engine = *_engine["primary"]; + + auto update_kv = [&engine, &callback, this](size_t past, const std::vector& selected) { + if (!engine.updateKV(past, selected)) + return Dialog::abort("context size exceeded", callback); + return true; + }; + + std::vector streamIndices(streams.size()); + std::vector past_map(streams.size()); + + std::iota(streamIndices.begin(), streamIndices.end(), 0); + // Since the first inference is done separately, it is + // expected that each stream already has 1 valid AR token. + std::iota(past_map.begin(), past_map.end(), 0); + + bool keep_generating = true; + const size_t context = _ctx->n_ctx(); + + if (streams.size() == 0) { + callback("\n", Sentence::END); + return true; + } + + // Prepare constant options for next inferences + const auto len_flat_sample_tree = get_len_flat_sample_tree(); + const auto forecast_tokens = gen_forecast_tokens(len_flat_sample_tree); + const auto attention_map = gen_attention_map(); + + std::vector> draftStreams(streams.size()); + + for (int i = 0; i < streams.size(); i++) { + // prepare the next inference + std::vector indices(_draft, 0); + std::iota(indices.begin(), indices.end(), 1); + draftStreams[i] = build_sample_tree(sample_to_verify(std::span{logits.data(),logits.size()}, i*(1+_draft)), std::span{logits.data(),logits.size()}, indices); + streams[i].push_back(draftStreams[i][0]); + + } + + std::vector multi_attn_mask; + + std::vector accepted_counts(_draft + 1, 0); + + engine.set({{"kv-prefix-offset", len_flat_sample_tree}}); + + State::busy(true); + + while (true) { + if (State::canceled()) break; + + // If this exceeds context length, truncate all streams and return + if (_n_past + streamIndices.size() > _ctx->size()) { + for (auto stream : streamIndices) + callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE); + break; + } + + // Accumulate input tokens from all streams + std::vector multi_tokens; + for (auto streamIdx : streamIndices) { + multi_tokens.insert(multi_tokens.end(), draftStreams[streamIdx].begin(), draftStreams[streamIdx].end()); + multi_tokens.insert(multi_tokens.end(), forecast_tokens.begin(), forecast_tokens.end()); + } + + if (_n_past + multi_tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size()); + callback("", Sentence::END); + break; + } + + tileAttentionMask(attention_map, streamIndices, past_map, len_flat_sample_tree, multi_attn_mask); + + size_t n_tok_t = 0; + + if (m_inputType == InputType::TOKENS) { + // Process input tokens for all streams in one batch + n_tok_t = engine.process(multi_tokens, multi_attn_mask, logits, true); + } else if (m_inputType == InputType::EMBEDDINGS) { + // Accumulate input embeddings from all streams + auto embedBufSize = engine.getEmbeddingBufferSize(); + std::vector multi_embeddings; + + convertTokensToEmbeddings(multi_tokens, multi_embeddings, embedBufSize, m_t2eCallback); + + // Process input tokens for all streams in one batch + n_tok_t = engine.process(multi_embeddings, multi_attn_mask, logits, true); + } + if (n_tok_t != multi_tokens.size()) return Dialog::abort("engine processing failed", callback); + + std::vector all_selected; + + // Process all logits independently + std::span logit_span = std::span{logits.data(),logits.size()}; + std::span token_span = std::span{multi_tokens.data(), multi_tokens.size()}; + for (int i = 0; i < streamIndices.size(); i++) { + const size_t streamIdx = streamIndices[i]; + std::vector& stream = streams[streamIdx]; + + const size_t tileStride = draftStreams[streamIdx].size() + forecast_tokens.size(); + + std::span tiled_logits = logit_span.subspan(i * tileStride * _vocab, _vocab); + + // Accept tokens + auto [accepted_tokens, accepted_ids] = verify_and_select_longest(token_span.subspan(i * tileStride, tileStride), + tiled_logits); + + // Commit accepted tokens to kv-caches + std::vector selected(tileStride, false); + for (auto id : accepted_ids) { + selected[id] = true; + past_map.push_back(streamIdx); + } + all_selected.insert(all_selected.end(), selected.begin(), selected.end()); + accepted_counts[accepted_tokens.size() - 1] += 1; + _n_past += accepted_tokens.size(); + + // Decode tokens + stream.insert(stream.end(), accepted_tokens.begin(), accepted_tokens.end()); + _n_generated += accepted_tokens.size(); + + // Prepare new tokens + std::vector indices(_draft, 0); + auto next_draft_offset = len_flat_sample_tree + accepted_ids.back() * _draft; + std::iota(indices.begin(), indices.end(), next_draft_offset); + draftStreams[streamIdx] = build_sample_tree(accepted_tokens.back(), tiled_logits, indices); + } + + update_kv(_n_past, all_selected); + for (auto it = streamIndices.begin(); it != streamIndices.end();) { + int32_t stream = *it; + if (_ctx->is_eos(streams[stream].back())) { + callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE); + it = streamIndices.erase(it); + } else { + ++it; + } + } + + if (streamIndices.size() == 0) break; + } + callback("\n", Sentence::END); + + State::busy(false); + + auto total_iteration = std::accumulate(accepted_counts.begin(), accepted_counts.end(), 0); + auto accept_rate = + float(_n_generated - 1) / total_iteration; // -1: exclude first generated token + __KPIS("SSD{{draft:{}, branch:{}, greedy:{}}}: accepted counts: {}, accept rate = {} tokens/iteration", + _draft, + _branches, + _t_sampler.greedy(), + accepted_counts, + accept_rate); + + return true; +} + +// Handle prompt processing and generation will be done processFollowOnGeneration +// Pass t2e callback using setter and remove as an argument. call setter from the base query function of dialog + +bool SelfSpecDecDialog::process(std::vector& embedding, + T2ECallback t2eCallback, + Dialog::Callback callback ){ + + // Check for prev failures and bail out early + if (State::failed()) return false; + + if(m_inputType != InputType::EMBEDDINGS) { + __ERROR("Input type for model is not embeddings."); + return false; + } + + Timer start; + State::clear(); + + std::vector logits; + auto& engine = *_engine["primary"]; + + auto update_kv = [&engine, &callback, this](size_t past, const std::vector& selected) { + if (!engine.updateKV(past, selected)) + return Dialog::abort("context size exceeded", callback); + return true; + }; + + // Store the t2e callback for reference during follow-on generation. + m_t2eCallback = t2eCallback; + + auto embedBufSize = engine.getEmbeddingBufferSize(); + + { + std::vector eosEmbedding(embedBufSize, 0.0); + if (m_t2eCallback) { + m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize); + } + if (!engine.cacheEosEmbedding(eosEmbedding)) { + __DEBUG("Failed to set the eos token embedding."); + return false; + } + } + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + start.reset(); + + engine.set({{"kv-prefix-skip", _forecast_prefix}}); + + std::vector tokens(1,0); + + // Process prompt + // get number of tokens in the input + size_t curTokensCount = embedding.size()/embedBufSize; + + if(curTokensCount * embedBufSize != embedding.size()){ + size_t expectedLength = (curTokensCount + (embedding.size()%embedBufSize != 0))*embedBufSize; + __DEBUG("Input is wrong expected {} and found {}.", expectedLength, embedding.size()); + return Dialog::abort("Input is not an multiple for the embedding Length", callback); + } + + _n_prompt += curTokensCount; + + std::vector attention_map(curTokensCount); + std::iota(attention_map.begin(), attention_map.end(), -1); + + engine.set({{"kv-prefix-offset", curTokensCount}}); // Do not attend prefix + + if (_n_past + curTokensCount > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokensCount, _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(embedding, attention_map, logits, false)) + return Dialog::abort("engine prompt processing failed", callback); // Change this message also to some generic message. + _n_past += curTokensCount; + update_kv(_n_past, {}); + + bool status = true; + if (_n_streams <= 1) { + tokens[0] = sample_to_verify(std::span{logits.data(),logits.size()}, 0); + + // Decode the first token. + _last_tok = tokens[0]; + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + return true; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true; + //decode_token(tokens[0]); + + if (!m_t2eCallback) { + callback("", Sentence::END); + return true; + } + + // Mark TTFT + _kpis.prompt.update(start.elapsed_usec()); + start.reset(); + State::busy(true); + + // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix + // process separately because logits are required for these tokens + for (int i = 0; i < _draft; ++i) + tokens.push_back(_forecast_token_offset + i); + + attention_map.resize(tokens.size()); + std::iota(attention_map.begin(), attention_map.end(), -1); + engine.set({{"kv-prefix-offset", 1}}); // Prevent the last token from attending + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + // Convert tokens to embeddings + // reset embedding vector to make space for the next runs + embedding.clear(); + convertTokensToEmbeddings(tokens, embedding, embedBufSize, m_t2eCallback); + + if (!engine.process(embedding, attention_map, logits, true)) + return Dialog::abort("initial inference for SSD pipeline failed", callback); + + _n_past += 1; + update_kv(_n_past, {}); + + // Use existing as much as possible + status = processFollowOnGeneration(tokens, logits, callback); + } else { + std::vector> streams; + getTopK(logits, streams, _n_streams, _p_threshold, callback); + + if (!m_t2eCallback) { + for (auto& stream : streams) { + if (!callback(_tokenizer->decode(stream) + "\n", Sentence::BEGIN)) return true; + } + callback("", Sentence::END); + return true; + } + + // Mark TTFT + _kpis.prompt.update(start.elapsed_usec()); + start.reset(); + State::busy(true); + + if (streams.size() == 0) { + callback("\n", Sentence::END); + return true; + } + + // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix + // process separately because logits are required for these tokens + attention_map.resize(1 + _draft); + std::iota(attention_map.begin(), attention_map.end(), -1); + + std::vector stream_indices(streams.size()); + std::iota(stream_indices.begin(), stream_indices.end(), 0); + + std::vector multi_attn_mask; + std::vector past_map; + const size_t kvPrefixOffset = 1; + + tileAttentionMask(attention_map, stream_indices, past_map, kvPrefixOffset, multi_attn_mask); + + // Accumulate input tokens from all streams + std::vector multi_tokens; + + multi_tokens.reserve(streams.size() * (1 + _draft)); + for (int i = 0; i < streams.size(); i++) { + multi_tokens.insert(multi_tokens.end(), streams[i].begin(), streams[i].end()); + for (int i = 0; i < _draft; ++i) { + multi_tokens.push_back(_forecast_token_offset + i); + } + } + + // Convert tokens to embeddings + // reset embedding vector to make space for the next runs + embedding.clear(); + convertTokensToEmbeddings(multi_tokens, embedding, embedBufSize, m_t2eCallback); + + if (_n_past + multi_tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(embedding, multi_attn_mask, logits, true)) + return Dialog::abort("initial inference for SSD pipeline failed", callback); + + std::vector selected(multi_tokens.size(), false); + for (int i = 0; i < multi_tokens.size(); i+=(_draft+1)) { + selected[i] = true; + } + + _n_past += streams.size(); + update_kv(_n_past, selected); + + status = processFollowOnGeneration(streams, logits, callback); + } + + _kpis.generate.update(start.elapsed_usec()); + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + start.reset(); + + return status; +} + +bool SelfSpecDecDialog::process(std::vector& tokens, Dialog::Callback callback) { + + // Check for prev failures and bail out early + if (State::failed()) return false; + + Timer start; + + if(m_inputType != InputType::TOKENS) { + __ERROR("Input type for model is not tokens."); + return false; + } + + State::clear(); + + std::vector logits; + auto& engine = *_engine["primary"]; + + auto update_kv = [&engine, &callback, this](size_t past, const std::vector& selected) { + if (!engine.updateKV(past, selected)) + return Dialog::abort("context size exceeded", callback); + return true; + }; + + using FF = Engine::Feature::Flags; + if (engine.supports(FF::DYNAMIC_LOAD)) engine.load(); + + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + start.reset(); + + engine.set({{"kv-prefix-skip", _forecast_prefix}}); + + std::vector attention_map(tokens.size()); + std::iota(attention_map.begin(), attention_map.end(), -1); + + // Process prompt + _n_prompt += tokens.size(); + engine.set({{"kv-prefix-offset", tokens.size()}}); // Do not attend prefix + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(tokens, attention_map, logits, false)) + return Dialog::abort("engine prompt processing failed", callback); + _n_past += tokens.size(); + update_kv(_n_past, {}); + + bool status = true; + if (_n_streams <= 1) { + tokens[0] = sample_to_verify(std::span{logits.data(),logits.size()}, 0); + tokens.resize(1); + + // Decode the first token. + _last_tok = tokens[0]; + if (_ctx->is_eos(_last_tok)) { + callback("", Sentence::END); + return true; + } + + if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true; + // decode_token(tokens[0]); + + // Mark TTFT + _kpis.prompt.update(start.elapsed_usec()); + start.reset(); + State::busy(true); + + // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix + // process separately because logits are required for these tokens + for (int i = 0; i < _draft; ++i) + tokens.push_back(_forecast_token_offset + i); + + attention_map.resize(tokens.size()); + std::iota(attention_map.begin(), attention_map.end(), -1); + engine.set({{"kv-prefix-offset", 1}}); // Prevent the last token from attending + + if (_n_past + tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(tokens, attention_map, logits, true)) + return Dialog::abort("initial inference for SSD pipeline failed", callback); + + _n_past += 1; + update_kv(_n_past, {}); + + status = processFollowOnGeneration(tokens, logits, callback); + } else { + std::vector> streams; + getTopK(logits, streams, _n_streams, _p_threshold, callback); + + // Mark TTFT + _kpis.prompt.update(start.elapsed_usec()); + start.reset(); + State::busy(true); + + if (streams.size() == 0) { + callback("\n", Sentence::END); + return true; + } + + // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix + // process separately because logits are required for these tokens + attention_map.resize(1 + _draft); + std::iota(attention_map.begin(), attention_map.end(), -1); + + std::vector stream_indices(streams.size()); + std::iota(stream_indices.begin(), stream_indices.end(), 0); + + std::vector multi_attn_mask; + std::vector past_map; + const size_t kvPrefixOffset = 1; + + tileAttentionMask(attention_map, stream_indices, past_map, kvPrefixOffset, multi_attn_mask); + + // Accumulate input tokens from all streams + std::vector multi_tokens; + + multi_tokens.reserve(streams.size() * (1 + _draft)); + for (int i = 0; i < streams.size(); i++) { + multi_tokens.insert(multi_tokens.end(), streams[i].begin(), streams[i].end()); + for (int i = 0; i < _draft; ++i) { + multi_tokens.push_back(_forecast_token_offset + i); + } + } + + if (_n_past + multi_tokens.size() > _ctx->size()) { + __WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size()); + callback("", Sentence::END); + return true; + } + + if (!engine.process(multi_tokens, multi_attn_mask, logits, true)) + return Dialog::abort("initial inference for SSD pipeline failed", callback); + + std::vector selected(multi_tokens.size(), false); + for (int i = 0; i < multi_tokens.size(); i+=(_draft+1)) { + selected[i] = true; + } + + _n_past += streams.size(); + update_kv(_n_past, selected); + + status = processFollowOnGeneration(streams, logits, callback); + } + + _kpis.generate.update(start.elapsed_usec()); + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + start.reset(); + + return status; +} + +void SelfSpecDecDialog::reset() { + Dialog::reset(); + _n_past = _forecast_prefix; + size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name); + if (n_restored_prefix != _forecast_prefix) { + // clang-format off + throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$", + n_restored_prefix, _kv_prefix_name, _forecast_prefix ) ); + // clang-format on + } +} + +bool SelfSpecDecDialog::save(const std::string& name) { + if (_n_streams > 1) { + throw std::runtime_error("Save is unsupported for multistream dialogs."); + } + return Dialog::save(name); +} + +bool SelfSpecDecDialog::restore(const std::string& name) { + if (_n_streams > 1) { + throw std::runtime_error("Restore is unsupported for multistream dialogs."); + } + return Dialog::restore(name); +} + +// Registrator instance +static OnLoad regy([]() { + Dialog::__register( + "ssd-q1", + [](std::shared_ptr env, const std::string& name, const json& conf) { + return (Dialog*)new SelfSpecDecDialog(env, name, conf); + } + ); +}); + +// Register ssd sampler for compatibility +static OnLoad sampler_regy([]() { + Sampler::__register("basic", [](Context& ctx, const json& conf) { + return (Sampler*)new BasicSampler(ctx, conf); + }); +}); + +void needSsdDialog() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/embedding.cpp b/Genie/Genie/src/qualla/embedding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9426e5715fa5050b54870222a55692c862e082ab --- /dev/null +++ b/Genie/Genie/src/qualla/embedding.cpp @@ -0,0 +1,190 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +namespace fs = std::filesystem; + +namespace qualla { + +Embedding::Embedding(std::shared_ptr env, const std::string& name, const qualla::json& json) + : _name(name), _env(env) { + Timer start; + + _env->logger().debug(fmt::format("embedding-new: {} config {}", name, json.dump())); + + using qc = qualla::Config; + + // Parse prompt config + const qualla::json& pmt_conf = qc::optional(json, "prompt", {}); + _tags = qc::optional>(pmt_conf, "tags", {"", ""}); + + // Create the context first + _ctx = Context::create(*_env, name, qc::optional(json, "context", {})); + + // Create Tokenizer + fs::path tok_path = _env->path().models / qc::mandatory(json, "tokenizer"); + _tokenizer = Tokenizer::create(*_ctx, tok_path); + + // Create Engine + const qualla::json& eng_conf = qc::mandatory(json, "engine"); + _engine = Engine::create(*_ctx, eng_conf); + + // Truncation of input to context + _input_truncation = qc::optional(json, "truncate-input", false); + + using FF = Engine::Feature::Flags; + if (!_engine->supports(FF::OUTPUT_EMBEDDINGS)) + throw std::runtime_error("engine must output embeddings"); + + _kpis.init.update(start.elapsed_usec()); +} + +Embedding::~Embedding() {} + +bool Embedding::process(std::vector& tokens, std::vector& output) { + Timer start; + + State::clear(); + + size_t n = _engine->process(tokens, output, false); + if (!n) { + State::error("engine prompt processing failed"); + return false; + } + + _n_prompt += tokens.size(); + + // Clean the buffer before using + _output_dimensions.clear(); + + uint64_t output_size = 1; + // push number of tokens present in the result. + _output_dimensions.push_back(n); + // push back the dimension of the each embedding + _output_dimensions.push_back(_ctx->n_embd()); + + output_size = n * _ctx->n_embd(); + + output.resize(output_size); + + _kpis.prompt.update(start.elapsed_usec()); + + // Log latest KPIs in a single line + _env->logger().post(Logger::KPIS, kpis().dump(" ")); + + return true; +} + +bool Embedding::query(const std::string& str, std::vector& output) { + std::string p_str; // prompt string + std::vector p_vec; // prompt tokens + + p_vec.reserve(_ctx->n_ctx()); + + p_str = _tags[0] + str + _tags[1]; + + _env->logger().debug(fmt::format("embedding-query: {}", str)); + _env->logger().debug(fmt::format("embedding-prompt: {}", p_str)); + + _n_queries++; + + _tokenizer->encode(p_str, p_vec); + + _env->logger().debug(fmt::format("embedding-tokens: {}", p_vec)); + + if(p_vec.size() > (_ctx->n_ctx())){ // Condition to not allow input to exceed context. + if(_input_truncation == false){ + throw std::runtime_error("Input exceeds the context of the model."); + } + else{ + p_vec.resize(_ctx->n_ctx()); + } + } + + return process(p_vec, output); +} + +// Embedding KPIs helpers + + +void Embedding::output_dimensions(std::vector& outputDimensions){ + outputDimensions = _output_dimensions; +} + +// Get latest KPIs +Embedding::KPIs& Embedding::kpis() { + // Update TPS + if (_n_prompt) { + float t = _kpis.prompt.total_usec / _n_prompt; + _kpis.tps.prompt = 1000000.0 / (t ? t : 1000000.0); + } + + // We could synthesize more KPIs from from other layers (engine, sampler, etc) + return _kpis; +} + +std::string Embedding::KPIs::dump(std::string_view sep) const { + return fmt::format( + "init:[{}]{}prompt:[{}]{} tps-prompt:{:.2f}", + init.dump(), + sep, + prompt.dump(), + sep, + tps.prompt + ); +} + +void Embedding::KPIs::reset() { + init.reset(); + prompt.reset(); + tps.prompt = 0.0; +} + +// Create API + +std::unique_ptr Embedding::create( + std::shared_ptr env, + const std::string& name, + const qualla::json& conf +) { + return std::make_unique(env, name, conf); +} + +std::unique_ptr Embedding::create( + std::shared_ptr env, + const std::string& name, + std::istream& json_stream +) { + return create(env, name, json::parse(json_stream)); +} + +std::unique_ptr Embedding::create( + std::shared_ptr env, + const std::string& name, + const fs::path& json_path +) { + if (!fs::exists(json_path)) + throw std::runtime_error(json_path.string() + ": file does not exist"); + std::ifstream ifs(json_path); + return create(env, name, ifs); +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engine.cpp b/Genie/Genie/src/qualla/engine.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2e2d3ce3db40f1e230259d7128b3c05790cb8543 --- /dev/null +++ b/Genie/Genie/src/qualla/engine.cpp @@ -0,0 +1,198 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include + +namespace qualla { + +Engine::Engine(Context& ctx, const std::string& type, const qualla::json& conf) + : _type(type), _ctx(ctx), _env(ctx.env()) { + _env.logger().debug( + fmt::format("engine-new: {} ctx {} config {}", type, _ctx.name(), conf.dump()) + ); + + using qc = qualla::Config; + _role = qc::optional(conf, "role", "primary"); +} + +Engine::~Engine() {} + +size_t Engine::process( + const std::vector& tokens, + const std::vector& attention_map, + std::vector& output, + bool output_all +) { + _env.logger().error(fmt::format("{}-engine does not support attention_map", _type)); + return 0; +} + +size_t Engine::process(const std::vector& tokens) { + // Derived engines should overwrite this to avoid copying logits + std::vector logits; + return process(tokens, logits); +} + +size_t Engine::process( + std::vector& embeddings, + const std::vector& attention_map, + std::vector& output, + bool output_all +) { + _env.logger().error(fmt::format("{}-engine does not support embedding as input", _type)); + return 0; +} + +bool Engine::updateKV(size_t n_past) { + _env.logger().error(fmt::format("{}-engine does not support sync", _type)); + return false; +} + +bool Engine::updateKV(size_t n_past, const std::vector& selected) { + _env.logger().error(fmt::format("{}-engine does not support sync with selected", _type)); + return false; +} + +size_t Engine::restore(const std::string& name) { + _env.logger().error(fmt::format("{}-engine does not support restore", _type)); + return 0; +} + +bool Engine::save(const std::string& name) { + _env.logger().error(fmt::format("{}-engine does not support save", _type)); + return false; +} + +void Engine::reset() { + _env.logger().error(fmt::format("{}-engine does not support reset", _type)); +} + +bool Engine::load() { + _env.logger().error(fmt::format("{}-engine does not support dynamic load", _type)); + return 0; +} + +bool Engine::unload() { + _env.logger().error(fmt::format("{}-engine does not support dynamic unload", _type)); + return false; +} + +bool Engine::set(qualla::json data) { + _env.logger().error(fmt::format("{}-engine does not support set()", _type)); + return false; +} + +qualla::json Engine::get() { + _env.logger().error(fmt::format("{}-engine does not support get()", _type)); + return false; +} + +bool Engine::cacheEosEmbedding(std::vector& eosEmbedding) { + _env.logger().error(fmt::format("{}-engine does not support cache eos embedding", _type)); + return true; +} + +size_t Engine::getEmbeddingBufferSize() { + _env.logger().error(fmt::format("{}-engine does not support embedding vectors", _type)); + return 0; +} + +qualla::InputType Engine::getInputType(){ + return qualla::InputType::TOKENS; +} + +// Engine KPIs + +std::string Engine::KPIs::dump(std::string_view sep) const { + return fmt::format( + "load:[{}]{}process:[{}]{}update-kv:[{}]{}unload:[{}]", + load.dump(), + sep, + process.dump(), + sep, + update_kv.dump(), + sep, + unload.dump() + ); +} + +void Engine::KPIs::reset() { + load.reset(); + process.reset(); + update_kv.reset(); + unload.reset(); +} + +// Engine registry type string + creator function +using Registry = std::unordered_map; +static std::unique_ptr registry; + +void Engine::__register(const std::string& type, Creator func) { + if (!registry) registry = std::make_unique(); + + Registry& r = *registry; + r[type] = func; +} + +std::unique_ptr Engine::create(Context& ctx, const qualla::json& conf) { + using qc = qualla::Config; + + std::string type = qc::mandatory(conf, "type"); + + + if (!registry) throw std::runtime_error(type + ": engine not found"); + + Registry& r = *registry; + + + if (!r.contains(type)) throw std::runtime_error(type + ": engine not found"); + + + return std::unique_ptr(r[type](ctx, conf)); +} + +std::unique_ptr Engine::create(Context& ctx, std::istream& json_stream) { + return create(ctx, json::parse(json_stream)); +} + +std::unique_ptr Engine::create(Context& ctx, const std::string& json_str) { + return create(ctx, json::parse(json_str)); +} + +std::vector Engine::list() { + std::vector v; + if (!registry) return v; + + Registry& r = *registry; + + for (auto k : r) + v.push_back(k.first); + return v; +} + +bool Engine::applyLoraAdapter(std::string lora_adapter_name) { + _env.logger().error(fmt::format("{}-engine does not support LoraAdapter", _type)); + return false; +} +bool Engine::applyLoraStrength(std::string tensor_name, float tensor_val) { + _env.logger().error(fmt::format("{}-engine does not support setLoraStrength", _type)); + return false; +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/lib.cpp b/Genie/Genie/src/qualla/engines/lib.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c6faaf917164472fe19d83559379f64d3c952628 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/lib.cpp @@ -0,0 +1,9 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +// Just a stub for building qualla::engines when no built-in engines are enabled diff --git a/Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp b/Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp new file mode 100644 index 0000000000000000000000000000000000000000..023834fab96e0bd8b2f3898d4ef6aea7a4276a30 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp @@ -0,0 +1,158 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "dlwrap.hpp" +#include "BackendExtensions.hpp" +#include "NetRunBackend.hpp" + +BackendExtensions::BackendExtensions( + BackendExtensionsConfigs backendExtensionsConfig, + void* backendLibHandle, + PerfProfile perfProfile, + std::shared_ptr clManager, + bool debug_qnn +) + : m_backendExtensionsLibPath(backendExtensionsConfig.sharedLibraryPath), + m_backendExtensionsConfigPath(backendExtensionsConfig.configFilePath), + m_backendInterface(nullptr), m_isNetRunBackendInterface(false), + m_createBackendInterfaceFn(nullptr), m_destroyBackendInterfaceFn(nullptr), + m_backendLibHandle(backendLibHandle), m_perfProfile(perfProfile), m_clManager(clManager), + m_debugQnn(debug_qnn) { + (void)m_perfProfile; +} + +BackendExtensions::~BackendExtensions() { + if (nullptr != m_backendInterface) { + if (m_isNetRunBackendInterface) { + QNN_DEBUG("Deleting NetRun Backend Interface"); + delete m_backendInterface; + } else { + if (nullptr != m_destroyBackendInterfaceFn) { + QNN_DEBUG("Destroying Backend Interface"); + m_destroyBackendInterfaceFn(m_backendInterface); + } + } + } +} + +bool BackendExtensions::loadFunctionPointers() { + + void* libHandle = dlopen(m_backendExtensionsLibPath.c_str(), RTLD_NOW | RTLD_LOCAL); + if (nullptr == libHandle) { + QNN_ERROR( + "Unable to load backend extensions lib: [%s]. dlerror(): [%s]", + m_backendExtensionsLibPath.c_str(), + dlerror() + ); + return false; + } + m_createBackendInterfaceFn = + (CreateBackendInterfaceFnType_t)dlsym(libHandle, "createBackendInterface"); + m_destroyBackendInterfaceFn = + (DestroyBackendInterfaceFnType_t)dlsym(libHandle, "destroyBackendInterface"); + if (nullptr == m_createBackendInterfaceFn || nullptr == m_destroyBackendInterfaceFn) { + QNN_ERROR("Unable to find symbols. dlerror(): [%s]", dlerror()); + return false; + } + + return true; +} + +void BackendExtensions::qnnLogCallback( + const char* fmt, + QnnLog_Level_t level, + uint64_t timestamp, + va_list args +) { + char buffer[1024] = ""; + const char* levelStr = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: + levelStr = " ERROR "; + break; + case QNN_LOG_LEVEL_WARN: + levelStr = "WARNING"; + break; + case QNN_LOG_LEVEL_INFO: + levelStr = " INFO "; + break; + case QNN_LOG_LEVEL_DEBUG: + levelStr = " DEBUG "; + break; + case QNN_LOG_LEVEL_VERBOSE: + levelStr = "VERBOSE"; + break; + case QNN_LOG_LEVEL_MAX: + levelStr = "UNKNOWN"; + break; + } + + int pos = snprintf( + buffer, sizeof(buffer), "QNN: [%s] time=%lu:", levelStr, (unsigned long)timestamp + ); + vsnprintf(buffer + pos, sizeof(buffer) - pos, fmt, args); + printf("%s", buffer); +} + +bool BackendExtensions::initialize() { + + QNN_DEBUG("DEBUG: m_backendExtensionsLibPath=%s\n", m_backendExtensionsLibPath.c_str()); + QNN_DEBUG("DEBUG: m_backendExtensionsConfigPath=%s\n", m_backendExtensionsConfigPath.c_str()); + if (m_backendExtensionsLibPath.empty() && m_backendExtensionsConfigPath.empty()) { + QNN_WARN("No BackendExtensions lib provided; initializing NetRunBackend Interface"); + m_isNetRunBackendInterface = true; + m_backendInterface = new NetRunBackend(); + } else { + QNN_DEBUG("Loading supplied backend extensions lib."); + QNN_DEBUG("Backend extensions lib path: %s", m_backendExtensionsLibPath.c_str()); + if (m_backendExtensionsConfigPath.empty()) { + QNN_DEBUG("Backend extensions lib specified without a config file."); + } else { + QNN_DEBUG("Backend extensions config path: %s", m_backendExtensionsConfigPath.c_str()); + } + if (!loadFunctionPointers()) { + QNN_ERROR("Failed to load function pointers."); + return false; + } + if (nullptr != m_createBackendInterfaceFn) { + m_backendInterface = m_createBackendInterfaceFn(); + } + } + if (nullptr == m_backendInterface) { + QNN_ERROR("Unable to load backend extensions interface."); + return false; + } + if (m_debugQnn) { + if (!(m_backendInterface->setupLogging(BackendExtensions::qnnLogCallback, QNN_LOG_LEVEL_VERBOSE))) { + QNN_WARN("Unable to initialize logging in backend extensions."); + } + } + if (!m_backendInterface->initialize(m_backendLibHandle)) { + QNN_ERROR("Unable to initialize backend extensions interface."); + return false; + } + if (!m_backendInterface->setPerfProfile(m_perfProfile)) { + QNN_WARN("Unable to set perf profile in backend extensions interface."); + //return false; + } + if (!m_backendInterface->loadConfig(m_backendExtensionsConfigPath)) { + QNN_ERROR("Unable to load backend extensions interface config."); + return false; + } + + if ((m_clManager != nullptr) && !m_backendInterface->loadCommandLineArgs(m_clManager)) { + QNN_ERROR("Unable to load backend extensions' command line arguments."); + return false; + } + + return true; +} + +IBackend* BackendExtensions::interface() { + return m_backendInterface; +} diff --git a/Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp b/Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a17b25dbd32315ae1c1d075b536c7009c5d24f42 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp @@ -0,0 +1,62 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include + +#include "IBackend.hpp" +#include "QnnConfig.hpp" +#include "Log.hpp" + +// This is a wrapper class that handles resources/state related to +// backend extensions interface. This is used by QnnNetRun library +// to manage and call into an IBackend interface implementation. +// Functionality present in this class: +// 1. Receives the argument string related to backend_extensions +// argument from the front end and processes it to open the +// backend extensions library. +// 2. Locates and stores symbols for creating and destroying the +// IBackend interface implementation. +// 3. If there is no backend_extensions argument, this class creates +// the dummy IBackend implementation aka NetRunBackend. +// 4. Gives QnnNetRun access to the implementation itself through +// interface() function. +class BackendExtensions final { + public: + BackendExtensions( + BackendExtensionsConfigs backendExtensionsConfig, + void* backendLibHandle, + PerfProfile perfProfile, + std::shared_ptr clManager = + std::shared_ptr(nullptr), + bool debug_qnn = false + ); + ~BackendExtensions(); + bool initialize(); + IBackend* interface(); + + private: + bool loadFunctionPointers(); + std::string m_backendExtensionsLibPath; + std::string m_backendExtensionsConfigPath; + IBackend* m_backendInterface; + bool m_isNetRunBackendInterface; + CreateBackendInterfaceFnType_t m_createBackendInterfaceFn; + DestroyBackendInterfaceFnType_t m_destroyBackendInterfaceFn; + void* m_backendLibHandle; + PerfProfile m_perfProfile; + std::shared_ptr m_clManager; + bool m_debugQnn{false}; + static void qnnLogCallback( + const char* fmt, + QnnLog_Level_t level, + uint64_t timestamp, + va_list args + ); +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp b/Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..deefb50fb474551bf95493e4d3caa2ddaafcd8ec --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp @@ -0,0 +1,122 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "ClientBuffer.hpp" +#include "QnnTypeMacros.hpp" + +void* ClientBuffer::getBuffer(Qnn_Tensor_t* tensor) { + if (!tensor) { + QNN_WARN("getBuffer: received a null pointer to a tensor"); + return nullptr; + } + return QNN_TENSOR_GET_CLIENT_BUF(tensor).data; +} + +size_t ClientBuffer::getBufferSize(Qnn_Tensor_t* tensor) { + if (!tensor) { + QNN_WARN("getBufferSize: received a null pointer to a tensor"); + return 0; + } + return QNN_TENSOR_GET_CLIENT_BUF(tensor).dataSize; +}; + +bool ClientBuffer::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) { + if (!tensor) { + QNN_ERROR("Received nullptr for tensors"); + return false; + } + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_RAW); + Qnn_ClientBuffer_t clientBuffer; + clientBuffer.data = malloc(tensorDataSize); + if (nullptr == clientBuffer.data) { + QNN_ERROR("mem alloc failed for clientBuffer.data"); + return false; + } + clientBuffer.dataSize = tensorDataSize; + QNN_TENSOR_SET_CLIENT_BUF(tensor, clientBuffer); + return true; +} + +bool ClientBuffer::freeTensorBuffer(Qnn_Tensor_t* tensor) { + if (!tensor) { + QNN_ERROR("Received nullptr for tensors"); + return false; + } + if (QNN_TENSOR_GET_CLIENT_BUF(tensor).data) { + if (m_sameMemoryFreeTensors.find(tensor) == m_sameMemoryFreeTensors.end()) { + free(QNN_TENSOR_GET_CLIENT_BUF(tensor).data); + } + QNN_TENSOR_SET_CLIENT_BUF(tensor, Qnn_ClientBuffer_t({nullptr, 0u})); + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED); + } + return true; +} + +bool ClientBuffer::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) { + if (nullptr == dest || nullptr == src) { + QNN_ERROR("Received nullptr"); + return false; + } + if (false == freeTensorBuffer(dest)) { + return false; + } + + QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src)); + QNN_TENSOR_SET_CLIENT_BUF(dest, QNN_TENSOR_GET_CLIENT_BUF(src)); + m_sameMemoryFreeTensors.insert(dest); + return true; +} + +bool ClientBuffer::useExternalMemory(Qnn_Tensor_t* dest, void* extMem) { + if (nullptr == dest || nullptr == extMem) { + QNN_ERROR("Received nullptr"); + return false; + } + + Qnn_ClientBuffer_t clientBuffer; + clientBuffer.data = extMem; + clientBuffer.dataSize = QNN_TENSOR_GET_CLIENT_BUF(dest).dataSize; + if (false == freeTensorBuffer(dest)) { + return false; + } + + QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSORMEMTYPE_RAW); + QNN_TENSOR_SET_CLIENT_BUF(dest, clientBuffer); + m_sameMemoryFreeTensors.insert(dest); + return true; +} + +void* ClientBuffer::allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) { + return nullptr; +} + +bool ClientBuffer::mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + size_t tensorDataSize, + int32_t fd, + uint32_t offset, + uint64_t totalBufferSize, + void* memPointer, + Qnn_ContextHandle_t contextHandle +) { + return false; +} + +bool ClientBuffer::deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) { + return false; +} + +void ClientBuffer::freeFusedBuffers() {} + +size_t ClientBuffer::getOffset(Qnn_Tensor_t* tensor) { + return 0; +} + +size_t ClientBuffer::getTotalBufferSize(Qnn_Tensor_t* tensor) { + return 0; +} \ No newline at end of file diff --git a/Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp b/Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6fa5b9fb081e817057abcc80a4fcc479f0ee8e1d --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp @@ -0,0 +1,85 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include "IBufferAlloc.hpp" +#include "Log.hpp" +#include +#include + +class ClientBuffer final : public IBufferAlloc { + public: + ClientBuffer() {}; + + // Disable copy constructors, r-value referencing, etc + ClientBuffer(const ClientBuffer&) = delete; + + ClientBuffer& operator=(const ClientBuffer&) = delete; + + ClientBuffer(ClientBuffer&&) = delete; + + ClientBuffer& operator=(ClientBuffer&&) = delete; + + bool initialize() override { return true; }; + + void* getBuffer(Qnn_Tensor_t* tensor) override; + + int getFd(Qnn_Tensor_t* tensor) override { + QNN_WARN("getFd: This is not ION memory"); + return -1; + }; + + size_t getOffset(Qnn_Tensor_t* tensor) override; + size_t getBufferSize(Qnn_Tensor_t* tensor) override; + size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override; + + bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override; + + bool freeTensorBuffer(Qnn_Tensor_t* tensor) override; + + bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override; + bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override { return false; } + + bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override; + + void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override; + bool allocateBuffers( + const std::map>& allocs_per_chunk, + std::map>& tensor_offsets + ) override { + return false; + }; + + bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + size_t tensorDataSize, + int32_t fd, + uint32_t offset, + uint64_t totalBufferSize, + void* memPointer, + Qnn_ContextHandle_t contextHandle + ) override; + bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override; + void freeFusedBuffers() override; + + bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + int alloc_idx, + size_t offset, + Qnn_ContextHandle_t ctx, + size_t size + ) override { + return false; + } + + virtual ~ClientBuffer() {}; + + private: + std::unordered_set m_sameMemoryFreeTensors; +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp b/Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1f6bf00a8c3a1bfd35b74e8d878e9927347ef67a --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp @@ -0,0 +1,156 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include "ICommandLineManager.hpp" +#include "QnnBackend.h" +#include "QnnContext.h" +#include "QnnGraph.h" +#include "QnnLog.h" +#include "QnnTypeDef.hpp" +#include "QnnProfile.h" +#include "QnnDevice.h" + +// Compile-time definition to check for QNN SDK features using the QNN API version +#define QUALLA_QNN_API_VERSION \ + (QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH) + +const uint32_t g_profilingLevelNotSet = 0; + +enum class PerfProfile { + LOW_BALANCED, + BALANCED, + DEFAULT, + HIGH_PERFORMANCE, + SUSTAINED_HIGH_PERFORMANCE, + BURST, + EXTREME_POWER_SAVER, + LOW_POWER_SAVER, + POWER_SAVER, + HIGH_POWER_SAVER, + SYSTEM_SETTINGS, + NO_USER_INPUT, + CUSTOM, + INVALID +}; + +// This is the interface that enables backend specific extensions in qnn-net-run. +// It is designed as hooks in the timeline of various events in NetRun. +// Backends that intend to implement custom features through qnn-net-run will have +// to implement this interface and add functionality in appropriate methods depending +// on where/when the custom functionality needs to be exercised. +// These functions/hooks will be called through the IBackend interface from within +// qnn-net-run wherever necessary. +class IBackend { + public: + virtual ~IBackend() {} + + virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) = 0; + + virtual bool initialize(void* backendLibHandle) = 0; + + virtual bool setPerfProfile(PerfProfile perfProfile) = 0; + + virtual QnnProfile_Level_t getProfilingLevel() = 0; + + virtual bool loadConfig(std::string configFile) = 0; + + virtual bool loadCommandLineArgs(std::shared_ptr clManager) = 0; + + virtual bool beforeBackendInitialize( + QnnBackend_Config_t*** customConfigs, + uint32_t* configCount + ) = 0; + + virtual bool afterBackendInitialize() = 0; + + virtual bool beforeContextCreate( + QnnContext_Config_t*** customConfigs, + uint32_t* configCount + ) = 0; + + virtual bool afterContextCreate() = 0; + + virtual bool beforeComposeGraphs( + GraphConfigInfo_t*** customGraphConfigs, + uint32_t* graphCount + ) = 0; + + virtual bool afterComposeGraphs() = 0; + +#if QUALLA_QNN_API_VERSION >= 21700 + virtual bool beforeGraphFinalizeUpdateConfig( + const char* graphName, + Qnn_GraphHandle_t graphHandle, + QnnGraph_Config_t*** customConfigs, + uint32_t* configCount + ) = 0; +#endif + + virtual bool beforeGraphFinalize() = 0; + + virtual bool afterGraphFinalize() = 0; + + virtual bool beforeRegisterOpPackages() = 0; + + virtual bool afterRegisterOpPackages() = 0; + + virtual bool beforeExecute( + const char* graphName, + QnnGraph_Config_t*** customConfigs, + uint32_t* configCount + ) = 0; + + virtual bool afterExecute() = 0; + + virtual bool beforeContextFree() = 0; + + virtual bool afterContextFree() = 0; + + virtual bool beforeBackendTerminate() = 0; + + virtual bool afterBackendTerminate() = 0; + + virtual bool beforeCreateFromBinary( + QnnContext_Config_t*** customConfigs, + uint32_t* configCount + ) = 0; + + virtual bool afterCreateFromBinary() = 0; + +#if QUALLA_QNN_API_VERSION >= 21700 + virtual bool beforeCreateContextsFromBinaryList( + std::map>* + contextKeyToCustomConfigsMap, + QnnContext_Config_t*** commonCustomConfigs, + uint32_t* commonConfigCount + ) = 0; + + virtual bool afterCreateContextsFromBinaryList() = 0; +#endif + + virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, uint32_t* configCount) = 0; + + virtual bool afterCreateDevice() = 0; + + virtual bool beforeFreeDevice() = 0; + + virtual bool afterFreeDevice() = 0; +}; + +// These are the function types that the backend extensions shared library is +// expected to expose. The first function helps NetRun obtain a valid implementation +// of IBackend interface and the second is used to destroy the same interface at the end. +// The function names themselves are expected to be these strings: +// 1. "createBackendInterface" +// 2. "destroyBackendInterface" +// These functions need to be tagged with extern "C" and their symbols need to be exposed. +typedef IBackend* (*CreateBackendInterfaceFnType_t)(); +typedef void (*DestroyBackendInterfaceFnType_t)(IBackend*); diff --git a/Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp b/Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6eac516d96a69ce8d07984b7a1ba2899c55997ae --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp @@ -0,0 +1,56 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once +#include "QnnTypes.h" +#include +#include +#include +#include +#include + +class IBufferAlloc { + public: + virtual ~IBufferAlloc() {} + IBufferAlloc() {} + virtual bool initialize() = 0; + virtual void* getBuffer(Qnn_Tensor_t* tensor) = 0; + virtual int getFd(Qnn_Tensor_t* tensor) = 0; + virtual size_t getOffset(Qnn_Tensor_t* tensor) = 0; + virtual size_t getBufferSize(Qnn_Tensor_t* tensor) = 0; + virtual size_t getTotalBufferSize(Qnn_Tensor_t* tensor) = 0; + virtual bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) = 0; + virtual bool freeTensorBuffer(Qnn_Tensor_t* tensor) = 0; + virtual bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) = 0; + virtual bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) = 0; + virtual bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) = 0; + virtual void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) = 0; + virtual bool allocateBuffers( + const std::map>& allocs_per_chunk, + std::map>& tensor_offsets + ) = 0; + virtual bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + size_t tensorDataSize, + int32_t fd, + uint32_t offset, + uint64_t totalBufferSize, + void* memPointer, + Qnn_ContextHandle_t contextHandle + ) = 0; + virtual bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + int alloc_idx, + size_t offset, + Qnn_ContextHandle_t ctx, + size_t size + ) = 0; + + virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0; + virtual void freeFusedBuffers() = 0; +}; \ No newline at end of file diff --git a/Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp b/Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp new file mode 100644 index 0000000000000000000000000000000000000000..13150c8b142aab8bcd63b6649104ac04ac1da84f --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp @@ -0,0 +1,95 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include +#include +#include +#include + +class ICommandLineManager { + public: + enum class Error { SUCCESS, PARSE_FAILURE, UNUSED_ARGUMENTS, OVER_SUBSCRIBED_ARGUMENTS }; + + using ValueList_t = std::vector>; + + /** + * @brief Parses provided command line arguments into key value pairs + * + * @param[in] argc Number of char* arguments in argv + * + * @param[in] argv Pointer to first element of null terminated character arrays + * + * @return Error code: + * - SUCCESS: provided command line arguments match expected format: --key=value, --key + * - PARSE_FAILURE: The provided command line arguments do not match expected format + * + */ + virtual Error parseClArgs(size_t argc, char** argv) = 0; + + /** + * @brief Provides passed values for requested key if available + * + * @param[in] key Key string of option + * + * @return (False, empty) if key is not an available argument + * + */ + virtual std::tuple serveArg(const std::string& key) = 0; + + /** + * @brief Checks whether any provided commandline arguments remain unserved + * + * @return True if unconsumed arguments remain, False otherwise + */ + virtual bool allArgumentsServed() const = 0; + + /** + * @brief Validates command line arguments were correctly utilized + * + * @return Error code: + * - SUCCESS: provided command line arguments were utilized following implementations + * policy + * - UNUSED_ARGUMENTS: Some arguments passed were not consumed + * - OVER_SUBSCRIBED_ARGUMENTS: Some arguments were requested by multiple times + * + */ + virtual Error validateUsage() = 0; + + virtual ~ICommandLineManager() = default; + + static bool isKey(const std::string& arg) { + return (arg.length() > keyPrefix().length()) && (arg.find(keyPrefix()) == 0) && + std::isalpha(arg.at(keyPrefix().length())); + } + + static Error parseKey(const std::string& arg, std::string& keyOut) { + if (!isKey(arg)) { + return Error::PARSE_FAILURE; + } + + auto valueSplit = arg.find(keyValueSplit()); + keyOut = valueSplit != arg.npos ? arg.substr(0, valueSplit) : arg; + return Error::SUCCESS; + } + + static Error parseValue(const std::string& arg, std::string& valueOut) { + auto valueSplit = arg.find(keyValueSplit()); + if (valueSplit == arg.npos || valueSplit == arg.length() - 1) { + return Error::PARSE_FAILURE; + } + valueOut = arg.substr(valueSplit + 1); + return Error::SUCCESS; + } + + private: + static const std::string keyPrefix() { return "--"; }; + static char keyValueSplit() { return '='; }; +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp b/Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ad51cb4a6b8ab777cf92c928fa19217f653a9a2 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp @@ -0,0 +1,382 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== +#include +#include +#include + +#include "ClientBuffer.hpp" +#include "IBufferAlloc.hpp" +#include "IOTensor.hpp" +#include "RpcMem.hpp" +#include "QnnTypeMacros.hpp" + +#ifdef _WIN32 + #define __strdup _strdup +#else + #define __strdup strdup +#endif + +IOTensor::IOTensor(BufferAlloc bufferAllocIn, QNN_INTERFACE_VER_TYPE* qnnInterface) + : m_bufferAlloc(bufferAllocIn), m_qnnInterface(qnnInterface), + m_bufferManager(new ClientBuffer()) {} + +bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) { + if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) { + m_bufferManager = std::unique_ptr(new RpcMem(contextHandle, m_qnnInterface)); + } + + if (true != m_bufferManager->initialize()) { + QNN_ERROR("Failed to initialize buffer manager"); + return false; + } + + return true; +} + +IOTensor::~IOTensor() { + if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) { + m_bufferManager->freeFusedBuffers(); + } +} + +// Setup details for Qnn_Tensor_t for execution +// based on information in TensorWrapper provided by model.so. +bool IOTensor::setupTensors( + Qnn_Tensor_t** tensors, + std::unordered_map& tensorNameToTensorPointer, + uint32_t tensorCount, + TensorWrapper* tensorWrappers, + std::unordered_map& tensorsSize, + Qnn_ContextHandle_t contextHandle, + bool skipBufferAllocation +) { + + if (nullptr == tensorWrappers) { + QNN_ERROR("tensorWrappers is nullptr"); + return false; + } + if (0 == tensorCount) { + QNN_DEBUG("tensor count is 0. Nothing to setup."); + return true; + } + + *tensors = (Qnn_Tensor_t*)calloc(1, tensorCount * sizeof(Qnn_Tensor_t)); + if (nullptr == *tensors) { + QNN_ERROR("mem alloc failed for *tensors"); + return false; + } + + auto returnStatus = true; + + uint64_t totalBufferSize = 0; + void* memPointer = nullptr; + int32_t fd = -1; + if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) { + // Calculate the total size of the tensors + for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) { + auto wrapperTensorName = + std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx])); + totalBufferSize += tensorsSize[wrapperTensorName]; + } + QNN_DEBUG("Calculated total size %lu", totalBufferSize); + + if (!skipBufferAllocation) { + // Allocate the buffer of this size + memPointer = m_bufferManager->allocateTensorFusedBuffer(totalBufferSize, &fd); + if (memPointer) { + QNN_DEBUG( + "Successfully allocated a buffer of size %lu, pointer %p, fd %d", + (unsigned long)totalBufferSize, + memPointer, + fd + ); + } else { + QNN_ERROR( + "Not able to allocate buffer of size %lu", (unsigned long)totalBufferSize + ); + return false; + } + } + } + + uint64_t offset = 0; + + for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) { + Qnn_Tensor_t wrapperTensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tensorIdx]); + auto wrapperTensorName = std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx])); + if (true == returnStatus) { + (*tensors)[tensorIdx] = QNN_TENSOR_INIT; + returnStatus = deepCopyQnnTensorInfo(((*tensors) + tensorIdx), &wrapperTensor); + } + if (true == returnStatus) { + size_t tensorDataSize = tensorsSize[wrapperTensorName]; + if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) { + if (!skipBufferAllocation) { + returnStatus = m_bufferManager->mapFusedBufferOffset( + ((*tensors) + tensorIdx), + tensorDataSize, + fd, + offset, + totalBufferSize, + memPointer, + contextHandle + ); + offset += tensorDataSize; + } + } else { + returnStatus = m_bufferManager->allocateTensorBuffer( + ((*tensors) + tensorIdx), tensorDataSize + ); + } + } + if (true != returnStatus) { + QNN_ERROR("Failure in setupTensors, cleaning up resources"); + tearDownTensors(*tensors, tensorIdx); + *tensors = nullptr; + QNN_ERROR("Failure in setupTensors, done cleaning up resources"); + return false; + } else { + tensorNameToTensorPointer.insert({wrapperTensorName, ((*tensors) + tensorIdx)}); + // QNN_DEBUG("allocateBuffer successful"); + } + } + + return returnStatus; +} + +// Setup details for all input tensors for graph execution. +bool IOTensor::setupInputTensors( + Qnn_Tensor_t** inputs, + std::unordered_map& tensorNameToTensorPointer, + const GraphInfo_t& graphInfo, + std::unordered_map& inputTensorsSize, + Qnn_ContextHandle_t contextHandle, + bool skipBufferAllocation +) { + + if (true != setupTensors( + inputs, + tensorNameToTensorPointer, + graphInfo.numInputTensors, + (graphInfo.inputTensors), + inputTensorsSize, + contextHandle, + skipBufferAllocation + )) { + QNN_ERROR("Failure in setupInputTensors, cleaning up resources"); + if (nullptr != *inputs) { + QNN_DEBUG("cleaning up input tensors"); + tearDownTensors(*inputs, graphInfo.numInputTensors); + *inputs = nullptr; + } + QNN_ERROR("Failure in setupInputTensors, done cleaning up resources"); + + return false; + } + + return true; +} + +// Setup details for all output tensors for graph execution. +bool IOTensor::setupOutputTensors( + Qnn_Tensor_t** outputs, + std::unordered_map& tensorNameToTensorPointer, + const GraphInfo_t& graphInfo, + std::unordered_map& outputTensorsSize, + Qnn_ContextHandle_t contextHandle, + bool skipBufferAllocation +) { + + if (true != setupTensors( + outputs, + tensorNameToTensorPointer, + graphInfo.numOutputTensors, + (graphInfo.outputTensors), + outputTensorsSize, + contextHandle, + skipBufferAllocation + )) { + QNN_ERROR("Failure in setupOutputTensors, cleaning up resources"); + if (nullptr != *outputs) { + QNN_DEBUG("cleaning up output tensors"); + tearDownTensors(*outputs, graphInfo.numOutputTensors); + *outputs = nullptr; + } + QNN_ERROR("Failure in setupOutputTensors, done cleaning up resources"); + + return false; + } + + return true; +} + +bool IOTensor::mapFusedBufferOffset( + GraphInfo_t* graph_info, + Qnn_ContextHandle_t context_handle, + const std::map>& graph_allocs +) { + std::lock_guard lk(_tmp_lock); // READ COMMENT IN IOTensor.hpp _tmp_lock + + bool ret = true; + for (const bool mode : {true, false}) { + TensorWrapper* tensor_bank = (mode) ? graph_info->inputTensors : graph_info->outputTensors; + uint32_t num_tensors = (mode) ? graph_info->numInputTensors : graph_info->numOutputTensors; + + for (size_t tidx = 0; tidx < num_tensors; tidx++) { + TensorWrapper& tensor_wrapper = tensor_bank[tidx]; + + Qnn_Tensor_t* tensor = &GET_TENSOR_WRAPPER_TENSOR(tensor_wrapper); + std::string tensor_name = std::string(GET_TENSOR_WRAPPER_NAME(tensor_wrapper)); + + if (!graph_allocs.contains(tensor_name)) continue; + auto& [alloc_idx, offset, size] = graph_allocs.at(tensor_name); + ret &= m_bufferManager->mapFusedBufferOffset( + tensor, alloc_idx, offset, context_handle, size + ); + } + } + + return ret; +} + +// Clean up all tensors related data after execution. +bool IOTensor::tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount) { + + if (nullptr != tensors) { + QNN_DEBUG("cleaning up resources for tensors"); + for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) { + // QNN_DEBUG("freeing resources for tensor: %zu", tensorIdx); + if (nullptr != QNN_TENSOR_GET_DIMENSIONS(&tensors[tensorIdx])) { + // QNN_DEBUG("freeing maxDimensions"); + free(QNN_TENSOR_GET_DIMENSIONS(&tensors[tensorIdx])); + } + if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) { + m_bufferManager->deregisterTensorFusedBuffer(&(tensors[tensorIdx])); + } else { + m_bufferManager->freeTensorBuffer(&(tensors[tensorIdx])); + } + m_freeTensorsPointerSet.insert(&(tensors[tensorIdx])); + } + free(tensors); + tensors = nullptr; + } + + return true; +} + +// Clean up all tensors after execution. +bool IOTensor::tearDownTensors(std::vector& tensors, uint32_t numTensors) { + + for (Qnn_Tensor_t* tensor : tensors) { + tearDownTensors(tensor, numTensors); + } + + return true; +} + +bool IOTensor::tearDownTensors(std::vector& tensors) { + return tearDownTensors(tensors.data(), tensors.size()); +} + +// Clean up all tensors after execution. +bool IOTensor::tearDownTensors( + std::unordered_map& tensors, + std::unordered_map& tensorCountMap +) { + + for (auto& tensor : tensors) { + tearDownTensors(tensor.second, tensorCountMap[tensor.first]); + } + + return true; +} + +// Clean up all tensors after execution. +bool IOTensor::tearDownTensors( + std::vector>& tensors, + std::unordered_map& tensorCountMap +) { + + for (auto& tensor : tensors) { + tearDownTensors(tensor, tensorCountMap); + } + + return true; +} + +bool IOTensor::deepCopyQnnTensorInfo(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) { + + if (nullptr == dest || nullptr == src) { + QNN_ERROR("Received nullptr"); + return false; + } + + // set tensor.version before using QNN_TENSOR_SET macros, as they require the version to be set + // to correctly assign values + dest->version = src->version; + const char* tensorName = QNN_TENSOR_GET_NAME(src); + if (!tensorName) { + QNN_TENSOR_SET_NAME(dest, nullptr); + } else { + QNN_TENSOR_SET_NAME(dest, __strdup(tensorName)); + } + QNN_TENSOR_SET_ID(dest, QNN_TENSOR_GET_ID(src)); + QNN_TENSOR_SET_TYPE(dest, QNN_TENSOR_GET_TYPE(src)); + QNN_TENSOR_SET_DATA_FORMAT(dest, QNN_TENSOR_GET_DATA_FORMAT(src)); + QNN_TENSOR_SET_DATA_TYPE(dest, QNN_TENSOR_GET_DATA_TYPE(src)); + Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT; + qParams.encodingDefinition = QNN_TENSOR_GET_QUANT_PARAMS(src).encodingDefinition; + qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; + if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding == + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding; + qParams.scaleOffsetEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).scaleOffsetEncoding; + } else if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding; + qParams.axisScaleOffsetEncoding.axis = + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.axis; + qParams.axisScaleOffsetEncoding.numScaleOffsets = + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets; + if (QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets > 0) { + qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc( + QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets * + sizeof(Qnn_ScaleOffset_t) + ); + if (qParams.axisScaleOffsetEncoding.scaleOffset) { + for (size_t idx = 0; + idx < QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets; + idx++) { + qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale = + QNN_TENSOR_GET_QUANT_PARAMS(src) + .axisScaleOffsetEncoding.scaleOffset[idx] + .scale; + qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset = + QNN_TENSOR_GET_QUANT_PARAMS(src) + .axisScaleOffsetEncoding.scaleOffset[idx] + .offset; + } + } + } + } + QNN_TENSOR_SET_QUANT_PARAMS(dest, qParams); + QNN_TENSOR_SET_RANK(dest, QNN_TENSOR_GET_RANK(src)); + QNN_TENSOR_SET_DIMENSIONS(dest, nullptr); + if (QNN_TENSOR_GET_RANK(src) > 0) { + QNN_TENSOR_SET_DIMENSIONS( + dest, (uint32_t*)malloc(QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t)) + ); + if (QNN_TENSOR_GET_DIMENSIONS(dest)) { + memcpy(QNN_TENSOR_GET_DIMENSIONS(dest), + QNN_TENSOR_GET_DIMENSIONS(src), + QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t)); + } + } + + return true; +} diff --git a/Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp b/Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4212bd8af669b9f27bc8a2197f6a7735cc76f066 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp @@ -0,0 +1,170 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "IBufferAlloc.hpp" +#include "QnnTypeDef.hpp" +#include "Log.hpp" +#include "QnnBackend.h" +#include "QnnCommon.h" +#include "QnnContext.h" +#include "QnnGraph.h" +#include "QnnInterface.h" +#include "QnnProperty.h" +#include "QnnTensor.h" +#include "QnnTypes.h" +enum class BufferAlloc { + DEFAULT, // malloc based allocator + SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform + INVALID +}; +class IBufferAlloc; +class IOTensor { + public: + IOTensor( + BufferAlloc bufferAllocIn = BufferAlloc::DEFAULT, + QNN_INTERFACE_VER_TYPE* qnnInterface = nullptr + ); + + ~IOTensor(); + + bool initialize(Qnn_ContextHandle_t contextHandle = nullptr); + + bool setupInputTensors( + Qnn_Tensor_t** inputs, + std::unordered_map& tensorNameToTensorPointer, + const GraphInfo_t& graphInfo, + std::unordered_map& inputTensorsSize, + Qnn_ContextHandle_t contextHandle, + bool skipBufferAllocation = false + ); + + bool setupOutputTensors( + Qnn_Tensor_t** outputs, + std::unordered_map& tensorNameToTensorPointer, + const GraphInfo_t& graphInfo, + std::unordered_map& outputTensorsSize, + Qnn_ContextHandle_t contextHandle, + bool skipBufferAllocation = false + ); + + bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount); + + bool tearDownTensors(std::vector& tensors, uint32_t tensorCount); + bool tearDownTensors(std::vector& tensors); + bool tearDownTensors( + std::unordered_map& tensors, + std::unordered_map& tensorCountMap + ); + bool tearDownTensors( + std::vector>& tensors, + std::unordered_map& tensorCountMap + ); + + bool tearDownTensors(const GraphInfo_t* graph_info) { + bool status = true; + if (!tearDownTensors(graph_info->inputTensors, graph_info->numInputTensors)) { + status = false; + QNN_ERROR("Failed to tear down input tensors for graph %s", graph_info->graphName); + } + + if (!tearDownTensors(graph_info->outputTensors, graph_info->numOutputTensors)) { + status = false; + QNN_ERROR("Failed to tear down output tensors for graph %s", graph_info->graphName); + } + return status; + } + + void* getBuffer(Qnn_Tensor_t* tensor) { return m_bufferManager->getBuffer(tensor); }; + + int getFd(Qnn_Tensor_t* tensor) { return m_bufferManager->getFd(tensor); }; + + size_t getOffset(Qnn_Tensor_t* tensor) { return m_bufferManager->getOffset(tensor); }; + + size_t getBufferSize(Qnn_Tensor_t* tensor) { return m_bufferManager->getBufferSize(tensor); }; + + size_t getTotalBufferSize(Qnn_Tensor_t* tensor) { + return m_bufferManager->getTotalBufferSize(tensor); + } + + void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) { + return m_bufferManager->allocateTensorFusedBuffer(bufferSize, fd); + } + + bool allocateBuffers( + const std::map>& allocs_per_chunk, + std::map>& tensor_offsets + ) { + return m_bufferManager->allocateBuffers(allocs_per_chunk, tensor_offsets); + } + + bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + size_t tensorDataSize, + int32_t fd, + uint32_t offset, + uint64_t totalBufferSize, + void* memPointer, + Qnn_ContextHandle_t contextHandle + ) { + return m_bufferManager->mapFusedBufferOffset( + tensor, tensorDataSize, fd, offset, totalBufferSize, memPointer, contextHandle + ); + } + + bool mapFusedBufferOffset( + GraphInfo_t* graph_info, + Qnn_ContextHandle_t context_handle, + const std::map>& graph_allocs + ); + + bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) { + return m_bufferManager->useSameMemory(dest, src); + } + + bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) { + return m_bufferManager->useSameMemory(dest, src, offset); + } + + bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) { + return m_bufferManager->useExternalMemory(dest, extMem); + } + + BufferAlloc getBufferAllocType() { return m_bufferAlloc; } + + std::unordered_set& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; } + + private: + BufferAlloc m_bufferAlloc; + QNN_INTERFACE_VER_TYPE* m_qnnInterface; + std::unique_ptr m_bufferManager; + std::unordered_set m_freeTensorsPointerSet; + + // There seems to be a race condition in mapFusedBufferOffset because we are + // calling it from multiple threads. Maybe memRegister/memDeRegister is not thread-safe + // Until I figure this out, adding a temporary lock here. TODO: Fix and remove this! + std::mutex _tmp_lock; + + bool deepCopyQnnTensorInfo(Qnn_Tensor_t* dest, Qnn_Tensor_t* src); + bool setupTensors( + Qnn_Tensor_t** tensors, + std::unordered_map& tensorNameToTensorPointer, + uint32_t tensorCount, + TensorWrapper* tensorsInfo, + std::unordered_map& tensorsSize, + Qnn_ContextHandle_t contextHandle, + bool skipBufferAllocation = false + ); +}; \ No newline at end of file diff --git a/Genie/Genie/src/qualla/engines/qnn-api/Log.hpp b/Genie/Genie/src/qualla/engines/qnn-api/Log.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b551f35140d20aa94674e7827e084f81ac1dc98 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/Log.hpp @@ -0,0 +1,24 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include + +// FIXME: Use logger from qualla::Env + +#define QNN_INFO(fmt, ...) fprintf(stderr, "[INFO] " #fmt "\n", ##__VA_ARGS__) +#define QNN_ERROR(fmt, ...) fprintf(stderr, "[ERROR] " #fmt "\n", ##__VA_ARGS__) +#define QNN_WARN(fmt, ...) fprintf(stderr, "[WARN] " #fmt "\n", ##__VA_ARGS__) + +#if 0 + // #define NSP_LOG_LEVEL 2 + #define QNN_DEBUG(fmt, ...) fprintf(stderr, "[DEBUG] " #fmt "\n", ##__VA_ARGS__) +#else + #define QNN_DEBUG(fmt, ...) +#endif diff --git a/Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp b/Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ebb7e6eb111f11404e900523d324c0a55b00881b --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp @@ -0,0 +1,173 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include + +#include "ICommandLineManager.hpp" +#include "IBackend.hpp" + +// This is an implementation of IBackend interface within qnn-net-run. +// NetRunBackend provides a dummy implementation of IBackend as a concrete +// implementation is needed in case there is no backend extensions library +// supplied by the user. +// This is built as part of QnnNetRun library and is used in case of no +// user supplied backend extensions implementation. +class NetRunBackend final : public IBackend { + public: + NetRunBackend() {} + + virtual ~NetRunBackend() {} + + virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) override { + ignore(callback); + ignore(maxLogLevel); + return true; + } + + virtual bool initialize(void* backendLibHandle) override { + ignore(backendLibHandle); + return true; + } + + virtual bool setPerfProfile(PerfProfile perfProfile) override { + ignore(perfProfile); + return true; + } + + virtual QnnProfile_Level_t getProfilingLevel() override { return g_profilingLevelNotSet; } + + virtual bool loadConfig(std::string configFile) override { + ignore(configFile); + return true; + } + + virtual bool loadCommandLineArgs(std::shared_ptr clManager) override { + ignore(clManager); + return true; + } + + virtual bool beforeBackendInitialize( + QnnBackend_Config_t*** customConfigs, + uint32_t* configCount + ) override { + ignore(customConfigs); + ignore(configCount); + return true; + } + + virtual bool afterBackendInitialize() override { return true; } + + virtual bool beforeContextCreate(QnnContext_Config_t*** customConfigs, uint32_t* configCount) + override { + ignore(customConfigs); + ignore(configCount); + return true; + } + + virtual bool afterContextCreate() override { return true; } + + virtual bool beforeComposeGraphs(GraphConfigInfo_t*** customGraphConfigs, uint32_t* graphCount) + override { + ignore(customGraphConfigs); + ignore(graphCount); + return true; + } + + virtual bool afterComposeGraphs() override { return true; } + +#if QUALLA_QNN_API_VERSION >= 21700 + virtual bool beforeGraphFinalizeUpdateConfig( + const char* graphName, + Qnn_GraphHandle_t graphHandle, + QnnGraph_Config_t*** customConfigs, + uint32_t* configCount + ) override { + ignore(graphName); + ignore(graphHandle); + ignore(customConfigs); + ignore(configCount); + return true; + } +#endif + + virtual bool beforeGraphFinalize() override { return true; } + + virtual bool afterGraphFinalize() override { return true; } + + virtual bool beforeRegisterOpPackages() override { return true; } + + virtual bool afterRegisterOpPackages() override { return true; } + + virtual bool beforeExecute( + const char* graphName, + QnnGraph_Config_t*** customConfigs, + uint32_t* configCount + ) override { + ignore(graphName); + ignore(customConfigs); + ignore(configCount); + return true; + } + + virtual bool afterExecute() override { return true; } + + virtual bool beforeContextFree() override { return true; } + + virtual bool afterContextFree() override { return true; } + + virtual bool beforeBackendTerminate() override { return true; } + + virtual bool afterBackendTerminate() override { return true; } + + virtual bool beforeCreateFromBinary(QnnContext_Config_t*** customConfigs, uint32_t* configCount) + override { + ignore(customConfigs); + ignore(configCount); + return true; + } + + virtual bool afterCreateFromBinary() override { return true; } + +#if QUALLA_QNN_API_VERSION >= 21700 + virtual bool beforeCreateContextsFromBinaryList( + std::map>* + contextKeyToCustomConfigsMap, + QnnContext_Config_t*** commonCustomConfigs, + uint32_t* commonConfigCount + ) override { + ignore(contextKeyToCustomConfigsMap); + ignore(commonCustomConfigs); + ignore(commonConfigCount); + return true; + } + + virtual bool afterCreateContextsFromBinaryList() override { return true; } +#endif + + virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, uint32_t* configCount) + override { + ignore(deviceConfigs); + ignore(configCount); + return true; + } + + virtual bool afterCreateDevice() override { return true; } + + virtual bool beforeFreeDevice() override { return true; } + + virtual bool afterFreeDevice() override { return true; } + + private: + // Utility function to ignore compiler warnings when a variable + // is unused. Recommended by Herb Sutter in Sutter's Mill + // instead of (void)variable. + template + void ignore(const T&) {} +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e7abc2b8db5d23ce164b037e9be2a1cfd0597415 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp @@ -0,0 +1,2681 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#if defined(__GNUC__) && !defined(__clang__) +#include +#endif +#ifndef _WIN32 + #include +#endif + +#include "dlwrap.hpp" +#include "QnnApi.hpp" + +#ifdef SPILLFILL + #include "QnnHtpContext.h" + #include "QnnHtpCommon.h" +#endif + +QnnApi::~QnnApi() { + // QNN_DEBUG("Destroying Performance"); + // if (true != destroyPerformance()) { + // QNN_DEBUG("Could not destroy Performance"); + // } + + QNN_DEBUG("Freeing Graphs"); + if (true != freeGraphs()) { + QNN_DEBUG("Could not free Graphs"); + } + + // Free context if not already done + if (m_isContextCreated) { + QNN_DEBUG("Freeing Context"); + if (true != freeContext()) { + QNN_DEBUG("Could not free context"); + } + } + + if (m_profileBackendHandle) { + QNN_DEBUG("Freeing profile handle"); + if (QNN_PROFILE_NO_ERROR != m_qnnInterface.profileFree(m_profileBackendHandle)) + QNN_ERROR("Could not free QNN HTP backend profile handle."); + } + + QNN_DEBUG("Freeing Device"); + if (getDeviceStatus()) { + if (true != freeDevice()) { + QNN_ERROR("Device Free failure"); + } + } + + QNN_DEBUG("Terminating Logging"); + if (m_isLogInitialized) { + terminateLog(); + } + m_isLogInitialized = false; + + // Terminate backend + if (m_isBackendInitialized) { + QNN_DEBUG("Terminating Backend"); + if (true != terminateBackend()) { + QNN_DEBUG("Could not terminate backend"); + } + } + + // Skip dlclose for HTP because it runs its own cleanup routines later. + if (m_backendLibraryHandle && (m_backendId != QNN_BACKEND_ID_HTP)) { + QNN_DEBUG("Closing Backend Lib Handle"); + dlclose(m_backendLibraryHandle); + } + + if (m_libModelHandle) { + QNN_DEBUG("Closing Model Lib Handle"); + dlclose(m_libModelHandle); + } + + if (!m_contextBinBuffersToBeCleared.empty()) { + for (auto& [buffer, bufferSize] : m_contextBinBuffersToBeCleared) { + QNN_DEBUG("Free context bin buffer %p of size %lu", buffer, bufferSize); + if (m_mmapContextBins) { +#ifndef _WIN32 + if (munmap(buffer, bufferSize)) { + QNN_ERROR("Failed to unmap buffer for context"); + } +#endif + } else { + delete[] buffer; + } + } + m_contextBinBuffersToBeCleared.clear(); + } +} + +bool QnnApi::getContextConfigs( + QnnContext_Config_t*** configs, + uint32_t& contextConfigCount, + Qnn_Priority_t contextPriority, + bool graphSwitching, + const std::vector& execSelectGraphs, + bool loadSelectGraphs +) { + std::vector contextConfigPtrsVec; + + if (contextPriority != QNN_PRIORITY_DEFAULT) { + contextConfigPtrsVec.push_back((QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t))); + contextConfigPtrsVec.back()->option = + QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_OPTION_PRIORITY; + contextConfigPtrsVec.back()->priority = contextPriority; + } + + const char** graphNames = nullptr; + + if (loadSelectGraphs && !execSelectGraphs.empty()) { + graphNames = (const char**)malloc(sizeof(const char*) * (execSelectGraphs.size() + 1)); + for (size_t i = 0; i < execSelectGraphs.size(); ++i) { + graphNames[i] = execSelectGraphs[i].c_str(); + } + + graphNames[execSelectGraphs.size()] = nullptr; // NULL termination + contextConfigPtrsVec.push_back((QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t))); + contextConfigPtrsVec.back()->option = + QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_ENABLE_GRAPHS; + contextConfigPtrsVec.back()->enableGraphs = graphNames; + } + + if (graphSwitching) { + contextConfigPtrsVec.push_back((QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t))); + contextConfigPtrsVec.back()->option = + QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_MEMORY_LIMIT_HINT; + contextConfigPtrsVec.back()->memoryLimitHint = 1024; + + contextConfigPtrsVec.push_back((QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t))); + contextConfigPtrsVec.back()->option = + QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_PERSISTENT_BINARY; + contextConfigPtrsVec.back()->isPersistentBinary = 1; + } + + contextConfigCount = contextConfigPtrsVec.size(); + + QnnContext_Config_t** contextConfigPtrs = + (QnnContext_Config_t**)malloc(contextConfigCount * sizeof(QnnContext_Config_t*)); + + if (nullptr == contextConfigPtrs) { + QNN_ERROR("Could not allocate memory for allContextConfigs"); + return false; + } + + for (size_t i = 0; i < contextConfigCount; i++) { + contextConfigPtrs[i] = contextConfigPtrsVec[i]; + } + + *configs = contextConfigPtrs; + + return true; +} + +bool QnnApi::mergeAllContextConfigs( + QnnContext_Config_t*** allCustomContextConfigs, + QnnContext_Config_t** customConfigs, + QnnContext_Config_t** contextConfigs, + uint32_t customConfigCount, + uint32_t contextConfigCount +) { + QnnContext_Config_t** allContextConfigs{nullptr}; + if (contextConfigCount + customConfigCount > 0) { + allContextConfigs = (QnnContext_Config_t**)calloc( + (contextConfigCount + customConfigCount + 1), sizeof(QnnContext_Config_t*) + ); + if (nullptr == allContextConfigs) { + QNN_ERROR("Could not allocate memory for allContextConfigs"); + return false; + } + for (size_t cnt = 0; cnt < contextConfigCount; cnt++) { + allContextConfigs[cnt] = contextConfigs[cnt]; + } + for (size_t cnt = 0; cnt < customConfigCount; cnt++) { + allContextConfigs[cnt + contextConfigCount] = customConfigs[cnt]; + } + } + *allCustomContextConfigs = allContextConfigs; + + return true; +} + +bool QnnApi::freeContextConfigs(QnnContext_Config_t** contextConfigs, uint32_t contextConfigCount) { + if (contextConfigs) { + for (size_t i = 0; i < contextConfigCount; i++) { + if (contextConfigs[i]->option == QNN_CONTEXT_CONFIG_ENABLE_GRAPHS) { + free((const char**)contextConfigs[i]->enableGraphs); + } + free(contextConfigs[i]); + } + free(contextConfigs); + } + + return true; +} + +bool QnnApi::setGraphConfigsBeforeExecute( + Qnn_GraphHandle_t graphHandle, + QnnGraph_Config_t** graphConfigs, + uint32_t configCount +) { + if (!graphConfigs || configCount == 0u) { + QNN_ERROR("No graph configs to set"); + return false; + } + + std::vector graphConfigsPointers(configCount + 1, nullptr); + for (size_t idx = 0u; idx < configCount; idx++) { + graphConfigsPointers[idx] = graphConfigs[idx]; + } + if (QNN_SUCCESS != m_qnnInterface.graphSetConfig(graphHandle, graphConfigsPointers.data())) { + QNN_ERROR("Failed to set graph configs."); + return false; + } + + return true; +} + +bool QnnApi::getQnnInterface(std::string backendPath) { + + QnnInterfaceGetProvidersFn_t getInterfaceProviders{nullptr}; + + m_backendLibraryHandle = dlopen(backendPath.c_str(), RTLD_NOW); + if (nullptr == m_backendLibraryHandle) { + QNN_ERROR("Unable to load backend. dlerror(): %s", dlerror()); + return false; + } + + // Get QNN Interface + getInterfaceProviders = (QnnInterfaceGetProvidersFn_t + )dlsym(m_backendLibraryHandle, "QnnInterface_getProviders"); + if (nullptr == getInterfaceProviders) { + return false; + } + + uint32_t numProviders{0}; + QnnInterface_t** interfaceProviders{nullptr}; + if (QNN_SUCCESS != + getInterfaceProviders((const QnnInterface_t***)&interfaceProviders, &numProviders)) { + QNN_ERROR("Failed to get interface providers."); + return false; + } + + if (nullptr == interfaceProviders) { + QNN_ERROR("Failed to get interface providers: null interface providers received."); + return false; + } + if (0u == numProviders) { + QNN_ERROR("Failed to get interface providers: 0 interface providers."); + return false; + } + + bool foundValidInterface{false}; + for (size_t pIdx = 0; pIdx < numProviders; pIdx++) { + const Qnn_ApiVersion_t& apiVersion = interfaceProviders[pIdx]->apiVersion; + if ((QNN_API_VERSION_MAJOR == apiVersion.coreApiVersion.major) && + (QNN_API_VERSION_MINOR <= apiVersion.coreApiVersion.minor)) { + foundValidInterface = true; + m_qnnInterface = interfaceProviders[pIdx]->QNN_INTERFACE_VER_NAME; + m_backendId = interfaceProviders[pIdx]->backendId; + break; + } + } + + if (!foundValidInterface) { + QNN_ERROR("Unable to find a valid interface."); + m_backendLibraryHandle = nullptr; + return false; + } + + return true; +} + +bool QnnApi::getQnnSystemInterface(std::string systemLibraryPath) { + QnnSystemInterfaceGetProvidersFn_t getSystemInterfaceProviders{nullptr}; + + void* systemLibraryHandle = dlopen(systemLibraryPath.c_str(), RTLD_NOW); + if (nullptr == systemLibraryHandle) { + QNN_ERROR("Unable to load system library. dlerror(): %s", dlerror()); + return false; + } + + // Get QNN System Interface + getSystemInterfaceProviders = (QnnSystemInterfaceGetProvidersFn_t + )dlsym(systemLibraryHandle, "QnnSystemInterface_getProviders"); + if (nullptr == getSystemInterfaceProviders) { + return false; + } + + uint32_t numProviders{0}; + QnnSystemInterface_t** systemInterfaceProviders{nullptr}; + if (QNN_SUCCESS != + getSystemInterfaceProviders( + (const QnnSystemInterface_t***)&systemInterfaceProviders, &numProviders + )) { + QNN_ERROR("Failed to get system interface providers."); + return false; + } + if (nullptr == systemInterfaceProviders) { + QNN_ERROR( + "Failed to get system interface providers: null system interface providers received." + ); + return false; + } + if (0 == numProviders) { + QNN_ERROR("Failed to get system interface providers: 0 system interface providers."); + return false; + } + + bool foundValidSystemInterface{false}; + for (size_t pIdx = 0; pIdx < numProviders; pIdx++) { + const Qnn_Version_t& systemApiVersion = systemInterfaceProviders[pIdx]->systemApiVersion; + if (QNN_SYSTEM_API_VERSION_MAJOR == systemApiVersion.major && + QNN_SYSTEM_API_VERSION_MINOR <= systemApiVersion.minor) { + foundValidSystemInterface = true; + m_qnnSystemInterface = systemInterfaceProviders[pIdx]->QNN_SYSTEM_INTERFACE_VER_NAME; + break; + } + } + if (!foundValidSystemInterface) { + QNN_ERROR("Unable to find a valid system interface."); + return false; + } + + return true; +} + +bool QnnApi::loadModel(std::string model_path) { + const char* dlsym_error; + + dlerror(); + m_libModelHandle = dlopen(model_path.c_str(), RTLD_NOW); + if (nullptr == m_libModelHandle) { + QNN_ERROR("Unable to load model. dlerror(): %s", dlerror()); + return false; + } + + // Currently model Prefix is fixed. If model was prepared with + // custom prefix, we need to change this. + std::string modelPrefix = "QnnModel"; + + std::string modelPrepareFunc = modelPrefix + "_composeGraphs"; + m_composeGraphsFnHandle = + (ComposeGraphsFnHandleType_t)dlsym(m_libModelHandle, modelPrepareFunc.c_str()); + dlsym_error = dlerror(); + if (dlsym_error || nullptr == m_composeGraphsFnHandle) { + m_composeGraphsFnHandle = nullptr; + std::string genaiModelPrepareFunc = "QnnModel_GenAI_composeGraphs"; + m_genaiComposeGraphsFnHandle = (GenAIComposeGraphsFnHandleType_t + )dlsym(m_libModelHandle, genaiModelPrepareFunc.c_str()); + dlsym_error = dlerror(); + if (dlsym_error || nullptr == m_genaiComposeGraphsFnHandle) { + QNN_ERROR("Did not find QnnModel_composeGraph function: %s", dlsym_error); + return false; + } + } + + std::string modelFreeFunc = modelPrefix + "_freeGraphsInfo"; + m_freeGraphInfoFnHandle = + (FreeGraphInfoFnHandleType_t)dlsym(m_libModelHandle, modelFreeFunc.c_str()); + dlsym_error = dlerror(); + if (dlsym_error || nullptr == m_freeGraphInfoFnHandle) { + QNN_ERROR("Did not find QnnModel_freeGraphsInfo function: %s", dlsym_error); + return false; + } + + return true; +} + +void QnnApi::qnnLogCallback( + const char* fmt, + QnnLog_Level_t level, + uint64_t timestamp, + va_list args +) { + char buffer[1024] = ""; + const char* levelStr = ""; + switch (level) { + case QNN_LOG_LEVEL_ERROR: + levelStr = " ERROR "; + break; + case QNN_LOG_LEVEL_WARN: + levelStr = "WARNING"; + break; + case QNN_LOG_LEVEL_INFO: + levelStr = " INFO "; + break; + case QNN_LOG_LEVEL_DEBUG: + levelStr = " DEBUG "; + break; + case QNN_LOG_LEVEL_VERBOSE: + levelStr = "VERBOSE"; + break; + case QNN_LOG_LEVEL_MAX: + levelStr = "UNKNOWN"; + break; + } + + int pos = snprintf( + buffer, sizeof(buffer), "QNN: [%s] time=%lu:", levelStr, (unsigned long)timestamp + ); + vsnprintf(buffer + pos, sizeof(buffer) - pos, fmt, args); + printf("%s", buffer); +} + +bool QnnApi::initializeLogging(const QnnLog_Level_t& logLevel, bool debug_qnn) { + // initialize logging in the backend + if (nullptr != m_qnnInterface.logCreate) { + QnnLog_Callback_t logCallback = nullptr; + if (debug_qnn) logCallback = QnnApi::qnnLogCallback; + + QNN_DEBUG( + "Initializing logging in the backend. Callback: [%p], Log Level: [%d]", + logCallback, + logLevel + ); + if (QNN_SUCCESS != m_qnnInterface.logCreate(logCallback, logLevel, &m_logHandle)) { + QNN_WARN("Unable to initialize logging in the backend."); + } + m_isLogInitialized = true; + } + else { + QNN_WARN("Logging not available in the backend."); + return true; + } + + return true; +} + +void QnnApi::terminateLog() { + // Terminate logging in the backend + if (nullptr != m_qnnInterface.logFree && nullptr != m_logHandle) { + if (QNN_SUCCESS != m_qnnInterface.logFree(m_logHandle)) { + QNN_WARN("Unable to terminate logging in the backend."); + } + } +} + +bool QnnApi::initializeBackendExtensions( + BackendExtensionsConfigs backendExtensionsConfig, + PerfProfile parsedPerfProfile, + bool debug_qnn +) { + + std::unique_ptr backendExtensions(new BackendExtensions( + backendExtensionsConfig, m_backendLibraryHandle, parsedPerfProfile, nullptr, debug_qnn + )); + if (nullptr == backendExtensions) { + QNN_ERROR("Unable to create backend extensions object."); + return false; + } + if (!backendExtensions->initialize()) { + QNN_ERROR("Unable to initialize backend extensions."); + return false; + } + m_backendExtensions = std::move(backendExtensions); + + return true; +} + +// Initialize a QnnBackend. +bool QnnApi::initializeBackend() { + if (nullptr == m_qnnInterface.backendCreate) { + QNN_ERROR("BackendCreate API is not supported for this backend"); + return false; + } + + QnnBackend_Config_t** customConfigs{nullptr}; + uint32_t customConfigCount{0}; + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeBackendInitialize( + &customConfigs, &customConfigCount + )) { + QNN_ERROR("Extensions Failure in beforeBackendInitialize()"); + return false; + } + } + QnnBackend_Config_t** allBackendConfigs{nullptr}; + if ((m_backendConfigCount + customConfigCount) > 0) { + allBackendConfigs = (QnnBackend_Config_t**)calloc( + (m_backendConfigCount + customConfigCount + 1), sizeof(QnnBackend_Config_t*) + ); + if (nullptr == allBackendConfigs) { + QNN_ERROR("Could not allocate memory for allBackendConfigs"); + return false; + } + for (size_t cnt = 0; cnt < m_backendConfigCount; cnt++) { + allBackendConfigs[cnt] = m_backendConfigs[cnt]; + } + for (size_t cnt = 0; cnt < customConfigCount; cnt++) { + allBackendConfigs[cnt + m_backendConfigCount] = customConfigs[cnt]; + } + } + + auto returnStatus = m_qnnInterface.backendCreate( + m_logHandle, (const QnnBackend_Config_t**)allBackendConfigs, &m_backendHandle + ); + if (QNN_SUCCESS != returnStatus) { + QNN_ERROR( + "Could not initialize backend due to error = %llu", (unsigned long long)returnStatus + ); + if (allBackendConfigs) { + free(allBackendConfigs); + } + return false; + } + QNN_DEBUG("Initialize Backend Returned Status = %llu", (unsigned long long)returnStatus); + + m_isBackendInitialized = true; + if (allBackendConfigs) { + free(allBackendConfigs); + } + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterBackendInitialize()) { + QNN_ERROR("Extensions Failure in afterBackendInitialize()"); + return false; + } + } + + return true; +} + +// Terminate the backend after done. +bool QnnApi::terminateBackend() { + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeBackendTerminate()) { + QNN_ERROR("Extensions Failure in beforeBackendTerminate()"); + return false; + } + } + // Terminate backend + if (m_isBackendInitialized && nullptr != m_qnnInterface.backendFree) { + QNN_DEBUG("Freeing backend"); + if (QNN_BACKEND_NO_ERROR != m_qnnInterface.backendFree(m_backendHandle)) { + QNN_ERROR("Could not free backend"); + } + } + m_isBackendInitialized = false; + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterBackendTerminate()) { + QNN_ERROR("Extensions Failure in afterBackendTerminate()"); + return false; + } + } + + return true; +} + +bool QnnApi::createDevice() { + QnnDevice_Config_t** deviceConfigs{nullptr}; + uint32_t configCount{0}; + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeCreateDevice(&deviceConfigs, &configCount)) { + QNN_ERROR("Extensions Failure in beforeCreateDevice()"); + return false; + } + } + std::vector deviceConfigPointers(configCount + 1, nullptr); + for (size_t idx = 0u; idx < configCount; idx++) { + deviceConfigPointers[idx] = deviceConfigs[idx]; + } + if (nullptr != m_qnnInterface.deviceCreate) { + auto qnnStatus = m_qnnInterface.deviceCreate( + m_logHandle, deviceConfigPointers.data(), &m_deviceHandle + ); + if (QNN_SUCCESS != qnnStatus) { + if (QNN_DEVICE_ERROR_UNSUPPORTED_FEATURE == qnnStatus) { + QNN_WARN("Device feature unsupported"); + } else { + QNN_ERROR("Failed to create device: %lu", (unsigned long)qnnStatus); + return false; + } + } + } + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterCreateDevice()) { + QNN_ERROR("Extensions Failure in afterCreateDevice()"); + return false; + } + } + return true; +} + +bool QnnApi::freeDevice() { + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeFreeDevice()) { + QNN_ERROR("Extensions Failure in beforeFreeDevice()"); + return false; + } + } + if (nullptr != m_qnnInterface.deviceFree) { + auto qnnStatus = m_qnnInterface.deviceFree(m_deviceHandle); + if (QNN_SUCCESS != qnnStatus) { + if (QNN_DEVICE_ERROR_UNSUPPORTED_FEATURE == qnnStatus) { + QNN_WARN("Device feature unsupported"); + } else { + QNN_ERROR("Failed to free device: %lu", (unsigned long)qnnStatus); + return false; + } + } + } + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterFreeDevice()) { + QNN_ERROR("Extensions Failure in afterfreeDevice()"); + return false; + } + } + return true; +} + +// Create a Context in a backend. +bool QnnApi::createContext(ContextConfigs contextConfig) { + QnnContext_Config_t** customConfigs{nullptr}; + uint32_t customConfigCount{0}; + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeContextCreate( + &customConfigs, &customConfigCount + )) { + QNN_ERROR("Extensions Failure in beforeContextCreate()"); + return false; + } + } + + QnnContext_Config_t** contextConfigs = nullptr; + uint32_t contextConfigCount = 0; + if (true != getContextConfigs(&contextConfigs, contextConfigCount, contextConfig.priority)) { + QNN_ERROR("Couldn't populate context configs"); + return false; + } + + QnnContext_Config_t** allContextConfigs{nullptr}; + if (true != mergeAllContextConfigs( + &allContextConfigs, + customConfigs, + contextConfigs, + customConfigCount, + contextConfigCount + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } + + Qnn_ContextHandle_t contextHandle{nullptr}; + if (QNN_CONTEXT_NO_ERROR != m_qnnInterface.contextCreate( + m_backendHandle, + nullptr, + (const QnnContext_Config_t**)allContextConfigs, + &contextHandle + )) { + QNN_ERROR("Could not create context"); + if (allContextConfigs) { + free(allContextConfigs); + } + + return false; + } + + m_contextVec.push_back(contextHandle); + m_isContextCreated = true; + if (allContextConfigs) { + free(allContextConfigs); + } + + if (true != freeContextConfigs(contextConfigs, contextConfigCount)) { + QNN_ERROR("Couldn't free context configs"); + return false; + } + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterContextCreate()) { + QNN_ERROR("Extensions Failure in afterContextCreate()"); + return false; + } + } + + return true; +} + +// Free context after done. +bool QnnApi::freeContext() { + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeContextFree()) { + QNN_ERROR("Extensions Failure in beforeContextFree()"); + return false; + } + } + for (const auto& context : m_contextVec) { + if (QNN_CONTEXT_NO_ERROR != m_qnnInterface.contextFree(context, nullptr)) { + QNN_ERROR("Could not free context"); + return false; + } + } + m_isContextCreated = false; + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterContextFree()) { + QNN_ERROR("Extensions Failure in afterContextFree()"); + return false; + } + } + + return true; +} + +// Calls composeGraph function in QNN's model.so. +// composeGraphs is supposed to populate graph related +// information in graphsInfo and graphsCount. +// m_debug is the option supplied to composeGraphs to +// say that all intermediate tensors including output tensors +// are expected to be read by the app. +bool QnnApi::composeGraphs(std::vector graphConfigs) { + GraphConfigInfo_t** customConfigs{nullptr}; + uint32_t customConfigGraphsCount{0}; + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeComposeGraphs( + &customConfigs, &customConfigGraphsCount + )) { + QNN_ERROR("Extensions Failure in beforeComposeGraphs()"); + return false; + } + } + + std::map> graphConfigsPointers; + if (!graphConfigs.empty()) { + for (auto const& inputGraphConfig : graphConfigs) { + // Only reset the memory for this graph, if it has not previously been populated with + // something + if (graphConfigsPointers.find(inputGraphConfig.graphName) == + graphConfigsPointers.end()) { + graphConfigsPointers[inputGraphConfig.graphName] = + std::vector(); + graphConfigsPointers[inputGraphConfig.graphName].reserve(s_graphConfigsReserveCount + ); + } + if (inputGraphConfig.priorityPresent) { + QnnGraph_Config_t* newGraphConfig = + (QnnGraph_Config_t*)malloc(sizeof(QnnGraph_Config_t)); + newGraphConfig->option = QNN_GRAPH_CONFIG_OPTION_PRIORITY; + newGraphConfig->priority = inputGraphConfig.priority; + graphConfigsPointers[inputGraphConfig.graphName].push_back(newGraphConfig); + } + } + } + + if (customConfigs != nullptr && customConfigGraphsCount > 0) { + for (size_t gIdx = 0; gIdx < customConfigGraphsCount; gIdx++) { + auto configPtr = customConfigs[gIdx]->graphConfigs; + if (*configPtr && + (!customConfigs[gIdx]->graphName || strlen(customConfigs[gIdx]->graphName) == 0)) { + QNN_ERROR("Graph configs specified without a graph name in the backend extensions." + ); + return false; + } + if (customConfigs[gIdx]->graphName && strlen(customConfigs[gIdx]->graphName) > 0 && + *configPtr) { + if (graphConfigsPointers.find(customConfigs[gIdx]->graphName) == + graphConfigsPointers.end()) { + graphConfigsPointers[customConfigs[gIdx]->graphName] = + std::vector(); + graphConfigsPointers[customConfigs[gIdx]->graphName].reserve( + s_graphConfigsReserveCount + ); + } + while (*configPtr) { + graphConfigsPointers[customConfigs[gIdx]->graphName].push_back( + (QnnGraph_Config_t*)*configPtr + ); + configPtr++; + } + } + } + } + + GraphConfigInfo_t** graphConfigsInfo{nullptr}; + graphConfigsInfo = + (GraphConfigInfo_t**)calloc(graphConfigsPointers.size(), sizeof(GraphConfigInfo_t*)); + size_t graphIdx{0}; + for (auto const& graphConfig : graphConfigsPointers) { + if (graphConfigsInfo && graphConfig.second.size() > 0) { + graphConfigsInfo[graphIdx] = (GraphConfigInfo_t*)malloc(sizeof(GraphConfigInfo_t)); + graphConfigsInfo[graphIdx]->graphName = (char*)graphConfig.first.c_str(); + graphConfigsInfo[graphIdx]->graphConfigs = (const QnnGraph_Config_t**)calloc( + graphConfig.second.size() + 1, sizeof(QnnGraph_Config_t*) + ); + for (size_t cnt = 0; cnt < graphConfig.second.size(); cnt++) { + graphConfigsInfo[graphIdx]->graphConfigs[cnt] = graphConfig.second[cnt]; + } + } + graphIdx++; + } + + int status = m_composeGraphsFnHandle( + m_backendHandle, + m_qnnInterface, + m_contextVec[0], + (const GraphConfigInfo_t**)graphConfigsInfo, + graphConfigsPointers.size(), + &m_graphsInfo, + &m_graphsCount, + m_DebugModeRequested, + nullptr, + QnnLog_Level_t::QNN_LOG_LEVEL_VERBOSE + ); + + if (graphConfigsInfo) { + for (size_t gIdx = 0; gIdx < graphConfigsPointers.size(); gIdx++) { + if (graphConfigsInfo[gIdx]) { + if (graphConfigsInfo[gIdx]->graphConfigs) { + free(graphConfigsInfo[gIdx]->graphConfigs); + graphConfigsInfo[gIdx]->graphConfigs = nullptr; + graphConfigsInfo[gIdx]->graphName = nullptr; + } + free(graphConfigsInfo[gIdx]); + graphConfigsInfo[gIdx] = nullptr; + } + } + free(graphConfigsInfo); + } + + for (auto const& graphConfig : graphConfigsPointers) { + for (size_t cnt = 0; cnt < graphConfig.second.size(); cnt++) { + if (graphConfig.second[cnt]) { + free(graphConfig.second[cnt]); + } + } + // graphConfig.second.clear(); + } + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterComposeGraphs()) { + QNN_ERROR("Extensions Failure in afterComposeGraphs()"); + return false; + } + } + + if (0 != status) { + QNN_ERROR("Failed in composeGraphs()"); + return false; + } + + // For now, we only handle 1 graph for this framework. + if (m_graphsCount != 1) { + QNN_ERROR("Only one graph is supported by framework"); + return false; + } + + return true; +} + +bool QnnApi::composeGraphs( + std::vector graphConfigs, + uint32_t* inputDim, + uint32_t inputRank, + uint32_t* outputDim, + uint32_t outputRank, + uint32_t* kvDim, + uint32_t kvRank, + Qnn_Param_t* params, + uint32_t numParams +) { + ModelError status = m_genaiComposeGraphsFnHandle( + m_backendHandle, + m_qnnInterface, + m_contextVec[0], + nullptr, + 0, + inputDim, + inputRank, + outputDim, + outputRank, + kvDim, + kvRank, + params, + numParams, + &m_graphsInfo, + &m_graphsCount, + m_DebugModeRequested, + nullptr, + QnnLog_Level_t::QNN_LOG_LEVEL_VERBOSE + ); + + if (status == MODEL_NO_ERROR) { + return true; + } + + return false; +} + +bool QnnApi::finalizeGraphs() { + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeGraphFinalize()) { + QNN_ERROR("Extensions Failure in beforeGraphFinalize()"); + return false; + } + } + + for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) { + if (QNN_GRAPH_NO_ERROR != + m_qnnInterface.graphFinalize(m_graphsInfo[graphIdx]->graph, nullptr, nullptr)) { + return false; + } + + if (m_profileBackendHandle) { + extractBackendProfilingInfo(m_profileBackendHandle); + } + } + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterGraphFinalize()) { + QNN_ERROR("Extensions Failure in afterGraphFinalize()"); + return false; + } + } + + return true; +} + +bool QnnApi::freeGraphs() { + freeGraphsInfo(&m_graphsInfo, m_graphsCount); + if (m_graphsInfo) { + free(m_graphsInfo); + } + m_graphsInfo = nullptr; + m_graphsCount = 0; + return true; +} + +bool QnnApi::mapAndGetContextBinaryInfo( + const bool use_mmap, + std::shared_ptr& buffer, + const std::string binaryPath, + const uint64_t bufferSize, + const size_t contextIdx, + const bool graphSwitching, + QnnSystemContext_Handle_t sysCtxHandle, + const QnnSystemContext_BinaryInfo_t** binaryInfo +) { + if (use_mmap) { +#ifndef _WIN32 + void* mappedBuffer = nullptr; + if (true != mmapBinaryFile(binaryPath, &mappedBuffer, bufferSize)) { + QNN_ERROR("Failed to read binary data for context index = %zu", contextIdx); + return false; + } + buffer = std::shared_ptr( + static_cast(mappedBuffer), + [graphSwitching, bufferSize](uint8_t* ptr) { + if (!graphSwitching) { + munmap(ptr, bufferSize); + } + } + ); +#else + return false; +#endif + } else { + buffer = std::shared_ptr(new uint8_t[bufferSize], [graphSwitching](uint8_t* ptr) { + if (!graphSwitching) { + delete[] ptr; + } + }); + + if (!buffer) { + QNN_ERROR("Failed to allocate memory for context index = %zu", contextIdx); + return false; + } + if (true != readBinaryFromFile(binaryPath, buffer.get(), bufferSize)) { + QNN_ERROR("Failed to read binary data for context index = %zu", contextIdx); + return false; + } + } + + if (graphSwitching) { + m_contextBinBuffersToBeCleared.push_back({buffer.get(), bufferSize}); + } + + Qnn_ContextBinarySize_t binaryInfoSize{0}; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextGetBinaryInfo( + sysCtxHandle, + static_cast(buffer.get()), + bufferSize, + binaryInfo, + &binaryInfoSize + )) { + QNN_ERROR("Failed to get context binary info for context index = %zu", contextIdx); + return false; + } + + return true; +} + +bool QnnApi::parseIOTensorsAndAccumulate(){ + + for(int gIdx =0;gIdxnumInputTensors : graph_info->numOutputTensors; + auto tensor_wrappers = (io) ? graph_info->inputTensors : graph_info->outputTensors; + for (size_t tensor_idx = 0; tensor_idx < n_tensors; tensor_idx++) { + + TensorWrapper& tensor = tensor_wrappers[tensor_idx]; + std::string tensor_name = QnnApi::getTensorName(tensor); + + std::vector tensor_dims; + if (!QnnApi::getTensorShape(tensor_dims, tensor)){ + QNN_ERROR("Couldn't get tensor shape : %s", tensor_name.c_str()); + return false; + } + + std::vector quantParams; + if (!QnnApi::getTensorQuantParams(&tensor_wrappers[tensor_idx], quantParams)) { + quantParams.emplace_back(0, 0); + } + + m_graphtoIOMap[gIdx][tensor_name] = + qualla::QnnUtils::Tensor(tensor_wrappers + tensor_idx, tensor_dims, quantParams); + } + } + } + + + // Maps tensor_name to context bitVector, each bit representing a context the tensor exists in + std::map tensor_ctx_map; + // Maps a ContextHandle to a one-hot encoded bitVector (e.g. 1, 2, 4, ...) + std::map ctx_to_hash; + + // Iterate over all tensors in all GraphVariants to figure out allocations + for(int gIdx =0;gIdxsecond.empty()) ? m_contextAllocMap.erase(it) : ++it; + } + +#if QNN_IO_TENSOR_DEBUG +for(auto& [bitvector, nameMap] : m_contextAllocMap){ + for(auto& [tname, size] : nameMap) + QNN_DEBUG("Context: %d Tensor name: %s Tensor size: %zu",bitvector,tname.c_str(),size); + } +#endif + return true; +} + +bool QnnApi::registerTensorsWithBackend(uint32_t& graphIdx){ + + std::map> graph_allocs; + for(auto& [tname,tspec] : m_graphtoIOMap[graphIdx]){ + + if (tname.starts_with("past_") && tname.ends_with("_in")) continue; // Process past_key/value_Inputs along with the outputs + auto& [alloc_idx, offset] = m_tensorAllocInfo.at(tname); + + size_t kv_offset = 0; + size_t size = tspec.dims.getAlignedSize(); + if (tname.starts_with("past_")) { + auto in_name = tname.substr(0, tname.rfind("_")).append("_in"); + if (m_graphtoIOMap[graphIdx].count(in_name)) { + auto kv_in = m_graphtoIOMap[graphIdx][in_name]; + kv_offset = kv_in.dims.getAlignedSize(); + if (m_kvUpdateMethod == POINTER_SHIFT) + kv_offset += (tname.starts_with("past_key")) ? m_ctxSize + : m_ctxSize * m_kvDim; + graph_allocs[in_name] = {alloc_idx, offset, kv_offset}; + } + } + graph_allocs[tname] = {alloc_idx, offset + kv_offset, size}; + } + auto& curContextHandle = m_contextVec[m_graphIdxToContextIdx[graphIdx]]; + if (!m_ioBufferMgr->mapFusedBufferOffset( + m_graphsInfo[graphIdx], curContextHandle, graph_allocs + )) { + QNN_ERROR("Error mapping tensor to allocation buffers"); + return false; + } + +#if QNN_IO_TENSOR_DEBUG +for(auto& [tname, data] : graph_allocs){ + QNN_DEBUG("Tensor Name: %s Alloc Idx: %d Tensor Offset: %zu Tensor Size: %zu",tname.c_str(),get<0>(data),get<1>(data),get<2>(data)); + } +#endif + + return true; + +} +bool QnnApi::createFromBinary( + std::vector cachedBinariesPathVec, + ContextConfigs contextConfig, + int64_t spill_fill_buffer_size, + uint64_t mmap_budget, + bool graphSwitching, + const std::vector& execSelectGraphs, + bool loadSelectGraphs +) { + + // Let backendExtensions populate configs + QnnContext_Config_t** customConfigs{nullptr}; + uint32_t customConfigCount{0}; + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeCreateFromBinary( + &customConfigs, &customConfigCount + )) { + QNN_ERROR("Extensions Failure in beforeCreateFromBinary()"); + return false; + } + } + + QnnContext_Config_t** contextConfigs = nullptr; + uint32_t contextConfigCount = 0; + if (true != getContextConfigs( + &contextConfigs, + contextConfigCount, + contextConfig.priority, + graphSwitching, + execSelectGraphs, + loadSelectGraphs + )) { + QNN_ERROR("Couldn't populate context configs"); + return false; + } + + // Merge BE specific and agnostic configs + QnnContext_Config_t** allContextConfigs{nullptr}; + if (true != mergeAllContextConfigs( + &allContextConfigs, + customConfigs, + contextConfigs, + customConfigCount, + contextConfigCount + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } + + if (nullptr == m_qnnSystemInterface.systemContextCreate || + nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo || + nullptr == m_qnnSystemInterface.systemContextFree) { + QNN_ERROR("QNN System function pointers are not populated."); + return false; + } + + graphCountPerContext = getGraphCountPerContext(); + +#ifdef SPILLFILL + Qnn_ContextHandle_t first_contextHandle{nullptr}; + QnnHtpContext_CustomConfig_t customConfigSF; + customConfigSF.option = QNN_HTP_CONTEXT_CONFIG_OPTION_REGISTER_MULTI_CONTEXTS; +#endif + + // Reading Binary Buffer and storing for later use during Deserialization + std::vector> bufferVec(cachedBinariesPathVec.size()); + // Stores sizes of all the Binary Buffers + std::vector allBuffSizes(cachedBinariesPathVec.size()); + // Stores graphs per Contexts + std::vector graphsPerContext(cachedBinariesPathVec.size()); + + for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) { + + auto _start = std::chrono::steady_clock::now(); // context Loading start + uint64_t bufferSize{0}; + std::shared_ptr& buffer{bufferVec[contextIdx]}; + uint32_t graphsCount; + + // read serialized binary into a byte buffer + bufferSize = getFileSize(cachedBinariesPathVec[contextIdx]); + allBuffSizes[contextIdx] = bufferSize; + if (0 == bufferSize) { + QNN_ERROR( + "Received path to an empty file for context index = %zu. Nothing to deserialize.", + contextIdx + ); + return false; + } + + // inspect binary info + QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { + QNN_ERROR("Could not create system handle for context index = %zu", contextIdx); + return false; + } + + const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr}; + if (!mapAndGetContextBinaryInfo( + m_mmapContextBins, + buffer, + cachedBinariesPathVec[contextIdx], + bufferSize, + contextIdx, + graphSwitching, + sysCtxHandle, + &binaryInfo + )) { + QNN_ERROR("Failed to map context Binary for contextIdx: %zu", contextIdx); + return false; + } + + GraphInfo_t** graphsInfo{nullptr}; + if (!copyMetadataToGraphsInfo(binaryInfo, graphsInfo, graphsCount)) { + QNN_ERROR("Failed to copy metadata for graph index = %zu", contextIdx); + freeGraphsInfo(&graphsInfo, graphsCount); + if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + if (graphCountPerContext == -1) { + graphCountPerContext = graphsCount; + m_graphsInfo = (GraphInfo_t**)calloc( + graphCountPerContext * cachedBinariesPathVec.size(), sizeof(GraphInfo_t*) + ); + } else if (graphCountPerContext != graphsCount) { + QNN_ERROR( + "Different len(graphs) found in different context files. Found %u vs %u", + graphsCount, + graphCountPerContext + ); + freeGraphsInfo(&graphsInfo, graphsCount); + if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + auto _stop = std::chrono::steady_clock::now(); // context Loading stop + QNN_DEBUG( + "Loading contexts[%lu] took: %lld us", + contextIdx, + std::chrono::duration_cast(_stop - _start).count() + ); + graphsPerContext.push_back(graphsCount); + for (int gIdx = 0; gIdx < graphsCount; gIdx++) { + m_graphsInfo[m_graphsCount] = graphsInfo[gIdx]; + m_graphIdxToContextIdx[m_graphsCount] = contextIdx; + m_graphsCount++; + } + m_qnnSystemInterface.systemContextFree(sysCtxHandle); + sysCtxHandle = nullptr; + } + + // Iterate over all the tensors across the graphs Info and build info about the IO space it is requiring. + if(false == parseIOTensorsAndAccumulate()){ + QNN_ERROR("Error in parsing the IO tensor info for all context binaries"); + return false; + } + + bool isIOBufferMgrInitialized = false; + + for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) { + + if (nullptr == m_qnnInterface.contextCreateFromBinary) { + QNN_ERROR( + "contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx + ); + freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + Qnn_ContextHandle_t contextHandle{nullptr}; + + uint32_t customConfigCountSF = 0; + +#ifdef SPILLFILL + if (spill_fill_buffer_size > 0) { + QnnHtpContext_GroupRegistration_t groupInfo{nullptr}; + if (contextIdx == 0) { + groupInfo.firstGroupHandle = 0x0; + } else { + groupInfo.firstGroupHandle = first_contextHandle; + } + groupInfo.maxSpillFillBuffer = spill_fill_buffer_size; + customConfigSF.groupRegistration = groupInfo; + + QnnContext_Config_t** cfgs{nullptr}; + customConfigCountSF = 1; + cfgs = (QnnContext_Config_t**)malloc( + customConfigCountSF * sizeof(QnnContext_Config_t*) + ); + cfgs[0] = (QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t)); + cfgs[0]->option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + cfgs[0]->customConfig = reinterpret_cast(&customConfigSF); + if (true != mergeAllContextConfigs( + &allContextConfigs, + cfgs, + allContextConfigs, + customConfigCountSF, + contextConfigCount + customConfigCount + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } + } +#endif + + uint32_t customConfigCountIOMemEstimate = 0; +#if 1 // Adding IO_MEM_ESTIMATION + QnnHtpContext_CustomConfig_t ioMemEstimation; + ioMemEstimation.option = QNN_HTP_CONTEXT_CONFIG_OPTION_IO_MEM_ESTIMATION; + ioMemEstimation.ioMemEstimation = true; + + QnnContext_Config_t** cfgs{nullptr}; + + customConfigCountIOMemEstimate = 1; + + cfgs = (QnnContext_Config_t**)malloc( + customConfigCountIOMemEstimate * sizeof(QnnContext_Config_t*) + ); + cfgs[0] = (QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t)); + cfgs[0]->option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + cfgs[0]->customConfig = + reinterpret_cast(&ioMemEstimation); + if (true != mergeAllContextConfigs( + &allContextConfigs, + cfgs, + allContextConfigs, + customConfigCountIOMemEstimate, + contextConfigCount + customConfigCount + customConfigCountSF + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } +#endif + + if (mmap_budget > 0) { + QnnHtpContext_CustomConfig_t customConfigReadBudget; + customConfigReadBudget.option = QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET; + customConfigReadBudget.fileReadMemoryBudgetInMb = mmap_budget; + + QnnContext_Config_t** cfgs{nullptr}; + + uint32_t customConfigCountReadBudget = 1; + + cfgs = (QnnContext_Config_t**)malloc( + customConfigCountReadBudget * sizeof(QnnContext_Config_t*) + ); + cfgs[0] = (QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t)); + cfgs[0]->option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + cfgs[0]->customConfig = + reinterpret_cast(&customConfigReadBudget); + if (true != mergeAllContextConfigs( + &allContextConfigs, + cfgs, + allContextConfigs, + customConfigCountReadBudget, + contextConfigCount + customConfigCount + customConfigCountSF + customConfigCountIOMemEstimate + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } + } + + + auto start = std::chrono::steady_clock::now(); // context Deserialization starts + + auto errCode = m_qnnInterface.contextCreateFromBinary( + m_backendHandle, + m_deviceHandle, + (const QnnContext_Config_t**)allContextConfigs, + (const void*)bufferVec[contextIdx].get(), + allBuffSizes[contextIdx], + &contextHandle, + nullptr // profile handle + + ); + + auto stop = std::chrono::steady_clock::now(); // context Deserialization stops + QNN_DEBUG( + "Initializing context[%lu] with %u graphs took: %lld us", + contextIdx, + graphsPerContext[contextIdx], + std::chrono::duration_cast(stop - start).count() + ); + + if(!isIOBufferMgrInitialized){ + + if (true != m_ioBufferMgr->initialize(contextHandle)) { + QNN_ERROR("qnn-htp: failure to initialize IOTensor"); + return false; + } + + isIOBufferMgrInitialized = true; + + // Calculate total allocation sizes and offset of each tensor within its allocated buffer + if (m_ioBufferMgr->allocateBuffers(m_contextAllocMap, m_tensorAllocInfo) == false){ + QNN_ERROR("Failed to allocate the Memory across the context buffers."); + return false; + } + + } + + if (errCode != QNN_SUCCESS) { + QNN_ERROR( + "Could not create context from binary for context index = %zu : err %d", + contextIdx, + (int)errCode + ); + freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + // Clearing buffer which is deseralized to reduce Memory footprint + bufferVec[contextIdx].reset(); + + if (m_profileBackendHandle) { + extractBackendProfilingInfo(m_profileBackendHandle); + } + + m_contextVec.push_back(contextHandle); + for (int n_graph = 0; n_graph < graphCountPerContext; n_graph++) { + + uint32_t graphIdx = contextIdx*graphCountPerContext + n_graph ; + + GraphInfo_t* cur_graph = m_graphsInfo[graphIdx]; + m_contextMap[cur_graph] = contextHandle; + + if (nullptr == m_qnnInterface.graphRetrieve) { + QNN_ERROR("graphRetrieveFnHandle is nullptr."); + freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + if (!m_graphsInfo || QNN_SUCCESS != m_qnnInterface.graphRetrieve( + contextHandle, + cur_graph->graphName, + &(cur_graph->graph) + )) { + QNN_ERROR("Unable to retrieve graph handle for graph index = %d", graphIdx); + freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + // Register all the Tensors per graph. + if(false == registerTensorsWithBackend(graphIdx)){ + QNN_ERROR("Unable to MemRegister IO Tensors for graph index = %d", graphIdx); + freeGraphsInfo(&m_graphsInfo, m_graphsCount); + return false; + } + + } + + +#ifdef SPILLFILL + if (spill_fill_buffer_size > 0 && contextIdx == 0) { + first_contextHandle = contextHandle; + } +#endif + + } + + m_isContextCreated = true; + + QNN_DEBUG( + "Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size() + ); + + if (true != freeContextConfigs(contextConfigs, contextConfigCount)) { + QNN_ERROR("Couldn't free context configs"); + return false; + } + if (allContextConfigs) { + free(allContextConfigs); + } + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterCreateFromBinary()) { + QNN_ERROR("Extensions Failure in afterCreateFromBinary()"); + return false; + } + } + + return true; +} + +#if QUALLA_QNN_API_VERSION >= 21700 +bool QnnApi::checkCapabilityOfCreateAsync(bool& propRet) { + if (nullptr == m_qnnInterface.propertyHasCapability) { + QNN_ERROR("propertyHasCapability is nullptr......."); + return false; + } + if (QNN_PROPERTY_SUPPORTED == m_qnnInterface.propertyHasCapability( + QNN_PROPERTY_CONTEXT_SUPPORT_CREATE_FROM_BINARY_LIST_ASYNC + )) { + propRet = true; + } else { + propRet = false; + } + return true; +} + +bool freeContextParams(QnnContext_Params_t** context_params_list, uint32_t numParams) { + if (context_params_list == nullptr || *context_params_list == nullptr) { + return false; + } + for (uint32_t i = 0; i < numParams; i++) { + if (nullptr != context_params_list[i]) { + delete context_params_list[i]; + } + } + return true; +} + +void QnnApi::contextNotifyFn( + Qnn_ContextHandle_t context, + Qnn_GraphHandle_t graph, + const char* graph_name, + QnnContext_createFromBinaryAsyncNotifyType_t completeType, + void* notifyParam, + Qnn_ErrorHandle_t status +) { + std::pair* pair = + reinterpret_cast*>(notifyParam); + QnnApi* QnnApi = pair->first; + uint32_t contextId = pair->second; + + if (completeType == + QnnContext_createFromBinaryAsyncNotifyType_t::QNN_CONTEXT_NOTIFY_TYPE_CONTEXT_INIT) { + QnnApi->updateContext(context, contextId); + } else if (completeType == + QnnContext_createFromBinaryAsyncNotifyType_t::QNN_CONTEXT_NOTIFY_TYPE_GRAPH_INIT) { + QnnApi->updateQnnApiGraphsandContextsInfo(graph_name, graph, contextId); + } +} + +bool QnnApi::createFromBinaryListAsync( + std::vector cachedBinariesPathVec, + ContextConfigs contextConfig, + int64_t spill_fill_buffer_size, + uint64_t mmap_budget, + bool graphSwitching, + const std::vector& execSelectGraphs, + bool loadSelectGraphs +) { + auto _start = std::chrono::steady_clock::now(); + + // Let backendExtensions populate configs + QnnContext_Config_t** customConfigs{nullptr}; + uint32_t customConfigCount{0}; + std::map> contextKeyToCustomConfigsMap; + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeCreateContextsFromBinaryList( + &contextKeyToCustomConfigsMap,&customConfigs, &customConfigCount + )) { + QNN_ERROR("Extensions Failure in beforeCreateContextsFromBinaryList()"); + return false; + } + } + + + QnnContext_Config_t** contextConfigs = nullptr; + uint32_t contextConfigCount = 0; + if (true != getContextConfigs( + &contextConfigs, + contextConfigCount, + contextConfig.priority, + graphSwitching, + execSelectGraphs, + loadSelectGraphs + )) { + QNN_ERROR("Couldn't populate context configs"); + return false; + } + + // Merge BE specific and agnostic configs + QnnContext_Config_t** allContextConfigs{nullptr}; + if (true != mergeAllContextConfigs( + &allContextConfigs, + customConfigs, + contextConfigs, + customConfigCount, + contextConfigCount + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } + + if (nullptr == m_qnnSystemInterface.systemContextCreate || + nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo || + nullptr == m_qnnSystemInterface.systemContextFree) { + QNN_ERROR("QNN System function pointers are not populated."); + return false; + } + + graphCountPerContext = getGraphCountPerContext(); + + std::vector context_params_list(cachedBinariesPathVec.size() +1, nullptr); + std::vector> bufferVec(cachedBinariesPathVec.size()); + // for every context's graph info + GraphInfo_t*** graphsInfo = + (GraphInfo_t***)calloc(cachedBinariesPathVec.size(), sizeof(GraphInfo_t**)); + uint32_t graphsTotalNum = 0; + + for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) { + auto _startPerContext = std::chrono::steady_clock::now(); + uint64_t bufferSize{0}; + std::shared_ptr& buffer{bufferVec[contextIdx]}; + uint32_t graphsCount; + + // read serialized binary into a byte buffer + bufferSize = getFileSize(cachedBinariesPathVec[contextIdx]); + if (0 == bufferSize) { + QNN_ERROR( + "Received path to an empty file for context index = %zu. Nothing to deserialize.", + contextIdx + ); + return false; + } + + // inspect binary info + QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { + QNN_ERROR("Could not create system handle for context index = %zu", contextIdx); + return false; + } + const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr}; + if (!mapAndGetContextBinaryInfo( + m_mmapContextBins, + buffer, + cachedBinariesPathVec[contextIdx], + bufferSize, + contextIdx, + graphSwitching, + sysCtxHandle, + &binaryInfo + )) { + QNN_ERROR("Failed to map context Binary."); + return false; + } + + if (!copyMetadataToGraphsInfo(binaryInfo, graphsInfo[contextIdx], graphsCount)) { + QNN_ERROR("Failed to copy metadata for graph index = %zu", contextIdx); + freeGraphsInfo(&graphsInfo[contextIdx], graphsCount); + freeGraphsInfo(&m_graphsInfo, graphsCount); + return false; + } + + if (graphCountPerContext == -1) { + graphCountPerContext = graphsCount; + graphsTotalNum = graphCountPerContext * cachedBinariesPathVec.size(); + m_graphsInfo = (GraphInfo_t**)calloc(graphsTotalNum, sizeof(GraphInfo_t*)); + + } else if (graphCountPerContext != graphsCount) { + QNN_ERROR( + "Different len(graphs) found in different context files. Found %u vs %u", + graphsCount, + graphCountPerContext + ); + freeGraphsInfo(&graphsInfo[contextIdx], graphsCount); + freeGraphsInfo(&m_graphsInfo, graphsTotalNum); + return false; + } + for (int gIdx = 0; gIdx < graphsCount; gIdx++) { + int graphIdxOfAll = contextIdx * graphsCount + gIdx; + m_graphsInfo[graphIdxOfAll] = graphsInfo[contextIdx][gIdx]; + m_graphNameToInfo[m_graphsInfo[graphIdxOfAll]->graphName] = m_graphsInfo[graphIdxOfAll]; + } + m_qnnSystemInterface.systemContextFree(sysCtxHandle); + sysCtxHandle = nullptr; + + uint32_t customConfigCountSF = 0; + + if (mmap_budget > 0) { + QnnHtpContext_CustomConfig_t customConfigReadBudget; + customConfigReadBudget.option = QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET; + customConfigReadBudget.fileReadMemoryBudgetInMb = mmap_budget; + + QnnContext_Config_t** cfgs{nullptr}; + + uint32_t customConfigCountReadBudget = 1; + + cfgs = (QnnContext_Config_t**)malloc( + customConfigCountReadBudget * sizeof(QnnContext_Config_t*) + ); + cfgs[0] = (QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t)); + cfgs[0]->option = QNN_CONTEXT_CONFIG_OPTION_CUSTOM; + cfgs[0]->customConfig = + reinterpret_cast(&customConfigReadBudget); + if (true != mergeAllContextConfigs( + &allContextConfigs, + cfgs, + allContextConfigs, + customConfigCountReadBudget, + contextConfigCount + customConfigCount + customConfigCountSF + )) { + QNN_ERROR("Error merging custom and context configs"); + return false; + } + } + + if (m_profileBackendHandle) { + extractBackendProfilingInfo(m_profileBackendHandle); + } + + // passing class QnnApi pointer into callback funtion(notifyFn) + std::pair* notifyParam = + new std::pair(this, (size_t)contextIdx); + QnnContext_Params_t* contextParam = new QnnContext_Params_t{ + .version = QNN_CONTEXT_PARAMS_VERSION_1, + .v1 = + QnnContext_ParamsV1_t{ + (const QnnContext_Config_t**)allContextConfigs, + (const void*)buffer.get(), + bufferSize, + nullptr, + QnnApi::contextNotifyFn, + (void*)notifyParam + } + }; + + context_params_list[contextIdx] = contextParam; + + auto _stop = std::chrono::steady_clock::now(); + QNN_DEBUG( + "Loading contexts[%lu] took: %lld us", + contextIdx, + std::chrono::duration_cast(_stop - _startPerContext).count() + ); + } + + if (nullptr == m_qnnInterface.contextCreateFromBinaryListAsync) { + QNN_ERROR("contextCreateFromBinaryListAsyncFnHandle is nullptr"); + freeGraphsInfo(&m_graphsInfo, graphsTotalNum); + freeContextParams(context_params_list.data(), cachedBinariesPathVec.size()); + return false; + } + + auto start = std::chrono::steady_clock::now(); + + + auto errCode = m_qnnInterface.contextCreateFromBinaryListAsync( + m_backendHandle, + m_deviceHandle, + const_cast(context_params_list.data()), + (const QnnContext_Config_t**)allContextConfigs, + nullptr + ); + + + auto stop = std::chrono::steady_clock::now(); + QNN_DEBUG( + "Initializing %lu context with %u graphs took: %lld us", + cachedBinariesPathVec.size(), + graphsTotalNum, + std::chrono::duration_cast(stop - start).count() + ); + + // Explicitly free the context binary buffers. This ensures that the lifecycle + // of the buffers outlasts the API call where their raw pointers are referenced. + for (auto contextBinaryBuffer : bufferVec) { + QNN_DEBUG("Freeing context binary buffer @%p", contextBinaryBuffer.get()); + contextBinaryBuffer.reset(); + } + + if (errCode != QNN_SUCCESS) { + QNN_ERROR( + "Could not create context from binary List Async for context, err %d", (int)errCode + ); + freeGraphsInfo(&m_graphsInfo, graphsTotalNum); + freeContextParams(context_params_list.data(), cachedBinariesPathVec.size()); + return false; + } + + // set graphInfo in m_graphsInfo + for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) { + int contextIdxOfgraphsInfo = graphIdx / graphCountPerContext; + uint32_t contexIdxofCurrGraph = m_graphNameToContextIdx[m_graphsInfo[graphIdx]->graphName]; + m_graphsInfo[graphIdx] = + graphsInfo[contextIdxOfgraphsInfo][graphIdx % graphCountPerContext]; + m_contextMap[m_graphsInfo[graphIdx]] = m_contextIdtoHandle[contexIdxofCurrGraph]; + } + + m_isContextCreated = true; + + if (true != freeContextConfigs(contextConfigs, contextConfigCount)) { + QNN_ERROR("Couldn't free context configs"); + return false; + } + + if (true != freeContextParams(context_params_list.data(), cachedBinariesPathVec.size())) { + QNN_ERROR("Couldn't free context params list"); + return false; + } + + if (allContextConfigs) { + free(allContextConfigs); + } + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterCreateContextsFromBinaryList()) { + QNN_ERROR("Extensions Failure in afterCreateContextsFromBinaryList()"); + return false; + } + } + return true; +} +#endif + +static std::vector __split(std::string_view str, char delim) { + std::vector split; + + size_t i = 0, p = 0; + + for (; i <= str.size(); ++i) { + if (i == str.size() || str[i] == delim) { + split.push_back(std::string(str.data() + p, i - p)); + p = ++i; + } + } + + return split; +} + +bool QnnApi::registerOpPackage(std::string opPackagePath) { + const size_t pathIdx = 0; + const size_t interfaceProviderIdx = 1; + const size_t targetIdx = 2; + + auto opPackage = __split(opPackagePath, ':'); + + if (opPackage.size() != 2 && opPackage.size() != 3) { + return false; + } + + if (nullptr == m_qnnInterface.backendRegisterOpPackage) { + return false; + } + + const char* target = nullptr; + if (opPackage.size() == 3) { + target = (char*)opPackage[targetIdx].c_str(); + } + + auto returnStatus = m_qnnInterface.backendRegisterOpPackage( + m_backendHandle, + (char*)opPackage[pathIdx].c_str(), + (char*)opPackage[interfaceProviderIdx].c_str(), + target + ); + if (QNN_SUCCESS != returnStatus) { + QNN_ERROR( + "Could not register OpPackage backend due to error = %llu", + (unsigned long long)returnStatus + ); + return false; + } + + return true; +} + +// Performance Setting for HTP +bool QnnApi::initializePerformance() { + + QnnDevice_Infrastructure_t deviceInfra = nullptr; + if (QNN_SUCCESS != m_qnnInterface.deviceGetInfrastructure(&deviceInfra)) { + QNN_ERROR("Failure in deviceGetInfrastructure()"); + return false; + } + + QnnHtpDevice_Infrastructure_t* htpInfra = + static_cast(deviceInfra); + m_perfInfra = &(htpInfra->perfInfra); + uint32_t deviceId = 0; + uint32_t coreId = 0; + if (QNN_SUCCESS != m_perfInfra->createPowerConfigId(deviceId, coreId, &m_powerConfigId)) { + QNN_ERROR("Failure in createPowerConfigId()"); + return false; + } + + return true; +} + +bool QnnApi::destroyPerformance() { + if (nullptr != m_perfInfra && + QNN_SUCCESS != m_perfInfra->destroyPowerConfigId(m_powerConfigId)) { + QNN_ERROR("Failure in destroyPowerConfigId()"); + return false; + } + + return true; +} + +bool QnnApi::boostPerformance() { + // Initialize the power config and select the voltage corner values for the performance setting. + QnnHtpPerfInfrastructure_PowerConfig_t powerConfig; + memset(&powerConfig, 0, sizeof(powerConfig)); + + powerConfig.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3; + powerConfig.dcvsV3Config.dcvsEnable = 1; + powerConfig.dcvsV3Config.setDcvsEnable = 1; + powerConfig.dcvsV3Config.contextId = m_powerConfigId; + + // refer QnnHtpPerfInfrastructure.h + powerConfig.dcvsV3Config.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_PERFORMANCE_MODE; + + // Set Sleep-Disable latency parameter + powerConfig.dcvsV3Config.setSleepDisable = 0; + powerConfig.dcvsV3Config.sleepDisable = 0; + + // Set Sleep latency parameter + powerConfig.dcvsV3Config.setSleepLatency = 0; + powerConfig.dcvsV3Config.sleepLatency = 1000; // range 40-2000 micro sec + + // Set Bus Clock Parameters (refer QnnHtpPerfInfrastructure.h) + powerConfig.dcvsV3Config.setBusParams = 1; + powerConfig.dcvsV3Config.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_TURBO_PLUS; + powerConfig.dcvsV3Config.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO_PLUS; + powerConfig.dcvsV3Config.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO_PLUS; + + // set Core Clock Parameters (refer QnnHtpPerfInfrastructure.h) + powerConfig.dcvsV3Config.setCoreParams = 1; + powerConfig.dcvsV3Config.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_TURBO_PLUS; + powerConfig.dcvsV3Config.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_TURBO_PLUS; + powerConfig.dcvsV3Config.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO_PLUS; + + // Set power config with different performance parameters + const QnnHtpPerfInfrastructure_PowerConfig_t* powerConfigs[] = {&powerConfig, NULL}; + if (QNN_SUCCESS != m_perfInfra->setPowerConfig(m_powerConfigId, powerConfigs)) { + QNN_ERROR("Failure in setPowerConfig() from boostPerformance"); + return false; + } + + return true; +} + +bool QnnApi::resetPerformance() { + // Initialize the power config and select the voltage corner values for the performance setting. + QnnHtpPerfInfrastructure_PowerConfig_t powerConfig; + memset(&powerConfig, 0, sizeof(powerConfig)); + + powerConfig.option = QNN_HTP_PERF_INFRASTRUCTURE_POWER_CONFIGOPTION_DCVS_V3; + powerConfig.dcvsV3Config.dcvsEnable = 1; + powerConfig.dcvsV3Config.setDcvsEnable = 1; + powerConfig.dcvsV3Config.contextId = m_powerConfigId; + + // refer QnnHtpPerfInfrastructure.h + powerConfig.dcvsV3Config.powerMode = QNN_HTP_PERF_INFRASTRUCTURE_POWERMODE_POWER_SAVER_MODE; + + // Set Sleep-Disable latency parameter + powerConfig.dcvsV3Config.setSleepDisable = 0; + powerConfig.dcvsV3Config.sleepDisable = 0; + + // Set Sleep latency parameter + powerConfig.dcvsV3Config.setSleepLatency = 0; + powerConfig.dcvsV3Config.sleepLatency = 1000; // range 40-2000 micro sec + + // Set Bus Clock Parameters (refer QnnHtpPerfInfrastructure.h) + powerConfig.dcvsV3Config.setBusParams = 1; + powerConfig.dcvsV3Config.busVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; + powerConfig.dcvsV3Config.busVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; + powerConfig.dcvsV3Config.busVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO; + + // set Core Clock Parameters (refer QnnHtpPerfInfrastructure.h) + powerConfig.dcvsV3Config.setCoreParams = 1; + powerConfig.dcvsV3Config.coreVoltageCornerMin = DCVS_VOLTAGE_VCORNER_NOM; + powerConfig.dcvsV3Config.coreVoltageCornerTarget = DCVS_VOLTAGE_VCORNER_NOM; + powerConfig.dcvsV3Config.coreVoltageCornerMax = DCVS_VOLTAGE_VCORNER_TURBO; + + // Set power config with different performance parameters + const QnnHtpPerfInfrastructure_PowerConfig_t* powerConfigs[] = {&powerConfig, NULL}; + if (QNN_SUCCESS != m_perfInfra->setPowerConfig(m_powerConfigId, powerConfigs)) { + QNN_ERROR("Failure in setPowerConfig() from resetPerformance"); + return false; + } + + return true; +} + +bool QnnApi::initialize( + std::string backendPath, + std::vector modelPathOrCachedBinaryPathVec, + BackendExtensionsConfigs backendExtensionsConfig, + PerfProfile parsedPerfProfile, + ContextConfigs contextConfig, + std::vector graphConfigs, + bool loadFromCachedBinary, + std::string systemLibraryPath, + bool debugModeRequested, + int64_t spill_fill_buffer_size, + bool mmapContextBins, + bool asyncInit, + uint64_t mmap_budget, + bool debug_qnn, + bool graphSwitching, + const std::vector& execSelectGraphs, + bool loadSelectGraphs +) { + if (modelPathOrCachedBinaryPathVec.size() > 1 && false == loadFromCachedBinary) { + QNN_ERROR("Currently only 1 model file is supported for this framework! \ + Although multiple context files are supported!"); + return false; + } + + m_mmapContextBins = mmapContextBins; + + // Setting up Debug mode + m_DebugModeRequested = debugModeRequested; + if (m_DebugModeRequested) { + QNN_WARN("Warning: Debug mode set to true."); + } + + // Initialize the QNN run time + if (false == getQnnInterface(backendPath)) { + QNN_ERROR("Qnn getQnnInterface FAILED!"); + return false; + } + + if (loadFromCachedBinary) { + if (false == getQnnSystemInterface(systemLibraryPath)) { + QNN_ERROR("Qnn getQnnSystemInterface FAILED!"); + return false; + } + } else { + if (false == loadModel(modelPathOrCachedBinaryPathVec[0])) { + QNN_ERROR("Loading model FAILED!"); + return false; + } + } + + QnnLog_Level_t logLevel = QNN_LOG_LEVEL_WARN; + if (false == initializeLogging(logLevel, debug_qnn)) { + QNN_ERROR("Unable to Initialize logging in backend"); + return false; + } + + // initialize backend extensions +#ifdef QUALLA_INTERNAL_QNN_SDK + // Initialize backendExtensions only when both backend ext config and backend ext lib are provided + if (!backendExtensionsConfig.configFilePath.empty() && + false == initializeBackendExtensions( + backendExtensionsConfig, parsedPerfProfile, debug_qnn + )) { + QNN_WARN("Failure in initializing backend extensions."); + } +#else + if (false == + initializeBackendExtensions(backendExtensionsConfig, parsedPerfProfile, debug_qnn)) { + QNN_ERROR("Failure in initializing backend extensions."); + return false; + } +#endif + if (false == initializeBackend()) { + QNN_ERROR("Qnn initializeBackend FAILED!"); + return false; + } + if (false == createDevice()) { + QNN_ERROR("Device Creation failure"); + setDeviceStatus(false); + return false; + } else { + setDeviceStatus(true); + } + if (!loadFromCachedBinary) { + if (false == createContext(contextConfig)) { + QNN_ERROR("Qnn createContext FAILED!"); + return false; + } + if (false == composeGraphs(graphConfigs)) { + QNN_ERROR("composeGraphs FAILED!"); + return false; + } + if (false == finalizeGraphs()) { + QNN_ERROR("finalizeGraphs FAILED!"); + return false; + } + } else { + bool cfb_ret = false; + bool asyncCapability = false; +#if QUALLA_QNN_API_VERSION >= 21700 + if(asyncInit == true){ + if (!checkCapabilityOfCreateAsync(asyncCapability)) { + QNN_ERROR("Capabilty checked failed"); + return false; + } + asyncInit = asyncCapability && asyncInit; + } + if (asyncInit == true) { + QNN_INFO("Using create From Binary List Async"); + cfb_ret = createFromBinaryListAsync( + modelPathOrCachedBinaryPathVec, + contextConfig, + spill_fill_buffer_size, + mmap_budget, + graphSwitching, + execSelectGraphs, + loadSelectGraphs + ); + if (cfb_ret == false) { + QNN_ERROR("Create From Binary List Async FAILED!"); + return false; + } + + } else { +#endif + QNN_INFO("Using create From Binary"); + cfb_ret = createFromBinary( + modelPathOrCachedBinaryPathVec, + contextConfig, + spill_fill_buffer_size, + mmap_budget, + graphSwitching, + execSelectGraphs, + loadSelectGraphs + ); + if (false == cfb_ret) { + QNN_ERROR("Create From Binary FAILED!"); + return false; + } + } +#if QUALLA_QNN_API_VERSION >= 21700 + } +#endif + + // if (false == initializePerformance()) { + // QNN_ERROR("initialize Performance FAILED!"); + // return false; + // } + + for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) { + m_graphNameToIndex[m_graphsInfo[graphIdx]->graphName] = graphIdx; + } + +#if NSP_LOG_LEVEL > 1 + for (const auto& graphNameIndex : m_graphNameToIndex) { + QNN_DEBUG( + "Found Graph name %s corresponding to index %d", + graphNameIndex.first.c_str(), + graphNameIndex.second + ); + } + + fprintf(stderr, "context_handles = ["); + for (auto ctx_handle : m_contextVec) + fprintf(stderr, "%p, ", ctx_handle); + fprintf(stderr, "]\n"); +#endif + return true; +} + +bool QnnApi::initialize( + std::string backendPath, + std::string modelPath, + std::string opPackage, + ContextConfigs contextConfig, + std::vector graphConfigs, + uint32_t* inputDim, + uint32_t inputRank, + uint32_t* outputDim, + uint32_t outputRank, + uint32_t* kvDim, + uint32_t kvRank, + Qnn_Param_t* params, + uint32_t numParams, + bool debugModeRequested +) { + // Setting up Debug mode + m_DebugModeRequested = debugModeRequested; + if (m_DebugModeRequested) { + QNN_WARN("Warning: Debug mode set to true."); + } + + // Initialize the QNN run time + if (false == getQnnInterface(backendPath)) { + QNN_ERROR("Qnn getQnnInterface FAILED!"); + return false; + } + + QnnLog_Level_t logLevel = QNN_LOG_LEVEL_WARN; + if (false == initializeLogging(logLevel, false)) { + QNN_ERROR("Unable to Initialize logging in backend"); + } + + if (false == initializeBackend()) { + QNN_ERROR("Qnn initializeBackend FAILED!"); + return false; + } + + //CPU does not support createDevice. + setDeviceStatus(false); + if (false == registerOpPackage(opPackage)) { + QNN_ERROR("Qnn initializeBackend FAILED!"); + return false; + } + +// Change to 1 to enable QNN Basic profiling +#if 0 + if (false == initProfiling()) { + QNN_ERROR("Profiling init failure"); + return false; + } +#endif + if (false == loadModel(modelPath)) { + QNN_ERROR("Loading model FAILED!"); + return false; + } + if (false == createContext(contextConfig)) { + QNN_ERROR("Qnn createContext FAILED!"); + return false; + } + if (false == composeGraphs( + graphConfigs, inputDim, inputRank, outputDim, outputRank, kvDim, kvRank, params, numParams + )) { + QNN_ERROR("composeGraphs FAILED!"); + return false; + } + if (false == finalizeGraphs()) { + QNN_ERROR("finalizeGraphs FAILED!"); + return false; + } + + for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) { + m_graphNameToIndex[m_graphsInfo[graphIdx]->graphName] = graphIdx; + } +#if NSP_LOG_LEVEL > 1 + for (const auto& graphNameIndex : m_graphNameToIndex) { + QNN_DEBUG( + "Found Graph name %s corresponding to index %d", + graphNameIndex.first.c_str(), + graphNameIndex.second + ); + } +#endif + return true; +} + +bool QnnApi::graphExecute( + Qnn_Tensor_t* input, + Qnn_Tensor_t* output, + std::string graphName, + std::map>& timeLogs +) { + QnnGraph_Config_t** customGraphConfigs{nullptr}; + uint32_t configCount{0}; + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->beforeExecute( + graphName.c_str(), &customGraphConfigs, &configCount + )) { + QNN_ERROR("Extensions Failure in beforeExecute()"); + return false; + } + if (customGraphConfigs) { + if (true != setGraphConfigsBeforeExecute( + m_graphsInfo[m_graphNameToIndex[graphName]]->graph, + customGraphConfigs, + configCount + )) { + QNN_ERROR("Failure in setGraphConfigsBeforeExecute()"); + return false; + } + } + } + + // if (true != boostPerformance()) { + // QNN_ERROR("Couldn't boost the performance"); + // return false; + // } + + Qnn_ErrorHandle_t ret = QNN_GRAPH_NO_ERROR; + try { +#if NSP_LOG_LEVEL > 1 + auto start = std::chrono::steady_clock::now(); +#endif + ret = m_qnnInterface.graphExecute( + m_graphsInfo[m_graphNameToIndex[graphName]]->graph, + input, + m_graphsInfo[m_graphNameToIndex[graphName]]->numInputTensors, + output, + m_graphsInfo[m_graphNameToIndex[graphName]]->numOutputTensors, + m_profileBackendHandle, + nullptr + ); +#if NSP_LOG_LEVEL > 1 + auto stop = std::chrono::steady_clock::now(); + QNN_DEBUG( + "graphExecute[%s] took: %lld us", + graphName.c_str(), + std::chrono::duration_cast(stop - start).count() + ); +#endif +#if NSP_LOG_LEVEL > 6 + timeLogs[graphName].first += static_cast( + std::chrono::duration_cast(stop - start).count() + ); + timeLogs[graphName].second++; +#endif + + } catch (const std::exception& ex) { + QNN_ERROR("ERROR executing inference ret"); + } catch (...) { + QNN_ERROR("ERROR executing inference ret"); + } + + if (m_profileBackendHandle) { + extractBackendProfilingInfo(m_profileBackendHandle, timeLogs, graphName); + } + + // if (true != resetPerformance()) { + // QNN_ERROR("Couldn't reset the performance"); + // return false; + // } + + if (ret != QNN_GRAPH_NO_ERROR) return false; + + if (nullptr != m_backendExtensions && m_backendExtensions->interface()) { + if (!m_backendExtensions->interface()->afterExecute()) { + QNN_ERROR("Extensions Failure in afterExecute()"); + return false; + } + } + + return true; +} + +bool QnnApi::getTensorQuantParams( + const Qnn_Tensor_t* tensor, + std::vector& quantParamsVec +) { + bool status = false; + auto dataType = QNN_TENSOR_GET_DATA_TYPE(tensor); + auto quantParams = QNN_TENSOR_GET_QUANT_PARAMS(tensor); + if (dataType == QNN_DATATYPE_UFIXED_POINT_8 || dataType == QNN_DATATYPE_SFIXED_POINT_8 || + dataType == QNN_DATATYPE_UFIXED_POINT_16) { + auto quantEncodingType = quantParams.quantizationEncoding; + if (quantEncodingType == + Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + status = true; + double scale = quantParams.scaleOffsetEncoding.scale; + int32_t offset = quantParams.scaleOffsetEncoding.offset; + quantParamsVec.emplace_back(scale, offset); + } else if (quantEncodingType == + Qnn_QuantizationEncoding_t::QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + status = true; + auto encodingStruct = quantParams.axisScaleOffsetEncoding; + for (uint32_t n = 0; n < encodingStruct.numScaleOffsets; n++) { + auto scaleOffset = encodingStruct.scaleOffset[n]; + quantParamsVec.emplace_back(scaleOffset.scale, scaleOffset.offset); + } + } else { + QNN_ERROR("quant encoding type not supported"); + } + } + return status; +} + +bool QnnApi::getTensorShape(std::vector& tensorDims, const TensorWrapper& tensorWrapper) { + const Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrapper); + if (false == + fillDims(tensorDims, QNN_TENSOR_GET_DIMENSIONS(tensor), QNN_TENSOR_GET_RANK(tensor))) + return false; + + tensorDims.push_back(getDataTypeSize(QNN_TENSOR_GET_DATA_TYPE(tensor))); + return true; +} + +bool QnnApi::getTensorNameAndShape( + std::string& tensorName, + std::vector& tensorDims, + TensorWrapper& tensorWrapper +) { + Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrapper); + tensorName = std::string(GET_TENSOR_WRAPPER_NAME(tensorWrapper)); + if (false == + fillDims(tensorDims, QNN_TENSOR_GET_DIMENSIONS(tensor), QNN_TENSOR_GET_RANK(tensor))) + return false; + + tensorDims.push_back(g_qnnDataTypeToSize[QNN_TENSOR_GET_DATA_TYPE(tensor)]); + return true; +} + +bool QnnApi::extractBackendProfilingInfo( + Qnn_ProfileHandle_t profileHandle, + std::map>& timeLogs, + std::string graphName +) { + if (nullptr == m_profileBackendHandle) { + QNN_ERROR("QNN HTP Profile handle is nullptr; may not be initialized."); + return false; + } + const QnnProfile_EventId_t* profileEvents{nullptr}; + uint32_t numEvents{0}; + if (QNN_PROFILE_NO_ERROR != + m_qnnInterface.profileGetEvents(profileHandle, &profileEvents, &numEvents)) { + QNN_ERROR("Failure in QNN HTP profile get events."); + return false; + } + QNN_DEBUG("ProfileEvents: [%p], numEvents: [%d]", profileEvents, numEvents); + for (size_t event = 0; event < numEvents; event++) { + extractProfilingEvent(*(profileEvents + event), timeLogs, graphName); + extractProfilingSubEvents(*(profileEvents + event), timeLogs, graphName); + } + return true; +} + +bool QnnApi::extractProfilingSubEvents( + QnnProfile_EventId_t profileEventId, + std::map>& timeLogs, + std::string graphName +) { + const QnnProfile_EventId_t* profileSubEvents{nullptr}; + uint32_t numSubEvents{0}; + if (QNN_PROFILE_NO_ERROR != + m_qnnInterface.profileGetSubEvents(profileEventId, &profileSubEvents, &numSubEvents)) { + QNN_ERROR("Failure in QNN HTP profile get sub events."); + return false; + } + QNN_DEBUG("ProfileSubEvents: [%p], numSubEvents: [%d]", profileSubEvents, numSubEvents); + for (size_t subEvent = 0; subEvent < numSubEvents; subEvent++) { + extractProfilingEvent(*(profileSubEvents + subEvent), timeLogs, graphName); + extractProfilingSubEvents(*(profileSubEvents + subEvent), timeLogs, graphName); + } + return true; +} + +bool QnnApi::extractProfilingEvent( + QnnProfile_EventId_t profileEventId, + std::map>& timeLogs, + std::string graphName +) { + QnnProfile_EventData_t eventData; + if (QNN_PROFILE_NO_ERROR != m_qnnInterface.profileGetEventData(profileEventId, &eventData)) { + QNN_ERROR("Failure in profile get event type."); + return false; + } + + QNN_DEBUG( + "Event Info - Event Type: [%d], Event Value: [%lu], Event Identifier: [%s], Event Unit: [%d]", + eventData.type, + eventData.value, + eventData.identifier, + eventData.unit + ); +#if NSP_LOG_LEVEL > 6 + timeLogs[graphName + "_" + eventData.identifier].first += static_cast(eventData.value); + timeLogs[graphName + "_" + eventData.identifier].second++; +#endif + + return true; +} + +bool QnnApi::extractBackendProfilingInfo(Qnn_ProfileHandle_t profileHandle) { + if (nullptr == m_profileBackendHandle) { + QNN_ERROR("QNN HTP Profile handle is nullptr; may not be initialized."); + return false; + } + const QnnProfile_EventId_t* profileEvents{nullptr}; + uint32_t numEvents{0}; + if (QNN_PROFILE_NO_ERROR != + m_qnnInterface.profileGetEvents(profileHandle, &profileEvents, &numEvents)) { + QNN_ERROR("Failure in QNN HTP profile get events."); + return false; + } + QNN_DEBUG("ProfileEvents: [%p], numEvents: [%d]", profileEvents, numEvents); + for (size_t event = 0; event < numEvents; event++) { + extractProfilingEvent(*(profileEvents + event)); + extractProfilingSubEvents(*(profileEvents + event)); + } + return true; +} + +bool QnnApi::extractProfilingSubEvents(QnnProfile_EventId_t profileEventId) { + const QnnProfile_EventId_t* profileSubEvents{nullptr}; + uint32_t numSubEvents{0}; + if (QNN_PROFILE_NO_ERROR != + m_qnnInterface.profileGetSubEvents(profileEventId, &profileSubEvents, &numSubEvents)) { + QNN_ERROR("Failure in QNN HTP profile get sub events."); + return false; + } + QNN_DEBUG("ProfileSubEvents: [%p], numSubEvents: [%d]", profileSubEvents, numSubEvents); + for (size_t subEvent = 0; subEvent < numSubEvents; subEvent++) { + extractProfilingEvent(*(profileSubEvents + subEvent)); + extractProfilingSubEvents(*(profileSubEvents + subEvent)); + } + return true; +} + +bool QnnApi::extractProfilingEvent(QnnProfile_EventId_t profileEventId) { + QnnProfile_EventData_t eventData; + if (QNN_PROFILE_NO_ERROR != m_qnnInterface.profileGetEventData(profileEventId, &eventData)) { + QNN_ERROR("Failure in profile get event type."); + return false; + } + + QNN_DEBUG( + "Event Info - Event Type: [%d], Event Value: [%lu], Event Identifier: [%s], Event Unit: [%d]", + eventData.type, + eventData.value, + eventData.identifier, + eventData.unit + ); + + return true; +} + +bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch) { +#if QUALLA_QNN_API_VERSION < 21700 + QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries"); + return false; +#else + // assumption splitNum from 0 + QNN_DEBUG("QnnApi::applyBinarySection %d ", binIndex); + uint32_t numAdapterGraph = 0; + if (nullptr == m_qnnInterface.contextApplyBinarySection) { + QNN_ERROR("contextApplyBinarySection Interface not suported!!"); + return false; + } + if (binIndex >= m_graphsCount) { + QNN_ERROR(" Passed split %d base Model graphcount %d ", binIndex, m_graphsCount); + return false; + } + uint64_t bufferSize{0}; + std::shared_ptr buffer{nullptr}; + bufferSize = getFileSize(binSectionPath); + + auto graphCountPerContext = getGraphCountPerContext(); + if (graphCountPerContext <= 0) { + QNN_ERROR(" graphCountPerContext is <=0 "); + return false; + } + const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr}; + QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { + QNN_ERROR("Could not create system handle for context index = %zu", binIndex); + return false; + } + Qnn_ContextBinarySize_t binaryInfoSize{0}; + + if(m_adapterNameToBuffer[binSectionPath]){ + buffer = m_adapterNameToBuffer[binSectionPath]; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextGetBinaryInfo( + sysCtxHandle, + static_cast(buffer.get()), + bufferSize, + &binaryInfo, + &binaryInfoSize + )) { + QNN_ERROR("Failed to get context binary info for context index = %zu", binIndex); + return false; + } + } + else{ + if (!mapAndGetContextBinaryInfo( + useMmap, + buffer, + binSectionPath, + bufferSize, + binIndex, + graphSwitch, + sysCtxHandle, + &binaryInfo + )) { + QNN_ERROR("Failed to map context Binary for contextIdx: %zu", binIndex); + return false; + } + m_adapterNameToBuffer[binSectionPath] = buffer; + } + numAdapterGraph = getNumGraphInBinary(binaryInfo); + if (numAdapterGraph <= 0) { + QNN_ERROR(" numAdapterGraph is <=0 "); + return false; + } + uint32_t contextId = 0; + uint32_t graphId = 0; + for(auto idx = 0;idxgraph; + if (contextHandle == nullptr || graphHandle == nullptr) { + QNN_ERROR(" contexthandle or graph handle is null for patch no = %d ", graphId); + return false; + } + + QnnContext_Buffer_t qnnBuffer; + qnnBuffer.version = QNN_CONTEXT_BUFFER_VERSION_1; + qnnBuffer.v1.memType = QNN_CONTEXTMEMTYPE_RAW; + qnnBuffer.v1.binaryBuf.dataSize = bufferSize; + qnnBuffer.v1.binaryBuf.data = static_cast(buffer.get()); + + auto errorCode = m_qnnInterface.contextApplyBinarySection( + contextHandle, + graphHandle, + QNN_CONTEXT_SECTION_UPDATABLE, + &qnnBuffer, + nullptr, //profile handle is null + nullptr //singal handle is null + ); + if (errorCode != QNN_SUCCESS) { + QNN_ERROR("Could not Apply Patch for graph = %d errocode = %zu ", graphId, errorCode); + return false; + } + } + if(updateIOEncodings(buffer,bufferSize,numAdapterGraph*binIndex) ==false) + { + QNN_ERROR("qnn-htp: Adapter updateIOEncodings failed"); + return false; + } + return true; +#endif +} + +bool QnnApi::updateIOEncodings(std::shared_ptr& buffer,uint64_t bufferSize,uint32_t graphIndex){ + + QNN_DEBUG("Applying adapter Encodings"); + QnnSystemContext_Handle_t sysCtxHandle{nullptr}; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextCreate(&sysCtxHandle)) { + QNN_ERROR("Could not create system handle for context index = %zu", graphIndex); + return false; + } + const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr}; + Qnn_ContextBinarySize_t binaryInfoSize{0}; + if (QNN_SUCCESS != m_qnnSystemInterface.systemContextGetBinaryInfo( + sysCtxHandle, + static_cast(buffer.get()), + bufferSize, + &binaryInfo, + &binaryInfoSize + )) { + QNN_ERROR("Failed to get context binary info for context index = %zu", graphIndex); + return false; + } + if (!updateMetaDataToGraphsInfo(binaryInfo, m_graphsInfo,graphIndex)) { + QNN_ERROR("Failed to copy metadata for graph index = %zu", graphIndex); + return false; + } + m_qnnSystemInterface.systemContextFree(sysCtxHandle); + sysCtxHandle = nullptr; + QNN_DEBUG(" updateIOEncodings success "); + return true; +} diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp new file mode 100644 index 0000000000000000000000000000000000000000..8392265aa950a95015dce233bbeb5c0456ea0665 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp @@ -0,0 +1,429 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include "BackendExtensions.hpp" +#include "QnnConfig.hpp" +#include "QnnHtpPerfInfrastructure.h" +#include "QnnHtpDevice.h" +#include "qnn-utils.hpp" +#include "IOTensor.hpp" + +#include +#include + +#define QNN_IO_TENSOR_DEBUG 0 + +enum KVManagerMode { POINTER_SHIFT = 0x0, SHIFT_CONCAT = 0x1 }; + +using qualla::QnnUtils::QuantParam; + +#define QUALLA_QNN_API_VERSION \ + (QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH) + +static std::map g_qnnDataTypeToSize = { + {QNN_DATATYPE_INT_8, 1}, + {QNN_DATATYPE_INT_16, 2}, + {QNN_DATATYPE_INT_32, 4}, + {QNN_DATATYPE_INT_64, 8}, + {QNN_DATATYPE_UINT_8, 1}, + {QNN_DATATYPE_UINT_16, 2}, + {QNN_DATATYPE_UINT_32, 4}, + {QNN_DATATYPE_UINT_64, 8}, + {QNN_DATATYPE_FLOAT_16, 2}, + {QNN_DATATYPE_FLOAT_32, 4}, + {QNN_DATATYPE_SFIXED_POINT_8, 1}, + {QNN_DATATYPE_SFIXED_POINT_16, 2}, + {QNN_DATATYPE_SFIXED_POINT_32, 4}, + {QNN_DATATYPE_UFIXED_POINT_8, 1}, + {QNN_DATATYPE_UFIXED_POINT_16, 2}, + {QNN_DATATYPE_UFIXED_POINT_32, 4}, + {QNN_DATATYPE_BOOL_8, 1}, +}; + +class QnnApi { + private: + const uint32_t s_graphConfigsReserveCount = 16; + + // Model vars + typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)( + const QnnInterface_t*** providerList, + uint32_t* numProviders + ); + typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)( + const QnnSystemInterface_t*** providerList, + uint32_t* numProviders + ); + + // Graph Related Function Handle Types + typedef ModelError_t (*ComposeGraphsFnHandleType_t)( + Qnn_BackendHandle_t, + QNN_INTERFACE_VER_TYPE, + Qnn_ContextHandle_t, + const GraphConfigInfo_t**, + const uint32_t, + GraphInfo_t***, + uint32_t*, + bool, + QnnLog_Callback_t, + QnnLog_Level_t + ); + + typedef ModelError_t (*GenAIComposeGraphsFnHandleType_t)( + Qnn_BackendHandle_t, + QNN_INTERFACE_VER_TYPE, + Qnn_ContextHandle_t, + const GraphConfigInfo_t**, + const uint32_t, + uint32_t* inputDim, + uint32_t inputRank, + uint32_t* outputDim, + uint32_t outputRank, + uint32_t* kvDim, + uint32_t kvRank, + Qnn_Param_t* params, + uint32_t numParam, + GraphInfo_t***, + uint32_t*, + bool, + QnnLog_Callback_t, + QnnLog_Level_t + ); + + typedef ModelError_t (*FreeGraphInfoFnHandleType_t)(GraphInfo_t***, uint32_t); + + void* m_libModelHandle{nullptr}; + void* m_backendHandle{nullptr}; + void* m_backendLibraryHandle{nullptr}; + + QNN_INTERFACE_VER_TYPE m_qnnInterface{nullptr}; + QNN_SYSTEM_INTERFACE_VER_TYPE m_qnnSystemInterface{nullptr}; + std::unique_ptr m_backendExtensions{nullptr}; + ComposeGraphsFnHandleType_t m_composeGraphsFnHandle{nullptr}; + GenAIComposeGraphsFnHandleType_t m_genaiComposeGraphsFnHandle{nullptr}; + FreeGraphInfoFnHandleType_t m_freeGraphInfoFnHandle{nullptr}; + uint32_t m_backendId{0}; + Qnn_LogHandle_t m_logHandle{nullptr}; + Qnn_DeviceHandle_t m_deviceHandle{nullptr}; + + Qnn_ProfileHandle_t m_profileBackendHandle{nullptr}; + + std::vector m_contextVec; + std::unordered_map m_contextMap; + uint32_t m_graphsCount{0}; + int32_t graphCountPerContext{-1}; + GraphInfo_t** m_graphsInfo; + std::unordered_map m_graphNameToIndex; + std::unordered_map m_graphNameToInfo; + std::unordered_map m_graphNameToContextIdx; + std::unordered_map m_contextIdtoHandle; + std::mutex m_updateCallBackMutex; + + // Useful Structure for IO Esimtation + std::unordered_map m_graphtoIOMap; // stores {GraphId -> IOTensorMap} + typedef int CtxBitVector; + std::map> m_contextAllocMap; // stores {Translated ContextId -> {Tensor name, size}} + std::map> m_tensorAllocInfo; // stores {Tensor name -> (fd of RPC buffer, offset)} + std::unordered_map m_graphIdxToContextIdx; // stores {Graph Idx -> Context Idx} + std::unordered_map> m_adapterNameToBuffer; + + uint32_t m_backendConfigCount{0}; + QnnBackend_Config_t** m_backendConfigs{nullptr}; + + QnnHtpDevice_PerfInfrastructure_t* m_perfInfra{nullptr}; + uint32_t m_powerConfigId = 1; + + // Useful Structure for IO Esimtation + IOTensor* m_ioBufferMgr{nullptr}; + int32_t m_ctxSize{-1}; + int32_t m_kvDim{-1}; + bool m_loraWeightEnabled{false}; + bool m_lmHeadWeightInput{false}; + KVManagerMode m_kvUpdateMethod{POINTER_SHIFT}; + + bool m_isLogInitialized{false}; + bool m_isBackendInitialized{false}; + bool m_isContextCreated{false}; + + // Variable to keep track of debug mode + bool m_DebugModeRequested; + bool m_debugQnn{false}; + + // Variable to indicate whether to mmap context bins or read them in memory + bool m_mmapContextBins; + bool m_isDeviceCreated = false; + + std::vector> m_contextBinBuffersToBeCleared; + + void setDeviceStatus(bool status) { m_isDeviceCreated = status; } + bool getDeviceStatus() { return m_isDeviceCreated; } + bool getContextConfigs( + QnnContext_Config_t*** configs, + uint32_t& contextConfigCount, + Qnn_Priority_t contextPriority, + bool graphSwitching = false, + const std::vector& execSelectGraphs = {}, + bool loadSelectGraphs = false + ); + bool mergeAllContextConfigs( + QnnContext_Config_t*** allCustomContextConfigs, + QnnContext_Config_t** customConfigs, + QnnContext_Config_t** contextConfigs, + uint32_t customConfigCount, + uint32_t contextConfigCount + ); + bool freeContextConfigs(QnnContext_Config_t** contextConfigs, uint32_t contextConfigCount); + bool setGraphConfigsBeforeExecute( + Qnn_GraphHandle_t graphHandle, + QnnGraph_Config_t** graphConfigs, + uint32_t configCount + ); + + bool getQnnInterface(std::string backendPath); + bool getQnnSystemInterface(std::string systemLibraryPath); + bool loadModel(std::string model_path); + bool initializeLogging(const QnnLog_Level_t& logLevel, bool debug_qnn); + void terminateLog(); + bool initializeBackendExtensions( + BackendExtensionsConfigs backendExtensionsConfig, + PerfProfile parsedPerfProfile, + bool debug_qnn + ); + bool initializeBackend(); + bool terminateBackend(); + bool createDevice(); + bool freeDevice(); + bool createContext(ContextConfigs contextConfig); + bool freeContext(); + bool composeGraphs(std::vector graphConfigs); + bool composeGraphs( + std::vector graphConfigs, + uint32_t* inputDim, + uint32_t inputRank, + uint32_t* outputDim, + uint32_t outputRank, + uint32_t* kvDim, + uint32_t kvRank, + Qnn_Param_t* params, + uint32_t numParams + ); + bool mapAndGetContextBinaryInfo( + const bool use_mmap, + std::shared_ptr& buffer, + const std::string binaryPath, + const uint64_t bufferSize, + const size_t contextIdx, + const bool graphSwitching, + QnnSystemContext_Handle_t sysCtxHandle, + const QnnSystemContext_BinaryInfo_t** binaryInfo + ); + + bool parseIOTensorsAndAccumulate(); + bool registerTensorsWithBackend(uint32_t& graphIdx); + + bool finalizeGraphs(); + bool initializePerformance(); + bool destroyPerformance(); + bool boostPerformance(); + bool resetPerformance(); + bool checkCapabilityOfCreateAsync(bool& propRet); + + bool initProfiling(); + bool extractBackendProfilingInfo( + Qnn_ProfileHandle_t profileHandle, + std::map>& timeLogs, + std::string graphName + ); + bool extractProfilingSubEvents( + QnnProfile_EventId_t profileEventId, + std::map>& timeLogs, + std::string graphName + ); + bool extractProfilingEvent( + QnnProfile_EventId_t profileEventId, + std::map>& timeLogs, + std::string graphName + ); + bool extractBackendProfilingInfo(Qnn_ProfileHandle_t profileHandle); + bool extractProfilingSubEvents(QnnProfile_EventId_t profileEventId); + bool extractProfilingEvent(QnnProfile_EventId_t profileEventId); + + Qnn_ContextHandle_t getContextWithId(uint32_t contextId) { + return m_contextIdtoHandle[contextId]; + } + + public: + QnnApi() {}; + ~QnnApi(); + + bool freeGraphs(); + static QnnApi& getInstance(); +#if QUALLA_QNN_API_VERSION >= 21700 + static void contextNotifyFn( + Qnn_ContextHandle_t context, + Qnn_GraphHandle_t graph, + const char* graph_name, + QnnContext_createFromBinaryAsyncNotifyType_t completeType, + void* notifyParam, + Qnn_ErrorHandle_t status + ); +#endif + bool createFromBinary( + std::vector cachedBinariesPathVec, + ContextConfigs contextConfig, + int64_t spill_fill_buffer_size = 0, + uint64_t mmap_budget = 0, + bool graphSwitching = false, + const std::vector& execSelectGraphs = {}, + bool loadSelectGraphs = false + ); +#if QUALLA_QNN_API_VERSION >= 21700 + bool createFromBinaryListAsync( + std::vector cachedBinariesPathVec, + ContextConfigs contextConfig, + int64_t spill_fill_buffer_size = 0, + uint64_t mmap_budget = 0, + bool graphSwitching = false, + const std::vector& execSelectGraphs = {}, + bool loadSelectGraphs = false + ); +#endif + bool initialize( + std::string backendPath, + std::vector modelPathOrCachedBinaryPathVec, + BackendExtensionsConfigs backendExtensionsConfig, + PerfProfile parsedPerfProfile = PerfProfile::BURST, + ContextConfigs contextConfig = ContextConfigs(), + std::vector graphConfigs = {}, + bool loadFromCachedBinary = false, + std::string systemLibraryPath = "", + bool debugModeRequested = false, + int64_t spill_fill_buffer_size = 0, + bool mmapContextBins = false, + bool asyncInit = true, + uint64_t mmap_budget = 0, + bool debug_qnn = false, + bool graphSwitching = false, + const std::vector& execSelectGraphs = {}, + bool loadSelectGraphs = false + ); + + bool registerOpPackage(std::string opPackagePath); + + void setIOTensorBufferMgr(IOTensor* ioBufferMgr){ + m_ioBufferMgr = ioBufferMgr; + } + + void setKVDim(int32_t kvDim){ + m_kvDim = kvDim; + } + + void setContextSize(int32_t ctxSize){ + m_ctxSize = ctxSize; + } + + void setKVUpdateMethod(KVManagerMode kvUpdateMethod){ + m_kvUpdateMethod = kvUpdateMethod ; + } + + std::map>* getTensorAllocInfo(){ + return &m_tensorAllocInfo; + } + + bool getLmHeadWeightInputEnabled(){ + return m_lmHeadWeightInput; + } + + bool getLoraWeightEnabled(){ + return m_loraWeightEnabled; + } + // Initalize with OpPackage + bool initialize( + std::string backendPath, + std::string modelPath, + std::string opPackage, + ContextConfigs contextConfig, + std::vector graphConfigs, + uint32_t* inputDim, + uint32_t inputRank, + uint32_t* outputDim, + uint32_t outputRank, + uint32_t* kvDim, + uint32_t kvRank, + Qnn_Param_t* params, + uint32_t numParams, + bool debugModeRequested + ); + + bool graphExecute( + Qnn_Tensor_t* input, + Qnn_Tensor_t* output, + std::string graphName, + std::map>& timeLogs + ); + + bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch); + + QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; }; + GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; }; + uint32_t getGraphsCount() { return m_graphsCount; }; + int32_t getGraphCountPerContext() { return graphCountPerContext; } + std::vector& getContexts() { return m_contextVec; }; + const Qnn_ContextHandle_t getContexts(GraphInfo_t* const graph) { + return m_contextMap.at(graph); + }; + + void updateContext(Qnn_ContextHandle_t context, uint32_t contextId) { + std::lock_guard lock(m_updateCallBackMutex); + m_contextVec.push_back(context); + m_contextIdtoHandle[contextId] = context; + } + + void updateQnnApiGraphsandContextsInfo( + std::string graphName, + Qnn_GraphHandle_t graph, + uint32_t contextId + ) { + // set graph handle to GraphInfo + std::lock_guard lock(m_updateCallBackMutex); + m_graphNameToInfo[graphName]->graph = graph; + m_graphNameToContextIdx[graphName] = contextId; + m_graphsCount++; + } + + static inline size_t getDataTypeSize(const Qnn_DataType_t& datatype) { + return g_qnnDataTypeToSize[datatype]; + } + static inline std::string getTensorName(const TensorWrapper& tensorWrapper) { + return GET_TENSOR_WRAPPER_NAME(tensorWrapper); + } + static bool getTensorQuantParams( + const Qnn_Tensor_t* tensor, + std::vector& quantParamsVec + ); + static bool getTensorShape(std::vector& tensorDims, const TensorWrapper& tensorWrapper); + static inline Qnn_DataType_t getTensorDtype(const Qnn_Tensor_t* tensor) { + return QNN_TENSOR_GET_DATA_TYPE(tensor); + } + + bool getTensorNameAndShape( + std::string& tensorName, + std::vector& tensorDims, + TensorWrapper& tensorWrapper + ); + static void qnnLogCallback( + const char* fmt, + QnnLog_Level_t level, + uint64_t timestamp, + va_list args + ); + bool updateIOEncodings(std::shared_ptr& buffer, + uint64_t bufferSize, + uint32_t graphIndex); +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e8e71a2626a7c6e6aa213114fa310ba5450cc486 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp @@ -0,0 +1,636 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "QnnApiUtils.hpp" +#include "QnnTypeMacros.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef _WIN32 + #include + #define __open ::_open + #define __strdup ::_strdup +#else + #include + #include + #define __open ::open + #define __strdup ::strdup +#endif + +bool freeQnnTensorWrapper(TensorWrapper& tensorWrapper) { + // free all pointer allocations in struct + if (nullptr != GET_TENSOR_WRAPPER_NAME(tensorWrapper)) { + free((void*)GET_TENSOR_WRAPPER_NAME(tensorWrapper)); + } + + Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrapper); + free(QNN_TENSOR_GET_DIMENSIONS(tensor)); + return true; +} + +bool freeQnnTensorWrappers(TensorWrapper*& tensorWrappers, uint32_t numTensors) { + // free all pointer allocations in struct + for (size_t i = 0; i < numTensors; i++) { + freeQnnTensorWrapper(tensorWrappers[i]); + } + free(tensorWrappers); + + return true; +} + +bool freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs) { + if (graphsInfo == nullptr || *graphsInfo == nullptr) { + return false; + } + for (uint32_t i = 0; i < numGraphs; i++) { + if (nullptr != (*graphsInfo)[i]) { + free((*graphsInfo)[i]->graphName); + freeQnnTensorWrappers( + (*graphsInfo)[i]->inputTensors, (*graphsInfo)[i]->numInputTensors + ); + freeQnnTensorWrappers( + (*graphsInfo)[i]->outputTensors, (*graphsInfo)[i]->numOutputTensors + ); + } + } + free(**graphsInfo); + free(*graphsInfo); + *graphsInfo = nullptr; + + return true; +} + +bool freeGraphInfo(GraphInfo_t* graphInfo) { + if (graphInfo == nullptr) { + return false; + } + if (nullptr != graphInfo->graphName) { + free(graphInfo->graphName); + } + freeQnnTensorWrappers(graphInfo->inputTensors, graphInfo->numInputTensors); + freeQnnTensorWrappers(graphInfo->outputTensors, graphInfo->numOutputTensors); + free(graphInfo); + return true; +} + +bool updateTensorInfo(const Qnn_Tensor_t* tensorsInfoSrc, + TensorWrapper* tensorWrappers, + uint32_t tensorsCount +){ + for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) { + QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", (int)tIdx); + Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tIdx]); + + QNN_TENSOR_SET_ID(tensor, QNN_TENSOR_GET_ID(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_GET_TYPE(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_DATA_FORMAT(tensor, QNN_TENSOR_GET_DATA_FORMAT(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_DATA_TYPE(tensor, QNN_TENSOR_GET_DATA_TYPE(&tensorsInfoSrc[tIdx])); + Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT; + qParams.encodingDefinition = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).encodingDefinition; + qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; + if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding == + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + qParams.quantizationEncoding = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding; + qParams.scaleOffsetEncoding = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).scaleOffsetEncoding; + } else if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + qParams.quantizationEncoding = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding; + qParams.axisScaleOffsetEncoding.axis = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.axis; + qParams.axisScaleOffsetEncoding.numScaleOffsets = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets; + if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets > 0) { + qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc( + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets * + sizeof(Qnn_ScaleOffset_t) + ); + if (qParams.axisScaleOffsetEncoding.scaleOffset) { + for (size_t idx = 0; + idx < QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets; + idx++) { + qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.scaleOffset[idx] + .scale; + qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.scaleOffset[idx] + .offset; + } + } + } + } + QNN_TENSOR_SET_QUANT_PARAMS(tensor, qParams); + QNN_TENSOR_SET_RANK(tensor, QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx])); + if (QNN_TENSOR_GET_RANK(tensorsInfoSrc[tIdx]) > 0) { + if (QNN_TENSOR_GET_DIMENSIONS(tensor)) { + memcpy(QNN_TENSOR_GET_DIMENSIONS(tensor), + QNN_TENSOR_GET_DIMENSIONS(&tensorsInfoSrc[tIdx]), + QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t)); + } + } + } + return true; +} + +bool copyTensorsInfo( + const Qnn_Tensor_t* tensorsInfoSrc, + TensorWrapper*& tensorWrappers, + uint32_t tensorsCount +) { + + auto returnStatus = true; + tensorWrappers = (TensorWrapper*)calloc(tensorsCount, sizeof(TensorWrapper)); + if (nullptr == tensorWrappers) { + QNN_ERROR("Failed to allocate memory for tensorWrappers."); + return false; + } + if (returnStatus) { + for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) { + // QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", (int)tIdx); + Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tIdx]); + tensor = QNN_TENSOR_INIT; + + const char* tensorName = QNN_TENSOR_GET_NAME(&tensorsInfoSrc[tIdx]); + if (!tensorName) { + QNN_TENSOR_SET_NAME(tensor, nullptr); + } else { + QNN_TENSOR_SET_NAME(tensor, __strdup(tensorName)); + } + + QNN_TENSOR_SET_ID(tensor, QNN_TENSOR_GET_ID(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_GET_TYPE(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_DATA_FORMAT(tensor, QNN_TENSOR_GET_DATA_FORMAT(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_DATA_TYPE(tensor, QNN_TENSOR_GET_DATA_TYPE(&tensorsInfoSrc[tIdx])); + Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT; + qParams.encodingDefinition = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).encodingDefinition; + qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED; + if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding == + QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) { + qParams.quantizationEncoding = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding; + qParams.scaleOffsetEncoding = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).scaleOffsetEncoding; + } else if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding == + QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) { + qParams.quantizationEncoding = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding; + qParams.axisScaleOffsetEncoding.axis = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.axis; + qParams.axisScaleOffsetEncoding.numScaleOffsets = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets; + if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets > 0) { + qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc( + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets * + sizeof(Qnn_ScaleOffset_t) + ); + if (qParams.axisScaleOffsetEncoding.scaleOffset) { + for (size_t idx = 0; + idx < QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.numScaleOffsets; + idx++) { + qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.scaleOffset[idx] + .scale; + qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset = + QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]) + .axisScaleOffsetEncoding.scaleOffset[idx] + .offset; + } + } + } + } + QNN_TENSOR_SET_QUANT_PARAMS(tensor, qParams); + QNN_TENSOR_SET_RANK(tensor, QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx])); + QNN_TENSOR_SET_DIMENSIONS(tensor, nullptr); + if (QNN_TENSOR_GET_RANK(tensorsInfoSrc[tIdx]) > 0) { + QNN_TENSOR_SET_DIMENSIONS( + tensor, + (uint32_t*)malloc( + QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t) + ) + ); + if (QNN_TENSOR_GET_DIMENSIONS(tensor)) { + memcpy(QNN_TENSOR_GET_DIMENSIONS(tensor), + QNN_TENSOR_GET_DIMENSIONS(&tensorsInfoSrc[tIdx]), + QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t)); + } + } + } + } + + return returnStatus; +} + + +bool updateGraphInfoV1(const QnnSystemContext_GraphInfoV1_t* graphInfoSrc, + GraphInfo_t* graphInfoDst +){ + if (graphInfoSrc->graphInputs) { + if (!updateTensorInfo( + graphInfoSrc->graphInputs, + graphInfoDst->inputTensors, + graphInfoSrc->numGraphInputs + )) { + return false; + } + } + if (graphInfoSrc->graphOutputs) { + if (!updateTensorInfo( + graphInfoSrc->graphOutputs, + graphInfoDst->outputTensors, + graphInfoSrc->numGraphOutputs + )) { + return false; + } + } + return true; +} + + +bool updateGraphInfoV3(const QnnSystemContext_GraphInfoV3_t* graphInfoSrc, + GraphInfo_t* graphInfoDst +){ + if (graphInfoSrc->graphInputs) { + if (!updateTensorInfo( + graphInfoSrc->graphInputs, + graphInfoDst->inputTensors, + graphInfoSrc->numGraphInputs + )) { + return false; + } + } + if (graphInfoSrc->graphOutputs) { + if (!updateTensorInfo( + graphInfoSrc->graphOutputs, + graphInfoDst->outputTensors, + graphInfoSrc->numGraphOutputs + )) { + return false; + } + } + return true; +} + +bool copyGraphsInfoV1( + const QnnSystemContext_GraphInfoV1_t* graphInfoSrc, + GraphInfo_t* graphInfoDst +) { + graphInfoDst->graphName = nullptr; + if (graphInfoSrc->graphName) { + graphInfoDst->graphName = __strdup(graphInfoSrc->graphName); + } + graphInfoDst->inputTensors = nullptr; + graphInfoDst->numInputTensors = 0; + if (graphInfoSrc->graphInputs) { + if (!copyTensorsInfo( + graphInfoSrc->graphInputs, + graphInfoDst->inputTensors, + graphInfoSrc->numGraphInputs + )) { + return false; + } + graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs; + } + graphInfoDst->outputTensors = nullptr; + graphInfoDst->numOutputTensors = 0; + if (graphInfoSrc->graphOutputs) { + if (!copyTensorsInfo( + graphInfoSrc->graphOutputs, + graphInfoDst->outputTensors, + graphInfoSrc->numGraphOutputs + )) { + return false; + } + graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs; + } + return true; +} + +bool copyGraphsInfoV3(const QnnSystemContext_GraphInfoV3_t *graphInfoSrc, + GraphInfo_t *graphInfoDst) { + graphInfoDst->graphName = nullptr; + if (graphInfoSrc->graphName) { + graphInfoDst->graphName = + __strdup(graphInfoSrc->graphName); + } + graphInfoDst->inputTensors = nullptr; + graphInfoDst->numInputTensors = 0; + if (graphInfoSrc->graphInputs) { + if (!copyTensorsInfo( + graphInfoSrc->graphInputs, graphInfoDst->inputTensors, graphInfoSrc->numGraphInputs)) { + return false; + } + graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs; + } + graphInfoDst->outputTensors = nullptr; + graphInfoDst->numOutputTensors = 0; + if (graphInfoSrc->graphOutputs) { + if (!copyTensorsInfo(graphInfoSrc->graphOutputs, + graphInfoDst->outputTensors, + graphInfoSrc->numGraphOutputs)) { + return false; + } + graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs; + } + return true; +} + +bool updateGraphInfo(const QnnSystemContext_GraphInfo_t* graphsInput, + const uint32_t numGraphs, + GraphInfo_t** graphsInfo, + uint32_t& graphsCount +){ + + for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) { + if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + if(updateGraphInfoV1(&graphsInput[gIdx].graphInfoV1, graphsInfo[graphsCount]) == false) { + return false; + } + } + if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + if(updateGraphInfoV3(&graphsInput[gIdx].graphInfoV3, graphsInfo[graphsCount]) == false) { + return false; + } + } + graphsCount++; + } + return true; +} + + +bool copyGraphsInfo( + const QnnSystemContext_GraphInfo_t* graphsInput, + const uint32_t numGraphs, + GraphInfo_t**& graphsInfo +) { + + if (!graphsInput) { + QNN_ERROR("Received nullptr for graphsInput."); + return false; + } + auto returnStatus = true; + graphsInfo = (GraphInfo_t**)calloc(numGraphs, sizeof(GraphInfo_t*)); + GraphInfo_t* graphInfoArr = (GraphInfo_t*)calloc(numGraphs, sizeof(GraphInfo_t)); + if (nullptr == graphsInfo || nullptr == graphInfoArr) { + QNN_ERROR("Failure to allocate memory for *graphInfo"); + returnStatus = false; + } + if (true == returnStatus) { + for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) { + QNN_DEBUG("Extracting graphsInfo for graph Idx: %d", (int)gIdx); + if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) { + copyGraphsInfoV1(&graphsInput[gIdx].graphInfoV1, &graphInfoArr[gIdx]); + } + if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) { + copyGraphsInfoV3(&graphsInput[gIdx].graphInfoV3, &graphInfoArr[gIdx]); + } + graphsInfo[gIdx] = graphInfoArr + gIdx; + } + } + if (true != returnStatus) { + QNN_DEBUG("Received an ERROR during extractGraphsInfo. Freeing resources."); + if (graphsInfo) { + for (uint32_t gIdx = 0; gIdx < numGraphs; gIdx++) { + if (graphsInfo[gIdx]) { + if (nullptr != graphsInfo[gIdx]->graphName) { + free(graphsInfo[gIdx]->graphName); + graphsInfo[gIdx]->graphName = nullptr; + } + freeQnnTensorWrappers( + graphsInfo[gIdx]->inputTensors, graphsInfo[gIdx]->numInputTensors + ); + freeQnnTensorWrappers( + graphsInfo[gIdx]->outputTensors, graphsInfo[gIdx]->numOutputTensors + ); + } + } + free(*graphsInfo); + } + free(graphsInfo); + graphsInfo = nullptr; + } + + return true; +} + +uint32_t getNumGraphInBinary(const QnnSystemContext_BinaryInfo_t* binaryInfo) +{ + uint32_t numGraph = 0; + if (nullptr == binaryInfo) { + QNN_ERROR("binaryInfo is nullptr."); + return false; + } + if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + numGraph = binaryInfo->contextBinaryInfoV1.numGraphs; + }else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + numGraph = binaryInfo->contextBinaryInfoV2.numGraphs; + } + else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + numGraph = binaryInfo->contextBinaryInfoV3.numGraphs; + } + return numGraph; +} + +bool updateMetaDataToGraphsInfo(const QnnSystemContext_BinaryInfo_t* binaryInfo, + GraphInfo_t** graphsInfo, + uint32_t& graphsCount +){ + if (nullptr == binaryInfo) { + QNN_ERROR("binaryInfo is nullptr."); + return false; + } + if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + if (binaryInfo->contextBinaryInfoV1.graphs) { + if (!updateGraphInfo( + binaryInfo->contextBinaryInfoV1.graphs, + binaryInfo->contextBinaryInfoV1.numGraphs, + graphsInfo, + graphsCount + )) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + return true; + } + } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + if (binaryInfo->contextBinaryInfoV2.graphs) { + if (!updateGraphInfo( + binaryInfo->contextBinaryInfoV2.graphs, + binaryInfo->contextBinaryInfoV2.numGraphs, + graphsInfo, + graphsCount + )) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + return true; + } + } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + if (binaryInfo->contextBinaryInfoV3.graphs) { + if (!updateGraphInfo( + binaryInfo->contextBinaryInfoV3.graphs, + binaryInfo->contextBinaryInfoV3.numGraphs, + graphsInfo, + graphsCount + )) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + return true; + } + } + QNN_ERROR("Unrecognized system context binary info version."); + return false; +} + +bool copyMetadataToGraphsInfo( + const QnnSystemContext_BinaryInfo_t* binaryInfo, + GraphInfo_t**& graphsInfo, + uint32_t& graphsCount +) { + if (nullptr == binaryInfo) { + QNN_ERROR("binaryInfo is nullptr."); + return false; + } + graphsCount = 0; + if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) { + if (binaryInfo->contextBinaryInfoV1.graphs) { + if (!copyGraphsInfo( + binaryInfo->contextBinaryInfoV1.graphs, + binaryInfo->contextBinaryInfoV1.numGraphs, + graphsInfo + )) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + graphsCount = binaryInfo->contextBinaryInfoV1.numGraphs; + return true; + } + } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) { + if (binaryInfo->contextBinaryInfoV2.graphs) { + if (!copyGraphsInfo( + binaryInfo->contextBinaryInfoV2.graphs, + binaryInfo->contextBinaryInfoV2.numGraphs, + graphsInfo + )) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + graphsCount = binaryInfo->contextBinaryInfoV2.numGraphs; + return true; + } + } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) { + if (binaryInfo->contextBinaryInfoV3.graphs) { + if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV3.graphs, + binaryInfo->contextBinaryInfoV3.numGraphs, + graphsInfo)) { + QNN_ERROR("Failed while copying graphs Info."); + return false; + } + graphsCount = binaryInfo->contextBinaryInfoV3.numGraphs; + return true; + } + } + QNN_ERROR("Unrecognized system context binary info version."); + return false; +} + +size_t getFileSize(std::string filePath) { + std::ifstream in(filePath, std::ifstream::binary); + if (!in) { + QNN_ERROR("Failed to open input file: %s", filePath.c_str()); + return 0; + } + in.seekg(0, in.end); + const size_t length = in.tellg(); + in.seekg(0, in.beg); + return length; +} + +bool readBinaryFromFile(std::string filePath, void* buffer, size_t bufferSize) { + if (nullptr == buffer) { + QNN_ERROR("buffer is nullptr"); + return false; + } + std::ifstream in(filePath, std::ifstream::binary); + if (!in) { + QNN_ERROR("Failed to open input file: %s", filePath.c_str()); + return false; + } + if (!in.read(reinterpret_cast(buffer), bufferSize)) { + QNN_ERROR("Failed to read the contents of: %s", filePath.c_str()); + return false; + } + return true; +} + +bool mmapBinaryFile(std::string filePath, void** buffer, size_t bufferSize) { +#ifndef _WIN32 + int fd = open(filePath.c_str(), O_RDONLY); + int OFFSET = 0; + + // read the binary file as memory map + *buffer = mmap(nullptr, bufferSize, PROT_READ, MAP_PRIVATE, fd, OFFSET); + close(fd); + if (madvise(*buffer, bufferSize, MADV_NOHUGEPAGE)) { + QNN_ERROR("Failed to advise OS on memory usage err: %s", strerror(errno)); + } + + return true; +#else + return false; +#endif +} + +bool fillDims(std::vector& dims, uint32_t* inDimensions, uint32_t rank) { + if (nullptr == inDimensions) { + QNN_ERROR("input dimensions is nullptr"); + return false; + } + + if (rank < 1) { + QNN_ERROR("invalid rank : %d", rank); + return false; + } + + // In case, rank is less than 4, we are pushing 1s + for (size_t r = 0; r < 4 - rank; r++) { + dims.push_back(1); + } + + for (size_t r = 0; r < rank; r++) { + dims.push_back(inDimensions[r]); + } + + return true; +} diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..71ccaf9610ce5707c00a627350fe38559372263f --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp @@ -0,0 +1,94 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "QnnInterface.h" +#include "QnnTypes.h" +#include "System/QnnSystemInterface.h" + +#include +#include +#include +#include +#include +#include + +#include "QnnTypeDef.hpp" +#include "Log.hpp" + +/** + * @brief Frees all memory allocated tensor attributes. + * + * @param[in] tensorWrapper tensor object to free + * + * @return Error code + */ +bool freeQnnTensorWrapper(TensorWrapper& tensorWrapper); + +/** + * @brief Loops through and frees all memory allocated tensor attributes for each tensorWrapper + * object. + * + * @param[in] tensorWrappers array of tensor objects to free + * + * @param[in] numTensors length of the above tensorWrappers array + * + * @return Error code + */ +bool freeQnnTensorWrappers(TensorWrapper*& tensorWrappers, uint32_t numTensors); + +/** + * @brief A helper function to free memory malloced for communicating the Graph for a model(s) + * + * @param[in] graphsInfo Pointer pointing to location of graph objects + * + * @param[in] numGraphs The number of graph objects the above pointer is pointing to + * + * @return Error code + * + */ +bool freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs); + +bool freeGraphInfo(GraphInfo_t* graphInfo); + +bool copyMetadataToGraphsInfo( + const QnnSystemContext_BinaryInfo_t* binaryInfo, + GraphInfo_t**& graphsInfo, + uint32_t& graphsCount +); + +bool copyGraphsInfo( + const QnnSystemContext_GraphInfo_t* graphsInput, + const uint32_t numGraphs, + GraphInfo_t**& graphsInfo +); + +bool copyGraphsInfoV1( + const QnnSystemContext_GraphInfoV1_t* graphInfoSrc, + GraphInfo_t* graphInfoDst +); + +bool copyTensorsInfo( + const Qnn_Tensor_t* tensorsInfoSrc, + TensorWrapper*& tensorWrappers, + uint32_t tensorsCount +); + +bool fillDims(std::vector& dims, uint32_t* inDimensions, uint32_t rank); +size_t getFileSize(std::string filePath); +bool readBinaryFromFile(std::string filePath, void* buffer, size_t bufferSize); +bool mmapBinaryFile(std::string filePath, void** buffer, size_t bufferSize); +bool updateMetaDataToGraphsInfo(const QnnSystemContext_BinaryInfo_t* binaryInfo,GraphInfo_t** graphsInfo,uint32_t& graphsCount); +bool updateGraphInfo(const QnnSystemContext_GraphInfo_t* graphsInput, + const uint32_t currCount, + GraphInfo_t* graphsInfo); +bool updateGraphInfoV1(const QnnSystemContext_GraphInfoV1_t* graphInfoSrc, + GraphInfo_t* graphInfoDst); +bool updateTensorInfo(const Qnn_Tensor_t* tensorsInfoSrc, + TensorWrapper* tensorWrappers, + uint32_t tensorsCount); +uint32_t getNumGraphInBinary(const QnnSystemContext_BinaryInfo_t* binaryInfo); \ No newline at end of file diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1690c589197a00352fe29fa969b86c19a5839677 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp @@ -0,0 +1,44 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== +#pragma once + +#include "QnnGraph.h" +#include "QnnTypes.h" +#include + +struct BackendExtensionsConfigs { + std::string sharedLibraryPath; + std::string configFilePath; + BackendExtensionsConfigs() : sharedLibraryPath(""), configFilePath("") {} + BackendExtensionsConfigs(std::string sharedLibraryPath, std::string configFilePath) + : sharedLibraryPath(sharedLibraryPath), configFilePath(configFilePath) {} +}; + +struct ContextConfigs { + bool priorityPresent; + Qnn_Priority_t priority; + ContextConfigs() : priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) {} + ContextConfigs(Qnn_Priority_t priority) : priorityPresent(true), priority(priority) {} +}; + +struct GraphConfigs { + std::string graphName; + bool priorityPresent; + Qnn_Priority_t priority; + GraphConfigs() + : graphName(), + priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) { + } +}; + +struct ConfigOptions { + BackendExtensionsConfigs backendExtensionsConfigs; + ContextConfigs contextConfigs; + std::vector graphConfigs; + ConfigOptions() : backendExtensionsConfigs(), contextConfigs(), graphConfigs() {} +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a6ed57fc8bea91e957ee66f7d9918e6d1bb78cf3 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp @@ -0,0 +1,52 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef QNN_TYPE_DEF_H_ +#define QNN_TYPE_DEF_H_ + +#include "QnnInterface.h" +#include "QnnTypes.h" +#include "Log.hpp" +#include "QnnTypeMacros.hpp" + +typedef enum ModelError { + MODEL_NO_ERROR = 0, + MODEL_TENSOR_ERROR = 1, + MODEL_PARAMS_ERROR = 2, + MODEL_NODES_ERROR = 3, + MODEL_GRAPH_ERROR = 4, + MODEL_CONTEXT_ERROR = 5, + MODEL_GENERATION_ERROR = 6, + MODEL_SETUP_ERROR = 7, + MODEL_INVALID_ARGUMENT_ERROR = 8, + MODEL_FILE_ERROR = 9, + MODEL_MEMORY_ALLOCATE_ERROR = 10, + // Value selected to ensure 32 bits. + MODEL_UNKNOWN_ERROR = 0x7FFFFFFF +} ModelError_t; + +using TensorWrapper = Qnn_Tensor_t; + #define GET_TENSOR_WRAPPER_TENSOR(tensorWrapper) tensorWrapper + #define GET_TENSOR_WRAPPER_NAME(tensorWrapper) QNN_TENSOR_GET_NAME(tensorWrapper) + +typedef struct GraphInfo { + Qnn_GraphHandle_t graph; + char* graphName; + TensorWrapper* inputTensors; + uint32_t numInputTensors; + TensorWrapper* outputTensors; + uint32_t numOutputTensors; +} GraphInfo_t; +typedef GraphInfo_t* GraphInfoPtr_t; + +typedef struct GraphConfigInfo { + char* graphName; + const QnnGraph_Config_t** graphConfigs; +} GraphConfigInfo_t; + +#endif // QNN_TYPE_DEF_H_ diff --git a/Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp b/Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cc0548e07031a1a1dae7010164018be3671a51f1 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp @@ -0,0 +1,702 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include "QnnTypes.h" + +#define QNN_OP_CFG_VALID(opConfig) ((opConfig).version == QNN_OPCONFIG_VERSION_1) + +inline Qnn_OpConfig_t createQnnOpConfig(const Qnn_OpConfigVersion_t version) { + Qnn_OpConfig_t opConfig = QNN_OPCONFIG_INIT; + opConfig.version = version; + if (version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1 = QNN_OPCONFIG_V1_INIT; + } + return opConfig; +} + +inline const char* getQnnOpConfigName(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.name; + } + return NULL; +} + +inline const char* getQnnOpConfigName(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigName(*opConfig); +} + +inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.packageName; + } + return NULL; +} + +inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigPackageName(*opConfig); +} + +inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.typeName; + } + return NULL; +} + +inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigTypeName(*opConfig); +} + +inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.numOfParams; + } + return 0u; +} + +inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigNumParams(*opConfig); +} + +inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.params; + } + return NULL; +} + +inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigParams(*opConfig); +} + +inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.numOfInputs; + } + return 0u; +} + +inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigNumInputs(*opConfig); +} + +inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.inputTensors; + } + return NULL; +} + +inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigInputs(*opConfig); +} + +inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.numOfOutputs; + } + return 0u; +} + +inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigNumOutputs(*opConfig); +} + +inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t& opConfig) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + return opConfig.v1.outputTensors; + } + return NULL; +} + +inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t* const opConfig) { + return getQnnOpConfigOutputs(*opConfig); +} + +inline void setQnnOpConfigName(Qnn_OpConfig_t& opConfig, const char* const name) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.name = name; + } +} + +inline void setQnnOpConfigName(Qnn_OpConfig_t* const opConfig, const char* const name) { + setQnnOpConfigName(*opConfig, name); +} + +inline void setQnnOpConfigPackageName(Qnn_OpConfig_t& opConfig, const char* const packageName) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.packageName = packageName; + } +} + +inline void setQnnOpConfigPackageName( + Qnn_OpConfig_t* const opConfig, + const char* const packageName +) { + setQnnOpConfigPackageName(*opConfig, packageName); +} + +inline void setQnnOpConfigTypeName(Qnn_OpConfig_t& opConfig, const char* const typeName) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.typeName = typeName; + } +} + +inline void setQnnOpConfigTypeName(Qnn_OpConfig_t* const opConfig, const char* const typeName) { + setQnnOpConfigTypeName(*opConfig, typeName); +} + +inline void setQnnOpConfigParams( + Qnn_OpConfig_t& opConfig, + uint32_t const numOfParams, + Qnn_Param_t* const params +) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.numOfParams = numOfParams; + opConfig.v1.params = params; + } +} + +inline void setQnnOpConfigParams( + Qnn_OpConfig_t* const opConfig, + uint32_t const numOfParams, + Qnn_Param_t* const params +) { + setQnnOpConfigParams(*opConfig, numOfParams, params); +} + +inline void setQnnOpConfigInputs( + Qnn_OpConfig_t& opConfig, + uint32_t const numOfInputs, + Qnn_Tensor_t* const inputTensors +) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.numOfInputs = numOfInputs; + opConfig.v1.inputTensors = inputTensors; + } +} + +inline void setQnnOpConfigInputs( + Qnn_OpConfig_t* const opConfig, + uint32_t const numOfInputs, + Qnn_Tensor_t* const inputTensors +) { + setQnnOpConfigInputs(*opConfig, numOfInputs, inputTensors); +} + +inline void setQnnOpConfigOutputs( + Qnn_OpConfig_t& opConfig, + uint32_t const numOfOutputs, + Qnn_Tensor_t* const outputTensors +) { + if (opConfig.version == QNN_OPCONFIG_VERSION_1) { + opConfig.v1.numOfOutputs = numOfOutputs; + opConfig.v1.outputTensors = outputTensors; + } +} + +inline void setQnnOpConfigOutputs( + Qnn_OpConfig_t* const opConfig, + uint32_t const numOfOutputs, + Qnn_Tensor_t* const outputTensors +) { + setQnnOpConfigOutputs(*opConfig, numOfOutputs, outputTensors); +} + +inline Qnn_Tensor_t createQnnTensor(const Qnn_TensorVersion_t version) { + Qnn_Tensor_t tensor = QNN_TENSOR_INIT; + tensor.version = version; + if (version == QNN_TENSOR_VERSION_1) { + tensor.v1 = QNN_TENSOR_V1_INIT; + } else if (version == QNN_TENSOR_VERSION_2) { + tensor.v2 = QNN_TENSOR_V2_INIT; + } + return tensor; +} + +inline uint32_t getQnnTensorId(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.id; +} + +inline uint32_t getQnnTensorId(const Qnn_Tensor_t* const tensor) { + return getQnnTensorId(*tensor); +} + +inline const char* getQnnTensorName(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.name; +} + +inline const char* getQnnTensorName(const Qnn_Tensor_t* const tensor) { + return getQnnTensorName(*tensor); +} + +inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.type; +} + +inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t* const tensor) { + return getQnnTensorType(*tensor); +} + +inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.dataFormat; +} + +inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t* const tensor) { + return getQnnTensorDataFormat(*tensor); +} + +inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.dataType; +} + +inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t* const tensor) { + return getQnnTensorDataType(*tensor); +} + +inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.quantizeParams; +} + +inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t* const tensor) { + if (tensor != nullptr) { + return getQnnTensorQuantParams(*tensor); + } + return QNN_QUANTIZE_PARAMS_INIT; +} + +inline uint32_t getQnnTensorRank(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.rank; +} + +inline uint32_t getQnnTensorRank(const Qnn_Tensor_t* const tensor) { + if (tensor != nullptr) { + return getQnnTensorRank(*tensor); + } + return 0u; +} + +inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.dimensions; +} + +inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t* const tensor) { + return getQnnTensorDimensions(*tensor); +} + +inline uint8_t* getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& tensor) { + if (tensor.version == QNN_TENSOR_VERSION_1) { + return NULL; + } else if (tensor.version == QNN_TENSOR_VERSION_2) { + return tensor.v2.isDynamicDimensions; + } + return NULL; +} + +inline uint8_t* getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t* tensor) { + return getQnnTensorIsDynamicDimensions(*tensor); +} + +inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t& tensor) { + if (tensor.version == QNN_TENSOR_VERSION_1) { + return QNN_SPARSE_PARAMS_INIT; + } else if (tensor.version == QNN_TENSOR_VERSION_2) { + return tensor.v2.sparseParams; + } + return QNN_SPARSE_PARAMS_INIT; +} + +inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t* tensor) { + return getQnnTensorSparseParams(*tensor); +} + +inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.memType; +} + +inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t* const tensor) { + return getQnnTensorMemType(*tensor); +} + +inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.clientBuf; +} + +inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t* const tensor) { + return getQnnTensorClientBuf(*tensor); +} + +inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t& tensor) { + // TensorCompatTest justifies no need to check version + return tensor.v1.memHandle; +} + +inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t* const tensor) { + return getQnnTensorMemHandle(*tensor); +} + +inline void setQnnTensorId(Qnn_Tensor_t& tensor, const uint32_t id) { + // TensorCompatTest justifies no need to check version + tensor.v1.id = id; +} + +inline void setQnnTensorId(Qnn_Tensor_t* const tensor, const uint32_t id) { + setQnnTensorId(*tensor, id); +} + +inline void setQnnTensorName(Qnn_Tensor_t& tensor, const char* const name) { + // TensorCompatTest justifies no need to check version + tensor.v1.name = name; +} + +inline void setQnnTensorName(Qnn_Tensor_t* const tensor, const char* const name) { + setQnnTensorName(*tensor, name); +} + +inline void setQnnTensorType(Qnn_Tensor_t& tensor, const Qnn_TensorType_t type) { + // TensorCompatTest justifies no need to check version + tensor.v1.type = type; +} + +inline void setQnnTensorType(Qnn_Tensor_t* const tensor, const Qnn_TensorType_t type) { + setQnnTensorType(*tensor, type); +} + +inline void setQnnTensorDataFormat(Qnn_Tensor_t& tensor, const Qnn_TensorDataFormat_t dataFormat) { + // TensorCompatTest justifies no need to check version + tensor.v1.dataFormat = dataFormat; +} + +inline void setQnnTensorDataFormat( + Qnn_Tensor_t* const tensor, + const Qnn_TensorDataFormat_t format +) { + setQnnTensorDataFormat(*tensor, format); +} + +inline void setQnnTensorDataType(Qnn_Tensor_t& tensor, const Qnn_DataType_t dataType) { + // TensorCompatTest justifies no need to check version + tensor.v1.dataType = dataType; +} + +inline void setQnnTensorDataType(Qnn_Tensor_t* const tensor, const Qnn_DataType_t dataType) { + setQnnTensorDataType(*tensor, dataType); +} + +inline void setQnnTensorQuantParams( + Qnn_Tensor_t& tensor, + const Qnn_QuantizeParams_t quantizeParams +) { + // TensorCompatTest justifies no need to check version + tensor.v1.quantizeParams = quantizeParams; +} + +inline void setQnnTensorQuantParams(Qnn_Tensor_t* const tensor, const Qnn_QuantizeParams_t params) { + setQnnTensorQuantParams(*tensor, params); +} + +inline void setQnnTensorRank(Qnn_Tensor_t& tensor, const uint32_t rank) { + // TensorCompatTest justifies no need to check version + tensor.v1.rank = rank; +} + +inline void setQnnTensorRank(Qnn_Tensor_t* const tensor, const uint32_t rank) { + setQnnTensorRank(*tensor, rank); +} + +inline void setQnnTensorDimensions(Qnn_Tensor_t& tensor, uint32_t* const dimensions) { + // TensorCompatTest justifies no need to check version + tensor.v1.dimensions = dimensions; +} + +inline void setQnnTensorDimensions(Qnn_Tensor_t* const tensor, uint32_t* const dimensions) { + setQnnTensorDimensions(*tensor, dimensions); +} + +inline void setQnnTensorIsDynamicDimensions( + Qnn_Tensor_t& tensor, + uint8_t* const isDynamicDimensions +) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + tensor.v2.isDynamicDimensions = isDynamicDimensions; + } +} + +inline void setQnnTensorIsDynamicDimensions( + Qnn_Tensor_t* tensor, + uint8_t* const isDynamicDimensions +) { + setQnnTensorIsDynamicDimensions(*tensor, isDynamicDimensions); +} + +inline void setQnnTensorSparseParams(Qnn_Tensor_t& tensor, const Qnn_SparseParams_t sparseParams) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + tensor.v2.sparseParams = sparseParams; + } +} + +inline void setQnnTensorSparseParams(Qnn_Tensor_t* tensor, Qnn_SparseParams_t sparseParams) { + setQnnTensorSparseParams(*tensor, sparseParams); +} + +inline void setQnnTensorMemType(Qnn_Tensor_t& tensor, const Qnn_TensorMemType_t memType) { + // TensorCompatTest justifies no need to check version + tensor.v1.memType = memType; +} + +inline void setQnnTensorMemType(Qnn_Tensor_t* const tensor, const Qnn_TensorMemType_t memType) { + setQnnTensorMemType(*tensor, memType); +} + +inline void setQnnTensorClientBuf(Qnn_Tensor_t& tensor, const Qnn_ClientBuffer_t clientBuf) { + // TensorCompatTest justifies no need to check version + tensor.v1.clientBuf = clientBuf; +} + +inline void setQnnTensorClientBuf(Qnn_Tensor_t* const tensor, const Qnn_ClientBuffer_t clientBuf) { + setQnnTensorClientBuf(*tensor, clientBuf); +} + +inline void setQnnTensorMemHandle(Qnn_Tensor_t& tensor, const Qnn_MemHandle_t memHandle) { + // TensorCompatTest justifies no need to check version + tensor.v1.memHandle = memHandle; +} + +inline void setQnnTensorMemHandle(Qnn_Tensor_t* const tensor, const Qnn_MemHandle_t handle) { + setQnnTensorMemHandle(*tensor, handle); +} + +inline Qnn_TensorSet_t createQnnTensorSet(const Qnn_TensorSetVersion_t version) { + Qnn_TensorSet_t tensorSet = QNN_TENSOR_SET_INIT; + tensorSet.version = version; + if (version == QNN_TENSOR_SET_VERSION_1) { + tensorSet.v1 = QNN_TENSOR_SET_V1_INIT; + } + return tensorSet; +} + +inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t& tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.numInputs; + } + return 0; +} + +inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t* tensorSet) { + return getQnnTensorSetNumInputs(*tensorSet); +} + +inline Qnn_Tensor_t* getQnnTensorSetInputTensors(const Qnn_TensorSet_t& tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.inputs; + } + return 0; +} + +inline Qnn_Tensor_t* getQnnTensorSetInputTensors(const Qnn_TensorSet_t* tensorSet) { + return getQnnTensorSetInputTensors(*tensorSet); +} + +inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t& tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.numOutputs; + } + return 0; +} + +inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t* tensorSet) { + return getQnnTensorSetNumOutputs(*tensorSet); +} + +inline Qnn_Tensor_t* getQnnTensorSetOutputTensors(const Qnn_TensorSet_t& tensorSet) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + return tensorSet.v1.outputs; + } + return 0; +} + +inline Qnn_Tensor_t* getQnnTensorSetOutputTensors(const Qnn_TensorSet_t* tensorSet) { + return getQnnTensorSetOutputTensors(*tensorSet); +} + +inline void setQnnTensorSetInputTensors( + Qnn_TensorSet_t& tensorSet, + Qnn_Tensor_t* inputTensors, + uint32_t const numInputs +) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + tensorSet.v1.inputs = inputTensors; + tensorSet.v1.numInputs = numInputs; + } +} + +inline void setQnnTensorSetInputTensors( + Qnn_TensorSet_t* tensorSet, + Qnn_Tensor_t* inputTensors, + uint32_t const numInputs +) { + setQnnTensorSetInputTensors(*tensorSet, inputTensors, numInputs); +} + +inline void setQnnTensorSetOutputTensors( + Qnn_TensorSet_t& tensorSet, + Qnn_Tensor_t* outputTensors, + const uint32_t numOutputs +) { + if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) { + tensorSet.v1.outputs = outputTensors; + tensorSet.v1.numOutputs = numOutputs; + } +} + +inline void setQnnTensorSetOutputTensors( + Qnn_TensorSet_t* tensorSet, + Qnn_Tensor_t* outputTensors, + const uint32_t numOutputs +) { + setQnnTensorSetOutputTensors(*tensorSet, outputTensors, numOutputs); +} + +// Creator for QNN Op Config +#define QNN_OP_CFG_CREATE(version) createQnnOpConfig(version) + +// Accessors for QNN Op Config +#define QNN_OP_CFG_GET_NAME(opConfig) getQnnOpConfigName(opConfig) +#define QNN_OP_CFG_GET_PACKAGE_NAME(opConfig) getQnnOpConfigPackageName(opConfig) +#define QNN_OP_CFG_GET_TYPE_NAME(opConfig) getQnnOpConfigTypeName(opConfig) +#define QNN_OP_CFG_GET_NUM_PARAMS(opConfig) getQnnOpConfigNumParams(opConfig) +#define QNN_OP_CFG_GET_PARAMS(opConfig) getQnnOpConfigParams(opConfig) +#define QNN_OP_CFG_GET_NUM_INPUTS(opConfig) getQnnOpConfigNumInputs(opConfig) +#define QNN_OP_CFG_GET_INPUTS(opConfig) getQnnOpConfigInputs(opConfig) +#define QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) getQnnOpConfigNumOutputs(opConfig) +#define QNN_OP_CFG_GET_OUTPUTS(opConfig) getQnnOpConfigOutputs(opConfig) + +// Modifiers for QNN Op Config +#define QNN_OP_CFG_SET_NAME(opConfig, value) setQnnOpConfigName(opConfig, value) +#define QNN_OP_CFG_SET_PACKAGE_NAME(opConfig, value) setQnnOpConfigPackageName(opConfig, value) +#define QNN_OP_CFG_SET_TYPE_NAME(opConfig, value) setQnnOpConfigTypeName(opConfig, value) +#define QNN_OP_CFG_SET_PARAMS(opConfig, numOfParams, params) \ + setQnnOpConfigParams(opConfig, numOfParams, params) +#define QNN_OP_CFG_SET_INPUTS(opConfig, numOfInputs, inputTensors) \ + setQnnOpConfigInputs(opConfig, numOfInputs, inputTensors) +#define QNN_OP_CFG_SET_OUTPUTS(opConfig, numOfOutputs, outputTensors) \ + setQnnOpConfigOutputs(opConfig, numOfOutputs, outputTensors) + +// Creator for QNN Tensor +#define QNN_TENSOR_CREATE(version) createQnnTensor(version) + +// Accessors for QNN Tensor +#define QNN_TENSOR_GET_ID(tensor) getQnnTensorId(tensor) +#define QNN_TENSOR_GET_NAME(tensor) getQnnTensorName(tensor) +#define QNN_TENSOR_GET_TYPE(tensor) getQnnTensorType(tensor) +#define QNN_TENSOR_GET_DATA_FORMAT(tensor) getQnnTensorDataFormat(tensor) +#define QNN_TENSOR_GET_DATA_TYPE(tensor) getQnnTensorDataType(tensor) +#define QNN_TENSOR_GET_QUANT_PARAMS(tensor) getQnnTensorQuantParams(tensor) +#define QNN_TENSOR_GET_RANK(tensor) getQnnTensorRank(tensor) +#define QNN_TENSOR_GET_DIMENSIONS(tensor) getQnnTensorDimensions(tensor) +#define QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(tensor) getQnnTensorIsDynamicDimensions(tensor) +#define QNN_TENSOR_GET_SPARSE_PARAMS(tensor) getQnnTensorSparseParams(tensor) +#define QNN_TENSOR_GET_MEM_TYPE(tensor) getQnnTensorMemType(tensor) +#define QNN_TENSOR_GET_CLIENT_BUF(tensor) getQnnTensorClientBuf(tensor) +#define QNN_TENSOR_GET_MEM_HANDLE(tensor) getQnnTensorMemHandle(tensor) + +// Modifiers for QNN Tensor +#define QNN_TENSOR_SET_ID(tensor, value) setQnnTensorId(tensor, value) +#define QNN_TENSOR_SET_NAME(tensor, value) setQnnTensorName(tensor, value) +#define QNN_TENSOR_SET_TYPE(tensor, value) setQnnTensorType(tensor, value) +#define QNN_TENSOR_SET_DATA_FORMAT(tensor, value) setQnnTensorDataFormat(tensor, value) +#define QNN_TENSOR_SET_DATA_TYPE(tensor, value) setQnnTensorDataType(tensor, value) +#define QNN_TENSOR_SET_QUANT_PARAMS(tensor, value) setQnnTensorQuantParams(tensor, value) +#define QNN_TENSOR_SET_RANK(tensor, value) setQnnTensorRank(tensor, value) +#define QNN_TENSOR_SET_DIMENSIONS(tensor, value) setQnnTensorDimensions(tensor, value) +#define QNN_TENSOR_SET_IS_DYNAMIC_DIMENSIONS(tensor, value) \ + setQnnTensorIsDynamicDimensions(tensor, value) +#define QNN_TENSOR_SET_SPARSE_PARAMS(tensor, value) setQnnTensorSparseParams(tensor, value) +#define QNN_TENSOR_SET_MEM_TYPE(tensor, value) setQnnTensorMemType(tensor, value) +#define QNN_TENSOR_SET_CLIENT_BUF(tensor, value) setQnnTensorClientBuf(tensor, value) +#define QNN_TENSOR_SET_MEM_HANDLE(tensor, value) setQnnTensorMemHandle(tensor, value) + +// Creator for QNN Tensor Set +#define QNN_TENSORSET_CREATE(version) createQnnTensorSet(version) + +// Accessors for QNN Tensor Set +#define QNN_TENSORSET_GET_NUM_INPUTS(tensorSet) getQnnTensorSetNumInputs(tensorSet) +#define QNN_TENSORSET_GET_INPUT_TENSORS(tensorSet) getQnnTensorSetInputTensors(tensorSet) +#define QNN_TENSORSET_GET_NUM_OUTPUTS(tensorSet) getQnnTensorSetNumOutputs(tensorSet) +#define QNN_TENSORSET_GET_OUTPUT_TENSORS(tensorSet) getQnnTensorSetOutputTensors(tensorSet) + +// Modifiers for QNN Tensor Set +#define QNN_TENSORSET_SET_INPUT_TENSORS(tensorSet, inputTensors, numInputs) \ + setQnnTensorSetInputTensors(tensorSet, inputTensors, numInputs) +#define QNN_TENSORSET_SET_OUTPUT_TENSORS(tensorSet, outputTensors, numOutputs) \ + setQnnTensorSetOutputTensors(tensorSet, outputTensors, numOutputs) + +inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t& tensor) { + if (tensor.version == QNN_TENSOR_VERSION_2) { + if (tensor.v2.isDynamicDimensions != NULL) { + return false; + } + + if (tensor.v2.dataFormat == QNN_TENSOR_DATA_FORMAT_SPARSE) { + return false; + } + } + + return true; +} + +inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t* const tensor) { + return isQnnTensorV1Compatible(*tensor); +} + +inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t& opConfig) { + if ((QNN_OP_CFG_GET_INPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_INPUTS(opConfig) > 0u)) { + for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_INPUTS(opConfig); + tensorIdx++) { + if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_INPUTS(opConfig)[tensorIdx])) { + return false; + } + } + } + if ((QNN_OP_CFG_GET_OUTPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) > 0u)) { + for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig); + tensorIdx++) { + if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_OUTPUTS(opConfig)[tensorIdx])) { + return false; + } + } + } + if ((QNN_OP_CFG_GET_PARAMS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_PARAMS(opConfig) > 0)) { + for (uint32_t paramIdx = 0u; paramIdx < QNN_OP_CFG_GET_NUM_PARAMS(opConfig); paramIdx++) { + const Qnn_Param_t& param = QNN_OP_CFG_GET_PARAMS(opConfig)[paramIdx]; + if (QNN_PARAMTYPE_TENSOR == param.paramType) { + if (!isQnnTensorV1Compatible(param.tensorParam)) { + return false; + } + } + } + } + + return true; +} + +inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t* const opConfig) { + return isQnnTensorV1Compatible(*opConfig); +} diff --git a/Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp b/Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b167620cb80f9b004b0c63236ff689cb88708e36 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp @@ -0,0 +1,481 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "QnnMem.h" +#include "QnnHtpMem.h" +#include "RpcMem.hpp" +#include "QnnTypeMacros.hpp" +#include "dlwrap.hpp" + +#define RPCMEM_HEAP_ID_SYSTEM 25 +#define RPCMEM_DEFAULT_FLAGS 1 + +#if 1 + #define TRACE_MEMORY_ALLOC QNN_DEBUG +#else + #define TRACE_MEMORY_ALLOC(fmt, ...) +#endif + +RpcMem::RpcMem(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface) + : m_libCdspRpc(nullptr), m_rpcMemAlloc(nullptr), m_rpcMemFree(nullptr), m_rpcMemToFd(nullptr), + m_qnnInterface(qnnInterface), m_contextHandle(contextHandle) { + (void)m_contextHandle; +} + +bool RpcMem::initialize() { + // On Android, 32-bit and 64-bit libcdsprpc.so can be found at /vendor/lib and /vendor/lib64 respectively. + // On Windows, it's installed into something like this + // c:\Windows\System32\DriverStore\FileRepository\qcnspmcdm8380.inf_arm64_30b9cc995571de6a\libcdsprpc.dll +#ifdef _WIN32 + const char* dsprpc_so = "libcdsprpc.dll"; +#else + const char* dsprpc_so = "libcdsprpc.so"; +#endif + + m_libCdspRpc = dlopen(dsprpc_so, RTLD_NOW | RTLD_LOCAL); + if (nullptr == m_libCdspRpc) { + QNN_ERROR("Unable to load backend. dlerror(): %s", dlerror()); + return false; + } + m_rpcMemAlloc = (RpcMemAllocFn_t)dlsym(m_libCdspRpc, "rpcmem_alloc"); + m_rpcMemFree = (RpcMemFreeFn_t)dlsym(m_libCdspRpc, "rpcmem_free"); + m_rpcMemToFd = (RpcMemToFdFn_t)dlsym(m_libCdspRpc, "rpcmem_to_fd"); + if (nullptr == m_rpcMemAlloc || nullptr == m_rpcMemFree || nullptr == m_rpcMemToFd) { + QNN_ERROR("Unable to access symbols in libcdsprpc. dlerror(): %s", dlerror()); + return false; + } + + return true; +} + +RpcMem::~RpcMem() { + if (m_libCdspRpc) { + QNN_DEBUG("Closing libcdsprpc.so handle"); + dlclose(m_libCdspRpc); + } +} + +RpcMemTensorData* RpcMem::getRpcMemTensorData(Qnn_Tensor_t* tensor) { + if (tensor == nullptr) return nullptr; + Qnn_MemHandle_t mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor); + if (mem_handle == nullptr) return nullptr; + return &m_memHandleToRpcMem.at(mem_handle); +} + +void* RpcMem::getBuffer(Qnn_Tensor_t* tensor) { + RpcMemTensorData* data = getRpcMemTensorData(tensor); + if (data == nullptr) { + QNN_ERROR("getBuffer : Couldn't find tensor %p", tensor); + return nullptr; + } + return data->memPointer; +} + +int RpcMem::getFd(Qnn_Tensor_t* tensor) { + RpcMemTensorData* data = getRpcMemTensorData(tensor); + if (data == nullptr) { + QNN_ERROR("getFd : Couldn't find tensor %p", tensor); + return -1; + } + return data->fd; +} + +size_t RpcMem::getOffset(Qnn_Tensor_t* tensor) { + RpcMemTensorData* data = getRpcMemTensorData(tensor); + if (data == nullptr) { + QNN_ERROR("getOffset : Couldn't find tensor %p", tensor); + return 0; + } + return data->offset; +} + +size_t RpcMem::getBufferSize(Qnn_Tensor_t* tensor) { + RpcMemTensorData* data = getRpcMemTensorData(tensor); + if (data == nullptr) { + QNN_ERROR("getBufferSize : Couldn't find tensor %p", tensor); + return 0; + } + return data->size; +}; + +size_t RpcMem::getTotalBufferSize(Qnn_Tensor_t* tensor) { + RpcMemTensorData* data = getRpcMemTensorData(tensor); + if (data == nullptr) { + QNN_ERROR("getTotalBufferSize : Couldn't find tensor %p", tensor); + return 0; + } + return data->totalBufferSize; +} + +bool RpcMem::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) { + if (m_libCdspRpc == nullptr) { + QNN_ERROR("RpcMem not initialized"); + return false; + } + if (!tensor) { + QNN_ERROR("Received nullptr for tensor"); + return false; + } + if (m_tensorToRpcMem.find(tensor) != m_tensorToRpcMem.end()) { + QNN_ERROR("Tensor already allocated"); + return false; + } + + auto memPointer = m_rpcMemAlloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, tensorDataSize); + auto status = true; + if (!memPointer) { + QNN_ERROR("rpcmem_alloc failure"); + status = false; + } + int memfd = -1; + if (status == true) { + memfd = m_rpcMemToFd(memPointer); + if (memfd == -1) { + QNN_ERROR("rpcmem_to_fd failure"); + status = false; + } + } + if (status == true) { + Qnn_MemDescriptor_t memDescriptor = { + {QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr}, + QNN_TENSOR_GET_DATA_TYPE(tensor), + QNN_MEM_TYPE_ION, + {{-1}} + }; + memDescriptor.ionInfo.fd = memfd; + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE); + QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr); + + Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor); + if (QNN_SUCCESS != m_qnnInterface->memRegister( + m_contextHandle, + &memDescriptor, + 1, + &(memHandle) + )) { + const char* tname = QNN_TENSOR_GET_NAME(tensor); + QNN_ERROR("memRegister fail %s (ctx=%p fd=%d)", tname, m_contextHandle, memfd); + status = false; + } + QNN_TENSOR_SET_MEM_HANDLE(tensor, memHandle); + } + if (status == true) { + m_tensorToRpcMem.insert({tensor, RpcMemTensorData(memfd, memPointer, tensorDataSize)}); + } + if (status == false) { + if (m_rpcMemFree) { + m_rpcMemFree(memPointer); + } + } + return status; +} + +bool RpcMem::freeTensorBuffer(Qnn_Tensor_t* tensor) { + if (!tensor) { + QNN_ERROR("Received nullptr for tensor"); + return false; + } + + if (m_sameMemoryFreeTensors.find(tensor) != m_sameMemoryFreeTensors.end()) { + if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) { + QNN_ERROR("Tensor not found"); + return false; + } + m_tensorToRpcMem.erase(tensor); + } else { + auto memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor); + if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) { + QNN_ERROR("Failed to deregister ion memory with the backend"); + return false; + } + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED); + if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) { + QNN_ERROR("Tensor not found"); + return false; + } + if (m_rpcMemFree) { + m_rpcMemFree(m_tensorToRpcMem[tensor].memPointer); + } + m_tensorToRpcMem.erase(tensor); + } + + return true; +} + +bool RpcMem::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) { + if (nullptr == dest || nullptr == src) { + QNN_ERROR("Received nullptr"); + return false; + } + if (m_tensorToRpcMem.find(src) == m_tensorToRpcMem.end()) { + QNN_ERROR("Src Tensor not found"); + return false; + } + + if (false == freeTensorBuffer(dest)) { + return false; + } + + QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src)); + QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src)); + m_tensorToRpcMem.insert({dest, m_tensorToRpcMem[src]}); + m_sameMemoryFreeTensors.insert(dest); + + return true; +} + +bool RpcMem::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) { + if (nullptr == dest || nullptr == src) { + QNN_ERROR("Received nullptr"); + return false; + } + if (m_tensorToRpcMem.find(src) == m_tensorToRpcMem.end()) { + QNN_ERROR("Src Tensor not found"); + return false; + } + + if (false == freeTensorBuffer(dest)) { + return false; + } + + QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src)); + QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src)); + m_tensorToRpcMem.insert({dest, m_tensorToRpcMem[src]}); + m_sameMemoryFreeTensors.insert(dest); + + return true; +} + +bool RpcMem::useExternalMemory(Qnn_Tensor_t* dest, void* extMem) { + QNN_ERROR("We don't support external memory feature for shared buffers yet!"); + return false; +} + +void* RpcMem::allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) { + *fd = -1; + if (m_libCdspRpc == nullptr) { + QNN_ERROR("RpcMem not initialized for fused buffer"); + return nullptr; + } + + void* memPointer = m_rpcMemAlloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, bufferSize); + if (!memPointer) { + QNN_ERROR("Not able to allocate fused buffer of size: %lu", (unsigned long)bufferSize); + return nullptr; + } + + m_fusedBuffers.push_back({memPointer, bufferSize}); + QNN_DEBUG( + "Successfully allocated fused buffer at %p with size %lu", + memPointer, + (unsigned long)bufferSize + ); + + if ((*fd = m_rpcMemToFd(memPointer)) == -1) { + QNN_ERROR( + "Not able to get fd for the fused buffer of size: %lu", (unsigned long)bufferSize + ); + return nullptr; + } + + QNN_DEBUG("Retrieved fd %d for pointer %p", *fd, memPointer); + return memPointer; +} + +bool RpcMem::allocateBuffers( + const std::map>& allocs_per_chunk, + std::map>& tensor_offsets +) { + int alloc_chunk_idx = m_fusedBuffers.size(); + int num_alloc_chunks = 0; + size_t total_alloc_size = 0; + + for (auto& [_, tensor_sizes] : allocs_per_chunk) { + // Calculate total allocation chunk size + size_t alloc_chunk_size = 0; + for (const auto& [tensor_name, tensor_size] : tensor_sizes) { + tensor_offsets[tensor_name] = {alloc_chunk_idx, alloc_chunk_size}; + alloc_chunk_size += tensor_size; + } + + // Allocate chunk for this unique context set + if (alloc_chunk_size <= 0) { + QNN_ERROR("Unexpected chunk size detected. Please re-check IO allocations"); + return false; + } + + m_fusedFds.push_back(0); + if (!allocateTensorFusedBuffer(alloc_chunk_size, &m_fusedFds.back())) // + return false; + total_alloc_size += alloc_chunk_size; + alloc_chunk_idx++; + num_alloc_chunks++; + } + QNN_INFO( + "Allocated total size = %lu across %d buffers", + (unsigned long)total_alloc_size, + num_alloc_chunks + ); + return true; +} + +bool RpcMem::mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + size_t tensorDataSize, + int32_t fd, + uint32_t offset, + uint64_t totalBufferSize, + void* memPointer, + Qnn_ContextHandle_t contextHandle +) { + if (m_libCdspRpc == nullptr) { + QNN_ERROR("RpcMem not initialized"); + return false; + } + if (!tensor) { + QNN_ERROR("Received nullptr for tensor"); + return false; + } + + Qnn_ErrorHandle_t ret; + const char* tname = QNN_TENSOR_GET_NAME(tensor); + + // Check if tensor already has a memHandle assigned + Qnn_MemHandle_t cur_mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor); + if (cur_mem_handle != nullptr) { + // Check if memHandle is already identical to requested buffer and offset + RpcMemTensorData& cur_rpc_mem_data = m_memHandleToRpcMem.at(cur_mem_handle); + if (cur_rpc_mem_data.fd == fd && cur_rpc_mem_data.offset == offset) { + return true; + } + + // updated offset, deregister previous mem_handle + if (tensorDataSize == 0) tensorDataSize = cur_rpc_mem_data.size; + // clang-format off + TRACE_MEMORY_ALLOC( "memDeRegister %-20s (fd=%d offset=%lu) memHandle=%p", + tname, cur_rpc_mem_data.fd, cur_rpc_mem_data.offset, cur_mem_handle); + // clang-format on + m_memHandleToRpcMem.erase(cur_mem_handle); + if ((ret = m_qnnInterface->memDeRegister(&cur_mem_handle, 1)) != QNN_SUCCESS) { + QNN_ERROR( + "memDeRegister ERROR(%lu) - %s memHandle=%p", + (unsigned long)ret, + tname, + cur_mem_handle + ); + return false; + } + } else { + // For inital tensors, we need to check if the tensor can re-use a memHandle + // from another tensor in the same context + auto memConfig = std::make_tuple(fd, offset, contextHandle); + if (memConfigList.contains(memConfig)) { + auto& parentTensor = memConfigList[memConfig]; + Qnn_MemHandle_t parentMemHandle = QNN_TENSOR_GET_MEM_HANDLE(parentTensor); + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE); + QNN_TENSOR_SET_MEM_HANDLE(tensor, parentMemHandle); + TRACE_MEMORY_ALLOC("%-20s : Mapping to memHandle %p", tname, parentMemHandle); + return true; + } + } + + // Register a new memHandle based on function arguments + QnnMemHtp_Descriptor_t htp_mem_desciptor = {QNN_HTP_MEM_SHARED_BUFFER, totalBufferSize, {0}}; + htp_mem_desciptor.sharedBufferConfig.fd = fd; + htp_mem_desciptor.sharedBufferConfig.offset = offset; + + Qnn_MemDescriptor_t mem_descriptor = { + {QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr}, + QNN_TENSOR_GET_DATA_TYPE(tensor), + QNN_MEM_TYPE_CUSTOM, + {{-1}} + }; + mem_descriptor.customInfo = &htp_mem_desciptor; + + Qnn_MemHandle_t mem_handle = nullptr; + ret = m_qnnInterface->memRegister(contextHandle, &mem_descriptor, 1, &mem_handle); + if (ret != QNN_SUCCESS) { + QNN_ERROR("%-20s (ctx=%p fd=%d offset=%u)", tname, contextHandle, fd, offset); + QNN_ERROR("memRegister ERROR(%lu)", (unsigned long)ret); + return false; + } + + // clang-format off + TRACE_MEMORY_ALLOC("%-20s (ctx=%p fd=%d offset=%u) memPointer=%p memHandle=%p", + tname, contextHandle, fd, offset, ((uint8_t*)memPointer) + offset, mem_handle); + // clang-format on + m_memHandleToRpcMem[mem_handle] = RpcMemTensorData( + fd, ((uint8_t*)memPointer) + offset, tensorDataSize, totalBufferSize, offset + ); + + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE); + QNN_TENSOR_SET_MEM_HANDLE(tensor, mem_handle); + if (cur_mem_handle == nullptr) // Cache memory config for initial memRegisters only + memConfigList[std::make_tuple(fd, offset, contextHandle)] = tensor; + + return true; +} + +bool RpcMem::mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + int alloc_idx, + size_t offset, + Qnn_ContextHandle_t ctx, + size_t size +) { + return mapFusedBufferOffset( + tensor, + size, + m_fusedFds[alloc_idx], + offset, + m_fusedBuffers[alloc_idx].second, + m_fusedBuffers[alloc_idx].first, + ctx + ); +} + +bool RpcMem::deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) { + if (!tensor) { + QNN_ERROR("Received nullptr for tensor"); + return false; + } + + if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) { + QNN_ERROR("Tensor not found"); + return false; + } + + // We are not freeing memhandles here since they are already freed when + // freeContext() gets called in the destructor of QnnApi class which + // happens before this point + + // Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor); + // QNN_ERROR("Interface handle %p memhandle %p", m_qnnInterface, memHandle); + // if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) { + // QNN_ERROR("Failed to deregister ion memory with the backend"); + // return false; + // } + + QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED); + QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr); + m_tensorToRpcMem.erase(tensor); + return true; +} + +void RpcMem::freeFusedBuffers() { + // for (auto& memHandle : m_orphanedMemHandles) { + // if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) { + // QNN_ERROR("Failed to deregister ion memory with the backend"); + // } + // } + + for (auto& [mem_ptr, buffer_size] : m_fusedBuffers) { + QNN_DEBUG("Freeing fused buffer %p (size=%lu)", mem_ptr, buffer_size); + m_rpcMemFree(mem_ptr); + } +} diff --git a/Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp b/Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp new file mode 100644 index 0000000000000000000000000000000000000000..abd772a6bc0ba70fcc0d54afbb1aca2edf3840f9 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp @@ -0,0 +1,115 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include + +#include "IBufferAlloc.hpp" +#include "QnnInterface.h" +#include "Log.hpp" + +typedef void* (*RpcMemAllocFn_t)(int, uint32_t, int); +typedef void (*RpcMemFreeFn_t)(void*); +typedef int (*RpcMemToFdFn_t)(void*); + +struct RpcMemTensorData { + int fd; + void* memPointer; + size_t size; + size_t totalBufferSize; + size_t offset; + RpcMemTensorData() : fd(-1), memPointer(nullptr), size(0) {} + RpcMemTensorData(int fdIn, void* memPointerIn, size_t sizeIn) + : fd(fdIn), memPointer(memPointerIn), size(sizeIn) {} + RpcMemTensorData( + int fdIn, + void* memPointerIn, + size_t sizeIn, + size_t totalBufferSizeIn, + size_t offsetIn + ) + : fd(fdIn), memPointer(memPointerIn), size(sizeIn), totalBufferSize(totalBufferSizeIn), + offset(offsetIn) {} +}; + +class RpcMem final : public IBufferAlloc { + public: + RpcMem(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface); + // Disable copy constructors, r-value referencing, etc + RpcMem(const RpcMem&) = delete; + RpcMem& operator=(const RpcMem&) = delete; + RpcMem(RpcMem&&) = delete; + RpcMem& operator=(RpcMem&&) = delete; + bool initialize() override; + void* getBuffer(Qnn_Tensor_t* tensor) override; + int getFd(Qnn_Tensor_t* tensor) override; + + size_t getOffset(Qnn_Tensor_t* tensor) override; + + size_t getBufferSize(Qnn_Tensor_t* tensor) override; + + size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override; + + bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override; + + bool freeTensorBuffer(Qnn_Tensor_t* tensor) override; + bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override; + bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override; + + bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override; + + void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override; + bool allocateBuffers( + const std::map>& allocs_per_chunk, + std::map>& tensor_offsets + ) override; + + bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + size_t tensorDataSize, + int32_t fd, + uint32_t offset, + uint64_t totalBufferSize, + void* memPointer, + Qnn_ContextHandle_t contextHandle + ) override; + bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override; + void freeFusedBuffers() override; + bool mapFusedBufferOffset( + Qnn_Tensor_t* tensor, + int alloc_idx, + size_t offset, + Qnn_ContextHandle_t ctx, + size_t size + ) override; + virtual ~RpcMem(); + + private: + RpcMemTensorData* getRpcMemTensorData(Qnn_Tensor_t* tensor); + + // Pointer to the dlopen'd libcdsprpc.so shared library which contains + // rpcmem_alloc, rpcmem_free, rpcmem_to_fd APIs + void* m_libCdspRpc; + // Function pointer to rpcmem_alloc + RpcMemAllocFn_t m_rpcMemAlloc; + // Function pointer to rpcmem_free + RpcMemFreeFn_t m_rpcMemFree; + // Function pointer to rpcmem_to_fd + RpcMemToFdFn_t m_rpcMemToFd; + QNN_INTERFACE_VER_TYPE* m_qnnInterface; + Qnn_ContextHandle_t m_contextHandle; + + std::unordered_map m_tensorToRpcMem; + std::unordered_set m_sameMemoryFreeTensors; + std::vector> m_fusedBuffers; // vector<> + std::vector m_fusedFds; + std::unordered_set m_orphanedMemHandles; + std::unordered_map m_memHandleToRpcMem; + std::map, Qnn_Tensor_t*> memConfigList; +}; diff --git a/Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp b/Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp new file mode 100644 index 0000000000000000000000000000000000000000..17df4ffa6d2f17d9b488ca7630c50bdb95dd79b6 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp @@ -0,0 +1,66 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifdef _WIN32 + + #pragma warning(disable : 4133 4996) + + #include + #include + #include + #include + #include + #include + + #include "dlwrap.hpp" + +static const char* last_func; +static long last_err; + +void* dlopen(const char* dll, int flags) { + HINSTANCE h = LoadLibraryA(dll); + if (h == NULL) { + last_err = GetLastError(); + last_func = "dlopen"; + } + + return h; +} + +int dlclose(void* h) { + if (!FreeLibrary((HINSTANCE)h)) { + last_err = GetLastError(); + last_func = "dlclose"; + return -1; + } + + return 0; +} + +void* dlsym(void* h, const char* name) { + FARPROC p = GetProcAddress((HINSTANCE)h, name); + if (!p) { + last_err = GetLastError(); + last_func = "dlsym"; + } + return (void*)(intptr_t)p; +} + +const char* dlerror(void) { + static char str[88]; + + if (!last_err) return NULL; + + sprintf(str, "%s error #%ld", last_func, last_err); + last_err = 0; + last_func = NULL; + + return str; +} + +#endif // _WIN32 diff --git a/Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp b/Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5170c06b5b3e04c493055985399490d9d7e30605 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp @@ -0,0 +1,33 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef DLWRAP_HPP +#define DLWRAP_HPP + +#ifndef _WIN32 + + // Just include regular dlfcn + #include + +#else // _WIN32 + + // Define basic set dl functions and flags + + #define RTLD_GLOBAL 0x100 + #define RTLD_LOCAL 0x000 + #define RTLD_LAZY 0x000 + #define RTLD_NOW 0x001 + +void* dlopen(const char* filename, int flag); +int dlclose(void* handle); +void* dlsym(void* handle, const char* name); +const char* dlerror(void); + +#endif // _WIN32 + +#endif // DLWRAP_HPP diff --git a/Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp b/Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3d78469c0a21dadf6c106b9b3220386d58a8f8e4 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp @@ -0,0 +1,104 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "qnn-utils.hpp" + +#include +#include +#include +#include +#include "QnnApi.hpp" +#include + +namespace fs = std::filesystem; + +namespace qualla { +namespace QnnUtils { + // Alternate implementation for bw() = lambda x: (10 * ((x & 0xf0)>>4) + (x & 0xf)) // 8 + int DataType::bw() { return (_dtype == QNN_DATATYPE_UNDEFINED) ? -1 : QnnApi::getDataTypeSize(_dtype);} + int DataType::type() {return (_dtype == QNN_DATATYPE_UNDEFINED) ? -1 : _dtype >> 4; } + + int32_t DataType::val() { return static_cast(_dtype); } + +bool writeRawData(void* data, size_t size, const fs::path& path) { + auto p = path.parent_path(); + if (!fs::exists(p) && !fs::create_directories(p)) return false; + + std::ofstream f(path, std::ofstream::binary); + f.write((char*)data, size); + f.close(); + + return true; +} + +bool readRawData(void* data, size_t size, const fs::path& path) { + if (fs::file_size(path) != size) { + throw std::runtime_error(fmt::format( + "file size doesnot match: {} size {}, buf-size {}", + path.string(), + fs::file_size(path), + size + )); + } + + std::ifstream f(path, std::ifstream::binary); + f.read((char*)data, size); + f.close(); + + return true; +} + +void getQuantParamString( + const std::vector& quantParam, + std::string& scale_string, + std::string& offset_string +) { + std::ostringstream scales_s; + std::ostringstream offsets_s; + for (int i = 0; i < quantParam.size(); i++) { + if (i != 0) { + scales_s << ", "; + offsets_s << ", "; + } + scales_s << std::fixed << std::setprecision(20) << quantParam[i].scale; + offsets_s << quantParam[i].offset; + } + scale_string = std::move(scales_s.str()); + offset_string = std::move(offsets_s.str()); +} + +const char* DataType::str() { + // clang-format off + switch (_dtype) { + case QNN_DATATYPE_INT_8: return "QNN_DATATYPE_INT_8"; + case QNN_DATATYPE_INT_16: return "QNN_DATATYPE_INT_16"; + case QNN_DATATYPE_INT_32: return "QNN_DATATYPE_INT_32"; + case QNN_DATATYPE_INT_64: return "QNN_DATATYPE_INT_64"; + case QNN_DATATYPE_UINT_8: return "QNN_DATATYPE_UINT_8"; + case QNN_DATATYPE_UINT_16: return "QNN_DATATYPE_UINT_16"; + case QNN_DATATYPE_UINT_32: return "QNN_DATATYPE_UINT_32"; + case QNN_DATATYPE_UINT_64: return "QNN_DATATYPE_UINT_64"; + case QNN_DATATYPE_FLOAT_16: return "QNN_DATATYPE_FLOAT_16"; + case QNN_DATATYPE_FLOAT_32: return "QNN_DATATYPE_FLOAT_32"; + case QNN_DATATYPE_FLOAT_64: return "QNN_DATATYPE_FLOAT_64"; + case QNN_DATATYPE_SFIXED_POINT_4: return "QNN_DATATYPE_SFIXED_POINT_4"; + case QNN_DATATYPE_SFIXED_POINT_8: return "QNN_DATATYPE_SFIXED_POINT_8"; + case QNN_DATATYPE_SFIXED_POINT_16: return "QNN_DATATYPE_SFIXED_POINT_16"; + case QNN_DATATYPE_SFIXED_POINT_32: return "QNN_DATATYPE_SFIXED_POINT_32"; + case QNN_DATATYPE_UFIXED_POINT_4: return "QNN_DATATYPE_UFIXED_POINT_4"; + case QNN_DATATYPE_UFIXED_POINT_8: return "QNN_DATATYPE_UFIXED_POINT_8"; + case QNN_DATATYPE_UFIXED_POINT_16: return "QNN_DATATYPE_UFIXED_POINT_16"; + case QNN_DATATYPE_UFIXED_POINT_32: return "QNN_DATATYPE_UFIXED_POINT_32"; + case QNN_DATATYPE_BOOL_8: return "QNN_DATATYPE_BOOL_8"; + case QNN_DATATYPE_STRING: return "QNN_DATATYPE_STRING"; + default: return "QNN_DATATYPE_UNDEFINED"; + } + // clang-format on +} +} // namespace QnnUtils +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp b/Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ca6efa2441cbf8548697412436ed7474ceb65cb6 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp @@ -0,0 +1,157 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#ifdef _MSC_VER + #pragma warning(disable : 4068) +#endif + +#include +#include +#include +#include +#include "QnnApiUtils.hpp" +#include "QnnInterface.h" + +namespace qualla { + +namespace QnnUtils { +class DataType { + private: + Qnn_DataType_t _dtype{QNN_DATATYPE_UNDEFINED}; + + public: + DataType() = default; + DataType(const Qnn_Tensor_t* tensor) : _dtype(QNN_TENSOR_GET_DATA_TYPE(tensor)) {} + DataType(Qnn_DataType_t dtype) : _dtype(dtype) {}; + + // Enable switch and comparisons + constexpr operator Qnn_DataType_t() const { return _dtype; } + + int bw(); + int type(); + + int32_t val(); + + const char* str(); +}; + +bool writeRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path); +bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path); + +struct Dims { + int32_t batch = 1; + int32_t height, width, channel, bitWidth; + Dims() : height(0), width(0), channel(0), bitWidth(0) {} + Dims(int32_t height, int32_t width, int32_t channel, int32_t bitWidth) + : height(height), width(width), channel(channel), bitWidth(bitWidth) {} + Dims(std::vector& tDims) + : height((int32_t)tDims[1]), width((int32_t)tDims[2]), channel((int32_t)tDims[3]), + bitWidth((int32_t)tDims[4]) { + // Hack to mix batch dimension + if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0]; + if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0]; + } + bool operator==(const Dims& rhs) const { + return (height == rhs.height) && (width == rhs.width) && (channel == rhs.channel) && + (bitWidth == rhs.bitWidth); + } + bool operator!=(const Dims& rhs) const { return !(operator==(rhs)); } + size_t getNumElements() const { return (size_t)(height * width * channel); } + size_t getSize() const { return (size_t)(batch * height * width * channel * bitWidth); } + size_t getAlignedSize() const { + size_t size = getSize(); + if ((size & uint64_t{7}) != uint64_t{0}) { + size += (uint64_t{8} - (size & uint64_t{7})); + } + return size; + } + int32_t getMaxDim() const { return std::max({height, width, channel}); }; + Dims T() const { return Dims(width, height, channel, bitWidth); } +}; + +struct QuantParam { + double scale; + int32_t offset; + QuantParam() {} + QuantParam(double scale_val, int32_t offset_val) : scale(scale_val), offset(offset_val) {} +}; + +struct Tensor { + Qnn_Tensor_t* tensor = nullptr; + Dims dims; + std::vector quantParam; + DataType dtype; + Tensor() {} + Tensor(Qnn_Tensor_t* tensorVal, Dims dimsVal, std::vector quantParamVec) + : tensor(tensorVal), dims(dimsVal), quantParam(quantParamVec), + dtype(QNN_TENSOR_GET_DATA_TYPE(tensorVal)) {} +}; + +// Maps tensor name to QnnUtils::Tensor +typedef std::map TensorMap; + +static inline uint8_t sat_round(const uint16_t x) { + const uint16_t rounded = x + 0x80; // add 0.5 + const uint16_t corrected = std::max(rounded, x); // catch unsigned wrap around + const uint16_t shifted = corrected >> 8; // divide by 256 + return static_cast(shifted); // to 8-bit +} + +static inline void downcast_u16_to_u8(uint8_t* dest, const uint16_t* src, size_t nmemb) { + for (size_t i = 0; i < nmemb; i++) + dest[i] = sat_round(src[i]); +} + +template +static inline void quantizeTensorPtr( + FloatType* tensor_float, + IntType* tensor_quant, + int32_t offset, + double scale, + size_t nmemb +) { +#pragma clang loop vectorize(enable) interleave(enable) + for (size_t i = 0; i < nmemb; i++) { + double val = tensor_float[i]; + tensor_quant[i] = static_cast(val / scale - offset); + } +} + +template +static inline void perWidthQuantizeTensorPtr( + FloatType* tensor_float, + IntType* tensor_quant, + std::vector& quantParam, + int32_t height, + int32_t width, + int32_t channel +) { + for (size_t h = 0; h < height; h++) { + for (size_t w = 0; w < width; w++) { + double scale = quantParam[w].scale; + int32_t offset = quantParam[w].offset; +#pragma clang loop vectorize(enable) interleave(enable) + for (size_t c = 0; c < channel; c++) { + int32_t i = (h * width * channel) + (w * channel) + c; + double val = tensor_float[i]; + tensor_quant[i] = static_cast(val / scale - offset); + } + } + } +} + +void getQuantParamString( + const std::vector& quantParam, + std::string& scale_string, + std::string& offset_string +); + +} // namespace QnnUtils +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-cpu.cpp b/Genie/Genie/src/qualla/engines/qnn-cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..48b1ab1aa4d820acd7581dbd6b6afa09aed08b83 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-cpu.cpp @@ -0,0 +1,237 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include +#include + +#include +#include +#include +#include + +#include + +#include "cpu-model.hpp" + +#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +class QnnCpuEngine : public Engine { + private: + // Model parameters + std::unique_ptr _model; + + public: + QnnCpuEngine(Context& ctx, const qualla::json& json); + ~QnnCpuEngine(); + + virtual size_t process( + const std::vector& tokens, + std::vector& logits, + bool logits_all + ) override; + + virtual size_t process( + const std::vector& tokens, + const std::vector& attention_map, + std::vector& logits, + bool logits_all + ) override; + + virtual bool updateKV(size_t n_past) override; + virtual bool updateKV(size_t n_past, const std::vector& selected) override; + virtual bool save(const std::string& name) override; + virtual size_t restore(const std::string& name) override; + virtual void reset() override; +}; + +namespace fs = std::filesystem; + +QnnCpuEngine::QnnCpuEngine(Context& ctx, const qualla::json& json) : Engine(ctx, "qnn-cpu", json) { + qualla::Timer start; + + using FF = Feature::Flags; + _features = FF::OUTPUT_LOGITS | FF::SAVE_RESTORE | FF::OUTPUT_EMBEDDINGS; + + __DEBUG("qnn-cpu: init start"); + + qualla::Config conf(json, _type + "-engine:"); + + // Parse config + QnnCpuModel::Params p; + + std::string model_output = conf.optional("model-output", "logits"); + if (model_output == "logits") + p.model_output = QnnCpuModel::ModelOutput::LOGITS; + else if (model_output == "embeddings") + p.model_output = QnnCpuModel::ModelOutput::EMBEDDINGS; + else + throw std::runtime_error( + "Only logits and embeddings outputs are supported. Invalid output supplied : " + + model_output + ); + + p.model_basedir = _env.path().models / conf.optional("model-basedir", ""); + p.model_bin_path = conf.mandatory("model-bin-path"); + p.model = conf.mandatory("model"); + p.op_package = conf.mandatory("op-package"); + p.backend_lib = conf.mandatory("backend-lib"); + p.n_threads = conf.optional("n-threads", 6); + p.n_logits = conf.optional("n_logits", 1); + p.n_layer = conf.optional("n_layer", 32); + p.n_embd = conf.optional("n_embd", 4096); + p.n_heads = conf.optional("n_heads", 32); + p.use_mmap = conf.optional("use-mmap", false); + p.ctx_size = _ctx.size(); + p.n_vocab_size = _ctx.n_vocab(); + + _model = std::make_unique(_env, p); + + // Load model + if (true != _model->initializeModel()) { + throw std::runtime_error("Failure to initialize model"); + } + + // Initialize IO Tensor buffers + if (true != _model->initializeIOTensors()) { + throw std::runtime_error("Error in setting up IO Tensors"); + } + + if (true != _model->validateModel()) { + throw std::runtime_error("Error validating model. Please check your I/O"); + } + + __DEBUG("qnn-cpu: model has been validated!"); + + if (true != _model->initializeTensorPointers()) { + throw std::runtime_error("Error : Could not find I/O tensors in loaded graphs"); + } + + _kpis.load.update(start.elapsed_usec()); +}; + +QnnCpuEngine::~QnnCpuEngine() { + __DEBUG("qnn-cpu: fini"); +} + +bool QnnCpuEngine::updateKV(size_t n_past) { + qualla::Timer start; + + if (n_past > _ctx.size()) { + __ERROR("qnn-cpu: context size exceeded : n_past {}", n_past); + State::error("context size exceeded"); + return false; + } + + __DEBUG("qnn-cpu: update-kv start : n_past {}", n_past); + + _model->setKVCacheNPast(n_past); + + __DEBUG("qnn-cpu: update-kv complete : {} usec", start.elapsed_usec()); + + _kpis.update_kv.update(start.elapsed_usec()); + + return true; +} + +bool QnnCpuEngine::updateKV(size_t n_past, const std::vector& selected) { + qualla::Timer start; + + if (n_past > _ctx.size()) { + __ERROR("qnn-cpu: context size exceeded : n_past {}", n_past); + State::error("context size exceeded"); + return false; + } + + __DEBUG("qnn-cpu: update-kv start : n_past {}", n_past); + + _model->setKVCacheNPast(n_past); + + __DEBUG("qnn-cpu: update-kv complete : {} usec", start.elapsed_usec()); + + _kpis.update_kv.update(start.elapsed_usec()); + + return true; +} + +size_t QnnCpuEngine::process( + const std::vector& tokens, + std::vector& logits, + bool logits_all = false +) { + qualla::Timer start; + + __DEBUG("qnn-cpu: inference start: n_tokens {}", tokens.size()); + + _model->runInference(tokens, logits_all); + + __DEBUG("qnn-cpu: inference complete : {} usec", start.elapsed_usec()); + + size_t n_tok; + + { + qualla::Timer t; + + __DEBUG("qnn-cpu: get-logits start: all {}", logits_all); + + n_tok = _model->getDequantLogits(logits, logits_all); + + __DEBUG("qnn-cpu: get-logits complete : {} usec", t.elapsed_usec()); + } + + _kpis.process.update(start.elapsed_usec()); + + return n_tok; +} + +size_t QnnCpuEngine::process( + const std::vector& tokens, + const std::vector& attention_map, + std::vector& logits, + bool logits_all = false +) { + return process( + tokens, + logits, + logits_all + ); +} + +size_t QnnCpuEngine::restore(const std::string& name) { + fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role); + return _model->loadKVCache(cache_path.string()); +} + +bool QnnCpuEngine::save(const std::string& name) { + fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role); + return _model->saveKVCache(cache_path.string()); +} + +void QnnCpuEngine::reset() { + // It's enough to just drop the KV$ + updateKV(0); +} + +// Registrator instance +static OnLoad regy([]() { + Engine::__register("qnn-cpu", [](Context& ctx, const json& conf) { + return (Engine*)new QnnCpuEngine(ctx, conf); + }); +}); +void needQnnCpuEngine() {} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp b/Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4e71728021596f2ad9d3d4679b5719f38079000c --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp @@ -0,0 +1,689 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "qualla/env.hpp" +#include "qualla/detail/timer.hpp" +#include "qualla/detail/cache-file.hpp" + +#include "fmt/format.h" +#include "fmt/ranges.h" + +#include "qnn-utils.hpp" +#include "cpu-model.hpp" + +#include +#include +#include +#include +#include + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +QnnCpuModel::QnnCpuModel(Env& env, const Params& params) + : _env(env), model_basedir(params.model_basedir), op_package(params.op_package), + backend_lib(params.backend_lib), model_bin_path(params.model_bin_path), model(params.model), + m_ctx_size(params.ctx_size), m_num_threads(params.n_threads), m_num_tokens(params.ctx_size), + m_numLogits(params.n_logits), m_vocab_size(params.n_vocab_size), m_num_layer(params.n_layer), + m_embd(params.n_embd), m_num_heads(params.n_heads), m_use_mmap(params.use_mmap), + model_output(params.model_output) { + // Initialize QnnAPI + m_qnnApi = std::unique_ptr(new QnnApi()); + m_head_dim = m_embd / m_num_heads; + m_input_dim.push_back(1); + m_input_dim.push_back(m_ctx_size); + // K$, V$ 4D Tensor {n_layer, n_heads, n_ctx, n_head_dim} + m_kv_dim.push_back(m_num_layer); + m_kv_dim.push_back(m_num_heads); + m_kv_dim.push_back(m_ctx_size + 1); + m_kv_dim.push_back(m_head_dim); + if (model_output == ModelOutput::LOGITS) { + m_output_dim.push_back(m_numLogits); + m_output_dim.push_back(m_vocab_size); + } else if (model_output == ModelOutput::EMBEDDINGS) { + m_numLogits = m_ctx_size; + m_output_dim.push_back(m_numLogits); + m_output_dim.push_back(m_embd); + } +} + +QnnCpuModel::~QnnCpuModel() { + // Free Qnn Tensor and their memory + auto start = std::chrono::steady_clock::now(); + if (dequant_logits_ptr != nullptr) free(dequant_logits_ptr); + if (m_ioTensor) { + QNN_DEBUG("Tearing Down Input Tensors Bank"); + for (auto& graph_name : model_order) { + m_ioTensor->tearDownTensors( + m_input_tensors[graph_name], m_input_specs[graph_name].size() + ); + m_ioTensor->tearDownTensors( + m_output_tensors[graph_name], m_output_specs[graph_name].size() + ); + } + } + auto stop = std::chrono::steady_clock::now(); + //QnnUtils::logProfile("Model destruction (cpp) took", start, stop); +} + +// Given a filename, initializeModel load and initializes QNN runtime libraries and the model +bool QnnCpuModel::initializeModel(void) { + // prepare params + Qnn_Param_t params[5]; + params[0].paramType = QNN_PARAMTYPE_SCALAR; + params[0].name = (char*)("model_bin_path"); + params[0].scalarParam.dataType = QNN_DATATYPE_STRING; + params[0].scalarParam.stringValue = model_bin_path.c_str(); + + params[1].paramType = QNN_PARAMTYPE_SCALAR; + params[1].name = (char*)("num_thread"); + params[1].scalarParam.dataType = QNN_DATATYPE_UINT_32; + params[1].scalarParam.uint32Value = m_num_threads; + + params[2].paramType = QNN_PARAMTYPE_SCALAR; + params[2].name = (char*)("num_context"); + params[2].scalarParam.dataType = QNN_DATATYPE_UINT_32; + params[2].scalarParam.uint32Value = m_ctx_size; + + params[3].paramType = QNN_PARAMTYPE_SCALAR; + params[3].name = (char*)("num_last_logits"); + params[3].scalarParam.dataType = QNN_DATATYPE_UINT_32; + params[3].scalarParam.uint32Value = m_numLogits; + + params[4].paramType = QNN_PARAMTYPE_SCALAR; + params[4].name = (char*)("use_mmap"); + params[4].scalarParam.dataType = QNN_DATATYPE_BOOL_8; + params[4].scalarParam.uint32Value = m_use_mmap; + + if (true != m_qnnApi->initialize( + backend_lib, + model, + op_package, + ContextConfigs(), + {}, + m_input_dim.data(), + m_input_dim.size(), + m_output_dim.data(), + m_output_dim.size(), + m_kv_dim.data(), + m_kv_dim.size(), + params, + 5, + false + )) { + QNN_ERROR("Backend library : %s", backend_lib.c_str()); + throw std::runtime_error("QNN initialization failed!"); + } + + // Initialize QNN IO Tensor + m_ioTensor = std::unique_ptr(new IOTensor()); + m_num_graphs = m_qnnApi->getGraphsCount(); + QNN_DEBUG("QNN initialized with %u graph(s)", m_num_graphs); + + auto graphs_info = m_qnnApi->getGraphsInfo(); + for (size_t graph_idx = 0; graph_idx < m_num_graphs; graph_idx++) { + GraphInfo_t* const& graph_info = graphs_info[graph_idx]; + char* graph_name = graph_info->graphName; + std::string graph_str = std::string(graph_name); + + QNN_DEBUG("Loaded graph[%lu] = %s", graph_idx, graph_name); + model_order.push_back(graph_str); + model_context[graph_str] = + m_qnnApi->getContexts()[graph_idx / m_qnnApi->getGraphCountPerContext()]; + } + + // CPU support KV cache mode + m_mode = ExecutionMode::KV_ONLY; + + return true; +} + +// Once the model has been loaded, initialize IO Tensors +// m_ioTensors is initialized by the context for now +bool QnnCpuModel::initializeIOTensors() { + QNN_DEBUG("Create input tensors bank"); + + // Ideally, we should create and initalize m_ioTensor for each context, but we want to + // be able to see/use all the buffers in every contexts so that they can be connected + // with each other. Hence, we are using only the first context to initialize the m_ioTensor + // and use it for all graphs/contexts. + if (true != m_ioTensor->initialize(m_qnnApi->getContexts()[0])) { + QNN_ERROR("Failure to initialize IOTensor"); + return false; + } + + // Getting graph info and its count needed for subsequent steps + GraphInfo_t** const& graphsInfo = m_qnnApi->getGraphsInfo(); + + for (size_t graphIdx = 0; graphIdx < m_num_graphs; graphIdx++) { + GraphInfo_t* const& graphInfo = graphsInfo[graphIdx]; + std::string graphName = std::string(graphInfo->graphName); + + // Setup Inputs + { + std::unordered_map inputTensorsSize; + for (size_t tensorIdx = 0; tensorIdx < graphInfo->numInputTensors; tensorIdx++) { + std::string tensor_name; + std::vector tensorDims; + + auto& tensor = graphInfo->inputTensors[tensorIdx]; + m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor); + std::vector quantParams; + if (!m_qnnApi->getTensorQuantParams(&tensor, quantParams)) { + QNN_DEBUG("Couldn't get tensor quant params : %s", tensor_name.c_str()); + quantParams.emplace_back(0, 0); + } + + auto dims = QnnUtils::Dims(tensorDims); + inputTensorsSize[tensor_name] = dims.getAlignedSize(); + + m_input_specs[graphName][tensor_name] = {&tensor, dims, quantParams}; + } + + Qnn_Tensor_t* tensor_bank = nullptr; + std::unordered_map tensor_ptr_map; + if (true != m_ioTensor->setupInputTensors( + &tensor_bank, + tensor_ptr_map, + *graphInfo, + inputTensorsSize, + m_qnnApi->getContexts()[graphIdx], + false + )) { + QNN_ERROR("Error in setting up Input Tensors for graph %s", graphName.c_str()); + return false; + } + + m_input_tensors[graphName] = tensor_bank; + for (auto& [tensor_name, tensor_ptr] : tensor_ptr_map) { + m_input_specs[graphName][tensor_name].tensor = (Qnn_Tensor_t*)tensor_ptr; + } + } + + // Setup Outputs + { + std::unordered_map outputTensorsSize; + for (size_t tensorIdx = 0; tensorIdx < graphInfo->numOutputTensors; tensorIdx++) { + std::string tensor_name; + std::vector tensorDims; + + auto& tensor = graphInfo->outputTensors[tensorIdx]; + m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor); + std::vector quantParams; + if (!m_qnnApi->getTensorQuantParams(&tensor, quantParams)) { + QNN_DEBUG("Couldn't get tensor quant params : %s", tensor_name.c_str()); + quantParams.emplace_back(0, 0); + } + + auto dims = QnnUtils::Dims(tensorDims); + outputTensorsSize[tensor_name] = dims.getAlignedSize(); + + m_output_specs[graphName][tensor_name] = {&tensor, dims, quantParams}; + } + + Qnn_Tensor_t* tensor_bank = nullptr; + std::unordered_map tensor_ptr_map; + if (true != m_ioTensor->setupOutputTensors( + &tensor_bank, + tensor_ptr_map, + *graphInfo, + outputTensorsSize, + m_qnnApi->getContexts()[graphIdx], + false + )) { + QNN_ERROR("Error in setting up Output Tensors for graph %s", graphName.c_str()); + return false; + } + + m_output_tensors[graphName] = tensor_bank; + for (auto& [tensor_name, tensor_ptr] : tensor_ptr_map) { + m_output_specs[graphName][tensor_name].tensor = (Qnn_Tensor_t*)tensor_ptr; + } + } + } + +#ifdef DUMP_TENSOR_SPECS + dumpTensorSpecs(); +#endif + + auto stop = std::chrono::steady_clock::now(); + //QnnUtils::logProfile("initializeIoTensors (cpp) took", start, stop); + + return true; +} + +void QnnCpuModel::dumpTensorSpecs() { +#ifdef DEBUG_DUMP_TARGET_PATH + if (true != QnnUtils::CreateDirsIfNotExist(DEBUG_DUMP_TARGET_PATH)) { + throw std::runtime_error( + std::string("Could not create directory : ") + DEBUG_DUMP_TARGET_PATH + ); + } + + static const char* stringFmt = + "\t\t{ \"name\": \"%s\", \"dims\": [1, %d, %d, %d], \"bitwidth\": %d, \"scale\": [%s], \"offset\": [%s] },\n"; + + GraphInfo_t** const& graphsInfo = m_qnnApi->getGraphsInfo(); + for (size_t graphIdx = 0; graphIdx < m_num_graphs; graphIdx++) { + GraphInfo_t* const& graphInfo = graphsInfo[graphIdx]; + std::string graphName = std::string(graphInfo->graphName); + + // Create output spec file and open it + char filename[255]; + sprintf(filename, "%s/spec.%s.json", DEBUG_DUMP_TARGET_PATH, graphInfo->graphName); + + FILE* specFile = fopen(filename, "w"); + if (specFile == NULL) { + throw std::runtime_error(std::string("Error opening file : ") + filename); + } + + fprintf(specFile, "{\n\t\"graph_name\" : \"%s\",\n\t\"inputs\" : [\n", graphName.c_str()); + + std::string tensor_name; + std::vector tensorDims; + + for (size_t tensorIdx = 0; tensorIdx < graphInfo->numInputTensors; tensorIdx++) { + auto& tensor = graphInfo->inputTensors[tensorIdx]; + m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor); + std::string fixed_tensor_name = tensor_name.substr(0, tensor_name.find("_converted")); + QnnUtils::Tensor& spec = m_input_specs[graphName][fixed_tensor_name]; + std::string scales; + std::string offsets; + getQuantParamString(spec.quantParam, scales, offsets); + fprintf(specFile, + stringFmt, + tensor_name.c_str(), + spec.dims.height, + spec.dims.width, + spec.dims.channel, + spec.dims.bitWidth, + scales.c_str(), + offsets.c_str()); + } + + fseek(specFile, -2, SEEK_CUR); // Remove trailing comma + + // Dump out output tensor specs + fprintf(specFile, "\n\t],\n\t\"outputs\" : [\n"); + + for (size_t tensorIdx = 0; tensorIdx < graphInfo->numOutputTensors; tensorIdx++) { + auto& tensor = graphInfo->outputTensors[tensorIdx]; + m_qnnApi->getTensorNameAndShape(tensor_name, tensorDims, tensor); + std::string fixed_tensor_name = tensor_name.substr(0, tensor_name.find("_converted")); + QnnUtils::Tensor& spec = m_output_specs[graphName][fixed_tensor_name]; + std::string scales; + std::string offsets; + getQuantParamString(spec.quantParam, scales, offsets); + fprintf(specFile, + stringFmt, + tensor_name.c_str(), + spec.dims.height, + spec.dims.width, + spec.dims.channel, + spec.dims.bitWidth, + scales.c_str(), + offsets.c_str()); + } + fseek(specFile, -2, SEEK_CUR); // Remove trailing comma + fprintf(specFile, "\n\t]\n}"); + + fclose(specFile); + } +#else + QNN_ERROR( + "Requested dump tensor specs, but DEBUG_DUMP_TARGET_PATH not set. Please check nsp-model.h" + ); +#endif +} + +template +inline bool findTensor(std::unordered_map& map, std::string key) { + if (map.find(key) == map.end()) { + if constexpr (PrintError == true) QNN_ERROR("Cannot find %s\n", key.c_str()); + return false; + } + return true; +} + +template +inline ValType* getTensor(std::unordered_map& map, std::string key) { + if (map.find(key) == map.end()) { + if constexpr (PrintError == true) QNN_ERROR("Cannot find %s\n", key.c_str()); + return nullptr; + } + return &map[key]; +} + +// Run all validations for the model here so we can exit early +bool QnnCpuModel::validateModel() { + return true; +} + +bool QnnCpuModel::initializeTensorPointers() { + auto& input_specs = m_input_specs[model_order.back()]; + t_input_ids = &input_specs["x0"]; + t_input_ids_num_token = &input_specs["x1"]; + t_input_ids_reset_kvcache = &input_specs["x2"]; + t_input_ids_k_cache = &input_specs["x3"]; + t_input_ids_v_cache = &input_specs["x4"]; + t_input_ids_n_past = &input_specs["x5"]; + + auto& output_specs = m_output_specs[model_order.back()]; + t_logits = &output_specs["output_genAI"]; + t_output_n_past = &output_specs["output_npast"]; + return true; +} + +void QnnCpuModel::setupInputTensors(const std::vector& tokens, bool run_bert_mode) { + auto start = std::chrono::steady_clock::now(); + + size_t num_tokens = m_num_tokens; + + if (tokens.size() > num_tokens) { + std::string err_msg = "Called inference with more tokens than model supports: "; + err_msg += std::to_string(tokens.size()) + " vs. " + std::to_string(num_tokens); + throw std::runtime_error(err_msg); + } + + // Grab pointers to buffers for access + uint32_t* input_id_buffer = (uint32_t*)getBuffer(t_input_ids); + uint32_t* input_id_num_token_buffer = (uint32_t*)getBuffer(t_input_ids_num_token); + uint32_t* input_id_reset_kvcache_buffer = (uint32_t*)getBuffer(t_input_ids_reset_kvcache); + uint32_t* input_id_n_past_buffer = (uint32_t*)getBuffer(t_input_ids_n_past); + + uint32_t size = 1; + for (auto dim : m_input_dim) { + size *= dim; + } + + std::memset(input_id_buffer, 0, size * sizeof(uint32_t)); + std::memset(input_id_n_past_buffer, 0, sizeof(uint32_t)); + std::memset(input_id_num_token_buffer, 0, sizeof(uint32_t)); + std::memset(input_id_reset_kvcache_buffer, 0, sizeof(uint32_t)); + + std::memcpy(input_id_buffer, tokens.data(), tokens.size() * sizeof(uint32_t)); + *input_id_num_token_buffer = tokens.size(); + *input_id_n_past_buffer = m_nPast; + + auto stop = std::chrono::steady_clock::now(); + // QnnUtils::logProfile("setupInputTensors (cpp) took", start, stop); +} + +// Use qnnAPI to execute the model +template +inline bool QnnCpuModel::executeModel(T1& input, T2& output, std::string graph_name) { + // given that a dnn instance is created and we have input loaded with image data we can get our output + // for our required app functionality Execute the network with the given single input. + QNN_DEBUG("Now executing inference for graph %s", graph_name.c_str()); + +#ifdef INPUT_DUMP + if (m_inference_count < 5) dumpTensors(graph_name, true); // Dump input tensors +#endif + + bool ret = m_qnnApi->graphExecute(input, output, graph_name, timeLogs); + + if (ret != true) { + QNN_ERROR("ERROR executing inference: %d for graph %s", ret, graph_name.c_str()); + return false; + } +#ifdef OUTPUT_DUMP + if (m_inference_count < 5) dumpTensors(graph_name, false); // Dump output tensors +#endif + QNN_DEBUG("Execute finished for graph %s", graph_name.c_str()); + + return true; +} + +bool QnnCpuModel::runInferenceHelper( + std::vector& exec_models, + int32_t* wait_time_total, + int32_t* exec_time_total, + bool pipeline_kv_update, + size_t update_size +) { + int32_t exec_time = 0; + int32_t wait_time = 0; + for (auto& graph_name : exec_models) { + { + auto startTime = std::chrono::steady_clock::now(); + if (true != + executeModel(m_input_tensors[graph_name], m_output_tensors[graph_name], graph_name)) + return false; + auto endTime = std::chrono::steady_clock::now(); + exec_time += static_cast( + std::chrono::duration_cast(endTime - startTime) + .count() + ); + } + } + + if (pipeline_kv_update) { + m_nPast += update_size; + } + + *exec_time_total = exec_time; + *wait_time_total = wait_time; + return true; +} + +bool QnnCpuModel::runInference(const std::vector& tokens, bool logits_all) { + __DEBUG("qnn-cpu: run-inference start : n_tokens {}", tokens.size()); + + auto start = std::chrono::steady_clock::now(); + + // Technical note: int32_t can hold upto 596 hours + // Even int16_t should be sufficient here - it holds upto 32.8 seconds + int32_t total_wait_time = 0; + int32_t total_exec_time = 0; + + // Setup inputs for inference + setupInputTensors(tokens, false); + + auto& exec_models = model_order; + if (!runInferenceHelper(exec_models, &total_wait_time, &total_exec_time, false, tokens.size())) + return false; + + prev_run.num_tokens_processed = tokens.size(); + m_inference_count++; + + prev_run.was_bert_mode = false; + prev_run.was_logits_all = logits_all; + + auto stop = std::chrono::steady_clock::now(); + //QnnUtils::logProfile("Run Inference (cpp) took", start, stop); + timeLogs["Run Inference (cpp) "].first += static_cast( + std::chrono::duration_cast(stop - start).count() + ); + timeLogs["Run Inference (cpp) "].second++; + QNN_DEBUG("[TIME] Wait[%d] Exec[%d]\n", total_wait_time, total_exec_time); + return true; +} + +void QnnCpuModel::printFinalLogs() { +#if NSP_LOG_LEVEL > 1 + QNN_DEBUG("Total inference count : %d", m_inference_count); + for (auto& [key, value] : timeLogs) { + QNN_DEBUG("%s : %lf", key.c_str(), value.first / value.second); + } +#endif +} + +bool QnnCpuModel::setKVCacheNPast(size_t n_past) { + if(n_past > m_nPast) { + size_t num_update = n_past - m_nPast; + if (n_past != 0 && num_update > prev_run.num_tokens_processed) { + std::string err_msg = "Requested larger n_past update than #tokens produced by model"; + err_msg += std::to_string(num_update) + " vs. " + std::to_string(m_num_tokens); + throw std::runtime_error(err_msg); + } + } + + m_nPast = n_past; + return true; +} + +size_t QnnCpuModel::getDequantLogits(std::vector& dequant_logits, bool logits_all) { + // if model is BERT, always return ALL logits + if (model_output == ModelOutput::EMBEDDINGS) + logits_all = true; + + __DEBUG("qnn-cpu: get-dequant-logits logits_all {}", logits_all); + + auto& logit_spec = m_output_specs[model_order.back()]["output_genAI"]; + float* logitBuf = (float*)getBuffer(logit_spec); + size_t offset = 0; + dequant_logits.clear(); + if (model_output == ModelOutput::LOGITS) { + // if logits_all return [m_numLogits * m_vocab_size] else return [1 * m_vocab_size] + if (!logits_all) { + // Return the last processed token logits i.e. [ ..., [1]] + if (m_numLogits > 1) { + offset = (m_numLogits - 1) * m_vocab_size; + } + } else { + // if m_numLogits > n_tokens_processed, it is left padded, [0, 0, [n_tokens_processed]] + // calculate offset for getting the appropriate logits + if (m_numLogits >= prev_run.num_tokens_processed) { + offset = (m_numLogits - prev_run.num_tokens_processed) * m_vocab_size; + } + } + } +#ifdef DUMP_LOGITS + { + char fname[255]; + sprintf(fname, "%s/logits/%03d", DEBUG_DUMP_TARGET_PATH, m_inference_count); + QnnUtils::writeRawData(getBuffer(logit_spec), getBufferSize(logit_spec), fname); + } +#endif + if (model_output == ModelOutput::LOGITS) { + // logits size = [m_numLogits * m_vocab_size] + // logits might be left padded so, use calculated offset + dequant_logits.reserve((getBufferSize(logit_spec) - (offset * sizeof(float)))); + for (auto i = offset; i < (getBufferSize(logit_spec) / sizeof(float)); ++i) { + dequant_logits.push_back(logitBuf[i]); + } + } else if (model_output == ModelOutput::EMBEDDINGS) { + // embeddings size = [n_tokens_processed * m_embd] + dequant_logits.reserve((prev_run.num_tokens_processed * m_embd * sizeof(float))); + for (auto i = offset; i < ((prev_run.num_tokens_processed * m_embd)); ++i) { + dequant_logits.push_back(logitBuf[i]); + } + } + + return logits_all? prev_run.num_tokens_processed : 1; +} + +// TODO: implement save/restore +size_t QnnCpuModel::loadKVCache(const std::string& load_path) { + //TO read the cache file into KV tensor + std::ifstream f(load_path, std::ios::in | std::ios::binary); + if (f.fail()) { + // TODO: replace with proper error handling + __ERROR("qnn-cpu: load-kv errror reading file {}", load_path); + return 0; + } + + CacheFileSpec spec; + f.read((char*)&spec, sizeof(spec)); + if (spec.magic != 0xC0DE) { + __ERROR("qnn-cpu: load-kv expected 0xC0DE found {:#x}", spec.magic); + return 0; + } + // clang-format off + __DEBUG("qnn-cpu: load-kv {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}", + spec.num_tensors, spec.magic, int(spec.dtype), spec.n_heads, spec.embed_dim, spec.update_size); + // clang-format on + + const int32_t n_valid = static_cast(spec.update_size); + + float* input_id_k_cache_buffer = (float*)getBuffer(t_input_ids_k_cache); + float* input_id_v_cache_buffer = (float*)getBuffer(t_input_ids_v_cache); + + // K$, V$ 4D Tensor {n_layer, n_heads, n_ctx, n_head_dim} + + const size_t copy_size = n_valid * m_head_dim; + const size_t skip_size = (m_ctx_size + 1) * m_head_dim; + + for (int i = 0; i < m_num_layer; i++) { + for(int j = 0; j < m_num_heads; j++) { + f.read((char*)input_id_k_cache_buffer, copy_size * sizeof(float)); + input_id_k_cache_buffer += skip_size; + } + } + + for (int i = 0; i < m_num_layer; i++) { + for(int j = 0; j < m_num_heads; j++) { + f.read((char*)input_id_v_cache_buffer, copy_size * sizeof(float)); + input_id_v_cache_buffer += skip_size; + } + } + + f.close(); + + m_nPast = n_valid; + prev_run.num_tokens_processed = m_nPast; + return spec.update_size; +} + +bool QnnCpuModel::saveKVCache(const std::string& save_path) { + __DEBUG("qnn-cpu: save-kv path {}", save_path); + + std::ofstream f(save_path, std::ios::out | std::ios::binary); + if (f.fail()) { + __ERROR("qnn-cpu: save-kv error opening file : {}", save_path); + throw std::runtime_error("Failed to write to cache file. Please re-check path"); + } + + const uint32_t n_valid = static_cast(m_nPast); + const CacheFileSpec::DataType dtype = CacheFileSpec::DataType::FLOAT32_T; + + // Save the cache file metadata + CacheFileSpec spec(m_num_layer * 2, 0xc0de, dtype, 0x0, m_num_heads, m_head_dim, n_valid); + f.write((char*)&spec, sizeof(spec)); // as nsp already updated the spec + if(n_valid > 0) { + // Dump KeyCache and ValueCache + float* input_id_k_cache_buffer = (float*)getBuffer(t_input_ids_k_cache); + float* input_id_v_cache_buffer = (float*)getBuffer(t_input_ids_v_cache); + + // K$, V$ 4D Tensor {n_layer, n_heads, n_ctx, n_head_dim} + + const size_t copy_size = n_valid * m_head_dim; + const size_t skip_size = (m_ctx_size + 1) * m_head_dim; + for (int i = 0; i < m_num_layer; i++) { + for(int j = 0; j < m_num_heads; j++) { + f.write((char*)input_id_k_cache_buffer, copy_size * sizeof(float)); + input_id_k_cache_buffer += skip_size; + } + } + + for (int i = 0; i < m_num_layer; i++) { + for(int j = 0; j < m_num_heads; j++) { + f.write((char*)input_id_v_cache_buffer, copy_size * sizeof(float)); + input_id_v_cache_buffer += skip_size; + } + } + } + + f.flush(); + f.close(); + + return true; +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp b/Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..5d6b606acb4e357752df896fe65d8ae4c0afbe26 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp @@ -0,0 +1,194 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef __QUALLA_QNN_CPU_MODEL_H_ +#define __QUALLA_QNN_CPU_MODEL_H_ + +#include +#include +#include +#include +#include +#include + +#include "qualla/env.hpp" + +#include "QnnApi.hpp" +#include "IOTensor.hpp" +#include "qnn-utils.hpp" + +#define LLAMA_MODEL + +namespace qualla { + +class QnnCpuModel { + enum ExecutionMode { AUTODETECT, BERT_KV, KV_ONLY, BERT_ONLY }; + + Env& _env; + + public: + enum ModelOutput { LOGITS = 0x0, EMBEDDINGS= 0x1 }; + + struct Params { + std::filesystem::path model_basedir; + std::string op_package; + std::string backend_lib; + std::string model_bin_path; + std::string model; + ModelOutput model_output; + + bool use_mmap; + uint32_t ctx_size; + uint32_t n_threads; + size_t n_vocab_size; + uint32_t n_logits; + uint32_t n_layer; + uint32_t n_embd; + uint32_t n_heads; + }; + + const std::filesystem::path model_basedir; + std::vector filename_list; + std::vector model_order; + std::vector bert_model_order; + std::vector kv_model_order; + + std::string op_package; + std::string backend_lib; + std::string model_bin_path; + std::string model; + + long long int spill_fill_buffer_size; + + std::unordered_map model_context; + ModelOutput model_output; + std::map> timeLogs; + std::unique_ptr m_qnnApi; + std::unique_ptr m_ioTensor{nullptr}; + + // Model parameters + + size_t m_ctx_size{1024}; + size_t m_num_layer{0}; + size_t m_embd{0}; + size_t m_num_heads{0}; + size_t m_head_dim{0}; + size_t m_num_tokens{0}; + std::string position_id_path_cos; + std::string position_id_path_sin; + int32_t eos_token_id; + int32_t m_num_threads; + int32_t m_numLogits; + size_t m_vocab_size{32000}; //todo:update vocab size from tokenzier + bool m_use_mmap{false}; + std::vector m_kv_dim; + std::vector m_input_dim; + std::vector m_output_dim; + std::vector m_params; + ExecutionMode m_mode{ExecutionMode::AUTODETECT}; + + // Save some information about the last inference run + struct PreviousRunInfo { + bool was_bert_mode; + size_t num_tokens_processed; + bool was_logits_all; + } prev_run{false, 0}; + + // Model specific variables + uint32_t m_num_graphs; + std::unordered_map m_input_tensors; + std::unordered_map> + m_input_specs; + + std::unordered_map m_output_tensors; + std::unordered_map> + m_output_specs; + + // Store some pointers for easier access + QnnUtils::Tensor* t_logits; + QnnUtils::Tensor* t_output_n_past; + QnnUtils::Tensor* t_input_ids; + QnnUtils::Tensor* t_input_ids_num_token; + QnnUtils::Tensor* t_input_ids_reset_kvcache; + QnnUtils::Tensor* t_input_ids_k_cache; + QnnUtils::Tensor* t_input_ids_v_cache; + QnnUtils::Tensor* t_input_ids_n_past; + float* dequant_logits_ptr{nullptr}; + + // Store pointers for bert + QnnUtils::Tensor* b_logits; + QnnUtils::Tensor* b_input_ids; + QnnUtils::Tensor* b_attn_mask; + +#ifdef LLAMA_MODEL + // LLama specific variables + uint16_t position_id_dims; // Derived from model in initializeTensorPointers + // uint16_t position_ids_sin[1024][64]; + // uint16_t position_ids_cos[1024][64]; // RoPE Embedding tensors. Loaded from datafile + std::unique_ptr position_ids_sin; // Initialized in load_precomputed_position_ids + std::unique_ptr position_ids_cos; // Initialized in load_precomputed_position_ids + + QnnUtils::Tensor* t_position_ids_sin; + QnnUtils::Tensor* t_position_ids_cos; +#else + QnnUtils::Tensor* t_position_ids; +#endif + + // n_past defines number of population of kvcache + size_t m_nPast{0}; + + // Keep track of inference count + int m_inference_count = 0; + + QnnCpuModel(Env& env, const Params& params); + ~QnnCpuModel(); + + bool initializeModel(void); + bool validateModel(void); + bool initializeIOTensors(void); + bool initializeTensorPointers(); + + void setupInputTensors(const std::vector& tokens, bool run_bert_mode); + + template + inline bool executeModel(T1& input, T2& output, std::string graph_name); + + void dumpTensors(std::string graph_name, bool dump_input); + void dumpTensorSpecs(); + + void printFinalLogs(); + + bool runInference(const std::vector& tokens, bool logits_all); + bool setKVCacheNPast(size_t n_past); + + size_t getDequantLogits(std::vector& logits, bool logits_all = false); + + size_t loadKVCache(const std::string& save_path); + bool saveKVCache(const std::string& load_path); + + private: + bool m_mmap_context_bins = false; // mmap context binary files instead of reading them in memory + // Internal functions to separate different runInference logic + bool runInferenceHelper( + std::vector& exec_models, + int32_t* wait_time_total, + int32_t* exec_time_total, + bool pipeline_kv_update, + size_t update_size + ); + + inline void* getBuffer(QnnUtils::Tensor& spec) { return m_ioTensor->getBuffer(spec.tensor); } + inline void* getBuffer(QnnUtils::Tensor* spec) { return m_ioTensor->getBuffer(spec->tensor); } + inline size_t getBufferSize(QnnUtils::Tensor& spec) { return spec.dims.getSize(); } + inline size_t getBufferSize(QnnUtils::Tensor* spec) { return spec->dims.getSize(); } + // TODO: Seems to be some issue with m_ioTensor->getBufferSize when sharing buffers +}; + +} // namespace qualla + +#endif // __QUALLA_QNN_CPU_MODEL_HPP_ diff --git a/Genie/Genie/src/qualla/engines/qnn-htp.cpp b/Genie/Genie/src/qualla/engines/qnn-htp.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e825d009be6464474ca9c25781a5f7837c4d70a --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp.cpp @@ -0,0 +1,406 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "qnn-htp.hpp" + +#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +namespace fs = std::filesystem; + +bool NspEngine::load() { + if (_model) return true; + + qualla::Timer start; + + __INFO("qnn-htp: loading model"); + + _model = std::make_unique(_env, _params); + + // Load model + if (true != _model->initializeModel()) { + throw std::runtime_error("Failure to initialize model"); + } + + // Initialize IO Tensor buffers + if (true != _model->initializeIOTensors()) { + throw std::runtime_error("Error in setting up IO Tensors"); + } + + if (true != _model->validateModel()) { + // throw std::runtime_error("Error validating model. Please check your I/O"); + } + + __INFO("qnn-htp: model has been validated!"); + + if (true != _model->initializeKVManager()) { + throw std::runtime_error("Error initializing KVCache managers"); + } + + if (true != _model->initializeTensorPointers()) { + throw std::runtime_error("Error : Could not find I/O tensors in loaded graphs"); + } + + if (true != _model->calculate_rope_embeddings()) { + throw std::runtime_error("Error : Could not load precomputed position ids"); + } + + // Initialize LoRA + if (_model->lora_conf == LoraConfigType::LORA_INPUT_WEIGHT_ENABLE) { + if (true != _model->flushLoraWeightsBuffers()) + throw std::runtime_error("Error : Failed to flush the lora buffers"); + } + + if (true != _model->load_lmhead_weight_as_input()) { + throw std::runtime_error("Error : Could not load lmhead weight input"); + } + + _kpis.load.update(start.elapsed_usec()); + + return true; +} + +bool NspEngine::unload() { + qualla::Timer start; + + __DEBUG("qnn-htp: unloading model"); + _model.reset(nullptr); + + _kpis.unload.update(start.elapsed_usec()); + + return true; +} + +NspEngine::NspEngine(Context& ctx, const qualla::json& json) : Engine(ctx, "qnn-htp", json) { + qualla::Timer start; + + using FF = Feature::Flags; + _features = FF::OUTPUT_LOGITS | FF::SAVE_RESTORE | FF::DYNAMIC_LOAD | FF::OUTPUT_EMBEDDINGS; + + __DEBUG("qnn-htp: init start"); + + qualla::Config conf(json, _type + "-engine:"); + + // Parse config + _params.model_basedir = conf.optional("model-basedir", ""); + if (_params.model_basedir.is_relative()) { + _params.model_basedir = _env.path().models / _params.model_basedir; + _params.model_basedir = _params.model_basedir.make_preferred(); + } + _params.model_list = conf.mandatory>("model-list"); + // Parse model architecture + std::string model_architecture = conf.optional("model-architecture-type", "decoder"); + if (model_architecture == "decoder") + _params.modelArchitectureType = ModelArchitectureType::DECODER; + else if (model_architecture == "encoder") + _params.modelArchitectureType = ModelArchitectureType::ENCODER; + else + throw std::runtime_error( + "Only Encoder and Decoder architectures are supported. Invalid architecture supplied : " + + model_architecture + ); + + _params.backend_lib = conf.optional("backend-lib", ""); + _params.backend_ext_conf = conf.optional("backend-ext-conf", ""); + _params.ctx_size = _ctx.size(); + _params.mmap_budget = conf.optional("mmap-budget", 0); + _params.use_mmap = conf.optional("use-mmap", true); + _params.use_async_Init = conf.optional("use-async-Init", true); + _params.spill_fill_bufsize = conf.optional("spill-fill-bufsize", 0); + _params.kv_dim = conf.optional("kv-dim", 128); + _params.n_embd = _ctx.n_embd(); + _params.pad_token = _ctx.pad(); + _params.variant_latency = std::map(); + _params.disable_kv_cache = conf.optional("disable-kv-cache", false); + _params.pooled_output = conf.optional("pooled-output", true); + _params.lmhead_weight_dir = conf.optional("lmhead-weight-dir", ""); + _params.graph_switching = conf.optional("enable-graph-switching", false); + _params.exec_select_graphs = + conf.optional>("execute-select-graphs", {}); + _params.load_select_graphs = conf.optional("load-select-graphs", false); + + qualla::json latencies = conf.optional("latency-map", {}); + for (auto& [variant, latency] : latencies.items()) + _params.variant_latency[std::stoi(variant)] = latency; + _params.kv_update_method = conf.optional( + "kv-update-method", (conf.optional("pos-id-dim", 64) == 40) ? "SHIFT_CONCAT" : "POINTER_SHIFT" + ); + _params.n_threads = conf.optional("n-threads", 4); + if(_params.disable_kv_cache){ + _params.n_threads = 0; + } + _params.poll = conf.optional("poll", false); + + // Positional encodings parameters + if (conf.json.contains("positional-encoding")) { + try { + conf.json["positional-encoding"].get_to(_params.positional_encoding_params); + } catch (const std::runtime_error& e) { + State::fatal(fmt::format("Error in positional-encoding - {}", e.what())); + throw std::runtime_error(State::error()); + } + } else { // For Backward compatibility. May be removed in future releases + // __WARN("Using depracated positional encoding config. Please switch to positional-encoding"); + auto &pos_type = _params.positional_encoding_params; + if(_params.modelArchitectureType == ModelArchitectureType::DECODER) { + pos_type.type = PositionalEncoding::ROPE; + pos_type.rope_params.dims = conf.optional("pos-id-dim", 64); + pos_type.rope_params.dims = conf.optional("pos-id-dims", pos_type.rope_params.dims); + pos_type.rope_params.theta = conf.optional("rope-theta", 10000.0); + pos_type.rope_params.rope_scaling = conf.optional("rope-scaling", RopeScalingParams()); + } + else{ + pos_type.type = PositionalEncoding::ABSOLUTE; + // Other parameters for ENCODER ONLY model doesn't matter. + } + } + // Default LoRA is Disabled + uint8_t lora_version = conf.optional("lora-version", 0); + switch(lora_version){ + case 0: _params.lora_config_type = LoraConfigType::LORA_DISABLE; break; + case 1: _params.lora_config_type = LoraConfigType::LORA_INPUT_WEIGHT_ENABLE; break; + case 2: _params.lora_config_type = LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE; break; + default: throw std::runtime_error("Lora Verison Undefined."); break; + } + // LoRA adapter setting + qualla::json lora_conf = conf.optional("lora", {}); + if (lora_conf.size() != 0) { + if (lora_conf.is_array()) { + for (auto lc : lora_conf) { + std::string lnm = lc["adapter-name"]; + _params.lora_param[lnm].lora_name = lnm; + _params.lora_param[lnm].alpha_tensor_name = lc["alpha-tensor-name"]; + _params.lora_param[lnm].alpha_tensor_val = 0.0f; + if (lc.contains("alpha-tensor-value")) { + _params.lora_param[lnm].alpha_tensor_val = lc["alpha-tensor-value"]; + } + if (_params.lora_config_type == LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) { + std::string basedir = ""; + if (lc.contains("binsection-basedir")) { + basedir = lc["binsection-basedir"]; + } + uint32_t n = lc["bin-sections"].size(); + for (uint32_t i = 0; i < n; i++) { + auto binSec = lc["bin-sections"].get>(); + fs::path binsection_path = fs::path(binSec[i]); + if (binsection_path.is_relative()) binsection_path = basedir / fs::path(binSec[i]); + if (!fs::is_regular_file(binsection_path)) { + __ERROR("qnn-htp: Can't access Lora binsection adapter : {}", + binsection_path.string()); + throw std::runtime_error( + "qnn-htp: Can't adapter file : " + binsection_path.string() + ); + } + _params.lora_param[lnm].binsection_list.push_back(binsection_path.string()); + } + } + else if( _params.lora_config_type == LoraConfigType::LORA_INPUT_WEIGHT_ENABLE ){ + _params.lora_param[lnm].path = lc["path"]; + } + } + } + } + + _params.embedding_length = _ctx.embeddingLength(); + _params.embedding_datatype = _ctx.embeddingDatatype(); + + // cpumask needs to be a string because JSON RFC doesn't allow for hex ints. + std::string cpumask = conf.optional("cpumask", "0"); + _params.cpumask = std::stoull(cpumask, nullptr, 0); + + // Debug flags + _params.debug_path = conf.optional("debug-path", "qualla_debug"); + _params.debug_specs = conf.optional("debug-specs", false); + _params.debug_tensors = conf.optional("debug-tensors", false); + _params.debug_outputs = conf.optional("debug-outputs", false); + _params.debug_qnn = conf.optional("debug-qnn", false); + + if (!conf.optional("dynamic-load", false)) { + load(); + } +}; + +NspEngine::~NspEngine() { + unload(); +} + +bool NspEngine::updateKV(size_t n_past) { + return updateKV(n_past, {}); +} + +bool NspEngine::updateKV(size_t n_past, const std::vector& selected) { + if (!_model && !load()) return false; + + qualla::Timer start; + + if (n_past > _ctx.size()) { + __ERROR("qnn-htp: context size exceeded : n_past {}", n_past); + State::error("context size exceeded"); + return false; + } + + if (!_model->setKVCacheNPast(n_past, selected)) { + __ERROR("qnn-htp: Error updating KV$"); + return false; + } + + __DEBUG("qnn-htp: Dispatched KV$ Update (n_past={}) in {} usec", n_past, start.elapsed_usec()); + + _kpis.update_kv.update(start.elapsed_usec()); + + return true; +} + +size_t NspEngine::process( + const std::vector& tokens, + std::vector& logits, + bool logits_all +) { + return process(tokens, {}, logits, logits_all); +} + +size_t NspEngine::process( + const std::vector& tokens, + const std::vector& attention_map, + std::vector& logits, + bool logits_all +) { + if (!_model && !load()) return 0; + + qualla::Timer start; + + size_t n_tok = _model->runInference(tokens, attention_map, logits, logits_all); + if (n_tok == 0) { + State::error("qnn-htp : runInference failed!"); + } + + _kpis.process.update(start.elapsed_usec()); + + return n_tok; +} + +size_t NspEngine::process( + std::vector& embeddings, + const std::vector& attention_map, + std::vector& logits, + bool logits_all +) { + if (!_model && !load()) return 0; + qualla::Timer start; + + __DEBUG("qnn-htp: inference start: n_tokens {}", embeddings.size()); + + size_t n_tok = _model->runInference( + embeddings, attention_map, logits, logits_all + ); + if (n_tok == 0) { + State::error("qnn-htp : runInference failed!"); + } + __DEBUG("qnn-htp: inference complete : {} usec", start.elapsed_usec()); + + _kpis.process.update(start.elapsed_usec()); + + return n_tok; +} + +bool NspEngine::cacheEosEmbedding(std::vector& eosEmbedding) { + if (!_model && !load()) { + return false; + } + return _model->cacheEosEmbedding(eosEmbedding); +}; + +size_t NspEngine::getEmbeddingBufferSize() { + return _model->getEmbeddingBufferSize(); +} + +bool NspEngine::set(qualla::json data) { + bool ret = false; + + if (data.contains("kv-prefix-skip")) { + _model->_size_to_skip_kv_prefix = data["kv-prefix-skip"].get(); + ret = true; + } + + if (data.contains("kv-prefix-offset")) { + _model->_offset_to_apply_kv_prefix = data["kv-prefix-offset"].get(); + ret = true; + } + return ret; +} + +qualla::json NspEngine::get() { + return {{"kv-prefix-skip", _model->_size_to_skip_kv_prefix}, + {"kv-prefix-offset", _model->_offset_to_apply_kv_prefix}}; +} + + +qualla::InputType NspEngine::getInputType(){ + return _model->m_inputType; +} + +size_t NspEngine::restore(const std::string& name) { + if (!_model && !load()) return 0; + + fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-htp", _role); + return _model->loadKVCache(cache_path.string()); +} + +bool NspEngine::save(const std::string& name) { + if (!_model && !load()) return false; + + fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-htp", _role); + return _model->saveKVCache(cache_path.string()); +} + +void NspEngine::reset() { + if (!_model && !load()) return; + + // It's enough to just drop the KV$ + updateKV(0); +} + +// Registrator instance +static OnLoad regy([]() { + Engine::__register("qnn-htp", [](Context& ctx, const json& conf) { + return (Engine*)new NspEngine(ctx, conf); + }); +}); +void needQnnHtpEngine() {} + +bool NspEngine::applyLoraAdapter(std::string lora_adapter_name) { + + if (!_model) { + __ERROR("qnn-htp: applyLoraAdapter failed model not initialized"); + return false; + } + if (_model->lora_conf == LoraConfigType::LORA_INPUT_WEIGHT_ENABLE) { + return _model->applyLoraWeights(lora_adapter_name); + } + else + return _model->applyLoraAdapter(lora_adapter_name); +} + +bool NspEngine::applyLoraStrength(std::string tensor_name, float tensor_val) { + if (!_model) { + __ERROR("qnn-htp: applyLoraStrength failed model not initialized"); + return false; + } + return _model->applyLoraStrength(tensor_name, tensor_val); +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp.hpp b/Genie/Genie/src/qualla/engines/qnn-htp.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4b04bb1911cea93c67b2d6c9831837baee2b9e5e --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp.hpp @@ -0,0 +1,88 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef __QNN_HTP_H__ +#define __QNN_HTP_H__ + +#include +#include + +#include +#include +#include +#include + +#include + +#include "nsp-model.hpp" + +namespace qualla { + +class NspEngine : public Engine { + protected: + QnnNspModel::Params _params; + + std::unique_ptr _model; + + public: + NspEngine(Context& ctx, const qualla::json& json); + virtual ~NspEngine(); + + virtual size_t process( + const std::vector& tokens, + std::vector& logits, + bool logits_all + ) override; + + virtual size_t process( + const std::vector& tokens, + const std::vector& attention_map, + std::vector& logits, + bool logits_all + ) override; + + virtual size_t process( + std::vector& embeddings, + const std::vector& attention_map, + std::vector& logits, + bool logits_all + ) override; + + /** Stores a precomputed EOS embedding vector. */ + virtual bool cacheEosEmbedding(std::vector& eosEmbedding) override; + + void getInputQuantParam(double& scale, int& offset) { + + auto tmp = _model->t_input_ids->quantParam[0]; + scale = tmp.scale; + offset = tmp.offset; + } + + virtual qualla::InputType getInputType() override; + + virtual size_t getEmbeddingBufferSize() override; + + virtual bool updateKV(size_t n_past) override; + virtual bool updateKV(size_t n_past, const std::vector& selected) override; + virtual bool save(const std::string& name) override; + virtual size_t restore(const std::string& name) override; + virtual void reset() override; + + virtual bool set(qualla::json data) override; + virtual qualla::json get() override; + + virtual bool load() override; + virtual bool unload() override; + + virtual bool applyLoraAdapter(std::string lora_adapter_name) override; + virtual bool applyLoraStrength(std::string tensor_name, float tensor_val) override; +}; + +} // namespace qualla + +#endif diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-graph.cpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-graph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d0de949c4524338e6c4a3f1af773bbf945593684 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-graph.cpp @@ -0,0 +1,304 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "qualla/detail/timer.hpp" + +#include "nsp-model.hpp" +#include "nsp-graph.hpp" + +#include + +#include "fmt/format.h" +#include "fmt/ranges.h" + +// Copied from threadpool.cpp +#if defined(_WIN32) + #define NOGDI + #include "windows.h" + +static int sched_yield(void) { + Sleep(0); + return 0; +} +#else + #include +#endif + +#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __KVTRACE(__fmt, ...) \ + _env.logger().post(Logger::KVMANAGER_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +namespace qualla { + +// GraphVariant is a self-contained graph. Represents one specific QNN Model +GraphVariant::GraphVariant(GraphInfo_t* g_info, Qnn_ContextHandle_t qnn_ctx, int32_t n_ctx, std::map& layerNames) + : ctx_size(n_ctx), graph_name(g_info->graphName), graph_info(g_info), context_handle(qnn_ctx), m_layerNames(layerNames) { + //TRACE("Parsing %s with ctx_size %d", this->graph_name.c_str(), n_ctx); + + for (bool io : {true, false}) { + uint32_t n_tensors = (io) ? graph_info->numInputTensors : graph_info->numOutputTensors; + auto tensor_wrappers = (io) ? graph_info->inputTensors : graph_info->outputTensors; + auto& tensor_specs = (io) ? input_specs : output_specs; + for (size_t tensor_idx = 0; tensor_idx < n_tensors; tensor_idx++) { + + TensorWrapper& tensor = tensor_wrappers[tensor_idx]; + std::string tensor_name = QnnApi::getTensorName(tensor); + + std::vector tensor_dims; + if (!QnnApi::getTensorShape(tensor_dims, tensor)) + throw std::runtime_error("Couldn't get tensor shape : " + tensor_name); + std::vector quantParams; + if (!QnnApi::getTensorQuantParams(&tensor, quantParams)) { + quantParams.emplace_back(0, 0); + } + tensor_specs[tensor_name] = + QnnUtils::Tensor(&tensor, tensor_dims, quantParams); + } + } + + n_tokens = static_cast(determineGraphInputSize()); +} + +// Attempt to determine input size from purely graph IO and context size +// The easiest way is using input_ids. Else, attention_mask/position_ids can also be used +size_t GraphVariant::determineGraphInputSize() { + QnnUtils::Tensor* tensor; + if (m_layerNames[LayerType::INPUT] == "inputs_embeds") { + if (!!(tensor = getInput(m_layerNames[LayerType::ATTN_MASK]))) return tensor->dims.getNumElements() / ctx_size; + } else { + if (!!(tensor = getInput(m_layerNames[LayerType::INPUT]))) return tensor->dims.getNumElements(); + // Use past_key_out tensor to find input size + // The last dimension of past_key_out tensor will always be the input size + for (auto& [tname, qtensor] : output_specs) { + if (!tname.starts_with("past_key")) continue; + return static_cast(qtensor.dims.channel); + } + } + throw std::runtime_error("Unexpected model. Couldn't determine m_num_tokens"); +} + +bool GraphVariant::refreshTensorQuantParams() { + for (bool io : {true, false}) { + uint32_t n_tensors = (io) ? graph_info->numInputTensors : graph_info->numOutputTensors; + auto tensor_wrappers = (io) ? graph_info->inputTensors : graph_info->outputTensors; + auto& tensor_specs = (io) ? input_specs : output_specs; + for (size_t tensor_idx = 0; tensor_idx < n_tensors; tensor_idx++) { + + TensorWrapper& tensor = tensor_wrappers[tensor_idx]; + std::string tensor_name = QnnApi::getTensorName(tensor); + std::vector quantParams; + if (!QnnApi::getTensorQuantParams(&tensor, quantParams)) { + quantParams.emplace_back(0, 0); + } + tensor_specs[tensor_name].quantParam = quantParams; + } + } + return true; +} + +QnnNspGraph::QnnNspGraph( + int idx, + Env& env, + int32_t n_ctx, + QnnApi* qnnApi, + IOTensor* ioTensor, + bool threaded +) + : _idx(idx), _env(env), ctx_size(n_ctx), g_qnn_api(qnnApi), g_buffer_mgr(ioTensor), + _threaded(threaded) { + + if (_threaded) { + _lock = new std::mutex(); + _lock_cv = new std::condition_variable(); + } + __DEBUG("qnn-htp: new-NSP-graph : n_ctx {}", n_ctx); +} + +QnnNspGraph::~QnnNspGraph() { + __DEBUG("qnn-htp: del-NSP-graph"); + if (kvmanager != nullptr) delete kvmanager; + if (_threaded) { + delete _lock; + delete _lock_cv; + } +} + +// Parse a loaded GraphInfo_t +bool QnnNspGraph::addGraph(GraphVariant* graph_spec) { + // TRACE("%d", graph_spec->n_tokens); + const int32_t n_tok = graph_spec->n_tokens; + // QNN_DEBUG("Searching for n_tokens=%d count=%lu ctx_size=%d", n_tok, variants.count(n_tok), ctx_size); + if (variants.find(n_tok) != variants.end()) { + printAvailableConfigs(); + __ERROR("qnn-htp: addGraph detected duplicate : {} v {}", n_tok, variants[n_tok]->n_tokens); + throw std::runtime_error("qnn-htp: duplicate graph found, likely overflow occured"); + } + + variants[n_tok] = graph_spec; + return true; +} + +void QnnNspGraph::printAvailableConfigs() { + std::stringstream config_stream; + for (auto& [config, _] : variants) + config_stream << config << ", "; + + __DEBUG("config = [{}]", config_stream.str()); +} + +void QnnNspGraph::dumpTensors(GraphVariant* const variant, bool mode, int n_inference) const { + if (n_inference >= 10) return; + + QnnUtils::TensorMap& tensor_specs = (mode) ? variant->input_specs : variant->output_specs; + std::string prefix = fmt::format("{}/{}/{:03d}", _debug_path, variant->graph_name, n_inference); + for (auto it = tensor_specs.begin(); it != tensor_specs.end(); ++it) { + auto tname = it->first; + auto tspec = it->second; + std::string fname = fmt::format("{}_{}_{}", prefix, (mode) ? "in" : "out", tname); + __TRACE("Dumping {} from {:p}", fname, g_buffer_mgr->getBuffer(tspec.tensor)); + QnnUtils::writeRawData(g_buffer_mgr->getBuffer(tspec.tensor), tspec.dims.getSize(), fname); + } +} + +bool QnnNspGraph::registerPointerShift(int32_t variant, int32_t ptr_offset) { + __TRACE("Called QnnNspGraph::registerPointerShift"); + if (_kv_update_method != POINTER_SHIFT) return true; + if (kvmanager->getNumKVTensors() == 0) return true; + qualla::Timer start; + + std::map> allocs; + + qualla::GraphVariant* graph_variant = variants.at(variant); + if (variant == ctx_size) { + // Re-map AR-c model outputs to initial state + for (auto& [tname, tspec] : graph_variant->output_specs) { + if (!tname.starts_with("past_")) continue; // Only process KV$ + auto& [alloc_idx, offset] = tensor_alloc_info->at(tname); + allocs[tname] = {alloc_idx, offset, tspec.dims.getAlignedSize()}; + } + } else { + + // For AR-n models, map input KV$ to appropriate offset + for (auto& [tname, tspec] : graph_variant->input_specs) { + if (!tname.starts_with("past_")) continue; // Only process KV$ + auto out_name = tname.substr(0, tname.rfind("_")).append("_out"); + + auto& [alloc_idx, offset] = tensor_alloc_info->at(out_name); + const bool is_key = tname.starts_with("past_key"); + const int32_t extra_offset = ptr_offset * (is_key ? 1 : kvmanager->_n_embed); + allocs[tname] = {alloc_idx, offset + extra_offset, tspec.dims.getAlignedSize()}; + } + } + + if (!g_buffer_mgr->mapFusedBufferOffset( + graph_variant->graph_info, graph_variant->context_handle, allocs + )) { + __ERROR("Error mapping tensor to allocation buffers"); + return false; + } + + __DEBUG("qnn-htp: pointerShift complete : {} usec", start.elapsed_usec()); + return true; +} + +void QnnNspGraph::registerKVManager(NewNSPKVManager* mgr) { + kvmanager = mgr; + if (mgr->getNumKVTensors() == 0 && _threaded) { + delete _lock; + delete _lock_cv; + _threaded = false; + } + mgr->registerPointerOffsetFn([this](int32_t variant, int32_t ptr_offset) { + return this->registerPointerShift(variant, ptr_offset); + }); +} + +bool QnnNspGraph::execute(int n_tokens, int n_inference, int32_t wait_count) { + GraphVariant* variant = variants.at(n_tokens); // Assume n_tokens exists in variants + run_wait_time = run_exec_time = 0; // Clear out the timer + + qualla::Timer timer; + + waitForLock("QnnNspGraph::execute", wait_count, false); + run_wait_time += timer.elapsed_usec(); + + // Register pointer shift + GraphInfo_t* const graph = variant->graph_info; + + if (_debug_tensors) dumpTensors(variant, true, n_inference); // Dump input tensors + + timer.reset(); // Reset the timer to calculate execution time + std::map> timeLogs; + if (!g_qnn_api->graphExecute( + graph->inputTensors, graph->outputTensors, graph->graphName, timeLogs + )) { + __ERROR("qnn-htp: graph-exec failed for {}", graph->graphName); + return false; + } + + run_exec_time += timer.elapsed_usec(); + + if (_debug_tensors) dumpTensors(variant, false, n_inference); // Dump output tensors + + timer.reset(); + releaseLock("QnnNspGraph::execute"); + run_wait_time += timer.elapsed_usec(); + return true; +} + +void QnnNspGraph::waitForLock(std::string requester) { + if (!_threaded) return; + __KVTRACE("qnn-lock : graph[{}] requested : {}", _idx, requester); + _lock->lock(); + __KVTRACE("qnn-lock : graph[{}] locking : {}", _idx, requester); +} + +void QnnNspGraph::waitForLock(std::string requester, int32_t wait_counter, bool poll) { + if (!_threaded) return; + __KVTRACE("qnn-lock : graph[{}] requested : {} (count={})", _idx, requester, wait_counter); + + if (poll) { + _lock->lock(); + // Busy wait until a specific update is complete + while (_counter < wait_counter) { + _lock->unlock(); + sched_yield(); + _lock->lock(); + } + } else { + std::unique_lock lk(*_lock); + _lock_cv->wait(lk, [&] { + __KVTRACE("qnn-lock : graph[{}] trying ({} >= {})", _idx, _counter, wait_counter); + return _counter >= wait_counter; + }); + lk.release(); + } + + __KVTRACE("qnn-lock : graph[{}] locking : {} (count={})", _idx, requester, wait_counter); + return; +} + +void QnnNspGraph::releaseLock(std::string requester) { + if (!_threaded) return; + __KVTRACE("qnn-lock : graph[{}] releasing : {} (count={})", _idx, requester, _counter); + _lock->unlock(); + _lock_cv->notify_one(); +} + +void QnnNspGraph::wakeUpLock() { + if (!_threaded) return; + _lock_cv->notify_one(); +} +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-graph.hpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-graph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..058a66f41bbf0c764aed67b5cd219b34f2146daa --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-graph.hpp @@ -0,0 +1,142 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include "qualla/env.hpp" + +#include "QnnApi.hpp" +#include "IOTensor.hpp" +#include "qnn-utils.hpp" +#include "nsp-kvmanager.hpp" + +namespace qualla { +enum class LayerType { + INPUT, + OUTPUT, + ATTN_MASK, + POS_SIN, + POS_COS, + POS_IDS, + TOKEN_TYPE_IDS, + POOL_OUTPUT, + SEQ_OUTPUT +}; +struct GraphVariant { + int32_t n_tokens; + int32_t ctx_size{-1}; + std::string graph_name; + + // QNN API specific variables + GraphInfo_t* graph_info; + Qnn_ContextHandle_t context_handle; + + QnnUtils::TensorMap input_specs; + QnnUtils::TensorMap output_specs; + + std::map& m_layerNames; + + GraphVariant() = delete; + GraphVariant(GraphInfo_t* g_info, Qnn_ContextHandle_t qnn_ctx, int32_t n_ctx, std::map& layerNames); + QnnUtils::Tensor* getTensor(const std::string& tensor_name) { + QnnUtils::Tensor* ret = getInput(tensor_name); + return (ret != nullptr) ? ret : getOutput(tensor_name); + } + QnnUtils::Tensor* getInput(const std::string& tensor_name) { + return input_specs.contains(tensor_name) ? &input_specs.at(tensor_name) : nullptr; + } + QnnUtils::Tensor* getOutput(const std::string& tensor_name) { + return output_specs.contains(tensor_name) ? &output_specs.at(tensor_name) : nullptr; + } + + bool refreshTensorQuantParams(); + + private: + size_t determineGraphInputSize(); +}; + +/** + * The idea behind QnnNspGraph is to represent "common" graphs + * For instance, both BERT-mode and KV$-mode are the same graph with different input sizes + * QnnNspGraph will contain and manage both BERT-split-n and KV$mode-split-n + * I/O tensors are mostly shared between these graphs, and can be managed collectively +*/ +class QnnNspGraph { + private: + int _idx; + Env& _env; + + int32_t ctx_size{-1}; + + // Useful pointers for graph execution (managed by NSPModel) + QnnApi* g_qnn_api; + IOTensor* g_buffer_mgr; + + bool _threaded; + std::mutex* _lock; // Locks whenever KV$ is being used or updated + std::condition_variable* _lock_cv; // Wake up _lock when jobs are complete + + KVManagerMode _kv_update_method{POINTER_SHIFT}; + + int32_t run_wait_time, run_exec_time; // Add more stats into a struct + + // Debug mode settings + bool _debug_specs{false}; + bool _debug_tensors{false}; + std::string _debug_path; + + public: + int32_t _counter{-1}; + NewNSPKVManager* kvmanager{nullptr}; + + // TODO: Remove this reference + std::map>* tensor_alloc_info; + + // Keys represent input_id size (1<=input_size<=ctx_size) + // Values are graph description for that input_id size + std::map variants; + + QnnNspGraph( + int idx, + Env& env, + int32_t n_ctx, + QnnApi* qnnApi, + IOTensor* ioTensor, + bool threaded + ); + ~QnnNspGraph(); + + bool addGraph(GraphVariant* graph_spec); + void printAvailableConfigs(); + void registerKVManager(NewNSPKVManager* mgr); + + // Given an input size, picks the correct model among the ones available + // This is likely not easy to implement as there's implications on KV$ management + size_t getOptimalModelInputSize(size_t n_past, size_t input_size) { return 0; } + + GraphVariant* operator[](int32_t idx) { return variants.at(idx); } + + bool execute(int n_tokens, int n_inference, int32_t wait_count); + const std::pair getExecutionStats() { return {run_wait_time, run_exec_time}; } + + void setDebugMode(bool debug_specs, bool debug_tensors, std::string debug_path) { + _debug_path = debug_path; + _debug_specs = debug_specs; + _debug_tensors = debug_tensors; + } + void dumpTensors(GraphVariant* const variant, bool mode, int n_inference) const; + + // Mutex functions + void wakeUpLock(); + void waitForLock(std::string requester = ""); + void waitForLock(std::string requester, int32_t wait_counter, bool poll); + void releaseLock(std::string requester = ""); + bool registerPointerShift(int32_t variant, int32_t ptr_offset); +}; + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvdispatcher.cpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvdispatcher.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e72d9a0773f4e3c38a041902da1a1ccc705fc444 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvdispatcher.cpp @@ -0,0 +1,319 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "nsp-kvdispatcher.hpp" + +#include "fmt/format.h" +#include "fmt/ranges.h" + +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KVTRACE(__fmt, ...) \ + _env.logger().post(Logger::KVMANAGER_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +// Copied from threadpool.cpp +#if defined(_WIN32) + #define NOGDI + #include "windows.h" + +static bool __thread_affinity(uint64_t mask) { + HANDLE h = GetCurrentThread(); + DWORD_PTR m = mask; + + m = SetThreadAffinityMask(h, m); + + return m != 0; +} + +static int sched_yield(void) { + Sleep(0); + return 0; +} + +#elif defined(__APPLE__) +static bool __thread_affinity(uint64_t mask) { + return true; +} + +#else // posix? + #include + #include + #include + +static bool __thread_affinity(uint64_t mask) { + cpu_set_t cpuset; + int32_t err; + + CPU_ZERO(&cpuset); + + for (uint32_t i = 0; i < 64; i++) { + if ((1ULL << i) & mask) { + CPU_SET(i, &cpuset); + } + } + + #ifdef __ANDROID__ + err = sched_setaffinity(0, sizeof(cpuset), &cpuset); + if (err < 0) { + err = errno; + } + #else + err = pthread_setaffinity_np(pthread_self(), sizeof(cpuset), &cpuset); + #endif + if (err != 0) { + fprintf(stderr, + "warn: failed to set affinity mask 0x%llx (err %d: %s)\n", + (unsigned long long)mask, + err, + strerror(err)); + return false; + } + + return true; +} + +#endif + +#ifdef _MSC_VER + +static inline void __cpu_relax(void) { + YieldProcessor(); +} + +#else + + #if defined(__aarch64__) + +static inline void __cpu_relax(void) { + __asm__ volatile("yield" ::: "memory"); +} + + #else + +static inline void __cpu_relax(void) { + __asm__ volatile("rep; nop" ::: "memory"); +} + + #endif +#endif + +namespace qualla { + +KVDispatcher::KVDispatcher( + Env& env, + std::vector& graphs, + bool threaded, + uint64_t cpumask +) + : _env(env), _threaded(threaded), _cpumask(cpumask) { + + int32_t idx = 0; + for (QnnNspGraph& graph : graphs) { + if (_threaded) + graph.kvmanager->registerCallback([this](int32_t split) { + return this->workerCallback(split); + }); + + // Initialize new DispatcherState() + bool active = (graph.kvmanager->getNumKVTensors() > 0); + _state.emplace_back(idx, active, false, &graph, KVState(), KVState(), KVState()); + idx++; + } + + if (_threaded) _dispatcher_thread = std::thread(&KVDispatcher::dispatchLoop, this); +} + +KVDispatcher::~KVDispatcher() { + if (_threaded) { + _dispatcher_terminate = true; + _cv.notify_all(); + _dispatcher_thread.join(); + } +} + +int32_t KVDispatcher::process( + int32_t split, + int32_t variant, + int32_t n_past, + const std::vector& selected +) { + DispatcherState& state = _state[split]; + + state.requested.n_past = n_past; + state.requested.variant = variant; + state.requested.selected = selected; + return ++state.requested.counter; +} + +int32_t KVDispatcher::dispatch(int32_t split, int32_t variant, int32_t n_past) { + return dispatch(split, variant, n_past, {}); +} +int32_t KVDispatcher::dispatch( + int32_t split, + int32_t variant, + int32_t n_past, + const std::vector& selected +) { + _variant = variant; + + if (!_threaded) { + if (_state[split].active) + _state[split].graph->kvmanager->dispatchUpdate(n_past, variant, selected); + return 0; + } + + if (!_state[split].active) // Increment current counter and return new value + return _state[split].current.counter = process(split, variant, n_past, selected); + + int32_t updated_idx; + { + std::lock_guard lk(_dispatcher_lock); + updated_idx = process(split, variant, n_past, selected); + _dispatcher_requested = true; + } + + _cv.notify_one(); + return updated_idx; +} + +int32_t KVDispatcher::dispatch(int32_t variant, int32_t n_past) { + return dispatch(variant, n_past, std::vector{}); +} + +int32_t KVDispatcher::dispatch(int32_t variant, int32_t n_past, const std::vector& selected) { + _variant = variant; + + if (!_threaded) { + for (auto& s : _state) + if (s.active) s.graph->kvmanager->dispatchUpdate(n_past, variant, selected); + return 0; + } + + int32_t global_updated_idx = -1; + { + std::lock_guard lk(_dispatcher_lock); + + for (auto& s : _state) { + if (!s.active) { + global_updated_idx = + (s.current.counter = process(s.split_idx, variant, n_past, selected)); + continue; + } + + int32_t updated_idx = process(s.split_idx, variant, n_past, selected); + if (global_updated_idx == -1) + global_updated_idx = updated_idx; + else if (global_updated_idx != updated_idx) { + // Something went wrong. States are not in sync + __ERROR("qnn-kv: Dispatcher states out of sync - {} vs {}", + global_updated_idx, + updated_idx); + } + } + _dispatcher_requested = true; + } + + _cv.notify_one(); + return global_updated_idx; +} + +void KVDispatcher::dispatchLoop() { + // if (_cpumask) __thread_affinity(_cpumask); + + //loop dispatch + std::vector dispatch_queue; + dispatch_queue.reserve(_state.size()); + std::unique_lock lk(_dispatcher_lock, std::defer_lock); + + while (true) { + lk.lock(); + _cv.wait(lk, [this] { + return _dispatcher_terminate || _dispatcher_requested || _dispatcher_job_completed; + }); + + // On exit, release all locks + if (_dispatcher_terminate) { + for (auto& s : _state) { + if (s.active && (s.release_lock || s.current.counter != s.queued.counter)) + s.graph->releaseLock("dispatcher_terminate"); + } + lk.unlock(); + break; + } + + __KVTRACE("qnn-kv: Dispatcher ({}, {})", _dispatcher_requested, _dispatcher_job_completed); + + // When a job is complete, release all relevant locks + if (_dispatcher_job_completed) { + for (auto& s : _state) { + if (s.release_lock) { + s.graph->releaseLock("kv-update"); + s.release_lock = false; + } + } + } + + for (auto& s : _state) { + if (!s.active) { + s.current = s.requested; + continue; + } + + auto& current = s.current; + auto& queued = s.queued; + auto& requested = s.requested; + + // There is no new work to be done, OR + // KVManager is already working on a job on this split. Wait for completion. + if (queued.counter == requested.counter || current.counter != queued.counter) continue; + + // Requested change has already been completed + if (current.n_past == requested.n_past && current.variant == requested.variant) { + s.graph->_counter = current.counter = queued.counter = requested.counter; + s.graph->wakeUpLock(); + continue; + } + + // Job has been requested but not yet dispatched + s.queued = s.requested; + dispatch_queue.emplace_back(s.split_idx); + } + + _dispatcher_job_completed = false; // Be ready for next job completion + _dispatcher_requested = false; // Be ready for next job request + + lk.unlock(); + + // Dispatch jobs + for (auto split : dispatch_queue) { + DispatcherState& s = _state[split]; + s.graph->waitForLock("kv-update"); + s.graph->kvmanager->dispatchUpdate( + s.queued.n_past, s.queued.variant, s.queued.selected + ); + } + dispatch_queue.clear(); + } + __KVTRACE("qnn-kv : Dispatcher terminating"); +} + +int32_t KVDispatcher::workerCallback(int32_t split) { + __KVTRACE("qnn-kv : graph[{}] workerCallback()", split); + { + std::lock_guard lk(_dispatcher_lock); + // Update relevant job counters + _state[split].current = _state[split].queued; + _state[split].graph->_counter = _state[split].current.counter; + _state[split].release_lock = true; + _dispatcher_job_completed = true; + } + + _cv.notify_one(); + return _state[split].current.counter; +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvdispatcher.hpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvdispatcher.hpp new file mode 100644 index 0000000000000000000000000000000000000000..94601ea60b79bf9c70b570d467a90346f7b9d6e3 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvdispatcher.hpp @@ -0,0 +1,109 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include +#include +#include +#include + +#include "qualla/detail/timer.hpp" +#include "qualla/detail/threadpool.hpp" + +#include "nsp-graph.hpp" +#include "nsp-kvmanager.hpp" + +namespace qualla { + +struct KVState { + int32_t counter; + int32_t n_past; + int32_t variant; + std::vector selected; + KVState() : counter(-1), n_past(-1), variant(-1) {} + KVState(int32_t _counter, int32_t _n_past, int32_t _variant) + : counter(_counter), n_past(_n_past), variant(_variant) {} +}; + +struct DispatcherState { + DispatcherState( + int split_idxVal, + bool activeVal, + bool release_lockVal, + QnnNspGraph* graphVal, + KVState currentVal, + KVState queuedVal, + KVState requestedVal + ) + : split_idx(split_idxVal), active(activeVal), release_lock(release_lockVal), + graph(graphVal), current(currentVal), queued(queuedVal), requested(requestedVal) {} + int split_idx; + bool active; // false means inactive, i.e. no KV$ to update + bool release_lock; // Set to true when job is complete so we can release the lock + QnnNspGraph* graph; + KVState current; + KVState queued; + KVState requested; +}; + +class KVDispatcher { + private: + Env& _env; + bool _threaded; + bool _poll; // Currently unused + uint64_t _cpumask{0}; + + int32_t _variant{-1}; + + std::vector _state; + + std::thread _dispatcher_thread; + bool _dispatcher_terminate{false}; + bool _dispatcher_requested{false}; + bool _dispatcher_job_completed{false}; + std::mutex _dispatcher_lock; + + std::condition_variable _cv; + + // Function to add jobs to the dispatcher + // @param split Determines which split to update + // @param variant Variant of the model to use for updating + // @param n_past Number of past updates to include in the update + // returns New counter + int32_t process( + int32_t split, + int32_t variant, + int32_t n_past, + const std::vector& selected + ); + + public: + KVDispatcher(Env& env, std::vector& graphs, bool threaded, uint64_t cpumask); + ~KVDispatcher(); + + // dispatch for all splits + int32_t dispatch(int32_t variant, int32_t n_past); + int32_t dispatch(int32_t variant, int32_t n_past, const std::vector& selected); + int32_t dispatch(int32_t split, int32_t variant, int32_t n_past); + int32_t dispatch( + int32_t split, + int32_t variant, + int32_t n_past, + const std::vector& selected + ); + + // Callback function for worker thread to mark update job has been completed + int32_t workerCallback(int32_t split); + + void dispatchLoop(); + + void setVariant(int32_t variant) { _variant = variant; } + int32_t getCurVariant() { return _variant; }; +}; +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7046ba06c94b1c3e12233b496da152c8ae1366b4 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp @@ -0,0 +1,558 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "qualla/detail/timer.hpp" +#include "qualla/detail/threadpool.hpp" + +#include "nsp-kvmanager.hpp" + +#include "fmt/format.h" +#include "fmt/ranges.h" + +// Copied from threadpool.cpp +#if defined(_WIN32) + #include "windows.h" + +static int sched_yield(void) { + Sleep(0); + return 0; +} +#else + #include +#endif + +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __TRACE(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __KVTRACE(__fmt, ...) \ + _env.logger().post(Logger::KVMANAGER_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +NewNSPKVManager::NewNSPKVManager( + int idx, + Env& env, + ThreadPool* threadpool, + IOTensor* buffer_mgr, + QnnUtils::TensorMap& tensor_specs, + int32_t ctx_size, + int32_t embed_dim, + KVManagerMode mode +) + : _env(env), _mgr_idx(idx), _mode(mode), _n_embed(embed_dim), _n_ctx(ctx_size) { + // Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h{head_idx}]_{in,out} + // TODO: Enforce tensor order during allocation as well to speed up cache loops(?) + std::map key_tensors, value_tensors; + for (auto& [tname, tensor] : tensor_specs) { + auto [tensor_type, layer_idx, head_idx] = parseKVTensorName(tname); + if (tensor_type == 0) continue; + if (tensor_type == 1) + key_tensors[layer_idx << 16 | head_idx] = &tensor; + else + value_tensors[layer_idx << 16 | head_idx] = &tensor; + } + + if (key_tensors.size() + value_tensors.size() == 0) return; + + // Calculate datatype - bitwidth and float vs quantized + auto rt = key_tensors.size() == 0 ? value_tensors.begin()->second : key_tensors.begin()->second; + _bw = rt->dtype.bw(); // Assume same bitwidth for all tensors + if (rt->quantParam[0].offset == 0 && rt->quantParam[0].scale == 0) + _pad_value = 0; // For floating point inputs, pad value is 0 + else // Currently only quantize 8-bit is supported. Will need to change to support 16-bit + _pad_value = static_cast(-rt->quantParam[0].offset); + + // clang-format off + __TRACE( "qnn-kv : {} KVManager[{} Key$ + {} Value$] : {}-bit KV$ n_embed={} n_ctx={} mode={}", + _mgr_idx, key_tensors.size(), value_tensors.size(), _bw*8, _n_embed, _n_ctx, + (_mode==POINTER_SHIFT ? "POINTER_SHIFT" : "SHIFT_CONCAT") + ); + // clang-format on + + _kv_cache.reserve(key_tensors.size() + value_tensors.size()); + for (auto& [_, tensor] : key_tensors) { + void* buffer = buffer_mgr->getBuffer(tensor->tensor); + _kv_cache.emplace_back(true, (char*)buffer, (char*)buffer, tensor->dims.height); + _key_scales.push_back(tensor->quantParam[0].scale); + } + + for (auto& [_, tensor] : value_tensors) { + void* buffer = buffer_mgr->getBuffer(tensor->tensor); + _kv_cache.emplace_back(false, (char*)buffer, (char*)buffer, tensor->dims.height); + _value_scales.push_back(tensor->quantParam[0].scale); + } + + // Calculate _max_n_heads + for (auto &cache : _kv_cache) + _max_n_heads = cache.n_heads > _max_n_heads ? cache.n_heads : _max_n_heads; + + // clang-format off + __TRACE( "qnn-kv : {} KVManager[{} Key$ + {} Value$] : n_heads<={} n_embed={} n_ctx={} mode={}", + _mgr_idx, key_tensors.size(), value_tensors.size(), _max_n_heads, _n_embed, _n_ctx, + (_mode==POINTER_SHIFT ? "POINTER_SHIFT" : "SHIFT_CONCAT") + ); + // clang-format on + + if (threadpool != nullptr && threadpool->size() > 0) { + _threadpool = threadpool; + n_threads = threadpool->size(); + _sync = 0; + + _update_jobs.reserve(n_threads + 1); + if (_mode == POINTER_SHIFT) + _update_jobs.push_back([this] { this->registerPointerOffset(); }); + + for (int idx = 0; idx < n_threads; idx++) + _update_jobs.push_back([this, idx] { this->runKVUpdateJob(idx); }); + } + + _callback_fn = [](int32_t a) { return 0; }; +} + +NewNSPKVManager::~NewNSPKVManager() {} + +// Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h{head_idx}]_{in,out} +std::tuple NewNSPKVManager::parseKVTensorName(std::string name) { + if (!name.starts_with("past_")) return {0, 0, 0}; + + const bool is_key = name.starts_with("past_key"); + const size_t pos0 = (is_key) ? 9 : 11; // "past_key_" OR "past_value_" + const size_t pos1 = name.find('_', pos0); + const size_t pos2 = name.find('_', pos1 + 2); + + uint16_t layer_idx = 0, head_idx = 0; + layer_idx = static_cast(std::stoi(name.substr(pos0, pos1 - pos0))); + if (pos2 != std::string::npos) + head_idx = static_cast(std::stoi(name.substr(pos1 + 2, pos2 - pos1 - 2))); + + return std::make_tuple(is_key ? 1 : 2, layer_idx, head_idx); +} + +// Switch key cache from AR-m to AR-n (relative to ctx_size) +bool NewNSPKVManager::switchKeyVariant(KVCache cache, int32_t m, int32_t n, int32_t offset) { + const size_t in_cache_dim = (m == _n_ctx) ? _n_ctx : _n_ctx - m; + const size_t out_cache_dim = _n_ctx - n; + const size_t n_heads = cache.n_heads; + + const size_t read_row_size = in_cache_dim * _bw; + const size_t write_row_size = out_cache_dim * _bw; + const size_t offset_size = offset * _bw; + + if (in_cache_dim > out_cache_dim) { + char* read_ptr = cache.buffer + read_row_size - write_row_size + offset_size; + char* write_ptr = cache.buffer + offset_size; + + for (int i = 0; i < n_heads * _n_embed; i++) { + std::memmove(write_ptr, read_ptr, write_row_size); + read_ptr += read_row_size; + write_ptr += write_row_size; + } + } else { + const size_t block_size_delta = write_row_size - read_row_size; + + char* read_ptr = cache.buffer + (n_heads * _n_embed - 1) * read_row_size + offset_size; + char* write_ptr = cache.buffer + (n_heads * _n_embed - 1) * write_row_size + offset_size; + + for (int i = 0; i < n_heads * _n_embed; i++) { + std::memmove(write_ptr + block_size_delta, read_ptr, read_row_size); + std::memset(write_ptr, _pad_value, block_size_delta); + read_ptr -= read_row_size; + write_ptr -= write_row_size; + } + } + + return true; +} + +// Switch value cache from AR-m to AR-n (relative to ctx_size) +bool NewNSPKVManager::switchValueVariant(KVCache cache, int32_t m, int32_t n, int32_t offset) { + const size_t in_cache_dim = (m == _n_ctx) ? _n_ctx : _n_ctx - m; + const size_t out_cache_dim = _n_ctx - n; + const size_t n_heads = cache.n_heads; + + const size_t read_block_size = in_cache_dim * _n_embed * _bw; + const size_t write_block_size = out_cache_dim * _n_embed * _bw; + const size_t offset_size = offset * _n_embed * _bw; + + if (in_cache_dim > out_cache_dim) { + char* read_ptr = cache.buffer + read_block_size - write_block_size + offset_size; + char* write_ptr = cache.buffer + offset_size; + + for (int i = 0; i < n_heads; i++) { + std::memmove(write_ptr, read_ptr, write_block_size); + read_ptr += read_block_size; + write_ptr += write_block_size; + } + } else { + const size_t block_size_delta = write_block_size - read_block_size; + + char* read_ptr = cache.buffer + (n_heads - 1) * read_block_size + offset_size; + char* write_ptr = cache.buffer + (n_heads - 1) * write_block_size + offset_size; + + for (int i = 0; i < n_heads; i++) { + std::memmove(write_ptr + block_size_delta, read_ptr, read_block_size); + std::memset(write_ptr, _pad_value, block_size_delta); + read_ptr -= read_block_size; + write_ptr -= write_block_size; + } + } + + return true; +} + +// clang-format off +bool NewNSPKVManager::updateKey(KVCache cache, int32_t variant, int32_t n_update, int32_t offset, const std::vector& selected) { + // clang-format on + char* dst = cache.buffer; + char* src = cache.output_buffer; + + if (n_update < 0) { + const int32_t n_iter = cache.n_heads * _n_embed; + const int32_t iter_size = (_n_ctx - variant) * _bw; + const int32_t copy_size = -n_update * _bw; + + if (_mode == SHIFT_CONCAT) { + std::memmove(dst + copy_size, dst, n_iter * iter_size - copy_size); + std::memset(dst, _pad_value, copy_size); + } else if (_mode == POINTER_SHIFT) { + char* write_ptr = dst + offset * _bw + iter_size - copy_size; + for (int32_t i = 0; i < n_iter; i++) { + std::memset(write_ptr, _pad_value, copy_size); + write_ptr += iter_size; + } + } + + return true; + } + + const int32_t n_iter = cache.n_heads * _n_embed; + const int32_t iter_size = (_n_ctx - variant) * _bw; + const int32_t copy_size = n_update * _bw; + const int32_t out_size = variant * _bw; + + if (_mode == SHIFT_CONCAT) // Shift KV$ buffer if necessary + std::memmove(dst, dst + copy_size, n_iter * iter_size - copy_size); + + // Concatenate output into the KV$ buffers + char* read_ptr = src; // output_buffer + char* write_ptr = dst + offset * _bw + iter_size - ((_mode == POINTER_SHIFT) ? 0 : copy_size); + + if (selected.empty()) { + for (int32_t i = 0; i < n_iter; i++) { + std::memcpy(write_ptr, read_ptr, copy_size); + write_ptr += iter_size; + read_ptr += out_size; + } + } else { + for (int32_t i = 0; i < n_iter; i++) { + auto wp = write_ptr, rp = read_ptr; + for (auto sel : selected) { + for (int i = 0; i < _bw; i++) { + if (sel) *wp++ = *rp; + ++rp; + } + } + write_ptr += iter_size; + read_ptr += out_size; + } + } + + return true; +} + +// clang-format off +bool NewNSPKVManager::updateValue(KVCache cache, int32_t variant, int32_t n_update, int32_t offset, const std::vector& selected) { + // clang-format on + char* dst = cache.buffer; + char* src = cache.output_buffer; + + if (n_update < 0) { + const int32_t n_iter = cache.n_heads; + const int32_t iter_size = (_n_ctx - variant) * _n_embed * _bw; + const int32_t copy_size = -n_update * _n_embed * _bw; + if (_mode == SHIFT_CONCAT) { + std::memmove(dst + copy_size, dst, cache.n_heads * iter_size - copy_size); + std::memset(dst, _pad_value, copy_size); + } else if (_mode == POINTER_SHIFT) { + char* write_ptr = dst + offset * _n_embed * _bw + iter_size - copy_size; + for (int32_t i = 0; i < n_iter; i++) { + std::memset(write_ptr, _pad_value, copy_size); + write_ptr += iter_size; + } + } + + return true; + } + + const int32_t n_iter = cache.n_heads; + const int32_t iter_size = (_n_ctx - variant) * _n_embed * _bw; + const int32_t copy_size = n_update * _n_embed * _bw; + const int32_t out_size = variant * _n_embed * _bw; + + if (_mode == SHIFT_CONCAT) // Shift KV$ buffer if necessary + std::memmove(dst, dst + copy_size, cache.n_heads * iter_size - copy_size); + + // Concatenate output into the KV$ buffers + char* read_ptr = src; + char* write_ptr = dst + offset * _n_embed * _bw + iter_size; + if (_mode != POINTER_SHIFT) write_ptr -= copy_size; + if (selected.empty()) { + for (int i = 0; i < cache.n_heads; i++) { + std::memcpy(write_ptr, read_ptr, copy_size); + write_ptr += iter_size; + read_ptr += out_size; + } + } else { + for (int i = 0; i < cache.n_heads; i++) { + auto wp = write_ptr, rp = read_ptr; + for (auto sel : selected) { + if (sel) { + std::memcpy(wp, rp, _n_embed * _bw); + wp += _n_embed * _bw; + } + rp += _n_embed * _bw; + } + write_ptr += iter_size; + read_ptr += out_size; + } + } + return true; +} + +bool NewNSPKVManager::registerPointerOffset() { + int32_t variant = _req_state.variant; + int32_t ptr_offset = _req_state.ptr_offset; + __KVTRACE("qnn-kv : graph[{}] pointerShift({} @ AR-{})", _mgr_idx, ptr_offset, variant); + _register_pointer_fn(variant, ptr_offset * _bw); + + if (_threadpool != nullptr) { + const int rem = --_sync; + __KVTRACE("qnn-kv : graph[{}] pointerShift complete ({} remain)", _mgr_idx, rem); + if (rem == 0) updateState(); + } + return true; +} + +bool NewNSPKVManager::updateState() { + // clang-format off + __TRACE("qnn-kv : graph[{}] updateState to AR-{}(n_past={}, ptr={})", _mgr_idx, + _req_state.variant, _req_state.n_past, _req_state.ptr_offset); + // clang-format on + + if (_cur_state.variant != _req_state.variant) { + int idx = 0; + for (KVCache& cache : _kv_cache) { + const int32_t dim_size = _n_ctx - _req_state.variant; + cache.output_buffer = cache.buffer + dim_size * cache.n_heads * _n_embed * _bw; + + if (_mode == POINTER_SHIFT) + cache.output_buffer += cache.is_key ? _n_ctx * _bw : _n_ctx * _n_embed * _bw; + } + } + + _cur_state = _req_state; + _counter = _callback_fn(_mgr_idx); + return true; +} + +// Function executes on the threadpool - called once per thread. +// Assumes the lock is properly attained by this point +void NewNSPKVManager::runKVUpdateJob(int thread_idx) { + // clang-format off + __KVTRACE( + "qnn-kv : graph[{}] tid[{}] kv-update started. {} ", + _mgr_idx, thread_idx, modeStr(_req_mode)); + // clang-format on + int job_count = 1 + ((getNumKVTensors() - 1) / n_threads); // Number of jobs per thread + int end_idx = job_count * (thread_idx + 1); + if (end_idx > getNumKVTensors()) end_idx = getNumKVTensors(); + + for (int idx = job_count * thread_idx; idx < end_idx; idx++) { + KVCache& cache = _kv_cache[idx]; + + auto& [variant, n_past, ptr_offset, selected] = _cur_state; + const int32_t n_update = _req_state.n_past - n_past; + + if (cache.is_key) { + if (_req_mode == CLEAR_CACHE) clearBuffer(cache); + if (_req_mode == UPDATE_OUTPUT || _req_mode == UPDATE_AND_SET) { + updateKey(cache, variant, n_update, ptr_offset, _req_state.selected); + } + if (_req_mode == SET_VARIANT || _req_mode == UPDATE_AND_SET) { + switchKeyVariant(cache, variant, _req_state.variant, _req_state.ptr_offset); + } + } else { + if (_req_mode == CLEAR_CACHE) clearBuffer(cache); + if (_req_mode == UPDATE_OUTPUT || _req_mode == UPDATE_AND_SET) { + updateValue(cache, variant, n_update, ptr_offset, _req_state.selected); + } + if (_req_mode == SET_VARIANT || _req_mode == UPDATE_AND_SET) { + switchValueVariant(cache, variant, _req_state.variant, _req_state.ptr_offset); + } + } + } + + if (_threadpool != nullptr) { + const int rem = --_sync; + __KVTRACE("qnn-kv : graph[{}] tid[{}] kv-update ({} remain)", _mgr_idx, thread_idx, rem); + if (rem == 0) updateState(); + } else // Without threading, this is only called once so we can updateState() immediately + updateState(); +} + +void NewNSPKVManager::dispatchUpdate( + int32_t n_past, + int32_t variant, + const std::vector& selected +) { + // clang-format off + __KVTRACE("qnn-kv : graph[{}] dispatchUpdate AR-{}(n_past={}, ptr={}) -> AR-{}(n_past={})", + _mgr_idx, _cur_state.variant, _cur_state.n_past, _cur_state.ptr_offset, variant, n_past); + // clang-format on + + bool skip_update = false; + _req_state = {variant, n_past, _cur_state.ptr_offset, selected}; + + if (_req_state.n_past == 0) { + _req_mode = CLEAR_CACHE; + _req_state.ptr_offset = 0; + + // Nothing to be done iff + // - Requested variant is BERT Mode, i.e. takes no input (new_variant == _n_ctx) + // - Cache is already empty (n_past == 0) + if (_req_state.variant == _n_ctx || _cur_state.n_past == 0) _req_mode = NO_OP; + } else if (_req_state.n_past == _cur_state.n_past) { + _req_mode = SET_VARIANT; + // Nothing needs to be done iff + // - Cache is empty (n_past == 0). Might want to check for BERT->AR-1 + // - Requested variant is already set (new_variant == cur_variant) + // - Requested variant is BERT Mode, i.e. takes no input (new_variant == _n_ctx) + if (_cur_state.n_past == 0 || _req_state.variant == _n_ctx || + _req_state.variant == _cur_state.variant) + _req_mode = NO_OP; + if (_req_state.variant == _n_ctx) _req_state.ptr_offset = 0; + + } else if (_req_state.n_past < _cur_state.n_past) { + _req_mode = UPDATE_OUTPUT; + if (_mode == POINTER_SHIFT) + _req_state.ptr_offset -= (_cur_state.n_past - _req_state.n_past); + + } else if (_req_state.variant == _cur_state.variant) { // UPDATE_OUTPUT + _req_mode = UPDATE_OUTPUT; + if (_cur_state.variant == _n_ctx) + _req_mode = NO_OP; + else if (_mode == POINTER_SHIFT) + _req_state.ptr_offset += (_req_state.n_past - _cur_state.n_past); + + } else { + _req_mode = UPDATE_AND_SET; + + if (_cur_state.variant == _n_ctx) + _req_mode = SET_VARIANT; + else if (_req_state.variant == _n_ctx) { + _req_state.n_past = 0; + _req_mode = NO_OP; // If we're switching to BERT-Mode, nothing to do + } + + if (_req_mode == UPDATE_AND_SET && _cur_state.variant != _n_ctx && _mode == POINTER_SHIFT) + _req_state.ptr_offset += (_req_state.n_past - _cur_state.n_past); + } + + // clang-format off + __KVTRACE("qnn-kv : graph[{}] Processing {} AR-{}(n_past={}, ptr={})", + _mgr_idx, modeStr(_req_mode), _req_state.variant, _req_state.n_past, _req_state.ptr_offset); + // clang-format on + + if (_req_mode == NO_OP) { + // TODO: Think about this case a bit more. Any other cases we want to registerPtrOffset()? + bool needs_register_ptr = + (_mode == POINTER_SHIFT && (_cur_state.variant != _req_state.variant || + _cur_state.ptr_offset != _req_state.ptr_offset)); + + if (needs_register_ptr) { + if (_threadpool != nullptr) { + _sync += 1; + registerPointerOffset(); + } else { + registerPointerOffset(); + updateState(); + } + } else + updateState(); + return; + } + + if (_threadpool != nullptr) { + _sync += _update_jobs.size(); + _threadpool->enqueue(_update_jobs); + } else { + runKVUpdateJob(0); + if (_mode == POINTER_SHIFT) registerPointerOffset(); + updateState(); + } +} + +bool NewNSPKVManager::loadCache( + std::ifstream* fs, + bool is_key, + int32_t n_valid, + int32_t variant, + int32_t n_heads +) { + __TRACE("qnn-kv : KVManager[{}] load cache", _mgr_idx); + const size_t cache_dim = (variant == _n_ctx) ? _n_ctx : _n_ctx - variant; + const size_t iter_size = (is_key) ? cache_dim * _bw : cache_dim * _n_embed * _bw; + const size_t copy_size = (is_key) ? n_valid * _bw : n_valid * _n_embed * _bw; + + for (KVCache& cache : _kv_cache) { + if (cache.is_key != is_key) continue; + + clearBuffer(cache); + const int n_iter = (is_key) ? cache.n_heads * _n_embed : cache.n_heads; + char* data = (char*)cache.buffer + iter_size - copy_size; + for (int i = 0; i < n_iter; i++) { + fs->read(data, copy_size); + data += iter_size; // Jump to the next row/block (depending on type) + } + + if (n_heads > cache.n_heads) + fs->seekg((n_heads - cache.n_heads) * _n_embed * n_valid * _bw, std::ios::cur); + } + + _req_state = {variant, n_valid, 0}; + updateState(); + + return true; +} + +bool NewNSPKVManager::dumpCache(std::ofstream* fs, bool is_key, int32_t n_valid, int32_t n_heads) { + __TRACE("qnn-kv : graph[{}] dump cache", _mgr_idx); + const int32_t variant = _cur_state.variant; + const int32_t ptr_offset = _cur_state.ptr_offset; + const size_t cache_dim = (variant == _n_ctx) ? _n_ctx : _n_ctx - variant; + + const size_t iter_size = (is_key) ? cache_dim * _bw : cache_dim * _n_embed * _bw; + const size_t copy_size = (is_key) ? n_valid * _bw : n_valid * _n_embed * _bw; + const size_t offset_size = (is_key) ? ptr_offset * _bw : ptr_offset * _n_embed * _bw; + + for (KVCache& cache : _kv_cache) { + if (cache.is_key != is_key) continue; + + const int n_iter = (is_key) ? cache.n_heads * _n_embed : cache.n_heads; + char* data = (char*)cache.buffer + offset_size + iter_size - copy_size; + for (int i = 0; i < n_iter; i++) { + fs->write(data, copy_size); + data += iter_size; // Jump to the next row/block (depending on type) + } + + if (n_heads > cache.n_heads) + fs->seekp((n_heads - cache.n_heads) * _n_embed * n_valid * _bw, std::ios::cur); + } + return true; +} +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f1208aa5364193f472c3bd88e6cbd00a7a813854 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp @@ -0,0 +1,163 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#pragma once + +#include + +#include "QnnApi.hpp" +#include "IOTensor.hpp" +#include "qnn-utils.hpp" + +#include + +#include "qualla/env.hpp" + +namespace qualla { + + +enum KVUpdateMode { + NO_OP = 0x0, + CLEAR_CACHE = 0x1, + SET_VARIANT = 0x2, + UPDATE_OUTPUT = 0x4, + UPDATE_AND_SET = 0x8 +}; + +static std::string modeStr(KVUpdateMode mode) { + if (mode == CLEAR_CACHE) return "CLEAR_CACHE"; + if (mode == SET_VARIANT) return "SET_VARIANT"; + if (mode == UPDATE_OUTPUT) return "UPDATE_OUTPUT"; + if (mode == UPDATE_AND_SET) return "UPDATE_AND_SET"; + return "NO_OP"; +} + +struct KVCache { + bool is_key; + char* buffer; + char* output_buffer; + int32_t n_heads; + KVCache() {} + KVCache(bool is_key_val, char* buffer_val, char* output_buffer_val, int32_t n_heads_val) : + is_key(is_key_val), buffer(buffer_val), output_buffer(output_buffer_val), n_heads(n_heads_val) {} +}; + +class NewNSPKVManager { + private: + Env& _env; + int _mgr_idx; // Identify KVManager in the logs + + ThreadPool* _threadpool{nullptr}; // Threadpool for async background processing + std::atomic_int _sync{0}; + + std::vector> _update_jobs; + + KVManagerMode _mode{POINTER_SHIFT}; + + std::vector _kv_cache; // + std::vector _key_scales, _value_scales; + int32_t _max_n_heads{0}; + + // Caputre states + struct KVManagerState { + int32_t variant; + int32_t n_past; + int32_t ptr_offset; + std::vector selected; + }; + + KVManagerState _cur_state{-1, -1, 0, {}}; + KVManagerState _req_state{-1, -1, 0, {}}; + KVUpdateMode _req_mode{NO_OP}; + + int32_t _counter{-1}; // Auto-increment variable for syncing updates + int32_t n_threads{1}; + + // Variant (n) stores AR-n for which the cache is currently formatted + // The following variables are strictly dependent on variant n. Make sure to update accordingly + size_t key_output_offset, value_output_offset; + + // Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h{head_idx}]_{in,out} + std::tuple parseKVTensorName(std::string name); + + // KV Manager Utility functions + void clearBuffer(KVCache cache) { + std::memset(cache.buffer, _pad_value, cache.n_heads * _n_ctx * _n_embed * _bw); + } + + bool switchKeyVariant(KVCache cache, int32_t m, int32_t n, int32_t ptr_offset); + bool switchValueVariant(KVCache cache, int32_t m, int32_t n, int32_t ptr_offset); + bool updateKey( + KVCache cache, + int32_t variant, + int32_t n_update, + int32_t offset, + const std::vector& selected + ); + bool updateValue( + KVCache cache, + int32_t variant, + int32_t n_update, + int32_t offset, + const std::vector& selected + ); + + // For pointer shift + std::map>* _alloc_info; + bool registerPointerOffset(); // Register offsets for POINTER_SHIFT + + std::function _callback_fn; + std::function _register_pointer_fn; + + public: + uint8_t _pad_value; // Assumes all tensors have a common zero point @ 128 + int8_t _bw{1}; // Bitwidth of KV$ values. Defaults to 8-bit KV$ + int32_t _n_embed{-1}; + int32_t _n_ctx{-1}; + + // clang-format off + NewNSPKVManager( int idx, Env& env, ThreadPool* threadpool, IOTensor* buffer_mgr, + QnnUtils::TensorMap &tensor_specs, int32_t ctx_size, int32_t embed_dim, KVManagerMode mode); + // clang-format on + ~NewNSPKVManager(); + + bool loadCache( + std::ifstream* fs, + bool is_key, + int32_t n_valid, + int32_t variant, + int32_t n_heads + ); + bool dumpCache(std::ofstream* fs, bool is_key, int32_t n_valid, int32_t n_heads); + + bool updateState(); + void runKVUpdateJob(int thread_idx); // Worker thread function + void setTensorAllocInfo(std::map>* alloc_info) { + _alloc_info = alloc_info; + } + void registerCallback(std::function callback_fn) { + _callback_fn = callback_fn; + } + + // TODO: Cleanup and remove this function. KVManager should handle all alloc/register for KV$ + void registerPointerOffsetFn(std::function register_fn) { + _register_pointer_fn = register_fn; + } + + void dispatchUpdate(int32_t new_n_past, int32_t variant, const std::vector& selected); + + const size_t getNumKVTensors() const { return _kv_cache.size(); } + const int32_t getMaxNHeads() const { return _max_n_heads; } + int32_t getCurOffset() { return _cur_state.ptr_offset; } + int32_t getCurVariant() { return _cur_state.variant; } + int32_t getNPast() { return _cur_state.n_past; } + std::vector& getKeyScales() { return _key_scales; } + std::vector& getValueScales() { return _value_scales; } +}; + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp new file mode 100644 index 0000000000000000000000000000000000000000..24da5f8473e41e9fb61bf91b2d4940977873c5d8 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp @@ -0,0 +1,2626 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#define _USE_MATH_DEFINES // Used for M_PI + +#include "qualla/env.hpp" +#include "qualla/detail/timer.hpp" +#include "qualla/detail/cache-file.hpp" + +#include "fmt/format.h" +#include "fmt/ranges.h" +#include "fmt/os.h" +#include +#include "nsp-model.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include "fp16/fp16.h" + +namespace fs = std::filesystem; + +#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__)) +#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__)) +#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__)) +#define __KPIS(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __DEBUG(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) +#define __TRACE(__fmt, ...) \ + _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); }) + +namespace qualla { + +QnnNspModel::QnnNspModel(Env& env, const Params& params) + : _env(env), model_basedir(params.model_basedir) { + // Initialize QnnAPI + m_qnnApi = std::unique_ptr(new QnnApi()); + + spill_fill_buffer_size = params.spill_fill_bufsize; + m_kv_dim = params.kv_dim; + m_use_mmap = params.use_mmap; + m_use_async_Init = params.use_async_Init; + mmap_budget = params.mmap_budget; + m_ctx_size = params.ctx_size; + m_pad_token = params.pad_token; + lmhead_weight_dir = params.lmhead_weight_dir; + graph_switching = params.graph_switching; + load_select_graphs = params.load_select_graphs; + lora_conf = params.lora_config_type; + embedding_length = params.embedding_length; + embedding_datatype = params.embedding_datatype; + m_disableKvCache = params.disable_kv_cache; + m_embd_size = params.n_embd; + m_modelArchitectureType = params.modelArchitectureType; + // Positional encoding parameters + m_positional_encoding = params.positional_encoding_params; + if (m_positional_encoding.type == PositionalEncoding::ROPE) // Save m_pos_dim for easy access + m_pos_dim = m_positional_encoding.rope_params.dims; + + // Debug flags + _debug_path = params.debug_path; + _debug_specs = params.debug_specs; + _debug_tensors = params.debug_tensors; + _debug_outputs = params.debug_outputs; + _debug_qnn = params.debug_qnn; + + _backend_lib = params.backend_lib; + _backend_ext_conf = params.backend_ext_conf; + + if (graph_switching && !m_use_mmap) + __WARN("Graph switching with non-mmaped implementation can cause high sustained memory usage" + ); + + variant_latency = params.variant_latency; + + if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ + m_pooled_output = params.pooled_output; + } + + exec_select_graphs = params.exec_select_graphs; + if (!exec_select_graphs.empty()) + __DEBUG("qnn-htp : Execute selected graphs = {}", exec_select_graphs); + + _kv_update_method = (params.kv_update_method == "POINTER_SHIFT") ? POINTER_SHIFT : SHIFT_CONCAT; + __DEBUG("qnn-htp : NSP KV$ Update Method = {}", + (_kv_update_method == POINTER_SHIFT) ? "POINTER_SHIFT" : "SHIFT_CONCAT"); + + // Set up filename list. + for (auto& i : params.model_list) { + fs::path model_path = fs::path(i); + if (model_path.is_relative()) model_path = model_basedir / fs::path(i); + if (!fs::is_regular_file(model_path)) { + __ERROR("NSPModel: Can't access model file : {}", model_path.string()); + throw std::runtime_error("NSPModel: Can't access model file : " + model_path.string()); + } + model_filelist.push_back(model_path.string()); + } + + if (lora_conf != LoraConfigType::LORA_DISABLE) { + lora_config.insert(params.lora_param.begin(), params.lora_param.end()); + } + + if (params.n_threads > 0) { + _threaded = true; + _cpumask = params.cpumask; + __DEBUG("qnn-htp: starting threadpool : n_threads {} params. {:#x} poll {}", + params.n_threads, + _cpumask, + params.poll); + threadpool.start(params.n_threads, _cpumask, params.poll); + } + + // Initialize QNN IO Tensor + m_ioTensor = std::unique_ptr(new IOTensor( + m_sharedBuffer ? BufferAlloc::SHARED_BUFFER : BufferAlloc::DEFAULT, + m_sharedBuffer ? m_qnnApi->getQnnInterfaceVer() : nullptr + )); + + m_qnnApi->setIOTensorBufferMgr(m_ioTensor.get()); + m_qnnApi->setKVDim(m_kv_dim); + m_qnnApi->setContextSize(m_ctx_size); + m_qnnApi->setKVUpdateMethod(_kv_update_method); + + if (params.debug_specs || params.debug_tensors) { + if (!fs::exists(params.debug_path) && !fs::create_directories(params.debug_path)) + throw std::runtime_error("Could not create debug directory : " + params.debug_path); + } +} + +QnnNspModel::~QnnNspModel() { + qualla::Timer start; + + if (_threaded) { + __DEBUG("qnn-htp: stopping threadpool"); + threadpool.stop(); // Stop Threadpool first + } + + // Free cached RoPE memory + if (rope_sin != nullptr) free(rope_sin); + if (rope_cos != nullptr) free(rope_cos); + + __DEBUG("qnn-htp: model destruct complete: {} usec", start.elapsed_usec()); +} +bool QnnNspModel::float32ToFloat16(uint8_t *out, float *in, size_t numElements) { + if(!numElements) return false; + uint16_t *temp = (uint16_t *)out; + for(size_t i = 0; i < numElements; i++){ + temp[i] = fp16_ieee_from_fp32_value(in[i]); + } + return true; +} +// Given a filename, initializeModel load and initializes QNN runtime libraries and the model +bool QnnNspModel::initializeModel(void) { + qualla::Timer start; + + __DEBUG("qnn-htp: model init start"); + + // Default backends +#ifdef _WIN32 + const std::string m_backend = _backend_lib.empty() ? "QnnHtp.dll" : _backend_lib; + const std::string m_systemLib = "QnnSystem.dll"; + const std::string backendExtensionsLibPath = "QnnHtpNetRunExtensions.dll"; +#else + const std::string m_backend = _backend_lib.empty() ? "libQnnHtp.so" : _backend_lib; + const std::string m_systemLib = "libQnnSystem.so"; + const std::string backendExtensionsLibPath = "libQnnHtpNetRunExtensions.so"; +#endif +#ifdef QUALLA_INTERNAL_QNN_SDK + if (_backend_ext_conf.empty()) { + __INFO("No backend extension config provided"); + } + fs::path m_backendExtensionsConfigPath = fs::path(_backend_ext_conf); +#else + fs::path m_backendExtensionsConfigPath = + _backend_ext_conf.empty() ? fs::path("data") / "htp_backend_ext_config.json" + : fs::path(_backend_ext_conf); + + if (m_backendExtensionsConfigPath.is_relative()) + m_backendExtensionsConfigPath = fs::path(model_basedir) / m_backendExtensionsConfigPath; + + if (!fs::is_regular_file(m_backendExtensionsConfigPath)) { + __ERROR("Cannot access {}", m_backendExtensionsConfigPath.string()); + return false; + } +#endif + __INFO("Backend library : {}", m_backend); + __INFO("System library : {}", m_systemLib); + __INFO("Model dir : {}", model_basedir.string()); + __INFO("Model files : {}", model_filelist); + __INFO("Backend extensions lib path : {}", backendExtensionsLibPath); + __INFO("Backend extensions config path : {}", m_backendExtensionsConfigPath.string()); + + if (!m_qnnApi->initialize( + m_backend, + model_filelist, + BackendExtensionsConfigs( + backendExtensionsLibPath, m_backendExtensionsConfigPath.string() + ), + PerfProfile::BURST, + ContextConfigs(Qnn_Priority_t::QNN_PRIORITY_DEFAULT), + {}, // graphConfigs + true, // loadFromCachedBinary + m_systemLib, // systemLibraryPath + false, + spill_fill_buffer_size, + m_use_mmap, + m_use_async_Init, + mmap_budget, + _debug_qnn, + graph_switching, + exec_select_graphs, + load_select_graphs + )) { + __ERROR("qnn-api initialization failed!"); + return false; + } + + int32_t n_splits = 0; + m_num_graphs = m_qnnApi->getGraphsCount(); + + __INFO("qnn-api initialized with {} graph(s)", m_num_graphs); + + GraphInfo_t** graphs_info = m_qnnApi->getGraphsInfo(); + m_variant_list.reserve(m_num_graphs); + std::map> graph_names; + for (size_t graph_idx = 0; graph_idx < m_num_graphs; graph_idx++) { + GraphInfo_t* const graph_info = graphs_info[graph_idx]; + GraphVariant graph(graph_info, m_qnnApi->getContexts(graph_info), m_ctx_size, m_layerNames); + __DEBUG("qnn-htp: Graph {}", graph.graph_name); + + if (!variant_latency.empty() && !variant_latency.contains(graph.n_tokens)) { + __WARN("qnn-htp: Disabling {} based on conf file", graph.graph_name); + continue; + } + + if (exec_select_graphs.size() != 0 && + std::find(exec_select_graphs.begin(), exec_select_graphs.end(), graph.graph_name) == + exec_select_graphs.end()) { + __DEBUG("qnn-htp: Graph {} is not selected to execute based on conf file", + graph.graph_name); + continue; + } + m_variant_list.emplace_back(graph); + n_splits = std::max(n_splits, ++nsp_graph_count[graph.n_tokens]); + graph_names[graph.n_tokens].push_back(graph.graph_name); + m_graph_map[std::string(graph_info->graphName)] = &m_variant_list.back(); + } + + if (exec_select_graphs.size() != 0 && graph_names.empty()) { + __ERROR("No matching graphs based on conf file"); + } + + // Create NSPGraph for each splits + m_nsp_graphs.reserve(n_splits); + for (int idx = 0; idx < n_splits; idx++) { + m_nsp_graphs.emplace_back( + idx, _env, m_ctx_size, m_qnnApi.get(), m_ioTensor.get(), _threaded + ); + m_nsp_graphs.back().setDebugMode(_debug_specs, _debug_tensors, _debug_path); + } + + // Insert all GraphVariants into corresponding NSPGraph + for (auto& [n_tokens, graphs] : graph_names) { + std::sort(graphs.begin(), graphs.end()); + for (int idx = 0; idx < graphs.size(); idx++) + m_nsp_graphs.at(idx).addGraph(m_graph_map.at(graphs[idx])); + } + + if (_debug_specs) dumpTensorSpecs(); + + { + __INFO("qnn-htp: Graphs loaded (AR-n: #splits): {}", nsp_graph_count); + + // Check if latency map matches the graphs loaded + if (!variant_latency.empty()) { + for (auto [variant, latency] : variant_latency) { + if (!nsp_graph_count.contains(variant)) { + __ERROR("Latency map (AR-n: #latency_ms): {}", variant_latency); + __ERROR("AR-{} present in latency map but not loaded!", variant); + __ERROR("Fix latency-map in the conf file, must map from AR-n to latency (ms)"); + return false; + } + } + } + } + + __DEBUG("qnn-htp: Model Init complete: {} usec", start.elapsed_usec()); + + return true; +} + +// Once the model has been loaded, initialize IO Tensors +// m_ioTensors is initialized by the context for now +bool QnnNspModel::initializeIOTensors() { + + if(m_use_async_Init == false){ // IO Tensor Mem Registration is already done within the + // model_initailize by Qnn_API for Sync Init. + + // set lmHeadWeightsEnabled and loraWeights Enabled + _lmhead_weight_input = m_qnnApi->getLmHeadWeightInputEnabled(); + _lora_enabled = m_qnnApi->getLoraWeightEnabled(); + for (auto it = nsp_graph_count.rbegin(); it != nsp_graph_count.rend(); ++it) { + for (QnnNspGraph& graph : m_nsp_graphs) { + // TensorAllocInfo is added to each NSP graph. + // Needed by Pointer_SHIFT Registration During Execute. + graph.tensor_alloc_info = m_qnnApi->getTensorAllocInfo(); + if(graph.tensor_alloc_info == NULL){ + __ERROR("Error Tensor Allocation Failed."); + return false; + } + } + } + return true; + } + + // This path is used in case of use Async Init is true. + qualla::Timer start; + + + + __DEBUG("qnn-htp: init IO tensors start"); + + // Ideally, we should create and initalize m_ioTensor for each context, but we want to + // be able to see/use all the buffers in every contexts so that they can be connected + // with each other. Hence, we are using only the first context to initialize the m_ioTensor + // and use it for all graphs/contexts. + __DEBUG("qnn-htp: init IO tensor using {}", m_graph_map.begin()->first); + if (true != m_ioTensor->initialize(m_graph_map.begin()->second->context_handle)) { + __ERROR("qnn-htp: failure to initialize IOTensor"); + return false; + } + + + + // Technical note: unordered_map is faster thans map but map makes debug logs easier to read + // The runtime impact shouldn't be very large since max size < #tensors + + typedef int CtxBitVector; + // Maps context bitVector to a map{tensor_name -> max_tensor_size} + std::map> ctx_alloc_map; + // Maps tensor_name to context bitVector, each bit representing a context the tensor exists in + std::map tensor_ctx_map; + // Maps a ContextHandle to a one-hot encoded bitVector (e.g. 1, 2, 4, ...) + std::map ctx_to_hash; + + // Iterate over all tensors in all GraphVariants to figure out allocations + for (auto& variant : m_variant_list) { + // Map the context handle to a hashed bitVector + if (!ctx_to_hash.contains(variant.context_handle)) { + ctx_to_hash[variant.context_handle] = 1 << ctx_to_hash.size(); + } + for (auto& tensor_specs : {variant.input_specs, variant.output_specs}) { + for (auto& [tname, tspec] : tensor_specs) { + size_t size = tspec.dims.getAlignedSize(); + CtxBitVector tcontext = ctx_to_hash[variant.context_handle]; + + // Check if it's LoRA enabled model + if (!_lora_enabled && tname.find("lora") != std::string::npos) _lora_enabled = true; + // Check if graph has lmhead weight input + if (!_lmhead_weight_input && tname.compare("weight") == 0) + _lmhead_weight_input = true; + + // Allocate KV Tensors as in+out + if (tname.starts_with("past_")) { + if (tname.ends_with("_in")) continue; // kv_in is processed along with kv_out + + // For kv_out, add the size of kv_in as well + const std::string tname_in = tname.substr(0, tname.rfind('_')).append("_in"); + if (auto tensor = variant.getInput(tname_in)) + size += tensor->dims.getAlignedSize(); + + d_kv = QnnUtils::DataType(tspec.tensor); + + // Allocate extra buffer for pointer shift + // 1024-n for keys (1024-n)*128 for values + // For aligned size, we might as well use 1024 and 128*1024 + if (_kv_update_method == POINTER_SHIFT) + size += (tname.starts_with("past_key")) ? m_ctx_size * d_kv.bw() + : m_ctx_size * m_kv_dim * d_kv.bw(); + } + + if (tensor_ctx_map.contains(tname)) { // For duplicate tensor names, link them + CtxBitVector context_bitvec = tensor_ctx_map.at(tname); + size = std::max(ctx_alloc_map[context_bitvec][tname], size); + if ((context_bitvec & tcontext) == 0) // Set of contexts needs to be updated + ctx_alloc_map[context_bitvec].erase(tname); + + tcontext |= context_bitvec; + } + + ctx_alloc_map[tcontext][tname] = size; + tensor_ctx_map[tname] = tcontext; + } + } + + // Cleanup is essential in case of very large number of splits + for (auto it = ctx_alloc_map.cbegin(); it != ctx_alloc_map.cend();) + it = (it->second.empty()) ? ctx_alloc_map.erase(it) : ++it; + } + + + + _env.logger().compose(Logger::MALLOC_DEBUG, [&](Logger::Helper w) { + for (auto& [tcontext, tensor_alloc_map] : ctx_alloc_map) { + w.write(fmt::format("qnn-htp: ctx_alloc_map[{}] = {{", tcontext)); + for (auto& [tname, tsize] : tensor_alloc_map) + w.write(fmt::format("\t{} : {},", tname, tsize)); + w.write("}"); + } + }); + + // Calculate total allocation sizes and offset of each tensor within its allocated buffer + if (m_ioTensor->allocateBuffers(ctx_alloc_map, tensor_alloc_info) == false) return false; + + _env.logger().compose(Logger::MALLOC_DEBUG, [&](Logger::Helper w) { + w.write("tensor_alloc_info = {"); + for (auto& [tname, toffset] : tensor_alloc_info) + w.write(fmt::format("\t{}: [{}, {}],", tname, toffset.first, toffset.second)); + w.write("}"); + }); + + // For each variant, map tensor name to its allocated buffer, i/o and offset within the buffer + // TODO: Check why we aren't just looping over all variants here! + for (auto it = nsp_graph_count.rbegin(); it != nsp_graph_count.rend(); ++it) { + + for (QnnNspGraph& graph : m_nsp_graphs) { + + // TODO: Remove this reference + graph.tensor_alloc_info = &tensor_alloc_info; + + auto variant = graph[it->first]; + + std::map> graph_allocs; + + + for (auto& [tname, tspec] : variant->input_specs) { + if (tname.starts_with("past_")) continue; + auto& [alloc_idx, offset] = tensor_alloc_info.at(tname); + graph_allocs[tname] = {alloc_idx, offset, tspec.dims.getAlignedSize()}; + } + + for (auto& [tname, tspec] : variant->output_specs) { + size_t kv_offset = 0; + size_t size = tspec.dims.getAlignedSize(); + + auto& [alloc_idx, offset] = tensor_alloc_info.at(tname); + if (tname.starts_with("past_")) { + auto in_name = tname.substr(0, tname.rfind("_")).append("_in"); + if (auto kv_in = variant->getInput(in_name)) { + kv_offset = kv_in->dims.getAlignedSize(); + if (_kv_update_method == POINTER_SHIFT) + kv_offset += (tname.starts_with("past_key")) + ? m_ctx_size * d_kv.bw() + : m_ctx_size * m_kv_dim * d_kv.bw(); + graph_allocs[in_name] = {alloc_idx, offset, kv_offset}; + } + } + + graph_allocs[tname] = {alloc_idx, offset + kv_offset, size}; + } + + if (!m_ioTensor->mapFusedBufferOffset( + variant->graph_info, variant->context_handle, graph_allocs + )) { + + __ERROR("Error mapping tensor to allocation buffers"); + return false; + } + } + } + + + + __DEBUG("qnn-htp: init IO tensors complete : {} usec", start.elapsed_usec()); + + return true; +} + +static bool checkShape( + const std::string& tensor_name, + const QnnUtils::Tensor* tensor, + int32_t height, + int32_t width, + int32_t channel, + int32_t bitWidth, + std::vector>& errors +) { + if (tensor == nullptr) return true; + const QnnUtils::Dims& tDims = tensor->dims; + + if ((height == -1 || height == tDims.height) && (width == -1 || width == tDims.width) && + (channel == -1 || channel == tDims.channel) && + (bitWidth == -1 || bitWidth == tDims.bitWidth)) + return true; + + std::stringstream err_msg; + err_msg << "Expected [ " << height << ", " << width << ", " << channel << "] " + << "bitWidth=" << bitWidth << ". Found [ " << tDims.height << ", " << tDims.width + << ", " << tDims.channel << "] " + << "bitWidth=" << tDims.bitWidth; + + errors.push_back({"ShapeError", tensor_name, err_msg.str()}); + return false; +} + +// Run all validations for the model here so we can exit early +bool QnnNspModel::validateModel() { + // Checks we will be running + // 1a. input_ids or inputs_embeds exists in the first split + // 1b. token_type_ids should exists in case of Bert + // 2. logits exists in the last split + // 3. Shapes for all named tensors are correct + // 4. All tensors with identical names (incl kv_in/kv_out) have identical quantization params + // Missing check : Shape of tensor between splits match up + + // Support for 16-bit KV Tensors is temporarily disabled + // If you need this, please refer to past commits (QuaLLA <= v0.3.22) + + // Important : These variables need to be set correctly + // m_vocab_size - Calculated as max(logits.shape) since len() + // m_kv_dim - Calculated in this function before usage + // m_ctx_size - Provided by the user as n_ctx + + std::vector> errors; + + QnnUtils::Tensor* tt; + + //default input type is token + m_inputType = InputType::TOKENS; + + // Check 1 - input layer exists + for (auto& [n_tokens, variant] : m_nsp_graphs.front().variants) { + // Update model expectations for E2T if an inputs_embeds layer is present. marks the input Type + if ((tt = variant->getInput("inputs_embeds")) != nullptr) { + m_layerNames[LayerType::INPUT] = "inputs_embeds"; + m_inputType = InputType::EMBEDDINGS; + } + if ((tt = variant->getInput(m_layerNames[LayerType::INPUT])) == nullptr) { + errors.push_back({variant->graph_name, m_layerNames[LayerType::INPUT], "Tensor not found"}); + } else { + input_bitWidth = tt->dtype.bw(); + checkShape(m_layerNames[LayerType::INPUT], tt, -1, -1, -1, input_bitWidth, errors); + + if (embedding_datatype == "float32") { + m_embeddingBufferSize = m_embd_size * sizeof(float); + } else { + m_embeddingBufferSize = m_embd_size * input_bitWidth; + } + + // For embedding inputs, the expected count is multiplied by the embedding size. + size_t expectedElementCount = (m_inputType == InputType::TOKENS) ? n_tokens : n_tokens * m_embd_size; + if (tt->dims.getNumElements() != expectedElementCount) + errors.push_back({variant->graph_name, m_layerNames[LayerType::INPUT], "Wrong input shape"}); + } + } + + // Check 1b - In case of BERT :-> token_type_ids + if(m_modelArchitectureType == ModelArchitectureType::ENCODER) { + for (auto &[n_tokens, variant]: m_nsp_graphs.front().variants) { + if ((tt = variant->getInput(m_layerNames[LayerType::TOKEN_TYPE_IDS])) == nullptr) + errors.push_back({variant->graph_name, m_layerNames[LayerType::TOKEN_TYPE_IDS], "Tensor not found"}); + else { + checkShape(m_layerNames[LayerType::TOKEN_TYPE_IDS], tt, -1, -1, -1, 4, errors); + if (tt->dims.getNumElements() != n_tokens) + errors.push_back({variant->graph_name, m_layerNames[LayerType::TOKEN_TYPE_IDS], + "Wrong token_type_ids shape"}); + } + } + } + + // Check 2 - In case of LLama :-> logits exists + // In case of BERT :-> pooled_output & sequence_outputs exists + for (auto& [n_tokens, variant] : m_nsp_graphs.back().variants) { + if (m_modelArchitectureType == ModelArchitectureType::ENCODER) { + if ((tt = variant->getOutput(m_layerNames[LayerType::POOL_OUTPUT])) == nullptr) + errors.push_back({variant->graph_name, m_layerNames[LayerType::POOL_OUTPUT], "Tensor not found"}); + else { + if (tt->dims.getNumElements() != m_embd_size) + errors.push_back( + {variant->graph_name, m_layerNames[LayerType::POOL_OUTPUT], "Wrong pooled_outputs shape"}); + + } + if (!m_pooled_output) { + if ((tt = variant->getOutput(m_layerNames[LayerType::SEQ_OUTPUT])) == nullptr) + errors.push_back({variant->graph_name, m_layerNames[LayerType::SEQ_OUTPUT], "Tensor not found"}); + else { + if (tt->dims.getNumElements() != n_tokens * m_embd_size) + errors.push_back({variant->graph_name, m_layerNames[LayerType::SEQ_OUTPUT], + "Wrong sequence_output shape"}); + + } + } + } else { + if ((tt = variant->getOutput(m_layerNames[LayerType::OUTPUT])) == nullptr) + errors.push_back({variant->graph_name, m_layerNames[LayerType::OUTPUT], "Tensor not found"}); + else { + if (m_vocab_size == -1) m_vocab_size = tt->dims.getMaxDim(); + if (tt->dims.getNumElements() != m_vocab_size && + tt->dims.getNumElements() != n_tokens * m_vocab_size) + errors.push_back({variant->graph_name, m_layerNames[LayerType::OUTPUT], "Wrong logits shape"}); + } + } + } + + // Check 3 - Shapes for all names tensors are correct + if (m_kv_dim == -1) { // Deduce KV$ embed_dim if not already available + for (auto& variant : m_variant_list) { + for (auto& [tname, tspec] : variant.output_specs) + if (tname.starts_with("past_key")) m_kv_dim = tspec.dims.width; + if (m_kv_dim != -1) break; + } + } + + for (auto& variant : m_variant_list) { + auto& n_tokens = variant.n_tokens; + if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ + checkShape(m_layerNames[LayerType::ATTN_MASK], variant.getInput(m_layerNames[LayerType::ATTN_MASK]), 1, 1, m_ctx_size, -1, errors); + } + else{ + checkShape(m_layerNames[LayerType::ATTN_MASK], variant.getInput(m_layerNames[LayerType::ATTN_MASK]), 1, n_tokens, m_ctx_size, -1, errors); + } + if (m_positional_encoding.type == PositionalEncoding::ROPE) { + checkShape(m_layerNames[LayerType::POS_SIN], variant.getInput(m_layerNames[LayerType::POS_SIN]), 1, n_tokens, m_pos_dim, -1, errors); + checkShape(m_layerNames[LayerType::POS_COS], variant.getInput(m_layerNames[LayerType::POS_COS]), 1, n_tokens, m_pos_dim, -1, errors); + } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) { + checkShape(m_layerNames[LayerType::POS_IDS], variant.getInput(m_layerNames[LayerType::POS_IDS]), 1, 1, n_tokens, -1, errors); + } else if (m_positional_encoding.type == PositionalEncoding::ALIBI) { + checkShape(m_layerNames[LayerType::POS_IDS], variant.getInput(m_layerNames[LayerType::POS_IDS]), 1, n_tokens, m_ctx_size, -1, errors); + } + + if(m_modelArchitectureType != ModelArchitectureType::ENCODER) { + for (auto &[tname, tspec]: variant.input_specs) { + if (tname.starts_with("past_key")) + checkShape(tname, &tspec, -1, m_kv_dim, m_ctx_size - n_tokens, 1, errors); + else if (tname.starts_with("past_value")) + checkShape(tname, &tspec, -1, m_ctx_size - n_tokens, m_kv_dim, 1, errors); + } + + for (auto &[tname, tspec]: variant.output_specs) { + if (tname.starts_with("past_key")) + checkShape(tname, &tspec, -1, m_kv_dim, n_tokens, 1, errors); + else if (tname.starts_with("past_value")) + checkShape(tname, &tspec, -1, n_tokens, m_kv_dim, 1, errors); + } + } + } + + // skip check in case of BERT architecture since no KV cache tensors are existing + if(m_modelArchitectureType != ModelArchitectureType::ENCODER) { + // Check 4 - Quantization parameter match + std::unordered_map quant_params; + for (auto &variant: m_variant_list) { + for (auto &tensor_specs: {variant.input_specs, variant.output_specs}) { + for (auto &[tname, tspec]: tensor_specs) { + std::string name = (tname.starts_with("past_") && tname.ends_with("_in")) + ? tname.substr(0, tname.rfind("_")).append("_out") + : tname; + if (name.compare(m_layerNames[LayerType::OUTPUT]) == 0) continue; + if (quant_params.contains(name)) { + if (quant_params.at(name).scale != tspec.quantParam[0].scale || + quant_params.at(name).offset != tspec.quantParam[0].offset) + errors.push_back( + {variant.graph_name, + tname, + "Non-identical quantization parameters found for the same tensor"} + ); + } else + quant_params[tname] = {tspec.quantParam[0].scale, tspec.quantParam[0].offset}; + } + } + } + } + + if (errors.size() > 0) { + QNN_ERROR("Model Validation Errors found"); + for (auto& [graph_name, tensor_name, err_msg] : errors) // Log the list of errors + QNN_ERROR("%s : %s - %s", graph_name.c_str(), tensor_name.c_str(), err_msg.c_str()); + QNN_ERROR("Note: -1 means ignore (i.e. no comparison)"); + QNN_ERROR("Check model i/o specs (set dump-specs=true in config) for debugging"); + return false; + } + + return true; +} + +bool QnnNspModel::initializeKVManager() { + + if(m_disableKvCache){ + return true; + } + + // Pick the largest variant + int32_t variant = nsp_graph_count.rbegin()->first; + + int idx = 0; + for (auto& graph : m_nsp_graphs) { + auto& specs = (variant == m_ctx_size) ? graph[variant]->output_specs + : graph[variant]->input_specs; + + ThreadPool* _pool = _threaded ? &threadpool : nullptr; + // clang-format off + NewNSPKVManager *manager = new NewNSPKVManager( idx++, _env, _pool, m_ioTensor.get(), + specs, m_ctx_size, m_kv_dim, _kv_update_method); + // clang-format on + graph.registerKVManager(manager); + + if (_kv_update_method == POINTER_SHIFT) + graph.kvmanager->setTensorAllocInfo(&tensor_alloc_info); + } + + _kv_dispatcher = + std::unique_ptr(new KVDispatcher(_env, m_nsp_graphs, _threaded, _cpumask) + ); + _kv_update_count = _kv_dispatcher->dispatch(variant, 0); + + return true; +} + +inline bool QnnNspModel::updateTensorPointer( + GraphVariant& variant, + std::string& key, + QnnUtils::Tensor*& t +) { + QnnUtils::Tensor* tensor_ptr = variant.getInput(key); + if (tensor_ptr == nullptr) return true; + if (t == nullptr) t = tensor_ptr; + if (getBuffer(t) == getBuffer(tensor_ptr)) return true; + + __ERROR("{} has different addresses: {} vs {}", key, (void*)t, (void*)tensor_ptr); + return false; +} + +bool QnnNspModel::initializeTensorPointers() { + // Ideally this needs to be done for all sets of AR-n available, e.g. for AR-1 and AR-1024 + + bool status = true; + for (auto& variant : m_variant_list) { + status &= updateTensorPointer(variant, m_layerNames[LayerType::INPUT], t_input_ids); + status &= updateTensorPointer(variant, m_layerNames[LayerType::ATTN_MASK], t_attn_mask); + status &= updateTensorPointer(variant, m_layerNames[LayerType::POS_SIN], t_position_ids_sin); + status &= updateTensorPointer(variant, m_layerNames[LayerType::POS_COS], t_position_ids_cos); + status &= updateTensorPointer(variant, m_layerNames[LayerType::POS_IDS], t_position_ids); + status &= updateTensorPointer(variant, m_layerNames[LayerType::TOKEN_TYPE_IDS], t_token_type_ids); + } + if (!status) __ERROR("qnn-htp: Error in setting up named tensor pointers."); + + status &= !(!t_input_ids || !t_attn_mask); + if (!t_input_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::INPUT]); + if (!t_attn_mask) __ERROR("Tensor not found: {}", m_layerNames[LayerType::ATTN_MASK]); + + if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ // This input only valid for Encoder only model like bert. + status &= !(!t_token_type_ids); + if (!t_token_type_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::TOKEN_TYPE_IDS]); + } + + if (m_positional_encoding.type == PositionalEncoding::ROPE) { + status &= !(!t_position_ids_sin || !t_position_ids_cos); + if (!t_position_ids_sin) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_SIN]); + if (!t_position_ids_cos) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_COS]); + } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) { + status &= !(!t_position_ids); + if (!t_position_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_IDS]); + } else if (m_positional_encoding.type == PositionalEncoding::ALIBI) { + status &= !(!t_position_ids); + if (!t_position_ids) __ERROR("Tensor not found: {}", m_layerNames[LayerType::POS_IDS]); + } else { + __ERROR("Unknown Rope Type found for tensor: {}", m_layerNames[LayerType::POS_IDS]); + } + + // Detect activation bitwidth + if (status) { + //Check Input-> Input_ID or Input_Embed + d_input = t_input_ids->dtype; + if (!supported_activations.contains(d_input)) { + __ERROR("Input Tensor: {} as unsupported activation type {}", m_layerNames[LayerType::INPUT], d_input.str()); + status = false; + } + // Check Attention Mask + d_attn_map = t_attn_mask->dtype; + if (!supported_activations.contains(d_attn_map)) { + __ERROR("attention_mask has unsupported type {}", d_attn_map.str()); + status = false; + } + // For Encoder only model, Check for Token_type_ids + if(m_modelArchitectureType == ModelArchitectureType::ENCODER) { + d_token_type = t_token_type_ids->dtype; + if (!supported_activations.contains(d_token_type)) { + __ERROR("token_type_ids has unsupported type {}", d_token_type.str()); + status = false; + } + } + + //For Position_IDs check data bitWidth + if (m_positional_encoding.type == PositionalEncoding::ROPE) + d_pos = t_position_ids_sin->dtype; + else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) + d_pos = t_position_ids->dtype; + else if (m_positional_encoding.type == PositionalEncoding::ALIBI) + d_pos = t_position_ids->dtype; + + if (((m_positional_encoding.type == PositionalEncoding::ABSOLUTE || + m_positional_encoding.type == PositionalEncoding::ALIBI) && + d_pos != QNN_DATATYPE_INT_32) || + (m_positional_encoding.type == PositionalEncoding::ROPE && + !supported_activations.contains(d_pos))) { + __ERROR("position encoding tensor has unsupported type {}", d_pos.str()); + status = false; + } + __DEBUG("qnn-htp datatypes: d_input {} d_attn_map {} d_pos {} d_kv {}", + d_input.str(), + d_attn_map.str(), + d_pos.str(), + d_kv.str()); + + if (!status) __ERROR("Only 8-bit, 16-bit and 32-bit activations are supported"); + } + + return status; +} +bool QnnNspModel::setupAttentionMaskFP16(bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset) { + QnnUtils::Dims t_attn_mask_dims = t_attn_mask->dims; + size_t numElements = t_attn_mask_dims.getNumElements(); + size_t bufSize = numElements * 2; // (bitwidth = 16, in bytes: 16/8) + std::vector attn_mask_vec(bufSize); + if (!float32ToFloat16((unsigned char *)attn_mask_vec.data(), (float *) getBuffer(t_attn_mask), numElements)) { + QNN_ERROR("Number of elements is 0"); + return false; + } + // Setup attention mask + { + uint16_t* attn_buffer = (uint16_t*)attn_mask_vec.data(); + const int n_valid = n_past + n_inputs; + + uint16_t pos_val = -1, neg_val = 0; + pos_val = 0; + neg_val = -1000; + + // Clear the attention mask + std::fill_n(attn_buffer, n_tokens * m_ctx_size, neg_val); + if (attention_map.empty()) { + uint16_t* cur_ptr = &attn_buffer + [(pad_left) ? (m_ctx_size - n_valid) * (m_ctx_size + 1) + : m_ctx_size - n_past - n_tokens]; + for (int n_masked = n_past + 1; n_masked <= n_valid; n_masked++) { + std::fill_n(cur_ptr, n_masked, pos_val); + cur_ptr += m_ctx_size; + } + } else if (attention_map.size() == n_inputs) { + // Only fill in n_inputs. Rest will be padding + const size_t attn_row_start = m_ctx_size - n_past - n_tokens; + for (int i = 0; i < n_inputs; i++) { + uint16_t* cur_ptr = &attn_buffer[i * m_ctx_size + attn_row_start]; + + cur_ptr[n_past + i] = pos_val; // Attend to itself + if (attention_map[i] < 0) { // If negative, attend to only past tokens + int32_t n_masked = n_past + attention_map[i] + 1; + if (i < n_apply_prefix_offset) { // Skip prefix is needed + cur_ptr += n_skip_prefix; + n_masked -= n_skip_prefix; + } + std::fill_n(cur_ptr, n_masked, pos_val); + + } else { // If positive, copy attention map from (relative to 0th input) parent + const int32_t pidx = attention_map[i]; // Parent token index + uint16_t* parent_ptr = &attn_buffer[pidx * m_ctx_size + attn_row_start]; + std::memcpy(cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(uint16_t)); + + // If parent skipped prefix, but this token needs to attend to prefix, add attn + if (i >= n_apply_prefix_offset && pidx < n_apply_prefix_offset) + std::fill_n(cur_ptr, n_skip_prefix, pos_val); + } + } + } else if (attention_map.size() == n_valid * n_inputs) { + uint16_t* cur_ptr = &attn_buffer[m_ctx_size - n_past - n_tokens]; + for (int i = 0; i < n_inputs; i++) { + for (int j = 0; j < n_valid; j++) + cur_ptr[j] = (attention_map[i * n_valid + j] == 0) ? neg_val : pos_val; + cur_ptr += m_ctx_size; + } + } + } + + return true; + +} +template +bool QnnNspModel::setupAttentionMask( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset +) { + // Setup attention mask + { + DType* attn_buffer = (DType*)getBuffer(t_attn_mask); + const int n_valid = n_past + n_inputs; + + DType pos_val = -1, neg_val = 0; + + if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ + pos_val = 1; // BGE model is using 1 to set attention mask and 0 to unset. + std::memset(attn_buffer, neg_val, 1 * m_ctx_size * sizeof(DType)); + size_t in_buf_offset = pad_left ? m_ctx_size - n_valid : 0; + DType* cur_ptr = &attn_buffer[in_buf_offset]; + std::fill_n(cur_ptr, n_valid, pos_val); + } + else { + // Clear the attention mask + std::fill_n(attn_buffer, n_tokens * m_ctx_size, neg_val); + if (attention_map.empty()) { + DType *cur_ptr = &attn_buffer + [(pad_left) ? (m_ctx_size - n_valid) * (m_ctx_size + 1) + : m_ctx_size - n_past - n_tokens]; + for (int n_masked = n_past + 1; n_masked <= n_valid; n_masked++) { + std::fill_n(cur_ptr, n_masked, pos_val); + cur_ptr += m_ctx_size; + } + } else if (attention_map.size() == n_inputs) { + // Only fill in n_inputs. Rest will be padding + const size_t attn_row_start = m_ctx_size - n_past - n_tokens; + for (int i = 0; i < n_inputs; i++) { + DType *cur_ptr = &attn_buffer[i * m_ctx_size + attn_row_start]; + + cur_ptr[n_past + i] = pos_val; // Attend to itself + if (attention_map[i] < 0) { // If negative, attend to only past tokens + int32_t n_masked = n_past + attention_map[i] + 1; + if (i < n_apply_prefix_offset) { // Skip prefix is needed + cur_ptr += n_skip_prefix; + n_masked -= n_skip_prefix; + } + std::fill_n(cur_ptr, n_masked, pos_val); + + } else { // If positive, copy attention map from (relative to 0th input) parent + const int32_t pidx = attention_map[i]; // Parent token index + DType *parent_ptr = &attn_buffer[pidx * m_ctx_size + attn_row_start]; + std::memcpy(cur_ptr, parent_ptr, (n_past + pidx + 1) * sizeof(DType)); + + // If parent skipped prefix, but this token needs to attend to prefix, add attn + if (i >= n_apply_prefix_offset && pidx < n_apply_prefix_offset) + std::fill_n(cur_ptr, n_skip_prefix, pos_val); + } + } + } else if (attention_map.size() == n_valid * n_inputs) { + DType *cur_ptr = &attn_buffer[m_ctx_size - n_past - n_tokens]; + for (int i = 0; i < n_inputs; i++) { + for (int j = 0; j < n_valid; j++) + cur_ptr[j] = (attention_map[i * n_valid + j] == 0) ? neg_val : pos_val; + cur_ptr += m_ctx_size; + } + } + } + } + + return true; +} + bool QnnNspModel::setupRopePositionEmbeddingFP16( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset + ) { + const int n_valid = n_past + n_inputs; + + // Cast RoPE embeddings to proper dtype + // The following two buffers are already converted to fp16 + uint16_t* typed_rope_sin = (uint16_t*)rope_sin; + uint16_t* typed_rope_cos = (uint16_t*)rope_cos; + + // These two need conversion + + QnnUtils::Dims t_position_ids_cos_dims = t_position_ids_cos->dims; + size_t numElements = t_position_ids_cos_dims.getNumElements(); + size_t bufSize = numElements * 2; // (bitwidth = 16, in bytes: 16/8) + std::vector position_ids_cos_vec(bufSize); + if (!float32ToFloat16((unsigned char *)position_ids_cos_vec.data(), (float *) getBuffer(t_position_ids_cos), numElements)) { + QNN_ERROR("Number of elements is 0"); + return false; + } + uint16_t* cos_buffer = (uint16_t*)position_ids_cos_vec.data(); + + QnnUtils::Dims t_position_ids_sin_dims = t_position_ids_sin->dims; + numElements = t_position_ids_sin_dims.getNumElements(); + bufSize = numElements * 2; // (bitwidth = 16, in bytes: 16/8) + std::vector position_ids_sin_vec(bufSize); + if (!float32ToFloat16((unsigned char *)position_ids_sin_vec.data(), (float *) getBuffer(t_position_ids_sin), numElements)) { + QNN_ERROR("Number of elements is 0"); + return false; + } + uint16_t* sin_buffer = (uint16_t*)position_ids_sin_vec.data(); + + // Clear out all position_ids as position_sin/cos[0] + const size_t pos_row_size = m_pos_dim * sizeof(uint16_t); + for (int i = 0; i < n_tokens; i++) { + std::memcpy(&sin_buffer[i * m_pos_dim], typed_rope_sin, pos_row_size); + std::memcpy(&cos_buffer[i * m_pos_dim], typed_rope_cos, pos_row_size); + } + + // Copy in position embeddings [0:(n_valid-1)] to input sin/cos buffer + const size_t pos_buf_offset = m_pos_dim * ((pad_left) ? m_ctx_size - n_valid : 0); + if (attention_map.size() == n_inputs) { + // Copy embeddings one by one based on the attention map + std::vector pos_ids(n_inputs, 0); + auto sin = &sin_buffer[pos_buf_offset]; + auto cos = &cos_buffer[pos_buf_offset]; + + // 1st token + pos_ids[0] = m_nPast - n_skip_prefix; + std::memcpy(sin, &typed_rope_sin[pos_ids[0] * m_pos_dim], pos_row_size); + std::memcpy(cos, &typed_rope_cos[pos_ids[0] * m_pos_dim], pos_row_size); + sin += m_pos_dim; + cos += m_pos_dim; + + // Rest + for (int i = 1; i < n_inputs; i++) { + auto parent_index = attention_map[i]; + pos_ids[i] = pos_ids[parent_index] + 1; + std::memcpy(sin, &typed_rope_sin[pos_ids[i] * m_pos_dim], pos_row_size); + std::memcpy(cos, &typed_rope_cos[pos_ids[i] * m_pos_dim], pos_row_size); + sin += m_pos_dim; + cos += m_pos_dim; + } + } else if (attention_map.size() == (n_past + n_inputs) * n_inputs) { + // For now, simply have the same position ID across the variant + auto sin = &sin_buffer[0]; + auto cos = &cos_buffer[0]; + + // Calculate position based on number of items this index is attending to + for (int i = 0; i < n_inputs; i++) { + auto attn_row = attention_map.subspan(i * n_valid, n_valid); + int32_t pos_id = + std::accumulate(attn_row.begin() + n_skip_prefix, attn_row.end(), 0) - attn_row[n_past + i]; + + // __DEBUG("PositionID [ i={}, n_past={}, pos_id={} ]", i, n_past, pos_id); + + std::memcpy(sin, &typed_rope_sin[pos_id * m_pos_dim], pos_row_size); + std::memcpy(cos, &typed_rope_cos[pos_id * m_pos_dim], pos_row_size); + sin += m_pos_dim; + cos += m_pos_dim; + } + } else { + const size_t pos_dat_offset = m_pos_dim * (n_past - n_skip_prefix); + const size_t pos_cpy_amt = pos_row_size * ((pad_left) ? n_valid : n_tokens); + std::memcpy(&sin_buffer[pos_buf_offset], &typed_rope_sin[pos_dat_offset], pos_cpy_amt); + std::memcpy(&cos_buffer[pos_buf_offset], &typed_rope_cos[pos_dat_offset], pos_cpy_amt); + } + + return true; + } +template +bool QnnNspModel::setupRopePositionEmbedding( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset +) { + + const int n_valid = n_past + n_inputs; + + // Cast RoPE embeddings to proper dtype + DType* typed_rope_sin = (DType*)rope_sin; + DType* typed_rope_cos = (DType*)rope_cos; + + DType* cos_buffer = (DType*)getBuffer(t_position_ids_cos); + DType* sin_buffer = (DType*)getBuffer(t_position_ids_sin); + + // Clear out all position_ids as position_sin/cos[0] + const size_t pos_row_size = m_pos_dim * sizeof(DType); + for (int i = 0; i < n_tokens; i++) { + std::memcpy(&sin_buffer[i * m_pos_dim], typed_rope_sin, pos_row_size); + std::memcpy(&cos_buffer[i * m_pos_dim], typed_rope_cos, pos_row_size); + } + + // Copy in position embeddings [0:(n_valid-1)] to input sin/cos buffer + const size_t pos_buf_offset = m_pos_dim * ((pad_left) ? m_ctx_size - n_valid : 0); + if (attention_map.size() == n_inputs) { + // Copy embeddings one by one based on the attention map + std::vector pos_ids(n_inputs, 0); + auto sin = &sin_buffer[pos_buf_offset]; + auto cos = &cos_buffer[pos_buf_offset]; + + // 1st token + pos_ids[0] = m_nPast - n_skip_prefix; + std::memcpy(sin, &typed_rope_sin[pos_ids[0] * m_pos_dim], pos_row_size); + std::memcpy(cos, &typed_rope_cos[pos_ids[0] * m_pos_dim], pos_row_size); + sin += m_pos_dim; + cos += m_pos_dim; + + // Rest + for (int i = 1; i < n_inputs; i++) { + auto parent_index = attention_map[i]; + pos_ids[i] = pos_ids[parent_index] + 1; + std::memcpy(sin, &typed_rope_sin[pos_ids[i] * m_pos_dim], pos_row_size); + std::memcpy(cos, &typed_rope_cos[pos_ids[i] * m_pos_dim], pos_row_size); + sin += m_pos_dim; + cos += m_pos_dim; + } + } else if (attention_map.size() == (n_past + n_inputs) * n_inputs) { + // For now, simply have the same position ID across the variant + auto sin = &sin_buffer[0]; + auto cos = &cos_buffer[0]; + + // Calculate position based on number of items this index is attending to + for (int i = 0; i < n_inputs; i++) { + auto attn_row = attention_map.subspan(i * n_valid, n_valid); + int32_t pos_id = + std::accumulate(attn_row.begin() + n_skip_prefix, attn_row.end(), 0) - attn_row[n_past + i]; + + // __DEBUG("PositionID [ i={}, n_past={}, pos_id={} ]", i, n_past, pos_id); + + std::memcpy(sin, &typed_rope_sin[pos_id * m_pos_dim], pos_row_size); + std::memcpy(cos, &typed_rope_cos[pos_id * m_pos_dim], pos_row_size); + sin += m_pos_dim; + cos += m_pos_dim; + } + } else { + const size_t pos_dat_offset = m_pos_dim * (n_past - n_skip_prefix); + const size_t pos_cpy_amt = pos_row_size * ((pad_left) ? n_valid : n_tokens); + std::memcpy(&sin_buffer[pos_buf_offset], &typed_rope_sin[pos_dat_offset], pos_cpy_amt); + std::memcpy(&cos_buffer[pos_buf_offset], &typed_rope_cos[pos_dat_offset], pos_cpy_amt); + } + + return true; +} + +template +bool QnnNspModel::setupAlibiPositionEmbedding( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past +) { + DType* alibi_buffer = (DType*)getBuffer(t_position_ids); + + const int n_valid = n_past + n_inputs; + const DType pad_val = m_ctx_size; + + // Clear alibi buffer + std::fill_n(alibi_buffer, n_tokens * m_ctx_size, pad_val); + + // Detect start of past tokens and new tokens based on m_ctx_size and n_tokens (variant) + DType* alibi_past = alibi_buffer; // [0, m_ctx_size-n_tokens) + DType* alibi_new = alibi_buffer + m_ctx_size - n_tokens; // [m_ctx_size-n_tokens, m_ctx_size) + + // For non SMART_MASK, past tokens/KV$ is left-padded and past ptr needs to be offset by padding + alibi_past += m_ctx_size - n_tokens - n_past; + + // For left padded inputs, new pointer needs to be offset by n_tokens - n_inputs + if (pad_left) { + alibi_new += n_tokens - n_inputs; + alibi_past += (n_tokens - n_inputs) * m_ctx_size; + alibi_new += (n_tokens - n_inputs) * m_ctx_size; + } + + // Fill alibi positions from [-n_past-i, -i) and [-i, 0] + for (int i = 0; i < n_inputs; i++) { + std::iota( + std::reverse_iterator(alibi_past + n_past), + std::reverse_iterator(alibi_past), + i + 1 + ); // Fill past tokens + std::iota( + std::reverse_iterator(alibi_new + i + 1), + std::reverse_iterator(alibi_new), + 0 + ); // Fill new tokens + + alibi_past += m_ctx_size; // Update pointers to next row + alibi_new += m_ctx_size; + } + + return true; +} + +bool QnnNspModel::setupInputTensors( + std::span tokens, + int32_t n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset +) { + qualla::Timer start; + + const int n_tokens = run_info.n_tokens; + const int n_inputs = run_info.n_processed; + const int32_t n_valid = n_past + n_inputs; + __TRACE("qnn-htp: setup-input-tensors with {} tokens for AR-{}", n_inputs, n_tokens); + + const bool pad_left = (n_tokens == m_ctx_size); + if (n_inputs > n_tokens) { + __ERROR("qnn-htp: setup-input-tensors too many tokens: {} on AR-{}", n_inputs, n_tokens); + return false; + } + + // Setup input id tensor + { + uint32_t* input_id_buffer = (uint32_t*)getBuffer(t_input_ids); + std::fill_n(input_id_buffer, n_tokens, static_cast(m_pad_token)); + + size_t in_buf_offset = pad_left ? n_tokens - n_inputs : 0; + std::memcpy(&input_id_buffer[in_buf_offset], tokens.data(), n_inputs * sizeof(uint32_t)); + } + + // clang-format off + switch (d_attn_map) { + case QNN_DATATYPE_UFIXED_POINT_8: + setupAttentionMask(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_UFIXED_POINT_16: + setupAttentionMask(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_INT_32: + setupAttentionMask(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_FLOAT_16: { + setupAttentionMaskFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, + n_apply_prefix_offset); + break; + } + default: __ERROR("Unsupported attention mask dtype {}", d_attn_map.str()); return false; + } + // clang-format on + + // Setup token type IDs + if(m_modelArchitectureType == ModelArchitectureType::ENCODER) { + //BERT Specific + uint32_t *token_type_id_buffer = (uint32_t *) getBuffer(t_token_type_ids); + std::memset(token_type_id_buffer, 0, n_tokens * sizeof(uint32_t)); + } + + // Setup position IDs + if (m_positional_encoding.type == PositionalEncoding::ROPE) { + // clang-format off + switch (d_pos) { + case QNN_DATATYPE_UFIXED_POINT_8: + setupRopePositionEmbedding(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_UFIXED_POINT_16: + setupRopePositionEmbedding(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_FLOAT_16: + setupRopePositionEmbeddingFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + default: __ERROR("Unsupported rope position dtype {}", d_pos.str()); return false; + } + // clang-format on + } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) { + uint32_t* position_id_buffer = (uint32_t*)getBuffer(t_position_ids); + std::memset(position_id_buffer, 0, n_tokens * sizeof(uint32_t)); + + // Fill up position_ids buffer + uint32_t* pos_id_start = &position_id_buffer[pad_left ? n_tokens - n_inputs : 0]; + uint32_t* pos_id_end = pos_id_start + n_inputs; + std::iota(pos_id_start, pos_id_end, n_past); + } else if (m_positional_encoding.type == PositionalEncoding::ALIBI) { + setupAlibiPositionEmbedding(pad_left, n_tokens, n_inputs, n_past); + } + + __TRACE("qnn-htp: setup-input-tensors complete : {} usec", start.elapsed_usec()); + return true; +} + + +bool QnnNspModel::setupInputTensors( + std::span embedding, + int32_t n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset +) { + qualla::Timer start; + + const int n_tokens = run_info.n_tokens; + const int n_inputs = run_info.n_processed; + const int32_t n_valid = n_past + n_inputs; + __TRACE("qnn-htp: setup-input-tensors with {} tokens for AR-{}", n_inputs, n_tokens); + + const bool pad_left = (n_tokens == m_ctx_size); + if (n_inputs > n_tokens) { + __ERROR("qnn-htp: setup-input-tensors too many tokens: {} on AR-{}", n_inputs, n_tokens); + return false; + } + + // Setup input embeds tensor + { + // Quantize and fill, don't make double copy + size_t in_buf_offset = pad_left ? n_tokens - n_inputs : 0; + size_t startIdx = pad_left ? 0 : n_inputs; + size_t endIdx = pad_left ? in_buf_offset : n_tokens; + + if (embedding_datatype == "float32") { + // First flush the buffer with eos token embedding + for (size_t i = startIdx; i < endIdx; i++) { + quantizeInput((float*)m_eosEmbedding.data(), i*m_embd_size, m_embd_size); + } + + // Quantize the data input vector + quantizeInput((float*)embedding.data(), in_buf_offset*m_embd_size, n_inputs * m_embd_size); + } else if (embedding_datatype == "native") { + // Size of the buffer for one embedding vector. + const size_t embedBufSize = m_embeddingBufferSize; + // First flush the buffer with eos token embedding + uint8_t* embeddingSrc = static_cast(m_eosEmbedding.data()); + for (size_t i = startIdx; i < endIdx; i++) { + std::copy(embeddingSrc, embeddingSrc + embedBufSize, (uint8_t*)getBuffer(t_input_ids) + i*embedBufSize); + } + + // Copy the data input vector + embeddingSrc = static_cast(embedding.data()); + std::copy(embeddingSrc, embeddingSrc + embedding.size(), (uint8_t*)getBuffer(t_input_ids) + in_buf_offset*embedBufSize); + } + } + + // Don't modify attention mask it should work out of the box + // clang-format off + switch (d_attn_map) { + case QNN_DATATYPE_UFIXED_POINT_8: + setupAttentionMask(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_UFIXED_POINT_16: + setupAttentionMask(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_INT_32: + setupAttentionMask(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_FLOAT_16: { + setupAttentionMaskFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, + n_apply_prefix_offset); + break; + } + default: __ERROR("Unsupported attention mask dtype {}", d_attn_map.str()); return false; + } + // clang-format on + + // Setup token type IDs // Will not be + if(m_modelArchitectureType == ModelArchitectureType::ENCODER) { + //BERT Specific + uint32_t *token_type_id_buffer = (uint32_t *) getBuffer(t_token_type_ids); + std::memset(token_type_id_buffer, 0, n_tokens * sizeof(uint32_t)); + } + + // Setup position IDs + if (m_positional_encoding.type == PositionalEncoding::ROPE) { + // clang-format off + switch (d_pos) { + case QNN_DATATYPE_UFIXED_POINT_8: + setupRopePositionEmbedding(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_UFIXED_POINT_16: + setupRopePositionEmbedding(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + case QNN_DATATYPE_FLOAT_16: + setupRopePositionEmbeddingFP16(pad_left, n_tokens, n_inputs, n_past, attention_map, n_skip_prefix, n_apply_prefix_offset); break; + default: __ERROR("Unsupported rope position dtype {}", d_pos.str()); return false; + } + // clang-format on + } else if (m_positional_encoding.type == PositionalEncoding::ABSOLUTE) { + uint32_t* position_id_buffer = (uint32_t*)getBuffer(t_position_ids); + std::memset(position_id_buffer, 0, n_tokens * sizeof(uint32_t)); + + // Fill up position_ids buffer + uint32_t* pos_id_start = &position_id_buffer[pad_left ? n_tokens - n_inputs : 0]; + uint32_t* pos_id_end = pos_id_start + n_inputs; + std::iota(pos_id_start, pos_id_end, n_past); + } + + __TRACE("qnn-htp: setup-input-tensors complete : {} usec", start.elapsed_usec()); + return true; +} + +bool QnnNspModel::runInferenceHelper(bool pipeline, int32_t* total_wait, int32_t* total_exec) { + // run_info is set in runInference + int32_t idx = 0; + int32_t wait_kv_update_count = _kv_update_count; + + auto [variant, n_processed, tokens] = run_info; // based on type one of the embedding and token vector will be empty. + for (auto& nsp_graph : m_nsp_graphs) { + //__DEBUG("execute({}, {}, {})", variant, m_inference_count, wait_kv_update_count); + if (!nsp_graph.execute(variant, m_inference_count, wait_kv_update_count)) return false; + auto [cur_wait, cur_exec] = nsp_graph.getExecutionStats(); + + // If we are pipelining execution with KV$Update, dispatch KV$ update jobs + if (pipeline) { + qualla::Timer timer; + + int32_t n_past = static_cast(m_nPast + n_processed); + if(!m_disableKvCache) + _kv_update_count = _kv_dispatcher->dispatch(idx, variant, n_past); + cur_wait += timer.elapsed_usec(); + } + + *total_exec += cur_exec; + *total_wait += cur_wait; + idx++; + } + + if (pipeline) { + if(m_inputType == InputType::TOKENS) // used tokens for processing, save them + token_history.insert(token_history.end(), &tokens[0], &tokens[n_processed]); + else if(m_inputType == InputType::UNKNOWN) + { + __ERROR("Unknown input type found"); + return false; + } + m_nPast += n_processed; + } + + if (_debug_outputs){ + if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ + if(!debugOutputs(m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::POOL_OUTPUT]), m_layerNames[LayerType::POOL_OUTPUT])){ + __DEBUG("qnn-htp : Failed to save {} tensor", m_layerNames[LayerType::POOL_OUTPUT]); + } + if(!debugOutputs(m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::SEQ_OUTPUT]), m_layerNames[LayerType::SEQ_OUTPUT])){ + __DEBUG("qnn-htp : Failed to save {} tensor", m_layerNames[LayerType::SEQ_OUTPUT]); + } + } + else { + if(!debugOutputs(m_nsp_graphs.back().variants[variant]->getOutput(m_layerNames[LayerType::OUTPUT]), m_layerNames[LayerType::OUTPUT])) { + __DEBUG("qnn-htp : Failed to save {} tensor", m_layerNames[LayerType::OUTPUT]); + } + } + } + + m_inference_count++; + return true; +} + +bool QnnNspModel::debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName){ + + if(outTensor == NULL){ + __DEBUG("qnn-htp : Encountered NULL Tensor"); + return false; + } + + auto [variant, n_processed, tokens] = run_info; + + int output_bw = outTensor->dtype.bw(); // Detect 8-bit vs 16-bit logits + uint8_t *output_buffer = (uint8_t *) getBuffer(outTensor); + + int32_t offset = (variant == m_ctx_size) ? (m_ctx_size - n_processed) : 0; + int32_t bufsize = 0; + if(m_modelArchitectureType == ModelArchitectureType::ENCODER){ + bufsize = m_ctx_size * m_embd_size * output_bw; // ctx * embed_size * output_bitwidth + // Bert is saving complete out buffer as it is. + } + else{ + // Reducing buffer to number of processed tokens and each token is of vocab_size + bufsize = n_processed * m_vocab_size * output_bw; // processed_token * vocab_size * output_bitwidth + output_buffer += offset * m_vocab_size * output_bw; // shift output buffer to offset * vocab_size * output_bitwidth + } + + std::string fname = fmt::format("{}/{}/{:03d}", _debug_path, outTensorName, m_inference_count); + QnnUtils::writeRawData(output_buffer, bufsize, fname); + return true; + +} + +int32_t QnnNspModel::selectVariantStrategy(int32_t n_inputs, int32_t n_past, int32_t cur_variant) { + int32_t best_variant = cur_variant; + int32_t best_cost = INT32_MAX; + int32_t switch_cost = 10; // Currently hard-coded to 10ms + + for (auto [variant, latency] : variant_latency) { + // If variant cannot support the n_past, it is a non-starter + // e.g. AR-128 with ctx_size=1024 can only support upto n_past=896 since it uses 128 output + if (n_past + n_inputs > m_ctx_size) continue; + + const int32_t n_iters = 1 + ((n_inputs - 1) / variant); + const int32_t cost = latency * n_iters + ((variant == cur_variant) ? 0 : switch_cost); + if (cost < best_cost) { + best_variant = variant; + best_cost = cost; + } + } + + __DEBUG("qnn-htp : Variant selected AR={} (~ {} ms)", best_variant, best_cost); + return best_variant; +} + +size_t QnnNspModel::runInference( + const std::vector& in_tokens, + const std::vector& attention_map, + std::vector& output, + bool output_all +) { + qualla::Timer start; + + __TRACE("runInference logits_all={} in_tokens={}", output_all, in_tokens); + + if(m_inputType != InputType::TOKENS) { + throw std::runtime_error("Wrong Type of input is supplied for token type query."); + } + + if (in_tokens.size() == 0) return 0; + + // Select variant based on variant_latency, or default to current variant + std::vector tokens(in_tokens); + if (!variant_latency.empty() && !m_disableKvCache) { + const int32_t cur_variant = _kv_dispatcher->getCurVariant(); + const int32_t new_variant = selectVariantStrategy(tokens.size(), m_nPast, cur_variant); + if (cur_variant != new_variant) // Switch variant if necessary + _kv_update_count = _kv_dispatcher->dispatch(new_variant, m_nPast); + } + + // If variant selected in BERT-Mode, append token history to current request + int32_t variant = 0; + if(!m_disableKvCache) + variant = _kv_dispatcher->getCurVariant(); + else + variant = nsp_graph_count.rbegin()->first; // pick largest variant + if (variant == m_ctx_size && m_nPast != 0) + tokens.insert(tokens.begin(), token_history.begin(), token_history.end()); + + const int32_t n_inputs = static_cast(tokens.size()); + const int32_t n_past = static_cast(m_nPast); + const int32_t n_valid = n_past + n_inputs; + run_info.n_tokens = variant; + if (variant != m_ctx_size && m_nPast + variant > m_ctx_size) { + __ERROR("qnn-htp: exceeding ctx_size! : {} + {} > {}", m_nPast, variant, m_ctx_size); + return 0; + } + + // Calculate number of batches for run-inference + const int32_t num_iters = 1 + ((n_inputs - 1) / variant); + __DEBUG("qnn-htp: run-inference : {} tokens (AR-{} * {} iters)", n_inputs, variant, num_iters); + + // Validate attention_map size + if (!attention_map.empty() && attention_map.size() != n_inputs && + attention_map.size() != n_inputs * (n_past + n_inputs)) { + // clang-format off + __ERROR("qnn-htp: attention_map must be 1D(n_inputs) or 2D(n_inputs * (n_past + n_inputs))" + "but has size={} for n_past={} n_inputs={}", attention_map.size(), n_past, n_inputs); + // clang-format on + return 0; + } + std::vector chunked_attn_map; + + // Technical note: int32_t can hold upto 596 hours + // Even int16_t should be sufficient here - it holds upto 32.8 seconds + int32_t total_wait = 0; + int32_t total_exec = 0; + + // user choice overwrites the default behaviour in case of Embedding models + if(m_modelArchitectureType == ModelArchitectureType::ENCODER) + output_all = !m_pooled_output; + + // Reset logit accumulator + size_t output_count = output_all ? n_inputs : 1; // actual number of logits + + if(m_modelArchitectureType == ModelArchitectureType::ENCODER) + output.resize(output_count * m_embd_size); + else + output.resize(output_count * m_vocab_size); + + for (int i = 0; i < num_iters; i++) { + const int32_t update_size = std::min(variant, n_inputs - i * variant); + run_info.n_processed = update_size; + run_info.tokens.assign(&tokens[i * variant], &tokens[i * variant + update_size]); + + int32_t n_skip_prefix = + (i * variant < _offset_to_apply_kv_prefix) ? _size_to_skip_kv_prefix : 0; + int32_t n_apply_prefix_offset = 0; + if (i * variant < _offset_to_apply_kv_prefix) + n_apply_prefix_offset = std::min(variant, _offset_to_apply_kv_prefix - i * variant); + + // Chunk inputs and attention mask + std::span tokens_chunk = std::span{tokens.data(),tokens.size()}.subspan(i * variant, update_size); + std::span attn_map_chunk = std::span(); + if (attention_map.size() == n_inputs) { + chunked_attn_map.resize(update_size); + // Take exactly update_size elements. Be mindful to decrease offset already processed + for (int j = 0; j < update_size; j++) + chunked_attn_map[j] = attention_map[i * variant + j] - (i * variant); + attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()}; + } else if (attention_map.size() == n_inputs * (n_past + n_inputs)) { + chunked_attn_map.clear(); + chunked_attn_map.resize(update_size * (m_nPast + update_size)); + + for (int j = 0; j < update_size; j++) { + // Be mindful. m_nPast changes each iteration. + // n_tokens is total #tokens called. update_size is the n_tokens for this iteration + // n_past is the initial m_nPast. n_valid = n_past + n_tokens + std::memcpy( + &chunked_attn_map[j * (m_nPast + update_size)], + &attention_map[i * variant * n_valid + j * n_valid], + (m_nPast + update_size) * sizeof(int32_t) + ); + } + attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()}; + } + + if (!setupInputTensors( + tokens_chunk, + (variant == m_ctx_size) ? 0 : m_nPast, + attn_map_chunk, + n_skip_prefix, + n_apply_prefix_offset + )) + return 0; + + // Run Inference and pipeline KV$ update iff n_inputs is exactly 1 or we have more batches + bool pipeline = (n_inputs == 1 || i < num_iters - 1); + if (!runInferenceHelper(pipeline, &total_wait, &total_exec)) return 0; + + if (m_modelArchitectureType != ModelArchitectureType::ENCODER && output_all) { + // Accumulate logits + const size_t logit_offset = i * variant * m_vocab_size; + const size_t logit_count = update_size * m_vocab_size; + getDequantLogits(std::span{output.data(), output.size()}.subspan(logit_offset, logit_count), + output_all); + } + } + + // Return last logit if not accumulating + if(m_modelArchitectureType != ModelArchitectureType::ENCODER) { + if(!output_all) + getDequantLogits(std::span{output.data(), output.size()}, output_all); + } + else + getEmbeddings(std::span{output.data(), output.size()}); + + __DEBUG("qnn-htp: run-inference complete : {} usec : wait {} exec {}", + start.elapsed_usec(), + total_wait, + total_exec); + + // threadpool.suspend(); + return output_count; +} + +bool QnnNspModel::quantizeInput(float* in, size_t tensorOffset ,size_t length) { + + if(t_input_ids == nullptr) { + __ERROR("Input Tensor {} not found during execute", m_layerNames[LayerType::INPUT]); + return false; + } + + const auto scale = t_input_ids->quantParam[0].scale; + const auto offset = t_input_ids->quantParam[0].offset; + + // clang-format off + switch (t_input_ids->dtype) { + case QNN_DATATYPE_UFIXED_POINT_8: QnnUtils::quantizeTensorPtr(in, (uint8_t*)getBuffer(t_input_ids) + tensorOffset, offset, scale, length); break; + case QNN_DATATYPE_UFIXED_POINT_16: QnnUtils::quantizeTensorPtr(in, (uint16_t*)getBuffer(t_input_ids) + tensorOffset, offset, scale, length); break; + default: __ERROR("Unsupported alpha tensor dtype {}", t_input_ids->dtype.str()); return false; + } + + return true; +} + +size_t QnnNspModel::getEmbeddingBufferSize() { + return m_embeddingBufferSize; +} + +size_t QnnNspModel::runInference( + std::vector& embedding, + const std::vector& attention_map, + std::vector& output, + bool output_all +) { + qualla::Timer start; + + __DEBUG("qnn-htp: run-inference start : n_Embd {}", embedding.size()); + + if(m_inputType != InputType::EMBEDDINGS) { + throw std::runtime_error("Embedding input type is not supported by the model."); + } + + if (embedding.size() == 0) return true; + + size_t embedBufSize = m_embeddingBufferSize; + // Select variant based on variant_latency, or default to current variant + int32_t curTokenCount = embedding.size() / embedBufSize; + if (!variant_latency.empty() && !m_disableKvCache) { + const int32_t cur_variant = _kv_dispatcher->getCurVariant(); + const int32_t new_variant = selectVariantStrategy(curTokenCount, m_nPast, cur_variant); + if (cur_variant != new_variant) // Switch variant if necessary + _kv_update_count = _kv_dispatcher->dispatch(new_variant, m_nPast); + } + + // If variant selected in BERT-Mode, append token history to current request + const int32_t variant = _kv_dispatcher->getCurVariant(); + + // We will never be maintaining history for the embedding + + const int32_t n_inputs = static_cast(curTokenCount); + const int32_t n_past = static_cast(m_nPast); + const int32_t n_valid = n_past + n_inputs; + run_info.n_tokens = variant; + + if (variant != m_ctx_size && m_nPast + variant > m_ctx_size) { + __ERROR("qnn-htp: exceeding ctx_size! : {} + {} > {}", m_nPast, variant, m_ctx_size); + return 0; + } + + const int32_t num_iters = 1 + ((n_inputs - 1) / variant); + __DEBUG("qnn-htp: run-inference : {} tokens (AR-{} * {} iters)", + n_inputs, + variant, + num_iters); + + // Validate attention_map size + if (!attention_map.empty() && attention_map.size() != n_inputs && + attention_map.size() != n_inputs * (n_past + n_inputs)) { + // clang-format off + __ERROR("qnn-htp: attention_map must be 1D(n_inputs) or 2D(n_inputs * (n_past + n_inputs))" + "but has size={} for n_past={} n_inputs={}", attention_map.size(), n_past, n_inputs); + // clang-format on + return 0; + } + std::vector chunked_attn_map; + + // Technical note: int32_t can hold upto 596 hours + // Even int16_t should be sufficient here - it holds upto 32.8 seconds + int32_t total_wait = 0; + int32_t total_exec = 0; + + // Reset logit accumulator + size_t output_count = output_all ? n_inputs : 1; // actual number of logits + + output.resize(output_count * m_vocab_size); + + for (int i = 0; i < num_iters; i++) { + const int32_t update_size = std::min(variant, n_inputs - i * variant); + run_info.n_processed = update_size; + const int32_t startIdx = i * variant * embedBufSize; + + int32_t n_skip_prefix = + (i * variant < _offset_to_apply_kv_prefix) ? _size_to_skip_kv_prefix : 0; + int32_t n_apply_prefix_offset = 0; + if (i * variant < _offset_to_apply_kv_prefix) + n_apply_prefix_offset = std::min(variant, _offset_to_apply_kv_prefix - i * variant); + + // Chunk inputs and attention mask + std::span embedding_chunk = std::span{embedding.data(),embedding.size()}.subspan(startIdx, update_size*embedBufSize); + std::span attn_map_chunk = std::span(); + if (attention_map.size() == n_inputs) { + chunked_attn_map.resize(update_size); + // Take exactly update_size elements. Be mindful to decrease offset already processed + for (int j = 0; j < update_size; j++) + chunked_attn_map[j] = attention_map[i * variant + j] - (i * variant); + attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()}; + } else if (attention_map.size() == n_inputs * (n_past + n_inputs)) { + chunked_attn_map.clear(); + chunked_attn_map.resize(update_size * (m_nPast + update_size)); + + for (int j = 0; j < update_size; j++) { + // Be mindful. m_nPast changes each iteration. + // n_tokens is total #tokens called. update_size is the n_tokens for this iteration + // n_past is the initial m_nPast. n_valid = n_past + n_tokens + std::memcpy( + &chunked_attn_map[j * (m_nPast + update_size)], + &attention_map[i * variant * n_valid + j * n_valid], + (m_nPast + update_size) * sizeof(int32_t) + ); + } + attn_map_chunk = std::span{chunked_attn_map.data(),chunked_attn_map.size()}; + } + + if (!setupInputTensors( + embedding_chunk, + (variant == m_ctx_size) ? 0 : m_nPast, + attn_map_chunk, + n_skip_prefix, + n_apply_prefix_offset + )) + return 0; + + // Run Inference and pipeline KV$ update iff n_inputs is exactly 1 or we have more batches + bool pipeline = (n_inputs == 1 || i < num_iters - 1); + if (!runInferenceHelper(pipeline, &total_wait, &total_exec)) return 0; + + if (output_all) { + // Accumulate logits + const size_t logit_offset = i * variant * m_vocab_size; + const size_t logit_count = update_size * m_vocab_size; + getDequantLogits(std::span{output.data(), output.size()}.subspan(logit_offset, logit_count), + output_all); + } + } + + // Return last logit if not accumulating + if(!output_all) + getDequantLogits(std::span{output.data(), output.size()}, output_all); + + __DEBUG("qnn-htp: run-inference complete : {} usec : wait {} exec {}", + start.elapsed_usec(), + total_wait, + total_exec); + + return output_count; +} + +bool QnnNspModel::cacheEosEmbedding(std::vector& eosEmbedding) { + m_eosEmbedding = eosEmbedding; + return true; +} + +bool QnnNspModel::setKVCacheNPast(size_t n_past, const std::vector& selected) { + __TRACE("setKVCacheNPast (m_nPast={} -> n_past={})", m_nPast, n_past); + if (n_past == m_nPast && n_past != 0) return true; + + if (m_nPast + run_info.n_processed < n_past) { + __ERROR("qnn-htp: set-kv n_past update larger than number of processed tokens : n_past {} n_proc {}", + n_past, + m_nPast + run_info.n_processed); + return false; + } + + if (m_inputType == InputType::TOKENS) { + if (n_past == 0) { + int32_t new_variant = nsp_graph_count.rbegin()->first; + _kv_update_count = _kv_dispatcher->dispatch(new_variant, 0, selected); + token_history.clear(); + + } else if (n_past < m_nPast) { + auto [variant, update_size, tokens] = run_info; + _kv_update_count = _kv_dispatcher->dispatch(variant, n_past); + token_history.resize(n_past); + } else { + int32_t new_variant = nsp_graph_count.begin()->first; + _kv_update_count = _kv_dispatcher->dispatch(new_variant, n_past, selected); + + auto [variant, update_size, tokens] = run_info; + + if (variant == m_ctx_size) { + token_history.assign(&tokens[0], &tokens[n_past]); + } else if (selected.empty()) { + token_history.insert(token_history.end(), &tokens[0], &tokens[n_past - m_nPast]); + } else { + for (auto i = 0; i < tokens.size(); ++i) { + if (selected[i]) token_history.push_back(tokens[i]); + } + } + } + } + else if (m_inputType == InputType::EMBEDDINGS) { // Don't add embedding history, It is costly maintenance to do. + if (n_past == 0) { + int32_t new_variant = nsp_graph_count.rbegin()->first; + _kv_update_count = _kv_dispatcher->dispatch(new_variant, 0, selected); + } else if (n_past < m_nPast) { + auto [variant, update_size, tokens] = run_info; + _kv_update_count = _kv_dispatcher->dispatch(variant, n_past); + } else { + int32_t new_variant = nsp_graph_count.begin()->first; + _kv_update_count = _kv_dispatcher->dispatch(new_variant, n_past, selected); + } + } + else + { + __ERROR("Wrong type of input is found."); + return false; + } + + m_nPast = n_past; + return true; +} + +template +inline void deQuantizeOutputs( + U* inputs, + std::span& outputs, + const double scale, + const int32_t offset, + const int count +) { +#pragma clang loop vectorize(enable) interleave(enable) + for (int i = 0; i < count; ++i) + outputs[i] = ((T)inputs[i] + offset) * scale; +} + +template +inline void castOutputs(U* inputs, std::span& outputs, const int numElements, const int bitWidth) { + if(bitWidth == 2) { +#pragma clang loop vectorize(enable) interleave(enable) + for (int i = 0; i < numElements; ++i) + outputs[i] = fp16_ieee_to_fp32_value(inputs[i]); + } + else if(bitWidth == 4) { +#pragma clang loop vectorize(enable) interleave(enable) + for (size_t i = 0; i < numElements; i++) { + outputs[i] = inputs[i]; + } + } +} + +size_t QnnNspModel::getDequantLogits(std::span dequant_logits, bool logits_all) { + qualla::Timer start; + + QnnUtils::Tensor* const logit_spec = + m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::OUTPUT]); + const int return_size = logits_all ? run_info.n_processed : 1; + const auto [scale, offset] = logit_spec->quantParam[0]; + + auto d_logits = QnnUtils::DataType(logit_spec->tensor); + + int logit_bw = logit_spec->dtype.bw(); + + uint8_t* logit_buffer = (uint8_t*)getBuffer(logit_spec); + // const int return_size = logits_all ? run_info.n_processed : 1; + if (logit_spec->dims.getNumElements() == m_vocab_size) { + // BERT Mode graph may return only the last logit + // If only one logit is returned, simply return the last logit + if (return_size > 1) + throw std::runtime_error("Requested all logits, but graph only produces one logit"); + } else { + // If multiple logits are returned, offset to the correct location in the buffer + if (run_info.n_tokens == m_ctx_size) { + // This was left-padded, logits are at [n_tokens - n_processed, n_tokens] + logit_buffer += (run_info.n_tokens - return_size) * m_vocab_size * d_logits.bw(); + } else if (logits_all == false) { + // This was right-padded, logits are at indexes [0, n_processed] + logit_buffer += (run_info.n_processed - 1) * m_vocab_size * d_logits.bw(); + } + } + const int n_logits = static_cast(m_vocab_size * return_size); + __TRACE("qnn-htp: get-logits logits_all={} for {} tokens. Returning {}*{}", + logits_all, + run_info.n_processed, + return_size, + m_vocab_size); + + switch (d_logits) { + case QNN_DATATYPE_UFIXED_POINT_8: + deQuantizeOutputs((uint8_t*)logit_buffer, dequant_logits, scale, offset, n_logits); + break; + case QNN_DATATYPE_UFIXED_POINT_16: + deQuantizeOutputs((uint16_t*)logit_buffer, dequant_logits, scale, offset, n_logits); + break; + case QNN_DATATYPE_FLOAT_16: { + castOutputs((uint16_t*)logit_buffer, dequant_logits, n_logits, logit_bw); + break; + } + default: + __ERROR("Unsupported logits dtype {}", d_logits.str()); + } + + __DEBUG("qnn-htp: getDequantLogits complete : {} usec (return_size={})", + start.elapsed_usec(), + return_size); + return return_size; +} + +bool QnnNspModel::calculate_rope_embeddings(void) { + if (m_positional_encoding.type != PositionalEncoding::ROPE) return true; + + const size_t nmemb = m_ctx_size * m_pos_dim; + const int pos_bw = d_pos.bw(); + + rope_sin = malloc(nmemb * pos_bw); + rope_cos = malloc(nmemb * pos_bw); + + auto [q_scale, q_offset] = t_position_ids_cos->quantParam[0]; + if (d_pos == QNN_DATATYPE_FLOAT_16) { // If floating point, don't quantize! + q_scale = 1.0; + q_offset = 0; + } + + // Calculate inv_freq array + std::vector inv_freq(m_pos_dim); + const double exponent = 1.0 / static_cast(m_pos_dim); + for (int j = 0; j < m_pos_dim; j++) + inv_freq[j] = 1.0 / pow(rope_theta, j * exponent); + double attention_factor = 1.0; + if (rope_scaling.rope_type == RopeScalingParams::ROPE_LLAMA3) { + // Implemented from HuggingFace + // https://github.com/huggingface/transformers/blob/47c29ccfaf56947d845971a439cbe75a764b63d7/src/transformers/modeling_rope_utils.py#L298 + const double& factor = rope_scaling.llama3_params.factor; + const double& low_freq_factor = rope_scaling.llama3_params.low_freq_factor; + const double& high_freq_factor = rope_scaling.llama3_params.high_freq_factor; + const int& old_context_len = rope_scaling.llama3_params.original_max_position_embeddings; + + const double low_freq_wavelen = old_context_len / low_freq_factor; + const double high_freq_wavelen = old_context_len / high_freq_factor; + + for (int j = 0; j < m_pos_dim; j++) { + const double wavelen = 2 * M_PI / inv_freq[j]; + if (wavelen < high_freq_wavelen) // wavelen < high_freq_wavelen: do nothing + continue; + else if (wavelen > low_freq_wavelen) // wavelen > low_freq_wavelen: divide by factor + inv_freq[j] = 1.0 / static_cast(factor * pow(rope_theta, j * exponent)); + else { // otherwise: interpolate between the two, using a smooth factor + assert(low_freq_wavelen != high_freq_wavelen); + const double smooth = + (static_cast(old_context_len) / wavelen - low_freq_factor) / + (high_freq_factor - low_freq_factor); + inv_freq[j] = ((1 - smooth) * inv_freq[j] / factor + smooth * inv_freq[j]); + } + } + } else if (rope_scaling.rope_type == RopeScalingParams::ROPE_LONGROPE) { + // Validate factor >= 1.0, len(long_factor) == rope-dim and len(short_factor) == rope-dim + const double& factor = rope_scaling.longrope_params.factor; + const int& old_context_len = rope_scaling.longrope_params.original_max_position_embeddings; + + const auto& inv_factors = (m_ctx_size > old_context_len) + ? rope_scaling.longrope_params.long_factor + : rope_scaling.longrope_params.short_factor; + + if (inv_factors.size() != m_pos_dim) + throw std::runtime_error(fmt::format( + "long-factor (len={}) and short-factor (len={}) must have length rope-dim={}", + rope_scaling.longrope_params.long_factor.size(), + rope_scaling.longrope_params.short_factor.size(), + m_pos_dim + )); + + for (int j = 0; j < m_pos_dim; j++) + inv_freq[j] = inv_freq[j] / inv_factors[j]; + + attention_factor = + std::sqrt(1.0 + std::log(factor) / std::log(static_cast(old_context_len))); + } + for (int i = 0; i < m_ctx_size; i++) { + for (int j = 0; j < m_pos_dim; j++) { + const double freq = i * inv_freq[j]; + + const double sin_val = ((sin(freq) * attention_factor) / q_scale) - q_offset; + const double cos_val = ((cos(freq) * attention_factor) / q_scale) - q_offset; + + // round() instead of floor() seems to produce an acuracy drop. To debug later + switch (d_pos) { + case QNN_DATATYPE_UFIXED_POINT_8: + ((uint8_t*)rope_sin)[i * m_pos_dim + j] = static_cast(sin_val); + ((uint8_t*)rope_cos)[i * m_pos_dim + j] = static_cast(cos_val); + break; + case QNN_DATATYPE_UFIXED_POINT_16: + ((uint16_t*)rope_sin)[i * m_pos_dim + j] = static_cast(sin_val); + ((uint16_t*)rope_cos)[i * m_pos_dim + j] = static_cast(cos_val); + break; + case QNN_DATATYPE_FLOAT_16: + ((uint16_t *)rope_sin)[i * m_pos_dim + j] = fp16_ieee_from_fp32_value(sin_val); + ((uint16_t*)rope_cos)[i * m_pos_dim + j] = fp16_ieee_from_fp32_value(cos_val); + break; + default: + __ERROR("Unsupported position ids datatype {}", d_pos.str()); + return false; + } + } + } + + if (_debug_tensors) { + std::string dtype = + fmt::format("{}", (d_pos == QNN_DATATYPE_FLOAT_16) ? "f" : "u", pos_bw * 8); + std::string fname_sin = fmt::format("{}/position_ids_sin.{}.dat", _debug_path, pos_bw * 8); + std::string fname_cos = fmt::format("{}/position_ids_cos.{}.dat", _debug_path, pos_bw * 8); + QnnUtils::writeRawData(rope_sin, nmemb * pos_bw, fname_sin); + QnnUtils::writeRawData(rope_cos, nmemb * pos_bw, fname_cos); + } + + return true; +} + +bool QnnNspModel::load_lmhead_weight_as_input(void) { + if (!_lmhead_weight_input) return true; + if (_lmhead_weight_input && lmhead_weight_dir.empty()) { + __ERROR("NSPModel: LMhead weight file not found"); + return false; + } + for (auto& variant : m_variant_list) { + for (auto& [tname, tspec] : variant.input_specs) { + if (tname.compare("weight") == 0) { + // weight tensor file name should be in same format as tensor name present in graph + std::string weight_file = + (model_basedir / fs::path(lmhead_weight_dir) / fs::path(tname + ".raw")) + .string(); + + QnnUtils::Dims dims = tspec.dims; + size_t numElements = dims.getNumElements(); + + size_t size = sizeof(float); + std::vector weight_f32; // Temporary variable to load fp32 values + weight_f32.reserve(numElements); + + FILE* fp = fopen(weight_file.c_str(), "r"); + if (fp == NULL) { + __ERROR("NSPModel: Error opening file: {}", weight_file); + return false; + } + + size_t count = fread(weight_f32.data(), size, numElements, fp); + fclose(fp); + + if (count != numElements) { + __ERROR("NSPModel: Could not load {} - expected file size {}", + weight_file, + numElements * size); + return false; + } + + int8_t* weight_buffer = (int8_t*)getBuffer(tspec); + // Quantize the values, per width quantization + QnnUtils::perWidthQuantizeTensorPtr( + weight_f32.data(), + weight_buffer, + tspec.quantParam, + dims.height, + dims.width, + dims.channel + ); + } + } + } + return true; +} + +bool QnnNspModel::flushLoraWeightsBuffers(void){ + if(!_lora_enabled){ + __ERROR("qnn-htp: Model does not support LoRA weights."); + return false; + } + + for (auto& variant : m_variant_list) { + for (auto& [tname, tspec] : variant.input_specs) { + if (tname.find("lora") != std::string::npos) { // find lora weights tensors and flush them out + if(getBuffer(tspec) == nullptr) + return false; + size_t numElements = tspec.dims.getNumElements(); + auto offset = tspec.quantParam[0].offset; + // Since values needs to be quantized so zero is going to get translated. + // clang-format off + switch (tspec.dtype) { + case QNN_DATATYPE_UFIXED_POINT_8: std::fill_n((uint8_t*)getBuffer(tspec), numElements, static_cast(-offset)); break; + case QNN_DATATYPE_UFIXED_POINT_16: std::fill_n((uint16_t*)getBuffer(tspec), numElements, static_cast(-offset)); break; + case QNN_DATATYPE_FLOAT_16:{ + uint16_t *buffer = (uint16_t *)getBuffer(tspec); + for(int i=0;i lora_weights_f32; // Temporary variable to load fp32 values + lora_weights_f32.reserve(numElements); + + FILE* fp = fopen(lora_weights_file.c_str(), "r"); + if (fp == NULL) { + __ERROR("NSPModel: Error opening file: {}", lora_weights_file); + return false; + } + + size_t count = fread(lora_weights_f32.data(), size, numElements, fp); + fclose(fp); + + if (count != numElements) { + __ERROR("NSPModel: Could not load {} - expected file size {}", + lora_weights_file, + numElements * size); + return false; + } + + // Quantize the values + // clang-format off + switch (tspec.dtype) { + case QNN_DATATYPE_UFIXED_POINT_8: QnnUtils::quantizeTensorPtr(lora_weights_f32.data(), (uint8_t*)getBuffer(tspec), offset, scale, numElements); break; + case QNN_DATATYPE_UFIXED_POINT_16: QnnUtils::quantizeTensorPtr(lora_weights_f32.data(), (uint16_t*)getBuffer(tspec), offset, scale, numElements); break; + case QNN_DATATYPE_FLOAT_16: float32ToFloat16((uint8_t *)getBuffer(tspec), lora_weights_f32.data(), numElements); break; + default: __ERROR("Unsupported {} datatype for {} tensor", tspec.dtype.str(), tname); return false; + } + } + } + } + return true; +} + +void QnnNspModel::dumpTensorSpecs() { + static const char* stringFmt = + "\t\t{ \"name\": \"%s\", \"dims\": [1, %d, %d, %d], " + "\"bitwidth\": %d, \"dtype\": \"%s\", \"scale\": [%s], \"offset\": [%s] },\n"; + for (GraphVariant& variant : m_variant_list) { + GraphInfo_t* graph_info = variant.graph_info; + + // Create output spec file and open it + std::string filename = fmt::format("{}/spec.{}.json", _debug_path, graph_info->graphName); + + FILE* specFile = fopen(filename.c_str(), "w"); + if (specFile == NULL) throw std::runtime_error("Error opening file : " + filename); + + fprintf(specFile, "{\n\t\"graph_name\" : \"%s\",\n", variant.graph_name.c_str()); + for (bool io : {true, false}) { + uint32_t n_tensors = (io) ? graph_info->numInputTensors : graph_info->numOutputTensors; + Qnn_Tensor_t* tensor = (io) ? graph_info->inputTensors : graph_info->outputTensors; + QnnUtils::TensorMap& tspecs = (io) ? variant.input_specs : variant.output_specs; + + fprintf(specFile, (io) ? "\t\"inputs\" : [\n" : "\t\"outputs\" : [\n"); + while (n_tensors-- > 0) { + std::string tname = QnnApi::getTensorName(*tensor); + auto& [_, dims, quant_params, dtype] = tspecs.at(tname); + auto& [__, h, w, c, bw] = dims; + std::string scales; + std::string offsets; + QnnUtils::getQuantParamString(quant_params, scales, offsets); + // clang-format off + fprintf(specFile, stringFmt, tname.c_str(), h, w, c, bw, dtype.str(), scales.c_str(), offsets.c_str()); + // clang-format on + tensor++; + } + fseek(specFile, -2, SEEK_CUR); // Remove trailing comma + fprintf(specFile, "\n\t],\n"); + } + fseek(specFile, -2, SEEK_CUR); // Remove trailing comma + fprintf(specFile, "\n}"); + fclose(specFile); + } +} + +size_t QnnNspModel::loadKVCache(const std::string& load_path) { + + if(m_disableKvCache){ + __ERROR("KV cache is disabled, loading KV cache is not allowed"); + return false; + } + + std::ifstream fs(load_path, std::ios::in | std::ios::binary); + if (fs.fail()) { + // TODO: replace with proper error handling + __ERROR("qnn-htp: load-kv errror reading file {}", load_path); + return 0; + } + + CacheFileSpec spec; + fs.read((char*)&spec, sizeof(spec)); + if (spec.magic != 0xC0DE) { + __ERROR("qnn-htp: load-kv expected 0xC0DE found {:#x}", spec.magic); + return 0; + } + + bool dtype_check = true; + // clang-format off + switch (d_kv) { + case QNN_DATATYPE_UFIXED_POINT_8: dtype_check = spec.dtype == CacheFileSpec::UINT8_T; break; + case QNN_DATATYPE_UFIXED_POINT_16: dtype_check = spec.dtype == CacheFileSpec::UINT16_T; break; + case QNN_DATATYPE_FLOAT_16: dtype_check = spec.dtype == CacheFileSpec::FLOAT16_T; break; + default: __ERROR("Unsupported KV$ datatype {}", d_kv.str()); return false; + } + // clang-format on + + if (!dtype_check) { + __ERROR("Model has KV$ Dtype {} but found {} in cache", d_kv.str(), int(spec.dtype)); + return false; + } + + // clang-format off + __DEBUG("qnn-htp: load-kv {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}", + spec.num_tensors, spec.magic, int(spec.dtype), spec.n_heads, spec.embed_dim, spec.update_size); + // clang-format on + + const int32_t n_valid = static_cast(spec.update_size); + const int32_t variant = nsp_graph_count.begin()->first; // Set KVManager to smallest variant + _kv_dispatcher->setVariant(variant); + + // Lock, load KeyCache then ValueCache, unlock + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.waitForLock("loadKVCache", _kv_update_count, false); + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.kvmanager->loadCache(&fs, true, n_valid, variant, spec.n_heads); + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.kvmanager->loadCache(&fs, false, n_valid, variant, spec.n_heads); + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.releaseLock("loadKVCache"); + + fs.seekg(spec.num_tensors * sizeof(double), std::ios::cur); + + + + // Loading previous runs history input only applicable in case of tokens. + // Embeddings history maintenance is costly in terms of memory and time. + if(m_inputType == InputType::TOKENS) { + token_history.clear(); + token_history.resize(n_valid); + fs.read((char *) token_history.data(), n_valid * sizeof(int32_t)); + } + else if(m_inputType == InputType::UNKNOWN) { + __ERROR("Wrong type of input is found."); + return false; + } + fs.close(); + + m_nPast = n_valid; + return spec.update_size; +} + +bool QnnNspModel::saveKVCache(const std::string& save_path) { + + if(m_disableKvCache){ + __ERROR("KV cache is disabled, saving KV cache is not allowed"); + return false; + } + + std::ofstream fs(save_path, std::ios::out | std::ios::binary); + if (fs.fail()) { + __ERROR("qnn-htp: save-kv error opening file : {}", save_path); + throw std::runtime_error("Failed to write to cache file. Please re-check path"); + } + + const uint16_t n_valid = static_cast(m_nPast); + + auto dtype = CacheFileSpec::UINT8_T; + // clang-format off + switch (d_kv) { + case QNN_DATATYPE_UFIXED_POINT_8: dtype = CacheFileSpec::UINT8_T; break; + case QNN_DATATYPE_UFIXED_POINT_16: dtype = CacheFileSpec::UINT16_T; break; + case QNN_DATATYPE_FLOAT_16: dtype = CacheFileSpec::FLOAT16_T; break; + default: __ERROR("Unsupported KV$ datatype {}", d_kv.str()); return false; + } + // clang-format on + + // Pre-calculate #tensors and n_heads to guide memory allocations + uint32_t n_tensors = 0; + int32_t n_heads = 0; + for (auto& nsp_graph : m_nsp_graphs) { + nsp_graph.waitForLock("saveKVCache", _kv_update_count, false); + n_tensors += nsp_graph.kvmanager->getNumKVTensors(); + n_heads = std::max(n_heads, nsp_graph.kvmanager->getMaxNHeads()); + } + + // Save the cache file metadata + CacheFileSpec file_spec( + n_tensors, 0xc0de, dtype, 0x0, static_cast(n_heads), m_kv_dim, n_valid + ); + fs.write((char*)&file_spec, sizeof(file_spec)); + + // Dump KeyCache and ValueCache + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.kvmanager->dumpCache(&fs, true, n_valid, n_heads); + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.kvmanager->dumpCache(&fs, false, n_valid, n_heads); + + // Dump Quantization parameters - Key scales then Value scales + for (auto& nsp_graph : m_nsp_graphs) { + std::vector& key_scales = nsp_graph.kvmanager->getKeyScales(); + fs.write((char*)key_scales.data(), key_scales.size() * sizeof(double)); + } + for (auto& nsp_graph : m_nsp_graphs) { + std::vector& value_scales = nsp_graph.kvmanager->getValueScales(); + fs.write((char*)value_scales.data(), value_scales.size() * sizeof(double)); + } + + // Saving previous runs history input only applicable in case of tokens. + // Embeddings history maintenance is costly in terms of memory and time. + if(m_inputType == InputType::TOKENS) + fs.write((char*)token_history.data(), n_valid * sizeof(int32_t)); + else if(m_inputType == InputType::UNKNOWN) { + __ERROR("Wrong type of input is found."); + return false; + } + + // Release the lock + for (auto& nsp_graph : m_nsp_graphs) + nsp_graph.releaseLock("saveKVCache"); + + fs.flush(); + fs.close(); + + return true; +} + +bool QnnNspModel::applyBinarySections(std::vector& binsection_list) { + //apply binarysection for lora config + for (int i = 0; i < binsection_list.size(); i++) { + __DEBUG("qnn-htp: applyBinarySections adapters {}", binsection_list.at(i)); + if (!m_qnnApi->applyBinarySection(i, binsection_list.at(i),m_use_mmap,graph_switching)) { + __ERROR("qnn-htp: Error in applyBinarySections {}", i); + return false; + } + } + return true; +} + +bool QnnNspModel::applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val) { + if(alpha_tensor_name.empty()) return true; + for (auto& variant : m_variant_list) { + if (!variant.input_specs.contains(alpha_tensor_name)) continue; + + auto& tspec = variant.input_specs.at(alpha_tensor_name); + auto [scale, offset] = tspec.quantParam[0]; + + // clang-format off + switch (tspec.dtype) { + case QNN_DATATYPE_UFIXED_POINT_8: QnnUtils::quantizeTensorPtr(&alpha_val, (uint8_t*)getBuffer(tspec), offset, scale, 1); break; + case QNN_DATATYPE_UFIXED_POINT_16: QnnUtils::quantizeTensorPtr(&alpha_val, (uint16_t*)getBuffer(tspec), offset, scale, 1); break; + case QNN_DATATYPE_FLOAT_16: *(uint16_t *)getBuffer(tspec) = fp16_ieee_from_fp32_value(alpha_val); break; + default: __ERROR("Unsupported alpha tensor dtype {}", tspec.dtype.str()); return false; + } + // clang-format on + __DEBUG("qnn-htp: applyAlphaTensor alpha = {}", alpha_val); + return true; // Each lora bin section should have only one alpha tensor + } + return false; +} + +bool QnnNspModel::applyLoraAdapter(const std::string& lora_adapter_name) { + if (lora_conf != LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) { + __ERROR("qnn-htp: Lora config is not enable for adapters"); + return false; + } + + if (!lora_config.contains(lora_adapter_name)) { + __ERROR("qnn-htp: Could not find lora adapters config to apply "); + return false; + } + + if (!applyLoraStrength( + lora_config[lora_adapter_name].alpha_tensor_name, + lora_config[lora_adapter_name].alpha_tensor_val + )) { + __ERROR("qnn-htp: Could not apply Alpha tensor "); + return false; + } + + if (!applyBinarySections(lora_config[lora_adapter_name].binsection_list)) { + __ERROR("qnn-htp: Could not apply binary Sections "); + return false; + } + + for (auto& g : m_nsp_graphs) { + for (auto& [n, variant] : g.variants) { + variant->refreshTensorQuantParams(); + } + } + + return true; +} + +size_t QnnNspModel::getEmbeddings(std::span embds) { + qualla::Timer start; + + QnnUtils::Tensor* output_spec = nullptr; + + if(m_pooled_output) + output_spec = m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::POOL_OUTPUT]); + else + output_spec = m_nsp_graphs.back().variants[run_info.n_tokens]->getOutput(m_layerNames[LayerType::SEQ_OUTPUT]); + + if(output_spec == nullptr) { + __ERROR("encountered null buffer"); + throw std::runtime_error("Model is not supporting per token embedding"); + } + const auto scale = output_spec->quantParam[0].scale; + const auto offset = output_spec->quantParam[0].offset; + + + auto output_datatype = QnnUtils::DataType(output_spec->tensor); + + int output_bw = output_spec->dtype.bw(); + + uint8_t* output_buffer = (uint8_t*)getBuffer(output_spec); + + const int return_size = m_pooled_output ? 1 : run_info.n_processed; + + if (!m_pooled_output) { + // If multiple tokens embedding are returned, offset to the correct location in the buffer + if (run_info.n_tokens == m_ctx_size) { + // This was left-padded, tokens embedding are at [n_tokens - n_processed, n_tokens] + output_buffer += (run_info.n_tokens - return_size) * m_embd_size * output_bw; + } else { + // This was right-padded, tokens embedding are at indexes [0, n_processed] + output_buffer += (run_info.n_processed - 1) * m_embd_size * output_bw; + } + } + + + const int output_len = static_cast(return_size * m_embd_size); + __TRACE("qnn-htp: get-embds for {} tokens. scale = {}, offset = {}, Returning {}", + run_info.n_processed, + scale, + offset, + output_len); + + switch (output_datatype) { + case QNN_DATATYPE_UFIXED_POINT_8: + deQuantizeOutputs((uint8_t*)output_buffer, embds, scale, offset, output_len); + break; + case QNN_DATATYPE_UFIXED_POINT_16: + deQuantizeOutputs((uint16_t*)output_buffer, embds, scale, offset, output_len); + break; + case QNN_DATATYPE_FLOAT_16: + castOutputs((uint16_t*)output_buffer, embds, output_len, output_bw); + break; + case QNN_DATATYPE_FLOAT_32: + castOutputs((float*)output_buffer, embds, output_len, output_bw); + break; + default: + __ERROR("Unsupported output datatype"); + } + + __DEBUG("qnn-htp: getEmbeddings complete : {} usec (return_size={})", + start.elapsed_usec(), + output_len); + return output_len; +} + +// Utility functions to convert structs from/to json for parsing/dumping +void from_json(const json& j, RopeScalingParams& p) { + p.rope_type = Config::optional(j, "rope-type", RopeScalingParams::DEFAULT); + if (p.rope_type == RopeScalingParams::ROPE_LLAMA3) { + try { + j.at("factor").get_to(p.llama3_params.factor); + j.at("low-freq-factor").get_to(p.llama3_params.low_freq_factor); + j.at("high-freq-factor").get_to(p.llama3_params.high_freq_factor); + j.at("original-max-position-embeddings") + .get_to(p.llama3_params.original_max_position_embeddings); + } catch (const json::exception& e) { + // clang-format off + throw std::runtime_error(fmt::format( "Parsing error for llama3 rope scaling - {}\n" + "llama3 requires keys ['original-max-position-embeddings', 'factor', 'low-freq-factor', 'high-freq-factor'].\n" + "Found config - {}", e.what(), j.dump())); + // clang-format on + } + } else if (p.rope_type == RopeScalingParams::ROPE_LONGROPE) { + try { + j.at("original-max-position-embeddings") + .get_to(p.longrope_params.original_max_position_embeddings); + j.at("long-factor").get_to(p.longrope_params.long_factor); + j.at("short-factor").get_to(p.longrope_params.short_factor); + if (j.contains("factor")) + j.at("factor").get_to(p.longrope_params.factor); + else + p.longrope_params.factor = j.at("max-position-embeddings").get() / + p.longrope_params.original_max_position_embeddings; + } catch (const json::exception& e) { + // clang-format off + throw std::runtime_error(fmt::format( "Parsing error for longrope scaling - {}\n" + "LongRope requires keys ['original-max-position-embeddings', 'factor' or 'max-position-embeddings', 'long-factor', 'short-factor'].\n" + "Found config - {}", e.what(), j.dump())); + // clang-format on + } + } +} + +void to_json(json& j, const RopeScalingParams& p) { + j["rope-type"] = p.rope_type; + if (p.rope_type == RopeScalingParams::ROPE_LLAMA3) { + j["factor"] = p.llama3_params.factor; + j["low-freq-factor"] = p.llama3_params.low_freq_factor; + j["high-freq-factor"] = p.llama3_params.high_freq_factor; + j["original-max-position-embeddings"] = p.llama3_params.original_max_position_embeddings; + } else if (p.rope_type == RopeScalingParams::ROPE_LONGROPE) { + j["factor"] = p.longrope_params.factor; + j["long-factor"] = p.longrope_params.long_factor; + j["short-factor"] = p.longrope_params.short_factor; + j["original-max-position-embeddings"] = p.longrope_params.original_max_position_embeddings; + } +} + +void from_json(const json& j, PositionalEncoding& p) { + p.type = Config::optional(j, "type", PositionalEncoding::ROPE); + if (p.type == PositionalEncoding::ROPE) { + p.rope_params.dims = Config::mandatory(j, "rope-dim"); + p.rope_params.theta = Config::optional(j, "rope-theta", 10000); + p.rope_params.rope_scaling = Config::optional(j, "rope-scaling", {}); + } +} + +void to_json(json& j, const PositionalEncoding& p) { + j["type"] = p.type; + if (p.type == PositionalEncoding::ROPE) { + j["rope-dim"] = p.rope_params.dims; + j["rope-theta"] = p.rope_params.theta; + j["rope-scaling"] = p.rope_params.rope_scaling; + } +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp new file mode 100644 index 0000000000000000000000000000000000000000..67eee738ed09a690f14ad73a6921499314040415 --- /dev/null +++ b/Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp @@ -0,0 +1,424 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#ifndef __QUALLA_NSP_MODEL_H_ +#define __QUALLA_NSP_MODEL_H_ + +#include +#include +#include +#include +#include + +#include "qualla/env.hpp" +#include "qualla/detail/threadpool.hpp" + +#include "QnnApi.hpp" +#include "IOTensor.hpp" + +#include "nsp-kvdispatcher.hpp" +#include "qnn-utils.hpp" +#include "nsp-graph.hpp" + +namespace qualla { + +enum ModelArchitectureType : uint8_t{ + DECODER = 0, + ENCODER = 1 +}; + +enum LoraConfigType : uint8_t{ + LORA_DISABLE = 0, + LORA_INPUT_WEIGHT_ENABLE = 1, + LORA_ADAPTER_WEIGHT_ENABLE = 2 +}; + +static const std::unordered_set supported_activations = { + QNN_DATATYPE_UFIXED_POINT_8, + QNN_DATATYPE_UFIXED_POINT_16, + QNN_DATATYPE_INT_32, + QNN_DATATYPE_FLOAT_16 +}; + +struct RopeScalingParams { + enum RopeType { DEFAULT, ROPE_LLAMA3, ROPE_LONGROPE } rope_type = DEFAULT; + + // This should be a union, but running into compilation issues with non-trivial dtr/copy-ctr + struct { + double factor; + double low_freq_factor; + double high_freq_factor; + int original_max_position_embeddings; + } llama3_params; + + struct { + double factor; + std::vector long_factor; + std::vector short_factor; + int original_max_position_embeddings; + } longrope_params; + + RopeScalingParams() {} +}; + +NLOHMANN_JSON_SERIALIZE_ENUM( + RopeScalingParams::RopeType, + {{RopeScalingParams::DEFAULT, "default"}, + {RopeScalingParams::ROPE_LLAMA3, "llama3"}, + {RopeScalingParams::ROPE_LONGROPE, "longrope"}} +) + +struct PositionalEncoding { + enum EncodingType : uint8_t { ROPE = 0x0, ABSOLUTE = 0x1, ALIBI = 0x2, UNDEFINED = 0xff } type; + struct { + int32_t dims; + double theta; + RopeScalingParams rope_scaling; + } rope_params; + + PositionalEncoding() { type = ROPE; } +}; + +NLOHMANN_JSON_SERIALIZE_ENUM( + PositionalEncoding::EncodingType, + {{PositionalEncoding::UNDEFINED, "undefined"}, + {PositionalEncoding::ROPE, "rope"}, + {PositionalEncoding::ABSOLUTE, "absolute"}, + {PositionalEncoding::ALIBI, "alibi"}} +) + +void from_json(const json& j, PositionalEncoding& p); +void to_json(json& j, const PositionalEncoding& p); +void from_json(const json& j, RopeScalingParams& p); +void to_json(json& j, const RopeScalingParams& p); + +class QnnNspModel { + protected: + Env& _env; + + // Populated by allocateTensors() + // Maps tensor name to allocation block index and block offset + std::map> tensor_alloc_info; + bool float32ToFloat16(uint8_t* out, + float* in, + size_t numElements); + + int32_t input_width = 1; + int32_t input_channel = 1; + int32_t input_bitWidth = 4; + + int32_t embedding_length = -1; + std::string embedding_datatype{"float32"}; + + // Maps layers to their tensor names. + std::map m_layerNames { + {LayerType::INPUT, "input_ids"}, + {LayerType::OUTPUT, "logits"}, + {LayerType::TOKEN_TYPE_IDS, "token_type_ids"}, + {LayerType::POOL_OUTPUT,"pooled_output"}, + {LayerType::SEQ_OUTPUT,"sequence_output"}, + {LayerType::ATTN_MASK, "attention_mask"}, + {LayerType::POS_SIN, "position_ids_sin"}, + {LayerType::POS_COS, "position_ids_cos"}, + {LayerType::POS_IDS, "position_ids"} + }; + + std::vector m_eosEmbedding; + public: + struct LoraConfig { + std::string lora_name; + std::vector binsection_list; //loarv2 adapter bins filenames + std::string path; //lorav1 weights directory. + std::string alpha_tensor_name; // loarv2 alpha tensor names + float alpha_tensor_val; //loarv2 alpha tensor values + }; + struct Params { + ModelArchitectureType modelArchitectureType; // Model architecture + std::filesystem::path model_basedir; // model basedir + std::vector model_list; // model filenames + std::map variant_latency; // latency for different variants + std::vector exec_select_graphs; // Execute selected graphs + bool load_select_graphs; // Load only graphs mentioned in exec_select_graphs from the context bin, by default all graphs are loaded + + bool use_mmap; + bool use_async_Init; + uint64_t mmap_budget; + int64_t spill_fill_bufsize; + int32_t ctx_size; + int32_t kv_dim; + int32_t pad_token; + size_t n_embd; + uint32_t n_threads{0}; + uint64_t cpumask{0}; + bool poll{false}; + std::string backend_lib; + std::string backend_ext_conf; + std::string debug_path; + bool debug_specs; + bool debug_tensors; + bool debug_outputs; + bool debug_qnn; + std::string kv_update_method; + std::string lmhead_weight_dir; + bool graph_switching; + LoraConfigType lora_config_type; + std::map lora_param; + std::string input_layer_name; + int32_t embedding_length; + std::string embedding_datatype; + bool pooled_output; + bool disable_kv_cache; + // Parameters for positional encodings + PositionalEncoding positional_encoding_params; + }; + + const std::filesystem::path model_basedir; + std::vector model_filelist; + std::string lmhead_weight_dir; + std::vector token_history; + std::map variant_latency; + std::vector exec_select_graphs; + bool load_select_graphs; + + InputType m_inputType{InputType::UNKNOWN}; + + LoraConfigType lora_conf; + std::map lora_config; + // QNN specific variables + const bool m_sharedBuffer{true}; + std::unique_ptr m_qnnApi; + std::unique_ptr m_ioTensor{nullptr}; + int64_t spill_fill_buffer_size; + bool m_use_mmap{false}; + bool m_use_async_Init{true}; + uint64_t mmap_budget; + bool graph_switching{false}; + size_t n_embd; + + + bool m_pooled_output{true}; + bool m_disableKvCache{false}; + // Model parameters + ModelArchitectureType m_modelArchitectureType; + int32_t m_ctx_size{-1}; + int32_t m_vocab_size{-1}; + int32_t m_kv_dim{-1}; + int32_t m_embd_size{-1}; + int32_t m_pad_token{-1}; + + size_t m_embeddingBufferSize{0}; + + QnnUtils::DataType d_input{QNN_DATATYPE_INT_32}, d_kv{QNN_DATATYPE_UFIXED_POINT_8}, + d_attn_map{QNN_DATATYPE_UFIXED_POINT_16}, d_token_type{QNN_DATATYPE_INT_32}; + + // int32_t attention_mask_bitwidth{2}, position_id_bitwidth{2}; + + // Information regarding model execution settings and last inference + struct RunInfo { + int32_t n_tokens; + size_t n_processed; + + std::vector tokens; + } run_info{-1, 0, {}}; + + // Model specific variables + uint32_t m_num_graphs; + bool _lora_enabled{false}; + bool _lmhead_weight_input{false}; + + // QnnNspGraph contains all GraphVariants for a specific split (with index=split_idx) + std::vector m_nsp_graphs; + // GraphVariant represents one input size within one split (e.g. KV$_split_1) + std::vector m_variant_list; + + // For ease of usage: Map from graph name to the corresponding GraphVariant + std::unordered_map m_graph_map; + // This map records how many graphs have been loaded for a particular input size + std::map nsp_graph_count; + + bool _threaded{false}; + uint64_t _cpumask{0}; + ThreadPool threadpool; + + KVManagerMode _kv_update_method{POINTER_SHIFT}; + + int32_t _kv_update_count{0}; + std::unique_ptr _kv_dispatcher; + + std::string _backend_lib; + std::string _backend_ext_conf; + + // Store some pointers for easier access + QnnUtils::Tensor* t_input_ids{nullptr}; + QnnUtils::Tensor* t_attn_mask{nullptr}; + QnnUtils::Tensor* t_token_type_ids{nullptr}; + + // Variables for positional encodings + PositionalEncoding m_positional_encoding; + QnnUtils::DataType d_pos{QNN_DATATYPE_UFIXED_POINT_16}; + // PositionalEncodingType::ABSOLUTE OR PositionalEncodingType::ALIBI + QnnUtils::Tensor* t_position_ids{nullptr}; + // PositionalEncodingType::ROPE variables + int32_t m_pos_dim{-1}; // Dimension of positional embedding tensor (incl partial_factor) + double rope_theta{10000.0}; // Base theta parameter for RoPE calculations + void* rope_sin{nullptr}; // Pre-calculated RoPE sin table of size [ctx_size, m_pos_dim] + void* rope_cos{nullptr}; // Pre-calculated RoPE cos table of size [ctx_size, m_pos_dim] + RopeScalingParams rope_scaling; // RoPE scaling parameters + + QnnUtils::Tensor* t_position_ids_sin{nullptr}; + QnnUtils::Tensor* t_position_ids_cos{nullptr}; + + // n_past defines number of population of kvcache + size_t m_nPast{0}; + + // Self-Specualtive Decoding + // This prefix is not for input tokens, but just for speical tokens + // Only the special tokens from the offset should attend the kv prefix + int32_t _size_to_skip_kv_prefix{0}; + int32_t _offset_to_apply_kv_prefix{0}; + + // Keep track of inference count + int m_inference_count = 0; + + // Debug mode settings + bool _debug_specs{false}; + bool _debug_tensors{false}; + bool _debug_outputs{false}; + bool _debug_qnn{false}; + std::string _debug_path; + + QnnNspModel(Env& env, const Params& params); + + ~QnnNspModel(); + + bool initializeModel(void); + bool validateModel(void); + bool initializeIOTensors(void); + bool initializeTensorPointers(); + bool initializeKVManager(); + bool calculate_rope_embeddings(void); + bool load_lmhead_weight_as_input(void); + bool flushLoraWeightsBuffers(void); + + template + bool setupAttentionMask( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset + ); + + bool setupAttentionMaskFP16( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset); + + bool setupRopePositionEmbeddingFP16( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset + ); + + template + bool setupRopePositionEmbedding( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset + ); + + template + bool setupAlibiPositionEmbedding( + bool pad_left, + int n_tokens, + int n_inputs, + int n_past + ); + + bool setupInputTensors( + std::span tokens, + int32_t n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset + ); + + bool setupInputTensors( + std::span embedding, + int32_t n_past, + std::span attention_map, + size_t n_skip_prefix, + size_t n_apply_prefix_offset + ); + + bool quantizeInput(float* in, size_t tensorOffset, size_t length); + + size_t getEmbeddingBufferSize(); + + size_t runInference( + const std::vector& tokens, + const std::vector& attention_map, + std::vector& output, + bool output_all = false + ); + + size_t runInference( + std::vector& embeddings, + const std::vector& attention_map, + std::vector& output, + bool output_all = false + ); + + bool cacheEosEmbedding(std::vector& eosEmbedding); + + bool setKVCacheNPast(size_t n_past, const std::vector& selected); + + size_t getEmbeddings(std::span embds); + + size_t getDequantLogits(std::span logits, bool logits_all = false); + + bool debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName); + + size_t loadKVCache(const std::string& load_path); + bool saveKVCache(const std::string& save_path); + bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val); + bool applyLoraAdapter(const std::string& lora_adapter_name); + bool applyBinarySections(std::vector& binsection_list); + bool applyLoraWeights(const std::string& lora_weights_name); + + protected: + // Internal functions to separate different runInference logic + int32_t selectVariantStrategy(int32_t n_inputs, int32_t n_past, int32_t cur_variant); + bool runInferenceHelper(bool pipeline, int32_t* total_wait, int32_t* total_exec); + + inline bool updateTensorPointer(GraphVariant& variant, std::string& key, QnnUtils::Tensor*& t); + inline void* getBuffer(QnnUtils::Tensor& spec) { return m_ioTensor->getBuffer(spec.tensor); } + inline void* getBuffer(QnnUtils::Tensor* spec) { return m_ioTensor->getBuffer(spec->tensor); } + inline size_t getBufferSize(QnnUtils::Tensor& spec) { return spec.dims.getSize(); } + inline size_t getBufferSize(QnnUtils::Tensor* spec) { return spec->dims.getSize(); } + + void dumpTensorSpecs(); +}; + +} // namespace qualla + +#endif diff --git a/Genie/Genie/src/qualla/env.cpp b/Genie/Genie/src/qualla/env.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef8bf21d23df24df620df2646cd00d22081abc95 --- /dev/null +++ b/Genie/Genie/src/qualla/env.cpp @@ -0,0 +1,51 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include + +#include + +namespace fs = std::filesystem; + +namespace qualla { + +Env::Env(const json& conf) { + _path.models = fs::path(); + _path.cache = fs::path(); + + if (conf.contains("path")) { + const json& p = conf["path"]; + + if (p.contains("models")) + _path.models = fs::path(p["models"].get()).make_preferred(); + if (p.contains("cache")) + _path.cache = fs::path(p["cache"].get()).make_preferred(); + } + + using qc = qualla::Config; + + // Create logger + const qualla::json& log_conf = qc::optional(conf, "log", {}); + _logger = Logger::create(log_conf); +} + +Env::~Env() {} + +std::shared_ptr Env::create(const qualla::json& conf) { + return std::make_shared(conf); +} + +std::shared_ptr Env::create(std::istream& json_stream) { + return create(json::parse(json_stream)); +} + +std::shared_ptr Env::create(const std::string& json_str) { + return create(json::parse(json_str)); +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/gpio-marker.cpp b/Genie/Genie/src/qualla/gpio-marker.cpp new file mode 100644 index 0000000000000000000000000000000000000000..63f52241f9b5df22257a7a926b851c4dced76b2d --- /dev/null +++ b/Genie/Genie/src/qualla/gpio-marker.cpp @@ -0,0 +1,66 @@ +//============================================================================== +// +// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +// All Rights Reserved. +// Confidential and Proprietary - Qualcomm Technologies, Inc. +// +//============================================================================== + +#include "qualla/detail/gpio-marker.hpp" +#include "fmt/format.h" + +namespace fs = std::filesystem; + +namespace qualla { + +GpioMarker::GpioMarker(const json& conf) { + // Parse config + using qc = qualla::Config; + + _tool_path = qc::optional(conf, "tool-path", ""); + _command = qc::optional(conf, "command", ""); + _gpio_num = qc::optional(conf, "gpio-num", -1); + + if (!_tool_path.empty()) { + if (fs::exists(_tool_path)) { + _gpio_marker_enable = true; + reset(); + } else { + _gpio_marker_enable = false; + } + } else { + _gpio_marker_enable = false; + } +} + +GpioMarker::~GpioMarker() {} + +void GpioMarker::set() { + if (!_gpio_marker_enable) return; + + _gpio_status = !_gpio_status; + std::string cmd = fmt::format("{} {} {}={}", _tool_path, _command, _gpio_num, _gpio_status); + system(cmd.c_str()); +} + +void GpioMarker::reset() { + if (!_gpio_marker_enable) return; + + std::string cmd = fmt::format("{} {} {}=0", _tool_path, _command, _gpio_num); + system(cmd.c_str()); + _gpio_status = 0; +} + +std::unique_ptr GpioMarker::create(const qualla::json& conf) { + return std::make_unique(conf); +} + +std::unique_ptr GpioMarker::create(std::istream& json_stream) { + return create(json::parse(json_stream)); +} + +std::unique_ptr GpioMarker::create(const std::string& json_str) { + return create(json::parse(json_str)); +} + +} // namespace qualla diff --git a/Genie/Genie/src/qualla/include/fmt/core.h b/Genie/Genie/src/qualla/include/fmt/core.h new file mode 100644 index 0000000000000000000000000000000000000000..f9e3b7d6dc1632c0596194f22218c976deedea54 --- /dev/null +++ b/Genie/Genie/src/qualla/include/fmt/core.h @@ -0,0 +1,2922 @@ +// Formatting library for C++ - the core API for char/UTF-8 +// +// Copyright (c) 2012 - present, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_CORE_H_ +#define FMT_CORE_H_ + +#include // std::byte +#include // std::FILE +#include // std::strlen +#include +#include +#include // std::addressof +#include +#include + +// The fmt library version in the form major * 10000 + minor * 100 + patch. +#define FMT_VERSION 100101 + +#if defined(__clang__) && !defined(__ibmxl__) +# define FMT_CLANG_VERSION (__clang_major__ * 100 + __clang_minor__) +#else +# define FMT_CLANG_VERSION 0 +#endif + +#if defined(__GNUC__) && !defined(__clang__) && !defined(__INTEL_COMPILER) && \ + !defined(__NVCOMPILER) +# define FMT_GCC_VERSION (__GNUC__ * 100 + __GNUC_MINOR__) +#else +# define FMT_GCC_VERSION 0 +#endif + +#ifndef FMT_GCC_PRAGMA +// Workaround _Pragma bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=59884. +# if FMT_GCC_VERSION >= 504 +# define FMT_GCC_PRAGMA(arg) _Pragma(arg) +# else +# define FMT_GCC_PRAGMA(arg) +# endif +#endif + +#ifdef __ICL +# define FMT_ICC_VERSION __ICL +#elif defined(__INTEL_COMPILER) +# define FMT_ICC_VERSION __INTEL_COMPILER +#else +# define FMT_ICC_VERSION 0 +#endif + +#ifdef _MSC_VER +# define FMT_MSC_VERSION _MSC_VER +# define FMT_MSC_WARNING(...) __pragma(warning(__VA_ARGS__)) +#else +# define FMT_MSC_VERSION 0 +# define FMT_MSC_WARNING(...) +#endif + +#ifdef _MSVC_LANG +# define FMT_CPLUSPLUS _MSVC_LANG +#else +# define FMT_CPLUSPLUS __cplusplus +#endif + +#ifdef __has_feature +# define FMT_HAS_FEATURE(x) __has_feature(x) +#else +# define FMT_HAS_FEATURE(x) 0 +#endif + +#if defined(__has_include) || FMT_ICC_VERSION >= 1600 || FMT_MSC_VERSION > 1900 +# define FMT_HAS_INCLUDE(x) __has_include(x) +#else +# define FMT_HAS_INCLUDE(x) 0 +#endif + +#ifdef __has_cpp_attribute +# define FMT_HAS_CPP_ATTRIBUTE(x) __has_cpp_attribute(x) +#else +# define FMT_HAS_CPP_ATTRIBUTE(x) 0 +#endif + +#define FMT_HAS_CPP14_ATTRIBUTE(attribute) \ + (FMT_CPLUSPLUS >= 201402L && FMT_HAS_CPP_ATTRIBUTE(attribute)) + +#define FMT_HAS_CPP17_ATTRIBUTE(attribute) \ + (FMT_CPLUSPLUS >= 201703L && FMT_HAS_CPP_ATTRIBUTE(attribute)) + +// Check if relaxed C++14 constexpr is supported. +// GCC doesn't allow throw in constexpr until version 6 (bug 67371). +#ifndef FMT_USE_CONSTEXPR +# if (FMT_HAS_FEATURE(cxx_relaxed_constexpr) || FMT_MSC_VERSION >= 1912 || \ + (FMT_GCC_VERSION >= 600 && FMT_CPLUSPLUS >= 201402L)) && \ + !FMT_ICC_VERSION && (!defined(__NVCC__) || FMT_CPLUSPLUS >= 202002L) +# define FMT_USE_CONSTEXPR 1 +# else +# define FMT_USE_CONSTEXPR 0 +# endif +#endif +#if FMT_USE_CONSTEXPR +# define FMT_CONSTEXPR constexpr +#else +# define FMT_CONSTEXPR +#endif + +#if ((FMT_CPLUSPLUS >= 202002L) && \ + (!defined(_GLIBCXX_RELEASE) || _GLIBCXX_RELEASE > 9)) || \ + (FMT_CPLUSPLUS >= 201709L && FMT_GCC_VERSION >= 1002) +# define FMT_CONSTEXPR20 constexpr +#else +# define FMT_CONSTEXPR20 +#endif + +// Check if constexpr std::char_traits<>::{compare,length} are supported. +#if defined(__GLIBCXX__) +# if FMT_CPLUSPLUS >= 201703L && defined(_GLIBCXX_RELEASE) && \ + _GLIBCXX_RELEASE >= 7 // GCC 7+ libstdc++ has _GLIBCXX_RELEASE. +# define FMT_CONSTEXPR_CHAR_TRAITS constexpr +# endif +#elif defined(_LIBCPP_VERSION) && FMT_CPLUSPLUS >= 201703L && \ + _LIBCPP_VERSION >= 4000 +# define FMT_CONSTEXPR_CHAR_TRAITS constexpr +#elif FMT_MSC_VERSION >= 1914 && FMT_CPLUSPLUS >= 201703L +# define FMT_CONSTEXPR_CHAR_TRAITS constexpr +#endif +#ifndef FMT_CONSTEXPR_CHAR_TRAITS +# define FMT_CONSTEXPR_CHAR_TRAITS +#endif + +// Check if exceptions are disabled. +#ifndef FMT_EXCEPTIONS +# if (defined(__GNUC__) && !defined(__EXCEPTIONS)) || \ + (FMT_MSC_VERSION && !_HAS_EXCEPTIONS) +# define FMT_EXCEPTIONS 0 +# else +# define FMT_EXCEPTIONS 1 +# endif +#endif + +// Disable [[noreturn]] on MSVC/NVCC because of bogus unreachable code warnings. +#if FMT_EXCEPTIONS && FMT_HAS_CPP_ATTRIBUTE(noreturn) && !FMT_MSC_VERSION && \ + !defined(__NVCC__) +# define FMT_NORETURN [[noreturn]] +#else +# define FMT_NORETURN +#endif + +#ifndef FMT_NODISCARD +# if FMT_HAS_CPP17_ATTRIBUTE(nodiscard) +# define FMT_NODISCARD [[nodiscard]] +# else +# define FMT_NODISCARD +# endif +#endif + +#ifndef FMT_INLINE +# if FMT_GCC_VERSION || FMT_CLANG_VERSION +# define FMT_INLINE inline __attribute__((always_inline)) +# else +# define FMT_INLINE inline +# endif +#endif + +#ifdef _MSC_VER +# define FMT_UNCHECKED_ITERATOR(It) \ + using _Unchecked_type = It // Mark iterator as checked. +#else +# define FMT_UNCHECKED_ITERATOR(It) using unchecked_type = It +#endif + +#ifndef FMT_BEGIN_NAMESPACE +# define FMT_BEGIN_NAMESPACE \ + namespace fmt { \ + inline namespace v10 { +# define FMT_END_NAMESPACE \ + } \ + } +#endif + +#ifndef FMT_EXPORT +# define FMT_EXPORT +# define FMT_BEGIN_EXPORT +# define FMT_END_EXPORT +#endif + +#if !defined(FMT_HEADER_ONLY) && defined(_WIN32) +# ifdef FMT_LIB_EXPORT +# define FMT_API __declspec(dllexport) +# elif defined(FMT_SHARED) +# define FMT_API __declspec(dllimport) +# endif +#else +# if defined(FMT_LIB_EXPORT) || defined(FMT_SHARED) +# if defined(__GNUC__) || defined(__clang__) +# define FMT_API __attribute__((visibility("default"))) +# endif +# endif +#endif +#ifndef FMT_API +# define FMT_API +#endif + +// libc++ supports string_view in pre-c++17. +#if FMT_HAS_INCLUDE() && \ + (FMT_CPLUSPLUS >= 201703L || defined(_LIBCPP_VERSION)) +# include +# define FMT_USE_STRING_VIEW +#elif FMT_HAS_INCLUDE("experimental/string_view") && FMT_CPLUSPLUS >= 201402L +# include +# define FMT_USE_EXPERIMENTAL_STRING_VIEW +#endif + +#ifndef FMT_UNICODE +# define FMT_UNICODE !FMT_MSC_VERSION +#endif + +#ifndef FMT_CONSTEVAL +# if ((FMT_GCC_VERSION >= 1000 || FMT_CLANG_VERSION >= 1101) && \ + (!defined(__apple_build_version__) || \ + __apple_build_version__ >= 14000029L) && \ + FMT_CPLUSPLUS >= 202002L) || \ + (defined(__cpp_consteval) && \ + (!FMT_MSC_VERSION || _MSC_FULL_VER >= 193030704)) +// consteval is broken in MSVC before VS2022 and Apple clang before 14. +# define FMT_CONSTEVAL consteval +# define FMT_HAS_CONSTEVAL +# else +# define FMT_CONSTEVAL +# endif +#endif + +#ifndef FMT_USE_NONTYPE_TEMPLATE_ARGS +# if defined(__cpp_nontype_template_args) && \ + ((FMT_GCC_VERSION >= 903 && FMT_CPLUSPLUS >= 201709L) || \ + __cpp_nontype_template_args >= 201911L) && \ + !defined(__NVCOMPILER) && !defined(__LCC__) +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 1 +# else +# define FMT_USE_NONTYPE_TEMPLATE_ARGS 0 +# endif +#endif + +// Enable minimal optimizations for more compact code in debug mode. +FMT_GCC_PRAGMA("GCC push_options") +#if !defined(__OPTIMIZE__) && !defined(__NVCOMPILER) && !defined(__LCC__) && \ + !defined(__CUDACC__) +FMT_GCC_PRAGMA("GCC optimize(\"Og\")") +#endif + +FMT_BEGIN_NAMESPACE + +// Implementations of enable_if_t and other metafunctions for older systems. +template +using enable_if_t = typename std::enable_if::type; +template +using conditional_t = typename std::conditional::type; +template using bool_constant = std::integral_constant; +template +using remove_reference_t = typename std::remove_reference::type; +template +using remove_const_t = typename std::remove_const::type; +template +using remove_cvref_t = typename std::remove_cv>::type; +template struct type_identity { using type = T; }; +template using type_identity_t = typename type_identity::type; +template +using underlying_t = typename std::underlying_type::type; + +// Checks whether T is a container with contiguous storage. +template struct is_contiguous : std::false_type {}; +template +struct is_contiguous> : std::true_type {}; + +struct monostate { + constexpr monostate() {} +}; + +// An enable_if helper to be used in template parameters which results in much +// shorter symbols: https://godbolt.org/z/sWw4vP. Extra parentheses are needed +// to workaround a bug in MSVC 2019 (see #1140 and #1186). +#ifdef FMT_DOC +# define FMT_ENABLE_IF(...) +#else +# define FMT_ENABLE_IF(...) fmt::enable_if_t<(__VA_ARGS__), int> = 0 +#endif + +// This is defined in core.h instead of format.h to avoid injecting in std. +// It is a template to avoid undesirable implicit conversions to std::byte. +#ifdef __cpp_lib_byte +template ::value)> +inline auto format_as(T b) -> unsigned char { + return static_cast(b); +} +#endif + +namespace detail { +// Suppresses "unused variable" warnings with the method described in +// https://herbsutter.com/2009/10/18/mailbag-shutting-up-compiler-warnings/. +// (void)var does not work on many Intel compilers. +template FMT_CONSTEXPR void ignore_unused(const T&...) {} + +constexpr FMT_INLINE auto is_constant_evaluated( + bool default_value = false) noexcept -> bool { +// Workaround for incompatibility between libstdc++ consteval-based +// std::is_constant_evaluated() implementation and clang-14. +// https://github.com/fmtlib/fmt/issues/3247 +#if FMT_CPLUSPLUS >= 202002L && defined(_GLIBCXX_RELEASE) && \ + _GLIBCXX_RELEASE >= 12 && \ + (FMT_CLANG_VERSION >= 1400 && FMT_CLANG_VERSION < 1500) + ignore_unused(default_value); + return __builtin_is_constant_evaluated(); +#elif defined(__cpp_lib_is_constant_evaluated) + ignore_unused(default_value); + return std::is_constant_evaluated(); +#else + return default_value; +#endif +} + +// Suppresses "conditional expression is constant" warnings. +template constexpr FMT_INLINE auto const_check(T value) -> T { + return value; +} + +FMT_NORETURN FMT_API void assert_fail(const char* file, int line, + const char* message); + +#ifndef FMT_ASSERT +# ifdef NDEBUG +// FMT_ASSERT is not empty to avoid -Wempty-body. +# define FMT_ASSERT(condition, message) \ + fmt::detail::ignore_unused((condition), (message)) +# else +# define FMT_ASSERT(condition, message) \ + ((condition) /* void() fails with -Winvalid-constexpr on clang 4.0.1 */ \ + ? (void)0 \ + : fmt::detail::assert_fail(__FILE__, __LINE__, (message))) +# endif +#endif + +#if defined(FMT_USE_STRING_VIEW) +template using std_string_view = std::basic_string_view; +#elif defined(FMT_USE_EXPERIMENTAL_STRING_VIEW) +template +using std_string_view = std::experimental::basic_string_view; +#else +template struct std_string_view {}; +#endif + +#ifdef FMT_USE_INT128 +// Do nothing. +#elif defined(__SIZEOF_INT128__) && !defined(__NVCC__) && \ + !(FMT_CLANG_VERSION && FMT_MSC_VERSION) +# define FMT_USE_INT128 1 +using int128_opt = __int128_t; // An optional native 128-bit integer. +using uint128_opt = __uint128_t; +template inline auto convert_for_visit(T value) -> T { + return value; +} +#else +# define FMT_USE_INT128 0 +#endif +#if !FMT_USE_INT128 +enum class int128_opt {}; +enum class uint128_opt {}; +// Reduce template instantiations. +template auto convert_for_visit(T) -> monostate { return {}; } +#endif + +// Casts a nonnegative integer to unsigned. +template +FMT_CONSTEXPR auto to_unsigned(Int value) -> + typename std::make_unsigned::type { + FMT_ASSERT(std::is_unsigned::value || value >= 0, "negative value"); + return static_cast::type>(value); +} + +FMT_CONSTEXPR inline auto is_utf8() -> bool { + FMT_MSC_WARNING(suppress : 4566) constexpr unsigned char section[] = "\u00A7"; + + // Avoid buggy sign extensions in MSVC's constant evaluation mode (#2297). + using uchar = unsigned char; + return FMT_UNICODE || (sizeof(section) == 3 && uchar(section[0]) == 0xC2 && + uchar(section[1]) == 0xA7); +} +} // namespace detail + +/** + An implementation of ``std::basic_string_view`` for pre-C++17. It provides a + subset of the API. ``fmt::basic_string_view`` is used for format strings even + if ``std::string_view`` is available to prevent issues when a library is + compiled with a different ``-std`` option than the client code (which is not + recommended). + */ +FMT_EXPORT +template class basic_string_view { + private: + const Char* data_; + size_t size_; + + public: + using value_type = Char; + using iterator = const Char*; + + constexpr basic_string_view() noexcept : data_(nullptr), size_(0) {} + + /** Constructs a string reference object from a C string and a size. */ + constexpr basic_string_view(const Char* s, size_t count) noexcept + : data_(s), size_(count) {} + + /** + \rst + Constructs a string reference object from a C string computing + the size with ``std::char_traits::length``. + \endrst + */ + FMT_CONSTEXPR_CHAR_TRAITS + FMT_INLINE + basic_string_view(const Char* s) + : data_(s), + size_(detail::const_check(std::is_same::value && + !detail::is_constant_evaluated(true)) + ? std::strlen(reinterpret_cast(s)) + : std::char_traits::length(s)) {} + + /** Constructs a string reference from a ``std::basic_string`` object. */ + template + FMT_CONSTEXPR basic_string_view( + const std::basic_string& s) noexcept + : data_(s.data()), size_(s.size()) {} + + template >::value)> + FMT_CONSTEXPR basic_string_view(S s) noexcept + : data_(s.data()), size_(s.size()) {} + + /** Returns a pointer to the string data. */ + constexpr auto data() const noexcept -> const Char* { return data_; } + + /** Returns the string size. */ + constexpr auto size() const noexcept -> size_t { return size_; } + + constexpr auto begin() const noexcept -> iterator { return data_; } + constexpr auto end() const noexcept -> iterator { return data_ + size_; } + + constexpr auto operator[](size_t pos) const noexcept -> const Char& { + return data_[pos]; + } + + FMT_CONSTEXPR void remove_prefix(size_t n) noexcept { + data_ += n; + size_ -= n; + } + + FMT_CONSTEXPR_CHAR_TRAITS bool starts_with( + basic_string_view sv) const noexcept { + return size_ >= sv.size_ && + std::char_traits::compare(data_, sv.data_, sv.size_) == 0; + } + FMT_CONSTEXPR_CHAR_TRAITS bool starts_with(Char c) const noexcept { + return size_ >= 1 && std::char_traits::eq(*data_, c); + } + FMT_CONSTEXPR_CHAR_TRAITS bool starts_with(const Char* s) const { + return starts_with(basic_string_view(s)); + } + + // Lexicographically compare this string reference to other. + FMT_CONSTEXPR_CHAR_TRAITS auto compare(basic_string_view other) const -> int { + size_t str_size = size_ < other.size_ ? size_ : other.size_; + int result = std::char_traits::compare(data_, other.data_, str_size); + if (result == 0) + result = size_ == other.size_ ? 0 : (size_ < other.size_ ? -1 : 1); + return result; + } + + FMT_CONSTEXPR_CHAR_TRAITS friend auto operator==(basic_string_view lhs, + basic_string_view rhs) + -> bool { + return lhs.compare(rhs) == 0; + } + friend auto operator!=(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) != 0; + } + friend auto operator<(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) < 0; + } + friend auto operator<=(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) <= 0; + } + friend auto operator>(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) > 0; + } + friend auto operator>=(basic_string_view lhs, basic_string_view rhs) -> bool { + return lhs.compare(rhs) >= 0; + } +}; + +FMT_EXPORT +using string_view = basic_string_view; + +/** Specifies if ``T`` is a character type. Can be specialized by users. */ +FMT_EXPORT +template struct is_char : std::false_type {}; +template <> struct is_char : std::true_type {}; + +namespace detail { + +// A base class for compile-time strings. +struct compile_string {}; + +template +struct is_compile_string : std::is_base_of {}; + +template ::value)> +FMT_INLINE auto to_string_view(const Char* s) -> basic_string_view { + return s; +} +template +inline auto to_string_view(const std::basic_string& s) + -> basic_string_view { + return s; +} +template +constexpr auto to_string_view(basic_string_view s) + -> basic_string_view { + return s; +} +template >::value)> +inline auto to_string_view(std_string_view s) -> basic_string_view { + return s; +} +template ::value)> +constexpr auto to_string_view(const S& s) + -> basic_string_view { + return basic_string_view(s); +} +void to_string_view(...); + +// Specifies whether S is a string type convertible to fmt::basic_string_view. +// It should be a constexpr function but MSVC 2017 fails to compile it in +// enable_if and MSVC 2015 fails to compile it as an alias template. +// ADL is intentionally disabled as to_string_view is not an extension point. +template +struct is_string + : std::is_class()))> {}; + +template struct char_t_impl {}; +template struct char_t_impl::value>> { + using result = decltype(to_string_view(std::declval())); + using type = typename result::value_type; +}; + +enum class type { + none_type, + // Integer types should go first, + int_type, + uint_type, + long_long_type, + ulong_long_type, + int128_type, + uint128_type, + bool_type, + char_type, + last_integer_type = char_type, + // followed by floating-point types. + float_type, + double_type, + long_double_type, + last_numeric_type = long_double_type, + cstring_type, + string_type, + pointer_type, + custom_type +}; + +// Maps core type T to the corresponding type enum constant. +template +struct type_constant : std::integral_constant {}; + +#define FMT_TYPE_CONSTANT(Type, constant) \ + template \ + struct type_constant \ + : std::integral_constant {} + +FMT_TYPE_CONSTANT(int, int_type); +FMT_TYPE_CONSTANT(unsigned, uint_type); +FMT_TYPE_CONSTANT(long long, long_long_type); +FMT_TYPE_CONSTANT(unsigned long long, ulong_long_type); +FMT_TYPE_CONSTANT(int128_opt, int128_type); +FMT_TYPE_CONSTANT(uint128_opt, uint128_type); +FMT_TYPE_CONSTANT(bool, bool_type); +FMT_TYPE_CONSTANT(Char, char_type); +FMT_TYPE_CONSTANT(float, float_type); +FMT_TYPE_CONSTANT(double, double_type); +FMT_TYPE_CONSTANT(long double, long_double_type); +FMT_TYPE_CONSTANT(const Char*, cstring_type); +FMT_TYPE_CONSTANT(basic_string_view, string_type); +FMT_TYPE_CONSTANT(const void*, pointer_type); + +constexpr bool is_integral_type(type t) { + return t > type::none_type && t <= type::last_integer_type; +} +constexpr bool is_arithmetic_type(type t) { + return t > type::none_type && t <= type::last_numeric_type; +} + +constexpr auto set(type rhs) -> int { return 1 << static_cast(rhs); } +constexpr auto in(type t, int set) -> bool { + return ((set >> static_cast(t)) & 1) != 0; +} + +// Bitsets of types. +enum { + sint_set = + set(type::int_type) | set(type::long_long_type) | set(type::int128_type), + uint_set = set(type::uint_type) | set(type::ulong_long_type) | + set(type::uint128_type), + bool_set = set(type::bool_type), + char_set = set(type::char_type), + float_set = set(type::float_type) | set(type::double_type) | + set(type::long_double_type), + string_set = set(type::string_type), + cstring_set = set(type::cstring_type), + pointer_set = set(type::pointer_type) +}; + +FMT_NORETURN FMT_API void throw_format_error(const char* message); + +struct error_handler { + constexpr error_handler() = default; + + // This function is intentionally not constexpr to give a compile-time error. + FMT_NORETURN void on_error(const char* message) { + throw_format_error(message); + } +}; +} // namespace detail + +/** Throws ``format_error`` with a given message. */ +using detail::throw_format_error; + +/** String's character type. */ +template using char_t = typename detail::char_t_impl::type; + +/** + \rst + Parsing context consisting of a format string range being parsed and an + argument counter for automatic indexing. + You can use the ``format_parse_context`` type alias for ``char`` instead. + \endrst + */ +FMT_EXPORT +template class basic_format_parse_context { + private: + basic_string_view format_str_; + int next_arg_id_; + + FMT_CONSTEXPR void do_check_arg_id(int id); + + public: + using char_type = Char; + using iterator = const Char*; + + explicit constexpr basic_format_parse_context( + basic_string_view format_str, int next_arg_id = 0) + : format_str_(format_str), next_arg_id_(next_arg_id) {} + + /** + Returns an iterator to the beginning of the format string range being + parsed. + */ + constexpr auto begin() const noexcept -> iterator { + return format_str_.begin(); + } + + /** + Returns an iterator past the end of the format string range being parsed. + */ + constexpr auto end() const noexcept -> iterator { return format_str_.end(); } + + /** Advances the begin iterator to ``it``. */ + FMT_CONSTEXPR void advance_to(iterator it) { + format_str_.remove_prefix(detail::to_unsigned(it - begin())); + } + + /** + Reports an error if using the manual argument indexing; otherwise returns + the next argument index and switches to the automatic indexing. + */ + FMT_CONSTEXPR auto next_arg_id() -> int { + if (next_arg_id_ < 0) { + detail::throw_format_error( + "cannot switch from manual to automatic argument indexing"); + return 0; + } + int id = next_arg_id_++; + do_check_arg_id(id); + return id; + } + + /** + Reports an error if using the automatic argument indexing; otherwise + switches to the manual indexing. + */ + FMT_CONSTEXPR void check_arg_id(int id) { + if (next_arg_id_ > 0) { + detail::throw_format_error( + "cannot switch from automatic to manual argument indexing"); + return; + } + next_arg_id_ = -1; + do_check_arg_id(id); + } + FMT_CONSTEXPR void check_arg_id(basic_string_view) {} + FMT_CONSTEXPR void check_dynamic_spec(int arg_id); +}; + +FMT_EXPORT +using format_parse_context = basic_format_parse_context; + +namespace detail { +// A parse context with extra data used only in compile-time checks. +template +class compile_parse_context : public basic_format_parse_context { + private: + int num_args_; + const type* types_; + using base = basic_format_parse_context; + + public: + explicit FMT_CONSTEXPR compile_parse_context( + basic_string_view format_str, int num_args, const type* types, + int next_arg_id = 0) + : base(format_str, next_arg_id), num_args_(num_args), types_(types) {} + + constexpr auto num_args() const -> int { return num_args_; } + constexpr auto arg_type(int id) const -> type { return types_[id]; } + + FMT_CONSTEXPR auto next_arg_id() -> int { + int id = base::next_arg_id(); + if (id >= num_args_) throw_format_error("argument not found"); + return id; + } + + FMT_CONSTEXPR void check_arg_id(int id) { + base::check_arg_id(id); + if (id >= num_args_) throw_format_error("argument not found"); + } + using base::check_arg_id; + + FMT_CONSTEXPR void check_dynamic_spec(int arg_id) { + detail::ignore_unused(arg_id); +#if !defined(__LCC__) + if (arg_id < num_args_ && types_ && !is_integral_type(types_[arg_id])) + throw_format_error("width/precision is not integer"); +#endif + } +}; + +// Extracts a reference to the container from back_insert_iterator. +template +inline auto get_container(std::back_insert_iterator it) + -> Container& { + using base = std::back_insert_iterator; + struct accessor : base { + accessor(base b) : base(b) {} + using base::container; + }; + return *accessor(it).container; +} + +template +FMT_CONSTEXPR auto copy_str(InputIt begin, InputIt end, OutputIt out) + -> OutputIt { + while (begin != end) *out++ = static_cast(*begin++); + return out; +} + +template , U>::value&& is_char::value)> +FMT_CONSTEXPR auto copy_str(T* begin, T* end, U* out) -> U* { + if (is_constant_evaluated()) return copy_str(begin, end, out); + auto size = to_unsigned(end - begin); + if (size > 0) memcpy(out, begin, size * sizeof(U)); + return out + size; +} + +/** + \rst + A contiguous memory buffer with an optional growing ability. It is an internal + class and shouldn't be used directly, only via `~fmt::basic_memory_buffer`. + \endrst + */ +template class buffer { + private: + T* ptr_; + size_t size_; + size_t capacity_; + + protected: + // Don't initialize ptr_ since it is not accessed to save a few cycles. + FMT_MSC_WARNING(suppress : 26495) + buffer(size_t sz) noexcept : size_(sz), capacity_(sz) {} + + FMT_CONSTEXPR20 buffer(T* p = nullptr, size_t sz = 0, size_t cap = 0) noexcept + : ptr_(p), size_(sz), capacity_(cap) {} + + FMT_CONSTEXPR20 ~buffer() = default; + buffer(buffer&&) = default; + + /** Sets the buffer data and capacity. */ + FMT_CONSTEXPR void set(T* buf_data, size_t buf_capacity) noexcept { + ptr_ = buf_data; + capacity_ = buf_capacity; + } + + /** Increases the buffer capacity to hold at least *capacity* elements. */ + virtual FMT_CONSTEXPR20 void grow(size_t capacity) = 0; + + public: + using value_type = T; + using const_reference = const T&; + + buffer(const buffer&) = delete; + void operator=(const buffer&) = delete; + + FMT_INLINE auto begin() noexcept -> T* { return ptr_; } + FMT_INLINE auto end() noexcept -> T* { return ptr_ + size_; } + + FMT_INLINE auto begin() const noexcept -> const T* { return ptr_; } + FMT_INLINE auto end() const noexcept -> const T* { return ptr_ + size_; } + + /** Returns the size of this buffer. */ + constexpr auto size() const noexcept -> size_t { return size_; } + + /** Returns the capacity of this buffer. */ + constexpr auto capacity() const noexcept -> size_t { return capacity_; } + + /** Returns a pointer to the buffer data (not null-terminated). */ + FMT_CONSTEXPR auto data() noexcept -> T* { return ptr_; } + FMT_CONSTEXPR auto data() const noexcept -> const T* { return ptr_; } + + /** Clears this buffer. */ + void clear() { size_ = 0; } + + // Tries resizing the buffer to contain *count* elements. If T is a POD type + // the new elements may not be initialized. + FMT_CONSTEXPR20 void try_resize(size_t count) { + try_reserve(count); + size_ = count <= capacity_ ? count : capacity_; + } + + // Tries increasing the buffer capacity to *new_capacity*. It can increase the + // capacity by a smaller amount than requested but guarantees there is space + // for at least one additional element either by increasing the capacity or by + // flushing the buffer if it is full. + FMT_CONSTEXPR20 void try_reserve(size_t new_capacity) { + if (new_capacity > capacity_) grow(new_capacity); + } + + FMT_CONSTEXPR20 void push_back(const T& value) { + try_reserve(size_ + 1); + ptr_[size_++] = value; + } + + /** Appends data to the end of the buffer. */ + template void append(const U* begin, const U* end); + + template FMT_CONSTEXPR auto operator[](Idx index) -> T& { + return ptr_[index]; + } + template + FMT_CONSTEXPR auto operator[](Idx index) const -> const T& { + return ptr_[index]; + } +}; + +struct buffer_traits { + explicit buffer_traits(size_t) {} + auto count() const -> size_t { return 0; } + auto limit(size_t size) -> size_t { return size; } +}; + +class fixed_buffer_traits { + private: + size_t count_ = 0; + size_t limit_; + + public: + explicit fixed_buffer_traits(size_t limit) : limit_(limit) {} + auto count() const -> size_t { return count_; } + auto limit(size_t size) -> size_t { + size_t n = limit_ > count_ ? limit_ - count_ : 0; + count_ += size; + return size < n ? size : n; + } +}; + +// A buffer that writes to an output iterator when flushed. +template +class iterator_buffer final : public Traits, public buffer { + private: + OutputIt out_; + enum { buffer_size = 256 }; + T data_[buffer_size]; + + protected: + FMT_CONSTEXPR20 void grow(size_t) override { + if (this->size() == buffer_size) flush(); + } + + void flush() { + auto size = this->size(); + this->clear(); + out_ = copy_str(data_, data_ + this->limit(size), out_); + } + + public: + explicit iterator_buffer(OutputIt out, size_t n = buffer_size) + : Traits(n), buffer(data_, 0, buffer_size), out_(out) {} + iterator_buffer(iterator_buffer&& other) + : Traits(other), buffer(data_, 0, buffer_size), out_(other.out_) {} + ~iterator_buffer() { flush(); } + + auto out() -> OutputIt { + flush(); + return out_; + } + auto count() const -> size_t { return Traits::count() + this->size(); } +}; + +template +class iterator_buffer final + : public fixed_buffer_traits, + public buffer { + private: + T* out_; + enum { buffer_size = 256 }; + T data_[buffer_size]; + + protected: + FMT_CONSTEXPR20 void grow(size_t) override { + if (this->size() == this->capacity()) flush(); + } + + void flush() { + size_t n = this->limit(this->size()); + if (this->data() == out_) { + out_ += n; + this->set(data_, buffer_size); + } + this->clear(); + } + + public: + explicit iterator_buffer(T* out, size_t n = buffer_size) + : fixed_buffer_traits(n), buffer(out, 0, n), out_(out) {} + iterator_buffer(iterator_buffer&& other) + : fixed_buffer_traits(other), + buffer(std::move(other)), + out_(other.out_) { + if (this->data() != out_) { + this->set(data_, buffer_size); + this->clear(); + } + } + ~iterator_buffer() { flush(); } + + auto out() -> T* { + flush(); + return out_; + } + auto count() const -> size_t { + return fixed_buffer_traits::count() + this->size(); + } +}; + +template class iterator_buffer final : public buffer { + protected: + FMT_CONSTEXPR20 void grow(size_t) override {} + + public: + explicit iterator_buffer(T* out, size_t = 0) : buffer(out, 0, ~size_t()) {} + + auto out() -> T* { return &*this->end(); } +}; + +// A buffer that writes to a container with the contiguous storage. +template +class iterator_buffer, + enable_if_t::value, + typename Container::value_type>> + final : public buffer { + private: + Container& container_; + + protected: + FMT_CONSTEXPR20 void grow(size_t capacity) override { + container_.resize(capacity); + this->set(&container_[0], capacity); + } + + public: + explicit iterator_buffer(Container& c) + : buffer(c.size()), container_(c) {} + explicit iterator_buffer(std::back_insert_iterator out, size_t = 0) + : iterator_buffer(get_container(out)) {} + + auto out() -> std::back_insert_iterator { + return std::back_inserter(container_); + } +}; + +// A buffer that counts the number of code units written discarding the output. +template class counting_buffer final : public buffer { + private: + enum { buffer_size = 256 }; + T data_[buffer_size]; + size_t count_ = 0; + + protected: + FMT_CONSTEXPR20 void grow(size_t) override { + if (this->size() != buffer_size) return; + count_ += this->size(); + this->clear(); + } + + public: + counting_buffer() : buffer(data_, 0, buffer_size) {} + + auto count() -> size_t { return count_ + this->size(); } +}; +} // namespace detail + +template +FMT_CONSTEXPR void basic_format_parse_context::do_check_arg_id(int id) { + // Argument id is only checked at compile-time during parsing because + // formatting has its own validation. + if (detail::is_constant_evaluated() && + (!FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200)) { + using context = detail::compile_parse_context; + if (id >= static_cast(this)->num_args()) + detail::throw_format_error("argument not found"); + } +} + +template +FMT_CONSTEXPR void basic_format_parse_context::check_dynamic_spec( + int arg_id) { + if (detail::is_constant_evaluated() && + (!FMT_GCC_VERSION || FMT_GCC_VERSION >= 1200)) { + using context = detail::compile_parse_context; + static_cast(this)->check_dynamic_spec(arg_id); + } +} + +FMT_EXPORT template class basic_format_arg; +FMT_EXPORT template class basic_format_args; +FMT_EXPORT template class dynamic_format_arg_store; + +// A formatter for objects of type T. +FMT_EXPORT +template +struct formatter { + // A deleted default constructor indicates a disabled formatter. + formatter() = delete; +}; + +// Specifies if T has an enabled formatter specialization. A type can be +// formattable even if it doesn't have a formatter e.g. via a conversion. +template +using has_formatter = + std::is_constructible>; + +// An output iterator that appends to a buffer. +// It is used to reduce symbol sizes for the common case. +class appender : public std::back_insert_iterator> { + using base = std::back_insert_iterator>; + + public: + using std::back_insert_iterator>::back_insert_iterator; + appender(base it) noexcept : base(it) {} + FMT_UNCHECKED_ITERATOR(appender); + + auto operator++() noexcept -> appender& { return *this; } + auto operator++(int) noexcept -> appender { return *this; } +}; + +namespace detail { + +template +constexpr auto has_const_formatter_impl(T*) + -> decltype(typename Context::template formatter_type().format( + std::declval(), std::declval()), + true) { + return true; +} +template +constexpr auto has_const_formatter_impl(...) -> bool { + return false; +} +template +constexpr auto has_const_formatter() -> bool { + return has_const_formatter_impl(static_cast(nullptr)); +} + +template +using buffer_appender = conditional_t::value, appender, + std::back_insert_iterator>>; + +// Maps an output iterator to a buffer. +template +auto get_buffer(OutputIt out) -> iterator_buffer { + return iterator_buffer(out); +} +template , Buf>::value)> +auto get_buffer(std::back_insert_iterator out) -> buffer& { + return get_container(out); +} + +template +FMT_INLINE auto get_iterator(Buf& buf, OutputIt) -> decltype(buf.out()) { + return buf.out(); +} +template +auto get_iterator(buffer&, OutputIt out) -> OutputIt { + return out; +} + +struct view {}; + +template struct named_arg : view { + const Char* name; + const T& value; + named_arg(const Char* n, const T& v) : name(n), value(v) {} +}; + +template struct named_arg_info { + const Char* name; + int id; +}; + +template +struct arg_data { + // args_[0].named_args points to named_args_ to avoid bloating format_args. + // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. + T args_[1 + (NUM_ARGS != 0 ? NUM_ARGS : +1)]; + named_arg_info named_args_[NUM_NAMED_ARGS]; + + template + arg_data(const U&... init) : args_{T(named_args_, NUM_NAMED_ARGS), init...} {} + arg_data(const arg_data& other) = delete; + auto args() const -> const T* { return args_ + 1; } + auto named_args() -> named_arg_info* { return named_args_; } +}; + +template +struct arg_data { + // +1 to workaround a bug in gcc 7.5 that causes duplicated-branches warning. + T args_[NUM_ARGS != 0 ? NUM_ARGS : +1]; + + template + FMT_CONSTEXPR FMT_INLINE arg_data(const U&... init) : args_{init...} {} + FMT_CONSTEXPR FMT_INLINE auto args() const -> const T* { return args_; } + FMT_CONSTEXPR FMT_INLINE auto named_args() -> std::nullptr_t { + return nullptr; + } +}; + +template +inline void init_named_args(named_arg_info*, int, int) {} + +template struct is_named_arg : std::false_type {}; +template struct is_statically_named_arg : std::false_type {}; + +template +struct is_named_arg> : std::true_type {}; + +template ::value)> +void init_named_args(named_arg_info* named_args, int arg_count, + int named_arg_count, const T&, const Tail&... args) { + init_named_args(named_args, arg_count + 1, named_arg_count, args...); +} + +template ::value)> +void init_named_args(named_arg_info* named_args, int arg_count, + int named_arg_count, const T& arg, const Tail&... args) { + named_args[named_arg_count++] = {arg.name, arg_count}; + init_named_args(named_args, arg_count + 1, named_arg_count, args...); +} + +template +FMT_CONSTEXPR FMT_INLINE void init_named_args(std::nullptr_t, int, int, + const Args&...) {} + +template constexpr auto count() -> size_t { return B ? 1 : 0; } +template constexpr auto count() -> size_t { + return (B1 ? 1 : 0) + count(); +} + +template constexpr auto count_named_args() -> size_t { + return count::value...>(); +} + +template +constexpr auto count_statically_named_args() -> size_t { + return count::value...>(); +} + +struct unformattable {}; +struct unformattable_char : unformattable {}; +struct unformattable_pointer : unformattable {}; + +template struct string_value { + const Char* data; + size_t size; +}; + +template struct named_arg_value { + const named_arg_info* data; + size_t size; +}; + +template struct custom_value { + using parse_context = typename Context::parse_context_type; + void* value; + void (*format)(void* arg, parse_context& parse_ctx, Context& ctx); +}; + +// A formatting argument value. +template class value { + public: + using char_type = typename Context::char_type; + + union { + monostate no_value; + int int_value; + unsigned uint_value; + long long long_long_value; + unsigned long long ulong_long_value; + int128_opt int128_value; + uint128_opt uint128_value; + bool bool_value; + char_type char_value; + float float_value; + double double_value; + long double long_double_value; + const void* pointer; + string_value string; + custom_value custom; + named_arg_value named_args; + }; + + constexpr FMT_INLINE value() : no_value() {} + constexpr FMT_INLINE value(int val) : int_value(val) {} + constexpr FMT_INLINE value(unsigned val) : uint_value(val) {} + constexpr FMT_INLINE value(long long val) : long_long_value(val) {} + constexpr FMT_INLINE value(unsigned long long val) : ulong_long_value(val) {} + FMT_INLINE value(int128_opt val) : int128_value(val) {} + FMT_INLINE value(uint128_opt val) : uint128_value(val) {} + constexpr FMT_INLINE value(float val) : float_value(val) {} + constexpr FMT_INLINE value(double val) : double_value(val) {} + FMT_INLINE value(long double val) : long_double_value(val) {} + constexpr FMT_INLINE value(bool val) : bool_value(val) {} + constexpr FMT_INLINE value(char_type val) : char_value(val) {} + FMT_CONSTEXPR FMT_INLINE value(const char_type* val) { + string.data = val; + if (is_constant_evaluated()) string.size = {}; + } + FMT_CONSTEXPR FMT_INLINE value(basic_string_view val) { + string.data = val.data(); + string.size = val.size(); + } + FMT_INLINE value(const void* val) : pointer(val) {} + FMT_INLINE value(const named_arg_info* args, size_t size) + : named_args{args, size} {} + + template FMT_CONSTEXPR20 FMT_INLINE value(T& val) { + using value_type = remove_const_t; + custom.value = const_cast(std::addressof(val)); + // Get the formatter type through the context to allow different contexts + // have different extension points, e.g. `formatter` for `format` and + // `printf_formatter` for `printf`. + custom.format = format_custom_arg< + value_type, typename Context::template formatter_type>; + } + value(unformattable); + value(unformattable_char); + value(unformattable_pointer); + + private: + // Formats an argument of a custom type, such as a user-defined class. + template + static void format_custom_arg(void* arg, + typename Context::parse_context_type& parse_ctx, + Context& ctx) { + auto f = Formatter(); + parse_ctx.advance_to(f.parse(parse_ctx)); + using qualified_type = + conditional_t(), const T, T>; + ctx.advance_to(f.format(*static_cast(arg), ctx)); + } +}; + +// To minimize the number of types we need to deal with, long is translated +// either to int or to long long depending on its size. +enum { long_short = sizeof(long) == sizeof(int) }; +using long_type = conditional_t; +using ulong_type = conditional_t; + +template struct format_as_result { + template ::value || std::is_class::value)> + static auto map(U*) -> decltype(format_as(std::declval())); + static auto map(...) -> void; + + using type = decltype(map(static_cast(nullptr))); +}; +template using format_as_t = typename format_as_result::type; + +template +struct has_format_as + : bool_constant, void>::value> {}; + +// Maps formatting arguments to core types. +// arg_mapper reports errors by returning unformattable instead of using +// static_assert because it's used in the is_formattable trait. +template struct arg_mapper { + using char_type = typename Context::char_type; + + FMT_CONSTEXPR FMT_INLINE auto map(signed char val) -> int { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(unsigned char val) -> unsigned { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(short val) -> int { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(unsigned short val) -> unsigned { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(int val) -> int { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(unsigned val) -> unsigned { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(long val) -> long_type { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(unsigned long val) -> ulong_type { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(long long val) -> long long { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(unsigned long long val) + -> unsigned long long { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(int128_opt val) -> int128_opt { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(uint128_opt val) -> uint128_opt { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(bool val) -> bool { return val; } + + template ::value || + std::is_same::value)> + FMT_CONSTEXPR FMT_INLINE auto map(T val) -> char_type { + return val; + } + template ::value || +#ifdef __cpp_char8_t + std::is_same::value || +#endif + std::is_same::value || + std::is_same::value) && + !std::is_same::value, + int> = 0> + FMT_CONSTEXPR FMT_INLINE auto map(T) -> unformattable_char { + return {}; + } + + FMT_CONSTEXPR FMT_INLINE auto map(float val) -> float { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(double val) -> double { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(long double val) -> long double { + return val; + } + + FMT_CONSTEXPR FMT_INLINE auto map(char_type* val) -> const char_type* { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(const char_type* val) -> const char_type* { + return val; + } + template ::value && !std::is_pointer::value && + std::is_same>::value)> + FMT_CONSTEXPR FMT_INLINE auto map(const T& val) + -> basic_string_view { + return to_string_view(val); + } + template ::value && !std::is_pointer::value && + !std::is_same>::value)> + FMT_CONSTEXPR FMT_INLINE auto map(const T&) -> unformattable_char { + return {}; + } + + FMT_CONSTEXPR FMT_INLINE auto map(void* val) -> const void* { return val; } + FMT_CONSTEXPR FMT_INLINE auto map(const void* val) -> const void* { + return val; + } + FMT_CONSTEXPR FMT_INLINE auto map(std::nullptr_t val) -> const void* { + return val; + } + + // Use SFINAE instead of a const T* parameter to avoid a conflict with the + // array overload. + template < + typename T, + FMT_ENABLE_IF( + std::is_pointer::value || std::is_member_pointer::value || + std::is_function::type>::value || + (std::is_array::value && + !std::is_convertible::value))> + FMT_CONSTEXPR auto map(const T&) -> unformattable_pointer { + return {}; + } + + template ::value)> + FMT_CONSTEXPR FMT_INLINE auto map(const T (&values)[N]) -> const T (&)[N] { + return values; + } + + // Only map owning types because mapping views can be unsafe. + template , + FMT_ENABLE_IF(std::is_arithmetic::value)> + FMT_CONSTEXPR FMT_INLINE auto map(const T& val) -> decltype(this->map(U())) { + return map(format_as(val)); + } + + template > + struct formattable : bool_constant() || + (has_formatter::value && + !std::is_const::value)> {}; + + template ::value)> + FMT_CONSTEXPR FMT_INLINE auto do_map(T& val) -> T& { + return val; + } + template ::value)> + FMT_CONSTEXPR FMT_INLINE auto do_map(T&) -> unformattable { + return {}; + } + + template , + FMT_ENABLE_IF((std::is_class::value || std::is_enum::value || + std::is_union::value) && + !is_string::value && !is_char::value && + !is_named_arg::value && + !std::is_arithmetic>::value)> + FMT_CONSTEXPR FMT_INLINE auto map(T& val) -> decltype(this->do_map(val)) { + return do_map(val); + } + + template ::value)> + FMT_CONSTEXPR FMT_INLINE auto map(const T& named_arg) + -> decltype(this->map(named_arg.value)) { + return map(named_arg.value); + } + + auto map(...) -> unformattable { return {}; } +}; + +// A type constant after applying arg_mapper. +template +using mapped_type_constant = + type_constant().map(std::declval())), + typename Context::char_type>; + +enum { packed_arg_bits = 4 }; +// Maximum number of arguments with packed types. +enum { max_packed_args = 62 / packed_arg_bits }; +enum : unsigned long long { is_unpacked_bit = 1ULL << 63 }; +enum : unsigned long long { has_named_args_bit = 1ULL << 62 }; + +template +auto copy_str(InputIt begin, InputIt end, appender out) -> appender { + get_container(out).append(begin, end); + return out; +} +template +auto copy_str(InputIt begin, InputIt end, + std::back_insert_iterator out) + -> std::back_insert_iterator { + get_container(out).append(begin, end); + return out; +} + +template +FMT_CONSTEXPR auto copy_str(R&& rng, OutputIt out) -> OutputIt { + return detail::copy_str(rng.begin(), rng.end(), out); +} + +#if FMT_GCC_VERSION && FMT_GCC_VERSION < 500 +// A workaround for gcc 4.8 to make void_t work in a SFINAE context. +template struct void_t_impl { using type = void; }; +template using void_t = typename void_t_impl::type; +#else +template using void_t = void; +#endif + +template +struct is_output_iterator : std::false_type {}; + +template +struct is_output_iterator< + It, T, + void_t::iterator_category, + decltype(*std::declval() = std::declval())>> + : std::true_type {}; + +template struct is_back_insert_iterator : std::false_type {}; +template +struct is_back_insert_iterator> + : std::true_type {}; + +// A type-erased reference to an std::locale to avoid a heavy include. +class locale_ref { + private: + const void* locale_; // A type-erased pointer to std::locale. + + public: + constexpr FMT_INLINE locale_ref() : locale_(nullptr) {} + template explicit locale_ref(const Locale& loc); + + explicit operator bool() const noexcept { return locale_ != nullptr; } + + template auto get() const -> Locale; +}; + +template constexpr auto encode_types() -> unsigned long long { + return 0; +} + +template +constexpr auto encode_types() -> unsigned long long { + return static_cast(mapped_type_constant::value) | + (encode_types() << packed_arg_bits); +} + +#if defined(__cpp_if_constexpr) +// This type is intentionally undefined, only used for errors +template struct type_is_unformattable_for; +#endif + +template +FMT_CONSTEXPR FMT_INLINE auto make_arg(T& val) -> value { + using arg_type = remove_cvref_t().map(val))>; + + constexpr bool formattable_char = + !std::is_same::value; + static_assert(formattable_char, "Mixing character types is disallowed."); + + // Formatting of arbitrary pointers is disallowed. If you want to format a + // pointer cast it to `void*` or `const void*`. In particular, this forbids + // formatting of `[const] volatile char*` printed as bool by iostreams. + constexpr bool formattable_pointer = + !std::is_same::value; + static_assert(formattable_pointer, + "Formatting of non-void pointers is disallowed."); + + constexpr bool formattable = !std::is_same::value; +#if defined(__cpp_if_constexpr) + if constexpr (!formattable) { + type_is_unformattable_for _; + } +#endif + static_assert( + formattable, + "Cannot format an argument. To make type T formattable provide a " + "formatter specialization: https://fmt.dev/latest/api.html#udt"); + return {arg_mapper().map(val)}; +} + +template +FMT_CONSTEXPR auto make_arg(T& val) -> basic_format_arg { + auto arg = basic_format_arg(); + arg.type_ = mapped_type_constant::value; + arg.value_ = make_arg(val); + return arg; +} + +template +FMT_CONSTEXPR inline auto make_arg(T& val) -> basic_format_arg { + return make_arg(val); +} +} // namespace detail +FMT_BEGIN_EXPORT + +// A formatting argument. It is a trivially copyable/constructible type to +// allow storage in basic_memory_buffer. +template class basic_format_arg { + private: + detail::value value_; + detail::type type_; + + template + friend FMT_CONSTEXPR auto detail::make_arg(T& value) + -> basic_format_arg; + + template + friend FMT_CONSTEXPR auto visit_format_arg(Visitor&& vis, + const basic_format_arg& arg) + -> decltype(vis(0)); + + friend class basic_format_args; + friend class dynamic_format_arg_store; + + using char_type = typename Context::char_type; + + template + friend struct detail::arg_data; + + basic_format_arg(const detail::named_arg_info* args, size_t size) + : value_(args, size) {} + + public: + class handle { + public: + explicit handle(detail::custom_value custom) : custom_(custom) {} + + void format(typename Context::parse_context_type& parse_ctx, + Context& ctx) const { + custom_.format(custom_.value, parse_ctx, ctx); + } + + private: + detail::custom_value custom_; + }; + + constexpr basic_format_arg() : type_(detail::type::none_type) {} + + constexpr explicit operator bool() const noexcept { + return type_ != detail::type::none_type; + } + + auto type() const -> detail::type { return type_; } + + auto is_integral() const -> bool { return detail::is_integral_type(type_); } + auto is_arithmetic() const -> bool { + return detail::is_arithmetic_type(type_); + } +}; + +/** + \rst + Visits an argument dispatching to the appropriate visit method based on + the argument type. For example, if the argument type is ``double`` then + ``vis(value)`` will be called with the value of type ``double``. + \endrst + */ +// DEPRECATED! +template +FMT_CONSTEXPR FMT_INLINE auto visit_format_arg( + Visitor&& vis, const basic_format_arg& arg) -> decltype(vis(0)) { + switch (arg.type_) { + case detail::type::none_type: + break; + case detail::type::int_type: + return vis(arg.value_.int_value); + case detail::type::uint_type: + return vis(arg.value_.uint_value); + case detail::type::long_long_type: + return vis(arg.value_.long_long_value); + case detail::type::ulong_long_type: + return vis(arg.value_.ulong_long_value); + case detail::type::int128_type: + return vis(detail::convert_for_visit(arg.value_.int128_value)); + case detail::type::uint128_type: + return vis(detail::convert_for_visit(arg.value_.uint128_value)); + case detail::type::bool_type: + return vis(arg.value_.bool_value); + case detail::type::char_type: + return vis(arg.value_.char_value); + case detail::type::float_type: + return vis(arg.value_.float_value); + case detail::type::double_type: + return vis(arg.value_.double_value); + case detail::type::long_double_type: + return vis(arg.value_.long_double_value); + case detail::type::cstring_type: + return vis(arg.value_.string.data); + case detail::type::string_type: + using sv = basic_string_view; + return vis(sv(arg.value_.string.data, arg.value_.string.size)); + case detail::type::pointer_type: + return vis(arg.value_.pointer); + case detail::type::custom_type: + return vis(typename basic_format_arg::handle(arg.value_.custom)); + } + return vis(monostate()); +} + +// Formatting context. +template class basic_format_context { + private: + OutputIt out_; + basic_format_args args_; + detail::locale_ref loc_; + + public: + using iterator = OutputIt; + using format_arg = basic_format_arg; + using format_args = basic_format_args; + using parse_context_type = basic_format_parse_context; + template using formatter_type = formatter; + + /** The character type for the output. */ + using char_type = Char; + + basic_format_context(basic_format_context&&) = default; + basic_format_context(const basic_format_context&) = delete; + void operator=(const basic_format_context&) = delete; + /** + Constructs a ``basic_format_context`` object. References to the arguments + are stored in the object so make sure they have appropriate lifetimes. + */ + constexpr basic_format_context(OutputIt out, format_args ctx_args, + detail::locale_ref loc = {}) + : out_(out), args_(ctx_args), loc_(loc) {} + + constexpr auto arg(int id) const -> format_arg { return args_.get(id); } + FMT_CONSTEXPR auto arg(basic_string_view name) -> format_arg { + return args_.get(name); + } + FMT_CONSTEXPR auto arg_id(basic_string_view name) -> int { + return args_.get_id(name); + } + auto args() const -> const format_args& { return args_; } + + FMT_CONSTEXPR auto error_handler() -> detail::error_handler { return {}; } + void on_error(const char* message) { error_handler().on_error(message); } + + // Returns an iterator to the beginning of the output range. + FMT_CONSTEXPR auto out() -> iterator { return out_; } + + // Advances the begin iterator to ``it``. + void advance_to(iterator it) { + if (!detail::is_back_insert_iterator()) out_ = it; + } + + FMT_CONSTEXPR auto locale() -> detail::locale_ref { return loc_; } +}; + +template +using buffer_context = + basic_format_context, Char>; +using format_context = buffer_context; + +template +using is_formattable = bool_constant>() + .map(std::declval()))>::value>; + +/** + \rst + An array of references to arguments. It can be implicitly converted into + `~fmt::basic_format_args` for passing into type-erased formatting functions + such as `~fmt::vformat`. + \endrst + */ +template +class format_arg_store +#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 + // Workaround a GCC template argument substitution bug. + : public basic_format_args +#endif +{ + private: + static const size_t num_args = sizeof...(Args); + static constexpr size_t num_named_args = detail::count_named_args(); + static const bool is_packed = num_args <= detail::max_packed_args; + + using value_type = conditional_t, + basic_format_arg>; + + detail::arg_data + data_; + + friend class basic_format_args; + + static constexpr unsigned long long desc = + (is_packed ? detail::encode_types() + : detail::is_unpacked_bit | num_args) | + (num_named_args != 0 + ? static_cast(detail::has_named_args_bit) + : 0); + + public: + template + FMT_CONSTEXPR FMT_INLINE format_arg_store(T&... args) + : +#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 + basic_format_args(*this), +#endif + data_{detail::make_arg(args)...} { + if (detail::const_check(num_named_args != 0)) + detail::init_named_args(data_.named_args(), 0, 0, args...); + } +}; + +/** + \rst + Constructs a `~fmt::format_arg_store` object that contains references to + arguments and can be implicitly converted to `~fmt::format_args`. `Context` + can be omitted in which case it defaults to `~fmt::format_context`. + See `~fmt::arg` for lifetime considerations. + \endrst + */ +// Arguments are taken by lvalue references to avoid some lifetime issues. +template +constexpr auto make_format_args(T&... args) + -> format_arg_store...> { + return {args...}; +} + +/** + \rst + Returns a named argument to be used in a formatting function. + It should only be used in a call to a formatting function or + `dynamic_format_arg_store::push_back`. + + **Example**:: + + fmt::print("Elapsed time: {s:.2f} seconds", fmt::arg("s", 1.23)); + \endrst + */ +template +inline auto arg(const Char* name, const T& arg) -> detail::named_arg { + static_assert(!detail::is_named_arg(), "nested named arguments"); + return {name, arg}; +} +FMT_END_EXPORT + +/** + \rst + A view of a collection of formatting arguments. To avoid lifetime issues it + should only be used as a parameter type in type-erased functions such as + ``vformat``:: + + void vlog(string_view format_str, format_args args); // OK + format_args args = make_format_args(); // Error: dangling reference + \endrst + */ +template class basic_format_args { + public: + using size_type = int; + using format_arg = basic_format_arg; + + private: + // A descriptor that contains information about formatting arguments. + // If the number of arguments is less or equal to max_packed_args then + // argument types are passed in the descriptor. This reduces binary code size + // per formatting function call. + unsigned long long desc_; + union { + // If is_packed() returns true then argument values are stored in values_; + // otherwise they are stored in args_. This is done to improve cache + // locality and reduce compiled code size since storing larger objects + // may require more code (at least on x86-64) even if the same amount of + // data is actually copied to stack. It saves ~10% on the bloat test. + const detail::value* values_; + const format_arg* args_; + }; + + constexpr auto is_packed() const -> bool { + return (desc_ & detail::is_unpacked_bit) == 0; + } + auto has_named_args() const -> bool { + return (desc_ & detail::has_named_args_bit) != 0; + } + + FMT_CONSTEXPR auto type(int index) const -> detail::type { + int shift = index * detail::packed_arg_bits; + unsigned int mask = (1 << detail::packed_arg_bits) - 1; + return static_cast((desc_ >> shift) & mask); + } + + constexpr FMT_INLINE basic_format_args(unsigned long long desc, + const detail::value* values) + : desc_(desc), values_(values) {} + constexpr basic_format_args(unsigned long long desc, const format_arg* args) + : desc_(desc), args_(args) {} + + public: + constexpr basic_format_args() : desc_(0), args_(nullptr) {} + + /** + \rst + Constructs a `basic_format_args` object from `~fmt::format_arg_store`. + \endrst + */ + template + constexpr FMT_INLINE basic_format_args( + const format_arg_store& store) + : basic_format_args(format_arg_store::desc, + store.data_.args()) {} + + /** + \rst + Constructs a `basic_format_args` object from + `~fmt::dynamic_format_arg_store`. + \endrst + */ + constexpr FMT_INLINE basic_format_args( + const dynamic_format_arg_store& store) + : basic_format_args(store.get_types(), store.data()) {} + + /** + \rst + Constructs a `basic_format_args` object from a dynamic set of arguments. + \endrst + */ + constexpr basic_format_args(const format_arg* args, int count) + : basic_format_args(detail::is_unpacked_bit | detail::to_unsigned(count), + args) {} + + /** Returns the argument with the specified id. */ + FMT_CONSTEXPR auto get(int id) const -> format_arg { + format_arg arg; + if (!is_packed()) { + if (id < max_size()) arg = args_[id]; + return arg; + } + if (id >= detail::max_packed_args) return arg; + arg.type_ = type(id); + if (arg.type_ == detail::type::none_type) return arg; + arg.value_ = values_[id]; + return arg; + } + + template + auto get(basic_string_view name) const -> format_arg { + int id = get_id(name); + return id >= 0 ? get(id) : format_arg(); + } + + template + auto get_id(basic_string_view name) const -> int { + if (!has_named_args()) return -1; + const auto& named_args = + (is_packed() ? values_[-1] : args_[-1].value_).named_args; + for (size_t i = 0; i < named_args.size; ++i) { + if (named_args.data[i].name == name) return named_args.data[i].id; + } + return -1; + } + + auto max_size() const -> int { + unsigned long long max_packed = detail::max_packed_args; + return static_cast(is_packed() ? max_packed + : desc_ & ~detail::is_unpacked_bit); + } +}; + +/** An alias to ``basic_format_args``. */ +// A separate type would result in shorter symbols but break ABI compatibility +// between clang and gcc on ARM (#1919). +FMT_EXPORT using format_args = basic_format_args; + +// We cannot use enum classes as bit fields because of a gcc bug, so we put them +// in namespaces instead (https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61414). +// Additionally, if an underlying type is specified, older gcc incorrectly warns +// that the type is too small. Both bugs are fixed in gcc 9.3. +#if FMT_GCC_VERSION && FMT_GCC_VERSION < 903 +# define FMT_ENUM_UNDERLYING_TYPE(type) +#else +# define FMT_ENUM_UNDERLYING_TYPE(type) : type +#endif +namespace align { +enum type FMT_ENUM_UNDERLYING_TYPE(unsigned char){none, left, right, center, + numeric}; +} +using align_t = align::type; +namespace sign { +enum type FMT_ENUM_UNDERLYING_TYPE(unsigned char){none, minus, plus, space}; +} +using sign_t = sign::type; + +namespace detail { + +// Workaround an array initialization issue in gcc 4.8. +template struct fill_t { + private: + enum { max_size = 4 }; + Char data_[max_size] = {Char(' '), Char(0), Char(0), Char(0)}; + unsigned char size_ = 1; + + public: + FMT_CONSTEXPR void operator=(basic_string_view s) { + auto size = s.size(); + FMT_ASSERT(size <= max_size, "invalid fill"); + for (size_t i = 0; i < size; ++i) data_[i] = s[i]; + size_ = static_cast(size); + } + + constexpr auto size() const -> size_t { return size_; } + constexpr auto data() const -> const Char* { return data_; } + + FMT_CONSTEXPR auto operator[](size_t index) -> Char& { return data_[index]; } + FMT_CONSTEXPR auto operator[](size_t index) const -> const Char& { + return data_[index]; + } +}; +} // namespace detail + +enum class presentation_type : unsigned char { + none, + dec, // 'd' + oct, // 'o' + hex_lower, // 'x' + hex_upper, // 'X' + bin_lower, // 'b' + bin_upper, // 'B' + hexfloat_lower, // 'a' + hexfloat_upper, // 'A' + exp_lower, // 'e' + exp_upper, // 'E' + fixed_lower, // 'f' + fixed_upper, // 'F' + general_lower, // 'g' + general_upper, // 'G' + chr, // 'c' + string, // 's' + pointer, // 'p' + debug // '?' +}; + +// Format specifiers for built-in and string types. +template struct format_specs { + int width; + int precision; + presentation_type type; + align_t align : 4; + sign_t sign : 3; + bool alt : 1; // Alternate form ('#'). + bool localized : 1; + detail::fill_t fill; + + constexpr format_specs() + : width(0), + precision(-1), + type(presentation_type::none), + align(align::none), + sign(sign::none), + alt(false), + localized(false) {} +}; + +namespace detail { + +enum class arg_id_kind { none, index, name }; + +// An argument reference. +template struct arg_ref { + FMT_CONSTEXPR arg_ref() : kind(arg_id_kind::none), val() {} + + FMT_CONSTEXPR explicit arg_ref(int index) + : kind(arg_id_kind::index), val(index) {} + FMT_CONSTEXPR explicit arg_ref(basic_string_view name) + : kind(arg_id_kind::name), val(name) {} + + FMT_CONSTEXPR auto operator=(int idx) -> arg_ref& { + kind = arg_id_kind::index; + val.index = idx; + return *this; + } + + arg_id_kind kind; + union value { + FMT_CONSTEXPR value(int idx = 0) : index(idx) {} + FMT_CONSTEXPR value(basic_string_view n) : name(n) {} + + int index; + basic_string_view name; + } val; +}; + +// Format specifiers with width and precision resolved at formatting rather +// than parsing time to allow reusing the same parsed specifiers with +// different sets of arguments (precompilation of format strings). +template +struct dynamic_format_specs : format_specs { + arg_ref width_ref; + arg_ref precision_ref; +}; + +// Converts a character to ASCII. Returns '\0' on conversion failure. +template ::value)> +constexpr auto to_ascii(Char c) -> char { + return c <= 0xff ? static_cast(c) : '\0'; +} +template ::value)> +constexpr auto to_ascii(Char c) -> char { + return c <= 0xff ? static_cast(c) : '\0'; +} + +// Returns the number of code units in a code point or 1 on error. +template +FMT_CONSTEXPR auto code_point_length(const Char* begin) -> int { + if (const_check(sizeof(Char) != 1)) return 1; + auto c = static_cast(*begin); + return static_cast((0x3a55000000000000ull >> (2 * (c >> 3))) & 0x3) + 1; +} + +// Return the result via the out param to workaround gcc bug 77539. +template +FMT_CONSTEXPR auto find(Ptr first, Ptr last, T value, Ptr& out) -> bool { + for (out = first; out != last; ++out) { + if (*out == value) return true; + } + return false; +} + +template <> +inline auto find(const char* first, const char* last, char value, + const char*& out) -> bool { + out = static_cast( + std::memchr(first, value, to_unsigned(last - first))); + return out != nullptr; +} + +// Parses the range [begin, end) as an unsigned integer. This function assumes +// that the range is non-empty and the first character is a digit. +template +FMT_CONSTEXPR auto parse_nonnegative_int(const Char*& begin, const Char* end, + int error_value) noexcept -> int { + FMT_ASSERT(begin != end && '0' <= *begin && *begin <= '9', ""); + unsigned value = 0, prev = 0; + auto p = begin; + do { + prev = value; + value = value * 10 + unsigned(*p - '0'); + ++p; + } while (p != end && '0' <= *p && *p <= '9'); + auto num_digits = p - begin; + begin = p; + if (num_digits <= std::numeric_limits::digits10) + return static_cast(value); + // Check for overflow. + const unsigned max = to_unsigned((std::numeric_limits::max)()); + return num_digits == std::numeric_limits::digits10 + 1 && + prev * 10ull + unsigned(p[-1] - '0') <= max + ? static_cast(value) + : error_value; +} + +FMT_CONSTEXPR inline auto parse_align(char c) -> align_t { + switch (c) { + case '<': + return align::left; + case '>': + return align::right; + case '^': + return align::center; + } + return align::none; +} + +template constexpr auto is_name_start(Char c) -> bool { + return ('a' <= c && c <= 'z') || ('A' <= c && c <= 'Z') || c == '_'; +} + +template +FMT_CONSTEXPR auto do_parse_arg_id(const Char* begin, const Char* end, + Handler&& handler) -> const Char* { + Char c = *begin; + if (c >= '0' && c <= '9') { + int index = 0; + constexpr int max = (std::numeric_limits::max)(); + if (c != '0') + index = parse_nonnegative_int(begin, end, max); + else + ++begin; + if (begin == end || (*begin != '}' && *begin != ':')) + throw_format_error("invalid format string"); + else + handler.on_index(index); + return begin; + } + if (!is_name_start(c)) { + throw_format_error("invalid format string"); + return begin; + } + auto it = begin; + do { + ++it; + } while (it != end && (is_name_start(*it) || ('0' <= *it && *it <= '9'))); + handler.on_name({begin, to_unsigned(it - begin)}); + return it; +} + +template +FMT_CONSTEXPR FMT_INLINE auto parse_arg_id(const Char* begin, const Char* end, + Handler&& handler) -> const Char* { + FMT_ASSERT(begin != end, ""); + Char c = *begin; + if (c != '}' && c != ':') return do_parse_arg_id(begin, end, handler); + handler.on_auto(); + return begin; +} + +template struct dynamic_spec_id_handler { + basic_format_parse_context& ctx; + arg_ref& ref; + + FMT_CONSTEXPR void on_auto() { + int id = ctx.next_arg_id(); + ref = arg_ref(id); + ctx.check_dynamic_spec(id); + } + FMT_CONSTEXPR void on_index(int id) { + ref = arg_ref(id); + ctx.check_arg_id(id); + ctx.check_dynamic_spec(id); + } + FMT_CONSTEXPR void on_name(basic_string_view id) { + ref = arg_ref(id); + ctx.check_arg_id(id); + } +}; + +// Parses [integer | "{" [arg_id] "}"]. +template +FMT_CONSTEXPR auto parse_dynamic_spec(const Char* begin, const Char* end, + int& value, arg_ref& ref, + basic_format_parse_context& ctx) + -> const Char* { + FMT_ASSERT(begin != end, ""); + if ('0' <= *begin && *begin <= '9') { + int val = parse_nonnegative_int(begin, end, -1); + if (val != -1) + value = val; + else + throw_format_error("number is too big"); + } else if (*begin == '{') { + ++begin; + auto handler = dynamic_spec_id_handler{ctx, ref}; + if (begin != end) begin = parse_arg_id(begin, end, handler); + if (begin != end && *begin == '}') return ++begin; + throw_format_error("invalid format string"); + } + return begin; +} + +template +FMT_CONSTEXPR auto parse_precision(const Char* begin, const Char* end, + int& value, arg_ref& ref, + basic_format_parse_context& ctx) + -> const Char* { + ++begin; + if (begin == end || *begin == '}') { + throw_format_error("invalid precision"); + return begin; + } + return parse_dynamic_spec(begin, end, value, ref, ctx); +} + +enum class state { start, align, sign, hash, zero, width, precision, locale }; + +// Parses standard format specifiers. +template +FMT_CONSTEXPR FMT_INLINE auto parse_format_specs( + const Char* begin, const Char* end, dynamic_format_specs& specs, + basic_format_parse_context& ctx, type arg_type) -> const Char* { + auto c = '\0'; + if (end - begin > 1) { + auto next = to_ascii(begin[1]); + c = parse_align(next) == align::none ? to_ascii(*begin) : '\0'; + } else { + if (begin == end) return begin; + c = to_ascii(*begin); + } + + struct { + state current_state = state::start; + FMT_CONSTEXPR void operator()(state s, bool valid = true) { + if (current_state >= s || !valid) + throw_format_error("invalid format specifier"); + current_state = s; + } + } enter_state; + + using pres = presentation_type; + constexpr auto integral_set = sint_set | uint_set | bool_set | char_set; + struct { + const Char*& begin; + dynamic_format_specs& specs; + type arg_type; + + FMT_CONSTEXPR auto operator()(pres type, int set) -> const Char* { + if (!in(arg_type, set)) throw_format_error("invalid format specifier"); + specs.type = type; + return begin + 1; + } + } parse_presentation_type{begin, specs, arg_type}; + + for (;;) { + switch (c) { + case '<': + case '>': + case '^': + enter_state(state::align); + specs.align = parse_align(c); + ++begin; + break; + case '+': + case '-': + case ' ': + enter_state(state::sign, in(arg_type, sint_set | float_set)); + switch (c) { + case '+': + specs.sign = sign::plus; + break; + case '-': + specs.sign = sign::minus; + break; + case ' ': + specs.sign = sign::space; + break; + } + ++begin; + break; + case '#': + enter_state(state::hash, is_arithmetic_type(arg_type)); + specs.alt = true; + ++begin; + break; + case '0': + enter_state(state::zero); + if (!is_arithmetic_type(arg_type)) + throw_format_error("format specifier requires numeric argument"); + if (specs.align == align::none) { + // Ignore 0 if align is specified for compatibility with std::format. + specs.align = align::numeric; + specs.fill[0] = Char('0'); + } + ++begin; + break; + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '{': + enter_state(state::width); + begin = parse_dynamic_spec(begin, end, specs.width, specs.width_ref, ctx); + break; + case '.': + enter_state(state::precision, + in(arg_type, float_set | string_set | cstring_set)); + begin = parse_precision(begin, end, specs.precision, specs.precision_ref, + ctx); + break; + case 'L': + enter_state(state::locale, is_arithmetic_type(arg_type)); + specs.localized = true; + ++begin; + break; + case 'd': + return parse_presentation_type(pres::dec, integral_set); + case 'o': + return parse_presentation_type(pres::oct, integral_set); + case 'x': + return parse_presentation_type(pres::hex_lower, integral_set); + case 'X': + return parse_presentation_type(pres::hex_upper, integral_set); + case 'b': + return parse_presentation_type(pres::bin_lower, integral_set); + case 'B': + return parse_presentation_type(pres::bin_upper, integral_set); + case 'a': + return parse_presentation_type(pres::hexfloat_lower, float_set); + case 'A': + return parse_presentation_type(pres::hexfloat_upper, float_set); + case 'e': + return parse_presentation_type(pres::exp_lower, float_set); + case 'E': + return parse_presentation_type(pres::exp_upper, float_set); + case 'f': + return parse_presentation_type(pres::fixed_lower, float_set); + case 'F': + return parse_presentation_type(pres::fixed_upper, float_set); + case 'g': + return parse_presentation_type(pres::general_lower, float_set); + case 'G': + return parse_presentation_type(pres::general_upper, float_set); + case 'c': + return parse_presentation_type(pres::chr, integral_set); + case 's': + return parse_presentation_type(pres::string, + bool_set | string_set | cstring_set); + case 'p': + return parse_presentation_type(pres::pointer, pointer_set | cstring_set); + case '?': + return parse_presentation_type(pres::debug, + char_set | string_set | cstring_set); + case '}': + return begin; + default: { + if (*begin == '}') return begin; + // Parse fill and alignment. + auto fill_end = begin + code_point_length(begin); + if (end - fill_end <= 0) { + throw_format_error("invalid format specifier"); + return begin; + } + if (*begin == '{') { + throw_format_error("invalid fill character '{'"); + return begin; + } + auto align = parse_align(to_ascii(*fill_end)); + enter_state(state::align, align != align::none); + specs.fill = {begin, to_unsigned(fill_end - begin)}; + specs.align = align; + begin = fill_end + 1; + } + } + if (begin == end) return begin; + c = to_ascii(*begin); + } +} + +template +FMT_CONSTEXPR auto parse_replacement_field(const Char* begin, const Char* end, + Handler&& handler) -> const Char* { + struct id_adapter { + Handler& handler; + int arg_id; + + FMT_CONSTEXPR void on_auto() { arg_id = handler.on_arg_id(); } + FMT_CONSTEXPR void on_index(int id) { arg_id = handler.on_arg_id(id); } + FMT_CONSTEXPR void on_name(basic_string_view id) { + arg_id = handler.on_arg_id(id); + } + }; + + ++begin; + if (begin == end) return handler.on_error("invalid format string"), end; + if (*begin == '}') { + handler.on_replacement_field(handler.on_arg_id(), begin); + } else if (*begin == '{') { + handler.on_text(begin, begin + 1); + } else { + auto adapter = id_adapter{handler, 0}; + begin = parse_arg_id(begin, end, adapter); + Char c = begin != end ? *begin : Char(); + if (c == '}') { + handler.on_replacement_field(adapter.arg_id, begin); + } else if (c == ':') { + begin = handler.on_format_specs(adapter.arg_id, begin + 1, end); + if (begin == end || *begin != '}') + return handler.on_error("unknown format specifier"), end; + } else { + return handler.on_error("missing '}' in format string"), end; + } + } + return begin + 1; +} + +template +FMT_CONSTEXPR FMT_INLINE void parse_format_string( + basic_string_view format_str, Handler&& handler) { + auto begin = format_str.data(); + auto end = begin + format_str.size(); + if (end - begin < 32) { + // Use a simple loop instead of memchr for small strings. + const Char* p = begin; + while (p != end) { + auto c = *p++; + if (c == '{') { + handler.on_text(begin, p - 1); + begin = p = parse_replacement_field(p - 1, end, handler); + } else if (c == '}') { + if (p == end || *p != '}') + return handler.on_error("unmatched '}' in format string"); + handler.on_text(begin, p); + begin = ++p; + } + } + handler.on_text(begin, end); + return; + } + struct writer { + FMT_CONSTEXPR void operator()(const Char* from, const Char* to) { + if (from == to) return; + for (;;) { + const Char* p = nullptr; + if (!find(from, to, Char('}'), p)) + return handler_.on_text(from, to); + ++p; + if (p == to || *p != '}') + return handler_.on_error("unmatched '}' in format string"); + handler_.on_text(from, p); + from = p + 1; + } + } + Handler& handler_; + } write = {handler}; + while (begin != end) { + // Doing two passes with memchr (one for '{' and another for '}') is up to + // 2.5x faster than the naive one-pass implementation on big format strings. + const Char* p = begin; + if (*begin != '{' && !find(begin + 1, end, Char('{'), p)) + return write(begin, end); + write(begin, p); + begin = parse_replacement_field(p, end, handler); + } +} + +template ::value> struct strip_named_arg { + using type = T; +}; +template struct strip_named_arg { + using type = remove_cvref_t; +}; + +template +FMT_CONSTEXPR auto parse_format_specs(ParseContext& ctx) + -> decltype(ctx.begin()) { + using char_type = typename ParseContext::char_type; + using context = buffer_context; + using mapped_type = conditional_t< + mapped_type_constant::value != type::custom_type, + decltype(arg_mapper().map(std::declval())), + typename strip_named_arg::type>; +#if defined(__cpp_if_constexpr) + if constexpr (std::is_default_constructible_v< + formatter>) { + return formatter().parse(ctx); + } else { + type_is_unformattable_for _; + return ctx.begin(); + } +#else + return formatter().parse(ctx); +#endif +} + +// Checks char specs and returns true iff the presentation type is char-like. +template +FMT_CONSTEXPR auto check_char_specs(const format_specs& specs) -> bool { + if (specs.type != presentation_type::none && + specs.type != presentation_type::chr && + specs.type != presentation_type::debug) { + return false; + } + if (specs.align == align::numeric || specs.sign != sign::none || specs.alt) + throw_format_error("invalid format specifier for char"); + return true; +} + +#if FMT_USE_NONTYPE_TEMPLATE_ARGS +template +constexpr auto get_arg_index_by_name(basic_string_view name) -> int { + if constexpr (is_statically_named_arg()) { + if (name == T::name) return N; + } + if constexpr (sizeof...(Args) > 0) + return get_arg_index_by_name(name); + (void)name; // Workaround an MSVC bug about "unused" parameter. + return -1; +} +#endif + +template +FMT_CONSTEXPR auto get_arg_index_by_name(basic_string_view name) -> int { +#if FMT_USE_NONTYPE_TEMPLATE_ARGS + if constexpr (sizeof...(Args) > 0) + return get_arg_index_by_name<0, Args...>(name); +#endif + (void)name; + return -1; +} + +template class format_string_checker { + private: + using parse_context_type = compile_parse_context; + static constexpr int num_args = sizeof...(Args); + + // Format specifier parsing function. + // In the future basic_format_parse_context will replace compile_parse_context + // here and will use is_constant_evaluated and downcasting to access the data + // needed for compile-time checks: https://godbolt.org/z/GvWzcTjh1. + using parse_func = const Char* (*)(parse_context_type&); + + type types_[num_args > 0 ? static_cast(num_args) : 1]; + parse_context_type context_; + parse_func parse_funcs_[num_args > 0 ? static_cast(num_args) : 1]; + + public: + explicit FMT_CONSTEXPR format_string_checker(basic_string_view fmt) + : types_{mapped_type_constant>::value...}, + context_(fmt, num_args, types_), + parse_funcs_{&parse_format_specs...} {} + + FMT_CONSTEXPR void on_text(const Char*, const Char*) {} + + FMT_CONSTEXPR auto on_arg_id() -> int { return context_.next_arg_id(); } + FMT_CONSTEXPR auto on_arg_id(int id) -> int { + return context_.check_arg_id(id), id; + } + FMT_CONSTEXPR auto on_arg_id(basic_string_view id) -> int { +#if FMT_USE_NONTYPE_TEMPLATE_ARGS + auto index = get_arg_index_by_name(id); + if (index < 0) on_error("named argument is not found"); + return index; +#else + (void)id; + on_error("compile-time checks for named arguments require C++20 support"); + return 0; +#endif + } + + FMT_CONSTEXPR void on_replacement_field(int id, const Char* begin) { + on_format_specs(id, begin, begin); // Call parse() on empty specs. + } + + FMT_CONSTEXPR auto on_format_specs(int id, const Char* begin, const Char*) + -> const Char* { + context_.advance_to(begin); + // id >= 0 check is a workaround for gcc 10 bug (#2065). + return id >= 0 && id < num_args ? parse_funcs_[id](context_) : begin; + } + + FMT_CONSTEXPR void on_error(const char* message) { + throw_format_error(message); + } +}; + +// Reports a compile-time error if S is not a valid format string. +template ::value)> +FMT_INLINE void check_format_string(const S&) { +#ifdef FMT_ENFORCE_COMPILE_STRING + static_assert(is_compile_string::value, + "FMT_ENFORCE_COMPILE_STRING requires all format strings to use " + "FMT_STRING."); +#endif +} +template ::value)> +void check_format_string(S format_str) { + using char_t = typename S::char_type; + FMT_CONSTEXPR auto s = basic_string_view(format_str); + using checker = format_string_checker...>; + FMT_CONSTEXPR bool error = (parse_format_string(s, checker(s)), true); + ignore_unused(error); +} + +template struct vformat_args { + using type = basic_format_args< + basic_format_context>, Char>>; +}; +template <> struct vformat_args { using type = format_args; }; + +// Use vformat_args and avoid type_identity to keep symbols short. +template +void vformat_to(buffer& buf, basic_string_view fmt, + typename vformat_args::type args, locale_ref loc = {}); + +FMT_API void vprint_mojibake(std::FILE*, string_view, format_args); +#ifndef _WIN32 +inline void vprint_mojibake(std::FILE*, string_view, format_args) {} +#endif +} // namespace detail + +FMT_BEGIN_EXPORT + +// A formatter specialization for natively supported types. +template +struct formatter::value != + detail::type::custom_type>> { + private: + detail::dynamic_format_specs specs_; + + public: + template + FMT_CONSTEXPR auto parse(ParseContext& ctx) -> const Char* { + auto type = detail::type_constant::value; + auto end = + detail::parse_format_specs(ctx.begin(), ctx.end(), specs_, ctx, type); + if (type == detail::type::char_type) detail::check_char_specs(specs_); + return end; + } + + template ::value, + FMT_ENABLE_IF(U == detail::type::string_type || + U == detail::type::cstring_type || + U == detail::type::char_type)> + FMT_CONSTEXPR void set_debug_format(bool set = true) { + specs_.type = set ? presentation_type::debug : presentation_type::none; + } + + template + FMT_CONSTEXPR auto format(const T& val, FormatContext& ctx) const + -> decltype(ctx.out()); +}; + +template struct runtime_format_string { + basic_string_view str; +}; + +/** A compile-time format string. */ +template class basic_format_string { + private: + basic_string_view str_; + + public: + template >::value)> + FMT_CONSTEVAL FMT_INLINE basic_format_string(const S& s) : str_(s) { + static_assert( + detail::count< + (std::is_base_of>::value && + std::is_reference::value)...>() == 0, + "passing views as lvalues is disallowed"); +#ifdef FMT_HAS_CONSTEVAL + if constexpr (detail::count_named_args() == + detail::count_statically_named_args()) { + using checker = + detail::format_string_checker...>; + detail::parse_format_string(str_, checker(s)); + } +#else + detail::check_format_string(s); +#endif + } + basic_format_string(runtime_format_string fmt) : str_(fmt.str) {} + + FMT_INLINE operator basic_string_view() const { return str_; } + FMT_INLINE auto get() const -> basic_string_view { return str_; } +}; + +#if FMT_GCC_VERSION && FMT_GCC_VERSION < 409 +// Workaround broken conversion on older gcc. +template using format_string = string_view; +inline auto runtime(string_view s) -> string_view { return s; } +#else +template +using format_string = basic_format_string...>; +/** + \rst + Creates a runtime format string. + + **Example**:: + + // Check format string at runtime instead of compile-time. + fmt::print(fmt::runtime("{:d}"), "I am not a number"); + \endrst + */ +inline auto runtime(string_view s) -> runtime_format_string<> { return {{s}}; } +#endif + +FMT_API auto vformat(string_view fmt, format_args args) -> std::string; + +/** + \rst + Formats ``args`` according to specifications in ``fmt`` and returns the result + as a string. + + **Example**:: + + #include + std::string message = fmt::format("The answer is {}.", 42); + \endrst +*/ +template +FMT_NODISCARD FMT_INLINE auto format(format_string fmt, T&&... args) + -> std::string { + return vformat(fmt, fmt::make_format_args(args...)); +} + +/** Formats a string and writes the output to ``out``. */ +template ::value)> +auto vformat_to(OutputIt out, string_view fmt, format_args args) -> OutputIt { + auto&& buf = detail::get_buffer(out); + detail::vformat_to(buf, fmt, args, {}); + return detail::get_iterator(buf, out); +} + +/** + \rst + Formats ``args`` according to specifications in ``fmt``, writes the result to + the output iterator ``out`` and returns the iterator past the end of the output + range. `format_to` does not append a terminating null character. + + **Example**:: + + auto out = std::vector(); + fmt::format_to(std::back_inserter(out), "{}", 42); + \endrst + */ +template ::value)> +FMT_INLINE auto format_to(OutputIt out, format_string fmt, T&&... args) + -> OutputIt { + return vformat_to(out, fmt, fmt::make_format_args(args...)); +} + +template struct format_to_n_result { + /** Iterator past the end of the output range. */ + OutputIt out; + /** Total (not truncated) output size. */ + size_t size; +}; + +template ::value)> +auto vformat_to_n(OutputIt out, size_t n, string_view fmt, format_args args) + -> format_to_n_result { + using traits = detail::fixed_buffer_traits; + auto buf = detail::iterator_buffer(out, n); + detail::vformat_to(buf, fmt, args, {}); + return {buf.out(), buf.count()}; +} + +/** + \rst + Formats ``args`` according to specifications in ``fmt``, writes up to ``n`` + characters of the result to the output iterator ``out`` and returns the total + (not truncated) output size and the iterator past the end of the output range. + `format_to_n` does not append a terminating null character. + \endrst + */ +template ::value)> +FMT_INLINE auto format_to_n(OutputIt out, size_t n, format_string fmt, + T&&... args) -> format_to_n_result { + return vformat_to_n(out, n, fmt, fmt::make_format_args(args...)); +} + +/** Returns the number of chars in the output of ``format(fmt, args...)``. */ +template +FMT_NODISCARD FMT_INLINE auto formatted_size(format_string fmt, + T&&... args) -> size_t { + auto buf = detail::counting_buffer<>(); + detail::vformat_to(buf, fmt, fmt::make_format_args(args...), {}); + return buf.count(); +} + +FMT_API void vprint(string_view fmt, format_args args); +FMT_API void vprint(std::FILE* f, string_view fmt, format_args args); + +/** + \rst + Formats ``args`` according to specifications in ``fmt`` and writes the output + to ``stdout``. + + **Example**:: + + fmt::print("Elapsed time: {0:.2f} seconds", 1.23); + \endrst + */ +template +FMT_INLINE void print(format_string fmt, T&&... args) { + const auto& vargs = fmt::make_format_args(args...); + return detail::is_utf8() ? vprint(fmt, vargs) + : detail::vprint_mojibake(stdout, fmt, vargs); +} + +/** + \rst + Formats ``args`` according to specifications in ``fmt`` and writes the + output to the file ``f``. + + **Example**:: + + fmt::print(stderr, "Don't {}!", "panic"); + \endrst + */ +template +FMT_INLINE void print(std::FILE* f, format_string fmt, T&&... args) { + const auto& vargs = fmt::make_format_args(args...); + return detail::is_utf8() ? vprint(f, fmt, vargs) + : detail::vprint_mojibake(f, fmt, vargs); +} + +/** + Formats ``args`` according to specifications in ``fmt`` and writes the + output to the file ``f`` followed by a newline. + */ +template +FMT_INLINE void println(std::FILE* f, format_string fmt, T&&... args) { + return fmt::print(f, "{}\n", fmt::format(fmt, std::forward(args)...)); +} + +/** + Formats ``args`` according to specifications in ``fmt`` and writes the output + to ``stdout`` followed by a newline. + */ +template +FMT_INLINE void println(format_string fmt, T&&... args) { + return fmt::println(stdout, fmt, std::forward(args)...); +} + +FMT_END_EXPORT +FMT_GCC_PRAGMA("GCC pop_options") +FMT_END_NAMESPACE + +#ifdef FMT_HEADER_ONLY +# include "format.h" +#endif +#endif // FMT_CORE_H_ diff --git a/Genie/Genie/src/qualla/include/fmt/format-inl.h b/Genie/Genie/src/qualla/include/fmt/format-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..dac2d437a41ab7b0b4e72895212b5a972ada73a9 --- /dev/null +++ b/Genie/Genie/src/qualla/include/fmt/format-inl.h @@ -0,0 +1,1662 @@ +// Formatting library for C++ - implementation +// +// Copyright (c) 2012 - 2016, Victor Zverovich +// All rights reserved. +// +// For the license information refer to format.h. + +#ifndef FMT_FORMAT_INL_H_ +#define FMT_FORMAT_INL_H_ + +#include +#include // errno +#include +#include +#include + +#ifndef FMT_STATIC_THOUSANDS_SEPARATOR +# include +#endif + +#ifdef _WIN32 +# include // _isatty +#endif + +#include "format.h" + +FMT_BEGIN_NAMESPACE +namespace detail { + +FMT_FUNC void assert_fail(const char* file, int line, const char* message) { + // Use unchecked std::fprintf to avoid triggering another assertion when + // writing to stderr fails + std::fprintf(stderr, "%s:%d: assertion failed: %s", file, line, message); + // Chosen instead of std::abort to satisfy Clang in CUDA mode during device + // code pass. + std::terminate(); +} + +FMT_FUNC void throw_format_error(const char* message) { + FMT_THROW(format_error(message)); +} + +FMT_FUNC void format_error_code(detail::buffer& out, int error_code, + string_view message) noexcept { + // Report error code making sure that the output fits into + // inline_buffer_size to avoid dynamic memory allocation and potential + // bad_alloc. + out.try_resize(0); + static const char SEP[] = ": "; + static const char ERROR_STR[] = "error "; + // Subtract 2 to account for terminating null characters in SEP and ERROR_STR. + size_t error_code_size = sizeof(SEP) + sizeof(ERROR_STR) - 2; + auto abs_value = static_cast>(error_code); + if (detail::is_negative(error_code)) { + abs_value = 0 - abs_value; + ++error_code_size; + } + error_code_size += detail::to_unsigned(detail::count_digits(abs_value)); + auto it = buffer_appender(out); + if (message.size() <= inline_buffer_size - error_code_size) + format_to(it, FMT_STRING("{}{}"), message, SEP); + format_to(it, FMT_STRING("{}{}"), ERROR_STR, error_code); + FMT_ASSERT(out.size() <= inline_buffer_size, ""); +} + +FMT_FUNC void report_error(format_func func, int error_code, + const char* message) noexcept { + memory_buffer full_message; + func(full_message, error_code, message); + // Don't use fwrite_fully because the latter may throw. + if (std::fwrite(full_message.data(), full_message.size(), 1, stderr) > 0) + std::fputc('\n', stderr); +} + +// A wrapper around fwrite that throws on error. +inline void fwrite_fully(const void* ptr, size_t size, size_t count, + FILE* stream) { + size_t written = std::fwrite(ptr, size, count, stream); + if (written < count) + FMT_THROW(system_error(errno, FMT_STRING("cannot write to file"))); +} + +#ifndef FMT_STATIC_THOUSANDS_SEPARATOR +template +locale_ref::locale_ref(const Locale& loc) : locale_(&loc) { + static_assert(std::is_same::value, ""); +} + +template Locale locale_ref::get() const { + static_assert(std::is_same::value, ""); + return locale_ ? *static_cast(locale_) : std::locale(); +} + +template +FMT_FUNC auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result { + auto& facet = std::use_facet>(loc.get()); + auto grouping = facet.grouping(); + auto thousands_sep = grouping.empty() ? Char() : facet.thousands_sep(); + return {std::move(grouping), thousands_sep}; +} +template FMT_FUNC Char decimal_point_impl(locale_ref loc) { + return std::use_facet>(loc.get()) + .decimal_point(); +} +#else +template +FMT_FUNC auto thousands_sep_impl(locale_ref) -> thousands_sep_result { + return {"\03", FMT_STATIC_THOUSANDS_SEPARATOR}; +} +template FMT_FUNC Char decimal_point_impl(locale_ref) { + return '.'; +} +#endif + +FMT_FUNC auto write_loc(appender out, loc_value value, + const format_specs<>& specs, locale_ref loc) -> bool { +#ifndef FMT_STATIC_THOUSANDS_SEPARATOR + auto locale = loc.get(); + // We cannot use the num_put facet because it may produce output in + // a wrong encoding. + using facet = format_facet; + if (std::has_facet(locale)) + return std::use_facet(locale).put(out, value, specs); + return facet(locale).put(out, value, specs); +#endif + return false; +} +} // namespace detail + +template typename Locale::id format_facet::id; + +#ifndef FMT_STATIC_THOUSANDS_SEPARATOR +template format_facet::format_facet(Locale& loc) { + auto& numpunct = std::use_facet>(loc); + grouping_ = numpunct.grouping(); + if (!grouping_.empty()) separator_ = std::string(1, numpunct.thousands_sep()); +} + +template <> +FMT_API FMT_FUNC auto format_facet::do_put( + appender out, loc_value val, const format_specs<>& specs) const -> bool { + return val.visit( + detail::loc_writer<>{out, specs, separator_, grouping_, decimal_point_}); +} +#endif + +FMT_FUNC std::system_error vsystem_error(int error_code, string_view fmt, + format_args args) { + auto ec = std::error_code(error_code, std::generic_category()); + return std::system_error(ec, vformat(fmt, args)); +} + +namespace detail { + +template inline bool operator==(basic_fp x, basic_fp y) { + return x.f == y.f && x.e == y.e; +} + +// Compilers should be able to optimize this into the ror instruction. +FMT_CONSTEXPR inline uint32_t rotr(uint32_t n, uint32_t r) noexcept { + r &= 31; + return (n >> r) | (n << (32 - r)); +} +FMT_CONSTEXPR inline uint64_t rotr(uint64_t n, uint32_t r) noexcept { + r &= 63; + return (n >> r) | (n << (64 - r)); +} + +// Implementation of Dragonbox algorithm: https://github.com/jk-jeon/dragonbox. +namespace dragonbox { +// Computes upper 64 bits of multiplication of a 32-bit unsigned integer and a +// 64-bit unsigned integer. +inline uint64_t umul96_upper64(uint32_t x, uint64_t y) noexcept { + return umul128_upper64(static_cast(x) << 32, y); +} + +// Computes lower 128 bits of multiplication of a 64-bit unsigned integer and a +// 128-bit unsigned integer. +inline uint128_fallback umul192_lower128(uint64_t x, + uint128_fallback y) noexcept { + uint64_t high = x * y.high(); + uint128_fallback high_low = umul128(x, y.low()); + return {high + high_low.high(), high_low.low()}; +} + +// Computes lower 64 bits of multiplication of a 32-bit unsigned integer and a +// 64-bit unsigned integer. +inline uint64_t umul96_lower64(uint32_t x, uint64_t y) noexcept { + return x * y; +} + +// Various fast log computations. +inline int floor_log10_pow2_minus_log10_4_over_3(int e) noexcept { + FMT_ASSERT(e <= 2936 && e >= -2985, "too large exponent"); + return (e * 631305 - 261663) >> 21; +} + +FMT_INLINE_VARIABLE constexpr struct { + uint32_t divisor; + int shift_amount; +} div_small_pow10_infos[] = {{10, 16}, {100, 16}}; + +// Replaces n by floor(n / pow(10, N)) returning true if and only if n is +// divisible by pow(10, N). +// Precondition: n <= pow(10, N + 1). +template +bool check_divisibility_and_divide_by_pow10(uint32_t& n) noexcept { + // The numbers below are chosen such that: + // 1. floor(n/d) = floor(nm / 2^k) where d=10 or d=100, + // 2. nm mod 2^k < m if and only if n is divisible by d, + // where m is magic_number, k is shift_amount + // and d is divisor. + // + // Item 1 is a common technique of replacing division by a constant with + // multiplication, see e.g. "Division by Invariant Integers Using + // Multiplication" by Granlund and Montgomery (1994). magic_number (m) is set + // to ceil(2^k/d) for large enough k. + // The idea for item 2 originates from Schubfach. + constexpr auto info = div_small_pow10_infos[N - 1]; + FMT_ASSERT(n <= info.divisor * 10, "n is too large"); + constexpr uint32_t magic_number = + (1u << info.shift_amount) / info.divisor + 1; + n *= magic_number; + const uint32_t comparison_mask = (1u << info.shift_amount) - 1; + bool result = (n & comparison_mask) < magic_number; + n >>= info.shift_amount; + return result; +} + +// Computes floor(n / pow(10, N)) for small n and N. +// Precondition: n <= pow(10, N + 1). +template uint32_t small_division_by_pow10(uint32_t n) noexcept { + constexpr auto info = div_small_pow10_infos[N - 1]; + FMT_ASSERT(n <= info.divisor * 10, "n is too large"); + constexpr uint32_t magic_number = + (1u << info.shift_amount) / info.divisor + 1; + return (n * magic_number) >> info.shift_amount; +} + +// Computes floor(n / 10^(kappa + 1)) (float) +inline uint32_t divide_by_10_to_kappa_plus_1(uint32_t n) noexcept { + // 1374389535 = ceil(2^37/100) + return static_cast((static_cast(n) * 1374389535) >> 37); +} +// Computes floor(n / 10^(kappa + 1)) (double) +inline uint64_t divide_by_10_to_kappa_plus_1(uint64_t n) noexcept { + // 2361183241434822607 = ceil(2^(64+7)/1000) + return umul128_upper64(n, 2361183241434822607ull) >> 7; +} + +// Various subroutines using pow10 cache +template struct cache_accessor; + +template <> struct cache_accessor { + using carrier_uint = float_info::carrier_uint; + using cache_entry_type = uint64_t; + + static uint64_t get_cached_power(int k) noexcept { + FMT_ASSERT(k >= float_info::min_k && k <= float_info::max_k, + "k is out of range"); + static constexpr const uint64_t pow10_significands[] = { + 0x81ceb32c4b43fcf5, 0xa2425ff75e14fc32, 0xcad2f7f5359a3b3f, + 0xfd87b5f28300ca0e, 0x9e74d1b791e07e49, 0xc612062576589ddb, + 0xf79687aed3eec552, 0x9abe14cd44753b53, 0xc16d9a0095928a28, + 0xf1c90080baf72cb2, 0x971da05074da7bef, 0xbce5086492111aeb, + 0xec1e4a7db69561a6, 0x9392ee8e921d5d08, 0xb877aa3236a4b44a, + 0xe69594bec44de15c, 0x901d7cf73ab0acda, 0xb424dc35095cd810, + 0xe12e13424bb40e14, 0x8cbccc096f5088cc, 0xafebff0bcb24aaff, + 0xdbe6fecebdedd5bf, 0x89705f4136b4a598, 0xabcc77118461cefd, + 0xd6bf94d5e57a42bd, 0x8637bd05af6c69b6, 0xa7c5ac471b478424, + 0xd1b71758e219652c, 0x83126e978d4fdf3c, 0xa3d70a3d70a3d70b, + 0xcccccccccccccccd, 0x8000000000000000, 0xa000000000000000, + 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, + 0xc350000000000000, 0xf424000000000000, 0x9896800000000000, + 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, + 0xba43b74000000000, 0xe8d4a51000000000, 0x9184e72a00000000, + 0xb5e620f480000000, 0xe35fa931a0000000, 0x8e1bc9bf04000000, + 0xb1a2bc2ec5000000, 0xde0b6b3a76400000, 0x8ac7230489e80000, + 0xad78ebc5ac620000, 0xd8d726b7177a8000, 0x878678326eac9000, + 0xa968163f0a57b400, 0xd3c21bcecceda100, 0x84595161401484a0, + 0xa56fa5b99019a5c8, 0xcecb8f27f4200f3a, 0x813f3978f8940985, + 0xa18f07d736b90be6, 0xc9f2c9cd04674edf, 0xfc6f7c4045812297, + 0x9dc5ada82b70b59e, 0xc5371912364ce306, 0xf684df56c3e01bc7, + 0x9a130b963a6c115d, 0xc097ce7bc90715b4, 0xf0bdc21abb48db21, + 0x96769950b50d88f5, 0xbc143fa4e250eb32, 0xeb194f8e1ae525fe, + 0x92efd1b8d0cf37bf, 0xb7abc627050305ae, 0xe596b7b0c643c71a, + 0x8f7e32ce7bea5c70, 0xb35dbf821ae4f38c, 0xe0352f62a19e306f}; + return pow10_significands[k - float_info::min_k]; + } + + struct compute_mul_result { + carrier_uint result; + bool is_integer; + }; + struct compute_mul_parity_result { + bool parity; + bool is_integer; + }; + + static compute_mul_result compute_mul( + carrier_uint u, const cache_entry_type& cache) noexcept { + auto r = umul96_upper64(u, cache); + return {static_cast(r >> 32), + static_cast(r) == 0}; + } + + static uint32_t compute_delta(const cache_entry_type& cache, + int beta) noexcept { + return static_cast(cache >> (64 - 1 - beta)); + } + + static compute_mul_parity_result compute_mul_parity( + carrier_uint two_f, const cache_entry_type& cache, int beta) noexcept { + FMT_ASSERT(beta >= 1, ""); + FMT_ASSERT(beta < 64, ""); + + auto r = umul96_lower64(two_f, cache); + return {((r >> (64 - beta)) & 1) != 0, + static_cast(r >> (32 - beta)) == 0}; + } + + static carrier_uint compute_left_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept { + return static_cast( + (cache - (cache >> (num_significand_bits() + 2))) >> + (64 - num_significand_bits() - 1 - beta)); + } + + static carrier_uint compute_right_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept { + return static_cast( + (cache + (cache >> (num_significand_bits() + 1))) >> + (64 - num_significand_bits() - 1 - beta)); + } + + static carrier_uint compute_round_up_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept { + return (static_cast( + cache >> (64 - num_significand_bits() - 2 - beta)) + + 1) / + 2; + } +}; + +template <> struct cache_accessor { + using carrier_uint = float_info::carrier_uint; + using cache_entry_type = uint128_fallback; + + static uint128_fallback get_cached_power(int k) noexcept { + FMT_ASSERT(k >= float_info::min_k && k <= float_info::max_k, + "k is out of range"); + + static constexpr const uint128_fallback pow10_significands[] = { +#if FMT_USE_FULL_CACHE_DRAGONBOX + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, + {0x9faacf3df73609b1, 0x77b191618c54e9ad}, + {0xc795830d75038c1d, 0xd59df5b9ef6a2418}, + {0xf97ae3d0d2446f25, 0x4b0573286b44ad1e}, + {0x9becce62836ac577, 0x4ee367f9430aec33}, + {0xc2e801fb244576d5, 0x229c41f793cda740}, + {0xf3a20279ed56d48a, 0x6b43527578c11110}, + {0x9845418c345644d6, 0x830a13896b78aaaa}, + {0xbe5691ef416bd60c, 0x23cc986bc656d554}, + {0xedec366b11c6cb8f, 0x2cbfbe86b7ec8aa9}, + {0x94b3a202eb1c3f39, 0x7bf7d71432f3d6aa}, + {0xb9e08a83a5e34f07, 0xdaf5ccd93fb0cc54}, + {0xe858ad248f5c22c9, 0xd1b3400f8f9cff69}, + {0x91376c36d99995be, 0x23100809b9c21fa2}, + {0xb58547448ffffb2d, 0xabd40a0c2832a78b}, + {0xe2e69915b3fff9f9, 0x16c90c8f323f516d}, + {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4}, + {0xb1442798f49ffb4a, 0x99cd11cfdf41779d}, + {0xdd95317f31c7fa1d, 0x40405643d711d584}, + {0x8a7d3eef7f1cfc52, 0x482835ea666b2573}, + {0xad1c8eab5ee43b66, 0xda3243650005eed0}, + {0xd863b256369d4a40, 0x90bed43e40076a83}, + {0x873e4f75e2224e68, 0x5a7744a6e804a292}, + {0xa90de3535aaae202, 0x711515d0a205cb37}, + {0xd3515c2831559a83, 0x0d5a5b44ca873e04}, + {0x8412d9991ed58091, 0xe858790afe9486c3}, + {0xa5178fff668ae0b6, 0x626e974dbe39a873}, + {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, + {0x80fa687f881c7f8e, 0x7ce66634bc9d0b9a}, + {0xa139029f6a239f72, 0x1c1fffc1ebc44e81}, + {0xc987434744ac874e, 0xa327ffb266b56221}, + {0xfbe9141915d7a922, 0x4bf1ff9f0062baa9}, + {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa}, + {0xc4ce17b399107c22, 0xcb550fb4384d21d4}, + {0xf6019da07f549b2b, 0x7e2a53a146606a49}, + {0x99c102844f94e0fb, 0x2eda7444cbfc426e}, + {0xc0314325637a1939, 0xfa911155fefb5309}, + {0xf03d93eebc589f88, 0x793555ab7eba27cb}, + {0x96267c7535b763b5, 0x4bc1558b2f3458df}, + {0xbbb01b9283253ca2, 0x9eb1aaedfb016f17}, + {0xea9c227723ee8bcb, 0x465e15a979c1cadd}, + {0x92a1958a7675175f, 0x0bfacd89ec191eca}, + {0xb749faed14125d36, 0xcef980ec671f667c}, + {0xe51c79a85916f484, 0x82b7e12780e7401b}, + {0x8f31cc0937ae58d2, 0xd1b2ecb8b0908811}, + {0xb2fe3f0b8599ef07, 0x861fa7e6dcb4aa16}, + {0xdfbdcece67006ac9, 0x67a791e093e1d49b}, + {0x8bd6a141006042bd, 0xe0c8bb2c5c6d24e1}, + {0xaecc49914078536d, 0x58fae9f773886e19}, + {0xda7f5bf590966848, 0xaf39a475506a899f}, + {0x888f99797a5e012d, 0x6d8406c952429604}, + {0xaab37fd7d8f58178, 0xc8e5087ba6d33b84}, + {0xd5605fcdcf32e1d6, 0xfb1e4a9a90880a65}, + {0x855c3be0a17fcd26, 0x5cf2eea09a550680}, + {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, + {0xd0601d8efc57b08b, 0xf13b94daf124da27}, + {0x823c12795db6ce57, 0x76c53d08d6b70859}, + {0xa2cb1717b52481ed, 0x54768c4b0c64ca6f}, + {0xcb7ddcdda26da268, 0xa9942f5dcf7dfd0a}, + {0xfe5d54150b090b02, 0xd3f93b35435d7c4d}, + {0x9efa548d26e5a6e1, 0xc47bc5014a1a6db0}, + {0xc6b8e9b0709f109a, 0x359ab6419ca1091c}, + {0xf867241c8cc6d4c0, 0xc30163d203c94b63}, + {0x9b407691d7fc44f8, 0x79e0de63425dcf1e}, + {0xc21094364dfb5636, 0x985915fc12f542e5}, + {0xf294b943e17a2bc4, 0x3e6f5b7b17b2939e}, + {0x979cf3ca6cec5b5a, 0xa705992ceecf9c43}, + {0xbd8430bd08277231, 0x50c6ff782a838354}, + {0xece53cec4a314ebd, 0xa4f8bf5635246429}, + {0x940f4613ae5ed136, 0x871b7795e136be9a}, + {0xb913179899f68584, 0x28e2557b59846e40}, + {0xe757dd7ec07426e5, 0x331aeada2fe589d0}, + {0x9096ea6f3848984f, 0x3ff0d2c85def7622}, + {0xb4bca50b065abe63, 0x0fed077a756b53aa}, + {0xe1ebce4dc7f16dfb, 0xd3e8495912c62895}, + {0x8d3360f09cf6e4bd, 0x64712dd7abbbd95d}, + {0xb080392cc4349dec, 0xbd8d794d96aacfb4}, + {0xdca04777f541c567, 0xecf0d7a0fc5583a1}, + {0x89e42caaf9491b60, 0xf41686c49db57245}, + {0xac5d37d5b79b6239, 0x311c2875c522ced6}, + {0xd77485cb25823ac7, 0x7d633293366b828c}, + {0x86a8d39ef77164bc, 0xae5dff9c02033198}, + {0xa8530886b54dbdeb, 0xd9f57f830283fdfd}, + {0xd267caa862a12d66, 0xd072df63c324fd7c}, + {0x8380dea93da4bc60, 0x4247cb9e59f71e6e}, + {0xa46116538d0deb78, 0x52d9be85f074e609}, + {0xcd795be870516656, 0x67902e276c921f8c}, + {0x806bd9714632dff6, 0x00ba1cd8a3db53b7}, + {0xa086cfcd97bf97f3, 0x80e8a40eccd228a5}, + {0xc8a883c0fdaf7df0, 0x6122cd128006b2ce}, + {0xfad2a4b13d1b5d6c, 0x796b805720085f82}, + {0x9cc3a6eec6311a63, 0xcbe3303674053bb1}, + {0xc3f490aa77bd60fc, 0xbedbfc4411068a9d}, + {0xf4f1b4d515acb93b, 0xee92fb5515482d45}, + {0x991711052d8bf3c5, 0x751bdd152d4d1c4b}, + {0xbf5cd54678eef0b6, 0xd262d45a78a0635e}, + {0xef340a98172aace4, 0x86fb897116c87c35}, + {0x9580869f0e7aac0e, 0xd45d35e6ae3d4da1}, + {0xbae0a846d2195712, 0x8974836059cca10a}, + {0xe998d258869facd7, 0x2bd1a438703fc94c}, + {0x91ff83775423cc06, 0x7b6306a34627ddd0}, + {0xb67f6455292cbf08, 0x1a3bc84c17b1d543}, + {0xe41f3d6a7377eeca, 0x20caba5f1d9e4a94}, + {0x8e938662882af53e, 0x547eb47b7282ee9d}, + {0xb23867fb2a35b28d, 0xe99e619a4f23aa44}, + {0xdec681f9f4c31f31, 0x6405fa00e2ec94d5}, + {0x8b3c113c38f9f37e, 0xde83bc408dd3dd05}, + {0xae0b158b4738705e, 0x9624ab50b148d446}, + {0xd98ddaee19068c76, 0x3badd624dd9b0958}, + {0x87f8a8d4cfa417c9, 0xe54ca5d70a80e5d7}, + {0xa9f6d30a038d1dbc, 0x5e9fcf4ccd211f4d}, + {0xd47487cc8470652b, 0x7647c32000696720}, + {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074}, + {0xa5fb0a17c777cf09, 0xf468107100525891}, + {0xcf79cc9db955c2cc, 0x7182148d4066eeb5}, + {0x81ac1fe293d599bf, 0xc6f14cd848405531}, + {0xa21727db38cb002f, 0xb8ada00e5a506a7d}, + {0xca9cf1d206fdc03b, 0xa6d90811f0e4851d}, + {0xfd442e4688bd304a, 0x908f4a166d1da664}, + {0x9e4a9cec15763e2e, 0x9a598e4e043287ff}, + {0xc5dd44271ad3cdba, 0x40eff1e1853f29fe}, + {0xf7549530e188c128, 0xd12bee59e68ef47d}, + {0x9a94dd3e8cf578b9, 0x82bb74f8301958cf}, + {0xc13a148e3032d6e7, 0xe36a52363c1faf02}, + {0xf18899b1bc3f8ca1, 0xdc44e6c3cb279ac2}, + {0x96f5600f15a7b7e5, 0x29ab103a5ef8c0ba}, + {0xbcb2b812db11a5de, 0x7415d448f6b6f0e8}, + {0xebdf661791d60f56, 0x111b495b3464ad22}, + {0x936b9fcebb25c995, 0xcab10dd900beec35}, + {0xb84687c269ef3bfb, 0x3d5d514f40eea743}, + {0xe65829b3046b0afa, 0x0cb4a5a3112a5113}, + {0x8ff71a0fe2c2e6dc, 0x47f0e785eaba72ac}, + {0xb3f4e093db73a093, 0x59ed216765690f57}, + {0xe0f218b8d25088b8, 0x306869c13ec3532d}, + {0x8c974f7383725573, 0x1e414218c73a13fc}, + {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, + {0xdbac6c247d62a583, 0xdf45f746b74abf3a}, + {0x894bc396ce5da772, 0x6b8bba8c328eb784}, + {0xab9eb47c81f5114f, 0x066ea92f3f326565}, + {0xd686619ba27255a2, 0xc80a537b0efefebe}, + {0x8613fd0145877585, 0xbd06742ce95f5f37}, + {0xa798fc4196e952e7, 0x2c48113823b73705}, + {0xd17f3b51fca3a7a0, 0xf75a15862ca504c6}, + {0x82ef85133de648c4, 0x9a984d73dbe722fc}, + {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb}, + {0xcc963fee10b7d1b3, 0x318df905079926a9}, + {0xffbbcfe994e5c61f, 0xfdf17746497f7053}, + {0x9fd561f1fd0f9bd3, 0xfeb6ea8bedefa634}, + {0xc7caba6e7c5382c8, 0xfe64a52ee96b8fc1}, + {0xf9bd690a1b68637b, 0x3dfdce7aa3c673b1}, + {0x9c1661a651213e2d, 0x06bea10ca65c084f}, + {0xc31bfa0fe5698db8, 0x486e494fcff30a63}, + {0xf3e2f893dec3f126, 0x5a89dba3c3efccfb}, + {0x986ddb5c6b3a76b7, 0xf89629465a75e01d}, + {0xbe89523386091465, 0xf6bbb397f1135824}, + {0xee2ba6c0678b597f, 0x746aa07ded582e2d}, + {0x94db483840b717ef, 0xa8c2a44eb4571cdd}, + {0xba121a4650e4ddeb, 0x92f34d62616ce414}, + {0xe896a0d7e51e1566, 0x77b020baf9c81d18}, + {0x915e2486ef32cd60, 0x0ace1474dc1d122f}, + {0xb5b5ada8aaff80b8, 0x0d819992132456bb}, + {0xe3231912d5bf60e6, 0x10e1fff697ed6c6a}, + {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, + {0xb1736b96b6fd83b3, 0xbd308ff8a6b17cb3}, + {0xddd0467c64bce4a0, 0xac7cb3f6d05ddbdf}, + {0x8aa22c0dbef60ee4, 0x6bcdf07a423aa96c}, + {0xad4ab7112eb3929d, 0x86c16c98d2c953c7}, + {0xd89d64d57a607744, 0xe871c7bf077ba8b8}, + {0x87625f056c7c4a8b, 0x11471cd764ad4973}, + {0xa93af6c6c79b5d2d, 0xd598e40d3dd89bd0}, + {0xd389b47879823479, 0x4aff1d108d4ec2c4}, + {0x843610cb4bf160cb, 0xcedf722a585139bb}, + {0xa54394fe1eedb8fe, 0xc2974eb4ee658829}, + {0xce947a3da6a9273e, 0x733d226229feea33}, + {0x811ccc668829b887, 0x0806357d5a3f5260}, + {0xa163ff802a3426a8, 0xca07c2dcb0cf26f8}, + {0xc9bcff6034c13052, 0xfc89b393dd02f0b6}, + {0xfc2c3f3841f17c67, 0xbbac2078d443ace3}, + {0x9d9ba7832936edc0, 0xd54b944b84aa4c0e}, + {0xc5029163f384a931, 0x0a9e795e65d4df12}, + {0xf64335bcf065d37d, 0x4d4617b5ff4a16d6}, + {0x99ea0196163fa42e, 0x504bced1bf8e4e46}, + {0xc06481fb9bcf8d39, 0xe45ec2862f71e1d7}, + {0xf07da27a82c37088, 0x5d767327bb4e5a4d}, + {0x964e858c91ba2655, 0x3a6a07f8d510f870}, + {0xbbe226efb628afea, 0x890489f70a55368c}, + {0xeadab0aba3b2dbe5, 0x2b45ac74ccea842f}, + {0x92c8ae6b464fc96f, 0x3b0b8bc90012929e}, + {0xb77ada0617e3bbcb, 0x09ce6ebb40173745}, + {0xe55990879ddcaabd, 0xcc420a6a101d0516}, + {0x8f57fa54c2a9eab6, 0x9fa946824a12232e}, + {0xb32df8e9f3546564, 0x47939822dc96abfa}, + {0xdff9772470297ebd, 0x59787e2b93bc56f8}, + {0x8bfbea76c619ef36, 0x57eb4edb3c55b65b}, + {0xaefae51477a06b03, 0xede622920b6b23f2}, + {0xdab99e59958885c4, 0xe95fab368e45ecee}, + {0x88b402f7fd75539b, 0x11dbcb0218ebb415}, + {0xaae103b5fcd2a881, 0xd652bdc29f26a11a}, + {0xd59944a37c0752a2, 0x4be76d3346f04960}, + {0x857fcae62d8493a5, 0x6f70a4400c562ddc}, + {0xa6dfbd9fb8e5b88e, 0xcb4ccd500f6bb953}, + {0xd097ad07a71f26b2, 0x7e2000a41346a7a8}, + {0x825ecc24c873782f, 0x8ed400668c0c28c9}, + {0xa2f67f2dfa90563b, 0x728900802f0f32fb}, + {0xcbb41ef979346bca, 0x4f2b40a03ad2ffba}, + {0xfea126b7d78186bc, 0xe2f610c84987bfa9}, + {0x9f24b832e6b0f436, 0x0dd9ca7d2df4d7ca}, + {0xc6ede63fa05d3143, 0x91503d1c79720dbc}, + {0xf8a95fcf88747d94, 0x75a44c6397ce912b}, + {0x9b69dbe1b548ce7c, 0xc986afbe3ee11abb}, + {0xc24452da229b021b, 0xfbe85badce996169}, + {0xf2d56790ab41c2a2, 0xfae27299423fb9c4}, + {0x97c560ba6b0919a5, 0xdccd879fc967d41b}, + {0xbdb6b8e905cb600f, 0x5400e987bbc1c921}, + {0xed246723473e3813, 0x290123e9aab23b69}, + {0x9436c0760c86e30b, 0xf9a0b6720aaf6522}, + {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, + {0xe7958cb87392c2c2, 0xb60b1d1230b20e05}, + {0x90bd77f3483bb9b9, 0xb1c6f22b5e6f48c3}, + {0xb4ecd5f01a4aa828, 0x1e38aeb6360b1af4}, + {0xe2280b6c20dd5232, 0x25c6da63c38de1b1}, + {0x8d590723948a535f, 0x579c487e5a38ad0f}, + {0xb0af48ec79ace837, 0x2d835a9df0c6d852}, + {0xdcdb1b2798182244, 0xf8e431456cf88e66}, + {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900}, + {0xac8b2d36eed2dac5, 0xe272467e3d222f40}, + {0xd7adf884aa879177, 0x5b0ed81dcc6abb10}, + {0x86ccbb52ea94baea, 0x98e947129fc2b4ea}, + {0xa87fea27a539e9a5, 0x3f2398d747b36225}, + {0xd29fe4b18e88640e, 0x8eec7f0d19a03aae}, + {0x83a3eeeef9153e89, 0x1953cf68300424ad}, + {0xa48ceaaab75a8e2b, 0x5fa8c3423c052dd8}, + {0xcdb02555653131b6, 0x3792f412cb06794e}, + {0x808e17555f3ebf11, 0xe2bbd88bbee40bd1}, + {0xa0b19d2ab70e6ed6, 0x5b6aceaeae9d0ec5}, + {0xc8de047564d20a8b, 0xf245825a5a445276}, + {0xfb158592be068d2e, 0xeed6e2f0f0d56713}, + {0x9ced737bb6c4183d, 0x55464dd69685606c}, + {0xc428d05aa4751e4c, 0xaa97e14c3c26b887}, + {0xf53304714d9265df, 0xd53dd99f4b3066a9}, + {0x993fe2c6d07b7fab, 0xe546a8038efe402a}, + {0xbf8fdb78849a5f96, 0xde98520472bdd034}, + {0xef73d256a5c0f77c, 0x963e66858f6d4441}, + {0x95a8637627989aad, 0xdde7001379a44aa9}, + {0xbb127c53b17ec159, 0x5560c018580d5d53}, + {0xe9d71b689dde71af, 0xaab8f01e6e10b4a7}, + {0x9226712162ab070d, 0xcab3961304ca70e9}, + {0xb6b00d69bb55c8d1, 0x3d607b97c5fd0d23}, + {0xe45c10c42a2b3b05, 0x8cb89a7db77c506b}, + {0x8eb98a7a9a5b04e3, 0x77f3608e92adb243}, + {0xb267ed1940f1c61c, 0x55f038b237591ed4}, + {0xdf01e85f912e37a3, 0x6b6c46dec52f6689}, + {0x8b61313bbabce2c6, 0x2323ac4b3b3da016}, + {0xae397d8aa96c1b77, 0xabec975e0a0d081b}, + {0xd9c7dced53c72255, 0x96e7bd358c904a22}, + {0x881cea14545c7575, 0x7e50d64177da2e55}, + {0xaa242499697392d2, 0xdde50bd1d5d0b9ea}, + {0xd4ad2dbfc3d07787, 0x955e4ec64b44e865}, + {0x84ec3c97da624ab4, 0xbd5af13bef0b113f}, + {0xa6274bbdd0fadd61, 0xecb1ad8aeacdd58f}, + {0xcfb11ead453994ba, 0x67de18eda5814af3}, + {0x81ceb32c4b43fcf4, 0x80eacf948770ced8}, + {0xa2425ff75e14fc31, 0xa1258379a94d028e}, + {0xcad2f7f5359a3b3e, 0x096ee45813a04331}, + {0xfd87b5f28300ca0d, 0x8bca9d6e188853fd}, + {0x9e74d1b791e07e48, 0x775ea264cf55347e}, + {0xc612062576589dda, 0x95364afe032a819e}, + {0xf79687aed3eec551, 0x3a83ddbd83f52205}, + {0x9abe14cd44753b52, 0xc4926a9672793543}, + {0xc16d9a0095928a27, 0x75b7053c0f178294}, + {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, + {0x971da05074da7bee, 0xd3f6fc16ebca5e04}, + {0xbce5086492111aea, 0x88f4bb1ca6bcf585}, + {0xec1e4a7db69561a5, 0x2b31e9e3d06c32e6}, + {0x9392ee8e921d5d07, 0x3aff322e62439fd0}, + {0xb877aa3236a4b449, 0x09befeb9fad487c3}, + {0xe69594bec44de15b, 0x4c2ebe687989a9b4}, + {0x901d7cf73ab0acd9, 0x0f9d37014bf60a11}, + {0xb424dc35095cd80f, 0x538484c19ef38c95}, + {0xe12e13424bb40e13, 0x2865a5f206b06fba}, + {0x8cbccc096f5088cb, 0xf93f87b7442e45d4}, + {0xafebff0bcb24aafe, 0xf78f69a51539d749}, + {0xdbe6fecebdedd5be, 0xb573440e5a884d1c}, + {0x89705f4136b4a597, 0x31680a88f8953031}, + {0xabcc77118461cefc, 0xfdc20d2b36ba7c3e}, + {0xd6bf94d5e57a42bc, 0x3d32907604691b4d}, + {0x8637bd05af6c69b5, 0xa63f9a49c2c1b110}, + {0xa7c5ac471b478423, 0x0fcf80dc33721d54}, + {0xd1b71758e219652b, 0xd3c36113404ea4a9}, + {0x83126e978d4fdf3b, 0x645a1cac083126ea}, + {0xa3d70a3d70a3d70a, 0x3d70a3d70a3d70a4}, + {0xcccccccccccccccc, 0xcccccccccccccccd}, + {0x8000000000000000, 0x0000000000000000}, + {0xa000000000000000, 0x0000000000000000}, + {0xc800000000000000, 0x0000000000000000}, + {0xfa00000000000000, 0x0000000000000000}, + {0x9c40000000000000, 0x0000000000000000}, + {0xc350000000000000, 0x0000000000000000}, + {0xf424000000000000, 0x0000000000000000}, + {0x9896800000000000, 0x0000000000000000}, + {0xbebc200000000000, 0x0000000000000000}, + {0xee6b280000000000, 0x0000000000000000}, + {0x9502f90000000000, 0x0000000000000000}, + {0xba43b74000000000, 0x0000000000000000}, + {0xe8d4a51000000000, 0x0000000000000000}, + {0x9184e72a00000000, 0x0000000000000000}, + {0xb5e620f480000000, 0x0000000000000000}, + {0xe35fa931a0000000, 0x0000000000000000}, + {0x8e1bc9bf04000000, 0x0000000000000000}, + {0xb1a2bc2ec5000000, 0x0000000000000000}, + {0xde0b6b3a76400000, 0x0000000000000000}, + {0x8ac7230489e80000, 0x0000000000000000}, + {0xad78ebc5ac620000, 0x0000000000000000}, + {0xd8d726b7177a8000, 0x0000000000000000}, + {0x878678326eac9000, 0x0000000000000000}, + {0xa968163f0a57b400, 0x0000000000000000}, + {0xd3c21bcecceda100, 0x0000000000000000}, + {0x84595161401484a0, 0x0000000000000000}, + {0xa56fa5b99019a5c8, 0x0000000000000000}, + {0xcecb8f27f4200f3a, 0x0000000000000000}, + {0x813f3978f8940984, 0x4000000000000000}, + {0xa18f07d736b90be5, 0x5000000000000000}, + {0xc9f2c9cd04674ede, 0xa400000000000000}, + {0xfc6f7c4045812296, 0x4d00000000000000}, + {0x9dc5ada82b70b59d, 0xf020000000000000}, + {0xc5371912364ce305, 0x6c28000000000000}, + {0xf684df56c3e01bc6, 0xc732000000000000}, + {0x9a130b963a6c115c, 0x3c7f400000000000}, + {0xc097ce7bc90715b3, 0x4b9f100000000000}, + {0xf0bdc21abb48db20, 0x1e86d40000000000}, + {0x96769950b50d88f4, 0x1314448000000000}, + {0xbc143fa4e250eb31, 0x17d955a000000000}, + {0xeb194f8e1ae525fd, 0x5dcfab0800000000}, + {0x92efd1b8d0cf37be, 0x5aa1cae500000000}, + {0xb7abc627050305ad, 0xf14a3d9e40000000}, + {0xe596b7b0c643c719, 0x6d9ccd05d0000000}, + {0x8f7e32ce7bea5c6f, 0xe4820023a2000000}, + {0xb35dbf821ae4f38b, 0xdda2802c8a800000}, + {0xe0352f62a19e306e, 0xd50b2037ad200000}, + {0x8c213d9da502de45, 0x4526f422cc340000}, + {0xaf298d050e4395d6, 0x9670b12b7f410000}, + {0xdaf3f04651d47b4c, 0x3c0cdd765f114000}, + {0x88d8762bf324cd0f, 0xa5880a69fb6ac800}, + {0xab0e93b6efee0053, 0x8eea0d047a457a00}, + {0xd5d238a4abe98068, 0x72a4904598d6d880}, + {0x85a36366eb71f041, 0x47a6da2b7f864750}, + {0xa70c3c40a64e6c51, 0x999090b65f67d924}, + {0xd0cf4b50cfe20765, 0xfff4b4e3f741cf6d}, + {0x82818f1281ed449f, 0xbff8f10e7a8921a5}, + {0xa321f2d7226895c7, 0xaff72d52192b6a0e}, + {0xcbea6f8ceb02bb39, 0x9bf4f8a69f764491}, + {0xfee50b7025c36a08, 0x02f236d04753d5b5}, + {0x9f4f2726179a2245, 0x01d762422c946591}, + {0xc722f0ef9d80aad6, 0x424d3ad2b7b97ef6}, + {0xf8ebad2b84e0d58b, 0xd2e0898765a7deb3}, + {0x9b934c3b330c8577, 0x63cc55f49f88eb30}, + {0xc2781f49ffcfa6d5, 0x3cbf6b71c76b25fc}, + {0xf316271c7fc3908a, 0x8bef464e3945ef7b}, + {0x97edd871cfda3a56, 0x97758bf0e3cbb5ad}, + {0xbde94e8e43d0c8ec, 0x3d52eeed1cbea318}, + {0xed63a231d4c4fb27, 0x4ca7aaa863ee4bde}, + {0x945e455f24fb1cf8, 0x8fe8caa93e74ef6b}, + {0xb975d6b6ee39e436, 0xb3e2fd538e122b45}, + {0xe7d34c64a9c85d44, 0x60dbbca87196b617}, + {0x90e40fbeea1d3a4a, 0xbc8955e946fe31ce}, + {0xb51d13aea4a488dd, 0x6babab6398bdbe42}, + {0xe264589a4dcdab14, 0xc696963c7eed2dd2}, + {0x8d7eb76070a08aec, 0xfc1e1de5cf543ca3}, + {0xb0de65388cc8ada8, 0x3b25a55f43294bcc}, + {0xdd15fe86affad912, 0x49ef0eb713f39ebf}, + {0x8a2dbf142dfcc7ab, 0x6e3569326c784338}, + {0xacb92ed9397bf996, 0x49c2c37f07965405}, + {0xd7e77a8f87daf7fb, 0xdc33745ec97be907}, + {0x86f0ac99b4e8dafd, 0x69a028bb3ded71a4}, + {0xa8acd7c0222311bc, 0xc40832ea0d68ce0d}, + {0xd2d80db02aabd62b, 0xf50a3fa490c30191}, + {0x83c7088e1aab65db, 0x792667c6da79e0fb}, + {0xa4b8cab1a1563f52, 0x577001b891185939}, + {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, + {0x80b05e5ac60b6178, 0x544f8158315b05b5}, + {0xa0dc75f1778e39d6, 0x696361ae3db1c722}, + {0xc913936dd571c84c, 0x03bc3a19cd1e38ea}, + {0xfb5878494ace3a5f, 0x04ab48a04065c724}, + {0x9d174b2dcec0e47b, 0x62eb0d64283f9c77}, + {0xc45d1df942711d9a, 0x3ba5d0bd324f8395}, + {0xf5746577930d6500, 0xca8f44ec7ee3647a}, + {0x9968bf6abbe85f20, 0x7e998b13cf4e1ecc}, + {0xbfc2ef456ae276e8, 0x9e3fedd8c321a67f}, + {0xefb3ab16c59b14a2, 0xc5cfe94ef3ea101f}, + {0x95d04aee3b80ece5, 0xbba1f1d158724a13}, + {0xbb445da9ca61281f, 0x2a8a6e45ae8edc98}, + {0xea1575143cf97226, 0xf52d09d71a3293be}, + {0x924d692ca61be758, 0x593c2626705f9c57}, + {0xb6e0c377cfa2e12e, 0x6f8b2fb00c77836d}, + {0xe498f455c38b997a, 0x0b6dfb9c0f956448}, + {0x8edf98b59a373fec, 0x4724bd4189bd5ead}, + {0xb2977ee300c50fe7, 0x58edec91ec2cb658}, + {0xdf3d5e9bc0f653e1, 0x2f2967b66737e3ee}, + {0x8b865b215899f46c, 0xbd79e0d20082ee75}, + {0xae67f1e9aec07187, 0xecd8590680a3aa12}, + {0xda01ee641a708de9, 0xe80e6f4820cc9496}, + {0x884134fe908658b2, 0x3109058d147fdcde}, + {0xaa51823e34a7eede, 0xbd4b46f0599fd416}, + {0xd4e5e2cdc1d1ea96, 0x6c9e18ac7007c91b}, + {0x850fadc09923329e, 0x03e2cf6bc604ddb1}, + {0xa6539930bf6bff45, 0x84db8346b786151d}, + {0xcfe87f7cef46ff16, 0xe612641865679a64}, + {0x81f14fae158c5f6e, 0x4fcb7e8f3f60c07f}, + {0xa26da3999aef7749, 0xe3be5e330f38f09e}, + {0xcb090c8001ab551c, 0x5cadf5bfd3072cc6}, + {0xfdcb4fa002162a63, 0x73d9732fc7c8f7f7}, + {0x9e9f11c4014dda7e, 0x2867e7fddcdd9afb}, + {0xc646d63501a1511d, 0xb281e1fd541501b9}, + {0xf7d88bc24209a565, 0x1f225a7ca91a4227}, + {0x9ae757596946075f, 0x3375788de9b06959}, + {0xc1a12d2fc3978937, 0x0052d6b1641c83af}, + {0xf209787bb47d6b84, 0xc0678c5dbd23a49b}, + {0x9745eb4d50ce6332, 0xf840b7ba963646e1}, + {0xbd176620a501fbff, 0xb650e5a93bc3d899}, + {0xec5d3fa8ce427aff, 0xa3e51f138ab4cebf}, + {0x93ba47c980e98cdf, 0xc66f336c36b10138}, + {0xb8a8d9bbe123f017, 0xb80b0047445d4185}, + {0xe6d3102ad96cec1d, 0xa60dc059157491e6}, + {0x9043ea1ac7e41392, 0x87c89837ad68db30}, + {0xb454e4a179dd1877, 0x29babe4598c311fc}, + {0xe16a1dc9d8545e94, 0xf4296dd6fef3d67b}, + {0x8ce2529e2734bb1d, 0x1899e4a65f58660d}, + {0xb01ae745b101e9e4, 0x5ec05dcff72e7f90}, + {0xdc21a1171d42645d, 0x76707543f4fa1f74}, + {0x899504ae72497eba, 0x6a06494a791c53a9}, + {0xabfa45da0edbde69, 0x0487db9d17636893}, + {0xd6f8d7509292d603, 0x45a9d2845d3c42b7}, + {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, + {0xa7f26836f282b732, 0x8e6cac7768d7141f}, + {0xd1ef0244af2364ff, 0x3207d795430cd927}, + {0x8335616aed761f1f, 0x7f44e6bd49e807b9}, + {0xa402b9c5a8d3a6e7, 0x5f16206c9c6209a7}, + {0xcd036837130890a1, 0x36dba887c37a8c10}, + {0x802221226be55a64, 0xc2494954da2c978a}, + {0xa02aa96b06deb0fd, 0xf2db9baa10b7bd6d}, + {0xc83553c5c8965d3d, 0x6f92829494e5acc8}, + {0xfa42a8b73abbf48c, 0xcb772339ba1f17fa}, + {0x9c69a97284b578d7, 0xff2a760414536efc}, + {0xc38413cf25e2d70d, 0xfef5138519684abb}, + {0xf46518c2ef5b8cd1, 0x7eb258665fc25d6a}, + {0x98bf2f79d5993802, 0xef2f773ffbd97a62}, + {0xbeeefb584aff8603, 0xaafb550ffacfd8fb}, + {0xeeaaba2e5dbf6784, 0x95ba2a53f983cf39}, + {0x952ab45cfa97a0b2, 0xdd945a747bf26184}, + {0xba756174393d88df, 0x94f971119aeef9e5}, + {0xe912b9d1478ceb17, 0x7a37cd5601aab85e}, + {0x91abb422ccb812ee, 0xac62e055c10ab33b}, + {0xb616a12b7fe617aa, 0x577b986b314d600a}, + {0xe39c49765fdf9d94, 0xed5a7e85fda0b80c}, + {0x8e41ade9fbebc27d, 0x14588f13be847308}, + {0xb1d219647ae6b31c, 0x596eb2d8ae258fc9}, + {0xde469fbd99a05fe3, 0x6fca5f8ed9aef3bc}, + {0x8aec23d680043bee, 0x25de7bb9480d5855}, + {0xada72ccc20054ae9, 0xaf561aa79a10ae6b}, + {0xd910f7ff28069da4, 0x1b2ba1518094da05}, + {0x87aa9aff79042286, 0x90fb44d2f05d0843}, + {0xa99541bf57452b28, 0x353a1607ac744a54}, + {0xd3fa922f2d1675f2, 0x42889b8997915ce9}, + {0x847c9b5d7c2e09b7, 0x69956135febada12}, + {0xa59bc234db398c25, 0x43fab9837e699096}, + {0xcf02b2c21207ef2e, 0x94f967e45e03f4bc}, + {0x8161afb94b44f57d, 0x1d1be0eebac278f6}, + {0xa1ba1ba79e1632dc, 0x6462d92a69731733}, + {0xca28a291859bbf93, 0x7d7b8f7503cfdcff}, + {0xfcb2cb35e702af78, 0x5cda735244c3d43f}, + {0x9defbf01b061adab, 0x3a0888136afa64a8}, + {0xc56baec21c7a1916, 0x088aaa1845b8fdd1}, + {0xf6c69a72a3989f5b, 0x8aad549e57273d46}, + {0x9a3c2087a63f6399, 0x36ac54e2f678864c}, + {0xc0cb28a98fcf3c7f, 0x84576a1bb416a7de}, + {0xf0fdf2d3f3c30b9f, 0x656d44a2a11c51d6}, + {0x969eb7c47859e743, 0x9f644ae5a4b1b326}, + {0xbc4665b596706114, 0x873d5d9f0dde1fef}, + {0xeb57ff22fc0c7959, 0xa90cb506d155a7eb}, + {0x9316ff75dd87cbd8, 0x09a7f12442d588f3}, + {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30}, + {0xe5d3ef282a242e81, 0x8f1668c8a86da5fb}, + {0x8fa475791a569d10, 0xf96e017d694487bd}, + {0xb38d92d760ec4455, 0x37c981dcc395a9ad}, + {0xe070f78d3927556a, 0x85bbe253f47b1418}, + {0x8c469ab843b89562, 0x93956d7478ccec8f}, + {0xaf58416654a6babb, 0x387ac8d1970027b3}, + {0xdb2e51bfe9d0696a, 0x06997b05fcc0319f}, + {0x88fcf317f22241e2, 0x441fece3bdf81f04}, + {0xab3c2fddeeaad25a, 0xd527e81cad7626c4}, + {0xd60b3bd56a5586f1, 0x8a71e223d8d3b075}, + {0x85c7056562757456, 0xf6872d5667844e4a}, + {0xa738c6bebb12d16c, 0xb428f8ac016561dc}, + {0xd106f86e69d785c7, 0xe13336d701beba53}, + {0x82a45b450226b39c, 0xecc0024661173474}, + {0xa34d721642b06084, 0x27f002d7f95d0191}, + {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5}, + {0xff290242c83396ce, 0x7e67047175a15272}, + {0x9f79a169bd203e41, 0x0f0062c6e984d387}, + {0xc75809c42c684dd1, 0x52c07b78a3e60869}, + {0xf92e0c3537826145, 0xa7709a56ccdf8a83}, + {0x9bbcc7a142b17ccb, 0x88a66076400bb692}, + {0xc2abf989935ddbfe, 0x6acff893d00ea436}, + {0xf356f7ebf83552fe, 0x0583f6b8c4124d44}, + {0x98165af37b2153de, 0xc3727a337a8b704b}, + {0xbe1bf1b059e9a8d6, 0x744f18c0592e4c5d}, + {0xeda2ee1c7064130c, 0x1162def06f79df74}, + {0x9485d4d1c63e8be7, 0x8addcb5645ac2ba9}, + {0xb9a74a0637ce2ee1, 0x6d953e2bd7173693}, + {0xe8111c87c5c1ba99, 0xc8fa8db6ccdd0438}, + {0x910ab1d4db9914a0, 0x1d9c9892400a22a3}, + {0xb54d5e4a127f59c8, 0x2503beb6d00cab4c}, + {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e}, + {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, + {0xb10d8e1456105dad, 0x7425a83e872c5f48}, + {0xdd50f1996b947518, 0xd12f124e28f7771a}, + {0x8a5296ffe33cc92f, 0x82bd6b70d99aaa70}, + {0xace73cbfdc0bfb7b, 0x636cc64d1001550c}, + {0xd8210befd30efa5a, 0x3c47f7e05401aa4f}, + {0x8714a775e3e95c78, 0x65acfaec34810a72}, + {0xa8d9d1535ce3b396, 0x7f1839a741a14d0e}, + {0xd31045a8341ca07c, 0x1ede48111209a051}, + {0x83ea2b892091e44d, 0x934aed0aab460433}, + {0xa4e4b66b68b65d60, 0xf81da84d56178540}, + {0xce1de40642e3f4b9, 0x36251260ab9d668f}, + {0x80d2ae83e9ce78f3, 0xc1d72b7c6b42601a}, + {0xa1075a24e4421730, 0xb24cf65b8612f820}, + {0xc94930ae1d529cfc, 0xdee033f26797b628}, + {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2}, + {0x9d412e0806e88aa5, 0x8e1f289560ee864f}, + {0xc491798a08a2ad4e, 0xf1a6f2bab92a27e3}, + {0xf5b5d7ec8acb58a2, 0xae10af696774b1dc}, + {0x9991a6f3d6bf1765, 0xacca6da1e0a8ef2a}, + {0xbff610b0cc6edd3f, 0x17fd090a58d32af4}, + {0xeff394dcff8a948e, 0xddfc4b4cef07f5b1}, + {0x95f83d0a1fb69cd9, 0x4abdaf101564f98f}, + {0xbb764c4ca7a4440f, 0x9d6d1ad41abe37f2}, + {0xea53df5fd18d5513, 0x84c86189216dc5ee}, + {0x92746b9be2f8552c, 0x32fd3cf5b4e49bb5}, + {0xb7118682dbb66a77, 0x3fbc8c33221dc2a2}, + {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, + {0x8f05b1163ba6832d, 0x29cb4d87f2a7400f}, + {0xb2c71d5bca9023f8, 0x743e20e9ef511013}, + {0xdf78e4b2bd342cf6, 0x914da9246b255417}, + {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f}, + {0xae9672aba3d0c320, 0xa184ac2473b529b2}, + {0xda3c0f568cc4f3e8, 0xc9e5d72d90a2741f}, + {0x8865899617fb1871, 0x7e2fa67c7a658893}, + {0xaa7eebfb9df9de8d, 0xddbb901b98feeab8}, + {0xd51ea6fa85785631, 0x552a74227f3ea566}, + {0x8533285c936b35de, 0xd53a88958f872760}, + {0xa67ff273b8460356, 0x8a892abaf368f138}, + {0xd01fef10a657842c, 0x2d2b7569b0432d86}, + {0x8213f56a67f6b29b, 0x9c3b29620e29fc74}, + {0xa298f2c501f45f42, 0x8349f3ba91b47b90}, + {0xcb3f2f7642717713, 0x241c70a936219a74}, + {0xfe0efb53d30dd4d7, 0xed238cd383aa0111}, + {0x9ec95d1463e8a506, 0xf4363804324a40ab}, + {0xc67bb4597ce2ce48, 0xb143c6053edcd0d6}, + {0xf81aa16fdc1b81da, 0xdd94b7868e94050b}, + {0x9b10a4e5e9913128, 0xca7cf2b4191c8327}, + {0xc1d4ce1f63f57d72, 0xfd1c2f611f63a3f1}, + {0xf24a01a73cf2dccf, 0xbc633b39673c8ced}, + {0x976e41088617ca01, 0xd5be0503e085d814}, + {0xbd49d14aa79dbc82, 0x4b2d8644d8a74e19}, + {0xec9c459d51852ba2, 0xddf8e7d60ed1219f}, + {0x93e1ab8252f33b45, 0xcabb90e5c942b504}, + {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, + {0xe7109bfba19c0c9d, 0x0cc512670a783ad5}, + {0x906a617d450187e2, 0x27fb2b80668b24c6}, + {0xb484f9dc9641e9da, 0xb1f9f660802dedf7}, + {0xe1a63853bbd26451, 0x5e7873f8a0396974}, + {0x8d07e33455637eb2, 0xdb0b487b6423e1e9}, + {0xb049dc016abc5e5f, 0x91ce1a9a3d2cda63}, + {0xdc5c5301c56b75f7, 0x7641a140cc7810fc}, + {0x89b9b3e11b6329ba, 0xa9e904c87fcb0a9e}, + {0xac2820d9623bf429, 0x546345fa9fbdcd45}, + {0xd732290fbacaf133, 0xa97c177947ad4096}, + {0x867f59a9d4bed6c0, 0x49ed8eabcccc485e}, + {0xa81f301449ee8c70, 0x5c68f256bfff5a75}, + {0xd226fc195c6a2f8c, 0x73832eec6fff3112}, + {0x83585d8fd9c25db7, 0xc831fd53c5ff7eac}, + {0xa42e74f3d032f525, 0xba3e7ca8b77f5e56}, + {0xcd3a1230c43fb26f, 0x28ce1bd2e55f35ec}, + {0x80444b5e7aa7cf85, 0x7980d163cf5b81b4}, + {0xa0555e361951c366, 0xd7e105bcc3326220}, + {0xc86ab5c39fa63440, 0x8dd9472bf3fefaa8}, + {0xfa856334878fc150, 0xb14f98f6f0feb952}, + {0x9c935e00d4b9d8d2, 0x6ed1bf9a569f33d4}, + {0xc3b8358109e84f07, 0x0a862f80ec4700c9}, + {0xf4a642e14c6262c8, 0xcd27bb612758c0fb}, + {0x98e7e9cccfbd7dbd, 0x8038d51cb897789d}, + {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4}, + {0xeeea5d5004981478, 0x1858ccfce06cac75}, + {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, + {0xbaa718e68396cffd, 0xd30560258f54e6bb}, + {0xe950df20247c83fd, 0x47c6b82ef32a206a}, + {0x91d28b7416cdd27e, 0x4cdc331d57fa5442}, + {0xb6472e511c81471d, 0xe0133fe4adf8e953}, + {0xe3d8f9e563a198e5, 0x58180fddd97723a7}, + {0x8e679c2f5e44ff8f, 0x570f09eaa7ea7649}, + {0xb201833b35d63f73, 0x2cd2cc6551e513db}, + {0xde81e40a034bcf4f, 0xf8077f7ea65e58d2}, + {0x8b112e86420f6191, 0xfb04afaf27faf783}, + {0xadd57a27d29339f6, 0x79c5db9af1f9b564}, + {0xd94ad8b1c7380874, 0x18375281ae7822bd}, + {0x87cec76f1c830548, 0x8f2293910d0b15b6}, + {0xa9c2794ae3a3c69a, 0xb2eb3875504ddb23}, + {0xd433179d9c8cb841, 0x5fa60692a46151ec}, + {0x849feec281d7f328, 0xdbc7c41ba6bcd334}, + {0xa5c7ea73224deff3, 0x12b9b522906c0801}, + {0xcf39e50feae16bef, 0xd768226b34870a01}, + {0x81842f29f2cce375, 0xe6a1158300d46641}, + {0xa1e53af46f801c53, 0x60495ae3c1097fd1}, + {0xca5e89b18b602368, 0x385bb19cb14bdfc5}, + {0xfcf62c1dee382c42, 0x46729e03dd9ed7b6}, + {0x9e19db92b4e31ba9, 0x6c07a2c26a8346d2}, + {0xc5a05277621be293, 0xc7098b7305241886}, + {0xf70867153aa2db38, 0xb8cbee4fc66d1ea8}, + {0x9a65406d44a5c903, 0x737f74f1dc043329}, + {0xc0fe908895cf3b44, 0x505f522e53053ff3}, + {0xf13e34aabb430a15, 0x647726b9e7c68ff0}, + {0x96c6e0eab509e64d, 0x5eca783430dc19f6}, + {0xbc789925624c5fe0, 0xb67d16413d132073}, + {0xeb96bf6ebadf77d8, 0xe41c5bd18c57e890}, + {0x933e37a534cbaae7, 0x8e91b962f7b6f15a}, + {0xb80dc58e81fe95a1, 0x723627bbb5a4adb1}, + {0xe61136f2227e3b09, 0xcec3b1aaa30dd91d}, + {0x8fcac257558ee4e6, 0x213a4f0aa5e8a7b2}, + {0xb3bd72ed2af29e1f, 0xa988e2cd4f62d19e}, + {0xe0accfa875af45a7, 0x93eb1b80a33b8606}, + {0x8c6c01c9498d8b88, 0xbc72f130660533c4}, + {0xaf87023b9bf0ee6a, 0xeb8fad7c7f8680b5}, + { 0xdb68c2ca82ed2a05, + 0xa67398db9f6820e2 } +#else + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b}, + {0xce5d73ff402d98e3, 0xfb0a3d212dc81290}, + {0xa6b34ad8c9dfc06f, 0xf42faa48c0ea481f}, + {0x86a8d39ef77164bc, 0xae5dff9c02033198}, + {0xd98ddaee19068c76, 0x3badd624dd9b0958}, + {0xafbd2350644eeacf, 0xe5d1929ef90898fb}, + {0x8df5efabc5979c8f, 0xca8d3ffa1ef463c2}, + {0xe55990879ddcaabd, 0xcc420a6a101d0516}, + {0xb94470938fa89bce, 0xf808e40e8d5b3e6a}, + {0x95a8637627989aad, 0xdde7001379a44aa9}, + {0xf1c90080baf72cb1, 0x5324c68b12dd6339}, + {0xc350000000000000, 0x0000000000000000}, + {0x9dc5ada82b70b59d, 0xf020000000000000}, + {0xfee50b7025c36a08, 0x02f236d04753d5b5}, + {0xcde6fd5e09abcf26, 0xed4c0226b55e6f87}, + {0xa6539930bf6bff45, 0x84db8346b786151d}, + {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3}, + {0xd910f7ff28069da4, 0x1b2ba1518094da05}, + {0xaf58416654a6babb, 0x387ac8d1970027b3}, + {0x8da471a9de737e24, 0x5ceaecfed289e5d3}, + {0xe4d5e82392a40515, 0x0fabaf3feaa5334b}, + {0xb8da1662e7b00a17, 0x3d6a751f3b936244}, + {0x95527a5202df0ccb, 0x0f37801e0c43ebc9}, + {0xf13e34aabb430a15, 0x647726b9e7c68ff0} +#endif + }; + +#if FMT_USE_FULL_CACHE_DRAGONBOX + return pow10_significands[k - float_info::min_k]; +#else + static constexpr const uint64_t powers_of_5_64[] = { + 0x0000000000000001, 0x0000000000000005, 0x0000000000000019, + 0x000000000000007d, 0x0000000000000271, 0x0000000000000c35, + 0x0000000000003d09, 0x000000000001312d, 0x000000000005f5e1, + 0x00000000001dcd65, 0x00000000009502f9, 0x0000000002e90edd, + 0x000000000e8d4a51, 0x0000000048c27395, 0x000000016bcc41e9, + 0x000000071afd498d, 0x0000002386f26fc1, 0x000000b1a2bc2ec5, + 0x000003782dace9d9, 0x00001158e460913d, 0x000056bc75e2d631, + 0x0001b1ae4d6e2ef5, 0x000878678326eac9, 0x002a5a058fc295ed, + 0x00d3c21bcecceda1, 0x0422ca8b0a00a425, 0x14adf4b7320334b9}; + + static const int compression_ratio = 27; + + // Compute base index. + int cache_index = (k - float_info::min_k) / compression_ratio; + int kb = cache_index * compression_ratio + float_info::min_k; + int offset = k - kb; + + // Get base cache. + uint128_fallback base_cache = pow10_significands[cache_index]; + if (offset == 0) return base_cache; + + // Compute the required amount of bit-shift. + int alpha = floor_log2_pow10(kb + offset) - floor_log2_pow10(kb) - offset; + FMT_ASSERT(alpha > 0 && alpha < 64, "shifting error detected"); + + // Try to recover the real cache. + uint64_t pow5 = powers_of_5_64[offset]; + uint128_fallback recovered_cache = umul128(base_cache.high(), pow5); + uint128_fallback middle_low = umul128(base_cache.low(), pow5); + + recovered_cache += middle_low.high(); + + uint64_t high_to_middle = recovered_cache.high() << (64 - alpha); + uint64_t middle_to_low = recovered_cache.low() << (64 - alpha); + + recovered_cache = + uint128_fallback{(recovered_cache.low() >> alpha) | high_to_middle, + ((middle_low.low() >> alpha) | middle_to_low)}; + FMT_ASSERT(recovered_cache.low() + 1 != 0, ""); + return {recovered_cache.high(), recovered_cache.low() + 1}; +#endif + } + + struct compute_mul_result { + carrier_uint result; + bool is_integer; + }; + struct compute_mul_parity_result { + bool parity; + bool is_integer; + }; + + static compute_mul_result compute_mul( + carrier_uint u, const cache_entry_type& cache) noexcept { + auto r = umul192_upper128(u, cache); + return {r.high(), r.low() == 0}; + } + + static uint32_t compute_delta(cache_entry_type const& cache, + int beta) noexcept { + return static_cast(cache.high() >> (64 - 1 - beta)); + } + + static compute_mul_parity_result compute_mul_parity( + carrier_uint two_f, const cache_entry_type& cache, int beta) noexcept { + FMT_ASSERT(beta >= 1, ""); + FMT_ASSERT(beta < 64, ""); + + auto r = umul192_lower128(two_f, cache); + return {((r.high() >> (64 - beta)) & 1) != 0, + ((r.high() << beta) | (r.low() >> (64 - beta))) == 0}; + } + + static carrier_uint compute_left_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept { + return (cache.high() - + (cache.high() >> (num_significand_bits() + 2))) >> + (64 - num_significand_bits() - 1 - beta); + } + + static carrier_uint compute_right_endpoint_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept { + return (cache.high() + + (cache.high() >> (num_significand_bits() + 1))) >> + (64 - num_significand_bits() - 1 - beta); + } + + static carrier_uint compute_round_up_for_shorter_interval_case( + const cache_entry_type& cache, int beta) noexcept { + return ((cache.high() >> (64 - num_significand_bits() - 2 - beta)) + + 1) / + 2; + } +}; + +FMT_FUNC uint128_fallback get_cached_power(int k) noexcept { + return cache_accessor::get_cached_power(k); +} + +// Various integer checks +template +bool is_left_endpoint_integer_shorter_interval(int exponent) noexcept { + const int case_shorter_interval_left_endpoint_lower_threshold = 2; + const int case_shorter_interval_left_endpoint_upper_threshold = 3; + return exponent >= case_shorter_interval_left_endpoint_lower_threshold && + exponent <= case_shorter_interval_left_endpoint_upper_threshold; +} + +// Remove trailing zeros from n and return the number of zeros removed (float) +FMT_INLINE int remove_trailing_zeros(uint32_t& n, int s = 0) noexcept { + FMT_ASSERT(n != 0, ""); + // Modular inverse of 5 (mod 2^32): (mod_inv_5 * 5) mod 2^32 = 1. + constexpr uint32_t mod_inv_5 = 0xcccccccd; + constexpr uint32_t mod_inv_25 = 0xc28f5c29; // = mod_inv_5 * mod_inv_5 + + while (true) { + auto q = rotr(n * mod_inv_25, 2); + if (q > max_value() / 100) break; + n = q; + s += 2; + } + auto q = rotr(n * mod_inv_5, 1); + if (q <= max_value() / 10) { + n = q; + s |= 1; + } + return s; +} + +// Removes trailing zeros and returns the number of zeros removed (double) +FMT_INLINE int remove_trailing_zeros(uint64_t& n) noexcept { + FMT_ASSERT(n != 0, ""); + + // This magic number is ceil(2^90 / 10^8). + constexpr uint64_t magic_number = 12379400392853802749ull; + auto nm = umul128(n, magic_number); + + // Is n is divisible by 10^8? + if ((nm.high() & ((1ull << (90 - 64)) - 1)) == 0 && nm.low() < magic_number) { + // If yes, work with the quotient... + auto n32 = static_cast(nm.high() >> (90 - 64)); + // ... and use the 32 bit variant of the function + int s = remove_trailing_zeros(n32, 8); + n = n32; + return s; + } + + // If n is not divisible by 10^8, work with n itself. + constexpr uint64_t mod_inv_5 = 0xcccccccccccccccd; + constexpr uint64_t mod_inv_25 = 0x8f5c28f5c28f5c29; // = mod_inv_5 * mod_inv_5 + + int s = 0; + while (true) { + auto q = rotr(n * mod_inv_25, 2); + if (q > max_value() / 100) break; + n = q; + s += 2; + } + auto q = rotr(n * mod_inv_5, 1); + if (q <= max_value() / 10) { + n = q; + s |= 1; + } + + return s; +} + +// The main algorithm for shorter interval case +template +FMT_INLINE decimal_fp shorter_interval_case(int exponent) noexcept { + decimal_fp ret_value; + // Compute k and beta + const int minus_k = floor_log10_pow2_minus_log10_4_over_3(exponent); + const int beta = exponent + floor_log2_pow10(-minus_k); + + // Compute xi and zi + using cache_entry_type = typename cache_accessor::cache_entry_type; + const cache_entry_type cache = cache_accessor::get_cached_power(-minus_k); + + auto xi = cache_accessor::compute_left_endpoint_for_shorter_interval_case( + cache, beta); + auto zi = cache_accessor::compute_right_endpoint_for_shorter_interval_case( + cache, beta); + + // If the left endpoint is not an integer, increase it + if (!is_left_endpoint_integer_shorter_interval(exponent)) ++xi; + + // Try bigger divisor + ret_value.significand = zi / 10; + + // If succeed, remove trailing zeros if necessary and return + if (ret_value.significand * 10 >= xi) { + ret_value.exponent = minus_k + 1; + ret_value.exponent += remove_trailing_zeros(ret_value.significand); + return ret_value; + } + + // Otherwise, compute the round-up of y + ret_value.significand = + cache_accessor::compute_round_up_for_shorter_interval_case(cache, + beta); + ret_value.exponent = minus_k; + + // When tie occurs, choose one of them according to the rule + if (exponent >= float_info::shorter_interval_tie_lower_threshold && + exponent <= float_info::shorter_interval_tie_upper_threshold) { + ret_value.significand = ret_value.significand % 2 == 0 + ? ret_value.significand + : ret_value.significand - 1; + } else if (ret_value.significand < xi) { + ++ret_value.significand; + } + return ret_value; +} + +template decimal_fp to_decimal(T x) noexcept { + // Step 1: integer promotion & Schubfach multiplier calculation. + + using carrier_uint = typename float_info::carrier_uint; + using cache_entry_type = typename cache_accessor::cache_entry_type; + auto br = bit_cast(x); + + // Extract significand bits and exponent bits. + const carrier_uint significand_mask = + (static_cast(1) << num_significand_bits()) - 1; + carrier_uint significand = (br & significand_mask); + int exponent = + static_cast((br & exponent_mask()) >> num_significand_bits()); + + if (exponent != 0) { // Check if normal. + exponent -= exponent_bias() + num_significand_bits(); + + // Shorter interval case; proceed like Schubfach. + // In fact, when exponent == 1 and significand == 0, the interval is + // regular. However, it can be shown that the end-results are anyway same. + if (significand == 0) return shorter_interval_case(exponent); + + significand |= (static_cast(1) << num_significand_bits()); + } else { + // Subnormal case; the interval is always regular. + if (significand == 0) return {0, 0}; + exponent = + std::numeric_limits::min_exponent - num_significand_bits() - 1; + } + + const bool include_left_endpoint = (significand % 2 == 0); + const bool include_right_endpoint = include_left_endpoint; + + // Compute k and beta. + const int minus_k = floor_log10_pow2(exponent) - float_info::kappa; + const cache_entry_type cache = cache_accessor::get_cached_power(-minus_k); + const int beta = exponent + floor_log2_pow10(-minus_k); + + // Compute zi and deltai. + // 10^kappa <= deltai < 10^(kappa + 1) + const uint32_t deltai = cache_accessor::compute_delta(cache, beta); + const carrier_uint two_fc = significand << 1; + + // For the case of binary32, the result of integer check is not correct for + // 29711844 * 2^-82 + // = 6.1442653300000000008655037797566933477355632930994033813476... * 10^-18 + // and 29711844 * 2^-81 + // = 1.2288530660000000001731007559513386695471126586198806762695... * 10^-17, + // and they are the unique counterexamples. However, since 29711844 is even, + // this does not cause any problem for the endpoints calculations; it can only + // cause a problem when we need to perform integer check for the center. + // Fortunately, with these inputs, that branch is never executed, so we are + // fine. + const typename cache_accessor::compute_mul_result z_mul = + cache_accessor::compute_mul((two_fc | 1) << beta, cache); + + // Step 2: Try larger divisor; remove trailing zeros if necessary. + + // Using an upper bound on zi, we might be able to optimize the division + // better than the compiler; we are computing zi / big_divisor here. + decimal_fp ret_value; + ret_value.significand = divide_by_10_to_kappa_plus_1(z_mul.result); + uint32_t r = static_cast(z_mul.result - float_info::big_divisor * + ret_value.significand); + + if (r < deltai) { + // Exclude the right endpoint if necessary. + if (r == 0 && (z_mul.is_integer & !include_right_endpoint)) { + --ret_value.significand; + r = float_info::big_divisor; + goto small_divisor_case_label; + } + } else if (r > deltai) { + goto small_divisor_case_label; + } else { + // r == deltai; compare fractional parts. + const typename cache_accessor::compute_mul_parity_result x_mul = + cache_accessor::compute_mul_parity(two_fc - 1, cache, beta); + + if (!(x_mul.parity | (x_mul.is_integer & include_left_endpoint))) + goto small_divisor_case_label; + } + ret_value.exponent = minus_k + float_info::kappa + 1; + + // We may need to remove trailing zeros. + ret_value.exponent += remove_trailing_zeros(ret_value.significand); + return ret_value; + + // Step 3: Find the significand with the smaller divisor. + +small_divisor_case_label: + ret_value.significand *= 10; + ret_value.exponent = minus_k + float_info::kappa; + + uint32_t dist = r - (deltai / 2) + (float_info::small_divisor / 2); + const bool approx_y_parity = + ((dist ^ (float_info::small_divisor / 2)) & 1) != 0; + + // Is dist divisible by 10^kappa? + const bool divisible_by_small_divisor = + check_divisibility_and_divide_by_pow10::kappa>(dist); + + // Add dist / 10^kappa to the significand. + ret_value.significand += dist; + + if (!divisible_by_small_divisor) return ret_value; + + // Check z^(f) >= epsilon^(f). + // We have either yi == zi - epsiloni or yi == (zi - epsiloni) - 1, + // where yi == zi - epsiloni if and only if z^(f) >= epsilon^(f). + // Since there are only 2 possibilities, we only need to care about the + // parity. Also, zi and r should have the same parity since the divisor + // is an even number. + const auto y_mul = cache_accessor::compute_mul_parity(two_fc, cache, beta); + + // If z^(f) >= epsilon^(f), we might have a tie when z^(f) == epsilon^(f), + // or equivalently, when y is an integer. + if (y_mul.parity != approx_y_parity) + --ret_value.significand; + else if (y_mul.is_integer & (ret_value.significand % 2 != 0)) + --ret_value.significand; + return ret_value; +} +} // namespace dragonbox +} // namespace detail + +template <> struct formatter { + FMT_CONSTEXPR auto parse(format_parse_context& ctx) + -> format_parse_context::iterator { + return ctx.begin(); + } + + auto format(const detail::bigint& n, format_context& ctx) const + -> format_context::iterator { + auto out = ctx.out(); + bool first = true; + for (auto i = n.bigits_.size(); i > 0; --i) { + auto value = n.bigits_[i - 1u]; + if (first) { + out = format_to(out, FMT_STRING("{:x}"), value); + first = false; + continue; + } + out = format_to(out, FMT_STRING("{:08x}"), value); + } + if (n.exp_ > 0) + out = format_to(out, FMT_STRING("p{}"), + n.exp_ * detail::bigint::bigit_bits); + return out; + } +}; + +FMT_FUNC detail::utf8_to_utf16::utf8_to_utf16(string_view s) { + for_each_codepoint(s, [this](uint32_t cp, string_view) { + if (cp == invalid_code_point) FMT_THROW(std::runtime_error("invalid utf8")); + if (cp <= 0xFFFF) { + buffer_.push_back(static_cast(cp)); + } else { + cp -= 0x10000; + buffer_.push_back(static_cast(0xD800 + (cp >> 10))); + buffer_.push_back(static_cast(0xDC00 + (cp & 0x3FF))); + } + return true; + }); + buffer_.push_back(0); +} + +FMT_FUNC void format_system_error(detail::buffer& out, int error_code, + const char* message) noexcept { + FMT_TRY { + auto ec = std::error_code(error_code, std::generic_category()); + write(std::back_inserter(out), std::system_error(ec, message).what()); + return; + } + FMT_CATCH(...) {} + format_error_code(out, error_code, message); +} + +FMT_FUNC void report_system_error(int error_code, + const char* message) noexcept { + report_error(format_system_error, error_code, message); +} + +FMT_FUNC std::string vformat(string_view fmt, format_args args) { + // Don't optimize the "{}" case to keep the binary size small and because it + // can be better optimized in fmt::format anyway. + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + return to_string(buffer); +} + +namespace detail { +#ifndef _WIN32 +FMT_FUNC bool write_console(std::FILE*, string_view) { return false; } +#else +using dword = conditional_t; +extern "C" __declspec(dllimport) int __stdcall WriteConsoleW( // + void*, const void*, dword, dword*, void*); + +FMT_FUNC bool write_console(std::FILE* f, string_view text) { + auto fd = _fileno(f); + if (!_isatty(fd)) return false; + auto u16 = utf8_to_utf16(text); + auto written = dword(); + return WriteConsoleW(reinterpret_cast(_get_osfhandle(fd)), u16.c_str(), + static_cast(u16.size()), &written, nullptr) != 0; +} + +// Print assuming legacy (non-Unicode) encoding. +FMT_FUNC void vprint_mojibake(std::FILE* f, string_view fmt, format_args args) { + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, + basic_format_args>(args)); + fwrite_fully(buffer.data(), 1, buffer.size(), f); +} +#endif + +FMT_FUNC void print(std::FILE* f, string_view text) { + if (!write_console(f, text)) fwrite_fully(text.data(), 1, text.size(), f); +} +} // namespace detail + +FMT_FUNC void vprint(std::FILE* f, string_view fmt, format_args args) { + auto buffer = memory_buffer(); + detail::vformat_to(buffer, fmt, args); + detail::print(f, {buffer.data(), buffer.size()}); +} + +FMT_FUNC void vprint(string_view fmt, format_args args) { + vprint(stdout, fmt, args); +} + +namespace detail { + +struct singleton { + unsigned char upper; + unsigned char lower_count; +}; + +inline auto is_printable(uint16_t x, const singleton* singletons, + size_t singletons_size, + const unsigned char* singleton_lowers, + const unsigned char* normal, size_t normal_size) + -> bool { + auto upper = x >> 8; + auto lower_start = 0; + for (size_t i = 0; i < singletons_size; ++i) { + auto s = singletons[i]; + auto lower_end = lower_start + s.lower_count; + if (upper < s.upper) break; + if (upper == s.upper) { + for (auto j = lower_start; j < lower_end; ++j) { + if (singleton_lowers[j] == (x & 0xff)) return false; + } + } + lower_start = lower_end; + } + + auto xsigned = static_cast(x); + auto current = true; + for (size_t i = 0; i < normal_size; ++i) { + auto v = static_cast(normal[i]); + auto len = (v & 0x80) != 0 ? (v & 0x7f) << 8 | normal[++i] : v; + xsigned -= len; + if (xsigned < 0) break; + current = !current; + } + return current; +} + +// This code is generated by support/printable.py. +FMT_FUNC auto is_printable(uint32_t cp) -> bool { + static constexpr singleton singletons0[] = { + {0x00, 1}, {0x03, 5}, {0x05, 6}, {0x06, 3}, {0x07, 6}, {0x08, 8}, + {0x09, 17}, {0x0a, 28}, {0x0b, 25}, {0x0c, 20}, {0x0d, 16}, {0x0e, 13}, + {0x0f, 4}, {0x10, 3}, {0x12, 18}, {0x13, 9}, {0x16, 1}, {0x17, 5}, + {0x18, 2}, {0x19, 3}, {0x1a, 7}, {0x1c, 2}, {0x1d, 1}, {0x1f, 22}, + {0x20, 3}, {0x2b, 3}, {0x2c, 2}, {0x2d, 11}, {0x2e, 1}, {0x30, 3}, + {0x31, 2}, {0x32, 1}, {0xa7, 2}, {0xa9, 2}, {0xaa, 4}, {0xab, 8}, + {0xfa, 2}, {0xfb, 5}, {0xfd, 4}, {0xfe, 3}, {0xff, 9}, + }; + static constexpr unsigned char singletons0_lower[] = { + 0xad, 0x78, 0x79, 0x8b, 0x8d, 0xa2, 0x30, 0x57, 0x58, 0x8b, 0x8c, 0x90, + 0x1c, 0x1d, 0xdd, 0x0e, 0x0f, 0x4b, 0x4c, 0xfb, 0xfc, 0x2e, 0x2f, 0x3f, + 0x5c, 0x5d, 0x5f, 0xb5, 0xe2, 0x84, 0x8d, 0x8e, 0x91, 0x92, 0xa9, 0xb1, + 0xba, 0xbb, 0xc5, 0xc6, 0xc9, 0xca, 0xde, 0xe4, 0xe5, 0xff, 0x00, 0x04, + 0x11, 0x12, 0x29, 0x31, 0x34, 0x37, 0x3a, 0x3b, 0x3d, 0x49, 0x4a, 0x5d, + 0x84, 0x8e, 0x92, 0xa9, 0xb1, 0xb4, 0xba, 0xbb, 0xc6, 0xca, 0xce, 0xcf, + 0xe4, 0xe5, 0x00, 0x04, 0x0d, 0x0e, 0x11, 0x12, 0x29, 0x31, 0x34, 0x3a, + 0x3b, 0x45, 0x46, 0x49, 0x4a, 0x5e, 0x64, 0x65, 0x84, 0x91, 0x9b, 0x9d, + 0xc9, 0xce, 0xcf, 0x0d, 0x11, 0x29, 0x45, 0x49, 0x57, 0x64, 0x65, 0x8d, + 0x91, 0xa9, 0xb4, 0xba, 0xbb, 0xc5, 0xc9, 0xdf, 0xe4, 0xe5, 0xf0, 0x0d, + 0x11, 0x45, 0x49, 0x64, 0x65, 0x80, 0x84, 0xb2, 0xbc, 0xbe, 0xbf, 0xd5, + 0xd7, 0xf0, 0xf1, 0x83, 0x85, 0x8b, 0xa4, 0xa6, 0xbe, 0xbf, 0xc5, 0xc7, + 0xce, 0xcf, 0xda, 0xdb, 0x48, 0x98, 0xbd, 0xcd, 0xc6, 0xce, 0xcf, 0x49, + 0x4e, 0x4f, 0x57, 0x59, 0x5e, 0x5f, 0x89, 0x8e, 0x8f, 0xb1, 0xb6, 0xb7, + 0xbf, 0xc1, 0xc6, 0xc7, 0xd7, 0x11, 0x16, 0x17, 0x5b, 0x5c, 0xf6, 0xf7, + 0xfe, 0xff, 0x80, 0x0d, 0x6d, 0x71, 0xde, 0xdf, 0x0e, 0x0f, 0x1f, 0x6e, + 0x6f, 0x1c, 0x1d, 0x5f, 0x7d, 0x7e, 0xae, 0xaf, 0xbb, 0xbc, 0xfa, 0x16, + 0x17, 0x1e, 0x1f, 0x46, 0x47, 0x4e, 0x4f, 0x58, 0x5a, 0x5c, 0x5e, 0x7e, + 0x7f, 0xb5, 0xc5, 0xd4, 0xd5, 0xdc, 0xf0, 0xf1, 0xf5, 0x72, 0x73, 0x8f, + 0x74, 0x75, 0x96, 0x2f, 0x5f, 0x26, 0x2e, 0x2f, 0xa7, 0xaf, 0xb7, 0xbf, + 0xc7, 0xcf, 0xd7, 0xdf, 0x9a, 0x40, 0x97, 0x98, 0x30, 0x8f, 0x1f, 0xc0, + 0xc1, 0xce, 0xff, 0x4e, 0x4f, 0x5a, 0x5b, 0x07, 0x08, 0x0f, 0x10, 0x27, + 0x2f, 0xee, 0xef, 0x6e, 0x6f, 0x37, 0x3d, 0x3f, 0x42, 0x45, 0x90, 0x91, + 0xfe, 0xff, 0x53, 0x67, 0x75, 0xc8, 0xc9, 0xd0, 0xd1, 0xd8, 0xd9, 0xe7, + 0xfe, 0xff, + }; + static constexpr singleton singletons1[] = { + {0x00, 6}, {0x01, 1}, {0x03, 1}, {0x04, 2}, {0x08, 8}, {0x09, 2}, + {0x0a, 5}, {0x0b, 2}, {0x0e, 4}, {0x10, 1}, {0x11, 2}, {0x12, 5}, + {0x13, 17}, {0x14, 1}, {0x15, 2}, {0x17, 2}, {0x19, 13}, {0x1c, 5}, + {0x1d, 8}, {0x24, 1}, {0x6a, 3}, {0x6b, 2}, {0xbc, 2}, {0xd1, 2}, + {0xd4, 12}, {0xd5, 9}, {0xd6, 2}, {0xd7, 2}, {0xda, 1}, {0xe0, 5}, + {0xe1, 2}, {0xe8, 2}, {0xee, 32}, {0xf0, 4}, {0xf8, 2}, {0xf9, 2}, + {0xfa, 2}, {0xfb, 1}, + }; + static constexpr unsigned char singletons1_lower[] = { + 0x0c, 0x27, 0x3b, 0x3e, 0x4e, 0x4f, 0x8f, 0x9e, 0x9e, 0x9f, 0x06, 0x07, + 0x09, 0x36, 0x3d, 0x3e, 0x56, 0xf3, 0xd0, 0xd1, 0x04, 0x14, 0x18, 0x36, + 0x37, 0x56, 0x57, 0x7f, 0xaa, 0xae, 0xaf, 0xbd, 0x35, 0xe0, 0x12, 0x87, + 0x89, 0x8e, 0x9e, 0x04, 0x0d, 0x0e, 0x11, 0x12, 0x29, 0x31, 0x34, 0x3a, + 0x45, 0x46, 0x49, 0x4a, 0x4e, 0x4f, 0x64, 0x65, 0x5c, 0xb6, 0xb7, 0x1b, + 0x1c, 0x07, 0x08, 0x0a, 0x0b, 0x14, 0x17, 0x36, 0x39, 0x3a, 0xa8, 0xa9, + 0xd8, 0xd9, 0x09, 0x37, 0x90, 0x91, 0xa8, 0x07, 0x0a, 0x3b, 0x3e, 0x66, + 0x69, 0x8f, 0x92, 0x6f, 0x5f, 0xee, 0xef, 0x5a, 0x62, 0x9a, 0x9b, 0x27, + 0x28, 0x55, 0x9d, 0xa0, 0xa1, 0xa3, 0xa4, 0xa7, 0xa8, 0xad, 0xba, 0xbc, + 0xc4, 0x06, 0x0b, 0x0c, 0x15, 0x1d, 0x3a, 0x3f, 0x45, 0x51, 0xa6, 0xa7, + 0xcc, 0xcd, 0xa0, 0x07, 0x19, 0x1a, 0x22, 0x25, 0x3e, 0x3f, 0xc5, 0xc6, + 0x04, 0x20, 0x23, 0x25, 0x26, 0x28, 0x33, 0x38, 0x3a, 0x48, 0x4a, 0x4c, + 0x50, 0x53, 0x55, 0x56, 0x58, 0x5a, 0x5c, 0x5e, 0x60, 0x63, 0x65, 0x66, + 0x6b, 0x73, 0x78, 0x7d, 0x7f, 0x8a, 0xa4, 0xaa, 0xaf, 0xb0, 0xc0, 0xd0, + 0xae, 0xaf, 0x79, 0xcc, 0x6e, 0x6f, 0x93, + }; + static constexpr unsigned char normal0[] = { + 0x00, 0x20, 0x5f, 0x22, 0x82, 0xdf, 0x04, 0x82, 0x44, 0x08, 0x1b, 0x04, + 0x06, 0x11, 0x81, 0xac, 0x0e, 0x80, 0xab, 0x35, 0x28, 0x0b, 0x80, 0xe0, + 0x03, 0x19, 0x08, 0x01, 0x04, 0x2f, 0x04, 0x34, 0x04, 0x07, 0x03, 0x01, + 0x07, 0x06, 0x07, 0x11, 0x0a, 0x50, 0x0f, 0x12, 0x07, 0x55, 0x07, 0x03, + 0x04, 0x1c, 0x0a, 0x09, 0x03, 0x08, 0x03, 0x07, 0x03, 0x02, 0x03, 0x03, + 0x03, 0x0c, 0x04, 0x05, 0x03, 0x0b, 0x06, 0x01, 0x0e, 0x15, 0x05, 0x3a, + 0x03, 0x11, 0x07, 0x06, 0x05, 0x10, 0x07, 0x57, 0x07, 0x02, 0x07, 0x15, + 0x0d, 0x50, 0x04, 0x43, 0x03, 0x2d, 0x03, 0x01, 0x04, 0x11, 0x06, 0x0f, + 0x0c, 0x3a, 0x04, 0x1d, 0x25, 0x5f, 0x20, 0x6d, 0x04, 0x6a, 0x25, 0x80, + 0xc8, 0x05, 0x82, 0xb0, 0x03, 0x1a, 0x06, 0x82, 0xfd, 0x03, 0x59, 0x07, + 0x15, 0x0b, 0x17, 0x09, 0x14, 0x0c, 0x14, 0x0c, 0x6a, 0x06, 0x0a, 0x06, + 0x1a, 0x06, 0x59, 0x07, 0x2b, 0x05, 0x46, 0x0a, 0x2c, 0x04, 0x0c, 0x04, + 0x01, 0x03, 0x31, 0x0b, 0x2c, 0x04, 0x1a, 0x06, 0x0b, 0x03, 0x80, 0xac, + 0x06, 0x0a, 0x06, 0x21, 0x3f, 0x4c, 0x04, 0x2d, 0x03, 0x74, 0x08, 0x3c, + 0x03, 0x0f, 0x03, 0x3c, 0x07, 0x38, 0x08, 0x2b, 0x05, 0x82, 0xff, 0x11, + 0x18, 0x08, 0x2f, 0x11, 0x2d, 0x03, 0x20, 0x10, 0x21, 0x0f, 0x80, 0x8c, + 0x04, 0x82, 0x97, 0x19, 0x0b, 0x15, 0x88, 0x94, 0x05, 0x2f, 0x05, 0x3b, + 0x07, 0x02, 0x0e, 0x18, 0x09, 0x80, 0xb3, 0x2d, 0x74, 0x0c, 0x80, 0xd6, + 0x1a, 0x0c, 0x05, 0x80, 0xff, 0x05, 0x80, 0xdf, 0x0c, 0xee, 0x0d, 0x03, + 0x84, 0x8d, 0x03, 0x37, 0x09, 0x81, 0x5c, 0x14, 0x80, 0xb8, 0x08, 0x80, + 0xcb, 0x2a, 0x38, 0x03, 0x0a, 0x06, 0x38, 0x08, 0x46, 0x08, 0x0c, 0x06, + 0x74, 0x0b, 0x1e, 0x03, 0x5a, 0x04, 0x59, 0x09, 0x80, 0x83, 0x18, 0x1c, + 0x0a, 0x16, 0x09, 0x4c, 0x04, 0x80, 0x8a, 0x06, 0xab, 0xa4, 0x0c, 0x17, + 0x04, 0x31, 0xa1, 0x04, 0x81, 0xda, 0x26, 0x07, 0x0c, 0x05, 0x05, 0x80, + 0xa5, 0x11, 0x81, 0x6d, 0x10, 0x78, 0x28, 0x2a, 0x06, 0x4c, 0x04, 0x80, + 0x8d, 0x04, 0x80, 0xbe, 0x03, 0x1b, 0x03, 0x0f, 0x0d, + }; + static constexpr unsigned char normal1[] = { + 0x5e, 0x22, 0x7b, 0x05, 0x03, 0x04, 0x2d, 0x03, 0x66, 0x03, 0x01, 0x2f, + 0x2e, 0x80, 0x82, 0x1d, 0x03, 0x31, 0x0f, 0x1c, 0x04, 0x24, 0x09, 0x1e, + 0x05, 0x2b, 0x05, 0x44, 0x04, 0x0e, 0x2a, 0x80, 0xaa, 0x06, 0x24, 0x04, + 0x24, 0x04, 0x28, 0x08, 0x34, 0x0b, 0x01, 0x80, 0x90, 0x81, 0x37, 0x09, + 0x16, 0x0a, 0x08, 0x80, 0x98, 0x39, 0x03, 0x63, 0x08, 0x09, 0x30, 0x16, + 0x05, 0x21, 0x03, 0x1b, 0x05, 0x01, 0x40, 0x38, 0x04, 0x4b, 0x05, 0x2f, + 0x04, 0x0a, 0x07, 0x09, 0x07, 0x40, 0x20, 0x27, 0x04, 0x0c, 0x09, 0x36, + 0x03, 0x3a, 0x05, 0x1a, 0x07, 0x04, 0x0c, 0x07, 0x50, 0x49, 0x37, 0x33, + 0x0d, 0x33, 0x07, 0x2e, 0x08, 0x0a, 0x81, 0x26, 0x52, 0x4e, 0x28, 0x08, + 0x2a, 0x56, 0x1c, 0x14, 0x17, 0x09, 0x4e, 0x04, 0x1e, 0x0f, 0x43, 0x0e, + 0x19, 0x07, 0x0a, 0x06, 0x48, 0x08, 0x27, 0x09, 0x75, 0x0b, 0x3f, 0x41, + 0x2a, 0x06, 0x3b, 0x05, 0x0a, 0x06, 0x51, 0x06, 0x01, 0x05, 0x10, 0x03, + 0x05, 0x80, 0x8b, 0x62, 0x1e, 0x48, 0x08, 0x0a, 0x80, 0xa6, 0x5e, 0x22, + 0x45, 0x0b, 0x0a, 0x06, 0x0d, 0x13, 0x39, 0x07, 0x0a, 0x36, 0x2c, 0x04, + 0x10, 0x80, 0xc0, 0x3c, 0x64, 0x53, 0x0c, 0x48, 0x09, 0x0a, 0x46, 0x45, + 0x1b, 0x48, 0x08, 0x53, 0x1d, 0x39, 0x81, 0x07, 0x46, 0x0a, 0x1d, 0x03, + 0x47, 0x49, 0x37, 0x03, 0x0e, 0x08, 0x0a, 0x06, 0x39, 0x07, 0x0a, 0x81, + 0x36, 0x19, 0x80, 0xb7, 0x01, 0x0f, 0x32, 0x0d, 0x83, 0x9b, 0x66, 0x75, + 0x0b, 0x80, 0xc4, 0x8a, 0xbc, 0x84, 0x2f, 0x8f, 0xd1, 0x82, 0x47, 0xa1, + 0xb9, 0x82, 0x39, 0x07, 0x2a, 0x04, 0x02, 0x60, 0x26, 0x0a, 0x46, 0x0a, + 0x28, 0x05, 0x13, 0x82, 0xb0, 0x5b, 0x65, 0x4b, 0x04, 0x39, 0x07, 0x11, + 0x40, 0x05, 0x0b, 0x02, 0x0e, 0x97, 0xf8, 0x08, 0x84, 0xd6, 0x2a, 0x09, + 0xa2, 0xf7, 0x81, 0x1f, 0x31, 0x03, 0x11, 0x04, 0x08, 0x81, 0x8c, 0x89, + 0x04, 0x6b, 0x05, 0x0d, 0x03, 0x09, 0x07, 0x10, 0x93, 0x60, 0x80, 0xf6, + 0x0a, 0x73, 0x08, 0x6e, 0x17, 0x46, 0x80, 0x9a, 0x14, 0x0c, 0x57, 0x09, + 0x19, 0x80, 0x87, 0x81, 0x47, 0x03, 0x85, 0x42, 0x0f, 0x15, 0x85, 0x50, + 0x2b, 0x80, 0xd5, 0x2d, 0x03, 0x1a, 0x04, 0x02, 0x81, 0x70, 0x3a, 0x05, + 0x01, 0x85, 0x00, 0x80, 0xd7, 0x29, 0x4c, 0x04, 0x0a, 0x04, 0x02, 0x83, + 0x11, 0x44, 0x4c, 0x3d, 0x80, 0xc2, 0x3c, 0x06, 0x01, 0x04, 0x55, 0x05, + 0x1b, 0x34, 0x02, 0x81, 0x0e, 0x2c, 0x04, 0x64, 0x0c, 0x56, 0x0a, 0x80, + 0xae, 0x38, 0x1d, 0x0d, 0x2c, 0x04, 0x09, 0x07, 0x02, 0x0e, 0x06, 0x80, + 0x9a, 0x83, 0xd8, 0x08, 0x0d, 0x03, 0x0d, 0x03, 0x74, 0x0c, 0x59, 0x07, + 0x0c, 0x14, 0x0c, 0x04, 0x38, 0x08, 0x0a, 0x06, 0x28, 0x08, 0x22, 0x4e, + 0x81, 0x54, 0x0c, 0x15, 0x03, 0x03, 0x05, 0x07, 0x09, 0x19, 0x07, 0x07, + 0x09, 0x03, 0x0d, 0x07, 0x29, 0x80, 0xcb, 0x25, 0x0a, 0x84, 0x06, + }; + auto lower = static_cast(cp); + if (cp < 0x10000) { + return is_printable(lower, singletons0, + sizeof(singletons0) / sizeof(*singletons0), + singletons0_lower, normal0, sizeof(normal0)); + } + if (cp < 0x20000) { + return is_printable(lower, singletons1, + sizeof(singletons1) / sizeof(*singletons1), + singletons1_lower, normal1, sizeof(normal1)); + } + if (0x2a6de <= cp && cp < 0x2a700) return false; + if (0x2b735 <= cp && cp < 0x2b740) return false; + if (0x2b81e <= cp && cp < 0x2b820) return false; + if (0x2cea2 <= cp && cp < 0x2ceb0) return false; + if (0x2ebe1 <= cp && cp < 0x2f800) return false; + if (0x2fa1e <= cp && cp < 0x30000) return false; + if (0x3134b <= cp && cp < 0xe0100) return false; + if (0xe01f0 <= cp && cp < 0x110000) return false; + return cp < 0x110000; +} + +} // namespace detail + +FMT_END_NAMESPACE + +#endif // FMT_FORMAT_INL_H_ diff --git a/Genie/Genie/src/qualla/include/fmt/format.h b/Genie/Genie/src/qualla/include/fmt/format.h new file mode 100644 index 0000000000000000000000000000000000000000..87a34b972ce6af4e2209e4d6cf78e8401e8f0037 --- /dev/null +++ b/Genie/Genie/src/qualla/include/fmt/format.h @@ -0,0 +1,4510 @@ +/* + Formatting library for C++ + + Copyright (c) 2012 - present, Victor Zverovich + + Permission is hereby granted, free of charge, to any person obtaining + a copy of this software and associated documentation files (the + "Software"), to deal in the Software without restriction, including + without limitation the rights to use, copy, modify, merge, publish, + distribute, sublicense, and/or sell copies of the Software, and to + permit persons to whom the Software is furnished to do so, subject to + the following conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + --- Optional exception to the license --- + + As an exception, if, as a result of your compiling your source code, portions + of this Software are embedded into a machine-executable object form of such + source code, you may redistribute such embedded portions in such object form + without including the above copyright and permission notices. + */ + +#ifndef FMT_FORMAT_H_ +#define FMT_FORMAT_H_ + +#include // std::signbit +#include // uint32_t +#include // std::memcpy +#include // std::initializer_list +#include // std::numeric_limits +#include // std::uninitialized_copy +#include // std::runtime_error +#include // std::system_error + +#ifdef __cpp_lib_bit_cast +# include // std::bitcast +#endif + +#include "core.h" + +#if defined __cpp_inline_variables && __cpp_inline_variables >= 201606L +# define FMT_INLINE_VARIABLE inline +#else +# define FMT_INLINE_VARIABLE +#endif + +#if FMT_HAS_CPP17_ATTRIBUTE(fallthrough) +# define FMT_FALLTHROUGH [[fallthrough]] +#elif defined(__clang__) +# define FMT_FALLTHROUGH [[clang::fallthrough]] +#elif FMT_GCC_VERSION >= 700 && \ + (!defined(__EDG_VERSION__) || __EDG_VERSION__ >= 520) +# define FMT_FALLTHROUGH [[gnu::fallthrough]] +#else +# define FMT_FALLTHROUGH +#endif + +#ifndef FMT_DEPRECATED +# if FMT_HAS_CPP14_ATTRIBUTE(deprecated) || FMT_MSC_VERSION >= 1900 +# define FMT_DEPRECATED [[deprecated]] +# else +# if (defined(__GNUC__) && !defined(__LCC__)) || defined(__clang__) +# define FMT_DEPRECATED __attribute__((deprecated)) +# elif FMT_MSC_VERSION +# define FMT_DEPRECATED __declspec(deprecated) +# else +# define FMT_DEPRECATED /* deprecated */ +# endif +# endif +#endif + +#ifndef FMT_NO_UNIQUE_ADDRESS +# if FMT_CPLUSPLUS >= 202002L +# if FMT_HAS_CPP_ATTRIBUTE(no_unique_address) +# define FMT_NO_UNIQUE_ADDRESS [[no_unique_address]] +// VS2019 v16.10 and later except clang-cl (https://reviews.llvm.org/D110485) +# elif (FMT_MSC_VERSION >= 1929) && !FMT_CLANG_VERSION +# define FMT_NO_UNIQUE_ADDRESS [[msvc::no_unique_address]] +# endif +# endif +#endif +#ifndef FMT_NO_UNIQUE_ADDRESS +# define FMT_NO_UNIQUE_ADDRESS +#endif + +#if FMT_GCC_VERSION || defined(__clang__) +# define FMT_VISIBILITY(value) __attribute__((visibility(value))) +#else +# define FMT_VISIBILITY(value) +#endif + +#ifdef __has_builtin +# define FMT_HAS_BUILTIN(x) __has_builtin(x) +#else +# define FMT_HAS_BUILTIN(x) 0 +#endif + +#if FMT_GCC_VERSION || FMT_CLANG_VERSION +# define FMT_NOINLINE __attribute__((noinline)) +#else +# define FMT_NOINLINE +#endif + +#ifndef FMT_THROW +# if FMT_EXCEPTIONS +# if FMT_MSC_VERSION || defined(__NVCC__) +FMT_BEGIN_NAMESPACE +namespace detail { +template inline void do_throw(const Exception& x) { + // Silence unreachable code warnings in MSVC and NVCC because these + // are nearly impossible to fix in a generic code. + volatile bool b = true; + if (b) throw x; +} +} // namespace detail +FMT_END_NAMESPACE +# define FMT_THROW(x) detail::do_throw(x) +# else +# define FMT_THROW(x) throw x +# endif +# else +# define FMT_THROW(x) \ + ::fmt::detail::assert_fail(__FILE__, __LINE__, (x).what()) +# endif +#endif + +#if FMT_EXCEPTIONS +# define FMT_TRY try +# define FMT_CATCH(x) catch (x) +#else +# define FMT_TRY if (true) +# define FMT_CATCH(x) if (false) +#endif + +#ifndef FMT_MAYBE_UNUSED +# if FMT_HAS_CPP17_ATTRIBUTE(maybe_unused) +# define FMT_MAYBE_UNUSED [[maybe_unused]] +# else +# define FMT_MAYBE_UNUSED +# endif +#endif + +#ifndef FMT_USE_USER_DEFINED_LITERALS +// EDG based compilers (Intel, NVIDIA, Elbrus, etc), GCC and MSVC support UDLs. +# if (FMT_HAS_FEATURE(cxx_user_literals) || FMT_GCC_VERSION >= 407 || \ + FMT_MSC_VERSION >= 1900) && \ + (!defined(__EDG_VERSION__) || __EDG_VERSION__ >= /* UDL feature */ 480) +# define FMT_USE_USER_DEFINED_LITERALS 1 +# else +# define FMT_USE_USER_DEFINED_LITERALS 0 +# endif +#endif + +// Defining FMT_REDUCE_INT_INSTANTIATIONS to 1, will reduce the number of +// integer formatter template instantiations to just one by only using the +// largest integer type. This results in a reduction in binary size but will +// cause a decrease in integer formatting performance. +#if !defined(FMT_REDUCE_INT_INSTANTIATIONS) +# define FMT_REDUCE_INT_INSTANTIATIONS 0 +#endif + +// __builtin_clz is broken in clang with Microsoft CodeGen: +// https://github.com/fmtlib/fmt/issues/519. +#if !FMT_MSC_VERSION +# if FMT_HAS_BUILTIN(__builtin_clz) || FMT_GCC_VERSION || FMT_ICC_VERSION +# define FMT_BUILTIN_CLZ(n) __builtin_clz(n) +# endif +# if FMT_HAS_BUILTIN(__builtin_clzll) || FMT_GCC_VERSION || FMT_ICC_VERSION +# define FMT_BUILTIN_CLZLL(n) __builtin_clzll(n) +# endif +#endif + +// __builtin_ctz is broken in Intel Compiler Classic on Windows: +// https://github.com/fmtlib/fmt/issues/2510. +#ifndef __ICL +# if FMT_HAS_BUILTIN(__builtin_ctz) || FMT_GCC_VERSION || FMT_ICC_VERSION || \ + defined(__NVCOMPILER) +# define FMT_BUILTIN_CTZ(n) __builtin_ctz(n) +# endif +# if FMT_HAS_BUILTIN(__builtin_ctzll) || FMT_GCC_VERSION || \ + FMT_ICC_VERSION || defined(__NVCOMPILER) +# define FMT_BUILTIN_CTZLL(n) __builtin_ctzll(n) +# endif +#endif + +#if FMT_MSC_VERSION +# include // _BitScanReverse[64], _BitScanForward[64], _umul128 +#endif + +// Some compilers masquerade as both MSVC and GCC-likes or otherwise support +// __builtin_clz and __builtin_clzll, so only define FMT_BUILTIN_CLZ using the +// MSVC intrinsics if the clz and clzll builtins are not available. +#if FMT_MSC_VERSION && !defined(FMT_BUILTIN_CLZLL) && \ + !defined(FMT_BUILTIN_CTZLL) +FMT_BEGIN_NAMESPACE +namespace detail { +// Avoid Clang with Microsoft CodeGen's -Wunknown-pragmas warning. +# if !defined(__clang__) +# pragma intrinsic(_BitScanForward) +# pragma intrinsic(_BitScanReverse) +# if defined(_WIN64) +# pragma intrinsic(_BitScanForward64) +# pragma intrinsic(_BitScanReverse64) +# endif +# endif + +inline auto clz(uint32_t x) -> int { + unsigned long r = 0; + _BitScanReverse(&r, x); + FMT_ASSERT(x != 0, ""); + // Static analysis complains about using uninitialized data + // "r", but the only way that can happen is if "x" is 0, + // which the callers guarantee to not happen. + FMT_MSC_WARNING(suppress : 6102) + return 31 ^ static_cast(r); +} +# define FMT_BUILTIN_CLZ(n) detail::clz(n) + +inline auto clzll(uint64_t x) -> int { + unsigned long r = 0; +# ifdef _WIN64 + _BitScanReverse64(&r, x); +# else + // Scan the high 32 bits. + if (_BitScanReverse(&r, static_cast(x >> 32))) + return 63 ^ static_cast(r + 32); + // Scan the low 32 bits. + _BitScanReverse(&r, static_cast(x)); +# endif + FMT_ASSERT(x != 0, ""); + FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. + return 63 ^ static_cast(r); +} +# define FMT_BUILTIN_CLZLL(n) detail::clzll(n) + +inline auto ctz(uint32_t x) -> int { + unsigned long r = 0; + _BitScanForward(&r, x); + FMT_ASSERT(x != 0, ""); + FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. + return static_cast(r); +} +# define FMT_BUILTIN_CTZ(n) detail::ctz(n) + +inline auto ctzll(uint64_t x) -> int { + unsigned long r = 0; + FMT_ASSERT(x != 0, ""); + FMT_MSC_WARNING(suppress : 6102) // Suppress a bogus static analysis warning. +# ifdef _WIN64 + _BitScanForward64(&r, x); +# else + // Scan the low 32 bits. + if (_BitScanForward(&r, static_cast(x))) return static_cast(r); + // Scan the high 32 bits. + _BitScanForward(&r, static_cast(x >> 32)); + r += 32; +# endif + return static_cast(r); +} +# define FMT_BUILTIN_CTZLL(n) detail::ctzll(n) +} // namespace detail +FMT_END_NAMESPACE +#endif + +FMT_BEGIN_NAMESPACE + +template struct disjunction : std::false_type {}; +template struct disjunction

: P {}; +template +struct disjunction + : conditional_t> {}; + +template struct conjunction : std::true_type {}; +template struct conjunction

: P {}; +template +struct conjunction + : conditional_t, P1> {}; + +namespace detail { + +FMT_CONSTEXPR inline void abort_fuzzing_if(bool condition) { + ignore_unused(condition); +#ifdef FMT_FUZZ + if (condition) throw std::runtime_error("fuzzing limit reached"); +#endif +} + +template struct string_literal { + static constexpr CharT value[sizeof...(C)] = {C...}; + constexpr operator basic_string_view() const { + return {value, sizeof...(C)}; + } +}; + +#if FMT_CPLUSPLUS < 201703L +template +constexpr CharT string_literal::value[sizeof...(C)]; +#endif + +template class formatbuf : public Streambuf { + private: + using char_type = typename Streambuf::char_type; + using streamsize = decltype(std::declval().sputn(nullptr, 0)); + using int_type = typename Streambuf::int_type; + using traits_type = typename Streambuf::traits_type; + + buffer& buffer_; + + public: + explicit formatbuf(buffer& buf) : buffer_(buf) {} + + protected: + // The put area is always empty. This makes the implementation simpler and has + // the advantage that the streambuf and the buffer are always in sync and + // sputc never writes into uninitialized memory. A disadvantage is that each + // call to sputc always results in a (virtual) call to overflow. There is no + // disadvantage here for sputn since this always results in a call to xsputn. + + auto overflow(int_type ch) -> int_type override { + if (!traits_type::eq_int_type(ch, traits_type::eof())) + buffer_.push_back(static_cast(ch)); + return ch; + } + + auto xsputn(const char_type* s, streamsize count) -> streamsize override { + buffer_.append(s, s + count); + return count; + } +}; + +// Implementation of std::bit_cast for pre-C++20. +template +FMT_CONSTEXPR20 auto bit_cast(const From& from) -> To { +#ifdef __cpp_lib_bit_cast + if (is_constant_evaluated()) return std::bit_cast(from); +#endif + auto to = To(); + // The cast suppresses a bogus -Wclass-memaccess on GCC. + std::memcpy(static_cast(&to), &from, sizeof(to)); + return to; +} + +inline auto is_big_endian() -> bool { +#ifdef _WIN32 + return false; +#elif defined(__BIG_ENDIAN__) + return true; +#elif defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) + return __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__; +#else + struct bytes { + char data[sizeof(int)]; + }; + return bit_cast(1).data[0] == 0; +#endif +} + +class uint128_fallback { + private: + uint64_t lo_, hi_; + + public: + constexpr uint128_fallback(uint64_t hi, uint64_t lo) : lo_(lo), hi_(hi) {} + constexpr uint128_fallback(uint64_t value = 0) : lo_(value), hi_(0) {} + + constexpr uint64_t high() const noexcept { return hi_; } + constexpr uint64_t low() const noexcept { return lo_; } + + template ::value)> + constexpr explicit operator T() const { + return static_cast(lo_); + } + + friend constexpr auto operator==(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> bool { + return lhs.hi_ == rhs.hi_ && lhs.lo_ == rhs.lo_; + } + friend constexpr auto operator!=(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> bool { + return !(lhs == rhs); + } + friend constexpr auto operator>(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> bool { + return lhs.hi_ != rhs.hi_ ? lhs.hi_ > rhs.hi_ : lhs.lo_ > rhs.lo_; + } + friend constexpr auto operator|(const uint128_fallback& lhs, + const uint128_fallback& rhs) + -> uint128_fallback { + return {lhs.hi_ | rhs.hi_, lhs.lo_ | rhs.lo_}; + } + friend constexpr auto operator&(const uint128_fallback& lhs, + const uint128_fallback& rhs) + -> uint128_fallback { + return {lhs.hi_ & rhs.hi_, lhs.lo_ & rhs.lo_}; + } + friend constexpr auto operator~(const uint128_fallback& n) + -> uint128_fallback { + return {~n.hi_, ~n.lo_}; + } + friend auto operator+(const uint128_fallback& lhs, + const uint128_fallback& rhs) -> uint128_fallback { + auto result = uint128_fallback(lhs); + result += rhs; + return result; + } + friend auto operator*(const uint128_fallback& lhs, uint32_t rhs) + -> uint128_fallback { + FMT_ASSERT(lhs.hi_ == 0, ""); + uint64_t hi = (lhs.lo_ >> 32) * rhs; + uint64_t lo = (lhs.lo_ & ~uint32_t()) * rhs; + uint64_t new_lo = (hi << 32) + lo; + return {(hi >> 32) + (new_lo < lo ? 1 : 0), new_lo}; + } + friend auto operator-(const uint128_fallback& lhs, uint64_t rhs) + -> uint128_fallback { + return {lhs.hi_ - (lhs.lo_ < rhs ? 1 : 0), lhs.lo_ - rhs}; + } + FMT_CONSTEXPR auto operator>>(int shift) const -> uint128_fallback { + if (shift == 64) return {0, hi_}; + if (shift > 64) return uint128_fallback(0, hi_) >> (shift - 64); + return {hi_ >> shift, (hi_ << (64 - shift)) | (lo_ >> shift)}; + } + FMT_CONSTEXPR auto operator<<(int shift) const -> uint128_fallback { + if (shift == 64) return {lo_, 0}; + if (shift > 64) return uint128_fallback(lo_, 0) << (shift - 64); + return {hi_ << shift | (lo_ >> (64 - shift)), (lo_ << shift)}; + } + FMT_CONSTEXPR auto operator>>=(int shift) -> uint128_fallback& { + return *this = *this >> shift; + } + FMT_CONSTEXPR void operator+=(uint128_fallback n) { + uint64_t new_lo = lo_ + n.lo_; + uint64_t new_hi = hi_ + n.hi_ + (new_lo < lo_ ? 1 : 0); + FMT_ASSERT(new_hi >= hi_, ""); + lo_ = new_lo; + hi_ = new_hi; + } + FMT_CONSTEXPR void operator&=(uint128_fallback n) { + lo_ &= n.lo_; + hi_ &= n.hi_; + } + + FMT_CONSTEXPR20 uint128_fallback& operator+=(uint64_t n) noexcept { + if (is_constant_evaluated()) { + lo_ += n; + hi_ += (lo_ < n ? 1 : 0); + return *this; + } +#if FMT_HAS_BUILTIN(__builtin_addcll) && !defined(__ibmxl__) + unsigned long long carry; + lo_ = __builtin_addcll(lo_, n, 0, &carry); + hi_ += carry; +#elif FMT_HAS_BUILTIN(__builtin_ia32_addcarryx_u64) && !defined(__ibmxl__) + unsigned long long result; + auto carry = __builtin_ia32_addcarryx_u64(0, lo_, n, &result); + lo_ = result; + hi_ += carry; +#elif defined(_MSC_VER) && defined(_M_X64) + auto carry = _addcarry_u64(0, lo_, n, &lo_); + _addcarry_u64(carry, hi_, 0, &hi_); +#else + lo_ += n; + hi_ += (lo_ < n ? 1 : 0); +#endif + return *this; + } +}; + +using uint128_t = conditional_t; + +#ifdef UINTPTR_MAX +using uintptr_t = ::uintptr_t; +#else +using uintptr_t = uint128_t; +#endif + +// Returns the largest possible value for type T. Same as +// std::numeric_limits::max() but shorter and not affected by the max macro. +template constexpr auto max_value() -> T { + return (std::numeric_limits::max)(); +} +template constexpr auto num_bits() -> int { + return std::numeric_limits::digits; +} +// std::numeric_limits::digits may return 0 for 128-bit ints. +template <> constexpr auto num_bits() -> int { return 128; } +template <> constexpr auto num_bits() -> int { return 128; } + +// A heterogeneous bit_cast used for converting 96-bit long double to uint128_t +// and 128-bit pointers to uint128_fallback. +template sizeof(From))> +inline auto bit_cast(const From& from) -> To { + constexpr auto size = static_cast(sizeof(From) / sizeof(unsigned)); + struct data_t { + unsigned value[static_cast(size)]; + } data = bit_cast(from); + auto result = To(); + if (const_check(is_big_endian())) { + for (int i = 0; i < size; ++i) + result = (result << num_bits()) | data.value[i]; + } else { + for (int i = size - 1; i >= 0; --i) + result = (result << num_bits()) | data.value[i]; + } + return result; +} + +template +FMT_CONSTEXPR20 inline auto countl_zero_fallback(UInt n) -> int { + int lz = 0; + constexpr UInt msb_mask = static_cast(1) << (num_bits() - 1); + for (; (n & msb_mask) == 0; n <<= 1) lz++; + return lz; +} + +FMT_CONSTEXPR20 inline auto countl_zero(uint32_t n) -> int { +#ifdef FMT_BUILTIN_CLZ + if (!is_constant_evaluated()) return FMT_BUILTIN_CLZ(n); +#endif + return countl_zero_fallback(n); +} + +FMT_CONSTEXPR20 inline auto countl_zero(uint64_t n) -> int { +#ifdef FMT_BUILTIN_CLZLL + if (!is_constant_evaluated()) return FMT_BUILTIN_CLZLL(n); +#endif + return countl_zero_fallback(n); +} + +FMT_INLINE void assume(bool condition) { + (void)condition; +#if FMT_HAS_BUILTIN(__builtin_assume) && !FMT_ICC_VERSION + __builtin_assume(condition); +#elif FMT_GCC_VERSION + if (!condition) __builtin_unreachable(); +#endif +} + +// An approximation of iterator_t for pre-C++20 systems. +template +using iterator_t = decltype(std::begin(std::declval())); +template using sentinel_t = decltype(std::end(std::declval())); + +// A workaround for std::string not having mutable data() until C++17. +template +inline auto get_data(std::basic_string& s) -> Char* { + return &s[0]; +} +template +inline auto get_data(Container& c) -> typename Container::value_type* { + return c.data(); +} + +// Attempts to reserve space for n extra characters in the output range. +// Returns a pointer to the reserved range or a reference to it. +template ::value)> +#if FMT_CLANG_VERSION >= 307 && !FMT_ICC_VERSION +__attribute__((no_sanitize("undefined"))) +#endif +inline auto +reserve(std::back_insert_iterator it, size_t n) -> + typename Container::value_type* { + Container& c = get_container(it); + size_t size = c.size(); + c.resize(size + n); + return get_data(c) + size; +} + +template +inline auto reserve(buffer_appender it, size_t n) -> buffer_appender { + buffer& buf = get_container(it); + buf.try_reserve(buf.size() + n); + return it; +} + +template +constexpr auto reserve(Iterator& it, size_t) -> Iterator& { + return it; +} + +template +using reserve_iterator = + remove_reference_t(), 0))>; + +template +constexpr auto to_pointer(OutputIt, size_t) -> T* { + return nullptr; +} +template auto to_pointer(buffer_appender it, size_t n) -> T* { + buffer& buf = get_container(it); + auto size = buf.size(); + if (buf.capacity() < size + n) return nullptr; + buf.try_resize(size + n); + return buf.data() + size; +} + +template ::value)> +inline auto base_iterator(std::back_insert_iterator it, + typename Container::value_type*) + -> std::back_insert_iterator { + return it; +} + +template +constexpr auto base_iterator(Iterator, Iterator it) -> Iterator { + return it; +} + +// is spectacularly slow to compile in C++20 so use a simple fill_n +// instead (#1998). +template +FMT_CONSTEXPR auto fill_n(OutputIt out, Size count, const T& value) + -> OutputIt { + for (Size i = 0; i < count; ++i) *out++ = value; + return out; +} +template +FMT_CONSTEXPR20 auto fill_n(T* out, Size count, char value) -> T* { + if (is_constant_evaluated()) { + return fill_n(out, count, value); + } + std::memset(out, value, to_unsigned(count)); + return out + count; +} + +#ifdef __cpp_char8_t +using char8_type = char8_t; +#else +enum char8_type : unsigned char {}; +#endif + +template +FMT_CONSTEXPR FMT_NOINLINE auto copy_str_noinline(InputIt begin, InputIt end, + OutputIt out) -> OutputIt { + return copy_str(begin, end, out); +} + +// A public domain branchless UTF-8 decoder by Christopher Wellons: +// https://github.com/skeeto/branchless-utf8 +/* Decode the next character, c, from s, reporting errors in e. + * + * Since this is a branchless decoder, four bytes will be read from the + * buffer regardless of the actual length of the next character. This + * means the buffer _must_ have at least three bytes of zero padding + * following the end of the data stream. + * + * Errors are reported in e, which will be non-zero if the parsed + * character was somehow invalid: invalid byte sequence, non-canonical + * encoding, or a surrogate half. + * + * The function returns a pointer to the next character. When an error + * occurs, this pointer will be a guess that depends on the particular + * error, but it will always advance at least one byte. + */ +FMT_CONSTEXPR inline auto utf8_decode(const char* s, uint32_t* c, int* e) + -> const char* { + constexpr const int masks[] = {0x00, 0x7f, 0x1f, 0x0f, 0x07}; + constexpr const uint32_t mins[] = {4194304, 0, 128, 2048, 65536}; + constexpr const int shiftc[] = {0, 18, 12, 6, 0}; + constexpr const int shifte[] = {0, 6, 4, 2, 0}; + + int len = "\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\1\0\0\0\0\0\0\0\0\2\2\2\2\3\3\4" + [static_cast(*s) >> 3]; + // Compute the pointer to the next character early so that the next + // iteration can start working on the next character. Neither Clang + // nor GCC figure out this reordering on their own. + const char* next = s + len + !len; + + using uchar = unsigned char; + + // Assume a four-byte character and load four bytes. Unused bits are + // shifted out. + *c = uint32_t(uchar(s[0]) & masks[len]) << 18; + *c |= uint32_t(uchar(s[1]) & 0x3f) << 12; + *c |= uint32_t(uchar(s[2]) & 0x3f) << 6; + *c |= uint32_t(uchar(s[3]) & 0x3f) << 0; + *c >>= shiftc[len]; + + // Accumulate the various error conditions. + *e = (*c < mins[len]) << 6; // non-canonical encoding + *e |= ((*c >> 11) == 0x1b) << 7; // surrogate half? + *e |= (*c > 0x10FFFF) << 8; // out of range? + *e |= (uchar(s[1]) & 0xc0) >> 2; + *e |= (uchar(s[2]) & 0xc0) >> 4; + *e |= uchar(s[3]) >> 6; + *e ^= 0x2a; // top two bits of each tail byte correct? + *e >>= shifte[len]; + + return next; +} + +constexpr FMT_INLINE_VARIABLE uint32_t invalid_code_point = ~uint32_t(); + +// Invokes f(cp, sv) for every code point cp in s with sv being the string view +// corresponding to the code point. cp is invalid_code_point on error. +template +FMT_CONSTEXPR void for_each_codepoint(string_view s, F f) { + auto decode = [f](const char* buf_ptr, const char* ptr) { + auto cp = uint32_t(); + auto error = 0; + auto end = utf8_decode(buf_ptr, &cp, &error); + bool result = f(error ? invalid_code_point : cp, + string_view(ptr, error ? 1 : to_unsigned(end - buf_ptr))); + return result ? (error ? buf_ptr + 1 : end) : nullptr; + }; + auto p = s.data(); + const size_t block_size = 4; // utf8_decode always reads blocks of 4 chars. + if (s.size() >= block_size) { + for (auto end = p + s.size() - block_size + 1; p < end;) { + p = decode(p, p); + if (!p) return; + } + } + if (auto num_chars_left = s.data() + s.size() - p) { + char buf[2 * block_size - 1] = {}; + copy_str(p, p + num_chars_left, buf); + const char* buf_ptr = buf; + do { + auto end = decode(buf_ptr, p); + if (!end) return; + p += end - buf_ptr; + buf_ptr = end; + } while (buf_ptr - buf < num_chars_left); + } +} + +template +inline auto compute_width(basic_string_view s) -> size_t { + return s.size(); +} + +// Computes approximate display width of a UTF-8 string. +FMT_CONSTEXPR inline size_t compute_width(string_view s) { + size_t num_code_points = 0; + // It is not a lambda for compatibility with C++14. + struct count_code_points { + size_t* count; + FMT_CONSTEXPR auto operator()(uint32_t cp, string_view) const -> bool { + *count += detail::to_unsigned( + 1 + + (cp >= 0x1100 && + (cp <= 0x115f || // Hangul Jamo init. consonants + cp == 0x2329 || // LEFT-POINTING ANGLE BRACKET + cp == 0x232a || // RIGHT-POINTING ANGLE BRACKET + // CJK ... Yi except IDEOGRAPHIC HALF FILL SPACE: + (cp >= 0x2e80 && cp <= 0xa4cf && cp != 0x303f) || + (cp >= 0xac00 && cp <= 0xd7a3) || // Hangul Syllables + (cp >= 0xf900 && cp <= 0xfaff) || // CJK Compatibility Ideographs + (cp >= 0xfe10 && cp <= 0xfe19) || // Vertical Forms + (cp >= 0xfe30 && cp <= 0xfe6f) || // CJK Compatibility Forms + (cp >= 0xff00 && cp <= 0xff60) || // Fullwidth Forms + (cp >= 0xffe0 && cp <= 0xffe6) || // Fullwidth Forms + (cp >= 0x20000 && cp <= 0x2fffd) || // CJK + (cp >= 0x30000 && cp <= 0x3fffd) || + // Miscellaneous Symbols and Pictographs + Emoticons: + (cp >= 0x1f300 && cp <= 0x1f64f) || + // Supplemental Symbols and Pictographs: + (cp >= 0x1f900 && cp <= 0x1f9ff)))); + return true; + } + }; + // We could avoid branches by using utf8_decode directly. + for_each_codepoint(s, count_code_points{&num_code_points}); + return num_code_points; +} + +inline auto compute_width(basic_string_view s) -> size_t { + return compute_width( + string_view(reinterpret_cast(s.data()), s.size())); +} + +template +inline auto code_point_index(basic_string_view s, size_t n) -> size_t { + size_t size = s.size(); + return n < size ? n : size; +} + +// Calculates the index of the nth code point in a UTF-8 string. +inline auto code_point_index(string_view s, size_t n) -> size_t { + const char* data = s.data(); + size_t num_code_points = 0; + for (size_t i = 0, size = s.size(); i != size; ++i) { + if ((data[i] & 0xc0) != 0x80 && ++num_code_points > n) return i; + } + return s.size(); +} + +inline auto code_point_index(basic_string_view s, size_t n) + -> size_t { + return code_point_index( + string_view(reinterpret_cast(s.data()), s.size()), n); +} + +template struct is_integral : std::is_integral {}; +template <> struct is_integral : std::true_type {}; +template <> struct is_integral : std::true_type {}; + +template +using is_signed = + std::integral_constant::is_signed || + std::is_same::value>; + +template +using is_integer = + bool_constant::value && !std::is_same::value && + !std::is_same::value && + !std::is_same::value>; + +#ifndef FMT_USE_FLOAT +# define FMT_USE_FLOAT 1 +#endif +#ifndef FMT_USE_DOUBLE +# define FMT_USE_DOUBLE 1 +#endif +#ifndef FMT_USE_LONG_DOUBLE +# define FMT_USE_LONG_DOUBLE 1 +#endif + +#ifndef FMT_USE_FLOAT128 +# ifdef __clang__ +// Clang emulates GCC, so it has to appear early. +# if FMT_HAS_INCLUDE() +# define FMT_USE_FLOAT128 1 +# endif +# elif defined(__GNUC__) +// GNU C++: +# if defined(_GLIBCXX_USE_FLOAT128) && !defined(__STRICT_ANSI__) +# define FMT_USE_FLOAT128 1 +# endif +# endif +# ifndef FMT_USE_FLOAT128 +# define FMT_USE_FLOAT128 0 +# endif +#endif + +#if FMT_USE_FLOAT128 +using float128 = __float128; +#else +using float128 = void; +#endif +template using is_float128 = std::is_same; + +template +using is_floating_point = + bool_constant::value || is_float128::value>; + +template ::value> +struct is_fast_float : bool_constant::is_iec559 && + sizeof(T) <= sizeof(double)> {}; +template struct is_fast_float : std::false_type {}; + +template +using is_double_double = bool_constant::digits == 106>; + +#ifndef FMT_USE_FULL_CACHE_DRAGONBOX +# define FMT_USE_FULL_CACHE_DRAGONBOX 0 +#endif + +template +template +void buffer::append(const U* begin, const U* end) { + while (begin != end) { + auto count = to_unsigned(end - begin); + try_reserve(size_ + count); + auto free_cap = capacity_ - size_; + if (free_cap < count) count = free_cap; + std::uninitialized_copy_n(begin, count, ptr_ + size_); + size_ += count; + begin += count; + } +} + +template +struct is_locale : std::false_type {}; +template +struct is_locale> : std::true_type {}; +} // namespace detail + +FMT_BEGIN_EXPORT + +// The number of characters to store in the basic_memory_buffer object itself +// to avoid dynamic memory allocation. +enum { inline_buffer_size = 500 }; + +/** + \rst + A dynamically growing memory buffer for trivially copyable/constructible types + with the first ``SIZE`` elements stored in the object itself. + + You can use the ``memory_buffer`` type alias for ``char`` instead. + + **Example**:: + + auto out = fmt::memory_buffer(); + format_to(std::back_inserter(out), "The answer is {}.", 42); + + This will append the following output to the ``out`` object: + + .. code-block:: none + + The answer is 42. + + The output can be converted to an ``std::string`` with ``to_string(out)``. + \endrst + */ +template > +class basic_memory_buffer final : public detail::buffer { + private: + T store_[SIZE]; + + // Don't inherit from Allocator to avoid generating type_info for it. + FMT_NO_UNIQUE_ADDRESS Allocator alloc_; + + // Deallocate memory allocated by the buffer. + FMT_CONSTEXPR20 void deallocate() { + T* data = this->data(); + if (data != store_) alloc_.deallocate(data, this->capacity()); + } + + protected: + FMT_CONSTEXPR20 void grow(size_t size) override { + detail::abort_fuzzing_if(size > 5000); + const size_t max_size = std::allocator_traits::max_size(alloc_); + size_t old_capacity = this->capacity(); + size_t new_capacity = old_capacity + old_capacity / 2; + if (size > new_capacity) + new_capacity = size; + else if (new_capacity > max_size) + new_capacity = size > max_size ? size : max_size; + T* old_data = this->data(); + T* new_data = + std::allocator_traits::allocate(alloc_, new_capacity); + // Suppress a bogus -Wstringop-overflow in gcc 13.1 (#3481). + detail::assume(this->size() <= new_capacity); + // The following code doesn't throw, so the raw pointer above doesn't leak. + std::uninitialized_copy_n(old_data, this->size(), new_data); + this->set(new_data, new_capacity); + // deallocate must not throw according to the standard, but even if it does, + // the buffer already uses the new storage and will deallocate it in + // destructor. + if (old_data != store_) alloc_.deallocate(old_data, old_capacity); + } + + public: + using value_type = T; + using const_reference = const T&; + + FMT_CONSTEXPR20 explicit basic_memory_buffer( + const Allocator& alloc = Allocator()) + : alloc_(alloc) { + this->set(store_, SIZE); + if (detail::is_constant_evaluated()) detail::fill_n(store_, SIZE, T()); + } + FMT_CONSTEXPR20 ~basic_memory_buffer() { deallocate(); } + + private: + // Move data from other to this buffer. + FMT_CONSTEXPR20 void move(basic_memory_buffer& other) { + alloc_ = std::move(other.alloc_); + T* data = other.data(); + size_t size = other.size(), capacity = other.capacity(); + if (data == other.store_) { + this->set(store_, capacity); + detail::copy_str(other.store_, other.store_ + size, store_); + } else { + this->set(data, capacity); + // Set pointer to the inline array so that delete is not called + // when deallocating. + other.set(other.store_, 0); + other.clear(); + } + this->resize(size); + } + + public: + /** + \rst + Constructs a :class:`fmt::basic_memory_buffer` object moving the content + of the other object to it. + \endrst + */ + FMT_CONSTEXPR20 basic_memory_buffer(basic_memory_buffer&& other) noexcept { + move(other); + } + + /** + \rst + Moves the content of the other ``basic_memory_buffer`` object to this one. + \endrst + */ + auto operator=(basic_memory_buffer&& other) noexcept -> basic_memory_buffer& { + FMT_ASSERT(this != &other, ""); + deallocate(); + move(other); + return *this; + } + + // Returns a copy of the allocator associated with this buffer. + auto get_allocator() const -> Allocator { return alloc_; } + + /** + Resizes the buffer to contain *count* elements. If T is a POD type new + elements may not be initialized. + */ + FMT_CONSTEXPR20 void resize(size_t count) { this->try_resize(count); } + + /** Increases the buffer capacity to *new_capacity*. */ + void reserve(size_t new_capacity) { this->try_reserve(new_capacity); } + + // Directly append data into the buffer + using detail::buffer::append; + template + void append(const ContiguousRange& range) { + append(range.data(), range.data() + range.size()); + } +}; + +using memory_buffer = basic_memory_buffer; + +template +struct is_contiguous> : std::true_type { +}; + +FMT_END_EXPORT +namespace detail { +FMT_API bool write_console(std::FILE* f, string_view text); +FMT_API void print(std::FILE*, string_view); +} // namespace detail + +FMT_BEGIN_EXPORT + +// Suppress a misleading warning in older versions of clang. +#if FMT_CLANG_VERSION +# pragma clang diagnostic ignored "-Wweak-vtables" +#endif + +/** An error reported from a formatting function. */ +class FMT_VISIBILITY("default") format_error : public std::runtime_error { + public: + using std::runtime_error::runtime_error; +}; + +namespace detail_exported { +#if FMT_USE_NONTYPE_TEMPLATE_ARGS +template struct fixed_string { + constexpr fixed_string(const Char (&str)[N]) { + detail::copy_str(static_cast(str), + str + N, data); + } + Char data[N] = {}; +}; +#endif + +// Converts a compile-time string to basic_string_view. +template +constexpr auto compile_string_to_view(const Char (&s)[N]) + -> basic_string_view { + // Remove trailing NUL character if needed. Won't be present if this is used + // with a raw character array (i.e. not defined as a string). + return {s, N - (std::char_traits::to_int_type(s[N - 1]) == 0 ? 1 : 0)}; +} +template +constexpr auto compile_string_to_view(detail::std_string_view s) + -> basic_string_view { + return {s.data(), s.size()}; +} +} // namespace detail_exported + +class loc_value { + private: + basic_format_arg value_; + + public: + template ::value)> + loc_value(T value) : value_(detail::make_arg(value)) {} + + template ::value)> + loc_value(T) {} + + template auto visit(Visitor&& vis) -> decltype(vis(0)) { + return visit_format_arg(vis, value_); + } +}; + +// A locale facet that formats values in UTF-8. +// It is parameterized on the locale to avoid the heavy include. +template class format_facet : public Locale::facet { + private: + std::string separator_; + std::string grouping_; + std::string decimal_point_; + + protected: + virtual auto do_put(appender out, loc_value val, + const format_specs<>& specs) const -> bool; + + public: + static FMT_API typename Locale::id id; + + explicit format_facet(Locale& loc); + explicit format_facet(string_view sep = "", + std::initializer_list g = {3}, + std::string decimal_point = ".") + : separator_(sep.data(), sep.size()), + grouping_(g.begin(), g.end()), + decimal_point_(decimal_point) {} + + auto put(appender out, loc_value val, const format_specs<>& specs) const + -> bool { + return do_put(out, val, specs); + } +}; + +namespace detail { + +// Returns true if value is negative, false otherwise. +// Same as `value < 0` but doesn't produce warnings if T is an unsigned type. +template ::value)> +constexpr auto is_negative(T value) -> bool { + return value < 0; +} +template ::value)> +constexpr auto is_negative(T) -> bool { + return false; +} + +template +FMT_CONSTEXPR auto is_supported_floating_point(T) -> bool { + if (std::is_same()) return FMT_USE_FLOAT; + if (std::is_same()) return FMT_USE_DOUBLE; + if (std::is_same()) return FMT_USE_LONG_DOUBLE; + return true; +} + +// Smallest of uint32_t, uint64_t, uint128_t that is large enough to +// represent all values of an integral type T. +template +using uint32_or_64_or_128_t = + conditional_t() <= 32 && !FMT_REDUCE_INT_INSTANTIATIONS, + uint32_t, + conditional_t() <= 64, uint64_t, uint128_t>>; +template +using uint64_or_128_t = conditional_t() <= 64, uint64_t, uint128_t>; + +#define FMT_POWERS_OF_10(factor) \ + factor * 10, (factor)*100, (factor)*1000, (factor)*10000, (factor)*100000, \ + (factor)*1000000, (factor)*10000000, (factor)*100000000, \ + (factor)*1000000000 + +// Converts value in the range [0, 100) to a string. +constexpr const char* digits2(size_t value) { + // GCC generates slightly better code when value is pointer-size. + return &"0001020304050607080910111213141516171819" + "2021222324252627282930313233343536373839" + "4041424344454647484950515253545556575859" + "6061626364656667686970717273747576777879" + "8081828384858687888990919293949596979899"[value * 2]; +} + +// Sign is a template parameter to workaround a bug in gcc 4.8. +template constexpr Char sign(Sign s) { +#if !FMT_GCC_VERSION || FMT_GCC_VERSION >= 604 + static_assert(std::is_same::value, ""); +#endif + return static_cast("\0-+ "[s]); +} + +template FMT_CONSTEXPR auto count_digits_fallback(T n) -> int { + int count = 1; + for (;;) { + // Integer division is slow so do it for a group of four digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + if (n < 10) return count; + if (n < 100) return count + 1; + if (n < 1000) return count + 2; + if (n < 10000) return count + 3; + n /= 10000u; + count += 4; + } +} +#if FMT_USE_INT128 +FMT_CONSTEXPR inline auto count_digits(uint128_opt n) -> int { + return count_digits_fallback(n); +} +#endif + +#ifdef FMT_BUILTIN_CLZLL +// It is a separate function rather than a part of count_digits to workaround +// the lack of static constexpr in constexpr functions. +inline auto do_count_digits(uint64_t n) -> int { + // This has comparable performance to the version by Kendall Willets + // (https://github.com/fmtlib/format-benchmark/blob/master/digits10) + // but uses smaller tables. + // Maps bsr(n) to ceil(log10(pow(2, bsr(n) + 1) - 1)). + static constexpr uint8_t bsr2log10[] = { + 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, + 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 9, 9, 9, 10, 10, 10, + 10, 11, 11, 11, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 15, 15, + 15, 16, 16, 16, 16, 17, 17, 17, 18, 18, 18, 19, 19, 19, 19, 20}; + auto t = bsr2log10[FMT_BUILTIN_CLZLL(n | 1) ^ 63]; + static constexpr const uint64_t zero_or_powers_of_10[] = { + 0, 0, FMT_POWERS_OF_10(1U), FMT_POWERS_OF_10(1000000000ULL), + 10000000000000000000ULL}; + return t - (n < zero_or_powers_of_10[t]); +} +#endif + +// Returns the number of decimal digits in n. Leading zeros are not counted +// except for n == 0 in which case count_digits returns 1. +FMT_CONSTEXPR20 inline auto count_digits(uint64_t n) -> int { +#ifdef FMT_BUILTIN_CLZLL + if (!is_constant_evaluated()) { + return do_count_digits(n); + } +#endif + return count_digits_fallback(n); +} + +// Counts the number of digits in n. BITS = log2(radix). +template +FMT_CONSTEXPR auto count_digits(UInt n) -> int { +#ifdef FMT_BUILTIN_CLZ + if (!is_constant_evaluated() && num_bits() == 32) + return (FMT_BUILTIN_CLZ(static_cast(n) | 1) ^ 31) / BITS + 1; +#endif + // Lambda avoids unreachable code warnings from NVHPC. + return [](UInt m) { + int num_digits = 0; + do { + ++num_digits; + } while ((m >>= BITS) != 0); + return num_digits; + }(n); +} + +#ifdef FMT_BUILTIN_CLZ +// It is a separate function rather than a part of count_digits to workaround +// the lack of static constexpr in constexpr functions. +FMT_INLINE auto do_count_digits(uint32_t n) -> int { +// An optimization by Kendall Willets from https://bit.ly/3uOIQrB. +// This increments the upper 32 bits (log10(T) - 1) when >= T is added. +# define FMT_INC(T) (((sizeof(#T) - 1ull) << 32) - T) + static constexpr uint64_t table[] = { + FMT_INC(0), FMT_INC(0), FMT_INC(0), // 8 + FMT_INC(10), FMT_INC(10), FMT_INC(10), // 64 + FMT_INC(100), FMT_INC(100), FMT_INC(100), // 512 + FMT_INC(1000), FMT_INC(1000), FMT_INC(1000), // 4096 + FMT_INC(10000), FMT_INC(10000), FMT_INC(10000), // 32k + FMT_INC(100000), FMT_INC(100000), FMT_INC(100000), // 256k + FMT_INC(1000000), FMT_INC(1000000), FMT_INC(1000000), // 2048k + FMT_INC(10000000), FMT_INC(10000000), FMT_INC(10000000), // 16M + FMT_INC(100000000), FMT_INC(100000000), FMT_INC(100000000), // 128M + FMT_INC(1000000000), FMT_INC(1000000000), FMT_INC(1000000000), // 1024M + FMT_INC(1000000000), FMT_INC(1000000000) // 4B + }; + auto inc = table[FMT_BUILTIN_CLZ(n | 1) ^ 31]; + return static_cast((n + inc) >> 32); +} +#endif + +// Optional version of count_digits for better performance on 32-bit platforms. +FMT_CONSTEXPR20 inline auto count_digits(uint32_t n) -> int { +#ifdef FMT_BUILTIN_CLZ + if (!is_constant_evaluated()) { + return do_count_digits(n); + } +#endif + return count_digits_fallback(n); +} + +template constexpr auto digits10() noexcept -> int { + return std::numeric_limits::digits10; +} +template <> constexpr auto digits10() noexcept -> int { return 38; } +template <> constexpr auto digits10() noexcept -> int { return 38; } + +template struct thousands_sep_result { + std::string grouping; + Char thousands_sep; +}; + +template +FMT_API auto thousands_sep_impl(locale_ref loc) -> thousands_sep_result; +template +inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { + auto result = thousands_sep_impl(loc); + return {result.grouping, Char(result.thousands_sep)}; +} +template <> +inline auto thousands_sep(locale_ref loc) -> thousands_sep_result { + return thousands_sep_impl(loc); +} + +template +FMT_API auto decimal_point_impl(locale_ref loc) -> Char; +template inline auto decimal_point(locale_ref loc) -> Char { + return Char(decimal_point_impl(loc)); +} +template <> inline auto decimal_point(locale_ref loc) -> wchar_t { + return decimal_point_impl(loc); +} + +// Compares two characters for equality. +template auto equal2(const Char* lhs, const char* rhs) -> bool { + return lhs[0] == Char(rhs[0]) && lhs[1] == Char(rhs[1]); +} +inline auto equal2(const char* lhs, const char* rhs) -> bool { + return memcmp(lhs, rhs, 2) == 0; +} + +// Copies two characters from src to dst. +template +FMT_CONSTEXPR20 FMT_INLINE void copy2(Char* dst, const char* src) { + if (!is_constant_evaluated() && sizeof(Char) == sizeof(char)) { + memcpy(dst, src, 2); + return; + } + *dst++ = static_cast(*src++); + *dst = static_cast(*src); +} + +template struct format_decimal_result { + Iterator begin; + Iterator end; +}; + +// Formats a decimal unsigned integer value writing into out pointing to a +// buffer of specified size. The caller must ensure that the buffer is large +// enough. +template +FMT_CONSTEXPR20 auto format_decimal(Char* out, UInt value, int size) + -> format_decimal_result { + FMT_ASSERT(size >= count_digits(value), "invalid digit count"); + out += size; + Char* end = out; + while (value >= 100) { + // Integer division is slow so do it for a group of two digits instead + // of for every digit. The idea comes from the talk by Alexandrescu + // "Three Optimization Tips for C++". See speed-test for a comparison. + out -= 2; + copy2(out, digits2(static_cast(value % 100))); + value /= 100; + } + if (value < 10) { + *--out = static_cast('0' + value); + return {out, end}; + } + out -= 2; + copy2(out, digits2(static_cast(value))); + return {out, end}; +} + +template >::value)> +FMT_CONSTEXPR inline auto format_decimal(Iterator out, UInt value, int size) + -> format_decimal_result { + // Buffer is large enough to hold all digits (digits10 + 1). + Char buffer[digits10() + 1] = {}; + auto end = format_decimal(buffer, value, size).end; + return {out, detail::copy_str_noinline(buffer, end, out)}; +} + +template +FMT_CONSTEXPR auto format_uint(Char* buffer, UInt value, int num_digits, + bool upper = false) -> Char* { + buffer += num_digits; + Char* end = buffer; + do { + const char* digits = upper ? "0123456789ABCDEF" : "0123456789abcdef"; + unsigned digit = static_cast(value & ((1 << BASE_BITS) - 1)); + *--buffer = static_cast(BASE_BITS < 4 ? static_cast('0' + digit) + : digits[digit]); + } while ((value >>= BASE_BITS) != 0); + return end; +} + +template +FMT_CONSTEXPR inline auto format_uint(It out, UInt value, int num_digits, + bool upper = false) -> It { + if (auto ptr = to_pointer(out, to_unsigned(num_digits))) { + format_uint(ptr, value, num_digits, upper); + return out; + } + // Buffer should be large enough to hold all digits (digits / BASE_BITS + 1). + char buffer[num_bits() / BASE_BITS + 1]; + format_uint(buffer, value, num_digits, upper); + return detail::copy_str_noinline(buffer, buffer + num_digits, out); +} + +// A converter from UTF-8 to UTF-16. +class utf8_to_utf16 { + private: + basic_memory_buffer buffer_; + + public: + FMT_API explicit utf8_to_utf16(string_view s); + operator basic_string_view() const { return {&buffer_[0], size()}; } + auto size() const -> size_t { return buffer_.size() - 1; } + auto c_str() const -> const wchar_t* { return &buffer_[0]; } + auto str() const -> std::wstring { return {&buffer_[0], size()}; } +}; + +enum class to_utf8_error_policy { abort, replace }; + +// A converter from UTF-16/UTF-32 (host endian) to UTF-8. +template class to_utf8 { + private: + Buffer buffer_; + + public: + to_utf8() {} + explicit to_utf8(basic_string_view s, + to_utf8_error_policy policy = to_utf8_error_policy::abort) { + static_assert(sizeof(WChar) == 2 || sizeof(WChar) == 4, + "Expect utf16 or utf32"); + if (!convert(s, policy)) + FMT_THROW(std::runtime_error(sizeof(WChar) == 2 ? "invalid utf16" + : "invalid utf32")); + } + operator string_view() const { return string_view(&buffer_[0], size()); } + size_t size() const { return buffer_.size() - 1; } + const char* c_str() const { return &buffer_[0]; } + std::string str() const { return std::string(&buffer_[0], size()); } + + // Performs conversion returning a bool instead of throwing exception on + // conversion error. This method may still throw in case of memory allocation + // error. + bool convert(basic_string_view s, + to_utf8_error_policy policy = to_utf8_error_policy::abort) { + if (!convert(buffer_, s, policy)) return false; + buffer_.push_back(0); + return true; + } + static bool convert( + Buffer& buf, basic_string_view s, + to_utf8_error_policy policy = to_utf8_error_policy::abort) { + for (auto p = s.begin(); p != s.end(); ++p) { + uint32_t c = static_cast(*p); + if (sizeof(WChar) == 2 && c >= 0xd800 && c <= 0xdfff) { + // Handle a surrogate pair. + ++p; + if (p == s.end() || (c & 0xfc00) != 0xd800 || (*p & 0xfc00) != 0xdc00) { + if (policy == to_utf8_error_policy::abort) return false; + buf.append(string_view("\xEF\xBF\xBD")); + --p; + } else { + c = (c << 10) + static_cast(*p) - 0x35fdc00; + } + } else if (c < 0x80) { + buf.push_back(static_cast(c)); + } else if (c < 0x800) { + buf.push_back(static_cast(0xc0 | (c >> 6))); + buf.push_back(static_cast(0x80 | (c & 0x3f))); + } else if ((c >= 0x800 && c <= 0xd7ff) || (c >= 0xe000 && c <= 0xffff)) { + buf.push_back(static_cast(0xe0 | (c >> 12))); + buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); + buf.push_back(static_cast(0x80 | (c & 0x3f))); + } else if (c >= 0x10000 && c <= 0x10ffff) { + buf.push_back(static_cast(0xf0 | (c >> 18))); + buf.push_back(static_cast(0x80 | ((c & 0x3ffff) >> 12))); + buf.push_back(static_cast(0x80 | ((c & 0xfff) >> 6))); + buf.push_back(static_cast(0x80 | (c & 0x3f))); + } else { + return false; + } + } + return true; + } +}; + +// Computes 128-bit result of multiplication of two 64-bit unsigned integers. +inline uint128_fallback umul128(uint64_t x, uint64_t y) noexcept { +#if FMT_USE_INT128 + auto p = static_cast(x) * static_cast(y); + return {static_cast(p >> 64), static_cast(p)}; +#elif defined(_MSC_VER) && defined(_M_X64) + auto hi = uint64_t(); + auto lo = _umul128(x, y, &hi); + return {hi, lo}; +#else + const uint64_t mask = static_cast(max_value()); + + uint64_t a = x >> 32; + uint64_t b = x & mask; + uint64_t c = y >> 32; + uint64_t d = y & mask; + + uint64_t ac = a * c; + uint64_t bc = b * c; + uint64_t ad = a * d; + uint64_t bd = b * d; + + uint64_t intermediate = (bd >> 32) + (ad & mask) + (bc & mask); + + return {ac + (intermediate >> 32) + (ad >> 32) + (bc >> 32), + (intermediate << 32) + (bd & mask)}; +#endif +} + +namespace dragonbox { +// Computes floor(log10(pow(2, e))) for e in [-2620, 2620] using the method from +// https://fmt.dev/papers/Dragonbox.pdf#page=28, section 6.1. +inline int floor_log10_pow2(int e) noexcept { + FMT_ASSERT(e <= 2620 && e >= -2620, "too large exponent"); + static_assert((-1 >> 1) == -1, "right shift is not arithmetic"); + return (e * 315653) >> 20; +} + +inline int floor_log2_pow10(int e) noexcept { + FMT_ASSERT(e <= 1233 && e >= -1233, "too large exponent"); + return (e * 1741647) >> 19; +} + +// Computes upper 64 bits of multiplication of two 64-bit unsigned integers. +inline uint64_t umul128_upper64(uint64_t x, uint64_t y) noexcept { +#if FMT_USE_INT128 + auto p = static_cast(x) * static_cast(y); + return static_cast(p >> 64); +#elif defined(_MSC_VER) && defined(_M_X64) + return __umulh(x, y); +#else + return umul128(x, y).high(); +#endif +} + +// Computes upper 128 bits of multiplication of a 64-bit unsigned integer and a +// 128-bit unsigned integer. +inline uint128_fallback umul192_upper128(uint64_t x, + uint128_fallback y) noexcept { + uint128_fallback r = umul128(x, y.high()); + r += umul128_upper64(x, y.low()); + return r; +} + +FMT_API uint128_fallback get_cached_power(int k) noexcept; + +// Type-specific information that Dragonbox uses. +template struct float_info; + +template <> struct float_info { + using carrier_uint = uint32_t; + static const int exponent_bits = 8; + static const int kappa = 1; + static const int big_divisor = 100; + static const int small_divisor = 10; + static const int min_k = -31; + static const int max_k = 46; + static const int shorter_interval_tie_lower_threshold = -35; + static const int shorter_interval_tie_upper_threshold = -35; +}; + +template <> struct float_info { + using carrier_uint = uint64_t; + static const int exponent_bits = 11; + static const int kappa = 2; + static const int big_divisor = 1000; + static const int small_divisor = 100; + static const int min_k = -292; + static const int max_k = 341; + static const int shorter_interval_tie_lower_threshold = -77; + static const int shorter_interval_tie_upper_threshold = -77; +}; + +// An 80- or 128-bit floating point number. +template +struct float_info::digits == 64 || + std::numeric_limits::digits == 113 || + is_float128::value>> { + using carrier_uint = detail::uint128_t; + static const int exponent_bits = 15; +}; + +// A double-double floating point number. +template +struct float_info::value>> { + using carrier_uint = detail::uint128_t; +}; + +template struct decimal_fp { + using significand_type = typename float_info::carrier_uint; + significand_type significand; + int exponent; +}; + +template FMT_API auto to_decimal(T x) noexcept -> decimal_fp; +} // namespace dragonbox + +// Returns true iff Float has the implicit bit which is not stored. +template constexpr bool has_implicit_bit() { + // An 80-bit FP number has a 64-bit significand an no implicit bit. + return std::numeric_limits::digits != 64; +} + +// Returns the number of significand bits stored in Float. The implicit bit is +// not counted since it is not stored. +template constexpr int num_significand_bits() { + // std::numeric_limits may not support __float128. + return is_float128() ? 112 + : (std::numeric_limits::digits - + (has_implicit_bit() ? 1 : 0)); +} + +template +constexpr auto exponent_mask() -> + typename dragonbox::float_info::carrier_uint { + using float_uint = typename dragonbox::float_info::carrier_uint; + return ((float_uint(1) << dragonbox::float_info::exponent_bits) - 1) + << num_significand_bits(); +} +template constexpr auto exponent_bias() -> int { + // std::numeric_limits may not support __float128. + return is_float128() ? 16383 + : std::numeric_limits::max_exponent - 1; +} + +// Writes the exponent exp in the form "[+-]d{2,3}" to buffer. +template +FMT_CONSTEXPR auto write_exponent(int exp, It it) -> It { + FMT_ASSERT(-10000 < exp && exp < 10000, "exponent out of range"); + if (exp < 0) { + *it++ = static_cast('-'); + exp = -exp; + } else { + *it++ = static_cast('+'); + } + if (exp >= 100) { + const char* top = digits2(to_unsigned(exp / 100)); + if (exp >= 1000) *it++ = static_cast(top[0]); + *it++ = static_cast(top[1]); + exp %= 100; + } + const char* d = digits2(to_unsigned(exp)); + *it++ = static_cast(d[0]); + *it++ = static_cast(d[1]); + return it; +} + +// A floating-point number f * pow(2, e) where F is an unsigned type. +template struct basic_fp { + F f; + int e; + + static constexpr const int num_significand_bits = + static_cast(sizeof(F) * num_bits()); + + constexpr basic_fp() : f(0), e(0) {} + constexpr basic_fp(uint64_t f_val, int e_val) : f(f_val), e(e_val) {} + + // Constructs fp from an IEEE754 floating-point number. + template FMT_CONSTEXPR basic_fp(Float n) { assign(n); } + + // Assigns n to this and return true iff predecessor is closer than successor. + template ::value)> + FMT_CONSTEXPR auto assign(Float n) -> bool { + static_assert(std::numeric_limits::digits <= 113, "unsupported FP"); + // Assume Float is in the format [sign][exponent][significand]. + using carrier_uint = typename dragonbox::float_info::carrier_uint; + const auto num_float_significand_bits = + detail::num_significand_bits(); + const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; + const auto significand_mask = implicit_bit - 1; + auto u = bit_cast(n); + f = static_cast(u & significand_mask); + auto biased_e = static_cast((u & exponent_mask()) >> + num_float_significand_bits); + // The predecessor is closer if n is a normalized power of 2 (f == 0) + // other than the smallest normalized number (biased_e > 1). + auto is_predecessor_closer = f == 0 && biased_e > 1; + if (biased_e == 0) + biased_e = 1; // Subnormals use biased exponent 1 (min exponent). + else if (has_implicit_bit()) + f += static_cast(implicit_bit); + e = biased_e - exponent_bias() - num_float_significand_bits; + if (!has_implicit_bit()) ++e; + return is_predecessor_closer; + } + + template ::value)> + FMT_CONSTEXPR auto assign(Float n) -> bool { + static_assert(std::numeric_limits::is_iec559, "unsupported FP"); + return assign(static_cast(n)); + } +}; + +using fp = basic_fp; + +// Normalizes the value converted from double and multiplied by (1 << SHIFT). +template +FMT_CONSTEXPR basic_fp normalize(basic_fp value) { + // Handle subnormals. + const auto implicit_bit = F(1) << num_significand_bits(); + const auto shifted_implicit_bit = implicit_bit << SHIFT; + while ((value.f & shifted_implicit_bit) == 0) { + value.f <<= 1; + --value.e; + } + // Subtract 1 to account for hidden bit. + const auto offset = basic_fp::num_significand_bits - + num_significand_bits() - SHIFT - 1; + value.f <<= offset; + value.e -= offset; + return value; +} + +// Computes lhs * rhs / pow(2, 64) rounded to nearest with half-up tie breaking. +FMT_CONSTEXPR inline uint64_t multiply(uint64_t lhs, uint64_t rhs) { +#if FMT_USE_INT128 + auto product = static_cast<__uint128_t>(lhs) * rhs; + auto f = static_cast(product >> 64); + return (static_cast(product) & (1ULL << 63)) != 0 ? f + 1 : f; +#else + // Multiply 32-bit parts of significands. + uint64_t mask = (1ULL << 32) - 1; + uint64_t a = lhs >> 32, b = lhs & mask; + uint64_t c = rhs >> 32, d = rhs & mask; + uint64_t ac = a * c, bc = b * c, ad = a * d, bd = b * d; + // Compute mid 64-bit of result and round. + uint64_t mid = (bd >> 32) + (ad & mask) + (bc & mask) + (1U << 31); + return ac + (ad >> 32) + (bc >> 32) + (mid >> 32); +#endif +} + +FMT_CONSTEXPR inline fp operator*(fp x, fp y) { + return {multiply(x.f, y.f), x.e + y.e + 64}; +} + +template struct basic_data { + // For checking rounding thresholds. + // The kth entry is chosen to be the smallest integer such that the + // upper 32-bits of 10^(k+1) times it is strictly bigger than 5 * 10^k. + static constexpr uint32_t fractional_part_rounding_thresholds[8] = { + 2576980378U, // ceil(2^31 + 2^32/10^1) + 2190433321U, // ceil(2^31 + 2^32/10^2) + 2151778616U, // ceil(2^31 + 2^32/10^3) + 2147913145U, // ceil(2^31 + 2^32/10^4) + 2147526598U, // ceil(2^31 + 2^32/10^5) + 2147487943U, // ceil(2^31 + 2^32/10^6) + 2147484078U, // ceil(2^31 + 2^32/10^7) + 2147483691U // ceil(2^31 + 2^32/10^8) + }; +}; +// This is a struct rather than an alias to avoid shadowing warnings in gcc. +struct data : basic_data<> {}; + +#if FMT_CPLUSPLUS < 201703L +template +constexpr uint32_t basic_data::fractional_part_rounding_thresholds[]; +#endif + +template () == num_bits()> +using convert_float_result = + conditional_t::value || doublish, double, T>; + +template +constexpr auto convert_float(T value) -> convert_float_result { + return static_cast>(value); +} + +template +FMT_NOINLINE FMT_CONSTEXPR auto fill(OutputIt it, size_t n, + const fill_t& fill) -> OutputIt { + auto fill_size = fill.size(); + if (fill_size == 1) return detail::fill_n(it, n, fill[0]); + auto data = fill.data(); + for (size_t i = 0; i < n; ++i) + it = copy_str(data, data + fill_size, it); + return it; +} + +// Writes the output of f, padded according to format specifications in specs. +// size: output size in code units. +// width: output display width in (terminal) column positions. +template +FMT_CONSTEXPR auto write_padded(OutputIt out, const format_specs& specs, + size_t size, size_t width, F&& f) -> OutputIt { + static_assert(align == align::left || align == align::right, ""); + unsigned spec_width = to_unsigned(specs.width); + size_t padding = spec_width > width ? spec_width - width : 0; + // Shifts are encoded as string literals because static constexpr is not + // supported in constexpr functions. + auto* shifts = align == align::left ? "\x1f\x1f\x00\x01" : "\x00\x1f\x00\x01"; + size_t left_padding = padding >> shifts[specs.align]; + size_t right_padding = padding - left_padding; + auto it = reserve(out, size + padding * specs.fill.size()); + if (left_padding != 0) it = fill(it, left_padding, specs.fill); + it = f(it); + if (right_padding != 0) it = fill(it, right_padding, specs.fill); + return base_iterator(out, it); +} + +template +constexpr auto write_padded(OutputIt out, const format_specs& specs, + size_t size, F&& f) -> OutputIt { + return write_padded(out, specs, size, size, f); +} + +template +FMT_CONSTEXPR auto write_bytes(OutputIt out, string_view bytes, + const format_specs& specs) -> OutputIt { + return write_padded( + out, specs, bytes.size(), [bytes](reserve_iterator it) { + const char* data = bytes.data(); + return copy_str(data, data + bytes.size(), it); + }); +} + +template +auto write_ptr(OutputIt out, UIntPtr value, const format_specs* specs) + -> OutputIt { + int num_digits = count_digits<4>(value); + auto size = to_unsigned(num_digits) + size_t(2); + auto write = [=](reserve_iterator it) { + *it++ = static_cast('0'); + *it++ = static_cast('x'); + return format_uint<4, Char>(it, value, num_digits); + }; + return specs ? write_padded(out, *specs, size, write) + : base_iterator(out, write(reserve(out, size))); +} + +// Returns true iff the code point cp is printable. +FMT_API auto is_printable(uint32_t cp) -> bool; + +inline auto needs_escape(uint32_t cp) -> bool { + return cp < 0x20 || cp == 0x7f || cp == '"' || cp == '\\' || + !is_printable(cp); +} + +template struct find_escape_result { + const Char* begin; + const Char* end; + uint32_t cp; +}; + +template +using make_unsigned_char = + typename conditional_t::value, + std::make_unsigned, + type_identity>::type; + +template +auto find_escape(const Char* begin, const Char* end) + -> find_escape_result { + for (; begin != end; ++begin) { + uint32_t cp = static_cast>(*begin); + if (const_check(sizeof(Char) == 1) && cp >= 0x80) continue; + if (needs_escape(cp)) return {begin, begin + 1, cp}; + } + return {begin, nullptr, 0}; +} + +inline auto find_escape(const char* begin, const char* end) + -> find_escape_result { + if (!is_utf8()) return find_escape(begin, end); + auto result = find_escape_result{end, nullptr, 0}; + for_each_codepoint(string_view(begin, to_unsigned(end - begin)), + [&](uint32_t cp, string_view sv) { + if (needs_escape(cp)) { + result = {sv.begin(), sv.end(), cp}; + return false; + } + return true; + }); + return result; +} + +#define FMT_STRING_IMPL(s, base, explicit) \ + [] { \ + /* Use the hidden visibility as a workaround for a GCC bug (#1973). */ \ + /* Use a macro-like name to avoid shadowing warnings. */ \ + struct FMT_VISIBILITY("hidden") FMT_COMPILE_STRING : base { \ + using char_type FMT_MAYBE_UNUSED = fmt::remove_cvref_t; \ + FMT_MAYBE_UNUSED FMT_CONSTEXPR explicit \ + operator fmt::basic_string_view() const { \ + return fmt::detail_exported::compile_string_to_view(s); \ + } \ + }; \ + return FMT_COMPILE_STRING(); \ + }() + +/** + \rst + Constructs a compile-time format string from a string literal *s*. + + **Example**:: + + // A compile-time error because 'd' is an invalid specifier for strings. + std::string s = fmt::format(FMT_STRING("{:d}"), "foo"); + \endrst + */ +#define FMT_STRING(s) FMT_STRING_IMPL(s, fmt::detail::compile_string, ) + +template +auto write_codepoint(OutputIt out, char prefix, uint32_t cp) -> OutputIt { + *out++ = static_cast('\\'); + *out++ = static_cast(prefix); + Char buf[width]; + fill_n(buf, width, static_cast('0')); + format_uint<4>(buf, cp, width); + return copy_str(buf, buf + width, out); +} + +template +auto write_escaped_cp(OutputIt out, const find_escape_result& escape) + -> OutputIt { + auto c = static_cast(escape.cp); + switch (escape.cp) { + case '\n': + *out++ = static_cast('\\'); + c = static_cast('n'); + break; + case '\r': + *out++ = static_cast('\\'); + c = static_cast('r'); + break; + case '\t': + *out++ = static_cast('\\'); + c = static_cast('t'); + break; + case '"': + FMT_FALLTHROUGH; + case '\'': + FMT_FALLTHROUGH; + case '\\': + *out++ = static_cast('\\'); + break; + default: + if (escape.cp < 0x100) { + return write_codepoint<2, Char>(out, 'x', escape.cp); + } + if (escape.cp < 0x10000) { + return write_codepoint<4, Char>(out, 'u', escape.cp); + } + if (escape.cp < 0x110000) { + return write_codepoint<8, Char>(out, 'U', escape.cp); + } + for (Char escape_char : basic_string_view( + escape.begin, to_unsigned(escape.end - escape.begin))) { + out = write_codepoint<2, Char>(out, 'x', + static_cast(escape_char) & 0xFF); + } + return out; + } + *out++ = c; + return out; +} + +template +auto write_escaped_string(OutputIt out, basic_string_view str) + -> OutputIt { + *out++ = static_cast('"'); + auto begin = str.begin(), end = str.end(); + do { + auto escape = find_escape(begin, end); + out = copy_str(begin, escape.begin, out); + begin = escape.end; + if (!begin) break; + out = write_escaped_cp(out, escape); + } while (begin != end); + *out++ = static_cast('"'); + return out; +} + +template +auto write_escaped_char(OutputIt out, Char v) -> OutputIt { + *out++ = static_cast('\''); + if ((needs_escape(static_cast(v)) && v != static_cast('"')) || + v == static_cast('\'')) { + out = write_escaped_cp( + out, find_escape_result{&v, &v + 1, static_cast(v)}); + } else { + *out++ = v; + } + *out++ = static_cast('\''); + return out; +} + +template +FMT_CONSTEXPR auto write_char(OutputIt out, Char value, + const format_specs& specs) -> OutputIt { + bool is_debug = specs.type == presentation_type::debug; + return write_padded(out, specs, 1, [=](reserve_iterator it) { + if (is_debug) return write_escaped_char(it, value); + *it++ = value; + return it; + }); +} +template +FMT_CONSTEXPR auto write(OutputIt out, Char value, + const format_specs& specs, locale_ref loc = {}) + -> OutputIt { + // char is formatted as unsigned char for consistency across platforms. + using unsigned_type = + conditional_t::value, unsigned char, unsigned>; + return check_char_specs(specs) + ? write_char(out, value, specs) + : write(out, static_cast(value), specs, loc); +} + +// Data for write_int that doesn't depend on output iterator type. It is used to +// avoid template code bloat. +template struct write_int_data { + size_t size; + size_t padding; + + FMT_CONSTEXPR write_int_data(int num_digits, unsigned prefix, + const format_specs& specs) + : size((prefix >> 24) + to_unsigned(num_digits)), padding(0) { + if (specs.align == align::numeric) { + auto width = to_unsigned(specs.width); + if (width > size) { + padding = width - size; + size = width; + } + } else if (specs.precision > num_digits) { + size = (prefix >> 24) + to_unsigned(specs.precision); + padding = to_unsigned(specs.precision - num_digits); + } + } +}; + +// Writes an integer in the format +// +// where are written by write_digits(it). +// prefix contains chars in three lower bytes and the size in the fourth byte. +template +FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, int num_digits, + unsigned prefix, + const format_specs& specs, + W write_digits) -> OutputIt { + // Slightly faster check for specs.width == 0 && specs.precision == -1. + if ((specs.width | (specs.precision + 1)) == 0) { + auto it = reserve(out, to_unsigned(num_digits) + (prefix >> 24)); + if (prefix != 0) { + for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) + *it++ = static_cast(p & 0xff); + } + return base_iterator(out, write_digits(it)); + } + auto data = write_int_data(num_digits, prefix, specs); + return write_padded( + out, specs, data.size, [=](reserve_iterator it) { + for (unsigned p = prefix & 0xffffff; p != 0; p >>= 8) + *it++ = static_cast(p & 0xff); + it = detail::fill_n(it, data.padding, static_cast('0')); + return write_digits(it); + }); +} + +template class digit_grouping { + private: + std::string grouping_; + std::basic_string thousands_sep_; + + struct next_state { + std::string::const_iterator group; + int pos; + }; + next_state initial_state() const { return {grouping_.begin(), 0}; } + + // Returns the next digit group separator position. + int next(next_state& state) const { + if (thousands_sep_.empty()) return max_value(); + if (state.group == grouping_.end()) return state.pos += grouping_.back(); + if (*state.group <= 0 || *state.group == max_value()) + return max_value(); + state.pos += *state.group++; + return state.pos; + } + + public: + explicit digit_grouping(locale_ref loc, bool localized = true) { + if (!localized) return; + auto sep = thousands_sep(loc); + grouping_ = sep.grouping; + if (sep.thousands_sep) thousands_sep_.assign(1, sep.thousands_sep); + } + digit_grouping(std::string grouping, std::basic_string sep) + : grouping_(std::move(grouping)), thousands_sep_(std::move(sep)) {} + + bool has_separator() const { return !thousands_sep_.empty(); } + + int count_separators(int num_digits) const { + int count = 0; + auto state = initial_state(); + while (num_digits > next(state)) ++count; + return count; + } + + // Applies grouping to digits and write the output to out. + template + Out apply(Out out, basic_string_view digits) const { + auto num_digits = static_cast(digits.size()); + auto separators = basic_memory_buffer(); + separators.push_back(0); + auto state = initial_state(); + while (int i = next(state)) { + if (i >= num_digits) break; + separators.push_back(i); + } + for (int i = 0, sep_index = static_cast(separators.size() - 1); + i < num_digits; ++i) { + if (num_digits - i == separators[sep_index]) { + out = + copy_str(thousands_sep_.data(), + thousands_sep_.data() + thousands_sep_.size(), out); + --sep_index; + } + *out++ = static_cast(digits[to_unsigned(i)]); + } + return out; + } +}; + +// Writes a decimal integer with digit grouping. +template +auto write_int(OutputIt out, UInt value, unsigned prefix, + const format_specs& specs, + const digit_grouping& grouping) -> OutputIt { + static_assert(std::is_same, UInt>::value, ""); + int num_digits = count_digits(value); + char digits[40]; + format_decimal(digits, value, num_digits); + unsigned size = to_unsigned((prefix != 0 ? 1 : 0) + num_digits + + grouping.count_separators(num_digits)); + return write_padded( + out, specs, size, size, [&](reserve_iterator it) { + if (prefix != 0) { + char sign = static_cast(prefix); + *it++ = static_cast(sign); + } + return grouping.apply(it, string_view(digits, to_unsigned(num_digits))); + }); +} + +// Writes a localized value. +FMT_API auto write_loc(appender out, loc_value value, + const format_specs<>& specs, locale_ref loc) -> bool; +template +inline auto write_loc(OutputIt, loc_value, const format_specs&, + locale_ref) -> bool { + return false; +} + +FMT_CONSTEXPR inline void prefix_append(unsigned& prefix, unsigned value) { + prefix |= prefix != 0 ? value << 8 : value; + prefix += (1u + (value > 0xff ? 1 : 0)) << 24; +} + +template struct write_int_arg { + UInt abs_value; + unsigned prefix; +}; + +template +FMT_CONSTEXPR auto make_write_int_arg(T value, sign_t sign) + -> write_int_arg> { + auto prefix = 0u; + auto abs_value = static_cast>(value); + if (is_negative(value)) { + prefix = 0x01000000 | '-'; + abs_value = 0 - abs_value; + } else { + constexpr const unsigned prefixes[4] = {0, 0, 0x1000000u | '+', + 0x1000000u | ' '}; + prefix = prefixes[sign]; + } + return {abs_value, prefix}; +} + +template struct loc_writer { + buffer_appender out; + const format_specs& specs; + std::basic_string sep; + std::string grouping; + std::basic_string decimal_point; + + template ::value)> + auto operator()(T value) -> bool { + auto arg = make_write_int_arg(value, specs.sign); + write_int(out, static_cast>(arg.abs_value), arg.prefix, + specs, digit_grouping(grouping, sep)); + return true; + } + + template ::value)> + auto operator()(T) -> bool { + return false; + } +}; + +template +FMT_CONSTEXPR FMT_INLINE auto write_int(OutputIt out, write_int_arg arg, + const format_specs& specs, + locale_ref) -> OutputIt { + static_assert(std::is_same>::value, ""); + auto abs_value = arg.abs_value; + auto prefix = arg.prefix; + switch (specs.type) { + case presentation_type::none: + case presentation_type::dec: { + auto num_digits = count_digits(abs_value); + return write_int( + out, num_digits, prefix, specs, [=](reserve_iterator it) { + return format_decimal(it, abs_value, num_digits).end; + }); + } + case presentation_type::hex_lower: + case presentation_type::hex_upper: { + bool upper = specs.type == presentation_type::hex_upper; + if (specs.alt) + prefix_append(prefix, unsigned(upper ? 'X' : 'x') << 8 | '0'); + int num_digits = count_digits<4>(abs_value); + return write_int( + out, num_digits, prefix, specs, [=](reserve_iterator it) { + return format_uint<4, Char>(it, abs_value, num_digits, upper); + }); + } + case presentation_type::bin_lower: + case presentation_type::bin_upper: { + bool upper = specs.type == presentation_type::bin_upper; + if (specs.alt) + prefix_append(prefix, unsigned(upper ? 'B' : 'b') << 8 | '0'); + int num_digits = count_digits<1>(abs_value); + return write_int(out, num_digits, prefix, specs, + [=](reserve_iterator it) { + return format_uint<1, Char>(it, abs_value, num_digits); + }); + } + case presentation_type::oct: { + int num_digits = count_digits<3>(abs_value); + // Octal prefix '0' is counted as a digit, so only add it if precision + // is not greater than the number of digits. + if (specs.alt && specs.precision <= num_digits && abs_value != 0) + prefix_append(prefix, '0'); + return write_int(out, num_digits, prefix, specs, + [=](reserve_iterator it) { + return format_uint<3, Char>(it, abs_value, num_digits); + }); + } + case presentation_type::chr: + return write_char(out, static_cast(abs_value), specs); + default: + throw_format_error("invalid format specifier"); + } + return out; +} +template +FMT_CONSTEXPR FMT_NOINLINE auto write_int_noinline( + OutputIt out, write_int_arg arg, const format_specs& specs, + locale_ref loc) -> OutputIt { + return write_int(out, arg, specs, loc); +} +template ::value && + !std::is_same::value && + std::is_same>::value)> +FMT_CONSTEXPR FMT_INLINE auto write(OutputIt out, T value, + const format_specs& specs, + locale_ref loc) -> OutputIt { + if (specs.localized && write_loc(out, value, specs, loc)) return out; + return write_int_noinline(out, make_write_int_arg(value, specs.sign), specs, + loc); +} +// An inlined version of write used in format string compilation. +template ::value && + !std::is_same::value && + !std::is_same>::value)> +FMT_CONSTEXPR FMT_INLINE auto write(OutputIt out, T value, + const format_specs& specs, + locale_ref loc) -> OutputIt { + if (specs.localized && write_loc(out, value, specs, loc)) return out; + return write_int(out, make_write_int_arg(value, specs.sign), specs, loc); +} + +// An output iterator that counts the number of objects written to it and +// discards them. +class counting_iterator { + private: + size_t count_; + + public: + using iterator_category = std::output_iterator_tag; + using difference_type = std::ptrdiff_t; + using pointer = void; + using reference = void; + FMT_UNCHECKED_ITERATOR(counting_iterator); + + struct value_type { + template FMT_CONSTEXPR void operator=(const T&) {} + }; + + FMT_CONSTEXPR counting_iterator() : count_(0) {} + + FMT_CONSTEXPR size_t count() const { return count_; } + + FMT_CONSTEXPR counting_iterator& operator++() { + ++count_; + return *this; + } + FMT_CONSTEXPR counting_iterator operator++(int) { + auto it = *this; + ++*this; + return it; + } + + FMT_CONSTEXPR friend counting_iterator operator+(counting_iterator it, + difference_type n) { + it.count_ += static_cast(n); + return it; + } + + FMT_CONSTEXPR value_type operator*() const { return {}; } +}; + +template +FMT_CONSTEXPR auto write(OutputIt out, basic_string_view s, + const format_specs& specs) -> OutputIt { + auto data = s.data(); + auto size = s.size(); + if (specs.precision >= 0 && to_unsigned(specs.precision) < size) + size = code_point_index(s, to_unsigned(specs.precision)); + bool is_debug = specs.type == presentation_type::debug; + size_t width = 0; + if (specs.width != 0) { + if (is_debug) + width = write_escaped_string(counting_iterator{}, s).count(); + else + width = compute_width(basic_string_view(data, size)); + } + return write_padded(out, specs, size, width, + [=](reserve_iterator it) { + if (is_debug) return write_escaped_string(it, s); + return copy_str(data, data + size, it); + }); +} +template +FMT_CONSTEXPR auto write(OutputIt out, + basic_string_view> s, + const format_specs& specs, locale_ref) + -> OutputIt { + return write(out, s, specs); +} +template +FMT_CONSTEXPR auto write(OutputIt out, const Char* s, + const format_specs& specs, locale_ref) + -> OutputIt { + return specs.type != presentation_type::pointer + ? write(out, basic_string_view(s), specs, {}) + : write_ptr(out, bit_cast(s), &specs); +} + +template ::value && + !std::is_same::value && + !std::is_same::value)> +FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { + auto abs_value = static_cast>(value); + bool negative = is_negative(value); + // Don't do -abs_value since it trips unsigned-integer-overflow sanitizer. + if (negative) abs_value = ~abs_value + 1; + int num_digits = count_digits(abs_value); + auto size = (negative ? 1 : 0) + static_cast(num_digits); + auto it = reserve(out, size); + if (auto ptr = to_pointer(it, size)) { + if (negative) *ptr++ = static_cast('-'); + format_decimal(ptr, abs_value, num_digits); + return out; + } + if (negative) *it++ = static_cast('-'); + it = format_decimal(it, abs_value, num_digits).end; + return base_iterator(out, it); +} + +// DEPRECATED! +template +FMT_CONSTEXPR auto parse_align(const Char* begin, const Char* end, + format_specs& specs) -> const Char* { + FMT_ASSERT(begin != end, ""); + auto align = align::none; + auto p = begin + code_point_length(begin); + if (end - p <= 0) p = begin; + for (;;) { + switch (to_ascii(*p)) { + case '<': + align = align::left; + break; + case '>': + align = align::right; + break; + case '^': + align = align::center; + break; + } + if (align != align::none) { + if (p != begin) { + auto c = *begin; + if (c == '}') return begin; + if (c == '{') { + throw_format_error("invalid fill character '{'"); + return begin; + } + specs.fill = {begin, to_unsigned(p - begin)}; + begin = p + 1; + } else { + ++begin; + } + break; + } else if (p == begin) { + break; + } + p = begin; + } + specs.align = align; + return begin; +} + +// A floating-point presentation format. +enum class float_format : unsigned char { + general, // General: exponent notation or fixed point based on magnitude. + exp, // Exponent notation with the default precision of 6, e.g. 1.2e-3. + fixed, // Fixed point with the default precision of 6, e.g. 0.0012. + hex +}; + +struct float_specs { + int precision; + float_format format : 8; + sign_t sign : 8; + bool upper : 1; + bool locale : 1; + bool binary32 : 1; + bool showpoint : 1; +}; + +template +FMT_CONSTEXPR auto parse_float_type_spec(const format_specs& specs, + ErrorHandler&& eh = {}) + -> float_specs { + auto result = float_specs(); + result.showpoint = specs.alt; + result.locale = specs.localized; + switch (specs.type) { + case presentation_type::none: + result.format = float_format::general; + break; + case presentation_type::general_upper: + result.upper = true; + FMT_FALLTHROUGH; + case presentation_type::general_lower: + result.format = float_format::general; + break; + case presentation_type::exp_upper: + result.upper = true; + FMT_FALLTHROUGH; + case presentation_type::exp_lower: + result.format = float_format::exp; + result.showpoint |= specs.precision != 0; + break; + case presentation_type::fixed_upper: + result.upper = true; + FMT_FALLTHROUGH; + case presentation_type::fixed_lower: + result.format = float_format::fixed; + result.showpoint |= specs.precision != 0; + break; + case presentation_type::hexfloat_upper: + result.upper = true; + FMT_FALLTHROUGH; + case presentation_type::hexfloat_lower: + result.format = float_format::hex; + break; + default: + eh.on_error("invalid format specifier"); + break; + } + return result; +} + +template +FMT_CONSTEXPR20 auto write_nonfinite(OutputIt out, bool isnan, + format_specs specs, + const float_specs& fspecs) -> OutputIt { + auto str = + isnan ? (fspecs.upper ? "NAN" : "nan") : (fspecs.upper ? "INF" : "inf"); + constexpr size_t str_size = 3; + auto sign = fspecs.sign; + auto size = str_size + (sign ? 1 : 0); + // Replace '0'-padding with space for non-finite values. + const bool is_zero_fill = + specs.fill.size() == 1 && *specs.fill.data() == static_cast('0'); + if (is_zero_fill) specs.fill[0] = static_cast(' '); + return write_padded(out, specs, size, [=](reserve_iterator it) { + if (sign) *it++ = detail::sign(sign); + return copy_str(str, str + str_size, it); + }); +} + +// A decimal floating-point number significand * pow(10, exp). +struct big_decimal_fp { + const char* significand; + int significand_size; + int exponent; +}; + +constexpr auto get_significand_size(const big_decimal_fp& f) -> int { + return f.significand_size; +} +template +inline auto get_significand_size(const dragonbox::decimal_fp& f) -> int { + return count_digits(f.significand); +} + +template +constexpr auto write_significand(OutputIt out, const char* significand, + int significand_size) -> OutputIt { + return copy_str(significand, significand + significand_size, out); +} +template +inline auto write_significand(OutputIt out, UInt significand, + int significand_size) -> OutputIt { + return format_decimal(out, significand, significand_size).end; +} +template +FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, + int significand_size, int exponent, + const Grouping& grouping) -> OutputIt { + if (!grouping.has_separator()) { + out = write_significand(out, significand, significand_size); + return detail::fill_n(out, exponent, static_cast('0')); + } + auto buffer = memory_buffer(); + write_significand(appender(buffer), significand, significand_size); + detail::fill_n(appender(buffer), exponent, '0'); + return grouping.apply(out, string_view(buffer.data(), buffer.size())); +} + +template ::value)> +inline auto write_significand(Char* out, UInt significand, int significand_size, + int integral_size, Char decimal_point) -> Char* { + if (!decimal_point) + return format_decimal(out, significand, significand_size).end; + out += significand_size + 1; + Char* end = out; + int floating_size = significand_size - integral_size; + for (int i = floating_size / 2; i > 0; --i) { + out -= 2; + copy2(out, digits2(static_cast(significand % 100))); + significand /= 100; + } + if (floating_size % 2 != 0) { + *--out = static_cast('0' + significand % 10); + significand /= 10; + } + *--out = decimal_point; + format_decimal(out - integral_size, significand, integral_size); + return end; +} + +template >::value)> +inline auto write_significand(OutputIt out, UInt significand, + int significand_size, int integral_size, + Char decimal_point) -> OutputIt { + // Buffer is large enough to hold digits (digits10 + 1) and a decimal point. + Char buffer[digits10() + 2]; + auto end = write_significand(buffer, significand, significand_size, + integral_size, decimal_point); + return detail::copy_str_noinline(buffer, end, out); +} + +template +FMT_CONSTEXPR auto write_significand(OutputIt out, const char* significand, + int significand_size, int integral_size, + Char decimal_point) -> OutputIt { + out = detail::copy_str_noinline(significand, + significand + integral_size, out); + if (!decimal_point) return out; + *out++ = decimal_point; + return detail::copy_str_noinline(significand + integral_size, + significand + significand_size, out); +} + +template +FMT_CONSTEXPR20 auto write_significand(OutputIt out, T significand, + int significand_size, int integral_size, + Char decimal_point, + const Grouping& grouping) -> OutputIt { + if (!grouping.has_separator()) { + return write_significand(out, significand, significand_size, integral_size, + decimal_point); + } + auto buffer = basic_memory_buffer(); + write_significand(buffer_appender(buffer), significand, + significand_size, integral_size, decimal_point); + grouping.apply( + out, basic_string_view(buffer.data(), to_unsigned(integral_size))); + return detail::copy_str_noinline(buffer.data() + integral_size, + buffer.end(), out); +} + +template > +FMT_CONSTEXPR20 auto do_write_float(OutputIt out, const DecimalFP& f, + const format_specs& specs, + float_specs fspecs, locale_ref loc) + -> OutputIt { + auto significand = f.significand; + int significand_size = get_significand_size(f); + const Char zero = static_cast('0'); + auto sign = fspecs.sign; + size_t size = to_unsigned(significand_size) + (sign ? 1 : 0); + using iterator = reserve_iterator; + + Char decimal_point = + fspecs.locale ? detail::decimal_point(loc) : static_cast('.'); + + int output_exp = f.exponent + significand_size - 1; + auto use_exp_format = [=]() { + if (fspecs.format == float_format::exp) return true; + if (fspecs.format != float_format::general) return false; + // Use the fixed notation if the exponent is in [exp_lower, exp_upper), + // e.g. 0.0001 instead of 1e-04. Otherwise use the exponent notation. + const int exp_lower = -4, exp_upper = 16; + return output_exp < exp_lower || + output_exp >= (fspecs.precision > 0 ? fspecs.precision : exp_upper); + }; + if (use_exp_format()) { + int num_zeros = 0; + if (fspecs.showpoint) { + num_zeros = fspecs.precision - significand_size; + if (num_zeros < 0) num_zeros = 0; + size += to_unsigned(num_zeros); + } else if (significand_size == 1) { + decimal_point = Char(); + } + auto abs_output_exp = output_exp >= 0 ? output_exp : -output_exp; + int exp_digits = 2; + if (abs_output_exp >= 100) exp_digits = abs_output_exp >= 1000 ? 4 : 3; + + size += to_unsigned((decimal_point ? 1 : 0) + 2 + exp_digits); + char exp_char = fspecs.upper ? 'E' : 'e'; + auto write = [=](iterator it) { + if (sign) *it++ = detail::sign(sign); + // Insert a decimal point after the first digit and add an exponent. + it = write_significand(it, significand, significand_size, 1, + decimal_point); + if (num_zeros > 0) it = detail::fill_n(it, num_zeros, zero); + *it++ = static_cast(exp_char); + return write_exponent(output_exp, it); + }; + return specs.width > 0 ? write_padded(out, specs, size, write) + : base_iterator(out, write(reserve(out, size))); + } + + int exp = f.exponent + significand_size; + if (f.exponent >= 0) { + // 1234e5 -> 123400000[.0+] + size += to_unsigned(f.exponent); + int num_zeros = fspecs.precision - exp; + abort_fuzzing_if(num_zeros > 5000); + if (fspecs.showpoint) { + ++size; + if (num_zeros <= 0 && fspecs.format != float_format::fixed) num_zeros = 0; + if (num_zeros > 0) size += to_unsigned(num_zeros); + } + auto grouping = Grouping(loc, fspecs.locale); + size += to_unsigned(grouping.count_separators(exp)); + return write_padded(out, specs, size, [&](iterator it) { + if (sign) *it++ = detail::sign(sign); + it = write_significand(it, significand, significand_size, + f.exponent, grouping); + if (!fspecs.showpoint) return it; + *it++ = decimal_point; + return num_zeros > 0 ? detail::fill_n(it, num_zeros, zero) : it; + }); + } else if (exp > 0) { + // 1234e-2 -> 12.34[0+] + int num_zeros = fspecs.showpoint ? fspecs.precision - significand_size : 0; + size += 1 + to_unsigned(num_zeros > 0 ? num_zeros : 0); + auto grouping = Grouping(loc, fspecs.locale); + size += to_unsigned(grouping.count_separators(exp)); + return write_padded(out, specs, size, [&](iterator it) { + if (sign) *it++ = detail::sign(sign); + it = write_significand(it, significand, significand_size, exp, + decimal_point, grouping); + return num_zeros > 0 ? detail::fill_n(it, num_zeros, zero) : it; + }); + } + // 1234e-6 -> 0.001234 + int num_zeros = -exp; + if (significand_size == 0 && fspecs.precision >= 0 && + fspecs.precision < num_zeros) { + num_zeros = fspecs.precision; + } + bool pointy = num_zeros != 0 || significand_size != 0 || fspecs.showpoint; + size += 1 + (pointy ? 1 : 0) + to_unsigned(num_zeros); + return write_padded(out, specs, size, [&](iterator it) { + if (sign) *it++ = detail::sign(sign); + *it++ = zero; + if (!pointy) return it; + *it++ = decimal_point; + it = detail::fill_n(it, num_zeros, zero); + return write_significand(it, significand, significand_size); + }); +} + +template class fallback_digit_grouping { + public: + constexpr fallback_digit_grouping(locale_ref, bool) {} + + constexpr bool has_separator() const { return false; } + + constexpr int count_separators(int) const { return 0; } + + template + constexpr Out apply(Out out, basic_string_view) const { + return out; + } +}; + +template +FMT_CONSTEXPR20 auto write_float(OutputIt out, const DecimalFP& f, + const format_specs& specs, + float_specs fspecs, locale_ref loc) + -> OutputIt { + if (is_constant_evaluated()) { + return do_write_float>(out, f, specs, fspecs, + loc); + } else { + return do_write_float(out, f, specs, fspecs, loc); + } +} + +template constexpr bool isnan(T value) { + return !(value >= value); // std::isnan doesn't support __float128. +} + +template +struct has_isfinite : std::false_type {}; + +template +struct has_isfinite> + : std::true_type {}; + +template ::value&& + has_isfinite::value)> +FMT_CONSTEXPR20 bool isfinite(T value) { + constexpr T inf = T(std::numeric_limits::infinity()); + if (is_constant_evaluated()) + return !detail::isnan(value) && value < inf && value > -inf; + return std::isfinite(value); +} +template ::value)> +FMT_CONSTEXPR bool isfinite(T value) { + T inf = T(std::numeric_limits::infinity()); + // std::isfinite doesn't support __float128. + return !detail::isnan(value) && value < inf && value > -inf; +} + +template ::value)> +FMT_INLINE FMT_CONSTEXPR bool signbit(T value) { + if (is_constant_evaluated()) { +#ifdef __cpp_if_constexpr + if constexpr (std::numeric_limits::is_iec559) { + auto bits = detail::bit_cast(static_cast(value)); + return (bits >> (num_bits() - 1)) != 0; + } +#endif + } + return std::signbit(static_cast(value)); +} + +inline FMT_CONSTEXPR20 void adjust_precision(int& precision, int exp10) { + // Adjust fixed precision by exponent because it is relative to decimal + // point. + if (exp10 > 0 && precision > max_value() - exp10) + FMT_THROW(format_error("number is too big")); + precision += exp10; +} + +class bigint { + private: + // A bigint is stored as an array of bigits (big digits), with bigit at index + // 0 being the least significant one. + using bigit = uint32_t; + using double_bigit = uint64_t; + enum { bigits_capacity = 32 }; + basic_memory_buffer bigits_; + int exp_; + + FMT_CONSTEXPR20 bigit operator[](int index) const { + return bigits_[to_unsigned(index)]; + } + FMT_CONSTEXPR20 bigit& operator[](int index) { + return bigits_[to_unsigned(index)]; + } + + static constexpr const int bigit_bits = num_bits(); + + friend struct formatter; + + FMT_CONSTEXPR20 void subtract_bigits(int index, bigit other, bigit& borrow) { + auto result = static_cast((*this)[index]) - other - borrow; + (*this)[index] = static_cast(result); + borrow = static_cast(result >> (bigit_bits * 2 - 1)); + } + + FMT_CONSTEXPR20 void remove_leading_zeros() { + int num_bigits = static_cast(bigits_.size()) - 1; + while (num_bigits > 0 && (*this)[num_bigits] == 0) --num_bigits; + bigits_.resize(to_unsigned(num_bigits + 1)); + } + + // Computes *this -= other assuming aligned bigints and *this >= other. + FMT_CONSTEXPR20 void subtract_aligned(const bigint& other) { + FMT_ASSERT(other.exp_ >= exp_, "unaligned bigints"); + FMT_ASSERT(compare(*this, other) >= 0, ""); + bigit borrow = 0; + int i = other.exp_ - exp_; + for (size_t j = 0, n = other.bigits_.size(); j != n; ++i, ++j) + subtract_bigits(i, other.bigits_[j], borrow); + while (borrow > 0) subtract_bigits(i, 0, borrow); + remove_leading_zeros(); + } + + FMT_CONSTEXPR20 void multiply(uint32_t value) { + const double_bigit wide_value = value; + bigit carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + double_bigit result = bigits_[i] * wide_value + carry; + bigits_[i] = static_cast(result); + carry = static_cast(result >> bigit_bits); + } + if (carry != 0) bigits_.push_back(carry); + } + + template ::value || + std::is_same::value)> + FMT_CONSTEXPR20 void multiply(UInt value) { + using half_uint = + conditional_t::value, uint64_t, uint32_t>; + const int shift = num_bits() - bigit_bits; + const UInt lower = static_cast(value); + const UInt upper = value >> num_bits(); + UInt carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + UInt result = lower * bigits_[i] + static_cast(carry); + carry = (upper * bigits_[i] << shift) + (result >> bigit_bits) + + (carry >> bigit_bits); + bigits_[i] = static_cast(result); + } + while (carry != 0) { + bigits_.push_back(static_cast(carry)); + carry >>= bigit_bits; + } + } + + template ::value || + std::is_same::value)> + FMT_CONSTEXPR20 void assign(UInt n) { + size_t num_bigits = 0; + do { + bigits_[num_bigits++] = static_cast(n); + n >>= bigit_bits; + } while (n != 0); + bigits_.resize(num_bigits); + exp_ = 0; + } + + public: + FMT_CONSTEXPR20 bigint() : exp_(0) {} + explicit bigint(uint64_t n) { assign(n); } + + bigint(const bigint&) = delete; + void operator=(const bigint&) = delete; + + FMT_CONSTEXPR20 void assign(const bigint& other) { + auto size = other.bigits_.size(); + bigits_.resize(size); + auto data = other.bigits_.data(); + copy_str(data, data + size, bigits_.data()); + exp_ = other.exp_; + } + + template FMT_CONSTEXPR20 void operator=(Int n) { + FMT_ASSERT(n > 0, ""); + assign(uint64_or_128_t(n)); + } + + FMT_CONSTEXPR20 int num_bigits() const { + return static_cast(bigits_.size()) + exp_; + } + + FMT_NOINLINE FMT_CONSTEXPR20 bigint& operator<<=(int shift) { + FMT_ASSERT(shift >= 0, ""); + exp_ += shift / bigit_bits; + shift %= bigit_bits; + if (shift == 0) return *this; + bigit carry = 0; + for (size_t i = 0, n = bigits_.size(); i < n; ++i) { + bigit c = bigits_[i] >> (bigit_bits - shift); + bigits_[i] = (bigits_[i] << shift) + carry; + carry = c; + } + if (carry != 0) bigits_.push_back(carry); + return *this; + } + + template FMT_CONSTEXPR20 bigint& operator*=(Int value) { + FMT_ASSERT(value > 0, ""); + multiply(uint32_or_64_or_128_t(value)); + return *this; + } + + friend FMT_CONSTEXPR20 int compare(const bigint& lhs, const bigint& rhs) { + int num_lhs_bigits = lhs.num_bigits(), num_rhs_bigits = rhs.num_bigits(); + if (num_lhs_bigits != num_rhs_bigits) + return num_lhs_bigits > num_rhs_bigits ? 1 : -1; + int i = static_cast(lhs.bigits_.size()) - 1; + int j = static_cast(rhs.bigits_.size()) - 1; + int end = i - j; + if (end < 0) end = 0; + for (; i >= end; --i, --j) { + bigit lhs_bigit = lhs[i], rhs_bigit = rhs[j]; + if (lhs_bigit != rhs_bigit) return lhs_bigit > rhs_bigit ? 1 : -1; + } + if (i != j) return i > j ? 1 : -1; + return 0; + } + + // Returns compare(lhs1 + lhs2, rhs). + friend FMT_CONSTEXPR20 int add_compare(const bigint& lhs1, const bigint& lhs2, + const bigint& rhs) { + auto minimum = [](int a, int b) { return a < b ? a : b; }; + auto maximum = [](int a, int b) { return a > b ? a : b; }; + int max_lhs_bigits = maximum(lhs1.num_bigits(), lhs2.num_bigits()); + int num_rhs_bigits = rhs.num_bigits(); + if (max_lhs_bigits + 1 < num_rhs_bigits) return -1; + if (max_lhs_bigits > num_rhs_bigits) return 1; + auto get_bigit = [](const bigint& n, int i) -> bigit { + return i >= n.exp_ && i < n.num_bigits() ? n[i - n.exp_] : 0; + }; + double_bigit borrow = 0; + int min_exp = minimum(minimum(lhs1.exp_, lhs2.exp_), rhs.exp_); + for (int i = num_rhs_bigits - 1; i >= min_exp; --i) { + double_bigit sum = + static_cast(get_bigit(lhs1, i)) + get_bigit(lhs2, i); + bigit rhs_bigit = get_bigit(rhs, i); + if (sum > rhs_bigit + borrow) return 1; + borrow = rhs_bigit + borrow - sum; + if (borrow > 1) return -1; + borrow <<= bigit_bits; + } + return borrow != 0 ? -1 : 0; + } + + // Assigns pow(10, exp) to this bigint. + FMT_CONSTEXPR20 void assign_pow10(int exp) { + FMT_ASSERT(exp >= 0, ""); + if (exp == 0) return *this = 1; + // Find the top bit. + int bitmask = 1; + while (exp >= bitmask) bitmask <<= 1; + bitmask >>= 1; + // pow(10, exp) = pow(5, exp) * pow(2, exp). First compute pow(5, exp) by + // repeated squaring and multiplication. + *this = 5; + bitmask >>= 1; + while (bitmask != 0) { + square(); + if ((exp & bitmask) != 0) *this *= 5; + bitmask >>= 1; + } + *this <<= exp; // Multiply by pow(2, exp) by shifting. + } + + FMT_CONSTEXPR20 void square() { + int num_bigits = static_cast(bigits_.size()); + int num_result_bigits = 2 * num_bigits; + basic_memory_buffer n(std::move(bigits_)); + bigits_.resize(to_unsigned(num_result_bigits)); + auto sum = uint128_t(); + for (int bigit_index = 0; bigit_index < num_bigits; ++bigit_index) { + // Compute bigit at position bigit_index of the result by adding + // cross-product terms n[i] * n[j] such that i + j == bigit_index. + for (int i = 0, j = bigit_index; j >= 0; ++i, --j) { + // Most terms are multiplied twice which can be optimized in the future. + sum += static_cast(n[i]) * n[j]; + } + (*this)[bigit_index] = static_cast(sum); + sum >>= num_bits(); // Compute the carry. + } + // Do the same for the top half. + for (int bigit_index = num_bigits; bigit_index < num_result_bigits; + ++bigit_index) { + for (int j = num_bigits - 1, i = bigit_index - j; i < num_bigits;) + sum += static_cast(n[i++]) * n[j--]; + (*this)[bigit_index] = static_cast(sum); + sum >>= num_bits(); + } + remove_leading_zeros(); + exp_ *= 2; + } + + // If this bigint has a bigger exponent than other, adds trailing zero to make + // exponents equal. This simplifies some operations such as subtraction. + FMT_CONSTEXPR20 void align(const bigint& other) { + int exp_difference = exp_ - other.exp_; + if (exp_difference <= 0) return; + int num_bigits = static_cast(bigits_.size()); + bigits_.resize(to_unsigned(num_bigits + exp_difference)); + for (int i = num_bigits - 1, j = i + exp_difference; i >= 0; --i, --j) + bigits_[j] = bigits_[i]; + std::uninitialized_fill_n(bigits_.data(), exp_difference, 0); + exp_ -= exp_difference; + } + + // Divides this bignum by divisor, assigning the remainder to this and + // returning the quotient. + FMT_CONSTEXPR20 int divmod_assign(const bigint& divisor) { + FMT_ASSERT(this != &divisor, ""); + if (compare(*this, divisor) < 0) return 0; + FMT_ASSERT(divisor.bigits_[divisor.bigits_.size() - 1u] != 0, ""); + align(divisor); + int quotient = 0; + do { + subtract_aligned(divisor); + ++quotient; + } while (compare(*this, divisor) >= 0); + return quotient; + } +}; + +// format_dragon flags. +enum dragon { + predecessor_closer = 1, + fixup = 2, // Run fixup to correct exp10 which can be off by one. + fixed = 4, +}; + +// Formats a floating-point number using a variation of the Fixed-Precision +// Positive Floating-Point Printout ((FPP)^2) algorithm by Steele & White: +// https://fmt.dev/papers/p372-steele.pdf. +FMT_CONSTEXPR20 inline void format_dragon(basic_fp value, + unsigned flags, int num_digits, + buffer& buf, int& exp10) { + bigint numerator; // 2 * R in (FPP)^2. + bigint denominator; // 2 * S in (FPP)^2. + // lower and upper are differences between value and corresponding boundaries. + bigint lower; // (M^- in (FPP)^2). + bigint upper_store; // upper's value if different from lower. + bigint* upper = nullptr; // (M^+ in (FPP)^2). + // Shift numerator and denominator by an extra bit or two (if lower boundary + // is closer) to make lower and upper integers. This eliminates multiplication + // by 2 during later computations. + bool is_predecessor_closer = (flags & dragon::predecessor_closer) != 0; + int shift = is_predecessor_closer ? 2 : 1; + if (value.e >= 0) { + numerator = value.f; + numerator <<= value.e + shift; + lower = 1; + lower <<= value.e; + if (is_predecessor_closer) { + upper_store = 1; + upper_store <<= value.e + 1; + upper = &upper_store; + } + denominator.assign_pow10(exp10); + denominator <<= shift; + } else if (exp10 < 0) { + numerator.assign_pow10(-exp10); + lower.assign(numerator); + if (is_predecessor_closer) { + upper_store.assign(numerator); + upper_store <<= 1; + upper = &upper_store; + } + numerator *= value.f; + numerator <<= shift; + denominator = 1; + denominator <<= shift - value.e; + } else { + numerator = value.f; + numerator <<= shift; + denominator.assign_pow10(exp10); + denominator <<= shift - value.e; + lower = 1; + if (is_predecessor_closer) { + upper_store = 1ULL << 1; + upper = &upper_store; + } + } + int even = static_cast((value.f & 1) == 0); + if (!upper) upper = &lower; + bool shortest = num_digits < 0; + if ((flags & dragon::fixup) != 0) { + if (add_compare(numerator, *upper, denominator) + even <= 0) { + --exp10; + numerator *= 10; + if (num_digits < 0) { + lower *= 10; + if (upper != &lower) *upper *= 10; + } + } + if ((flags & dragon::fixed) != 0) adjust_precision(num_digits, exp10 + 1); + } + // Invariant: value == (numerator / denominator) * pow(10, exp10). + if (shortest) { + // Generate the shortest representation. + num_digits = 0; + char* data = buf.data(); + for (;;) { + int digit = numerator.divmod_assign(denominator); + bool low = compare(numerator, lower) - even < 0; // numerator <[=] lower. + // numerator + upper >[=] pow10: + bool high = add_compare(numerator, *upper, denominator) + even > 0; + data[num_digits++] = static_cast('0' + digit); + if (low || high) { + if (!low) { + ++data[num_digits - 1]; + } else if (high) { + int result = add_compare(numerator, numerator, denominator); + // Round half to even. + if (result > 0 || (result == 0 && (digit % 2) != 0)) + ++data[num_digits - 1]; + } + buf.try_resize(to_unsigned(num_digits)); + exp10 -= num_digits - 1; + return; + } + numerator *= 10; + lower *= 10; + if (upper != &lower) *upper *= 10; + } + } + // Generate the given number of digits. + exp10 -= num_digits - 1; + if (num_digits <= 0) { + denominator *= 10; + auto digit = add_compare(numerator, numerator, denominator) > 0 ? '1' : '0'; + buf.push_back(digit); + return; + } + buf.try_resize(to_unsigned(num_digits)); + for (int i = 0; i < num_digits - 1; ++i) { + int digit = numerator.divmod_assign(denominator); + buf[i] = static_cast('0' + digit); + numerator *= 10; + } + int digit = numerator.divmod_assign(denominator); + auto result = add_compare(numerator, numerator, denominator); + if (result > 0 || (result == 0 && (digit % 2) != 0)) { + if (digit == 9) { + const auto overflow = '0' + 10; + buf[num_digits - 1] = overflow; + // Propagate the carry. + for (int i = num_digits - 1; i > 0 && buf[i] == overflow; --i) { + buf[i] = '0'; + ++buf[i - 1]; + } + if (buf[0] == overflow) { + buf[0] = '1'; + if ((flags & dragon::fixed) != 0) buf.push_back('0'); + else ++exp10; + } + return; + } + ++digit; + } + buf[num_digits - 1] = static_cast('0' + digit); +} + +// Formats a floating-point number using the hexfloat format. +template ::value)> +FMT_CONSTEXPR20 void format_hexfloat(Float value, int precision, + float_specs specs, buffer& buf) { + // float is passed as double to reduce the number of instantiations and to + // simplify implementation. + static_assert(!std::is_same::value, ""); + + using info = dragonbox::float_info; + + // Assume Float is in the format [sign][exponent][significand]. + using carrier_uint = typename info::carrier_uint; + + constexpr auto num_float_significand_bits = + detail::num_significand_bits(); + + basic_fp f(value); + f.e += num_float_significand_bits; + if (!has_implicit_bit()) --f.e; + + constexpr auto num_fraction_bits = + num_float_significand_bits + (has_implicit_bit() ? 1 : 0); + constexpr auto num_xdigits = (num_fraction_bits + 3) / 4; + + constexpr auto leading_shift = ((num_xdigits - 1) * 4); + const auto leading_mask = carrier_uint(0xF) << leading_shift; + const auto leading_xdigit = + static_cast((f.f & leading_mask) >> leading_shift); + if (leading_xdigit > 1) f.e -= (32 - countl_zero(leading_xdigit) - 1); + + int print_xdigits = num_xdigits - 1; + if (precision >= 0 && print_xdigits > precision) { + const int shift = ((print_xdigits - precision - 1) * 4); + const auto mask = carrier_uint(0xF) << shift; + const auto v = static_cast((f.f & mask) >> shift); + + if (v >= 8) { + const auto inc = carrier_uint(1) << (shift + 4); + f.f += inc; + f.f &= ~(inc - 1); + } + + // Check long double overflow + if (!has_implicit_bit()) { + const auto implicit_bit = carrier_uint(1) << num_float_significand_bits; + if ((f.f & implicit_bit) == implicit_bit) { + f.f >>= 4; + f.e += 4; + } + } + + print_xdigits = precision; + } + + char xdigits[num_bits() / 4]; + detail::fill_n(xdigits, sizeof(xdigits), '0'); + format_uint<4>(xdigits, f.f, num_xdigits, specs.upper); + + // Remove zero tail + while (print_xdigits > 0 && xdigits[print_xdigits] == '0') --print_xdigits; + + buf.push_back('0'); + buf.push_back(specs.upper ? 'X' : 'x'); + buf.push_back(xdigits[0]); + if (specs.showpoint || print_xdigits > 0 || print_xdigits < precision) + buf.push_back('.'); + buf.append(xdigits + 1, xdigits + 1 + print_xdigits); + for (; print_xdigits < precision; ++print_xdigits) buf.push_back('0'); + + buf.push_back(specs.upper ? 'P' : 'p'); + + uint32_t abs_e; + if (f.e < 0) { + buf.push_back('-'); + abs_e = static_cast(-f.e); + } else { + buf.push_back('+'); + abs_e = static_cast(f.e); + } + format_decimal(appender(buf), abs_e, detail::count_digits(abs_e)); +} + +template ::value)> +FMT_CONSTEXPR20 void format_hexfloat(Float value, int precision, + float_specs specs, buffer& buf) { + format_hexfloat(static_cast(value), precision, specs, buf); +} + +template +FMT_CONSTEXPR20 auto format_float(Float value, int precision, float_specs specs, + buffer& buf) -> int { + // float is passed as double to reduce the number of instantiations. + static_assert(!std::is_same::value, ""); + FMT_ASSERT(value >= 0, "value is negative"); + auto converted_value = convert_float(value); + + const bool fixed = specs.format == float_format::fixed; + if (value <= 0) { // <= instead of == to silence a warning. + if (precision <= 0 || !fixed) { + buf.push_back('0'); + return 0; + } + buf.try_resize(to_unsigned(precision)); + fill_n(buf.data(), precision, '0'); + return -precision; + } + + int exp = 0; + bool use_dragon = true; + unsigned dragon_flags = 0; + if (!is_fast_float() || is_constant_evaluated()) { + const auto inv_log2_10 = 0.3010299956639812; // 1 / log2(10) + using info = dragonbox::float_info; + const auto f = basic_fp(converted_value); + // Compute exp, an approximate power of 10, such that + // 10^(exp - 1) <= value < 10^exp or 10^exp <= value < 10^(exp + 1). + // This is based on log10(value) == log2(value) / log2(10) and approximation + // of log2(value) by e + num_fraction_bits idea from double-conversion. + auto e = (f.e + count_digits<1>(f.f) - 1) * inv_log2_10 - 1e-10; + exp = static_cast(e); + if (e > exp) ++exp; // Compute ceil. + dragon_flags = dragon::fixup; + } else if (precision < 0) { + // Use Dragonbox for the shortest format. + if (specs.binary32) { + auto dec = dragonbox::to_decimal(static_cast(value)); + write(buffer_appender(buf), dec.significand); + return dec.exponent; + } + auto dec = dragonbox::to_decimal(static_cast(value)); + write(buffer_appender(buf), dec.significand); + return dec.exponent; + } else { + // Extract significand bits and exponent bits. + using info = dragonbox::float_info; + auto br = bit_cast(static_cast(value)); + + const uint64_t significand_mask = + (static_cast(1) << num_significand_bits()) - 1; + uint64_t significand = (br & significand_mask); + int exponent = static_cast((br & exponent_mask()) >> + num_significand_bits()); + + if (exponent != 0) { // Check if normal. + exponent -= exponent_bias() + num_significand_bits(); + significand |= + (static_cast(1) << num_significand_bits()); + significand <<= 1; + } else { + // Normalize subnormal inputs. + FMT_ASSERT(significand != 0, "zeros should not appear here"); + int shift = countl_zero(significand); + FMT_ASSERT(shift >= num_bits() - num_significand_bits(), + ""); + shift -= (num_bits() - num_significand_bits() - 2); + exponent = (std::numeric_limits::min_exponent - + num_significand_bits()) - + shift; + significand <<= shift; + } + + // Compute the first several nonzero decimal significand digits. + // We call the number we get the first segment. + const int k = info::kappa - dragonbox::floor_log10_pow2(exponent); + exp = -k; + const int beta = exponent + dragonbox::floor_log2_pow10(k); + uint64_t first_segment; + bool has_more_segments; + int digits_in_the_first_segment; + { + const auto r = dragonbox::umul192_upper128( + significand << beta, dragonbox::get_cached_power(k)); + first_segment = r.high(); + has_more_segments = r.low() != 0; + + // The first segment can have 18 ~ 19 digits. + if (first_segment >= 1000000000000000000ULL) { + digits_in_the_first_segment = 19; + } else { + // When it is of 18-digits, we align it to 19-digits by adding a bogus + // zero at the end. + digits_in_the_first_segment = 18; + first_segment *= 10; + } + } + + // Compute the actual number of decimal digits to print. + if (fixed) adjust_precision(precision, exp + digits_in_the_first_segment); + + // Use Dragon4 only when there might be not enough digits in the first + // segment. + if (digits_in_the_first_segment > precision) { + use_dragon = false; + + if (precision <= 0) { + exp += digits_in_the_first_segment; + + if (precision < 0) { + // Nothing to do, since all we have are just leading zeros. + buf.try_resize(0); + } else { + // We may need to round-up. + buf.try_resize(1); + if ((first_segment | static_cast(has_more_segments)) > + 5000000000000000000ULL) { + buf[0] = '1'; + } else { + buf[0] = '0'; + } + } + } // precision <= 0 + else { + exp += digits_in_the_first_segment - precision; + + // When precision > 0, we divide the first segment into three + // subsegments, each with 9, 9, and 0 ~ 1 digits so that each fits + // in 32-bits which usually allows faster calculation than in + // 64-bits. Since some compiler (e.g. MSVC) doesn't know how to optimize + // division-by-constant for large 64-bit divisors, we do it here + // manually. The magic number 7922816251426433760 below is equal to + // ceil(2^(64+32) / 10^10). + const uint32_t first_subsegment = static_cast( + dragonbox::umul128_upper64(first_segment, 7922816251426433760ULL) >> + 32); + const uint64_t second_third_subsegments = + first_segment - first_subsegment * 10000000000ULL; + + uint64_t prod; + uint32_t digits; + bool should_round_up; + int number_of_digits_to_print = precision > 9 ? 9 : precision; + + // Print a 9-digits subsegment, either the first or the second. + auto print_subsegment = [&](uint32_t subsegment, char* buffer) { + int number_of_digits_printed = 0; + + // If we want to print an odd number of digits from the subsegment, + if ((number_of_digits_to_print & 1) != 0) { + // Convert to 64-bit fixed-point fractional form with 1-digit + // integer part. The magic number 720575941 is a good enough + // approximation of 2^(32 + 24) / 10^8; see + // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case + // for details. + prod = ((subsegment * static_cast(720575941)) >> 24) + 1; + digits = static_cast(prod >> 32); + *buffer = static_cast('0' + digits); + number_of_digits_printed++; + } + // If we want to print an even number of digits from the + // first_subsegment, + else { + // Convert to 64-bit fixed-point fractional form with 2-digits + // integer part. The magic number 450359963 is a good enough + // approximation of 2^(32 + 20) / 10^7; see + // https://jk-jeon.github.io/posts/2022/12/fixed-precision-formatting/#fixed-length-case + // for details. + prod = ((subsegment * static_cast(450359963)) >> 20) + 1; + digits = static_cast(prod >> 32); + copy2(buffer, digits2(digits)); + number_of_digits_printed += 2; + } + + // Print all digit pairs. + while (number_of_digits_printed < number_of_digits_to_print) { + prod = static_cast(prod) * static_cast(100); + digits = static_cast(prod >> 32); + copy2(buffer + number_of_digits_printed, digits2(digits)); + number_of_digits_printed += 2; + } + }; + + // Print first subsegment. + print_subsegment(first_subsegment, buf.data()); + + // Perform rounding if the first subsegment is the last subsegment to + // print. + if (precision <= 9) { + // Rounding inside the subsegment. + // We round-up if: + // - either the fractional part is strictly larger than 1/2, or + // - the fractional part is exactly 1/2 and the last digit is odd. + // We rely on the following observations: + // - If fractional_part >= threshold, then the fractional part is + // strictly larger than 1/2. + // - If the MSB of fractional_part is set, then the fractional part + // must be at least 1/2. + // - When the MSB of fractional_part is set, either + // second_third_subsegments being nonzero or has_more_segments + // being true means there are further digits not printed, so the + // fractional part is strictly larger than 1/2. + if (precision < 9) { + uint32_t fractional_part = static_cast(prod); + should_round_up = fractional_part >= + data::fractional_part_rounding_thresholds + [8 - number_of_digits_to_print] || + ((fractional_part >> 31) & + ((digits & 1) | (second_third_subsegments != 0) | + has_more_segments)) != 0; + } + // Rounding at the subsegment boundary. + // In this case, the fractional part is at least 1/2 if and only if + // second_third_subsegments >= 5000000000ULL, and is strictly larger + // than 1/2 if we further have either second_third_subsegments > + // 5000000000ULL or has_more_segments == true. + else { + should_round_up = second_third_subsegments > 5000000000ULL || + (second_third_subsegments == 5000000000ULL && + ((digits & 1) != 0 || has_more_segments)); + } + } + // Otherwise, print the second subsegment. + else { + // Compilers are not aware of how to leverage the maximum value of + // second_third_subsegments to find out a better magic number which + // allows us to eliminate an additional shift. 1844674407370955162 = + // ceil(2^64/10) < ceil(2^64*(10^9/(10^10 - 1))). + const uint32_t second_subsegment = + static_cast(dragonbox::umul128_upper64( + second_third_subsegments, 1844674407370955162ULL)); + const uint32_t third_subsegment = + static_cast(second_third_subsegments) - + second_subsegment * 10; + + number_of_digits_to_print = precision - 9; + print_subsegment(second_subsegment, buf.data() + 9); + + // Rounding inside the subsegment. + if (precision < 18) { + // The condition third_subsegment != 0 implies that the segment was + // of 19 digits, so in this case the third segment should be + // consisting of a genuine digit from the input. + uint32_t fractional_part = static_cast(prod); + should_round_up = fractional_part >= + data::fractional_part_rounding_thresholds + [8 - number_of_digits_to_print] || + ((fractional_part >> 31) & + ((digits & 1) | (third_subsegment != 0) | + has_more_segments)) != 0; + } + // Rounding at the subsegment boundary. + else { + // In this case, the segment must be of 19 digits, thus + // the third subsegment should be consisting of a genuine digit from + // the input. + should_round_up = third_subsegment > 5 || + (third_subsegment == 5 && + ((digits & 1) != 0 || has_more_segments)); + } + } + + // Round-up if necessary. + if (should_round_up) { + ++buf[precision - 1]; + for (int i = precision - 1; i > 0 && buf[i] > '9'; --i) { + buf[i] = '0'; + ++buf[i - 1]; + } + if (buf[0] > '9') { + buf[0] = '1'; + if (fixed) + buf[precision++] = '0'; + else + ++exp; + } + } + buf.try_resize(to_unsigned(precision)); + } + } // if (digits_in_the_first_segment > precision) + else { + // Adjust the exponent for its use in Dragon4. + exp += digits_in_the_first_segment - 1; + } + } + if (use_dragon) { + auto f = basic_fp(); + bool is_predecessor_closer = specs.binary32 + ? f.assign(static_cast(value)) + : f.assign(converted_value); + if (is_predecessor_closer) dragon_flags |= dragon::predecessor_closer; + if (fixed) dragon_flags |= dragon::fixed; + // Limit precision to the maximum possible number of significant digits in + // an IEEE754 double because we don't need to generate zeros. + const int max_double_digits = 767; + if (precision > max_double_digits) precision = max_double_digits; + format_dragon(f, dragon_flags, precision, buf, exp); + } + if (!fixed && !specs.showpoint) { + // Remove trailing zeros. + auto num_digits = buf.size(); + while (num_digits > 0 && buf[num_digits - 1] == '0') { + --num_digits; + ++exp; + } + buf.try_resize(num_digits); + } + return exp; +} +template +FMT_CONSTEXPR20 auto write_float(OutputIt out, T value, + format_specs specs, locale_ref loc) + -> OutputIt { + float_specs fspecs = parse_float_type_spec(specs); + fspecs.sign = specs.sign; + if (detail::signbit(value)) { // value < 0 is false for NaN so use signbit. + fspecs.sign = sign::minus; + value = -value; + } else if (fspecs.sign == sign::minus) { + fspecs.sign = sign::none; + } + + if (!detail::isfinite(value)) + return write_nonfinite(out, detail::isnan(value), specs, fspecs); + + if (specs.align == align::numeric && fspecs.sign) { + auto it = reserve(out, 1); + *it++ = detail::sign(fspecs.sign); + out = base_iterator(out, it); + fspecs.sign = sign::none; + if (specs.width != 0) --specs.width; + } + + memory_buffer buffer; + if (fspecs.format == float_format::hex) { + if (fspecs.sign) buffer.push_back(detail::sign(fspecs.sign)); + format_hexfloat(convert_float(value), specs.precision, fspecs, buffer); + return write_bytes(out, {buffer.data(), buffer.size()}, + specs); + } + int precision = specs.precision >= 0 || specs.type == presentation_type::none + ? specs.precision + : 6; + if (fspecs.format == float_format::exp) { + if (precision == max_value()) + throw_format_error("number is too big"); + else + ++precision; + } else if (fspecs.format != float_format::fixed && precision == 0) { + precision = 1; + } + if (const_check(std::is_same())) fspecs.binary32 = true; + int exp = format_float(convert_float(value), precision, fspecs, buffer); + fspecs.precision = precision; + auto f = big_decimal_fp{buffer.data(), static_cast(buffer.size()), exp}; + return write_float(out, f, specs, fspecs, loc); +} + +template ::value)> +FMT_CONSTEXPR20 auto write(OutputIt out, T value, format_specs specs, + locale_ref loc = {}) -> OutputIt { + if (const_check(!is_supported_floating_point(value))) return out; + return specs.localized && write_loc(out, value, specs, loc) + ? out + : write_float(out, value, specs, loc); +} + +template ::value)> +FMT_CONSTEXPR20 auto write(OutputIt out, T value) -> OutputIt { + if (is_constant_evaluated()) return write(out, value, format_specs()); + if (const_check(!is_supported_floating_point(value))) return out; + + auto fspecs = float_specs(); + if (detail::signbit(value)) { + fspecs.sign = sign::minus; + value = -value; + } + + constexpr auto specs = format_specs(); + using floaty = conditional_t::value, double, T>; + using floaty_uint = typename dragonbox::float_info::carrier_uint; + floaty_uint mask = exponent_mask(); + if ((bit_cast(value) & mask) == mask) + return write_nonfinite(out, std::isnan(value), specs, fspecs); + + auto dec = dragonbox::to_decimal(static_cast(value)); + return write_float(out, dec, specs, fspecs, {}); +} + +template ::value && + !is_fast_float::value)> +inline auto write(OutputIt out, T value) -> OutputIt { + return write(out, value, format_specs()); +} + +template +auto write(OutputIt out, monostate, format_specs = {}, locale_ref = {}) + -> OutputIt { + FMT_ASSERT(false, ""); + return out; +} + +template +FMT_CONSTEXPR auto write(OutputIt out, basic_string_view value) + -> OutputIt { + auto it = reserve(out, value.size()); + it = copy_str_noinline(value.begin(), value.end(), it); + return base_iterator(out, it); +} + +template ::value)> +constexpr auto write(OutputIt out, const T& value) -> OutputIt { + return write(out, to_string_view(value)); +} + +// FMT_ENABLE_IF() condition separated to workaround an MSVC bug. +template < + typename Char, typename OutputIt, typename T, + bool check = + std::is_enum::value && !std::is_same::value && + mapped_type_constant>::value != + type::custom_type, + FMT_ENABLE_IF(check)> +FMT_CONSTEXPR auto write(OutputIt out, T value) -> OutputIt { + return write(out, static_cast>(value)); +} + +template ::value)> +FMT_CONSTEXPR auto write(OutputIt out, T value, + const format_specs& specs = {}, locale_ref = {}) + -> OutputIt { + return specs.type != presentation_type::none && + specs.type != presentation_type::string + ? write(out, value ? 1 : 0, specs, {}) + : write_bytes(out, value ? "true" : "false", specs); +} + +template +FMT_CONSTEXPR auto write(OutputIt out, Char value) -> OutputIt { + auto it = reserve(out, 1); + *it++ = value; + return base_iterator(out, it); +} + +template +FMT_CONSTEXPR_CHAR_TRAITS auto write(OutputIt out, const Char* value) + -> OutputIt { + if (value) return write(out, basic_string_view(value)); + throw_format_error("string pointer is null"); + return out; +} + +template ::value)> +auto write(OutputIt out, const T* value, const format_specs& specs = {}, + locale_ref = {}) -> OutputIt { + return write_ptr(out, bit_cast(value), &specs); +} + +// A write overload that handles implicit conversions. +template > +FMT_CONSTEXPR auto write(OutputIt out, const T& value) -> enable_if_t< + std::is_class::value && !is_string::value && + !is_floating_point::value && !std::is_same::value && + !std::is_same().map( + value))>>::value, + OutputIt> { + return write(out, arg_mapper().map(value)); +} + +template > +FMT_CONSTEXPR auto write(OutputIt out, const T& value) + -> enable_if_t::value == type::custom_type, + OutputIt> { + auto ctx = Context(out, {}, {}); + return typename Context::template formatter_type().format(value, ctx); +} + +// An argument visitor that formats the argument and writes it via the output +// iterator. It's a class and not a generic lambda for compatibility with C++11. +template struct default_arg_formatter { + using iterator = buffer_appender; + using context = buffer_context; + + iterator out; + basic_format_args args; + locale_ref loc; + + template auto operator()(T value) -> iterator { + return write(out, value); + } + auto operator()(typename basic_format_arg::handle h) -> iterator { + basic_format_parse_context parse_ctx({}); + context format_ctx(out, args, loc); + h.format(parse_ctx, format_ctx); + return format_ctx.out(); + } +}; + +template struct arg_formatter { + using iterator = buffer_appender; + using context = buffer_context; + + iterator out; + const format_specs& specs; + locale_ref locale; + + template + FMT_CONSTEXPR FMT_INLINE auto operator()(T value) -> iterator { + return detail::write(out, value, specs, locale); + } + auto operator()(typename basic_format_arg::handle) -> iterator { + // User-defined types are handled separately because they require access + // to the parse context. + return out; + } +}; + +template struct custom_formatter { + basic_format_parse_context& parse_ctx; + buffer_context& ctx; + + void operator()( + typename basic_format_arg>::handle h) const { + h.format(parse_ctx, ctx); + } + template void operator()(T) const {} +}; + +template class width_checker { + public: + explicit FMT_CONSTEXPR width_checker(ErrorHandler& eh) : handler_(eh) {} + + template ::value)> + FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { + if (is_negative(value)) handler_.on_error("negative width"); + return static_cast(value); + } + + template ::value)> + FMT_CONSTEXPR auto operator()(T) -> unsigned long long { + handler_.on_error("width is not integer"); + return 0; + } + + private: + ErrorHandler& handler_; +}; + +template class precision_checker { + public: + explicit FMT_CONSTEXPR precision_checker(ErrorHandler& eh) : handler_(eh) {} + + template ::value)> + FMT_CONSTEXPR auto operator()(T value) -> unsigned long long { + if (is_negative(value)) handler_.on_error("negative precision"); + return static_cast(value); + } + + template ::value)> + FMT_CONSTEXPR auto operator()(T) -> unsigned long long { + handler_.on_error("precision is not integer"); + return 0; + } + + private: + ErrorHandler& handler_; +}; + +template