add genie
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Genie/Genie/GenieSymbols.default +31 -0
- Genie/Genie/Makefile +57 -0
- Genie/Genie/README +16 -0
- Genie/Genie/make/Android.mk +56 -0
- Genie/Genie/make/Application.mk +14 -0
- Genie/Genie/make/Makefile.linux-x86_64 +192 -0
- Genie/Genie/src/Dialog.cpp +1804 -0
- Genie/Genie/src/Dialog.hpp +95 -0
- Genie/Genie/src/Exception.hpp +27 -0
- Genie/Genie/src/GenieCommon.cpp +15 -0
- Genie/Genie/src/GenieDialog.cpp +249 -0
- Genie/Genie/src/GenieDialogEmbedding.cpp +41 -0
- Genie/Genie/src/Macro.hpp +101 -0
- Genie/Genie/src/Util/HandleGenerator.hpp +62 -0
- Genie/Genie/src/Util/HandleManager.hpp +84 -0
- Genie/Genie/src/qualla/context.cpp +118 -0
- Genie/Genie/src/qualla/dialog.cpp +590 -0
- Genie/Genie/src/qualla/dialogs/basic.cpp +421 -0
- Genie/Genie/src/qualla/dialogs/kv-share.cpp +359 -0
- Genie/Genie/src/qualla/dialogs/lhd-dec.cpp +481 -0
- Genie/Genie/src/qualla/dialogs/multistream.cpp +300 -0
- Genie/Genie/src/qualla/dialogs/spec-dec.cpp +458 -0
- Genie/Genie/src/qualla/dialogs/ssd-q1.cpp +1046 -0
- Genie/Genie/src/qualla/embedding.cpp +190 -0
- Genie/Genie/src/qualla/engine.cpp +198 -0
- Genie/Genie/src/qualla/engines/lib.cpp +9 -0
- Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp +158 -0
- Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp +62 -0
- Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp +122 -0
- Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp +85 -0
- Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp +156 -0
- Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp +56 -0
- Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp +95 -0
- Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp +382 -0
- Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp +170 -0
- Genie/Genie/src/qualla/engines/qnn-api/Log.hpp +24 -0
- Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp +173 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp +0 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp +429 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp +636 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp +94 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp +44 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp +52 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp +702 -0
- Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp +481 -0
- Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp +115 -0
- Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp +66 -0
- Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp +33 -0
- Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp +104 -0
- Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp +157 -0
Genie/Genie/GenieSymbols.default
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#=============================================================================
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
# All Rights Reserved.
|
| 5 |
+
# Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
#
|
| 7 |
+
#=============================================================================
|
| 8 |
+
{
|
| 9 |
+
global:
|
| 10 |
+
Genie_getApiMajorVersion*;
|
| 11 |
+
Genie_getApiMinorVersion*;
|
| 12 |
+
Genie_getApiPatchVersion*;
|
| 13 |
+
GenieDialogConfig_createFromJson*;
|
| 14 |
+
GenieDialogConfig_free*;
|
| 15 |
+
GenieDialog_create*;
|
| 16 |
+
GenieDialog_query*;
|
| 17 |
+
GenieDialog_tokenQuery*;
|
| 18 |
+
GenieDialog_embeddingQuery*;
|
| 19 |
+
GenieDialog_save*;
|
| 20 |
+
GenieDialog_restore*;
|
| 21 |
+
GenieDialog_reset*;
|
| 22 |
+
GenieDialog_setLoraStrength*;
|
| 23 |
+
GenieDialog_applyLora*;
|
| 24 |
+
GenieDialog_free*;
|
| 25 |
+
GenieEmbeddingConfig_createFromJson*;
|
| 26 |
+
GenieEmbeddingConfig_free*;
|
| 27 |
+
GenieEmbedding_create*;
|
| 28 |
+
GenieEmbedding_generate*;
|
| 29 |
+
GenieEmbedding_free*;
|
| 30 |
+
local: *;
|
| 31 |
+
};
|
Genie/Genie/Makefile
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#=============================================================================
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
# All Rights Reserved.
|
| 5 |
+
# Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
#
|
| 7 |
+
#=============================================================================
|
| 8 |
+
|
| 9 |
+
RUST_TARGET := aarch64-linux-android
|
| 10 |
+
RUST_SOURCE_DIR := ./src/qualla/tokenizers/rust
|
| 11 |
+
# specify compiler
|
| 12 |
+
export CXX := clang++-14
|
| 13 |
+
export PATH := $(ANDROID_NDK_ROOT)/toolchains/llvm/prebuilt/linux-x86_64/bin:$(PATH)
|
| 14 |
+
.PHONY: all x86 android clean clean_x86 clean_android
|
| 15 |
+
.DEFAULT: x86
|
| 16 |
+
|
| 17 |
+
all: x86 android
|
| 18 |
+
|
| 19 |
+
x86: build_x86_tokenizer
|
| 20 |
+
@echo "-------------------- Building genie for x86 -------------------- "
|
| 21 |
+
@$(MAKE) -f make/Makefile.linux-x86_64 CPATH="/usr/include/x86_64-linux-gnu" || (echo "-------------------- genie x86 build failed --------------------"; exit 1; )
|
| 22 |
+
@echo "-------------------- genie x86 build succeeded -------------------- "
|
| 23 |
+
|
| 24 |
+
android: check_ndk build_android_tokenizer
|
| 25 |
+
@echo "-------------------- Building genie for android -------------------- "
|
| 26 |
+
@$(ANDROID_NDK_ROOT)/ndk-build APP_ALLOW_MISSING_DEPS=true APP_ABI="arm64-v8a" NDK_PROJECT_PATH=./ NDK_APPLICATION_MK=make/Application.mk APP_BUILD_SCRIPT=make/Android.mk || (echo "-------------------- genie android build failed --------------------"; exit 1; )
|
| 27 |
+
@$(rename_target_dirs)
|
| 28 |
+
@echo "-------------------- genie android build succeeded -------------------- "
|
| 29 |
+
|
| 30 |
+
clean: clean_x86 clean_android
|
| 31 |
+
|
| 32 |
+
clean_x86:
|
| 33 |
+
@$(MAKE) -f make/Makefile.linux-x86_64 clean
|
| 34 |
+
|
| 35 |
+
clean_android:
|
| 36 |
+
if [ -d "lib/aarch64-android" ]; then rm -rf lib/aarch64-android; fi
|
| 37 |
+
if [ -d "obj/local" ]; then rm -rf obj/local; fi
|
| 38 |
+
|
| 39 |
+
# utilities
|
| 40 |
+
rename_target_dirs = \
|
| 41 |
+
@if [ -d ./lib/aarch64-android ]; then rm -rf ./lib/aarch64-android; fi; \
|
| 42 |
+
find ./obj/local -type d -execdir rename 's/arm64-v8a/aarch64-android/' '{}' \+ \
|
| 43 |
+
&& mkdir -p lib \
|
| 44 |
+
&& mv ./obj/local/aarch64-android lib/ \
|
| 45 |
+
&& mv ./libs/arm64-v8a/libc++_shared.so lib/aarch64-android/ \
|
| 46 |
+
&& rm -rf ./libs \
|
| 47 |
+
|
| 48 |
+
check_ndk:
|
| 49 |
+
ifeq ($(ANDROID_NDK_ROOT),)
|
| 50 |
+
$(error ERROR: ANDROID_NDK_ROOT not set, skipping compilation for Android platform(s).)
|
| 51 |
+
endif
|
| 52 |
+
|
| 53 |
+
build_x86_tokenizer: $(RUST_SOURCE_DIR)/Cargo.toml
|
| 54 |
+
cargo build --release --manifest-path=$<
|
| 55 |
+
|
| 56 |
+
build_android_tokenizer: $(RUST_SOURCE_DIR)/Cargo.toml
|
| 57 |
+
cargo build --release --manifest-path=$< --target=$(RUST_TARGET)
|
Genie/Genie/README
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#=============================================================================
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
# All Rights Reserved.
|
| 5 |
+
# Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
#
|
| 7 |
+
#=============================================================================
|
| 8 |
+
|
| 9 |
+
Genie library source code example
|
| 10 |
+
---------------------------------
|
| 11 |
+
|
| 12 |
+
The Genie library (libGenie.so / Genie.dll) source code example provides users with an ability to recreate the Genie
|
| 13 |
+
library from source. Note that the Genie library source may be refactored, rewritten, or otherwise modified without
|
| 14 |
+
notice. The Genie C API is the commercially controlled and versioned interface that users should expect to be stable.
|
| 15 |
+
Please refer to the Genie SDK documentation tutorials at ${SDK_ROOT}/doc/Genie/ for more information on how to build the
|
| 16 |
+
sample code.
|
Genie/Genie/make/Android.mk
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#=============================================================================
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
# All Rights Reserved.
|
| 5 |
+
# Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
#
|
| 7 |
+
#=============================================================================
|
| 8 |
+
|
| 9 |
+
LOCAL_PATH := $(call my-dir)
|
| 10 |
+
SUPPORTED_TARGET_ABI := arm64-v8a x86 x86_64
|
| 11 |
+
|
| 12 |
+
#============================ Verify Target Info and Application Variables =========================================
|
| 13 |
+
ifneq ($(filter $(TARGET_ARCH_ABI),$(SUPPORTED_TARGET_ABI)),)
|
| 14 |
+
ifneq ($(APP_STL), c++_shared)
|
| 15 |
+
$(error Unsupported APP_STL: "$(APP_STL)")
|
| 16 |
+
endif
|
| 17 |
+
else
|
| 18 |
+
$(error Unsupported TARGET_ARCH_ABI: '$(TARGET_ARCH_ABI)')
|
| 19 |
+
endif
|
| 20 |
+
|
| 21 |
+
#============================ Define Common Variables ===============================================================
|
| 22 |
+
# PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../../include/QNN
|
| 23 |
+
# Include paths
|
| 24 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../include
|
| 25 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/Genie
|
| 26 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/include
|
| 27 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN
|
| 28 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN/HTP
|
| 29 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers
|
| 30 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api
|
| 31 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu
|
| 32 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp
|
| 33 |
+
|
| 34 |
+
#========================== Define T2T Lib variables =============================================
|
| 35 |
+
include $(CLEAR_VARS)
|
| 36 |
+
LOCAL_MODULE := tokenizers_capi
|
| 37 |
+
LOCAL_SRC_FILES := ../src/qualla/tokenizers/rust/target/aarch64-linux-android/release/libtokenizers_capi.a
|
| 38 |
+
include $(PREBUILT_STATIC_LIBRARY)
|
| 39 |
+
|
| 40 |
+
include $(CLEAR_VARS)
|
| 41 |
+
LOCAL_C_INCLUDES := $(PACKAGE_C_INCLUDES)
|
| 42 |
+
MY_SRC_FILES := $(wildcard $(LOCAL_PATH)/../src/*.cpp)
|
| 43 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/*.cpp)
|
| 44 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/dialogs/*.cpp)
|
| 45 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp)
|
| 46 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp)
|
| 47 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp)
|
| 48 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp)
|
| 49 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp)
|
| 50 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp)
|
| 51 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/samplers/*.cpp)
|
| 52 |
+
|
| 53 |
+
LOCAL_MODULE := libGenie
|
| 54 |
+
LOCAL_SRC_FILES := $(subst make/,,$(MY_SRC_FILES))
|
| 55 |
+
LOCAL_STATIC_LIBRARIES := tokenizers_capi
|
| 56 |
+
include $(BUILD_SHARED_LIBRARY)
|
Genie/Genie/make/Application.mk
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#=============================================================================
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
# All Rights Reserved.
|
| 5 |
+
# Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
#
|
| 7 |
+
#=============================================================================
|
| 8 |
+
|
| 9 |
+
APP_ABI := arm64-v8a
|
| 10 |
+
APP_STL := c++_shared
|
| 11 |
+
APP_PLATFORM := android-21
|
| 12 |
+
APP_MODULES := Genie
|
| 13 |
+
APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 14 |
+
APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all
|
Genie/Genie/make/Makefile.linux-x86_64
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#=============================================================================
|
| 2 |
+
#
|
| 3 |
+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
# All Rights Reserved.
|
| 5 |
+
# Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
#
|
| 7 |
+
#=============================================================================
|
| 8 |
+
|
| 9 |
+
# define relevant directories
|
| 10 |
+
SRC_DIR := src/qualla
|
| 11 |
+
#
|
| 12 |
+
SRC_DIR_GENIE_TOKENIZERS := src/qualla/tokenizers
|
| 13 |
+
#
|
| 14 |
+
SRC_DIR_SAMPLE_DIALOGS := src/qualla/dialogs
|
| 15 |
+
|
| 16 |
+
# All engines
|
| 17 |
+
SRC_DIR_GENIE_ENGINES := src/qualla/engines
|
| 18 |
+
SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api
|
| 19 |
+
SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu
|
| 20 |
+
SRC_DIR_GENIE_UTILS := src/qualla/utils
|
| 21 |
+
#
|
| 22 |
+
SRC_DIR_GENIE_LOGGERS := src/qualla/loggers
|
| 23 |
+
|
| 24 |
+
#
|
| 25 |
+
SRC_DIR_GENIE_SAMPLERS := src/qualla/samplers
|
| 26 |
+
|
| 27 |
+
#
|
| 28 |
+
SRC_DIR_GENIE := src
|
| 29 |
+
|
| 30 |
+
# Includes
|
| 31 |
+
GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu
|
| 32 |
+
GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api
|
| 33 |
+
GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp
|
| 34 |
+
GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers
|
| 35 |
+
|
| 36 |
+
GENIE_INCLUDE := include
|
| 37 |
+
GENIE_C_API_HEADERS_INCLUDE := ../../../include/Genie
|
| 38 |
+
QUALLA_INCLUDE := src/qualla/include
|
| 39 |
+
QNN_API_INCLUDE := ../../../include/QNN/
|
| 40 |
+
QNN_API_HTP_INCLUDE := $(QNN_API_INCLUDE)/HTP
|
| 41 |
+
|
| 42 |
+
AR := /usr/bin/ar
|
| 43 |
+
ARFLAGS := rcs
|
| 44 |
+
# Checking if clang++ is present. If not switch to clang++
|
| 45 |
+
ifeq ($(shell $(CXX) -v 2>&1 | grep -c "clang version"), 0)
|
| 46 |
+
CXX := clang++
|
| 47 |
+
endif
|
| 48 |
+
|
| 49 |
+
QNN_TARGET ?= x86_64-linux-clang
|
| 50 |
+
export TARGET_DIR := ./lib/$(QNN_TARGET)
|
| 51 |
+
|
| 52 |
+
libGenie := $(TARGET_DIR)/libGenie.so
|
| 53 |
+
libtokenizers := src/qualla/tokenizers/rust/target/release/libtokenizers_capi.a
|
| 54 |
+
|
| 55 |
+
# define target architecture if not previously defined, default is x86
|
| 56 |
+
ifndef TARGET_AARCH_VARS
|
| 57 |
+
TARGET_AARCH_VARS:= -march=x86-64
|
| 58 |
+
endif
|
| 59 |
+
|
| 60 |
+
.PHONY: linux_x86_64
|
| 61 |
+
.DEFAULT: linux_x86_64
|
| 62 |
+
GENIE_all: $(libGenie)
|
| 63 |
+
|
| 64 |
+
# Include paths
|
| 65 |
+
INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE)
|
| 66 |
+
|
| 67 |
+
# set compiler flags
|
| 68 |
+
COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES)
|
| 69 |
+
COMMON_LDFLAGS = -shared -s -fPIC -pthread -L/usr/lib/x86_64-linux-gnu -L./src/qualla/tokenizers/rust/target/release
|
| 70 |
+
|
| 71 |
+
COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include
|
| 72 |
+
|
| 73 |
+
ifdef QNN_DEBUG_ENABLE
|
| 74 |
+
CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 75 |
+
CFLAGS += $(COMMON_CFLAGS)
|
| 76 |
+
LDFLAGS += $(COMMON_LDFLAGS)
|
| 77 |
+
else
|
| 78 |
+
CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 79 |
+
CFLAGS += $(COMMON_CFLAGS)
|
| 80 |
+
LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto
|
| 81 |
+
endif
|
| 82 |
+
|
| 83 |
+
# define library sources
|
| 84 |
+
SOURCES_GENIE_CPP := $(wildcard $(SRC_DIR_GENIE)/*.cpp)
|
| 85 |
+
SOURCES := $(wildcard $(SRC_DIR)/*.cpp)
|
| 86 |
+
SOURCES_GENIE_TOKENIZERS := $(wildcard $(SRC_DIR_GENIE_TOKENIZERS)/*.cpp)
|
| 87 |
+
SOURCES_GENIE_QNN_API_CPP := $(wildcard $(SRC_DIR_GENIE_QNN_API)/*.cpp)
|
| 88 |
+
|
| 89 |
+
SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp))
|
| 90 |
+
SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp)
|
| 91 |
+
SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp)
|
| 92 |
+
SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
SOURCES_GENIE_LOGGERS_CPP := $(wildcard $(SRC_DIR_GENIE_LOGGERS)/*.cpp)
|
| 96 |
+
SOURCES_GENIE_SAMPLERS_CPP := $(wildcard $(SRC_DIR_GENIE_SAMPLERS)/*.cpp)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
# define object directory
|
| 100 |
+
OBJ_ROOT := obj
|
| 101 |
+
OBJ_DIR_QUALLA := obj/$(QNN_TARGET)/qualla
|
| 102 |
+
OBJ_DIR_GENIE := obj/$(QNN_TARGET)/genie
|
| 103 |
+
OBJ_DIR_GENIE_TOKENIZERS := $(OBJ_DIR_QUALLA)/tokenizers
|
| 104 |
+
OBJ_DIR_GENIE_QNN_API := $(OBJ_DIR_QUALLA)/qnn-api
|
| 105 |
+
|
| 106 |
+
OBJ_DIR_GENIE_DIALOGS := $(OBJ_DIR_QUALLA)/dialogs
|
| 107 |
+
OBJ_DIR_GENIE_ENGINES := $(OBJ_DIR_QUALLA)/engines
|
| 108 |
+
OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils
|
| 109 |
+
OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu
|
| 110 |
+
$(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU))
|
| 111 |
+
|
| 112 |
+
OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers
|
| 113 |
+
OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers
|
| 114 |
+
|
| 115 |
+
$(shell mkdir -p $(OBJ_DIR_GENIE))
|
| 116 |
+
$(shell mkdir -p $(OBJ_DIR_GENIE_LOGGERS))
|
| 117 |
+
$(shell mkdir -p $(OBJ_DIR_GENIE_SAMPLERS))
|
| 118 |
+
|
| 119 |
+
# setup object files in object directory
|
| 120 |
+
OBJECTS_GENIE := $(patsubst %.cpp,$(OBJ_DIR_GENIE)/%.o,$(foreach x,$(SOURCES_GENIE_CPP),$(notdir $(x))))
|
| 121 |
+
OBJECTS_QUALLA := $(patsubst %.cpp,$(OBJ_DIR_QUALLA)/%.o,$(foreach x,$(SOURCES),$(notdir $(x))))
|
| 122 |
+
OBJECTS_GENIE_TOKENIZERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_TOKENIZERS)/%.o,$(foreach x,$(SOURCES_GENIE_TOKENIZERS),$(notdir $(x))))
|
| 123 |
+
OBJECTS_GENIE_QNN_API := $(patsubst %.cpp,$(OBJ_DIR_GENIE_QNN_API)/%.o,$(foreach x,$(SOURCES_GENIE_QNN_API_CPP),$(notdir $(x))))
|
| 124 |
+
OBJECTS_GENIE_ENGINES := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPP),$(notdir $(x))))
|
| 125 |
+
OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x))))
|
| 126 |
+
OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x))))
|
| 127 |
+
OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x))))
|
| 128 |
+
|
| 129 |
+
OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x))))
|
| 130 |
+
OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x))))
|
| 131 |
+
|
| 132 |
+
LIBS=-ldl
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# Rule to make shared lib
|
| 136 |
+
.PHONY: libGenie
|
| 137 |
+
libGenie: $(libGenie)
|
| 138 |
+
|
| 139 |
+
# Implicit rule to compile and link object files
|
| 140 |
+
$(OBJ_DIR_GENIE)/%.o: $(SRC_DIR_GENIE)/%.cpp
|
| 141 |
+
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 142 |
+
|
| 143 |
+
$(OBJ_DIR_QUALLA)/%.o: $(SRC_DIR)/%.cpp
|
| 144 |
+
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 145 |
+
|
| 146 |
+
$(OBJ_DIR_GENIE_TOKENIZERS)/%.o: $(SRC_DIR_GENIE_TOKENIZERS)/%.cpp
|
| 147 |
+
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 148 |
+
|
| 149 |
+
$(OBJ_DIR_GENIE_QNN_API)/%.o: $(SRC_DIR_GENIE_QNN_API)/%.cpp
|
| 150 |
+
$(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 151 |
+
|
| 152 |
+
$(OBJ_DIR_GENIE_ENGINES)/%.o: $(SRC_DIR_GENIE_ENGINES)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 153 |
+
|
| 154 |
+
$(OBJ_DIR_GENIE_DIALOGS)/%.o: $(SRC_DIR_SAMPLE_DIALOGS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 155 |
+
|
| 156 |
+
$(OBJ_DIR_GENIE_UTILS)/%.o: $(SRC_DIR_GENIE_UTILS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 157 |
+
|
| 158 |
+
$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 159 |
+
|
| 160 |
+
$(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 161 |
+
|
| 162 |
+
$(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# set up resources
|
| 166 |
+
directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS)
|
| 167 |
+
|
| 168 |
+
# Compile
|
| 169 |
+
$(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories)
|
| 170 |
+
$(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# rule for object directory resource
|
| 174 |
+
$(OBJECTS_GENIE): | $(OBJ_DIR_GENIE)
|
| 175 |
+
$(OBJECTS_QUALLA): | $(OBJ_DIR_QUALLA)
|
| 176 |
+
$(OBJECTS_GENIE_TOKENIZERS): | $(OBJ_DIR_GENIE_TOKENIZERS)
|
| 177 |
+
$(OBJECTS_GENIE_QNN_API): | $(OBJ_DIR_GENIE_QNN_API)
|
| 178 |
+
$(OBJECTS_GENIE_ENGINES): | $(OBJ_DIR_GENIE_ENGINES)
|
| 179 |
+
$(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS)
|
| 180 |
+
$(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS)
|
| 181 |
+
$(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU)
|
| 182 |
+
$(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS)
|
| 183 |
+
$(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# rule to create directories
|
| 187 |
+
$(directories):
|
| 188 |
+
mkdir -p $@
|
| 189 |
+
|
| 190 |
+
.PHONY: clean
|
| 191 |
+
clean:
|
| 192 |
+
rm -rf $(OBJ_ROOT) $(TARGET_DIR)
|
Genie/Genie/src/Dialog.cpp
ADDED
|
@@ -0,0 +1,1804 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <exception>
|
| 10 |
+
#include <set>
|
| 11 |
+
#include <sstream>
|
| 12 |
+
|
| 13 |
+
#include "Dialog.hpp"
|
| 14 |
+
#include "Exception.hpp"
|
| 15 |
+
#include "Macro.hpp"
|
| 16 |
+
#include "qualla/detail/json.hpp"
|
| 17 |
+
#include "qualla/env.hpp"
|
| 18 |
+
|
| 19 |
+
using namespace genie;
|
| 20 |
+
|
| 21 |
+
#ifdef _WIN32
|
| 22 |
+
inline std::string libPrefix = "";
|
| 23 |
+
inline std::string libSuffix = ".dll";
|
| 24 |
+
#else
|
| 25 |
+
inline std::string libPrefix = "lib";
|
| 26 |
+
inline std::string libSuffix = ".so";
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
+
inline std::string getLibName(std::string baseName) { return libPrefix + baseName + libSuffix; }
|
| 30 |
+
|
| 31 |
+
//=============================================================================
|
| 32 |
+
// Context::Config functions
|
| 33 |
+
//=============================================================================
|
| 34 |
+
|
| 35 |
+
static void validateContextConfig(const qualla::json& config) {
|
| 36 |
+
if (!config.is_object()) {
|
| 37 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "context config is not an object");
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
std::set<std::string> mandatoryFields{"version", "bos-token", "eos-token", "size", "n-vocab"};
|
| 41 |
+
for (const auto& field : mandatoryFields) {
|
| 42 |
+
if (!config.contains(field)) {
|
| 43 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing context field: " + field);
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// component is used in the "ENFORCE" macros
|
| 48 |
+
std::string component = "context";
|
| 49 |
+
|
| 50 |
+
for (auto& item : config.items()) {
|
| 51 |
+
if (item.key() == "version") {
|
| 52 |
+
JSON_ENFORCE_NUMERIC();
|
| 53 |
+
if (item.value().get<int>() != 1) {
|
| 54 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 55 |
+
"Invalid context config: unsupported version: " + item.value().dump());
|
| 56 |
+
}
|
| 57 |
+
} else if (item.key() == "bos-token") {
|
| 58 |
+
JSON_ENFORCE_NUMERIC();
|
| 59 |
+
} else if (item.key() == "eos-token") {
|
| 60 |
+
JSON_ENFORCE_ARRAY_OR_NUMERIC();
|
| 61 |
+
} else if (item.key() == "eot-token") {
|
| 62 |
+
JSON_ENFORCE_NUMERIC();
|
| 63 |
+
} else if (item.key() == "size") {
|
| 64 |
+
JSON_ENFORCE_NUMERIC();
|
| 65 |
+
} else if (item.key() == "n-vocab") {
|
| 66 |
+
JSON_ENFORCE_NUMERIC();
|
| 67 |
+
} else if (item.key() == "pad-token") {
|
| 68 |
+
JSON_ENFORCE_NUMERIC();
|
| 69 |
+
} else {
|
| 70 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
|
| 71 |
+
}
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
static void translateContextConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 76 |
+
if (genieConfig["dialog"].contains("context")) {
|
| 77 |
+
if (genieConfig["dialog"]["context"].contains("bos-token")) {
|
| 78 |
+
quallaConfig["context"]["bos-token"] = genieConfig["dialog"]["context"]["bos-token"];
|
| 79 |
+
}
|
| 80 |
+
if (genieConfig["dialog"]["context"].contains("eos-token")) {
|
| 81 |
+
quallaConfig["context"]["eos-token"] = genieConfig["dialog"]["context"]["eos-token"];
|
| 82 |
+
}
|
| 83 |
+
if (genieConfig["dialog"]["context"].contains("eot-token")) {
|
| 84 |
+
quallaConfig["context"]["eot-token"] = genieConfig["dialog"]["context"]["eot-token"];
|
| 85 |
+
}
|
| 86 |
+
if (genieConfig["dialog"]["context"].contains("size")) {
|
| 87 |
+
quallaConfig["context"]["size"] = genieConfig["dialog"]["context"]["size"];
|
| 88 |
+
}
|
| 89 |
+
if (genieConfig["dialog"]["context"].contains("n-vocab")) {
|
| 90 |
+
quallaConfig["context"]["n-vocab"] = genieConfig["dialog"]["context"]["n-vocab"];
|
| 91 |
+
}
|
| 92 |
+
if (genieConfig["dialog"]["context"].contains("pad-token")) {
|
| 93 |
+
quallaConfig["context"]["pad-token"] = genieConfig["dialog"]["context"]["pad-token"];
|
| 94 |
+
}
|
| 95 |
+
}
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
//=============================================================================
|
| 99 |
+
// Sampler::Config functions
|
| 100 |
+
//=============================================================================
|
| 101 |
+
|
| 102 |
+
static void validateSamplerConfig(const qualla::json& config) {
|
| 103 |
+
if (!config.is_object()) {
|
| 104 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
std::set<std::string> mandatoryFields{"version"};
|
| 108 |
+
for (const auto& field : mandatoryFields) {
|
| 109 |
+
if (!config.contains(field)) {
|
| 110 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// component is used in the "ENFORCE" macros
|
| 115 |
+
std::string component = "sampler";
|
| 116 |
+
|
| 117 |
+
for (auto& item : config.items()) {
|
| 118 |
+
if (item.key() == "version") {
|
| 119 |
+
JSON_ENFORCE_NUMERIC();
|
| 120 |
+
if (item.value().get<int>() != 1) {
|
| 121 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 122 |
+
"Invalid sampler config: unsupported version: " + item.value().dump());
|
| 123 |
+
}
|
| 124 |
+
} else if (item.key() == "seed") {
|
| 125 |
+
JSON_ENFORCE_NUMERIC();
|
| 126 |
+
} else if (item.key() == "temp") {
|
| 127 |
+
JSON_ENFORCE_NUMERIC();
|
| 128 |
+
} else if (item.key() == "top-k") {
|
| 129 |
+
JSON_ENFORCE_NUMERIC();
|
| 130 |
+
} else if (item.key() == "top-p") {
|
| 131 |
+
JSON_ENFORCE_NUMERIC();
|
| 132 |
+
} else if (item.key() == "greedy") {
|
| 133 |
+
JSON_ENFORCE_BOOLEAN();
|
| 134 |
+
} else {
|
| 135 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 141 |
+
if (genieConfig["dialog"].contains("sampler")) {
|
| 142 |
+
quallaConfig["sampler"]["type"] = "basic";
|
| 143 |
+
|
| 144 |
+
if (genieConfig["dialog"]["sampler"].contains("seed")) {
|
| 145 |
+
quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
|
| 146 |
+
}
|
| 147 |
+
if (genieConfig["dialog"]["sampler"].contains("temp")) {
|
| 148 |
+
quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
quallaConfig["sampler"]["role"] = "primary";
|
| 152 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 153 |
+
if (genieConfig["dialog"]["type"] == "spd") {
|
| 154 |
+
quallaConfig["sampler"]["role"] = "target";
|
| 155 |
+
}
|
| 156 |
+
#endif
|
| 157 |
+
|
| 158 |
+
if (genieConfig["dialog"]["sampler"].contains("top-k")) {
|
| 159 |
+
quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
|
| 160 |
+
}
|
| 161 |
+
if (genieConfig["dialog"]["sampler"].contains("top-p")) {
|
| 162 |
+
quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
|
| 163 |
+
}
|
| 164 |
+
if (genieConfig["dialog"]["sampler"].contains("greedy")) {
|
| 165 |
+
quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
|
| 166 |
+
}
|
| 167 |
+
if (genieConfig["dialog"]["sampler"].contains("seed")) {
|
| 168 |
+
quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
//=============================================================================
|
| 174 |
+
// Tokenizer::Config functions
|
| 175 |
+
//=============================================================================
|
| 176 |
+
|
| 177 |
+
static void validateTokenizerConfig(const qualla::json& config) {
|
| 178 |
+
if (!config.is_object()) {
|
| 179 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "tokenizer config is not an object");
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
std::set<std::string> mandatoryFields{"version", "path"};
|
| 183 |
+
for (const auto& field : mandatoryFields) {
|
| 184 |
+
if (!config.contains(field)) {
|
| 185 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing tokenizer field: " + field);
|
| 186 |
+
}
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
// component is used in the "ENFORCE" macros
|
| 190 |
+
std::string component = "tokenizer";
|
| 191 |
+
|
| 192 |
+
for (auto& item : config.items()) {
|
| 193 |
+
if (item.key() == "version") {
|
| 194 |
+
JSON_ENFORCE_NUMERIC();
|
| 195 |
+
if (item.value().get<int>() != 1) {
|
| 196 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 197 |
+
"Invalid tokenizer config: unsupported version: " + item.value().dump());
|
| 198 |
+
}
|
| 199 |
+
} else if (item.key() == "path") {
|
| 200 |
+
JSON_ENFORCE_STRING();
|
| 201 |
+
// Note: the existence of this file is checked by qualla
|
| 202 |
+
} else {
|
| 203 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 204 |
+
"Unknown tokenizer config key: " + item.key());
|
| 205 |
+
}
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
static void translateTokenizerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 210 |
+
quallaConfig["tokenizer"] = genieConfig["dialog"]["tokenizer"]["path"];
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
//=============================================================================
|
| 214 |
+
// Embedding::Config functions
|
| 215 |
+
//=============================================================================
|
| 216 |
+
|
| 217 |
+
static void validateEmbeddingConfig(const qualla::json& config) {
|
| 218 |
+
if (!config.is_object()) {
|
| 219 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "embedding config is not an object");
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
std::set<std::string> mandatoryFields{"version", "size"};
|
| 223 |
+
for (const auto& field : mandatoryFields) {
|
| 224 |
+
if (!config.contains(field)) {
|
| 225 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// component is used in the "ENFORCE" macros
|
| 230 |
+
std::string component = "embedding";
|
| 231 |
+
|
| 232 |
+
for (auto& item : config.items()) {
|
| 233 |
+
if (item.key() == "version") {
|
| 234 |
+
JSON_ENFORCE_NUMERIC();
|
| 235 |
+
if (item.value().get<int>() != 1) {
|
| 236 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 237 |
+
"Invalid embedding config: unsupported version: " + item.value().dump());
|
| 238 |
+
}
|
| 239 |
+
} else if (item.key() == "size") {
|
| 240 |
+
JSON_ENFORCE_NUMERIC();
|
| 241 |
+
} else if (item.key() == "datatype") {
|
| 242 |
+
JSON_ENFORCE_STRING();
|
| 243 |
+
const std::set<std::string> supportedTypes = {"float32", "native"};
|
| 244 |
+
if (std::find(supportedTypes.begin(), supportedTypes.end(), item.value()) ==
|
| 245 |
+
supportedTypes.end()) {
|
| 246 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 247 |
+
"Unknown embedding datatype: " + std::string(item.value()));
|
| 248 |
+
}
|
| 249 |
+
} else {
|
| 250 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 251 |
+
"Unknown embedding config key: " + item.key());
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
static void translateEmbeddingConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 257 |
+
if (genieConfig["dialog"].contains("embedding")) {
|
| 258 |
+
quallaConfig["context"]["n-embd"] = genieConfig["dialog"]["embedding"]["size"];
|
| 259 |
+
|
| 260 |
+
if (genieConfig["dialog"]["embedding"].contains("datatype")) {
|
| 261 |
+
quallaConfig["context"]["embedding-datatype"] =
|
| 262 |
+
genieConfig["dialog"]["embedding"]["datatype"];
|
| 263 |
+
}
|
| 264 |
+
}
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
bool position_dim_set = false;
|
| 268 |
+
bool rope_theta_set = false;
|
| 269 |
+
|
| 270 |
+
//=============================================================================
|
| 271 |
+
// Backend::Config functions
|
| 272 |
+
//=============================================================================
|
| 273 |
+
|
| 274 |
+
static void validateBackendHtpConfig(const qualla::json& config) {
|
| 275 |
+
if (!config.is_object()) {
|
| 276 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnHtp config is not an object");
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
std::set<std::string> mandatoryFields{
|
| 280 |
+
"version", "spill-fill-bufsize", "mmap-budget", "use-mmap", "cpu-mask", "poll"};
|
| 281 |
+
for (const auto& field : mandatoryFields) {
|
| 282 |
+
if (!config.contains(field)) {
|
| 283 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp field: " + field);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
// component is used in the "ENFORCE" macros
|
| 288 |
+
std::string component = "QnnHtp";
|
| 289 |
+
|
| 290 |
+
for (auto& item : config.items()) {
|
| 291 |
+
if (item.key() == "version") {
|
| 292 |
+
JSON_ENFORCE_NUMERIC();
|
| 293 |
+
if (item.value().get<int>() != 1) {
|
| 294 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 295 |
+
"Invalid QnnHtp config: unsupported version: " + item.value().dump());
|
| 296 |
+
}
|
| 297 |
+
} else if (item.key() == "spill-fill-bufsize") {
|
| 298 |
+
JSON_ENFORCE_NUMERIC();
|
| 299 |
+
} else if (item.key() == "mmap-budget") {
|
| 300 |
+
JSON_ENFORCE_NUMERIC();
|
| 301 |
+
} else if (item.key() == "use-mmap") {
|
| 302 |
+
JSON_ENFORCE_BOOLEAN();
|
| 303 |
+
#ifdef _WIN32
|
| 304 |
+
if (item.value() == true) {
|
| 305 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 306 |
+
"Invalid QnnHtp config. use-mmap not supported on target");
|
| 307 |
+
}
|
| 308 |
+
#endif
|
| 309 |
+
} else if (item.key() == "pos-id-dim") {
|
| 310 |
+
position_dim_set = true;
|
| 311 |
+
JSON_ENFORCE_NUMERIC();
|
| 312 |
+
} else if (item.key() == "cpu-mask") {
|
| 313 |
+
JSON_ENFORCE_STRING();
|
| 314 |
+
} else if (item.key() == "poll") {
|
| 315 |
+
JSON_ENFORCE_BOOLEAN();
|
| 316 |
+
} else if (item.key() == "kv-dim") {
|
| 317 |
+
JSON_ENFORCE_NUMERIC();
|
| 318 |
+
} else if (item.key() == "kv-update-method") {
|
| 319 |
+
JSON_ENFORCE_STRING();
|
| 320 |
+
} else if (item.key() == "allow-async-init") {
|
| 321 |
+
JSON_ENFORCE_BOOLEAN();
|
| 322 |
+
} else if (item.key() == "rope-theta") {
|
| 323 |
+
rope_theta_set = true;
|
| 324 |
+
JSON_ENFORCE_NUMERIC();
|
| 325 |
+
} else {
|
| 326 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
static void validateBackendGenaiConfig(const qualla::json& config) {
|
| 332 |
+
if (!config.is_object()) {
|
| 333 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnGenAiTransformer config is not an object");
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
std::set<std::string> mandatoryFields{"version"};
|
| 337 |
+
for (const auto& field : mandatoryFields) {
|
| 338 |
+
if (!config.contains(field)) {
|
| 339 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 340 |
+
"Missing QnnGenAiTransformer field: " + field);
|
| 341 |
+
}
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
// component is used in the "ENFORCE" macros
|
| 345 |
+
std::string component = "QnnGenAiTransformer";
|
| 346 |
+
|
| 347 |
+
for (auto& item : config.items()) {
|
| 348 |
+
if (item.key() == "version") {
|
| 349 |
+
JSON_ENFORCE_NUMERIC();
|
| 350 |
+
if (item.value().get<int>() != 1) {
|
| 351 |
+
throw Exception(
|
| 352 |
+
GENIE_STATUS_ERROR_JSON_VALUE,
|
| 353 |
+
"Invalid QnnGenAiTransformer config: unsupported version: " + item.value().dump());
|
| 354 |
+
}
|
| 355 |
+
} else if (item.key() == "use-mmap") {
|
| 356 |
+
JSON_ENFORCE_BOOLEAN();
|
| 357 |
+
#ifdef _WIN32
|
| 358 |
+
if (item.value() == true) {
|
| 359 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 360 |
+
"Invalid QnnGenAiTransformer config. use-mmap not supported on target");
|
| 361 |
+
}
|
| 362 |
+
#endif
|
| 363 |
+
} else if (item.key() == "n-logits") {
|
| 364 |
+
JSON_ENFORCE_NUMERIC();
|
| 365 |
+
} else if (item.key() == "n-layer") {
|
| 366 |
+
JSON_ENFORCE_NUMERIC();
|
| 367 |
+
} else if (item.key() == "n-embd") {
|
| 368 |
+
JSON_ENFORCE_NUMERIC();
|
| 369 |
+
} else if (item.key() == "n-heads") {
|
| 370 |
+
JSON_ENFORCE_NUMERIC();
|
| 371 |
+
} else {
|
| 372 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 373 |
+
"Unknown QnnGenAiTransformer config key: " + item.key());
|
| 374 |
+
}
|
| 375 |
+
}
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
static void validateBackendConfig(const qualla::json& config) {
|
| 379 |
+
if (!config.is_object()) {
|
| 380 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "backend config is not an object");
|
| 381 |
+
}
|
| 382 |
+
|
| 383 |
+
std::set<std::string> mandatoryFields{"version", "type"};
|
| 384 |
+
for (const auto& field : mandatoryFields) {
|
| 385 |
+
if (!config.contains(field)) {
|
| 386 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing backend field: " + field);
|
| 387 |
+
}
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
// component is used in the "ENFORCE" macros
|
| 391 |
+
std::string component = "backend";
|
| 392 |
+
|
| 393 |
+
std::string type;
|
| 394 |
+
bool htp = false;
|
| 395 |
+
qualla::json htpConfig;
|
| 396 |
+
bool genai = false;
|
| 397 |
+
qualla::json genaiConfig;
|
| 398 |
+
|
| 399 |
+
for (auto& item : config.items()) {
|
| 400 |
+
if (item.key() == "version") {
|
| 401 |
+
JSON_ENFORCE_NUMERIC();
|
| 402 |
+
if (item.value().get<int>() != 1) {
|
| 403 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 404 |
+
"Invalid backend config: unsupported version: " + item.value().dump());
|
| 405 |
+
}
|
| 406 |
+
} else if (item.key() == "type") {
|
| 407 |
+
JSON_ENFORCE_STRING();
|
| 408 |
+
type = item.value().get<std::string>();
|
| 409 |
+
if (type == "QnnHtp") {
|
| 410 |
+
htp = true;
|
| 411 |
+
} else if (type == "QnnGenAiTransformer") {
|
| 412 |
+
genai = true;
|
| 413 |
+
} else {
|
| 414 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 415 |
+
"Invalid backend config: unsupported type: " + item.value().dump());
|
| 416 |
+
}
|
| 417 |
+
} else if (item.key() == "extensions") {
|
| 418 |
+
JSON_ENFORCE_STRING();
|
| 419 |
+
} else if (item.key() == "QnnHtp") {
|
| 420 |
+
JSON_ENFORCE_OBJECT();
|
| 421 |
+
htpConfig = item.value();
|
| 422 |
+
} else if (item.key() == "QnnGenAiTransformer") {
|
| 423 |
+
JSON_ENFORCE_OBJECT();
|
| 424 |
+
genaiConfig = item.value();
|
| 425 |
+
} else {
|
| 426 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown backend config key: " + item.key());
|
| 427 |
+
}
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
if (htp) {
|
| 431 |
+
if (!htpConfig.is_object()) {
|
| 432 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp dialog config");
|
| 433 |
+
}
|
| 434 |
+
validateBackendHtpConfig(htpConfig);
|
| 435 |
+
} else {
|
| 436 |
+
if (htpConfig.is_object()) {
|
| 437 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 438 |
+
"QnnHtp backend config for incorrect backend type: " + type);
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
if (genai) {
|
| 443 |
+
if (!genaiConfig.is_object()) {
|
| 444 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnGenAiTransformer dialog config");
|
| 445 |
+
}
|
| 446 |
+
validateBackendGenaiConfig(genaiConfig);
|
| 447 |
+
} else {
|
| 448 |
+
if (genaiConfig.is_object()) {
|
| 449 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 450 |
+
"QnnGenAiTransformer backend config for incorrect backend type: " + type);
|
| 451 |
+
}
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
//=============================================================================
|
| 456 |
+
// Model::Config functions
|
| 457 |
+
//=============================================================================
|
| 458 |
+
|
| 459 |
+
static void validateLoraAdapterConfig(const qualla::json& config,
|
| 460 |
+
LORA_VERSION& specifiedLoraVersion) {
|
| 461 |
+
if (!config.is_object()) {
|
| 462 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora adapter config is not an object");
|
| 463 |
+
}
|
| 464 |
+
const std::set<std::string> mandatoryFields{"version", "name"};
|
| 465 |
+
for (const auto& field : mandatoryFields) {
|
| 466 |
+
if (!config.contains(field)) {
|
| 467 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora adapter field: " + field);
|
| 468 |
+
}
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
// component is used in the "ENFORCE" macros
|
| 472 |
+
const std::string component = "lora adapter";
|
| 473 |
+
LORA_VERSION configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED;
|
| 474 |
+
for (auto& item : config.items()) {
|
| 475 |
+
if (item.key() == "version") {
|
| 476 |
+
JSON_ENFORCE_NUMERIC();
|
| 477 |
+
if (item.value().get<int>() != 1) {
|
| 478 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 479 |
+
"Invalid lora config: unsupported version: " + item.value().dump());
|
| 480 |
+
}
|
| 481 |
+
} else if (item.key() == "name") {
|
| 482 |
+
JSON_ENFORCE_STRING();
|
| 483 |
+
} else if (item.key() == "bin-sections") {
|
| 484 |
+
JSON_ENFORCE_ARRAY();
|
| 485 |
+
configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; // Adapter occurs with V2
|
| 486 |
+
for (auto& elem : item.value()) {
|
| 487 |
+
if (!elem.is_string()) {
|
| 488 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 489 |
+
"bin-sections must be an array of strings");
|
| 490 |
+
}
|
| 491 |
+
}
|
| 492 |
+
} else if (item.key() == "path") {
|
| 493 |
+
configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1; // Weights are V1
|
| 494 |
+
JSON_ENFORCE_STRING();
|
| 495 |
+
// Note:all directory validations will done by NSP engine
|
| 496 |
+
} else {
|
| 497 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 498 |
+
"Unknown lora adapter config key: " + item.key());
|
| 499 |
+
}
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1 &&
|
| 503 |
+
configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2) {
|
| 504 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 505 |
+
"LoRA Adapters must be used with lora version: 2");
|
| 506 |
+
} else if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2 &&
|
| 507 |
+
configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1) {
|
| 508 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 509 |
+
"LoRA Weights must be used with lora version: 1");
|
| 510 |
+
} else if (configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) {
|
| 511 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid lora config.");
|
| 512 |
+
}
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
static void validateLoraConfig(const qualla::json& config) {
|
| 516 |
+
if (!config.is_object()) {
|
| 517 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora config is not an object");
|
| 518 |
+
}
|
| 519 |
+
|
| 520 |
+
const std::set<std::string> mandatoryFields{"version", "adapters"};
|
| 521 |
+
for (const auto& field : mandatoryFields) {
|
| 522 |
+
if (!config.contains(field)) {
|
| 523 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora field: " + field);
|
| 524 |
+
}
|
| 525 |
+
}
|
| 526 |
+
|
| 527 |
+
// component is used in the "ENFORCE" macros
|
| 528 |
+
const std::string component = "lora";
|
| 529 |
+
LORA_VERSION specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; // Default is loraV2
|
| 530 |
+
if (config.find("lora-version") != config.end()) {
|
| 531 |
+
switch (static_cast<uint8_t>(config["lora-version"])) {
|
| 532 |
+
case 1:
|
| 533 |
+
specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1;
|
| 534 |
+
break;
|
| 535 |
+
case 2:
|
| 536 |
+
specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2;
|
| 537 |
+
break;
|
| 538 |
+
default:
|
| 539 |
+
specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED;
|
| 540 |
+
break;
|
| 541 |
+
}
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
for (auto& item : config.items()) {
|
| 545 |
+
if (item.key() == "version") {
|
| 546 |
+
JSON_ENFORCE_NUMERIC();
|
| 547 |
+
if (item.value().get<int>() != 1) {
|
| 548 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 549 |
+
"Invalid lora config: unsupported version: " + item.value().dump());
|
| 550 |
+
}
|
| 551 |
+
} else if (item.key() == "alpha-tensor-name") {
|
| 552 |
+
JSON_ENFORCE_STRING();
|
| 553 |
+
} else if (item.key() == "adapters") {
|
| 554 |
+
JSON_ENFORCE_ARRAY();
|
| 555 |
+
for (auto& elem : item.value()) {
|
| 556 |
+
validateLoraAdapterConfig(elem, specifiedLoraVersion);
|
| 557 |
+
}
|
| 558 |
+
} else if (item.key() == "lora-version") { // Optional
|
| 559 |
+
JSON_ENFORCE_NUMERIC();
|
| 560 |
+
} else {
|
| 561 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lora config key: " + item.key());
|
| 562 |
+
}
|
| 563 |
+
}
|
| 564 |
+
if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) {
|
| 565 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 566 |
+
"Unsupported lora version: " + to_string(config["lora-version"]));
|
| 567 |
+
}
|
| 568 |
+
}
|
| 569 |
+
|
| 570 |
+
static void validateModelBinaryConfig(const qualla::json& config) {
|
| 571 |
+
if (!config.is_object()) {
|
| 572 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "binary config is not an object");
|
| 573 |
+
}
|
| 574 |
+
|
| 575 |
+
std::set<std::string> mandatoryFields{"version", "ctx-bins"};
|
| 576 |
+
for (const auto& field : mandatoryFields) {
|
| 577 |
+
if (!config.contains(field)) {
|
| 578 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary field: " + field);
|
| 579 |
+
}
|
| 580 |
+
}
|
| 581 |
+
|
| 582 |
+
// component is used in the "ENFORCE" macros
|
| 583 |
+
std::string component = "binary";
|
| 584 |
+
|
| 585 |
+
for (auto& item : config.items()) {
|
| 586 |
+
if (item.key() == "version") {
|
| 587 |
+
JSON_ENFORCE_NUMERIC();
|
| 588 |
+
if (item.value().get<int>() != 1) {
|
| 589 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 590 |
+
"Invalid binary config: unsupported version: " + item.value().dump());
|
| 591 |
+
}
|
| 592 |
+
} else if (item.key() == "ctx-bins") {
|
| 593 |
+
JSON_ENFORCE_ARRAY();
|
| 594 |
+
for (auto& elem : item.value()) {
|
| 595 |
+
if (!elem.is_string()) {
|
| 596 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "ctx-bins must be an array of strings");
|
| 597 |
+
}
|
| 598 |
+
}
|
| 599 |
+
} else if (item.key() == "lora") {
|
| 600 |
+
JSON_ENFORCE_OBJECT();
|
| 601 |
+
validateLoraConfig(item.value());
|
| 602 |
+
} else {
|
| 603 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown binary config key: " + item.key());
|
| 604 |
+
}
|
| 605 |
+
}
|
| 606 |
+
}
|
| 607 |
+
|
| 608 |
+
static void validateModelLibraryConfig(const qualla::json& config) {
|
| 609 |
+
if (!config.is_object()) {
|
| 610 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "library config is not an object");
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
std::set<std::string> mandatoryFields{"version", "model-bin"};
|
| 614 |
+
for (const auto& field : mandatoryFields) {
|
| 615 |
+
if (!config.contains(field)) {
|
| 616 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library field: " + field);
|
| 617 |
+
}
|
| 618 |
+
}
|
| 619 |
+
|
| 620 |
+
// component is used in the "ENFORCE" macros
|
| 621 |
+
std::string component = "library";
|
| 622 |
+
|
| 623 |
+
for (auto& item : config.items()) {
|
| 624 |
+
if (item.key() == "version") {
|
| 625 |
+
JSON_ENFORCE_NUMERIC();
|
| 626 |
+
if (item.value().get<int>() != 1) {
|
| 627 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 628 |
+
"Invalid library config: unsupported version: " + item.value().dump());
|
| 629 |
+
}
|
| 630 |
+
} else if (item.key() == "model-bin") {
|
| 631 |
+
JSON_ENFORCE_STRING();
|
| 632 |
+
} else {
|
| 633 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
|
| 634 |
+
}
|
| 635 |
+
}
|
| 636 |
+
}
|
| 637 |
+
|
| 638 |
+
static void validateRopeScalingConfig(const qualla::json& config) {
|
| 639 |
+
// component is used in the "ENFORCE" macros
|
| 640 |
+
std::string component = "rope-scaling";
|
| 641 |
+
if (config.is_object()) {
|
| 642 |
+
std::string ropeType;
|
| 643 |
+
for (auto& item : config.items()) {
|
| 644 |
+
if (item.key() == "rope-type") {
|
| 645 |
+
JSON_ENFORCE_STRING();
|
| 646 |
+
ropeType = item.value().get<std::string>();
|
| 647 |
+
if (ropeType != "llama3" && ropeType != "default" && ropeType != "longrope") {
|
| 648 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Rope type not supported" + ropeType);
|
| 649 |
+
}
|
| 650 |
+
} else if (item.key() == "factor" || item.key() == "low-freq-factor" ||
|
| 651 |
+
item.key() == "high-freq-factor" ||
|
| 652 |
+
item.key() == "original-max-position-embeddings") {
|
| 653 |
+
JSON_ENFORCE_NUMERIC();
|
| 654 |
+
} else if (item.key() == "short-factor" || item.key() == "long-factor") {
|
| 655 |
+
JSON_ENFORCE_ARRAY();
|
| 656 |
+
} else {
|
| 657 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 658 |
+
"Rope scaling parameter not supported " + item.key());
|
| 659 |
+
}
|
| 660 |
+
}
|
| 661 |
+
}
|
| 662 |
+
}
|
| 663 |
+
|
| 664 |
+
static void validatePositionalEncodingConfig(const qualla::json& config) {
|
| 665 |
+
// component is used in the "ENFORCE" macros
|
| 666 |
+
std::string component = "positional-encoding";
|
| 667 |
+
qualla::json ropeScalingConfig;
|
| 668 |
+
if (config.is_object()) {
|
| 669 |
+
for (auto& item : config.items()) {
|
| 670 |
+
if (item.key() == "type") {
|
| 671 |
+
std::string positionEncodingType = item.value().get<std::string>();
|
| 672 |
+
if (positionEncodingType != "rope" && positionEncodingType != "absolute" &&
|
| 673 |
+
positionEncodingType != "alibi") {
|
| 674 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "positional-encoding type not supported");
|
| 675 |
+
}
|
| 676 |
+
} else if (item.key() == "rope-dim") {
|
| 677 |
+
JSON_ENFORCE_NUMERIC();
|
| 678 |
+
} else if (item.key() == "rope-theta") {
|
| 679 |
+
JSON_ENFORCE_NUMERIC();
|
| 680 |
+
} else if (item.key() == "rope-scaling") {
|
| 681 |
+
JSON_ENFORCE_OBJECT();
|
| 682 |
+
ropeScalingConfig = item.value();
|
| 683 |
+
} else {
|
| 684 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 685 |
+
"Unknown positional encoding config key: " + item.key());
|
| 686 |
+
}
|
| 687 |
+
}
|
| 688 |
+
}
|
| 689 |
+
if (position_dim_set) {
|
| 690 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 691 |
+
"Specify one config from pos-id-dim and positional-encoding");
|
| 692 |
+
}
|
| 693 |
+
if (rope_theta_set) {
|
| 694 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 695 |
+
"Specify one config from rope-theta and positional-encoding");
|
| 696 |
+
}
|
| 697 |
+
if (ropeScalingConfig.is_object()) {
|
| 698 |
+
validateRopeScalingConfig(ropeScalingConfig);
|
| 699 |
+
}
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
static void validateModelConfig(const qualla::json& config) {
|
| 703 |
+
if (!config.is_object()) {
|
| 704 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "model config is not an object");
|
| 705 |
+
}
|
| 706 |
+
|
| 707 |
+
std::set<std::string> mandatoryFields{"version", "type"};
|
| 708 |
+
for (const auto& field : mandatoryFields) {
|
| 709 |
+
if (!config.contains(field)) {
|
| 710 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing model field: " + field);
|
| 711 |
+
}
|
| 712 |
+
}
|
| 713 |
+
|
| 714 |
+
// component is used in the "ENFORCE" macros
|
| 715 |
+
std::string component = "model";
|
| 716 |
+
|
| 717 |
+
std::string type;
|
| 718 |
+
bool binary = false;
|
| 719 |
+
qualla::json binaryConfig;
|
| 720 |
+
bool library = false;
|
| 721 |
+
qualla::json libraryConfig;
|
| 722 |
+
qualla::json positionalEncodingConfig;
|
| 723 |
+
bool positionalEncoding = false;
|
| 724 |
+
|
| 725 |
+
for (auto& item : config.items()) {
|
| 726 |
+
if (item.key() == "version") {
|
| 727 |
+
JSON_ENFORCE_NUMERIC();
|
| 728 |
+
if (item.value().get<int>() != 1) {
|
| 729 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 730 |
+
"Invalid model config: unsupported version: " + item.value().dump());
|
| 731 |
+
}
|
| 732 |
+
} else if (item.key() == "type") {
|
| 733 |
+
JSON_ENFORCE_STRING();
|
| 734 |
+
type = item.value().get<std::string>();
|
| 735 |
+
if (type == "binary") {
|
| 736 |
+
binary = true;
|
| 737 |
+
} else if (type == "library") {
|
| 738 |
+
library = true;
|
| 739 |
+
} else {
|
| 740 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 741 |
+
"Invalid model config: unsupported type: " + item.value().dump());
|
| 742 |
+
}
|
| 743 |
+
} else if (item.key() == "binary") {
|
| 744 |
+
JSON_ENFORCE_OBJECT();
|
| 745 |
+
binaryConfig = item.value();
|
| 746 |
+
} else if (item.key() == "library") {
|
| 747 |
+
JSON_ENFORCE_OBJECT();
|
| 748 |
+
libraryConfig = item.value();
|
| 749 |
+
} else if (item.key() == "positional-encoding") {
|
| 750 |
+
JSON_ENFORCE_OBJECT();
|
| 751 |
+
positionalEncodingConfig = item.value();
|
| 752 |
+
positionalEncoding = true;
|
| 753 |
+
} else {
|
| 754 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown model config key: " + item.key());
|
| 755 |
+
}
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
if (binary) {
|
| 759 |
+
if (!binaryConfig.is_object()) {
|
| 760 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary model config");
|
| 761 |
+
}
|
| 762 |
+
validateModelBinaryConfig(binaryConfig);
|
| 763 |
+
} else {
|
| 764 |
+
if (binaryConfig.is_object()) {
|
| 765 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 766 |
+
"binary model config for incorrect model type: " + type);
|
| 767 |
+
}
|
| 768 |
+
}
|
| 769 |
+
|
| 770 |
+
if (library) {
|
| 771 |
+
if (!libraryConfig.is_object()) {
|
| 772 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library model config");
|
| 773 |
+
}
|
| 774 |
+
validateModelLibraryConfig(libraryConfig);
|
| 775 |
+
} else {
|
| 776 |
+
if (libraryConfig.is_object()) {
|
| 777 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 778 |
+
"library model config for incorrect model type: " + type);
|
| 779 |
+
}
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
if (positionalEncoding) {
|
| 783 |
+
if (!positionalEncodingConfig.is_object()) {
|
| 784 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing Positional encoding config");
|
| 785 |
+
}
|
| 786 |
+
validatePositionalEncodingConfig(positionalEncodingConfig);
|
| 787 |
+
} else {
|
| 788 |
+
if (positionalEncodingConfig.is_object()) {
|
| 789 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 790 |
+
"Positional encoding config for incorrect model type: " + type);
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
//=============================================================================
|
| 796 |
+
// Engine::Config functions
|
| 797 |
+
//=============================================================================
|
| 798 |
+
|
| 799 |
+
static void validateEngineConfig(const qualla::json& config, std::string dialogType) {
|
| 800 |
+
if (!config.is_object()) {
|
| 801 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object");
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
std::set<std::string> mandatoryFields{"version", "backend", "model", "n-threads"};
|
| 805 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 806 |
+
if (dialogType == "spd") {
|
| 807 |
+
mandatoryFields.insert("role");
|
| 808 |
+
}
|
| 809 |
+
#endif
|
| 810 |
+
if (dialogType == "kv-share") {
|
| 811 |
+
mandatoryFields.insert("role");
|
| 812 |
+
}
|
| 813 |
+
|
| 814 |
+
for (const auto& field : mandatoryFields) {
|
| 815 |
+
if (!config.contains(field)) {
|
| 816 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing engine field: " + field);
|
| 817 |
+
}
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
// component is used in the "ENFORCE" macros
|
| 821 |
+
std::string component = "engine";
|
| 822 |
+
|
| 823 |
+
for (auto& item : config.items()) {
|
| 824 |
+
if (item.key() == "version") {
|
| 825 |
+
JSON_ENFORCE_NUMERIC();
|
| 826 |
+
if (item.value().get<int>() != 1) {
|
| 827 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 828 |
+
"Invalid engine config: unsupported version: " + item.value().dump());
|
| 829 |
+
}
|
| 830 |
+
} else if (item.key() == "backend") {
|
| 831 |
+
JSON_ENFORCE_OBJECT();
|
| 832 |
+
validateBackendConfig(item.value());
|
| 833 |
+
} else if (item.key() == "model") {
|
| 834 |
+
JSON_ENFORCE_OBJECT();
|
| 835 |
+
validateModelConfig(item.value());
|
| 836 |
+
} else if (item.key() == "n-threads") {
|
| 837 |
+
JSON_ENFORCE_NUMERIC();
|
| 838 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 839 |
+
} else if (item.key() == "role" && dialogType == "spd") {
|
| 840 |
+
JSON_ENFORCE_STRING();
|
| 841 |
+
if (item.value() != "draft" && item.value() != "target") {
|
| 842 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 843 |
+
"Unknown value: for engine config key: " + item.key());
|
| 844 |
+
}
|
| 845 |
+
#endif
|
| 846 |
+
} else if (item.key() == "role" && dialogType == "kv-share") {
|
| 847 |
+
JSON_ENFORCE_STRING();
|
| 848 |
+
if (item.value() != "primary" && item.value() != "secondary") {
|
| 849 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 850 |
+
"Unknown value: for engine config key: " + item.key());
|
| 851 |
+
}
|
| 852 |
+
} else {
|
| 853 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown engine config key: " + item.key());
|
| 854 |
+
}
|
| 855 |
+
}
|
| 856 |
+
}
|
| 857 |
+
|
| 858 |
+
static void validateMultiEngineConfig(const qualla::json& configs, std::string dialogType) {
|
| 859 |
+
if (configs.is_object()) {
|
| 860 |
+
validateEngineConfig(configs, dialogType);
|
| 861 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 862 |
+
if (dialogType == "spd") {
|
| 863 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for spd is not an array");
|
| 864 |
+
}
|
| 865 |
+
#endif
|
| 866 |
+
if (dialogType == "kv-share") {
|
| 867 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for kv-share is not an array");
|
| 868 |
+
}
|
| 869 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 870 |
+
} else if (configs.is_array() && dialogType == "spd") {
|
| 871 |
+
if (configs.size() != 2) {
|
| 872 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 873 |
+
"engine config for spd contain invalid number of engines");
|
| 874 |
+
}
|
| 875 |
+
bool engineRoleMask[2] = {false, false};
|
| 876 |
+
for (auto& item : configs) {
|
| 877 |
+
validateEngineConfig(item, dialogType);
|
| 878 |
+
if (item["role"] == "draft") {
|
| 879 |
+
engineRoleMask[0] = true;
|
| 880 |
+
} else if (item["role"] == "target") {
|
| 881 |
+
engineRoleMask[1] = true;
|
| 882 |
+
}
|
| 883 |
+
}
|
| 884 |
+
if (!engineRoleMask[0]) {
|
| 885 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 886 |
+
"engine config for spd does not contain draft engine");
|
| 887 |
+
}
|
| 888 |
+
if (!engineRoleMask[1]) {
|
| 889 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 890 |
+
"engine config for spd does not contain target engine");
|
| 891 |
+
}
|
| 892 |
+
#endif
|
| 893 |
+
} else if (configs.is_array() && dialogType == "kv-share") {
|
| 894 |
+
if (configs.size() != 2) {
|
| 895 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 896 |
+
"engine config for kv-share contain invalid number of engines");
|
| 897 |
+
}
|
| 898 |
+
bool engineRoleMask[2] = {false, false};
|
| 899 |
+
for (auto& item : configs) {
|
| 900 |
+
validateEngineConfig(item, dialogType);
|
| 901 |
+
if (item["role"] == "primary") {
|
| 902 |
+
engineRoleMask[0] = true;
|
| 903 |
+
} else if (item["role"] == "secondary") {
|
| 904 |
+
engineRoleMask[1] = true;
|
| 905 |
+
}
|
| 906 |
+
}
|
| 907 |
+
if (!engineRoleMask[0]) {
|
| 908 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 909 |
+
"engine config for kv-share does not contain primary");
|
| 910 |
+
}
|
| 911 |
+
if (!engineRoleMask[1]) {
|
| 912 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 913 |
+
"engine config for kv-share does not contain secondary");
|
| 914 |
+
}
|
| 915 |
+
} else {
|
| 916 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object or an array");
|
| 917 |
+
}
|
| 918 |
+
}
|
| 919 |
+
|
| 920 |
+
static void translateEngineConfig(const qualla::json& genieEngineConfig,
|
| 921 |
+
qualla::json& quallaEngineConfig) {
|
| 922 |
+
if (genieEngineConfig["version"] == 1) {
|
| 923 |
+
if (genieEngineConfig.contains("role")) {
|
| 924 |
+
quallaEngineConfig["role"] = genieEngineConfig["role"];
|
| 925 |
+
} else {
|
| 926 |
+
quallaEngineConfig["role"] = "primary";
|
| 927 |
+
}
|
| 928 |
+
|
| 929 |
+
quallaEngineConfig["n-threads"] = genieEngineConfig["n-threads"];
|
| 930 |
+
|
| 931 |
+
if (genieEngineConfig["backend"]["type"] == "QnnHtp") {
|
| 932 |
+
quallaEngineConfig["type"] = "qnn-htp";
|
| 933 |
+
quallaEngineConfig["backend-lib"] = getLibName("QnnHtp");
|
| 934 |
+
quallaEngineConfig["mmap-budget"] = genieEngineConfig["backend"]["QnnHtp"]["mmap-budget"];
|
| 935 |
+
quallaEngineConfig["use-mmap"] = genieEngineConfig["backend"]["QnnHtp"]["use-mmap"];
|
| 936 |
+
quallaEngineConfig["spill-fill-bufsize"] =
|
| 937 |
+
genieEngineConfig["backend"]["QnnHtp"]["spill-fill-bufsize"];
|
| 938 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("pos-id-dim")) {
|
| 939 |
+
quallaEngineConfig["pos-id-dim"] = genieEngineConfig["backend"]["QnnHtp"]["pos-id-dim"];
|
| 940 |
+
}
|
| 941 |
+
quallaEngineConfig["cpumask"] = genieEngineConfig["backend"]["QnnHtp"]["cpu-mask"];
|
| 942 |
+
quallaEngineConfig["poll"] = genieEngineConfig["backend"]["QnnHtp"]["poll"];
|
| 943 |
+
quallaEngineConfig["kv-dim"] = genieEngineConfig["backend"]["QnnHtp"]["kv-dim"];
|
| 944 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("rope-theta")) {
|
| 945 |
+
quallaEngineConfig["rope-theta"] = genieEngineConfig["backend"]["QnnHtp"]["rope-theta"];
|
| 946 |
+
}
|
| 947 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("kv-update-method")) {
|
| 948 |
+
quallaEngineConfig["kv-update-method"] =
|
| 949 |
+
genieEngineConfig["backend"]["QnnHtp"]["kv-update-method"];
|
| 950 |
+
}
|
| 951 |
+
// By default, Qualla will default to the async init path.
|
| 952 |
+
// For now, we are forcing async init off unless explicitly
|
| 953 |
+
// specified in the Genie config. It is HTP specific feature only.
|
| 954 |
+
quallaEngineConfig["use-async-Init"] = false;
|
| 955 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("allow-async-init")) {
|
| 956 |
+
quallaEngineConfig["use-async-Init"] =
|
| 957 |
+
genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
|
| 958 |
+
}
|
| 959 |
+
} else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
|
| 960 |
+
quallaEngineConfig["type"] = "qnn-cpu";
|
| 961 |
+
quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
|
| 962 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-logits")) {
|
| 963 |
+
quallaEngineConfig["n_logits"] =
|
| 964 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-logits"];
|
| 965 |
+
}
|
| 966 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("use-mmap")) {
|
| 967 |
+
quallaEngineConfig["use-mmap"] =
|
| 968 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["use-mmap"];
|
| 969 |
+
}
|
| 970 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-layer")) {
|
| 971 |
+
quallaEngineConfig["n_layer"] =
|
| 972 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-layer"];
|
| 973 |
+
}
|
| 974 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-embd")) {
|
| 975 |
+
quallaEngineConfig["n_embd"] =
|
| 976 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-embd"];
|
| 977 |
+
}
|
| 978 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-heads")) {
|
| 979 |
+
quallaEngineConfig["n_heads"] =
|
| 980 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
|
| 981 |
+
}
|
| 982 |
+
}
|
| 983 |
+
|
| 984 |
+
if (genieEngineConfig["backend"].contains("extensions")) {
|
| 985 |
+
quallaEngineConfig["backend-ext-conf"] = genieEngineConfig["backend"]["extensions"];
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
if (genieEngineConfig["model"]["type"] == "binary") {
|
| 989 |
+
quallaEngineConfig["model-list"] = genieEngineConfig["model"]["binary"]["ctx-bins"];
|
| 990 |
+
if (genieEngineConfig["model"]["binary"].contains("lora")) {
|
| 991 |
+
quallaEngineConfig["lora-version"] =
|
| 992 |
+
static_cast<uint8_t>(LORA_VERSION::GENIE_LORA_VERSION_V2);
|
| 993 |
+
if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") &&
|
| 994 |
+
genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) {
|
| 995 |
+
quallaEngineConfig["lora-version"] =
|
| 996 |
+
genieEngineConfig["model"]["binary"]["lora"]["lora-version"];
|
| 997 |
+
}
|
| 998 |
+
for (int i = 0; i < genieEngineConfig["model"]["binary"]["lora"]["adapters"].size(); i++) {
|
| 999 |
+
quallaEngineConfig["lora"][i]["adapter-name"] =
|
| 1000 |
+
genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["name"];
|
| 1001 |
+
quallaEngineConfig["lora"][i]["alpha-tensor-name"] = "";
|
| 1002 |
+
if (genieEngineConfig["model"]["binary"]["lora"].contains("alpha-tensor-name")) {
|
| 1003 |
+
quallaEngineConfig["lora"][i]["alpha-tensor-name"] =
|
| 1004 |
+
genieEngineConfig["model"]["binary"]["lora"]["alpha-tensor-name"];
|
| 1005 |
+
}
|
| 1006 |
+
quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f;
|
| 1007 |
+
quallaEngineConfig["lora"][i]["binsection-basedir"] = "";
|
| 1008 |
+
if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") &&
|
| 1009 |
+
genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) {
|
| 1010 |
+
quallaEngineConfig["lora"][i]["path"] =
|
| 1011 |
+
genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["path"];
|
| 1012 |
+
} else {
|
| 1013 |
+
quallaEngineConfig["lora"][i]["bin-sections"] =
|
| 1014 |
+
genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["bin-sections"];
|
| 1015 |
+
}
|
| 1016 |
+
}
|
| 1017 |
+
}
|
| 1018 |
+
} else if (genieEngineConfig["model"]["type"] == "library") {
|
| 1019 |
+
quallaEngineConfig["model"] = getLibName("QnnGenAiTransformerModel");
|
| 1020 |
+
quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
|
| 1021 |
+
quallaEngineConfig["op-package"] =
|
| 1022 |
+
getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
|
| 1023 |
+
}
|
| 1024 |
+
if (genieEngineConfig["model"].contains("positional-encoding")) {
|
| 1025 |
+
quallaEngineConfig["positional-encoding"]["type"] =
|
| 1026 |
+
genieEngineConfig["model"]["positional-encoding"]["type"];
|
| 1027 |
+
if (genieEngineConfig["model"]["positional-encoding"]["type"] == "rope") {
|
| 1028 |
+
quallaEngineConfig["positional-encoding"]["rope-dim"] =
|
| 1029 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-dim"];
|
| 1030 |
+
if (genieEngineConfig["model"]["positional-encoding"].contains("rope-theta")) {
|
| 1031 |
+
quallaEngineConfig["positional-encoding"]["rope-theta"] =
|
| 1032 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-theta"];
|
| 1033 |
+
}
|
| 1034 |
+
if (genieEngineConfig["model"]["positional-encoding"].contains("rope-scaling")) {
|
| 1035 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1036 |
+
"rope-type")) {
|
| 1037 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["rope-type"] =
|
| 1038 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"];
|
| 1039 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] ==
|
| 1040 |
+
"llama3") {
|
| 1041 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1042 |
+
"factor")) {
|
| 1043 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] =
|
| 1044 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"];
|
| 1045 |
+
}
|
| 1046 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1047 |
+
"low-freq-factor")) {
|
| 1048 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["low-freq-factor"] =
|
| 1049 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
|
| 1050 |
+
["low-freq-factor"];
|
| 1051 |
+
}
|
| 1052 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1053 |
+
"high-freq-factor")) {
|
| 1054 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["high-freq-factor"] =
|
| 1055 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
|
| 1056 |
+
["high-freq-factor"];
|
| 1057 |
+
}
|
| 1058 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1059 |
+
"original-max-position-embeddings")) {
|
| 1060 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]
|
| 1061 |
+
["original-max-position-embeddings"] =
|
| 1062 |
+
genieEngineConfig["model"]["positional-encoding"]
|
| 1063 |
+
["rope-scaling"]
|
| 1064 |
+
["original-max-position-embeddings"];
|
| 1065 |
+
}
|
| 1066 |
+
}
|
| 1067 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] ==
|
| 1068 |
+
"longrope") {
|
| 1069 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1070 |
+
"factor")) {
|
| 1071 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] =
|
| 1072 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"];
|
| 1073 |
+
}
|
| 1074 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1075 |
+
"short-factor")) {
|
| 1076 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["short-factor"] =
|
| 1077 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
|
| 1078 |
+
["short-factor"];
|
| 1079 |
+
}
|
| 1080 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1081 |
+
"long-factor")) {
|
| 1082 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]["long-factor"] =
|
| 1083 |
+
genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
|
| 1084 |
+
["long-factor"];
|
| 1085 |
+
}
|
| 1086 |
+
if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
|
| 1087 |
+
"original-max-position-embeddings")) {
|
| 1088 |
+
quallaEngineConfig["positional-encoding"]["rope-scaling"]
|
| 1089 |
+
["original-max-position-embeddings"] =
|
| 1090 |
+
genieEngineConfig["model"]["positional-encoding"]
|
| 1091 |
+
["rope-scaling"]
|
| 1092 |
+
["original-max-position-embeddings"];
|
| 1093 |
+
}
|
| 1094 |
+
}
|
| 1095 |
+
}
|
| 1096 |
+
}
|
| 1097 |
+
}
|
| 1098 |
+
}
|
| 1099 |
+
}
|
| 1100 |
+
}
|
| 1101 |
+
|
| 1102 |
+
static void translateMultiEngineConfig(const qualla::json& genieConfig,
|
| 1103 |
+
qualla::json& quallaConfig) {
|
| 1104 |
+
if (genieConfig["dialog"]["engine"].is_array()) {
|
| 1105 |
+
quallaConfig["engine"] = qualla::json::array();
|
| 1106 |
+
for (auto& item : genieConfig["dialog"]["engine"]) {
|
| 1107 |
+
qualla::json quallaEngineConfig;
|
| 1108 |
+
translateEngineConfig(item, quallaEngineConfig);
|
| 1109 |
+
quallaConfig["engine"].push_back(quallaEngineConfig);
|
| 1110 |
+
}
|
| 1111 |
+
} else {
|
| 1112 |
+
translateEngineConfig(genieConfig["dialog"]["engine"], quallaConfig["engine"]);
|
| 1113 |
+
}
|
| 1114 |
+
}
|
| 1115 |
+
|
| 1116 |
+
//=============================================================================
|
| 1117 |
+
// Dialog::Config functions
|
| 1118 |
+
//=============================================================================
|
| 1119 |
+
|
| 1120 |
+
qnn::util::HandleManager<Dialog::Config> Dialog::Config::s_manager;
|
| 1121 |
+
|
| 1122 |
+
GenieDialogConfig_Handle_t Dialog::Config::add(std::shared_ptr<Dialog::Config> config) {
|
| 1123 |
+
return (GenieDialogConfig_Handle_t)s_manager.add(config);
|
| 1124 |
+
}
|
| 1125 |
+
|
| 1126 |
+
std::shared_ptr<Dialog::Config> Dialog::Config::get(GenieDialogConfig_Handle_t handle) {
|
| 1127 |
+
return s_manager.get((qnn::util::Handle_t)handle);
|
| 1128 |
+
}
|
| 1129 |
+
|
| 1130 |
+
void Dialog::Config::remove(GenieDialogConfig_Handle_t handle) {
|
| 1131 |
+
s_manager.remove((qnn::util::Handle_t)handle);
|
| 1132 |
+
}
|
| 1133 |
+
|
| 1134 |
+
#if defined(GENIE_SSD_FEATURE)
|
| 1135 |
+
static void validateDialogSsdConfig(const qualla::json& config) {
|
| 1136 |
+
if (!config.is_object()) {
|
| 1137 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "ssd-q1 config is not an object");
|
| 1138 |
+
}
|
| 1139 |
+
|
| 1140 |
+
std::set<std::string> mandatoryFields{"version",
|
| 1141 |
+
"ssd-version",
|
| 1142 |
+
"forecast-token-count",
|
| 1143 |
+
"branches",
|
| 1144 |
+
"forecast-prefix",
|
| 1145 |
+
"forecast-prefix-name"};
|
| 1146 |
+
for (const auto& field : mandatoryFields) {
|
| 1147 |
+
if (!config.contains(field)) {
|
| 1148 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 field: " + field);
|
| 1149 |
+
}
|
| 1150 |
+
}
|
| 1151 |
+
|
| 1152 |
+
// component is used in the "ENFORCE" macros
|
| 1153 |
+
std::string component = "ssd-q1";
|
| 1154 |
+
|
| 1155 |
+
int branchesSize = 0;
|
| 1156 |
+
int forecastTokenCount = 0;
|
| 1157 |
+
|
| 1158 |
+
int nStreams = 1;
|
| 1159 |
+
float pThreshold = 0.0;
|
| 1160 |
+
|
| 1161 |
+
for (auto& item : config.items()) {
|
| 1162 |
+
if (item.key() == "version") {
|
| 1163 |
+
JSON_ENFORCE_NUMERIC();
|
| 1164 |
+
if (item.value().get<int>() != 1) {
|
| 1165 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1166 |
+
"Invalid ssd-q1 config: unsupported version: " + item.value().dump());
|
| 1167 |
+
}
|
| 1168 |
+
} else if (item.key() == "ssd-version") {
|
| 1169 |
+
JSON_ENFORCE_NUMERIC();
|
| 1170 |
+
} else if (item.key() == "forecast-token-count") {
|
| 1171 |
+
JSON_ENFORCE_NUMERIC();
|
| 1172 |
+
forecastTokenCount = item.value();
|
| 1173 |
+
} else if (item.key() == "branches") {
|
| 1174 |
+
JSON_ENFORCE_ARRAY();
|
| 1175 |
+
for (auto& elem : item.value()) {
|
| 1176 |
+
if (!elem.is_number_integer()) {
|
| 1177 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "branches must be an array of integers");
|
| 1178 |
+
}
|
| 1179 |
+
}
|
| 1180 |
+
branchesSize = item.value().size();
|
| 1181 |
+
} else if (item.key() == "forecast-prefix") {
|
| 1182 |
+
JSON_ENFORCE_NUMERIC();
|
| 1183 |
+
} else if (item.key() == "forecast-prefix-name") {
|
| 1184 |
+
JSON_ENFORCE_STRING();
|
| 1185 |
+
} else if (item.key() == "n-streams") {
|
| 1186 |
+
JSON_ENFORCE_NUMERIC();
|
| 1187 |
+
nStreams = item.value();
|
| 1188 |
+
} else if (item.key() == "p-threshold") {
|
| 1189 |
+
JSON_ENFORCE_NUMERIC();
|
| 1190 |
+
pThreshold = item.value();
|
| 1191 |
+
} else {
|
| 1192 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown ssd-q1 config key: " + item.key());
|
| 1193 |
+
}
|
| 1194 |
+
}
|
| 1195 |
+
|
| 1196 |
+
if ((pThreshold > 0.0) && (nStreams <= 1)) {
|
| 1197 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1198 |
+
"p-threshold can only be used with multistream (n-streams > 1)");
|
| 1199 |
+
}
|
| 1200 |
+
|
| 1201 |
+
if (branchesSize > forecastTokenCount) {
|
| 1202 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1203 |
+
"Size of branches array must be less than forecast-token-count");
|
| 1204 |
+
}
|
| 1205 |
+
}
|
| 1206 |
+
#endif
|
| 1207 |
+
|
| 1208 |
+
#if defined(GENIE_LADE_FEATURE)
|
| 1209 |
+
static void validateDialogLadeConfig(const qualla::json& config) {
|
| 1210 |
+
if (!config.is_object()) {
|
| 1211 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lade config is not an object");
|
| 1212 |
+
}
|
| 1213 |
+
|
| 1214 |
+
std::set<std::string> mandatoryFields{"version", "update-mode", "window", "ngram", "gcap"};
|
| 1215 |
+
for (const auto& field : mandatoryFields) {
|
| 1216 |
+
if (!config.contains(field)) {
|
| 1217 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade field: " + field);
|
| 1218 |
+
}
|
| 1219 |
+
}
|
| 1220 |
+
|
| 1221 |
+
// component is used in the "ENFORCE" macros
|
| 1222 |
+
std::string component = "lade";
|
| 1223 |
+
|
| 1224 |
+
for (auto& item : config.items()) {
|
| 1225 |
+
if (item.key() == "version") {
|
| 1226 |
+
JSON_ENFORCE_NUMERIC();
|
| 1227 |
+
if (item.value().get<int>() != 1) {
|
| 1228 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1229 |
+
"Invalid lade config: unsupported version: " + item.value().dump());
|
| 1230 |
+
}
|
| 1231 |
+
} else if (item.key() == "update-mode") {
|
| 1232 |
+
JSON_ENFORCE_STRING();
|
| 1233 |
+
std::string mode = item.value().get<std::string>();
|
| 1234 |
+
if ((mode != "FWD_MAX_HIT") && (mode != "FWD_LEVEL") && (mode != "ALWAYS_FWD_ONE")) {
|
| 1235 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1236 |
+
"Invalid lade config: unsupported update-mode: " + item.value().dump());
|
| 1237 |
+
}
|
| 1238 |
+
} else if (item.key() == "window") {
|
| 1239 |
+
JSON_ENFORCE_NUMERIC();
|
| 1240 |
+
} else if (item.key() == "ngram") {
|
| 1241 |
+
JSON_ENFORCE_NUMERIC();
|
| 1242 |
+
} else if (item.key() == "gcap") {
|
| 1243 |
+
JSON_ENFORCE_NUMERIC();
|
| 1244 |
+
} else {
|
| 1245 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lade config key: " + item.key());
|
| 1246 |
+
}
|
| 1247 |
+
}
|
| 1248 |
+
}
|
| 1249 |
+
#endif
|
| 1250 |
+
|
| 1251 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 1252 |
+
static void validateDialogSpdConfig(const qualla::json& config) {
|
| 1253 |
+
if (!config.is_object()) {
|
| 1254 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "spd config is not an object");
|
| 1255 |
+
}
|
| 1256 |
+
|
| 1257 |
+
std::set<std::string> mandatoryFields{"version", "draft-len"};
|
| 1258 |
+
for (const auto& field : mandatoryFields) {
|
| 1259 |
+
if (!config.contains(field)) {
|
| 1260 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd field: " + field);
|
| 1261 |
+
}
|
| 1262 |
+
}
|
| 1263 |
+
|
| 1264 |
+
// component is used in the "ENFORCE" macros
|
| 1265 |
+
std::string component = "spd";
|
| 1266 |
+
for (auto& item : config.items()) {
|
| 1267 |
+
if (item.key() == "version") {
|
| 1268 |
+
JSON_ENFORCE_NUMERIC();
|
| 1269 |
+
if (item.value().get<int>() != 1) {
|
| 1270 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1271 |
+
"Invalid spd config: unsupported version: " + item.value().dump());
|
| 1272 |
+
}
|
| 1273 |
+
} else if (item.key() == "draft-len") {
|
| 1274 |
+
JSON_ENFORCE_NUMERIC();
|
| 1275 |
+
} else {
|
| 1276 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown spd config key: " + item.key());
|
| 1277 |
+
}
|
| 1278 |
+
}
|
| 1279 |
+
}
|
| 1280 |
+
#endif
|
| 1281 |
+
|
| 1282 |
+
#if defined(GENIE_MULTISTREAM_FEATURE)
|
| 1283 |
+
static void validateDialogMultistreamConfig(const qualla::json& config) {
|
| 1284 |
+
if (!config.is_object()) {
|
| 1285 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "multistream config is not an object");
|
| 1286 |
+
}
|
| 1287 |
+
|
| 1288 |
+
std::set<std::string> mandatoryFields{"version", "n-streams"};
|
| 1289 |
+
for (const auto& field : mandatoryFields) {
|
| 1290 |
+
if (!config.contains(field)) {
|
| 1291 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream field: " + field);
|
| 1292 |
+
}
|
| 1293 |
+
}
|
| 1294 |
+
|
| 1295 |
+
// component is used in the "ENFORCE" macros
|
| 1296 |
+
std::string component = "multistream";
|
| 1297 |
+
|
| 1298 |
+
for (auto& item : config.items()) {
|
| 1299 |
+
if (item.key() == "version") {
|
| 1300 |
+
JSON_ENFORCE_NUMERIC();
|
| 1301 |
+
if (item.value().get<int>() != 1) {
|
| 1302 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1303 |
+
"Invalid multistream config: unsupported version: " + item.value().dump());
|
| 1304 |
+
}
|
| 1305 |
+
} else if (item.key() == "n-streams") {
|
| 1306 |
+
JSON_ENFORCE_NUMERIC();
|
| 1307 |
+
} else if (item.key() == "p-threshold") {
|
| 1308 |
+
JSON_ENFORCE_NUMERIC();
|
| 1309 |
+
} else {
|
| 1310 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1311 |
+
"Unknown multistream config key: " + item.key());
|
| 1312 |
+
}
|
| 1313 |
+
}
|
| 1314 |
+
}
|
| 1315 |
+
#endif
|
| 1316 |
+
|
| 1317 |
+
static void validateDialogConfig(const qualla::json& config) {
|
| 1318 |
+
if (!config.is_object()) {
|
| 1319 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object");
|
| 1320 |
+
}
|
| 1321 |
+
|
| 1322 |
+
std::set<std::string> mandatoryFields{"version", "type", "context", "tokenizer", "engine"};
|
| 1323 |
+
for (const auto& field : mandatoryFields) {
|
| 1324 |
+
if (!config.contains(field)) {
|
| 1325 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field);
|
| 1326 |
+
}
|
| 1327 |
+
}
|
| 1328 |
+
|
| 1329 |
+
// component is used in the "ENFORCE" macros
|
| 1330 |
+
std::string component = "dialog";
|
| 1331 |
+
|
| 1332 |
+
std::string dialogType = "basic";
|
| 1333 |
+
#if defined(GENIE_SSD_FEATURE)
|
| 1334 |
+
bool ssdq1 = false;
|
| 1335 |
+
qualla::json ssdq1Config;
|
| 1336 |
+
#endif
|
| 1337 |
+
#if defined(GENIE_LADE_FEATURE)
|
| 1338 |
+
bool lade = false;
|
| 1339 |
+
qualla::json ladeConfig;
|
| 1340 |
+
#endif
|
| 1341 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 1342 |
+
bool spd = false;
|
| 1343 |
+
qualla::json spdConfig;
|
| 1344 |
+
#endif
|
| 1345 |
+
#if defined(GENIE_MULTISTREAM_FEATURE)
|
| 1346 |
+
bool multistream = false;
|
| 1347 |
+
qualla::json multistreamConfig;
|
| 1348 |
+
#endif
|
| 1349 |
+
|
| 1350 |
+
for (auto& item : config.items()) {
|
| 1351 |
+
if (item.key() == "version") {
|
| 1352 |
+
JSON_ENFORCE_NUMERIC();
|
| 1353 |
+
if (item.value().get<int>() != 1) {
|
| 1354 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1355 |
+
"Invalid dialog config: unsupported version: " + item.value().dump());
|
| 1356 |
+
}
|
| 1357 |
+
} else if (item.key() == "type") {
|
| 1358 |
+
JSON_ENFORCE_STRING();
|
| 1359 |
+
dialogType = item.value();
|
| 1360 |
+
if (dialogType == "basic" || dialogType == "kv-share") {
|
| 1361 |
+
// Do nothing
|
| 1362 |
+
#if defined(GENIE_SSD_FEATURE)
|
| 1363 |
+
} else if (dialogType == "ssd-q1") {
|
| 1364 |
+
ssdq1 = true;
|
| 1365 |
+
#endif
|
| 1366 |
+
#if defined(GENIE_LADE_FEATURE)
|
| 1367 |
+
} else if (dialogType == "lade") {
|
| 1368 |
+
lade = true;
|
| 1369 |
+
#endif
|
| 1370 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 1371 |
+
} else if (dialogType == "spd") {
|
| 1372 |
+
spd = true;
|
| 1373 |
+
#endif
|
| 1374 |
+
#if defined(GENIE_MULTISTREAM_FEATURE)
|
| 1375 |
+
} else if (dialogType == "multistream") {
|
| 1376 |
+
multistream = true;
|
| 1377 |
+
#endif
|
| 1378 |
+
} else {
|
| 1379 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "Invalid dialog type: " + dialogType);
|
| 1380 |
+
}
|
| 1381 |
+
#if defined(GENIE_SSD_FEATURE)
|
| 1382 |
+
} else if (item.key() == "ssd-q1") {
|
| 1383 |
+
JSON_ENFORCE_OBJECT();
|
| 1384 |
+
ssdq1Config = item.value();
|
| 1385 |
+
// ssd-q1 validation is done below
|
| 1386 |
+
#endif
|
| 1387 |
+
#if defined(GENIE_LADE_FEATURE)
|
| 1388 |
+
} else if (item.key() == "lade") {
|
| 1389 |
+
JSON_ENFORCE_OBJECT();
|
| 1390 |
+
ladeConfig = item.value();
|
| 1391 |
+
// ssd-q1 validation is done below
|
| 1392 |
+
#endif
|
| 1393 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 1394 |
+
} else if (item.key() == "spd") {
|
| 1395 |
+
JSON_ENFORCE_OBJECT();
|
| 1396 |
+
spdConfig = item.value();
|
| 1397 |
+
// spd validation is done below
|
| 1398 |
+
#endif
|
| 1399 |
+
#if defined(GENIE_MULTISTREAM_FEATURE)
|
| 1400 |
+
} else if (item.key() == "multistream") {
|
| 1401 |
+
JSON_ENFORCE_OBJECT();
|
| 1402 |
+
multistreamConfig = item.value();
|
| 1403 |
+
// multistream validation is done below
|
| 1404 |
+
#endif
|
| 1405 |
+
} else if (item.key() == "stop-sequence") {
|
| 1406 |
+
JSON_ENFORCE_ARRAY();
|
| 1407 |
+
for (auto& elem : item.value()) {
|
| 1408 |
+
if (!elem.is_string()) {
|
| 1409 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1410 |
+
"stop-sequence must be an array of strings");
|
| 1411 |
+
}
|
| 1412 |
+
}
|
| 1413 |
+
} else if (item.key() == "max-num-tokens") {
|
| 1414 |
+
JSON_ENFORCE_NUMERIC();
|
| 1415 |
+
if (item.value().get<int>() < 0) {
|
| 1416 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 1417 |
+
"number of tokens must be > 0. provided: " + item.value().dump());
|
| 1418 |
+
}
|
| 1419 |
+
} else if (item.key() == "context") {
|
| 1420 |
+
JSON_ENFORCE_OBJECT();
|
| 1421 |
+
validateContextConfig(item.value());
|
| 1422 |
+
} else if (item.key() == "tokenizer") {
|
| 1423 |
+
JSON_ENFORCE_OBJECT();
|
| 1424 |
+
validateTokenizerConfig(item.value());
|
| 1425 |
+
} else if (item.key() == "sampler") {
|
| 1426 |
+
JSON_ENFORCE_OBJECT();
|
| 1427 |
+
validateSamplerConfig(item.value());
|
| 1428 |
+
} else if (item.key() == "engine") {
|
| 1429 |
+
JSON_ENFORCE_ARRAY_OR_OBJECT();
|
| 1430 |
+
} else if (item.key() == "embedding") {
|
| 1431 |
+
JSON_ENFORCE_OBJECT();
|
| 1432 |
+
validateEmbeddingConfig(item.value());
|
| 1433 |
+
} else {
|
| 1434 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key());
|
| 1435 |
+
}
|
| 1436 |
+
}
|
| 1437 |
+
|
| 1438 |
+
// Engine Verification requires dialogType for engine roles. Since "type" is encounterd
|
| 1439 |
+
// later than "engine" in loop. Therefore, moving engine validation out of the loop.
|
| 1440 |
+
validateMultiEngineConfig(config["engine"], dialogType);
|
| 1441 |
+
|
| 1442 |
+
#if defined(GENIE_SSD_FEATURE)
|
| 1443 |
+
if (ssdq1) {
|
| 1444 |
+
if (!ssdq1Config.is_object()) {
|
| 1445 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 dialog config");
|
| 1446 |
+
}
|
| 1447 |
+
validateDialogSsdConfig(ssdq1Config);
|
| 1448 |
+
} else {
|
| 1449 |
+
if (ssdq1Config.is_object()) {
|
| 1450 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1451 |
+
"ssd-q1 dialog config for incorrect dialog type: " + dialogType);
|
| 1452 |
+
}
|
| 1453 |
+
}
|
| 1454 |
+
#endif
|
| 1455 |
+
#if defined(GENIE_LADE_FEATURE)
|
| 1456 |
+
if (lade) {
|
| 1457 |
+
if (!ladeConfig.is_object()) {
|
| 1458 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade dialog config");
|
| 1459 |
+
}
|
| 1460 |
+
validateDialogLadeConfig(ladeConfig);
|
| 1461 |
+
} else {
|
| 1462 |
+
if (ladeConfig.is_object()) {
|
| 1463 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1464 |
+
"lade dialog config for incorrect dialog type: " + dialogType);
|
| 1465 |
+
}
|
| 1466 |
+
}
|
| 1467 |
+
#endif
|
| 1468 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 1469 |
+
if (spd) {
|
| 1470 |
+
if (!spdConfig.is_object()) {
|
| 1471 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd dialog config");
|
| 1472 |
+
}
|
| 1473 |
+
validateDialogSpdConfig(spdConfig);
|
| 1474 |
+
} else {
|
| 1475 |
+
if (spdConfig.is_object()) {
|
| 1476 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1477 |
+
"spd dialog config for incorrect dialog type: " + dialogType);
|
| 1478 |
+
}
|
| 1479 |
+
}
|
| 1480 |
+
#endif
|
| 1481 |
+
#if defined(GENIE_MULTISTREAM_FEATURE)
|
| 1482 |
+
if (multistream) {
|
| 1483 |
+
if (!multistreamConfig.is_object()) {
|
| 1484 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream dialog config");
|
| 1485 |
+
}
|
| 1486 |
+
validateDialogMultistreamConfig(multistreamConfig);
|
| 1487 |
+
} else {
|
| 1488 |
+
if (multistreamConfig.is_object()) {
|
| 1489 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1490 |
+
"multistream dialog config for incorrect dialog type: " + dialogType);
|
| 1491 |
+
}
|
| 1492 |
+
}
|
| 1493 |
+
#endif
|
| 1494 |
+
}
|
| 1495 |
+
|
| 1496 |
+
static void translateDialogConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 1497 |
+
if (genieConfig["dialog"]["version"] == 1) {
|
| 1498 |
+
if (genieConfig["dialog"]["type"] == "lade") {
|
| 1499 |
+
quallaConfig["type"] = "lhd-dec";
|
| 1500 |
+
} else if (genieConfig["dialog"]["type"] == "spd") {
|
| 1501 |
+
quallaConfig["type"] = "spec-dec";
|
| 1502 |
+
} else if (genieConfig["dialog"]["type"] == "multistream") {
|
| 1503 |
+
quallaConfig["type"] = "multistream";
|
| 1504 |
+
} else {
|
| 1505 |
+
quallaConfig["type"] = genieConfig["dialog"]["type"];
|
| 1506 |
+
}
|
| 1507 |
+
#if defined(GENIE_SSD_FEATURE)
|
| 1508 |
+
if (genieConfig["dialog"]["type"] == "ssd-q1") {
|
| 1509 |
+
quallaConfig["ssd-version"] = genieConfig["dialog"]["ssd-q1"]["ssd-version"];
|
| 1510 |
+
quallaConfig["forecast-token-count"] =
|
| 1511 |
+
genieConfig["dialog"]["ssd-q1"]["forecast-token-count"];
|
| 1512 |
+
quallaConfig["branches"] = genieConfig["dialog"]["ssd-q1"]["branches"];
|
| 1513 |
+
quallaConfig["forecast-prefix"] = genieConfig["dialog"]["ssd-q1"]["forecast-prefix"];
|
| 1514 |
+
quallaConfig["forecast-prefix-name"] =
|
| 1515 |
+
genieConfig["dialog"]["ssd-q1"]["forecast-prefix-name"];
|
| 1516 |
+
|
| 1517 |
+
if (genieConfig["dialog"]["ssd-q1"].contains("n-streams")) {
|
| 1518 |
+
quallaConfig["n-streams"] = genieConfig["dialog"]["ssd-q1"]["n-streams"];
|
| 1519 |
+
}
|
| 1520 |
+
if (genieConfig["dialog"]["ssd-q1"].contains("p-threshold")) {
|
| 1521 |
+
quallaConfig["p-threshold"] = genieConfig["dialog"]["ssd-q1"]["p-threshold"];
|
| 1522 |
+
}
|
| 1523 |
+
}
|
| 1524 |
+
#endif
|
| 1525 |
+
#if defined(GENIE_LADE_FEATURE)
|
| 1526 |
+
if (genieConfig["dialog"]["type"] == "lade") {
|
| 1527 |
+
quallaConfig["lhd-update-mode"] = genieConfig["dialog"]["lade"]["update-mode"];
|
| 1528 |
+
quallaConfig["window"] = genieConfig["dialog"]["lade"]["window"];
|
| 1529 |
+
quallaConfig["ngram"] = genieConfig["dialog"]["lade"]["ngram"];
|
| 1530 |
+
quallaConfig["gcap"] = genieConfig["dialog"]["lade"]["gcap"];
|
| 1531 |
+
}
|
| 1532 |
+
#endif
|
| 1533 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 1534 |
+
if (genieConfig["dialog"]["type"] == "spd") {
|
| 1535 |
+
quallaConfig["draft-len"] = genieConfig["dialog"]["spd"]["draft-len"];
|
| 1536 |
+
}
|
| 1537 |
+
#endif
|
| 1538 |
+
#if defined(GENIE_MULTISTREAM_FEATURE)
|
| 1539 |
+
if (genieConfig["dialog"]["type"] == "multistream") {
|
| 1540 |
+
quallaConfig["n-streams"] = genieConfig["dialog"]["multistream"]["n-streams"];
|
| 1541 |
+
if (genieConfig["dialog"]["multistream"].contains("p-threshold")) {
|
| 1542 |
+
quallaConfig["p-threshold"] = genieConfig["dialog"]["multistream"]["p-threshold"];
|
| 1543 |
+
}
|
| 1544 |
+
}
|
| 1545 |
+
#endif
|
| 1546 |
+
}
|
| 1547 |
+
if (genieConfig["dialog"].contains("stop-sequence")) {
|
| 1548 |
+
quallaConfig["prompt"]["stop-sequence"] = genieConfig["dialog"]["stop-sequence"];
|
| 1549 |
+
}
|
| 1550 |
+
|
| 1551 |
+
translateContextConfig(genieConfig, quallaConfig);
|
| 1552 |
+
translateTokenizerConfig(genieConfig, quallaConfig);
|
| 1553 |
+
translateSamplerConfig(genieConfig, quallaConfig);
|
| 1554 |
+
translateMultiEngineConfig(genieConfig, quallaConfig);
|
| 1555 |
+
translateEmbeddingConfig(genieConfig, quallaConfig);
|
| 1556 |
+
}
|
| 1557 |
+
|
| 1558 |
+
uint32_t getMaxNumTokens(const qualla::json& genieConfig) {
|
| 1559 |
+
uint32_t tokenLimit{UINT32_MAX};
|
| 1560 |
+
if (genieConfig["dialog"]["version"] == 1) {
|
| 1561 |
+
if (genieConfig["dialog"].contains("max-num-tokens")) {
|
| 1562 |
+
tokenLimit = genieConfig["dialog"]["max-num-tokens"];
|
| 1563 |
+
}
|
| 1564 |
+
}
|
| 1565 |
+
return tokenLimit;
|
| 1566 |
+
}
|
| 1567 |
+
|
| 1568 |
+
Dialog::Config::Config(const char* configStr) {
|
| 1569 |
+
qualla::json config;
|
| 1570 |
+
rope_theta_set = false;
|
| 1571 |
+
position_dim_set = false;
|
| 1572 |
+
{
|
| 1573 |
+
std::set<qualla::json> keys;
|
| 1574 |
+
|
| 1575 |
+
auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
|
| 1576 |
+
if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
|
| 1577 |
+
if (keys.count(parsed) > 0) {
|
| 1578 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 1579 |
+
"Multiple dialog config key: " + parsed.dump());
|
| 1580 |
+
}
|
| 1581 |
+
keys.insert(parsed);
|
| 1582 |
+
}
|
| 1583 |
+
return true;
|
| 1584 |
+
};
|
| 1585 |
+
|
| 1586 |
+
config = qualla::json::parse(configStr, callback);
|
| 1587 |
+
}
|
| 1588 |
+
|
| 1589 |
+
if (!config.is_object()) {
|
| 1590 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object");
|
| 1591 |
+
}
|
| 1592 |
+
|
| 1593 |
+
std::set<std::string> mandatoryFields{"dialog"};
|
| 1594 |
+
for (const auto& field : mandatoryFields) {
|
| 1595 |
+
if (!config.contains(field)) {
|
| 1596 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field);
|
| 1597 |
+
}
|
| 1598 |
+
}
|
| 1599 |
+
|
| 1600 |
+
// component is used in the "ENFORCE" macros
|
| 1601 |
+
std::string component = "dialog";
|
| 1602 |
+
|
| 1603 |
+
for (auto& item : config.items()) {
|
| 1604 |
+
if (item.key() == "dialog") {
|
| 1605 |
+
JSON_ENFORCE_OBJECT();
|
| 1606 |
+
validateDialogConfig(item.value());
|
| 1607 |
+
} else {
|
| 1608 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key());
|
| 1609 |
+
}
|
| 1610 |
+
}
|
| 1611 |
+
m_config = config;
|
| 1612 |
+
}
|
| 1613 |
+
|
| 1614 |
+
qualla::json Dialog::Config::getJson() const { return m_config; }
|
| 1615 |
+
|
| 1616 |
+
//=============================================================================
|
| 1617 |
+
// Dialog functions
|
| 1618 |
+
//=============================================================================
|
| 1619 |
+
|
| 1620 |
+
qnn::util::HandleManager<Dialog> Dialog::s_manager;
|
| 1621 |
+
std::atomic<std::uint32_t> Dialog::s_nameCounter{0u};
|
| 1622 |
+
|
| 1623 |
+
GenieDialog_Handle_t Dialog::add(std::shared_ptr<Dialog> dialog) {
|
| 1624 |
+
return (GenieDialog_Handle_t)s_manager.add(dialog);
|
| 1625 |
+
}
|
| 1626 |
+
|
| 1627 |
+
std::shared_ptr<Dialog> Dialog::get(GenieDialog_Handle_t handle) {
|
| 1628 |
+
return s_manager.get((qnn::util::Handle_t)handle);
|
| 1629 |
+
}
|
| 1630 |
+
|
| 1631 |
+
void Dialog::remove(GenieDialog_Handle_t handle) { s_manager.remove((qnn::util::Handle_t)handle); }
|
| 1632 |
+
|
| 1633 |
+
Dialog::Dialog(std::shared_ptr<Config> config) {
|
| 1634 |
+
auto env = qualla::Env::create(qualla::json{});
|
| 1635 |
+
qualla::json quallaConfig;
|
| 1636 |
+
translateDialogConfig(config->getJson(), quallaConfig);
|
| 1637 |
+
m_tokenLimit = getMaxNumTokens(config->getJson());
|
| 1638 |
+
m_quallaDialog = qualla::Dialog::create(
|
| 1639 |
+
env, "dialog" + std::to_string(s_nameCounter.fetch_add(1u)), quallaConfig);
|
| 1640 |
+
if (!m_quallaDialog) {
|
| 1641 |
+
throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
|
| 1642 |
+
}
|
| 1643 |
+
}
|
| 1644 |
+
|
| 1645 |
+
static_assert(qualla::Sentence::Code::COMPLETE ==
|
| 1646 |
+
static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_COMPLETE));
|
| 1647 |
+
static_assert(qualla::Sentence::Code::BEGIN ==
|
| 1648 |
+
static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_BEGIN));
|
| 1649 |
+
static_assert(qualla::Sentence::Code::CONTINUE ==
|
| 1650 |
+
static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_CONTINUE));
|
| 1651 |
+
static_assert(qualla::Sentence::Code::END ==
|
| 1652 |
+
static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_END));
|
| 1653 |
+
static_assert(qualla::Sentence::Code::ABORT ==
|
| 1654 |
+
static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_ABORT));
|
| 1655 |
+
|
| 1656 |
+
int32_t Dialog::query(const char* queryStr,
|
| 1657 |
+
GenieDialog_SentenceCode_t sentenceCode,
|
| 1658 |
+
GenieDialog_QueryCallback_t callback,
|
| 1659 |
+
const void* userData) {
|
| 1660 |
+
std::string query(queryStr);
|
| 1661 |
+
uint32_t genTokenCount = 0u;
|
| 1662 |
+
bool status = m_quallaDialog->query(
|
| 1663 |
+
query,
|
| 1664 |
+
static_cast<qualla::Sentence::Code>(sentenceCode),
|
| 1665 |
+
[&](const std::string& response, qualla::Sentence::Code code) {
|
| 1666 |
+
callback(response.c_str(), static_cast<GenieDialog_SentenceCode_t>(code), userData);
|
| 1667 |
+
bool keepGoing = ++genTokenCount < m_tokenLimit;
|
| 1668 |
+
if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) ||
|
| 1669 |
+
(code == qualla::Sentence::Code::CONTINUE))) {
|
| 1670 |
+
callback("", GENIE_DIALOG_SENTENCE_END, userData);
|
| 1671 |
+
}
|
| 1672 |
+
return keepGoing;
|
| 1673 |
+
});
|
| 1674 |
+
qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
|
| 1675 |
+
printf(
|
| 1676 |
+
"\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
|
| 1677 |
+
"%f toks/sec\n"
|
| 1678 |
+
"Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
|
| 1679 |
+
kpis.init.total_usec,
|
| 1680 |
+
kpis.prompt.last_usec,
|
| 1681 |
+
kpis.tps.prompt,
|
| 1682 |
+
kpis.generate.last_usec,
|
| 1683 |
+
kpis.tps.generate);
|
| 1684 |
+
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1685 |
+
}
|
| 1686 |
+
|
| 1687 |
+
int32_t Dialog::save(const std::string& name) {
|
| 1688 |
+
return m_quallaDialog->save(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1689 |
+
}
|
| 1690 |
+
|
| 1691 |
+
int32_t Dialog::restore(const std::string& name) {
|
| 1692 |
+
return m_quallaDialog->restore(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1693 |
+
}
|
| 1694 |
+
|
| 1695 |
+
#if defined(GENIE_E2T_FEATURE)
|
| 1696 |
+
int32_t Dialog::embeddingQuery(const void* embeddings,
|
| 1697 |
+
const uint32_t embeddingsSize,
|
| 1698 |
+
GenieDialog_SentenceCode_t sentenceCode,
|
| 1699 |
+
GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
|
| 1700 |
+
GenieDialog_QueryCallback_t callback,
|
| 1701 |
+
const void* userData) {
|
| 1702 |
+
uint32_t genTokenCount = 0u;
|
| 1703 |
+
|
| 1704 |
+
if (embeddingsSize % m_quallaDialog->getEmbeddingBufferSize() != 0) {
|
| 1705 |
+
throw std::runtime_error(
|
| 1706 |
+
"The embeddings buffer size must be an integer multiple of the embedding vector size in "
|
| 1707 |
+
"bytes.");
|
| 1708 |
+
}
|
| 1709 |
+
|
| 1710 |
+
const uint8_t* embeddingsSrc = static_cast<const uint8_t*>(embeddings);
|
| 1711 |
+
std::vector<uint8_t> embeddingVector(embeddingsSrc, embeddingsSrc + embeddingsSize);
|
| 1712 |
+
|
| 1713 |
+
qualla::Dialog::T2ECallback t2eQuallaCallback{nullptr};
|
| 1714 |
+
if (t2eCallback) {
|
| 1715 |
+
t2eQuallaCallback = [&](const int32_t token, void* embedding, const uint32_t embd_size) {
|
| 1716 |
+
t2eCallback(token, embedding, embd_size, userData);
|
| 1717 |
+
};
|
| 1718 |
+
}
|
| 1719 |
+
|
| 1720 |
+
bool status = m_quallaDialog->query(
|
| 1721 |
+
embeddingVector,
|
| 1722 |
+
static_cast<qualla::Sentence::Code>(sentenceCode),
|
| 1723 |
+
t2eQuallaCallback,
|
| 1724 |
+
[&](const std::string& response, qualla::Sentence::Code code) {
|
| 1725 |
+
callback(response.c_str(), static_cast<GenieDialog_SentenceCode_t>(code), userData);
|
| 1726 |
+
bool keepGoing = ++genTokenCount < m_tokenLimit;
|
| 1727 |
+
if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) ||
|
| 1728 |
+
(code == qualla::Sentence::Code::CONTINUE))) {
|
| 1729 |
+
callback("", GENIE_DIALOG_SENTENCE_END, userData);
|
| 1730 |
+
}
|
| 1731 |
+
return keepGoing;
|
| 1732 |
+
});
|
| 1733 |
+
qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
|
| 1734 |
+
printf(
|
| 1735 |
+
"\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
|
| 1736 |
+
"%f toks/sec\n"
|
| 1737 |
+
"Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
|
| 1738 |
+
kpis.init.total_usec,
|
| 1739 |
+
kpis.prompt.last_usec,
|
| 1740 |
+
kpis.tps.prompt,
|
| 1741 |
+
kpis.generate.last_usec,
|
| 1742 |
+
kpis.tps.generate);
|
| 1743 |
+
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1744 |
+
}
|
| 1745 |
+
#endif
|
| 1746 |
+
|
| 1747 |
+
void Dialog::reset() { m_quallaDialog->reset(); }
|
| 1748 |
+
|
| 1749 |
+
#if defined(GENIE_LORA_FEATURE)
|
| 1750 |
+
|
| 1751 |
+
int32_t Dialog::applyLora(std::string loraAdapterName, std::string engine) {
|
| 1752 |
+
bool status = m_quallaDialog->applyLoraAdapter(loraAdapterName, engine);
|
| 1753 |
+
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL);
|
| 1754 |
+
}
|
| 1755 |
+
|
| 1756 |
+
int32_t Dialog::applyLoraStrength(std::string tensorName, std::string engine, float alpha) {
|
| 1757 |
+
bool status = m_quallaDialog->applyLoraStrength(tensorName, alpha, engine);
|
| 1758 |
+
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL);
|
| 1759 |
+
}
|
| 1760 |
+
|
| 1761 |
+
#endif
|
| 1762 |
+
|
| 1763 |
+
int32_t Dialog::tokenQuery(const uint32_t* tokens,
|
| 1764 |
+
const uint32_t sizeInputTokens,
|
| 1765 |
+
GenieDialog_SentenceCode_t sentenceCode,
|
| 1766 |
+
GenieDialog_TokenQueryCallback_t callback,
|
| 1767 |
+
const void* userData) {
|
| 1768 |
+
std::vector<uint32_t> inputTokens;
|
| 1769 |
+
for (size_t i = 0; i < sizeInputTokens; i++) {
|
| 1770 |
+
inputTokens.push_back(tokens[i]);
|
| 1771 |
+
}
|
| 1772 |
+
uint32_t genTokenCount = 0u;
|
| 1773 |
+
dialogCallback.setCallBackType(qualla::QUALLA_CALLBACK_TYPE_TOKEN);
|
| 1774 |
+
dialogCallback.getTokenCbFunc() = std::make_shared<
|
| 1775 |
+
std::function<bool(const int32_t*, const uint32_t, qualla::Sentence::Code)>>();
|
| 1776 |
+
*(dialogCallback.getTokenCbFunc()) = [&](const int32_t* responseTokens,
|
| 1777 |
+
const uint32_t sizeResponseTokens,
|
| 1778 |
+
qualla::Sentence::Code code) {
|
| 1779 |
+
callback((const uint32_t*)responseTokens,
|
| 1780 |
+
sizeResponseTokens,
|
| 1781 |
+
static_cast<GenieDialog_SentenceCode_t>(code),
|
| 1782 |
+
userData);
|
| 1783 |
+
bool keepGoing = ++genTokenCount < m_tokenLimit;
|
| 1784 |
+
if (!keepGoing &&
|
| 1785 |
+
((code == qualla::Sentence::Code::BEGIN) || (code == qualla::Sentence::Code::CONTINUE))) {
|
| 1786 |
+
callback(nullptr, 0, GENIE_DIALOG_SENTENCE_END, userData);
|
| 1787 |
+
}
|
| 1788 |
+
return keepGoing;
|
| 1789 |
+
};
|
| 1790 |
+
bool status = m_quallaDialog->query((const std::vector<uint32_t>)inputTokens,
|
| 1791 |
+
static_cast<qualla::Sentence::Code>(sentenceCode),
|
| 1792 |
+
dialogCallback);
|
| 1793 |
+
qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
|
| 1794 |
+
printf(
|
| 1795 |
+
"\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
|
| 1796 |
+
"%f toks/sec\n"
|
| 1797 |
+
"Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
|
| 1798 |
+
kpis.init.total_usec,
|
| 1799 |
+
kpis.prompt.last_usec,
|
| 1800 |
+
kpis.tps.prompt,
|
| 1801 |
+
kpis.generate.last_usec,
|
| 1802 |
+
kpis.tps.generate);
|
| 1803 |
+
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1804 |
+
}
|
Genie/Genie/src/Dialog.hpp
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <atomic>
|
| 12 |
+
#include <memory>
|
| 13 |
+
|
| 14 |
+
#include "GenieDialog.h"
|
| 15 |
+
#include "Util/HandleManager.hpp"
|
| 16 |
+
#include "qualla/dialog.hpp"
|
| 17 |
+
#include "qualla/DialogCallback.hpp"
|
| 18 |
+
|
| 19 |
+
namespace genie {
|
| 20 |
+
|
| 21 |
+
enum LORA_VERSION : uint8_t {
|
| 22 |
+
GENIE_LORA_VERSION_V1 = 0x1,
|
| 23 |
+
GENIE_LORA_VERSION_V2 = 0x2,
|
| 24 |
+
GENIE_LORA_VERSION_UNDEFINED = 0xFF
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
class Dialog {
|
| 28 |
+
public:
|
| 29 |
+
class Config {
|
| 30 |
+
public:
|
| 31 |
+
static GenieDialogConfig_Handle_t add(std::shared_ptr<Config> config);
|
| 32 |
+
static std::shared_ptr<Config> get(GenieDialogConfig_Handle_t handle);
|
| 33 |
+
static void remove(GenieDialogConfig_Handle_t handle);
|
| 34 |
+
|
| 35 |
+
Config(const char* configStr);
|
| 36 |
+
qualla::json getJson() const;
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
static qnn::util::HandleManager<Config> s_manager;
|
| 40 |
+
qualla::json m_config;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
|
| 44 |
+
static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
|
| 45 |
+
static void remove(GenieDialog_Handle_t handle);
|
| 46 |
+
|
| 47 |
+
qualla::DialogCallback dialogCallback;
|
| 48 |
+
|
| 49 |
+
Dialog(std::shared_ptr<Config> config);
|
| 50 |
+
|
| 51 |
+
Dialog(const Dialog&) = delete;
|
| 52 |
+
Dialog& operator=(const Dialog&) = delete;
|
| 53 |
+
Dialog(Dialog&&) = delete;
|
| 54 |
+
Dialog& operator=(Dialog&&) = delete;
|
| 55 |
+
|
| 56 |
+
int32_t query(const char* queryStr,
|
| 57 |
+
GenieDialog_SentenceCode_t sentenceCode,
|
| 58 |
+
GenieDialog_QueryCallback_t callback,
|
| 59 |
+
const void* userData);
|
| 60 |
+
|
| 61 |
+
int32_t save(const std::string&);
|
| 62 |
+
|
| 63 |
+
int32_t restore(const std::string&);
|
| 64 |
+
|
| 65 |
+
#if defined(GENIE_E2T_FEATURE)
|
| 66 |
+
int32_t embeddingQuery(const void* embeddings,
|
| 67 |
+
const uint32_t embeddingsSize,
|
| 68 |
+
GenieDialog_SentenceCode_t sentenceCode,
|
| 69 |
+
GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
|
| 70 |
+
GenieDialog_QueryCallback_t callback,
|
| 71 |
+
const void* userData);
|
| 72 |
+
#endif
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
int32_t tokenQuery(const uint32_t* tokens,
|
| 77 |
+
const uint32_t sizeInputTokens,
|
| 78 |
+
GenieDialog_SentenceCode_t sentenceCode,
|
| 79 |
+
GenieDialog_TokenQueryCallback_t callback,
|
| 80 |
+
const void* userData);
|
| 81 |
+
|
| 82 |
+
void reset();
|
| 83 |
+
|
| 84 |
+
#if defined(GENIE_LORA_FEATURE)
|
| 85 |
+
int32_t applyLora(std::string loraAdapterName, std::string engine);
|
| 86 |
+
int32_t applyLoraStrength(std::string tensorName, std::string engine, float alpha);
|
| 87 |
+
#endif
|
| 88 |
+
|
| 89 |
+
private:
|
| 90 |
+
std::unique_ptr<qualla::Dialog> m_quallaDialog;
|
| 91 |
+
uint32_t m_tokenLimit{UINT32_MAX};
|
| 92 |
+
static qnn::util::HandleManager<Dialog> s_manager;
|
| 93 |
+
static std::atomic<std::uint32_t> s_nameCounter;
|
| 94 |
+
};
|
| 95 |
+
} // namespace genie
|
Genie/Genie/src/Exception.hpp
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <exception>
|
| 12 |
+
#include <string>
|
| 13 |
+
|
| 14 |
+
#include "GenieCommon.h"
|
| 15 |
+
|
| 16 |
+
namespace genie {
|
| 17 |
+
|
| 18 |
+
class Exception : public std::runtime_error {
|
| 19 |
+
public:
|
| 20 |
+
Exception(Genie_Status_t status, std::string what) : std::runtime_error(what), m_status(status) {}
|
| 21 |
+
Genie_Status_t status() const { return m_status; }
|
| 22 |
+
|
| 23 |
+
private:
|
| 24 |
+
Genie_Status_t m_status;
|
| 25 |
+
};
|
| 26 |
+
|
| 27 |
+
} // namespace genie
|
Genie/Genie/src/GenieCommon.cpp
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//=============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//=============================================================================
|
| 8 |
+
|
| 9 |
+
#include "GenieCommon.h"
|
| 10 |
+
|
| 11 |
+
uint32_t Genie_getApiMajorVersion(void) { return GENIE_API_VERSION_MAJOR; }
|
| 12 |
+
|
| 13 |
+
uint32_t Genie_getApiMinorVersion(void) { return GENIE_API_VERSION_MINOR; }
|
| 14 |
+
|
| 15 |
+
uint32_t Genie_getApiPatchVersion(void) { return GENIE_API_VERSION_PATCH; }
|
Genie/Genie/src/GenieDialog.cpp
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//=============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//=============================================================================
|
| 8 |
+
|
| 9 |
+
#include "Dialog.hpp"
|
| 10 |
+
#include "Exception.hpp"
|
| 11 |
+
#include "GenieDialog.h"
|
| 12 |
+
#include "Macro.hpp"
|
| 13 |
+
#include "Util/HandleManager.hpp"
|
| 14 |
+
#include "qualla/detail/json.hpp"
|
| 15 |
+
|
| 16 |
+
using namespace genie;
|
| 17 |
+
|
| 18 |
+
GENIE_API
|
| 19 |
+
Genie_Status_t GenieDialogConfig_createFromJson(const char* str,
|
| 20 |
+
GenieDialogConfig_Handle_t* configHandle) {
|
| 21 |
+
try {
|
| 22 |
+
GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 23 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 24 |
+
auto config = std::make_shared<Dialog::Config>(str);
|
| 25 |
+
GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC);
|
| 26 |
+
*configHandle = genie::Dialog::Config::add(config);
|
| 27 |
+
} catch (const qualla::json::parse_error& e) {
|
| 28 |
+
std::cerr << e.what() << std::endl;
|
| 29 |
+
return GENIE_STATUS_ERROR_JSON_FORMAT;
|
| 30 |
+
} catch (const Exception& e) {
|
| 31 |
+
std::cerr << e.what() << std::endl;
|
| 32 |
+
return e.status();
|
| 33 |
+
} catch (const std::exception& e) {
|
| 34 |
+
std::cerr << e.what() << std::endl;
|
| 35 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 36 |
+
}
|
| 37 |
+
return GENIE_STATUS_SUCCESS;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
GENIE_API
|
| 41 |
+
Genie_Status_t GenieDialogConfig_free(const GenieDialogConfig_Handle_t configHandle) {
|
| 42 |
+
try {
|
| 43 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 44 |
+
{
|
| 45 |
+
// Check if the dialog actually exists
|
| 46 |
+
auto configObj = genie::Dialog::Config::get(configHandle);
|
| 47 |
+
GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 48 |
+
}
|
| 49 |
+
genie::Dialog::Config::remove(configHandle);
|
| 50 |
+
} catch (const std::exception& e) {
|
| 51 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 52 |
+
}
|
| 53 |
+
return GENIE_STATUS_SUCCESS;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
GENIE_API
|
| 57 |
+
Genie_Status_t GenieDialog_create(const GenieDialogConfig_Handle_t configHandle,
|
| 58 |
+
GenieDialog_Handle_t* dialogHandle) {
|
| 59 |
+
try {
|
| 60 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 61 |
+
|
| 62 |
+
// Get config object
|
| 63 |
+
auto configObj = genie::Dialog::Config::get(configHandle);
|
| 64 |
+
GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 65 |
+
|
| 66 |
+
// Create dialog
|
| 67 |
+
auto dialog = std::make_shared<genie::Dialog>(configObj);
|
| 68 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_MEM_ALLOC);
|
| 69 |
+
|
| 70 |
+
// Create Handle
|
| 71 |
+
*dialogHandle = genie::Dialog::add(dialog);
|
| 72 |
+
} catch (const std::exception& e) {
|
| 73 |
+
std::cerr << e.what() << std::endl;
|
| 74 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// Return SUCCESS
|
| 78 |
+
return GENIE_STATUS_SUCCESS;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
GENIE_API
|
| 82 |
+
Genie_Status_t GenieDialog_query(const GenieDialog_Handle_t dialogHandle,
|
| 83 |
+
const char* queryStr,
|
| 84 |
+
const GenieDialog_SentenceCode_t sentenceCode,
|
| 85 |
+
const GenieDialog_QueryCallback_t callback,
|
| 86 |
+
const void* userData) {
|
| 87 |
+
int32_t status;
|
| 88 |
+
|
| 89 |
+
try {
|
| 90 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 91 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 92 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 93 |
+
GENIE_ENSURE(queryStr, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 94 |
+
GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 95 |
+
|
| 96 |
+
switch (sentenceCode) {
|
| 97 |
+
case GENIE_DIALOG_SENTENCE_COMPLETE:
|
| 98 |
+
case GENIE_DIALOG_SENTENCE_BEGIN:
|
| 99 |
+
case GENIE_DIALOG_SENTENCE_CONTINUE:
|
| 100 |
+
case GENIE_DIALOG_SENTENCE_END:
|
| 101 |
+
case GENIE_DIALOG_SENTENCE_ABORT:
|
| 102 |
+
// Do nothing
|
| 103 |
+
break;
|
| 104 |
+
default:
|
| 105 |
+
return GENIE_STATUS_ERROR_INVALID_ARGUMENT;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
status = dialog->query(queryStr, sentenceCode, callback, userData);
|
| 109 |
+
} catch (const std::exception& e) {
|
| 110 |
+
std::cerr << e.what() << std::endl;
|
| 111 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
return status;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
GENIE_API
|
| 118 |
+
Genie_Status_t GenieDialog_save(const GenieDialog_Handle_t dialogHandle, const char* path) {
|
| 119 |
+
int32_t status;
|
| 120 |
+
|
| 121 |
+
try {
|
| 122 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 123 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 124 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 125 |
+
GENIE_ENSURE(path, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 126 |
+
status = dialog->save(path);
|
| 127 |
+
} catch (const std::exception& e) {
|
| 128 |
+
std::cerr << e.what() << std::endl;
|
| 129 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
return status;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
GENIE_API
|
| 136 |
+
Genie_Status_t GenieDialog_restore(const GenieDialog_Handle_t dialogHandle, const char* path) {
|
| 137 |
+
int32_t status;
|
| 138 |
+
|
| 139 |
+
try {
|
| 140 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 141 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 142 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 143 |
+
GENIE_ENSURE(path, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 144 |
+
status = dialog->restore(path);
|
| 145 |
+
} catch (const std::exception& e) {
|
| 146 |
+
std::cerr << e.what() << std::endl;
|
| 147 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
return status;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
GENIE_API
|
| 154 |
+
Genie_Status_t GenieDialog_reset(const GenieDialog_Handle_t dialogHandle) {
|
| 155 |
+
try {
|
| 156 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 157 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 158 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 159 |
+
dialog->reset();
|
| 160 |
+
} catch (const std::exception& e) {
|
| 161 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 162 |
+
}
|
| 163 |
+
return GENIE_STATUS_SUCCESS;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
#if defined(GENIE_LORA_FEATURE)
|
| 167 |
+
|
| 168 |
+
GENIE_API
|
| 169 |
+
Genie_Status_t GenieDialog_applyLora(const GenieDialog_Handle_t dialogHandle,
|
| 170 |
+
const char* engine,
|
| 171 |
+
const char* loraAdapterName) {
|
| 172 |
+
int32_t status;
|
| 173 |
+
try {
|
| 174 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 175 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 176 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 177 |
+
GENIE_ENSURE(engine, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 178 |
+
std::string eng(engine);
|
| 179 |
+
GENIE_ENSURE(loraAdapterName, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 180 |
+
std::string loraName(loraAdapterName);
|
| 181 |
+
status = dialog->applyLora(loraName, eng);
|
| 182 |
+
} catch (const std::exception& e) {
|
| 183 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 184 |
+
}
|
| 185 |
+
return status;
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
GENIE_API
|
| 189 |
+
Genie_Status_t GenieDialog_setLoraStrength(const GenieDialog_Handle_t dialogHandle,
|
| 190 |
+
const char* engine,
|
| 191 |
+
const char* tensorName,
|
| 192 |
+
const float alpha) {
|
| 193 |
+
int32_t status;
|
| 194 |
+
try {
|
| 195 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 196 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 197 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 198 |
+
GENIE_ENSURE(engine, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 199 |
+
std::string eng(engine);
|
| 200 |
+
GENIE_ENSURE(tensorName, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 201 |
+
std::string alphaTensorName(tensorName);
|
| 202 |
+
GENIE_ENSURE_NOT_EMPTY(alphaTensorName, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 203 |
+
status = dialog->applyLoraStrength(tensorName, eng, alpha);
|
| 204 |
+
} catch (const std::exception& e) {
|
| 205 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 206 |
+
}
|
| 207 |
+
return status;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
#endif
|
| 211 |
+
|
| 212 |
+
GENIE_API
|
| 213 |
+
Genie_Status_t GenieDialog_tokenQuery(const GenieDialog_Handle_t dialogHandle,
|
| 214 |
+
const uint32_t* inputTokens,
|
| 215 |
+
const uint32_t numTokens,
|
| 216 |
+
const GenieDialog_SentenceCode_t sentenceCode,
|
| 217 |
+
const GenieDialog_TokenQueryCallback_t callback,
|
| 218 |
+
const void* userData) {
|
| 219 |
+
bool status;
|
| 220 |
+
try {
|
| 221 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 222 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 223 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 224 |
+
GENIE_ENSURE(inputTokens, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 225 |
+
GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 226 |
+
status = dialog->tokenQuery(inputTokens, numTokens, sentenceCode, callback, userData);
|
| 227 |
+
} catch (const std::exception& e) {
|
| 228 |
+
std::cerr << e.what() << std::endl;
|
| 229 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
return status;
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
GENIE_API
|
| 236 |
+
Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
|
| 237 |
+
try {
|
| 238 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 239 |
+
{
|
| 240 |
+
// Check if the dialog actually exists
|
| 241 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 242 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 243 |
+
}
|
| 244 |
+
genie::Dialog::remove(dialogHandle);
|
| 245 |
+
} catch (const std::exception& e) {
|
| 246 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 247 |
+
}
|
| 248 |
+
return GENIE_STATUS_SUCCESS;
|
| 249 |
+
}
|
Genie/Genie/src/GenieDialogEmbedding.cpp
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//=============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//=============================================================================
|
| 8 |
+
|
| 9 |
+
#include "Dialog.hpp"
|
| 10 |
+
#include "Exception.hpp"
|
| 11 |
+
#include "GenieDialog.h"
|
| 12 |
+
#include "Macro.hpp"
|
| 13 |
+
#include "Util/HandleManager.hpp"
|
| 14 |
+
#include "qualla/detail/json.hpp"
|
| 15 |
+
|
| 16 |
+
using namespace genie;
|
| 17 |
+
|
| 18 |
+
GENIE_API
|
| 19 |
+
Genie_Status_t GenieDialog_embeddingQuery(const GenieDialog_Handle_t dialogHandle,
|
| 20 |
+
const void* embeddings,
|
| 21 |
+
const uint32_t embeddingsSize,
|
| 22 |
+
const GenieDialog_SentenceCode_t sentenceCode,
|
| 23 |
+
const GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
|
| 24 |
+
const GenieDialog_QueryCallback_t callback,
|
| 25 |
+
const void* userData) {
|
| 26 |
+
Genie_Status_t status;
|
| 27 |
+
try {
|
| 28 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 29 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 30 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 31 |
+
GENIE_ENSURE(embeddings, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 32 |
+
GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 33 |
+
status = dialog->embeddingQuery(
|
| 34 |
+
embeddings, embeddingsSize, sentenceCode, t2eCallback, callback, userData);
|
| 35 |
+
} catch (const std::exception& e) {
|
| 36 |
+
std::cerr << e.what() << std::endl;
|
| 37 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
return status;
|
| 41 |
+
}
|
Genie/Genie/src/Macro.hpp
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//=============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
//======================================================================================================================
|
| 12 |
+
// Error generation macros
|
| 13 |
+
//======================================================================================================================
|
| 14 |
+
|
| 15 |
+
#define GENIE_LOG_ERROR(fmt, ...)
|
| 16 |
+
|
| 17 |
+
#define GENIE_ENSURE_MSG(value, return_error, msg) \
|
| 18 |
+
do { \
|
| 19 |
+
if (!(value)) { \
|
| 20 |
+
GENIE_LOG_ERROR(" " msg); \
|
| 21 |
+
return return_error; \
|
| 22 |
+
} \
|
| 23 |
+
} while (0)
|
| 24 |
+
|
| 25 |
+
#define GENIE_ENSURE(value, return_error) \
|
| 26 |
+
do { \
|
| 27 |
+
if (!(value)) { \
|
| 28 |
+
GENIE_LOG_ERROR("%s was not true.", #value); \
|
| 29 |
+
return return_error; \
|
| 30 |
+
} \
|
| 31 |
+
} while (0)
|
| 32 |
+
|
| 33 |
+
#define GENIE_ENSURE_STATUS(status, return_error) \
|
| 34 |
+
do { \
|
| 35 |
+
if ((status) != GENIE_SUCCESS) { \
|
| 36 |
+
return return_error; \
|
| 37 |
+
} \
|
| 38 |
+
} while (0)
|
| 39 |
+
|
| 40 |
+
#define GENIE_ENSURE_EQ(a, b, return_error) \
|
| 41 |
+
do { \
|
| 42 |
+
if ((a) != (b)) { \
|
| 43 |
+
GENIE_LOG_ERROR("%s != %s (%d != %d)", #a, #b, (a), (b)); \
|
| 44 |
+
return return_error; \
|
| 45 |
+
} \
|
| 46 |
+
} while (0)
|
| 47 |
+
|
| 48 |
+
#define GENIE_ENSURE_NOT_EMPTY(value, return_error) \
|
| 49 |
+
do { \
|
| 50 |
+
if (value.empty()) { \
|
| 51 |
+
GENIE_LOG_ERROR("%s was not true.", #value); \
|
| 52 |
+
return return_error; \
|
| 53 |
+
} \
|
| 54 |
+
} while (0)
|
| 55 |
+
//======================================================================================================================
|
| 56 |
+
// JSON config macros
|
| 57 |
+
//======================================================================================================================
|
| 58 |
+
|
| 59 |
+
#define JSON_ENFORCE_OBJECT() \
|
| 60 |
+
if (!item.value().is_object()) { \
|
| 61 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 62 |
+
"Invalid " + component + " config: " + item.key() + " is not an object"); \
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
#define JSON_ENFORCE_ARRAY() \
|
| 66 |
+
if (!item.value().is_array()) { \
|
| 67 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 68 |
+
"Invalid " + component + " config: " + item.key() + " is not an array"); \
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
#define JSON_ENFORCE_ARRAY_OR_OBJECT() \
|
| 72 |
+
if (!item.value().is_array() && !item.value().is_object()) { \
|
| 73 |
+
throw Exception( \
|
| 74 |
+
GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 75 |
+
"Invalid " + component + " config: " + item.key() + " is not an array or object"); \
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
#define JSON_ENFORCE_NUMERIC() \
|
| 79 |
+
if (!item.value().is_number()) { \
|
| 80 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 81 |
+
"Invalid " + component + " config: " + item.key() + " is not numeric"); \
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
#define JSON_ENFORCE_ARRAY_OR_NUMERIC() \
|
| 85 |
+
if (!item.value().is_number() && !item.value().is_array()) { \
|
| 86 |
+
throw Exception( \
|
| 87 |
+
GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 88 |
+
"Invalid " + component + " config: " + item.key() + " is not an array or numeric"); \
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#define JSON_ENFORCE_BOOLEAN() \
|
| 92 |
+
if (!item.value().is_boolean()) { \
|
| 93 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 94 |
+
"Invalid " + component + " config: " + item.key() + " is not boolean"); \
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
#define JSON_ENFORCE_STRING() \
|
| 98 |
+
if (!item.value().is_string()) { \
|
| 99 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
|
| 100 |
+
"Invalid " + component + " config: " + item.key() + " is not a string"); \
|
| 101 |
+
}
|
Genie/Genie/src/Util/HandleGenerator.hpp
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) 2019-2020,2023 Qualcomm Technologies, Inc.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <mutex>
|
| 12 |
+
|
| 13 |
+
namespace qnn {
|
| 14 |
+
namespace util {
|
| 15 |
+
|
| 16 |
+
typedef std::size_t Handle_t;
|
| 17 |
+
|
| 18 |
+
class HandleGenerator final {
|
| 19 |
+
static_assert(std::is_integral<Handle_t>::value, "Handle must be an integral type");
|
| 20 |
+
static_assert((sizeof(Handle_t) == 8) || (sizeof(Handle_t) == 4),
|
| 21 |
+
"Implementation of HandleGenerator::bswap() for sizeof(std::size_t) is required");
|
| 22 |
+
|
| 23 |
+
public:
|
| 24 |
+
HandleGenerator(const HandleGenerator&) = delete;
|
| 25 |
+
HandleGenerator& operator=(const HandleGenerator&) = delete;
|
| 26 |
+
HandleGenerator(HandleGenerator&&) = delete;
|
| 27 |
+
HandleGenerator& operator=(HandleGenerator&&) = delete;
|
| 28 |
+
|
| 29 |
+
static Handle_t generate(const void* const addr) {
|
| 30 |
+
return (bswap((Handle_t)addr) ^ (Handle_t)s_operand);
|
| 31 |
+
}
|
| 32 |
+
static const void* reverse(const Handle_t handle) {
|
| 33 |
+
return (void*)bswap(handle ^ (Handle_t)s_operand);
|
| 34 |
+
}
|
| 35 |
+
static constexpr Handle_t invalid() { return s_operand; }
|
| 36 |
+
|
| 37 |
+
private:
|
| 38 |
+
HandleGenerator() {}
|
| 39 |
+
|
| 40 |
+
static uint32_t bswap32(const uint32_t val) {
|
| 41 |
+
return (val >> 24U) | ((val >> 8U) & 0xff00U) | ((val << 8U) & 0xff0000U) | (val << 24U);
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
static uint64_t bswap64(const uint64_t val) {
|
| 45 |
+
return ((bswap32(val) + 0ULL) << 32U) | bswap32(val >> 32U);
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
template <typename T>
|
| 49 |
+
static size_t bswap(T val) {
|
| 50 |
+
if (sizeof(T) == 4) {
|
| 51 |
+
return bswap32(val);
|
| 52 |
+
} else {
|
| 53 |
+
return bswap64(val);
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
// Magic number generated via "openssl rand -hex 8"
|
| 58 |
+
static constexpr Handle_t s_operand = (Handle_t)0xd4c2416534bcdc9b;
|
| 59 |
+
};
|
| 60 |
+
|
| 61 |
+
} // namespace util
|
| 62 |
+
} // namespace qnn
|
Genie/Genie/src/Util/HandleManager.hpp
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) 2019-2020 Qualcomm Technologies, Inc.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <algorithm>
|
| 12 |
+
#include <functional>
|
| 13 |
+
#include <memory>
|
| 14 |
+
#include <mutex>
|
| 15 |
+
#include <unordered_map>
|
| 16 |
+
|
| 17 |
+
#include "HandleGenerator.hpp"
|
| 18 |
+
|
| 19 |
+
namespace qnn {
|
| 20 |
+
namespace util {
|
| 21 |
+
|
| 22 |
+
template <typename T>
|
| 23 |
+
class HandleManager {
|
| 24 |
+
public:
|
| 25 |
+
HandleManager() = default;
|
| 26 |
+
HandleManager(const HandleManager&) = delete;
|
| 27 |
+
HandleManager& operator=(const HandleManager&) = delete;
|
| 28 |
+
HandleManager(HandleManager&&) = delete;
|
| 29 |
+
HandleManager& operator=(HandleManager&&) = delete;
|
| 30 |
+
|
| 31 |
+
Handle_t add(std::shared_ptr<T> item) {
|
| 32 |
+
std::lock_guard<std::mutex> locker(m_itemsMtx);
|
| 33 |
+
|
| 34 |
+
if (!item) {
|
| 35 |
+
return HandleGenerator::invalid();
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
auto handle = HandleGenerator::generate(item.get());
|
| 39 |
+
m_items[handle] = item;
|
| 40 |
+
return handle;
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
Handle_t add(T* item) { return add(std::shared_ptr<T>(item)); }
|
| 44 |
+
|
| 45 |
+
Handle_t add(std::weak_ptr<T> item) { return add(item.lock()); }
|
| 46 |
+
|
| 47 |
+
std::shared_ptr<T> get(Handle_t handle) {
|
| 48 |
+
std::lock_guard<std::mutex> locker(m_itemsMtx);
|
| 49 |
+
|
| 50 |
+
auto it = m_items.find(handle);
|
| 51 |
+
if (it == m_items.end()) {
|
| 52 |
+
return std::shared_ptr<T>(nullptr);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
return it->second;
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
typedef std::function<bool(const std::pair<Handle_t, std::shared_ptr<T>>&)> UnaryPredicate_t;
|
| 59 |
+
|
| 60 |
+
Handle_t findIf(UnaryPredicate_t pred) const {
|
| 61 |
+
auto it = std::find_if(m_items.begin(), m_items.end(), pred);
|
| 62 |
+
if (it == m_items.end()) {
|
| 63 |
+
return HandleGenerator::invalid();
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
return it->first;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
size_t remove(Handle_t handle) {
|
| 70 |
+
std::lock_guard<std::mutex> locker(m_itemsMtx);
|
| 71 |
+
return m_items.erase(handle);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
void clear() { m_items.clear(); }
|
| 75 |
+
|
| 76 |
+
const std::unordered_map<Handle_t, std::shared_ptr<T>>& getItems() const { return m_items; }
|
| 77 |
+
|
| 78 |
+
private:
|
| 79 |
+
std::unordered_map<Handle_t, std::shared_ptr<T>> m_items;
|
| 80 |
+
std::mutex m_itemsMtx;
|
| 81 |
+
};
|
| 82 |
+
|
| 83 |
+
} // namespace util
|
| 84 |
+
} // namespace qnn
|
Genie/Genie/src/qualla/context.cpp
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/logger.hpp>
|
| 10 |
+
#include <qualla/context.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
#include <qualla/detail/onload.hpp>
|
| 13 |
+
|
| 14 |
+
#include <fmt/format.h>
|
| 15 |
+
#include <fmt/ranges.h>
|
| 16 |
+
|
| 17 |
+
namespace qualla {
|
| 18 |
+
|
| 19 |
+
Context::Context(Env& env, const std::string& name, const qualla::json& json)
|
| 20 |
+
: _name(name), _env(env), _conf(json) {
|
| 21 |
+
_env.logger().debug(fmt::format("ctx-new: {} config {}", _name, _conf.dump()));
|
| 22 |
+
|
| 23 |
+
qualla::Config conf(json, "context:");
|
| 24 |
+
_size = conf.optional<size_t>("size", 1024);
|
| 25 |
+
_size = conf.optional<size_t>("n-ctx", _size); // alternative name
|
| 26 |
+
_n_vocab = conf.optional<size_t>("n-vocab", 32000);
|
| 27 |
+
_n_embd = conf.optional<size_t>("n-embd", 1024);
|
| 28 |
+
_embedding_length = conf.optional<int32_t>("embedding-length", -1);
|
| 29 |
+
_embedding_datatype = conf.optional<std::string>("embedding-datatype", "float32");
|
| 30 |
+
// For backward compatibility. When eot-token is removed, this logic can be simplified
|
| 31 |
+
// Currently, EOT is marked as default truncating token if available
|
| 32 |
+
int32_t eot_tok = conf.optional<int32_t>("eot-token", -1);
|
| 33 |
+
if (eot_tok >= 0) _eos_tok_list.insert(eot_tok);
|
| 34 |
+
|
| 35 |
+
const qualla::json eos_conf = conf.optional<qualla::json>("eos-token", _eos_tok);
|
| 36 |
+
if (eos_conf.is_array() && eos_conf.size() > 0) {
|
| 37 |
+
const std::vector<int32_t>& eos_tokens = eos_conf.get<std::vector<int32_t>>();
|
| 38 |
+
_eos_tok = eos_tokens[0];
|
| 39 |
+
for (const int32_t& eos_tok : eos_tokens)
|
| 40 |
+
_eos_tok_list.insert(eos_tok);
|
| 41 |
+
} else if (eos_conf.is_number_integer()) {
|
| 42 |
+
int32_t eos_tok = eos_conf.get<int32_t>();
|
| 43 |
+
_eos_tok = (eot_tok >= 0) ? eot_tok : eos_tok;
|
| 44 |
+
_eos_tok_list.insert(eos_tok);
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
_pad_tok = conf.optional<qualla::json>("pad-token", _eos_tok);
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
std::unique_ptr<Context> Context::create(
|
| 51 |
+
Env& env,
|
| 52 |
+
const std::string& name,
|
| 53 |
+
const qualla::json& conf
|
| 54 |
+
) {
|
| 55 |
+
return std::make_unique<Context>(env, name, conf);
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
std::unique_ptr<Context> Context::create(
|
| 59 |
+
Env& env,
|
| 60 |
+
const std::string& name,
|
| 61 |
+
std::istream& json_stream
|
| 62 |
+
) {
|
| 63 |
+
return create(env, name, json::parse(json_stream));
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
std::unique_ptr<Context> Context::create(
|
| 67 |
+
Env& env,
|
| 68 |
+
const std::string& name,
|
| 69 |
+
const std::string& json_str
|
| 70 |
+
) {
|
| 71 |
+
return create(env, name, json::parse(json_str));
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
#ifdef QUALLA_STATIC
|
| 75 |
+
|
| 76 |
+
// This is a hack to make sure all core bits are linked in for the static build
|
| 77 |
+
|
| 78 |
+
extern void needFileLogger();
|
| 79 |
+
extern void needStdoutLogger();
|
| 80 |
+
extern void needBasicSampler();
|
| 81 |
+
extern void needBasicDialog();
|
| 82 |
+
extern void needKvShareDialog();
|
| 83 |
+
extern void needSpdDialog();
|
| 84 |
+
extern void needSsdDialog();
|
| 85 |
+
extern void needLadeDialog();
|
| 86 |
+
extern void needMultistreamDialog();
|
| 87 |
+
|
| 88 |
+
#ifdef QUALLA_ENGINE_QNN_HTP
|
| 89 |
+
extern void needQnnHtpEngine();
|
| 90 |
+
#endif
|
| 91 |
+
|
| 92 |
+
#ifdef QUALLA_ENGINE_QNN_CPU
|
| 93 |
+
extern void needQnnCpuEngine();
|
| 94 |
+
#endif
|
| 95 |
+
|
| 96 |
+
static OnLoad needs([]() {
|
| 97 |
+
needStdoutLogger();
|
| 98 |
+
needFileLogger();
|
| 99 |
+
needBasicDialog();
|
| 100 |
+
needBasicSampler();
|
| 101 |
+
needKvShareDialog();
|
| 102 |
+
needSpdDialog();
|
| 103 |
+
needSsdDialog();
|
| 104 |
+
needLadeDialog();
|
| 105 |
+
needMultistreamDialog();
|
| 106 |
+
|
| 107 |
+
#ifdef QUALLA_ENGINE_QNN_HTP
|
| 108 |
+
needQnnHtpEngine();
|
| 109 |
+
#endif
|
| 110 |
+
|
| 111 |
+
#ifdef QUALLA_ENGINE_QNN_CPU
|
| 112 |
+
needQnnCpuEngine();
|
| 113 |
+
#endif
|
| 114 |
+
});
|
| 115 |
+
|
| 116 |
+
#endif
|
| 117 |
+
|
| 118 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialog.cpp
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All rights reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/dialog.hpp>
|
| 10 |
+
#include <qualla/logger.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
#include <qualla/detail/timer.hpp>
|
| 13 |
+
#include <qualla/detail/sampler-utils.hpp>
|
| 14 |
+
|
| 15 |
+
#include <algorithm>
|
| 16 |
+
#include <functional>
|
| 17 |
+
#include <fstream>
|
| 18 |
+
#include <string>
|
| 19 |
+
#include <unordered_map>
|
| 20 |
+
#include <filesystem>
|
| 21 |
+
#include <iostream>
|
| 22 |
+
|
| 23 |
+
#include <fmt/format.h>
|
| 24 |
+
#include <fmt/ranges.h>
|
| 25 |
+
|
| 26 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 27 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 28 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 29 |
+
#define __KPIS(__fmt, ...) \
|
| 30 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 31 |
+
#define __DEBUG(__fmt, ...) \
|
| 32 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 33 |
+
#define __TRACE(__fmt, ...) \
|
| 34 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 35 |
+
|
| 36 |
+
namespace fs = std::filesystem;
|
| 37 |
+
|
| 38 |
+
namespace qualla {
|
| 39 |
+
|
| 40 |
+
Dialog::Dialog(std::shared_ptr<Env> env, const std::string& name, const qualla::json& json)
|
| 41 |
+
: _env(env) {
|
| 42 |
+
Timer start;
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
__DEBUG("dialog-new: {} config {}", name, json.dump());
|
| 47 |
+
|
| 48 |
+
using qc = qualla::Config;
|
| 49 |
+
|
| 50 |
+
// Create Gpiomarker and reset the gpio status to low
|
| 51 |
+
const qualla::json& gpio_conf = qc::optional<qualla::json>(json, "gpio", {});
|
| 52 |
+
_gpio_marker = GpioMarker::create(gpio_conf);
|
| 53 |
+
|
| 54 |
+
_gpio_marker->set();
|
| 55 |
+
|
| 56 |
+
// Create the context first
|
| 57 |
+
_ctx = Context::create(*_env, name, qc::mandatory<qualla::json>(json, "context"));
|
| 58 |
+
|
| 59 |
+
// Parse prompt config
|
| 60 |
+
const qualla::json& pmt_conf = qc::optional<qualla::json>(json, "prompt", {});
|
| 61 |
+
_prompt_type = qc::optional<std::string>(pmt_conf, "type", "llama2");
|
| 62 |
+
_sys_tags = qc::optional<std::vector<std::string>>(pmt_conf, "sys-tags", {"", ""});
|
| 63 |
+
_inst_tags = qc::optional<std::vector<std::string>>(pmt_conf, "inst-tags", {"", ""});
|
| 64 |
+
_role_tags = qc::optional<std::vector<std::string>>(pmt_conf, "role-tags", {"", ""});
|
| 65 |
+
_sys_prompt = qc::optional<std::string>(pmt_conf, "sys-prompt", "");
|
| 66 |
+
|
| 67 |
+
const std::vector<std::string>& stop_sequence =
|
| 68 |
+
qc::optional<std::vector<std::string>>(pmt_conf, "stop-sequence", {});
|
| 69 |
+
_stop_sequence = SequenceMatchTrie(stop_sequence);
|
| 70 |
+
|
| 71 |
+
// Create Tokenizer
|
| 72 |
+
// TODO: auto-detect / validate n_vocab with tokenizer vocab
|
| 73 |
+
fs::path tok_path = _env->path().models / qc::mandatory<std::string>(json, "tokenizer");
|
| 74 |
+
_tokenizer = Tokenizer::create(*_ctx, tok_path);
|
| 75 |
+
|
| 76 |
+
// Create Sampler(s)
|
| 77 |
+
auto add_sampler = [&](const qualla::json& j) {
|
| 78 |
+
std::string role = qc::optional<std::string>(j, "role", "primary");
|
| 79 |
+
_sampler[role] = Sampler::create(*_ctx, j);
|
| 80 |
+
};
|
| 81 |
+
|
| 82 |
+
const qualla::json& sam_conf = qc::mandatory<qualla::json>(json, "sampler");
|
| 83 |
+
if (sam_conf.is_array()) {
|
| 84 |
+
for (auto sc : sam_conf) {
|
| 85 |
+
add_sampler(sc);
|
| 86 |
+
}
|
| 87 |
+
} else
|
| 88 |
+
add_sampler(sam_conf);
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
// Create Engine(s)
|
| 94 |
+
auto add_engine = [&](const qualla::json& j) {
|
| 95 |
+
std::string role = qc::optional<std::string>(j, "role", "primary");
|
| 96 |
+
|
| 97 |
+
_engine[role] = Engine::create(*_ctx, j);
|
| 98 |
+
|
| 99 |
+
using FF = Engine::Feature::Flags;
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if (!_engine[role]->supports(FF::OUTPUT_LOGITS))
|
| 103 |
+
throw std::runtime_error("the engine must output Logits");
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
const qualla::json& eng_conf = qc::mandatory<qualla::json>(json, "engine");
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if (eng_conf.is_array()) {
|
| 112 |
+
|
| 113 |
+
for (auto ec : eng_conf) {
|
| 114 |
+
add_engine(ec);
|
| 115 |
+
}
|
| 116 |
+
} else{
|
| 117 |
+
add_engine(eng_conf);
|
| 118 |
+
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
// Store input type (token, embedding, etc) from the engine.
|
| 122 |
+
// This assumes multi-engine usecases use matching input types.
|
| 123 |
+
m_inputType = _engine.begin()->second->getInputType();
|
| 124 |
+
|
| 125 |
+
_kpis.init.update(start.elapsed_usec());
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
Dialog::~Dialog() {}
|
| 129 |
+
|
| 130 |
+
static bool __no_response_query(const std::string&, Sentence::Code) {
|
| 131 |
+
return false;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static bool __no_response_token(const int32_t*, const uint32_t, Sentence::Code) {
|
| 135 |
+
return false;
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
static bool __no_response(const std::string&, Sentence::Code) {
|
| 139 |
+
return false;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
void Dialog::getTopK(std::vector<float>& logits, std::vector<std::vector<int32_t>>& tokens, size_t topK, float pThreshold, Dialog::Callback callback) {
|
| 143 |
+
|
| 144 |
+
auto& sampler = *_sampler["primary"];
|
| 145 |
+
|
| 146 |
+
// Sample top-k logits but with a minimum probability threshold
|
| 147 |
+
#if defined(__GNUC__) && !defined(__clang__)
|
| 148 |
+
std::span<float> indexed_logits_span(logits);
|
| 149 |
+
IndexedLogits indexed_logits(indexed_logits_span, sampler.rng());
|
| 150 |
+
#else
|
| 151 |
+
IndexedLogits indexed_logits(std::span{logits.data(),logits.size()}, sampler.rng());
|
| 152 |
+
#endif
|
| 153 |
+
indexed_logits.softmax();
|
| 154 |
+
indexed_logits.topK(topK);
|
| 155 |
+
|
| 156 |
+
for (int i = 0; i < topK; i++) {
|
| 157 |
+
|
| 158 |
+
_last_tok = indexed_logits.indices[i];
|
| 159 |
+
|
| 160 |
+
// Only sample tokens above some probability threshold
|
| 161 |
+
// TODO: Modify sampling algorithm as necessary
|
| 162 |
+
if (indexed_logits.probs[i] < pThreshold) {
|
| 163 |
+
break;
|
| 164 |
+
} else if (_ctx->is_eos(_last_tok)) {
|
| 165 |
+
callback("", Sentence::CONTINUE);
|
| 166 |
+
} else {
|
| 167 |
+
tokens.push_back({_last_tok});
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
bool Dialog::query(const std::string& str, Sentence::Code scode, Dialog::Callback callback) {
|
| 173 |
+
std::vector<int32_t> p_vec; // prompt tokens
|
| 174 |
+
std::string p_str; // prompt string
|
| 175 |
+
|
| 176 |
+
p_vec.reserve(1024);
|
| 177 |
+
|
| 178 |
+
if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) {
|
| 179 |
+
// Reset prompt/gen counts for new query
|
| 180 |
+
_n_prompt = 0;
|
| 181 |
+
_n_generated = 0;
|
| 182 |
+
_n_previous_prompt = 0;
|
| 183 |
+
_n_previous_generated = 0;
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
if (_last_tok >= 0 && !_ctx->is_eos(_last_tok)) p_vec.push_back(_last_tok);
|
| 187 |
+
|
| 188 |
+
p_str = _inst_tags[0];
|
| 189 |
+
|
| 190 |
+
if (!_n_queries) {
|
| 191 |
+
// First query. Prepend sys-prompt.
|
| 192 |
+
p_str += _sys_tags[0] + _sys_prompt + _sys_tags[1];
|
| 193 |
+
} else {
|
| 194 |
+
// Add EOS explicitly if the last query was aborted prematurely.
|
| 195 |
+
if (_ctx->eos_tok() >= 0) p_vec.push_back(_ctx->eos_tok());
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
// Add BOS
|
| 199 |
+
if (_ctx->bos_tok() >= 0) {
|
| 200 |
+
p_vec.push_back(_ctx->bos_tok());
|
| 201 |
+
}
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
// FIXME: make this more generic
|
| 205 |
+
if (_prompt_type == "llama3") {
|
| 206 |
+
p_str += _sys_tags[0] + _role_tags[1] + _sys_tags[1] + str + _inst_tags[2];
|
| 207 |
+
} else {
|
| 208 |
+
p_str += str;
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
if (scode == Sentence::COMPLETE || scode == Sentence::END) {
|
| 212 |
+
if (_prompt_type == "llama3") {
|
| 213 |
+
p_str += _sys_tags[0] + _role_tags[2] + _sys_tags[1];
|
| 214 |
+
} else {
|
| 215 |
+
p_str += _inst_tags[1];
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
_env->logger().post(Logger::DEBUG, [&]() {
|
| 220 |
+
qualla::json j{{"string", str}, {"prompt", p_str}};
|
| 221 |
+
return fmt::format("dialog-query: {} {}", _ctx->name(), j.dump());
|
| 222 |
+
});
|
| 223 |
+
|
| 224 |
+
_n_queries++;
|
| 225 |
+
|
| 226 |
+
_tokenizer->encode(p_str, p_vec);
|
| 227 |
+
|
| 228 |
+
__DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec);
|
| 229 |
+
__DEBUG("dialog-text: \"{}\"", p_str);
|
| 230 |
+
|
| 231 |
+
if (scode == Sentence::COMPLETE || scode == Sentence::END) {
|
| 232 |
+
// Detect stop sequences here
|
| 233 |
+
if (!_stop_sequence.empty()) {
|
| 234 |
+
_stop_sequence.reset();
|
| 235 |
+
return process(p_vec, [&](const std::string& str, Sentence::Code c) {
|
| 236 |
+
// Check for stop sequence and end inference when stop sequence is found
|
| 237 |
+
if (_stop_sequence.process_next_string(str)) {
|
| 238 |
+
callback(str, c); // Emit sequences until match is complete
|
| 239 |
+
return false;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
// Else, return normal callback function
|
| 243 |
+
return callback(str, c);
|
| 244 |
+
});
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
return process(p_vec, callback);
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
return process(p_vec, __no_response);
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
bool Dialog::query(const std::vector<uint32_t>& input, Sentence::Code scode, qualla::DialogCallback& callback) {
|
| 254 |
+
std::vector<int32_t> p_vec; // prompt tokens
|
| 255 |
+
p_vec.reserve(1024);
|
| 256 |
+
|
| 257 |
+
if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) {
|
| 258 |
+
// Reset prompt/gen counts for new query
|
| 259 |
+
_n_prompt = 0;
|
| 260 |
+
_n_generated = 0;
|
| 261 |
+
_n_previous_prompt = 0;
|
| 262 |
+
_n_previous_generated = 0;
|
| 263 |
+
|
| 264 |
+
if (_last_tok >= 0)
|
| 265 |
+
p_vec.push_back(_last_tok);
|
| 266 |
+
|
| 267 |
+
// Add EOS explicitly if the last query was aborted prematurely.
|
| 268 |
+
if (_n_queries && _last_tok != _ctx->eos_tok()) {
|
| 269 |
+
p_vec.push_back(_ctx->eos_tok());
|
| 270 |
+
}
|
| 271 |
+
// Add BOS
|
| 272 |
+
if (_ctx->bos_tok() >= 0) {
|
| 273 |
+
p_vec.push_back(_ctx->bos_tok());
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
|
| 277 |
+
p_vec.insert(p_vec.end(), input.begin(), input.end());
|
| 278 |
+
__DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec);
|
| 279 |
+
|
| 280 |
+
_n_queries++;
|
| 281 |
+
|
| 282 |
+
if (scode == Sentence::COMPLETE || scode == Sentence::END) {
|
| 283 |
+
return process(p_vec, callback);
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
DialogCallback callback_return_token(QUALLA_CALLBACK_TYPE_TOKEN);
|
| 287 |
+
*(callback_return_token.getTokenCbFunc()) = __no_response_token;
|
| 288 |
+
return process(p_vec, callback_return_token);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
bool Dialog::query(
|
| 292 |
+
std::vector<uint8_t>& embedding_vectors,
|
| 293 |
+
Sentence::Code scode,
|
| 294 |
+
T2ECallback t2eCallback,
|
| 295 |
+
Dialog::Callback callback
|
| 296 |
+
) {
|
| 297 |
+
_n_queries++;
|
| 298 |
+
if (scode == Sentence::COMPLETE || scode == Sentence::END) {
|
| 299 |
+
return process(embedding_vectors, t2eCallback, callback);
|
| 300 |
+
}
|
| 301 |
+
// Only process, no output
|
| 302 |
+
return process(embedding_vectors, t2eCallback, [&](const std::string&, Sentence::Code) {
|
| 303 |
+
return false;
|
| 304 |
+
});
|
| 305 |
+
}
|
| 306 |
+
|
| 307 |
+
bool Dialog::prime(const std::string& str) {
|
| 308 |
+
bool r = query(str, Sentence::COMPLETE, __no_response);
|
| 309 |
+
|
| 310 |
+
// End with EOS as we want the primer to be self-contained
|
| 311 |
+
_last_tok = _ctx->eos_tok();
|
| 312 |
+
|
| 313 |
+
return r;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
bool Dialog::save(const std::string& o_name) {
|
| 317 |
+
Timer start;
|
| 318 |
+
|
| 319 |
+
// Save using session name unless override is provided
|
| 320 |
+
std::string name = o_name.empty() ? _ctx->name() : o_name;
|
| 321 |
+
fs::path save_path = name;
|
| 322 |
+
|
| 323 |
+
if (!_n_past) {
|
| 324 |
+
__ERROR("dialog-save: {} : nothing to save yet", name);
|
| 325 |
+
return false;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
__INFO("dialog-save: saving as {} {}", name, save_path.string());
|
| 329 |
+
|
| 330 |
+
if (!fs::exists(save_path) && !fs::create_directories(save_path)) {
|
| 331 |
+
__ERROR("dialog-save: {} : failed to create cache directory", name);
|
| 332 |
+
return false;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
// Save Dialog state
|
| 336 |
+
qualla::json j{
|
| 337 |
+
{"n-past", _n_past},
|
| 338 |
+
{"n-prompt", _n_prompt},
|
| 339 |
+
{"n-generated", _n_generated},
|
| 340 |
+
{"n-queries", _n_queries},
|
| 341 |
+
{"last-tok", _last_tok}
|
| 342 |
+
};
|
| 343 |
+
{
|
| 344 |
+
fs::path p = save_path / "dialog.json";
|
| 345 |
+
std::ofstream f(p);
|
| 346 |
+
f << j;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
// Save Engines (mandatory)
|
| 350 |
+
for (auto& e : _engine) {
|
| 351 |
+
if (!e.second->save(name)) {
|
| 352 |
+
__ERROR("dialog-save: {} : unable to save {} engine", name, e.first);
|
| 353 |
+
return false;
|
| 354 |
+
}
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
// Save Samplers (optional)
|
| 358 |
+
for (auto& s : _sampler) {
|
| 359 |
+
if (!s.second->save(name)) {
|
| 360 |
+
__WARN("dialog-save: {} : unable to save {} sampler", name, s.first);
|
| 361 |
+
}
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
_kpis.save.update(start.elapsed_usec());
|
| 365 |
+
|
| 366 |
+
return true;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
bool Dialog::restore(const std::string& o_name) {
|
| 370 |
+
Timer start;
|
| 371 |
+
|
| 372 |
+
// Restore using session name unless override is provided
|
| 373 |
+
std::string name = o_name.empty() ? _ctx->name() : o_name;
|
| 374 |
+
fs::path restore_path = name;
|
| 375 |
+
|
| 376 |
+
__INFO("dialog-restore: restoring from {} {}", name, restore_path.string());
|
| 377 |
+
|
| 378 |
+
// Try to restore the Dialog state (optional)
|
| 379 |
+
// If this fails we reset everything and try to restore the engine.
|
| 380 |
+
qualla::json j{};
|
| 381 |
+
{
|
| 382 |
+
fs::path p = restore_path / "dialog.json";
|
| 383 |
+
if (fs::exists(p)) {
|
| 384 |
+
std::ifstream f(p);
|
| 385 |
+
j = qualla::json::parse(f);
|
| 386 |
+
} else {
|
| 387 |
+
__DEBUG("dialog-restore: {} : internal state not restored", name);
|
| 388 |
+
}
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
using qc = qualla::Config;
|
| 392 |
+
_n_past = qc::optional<uint32_t>(j, "n-past", 0);
|
| 393 |
+
_n_prompt = qc::optional<uint32_t>(j, "n-prompt", 0);
|
| 394 |
+
_n_generated = qc::optional<uint32_t>(j, "n-generated", 0);
|
| 395 |
+
_n_queries = qc::optional<uint32_t>(j, "n-queries", 1);
|
| 396 |
+
_last_tok = qc::optional<int32_t>(j, "last-tok", _ctx->eos_tok());
|
| 397 |
+
|
| 398 |
+
// Restore Engines (mandatory)
|
| 399 |
+
for (auto& e : _engine) {
|
| 400 |
+
uint32_t n = e.second->restore(name);
|
| 401 |
+
if (!n) {
|
| 402 |
+
__ERROR("dialog-restore: {} : unable to restore {} engine", name, e.first);
|
| 403 |
+
return false;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
// Restore n_past from the engine state
|
| 407 |
+
if (_n_past && n != _n_past) {
|
| 408 |
+
__WARN("dialog-restore: {} : n-past mismatch : {} engine {} intern {}",
|
| 409 |
+
name,
|
| 410 |
+
e.first,
|
| 411 |
+
_n_past,
|
| 412 |
+
n);
|
| 413 |
+
// Keep the smaller number
|
| 414 |
+
_n_past = std::min(n, _n_past);
|
| 415 |
+
} else
|
| 416 |
+
_n_past = n;
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
// Restore Samplers (optional)
|
| 420 |
+
for (auto& s : _sampler) {
|
| 421 |
+
if (!s.second->restore(name)) {
|
| 422 |
+
__WARN("dialog-restore: {} : unable to restore {} sampler", name, s.first);
|
| 423 |
+
}
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
_kpis.reset();
|
| 427 |
+
_kpis.restore.update(start.elapsed_usec());
|
| 428 |
+
|
| 429 |
+
return true;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
void Dialog::reset() {
|
| 433 |
+
__INFO("dialog-reset: {}", _ctx->name());
|
| 434 |
+
|
| 435 |
+
_n_past = 0;
|
| 436 |
+
_n_prompt = 0;
|
| 437 |
+
_n_generated = 0;
|
| 438 |
+
_n_queries = 0;
|
| 439 |
+
_last_tok = -1;
|
| 440 |
+
_n_previous_prompt = 0;
|
| 441 |
+
_n_previous_generated = 0;
|
| 442 |
+
|
| 443 |
+
_kpis.reset();
|
| 444 |
+
|
| 445 |
+
// Reset Engines and Samplers
|
| 446 |
+
for (auto& e : _engine)
|
| 447 |
+
e.second->reset();
|
| 448 |
+
for (auto& s : _sampler)
|
| 449 |
+
s.second->reset();
|
| 450 |
+
|
| 451 |
+
State::clear();
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
// Dialog KPIs helpers
|
| 455 |
+
|
| 456 |
+
// Get latest KPIs
|
| 457 |
+
Dialog::KPIs& Dialog::kpis() {
|
| 458 |
+
// Update TPS
|
| 459 |
+
if (_n_prompt) {
|
| 460 |
+
float t = _kpis.prompt.last_usec / _n_prompt;
|
| 461 |
+
_kpis.tps.n_prompt = _n_prompt;
|
| 462 |
+
_kpis.tps.prompt = 1000000.0 / (t ? t : 1000000.0);
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
if (_n_generated) {
|
| 466 |
+
float t = _kpis.generate.last_usec / _n_generated;
|
| 467 |
+
_kpis.tps.n_generate = _n_generated;
|
| 468 |
+
_kpis.tps.generate = 1000000.0 / (t ? t : 1000000.0);
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
// We could synthesize more KPIs from from other layers (engine, sampler, etc)
|
| 472 |
+
return _kpis;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
std::string Dialog::KPIs::dump(std::string_view sep) const {
|
| 476 |
+
return fmt::format(
|
| 477 |
+
"init:[{}]{}prompt:[{}]{}generate:[{}]{}save:[{}]{}restore:[{}]{} tps-prompt:{:.2f} tps-generate:{:.2f}",
|
| 478 |
+
init.dump(),
|
| 479 |
+
sep,
|
| 480 |
+
prompt.dump(),
|
| 481 |
+
sep,
|
| 482 |
+
generate.dump(),
|
| 483 |
+
sep,
|
| 484 |
+
save.dump(),
|
| 485 |
+
sep,
|
| 486 |
+
restore.dump(),
|
| 487 |
+
sep,
|
| 488 |
+
tps.prompt,
|
| 489 |
+
tps.generate
|
| 490 |
+
);
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
void Dialog::KPIs::reset() {
|
| 494 |
+
init.reset();
|
| 495 |
+
prompt.reset();
|
| 496 |
+
generate.reset();
|
| 497 |
+
save.reset();
|
| 498 |
+
restore.reset();
|
| 499 |
+
tps.prompt = 0.0;
|
| 500 |
+
tps.generate = 0.0;
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
// Create API
|
| 504 |
+
|
| 505 |
+
// Dialog registry : type string + creator function
|
| 506 |
+
using Registry = std::unordered_map<std::string, Dialog::Creator>;
|
| 507 |
+
static std::unique_ptr<Registry> registry;
|
| 508 |
+
|
| 509 |
+
void Dialog::__register(const std::string& type, Creator func) {
|
| 510 |
+
if (!registry) registry = std::make_unique<Registry>();
|
| 511 |
+
|
| 512 |
+
Registry& r = *registry;
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
r[type] = func;
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
std::unique_ptr<Dialog> Dialog::create(
|
| 519 |
+
std::shared_ptr<Env> env,
|
| 520 |
+
const std::string& name,
|
| 521 |
+
const qualla::json& conf
|
| 522 |
+
) {
|
| 523 |
+
|
| 524 |
+
using qc = qualla::Config;
|
| 525 |
+
std::string type = qc::optional<std::string>(conf, "type", "basic");
|
| 526 |
+
|
| 527 |
+
if (!registry) throw std::runtime_error(type + ": dialog not found");
|
| 528 |
+
|
| 529 |
+
Registry& r = *registry;
|
| 530 |
+
|
| 531 |
+
if (!r.contains(type)) throw std::runtime_error(type + ": dialog not found");
|
| 532 |
+
|
| 533 |
+
if (!r.contains(type)) {
|
| 534 |
+
throw std::runtime_error(type + ": dialog not found");
|
| 535 |
+
}
|
| 536 |
+
|
| 537 |
+
return std::unique_ptr<Dialog>(r[type](env, name, conf));
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
std::unique_ptr<Dialog> Dialog::create(
|
| 541 |
+
std::shared_ptr<Env> env,
|
| 542 |
+
const std::string& name,
|
| 543 |
+
std::istream& json_stream
|
| 544 |
+
) {
|
| 545 |
+
|
| 546 |
+
return create(env, name, json::parse(json_stream));
|
| 547 |
+
}
|
| 548 |
+
|
| 549 |
+
std::unique_ptr<Dialog> Dialog::create(
|
| 550 |
+
std::shared_ptr<Env> env,
|
| 551 |
+
const std::string& name,
|
| 552 |
+
const fs::path& json_path
|
| 553 |
+
) {
|
| 554 |
+
|
| 555 |
+
if (!fs::exists(json_path))
|
| 556 |
+
throw std::runtime_error(json_path.string() + ": file does not exist");
|
| 557 |
+
std::ifstream ifs(json_path);
|
| 558 |
+
return create(env, name, ifs);
|
| 559 |
+
}
|
| 560 |
+
|
| 561 |
+
std::vector<std::string> Dialog::list() {
|
| 562 |
+
std::vector<std::string> v;
|
| 563 |
+
if (!registry) return v;
|
| 564 |
+
|
| 565 |
+
Registry& r = *registry;
|
| 566 |
+
|
| 567 |
+
for (auto k : r)
|
| 568 |
+
v.push_back(k.first);
|
| 569 |
+
v.push_back("basic"); // default type, always registered
|
| 570 |
+
return v;
|
| 571 |
+
}
|
| 572 |
+
|
| 573 |
+
bool Dialog::applyLoraAdapter(std::string lora_adapter_name, std::string engine_role) {
|
| 574 |
+
auto& engine = *_engine[engine_role];
|
| 575 |
+
if (!engine.applyLoraAdapter(lora_adapter_name)) {
|
| 576 |
+
__WARN("dialog-applyLoraAdapter: failed for {}", lora_adapter_name);
|
| 577 |
+
return false;
|
| 578 |
+
}
|
| 579 |
+
return true;
|
| 580 |
+
}
|
| 581 |
+
bool Dialog::applyLoraStrength(std::string tensor_name, float tensor_val, std::string engine_role) {
|
| 582 |
+
auto& engine = *_engine[engine_role];
|
| 583 |
+
if (!engine.applyLoraStrength(tensor_name, tensor_val)) {
|
| 584 |
+
__WARN("dialog-applyLoraStrength: failed for {}", tensor_name);
|
| 585 |
+
return false;
|
| 586 |
+
}
|
| 587 |
+
return true;
|
| 588 |
+
}
|
| 589 |
+
|
| 590 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialogs/basic.cpp
ADDED
|
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All rights reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/dialog.hpp>
|
| 10 |
+
#include <qualla/logger.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
#include <qualla/detail/timer.hpp>
|
| 13 |
+
#include <qualla/detail/onload.hpp>
|
| 14 |
+
#include <qualla/detail/basic-dialog.hpp>
|
| 15 |
+
|
| 16 |
+
#include <functional>
|
| 17 |
+
#include <filesystem>
|
| 18 |
+
#include <string>
|
| 19 |
+
|
| 20 |
+
#include <fmt/format.h>
|
| 21 |
+
#include <fmt/ranges.h>
|
| 22 |
+
|
| 23 |
+
namespace fs = std::filesystem;
|
| 24 |
+
|
| 25 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 26 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 27 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 28 |
+
#define __KPIS(__fmt, ...) \
|
| 29 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 30 |
+
#define __DEBUG(__fmt, ...) \
|
| 31 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 32 |
+
#define __TRACE(__fmt, ...) \
|
| 33 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 34 |
+
|
| 35 |
+
namespace qualla {
|
| 36 |
+
|
| 37 |
+
BasicDialog::BasicDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf) : Dialog(env, name, conf) {
|
| 38 |
+
if (!_engine.contains("primary")) {
|
| 39 |
+
State::fatal("\"primary\" engine not present in config!");
|
| 40 |
+
return;
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
bool BasicDialog::processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, Dialog::Callback callback){
|
| 45 |
+
|
| 46 |
+
auto& sampler = *_sampler["primary"];
|
| 47 |
+
auto& engine = *_engine["primary"];
|
| 48 |
+
|
| 49 |
+
while (true) {
|
| 50 |
+
if (State::canceled()) {
|
| 51 |
+
callback("", Sentence::END);
|
| 52 |
+
break;
|
| 53 |
+
}
|
| 54 |
+
// This condition is valid for both tokens and embedding
|
| 55 |
+
if (_n_past + 1 > _ctx->size()) {
|
| 56 |
+
__WARN("Context limit exceeded ({} + 1 > {})", _n_past, _ctx->size());
|
| 57 |
+
callback("", Sentence::END);
|
| 58 |
+
break;
|
| 59 |
+
}
|
| 60 |
+
if (m_inputType == InputType::TOKENS) {
|
| 61 |
+
if (!engine.process(tokens, logits))
|
| 62 |
+
return Dialog::abort("engine processing failed", callback);
|
| 63 |
+
} else if(m_inputType == InputType::EMBEDDINGS) {
|
| 64 |
+
// Convert tokens to embedding for the processing in the engine.
|
| 65 |
+
auto embedBufSize = engine.getEmbeddingBufferSize();
|
| 66 |
+
std::vector<uint8_t> embedding;
|
| 67 |
+
for(auto &token: tokens){
|
| 68 |
+
std::vector<uint8_t> curTokenEmbedding(embedBufSize,0);
|
| 69 |
+
m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
|
| 70 |
+
embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
|
| 71 |
+
}
|
| 72 |
+
if (!engine.process(embedding, {}, logits))
|
| 73 |
+
return Dialog::abort("engine processing failed", callback);
|
| 74 |
+
}
|
| 75 |
+
else{
|
| 76 |
+
return Dialog::abort("No valid Input Type is used", callback);
|
| 77 |
+
}
|
| 78 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 79 |
+
|
| 80 |
+
_n_past++;
|
| 81 |
+
_n_generated++;
|
| 82 |
+
|
| 83 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 84 |
+
|
| 85 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 86 |
+
callback("", Sentence::END);
|
| 87 |
+
break;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break;
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
return true;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
bool BasicDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
|
| 97 |
+
// Check for prev failures and bail out early
|
| 98 |
+
if (State::failed()) return false;
|
| 99 |
+
|
| 100 |
+
Timer start;
|
| 101 |
+
|
| 102 |
+
if(m_inputType != InputType::TOKENS) {
|
| 103 |
+
__ERROR("Input type for model is not tokens.");
|
| 104 |
+
return false;
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
_gpio_marker->set();
|
| 108 |
+
|
| 109 |
+
// Vector for storing logits.
|
| 110 |
+
// Allocated & filled by the engine.
|
| 111 |
+
std::vector<float> logits;
|
| 112 |
+
|
| 113 |
+
State::clear();
|
| 114 |
+
|
| 115 |
+
auto& sampler = *_sampler["primary"];
|
| 116 |
+
auto& engine = *_engine["primary"];
|
| 117 |
+
|
| 118 |
+
using FF = Engine::Feature::Flags;
|
| 119 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 120 |
+
|
| 121 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 122 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 123 |
+
callback("", Sentence::END);
|
| 124 |
+
return true;
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
if (!engine.process(tokens, logits, false))
|
| 128 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 129 |
+
|
| 130 |
+
_n_prompt += tokens.size();
|
| 131 |
+
_n_past += tokens.size();
|
| 132 |
+
|
| 133 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 134 |
+
|
| 135 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 136 |
+
tokens.resize(1);
|
| 137 |
+
|
| 138 |
+
_n_generated++;
|
| 139 |
+
|
| 140 |
+
_gpio_marker->set();
|
| 141 |
+
|
| 142 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 143 |
+
|
| 144 |
+
// Log latest KPIs
|
| 145 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 146 |
+
|
| 147 |
+
start.reset();
|
| 148 |
+
|
| 149 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 150 |
+
callback("", Sentence::END);
|
| 151 |
+
return true;
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
|
| 155 |
+
|
| 156 |
+
State::busy(true);
|
| 157 |
+
|
| 158 |
+
processFollowOnGeneration(tokens, logits, callback);
|
| 159 |
+
|
| 160 |
+
State::busy(false);
|
| 161 |
+
|
| 162 |
+
_gpio_marker->set();
|
| 163 |
+
_gpio_marker->reset();
|
| 164 |
+
|
| 165 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 166 |
+
|
| 167 |
+
// Log latest KPIs in a single line
|
| 168 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 169 |
+
|
| 170 |
+
return !State::failed();
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
bool BasicDialog::processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, qualla::DialogCallback callback){
|
| 174 |
+
|
| 175 |
+
auto& sampler = *_sampler["primary"];
|
| 176 |
+
auto& engine = *_engine["primary"];
|
| 177 |
+
|
| 178 |
+
while (true) {
|
| 179 |
+
if (State::canceled()) {
|
| 180 |
+
callback.callBack(nullptr, 0, Sentence::END, tokenizer());
|
| 181 |
+
break;
|
| 182 |
+
}
|
| 183 |
+
// This condition is valid for both tokens and embedding
|
| 184 |
+
if (_n_past + 1 > _ctx->size()) {
|
| 185 |
+
__WARN("Context limit exceeded ({} + 1 > {})", _n_past, _ctx->size());
|
| 186 |
+
callback.callBack(nullptr, 0, Sentence::END, tokenizer());
|
| 187 |
+
break;
|
| 188 |
+
}
|
| 189 |
+
if (m_inputType == InputType::TOKENS) {
|
| 190 |
+
if (!engine.process(tokens, logits))
|
| 191 |
+
return Dialog::abort("engine processing failed", callback);
|
| 192 |
+
} else if(m_inputType == InputType::EMBEDDINGS) {
|
| 193 |
+
// Convert tokens to embedding for the processing in the engine.
|
| 194 |
+
auto embedBufSize = engine.getEmbeddingBufferSize();
|
| 195 |
+
std::vector<uint8_t> embedding;
|
| 196 |
+
for(auto &token: tokens){
|
| 197 |
+
std::vector<uint8_t> curTokenEmbedding(embedBufSize,0);
|
| 198 |
+
m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
|
| 199 |
+
embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
|
| 200 |
+
}
|
| 201 |
+
if (!engine.process(embedding, {}, logits))
|
| 202 |
+
return Dialog::abort("engine processing failed", callback);
|
| 203 |
+
}
|
| 204 |
+
else{
|
| 205 |
+
return Dialog::abort("No valid Input Type is used", callback);
|
| 206 |
+
}
|
| 207 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 208 |
+
|
| 209 |
+
_n_past++;
|
| 210 |
+
_n_generated++;
|
| 211 |
+
|
| 212 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 213 |
+
|
| 214 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 215 |
+
callback.callBack(nullptr, 0, Sentence::END, tokenizer());
|
| 216 |
+
break;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
if (!callback.callBack(tokens.data(), tokens.size(), Sentence::CONTINUE, tokenizer())) break;
|
| 220 |
+
}
|
| 221 |
+
|
| 222 |
+
return true;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
bool BasicDialog::process(std::vector<int32_t>& tokens, qualla::DialogCallback callback) {
|
| 226 |
+
// Check for prev failures and bail out early
|
| 227 |
+
if (State::failed()) return false;
|
| 228 |
+
|
| 229 |
+
Timer start;
|
| 230 |
+
|
| 231 |
+
if(m_inputType != InputType::TOKENS) {
|
| 232 |
+
__ERROR("Input type for model is not tokens.");
|
| 233 |
+
return false;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
_gpio_marker->set();
|
| 237 |
+
|
| 238 |
+
// Vector for storing logits.
|
| 239 |
+
// Allocated & filled by the engine.
|
| 240 |
+
std::vector<float> logits;
|
| 241 |
+
|
| 242 |
+
State::clear();
|
| 243 |
+
|
| 244 |
+
auto& sampler = *_sampler["primary"];
|
| 245 |
+
auto& engine = *_engine["primary"];
|
| 246 |
+
|
| 247 |
+
using FF = Engine::Feature::Flags;
|
| 248 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 249 |
+
|
| 250 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 251 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 252 |
+
callback.callBack(nullptr, 0, Sentence::END, tokenizer());
|
| 253 |
+
return true;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
if (!engine.process(tokens, logits, false)) {
|
| 257 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
_n_prompt += tokens.size();
|
| 261 |
+
_n_past += tokens.size();
|
| 262 |
+
|
| 263 |
+
if (!engine.updateKV(_n_past)) {
|
| 264 |
+
return Dialog::abort("context size exceeded", callback);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 268 |
+
tokens.resize(1);
|
| 269 |
+
|
| 270 |
+
_n_generated++;
|
| 271 |
+
|
| 272 |
+
_gpio_marker->set();
|
| 273 |
+
|
| 274 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 275 |
+
|
| 276 |
+
// Log latest KPIs
|
| 277 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 278 |
+
|
| 279 |
+
start.reset();
|
| 280 |
+
|
| 281 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 282 |
+
callback.callBack(nullptr, 0, Sentence::END, tokenizer());
|
| 283 |
+
return true;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
if (!callback.callBack(tokens.data(), tokens.size(), Sentence::BEGIN, tokenizer()))
|
| 287 |
+
return true;
|
| 288 |
+
|
| 289 |
+
State::busy(true);
|
| 290 |
+
processFollowOnGeneration(tokens, logits, callback);
|
| 291 |
+
State::busy(false);
|
| 292 |
+
|
| 293 |
+
_gpio_marker->set();
|
| 294 |
+
_gpio_marker->reset();
|
| 295 |
+
|
| 296 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 297 |
+
|
| 298 |
+
// Log latest KPIs in a single line
|
| 299 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 300 |
+
|
| 301 |
+
return !State::failed();
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
bool BasicDialog::process(
|
| 305 |
+
std::vector<uint8_t>& embedding_vectors,
|
| 306 |
+
T2ECallback t2eCallback,
|
| 307 |
+
Dialog::Callback callback
|
| 308 |
+
) {
|
| 309 |
+
Timer start;
|
| 310 |
+
|
| 311 |
+
if(m_inputType != InputType::EMBEDDINGS) {
|
| 312 |
+
__ERROR("Input type for model is not embeddings.");
|
| 313 |
+
return false;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
// Vector for storing logits.
|
| 317 |
+
// Allocated & filled by the engine.
|
| 318 |
+
std::vector<float> logits;
|
| 319 |
+
|
| 320 |
+
State::clear();
|
| 321 |
+
|
| 322 |
+
_gpio_marker->set();
|
| 323 |
+
|
| 324 |
+
auto& sampler = *_sampler["primary"];
|
| 325 |
+
auto& engine = *_engine["primary"];
|
| 326 |
+
|
| 327 |
+
// Store the t2e callback for reference during follow-on generation.
|
| 328 |
+
m_t2eCallback = t2eCallback;
|
| 329 |
+
|
| 330 |
+
size_t embedBufSize = engine.getEmbeddingBufferSize();
|
| 331 |
+
|
| 332 |
+
{
|
| 333 |
+
std::vector<uint8_t> eosEmbedding(embedBufSize, 0.0);
|
| 334 |
+
if (m_t2eCallback) {
|
| 335 |
+
m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize);
|
| 336 |
+
}
|
| 337 |
+
// For non-autogenerative usecases (where t2eCallback is not supplied),
|
| 338 |
+
// the EOS vector is all zero. This is fine for models with proper
|
| 339 |
+
// attention masking support, but may degrade accuracy otherwise.
|
| 340 |
+
if (!engine.cacheEosEmbedding(eosEmbedding)) {
|
| 341 |
+
__DEBUG("Failed to set the eos token embedding.");
|
| 342 |
+
return false;
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
using FF = Engine::Feature::Flags;
|
| 347 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 348 |
+
|
| 349 |
+
size_t curTokenCount = embedding_vectors.size() / embedBufSize;
|
| 350 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 351 |
+
start.reset(); // Don't include preprocessing time
|
| 352 |
+
|
| 353 |
+
if (_n_past + curTokenCount > _ctx->size()) {
|
| 354 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokenCount, _ctx->size());
|
| 355 |
+
callback("", Sentence::END);
|
| 356 |
+
return true;
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
if (!engine.process(embedding_vectors, {}, logits))
|
| 360 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 361 |
+
_n_prompt += curTokenCount;
|
| 362 |
+
_n_past += curTokenCount;
|
| 363 |
+
|
| 364 |
+
std::vector<int32_t> tokens(1, 0);
|
| 365 |
+
|
| 366 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 367 |
+
|
| 368 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 369 |
+
|
| 370 |
+
_n_generated++;
|
| 371 |
+
|
| 372 |
+
_gpio_marker->set();
|
| 373 |
+
|
| 374 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 375 |
+
|
| 376 |
+
// Log latest KPIs
|
| 377 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 378 |
+
|
| 379 |
+
start.reset();
|
| 380 |
+
|
| 381 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 382 |
+
callback("", Sentence::END);
|
| 383 |
+
return true;
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) {
|
| 387 |
+
return true;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
if (!m_t2eCallback) {
|
| 391 |
+
callback("", Sentence::END);
|
| 392 |
+
return true;
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
State::busy(true);
|
| 396 |
+
processFollowOnGeneration(tokens, logits, callback);
|
| 397 |
+
State::busy(false);
|
| 398 |
+
|
| 399 |
+
_gpio_marker->set();
|
| 400 |
+
_gpio_marker->reset();
|
| 401 |
+
|
| 402 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 403 |
+
// Log latest KPIs in a single line
|
| 404 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 405 |
+
|
| 406 |
+
return !State::failed();
|
| 407 |
+
}
|
| 408 |
+
|
| 409 |
+
// Registrator instance
|
| 410 |
+
static OnLoad regy([]() {
|
| 411 |
+
Dialog::__register(
|
| 412 |
+
"basic",
|
| 413 |
+
[](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
|
| 414 |
+
return (Dialog*)new BasicDialog(env, name, conf);
|
| 415 |
+
}
|
| 416 |
+
);
|
| 417 |
+
});
|
| 418 |
+
|
| 419 |
+
void needBasicDialog() {}
|
| 420 |
+
|
| 421 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialogs/kv-share.cpp
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All rights reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/dialog.hpp>
|
| 10 |
+
#include <qualla/sampler.hpp>
|
| 11 |
+
#include <qualla/logger.hpp>
|
| 12 |
+
#include <qualla/detail/config.hpp>
|
| 13 |
+
#include <qualla/detail/timer.hpp>
|
| 14 |
+
#include <qualla/detail/onload.hpp>
|
| 15 |
+
#include <qualla/detail/sampler-utils.hpp>
|
| 16 |
+
#include <qualla/detail/basic-sampler.hpp>
|
| 17 |
+
#include <qualla/detail/cache-file.hpp>
|
| 18 |
+
|
| 19 |
+
#include <functional>
|
| 20 |
+
#include <fstream>
|
| 21 |
+
#include <string>
|
| 22 |
+
#include <unordered_map>
|
| 23 |
+
#include <filesystem>
|
| 24 |
+
#include <random>
|
| 25 |
+
|
| 26 |
+
#include <fmt/format.h>
|
| 27 |
+
#include <fmt/ranges.h>
|
| 28 |
+
|
| 29 |
+
namespace fs = std::filesystem;
|
| 30 |
+
|
| 31 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 32 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 33 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 34 |
+
#define __KPIS(__fmt, ...) \
|
| 35 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 36 |
+
#define __DEBUG(__fmt, ...) \
|
| 37 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 38 |
+
#define __TRACE(__fmt, ...) \
|
| 39 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 40 |
+
|
| 41 |
+
namespace qualla {
|
| 42 |
+
|
| 43 |
+
using qc = qualla::Config;
|
| 44 |
+
|
| 45 |
+
class KvShareDialog : public Dialog {
|
| 46 |
+
public:
|
| 47 |
+
KvShareDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
|
| 48 |
+
: Dialog(env, name, conf) {}
|
| 49 |
+
|
| 50 |
+
virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;
|
| 51 |
+
|
| 52 |
+
virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
|
| 53 |
+
return false;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
virtual void reset() override;
|
| 57 |
+
|
| 58 |
+
bool convertKV(const fs::path& cache_dir);
|
| 59 |
+
|
| 60 |
+
};
|
| 61 |
+
|
| 62 |
+
void KvShareDialog::reset() {
|
| 63 |
+
__INFO("dialog-reset: {}", _ctx->name());
|
| 64 |
+
|
| 65 |
+
_n_past = 0;
|
| 66 |
+
_n_prompt = 0;
|
| 67 |
+
_n_generated = 0;
|
| 68 |
+
_n_queries = 0;
|
| 69 |
+
_last_tok = -1;
|
| 70 |
+
|
| 71 |
+
_kpis.reset();
|
| 72 |
+
|
| 73 |
+
// Reset Samplers
|
| 74 |
+
for (auto& s : _sampler)
|
| 75 |
+
s.second->reset();
|
| 76 |
+
|
| 77 |
+
// Reset Engines
|
| 78 |
+
for (auto& e : _engine) {
|
| 79 |
+
e.second->reset();
|
| 80 |
+
e.second->unload();
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
State::clear();
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
bool KvShareDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
|
| 87 |
+
|
| 88 |
+
// Check for prev failures and bail out early
|
| 89 |
+
if (State::failed()) return false;
|
| 90 |
+
|
| 91 |
+
Timer start;
|
| 92 |
+
|
| 93 |
+
// Vector for storing logits.
|
| 94 |
+
// Allocated & filled by the engine.
|
| 95 |
+
std::vector<float> logits;
|
| 96 |
+
|
| 97 |
+
State::clear();
|
| 98 |
+
|
| 99 |
+
auto& sampler = *_sampler["primary"];
|
| 100 |
+
|
| 101 |
+
auto& p_engine = *_engine["primary"]; // prompt
|
| 102 |
+
auto& s_engine = *_engine["secondary"]; // generation
|
| 103 |
+
|
| 104 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 105 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 106 |
+
callback("", Sentence::END);
|
| 107 |
+
return true;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
if (!p_engine.process(tokens, logits))
|
| 111 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 112 |
+
|
| 113 |
+
_n_prompt += tokens.size();
|
| 114 |
+
_n_past += tokens.size();
|
| 115 |
+
|
| 116 |
+
if (!p_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 117 |
+
|
| 118 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 119 |
+
tokens.resize(1);
|
| 120 |
+
|
| 121 |
+
_n_generated++;
|
| 122 |
+
|
| 123 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 124 |
+
// Log latest KPIs
|
| 125 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 126 |
+
|
| 127 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 128 |
+
callback("", Sentence::END);
|
| 129 |
+
return true;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
|
| 133 |
+
|
| 134 |
+
__DEBUG("dialog: {} : switching engines", _ctx->name());
|
| 135 |
+
{
|
| 136 |
+
// Setup cache dir for saving the engine state
|
| 137 |
+
std::string cache_name = _ctx->name() + "-kv-share";
|
| 138 |
+
fs::path cache_dir = _env->path().cache / cache_name;
|
| 139 |
+
|
| 140 |
+
if (!fs::exists(cache_dir) && !fs::create_directories(cache_dir)) {
|
| 141 |
+
__ERROR("dialog: {} : failed to create cache directory {}",
|
| 142 |
+
_ctx->name(),
|
| 143 |
+
cache_dir.string());
|
| 144 |
+
return Dialog::abort("engine switch failed", callback);
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
// Save and unload the primary engine
|
| 148 |
+
p_engine.save(cache_name);
|
| 149 |
+
p_engine.unload();
|
| 150 |
+
|
| 151 |
+
// The purpose is to save the hyperparams
|
| 152 |
+
s_engine.save(cache_name);
|
| 153 |
+
|
| 154 |
+
convertKV(cache_dir);
|
| 155 |
+
|
| 156 |
+
size_t n = s_engine.restore(cache_name);
|
| 157 |
+
|
| 158 |
+
if(!fs::remove_all(cache_dir)) {
|
| 159 |
+
__WARN("dialog: {} : cache files not closed/dir not found", _ctx->name());
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
if (n != _n_past) {
|
| 163 |
+
__WARN("dialog: {} : kv size mismatch {} expected {}", _ctx->name(), n, _n_past);
|
| 164 |
+
_n_past = n;
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
s_engine.updateKV(_n_past);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
start.reset();
|
| 171 |
+
|
| 172 |
+
State::busy(true);
|
| 173 |
+
|
| 174 |
+
while (true) {
|
| 175 |
+
if (State::canceled()) {
|
| 176 |
+
callback("", Sentence::END);
|
| 177 |
+
break;
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 181 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 182 |
+
callback("", Sentence::END);
|
| 183 |
+
break;
|
| 184 |
+
}
|
| 185 |
+
if (!s_engine.process(tokens, logits))
|
| 186 |
+
return Dialog::abort("secondary engine processing failed", callback);
|
| 187 |
+
|
| 188 |
+
tokens[0] = _last_tok = sampler.process(logits);
|
| 189 |
+
|
| 190 |
+
_n_past++;
|
| 191 |
+
_n_generated++;
|
| 192 |
+
|
| 193 |
+
if (!s_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 194 |
+
|
| 195 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 196 |
+
callback("", Sentence::END);
|
| 197 |
+
break;
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break;
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
State::busy(false);
|
| 204 |
+
|
| 205 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 206 |
+
|
| 207 |
+
// Log latest KPIs in a single line
|
| 208 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 209 |
+
|
| 210 |
+
return true;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
bool KvShareDialog::convertKV(const fs::path& cache_dir) {
|
| 214 |
+
Timer start;
|
| 215 |
+
|
| 216 |
+
fs::path nsp_cache_path = cache_dir / "kv-cache.primary.qnn-htp";
|
| 217 |
+
fs::path cpu_cache_path = cache_dir / "kv-cache.secondary.qnn-cpu";
|
| 218 |
+
|
| 219 |
+
__DEBUG("kv-convert: begin converting {} to ", nsp_cache_path.string(), cpu_cache_path.string());
|
| 220 |
+
|
| 221 |
+
std::ifstream nsp_fs(nsp_cache_path, std::ios::in | std::ios::binary);
|
| 222 |
+
|
| 223 |
+
if (nsp_fs.fail()) {
|
| 224 |
+
__ERROR("kv-convert: error reading file {}", nsp_cache_path.string());
|
| 225 |
+
State::error("failed to read primary kv-cache");
|
| 226 |
+
return false;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
// Read spec from nsp file
|
| 230 |
+
CacheFileSpec nsp_spec;
|
| 231 |
+
nsp_fs.read((char*)&nsp_spec, sizeof(nsp_spec));
|
| 232 |
+
if (nsp_spec.magic != 0xC0DE) {
|
| 233 |
+
__ERROR("kv-convert: expected 0xC0DE found {:#x}", nsp_spec.magic);
|
| 234 |
+
State::error("invalid format of primary kv-cache");
|
| 235 |
+
return false;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
// clang-format off
|
| 239 |
+
__DEBUG("kv-convert: load {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
|
| 240 |
+
nsp_spec.num_tensors, nsp_spec.magic, int(nsp_spec.dtype), nsp_spec.n_heads, nsp_spec.embed_dim, nsp_spec.update_size);
|
| 241 |
+
// clang-format on
|
| 242 |
+
|
| 243 |
+
std::fstream cpu_fs(cpu_cache_path, std::ios::in | std::ios::out | std::ios::binary);
|
| 244 |
+
|
| 245 |
+
if (cpu_fs.fail()) {
|
| 246 |
+
// TODO: replace with proper error handling
|
| 247 |
+
__ERROR("kv-convert: failed to write {}", cpu_cache_path.string());
|
| 248 |
+
State::error("failed to save secondary kv-cache");
|
| 249 |
+
return false;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
CacheFileSpec cpu_spec;
|
| 253 |
+
cpu_fs.read((char*)&cpu_spec, sizeof(cpu_spec));
|
| 254 |
+
if (cpu_spec.magic != 0xC0DE) {
|
| 255 |
+
__ERROR("kv-convert: expected 0xC0DE found {:#x}", cpu_spec.magic);
|
| 256 |
+
State::error("invalid format of secondary kv-cache");
|
| 257 |
+
return false;
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
// Set the n_tokens processed during prompt processing and the spec write to file
|
| 261 |
+
cpu_spec.update_size = nsp_spec.update_size;
|
| 262 |
+
cpu_fs.seekp(std::ios::beg);
|
| 263 |
+
cpu_fs.write((char*)&cpu_spec, sizeof(cpu_spec));
|
| 264 |
+
|
| 265 |
+
const uint32_t n_layer = nsp_spec.num_tensors / 2;
|
| 266 |
+
const uint32_t n_head = nsp_spec.n_heads;
|
| 267 |
+
const uint32_t kv_dim = nsp_spec.embed_dim;
|
| 268 |
+
const uint32_t n_tok = nsp_spec.update_size;
|
| 269 |
+
|
| 270 |
+
const size_t cache_size = n_layer * n_head * kv_dim * n_tok;
|
| 271 |
+
|
| 272 |
+
// Read Key/Value Cache
|
| 273 |
+
std::vector<uint8_t> key_cache(cache_size);
|
| 274 |
+
std::vector<uint8_t> value_cache(cache_size);
|
| 275 |
+
nsp_fs.read((char*)key_cache.data(), cache_size);
|
| 276 |
+
nsp_fs.read((char*)value_cache.data(), cache_size);
|
| 277 |
+
|
| 278 |
+
// Read Quantization parameters
|
| 279 |
+
std::vector<double> key_scales(n_layer);
|
| 280 |
+
std::vector<double> value_scales(n_layer);
|
| 281 |
+
nsp_fs.read((char*)key_scales.data(), n_layer * sizeof(double));
|
| 282 |
+
nsp_fs.read((char*)value_scales.data(), n_layer * sizeof(double));
|
| 283 |
+
|
| 284 |
+
nsp_fs.close();
|
| 285 |
+
|
| 286 |
+
// Convert and write on cpu_file
|
| 287 |
+
// Dequant and transpose caches
|
| 288 |
+
const uint32_t layer_size = n_head * kv_dim * n_tok;
|
| 289 |
+
const uint32_t head_size = kv_dim * n_tok;
|
| 290 |
+
|
| 291 |
+
// Transpose kvdim * n_tok (QNN-HTP K$) -> n_tok * kvdim (QNN-CPU K$)
|
| 292 |
+
// For ScopGPT KV$ Format
|
| 293 |
+
__DEBUG("kv-convert: dequantizing keys");
|
| 294 |
+
std::vector<float> dequant_keys(cache_size);
|
| 295 |
+
for (uint32_t i = 0; i < n_layer; i++) {
|
| 296 |
+
for (uint32_t j = 0; j < n_head; j++) {
|
| 297 |
+
for (uint32_t k = 0; k < kv_dim; k++) {
|
| 298 |
+
for (uint32_t l = 0; l < n_tok; l++) {
|
| 299 |
+
// Interleave K$
|
| 300 |
+
// QNN HTP: [0 2 4 ... 126 1 3 5 ... 127]
|
| 301 |
+
// QNN CPU: [0 1 2 ... 63 64 65 ... 127]
|
| 302 |
+
const uint32_t interleaved_k =
|
| 303 |
+
(2 * k < kv_dim) ? 2 * k : 2 * (k - kv_dim / 2) + 1;
|
| 304 |
+
|
| 305 |
+
const uint32_t read_loc = i * layer_size + j * head_size + k * n_tok + l;
|
| 306 |
+
const uint32_t write_loc = i * layer_size + j * head_size + l * kv_dim + interleaved_k;
|
| 307 |
+
|
| 308 |
+
dequant_keys[write_loc] =
|
| 309 |
+
(static_cast<float>(key_cache[read_loc]) - 128) * key_scales[i];
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
__DEBUG("kv-convert: dequantizing values");
|
| 316 |
+
std::vector<float> dequant_values(cache_size);
|
| 317 |
+
for (uint32_t i = 0; i < n_layer; i++) {
|
| 318 |
+
for (uint32_t j = 0; j < n_head; j++) {
|
| 319 |
+
for (uint32_t l = 0; l < n_tok; l++) {
|
| 320 |
+
for (uint32_t k = 0; k < kv_dim; k++) {
|
| 321 |
+
const uint32_t read_loc = i * layer_size + j * head_size + l * kv_dim + k;
|
| 322 |
+
const uint32_t write_loc = read_loc;
|
| 323 |
+
|
| 324 |
+
dequant_values[write_loc] =
|
| 325 |
+
(static_cast<float>(value_cache[read_loc]) - 128) * value_scales[i];
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
__DEBUG("kv-convert: storing converted KV to file");
|
| 332 |
+
cpu_fs.write((char *)dequant_keys.data(), dequant_keys.size() * sizeof(float));
|
| 333 |
+
cpu_fs.write((char *)dequant_values.data(), dequant_values.size() * sizeof(float));
|
| 334 |
+
|
| 335 |
+
cpu_fs.flush();
|
| 336 |
+
cpu_fs.close();
|
| 337 |
+
|
| 338 |
+
__DEBUG("kv-convert: done converting {} to {} in {} usec",
|
| 339 |
+
nsp_cache_path.string(),
|
| 340 |
+
cpu_cache_path.string(),
|
| 341 |
+
start.elapsed_usec());
|
| 342 |
+
|
| 343 |
+
return true;
|
| 344 |
+
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// Registrator instance
|
| 348 |
+
static OnLoad regy([]() {
|
| 349 |
+
Dialog::__register(
|
| 350 |
+
"kv-share",
|
| 351 |
+
[](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
|
| 352 |
+
return (Dialog*)new KvShareDialog(env, name, conf);
|
| 353 |
+
}
|
| 354 |
+
);
|
| 355 |
+
});
|
| 356 |
+
|
| 357 |
+
void needKvShareDialog() {}
|
| 358 |
+
|
| 359 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialogs/lhd-dec.cpp
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/dialog.hpp>
|
| 10 |
+
#include <qualla/logger.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
#include <qualla/detail/timer.hpp>
|
| 13 |
+
#include <qualla/detail/onload.hpp>
|
| 14 |
+
#include <qualla/detail/lhd-dialog.hpp>
|
| 15 |
+
|
| 16 |
+
#include <functional>
|
| 17 |
+
#include <filesystem>
|
| 18 |
+
#include <string>
|
| 19 |
+
#include <cmath>
|
| 20 |
+
#include <cstdio>
|
| 21 |
+
#include <random>
|
| 22 |
+
|
| 23 |
+
#include <fmt/format.h>
|
| 24 |
+
#include <fmt/ranges.h>
|
| 25 |
+
|
| 26 |
+
namespace fs = std::filesystem;
|
| 27 |
+
|
| 28 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 29 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 30 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 31 |
+
#define __KPIS(__fmt, ...) \
|
| 32 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 33 |
+
#define __DEBUG(__fmt, ...) \
|
| 34 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 35 |
+
#define __TRACE(__fmt, ...) \
|
| 36 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 37 |
+
|
| 38 |
+
namespace qualla {
|
| 39 |
+
|
| 40 |
+
using qc = qualla::Config;
|
| 41 |
+
|
| 42 |
+
LhdDecDialog::LhdDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
|
| 43 |
+
: Dialog(env, name, conf) {
|
| 44 |
+
|
| 45 |
+
_window = qc::optional<size_t>(conf, "window", 8);
|
| 46 |
+
_ngram = qc::optional<size_t>(conf, "ngram", 3);
|
| 47 |
+
_gcap = qc::optional<size_t>(conf, "gcap", 8);
|
| 48 |
+
|
| 49 |
+
_lhd_mode_str = qc::optional<std::string>(conf, "lhd-update-mode", "ALWAYS_FWD_ONE");
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
bool LhdDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
|
| 53 |
+
// Check for prev failures and bail out early
|
| 54 |
+
if (State::failed()) return false;
|
| 55 |
+
|
| 56 |
+
Timer start;
|
| 57 |
+
|
| 58 |
+
// Vector for storing logits.
|
| 59 |
+
// Allocated & filled by the engine.
|
| 60 |
+
std::vector<float> logits;
|
| 61 |
+
std::vector<int32_t> resultTokens;
|
| 62 |
+
|
| 63 |
+
State::clear();
|
| 64 |
+
|
| 65 |
+
auto& sampler = *_sampler["primary"];
|
| 66 |
+
auto& engine = *_engine["primary"];
|
| 67 |
+
|
| 68 |
+
using FF = Engine::Feature::Flags;
|
| 69 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 70 |
+
|
| 71 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 72 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 73 |
+
callback("", Sentence::END);
|
| 74 |
+
return true;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
if (!engine.process(tokens, logits, false))
|
| 78 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 79 |
+
|
| 80 |
+
_n_prompt += tokens.size();
|
| 81 |
+
_n_past += tokens.size();
|
| 82 |
+
|
| 83 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 84 |
+
|
| 85 |
+
std::vector<int32_t> tokens_tmp(1);
|
| 86 |
+
tokens_tmp[0] = _last_tok = sampler.process(logits);
|
| 87 |
+
resultTokens.push_back(_last_tok);
|
| 88 |
+
|
| 89 |
+
_n_generated++;
|
| 90 |
+
|
| 91 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 92 |
+
|
| 93 |
+
// Log latest KPIs
|
| 94 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 95 |
+
|
| 96 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 97 |
+
callback("", Sentence::END);
|
| 98 |
+
return true;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
// Exit condition : Prediction limit reached OR ctx size limit reached
|
| 102 |
+
if (!callback(_tokenizer->decode(tokens_tmp), Sentence::BEGIN)) return true;
|
| 103 |
+
|
| 104 |
+
State::busy(true);
|
| 105 |
+
|
| 106 |
+
// verification branch init
|
| 107 |
+
v_branch.resize(_gcap);
|
| 108 |
+
|
| 109 |
+
// n-gram pools
|
| 110 |
+
const size_t n_vocab = _ctx->n_vocab();
|
| 111 |
+
ngram_container ngrams_pool(n_vocab, _ngram, _gcap);
|
| 112 |
+
|
| 113 |
+
// lookahead branch first level init
|
| 114 |
+
lhd_branch.resize(_ngram - 1);
|
| 115 |
+
lhd_branch_prev.resize(_window);
|
| 116 |
+
|
| 117 |
+
for (int j = 0; j < _ngram - 1; j++) {
|
| 118 |
+
lhd_branch[j].resize(_window);
|
| 119 |
+
|
| 120 |
+
for (int i = 0; i < _window; i++) {
|
| 121 |
+
if (j == 0) {
|
| 122 |
+
// initialize with the random token from prompt
|
| 123 |
+
lhd_branch[j][i] = tokens[1 + rand() % (tokens.size() - 1)];
|
| 124 |
+
} else {
|
| 125 |
+
// initialize with a sequence of increasing numbers
|
| 126 |
+
lhd_branch[j][i] = 1000 + i;
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
// lookahead branch other level init
|
| 132 |
+
while (_level_idx < _ngram - 1) {
|
| 133 |
+
|
| 134 |
+
batch.clear();
|
| 135 |
+
attention_map.clear();
|
| 136 |
+
|
| 137 |
+
// fill the first token of the first level
|
| 138 |
+
batch.push_back(_last_tok);
|
| 139 |
+
attention_map.push_back(-1);
|
| 140 |
+
lhd_branch[0][0] = _last_tok;
|
| 141 |
+
|
| 142 |
+
// fill the remaining WINDOW - 1 tokens for the first level
|
| 143 |
+
for (int i = 1; i < _window; i++) {
|
| 144 |
+
batch.push_back(lhd_branch[0][i]);
|
| 145 |
+
attention_map.push_back(i - 1);
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
// fill the rest of the levels
|
| 149 |
+
for (int j = 1; j < _ngram - 1; j++) {
|
| 150 |
+
for (int i = 0; i < _window; i++) {
|
| 151 |
+
batch.push_back(lhd_branch[j][i]);
|
| 152 |
+
attention_map.push_back((j - 1) * _window + i);
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// re-init tokens batch
|
| 157 |
+
tokens.resize(_window * (_ngram - 1));
|
| 158 |
+
tokens = batch;
|
| 159 |
+
|
| 160 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 161 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 162 |
+
callback("", Sentence::END);
|
| 163 |
+
break;
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
size_t n_tok = engine.process(tokens, attention_map, logits, true);
|
| 167 |
+
if (n_tok != tokens.size())
|
| 168 |
+
return Dialog::abort("engine lookahead branch processing failed", callback);
|
| 169 |
+
|
| 170 |
+
for (int i = 0; i < _window; i++) {
|
| 171 |
+
size_t sample_tmp_idx = (_level_idx - 1) * _window + i;
|
| 172 |
+
// sampler from logits all
|
| 173 |
+
std::span<float> span_logits{logits.data(),logits.size()};
|
| 174 |
+
std::span<float> span_tmp = span_logits.subspan(sample_tmp_idx * n_vocab, n_vocab);
|
| 175 |
+
int32_t sampled_tmp_token = sampler.process(span_tmp);
|
| 176 |
+
lhd_branch[_level_idx][i] = sampled_tmp_token;
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
_level_idx++;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
if (_lhd_mode_str == "FWD_MAX_HIT")
|
| 183 |
+
_lhd_update_mode = FWD_MAX_HIT;
|
| 184 |
+
else if (_lhd_mode_str == "FWD_LEVEL")
|
| 185 |
+
_lhd_update_mode = FWD_LEVEL;
|
| 186 |
+
else
|
| 187 |
+
_lhd_update_mode = ALWAYS_FWD_ONE;
|
| 188 |
+
|
| 189 |
+
start.reset();
|
| 190 |
+
|
| 191 |
+
while (true) {
|
| 192 |
+
if (State::canceled()) {
|
| 193 |
+
callback("", Sentence::END);
|
| 194 |
+
break;
|
| 195 |
+
}
|
| 196 |
+
// input batch init
|
| 197 |
+
{
|
| 198 |
+
batch.clear();
|
| 199 |
+
attention_map.clear();
|
| 200 |
+
|
| 201 |
+
// fill the first token of the first level
|
| 202 |
+
batch.push_back(_last_tok);
|
| 203 |
+
attention_map.push_back(-1);
|
| 204 |
+
// lhd_branch[0][0] = _last_tok;
|
| 205 |
+
|
| 206 |
+
// fill the remaining WINDOW - 1 tokens for the first level
|
| 207 |
+
for (int i = 1; i < _window; i++) {
|
| 208 |
+
batch.push_back(lhd_branch[0][i]);
|
| 209 |
+
attention_map.push_back(i - 1);
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
// fill the rest of the levels
|
| 213 |
+
for (int j = 1; j < _ngram - 1; j++) {
|
| 214 |
+
for (int i = 0; i < _window; i++) {
|
| 215 |
+
batch.push_back(lhd_branch[j][i]);
|
| 216 |
+
attention_map.push_back((j - 1) * _window + i);
|
| 217 |
+
}
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
// build verification n-grams(branch)
|
| 221 |
+
{
|
| 222 |
+
const int g_cur = ngrams_pool.cnt[_last_tok];
|
| 223 |
+
|
| 224 |
+
v_branch.resize(g_cur);
|
| 225 |
+
// input_token_batch.size = (_window + g_cur) * (_ngram - 1);
|
| 226 |
+
tokens.resize((_window + g_cur) * (_ngram - 1));
|
| 227 |
+
for (int g = 0; g < g_cur; g++) {
|
| 228 |
+
v_branch[g].active = true;
|
| 229 |
+
v_branch[g].tokens.resize(_ngram);
|
| 230 |
+
v_branch[g].i_batch.resize(_ngram);
|
| 231 |
+
v_branch[g].seq_id = _window + 1 + g;
|
| 232 |
+
v_branch[g].i_batch[0] = 0;
|
| 233 |
+
v_branch[g].tokens[0] = _last_tok;
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
for (int j = 0; j < _ngram - 1; j++) {
|
| 237 |
+
for (int g = 0; g < g_cur; g++) {
|
| 238 |
+
const int idx = _last_tok * (_ngram - 1) * _gcap + g * (_ngram - 1);
|
| 239 |
+
const int32_t t = ngrams_pool.tokens[idx + j];
|
| 240 |
+
v_branch[g].tokens[j + 1] = t;
|
| 241 |
+
v_branch[g].i_batch[j + 1] = j + 1;
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
|
| 245 |
+
for (int g = 0; g < g_cur; g++) {
|
| 246 |
+
for (int j = 0; j < _ngram - 1; j++) {
|
| 247 |
+
batch.push_back(v_branch[g].tokens[j + 1]);
|
| 248 |
+
if (j == 0)
|
| 249 |
+
attention_map.push_back(0);
|
| 250 |
+
else
|
| 251 |
+
attention_map.push_back(batch.size() - 2);
|
| 252 |
+
}
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
// re-init tokens batch
|
| 258 |
+
std::vector<bool> selected(attention_map.size(), false);
|
| 259 |
+
tokens = batch;
|
| 260 |
+
|
| 261 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 262 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 263 |
+
callback("", Sentence::END);
|
| 264 |
+
break;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
size_t n_tok = engine.process(tokens, attention_map, logits, true);
|
| 268 |
+
if (n_tok != tokens.size()) return Dialog::abort("engine gen processing failed", callback);
|
| 269 |
+
|
| 270 |
+
// verification branch seq-id
|
| 271 |
+
size_t seq_id_best = 0;
|
| 272 |
+
// max hit pos
|
| 273 |
+
size_t i_batch_best = 0;
|
| 274 |
+
|
| 275 |
+
// Lookahead decoding and verification
|
| 276 |
+
for (int v = 0; v < _ngram; ++v) {
|
| 277 |
+
int i_batch = 0;
|
| 278 |
+
|
| 279 |
+
if (v > 0) {
|
| 280 |
+
for (int g = 0; g < (int)v_branch.size(); g++) {
|
| 281 |
+
// record the best matched seq and pos
|
| 282 |
+
if (v_branch[g].active) {
|
| 283 |
+
i_batch = v_branch[g].i_batch[v];
|
| 284 |
+
i_batch_best = i_batch;
|
| 285 |
+
seq_id_best = v_branch[g].seq_id;
|
| 286 |
+
++_n_accept;
|
| 287 |
+
break;
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
if (i_batch == 0) {
|
| 292 |
+
break;
|
| 293 |
+
}
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
size_t sample_idx;
|
| 297 |
+
if (seq_id_best == 0)
|
| 298 |
+
sample_idx = 0;
|
| 299 |
+
else
|
| 300 |
+
sample_idx = _window * (_ngram - 1) + (seq_id_best - (_window + 1)) * (_ngram - 1) +
|
| 301 |
+
i_batch - 1;
|
| 302 |
+
|
| 303 |
+
//vector selected set
|
| 304 |
+
selected[sample_idx] = true;
|
| 305 |
+
|
| 306 |
+
// sampler from logits all
|
| 307 |
+
std::span<float> span_logits{logits.data(),logits.size()};
|
| 308 |
+
std::span<float> sample_logit = span_logits.subspan(sample_idx * n_vocab, n_vocab);
|
| 309 |
+
_last_tok = sampler.process(sample_logit);
|
| 310 |
+
|
| 311 |
+
std::vector<int32_t> tokens_tmp(1);
|
| 312 |
+
tokens_tmp[0] = _last_tok;
|
| 313 |
+
|
| 314 |
+
resultTokens.push_back(_last_tok);
|
| 315 |
+
_n_generated++;
|
| 316 |
+
_n_past++;
|
| 317 |
+
|
| 318 |
+
if (_ctx->is_eos(_last_tok)) break;
|
| 319 |
+
|
| 320 |
+
if (!callback(_tokenizer->decode(tokens_tmp), Sentence::CONTINUE)) return true;
|
| 321 |
+
|
| 322 |
+
// if verify pass, check the next sample token until verifing failed
|
| 323 |
+
for (int g = 0; g < (int)v_branch.size(); g++) {
|
| 324 |
+
// update the n-gram active status
|
| 325 |
+
if (v_branch[g].active) {
|
| 326 |
+
if (v == _ngram - 1) {
|
| 327 |
+
v_branch[g].active = false;
|
| 328 |
+
} else {
|
| 329 |
+
if (_last_tok != v_branch[g].tokens[v + 1]) {
|
| 330 |
+
v_branch[g].active = false;
|
| 331 |
+
}
|
| 332 |
+
}
|
| 333 |
+
}
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
// update lookahead tokens when v=0 OR verify match
|
| 337 |
+
{
|
| 338 |
+
for (int i = 0; i < _window; i++) {
|
| 339 |
+
lhd_branch_prev[i] = lhd_branch[0][i];
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
if (v == 0) {
|
| 343 |
+
for (int j = 0; j < _ngram - 2; j++) {
|
| 344 |
+
lhd_branch[j] = lhd_branch[j + 1];
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
// sample from the last level
|
| 348 |
+
for (int i = 0; i < _window; i++) {
|
| 349 |
+
size_t sample_idx = (_ngram - 2) * _window + i;
|
| 350 |
+
std::span<float> sample_logit =
|
| 351 |
+
span_logits.subspan(sample_idx * n_vocab, n_vocab);
|
| 352 |
+
lhd_branch[_ngram - 2][i] = sampler.process(sample_logit);
|
| 353 |
+
}
|
| 354 |
+
} else {
|
| 355 |
+
if (_lhd_update_mode == FWD_MAX_HIT) {
|
| 356 |
+
// update lookahead branch by foward
|
| 357 |
+
for (int j = 0; j < _ngram - 1; j++) {
|
| 358 |
+
for (int i = 0; i < _window - v; i++) {
|
| 359 |
+
lhd_branch[j][i] = lhd_branch[j][i + 1];
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
} else if (_lhd_update_mode == FWD_LEVEL) {
|
| 363 |
+
// update lookahead branch by shifting level
|
| 364 |
+
for (int j = 0; j < _ngram - 2; j++) {
|
| 365 |
+
lhd_branch[j] = lhd_branch[j + 1];
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
for (int i = 0; i < _window; i++) {
|
| 369 |
+
// init from the previous level
|
| 370 |
+
lhd_branch[_ngram - 2][i] = lhd_branch[0][i];
|
| 371 |
+
}
|
| 372 |
+
}
|
| 373 |
+
}
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
// update n-grams pool
|
| 377 |
+
// only update n-grams pools when v=0
|
| 378 |
+
if (v == 0) {
|
| 379 |
+
std::vector<int32_t> ngram(_ngram - 1);
|
| 380 |
+
// n-gram pool generation
|
| 381 |
+
for (int f = 0; f < _window; ++f) {
|
| 382 |
+
const int ft = lhd_branch_prev[f]; // first token of the n-gram
|
| 383 |
+
|
| 384 |
+
for (int j = 0; j < _ngram - 1; ++j) {
|
| 385 |
+
ngram[j] = lhd_branch[j][f];
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
// filter-out repeating n-grams
|
| 389 |
+
{
|
| 390 |
+
bool is_unique = true;
|
| 391 |
+
|
| 392 |
+
for (int k = 0; k < ngrams_pool.cnt[ft]; ++k) {
|
| 393 |
+
// caculate the related idx by the first n-gram token
|
| 394 |
+
const int idx = ft * (_ngram - 1) * _gcap + k * (_ngram - 1);
|
| 395 |
+
|
| 396 |
+
bool is_match = true;
|
| 397 |
+
for (int j = 0; j < _ngram - 1; ++j) {
|
| 398 |
+
if (ngrams_pool.tokens[idx + j] != ngram[j]) {
|
| 399 |
+
is_match = false;
|
| 400 |
+
break;
|
| 401 |
+
}
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
// if n-gram match all, discard one of them
|
| 405 |
+
if (is_match) {
|
| 406 |
+
is_unique = false;
|
| 407 |
+
break;
|
| 408 |
+
}
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
if (!is_unique) {
|
| 412 |
+
continue;
|
| 413 |
+
}
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
const int head = ngrams_pool.head[ft];
|
| 417 |
+
const int idx = ft * (_ngram - 1) * _gcap + head * (_ngram - 1);
|
| 418 |
+
|
| 419 |
+
for (int i = 0; i < _ngram - 1; i++) {
|
| 420 |
+
// update the n-gram pool with new n-gram
|
| 421 |
+
ngrams_pool.tokens[idx + i] = ngram[i];
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
ngrams_pool.cnt[ft] = std::min(_gcap, ngrams_pool.cnt[ft] + 1);
|
| 425 |
+
ngrams_pool.head[ft] = (head + 1) % _gcap;
|
| 426 |
+
|
| 427 |
+
ngrams_pool.n_total++;
|
| 428 |
+
}
|
| 429 |
+
}
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
if (_lhd_update_mode == FWD_MAX_HIT) {
|
| 433 |
+
// std::random_device rd;
|
| 434 |
+
// std::mt19937 gen(rd());
|
| 435 |
+
// std::uniform_int_distribution<> dis(0, resultTokens.size() - 1);
|
| 436 |
+
|
| 437 |
+
// fill lookahead branch
|
| 438 |
+
for (int i = 0; i < _ngram - 1; i++) {
|
| 439 |
+
for (int j = _window - i_batch_best; j < _window; j++) {
|
| 440 |
+
lhd_branch[i][j] = resultTokens[1 + rand() % (resultTokens.size() - 1)];
|
| 441 |
+
// lhd_branch[i][j] = resultTokens[dis(gen)];
|
| 442 |
+
// std::cout << "Fill token = " << lhd_branch[i][j] << std::endl;
|
| 443 |
+
}
|
| 444 |
+
}
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
// KV cache management
|
| 448 |
+
if (!engine.updateKV(_n_past, selected))
|
| 449 |
+
return Dialog::abort("context size exceeded", callback);
|
| 450 |
+
|
| 451 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 452 |
+
callback("", Sentence::END);
|
| 453 |
+
break;
|
| 454 |
+
}
|
| 455 |
+
}
|
| 456 |
+
|
| 457 |
+
State::busy(false);
|
| 458 |
+
|
| 459 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 460 |
+
|
| 461 |
+
// Log latest KPIs in a single line
|
| 462 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 463 |
+
std::cout << std::endl << std::endl << std::flush;
|
| 464 |
+
__DEBUG("lhd-dec: n_generated = {} ---------- n_accept = {}", _n_generated, _n_accept);
|
| 465 |
+
|
| 466 |
+
return !State::failed();
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
// Registrator instance
|
| 470 |
+
static OnLoad regy([]() {
|
| 471 |
+
Dialog::__register(
|
| 472 |
+
"lhd-dec",
|
| 473 |
+
[](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
|
| 474 |
+
return (Dialog*)new LhdDecDialog(env, name, conf);
|
| 475 |
+
}
|
| 476 |
+
);
|
| 477 |
+
});
|
| 478 |
+
|
| 479 |
+
void needLadeDialog() {}
|
| 480 |
+
|
| 481 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialogs/multistream.cpp
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/dialog.hpp>
|
| 10 |
+
#include <qualla/logger.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
#include <qualla/detail/timer.hpp>
|
| 13 |
+
#include <qualla/detail/onload.hpp>
|
| 14 |
+
#include <qualla/detail/multistream-dialog.hpp>
|
| 15 |
+
#include <qualla/detail/sampler-utils.hpp>
|
| 16 |
+
|
| 17 |
+
#include <functional>
|
| 18 |
+
#include <filesystem>
|
| 19 |
+
#include <string>
|
| 20 |
+
|
| 21 |
+
#include <fmt/format.h>
|
| 22 |
+
#include <fmt/ranges.h>
|
| 23 |
+
|
| 24 |
+
namespace fs = std::filesystem;
|
| 25 |
+
|
| 26 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 27 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 28 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 29 |
+
#define __KPIS(__fmt, ...) \
|
| 30 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 31 |
+
#define __DEBUG(__fmt, ...) \
|
| 32 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 33 |
+
#define __TRACE(__fmt, ...) \
|
| 34 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 35 |
+
|
| 36 |
+
namespace qualla {
|
| 37 |
+
|
| 38 |
+
bool MultiStreamDialog::processFollowOnGeneration(std::vector<std::vector<int32_t>>& streams, std::vector<float>& logits, Dialog::Callback callback) {
|
| 39 |
+
|
| 40 |
+
auto& sampler = *_sampler["primary"];
|
| 41 |
+
auto& engine = *_engine["primary"];
|
| 42 |
+
|
| 43 |
+
std::vector<std::vector<int32_t>> attention_mask(_n_streams);
|
| 44 |
+
std::vector<int32_t> streamIndices;
|
| 45 |
+
|
| 46 |
+
if (streams.size() == 0) {
|
| 47 |
+
callback("\n", Sentence::END);
|
| 48 |
+
return true;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
for (int i = 0; i < streams.size(); i++) {
|
| 52 |
+
// Initialize all attention_masks to attend to all previous tokens
|
| 53 |
+
attention_mask[i].resize(_n_past, 1);
|
| 54 |
+
streamIndices.push_back(i);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
State::busy(true);
|
| 58 |
+
|
| 59 |
+
while (true) {
|
| 60 |
+
if (State::canceled()) break;
|
| 61 |
+
|
| 62 |
+
// If this exceeds context length, truncate all streams and return
|
| 63 |
+
if (_n_past + streamIndices.size() > _ctx->size()) {
|
| 64 |
+
for (auto stream : streamIndices)
|
| 65 |
+
callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
|
| 66 |
+
break;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
// Accumulate input tokens from all streams
|
| 70 |
+
std::vector<int32_t> multi_tokens(streamIndices.size());
|
| 71 |
+
|
| 72 |
+
for (int i = 0; i < streamIndices.size(); i++) {
|
| 73 |
+
multi_tokens[i] = streams[streamIndices[i]].back();
|
| 74 |
+
|
| 75 |
+
// Also add current iteration to the attention_mask
|
| 76 |
+
for (auto _mask_row : streamIndices)
|
| 77 |
+
// Set to true iff on diagonal, i.e. attend to itself
|
| 78 |
+
attention_mask[streamIndices[i]].push_back((streamIndices[i] == _mask_row) ? 1 : 0);
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
// Concatenate attention_mask for all active streams
|
| 82 |
+
std::vector<int32_t> multi_attn_mask;
|
| 83 |
+
multi_attn_mask.reserve((_n_past + streamIndices.size()) * streamIndices.size());
|
| 84 |
+
for (auto i : streamIndices)
|
| 85 |
+
multi_attn_mask.insert(
|
| 86 |
+
multi_attn_mask.end(),
|
| 87 |
+
attention_mask[i].begin(),
|
| 88 |
+
attention_mask[i].end()
|
| 89 |
+
);
|
| 90 |
+
|
| 91 |
+
// __DEBUG("Multi attention mask = {}", multi_attn_mask);
|
| 92 |
+
|
| 93 |
+
if (m_inputType == InputType::TOKENS) {
|
| 94 |
+
// Process input tokens for all streams in one batch
|
| 95 |
+
if (!engine.process(multi_tokens, multi_attn_mask, logits, true))
|
| 96 |
+
return Dialog::abort("engine gen processing failed", callback);
|
| 97 |
+
} else if (m_inputType == InputType::EMBEDDINGS) {
|
| 98 |
+
// Accumulate input embeddings from all streams
|
| 99 |
+
auto embedBufSize = engine.getEmbeddingBufferSize();
|
| 100 |
+
std::vector<uint8_t> multi_embeddings;
|
| 101 |
+
|
| 102 |
+
for (auto token : multi_tokens) {
|
| 103 |
+
// Convert tokens to embedding for the processing in the engine.
|
| 104 |
+
std::vector<uint8_t> curTokenEmbedding(embedBufSize, 0);
|
| 105 |
+
m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
|
| 106 |
+
multi_embeddings.insert(multi_embeddings.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
// Process input tokens for all streams in one batch
|
| 110 |
+
if (!engine.process(multi_embeddings, multi_attn_mask, logits, true))
|
| 111 |
+
return Dialog::abort("engine gen processing failed", callback);
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// Process all logits independently
|
| 115 |
+
std::span<float> logit_span = std::span{logits.data(),logits.size()};
|
| 116 |
+
for (int i = 0; i < streamIndices.size(); i++) {
|
| 117 |
+
_last_tok = sampler.process(logit_span.subspan(i * _vocab, _vocab));
|
| 118 |
+
streams[streamIndices[i]].push_back(_last_tok);
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
_n_past += streamIndices.size();
|
| 122 |
+
_n_generated += streamIndices.size();
|
| 123 |
+
|
| 124 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 125 |
+
|
| 126 |
+
for (auto it = streamIndices.begin(); it != streamIndices.end();) {
|
| 127 |
+
int32_t stream = *it;
|
| 128 |
+
if (_ctx->is_eos(streams[stream].back())) {
|
| 129 |
+
callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
|
| 130 |
+
it = streamIndices.erase(it);
|
| 131 |
+
} else {
|
| 132 |
+
++it;
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
if (streamIndices.size() == 0) break;
|
| 137 |
+
}
|
| 138 |
+
callback("\n", Sentence::END);
|
| 139 |
+
|
| 140 |
+
State::busy(false);
|
| 141 |
+
|
| 142 |
+
return true;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
bool MultiStreamDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
|
| 146 |
+
// Check for prev failures and bail out early
|
| 147 |
+
if (State::failed()) return false;
|
| 148 |
+
|
| 149 |
+
Timer start;
|
| 150 |
+
|
| 151 |
+
if(m_inputType != InputType::TOKENS) {
|
| 152 |
+
__ERROR("Input type for model is not tokens.");
|
| 153 |
+
return false;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
// Vector for storing logits.
|
| 157 |
+
// Allocated & filled by the engine.
|
| 158 |
+
std::vector<float> logits;
|
| 159 |
+
|
| 160 |
+
State::clear();
|
| 161 |
+
|
| 162 |
+
auto& engine = *_engine["primary"];
|
| 163 |
+
|
| 164 |
+
using FF = Engine::Feature::Flags;
|
| 165 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 166 |
+
|
| 167 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 168 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 169 |
+
callback("", Sentence::END);
|
| 170 |
+
return true;
|
| 171 |
+
}
|
| 172 |
+
|
| 173 |
+
if (!engine.process(tokens, logits, false))
|
| 174 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 175 |
+
|
| 176 |
+
_n_prompt += tokens.size();
|
| 177 |
+
_n_past += tokens.size();
|
| 178 |
+
|
| 179 |
+
_prompt_len = _n_past;
|
| 180 |
+
|
| 181 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 182 |
+
|
| 183 |
+
std::vector<std::vector<int32_t>> streams;
|
| 184 |
+
getTopK(logits, streams, _n_streams, _p_threshold, callback);
|
| 185 |
+
|
| 186 |
+
_n_generated += streams.size();
|
| 187 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 188 |
+
|
| 189 |
+
// Log latest KPIs
|
| 190 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 191 |
+
|
| 192 |
+
start.reset();
|
| 193 |
+
|
| 194 |
+
bool status = processFollowOnGeneration(streams, logits, callback);
|
| 195 |
+
|
| 196 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 197 |
+
|
| 198 |
+
// Log latest KPIs in a single line
|
| 199 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 200 |
+
|
| 201 |
+
return status;
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
bool MultiStreamDialog::process(
|
| 205 |
+
std::vector<uint8_t>& embedding_vectors,
|
| 206 |
+
T2ECallback t2eCallback,
|
| 207 |
+
Dialog::Callback callback
|
| 208 |
+
) {
|
| 209 |
+
// Check for prev failures and bail out early
|
| 210 |
+
if (State::failed()) return false;
|
| 211 |
+
|
| 212 |
+
Timer start;
|
| 213 |
+
|
| 214 |
+
if(m_inputType != InputType::EMBEDDINGS) {
|
| 215 |
+
__ERROR("Input type for model is not embeddings.");
|
| 216 |
+
return false;
|
| 217 |
+
}
|
| 218 |
+
|
| 219 |
+
// Vector for storing logits.
|
| 220 |
+
// Allocated & filled by the engine.
|
| 221 |
+
std::vector<float> logits;
|
| 222 |
+
|
| 223 |
+
State::clear();
|
| 224 |
+
|
| 225 |
+
auto& sampler = *_sampler["primary"];
|
| 226 |
+
auto& engine = *_engine["primary"];
|
| 227 |
+
|
| 228 |
+
// Store the t2e callback for reference during follow-on generation.
|
| 229 |
+
m_t2eCallback = t2eCallback;
|
| 230 |
+
|
| 231 |
+
size_t embedBufSize = engine.getEmbeddingBufferSize();
|
| 232 |
+
|
| 233 |
+
{
|
| 234 |
+
std::vector<uint8_t> eosEmbedding(embedBufSize, 0.0);
|
| 235 |
+
if (m_t2eCallback) {
|
| 236 |
+
m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize);
|
| 237 |
+
}
|
| 238 |
+
// For non-autogenerative usecases (where t2eCallback is not supplied),
|
| 239 |
+
// the EOS vector is all zero. This is fine for models with proper
|
| 240 |
+
// attention masking support, but may degrade accuracy otherwise.
|
| 241 |
+
if (!engine.cacheEosEmbedding(eosEmbedding)) {
|
| 242 |
+
__DEBUG("Failed to set the eos token embedding.");
|
| 243 |
+
return false;
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
using FF = Engine::Feature::Flags;
|
| 248 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 249 |
+
|
| 250 |
+
size_t curTokenCount = embedding_vectors.size() / embedBufSize;
|
| 251 |
+
if (_n_past + curTokenCount > _ctx->size()) {
|
| 252 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokenCount, _ctx->size());
|
| 253 |
+
callback("", Sentence::END);
|
| 254 |
+
return true;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
if (!engine.process(embedding_vectors, {}, logits))
|
| 258 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 259 |
+
|
| 260 |
+
_n_prompt += curTokenCount;
|
| 261 |
+
_n_past += curTokenCount;
|
| 262 |
+
|
| 263 |
+
_prompt_len = _n_past;
|
| 264 |
+
|
| 265 |
+
if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
|
| 266 |
+
|
| 267 |
+
std::vector<std::vector<int32_t>> streams;
|
| 268 |
+
getTopK(logits, streams, _n_streams, _p_threshold, callback);
|
| 269 |
+
|
| 270 |
+
_n_generated += streams.size();
|
| 271 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 272 |
+
|
| 273 |
+
// Log latest KPIs
|
| 274 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 275 |
+
|
| 276 |
+
start.reset();
|
| 277 |
+
|
| 278 |
+
bool status = processFollowOnGeneration(streams, logits, callback);
|
| 279 |
+
|
| 280 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 281 |
+
|
| 282 |
+
// Log latest KPIs in a single line
|
| 283 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 284 |
+
|
| 285 |
+
return status;
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
// Registrator instance
|
| 289 |
+
static OnLoad regy([]() {
|
| 290 |
+
Dialog::__register(
|
| 291 |
+
"multistream",
|
| 292 |
+
[](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
|
| 293 |
+
return (Dialog*)new MultiStreamDialog(env, name, conf);
|
| 294 |
+
}
|
| 295 |
+
);
|
| 296 |
+
});
|
| 297 |
+
|
| 298 |
+
void needMultistreamDialog() {}
|
| 299 |
+
|
| 300 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialogs/spec-dec.cpp
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All rights reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/dialog.hpp>
|
| 10 |
+
#include <qualla/sampler.hpp>
|
| 11 |
+
#include <qualla/logger.hpp>
|
| 12 |
+
#include <qualla/detail/config.hpp>
|
| 13 |
+
#include <qualla/detail/timer.hpp>
|
| 14 |
+
#include <qualla/detail/onload.hpp>
|
| 15 |
+
#include <qualla/detail/sampler-utils.hpp>
|
| 16 |
+
#include <qualla/detail/basic-sampler.hpp>
|
| 17 |
+
|
| 18 |
+
#include <functional>
|
| 19 |
+
#include <fstream>
|
| 20 |
+
#include <string>
|
| 21 |
+
#include <unordered_map>
|
| 22 |
+
#include <filesystem>
|
| 23 |
+
#include <random>
|
| 24 |
+
#include <thread>
|
| 25 |
+
|
| 26 |
+
#include <fmt/format.h>
|
| 27 |
+
#include <fmt/ranges.h>
|
| 28 |
+
|
| 29 |
+
namespace fs = std::filesystem;
|
| 30 |
+
|
| 31 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 32 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 33 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 34 |
+
#define __KPIS(__fmt, ...) \
|
| 35 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 36 |
+
#define __DEBUG(__fmt, ...) \
|
| 37 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 38 |
+
#define __TRACE(__fmt, ...) \
|
| 39 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 40 |
+
|
| 41 |
+
namespace qualla {
|
| 42 |
+
|
| 43 |
+
using qc = qualla::Config;
|
| 44 |
+
|
| 45 |
+
class SpecDecDialog : public Dialog {
|
| 46 |
+
public:
|
| 47 |
+
SpecDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf);
|
| 48 |
+
|
| 49 |
+
virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;
|
| 50 |
+
|
| 51 |
+
virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
|
| 52 |
+
return false;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
private:
|
| 56 |
+
int32_t _draft_len; // Number of draft tokens
|
| 57 |
+
bool _parallel; // Enable parallel processing (where possible)
|
| 58 |
+
|
| 59 |
+
Sampler& _d_sampler; // Draft sampler
|
| 60 |
+
Sampler& _t_sampler; // Target sampler
|
| 61 |
+
|
| 62 |
+
// Token acceptor, called for each accepted token.
|
| 63 |
+
// Returns true to continue, false to stop
|
| 64 |
+
using Acceptor = std::function<bool(int32_t token)>;
|
| 65 |
+
|
| 66 |
+
// Rejection sampling.
|
| 67 |
+
// Returns number of accepted tokens
|
| 68 |
+
size_t rejectionSampling(
|
| 69 |
+
std::span<int32_t> tokens,
|
| 70 |
+
std::span<float> target_logits,
|
| 71 |
+
std::span<float> draft_probs,
|
| 72 |
+
Acceptor accept
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
int32_t sampleFromModifiedDist(std::span<float> src0_dst, std::span<float> src1);
|
| 76 |
+
};
|
| 77 |
+
|
| 78 |
+
SpecDecDialog::SpecDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
|
| 79 |
+
: Dialog(env, name, conf),
|
| 80 |
+
_d_sampler(_sampler.contains("draft") ? *_sampler["draft"] : *_sampler["target"]),
|
| 81 |
+
_t_sampler(*_sampler["target"]) {
|
| 82 |
+
|
| 83 |
+
_draft_len = qc::optional<int32_t>(conf, "draft-len", 3);
|
| 84 |
+
_parallel = qc::optional<bool>(conf, "parallel", false);
|
| 85 |
+
|
| 86 |
+
// Check all underlying components for correct types an config
|
| 87 |
+
// If something is not right we set our error state that can be checked later
|
| 88 |
+
|
| 89 |
+
if (!_sampler.contains("target")) {
|
| 90 |
+
State::fatal("\"target\" sampler not present in config!");
|
| 91 |
+
return;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
if (!_engine.contains("target")) {
|
| 95 |
+
State::fatal("\"target\" engine not present in config!");
|
| 96 |
+
return;
|
| 97 |
+
}
|
| 98 |
+
if (!_engine.contains("draft")) {
|
| 99 |
+
State::fatal("\"draft\" engine not present in config!");
|
| 100 |
+
return;
|
| 101 |
+
}
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
int32_t SpecDecDialog::sampleFromModifiedDist(std::span<float> src0_dst, std::span<float> src1) {
|
| 105 |
+
// [max(prob_target[x] - prob_draft[x], 0.f) for all x in vocab]
|
| 106 |
+
size_t size = src0_dst.size();
|
| 107 |
+
|
| 108 |
+
if (_t_sampler.gumbel()) {
|
| 109 |
+
// Avoid going in the denormal zone.
|
| 110 |
+
float tiny = 1.1754943508222875e-38;
|
| 111 |
+
|
| 112 |
+
#pragma clang loop vectorize(enable) unroll_count(4)
|
| 113 |
+
for (size_t i = 0U; i < size; i++) {
|
| 114 |
+
float p_src0 = std::exp(src0_dst[i]);
|
| 115 |
+
float p_src1 = std::exp(src1[i]);
|
| 116 |
+
src0_dst[i] = std::log(std::max(tiny, p_src0 - p_src1));
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
// NOTE: The output logps_target is unnormalized since we use Gumbel trick.
|
| 120 |
+
// If we use standard multinomial sampling, normalization should be added.
|
| 121 |
+
|
| 122 |
+
} else {
|
| 123 |
+
float sum = 0.0; // Unlikely to overflow (?)
|
| 124 |
+
#pragma clang loop vectorize(enable) unroll_count(4)
|
| 125 |
+
for (size_t i = 0U; i < size; i++) {
|
| 126 |
+
float num = std::max(0.f, src0_dst[i] - src1[i]);
|
| 127 |
+
sum += num;
|
| 128 |
+
src0_dst[i] = num;
|
| 129 |
+
}
|
| 130 |
+
// Normalize
|
| 131 |
+
#pragma clang loop vectorize(enable) unroll_count(4)
|
| 132 |
+
for (size_t i = 0U; i < size; i++) {
|
| 133 |
+
src0_dst[i] /= sum;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
if (_t_sampler.greedy()) return argmax(src0_dst);
|
| 138 |
+
|
| 139 |
+
if (_t_sampler.gumbel()) return sampleUsingGumbelMax(src0_dst, _t_sampler.rng());
|
| 140 |
+
|
| 141 |
+
// Skipping softmax since the probs are already normalized
|
| 142 |
+
return sampleFromProbs(src0_dst, _t_sampler.rng());
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
size_t SpecDecDialog::rejectionSampling(
|
| 146 |
+
std::span<int32_t> tokens,
|
| 147 |
+
std::span<float> target_logits,
|
| 148 |
+
std::span<float> draft_probs,
|
| 149 |
+
Acceptor accept
|
| 150 |
+
) {
|
| 151 |
+
const size_t n_vocab = _ctx->n_vocab();
|
| 152 |
+
const size_t n_tok = tokens.size();
|
| 153 |
+
|
| 154 |
+
assert(tokens.size() == draft_probs.size() / n_vocab);
|
| 155 |
+
assert(target_logits.size() == draft_probs.size() + n_vocab);
|
| 156 |
+
|
| 157 |
+
// Rejection sampling:
|
| 158 |
+
// For each token in the n_tok tokens sampled from the draft model:
|
| 159 |
+
// 1. Determine the probability of that token being accepted by the target model
|
| 160 |
+
// 2. Accept the token with probability = prob_target[tok] / prob_draft[tok] (clamped to [0, 1])
|
| 161 |
+
// 3. If the token is rejected, resample a new token from the following distribution:
|
| 162 |
+
// [max(prob_target[x] - prob_draft[x], 0.f) for all x in vocab]
|
| 163 |
+
int32_t t_tok;
|
| 164 |
+
size_t n_accepted = 0;
|
| 165 |
+
|
| 166 |
+
std::vector<float> target_probs;
|
| 167 |
+
|
| 168 |
+
for (int32_t i = 0; i < n_tok; i++) {
|
| 169 |
+
int32_t d_tok = tokens[i];
|
| 170 |
+
|
| 171 |
+
std::span<float> t_span = target_logits.subspan(i * n_vocab, n_vocab);
|
| 172 |
+
|
| 173 |
+
if (_t_sampler.greedy()) {
|
| 174 |
+
t_tok = _t_sampler.process(t_span);
|
| 175 |
+
if (t_tok != d_tok) {
|
| 176 |
+
// Reject
|
| 177 |
+
break;
|
| 178 |
+
}
|
| 179 |
+
} else {
|
| 180 |
+
target_probs.clear();
|
| 181 |
+
t_tok = _t_sampler.process(t_span, target_probs, false); // only probs, no token
|
| 182 |
+
|
| 183 |
+
// Acceptance threshold
|
| 184 |
+
double threshold;
|
| 185 |
+
float prob_draft = draft_probs[i * n_vocab + d_tok];
|
| 186 |
+
float prob_target = target_probs[d_tok];
|
| 187 |
+
|
| 188 |
+
if (_t_sampler.gumbel()) {
|
| 189 |
+
threshold = std::exp(double(prob_target) - double(prob_draft));
|
| 190 |
+
} else {
|
| 191 |
+
threshold = double(prob_target) / double(prob_draft);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
double r = sampleFromUniform(_t_sampler.rng());
|
| 195 |
+
if (r > threshold) {
|
| 196 |
+
// Reject
|
| 197 |
+
break;
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
// Accepted!
|
| 201 |
+
++n_accepted;
|
| 202 |
+
if (!accept(d_tok)) return n_accepted;
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
// Sample an extra token either from the target distribution or the modified distribution
|
| 206 |
+
if (n_accepted == n_tok) {
|
| 207 |
+
t_tok = _t_sampler.process(target_logits.subspan(n_tok * n_vocab));
|
| 208 |
+
} else if (!_t_sampler.greedy()) {
|
| 209 |
+
// Resample from modified distribution.
|
| 210 |
+
t_tok = sampleFromModifiedDist(
|
| 211 |
+
std::span{target_probs.data(),target_probs.size()}, draft_probs.subspan(n_accepted * n_vocab, n_vocab)
|
| 212 |
+
);
|
| 213 |
+
} // for greedy, t_tok should be already valid from the loop above
|
| 214 |
+
|
| 215 |
+
++n_accepted;
|
| 216 |
+
accept(t_tok);
|
| 217 |
+
|
| 218 |
+
return n_accepted;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
bool SpecDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
|
| 222 |
+
|
| 223 |
+
// Check for prev failures and bail out early
|
| 224 |
+
if (State::failed()) return false;
|
| 225 |
+
|
| 226 |
+
Timer start;
|
| 227 |
+
|
| 228 |
+
const size_t n_vocab = _ctx->n_vocab();
|
| 229 |
+
|
| 230 |
+
// Vector for storing logits.
|
| 231 |
+
// Allocated & filled by the engine.
|
| 232 |
+
std::vector<float> t_logits;
|
| 233 |
+
std::vector<float> d_logits;
|
| 234 |
+
|
| 235 |
+
bool keep_generating = true;
|
| 236 |
+
|
| 237 |
+
// A buffer for tokens to be decoded (one at a time, per the Middleware's request)
|
| 238 |
+
std::vector<int32_t> decode_buf(1, 0);
|
| 239 |
+
|
| 240 |
+
// Decode new token.
|
| 241 |
+
// Return true to continue generation, and false otherwise
|
| 242 |
+
auto decode_token = [&](int32_t t) {
|
| 243 |
+
decode_buf[0] = _last_tok = t;
|
| 244 |
+
|
| 245 |
+
if (_ctx->is_eos(t)) {
|
| 246 |
+
keep_generating = false;
|
| 247 |
+
callback("", Sentence::END);
|
| 248 |
+
} else {
|
| 249 |
+
keep_generating = callback(_tokenizer->decode(decode_buf), Sentence::CONTINUE);
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
return keep_generating;
|
| 253 |
+
};
|
| 254 |
+
|
| 255 |
+
State::clear();
|
| 256 |
+
|
| 257 |
+
auto& t_engine = *_engine["target"];
|
| 258 |
+
auto& d_engine = *_engine["draft"];
|
| 259 |
+
|
| 260 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 261 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 262 |
+
callback("", Sentence::END);
|
| 263 |
+
return true;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
// Step 0: Process the prompt both on the target and draft models.
|
| 267 |
+
bool d_pmpt, t_pmpt;
|
| 268 |
+
if (_parallel) {
|
| 269 |
+
std::thread dt([&]() { d_pmpt = d_engine.process(tokens, d_logits, false); });
|
| 270 |
+
std::thread tt([&]() { t_pmpt = t_engine.process(tokens, t_logits, false); });
|
| 271 |
+
dt.join();
|
| 272 |
+
tt.join();
|
| 273 |
+
} else {
|
| 274 |
+
d_pmpt = d_engine.process(tokens, d_logits, false);
|
| 275 |
+
t_pmpt = t_engine.process(tokens, t_logits, false);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
if (!d_pmpt) return Dialog::abort("draft engine prompt processing failed", callback);
|
| 279 |
+
if (!t_pmpt) return Dialog::abort("target engine prompt processing failed", callback);
|
| 280 |
+
|
| 281 |
+
// KV state Update
|
| 282 |
+
_n_prompt += tokens.size();
|
| 283 |
+
_n_past += tokens.size();
|
| 284 |
+
|
| 285 |
+
if (!t_engine.updateKV(_n_past)) return Dialog::abort("target context size exceeded", callback);
|
| 286 |
+
if (!d_engine.updateKV(_n_past)) return Dialog::abort("draft context size exceeded", callback);
|
| 287 |
+
|
| 288 |
+
// Sample one token from the target.
|
| 289 |
+
_last_tok = _t_sampler.process(t_logits);
|
| 290 |
+
|
| 291 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 292 |
+
|
| 293 |
+
// Log latest KPIs
|
| 294 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 295 |
+
|
| 296 |
+
if (!decode_token(_last_tok)) return true;
|
| 297 |
+
|
| 298 |
+
// Done with the prompt, start generating
|
| 299 |
+
start.reset();
|
| 300 |
+
State::busy(true);
|
| 301 |
+
|
| 302 |
+
// Buffers for all the tokens that need to be considered for each iteration
|
| 303 |
+
std::vector<int32_t> toks_to_target(_draft_len + 1);
|
| 304 |
+
std::vector<int32_t> toks_to_draft(2);
|
| 305 |
+
|
| 306 |
+
// Buffer for all the probability distributions from the draft sampler
|
| 307 |
+
std::vector<float> d_probs(n_vocab * _draft_len);
|
| 308 |
+
|
| 309 |
+
toks_to_target.assign(1, _last_tok);
|
| 310 |
+
toks_to_draft.assign(1, _last_tok);
|
| 311 |
+
|
| 312 |
+
// For keeping track of the number of tokens that were accepted in each iteration.
|
| 313 |
+
std::vector<int32_t> n_accepted_counts(_draft_len + 1, 0);
|
| 314 |
+
|
| 315 |
+
// Draft n_past, either in sync with n_past or one token behind (accepted-all)
|
| 316 |
+
size_t d_n_past = _n_past;
|
| 317 |
+
|
| 318 |
+
while (!State::canceled() && keep_generating) {
|
| 319 |
+
// Step 1: Use draft model to decode draft_len (aka gamma) tokens, and accumulate probabilities
|
| 320 |
+
d_probs.clear();
|
| 321 |
+
|
| 322 |
+
for (int32_t i = 0; i < _draft_len; i++) {
|
| 323 |
+
if (d_n_past + toks_to_draft.size() > _ctx->size()) {
|
| 324 |
+
__WARN("Context limit exceeded ({} + {} > {})",
|
| 325 |
+
d_n_past,
|
| 326 |
+
toks_to_target.size(),
|
| 327 |
+
_ctx->size());
|
| 328 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 329 |
+
|
| 330 |
+
// Log latest KPIs in a single line
|
| 331 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 332 |
+
callback("", Sentence::END);
|
| 333 |
+
return true;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
if (!d_engine.process(toks_to_draft, d_logits))
|
| 337 |
+
return Dialog::abort("draft engine gen processing failed", callback);
|
| 338 |
+
|
| 339 |
+
d_n_past += toks_to_draft.size();
|
| 340 |
+
|
| 341 |
+
if (!d_engine.updateKV(d_n_past))
|
| 342 |
+
return Dialog::abort("draft context size exceeded", callback);
|
| 343 |
+
|
| 344 |
+
int32_t token = _d_sampler.process(d_logits, d_probs);
|
| 345 |
+
toks_to_draft.assign(1, token);
|
| 346 |
+
toks_to_target.push_back(token);
|
| 347 |
+
|
| 348 |
+
if (_ctx->is_eos(token)) break;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
// Step 2: run the target model on the draft tokens
|
| 352 |
+
if (_n_past + toks_to_target.size() > _ctx->size()) {
|
| 353 |
+
__WARN("Context limit exceeded ({} + {} > {})",
|
| 354 |
+
_n_past,
|
| 355 |
+
toks_to_target.size(),
|
| 356 |
+
_ctx->size());
|
| 357 |
+
callback("", Sentence::END);
|
| 358 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 359 |
+
|
| 360 |
+
// Log latest KPIs in a single line
|
| 361 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 362 |
+
return true;
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
std::vector<int32_t> attention_map(toks_to_target.size());
|
| 366 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 367 |
+
size_t n_tok_t =
|
| 368 |
+
t_engine.process(toks_to_target, attention_map, t_logits, true /* all logits */);
|
| 369 |
+
if (n_tok_t != toks_to_target.size())
|
| 370 |
+
return Dialog::abort("target engine gen processing failed", callback);
|
| 371 |
+
|
| 372 |
+
// Step 3: accept or reject draft tokens
|
| 373 |
+
size_t n_accepted = rejectionSampling(
|
| 374 |
+
std::span{toks_to_target.data(),toks_to_target.size()}.subspan(1),
|
| 375 |
+
std::span{t_logits.data(),t_logits.size()}, std::span{d_probs.data(),d_probs.size()}, decode_token
|
| 376 |
+
);
|
| 377 |
+
|
| 378 |
+
_n_generated += n_accepted;
|
| 379 |
+
_n_past += n_accepted;
|
| 380 |
+
|
| 381 |
+
// Update stats
|
| 382 |
+
n_accepted_counts[n_accepted - 1]++;
|
| 383 |
+
|
| 384 |
+
// Accepted all?
|
| 385 |
+
if (n_accepted == _draft_len + 1) {
|
| 386 |
+
// Grab the last 2 tokens
|
| 387 |
+
toks_to_draft.assign({toks_to_target[_draft_len], _last_tok});
|
| 388 |
+
d_n_past = _n_past - 1;
|
| 389 |
+
} else {
|
| 390 |
+
// Grab only the last token
|
| 391 |
+
toks_to_draft.assign(1, _last_tok);
|
| 392 |
+
d_n_past = _n_past;
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
toks_to_target.assign(1, _last_tok);
|
| 396 |
+
|
| 397 |
+
__DEBUG("spec-dec: draft_len {} n_generated {} n_accepted {} n_past {}",
|
| 398 |
+
_draft_len,
|
| 399 |
+
_n_generated,
|
| 400 |
+
n_accepted,
|
| 401 |
+
_n_past);
|
| 402 |
+
|
| 403 |
+
std::vector<bool> selected(attention_map.size(), false);
|
| 404 |
+
selected[0] = true; // first token is selected always
|
| 405 |
+
auto last_sel = 0;
|
| 406 |
+
for (int i = n_accepted - 1; i != 0; i = attention_map[i]) {
|
| 407 |
+
selected[i] = true;
|
| 408 |
+
last_sel = i > last_sel ? i : last_sel;
|
| 409 |
+
}
|
| 410 |
+
selected.resize(last_sel + 1); // trim away rejected tokens
|
| 411 |
+
|
| 412 |
+
// Step 4: commit accepted tokens to kv-caches
|
| 413 |
+
if (!t_engine.updateKV(_n_past, selected))
|
| 414 |
+
return Dialog::abort("target context size exceeded", callback);
|
| 415 |
+
if (!d_engine.updateKV(d_n_past))
|
| 416 |
+
return Dialog::abort("draft context size exceeded", callback);
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
if (d_n_past != _n_past) {
|
| 420 |
+
// The draft engine needs to process one last token to catch up
|
| 421 |
+
toks_to_draft.resize(1);
|
| 422 |
+
if (!d_engine.process(toks_to_draft))
|
| 423 |
+
return Dialog::abort("draft engine gen processing failed", callback);
|
| 424 |
+
if (!d_engine.updateKV(_n_past))
|
| 425 |
+
return Dialog::abort("draft context size exceeded", callback);
|
| 426 |
+
}
|
| 427 |
+
|
| 428 |
+
State::busy(false);
|
| 429 |
+
|
| 430 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 431 |
+
|
| 432 |
+
// Log latest KPIs in a single line
|
| 433 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 434 |
+
__KPIS("spec-dec: accepted counts: {}", n_accepted_counts);
|
| 435 |
+
|
| 436 |
+
return true;
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
// Registrator instance
|
| 440 |
+
static OnLoad regy([]() {
|
| 441 |
+
Dialog::__register(
|
| 442 |
+
"spec-dec",
|
| 443 |
+
[](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
|
| 444 |
+
return (Dialog*)new SpecDecDialog(env, name, conf);
|
| 445 |
+
}
|
| 446 |
+
);
|
| 447 |
+
});
|
| 448 |
+
|
| 449 |
+
// Register spec-dec sampler for compatibility
|
| 450 |
+
static OnLoad sampler_regy([]() {
|
| 451 |
+
Sampler::__register("spec-dec", [](Context& ctx, const json& conf) {
|
| 452 |
+
return (Sampler*)new BasicSampler(ctx, conf);
|
| 453 |
+
});
|
| 454 |
+
});
|
| 455 |
+
|
| 456 |
+
void needSpdDialog() {}
|
| 457 |
+
|
| 458 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/dialogs/ssd-q1.cpp
ADDED
|
@@ -0,0 +1,1046 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All rights reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/context.hpp>
|
| 10 |
+
#include <qualla/dialog.hpp>
|
| 11 |
+
#include <qualla/sampler.hpp>
|
| 12 |
+
#include <qualla/logger.hpp>
|
| 13 |
+
#include <qualla/detail/config.hpp>
|
| 14 |
+
#include <qualla/detail/json.hpp>
|
| 15 |
+
#include <qualla/detail/timer.hpp>
|
| 16 |
+
#include <qualla/detail/onload.hpp>
|
| 17 |
+
#include <qualla/detail/sampler-utils.hpp>
|
| 18 |
+
#include <qualla/detail/basic-sampler.hpp>
|
| 19 |
+
|
| 20 |
+
#include <functional>
|
| 21 |
+
#include <fstream>
|
| 22 |
+
#include <string>
|
| 23 |
+
#include <unordered_map>
|
| 24 |
+
#include <filesystem>
|
| 25 |
+
#include <random>
|
| 26 |
+
|
| 27 |
+
#include <fmt/format.h>
|
| 28 |
+
#include <fmt/ranges.h>
|
| 29 |
+
|
| 30 |
+
namespace fs = std::filesystem;
|
| 31 |
+
|
| 32 |
+
#define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 33 |
+
#define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 34 |
+
#define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 35 |
+
#define __KPIS(__fmt, ...) \
|
| 36 |
+
_env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 37 |
+
#define __DEBUG(__fmt, ...) \
|
| 38 |
+
_env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 39 |
+
#define __TRACE(__fmt, ...) \
|
| 40 |
+
_env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 41 |
+
|
| 42 |
+
namespace qualla {
|
| 43 |
+
|
| 44 |
+
using qc = qualla::Config;
|
| 45 |
+
using Logits = std::span<float>;
|
| 46 |
+
|
| 47 |
+
class SelfSpecDecDialog : public Dialog {
|
| 48 |
+
enum { VERSION = 1 };
|
| 49 |
+
|
| 50 |
+
public:
|
| 51 |
+
SelfSpecDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf);
|
| 52 |
+
|
| 53 |
+
virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;
|
| 54 |
+
virtual bool process(std::vector<uint8_t>& embedding_vectors, Dialog::T2ECallback t2eCallback, Dialog::Callback callback) override;
|
| 55 |
+
virtual void reset() override;
|
| 56 |
+
|
| 57 |
+
virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
|
| 58 |
+
return false;
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
virtual bool save(const std::string& name) override;
|
| 62 |
+
virtual bool restore(const std::string& name) override;
|
| 63 |
+
|
| 64 |
+
private:
|
| 65 |
+
Sampler& _t_sampler;
|
| 66 |
+
|
| 67 |
+
int32_t _vocab;
|
| 68 |
+
|
| 69 |
+
std::string _kv_prefix_name{"forecast-prefix"};
|
| 70 |
+
|
| 71 |
+
// AR8
|
| 72 |
+
size_t _draft{1};
|
| 73 |
+
std::vector<size_t> _branches{3};
|
| 74 |
+
|
| 75 |
+
size_t _forecast_prefix{16};
|
| 76 |
+
size_t _forecast_token_offset{32000};
|
| 77 |
+
size_t _forecast_token_count{4};
|
| 78 |
+
|
| 79 |
+
// Multistream parameters
|
| 80 |
+
int32_t _n_streams;
|
| 81 |
+
float _p_threshold;
|
| 82 |
+
|
| 83 |
+
InputType m_inputType{InputType::UNKNOWN};
|
| 84 |
+
|
| 85 |
+
bool processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, Dialog::Callback callback);
|
| 86 |
+
// Multistream
|
| 87 |
+
bool processFollowOnGeneration(std::vector<std::vector<int32_t>>& streams, std::vector<float>& logits, Dialog::Callback callback);
|
| 88 |
+
|
| 89 |
+
/*
|
| 90 |
+
Helper function for combining masks for SSD mulstistream.
|
| 91 |
+
|
| 92 |
+
@param masks The attention mask to be tiled
|
| 93 |
+
@param streamIndices Indices of streams. The tiling count is equal to the size of this vector.
|
| 94 |
+
@param pastMap A vector of stream indices for masking all past tokens after the prompt.
|
| 95 |
+
@param prefixOffset Offset where KV prefix masking begins in each tile.
|
| 96 |
+
@param finalMask A mask that combines all of the independent masks such that
|
| 97 |
+
they can be executed in the same inference.
|
| 98 |
+
*/
|
| 99 |
+
void tileAttentionMask(const std::vector<int32_t>& mask, const std::vector<size_t> streamIndices, const std::vector<size_t>& pastMap, const size_t prefixOffset, std::vector<int32_t>& finalMask);
|
| 100 |
+
|
| 101 |
+
std::vector<int32_t> gen_attention_map() const;
|
| 102 |
+
auto get_len_flat_sample_tree() const;
|
| 103 |
+
auto gen_forecast_tokens(int repeat) const;
|
| 104 |
+
|
| 105 |
+
// Sampling and verification
|
| 106 |
+
std::vector<int32_t> build_sample_tree(
|
| 107 |
+
int32_t last_token,
|
| 108 |
+
Logits logits,
|
| 109 |
+
const std::vector<int32_t>& indices
|
| 110 |
+
);
|
| 111 |
+
std::tuple<std::vector<int32_t>, std::vector<int32_t>> verify_and_select_longest(
|
| 112 |
+
std::span<int32_t> sample_tree,
|
| 113 |
+
Logits logits
|
| 114 |
+
);
|
| 115 |
+
std::vector<int32_t> sample_to_draft(Logits logits, size_t index, size_t count) {
|
| 116 |
+
const auto thislogit = logits.subspan(index * _vocab, _vocab);
|
| 117 |
+
IndexedLogits logit(thislogit, _t_sampler.rng());
|
| 118 |
+
logit.topK(count);
|
| 119 |
+
return logit.indices;
|
| 120 |
+
}
|
| 121 |
+
int32_t sample_to_verify(Logits logits, size_t index) {
|
| 122 |
+
const auto thislogit = logits.subspan(index * _vocab, _vocab);
|
| 123 |
+
if (_t_sampler.greedy()) {
|
| 124 |
+
return argmax(thislogit);
|
| 125 |
+
}
|
| 126 |
+
auto token = _t_sampler.process(thislogit);
|
| 127 |
+
return token;
|
| 128 |
+
}
|
| 129 |
+
};
|
| 130 |
+
|
| 131 |
+
SelfSpecDecDialog::SelfSpecDecDialog(
|
| 132 |
+
std::shared_ptr<Env> env,
|
| 133 |
+
const std::string& name,
|
| 134 |
+
const json& conf
|
| 135 |
+
)
|
| 136 |
+
: Dialog(env, name, conf), _t_sampler(*_sampler["primary"]) {
|
| 137 |
+
|
| 138 |
+
auto ssd_version = qc::optional<int>(conf, "ssd-version", 0);
|
| 139 |
+
if (ssd_version > SelfSpecDecDialog::VERSION) __WARN("newer ssd-version in config!");
|
| 140 |
+
|
| 141 |
+
_vocab = _ctx->n_vocab();
|
| 142 |
+
|
| 143 |
+
_branches = qc::optional(conf, "branches", _branches);
|
| 144 |
+
_draft = _branches.size();
|
| 145 |
+
|
| 146 |
+
_forecast_prefix = qc::optional(conf, "forecast-prefix", _forecast_prefix);
|
| 147 |
+
_forecast_token_count = qc::optional(conf, "forecast-token-count", _forecast_token_count);
|
| 148 |
+
_forecast_token_offset = _vocab;
|
| 149 |
+
|
| 150 |
+
_kv_prefix_name = qc::optional(conf, "forecast-prefix-name", _kv_prefix_name);
|
| 151 |
+
|
| 152 |
+
_n_streams = qc::optional<int32_t>(conf, "n-streams", 1);
|
| 153 |
+
_p_threshold = qc::optional<float>(conf, "p-threshold", 0.0);
|
| 154 |
+
|
| 155 |
+
if (!_engine.contains("primary")) {
|
| 156 |
+
State::fatal("\"primary\" engine not present in config!");
|
| 157 |
+
return;
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
//Get Input Type from the engine
|
| 161 |
+
m_inputType = _engine["primary"]->getInputType();
|
| 162 |
+
// Load KV prefix
|
| 163 |
+
Timer timer;
|
| 164 |
+
size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
|
| 165 |
+
if (n_restored_prefix != _forecast_prefix) {
|
| 166 |
+
// clang-format off
|
| 167 |
+
throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
|
| 168 |
+
n_restored_prefix, _kv_prefix_name, _forecast_prefix ) );
|
| 169 |
+
// clang-format on
|
| 170 |
+
}
|
| 171 |
+
_n_past = _forecast_prefix;
|
| 172 |
+
_kpis.restore.update(timer.elapsed_usec());
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
auto SelfSpecDecDialog::get_len_flat_sample_tree() const {
|
| 176 |
+
size_t len_flat_sample_tree = 1;
|
| 177 |
+
size_t last_tokens = 1;
|
| 178 |
+
for (int i = 0; i < _draft; ++i) {
|
| 179 |
+
len_flat_sample_tree += last_tokens * _branches[i];
|
| 180 |
+
last_tokens = last_tokens * _branches[i];
|
| 181 |
+
}
|
| 182 |
+
return len_flat_sample_tree;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
auto SelfSpecDecDialog::gen_forecast_tokens(int repeat) const {
|
| 186 |
+
std::vector<int32_t> forecast_tokens(_draft, 0);
|
| 187 |
+
std::iota(forecast_tokens.begin(), forecast_tokens.end(), _forecast_token_offset);
|
| 188 |
+
|
| 189 |
+
std::vector<int32_t> ret;
|
| 190 |
+
for (auto i = 0; i < repeat; ++i)
|
| 191 |
+
ret.insert(ret.end(), forecast_tokens.begin(), forecast_tokens.end());
|
| 192 |
+
return ret;
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
std::vector<int32_t> SelfSpecDecDialog::gen_attention_map() const {
|
| 196 |
+
auto len_flat_sample_tree = get_len_flat_sample_tree();
|
| 197 |
+
std::vector<int32_t> attention_map(len_flat_sample_tree + len_flat_sample_tree * _draft, -1);
|
| 198 |
+
|
| 199 |
+
auto build_verify_tree = [&attention_map,
|
| 200 |
+
this](auto self, int parent_begin, int parent_end, int level) {
|
| 201 |
+
if (level == _draft) return;
|
| 202 |
+
auto current = parent_end;
|
| 203 |
+
for (auto parent = parent_begin; parent < parent_end; parent += 1) {
|
| 204 |
+
for (auto child = current; child < current + _branches[level]; child += 1)
|
| 205 |
+
attention_map[child] = parent;
|
| 206 |
+
current += _branches[level];
|
| 207 |
+
}
|
| 208 |
+
self(self, parent_end, current, level + 1);
|
| 209 |
+
};
|
| 210 |
+
|
| 211 |
+
auto build_forecast_tree = [&attention_map, this](int parent_begin, int parent_end) {
|
| 212 |
+
auto current = parent_end;
|
| 213 |
+
for (auto parent = parent_begin; parent < parent_end; parent += 1) {
|
| 214 |
+
for (auto child = current, current_parent = parent; child < current + _draft;
|
| 215 |
+
child += 1) {
|
| 216 |
+
attention_map[child] = current_parent;
|
| 217 |
+
current_parent = child;
|
| 218 |
+
}
|
| 219 |
+
current += _draft;
|
| 220 |
+
}
|
| 221 |
+
};
|
| 222 |
+
|
| 223 |
+
build_verify_tree(build_verify_tree, 0, 1, 0);
|
| 224 |
+
build_forecast_tree(0, len_flat_sample_tree);
|
| 225 |
+
return attention_map;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
std::vector<int32_t> SelfSpecDecDialog::build_sample_tree(
|
| 229 |
+
int32_t last_token,
|
| 230 |
+
Logits logits,
|
| 231 |
+
const std::vector<int32_t>& indices
|
| 232 |
+
) {
|
| 233 |
+
std::vector<int32_t> tree = {last_token};
|
| 234 |
+
for (auto draft = 0, repeat = 1; draft < _draft; ++draft) {
|
| 235 |
+
auto samples = sample_to_draft(logits, indices[draft], _branches[draft]);
|
| 236 |
+
for (auto i = 0; i < repeat; ++i) {
|
| 237 |
+
tree.insert(tree.end(), samples.begin(), samples.end());
|
| 238 |
+
}
|
| 239 |
+
repeat *= _branches[draft];
|
| 240 |
+
}
|
| 241 |
+
return tree;
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
std::tuple<std::vector<int32_t>, std::vector<int32_t>> SelfSpecDecDialog::verify_and_select_longest(
|
| 245 |
+
std::span<int32_t> sample_tree,
|
| 246 |
+
Logits logits
|
| 247 |
+
) {
|
| 248 |
+
std::vector<std::vector<int32_t>> accepted_all = {{sample_to_verify(logits, 0)}};
|
| 249 |
+
std::vector<std::vector<int32_t>> node_ids_all = {{0}};
|
| 250 |
+
|
| 251 |
+
std::vector<int32_t> draft_offset(_draft, 0);
|
| 252 |
+
draft_offset[0] = 1;
|
| 253 |
+
for (int32_t i = 1, draft_count = _branches[0]; i < _draft; ++i) {
|
| 254 |
+
draft_offset[i] = draft_offset[i - 1] + draft_count;
|
| 255 |
+
draft_count = draft_count * _branches[i];
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
size_t longest = 0, longest_size = 1;
|
| 259 |
+
auto verify_recursive = [&](auto self,
|
| 260 |
+
std::vector<int32_t> accepted,
|
| 261 |
+
std::vector<int32_t> node_ids,
|
| 262 |
+
int draft,
|
| 263 |
+
int offset_in_draft) -> void {
|
| 264 |
+
auto target = accepted.back();
|
| 265 |
+
auto branch_base = draft_offset[draft] + offset_in_draft;
|
| 266 |
+
for (auto branch = 0; branch < _branches[draft]; ++branch) {
|
| 267 |
+
auto ndx_node = branch_base + branch;
|
| 268 |
+
if (!_ctx->is_eos(target) && target == sample_tree[ndx_node]) {
|
| 269 |
+
auto sample_accepted = sample_to_verify(logits, ndx_node);
|
| 270 |
+
accepted_all.push_back(accepted);
|
| 271 |
+
accepted_all.back().push_back(sample_accepted);
|
| 272 |
+
node_ids_all.push_back(node_ids);
|
| 273 |
+
node_ids_all.back().push_back(ndx_node);
|
| 274 |
+
if (node_ids_all.back().size() > longest_size) {
|
| 275 |
+
longest = node_ids_all.size() - 1;
|
| 276 |
+
longest_size = node_ids_all.back().size();
|
| 277 |
+
}
|
| 278 |
+
if (draft + 1 < _draft)
|
| 279 |
+
self(self,
|
| 280 |
+
accepted_all.back(),
|
| 281 |
+
node_ids_all.back(),
|
| 282 |
+
draft + 1,
|
| 283 |
+
(offset_in_draft + branch) * _branches[draft + 1]);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
};
|
| 287 |
+
verify_recursive(verify_recursive, accepted_all.back(), node_ids_all.back(), 0, 0);
|
| 288 |
+
return {accepted_all[longest], node_ids_all[longest]};
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
void SelfSpecDecDialog::tileAttentionMask(const std::vector<int32_t>& mask, const std::vector<size_t> streamIndices, const std::vector<size_t>& pastMap, const size_t prefixOffset, std::vector<int32_t>& tiledMask) {
|
| 292 |
+
|
| 293 |
+
const size_t sampleTreeLen = get_len_flat_sample_tree();
|
| 294 |
+
const size_t pastMapLen = pastMap.size();
|
| 295 |
+
const int posVal = 1, negVal = 0;
|
| 296 |
+
|
| 297 |
+
const size_t maskSize = mask.size();
|
| 298 |
+
const size_t numTokens = maskSize * streamIndices.size();
|
| 299 |
+
|
| 300 |
+
const size_t rowLength = _n_past + numTokens;
|
| 301 |
+
tiledMask.resize(numTokens * rowLength);
|
| 302 |
+
|
| 303 |
+
for (int maskIdx = 0; maskIdx < streamIndices.size(); maskIdx++) {
|
| 304 |
+
// Number of rows to skip to reach the current tile.
|
| 305 |
+
const size_t tileOffset = maskIdx * maskSize;
|
| 306 |
+
int32_t* const tileStart = &tiledMask[tileOffset*rowLength + tileOffset + _n_past];
|
| 307 |
+
for (int i = 0; i < maskSize; i++) {
|
| 308 |
+
// Pointer to the start of row i of the current mask
|
| 309 |
+
int32_t* rowPtr = &tiledMask[(tileOffset + i)*rowLength];
|
| 310 |
+
// Skip kv-prefix attention for rows without speculative tokens.
|
| 311 |
+
const int prefixFillVal = (i < prefixOffset) ? negVal : posVal;
|
| 312 |
+
std::fill_n(rowPtr, _forecast_prefix, prefixFillVal);
|
| 313 |
+
rowPtr += _forecast_prefix;
|
| 314 |
+
// Always attend to prompt.
|
| 315 |
+
std::fill_n(rowPtr, _n_prompt, posVal);
|
| 316 |
+
rowPtr += _n_prompt;
|
| 317 |
+
|
| 318 |
+
// Fill in the past valid tokens for this stream.
|
| 319 |
+
for (const size_t& pastIdx : pastMap) {
|
| 320 |
+
*rowPtr = (pastIdx == streamIndices[maskIdx]) ? posVal : negVal;
|
| 321 |
+
rowPtr++;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
// Clear the rest of the row. It will mostly consist of 0's.
|
| 325 |
+
std::fill_n(rowPtr, rowLength - _n_prompt - _forecast_prefix - pastMapLen, negVal);
|
| 326 |
+
// Move to the correct tile.
|
| 327 |
+
rowPtr += tileOffset;
|
| 328 |
+
// Translate the mask.
|
| 329 |
+
const auto tokenId = mask[i];
|
| 330 |
+
if (tokenId > -1) {
|
| 331 |
+
std::copy_n(tileStart + (tokenId * rowLength), tokenId + 1, rowPtr);
|
| 332 |
+
}
|
| 333 |
+
// Always attend to self.
|
| 334 |
+
rowPtr[i] = posVal;
|
| 335 |
+
}
|
| 336 |
+
}
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
// Takes a vector of tokens and produces a vector of embeddings via the provided T2E callback.
|
| 340 |
+
static inline void convertTokensToEmbeddings(std::vector<int32_t>& tokens,
|
| 341 |
+
std::vector<uint8_t>& embeddings,
|
| 342 |
+
size_t embeddingBufferSize,
|
| 343 |
+
Dialog::T2ECallback t2eCallback) {
|
| 344 |
+
for(auto &token : tokens){
|
| 345 |
+
std::vector<uint8_t> embedding(embeddingBufferSize,0);
|
| 346 |
+
t2eCallback(token, embedding.data(), embeddingBufferSize);
|
| 347 |
+
embeddings.insert(embeddings.end(), embedding.begin(), embedding.end());
|
| 348 |
+
}
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
bool SelfSpecDecDialog::processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, Dialog::Callback callback){
|
| 352 |
+
|
| 353 |
+
// Handles the printing of the subsequent generated tokens
|
| 354 |
+
bool keep_generating = true;
|
| 355 |
+
const size_t context = _ctx->n_ctx();
|
| 356 |
+
|
| 357 |
+
std::vector<int32_t> decode_buf(
|
| 358 |
+
1, 0
|
| 359 |
+
); // A buffer for tokens to be decoded (one at a time, per the Middleware's request)
|
| 360 |
+
auto decode_token = [&](int32_t t) {
|
| 361 |
+
if (!keep_generating) return;
|
| 362 |
+
// Decode new token.
|
| 363 |
+
// Return true to continue generation, and false otherwise
|
| 364 |
+
decode_buf[0] = _last_tok = t;
|
| 365 |
+
++_n_generated;
|
| 366 |
+
if (_ctx->is_eos(t)) {
|
| 367 |
+
keep_generating = false;
|
| 368 |
+
callback("", Sentence::END);
|
| 369 |
+
} else {
|
| 370 |
+
keep_generating = callback(_tokenizer->decode(decode_buf), Sentence::CONTINUE);
|
| 371 |
+
}
|
| 372 |
+
return;
|
| 373 |
+
};
|
| 374 |
+
// set decode_buf from prompt processing
|
| 375 |
+
decode_buf[0] = _last_tok;
|
| 376 |
+
|
| 377 |
+
auto& engine = *_engine["primary"];
|
| 378 |
+
|
| 379 |
+
auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
|
| 380 |
+
if (!engine.updateKV(past, selected))
|
| 381 |
+
return Dialog::abort("context size exceeded", callback);
|
| 382 |
+
return true;
|
| 383 |
+
};
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
// prepare the next inference
|
| 387 |
+
std::vector<int32_t> indices(_draft, 0);
|
| 388 |
+
std::iota(indices.begin(), indices.end(), 1);
|
| 389 |
+
tokens = build_sample_tree(sample_to_verify(std::span{logits.data(),logits.size()}, 0), std::span{logits.data(),logits.size()}, indices);
|
| 390 |
+
decode_token(tokens[0]);
|
| 391 |
+
|
| 392 |
+
// Prepare constant options for next inferences
|
| 393 |
+
const auto len_flat_sample_tree = get_len_flat_sample_tree();
|
| 394 |
+
const auto forecast_tokens = gen_forecast_tokens(len_flat_sample_tree);
|
| 395 |
+
const auto attention_map = gen_attention_map();
|
| 396 |
+
|
| 397 |
+
engine.set({{"kv-prefix-offset", len_flat_sample_tree}});
|
| 398 |
+
|
| 399 |
+
std::vector<int32_t> accepted_counts(_draft + 1, 0);
|
| 400 |
+
std::vector<bool> selected(attention_map.size(), false);
|
| 401 |
+
|
| 402 |
+
while (!State::canceled() && keep_generating) {
|
| 403 |
+
|
| 404 |
+
// Append forecast tokens
|
| 405 |
+
tokens.insert(tokens.end(), forecast_tokens.begin(), forecast_tokens.end());
|
| 406 |
+
|
| 407 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 408 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 409 |
+
callback("", Sentence::END);
|
| 410 |
+
break;
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
size_t n_tok_t = 0;
|
| 414 |
+
|
| 415 |
+
// Bifurcate based on embedding as input or token as input
|
| 416 |
+
if (m_inputType == InputType::TOKENS)
|
| 417 |
+
n_tok_t = engine.process(tokens, attention_map, logits, true /* all logits */);
|
| 418 |
+
else if (m_inputType == InputType::EMBEDDINGS) {
|
| 419 |
+
// Convert tokens to embedding for the processing in the engine.
|
| 420 |
+
auto embedBufSize = engine.getEmbeddingBufferSize();
|
| 421 |
+
std::vector<uint8_t> embedding;
|
| 422 |
+
for(auto &token: tokens){
|
| 423 |
+
std::vector<uint8_t> curTokenEmbedding(embedBufSize,0);
|
| 424 |
+
m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
|
| 425 |
+
embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
|
| 426 |
+
}
|
| 427 |
+
n_tok_t = engine.process(embedding, attention_map, logits, true /* all logits */);
|
| 428 |
+
} else {
|
| 429 |
+
return Dialog::abort("No valid Input Type is used", callback);
|
| 430 |
+
}
|
| 431 |
+
if (n_tok_t != tokens.size()) return Dialog::abort("engine processing failed", callback);
|
| 432 |
+
|
| 433 |
+
// Accept tokens
|
| 434 |
+
auto [accepted_tokens, accepted_ids] = verify_and_select_longest(std::span{tokens.data(),tokens.size()},
|
| 435 |
+
std::span{logits.data(),logits.size()});
|
| 436 |
+
|
| 437 |
+
// Commit accepted tokens to kv-caches
|
| 438 |
+
selected.resize(accepted_ids.back() + 1); // trim away rejected tokens
|
| 439 |
+
std::fill(selected.begin(), selected.end(), false);
|
| 440 |
+
for (auto id : accepted_ids)
|
| 441 |
+
selected[id] = true;
|
| 442 |
+
accepted_counts[accepted_tokens.size() - 1] += 1;
|
| 443 |
+
_n_past += accepted_tokens.size();
|
| 444 |
+
update_kv(_n_past, selected);
|
| 445 |
+
|
| 446 |
+
// Decode tokens
|
| 447 |
+
std::for_each(accepted_tokens.begin(), accepted_tokens.end(), decode_token);
|
| 448 |
+
|
| 449 |
+
// Prepare new tokens
|
| 450 |
+
auto next_draft_offset = len_flat_sample_tree + accepted_ids.back() * _draft;
|
| 451 |
+
std::iota(indices.begin(), indices.end(), next_draft_offset);
|
| 452 |
+
tokens = build_sample_tree(accepted_tokens.back(), std::span{logits.data(),logits.size()}, indices);
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
State::busy(false);
|
| 456 |
+
|
| 457 |
+
auto total_iteration = std::accumulate(accepted_counts.begin(), accepted_counts.end(), 0);
|
| 458 |
+
auto accept_rate =
|
| 459 |
+
float(_n_generated - 1) / total_iteration; // -1: exclude first generated token
|
| 460 |
+
__KPIS("SSD{{draft:{}, branch:{}, greedy:{}}}: accepted counts: {}, accept rate = {} tokens/iteration",
|
| 461 |
+
_draft,
|
| 462 |
+
_branches,
|
| 463 |
+
_t_sampler.greedy(),
|
| 464 |
+
accepted_counts,
|
| 465 |
+
accept_rate);
|
| 466 |
+
|
| 467 |
+
return true;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
// Multistream AR generation
|
| 471 |
+
bool SelfSpecDecDialog::processFollowOnGeneration(std::vector<std::vector<int32_t>>& streams, std::vector<float>& logits, Dialog::Callback callback) {
|
| 472 |
+
|
| 473 |
+
auto& sampler = *_sampler["primary"];
|
| 474 |
+
auto& engine = *_engine["primary"];
|
| 475 |
+
|
| 476 |
+
auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
|
| 477 |
+
if (!engine.updateKV(past, selected))
|
| 478 |
+
return Dialog::abort("context size exceeded", callback);
|
| 479 |
+
return true;
|
| 480 |
+
};
|
| 481 |
+
|
| 482 |
+
std::vector<size_t> streamIndices(streams.size());
|
| 483 |
+
std::vector<size_t> past_map(streams.size());
|
| 484 |
+
|
| 485 |
+
std::iota(streamIndices.begin(), streamIndices.end(), 0);
|
| 486 |
+
// Since the first inference is done separately, it is
|
| 487 |
+
// expected that each stream already has 1 valid AR token.
|
| 488 |
+
std::iota(past_map.begin(), past_map.end(), 0);
|
| 489 |
+
|
| 490 |
+
bool keep_generating = true;
|
| 491 |
+
const size_t context = _ctx->n_ctx();
|
| 492 |
+
|
| 493 |
+
if (streams.size() == 0) {
|
| 494 |
+
callback("\n", Sentence::END);
|
| 495 |
+
return true;
|
| 496 |
+
}
|
| 497 |
+
|
| 498 |
+
// Prepare constant options for next inferences
|
| 499 |
+
const auto len_flat_sample_tree = get_len_flat_sample_tree();
|
| 500 |
+
const auto forecast_tokens = gen_forecast_tokens(len_flat_sample_tree);
|
| 501 |
+
const auto attention_map = gen_attention_map();
|
| 502 |
+
|
| 503 |
+
std::vector<std::vector<int32_t>> draftStreams(streams.size());
|
| 504 |
+
|
| 505 |
+
for (int i = 0; i < streams.size(); i++) {
|
| 506 |
+
// prepare the next inference
|
| 507 |
+
std::vector<int32_t> indices(_draft, 0);
|
| 508 |
+
std::iota(indices.begin(), indices.end(), 1);
|
| 509 |
+
draftStreams[i] = build_sample_tree(sample_to_verify(std::span{logits.data(),logits.size()}, i*(1+_draft)), std::span{logits.data(),logits.size()}, indices);
|
| 510 |
+
streams[i].push_back(draftStreams[i][0]);
|
| 511 |
+
|
| 512 |
+
}
|
| 513 |
+
|
| 514 |
+
std::vector<int32_t> multi_attn_mask;
|
| 515 |
+
|
| 516 |
+
std::vector<int32_t> accepted_counts(_draft + 1, 0);
|
| 517 |
+
|
| 518 |
+
engine.set({{"kv-prefix-offset", len_flat_sample_tree}});
|
| 519 |
+
|
| 520 |
+
State::busy(true);
|
| 521 |
+
|
| 522 |
+
while (true) {
|
| 523 |
+
if (State::canceled()) break;
|
| 524 |
+
|
| 525 |
+
// If this exceeds context length, truncate all streams and return
|
| 526 |
+
if (_n_past + streamIndices.size() > _ctx->size()) {
|
| 527 |
+
for (auto stream : streamIndices)
|
| 528 |
+
callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
|
| 529 |
+
break;
|
| 530 |
+
}
|
| 531 |
+
|
| 532 |
+
// Accumulate input tokens from all streams
|
| 533 |
+
std::vector<int32_t> multi_tokens;
|
| 534 |
+
for (auto streamIdx : streamIndices) {
|
| 535 |
+
multi_tokens.insert(multi_tokens.end(), draftStreams[streamIdx].begin(), draftStreams[streamIdx].end());
|
| 536 |
+
multi_tokens.insert(multi_tokens.end(), forecast_tokens.begin(), forecast_tokens.end());
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
if (_n_past + multi_tokens.size() > _ctx->size()) {
|
| 540 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size());
|
| 541 |
+
callback("", Sentence::END);
|
| 542 |
+
break;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
tileAttentionMask(attention_map, streamIndices, past_map, len_flat_sample_tree, multi_attn_mask);
|
| 546 |
+
|
| 547 |
+
size_t n_tok_t = 0;
|
| 548 |
+
|
| 549 |
+
if (m_inputType == InputType::TOKENS) {
|
| 550 |
+
// Process input tokens for all streams in one batch
|
| 551 |
+
n_tok_t = engine.process(multi_tokens, multi_attn_mask, logits, true);
|
| 552 |
+
} else if (m_inputType == InputType::EMBEDDINGS) {
|
| 553 |
+
// Accumulate input embeddings from all streams
|
| 554 |
+
auto embedBufSize = engine.getEmbeddingBufferSize();
|
| 555 |
+
std::vector<uint8_t> multi_embeddings;
|
| 556 |
+
|
| 557 |
+
convertTokensToEmbeddings(multi_tokens, multi_embeddings, embedBufSize, m_t2eCallback);
|
| 558 |
+
|
| 559 |
+
// Process input tokens for all streams in one batch
|
| 560 |
+
n_tok_t = engine.process(multi_embeddings, multi_attn_mask, logits, true);
|
| 561 |
+
}
|
| 562 |
+
if (n_tok_t != multi_tokens.size()) return Dialog::abort("engine processing failed", callback);
|
| 563 |
+
|
| 564 |
+
std::vector<bool> all_selected;
|
| 565 |
+
|
| 566 |
+
// Process all logits independently
|
| 567 |
+
std::span<float> logit_span = std::span{logits.data(),logits.size()};
|
| 568 |
+
std::span<int32_t> token_span = std::span{multi_tokens.data(), multi_tokens.size()};
|
| 569 |
+
for (int i = 0; i < streamIndices.size(); i++) {
|
| 570 |
+
const size_t streamIdx = streamIndices[i];
|
| 571 |
+
std::vector<int32_t>& stream = streams[streamIdx];
|
| 572 |
+
|
| 573 |
+
const size_t tileStride = draftStreams[streamIdx].size() + forecast_tokens.size();
|
| 574 |
+
|
| 575 |
+
std::span<float> tiled_logits = logit_span.subspan(i * tileStride * _vocab, _vocab);
|
| 576 |
+
|
| 577 |
+
// Accept tokens
|
| 578 |
+
auto [accepted_tokens, accepted_ids] = verify_and_select_longest(token_span.subspan(i * tileStride, tileStride),
|
| 579 |
+
tiled_logits);
|
| 580 |
+
|
| 581 |
+
// Commit accepted tokens to kv-caches
|
| 582 |
+
std::vector<bool> selected(tileStride, false);
|
| 583 |
+
for (auto id : accepted_ids) {
|
| 584 |
+
selected[id] = true;
|
| 585 |
+
past_map.push_back(streamIdx);
|
| 586 |
+
}
|
| 587 |
+
all_selected.insert(all_selected.end(), selected.begin(), selected.end());
|
| 588 |
+
accepted_counts[accepted_tokens.size() - 1] += 1;
|
| 589 |
+
_n_past += accepted_tokens.size();
|
| 590 |
+
|
| 591 |
+
// Decode tokens
|
| 592 |
+
stream.insert(stream.end(), accepted_tokens.begin(), accepted_tokens.end());
|
| 593 |
+
_n_generated += accepted_tokens.size();
|
| 594 |
+
|
| 595 |
+
// Prepare new tokens
|
| 596 |
+
std::vector<int32_t> indices(_draft, 0);
|
| 597 |
+
auto next_draft_offset = len_flat_sample_tree + accepted_ids.back() * _draft;
|
| 598 |
+
std::iota(indices.begin(), indices.end(), next_draft_offset);
|
| 599 |
+
draftStreams[streamIdx] = build_sample_tree(accepted_tokens.back(), tiled_logits, indices);
|
| 600 |
+
}
|
| 601 |
+
|
| 602 |
+
update_kv(_n_past, all_selected);
|
| 603 |
+
for (auto it = streamIndices.begin(); it != streamIndices.end();) {
|
| 604 |
+
int32_t stream = *it;
|
| 605 |
+
if (_ctx->is_eos(streams[stream].back())) {
|
| 606 |
+
callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
|
| 607 |
+
it = streamIndices.erase(it);
|
| 608 |
+
} else {
|
| 609 |
+
++it;
|
| 610 |
+
}
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
if (streamIndices.size() == 0) break;
|
| 614 |
+
}
|
| 615 |
+
callback("\n", Sentence::END);
|
| 616 |
+
|
| 617 |
+
State::busy(false);
|
| 618 |
+
|
| 619 |
+
auto total_iteration = std::accumulate(accepted_counts.begin(), accepted_counts.end(), 0);
|
| 620 |
+
auto accept_rate =
|
| 621 |
+
float(_n_generated - 1) / total_iteration; // -1: exclude first generated token
|
| 622 |
+
__KPIS("SSD{{draft:{}, branch:{}, greedy:{}}}: accepted counts: {}, accept rate = {} tokens/iteration",
|
| 623 |
+
_draft,
|
| 624 |
+
_branches,
|
| 625 |
+
_t_sampler.greedy(),
|
| 626 |
+
accepted_counts,
|
| 627 |
+
accept_rate);
|
| 628 |
+
|
| 629 |
+
return true;
|
| 630 |
+
}
|
| 631 |
+
|
| 632 |
+
// Handle prompt processing and generation will be done processFollowOnGeneration
|
| 633 |
+
// Pass t2e callback using setter and remove as an argument. call setter from the base query function of dialog
|
| 634 |
+
|
| 635 |
+
bool SelfSpecDecDialog::process(std::vector<uint8_t>& embedding,
|
| 636 |
+
T2ECallback t2eCallback,
|
| 637 |
+
Dialog::Callback callback ){
|
| 638 |
+
|
| 639 |
+
// Check for prev failures and bail out early
|
| 640 |
+
if (State::failed()) return false;
|
| 641 |
+
|
| 642 |
+
if(m_inputType != InputType::EMBEDDINGS) {
|
| 643 |
+
__ERROR("Input type for model is not embeddings.");
|
| 644 |
+
return false;
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
Timer start;
|
| 648 |
+
State::clear();
|
| 649 |
+
|
| 650 |
+
std::vector<float> logits;
|
| 651 |
+
auto& engine = *_engine["primary"];
|
| 652 |
+
|
| 653 |
+
auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
|
| 654 |
+
if (!engine.updateKV(past, selected))
|
| 655 |
+
return Dialog::abort("context size exceeded", callback);
|
| 656 |
+
return true;
|
| 657 |
+
};
|
| 658 |
+
|
| 659 |
+
// Store the t2e callback for reference during follow-on generation.
|
| 660 |
+
m_t2eCallback = t2eCallback;
|
| 661 |
+
|
| 662 |
+
auto embedBufSize = engine.getEmbeddingBufferSize();
|
| 663 |
+
|
| 664 |
+
{
|
| 665 |
+
std::vector<uint8_t> eosEmbedding(embedBufSize, 0.0);
|
| 666 |
+
if (m_t2eCallback) {
|
| 667 |
+
m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize);
|
| 668 |
+
}
|
| 669 |
+
if (!engine.cacheEosEmbedding(eosEmbedding)) {
|
| 670 |
+
__DEBUG("Failed to set the eos token embedding.");
|
| 671 |
+
return false;
|
| 672 |
+
}
|
| 673 |
+
}
|
| 674 |
+
|
| 675 |
+
using FF = Engine::Feature::Flags;
|
| 676 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 677 |
+
|
| 678 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 679 |
+
start.reset();
|
| 680 |
+
|
| 681 |
+
engine.set({{"kv-prefix-skip", _forecast_prefix}});
|
| 682 |
+
|
| 683 |
+
std::vector<int32_t> tokens(1,0);
|
| 684 |
+
|
| 685 |
+
// Process prompt
|
| 686 |
+
// get number of tokens in the input
|
| 687 |
+
size_t curTokensCount = embedding.size()/embedBufSize;
|
| 688 |
+
|
| 689 |
+
if(curTokensCount * embedBufSize != embedding.size()){
|
| 690 |
+
size_t expectedLength = (curTokensCount + (embedding.size()%embedBufSize != 0))*embedBufSize;
|
| 691 |
+
__DEBUG("Input is wrong expected {} and found {}.", expectedLength, embedding.size());
|
| 692 |
+
return Dialog::abort("Input is not an multiple for the embedding Length", callback);
|
| 693 |
+
}
|
| 694 |
+
|
| 695 |
+
_n_prompt += curTokensCount;
|
| 696 |
+
|
| 697 |
+
std::vector<int32_t> attention_map(curTokensCount);
|
| 698 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 699 |
+
|
| 700 |
+
engine.set({{"kv-prefix-offset", curTokensCount}}); // Do not attend prefix
|
| 701 |
+
|
| 702 |
+
if (_n_past + curTokensCount > _ctx->size()) {
|
| 703 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokensCount, _ctx->size());
|
| 704 |
+
callback("", Sentence::END);
|
| 705 |
+
return true;
|
| 706 |
+
}
|
| 707 |
+
|
| 708 |
+
if (!engine.process(embedding, attention_map, logits, false))
|
| 709 |
+
return Dialog::abort("engine prompt processing failed", callback); // Change this message also to some generic message.
|
| 710 |
+
_n_past += curTokensCount;
|
| 711 |
+
update_kv(_n_past, {});
|
| 712 |
+
|
| 713 |
+
bool status = true;
|
| 714 |
+
if (_n_streams <= 1) {
|
| 715 |
+
tokens[0] = sample_to_verify(std::span{logits.data(),logits.size()}, 0);
|
| 716 |
+
|
| 717 |
+
// Decode the first token.
|
| 718 |
+
_last_tok = tokens[0];
|
| 719 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 720 |
+
callback("", Sentence::END);
|
| 721 |
+
return true;
|
| 722 |
+
}
|
| 723 |
+
|
| 724 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
|
| 725 |
+
//decode_token(tokens[0]);
|
| 726 |
+
|
| 727 |
+
if (!m_t2eCallback) {
|
| 728 |
+
callback("", Sentence::END);
|
| 729 |
+
return true;
|
| 730 |
+
}
|
| 731 |
+
|
| 732 |
+
// Mark TTFT
|
| 733 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 734 |
+
start.reset();
|
| 735 |
+
State::busy(true);
|
| 736 |
+
|
| 737 |
+
// Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
|
| 738 |
+
// process separately because logits are required for these tokens
|
| 739 |
+
for (int i = 0; i < _draft; ++i)
|
| 740 |
+
tokens.push_back(_forecast_token_offset + i);
|
| 741 |
+
|
| 742 |
+
attention_map.resize(tokens.size());
|
| 743 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 744 |
+
engine.set({{"kv-prefix-offset", 1}}); // Prevent the last token from attending
|
| 745 |
+
|
| 746 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 747 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 748 |
+
callback("", Sentence::END);
|
| 749 |
+
return true;
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
// Convert tokens to embeddings
|
| 753 |
+
// reset embedding vector to make space for the next runs
|
| 754 |
+
embedding.clear();
|
| 755 |
+
convertTokensToEmbeddings(tokens, embedding, embedBufSize, m_t2eCallback);
|
| 756 |
+
|
| 757 |
+
if (!engine.process(embedding, attention_map, logits, true))
|
| 758 |
+
return Dialog::abort("initial inference for SSD pipeline failed", callback);
|
| 759 |
+
|
| 760 |
+
_n_past += 1;
|
| 761 |
+
update_kv(_n_past, {});
|
| 762 |
+
|
| 763 |
+
// Use existing as much as possible
|
| 764 |
+
status = processFollowOnGeneration(tokens, logits, callback);
|
| 765 |
+
} else {
|
| 766 |
+
std::vector<std::vector<int32_t>> streams;
|
| 767 |
+
getTopK(logits, streams, _n_streams, _p_threshold, callback);
|
| 768 |
+
|
| 769 |
+
if (!m_t2eCallback) {
|
| 770 |
+
for (auto& stream : streams) {
|
| 771 |
+
if (!callback(_tokenizer->decode(stream) + "\n", Sentence::BEGIN)) return true;
|
| 772 |
+
}
|
| 773 |
+
callback("", Sentence::END);
|
| 774 |
+
return true;
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
// Mark TTFT
|
| 778 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 779 |
+
start.reset();
|
| 780 |
+
State::busy(true);
|
| 781 |
+
|
| 782 |
+
if (streams.size() == 0) {
|
| 783 |
+
callback("\n", Sentence::END);
|
| 784 |
+
return true;
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
// Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
|
| 788 |
+
// process separately because logits are required for these tokens
|
| 789 |
+
attention_map.resize(1 + _draft);
|
| 790 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 791 |
+
|
| 792 |
+
std::vector<size_t> stream_indices(streams.size());
|
| 793 |
+
std::iota(stream_indices.begin(), stream_indices.end(), 0);
|
| 794 |
+
|
| 795 |
+
std::vector<int32_t> multi_attn_mask;
|
| 796 |
+
std::vector<size_t> past_map;
|
| 797 |
+
const size_t kvPrefixOffset = 1;
|
| 798 |
+
|
| 799 |
+
tileAttentionMask(attention_map, stream_indices, past_map, kvPrefixOffset, multi_attn_mask);
|
| 800 |
+
|
| 801 |
+
// Accumulate input tokens from all streams
|
| 802 |
+
std::vector<int32_t> multi_tokens;
|
| 803 |
+
|
| 804 |
+
multi_tokens.reserve(streams.size() * (1 + _draft));
|
| 805 |
+
for (int i = 0; i < streams.size(); i++) {
|
| 806 |
+
multi_tokens.insert(multi_tokens.end(), streams[i].begin(), streams[i].end());
|
| 807 |
+
for (int i = 0; i < _draft; ++i) {
|
| 808 |
+
multi_tokens.push_back(_forecast_token_offset + i);
|
| 809 |
+
}
|
| 810 |
+
}
|
| 811 |
+
|
| 812 |
+
// Convert tokens to embeddings
|
| 813 |
+
// reset embedding vector to make space for the next runs
|
| 814 |
+
embedding.clear();
|
| 815 |
+
convertTokensToEmbeddings(multi_tokens, embedding, embedBufSize, m_t2eCallback);
|
| 816 |
+
|
| 817 |
+
if (_n_past + multi_tokens.size() > _ctx->size()) {
|
| 818 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size());
|
| 819 |
+
callback("", Sentence::END);
|
| 820 |
+
return true;
|
| 821 |
+
}
|
| 822 |
+
|
| 823 |
+
if (!engine.process(embedding, multi_attn_mask, logits, true))
|
| 824 |
+
return Dialog::abort("initial inference for SSD pipeline failed", callback);
|
| 825 |
+
|
| 826 |
+
std::vector<bool> selected(multi_tokens.size(), false);
|
| 827 |
+
for (int i = 0; i < multi_tokens.size(); i+=(_draft+1)) {
|
| 828 |
+
selected[i] = true;
|
| 829 |
+
}
|
| 830 |
+
|
| 831 |
+
_n_past += streams.size();
|
| 832 |
+
update_kv(_n_past, selected);
|
| 833 |
+
|
| 834 |
+
status = processFollowOnGeneration(streams, logits, callback);
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 838 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 839 |
+
start.reset();
|
| 840 |
+
|
| 841 |
+
return status;
|
| 842 |
+
}
|
| 843 |
+
|
| 844 |
+
bool SelfSpecDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
|
| 845 |
+
|
| 846 |
+
// Check for prev failures and bail out early
|
| 847 |
+
if (State::failed()) return false;
|
| 848 |
+
|
| 849 |
+
Timer start;
|
| 850 |
+
|
| 851 |
+
if(m_inputType != InputType::TOKENS) {
|
| 852 |
+
__ERROR("Input type for model is not tokens.");
|
| 853 |
+
return false;
|
| 854 |
+
}
|
| 855 |
+
|
| 856 |
+
State::clear();
|
| 857 |
+
|
| 858 |
+
std::vector<float> logits;
|
| 859 |
+
auto& engine = *_engine["primary"];
|
| 860 |
+
|
| 861 |
+
auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
|
| 862 |
+
if (!engine.updateKV(past, selected))
|
| 863 |
+
return Dialog::abort("context size exceeded", callback);
|
| 864 |
+
return true;
|
| 865 |
+
};
|
| 866 |
+
|
| 867 |
+
using FF = Engine::Feature::Flags;
|
| 868 |
+
if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
|
| 869 |
+
|
| 870 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 871 |
+
start.reset();
|
| 872 |
+
|
| 873 |
+
engine.set({{"kv-prefix-skip", _forecast_prefix}});
|
| 874 |
+
|
| 875 |
+
std::vector<int32_t> attention_map(tokens.size());
|
| 876 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 877 |
+
|
| 878 |
+
// Process prompt
|
| 879 |
+
_n_prompt += tokens.size();
|
| 880 |
+
engine.set({{"kv-prefix-offset", tokens.size()}}); // Do not attend prefix
|
| 881 |
+
|
| 882 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 883 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 884 |
+
callback("", Sentence::END);
|
| 885 |
+
return true;
|
| 886 |
+
}
|
| 887 |
+
|
| 888 |
+
if (!engine.process(tokens, attention_map, logits, false))
|
| 889 |
+
return Dialog::abort("engine prompt processing failed", callback);
|
| 890 |
+
_n_past += tokens.size();
|
| 891 |
+
update_kv(_n_past, {});
|
| 892 |
+
|
| 893 |
+
bool status = true;
|
| 894 |
+
if (_n_streams <= 1) {
|
| 895 |
+
tokens[0] = sample_to_verify(std::span{logits.data(),logits.size()}, 0);
|
| 896 |
+
tokens.resize(1);
|
| 897 |
+
|
| 898 |
+
// Decode the first token.
|
| 899 |
+
_last_tok = tokens[0];
|
| 900 |
+
if (_ctx->is_eos(_last_tok)) {
|
| 901 |
+
callback("", Sentence::END);
|
| 902 |
+
return true;
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
|
| 906 |
+
// decode_token(tokens[0]);
|
| 907 |
+
|
| 908 |
+
// Mark TTFT
|
| 909 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 910 |
+
start.reset();
|
| 911 |
+
State::busy(true);
|
| 912 |
+
|
| 913 |
+
// Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
|
| 914 |
+
// process separately because logits are required for these tokens
|
| 915 |
+
for (int i = 0; i < _draft; ++i)
|
| 916 |
+
tokens.push_back(_forecast_token_offset + i);
|
| 917 |
+
|
| 918 |
+
attention_map.resize(tokens.size());
|
| 919 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 920 |
+
engine.set({{"kv-prefix-offset", 1}}); // Prevent the last token from attending
|
| 921 |
+
|
| 922 |
+
if (_n_past + tokens.size() > _ctx->size()) {
|
| 923 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
|
| 924 |
+
callback("", Sentence::END);
|
| 925 |
+
return true;
|
| 926 |
+
}
|
| 927 |
+
|
| 928 |
+
if (!engine.process(tokens, attention_map, logits, true))
|
| 929 |
+
return Dialog::abort("initial inference for SSD pipeline failed", callback);
|
| 930 |
+
|
| 931 |
+
_n_past += 1;
|
| 932 |
+
update_kv(_n_past, {});
|
| 933 |
+
|
| 934 |
+
status = processFollowOnGeneration(tokens, logits, callback);
|
| 935 |
+
} else {
|
| 936 |
+
std::vector<std::vector<int32_t>> streams;
|
| 937 |
+
getTopK(logits, streams, _n_streams, _p_threshold, callback);
|
| 938 |
+
|
| 939 |
+
// Mark TTFT
|
| 940 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 941 |
+
start.reset();
|
| 942 |
+
State::busy(true);
|
| 943 |
+
|
| 944 |
+
if (streams.size() == 0) {
|
| 945 |
+
callback("\n", Sentence::END);
|
| 946 |
+
return true;
|
| 947 |
+
}
|
| 948 |
+
|
| 949 |
+
// Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
|
| 950 |
+
// process separately because logits are required for these tokens
|
| 951 |
+
attention_map.resize(1 + _draft);
|
| 952 |
+
std::iota(attention_map.begin(), attention_map.end(), -1);
|
| 953 |
+
|
| 954 |
+
std::vector<size_t> stream_indices(streams.size());
|
| 955 |
+
std::iota(stream_indices.begin(), stream_indices.end(), 0);
|
| 956 |
+
|
| 957 |
+
std::vector<int32_t> multi_attn_mask;
|
| 958 |
+
std::vector<size_t> past_map;
|
| 959 |
+
const size_t kvPrefixOffset = 1;
|
| 960 |
+
|
| 961 |
+
tileAttentionMask(attention_map, stream_indices, past_map, kvPrefixOffset, multi_attn_mask);
|
| 962 |
+
|
| 963 |
+
// Accumulate input tokens from all streams
|
| 964 |
+
std::vector<int32_t> multi_tokens;
|
| 965 |
+
|
| 966 |
+
multi_tokens.reserve(streams.size() * (1 + _draft));
|
| 967 |
+
for (int i = 0; i < streams.size(); i++) {
|
| 968 |
+
multi_tokens.insert(multi_tokens.end(), streams[i].begin(), streams[i].end());
|
| 969 |
+
for (int i = 0; i < _draft; ++i) {
|
| 970 |
+
multi_tokens.push_back(_forecast_token_offset + i);
|
| 971 |
+
}
|
| 972 |
+
}
|
| 973 |
+
|
| 974 |
+
if (_n_past + multi_tokens.size() > _ctx->size()) {
|
| 975 |
+
__WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size());
|
| 976 |
+
callback("", Sentence::END);
|
| 977 |
+
return true;
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
if (!engine.process(multi_tokens, multi_attn_mask, logits, true))
|
| 981 |
+
return Dialog::abort("initial inference for SSD pipeline failed", callback);
|
| 982 |
+
|
| 983 |
+
std::vector<bool> selected(multi_tokens.size(), false);
|
| 984 |
+
for (int i = 0; i < multi_tokens.size(); i+=(_draft+1)) {
|
| 985 |
+
selected[i] = true;
|
| 986 |
+
}
|
| 987 |
+
|
| 988 |
+
_n_past += streams.size();
|
| 989 |
+
update_kv(_n_past, selected);
|
| 990 |
+
|
| 991 |
+
status = processFollowOnGeneration(streams, logits, callback);
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
_kpis.generate.update(start.elapsed_usec());
|
| 995 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 996 |
+
start.reset();
|
| 997 |
+
|
| 998 |
+
return status;
|
| 999 |
+
}
|
| 1000 |
+
|
| 1001 |
+
void SelfSpecDecDialog::reset() {
|
| 1002 |
+
Dialog::reset();
|
| 1003 |
+
_n_past = _forecast_prefix;
|
| 1004 |
+
size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
|
| 1005 |
+
if (n_restored_prefix != _forecast_prefix) {
|
| 1006 |
+
// clang-format off
|
| 1007 |
+
throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
|
| 1008 |
+
n_restored_prefix, _kv_prefix_name, _forecast_prefix ) );
|
| 1009 |
+
// clang-format on
|
| 1010 |
+
}
|
| 1011 |
+
}
|
| 1012 |
+
|
| 1013 |
+
bool SelfSpecDecDialog::save(const std::string& name) {
|
| 1014 |
+
if (_n_streams > 1) {
|
| 1015 |
+
throw std::runtime_error("Save is unsupported for multistream dialogs.");
|
| 1016 |
+
}
|
| 1017 |
+
return Dialog::save(name);
|
| 1018 |
+
}
|
| 1019 |
+
|
| 1020 |
+
bool SelfSpecDecDialog::restore(const std::string& name) {
|
| 1021 |
+
if (_n_streams > 1) {
|
| 1022 |
+
throw std::runtime_error("Restore is unsupported for multistream dialogs.");
|
| 1023 |
+
}
|
| 1024 |
+
return Dialog::restore(name);
|
| 1025 |
+
}
|
| 1026 |
+
|
| 1027 |
+
// Registrator instance
|
| 1028 |
+
static OnLoad regy([]() {
|
| 1029 |
+
Dialog::__register(
|
| 1030 |
+
"ssd-q1",
|
| 1031 |
+
[](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
|
| 1032 |
+
return (Dialog*)new SelfSpecDecDialog(env, name, conf);
|
| 1033 |
+
}
|
| 1034 |
+
);
|
| 1035 |
+
});
|
| 1036 |
+
|
| 1037 |
+
// Register ssd sampler for compatibility
|
| 1038 |
+
static OnLoad sampler_regy([]() {
|
| 1039 |
+
Sampler::__register("basic", [](Context& ctx, const json& conf) {
|
| 1040 |
+
return (Sampler*)new BasicSampler(ctx, conf);
|
| 1041 |
+
});
|
| 1042 |
+
});
|
| 1043 |
+
|
| 1044 |
+
void needSsdDialog() {}
|
| 1045 |
+
|
| 1046 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/embedding.cpp
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/embedding.hpp>
|
| 10 |
+
#include <qualla/logger.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
#include <qualla/detail/timer.hpp>
|
| 13 |
+
|
| 14 |
+
#include <functional>
|
| 15 |
+
#include <fstream>
|
| 16 |
+
#include <string>
|
| 17 |
+
#include <unordered_map>
|
| 18 |
+
#include <filesystem>
|
| 19 |
+
|
| 20 |
+
#include <fmt/format.h>
|
| 21 |
+
#include <fmt/ranges.h>
|
| 22 |
+
|
| 23 |
+
namespace fs = std::filesystem;
|
| 24 |
+
|
| 25 |
+
namespace qualla {
|
| 26 |
+
|
| 27 |
+
Embedding::Embedding(std::shared_ptr<Env> env, const std::string& name, const qualla::json& json)
|
| 28 |
+
: _name(name), _env(env) {
|
| 29 |
+
Timer start;
|
| 30 |
+
|
| 31 |
+
_env->logger().debug(fmt::format("embedding-new: {} config {}", name, json.dump()));
|
| 32 |
+
|
| 33 |
+
using qc = qualla::Config;
|
| 34 |
+
|
| 35 |
+
// Parse prompt config
|
| 36 |
+
const qualla::json& pmt_conf = qc::optional<qualla::json>(json, "prompt", {});
|
| 37 |
+
_tags = qc::optional<std::vector<std::string>>(pmt_conf, "tags", {"", ""});
|
| 38 |
+
|
| 39 |
+
// Create the context first
|
| 40 |
+
_ctx = Context::create(*_env, name, qc::optional<qualla::json>(json, "context", {}));
|
| 41 |
+
|
| 42 |
+
// Create Tokenizer
|
| 43 |
+
fs::path tok_path = _env->path().models / qc::mandatory<std::string>(json, "tokenizer");
|
| 44 |
+
_tokenizer = Tokenizer::create(*_ctx, tok_path);
|
| 45 |
+
|
| 46 |
+
// Create Engine
|
| 47 |
+
const qualla::json& eng_conf = qc::mandatory<qualla::json>(json, "engine");
|
| 48 |
+
_engine = Engine::create(*_ctx, eng_conf);
|
| 49 |
+
|
| 50 |
+
// Truncation of input to context
|
| 51 |
+
_input_truncation = qc::optional<qualla::json>(json, "truncate-input", false);
|
| 52 |
+
|
| 53 |
+
using FF = Engine::Feature::Flags;
|
| 54 |
+
if (!_engine->supports(FF::OUTPUT_EMBEDDINGS))
|
| 55 |
+
throw std::runtime_error("engine must output embeddings");
|
| 56 |
+
|
| 57 |
+
_kpis.init.update(start.elapsed_usec());
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
Embedding::~Embedding() {}
|
| 61 |
+
|
| 62 |
+
bool Embedding::process(std::vector<int32_t>& tokens, std::vector<float>& output) {
|
| 63 |
+
Timer start;
|
| 64 |
+
|
| 65 |
+
State::clear();
|
| 66 |
+
|
| 67 |
+
size_t n = _engine->process(tokens, output, false);
|
| 68 |
+
if (!n) {
|
| 69 |
+
State::error("engine prompt processing failed");
|
| 70 |
+
return false;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
_n_prompt += tokens.size();
|
| 74 |
+
|
| 75 |
+
// Clean the buffer before using
|
| 76 |
+
_output_dimensions.clear();
|
| 77 |
+
|
| 78 |
+
uint64_t output_size = 1;
|
| 79 |
+
// push number of tokens present in the result.
|
| 80 |
+
_output_dimensions.push_back(n);
|
| 81 |
+
// push back the dimension of the each embedding
|
| 82 |
+
_output_dimensions.push_back(_ctx->n_embd());
|
| 83 |
+
|
| 84 |
+
output_size = n * _ctx->n_embd();
|
| 85 |
+
|
| 86 |
+
output.resize(output_size);
|
| 87 |
+
|
| 88 |
+
_kpis.prompt.update(start.elapsed_usec());
|
| 89 |
+
|
| 90 |
+
// Log latest KPIs in a single line
|
| 91 |
+
_env->logger().post(Logger::KPIS, kpis().dump(" "));
|
| 92 |
+
|
| 93 |
+
return true;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
bool Embedding::query(const std::string& str, std::vector<float>& output) {
|
| 97 |
+
std::string p_str; // prompt string
|
| 98 |
+
std::vector<int32_t> p_vec; // prompt tokens
|
| 99 |
+
|
| 100 |
+
p_vec.reserve(_ctx->n_ctx());
|
| 101 |
+
|
| 102 |
+
p_str = _tags[0] + str + _tags[1];
|
| 103 |
+
|
| 104 |
+
_env->logger().debug(fmt::format("embedding-query: {}", str));
|
| 105 |
+
_env->logger().debug(fmt::format("embedding-prompt: {}", p_str));
|
| 106 |
+
|
| 107 |
+
_n_queries++;
|
| 108 |
+
|
| 109 |
+
_tokenizer->encode(p_str, p_vec);
|
| 110 |
+
|
| 111 |
+
_env->logger().debug(fmt::format("embedding-tokens: {}", p_vec));
|
| 112 |
+
|
| 113 |
+
if(p_vec.size() > (_ctx->n_ctx())){ // Condition to not allow input to exceed context.
|
| 114 |
+
if(_input_truncation == false){
|
| 115 |
+
throw std::runtime_error("Input exceeds the context of the model.");
|
| 116 |
+
}
|
| 117 |
+
else{
|
| 118 |
+
p_vec.resize(_ctx->n_ctx());
|
| 119 |
+
}
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
return process(p_vec, output);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
// Embedding KPIs helpers
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
void Embedding::output_dimensions(std::vector<std::uint32_t>& outputDimensions){
|
| 129 |
+
outputDimensions = _output_dimensions;
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
// Get latest KPIs
|
| 133 |
+
Embedding::KPIs& Embedding::kpis() {
|
| 134 |
+
// Update TPS
|
| 135 |
+
if (_n_prompt) {
|
| 136 |
+
float t = _kpis.prompt.total_usec / _n_prompt;
|
| 137 |
+
_kpis.tps.prompt = 1000000.0 / (t ? t : 1000000.0);
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
// We could synthesize more KPIs from from other layers (engine, sampler, etc)
|
| 141 |
+
return _kpis;
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
std::string Embedding::KPIs::dump(std::string_view sep) const {
|
| 145 |
+
return fmt::format(
|
| 146 |
+
"init:[{}]{}prompt:[{}]{} tps-prompt:{:.2f}",
|
| 147 |
+
init.dump(),
|
| 148 |
+
sep,
|
| 149 |
+
prompt.dump(),
|
| 150 |
+
sep,
|
| 151 |
+
tps.prompt
|
| 152 |
+
);
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
void Embedding::KPIs::reset() {
|
| 156 |
+
init.reset();
|
| 157 |
+
prompt.reset();
|
| 158 |
+
tps.prompt = 0.0;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
// Create API
|
| 162 |
+
|
| 163 |
+
std::unique_ptr<Embedding> Embedding::create(
|
| 164 |
+
std::shared_ptr<Env> env,
|
| 165 |
+
const std::string& name,
|
| 166 |
+
const qualla::json& conf
|
| 167 |
+
) {
|
| 168 |
+
return std::make_unique<Embedding>(env, name, conf);
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
std::unique_ptr<Embedding> Embedding::create(
|
| 172 |
+
std::shared_ptr<Env> env,
|
| 173 |
+
const std::string& name,
|
| 174 |
+
std::istream& json_stream
|
| 175 |
+
) {
|
| 176 |
+
return create(env, name, json::parse(json_stream));
|
| 177 |
+
}
|
| 178 |
+
|
| 179 |
+
std::unique_ptr<Embedding> Embedding::create(
|
| 180 |
+
std::shared_ptr<Env> env,
|
| 181 |
+
const std::string& name,
|
| 182 |
+
const fs::path& json_path
|
| 183 |
+
) {
|
| 184 |
+
if (!fs::exists(json_path))
|
| 185 |
+
throw std::runtime_error(json_path.string() + ": file does not exist");
|
| 186 |
+
std::ifstream ifs(json_path);
|
| 187 |
+
return create(env, name, ifs);
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/engine.cpp
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <qualla/engine.hpp>
|
| 10 |
+
#include <qualla/detail/kpi.hpp>
|
| 11 |
+
#include <qualla/detail/config.hpp>
|
| 12 |
+
|
| 13 |
+
#include <functional>
|
| 14 |
+
#include <iostream>
|
| 15 |
+
#include <sstream>
|
| 16 |
+
#include <string>
|
| 17 |
+
#include <unordered_map>
|
| 18 |
+
|
| 19 |
+
#include <fmt/format.h>
|
| 20 |
+
#include <fmt/ranges.h>
|
| 21 |
+
|
| 22 |
+
namespace qualla {
|
| 23 |
+
|
| 24 |
+
Engine::Engine(Context& ctx, const std::string& type, const qualla::json& conf)
|
| 25 |
+
: _type(type), _ctx(ctx), _env(ctx.env()) {
|
| 26 |
+
_env.logger().debug(
|
| 27 |
+
fmt::format("engine-new: {} ctx {} config {}", type, _ctx.name(), conf.dump())
|
| 28 |
+
);
|
| 29 |
+
|
| 30 |
+
using qc = qualla::Config;
|
| 31 |
+
_role = qc::optional<std::string>(conf, "role", "primary");
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
Engine::~Engine() {}
|
| 35 |
+
|
| 36 |
+
size_t Engine::process(
|
| 37 |
+
const std::vector<int32_t>& tokens,
|
| 38 |
+
const std::vector<int32_t>& attention_map,
|
| 39 |
+
std::vector<float>& output,
|
| 40 |
+
bool output_all
|
| 41 |
+
) {
|
| 42 |
+
_env.logger().error(fmt::format("{}-engine does not support attention_map", _type));
|
| 43 |
+
return 0;
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
size_t Engine::process(const std::vector<int32_t>& tokens) {
|
| 47 |
+
// Derived engines should overwrite this to avoid copying logits
|
| 48 |
+
std::vector<float> logits;
|
| 49 |
+
return process(tokens, logits);
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
size_t Engine::process(
|
| 53 |
+
std::vector<uint8_t>& embeddings,
|
| 54 |
+
const std::vector<int32_t>& attention_map,
|
| 55 |
+
std::vector<float>& output,
|
| 56 |
+
bool output_all
|
| 57 |
+
) {
|
| 58 |
+
_env.logger().error(fmt::format("{}-engine does not support embedding as input", _type));
|
| 59 |
+
return 0;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
bool Engine::updateKV(size_t n_past) {
|
| 63 |
+
_env.logger().error(fmt::format("{}-engine does not support sync", _type));
|
| 64 |
+
return false;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
bool Engine::updateKV(size_t n_past, const std::vector<bool>& selected) {
|
| 68 |
+
_env.logger().error(fmt::format("{}-engine does not support sync with selected", _type));
|
| 69 |
+
return false;
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
size_t Engine::restore(const std::string& name) {
|
| 73 |
+
_env.logger().error(fmt::format("{}-engine does not support restore", _type));
|
| 74 |
+
return 0;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
bool Engine::save(const std::string& name) {
|
| 78 |
+
_env.logger().error(fmt::format("{}-engine does not support save", _type));
|
| 79 |
+
return false;
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
void Engine::reset() {
|
| 83 |
+
_env.logger().error(fmt::format("{}-engine does not support reset", _type));
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
bool Engine::load() {
|
| 87 |
+
_env.logger().error(fmt::format("{}-engine does not support dynamic load", _type));
|
| 88 |
+
return 0;
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
bool Engine::unload() {
|
| 92 |
+
_env.logger().error(fmt::format("{}-engine does not support dynamic unload", _type));
|
| 93 |
+
return false;
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
bool Engine::set(qualla::json data) {
|
| 97 |
+
_env.logger().error(fmt::format("{}-engine does not support set()", _type));
|
| 98 |
+
return false;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
qualla::json Engine::get() {
|
| 102 |
+
_env.logger().error(fmt::format("{}-engine does not support get()", _type));
|
| 103 |
+
return false;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
bool Engine::cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding) {
|
| 107 |
+
_env.logger().error(fmt::format("{}-engine does not support cache eos embedding", _type));
|
| 108 |
+
return true;
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
size_t Engine::getEmbeddingBufferSize() {
|
| 112 |
+
_env.logger().error(fmt::format("{}-engine does not support embedding vectors", _type));
|
| 113 |
+
return 0;
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
qualla::InputType Engine::getInputType(){
|
| 117 |
+
return qualla::InputType::TOKENS;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
// Engine KPIs
|
| 121 |
+
|
| 122 |
+
std::string Engine::KPIs::dump(std::string_view sep) const {
|
| 123 |
+
return fmt::format(
|
| 124 |
+
"load:[{}]{}process:[{}]{}update-kv:[{}]{}unload:[{}]",
|
| 125 |
+
load.dump(),
|
| 126 |
+
sep,
|
| 127 |
+
process.dump(),
|
| 128 |
+
sep,
|
| 129 |
+
update_kv.dump(),
|
| 130 |
+
sep,
|
| 131 |
+
unload.dump()
|
| 132 |
+
);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
void Engine::KPIs::reset() {
|
| 136 |
+
load.reset();
|
| 137 |
+
process.reset();
|
| 138 |
+
update_kv.reset();
|
| 139 |
+
unload.reset();
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
// Engine registry type string + creator function
|
| 143 |
+
using Registry = std::unordered_map<std::string, Engine::Creator>;
|
| 144 |
+
static std::unique_ptr<Registry> registry;
|
| 145 |
+
|
| 146 |
+
void Engine::__register(const std::string& type, Creator func) {
|
| 147 |
+
if (!registry) registry = std::make_unique<Registry>();
|
| 148 |
+
|
| 149 |
+
Registry& r = *registry;
|
| 150 |
+
r[type] = func;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
std::unique_ptr<Engine> Engine::create(Context& ctx, const qualla::json& conf) {
|
| 154 |
+
using qc = qualla::Config;
|
| 155 |
+
|
| 156 |
+
std::string type = qc::mandatory<std::string>(conf, "type");
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if (!registry) throw std::runtime_error(type + ": engine not found");
|
| 160 |
+
|
| 161 |
+
Registry& r = *registry;
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
if (!r.contains(type)) throw std::runtime_error(type + ": engine not found");
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
return std::unique_ptr<Engine>(r[type](ctx, conf));
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
std::unique_ptr<Engine> Engine::create(Context& ctx, std::istream& json_stream) {
|
| 171 |
+
return create(ctx, json::parse(json_stream));
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
std::unique_ptr<Engine> Engine::create(Context& ctx, const std::string& json_str) {
|
| 175 |
+
return create(ctx, json::parse(json_str));
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
std::vector<std::string> Engine::list() {
|
| 179 |
+
std::vector<std::string> v;
|
| 180 |
+
if (!registry) return v;
|
| 181 |
+
|
| 182 |
+
Registry& r = *registry;
|
| 183 |
+
|
| 184 |
+
for (auto k : r)
|
| 185 |
+
v.push_back(k.first);
|
| 186 |
+
return v;
|
| 187 |
+
}
|
| 188 |
+
|
| 189 |
+
bool Engine::applyLoraAdapter(std::string lora_adapter_name) {
|
| 190 |
+
_env.logger().error(fmt::format("{}-engine does not support LoraAdapter", _type));
|
| 191 |
+
return false;
|
| 192 |
+
}
|
| 193 |
+
bool Engine::applyLoraStrength(std::string tensor_name, float tensor_val) {
|
| 194 |
+
_env.logger().error(fmt::format("{}-engine does not support setLoraStrength", _type));
|
| 195 |
+
return false;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/engines/lib.cpp
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
// Just a stub for building qualla::engines when no built-in engines are enabled
|
Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "dlwrap.hpp"
|
| 10 |
+
#include "BackendExtensions.hpp"
|
| 11 |
+
#include "NetRunBackend.hpp"
|
| 12 |
+
|
| 13 |
+
BackendExtensions::BackendExtensions(
|
| 14 |
+
BackendExtensionsConfigs backendExtensionsConfig,
|
| 15 |
+
void* backendLibHandle,
|
| 16 |
+
PerfProfile perfProfile,
|
| 17 |
+
std::shared_ptr<ICommandLineManager> clManager,
|
| 18 |
+
bool debug_qnn
|
| 19 |
+
)
|
| 20 |
+
: m_backendExtensionsLibPath(backendExtensionsConfig.sharedLibraryPath),
|
| 21 |
+
m_backendExtensionsConfigPath(backendExtensionsConfig.configFilePath),
|
| 22 |
+
m_backendInterface(nullptr), m_isNetRunBackendInterface(false),
|
| 23 |
+
m_createBackendInterfaceFn(nullptr), m_destroyBackendInterfaceFn(nullptr),
|
| 24 |
+
m_backendLibHandle(backendLibHandle), m_perfProfile(perfProfile), m_clManager(clManager),
|
| 25 |
+
m_debugQnn(debug_qnn) {
|
| 26 |
+
(void)m_perfProfile;
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
BackendExtensions::~BackendExtensions() {
|
| 30 |
+
if (nullptr != m_backendInterface) {
|
| 31 |
+
if (m_isNetRunBackendInterface) {
|
| 32 |
+
QNN_DEBUG("Deleting NetRun Backend Interface");
|
| 33 |
+
delete m_backendInterface;
|
| 34 |
+
} else {
|
| 35 |
+
if (nullptr != m_destroyBackendInterfaceFn) {
|
| 36 |
+
QNN_DEBUG("Destroying Backend Interface");
|
| 37 |
+
m_destroyBackendInterfaceFn(m_backendInterface);
|
| 38 |
+
}
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
bool BackendExtensions::loadFunctionPointers() {
|
| 44 |
+
|
| 45 |
+
void* libHandle = dlopen(m_backendExtensionsLibPath.c_str(), RTLD_NOW | RTLD_LOCAL);
|
| 46 |
+
if (nullptr == libHandle) {
|
| 47 |
+
QNN_ERROR(
|
| 48 |
+
"Unable to load backend extensions lib: [%s]. dlerror(): [%s]",
|
| 49 |
+
m_backendExtensionsLibPath.c_str(),
|
| 50 |
+
dlerror()
|
| 51 |
+
);
|
| 52 |
+
return false;
|
| 53 |
+
}
|
| 54 |
+
m_createBackendInterfaceFn =
|
| 55 |
+
(CreateBackendInterfaceFnType_t)dlsym(libHandle, "createBackendInterface");
|
| 56 |
+
m_destroyBackendInterfaceFn =
|
| 57 |
+
(DestroyBackendInterfaceFnType_t)dlsym(libHandle, "destroyBackendInterface");
|
| 58 |
+
if (nullptr == m_createBackendInterfaceFn || nullptr == m_destroyBackendInterfaceFn) {
|
| 59 |
+
QNN_ERROR("Unable to find symbols. dlerror(): [%s]", dlerror());
|
| 60 |
+
return false;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
return true;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
void BackendExtensions::qnnLogCallback(
|
| 67 |
+
const char* fmt,
|
| 68 |
+
QnnLog_Level_t level,
|
| 69 |
+
uint64_t timestamp,
|
| 70 |
+
va_list args
|
| 71 |
+
) {
|
| 72 |
+
char buffer[1024] = "";
|
| 73 |
+
const char* levelStr = "";
|
| 74 |
+
switch (level) {
|
| 75 |
+
case QNN_LOG_LEVEL_ERROR:
|
| 76 |
+
levelStr = " ERROR ";
|
| 77 |
+
break;
|
| 78 |
+
case QNN_LOG_LEVEL_WARN:
|
| 79 |
+
levelStr = "WARNING";
|
| 80 |
+
break;
|
| 81 |
+
case QNN_LOG_LEVEL_INFO:
|
| 82 |
+
levelStr = " INFO ";
|
| 83 |
+
break;
|
| 84 |
+
case QNN_LOG_LEVEL_DEBUG:
|
| 85 |
+
levelStr = " DEBUG ";
|
| 86 |
+
break;
|
| 87 |
+
case QNN_LOG_LEVEL_VERBOSE:
|
| 88 |
+
levelStr = "VERBOSE";
|
| 89 |
+
break;
|
| 90 |
+
case QNN_LOG_LEVEL_MAX:
|
| 91 |
+
levelStr = "UNKNOWN";
|
| 92 |
+
break;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
int pos = snprintf(
|
| 96 |
+
buffer, sizeof(buffer), "QNN: [%s] time=%lu:", levelStr, (unsigned long)timestamp
|
| 97 |
+
);
|
| 98 |
+
vsnprintf(buffer + pos, sizeof(buffer) - pos, fmt, args);
|
| 99 |
+
printf("%s", buffer);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
bool BackendExtensions::initialize() {
|
| 103 |
+
|
| 104 |
+
QNN_DEBUG("DEBUG: m_backendExtensionsLibPath=%s\n", m_backendExtensionsLibPath.c_str());
|
| 105 |
+
QNN_DEBUG("DEBUG: m_backendExtensionsConfigPath=%s\n", m_backendExtensionsConfigPath.c_str());
|
| 106 |
+
if (m_backendExtensionsLibPath.empty() && m_backendExtensionsConfigPath.empty()) {
|
| 107 |
+
QNN_WARN("No BackendExtensions lib provided; initializing NetRunBackend Interface");
|
| 108 |
+
m_isNetRunBackendInterface = true;
|
| 109 |
+
m_backendInterface = new NetRunBackend();
|
| 110 |
+
} else {
|
| 111 |
+
QNN_DEBUG("Loading supplied backend extensions lib.");
|
| 112 |
+
QNN_DEBUG("Backend extensions lib path: %s", m_backendExtensionsLibPath.c_str());
|
| 113 |
+
if (m_backendExtensionsConfigPath.empty()) {
|
| 114 |
+
QNN_DEBUG("Backend extensions lib specified without a config file.");
|
| 115 |
+
} else {
|
| 116 |
+
QNN_DEBUG("Backend extensions config path: %s", m_backendExtensionsConfigPath.c_str());
|
| 117 |
+
}
|
| 118 |
+
if (!loadFunctionPointers()) {
|
| 119 |
+
QNN_ERROR("Failed to load function pointers.");
|
| 120 |
+
return false;
|
| 121 |
+
}
|
| 122 |
+
if (nullptr != m_createBackendInterfaceFn) {
|
| 123 |
+
m_backendInterface = m_createBackendInterfaceFn();
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
if (nullptr == m_backendInterface) {
|
| 127 |
+
QNN_ERROR("Unable to load backend extensions interface.");
|
| 128 |
+
return false;
|
| 129 |
+
}
|
| 130 |
+
if (m_debugQnn) {
|
| 131 |
+
if (!(m_backendInterface->setupLogging(BackendExtensions::qnnLogCallback, QNN_LOG_LEVEL_VERBOSE))) {
|
| 132 |
+
QNN_WARN("Unable to initialize logging in backend extensions.");
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
if (!m_backendInterface->initialize(m_backendLibHandle)) {
|
| 136 |
+
QNN_ERROR("Unable to initialize backend extensions interface.");
|
| 137 |
+
return false;
|
| 138 |
+
}
|
| 139 |
+
if (!m_backendInterface->setPerfProfile(m_perfProfile)) {
|
| 140 |
+
QNN_WARN("Unable to set perf profile in backend extensions interface.");
|
| 141 |
+
//return false;
|
| 142 |
+
}
|
| 143 |
+
if (!m_backendInterface->loadConfig(m_backendExtensionsConfigPath)) {
|
| 144 |
+
QNN_ERROR("Unable to load backend extensions interface config.");
|
| 145 |
+
return false;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
if ((m_clManager != nullptr) && !m_backendInterface->loadCommandLineArgs(m_clManager)) {
|
| 149 |
+
QNN_ERROR("Unable to load backend extensions' command line arguments.");
|
| 150 |
+
return false;
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
return true;
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
IBackend* BackendExtensions::interface() {
|
| 157 |
+
return m_backendInterface;
|
| 158 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <string>
|
| 12 |
+
|
| 13 |
+
#include "IBackend.hpp"
|
| 14 |
+
#include "QnnConfig.hpp"
|
| 15 |
+
#include "Log.hpp"
|
| 16 |
+
|
| 17 |
+
// This is a wrapper class that handles resources/state related to
|
| 18 |
+
// backend extensions interface. This is used by QnnNetRun library
|
| 19 |
+
// to manage and call into an IBackend interface implementation.
|
| 20 |
+
// Functionality present in this class:
|
| 21 |
+
// 1. Receives the argument string related to backend_extensions
|
| 22 |
+
// argument from the front end and processes it to open the
|
| 23 |
+
// backend extensions library.
|
| 24 |
+
// 2. Locates and stores symbols for creating and destroying the
|
| 25 |
+
// IBackend interface implementation.
|
| 26 |
+
// 3. If there is no backend_extensions argument, this class creates
|
| 27 |
+
// the dummy IBackend implementation aka NetRunBackend.
|
| 28 |
+
// 4. Gives QnnNetRun access to the implementation itself through
|
| 29 |
+
// interface() function.
|
| 30 |
+
class BackendExtensions final {
|
| 31 |
+
public:
|
| 32 |
+
BackendExtensions(
|
| 33 |
+
BackendExtensionsConfigs backendExtensionsConfig,
|
| 34 |
+
void* backendLibHandle,
|
| 35 |
+
PerfProfile perfProfile,
|
| 36 |
+
std::shared_ptr<ICommandLineManager> clManager =
|
| 37 |
+
std::shared_ptr<ICommandLineManager>(nullptr),
|
| 38 |
+
bool debug_qnn = false
|
| 39 |
+
);
|
| 40 |
+
~BackendExtensions();
|
| 41 |
+
bool initialize();
|
| 42 |
+
IBackend* interface();
|
| 43 |
+
|
| 44 |
+
private:
|
| 45 |
+
bool loadFunctionPointers();
|
| 46 |
+
std::string m_backendExtensionsLibPath;
|
| 47 |
+
std::string m_backendExtensionsConfigPath;
|
| 48 |
+
IBackend* m_backendInterface;
|
| 49 |
+
bool m_isNetRunBackendInterface;
|
| 50 |
+
CreateBackendInterfaceFnType_t m_createBackendInterfaceFn;
|
| 51 |
+
DestroyBackendInterfaceFnType_t m_destroyBackendInterfaceFn;
|
| 52 |
+
void* m_backendLibHandle;
|
| 53 |
+
PerfProfile m_perfProfile;
|
| 54 |
+
std::shared_ptr<ICommandLineManager> m_clManager;
|
| 55 |
+
bool m_debugQnn{false};
|
| 56 |
+
static void qnnLogCallback(
|
| 57 |
+
const char* fmt,
|
| 58 |
+
QnnLog_Level_t level,
|
| 59 |
+
uint64_t timestamp,
|
| 60 |
+
va_list args
|
| 61 |
+
);
|
| 62 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "ClientBuffer.hpp"
|
| 10 |
+
#include "QnnTypeMacros.hpp"
|
| 11 |
+
|
| 12 |
+
void* ClientBuffer::getBuffer(Qnn_Tensor_t* tensor) {
|
| 13 |
+
if (!tensor) {
|
| 14 |
+
QNN_WARN("getBuffer: received a null pointer to a tensor");
|
| 15 |
+
return nullptr;
|
| 16 |
+
}
|
| 17 |
+
return QNN_TENSOR_GET_CLIENT_BUF(tensor).data;
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
size_t ClientBuffer::getBufferSize(Qnn_Tensor_t* tensor) {
|
| 21 |
+
if (!tensor) {
|
| 22 |
+
QNN_WARN("getBufferSize: received a null pointer to a tensor");
|
| 23 |
+
return 0;
|
| 24 |
+
}
|
| 25 |
+
return QNN_TENSOR_GET_CLIENT_BUF(tensor).dataSize;
|
| 26 |
+
};
|
| 27 |
+
|
| 28 |
+
bool ClientBuffer::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) {
|
| 29 |
+
if (!tensor) {
|
| 30 |
+
QNN_ERROR("Received nullptr for tensors");
|
| 31 |
+
return false;
|
| 32 |
+
}
|
| 33 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_RAW);
|
| 34 |
+
Qnn_ClientBuffer_t clientBuffer;
|
| 35 |
+
clientBuffer.data = malloc(tensorDataSize);
|
| 36 |
+
if (nullptr == clientBuffer.data) {
|
| 37 |
+
QNN_ERROR("mem alloc failed for clientBuffer.data");
|
| 38 |
+
return false;
|
| 39 |
+
}
|
| 40 |
+
clientBuffer.dataSize = tensorDataSize;
|
| 41 |
+
QNN_TENSOR_SET_CLIENT_BUF(tensor, clientBuffer);
|
| 42 |
+
return true;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
bool ClientBuffer::freeTensorBuffer(Qnn_Tensor_t* tensor) {
|
| 46 |
+
if (!tensor) {
|
| 47 |
+
QNN_ERROR("Received nullptr for tensors");
|
| 48 |
+
return false;
|
| 49 |
+
}
|
| 50 |
+
if (QNN_TENSOR_GET_CLIENT_BUF(tensor).data) {
|
| 51 |
+
if (m_sameMemoryFreeTensors.find(tensor) == m_sameMemoryFreeTensors.end()) {
|
| 52 |
+
free(QNN_TENSOR_GET_CLIENT_BUF(tensor).data);
|
| 53 |
+
}
|
| 54 |
+
QNN_TENSOR_SET_CLIENT_BUF(tensor, Qnn_ClientBuffer_t({nullptr, 0u}));
|
| 55 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED);
|
| 56 |
+
}
|
| 57 |
+
return true;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
bool ClientBuffer::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
|
| 61 |
+
if (nullptr == dest || nullptr == src) {
|
| 62 |
+
QNN_ERROR("Received nullptr");
|
| 63 |
+
return false;
|
| 64 |
+
}
|
| 65 |
+
if (false == freeTensorBuffer(dest)) {
|
| 66 |
+
return false;
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
|
| 70 |
+
QNN_TENSOR_SET_CLIENT_BUF(dest, QNN_TENSOR_GET_CLIENT_BUF(src));
|
| 71 |
+
m_sameMemoryFreeTensors.insert(dest);
|
| 72 |
+
return true;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
bool ClientBuffer::useExternalMemory(Qnn_Tensor_t* dest, void* extMem) {
|
| 76 |
+
if (nullptr == dest || nullptr == extMem) {
|
| 77 |
+
QNN_ERROR("Received nullptr");
|
| 78 |
+
return false;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
Qnn_ClientBuffer_t clientBuffer;
|
| 82 |
+
clientBuffer.data = extMem;
|
| 83 |
+
clientBuffer.dataSize = QNN_TENSOR_GET_CLIENT_BUF(dest).dataSize;
|
| 84 |
+
if (false == freeTensorBuffer(dest)) {
|
| 85 |
+
return false;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSORMEMTYPE_RAW);
|
| 89 |
+
QNN_TENSOR_SET_CLIENT_BUF(dest, clientBuffer);
|
| 90 |
+
m_sameMemoryFreeTensors.insert(dest);
|
| 91 |
+
return true;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
void* ClientBuffer::allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) {
|
| 95 |
+
return nullptr;
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
bool ClientBuffer::mapFusedBufferOffset(
|
| 99 |
+
Qnn_Tensor_t* tensor,
|
| 100 |
+
size_t tensorDataSize,
|
| 101 |
+
int32_t fd,
|
| 102 |
+
uint32_t offset,
|
| 103 |
+
uint64_t totalBufferSize,
|
| 104 |
+
void* memPointer,
|
| 105 |
+
Qnn_ContextHandle_t contextHandle
|
| 106 |
+
) {
|
| 107 |
+
return false;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
bool ClientBuffer::deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) {
|
| 111 |
+
return false;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
void ClientBuffer::freeFusedBuffers() {}
|
| 115 |
+
|
| 116 |
+
size_t ClientBuffer::getOffset(Qnn_Tensor_t* tensor) {
|
| 117 |
+
return 0;
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
size_t ClientBuffer::getTotalBufferSize(Qnn_Tensor_t* tensor) {
|
| 121 |
+
return 0;
|
| 122 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include "IBufferAlloc.hpp"
|
| 12 |
+
#include "Log.hpp"
|
| 13 |
+
#include <unordered_set>
|
| 14 |
+
#include <stdlib.h>
|
| 15 |
+
|
| 16 |
+
class ClientBuffer final : public IBufferAlloc {
|
| 17 |
+
public:
|
| 18 |
+
ClientBuffer() {};
|
| 19 |
+
|
| 20 |
+
// Disable copy constructors, r-value referencing, etc
|
| 21 |
+
ClientBuffer(const ClientBuffer&) = delete;
|
| 22 |
+
|
| 23 |
+
ClientBuffer& operator=(const ClientBuffer&) = delete;
|
| 24 |
+
|
| 25 |
+
ClientBuffer(ClientBuffer&&) = delete;
|
| 26 |
+
|
| 27 |
+
ClientBuffer& operator=(ClientBuffer&&) = delete;
|
| 28 |
+
|
| 29 |
+
bool initialize() override { return true; };
|
| 30 |
+
|
| 31 |
+
void* getBuffer(Qnn_Tensor_t* tensor) override;
|
| 32 |
+
|
| 33 |
+
int getFd(Qnn_Tensor_t* tensor) override {
|
| 34 |
+
QNN_WARN("getFd: This is not ION memory");
|
| 35 |
+
return -1;
|
| 36 |
+
};
|
| 37 |
+
|
| 38 |
+
size_t getOffset(Qnn_Tensor_t* tensor) override;
|
| 39 |
+
size_t getBufferSize(Qnn_Tensor_t* tensor) override;
|
| 40 |
+
size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override;
|
| 41 |
+
|
| 42 |
+
bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override;
|
| 43 |
+
|
| 44 |
+
bool freeTensorBuffer(Qnn_Tensor_t* tensor) override;
|
| 45 |
+
|
| 46 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override;
|
| 47 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override { return false; }
|
| 48 |
+
|
| 49 |
+
bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override;
|
| 50 |
+
|
| 51 |
+
void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override;
|
| 52 |
+
bool allocateBuffers(
|
| 53 |
+
const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
|
| 54 |
+
std::map<std::string, std::pair<int, size_t>>& tensor_offsets
|
| 55 |
+
) override {
|
| 56 |
+
return false;
|
| 57 |
+
};
|
| 58 |
+
|
| 59 |
+
bool mapFusedBufferOffset(
|
| 60 |
+
Qnn_Tensor_t* tensor,
|
| 61 |
+
size_t tensorDataSize,
|
| 62 |
+
int32_t fd,
|
| 63 |
+
uint32_t offset,
|
| 64 |
+
uint64_t totalBufferSize,
|
| 65 |
+
void* memPointer,
|
| 66 |
+
Qnn_ContextHandle_t contextHandle
|
| 67 |
+
) override;
|
| 68 |
+
bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override;
|
| 69 |
+
void freeFusedBuffers() override;
|
| 70 |
+
|
| 71 |
+
bool mapFusedBufferOffset(
|
| 72 |
+
Qnn_Tensor_t* tensor,
|
| 73 |
+
int alloc_idx,
|
| 74 |
+
size_t offset,
|
| 75 |
+
Qnn_ContextHandle_t ctx,
|
| 76 |
+
size_t size
|
| 77 |
+
) override {
|
| 78 |
+
return false;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
virtual ~ClientBuffer() {};
|
| 82 |
+
|
| 83 |
+
private:
|
| 84 |
+
std::unordered_set<Qnn_Tensor_t*> m_sameMemoryFreeTensors;
|
| 85 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <map>
|
| 12 |
+
#include "ICommandLineManager.hpp"
|
| 13 |
+
#include "QnnBackend.h"
|
| 14 |
+
#include "QnnContext.h"
|
| 15 |
+
#include "QnnGraph.h"
|
| 16 |
+
#include "QnnLog.h"
|
| 17 |
+
#include "QnnTypeDef.hpp"
|
| 18 |
+
#include "QnnProfile.h"
|
| 19 |
+
#include "QnnDevice.h"
|
| 20 |
+
|
| 21 |
+
// Compile-time definition to check for QNN SDK features using the QNN API version
|
| 22 |
+
#define QUALLA_QNN_API_VERSION \
|
| 23 |
+
(QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH)
|
| 24 |
+
|
| 25 |
+
const uint32_t g_profilingLevelNotSet = 0;
|
| 26 |
+
|
| 27 |
+
enum class PerfProfile {
|
| 28 |
+
LOW_BALANCED,
|
| 29 |
+
BALANCED,
|
| 30 |
+
DEFAULT,
|
| 31 |
+
HIGH_PERFORMANCE,
|
| 32 |
+
SUSTAINED_HIGH_PERFORMANCE,
|
| 33 |
+
BURST,
|
| 34 |
+
EXTREME_POWER_SAVER,
|
| 35 |
+
LOW_POWER_SAVER,
|
| 36 |
+
POWER_SAVER,
|
| 37 |
+
HIGH_POWER_SAVER,
|
| 38 |
+
SYSTEM_SETTINGS,
|
| 39 |
+
NO_USER_INPUT,
|
| 40 |
+
CUSTOM,
|
| 41 |
+
INVALID
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
// This is the interface that enables backend specific extensions in qnn-net-run.
|
| 45 |
+
// It is designed as hooks in the timeline of various events in NetRun.
|
| 46 |
+
// Backends that intend to implement custom features through qnn-net-run will have
|
| 47 |
+
// to implement this interface and add functionality in appropriate methods depending
|
| 48 |
+
// on where/when the custom functionality needs to be exercised.
|
| 49 |
+
// These functions/hooks will be called through the IBackend interface from within
|
| 50 |
+
// qnn-net-run wherever necessary.
|
| 51 |
+
class IBackend {
|
| 52 |
+
public:
|
| 53 |
+
virtual ~IBackend() {}
|
| 54 |
+
|
| 55 |
+
virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) = 0;
|
| 56 |
+
|
| 57 |
+
virtual bool initialize(void* backendLibHandle) = 0;
|
| 58 |
+
|
| 59 |
+
virtual bool setPerfProfile(PerfProfile perfProfile) = 0;
|
| 60 |
+
|
| 61 |
+
virtual QnnProfile_Level_t getProfilingLevel() = 0;
|
| 62 |
+
|
| 63 |
+
virtual bool loadConfig(std::string configFile) = 0;
|
| 64 |
+
|
| 65 |
+
virtual bool loadCommandLineArgs(std::shared_ptr<ICommandLineManager> clManager) = 0;
|
| 66 |
+
|
| 67 |
+
virtual bool beforeBackendInitialize(
|
| 68 |
+
QnnBackend_Config_t*** customConfigs,
|
| 69 |
+
uint32_t* configCount
|
| 70 |
+
) = 0;
|
| 71 |
+
|
| 72 |
+
virtual bool afterBackendInitialize() = 0;
|
| 73 |
+
|
| 74 |
+
virtual bool beforeContextCreate(
|
| 75 |
+
QnnContext_Config_t*** customConfigs,
|
| 76 |
+
uint32_t* configCount
|
| 77 |
+
) = 0;
|
| 78 |
+
|
| 79 |
+
virtual bool afterContextCreate() = 0;
|
| 80 |
+
|
| 81 |
+
virtual bool beforeComposeGraphs(
|
| 82 |
+
GraphConfigInfo_t*** customGraphConfigs,
|
| 83 |
+
uint32_t* graphCount
|
| 84 |
+
) = 0;
|
| 85 |
+
|
| 86 |
+
virtual bool afterComposeGraphs() = 0;
|
| 87 |
+
|
| 88 |
+
#if QUALLA_QNN_API_VERSION >= 21700
|
| 89 |
+
virtual bool beforeGraphFinalizeUpdateConfig(
|
| 90 |
+
const char* graphName,
|
| 91 |
+
Qnn_GraphHandle_t graphHandle,
|
| 92 |
+
QnnGraph_Config_t*** customConfigs,
|
| 93 |
+
uint32_t* configCount
|
| 94 |
+
) = 0;
|
| 95 |
+
#endif
|
| 96 |
+
|
| 97 |
+
virtual bool beforeGraphFinalize() = 0;
|
| 98 |
+
|
| 99 |
+
virtual bool afterGraphFinalize() = 0;
|
| 100 |
+
|
| 101 |
+
virtual bool beforeRegisterOpPackages() = 0;
|
| 102 |
+
|
| 103 |
+
virtual bool afterRegisterOpPackages() = 0;
|
| 104 |
+
|
| 105 |
+
virtual bool beforeExecute(
|
| 106 |
+
const char* graphName,
|
| 107 |
+
QnnGraph_Config_t*** customConfigs,
|
| 108 |
+
uint32_t* configCount
|
| 109 |
+
) = 0;
|
| 110 |
+
|
| 111 |
+
virtual bool afterExecute() = 0;
|
| 112 |
+
|
| 113 |
+
virtual bool beforeContextFree() = 0;
|
| 114 |
+
|
| 115 |
+
virtual bool afterContextFree() = 0;
|
| 116 |
+
|
| 117 |
+
virtual bool beforeBackendTerminate() = 0;
|
| 118 |
+
|
| 119 |
+
virtual bool afterBackendTerminate() = 0;
|
| 120 |
+
|
| 121 |
+
virtual bool beforeCreateFromBinary(
|
| 122 |
+
QnnContext_Config_t*** customConfigs,
|
| 123 |
+
uint32_t* configCount
|
| 124 |
+
) = 0;
|
| 125 |
+
|
| 126 |
+
virtual bool afterCreateFromBinary() = 0;
|
| 127 |
+
|
| 128 |
+
#if QUALLA_QNN_API_VERSION >= 21700
|
| 129 |
+
virtual bool beforeCreateContextsFromBinaryList(
|
| 130 |
+
std::map<std::string, std::tuple<QnnContext_Config_t**, uint32_t>>*
|
| 131 |
+
contextKeyToCustomConfigsMap,
|
| 132 |
+
QnnContext_Config_t*** commonCustomConfigs,
|
| 133 |
+
uint32_t* commonConfigCount
|
| 134 |
+
) = 0;
|
| 135 |
+
|
| 136 |
+
virtual bool afterCreateContextsFromBinaryList() = 0;
|
| 137 |
+
#endif
|
| 138 |
+
|
| 139 |
+
virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, uint32_t* configCount) = 0;
|
| 140 |
+
|
| 141 |
+
virtual bool afterCreateDevice() = 0;
|
| 142 |
+
|
| 143 |
+
virtual bool beforeFreeDevice() = 0;
|
| 144 |
+
|
| 145 |
+
virtual bool afterFreeDevice() = 0;
|
| 146 |
+
};
|
| 147 |
+
|
| 148 |
+
// These are the function types that the backend extensions shared library is
|
| 149 |
+
// expected to expose. The first function helps NetRun obtain a valid implementation
|
| 150 |
+
// of IBackend interface and the second is used to destroy the same interface at the end.
|
| 151 |
+
// The function names themselves are expected to be these strings:
|
| 152 |
+
// 1. "createBackendInterface"
|
| 153 |
+
// 2. "destroyBackendInterface"
|
| 154 |
+
// These functions need to be tagged with extern "C" and their symbols need to be exposed.
|
| 155 |
+
typedef IBackend* (*CreateBackendInterfaceFnType_t)();
|
| 156 |
+
typedef void (*DestroyBackendInterfaceFnType_t)(IBackend*);
|
Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
#include "QnnTypes.h"
|
| 11 |
+
#include <map>
|
| 12 |
+
#include <string>
|
| 13 |
+
#include <vector>
|
| 14 |
+
#include <utility>
|
| 15 |
+
#include <unordered_map>
|
| 16 |
+
|
| 17 |
+
class IBufferAlloc {
|
| 18 |
+
public:
|
| 19 |
+
virtual ~IBufferAlloc() {}
|
| 20 |
+
IBufferAlloc() {}
|
| 21 |
+
virtual bool initialize() = 0;
|
| 22 |
+
virtual void* getBuffer(Qnn_Tensor_t* tensor) = 0;
|
| 23 |
+
virtual int getFd(Qnn_Tensor_t* tensor) = 0;
|
| 24 |
+
virtual size_t getOffset(Qnn_Tensor_t* tensor) = 0;
|
| 25 |
+
virtual size_t getBufferSize(Qnn_Tensor_t* tensor) = 0;
|
| 26 |
+
virtual size_t getTotalBufferSize(Qnn_Tensor_t* tensor) = 0;
|
| 27 |
+
virtual bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) = 0;
|
| 28 |
+
virtual bool freeTensorBuffer(Qnn_Tensor_t* tensor) = 0;
|
| 29 |
+
virtual bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) = 0;
|
| 30 |
+
virtual bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) = 0;
|
| 31 |
+
virtual bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) = 0;
|
| 32 |
+
virtual void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) = 0;
|
| 33 |
+
virtual bool allocateBuffers(
|
| 34 |
+
const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
|
| 35 |
+
std::map<std::string, std::pair<int, size_t>>& tensor_offsets
|
| 36 |
+
) = 0;
|
| 37 |
+
virtual bool mapFusedBufferOffset(
|
| 38 |
+
Qnn_Tensor_t* tensor,
|
| 39 |
+
size_t tensorDataSize,
|
| 40 |
+
int32_t fd,
|
| 41 |
+
uint32_t offset,
|
| 42 |
+
uint64_t totalBufferSize,
|
| 43 |
+
void* memPointer,
|
| 44 |
+
Qnn_ContextHandle_t contextHandle
|
| 45 |
+
) = 0;
|
| 46 |
+
virtual bool mapFusedBufferOffset(
|
| 47 |
+
Qnn_Tensor_t* tensor,
|
| 48 |
+
int alloc_idx,
|
| 49 |
+
size_t offset,
|
| 50 |
+
Qnn_ContextHandle_t ctx,
|
| 51 |
+
size_t size
|
| 52 |
+
) = 0;
|
| 53 |
+
|
| 54 |
+
virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0;
|
| 55 |
+
virtual void freeFusedBuffers() = 0;
|
| 56 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <cctype>
|
| 12 |
+
#include <memory>
|
| 13 |
+
#include <string>
|
| 14 |
+
#include <tuple>
|
| 15 |
+
#include <vector>
|
| 16 |
+
|
| 17 |
+
class ICommandLineManager {
|
| 18 |
+
public:
|
| 19 |
+
enum class Error { SUCCESS, PARSE_FAILURE, UNUSED_ARGUMENTS, OVER_SUBSCRIBED_ARGUMENTS };
|
| 20 |
+
|
| 21 |
+
using ValueList_t = std::vector<std::shared_ptr<const std::string>>;
|
| 22 |
+
|
| 23 |
+
/**
|
| 24 |
+
* @brief Parses provided command line arguments into key value pairs
|
| 25 |
+
*
|
| 26 |
+
* @param[in] argc Number of char* arguments in argv
|
| 27 |
+
*
|
| 28 |
+
* @param[in] argv Pointer to first element of null terminated character arrays
|
| 29 |
+
*
|
| 30 |
+
* @return Error code:
|
| 31 |
+
* - SUCCESS: provided command line arguments match expected format: --key=value, --key
|
| 32 |
+
* - PARSE_FAILURE: The provided command line arguments do not match expected format
|
| 33 |
+
*
|
| 34 |
+
*/
|
| 35 |
+
virtual Error parseClArgs(size_t argc, char** argv) = 0;
|
| 36 |
+
|
| 37 |
+
/**
|
| 38 |
+
* @brief Provides passed values for requested key if available
|
| 39 |
+
*
|
| 40 |
+
* @param[in] key Key string of option
|
| 41 |
+
*
|
| 42 |
+
* @return (False, empty) if key is not an available argument
|
| 43 |
+
*
|
| 44 |
+
*/
|
| 45 |
+
virtual std::tuple<bool, ValueList_t> serveArg(const std::string& key) = 0;
|
| 46 |
+
|
| 47 |
+
/**
|
| 48 |
+
* @brief Checks whether any provided commandline arguments remain unserved
|
| 49 |
+
*
|
| 50 |
+
* @return True if unconsumed arguments remain, False otherwise
|
| 51 |
+
*/
|
| 52 |
+
virtual bool allArgumentsServed() const = 0;
|
| 53 |
+
|
| 54 |
+
/**
|
| 55 |
+
* @brief Validates command line arguments were correctly utilized
|
| 56 |
+
*
|
| 57 |
+
* @return Error code:
|
| 58 |
+
* - SUCCESS: provided command line arguments were utilized following implementations
|
| 59 |
+
* policy
|
| 60 |
+
* - UNUSED_ARGUMENTS: Some arguments passed were not consumed
|
| 61 |
+
* - OVER_SUBSCRIBED_ARGUMENTS: Some arguments were requested by multiple times
|
| 62 |
+
*
|
| 63 |
+
*/
|
| 64 |
+
virtual Error validateUsage() = 0;
|
| 65 |
+
|
| 66 |
+
virtual ~ICommandLineManager() = default;
|
| 67 |
+
|
| 68 |
+
static bool isKey(const std::string& arg) {
|
| 69 |
+
return (arg.length() > keyPrefix().length()) && (arg.find(keyPrefix()) == 0) &&
|
| 70 |
+
std::isalpha(arg.at(keyPrefix().length()));
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
static Error parseKey(const std::string& arg, std::string& keyOut) {
|
| 74 |
+
if (!isKey(arg)) {
|
| 75 |
+
return Error::PARSE_FAILURE;
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
auto valueSplit = arg.find(keyValueSplit());
|
| 79 |
+
keyOut = valueSplit != arg.npos ? arg.substr(0, valueSplit) : arg;
|
| 80 |
+
return Error::SUCCESS;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
static Error parseValue(const std::string& arg, std::string& valueOut) {
|
| 84 |
+
auto valueSplit = arg.find(keyValueSplit());
|
| 85 |
+
if (valueSplit == arg.npos || valueSplit == arg.length() - 1) {
|
| 86 |
+
return Error::PARSE_FAILURE;
|
| 87 |
+
}
|
| 88 |
+
valueOut = arg.substr(valueSplit + 1);
|
| 89 |
+
return Error::SUCCESS;
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
private:
|
| 93 |
+
static const std::string keyPrefix() { return "--"; };
|
| 94 |
+
static char keyValueSplit() { return '='; };
|
| 95 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
#include <cstring>
|
| 9 |
+
#include <fstream>
|
| 10 |
+
#include <iostream>
|
| 11 |
+
|
| 12 |
+
#include "ClientBuffer.hpp"
|
| 13 |
+
#include "IBufferAlloc.hpp"
|
| 14 |
+
#include "IOTensor.hpp"
|
| 15 |
+
#include "RpcMem.hpp"
|
| 16 |
+
#include "QnnTypeMacros.hpp"
|
| 17 |
+
|
| 18 |
+
#ifdef _WIN32
|
| 19 |
+
#define __strdup _strdup
|
| 20 |
+
#else
|
| 21 |
+
#define __strdup strdup
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
IOTensor::IOTensor(BufferAlloc bufferAllocIn, QNN_INTERFACE_VER_TYPE* qnnInterface)
|
| 25 |
+
: m_bufferAlloc(bufferAllocIn), m_qnnInterface(qnnInterface),
|
| 26 |
+
m_bufferManager(new ClientBuffer()) {}
|
| 27 |
+
|
| 28 |
+
bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
|
| 29 |
+
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 30 |
+
m_bufferManager = std::unique_ptr<IBufferAlloc>(new RpcMem(contextHandle, m_qnnInterface));
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
if (true != m_bufferManager->initialize()) {
|
| 34 |
+
QNN_ERROR("Failed to initialize buffer manager");
|
| 35 |
+
return false;
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
return true;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
IOTensor::~IOTensor() {
|
| 42 |
+
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 43 |
+
m_bufferManager->freeFusedBuffers();
|
| 44 |
+
}
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
// Setup details for Qnn_Tensor_t for execution
|
| 48 |
+
// based on information in TensorWrapper provided by model.so.
|
| 49 |
+
bool IOTensor::setupTensors(
|
| 50 |
+
Qnn_Tensor_t** tensors,
|
| 51 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 52 |
+
uint32_t tensorCount,
|
| 53 |
+
TensorWrapper* tensorWrappers,
|
| 54 |
+
std::unordered_map<std::string, size_t>& tensorsSize,
|
| 55 |
+
Qnn_ContextHandle_t contextHandle,
|
| 56 |
+
bool skipBufferAllocation
|
| 57 |
+
) {
|
| 58 |
+
|
| 59 |
+
if (nullptr == tensorWrappers) {
|
| 60 |
+
QNN_ERROR("tensorWrappers is nullptr");
|
| 61 |
+
return false;
|
| 62 |
+
}
|
| 63 |
+
if (0 == tensorCount) {
|
| 64 |
+
QNN_DEBUG("tensor count is 0. Nothing to setup.");
|
| 65 |
+
return true;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
*tensors = (Qnn_Tensor_t*)calloc(1, tensorCount * sizeof(Qnn_Tensor_t));
|
| 69 |
+
if (nullptr == *tensors) {
|
| 70 |
+
QNN_ERROR("mem alloc failed for *tensors");
|
| 71 |
+
return false;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
auto returnStatus = true;
|
| 75 |
+
|
| 76 |
+
uint64_t totalBufferSize = 0;
|
| 77 |
+
void* memPointer = nullptr;
|
| 78 |
+
int32_t fd = -1;
|
| 79 |
+
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 80 |
+
// Calculate the total size of the tensors
|
| 81 |
+
for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
|
| 82 |
+
auto wrapperTensorName =
|
| 83 |
+
std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx]));
|
| 84 |
+
totalBufferSize += tensorsSize[wrapperTensorName];
|
| 85 |
+
}
|
| 86 |
+
QNN_DEBUG("Calculated total size %lu", totalBufferSize);
|
| 87 |
+
|
| 88 |
+
if (!skipBufferAllocation) {
|
| 89 |
+
// Allocate the buffer of this size
|
| 90 |
+
memPointer = m_bufferManager->allocateTensorFusedBuffer(totalBufferSize, &fd);
|
| 91 |
+
if (memPointer) {
|
| 92 |
+
QNN_DEBUG(
|
| 93 |
+
"Successfully allocated a buffer of size %lu, pointer %p, fd %d",
|
| 94 |
+
(unsigned long)totalBufferSize,
|
| 95 |
+
memPointer,
|
| 96 |
+
fd
|
| 97 |
+
);
|
| 98 |
+
} else {
|
| 99 |
+
QNN_ERROR(
|
| 100 |
+
"Not able to allocate buffer of size %lu", (unsigned long)totalBufferSize
|
| 101 |
+
);
|
| 102 |
+
return false;
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
uint64_t offset = 0;
|
| 108 |
+
|
| 109 |
+
for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
|
| 110 |
+
Qnn_Tensor_t wrapperTensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tensorIdx]);
|
| 111 |
+
auto wrapperTensorName = std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx]));
|
| 112 |
+
if (true == returnStatus) {
|
| 113 |
+
(*tensors)[tensorIdx] = QNN_TENSOR_INIT;
|
| 114 |
+
returnStatus = deepCopyQnnTensorInfo(((*tensors) + tensorIdx), &wrapperTensor);
|
| 115 |
+
}
|
| 116 |
+
if (true == returnStatus) {
|
| 117 |
+
size_t tensorDataSize = tensorsSize[wrapperTensorName];
|
| 118 |
+
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 119 |
+
if (!skipBufferAllocation) {
|
| 120 |
+
returnStatus = m_bufferManager->mapFusedBufferOffset(
|
| 121 |
+
((*tensors) + tensorIdx),
|
| 122 |
+
tensorDataSize,
|
| 123 |
+
fd,
|
| 124 |
+
offset,
|
| 125 |
+
totalBufferSize,
|
| 126 |
+
memPointer,
|
| 127 |
+
contextHandle
|
| 128 |
+
);
|
| 129 |
+
offset += tensorDataSize;
|
| 130 |
+
}
|
| 131 |
+
} else {
|
| 132 |
+
returnStatus = m_bufferManager->allocateTensorBuffer(
|
| 133 |
+
((*tensors) + tensorIdx), tensorDataSize
|
| 134 |
+
);
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
+
if (true != returnStatus) {
|
| 138 |
+
QNN_ERROR("Failure in setupTensors, cleaning up resources");
|
| 139 |
+
tearDownTensors(*tensors, tensorIdx);
|
| 140 |
+
*tensors = nullptr;
|
| 141 |
+
QNN_ERROR("Failure in setupTensors, done cleaning up resources");
|
| 142 |
+
return false;
|
| 143 |
+
} else {
|
| 144 |
+
tensorNameToTensorPointer.insert({wrapperTensorName, ((*tensors) + tensorIdx)});
|
| 145 |
+
// QNN_DEBUG("allocateBuffer successful");
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
return returnStatus;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
// Setup details for all input tensors for graph execution.
|
| 153 |
+
bool IOTensor::setupInputTensors(
|
| 154 |
+
Qnn_Tensor_t** inputs,
|
| 155 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 156 |
+
const GraphInfo_t& graphInfo,
|
| 157 |
+
std::unordered_map<std::string, size_t>& inputTensorsSize,
|
| 158 |
+
Qnn_ContextHandle_t contextHandle,
|
| 159 |
+
bool skipBufferAllocation
|
| 160 |
+
) {
|
| 161 |
+
|
| 162 |
+
if (true != setupTensors(
|
| 163 |
+
inputs,
|
| 164 |
+
tensorNameToTensorPointer,
|
| 165 |
+
graphInfo.numInputTensors,
|
| 166 |
+
(graphInfo.inputTensors),
|
| 167 |
+
inputTensorsSize,
|
| 168 |
+
contextHandle,
|
| 169 |
+
skipBufferAllocation
|
| 170 |
+
)) {
|
| 171 |
+
QNN_ERROR("Failure in setupInputTensors, cleaning up resources");
|
| 172 |
+
if (nullptr != *inputs) {
|
| 173 |
+
QNN_DEBUG("cleaning up input tensors");
|
| 174 |
+
tearDownTensors(*inputs, graphInfo.numInputTensors);
|
| 175 |
+
*inputs = nullptr;
|
| 176 |
+
}
|
| 177 |
+
QNN_ERROR("Failure in setupInputTensors, done cleaning up resources");
|
| 178 |
+
|
| 179 |
+
return false;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
return true;
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
// Setup details for all output tensors for graph execution.
|
| 186 |
+
bool IOTensor::setupOutputTensors(
|
| 187 |
+
Qnn_Tensor_t** outputs,
|
| 188 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 189 |
+
const GraphInfo_t& graphInfo,
|
| 190 |
+
std::unordered_map<std::string, size_t>& outputTensorsSize,
|
| 191 |
+
Qnn_ContextHandle_t contextHandle,
|
| 192 |
+
bool skipBufferAllocation
|
| 193 |
+
) {
|
| 194 |
+
|
| 195 |
+
if (true != setupTensors(
|
| 196 |
+
outputs,
|
| 197 |
+
tensorNameToTensorPointer,
|
| 198 |
+
graphInfo.numOutputTensors,
|
| 199 |
+
(graphInfo.outputTensors),
|
| 200 |
+
outputTensorsSize,
|
| 201 |
+
contextHandle,
|
| 202 |
+
skipBufferAllocation
|
| 203 |
+
)) {
|
| 204 |
+
QNN_ERROR("Failure in setupOutputTensors, cleaning up resources");
|
| 205 |
+
if (nullptr != *outputs) {
|
| 206 |
+
QNN_DEBUG("cleaning up output tensors");
|
| 207 |
+
tearDownTensors(*outputs, graphInfo.numOutputTensors);
|
| 208 |
+
*outputs = nullptr;
|
| 209 |
+
}
|
| 210 |
+
QNN_ERROR("Failure in setupOutputTensors, done cleaning up resources");
|
| 211 |
+
|
| 212 |
+
return false;
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
return true;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
bool IOTensor::mapFusedBufferOffset(
|
| 219 |
+
GraphInfo_t* graph_info,
|
| 220 |
+
Qnn_ContextHandle_t context_handle,
|
| 221 |
+
const std::map<std::string, std::tuple<int, size_t, size_t>>& graph_allocs
|
| 222 |
+
) {
|
| 223 |
+
std::lock_guard lk(_tmp_lock); // READ COMMENT IN IOTensor.hpp _tmp_lock
|
| 224 |
+
|
| 225 |
+
bool ret = true;
|
| 226 |
+
for (const bool mode : {true, false}) {
|
| 227 |
+
TensorWrapper* tensor_bank = (mode) ? graph_info->inputTensors : graph_info->outputTensors;
|
| 228 |
+
uint32_t num_tensors = (mode) ? graph_info->numInputTensors : graph_info->numOutputTensors;
|
| 229 |
+
|
| 230 |
+
for (size_t tidx = 0; tidx < num_tensors; tidx++) {
|
| 231 |
+
TensorWrapper& tensor_wrapper = tensor_bank[tidx];
|
| 232 |
+
|
| 233 |
+
Qnn_Tensor_t* tensor = &GET_TENSOR_WRAPPER_TENSOR(tensor_wrapper);
|
| 234 |
+
std::string tensor_name = std::string(GET_TENSOR_WRAPPER_NAME(tensor_wrapper));
|
| 235 |
+
|
| 236 |
+
if (!graph_allocs.contains(tensor_name)) continue;
|
| 237 |
+
auto& [alloc_idx, offset, size] = graph_allocs.at(tensor_name);
|
| 238 |
+
ret &= m_bufferManager->mapFusedBufferOffset(
|
| 239 |
+
tensor, alloc_idx, offset, context_handle, size
|
| 240 |
+
);
|
| 241 |
+
}
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
return ret;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
// Clean up all tensors related data after execution.
|
| 248 |
+
bool IOTensor::tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount) {
|
| 249 |
+
|
| 250 |
+
if (nullptr != tensors) {
|
| 251 |
+
QNN_DEBUG("cleaning up resources for tensors");
|
| 252 |
+
for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
|
| 253 |
+
// QNN_DEBUG("freeing resources for tensor: %zu", tensorIdx);
|
| 254 |
+
if (nullptr != QNN_TENSOR_GET_DIMENSIONS(&tensors[tensorIdx])) {
|
| 255 |
+
// QNN_DEBUG("freeing maxDimensions");
|
| 256 |
+
free(QNN_TENSOR_GET_DIMENSIONS(&tensors[tensorIdx]));
|
| 257 |
+
}
|
| 258 |
+
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 259 |
+
m_bufferManager->deregisterTensorFusedBuffer(&(tensors[tensorIdx]));
|
| 260 |
+
} else {
|
| 261 |
+
m_bufferManager->freeTensorBuffer(&(tensors[tensorIdx]));
|
| 262 |
+
}
|
| 263 |
+
m_freeTensorsPointerSet.insert(&(tensors[tensorIdx]));
|
| 264 |
+
}
|
| 265 |
+
free(tensors);
|
| 266 |
+
tensors = nullptr;
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
return true;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
// Clean up all tensors after execution.
|
| 273 |
+
bool IOTensor::tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t numTensors) {
|
| 274 |
+
|
| 275 |
+
for (Qnn_Tensor_t* tensor : tensors) {
|
| 276 |
+
tearDownTensors(tensor, numTensors);
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
return true;
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
bool IOTensor::tearDownTensors(std::vector<Qnn_Tensor_t>& tensors) {
|
| 283 |
+
return tearDownTensors(tensors.data(), tensors.size());
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
// Clean up all tensors after execution.
|
| 287 |
+
bool IOTensor::tearDownTensors(
|
| 288 |
+
std::unordered_map<std::string, Qnn_Tensor_t*>& tensors,
|
| 289 |
+
std::unordered_map<std::string, uint32_t>& tensorCountMap
|
| 290 |
+
) {
|
| 291 |
+
|
| 292 |
+
for (auto& tensor : tensors) {
|
| 293 |
+
tearDownTensors(tensor.second, tensorCountMap[tensor.first]);
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
return true;
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
// Clean up all tensors after execution.
|
| 300 |
+
bool IOTensor::tearDownTensors(
|
| 301 |
+
std::vector<std::unordered_map<std::string, Qnn_Tensor_t*>>& tensors,
|
| 302 |
+
std::unordered_map<std::string, uint32_t>& tensorCountMap
|
| 303 |
+
) {
|
| 304 |
+
|
| 305 |
+
for (auto& tensor : tensors) {
|
| 306 |
+
tearDownTensors(tensor, tensorCountMap);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
return true;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
bool IOTensor::deepCopyQnnTensorInfo(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
|
| 313 |
+
|
| 314 |
+
if (nullptr == dest || nullptr == src) {
|
| 315 |
+
QNN_ERROR("Received nullptr");
|
| 316 |
+
return false;
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
// set tensor.version before using QNN_TENSOR_SET macros, as they require the version to be set
|
| 320 |
+
// to correctly assign values
|
| 321 |
+
dest->version = src->version;
|
| 322 |
+
const char* tensorName = QNN_TENSOR_GET_NAME(src);
|
| 323 |
+
if (!tensorName) {
|
| 324 |
+
QNN_TENSOR_SET_NAME(dest, nullptr);
|
| 325 |
+
} else {
|
| 326 |
+
QNN_TENSOR_SET_NAME(dest, __strdup(tensorName));
|
| 327 |
+
}
|
| 328 |
+
QNN_TENSOR_SET_ID(dest, QNN_TENSOR_GET_ID(src));
|
| 329 |
+
QNN_TENSOR_SET_TYPE(dest, QNN_TENSOR_GET_TYPE(src));
|
| 330 |
+
QNN_TENSOR_SET_DATA_FORMAT(dest, QNN_TENSOR_GET_DATA_FORMAT(src));
|
| 331 |
+
QNN_TENSOR_SET_DATA_TYPE(dest, QNN_TENSOR_GET_DATA_TYPE(src));
|
| 332 |
+
Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT;
|
| 333 |
+
qParams.encodingDefinition = QNN_TENSOR_GET_QUANT_PARAMS(src).encodingDefinition;
|
| 334 |
+
qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
|
| 335 |
+
if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding ==
|
| 336 |
+
QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
|
| 337 |
+
qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding;
|
| 338 |
+
qParams.scaleOffsetEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).scaleOffsetEncoding;
|
| 339 |
+
} else if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding ==
|
| 340 |
+
QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
|
| 341 |
+
qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding;
|
| 342 |
+
qParams.axisScaleOffsetEncoding.axis =
|
| 343 |
+
QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.axis;
|
| 344 |
+
qParams.axisScaleOffsetEncoding.numScaleOffsets =
|
| 345 |
+
QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets;
|
| 346 |
+
if (QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets > 0) {
|
| 347 |
+
qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc(
|
| 348 |
+
QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets *
|
| 349 |
+
sizeof(Qnn_ScaleOffset_t)
|
| 350 |
+
);
|
| 351 |
+
if (qParams.axisScaleOffsetEncoding.scaleOffset) {
|
| 352 |
+
for (size_t idx = 0;
|
| 353 |
+
idx < QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets;
|
| 354 |
+
idx++) {
|
| 355 |
+
qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale =
|
| 356 |
+
QNN_TENSOR_GET_QUANT_PARAMS(src)
|
| 357 |
+
.axisScaleOffsetEncoding.scaleOffset[idx]
|
| 358 |
+
.scale;
|
| 359 |
+
qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset =
|
| 360 |
+
QNN_TENSOR_GET_QUANT_PARAMS(src)
|
| 361 |
+
.axisScaleOffsetEncoding.scaleOffset[idx]
|
| 362 |
+
.offset;
|
| 363 |
+
}
|
| 364 |
+
}
|
| 365 |
+
}
|
| 366 |
+
}
|
| 367 |
+
QNN_TENSOR_SET_QUANT_PARAMS(dest, qParams);
|
| 368 |
+
QNN_TENSOR_SET_RANK(dest, QNN_TENSOR_GET_RANK(src));
|
| 369 |
+
QNN_TENSOR_SET_DIMENSIONS(dest, nullptr);
|
| 370 |
+
if (QNN_TENSOR_GET_RANK(src) > 0) {
|
| 371 |
+
QNN_TENSOR_SET_DIMENSIONS(
|
| 372 |
+
dest, (uint32_t*)malloc(QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t))
|
| 373 |
+
);
|
| 374 |
+
if (QNN_TENSOR_GET_DIMENSIONS(dest)) {
|
| 375 |
+
memcpy(QNN_TENSOR_GET_DIMENSIONS(dest),
|
| 376 |
+
QNN_TENSOR_GET_DIMENSIONS(src),
|
| 377 |
+
QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t));
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
return true;
|
| 382 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include <memory>
|
| 11 |
+
#include <queue>
|
| 12 |
+
#include <unordered_map>
|
| 13 |
+
#include <unordered_set>
|
| 14 |
+
#include <vector>
|
| 15 |
+
#include <mutex>
|
| 16 |
+
|
| 17 |
+
#include "IBufferAlloc.hpp"
|
| 18 |
+
#include "QnnTypeDef.hpp"
|
| 19 |
+
#include "Log.hpp"
|
| 20 |
+
#include "QnnBackend.h"
|
| 21 |
+
#include "QnnCommon.h"
|
| 22 |
+
#include "QnnContext.h"
|
| 23 |
+
#include "QnnGraph.h"
|
| 24 |
+
#include "QnnInterface.h"
|
| 25 |
+
#include "QnnProperty.h"
|
| 26 |
+
#include "QnnTensor.h"
|
| 27 |
+
#include "QnnTypes.h"
|
| 28 |
+
enum class BufferAlloc {
|
| 29 |
+
DEFAULT, // malloc based allocator
|
| 30 |
+
SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform
|
| 31 |
+
INVALID
|
| 32 |
+
};
|
| 33 |
+
class IBufferAlloc;
|
| 34 |
+
class IOTensor {
|
| 35 |
+
public:
|
| 36 |
+
IOTensor(
|
| 37 |
+
BufferAlloc bufferAllocIn = BufferAlloc::DEFAULT,
|
| 38 |
+
QNN_INTERFACE_VER_TYPE* qnnInterface = nullptr
|
| 39 |
+
);
|
| 40 |
+
|
| 41 |
+
~IOTensor();
|
| 42 |
+
|
| 43 |
+
bool initialize(Qnn_ContextHandle_t contextHandle = nullptr);
|
| 44 |
+
|
| 45 |
+
bool setupInputTensors(
|
| 46 |
+
Qnn_Tensor_t** inputs,
|
| 47 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 48 |
+
const GraphInfo_t& graphInfo,
|
| 49 |
+
std::unordered_map<std::string, size_t>& inputTensorsSize,
|
| 50 |
+
Qnn_ContextHandle_t contextHandle,
|
| 51 |
+
bool skipBufferAllocation = false
|
| 52 |
+
);
|
| 53 |
+
|
| 54 |
+
bool setupOutputTensors(
|
| 55 |
+
Qnn_Tensor_t** outputs,
|
| 56 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 57 |
+
const GraphInfo_t& graphInfo,
|
| 58 |
+
std::unordered_map<std::string, size_t>& outputTensorsSize,
|
| 59 |
+
Qnn_ContextHandle_t contextHandle,
|
| 60 |
+
bool skipBufferAllocation = false
|
| 61 |
+
);
|
| 62 |
+
|
| 63 |
+
bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount);
|
| 64 |
+
|
| 65 |
+
bool tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t tensorCount);
|
| 66 |
+
bool tearDownTensors(std::vector<Qnn_Tensor_t>& tensors);
|
| 67 |
+
bool tearDownTensors(
|
| 68 |
+
std::unordered_map<std::string, Qnn_Tensor_t*>& tensors,
|
| 69 |
+
std::unordered_map<std::string, uint32_t>& tensorCountMap
|
| 70 |
+
);
|
| 71 |
+
bool tearDownTensors(
|
| 72 |
+
std::vector<std::unordered_map<std::string, Qnn_Tensor_t*>>& tensors,
|
| 73 |
+
std::unordered_map<std::string, uint32_t>& tensorCountMap
|
| 74 |
+
);
|
| 75 |
+
|
| 76 |
+
bool tearDownTensors(const GraphInfo_t* graph_info) {
|
| 77 |
+
bool status = true;
|
| 78 |
+
if (!tearDownTensors(graph_info->inputTensors, graph_info->numInputTensors)) {
|
| 79 |
+
status = false;
|
| 80 |
+
QNN_ERROR("Failed to tear down input tensors for graph %s", graph_info->graphName);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
if (!tearDownTensors(graph_info->outputTensors, graph_info->numOutputTensors)) {
|
| 84 |
+
status = false;
|
| 85 |
+
QNN_ERROR("Failed to tear down output tensors for graph %s", graph_info->graphName);
|
| 86 |
+
}
|
| 87 |
+
return status;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
void* getBuffer(Qnn_Tensor_t* tensor) { return m_bufferManager->getBuffer(tensor); };
|
| 91 |
+
|
| 92 |
+
int getFd(Qnn_Tensor_t* tensor) { return m_bufferManager->getFd(tensor); };
|
| 93 |
+
|
| 94 |
+
size_t getOffset(Qnn_Tensor_t* tensor) { return m_bufferManager->getOffset(tensor); };
|
| 95 |
+
|
| 96 |
+
size_t getBufferSize(Qnn_Tensor_t* tensor) { return m_bufferManager->getBufferSize(tensor); };
|
| 97 |
+
|
| 98 |
+
size_t getTotalBufferSize(Qnn_Tensor_t* tensor) {
|
| 99 |
+
return m_bufferManager->getTotalBufferSize(tensor);
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) {
|
| 103 |
+
return m_bufferManager->allocateTensorFusedBuffer(bufferSize, fd);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
bool allocateBuffers(
|
| 107 |
+
const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
|
| 108 |
+
std::map<std::string, std::pair<int, size_t>>& tensor_offsets
|
| 109 |
+
) {
|
| 110 |
+
return m_bufferManager->allocateBuffers(allocs_per_chunk, tensor_offsets);
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
bool mapFusedBufferOffset(
|
| 114 |
+
Qnn_Tensor_t* tensor,
|
| 115 |
+
size_t tensorDataSize,
|
| 116 |
+
int32_t fd,
|
| 117 |
+
uint32_t offset,
|
| 118 |
+
uint64_t totalBufferSize,
|
| 119 |
+
void* memPointer,
|
| 120 |
+
Qnn_ContextHandle_t contextHandle
|
| 121 |
+
) {
|
| 122 |
+
return m_bufferManager->mapFusedBufferOffset(
|
| 123 |
+
tensor, tensorDataSize, fd, offset, totalBufferSize, memPointer, contextHandle
|
| 124 |
+
);
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
bool mapFusedBufferOffset(
|
| 128 |
+
GraphInfo_t* graph_info,
|
| 129 |
+
Qnn_ContextHandle_t context_handle,
|
| 130 |
+
const std::map<std::string, std::tuple<int, size_t, size_t>>& graph_allocs
|
| 131 |
+
);
|
| 132 |
+
|
| 133 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
|
| 134 |
+
return m_bufferManager->useSameMemory(dest, src);
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) {
|
| 138 |
+
return m_bufferManager->useSameMemory(dest, src, offset);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) {
|
| 142 |
+
return m_bufferManager->useExternalMemory(dest, extMem);
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
BufferAlloc getBufferAllocType() { return m_bufferAlloc; }
|
| 146 |
+
|
| 147 |
+
std::unordered_set<void*>& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; }
|
| 148 |
+
|
| 149 |
+
private:
|
| 150 |
+
BufferAlloc m_bufferAlloc;
|
| 151 |
+
QNN_INTERFACE_VER_TYPE* m_qnnInterface;
|
| 152 |
+
std::unique_ptr<IBufferAlloc> m_bufferManager;
|
| 153 |
+
std::unordered_set<void*> m_freeTensorsPointerSet;
|
| 154 |
+
|
| 155 |
+
// There seems to be a race condition in mapFusedBufferOffset because we are
|
| 156 |
+
// calling it from multiple threads. Maybe memRegister/memDeRegister is not thread-safe
|
| 157 |
+
// Until I figure this out, adding a temporary lock here. TODO: Fix and remove this!
|
| 158 |
+
std::mutex _tmp_lock;
|
| 159 |
+
|
| 160 |
+
bool deepCopyQnnTensorInfo(Qnn_Tensor_t* dest, Qnn_Tensor_t* src);
|
| 161 |
+
bool setupTensors(
|
| 162 |
+
Qnn_Tensor_t** tensors,
|
| 163 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 164 |
+
uint32_t tensorCount,
|
| 165 |
+
TensorWrapper* tensorsInfo,
|
| 166 |
+
std::unordered_map<std::string, size_t>& tensorsSize,
|
| 167 |
+
Qnn_ContextHandle_t contextHandle,
|
| 168 |
+
bool skipBufferAllocation = false
|
| 169 |
+
);
|
| 170 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/Log.hpp
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <stdio.h>
|
| 12 |
+
|
| 13 |
+
// FIXME: Use logger from qualla::Env
|
| 14 |
+
|
| 15 |
+
#define QNN_INFO(fmt, ...) fprintf(stderr, "[INFO] " #fmt "\n", ##__VA_ARGS__)
|
| 16 |
+
#define QNN_ERROR(fmt, ...) fprintf(stderr, "[ERROR] " #fmt "\n", ##__VA_ARGS__)
|
| 17 |
+
#define QNN_WARN(fmt, ...) fprintf(stderr, "[WARN] " #fmt "\n", ##__VA_ARGS__)
|
| 18 |
+
|
| 19 |
+
#if 0
|
| 20 |
+
// #define NSP_LOG_LEVEL 2
|
| 21 |
+
#define QNN_DEBUG(fmt, ...) fprintf(stderr, "[DEBUG] " #fmt "\n", ##__VA_ARGS__)
|
| 22 |
+
#else
|
| 23 |
+
#define QNN_DEBUG(fmt, ...)
|
| 24 |
+
#endif
|
Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <string>
|
| 12 |
+
|
| 13 |
+
#include "ICommandLineManager.hpp"
|
| 14 |
+
#include "IBackend.hpp"
|
| 15 |
+
|
| 16 |
+
// This is an implementation of IBackend interface within qnn-net-run.
|
| 17 |
+
// NetRunBackend provides a dummy implementation of IBackend as a concrete
|
| 18 |
+
// implementation is needed in case there is no backend extensions library
|
| 19 |
+
// supplied by the user.
|
| 20 |
+
// This is built as part of QnnNetRun library and is used in case of no
|
| 21 |
+
// user supplied backend extensions implementation.
|
| 22 |
+
class NetRunBackend final : public IBackend {
|
| 23 |
+
public:
|
| 24 |
+
NetRunBackend() {}
|
| 25 |
+
|
| 26 |
+
virtual ~NetRunBackend() {}
|
| 27 |
+
|
| 28 |
+
virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) override {
|
| 29 |
+
ignore(callback);
|
| 30 |
+
ignore(maxLogLevel);
|
| 31 |
+
return true;
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
virtual bool initialize(void* backendLibHandle) override {
|
| 35 |
+
ignore(backendLibHandle);
|
| 36 |
+
return true;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
virtual bool setPerfProfile(PerfProfile perfProfile) override {
|
| 40 |
+
ignore(perfProfile);
|
| 41 |
+
return true;
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
virtual QnnProfile_Level_t getProfilingLevel() override { return g_profilingLevelNotSet; }
|
| 45 |
+
|
| 46 |
+
virtual bool loadConfig(std::string configFile) override {
|
| 47 |
+
ignore(configFile);
|
| 48 |
+
return true;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
virtual bool loadCommandLineArgs(std::shared_ptr<ICommandLineManager> clManager) override {
|
| 52 |
+
ignore(clManager);
|
| 53 |
+
return true;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
virtual bool beforeBackendInitialize(
|
| 57 |
+
QnnBackend_Config_t*** customConfigs,
|
| 58 |
+
uint32_t* configCount
|
| 59 |
+
) override {
|
| 60 |
+
ignore(customConfigs);
|
| 61 |
+
ignore(configCount);
|
| 62 |
+
return true;
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
virtual bool afterBackendInitialize() override { return true; }
|
| 66 |
+
|
| 67 |
+
virtual bool beforeContextCreate(QnnContext_Config_t*** customConfigs, uint32_t* configCount)
|
| 68 |
+
override {
|
| 69 |
+
ignore(customConfigs);
|
| 70 |
+
ignore(configCount);
|
| 71 |
+
return true;
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
virtual bool afterContextCreate() override { return true; }
|
| 75 |
+
|
| 76 |
+
virtual bool beforeComposeGraphs(GraphConfigInfo_t*** customGraphConfigs, uint32_t* graphCount)
|
| 77 |
+
override {
|
| 78 |
+
ignore(customGraphConfigs);
|
| 79 |
+
ignore(graphCount);
|
| 80 |
+
return true;
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
virtual bool afterComposeGraphs() override { return true; }
|
| 84 |
+
|
| 85 |
+
#if QUALLA_QNN_API_VERSION >= 21700
|
| 86 |
+
virtual bool beforeGraphFinalizeUpdateConfig(
|
| 87 |
+
const char* graphName,
|
| 88 |
+
Qnn_GraphHandle_t graphHandle,
|
| 89 |
+
QnnGraph_Config_t*** customConfigs,
|
| 90 |
+
uint32_t* configCount
|
| 91 |
+
) override {
|
| 92 |
+
ignore(graphName);
|
| 93 |
+
ignore(graphHandle);
|
| 94 |
+
ignore(customConfigs);
|
| 95 |
+
ignore(configCount);
|
| 96 |
+
return true;
|
| 97 |
+
}
|
| 98 |
+
#endif
|
| 99 |
+
|
| 100 |
+
virtual bool beforeGraphFinalize() override { return true; }
|
| 101 |
+
|
| 102 |
+
virtual bool afterGraphFinalize() override { return true; }
|
| 103 |
+
|
| 104 |
+
virtual bool beforeRegisterOpPackages() override { return true; }
|
| 105 |
+
|
| 106 |
+
virtual bool afterRegisterOpPackages() override { return true; }
|
| 107 |
+
|
| 108 |
+
virtual bool beforeExecute(
|
| 109 |
+
const char* graphName,
|
| 110 |
+
QnnGraph_Config_t*** customConfigs,
|
| 111 |
+
uint32_t* configCount
|
| 112 |
+
) override {
|
| 113 |
+
ignore(graphName);
|
| 114 |
+
ignore(customConfigs);
|
| 115 |
+
ignore(configCount);
|
| 116 |
+
return true;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
virtual bool afterExecute() override { return true; }
|
| 120 |
+
|
| 121 |
+
virtual bool beforeContextFree() override { return true; }
|
| 122 |
+
|
| 123 |
+
virtual bool afterContextFree() override { return true; }
|
| 124 |
+
|
| 125 |
+
virtual bool beforeBackendTerminate() override { return true; }
|
| 126 |
+
|
| 127 |
+
virtual bool afterBackendTerminate() override { return true; }
|
| 128 |
+
|
| 129 |
+
virtual bool beforeCreateFromBinary(QnnContext_Config_t*** customConfigs, uint32_t* configCount)
|
| 130 |
+
override {
|
| 131 |
+
ignore(customConfigs);
|
| 132 |
+
ignore(configCount);
|
| 133 |
+
return true;
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
virtual bool afterCreateFromBinary() override { return true; }
|
| 137 |
+
|
| 138 |
+
#if QUALLA_QNN_API_VERSION >= 21700
|
| 139 |
+
virtual bool beforeCreateContextsFromBinaryList(
|
| 140 |
+
std::map<std::string, std::tuple<QnnContext_Config_t**, uint32_t>>*
|
| 141 |
+
contextKeyToCustomConfigsMap,
|
| 142 |
+
QnnContext_Config_t*** commonCustomConfigs,
|
| 143 |
+
uint32_t* commonConfigCount
|
| 144 |
+
) override {
|
| 145 |
+
ignore(contextKeyToCustomConfigsMap);
|
| 146 |
+
ignore(commonCustomConfigs);
|
| 147 |
+
ignore(commonConfigCount);
|
| 148 |
+
return true;
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
virtual bool afterCreateContextsFromBinaryList() override { return true; }
|
| 152 |
+
#endif
|
| 153 |
+
|
| 154 |
+
virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, uint32_t* configCount)
|
| 155 |
+
override {
|
| 156 |
+
ignore(deviceConfigs);
|
| 157 |
+
ignore(configCount);
|
| 158 |
+
return true;
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
virtual bool afterCreateDevice() override { return true; }
|
| 162 |
+
|
| 163 |
+
virtual bool beforeFreeDevice() override { return true; }
|
| 164 |
+
|
| 165 |
+
virtual bool afterFreeDevice() override { return true; }
|
| 166 |
+
|
| 167 |
+
private:
|
| 168 |
+
// Utility function to ignore compiler warnings when a variable
|
| 169 |
+
// is unused. Recommended by Herb Sutter in Sutter's Mill
|
| 170 |
+
// instead of (void)variable.
|
| 171 |
+
template <typename T>
|
| 172 |
+
void ignore(const T&) {}
|
| 173 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include "BackendExtensions.hpp"
|
| 12 |
+
#include "QnnConfig.hpp"
|
| 13 |
+
#include "QnnHtpPerfInfrastructure.h"
|
| 14 |
+
#include "QnnHtpDevice.h"
|
| 15 |
+
#include "qnn-utils.hpp"
|
| 16 |
+
#include "IOTensor.hpp"
|
| 17 |
+
|
| 18 |
+
#include <memory>
|
| 19 |
+
#include <mutex>
|
| 20 |
+
|
| 21 |
+
#define QNN_IO_TENSOR_DEBUG 0
|
| 22 |
+
|
| 23 |
+
enum KVManagerMode { POINTER_SHIFT = 0x0, SHIFT_CONCAT = 0x1 };
|
| 24 |
+
|
| 25 |
+
using qualla::QnnUtils::QuantParam;
|
| 26 |
+
|
| 27 |
+
#define QUALLA_QNN_API_VERSION \
|
| 28 |
+
(QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH)
|
| 29 |
+
|
| 30 |
+
static std::map<Qnn_DataType_t, size_t> g_qnnDataTypeToSize = {
|
| 31 |
+
{QNN_DATATYPE_INT_8, 1},
|
| 32 |
+
{QNN_DATATYPE_INT_16, 2},
|
| 33 |
+
{QNN_DATATYPE_INT_32, 4},
|
| 34 |
+
{QNN_DATATYPE_INT_64, 8},
|
| 35 |
+
{QNN_DATATYPE_UINT_8, 1},
|
| 36 |
+
{QNN_DATATYPE_UINT_16, 2},
|
| 37 |
+
{QNN_DATATYPE_UINT_32, 4},
|
| 38 |
+
{QNN_DATATYPE_UINT_64, 8},
|
| 39 |
+
{QNN_DATATYPE_FLOAT_16, 2},
|
| 40 |
+
{QNN_DATATYPE_FLOAT_32, 4},
|
| 41 |
+
{QNN_DATATYPE_SFIXED_POINT_8, 1},
|
| 42 |
+
{QNN_DATATYPE_SFIXED_POINT_16, 2},
|
| 43 |
+
{QNN_DATATYPE_SFIXED_POINT_32, 4},
|
| 44 |
+
{QNN_DATATYPE_UFIXED_POINT_8, 1},
|
| 45 |
+
{QNN_DATATYPE_UFIXED_POINT_16, 2},
|
| 46 |
+
{QNN_DATATYPE_UFIXED_POINT_32, 4},
|
| 47 |
+
{QNN_DATATYPE_BOOL_8, 1},
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
+
class QnnApi {
|
| 51 |
+
private:
|
| 52 |
+
const uint32_t s_graphConfigsReserveCount = 16;
|
| 53 |
+
|
| 54 |
+
// Model vars
|
| 55 |
+
typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)(
|
| 56 |
+
const QnnInterface_t*** providerList,
|
| 57 |
+
uint32_t* numProviders
|
| 58 |
+
);
|
| 59 |
+
typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(
|
| 60 |
+
const QnnSystemInterface_t*** providerList,
|
| 61 |
+
uint32_t* numProviders
|
| 62 |
+
);
|
| 63 |
+
|
| 64 |
+
// Graph Related Function Handle Types
|
| 65 |
+
typedef ModelError_t (*ComposeGraphsFnHandleType_t)(
|
| 66 |
+
Qnn_BackendHandle_t,
|
| 67 |
+
QNN_INTERFACE_VER_TYPE,
|
| 68 |
+
Qnn_ContextHandle_t,
|
| 69 |
+
const GraphConfigInfo_t**,
|
| 70 |
+
const uint32_t,
|
| 71 |
+
GraphInfo_t***,
|
| 72 |
+
uint32_t*,
|
| 73 |
+
bool,
|
| 74 |
+
QnnLog_Callback_t,
|
| 75 |
+
QnnLog_Level_t
|
| 76 |
+
);
|
| 77 |
+
|
| 78 |
+
typedef ModelError_t (*GenAIComposeGraphsFnHandleType_t)(
|
| 79 |
+
Qnn_BackendHandle_t,
|
| 80 |
+
QNN_INTERFACE_VER_TYPE,
|
| 81 |
+
Qnn_ContextHandle_t,
|
| 82 |
+
const GraphConfigInfo_t**,
|
| 83 |
+
const uint32_t,
|
| 84 |
+
uint32_t* inputDim,
|
| 85 |
+
uint32_t inputRank,
|
| 86 |
+
uint32_t* outputDim,
|
| 87 |
+
uint32_t outputRank,
|
| 88 |
+
uint32_t* kvDim,
|
| 89 |
+
uint32_t kvRank,
|
| 90 |
+
Qnn_Param_t* params,
|
| 91 |
+
uint32_t numParam,
|
| 92 |
+
GraphInfo_t***,
|
| 93 |
+
uint32_t*,
|
| 94 |
+
bool,
|
| 95 |
+
QnnLog_Callback_t,
|
| 96 |
+
QnnLog_Level_t
|
| 97 |
+
);
|
| 98 |
+
|
| 99 |
+
typedef ModelError_t (*FreeGraphInfoFnHandleType_t)(GraphInfo_t***, uint32_t);
|
| 100 |
+
|
| 101 |
+
void* m_libModelHandle{nullptr};
|
| 102 |
+
void* m_backendHandle{nullptr};
|
| 103 |
+
void* m_backendLibraryHandle{nullptr};
|
| 104 |
+
|
| 105 |
+
QNN_INTERFACE_VER_TYPE m_qnnInterface{nullptr};
|
| 106 |
+
QNN_SYSTEM_INTERFACE_VER_TYPE m_qnnSystemInterface{nullptr};
|
| 107 |
+
std::unique_ptr<BackendExtensions> m_backendExtensions{nullptr};
|
| 108 |
+
ComposeGraphsFnHandleType_t m_composeGraphsFnHandle{nullptr};
|
| 109 |
+
GenAIComposeGraphsFnHandleType_t m_genaiComposeGraphsFnHandle{nullptr};
|
| 110 |
+
FreeGraphInfoFnHandleType_t m_freeGraphInfoFnHandle{nullptr};
|
| 111 |
+
uint32_t m_backendId{0};
|
| 112 |
+
Qnn_LogHandle_t m_logHandle{nullptr};
|
| 113 |
+
Qnn_DeviceHandle_t m_deviceHandle{nullptr};
|
| 114 |
+
|
| 115 |
+
Qnn_ProfileHandle_t m_profileBackendHandle{nullptr};
|
| 116 |
+
|
| 117 |
+
std::vector<Qnn_ContextHandle_t> m_contextVec;
|
| 118 |
+
std::unordered_map<GraphInfo*, Qnn_ContextHandle_t> m_contextMap;
|
| 119 |
+
uint32_t m_graphsCount{0};
|
| 120 |
+
int32_t graphCountPerContext{-1};
|
| 121 |
+
GraphInfo_t** m_graphsInfo;
|
| 122 |
+
std::unordered_map<std::string, uint32_t> m_graphNameToIndex;
|
| 123 |
+
std::unordered_map<std::string, GraphInfo*> m_graphNameToInfo;
|
| 124 |
+
std::unordered_map<std::string, uint32_t> m_graphNameToContextIdx;
|
| 125 |
+
std::unordered_map<uint32_t, Qnn_ContextHandle_t> m_contextIdtoHandle;
|
| 126 |
+
std::mutex m_updateCallBackMutex;
|
| 127 |
+
|
| 128 |
+
// Useful Structure for IO Esimtation
|
| 129 |
+
std::unordered_map<int,qualla::QnnUtils::TensorMap> m_graphtoIOMap; // stores {GraphId -> IOTensorMap}
|
| 130 |
+
typedef int CtxBitVector;
|
| 131 |
+
std::map<CtxBitVector, std::map<std::string, size_t>> m_contextAllocMap; // stores {Translated ContextId -> {Tensor name, size}}
|
| 132 |
+
std::map<std::string, std::pair<int, size_t>> m_tensorAllocInfo; // stores {Tensor name -> (fd of RPC buffer, offset)}
|
| 133 |
+
std::unordered_map<uint32_t, uint32_t> m_graphIdxToContextIdx; // stores {Graph Idx -> Context Idx}
|
| 134 |
+
std::unordered_map<std::string,std::shared_ptr<uint8_t>> m_adapterNameToBuffer;
|
| 135 |
+
|
| 136 |
+
uint32_t m_backendConfigCount{0};
|
| 137 |
+
QnnBackend_Config_t** m_backendConfigs{nullptr};
|
| 138 |
+
|
| 139 |
+
QnnHtpDevice_PerfInfrastructure_t* m_perfInfra{nullptr};
|
| 140 |
+
uint32_t m_powerConfigId = 1;
|
| 141 |
+
|
| 142 |
+
// Useful Structure for IO Esimtation
|
| 143 |
+
IOTensor* m_ioBufferMgr{nullptr};
|
| 144 |
+
int32_t m_ctxSize{-1};
|
| 145 |
+
int32_t m_kvDim{-1};
|
| 146 |
+
bool m_loraWeightEnabled{false};
|
| 147 |
+
bool m_lmHeadWeightInput{false};
|
| 148 |
+
KVManagerMode m_kvUpdateMethod{POINTER_SHIFT};
|
| 149 |
+
|
| 150 |
+
bool m_isLogInitialized{false};
|
| 151 |
+
bool m_isBackendInitialized{false};
|
| 152 |
+
bool m_isContextCreated{false};
|
| 153 |
+
|
| 154 |
+
// Variable to keep track of debug mode
|
| 155 |
+
bool m_DebugModeRequested;
|
| 156 |
+
bool m_debugQnn{false};
|
| 157 |
+
|
| 158 |
+
// Variable to indicate whether to mmap context bins or read them in memory
|
| 159 |
+
bool m_mmapContextBins;
|
| 160 |
+
bool m_isDeviceCreated = false;
|
| 161 |
+
|
| 162 |
+
std::vector<std::pair<uint8_t*, uint64_t>> m_contextBinBuffersToBeCleared;
|
| 163 |
+
|
| 164 |
+
void setDeviceStatus(bool status) { m_isDeviceCreated = status; }
|
| 165 |
+
bool getDeviceStatus() { return m_isDeviceCreated; }
|
| 166 |
+
bool getContextConfigs(
|
| 167 |
+
QnnContext_Config_t*** configs,
|
| 168 |
+
uint32_t& contextConfigCount,
|
| 169 |
+
Qnn_Priority_t contextPriority,
|
| 170 |
+
bool graphSwitching = false,
|
| 171 |
+
const std::vector<std::string>& execSelectGraphs = {},
|
| 172 |
+
bool loadSelectGraphs = false
|
| 173 |
+
);
|
| 174 |
+
bool mergeAllContextConfigs(
|
| 175 |
+
QnnContext_Config_t*** allCustomContextConfigs,
|
| 176 |
+
QnnContext_Config_t** customConfigs,
|
| 177 |
+
QnnContext_Config_t** contextConfigs,
|
| 178 |
+
uint32_t customConfigCount,
|
| 179 |
+
uint32_t contextConfigCount
|
| 180 |
+
);
|
| 181 |
+
bool freeContextConfigs(QnnContext_Config_t** contextConfigs, uint32_t contextConfigCount);
|
| 182 |
+
bool setGraphConfigsBeforeExecute(
|
| 183 |
+
Qnn_GraphHandle_t graphHandle,
|
| 184 |
+
QnnGraph_Config_t** graphConfigs,
|
| 185 |
+
uint32_t configCount
|
| 186 |
+
);
|
| 187 |
+
|
| 188 |
+
bool getQnnInterface(std::string backendPath);
|
| 189 |
+
bool getQnnSystemInterface(std::string systemLibraryPath);
|
| 190 |
+
bool loadModel(std::string model_path);
|
| 191 |
+
bool initializeLogging(const QnnLog_Level_t& logLevel, bool debug_qnn);
|
| 192 |
+
void terminateLog();
|
| 193 |
+
bool initializeBackendExtensions(
|
| 194 |
+
BackendExtensionsConfigs backendExtensionsConfig,
|
| 195 |
+
PerfProfile parsedPerfProfile,
|
| 196 |
+
bool debug_qnn
|
| 197 |
+
);
|
| 198 |
+
bool initializeBackend();
|
| 199 |
+
bool terminateBackend();
|
| 200 |
+
bool createDevice();
|
| 201 |
+
bool freeDevice();
|
| 202 |
+
bool createContext(ContextConfigs contextConfig);
|
| 203 |
+
bool freeContext();
|
| 204 |
+
bool composeGraphs(std::vector<GraphConfigs> graphConfigs);
|
| 205 |
+
bool composeGraphs(
|
| 206 |
+
std::vector<GraphConfigs> graphConfigs,
|
| 207 |
+
uint32_t* inputDim,
|
| 208 |
+
uint32_t inputRank,
|
| 209 |
+
uint32_t* outputDim,
|
| 210 |
+
uint32_t outputRank,
|
| 211 |
+
uint32_t* kvDim,
|
| 212 |
+
uint32_t kvRank,
|
| 213 |
+
Qnn_Param_t* params,
|
| 214 |
+
uint32_t numParams
|
| 215 |
+
);
|
| 216 |
+
bool mapAndGetContextBinaryInfo(
|
| 217 |
+
const bool use_mmap,
|
| 218 |
+
std::shared_ptr<uint8_t>& buffer,
|
| 219 |
+
const std::string binaryPath,
|
| 220 |
+
const uint64_t bufferSize,
|
| 221 |
+
const size_t contextIdx,
|
| 222 |
+
const bool graphSwitching,
|
| 223 |
+
QnnSystemContext_Handle_t sysCtxHandle,
|
| 224 |
+
const QnnSystemContext_BinaryInfo_t** binaryInfo
|
| 225 |
+
);
|
| 226 |
+
|
| 227 |
+
bool parseIOTensorsAndAccumulate();
|
| 228 |
+
bool registerTensorsWithBackend(uint32_t& graphIdx);
|
| 229 |
+
|
| 230 |
+
bool finalizeGraphs();
|
| 231 |
+
bool initializePerformance();
|
| 232 |
+
bool destroyPerformance();
|
| 233 |
+
bool boostPerformance();
|
| 234 |
+
bool resetPerformance();
|
| 235 |
+
bool checkCapabilityOfCreateAsync(bool& propRet);
|
| 236 |
+
|
| 237 |
+
bool initProfiling();
|
| 238 |
+
bool extractBackendProfilingInfo(
|
| 239 |
+
Qnn_ProfileHandle_t profileHandle,
|
| 240 |
+
std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
|
| 241 |
+
std::string graphName
|
| 242 |
+
);
|
| 243 |
+
bool extractProfilingSubEvents(
|
| 244 |
+
QnnProfile_EventId_t profileEventId,
|
| 245 |
+
std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
|
| 246 |
+
std::string graphName
|
| 247 |
+
);
|
| 248 |
+
bool extractProfilingEvent(
|
| 249 |
+
QnnProfile_EventId_t profileEventId,
|
| 250 |
+
std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
|
| 251 |
+
std::string graphName
|
| 252 |
+
);
|
| 253 |
+
bool extractBackendProfilingInfo(Qnn_ProfileHandle_t profileHandle);
|
| 254 |
+
bool extractProfilingSubEvents(QnnProfile_EventId_t profileEventId);
|
| 255 |
+
bool extractProfilingEvent(QnnProfile_EventId_t profileEventId);
|
| 256 |
+
|
| 257 |
+
Qnn_ContextHandle_t getContextWithId(uint32_t contextId) {
|
| 258 |
+
return m_contextIdtoHandle[contextId];
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
public:
|
| 262 |
+
QnnApi() {};
|
| 263 |
+
~QnnApi();
|
| 264 |
+
|
| 265 |
+
bool freeGraphs();
|
| 266 |
+
static QnnApi& getInstance();
|
| 267 |
+
#if QUALLA_QNN_API_VERSION >= 21700
|
| 268 |
+
static void contextNotifyFn(
|
| 269 |
+
Qnn_ContextHandle_t context,
|
| 270 |
+
Qnn_GraphHandle_t graph,
|
| 271 |
+
const char* graph_name,
|
| 272 |
+
QnnContext_createFromBinaryAsyncNotifyType_t completeType,
|
| 273 |
+
void* notifyParam,
|
| 274 |
+
Qnn_ErrorHandle_t status
|
| 275 |
+
);
|
| 276 |
+
#endif
|
| 277 |
+
bool createFromBinary(
|
| 278 |
+
std::vector<std::string> cachedBinariesPathVec,
|
| 279 |
+
ContextConfigs contextConfig,
|
| 280 |
+
int64_t spill_fill_buffer_size = 0,
|
| 281 |
+
uint64_t mmap_budget = 0,
|
| 282 |
+
bool graphSwitching = false,
|
| 283 |
+
const std::vector<std::string>& execSelectGraphs = {},
|
| 284 |
+
bool loadSelectGraphs = false
|
| 285 |
+
);
|
| 286 |
+
#if QUALLA_QNN_API_VERSION >= 21700
|
| 287 |
+
bool createFromBinaryListAsync(
|
| 288 |
+
std::vector<std::string> cachedBinariesPathVec,
|
| 289 |
+
ContextConfigs contextConfig,
|
| 290 |
+
int64_t spill_fill_buffer_size = 0,
|
| 291 |
+
uint64_t mmap_budget = 0,
|
| 292 |
+
bool graphSwitching = false,
|
| 293 |
+
const std::vector<std::string>& execSelectGraphs = {},
|
| 294 |
+
bool loadSelectGraphs = false
|
| 295 |
+
);
|
| 296 |
+
#endif
|
| 297 |
+
bool initialize(
|
| 298 |
+
std::string backendPath,
|
| 299 |
+
std::vector<std::string> modelPathOrCachedBinaryPathVec,
|
| 300 |
+
BackendExtensionsConfigs backendExtensionsConfig,
|
| 301 |
+
PerfProfile parsedPerfProfile = PerfProfile::BURST,
|
| 302 |
+
ContextConfigs contextConfig = ContextConfigs(),
|
| 303 |
+
std::vector<GraphConfigs> graphConfigs = {},
|
| 304 |
+
bool loadFromCachedBinary = false,
|
| 305 |
+
std::string systemLibraryPath = "",
|
| 306 |
+
bool debugModeRequested = false,
|
| 307 |
+
int64_t spill_fill_buffer_size = 0,
|
| 308 |
+
bool mmapContextBins = false,
|
| 309 |
+
bool asyncInit = true,
|
| 310 |
+
uint64_t mmap_budget = 0,
|
| 311 |
+
bool debug_qnn = false,
|
| 312 |
+
bool graphSwitching = false,
|
| 313 |
+
const std::vector<std::string>& execSelectGraphs = {},
|
| 314 |
+
bool loadSelectGraphs = false
|
| 315 |
+
);
|
| 316 |
+
|
| 317 |
+
bool registerOpPackage(std::string opPackagePath);
|
| 318 |
+
|
| 319 |
+
void setIOTensorBufferMgr(IOTensor* ioBufferMgr){
|
| 320 |
+
m_ioBufferMgr = ioBufferMgr;
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
void setKVDim(int32_t kvDim){
|
| 324 |
+
m_kvDim = kvDim;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
void setContextSize(int32_t ctxSize){
|
| 328 |
+
m_ctxSize = ctxSize;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
void setKVUpdateMethod(KVManagerMode kvUpdateMethod){
|
| 332 |
+
m_kvUpdateMethod = kvUpdateMethod ;
|
| 333 |
+
}
|
| 334 |
+
|
| 335 |
+
std::map<std::string, std::pair<int, size_t>>* getTensorAllocInfo(){
|
| 336 |
+
return &m_tensorAllocInfo;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
bool getLmHeadWeightInputEnabled(){
|
| 340 |
+
return m_lmHeadWeightInput;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
bool getLoraWeightEnabled(){
|
| 344 |
+
return m_loraWeightEnabled;
|
| 345 |
+
}
|
| 346 |
+
// Initalize with OpPackage
|
| 347 |
+
bool initialize(
|
| 348 |
+
std::string backendPath,
|
| 349 |
+
std::string modelPath,
|
| 350 |
+
std::string opPackage,
|
| 351 |
+
ContextConfigs contextConfig,
|
| 352 |
+
std::vector<GraphConfigs> graphConfigs,
|
| 353 |
+
uint32_t* inputDim,
|
| 354 |
+
uint32_t inputRank,
|
| 355 |
+
uint32_t* outputDim,
|
| 356 |
+
uint32_t outputRank,
|
| 357 |
+
uint32_t* kvDim,
|
| 358 |
+
uint32_t kvRank,
|
| 359 |
+
Qnn_Param_t* params,
|
| 360 |
+
uint32_t numParams,
|
| 361 |
+
bool debugModeRequested
|
| 362 |
+
);
|
| 363 |
+
|
| 364 |
+
bool graphExecute(
|
| 365 |
+
Qnn_Tensor_t* input,
|
| 366 |
+
Qnn_Tensor_t* output,
|
| 367 |
+
std::string graphName,
|
| 368 |
+
std::map<std::string, std::pair<double, uint16_t>>& timeLogs
|
| 369 |
+
);
|
| 370 |
+
|
| 371 |
+
bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);
|
| 372 |
+
|
| 373 |
+
QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; };
|
| 374 |
+
GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; };
|
| 375 |
+
uint32_t getGraphsCount() { return m_graphsCount; };
|
| 376 |
+
int32_t getGraphCountPerContext() { return graphCountPerContext; }
|
| 377 |
+
std::vector<Qnn_ContextHandle_t>& getContexts() { return m_contextVec; };
|
| 378 |
+
const Qnn_ContextHandle_t getContexts(GraphInfo_t* const graph) {
|
| 379 |
+
return m_contextMap.at(graph);
|
| 380 |
+
};
|
| 381 |
+
|
| 382 |
+
void updateContext(Qnn_ContextHandle_t context, uint32_t contextId) {
|
| 383 |
+
std::lock_guard<std::mutex> lock(m_updateCallBackMutex);
|
| 384 |
+
m_contextVec.push_back(context);
|
| 385 |
+
m_contextIdtoHandle[contextId] = context;
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
void updateQnnApiGraphsandContextsInfo(
|
| 389 |
+
std::string graphName,
|
| 390 |
+
Qnn_GraphHandle_t graph,
|
| 391 |
+
uint32_t contextId
|
| 392 |
+
) {
|
| 393 |
+
// set graph handle to GraphInfo
|
| 394 |
+
std::lock_guard<std::mutex> lock(m_updateCallBackMutex);
|
| 395 |
+
m_graphNameToInfo[graphName]->graph = graph;
|
| 396 |
+
m_graphNameToContextIdx[graphName] = contextId;
|
| 397 |
+
m_graphsCount++;
|
| 398 |
+
}
|
| 399 |
+
|
| 400 |
+
static inline size_t getDataTypeSize(const Qnn_DataType_t& datatype) {
|
| 401 |
+
return g_qnnDataTypeToSize[datatype];
|
| 402 |
+
}
|
| 403 |
+
static inline std::string getTensorName(const TensorWrapper& tensorWrapper) {
|
| 404 |
+
return GET_TENSOR_WRAPPER_NAME(tensorWrapper);
|
| 405 |
+
}
|
| 406 |
+
static bool getTensorQuantParams(
|
| 407 |
+
const Qnn_Tensor_t* tensor,
|
| 408 |
+
std::vector<QuantParam>& quantParamsVec
|
| 409 |
+
);
|
| 410 |
+
static bool getTensorShape(std::vector<size_t>& tensorDims, const TensorWrapper& tensorWrapper);
|
| 411 |
+
static inline Qnn_DataType_t getTensorDtype(const Qnn_Tensor_t* tensor) {
|
| 412 |
+
return QNN_TENSOR_GET_DATA_TYPE(tensor);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
bool getTensorNameAndShape(
|
| 416 |
+
std::string& tensorName,
|
| 417 |
+
std::vector<size_t>& tensorDims,
|
| 418 |
+
TensorWrapper& tensorWrapper
|
| 419 |
+
);
|
| 420 |
+
static void qnnLogCallback(
|
| 421 |
+
const char* fmt,
|
| 422 |
+
QnnLog_Level_t level,
|
| 423 |
+
uint64_t timestamp,
|
| 424 |
+
va_list args
|
| 425 |
+
);
|
| 426 |
+
bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
|
| 427 |
+
uint64_t bufferSize,
|
| 428 |
+
uint32_t graphIndex);
|
| 429 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp
ADDED
|
@@ -0,0 +1,636 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "QnnApiUtils.hpp"
|
| 10 |
+
#include "QnnTypeMacros.hpp"
|
| 11 |
+
|
| 12 |
+
#include <algorithm>
|
| 13 |
+
#include <cstring>
|
| 14 |
+
#include <fstream>
|
| 15 |
+
#include <iostream>
|
| 16 |
+
#include <sstream>
|
| 17 |
+
#include <string>
|
| 18 |
+
#include <tuple>
|
| 19 |
+
|
| 20 |
+
#include <fcntl.h>
|
| 21 |
+
#include <errno.h>
|
| 22 |
+
|
| 23 |
+
#ifdef _WIN32
|
| 24 |
+
#include <windows.h>
|
| 25 |
+
#define __open ::_open
|
| 26 |
+
#define __strdup ::_strdup
|
| 27 |
+
#else
|
| 28 |
+
#include <unistd.h>
|
| 29 |
+
#include <sys/mman.h>
|
| 30 |
+
#define __open ::open
|
| 31 |
+
#define __strdup ::strdup
|
| 32 |
+
#endif
|
| 33 |
+
|
| 34 |
+
bool freeQnnTensorWrapper(TensorWrapper& tensorWrapper) {
|
| 35 |
+
// free all pointer allocations in struct
|
| 36 |
+
if (nullptr != GET_TENSOR_WRAPPER_NAME(tensorWrapper)) {
|
| 37 |
+
free((void*)GET_TENSOR_WRAPPER_NAME(tensorWrapper));
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrapper);
|
| 41 |
+
free(QNN_TENSOR_GET_DIMENSIONS(tensor));
|
| 42 |
+
return true;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
bool freeQnnTensorWrappers(TensorWrapper*& tensorWrappers, uint32_t numTensors) {
|
| 46 |
+
// free all pointer allocations in struct
|
| 47 |
+
for (size_t i = 0; i < numTensors; i++) {
|
| 48 |
+
freeQnnTensorWrapper(tensorWrappers[i]);
|
| 49 |
+
}
|
| 50 |
+
free(tensorWrappers);
|
| 51 |
+
|
| 52 |
+
return true;
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
bool freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs) {
|
| 56 |
+
if (graphsInfo == nullptr || *graphsInfo == nullptr) {
|
| 57 |
+
return false;
|
| 58 |
+
}
|
| 59 |
+
for (uint32_t i = 0; i < numGraphs; i++) {
|
| 60 |
+
if (nullptr != (*graphsInfo)[i]) {
|
| 61 |
+
free((*graphsInfo)[i]->graphName);
|
| 62 |
+
freeQnnTensorWrappers(
|
| 63 |
+
(*graphsInfo)[i]->inputTensors, (*graphsInfo)[i]->numInputTensors
|
| 64 |
+
);
|
| 65 |
+
freeQnnTensorWrappers(
|
| 66 |
+
(*graphsInfo)[i]->outputTensors, (*graphsInfo)[i]->numOutputTensors
|
| 67 |
+
);
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
free(**graphsInfo);
|
| 71 |
+
free(*graphsInfo);
|
| 72 |
+
*graphsInfo = nullptr;
|
| 73 |
+
|
| 74 |
+
return true;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
bool freeGraphInfo(GraphInfo_t* graphInfo) {
|
| 78 |
+
if (graphInfo == nullptr) {
|
| 79 |
+
return false;
|
| 80 |
+
}
|
| 81 |
+
if (nullptr != graphInfo->graphName) {
|
| 82 |
+
free(graphInfo->graphName);
|
| 83 |
+
}
|
| 84 |
+
freeQnnTensorWrappers(graphInfo->inputTensors, graphInfo->numInputTensors);
|
| 85 |
+
freeQnnTensorWrappers(graphInfo->outputTensors, graphInfo->numOutputTensors);
|
| 86 |
+
free(graphInfo);
|
| 87 |
+
return true;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
bool updateTensorInfo(const Qnn_Tensor_t* tensorsInfoSrc,
|
| 91 |
+
TensorWrapper* tensorWrappers,
|
| 92 |
+
uint32_t tensorsCount
|
| 93 |
+
){
|
| 94 |
+
for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) {
|
| 95 |
+
QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", (int)tIdx);
|
| 96 |
+
Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tIdx]);
|
| 97 |
+
|
| 98 |
+
QNN_TENSOR_SET_ID(tensor, QNN_TENSOR_GET_ID(&tensorsInfoSrc[tIdx]));
|
| 99 |
+
QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_GET_TYPE(&tensorsInfoSrc[tIdx]));
|
| 100 |
+
QNN_TENSOR_SET_DATA_FORMAT(tensor, QNN_TENSOR_GET_DATA_FORMAT(&tensorsInfoSrc[tIdx]));
|
| 101 |
+
QNN_TENSOR_SET_DATA_TYPE(tensor, QNN_TENSOR_GET_DATA_TYPE(&tensorsInfoSrc[tIdx]));
|
| 102 |
+
Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT;
|
| 103 |
+
qParams.encodingDefinition =
|
| 104 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).encodingDefinition;
|
| 105 |
+
qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
|
| 106 |
+
if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
|
| 107 |
+
QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
|
| 108 |
+
qParams.quantizationEncoding =
|
| 109 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
|
| 110 |
+
qParams.scaleOffsetEncoding =
|
| 111 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).scaleOffsetEncoding;
|
| 112 |
+
} else if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
|
| 113 |
+
QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
|
| 114 |
+
qParams.quantizationEncoding =
|
| 115 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
|
| 116 |
+
qParams.axisScaleOffsetEncoding.axis =
|
| 117 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 118 |
+
.axisScaleOffsetEncoding.axis;
|
| 119 |
+
qParams.axisScaleOffsetEncoding.numScaleOffsets =
|
| 120 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 121 |
+
.axisScaleOffsetEncoding.numScaleOffsets;
|
| 122 |
+
if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 123 |
+
.axisScaleOffsetEncoding.numScaleOffsets > 0) {
|
| 124 |
+
qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc(
|
| 125 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 126 |
+
.axisScaleOffsetEncoding.numScaleOffsets *
|
| 127 |
+
sizeof(Qnn_ScaleOffset_t)
|
| 128 |
+
);
|
| 129 |
+
if (qParams.axisScaleOffsetEncoding.scaleOffset) {
|
| 130 |
+
for (size_t idx = 0;
|
| 131 |
+
idx < QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 132 |
+
.axisScaleOffsetEncoding.numScaleOffsets;
|
| 133 |
+
idx++) {
|
| 134 |
+
qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale =
|
| 135 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 136 |
+
.axisScaleOffsetEncoding.scaleOffset[idx]
|
| 137 |
+
.scale;
|
| 138 |
+
qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset =
|
| 139 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 140 |
+
.axisScaleOffsetEncoding.scaleOffset[idx]
|
| 141 |
+
.offset;
|
| 142 |
+
}
|
| 143 |
+
}
|
| 144 |
+
}
|
| 145 |
+
}
|
| 146 |
+
QNN_TENSOR_SET_QUANT_PARAMS(tensor, qParams);
|
| 147 |
+
QNN_TENSOR_SET_RANK(tensor, QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]));
|
| 148 |
+
if (QNN_TENSOR_GET_RANK(tensorsInfoSrc[tIdx]) > 0) {
|
| 149 |
+
if (QNN_TENSOR_GET_DIMENSIONS(tensor)) {
|
| 150 |
+
memcpy(QNN_TENSOR_GET_DIMENSIONS(tensor),
|
| 151 |
+
QNN_TENSOR_GET_DIMENSIONS(&tensorsInfoSrc[tIdx]),
|
| 152 |
+
QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t));
|
| 153 |
+
}
|
| 154 |
+
}
|
| 155 |
+
}
|
| 156 |
+
return true;
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
bool copyTensorsInfo(
|
| 160 |
+
const Qnn_Tensor_t* tensorsInfoSrc,
|
| 161 |
+
TensorWrapper*& tensorWrappers,
|
| 162 |
+
uint32_t tensorsCount
|
| 163 |
+
) {
|
| 164 |
+
|
| 165 |
+
auto returnStatus = true;
|
| 166 |
+
tensorWrappers = (TensorWrapper*)calloc(tensorsCount, sizeof(TensorWrapper));
|
| 167 |
+
if (nullptr == tensorWrappers) {
|
| 168 |
+
QNN_ERROR("Failed to allocate memory for tensorWrappers.");
|
| 169 |
+
return false;
|
| 170 |
+
}
|
| 171 |
+
if (returnStatus) {
|
| 172 |
+
for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) {
|
| 173 |
+
// QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", (int)tIdx);
|
| 174 |
+
Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tIdx]);
|
| 175 |
+
tensor = QNN_TENSOR_INIT;
|
| 176 |
+
|
| 177 |
+
const char* tensorName = QNN_TENSOR_GET_NAME(&tensorsInfoSrc[tIdx]);
|
| 178 |
+
if (!tensorName) {
|
| 179 |
+
QNN_TENSOR_SET_NAME(tensor, nullptr);
|
| 180 |
+
} else {
|
| 181 |
+
QNN_TENSOR_SET_NAME(tensor, __strdup(tensorName));
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
QNN_TENSOR_SET_ID(tensor, QNN_TENSOR_GET_ID(&tensorsInfoSrc[tIdx]));
|
| 185 |
+
QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_GET_TYPE(&tensorsInfoSrc[tIdx]));
|
| 186 |
+
QNN_TENSOR_SET_DATA_FORMAT(tensor, QNN_TENSOR_GET_DATA_FORMAT(&tensorsInfoSrc[tIdx]));
|
| 187 |
+
QNN_TENSOR_SET_DATA_TYPE(tensor, QNN_TENSOR_GET_DATA_TYPE(&tensorsInfoSrc[tIdx]));
|
| 188 |
+
Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT;
|
| 189 |
+
qParams.encodingDefinition =
|
| 190 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).encodingDefinition;
|
| 191 |
+
qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
|
| 192 |
+
if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
|
| 193 |
+
QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
|
| 194 |
+
qParams.quantizationEncoding =
|
| 195 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
|
| 196 |
+
qParams.scaleOffsetEncoding =
|
| 197 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).scaleOffsetEncoding;
|
| 198 |
+
} else if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
|
| 199 |
+
QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
|
| 200 |
+
qParams.quantizationEncoding =
|
| 201 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
|
| 202 |
+
qParams.axisScaleOffsetEncoding.axis =
|
| 203 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 204 |
+
.axisScaleOffsetEncoding.axis;
|
| 205 |
+
qParams.axisScaleOffsetEncoding.numScaleOffsets =
|
| 206 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 207 |
+
.axisScaleOffsetEncoding.numScaleOffsets;
|
| 208 |
+
if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 209 |
+
.axisScaleOffsetEncoding.numScaleOffsets > 0) {
|
| 210 |
+
qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc(
|
| 211 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 212 |
+
.axisScaleOffsetEncoding.numScaleOffsets *
|
| 213 |
+
sizeof(Qnn_ScaleOffset_t)
|
| 214 |
+
);
|
| 215 |
+
if (qParams.axisScaleOffsetEncoding.scaleOffset) {
|
| 216 |
+
for (size_t idx = 0;
|
| 217 |
+
idx < QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 218 |
+
.axisScaleOffsetEncoding.numScaleOffsets;
|
| 219 |
+
idx++) {
|
| 220 |
+
qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale =
|
| 221 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 222 |
+
.axisScaleOffsetEncoding.scaleOffset[idx]
|
| 223 |
+
.scale;
|
| 224 |
+
qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset =
|
| 225 |
+
QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
|
| 226 |
+
.axisScaleOffsetEncoding.scaleOffset[idx]
|
| 227 |
+
.offset;
|
| 228 |
+
}
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
}
|
| 232 |
+
QNN_TENSOR_SET_QUANT_PARAMS(tensor, qParams);
|
| 233 |
+
QNN_TENSOR_SET_RANK(tensor, QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]));
|
| 234 |
+
QNN_TENSOR_SET_DIMENSIONS(tensor, nullptr);
|
| 235 |
+
if (QNN_TENSOR_GET_RANK(tensorsInfoSrc[tIdx]) > 0) {
|
| 236 |
+
QNN_TENSOR_SET_DIMENSIONS(
|
| 237 |
+
tensor,
|
| 238 |
+
(uint32_t*)malloc(
|
| 239 |
+
QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t)
|
| 240 |
+
)
|
| 241 |
+
);
|
| 242 |
+
if (QNN_TENSOR_GET_DIMENSIONS(tensor)) {
|
| 243 |
+
memcpy(QNN_TENSOR_GET_DIMENSIONS(tensor),
|
| 244 |
+
QNN_TENSOR_GET_DIMENSIONS(&tensorsInfoSrc[tIdx]),
|
| 245 |
+
QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t));
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
}
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
return returnStatus;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
bool updateGraphInfoV1(const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
|
| 256 |
+
GraphInfo_t* graphInfoDst
|
| 257 |
+
){
|
| 258 |
+
if (graphInfoSrc->graphInputs) {
|
| 259 |
+
if (!updateTensorInfo(
|
| 260 |
+
graphInfoSrc->graphInputs,
|
| 261 |
+
graphInfoDst->inputTensors,
|
| 262 |
+
graphInfoSrc->numGraphInputs
|
| 263 |
+
)) {
|
| 264 |
+
return false;
|
| 265 |
+
}
|
| 266 |
+
}
|
| 267 |
+
if (graphInfoSrc->graphOutputs) {
|
| 268 |
+
if (!updateTensorInfo(
|
| 269 |
+
graphInfoSrc->graphOutputs,
|
| 270 |
+
graphInfoDst->outputTensors,
|
| 271 |
+
graphInfoSrc->numGraphOutputs
|
| 272 |
+
)) {
|
| 273 |
+
return false;
|
| 274 |
+
}
|
| 275 |
+
}
|
| 276 |
+
return true;
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
bool updateGraphInfoV3(const QnnSystemContext_GraphInfoV3_t* graphInfoSrc,
|
| 281 |
+
GraphInfo_t* graphInfoDst
|
| 282 |
+
){
|
| 283 |
+
if (graphInfoSrc->graphInputs) {
|
| 284 |
+
if (!updateTensorInfo(
|
| 285 |
+
graphInfoSrc->graphInputs,
|
| 286 |
+
graphInfoDst->inputTensors,
|
| 287 |
+
graphInfoSrc->numGraphInputs
|
| 288 |
+
)) {
|
| 289 |
+
return false;
|
| 290 |
+
}
|
| 291 |
+
}
|
| 292 |
+
if (graphInfoSrc->graphOutputs) {
|
| 293 |
+
if (!updateTensorInfo(
|
| 294 |
+
graphInfoSrc->graphOutputs,
|
| 295 |
+
graphInfoDst->outputTensors,
|
| 296 |
+
graphInfoSrc->numGraphOutputs
|
| 297 |
+
)) {
|
| 298 |
+
return false;
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
return true;
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
bool copyGraphsInfoV1(
|
| 305 |
+
const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
|
| 306 |
+
GraphInfo_t* graphInfoDst
|
| 307 |
+
) {
|
| 308 |
+
graphInfoDst->graphName = nullptr;
|
| 309 |
+
if (graphInfoSrc->graphName) {
|
| 310 |
+
graphInfoDst->graphName = __strdup(graphInfoSrc->graphName);
|
| 311 |
+
}
|
| 312 |
+
graphInfoDst->inputTensors = nullptr;
|
| 313 |
+
graphInfoDst->numInputTensors = 0;
|
| 314 |
+
if (graphInfoSrc->graphInputs) {
|
| 315 |
+
if (!copyTensorsInfo(
|
| 316 |
+
graphInfoSrc->graphInputs,
|
| 317 |
+
graphInfoDst->inputTensors,
|
| 318 |
+
graphInfoSrc->numGraphInputs
|
| 319 |
+
)) {
|
| 320 |
+
return false;
|
| 321 |
+
}
|
| 322 |
+
graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs;
|
| 323 |
+
}
|
| 324 |
+
graphInfoDst->outputTensors = nullptr;
|
| 325 |
+
graphInfoDst->numOutputTensors = 0;
|
| 326 |
+
if (graphInfoSrc->graphOutputs) {
|
| 327 |
+
if (!copyTensorsInfo(
|
| 328 |
+
graphInfoSrc->graphOutputs,
|
| 329 |
+
graphInfoDst->outputTensors,
|
| 330 |
+
graphInfoSrc->numGraphOutputs
|
| 331 |
+
)) {
|
| 332 |
+
return false;
|
| 333 |
+
}
|
| 334 |
+
graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs;
|
| 335 |
+
}
|
| 336 |
+
return true;
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
bool copyGraphsInfoV3(const QnnSystemContext_GraphInfoV3_t *graphInfoSrc,
|
| 340 |
+
GraphInfo_t *graphInfoDst) {
|
| 341 |
+
graphInfoDst->graphName = nullptr;
|
| 342 |
+
if (graphInfoSrc->graphName) {
|
| 343 |
+
graphInfoDst->graphName =
|
| 344 |
+
__strdup(graphInfoSrc->graphName);
|
| 345 |
+
}
|
| 346 |
+
graphInfoDst->inputTensors = nullptr;
|
| 347 |
+
graphInfoDst->numInputTensors = 0;
|
| 348 |
+
if (graphInfoSrc->graphInputs) {
|
| 349 |
+
if (!copyTensorsInfo(
|
| 350 |
+
graphInfoSrc->graphInputs, graphInfoDst->inputTensors, graphInfoSrc->numGraphInputs)) {
|
| 351 |
+
return false;
|
| 352 |
+
}
|
| 353 |
+
graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs;
|
| 354 |
+
}
|
| 355 |
+
graphInfoDst->outputTensors = nullptr;
|
| 356 |
+
graphInfoDst->numOutputTensors = 0;
|
| 357 |
+
if (graphInfoSrc->graphOutputs) {
|
| 358 |
+
if (!copyTensorsInfo(graphInfoSrc->graphOutputs,
|
| 359 |
+
graphInfoDst->outputTensors,
|
| 360 |
+
graphInfoSrc->numGraphOutputs)) {
|
| 361 |
+
return false;
|
| 362 |
+
}
|
| 363 |
+
graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs;
|
| 364 |
+
}
|
| 365 |
+
return true;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
bool updateGraphInfo(const QnnSystemContext_GraphInfo_t* graphsInput,
|
| 369 |
+
const uint32_t numGraphs,
|
| 370 |
+
GraphInfo_t** graphsInfo,
|
| 371 |
+
uint32_t& graphsCount
|
| 372 |
+
){
|
| 373 |
+
|
| 374 |
+
for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) {
|
| 375 |
+
if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
|
| 376 |
+
if(updateGraphInfoV1(&graphsInput[gIdx].graphInfoV1, graphsInfo[graphsCount]) == false) {
|
| 377 |
+
return false;
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
|
| 381 |
+
if(updateGraphInfoV3(&graphsInput[gIdx].graphInfoV3, graphsInfo[graphsCount]) == false) {
|
| 382 |
+
return false;
|
| 383 |
+
}
|
| 384 |
+
}
|
| 385 |
+
graphsCount++;
|
| 386 |
+
}
|
| 387 |
+
return true;
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
bool copyGraphsInfo(
|
| 392 |
+
const QnnSystemContext_GraphInfo_t* graphsInput,
|
| 393 |
+
const uint32_t numGraphs,
|
| 394 |
+
GraphInfo_t**& graphsInfo
|
| 395 |
+
) {
|
| 396 |
+
|
| 397 |
+
if (!graphsInput) {
|
| 398 |
+
QNN_ERROR("Received nullptr for graphsInput.");
|
| 399 |
+
return false;
|
| 400 |
+
}
|
| 401 |
+
auto returnStatus = true;
|
| 402 |
+
graphsInfo = (GraphInfo_t**)calloc(numGraphs, sizeof(GraphInfo_t*));
|
| 403 |
+
GraphInfo_t* graphInfoArr = (GraphInfo_t*)calloc(numGraphs, sizeof(GraphInfo_t));
|
| 404 |
+
if (nullptr == graphsInfo || nullptr == graphInfoArr) {
|
| 405 |
+
QNN_ERROR("Failure to allocate memory for *graphInfo");
|
| 406 |
+
returnStatus = false;
|
| 407 |
+
}
|
| 408 |
+
if (true == returnStatus) {
|
| 409 |
+
for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) {
|
| 410 |
+
QNN_DEBUG("Extracting graphsInfo for graph Idx: %d", (int)gIdx);
|
| 411 |
+
if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
|
| 412 |
+
copyGraphsInfoV1(&graphsInput[gIdx].graphInfoV1, &graphInfoArr[gIdx]);
|
| 413 |
+
}
|
| 414 |
+
if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
|
| 415 |
+
copyGraphsInfoV3(&graphsInput[gIdx].graphInfoV3, &graphInfoArr[gIdx]);
|
| 416 |
+
}
|
| 417 |
+
graphsInfo[gIdx] = graphInfoArr + gIdx;
|
| 418 |
+
}
|
| 419 |
+
}
|
| 420 |
+
if (true != returnStatus) {
|
| 421 |
+
QNN_DEBUG("Received an ERROR during extractGraphsInfo. Freeing resources.");
|
| 422 |
+
if (graphsInfo) {
|
| 423 |
+
for (uint32_t gIdx = 0; gIdx < numGraphs; gIdx++) {
|
| 424 |
+
if (graphsInfo[gIdx]) {
|
| 425 |
+
if (nullptr != graphsInfo[gIdx]->graphName) {
|
| 426 |
+
free(graphsInfo[gIdx]->graphName);
|
| 427 |
+
graphsInfo[gIdx]->graphName = nullptr;
|
| 428 |
+
}
|
| 429 |
+
freeQnnTensorWrappers(
|
| 430 |
+
graphsInfo[gIdx]->inputTensors, graphsInfo[gIdx]->numInputTensors
|
| 431 |
+
);
|
| 432 |
+
freeQnnTensorWrappers(
|
| 433 |
+
graphsInfo[gIdx]->outputTensors, graphsInfo[gIdx]->numOutputTensors
|
| 434 |
+
);
|
| 435 |
+
}
|
| 436 |
+
}
|
| 437 |
+
free(*graphsInfo);
|
| 438 |
+
}
|
| 439 |
+
free(graphsInfo);
|
| 440 |
+
graphsInfo = nullptr;
|
| 441 |
+
}
|
| 442 |
+
|
| 443 |
+
return true;
|
| 444 |
+
}
|
| 445 |
+
|
| 446 |
+
uint32_t getNumGraphInBinary(const QnnSystemContext_BinaryInfo_t* binaryInfo)
|
| 447 |
+
{
|
| 448 |
+
uint32_t numGraph = 0;
|
| 449 |
+
if (nullptr == binaryInfo) {
|
| 450 |
+
QNN_ERROR("binaryInfo is nullptr.");
|
| 451 |
+
return false;
|
| 452 |
+
}
|
| 453 |
+
if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
|
| 454 |
+
numGraph = binaryInfo->contextBinaryInfoV1.numGraphs;
|
| 455 |
+
}else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
|
| 456 |
+
numGraph = binaryInfo->contextBinaryInfoV2.numGraphs;
|
| 457 |
+
}
|
| 458 |
+
else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
|
| 459 |
+
numGraph = binaryInfo->contextBinaryInfoV3.numGraphs;
|
| 460 |
+
}
|
| 461 |
+
return numGraph;
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
bool updateMetaDataToGraphsInfo(const QnnSystemContext_BinaryInfo_t* binaryInfo,
|
| 465 |
+
GraphInfo_t** graphsInfo,
|
| 466 |
+
uint32_t& graphsCount
|
| 467 |
+
){
|
| 468 |
+
if (nullptr == binaryInfo) {
|
| 469 |
+
QNN_ERROR("binaryInfo is nullptr.");
|
| 470 |
+
return false;
|
| 471 |
+
}
|
| 472 |
+
if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
|
| 473 |
+
if (binaryInfo->contextBinaryInfoV1.graphs) {
|
| 474 |
+
if (!updateGraphInfo(
|
| 475 |
+
binaryInfo->contextBinaryInfoV1.graphs,
|
| 476 |
+
binaryInfo->contextBinaryInfoV1.numGraphs,
|
| 477 |
+
graphsInfo,
|
| 478 |
+
graphsCount
|
| 479 |
+
)) {
|
| 480 |
+
QNN_ERROR("Failed while copying graphs Info.");
|
| 481 |
+
return false;
|
| 482 |
+
}
|
| 483 |
+
return true;
|
| 484 |
+
}
|
| 485 |
+
} else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
|
| 486 |
+
if (binaryInfo->contextBinaryInfoV2.graphs) {
|
| 487 |
+
if (!updateGraphInfo(
|
| 488 |
+
binaryInfo->contextBinaryInfoV2.graphs,
|
| 489 |
+
binaryInfo->contextBinaryInfoV2.numGraphs,
|
| 490 |
+
graphsInfo,
|
| 491 |
+
graphsCount
|
| 492 |
+
)) {
|
| 493 |
+
QNN_ERROR("Failed while copying graphs Info.");
|
| 494 |
+
return false;
|
| 495 |
+
}
|
| 496 |
+
return true;
|
| 497 |
+
}
|
| 498 |
+
} else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
|
| 499 |
+
if (binaryInfo->contextBinaryInfoV3.graphs) {
|
| 500 |
+
if (!updateGraphInfo(
|
| 501 |
+
binaryInfo->contextBinaryInfoV3.graphs,
|
| 502 |
+
binaryInfo->contextBinaryInfoV3.numGraphs,
|
| 503 |
+
graphsInfo,
|
| 504 |
+
graphsCount
|
| 505 |
+
)) {
|
| 506 |
+
QNN_ERROR("Failed while copying graphs Info.");
|
| 507 |
+
return false;
|
| 508 |
+
}
|
| 509 |
+
return true;
|
| 510 |
+
}
|
| 511 |
+
}
|
| 512 |
+
QNN_ERROR("Unrecognized system context binary info version.");
|
| 513 |
+
return false;
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
bool copyMetadataToGraphsInfo(
|
| 517 |
+
const QnnSystemContext_BinaryInfo_t* binaryInfo,
|
| 518 |
+
GraphInfo_t**& graphsInfo,
|
| 519 |
+
uint32_t& graphsCount
|
| 520 |
+
) {
|
| 521 |
+
if (nullptr == binaryInfo) {
|
| 522 |
+
QNN_ERROR("binaryInfo is nullptr.");
|
| 523 |
+
return false;
|
| 524 |
+
}
|
| 525 |
+
graphsCount = 0;
|
| 526 |
+
if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
|
| 527 |
+
if (binaryInfo->contextBinaryInfoV1.graphs) {
|
| 528 |
+
if (!copyGraphsInfo(
|
| 529 |
+
binaryInfo->contextBinaryInfoV1.graphs,
|
| 530 |
+
binaryInfo->contextBinaryInfoV1.numGraphs,
|
| 531 |
+
graphsInfo
|
| 532 |
+
)) {
|
| 533 |
+
QNN_ERROR("Failed while copying graphs Info.");
|
| 534 |
+
return false;
|
| 535 |
+
}
|
| 536 |
+
graphsCount = binaryInfo->contextBinaryInfoV1.numGraphs;
|
| 537 |
+
return true;
|
| 538 |
+
}
|
| 539 |
+
} else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
|
| 540 |
+
if (binaryInfo->contextBinaryInfoV2.graphs) {
|
| 541 |
+
if (!copyGraphsInfo(
|
| 542 |
+
binaryInfo->contextBinaryInfoV2.graphs,
|
| 543 |
+
binaryInfo->contextBinaryInfoV2.numGraphs,
|
| 544 |
+
graphsInfo
|
| 545 |
+
)) {
|
| 546 |
+
QNN_ERROR("Failed while copying graphs Info.");
|
| 547 |
+
return false;
|
| 548 |
+
}
|
| 549 |
+
graphsCount = binaryInfo->contextBinaryInfoV2.numGraphs;
|
| 550 |
+
return true;
|
| 551 |
+
}
|
| 552 |
+
} else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
|
| 553 |
+
if (binaryInfo->contextBinaryInfoV3.graphs) {
|
| 554 |
+
if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV3.graphs,
|
| 555 |
+
binaryInfo->contextBinaryInfoV3.numGraphs,
|
| 556 |
+
graphsInfo)) {
|
| 557 |
+
QNN_ERROR("Failed while copying graphs Info.");
|
| 558 |
+
return false;
|
| 559 |
+
}
|
| 560 |
+
graphsCount = binaryInfo->contextBinaryInfoV3.numGraphs;
|
| 561 |
+
return true;
|
| 562 |
+
}
|
| 563 |
+
}
|
| 564 |
+
QNN_ERROR("Unrecognized system context binary info version.");
|
| 565 |
+
return false;
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
size_t getFileSize(std::string filePath) {
|
| 569 |
+
std::ifstream in(filePath, std::ifstream::binary);
|
| 570 |
+
if (!in) {
|
| 571 |
+
QNN_ERROR("Failed to open input file: %s", filePath.c_str());
|
| 572 |
+
return 0;
|
| 573 |
+
}
|
| 574 |
+
in.seekg(0, in.end);
|
| 575 |
+
const size_t length = in.tellg();
|
| 576 |
+
in.seekg(0, in.beg);
|
| 577 |
+
return length;
|
| 578 |
+
}
|
| 579 |
+
|
| 580 |
+
bool readBinaryFromFile(std::string filePath, void* buffer, size_t bufferSize) {
|
| 581 |
+
if (nullptr == buffer) {
|
| 582 |
+
QNN_ERROR("buffer is nullptr");
|
| 583 |
+
return false;
|
| 584 |
+
}
|
| 585 |
+
std::ifstream in(filePath, std::ifstream::binary);
|
| 586 |
+
if (!in) {
|
| 587 |
+
QNN_ERROR("Failed to open input file: %s", filePath.c_str());
|
| 588 |
+
return false;
|
| 589 |
+
}
|
| 590 |
+
if (!in.read(reinterpret_cast<char*>(buffer), bufferSize)) {
|
| 591 |
+
QNN_ERROR("Failed to read the contents of: %s", filePath.c_str());
|
| 592 |
+
return false;
|
| 593 |
+
}
|
| 594 |
+
return true;
|
| 595 |
+
}
|
| 596 |
+
|
| 597 |
+
bool mmapBinaryFile(std::string filePath, void** buffer, size_t bufferSize) {
|
| 598 |
+
#ifndef _WIN32
|
| 599 |
+
int fd = open(filePath.c_str(), O_RDONLY);
|
| 600 |
+
int OFFSET = 0;
|
| 601 |
+
|
| 602 |
+
// read the binary file as memory map
|
| 603 |
+
*buffer = mmap(nullptr, bufferSize, PROT_READ, MAP_PRIVATE, fd, OFFSET);
|
| 604 |
+
close(fd);
|
| 605 |
+
if (madvise(*buffer, bufferSize, MADV_NOHUGEPAGE)) {
|
| 606 |
+
QNN_ERROR("Failed to advise OS on memory usage err: %s", strerror(errno));
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
return true;
|
| 610 |
+
#else
|
| 611 |
+
return false;
|
| 612 |
+
#endif
|
| 613 |
+
}
|
| 614 |
+
|
| 615 |
+
bool fillDims(std::vector<size_t>& dims, uint32_t* inDimensions, uint32_t rank) {
|
| 616 |
+
if (nullptr == inDimensions) {
|
| 617 |
+
QNN_ERROR("input dimensions is nullptr");
|
| 618 |
+
return false;
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
if (rank < 1) {
|
| 622 |
+
QNN_ERROR("invalid rank : %d", rank);
|
| 623 |
+
return false;
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
// In case, rank is less than 4, we are pushing 1s
|
| 627 |
+
for (size_t r = 0; r < 4 - rank; r++) {
|
| 628 |
+
dims.push_back(1);
|
| 629 |
+
}
|
| 630 |
+
|
| 631 |
+
for (size_t r = 0; r < rank; r++) {
|
| 632 |
+
dims.push_back(inDimensions[r]);
|
| 633 |
+
}
|
| 634 |
+
|
| 635 |
+
return true;
|
| 636 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "QnnInterface.h"
|
| 10 |
+
#include "QnnTypes.h"
|
| 11 |
+
#include "System/QnnSystemInterface.h"
|
| 12 |
+
|
| 13 |
+
#include <iostream>
|
| 14 |
+
#include <map>
|
| 15 |
+
#include <queue>
|
| 16 |
+
#include <string>
|
| 17 |
+
#include <unordered_map>
|
| 18 |
+
#include <vector>
|
| 19 |
+
|
| 20 |
+
#include "QnnTypeDef.hpp"
|
| 21 |
+
#include "Log.hpp"
|
| 22 |
+
|
| 23 |
+
/**
|
| 24 |
+
* @brief Frees all memory allocated tensor attributes.
|
| 25 |
+
*
|
| 26 |
+
* @param[in] tensorWrapper tensor object to free
|
| 27 |
+
*
|
| 28 |
+
* @return Error code
|
| 29 |
+
*/
|
| 30 |
+
bool freeQnnTensorWrapper(TensorWrapper& tensorWrapper);
|
| 31 |
+
|
| 32 |
+
/**
|
| 33 |
+
* @brief Loops through and frees all memory allocated tensor attributes for each tensorWrapper
|
| 34 |
+
* object.
|
| 35 |
+
*
|
| 36 |
+
* @param[in] tensorWrappers array of tensor objects to free
|
| 37 |
+
*
|
| 38 |
+
* @param[in] numTensors length of the above tensorWrappers array
|
| 39 |
+
*
|
| 40 |
+
* @return Error code
|
| 41 |
+
*/
|
| 42 |
+
bool freeQnnTensorWrappers(TensorWrapper*& tensorWrappers, uint32_t numTensors);
|
| 43 |
+
|
| 44 |
+
/**
|
| 45 |
+
* @brief A helper function to free memory malloced for communicating the Graph for a model(s)
|
| 46 |
+
*
|
| 47 |
+
* @param[in] graphsInfo Pointer pointing to location of graph objects
|
| 48 |
+
*
|
| 49 |
+
* @param[in] numGraphs The number of graph objects the above pointer is pointing to
|
| 50 |
+
*
|
| 51 |
+
* @return Error code
|
| 52 |
+
*
|
| 53 |
+
*/
|
| 54 |
+
bool freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs);
|
| 55 |
+
|
| 56 |
+
bool freeGraphInfo(GraphInfo_t* graphInfo);
|
| 57 |
+
|
| 58 |
+
bool copyMetadataToGraphsInfo(
|
| 59 |
+
const QnnSystemContext_BinaryInfo_t* binaryInfo,
|
| 60 |
+
GraphInfo_t**& graphsInfo,
|
| 61 |
+
uint32_t& graphsCount
|
| 62 |
+
);
|
| 63 |
+
|
| 64 |
+
bool copyGraphsInfo(
|
| 65 |
+
const QnnSystemContext_GraphInfo_t* graphsInput,
|
| 66 |
+
const uint32_t numGraphs,
|
| 67 |
+
GraphInfo_t**& graphsInfo
|
| 68 |
+
);
|
| 69 |
+
|
| 70 |
+
bool copyGraphsInfoV1(
|
| 71 |
+
const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
|
| 72 |
+
GraphInfo_t* graphInfoDst
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
bool copyTensorsInfo(
|
| 76 |
+
const Qnn_Tensor_t* tensorsInfoSrc,
|
| 77 |
+
TensorWrapper*& tensorWrappers,
|
| 78 |
+
uint32_t tensorsCount
|
| 79 |
+
);
|
| 80 |
+
|
| 81 |
+
bool fillDims(std::vector<size_t>& dims, uint32_t* inDimensions, uint32_t rank);
|
| 82 |
+
size_t getFileSize(std::string filePath);
|
| 83 |
+
bool readBinaryFromFile(std::string filePath, void* buffer, size_t bufferSize);
|
| 84 |
+
bool mmapBinaryFile(std::string filePath, void** buffer, size_t bufferSize);
|
| 85 |
+
bool updateMetaDataToGraphsInfo(const QnnSystemContext_BinaryInfo_t* binaryInfo,GraphInfo_t** graphsInfo,uint32_t& graphsCount);
|
| 86 |
+
bool updateGraphInfo(const QnnSystemContext_GraphInfo_t* graphsInput,
|
| 87 |
+
const uint32_t currCount,
|
| 88 |
+
GraphInfo_t* graphsInfo);
|
| 89 |
+
bool updateGraphInfoV1(const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
|
| 90 |
+
GraphInfo_t* graphInfoDst);
|
| 91 |
+
bool updateTensorInfo(const Qnn_Tensor_t* tensorsInfoSrc,
|
| 92 |
+
TensorWrapper* tensorWrappers,
|
| 93 |
+
uint32_t tensorsCount);
|
| 94 |
+
uint32_t getNumGraphInBinary(const QnnSystemContext_BinaryInfo_t* binaryInfo);
|
Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include "QnnGraph.h"
|
| 11 |
+
#include "QnnTypes.h"
|
| 12 |
+
#include <vector>
|
| 13 |
+
|
| 14 |
+
struct BackendExtensionsConfigs {
|
| 15 |
+
std::string sharedLibraryPath;
|
| 16 |
+
std::string configFilePath;
|
| 17 |
+
BackendExtensionsConfigs() : sharedLibraryPath(""), configFilePath("") {}
|
| 18 |
+
BackendExtensionsConfigs(std::string sharedLibraryPath, std::string configFilePath)
|
| 19 |
+
: sharedLibraryPath(sharedLibraryPath), configFilePath(configFilePath) {}
|
| 20 |
+
};
|
| 21 |
+
|
| 22 |
+
struct ContextConfigs {
|
| 23 |
+
bool priorityPresent;
|
| 24 |
+
Qnn_Priority_t priority;
|
| 25 |
+
ContextConfigs() : priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) {}
|
| 26 |
+
ContextConfigs(Qnn_Priority_t priority) : priorityPresent(true), priority(priority) {}
|
| 27 |
+
};
|
| 28 |
+
|
| 29 |
+
struct GraphConfigs {
|
| 30 |
+
std::string graphName;
|
| 31 |
+
bool priorityPresent;
|
| 32 |
+
Qnn_Priority_t priority;
|
| 33 |
+
GraphConfigs()
|
| 34 |
+
: graphName(),
|
| 35 |
+
priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) {
|
| 36 |
+
}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
struct ConfigOptions {
|
| 40 |
+
BackendExtensionsConfigs backendExtensionsConfigs;
|
| 41 |
+
ContextConfigs contextConfigs;
|
| 42 |
+
std::vector<GraphConfigs> graphConfigs;
|
| 43 |
+
ConfigOptions() : backendExtensionsConfigs(), contextConfigs(), graphConfigs() {}
|
| 44 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#ifndef QNN_TYPE_DEF_H_
|
| 10 |
+
#define QNN_TYPE_DEF_H_
|
| 11 |
+
|
| 12 |
+
#include "QnnInterface.h"
|
| 13 |
+
#include "QnnTypes.h"
|
| 14 |
+
#include "Log.hpp"
|
| 15 |
+
#include "QnnTypeMacros.hpp"
|
| 16 |
+
|
| 17 |
+
typedef enum ModelError {
|
| 18 |
+
MODEL_NO_ERROR = 0,
|
| 19 |
+
MODEL_TENSOR_ERROR = 1,
|
| 20 |
+
MODEL_PARAMS_ERROR = 2,
|
| 21 |
+
MODEL_NODES_ERROR = 3,
|
| 22 |
+
MODEL_GRAPH_ERROR = 4,
|
| 23 |
+
MODEL_CONTEXT_ERROR = 5,
|
| 24 |
+
MODEL_GENERATION_ERROR = 6,
|
| 25 |
+
MODEL_SETUP_ERROR = 7,
|
| 26 |
+
MODEL_INVALID_ARGUMENT_ERROR = 8,
|
| 27 |
+
MODEL_FILE_ERROR = 9,
|
| 28 |
+
MODEL_MEMORY_ALLOCATE_ERROR = 10,
|
| 29 |
+
// Value selected to ensure 32 bits.
|
| 30 |
+
MODEL_UNKNOWN_ERROR = 0x7FFFFFFF
|
| 31 |
+
} ModelError_t;
|
| 32 |
+
|
| 33 |
+
using TensorWrapper = Qnn_Tensor_t;
|
| 34 |
+
#define GET_TENSOR_WRAPPER_TENSOR(tensorWrapper) tensorWrapper
|
| 35 |
+
#define GET_TENSOR_WRAPPER_NAME(tensorWrapper) QNN_TENSOR_GET_NAME(tensorWrapper)
|
| 36 |
+
|
| 37 |
+
typedef struct GraphInfo {
|
| 38 |
+
Qnn_GraphHandle_t graph;
|
| 39 |
+
char* graphName;
|
| 40 |
+
TensorWrapper* inputTensors;
|
| 41 |
+
uint32_t numInputTensors;
|
| 42 |
+
TensorWrapper* outputTensors;
|
| 43 |
+
uint32_t numOutputTensors;
|
| 44 |
+
} GraphInfo_t;
|
| 45 |
+
typedef GraphInfo_t* GraphInfoPtr_t;
|
| 46 |
+
|
| 47 |
+
typedef struct GraphConfigInfo {
|
| 48 |
+
char* graphName;
|
| 49 |
+
const QnnGraph_Config_t** graphConfigs;
|
| 50 |
+
} GraphConfigInfo_t;
|
| 51 |
+
|
| 52 |
+
#endif // QNN_TYPE_DEF_H_
|
Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp
ADDED
|
@@ -0,0 +1,702 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include "QnnTypes.h"
|
| 12 |
+
|
| 13 |
+
#define QNN_OP_CFG_VALID(opConfig) ((opConfig).version == QNN_OPCONFIG_VERSION_1)
|
| 14 |
+
|
| 15 |
+
inline Qnn_OpConfig_t createQnnOpConfig(const Qnn_OpConfigVersion_t version) {
|
| 16 |
+
Qnn_OpConfig_t opConfig = QNN_OPCONFIG_INIT;
|
| 17 |
+
opConfig.version = version;
|
| 18 |
+
if (version == QNN_OPCONFIG_VERSION_1) {
|
| 19 |
+
opConfig.v1 = QNN_OPCONFIG_V1_INIT;
|
| 20 |
+
}
|
| 21 |
+
return opConfig;
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
inline const char* getQnnOpConfigName(const Qnn_OpConfig_t& opConfig) {
|
| 25 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 26 |
+
return opConfig.v1.name;
|
| 27 |
+
}
|
| 28 |
+
return NULL;
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
inline const char* getQnnOpConfigName(const Qnn_OpConfig_t* const opConfig) {
|
| 32 |
+
return getQnnOpConfigName(*opConfig);
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t& opConfig) {
|
| 36 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 37 |
+
return opConfig.v1.packageName;
|
| 38 |
+
}
|
| 39 |
+
return NULL;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t* const opConfig) {
|
| 43 |
+
return getQnnOpConfigPackageName(*opConfig);
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t& opConfig) {
|
| 47 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 48 |
+
return opConfig.v1.typeName;
|
| 49 |
+
}
|
| 50 |
+
return NULL;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t* const opConfig) {
|
| 54 |
+
return getQnnOpConfigTypeName(*opConfig);
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t& opConfig) {
|
| 58 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 59 |
+
return opConfig.v1.numOfParams;
|
| 60 |
+
}
|
| 61 |
+
return 0u;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t* const opConfig) {
|
| 65 |
+
return getQnnOpConfigNumParams(*opConfig);
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t& opConfig) {
|
| 69 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 70 |
+
return opConfig.v1.params;
|
| 71 |
+
}
|
| 72 |
+
return NULL;
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t* const opConfig) {
|
| 76 |
+
return getQnnOpConfigParams(*opConfig);
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t& opConfig) {
|
| 80 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 81 |
+
return opConfig.v1.numOfInputs;
|
| 82 |
+
}
|
| 83 |
+
return 0u;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t* const opConfig) {
|
| 87 |
+
return getQnnOpConfigNumInputs(*opConfig);
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t& opConfig) {
|
| 91 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 92 |
+
return opConfig.v1.inputTensors;
|
| 93 |
+
}
|
| 94 |
+
return NULL;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t* const opConfig) {
|
| 98 |
+
return getQnnOpConfigInputs(*opConfig);
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t& opConfig) {
|
| 102 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 103 |
+
return opConfig.v1.numOfOutputs;
|
| 104 |
+
}
|
| 105 |
+
return 0u;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t* const opConfig) {
|
| 109 |
+
return getQnnOpConfigNumOutputs(*opConfig);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t& opConfig) {
|
| 113 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 114 |
+
return opConfig.v1.outputTensors;
|
| 115 |
+
}
|
| 116 |
+
return NULL;
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t* const opConfig) {
|
| 120 |
+
return getQnnOpConfigOutputs(*opConfig);
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
inline void setQnnOpConfigName(Qnn_OpConfig_t& opConfig, const char* const name) {
|
| 124 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 125 |
+
opConfig.v1.name = name;
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
inline void setQnnOpConfigName(Qnn_OpConfig_t* const opConfig, const char* const name) {
|
| 130 |
+
setQnnOpConfigName(*opConfig, name);
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
inline void setQnnOpConfigPackageName(Qnn_OpConfig_t& opConfig, const char* const packageName) {
|
| 134 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 135 |
+
opConfig.v1.packageName = packageName;
|
| 136 |
+
}
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
inline void setQnnOpConfigPackageName(
|
| 140 |
+
Qnn_OpConfig_t* const opConfig,
|
| 141 |
+
const char* const packageName
|
| 142 |
+
) {
|
| 143 |
+
setQnnOpConfigPackageName(*opConfig, packageName);
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
inline void setQnnOpConfigTypeName(Qnn_OpConfig_t& opConfig, const char* const typeName) {
|
| 147 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 148 |
+
opConfig.v1.typeName = typeName;
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
inline void setQnnOpConfigTypeName(Qnn_OpConfig_t* const opConfig, const char* const typeName) {
|
| 153 |
+
setQnnOpConfigTypeName(*opConfig, typeName);
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
inline void setQnnOpConfigParams(
|
| 157 |
+
Qnn_OpConfig_t& opConfig,
|
| 158 |
+
uint32_t const numOfParams,
|
| 159 |
+
Qnn_Param_t* const params
|
| 160 |
+
) {
|
| 161 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 162 |
+
opConfig.v1.numOfParams = numOfParams;
|
| 163 |
+
opConfig.v1.params = params;
|
| 164 |
+
}
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
inline void setQnnOpConfigParams(
|
| 168 |
+
Qnn_OpConfig_t* const opConfig,
|
| 169 |
+
uint32_t const numOfParams,
|
| 170 |
+
Qnn_Param_t* const params
|
| 171 |
+
) {
|
| 172 |
+
setQnnOpConfigParams(*opConfig, numOfParams, params);
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
inline void setQnnOpConfigInputs(
|
| 176 |
+
Qnn_OpConfig_t& opConfig,
|
| 177 |
+
uint32_t const numOfInputs,
|
| 178 |
+
Qnn_Tensor_t* const inputTensors
|
| 179 |
+
) {
|
| 180 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 181 |
+
opConfig.v1.numOfInputs = numOfInputs;
|
| 182 |
+
opConfig.v1.inputTensors = inputTensors;
|
| 183 |
+
}
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
inline void setQnnOpConfigInputs(
|
| 187 |
+
Qnn_OpConfig_t* const opConfig,
|
| 188 |
+
uint32_t const numOfInputs,
|
| 189 |
+
Qnn_Tensor_t* const inputTensors
|
| 190 |
+
) {
|
| 191 |
+
setQnnOpConfigInputs(*opConfig, numOfInputs, inputTensors);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
inline void setQnnOpConfigOutputs(
|
| 195 |
+
Qnn_OpConfig_t& opConfig,
|
| 196 |
+
uint32_t const numOfOutputs,
|
| 197 |
+
Qnn_Tensor_t* const outputTensors
|
| 198 |
+
) {
|
| 199 |
+
if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
|
| 200 |
+
opConfig.v1.numOfOutputs = numOfOutputs;
|
| 201 |
+
opConfig.v1.outputTensors = outputTensors;
|
| 202 |
+
}
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
inline void setQnnOpConfigOutputs(
|
| 206 |
+
Qnn_OpConfig_t* const opConfig,
|
| 207 |
+
uint32_t const numOfOutputs,
|
| 208 |
+
Qnn_Tensor_t* const outputTensors
|
| 209 |
+
) {
|
| 210 |
+
setQnnOpConfigOutputs(*opConfig, numOfOutputs, outputTensors);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
inline Qnn_Tensor_t createQnnTensor(const Qnn_TensorVersion_t version) {
|
| 214 |
+
Qnn_Tensor_t tensor = QNN_TENSOR_INIT;
|
| 215 |
+
tensor.version = version;
|
| 216 |
+
if (version == QNN_TENSOR_VERSION_1) {
|
| 217 |
+
tensor.v1 = QNN_TENSOR_V1_INIT;
|
| 218 |
+
} else if (version == QNN_TENSOR_VERSION_2) {
|
| 219 |
+
tensor.v2 = QNN_TENSOR_V2_INIT;
|
| 220 |
+
}
|
| 221 |
+
return tensor;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
inline uint32_t getQnnTensorId(const Qnn_Tensor_t& tensor) {
|
| 225 |
+
// TensorCompatTest justifies no need to check version
|
| 226 |
+
return tensor.v1.id;
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
inline uint32_t getQnnTensorId(const Qnn_Tensor_t* const tensor) {
|
| 230 |
+
return getQnnTensorId(*tensor);
|
| 231 |
+
}
|
| 232 |
+
|
| 233 |
+
inline const char* getQnnTensorName(const Qnn_Tensor_t& tensor) {
|
| 234 |
+
// TensorCompatTest justifies no need to check version
|
| 235 |
+
return tensor.v1.name;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
inline const char* getQnnTensorName(const Qnn_Tensor_t* const tensor) {
|
| 239 |
+
return getQnnTensorName(*tensor);
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t& tensor) {
|
| 243 |
+
// TensorCompatTest justifies no need to check version
|
| 244 |
+
return tensor.v1.type;
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t* const tensor) {
|
| 248 |
+
return getQnnTensorType(*tensor);
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t& tensor) {
|
| 252 |
+
// TensorCompatTest justifies no need to check version
|
| 253 |
+
return tensor.v1.dataFormat;
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t* const tensor) {
|
| 257 |
+
return getQnnTensorDataFormat(*tensor);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t& tensor) {
|
| 261 |
+
// TensorCompatTest justifies no need to check version
|
| 262 |
+
return tensor.v1.dataType;
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t* const tensor) {
|
| 266 |
+
return getQnnTensorDataType(*tensor);
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t& tensor) {
|
| 270 |
+
// TensorCompatTest justifies no need to check version
|
| 271 |
+
return tensor.v1.quantizeParams;
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t* const tensor) {
|
| 275 |
+
if (tensor != nullptr) {
|
| 276 |
+
return getQnnTensorQuantParams(*tensor);
|
| 277 |
+
}
|
| 278 |
+
return QNN_QUANTIZE_PARAMS_INIT;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
inline uint32_t getQnnTensorRank(const Qnn_Tensor_t& tensor) {
|
| 282 |
+
// TensorCompatTest justifies no need to check version
|
| 283 |
+
return tensor.v1.rank;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
inline uint32_t getQnnTensorRank(const Qnn_Tensor_t* const tensor) {
|
| 287 |
+
if (tensor != nullptr) {
|
| 288 |
+
return getQnnTensorRank(*tensor);
|
| 289 |
+
}
|
| 290 |
+
return 0u;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t& tensor) {
|
| 294 |
+
// TensorCompatTest justifies no need to check version
|
| 295 |
+
return tensor.v1.dimensions;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t* const tensor) {
|
| 299 |
+
return getQnnTensorDimensions(*tensor);
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
inline uint8_t* getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& tensor) {
|
| 303 |
+
if (tensor.version == QNN_TENSOR_VERSION_1) {
|
| 304 |
+
return NULL;
|
| 305 |
+
} else if (tensor.version == QNN_TENSOR_VERSION_2) {
|
| 306 |
+
return tensor.v2.isDynamicDimensions;
|
| 307 |
+
}
|
| 308 |
+
return NULL;
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
inline uint8_t* getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t* tensor) {
|
| 312 |
+
return getQnnTensorIsDynamicDimensions(*tensor);
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t& tensor) {
|
| 316 |
+
if (tensor.version == QNN_TENSOR_VERSION_1) {
|
| 317 |
+
return QNN_SPARSE_PARAMS_INIT;
|
| 318 |
+
} else if (tensor.version == QNN_TENSOR_VERSION_2) {
|
| 319 |
+
return tensor.v2.sparseParams;
|
| 320 |
+
}
|
| 321 |
+
return QNN_SPARSE_PARAMS_INIT;
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t* tensor) {
|
| 325 |
+
return getQnnTensorSparseParams(*tensor);
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t& tensor) {
|
| 329 |
+
// TensorCompatTest justifies no need to check version
|
| 330 |
+
return tensor.v1.memType;
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t* const tensor) {
|
| 334 |
+
return getQnnTensorMemType(*tensor);
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t& tensor) {
|
| 338 |
+
// TensorCompatTest justifies no need to check version
|
| 339 |
+
return tensor.v1.clientBuf;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t* const tensor) {
|
| 343 |
+
return getQnnTensorClientBuf(*tensor);
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t& tensor) {
|
| 347 |
+
// TensorCompatTest justifies no need to check version
|
| 348 |
+
return tensor.v1.memHandle;
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t* const tensor) {
|
| 352 |
+
return getQnnTensorMemHandle(*tensor);
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
inline void setQnnTensorId(Qnn_Tensor_t& tensor, const uint32_t id) {
|
| 356 |
+
// TensorCompatTest justifies no need to check version
|
| 357 |
+
tensor.v1.id = id;
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
inline void setQnnTensorId(Qnn_Tensor_t* const tensor, const uint32_t id) {
|
| 361 |
+
setQnnTensorId(*tensor, id);
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
inline void setQnnTensorName(Qnn_Tensor_t& tensor, const char* const name) {
|
| 365 |
+
// TensorCompatTest justifies no need to check version
|
| 366 |
+
tensor.v1.name = name;
|
| 367 |
+
}
|
| 368 |
+
|
| 369 |
+
inline void setQnnTensorName(Qnn_Tensor_t* const tensor, const char* const name) {
|
| 370 |
+
setQnnTensorName(*tensor, name);
|
| 371 |
+
}
|
| 372 |
+
|
| 373 |
+
inline void setQnnTensorType(Qnn_Tensor_t& tensor, const Qnn_TensorType_t type) {
|
| 374 |
+
// TensorCompatTest justifies no need to check version
|
| 375 |
+
tensor.v1.type = type;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
inline void setQnnTensorType(Qnn_Tensor_t* const tensor, const Qnn_TensorType_t type) {
|
| 379 |
+
setQnnTensorType(*tensor, type);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
inline void setQnnTensorDataFormat(Qnn_Tensor_t& tensor, const Qnn_TensorDataFormat_t dataFormat) {
|
| 383 |
+
// TensorCompatTest justifies no need to check version
|
| 384 |
+
tensor.v1.dataFormat = dataFormat;
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
inline void setQnnTensorDataFormat(
|
| 388 |
+
Qnn_Tensor_t* const tensor,
|
| 389 |
+
const Qnn_TensorDataFormat_t format
|
| 390 |
+
) {
|
| 391 |
+
setQnnTensorDataFormat(*tensor, format);
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
inline void setQnnTensorDataType(Qnn_Tensor_t& tensor, const Qnn_DataType_t dataType) {
|
| 395 |
+
// TensorCompatTest justifies no need to check version
|
| 396 |
+
tensor.v1.dataType = dataType;
|
| 397 |
+
}
|
| 398 |
+
|
| 399 |
+
inline void setQnnTensorDataType(Qnn_Tensor_t* const tensor, const Qnn_DataType_t dataType) {
|
| 400 |
+
setQnnTensorDataType(*tensor, dataType);
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
inline void setQnnTensorQuantParams(
|
| 404 |
+
Qnn_Tensor_t& tensor,
|
| 405 |
+
const Qnn_QuantizeParams_t quantizeParams
|
| 406 |
+
) {
|
| 407 |
+
// TensorCompatTest justifies no need to check version
|
| 408 |
+
tensor.v1.quantizeParams = quantizeParams;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
inline void setQnnTensorQuantParams(Qnn_Tensor_t* const tensor, const Qnn_QuantizeParams_t params) {
|
| 412 |
+
setQnnTensorQuantParams(*tensor, params);
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
inline void setQnnTensorRank(Qnn_Tensor_t& tensor, const uint32_t rank) {
|
| 416 |
+
// TensorCompatTest justifies no need to check version
|
| 417 |
+
tensor.v1.rank = rank;
|
| 418 |
+
}
|
| 419 |
+
|
| 420 |
+
inline void setQnnTensorRank(Qnn_Tensor_t* const tensor, const uint32_t rank) {
|
| 421 |
+
setQnnTensorRank(*tensor, rank);
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
inline void setQnnTensorDimensions(Qnn_Tensor_t& tensor, uint32_t* const dimensions) {
|
| 425 |
+
// TensorCompatTest justifies no need to check version
|
| 426 |
+
tensor.v1.dimensions = dimensions;
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
inline void setQnnTensorDimensions(Qnn_Tensor_t* const tensor, uint32_t* const dimensions) {
|
| 430 |
+
setQnnTensorDimensions(*tensor, dimensions);
|
| 431 |
+
}
|
| 432 |
+
|
| 433 |
+
inline void setQnnTensorIsDynamicDimensions(
|
| 434 |
+
Qnn_Tensor_t& tensor,
|
| 435 |
+
uint8_t* const isDynamicDimensions
|
| 436 |
+
) {
|
| 437 |
+
if (tensor.version == QNN_TENSOR_VERSION_2) {
|
| 438 |
+
tensor.v2.isDynamicDimensions = isDynamicDimensions;
|
| 439 |
+
}
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
inline void setQnnTensorIsDynamicDimensions(
|
| 443 |
+
Qnn_Tensor_t* tensor,
|
| 444 |
+
uint8_t* const isDynamicDimensions
|
| 445 |
+
) {
|
| 446 |
+
setQnnTensorIsDynamicDimensions(*tensor, isDynamicDimensions);
|
| 447 |
+
}
|
| 448 |
+
|
| 449 |
+
inline void setQnnTensorSparseParams(Qnn_Tensor_t& tensor, const Qnn_SparseParams_t sparseParams) {
|
| 450 |
+
if (tensor.version == QNN_TENSOR_VERSION_2) {
|
| 451 |
+
tensor.v2.sparseParams = sparseParams;
|
| 452 |
+
}
|
| 453 |
+
}
|
| 454 |
+
|
| 455 |
+
inline void setQnnTensorSparseParams(Qnn_Tensor_t* tensor, Qnn_SparseParams_t sparseParams) {
|
| 456 |
+
setQnnTensorSparseParams(*tensor, sparseParams);
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
inline void setQnnTensorMemType(Qnn_Tensor_t& tensor, const Qnn_TensorMemType_t memType) {
|
| 460 |
+
// TensorCompatTest justifies no need to check version
|
| 461 |
+
tensor.v1.memType = memType;
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
inline void setQnnTensorMemType(Qnn_Tensor_t* const tensor, const Qnn_TensorMemType_t memType) {
|
| 465 |
+
setQnnTensorMemType(*tensor, memType);
|
| 466 |
+
}
|
| 467 |
+
|
| 468 |
+
inline void setQnnTensorClientBuf(Qnn_Tensor_t& tensor, const Qnn_ClientBuffer_t clientBuf) {
|
| 469 |
+
// TensorCompatTest justifies no need to check version
|
| 470 |
+
tensor.v1.clientBuf = clientBuf;
|
| 471 |
+
}
|
| 472 |
+
|
| 473 |
+
inline void setQnnTensorClientBuf(Qnn_Tensor_t* const tensor, const Qnn_ClientBuffer_t clientBuf) {
|
| 474 |
+
setQnnTensorClientBuf(*tensor, clientBuf);
|
| 475 |
+
}
|
| 476 |
+
|
| 477 |
+
inline void setQnnTensorMemHandle(Qnn_Tensor_t& tensor, const Qnn_MemHandle_t memHandle) {
|
| 478 |
+
// TensorCompatTest justifies no need to check version
|
| 479 |
+
tensor.v1.memHandle = memHandle;
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
inline void setQnnTensorMemHandle(Qnn_Tensor_t* const tensor, const Qnn_MemHandle_t handle) {
|
| 483 |
+
setQnnTensorMemHandle(*tensor, handle);
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
inline Qnn_TensorSet_t createQnnTensorSet(const Qnn_TensorSetVersion_t version) {
|
| 487 |
+
Qnn_TensorSet_t tensorSet = QNN_TENSOR_SET_INIT;
|
| 488 |
+
tensorSet.version = version;
|
| 489 |
+
if (version == QNN_TENSOR_SET_VERSION_1) {
|
| 490 |
+
tensorSet.v1 = QNN_TENSOR_SET_V1_INIT;
|
| 491 |
+
}
|
| 492 |
+
return tensorSet;
|
| 493 |
+
}
|
| 494 |
+
|
| 495 |
+
inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t& tensorSet) {
|
| 496 |
+
if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
|
| 497 |
+
return tensorSet.v1.numInputs;
|
| 498 |
+
}
|
| 499 |
+
return 0;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t* tensorSet) {
|
| 503 |
+
return getQnnTensorSetNumInputs(*tensorSet);
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
inline Qnn_Tensor_t* getQnnTensorSetInputTensors(const Qnn_TensorSet_t& tensorSet) {
|
| 507 |
+
if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
|
| 508 |
+
return tensorSet.v1.inputs;
|
| 509 |
+
}
|
| 510 |
+
return 0;
|
| 511 |
+
}
|
| 512 |
+
|
| 513 |
+
inline Qnn_Tensor_t* getQnnTensorSetInputTensors(const Qnn_TensorSet_t* tensorSet) {
|
| 514 |
+
return getQnnTensorSetInputTensors(*tensorSet);
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t& tensorSet) {
|
| 518 |
+
if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
|
| 519 |
+
return tensorSet.v1.numOutputs;
|
| 520 |
+
}
|
| 521 |
+
return 0;
|
| 522 |
+
}
|
| 523 |
+
|
| 524 |
+
inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t* tensorSet) {
|
| 525 |
+
return getQnnTensorSetNumOutputs(*tensorSet);
|
| 526 |
+
}
|
| 527 |
+
|
| 528 |
+
inline Qnn_Tensor_t* getQnnTensorSetOutputTensors(const Qnn_TensorSet_t& tensorSet) {
|
| 529 |
+
if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
|
| 530 |
+
return tensorSet.v1.outputs;
|
| 531 |
+
}
|
| 532 |
+
return 0;
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
inline Qnn_Tensor_t* getQnnTensorSetOutputTensors(const Qnn_TensorSet_t* tensorSet) {
|
| 536 |
+
return getQnnTensorSetOutputTensors(*tensorSet);
|
| 537 |
+
}
|
| 538 |
+
|
| 539 |
+
inline void setQnnTensorSetInputTensors(
|
| 540 |
+
Qnn_TensorSet_t& tensorSet,
|
| 541 |
+
Qnn_Tensor_t* inputTensors,
|
| 542 |
+
uint32_t const numInputs
|
| 543 |
+
) {
|
| 544 |
+
if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
|
| 545 |
+
tensorSet.v1.inputs = inputTensors;
|
| 546 |
+
tensorSet.v1.numInputs = numInputs;
|
| 547 |
+
}
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
inline void setQnnTensorSetInputTensors(
|
| 551 |
+
Qnn_TensorSet_t* tensorSet,
|
| 552 |
+
Qnn_Tensor_t* inputTensors,
|
| 553 |
+
uint32_t const numInputs
|
| 554 |
+
) {
|
| 555 |
+
setQnnTensorSetInputTensors(*tensorSet, inputTensors, numInputs);
|
| 556 |
+
}
|
| 557 |
+
|
| 558 |
+
inline void setQnnTensorSetOutputTensors(
|
| 559 |
+
Qnn_TensorSet_t& tensorSet,
|
| 560 |
+
Qnn_Tensor_t* outputTensors,
|
| 561 |
+
const uint32_t numOutputs
|
| 562 |
+
) {
|
| 563 |
+
if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
|
| 564 |
+
tensorSet.v1.outputs = outputTensors;
|
| 565 |
+
tensorSet.v1.numOutputs = numOutputs;
|
| 566 |
+
}
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
inline void setQnnTensorSetOutputTensors(
|
| 570 |
+
Qnn_TensorSet_t* tensorSet,
|
| 571 |
+
Qnn_Tensor_t* outputTensors,
|
| 572 |
+
const uint32_t numOutputs
|
| 573 |
+
) {
|
| 574 |
+
setQnnTensorSetOutputTensors(*tensorSet, outputTensors, numOutputs);
|
| 575 |
+
}
|
| 576 |
+
|
| 577 |
+
// Creator for QNN Op Config
|
| 578 |
+
#define QNN_OP_CFG_CREATE(version) createQnnOpConfig(version)
|
| 579 |
+
|
| 580 |
+
// Accessors for QNN Op Config
|
| 581 |
+
#define QNN_OP_CFG_GET_NAME(opConfig) getQnnOpConfigName(opConfig)
|
| 582 |
+
#define QNN_OP_CFG_GET_PACKAGE_NAME(opConfig) getQnnOpConfigPackageName(opConfig)
|
| 583 |
+
#define QNN_OP_CFG_GET_TYPE_NAME(opConfig) getQnnOpConfigTypeName(opConfig)
|
| 584 |
+
#define QNN_OP_CFG_GET_NUM_PARAMS(opConfig) getQnnOpConfigNumParams(opConfig)
|
| 585 |
+
#define QNN_OP_CFG_GET_PARAMS(opConfig) getQnnOpConfigParams(opConfig)
|
| 586 |
+
#define QNN_OP_CFG_GET_NUM_INPUTS(opConfig) getQnnOpConfigNumInputs(opConfig)
|
| 587 |
+
#define QNN_OP_CFG_GET_INPUTS(opConfig) getQnnOpConfigInputs(opConfig)
|
| 588 |
+
#define QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) getQnnOpConfigNumOutputs(opConfig)
|
| 589 |
+
#define QNN_OP_CFG_GET_OUTPUTS(opConfig) getQnnOpConfigOutputs(opConfig)
|
| 590 |
+
|
| 591 |
+
// Modifiers for QNN Op Config
|
| 592 |
+
#define QNN_OP_CFG_SET_NAME(opConfig, value) setQnnOpConfigName(opConfig, value)
|
| 593 |
+
#define QNN_OP_CFG_SET_PACKAGE_NAME(opConfig, value) setQnnOpConfigPackageName(opConfig, value)
|
| 594 |
+
#define QNN_OP_CFG_SET_TYPE_NAME(opConfig, value) setQnnOpConfigTypeName(opConfig, value)
|
| 595 |
+
#define QNN_OP_CFG_SET_PARAMS(opConfig, numOfParams, params) \
|
| 596 |
+
setQnnOpConfigParams(opConfig, numOfParams, params)
|
| 597 |
+
#define QNN_OP_CFG_SET_INPUTS(opConfig, numOfInputs, inputTensors) \
|
| 598 |
+
setQnnOpConfigInputs(opConfig, numOfInputs, inputTensors)
|
| 599 |
+
#define QNN_OP_CFG_SET_OUTPUTS(opConfig, numOfOutputs, outputTensors) \
|
| 600 |
+
setQnnOpConfigOutputs(opConfig, numOfOutputs, outputTensors)
|
| 601 |
+
|
| 602 |
+
// Creator for QNN Tensor
|
| 603 |
+
#define QNN_TENSOR_CREATE(version) createQnnTensor(version)
|
| 604 |
+
|
| 605 |
+
// Accessors for QNN Tensor
|
| 606 |
+
#define QNN_TENSOR_GET_ID(tensor) getQnnTensorId(tensor)
|
| 607 |
+
#define QNN_TENSOR_GET_NAME(tensor) getQnnTensorName(tensor)
|
| 608 |
+
#define QNN_TENSOR_GET_TYPE(tensor) getQnnTensorType(tensor)
|
| 609 |
+
#define QNN_TENSOR_GET_DATA_FORMAT(tensor) getQnnTensorDataFormat(tensor)
|
| 610 |
+
#define QNN_TENSOR_GET_DATA_TYPE(tensor) getQnnTensorDataType(tensor)
|
| 611 |
+
#define QNN_TENSOR_GET_QUANT_PARAMS(tensor) getQnnTensorQuantParams(tensor)
|
| 612 |
+
#define QNN_TENSOR_GET_RANK(tensor) getQnnTensorRank(tensor)
|
| 613 |
+
#define QNN_TENSOR_GET_DIMENSIONS(tensor) getQnnTensorDimensions(tensor)
|
| 614 |
+
#define QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(tensor) getQnnTensorIsDynamicDimensions(tensor)
|
| 615 |
+
#define QNN_TENSOR_GET_SPARSE_PARAMS(tensor) getQnnTensorSparseParams(tensor)
|
| 616 |
+
#define QNN_TENSOR_GET_MEM_TYPE(tensor) getQnnTensorMemType(tensor)
|
| 617 |
+
#define QNN_TENSOR_GET_CLIENT_BUF(tensor) getQnnTensorClientBuf(tensor)
|
| 618 |
+
#define QNN_TENSOR_GET_MEM_HANDLE(tensor) getQnnTensorMemHandle(tensor)
|
| 619 |
+
|
| 620 |
+
// Modifiers for QNN Tensor
|
| 621 |
+
#define QNN_TENSOR_SET_ID(tensor, value) setQnnTensorId(tensor, value)
|
| 622 |
+
#define QNN_TENSOR_SET_NAME(tensor, value) setQnnTensorName(tensor, value)
|
| 623 |
+
#define QNN_TENSOR_SET_TYPE(tensor, value) setQnnTensorType(tensor, value)
|
| 624 |
+
#define QNN_TENSOR_SET_DATA_FORMAT(tensor, value) setQnnTensorDataFormat(tensor, value)
|
| 625 |
+
#define QNN_TENSOR_SET_DATA_TYPE(tensor, value) setQnnTensorDataType(tensor, value)
|
| 626 |
+
#define QNN_TENSOR_SET_QUANT_PARAMS(tensor, value) setQnnTensorQuantParams(tensor, value)
|
| 627 |
+
#define QNN_TENSOR_SET_RANK(tensor, value) setQnnTensorRank(tensor, value)
|
| 628 |
+
#define QNN_TENSOR_SET_DIMENSIONS(tensor, value) setQnnTensorDimensions(tensor, value)
|
| 629 |
+
#define QNN_TENSOR_SET_IS_DYNAMIC_DIMENSIONS(tensor, value) \
|
| 630 |
+
setQnnTensorIsDynamicDimensions(tensor, value)
|
| 631 |
+
#define QNN_TENSOR_SET_SPARSE_PARAMS(tensor, value) setQnnTensorSparseParams(tensor, value)
|
| 632 |
+
#define QNN_TENSOR_SET_MEM_TYPE(tensor, value) setQnnTensorMemType(tensor, value)
|
| 633 |
+
#define QNN_TENSOR_SET_CLIENT_BUF(tensor, value) setQnnTensorClientBuf(tensor, value)
|
| 634 |
+
#define QNN_TENSOR_SET_MEM_HANDLE(tensor, value) setQnnTensorMemHandle(tensor, value)
|
| 635 |
+
|
| 636 |
+
// Creator for QNN Tensor Set
|
| 637 |
+
#define QNN_TENSORSET_CREATE(version) createQnnTensorSet(version)
|
| 638 |
+
|
| 639 |
+
// Accessors for QNN Tensor Set
|
| 640 |
+
#define QNN_TENSORSET_GET_NUM_INPUTS(tensorSet) getQnnTensorSetNumInputs(tensorSet)
|
| 641 |
+
#define QNN_TENSORSET_GET_INPUT_TENSORS(tensorSet) getQnnTensorSetInputTensors(tensorSet)
|
| 642 |
+
#define QNN_TENSORSET_GET_NUM_OUTPUTS(tensorSet) getQnnTensorSetNumOutputs(tensorSet)
|
| 643 |
+
#define QNN_TENSORSET_GET_OUTPUT_TENSORS(tensorSet) getQnnTensorSetOutputTensors(tensorSet)
|
| 644 |
+
|
| 645 |
+
// Modifiers for QNN Tensor Set
|
| 646 |
+
#define QNN_TENSORSET_SET_INPUT_TENSORS(tensorSet, inputTensors, numInputs) \
|
| 647 |
+
setQnnTensorSetInputTensors(tensorSet, inputTensors, numInputs)
|
| 648 |
+
#define QNN_TENSORSET_SET_OUTPUT_TENSORS(tensorSet, outputTensors, numOutputs) \
|
| 649 |
+
setQnnTensorSetOutputTensors(tensorSet, outputTensors, numOutputs)
|
| 650 |
+
|
| 651 |
+
inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t& tensor) {
|
| 652 |
+
if (tensor.version == QNN_TENSOR_VERSION_2) {
|
| 653 |
+
if (tensor.v2.isDynamicDimensions != NULL) {
|
| 654 |
+
return false;
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
if (tensor.v2.dataFormat == QNN_TENSOR_DATA_FORMAT_SPARSE) {
|
| 658 |
+
return false;
|
| 659 |
+
}
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
return true;
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t* const tensor) {
|
| 666 |
+
return isQnnTensorV1Compatible(*tensor);
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t& opConfig) {
|
| 670 |
+
if ((QNN_OP_CFG_GET_INPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_INPUTS(opConfig) > 0u)) {
|
| 671 |
+
for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_INPUTS(opConfig);
|
| 672 |
+
tensorIdx++) {
|
| 673 |
+
if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_INPUTS(opConfig)[tensorIdx])) {
|
| 674 |
+
return false;
|
| 675 |
+
}
|
| 676 |
+
}
|
| 677 |
+
}
|
| 678 |
+
if ((QNN_OP_CFG_GET_OUTPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) > 0u)) {
|
| 679 |
+
for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig);
|
| 680 |
+
tensorIdx++) {
|
| 681 |
+
if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_OUTPUTS(opConfig)[tensorIdx])) {
|
| 682 |
+
return false;
|
| 683 |
+
}
|
| 684 |
+
}
|
| 685 |
+
}
|
| 686 |
+
if ((QNN_OP_CFG_GET_PARAMS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_PARAMS(opConfig) > 0)) {
|
| 687 |
+
for (uint32_t paramIdx = 0u; paramIdx < QNN_OP_CFG_GET_NUM_PARAMS(opConfig); paramIdx++) {
|
| 688 |
+
const Qnn_Param_t& param = QNN_OP_CFG_GET_PARAMS(opConfig)[paramIdx];
|
| 689 |
+
if (QNN_PARAMTYPE_TENSOR == param.paramType) {
|
| 690 |
+
if (!isQnnTensorV1Compatible(param.tensorParam)) {
|
| 691 |
+
return false;
|
| 692 |
+
}
|
| 693 |
+
}
|
| 694 |
+
}
|
| 695 |
+
}
|
| 696 |
+
|
| 697 |
+
return true;
|
| 698 |
+
}
|
| 699 |
+
|
| 700 |
+
inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t* const opConfig) {
|
| 701 |
+
return isQnnTensorV1Compatible(*opConfig);
|
| 702 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "QnnMem.h"
|
| 10 |
+
#include "QnnHtpMem.h"
|
| 11 |
+
#include "RpcMem.hpp"
|
| 12 |
+
#include "QnnTypeMacros.hpp"
|
| 13 |
+
#include "dlwrap.hpp"
|
| 14 |
+
|
| 15 |
+
#define RPCMEM_HEAP_ID_SYSTEM 25
|
| 16 |
+
#define RPCMEM_DEFAULT_FLAGS 1
|
| 17 |
+
|
| 18 |
+
#if 1
|
| 19 |
+
#define TRACE_MEMORY_ALLOC QNN_DEBUG
|
| 20 |
+
#else
|
| 21 |
+
#define TRACE_MEMORY_ALLOC(fmt, ...)
|
| 22 |
+
#endif
|
| 23 |
+
|
| 24 |
+
RpcMem::RpcMem(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface)
|
| 25 |
+
: m_libCdspRpc(nullptr), m_rpcMemAlloc(nullptr), m_rpcMemFree(nullptr), m_rpcMemToFd(nullptr),
|
| 26 |
+
m_qnnInterface(qnnInterface), m_contextHandle(contextHandle) {
|
| 27 |
+
(void)m_contextHandle;
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
bool RpcMem::initialize() {
|
| 31 |
+
// On Android, 32-bit and 64-bit libcdsprpc.so can be found at /vendor/lib and /vendor/lib64 respectively.
|
| 32 |
+
// On Windows, it's installed into something like this
|
| 33 |
+
// c:\Windows\System32\DriverStore\FileRepository\qcnspmcdm8380.inf_arm64_30b9cc995571de6a\libcdsprpc.dll
|
| 34 |
+
#ifdef _WIN32
|
| 35 |
+
const char* dsprpc_so = "libcdsprpc.dll";
|
| 36 |
+
#else
|
| 37 |
+
const char* dsprpc_so = "libcdsprpc.so";
|
| 38 |
+
#endif
|
| 39 |
+
|
| 40 |
+
m_libCdspRpc = dlopen(dsprpc_so, RTLD_NOW | RTLD_LOCAL);
|
| 41 |
+
if (nullptr == m_libCdspRpc) {
|
| 42 |
+
QNN_ERROR("Unable to load backend. dlerror(): %s", dlerror());
|
| 43 |
+
return false;
|
| 44 |
+
}
|
| 45 |
+
m_rpcMemAlloc = (RpcMemAllocFn_t)dlsym(m_libCdspRpc, "rpcmem_alloc");
|
| 46 |
+
m_rpcMemFree = (RpcMemFreeFn_t)dlsym(m_libCdspRpc, "rpcmem_free");
|
| 47 |
+
m_rpcMemToFd = (RpcMemToFdFn_t)dlsym(m_libCdspRpc, "rpcmem_to_fd");
|
| 48 |
+
if (nullptr == m_rpcMemAlloc || nullptr == m_rpcMemFree || nullptr == m_rpcMemToFd) {
|
| 49 |
+
QNN_ERROR("Unable to access symbols in libcdsprpc. dlerror(): %s", dlerror());
|
| 50 |
+
return false;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
return true;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
RpcMem::~RpcMem() {
|
| 57 |
+
if (m_libCdspRpc) {
|
| 58 |
+
QNN_DEBUG("Closing libcdsprpc.so handle");
|
| 59 |
+
dlclose(m_libCdspRpc);
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
RpcMemTensorData* RpcMem::getRpcMemTensorData(Qnn_Tensor_t* tensor) {
|
| 64 |
+
if (tensor == nullptr) return nullptr;
|
| 65 |
+
Qnn_MemHandle_t mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 66 |
+
if (mem_handle == nullptr) return nullptr;
|
| 67 |
+
return &m_memHandleToRpcMem.at(mem_handle);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
void* RpcMem::getBuffer(Qnn_Tensor_t* tensor) {
|
| 71 |
+
RpcMemTensorData* data = getRpcMemTensorData(tensor);
|
| 72 |
+
if (data == nullptr) {
|
| 73 |
+
QNN_ERROR("getBuffer : Couldn't find tensor %p", tensor);
|
| 74 |
+
return nullptr;
|
| 75 |
+
}
|
| 76 |
+
return data->memPointer;
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
int RpcMem::getFd(Qnn_Tensor_t* tensor) {
|
| 80 |
+
RpcMemTensorData* data = getRpcMemTensorData(tensor);
|
| 81 |
+
if (data == nullptr) {
|
| 82 |
+
QNN_ERROR("getFd : Couldn't find tensor %p", tensor);
|
| 83 |
+
return -1;
|
| 84 |
+
}
|
| 85 |
+
return data->fd;
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
size_t RpcMem::getOffset(Qnn_Tensor_t* tensor) {
|
| 89 |
+
RpcMemTensorData* data = getRpcMemTensorData(tensor);
|
| 90 |
+
if (data == nullptr) {
|
| 91 |
+
QNN_ERROR("getOffset : Couldn't find tensor %p", tensor);
|
| 92 |
+
return 0;
|
| 93 |
+
}
|
| 94 |
+
return data->offset;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
size_t RpcMem::getBufferSize(Qnn_Tensor_t* tensor) {
|
| 98 |
+
RpcMemTensorData* data = getRpcMemTensorData(tensor);
|
| 99 |
+
if (data == nullptr) {
|
| 100 |
+
QNN_ERROR("getBufferSize : Couldn't find tensor %p", tensor);
|
| 101 |
+
return 0;
|
| 102 |
+
}
|
| 103 |
+
return data->size;
|
| 104 |
+
};
|
| 105 |
+
|
| 106 |
+
size_t RpcMem::getTotalBufferSize(Qnn_Tensor_t* tensor) {
|
| 107 |
+
RpcMemTensorData* data = getRpcMemTensorData(tensor);
|
| 108 |
+
if (data == nullptr) {
|
| 109 |
+
QNN_ERROR("getTotalBufferSize : Couldn't find tensor %p", tensor);
|
| 110 |
+
return 0;
|
| 111 |
+
}
|
| 112 |
+
return data->totalBufferSize;
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
bool RpcMem::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) {
|
| 116 |
+
if (m_libCdspRpc == nullptr) {
|
| 117 |
+
QNN_ERROR("RpcMem not initialized");
|
| 118 |
+
return false;
|
| 119 |
+
}
|
| 120 |
+
if (!tensor) {
|
| 121 |
+
QNN_ERROR("Received nullptr for tensor");
|
| 122 |
+
return false;
|
| 123 |
+
}
|
| 124 |
+
if (m_tensorToRpcMem.find(tensor) != m_tensorToRpcMem.end()) {
|
| 125 |
+
QNN_ERROR("Tensor already allocated");
|
| 126 |
+
return false;
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
auto memPointer = m_rpcMemAlloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, tensorDataSize);
|
| 130 |
+
auto status = true;
|
| 131 |
+
if (!memPointer) {
|
| 132 |
+
QNN_ERROR("rpcmem_alloc failure");
|
| 133 |
+
status = false;
|
| 134 |
+
}
|
| 135 |
+
int memfd = -1;
|
| 136 |
+
if (status == true) {
|
| 137 |
+
memfd = m_rpcMemToFd(memPointer);
|
| 138 |
+
if (memfd == -1) {
|
| 139 |
+
QNN_ERROR("rpcmem_to_fd failure");
|
| 140 |
+
status = false;
|
| 141 |
+
}
|
| 142 |
+
}
|
| 143 |
+
if (status == true) {
|
| 144 |
+
Qnn_MemDescriptor_t memDescriptor = {
|
| 145 |
+
{QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr},
|
| 146 |
+
QNN_TENSOR_GET_DATA_TYPE(tensor),
|
| 147 |
+
QNN_MEM_TYPE_ION,
|
| 148 |
+
{{-1}}
|
| 149 |
+
};
|
| 150 |
+
memDescriptor.ionInfo.fd = memfd;
|
| 151 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
|
| 152 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr);
|
| 153 |
+
|
| 154 |
+
Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 155 |
+
if (QNN_SUCCESS != m_qnnInterface->memRegister(
|
| 156 |
+
m_contextHandle,
|
| 157 |
+
&memDescriptor,
|
| 158 |
+
1,
|
| 159 |
+
&(memHandle)
|
| 160 |
+
)) {
|
| 161 |
+
const char* tname = QNN_TENSOR_GET_NAME(tensor);
|
| 162 |
+
QNN_ERROR("memRegister fail %s (ctx=%p fd=%d)", tname, m_contextHandle, memfd);
|
| 163 |
+
status = false;
|
| 164 |
+
}
|
| 165 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, memHandle);
|
| 166 |
+
}
|
| 167 |
+
if (status == true) {
|
| 168 |
+
m_tensorToRpcMem.insert({tensor, RpcMemTensorData(memfd, memPointer, tensorDataSize)});
|
| 169 |
+
}
|
| 170 |
+
if (status == false) {
|
| 171 |
+
if (m_rpcMemFree) {
|
| 172 |
+
m_rpcMemFree(memPointer);
|
| 173 |
+
}
|
| 174 |
+
}
|
| 175 |
+
return status;
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
bool RpcMem::freeTensorBuffer(Qnn_Tensor_t* tensor) {
|
| 179 |
+
if (!tensor) {
|
| 180 |
+
QNN_ERROR("Received nullptr for tensor");
|
| 181 |
+
return false;
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
if (m_sameMemoryFreeTensors.find(tensor) != m_sameMemoryFreeTensors.end()) {
|
| 185 |
+
if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) {
|
| 186 |
+
QNN_ERROR("Tensor not found");
|
| 187 |
+
return false;
|
| 188 |
+
}
|
| 189 |
+
m_tensorToRpcMem.erase(tensor);
|
| 190 |
+
} else {
|
| 191 |
+
auto memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 192 |
+
if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
|
| 193 |
+
QNN_ERROR("Failed to deregister ion memory with the backend");
|
| 194 |
+
return false;
|
| 195 |
+
}
|
| 196 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED);
|
| 197 |
+
if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) {
|
| 198 |
+
QNN_ERROR("Tensor not found");
|
| 199 |
+
return false;
|
| 200 |
+
}
|
| 201 |
+
if (m_rpcMemFree) {
|
| 202 |
+
m_rpcMemFree(m_tensorToRpcMem[tensor].memPointer);
|
| 203 |
+
}
|
| 204 |
+
m_tensorToRpcMem.erase(tensor);
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
return true;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
bool RpcMem::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
|
| 211 |
+
if (nullptr == dest || nullptr == src) {
|
| 212 |
+
QNN_ERROR("Received nullptr");
|
| 213 |
+
return false;
|
| 214 |
+
}
|
| 215 |
+
if (m_tensorToRpcMem.find(src) == m_tensorToRpcMem.end()) {
|
| 216 |
+
QNN_ERROR("Src Tensor not found");
|
| 217 |
+
return false;
|
| 218 |
+
}
|
| 219 |
+
|
| 220 |
+
if (false == freeTensorBuffer(dest)) {
|
| 221 |
+
return false;
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
|
| 225 |
+
QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src));
|
| 226 |
+
m_tensorToRpcMem.insert({dest, m_tensorToRpcMem[src]});
|
| 227 |
+
m_sameMemoryFreeTensors.insert(dest);
|
| 228 |
+
|
| 229 |
+
return true;
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
bool RpcMem::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) {
|
| 233 |
+
if (nullptr == dest || nullptr == src) {
|
| 234 |
+
QNN_ERROR("Received nullptr");
|
| 235 |
+
return false;
|
| 236 |
+
}
|
| 237 |
+
if (m_tensorToRpcMem.find(src) == m_tensorToRpcMem.end()) {
|
| 238 |
+
QNN_ERROR("Src Tensor not found");
|
| 239 |
+
return false;
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
if (false == freeTensorBuffer(dest)) {
|
| 243 |
+
return false;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
|
| 247 |
+
QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src));
|
| 248 |
+
m_tensorToRpcMem.insert({dest, m_tensorToRpcMem[src]});
|
| 249 |
+
m_sameMemoryFreeTensors.insert(dest);
|
| 250 |
+
|
| 251 |
+
return true;
|
| 252 |
+
}
|
| 253 |
+
|
| 254 |
+
bool RpcMem::useExternalMemory(Qnn_Tensor_t* dest, void* extMem) {
|
| 255 |
+
QNN_ERROR("We don't support external memory feature for shared buffers yet!");
|
| 256 |
+
return false;
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
void* RpcMem::allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) {
|
| 260 |
+
*fd = -1;
|
| 261 |
+
if (m_libCdspRpc == nullptr) {
|
| 262 |
+
QNN_ERROR("RpcMem not initialized for fused buffer");
|
| 263 |
+
return nullptr;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
void* memPointer = m_rpcMemAlloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, bufferSize);
|
| 267 |
+
if (!memPointer) {
|
| 268 |
+
QNN_ERROR("Not able to allocate fused buffer of size: %lu", (unsigned long)bufferSize);
|
| 269 |
+
return nullptr;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
m_fusedBuffers.push_back({memPointer, bufferSize});
|
| 273 |
+
QNN_DEBUG(
|
| 274 |
+
"Successfully allocated fused buffer at %p with size %lu",
|
| 275 |
+
memPointer,
|
| 276 |
+
(unsigned long)bufferSize
|
| 277 |
+
);
|
| 278 |
+
|
| 279 |
+
if ((*fd = m_rpcMemToFd(memPointer)) == -1) {
|
| 280 |
+
QNN_ERROR(
|
| 281 |
+
"Not able to get fd for the fused buffer of size: %lu", (unsigned long)bufferSize
|
| 282 |
+
);
|
| 283 |
+
return nullptr;
|
| 284 |
+
}
|
| 285 |
+
|
| 286 |
+
QNN_DEBUG("Retrieved fd %d for pointer %p", *fd, memPointer);
|
| 287 |
+
return memPointer;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
bool RpcMem::allocateBuffers(
|
| 291 |
+
const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
|
| 292 |
+
std::map<std::string, std::pair<int, size_t>>& tensor_offsets
|
| 293 |
+
) {
|
| 294 |
+
int alloc_chunk_idx = m_fusedBuffers.size();
|
| 295 |
+
int num_alloc_chunks = 0;
|
| 296 |
+
size_t total_alloc_size = 0;
|
| 297 |
+
|
| 298 |
+
for (auto& [_, tensor_sizes] : allocs_per_chunk) {
|
| 299 |
+
// Calculate total allocation chunk size
|
| 300 |
+
size_t alloc_chunk_size = 0;
|
| 301 |
+
for (const auto& [tensor_name, tensor_size] : tensor_sizes) {
|
| 302 |
+
tensor_offsets[tensor_name] = {alloc_chunk_idx, alloc_chunk_size};
|
| 303 |
+
alloc_chunk_size += tensor_size;
|
| 304 |
+
}
|
| 305 |
+
|
| 306 |
+
// Allocate chunk for this unique context set
|
| 307 |
+
if (alloc_chunk_size <= 0) {
|
| 308 |
+
QNN_ERROR("Unexpected chunk size detected. Please re-check IO allocations");
|
| 309 |
+
return false;
|
| 310 |
+
}
|
| 311 |
+
|
| 312 |
+
m_fusedFds.push_back(0);
|
| 313 |
+
if (!allocateTensorFusedBuffer(alloc_chunk_size, &m_fusedFds.back())) //
|
| 314 |
+
return false;
|
| 315 |
+
total_alloc_size += alloc_chunk_size;
|
| 316 |
+
alloc_chunk_idx++;
|
| 317 |
+
num_alloc_chunks++;
|
| 318 |
+
}
|
| 319 |
+
QNN_INFO(
|
| 320 |
+
"Allocated total size = %lu across %d buffers",
|
| 321 |
+
(unsigned long)total_alloc_size,
|
| 322 |
+
num_alloc_chunks
|
| 323 |
+
);
|
| 324 |
+
return true;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
bool RpcMem::mapFusedBufferOffset(
|
| 328 |
+
Qnn_Tensor_t* tensor,
|
| 329 |
+
size_t tensorDataSize,
|
| 330 |
+
int32_t fd,
|
| 331 |
+
uint32_t offset,
|
| 332 |
+
uint64_t totalBufferSize,
|
| 333 |
+
void* memPointer,
|
| 334 |
+
Qnn_ContextHandle_t contextHandle
|
| 335 |
+
) {
|
| 336 |
+
if (m_libCdspRpc == nullptr) {
|
| 337 |
+
QNN_ERROR("RpcMem not initialized");
|
| 338 |
+
return false;
|
| 339 |
+
}
|
| 340 |
+
if (!tensor) {
|
| 341 |
+
QNN_ERROR("Received nullptr for tensor");
|
| 342 |
+
return false;
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
Qnn_ErrorHandle_t ret;
|
| 346 |
+
const char* tname = QNN_TENSOR_GET_NAME(tensor);
|
| 347 |
+
|
| 348 |
+
// Check if tensor already has a memHandle assigned
|
| 349 |
+
Qnn_MemHandle_t cur_mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 350 |
+
if (cur_mem_handle != nullptr) {
|
| 351 |
+
// Check if memHandle is already identical to requested buffer and offset
|
| 352 |
+
RpcMemTensorData& cur_rpc_mem_data = m_memHandleToRpcMem.at(cur_mem_handle);
|
| 353 |
+
if (cur_rpc_mem_data.fd == fd && cur_rpc_mem_data.offset == offset) {
|
| 354 |
+
return true;
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
// updated offset, deregister previous mem_handle
|
| 358 |
+
if (tensorDataSize == 0) tensorDataSize = cur_rpc_mem_data.size;
|
| 359 |
+
// clang-format off
|
| 360 |
+
TRACE_MEMORY_ALLOC( "memDeRegister %-20s (fd=%d offset=%lu) memHandle=%p",
|
| 361 |
+
tname, cur_rpc_mem_data.fd, cur_rpc_mem_data.offset, cur_mem_handle);
|
| 362 |
+
// clang-format on
|
| 363 |
+
m_memHandleToRpcMem.erase(cur_mem_handle);
|
| 364 |
+
if ((ret = m_qnnInterface->memDeRegister(&cur_mem_handle, 1)) != QNN_SUCCESS) {
|
| 365 |
+
QNN_ERROR(
|
| 366 |
+
"memDeRegister ERROR(%lu) - %s memHandle=%p",
|
| 367 |
+
(unsigned long)ret,
|
| 368 |
+
tname,
|
| 369 |
+
cur_mem_handle
|
| 370 |
+
);
|
| 371 |
+
return false;
|
| 372 |
+
}
|
| 373 |
+
} else {
|
| 374 |
+
// For inital tensors, we need to check if the tensor can re-use a memHandle
|
| 375 |
+
// from another tensor in the same context
|
| 376 |
+
auto memConfig = std::make_tuple(fd, offset, contextHandle);
|
| 377 |
+
if (memConfigList.contains(memConfig)) {
|
| 378 |
+
auto& parentTensor = memConfigList[memConfig];
|
| 379 |
+
Qnn_MemHandle_t parentMemHandle = QNN_TENSOR_GET_MEM_HANDLE(parentTensor);
|
| 380 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
|
| 381 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, parentMemHandle);
|
| 382 |
+
TRACE_MEMORY_ALLOC("%-20s : Mapping to memHandle %p", tname, parentMemHandle);
|
| 383 |
+
return true;
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
// Register a new memHandle based on function arguments
|
| 388 |
+
QnnMemHtp_Descriptor_t htp_mem_desciptor = {QNN_HTP_MEM_SHARED_BUFFER, totalBufferSize, {0}};
|
| 389 |
+
htp_mem_desciptor.sharedBufferConfig.fd = fd;
|
| 390 |
+
htp_mem_desciptor.sharedBufferConfig.offset = offset;
|
| 391 |
+
|
| 392 |
+
Qnn_MemDescriptor_t mem_descriptor = {
|
| 393 |
+
{QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr},
|
| 394 |
+
QNN_TENSOR_GET_DATA_TYPE(tensor),
|
| 395 |
+
QNN_MEM_TYPE_CUSTOM,
|
| 396 |
+
{{-1}}
|
| 397 |
+
};
|
| 398 |
+
mem_descriptor.customInfo = &htp_mem_desciptor;
|
| 399 |
+
|
| 400 |
+
Qnn_MemHandle_t mem_handle = nullptr;
|
| 401 |
+
ret = m_qnnInterface->memRegister(contextHandle, &mem_descriptor, 1, &mem_handle);
|
| 402 |
+
if (ret != QNN_SUCCESS) {
|
| 403 |
+
QNN_ERROR("%-20s (ctx=%p fd=%d offset=%u)", tname, contextHandle, fd, offset);
|
| 404 |
+
QNN_ERROR("memRegister ERROR(%lu)", (unsigned long)ret);
|
| 405 |
+
return false;
|
| 406 |
+
}
|
| 407 |
+
|
| 408 |
+
// clang-format off
|
| 409 |
+
TRACE_MEMORY_ALLOC("%-20s (ctx=%p fd=%d offset=%u) memPointer=%p memHandle=%p",
|
| 410 |
+
tname, contextHandle, fd, offset, ((uint8_t*)memPointer) + offset, mem_handle);
|
| 411 |
+
// clang-format on
|
| 412 |
+
m_memHandleToRpcMem[mem_handle] = RpcMemTensorData(
|
| 413 |
+
fd, ((uint8_t*)memPointer) + offset, tensorDataSize, totalBufferSize, offset
|
| 414 |
+
);
|
| 415 |
+
|
| 416 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
|
| 417 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, mem_handle);
|
| 418 |
+
if (cur_mem_handle == nullptr) // Cache memory config for initial memRegisters only
|
| 419 |
+
memConfigList[std::make_tuple(fd, offset, contextHandle)] = tensor;
|
| 420 |
+
|
| 421 |
+
return true;
|
| 422 |
+
}
|
| 423 |
+
|
| 424 |
+
bool RpcMem::mapFusedBufferOffset(
|
| 425 |
+
Qnn_Tensor_t* tensor,
|
| 426 |
+
int alloc_idx,
|
| 427 |
+
size_t offset,
|
| 428 |
+
Qnn_ContextHandle_t ctx,
|
| 429 |
+
size_t size
|
| 430 |
+
) {
|
| 431 |
+
return mapFusedBufferOffset(
|
| 432 |
+
tensor,
|
| 433 |
+
size,
|
| 434 |
+
m_fusedFds[alloc_idx],
|
| 435 |
+
offset,
|
| 436 |
+
m_fusedBuffers[alloc_idx].second,
|
| 437 |
+
m_fusedBuffers[alloc_idx].first,
|
| 438 |
+
ctx
|
| 439 |
+
);
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
bool RpcMem::deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) {
|
| 443 |
+
if (!tensor) {
|
| 444 |
+
QNN_ERROR("Received nullptr for tensor");
|
| 445 |
+
return false;
|
| 446 |
+
}
|
| 447 |
+
|
| 448 |
+
if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) {
|
| 449 |
+
QNN_ERROR("Tensor not found");
|
| 450 |
+
return false;
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
// We are not freeing memhandles here since they are already freed when
|
| 454 |
+
// freeContext() gets called in the destructor of QnnApi class which
|
| 455 |
+
// happens before this point
|
| 456 |
+
|
| 457 |
+
// Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 458 |
+
// QNN_ERROR("Interface handle %p memhandle %p", m_qnnInterface, memHandle);
|
| 459 |
+
// if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
|
| 460 |
+
// QNN_ERROR("Failed to deregister ion memory with the backend");
|
| 461 |
+
// return false;
|
| 462 |
+
// }
|
| 463 |
+
|
| 464 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED);
|
| 465 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr);
|
| 466 |
+
m_tensorToRpcMem.erase(tensor);
|
| 467 |
+
return true;
|
| 468 |
+
}
|
| 469 |
+
|
| 470 |
+
void RpcMem::freeFusedBuffers() {
|
| 471 |
+
// for (auto& memHandle : m_orphanedMemHandles) {
|
| 472 |
+
// if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
|
| 473 |
+
// QNN_ERROR("Failed to deregister ion memory with the backend");
|
| 474 |
+
// }
|
| 475 |
+
// }
|
| 476 |
+
|
| 477 |
+
for (auto& [mem_ptr, buffer_size] : m_fusedBuffers) {
|
| 478 |
+
QNN_DEBUG("Freeing fused buffer %p (size=%lu)", mem_ptr, buffer_size);
|
| 479 |
+
m_rpcMemFree(mem_ptr);
|
| 480 |
+
}
|
| 481 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <unordered_set>
|
| 12 |
+
|
| 13 |
+
#include "IBufferAlloc.hpp"
|
| 14 |
+
#include "QnnInterface.h"
|
| 15 |
+
#include "Log.hpp"
|
| 16 |
+
|
| 17 |
+
typedef void* (*RpcMemAllocFn_t)(int, uint32_t, int);
|
| 18 |
+
typedef void (*RpcMemFreeFn_t)(void*);
|
| 19 |
+
typedef int (*RpcMemToFdFn_t)(void*);
|
| 20 |
+
|
| 21 |
+
struct RpcMemTensorData {
|
| 22 |
+
int fd;
|
| 23 |
+
void* memPointer;
|
| 24 |
+
size_t size;
|
| 25 |
+
size_t totalBufferSize;
|
| 26 |
+
size_t offset;
|
| 27 |
+
RpcMemTensorData() : fd(-1), memPointer(nullptr), size(0) {}
|
| 28 |
+
RpcMemTensorData(int fdIn, void* memPointerIn, size_t sizeIn)
|
| 29 |
+
: fd(fdIn), memPointer(memPointerIn), size(sizeIn) {}
|
| 30 |
+
RpcMemTensorData(
|
| 31 |
+
int fdIn,
|
| 32 |
+
void* memPointerIn,
|
| 33 |
+
size_t sizeIn,
|
| 34 |
+
size_t totalBufferSizeIn,
|
| 35 |
+
size_t offsetIn
|
| 36 |
+
)
|
| 37 |
+
: fd(fdIn), memPointer(memPointerIn), size(sizeIn), totalBufferSize(totalBufferSizeIn),
|
| 38 |
+
offset(offsetIn) {}
|
| 39 |
+
};
|
| 40 |
+
|
| 41 |
+
class RpcMem final : public IBufferAlloc {
|
| 42 |
+
public:
|
| 43 |
+
RpcMem(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface);
|
| 44 |
+
// Disable copy constructors, r-value referencing, etc
|
| 45 |
+
RpcMem(const RpcMem&) = delete;
|
| 46 |
+
RpcMem& operator=(const RpcMem&) = delete;
|
| 47 |
+
RpcMem(RpcMem&&) = delete;
|
| 48 |
+
RpcMem& operator=(RpcMem&&) = delete;
|
| 49 |
+
bool initialize() override;
|
| 50 |
+
void* getBuffer(Qnn_Tensor_t* tensor) override;
|
| 51 |
+
int getFd(Qnn_Tensor_t* tensor) override;
|
| 52 |
+
|
| 53 |
+
size_t getOffset(Qnn_Tensor_t* tensor) override;
|
| 54 |
+
|
| 55 |
+
size_t getBufferSize(Qnn_Tensor_t* tensor) override;
|
| 56 |
+
|
| 57 |
+
size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override;
|
| 58 |
+
|
| 59 |
+
bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override;
|
| 60 |
+
|
| 61 |
+
bool freeTensorBuffer(Qnn_Tensor_t* tensor) override;
|
| 62 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override;
|
| 63 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override;
|
| 64 |
+
|
| 65 |
+
bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override;
|
| 66 |
+
|
| 67 |
+
void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override;
|
| 68 |
+
bool allocateBuffers(
|
| 69 |
+
const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
|
| 70 |
+
std::map<std::string, std::pair<int, size_t>>& tensor_offsets
|
| 71 |
+
) override;
|
| 72 |
+
|
| 73 |
+
bool mapFusedBufferOffset(
|
| 74 |
+
Qnn_Tensor_t* tensor,
|
| 75 |
+
size_t tensorDataSize,
|
| 76 |
+
int32_t fd,
|
| 77 |
+
uint32_t offset,
|
| 78 |
+
uint64_t totalBufferSize,
|
| 79 |
+
void* memPointer,
|
| 80 |
+
Qnn_ContextHandle_t contextHandle
|
| 81 |
+
) override;
|
| 82 |
+
bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override;
|
| 83 |
+
void freeFusedBuffers() override;
|
| 84 |
+
bool mapFusedBufferOffset(
|
| 85 |
+
Qnn_Tensor_t* tensor,
|
| 86 |
+
int alloc_idx,
|
| 87 |
+
size_t offset,
|
| 88 |
+
Qnn_ContextHandle_t ctx,
|
| 89 |
+
size_t size
|
| 90 |
+
) override;
|
| 91 |
+
virtual ~RpcMem();
|
| 92 |
+
|
| 93 |
+
private:
|
| 94 |
+
RpcMemTensorData* getRpcMemTensorData(Qnn_Tensor_t* tensor);
|
| 95 |
+
|
| 96 |
+
// Pointer to the dlopen'd libcdsprpc.so shared library which contains
|
| 97 |
+
// rpcmem_alloc, rpcmem_free, rpcmem_to_fd APIs
|
| 98 |
+
void* m_libCdspRpc;
|
| 99 |
+
// Function pointer to rpcmem_alloc
|
| 100 |
+
RpcMemAllocFn_t m_rpcMemAlloc;
|
| 101 |
+
// Function pointer to rpcmem_free
|
| 102 |
+
RpcMemFreeFn_t m_rpcMemFree;
|
| 103 |
+
// Function pointer to rpcmem_to_fd
|
| 104 |
+
RpcMemToFdFn_t m_rpcMemToFd;
|
| 105 |
+
QNN_INTERFACE_VER_TYPE* m_qnnInterface;
|
| 106 |
+
Qnn_ContextHandle_t m_contextHandle;
|
| 107 |
+
|
| 108 |
+
std::unordered_map<Qnn_Tensor_t*, RpcMemTensorData> m_tensorToRpcMem;
|
| 109 |
+
std::unordered_set<Qnn_Tensor_t*> m_sameMemoryFreeTensors;
|
| 110 |
+
std::vector<std::pair<void*, size_t>> m_fusedBuffers; // vector<<memPointer, bufferSize>>
|
| 111 |
+
std::vector<int32_t> m_fusedFds;
|
| 112 |
+
std::unordered_set<Qnn_MemHandle_t> m_orphanedMemHandles;
|
| 113 |
+
std::unordered_map<Qnn_MemHandle_t, RpcMemTensorData> m_memHandleToRpcMem;
|
| 114 |
+
std::map<std::tuple<int, size_t, Qnn_ContextHandle_t>, Qnn_Tensor_t*> memConfigList;
|
| 115 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#ifdef _WIN32
|
| 10 |
+
|
| 11 |
+
#pragma warning(disable : 4133 4996)
|
| 12 |
+
|
| 13 |
+
#include <inttypes.h>
|
| 14 |
+
#include <stdio.h>
|
| 15 |
+
#include <stdlib.h>
|
| 16 |
+
#include <string.h>
|
| 17 |
+
#include <windows.h>
|
| 18 |
+
#include <wchar.h>
|
| 19 |
+
|
| 20 |
+
#include "dlwrap.hpp"
|
| 21 |
+
|
| 22 |
+
static const char* last_func;
|
| 23 |
+
static long last_err;
|
| 24 |
+
|
| 25 |
+
void* dlopen(const char* dll, int flags) {
|
| 26 |
+
HINSTANCE h = LoadLibraryA(dll);
|
| 27 |
+
if (h == NULL) {
|
| 28 |
+
last_err = GetLastError();
|
| 29 |
+
last_func = "dlopen";
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
return h;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
int dlclose(void* h) {
|
| 36 |
+
if (!FreeLibrary((HINSTANCE)h)) {
|
| 37 |
+
last_err = GetLastError();
|
| 38 |
+
last_func = "dlclose";
|
| 39 |
+
return -1;
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return 0;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
void* dlsym(void* h, const char* name) {
|
| 46 |
+
FARPROC p = GetProcAddress((HINSTANCE)h, name);
|
| 47 |
+
if (!p) {
|
| 48 |
+
last_err = GetLastError();
|
| 49 |
+
last_func = "dlsym";
|
| 50 |
+
}
|
| 51 |
+
return (void*)(intptr_t)p;
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
const char* dlerror(void) {
|
| 55 |
+
static char str[88];
|
| 56 |
+
|
| 57 |
+
if (!last_err) return NULL;
|
| 58 |
+
|
| 59 |
+
sprintf(str, "%s error #%ld", last_func, last_err);
|
| 60 |
+
last_err = 0;
|
| 61 |
+
last_func = NULL;
|
| 62 |
+
|
| 63 |
+
return str;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
#endif // _WIN32
|
Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#ifndef DLWRAP_HPP
|
| 10 |
+
#define DLWRAP_HPP
|
| 11 |
+
|
| 12 |
+
#ifndef _WIN32
|
| 13 |
+
|
| 14 |
+
// Just include regular dlfcn
|
| 15 |
+
#include <dlfcn.h>
|
| 16 |
+
|
| 17 |
+
#else // _WIN32
|
| 18 |
+
|
| 19 |
+
// Define basic set dl functions and flags
|
| 20 |
+
|
| 21 |
+
#define RTLD_GLOBAL 0x100
|
| 22 |
+
#define RTLD_LOCAL 0x000
|
| 23 |
+
#define RTLD_LAZY 0x000
|
| 24 |
+
#define RTLD_NOW 0x001
|
| 25 |
+
|
| 26 |
+
void* dlopen(const char* filename, int flag);
|
| 27 |
+
int dlclose(void* handle);
|
| 28 |
+
void* dlsym(void* handle, const char* name);
|
| 29 |
+
const char* dlerror(void);
|
| 30 |
+
|
| 31 |
+
#endif // _WIN32
|
| 32 |
+
|
| 33 |
+
#endif // DLWRAP_HPP
|
Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "qnn-utils.hpp"
|
| 10 |
+
|
| 11 |
+
#include <string>
|
| 12 |
+
#include <fstream>
|
| 13 |
+
#include <filesystem>
|
| 14 |
+
#include <sstream>
|
| 15 |
+
#include "QnnApi.hpp"
|
| 16 |
+
#include <fmt/format.h>
|
| 17 |
+
|
| 18 |
+
namespace fs = std::filesystem;
|
| 19 |
+
|
| 20 |
+
namespace qualla {
|
| 21 |
+
namespace QnnUtils {
|
| 22 |
+
// Alternate implementation for bw() = lambda x: (10 * ((x & 0xf0)>>4) + (x & 0xf)) // 8
|
| 23 |
+
int DataType::bw() { return (_dtype == QNN_DATATYPE_UNDEFINED) ? -1 : QnnApi::getDataTypeSize(_dtype);}
|
| 24 |
+
int DataType::type() {return (_dtype == QNN_DATATYPE_UNDEFINED) ? -1 : _dtype >> 4; }
|
| 25 |
+
|
| 26 |
+
int32_t DataType::val() { return static_cast<int32_t>(_dtype); }
|
| 27 |
+
|
| 28 |
+
bool writeRawData(void* data, size_t size, const fs::path& path) {
|
| 29 |
+
auto p = path.parent_path();
|
| 30 |
+
if (!fs::exists(p) && !fs::create_directories(p)) return false;
|
| 31 |
+
|
| 32 |
+
std::ofstream f(path, std::ofstream::binary);
|
| 33 |
+
f.write((char*)data, size);
|
| 34 |
+
f.close();
|
| 35 |
+
|
| 36 |
+
return true;
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
bool readRawData(void* data, size_t size, const fs::path& path) {
|
| 40 |
+
if (fs::file_size(path) != size) {
|
| 41 |
+
throw std::runtime_error(fmt::format(
|
| 42 |
+
"file size doesnot match: {} size {}, buf-size {}",
|
| 43 |
+
path.string(),
|
| 44 |
+
fs::file_size(path),
|
| 45 |
+
size
|
| 46 |
+
));
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
std::ifstream f(path, std::ifstream::binary);
|
| 50 |
+
f.read((char*)data, size);
|
| 51 |
+
f.close();
|
| 52 |
+
|
| 53 |
+
return true;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
void getQuantParamString(
|
| 57 |
+
const std::vector<QuantParam>& quantParam,
|
| 58 |
+
std::string& scale_string,
|
| 59 |
+
std::string& offset_string
|
| 60 |
+
) {
|
| 61 |
+
std::ostringstream scales_s;
|
| 62 |
+
std::ostringstream offsets_s;
|
| 63 |
+
for (int i = 0; i < quantParam.size(); i++) {
|
| 64 |
+
if (i != 0) {
|
| 65 |
+
scales_s << ", ";
|
| 66 |
+
offsets_s << ", ";
|
| 67 |
+
}
|
| 68 |
+
scales_s << std::fixed << std::setprecision(20) << quantParam[i].scale;
|
| 69 |
+
offsets_s << quantParam[i].offset;
|
| 70 |
+
}
|
| 71 |
+
scale_string = std::move(scales_s.str());
|
| 72 |
+
offset_string = std::move(offsets_s.str());
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
const char* DataType::str() {
|
| 76 |
+
// clang-format off
|
| 77 |
+
switch (_dtype) {
|
| 78 |
+
case QNN_DATATYPE_INT_8: return "QNN_DATATYPE_INT_8";
|
| 79 |
+
case QNN_DATATYPE_INT_16: return "QNN_DATATYPE_INT_16";
|
| 80 |
+
case QNN_DATATYPE_INT_32: return "QNN_DATATYPE_INT_32";
|
| 81 |
+
case QNN_DATATYPE_INT_64: return "QNN_DATATYPE_INT_64";
|
| 82 |
+
case QNN_DATATYPE_UINT_8: return "QNN_DATATYPE_UINT_8";
|
| 83 |
+
case QNN_DATATYPE_UINT_16: return "QNN_DATATYPE_UINT_16";
|
| 84 |
+
case QNN_DATATYPE_UINT_32: return "QNN_DATATYPE_UINT_32";
|
| 85 |
+
case QNN_DATATYPE_UINT_64: return "QNN_DATATYPE_UINT_64";
|
| 86 |
+
case QNN_DATATYPE_FLOAT_16: return "QNN_DATATYPE_FLOAT_16";
|
| 87 |
+
case QNN_DATATYPE_FLOAT_32: return "QNN_DATATYPE_FLOAT_32";
|
| 88 |
+
case QNN_DATATYPE_FLOAT_64: return "QNN_DATATYPE_FLOAT_64";
|
| 89 |
+
case QNN_DATATYPE_SFIXED_POINT_4: return "QNN_DATATYPE_SFIXED_POINT_4";
|
| 90 |
+
case QNN_DATATYPE_SFIXED_POINT_8: return "QNN_DATATYPE_SFIXED_POINT_8";
|
| 91 |
+
case QNN_DATATYPE_SFIXED_POINT_16: return "QNN_DATATYPE_SFIXED_POINT_16";
|
| 92 |
+
case QNN_DATATYPE_SFIXED_POINT_32: return "QNN_DATATYPE_SFIXED_POINT_32";
|
| 93 |
+
case QNN_DATATYPE_UFIXED_POINT_4: return "QNN_DATATYPE_UFIXED_POINT_4";
|
| 94 |
+
case QNN_DATATYPE_UFIXED_POINT_8: return "QNN_DATATYPE_UFIXED_POINT_8";
|
| 95 |
+
case QNN_DATATYPE_UFIXED_POINT_16: return "QNN_DATATYPE_UFIXED_POINT_16";
|
| 96 |
+
case QNN_DATATYPE_UFIXED_POINT_32: return "QNN_DATATYPE_UFIXED_POINT_32";
|
| 97 |
+
case QNN_DATATYPE_BOOL_8: return "QNN_DATATYPE_BOOL_8";
|
| 98 |
+
case QNN_DATATYPE_STRING: return "QNN_DATATYPE_STRING";
|
| 99 |
+
default: return "QNN_DATATYPE_UNDEFINED";
|
| 100 |
+
}
|
| 101 |
+
// clang-format on
|
| 102 |
+
}
|
| 103 |
+
} // namespace QnnUtils
|
| 104 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#ifdef _MSC_VER
|
| 12 |
+
#pragma warning(disable : 4068)
|
| 13 |
+
#endif
|
| 14 |
+
|
| 15 |
+
#include <string>
|
| 16 |
+
#include <vector>
|
| 17 |
+
#include <algorithm>
|
| 18 |
+
#include <filesystem>
|
| 19 |
+
#include "QnnApiUtils.hpp"
|
| 20 |
+
#include "QnnInterface.h"
|
| 21 |
+
|
| 22 |
+
namespace qualla {
|
| 23 |
+
|
| 24 |
+
namespace QnnUtils {
|
| 25 |
+
class DataType {
|
| 26 |
+
private:
|
| 27 |
+
Qnn_DataType_t _dtype{QNN_DATATYPE_UNDEFINED};
|
| 28 |
+
|
| 29 |
+
public:
|
| 30 |
+
DataType() = default;
|
| 31 |
+
DataType(const Qnn_Tensor_t* tensor) : _dtype(QNN_TENSOR_GET_DATA_TYPE(tensor)) {}
|
| 32 |
+
DataType(Qnn_DataType_t dtype) : _dtype(dtype) {};
|
| 33 |
+
|
| 34 |
+
// Enable switch and comparisons
|
| 35 |
+
constexpr operator Qnn_DataType_t() const { return _dtype; }
|
| 36 |
+
|
| 37 |
+
int bw();
|
| 38 |
+
int type();
|
| 39 |
+
|
| 40 |
+
int32_t val();
|
| 41 |
+
|
| 42 |
+
const char* str();
|
| 43 |
+
};
|
| 44 |
+
|
| 45 |
+
bool writeRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
|
| 46 |
+
bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
|
| 47 |
+
|
| 48 |
+
struct Dims {
|
| 49 |
+
int32_t batch = 1;
|
| 50 |
+
int32_t height, width, channel, bitWidth;
|
| 51 |
+
Dims() : height(0), width(0), channel(0), bitWidth(0) {}
|
| 52 |
+
Dims(int32_t height, int32_t width, int32_t channel, int32_t bitWidth)
|
| 53 |
+
: height(height), width(width), channel(channel), bitWidth(bitWidth) {}
|
| 54 |
+
Dims(std::vector<size_t>& tDims)
|
| 55 |
+
: height((int32_t)tDims[1]), width((int32_t)tDims[2]), channel((int32_t)tDims[3]),
|
| 56 |
+
bitWidth((int32_t)tDims[4]) {
|
| 57 |
+
// Hack to mix batch dimension
|
| 58 |
+
if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0];
|
| 59 |
+
if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0];
|
| 60 |
+
}
|
| 61 |
+
bool operator==(const Dims& rhs) const {
|
| 62 |
+
return (height == rhs.height) && (width == rhs.width) && (channel == rhs.channel) &&
|
| 63 |
+
(bitWidth == rhs.bitWidth);
|
| 64 |
+
}
|
| 65 |
+
bool operator!=(const Dims& rhs) const { return !(operator==(rhs)); }
|
| 66 |
+
size_t getNumElements() const { return (size_t)(height * width * channel); }
|
| 67 |
+
size_t getSize() const { return (size_t)(batch * height * width * channel * bitWidth); }
|
| 68 |
+
size_t getAlignedSize() const {
|
| 69 |
+
size_t size = getSize();
|
| 70 |
+
if ((size & uint64_t{7}) != uint64_t{0}) {
|
| 71 |
+
size += (uint64_t{8} - (size & uint64_t{7}));
|
| 72 |
+
}
|
| 73 |
+
return size;
|
| 74 |
+
}
|
| 75 |
+
int32_t getMaxDim() const { return std::max({height, width, channel}); };
|
| 76 |
+
Dims T() const { return Dims(width, height, channel, bitWidth); }
|
| 77 |
+
};
|
| 78 |
+
|
| 79 |
+
struct QuantParam {
|
| 80 |
+
double scale;
|
| 81 |
+
int32_t offset;
|
| 82 |
+
QuantParam() {}
|
| 83 |
+
QuantParam(double scale_val, int32_t offset_val) : scale(scale_val), offset(offset_val) {}
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
struct Tensor {
|
| 87 |
+
Qnn_Tensor_t* tensor = nullptr;
|
| 88 |
+
Dims dims;
|
| 89 |
+
std::vector<QuantParam> quantParam;
|
| 90 |
+
DataType dtype;
|
| 91 |
+
Tensor() {}
|
| 92 |
+
Tensor(Qnn_Tensor_t* tensorVal, Dims dimsVal, std::vector<QuantParam> quantParamVec)
|
| 93 |
+
: tensor(tensorVal), dims(dimsVal), quantParam(quantParamVec),
|
| 94 |
+
dtype(QNN_TENSOR_GET_DATA_TYPE(tensorVal)) {}
|
| 95 |
+
};
|
| 96 |
+
|
| 97 |
+
// Maps tensor name to QnnUtils::Tensor<Qnn_Tensor_t* tensor, dims, quantparams>
|
| 98 |
+
typedef std::map<std::string, Tensor> TensorMap;
|
| 99 |
+
|
| 100 |
+
static inline uint8_t sat_round(const uint16_t x) {
|
| 101 |
+
const uint16_t rounded = x + 0x80; // add 0.5
|
| 102 |
+
const uint16_t corrected = std::max(rounded, x); // catch unsigned wrap around
|
| 103 |
+
const uint16_t shifted = corrected >> 8; // divide by 256
|
| 104 |
+
return static_cast<uint8_t>(shifted); // to 8-bit
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
static inline void downcast_u16_to_u8(uint8_t* dest, const uint16_t* src, size_t nmemb) {
|
| 108 |
+
for (size_t i = 0; i < nmemb; i++)
|
| 109 |
+
dest[i] = sat_round(src[i]);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
template <typename FloatType, typename IntType>
|
| 113 |
+
static inline void quantizeTensorPtr(
|
| 114 |
+
FloatType* tensor_float,
|
| 115 |
+
IntType* tensor_quant,
|
| 116 |
+
int32_t offset,
|
| 117 |
+
double scale,
|
| 118 |
+
size_t nmemb
|
| 119 |
+
) {
|
| 120 |
+
#pragma clang loop vectorize(enable) interleave(enable)
|
| 121 |
+
for (size_t i = 0; i < nmemb; i++) {
|
| 122 |
+
double val = tensor_float[i];
|
| 123 |
+
tensor_quant[i] = static_cast<IntType>(val / scale - offset);
|
| 124 |
+
}
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
template <typename FloatType, typename IntType>
|
| 128 |
+
static inline void perWidthQuantizeTensorPtr(
|
| 129 |
+
FloatType* tensor_float,
|
| 130 |
+
IntType* tensor_quant,
|
| 131 |
+
std::vector<QnnUtils::QuantParam>& quantParam,
|
| 132 |
+
int32_t height,
|
| 133 |
+
int32_t width,
|
| 134 |
+
int32_t channel
|
| 135 |
+
) {
|
| 136 |
+
for (size_t h = 0; h < height; h++) {
|
| 137 |
+
for (size_t w = 0; w < width; w++) {
|
| 138 |
+
double scale = quantParam[w].scale;
|
| 139 |
+
int32_t offset = quantParam[w].offset;
|
| 140 |
+
#pragma clang loop vectorize(enable) interleave(enable)
|
| 141 |
+
for (size_t c = 0; c < channel; c++) {
|
| 142 |
+
int32_t i = (h * width * channel) + (w * channel) + c;
|
| 143 |
+
double val = tensor_float[i];
|
| 144 |
+
tensor_quant[i] = static_cast<IntType>(val / scale - offset);
|
| 145 |
+
}
|
| 146 |
+
}
|
| 147 |
+
}
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
void getQuantParamString(
|
| 151 |
+
const std::vector<QuantParam>& quantParam,
|
| 152 |
+
std::string& scale_string,
|
| 153 |
+
std::string& offset_string
|
| 154 |
+
);
|
| 155 |
+
|
| 156 |
+
} // namespace QnnUtils
|
| 157 |
+
} // namespace qualla
|