jstzwjr commited on
Commit
11481cd
·
1 Parent(s): 3cdcd5e
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Genie/Genie/GenieSymbols.default +31 -0
  2. Genie/Genie/Makefile +57 -0
  3. Genie/Genie/README +16 -0
  4. Genie/Genie/make/Android.mk +56 -0
  5. Genie/Genie/make/Application.mk +14 -0
  6. Genie/Genie/make/Makefile.linux-x86_64 +192 -0
  7. Genie/Genie/src/Dialog.cpp +1804 -0
  8. Genie/Genie/src/Dialog.hpp +95 -0
  9. Genie/Genie/src/Exception.hpp +27 -0
  10. Genie/Genie/src/GenieCommon.cpp +15 -0
  11. Genie/Genie/src/GenieDialog.cpp +249 -0
  12. Genie/Genie/src/GenieDialogEmbedding.cpp +41 -0
  13. Genie/Genie/src/Macro.hpp +101 -0
  14. Genie/Genie/src/Util/HandleGenerator.hpp +62 -0
  15. Genie/Genie/src/Util/HandleManager.hpp +84 -0
  16. Genie/Genie/src/qualla/context.cpp +118 -0
  17. Genie/Genie/src/qualla/dialog.cpp +590 -0
  18. Genie/Genie/src/qualla/dialogs/basic.cpp +421 -0
  19. Genie/Genie/src/qualla/dialogs/kv-share.cpp +359 -0
  20. Genie/Genie/src/qualla/dialogs/lhd-dec.cpp +481 -0
  21. Genie/Genie/src/qualla/dialogs/multistream.cpp +300 -0
  22. Genie/Genie/src/qualla/dialogs/spec-dec.cpp +458 -0
  23. Genie/Genie/src/qualla/dialogs/ssd-q1.cpp +1046 -0
  24. Genie/Genie/src/qualla/embedding.cpp +190 -0
  25. Genie/Genie/src/qualla/engine.cpp +198 -0
  26. Genie/Genie/src/qualla/engines/lib.cpp +9 -0
  27. Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp +158 -0
  28. Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp +62 -0
  29. Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp +122 -0
  30. Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp +85 -0
  31. Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp +156 -0
  32. Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp +56 -0
  33. Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp +95 -0
  34. Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp +382 -0
  35. Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp +170 -0
  36. Genie/Genie/src/qualla/engines/qnn-api/Log.hpp +24 -0
  37. Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp +173 -0
  38. Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp +0 -0
  39. Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp +429 -0
  40. Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp +636 -0
  41. Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp +94 -0
  42. Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp +44 -0
  43. Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp +52 -0
  44. Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp +702 -0
  45. Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp +481 -0
  46. Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp +115 -0
  47. Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp +66 -0
  48. Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp +33 -0
  49. Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp +104 -0
  50. Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp +157 -0
Genie/Genie/GenieSymbols.default ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=============================================================================
2
+ #
3
+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ # All Rights Reserved.
5
+ # Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ #
7
+ #=============================================================================
8
+ {
9
+ global:
10
+ Genie_getApiMajorVersion*;
11
+ Genie_getApiMinorVersion*;
12
+ Genie_getApiPatchVersion*;
13
+ GenieDialogConfig_createFromJson*;
14
+ GenieDialogConfig_free*;
15
+ GenieDialog_create*;
16
+ GenieDialog_query*;
17
+ GenieDialog_tokenQuery*;
18
+ GenieDialog_embeddingQuery*;
19
+ GenieDialog_save*;
20
+ GenieDialog_restore*;
21
+ GenieDialog_reset*;
22
+ GenieDialog_setLoraStrength*;
23
+ GenieDialog_applyLora*;
24
+ GenieDialog_free*;
25
+ GenieEmbeddingConfig_createFromJson*;
26
+ GenieEmbeddingConfig_free*;
27
+ GenieEmbedding_create*;
28
+ GenieEmbedding_generate*;
29
+ GenieEmbedding_free*;
30
+ local: *;
31
+ };
Genie/Genie/Makefile ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=============================================================================
2
+ #
3
+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ # All Rights Reserved.
5
+ # Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ #
7
+ #=============================================================================
8
+
9
+ RUST_TARGET := aarch64-linux-android
10
+ RUST_SOURCE_DIR := ./src/qualla/tokenizers/rust
11
+ # specify compiler
12
+ export CXX := clang++-14
13
+ export PATH := $(ANDROID_NDK_ROOT)/toolchains/llvm/prebuilt/linux-x86_64/bin:$(PATH)
14
+ .PHONY: all x86 android clean clean_x86 clean_android
15
+ .DEFAULT: x86
16
+
17
+ all: x86 android
18
+
19
+ x86: build_x86_tokenizer
20
+ @echo "-------------------- Building genie for x86 -------------------- "
21
+ @$(MAKE) -f make/Makefile.linux-x86_64 CPATH="/usr/include/x86_64-linux-gnu" || (echo "-------------------- genie x86 build failed --------------------"; exit 1; )
22
+ @echo "-------------------- genie x86 build succeeded -------------------- "
23
+
24
+ android: check_ndk build_android_tokenizer
25
+ @echo "-------------------- Building genie for android -------------------- "
26
+ @$(ANDROID_NDK_ROOT)/ndk-build APP_ALLOW_MISSING_DEPS=true APP_ABI="arm64-v8a" NDK_PROJECT_PATH=./ NDK_APPLICATION_MK=make/Application.mk APP_BUILD_SCRIPT=make/Android.mk || (echo "-------------------- genie android build failed --------------------"; exit 1; )
27
+ @$(rename_target_dirs)
28
+ @echo "-------------------- genie android build succeeded -------------------- "
29
+
30
+ clean: clean_x86 clean_android
31
+
32
+ clean_x86:
33
+ @$(MAKE) -f make/Makefile.linux-x86_64 clean
34
+
35
+ clean_android:
36
+ if [ -d "lib/aarch64-android" ]; then rm -rf lib/aarch64-android; fi
37
+ if [ -d "obj/local" ]; then rm -rf obj/local; fi
38
+
39
+ # utilities
40
+ rename_target_dirs = \
41
+ @if [ -d ./lib/aarch64-android ]; then rm -rf ./lib/aarch64-android; fi; \
42
+ find ./obj/local -type d -execdir rename 's/arm64-v8a/aarch64-android/' '{}' \+ \
43
+ && mkdir -p lib \
44
+ && mv ./obj/local/aarch64-android lib/ \
45
+ && mv ./libs/arm64-v8a/libc++_shared.so lib/aarch64-android/ \
46
+ && rm -rf ./libs \
47
+
48
+ check_ndk:
49
+ ifeq ($(ANDROID_NDK_ROOT),)
50
+ $(error ERROR: ANDROID_NDK_ROOT not set, skipping compilation for Android platform(s).)
51
+ endif
52
+
53
+ build_x86_tokenizer: $(RUST_SOURCE_DIR)/Cargo.toml
54
+ cargo build --release --manifest-path=$<
55
+
56
+ build_android_tokenizer: $(RUST_SOURCE_DIR)/Cargo.toml
57
+ cargo build --release --manifest-path=$< --target=$(RUST_TARGET)
Genie/Genie/README ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=============================================================================
2
+ #
3
+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ # All Rights Reserved.
5
+ # Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ #
7
+ #=============================================================================
8
+
9
+ Genie library source code example
10
+ ---------------------------------
11
+
12
+ The Genie library (libGenie.so / Genie.dll) source code example provides users with an ability to recreate the Genie
13
+ library from source. Note that the Genie library source may be refactored, rewritten, or otherwise modified without
14
+ notice. The Genie C API is the commercially controlled and versioned interface that users should expect to be stable.
15
+ Please refer to the Genie SDK documentation tutorials at ${SDK_ROOT}/doc/Genie/ for more information on how to build the
16
+ sample code.
Genie/Genie/make/Android.mk ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=============================================================================
2
+ #
3
+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ # All Rights Reserved.
5
+ # Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ #
7
+ #=============================================================================
8
+
9
+ LOCAL_PATH := $(call my-dir)
10
+ SUPPORTED_TARGET_ABI := arm64-v8a x86 x86_64
11
+
12
+ #============================ Verify Target Info and Application Variables =========================================
13
+ ifneq ($(filter $(TARGET_ARCH_ABI),$(SUPPORTED_TARGET_ABI)),)
14
+ ifneq ($(APP_STL), c++_shared)
15
+ $(error Unsupported APP_STL: "$(APP_STL)")
16
+ endif
17
+ else
18
+ $(error Unsupported TARGET_ARCH_ABI: '$(TARGET_ARCH_ABI)')
19
+ endif
20
+
21
+ #============================ Define Common Variables ===============================================================
22
+ # PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../../include/QNN
23
+ # Include paths
24
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../include
25
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/Genie
26
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/include
27
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN
28
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN/HTP
29
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers
30
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api
31
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu
32
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp
33
+
34
+ #========================== Define T2T Lib variables =============================================
35
+ include $(CLEAR_VARS)
36
+ LOCAL_MODULE := tokenizers_capi
37
+ LOCAL_SRC_FILES := ../src/qualla/tokenizers/rust/target/aarch64-linux-android/release/libtokenizers_capi.a
38
+ include $(PREBUILT_STATIC_LIBRARY)
39
+
40
+ include $(CLEAR_VARS)
41
+ LOCAL_C_INCLUDES := $(PACKAGE_C_INCLUDES)
42
+ MY_SRC_FILES := $(wildcard $(LOCAL_PATH)/../src/*.cpp)
43
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/*.cpp)
44
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/dialogs/*.cpp)
45
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp)
46
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp)
47
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp)
48
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp)
49
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp)
50
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp)
51
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/samplers/*.cpp)
52
+
53
+ LOCAL_MODULE := libGenie
54
+ LOCAL_SRC_FILES := $(subst make/,,$(MY_SRC_FILES))
55
+ LOCAL_STATIC_LIBRARIES := tokenizers_capi
56
+ include $(BUILD_SHARED_LIBRARY)
Genie/Genie/make/Application.mk ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=============================================================================
2
+ #
3
+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ # All Rights Reserved.
5
+ # Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ #
7
+ #=============================================================================
8
+
9
+ APP_ABI := arm64-v8a
10
+ APP_STL := c++_shared
11
+ APP_PLATFORM := android-21
12
+ APP_MODULES := Genie
13
+ APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
14
+ APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all
Genie/Genie/make/Makefile.linux-x86_64 ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #=============================================================================
2
+ #
3
+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ # All Rights Reserved.
5
+ # Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ #
7
+ #=============================================================================
8
+
9
+ # define relevant directories
10
+ SRC_DIR := src/qualla
11
+ #
12
+ SRC_DIR_GENIE_TOKENIZERS := src/qualla/tokenizers
13
+ #
14
+ SRC_DIR_SAMPLE_DIALOGS := src/qualla/dialogs
15
+
16
+ # All engines
17
+ SRC_DIR_GENIE_ENGINES := src/qualla/engines
18
+ SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api
19
+ SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu
20
+ SRC_DIR_GENIE_UTILS := src/qualla/utils
21
+ #
22
+ SRC_DIR_GENIE_LOGGERS := src/qualla/loggers
23
+
24
+ #
25
+ SRC_DIR_GENIE_SAMPLERS := src/qualla/samplers
26
+
27
+ #
28
+ SRC_DIR_GENIE := src
29
+
30
+ # Includes
31
+ GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu
32
+ GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api
33
+ GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp
34
+ GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers
35
+
36
+ GENIE_INCLUDE := include
37
+ GENIE_C_API_HEADERS_INCLUDE := ../../../include/Genie
38
+ QUALLA_INCLUDE := src/qualla/include
39
+ QNN_API_INCLUDE := ../../../include/QNN/
40
+ QNN_API_HTP_INCLUDE := $(QNN_API_INCLUDE)/HTP
41
+
42
+ AR := /usr/bin/ar
43
+ ARFLAGS := rcs
44
+ # Checking if clang++ is present. If not switch to clang++
45
+ ifeq ($(shell $(CXX) -v 2>&1 | grep -c "clang version"), 0)
46
+ CXX := clang++
47
+ endif
48
+
49
+ QNN_TARGET ?= x86_64-linux-clang
50
+ export TARGET_DIR := ./lib/$(QNN_TARGET)
51
+
52
+ libGenie := $(TARGET_DIR)/libGenie.so
53
+ libtokenizers := src/qualla/tokenizers/rust/target/release/libtokenizers_capi.a
54
+
55
+ # define target architecture if not previously defined, default is x86
56
+ ifndef TARGET_AARCH_VARS
57
+ TARGET_AARCH_VARS:= -march=x86-64
58
+ endif
59
+
60
+ .PHONY: linux_x86_64
61
+ .DEFAULT: linux_x86_64
62
+ GENIE_all: $(libGenie)
63
+
64
+ # Include paths
65
+ INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE)
66
+
67
+ # set compiler flags
68
+ COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES)
69
+ COMMON_LDFLAGS = -shared -s -fPIC -pthread -L/usr/lib/x86_64-linux-gnu -L./src/qualla/tokenizers/rust/target/release
70
+
71
+ COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include
72
+
73
+ ifdef QNN_DEBUG_ENABLE
74
+ CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
75
+ CFLAGS += $(COMMON_CFLAGS)
76
+ LDFLAGS += $(COMMON_LDFLAGS)
77
+ else
78
+ CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
79
+ CFLAGS += $(COMMON_CFLAGS)
80
+ LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto
81
+ endif
82
+
83
+ # define library sources
84
+ SOURCES_GENIE_CPP := $(wildcard $(SRC_DIR_GENIE)/*.cpp)
85
+ SOURCES := $(wildcard $(SRC_DIR)/*.cpp)
86
+ SOURCES_GENIE_TOKENIZERS := $(wildcard $(SRC_DIR_GENIE_TOKENIZERS)/*.cpp)
87
+ SOURCES_GENIE_QNN_API_CPP := $(wildcard $(SRC_DIR_GENIE_QNN_API)/*.cpp)
88
+
89
+ SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp))
90
+ SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp)
91
+ SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp)
92
+ SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp)
93
+
94
+
95
+ SOURCES_GENIE_LOGGERS_CPP := $(wildcard $(SRC_DIR_GENIE_LOGGERS)/*.cpp)
96
+ SOURCES_GENIE_SAMPLERS_CPP := $(wildcard $(SRC_DIR_GENIE_SAMPLERS)/*.cpp)
97
+
98
+
99
+ # define object directory
100
+ OBJ_ROOT := obj
101
+ OBJ_DIR_QUALLA := obj/$(QNN_TARGET)/qualla
102
+ OBJ_DIR_GENIE := obj/$(QNN_TARGET)/genie
103
+ OBJ_DIR_GENIE_TOKENIZERS := $(OBJ_DIR_QUALLA)/tokenizers
104
+ OBJ_DIR_GENIE_QNN_API := $(OBJ_DIR_QUALLA)/qnn-api
105
+
106
+ OBJ_DIR_GENIE_DIALOGS := $(OBJ_DIR_QUALLA)/dialogs
107
+ OBJ_DIR_GENIE_ENGINES := $(OBJ_DIR_QUALLA)/engines
108
+ OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils
109
+ OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu
110
+ $(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU))
111
+
112
+ OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers
113
+ OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers
114
+
115
+ $(shell mkdir -p $(OBJ_DIR_GENIE))
116
+ $(shell mkdir -p $(OBJ_DIR_GENIE_LOGGERS))
117
+ $(shell mkdir -p $(OBJ_DIR_GENIE_SAMPLERS))
118
+
119
+ # setup object files in object directory
120
+ OBJECTS_GENIE := $(patsubst %.cpp,$(OBJ_DIR_GENIE)/%.o,$(foreach x,$(SOURCES_GENIE_CPP),$(notdir $(x))))
121
+ OBJECTS_QUALLA := $(patsubst %.cpp,$(OBJ_DIR_QUALLA)/%.o,$(foreach x,$(SOURCES),$(notdir $(x))))
122
+ OBJECTS_GENIE_TOKENIZERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_TOKENIZERS)/%.o,$(foreach x,$(SOURCES_GENIE_TOKENIZERS),$(notdir $(x))))
123
+ OBJECTS_GENIE_QNN_API := $(patsubst %.cpp,$(OBJ_DIR_GENIE_QNN_API)/%.o,$(foreach x,$(SOURCES_GENIE_QNN_API_CPP),$(notdir $(x))))
124
+ OBJECTS_GENIE_ENGINES := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPP),$(notdir $(x))))
125
+ OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x))))
126
+ OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x))))
127
+ OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x))))
128
+
129
+ OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x))))
130
+ OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x))))
131
+
132
+ LIBS=-ldl
133
+
134
+
135
+ # Rule to make shared lib
136
+ .PHONY: libGenie
137
+ libGenie: $(libGenie)
138
+
139
+ # Implicit rule to compile and link object files
140
+ $(OBJ_DIR_GENIE)/%.o: $(SRC_DIR_GENIE)/%.cpp
141
+ $(CXX) $(CXXFLAGS) -c $^ -o $@
142
+
143
+ $(OBJ_DIR_QUALLA)/%.o: $(SRC_DIR)/%.cpp
144
+ $(CXX) $(CXXFLAGS) -c $^ -o $@
145
+
146
+ $(OBJ_DIR_GENIE_TOKENIZERS)/%.o: $(SRC_DIR_GENIE_TOKENIZERS)/%.cpp
147
+ $(CXX) $(CXXFLAGS) -c $^ -o $@
148
+
149
+ $(OBJ_DIR_GENIE_QNN_API)/%.o: $(SRC_DIR_GENIE_QNN_API)/%.cpp
150
+ $(CXX) $(CXXFLAGS) -c $^ -o $@
151
+
152
+ $(OBJ_DIR_GENIE_ENGINES)/%.o: $(SRC_DIR_GENIE_ENGINES)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
153
+
154
+ $(OBJ_DIR_GENIE_DIALOGS)/%.o: $(SRC_DIR_SAMPLE_DIALOGS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
155
+
156
+ $(OBJ_DIR_GENIE_UTILS)/%.o: $(SRC_DIR_GENIE_UTILS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
157
+
158
+ $(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
159
+
160
+ $(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
161
+
162
+ $(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
163
+
164
+
165
+ # set up resources
166
+ directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS)
167
+
168
+ # Compile
169
+ $(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories)
170
+ $(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers)
171
+
172
+
173
+ # rule for object directory resource
174
+ $(OBJECTS_GENIE): | $(OBJ_DIR_GENIE)
175
+ $(OBJECTS_QUALLA): | $(OBJ_DIR_QUALLA)
176
+ $(OBJECTS_GENIE_TOKENIZERS): | $(OBJ_DIR_GENIE_TOKENIZERS)
177
+ $(OBJECTS_GENIE_QNN_API): | $(OBJ_DIR_GENIE_QNN_API)
178
+ $(OBJECTS_GENIE_ENGINES): | $(OBJ_DIR_GENIE_ENGINES)
179
+ $(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS)
180
+ $(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS)
181
+ $(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU)
182
+ $(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS)
183
+ $(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS)
184
+
185
+
186
+ # rule to create directories
187
+ $(directories):
188
+ mkdir -p $@
189
+
190
+ .PHONY: clean
191
+ clean:
192
+ rm -rf $(OBJ_ROOT) $(TARGET_DIR)
Genie/Genie/src/Dialog.cpp ADDED
@@ -0,0 +1,1804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <exception>
10
+ #include <set>
11
+ #include <sstream>
12
+
13
+ #include "Dialog.hpp"
14
+ #include "Exception.hpp"
15
+ #include "Macro.hpp"
16
+ #include "qualla/detail/json.hpp"
17
+ #include "qualla/env.hpp"
18
+
19
+ using namespace genie;
20
+
21
+ #ifdef _WIN32
22
+ inline std::string libPrefix = "";
23
+ inline std::string libSuffix = ".dll";
24
+ #else
25
+ inline std::string libPrefix = "lib";
26
+ inline std::string libSuffix = ".so";
27
+ #endif
28
+
29
+ inline std::string getLibName(std::string baseName) { return libPrefix + baseName + libSuffix; }
30
+
31
+ //=============================================================================
32
+ // Context::Config functions
33
+ //=============================================================================
34
+
35
+ static void validateContextConfig(const qualla::json& config) {
36
+ if (!config.is_object()) {
37
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "context config is not an object");
38
+ }
39
+
40
+ std::set<std::string> mandatoryFields{"version", "bos-token", "eos-token", "size", "n-vocab"};
41
+ for (const auto& field : mandatoryFields) {
42
+ if (!config.contains(field)) {
43
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing context field: " + field);
44
+ }
45
+ }
46
+
47
+ // component is used in the "ENFORCE" macros
48
+ std::string component = "context";
49
+
50
+ for (auto& item : config.items()) {
51
+ if (item.key() == "version") {
52
+ JSON_ENFORCE_NUMERIC();
53
+ if (item.value().get<int>() != 1) {
54
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
55
+ "Invalid context config: unsupported version: " + item.value().dump());
56
+ }
57
+ } else if (item.key() == "bos-token") {
58
+ JSON_ENFORCE_NUMERIC();
59
+ } else if (item.key() == "eos-token") {
60
+ JSON_ENFORCE_ARRAY_OR_NUMERIC();
61
+ } else if (item.key() == "eot-token") {
62
+ JSON_ENFORCE_NUMERIC();
63
+ } else if (item.key() == "size") {
64
+ JSON_ENFORCE_NUMERIC();
65
+ } else if (item.key() == "n-vocab") {
66
+ JSON_ENFORCE_NUMERIC();
67
+ } else if (item.key() == "pad-token") {
68
+ JSON_ENFORCE_NUMERIC();
69
+ } else {
70
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
71
+ }
72
+ }
73
+ }
74
+
75
+ static void translateContextConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
76
+ if (genieConfig["dialog"].contains("context")) {
77
+ if (genieConfig["dialog"]["context"].contains("bos-token")) {
78
+ quallaConfig["context"]["bos-token"] = genieConfig["dialog"]["context"]["bos-token"];
79
+ }
80
+ if (genieConfig["dialog"]["context"].contains("eos-token")) {
81
+ quallaConfig["context"]["eos-token"] = genieConfig["dialog"]["context"]["eos-token"];
82
+ }
83
+ if (genieConfig["dialog"]["context"].contains("eot-token")) {
84
+ quallaConfig["context"]["eot-token"] = genieConfig["dialog"]["context"]["eot-token"];
85
+ }
86
+ if (genieConfig["dialog"]["context"].contains("size")) {
87
+ quallaConfig["context"]["size"] = genieConfig["dialog"]["context"]["size"];
88
+ }
89
+ if (genieConfig["dialog"]["context"].contains("n-vocab")) {
90
+ quallaConfig["context"]["n-vocab"] = genieConfig["dialog"]["context"]["n-vocab"];
91
+ }
92
+ if (genieConfig["dialog"]["context"].contains("pad-token")) {
93
+ quallaConfig["context"]["pad-token"] = genieConfig["dialog"]["context"]["pad-token"];
94
+ }
95
+ }
96
+ }
97
+
98
+ //=============================================================================
99
+ // Sampler::Config functions
100
+ //=============================================================================
101
+
102
+ static void validateSamplerConfig(const qualla::json& config) {
103
+ if (!config.is_object()) {
104
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
105
+ }
106
+
107
+ std::set<std::string> mandatoryFields{"version"};
108
+ for (const auto& field : mandatoryFields) {
109
+ if (!config.contains(field)) {
110
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
111
+ }
112
+ }
113
+
114
+ // component is used in the "ENFORCE" macros
115
+ std::string component = "sampler";
116
+
117
+ for (auto& item : config.items()) {
118
+ if (item.key() == "version") {
119
+ JSON_ENFORCE_NUMERIC();
120
+ if (item.value().get<int>() != 1) {
121
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
122
+ "Invalid sampler config: unsupported version: " + item.value().dump());
123
+ }
124
+ } else if (item.key() == "seed") {
125
+ JSON_ENFORCE_NUMERIC();
126
+ } else if (item.key() == "temp") {
127
+ JSON_ENFORCE_NUMERIC();
128
+ } else if (item.key() == "top-k") {
129
+ JSON_ENFORCE_NUMERIC();
130
+ } else if (item.key() == "top-p") {
131
+ JSON_ENFORCE_NUMERIC();
132
+ } else if (item.key() == "greedy") {
133
+ JSON_ENFORCE_BOOLEAN();
134
+ } else {
135
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
136
+ }
137
+ }
138
+ }
139
+
140
+ static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
141
+ if (genieConfig["dialog"].contains("sampler")) {
142
+ quallaConfig["sampler"]["type"] = "basic";
143
+
144
+ if (genieConfig["dialog"]["sampler"].contains("seed")) {
145
+ quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
146
+ }
147
+ if (genieConfig["dialog"]["sampler"].contains("temp")) {
148
+ quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
149
+ }
150
+
151
+ quallaConfig["sampler"]["role"] = "primary";
152
+ #if defined(GENIE_SPD_FEATURE)
153
+ if (genieConfig["dialog"]["type"] == "spd") {
154
+ quallaConfig["sampler"]["role"] = "target";
155
+ }
156
+ #endif
157
+
158
+ if (genieConfig["dialog"]["sampler"].contains("top-k")) {
159
+ quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
160
+ }
161
+ if (genieConfig["dialog"]["sampler"].contains("top-p")) {
162
+ quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
163
+ }
164
+ if (genieConfig["dialog"]["sampler"].contains("greedy")) {
165
+ quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
166
+ }
167
+ if (genieConfig["dialog"]["sampler"].contains("seed")) {
168
+ quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
169
+ }
170
+ }
171
+ }
172
+
173
+ //=============================================================================
174
+ // Tokenizer::Config functions
175
+ //=============================================================================
176
+
177
+ static void validateTokenizerConfig(const qualla::json& config) {
178
+ if (!config.is_object()) {
179
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "tokenizer config is not an object");
180
+ }
181
+
182
+ std::set<std::string> mandatoryFields{"version", "path"};
183
+ for (const auto& field : mandatoryFields) {
184
+ if (!config.contains(field)) {
185
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing tokenizer field: " + field);
186
+ }
187
+ }
188
+
189
+ // component is used in the "ENFORCE" macros
190
+ std::string component = "tokenizer";
191
+
192
+ for (auto& item : config.items()) {
193
+ if (item.key() == "version") {
194
+ JSON_ENFORCE_NUMERIC();
195
+ if (item.value().get<int>() != 1) {
196
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
197
+ "Invalid tokenizer config: unsupported version: " + item.value().dump());
198
+ }
199
+ } else if (item.key() == "path") {
200
+ JSON_ENFORCE_STRING();
201
+ // Note: the existence of this file is checked by qualla
202
+ } else {
203
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
204
+ "Unknown tokenizer config key: " + item.key());
205
+ }
206
+ }
207
+ }
208
+
209
+ static void translateTokenizerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
210
+ quallaConfig["tokenizer"] = genieConfig["dialog"]["tokenizer"]["path"];
211
+ }
212
+
213
+ //=============================================================================
214
+ // Embedding::Config functions
215
+ //=============================================================================
216
+
217
+ static void validateEmbeddingConfig(const qualla::json& config) {
218
+ if (!config.is_object()) {
219
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "embedding config is not an object");
220
+ }
221
+
222
+ std::set<std::string> mandatoryFields{"version", "size"};
223
+ for (const auto& field : mandatoryFields) {
224
+ if (!config.contains(field)) {
225
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
226
+ }
227
+ }
228
+
229
+ // component is used in the "ENFORCE" macros
230
+ std::string component = "embedding";
231
+
232
+ for (auto& item : config.items()) {
233
+ if (item.key() == "version") {
234
+ JSON_ENFORCE_NUMERIC();
235
+ if (item.value().get<int>() != 1) {
236
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
237
+ "Invalid embedding config: unsupported version: " + item.value().dump());
238
+ }
239
+ } else if (item.key() == "size") {
240
+ JSON_ENFORCE_NUMERIC();
241
+ } else if (item.key() == "datatype") {
242
+ JSON_ENFORCE_STRING();
243
+ const std::set<std::string> supportedTypes = {"float32", "native"};
244
+ if (std::find(supportedTypes.begin(), supportedTypes.end(), item.value()) ==
245
+ supportedTypes.end()) {
246
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
247
+ "Unknown embedding datatype: " + std::string(item.value()));
248
+ }
249
+ } else {
250
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
251
+ "Unknown embedding config key: " + item.key());
252
+ }
253
+ }
254
+ }
255
+
256
+ static void translateEmbeddingConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
257
+ if (genieConfig["dialog"].contains("embedding")) {
258
+ quallaConfig["context"]["n-embd"] = genieConfig["dialog"]["embedding"]["size"];
259
+
260
+ if (genieConfig["dialog"]["embedding"].contains("datatype")) {
261
+ quallaConfig["context"]["embedding-datatype"] =
262
+ genieConfig["dialog"]["embedding"]["datatype"];
263
+ }
264
+ }
265
+ }
266
+
267
+ bool position_dim_set = false;
268
+ bool rope_theta_set = false;
269
+
270
+ //=============================================================================
271
+ // Backend::Config functions
272
+ //=============================================================================
273
+
274
+ static void validateBackendHtpConfig(const qualla::json& config) {
275
+ if (!config.is_object()) {
276
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnHtp config is not an object");
277
+ }
278
+
279
+ std::set<std::string> mandatoryFields{
280
+ "version", "spill-fill-bufsize", "mmap-budget", "use-mmap", "cpu-mask", "poll"};
281
+ for (const auto& field : mandatoryFields) {
282
+ if (!config.contains(field)) {
283
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp field: " + field);
284
+ }
285
+ }
286
+
287
+ // component is used in the "ENFORCE" macros
288
+ std::string component = "QnnHtp";
289
+
290
+ for (auto& item : config.items()) {
291
+ if (item.key() == "version") {
292
+ JSON_ENFORCE_NUMERIC();
293
+ if (item.value().get<int>() != 1) {
294
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
295
+ "Invalid QnnHtp config: unsupported version: " + item.value().dump());
296
+ }
297
+ } else if (item.key() == "spill-fill-bufsize") {
298
+ JSON_ENFORCE_NUMERIC();
299
+ } else if (item.key() == "mmap-budget") {
300
+ JSON_ENFORCE_NUMERIC();
301
+ } else if (item.key() == "use-mmap") {
302
+ JSON_ENFORCE_BOOLEAN();
303
+ #ifdef _WIN32
304
+ if (item.value() == true) {
305
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
306
+ "Invalid QnnHtp config. use-mmap not supported on target");
307
+ }
308
+ #endif
309
+ } else if (item.key() == "pos-id-dim") {
310
+ position_dim_set = true;
311
+ JSON_ENFORCE_NUMERIC();
312
+ } else if (item.key() == "cpu-mask") {
313
+ JSON_ENFORCE_STRING();
314
+ } else if (item.key() == "poll") {
315
+ JSON_ENFORCE_BOOLEAN();
316
+ } else if (item.key() == "kv-dim") {
317
+ JSON_ENFORCE_NUMERIC();
318
+ } else if (item.key() == "kv-update-method") {
319
+ JSON_ENFORCE_STRING();
320
+ } else if (item.key() == "allow-async-init") {
321
+ JSON_ENFORCE_BOOLEAN();
322
+ } else if (item.key() == "rope-theta") {
323
+ rope_theta_set = true;
324
+ JSON_ENFORCE_NUMERIC();
325
+ } else {
326
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
327
+ }
328
+ }
329
+ }
330
+
331
+ static void validateBackendGenaiConfig(const qualla::json& config) {
332
+ if (!config.is_object()) {
333
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnGenAiTransformer config is not an object");
334
+ }
335
+
336
+ std::set<std::string> mandatoryFields{"version"};
337
+ for (const auto& field : mandatoryFields) {
338
+ if (!config.contains(field)) {
339
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
340
+ "Missing QnnGenAiTransformer field: " + field);
341
+ }
342
+ }
343
+
344
+ // component is used in the "ENFORCE" macros
345
+ std::string component = "QnnGenAiTransformer";
346
+
347
+ for (auto& item : config.items()) {
348
+ if (item.key() == "version") {
349
+ JSON_ENFORCE_NUMERIC();
350
+ if (item.value().get<int>() != 1) {
351
+ throw Exception(
352
+ GENIE_STATUS_ERROR_JSON_VALUE,
353
+ "Invalid QnnGenAiTransformer config: unsupported version: " + item.value().dump());
354
+ }
355
+ } else if (item.key() == "use-mmap") {
356
+ JSON_ENFORCE_BOOLEAN();
357
+ #ifdef _WIN32
358
+ if (item.value() == true) {
359
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
360
+ "Invalid QnnGenAiTransformer config. use-mmap not supported on target");
361
+ }
362
+ #endif
363
+ } else if (item.key() == "n-logits") {
364
+ JSON_ENFORCE_NUMERIC();
365
+ } else if (item.key() == "n-layer") {
366
+ JSON_ENFORCE_NUMERIC();
367
+ } else if (item.key() == "n-embd") {
368
+ JSON_ENFORCE_NUMERIC();
369
+ } else if (item.key() == "n-heads") {
370
+ JSON_ENFORCE_NUMERIC();
371
+ } else {
372
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
373
+ "Unknown QnnGenAiTransformer config key: " + item.key());
374
+ }
375
+ }
376
+ }
377
+
378
+ static void validateBackendConfig(const qualla::json& config) {
379
+ if (!config.is_object()) {
380
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "backend config is not an object");
381
+ }
382
+
383
+ std::set<std::string> mandatoryFields{"version", "type"};
384
+ for (const auto& field : mandatoryFields) {
385
+ if (!config.contains(field)) {
386
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing backend field: " + field);
387
+ }
388
+ }
389
+
390
+ // component is used in the "ENFORCE" macros
391
+ std::string component = "backend";
392
+
393
+ std::string type;
394
+ bool htp = false;
395
+ qualla::json htpConfig;
396
+ bool genai = false;
397
+ qualla::json genaiConfig;
398
+
399
+ for (auto& item : config.items()) {
400
+ if (item.key() == "version") {
401
+ JSON_ENFORCE_NUMERIC();
402
+ if (item.value().get<int>() != 1) {
403
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
404
+ "Invalid backend config: unsupported version: " + item.value().dump());
405
+ }
406
+ } else if (item.key() == "type") {
407
+ JSON_ENFORCE_STRING();
408
+ type = item.value().get<std::string>();
409
+ if (type == "QnnHtp") {
410
+ htp = true;
411
+ } else if (type == "QnnGenAiTransformer") {
412
+ genai = true;
413
+ } else {
414
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
415
+ "Invalid backend config: unsupported type: " + item.value().dump());
416
+ }
417
+ } else if (item.key() == "extensions") {
418
+ JSON_ENFORCE_STRING();
419
+ } else if (item.key() == "QnnHtp") {
420
+ JSON_ENFORCE_OBJECT();
421
+ htpConfig = item.value();
422
+ } else if (item.key() == "QnnGenAiTransformer") {
423
+ JSON_ENFORCE_OBJECT();
424
+ genaiConfig = item.value();
425
+ } else {
426
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown backend config key: " + item.key());
427
+ }
428
+ }
429
+
430
+ if (htp) {
431
+ if (!htpConfig.is_object()) {
432
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp dialog config");
433
+ }
434
+ validateBackendHtpConfig(htpConfig);
435
+ } else {
436
+ if (htpConfig.is_object()) {
437
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
438
+ "QnnHtp backend config for incorrect backend type: " + type);
439
+ }
440
+ }
441
+
442
+ if (genai) {
443
+ if (!genaiConfig.is_object()) {
444
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnGenAiTransformer dialog config");
445
+ }
446
+ validateBackendGenaiConfig(genaiConfig);
447
+ } else {
448
+ if (genaiConfig.is_object()) {
449
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
450
+ "QnnGenAiTransformer backend config for incorrect backend type: " + type);
451
+ }
452
+ }
453
+ }
454
+
455
+ //=============================================================================
456
+ // Model::Config functions
457
+ //=============================================================================
458
+
459
+ static void validateLoraAdapterConfig(const qualla::json& config,
460
+ LORA_VERSION& specifiedLoraVersion) {
461
+ if (!config.is_object()) {
462
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora adapter config is not an object");
463
+ }
464
+ const std::set<std::string> mandatoryFields{"version", "name"};
465
+ for (const auto& field : mandatoryFields) {
466
+ if (!config.contains(field)) {
467
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora adapter field: " + field);
468
+ }
469
+ }
470
+
471
+ // component is used in the "ENFORCE" macros
472
+ const std::string component = "lora adapter";
473
+ LORA_VERSION configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED;
474
+ for (auto& item : config.items()) {
475
+ if (item.key() == "version") {
476
+ JSON_ENFORCE_NUMERIC();
477
+ if (item.value().get<int>() != 1) {
478
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
479
+ "Invalid lora config: unsupported version: " + item.value().dump());
480
+ }
481
+ } else if (item.key() == "name") {
482
+ JSON_ENFORCE_STRING();
483
+ } else if (item.key() == "bin-sections") {
484
+ JSON_ENFORCE_ARRAY();
485
+ configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; // Adapter occurs with V2
486
+ for (auto& elem : item.value()) {
487
+ if (!elem.is_string()) {
488
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
489
+ "bin-sections must be an array of strings");
490
+ }
491
+ }
492
+ } else if (item.key() == "path") {
493
+ configuredLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1; // Weights are V1
494
+ JSON_ENFORCE_STRING();
495
+ // Note:all directory validations will done by NSP engine
496
+ } else {
497
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
498
+ "Unknown lora adapter config key: " + item.key());
499
+ }
500
+ }
501
+
502
+ if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1 &&
503
+ configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2) {
504
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
505
+ "LoRA Adapters must be used with lora version: 2");
506
+ } else if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V2 &&
507
+ configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_V1) {
508
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
509
+ "LoRA Weights must be used with lora version: 1");
510
+ } else if (configuredLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) {
511
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid lora config.");
512
+ }
513
+ }
514
+
515
+ static void validateLoraConfig(const qualla::json& config) {
516
+ if (!config.is_object()) {
517
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lora config is not an object");
518
+ }
519
+
520
+ const std::set<std::string> mandatoryFields{"version", "adapters"};
521
+ for (const auto& field : mandatoryFields) {
522
+ if (!config.contains(field)) {
523
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lora field: " + field);
524
+ }
525
+ }
526
+
527
+ // component is used in the "ENFORCE" macros
528
+ const std::string component = "lora";
529
+ LORA_VERSION specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2; // Default is loraV2
530
+ if (config.find("lora-version") != config.end()) {
531
+ switch (static_cast<uint8_t>(config["lora-version"])) {
532
+ case 1:
533
+ specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V1;
534
+ break;
535
+ case 2:
536
+ specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_V2;
537
+ break;
538
+ default:
539
+ specifiedLoraVersion = LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED;
540
+ break;
541
+ }
542
+ }
543
+
544
+ for (auto& item : config.items()) {
545
+ if (item.key() == "version") {
546
+ JSON_ENFORCE_NUMERIC();
547
+ if (item.value().get<int>() != 1) {
548
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
549
+ "Invalid lora config: unsupported version: " + item.value().dump());
550
+ }
551
+ } else if (item.key() == "alpha-tensor-name") {
552
+ JSON_ENFORCE_STRING();
553
+ } else if (item.key() == "adapters") {
554
+ JSON_ENFORCE_ARRAY();
555
+ for (auto& elem : item.value()) {
556
+ validateLoraAdapterConfig(elem, specifiedLoraVersion);
557
+ }
558
+ } else if (item.key() == "lora-version") { // Optional
559
+ JSON_ENFORCE_NUMERIC();
560
+ } else {
561
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lora config key: " + item.key());
562
+ }
563
+ }
564
+ if (specifiedLoraVersion == LORA_VERSION::GENIE_LORA_VERSION_UNDEFINED) {
565
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
566
+ "Unsupported lora version: " + to_string(config["lora-version"]));
567
+ }
568
+ }
569
+
570
+ static void validateModelBinaryConfig(const qualla::json& config) {
571
+ if (!config.is_object()) {
572
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "binary config is not an object");
573
+ }
574
+
575
+ std::set<std::string> mandatoryFields{"version", "ctx-bins"};
576
+ for (const auto& field : mandatoryFields) {
577
+ if (!config.contains(field)) {
578
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary field: " + field);
579
+ }
580
+ }
581
+
582
+ // component is used in the "ENFORCE" macros
583
+ std::string component = "binary";
584
+
585
+ for (auto& item : config.items()) {
586
+ if (item.key() == "version") {
587
+ JSON_ENFORCE_NUMERIC();
588
+ if (item.value().get<int>() != 1) {
589
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
590
+ "Invalid binary config: unsupported version: " + item.value().dump());
591
+ }
592
+ } else if (item.key() == "ctx-bins") {
593
+ JSON_ENFORCE_ARRAY();
594
+ for (auto& elem : item.value()) {
595
+ if (!elem.is_string()) {
596
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "ctx-bins must be an array of strings");
597
+ }
598
+ }
599
+ } else if (item.key() == "lora") {
600
+ JSON_ENFORCE_OBJECT();
601
+ validateLoraConfig(item.value());
602
+ } else {
603
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown binary config key: " + item.key());
604
+ }
605
+ }
606
+ }
607
+
608
+ static void validateModelLibraryConfig(const qualla::json& config) {
609
+ if (!config.is_object()) {
610
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "library config is not an object");
611
+ }
612
+
613
+ std::set<std::string> mandatoryFields{"version", "model-bin"};
614
+ for (const auto& field : mandatoryFields) {
615
+ if (!config.contains(field)) {
616
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library field: " + field);
617
+ }
618
+ }
619
+
620
+ // component is used in the "ENFORCE" macros
621
+ std::string component = "library";
622
+
623
+ for (auto& item : config.items()) {
624
+ if (item.key() == "version") {
625
+ JSON_ENFORCE_NUMERIC();
626
+ if (item.value().get<int>() != 1) {
627
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
628
+ "Invalid library config: unsupported version: " + item.value().dump());
629
+ }
630
+ } else if (item.key() == "model-bin") {
631
+ JSON_ENFORCE_STRING();
632
+ } else {
633
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
634
+ }
635
+ }
636
+ }
637
+
638
+ static void validateRopeScalingConfig(const qualla::json& config) {
639
+ // component is used in the "ENFORCE" macros
640
+ std::string component = "rope-scaling";
641
+ if (config.is_object()) {
642
+ std::string ropeType;
643
+ for (auto& item : config.items()) {
644
+ if (item.key() == "rope-type") {
645
+ JSON_ENFORCE_STRING();
646
+ ropeType = item.value().get<std::string>();
647
+ if (ropeType != "llama3" && ropeType != "default" && ropeType != "longrope") {
648
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Rope type not supported" + ropeType);
649
+ }
650
+ } else if (item.key() == "factor" || item.key() == "low-freq-factor" ||
651
+ item.key() == "high-freq-factor" ||
652
+ item.key() == "original-max-position-embeddings") {
653
+ JSON_ENFORCE_NUMERIC();
654
+ } else if (item.key() == "short-factor" || item.key() == "long-factor") {
655
+ JSON_ENFORCE_ARRAY();
656
+ } else {
657
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
658
+ "Rope scaling parameter not supported " + item.key());
659
+ }
660
+ }
661
+ }
662
+ }
663
+
664
+ static void validatePositionalEncodingConfig(const qualla::json& config) {
665
+ // component is used in the "ENFORCE" macros
666
+ std::string component = "positional-encoding";
667
+ qualla::json ropeScalingConfig;
668
+ if (config.is_object()) {
669
+ for (auto& item : config.items()) {
670
+ if (item.key() == "type") {
671
+ std::string positionEncodingType = item.value().get<std::string>();
672
+ if (positionEncodingType != "rope" && positionEncodingType != "absolute" &&
673
+ positionEncodingType != "alibi") {
674
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "positional-encoding type not supported");
675
+ }
676
+ } else if (item.key() == "rope-dim") {
677
+ JSON_ENFORCE_NUMERIC();
678
+ } else if (item.key() == "rope-theta") {
679
+ JSON_ENFORCE_NUMERIC();
680
+ } else if (item.key() == "rope-scaling") {
681
+ JSON_ENFORCE_OBJECT();
682
+ ropeScalingConfig = item.value();
683
+ } else {
684
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
685
+ "Unknown positional encoding config key: " + item.key());
686
+ }
687
+ }
688
+ }
689
+ if (position_dim_set) {
690
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
691
+ "Specify one config from pos-id-dim and positional-encoding");
692
+ }
693
+ if (rope_theta_set) {
694
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
695
+ "Specify one config from rope-theta and positional-encoding");
696
+ }
697
+ if (ropeScalingConfig.is_object()) {
698
+ validateRopeScalingConfig(ropeScalingConfig);
699
+ }
700
+ }
701
+
702
+ static void validateModelConfig(const qualla::json& config) {
703
+ if (!config.is_object()) {
704
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "model config is not an object");
705
+ }
706
+
707
+ std::set<std::string> mandatoryFields{"version", "type"};
708
+ for (const auto& field : mandatoryFields) {
709
+ if (!config.contains(field)) {
710
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing model field: " + field);
711
+ }
712
+ }
713
+
714
+ // component is used in the "ENFORCE" macros
715
+ std::string component = "model";
716
+
717
+ std::string type;
718
+ bool binary = false;
719
+ qualla::json binaryConfig;
720
+ bool library = false;
721
+ qualla::json libraryConfig;
722
+ qualla::json positionalEncodingConfig;
723
+ bool positionalEncoding = false;
724
+
725
+ for (auto& item : config.items()) {
726
+ if (item.key() == "version") {
727
+ JSON_ENFORCE_NUMERIC();
728
+ if (item.value().get<int>() != 1) {
729
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
730
+ "Invalid model config: unsupported version: " + item.value().dump());
731
+ }
732
+ } else if (item.key() == "type") {
733
+ JSON_ENFORCE_STRING();
734
+ type = item.value().get<std::string>();
735
+ if (type == "binary") {
736
+ binary = true;
737
+ } else if (type == "library") {
738
+ library = true;
739
+ } else {
740
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
741
+ "Invalid model config: unsupported type: " + item.value().dump());
742
+ }
743
+ } else if (item.key() == "binary") {
744
+ JSON_ENFORCE_OBJECT();
745
+ binaryConfig = item.value();
746
+ } else if (item.key() == "library") {
747
+ JSON_ENFORCE_OBJECT();
748
+ libraryConfig = item.value();
749
+ } else if (item.key() == "positional-encoding") {
750
+ JSON_ENFORCE_OBJECT();
751
+ positionalEncodingConfig = item.value();
752
+ positionalEncoding = true;
753
+ } else {
754
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown model config key: " + item.key());
755
+ }
756
+ }
757
+
758
+ if (binary) {
759
+ if (!binaryConfig.is_object()) {
760
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary model config");
761
+ }
762
+ validateModelBinaryConfig(binaryConfig);
763
+ } else {
764
+ if (binaryConfig.is_object()) {
765
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
766
+ "binary model config for incorrect model type: " + type);
767
+ }
768
+ }
769
+
770
+ if (library) {
771
+ if (!libraryConfig.is_object()) {
772
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library model config");
773
+ }
774
+ validateModelLibraryConfig(libraryConfig);
775
+ } else {
776
+ if (libraryConfig.is_object()) {
777
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
778
+ "library model config for incorrect model type: " + type);
779
+ }
780
+ }
781
+
782
+ if (positionalEncoding) {
783
+ if (!positionalEncodingConfig.is_object()) {
784
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing Positional encoding config");
785
+ }
786
+ validatePositionalEncodingConfig(positionalEncodingConfig);
787
+ } else {
788
+ if (positionalEncodingConfig.is_object()) {
789
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
790
+ "Positional encoding config for incorrect model type: " + type);
791
+ }
792
+ }
793
+ }
794
+
795
+ //=============================================================================
796
+ // Engine::Config functions
797
+ //=============================================================================
798
+
799
+ static void validateEngineConfig(const qualla::json& config, std::string dialogType) {
800
+ if (!config.is_object()) {
801
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object");
802
+ }
803
+
804
+ std::set<std::string> mandatoryFields{"version", "backend", "model", "n-threads"};
805
+ #if defined(GENIE_SPD_FEATURE)
806
+ if (dialogType == "spd") {
807
+ mandatoryFields.insert("role");
808
+ }
809
+ #endif
810
+ if (dialogType == "kv-share") {
811
+ mandatoryFields.insert("role");
812
+ }
813
+
814
+ for (const auto& field : mandatoryFields) {
815
+ if (!config.contains(field)) {
816
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing engine field: " + field);
817
+ }
818
+ }
819
+
820
+ // component is used in the "ENFORCE" macros
821
+ std::string component = "engine";
822
+
823
+ for (auto& item : config.items()) {
824
+ if (item.key() == "version") {
825
+ JSON_ENFORCE_NUMERIC();
826
+ if (item.value().get<int>() != 1) {
827
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
828
+ "Invalid engine config: unsupported version: " + item.value().dump());
829
+ }
830
+ } else if (item.key() == "backend") {
831
+ JSON_ENFORCE_OBJECT();
832
+ validateBackendConfig(item.value());
833
+ } else if (item.key() == "model") {
834
+ JSON_ENFORCE_OBJECT();
835
+ validateModelConfig(item.value());
836
+ } else if (item.key() == "n-threads") {
837
+ JSON_ENFORCE_NUMERIC();
838
+ #if defined(GENIE_SPD_FEATURE)
839
+ } else if (item.key() == "role" && dialogType == "spd") {
840
+ JSON_ENFORCE_STRING();
841
+ if (item.value() != "draft" && item.value() != "target") {
842
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
843
+ "Unknown value: for engine config key: " + item.key());
844
+ }
845
+ #endif
846
+ } else if (item.key() == "role" && dialogType == "kv-share") {
847
+ JSON_ENFORCE_STRING();
848
+ if (item.value() != "primary" && item.value() != "secondary") {
849
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
850
+ "Unknown value: for engine config key: " + item.key());
851
+ }
852
+ } else {
853
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown engine config key: " + item.key());
854
+ }
855
+ }
856
+ }
857
+
858
+ static void validateMultiEngineConfig(const qualla::json& configs, std::string dialogType) {
859
+ if (configs.is_object()) {
860
+ validateEngineConfig(configs, dialogType);
861
+ #if defined(GENIE_SPD_FEATURE)
862
+ if (dialogType == "spd") {
863
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for spd is not an array");
864
+ }
865
+ #endif
866
+ if (dialogType == "kv-share") {
867
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config for kv-share is not an array");
868
+ }
869
+ #if defined(GENIE_SPD_FEATURE)
870
+ } else if (configs.is_array() && dialogType == "spd") {
871
+ if (configs.size() != 2) {
872
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
873
+ "engine config for spd contain invalid number of engines");
874
+ }
875
+ bool engineRoleMask[2] = {false, false};
876
+ for (auto& item : configs) {
877
+ validateEngineConfig(item, dialogType);
878
+ if (item["role"] == "draft") {
879
+ engineRoleMask[0] = true;
880
+ } else if (item["role"] == "target") {
881
+ engineRoleMask[1] = true;
882
+ }
883
+ }
884
+ if (!engineRoleMask[0]) {
885
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
886
+ "engine config for spd does not contain draft engine");
887
+ }
888
+ if (!engineRoleMask[1]) {
889
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
890
+ "engine config for spd does not contain target engine");
891
+ }
892
+ #endif
893
+ } else if (configs.is_array() && dialogType == "kv-share") {
894
+ if (configs.size() != 2) {
895
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
896
+ "engine config for kv-share contain invalid number of engines");
897
+ }
898
+ bool engineRoleMask[2] = {false, false};
899
+ for (auto& item : configs) {
900
+ validateEngineConfig(item, dialogType);
901
+ if (item["role"] == "primary") {
902
+ engineRoleMask[0] = true;
903
+ } else if (item["role"] == "secondary") {
904
+ engineRoleMask[1] = true;
905
+ }
906
+ }
907
+ if (!engineRoleMask[0]) {
908
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
909
+ "engine config for kv-share does not contain primary");
910
+ }
911
+ if (!engineRoleMask[1]) {
912
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
913
+ "engine config for kv-share does not contain secondary");
914
+ }
915
+ } else {
916
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object or an array");
917
+ }
918
+ }
919
+
920
+ static void translateEngineConfig(const qualla::json& genieEngineConfig,
921
+ qualla::json& quallaEngineConfig) {
922
+ if (genieEngineConfig["version"] == 1) {
923
+ if (genieEngineConfig.contains("role")) {
924
+ quallaEngineConfig["role"] = genieEngineConfig["role"];
925
+ } else {
926
+ quallaEngineConfig["role"] = "primary";
927
+ }
928
+
929
+ quallaEngineConfig["n-threads"] = genieEngineConfig["n-threads"];
930
+
931
+ if (genieEngineConfig["backend"]["type"] == "QnnHtp") {
932
+ quallaEngineConfig["type"] = "qnn-htp";
933
+ quallaEngineConfig["backend-lib"] = getLibName("QnnHtp");
934
+ quallaEngineConfig["mmap-budget"] = genieEngineConfig["backend"]["QnnHtp"]["mmap-budget"];
935
+ quallaEngineConfig["use-mmap"] = genieEngineConfig["backend"]["QnnHtp"]["use-mmap"];
936
+ quallaEngineConfig["spill-fill-bufsize"] =
937
+ genieEngineConfig["backend"]["QnnHtp"]["spill-fill-bufsize"];
938
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("pos-id-dim")) {
939
+ quallaEngineConfig["pos-id-dim"] = genieEngineConfig["backend"]["QnnHtp"]["pos-id-dim"];
940
+ }
941
+ quallaEngineConfig["cpumask"] = genieEngineConfig["backend"]["QnnHtp"]["cpu-mask"];
942
+ quallaEngineConfig["poll"] = genieEngineConfig["backend"]["QnnHtp"]["poll"];
943
+ quallaEngineConfig["kv-dim"] = genieEngineConfig["backend"]["QnnHtp"]["kv-dim"];
944
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("rope-theta")) {
945
+ quallaEngineConfig["rope-theta"] = genieEngineConfig["backend"]["QnnHtp"]["rope-theta"];
946
+ }
947
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("kv-update-method")) {
948
+ quallaEngineConfig["kv-update-method"] =
949
+ genieEngineConfig["backend"]["QnnHtp"]["kv-update-method"];
950
+ }
951
+ // By default, Qualla will default to the async init path.
952
+ // For now, we are forcing async init off unless explicitly
953
+ // specified in the Genie config. It is HTP specific feature only.
954
+ quallaEngineConfig["use-async-Init"] = false;
955
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("allow-async-init")) {
956
+ quallaEngineConfig["use-async-Init"] =
957
+ genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
958
+ }
959
+ } else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
960
+ quallaEngineConfig["type"] = "qnn-cpu";
961
+ quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
962
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-logits")) {
963
+ quallaEngineConfig["n_logits"] =
964
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-logits"];
965
+ }
966
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("use-mmap")) {
967
+ quallaEngineConfig["use-mmap"] =
968
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["use-mmap"];
969
+ }
970
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-layer")) {
971
+ quallaEngineConfig["n_layer"] =
972
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-layer"];
973
+ }
974
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-embd")) {
975
+ quallaEngineConfig["n_embd"] =
976
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-embd"];
977
+ }
978
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-heads")) {
979
+ quallaEngineConfig["n_heads"] =
980
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
981
+ }
982
+ }
983
+
984
+ if (genieEngineConfig["backend"].contains("extensions")) {
985
+ quallaEngineConfig["backend-ext-conf"] = genieEngineConfig["backend"]["extensions"];
986
+ }
987
+
988
+ if (genieEngineConfig["model"]["type"] == "binary") {
989
+ quallaEngineConfig["model-list"] = genieEngineConfig["model"]["binary"]["ctx-bins"];
990
+ if (genieEngineConfig["model"]["binary"].contains("lora")) {
991
+ quallaEngineConfig["lora-version"] =
992
+ static_cast<uint8_t>(LORA_VERSION::GENIE_LORA_VERSION_V2);
993
+ if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") &&
994
+ genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) {
995
+ quallaEngineConfig["lora-version"] =
996
+ genieEngineConfig["model"]["binary"]["lora"]["lora-version"];
997
+ }
998
+ for (int i = 0; i < genieEngineConfig["model"]["binary"]["lora"]["adapters"].size(); i++) {
999
+ quallaEngineConfig["lora"][i]["adapter-name"] =
1000
+ genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["name"];
1001
+ quallaEngineConfig["lora"][i]["alpha-tensor-name"] = "";
1002
+ if (genieEngineConfig["model"]["binary"]["lora"].contains("alpha-tensor-name")) {
1003
+ quallaEngineConfig["lora"][i]["alpha-tensor-name"] =
1004
+ genieEngineConfig["model"]["binary"]["lora"]["alpha-tensor-name"];
1005
+ }
1006
+ quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f;
1007
+ quallaEngineConfig["lora"][i]["binsection-basedir"] = "";
1008
+ if (genieEngineConfig["model"]["binary"]["lora"].contains("lora-version") &&
1009
+ genieEngineConfig["model"]["binary"]["lora"]["lora-version"] == 1) {
1010
+ quallaEngineConfig["lora"][i]["path"] =
1011
+ genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["path"];
1012
+ } else {
1013
+ quallaEngineConfig["lora"][i]["bin-sections"] =
1014
+ genieEngineConfig["model"]["binary"]["lora"]["adapters"][i]["bin-sections"];
1015
+ }
1016
+ }
1017
+ }
1018
+ } else if (genieEngineConfig["model"]["type"] == "library") {
1019
+ quallaEngineConfig["model"] = getLibName("QnnGenAiTransformerModel");
1020
+ quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
1021
+ quallaEngineConfig["op-package"] =
1022
+ getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
1023
+ }
1024
+ if (genieEngineConfig["model"].contains("positional-encoding")) {
1025
+ quallaEngineConfig["positional-encoding"]["type"] =
1026
+ genieEngineConfig["model"]["positional-encoding"]["type"];
1027
+ if (genieEngineConfig["model"]["positional-encoding"]["type"] == "rope") {
1028
+ quallaEngineConfig["positional-encoding"]["rope-dim"] =
1029
+ genieEngineConfig["model"]["positional-encoding"]["rope-dim"];
1030
+ if (genieEngineConfig["model"]["positional-encoding"].contains("rope-theta")) {
1031
+ quallaEngineConfig["positional-encoding"]["rope-theta"] =
1032
+ genieEngineConfig["model"]["positional-encoding"]["rope-theta"];
1033
+ }
1034
+ if (genieEngineConfig["model"]["positional-encoding"].contains("rope-scaling")) {
1035
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1036
+ "rope-type")) {
1037
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["rope-type"] =
1038
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"];
1039
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] ==
1040
+ "llama3") {
1041
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1042
+ "factor")) {
1043
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] =
1044
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"];
1045
+ }
1046
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1047
+ "low-freq-factor")) {
1048
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["low-freq-factor"] =
1049
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
1050
+ ["low-freq-factor"];
1051
+ }
1052
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1053
+ "high-freq-factor")) {
1054
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["high-freq-factor"] =
1055
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
1056
+ ["high-freq-factor"];
1057
+ }
1058
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1059
+ "original-max-position-embeddings")) {
1060
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]
1061
+ ["original-max-position-embeddings"] =
1062
+ genieEngineConfig["model"]["positional-encoding"]
1063
+ ["rope-scaling"]
1064
+ ["original-max-position-embeddings"];
1065
+ }
1066
+ }
1067
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["rope-type"] ==
1068
+ "longrope") {
1069
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1070
+ "factor")) {
1071
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["factor"] =
1072
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]["factor"];
1073
+ }
1074
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1075
+ "short-factor")) {
1076
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["short-factor"] =
1077
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
1078
+ ["short-factor"];
1079
+ }
1080
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1081
+ "long-factor")) {
1082
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]["long-factor"] =
1083
+ genieEngineConfig["model"]["positional-encoding"]["rope-scaling"]
1084
+ ["long-factor"];
1085
+ }
1086
+ if (genieEngineConfig["model"]["positional-encoding"]["rope-scaling"].contains(
1087
+ "original-max-position-embeddings")) {
1088
+ quallaEngineConfig["positional-encoding"]["rope-scaling"]
1089
+ ["original-max-position-embeddings"] =
1090
+ genieEngineConfig["model"]["positional-encoding"]
1091
+ ["rope-scaling"]
1092
+ ["original-max-position-embeddings"];
1093
+ }
1094
+ }
1095
+ }
1096
+ }
1097
+ }
1098
+ }
1099
+ }
1100
+ }
1101
+
1102
+ static void translateMultiEngineConfig(const qualla::json& genieConfig,
1103
+ qualla::json& quallaConfig) {
1104
+ if (genieConfig["dialog"]["engine"].is_array()) {
1105
+ quallaConfig["engine"] = qualla::json::array();
1106
+ for (auto& item : genieConfig["dialog"]["engine"]) {
1107
+ qualla::json quallaEngineConfig;
1108
+ translateEngineConfig(item, quallaEngineConfig);
1109
+ quallaConfig["engine"].push_back(quallaEngineConfig);
1110
+ }
1111
+ } else {
1112
+ translateEngineConfig(genieConfig["dialog"]["engine"], quallaConfig["engine"]);
1113
+ }
1114
+ }
1115
+
1116
+ //=============================================================================
1117
+ // Dialog::Config functions
1118
+ //=============================================================================
1119
+
1120
+ qnn::util::HandleManager<Dialog::Config> Dialog::Config::s_manager;
1121
+
1122
+ GenieDialogConfig_Handle_t Dialog::Config::add(std::shared_ptr<Dialog::Config> config) {
1123
+ return (GenieDialogConfig_Handle_t)s_manager.add(config);
1124
+ }
1125
+
1126
+ std::shared_ptr<Dialog::Config> Dialog::Config::get(GenieDialogConfig_Handle_t handle) {
1127
+ return s_manager.get((qnn::util::Handle_t)handle);
1128
+ }
1129
+
1130
+ void Dialog::Config::remove(GenieDialogConfig_Handle_t handle) {
1131
+ s_manager.remove((qnn::util::Handle_t)handle);
1132
+ }
1133
+
1134
+ #if defined(GENIE_SSD_FEATURE)
1135
+ static void validateDialogSsdConfig(const qualla::json& config) {
1136
+ if (!config.is_object()) {
1137
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "ssd-q1 config is not an object");
1138
+ }
1139
+
1140
+ std::set<std::string> mandatoryFields{"version",
1141
+ "ssd-version",
1142
+ "forecast-token-count",
1143
+ "branches",
1144
+ "forecast-prefix",
1145
+ "forecast-prefix-name"};
1146
+ for (const auto& field : mandatoryFields) {
1147
+ if (!config.contains(field)) {
1148
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 field: " + field);
1149
+ }
1150
+ }
1151
+
1152
+ // component is used in the "ENFORCE" macros
1153
+ std::string component = "ssd-q1";
1154
+
1155
+ int branchesSize = 0;
1156
+ int forecastTokenCount = 0;
1157
+
1158
+ int nStreams = 1;
1159
+ float pThreshold = 0.0;
1160
+
1161
+ for (auto& item : config.items()) {
1162
+ if (item.key() == "version") {
1163
+ JSON_ENFORCE_NUMERIC();
1164
+ if (item.value().get<int>() != 1) {
1165
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1166
+ "Invalid ssd-q1 config: unsupported version: " + item.value().dump());
1167
+ }
1168
+ } else if (item.key() == "ssd-version") {
1169
+ JSON_ENFORCE_NUMERIC();
1170
+ } else if (item.key() == "forecast-token-count") {
1171
+ JSON_ENFORCE_NUMERIC();
1172
+ forecastTokenCount = item.value();
1173
+ } else if (item.key() == "branches") {
1174
+ JSON_ENFORCE_ARRAY();
1175
+ for (auto& elem : item.value()) {
1176
+ if (!elem.is_number_integer()) {
1177
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "branches must be an array of integers");
1178
+ }
1179
+ }
1180
+ branchesSize = item.value().size();
1181
+ } else if (item.key() == "forecast-prefix") {
1182
+ JSON_ENFORCE_NUMERIC();
1183
+ } else if (item.key() == "forecast-prefix-name") {
1184
+ JSON_ENFORCE_STRING();
1185
+ } else if (item.key() == "n-streams") {
1186
+ JSON_ENFORCE_NUMERIC();
1187
+ nStreams = item.value();
1188
+ } else if (item.key() == "p-threshold") {
1189
+ JSON_ENFORCE_NUMERIC();
1190
+ pThreshold = item.value();
1191
+ } else {
1192
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown ssd-q1 config key: " + item.key());
1193
+ }
1194
+ }
1195
+
1196
+ if ((pThreshold > 0.0) && (nStreams <= 1)) {
1197
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1198
+ "p-threshold can only be used with multistream (n-streams > 1)");
1199
+ }
1200
+
1201
+ if (branchesSize > forecastTokenCount) {
1202
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1203
+ "Size of branches array must be less than forecast-token-count");
1204
+ }
1205
+ }
1206
+ #endif
1207
+
1208
+ #if defined(GENIE_LADE_FEATURE)
1209
+ static void validateDialogLadeConfig(const qualla::json& config) {
1210
+ if (!config.is_object()) {
1211
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "lade config is not an object");
1212
+ }
1213
+
1214
+ std::set<std::string> mandatoryFields{"version", "update-mode", "window", "ngram", "gcap"};
1215
+ for (const auto& field : mandatoryFields) {
1216
+ if (!config.contains(field)) {
1217
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade field: " + field);
1218
+ }
1219
+ }
1220
+
1221
+ // component is used in the "ENFORCE" macros
1222
+ std::string component = "lade";
1223
+
1224
+ for (auto& item : config.items()) {
1225
+ if (item.key() == "version") {
1226
+ JSON_ENFORCE_NUMERIC();
1227
+ if (item.value().get<int>() != 1) {
1228
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1229
+ "Invalid lade config: unsupported version: " + item.value().dump());
1230
+ }
1231
+ } else if (item.key() == "update-mode") {
1232
+ JSON_ENFORCE_STRING();
1233
+ std::string mode = item.value().get<std::string>();
1234
+ if ((mode != "FWD_MAX_HIT") && (mode != "FWD_LEVEL") && (mode != "ALWAYS_FWD_ONE")) {
1235
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1236
+ "Invalid lade config: unsupported update-mode: " + item.value().dump());
1237
+ }
1238
+ } else if (item.key() == "window") {
1239
+ JSON_ENFORCE_NUMERIC();
1240
+ } else if (item.key() == "ngram") {
1241
+ JSON_ENFORCE_NUMERIC();
1242
+ } else if (item.key() == "gcap") {
1243
+ JSON_ENFORCE_NUMERIC();
1244
+ } else {
1245
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown lade config key: " + item.key());
1246
+ }
1247
+ }
1248
+ }
1249
+ #endif
1250
+
1251
+ #if defined(GENIE_SPD_FEATURE)
1252
+ static void validateDialogSpdConfig(const qualla::json& config) {
1253
+ if (!config.is_object()) {
1254
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "spd config is not an object");
1255
+ }
1256
+
1257
+ std::set<std::string> mandatoryFields{"version", "draft-len"};
1258
+ for (const auto& field : mandatoryFields) {
1259
+ if (!config.contains(field)) {
1260
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd field: " + field);
1261
+ }
1262
+ }
1263
+
1264
+ // component is used in the "ENFORCE" macros
1265
+ std::string component = "spd";
1266
+ for (auto& item : config.items()) {
1267
+ if (item.key() == "version") {
1268
+ JSON_ENFORCE_NUMERIC();
1269
+ if (item.value().get<int>() != 1) {
1270
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1271
+ "Invalid spd config: unsupported version: " + item.value().dump());
1272
+ }
1273
+ } else if (item.key() == "draft-len") {
1274
+ JSON_ENFORCE_NUMERIC();
1275
+ } else {
1276
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown spd config key: " + item.key());
1277
+ }
1278
+ }
1279
+ }
1280
+ #endif
1281
+
1282
+ #if defined(GENIE_MULTISTREAM_FEATURE)
1283
+ static void validateDialogMultistreamConfig(const qualla::json& config) {
1284
+ if (!config.is_object()) {
1285
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "multistream config is not an object");
1286
+ }
1287
+
1288
+ std::set<std::string> mandatoryFields{"version", "n-streams"};
1289
+ for (const auto& field : mandatoryFields) {
1290
+ if (!config.contains(field)) {
1291
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream field: " + field);
1292
+ }
1293
+ }
1294
+
1295
+ // component is used in the "ENFORCE" macros
1296
+ std::string component = "multistream";
1297
+
1298
+ for (auto& item : config.items()) {
1299
+ if (item.key() == "version") {
1300
+ JSON_ENFORCE_NUMERIC();
1301
+ if (item.value().get<int>() != 1) {
1302
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1303
+ "Invalid multistream config: unsupported version: " + item.value().dump());
1304
+ }
1305
+ } else if (item.key() == "n-streams") {
1306
+ JSON_ENFORCE_NUMERIC();
1307
+ } else if (item.key() == "p-threshold") {
1308
+ JSON_ENFORCE_NUMERIC();
1309
+ } else {
1310
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1311
+ "Unknown multistream config key: " + item.key());
1312
+ }
1313
+ }
1314
+ }
1315
+ #endif
1316
+
1317
+ static void validateDialogConfig(const qualla::json& config) {
1318
+ if (!config.is_object()) {
1319
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object");
1320
+ }
1321
+
1322
+ std::set<std::string> mandatoryFields{"version", "type", "context", "tokenizer", "engine"};
1323
+ for (const auto& field : mandatoryFields) {
1324
+ if (!config.contains(field)) {
1325
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field);
1326
+ }
1327
+ }
1328
+
1329
+ // component is used in the "ENFORCE" macros
1330
+ std::string component = "dialog";
1331
+
1332
+ std::string dialogType = "basic";
1333
+ #if defined(GENIE_SSD_FEATURE)
1334
+ bool ssdq1 = false;
1335
+ qualla::json ssdq1Config;
1336
+ #endif
1337
+ #if defined(GENIE_LADE_FEATURE)
1338
+ bool lade = false;
1339
+ qualla::json ladeConfig;
1340
+ #endif
1341
+ #if defined(GENIE_SPD_FEATURE)
1342
+ bool spd = false;
1343
+ qualla::json spdConfig;
1344
+ #endif
1345
+ #if defined(GENIE_MULTISTREAM_FEATURE)
1346
+ bool multistream = false;
1347
+ qualla::json multistreamConfig;
1348
+ #endif
1349
+
1350
+ for (auto& item : config.items()) {
1351
+ if (item.key() == "version") {
1352
+ JSON_ENFORCE_NUMERIC();
1353
+ if (item.value().get<int>() != 1) {
1354
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1355
+ "Invalid dialog config: unsupported version: " + item.value().dump());
1356
+ }
1357
+ } else if (item.key() == "type") {
1358
+ JSON_ENFORCE_STRING();
1359
+ dialogType = item.value();
1360
+ if (dialogType == "basic" || dialogType == "kv-share") {
1361
+ // Do nothing
1362
+ #if defined(GENIE_SSD_FEATURE)
1363
+ } else if (dialogType == "ssd-q1") {
1364
+ ssdq1 = true;
1365
+ #endif
1366
+ #if defined(GENIE_LADE_FEATURE)
1367
+ } else if (dialogType == "lade") {
1368
+ lade = true;
1369
+ #endif
1370
+ #if defined(GENIE_SPD_FEATURE)
1371
+ } else if (dialogType == "spd") {
1372
+ spd = true;
1373
+ #endif
1374
+ #if defined(GENIE_MULTISTREAM_FEATURE)
1375
+ } else if (dialogType == "multistream") {
1376
+ multistream = true;
1377
+ #endif
1378
+ } else {
1379
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "Invalid dialog type: " + dialogType);
1380
+ }
1381
+ #if defined(GENIE_SSD_FEATURE)
1382
+ } else if (item.key() == "ssd-q1") {
1383
+ JSON_ENFORCE_OBJECT();
1384
+ ssdq1Config = item.value();
1385
+ // ssd-q1 validation is done below
1386
+ #endif
1387
+ #if defined(GENIE_LADE_FEATURE)
1388
+ } else if (item.key() == "lade") {
1389
+ JSON_ENFORCE_OBJECT();
1390
+ ladeConfig = item.value();
1391
+ // ssd-q1 validation is done below
1392
+ #endif
1393
+ #if defined(GENIE_SPD_FEATURE)
1394
+ } else if (item.key() == "spd") {
1395
+ JSON_ENFORCE_OBJECT();
1396
+ spdConfig = item.value();
1397
+ // spd validation is done below
1398
+ #endif
1399
+ #if defined(GENIE_MULTISTREAM_FEATURE)
1400
+ } else if (item.key() == "multistream") {
1401
+ JSON_ENFORCE_OBJECT();
1402
+ multistreamConfig = item.value();
1403
+ // multistream validation is done below
1404
+ #endif
1405
+ } else if (item.key() == "stop-sequence") {
1406
+ JSON_ENFORCE_ARRAY();
1407
+ for (auto& elem : item.value()) {
1408
+ if (!elem.is_string()) {
1409
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1410
+ "stop-sequence must be an array of strings");
1411
+ }
1412
+ }
1413
+ } else if (item.key() == "max-num-tokens") {
1414
+ JSON_ENFORCE_NUMERIC();
1415
+ if (item.value().get<int>() < 0) {
1416
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
1417
+ "number of tokens must be > 0. provided: " + item.value().dump());
1418
+ }
1419
+ } else if (item.key() == "context") {
1420
+ JSON_ENFORCE_OBJECT();
1421
+ validateContextConfig(item.value());
1422
+ } else if (item.key() == "tokenizer") {
1423
+ JSON_ENFORCE_OBJECT();
1424
+ validateTokenizerConfig(item.value());
1425
+ } else if (item.key() == "sampler") {
1426
+ JSON_ENFORCE_OBJECT();
1427
+ validateSamplerConfig(item.value());
1428
+ } else if (item.key() == "engine") {
1429
+ JSON_ENFORCE_ARRAY_OR_OBJECT();
1430
+ } else if (item.key() == "embedding") {
1431
+ JSON_ENFORCE_OBJECT();
1432
+ validateEmbeddingConfig(item.value());
1433
+ } else {
1434
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key());
1435
+ }
1436
+ }
1437
+
1438
+ // Engine Verification requires dialogType for engine roles. Since "type" is encounterd
1439
+ // later than "engine" in loop. Therefore, moving engine validation out of the loop.
1440
+ validateMultiEngineConfig(config["engine"], dialogType);
1441
+
1442
+ #if defined(GENIE_SSD_FEATURE)
1443
+ if (ssdq1) {
1444
+ if (!ssdq1Config.is_object()) {
1445
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing ssd-q1 dialog config");
1446
+ }
1447
+ validateDialogSsdConfig(ssdq1Config);
1448
+ } else {
1449
+ if (ssdq1Config.is_object()) {
1450
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1451
+ "ssd-q1 dialog config for incorrect dialog type: " + dialogType);
1452
+ }
1453
+ }
1454
+ #endif
1455
+ #if defined(GENIE_LADE_FEATURE)
1456
+ if (lade) {
1457
+ if (!ladeConfig.is_object()) {
1458
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing lade dialog config");
1459
+ }
1460
+ validateDialogLadeConfig(ladeConfig);
1461
+ } else {
1462
+ if (ladeConfig.is_object()) {
1463
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1464
+ "lade dialog config for incorrect dialog type: " + dialogType);
1465
+ }
1466
+ }
1467
+ #endif
1468
+ #if defined(GENIE_SPD_FEATURE)
1469
+ if (spd) {
1470
+ if (!spdConfig.is_object()) {
1471
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing spd dialog config");
1472
+ }
1473
+ validateDialogSpdConfig(spdConfig);
1474
+ } else {
1475
+ if (spdConfig.is_object()) {
1476
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1477
+ "spd dialog config for incorrect dialog type: " + dialogType);
1478
+ }
1479
+ }
1480
+ #endif
1481
+ #if defined(GENIE_MULTISTREAM_FEATURE)
1482
+ if (multistream) {
1483
+ if (!multistreamConfig.is_object()) {
1484
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing multistream dialog config");
1485
+ }
1486
+ validateDialogMultistreamConfig(multistreamConfig);
1487
+ } else {
1488
+ if (multistreamConfig.is_object()) {
1489
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1490
+ "multistream dialog config for incorrect dialog type: " + dialogType);
1491
+ }
1492
+ }
1493
+ #endif
1494
+ }
1495
+
1496
+ static void translateDialogConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
1497
+ if (genieConfig["dialog"]["version"] == 1) {
1498
+ if (genieConfig["dialog"]["type"] == "lade") {
1499
+ quallaConfig["type"] = "lhd-dec";
1500
+ } else if (genieConfig["dialog"]["type"] == "spd") {
1501
+ quallaConfig["type"] = "spec-dec";
1502
+ } else if (genieConfig["dialog"]["type"] == "multistream") {
1503
+ quallaConfig["type"] = "multistream";
1504
+ } else {
1505
+ quallaConfig["type"] = genieConfig["dialog"]["type"];
1506
+ }
1507
+ #if defined(GENIE_SSD_FEATURE)
1508
+ if (genieConfig["dialog"]["type"] == "ssd-q1") {
1509
+ quallaConfig["ssd-version"] = genieConfig["dialog"]["ssd-q1"]["ssd-version"];
1510
+ quallaConfig["forecast-token-count"] =
1511
+ genieConfig["dialog"]["ssd-q1"]["forecast-token-count"];
1512
+ quallaConfig["branches"] = genieConfig["dialog"]["ssd-q1"]["branches"];
1513
+ quallaConfig["forecast-prefix"] = genieConfig["dialog"]["ssd-q1"]["forecast-prefix"];
1514
+ quallaConfig["forecast-prefix-name"] =
1515
+ genieConfig["dialog"]["ssd-q1"]["forecast-prefix-name"];
1516
+
1517
+ if (genieConfig["dialog"]["ssd-q1"].contains("n-streams")) {
1518
+ quallaConfig["n-streams"] = genieConfig["dialog"]["ssd-q1"]["n-streams"];
1519
+ }
1520
+ if (genieConfig["dialog"]["ssd-q1"].contains("p-threshold")) {
1521
+ quallaConfig["p-threshold"] = genieConfig["dialog"]["ssd-q1"]["p-threshold"];
1522
+ }
1523
+ }
1524
+ #endif
1525
+ #if defined(GENIE_LADE_FEATURE)
1526
+ if (genieConfig["dialog"]["type"] == "lade") {
1527
+ quallaConfig["lhd-update-mode"] = genieConfig["dialog"]["lade"]["update-mode"];
1528
+ quallaConfig["window"] = genieConfig["dialog"]["lade"]["window"];
1529
+ quallaConfig["ngram"] = genieConfig["dialog"]["lade"]["ngram"];
1530
+ quallaConfig["gcap"] = genieConfig["dialog"]["lade"]["gcap"];
1531
+ }
1532
+ #endif
1533
+ #if defined(GENIE_SPD_FEATURE)
1534
+ if (genieConfig["dialog"]["type"] == "spd") {
1535
+ quallaConfig["draft-len"] = genieConfig["dialog"]["spd"]["draft-len"];
1536
+ }
1537
+ #endif
1538
+ #if defined(GENIE_MULTISTREAM_FEATURE)
1539
+ if (genieConfig["dialog"]["type"] == "multistream") {
1540
+ quallaConfig["n-streams"] = genieConfig["dialog"]["multistream"]["n-streams"];
1541
+ if (genieConfig["dialog"]["multistream"].contains("p-threshold")) {
1542
+ quallaConfig["p-threshold"] = genieConfig["dialog"]["multistream"]["p-threshold"];
1543
+ }
1544
+ }
1545
+ #endif
1546
+ }
1547
+ if (genieConfig["dialog"].contains("stop-sequence")) {
1548
+ quallaConfig["prompt"]["stop-sequence"] = genieConfig["dialog"]["stop-sequence"];
1549
+ }
1550
+
1551
+ translateContextConfig(genieConfig, quallaConfig);
1552
+ translateTokenizerConfig(genieConfig, quallaConfig);
1553
+ translateSamplerConfig(genieConfig, quallaConfig);
1554
+ translateMultiEngineConfig(genieConfig, quallaConfig);
1555
+ translateEmbeddingConfig(genieConfig, quallaConfig);
1556
+ }
1557
+
1558
+ uint32_t getMaxNumTokens(const qualla::json& genieConfig) {
1559
+ uint32_t tokenLimit{UINT32_MAX};
1560
+ if (genieConfig["dialog"]["version"] == 1) {
1561
+ if (genieConfig["dialog"].contains("max-num-tokens")) {
1562
+ tokenLimit = genieConfig["dialog"]["max-num-tokens"];
1563
+ }
1564
+ }
1565
+ return tokenLimit;
1566
+ }
1567
+
1568
+ Dialog::Config::Config(const char* configStr) {
1569
+ qualla::json config;
1570
+ rope_theta_set = false;
1571
+ position_dim_set = false;
1572
+ {
1573
+ std::set<qualla::json> keys;
1574
+
1575
+ auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
1576
+ if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
1577
+ if (keys.count(parsed) > 0) {
1578
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
1579
+ "Multiple dialog config key: " + parsed.dump());
1580
+ }
1581
+ keys.insert(parsed);
1582
+ }
1583
+ return true;
1584
+ };
1585
+
1586
+ config = qualla::json::parse(configStr, callback);
1587
+ }
1588
+
1589
+ if (!config.is_object()) {
1590
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Dialog config is not an object");
1591
+ }
1592
+
1593
+ std::set<std::string> mandatoryFields{"dialog"};
1594
+ for (const auto& field : mandatoryFields) {
1595
+ if (!config.contains(field)) {
1596
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing dialog field: " + field);
1597
+ }
1598
+ }
1599
+
1600
+ // component is used in the "ENFORCE" macros
1601
+ std::string component = "dialog";
1602
+
1603
+ for (auto& item : config.items()) {
1604
+ if (item.key() == "dialog") {
1605
+ JSON_ENFORCE_OBJECT();
1606
+ validateDialogConfig(item.value());
1607
+ } else {
1608
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown dialog config key: " + item.key());
1609
+ }
1610
+ }
1611
+ m_config = config;
1612
+ }
1613
+
1614
+ qualla::json Dialog::Config::getJson() const { return m_config; }
1615
+
1616
+ //=============================================================================
1617
+ // Dialog functions
1618
+ //=============================================================================
1619
+
1620
+ qnn::util::HandleManager<Dialog> Dialog::s_manager;
1621
+ std::atomic<std::uint32_t> Dialog::s_nameCounter{0u};
1622
+
1623
+ GenieDialog_Handle_t Dialog::add(std::shared_ptr<Dialog> dialog) {
1624
+ return (GenieDialog_Handle_t)s_manager.add(dialog);
1625
+ }
1626
+
1627
+ std::shared_ptr<Dialog> Dialog::get(GenieDialog_Handle_t handle) {
1628
+ return s_manager.get((qnn::util::Handle_t)handle);
1629
+ }
1630
+
1631
+ void Dialog::remove(GenieDialog_Handle_t handle) { s_manager.remove((qnn::util::Handle_t)handle); }
1632
+
1633
+ Dialog::Dialog(std::shared_ptr<Config> config) {
1634
+ auto env = qualla::Env::create(qualla::json{});
1635
+ qualla::json quallaConfig;
1636
+ translateDialogConfig(config->getJson(), quallaConfig);
1637
+ m_tokenLimit = getMaxNumTokens(config->getJson());
1638
+ m_quallaDialog = qualla::Dialog::create(
1639
+ env, "dialog" + std::to_string(s_nameCounter.fetch_add(1u)), quallaConfig);
1640
+ if (!m_quallaDialog) {
1641
+ throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
1642
+ }
1643
+ }
1644
+
1645
+ static_assert(qualla::Sentence::Code::COMPLETE ==
1646
+ static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_COMPLETE));
1647
+ static_assert(qualla::Sentence::Code::BEGIN ==
1648
+ static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_BEGIN));
1649
+ static_assert(qualla::Sentence::Code::CONTINUE ==
1650
+ static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_CONTINUE));
1651
+ static_assert(qualla::Sentence::Code::END ==
1652
+ static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_END));
1653
+ static_assert(qualla::Sentence::Code::ABORT ==
1654
+ static_cast<qualla::Sentence::Code>(GENIE_DIALOG_SENTENCE_ABORT));
1655
+
1656
+ int32_t Dialog::query(const char* queryStr,
1657
+ GenieDialog_SentenceCode_t sentenceCode,
1658
+ GenieDialog_QueryCallback_t callback,
1659
+ const void* userData) {
1660
+ std::string query(queryStr);
1661
+ uint32_t genTokenCount = 0u;
1662
+ bool status = m_quallaDialog->query(
1663
+ query,
1664
+ static_cast<qualla::Sentence::Code>(sentenceCode),
1665
+ [&](const std::string& response, qualla::Sentence::Code code) {
1666
+ callback(response.c_str(), static_cast<GenieDialog_SentenceCode_t>(code), userData);
1667
+ bool keepGoing = ++genTokenCount < m_tokenLimit;
1668
+ if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) ||
1669
+ (code == qualla::Sentence::Code::CONTINUE))) {
1670
+ callback("", GENIE_DIALOG_SENTENCE_END, userData);
1671
+ }
1672
+ return keepGoing;
1673
+ });
1674
+ qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
1675
+ printf(
1676
+ "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
1677
+ "%f toks/sec\n"
1678
+ "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
1679
+ kpis.init.total_usec,
1680
+ kpis.prompt.last_usec,
1681
+ kpis.tps.prompt,
1682
+ kpis.generate.last_usec,
1683
+ kpis.tps.generate);
1684
+ return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1685
+ }
1686
+
1687
+ int32_t Dialog::save(const std::string& name) {
1688
+ return m_quallaDialog->save(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1689
+ }
1690
+
1691
+ int32_t Dialog::restore(const std::string& name) {
1692
+ return m_quallaDialog->restore(name) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1693
+ }
1694
+
1695
+ #if defined(GENIE_E2T_FEATURE)
1696
+ int32_t Dialog::embeddingQuery(const void* embeddings,
1697
+ const uint32_t embeddingsSize,
1698
+ GenieDialog_SentenceCode_t sentenceCode,
1699
+ GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
1700
+ GenieDialog_QueryCallback_t callback,
1701
+ const void* userData) {
1702
+ uint32_t genTokenCount = 0u;
1703
+
1704
+ if (embeddingsSize % m_quallaDialog->getEmbeddingBufferSize() != 0) {
1705
+ throw std::runtime_error(
1706
+ "The embeddings buffer size must be an integer multiple of the embedding vector size in "
1707
+ "bytes.");
1708
+ }
1709
+
1710
+ const uint8_t* embeddingsSrc = static_cast<const uint8_t*>(embeddings);
1711
+ std::vector<uint8_t> embeddingVector(embeddingsSrc, embeddingsSrc + embeddingsSize);
1712
+
1713
+ qualla::Dialog::T2ECallback t2eQuallaCallback{nullptr};
1714
+ if (t2eCallback) {
1715
+ t2eQuallaCallback = [&](const int32_t token, void* embedding, const uint32_t embd_size) {
1716
+ t2eCallback(token, embedding, embd_size, userData);
1717
+ };
1718
+ }
1719
+
1720
+ bool status = m_quallaDialog->query(
1721
+ embeddingVector,
1722
+ static_cast<qualla::Sentence::Code>(sentenceCode),
1723
+ t2eQuallaCallback,
1724
+ [&](const std::string& response, qualla::Sentence::Code code) {
1725
+ callback(response.c_str(), static_cast<GenieDialog_SentenceCode_t>(code), userData);
1726
+ bool keepGoing = ++genTokenCount < m_tokenLimit;
1727
+ if (!keepGoing && ((code == qualla::Sentence::Code::BEGIN) ||
1728
+ (code == qualla::Sentence::Code::CONTINUE))) {
1729
+ callback("", GENIE_DIALOG_SENTENCE_END, userData);
1730
+ }
1731
+ return keepGoing;
1732
+ });
1733
+ qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
1734
+ printf(
1735
+ "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
1736
+ "%f toks/sec\n"
1737
+ "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
1738
+ kpis.init.total_usec,
1739
+ kpis.prompt.last_usec,
1740
+ kpis.tps.prompt,
1741
+ kpis.generate.last_usec,
1742
+ kpis.tps.generate);
1743
+ return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1744
+ }
1745
+ #endif
1746
+
1747
+ void Dialog::reset() { m_quallaDialog->reset(); }
1748
+
1749
+ #if defined(GENIE_LORA_FEATURE)
1750
+
1751
+ int32_t Dialog::applyLora(std::string loraAdapterName, std::string engine) {
1752
+ bool status = m_quallaDialog->applyLoraAdapter(loraAdapterName, engine);
1753
+ return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL);
1754
+ }
1755
+
1756
+ int32_t Dialog::applyLoraStrength(std::string tensorName, std::string engine, float alpha) {
1757
+ bool status = m_quallaDialog->applyLoraStrength(tensorName, alpha, engine);
1758
+ return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERAL);
1759
+ }
1760
+
1761
+ #endif
1762
+
1763
+ int32_t Dialog::tokenQuery(const uint32_t* tokens,
1764
+ const uint32_t sizeInputTokens,
1765
+ GenieDialog_SentenceCode_t sentenceCode,
1766
+ GenieDialog_TokenQueryCallback_t callback,
1767
+ const void* userData) {
1768
+ std::vector<uint32_t> inputTokens;
1769
+ for (size_t i = 0; i < sizeInputTokens; i++) {
1770
+ inputTokens.push_back(tokens[i]);
1771
+ }
1772
+ uint32_t genTokenCount = 0u;
1773
+ dialogCallback.setCallBackType(qualla::QUALLA_CALLBACK_TYPE_TOKEN);
1774
+ dialogCallback.getTokenCbFunc() = std::make_shared<
1775
+ std::function<bool(const int32_t*, const uint32_t, qualla::Sentence::Code)>>();
1776
+ *(dialogCallback.getTokenCbFunc()) = [&](const int32_t* responseTokens,
1777
+ const uint32_t sizeResponseTokens,
1778
+ qualla::Sentence::Code code) {
1779
+ callback((const uint32_t*)responseTokens,
1780
+ sizeResponseTokens,
1781
+ static_cast<GenieDialog_SentenceCode_t>(code),
1782
+ userData);
1783
+ bool keepGoing = ++genTokenCount < m_tokenLimit;
1784
+ if (!keepGoing &&
1785
+ ((code == qualla::Sentence::Code::BEGIN) || (code == qualla::Sentence::Code::CONTINUE))) {
1786
+ callback(nullptr, 0, GENIE_DIALOG_SENTENCE_END, userData);
1787
+ }
1788
+ return keepGoing;
1789
+ };
1790
+ bool status = m_quallaDialog->query((const std::vector<uint32_t>)inputTokens,
1791
+ static_cast<qualla::Sentence::Code>(sentenceCode),
1792
+ dialogCallback);
1793
+ qualla::Dialog::KPIs kpis = m_quallaDialog->kpis();
1794
+ printf(
1795
+ "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
1796
+ "%f toks/sec\n"
1797
+ "Token Generation Time: %zu us, Token Generation Rate: %f toks/sec\n",
1798
+ kpis.init.total_usec,
1799
+ kpis.prompt.last_usec,
1800
+ kpis.tps.prompt,
1801
+ kpis.generate.last_usec,
1802
+ kpis.tps.generate);
1803
+ return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1804
+ }
Genie/Genie/src/Dialog.hpp ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <atomic>
12
+ #include <memory>
13
+
14
+ #include "GenieDialog.h"
15
+ #include "Util/HandleManager.hpp"
16
+ #include "qualla/dialog.hpp"
17
+ #include "qualla/DialogCallback.hpp"
18
+
19
+ namespace genie {
20
+
21
+ enum LORA_VERSION : uint8_t {
22
+ GENIE_LORA_VERSION_V1 = 0x1,
23
+ GENIE_LORA_VERSION_V2 = 0x2,
24
+ GENIE_LORA_VERSION_UNDEFINED = 0xFF
25
+ };
26
+
27
+ class Dialog {
28
+ public:
29
+ class Config {
30
+ public:
31
+ static GenieDialogConfig_Handle_t add(std::shared_ptr<Config> config);
32
+ static std::shared_ptr<Config> get(GenieDialogConfig_Handle_t handle);
33
+ static void remove(GenieDialogConfig_Handle_t handle);
34
+
35
+ Config(const char* configStr);
36
+ qualla::json getJson() const;
37
+
38
+ private:
39
+ static qnn::util::HandleManager<Config> s_manager;
40
+ qualla::json m_config;
41
+ };
42
+
43
+ static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
44
+ static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
45
+ static void remove(GenieDialog_Handle_t handle);
46
+
47
+ qualla::DialogCallback dialogCallback;
48
+
49
+ Dialog(std::shared_ptr<Config> config);
50
+
51
+ Dialog(const Dialog&) = delete;
52
+ Dialog& operator=(const Dialog&) = delete;
53
+ Dialog(Dialog&&) = delete;
54
+ Dialog& operator=(Dialog&&) = delete;
55
+
56
+ int32_t query(const char* queryStr,
57
+ GenieDialog_SentenceCode_t sentenceCode,
58
+ GenieDialog_QueryCallback_t callback,
59
+ const void* userData);
60
+
61
+ int32_t save(const std::string&);
62
+
63
+ int32_t restore(const std::string&);
64
+
65
+ #if defined(GENIE_E2T_FEATURE)
66
+ int32_t embeddingQuery(const void* embeddings,
67
+ const uint32_t embeddingsSize,
68
+ GenieDialog_SentenceCode_t sentenceCode,
69
+ GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
70
+ GenieDialog_QueryCallback_t callback,
71
+ const void* userData);
72
+ #endif
73
+
74
+
75
+
76
+ int32_t tokenQuery(const uint32_t* tokens,
77
+ const uint32_t sizeInputTokens,
78
+ GenieDialog_SentenceCode_t sentenceCode,
79
+ GenieDialog_TokenQueryCallback_t callback,
80
+ const void* userData);
81
+
82
+ void reset();
83
+
84
+ #if defined(GENIE_LORA_FEATURE)
85
+ int32_t applyLora(std::string loraAdapterName, std::string engine);
86
+ int32_t applyLoraStrength(std::string tensorName, std::string engine, float alpha);
87
+ #endif
88
+
89
+ private:
90
+ std::unique_ptr<qualla::Dialog> m_quallaDialog;
91
+ uint32_t m_tokenLimit{UINT32_MAX};
92
+ static qnn::util::HandleManager<Dialog> s_manager;
93
+ static std::atomic<std::uint32_t> s_nameCounter;
94
+ };
95
+ } // namespace genie
Genie/Genie/src/Exception.hpp ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <exception>
12
+ #include <string>
13
+
14
+ #include "GenieCommon.h"
15
+
16
+ namespace genie {
17
+
18
+ class Exception : public std::runtime_error {
19
+ public:
20
+ Exception(Genie_Status_t status, std::string what) : std::runtime_error(what), m_status(status) {}
21
+ Genie_Status_t status() const { return m_status; }
22
+
23
+ private:
24
+ Genie_Status_t m_status;
25
+ };
26
+
27
+ } // namespace genie
Genie/Genie/src/GenieCommon.cpp ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //=============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //=============================================================================
8
+
9
+ #include "GenieCommon.h"
10
+
11
+ uint32_t Genie_getApiMajorVersion(void) { return GENIE_API_VERSION_MAJOR; }
12
+
13
+ uint32_t Genie_getApiMinorVersion(void) { return GENIE_API_VERSION_MINOR; }
14
+
15
+ uint32_t Genie_getApiPatchVersion(void) { return GENIE_API_VERSION_PATCH; }
Genie/Genie/src/GenieDialog.cpp ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //=============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //=============================================================================
8
+
9
+ #include "Dialog.hpp"
10
+ #include "Exception.hpp"
11
+ #include "GenieDialog.h"
12
+ #include "Macro.hpp"
13
+ #include "Util/HandleManager.hpp"
14
+ #include "qualla/detail/json.hpp"
15
+
16
+ using namespace genie;
17
+
18
+ GENIE_API
19
+ Genie_Status_t GenieDialogConfig_createFromJson(const char* str,
20
+ GenieDialogConfig_Handle_t* configHandle) {
21
+ try {
22
+ GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
23
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
24
+ auto config = std::make_shared<Dialog::Config>(str);
25
+ GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC);
26
+ *configHandle = genie::Dialog::Config::add(config);
27
+ } catch (const qualla::json::parse_error& e) {
28
+ std::cerr << e.what() << std::endl;
29
+ return GENIE_STATUS_ERROR_JSON_FORMAT;
30
+ } catch (const Exception& e) {
31
+ std::cerr << e.what() << std::endl;
32
+ return e.status();
33
+ } catch (const std::exception& e) {
34
+ std::cerr << e.what() << std::endl;
35
+ return GENIE_STATUS_ERROR_GENERAL;
36
+ }
37
+ return GENIE_STATUS_SUCCESS;
38
+ }
39
+
40
+ GENIE_API
41
+ Genie_Status_t GenieDialogConfig_free(const GenieDialogConfig_Handle_t configHandle) {
42
+ try {
43
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
44
+ {
45
+ // Check if the dialog actually exists
46
+ auto configObj = genie::Dialog::Config::get(configHandle);
47
+ GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
48
+ }
49
+ genie::Dialog::Config::remove(configHandle);
50
+ } catch (const std::exception& e) {
51
+ return GENIE_STATUS_ERROR_GENERAL;
52
+ }
53
+ return GENIE_STATUS_SUCCESS;
54
+ }
55
+
56
+ GENIE_API
57
+ Genie_Status_t GenieDialog_create(const GenieDialogConfig_Handle_t configHandle,
58
+ GenieDialog_Handle_t* dialogHandle) {
59
+ try {
60
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
61
+
62
+ // Get config object
63
+ auto configObj = genie::Dialog::Config::get(configHandle);
64
+ GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
65
+
66
+ // Create dialog
67
+ auto dialog = std::make_shared<genie::Dialog>(configObj);
68
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_MEM_ALLOC);
69
+
70
+ // Create Handle
71
+ *dialogHandle = genie::Dialog::add(dialog);
72
+ } catch (const std::exception& e) {
73
+ std::cerr << e.what() << std::endl;
74
+ return GENIE_STATUS_ERROR_GENERAL;
75
+ }
76
+
77
+ // Return SUCCESS
78
+ return GENIE_STATUS_SUCCESS;
79
+ }
80
+
81
+ GENIE_API
82
+ Genie_Status_t GenieDialog_query(const GenieDialog_Handle_t dialogHandle,
83
+ const char* queryStr,
84
+ const GenieDialog_SentenceCode_t sentenceCode,
85
+ const GenieDialog_QueryCallback_t callback,
86
+ const void* userData) {
87
+ int32_t status;
88
+
89
+ try {
90
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
91
+ auto dialog = genie::Dialog::get(dialogHandle);
92
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
93
+ GENIE_ENSURE(queryStr, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
94
+ GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
95
+
96
+ switch (sentenceCode) {
97
+ case GENIE_DIALOG_SENTENCE_COMPLETE:
98
+ case GENIE_DIALOG_SENTENCE_BEGIN:
99
+ case GENIE_DIALOG_SENTENCE_CONTINUE:
100
+ case GENIE_DIALOG_SENTENCE_END:
101
+ case GENIE_DIALOG_SENTENCE_ABORT:
102
+ // Do nothing
103
+ break;
104
+ default:
105
+ return GENIE_STATUS_ERROR_INVALID_ARGUMENT;
106
+ }
107
+
108
+ status = dialog->query(queryStr, sentenceCode, callback, userData);
109
+ } catch (const std::exception& e) {
110
+ std::cerr << e.what() << std::endl;
111
+ return GENIE_STATUS_ERROR_GENERAL;
112
+ }
113
+
114
+ return status;
115
+ }
116
+
117
+ GENIE_API
118
+ Genie_Status_t GenieDialog_save(const GenieDialog_Handle_t dialogHandle, const char* path) {
119
+ int32_t status;
120
+
121
+ try {
122
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
123
+ auto dialog = genie::Dialog::get(dialogHandle);
124
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
125
+ GENIE_ENSURE(path, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
126
+ status = dialog->save(path);
127
+ } catch (const std::exception& e) {
128
+ std::cerr << e.what() << std::endl;
129
+ return GENIE_STATUS_ERROR_GENERAL;
130
+ }
131
+
132
+ return status;
133
+ }
134
+
135
+ GENIE_API
136
+ Genie_Status_t GenieDialog_restore(const GenieDialog_Handle_t dialogHandle, const char* path) {
137
+ int32_t status;
138
+
139
+ try {
140
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
141
+ auto dialog = genie::Dialog::get(dialogHandle);
142
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
143
+ GENIE_ENSURE(path, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
144
+ status = dialog->restore(path);
145
+ } catch (const std::exception& e) {
146
+ std::cerr << e.what() << std::endl;
147
+ return GENIE_STATUS_ERROR_GENERAL;
148
+ }
149
+
150
+ return status;
151
+ }
152
+
153
+ GENIE_API
154
+ Genie_Status_t GenieDialog_reset(const GenieDialog_Handle_t dialogHandle) {
155
+ try {
156
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
157
+ auto dialog = genie::Dialog::get(dialogHandle);
158
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
159
+ dialog->reset();
160
+ } catch (const std::exception& e) {
161
+ return GENIE_STATUS_ERROR_GENERAL;
162
+ }
163
+ return GENIE_STATUS_SUCCESS;
164
+ }
165
+
166
+ #if defined(GENIE_LORA_FEATURE)
167
+
168
+ GENIE_API
169
+ Genie_Status_t GenieDialog_applyLora(const GenieDialog_Handle_t dialogHandle,
170
+ const char* engine,
171
+ const char* loraAdapterName) {
172
+ int32_t status;
173
+ try {
174
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
175
+ auto dialog = genie::Dialog::get(dialogHandle);
176
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
177
+ GENIE_ENSURE(engine, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
178
+ std::string eng(engine);
179
+ GENIE_ENSURE(loraAdapterName, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
180
+ std::string loraName(loraAdapterName);
181
+ status = dialog->applyLora(loraName, eng);
182
+ } catch (const std::exception& e) {
183
+ return GENIE_STATUS_ERROR_GENERAL;
184
+ }
185
+ return status;
186
+ }
187
+
188
+ GENIE_API
189
+ Genie_Status_t GenieDialog_setLoraStrength(const GenieDialog_Handle_t dialogHandle,
190
+ const char* engine,
191
+ const char* tensorName,
192
+ const float alpha) {
193
+ int32_t status;
194
+ try {
195
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
196
+ auto dialog = genie::Dialog::get(dialogHandle);
197
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
198
+ GENIE_ENSURE(engine, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
199
+ std::string eng(engine);
200
+ GENIE_ENSURE(tensorName, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
201
+ std::string alphaTensorName(tensorName);
202
+ GENIE_ENSURE_NOT_EMPTY(alphaTensorName, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
203
+ status = dialog->applyLoraStrength(tensorName, eng, alpha);
204
+ } catch (const std::exception& e) {
205
+ return GENIE_STATUS_ERROR_GENERAL;
206
+ }
207
+ return status;
208
+ }
209
+
210
+ #endif
211
+
212
+ GENIE_API
213
+ Genie_Status_t GenieDialog_tokenQuery(const GenieDialog_Handle_t dialogHandle,
214
+ const uint32_t* inputTokens,
215
+ const uint32_t numTokens,
216
+ const GenieDialog_SentenceCode_t sentenceCode,
217
+ const GenieDialog_TokenQueryCallback_t callback,
218
+ const void* userData) {
219
+ bool status;
220
+ try {
221
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
222
+ auto dialog = genie::Dialog::get(dialogHandle);
223
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
224
+ GENIE_ENSURE(inputTokens, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
225
+ GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
226
+ status = dialog->tokenQuery(inputTokens, numTokens, sentenceCode, callback, userData);
227
+ } catch (const std::exception& e) {
228
+ std::cerr << e.what() << std::endl;
229
+ return GENIE_STATUS_ERROR_GENERAL;
230
+ }
231
+
232
+ return status;
233
+ }
234
+
235
+ GENIE_API
236
+ Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
237
+ try {
238
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
239
+ {
240
+ // Check if the dialog actually exists
241
+ auto dialog = genie::Dialog::get(dialogHandle);
242
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
243
+ }
244
+ genie::Dialog::remove(dialogHandle);
245
+ } catch (const std::exception& e) {
246
+ return GENIE_STATUS_ERROR_GENERAL;
247
+ }
248
+ return GENIE_STATUS_SUCCESS;
249
+ }
Genie/Genie/src/GenieDialogEmbedding.cpp ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //=============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //=============================================================================
8
+
9
+ #include "Dialog.hpp"
10
+ #include "Exception.hpp"
11
+ #include "GenieDialog.h"
12
+ #include "Macro.hpp"
13
+ #include "Util/HandleManager.hpp"
14
+ #include "qualla/detail/json.hpp"
15
+
16
+ using namespace genie;
17
+
18
+ GENIE_API
19
+ Genie_Status_t GenieDialog_embeddingQuery(const GenieDialog_Handle_t dialogHandle,
20
+ const void* embeddings,
21
+ const uint32_t embeddingsSize,
22
+ const GenieDialog_SentenceCode_t sentenceCode,
23
+ const GenieDialog_TokenToEmbeddingCallback_t t2eCallback,
24
+ const GenieDialog_QueryCallback_t callback,
25
+ const void* userData) {
26
+ Genie_Status_t status;
27
+ try {
28
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
29
+ auto dialog = genie::Dialog::get(dialogHandle);
30
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
31
+ GENIE_ENSURE(embeddings, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
32
+ GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
33
+ status = dialog->embeddingQuery(
34
+ embeddings, embeddingsSize, sentenceCode, t2eCallback, callback, userData);
35
+ } catch (const std::exception& e) {
36
+ std::cerr << e.what() << std::endl;
37
+ return GENIE_STATUS_ERROR_GENERAL;
38
+ }
39
+
40
+ return status;
41
+ }
Genie/Genie/src/Macro.hpp ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //=============================================================================
8
+
9
+ #pragma once
10
+
11
+ //======================================================================================================================
12
+ // Error generation macros
13
+ //======================================================================================================================
14
+
15
+ #define GENIE_LOG_ERROR(fmt, ...)
16
+
17
+ #define GENIE_ENSURE_MSG(value, return_error, msg) \
18
+ do { \
19
+ if (!(value)) { \
20
+ GENIE_LOG_ERROR(" " msg); \
21
+ return return_error; \
22
+ } \
23
+ } while (0)
24
+
25
+ #define GENIE_ENSURE(value, return_error) \
26
+ do { \
27
+ if (!(value)) { \
28
+ GENIE_LOG_ERROR("%s was not true.", #value); \
29
+ return return_error; \
30
+ } \
31
+ } while (0)
32
+
33
+ #define GENIE_ENSURE_STATUS(status, return_error) \
34
+ do { \
35
+ if ((status) != GENIE_SUCCESS) { \
36
+ return return_error; \
37
+ } \
38
+ } while (0)
39
+
40
+ #define GENIE_ENSURE_EQ(a, b, return_error) \
41
+ do { \
42
+ if ((a) != (b)) { \
43
+ GENIE_LOG_ERROR("%s != %s (%d != %d)", #a, #b, (a), (b)); \
44
+ return return_error; \
45
+ } \
46
+ } while (0)
47
+
48
+ #define GENIE_ENSURE_NOT_EMPTY(value, return_error) \
49
+ do { \
50
+ if (value.empty()) { \
51
+ GENIE_LOG_ERROR("%s was not true.", #value); \
52
+ return return_error; \
53
+ } \
54
+ } while (0)
55
+ //======================================================================================================================
56
+ // JSON config macros
57
+ //======================================================================================================================
58
+
59
+ #define JSON_ENFORCE_OBJECT() \
60
+ if (!item.value().is_object()) { \
61
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
62
+ "Invalid " + component + " config: " + item.key() + " is not an object"); \
63
+ }
64
+
65
+ #define JSON_ENFORCE_ARRAY() \
66
+ if (!item.value().is_array()) { \
67
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
68
+ "Invalid " + component + " config: " + item.key() + " is not an array"); \
69
+ }
70
+
71
+ #define JSON_ENFORCE_ARRAY_OR_OBJECT() \
72
+ if (!item.value().is_array() && !item.value().is_object()) { \
73
+ throw Exception( \
74
+ GENIE_STATUS_ERROR_JSON_SCHEMA, \
75
+ "Invalid " + component + " config: " + item.key() + " is not an array or object"); \
76
+ }
77
+
78
+ #define JSON_ENFORCE_NUMERIC() \
79
+ if (!item.value().is_number()) { \
80
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
81
+ "Invalid " + component + " config: " + item.key() + " is not numeric"); \
82
+ }
83
+
84
+ #define JSON_ENFORCE_ARRAY_OR_NUMERIC() \
85
+ if (!item.value().is_number() && !item.value().is_array()) { \
86
+ throw Exception( \
87
+ GENIE_STATUS_ERROR_JSON_SCHEMA, \
88
+ "Invalid " + component + " config: " + item.key() + " is not an array or numeric"); \
89
+ }
90
+
91
+ #define JSON_ENFORCE_BOOLEAN() \
92
+ if (!item.value().is_boolean()) { \
93
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
94
+ "Invalid " + component + " config: " + item.key() + " is not boolean"); \
95
+ }
96
+
97
+ #define JSON_ENFORCE_STRING() \
98
+ if (!item.value().is_string()) { \
99
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, \
100
+ "Invalid " + component + " config: " + item.key() + " is not a string"); \
101
+ }
Genie/Genie/src/Util/HandleGenerator.hpp ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) 2019-2020,2023 Qualcomm Technologies, Inc.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <mutex>
12
+
13
+ namespace qnn {
14
+ namespace util {
15
+
16
+ typedef std::size_t Handle_t;
17
+
18
+ class HandleGenerator final {
19
+ static_assert(std::is_integral<Handle_t>::value, "Handle must be an integral type");
20
+ static_assert((sizeof(Handle_t) == 8) || (sizeof(Handle_t) == 4),
21
+ "Implementation of HandleGenerator::bswap() for sizeof(std::size_t) is required");
22
+
23
+ public:
24
+ HandleGenerator(const HandleGenerator&) = delete;
25
+ HandleGenerator& operator=(const HandleGenerator&) = delete;
26
+ HandleGenerator(HandleGenerator&&) = delete;
27
+ HandleGenerator& operator=(HandleGenerator&&) = delete;
28
+
29
+ static Handle_t generate(const void* const addr) {
30
+ return (bswap((Handle_t)addr) ^ (Handle_t)s_operand);
31
+ }
32
+ static const void* reverse(const Handle_t handle) {
33
+ return (void*)bswap(handle ^ (Handle_t)s_operand);
34
+ }
35
+ static constexpr Handle_t invalid() { return s_operand; }
36
+
37
+ private:
38
+ HandleGenerator() {}
39
+
40
+ static uint32_t bswap32(const uint32_t val) {
41
+ return (val >> 24U) | ((val >> 8U) & 0xff00U) | ((val << 8U) & 0xff0000U) | (val << 24U);
42
+ }
43
+
44
+ static uint64_t bswap64(const uint64_t val) {
45
+ return ((bswap32(val) + 0ULL) << 32U) | bswap32(val >> 32U);
46
+ }
47
+
48
+ template <typename T>
49
+ static size_t bswap(T val) {
50
+ if (sizeof(T) == 4) {
51
+ return bswap32(val);
52
+ } else {
53
+ return bswap64(val);
54
+ }
55
+ }
56
+
57
+ // Magic number generated via "openssl rand -hex 8"
58
+ static constexpr Handle_t s_operand = (Handle_t)0xd4c2416534bcdc9b;
59
+ };
60
+
61
+ } // namespace util
62
+ } // namespace qnn
Genie/Genie/src/Util/HandleManager.hpp ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) 2019-2020 Qualcomm Technologies, Inc.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <algorithm>
12
+ #include <functional>
13
+ #include <memory>
14
+ #include <mutex>
15
+ #include <unordered_map>
16
+
17
+ #include "HandleGenerator.hpp"
18
+
19
+ namespace qnn {
20
+ namespace util {
21
+
22
+ template <typename T>
23
+ class HandleManager {
24
+ public:
25
+ HandleManager() = default;
26
+ HandleManager(const HandleManager&) = delete;
27
+ HandleManager& operator=(const HandleManager&) = delete;
28
+ HandleManager(HandleManager&&) = delete;
29
+ HandleManager& operator=(HandleManager&&) = delete;
30
+
31
+ Handle_t add(std::shared_ptr<T> item) {
32
+ std::lock_guard<std::mutex> locker(m_itemsMtx);
33
+
34
+ if (!item) {
35
+ return HandleGenerator::invalid();
36
+ }
37
+
38
+ auto handle = HandleGenerator::generate(item.get());
39
+ m_items[handle] = item;
40
+ return handle;
41
+ }
42
+
43
+ Handle_t add(T* item) { return add(std::shared_ptr<T>(item)); }
44
+
45
+ Handle_t add(std::weak_ptr<T> item) { return add(item.lock()); }
46
+
47
+ std::shared_ptr<T> get(Handle_t handle) {
48
+ std::lock_guard<std::mutex> locker(m_itemsMtx);
49
+
50
+ auto it = m_items.find(handle);
51
+ if (it == m_items.end()) {
52
+ return std::shared_ptr<T>(nullptr);
53
+ }
54
+
55
+ return it->second;
56
+ }
57
+
58
+ typedef std::function<bool(const std::pair<Handle_t, std::shared_ptr<T>>&)> UnaryPredicate_t;
59
+
60
+ Handle_t findIf(UnaryPredicate_t pred) const {
61
+ auto it = std::find_if(m_items.begin(), m_items.end(), pred);
62
+ if (it == m_items.end()) {
63
+ return HandleGenerator::invalid();
64
+ }
65
+
66
+ return it->first;
67
+ }
68
+
69
+ size_t remove(Handle_t handle) {
70
+ std::lock_guard<std::mutex> locker(m_itemsMtx);
71
+ return m_items.erase(handle);
72
+ }
73
+
74
+ void clear() { m_items.clear(); }
75
+
76
+ const std::unordered_map<Handle_t, std::shared_ptr<T>>& getItems() const { return m_items; }
77
+
78
+ private:
79
+ std::unordered_map<Handle_t, std::shared_ptr<T>> m_items;
80
+ std::mutex m_itemsMtx;
81
+ };
82
+
83
+ } // namespace util
84
+ } // namespace qnn
Genie/Genie/src/qualla/context.cpp ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/logger.hpp>
10
+ #include <qualla/context.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+ #include <qualla/detail/onload.hpp>
13
+
14
+ #include <fmt/format.h>
15
+ #include <fmt/ranges.h>
16
+
17
+ namespace qualla {
18
+
19
+ Context::Context(Env& env, const std::string& name, const qualla::json& json)
20
+ : _name(name), _env(env), _conf(json) {
21
+ _env.logger().debug(fmt::format("ctx-new: {} config {}", _name, _conf.dump()));
22
+
23
+ qualla::Config conf(json, "context:");
24
+ _size = conf.optional<size_t>("size", 1024);
25
+ _size = conf.optional<size_t>("n-ctx", _size); // alternative name
26
+ _n_vocab = conf.optional<size_t>("n-vocab", 32000);
27
+ _n_embd = conf.optional<size_t>("n-embd", 1024);
28
+ _embedding_length = conf.optional<int32_t>("embedding-length", -1);
29
+ _embedding_datatype = conf.optional<std::string>("embedding-datatype", "float32");
30
+ // For backward compatibility. When eot-token is removed, this logic can be simplified
31
+ // Currently, EOT is marked as default truncating token if available
32
+ int32_t eot_tok = conf.optional<int32_t>("eot-token", -1);
33
+ if (eot_tok >= 0) _eos_tok_list.insert(eot_tok);
34
+
35
+ const qualla::json eos_conf = conf.optional<qualla::json>("eos-token", _eos_tok);
36
+ if (eos_conf.is_array() && eos_conf.size() > 0) {
37
+ const std::vector<int32_t>& eos_tokens = eos_conf.get<std::vector<int32_t>>();
38
+ _eos_tok = eos_tokens[0];
39
+ for (const int32_t& eos_tok : eos_tokens)
40
+ _eos_tok_list.insert(eos_tok);
41
+ } else if (eos_conf.is_number_integer()) {
42
+ int32_t eos_tok = eos_conf.get<int32_t>();
43
+ _eos_tok = (eot_tok >= 0) ? eot_tok : eos_tok;
44
+ _eos_tok_list.insert(eos_tok);
45
+ }
46
+
47
+ _pad_tok = conf.optional<qualla::json>("pad-token", _eos_tok);
48
+ }
49
+
50
+ std::unique_ptr<Context> Context::create(
51
+ Env& env,
52
+ const std::string& name,
53
+ const qualla::json& conf
54
+ ) {
55
+ return std::make_unique<Context>(env, name, conf);
56
+ }
57
+
58
+ std::unique_ptr<Context> Context::create(
59
+ Env& env,
60
+ const std::string& name,
61
+ std::istream& json_stream
62
+ ) {
63
+ return create(env, name, json::parse(json_stream));
64
+ }
65
+
66
+ std::unique_ptr<Context> Context::create(
67
+ Env& env,
68
+ const std::string& name,
69
+ const std::string& json_str
70
+ ) {
71
+ return create(env, name, json::parse(json_str));
72
+ }
73
+
74
+ #ifdef QUALLA_STATIC
75
+
76
+ // This is a hack to make sure all core bits are linked in for the static build
77
+
78
+ extern void needFileLogger();
79
+ extern void needStdoutLogger();
80
+ extern void needBasicSampler();
81
+ extern void needBasicDialog();
82
+ extern void needKvShareDialog();
83
+ extern void needSpdDialog();
84
+ extern void needSsdDialog();
85
+ extern void needLadeDialog();
86
+ extern void needMultistreamDialog();
87
+
88
+ #ifdef QUALLA_ENGINE_QNN_HTP
89
+ extern void needQnnHtpEngine();
90
+ #endif
91
+
92
+ #ifdef QUALLA_ENGINE_QNN_CPU
93
+ extern void needQnnCpuEngine();
94
+ #endif
95
+
96
+ static OnLoad needs([]() {
97
+ needStdoutLogger();
98
+ needFileLogger();
99
+ needBasicDialog();
100
+ needBasicSampler();
101
+ needKvShareDialog();
102
+ needSpdDialog();
103
+ needSsdDialog();
104
+ needLadeDialog();
105
+ needMultistreamDialog();
106
+
107
+ #ifdef QUALLA_ENGINE_QNN_HTP
108
+ needQnnHtpEngine();
109
+ #endif
110
+
111
+ #ifdef QUALLA_ENGINE_QNN_CPU
112
+ needQnnCpuEngine();
113
+ #endif
114
+ });
115
+
116
+ #endif
117
+
118
+ } // namespace qualla
Genie/Genie/src/qualla/dialog.cpp ADDED
@@ -0,0 +1,590 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All rights reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/dialog.hpp>
10
+ #include <qualla/logger.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+ #include <qualla/detail/timer.hpp>
13
+ #include <qualla/detail/sampler-utils.hpp>
14
+
15
+ #include <algorithm>
16
+ #include <functional>
17
+ #include <fstream>
18
+ #include <string>
19
+ #include <unordered_map>
20
+ #include <filesystem>
21
+ #include <iostream>
22
+
23
+ #include <fmt/format.h>
24
+ #include <fmt/ranges.h>
25
+
26
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
27
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
28
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
29
+ #define __KPIS(__fmt, ...) \
30
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
31
+ #define __DEBUG(__fmt, ...) \
32
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
33
+ #define __TRACE(__fmt, ...) \
34
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
35
+
36
+ namespace fs = std::filesystem;
37
+
38
+ namespace qualla {
39
+
40
+ Dialog::Dialog(std::shared_ptr<Env> env, const std::string& name, const qualla::json& json)
41
+ : _env(env) {
42
+ Timer start;
43
+
44
+
45
+
46
+ __DEBUG("dialog-new: {} config {}", name, json.dump());
47
+
48
+ using qc = qualla::Config;
49
+
50
+ // Create Gpiomarker and reset the gpio status to low
51
+ const qualla::json& gpio_conf = qc::optional<qualla::json>(json, "gpio", {});
52
+ _gpio_marker = GpioMarker::create(gpio_conf);
53
+
54
+ _gpio_marker->set();
55
+
56
+ // Create the context first
57
+ _ctx = Context::create(*_env, name, qc::mandatory<qualla::json>(json, "context"));
58
+
59
+ // Parse prompt config
60
+ const qualla::json& pmt_conf = qc::optional<qualla::json>(json, "prompt", {});
61
+ _prompt_type = qc::optional<std::string>(pmt_conf, "type", "llama2");
62
+ _sys_tags = qc::optional<std::vector<std::string>>(pmt_conf, "sys-tags", {"", ""});
63
+ _inst_tags = qc::optional<std::vector<std::string>>(pmt_conf, "inst-tags", {"", ""});
64
+ _role_tags = qc::optional<std::vector<std::string>>(pmt_conf, "role-tags", {"", ""});
65
+ _sys_prompt = qc::optional<std::string>(pmt_conf, "sys-prompt", "");
66
+
67
+ const std::vector<std::string>& stop_sequence =
68
+ qc::optional<std::vector<std::string>>(pmt_conf, "stop-sequence", {});
69
+ _stop_sequence = SequenceMatchTrie(stop_sequence);
70
+
71
+ // Create Tokenizer
72
+ // TODO: auto-detect / validate n_vocab with tokenizer vocab
73
+ fs::path tok_path = _env->path().models / qc::mandatory<std::string>(json, "tokenizer");
74
+ _tokenizer = Tokenizer::create(*_ctx, tok_path);
75
+
76
+ // Create Sampler(s)
77
+ auto add_sampler = [&](const qualla::json& j) {
78
+ std::string role = qc::optional<std::string>(j, "role", "primary");
79
+ _sampler[role] = Sampler::create(*_ctx, j);
80
+ };
81
+
82
+ const qualla::json& sam_conf = qc::mandatory<qualla::json>(json, "sampler");
83
+ if (sam_conf.is_array()) {
84
+ for (auto sc : sam_conf) {
85
+ add_sampler(sc);
86
+ }
87
+ } else
88
+ add_sampler(sam_conf);
89
+
90
+
91
+
92
+
93
+ // Create Engine(s)
94
+ auto add_engine = [&](const qualla::json& j) {
95
+ std::string role = qc::optional<std::string>(j, "role", "primary");
96
+
97
+ _engine[role] = Engine::create(*_ctx, j);
98
+
99
+ using FF = Engine::Feature::Flags;
100
+
101
+
102
+ if (!_engine[role]->supports(FF::OUTPUT_LOGITS))
103
+ throw std::runtime_error("the engine must output Logits");
104
+ };
105
+
106
+
107
+
108
+ const qualla::json& eng_conf = qc::mandatory<qualla::json>(json, "engine");
109
+
110
+
111
+ if (eng_conf.is_array()) {
112
+
113
+ for (auto ec : eng_conf) {
114
+ add_engine(ec);
115
+ }
116
+ } else{
117
+ add_engine(eng_conf);
118
+
119
+ }
120
+
121
+ // Store input type (token, embedding, etc) from the engine.
122
+ // This assumes multi-engine usecases use matching input types.
123
+ m_inputType = _engine.begin()->second->getInputType();
124
+
125
+ _kpis.init.update(start.elapsed_usec());
126
+ }
127
+
128
+ Dialog::~Dialog() {}
129
+
130
+ static bool __no_response_query(const std::string&, Sentence::Code) {
131
+ return false;
132
+ }
133
+
134
+ static bool __no_response_token(const int32_t*, const uint32_t, Sentence::Code) {
135
+ return false;
136
+ }
137
+
138
+ static bool __no_response(const std::string&, Sentence::Code) {
139
+ return false;
140
+ }
141
+
142
+ void Dialog::getTopK(std::vector<float>& logits, std::vector<std::vector<int32_t>>& tokens, size_t topK, float pThreshold, Dialog::Callback callback) {
143
+
144
+ auto& sampler = *_sampler["primary"];
145
+
146
+ // Sample top-k logits but with a minimum probability threshold
147
+ #if defined(__GNUC__) && !defined(__clang__)
148
+ std::span<float> indexed_logits_span(logits);
149
+ IndexedLogits indexed_logits(indexed_logits_span, sampler.rng());
150
+ #else
151
+ IndexedLogits indexed_logits(std::span{logits.data(),logits.size()}, sampler.rng());
152
+ #endif
153
+ indexed_logits.softmax();
154
+ indexed_logits.topK(topK);
155
+
156
+ for (int i = 0; i < topK; i++) {
157
+
158
+ _last_tok = indexed_logits.indices[i];
159
+
160
+ // Only sample tokens above some probability threshold
161
+ // TODO: Modify sampling algorithm as necessary
162
+ if (indexed_logits.probs[i] < pThreshold) {
163
+ break;
164
+ } else if (_ctx->is_eos(_last_tok)) {
165
+ callback("", Sentence::CONTINUE);
166
+ } else {
167
+ tokens.push_back({_last_tok});
168
+ }
169
+ }
170
+ }
171
+
172
+ bool Dialog::query(const std::string& str, Sentence::Code scode, Dialog::Callback callback) {
173
+ std::vector<int32_t> p_vec; // prompt tokens
174
+ std::string p_str; // prompt string
175
+
176
+ p_vec.reserve(1024);
177
+
178
+ if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) {
179
+ // Reset prompt/gen counts for new query
180
+ _n_prompt = 0;
181
+ _n_generated = 0;
182
+ _n_previous_prompt = 0;
183
+ _n_previous_generated = 0;
184
+
185
+
186
+ if (_last_tok >= 0 && !_ctx->is_eos(_last_tok)) p_vec.push_back(_last_tok);
187
+
188
+ p_str = _inst_tags[0];
189
+
190
+ if (!_n_queries) {
191
+ // First query. Prepend sys-prompt.
192
+ p_str += _sys_tags[0] + _sys_prompt + _sys_tags[1];
193
+ } else {
194
+ // Add EOS explicitly if the last query was aborted prematurely.
195
+ if (_ctx->eos_tok() >= 0) p_vec.push_back(_ctx->eos_tok());
196
+ }
197
+
198
+ // Add BOS
199
+ if (_ctx->bos_tok() >= 0) {
200
+ p_vec.push_back(_ctx->bos_tok());
201
+ }
202
+ }
203
+
204
+ // FIXME: make this more generic
205
+ if (_prompt_type == "llama3") {
206
+ p_str += _sys_tags[0] + _role_tags[1] + _sys_tags[1] + str + _inst_tags[2];
207
+ } else {
208
+ p_str += str;
209
+ }
210
+
211
+ if (scode == Sentence::COMPLETE || scode == Sentence::END) {
212
+ if (_prompt_type == "llama3") {
213
+ p_str += _sys_tags[0] + _role_tags[2] + _sys_tags[1];
214
+ } else {
215
+ p_str += _inst_tags[1];
216
+ }
217
+ }
218
+
219
+ _env->logger().post(Logger::DEBUG, [&]() {
220
+ qualla::json j{{"string", str}, {"prompt", p_str}};
221
+ return fmt::format("dialog-query: {} {}", _ctx->name(), j.dump());
222
+ });
223
+
224
+ _n_queries++;
225
+
226
+ _tokenizer->encode(p_str, p_vec);
227
+
228
+ __DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec);
229
+ __DEBUG("dialog-text: \"{}\"", p_str);
230
+
231
+ if (scode == Sentence::COMPLETE || scode == Sentence::END) {
232
+ // Detect stop sequences here
233
+ if (!_stop_sequence.empty()) {
234
+ _stop_sequence.reset();
235
+ return process(p_vec, [&](const std::string& str, Sentence::Code c) {
236
+ // Check for stop sequence and end inference when stop sequence is found
237
+ if (_stop_sequence.process_next_string(str)) {
238
+ callback(str, c); // Emit sequences until match is complete
239
+ return false;
240
+ }
241
+
242
+ // Else, return normal callback function
243
+ return callback(str, c);
244
+ });
245
+ }
246
+
247
+ return process(p_vec, callback);
248
+ }
249
+
250
+ return process(p_vec, __no_response);
251
+ }
252
+
253
+ bool Dialog::query(const std::vector<uint32_t>& input, Sentence::Code scode, qualla::DialogCallback& callback) {
254
+ std::vector<int32_t> p_vec; // prompt tokens
255
+ p_vec.reserve(1024);
256
+
257
+ if (scode == Sentence::COMPLETE || scode == Sentence::BEGIN) {
258
+ // Reset prompt/gen counts for new query
259
+ _n_prompt = 0;
260
+ _n_generated = 0;
261
+ _n_previous_prompt = 0;
262
+ _n_previous_generated = 0;
263
+
264
+ if (_last_tok >= 0)
265
+ p_vec.push_back(_last_tok);
266
+
267
+ // Add EOS explicitly if the last query was aborted prematurely.
268
+ if (_n_queries && _last_tok != _ctx->eos_tok()) {
269
+ p_vec.push_back(_ctx->eos_tok());
270
+ }
271
+ // Add BOS
272
+ if (_ctx->bos_tok() >= 0) {
273
+ p_vec.push_back(_ctx->bos_tok());
274
+ }
275
+ }
276
+
277
+ p_vec.insert(p_vec.end(), input.begin(), input.end());
278
+ __DEBUG("dialog-tokens: {} {}", _ctx->name(), p_vec);
279
+
280
+ _n_queries++;
281
+
282
+ if (scode == Sentence::COMPLETE || scode == Sentence::END) {
283
+ return process(p_vec, callback);
284
+ }
285
+
286
+ DialogCallback callback_return_token(QUALLA_CALLBACK_TYPE_TOKEN);
287
+ *(callback_return_token.getTokenCbFunc()) = __no_response_token;
288
+ return process(p_vec, callback_return_token);
289
+ }
290
+
291
+ bool Dialog::query(
292
+ std::vector<uint8_t>& embedding_vectors,
293
+ Sentence::Code scode,
294
+ T2ECallback t2eCallback,
295
+ Dialog::Callback callback
296
+ ) {
297
+ _n_queries++;
298
+ if (scode == Sentence::COMPLETE || scode == Sentence::END) {
299
+ return process(embedding_vectors, t2eCallback, callback);
300
+ }
301
+ // Only process, no output
302
+ return process(embedding_vectors, t2eCallback, [&](const std::string&, Sentence::Code) {
303
+ return false;
304
+ });
305
+ }
306
+
307
+ bool Dialog::prime(const std::string& str) {
308
+ bool r = query(str, Sentence::COMPLETE, __no_response);
309
+
310
+ // End with EOS as we want the primer to be self-contained
311
+ _last_tok = _ctx->eos_tok();
312
+
313
+ return r;
314
+ }
315
+
316
+ bool Dialog::save(const std::string& o_name) {
317
+ Timer start;
318
+
319
+ // Save using session name unless override is provided
320
+ std::string name = o_name.empty() ? _ctx->name() : o_name;
321
+ fs::path save_path = name;
322
+
323
+ if (!_n_past) {
324
+ __ERROR("dialog-save: {} : nothing to save yet", name);
325
+ return false;
326
+ }
327
+
328
+ __INFO("dialog-save: saving as {} {}", name, save_path.string());
329
+
330
+ if (!fs::exists(save_path) && !fs::create_directories(save_path)) {
331
+ __ERROR("dialog-save: {} : failed to create cache directory", name);
332
+ return false;
333
+ }
334
+
335
+ // Save Dialog state
336
+ qualla::json j{
337
+ {"n-past", _n_past},
338
+ {"n-prompt", _n_prompt},
339
+ {"n-generated", _n_generated},
340
+ {"n-queries", _n_queries},
341
+ {"last-tok", _last_tok}
342
+ };
343
+ {
344
+ fs::path p = save_path / "dialog.json";
345
+ std::ofstream f(p);
346
+ f << j;
347
+ }
348
+
349
+ // Save Engines (mandatory)
350
+ for (auto& e : _engine) {
351
+ if (!e.second->save(name)) {
352
+ __ERROR("dialog-save: {} : unable to save {} engine", name, e.first);
353
+ return false;
354
+ }
355
+ }
356
+
357
+ // Save Samplers (optional)
358
+ for (auto& s : _sampler) {
359
+ if (!s.second->save(name)) {
360
+ __WARN("dialog-save: {} : unable to save {} sampler", name, s.first);
361
+ }
362
+ }
363
+
364
+ _kpis.save.update(start.elapsed_usec());
365
+
366
+ return true;
367
+ }
368
+
369
+ bool Dialog::restore(const std::string& o_name) {
370
+ Timer start;
371
+
372
+ // Restore using session name unless override is provided
373
+ std::string name = o_name.empty() ? _ctx->name() : o_name;
374
+ fs::path restore_path = name;
375
+
376
+ __INFO("dialog-restore: restoring from {} {}", name, restore_path.string());
377
+
378
+ // Try to restore the Dialog state (optional)
379
+ // If this fails we reset everything and try to restore the engine.
380
+ qualla::json j{};
381
+ {
382
+ fs::path p = restore_path / "dialog.json";
383
+ if (fs::exists(p)) {
384
+ std::ifstream f(p);
385
+ j = qualla::json::parse(f);
386
+ } else {
387
+ __DEBUG("dialog-restore: {} : internal state not restored", name);
388
+ }
389
+ }
390
+
391
+ using qc = qualla::Config;
392
+ _n_past = qc::optional<uint32_t>(j, "n-past", 0);
393
+ _n_prompt = qc::optional<uint32_t>(j, "n-prompt", 0);
394
+ _n_generated = qc::optional<uint32_t>(j, "n-generated", 0);
395
+ _n_queries = qc::optional<uint32_t>(j, "n-queries", 1);
396
+ _last_tok = qc::optional<int32_t>(j, "last-tok", _ctx->eos_tok());
397
+
398
+ // Restore Engines (mandatory)
399
+ for (auto& e : _engine) {
400
+ uint32_t n = e.second->restore(name);
401
+ if (!n) {
402
+ __ERROR("dialog-restore: {} : unable to restore {} engine", name, e.first);
403
+ return false;
404
+ }
405
+
406
+ // Restore n_past from the engine state
407
+ if (_n_past && n != _n_past) {
408
+ __WARN("dialog-restore: {} : n-past mismatch : {} engine {} intern {}",
409
+ name,
410
+ e.first,
411
+ _n_past,
412
+ n);
413
+ // Keep the smaller number
414
+ _n_past = std::min(n, _n_past);
415
+ } else
416
+ _n_past = n;
417
+ }
418
+
419
+ // Restore Samplers (optional)
420
+ for (auto& s : _sampler) {
421
+ if (!s.second->restore(name)) {
422
+ __WARN("dialog-restore: {} : unable to restore {} sampler", name, s.first);
423
+ }
424
+ }
425
+
426
+ _kpis.reset();
427
+ _kpis.restore.update(start.elapsed_usec());
428
+
429
+ return true;
430
+ }
431
+
432
+ void Dialog::reset() {
433
+ __INFO("dialog-reset: {}", _ctx->name());
434
+
435
+ _n_past = 0;
436
+ _n_prompt = 0;
437
+ _n_generated = 0;
438
+ _n_queries = 0;
439
+ _last_tok = -1;
440
+ _n_previous_prompt = 0;
441
+ _n_previous_generated = 0;
442
+
443
+ _kpis.reset();
444
+
445
+ // Reset Engines and Samplers
446
+ for (auto& e : _engine)
447
+ e.second->reset();
448
+ for (auto& s : _sampler)
449
+ s.second->reset();
450
+
451
+ State::clear();
452
+ }
453
+
454
+ // Dialog KPIs helpers
455
+
456
+ // Get latest KPIs
457
+ Dialog::KPIs& Dialog::kpis() {
458
+ // Update TPS
459
+ if (_n_prompt) {
460
+ float t = _kpis.prompt.last_usec / _n_prompt;
461
+ _kpis.tps.n_prompt = _n_prompt;
462
+ _kpis.tps.prompt = 1000000.0 / (t ? t : 1000000.0);
463
+ }
464
+
465
+ if (_n_generated) {
466
+ float t = _kpis.generate.last_usec / _n_generated;
467
+ _kpis.tps.n_generate = _n_generated;
468
+ _kpis.tps.generate = 1000000.0 / (t ? t : 1000000.0);
469
+ }
470
+
471
+ // We could synthesize more KPIs from from other layers (engine, sampler, etc)
472
+ return _kpis;
473
+ }
474
+
475
+ std::string Dialog::KPIs::dump(std::string_view sep) const {
476
+ return fmt::format(
477
+ "init:[{}]{}prompt:[{}]{}generate:[{}]{}save:[{}]{}restore:[{}]{} tps-prompt:{:.2f} tps-generate:{:.2f}",
478
+ init.dump(),
479
+ sep,
480
+ prompt.dump(),
481
+ sep,
482
+ generate.dump(),
483
+ sep,
484
+ save.dump(),
485
+ sep,
486
+ restore.dump(),
487
+ sep,
488
+ tps.prompt,
489
+ tps.generate
490
+ );
491
+ }
492
+
493
+ void Dialog::KPIs::reset() {
494
+ init.reset();
495
+ prompt.reset();
496
+ generate.reset();
497
+ save.reset();
498
+ restore.reset();
499
+ tps.prompt = 0.0;
500
+ tps.generate = 0.0;
501
+ }
502
+
503
+ // Create API
504
+
505
+ // Dialog registry : type string + creator function
506
+ using Registry = std::unordered_map<std::string, Dialog::Creator>;
507
+ static std::unique_ptr<Registry> registry;
508
+
509
+ void Dialog::__register(const std::string& type, Creator func) {
510
+ if (!registry) registry = std::make_unique<Registry>();
511
+
512
+ Registry& r = *registry;
513
+
514
+
515
+ r[type] = func;
516
+ }
517
+
518
+ std::unique_ptr<Dialog> Dialog::create(
519
+ std::shared_ptr<Env> env,
520
+ const std::string& name,
521
+ const qualla::json& conf
522
+ ) {
523
+
524
+ using qc = qualla::Config;
525
+ std::string type = qc::optional<std::string>(conf, "type", "basic");
526
+
527
+ if (!registry) throw std::runtime_error(type + ": dialog not found");
528
+
529
+ Registry& r = *registry;
530
+
531
+ if (!r.contains(type)) throw std::runtime_error(type + ": dialog not found");
532
+
533
+ if (!r.contains(type)) {
534
+ throw std::runtime_error(type + ": dialog not found");
535
+ }
536
+
537
+ return std::unique_ptr<Dialog>(r[type](env, name, conf));
538
+ }
539
+
540
+ std::unique_ptr<Dialog> Dialog::create(
541
+ std::shared_ptr<Env> env,
542
+ const std::string& name,
543
+ std::istream& json_stream
544
+ ) {
545
+
546
+ return create(env, name, json::parse(json_stream));
547
+ }
548
+
549
+ std::unique_ptr<Dialog> Dialog::create(
550
+ std::shared_ptr<Env> env,
551
+ const std::string& name,
552
+ const fs::path& json_path
553
+ ) {
554
+
555
+ if (!fs::exists(json_path))
556
+ throw std::runtime_error(json_path.string() + ": file does not exist");
557
+ std::ifstream ifs(json_path);
558
+ return create(env, name, ifs);
559
+ }
560
+
561
+ std::vector<std::string> Dialog::list() {
562
+ std::vector<std::string> v;
563
+ if (!registry) return v;
564
+
565
+ Registry& r = *registry;
566
+
567
+ for (auto k : r)
568
+ v.push_back(k.first);
569
+ v.push_back("basic"); // default type, always registered
570
+ return v;
571
+ }
572
+
573
+ bool Dialog::applyLoraAdapter(std::string lora_adapter_name, std::string engine_role) {
574
+ auto& engine = *_engine[engine_role];
575
+ if (!engine.applyLoraAdapter(lora_adapter_name)) {
576
+ __WARN("dialog-applyLoraAdapter: failed for {}", lora_adapter_name);
577
+ return false;
578
+ }
579
+ return true;
580
+ }
581
+ bool Dialog::applyLoraStrength(std::string tensor_name, float tensor_val, std::string engine_role) {
582
+ auto& engine = *_engine[engine_role];
583
+ if (!engine.applyLoraStrength(tensor_name, tensor_val)) {
584
+ __WARN("dialog-applyLoraStrength: failed for {}", tensor_name);
585
+ return false;
586
+ }
587
+ return true;
588
+ }
589
+
590
+ } // namespace qualla
Genie/Genie/src/qualla/dialogs/basic.cpp ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All rights reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/dialog.hpp>
10
+ #include <qualla/logger.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+ #include <qualla/detail/timer.hpp>
13
+ #include <qualla/detail/onload.hpp>
14
+ #include <qualla/detail/basic-dialog.hpp>
15
+
16
+ #include <functional>
17
+ #include <filesystem>
18
+ #include <string>
19
+
20
+ #include <fmt/format.h>
21
+ #include <fmt/ranges.h>
22
+
23
+ namespace fs = std::filesystem;
24
+
25
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
26
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
27
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
28
+ #define __KPIS(__fmt, ...) \
29
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
30
+ #define __DEBUG(__fmt, ...) \
31
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
32
+ #define __TRACE(__fmt, ...) \
33
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
34
+
35
+ namespace qualla {
36
+
37
+ BasicDialog::BasicDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf) : Dialog(env, name, conf) {
38
+ if (!_engine.contains("primary")) {
39
+ State::fatal("\"primary\" engine not present in config!");
40
+ return;
41
+ }
42
+ }
43
+
44
+ bool BasicDialog::processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, Dialog::Callback callback){
45
+
46
+ auto& sampler = *_sampler["primary"];
47
+ auto& engine = *_engine["primary"];
48
+
49
+ while (true) {
50
+ if (State::canceled()) {
51
+ callback("", Sentence::END);
52
+ break;
53
+ }
54
+ // This condition is valid for both tokens and embedding
55
+ if (_n_past + 1 > _ctx->size()) {
56
+ __WARN("Context limit exceeded ({} + 1 > {})", _n_past, _ctx->size());
57
+ callback("", Sentence::END);
58
+ break;
59
+ }
60
+ if (m_inputType == InputType::TOKENS) {
61
+ if (!engine.process(tokens, logits))
62
+ return Dialog::abort("engine processing failed", callback);
63
+ } else if(m_inputType == InputType::EMBEDDINGS) {
64
+ // Convert tokens to embedding for the processing in the engine.
65
+ auto embedBufSize = engine.getEmbeddingBufferSize();
66
+ std::vector<uint8_t> embedding;
67
+ for(auto &token: tokens){
68
+ std::vector<uint8_t> curTokenEmbedding(embedBufSize,0);
69
+ m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
70
+ embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
71
+ }
72
+ if (!engine.process(embedding, {}, logits))
73
+ return Dialog::abort("engine processing failed", callback);
74
+ }
75
+ else{
76
+ return Dialog::abort("No valid Input Type is used", callback);
77
+ }
78
+ tokens[0] = _last_tok = sampler.process(logits);
79
+
80
+ _n_past++;
81
+ _n_generated++;
82
+
83
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
84
+
85
+ if (_ctx->is_eos(_last_tok)) {
86
+ callback("", Sentence::END);
87
+ break;
88
+ }
89
+
90
+ if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break;
91
+ }
92
+
93
+ return true;
94
+ }
95
+
96
+ bool BasicDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
97
+ // Check for prev failures and bail out early
98
+ if (State::failed()) return false;
99
+
100
+ Timer start;
101
+
102
+ if(m_inputType != InputType::TOKENS) {
103
+ __ERROR("Input type for model is not tokens.");
104
+ return false;
105
+ }
106
+
107
+ _gpio_marker->set();
108
+
109
+ // Vector for storing logits.
110
+ // Allocated & filled by the engine.
111
+ std::vector<float> logits;
112
+
113
+ State::clear();
114
+
115
+ auto& sampler = *_sampler["primary"];
116
+ auto& engine = *_engine["primary"];
117
+
118
+ using FF = Engine::Feature::Flags;
119
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
120
+
121
+ if (_n_past + tokens.size() > _ctx->size()) {
122
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
123
+ callback("", Sentence::END);
124
+ return true;
125
+ }
126
+
127
+ if (!engine.process(tokens, logits, false))
128
+ return Dialog::abort("engine prompt processing failed", callback);
129
+
130
+ _n_prompt += tokens.size();
131
+ _n_past += tokens.size();
132
+
133
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
134
+
135
+ tokens[0] = _last_tok = sampler.process(logits);
136
+ tokens.resize(1);
137
+
138
+ _n_generated++;
139
+
140
+ _gpio_marker->set();
141
+
142
+ _kpis.prompt.update(start.elapsed_usec());
143
+
144
+ // Log latest KPIs
145
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
146
+
147
+ start.reset();
148
+
149
+ if (_ctx->is_eos(_last_tok)) {
150
+ callback("", Sentence::END);
151
+ return true;
152
+ }
153
+
154
+ if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
155
+
156
+ State::busy(true);
157
+
158
+ processFollowOnGeneration(tokens, logits, callback);
159
+
160
+ State::busy(false);
161
+
162
+ _gpio_marker->set();
163
+ _gpio_marker->reset();
164
+
165
+ _kpis.generate.update(start.elapsed_usec());
166
+
167
+ // Log latest KPIs in a single line
168
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
169
+
170
+ return !State::failed();
171
+ }
172
+
173
+ bool BasicDialog::processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, qualla::DialogCallback callback){
174
+
175
+ auto& sampler = *_sampler["primary"];
176
+ auto& engine = *_engine["primary"];
177
+
178
+ while (true) {
179
+ if (State::canceled()) {
180
+ callback.callBack(nullptr, 0, Sentence::END, tokenizer());
181
+ break;
182
+ }
183
+ // This condition is valid for both tokens and embedding
184
+ if (_n_past + 1 > _ctx->size()) {
185
+ __WARN("Context limit exceeded ({} + 1 > {})", _n_past, _ctx->size());
186
+ callback.callBack(nullptr, 0, Sentence::END, tokenizer());
187
+ break;
188
+ }
189
+ if (m_inputType == InputType::TOKENS) {
190
+ if (!engine.process(tokens, logits))
191
+ return Dialog::abort("engine processing failed", callback);
192
+ } else if(m_inputType == InputType::EMBEDDINGS) {
193
+ // Convert tokens to embedding for the processing in the engine.
194
+ auto embedBufSize = engine.getEmbeddingBufferSize();
195
+ std::vector<uint8_t> embedding;
196
+ for(auto &token: tokens){
197
+ std::vector<uint8_t> curTokenEmbedding(embedBufSize,0);
198
+ m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
199
+ embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
200
+ }
201
+ if (!engine.process(embedding, {}, logits))
202
+ return Dialog::abort("engine processing failed", callback);
203
+ }
204
+ else{
205
+ return Dialog::abort("No valid Input Type is used", callback);
206
+ }
207
+ tokens[0] = _last_tok = sampler.process(logits);
208
+
209
+ _n_past++;
210
+ _n_generated++;
211
+
212
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
213
+
214
+ if (_ctx->is_eos(_last_tok)) {
215
+ callback.callBack(nullptr, 0, Sentence::END, tokenizer());
216
+ break;
217
+ }
218
+
219
+ if (!callback.callBack(tokens.data(), tokens.size(), Sentence::CONTINUE, tokenizer())) break;
220
+ }
221
+
222
+ return true;
223
+ }
224
+
225
+ bool BasicDialog::process(std::vector<int32_t>& tokens, qualla::DialogCallback callback) {
226
+ // Check for prev failures and bail out early
227
+ if (State::failed()) return false;
228
+
229
+ Timer start;
230
+
231
+ if(m_inputType != InputType::TOKENS) {
232
+ __ERROR("Input type for model is not tokens.");
233
+ return false;
234
+ }
235
+
236
+ _gpio_marker->set();
237
+
238
+ // Vector for storing logits.
239
+ // Allocated & filled by the engine.
240
+ std::vector<float> logits;
241
+
242
+ State::clear();
243
+
244
+ auto& sampler = *_sampler["primary"];
245
+ auto& engine = *_engine["primary"];
246
+
247
+ using FF = Engine::Feature::Flags;
248
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
249
+
250
+ if (_n_past + tokens.size() > _ctx->size()) {
251
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
252
+ callback.callBack(nullptr, 0, Sentence::END, tokenizer());
253
+ return true;
254
+ }
255
+
256
+ if (!engine.process(tokens, logits, false)) {
257
+ return Dialog::abort("engine prompt processing failed", callback);
258
+ }
259
+
260
+ _n_prompt += tokens.size();
261
+ _n_past += tokens.size();
262
+
263
+ if (!engine.updateKV(_n_past)) {
264
+ return Dialog::abort("context size exceeded", callback);
265
+ }
266
+
267
+ tokens[0] = _last_tok = sampler.process(logits);
268
+ tokens.resize(1);
269
+
270
+ _n_generated++;
271
+
272
+ _gpio_marker->set();
273
+
274
+ _kpis.prompt.update(start.elapsed_usec());
275
+
276
+ // Log latest KPIs
277
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
278
+
279
+ start.reset();
280
+
281
+ if (_ctx->is_eos(_last_tok)) {
282
+ callback.callBack(nullptr, 0, Sentence::END, tokenizer());
283
+ return true;
284
+ }
285
+
286
+ if (!callback.callBack(tokens.data(), tokens.size(), Sentence::BEGIN, tokenizer()))
287
+ return true;
288
+
289
+ State::busy(true);
290
+ processFollowOnGeneration(tokens, logits, callback);
291
+ State::busy(false);
292
+
293
+ _gpio_marker->set();
294
+ _gpio_marker->reset();
295
+
296
+ _kpis.generate.update(start.elapsed_usec());
297
+
298
+ // Log latest KPIs in a single line
299
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
300
+
301
+ return !State::failed();
302
+ }
303
+
304
+ bool BasicDialog::process(
305
+ std::vector<uint8_t>& embedding_vectors,
306
+ T2ECallback t2eCallback,
307
+ Dialog::Callback callback
308
+ ) {
309
+ Timer start;
310
+
311
+ if(m_inputType != InputType::EMBEDDINGS) {
312
+ __ERROR("Input type for model is not embeddings.");
313
+ return false;
314
+ }
315
+
316
+ // Vector for storing logits.
317
+ // Allocated & filled by the engine.
318
+ std::vector<float> logits;
319
+
320
+ State::clear();
321
+
322
+ _gpio_marker->set();
323
+
324
+ auto& sampler = *_sampler["primary"];
325
+ auto& engine = *_engine["primary"];
326
+
327
+ // Store the t2e callback for reference during follow-on generation.
328
+ m_t2eCallback = t2eCallback;
329
+
330
+ size_t embedBufSize = engine.getEmbeddingBufferSize();
331
+
332
+ {
333
+ std::vector<uint8_t> eosEmbedding(embedBufSize, 0.0);
334
+ if (m_t2eCallback) {
335
+ m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize);
336
+ }
337
+ // For non-autogenerative usecases (where t2eCallback is not supplied),
338
+ // the EOS vector is all zero. This is fine for models with proper
339
+ // attention masking support, but may degrade accuracy otherwise.
340
+ if (!engine.cacheEosEmbedding(eosEmbedding)) {
341
+ __DEBUG("Failed to set the eos token embedding.");
342
+ return false;
343
+ }
344
+ }
345
+
346
+ using FF = Engine::Feature::Flags;
347
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
348
+
349
+ size_t curTokenCount = embedding_vectors.size() / embedBufSize;
350
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
351
+ start.reset(); // Don't include preprocessing time
352
+
353
+ if (_n_past + curTokenCount > _ctx->size()) {
354
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokenCount, _ctx->size());
355
+ callback("", Sentence::END);
356
+ return true;
357
+ }
358
+
359
+ if (!engine.process(embedding_vectors, {}, logits))
360
+ return Dialog::abort("engine prompt processing failed", callback);
361
+ _n_prompt += curTokenCount;
362
+ _n_past += curTokenCount;
363
+
364
+ std::vector<int32_t> tokens(1, 0);
365
+
366
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
367
+
368
+ tokens[0] = _last_tok = sampler.process(logits);
369
+
370
+ _n_generated++;
371
+
372
+ _gpio_marker->set();
373
+
374
+ _kpis.prompt.update(start.elapsed_usec());
375
+
376
+ // Log latest KPIs
377
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
378
+
379
+ start.reset();
380
+
381
+ if (_ctx->is_eos(_last_tok)) {
382
+ callback("", Sentence::END);
383
+ return true;
384
+ }
385
+
386
+ if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) {
387
+ return true;
388
+ }
389
+
390
+ if (!m_t2eCallback) {
391
+ callback("", Sentence::END);
392
+ return true;
393
+ }
394
+
395
+ State::busy(true);
396
+ processFollowOnGeneration(tokens, logits, callback);
397
+ State::busy(false);
398
+
399
+ _gpio_marker->set();
400
+ _gpio_marker->reset();
401
+
402
+ _kpis.generate.update(start.elapsed_usec());
403
+ // Log latest KPIs in a single line
404
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
405
+
406
+ return !State::failed();
407
+ }
408
+
409
+ // Registrator instance
410
+ static OnLoad regy([]() {
411
+ Dialog::__register(
412
+ "basic",
413
+ [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
414
+ return (Dialog*)new BasicDialog(env, name, conf);
415
+ }
416
+ );
417
+ });
418
+
419
+ void needBasicDialog() {}
420
+
421
+ } // namespace qualla
Genie/Genie/src/qualla/dialogs/kv-share.cpp ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All rights reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/dialog.hpp>
10
+ #include <qualla/sampler.hpp>
11
+ #include <qualla/logger.hpp>
12
+ #include <qualla/detail/config.hpp>
13
+ #include <qualla/detail/timer.hpp>
14
+ #include <qualla/detail/onload.hpp>
15
+ #include <qualla/detail/sampler-utils.hpp>
16
+ #include <qualla/detail/basic-sampler.hpp>
17
+ #include <qualla/detail/cache-file.hpp>
18
+
19
+ #include <functional>
20
+ #include <fstream>
21
+ #include <string>
22
+ #include <unordered_map>
23
+ #include <filesystem>
24
+ #include <random>
25
+
26
+ #include <fmt/format.h>
27
+ #include <fmt/ranges.h>
28
+
29
+ namespace fs = std::filesystem;
30
+
31
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
32
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
33
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
34
+ #define __KPIS(__fmt, ...) \
35
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
36
+ #define __DEBUG(__fmt, ...) \
37
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
38
+ #define __TRACE(__fmt, ...) \
39
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
40
+
41
+ namespace qualla {
42
+
43
+ using qc = qualla::Config;
44
+
45
+ class KvShareDialog : public Dialog {
46
+ public:
47
+ KvShareDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
48
+ : Dialog(env, name, conf) {}
49
+
50
+ virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;
51
+
52
+ virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
53
+ return false;
54
+ }
55
+
56
+ virtual void reset() override;
57
+
58
+ bool convertKV(const fs::path& cache_dir);
59
+
60
+ };
61
+
62
+ void KvShareDialog::reset() {
63
+ __INFO("dialog-reset: {}", _ctx->name());
64
+
65
+ _n_past = 0;
66
+ _n_prompt = 0;
67
+ _n_generated = 0;
68
+ _n_queries = 0;
69
+ _last_tok = -1;
70
+
71
+ _kpis.reset();
72
+
73
+ // Reset Samplers
74
+ for (auto& s : _sampler)
75
+ s.second->reset();
76
+
77
+ // Reset Engines
78
+ for (auto& e : _engine) {
79
+ e.second->reset();
80
+ e.second->unload();
81
+ }
82
+
83
+ State::clear();
84
+ }
85
+
86
+ bool KvShareDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
87
+
88
+ // Check for prev failures and bail out early
89
+ if (State::failed()) return false;
90
+
91
+ Timer start;
92
+
93
+ // Vector for storing logits.
94
+ // Allocated & filled by the engine.
95
+ std::vector<float> logits;
96
+
97
+ State::clear();
98
+
99
+ auto& sampler = *_sampler["primary"];
100
+
101
+ auto& p_engine = *_engine["primary"]; // prompt
102
+ auto& s_engine = *_engine["secondary"]; // generation
103
+
104
+ if (_n_past + tokens.size() > _ctx->size()) {
105
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
106
+ callback("", Sentence::END);
107
+ return true;
108
+ }
109
+
110
+ if (!p_engine.process(tokens, logits))
111
+ return Dialog::abort("engine prompt processing failed", callback);
112
+
113
+ _n_prompt += tokens.size();
114
+ _n_past += tokens.size();
115
+
116
+ if (!p_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
117
+
118
+ tokens[0] = _last_tok = sampler.process(logits);
119
+ tokens.resize(1);
120
+
121
+ _n_generated++;
122
+
123
+ _kpis.prompt.update(start.elapsed_usec());
124
+ // Log latest KPIs
125
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
126
+
127
+ if (_ctx->is_eos(_last_tok)) {
128
+ callback("", Sentence::END);
129
+ return true;
130
+ }
131
+
132
+ if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
133
+
134
+ __DEBUG("dialog: {} : switching engines", _ctx->name());
135
+ {
136
+ // Setup cache dir for saving the engine state
137
+ std::string cache_name = _ctx->name() + "-kv-share";
138
+ fs::path cache_dir = _env->path().cache / cache_name;
139
+
140
+ if (!fs::exists(cache_dir) && !fs::create_directories(cache_dir)) {
141
+ __ERROR("dialog: {} : failed to create cache directory {}",
142
+ _ctx->name(),
143
+ cache_dir.string());
144
+ return Dialog::abort("engine switch failed", callback);
145
+ }
146
+
147
+ // Save and unload the primary engine
148
+ p_engine.save(cache_name);
149
+ p_engine.unload();
150
+
151
+ // The purpose is to save the hyperparams
152
+ s_engine.save(cache_name);
153
+
154
+ convertKV(cache_dir);
155
+
156
+ size_t n = s_engine.restore(cache_name);
157
+
158
+ if(!fs::remove_all(cache_dir)) {
159
+ __WARN("dialog: {} : cache files not closed/dir not found", _ctx->name());
160
+ }
161
+
162
+ if (n != _n_past) {
163
+ __WARN("dialog: {} : kv size mismatch {} expected {}", _ctx->name(), n, _n_past);
164
+ _n_past = n;
165
+ }
166
+
167
+ s_engine.updateKV(_n_past);
168
+ }
169
+
170
+ start.reset();
171
+
172
+ State::busy(true);
173
+
174
+ while (true) {
175
+ if (State::canceled()) {
176
+ callback("", Sentence::END);
177
+ break;
178
+ }
179
+
180
+ if (_n_past + tokens.size() > _ctx->size()) {
181
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
182
+ callback("", Sentence::END);
183
+ break;
184
+ }
185
+ if (!s_engine.process(tokens, logits))
186
+ return Dialog::abort("secondary engine processing failed", callback);
187
+
188
+ tokens[0] = _last_tok = sampler.process(logits);
189
+
190
+ _n_past++;
191
+ _n_generated++;
192
+
193
+ if (!s_engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
194
+
195
+ if (_ctx->is_eos(_last_tok)) {
196
+ callback("", Sentence::END);
197
+ break;
198
+ }
199
+
200
+ if (!callback(_tokenizer->decode(tokens), Sentence::CONTINUE)) break;
201
+ }
202
+
203
+ State::busy(false);
204
+
205
+ _kpis.generate.update(start.elapsed_usec());
206
+
207
+ // Log latest KPIs in a single line
208
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
209
+
210
+ return true;
211
+ }
212
+
213
+ bool KvShareDialog::convertKV(const fs::path& cache_dir) {
214
+ Timer start;
215
+
216
+ fs::path nsp_cache_path = cache_dir / "kv-cache.primary.qnn-htp";
217
+ fs::path cpu_cache_path = cache_dir / "kv-cache.secondary.qnn-cpu";
218
+
219
+ __DEBUG("kv-convert: begin converting {} to ", nsp_cache_path.string(), cpu_cache_path.string());
220
+
221
+ std::ifstream nsp_fs(nsp_cache_path, std::ios::in | std::ios::binary);
222
+
223
+ if (nsp_fs.fail()) {
224
+ __ERROR("kv-convert: error reading file {}", nsp_cache_path.string());
225
+ State::error("failed to read primary kv-cache");
226
+ return false;
227
+ }
228
+
229
+ // Read spec from nsp file
230
+ CacheFileSpec nsp_spec;
231
+ nsp_fs.read((char*)&nsp_spec, sizeof(nsp_spec));
232
+ if (nsp_spec.magic != 0xC0DE) {
233
+ __ERROR("kv-convert: expected 0xC0DE found {:#x}", nsp_spec.magic);
234
+ State::error("invalid format of primary kv-cache");
235
+ return false;
236
+ }
237
+
238
+ // clang-format off
239
+ __DEBUG("kv-convert: load {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
240
+ nsp_spec.num_tensors, nsp_spec.magic, int(nsp_spec.dtype), nsp_spec.n_heads, nsp_spec.embed_dim, nsp_spec.update_size);
241
+ // clang-format on
242
+
243
+ std::fstream cpu_fs(cpu_cache_path, std::ios::in | std::ios::out | std::ios::binary);
244
+
245
+ if (cpu_fs.fail()) {
246
+ // TODO: replace with proper error handling
247
+ __ERROR("kv-convert: failed to write {}", cpu_cache_path.string());
248
+ State::error("failed to save secondary kv-cache");
249
+ return false;
250
+ }
251
+
252
+ CacheFileSpec cpu_spec;
253
+ cpu_fs.read((char*)&cpu_spec, sizeof(cpu_spec));
254
+ if (cpu_spec.magic != 0xC0DE) {
255
+ __ERROR("kv-convert: expected 0xC0DE found {:#x}", cpu_spec.magic);
256
+ State::error("invalid format of secondary kv-cache");
257
+ return false;
258
+ }
259
+
260
+ // Set the n_tokens processed during prompt processing and the spec write to file
261
+ cpu_spec.update_size = nsp_spec.update_size;
262
+ cpu_fs.seekp(std::ios::beg);
263
+ cpu_fs.write((char*)&cpu_spec, sizeof(cpu_spec));
264
+
265
+ const uint32_t n_layer = nsp_spec.num_tensors / 2;
266
+ const uint32_t n_head = nsp_spec.n_heads;
267
+ const uint32_t kv_dim = nsp_spec.embed_dim;
268
+ const uint32_t n_tok = nsp_spec.update_size;
269
+
270
+ const size_t cache_size = n_layer * n_head * kv_dim * n_tok;
271
+
272
+ // Read Key/Value Cache
273
+ std::vector<uint8_t> key_cache(cache_size);
274
+ std::vector<uint8_t> value_cache(cache_size);
275
+ nsp_fs.read((char*)key_cache.data(), cache_size);
276
+ nsp_fs.read((char*)value_cache.data(), cache_size);
277
+
278
+ // Read Quantization parameters
279
+ std::vector<double> key_scales(n_layer);
280
+ std::vector<double> value_scales(n_layer);
281
+ nsp_fs.read((char*)key_scales.data(), n_layer * sizeof(double));
282
+ nsp_fs.read((char*)value_scales.data(), n_layer * sizeof(double));
283
+
284
+ nsp_fs.close();
285
+
286
+ // Convert and write on cpu_file
287
+ // Dequant and transpose caches
288
+ const uint32_t layer_size = n_head * kv_dim * n_tok;
289
+ const uint32_t head_size = kv_dim * n_tok;
290
+
291
+ // Transpose kvdim * n_tok (QNN-HTP K$) -> n_tok * kvdim (QNN-CPU K$)
292
+ // For ScopGPT KV$ Format
293
+ __DEBUG("kv-convert: dequantizing keys");
294
+ std::vector<float> dequant_keys(cache_size);
295
+ for (uint32_t i = 0; i < n_layer; i++) {
296
+ for (uint32_t j = 0; j < n_head; j++) {
297
+ for (uint32_t k = 0; k < kv_dim; k++) {
298
+ for (uint32_t l = 0; l < n_tok; l++) {
299
+ // Interleave K$
300
+ // QNN HTP: [0 2 4 ... 126 1 3 5 ... 127]
301
+ // QNN CPU: [0 1 2 ... 63 64 65 ... 127]
302
+ const uint32_t interleaved_k =
303
+ (2 * k < kv_dim) ? 2 * k : 2 * (k - kv_dim / 2) + 1;
304
+
305
+ const uint32_t read_loc = i * layer_size + j * head_size + k * n_tok + l;
306
+ const uint32_t write_loc = i * layer_size + j * head_size + l * kv_dim + interleaved_k;
307
+
308
+ dequant_keys[write_loc] =
309
+ (static_cast<float>(key_cache[read_loc]) - 128) * key_scales[i];
310
+ }
311
+ }
312
+ }
313
+ }
314
+
315
+ __DEBUG("kv-convert: dequantizing values");
316
+ std::vector<float> dequant_values(cache_size);
317
+ for (uint32_t i = 0; i < n_layer; i++) {
318
+ for (uint32_t j = 0; j < n_head; j++) {
319
+ for (uint32_t l = 0; l < n_tok; l++) {
320
+ for (uint32_t k = 0; k < kv_dim; k++) {
321
+ const uint32_t read_loc = i * layer_size + j * head_size + l * kv_dim + k;
322
+ const uint32_t write_loc = read_loc;
323
+
324
+ dequant_values[write_loc] =
325
+ (static_cast<float>(value_cache[read_loc]) - 128) * value_scales[i];
326
+ }
327
+ }
328
+ }
329
+ }
330
+
331
+ __DEBUG("kv-convert: storing converted KV to file");
332
+ cpu_fs.write((char *)dequant_keys.data(), dequant_keys.size() * sizeof(float));
333
+ cpu_fs.write((char *)dequant_values.data(), dequant_values.size() * sizeof(float));
334
+
335
+ cpu_fs.flush();
336
+ cpu_fs.close();
337
+
338
+ __DEBUG("kv-convert: done converting {} to {} in {} usec",
339
+ nsp_cache_path.string(),
340
+ cpu_cache_path.string(),
341
+ start.elapsed_usec());
342
+
343
+ return true;
344
+
345
+ }
346
+
347
+ // Registrator instance
348
+ static OnLoad regy([]() {
349
+ Dialog::__register(
350
+ "kv-share",
351
+ [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
352
+ return (Dialog*)new KvShareDialog(env, name, conf);
353
+ }
354
+ );
355
+ });
356
+
357
+ void needKvShareDialog() {}
358
+
359
+ } // namespace qualla
Genie/Genie/src/qualla/dialogs/lhd-dec.cpp ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/dialog.hpp>
10
+ #include <qualla/logger.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+ #include <qualla/detail/timer.hpp>
13
+ #include <qualla/detail/onload.hpp>
14
+ #include <qualla/detail/lhd-dialog.hpp>
15
+
16
+ #include <functional>
17
+ #include <filesystem>
18
+ #include <string>
19
+ #include <cmath>
20
+ #include <cstdio>
21
+ #include <random>
22
+
23
+ #include <fmt/format.h>
24
+ #include <fmt/ranges.h>
25
+
26
+ namespace fs = std::filesystem;
27
+
28
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
29
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
30
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
31
+ #define __KPIS(__fmt, ...) \
32
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
33
+ #define __DEBUG(__fmt, ...) \
34
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
35
+ #define __TRACE(__fmt, ...) \
36
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
37
+
38
+ namespace qualla {
39
+
40
+ using qc = qualla::Config;
41
+
42
+ LhdDecDialog::LhdDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
43
+ : Dialog(env, name, conf) {
44
+
45
+ _window = qc::optional<size_t>(conf, "window", 8);
46
+ _ngram = qc::optional<size_t>(conf, "ngram", 3);
47
+ _gcap = qc::optional<size_t>(conf, "gcap", 8);
48
+
49
+ _lhd_mode_str = qc::optional<std::string>(conf, "lhd-update-mode", "ALWAYS_FWD_ONE");
50
+ }
51
+
52
+ bool LhdDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
53
+ // Check for prev failures and bail out early
54
+ if (State::failed()) return false;
55
+
56
+ Timer start;
57
+
58
+ // Vector for storing logits.
59
+ // Allocated & filled by the engine.
60
+ std::vector<float> logits;
61
+ std::vector<int32_t> resultTokens;
62
+
63
+ State::clear();
64
+
65
+ auto& sampler = *_sampler["primary"];
66
+ auto& engine = *_engine["primary"];
67
+
68
+ using FF = Engine::Feature::Flags;
69
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
70
+
71
+ if (_n_past + tokens.size() > _ctx->size()) {
72
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
73
+ callback("", Sentence::END);
74
+ return true;
75
+ }
76
+
77
+ if (!engine.process(tokens, logits, false))
78
+ return Dialog::abort("engine prompt processing failed", callback);
79
+
80
+ _n_prompt += tokens.size();
81
+ _n_past += tokens.size();
82
+
83
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
84
+
85
+ std::vector<int32_t> tokens_tmp(1);
86
+ tokens_tmp[0] = _last_tok = sampler.process(logits);
87
+ resultTokens.push_back(_last_tok);
88
+
89
+ _n_generated++;
90
+
91
+ _kpis.prompt.update(start.elapsed_usec());
92
+
93
+ // Log latest KPIs
94
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
95
+
96
+ if (_ctx->is_eos(_last_tok)) {
97
+ callback("", Sentence::END);
98
+ return true;
99
+ }
100
+
101
+ // Exit condition : Prediction limit reached OR ctx size limit reached
102
+ if (!callback(_tokenizer->decode(tokens_tmp), Sentence::BEGIN)) return true;
103
+
104
+ State::busy(true);
105
+
106
+ // verification branch init
107
+ v_branch.resize(_gcap);
108
+
109
+ // n-gram pools
110
+ const size_t n_vocab = _ctx->n_vocab();
111
+ ngram_container ngrams_pool(n_vocab, _ngram, _gcap);
112
+
113
+ // lookahead branch first level init
114
+ lhd_branch.resize(_ngram - 1);
115
+ lhd_branch_prev.resize(_window);
116
+
117
+ for (int j = 0; j < _ngram - 1; j++) {
118
+ lhd_branch[j].resize(_window);
119
+
120
+ for (int i = 0; i < _window; i++) {
121
+ if (j == 0) {
122
+ // initialize with the random token from prompt
123
+ lhd_branch[j][i] = tokens[1 + rand() % (tokens.size() - 1)];
124
+ } else {
125
+ // initialize with a sequence of increasing numbers
126
+ lhd_branch[j][i] = 1000 + i;
127
+ }
128
+ }
129
+ }
130
+
131
+ // lookahead branch other level init
132
+ while (_level_idx < _ngram - 1) {
133
+
134
+ batch.clear();
135
+ attention_map.clear();
136
+
137
+ // fill the first token of the first level
138
+ batch.push_back(_last_tok);
139
+ attention_map.push_back(-1);
140
+ lhd_branch[0][0] = _last_tok;
141
+
142
+ // fill the remaining WINDOW - 1 tokens for the first level
143
+ for (int i = 1; i < _window; i++) {
144
+ batch.push_back(lhd_branch[0][i]);
145
+ attention_map.push_back(i - 1);
146
+ }
147
+
148
+ // fill the rest of the levels
149
+ for (int j = 1; j < _ngram - 1; j++) {
150
+ for (int i = 0; i < _window; i++) {
151
+ batch.push_back(lhd_branch[j][i]);
152
+ attention_map.push_back((j - 1) * _window + i);
153
+ }
154
+ }
155
+
156
+ // re-init tokens batch
157
+ tokens.resize(_window * (_ngram - 1));
158
+ tokens = batch;
159
+
160
+ if (_n_past + tokens.size() > _ctx->size()) {
161
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
162
+ callback("", Sentence::END);
163
+ break;
164
+ }
165
+
166
+ size_t n_tok = engine.process(tokens, attention_map, logits, true);
167
+ if (n_tok != tokens.size())
168
+ return Dialog::abort("engine lookahead branch processing failed", callback);
169
+
170
+ for (int i = 0; i < _window; i++) {
171
+ size_t sample_tmp_idx = (_level_idx - 1) * _window + i;
172
+ // sampler from logits all
173
+ std::span<float> span_logits{logits.data(),logits.size()};
174
+ std::span<float> span_tmp = span_logits.subspan(sample_tmp_idx * n_vocab, n_vocab);
175
+ int32_t sampled_tmp_token = sampler.process(span_tmp);
176
+ lhd_branch[_level_idx][i] = sampled_tmp_token;
177
+ }
178
+
179
+ _level_idx++;
180
+ }
181
+
182
+ if (_lhd_mode_str == "FWD_MAX_HIT")
183
+ _lhd_update_mode = FWD_MAX_HIT;
184
+ else if (_lhd_mode_str == "FWD_LEVEL")
185
+ _lhd_update_mode = FWD_LEVEL;
186
+ else
187
+ _lhd_update_mode = ALWAYS_FWD_ONE;
188
+
189
+ start.reset();
190
+
191
+ while (true) {
192
+ if (State::canceled()) {
193
+ callback("", Sentence::END);
194
+ break;
195
+ }
196
+ // input batch init
197
+ {
198
+ batch.clear();
199
+ attention_map.clear();
200
+
201
+ // fill the first token of the first level
202
+ batch.push_back(_last_tok);
203
+ attention_map.push_back(-1);
204
+ // lhd_branch[0][0] = _last_tok;
205
+
206
+ // fill the remaining WINDOW - 1 tokens for the first level
207
+ for (int i = 1; i < _window; i++) {
208
+ batch.push_back(lhd_branch[0][i]);
209
+ attention_map.push_back(i - 1);
210
+ }
211
+
212
+ // fill the rest of the levels
213
+ for (int j = 1; j < _ngram - 1; j++) {
214
+ for (int i = 0; i < _window; i++) {
215
+ batch.push_back(lhd_branch[j][i]);
216
+ attention_map.push_back((j - 1) * _window + i);
217
+ }
218
+ }
219
+
220
+ // build verification n-grams(branch)
221
+ {
222
+ const int g_cur = ngrams_pool.cnt[_last_tok];
223
+
224
+ v_branch.resize(g_cur);
225
+ // input_token_batch.size = (_window + g_cur) * (_ngram - 1);
226
+ tokens.resize((_window + g_cur) * (_ngram - 1));
227
+ for (int g = 0; g < g_cur; g++) {
228
+ v_branch[g].active = true;
229
+ v_branch[g].tokens.resize(_ngram);
230
+ v_branch[g].i_batch.resize(_ngram);
231
+ v_branch[g].seq_id = _window + 1 + g;
232
+ v_branch[g].i_batch[0] = 0;
233
+ v_branch[g].tokens[0] = _last_tok;
234
+ }
235
+
236
+ for (int j = 0; j < _ngram - 1; j++) {
237
+ for (int g = 0; g < g_cur; g++) {
238
+ const int idx = _last_tok * (_ngram - 1) * _gcap + g * (_ngram - 1);
239
+ const int32_t t = ngrams_pool.tokens[idx + j];
240
+ v_branch[g].tokens[j + 1] = t;
241
+ v_branch[g].i_batch[j + 1] = j + 1;
242
+ }
243
+ }
244
+
245
+ for (int g = 0; g < g_cur; g++) {
246
+ for (int j = 0; j < _ngram - 1; j++) {
247
+ batch.push_back(v_branch[g].tokens[j + 1]);
248
+ if (j == 0)
249
+ attention_map.push_back(0);
250
+ else
251
+ attention_map.push_back(batch.size() - 2);
252
+ }
253
+ }
254
+ }
255
+ }
256
+
257
+ // re-init tokens batch
258
+ std::vector<bool> selected(attention_map.size(), false);
259
+ tokens = batch;
260
+
261
+ if (_n_past + tokens.size() > _ctx->size()) {
262
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
263
+ callback("", Sentence::END);
264
+ break;
265
+ }
266
+
267
+ size_t n_tok = engine.process(tokens, attention_map, logits, true);
268
+ if (n_tok != tokens.size()) return Dialog::abort("engine gen processing failed", callback);
269
+
270
+ // verification branch seq-id
271
+ size_t seq_id_best = 0;
272
+ // max hit pos
273
+ size_t i_batch_best = 0;
274
+
275
+ // Lookahead decoding and verification
276
+ for (int v = 0; v < _ngram; ++v) {
277
+ int i_batch = 0;
278
+
279
+ if (v > 0) {
280
+ for (int g = 0; g < (int)v_branch.size(); g++) {
281
+ // record the best matched seq and pos
282
+ if (v_branch[g].active) {
283
+ i_batch = v_branch[g].i_batch[v];
284
+ i_batch_best = i_batch;
285
+ seq_id_best = v_branch[g].seq_id;
286
+ ++_n_accept;
287
+ break;
288
+ }
289
+ }
290
+
291
+ if (i_batch == 0) {
292
+ break;
293
+ }
294
+ }
295
+
296
+ size_t sample_idx;
297
+ if (seq_id_best == 0)
298
+ sample_idx = 0;
299
+ else
300
+ sample_idx = _window * (_ngram - 1) + (seq_id_best - (_window + 1)) * (_ngram - 1) +
301
+ i_batch - 1;
302
+
303
+ //vector selected set
304
+ selected[sample_idx] = true;
305
+
306
+ // sampler from logits all
307
+ std::span<float> span_logits{logits.data(),logits.size()};
308
+ std::span<float> sample_logit = span_logits.subspan(sample_idx * n_vocab, n_vocab);
309
+ _last_tok = sampler.process(sample_logit);
310
+
311
+ std::vector<int32_t> tokens_tmp(1);
312
+ tokens_tmp[0] = _last_tok;
313
+
314
+ resultTokens.push_back(_last_tok);
315
+ _n_generated++;
316
+ _n_past++;
317
+
318
+ if (_ctx->is_eos(_last_tok)) break;
319
+
320
+ if (!callback(_tokenizer->decode(tokens_tmp), Sentence::CONTINUE)) return true;
321
+
322
+ // if verify pass, check the next sample token until verifing failed
323
+ for (int g = 0; g < (int)v_branch.size(); g++) {
324
+ // update the n-gram active status
325
+ if (v_branch[g].active) {
326
+ if (v == _ngram - 1) {
327
+ v_branch[g].active = false;
328
+ } else {
329
+ if (_last_tok != v_branch[g].tokens[v + 1]) {
330
+ v_branch[g].active = false;
331
+ }
332
+ }
333
+ }
334
+ }
335
+
336
+ // update lookahead tokens when v=0 OR verify match
337
+ {
338
+ for (int i = 0; i < _window; i++) {
339
+ lhd_branch_prev[i] = lhd_branch[0][i];
340
+ }
341
+
342
+ if (v == 0) {
343
+ for (int j = 0; j < _ngram - 2; j++) {
344
+ lhd_branch[j] = lhd_branch[j + 1];
345
+ }
346
+
347
+ // sample from the last level
348
+ for (int i = 0; i < _window; i++) {
349
+ size_t sample_idx = (_ngram - 2) * _window + i;
350
+ std::span<float> sample_logit =
351
+ span_logits.subspan(sample_idx * n_vocab, n_vocab);
352
+ lhd_branch[_ngram - 2][i] = sampler.process(sample_logit);
353
+ }
354
+ } else {
355
+ if (_lhd_update_mode == FWD_MAX_HIT) {
356
+ // update lookahead branch by foward
357
+ for (int j = 0; j < _ngram - 1; j++) {
358
+ for (int i = 0; i < _window - v; i++) {
359
+ lhd_branch[j][i] = lhd_branch[j][i + 1];
360
+ }
361
+ }
362
+ } else if (_lhd_update_mode == FWD_LEVEL) {
363
+ // update lookahead branch by shifting level
364
+ for (int j = 0; j < _ngram - 2; j++) {
365
+ lhd_branch[j] = lhd_branch[j + 1];
366
+ }
367
+
368
+ for (int i = 0; i < _window; i++) {
369
+ // init from the previous level
370
+ lhd_branch[_ngram - 2][i] = lhd_branch[0][i];
371
+ }
372
+ }
373
+ }
374
+ }
375
+
376
+ // update n-grams pool
377
+ // only update n-grams pools when v=0
378
+ if (v == 0) {
379
+ std::vector<int32_t> ngram(_ngram - 1);
380
+ // n-gram pool generation
381
+ for (int f = 0; f < _window; ++f) {
382
+ const int ft = lhd_branch_prev[f]; // first token of the n-gram
383
+
384
+ for (int j = 0; j < _ngram - 1; ++j) {
385
+ ngram[j] = lhd_branch[j][f];
386
+ }
387
+
388
+ // filter-out repeating n-grams
389
+ {
390
+ bool is_unique = true;
391
+
392
+ for (int k = 0; k < ngrams_pool.cnt[ft]; ++k) {
393
+ // caculate the related idx by the first n-gram token
394
+ const int idx = ft * (_ngram - 1) * _gcap + k * (_ngram - 1);
395
+
396
+ bool is_match = true;
397
+ for (int j = 0; j < _ngram - 1; ++j) {
398
+ if (ngrams_pool.tokens[idx + j] != ngram[j]) {
399
+ is_match = false;
400
+ break;
401
+ }
402
+ }
403
+
404
+ // if n-gram match all, discard one of them
405
+ if (is_match) {
406
+ is_unique = false;
407
+ break;
408
+ }
409
+ }
410
+
411
+ if (!is_unique) {
412
+ continue;
413
+ }
414
+ }
415
+
416
+ const int head = ngrams_pool.head[ft];
417
+ const int idx = ft * (_ngram - 1) * _gcap + head * (_ngram - 1);
418
+
419
+ for (int i = 0; i < _ngram - 1; i++) {
420
+ // update the n-gram pool with new n-gram
421
+ ngrams_pool.tokens[idx + i] = ngram[i];
422
+ }
423
+
424
+ ngrams_pool.cnt[ft] = std::min(_gcap, ngrams_pool.cnt[ft] + 1);
425
+ ngrams_pool.head[ft] = (head + 1) % _gcap;
426
+
427
+ ngrams_pool.n_total++;
428
+ }
429
+ }
430
+ }
431
+
432
+ if (_lhd_update_mode == FWD_MAX_HIT) {
433
+ // std::random_device rd;
434
+ // std::mt19937 gen(rd());
435
+ // std::uniform_int_distribution<> dis(0, resultTokens.size() - 1);
436
+
437
+ // fill lookahead branch
438
+ for (int i = 0; i < _ngram - 1; i++) {
439
+ for (int j = _window - i_batch_best; j < _window; j++) {
440
+ lhd_branch[i][j] = resultTokens[1 + rand() % (resultTokens.size() - 1)];
441
+ // lhd_branch[i][j] = resultTokens[dis(gen)];
442
+ // std::cout << "Fill token = " << lhd_branch[i][j] << std::endl;
443
+ }
444
+ }
445
+ }
446
+
447
+ // KV cache management
448
+ if (!engine.updateKV(_n_past, selected))
449
+ return Dialog::abort("context size exceeded", callback);
450
+
451
+ if (_ctx->is_eos(_last_tok)) {
452
+ callback("", Sentence::END);
453
+ break;
454
+ }
455
+ }
456
+
457
+ State::busy(false);
458
+
459
+ _kpis.generate.update(start.elapsed_usec());
460
+
461
+ // Log latest KPIs in a single line
462
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
463
+ std::cout << std::endl << std::endl << std::flush;
464
+ __DEBUG("lhd-dec: n_generated = {} ---------- n_accept = {}", _n_generated, _n_accept);
465
+
466
+ return !State::failed();
467
+ }
468
+
469
+ // Registrator instance
470
+ static OnLoad regy([]() {
471
+ Dialog::__register(
472
+ "lhd-dec",
473
+ [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
474
+ return (Dialog*)new LhdDecDialog(env, name, conf);
475
+ }
476
+ );
477
+ });
478
+
479
+ void needLadeDialog() {}
480
+
481
+ } // namespace qualla
Genie/Genie/src/qualla/dialogs/multistream.cpp ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/dialog.hpp>
10
+ #include <qualla/logger.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+ #include <qualla/detail/timer.hpp>
13
+ #include <qualla/detail/onload.hpp>
14
+ #include <qualla/detail/multistream-dialog.hpp>
15
+ #include <qualla/detail/sampler-utils.hpp>
16
+
17
+ #include <functional>
18
+ #include <filesystem>
19
+ #include <string>
20
+
21
+ #include <fmt/format.h>
22
+ #include <fmt/ranges.h>
23
+
24
+ namespace fs = std::filesystem;
25
+
26
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
27
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
28
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
29
+ #define __KPIS(__fmt, ...) \
30
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
31
+ #define __DEBUG(__fmt, ...) \
32
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
33
+ #define __TRACE(__fmt, ...) \
34
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
35
+
36
+ namespace qualla {
37
+
38
+ bool MultiStreamDialog::processFollowOnGeneration(std::vector<std::vector<int32_t>>& streams, std::vector<float>& logits, Dialog::Callback callback) {
39
+
40
+ auto& sampler = *_sampler["primary"];
41
+ auto& engine = *_engine["primary"];
42
+
43
+ std::vector<std::vector<int32_t>> attention_mask(_n_streams);
44
+ std::vector<int32_t> streamIndices;
45
+
46
+ if (streams.size() == 0) {
47
+ callback("\n", Sentence::END);
48
+ return true;
49
+ }
50
+
51
+ for (int i = 0; i < streams.size(); i++) {
52
+ // Initialize all attention_masks to attend to all previous tokens
53
+ attention_mask[i].resize(_n_past, 1);
54
+ streamIndices.push_back(i);
55
+ }
56
+
57
+ State::busy(true);
58
+
59
+ while (true) {
60
+ if (State::canceled()) break;
61
+
62
+ // If this exceeds context length, truncate all streams and return
63
+ if (_n_past + streamIndices.size() > _ctx->size()) {
64
+ for (auto stream : streamIndices)
65
+ callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
66
+ break;
67
+ }
68
+
69
+ // Accumulate input tokens from all streams
70
+ std::vector<int32_t> multi_tokens(streamIndices.size());
71
+
72
+ for (int i = 0; i < streamIndices.size(); i++) {
73
+ multi_tokens[i] = streams[streamIndices[i]].back();
74
+
75
+ // Also add current iteration to the attention_mask
76
+ for (auto _mask_row : streamIndices)
77
+ // Set to true iff on diagonal, i.e. attend to itself
78
+ attention_mask[streamIndices[i]].push_back((streamIndices[i] == _mask_row) ? 1 : 0);
79
+ }
80
+
81
+ // Concatenate attention_mask for all active streams
82
+ std::vector<int32_t> multi_attn_mask;
83
+ multi_attn_mask.reserve((_n_past + streamIndices.size()) * streamIndices.size());
84
+ for (auto i : streamIndices)
85
+ multi_attn_mask.insert(
86
+ multi_attn_mask.end(),
87
+ attention_mask[i].begin(),
88
+ attention_mask[i].end()
89
+ );
90
+
91
+ // __DEBUG("Multi attention mask = {}", multi_attn_mask);
92
+
93
+ if (m_inputType == InputType::TOKENS) {
94
+ // Process input tokens for all streams in one batch
95
+ if (!engine.process(multi_tokens, multi_attn_mask, logits, true))
96
+ return Dialog::abort("engine gen processing failed", callback);
97
+ } else if (m_inputType == InputType::EMBEDDINGS) {
98
+ // Accumulate input embeddings from all streams
99
+ auto embedBufSize = engine.getEmbeddingBufferSize();
100
+ std::vector<uint8_t> multi_embeddings;
101
+
102
+ for (auto token : multi_tokens) {
103
+ // Convert tokens to embedding for the processing in the engine.
104
+ std::vector<uint8_t> curTokenEmbedding(embedBufSize, 0);
105
+ m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
106
+ multi_embeddings.insert(multi_embeddings.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
107
+ }
108
+
109
+ // Process input tokens for all streams in one batch
110
+ if (!engine.process(multi_embeddings, multi_attn_mask, logits, true))
111
+ return Dialog::abort("engine gen processing failed", callback);
112
+ }
113
+
114
+ // Process all logits independently
115
+ std::span<float> logit_span = std::span{logits.data(),logits.size()};
116
+ for (int i = 0; i < streamIndices.size(); i++) {
117
+ _last_tok = sampler.process(logit_span.subspan(i * _vocab, _vocab));
118
+ streams[streamIndices[i]].push_back(_last_tok);
119
+ }
120
+
121
+ _n_past += streamIndices.size();
122
+ _n_generated += streamIndices.size();
123
+
124
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
125
+
126
+ for (auto it = streamIndices.begin(); it != streamIndices.end();) {
127
+ int32_t stream = *it;
128
+ if (_ctx->is_eos(streams[stream].back())) {
129
+ callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
130
+ it = streamIndices.erase(it);
131
+ } else {
132
+ ++it;
133
+ }
134
+ }
135
+
136
+ if (streamIndices.size() == 0) break;
137
+ }
138
+ callback("\n", Sentence::END);
139
+
140
+ State::busy(false);
141
+
142
+ return true;
143
+ }
144
+
145
+ bool MultiStreamDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
146
+ // Check for prev failures and bail out early
147
+ if (State::failed()) return false;
148
+
149
+ Timer start;
150
+
151
+ if(m_inputType != InputType::TOKENS) {
152
+ __ERROR("Input type for model is not tokens.");
153
+ return false;
154
+ }
155
+
156
+ // Vector for storing logits.
157
+ // Allocated & filled by the engine.
158
+ std::vector<float> logits;
159
+
160
+ State::clear();
161
+
162
+ auto& engine = *_engine["primary"];
163
+
164
+ using FF = Engine::Feature::Flags;
165
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
166
+
167
+ if (_n_past + tokens.size() > _ctx->size()) {
168
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
169
+ callback("", Sentence::END);
170
+ return true;
171
+ }
172
+
173
+ if (!engine.process(tokens, logits, false))
174
+ return Dialog::abort("engine prompt processing failed", callback);
175
+
176
+ _n_prompt += tokens.size();
177
+ _n_past += tokens.size();
178
+
179
+ _prompt_len = _n_past;
180
+
181
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
182
+
183
+ std::vector<std::vector<int32_t>> streams;
184
+ getTopK(logits, streams, _n_streams, _p_threshold, callback);
185
+
186
+ _n_generated += streams.size();
187
+ _kpis.prompt.update(start.elapsed_usec());
188
+
189
+ // Log latest KPIs
190
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
191
+
192
+ start.reset();
193
+
194
+ bool status = processFollowOnGeneration(streams, logits, callback);
195
+
196
+ _kpis.generate.update(start.elapsed_usec());
197
+
198
+ // Log latest KPIs in a single line
199
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
200
+
201
+ return status;
202
+ }
203
+
204
+ bool MultiStreamDialog::process(
205
+ std::vector<uint8_t>& embedding_vectors,
206
+ T2ECallback t2eCallback,
207
+ Dialog::Callback callback
208
+ ) {
209
+ // Check for prev failures and bail out early
210
+ if (State::failed()) return false;
211
+
212
+ Timer start;
213
+
214
+ if(m_inputType != InputType::EMBEDDINGS) {
215
+ __ERROR("Input type for model is not embeddings.");
216
+ return false;
217
+ }
218
+
219
+ // Vector for storing logits.
220
+ // Allocated & filled by the engine.
221
+ std::vector<float> logits;
222
+
223
+ State::clear();
224
+
225
+ auto& sampler = *_sampler["primary"];
226
+ auto& engine = *_engine["primary"];
227
+
228
+ // Store the t2e callback for reference during follow-on generation.
229
+ m_t2eCallback = t2eCallback;
230
+
231
+ size_t embedBufSize = engine.getEmbeddingBufferSize();
232
+
233
+ {
234
+ std::vector<uint8_t> eosEmbedding(embedBufSize, 0.0);
235
+ if (m_t2eCallback) {
236
+ m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize);
237
+ }
238
+ // For non-autogenerative usecases (where t2eCallback is not supplied),
239
+ // the EOS vector is all zero. This is fine for models with proper
240
+ // attention masking support, but may degrade accuracy otherwise.
241
+ if (!engine.cacheEosEmbedding(eosEmbedding)) {
242
+ __DEBUG("Failed to set the eos token embedding.");
243
+ return false;
244
+ }
245
+ }
246
+
247
+ using FF = Engine::Feature::Flags;
248
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
249
+
250
+ size_t curTokenCount = embedding_vectors.size() / embedBufSize;
251
+ if (_n_past + curTokenCount > _ctx->size()) {
252
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokenCount, _ctx->size());
253
+ callback("", Sentence::END);
254
+ return true;
255
+ }
256
+
257
+ if (!engine.process(embedding_vectors, {}, logits))
258
+ return Dialog::abort("engine prompt processing failed", callback);
259
+
260
+ _n_prompt += curTokenCount;
261
+ _n_past += curTokenCount;
262
+
263
+ _prompt_len = _n_past;
264
+
265
+ if (!engine.updateKV(_n_past)) return Dialog::abort("context size exceeded", callback);
266
+
267
+ std::vector<std::vector<int32_t>> streams;
268
+ getTopK(logits, streams, _n_streams, _p_threshold, callback);
269
+
270
+ _n_generated += streams.size();
271
+ _kpis.prompt.update(start.elapsed_usec());
272
+
273
+ // Log latest KPIs
274
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
275
+
276
+ start.reset();
277
+
278
+ bool status = processFollowOnGeneration(streams, logits, callback);
279
+
280
+ _kpis.generate.update(start.elapsed_usec());
281
+
282
+ // Log latest KPIs in a single line
283
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
284
+
285
+ return status;
286
+ }
287
+
288
+ // Registrator instance
289
+ static OnLoad regy([]() {
290
+ Dialog::__register(
291
+ "multistream",
292
+ [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
293
+ return (Dialog*)new MultiStreamDialog(env, name, conf);
294
+ }
295
+ );
296
+ });
297
+
298
+ void needMultistreamDialog() {}
299
+
300
+ } // namespace qualla
Genie/Genie/src/qualla/dialogs/spec-dec.cpp ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All rights reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/dialog.hpp>
10
+ #include <qualla/sampler.hpp>
11
+ #include <qualla/logger.hpp>
12
+ #include <qualla/detail/config.hpp>
13
+ #include <qualla/detail/timer.hpp>
14
+ #include <qualla/detail/onload.hpp>
15
+ #include <qualla/detail/sampler-utils.hpp>
16
+ #include <qualla/detail/basic-sampler.hpp>
17
+
18
+ #include <functional>
19
+ #include <fstream>
20
+ #include <string>
21
+ #include <unordered_map>
22
+ #include <filesystem>
23
+ #include <random>
24
+ #include <thread>
25
+
26
+ #include <fmt/format.h>
27
+ #include <fmt/ranges.h>
28
+
29
+ namespace fs = std::filesystem;
30
+
31
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
32
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
33
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
34
+ #define __KPIS(__fmt, ...) \
35
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
36
+ #define __DEBUG(__fmt, ...) \
37
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
38
+ #define __TRACE(__fmt, ...) \
39
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
40
+
41
+ namespace qualla {
42
+
43
+ using qc = qualla::Config;
44
+
45
+ class SpecDecDialog : public Dialog {
46
+ public:
47
+ SpecDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf);
48
+
49
+ virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;
50
+
51
+ virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
52
+ return false;
53
+ }
54
+
55
+ private:
56
+ int32_t _draft_len; // Number of draft tokens
57
+ bool _parallel; // Enable parallel processing (where possible)
58
+
59
+ Sampler& _d_sampler; // Draft sampler
60
+ Sampler& _t_sampler; // Target sampler
61
+
62
+ // Token acceptor, called for each accepted token.
63
+ // Returns true to continue, false to stop
64
+ using Acceptor = std::function<bool(int32_t token)>;
65
+
66
+ // Rejection sampling.
67
+ // Returns number of accepted tokens
68
+ size_t rejectionSampling(
69
+ std::span<int32_t> tokens,
70
+ std::span<float> target_logits,
71
+ std::span<float> draft_probs,
72
+ Acceptor accept
73
+ );
74
+
75
+ int32_t sampleFromModifiedDist(std::span<float> src0_dst, std::span<float> src1);
76
+ };
77
+
78
+ SpecDecDialog::SpecDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf)
79
+ : Dialog(env, name, conf),
80
+ _d_sampler(_sampler.contains("draft") ? *_sampler["draft"] : *_sampler["target"]),
81
+ _t_sampler(*_sampler["target"]) {
82
+
83
+ _draft_len = qc::optional<int32_t>(conf, "draft-len", 3);
84
+ _parallel = qc::optional<bool>(conf, "parallel", false);
85
+
86
+ // Check all underlying components for correct types an config
87
+ // If something is not right we set our error state that can be checked later
88
+
89
+ if (!_sampler.contains("target")) {
90
+ State::fatal("\"target\" sampler not present in config!");
91
+ return;
92
+ }
93
+
94
+ if (!_engine.contains("target")) {
95
+ State::fatal("\"target\" engine not present in config!");
96
+ return;
97
+ }
98
+ if (!_engine.contains("draft")) {
99
+ State::fatal("\"draft\" engine not present in config!");
100
+ return;
101
+ }
102
+ }
103
+
104
+ int32_t SpecDecDialog::sampleFromModifiedDist(std::span<float> src0_dst, std::span<float> src1) {
105
+ // [max(prob_target[x] - prob_draft[x], 0.f) for all x in vocab]
106
+ size_t size = src0_dst.size();
107
+
108
+ if (_t_sampler.gumbel()) {
109
+ // Avoid going in the denormal zone.
110
+ float tiny = 1.1754943508222875e-38;
111
+
112
+ #pragma clang loop vectorize(enable) unroll_count(4)
113
+ for (size_t i = 0U; i < size; i++) {
114
+ float p_src0 = std::exp(src0_dst[i]);
115
+ float p_src1 = std::exp(src1[i]);
116
+ src0_dst[i] = std::log(std::max(tiny, p_src0 - p_src1));
117
+ }
118
+
119
+ // NOTE: The output logps_target is unnormalized since we use Gumbel trick.
120
+ // If we use standard multinomial sampling, normalization should be added.
121
+
122
+ } else {
123
+ float sum = 0.0; // Unlikely to overflow (?)
124
+ #pragma clang loop vectorize(enable) unroll_count(4)
125
+ for (size_t i = 0U; i < size; i++) {
126
+ float num = std::max(0.f, src0_dst[i] - src1[i]);
127
+ sum += num;
128
+ src0_dst[i] = num;
129
+ }
130
+ // Normalize
131
+ #pragma clang loop vectorize(enable) unroll_count(4)
132
+ for (size_t i = 0U; i < size; i++) {
133
+ src0_dst[i] /= sum;
134
+ }
135
+ }
136
+
137
+ if (_t_sampler.greedy()) return argmax(src0_dst);
138
+
139
+ if (_t_sampler.gumbel()) return sampleUsingGumbelMax(src0_dst, _t_sampler.rng());
140
+
141
+ // Skipping softmax since the probs are already normalized
142
+ return sampleFromProbs(src0_dst, _t_sampler.rng());
143
+ }
144
+
145
+ size_t SpecDecDialog::rejectionSampling(
146
+ std::span<int32_t> tokens,
147
+ std::span<float> target_logits,
148
+ std::span<float> draft_probs,
149
+ Acceptor accept
150
+ ) {
151
+ const size_t n_vocab = _ctx->n_vocab();
152
+ const size_t n_tok = tokens.size();
153
+
154
+ assert(tokens.size() == draft_probs.size() / n_vocab);
155
+ assert(target_logits.size() == draft_probs.size() + n_vocab);
156
+
157
+ // Rejection sampling:
158
+ // For each token in the n_tok tokens sampled from the draft model:
159
+ // 1. Determine the probability of that token being accepted by the target model
160
+ // 2. Accept the token with probability = prob_target[tok] / prob_draft[tok] (clamped to [0, 1])
161
+ // 3. If the token is rejected, resample a new token from the following distribution:
162
+ // [max(prob_target[x] - prob_draft[x], 0.f) for all x in vocab]
163
+ int32_t t_tok;
164
+ size_t n_accepted = 0;
165
+
166
+ std::vector<float> target_probs;
167
+
168
+ for (int32_t i = 0; i < n_tok; i++) {
169
+ int32_t d_tok = tokens[i];
170
+
171
+ std::span<float> t_span = target_logits.subspan(i * n_vocab, n_vocab);
172
+
173
+ if (_t_sampler.greedy()) {
174
+ t_tok = _t_sampler.process(t_span);
175
+ if (t_tok != d_tok) {
176
+ // Reject
177
+ break;
178
+ }
179
+ } else {
180
+ target_probs.clear();
181
+ t_tok = _t_sampler.process(t_span, target_probs, false); // only probs, no token
182
+
183
+ // Acceptance threshold
184
+ double threshold;
185
+ float prob_draft = draft_probs[i * n_vocab + d_tok];
186
+ float prob_target = target_probs[d_tok];
187
+
188
+ if (_t_sampler.gumbel()) {
189
+ threshold = std::exp(double(prob_target) - double(prob_draft));
190
+ } else {
191
+ threshold = double(prob_target) / double(prob_draft);
192
+ }
193
+
194
+ double r = sampleFromUniform(_t_sampler.rng());
195
+ if (r > threshold) {
196
+ // Reject
197
+ break;
198
+ }
199
+ }
200
+ // Accepted!
201
+ ++n_accepted;
202
+ if (!accept(d_tok)) return n_accepted;
203
+ }
204
+
205
+ // Sample an extra token either from the target distribution or the modified distribution
206
+ if (n_accepted == n_tok) {
207
+ t_tok = _t_sampler.process(target_logits.subspan(n_tok * n_vocab));
208
+ } else if (!_t_sampler.greedy()) {
209
+ // Resample from modified distribution.
210
+ t_tok = sampleFromModifiedDist(
211
+ std::span{target_probs.data(),target_probs.size()}, draft_probs.subspan(n_accepted * n_vocab, n_vocab)
212
+ );
213
+ } // for greedy, t_tok should be already valid from the loop above
214
+
215
+ ++n_accepted;
216
+ accept(t_tok);
217
+
218
+ return n_accepted;
219
+ }
220
+
221
+ bool SpecDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
222
+
223
+ // Check for prev failures and bail out early
224
+ if (State::failed()) return false;
225
+
226
+ Timer start;
227
+
228
+ const size_t n_vocab = _ctx->n_vocab();
229
+
230
+ // Vector for storing logits.
231
+ // Allocated & filled by the engine.
232
+ std::vector<float> t_logits;
233
+ std::vector<float> d_logits;
234
+
235
+ bool keep_generating = true;
236
+
237
+ // A buffer for tokens to be decoded (one at a time, per the Middleware's request)
238
+ std::vector<int32_t> decode_buf(1, 0);
239
+
240
+ // Decode new token.
241
+ // Return true to continue generation, and false otherwise
242
+ auto decode_token = [&](int32_t t) {
243
+ decode_buf[0] = _last_tok = t;
244
+
245
+ if (_ctx->is_eos(t)) {
246
+ keep_generating = false;
247
+ callback("", Sentence::END);
248
+ } else {
249
+ keep_generating = callback(_tokenizer->decode(decode_buf), Sentence::CONTINUE);
250
+ }
251
+
252
+ return keep_generating;
253
+ };
254
+
255
+ State::clear();
256
+
257
+ auto& t_engine = *_engine["target"];
258
+ auto& d_engine = *_engine["draft"];
259
+
260
+ if (_n_past + tokens.size() > _ctx->size()) {
261
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
262
+ callback("", Sentence::END);
263
+ return true;
264
+ }
265
+
266
+ // Step 0: Process the prompt both on the target and draft models.
267
+ bool d_pmpt, t_pmpt;
268
+ if (_parallel) {
269
+ std::thread dt([&]() { d_pmpt = d_engine.process(tokens, d_logits, false); });
270
+ std::thread tt([&]() { t_pmpt = t_engine.process(tokens, t_logits, false); });
271
+ dt.join();
272
+ tt.join();
273
+ } else {
274
+ d_pmpt = d_engine.process(tokens, d_logits, false);
275
+ t_pmpt = t_engine.process(tokens, t_logits, false);
276
+ }
277
+
278
+ if (!d_pmpt) return Dialog::abort("draft engine prompt processing failed", callback);
279
+ if (!t_pmpt) return Dialog::abort("target engine prompt processing failed", callback);
280
+
281
+ // KV state Update
282
+ _n_prompt += tokens.size();
283
+ _n_past += tokens.size();
284
+
285
+ if (!t_engine.updateKV(_n_past)) return Dialog::abort("target context size exceeded", callback);
286
+ if (!d_engine.updateKV(_n_past)) return Dialog::abort("draft context size exceeded", callback);
287
+
288
+ // Sample one token from the target.
289
+ _last_tok = _t_sampler.process(t_logits);
290
+
291
+ _kpis.prompt.update(start.elapsed_usec());
292
+
293
+ // Log latest KPIs
294
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
295
+
296
+ if (!decode_token(_last_tok)) return true;
297
+
298
+ // Done with the prompt, start generating
299
+ start.reset();
300
+ State::busy(true);
301
+
302
+ // Buffers for all the tokens that need to be considered for each iteration
303
+ std::vector<int32_t> toks_to_target(_draft_len + 1);
304
+ std::vector<int32_t> toks_to_draft(2);
305
+
306
+ // Buffer for all the probability distributions from the draft sampler
307
+ std::vector<float> d_probs(n_vocab * _draft_len);
308
+
309
+ toks_to_target.assign(1, _last_tok);
310
+ toks_to_draft.assign(1, _last_tok);
311
+
312
+ // For keeping track of the number of tokens that were accepted in each iteration.
313
+ std::vector<int32_t> n_accepted_counts(_draft_len + 1, 0);
314
+
315
+ // Draft n_past, either in sync with n_past or one token behind (accepted-all)
316
+ size_t d_n_past = _n_past;
317
+
318
+ while (!State::canceled() && keep_generating) {
319
+ // Step 1: Use draft model to decode draft_len (aka gamma) tokens, and accumulate probabilities
320
+ d_probs.clear();
321
+
322
+ for (int32_t i = 0; i < _draft_len; i++) {
323
+ if (d_n_past + toks_to_draft.size() > _ctx->size()) {
324
+ __WARN("Context limit exceeded ({} + {} > {})",
325
+ d_n_past,
326
+ toks_to_target.size(),
327
+ _ctx->size());
328
+ _kpis.generate.update(start.elapsed_usec());
329
+
330
+ // Log latest KPIs in a single line
331
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
332
+ callback("", Sentence::END);
333
+ return true;
334
+ }
335
+
336
+ if (!d_engine.process(toks_to_draft, d_logits))
337
+ return Dialog::abort("draft engine gen processing failed", callback);
338
+
339
+ d_n_past += toks_to_draft.size();
340
+
341
+ if (!d_engine.updateKV(d_n_past))
342
+ return Dialog::abort("draft context size exceeded", callback);
343
+
344
+ int32_t token = _d_sampler.process(d_logits, d_probs);
345
+ toks_to_draft.assign(1, token);
346
+ toks_to_target.push_back(token);
347
+
348
+ if (_ctx->is_eos(token)) break;
349
+ }
350
+
351
+ // Step 2: run the target model on the draft tokens
352
+ if (_n_past + toks_to_target.size() > _ctx->size()) {
353
+ __WARN("Context limit exceeded ({} + {} > {})",
354
+ _n_past,
355
+ toks_to_target.size(),
356
+ _ctx->size());
357
+ callback("", Sentence::END);
358
+ _kpis.generate.update(start.elapsed_usec());
359
+
360
+ // Log latest KPIs in a single line
361
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
362
+ return true;
363
+ }
364
+
365
+ std::vector<int32_t> attention_map(toks_to_target.size());
366
+ std::iota(attention_map.begin(), attention_map.end(), -1);
367
+ size_t n_tok_t =
368
+ t_engine.process(toks_to_target, attention_map, t_logits, true /* all logits */);
369
+ if (n_tok_t != toks_to_target.size())
370
+ return Dialog::abort("target engine gen processing failed", callback);
371
+
372
+ // Step 3: accept or reject draft tokens
373
+ size_t n_accepted = rejectionSampling(
374
+ std::span{toks_to_target.data(),toks_to_target.size()}.subspan(1),
375
+ std::span{t_logits.data(),t_logits.size()}, std::span{d_probs.data(),d_probs.size()}, decode_token
376
+ );
377
+
378
+ _n_generated += n_accepted;
379
+ _n_past += n_accepted;
380
+
381
+ // Update stats
382
+ n_accepted_counts[n_accepted - 1]++;
383
+
384
+ // Accepted all?
385
+ if (n_accepted == _draft_len + 1) {
386
+ // Grab the last 2 tokens
387
+ toks_to_draft.assign({toks_to_target[_draft_len], _last_tok});
388
+ d_n_past = _n_past - 1;
389
+ } else {
390
+ // Grab only the last token
391
+ toks_to_draft.assign(1, _last_tok);
392
+ d_n_past = _n_past;
393
+ }
394
+
395
+ toks_to_target.assign(1, _last_tok);
396
+
397
+ __DEBUG("spec-dec: draft_len {} n_generated {} n_accepted {} n_past {}",
398
+ _draft_len,
399
+ _n_generated,
400
+ n_accepted,
401
+ _n_past);
402
+
403
+ std::vector<bool> selected(attention_map.size(), false);
404
+ selected[0] = true; // first token is selected always
405
+ auto last_sel = 0;
406
+ for (int i = n_accepted - 1; i != 0; i = attention_map[i]) {
407
+ selected[i] = true;
408
+ last_sel = i > last_sel ? i : last_sel;
409
+ }
410
+ selected.resize(last_sel + 1); // trim away rejected tokens
411
+
412
+ // Step 4: commit accepted tokens to kv-caches
413
+ if (!t_engine.updateKV(_n_past, selected))
414
+ return Dialog::abort("target context size exceeded", callback);
415
+ if (!d_engine.updateKV(d_n_past))
416
+ return Dialog::abort("draft context size exceeded", callback);
417
+ }
418
+
419
+ if (d_n_past != _n_past) {
420
+ // The draft engine needs to process one last token to catch up
421
+ toks_to_draft.resize(1);
422
+ if (!d_engine.process(toks_to_draft))
423
+ return Dialog::abort("draft engine gen processing failed", callback);
424
+ if (!d_engine.updateKV(_n_past))
425
+ return Dialog::abort("draft context size exceeded", callback);
426
+ }
427
+
428
+ State::busy(false);
429
+
430
+ _kpis.generate.update(start.elapsed_usec());
431
+
432
+ // Log latest KPIs in a single line
433
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
434
+ __KPIS("spec-dec: accepted counts: {}", n_accepted_counts);
435
+
436
+ return true;
437
+ }
438
+
439
+ // Registrator instance
440
+ static OnLoad regy([]() {
441
+ Dialog::__register(
442
+ "spec-dec",
443
+ [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
444
+ return (Dialog*)new SpecDecDialog(env, name, conf);
445
+ }
446
+ );
447
+ });
448
+
449
+ // Register spec-dec sampler for compatibility
450
+ static OnLoad sampler_regy([]() {
451
+ Sampler::__register("spec-dec", [](Context& ctx, const json& conf) {
452
+ return (Sampler*)new BasicSampler(ctx, conf);
453
+ });
454
+ });
455
+
456
+ void needSpdDialog() {}
457
+
458
+ } // namespace qualla
Genie/Genie/src/qualla/dialogs/ssd-q1.cpp ADDED
@@ -0,0 +1,1046 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All rights reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/context.hpp>
10
+ #include <qualla/dialog.hpp>
11
+ #include <qualla/sampler.hpp>
12
+ #include <qualla/logger.hpp>
13
+ #include <qualla/detail/config.hpp>
14
+ #include <qualla/detail/json.hpp>
15
+ #include <qualla/detail/timer.hpp>
16
+ #include <qualla/detail/onload.hpp>
17
+ #include <qualla/detail/sampler-utils.hpp>
18
+ #include <qualla/detail/basic-sampler.hpp>
19
+
20
+ #include <functional>
21
+ #include <fstream>
22
+ #include <string>
23
+ #include <unordered_map>
24
+ #include <filesystem>
25
+ #include <random>
26
+
27
+ #include <fmt/format.h>
28
+ #include <fmt/ranges.h>
29
+
30
+ namespace fs = std::filesystem;
31
+
32
+ #define __INFO(__fmt, ...) _env->logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
33
+ #define __WARN(__fmt, ...) _env->logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
34
+ #define __ERROR(__fmt, ...) _env->logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
35
+ #define __KPIS(__fmt, ...) \
36
+ _env->logger().post(Logger::KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
37
+ #define __DEBUG(__fmt, ...) \
38
+ _env->logger().post(Logger::DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
39
+ #define __TRACE(__fmt, ...) \
40
+ _env->logger().post(Logger::TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
41
+
42
+ namespace qualla {
43
+
44
+ using qc = qualla::Config;
45
+ using Logits = std::span<float>;
46
+
47
+ class SelfSpecDecDialog : public Dialog {
48
+ enum { VERSION = 1 };
49
+
50
+ public:
51
+ SelfSpecDecDialog(std::shared_ptr<Env> env, const std::string& name, const json& conf);
52
+
53
+ virtual bool process(std::vector<int32_t>& tokens, Dialog::Callback callback) override;
54
+ virtual bool process(std::vector<uint8_t>& embedding_vectors, Dialog::T2ECallback t2eCallback, Dialog::Callback callback) override;
55
+ virtual void reset() override;
56
+
57
+ virtual bool process(std::vector<int32_t>& tokens, DialogCallback callback) override {
58
+ return false;
59
+ }
60
+
61
+ virtual bool save(const std::string& name) override;
62
+ virtual bool restore(const std::string& name) override;
63
+
64
+ private:
65
+ Sampler& _t_sampler;
66
+
67
+ int32_t _vocab;
68
+
69
+ std::string _kv_prefix_name{"forecast-prefix"};
70
+
71
+ // AR8
72
+ size_t _draft{1};
73
+ std::vector<size_t> _branches{3};
74
+
75
+ size_t _forecast_prefix{16};
76
+ size_t _forecast_token_offset{32000};
77
+ size_t _forecast_token_count{4};
78
+
79
+ // Multistream parameters
80
+ int32_t _n_streams;
81
+ float _p_threshold;
82
+
83
+ InputType m_inputType{InputType::UNKNOWN};
84
+
85
+ bool processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, Dialog::Callback callback);
86
+ // Multistream
87
+ bool processFollowOnGeneration(std::vector<std::vector<int32_t>>& streams, std::vector<float>& logits, Dialog::Callback callback);
88
+
89
+ /*
90
+ Helper function for combining masks for SSD mulstistream.
91
+
92
+ @param masks The attention mask to be tiled
93
+ @param streamIndices Indices of streams. The tiling count is equal to the size of this vector.
94
+ @param pastMap A vector of stream indices for masking all past tokens after the prompt.
95
+ @param prefixOffset Offset where KV prefix masking begins in each tile.
96
+ @param finalMask A mask that combines all of the independent masks such that
97
+ they can be executed in the same inference.
98
+ */
99
+ void tileAttentionMask(const std::vector<int32_t>& mask, const std::vector<size_t> streamIndices, const std::vector<size_t>& pastMap, const size_t prefixOffset, std::vector<int32_t>& finalMask);
100
+
101
+ std::vector<int32_t> gen_attention_map() const;
102
+ auto get_len_flat_sample_tree() const;
103
+ auto gen_forecast_tokens(int repeat) const;
104
+
105
+ // Sampling and verification
106
+ std::vector<int32_t> build_sample_tree(
107
+ int32_t last_token,
108
+ Logits logits,
109
+ const std::vector<int32_t>& indices
110
+ );
111
+ std::tuple<std::vector<int32_t>, std::vector<int32_t>> verify_and_select_longest(
112
+ std::span<int32_t> sample_tree,
113
+ Logits logits
114
+ );
115
+ std::vector<int32_t> sample_to_draft(Logits logits, size_t index, size_t count) {
116
+ const auto thislogit = logits.subspan(index * _vocab, _vocab);
117
+ IndexedLogits logit(thislogit, _t_sampler.rng());
118
+ logit.topK(count);
119
+ return logit.indices;
120
+ }
121
+ int32_t sample_to_verify(Logits logits, size_t index) {
122
+ const auto thislogit = logits.subspan(index * _vocab, _vocab);
123
+ if (_t_sampler.greedy()) {
124
+ return argmax(thislogit);
125
+ }
126
+ auto token = _t_sampler.process(thislogit);
127
+ return token;
128
+ }
129
+ };
130
+
131
+ SelfSpecDecDialog::SelfSpecDecDialog(
132
+ std::shared_ptr<Env> env,
133
+ const std::string& name,
134
+ const json& conf
135
+ )
136
+ : Dialog(env, name, conf), _t_sampler(*_sampler["primary"]) {
137
+
138
+ auto ssd_version = qc::optional<int>(conf, "ssd-version", 0);
139
+ if (ssd_version > SelfSpecDecDialog::VERSION) __WARN("newer ssd-version in config!");
140
+
141
+ _vocab = _ctx->n_vocab();
142
+
143
+ _branches = qc::optional(conf, "branches", _branches);
144
+ _draft = _branches.size();
145
+
146
+ _forecast_prefix = qc::optional(conf, "forecast-prefix", _forecast_prefix);
147
+ _forecast_token_count = qc::optional(conf, "forecast-token-count", _forecast_token_count);
148
+ _forecast_token_offset = _vocab;
149
+
150
+ _kv_prefix_name = qc::optional(conf, "forecast-prefix-name", _kv_prefix_name);
151
+
152
+ _n_streams = qc::optional<int32_t>(conf, "n-streams", 1);
153
+ _p_threshold = qc::optional<float>(conf, "p-threshold", 0.0);
154
+
155
+ if (!_engine.contains("primary")) {
156
+ State::fatal("\"primary\" engine not present in config!");
157
+ return;
158
+ }
159
+
160
+ //Get Input Type from the engine
161
+ m_inputType = _engine["primary"]->getInputType();
162
+ // Load KV prefix
163
+ Timer timer;
164
+ size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
165
+ if (n_restored_prefix != _forecast_prefix) {
166
+ // clang-format off
167
+ throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
168
+ n_restored_prefix, _kv_prefix_name, _forecast_prefix ) );
169
+ // clang-format on
170
+ }
171
+ _n_past = _forecast_prefix;
172
+ _kpis.restore.update(timer.elapsed_usec());
173
+ }
174
+
175
+ auto SelfSpecDecDialog::get_len_flat_sample_tree() const {
176
+ size_t len_flat_sample_tree = 1;
177
+ size_t last_tokens = 1;
178
+ for (int i = 0; i < _draft; ++i) {
179
+ len_flat_sample_tree += last_tokens * _branches[i];
180
+ last_tokens = last_tokens * _branches[i];
181
+ }
182
+ return len_flat_sample_tree;
183
+ }
184
+
185
+ auto SelfSpecDecDialog::gen_forecast_tokens(int repeat) const {
186
+ std::vector<int32_t> forecast_tokens(_draft, 0);
187
+ std::iota(forecast_tokens.begin(), forecast_tokens.end(), _forecast_token_offset);
188
+
189
+ std::vector<int32_t> ret;
190
+ for (auto i = 0; i < repeat; ++i)
191
+ ret.insert(ret.end(), forecast_tokens.begin(), forecast_tokens.end());
192
+ return ret;
193
+ }
194
+
195
+ std::vector<int32_t> SelfSpecDecDialog::gen_attention_map() const {
196
+ auto len_flat_sample_tree = get_len_flat_sample_tree();
197
+ std::vector<int32_t> attention_map(len_flat_sample_tree + len_flat_sample_tree * _draft, -1);
198
+
199
+ auto build_verify_tree = [&attention_map,
200
+ this](auto self, int parent_begin, int parent_end, int level) {
201
+ if (level == _draft) return;
202
+ auto current = parent_end;
203
+ for (auto parent = parent_begin; parent < parent_end; parent += 1) {
204
+ for (auto child = current; child < current + _branches[level]; child += 1)
205
+ attention_map[child] = parent;
206
+ current += _branches[level];
207
+ }
208
+ self(self, parent_end, current, level + 1);
209
+ };
210
+
211
+ auto build_forecast_tree = [&attention_map, this](int parent_begin, int parent_end) {
212
+ auto current = parent_end;
213
+ for (auto parent = parent_begin; parent < parent_end; parent += 1) {
214
+ for (auto child = current, current_parent = parent; child < current + _draft;
215
+ child += 1) {
216
+ attention_map[child] = current_parent;
217
+ current_parent = child;
218
+ }
219
+ current += _draft;
220
+ }
221
+ };
222
+
223
+ build_verify_tree(build_verify_tree, 0, 1, 0);
224
+ build_forecast_tree(0, len_flat_sample_tree);
225
+ return attention_map;
226
+ }
227
+
228
+ std::vector<int32_t> SelfSpecDecDialog::build_sample_tree(
229
+ int32_t last_token,
230
+ Logits logits,
231
+ const std::vector<int32_t>& indices
232
+ ) {
233
+ std::vector<int32_t> tree = {last_token};
234
+ for (auto draft = 0, repeat = 1; draft < _draft; ++draft) {
235
+ auto samples = sample_to_draft(logits, indices[draft], _branches[draft]);
236
+ for (auto i = 0; i < repeat; ++i) {
237
+ tree.insert(tree.end(), samples.begin(), samples.end());
238
+ }
239
+ repeat *= _branches[draft];
240
+ }
241
+ return tree;
242
+ }
243
+
244
+ std::tuple<std::vector<int32_t>, std::vector<int32_t>> SelfSpecDecDialog::verify_and_select_longest(
245
+ std::span<int32_t> sample_tree,
246
+ Logits logits
247
+ ) {
248
+ std::vector<std::vector<int32_t>> accepted_all = {{sample_to_verify(logits, 0)}};
249
+ std::vector<std::vector<int32_t>> node_ids_all = {{0}};
250
+
251
+ std::vector<int32_t> draft_offset(_draft, 0);
252
+ draft_offset[0] = 1;
253
+ for (int32_t i = 1, draft_count = _branches[0]; i < _draft; ++i) {
254
+ draft_offset[i] = draft_offset[i - 1] + draft_count;
255
+ draft_count = draft_count * _branches[i];
256
+ }
257
+
258
+ size_t longest = 0, longest_size = 1;
259
+ auto verify_recursive = [&](auto self,
260
+ std::vector<int32_t> accepted,
261
+ std::vector<int32_t> node_ids,
262
+ int draft,
263
+ int offset_in_draft) -> void {
264
+ auto target = accepted.back();
265
+ auto branch_base = draft_offset[draft] + offset_in_draft;
266
+ for (auto branch = 0; branch < _branches[draft]; ++branch) {
267
+ auto ndx_node = branch_base + branch;
268
+ if (!_ctx->is_eos(target) && target == sample_tree[ndx_node]) {
269
+ auto sample_accepted = sample_to_verify(logits, ndx_node);
270
+ accepted_all.push_back(accepted);
271
+ accepted_all.back().push_back(sample_accepted);
272
+ node_ids_all.push_back(node_ids);
273
+ node_ids_all.back().push_back(ndx_node);
274
+ if (node_ids_all.back().size() > longest_size) {
275
+ longest = node_ids_all.size() - 1;
276
+ longest_size = node_ids_all.back().size();
277
+ }
278
+ if (draft + 1 < _draft)
279
+ self(self,
280
+ accepted_all.back(),
281
+ node_ids_all.back(),
282
+ draft + 1,
283
+ (offset_in_draft + branch) * _branches[draft + 1]);
284
+ }
285
+ }
286
+ };
287
+ verify_recursive(verify_recursive, accepted_all.back(), node_ids_all.back(), 0, 0);
288
+ return {accepted_all[longest], node_ids_all[longest]};
289
+ }
290
+
291
+ void SelfSpecDecDialog::tileAttentionMask(const std::vector<int32_t>& mask, const std::vector<size_t> streamIndices, const std::vector<size_t>& pastMap, const size_t prefixOffset, std::vector<int32_t>& tiledMask) {
292
+
293
+ const size_t sampleTreeLen = get_len_flat_sample_tree();
294
+ const size_t pastMapLen = pastMap.size();
295
+ const int posVal = 1, negVal = 0;
296
+
297
+ const size_t maskSize = mask.size();
298
+ const size_t numTokens = maskSize * streamIndices.size();
299
+
300
+ const size_t rowLength = _n_past + numTokens;
301
+ tiledMask.resize(numTokens * rowLength);
302
+
303
+ for (int maskIdx = 0; maskIdx < streamIndices.size(); maskIdx++) {
304
+ // Number of rows to skip to reach the current tile.
305
+ const size_t tileOffset = maskIdx * maskSize;
306
+ int32_t* const tileStart = &tiledMask[tileOffset*rowLength + tileOffset + _n_past];
307
+ for (int i = 0; i < maskSize; i++) {
308
+ // Pointer to the start of row i of the current mask
309
+ int32_t* rowPtr = &tiledMask[(tileOffset + i)*rowLength];
310
+ // Skip kv-prefix attention for rows without speculative tokens.
311
+ const int prefixFillVal = (i < prefixOffset) ? negVal : posVal;
312
+ std::fill_n(rowPtr, _forecast_prefix, prefixFillVal);
313
+ rowPtr += _forecast_prefix;
314
+ // Always attend to prompt.
315
+ std::fill_n(rowPtr, _n_prompt, posVal);
316
+ rowPtr += _n_prompt;
317
+
318
+ // Fill in the past valid tokens for this stream.
319
+ for (const size_t& pastIdx : pastMap) {
320
+ *rowPtr = (pastIdx == streamIndices[maskIdx]) ? posVal : negVal;
321
+ rowPtr++;
322
+ }
323
+
324
+ // Clear the rest of the row. It will mostly consist of 0's.
325
+ std::fill_n(rowPtr, rowLength - _n_prompt - _forecast_prefix - pastMapLen, negVal);
326
+ // Move to the correct tile.
327
+ rowPtr += tileOffset;
328
+ // Translate the mask.
329
+ const auto tokenId = mask[i];
330
+ if (tokenId > -1) {
331
+ std::copy_n(tileStart + (tokenId * rowLength), tokenId + 1, rowPtr);
332
+ }
333
+ // Always attend to self.
334
+ rowPtr[i] = posVal;
335
+ }
336
+ }
337
+ }
338
+
339
+ // Takes a vector of tokens and produces a vector of embeddings via the provided T2E callback.
340
+ static inline void convertTokensToEmbeddings(std::vector<int32_t>& tokens,
341
+ std::vector<uint8_t>& embeddings,
342
+ size_t embeddingBufferSize,
343
+ Dialog::T2ECallback t2eCallback) {
344
+ for(auto &token : tokens){
345
+ std::vector<uint8_t> embedding(embeddingBufferSize,0);
346
+ t2eCallback(token, embedding.data(), embeddingBufferSize);
347
+ embeddings.insert(embeddings.end(), embedding.begin(), embedding.end());
348
+ }
349
+ }
350
+
351
+ bool SelfSpecDecDialog::processFollowOnGeneration(std::vector<int32_t>& tokens, std::vector<float>& logits, Dialog::Callback callback){
352
+
353
+ // Handles the printing of the subsequent generated tokens
354
+ bool keep_generating = true;
355
+ const size_t context = _ctx->n_ctx();
356
+
357
+ std::vector<int32_t> decode_buf(
358
+ 1, 0
359
+ ); // A buffer for tokens to be decoded (one at a time, per the Middleware's request)
360
+ auto decode_token = [&](int32_t t) {
361
+ if (!keep_generating) return;
362
+ // Decode new token.
363
+ // Return true to continue generation, and false otherwise
364
+ decode_buf[0] = _last_tok = t;
365
+ ++_n_generated;
366
+ if (_ctx->is_eos(t)) {
367
+ keep_generating = false;
368
+ callback("", Sentence::END);
369
+ } else {
370
+ keep_generating = callback(_tokenizer->decode(decode_buf), Sentence::CONTINUE);
371
+ }
372
+ return;
373
+ };
374
+ // set decode_buf from prompt processing
375
+ decode_buf[0] = _last_tok;
376
+
377
+ auto& engine = *_engine["primary"];
378
+
379
+ auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
380
+ if (!engine.updateKV(past, selected))
381
+ return Dialog::abort("context size exceeded", callback);
382
+ return true;
383
+ };
384
+
385
+
386
+ // prepare the next inference
387
+ std::vector<int32_t> indices(_draft, 0);
388
+ std::iota(indices.begin(), indices.end(), 1);
389
+ tokens = build_sample_tree(sample_to_verify(std::span{logits.data(),logits.size()}, 0), std::span{logits.data(),logits.size()}, indices);
390
+ decode_token(tokens[0]);
391
+
392
+ // Prepare constant options for next inferences
393
+ const auto len_flat_sample_tree = get_len_flat_sample_tree();
394
+ const auto forecast_tokens = gen_forecast_tokens(len_flat_sample_tree);
395
+ const auto attention_map = gen_attention_map();
396
+
397
+ engine.set({{"kv-prefix-offset", len_flat_sample_tree}});
398
+
399
+ std::vector<int32_t> accepted_counts(_draft + 1, 0);
400
+ std::vector<bool> selected(attention_map.size(), false);
401
+
402
+ while (!State::canceled() && keep_generating) {
403
+
404
+ // Append forecast tokens
405
+ tokens.insert(tokens.end(), forecast_tokens.begin(), forecast_tokens.end());
406
+
407
+ if (_n_past + tokens.size() > _ctx->size()) {
408
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
409
+ callback("", Sentence::END);
410
+ break;
411
+ }
412
+
413
+ size_t n_tok_t = 0;
414
+
415
+ // Bifurcate based on embedding as input or token as input
416
+ if (m_inputType == InputType::TOKENS)
417
+ n_tok_t = engine.process(tokens, attention_map, logits, true /* all logits */);
418
+ else if (m_inputType == InputType::EMBEDDINGS) {
419
+ // Convert tokens to embedding for the processing in the engine.
420
+ auto embedBufSize = engine.getEmbeddingBufferSize();
421
+ std::vector<uint8_t> embedding;
422
+ for(auto &token: tokens){
423
+ std::vector<uint8_t> curTokenEmbedding(embedBufSize,0);
424
+ m_t2eCallback(token, curTokenEmbedding.data(), embedBufSize);
425
+ embedding.insert(embedding.end(), curTokenEmbedding.begin(), curTokenEmbedding.end());
426
+ }
427
+ n_tok_t = engine.process(embedding, attention_map, logits, true /* all logits */);
428
+ } else {
429
+ return Dialog::abort("No valid Input Type is used", callback);
430
+ }
431
+ if (n_tok_t != tokens.size()) return Dialog::abort("engine processing failed", callback);
432
+
433
+ // Accept tokens
434
+ auto [accepted_tokens, accepted_ids] = verify_and_select_longest(std::span{tokens.data(),tokens.size()},
435
+ std::span{logits.data(),logits.size()});
436
+
437
+ // Commit accepted tokens to kv-caches
438
+ selected.resize(accepted_ids.back() + 1); // trim away rejected tokens
439
+ std::fill(selected.begin(), selected.end(), false);
440
+ for (auto id : accepted_ids)
441
+ selected[id] = true;
442
+ accepted_counts[accepted_tokens.size() - 1] += 1;
443
+ _n_past += accepted_tokens.size();
444
+ update_kv(_n_past, selected);
445
+
446
+ // Decode tokens
447
+ std::for_each(accepted_tokens.begin(), accepted_tokens.end(), decode_token);
448
+
449
+ // Prepare new tokens
450
+ auto next_draft_offset = len_flat_sample_tree + accepted_ids.back() * _draft;
451
+ std::iota(indices.begin(), indices.end(), next_draft_offset);
452
+ tokens = build_sample_tree(accepted_tokens.back(), std::span{logits.data(),logits.size()}, indices);
453
+ }
454
+
455
+ State::busy(false);
456
+
457
+ auto total_iteration = std::accumulate(accepted_counts.begin(), accepted_counts.end(), 0);
458
+ auto accept_rate =
459
+ float(_n_generated - 1) / total_iteration; // -1: exclude first generated token
460
+ __KPIS("SSD{{draft:{}, branch:{}, greedy:{}}}: accepted counts: {}, accept rate = {} tokens/iteration",
461
+ _draft,
462
+ _branches,
463
+ _t_sampler.greedy(),
464
+ accepted_counts,
465
+ accept_rate);
466
+
467
+ return true;
468
+ }
469
+
470
+ // Multistream AR generation
471
+ bool SelfSpecDecDialog::processFollowOnGeneration(std::vector<std::vector<int32_t>>& streams, std::vector<float>& logits, Dialog::Callback callback) {
472
+
473
+ auto& sampler = *_sampler["primary"];
474
+ auto& engine = *_engine["primary"];
475
+
476
+ auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
477
+ if (!engine.updateKV(past, selected))
478
+ return Dialog::abort("context size exceeded", callback);
479
+ return true;
480
+ };
481
+
482
+ std::vector<size_t> streamIndices(streams.size());
483
+ std::vector<size_t> past_map(streams.size());
484
+
485
+ std::iota(streamIndices.begin(), streamIndices.end(), 0);
486
+ // Since the first inference is done separately, it is
487
+ // expected that each stream already has 1 valid AR token.
488
+ std::iota(past_map.begin(), past_map.end(), 0);
489
+
490
+ bool keep_generating = true;
491
+ const size_t context = _ctx->n_ctx();
492
+
493
+ if (streams.size() == 0) {
494
+ callback("\n", Sentence::END);
495
+ return true;
496
+ }
497
+
498
+ // Prepare constant options for next inferences
499
+ const auto len_flat_sample_tree = get_len_flat_sample_tree();
500
+ const auto forecast_tokens = gen_forecast_tokens(len_flat_sample_tree);
501
+ const auto attention_map = gen_attention_map();
502
+
503
+ std::vector<std::vector<int32_t>> draftStreams(streams.size());
504
+
505
+ for (int i = 0; i < streams.size(); i++) {
506
+ // prepare the next inference
507
+ std::vector<int32_t> indices(_draft, 0);
508
+ std::iota(indices.begin(), indices.end(), 1);
509
+ draftStreams[i] = build_sample_tree(sample_to_verify(std::span{logits.data(),logits.size()}, i*(1+_draft)), std::span{logits.data(),logits.size()}, indices);
510
+ streams[i].push_back(draftStreams[i][0]);
511
+
512
+ }
513
+
514
+ std::vector<int32_t> multi_attn_mask;
515
+
516
+ std::vector<int32_t> accepted_counts(_draft + 1, 0);
517
+
518
+ engine.set({{"kv-prefix-offset", len_flat_sample_tree}});
519
+
520
+ State::busy(true);
521
+
522
+ while (true) {
523
+ if (State::canceled()) break;
524
+
525
+ // If this exceeds context length, truncate all streams and return
526
+ if (_n_past + streamIndices.size() > _ctx->size()) {
527
+ for (auto stream : streamIndices)
528
+ callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
529
+ break;
530
+ }
531
+
532
+ // Accumulate input tokens from all streams
533
+ std::vector<int32_t> multi_tokens;
534
+ for (auto streamIdx : streamIndices) {
535
+ multi_tokens.insert(multi_tokens.end(), draftStreams[streamIdx].begin(), draftStreams[streamIdx].end());
536
+ multi_tokens.insert(multi_tokens.end(), forecast_tokens.begin(), forecast_tokens.end());
537
+ }
538
+
539
+ if (_n_past + multi_tokens.size() > _ctx->size()) {
540
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size());
541
+ callback("", Sentence::END);
542
+ break;
543
+ }
544
+
545
+ tileAttentionMask(attention_map, streamIndices, past_map, len_flat_sample_tree, multi_attn_mask);
546
+
547
+ size_t n_tok_t = 0;
548
+
549
+ if (m_inputType == InputType::TOKENS) {
550
+ // Process input tokens for all streams in one batch
551
+ n_tok_t = engine.process(multi_tokens, multi_attn_mask, logits, true);
552
+ } else if (m_inputType == InputType::EMBEDDINGS) {
553
+ // Accumulate input embeddings from all streams
554
+ auto embedBufSize = engine.getEmbeddingBufferSize();
555
+ std::vector<uint8_t> multi_embeddings;
556
+
557
+ convertTokensToEmbeddings(multi_tokens, multi_embeddings, embedBufSize, m_t2eCallback);
558
+
559
+ // Process input tokens for all streams in one batch
560
+ n_tok_t = engine.process(multi_embeddings, multi_attn_mask, logits, true);
561
+ }
562
+ if (n_tok_t != multi_tokens.size()) return Dialog::abort("engine processing failed", callback);
563
+
564
+ std::vector<bool> all_selected;
565
+
566
+ // Process all logits independently
567
+ std::span<float> logit_span = std::span{logits.data(),logits.size()};
568
+ std::span<int32_t> token_span = std::span{multi_tokens.data(), multi_tokens.size()};
569
+ for (int i = 0; i < streamIndices.size(); i++) {
570
+ const size_t streamIdx = streamIndices[i];
571
+ std::vector<int32_t>& stream = streams[streamIdx];
572
+
573
+ const size_t tileStride = draftStreams[streamIdx].size() + forecast_tokens.size();
574
+
575
+ std::span<float> tiled_logits = logit_span.subspan(i * tileStride * _vocab, _vocab);
576
+
577
+ // Accept tokens
578
+ auto [accepted_tokens, accepted_ids] = verify_and_select_longest(token_span.subspan(i * tileStride, tileStride),
579
+ tiled_logits);
580
+
581
+ // Commit accepted tokens to kv-caches
582
+ std::vector<bool> selected(tileStride, false);
583
+ for (auto id : accepted_ids) {
584
+ selected[id] = true;
585
+ past_map.push_back(streamIdx);
586
+ }
587
+ all_selected.insert(all_selected.end(), selected.begin(), selected.end());
588
+ accepted_counts[accepted_tokens.size() - 1] += 1;
589
+ _n_past += accepted_tokens.size();
590
+
591
+ // Decode tokens
592
+ stream.insert(stream.end(), accepted_tokens.begin(), accepted_tokens.end());
593
+ _n_generated += accepted_tokens.size();
594
+
595
+ // Prepare new tokens
596
+ std::vector<int32_t> indices(_draft, 0);
597
+ auto next_draft_offset = len_flat_sample_tree + accepted_ids.back() * _draft;
598
+ std::iota(indices.begin(), indices.end(), next_draft_offset);
599
+ draftStreams[streamIdx] = build_sample_tree(accepted_tokens.back(), tiled_logits, indices);
600
+ }
601
+
602
+ update_kv(_n_past, all_selected);
603
+ for (auto it = streamIndices.begin(); it != streamIndices.end();) {
604
+ int32_t stream = *it;
605
+ if (_ctx->is_eos(streams[stream].back())) {
606
+ callback(_tokenizer->decode(streams[stream]) + "\n", Sentence::CONTINUE);
607
+ it = streamIndices.erase(it);
608
+ } else {
609
+ ++it;
610
+ }
611
+ }
612
+
613
+ if (streamIndices.size() == 0) break;
614
+ }
615
+ callback("\n", Sentence::END);
616
+
617
+ State::busy(false);
618
+
619
+ auto total_iteration = std::accumulate(accepted_counts.begin(), accepted_counts.end(), 0);
620
+ auto accept_rate =
621
+ float(_n_generated - 1) / total_iteration; // -1: exclude first generated token
622
+ __KPIS("SSD{{draft:{}, branch:{}, greedy:{}}}: accepted counts: {}, accept rate = {} tokens/iteration",
623
+ _draft,
624
+ _branches,
625
+ _t_sampler.greedy(),
626
+ accepted_counts,
627
+ accept_rate);
628
+
629
+ return true;
630
+ }
631
+
632
+ // Handle prompt processing and generation will be done processFollowOnGeneration
633
+ // Pass t2e callback using setter and remove as an argument. call setter from the base query function of dialog
634
+
635
+ bool SelfSpecDecDialog::process(std::vector<uint8_t>& embedding,
636
+ T2ECallback t2eCallback,
637
+ Dialog::Callback callback ){
638
+
639
+ // Check for prev failures and bail out early
640
+ if (State::failed()) return false;
641
+
642
+ if(m_inputType != InputType::EMBEDDINGS) {
643
+ __ERROR("Input type for model is not embeddings.");
644
+ return false;
645
+ }
646
+
647
+ Timer start;
648
+ State::clear();
649
+
650
+ std::vector<float> logits;
651
+ auto& engine = *_engine["primary"];
652
+
653
+ auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
654
+ if (!engine.updateKV(past, selected))
655
+ return Dialog::abort("context size exceeded", callback);
656
+ return true;
657
+ };
658
+
659
+ // Store the t2e callback for reference during follow-on generation.
660
+ m_t2eCallback = t2eCallback;
661
+
662
+ auto embedBufSize = engine.getEmbeddingBufferSize();
663
+
664
+ {
665
+ std::vector<uint8_t> eosEmbedding(embedBufSize, 0.0);
666
+ if (m_t2eCallback) {
667
+ m_t2eCallback(_ctx->eos(), eosEmbedding.data(), embedBufSize);
668
+ }
669
+ if (!engine.cacheEosEmbedding(eosEmbedding)) {
670
+ __DEBUG("Failed to set the eos token embedding.");
671
+ return false;
672
+ }
673
+ }
674
+
675
+ using FF = Engine::Feature::Flags;
676
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
677
+
678
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
679
+ start.reset();
680
+
681
+ engine.set({{"kv-prefix-skip", _forecast_prefix}});
682
+
683
+ std::vector<int32_t> tokens(1,0);
684
+
685
+ // Process prompt
686
+ // get number of tokens in the input
687
+ size_t curTokensCount = embedding.size()/embedBufSize;
688
+
689
+ if(curTokensCount * embedBufSize != embedding.size()){
690
+ size_t expectedLength = (curTokensCount + (embedding.size()%embedBufSize != 0))*embedBufSize;
691
+ __DEBUG("Input is wrong expected {} and found {}.", expectedLength, embedding.size());
692
+ return Dialog::abort("Input is not an multiple for the embedding Length", callback);
693
+ }
694
+
695
+ _n_prompt += curTokensCount;
696
+
697
+ std::vector<int32_t> attention_map(curTokensCount);
698
+ std::iota(attention_map.begin(), attention_map.end(), -1);
699
+
700
+ engine.set({{"kv-prefix-offset", curTokensCount}}); // Do not attend prefix
701
+
702
+ if (_n_past + curTokensCount > _ctx->size()) {
703
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, curTokensCount, _ctx->size());
704
+ callback("", Sentence::END);
705
+ return true;
706
+ }
707
+
708
+ if (!engine.process(embedding, attention_map, logits, false))
709
+ return Dialog::abort("engine prompt processing failed", callback); // Change this message also to some generic message.
710
+ _n_past += curTokensCount;
711
+ update_kv(_n_past, {});
712
+
713
+ bool status = true;
714
+ if (_n_streams <= 1) {
715
+ tokens[0] = sample_to_verify(std::span{logits.data(),logits.size()}, 0);
716
+
717
+ // Decode the first token.
718
+ _last_tok = tokens[0];
719
+ if (_ctx->is_eos(_last_tok)) {
720
+ callback("", Sentence::END);
721
+ return true;
722
+ }
723
+
724
+ if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
725
+ //decode_token(tokens[0]);
726
+
727
+ if (!m_t2eCallback) {
728
+ callback("", Sentence::END);
729
+ return true;
730
+ }
731
+
732
+ // Mark TTFT
733
+ _kpis.prompt.update(start.elapsed_usec());
734
+ start.reset();
735
+ State::busy(true);
736
+
737
+ // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
738
+ // process separately because logits are required for these tokens
739
+ for (int i = 0; i < _draft; ++i)
740
+ tokens.push_back(_forecast_token_offset + i);
741
+
742
+ attention_map.resize(tokens.size());
743
+ std::iota(attention_map.begin(), attention_map.end(), -1);
744
+ engine.set({{"kv-prefix-offset", 1}}); // Prevent the last token from attending
745
+
746
+ if (_n_past + tokens.size() > _ctx->size()) {
747
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
748
+ callback("", Sentence::END);
749
+ return true;
750
+ }
751
+
752
+ // Convert tokens to embeddings
753
+ // reset embedding vector to make space for the next runs
754
+ embedding.clear();
755
+ convertTokensToEmbeddings(tokens, embedding, embedBufSize, m_t2eCallback);
756
+
757
+ if (!engine.process(embedding, attention_map, logits, true))
758
+ return Dialog::abort("initial inference for SSD pipeline failed", callback);
759
+
760
+ _n_past += 1;
761
+ update_kv(_n_past, {});
762
+
763
+ // Use existing as much as possible
764
+ status = processFollowOnGeneration(tokens, logits, callback);
765
+ } else {
766
+ std::vector<std::vector<int32_t>> streams;
767
+ getTopK(logits, streams, _n_streams, _p_threshold, callback);
768
+
769
+ if (!m_t2eCallback) {
770
+ for (auto& stream : streams) {
771
+ if (!callback(_tokenizer->decode(stream) + "\n", Sentence::BEGIN)) return true;
772
+ }
773
+ callback("", Sentence::END);
774
+ return true;
775
+ }
776
+
777
+ // Mark TTFT
778
+ _kpis.prompt.update(start.elapsed_usec());
779
+ start.reset();
780
+ State::busy(true);
781
+
782
+ if (streams.size() == 0) {
783
+ callback("\n", Sentence::END);
784
+ return true;
785
+ }
786
+
787
+ // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
788
+ // process separately because logits are required for these tokens
789
+ attention_map.resize(1 + _draft);
790
+ std::iota(attention_map.begin(), attention_map.end(), -1);
791
+
792
+ std::vector<size_t> stream_indices(streams.size());
793
+ std::iota(stream_indices.begin(), stream_indices.end(), 0);
794
+
795
+ std::vector<int32_t> multi_attn_mask;
796
+ std::vector<size_t> past_map;
797
+ const size_t kvPrefixOffset = 1;
798
+
799
+ tileAttentionMask(attention_map, stream_indices, past_map, kvPrefixOffset, multi_attn_mask);
800
+
801
+ // Accumulate input tokens from all streams
802
+ std::vector<int32_t> multi_tokens;
803
+
804
+ multi_tokens.reserve(streams.size() * (1 + _draft));
805
+ for (int i = 0; i < streams.size(); i++) {
806
+ multi_tokens.insert(multi_tokens.end(), streams[i].begin(), streams[i].end());
807
+ for (int i = 0; i < _draft; ++i) {
808
+ multi_tokens.push_back(_forecast_token_offset + i);
809
+ }
810
+ }
811
+
812
+ // Convert tokens to embeddings
813
+ // reset embedding vector to make space for the next runs
814
+ embedding.clear();
815
+ convertTokensToEmbeddings(multi_tokens, embedding, embedBufSize, m_t2eCallback);
816
+
817
+ if (_n_past + multi_tokens.size() > _ctx->size()) {
818
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size());
819
+ callback("", Sentence::END);
820
+ return true;
821
+ }
822
+
823
+ if (!engine.process(embedding, multi_attn_mask, logits, true))
824
+ return Dialog::abort("initial inference for SSD pipeline failed", callback);
825
+
826
+ std::vector<bool> selected(multi_tokens.size(), false);
827
+ for (int i = 0; i < multi_tokens.size(); i+=(_draft+1)) {
828
+ selected[i] = true;
829
+ }
830
+
831
+ _n_past += streams.size();
832
+ update_kv(_n_past, selected);
833
+
834
+ status = processFollowOnGeneration(streams, logits, callback);
835
+ }
836
+
837
+ _kpis.generate.update(start.elapsed_usec());
838
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
839
+ start.reset();
840
+
841
+ return status;
842
+ }
843
+
844
+ bool SelfSpecDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback callback) {
845
+
846
+ // Check for prev failures and bail out early
847
+ if (State::failed()) return false;
848
+
849
+ Timer start;
850
+
851
+ if(m_inputType != InputType::TOKENS) {
852
+ __ERROR("Input type for model is not tokens.");
853
+ return false;
854
+ }
855
+
856
+ State::clear();
857
+
858
+ std::vector<float> logits;
859
+ auto& engine = *_engine["primary"];
860
+
861
+ auto update_kv = [&engine, &callback, this](size_t past, const std::vector<bool>& selected) {
862
+ if (!engine.updateKV(past, selected))
863
+ return Dialog::abort("context size exceeded", callback);
864
+ return true;
865
+ };
866
+
867
+ using FF = Engine::Feature::Flags;
868
+ if (engine.supports(FF::DYNAMIC_LOAD)) engine.load();
869
+
870
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
871
+ start.reset();
872
+
873
+ engine.set({{"kv-prefix-skip", _forecast_prefix}});
874
+
875
+ std::vector<int32_t> attention_map(tokens.size());
876
+ std::iota(attention_map.begin(), attention_map.end(), -1);
877
+
878
+ // Process prompt
879
+ _n_prompt += tokens.size();
880
+ engine.set({{"kv-prefix-offset", tokens.size()}}); // Do not attend prefix
881
+
882
+ if (_n_past + tokens.size() > _ctx->size()) {
883
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
884
+ callback("", Sentence::END);
885
+ return true;
886
+ }
887
+
888
+ if (!engine.process(tokens, attention_map, logits, false))
889
+ return Dialog::abort("engine prompt processing failed", callback);
890
+ _n_past += tokens.size();
891
+ update_kv(_n_past, {});
892
+
893
+ bool status = true;
894
+ if (_n_streams <= 1) {
895
+ tokens[0] = sample_to_verify(std::span{logits.data(),logits.size()}, 0);
896
+ tokens.resize(1);
897
+
898
+ // Decode the first token.
899
+ _last_tok = tokens[0];
900
+ if (_ctx->is_eos(_last_tok)) {
901
+ callback("", Sentence::END);
902
+ return true;
903
+ }
904
+
905
+ if (!callback(_tokenizer->decode(tokens), Sentence::BEGIN)) return true;
906
+ // decode_token(tokens[0]);
907
+
908
+ // Mark TTFT
909
+ _kpis.prompt.update(start.elapsed_usec());
910
+ start.reset();
911
+ State::busy(true);
912
+
913
+ // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
914
+ // process separately because logits are required for these tokens
915
+ for (int i = 0; i < _draft; ++i)
916
+ tokens.push_back(_forecast_token_offset + i);
917
+
918
+ attention_map.resize(tokens.size());
919
+ std::iota(attention_map.begin(), attention_map.end(), -1);
920
+ engine.set({{"kv-prefix-offset", 1}}); // Prevent the last token from attending
921
+
922
+ if (_n_past + tokens.size() > _ctx->size()) {
923
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, tokens.size(), _ctx->size());
924
+ callback("", Sentence::END);
925
+ return true;
926
+ }
927
+
928
+ if (!engine.process(tokens, attention_map, logits, true))
929
+ return Dialog::abort("initial inference for SSD pipeline failed", callback);
930
+
931
+ _n_past += 1;
932
+ update_kv(_n_past, {});
933
+
934
+ status = processFollowOnGeneration(tokens, logits, callback);
935
+ } else {
936
+ std::vector<std::vector<int32_t>> streams;
937
+ getTopK(logits, streams, _n_streams, _p_threshold, callback);
938
+
939
+ // Mark TTFT
940
+ _kpis.prompt.update(start.elapsed_usec());
941
+ start.reset();
942
+ State::busy(true);
943
+
944
+ if (streams.size() == 0) {
945
+ callback("\n", Sentence::END);
946
+ return true;
947
+ }
948
+
949
+ // Initial inference for self-speculative decoding pipeline with forecast tokens and prefix
950
+ // process separately because logits are required for these tokens
951
+ attention_map.resize(1 + _draft);
952
+ std::iota(attention_map.begin(), attention_map.end(), -1);
953
+
954
+ std::vector<size_t> stream_indices(streams.size());
955
+ std::iota(stream_indices.begin(), stream_indices.end(), 0);
956
+
957
+ std::vector<int32_t> multi_attn_mask;
958
+ std::vector<size_t> past_map;
959
+ const size_t kvPrefixOffset = 1;
960
+
961
+ tileAttentionMask(attention_map, stream_indices, past_map, kvPrefixOffset, multi_attn_mask);
962
+
963
+ // Accumulate input tokens from all streams
964
+ std::vector<int32_t> multi_tokens;
965
+
966
+ multi_tokens.reserve(streams.size() * (1 + _draft));
967
+ for (int i = 0; i < streams.size(); i++) {
968
+ multi_tokens.insert(multi_tokens.end(), streams[i].begin(), streams[i].end());
969
+ for (int i = 0; i < _draft; ++i) {
970
+ multi_tokens.push_back(_forecast_token_offset + i);
971
+ }
972
+ }
973
+
974
+ if (_n_past + multi_tokens.size() > _ctx->size()) {
975
+ __WARN("Context limit exceeded ({} + {} > {})", _n_past, multi_tokens.size(), _ctx->size());
976
+ callback("", Sentence::END);
977
+ return true;
978
+ }
979
+
980
+ if (!engine.process(multi_tokens, multi_attn_mask, logits, true))
981
+ return Dialog::abort("initial inference for SSD pipeline failed", callback);
982
+
983
+ std::vector<bool> selected(multi_tokens.size(), false);
984
+ for (int i = 0; i < multi_tokens.size(); i+=(_draft+1)) {
985
+ selected[i] = true;
986
+ }
987
+
988
+ _n_past += streams.size();
989
+ update_kv(_n_past, selected);
990
+
991
+ status = processFollowOnGeneration(streams, logits, callback);
992
+ }
993
+
994
+ _kpis.generate.update(start.elapsed_usec());
995
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
996
+ start.reset();
997
+
998
+ return status;
999
+ }
1000
+
1001
+ void SelfSpecDecDialog::reset() {
1002
+ Dialog::reset();
1003
+ _n_past = _forecast_prefix;
1004
+ size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
1005
+ if (n_restored_prefix != _forecast_prefix) {
1006
+ // clang-format off
1007
+ throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
1008
+ n_restored_prefix, _kv_prefix_name, _forecast_prefix ) );
1009
+ // clang-format on
1010
+ }
1011
+ }
1012
+
1013
+ bool SelfSpecDecDialog::save(const std::string& name) {
1014
+ if (_n_streams > 1) {
1015
+ throw std::runtime_error("Save is unsupported for multistream dialogs.");
1016
+ }
1017
+ return Dialog::save(name);
1018
+ }
1019
+
1020
+ bool SelfSpecDecDialog::restore(const std::string& name) {
1021
+ if (_n_streams > 1) {
1022
+ throw std::runtime_error("Restore is unsupported for multistream dialogs.");
1023
+ }
1024
+ return Dialog::restore(name);
1025
+ }
1026
+
1027
+ // Registrator instance
1028
+ static OnLoad regy([]() {
1029
+ Dialog::__register(
1030
+ "ssd-q1",
1031
+ [](std::shared_ptr<Env> env, const std::string& name, const json& conf) {
1032
+ return (Dialog*)new SelfSpecDecDialog(env, name, conf);
1033
+ }
1034
+ );
1035
+ });
1036
+
1037
+ // Register ssd sampler for compatibility
1038
+ static OnLoad sampler_regy([]() {
1039
+ Sampler::__register("basic", [](Context& ctx, const json& conf) {
1040
+ return (Sampler*)new BasicSampler(ctx, conf);
1041
+ });
1042
+ });
1043
+
1044
+ void needSsdDialog() {}
1045
+
1046
+ } // namespace qualla
Genie/Genie/src/qualla/embedding.cpp ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/embedding.hpp>
10
+ #include <qualla/logger.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+ #include <qualla/detail/timer.hpp>
13
+
14
+ #include <functional>
15
+ #include <fstream>
16
+ #include <string>
17
+ #include <unordered_map>
18
+ #include <filesystem>
19
+
20
+ #include <fmt/format.h>
21
+ #include <fmt/ranges.h>
22
+
23
+ namespace fs = std::filesystem;
24
+
25
+ namespace qualla {
26
+
27
+ Embedding::Embedding(std::shared_ptr<Env> env, const std::string& name, const qualla::json& json)
28
+ : _name(name), _env(env) {
29
+ Timer start;
30
+
31
+ _env->logger().debug(fmt::format("embedding-new: {} config {}", name, json.dump()));
32
+
33
+ using qc = qualla::Config;
34
+
35
+ // Parse prompt config
36
+ const qualla::json& pmt_conf = qc::optional<qualla::json>(json, "prompt", {});
37
+ _tags = qc::optional<std::vector<std::string>>(pmt_conf, "tags", {"", ""});
38
+
39
+ // Create the context first
40
+ _ctx = Context::create(*_env, name, qc::optional<qualla::json>(json, "context", {}));
41
+
42
+ // Create Tokenizer
43
+ fs::path tok_path = _env->path().models / qc::mandatory<std::string>(json, "tokenizer");
44
+ _tokenizer = Tokenizer::create(*_ctx, tok_path);
45
+
46
+ // Create Engine
47
+ const qualla::json& eng_conf = qc::mandatory<qualla::json>(json, "engine");
48
+ _engine = Engine::create(*_ctx, eng_conf);
49
+
50
+ // Truncation of input to context
51
+ _input_truncation = qc::optional<qualla::json>(json, "truncate-input", false);
52
+
53
+ using FF = Engine::Feature::Flags;
54
+ if (!_engine->supports(FF::OUTPUT_EMBEDDINGS))
55
+ throw std::runtime_error("engine must output embeddings");
56
+
57
+ _kpis.init.update(start.elapsed_usec());
58
+ }
59
+
60
+ Embedding::~Embedding() {}
61
+
62
+ bool Embedding::process(std::vector<int32_t>& tokens, std::vector<float>& output) {
63
+ Timer start;
64
+
65
+ State::clear();
66
+
67
+ size_t n = _engine->process(tokens, output, false);
68
+ if (!n) {
69
+ State::error("engine prompt processing failed");
70
+ return false;
71
+ }
72
+
73
+ _n_prompt += tokens.size();
74
+
75
+ // Clean the buffer before using
76
+ _output_dimensions.clear();
77
+
78
+ uint64_t output_size = 1;
79
+ // push number of tokens present in the result.
80
+ _output_dimensions.push_back(n);
81
+ // push back the dimension of the each embedding
82
+ _output_dimensions.push_back(_ctx->n_embd());
83
+
84
+ output_size = n * _ctx->n_embd();
85
+
86
+ output.resize(output_size);
87
+
88
+ _kpis.prompt.update(start.elapsed_usec());
89
+
90
+ // Log latest KPIs in a single line
91
+ _env->logger().post(Logger::KPIS, kpis().dump(" "));
92
+
93
+ return true;
94
+ }
95
+
96
+ bool Embedding::query(const std::string& str, std::vector<float>& output) {
97
+ std::string p_str; // prompt string
98
+ std::vector<int32_t> p_vec; // prompt tokens
99
+
100
+ p_vec.reserve(_ctx->n_ctx());
101
+
102
+ p_str = _tags[0] + str + _tags[1];
103
+
104
+ _env->logger().debug(fmt::format("embedding-query: {}", str));
105
+ _env->logger().debug(fmt::format("embedding-prompt: {}", p_str));
106
+
107
+ _n_queries++;
108
+
109
+ _tokenizer->encode(p_str, p_vec);
110
+
111
+ _env->logger().debug(fmt::format("embedding-tokens: {}", p_vec));
112
+
113
+ if(p_vec.size() > (_ctx->n_ctx())){ // Condition to not allow input to exceed context.
114
+ if(_input_truncation == false){
115
+ throw std::runtime_error("Input exceeds the context of the model.");
116
+ }
117
+ else{
118
+ p_vec.resize(_ctx->n_ctx());
119
+ }
120
+ }
121
+
122
+ return process(p_vec, output);
123
+ }
124
+
125
+ // Embedding KPIs helpers
126
+
127
+
128
+ void Embedding::output_dimensions(std::vector<std::uint32_t>& outputDimensions){
129
+ outputDimensions = _output_dimensions;
130
+ }
131
+
132
+ // Get latest KPIs
133
+ Embedding::KPIs& Embedding::kpis() {
134
+ // Update TPS
135
+ if (_n_prompt) {
136
+ float t = _kpis.prompt.total_usec / _n_prompt;
137
+ _kpis.tps.prompt = 1000000.0 / (t ? t : 1000000.0);
138
+ }
139
+
140
+ // We could synthesize more KPIs from from other layers (engine, sampler, etc)
141
+ return _kpis;
142
+ }
143
+
144
+ std::string Embedding::KPIs::dump(std::string_view sep) const {
145
+ return fmt::format(
146
+ "init:[{}]{}prompt:[{}]{} tps-prompt:{:.2f}",
147
+ init.dump(),
148
+ sep,
149
+ prompt.dump(),
150
+ sep,
151
+ tps.prompt
152
+ );
153
+ }
154
+
155
+ void Embedding::KPIs::reset() {
156
+ init.reset();
157
+ prompt.reset();
158
+ tps.prompt = 0.0;
159
+ }
160
+
161
+ // Create API
162
+
163
+ std::unique_ptr<Embedding> Embedding::create(
164
+ std::shared_ptr<Env> env,
165
+ const std::string& name,
166
+ const qualla::json& conf
167
+ ) {
168
+ return std::make_unique<Embedding>(env, name, conf);
169
+ }
170
+
171
+ std::unique_ptr<Embedding> Embedding::create(
172
+ std::shared_ptr<Env> env,
173
+ const std::string& name,
174
+ std::istream& json_stream
175
+ ) {
176
+ return create(env, name, json::parse(json_stream));
177
+ }
178
+
179
+ std::unique_ptr<Embedding> Embedding::create(
180
+ std::shared_ptr<Env> env,
181
+ const std::string& name,
182
+ const fs::path& json_path
183
+ ) {
184
+ if (!fs::exists(json_path))
185
+ throw std::runtime_error(json_path.string() + ": file does not exist");
186
+ std::ifstream ifs(json_path);
187
+ return create(env, name, ifs);
188
+ }
189
+
190
+ } // namespace qualla
Genie/Genie/src/qualla/engine.cpp ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <qualla/engine.hpp>
10
+ #include <qualla/detail/kpi.hpp>
11
+ #include <qualla/detail/config.hpp>
12
+
13
+ #include <functional>
14
+ #include <iostream>
15
+ #include <sstream>
16
+ #include <string>
17
+ #include <unordered_map>
18
+
19
+ #include <fmt/format.h>
20
+ #include <fmt/ranges.h>
21
+
22
+ namespace qualla {
23
+
24
+ Engine::Engine(Context& ctx, const std::string& type, const qualla::json& conf)
25
+ : _type(type), _ctx(ctx), _env(ctx.env()) {
26
+ _env.logger().debug(
27
+ fmt::format("engine-new: {} ctx {} config {}", type, _ctx.name(), conf.dump())
28
+ );
29
+
30
+ using qc = qualla::Config;
31
+ _role = qc::optional<std::string>(conf, "role", "primary");
32
+ }
33
+
34
+ Engine::~Engine() {}
35
+
36
+ size_t Engine::process(
37
+ const std::vector<int32_t>& tokens,
38
+ const std::vector<int32_t>& attention_map,
39
+ std::vector<float>& output,
40
+ bool output_all
41
+ ) {
42
+ _env.logger().error(fmt::format("{}-engine does not support attention_map", _type));
43
+ return 0;
44
+ }
45
+
46
+ size_t Engine::process(const std::vector<int32_t>& tokens) {
47
+ // Derived engines should overwrite this to avoid copying logits
48
+ std::vector<float> logits;
49
+ return process(tokens, logits);
50
+ }
51
+
52
+ size_t Engine::process(
53
+ std::vector<uint8_t>& embeddings,
54
+ const std::vector<int32_t>& attention_map,
55
+ std::vector<float>& output,
56
+ bool output_all
57
+ ) {
58
+ _env.logger().error(fmt::format("{}-engine does not support embedding as input", _type));
59
+ return 0;
60
+ }
61
+
62
+ bool Engine::updateKV(size_t n_past) {
63
+ _env.logger().error(fmt::format("{}-engine does not support sync", _type));
64
+ return false;
65
+ }
66
+
67
+ bool Engine::updateKV(size_t n_past, const std::vector<bool>& selected) {
68
+ _env.logger().error(fmt::format("{}-engine does not support sync with selected", _type));
69
+ return false;
70
+ }
71
+
72
+ size_t Engine::restore(const std::string& name) {
73
+ _env.logger().error(fmt::format("{}-engine does not support restore", _type));
74
+ return 0;
75
+ }
76
+
77
+ bool Engine::save(const std::string& name) {
78
+ _env.logger().error(fmt::format("{}-engine does not support save", _type));
79
+ return false;
80
+ }
81
+
82
+ void Engine::reset() {
83
+ _env.logger().error(fmt::format("{}-engine does not support reset", _type));
84
+ }
85
+
86
+ bool Engine::load() {
87
+ _env.logger().error(fmt::format("{}-engine does not support dynamic load", _type));
88
+ return 0;
89
+ }
90
+
91
+ bool Engine::unload() {
92
+ _env.logger().error(fmt::format("{}-engine does not support dynamic unload", _type));
93
+ return false;
94
+ }
95
+
96
+ bool Engine::set(qualla::json data) {
97
+ _env.logger().error(fmt::format("{}-engine does not support set()", _type));
98
+ return false;
99
+ }
100
+
101
+ qualla::json Engine::get() {
102
+ _env.logger().error(fmt::format("{}-engine does not support get()", _type));
103
+ return false;
104
+ }
105
+
106
+ bool Engine::cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding) {
107
+ _env.logger().error(fmt::format("{}-engine does not support cache eos embedding", _type));
108
+ return true;
109
+ }
110
+
111
+ size_t Engine::getEmbeddingBufferSize() {
112
+ _env.logger().error(fmt::format("{}-engine does not support embedding vectors", _type));
113
+ return 0;
114
+ }
115
+
116
+ qualla::InputType Engine::getInputType(){
117
+ return qualla::InputType::TOKENS;
118
+ }
119
+
120
+ // Engine KPIs
121
+
122
+ std::string Engine::KPIs::dump(std::string_view sep) const {
123
+ return fmt::format(
124
+ "load:[{}]{}process:[{}]{}update-kv:[{}]{}unload:[{}]",
125
+ load.dump(),
126
+ sep,
127
+ process.dump(),
128
+ sep,
129
+ update_kv.dump(),
130
+ sep,
131
+ unload.dump()
132
+ );
133
+ }
134
+
135
+ void Engine::KPIs::reset() {
136
+ load.reset();
137
+ process.reset();
138
+ update_kv.reset();
139
+ unload.reset();
140
+ }
141
+
142
+ // Engine registry type string + creator function
143
+ using Registry = std::unordered_map<std::string, Engine::Creator>;
144
+ static std::unique_ptr<Registry> registry;
145
+
146
+ void Engine::__register(const std::string& type, Creator func) {
147
+ if (!registry) registry = std::make_unique<Registry>();
148
+
149
+ Registry& r = *registry;
150
+ r[type] = func;
151
+ }
152
+
153
+ std::unique_ptr<Engine> Engine::create(Context& ctx, const qualla::json& conf) {
154
+ using qc = qualla::Config;
155
+
156
+ std::string type = qc::mandatory<std::string>(conf, "type");
157
+
158
+
159
+ if (!registry) throw std::runtime_error(type + ": engine not found");
160
+
161
+ Registry& r = *registry;
162
+
163
+
164
+ if (!r.contains(type)) throw std::runtime_error(type + ": engine not found");
165
+
166
+
167
+ return std::unique_ptr<Engine>(r[type](ctx, conf));
168
+ }
169
+
170
+ std::unique_ptr<Engine> Engine::create(Context& ctx, std::istream& json_stream) {
171
+ return create(ctx, json::parse(json_stream));
172
+ }
173
+
174
+ std::unique_ptr<Engine> Engine::create(Context& ctx, const std::string& json_str) {
175
+ return create(ctx, json::parse(json_str));
176
+ }
177
+
178
+ std::vector<std::string> Engine::list() {
179
+ std::vector<std::string> v;
180
+ if (!registry) return v;
181
+
182
+ Registry& r = *registry;
183
+
184
+ for (auto k : r)
185
+ v.push_back(k.first);
186
+ return v;
187
+ }
188
+
189
+ bool Engine::applyLoraAdapter(std::string lora_adapter_name) {
190
+ _env.logger().error(fmt::format("{}-engine does not support LoraAdapter", _type));
191
+ return false;
192
+ }
193
+ bool Engine::applyLoraStrength(std::string tensor_name, float tensor_val) {
194
+ _env.logger().error(fmt::format("{}-engine does not support setLoraStrength", _type));
195
+ return false;
196
+ }
197
+
198
+ } // namespace qualla
Genie/Genie/src/qualla/engines/lib.cpp ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ // Just a stub for building qualla::engines when no built-in engines are enabled
Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.cpp ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "dlwrap.hpp"
10
+ #include "BackendExtensions.hpp"
11
+ #include "NetRunBackend.hpp"
12
+
13
+ BackendExtensions::BackendExtensions(
14
+ BackendExtensionsConfigs backendExtensionsConfig,
15
+ void* backendLibHandle,
16
+ PerfProfile perfProfile,
17
+ std::shared_ptr<ICommandLineManager> clManager,
18
+ bool debug_qnn
19
+ )
20
+ : m_backendExtensionsLibPath(backendExtensionsConfig.sharedLibraryPath),
21
+ m_backendExtensionsConfigPath(backendExtensionsConfig.configFilePath),
22
+ m_backendInterface(nullptr), m_isNetRunBackendInterface(false),
23
+ m_createBackendInterfaceFn(nullptr), m_destroyBackendInterfaceFn(nullptr),
24
+ m_backendLibHandle(backendLibHandle), m_perfProfile(perfProfile), m_clManager(clManager),
25
+ m_debugQnn(debug_qnn) {
26
+ (void)m_perfProfile;
27
+ }
28
+
29
+ BackendExtensions::~BackendExtensions() {
30
+ if (nullptr != m_backendInterface) {
31
+ if (m_isNetRunBackendInterface) {
32
+ QNN_DEBUG("Deleting NetRun Backend Interface");
33
+ delete m_backendInterface;
34
+ } else {
35
+ if (nullptr != m_destroyBackendInterfaceFn) {
36
+ QNN_DEBUG("Destroying Backend Interface");
37
+ m_destroyBackendInterfaceFn(m_backendInterface);
38
+ }
39
+ }
40
+ }
41
+ }
42
+
43
+ bool BackendExtensions::loadFunctionPointers() {
44
+
45
+ void* libHandle = dlopen(m_backendExtensionsLibPath.c_str(), RTLD_NOW | RTLD_LOCAL);
46
+ if (nullptr == libHandle) {
47
+ QNN_ERROR(
48
+ "Unable to load backend extensions lib: [%s]. dlerror(): [%s]",
49
+ m_backendExtensionsLibPath.c_str(),
50
+ dlerror()
51
+ );
52
+ return false;
53
+ }
54
+ m_createBackendInterfaceFn =
55
+ (CreateBackendInterfaceFnType_t)dlsym(libHandle, "createBackendInterface");
56
+ m_destroyBackendInterfaceFn =
57
+ (DestroyBackendInterfaceFnType_t)dlsym(libHandle, "destroyBackendInterface");
58
+ if (nullptr == m_createBackendInterfaceFn || nullptr == m_destroyBackendInterfaceFn) {
59
+ QNN_ERROR("Unable to find symbols. dlerror(): [%s]", dlerror());
60
+ return false;
61
+ }
62
+
63
+ return true;
64
+ }
65
+
66
+ void BackendExtensions::qnnLogCallback(
67
+ const char* fmt,
68
+ QnnLog_Level_t level,
69
+ uint64_t timestamp,
70
+ va_list args
71
+ ) {
72
+ char buffer[1024] = "";
73
+ const char* levelStr = "";
74
+ switch (level) {
75
+ case QNN_LOG_LEVEL_ERROR:
76
+ levelStr = " ERROR ";
77
+ break;
78
+ case QNN_LOG_LEVEL_WARN:
79
+ levelStr = "WARNING";
80
+ break;
81
+ case QNN_LOG_LEVEL_INFO:
82
+ levelStr = " INFO ";
83
+ break;
84
+ case QNN_LOG_LEVEL_DEBUG:
85
+ levelStr = " DEBUG ";
86
+ break;
87
+ case QNN_LOG_LEVEL_VERBOSE:
88
+ levelStr = "VERBOSE";
89
+ break;
90
+ case QNN_LOG_LEVEL_MAX:
91
+ levelStr = "UNKNOWN";
92
+ break;
93
+ }
94
+
95
+ int pos = snprintf(
96
+ buffer, sizeof(buffer), "QNN: [%s] time=%lu:", levelStr, (unsigned long)timestamp
97
+ );
98
+ vsnprintf(buffer + pos, sizeof(buffer) - pos, fmt, args);
99
+ printf("%s", buffer);
100
+ }
101
+
102
+ bool BackendExtensions::initialize() {
103
+
104
+ QNN_DEBUG("DEBUG: m_backendExtensionsLibPath=%s\n", m_backendExtensionsLibPath.c_str());
105
+ QNN_DEBUG("DEBUG: m_backendExtensionsConfigPath=%s\n", m_backendExtensionsConfigPath.c_str());
106
+ if (m_backendExtensionsLibPath.empty() && m_backendExtensionsConfigPath.empty()) {
107
+ QNN_WARN("No BackendExtensions lib provided; initializing NetRunBackend Interface");
108
+ m_isNetRunBackendInterface = true;
109
+ m_backendInterface = new NetRunBackend();
110
+ } else {
111
+ QNN_DEBUG("Loading supplied backend extensions lib.");
112
+ QNN_DEBUG("Backend extensions lib path: %s", m_backendExtensionsLibPath.c_str());
113
+ if (m_backendExtensionsConfigPath.empty()) {
114
+ QNN_DEBUG("Backend extensions lib specified without a config file.");
115
+ } else {
116
+ QNN_DEBUG("Backend extensions config path: %s", m_backendExtensionsConfigPath.c_str());
117
+ }
118
+ if (!loadFunctionPointers()) {
119
+ QNN_ERROR("Failed to load function pointers.");
120
+ return false;
121
+ }
122
+ if (nullptr != m_createBackendInterfaceFn) {
123
+ m_backendInterface = m_createBackendInterfaceFn();
124
+ }
125
+ }
126
+ if (nullptr == m_backendInterface) {
127
+ QNN_ERROR("Unable to load backend extensions interface.");
128
+ return false;
129
+ }
130
+ if (m_debugQnn) {
131
+ if (!(m_backendInterface->setupLogging(BackendExtensions::qnnLogCallback, QNN_LOG_LEVEL_VERBOSE))) {
132
+ QNN_WARN("Unable to initialize logging in backend extensions.");
133
+ }
134
+ }
135
+ if (!m_backendInterface->initialize(m_backendLibHandle)) {
136
+ QNN_ERROR("Unable to initialize backend extensions interface.");
137
+ return false;
138
+ }
139
+ if (!m_backendInterface->setPerfProfile(m_perfProfile)) {
140
+ QNN_WARN("Unable to set perf profile in backend extensions interface.");
141
+ //return false;
142
+ }
143
+ if (!m_backendInterface->loadConfig(m_backendExtensionsConfigPath)) {
144
+ QNN_ERROR("Unable to load backend extensions interface config.");
145
+ return false;
146
+ }
147
+
148
+ if ((m_clManager != nullptr) && !m_backendInterface->loadCommandLineArgs(m_clManager)) {
149
+ QNN_ERROR("Unable to load backend extensions' command line arguments.");
150
+ return false;
151
+ }
152
+
153
+ return true;
154
+ }
155
+
156
+ IBackend* BackendExtensions::interface() {
157
+ return m_backendInterface;
158
+ }
Genie/Genie/src/qualla/engines/qnn-api/BackendExtensions.hpp ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <string>
12
+
13
+ #include "IBackend.hpp"
14
+ #include "QnnConfig.hpp"
15
+ #include "Log.hpp"
16
+
17
+ // This is a wrapper class that handles resources/state related to
18
+ // backend extensions interface. This is used by QnnNetRun library
19
+ // to manage and call into an IBackend interface implementation.
20
+ // Functionality present in this class:
21
+ // 1. Receives the argument string related to backend_extensions
22
+ // argument from the front end and processes it to open the
23
+ // backend extensions library.
24
+ // 2. Locates and stores symbols for creating and destroying the
25
+ // IBackend interface implementation.
26
+ // 3. If there is no backend_extensions argument, this class creates
27
+ // the dummy IBackend implementation aka NetRunBackend.
28
+ // 4. Gives QnnNetRun access to the implementation itself through
29
+ // interface() function.
30
+ class BackendExtensions final {
31
+ public:
32
+ BackendExtensions(
33
+ BackendExtensionsConfigs backendExtensionsConfig,
34
+ void* backendLibHandle,
35
+ PerfProfile perfProfile,
36
+ std::shared_ptr<ICommandLineManager> clManager =
37
+ std::shared_ptr<ICommandLineManager>(nullptr),
38
+ bool debug_qnn = false
39
+ );
40
+ ~BackendExtensions();
41
+ bool initialize();
42
+ IBackend* interface();
43
+
44
+ private:
45
+ bool loadFunctionPointers();
46
+ std::string m_backendExtensionsLibPath;
47
+ std::string m_backendExtensionsConfigPath;
48
+ IBackend* m_backendInterface;
49
+ bool m_isNetRunBackendInterface;
50
+ CreateBackendInterfaceFnType_t m_createBackendInterfaceFn;
51
+ DestroyBackendInterfaceFnType_t m_destroyBackendInterfaceFn;
52
+ void* m_backendLibHandle;
53
+ PerfProfile m_perfProfile;
54
+ std::shared_ptr<ICommandLineManager> m_clManager;
55
+ bool m_debugQnn{false};
56
+ static void qnnLogCallback(
57
+ const char* fmt,
58
+ QnnLog_Level_t level,
59
+ uint64_t timestamp,
60
+ va_list args
61
+ );
62
+ };
Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.cpp ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "ClientBuffer.hpp"
10
+ #include "QnnTypeMacros.hpp"
11
+
12
+ void* ClientBuffer::getBuffer(Qnn_Tensor_t* tensor) {
13
+ if (!tensor) {
14
+ QNN_WARN("getBuffer: received a null pointer to a tensor");
15
+ return nullptr;
16
+ }
17
+ return QNN_TENSOR_GET_CLIENT_BUF(tensor).data;
18
+ }
19
+
20
+ size_t ClientBuffer::getBufferSize(Qnn_Tensor_t* tensor) {
21
+ if (!tensor) {
22
+ QNN_WARN("getBufferSize: received a null pointer to a tensor");
23
+ return 0;
24
+ }
25
+ return QNN_TENSOR_GET_CLIENT_BUF(tensor).dataSize;
26
+ };
27
+
28
+ bool ClientBuffer::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) {
29
+ if (!tensor) {
30
+ QNN_ERROR("Received nullptr for tensors");
31
+ return false;
32
+ }
33
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_RAW);
34
+ Qnn_ClientBuffer_t clientBuffer;
35
+ clientBuffer.data = malloc(tensorDataSize);
36
+ if (nullptr == clientBuffer.data) {
37
+ QNN_ERROR("mem alloc failed for clientBuffer.data");
38
+ return false;
39
+ }
40
+ clientBuffer.dataSize = tensorDataSize;
41
+ QNN_TENSOR_SET_CLIENT_BUF(tensor, clientBuffer);
42
+ return true;
43
+ }
44
+
45
+ bool ClientBuffer::freeTensorBuffer(Qnn_Tensor_t* tensor) {
46
+ if (!tensor) {
47
+ QNN_ERROR("Received nullptr for tensors");
48
+ return false;
49
+ }
50
+ if (QNN_TENSOR_GET_CLIENT_BUF(tensor).data) {
51
+ if (m_sameMemoryFreeTensors.find(tensor) == m_sameMemoryFreeTensors.end()) {
52
+ free(QNN_TENSOR_GET_CLIENT_BUF(tensor).data);
53
+ }
54
+ QNN_TENSOR_SET_CLIENT_BUF(tensor, Qnn_ClientBuffer_t({nullptr, 0u}));
55
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED);
56
+ }
57
+ return true;
58
+ }
59
+
60
+ bool ClientBuffer::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
61
+ if (nullptr == dest || nullptr == src) {
62
+ QNN_ERROR("Received nullptr");
63
+ return false;
64
+ }
65
+ if (false == freeTensorBuffer(dest)) {
66
+ return false;
67
+ }
68
+
69
+ QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
70
+ QNN_TENSOR_SET_CLIENT_BUF(dest, QNN_TENSOR_GET_CLIENT_BUF(src));
71
+ m_sameMemoryFreeTensors.insert(dest);
72
+ return true;
73
+ }
74
+
75
+ bool ClientBuffer::useExternalMemory(Qnn_Tensor_t* dest, void* extMem) {
76
+ if (nullptr == dest || nullptr == extMem) {
77
+ QNN_ERROR("Received nullptr");
78
+ return false;
79
+ }
80
+
81
+ Qnn_ClientBuffer_t clientBuffer;
82
+ clientBuffer.data = extMem;
83
+ clientBuffer.dataSize = QNN_TENSOR_GET_CLIENT_BUF(dest).dataSize;
84
+ if (false == freeTensorBuffer(dest)) {
85
+ return false;
86
+ }
87
+
88
+ QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSORMEMTYPE_RAW);
89
+ QNN_TENSOR_SET_CLIENT_BUF(dest, clientBuffer);
90
+ m_sameMemoryFreeTensors.insert(dest);
91
+ return true;
92
+ }
93
+
94
+ void* ClientBuffer::allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) {
95
+ return nullptr;
96
+ }
97
+
98
+ bool ClientBuffer::mapFusedBufferOffset(
99
+ Qnn_Tensor_t* tensor,
100
+ size_t tensorDataSize,
101
+ int32_t fd,
102
+ uint32_t offset,
103
+ uint64_t totalBufferSize,
104
+ void* memPointer,
105
+ Qnn_ContextHandle_t contextHandle
106
+ ) {
107
+ return false;
108
+ }
109
+
110
+ bool ClientBuffer::deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) {
111
+ return false;
112
+ }
113
+
114
+ void ClientBuffer::freeFusedBuffers() {}
115
+
116
+ size_t ClientBuffer::getOffset(Qnn_Tensor_t* tensor) {
117
+ return 0;
118
+ }
119
+
120
+ size_t ClientBuffer::getTotalBufferSize(Qnn_Tensor_t* tensor) {
121
+ return 0;
122
+ }
Genie/Genie/src/qualla/engines/qnn-api/ClientBuffer.hpp ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include "IBufferAlloc.hpp"
12
+ #include "Log.hpp"
13
+ #include <unordered_set>
14
+ #include <stdlib.h>
15
+
16
+ class ClientBuffer final : public IBufferAlloc {
17
+ public:
18
+ ClientBuffer() {};
19
+
20
+ // Disable copy constructors, r-value referencing, etc
21
+ ClientBuffer(const ClientBuffer&) = delete;
22
+
23
+ ClientBuffer& operator=(const ClientBuffer&) = delete;
24
+
25
+ ClientBuffer(ClientBuffer&&) = delete;
26
+
27
+ ClientBuffer& operator=(ClientBuffer&&) = delete;
28
+
29
+ bool initialize() override { return true; };
30
+
31
+ void* getBuffer(Qnn_Tensor_t* tensor) override;
32
+
33
+ int getFd(Qnn_Tensor_t* tensor) override {
34
+ QNN_WARN("getFd: This is not ION memory");
35
+ return -1;
36
+ };
37
+
38
+ size_t getOffset(Qnn_Tensor_t* tensor) override;
39
+ size_t getBufferSize(Qnn_Tensor_t* tensor) override;
40
+ size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override;
41
+
42
+ bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override;
43
+
44
+ bool freeTensorBuffer(Qnn_Tensor_t* tensor) override;
45
+
46
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override;
47
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override { return false; }
48
+
49
+ bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override;
50
+
51
+ void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override;
52
+ bool allocateBuffers(
53
+ const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
54
+ std::map<std::string, std::pair<int, size_t>>& tensor_offsets
55
+ ) override {
56
+ return false;
57
+ };
58
+
59
+ bool mapFusedBufferOffset(
60
+ Qnn_Tensor_t* tensor,
61
+ size_t tensorDataSize,
62
+ int32_t fd,
63
+ uint32_t offset,
64
+ uint64_t totalBufferSize,
65
+ void* memPointer,
66
+ Qnn_ContextHandle_t contextHandle
67
+ ) override;
68
+ bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override;
69
+ void freeFusedBuffers() override;
70
+
71
+ bool mapFusedBufferOffset(
72
+ Qnn_Tensor_t* tensor,
73
+ int alloc_idx,
74
+ size_t offset,
75
+ Qnn_ContextHandle_t ctx,
76
+ size_t size
77
+ ) override {
78
+ return false;
79
+ }
80
+
81
+ virtual ~ClientBuffer() {};
82
+
83
+ private:
84
+ std::unordered_set<Qnn_Tensor_t*> m_sameMemoryFreeTensors;
85
+ };
Genie/Genie/src/qualla/engines/qnn-api/IBackend.hpp ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <map>
12
+ #include "ICommandLineManager.hpp"
13
+ #include "QnnBackend.h"
14
+ #include "QnnContext.h"
15
+ #include "QnnGraph.h"
16
+ #include "QnnLog.h"
17
+ #include "QnnTypeDef.hpp"
18
+ #include "QnnProfile.h"
19
+ #include "QnnDevice.h"
20
+
21
+ // Compile-time definition to check for QNN SDK features using the QNN API version
22
+ #define QUALLA_QNN_API_VERSION \
23
+ (QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH)
24
+
25
+ const uint32_t g_profilingLevelNotSet = 0;
26
+
27
+ enum class PerfProfile {
28
+ LOW_BALANCED,
29
+ BALANCED,
30
+ DEFAULT,
31
+ HIGH_PERFORMANCE,
32
+ SUSTAINED_HIGH_PERFORMANCE,
33
+ BURST,
34
+ EXTREME_POWER_SAVER,
35
+ LOW_POWER_SAVER,
36
+ POWER_SAVER,
37
+ HIGH_POWER_SAVER,
38
+ SYSTEM_SETTINGS,
39
+ NO_USER_INPUT,
40
+ CUSTOM,
41
+ INVALID
42
+ };
43
+
44
+ // This is the interface that enables backend specific extensions in qnn-net-run.
45
+ // It is designed as hooks in the timeline of various events in NetRun.
46
+ // Backends that intend to implement custom features through qnn-net-run will have
47
+ // to implement this interface and add functionality in appropriate methods depending
48
+ // on where/when the custom functionality needs to be exercised.
49
+ // These functions/hooks will be called through the IBackend interface from within
50
+ // qnn-net-run wherever necessary.
51
+ class IBackend {
52
+ public:
53
+ virtual ~IBackend() {}
54
+
55
+ virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) = 0;
56
+
57
+ virtual bool initialize(void* backendLibHandle) = 0;
58
+
59
+ virtual bool setPerfProfile(PerfProfile perfProfile) = 0;
60
+
61
+ virtual QnnProfile_Level_t getProfilingLevel() = 0;
62
+
63
+ virtual bool loadConfig(std::string configFile) = 0;
64
+
65
+ virtual bool loadCommandLineArgs(std::shared_ptr<ICommandLineManager> clManager) = 0;
66
+
67
+ virtual bool beforeBackendInitialize(
68
+ QnnBackend_Config_t*** customConfigs,
69
+ uint32_t* configCount
70
+ ) = 0;
71
+
72
+ virtual bool afterBackendInitialize() = 0;
73
+
74
+ virtual bool beforeContextCreate(
75
+ QnnContext_Config_t*** customConfigs,
76
+ uint32_t* configCount
77
+ ) = 0;
78
+
79
+ virtual bool afterContextCreate() = 0;
80
+
81
+ virtual bool beforeComposeGraphs(
82
+ GraphConfigInfo_t*** customGraphConfigs,
83
+ uint32_t* graphCount
84
+ ) = 0;
85
+
86
+ virtual bool afterComposeGraphs() = 0;
87
+
88
+ #if QUALLA_QNN_API_VERSION >= 21700
89
+ virtual bool beforeGraphFinalizeUpdateConfig(
90
+ const char* graphName,
91
+ Qnn_GraphHandle_t graphHandle,
92
+ QnnGraph_Config_t*** customConfigs,
93
+ uint32_t* configCount
94
+ ) = 0;
95
+ #endif
96
+
97
+ virtual bool beforeGraphFinalize() = 0;
98
+
99
+ virtual bool afterGraphFinalize() = 0;
100
+
101
+ virtual bool beforeRegisterOpPackages() = 0;
102
+
103
+ virtual bool afterRegisterOpPackages() = 0;
104
+
105
+ virtual bool beforeExecute(
106
+ const char* graphName,
107
+ QnnGraph_Config_t*** customConfigs,
108
+ uint32_t* configCount
109
+ ) = 0;
110
+
111
+ virtual bool afterExecute() = 0;
112
+
113
+ virtual bool beforeContextFree() = 0;
114
+
115
+ virtual bool afterContextFree() = 0;
116
+
117
+ virtual bool beforeBackendTerminate() = 0;
118
+
119
+ virtual bool afterBackendTerminate() = 0;
120
+
121
+ virtual bool beforeCreateFromBinary(
122
+ QnnContext_Config_t*** customConfigs,
123
+ uint32_t* configCount
124
+ ) = 0;
125
+
126
+ virtual bool afterCreateFromBinary() = 0;
127
+
128
+ #if QUALLA_QNN_API_VERSION >= 21700
129
+ virtual bool beforeCreateContextsFromBinaryList(
130
+ std::map<std::string, std::tuple<QnnContext_Config_t**, uint32_t>>*
131
+ contextKeyToCustomConfigsMap,
132
+ QnnContext_Config_t*** commonCustomConfigs,
133
+ uint32_t* commonConfigCount
134
+ ) = 0;
135
+
136
+ virtual bool afterCreateContextsFromBinaryList() = 0;
137
+ #endif
138
+
139
+ virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, uint32_t* configCount) = 0;
140
+
141
+ virtual bool afterCreateDevice() = 0;
142
+
143
+ virtual bool beforeFreeDevice() = 0;
144
+
145
+ virtual bool afterFreeDevice() = 0;
146
+ };
147
+
148
+ // These are the function types that the backend extensions shared library is
149
+ // expected to expose. The first function helps NetRun obtain a valid implementation
150
+ // of IBackend interface and the second is used to destroy the same interface at the end.
151
+ // The function names themselves are expected to be these strings:
152
+ // 1. "createBackendInterface"
153
+ // 2. "destroyBackendInterface"
154
+ // These functions need to be tagged with extern "C" and their symbols need to be exposed.
155
+ typedef IBackend* (*CreateBackendInterfaceFnType_t)();
156
+ typedef void (*DestroyBackendInterfaceFnType_t)(IBackend*);
Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+ #include "QnnTypes.h"
11
+ #include <map>
12
+ #include <string>
13
+ #include <vector>
14
+ #include <utility>
15
+ #include <unordered_map>
16
+
17
+ class IBufferAlloc {
18
+ public:
19
+ virtual ~IBufferAlloc() {}
20
+ IBufferAlloc() {}
21
+ virtual bool initialize() = 0;
22
+ virtual void* getBuffer(Qnn_Tensor_t* tensor) = 0;
23
+ virtual int getFd(Qnn_Tensor_t* tensor) = 0;
24
+ virtual size_t getOffset(Qnn_Tensor_t* tensor) = 0;
25
+ virtual size_t getBufferSize(Qnn_Tensor_t* tensor) = 0;
26
+ virtual size_t getTotalBufferSize(Qnn_Tensor_t* tensor) = 0;
27
+ virtual bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) = 0;
28
+ virtual bool freeTensorBuffer(Qnn_Tensor_t* tensor) = 0;
29
+ virtual bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) = 0;
30
+ virtual bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) = 0;
31
+ virtual bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) = 0;
32
+ virtual void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) = 0;
33
+ virtual bool allocateBuffers(
34
+ const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
35
+ std::map<std::string, std::pair<int, size_t>>& tensor_offsets
36
+ ) = 0;
37
+ virtual bool mapFusedBufferOffset(
38
+ Qnn_Tensor_t* tensor,
39
+ size_t tensorDataSize,
40
+ int32_t fd,
41
+ uint32_t offset,
42
+ uint64_t totalBufferSize,
43
+ void* memPointer,
44
+ Qnn_ContextHandle_t contextHandle
45
+ ) = 0;
46
+ virtual bool mapFusedBufferOffset(
47
+ Qnn_Tensor_t* tensor,
48
+ int alloc_idx,
49
+ size_t offset,
50
+ Qnn_ContextHandle_t ctx,
51
+ size_t size
52
+ ) = 0;
53
+
54
+ virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0;
55
+ virtual void freeFusedBuffers() = 0;
56
+ };
Genie/Genie/src/qualla/engines/qnn-api/ICommandLineManager.hpp ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <cctype>
12
+ #include <memory>
13
+ #include <string>
14
+ #include <tuple>
15
+ #include <vector>
16
+
17
+ class ICommandLineManager {
18
+ public:
19
+ enum class Error { SUCCESS, PARSE_FAILURE, UNUSED_ARGUMENTS, OVER_SUBSCRIBED_ARGUMENTS };
20
+
21
+ using ValueList_t = std::vector<std::shared_ptr<const std::string>>;
22
+
23
+ /**
24
+ * @brief Parses provided command line arguments into key value pairs
25
+ *
26
+ * @param[in] argc Number of char* arguments in argv
27
+ *
28
+ * @param[in] argv Pointer to first element of null terminated character arrays
29
+ *
30
+ * @return Error code:
31
+ * - SUCCESS: provided command line arguments match expected format: --key=value, --key
32
+ * - PARSE_FAILURE: The provided command line arguments do not match expected format
33
+ *
34
+ */
35
+ virtual Error parseClArgs(size_t argc, char** argv) = 0;
36
+
37
+ /**
38
+ * @brief Provides passed values for requested key if available
39
+ *
40
+ * @param[in] key Key string of option
41
+ *
42
+ * @return (False, empty) if key is not an available argument
43
+ *
44
+ */
45
+ virtual std::tuple<bool, ValueList_t> serveArg(const std::string& key) = 0;
46
+
47
+ /**
48
+ * @brief Checks whether any provided commandline arguments remain unserved
49
+ *
50
+ * @return True if unconsumed arguments remain, False otherwise
51
+ */
52
+ virtual bool allArgumentsServed() const = 0;
53
+
54
+ /**
55
+ * @brief Validates command line arguments were correctly utilized
56
+ *
57
+ * @return Error code:
58
+ * - SUCCESS: provided command line arguments were utilized following implementations
59
+ * policy
60
+ * - UNUSED_ARGUMENTS: Some arguments passed were not consumed
61
+ * - OVER_SUBSCRIBED_ARGUMENTS: Some arguments were requested by multiple times
62
+ *
63
+ */
64
+ virtual Error validateUsage() = 0;
65
+
66
+ virtual ~ICommandLineManager() = default;
67
+
68
+ static bool isKey(const std::string& arg) {
69
+ return (arg.length() > keyPrefix().length()) && (arg.find(keyPrefix()) == 0) &&
70
+ std::isalpha(arg.at(keyPrefix().length()));
71
+ }
72
+
73
+ static Error parseKey(const std::string& arg, std::string& keyOut) {
74
+ if (!isKey(arg)) {
75
+ return Error::PARSE_FAILURE;
76
+ }
77
+
78
+ auto valueSplit = arg.find(keyValueSplit());
79
+ keyOut = valueSplit != arg.npos ? arg.substr(0, valueSplit) : arg;
80
+ return Error::SUCCESS;
81
+ }
82
+
83
+ static Error parseValue(const std::string& arg, std::string& valueOut) {
84
+ auto valueSplit = arg.find(keyValueSplit());
85
+ if (valueSplit == arg.npos || valueSplit == arg.length() - 1) {
86
+ return Error::PARSE_FAILURE;
87
+ }
88
+ valueOut = arg.substr(valueSplit + 1);
89
+ return Error::SUCCESS;
90
+ }
91
+
92
+ private:
93
+ static const std::string keyPrefix() { return "--"; };
94
+ static char keyValueSplit() { return '='; };
95
+ };
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+ #include <cstring>
9
+ #include <fstream>
10
+ #include <iostream>
11
+
12
+ #include "ClientBuffer.hpp"
13
+ #include "IBufferAlloc.hpp"
14
+ #include "IOTensor.hpp"
15
+ #include "RpcMem.hpp"
16
+ #include "QnnTypeMacros.hpp"
17
+
18
+ #ifdef _WIN32
19
+ #define __strdup _strdup
20
+ #else
21
+ #define __strdup strdup
22
+ #endif
23
+
24
+ IOTensor::IOTensor(BufferAlloc bufferAllocIn, QNN_INTERFACE_VER_TYPE* qnnInterface)
25
+ : m_bufferAlloc(bufferAllocIn), m_qnnInterface(qnnInterface),
26
+ m_bufferManager(new ClientBuffer()) {}
27
+
28
+ bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
29
+ if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
30
+ m_bufferManager = std::unique_ptr<IBufferAlloc>(new RpcMem(contextHandle, m_qnnInterface));
31
+ }
32
+
33
+ if (true != m_bufferManager->initialize()) {
34
+ QNN_ERROR("Failed to initialize buffer manager");
35
+ return false;
36
+ }
37
+
38
+ return true;
39
+ }
40
+
41
+ IOTensor::~IOTensor() {
42
+ if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
43
+ m_bufferManager->freeFusedBuffers();
44
+ }
45
+ }
46
+
47
+ // Setup details for Qnn_Tensor_t for execution
48
+ // based on information in TensorWrapper provided by model.so.
49
+ bool IOTensor::setupTensors(
50
+ Qnn_Tensor_t** tensors,
51
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
52
+ uint32_t tensorCount,
53
+ TensorWrapper* tensorWrappers,
54
+ std::unordered_map<std::string, size_t>& tensorsSize,
55
+ Qnn_ContextHandle_t contextHandle,
56
+ bool skipBufferAllocation
57
+ ) {
58
+
59
+ if (nullptr == tensorWrappers) {
60
+ QNN_ERROR("tensorWrappers is nullptr");
61
+ return false;
62
+ }
63
+ if (0 == tensorCount) {
64
+ QNN_DEBUG("tensor count is 0. Nothing to setup.");
65
+ return true;
66
+ }
67
+
68
+ *tensors = (Qnn_Tensor_t*)calloc(1, tensorCount * sizeof(Qnn_Tensor_t));
69
+ if (nullptr == *tensors) {
70
+ QNN_ERROR("mem alloc failed for *tensors");
71
+ return false;
72
+ }
73
+
74
+ auto returnStatus = true;
75
+
76
+ uint64_t totalBufferSize = 0;
77
+ void* memPointer = nullptr;
78
+ int32_t fd = -1;
79
+ if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
80
+ // Calculate the total size of the tensors
81
+ for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
82
+ auto wrapperTensorName =
83
+ std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx]));
84
+ totalBufferSize += tensorsSize[wrapperTensorName];
85
+ }
86
+ QNN_DEBUG("Calculated total size %lu", totalBufferSize);
87
+
88
+ if (!skipBufferAllocation) {
89
+ // Allocate the buffer of this size
90
+ memPointer = m_bufferManager->allocateTensorFusedBuffer(totalBufferSize, &fd);
91
+ if (memPointer) {
92
+ QNN_DEBUG(
93
+ "Successfully allocated a buffer of size %lu, pointer %p, fd %d",
94
+ (unsigned long)totalBufferSize,
95
+ memPointer,
96
+ fd
97
+ );
98
+ } else {
99
+ QNN_ERROR(
100
+ "Not able to allocate buffer of size %lu", (unsigned long)totalBufferSize
101
+ );
102
+ return false;
103
+ }
104
+ }
105
+ }
106
+
107
+ uint64_t offset = 0;
108
+
109
+ for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
110
+ Qnn_Tensor_t wrapperTensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tensorIdx]);
111
+ auto wrapperTensorName = std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx]));
112
+ if (true == returnStatus) {
113
+ (*tensors)[tensorIdx] = QNN_TENSOR_INIT;
114
+ returnStatus = deepCopyQnnTensorInfo(((*tensors) + tensorIdx), &wrapperTensor);
115
+ }
116
+ if (true == returnStatus) {
117
+ size_t tensorDataSize = tensorsSize[wrapperTensorName];
118
+ if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
119
+ if (!skipBufferAllocation) {
120
+ returnStatus = m_bufferManager->mapFusedBufferOffset(
121
+ ((*tensors) + tensorIdx),
122
+ tensorDataSize,
123
+ fd,
124
+ offset,
125
+ totalBufferSize,
126
+ memPointer,
127
+ contextHandle
128
+ );
129
+ offset += tensorDataSize;
130
+ }
131
+ } else {
132
+ returnStatus = m_bufferManager->allocateTensorBuffer(
133
+ ((*tensors) + tensorIdx), tensorDataSize
134
+ );
135
+ }
136
+ }
137
+ if (true != returnStatus) {
138
+ QNN_ERROR("Failure in setupTensors, cleaning up resources");
139
+ tearDownTensors(*tensors, tensorIdx);
140
+ *tensors = nullptr;
141
+ QNN_ERROR("Failure in setupTensors, done cleaning up resources");
142
+ return false;
143
+ } else {
144
+ tensorNameToTensorPointer.insert({wrapperTensorName, ((*tensors) + tensorIdx)});
145
+ // QNN_DEBUG("allocateBuffer successful");
146
+ }
147
+ }
148
+
149
+ return returnStatus;
150
+ }
151
+
152
+ // Setup details for all input tensors for graph execution.
153
+ bool IOTensor::setupInputTensors(
154
+ Qnn_Tensor_t** inputs,
155
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
156
+ const GraphInfo_t& graphInfo,
157
+ std::unordered_map<std::string, size_t>& inputTensorsSize,
158
+ Qnn_ContextHandle_t contextHandle,
159
+ bool skipBufferAllocation
160
+ ) {
161
+
162
+ if (true != setupTensors(
163
+ inputs,
164
+ tensorNameToTensorPointer,
165
+ graphInfo.numInputTensors,
166
+ (graphInfo.inputTensors),
167
+ inputTensorsSize,
168
+ contextHandle,
169
+ skipBufferAllocation
170
+ )) {
171
+ QNN_ERROR("Failure in setupInputTensors, cleaning up resources");
172
+ if (nullptr != *inputs) {
173
+ QNN_DEBUG("cleaning up input tensors");
174
+ tearDownTensors(*inputs, graphInfo.numInputTensors);
175
+ *inputs = nullptr;
176
+ }
177
+ QNN_ERROR("Failure in setupInputTensors, done cleaning up resources");
178
+
179
+ return false;
180
+ }
181
+
182
+ return true;
183
+ }
184
+
185
+ // Setup details for all output tensors for graph execution.
186
+ bool IOTensor::setupOutputTensors(
187
+ Qnn_Tensor_t** outputs,
188
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
189
+ const GraphInfo_t& graphInfo,
190
+ std::unordered_map<std::string, size_t>& outputTensorsSize,
191
+ Qnn_ContextHandle_t contextHandle,
192
+ bool skipBufferAllocation
193
+ ) {
194
+
195
+ if (true != setupTensors(
196
+ outputs,
197
+ tensorNameToTensorPointer,
198
+ graphInfo.numOutputTensors,
199
+ (graphInfo.outputTensors),
200
+ outputTensorsSize,
201
+ contextHandle,
202
+ skipBufferAllocation
203
+ )) {
204
+ QNN_ERROR("Failure in setupOutputTensors, cleaning up resources");
205
+ if (nullptr != *outputs) {
206
+ QNN_DEBUG("cleaning up output tensors");
207
+ tearDownTensors(*outputs, graphInfo.numOutputTensors);
208
+ *outputs = nullptr;
209
+ }
210
+ QNN_ERROR("Failure in setupOutputTensors, done cleaning up resources");
211
+
212
+ return false;
213
+ }
214
+
215
+ return true;
216
+ }
217
+
218
+ bool IOTensor::mapFusedBufferOffset(
219
+ GraphInfo_t* graph_info,
220
+ Qnn_ContextHandle_t context_handle,
221
+ const std::map<std::string, std::tuple<int, size_t, size_t>>& graph_allocs
222
+ ) {
223
+ std::lock_guard lk(_tmp_lock); // READ COMMENT IN IOTensor.hpp _tmp_lock
224
+
225
+ bool ret = true;
226
+ for (const bool mode : {true, false}) {
227
+ TensorWrapper* tensor_bank = (mode) ? graph_info->inputTensors : graph_info->outputTensors;
228
+ uint32_t num_tensors = (mode) ? graph_info->numInputTensors : graph_info->numOutputTensors;
229
+
230
+ for (size_t tidx = 0; tidx < num_tensors; tidx++) {
231
+ TensorWrapper& tensor_wrapper = tensor_bank[tidx];
232
+
233
+ Qnn_Tensor_t* tensor = &GET_TENSOR_WRAPPER_TENSOR(tensor_wrapper);
234
+ std::string tensor_name = std::string(GET_TENSOR_WRAPPER_NAME(tensor_wrapper));
235
+
236
+ if (!graph_allocs.contains(tensor_name)) continue;
237
+ auto& [alloc_idx, offset, size] = graph_allocs.at(tensor_name);
238
+ ret &= m_bufferManager->mapFusedBufferOffset(
239
+ tensor, alloc_idx, offset, context_handle, size
240
+ );
241
+ }
242
+ }
243
+
244
+ return ret;
245
+ }
246
+
247
+ // Clean up all tensors related data after execution.
248
+ bool IOTensor::tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount) {
249
+
250
+ if (nullptr != tensors) {
251
+ QNN_DEBUG("cleaning up resources for tensors");
252
+ for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
253
+ // QNN_DEBUG("freeing resources for tensor: %zu", tensorIdx);
254
+ if (nullptr != QNN_TENSOR_GET_DIMENSIONS(&tensors[tensorIdx])) {
255
+ // QNN_DEBUG("freeing maxDimensions");
256
+ free(QNN_TENSOR_GET_DIMENSIONS(&tensors[tensorIdx]));
257
+ }
258
+ if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
259
+ m_bufferManager->deregisterTensorFusedBuffer(&(tensors[tensorIdx]));
260
+ } else {
261
+ m_bufferManager->freeTensorBuffer(&(tensors[tensorIdx]));
262
+ }
263
+ m_freeTensorsPointerSet.insert(&(tensors[tensorIdx]));
264
+ }
265
+ free(tensors);
266
+ tensors = nullptr;
267
+ }
268
+
269
+ return true;
270
+ }
271
+
272
+ // Clean up all tensors after execution.
273
+ bool IOTensor::tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t numTensors) {
274
+
275
+ for (Qnn_Tensor_t* tensor : tensors) {
276
+ tearDownTensors(tensor, numTensors);
277
+ }
278
+
279
+ return true;
280
+ }
281
+
282
+ bool IOTensor::tearDownTensors(std::vector<Qnn_Tensor_t>& tensors) {
283
+ return tearDownTensors(tensors.data(), tensors.size());
284
+ }
285
+
286
+ // Clean up all tensors after execution.
287
+ bool IOTensor::tearDownTensors(
288
+ std::unordered_map<std::string, Qnn_Tensor_t*>& tensors,
289
+ std::unordered_map<std::string, uint32_t>& tensorCountMap
290
+ ) {
291
+
292
+ for (auto& tensor : tensors) {
293
+ tearDownTensors(tensor.second, tensorCountMap[tensor.first]);
294
+ }
295
+
296
+ return true;
297
+ }
298
+
299
+ // Clean up all tensors after execution.
300
+ bool IOTensor::tearDownTensors(
301
+ std::vector<std::unordered_map<std::string, Qnn_Tensor_t*>>& tensors,
302
+ std::unordered_map<std::string, uint32_t>& tensorCountMap
303
+ ) {
304
+
305
+ for (auto& tensor : tensors) {
306
+ tearDownTensors(tensor, tensorCountMap);
307
+ }
308
+
309
+ return true;
310
+ }
311
+
312
+ bool IOTensor::deepCopyQnnTensorInfo(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
313
+
314
+ if (nullptr == dest || nullptr == src) {
315
+ QNN_ERROR("Received nullptr");
316
+ return false;
317
+ }
318
+
319
+ // set tensor.version before using QNN_TENSOR_SET macros, as they require the version to be set
320
+ // to correctly assign values
321
+ dest->version = src->version;
322
+ const char* tensorName = QNN_TENSOR_GET_NAME(src);
323
+ if (!tensorName) {
324
+ QNN_TENSOR_SET_NAME(dest, nullptr);
325
+ } else {
326
+ QNN_TENSOR_SET_NAME(dest, __strdup(tensorName));
327
+ }
328
+ QNN_TENSOR_SET_ID(dest, QNN_TENSOR_GET_ID(src));
329
+ QNN_TENSOR_SET_TYPE(dest, QNN_TENSOR_GET_TYPE(src));
330
+ QNN_TENSOR_SET_DATA_FORMAT(dest, QNN_TENSOR_GET_DATA_FORMAT(src));
331
+ QNN_TENSOR_SET_DATA_TYPE(dest, QNN_TENSOR_GET_DATA_TYPE(src));
332
+ Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT;
333
+ qParams.encodingDefinition = QNN_TENSOR_GET_QUANT_PARAMS(src).encodingDefinition;
334
+ qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
335
+ if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding ==
336
+ QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
337
+ qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding;
338
+ qParams.scaleOffsetEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).scaleOffsetEncoding;
339
+ } else if (QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding ==
340
+ QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
341
+ qParams.quantizationEncoding = QNN_TENSOR_GET_QUANT_PARAMS(src).quantizationEncoding;
342
+ qParams.axisScaleOffsetEncoding.axis =
343
+ QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.axis;
344
+ qParams.axisScaleOffsetEncoding.numScaleOffsets =
345
+ QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets;
346
+ if (QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets > 0) {
347
+ qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc(
348
+ QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets *
349
+ sizeof(Qnn_ScaleOffset_t)
350
+ );
351
+ if (qParams.axisScaleOffsetEncoding.scaleOffset) {
352
+ for (size_t idx = 0;
353
+ idx < QNN_TENSOR_GET_QUANT_PARAMS(src).axisScaleOffsetEncoding.numScaleOffsets;
354
+ idx++) {
355
+ qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale =
356
+ QNN_TENSOR_GET_QUANT_PARAMS(src)
357
+ .axisScaleOffsetEncoding.scaleOffset[idx]
358
+ .scale;
359
+ qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset =
360
+ QNN_TENSOR_GET_QUANT_PARAMS(src)
361
+ .axisScaleOffsetEncoding.scaleOffset[idx]
362
+ .offset;
363
+ }
364
+ }
365
+ }
366
+ }
367
+ QNN_TENSOR_SET_QUANT_PARAMS(dest, qParams);
368
+ QNN_TENSOR_SET_RANK(dest, QNN_TENSOR_GET_RANK(src));
369
+ QNN_TENSOR_SET_DIMENSIONS(dest, nullptr);
370
+ if (QNN_TENSOR_GET_RANK(src) > 0) {
371
+ QNN_TENSOR_SET_DIMENSIONS(
372
+ dest, (uint32_t*)malloc(QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t))
373
+ );
374
+ if (QNN_TENSOR_GET_DIMENSIONS(dest)) {
375
+ memcpy(QNN_TENSOR_GET_DIMENSIONS(dest),
376
+ QNN_TENSOR_GET_DIMENSIONS(src),
377
+ QNN_TENSOR_GET_RANK(src) * sizeof(uint32_t));
378
+ }
379
+ }
380
+
381
+ return true;
382
+ }
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+ #pragma once
9
+
10
+ #include <memory>
11
+ #include <queue>
12
+ #include <unordered_map>
13
+ #include <unordered_set>
14
+ #include <vector>
15
+ #include <mutex>
16
+
17
+ #include "IBufferAlloc.hpp"
18
+ #include "QnnTypeDef.hpp"
19
+ #include "Log.hpp"
20
+ #include "QnnBackend.h"
21
+ #include "QnnCommon.h"
22
+ #include "QnnContext.h"
23
+ #include "QnnGraph.h"
24
+ #include "QnnInterface.h"
25
+ #include "QnnProperty.h"
26
+ #include "QnnTensor.h"
27
+ #include "QnnTypes.h"
28
+ enum class BufferAlloc {
29
+ DEFAULT, // malloc based allocator
30
+ SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform
31
+ INVALID
32
+ };
33
+ class IBufferAlloc;
34
+ class IOTensor {
35
+ public:
36
+ IOTensor(
37
+ BufferAlloc bufferAllocIn = BufferAlloc::DEFAULT,
38
+ QNN_INTERFACE_VER_TYPE* qnnInterface = nullptr
39
+ );
40
+
41
+ ~IOTensor();
42
+
43
+ bool initialize(Qnn_ContextHandle_t contextHandle = nullptr);
44
+
45
+ bool setupInputTensors(
46
+ Qnn_Tensor_t** inputs,
47
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
48
+ const GraphInfo_t& graphInfo,
49
+ std::unordered_map<std::string, size_t>& inputTensorsSize,
50
+ Qnn_ContextHandle_t contextHandle,
51
+ bool skipBufferAllocation = false
52
+ );
53
+
54
+ bool setupOutputTensors(
55
+ Qnn_Tensor_t** outputs,
56
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
57
+ const GraphInfo_t& graphInfo,
58
+ std::unordered_map<std::string, size_t>& outputTensorsSize,
59
+ Qnn_ContextHandle_t contextHandle,
60
+ bool skipBufferAllocation = false
61
+ );
62
+
63
+ bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount);
64
+
65
+ bool tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t tensorCount);
66
+ bool tearDownTensors(std::vector<Qnn_Tensor_t>& tensors);
67
+ bool tearDownTensors(
68
+ std::unordered_map<std::string, Qnn_Tensor_t*>& tensors,
69
+ std::unordered_map<std::string, uint32_t>& tensorCountMap
70
+ );
71
+ bool tearDownTensors(
72
+ std::vector<std::unordered_map<std::string, Qnn_Tensor_t*>>& tensors,
73
+ std::unordered_map<std::string, uint32_t>& tensorCountMap
74
+ );
75
+
76
+ bool tearDownTensors(const GraphInfo_t* graph_info) {
77
+ bool status = true;
78
+ if (!tearDownTensors(graph_info->inputTensors, graph_info->numInputTensors)) {
79
+ status = false;
80
+ QNN_ERROR("Failed to tear down input tensors for graph %s", graph_info->graphName);
81
+ }
82
+
83
+ if (!tearDownTensors(graph_info->outputTensors, graph_info->numOutputTensors)) {
84
+ status = false;
85
+ QNN_ERROR("Failed to tear down output tensors for graph %s", graph_info->graphName);
86
+ }
87
+ return status;
88
+ }
89
+
90
+ void* getBuffer(Qnn_Tensor_t* tensor) { return m_bufferManager->getBuffer(tensor); };
91
+
92
+ int getFd(Qnn_Tensor_t* tensor) { return m_bufferManager->getFd(tensor); };
93
+
94
+ size_t getOffset(Qnn_Tensor_t* tensor) { return m_bufferManager->getOffset(tensor); };
95
+
96
+ size_t getBufferSize(Qnn_Tensor_t* tensor) { return m_bufferManager->getBufferSize(tensor); };
97
+
98
+ size_t getTotalBufferSize(Qnn_Tensor_t* tensor) {
99
+ return m_bufferManager->getTotalBufferSize(tensor);
100
+ }
101
+
102
+ void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) {
103
+ return m_bufferManager->allocateTensorFusedBuffer(bufferSize, fd);
104
+ }
105
+
106
+ bool allocateBuffers(
107
+ const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
108
+ std::map<std::string, std::pair<int, size_t>>& tensor_offsets
109
+ ) {
110
+ return m_bufferManager->allocateBuffers(allocs_per_chunk, tensor_offsets);
111
+ }
112
+
113
+ bool mapFusedBufferOffset(
114
+ Qnn_Tensor_t* tensor,
115
+ size_t tensorDataSize,
116
+ int32_t fd,
117
+ uint32_t offset,
118
+ uint64_t totalBufferSize,
119
+ void* memPointer,
120
+ Qnn_ContextHandle_t contextHandle
121
+ ) {
122
+ return m_bufferManager->mapFusedBufferOffset(
123
+ tensor, tensorDataSize, fd, offset, totalBufferSize, memPointer, contextHandle
124
+ );
125
+ }
126
+
127
+ bool mapFusedBufferOffset(
128
+ GraphInfo_t* graph_info,
129
+ Qnn_ContextHandle_t context_handle,
130
+ const std::map<std::string, std::tuple<int, size_t, size_t>>& graph_allocs
131
+ );
132
+
133
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
134
+ return m_bufferManager->useSameMemory(dest, src);
135
+ }
136
+
137
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) {
138
+ return m_bufferManager->useSameMemory(dest, src, offset);
139
+ }
140
+
141
+ bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) {
142
+ return m_bufferManager->useExternalMemory(dest, extMem);
143
+ }
144
+
145
+ BufferAlloc getBufferAllocType() { return m_bufferAlloc; }
146
+
147
+ std::unordered_set<void*>& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; }
148
+
149
+ private:
150
+ BufferAlloc m_bufferAlloc;
151
+ QNN_INTERFACE_VER_TYPE* m_qnnInterface;
152
+ std::unique_ptr<IBufferAlloc> m_bufferManager;
153
+ std::unordered_set<void*> m_freeTensorsPointerSet;
154
+
155
+ // There seems to be a race condition in mapFusedBufferOffset because we are
156
+ // calling it from multiple threads. Maybe memRegister/memDeRegister is not thread-safe
157
+ // Until I figure this out, adding a temporary lock here. TODO: Fix and remove this!
158
+ std::mutex _tmp_lock;
159
+
160
+ bool deepCopyQnnTensorInfo(Qnn_Tensor_t* dest, Qnn_Tensor_t* src);
161
+ bool setupTensors(
162
+ Qnn_Tensor_t** tensors,
163
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
164
+ uint32_t tensorCount,
165
+ TensorWrapper* tensorsInfo,
166
+ std::unordered_map<std::string, size_t>& tensorsSize,
167
+ Qnn_ContextHandle_t contextHandle,
168
+ bool skipBufferAllocation = false
169
+ );
170
+ };
Genie/Genie/src/qualla/engines/qnn-api/Log.hpp ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <stdio.h>
12
+
13
+ // FIXME: Use logger from qualla::Env
14
+
15
+ #define QNN_INFO(fmt, ...) fprintf(stderr, "[INFO] " #fmt "\n", ##__VA_ARGS__)
16
+ #define QNN_ERROR(fmt, ...) fprintf(stderr, "[ERROR] " #fmt "\n", ##__VA_ARGS__)
17
+ #define QNN_WARN(fmt, ...) fprintf(stderr, "[WARN] " #fmt "\n", ##__VA_ARGS__)
18
+
19
+ #if 0
20
+ // #define NSP_LOG_LEVEL 2
21
+ #define QNN_DEBUG(fmt, ...) fprintf(stderr, "[DEBUG] " #fmt "\n", ##__VA_ARGS__)
22
+ #else
23
+ #define QNN_DEBUG(fmt, ...)
24
+ #endif
Genie/Genie/src/qualla/engines/qnn-api/NetRunBackend.hpp ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <string>
12
+
13
+ #include "ICommandLineManager.hpp"
14
+ #include "IBackend.hpp"
15
+
16
+ // This is an implementation of IBackend interface within qnn-net-run.
17
+ // NetRunBackend provides a dummy implementation of IBackend as a concrete
18
+ // implementation is needed in case there is no backend extensions library
19
+ // supplied by the user.
20
+ // This is built as part of QnnNetRun library and is used in case of no
21
+ // user supplied backend extensions implementation.
22
+ class NetRunBackend final : public IBackend {
23
+ public:
24
+ NetRunBackend() {}
25
+
26
+ virtual ~NetRunBackend() {}
27
+
28
+ virtual bool setupLogging(QnnLog_Callback_t callback, QnnLog_Level_t maxLogLevel) override {
29
+ ignore(callback);
30
+ ignore(maxLogLevel);
31
+ return true;
32
+ }
33
+
34
+ virtual bool initialize(void* backendLibHandle) override {
35
+ ignore(backendLibHandle);
36
+ return true;
37
+ }
38
+
39
+ virtual bool setPerfProfile(PerfProfile perfProfile) override {
40
+ ignore(perfProfile);
41
+ return true;
42
+ }
43
+
44
+ virtual QnnProfile_Level_t getProfilingLevel() override { return g_profilingLevelNotSet; }
45
+
46
+ virtual bool loadConfig(std::string configFile) override {
47
+ ignore(configFile);
48
+ return true;
49
+ }
50
+
51
+ virtual bool loadCommandLineArgs(std::shared_ptr<ICommandLineManager> clManager) override {
52
+ ignore(clManager);
53
+ return true;
54
+ }
55
+
56
+ virtual bool beforeBackendInitialize(
57
+ QnnBackend_Config_t*** customConfigs,
58
+ uint32_t* configCount
59
+ ) override {
60
+ ignore(customConfigs);
61
+ ignore(configCount);
62
+ return true;
63
+ }
64
+
65
+ virtual bool afterBackendInitialize() override { return true; }
66
+
67
+ virtual bool beforeContextCreate(QnnContext_Config_t*** customConfigs, uint32_t* configCount)
68
+ override {
69
+ ignore(customConfigs);
70
+ ignore(configCount);
71
+ return true;
72
+ }
73
+
74
+ virtual bool afterContextCreate() override { return true; }
75
+
76
+ virtual bool beforeComposeGraphs(GraphConfigInfo_t*** customGraphConfigs, uint32_t* graphCount)
77
+ override {
78
+ ignore(customGraphConfigs);
79
+ ignore(graphCount);
80
+ return true;
81
+ }
82
+
83
+ virtual bool afterComposeGraphs() override { return true; }
84
+
85
+ #if QUALLA_QNN_API_VERSION >= 21700
86
+ virtual bool beforeGraphFinalizeUpdateConfig(
87
+ const char* graphName,
88
+ Qnn_GraphHandle_t graphHandle,
89
+ QnnGraph_Config_t*** customConfigs,
90
+ uint32_t* configCount
91
+ ) override {
92
+ ignore(graphName);
93
+ ignore(graphHandle);
94
+ ignore(customConfigs);
95
+ ignore(configCount);
96
+ return true;
97
+ }
98
+ #endif
99
+
100
+ virtual bool beforeGraphFinalize() override { return true; }
101
+
102
+ virtual bool afterGraphFinalize() override { return true; }
103
+
104
+ virtual bool beforeRegisterOpPackages() override { return true; }
105
+
106
+ virtual bool afterRegisterOpPackages() override { return true; }
107
+
108
+ virtual bool beforeExecute(
109
+ const char* graphName,
110
+ QnnGraph_Config_t*** customConfigs,
111
+ uint32_t* configCount
112
+ ) override {
113
+ ignore(graphName);
114
+ ignore(customConfigs);
115
+ ignore(configCount);
116
+ return true;
117
+ }
118
+
119
+ virtual bool afterExecute() override { return true; }
120
+
121
+ virtual bool beforeContextFree() override { return true; }
122
+
123
+ virtual bool afterContextFree() override { return true; }
124
+
125
+ virtual bool beforeBackendTerminate() override { return true; }
126
+
127
+ virtual bool afterBackendTerminate() override { return true; }
128
+
129
+ virtual bool beforeCreateFromBinary(QnnContext_Config_t*** customConfigs, uint32_t* configCount)
130
+ override {
131
+ ignore(customConfigs);
132
+ ignore(configCount);
133
+ return true;
134
+ }
135
+
136
+ virtual bool afterCreateFromBinary() override { return true; }
137
+
138
+ #if QUALLA_QNN_API_VERSION >= 21700
139
+ virtual bool beforeCreateContextsFromBinaryList(
140
+ std::map<std::string, std::tuple<QnnContext_Config_t**, uint32_t>>*
141
+ contextKeyToCustomConfigsMap,
142
+ QnnContext_Config_t*** commonCustomConfigs,
143
+ uint32_t* commonConfigCount
144
+ ) override {
145
+ ignore(contextKeyToCustomConfigsMap);
146
+ ignore(commonCustomConfigs);
147
+ ignore(commonConfigCount);
148
+ return true;
149
+ }
150
+
151
+ virtual bool afterCreateContextsFromBinaryList() override { return true; }
152
+ #endif
153
+
154
+ virtual bool beforeCreateDevice(QnnDevice_Config_t*** deviceConfigs, uint32_t* configCount)
155
+ override {
156
+ ignore(deviceConfigs);
157
+ ignore(configCount);
158
+ return true;
159
+ }
160
+
161
+ virtual bool afterCreateDevice() override { return true; }
162
+
163
+ virtual bool beforeFreeDevice() override { return true; }
164
+
165
+ virtual bool afterFreeDevice() override { return true; }
166
+
167
+ private:
168
+ // Utility function to ignore compiler warnings when a variable
169
+ // is unused. Recommended by Herb Sutter in Sutter's Mill
170
+ // instead of (void)variable.
171
+ template <typename T>
172
+ void ignore(const T&) {}
173
+ };
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp ADDED
The diff for this file is too large to render. See raw diff
 
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include "BackendExtensions.hpp"
12
+ #include "QnnConfig.hpp"
13
+ #include "QnnHtpPerfInfrastructure.h"
14
+ #include "QnnHtpDevice.h"
15
+ #include "qnn-utils.hpp"
16
+ #include "IOTensor.hpp"
17
+
18
+ #include <memory>
19
+ #include <mutex>
20
+
21
+ #define QNN_IO_TENSOR_DEBUG 0
22
+
23
+ enum KVManagerMode { POINTER_SHIFT = 0x0, SHIFT_CONCAT = 0x1 };
24
+
25
+ using qualla::QnnUtils::QuantParam;
26
+
27
+ #define QUALLA_QNN_API_VERSION \
28
+ (QNN_API_VERSION_MAJOR * 10000 + QNN_API_VERSION_MINOR * 100 + QNN_API_VERSION_PATCH)
29
+
30
+ static std::map<Qnn_DataType_t, size_t> g_qnnDataTypeToSize = {
31
+ {QNN_DATATYPE_INT_8, 1},
32
+ {QNN_DATATYPE_INT_16, 2},
33
+ {QNN_DATATYPE_INT_32, 4},
34
+ {QNN_DATATYPE_INT_64, 8},
35
+ {QNN_DATATYPE_UINT_8, 1},
36
+ {QNN_DATATYPE_UINT_16, 2},
37
+ {QNN_DATATYPE_UINT_32, 4},
38
+ {QNN_DATATYPE_UINT_64, 8},
39
+ {QNN_DATATYPE_FLOAT_16, 2},
40
+ {QNN_DATATYPE_FLOAT_32, 4},
41
+ {QNN_DATATYPE_SFIXED_POINT_8, 1},
42
+ {QNN_DATATYPE_SFIXED_POINT_16, 2},
43
+ {QNN_DATATYPE_SFIXED_POINT_32, 4},
44
+ {QNN_DATATYPE_UFIXED_POINT_8, 1},
45
+ {QNN_DATATYPE_UFIXED_POINT_16, 2},
46
+ {QNN_DATATYPE_UFIXED_POINT_32, 4},
47
+ {QNN_DATATYPE_BOOL_8, 1},
48
+ };
49
+
50
+ class QnnApi {
51
+ private:
52
+ const uint32_t s_graphConfigsReserveCount = 16;
53
+
54
+ // Model vars
55
+ typedef Qnn_ErrorHandle_t (*QnnInterfaceGetProvidersFn_t)(
56
+ const QnnInterface_t*** providerList,
57
+ uint32_t* numProviders
58
+ );
59
+ typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(
60
+ const QnnSystemInterface_t*** providerList,
61
+ uint32_t* numProviders
62
+ );
63
+
64
+ // Graph Related Function Handle Types
65
+ typedef ModelError_t (*ComposeGraphsFnHandleType_t)(
66
+ Qnn_BackendHandle_t,
67
+ QNN_INTERFACE_VER_TYPE,
68
+ Qnn_ContextHandle_t,
69
+ const GraphConfigInfo_t**,
70
+ const uint32_t,
71
+ GraphInfo_t***,
72
+ uint32_t*,
73
+ bool,
74
+ QnnLog_Callback_t,
75
+ QnnLog_Level_t
76
+ );
77
+
78
+ typedef ModelError_t (*GenAIComposeGraphsFnHandleType_t)(
79
+ Qnn_BackendHandle_t,
80
+ QNN_INTERFACE_VER_TYPE,
81
+ Qnn_ContextHandle_t,
82
+ const GraphConfigInfo_t**,
83
+ const uint32_t,
84
+ uint32_t* inputDim,
85
+ uint32_t inputRank,
86
+ uint32_t* outputDim,
87
+ uint32_t outputRank,
88
+ uint32_t* kvDim,
89
+ uint32_t kvRank,
90
+ Qnn_Param_t* params,
91
+ uint32_t numParam,
92
+ GraphInfo_t***,
93
+ uint32_t*,
94
+ bool,
95
+ QnnLog_Callback_t,
96
+ QnnLog_Level_t
97
+ );
98
+
99
+ typedef ModelError_t (*FreeGraphInfoFnHandleType_t)(GraphInfo_t***, uint32_t);
100
+
101
+ void* m_libModelHandle{nullptr};
102
+ void* m_backendHandle{nullptr};
103
+ void* m_backendLibraryHandle{nullptr};
104
+
105
+ QNN_INTERFACE_VER_TYPE m_qnnInterface{nullptr};
106
+ QNN_SYSTEM_INTERFACE_VER_TYPE m_qnnSystemInterface{nullptr};
107
+ std::unique_ptr<BackendExtensions> m_backendExtensions{nullptr};
108
+ ComposeGraphsFnHandleType_t m_composeGraphsFnHandle{nullptr};
109
+ GenAIComposeGraphsFnHandleType_t m_genaiComposeGraphsFnHandle{nullptr};
110
+ FreeGraphInfoFnHandleType_t m_freeGraphInfoFnHandle{nullptr};
111
+ uint32_t m_backendId{0};
112
+ Qnn_LogHandle_t m_logHandle{nullptr};
113
+ Qnn_DeviceHandle_t m_deviceHandle{nullptr};
114
+
115
+ Qnn_ProfileHandle_t m_profileBackendHandle{nullptr};
116
+
117
+ std::vector<Qnn_ContextHandle_t> m_contextVec;
118
+ std::unordered_map<GraphInfo*, Qnn_ContextHandle_t> m_contextMap;
119
+ uint32_t m_graphsCount{0};
120
+ int32_t graphCountPerContext{-1};
121
+ GraphInfo_t** m_graphsInfo;
122
+ std::unordered_map<std::string, uint32_t> m_graphNameToIndex;
123
+ std::unordered_map<std::string, GraphInfo*> m_graphNameToInfo;
124
+ std::unordered_map<std::string, uint32_t> m_graphNameToContextIdx;
125
+ std::unordered_map<uint32_t, Qnn_ContextHandle_t> m_contextIdtoHandle;
126
+ std::mutex m_updateCallBackMutex;
127
+
128
+ // Useful Structure for IO Esimtation
129
+ std::unordered_map<int,qualla::QnnUtils::TensorMap> m_graphtoIOMap; // stores {GraphId -> IOTensorMap}
130
+ typedef int CtxBitVector;
131
+ std::map<CtxBitVector, std::map<std::string, size_t>> m_contextAllocMap; // stores {Translated ContextId -> {Tensor name, size}}
132
+ std::map<std::string, std::pair<int, size_t>> m_tensorAllocInfo; // stores {Tensor name -> (fd of RPC buffer, offset)}
133
+ std::unordered_map<uint32_t, uint32_t> m_graphIdxToContextIdx; // stores {Graph Idx -> Context Idx}
134
+ std::unordered_map<std::string,std::shared_ptr<uint8_t>> m_adapterNameToBuffer;
135
+
136
+ uint32_t m_backendConfigCount{0};
137
+ QnnBackend_Config_t** m_backendConfigs{nullptr};
138
+
139
+ QnnHtpDevice_PerfInfrastructure_t* m_perfInfra{nullptr};
140
+ uint32_t m_powerConfigId = 1;
141
+
142
+ // Useful Structure for IO Esimtation
143
+ IOTensor* m_ioBufferMgr{nullptr};
144
+ int32_t m_ctxSize{-1};
145
+ int32_t m_kvDim{-1};
146
+ bool m_loraWeightEnabled{false};
147
+ bool m_lmHeadWeightInput{false};
148
+ KVManagerMode m_kvUpdateMethod{POINTER_SHIFT};
149
+
150
+ bool m_isLogInitialized{false};
151
+ bool m_isBackendInitialized{false};
152
+ bool m_isContextCreated{false};
153
+
154
+ // Variable to keep track of debug mode
155
+ bool m_DebugModeRequested;
156
+ bool m_debugQnn{false};
157
+
158
+ // Variable to indicate whether to mmap context bins or read them in memory
159
+ bool m_mmapContextBins;
160
+ bool m_isDeviceCreated = false;
161
+
162
+ std::vector<std::pair<uint8_t*, uint64_t>> m_contextBinBuffersToBeCleared;
163
+
164
+ void setDeviceStatus(bool status) { m_isDeviceCreated = status; }
165
+ bool getDeviceStatus() { return m_isDeviceCreated; }
166
+ bool getContextConfigs(
167
+ QnnContext_Config_t*** configs,
168
+ uint32_t& contextConfigCount,
169
+ Qnn_Priority_t contextPriority,
170
+ bool graphSwitching = false,
171
+ const std::vector<std::string>& execSelectGraphs = {},
172
+ bool loadSelectGraphs = false
173
+ );
174
+ bool mergeAllContextConfigs(
175
+ QnnContext_Config_t*** allCustomContextConfigs,
176
+ QnnContext_Config_t** customConfigs,
177
+ QnnContext_Config_t** contextConfigs,
178
+ uint32_t customConfigCount,
179
+ uint32_t contextConfigCount
180
+ );
181
+ bool freeContextConfigs(QnnContext_Config_t** contextConfigs, uint32_t contextConfigCount);
182
+ bool setGraphConfigsBeforeExecute(
183
+ Qnn_GraphHandle_t graphHandle,
184
+ QnnGraph_Config_t** graphConfigs,
185
+ uint32_t configCount
186
+ );
187
+
188
+ bool getQnnInterface(std::string backendPath);
189
+ bool getQnnSystemInterface(std::string systemLibraryPath);
190
+ bool loadModel(std::string model_path);
191
+ bool initializeLogging(const QnnLog_Level_t& logLevel, bool debug_qnn);
192
+ void terminateLog();
193
+ bool initializeBackendExtensions(
194
+ BackendExtensionsConfigs backendExtensionsConfig,
195
+ PerfProfile parsedPerfProfile,
196
+ bool debug_qnn
197
+ );
198
+ bool initializeBackend();
199
+ bool terminateBackend();
200
+ bool createDevice();
201
+ bool freeDevice();
202
+ bool createContext(ContextConfigs contextConfig);
203
+ bool freeContext();
204
+ bool composeGraphs(std::vector<GraphConfigs> graphConfigs);
205
+ bool composeGraphs(
206
+ std::vector<GraphConfigs> graphConfigs,
207
+ uint32_t* inputDim,
208
+ uint32_t inputRank,
209
+ uint32_t* outputDim,
210
+ uint32_t outputRank,
211
+ uint32_t* kvDim,
212
+ uint32_t kvRank,
213
+ Qnn_Param_t* params,
214
+ uint32_t numParams
215
+ );
216
+ bool mapAndGetContextBinaryInfo(
217
+ const bool use_mmap,
218
+ std::shared_ptr<uint8_t>& buffer,
219
+ const std::string binaryPath,
220
+ const uint64_t bufferSize,
221
+ const size_t contextIdx,
222
+ const bool graphSwitching,
223
+ QnnSystemContext_Handle_t sysCtxHandle,
224
+ const QnnSystemContext_BinaryInfo_t** binaryInfo
225
+ );
226
+
227
+ bool parseIOTensorsAndAccumulate();
228
+ bool registerTensorsWithBackend(uint32_t& graphIdx);
229
+
230
+ bool finalizeGraphs();
231
+ bool initializePerformance();
232
+ bool destroyPerformance();
233
+ bool boostPerformance();
234
+ bool resetPerformance();
235
+ bool checkCapabilityOfCreateAsync(bool& propRet);
236
+
237
+ bool initProfiling();
238
+ bool extractBackendProfilingInfo(
239
+ Qnn_ProfileHandle_t profileHandle,
240
+ std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
241
+ std::string graphName
242
+ );
243
+ bool extractProfilingSubEvents(
244
+ QnnProfile_EventId_t profileEventId,
245
+ std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
246
+ std::string graphName
247
+ );
248
+ bool extractProfilingEvent(
249
+ QnnProfile_EventId_t profileEventId,
250
+ std::map<std::string, std::pair<double, uint16_t>>& timeLogs,
251
+ std::string graphName
252
+ );
253
+ bool extractBackendProfilingInfo(Qnn_ProfileHandle_t profileHandle);
254
+ bool extractProfilingSubEvents(QnnProfile_EventId_t profileEventId);
255
+ bool extractProfilingEvent(QnnProfile_EventId_t profileEventId);
256
+
257
+ Qnn_ContextHandle_t getContextWithId(uint32_t contextId) {
258
+ return m_contextIdtoHandle[contextId];
259
+ }
260
+
261
+ public:
262
+ QnnApi() {};
263
+ ~QnnApi();
264
+
265
+ bool freeGraphs();
266
+ static QnnApi& getInstance();
267
+ #if QUALLA_QNN_API_VERSION >= 21700
268
+ static void contextNotifyFn(
269
+ Qnn_ContextHandle_t context,
270
+ Qnn_GraphHandle_t graph,
271
+ const char* graph_name,
272
+ QnnContext_createFromBinaryAsyncNotifyType_t completeType,
273
+ void* notifyParam,
274
+ Qnn_ErrorHandle_t status
275
+ );
276
+ #endif
277
+ bool createFromBinary(
278
+ std::vector<std::string> cachedBinariesPathVec,
279
+ ContextConfigs contextConfig,
280
+ int64_t spill_fill_buffer_size = 0,
281
+ uint64_t mmap_budget = 0,
282
+ bool graphSwitching = false,
283
+ const std::vector<std::string>& execSelectGraphs = {},
284
+ bool loadSelectGraphs = false
285
+ );
286
+ #if QUALLA_QNN_API_VERSION >= 21700
287
+ bool createFromBinaryListAsync(
288
+ std::vector<std::string> cachedBinariesPathVec,
289
+ ContextConfigs contextConfig,
290
+ int64_t spill_fill_buffer_size = 0,
291
+ uint64_t mmap_budget = 0,
292
+ bool graphSwitching = false,
293
+ const std::vector<std::string>& execSelectGraphs = {},
294
+ bool loadSelectGraphs = false
295
+ );
296
+ #endif
297
+ bool initialize(
298
+ std::string backendPath,
299
+ std::vector<std::string> modelPathOrCachedBinaryPathVec,
300
+ BackendExtensionsConfigs backendExtensionsConfig,
301
+ PerfProfile parsedPerfProfile = PerfProfile::BURST,
302
+ ContextConfigs contextConfig = ContextConfigs(),
303
+ std::vector<GraphConfigs> graphConfigs = {},
304
+ bool loadFromCachedBinary = false,
305
+ std::string systemLibraryPath = "",
306
+ bool debugModeRequested = false,
307
+ int64_t spill_fill_buffer_size = 0,
308
+ bool mmapContextBins = false,
309
+ bool asyncInit = true,
310
+ uint64_t mmap_budget = 0,
311
+ bool debug_qnn = false,
312
+ bool graphSwitching = false,
313
+ const std::vector<std::string>& execSelectGraphs = {},
314
+ bool loadSelectGraphs = false
315
+ );
316
+
317
+ bool registerOpPackage(std::string opPackagePath);
318
+
319
+ void setIOTensorBufferMgr(IOTensor* ioBufferMgr){
320
+ m_ioBufferMgr = ioBufferMgr;
321
+ }
322
+
323
+ void setKVDim(int32_t kvDim){
324
+ m_kvDim = kvDim;
325
+ }
326
+
327
+ void setContextSize(int32_t ctxSize){
328
+ m_ctxSize = ctxSize;
329
+ }
330
+
331
+ void setKVUpdateMethod(KVManagerMode kvUpdateMethod){
332
+ m_kvUpdateMethod = kvUpdateMethod ;
333
+ }
334
+
335
+ std::map<std::string, std::pair<int, size_t>>* getTensorAllocInfo(){
336
+ return &m_tensorAllocInfo;
337
+ }
338
+
339
+ bool getLmHeadWeightInputEnabled(){
340
+ return m_lmHeadWeightInput;
341
+ }
342
+
343
+ bool getLoraWeightEnabled(){
344
+ return m_loraWeightEnabled;
345
+ }
346
+ // Initalize with OpPackage
347
+ bool initialize(
348
+ std::string backendPath,
349
+ std::string modelPath,
350
+ std::string opPackage,
351
+ ContextConfigs contextConfig,
352
+ std::vector<GraphConfigs> graphConfigs,
353
+ uint32_t* inputDim,
354
+ uint32_t inputRank,
355
+ uint32_t* outputDim,
356
+ uint32_t outputRank,
357
+ uint32_t* kvDim,
358
+ uint32_t kvRank,
359
+ Qnn_Param_t* params,
360
+ uint32_t numParams,
361
+ bool debugModeRequested
362
+ );
363
+
364
+ bool graphExecute(
365
+ Qnn_Tensor_t* input,
366
+ Qnn_Tensor_t* output,
367
+ std::string graphName,
368
+ std::map<std::string, std::pair<double, uint16_t>>& timeLogs
369
+ );
370
+
371
+ bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);
372
+
373
+ QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; };
374
+ GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; };
375
+ uint32_t getGraphsCount() { return m_graphsCount; };
376
+ int32_t getGraphCountPerContext() { return graphCountPerContext; }
377
+ std::vector<Qnn_ContextHandle_t>& getContexts() { return m_contextVec; };
378
+ const Qnn_ContextHandle_t getContexts(GraphInfo_t* const graph) {
379
+ return m_contextMap.at(graph);
380
+ };
381
+
382
+ void updateContext(Qnn_ContextHandle_t context, uint32_t contextId) {
383
+ std::lock_guard<std::mutex> lock(m_updateCallBackMutex);
384
+ m_contextVec.push_back(context);
385
+ m_contextIdtoHandle[contextId] = context;
386
+ }
387
+
388
+ void updateQnnApiGraphsandContextsInfo(
389
+ std::string graphName,
390
+ Qnn_GraphHandle_t graph,
391
+ uint32_t contextId
392
+ ) {
393
+ // set graph handle to GraphInfo
394
+ std::lock_guard<std::mutex> lock(m_updateCallBackMutex);
395
+ m_graphNameToInfo[graphName]->graph = graph;
396
+ m_graphNameToContextIdx[graphName] = contextId;
397
+ m_graphsCount++;
398
+ }
399
+
400
+ static inline size_t getDataTypeSize(const Qnn_DataType_t& datatype) {
401
+ return g_qnnDataTypeToSize[datatype];
402
+ }
403
+ static inline std::string getTensorName(const TensorWrapper& tensorWrapper) {
404
+ return GET_TENSOR_WRAPPER_NAME(tensorWrapper);
405
+ }
406
+ static bool getTensorQuantParams(
407
+ const Qnn_Tensor_t* tensor,
408
+ std::vector<QuantParam>& quantParamsVec
409
+ );
410
+ static bool getTensorShape(std::vector<size_t>& tensorDims, const TensorWrapper& tensorWrapper);
411
+ static inline Qnn_DataType_t getTensorDtype(const Qnn_Tensor_t* tensor) {
412
+ return QNN_TENSOR_GET_DATA_TYPE(tensor);
413
+ }
414
+
415
+ bool getTensorNameAndShape(
416
+ std::string& tensorName,
417
+ std::vector<size_t>& tensorDims,
418
+ TensorWrapper& tensorWrapper
419
+ );
420
+ static void qnnLogCallback(
421
+ const char* fmt,
422
+ QnnLog_Level_t level,
423
+ uint64_t timestamp,
424
+ va_list args
425
+ );
426
+ bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
427
+ uint64_t bufferSize,
428
+ uint32_t graphIndex);
429
+ };
Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.cpp ADDED
@@ -0,0 +1,636 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "QnnApiUtils.hpp"
10
+ #include "QnnTypeMacros.hpp"
11
+
12
+ #include <algorithm>
13
+ #include <cstring>
14
+ #include <fstream>
15
+ #include <iostream>
16
+ #include <sstream>
17
+ #include <string>
18
+ #include <tuple>
19
+
20
+ #include <fcntl.h>
21
+ #include <errno.h>
22
+
23
+ #ifdef _WIN32
24
+ #include <windows.h>
25
+ #define __open ::_open
26
+ #define __strdup ::_strdup
27
+ #else
28
+ #include <unistd.h>
29
+ #include <sys/mman.h>
30
+ #define __open ::open
31
+ #define __strdup ::strdup
32
+ #endif
33
+
34
+ bool freeQnnTensorWrapper(TensorWrapper& tensorWrapper) {
35
+ // free all pointer allocations in struct
36
+ if (nullptr != GET_TENSOR_WRAPPER_NAME(tensorWrapper)) {
37
+ free((void*)GET_TENSOR_WRAPPER_NAME(tensorWrapper));
38
+ }
39
+
40
+ Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrapper);
41
+ free(QNN_TENSOR_GET_DIMENSIONS(tensor));
42
+ return true;
43
+ }
44
+
45
+ bool freeQnnTensorWrappers(TensorWrapper*& tensorWrappers, uint32_t numTensors) {
46
+ // free all pointer allocations in struct
47
+ for (size_t i = 0; i < numTensors; i++) {
48
+ freeQnnTensorWrapper(tensorWrappers[i]);
49
+ }
50
+ free(tensorWrappers);
51
+
52
+ return true;
53
+ }
54
+
55
+ bool freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs) {
56
+ if (graphsInfo == nullptr || *graphsInfo == nullptr) {
57
+ return false;
58
+ }
59
+ for (uint32_t i = 0; i < numGraphs; i++) {
60
+ if (nullptr != (*graphsInfo)[i]) {
61
+ free((*graphsInfo)[i]->graphName);
62
+ freeQnnTensorWrappers(
63
+ (*graphsInfo)[i]->inputTensors, (*graphsInfo)[i]->numInputTensors
64
+ );
65
+ freeQnnTensorWrappers(
66
+ (*graphsInfo)[i]->outputTensors, (*graphsInfo)[i]->numOutputTensors
67
+ );
68
+ }
69
+ }
70
+ free(**graphsInfo);
71
+ free(*graphsInfo);
72
+ *graphsInfo = nullptr;
73
+
74
+ return true;
75
+ }
76
+
77
+ bool freeGraphInfo(GraphInfo_t* graphInfo) {
78
+ if (graphInfo == nullptr) {
79
+ return false;
80
+ }
81
+ if (nullptr != graphInfo->graphName) {
82
+ free(graphInfo->graphName);
83
+ }
84
+ freeQnnTensorWrappers(graphInfo->inputTensors, graphInfo->numInputTensors);
85
+ freeQnnTensorWrappers(graphInfo->outputTensors, graphInfo->numOutputTensors);
86
+ free(graphInfo);
87
+ return true;
88
+ }
89
+
90
+ bool updateTensorInfo(const Qnn_Tensor_t* tensorsInfoSrc,
91
+ TensorWrapper* tensorWrappers,
92
+ uint32_t tensorsCount
93
+ ){
94
+ for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) {
95
+ QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", (int)tIdx);
96
+ Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tIdx]);
97
+
98
+ QNN_TENSOR_SET_ID(tensor, QNN_TENSOR_GET_ID(&tensorsInfoSrc[tIdx]));
99
+ QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_GET_TYPE(&tensorsInfoSrc[tIdx]));
100
+ QNN_TENSOR_SET_DATA_FORMAT(tensor, QNN_TENSOR_GET_DATA_FORMAT(&tensorsInfoSrc[tIdx]));
101
+ QNN_TENSOR_SET_DATA_TYPE(tensor, QNN_TENSOR_GET_DATA_TYPE(&tensorsInfoSrc[tIdx]));
102
+ Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT;
103
+ qParams.encodingDefinition =
104
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).encodingDefinition;
105
+ qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
106
+ if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
107
+ QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
108
+ qParams.quantizationEncoding =
109
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
110
+ qParams.scaleOffsetEncoding =
111
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).scaleOffsetEncoding;
112
+ } else if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
113
+ QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
114
+ qParams.quantizationEncoding =
115
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
116
+ qParams.axisScaleOffsetEncoding.axis =
117
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
118
+ .axisScaleOffsetEncoding.axis;
119
+ qParams.axisScaleOffsetEncoding.numScaleOffsets =
120
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
121
+ .axisScaleOffsetEncoding.numScaleOffsets;
122
+ if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
123
+ .axisScaleOffsetEncoding.numScaleOffsets > 0) {
124
+ qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc(
125
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
126
+ .axisScaleOffsetEncoding.numScaleOffsets *
127
+ sizeof(Qnn_ScaleOffset_t)
128
+ );
129
+ if (qParams.axisScaleOffsetEncoding.scaleOffset) {
130
+ for (size_t idx = 0;
131
+ idx < QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
132
+ .axisScaleOffsetEncoding.numScaleOffsets;
133
+ idx++) {
134
+ qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale =
135
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
136
+ .axisScaleOffsetEncoding.scaleOffset[idx]
137
+ .scale;
138
+ qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset =
139
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
140
+ .axisScaleOffsetEncoding.scaleOffset[idx]
141
+ .offset;
142
+ }
143
+ }
144
+ }
145
+ }
146
+ QNN_TENSOR_SET_QUANT_PARAMS(tensor, qParams);
147
+ QNN_TENSOR_SET_RANK(tensor, QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]));
148
+ if (QNN_TENSOR_GET_RANK(tensorsInfoSrc[tIdx]) > 0) {
149
+ if (QNN_TENSOR_GET_DIMENSIONS(tensor)) {
150
+ memcpy(QNN_TENSOR_GET_DIMENSIONS(tensor),
151
+ QNN_TENSOR_GET_DIMENSIONS(&tensorsInfoSrc[tIdx]),
152
+ QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t));
153
+ }
154
+ }
155
+ }
156
+ return true;
157
+ }
158
+
159
+ bool copyTensorsInfo(
160
+ const Qnn_Tensor_t* tensorsInfoSrc,
161
+ TensorWrapper*& tensorWrappers,
162
+ uint32_t tensorsCount
163
+ ) {
164
+
165
+ auto returnStatus = true;
166
+ tensorWrappers = (TensorWrapper*)calloc(tensorsCount, sizeof(TensorWrapper));
167
+ if (nullptr == tensorWrappers) {
168
+ QNN_ERROR("Failed to allocate memory for tensorWrappers.");
169
+ return false;
170
+ }
171
+ if (returnStatus) {
172
+ for (size_t tIdx = 0; tIdx < tensorsCount; tIdx++) {
173
+ // QNN_DEBUG("Extracting tensorInfo for tensor Idx: %d", (int)tIdx);
174
+ Qnn_Tensor_t& tensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tIdx]);
175
+ tensor = QNN_TENSOR_INIT;
176
+
177
+ const char* tensorName = QNN_TENSOR_GET_NAME(&tensorsInfoSrc[tIdx]);
178
+ if (!tensorName) {
179
+ QNN_TENSOR_SET_NAME(tensor, nullptr);
180
+ } else {
181
+ QNN_TENSOR_SET_NAME(tensor, __strdup(tensorName));
182
+ }
183
+
184
+ QNN_TENSOR_SET_ID(tensor, QNN_TENSOR_GET_ID(&tensorsInfoSrc[tIdx]));
185
+ QNN_TENSOR_SET_TYPE(tensor, QNN_TENSOR_GET_TYPE(&tensorsInfoSrc[tIdx]));
186
+ QNN_TENSOR_SET_DATA_FORMAT(tensor, QNN_TENSOR_GET_DATA_FORMAT(&tensorsInfoSrc[tIdx]));
187
+ QNN_TENSOR_SET_DATA_TYPE(tensor, QNN_TENSOR_GET_DATA_TYPE(&tensorsInfoSrc[tIdx]));
188
+ Qnn_QuantizeParams_t qParams = QNN_QUANTIZE_PARAMS_INIT;
189
+ qParams.encodingDefinition =
190
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).encodingDefinition;
191
+ qParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
192
+ if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
193
+ QNN_QUANTIZATION_ENCODING_SCALE_OFFSET) {
194
+ qParams.quantizationEncoding =
195
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
196
+ qParams.scaleOffsetEncoding =
197
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).scaleOffsetEncoding;
198
+ } else if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding ==
199
+ QNN_QUANTIZATION_ENCODING_AXIS_SCALE_OFFSET) {
200
+ qParams.quantizationEncoding =
201
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx]).quantizationEncoding;
202
+ qParams.axisScaleOffsetEncoding.axis =
203
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
204
+ .axisScaleOffsetEncoding.axis;
205
+ qParams.axisScaleOffsetEncoding.numScaleOffsets =
206
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
207
+ .axisScaleOffsetEncoding.numScaleOffsets;
208
+ if (QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
209
+ .axisScaleOffsetEncoding.numScaleOffsets > 0) {
210
+ qParams.axisScaleOffsetEncoding.scaleOffset = (Qnn_ScaleOffset_t*)malloc(
211
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
212
+ .axisScaleOffsetEncoding.numScaleOffsets *
213
+ sizeof(Qnn_ScaleOffset_t)
214
+ );
215
+ if (qParams.axisScaleOffsetEncoding.scaleOffset) {
216
+ for (size_t idx = 0;
217
+ idx < QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
218
+ .axisScaleOffsetEncoding.numScaleOffsets;
219
+ idx++) {
220
+ qParams.axisScaleOffsetEncoding.scaleOffset[idx].scale =
221
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
222
+ .axisScaleOffsetEncoding.scaleOffset[idx]
223
+ .scale;
224
+ qParams.axisScaleOffsetEncoding.scaleOffset[idx].offset =
225
+ QNN_TENSOR_GET_QUANT_PARAMS(&tensorsInfoSrc[tIdx])
226
+ .axisScaleOffsetEncoding.scaleOffset[idx]
227
+ .offset;
228
+ }
229
+ }
230
+ }
231
+ }
232
+ QNN_TENSOR_SET_QUANT_PARAMS(tensor, qParams);
233
+ QNN_TENSOR_SET_RANK(tensor, QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]));
234
+ QNN_TENSOR_SET_DIMENSIONS(tensor, nullptr);
235
+ if (QNN_TENSOR_GET_RANK(tensorsInfoSrc[tIdx]) > 0) {
236
+ QNN_TENSOR_SET_DIMENSIONS(
237
+ tensor,
238
+ (uint32_t*)malloc(
239
+ QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t)
240
+ )
241
+ );
242
+ if (QNN_TENSOR_GET_DIMENSIONS(tensor)) {
243
+ memcpy(QNN_TENSOR_GET_DIMENSIONS(tensor),
244
+ QNN_TENSOR_GET_DIMENSIONS(&tensorsInfoSrc[tIdx]),
245
+ QNN_TENSOR_GET_RANK(&tensorsInfoSrc[tIdx]) * sizeof(uint32_t));
246
+ }
247
+ }
248
+ }
249
+ }
250
+
251
+ return returnStatus;
252
+ }
253
+
254
+
255
+ bool updateGraphInfoV1(const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
256
+ GraphInfo_t* graphInfoDst
257
+ ){
258
+ if (graphInfoSrc->graphInputs) {
259
+ if (!updateTensorInfo(
260
+ graphInfoSrc->graphInputs,
261
+ graphInfoDst->inputTensors,
262
+ graphInfoSrc->numGraphInputs
263
+ )) {
264
+ return false;
265
+ }
266
+ }
267
+ if (graphInfoSrc->graphOutputs) {
268
+ if (!updateTensorInfo(
269
+ graphInfoSrc->graphOutputs,
270
+ graphInfoDst->outputTensors,
271
+ graphInfoSrc->numGraphOutputs
272
+ )) {
273
+ return false;
274
+ }
275
+ }
276
+ return true;
277
+ }
278
+
279
+
280
+ bool updateGraphInfoV3(const QnnSystemContext_GraphInfoV3_t* graphInfoSrc,
281
+ GraphInfo_t* graphInfoDst
282
+ ){
283
+ if (graphInfoSrc->graphInputs) {
284
+ if (!updateTensorInfo(
285
+ graphInfoSrc->graphInputs,
286
+ graphInfoDst->inputTensors,
287
+ graphInfoSrc->numGraphInputs
288
+ )) {
289
+ return false;
290
+ }
291
+ }
292
+ if (graphInfoSrc->graphOutputs) {
293
+ if (!updateTensorInfo(
294
+ graphInfoSrc->graphOutputs,
295
+ graphInfoDst->outputTensors,
296
+ graphInfoSrc->numGraphOutputs
297
+ )) {
298
+ return false;
299
+ }
300
+ }
301
+ return true;
302
+ }
303
+
304
+ bool copyGraphsInfoV1(
305
+ const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
306
+ GraphInfo_t* graphInfoDst
307
+ ) {
308
+ graphInfoDst->graphName = nullptr;
309
+ if (graphInfoSrc->graphName) {
310
+ graphInfoDst->graphName = __strdup(graphInfoSrc->graphName);
311
+ }
312
+ graphInfoDst->inputTensors = nullptr;
313
+ graphInfoDst->numInputTensors = 0;
314
+ if (graphInfoSrc->graphInputs) {
315
+ if (!copyTensorsInfo(
316
+ graphInfoSrc->graphInputs,
317
+ graphInfoDst->inputTensors,
318
+ graphInfoSrc->numGraphInputs
319
+ )) {
320
+ return false;
321
+ }
322
+ graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs;
323
+ }
324
+ graphInfoDst->outputTensors = nullptr;
325
+ graphInfoDst->numOutputTensors = 0;
326
+ if (graphInfoSrc->graphOutputs) {
327
+ if (!copyTensorsInfo(
328
+ graphInfoSrc->graphOutputs,
329
+ graphInfoDst->outputTensors,
330
+ graphInfoSrc->numGraphOutputs
331
+ )) {
332
+ return false;
333
+ }
334
+ graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs;
335
+ }
336
+ return true;
337
+ }
338
+
339
+ bool copyGraphsInfoV3(const QnnSystemContext_GraphInfoV3_t *graphInfoSrc,
340
+ GraphInfo_t *graphInfoDst) {
341
+ graphInfoDst->graphName = nullptr;
342
+ if (graphInfoSrc->graphName) {
343
+ graphInfoDst->graphName =
344
+ __strdup(graphInfoSrc->graphName);
345
+ }
346
+ graphInfoDst->inputTensors = nullptr;
347
+ graphInfoDst->numInputTensors = 0;
348
+ if (graphInfoSrc->graphInputs) {
349
+ if (!copyTensorsInfo(
350
+ graphInfoSrc->graphInputs, graphInfoDst->inputTensors, graphInfoSrc->numGraphInputs)) {
351
+ return false;
352
+ }
353
+ graphInfoDst->numInputTensors = graphInfoSrc->numGraphInputs;
354
+ }
355
+ graphInfoDst->outputTensors = nullptr;
356
+ graphInfoDst->numOutputTensors = 0;
357
+ if (graphInfoSrc->graphOutputs) {
358
+ if (!copyTensorsInfo(graphInfoSrc->graphOutputs,
359
+ graphInfoDst->outputTensors,
360
+ graphInfoSrc->numGraphOutputs)) {
361
+ return false;
362
+ }
363
+ graphInfoDst->numOutputTensors = graphInfoSrc->numGraphOutputs;
364
+ }
365
+ return true;
366
+ }
367
+
368
+ bool updateGraphInfo(const QnnSystemContext_GraphInfo_t* graphsInput,
369
+ const uint32_t numGraphs,
370
+ GraphInfo_t** graphsInfo,
371
+ uint32_t& graphsCount
372
+ ){
373
+
374
+ for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) {
375
+ if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
376
+ if(updateGraphInfoV1(&graphsInput[gIdx].graphInfoV1, graphsInfo[graphsCount]) == false) {
377
+ return false;
378
+ }
379
+ }
380
+ if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
381
+ if(updateGraphInfoV3(&graphsInput[gIdx].graphInfoV3, graphsInfo[graphsCount]) == false) {
382
+ return false;
383
+ }
384
+ }
385
+ graphsCount++;
386
+ }
387
+ return true;
388
+ }
389
+
390
+
391
+ bool copyGraphsInfo(
392
+ const QnnSystemContext_GraphInfo_t* graphsInput,
393
+ const uint32_t numGraphs,
394
+ GraphInfo_t**& graphsInfo
395
+ ) {
396
+
397
+ if (!graphsInput) {
398
+ QNN_ERROR("Received nullptr for graphsInput.");
399
+ return false;
400
+ }
401
+ auto returnStatus = true;
402
+ graphsInfo = (GraphInfo_t**)calloc(numGraphs, sizeof(GraphInfo_t*));
403
+ GraphInfo_t* graphInfoArr = (GraphInfo_t*)calloc(numGraphs, sizeof(GraphInfo_t));
404
+ if (nullptr == graphsInfo || nullptr == graphInfoArr) {
405
+ QNN_ERROR("Failure to allocate memory for *graphInfo");
406
+ returnStatus = false;
407
+ }
408
+ if (true == returnStatus) {
409
+ for (size_t gIdx = 0; gIdx < numGraphs; gIdx++) {
410
+ QNN_DEBUG("Extracting graphsInfo for graph Idx: %d", (int)gIdx);
411
+ if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_1) {
412
+ copyGraphsInfoV1(&graphsInput[gIdx].graphInfoV1, &graphInfoArr[gIdx]);
413
+ }
414
+ if (graphsInput[gIdx].version == QNN_SYSTEM_CONTEXT_GRAPH_INFO_VERSION_3) {
415
+ copyGraphsInfoV3(&graphsInput[gIdx].graphInfoV3, &graphInfoArr[gIdx]);
416
+ }
417
+ graphsInfo[gIdx] = graphInfoArr + gIdx;
418
+ }
419
+ }
420
+ if (true != returnStatus) {
421
+ QNN_DEBUG("Received an ERROR during extractGraphsInfo. Freeing resources.");
422
+ if (graphsInfo) {
423
+ for (uint32_t gIdx = 0; gIdx < numGraphs; gIdx++) {
424
+ if (graphsInfo[gIdx]) {
425
+ if (nullptr != graphsInfo[gIdx]->graphName) {
426
+ free(graphsInfo[gIdx]->graphName);
427
+ graphsInfo[gIdx]->graphName = nullptr;
428
+ }
429
+ freeQnnTensorWrappers(
430
+ graphsInfo[gIdx]->inputTensors, graphsInfo[gIdx]->numInputTensors
431
+ );
432
+ freeQnnTensorWrappers(
433
+ graphsInfo[gIdx]->outputTensors, graphsInfo[gIdx]->numOutputTensors
434
+ );
435
+ }
436
+ }
437
+ free(*graphsInfo);
438
+ }
439
+ free(graphsInfo);
440
+ graphsInfo = nullptr;
441
+ }
442
+
443
+ return true;
444
+ }
445
+
446
+ uint32_t getNumGraphInBinary(const QnnSystemContext_BinaryInfo_t* binaryInfo)
447
+ {
448
+ uint32_t numGraph = 0;
449
+ if (nullptr == binaryInfo) {
450
+ QNN_ERROR("binaryInfo is nullptr.");
451
+ return false;
452
+ }
453
+ if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
454
+ numGraph = binaryInfo->contextBinaryInfoV1.numGraphs;
455
+ }else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
456
+ numGraph = binaryInfo->contextBinaryInfoV2.numGraphs;
457
+ }
458
+ else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
459
+ numGraph = binaryInfo->contextBinaryInfoV3.numGraphs;
460
+ }
461
+ return numGraph;
462
+ }
463
+
464
+ bool updateMetaDataToGraphsInfo(const QnnSystemContext_BinaryInfo_t* binaryInfo,
465
+ GraphInfo_t** graphsInfo,
466
+ uint32_t& graphsCount
467
+ ){
468
+ if (nullptr == binaryInfo) {
469
+ QNN_ERROR("binaryInfo is nullptr.");
470
+ return false;
471
+ }
472
+ if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
473
+ if (binaryInfo->contextBinaryInfoV1.graphs) {
474
+ if (!updateGraphInfo(
475
+ binaryInfo->contextBinaryInfoV1.graphs,
476
+ binaryInfo->contextBinaryInfoV1.numGraphs,
477
+ graphsInfo,
478
+ graphsCount
479
+ )) {
480
+ QNN_ERROR("Failed while copying graphs Info.");
481
+ return false;
482
+ }
483
+ return true;
484
+ }
485
+ } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
486
+ if (binaryInfo->contextBinaryInfoV2.graphs) {
487
+ if (!updateGraphInfo(
488
+ binaryInfo->contextBinaryInfoV2.graphs,
489
+ binaryInfo->contextBinaryInfoV2.numGraphs,
490
+ graphsInfo,
491
+ graphsCount
492
+ )) {
493
+ QNN_ERROR("Failed while copying graphs Info.");
494
+ return false;
495
+ }
496
+ return true;
497
+ }
498
+ } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
499
+ if (binaryInfo->contextBinaryInfoV3.graphs) {
500
+ if (!updateGraphInfo(
501
+ binaryInfo->contextBinaryInfoV3.graphs,
502
+ binaryInfo->contextBinaryInfoV3.numGraphs,
503
+ graphsInfo,
504
+ graphsCount
505
+ )) {
506
+ QNN_ERROR("Failed while copying graphs Info.");
507
+ return false;
508
+ }
509
+ return true;
510
+ }
511
+ }
512
+ QNN_ERROR("Unrecognized system context binary info version.");
513
+ return false;
514
+ }
515
+
516
+ bool copyMetadataToGraphsInfo(
517
+ const QnnSystemContext_BinaryInfo_t* binaryInfo,
518
+ GraphInfo_t**& graphsInfo,
519
+ uint32_t& graphsCount
520
+ ) {
521
+ if (nullptr == binaryInfo) {
522
+ QNN_ERROR("binaryInfo is nullptr.");
523
+ return false;
524
+ }
525
+ graphsCount = 0;
526
+ if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_1) {
527
+ if (binaryInfo->contextBinaryInfoV1.graphs) {
528
+ if (!copyGraphsInfo(
529
+ binaryInfo->contextBinaryInfoV1.graphs,
530
+ binaryInfo->contextBinaryInfoV1.numGraphs,
531
+ graphsInfo
532
+ )) {
533
+ QNN_ERROR("Failed while copying graphs Info.");
534
+ return false;
535
+ }
536
+ graphsCount = binaryInfo->contextBinaryInfoV1.numGraphs;
537
+ return true;
538
+ }
539
+ } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_2) {
540
+ if (binaryInfo->contextBinaryInfoV2.graphs) {
541
+ if (!copyGraphsInfo(
542
+ binaryInfo->contextBinaryInfoV2.graphs,
543
+ binaryInfo->contextBinaryInfoV2.numGraphs,
544
+ graphsInfo
545
+ )) {
546
+ QNN_ERROR("Failed while copying graphs Info.");
547
+ return false;
548
+ }
549
+ graphsCount = binaryInfo->contextBinaryInfoV2.numGraphs;
550
+ return true;
551
+ }
552
+ } else if (binaryInfo->version == QNN_SYSTEM_CONTEXT_BINARY_INFO_VERSION_3) {
553
+ if (binaryInfo->contextBinaryInfoV3.graphs) {
554
+ if (!copyGraphsInfo(binaryInfo->contextBinaryInfoV3.graphs,
555
+ binaryInfo->contextBinaryInfoV3.numGraphs,
556
+ graphsInfo)) {
557
+ QNN_ERROR("Failed while copying graphs Info.");
558
+ return false;
559
+ }
560
+ graphsCount = binaryInfo->contextBinaryInfoV3.numGraphs;
561
+ return true;
562
+ }
563
+ }
564
+ QNN_ERROR("Unrecognized system context binary info version.");
565
+ return false;
566
+ }
567
+
568
+ size_t getFileSize(std::string filePath) {
569
+ std::ifstream in(filePath, std::ifstream::binary);
570
+ if (!in) {
571
+ QNN_ERROR("Failed to open input file: %s", filePath.c_str());
572
+ return 0;
573
+ }
574
+ in.seekg(0, in.end);
575
+ const size_t length = in.tellg();
576
+ in.seekg(0, in.beg);
577
+ return length;
578
+ }
579
+
580
+ bool readBinaryFromFile(std::string filePath, void* buffer, size_t bufferSize) {
581
+ if (nullptr == buffer) {
582
+ QNN_ERROR("buffer is nullptr");
583
+ return false;
584
+ }
585
+ std::ifstream in(filePath, std::ifstream::binary);
586
+ if (!in) {
587
+ QNN_ERROR("Failed to open input file: %s", filePath.c_str());
588
+ return false;
589
+ }
590
+ if (!in.read(reinterpret_cast<char*>(buffer), bufferSize)) {
591
+ QNN_ERROR("Failed to read the contents of: %s", filePath.c_str());
592
+ return false;
593
+ }
594
+ return true;
595
+ }
596
+
597
+ bool mmapBinaryFile(std::string filePath, void** buffer, size_t bufferSize) {
598
+ #ifndef _WIN32
599
+ int fd = open(filePath.c_str(), O_RDONLY);
600
+ int OFFSET = 0;
601
+
602
+ // read the binary file as memory map
603
+ *buffer = mmap(nullptr, bufferSize, PROT_READ, MAP_PRIVATE, fd, OFFSET);
604
+ close(fd);
605
+ if (madvise(*buffer, bufferSize, MADV_NOHUGEPAGE)) {
606
+ QNN_ERROR("Failed to advise OS on memory usage err: %s", strerror(errno));
607
+ }
608
+
609
+ return true;
610
+ #else
611
+ return false;
612
+ #endif
613
+ }
614
+
615
+ bool fillDims(std::vector<size_t>& dims, uint32_t* inDimensions, uint32_t rank) {
616
+ if (nullptr == inDimensions) {
617
+ QNN_ERROR("input dimensions is nullptr");
618
+ return false;
619
+ }
620
+
621
+ if (rank < 1) {
622
+ QNN_ERROR("invalid rank : %d", rank);
623
+ return false;
624
+ }
625
+
626
+ // In case, rank is less than 4, we are pushing 1s
627
+ for (size_t r = 0; r < 4 - rank; r++) {
628
+ dims.push_back(1);
629
+ }
630
+
631
+ for (size_t r = 0; r < rank; r++) {
632
+ dims.push_back(inDimensions[r]);
633
+ }
634
+
635
+ return true;
636
+ }
Genie/Genie/src/qualla/engines/qnn-api/QnnApiUtils.hpp ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "QnnInterface.h"
10
+ #include "QnnTypes.h"
11
+ #include "System/QnnSystemInterface.h"
12
+
13
+ #include <iostream>
14
+ #include <map>
15
+ #include <queue>
16
+ #include <string>
17
+ #include <unordered_map>
18
+ #include <vector>
19
+
20
+ #include "QnnTypeDef.hpp"
21
+ #include "Log.hpp"
22
+
23
+ /**
24
+ * @brief Frees all memory allocated tensor attributes.
25
+ *
26
+ * @param[in] tensorWrapper tensor object to free
27
+ *
28
+ * @return Error code
29
+ */
30
+ bool freeQnnTensorWrapper(TensorWrapper& tensorWrapper);
31
+
32
+ /**
33
+ * @brief Loops through and frees all memory allocated tensor attributes for each tensorWrapper
34
+ * object.
35
+ *
36
+ * @param[in] tensorWrappers array of tensor objects to free
37
+ *
38
+ * @param[in] numTensors length of the above tensorWrappers array
39
+ *
40
+ * @return Error code
41
+ */
42
+ bool freeQnnTensorWrappers(TensorWrapper*& tensorWrappers, uint32_t numTensors);
43
+
44
+ /**
45
+ * @brief A helper function to free memory malloced for communicating the Graph for a model(s)
46
+ *
47
+ * @param[in] graphsInfo Pointer pointing to location of graph objects
48
+ *
49
+ * @param[in] numGraphs The number of graph objects the above pointer is pointing to
50
+ *
51
+ * @return Error code
52
+ *
53
+ */
54
+ bool freeGraphsInfo(GraphInfoPtr_t** graphsInfo, uint32_t numGraphs);
55
+
56
+ bool freeGraphInfo(GraphInfo_t* graphInfo);
57
+
58
+ bool copyMetadataToGraphsInfo(
59
+ const QnnSystemContext_BinaryInfo_t* binaryInfo,
60
+ GraphInfo_t**& graphsInfo,
61
+ uint32_t& graphsCount
62
+ );
63
+
64
+ bool copyGraphsInfo(
65
+ const QnnSystemContext_GraphInfo_t* graphsInput,
66
+ const uint32_t numGraphs,
67
+ GraphInfo_t**& graphsInfo
68
+ );
69
+
70
+ bool copyGraphsInfoV1(
71
+ const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
72
+ GraphInfo_t* graphInfoDst
73
+ );
74
+
75
+ bool copyTensorsInfo(
76
+ const Qnn_Tensor_t* tensorsInfoSrc,
77
+ TensorWrapper*& tensorWrappers,
78
+ uint32_t tensorsCount
79
+ );
80
+
81
+ bool fillDims(std::vector<size_t>& dims, uint32_t* inDimensions, uint32_t rank);
82
+ size_t getFileSize(std::string filePath);
83
+ bool readBinaryFromFile(std::string filePath, void* buffer, size_t bufferSize);
84
+ bool mmapBinaryFile(std::string filePath, void** buffer, size_t bufferSize);
85
+ bool updateMetaDataToGraphsInfo(const QnnSystemContext_BinaryInfo_t* binaryInfo,GraphInfo_t** graphsInfo,uint32_t& graphsCount);
86
+ bool updateGraphInfo(const QnnSystemContext_GraphInfo_t* graphsInput,
87
+ const uint32_t currCount,
88
+ GraphInfo_t* graphsInfo);
89
+ bool updateGraphInfoV1(const QnnSystemContext_GraphInfoV1_t* graphInfoSrc,
90
+ GraphInfo_t* graphInfoDst);
91
+ bool updateTensorInfo(const Qnn_Tensor_t* tensorsInfoSrc,
92
+ TensorWrapper* tensorWrappers,
93
+ uint32_t tensorsCount);
94
+ uint32_t getNumGraphInBinary(const QnnSystemContext_BinaryInfo_t* binaryInfo);
Genie/Genie/src/qualla/engines/qnn-api/QnnConfig.hpp ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+ #pragma once
9
+
10
+ #include "QnnGraph.h"
11
+ #include "QnnTypes.h"
12
+ #include <vector>
13
+
14
+ struct BackendExtensionsConfigs {
15
+ std::string sharedLibraryPath;
16
+ std::string configFilePath;
17
+ BackendExtensionsConfigs() : sharedLibraryPath(""), configFilePath("") {}
18
+ BackendExtensionsConfigs(std::string sharedLibraryPath, std::string configFilePath)
19
+ : sharedLibraryPath(sharedLibraryPath), configFilePath(configFilePath) {}
20
+ };
21
+
22
+ struct ContextConfigs {
23
+ bool priorityPresent;
24
+ Qnn_Priority_t priority;
25
+ ContextConfigs() : priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) {}
26
+ ContextConfigs(Qnn_Priority_t priority) : priorityPresent(true), priority(priority) {}
27
+ };
28
+
29
+ struct GraphConfigs {
30
+ std::string graphName;
31
+ bool priorityPresent;
32
+ Qnn_Priority_t priority;
33
+ GraphConfigs()
34
+ : graphName(),
35
+ priorityPresent(false), priority(QNN_PRIORITY_UNDEFINED) {
36
+ }
37
+ };
38
+
39
+ struct ConfigOptions {
40
+ BackendExtensionsConfigs backendExtensionsConfigs;
41
+ ContextConfigs contextConfigs;
42
+ std::vector<GraphConfigs> graphConfigs;
43
+ ConfigOptions() : backendExtensionsConfigs(), contextConfigs(), graphConfigs() {}
44
+ };
Genie/Genie/src/qualla/engines/qnn-api/QnnTypeDef.hpp ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #ifndef QNN_TYPE_DEF_H_
10
+ #define QNN_TYPE_DEF_H_
11
+
12
+ #include "QnnInterface.h"
13
+ #include "QnnTypes.h"
14
+ #include "Log.hpp"
15
+ #include "QnnTypeMacros.hpp"
16
+
17
+ typedef enum ModelError {
18
+ MODEL_NO_ERROR = 0,
19
+ MODEL_TENSOR_ERROR = 1,
20
+ MODEL_PARAMS_ERROR = 2,
21
+ MODEL_NODES_ERROR = 3,
22
+ MODEL_GRAPH_ERROR = 4,
23
+ MODEL_CONTEXT_ERROR = 5,
24
+ MODEL_GENERATION_ERROR = 6,
25
+ MODEL_SETUP_ERROR = 7,
26
+ MODEL_INVALID_ARGUMENT_ERROR = 8,
27
+ MODEL_FILE_ERROR = 9,
28
+ MODEL_MEMORY_ALLOCATE_ERROR = 10,
29
+ // Value selected to ensure 32 bits.
30
+ MODEL_UNKNOWN_ERROR = 0x7FFFFFFF
31
+ } ModelError_t;
32
+
33
+ using TensorWrapper = Qnn_Tensor_t;
34
+ #define GET_TENSOR_WRAPPER_TENSOR(tensorWrapper) tensorWrapper
35
+ #define GET_TENSOR_WRAPPER_NAME(tensorWrapper) QNN_TENSOR_GET_NAME(tensorWrapper)
36
+
37
+ typedef struct GraphInfo {
38
+ Qnn_GraphHandle_t graph;
39
+ char* graphName;
40
+ TensorWrapper* inputTensors;
41
+ uint32_t numInputTensors;
42
+ TensorWrapper* outputTensors;
43
+ uint32_t numOutputTensors;
44
+ } GraphInfo_t;
45
+ typedef GraphInfo_t* GraphInfoPtr_t;
46
+
47
+ typedef struct GraphConfigInfo {
48
+ char* graphName;
49
+ const QnnGraph_Config_t** graphConfigs;
50
+ } GraphConfigInfo_t;
51
+
52
+ #endif // QNN_TYPE_DEF_H_
Genie/Genie/src/qualla/engines/qnn-api/QnnTypeMacros.hpp ADDED
@@ -0,0 +1,702 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include "QnnTypes.h"
12
+
13
+ #define QNN_OP_CFG_VALID(opConfig) ((opConfig).version == QNN_OPCONFIG_VERSION_1)
14
+
15
+ inline Qnn_OpConfig_t createQnnOpConfig(const Qnn_OpConfigVersion_t version) {
16
+ Qnn_OpConfig_t opConfig = QNN_OPCONFIG_INIT;
17
+ opConfig.version = version;
18
+ if (version == QNN_OPCONFIG_VERSION_1) {
19
+ opConfig.v1 = QNN_OPCONFIG_V1_INIT;
20
+ }
21
+ return opConfig;
22
+ }
23
+
24
+ inline const char* getQnnOpConfigName(const Qnn_OpConfig_t& opConfig) {
25
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
26
+ return opConfig.v1.name;
27
+ }
28
+ return NULL;
29
+ }
30
+
31
+ inline const char* getQnnOpConfigName(const Qnn_OpConfig_t* const opConfig) {
32
+ return getQnnOpConfigName(*opConfig);
33
+ }
34
+
35
+ inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t& opConfig) {
36
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
37
+ return opConfig.v1.packageName;
38
+ }
39
+ return NULL;
40
+ }
41
+
42
+ inline const char* getQnnOpConfigPackageName(const Qnn_OpConfig_t* const opConfig) {
43
+ return getQnnOpConfigPackageName(*opConfig);
44
+ }
45
+
46
+ inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t& opConfig) {
47
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
48
+ return opConfig.v1.typeName;
49
+ }
50
+ return NULL;
51
+ }
52
+
53
+ inline const char* getQnnOpConfigTypeName(const Qnn_OpConfig_t* const opConfig) {
54
+ return getQnnOpConfigTypeName(*opConfig);
55
+ }
56
+
57
+ inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t& opConfig) {
58
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
59
+ return opConfig.v1.numOfParams;
60
+ }
61
+ return 0u;
62
+ }
63
+
64
+ inline uint32_t getQnnOpConfigNumParams(const Qnn_OpConfig_t* const opConfig) {
65
+ return getQnnOpConfigNumParams(*opConfig);
66
+ }
67
+
68
+ inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t& opConfig) {
69
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
70
+ return opConfig.v1.params;
71
+ }
72
+ return NULL;
73
+ }
74
+
75
+ inline const Qnn_Param_t* getQnnOpConfigParams(const Qnn_OpConfig_t* const opConfig) {
76
+ return getQnnOpConfigParams(*opConfig);
77
+ }
78
+
79
+ inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t& opConfig) {
80
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
81
+ return opConfig.v1.numOfInputs;
82
+ }
83
+ return 0u;
84
+ }
85
+
86
+ inline uint32_t getQnnOpConfigNumInputs(const Qnn_OpConfig_t* const opConfig) {
87
+ return getQnnOpConfigNumInputs(*opConfig);
88
+ }
89
+
90
+ inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t& opConfig) {
91
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
92
+ return opConfig.v1.inputTensors;
93
+ }
94
+ return NULL;
95
+ }
96
+
97
+ inline const Qnn_Tensor_t* getQnnOpConfigInputs(const Qnn_OpConfig_t* const opConfig) {
98
+ return getQnnOpConfigInputs(*opConfig);
99
+ }
100
+
101
+ inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t& opConfig) {
102
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
103
+ return opConfig.v1.numOfOutputs;
104
+ }
105
+ return 0u;
106
+ }
107
+
108
+ inline uint32_t getQnnOpConfigNumOutputs(const Qnn_OpConfig_t* const opConfig) {
109
+ return getQnnOpConfigNumOutputs(*opConfig);
110
+ }
111
+
112
+ inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t& opConfig) {
113
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
114
+ return opConfig.v1.outputTensors;
115
+ }
116
+ return NULL;
117
+ }
118
+
119
+ inline const Qnn_Tensor_t* getQnnOpConfigOutputs(const Qnn_OpConfig_t* const opConfig) {
120
+ return getQnnOpConfigOutputs(*opConfig);
121
+ }
122
+
123
+ inline void setQnnOpConfigName(Qnn_OpConfig_t& opConfig, const char* const name) {
124
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
125
+ opConfig.v1.name = name;
126
+ }
127
+ }
128
+
129
+ inline void setQnnOpConfigName(Qnn_OpConfig_t* const opConfig, const char* const name) {
130
+ setQnnOpConfigName(*opConfig, name);
131
+ }
132
+
133
+ inline void setQnnOpConfigPackageName(Qnn_OpConfig_t& opConfig, const char* const packageName) {
134
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
135
+ opConfig.v1.packageName = packageName;
136
+ }
137
+ }
138
+
139
+ inline void setQnnOpConfigPackageName(
140
+ Qnn_OpConfig_t* const opConfig,
141
+ const char* const packageName
142
+ ) {
143
+ setQnnOpConfigPackageName(*opConfig, packageName);
144
+ }
145
+
146
+ inline void setQnnOpConfigTypeName(Qnn_OpConfig_t& opConfig, const char* const typeName) {
147
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
148
+ opConfig.v1.typeName = typeName;
149
+ }
150
+ }
151
+
152
+ inline void setQnnOpConfigTypeName(Qnn_OpConfig_t* const opConfig, const char* const typeName) {
153
+ setQnnOpConfigTypeName(*opConfig, typeName);
154
+ }
155
+
156
+ inline void setQnnOpConfigParams(
157
+ Qnn_OpConfig_t& opConfig,
158
+ uint32_t const numOfParams,
159
+ Qnn_Param_t* const params
160
+ ) {
161
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
162
+ opConfig.v1.numOfParams = numOfParams;
163
+ opConfig.v1.params = params;
164
+ }
165
+ }
166
+
167
+ inline void setQnnOpConfigParams(
168
+ Qnn_OpConfig_t* const opConfig,
169
+ uint32_t const numOfParams,
170
+ Qnn_Param_t* const params
171
+ ) {
172
+ setQnnOpConfigParams(*opConfig, numOfParams, params);
173
+ }
174
+
175
+ inline void setQnnOpConfigInputs(
176
+ Qnn_OpConfig_t& opConfig,
177
+ uint32_t const numOfInputs,
178
+ Qnn_Tensor_t* const inputTensors
179
+ ) {
180
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
181
+ opConfig.v1.numOfInputs = numOfInputs;
182
+ opConfig.v1.inputTensors = inputTensors;
183
+ }
184
+ }
185
+
186
+ inline void setQnnOpConfigInputs(
187
+ Qnn_OpConfig_t* const opConfig,
188
+ uint32_t const numOfInputs,
189
+ Qnn_Tensor_t* const inputTensors
190
+ ) {
191
+ setQnnOpConfigInputs(*opConfig, numOfInputs, inputTensors);
192
+ }
193
+
194
+ inline void setQnnOpConfigOutputs(
195
+ Qnn_OpConfig_t& opConfig,
196
+ uint32_t const numOfOutputs,
197
+ Qnn_Tensor_t* const outputTensors
198
+ ) {
199
+ if (opConfig.version == QNN_OPCONFIG_VERSION_1) {
200
+ opConfig.v1.numOfOutputs = numOfOutputs;
201
+ opConfig.v1.outputTensors = outputTensors;
202
+ }
203
+ }
204
+
205
+ inline void setQnnOpConfigOutputs(
206
+ Qnn_OpConfig_t* const opConfig,
207
+ uint32_t const numOfOutputs,
208
+ Qnn_Tensor_t* const outputTensors
209
+ ) {
210
+ setQnnOpConfigOutputs(*opConfig, numOfOutputs, outputTensors);
211
+ }
212
+
213
+ inline Qnn_Tensor_t createQnnTensor(const Qnn_TensorVersion_t version) {
214
+ Qnn_Tensor_t tensor = QNN_TENSOR_INIT;
215
+ tensor.version = version;
216
+ if (version == QNN_TENSOR_VERSION_1) {
217
+ tensor.v1 = QNN_TENSOR_V1_INIT;
218
+ } else if (version == QNN_TENSOR_VERSION_2) {
219
+ tensor.v2 = QNN_TENSOR_V2_INIT;
220
+ }
221
+ return tensor;
222
+ }
223
+
224
+ inline uint32_t getQnnTensorId(const Qnn_Tensor_t& tensor) {
225
+ // TensorCompatTest justifies no need to check version
226
+ return tensor.v1.id;
227
+ }
228
+
229
+ inline uint32_t getQnnTensorId(const Qnn_Tensor_t* const tensor) {
230
+ return getQnnTensorId(*tensor);
231
+ }
232
+
233
+ inline const char* getQnnTensorName(const Qnn_Tensor_t& tensor) {
234
+ // TensorCompatTest justifies no need to check version
235
+ return tensor.v1.name;
236
+ }
237
+
238
+ inline const char* getQnnTensorName(const Qnn_Tensor_t* const tensor) {
239
+ return getQnnTensorName(*tensor);
240
+ }
241
+
242
+ inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t& tensor) {
243
+ // TensorCompatTest justifies no need to check version
244
+ return tensor.v1.type;
245
+ }
246
+
247
+ inline Qnn_TensorType_t getQnnTensorType(const Qnn_Tensor_t* const tensor) {
248
+ return getQnnTensorType(*tensor);
249
+ }
250
+
251
+ inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t& tensor) {
252
+ // TensorCompatTest justifies no need to check version
253
+ return tensor.v1.dataFormat;
254
+ }
255
+
256
+ inline Qnn_TensorDataFormat_t getQnnTensorDataFormat(const Qnn_Tensor_t* const tensor) {
257
+ return getQnnTensorDataFormat(*tensor);
258
+ }
259
+
260
+ inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t& tensor) {
261
+ // TensorCompatTest justifies no need to check version
262
+ return tensor.v1.dataType;
263
+ }
264
+
265
+ inline Qnn_DataType_t getQnnTensorDataType(const Qnn_Tensor_t* const tensor) {
266
+ return getQnnTensorDataType(*tensor);
267
+ }
268
+
269
+ inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t& tensor) {
270
+ // TensorCompatTest justifies no need to check version
271
+ return tensor.v1.quantizeParams;
272
+ }
273
+
274
+ inline Qnn_QuantizeParams_t getQnnTensorQuantParams(const Qnn_Tensor_t* const tensor) {
275
+ if (tensor != nullptr) {
276
+ return getQnnTensorQuantParams(*tensor);
277
+ }
278
+ return QNN_QUANTIZE_PARAMS_INIT;
279
+ }
280
+
281
+ inline uint32_t getQnnTensorRank(const Qnn_Tensor_t& tensor) {
282
+ // TensorCompatTest justifies no need to check version
283
+ return tensor.v1.rank;
284
+ }
285
+
286
+ inline uint32_t getQnnTensorRank(const Qnn_Tensor_t* const tensor) {
287
+ if (tensor != nullptr) {
288
+ return getQnnTensorRank(*tensor);
289
+ }
290
+ return 0u;
291
+ }
292
+
293
+ inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t& tensor) {
294
+ // TensorCompatTest justifies no need to check version
295
+ return tensor.v1.dimensions;
296
+ }
297
+
298
+ inline uint32_t* getQnnTensorDimensions(const Qnn_Tensor_t* const tensor) {
299
+ return getQnnTensorDimensions(*tensor);
300
+ }
301
+
302
+ inline uint8_t* getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t& tensor) {
303
+ if (tensor.version == QNN_TENSOR_VERSION_1) {
304
+ return NULL;
305
+ } else if (tensor.version == QNN_TENSOR_VERSION_2) {
306
+ return tensor.v2.isDynamicDimensions;
307
+ }
308
+ return NULL;
309
+ }
310
+
311
+ inline uint8_t* getQnnTensorIsDynamicDimensions(const Qnn_Tensor_t* tensor) {
312
+ return getQnnTensorIsDynamicDimensions(*tensor);
313
+ }
314
+
315
+ inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t& tensor) {
316
+ if (tensor.version == QNN_TENSOR_VERSION_1) {
317
+ return QNN_SPARSE_PARAMS_INIT;
318
+ } else if (tensor.version == QNN_TENSOR_VERSION_2) {
319
+ return tensor.v2.sparseParams;
320
+ }
321
+ return QNN_SPARSE_PARAMS_INIT;
322
+ }
323
+
324
+ inline Qnn_SparseParams_t getQnnTensorSparseParams(const Qnn_Tensor_t* tensor) {
325
+ return getQnnTensorSparseParams(*tensor);
326
+ }
327
+
328
+ inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t& tensor) {
329
+ // TensorCompatTest justifies no need to check version
330
+ return tensor.v1.memType;
331
+ }
332
+
333
+ inline Qnn_TensorMemType_t getQnnTensorMemType(const Qnn_Tensor_t* const tensor) {
334
+ return getQnnTensorMemType(*tensor);
335
+ }
336
+
337
+ inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t& tensor) {
338
+ // TensorCompatTest justifies no need to check version
339
+ return tensor.v1.clientBuf;
340
+ }
341
+
342
+ inline Qnn_ClientBuffer_t getQnnTensorClientBuf(const Qnn_Tensor_t* const tensor) {
343
+ return getQnnTensorClientBuf(*tensor);
344
+ }
345
+
346
+ inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t& tensor) {
347
+ // TensorCompatTest justifies no need to check version
348
+ return tensor.v1.memHandle;
349
+ }
350
+
351
+ inline Qnn_MemHandle_t getQnnTensorMemHandle(const Qnn_Tensor_t* const tensor) {
352
+ return getQnnTensorMemHandle(*tensor);
353
+ }
354
+
355
+ inline void setQnnTensorId(Qnn_Tensor_t& tensor, const uint32_t id) {
356
+ // TensorCompatTest justifies no need to check version
357
+ tensor.v1.id = id;
358
+ }
359
+
360
+ inline void setQnnTensorId(Qnn_Tensor_t* const tensor, const uint32_t id) {
361
+ setQnnTensorId(*tensor, id);
362
+ }
363
+
364
+ inline void setQnnTensorName(Qnn_Tensor_t& tensor, const char* const name) {
365
+ // TensorCompatTest justifies no need to check version
366
+ tensor.v1.name = name;
367
+ }
368
+
369
+ inline void setQnnTensorName(Qnn_Tensor_t* const tensor, const char* const name) {
370
+ setQnnTensorName(*tensor, name);
371
+ }
372
+
373
+ inline void setQnnTensorType(Qnn_Tensor_t& tensor, const Qnn_TensorType_t type) {
374
+ // TensorCompatTest justifies no need to check version
375
+ tensor.v1.type = type;
376
+ }
377
+
378
+ inline void setQnnTensorType(Qnn_Tensor_t* const tensor, const Qnn_TensorType_t type) {
379
+ setQnnTensorType(*tensor, type);
380
+ }
381
+
382
+ inline void setQnnTensorDataFormat(Qnn_Tensor_t& tensor, const Qnn_TensorDataFormat_t dataFormat) {
383
+ // TensorCompatTest justifies no need to check version
384
+ tensor.v1.dataFormat = dataFormat;
385
+ }
386
+
387
+ inline void setQnnTensorDataFormat(
388
+ Qnn_Tensor_t* const tensor,
389
+ const Qnn_TensorDataFormat_t format
390
+ ) {
391
+ setQnnTensorDataFormat(*tensor, format);
392
+ }
393
+
394
+ inline void setQnnTensorDataType(Qnn_Tensor_t& tensor, const Qnn_DataType_t dataType) {
395
+ // TensorCompatTest justifies no need to check version
396
+ tensor.v1.dataType = dataType;
397
+ }
398
+
399
+ inline void setQnnTensorDataType(Qnn_Tensor_t* const tensor, const Qnn_DataType_t dataType) {
400
+ setQnnTensorDataType(*tensor, dataType);
401
+ }
402
+
403
+ inline void setQnnTensorQuantParams(
404
+ Qnn_Tensor_t& tensor,
405
+ const Qnn_QuantizeParams_t quantizeParams
406
+ ) {
407
+ // TensorCompatTest justifies no need to check version
408
+ tensor.v1.quantizeParams = quantizeParams;
409
+ }
410
+
411
+ inline void setQnnTensorQuantParams(Qnn_Tensor_t* const tensor, const Qnn_QuantizeParams_t params) {
412
+ setQnnTensorQuantParams(*tensor, params);
413
+ }
414
+
415
+ inline void setQnnTensorRank(Qnn_Tensor_t& tensor, const uint32_t rank) {
416
+ // TensorCompatTest justifies no need to check version
417
+ tensor.v1.rank = rank;
418
+ }
419
+
420
+ inline void setQnnTensorRank(Qnn_Tensor_t* const tensor, const uint32_t rank) {
421
+ setQnnTensorRank(*tensor, rank);
422
+ }
423
+
424
+ inline void setQnnTensorDimensions(Qnn_Tensor_t& tensor, uint32_t* const dimensions) {
425
+ // TensorCompatTest justifies no need to check version
426
+ tensor.v1.dimensions = dimensions;
427
+ }
428
+
429
+ inline void setQnnTensorDimensions(Qnn_Tensor_t* const tensor, uint32_t* const dimensions) {
430
+ setQnnTensorDimensions(*tensor, dimensions);
431
+ }
432
+
433
+ inline void setQnnTensorIsDynamicDimensions(
434
+ Qnn_Tensor_t& tensor,
435
+ uint8_t* const isDynamicDimensions
436
+ ) {
437
+ if (tensor.version == QNN_TENSOR_VERSION_2) {
438
+ tensor.v2.isDynamicDimensions = isDynamicDimensions;
439
+ }
440
+ }
441
+
442
+ inline void setQnnTensorIsDynamicDimensions(
443
+ Qnn_Tensor_t* tensor,
444
+ uint8_t* const isDynamicDimensions
445
+ ) {
446
+ setQnnTensorIsDynamicDimensions(*tensor, isDynamicDimensions);
447
+ }
448
+
449
+ inline void setQnnTensorSparseParams(Qnn_Tensor_t& tensor, const Qnn_SparseParams_t sparseParams) {
450
+ if (tensor.version == QNN_TENSOR_VERSION_2) {
451
+ tensor.v2.sparseParams = sparseParams;
452
+ }
453
+ }
454
+
455
+ inline void setQnnTensorSparseParams(Qnn_Tensor_t* tensor, Qnn_SparseParams_t sparseParams) {
456
+ setQnnTensorSparseParams(*tensor, sparseParams);
457
+ }
458
+
459
+ inline void setQnnTensorMemType(Qnn_Tensor_t& tensor, const Qnn_TensorMemType_t memType) {
460
+ // TensorCompatTest justifies no need to check version
461
+ tensor.v1.memType = memType;
462
+ }
463
+
464
+ inline void setQnnTensorMemType(Qnn_Tensor_t* const tensor, const Qnn_TensorMemType_t memType) {
465
+ setQnnTensorMemType(*tensor, memType);
466
+ }
467
+
468
+ inline void setQnnTensorClientBuf(Qnn_Tensor_t& tensor, const Qnn_ClientBuffer_t clientBuf) {
469
+ // TensorCompatTest justifies no need to check version
470
+ tensor.v1.clientBuf = clientBuf;
471
+ }
472
+
473
+ inline void setQnnTensorClientBuf(Qnn_Tensor_t* const tensor, const Qnn_ClientBuffer_t clientBuf) {
474
+ setQnnTensorClientBuf(*tensor, clientBuf);
475
+ }
476
+
477
+ inline void setQnnTensorMemHandle(Qnn_Tensor_t& tensor, const Qnn_MemHandle_t memHandle) {
478
+ // TensorCompatTest justifies no need to check version
479
+ tensor.v1.memHandle = memHandle;
480
+ }
481
+
482
+ inline void setQnnTensorMemHandle(Qnn_Tensor_t* const tensor, const Qnn_MemHandle_t handle) {
483
+ setQnnTensorMemHandle(*tensor, handle);
484
+ }
485
+
486
+ inline Qnn_TensorSet_t createQnnTensorSet(const Qnn_TensorSetVersion_t version) {
487
+ Qnn_TensorSet_t tensorSet = QNN_TENSOR_SET_INIT;
488
+ tensorSet.version = version;
489
+ if (version == QNN_TENSOR_SET_VERSION_1) {
490
+ tensorSet.v1 = QNN_TENSOR_SET_V1_INIT;
491
+ }
492
+ return tensorSet;
493
+ }
494
+
495
+ inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t& tensorSet) {
496
+ if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
497
+ return tensorSet.v1.numInputs;
498
+ }
499
+ return 0;
500
+ }
501
+
502
+ inline uint32_t getQnnTensorSetNumInputs(const Qnn_TensorSet_t* tensorSet) {
503
+ return getQnnTensorSetNumInputs(*tensorSet);
504
+ }
505
+
506
+ inline Qnn_Tensor_t* getQnnTensorSetInputTensors(const Qnn_TensorSet_t& tensorSet) {
507
+ if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
508
+ return tensorSet.v1.inputs;
509
+ }
510
+ return 0;
511
+ }
512
+
513
+ inline Qnn_Tensor_t* getQnnTensorSetInputTensors(const Qnn_TensorSet_t* tensorSet) {
514
+ return getQnnTensorSetInputTensors(*tensorSet);
515
+ }
516
+
517
+ inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t& tensorSet) {
518
+ if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
519
+ return tensorSet.v1.numOutputs;
520
+ }
521
+ return 0;
522
+ }
523
+
524
+ inline uint32_t getQnnTensorSetNumOutputs(const Qnn_TensorSet_t* tensorSet) {
525
+ return getQnnTensorSetNumOutputs(*tensorSet);
526
+ }
527
+
528
+ inline Qnn_Tensor_t* getQnnTensorSetOutputTensors(const Qnn_TensorSet_t& tensorSet) {
529
+ if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
530
+ return tensorSet.v1.outputs;
531
+ }
532
+ return 0;
533
+ }
534
+
535
+ inline Qnn_Tensor_t* getQnnTensorSetOutputTensors(const Qnn_TensorSet_t* tensorSet) {
536
+ return getQnnTensorSetOutputTensors(*tensorSet);
537
+ }
538
+
539
+ inline void setQnnTensorSetInputTensors(
540
+ Qnn_TensorSet_t& tensorSet,
541
+ Qnn_Tensor_t* inputTensors,
542
+ uint32_t const numInputs
543
+ ) {
544
+ if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
545
+ tensorSet.v1.inputs = inputTensors;
546
+ tensorSet.v1.numInputs = numInputs;
547
+ }
548
+ }
549
+
550
+ inline void setQnnTensorSetInputTensors(
551
+ Qnn_TensorSet_t* tensorSet,
552
+ Qnn_Tensor_t* inputTensors,
553
+ uint32_t const numInputs
554
+ ) {
555
+ setQnnTensorSetInputTensors(*tensorSet, inputTensors, numInputs);
556
+ }
557
+
558
+ inline void setQnnTensorSetOutputTensors(
559
+ Qnn_TensorSet_t& tensorSet,
560
+ Qnn_Tensor_t* outputTensors,
561
+ const uint32_t numOutputs
562
+ ) {
563
+ if (tensorSet.version == QNN_TENSOR_SET_VERSION_1) {
564
+ tensorSet.v1.outputs = outputTensors;
565
+ tensorSet.v1.numOutputs = numOutputs;
566
+ }
567
+ }
568
+
569
+ inline void setQnnTensorSetOutputTensors(
570
+ Qnn_TensorSet_t* tensorSet,
571
+ Qnn_Tensor_t* outputTensors,
572
+ const uint32_t numOutputs
573
+ ) {
574
+ setQnnTensorSetOutputTensors(*tensorSet, outputTensors, numOutputs);
575
+ }
576
+
577
+ // Creator for QNN Op Config
578
+ #define QNN_OP_CFG_CREATE(version) createQnnOpConfig(version)
579
+
580
+ // Accessors for QNN Op Config
581
+ #define QNN_OP_CFG_GET_NAME(opConfig) getQnnOpConfigName(opConfig)
582
+ #define QNN_OP_CFG_GET_PACKAGE_NAME(opConfig) getQnnOpConfigPackageName(opConfig)
583
+ #define QNN_OP_CFG_GET_TYPE_NAME(opConfig) getQnnOpConfigTypeName(opConfig)
584
+ #define QNN_OP_CFG_GET_NUM_PARAMS(opConfig) getQnnOpConfigNumParams(opConfig)
585
+ #define QNN_OP_CFG_GET_PARAMS(opConfig) getQnnOpConfigParams(opConfig)
586
+ #define QNN_OP_CFG_GET_NUM_INPUTS(opConfig) getQnnOpConfigNumInputs(opConfig)
587
+ #define QNN_OP_CFG_GET_INPUTS(opConfig) getQnnOpConfigInputs(opConfig)
588
+ #define QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) getQnnOpConfigNumOutputs(opConfig)
589
+ #define QNN_OP_CFG_GET_OUTPUTS(opConfig) getQnnOpConfigOutputs(opConfig)
590
+
591
+ // Modifiers for QNN Op Config
592
+ #define QNN_OP_CFG_SET_NAME(opConfig, value) setQnnOpConfigName(opConfig, value)
593
+ #define QNN_OP_CFG_SET_PACKAGE_NAME(opConfig, value) setQnnOpConfigPackageName(opConfig, value)
594
+ #define QNN_OP_CFG_SET_TYPE_NAME(opConfig, value) setQnnOpConfigTypeName(opConfig, value)
595
+ #define QNN_OP_CFG_SET_PARAMS(opConfig, numOfParams, params) \
596
+ setQnnOpConfigParams(opConfig, numOfParams, params)
597
+ #define QNN_OP_CFG_SET_INPUTS(opConfig, numOfInputs, inputTensors) \
598
+ setQnnOpConfigInputs(opConfig, numOfInputs, inputTensors)
599
+ #define QNN_OP_CFG_SET_OUTPUTS(opConfig, numOfOutputs, outputTensors) \
600
+ setQnnOpConfigOutputs(opConfig, numOfOutputs, outputTensors)
601
+
602
+ // Creator for QNN Tensor
603
+ #define QNN_TENSOR_CREATE(version) createQnnTensor(version)
604
+
605
+ // Accessors for QNN Tensor
606
+ #define QNN_TENSOR_GET_ID(tensor) getQnnTensorId(tensor)
607
+ #define QNN_TENSOR_GET_NAME(tensor) getQnnTensorName(tensor)
608
+ #define QNN_TENSOR_GET_TYPE(tensor) getQnnTensorType(tensor)
609
+ #define QNN_TENSOR_GET_DATA_FORMAT(tensor) getQnnTensorDataFormat(tensor)
610
+ #define QNN_TENSOR_GET_DATA_TYPE(tensor) getQnnTensorDataType(tensor)
611
+ #define QNN_TENSOR_GET_QUANT_PARAMS(tensor) getQnnTensorQuantParams(tensor)
612
+ #define QNN_TENSOR_GET_RANK(tensor) getQnnTensorRank(tensor)
613
+ #define QNN_TENSOR_GET_DIMENSIONS(tensor) getQnnTensorDimensions(tensor)
614
+ #define QNN_TENSOR_GET_IS_DYNAMIC_DIMENSIONS(tensor) getQnnTensorIsDynamicDimensions(tensor)
615
+ #define QNN_TENSOR_GET_SPARSE_PARAMS(tensor) getQnnTensorSparseParams(tensor)
616
+ #define QNN_TENSOR_GET_MEM_TYPE(tensor) getQnnTensorMemType(tensor)
617
+ #define QNN_TENSOR_GET_CLIENT_BUF(tensor) getQnnTensorClientBuf(tensor)
618
+ #define QNN_TENSOR_GET_MEM_HANDLE(tensor) getQnnTensorMemHandle(tensor)
619
+
620
+ // Modifiers for QNN Tensor
621
+ #define QNN_TENSOR_SET_ID(tensor, value) setQnnTensorId(tensor, value)
622
+ #define QNN_TENSOR_SET_NAME(tensor, value) setQnnTensorName(tensor, value)
623
+ #define QNN_TENSOR_SET_TYPE(tensor, value) setQnnTensorType(tensor, value)
624
+ #define QNN_TENSOR_SET_DATA_FORMAT(tensor, value) setQnnTensorDataFormat(tensor, value)
625
+ #define QNN_TENSOR_SET_DATA_TYPE(tensor, value) setQnnTensorDataType(tensor, value)
626
+ #define QNN_TENSOR_SET_QUANT_PARAMS(tensor, value) setQnnTensorQuantParams(tensor, value)
627
+ #define QNN_TENSOR_SET_RANK(tensor, value) setQnnTensorRank(tensor, value)
628
+ #define QNN_TENSOR_SET_DIMENSIONS(tensor, value) setQnnTensorDimensions(tensor, value)
629
+ #define QNN_TENSOR_SET_IS_DYNAMIC_DIMENSIONS(tensor, value) \
630
+ setQnnTensorIsDynamicDimensions(tensor, value)
631
+ #define QNN_TENSOR_SET_SPARSE_PARAMS(tensor, value) setQnnTensorSparseParams(tensor, value)
632
+ #define QNN_TENSOR_SET_MEM_TYPE(tensor, value) setQnnTensorMemType(tensor, value)
633
+ #define QNN_TENSOR_SET_CLIENT_BUF(tensor, value) setQnnTensorClientBuf(tensor, value)
634
+ #define QNN_TENSOR_SET_MEM_HANDLE(tensor, value) setQnnTensorMemHandle(tensor, value)
635
+
636
+ // Creator for QNN Tensor Set
637
+ #define QNN_TENSORSET_CREATE(version) createQnnTensorSet(version)
638
+
639
+ // Accessors for QNN Tensor Set
640
+ #define QNN_TENSORSET_GET_NUM_INPUTS(tensorSet) getQnnTensorSetNumInputs(tensorSet)
641
+ #define QNN_TENSORSET_GET_INPUT_TENSORS(tensorSet) getQnnTensorSetInputTensors(tensorSet)
642
+ #define QNN_TENSORSET_GET_NUM_OUTPUTS(tensorSet) getQnnTensorSetNumOutputs(tensorSet)
643
+ #define QNN_TENSORSET_GET_OUTPUT_TENSORS(tensorSet) getQnnTensorSetOutputTensors(tensorSet)
644
+
645
+ // Modifiers for QNN Tensor Set
646
+ #define QNN_TENSORSET_SET_INPUT_TENSORS(tensorSet, inputTensors, numInputs) \
647
+ setQnnTensorSetInputTensors(tensorSet, inputTensors, numInputs)
648
+ #define QNN_TENSORSET_SET_OUTPUT_TENSORS(tensorSet, outputTensors, numOutputs) \
649
+ setQnnTensorSetOutputTensors(tensorSet, outputTensors, numOutputs)
650
+
651
+ inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t& tensor) {
652
+ if (tensor.version == QNN_TENSOR_VERSION_2) {
653
+ if (tensor.v2.isDynamicDimensions != NULL) {
654
+ return false;
655
+ }
656
+
657
+ if (tensor.v2.dataFormat == QNN_TENSOR_DATA_FORMAT_SPARSE) {
658
+ return false;
659
+ }
660
+ }
661
+
662
+ return true;
663
+ }
664
+
665
+ inline bool isQnnTensorV1Compatible(const Qnn_Tensor_t* const tensor) {
666
+ return isQnnTensorV1Compatible(*tensor);
667
+ }
668
+
669
+ inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t& opConfig) {
670
+ if ((QNN_OP_CFG_GET_INPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_INPUTS(opConfig) > 0u)) {
671
+ for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_INPUTS(opConfig);
672
+ tensorIdx++) {
673
+ if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_INPUTS(opConfig)[tensorIdx])) {
674
+ return false;
675
+ }
676
+ }
677
+ }
678
+ if ((QNN_OP_CFG_GET_OUTPUTS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig) > 0u)) {
679
+ for (uint32_t tensorIdx = 0u; tensorIdx < QNN_OP_CFG_GET_NUM_OUTPUTS(opConfig);
680
+ tensorIdx++) {
681
+ if (!isQnnTensorV1Compatible(QNN_OP_CFG_GET_OUTPUTS(opConfig)[tensorIdx])) {
682
+ return false;
683
+ }
684
+ }
685
+ }
686
+ if ((QNN_OP_CFG_GET_PARAMS(opConfig) != NULL) && (QNN_OP_CFG_GET_NUM_PARAMS(opConfig) > 0)) {
687
+ for (uint32_t paramIdx = 0u; paramIdx < QNN_OP_CFG_GET_NUM_PARAMS(opConfig); paramIdx++) {
688
+ const Qnn_Param_t& param = QNN_OP_CFG_GET_PARAMS(opConfig)[paramIdx];
689
+ if (QNN_PARAMTYPE_TENSOR == param.paramType) {
690
+ if (!isQnnTensorV1Compatible(param.tensorParam)) {
691
+ return false;
692
+ }
693
+ }
694
+ }
695
+ }
696
+
697
+ return true;
698
+ }
699
+
700
+ inline bool isQnnTensorV1Compatible(const Qnn_OpConfig_t* const opConfig) {
701
+ return isQnnTensorV1Compatible(*opConfig);
702
+ }
Genie/Genie/src/qualla/engines/qnn-api/RpcMem.cpp ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "QnnMem.h"
10
+ #include "QnnHtpMem.h"
11
+ #include "RpcMem.hpp"
12
+ #include "QnnTypeMacros.hpp"
13
+ #include "dlwrap.hpp"
14
+
15
+ #define RPCMEM_HEAP_ID_SYSTEM 25
16
+ #define RPCMEM_DEFAULT_FLAGS 1
17
+
18
+ #if 1
19
+ #define TRACE_MEMORY_ALLOC QNN_DEBUG
20
+ #else
21
+ #define TRACE_MEMORY_ALLOC(fmt, ...)
22
+ #endif
23
+
24
+ RpcMem::RpcMem(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface)
25
+ : m_libCdspRpc(nullptr), m_rpcMemAlloc(nullptr), m_rpcMemFree(nullptr), m_rpcMemToFd(nullptr),
26
+ m_qnnInterface(qnnInterface), m_contextHandle(contextHandle) {
27
+ (void)m_contextHandle;
28
+ }
29
+
30
+ bool RpcMem::initialize() {
31
+ // On Android, 32-bit and 64-bit libcdsprpc.so can be found at /vendor/lib and /vendor/lib64 respectively.
32
+ // On Windows, it's installed into something like this
33
+ // c:\Windows\System32\DriverStore\FileRepository\qcnspmcdm8380.inf_arm64_30b9cc995571de6a\libcdsprpc.dll
34
+ #ifdef _WIN32
35
+ const char* dsprpc_so = "libcdsprpc.dll";
36
+ #else
37
+ const char* dsprpc_so = "libcdsprpc.so";
38
+ #endif
39
+
40
+ m_libCdspRpc = dlopen(dsprpc_so, RTLD_NOW | RTLD_LOCAL);
41
+ if (nullptr == m_libCdspRpc) {
42
+ QNN_ERROR("Unable to load backend. dlerror(): %s", dlerror());
43
+ return false;
44
+ }
45
+ m_rpcMemAlloc = (RpcMemAllocFn_t)dlsym(m_libCdspRpc, "rpcmem_alloc");
46
+ m_rpcMemFree = (RpcMemFreeFn_t)dlsym(m_libCdspRpc, "rpcmem_free");
47
+ m_rpcMemToFd = (RpcMemToFdFn_t)dlsym(m_libCdspRpc, "rpcmem_to_fd");
48
+ if (nullptr == m_rpcMemAlloc || nullptr == m_rpcMemFree || nullptr == m_rpcMemToFd) {
49
+ QNN_ERROR("Unable to access symbols in libcdsprpc. dlerror(): %s", dlerror());
50
+ return false;
51
+ }
52
+
53
+ return true;
54
+ }
55
+
56
+ RpcMem::~RpcMem() {
57
+ if (m_libCdspRpc) {
58
+ QNN_DEBUG("Closing libcdsprpc.so handle");
59
+ dlclose(m_libCdspRpc);
60
+ }
61
+ }
62
+
63
+ RpcMemTensorData* RpcMem::getRpcMemTensorData(Qnn_Tensor_t* tensor) {
64
+ if (tensor == nullptr) return nullptr;
65
+ Qnn_MemHandle_t mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
66
+ if (mem_handle == nullptr) return nullptr;
67
+ return &m_memHandleToRpcMem.at(mem_handle);
68
+ }
69
+
70
+ void* RpcMem::getBuffer(Qnn_Tensor_t* tensor) {
71
+ RpcMemTensorData* data = getRpcMemTensorData(tensor);
72
+ if (data == nullptr) {
73
+ QNN_ERROR("getBuffer : Couldn't find tensor %p", tensor);
74
+ return nullptr;
75
+ }
76
+ return data->memPointer;
77
+ }
78
+
79
+ int RpcMem::getFd(Qnn_Tensor_t* tensor) {
80
+ RpcMemTensorData* data = getRpcMemTensorData(tensor);
81
+ if (data == nullptr) {
82
+ QNN_ERROR("getFd : Couldn't find tensor %p", tensor);
83
+ return -1;
84
+ }
85
+ return data->fd;
86
+ }
87
+
88
+ size_t RpcMem::getOffset(Qnn_Tensor_t* tensor) {
89
+ RpcMemTensorData* data = getRpcMemTensorData(tensor);
90
+ if (data == nullptr) {
91
+ QNN_ERROR("getOffset : Couldn't find tensor %p", tensor);
92
+ return 0;
93
+ }
94
+ return data->offset;
95
+ }
96
+
97
+ size_t RpcMem::getBufferSize(Qnn_Tensor_t* tensor) {
98
+ RpcMemTensorData* data = getRpcMemTensorData(tensor);
99
+ if (data == nullptr) {
100
+ QNN_ERROR("getBufferSize : Couldn't find tensor %p", tensor);
101
+ return 0;
102
+ }
103
+ return data->size;
104
+ };
105
+
106
+ size_t RpcMem::getTotalBufferSize(Qnn_Tensor_t* tensor) {
107
+ RpcMemTensorData* data = getRpcMemTensorData(tensor);
108
+ if (data == nullptr) {
109
+ QNN_ERROR("getTotalBufferSize : Couldn't find tensor %p", tensor);
110
+ return 0;
111
+ }
112
+ return data->totalBufferSize;
113
+ }
114
+
115
+ bool RpcMem::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) {
116
+ if (m_libCdspRpc == nullptr) {
117
+ QNN_ERROR("RpcMem not initialized");
118
+ return false;
119
+ }
120
+ if (!tensor) {
121
+ QNN_ERROR("Received nullptr for tensor");
122
+ return false;
123
+ }
124
+ if (m_tensorToRpcMem.find(tensor) != m_tensorToRpcMem.end()) {
125
+ QNN_ERROR("Tensor already allocated");
126
+ return false;
127
+ }
128
+
129
+ auto memPointer = m_rpcMemAlloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, tensorDataSize);
130
+ auto status = true;
131
+ if (!memPointer) {
132
+ QNN_ERROR("rpcmem_alloc failure");
133
+ status = false;
134
+ }
135
+ int memfd = -1;
136
+ if (status == true) {
137
+ memfd = m_rpcMemToFd(memPointer);
138
+ if (memfd == -1) {
139
+ QNN_ERROR("rpcmem_to_fd failure");
140
+ status = false;
141
+ }
142
+ }
143
+ if (status == true) {
144
+ Qnn_MemDescriptor_t memDescriptor = {
145
+ {QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr},
146
+ QNN_TENSOR_GET_DATA_TYPE(tensor),
147
+ QNN_MEM_TYPE_ION,
148
+ {{-1}}
149
+ };
150
+ memDescriptor.ionInfo.fd = memfd;
151
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
152
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr);
153
+
154
+ Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
155
+ if (QNN_SUCCESS != m_qnnInterface->memRegister(
156
+ m_contextHandle,
157
+ &memDescriptor,
158
+ 1,
159
+ &(memHandle)
160
+ )) {
161
+ const char* tname = QNN_TENSOR_GET_NAME(tensor);
162
+ QNN_ERROR("memRegister fail %s (ctx=%p fd=%d)", tname, m_contextHandle, memfd);
163
+ status = false;
164
+ }
165
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, memHandle);
166
+ }
167
+ if (status == true) {
168
+ m_tensorToRpcMem.insert({tensor, RpcMemTensorData(memfd, memPointer, tensorDataSize)});
169
+ }
170
+ if (status == false) {
171
+ if (m_rpcMemFree) {
172
+ m_rpcMemFree(memPointer);
173
+ }
174
+ }
175
+ return status;
176
+ }
177
+
178
+ bool RpcMem::freeTensorBuffer(Qnn_Tensor_t* tensor) {
179
+ if (!tensor) {
180
+ QNN_ERROR("Received nullptr for tensor");
181
+ return false;
182
+ }
183
+
184
+ if (m_sameMemoryFreeTensors.find(tensor) != m_sameMemoryFreeTensors.end()) {
185
+ if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) {
186
+ QNN_ERROR("Tensor not found");
187
+ return false;
188
+ }
189
+ m_tensorToRpcMem.erase(tensor);
190
+ } else {
191
+ auto memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
192
+ if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
193
+ QNN_ERROR("Failed to deregister ion memory with the backend");
194
+ return false;
195
+ }
196
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED);
197
+ if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) {
198
+ QNN_ERROR("Tensor not found");
199
+ return false;
200
+ }
201
+ if (m_rpcMemFree) {
202
+ m_rpcMemFree(m_tensorToRpcMem[tensor].memPointer);
203
+ }
204
+ m_tensorToRpcMem.erase(tensor);
205
+ }
206
+
207
+ return true;
208
+ }
209
+
210
+ bool RpcMem::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
211
+ if (nullptr == dest || nullptr == src) {
212
+ QNN_ERROR("Received nullptr");
213
+ return false;
214
+ }
215
+ if (m_tensorToRpcMem.find(src) == m_tensorToRpcMem.end()) {
216
+ QNN_ERROR("Src Tensor not found");
217
+ return false;
218
+ }
219
+
220
+ if (false == freeTensorBuffer(dest)) {
221
+ return false;
222
+ }
223
+
224
+ QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
225
+ QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src));
226
+ m_tensorToRpcMem.insert({dest, m_tensorToRpcMem[src]});
227
+ m_sameMemoryFreeTensors.insert(dest);
228
+
229
+ return true;
230
+ }
231
+
232
+ bool RpcMem::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) {
233
+ if (nullptr == dest || nullptr == src) {
234
+ QNN_ERROR("Received nullptr");
235
+ return false;
236
+ }
237
+ if (m_tensorToRpcMem.find(src) == m_tensorToRpcMem.end()) {
238
+ QNN_ERROR("Src Tensor not found");
239
+ return false;
240
+ }
241
+
242
+ if (false == freeTensorBuffer(dest)) {
243
+ return false;
244
+ }
245
+
246
+ QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
247
+ QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src));
248
+ m_tensorToRpcMem.insert({dest, m_tensorToRpcMem[src]});
249
+ m_sameMemoryFreeTensors.insert(dest);
250
+
251
+ return true;
252
+ }
253
+
254
+ bool RpcMem::useExternalMemory(Qnn_Tensor_t* dest, void* extMem) {
255
+ QNN_ERROR("We don't support external memory feature for shared buffers yet!");
256
+ return false;
257
+ }
258
+
259
+ void* RpcMem::allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) {
260
+ *fd = -1;
261
+ if (m_libCdspRpc == nullptr) {
262
+ QNN_ERROR("RpcMem not initialized for fused buffer");
263
+ return nullptr;
264
+ }
265
+
266
+ void* memPointer = m_rpcMemAlloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS, bufferSize);
267
+ if (!memPointer) {
268
+ QNN_ERROR("Not able to allocate fused buffer of size: %lu", (unsigned long)bufferSize);
269
+ return nullptr;
270
+ }
271
+
272
+ m_fusedBuffers.push_back({memPointer, bufferSize});
273
+ QNN_DEBUG(
274
+ "Successfully allocated fused buffer at %p with size %lu",
275
+ memPointer,
276
+ (unsigned long)bufferSize
277
+ );
278
+
279
+ if ((*fd = m_rpcMemToFd(memPointer)) == -1) {
280
+ QNN_ERROR(
281
+ "Not able to get fd for the fused buffer of size: %lu", (unsigned long)bufferSize
282
+ );
283
+ return nullptr;
284
+ }
285
+
286
+ QNN_DEBUG("Retrieved fd %d for pointer %p", *fd, memPointer);
287
+ return memPointer;
288
+ }
289
+
290
+ bool RpcMem::allocateBuffers(
291
+ const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
292
+ std::map<std::string, std::pair<int, size_t>>& tensor_offsets
293
+ ) {
294
+ int alloc_chunk_idx = m_fusedBuffers.size();
295
+ int num_alloc_chunks = 0;
296
+ size_t total_alloc_size = 0;
297
+
298
+ for (auto& [_, tensor_sizes] : allocs_per_chunk) {
299
+ // Calculate total allocation chunk size
300
+ size_t alloc_chunk_size = 0;
301
+ for (const auto& [tensor_name, tensor_size] : tensor_sizes) {
302
+ tensor_offsets[tensor_name] = {alloc_chunk_idx, alloc_chunk_size};
303
+ alloc_chunk_size += tensor_size;
304
+ }
305
+
306
+ // Allocate chunk for this unique context set
307
+ if (alloc_chunk_size <= 0) {
308
+ QNN_ERROR("Unexpected chunk size detected. Please re-check IO allocations");
309
+ return false;
310
+ }
311
+
312
+ m_fusedFds.push_back(0);
313
+ if (!allocateTensorFusedBuffer(alloc_chunk_size, &m_fusedFds.back())) //
314
+ return false;
315
+ total_alloc_size += alloc_chunk_size;
316
+ alloc_chunk_idx++;
317
+ num_alloc_chunks++;
318
+ }
319
+ QNN_INFO(
320
+ "Allocated total size = %lu across %d buffers",
321
+ (unsigned long)total_alloc_size,
322
+ num_alloc_chunks
323
+ );
324
+ return true;
325
+ }
326
+
327
+ bool RpcMem::mapFusedBufferOffset(
328
+ Qnn_Tensor_t* tensor,
329
+ size_t tensorDataSize,
330
+ int32_t fd,
331
+ uint32_t offset,
332
+ uint64_t totalBufferSize,
333
+ void* memPointer,
334
+ Qnn_ContextHandle_t contextHandle
335
+ ) {
336
+ if (m_libCdspRpc == nullptr) {
337
+ QNN_ERROR("RpcMem not initialized");
338
+ return false;
339
+ }
340
+ if (!tensor) {
341
+ QNN_ERROR("Received nullptr for tensor");
342
+ return false;
343
+ }
344
+
345
+ Qnn_ErrorHandle_t ret;
346
+ const char* tname = QNN_TENSOR_GET_NAME(tensor);
347
+
348
+ // Check if tensor already has a memHandle assigned
349
+ Qnn_MemHandle_t cur_mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
350
+ if (cur_mem_handle != nullptr) {
351
+ // Check if memHandle is already identical to requested buffer and offset
352
+ RpcMemTensorData& cur_rpc_mem_data = m_memHandleToRpcMem.at(cur_mem_handle);
353
+ if (cur_rpc_mem_data.fd == fd && cur_rpc_mem_data.offset == offset) {
354
+ return true;
355
+ }
356
+
357
+ // updated offset, deregister previous mem_handle
358
+ if (tensorDataSize == 0) tensorDataSize = cur_rpc_mem_data.size;
359
+ // clang-format off
360
+ TRACE_MEMORY_ALLOC( "memDeRegister %-20s (fd=%d offset=%lu) memHandle=%p",
361
+ tname, cur_rpc_mem_data.fd, cur_rpc_mem_data.offset, cur_mem_handle);
362
+ // clang-format on
363
+ m_memHandleToRpcMem.erase(cur_mem_handle);
364
+ if ((ret = m_qnnInterface->memDeRegister(&cur_mem_handle, 1)) != QNN_SUCCESS) {
365
+ QNN_ERROR(
366
+ "memDeRegister ERROR(%lu) - %s memHandle=%p",
367
+ (unsigned long)ret,
368
+ tname,
369
+ cur_mem_handle
370
+ );
371
+ return false;
372
+ }
373
+ } else {
374
+ // For inital tensors, we need to check if the tensor can re-use a memHandle
375
+ // from another tensor in the same context
376
+ auto memConfig = std::make_tuple(fd, offset, contextHandle);
377
+ if (memConfigList.contains(memConfig)) {
378
+ auto& parentTensor = memConfigList[memConfig];
379
+ Qnn_MemHandle_t parentMemHandle = QNN_TENSOR_GET_MEM_HANDLE(parentTensor);
380
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
381
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, parentMemHandle);
382
+ TRACE_MEMORY_ALLOC("%-20s : Mapping to memHandle %p", tname, parentMemHandle);
383
+ return true;
384
+ }
385
+ }
386
+
387
+ // Register a new memHandle based on function arguments
388
+ QnnMemHtp_Descriptor_t htp_mem_desciptor = {QNN_HTP_MEM_SHARED_BUFFER, totalBufferSize, {0}};
389
+ htp_mem_desciptor.sharedBufferConfig.fd = fd;
390
+ htp_mem_desciptor.sharedBufferConfig.offset = offset;
391
+
392
+ Qnn_MemDescriptor_t mem_descriptor = {
393
+ {QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr},
394
+ QNN_TENSOR_GET_DATA_TYPE(tensor),
395
+ QNN_MEM_TYPE_CUSTOM,
396
+ {{-1}}
397
+ };
398
+ mem_descriptor.customInfo = &htp_mem_desciptor;
399
+
400
+ Qnn_MemHandle_t mem_handle = nullptr;
401
+ ret = m_qnnInterface->memRegister(contextHandle, &mem_descriptor, 1, &mem_handle);
402
+ if (ret != QNN_SUCCESS) {
403
+ QNN_ERROR("%-20s (ctx=%p fd=%d offset=%u)", tname, contextHandle, fd, offset);
404
+ QNN_ERROR("memRegister ERROR(%lu)", (unsigned long)ret);
405
+ return false;
406
+ }
407
+
408
+ // clang-format off
409
+ TRACE_MEMORY_ALLOC("%-20s (ctx=%p fd=%d offset=%u) memPointer=%p memHandle=%p",
410
+ tname, contextHandle, fd, offset, ((uint8_t*)memPointer) + offset, mem_handle);
411
+ // clang-format on
412
+ m_memHandleToRpcMem[mem_handle] = RpcMemTensorData(
413
+ fd, ((uint8_t*)memPointer) + offset, tensorDataSize, totalBufferSize, offset
414
+ );
415
+
416
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
417
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, mem_handle);
418
+ if (cur_mem_handle == nullptr) // Cache memory config for initial memRegisters only
419
+ memConfigList[std::make_tuple(fd, offset, contextHandle)] = tensor;
420
+
421
+ return true;
422
+ }
423
+
424
+ bool RpcMem::mapFusedBufferOffset(
425
+ Qnn_Tensor_t* tensor,
426
+ int alloc_idx,
427
+ size_t offset,
428
+ Qnn_ContextHandle_t ctx,
429
+ size_t size
430
+ ) {
431
+ return mapFusedBufferOffset(
432
+ tensor,
433
+ size,
434
+ m_fusedFds[alloc_idx],
435
+ offset,
436
+ m_fusedBuffers[alloc_idx].second,
437
+ m_fusedBuffers[alloc_idx].first,
438
+ ctx
439
+ );
440
+ }
441
+
442
+ bool RpcMem::deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) {
443
+ if (!tensor) {
444
+ QNN_ERROR("Received nullptr for tensor");
445
+ return false;
446
+ }
447
+
448
+ if (m_tensorToRpcMem.find(tensor) == m_tensorToRpcMem.end()) {
449
+ QNN_ERROR("Tensor not found");
450
+ return false;
451
+ }
452
+
453
+ // We are not freeing memhandles here since they are already freed when
454
+ // freeContext() gets called in the destructor of QnnApi class which
455
+ // happens before this point
456
+
457
+ // Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
458
+ // QNN_ERROR("Interface handle %p memhandle %p", m_qnnInterface, memHandle);
459
+ // if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
460
+ // QNN_ERROR("Failed to deregister ion memory with the backend");
461
+ // return false;
462
+ // }
463
+
464
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_UNDEFINED);
465
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr);
466
+ m_tensorToRpcMem.erase(tensor);
467
+ return true;
468
+ }
469
+
470
+ void RpcMem::freeFusedBuffers() {
471
+ // for (auto& memHandle : m_orphanedMemHandles) {
472
+ // if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
473
+ // QNN_ERROR("Failed to deregister ion memory with the backend");
474
+ // }
475
+ // }
476
+
477
+ for (auto& [mem_ptr, buffer_size] : m_fusedBuffers) {
478
+ QNN_DEBUG("Freeing fused buffer %p (size=%lu)", mem_ptr, buffer_size);
479
+ m_rpcMemFree(mem_ptr);
480
+ }
481
+ }
Genie/Genie/src/qualla/engines/qnn-api/RpcMem.hpp ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <unordered_set>
12
+
13
+ #include "IBufferAlloc.hpp"
14
+ #include "QnnInterface.h"
15
+ #include "Log.hpp"
16
+
17
+ typedef void* (*RpcMemAllocFn_t)(int, uint32_t, int);
18
+ typedef void (*RpcMemFreeFn_t)(void*);
19
+ typedef int (*RpcMemToFdFn_t)(void*);
20
+
21
+ struct RpcMemTensorData {
22
+ int fd;
23
+ void* memPointer;
24
+ size_t size;
25
+ size_t totalBufferSize;
26
+ size_t offset;
27
+ RpcMemTensorData() : fd(-1), memPointer(nullptr), size(0) {}
28
+ RpcMemTensorData(int fdIn, void* memPointerIn, size_t sizeIn)
29
+ : fd(fdIn), memPointer(memPointerIn), size(sizeIn) {}
30
+ RpcMemTensorData(
31
+ int fdIn,
32
+ void* memPointerIn,
33
+ size_t sizeIn,
34
+ size_t totalBufferSizeIn,
35
+ size_t offsetIn
36
+ )
37
+ : fd(fdIn), memPointer(memPointerIn), size(sizeIn), totalBufferSize(totalBufferSizeIn),
38
+ offset(offsetIn) {}
39
+ };
40
+
41
+ class RpcMem final : public IBufferAlloc {
42
+ public:
43
+ RpcMem(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface);
44
+ // Disable copy constructors, r-value referencing, etc
45
+ RpcMem(const RpcMem&) = delete;
46
+ RpcMem& operator=(const RpcMem&) = delete;
47
+ RpcMem(RpcMem&&) = delete;
48
+ RpcMem& operator=(RpcMem&&) = delete;
49
+ bool initialize() override;
50
+ void* getBuffer(Qnn_Tensor_t* tensor) override;
51
+ int getFd(Qnn_Tensor_t* tensor) override;
52
+
53
+ size_t getOffset(Qnn_Tensor_t* tensor) override;
54
+
55
+ size_t getBufferSize(Qnn_Tensor_t* tensor) override;
56
+
57
+ size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override;
58
+
59
+ bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override;
60
+
61
+ bool freeTensorBuffer(Qnn_Tensor_t* tensor) override;
62
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override;
63
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override;
64
+
65
+ bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override;
66
+
67
+ void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override;
68
+ bool allocateBuffers(
69
+ const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
70
+ std::map<std::string, std::pair<int, size_t>>& tensor_offsets
71
+ ) override;
72
+
73
+ bool mapFusedBufferOffset(
74
+ Qnn_Tensor_t* tensor,
75
+ size_t tensorDataSize,
76
+ int32_t fd,
77
+ uint32_t offset,
78
+ uint64_t totalBufferSize,
79
+ void* memPointer,
80
+ Qnn_ContextHandle_t contextHandle
81
+ ) override;
82
+ bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override;
83
+ void freeFusedBuffers() override;
84
+ bool mapFusedBufferOffset(
85
+ Qnn_Tensor_t* tensor,
86
+ int alloc_idx,
87
+ size_t offset,
88
+ Qnn_ContextHandle_t ctx,
89
+ size_t size
90
+ ) override;
91
+ virtual ~RpcMem();
92
+
93
+ private:
94
+ RpcMemTensorData* getRpcMemTensorData(Qnn_Tensor_t* tensor);
95
+
96
+ // Pointer to the dlopen'd libcdsprpc.so shared library which contains
97
+ // rpcmem_alloc, rpcmem_free, rpcmem_to_fd APIs
98
+ void* m_libCdspRpc;
99
+ // Function pointer to rpcmem_alloc
100
+ RpcMemAllocFn_t m_rpcMemAlloc;
101
+ // Function pointer to rpcmem_free
102
+ RpcMemFreeFn_t m_rpcMemFree;
103
+ // Function pointer to rpcmem_to_fd
104
+ RpcMemToFdFn_t m_rpcMemToFd;
105
+ QNN_INTERFACE_VER_TYPE* m_qnnInterface;
106
+ Qnn_ContextHandle_t m_contextHandle;
107
+
108
+ std::unordered_map<Qnn_Tensor_t*, RpcMemTensorData> m_tensorToRpcMem;
109
+ std::unordered_set<Qnn_Tensor_t*> m_sameMemoryFreeTensors;
110
+ std::vector<std::pair<void*, size_t>> m_fusedBuffers; // vector<<memPointer, bufferSize>>
111
+ std::vector<int32_t> m_fusedFds;
112
+ std::unordered_set<Qnn_MemHandle_t> m_orphanedMemHandles;
113
+ std::unordered_map<Qnn_MemHandle_t, RpcMemTensorData> m_memHandleToRpcMem;
114
+ std::map<std::tuple<int, size_t, Qnn_ContextHandle_t>, Qnn_Tensor_t*> memConfigList;
115
+ };
Genie/Genie/src/qualla/engines/qnn-api/dlwrap.cpp ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #ifdef _WIN32
10
+
11
+ #pragma warning(disable : 4133 4996)
12
+
13
+ #include <inttypes.h>
14
+ #include <stdio.h>
15
+ #include <stdlib.h>
16
+ #include <string.h>
17
+ #include <windows.h>
18
+ #include <wchar.h>
19
+
20
+ #include "dlwrap.hpp"
21
+
22
+ static const char* last_func;
23
+ static long last_err;
24
+
25
+ void* dlopen(const char* dll, int flags) {
26
+ HINSTANCE h = LoadLibraryA(dll);
27
+ if (h == NULL) {
28
+ last_err = GetLastError();
29
+ last_func = "dlopen";
30
+ }
31
+
32
+ return h;
33
+ }
34
+
35
+ int dlclose(void* h) {
36
+ if (!FreeLibrary((HINSTANCE)h)) {
37
+ last_err = GetLastError();
38
+ last_func = "dlclose";
39
+ return -1;
40
+ }
41
+
42
+ return 0;
43
+ }
44
+
45
+ void* dlsym(void* h, const char* name) {
46
+ FARPROC p = GetProcAddress((HINSTANCE)h, name);
47
+ if (!p) {
48
+ last_err = GetLastError();
49
+ last_func = "dlsym";
50
+ }
51
+ return (void*)(intptr_t)p;
52
+ }
53
+
54
+ const char* dlerror(void) {
55
+ static char str[88];
56
+
57
+ if (!last_err) return NULL;
58
+
59
+ sprintf(str, "%s error #%ld", last_func, last_err);
60
+ last_err = 0;
61
+ last_func = NULL;
62
+
63
+ return str;
64
+ }
65
+
66
+ #endif // _WIN32
Genie/Genie/src/qualla/engines/qnn-api/dlwrap.hpp ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #ifndef DLWRAP_HPP
10
+ #define DLWRAP_HPP
11
+
12
+ #ifndef _WIN32
13
+
14
+ // Just include regular dlfcn
15
+ #include <dlfcn.h>
16
+
17
+ #else // _WIN32
18
+
19
+ // Define basic set dl functions and flags
20
+
21
+ #define RTLD_GLOBAL 0x100
22
+ #define RTLD_LOCAL 0x000
23
+ #define RTLD_LAZY 0x000
24
+ #define RTLD_NOW 0x001
25
+
26
+ void* dlopen(const char* filename, int flag);
27
+ int dlclose(void* handle);
28
+ void* dlsym(void* handle, const char* name);
29
+ const char* dlerror(void);
30
+
31
+ #endif // _WIN32
32
+
33
+ #endif // DLWRAP_HPP
Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.cpp ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "qnn-utils.hpp"
10
+
11
+ #include <string>
12
+ #include <fstream>
13
+ #include <filesystem>
14
+ #include <sstream>
15
+ #include "QnnApi.hpp"
16
+ #include <fmt/format.h>
17
+
18
+ namespace fs = std::filesystem;
19
+
20
+ namespace qualla {
21
+ namespace QnnUtils {
22
+ // Alternate implementation for bw() = lambda x: (10 * ((x & 0xf0)>>4) + (x & 0xf)) // 8
23
+ int DataType::bw() { return (_dtype == QNN_DATATYPE_UNDEFINED) ? -1 : QnnApi::getDataTypeSize(_dtype);}
24
+ int DataType::type() {return (_dtype == QNN_DATATYPE_UNDEFINED) ? -1 : _dtype >> 4; }
25
+
26
+ int32_t DataType::val() { return static_cast<int32_t>(_dtype); }
27
+
28
+ bool writeRawData(void* data, size_t size, const fs::path& path) {
29
+ auto p = path.parent_path();
30
+ if (!fs::exists(p) && !fs::create_directories(p)) return false;
31
+
32
+ std::ofstream f(path, std::ofstream::binary);
33
+ f.write((char*)data, size);
34
+ f.close();
35
+
36
+ return true;
37
+ }
38
+
39
+ bool readRawData(void* data, size_t size, const fs::path& path) {
40
+ if (fs::file_size(path) != size) {
41
+ throw std::runtime_error(fmt::format(
42
+ "file size doesnot match: {} size {}, buf-size {}",
43
+ path.string(),
44
+ fs::file_size(path),
45
+ size
46
+ ));
47
+ }
48
+
49
+ std::ifstream f(path, std::ifstream::binary);
50
+ f.read((char*)data, size);
51
+ f.close();
52
+
53
+ return true;
54
+ }
55
+
56
+ void getQuantParamString(
57
+ const std::vector<QuantParam>& quantParam,
58
+ std::string& scale_string,
59
+ std::string& offset_string
60
+ ) {
61
+ std::ostringstream scales_s;
62
+ std::ostringstream offsets_s;
63
+ for (int i = 0; i < quantParam.size(); i++) {
64
+ if (i != 0) {
65
+ scales_s << ", ";
66
+ offsets_s << ", ";
67
+ }
68
+ scales_s << std::fixed << std::setprecision(20) << quantParam[i].scale;
69
+ offsets_s << quantParam[i].offset;
70
+ }
71
+ scale_string = std::move(scales_s.str());
72
+ offset_string = std::move(offsets_s.str());
73
+ }
74
+
75
+ const char* DataType::str() {
76
+ // clang-format off
77
+ switch (_dtype) {
78
+ case QNN_DATATYPE_INT_8: return "QNN_DATATYPE_INT_8";
79
+ case QNN_DATATYPE_INT_16: return "QNN_DATATYPE_INT_16";
80
+ case QNN_DATATYPE_INT_32: return "QNN_DATATYPE_INT_32";
81
+ case QNN_DATATYPE_INT_64: return "QNN_DATATYPE_INT_64";
82
+ case QNN_DATATYPE_UINT_8: return "QNN_DATATYPE_UINT_8";
83
+ case QNN_DATATYPE_UINT_16: return "QNN_DATATYPE_UINT_16";
84
+ case QNN_DATATYPE_UINT_32: return "QNN_DATATYPE_UINT_32";
85
+ case QNN_DATATYPE_UINT_64: return "QNN_DATATYPE_UINT_64";
86
+ case QNN_DATATYPE_FLOAT_16: return "QNN_DATATYPE_FLOAT_16";
87
+ case QNN_DATATYPE_FLOAT_32: return "QNN_DATATYPE_FLOAT_32";
88
+ case QNN_DATATYPE_FLOAT_64: return "QNN_DATATYPE_FLOAT_64";
89
+ case QNN_DATATYPE_SFIXED_POINT_4: return "QNN_DATATYPE_SFIXED_POINT_4";
90
+ case QNN_DATATYPE_SFIXED_POINT_8: return "QNN_DATATYPE_SFIXED_POINT_8";
91
+ case QNN_DATATYPE_SFIXED_POINT_16: return "QNN_DATATYPE_SFIXED_POINT_16";
92
+ case QNN_DATATYPE_SFIXED_POINT_32: return "QNN_DATATYPE_SFIXED_POINT_32";
93
+ case QNN_DATATYPE_UFIXED_POINT_4: return "QNN_DATATYPE_UFIXED_POINT_4";
94
+ case QNN_DATATYPE_UFIXED_POINT_8: return "QNN_DATATYPE_UFIXED_POINT_8";
95
+ case QNN_DATATYPE_UFIXED_POINT_16: return "QNN_DATATYPE_UFIXED_POINT_16";
96
+ case QNN_DATATYPE_UFIXED_POINT_32: return "QNN_DATATYPE_UFIXED_POINT_32";
97
+ case QNN_DATATYPE_BOOL_8: return "QNN_DATATYPE_BOOL_8";
98
+ case QNN_DATATYPE_STRING: return "QNN_DATATYPE_STRING";
99
+ default: return "QNN_DATATYPE_UNDEFINED";
100
+ }
101
+ // clang-format on
102
+ }
103
+ } // namespace QnnUtils
104
+ } // namespace qualla
Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #ifdef _MSC_VER
12
+ #pragma warning(disable : 4068)
13
+ #endif
14
+
15
+ #include <string>
16
+ #include <vector>
17
+ #include <algorithm>
18
+ #include <filesystem>
19
+ #include "QnnApiUtils.hpp"
20
+ #include "QnnInterface.h"
21
+
22
+ namespace qualla {
23
+
24
+ namespace QnnUtils {
25
+ class DataType {
26
+ private:
27
+ Qnn_DataType_t _dtype{QNN_DATATYPE_UNDEFINED};
28
+
29
+ public:
30
+ DataType() = default;
31
+ DataType(const Qnn_Tensor_t* tensor) : _dtype(QNN_TENSOR_GET_DATA_TYPE(tensor)) {}
32
+ DataType(Qnn_DataType_t dtype) : _dtype(dtype) {};
33
+
34
+ // Enable switch and comparisons
35
+ constexpr operator Qnn_DataType_t() const { return _dtype; }
36
+
37
+ int bw();
38
+ int type();
39
+
40
+ int32_t val();
41
+
42
+ const char* str();
43
+ };
44
+
45
+ bool writeRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
46
+ bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
47
+
48
+ struct Dims {
49
+ int32_t batch = 1;
50
+ int32_t height, width, channel, bitWidth;
51
+ Dims() : height(0), width(0), channel(0), bitWidth(0) {}
52
+ Dims(int32_t height, int32_t width, int32_t channel, int32_t bitWidth)
53
+ : height(height), width(width), channel(channel), bitWidth(bitWidth) {}
54
+ Dims(std::vector<size_t>& tDims)
55
+ : height((int32_t)tDims[1]), width((int32_t)tDims[2]), channel((int32_t)tDims[3]),
56
+ bitWidth((int32_t)tDims[4]) {
57
+ // Hack to mix batch dimension
58
+ if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0];
59
+ if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0];
60
+ }
61
+ bool operator==(const Dims& rhs) const {
62
+ return (height == rhs.height) && (width == rhs.width) && (channel == rhs.channel) &&
63
+ (bitWidth == rhs.bitWidth);
64
+ }
65
+ bool operator!=(const Dims& rhs) const { return !(operator==(rhs)); }
66
+ size_t getNumElements() const { return (size_t)(height * width * channel); }
67
+ size_t getSize() const { return (size_t)(batch * height * width * channel * bitWidth); }
68
+ size_t getAlignedSize() const {
69
+ size_t size = getSize();
70
+ if ((size & uint64_t{7}) != uint64_t{0}) {
71
+ size += (uint64_t{8} - (size & uint64_t{7}));
72
+ }
73
+ return size;
74
+ }
75
+ int32_t getMaxDim() const { return std::max({height, width, channel}); };
76
+ Dims T() const { return Dims(width, height, channel, bitWidth); }
77
+ };
78
+
79
+ struct QuantParam {
80
+ double scale;
81
+ int32_t offset;
82
+ QuantParam() {}
83
+ QuantParam(double scale_val, int32_t offset_val) : scale(scale_val), offset(offset_val) {}
84
+ };
85
+
86
+ struct Tensor {
87
+ Qnn_Tensor_t* tensor = nullptr;
88
+ Dims dims;
89
+ std::vector<QuantParam> quantParam;
90
+ DataType dtype;
91
+ Tensor() {}
92
+ Tensor(Qnn_Tensor_t* tensorVal, Dims dimsVal, std::vector<QuantParam> quantParamVec)
93
+ : tensor(tensorVal), dims(dimsVal), quantParam(quantParamVec),
94
+ dtype(QNN_TENSOR_GET_DATA_TYPE(tensorVal)) {}
95
+ };
96
+
97
+ // Maps tensor name to QnnUtils::Tensor<Qnn_Tensor_t* tensor, dims, quantparams>
98
+ typedef std::map<std::string, Tensor> TensorMap;
99
+
100
+ static inline uint8_t sat_round(const uint16_t x) {
101
+ const uint16_t rounded = x + 0x80; // add 0.5
102
+ const uint16_t corrected = std::max(rounded, x); // catch unsigned wrap around
103
+ const uint16_t shifted = corrected >> 8; // divide by 256
104
+ return static_cast<uint8_t>(shifted); // to 8-bit
105
+ }
106
+
107
+ static inline void downcast_u16_to_u8(uint8_t* dest, const uint16_t* src, size_t nmemb) {
108
+ for (size_t i = 0; i < nmemb; i++)
109
+ dest[i] = sat_round(src[i]);
110
+ }
111
+
112
+ template <typename FloatType, typename IntType>
113
+ static inline void quantizeTensorPtr(
114
+ FloatType* tensor_float,
115
+ IntType* tensor_quant,
116
+ int32_t offset,
117
+ double scale,
118
+ size_t nmemb
119
+ ) {
120
+ #pragma clang loop vectorize(enable) interleave(enable)
121
+ for (size_t i = 0; i < nmemb; i++) {
122
+ double val = tensor_float[i];
123
+ tensor_quant[i] = static_cast<IntType>(val / scale - offset);
124
+ }
125
+ }
126
+
127
+ template <typename FloatType, typename IntType>
128
+ static inline void perWidthQuantizeTensorPtr(
129
+ FloatType* tensor_float,
130
+ IntType* tensor_quant,
131
+ std::vector<QnnUtils::QuantParam>& quantParam,
132
+ int32_t height,
133
+ int32_t width,
134
+ int32_t channel
135
+ ) {
136
+ for (size_t h = 0; h < height; h++) {
137
+ for (size_t w = 0; w < width; w++) {
138
+ double scale = quantParam[w].scale;
139
+ int32_t offset = quantParam[w].offset;
140
+ #pragma clang loop vectorize(enable) interleave(enable)
141
+ for (size_t c = 0; c < channel; c++) {
142
+ int32_t i = (h * width * channel) + (w * channel) + c;
143
+ double val = tensor_float[i];
144
+ tensor_quant[i] = static_cast<IntType>(val / scale - offset);
145
+ }
146
+ }
147
+ }
148
+ }
149
+
150
+ void getQuantParamString(
151
+ const std::vector<QuantParam>& quantParam,
152
+ std::string& scale_string,
153
+ std::string& offset_string
154
+ );
155
+
156
+ } // namespace QnnUtils
157
+ } // namespace qualla