add genie2.29
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Genie/Genie/GenieSymbols.default +5 -0
- Genie/Genie/make/Android.mk +2 -0
- Genie/Genie/make/Application.mk +1 -1
- Genie/Genie/make/Makefile.linux-x86_64 +14 -5
- Genie/Genie/src/Dialog.cpp +54 -80
- Genie/Genie/src/Dialog.hpp +6 -1
- Genie/Genie/src/Embedding.cpp +740 -0
- Genie/Genie/src/Embedding.hpp +56 -0
- Genie/Genie/src/Exception.hpp +1 -0
- Genie/Genie/src/GenieDialog.cpp +19 -1
- Genie/Genie/src/GenieEmbedding.cpp +118 -0
- Genie/Genie/src/GenieSampler.cpp +93 -0
- Genie/Genie/src/Macro.hpp +2 -0
- Genie/Genie/src/Sampler.cpp +275 -0
- Genie/Genie/src/Sampler.hpp +60 -0
- Genie/Genie/src/qualla/context.cpp +8 -0
- Genie/Genie/src/qualla/dialogs/ssd-q1.cpp +2 -2
- Genie/Genie/src/qualla/engine.cpp +1 -1
- Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.cpp +317 -0
- Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.hpp +128 -0
- Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp +15 -1
- Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp +76 -1
- Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp +25 -0
- Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp +369 -90
- Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp +9 -0
- Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp +5 -5
- Genie/Genie/src/qualla/engines/qnn-cpu.cpp +55 -3
- Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp +51 -0
- Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp +25 -0
- Genie/Genie/src/qualla/engines/qnn-gpu.cpp +193 -0
- Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.cpp +603 -0
- Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.hpp +136 -0
- Genie/Genie/src/qualla/engines/qnn-htp.cpp +2 -2
- Genie/Genie/src/qualla/engines/qnn-htp.hpp +1 -1
- Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp +9 -3
- Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp +2 -1
- Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp +8 -4
- Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp +4 -6
- Genie/Genie/src/qualla/include/qualla/detail/basic-sampler.hpp +1 -0
- Genie/Genie/src/qualla/include/qualla/dialog.hpp +1 -0
- Genie/Genie/src/qualla/include/qualla/engine.hpp +1 -1
- Genie/Genie/src/qualla/include/qualla/sampler.hpp +1 -0
- Genie/Genie/src/qualla/sampler.cpp +4 -0
- Genie/Genie/src/qualla/samplers/basic.cpp +8 -0
- Genie/Genie/src/qualla/tokenizers/rust/Cargo.lock +26 -26
- Genie/Model/model.cpp +23 -2
- Genie/configs/llama2-7b/llama2-7b-draft-htp-target-htp-spd.json +2 -1
- Genie/configs/llama2-7b/llama2-7b-genaitransformer-lora.json +62 -0
- Genie/configs/llama2-7b/llama2-7b-genaitransformer.json +4 -1
- Genie/configs/llama2-7b/llama2-7b-gpu.json +43 -0
Genie/Genie/GenieSymbols.default
CHANGED
|
@@ -14,6 +14,11 @@
|
|
| 14 |
GenieDialogConfig_free*;
|
| 15 |
GenieDialog_create*;
|
| 16 |
GenieDialog_query*;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
GenieDialog_tokenQuery*;
|
| 18 |
GenieDialog_embeddingQuery*;
|
| 19 |
GenieDialog_save*;
|
|
|
|
| 14 |
GenieDialogConfig_free*;
|
| 15 |
GenieDialog_create*;
|
| 16 |
GenieDialog_query*;
|
| 17 |
+
GenieDialog_getSampler*;
|
| 18 |
+
GenieSampler_applyConfig*;
|
| 19 |
+
GenieSamplerConfig_createFromJson*;
|
| 20 |
+
GenieSamplerConfig_setParam*;
|
| 21 |
+
GenieSamplerConfig_free*;
|
| 22 |
GenieDialog_tokenQuery*;
|
| 23 |
GenieDialog_embeddingQuery*;
|
| 24 |
GenieDialog_save*;
|
Genie/Genie/make/Android.mk
CHANGED
|
@@ -29,6 +29,7 @@ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN/HTP
|
|
| 29 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers
|
| 30 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api
|
| 31 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu
|
|
|
|
| 32 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp
|
| 33 |
|
| 34 |
#========================== Define T2T Lib variables =============================================
|
|
@@ -45,6 +46,7 @@ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/dialogs
|
|
| 45 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp)
|
| 46 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp)
|
| 47 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp)
|
|
|
|
| 48 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp)
|
| 49 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp)
|
| 50 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp)
|
|
|
|
| 29 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers
|
| 30 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api
|
| 31 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu
|
| 32 |
+
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-gpu
|
| 33 |
PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp
|
| 34 |
|
| 35 |
#========================== Define T2T Lib variables =============================================
|
|
|
|
| 46 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp)
|
| 47 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp)
|
| 48 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp)
|
| 49 |
+
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-gpu/*.cpp)
|
| 50 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp)
|
| 51 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp)
|
| 52 |
MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp)
|
Genie/Genie/make/Application.mk
CHANGED
|
@@ -10,5 +10,5 @@ APP_ABI := arm64-v8a
|
|
| 10 |
APP_STL := c++_shared
|
| 11 |
APP_PLATFORM := android-21
|
| 12 |
APP_MODULES := Genie
|
| 13 |
-
APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 14 |
APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all
|
|
|
|
| 10 |
APP_STL := c++_shared
|
| 11 |
APP_PLATFORM := android-21
|
| 12 |
APP_MODULES := Genie
|
| 13 |
+
APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_ENGINE_QNN_GPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 14 |
APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all
|
Genie/Genie/make/Makefile.linux-x86_64
CHANGED
|
@@ -17,6 +17,7 @@ SRC_DIR_SAMPLE_DIALOGS := src/qualla/dialogs
|
|
| 17 |
SRC_DIR_GENIE_ENGINES := src/qualla/engines
|
| 18 |
SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api
|
| 19 |
SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu
|
|
|
|
| 20 |
SRC_DIR_GENIE_UTILS := src/qualla/utils
|
| 21 |
#
|
| 22 |
SRC_DIR_GENIE_LOGGERS := src/qualla/loggers
|
|
@@ -29,6 +30,7 @@ SRC_DIR_GENIE := src
|
|
| 29 |
|
| 30 |
# Includes
|
| 31 |
GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu
|
|
|
|
| 32 |
GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api
|
| 33 |
GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp
|
| 34 |
GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers
|
|
@@ -62,7 +64,7 @@ endif
|
|
| 62 |
GENIE_all: $(libGenie)
|
| 63 |
|
| 64 |
# Include paths
|
| 65 |
-
INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE)
|
| 66 |
|
| 67 |
# set compiler flags
|
| 68 |
COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES)
|
|
@@ -71,11 +73,11 @@ COMMON_LDFLAGS = -shared -s -fPIC -pthread -L/usr/lib/x86_64-linux-gnu -L./src/
|
|
| 71 |
COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include
|
| 72 |
|
| 73 |
ifdef QNN_DEBUG_ENABLE
|
| 74 |
-
CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 75 |
CFLAGS += $(COMMON_CFLAGS)
|
| 76 |
LDFLAGS += $(COMMON_LDFLAGS)
|
| 77 |
else
|
| 78 |
-
CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 79 |
CFLAGS += $(COMMON_CFLAGS)
|
| 80 |
LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto
|
| 81 |
endif
|
|
@@ -89,6 +91,7 @@ SOURCES_GENIE_QNN_API_CPP := $(wildcard $(SRC_DIR_GENIE_QNN_API)/*.cpp)
|
|
| 89 |
SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp))
|
| 90 |
SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp)
|
| 91 |
SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp)
|
|
|
|
| 92 |
SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp)
|
| 93 |
|
| 94 |
|
|
@@ -108,6 +111,8 @@ OBJ_DIR_GENIE_ENGINES := $(OBJ_DIR_QUALLA)/engines
|
|
| 108 |
OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils
|
| 109 |
OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu
|
| 110 |
$(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU))
|
|
|
|
|
|
|
| 111 |
|
| 112 |
OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers
|
| 113 |
OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers
|
|
@@ -125,6 +130,7 @@ OBJECTS_GENIE_ENGINES := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES)/%.o,$(foreach
|
|
| 125 |
OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x))))
|
| 126 |
OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x))))
|
| 127 |
OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x))))
|
|
|
|
| 128 |
|
| 129 |
OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x))))
|
| 130 |
OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x))))
|
|
@@ -157,16 +163,18 @@ $(OBJ_DIR_GENIE_UTILS)/%.o: $(SRC_DIR_GENIE_UTILS)/%.cpp $(CXX) $(CXXFLAGS) -c $
|
|
| 157 |
|
| 158 |
$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 159 |
|
|
|
|
|
|
|
| 160 |
$(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 161 |
|
| 162 |
$(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 163 |
|
| 164 |
|
| 165 |
# set up resources
|
| 166 |
-
directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS)
|
| 167 |
|
| 168 |
# Compile
|
| 169 |
-
$(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories)
|
| 170 |
$(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers)
|
| 171 |
|
| 172 |
|
|
@@ -179,6 +187,7 @@ $(OBJECTS_GENIE_ENGINES): | $(OBJ_DIR_GENIE_ENGINES)
|
|
| 179 |
$(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS)
|
| 180 |
$(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS)
|
| 181 |
$(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU)
|
|
|
|
| 182 |
$(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS)
|
| 183 |
$(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS)
|
| 184 |
|
|
|
|
| 17 |
SRC_DIR_GENIE_ENGINES := src/qualla/engines
|
| 18 |
SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api
|
| 19 |
SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu
|
| 20 |
+
SRC_DIR_GENIE_ENGINES_GPU := src/qualla/engines/qnn-gpu
|
| 21 |
SRC_DIR_GENIE_UTILS := src/qualla/utils
|
| 22 |
#
|
| 23 |
SRC_DIR_GENIE_LOGGERS := src/qualla/loggers
|
|
|
|
| 30 |
|
| 31 |
# Includes
|
| 32 |
GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu
|
| 33 |
+
GENIE_ENGINES_GPU_INCLUDE := src/qualla/engines/qnn-gpu
|
| 34 |
GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api
|
| 35 |
GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp
|
| 36 |
GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers
|
|
|
|
| 64 |
GENIE_all: $(libGenie)
|
| 65 |
|
| 66 |
# Include paths
|
| 67 |
+
INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(GENIE_ENGINES_GPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE)
|
| 68 |
|
| 69 |
# set compiler flags
|
| 70 |
COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES)
|
|
|
|
| 73 |
COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include
|
| 74 |
|
| 75 |
ifdef QNN_DEBUG_ENABLE
|
| 76 |
+
CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_ENGINE_QNN_GPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 77 |
CFLAGS += $(COMMON_CFLAGS)
|
| 78 |
LDFLAGS += $(COMMON_LDFLAGS)
|
| 79 |
else
|
| 80 |
+
CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_ENGINE_QNN_GPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
|
| 81 |
CFLAGS += $(COMMON_CFLAGS)
|
| 82 |
LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto
|
| 83 |
endif
|
|
|
|
| 91 |
SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp))
|
| 92 |
SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp)
|
| 93 |
SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp)
|
| 94 |
+
SOURCES_GENIE_ENGINES_GPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_GPU)/*.cpp)
|
| 95 |
SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp)
|
| 96 |
|
| 97 |
|
|
|
|
| 111 |
OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils
|
| 112 |
OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu
|
| 113 |
$(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU))
|
| 114 |
+
OBJ_DIR_GENIE_ENGINES_GPU := $(OBJ_DIR_QUALLA)/engines/qnn-gpu
|
| 115 |
+
$(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_GPU))
|
| 116 |
|
| 117 |
OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers
|
| 118 |
OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers
|
|
|
|
| 130 |
OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x))))
|
| 131 |
OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x))))
|
| 132 |
OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x))))
|
| 133 |
+
OBJECTS_GENIE_ENGINES_GPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_GPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_GPU_CPP),$(notdir $(x))))
|
| 134 |
|
| 135 |
OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x))))
|
| 136 |
OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x))))
|
|
|
|
| 163 |
|
| 164 |
$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 165 |
|
| 166 |
+
$(OBJ_DIR_GENIE_ENGINES_GPU)/%.o: $(SRC_DIR_GENIE_ENGINES_GPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 167 |
+
|
| 168 |
$(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 169 |
|
| 170 |
$(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
|
| 171 |
|
| 172 |
|
| 173 |
# set up resources
|
| 174 |
+
directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_ENGINES_GPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS)
|
| 175 |
|
| 176 |
# Compile
|
| 177 |
+
$(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_ENGINES_GPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories)
|
| 178 |
$(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers)
|
| 179 |
|
| 180 |
|
|
|
|
| 187 |
$(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS)
|
| 188 |
$(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS)
|
| 189 |
$(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU)
|
| 190 |
+
$(OBJECTS_GENIE_ENGINES_GPU): | $(OBJ_DIR_GENIE_ENGINES_GPU)
|
| 191 |
$(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS)
|
| 192 |
$(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS)
|
| 193 |
|
Genie/Genie/src/Dialog.cpp
CHANGED
|
@@ -95,81 +95,6 @@ static void translateContextConfig(const qualla::json& genieConfig, qualla::json
|
|
| 95 |
}
|
| 96 |
}
|
| 97 |
|
| 98 |
-
//=============================================================================
|
| 99 |
-
// Sampler::Config functions
|
| 100 |
-
//=============================================================================
|
| 101 |
-
|
| 102 |
-
static void validateSamplerConfig(const qualla::json& config) {
|
| 103 |
-
if (!config.is_object()) {
|
| 104 |
-
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
|
| 105 |
-
}
|
| 106 |
-
|
| 107 |
-
std::set<std::string> mandatoryFields{"version"};
|
| 108 |
-
for (const auto& field : mandatoryFields) {
|
| 109 |
-
if (!config.contains(field)) {
|
| 110 |
-
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
|
| 111 |
-
}
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
// component is used in the "ENFORCE" macros
|
| 115 |
-
std::string component = "sampler";
|
| 116 |
-
|
| 117 |
-
for (auto& item : config.items()) {
|
| 118 |
-
if (item.key() == "version") {
|
| 119 |
-
JSON_ENFORCE_NUMERIC();
|
| 120 |
-
if (item.value().get<int>() != 1) {
|
| 121 |
-
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 122 |
-
"Invalid sampler config: unsupported version: " + item.value().dump());
|
| 123 |
-
}
|
| 124 |
-
} else if (item.key() == "seed") {
|
| 125 |
-
JSON_ENFORCE_NUMERIC();
|
| 126 |
-
} else if (item.key() == "temp") {
|
| 127 |
-
JSON_ENFORCE_NUMERIC();
|
| 128 |
-
} else if (item.key() == "top-k") {
|
| 129 |
-
JSON_ENFORCE_NUMERIC();
|
| 130 |
-
} else if (item.key() == "top-p") {
|
| 131 |
-
JSON_ENFORCE_NUMERIC();
|
| 132 |
-
} else if (item.key() == "greedy") {
|
| 133 |
-
JSON_ENFORCE_BOOLEAN();
|
| 134 |
-
} else {
|
| 135 |
-
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
|
| 136 |
-
}
|
| 137 |
-
}
|
| 138 |
-
}
|
| 139 |
-
|
| 140 |
-
static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 141 |
-
if (genieConfig["dialog"].contains("sampler")) {
|
| 142 |
-
quallaConfig["sampler"]["type"] = "basic";
|
| 143 |
-
|
| 144 |
-
if (genieConfig["dialog"]["sampler"].contains("seed")) {
|
| 145 |
-
quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
|
| 146 |
-
}
|
| 147 |
-
if (genieConfig["dialog"]["sampler"].contains("temp")) {
|
| 148 |
-
quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
|
| 149 |
-
}
|
| 150 |
-
|
| 151 |
-
quallaConfig["sampler"]["role"] = "primary";
|
| 152 |
-
#if defined(GENIE_SPD_FEATURE)
|
| 153 |
-
if (genieConfig["dialog"]["type"] == "spd") {
|
| 154 |
-
quallaConfig["sampler"]["role"] = "target";
|
| 155 |
-
}
|
| 156 |
-
#endif
|
| 157 |
-
|
| 158 |
-
if (genieConfig["dialog"]["sampler"].contains("top-k")) {
|
| 159 |
-
quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
|
| 160 |
-
}
|
| 161 |
-
if (genieConfig["dialog"]["sampler"].contains("top-p")) {
|
| 162 |
-
quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
|
| 163 |
-
}
|
| 164 |
-
if (genieConfig["dialog"]["sampler"].contains("greedy")) {
|
| 165 |
-
quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
|
| 166 |
-
}
|
| 167 |
-
if (genieConfig["dialog"]["sampler"].contains("seed")) {
|
| 168 |
-
quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
|
| 169 |
-
}
|
| 170 |
-
}
|
| 171 |
-
}
|
| 172 |
-
|
| 173 |
//=============================================================================
|
| 174 |
// Tokenizer::Config functions
|
| 175 |
//=============================================================================
|
|
@@ -322,6 +247,8 @@ static void validateBackendHtpConfig(const qualla::json& config) {
|
|
| 322 |
} else if (item.key() == "rope-theta") {
|
| 323 |
rope_theta_set = true;
|
| 324 |
JSON_ENFORCE_NUMERIC();
|
|
|
|
|
|
|
| 325 |
} else {
|
| 326 |
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
|
| 327 |
}
|
|
@@ -410,7 +337,7 @@ static void validateBackendConfig(const qualla::json& config) {
|
|
| 410 |
htp = true;
|
| 411 |
} else if (type == "QnnGenAiTransformer") {
|
| 412 |
genai = true;
|
| 413 |
-
} else {
|
| 414 |
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 415 |
"Invalid backend config: unsupported type: " + item.value().dump());
|
| 416 |
}
|
|
@@ -629,6 +556,9 @@ static void validateModelLibraryConfig(const qualla::json& config) {
|
|
| 629 |
}
|
| 630 |
} else if (item.key() == "model-bin") {
|
| 631 |
JSON_ENFORCE_STRING();
|
|
|
|
|
|
|
|
|
|
| 632 |
} else {
|
| 633 |
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
|
| 634 |
}
|
|
@@ -956,6 +886,10 @@ static void translateEngineConfig(const qualla::json& genieEngineConfig,
|
|
| 956 |
quallaEngineConfig["use-async-Init"] =
|
| 957 |
genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
|
| 958 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 959 |
} else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
|
| 960 |
quallaEngineConfig["type"] = "qnn-cpu";
|
| 961 |
quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
|
|
@@ -979,6 +913,8 @@ static void translateEngineConfig(const qualla::json& genieEngineConfig,
|
|
| 979 |
quallaEngineConfig["n_heads"] =
|
| 980 |
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
|
| 981 |
}
|
|
|
|
|
|
|
| 982 |
}
|
| 983 |
|
| 984 |
if (genieEngineConfig["backend"].contains("extensions")) {
|
|
@@ -1020,6 +956,21 @@ static void translateEngineConfig(const qualla::json& genieEngineConfig,
|
|
| 1020 |
quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
|
| 1021 |
quallaEngineConfig["op-package"] =
|
| 1022 |
getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1023 |
}
|
| 1024 |
if (genieEngineConfig["model"].contains("positional-encoding")) {
|
| 1025 |
quallaEngineConfig["positional-encoding"]["type"] =
|
|
@@ -1424,7 +1375,7 @@ static void validateDialogConfig(const qualla::json& config) {
|
|
| 1424 |
validateTokenizerConfig(item.value());
|
| 1425 |
} else if (item.key() == "sampler") {
|
| 1426 |
JSON_ENFORCE_OBJECT();
|
| 1427 |
-
validateSamplerConfig(item.value());
|
| 1428 |
} else if (item.key() == "engine") {
|
| 1429 |
JSON_ENFORCE_ARRAY_OR_OBJECT();
|
| 1430 |
} else if (item.key() == "embedding") {
|
|
@@ -1550,7 +1501,7 @@ static void translateDialogConfig(const qualla::json& genieConfig, qualla::json&
|
|
| 1550 |
|
| 1551 |
translateContextConfig(genieConfig, quallaConfig);
|
| 1552 |
translateTokenizerConfig(genieConfig, quallaConfig);
|
| 1553 |
-
translateSamplerConfig(genieConfig, quallaConfig);
|
| 1554 |
translateMultiEngineConfig(genieConfig, quallaConfig);
|
| 1555 |
translateEmbeddingConfig(genieConfig, quallaConfig);
|
| 1556 |
}
|
|
@@ -1611,7 +1562,7 @@ Dialog::Config::Config(const char* configStr) {
|
|
| 1611 |
m_config = config;
|
| 1612 |
}
|
| 1613 |
|
| 1614 |
-
qualla::json Dialog::Config::getJson()
|
| 1615 |
|
| 1616 |
//=============================================================================
|
| 1617 |
// Dialog functions
|
|
@@ -1640,6 +1591,27 @@ Dialog::Dialog(std::shared_ptr<Config> config) {
|
|
| 1640 |
if (!m_quallaDialog) {
|
| 1641 |
throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
|
| 1642 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1643 |
}
|
| 1644 |
|
| 1645 |
static_assert(qualla::Sentence::Code::COMPLETE ==
|
|
@@ -1801,4 +1773,6 @@ int32_t Dialog::tokenQuery(const uint32_t* tokens,
|
|
| 1801 |
kpis.generate.last_usec,
|
| 1802 |
kpis.tps.generate);
|
| 1803 |
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1804 |
-
}
|
|
|
|
|
|
|
|
|
| 95 |
}
|
| 96 |
}
|
| 97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
//=============================================================================
|
| 99 |
// Tokenizer::Config functions
|
| 100 |
//=============================================================================
|
|
|
|
| 247 |
} else if (item.key() == "rope-theta") {
|
| 248 |
rope_theta_set = true;
|
| 249 |
JSON_ENFORCE_NUMERIC();
|
| 250 |
+
} else if (item.key() == "enable-graph-switching") {
|
| 251 |
+
JSON_ENFORCE_BOOLEAN();
|
| 252 |
} else {
|
| 253 |
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
|
| 254 |
}
|
|
|
|
| 337 |
htp = true;
|
| 338 |
} else if (type == "QnnGenAiTransformer") {
|
| 339 |
genai = true;
|
| 340 |
+
} else if (type != "QnnGpu") {
|
| 341 |
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 342 |
"Invalid backend config: unsupported type: " + item.value().dump());
|
| 343 |
}
|
|
|
|
| 556 |
}
|
| 557 |
} else if (item.key() == "model-bin") {
|
| 558 |
JSON_ENFORCE_STRING();
|
| 559 |
+
} else if (item.key() == "lora") {
|
| 560 |
+
JSON_ENFORCE_OBJECT();
|
| 561 |
+
validateLoraConfig(item.value());
|
| 562 |
} else {
|
| 563 |
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
|
| 564 |
}
|
|
|
|
| 886 |
quallaEngineConfig["use-async-Init"] =
|
| 887 |
genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
|
| 888 |
}
|
| 889 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("enable-graph-switching")) {
|
| 890 |
+
quallaEngineConfig["enable-graph-switching"] =
|
| 891 |
+
genieEngineConfig["backend"]["QnnHtp"]["enable-graph-switching"];
|
| 892 |
+
}
|
| 893 |
} else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
|
| 894 |
quallaEngineConfig["type"] = "qnn-cpu";
|
| 895 |
quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
|
|
|
|
| 913 |
quallaEngineConfig["n_heads"] =
|
| 914 |
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
|
| 915 |
}
|
| 916 |
+
} else if (genieEngineConfig["backend"]["type"] == "QnnGpu") {
|
| 917 |
+
quallaEngineConfig["type"] = "qnn-gpu";
|
| 918 |
}
|
| 919 |
|
| 920 |
if (genieEngineConfig["backend"].contains("extensions")) {
|
|
|
|
| 956 |
quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
|
| 957 |
quallaEngineConfig["op-package"] =
|
| 958 |
getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
|
| 959 |
+
if (genieEngineConfig["model"]["library"].contains("lora")) {
|
| 960 |
+
for (int i = 0; i < genieEngineConfig["model"]["library"]["lora"]["adapters"].size(); i++) {
|
| 961 |
+
quallaEngineConfig["lora"][i]["adapter-name"] =
|
| 962 |
+
genieEngineConfig["model"]["library"]["lora"]["adapters"][i]["name"];
|
| 963 |
+
if (genieEngineConfig["model"]["library"]["lora"].contains("alpha-tensor-name")) {
|
| 964 |
+
quallaEngineConfig["lora"][i]["alpha-tensor-name"] =
|
| 965 |
+
genieEngineConfig["model"]["library"]["lora"]
|
| 966 |
+
["alpha-tensor-name"];
|
| 967 |
+
}
|
| 968 |
+
quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f;
|
| 969 |
+
quallaEngineConfig["lora"][i]["binsection-basedir"] = "";
|
| 970 |
+
quallaEngineConfig["lora"][i]["bin-sections"] =
|
| 971 |
+
genieEngineConfig["model"]["library"]["lora"]["adapters"][i]["bin-sections"];
|
| 972 |
+
}
|
| 973 |
+
}
|
| 974 |
}
|
| 975 |
if (genieEngineConfig["model"].contains("positional-encoding")) {
|
| 976 |
quallaEngineConfig["positional-encoding"]["type"] =
|
|
|
|
| 1375 |
validateTokenizerConfig(item.value());
|
| 1376 |
} else if (item.key() == "sampler") {
|
| 1377 |
JSON_ENFORCE_OBJECT();
|
| 1378 |
+
Sampler::SamplerConfig::validateSamplerConfig(item.value());
|
| 1379 |
} else if (item.key() == "engine") {
|
| 1380 |
JSON_ENFORCE_ARRAY_OR_OBJECT();
|
| 1381 |
} else if (item.key() == "embedding") {
|
|
|
|
| 1501 |
|
| 1502 |
translateContextConfig(genieConfig, quallaConfig);
|
| 1503 |
translateTokenizerConfig(genieConfig, quallaConfig);
|
| 1504 |
+
Sampler::SamplerConfig::translateSamplerConfig(genieConfig, quallaConfig);
|
| 1505 |
translateMultiEngineConfig(genieConfig, quallaConfig);
|
| 1506 |
translateEmbeddingConfig(genieConfig, quallaConfig);
|
| 1507 |
}
|
|
|
|
| 1562 |
m_config = config;
|
| 1563 |
}
|
| 1564 |
|
| 1565 |
+
qualla::json& Dialog::Config::getJson() { return m_config; }
|
| 1566 |
|
| 1567 |
//=============================================================================
|
| 1568 |
// Dialog functions
|
|
|
|
| 1591 |
if (!m_quallaDialog) {
|
| 1592 |
throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
|
| 1593 |
}
|
| 1594 |
+
/*
|
| 1595 |
+
* spec-dec has a mandatory "target" sampler and an optional "draft" sampler
|
| 1596 |
+
* Check their availability and pass their references to Dialog Sampler to update with
|
| 1597 |
+
* applyConfig()
|
| 1598 |
+
*/
|
| 1599 |
+
std::shared_ptr<Sampler> sampler;
|
| 1600 |
+
std::vector<std::reference_wrapper<qualla::Sampler>> quallaSamplers;
|
| 1601 |
+
if (quallaConfig["type"] == "spec-dec") {
|
| 1602 |
+
quallaSamplers.push_back(m_quallaDialog->sampler("target"));
|
| 1603 |
+
if (m_quallaDialog->isSamplerPresent("draft"))
|
| 1604 |
+
quallaSamplers.push_back(m_quallaDialog->sampler("draft"));
|
| 1605 |
+
sampler = std::make_shared<Sampler>(config->getJson()["dialog"], quallaSamplers);
|
| 1606 |
+
} else {
|
| 1607 |
+
quallaSamplers.push_back(m_quallaDialog->sampler()); // Default role is "primary"
|
| 1608 |
+
sampler = std::make_shared<Sampler>(config->getJson()["dialog"], quallaSamplers);
|
| 1609 |
+
}
|
| 1610 |
+
m_samplerHandle = Sampler::add(sampler);
|
| 1611 |
+
}
|
| 1612 |
+
|
| 1613 |
+
GenieSampler_Handle_t Dialog::getSamplerHandle(std::shared_ptr<Dialog> dialog) {
|
| 1614 |
+
return dialog->m_samplerHandle;
|
| 1615 |
}
|
| 1616 |
|
| 1617 |
static_assert(qualla::Sentence::Code::COMPLETE ==
|
|
|
|
| 1773 |
kpis.generate.last_usec,
|
| 1774 |
kpis.tps.generate);
|
| 1775 |
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
|
| 1776 |
+
}
|
| 1777 |
+
|
| 1778 |
+
Dialog::~Dialog() { Sampler::remove(m_samplerHandle); }
|
Genie/Genie/src/Dialog.hpp
CHANGED
|
@@ -10,11 +10,13 @@
|
|
| 10 |
|
| 11 |
#include <atomic>
|
| 12 |
#include <memory>
|
|
|
|
| 13 |
|
| 14 |
#include "GenieDialog.h"
|
| 15 |
#include "Util/HandleManager.hpp"
|
| 16 |
#include "qualla/dialog.hpp"
|
| 17 |
#include "qualla/DialogCallback.hpp"
|
|
|
|
| 18 |
|
| 19 |
namespace genie {
|
| 20 |
|
|
@@ -33,7 +35,7 @@ class Dialog {
|
|
| 33 |
static void remove(GenieDialogConfig_Handle_t handle);
|
| 34 |
|
| 35 |
Config(const char* configStr);
|
| 36 |
-
qualla::json getJson()
|
| 37 |
|
| 38 |
private:
|
| 39 |
static qnn::util::HandleManager<Config> s_manager;
|
|
@@ -43,10 +45,12 @@ class Dialog {
|
|
| 43 |
static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
|
| 44 |
static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
|
| 45 |
static void remove(GenieDialog_Handle_t handle);
|
|
|
|
| 46 |
|
| 47 |
qualla::DialogCallback dialogCallback;
|
| 48 |
|
| 49 |
Dialog(std::shared_ptr<Config> config);
|
|
|
|
| 50 |
|
| 51 |
Dialog(const Dialog&) = delete;
|
| 52 |
Dialog& operator=(const Dialog&) = delete;
|
|
@@ -91,5 +95,6 @@ class Dialog {
|
|
| 91 |
uint32_t m_tokenLimit{UINT32_MAX};
|
| 92 |
static qnn::util::HandleManager<Dialog> s_manager;
|
| 93 |
static std::atomic<std::uint32_t> s_nameCounter;
|
|
|
|
| 94 |
};
|
| 95 |
} // namespace genie
|
|
|
|
| 10 |
|
| 11 |
#include <atomic>
|
| 12 |
#include <memory>
|
| 13 |
+
#include <functional>
|
| 14 |
|
| 15 |
#include "GenieDialog.h"
|
| 16 |
#include "Util/HandleManager.hpp"
|
| 17 |
#include "qualla/dialog.hpp"
|
| 18 |
#include "qualla/DialogCallback.hpp"
|
| 19 |
+
#include "Sampler.hpp"
|
| 20 |
|
| 21 |
namespace genie {
|
| 22 |
|
|
|
|
| 35 |
static void remove(GenieDialogConfig_Handle_t handle);
|
| 36 |
|
| 37 |
Config(const char* configStr);
|
| 38 |
+
qualla::json& getJson();
|
| 39 |
|
| 40 |
private:
|
| 41 |
static qnn::util::HandleManager<Config> s_manager;
|
|
|
|
| 45 |
static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
|
| 46 |
static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
|
| 47 |
static void remove(GenieDialog_Handle_t handle);
|
| 48 |
+
static GenieSampler_Handle_t getSamplerHandle(std::shared_ptr<genie::Dialog> dialog);
|
| 49 |
|
| 50 |
qualla::DialogCallback dialogCallback;
|
| 51 |
|
| 52 |
Dialog(std::shared_ptr<Config> config);
|
| 53 |
+
~Dialog();
|
| 54 |
|
| 55 |
Dialog(const Dialog&) = delete;
|
| 56 |
Dialog& operator=(const Dialog&) = delete;
|
|
|
|
| 95 |
uint32_t m_tokenLimit{UINT32_MAX};
|
| 96 |
static qnn::util::HandleManager<Dialog> s_manager;
|
| 97 |
static std::atomic<std::uint32_t> s_nameCounter;
|
| 98 |
+
GenieSampler_Handle_t m_samplerHandle;
|
| 99 |
};
|
| 100 |
} // namespace genie
|
Genie/Genie/src/Embedding.cpp
ADDED
|
@@ -0,0 +1,740 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include <exception>
|
| 10 |
+
#include <set>
|
| 11 |
+
#include <sstream>
|
| 12 |
+
|
| 13 |
+
#include "Embedding.hpp"
|
| 14 |
+
#include "Exception.hpp"
|
| 15 |
+
#include "Macro.hpp"
|
| 16 |
+
#include "qualla/detail/json.hpp"
|
| 17 |
+
#include "qualla/env.hpp"
|
| 18 |
+
|
| 19 |
+
using namespace genie;
|
| 20 |
+
|
| 21 |
+
#ifdef _WIN32
|
| 22 |
+
inline std::string libPrefix = "";
|
| 23 |
+
inline std::string libSuffix = ".dll";
|
| 24 |
+
#else
|
| 25 |
+
inline std::string libPrefix = "lib";
|
| 26 |
+
inline std::string libSuffix = ".so";
|
| 27 |
+
#endif
|
| 28 |
+
|
| 29 |
+
inline std::string getLibName(std::string baseName) { return libPrefix + baseName + libSuffix; }
|
| 30 |
+
|
| 31 |
+
//=============================================================================
|
| 32 |
+
// Context::Config functions
|
| 33 |
+
//=============================================================================
|
| 34 |
+
|
| 35 |
+
static void validateContextConfig(const qualla::json& config) {
|
| 36 |
+
if (!config.is_object()) {
|
| 37 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "context config is not an object");
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
std::set<std::string> mandatoryFields{
|
| 41 |
+
"version", "n-vocab", "ctx-size", "embed-size", "pad-token"};
|
| 42 |
+
for (const auto& field : mandatoryFields) {
|
| 43 |
+
if (!config.contains(field)) {
|
| 44 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing context field: " + field);
|
| 45 |
+
}
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
// component is used in the "ENFORCE" macros
|
| 49 |
+
std::string component = "context";
|
| 50 |
+
|
| 51 |
+
for (auto& item : config.items()) {
|
| 52 |
+
if (item.key() == "version") {
|
| 53 |
+
JSON_ENFORCE_NUMERIC();
|
| 54 |
+
if (item.value().get<int>() != 1) {
|
| 55 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 56 |
+
"Invalid context config: unsupported version: " + item.value().dump());
|
| 57 |
+
}
|
| 58 |
+
} else if (item.key() == "n-vocab") {
|
| 59 |
+
JSON_ENFORCE_NUMERIC();
|
| 60 |
+
} else if (item.key() == "ctx-size") {
|
| 61 |
+
JSON_ENFORCE_NUMERIC();
|
| 62 |
+
} else if (item.key() == "embed-size") {
|
| 63 |
+
JSON_ENFORCE_NUMERIC();
|
| 64 |
+
} else if (item.key() == "pad-token") {
|
| 65 |
+
JSON_ENFORCE_NUMERIC();
|
| 66 |
+
} else {
|
| 67 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
|
| 68 |
+
}
|
| 69 |
+
}
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
static void translateContextConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 73 |
+
quallaConfig["n-vocab"] = genieConfig["n-vocab"];
|
| 74 |
+
quallaConfig["size"] = genieConfig["ctx-size"];
|
| 75 |
+
quallaConfig["n-embd"] = genieConfig["embed-size"];
|
| 76 |
+
quallaConfig["pad-token"] = genieConfig["pad-token"];
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
//=============================================================================
|
| 80 |
+
// Tokenizer::Config functions
|
| 81 |
+
//=============================================================================
|
| 82 |
+
|
| 83 |
+
static void validateTokenizerConfig(const qualla::json& config) {
|
| 84 |
+
if (!config.is_object()) {
|
| 85 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "tokenizer config is not an object");
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
std::set<std::string> mandatoryFields{"version", "path"};
|
| 89 |
+
for (const auto& field : mandatoryFields) {
|
| 90 |
+
if (!config.contains(field)) {
|
| 91 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing tokenizer field: " + field);
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// component is used in the "ENFORCE" macros
|
| 96 |
+
std::string component = "tokenizer";
|
| 97 |
+
|
| 98 |
+
for (auto& item : config.items()) {
|
| 99 |
+
if (item.key() == "version") {
|
| 100 |
+
JSON_ENFORCE_NUMERIC();
|
| 101 |
+
if (item.value().get<int>() != 1) {
|
| 102 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 103 |
+
"Invalid tokenizer config: unsupported version: " + item.value().dump());
|
| 104 |
+
}
|
| 105 |
+
} else if (item.key() == "path") {
|
| 106 |
+
JSON_ENFORCE_STRING();
|
| 107 |
+
// Note: the existence of this file is checked by qualla
|
| 108 |
+
} else {
|
| 109 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 110 |
+
"Unknown tokenizer config key: " + item.key());
|
| 111 |
+
}
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
static void translateTokenizerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 116 |
+
quallaConfig["tokenizer"] = genieConfig["path"];
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
//=============================================================================
|
| 120 |
+
// Backend::Config functions
|
| 121 |
+
//=============================================================================
|
| 122 |
+
|
| 123 |
+
static void validateBackendHtpConfig(const qualla::json& config) {
|
| 124 |
+
if (!config.is_object()) {
|
| 125 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnHtp config is not an object");
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
std::set<std::string> mandatoryFields{
|
| 129 |
+
"version", "spill-fill-bufsize", "use-mmap", "pooled-output", "allow-async-init"};
|
| 130 |
+
for (const auto& field : mandatoryFields) {
|
| 131 |
+
if (!config.contains(field)) {
|
| 132 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp field: " + field);
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
// component is used in the "ENFORCE" macros
|
| 137 |
+
std::string component = "QnnHtp";
|
| 138 |
+
|
| 139 |
+
for (auto& item : config.items()) {
|
| 140 |
+
if (item.key() == "version") {
|
| 141 |
+
JSON_ENFORCE_NUMERIC();
|
| 142 |
+
if (item.value().get<int>() != 1) {
|
| 143 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 144 |
+
"Invalid QnnHtp config: unsupported version: " + item.value().dump());
|
| 145 |
+
}
|
| 146 |
+
} else if (item.key() == "spill-fill-bufsize") {
|
| 147 |
+
JSON_ENFORCE_NUMERIC();
|
| 148 |
+
} else if (item.key() == "use-mmap") {
|
| 149 |
+
JSON_ENFORCE_BOOLEAN();
|
| 150 |
+
} else if (item.key() == "pooled-output") {
|
| 151 |
+
JSON_ENFORCE_BOOLEAN();
|
| 152 |
+
} else if (item.key() == "allow-async-init") {
|
| 153 |
+
JSON_ENFORCE_BOOLEAN();
|
| 154 |
+
} else if (item.key() == "disable-kv-cache") {
|
| 155 |
+
JSON_ENFORCE_BOOLEAN();
|
| 156 |
+
} else {
|
| 157 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
|
| 158 |
+
}
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
static void validateBackendGenaiConfig(const qualla::json& config) {
|
| 163 |
+
if (!config.is_object()) {
|
| 164 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnGenAiTransformer config is not an object");
|
| 165 |
+
}
|
| 166 |
+
|
| 167 |
+
std::set<std::string> mandatoryFields{"version"};
|
| 168 |
+
for (const auto& field : mandatoryFields) {
|
| 169 |
+
if (!config.contains(field)) {
|
| 170 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 171 |
+
"Missing QnnGenAiTransformer field: " + field);
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
// component is used in the "ENFORCE" macros
|
| 176 |
+
std::string component = "QnnGenAiTransformer";
|
| 177 |
+
|
| 178 |
+
for (auto& item : config.items()) {
|
| 179 |
+
if (item.key() == "version") {
|
| 180 |
+
JSON_ENFORCE_NUMERIC();
|
| 181 |
+
if (item.value().get<int>() != 1) {
|
| 182 |
+
throw Exception(
|
| 183 |
+
GENIE_STATUS_ERROR_JSON_VALUE,
|
| 184 |
+
"Invalid QnnGenAiTransformer config: unsupported version: " + item.value().dump());
|
| 185 |
+
}
|
| 186 |
+
} else if (item.key() == "n-logits") {
|
| 187 |
+
JSON_ENFORCE_NUMERIC();
|
| 188 |
+
} else if (item.key() == "n-layer") {
|
| 189 |
+
JSON_ENFORCE_NUMERIC();
|
| 190 |
+
} else if (item.key() == "n-embd") {
|
| 191 |
+
JSON_ENFORCE_NUMERIC();
|
| 192 |
+
} else if (item.key() == "n-heads") {
|
| 193 |
+
JSON_ENFORCE_NUMERIC();
|
| 194 |
+
} else {
|
| 195 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 196 |
+
"Unknown QnnGenAiTransformer config key: " + item.key());
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
static void validateBackendConfig(const qualla::json& config) {
|
| 202 |
+
if (!config.is_object()) {
|
| 203 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "backend config is not an object");
|
| 204 |
+
}
|
| 205 |
+
|
| 206 |
+
std::set<std::string> mandatoryFields{"version", "type"};
|
| 207 |
+
for (const auto& field : mandatoryFields) {
|
| 208 |
+
if (!config.contains(field)) {
|
| 209 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing backend field: " + field);
|
| 210 |
+
}
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
// component is used in the "ENFORCE" macros
|
| 214 |
+
std::string component = "backend";
|
| 215 |
+
|
| 216 |
+
std::string type;
|
| 217 |
+
bool htp = false;
|
| 218 |
+
qualla::json htpConfig;
|
| 219 |
+
bool genai = false;
|
| 220 |
+
qualla::json genaiConfig;
|
| 221 |
+
|
| 222 |
+
for (auto& item : config.items()) {
|
| 223 |
+
if (item.key() == "version") {
|
| 224 |
+
JSON_ENFORCE_NUMERIC();
|
| 225 |
+
if (item.value().get<int>() != 1) {
|
| 226 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 227 |
+
"Invalid backend config: unsupported version: " + item.value().dump());
|
| 228 |
+
}
|
| 229 |
+
} else if (item.key() == "type") {
|
| 230 |
+
JSON_ENFORCE_STRING();
|
| 231 |
+
type = item.value().get<std::string>();
|
| 232 |
+
if (type == "QnnHtp") {
|
| 233 |
+
htp = true;
|
| 234 |
+
} else if (type == "QnnGenAiTransformer") {
|
| 235 |
+
genai = true;
|
| 236 |
+
} else {
|
| 237 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 238 |
+
"Invalid backend config: unsupported type: " + item.value().dump());
|
| 239 |
+
}
|
| 240 |
+
} else if (item.key() == "extensions") {
|
| 241 |
+
JSON_ENFORCE_STRING();
|
| 242 |
+
} else if (item.key() == "QnnHtp") {
|
| 243 |
+
JSON_ENFORCE_OBJECT();
|
| 244 |
+
htpConfig = item.value();
|
| 245 |
+
} else if (item.key() == "QnnGenAiTransformer") {
|
| 246 |
+
JSON_ENFORCE_OBJECT();
|
| 247 |
+
genaiConfig = item.value();
|
| 248 |
+
} else {
|
| 249 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown backend config key: " + item.key());
|
| 250 |
+
}
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
if (htp) {
|
| 254 |
+
if (!htpConfig.is_object()) {
|
| 255 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp embedding config");
|
| 256 |
+
}
|
| 257 |
+
validateBackendHtpConfig(htpConfig);
|
| 258 |
+
} else {
|
| 259 |
+
if (htpConfig.is_object()) {
|
| 260 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 261 |
+
"QnnHtp backend config for incorrect backend type: " + type);
|
| 262 |
+
}
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
if (genai) {
|
| 266 |
+
if (!genaiConfig.is_object()) {
|
| 267 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 268 |
+
"Missing QnnGenAiTransformer embedding config");
|
| 269 |
+
}
|
| 270 |
+
validateBackendGenaiConfig(genaiConfig);
|
| 271 |
+
} else {
|
| 272 |
+
if (genaiConfig.is_object()) {
|
| 273 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 274 |
+
"QnnGenAiTransformer backend config for incorrect backend type: " + type);
|
| 275 |
+
}
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
//=============================================================================
|
| 280 |
+
// Model::Config functions
|
| 281 |
+
//=============================================================================
|
| 282 |
+
|
| 283 |
+
static void validateModelBinaryConfig(const qualla::json& config) {
|
| 284 |
+
if (!config.is_object()) {
|
| 285 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "binary config is not an object");
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
std::set<std::string> mandatoryFields{"version", "ctx-bins"};
|
| 289 |
+
for (const auto& field : mandatoryFields) {
|
| 290 |
+
if (!config.contains(field)) {
|
| 291 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary field: " + field);
|
| 292 |
+
}
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
// component is used in the "ENFORCE" macros
|
| 296 |
+
std::string component = "binary";
|
| 297 |
+
|
| 298 |
+
for (auto& item : config.items()) {
|
| 299 |
+
if (item.key() == "version") {
|
| 300 |
+
JSON_ENFORCE_NUMERIC();
|
| 301 |
+
if (item.value().get<int>() != 1) {
|
| 302 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 303 |
+
"Invalid binary config: unsupported version: " + item.value().dump());
|
| 304 |
+
}
|
| 305 |
+
} else if (item.key() == "ctx-bins") {
|
| 306 |
+
JSON_ENFORCE_ARRAY();
|
| 307 |
+
for (auto& elem : item.value()) {
|
| 308 |
+
if (!elem.is_string()) {
|
| 309 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "ctx-bins must be an array of strings");
|
| 310 |
+
}
|
| 311 |
+
}
|
| 312 |
+
} else {
|
| 313 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown binary config key: " + item.key());
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
static void validateModelLibraryConfig(const qualla::json& config) {
|
| 319 |
+
if (!config.is_object()) {
|
| 320 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "library config is not an object");
|
| 321 |
+
}
|
| 322 |
+
|
| 323 |
+
std::set<std::string> mandatoryFields{"version", "model-bin"};
|
| 324 |
+
for (const auto& field : mandatoryFields) {
|
| 325 |
+
if (!config.contains(field)) {
|
| 326 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library field: " + field);
|
| 327 |
+
}
|
| 328 |
+
}
|
| 329 |
+
|
| 330 |
+
// component is used in the "ENFORCE" macros
|
| 331 |
+
std::string component = "library";
|
| 332 |
+
|
| 333 |
+
for (auto& item : config.items()) {
|
| 334 |
+
if (item.key() == "version") {
|
| 335 |
+
JSON_ENFORCE_NUMERIC();
|
| 336 |
+
if (item.value().get<int>() != 1) {
|
| 337 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 338 |
+
"Invalid library config: unsupported version: " + item.value().dump());
|
| 339 |
+
}
|
| 340 |
+
} else if (item.key() == "model-bin") {
|
| 341 |
+
JSON_ENFORCE_STRING();
|
| 342 |
+
} else {
|
| 343 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
|
| 344 |
+
}
|
| 345 |
+
}
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
static void validateModelConfig(const qualla::json& config) {
|
| 349 |
+
if (!config.is_object()) {
|
| 350 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "model config is not an object");
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
std::set<std::string> mandatoryFields{"version", "type"};
|
| 354 |
+
for (const auto& field : mandatoryFields) {
|
| 355 |
+
if (!config.contains(field)) {
|
| 356 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing model field: " + field);
|
| 357 |
+
}
|
| 358 |
+
}
|
| 359 |
+
|
| 360 |
+
// component is used in the "ENFORCE" macros
|
| 361 |
+
std::string component = "model";
|
| 362 |
+
|
| 363 |
+
std::string type;
|
| 364 |
+
bool binary = false;
|
| 365 |
+
qualla::json binaryConfig;
|
| 366 |
+
bool library = false;
|
| 367 |
+
qualla::json libraryConfig;
|
| 368 |
+
|
| 369 |
+
for (auto& item : config.items()) {
|
| 370 |
+
if (item.key() == "version") {
|
| 371 |
+
JSON_ENFORCE_NUMERIC();
|
| 372 |
+
if (item.value().get<int>() != 1) {
|
| 373 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 374 |
+
"Invalid model config: unsupported version: " + item.value().dump());
|
| 375 |
+
}
|
| 376 |
+
} else if (item.key() == "type") {
|
| 377 |
+
JSON_ENFORCE_STRING();
|
| 378 |
+
type = item.value().get<std::string>();
|
| 379 |
+
if (type == "binary") {
|
| 380 |
+
binary = true;
|
| 381 |
+
} else if (type == "library") {
|
| 382 |
+
library = true;
|
| 383 |
+
} else {
|
| 384 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 385 |
+
"Invalid model config: unsupported type: " + item.value().dump());
|
| 386 |
+
}
|
| 387 |
+
} else if (item.key() == "binary") {
|
| 388 |
+
JSON_ENFORCE_OBJECT();
|
| 389 |
+
binaryConfig = item.value();
|
| 390 |
+
} else if (item.key() == "library") {
|
| 391 |
+
JSON_ENFORCE_OBJECT();
|
| 392 |
+
libraryConfig = item.value();
|
| 393 |
+
} else {
|
| 394 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown model config key: " + item.key());
|
| 395 |
+
}
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
if (binary) {
|
| 399 |
+
if (!binaryConfig.is_object()) {
|
| 400 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary model config");
|
| 401 |
+
}
|
| 402 |
+
validateModelBinaryConfig(binaryConfig);
|
| 403 |
+
} else {
|
| 404 |
+
if (binaryConfig.is_object()) {
|
| 405 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 406 |
+
"binary model config for incorrect model type: " + type);
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
if (library) {
|
| 411 |
+
if (!libraryConfig.is_object()) {
|
| 412 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library model config");
|
| 413 |
+
}
|
| 414 |
+
validateModelLibraryConfig(libraryConfig);
|
| 415 |
+
} else {
|
| 416 |
+
if (libraryConfig.is_object()) {
|
| 417 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 418 |
+
"library model config for incorrect model type: " + type);
|
| 419 |
+
}
|
| 420 |
+
}
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
//=============================================================================
|
| 424 |
+
// Engine::Config functions
|
| 425 |
+
//=============================================================================
|
| 426 |
+
|
| 427 |
+
static void validateEngineConfig(const qualla::json& config) {
|
| 428 |
+
if (!config.is_object()) {
|
| 429 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object");
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
std::set<std::string> mandatoryFields{"version", "backend", "model"};
|
| 433 |
+
for (const auto& field : mandatoryFields) {
|
| 434 |
+
if (!config.contains(field)) {
|
| 435 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing engine field: " + field);
|
| 436 |
+
}
|
| 437 |
+
}
|
| 438 |
+
|
| 439 |
+
// component is used in the "ENFORCE" macros
|
| 440 |
+
std::string component = "engine";
|
| 441 |
+
|
| 442 |
+
for (auto& item : config.items()) {
|
| 443 |
+
if (item.key() == "version") {
|
| 444 |
+
JSON_ENFORCE_NUMERIC();
|
| 445 |
+
if (item.value().get<int>() != 1) {
|
| 446 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 447 |
+
"Invalid engine config: unsupported version: " + item.value().dump());
|
| 448 |
+
}
|
| 449 |
+
} else if (item.key() == "backend") {
|
| 450 |
+
JSON_ENFORCE_OBJECT();
|
| 451 |
+
validateBackendConfig(item.value());
|
| 452 |
+
} else if (item.key() == "model") {
|
| 453 |
+
JSON_ENFORCE_OBJECT();
|
| 454 |
+
validateModelConfig(item.value());
|
| 455 |
+
} else if (item.key() == "n-threads") {
|
| 456 |
+
JSON_ENFORCE_NUMERIC();
|
| 457 |
+
} else {
|
| 458 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown engine config key: " + item.key());
|
| 459 |
+
}
|
| 460 |
+
}
|
| 461 |
+
}
|
| 462 |
+
|
| 463 |
+
static void translateEngineConfig(const qualla::json& genieEngineConfig,
|
| 464 |
+
qualla::json& quallaEngineConfig) {
|
| 465 |
+
if (genieEngineConfig["version"] == 1) {
|
| 466 |
+
if (genieEngineConfig.contains("n-threads"))
|
| 467 |
+
quallaEngineConfig["n-threads"] = genieEngineConfig["n-threads"];
|
| 468 |
+
|
| 469 |
+
if (genieEngineConfig["backend"]["type"] == "QnnHtp") {
|
| 470 |
+
quallaEngineConfig["type"] = "qnn-htp";
|
| 471 |
+
quallaEngineConfig["model-architecture-type"] = "encoder",
|
| 472 |
+
quallaEngineConfig["backend-lib"] = getLibName("QnnHtp");
|
| 473 |
+
quallaEngineConfig["use-mmap"] = genieEngineConfig["backend"]["QnnHtp"]["use-mmap"];
|
| 474 |
+
quallaEngineConfig["spill-fill-bufsize"] =
|
| 475 |
+
genieEngineConfig["backend"]["QnnHtp"]["spill-fill-bufsize"];
|
| 476 |
+
quallaEngineConfig["pooled-output"] = genieEngineConfig["backend"]["QnnHtp"]["pooled-output"];
|
| 477 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("disable-kv-cache")) {
|
| 478 |
+
quallaEngineConfig["disable-kv-cache"] =
|
| 479 |
+
genieEngineConfig["backend"]["QnnHtp"]["disable-kv-cache"];
|
| 480 |
+
}
|
| 481 |
+
// By default, Qualla will default to the async init path.
|
| 482 |
+
// For now, we are forcing async init off unless explicitly
|
| 483 |
+
// specified in the Genie config. It is HTP specific feature only.
|
| 484 |
+
quallaEngineConfig["use-async-Init"] = false;
|
| 485 |
+
if (genieEngineConfig["backend"]["QnnHtp"].contains("allow-async-init")) {
|
| 486 |
+
quallaEngineConfig["use-async-Init"] =
|
| 487 |
+
genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
|
| 488 |
+
}
|
| 489 |
+
} else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
|
| 490 |
+
quallaEngineConfig["type"] = "qnn-cpu";
|
| 491 |
+
quallaEngineConfig["model-output"] = "embeddings";
|
| 492 |
+
quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
|
| 493 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-logits")) {
|
| 494 |
+
quallaEngineConfig["n_logits"] =
|
| 495 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-logits"];
|
| 496 |
+
}
|
| 497 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-layer")) {
|
| 498 |
+
quallaEngineConfig["n_layer"] =
|
| 499 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-layer"];
|
| 500 |
+
}
|
| 501 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-embd")) {
|
| 502 |
+
quallaEngineConfig["n_embd"] =
|
| 503 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-embd"];
|
| 504 |
+
}
|
| 505 |
+
if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-heads")) {
|
| 506 |
+
quallaEngineConfig["n_heads"] =
|
| 507 |
+
genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
|
| 508 |
+
}
|
| 509 |
+
}
|
| 510 |
+
|
| 511 |
+
if (genieEngineConfig["backend"].contains("extensions")) {
|
| 512 |
+
quallaEngineConfig["backend-ext-conf"] = genieEngineConfig["backend"]["extensions"];
|
| 513 |
+
}
|
| 514 |
+
|
| 515 |
+
if (genieEngineConfig["model"]["type"] == "binary") {
|
| 516 |
+
quallaEngineConfig["model-list"] = genieEngineConfig["model"]["binary"]["ctx-bins"];
|
| 517 |
+
} else if (genieEngineConfig["model"]["type"] == "library") {
|
| 518 |
+
quallaEngineConfig["model"] = getLibName("QnnGenAiTransformerModel");
|
| 519 |
+
quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
|
| 520 |
+
quallaEngineConfig["op-package"] =
|
| 521 |
+
getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
|
| 522 |
+
}
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
|
| 526 |
+
//=============================================================================
|
| 527 |
+
// Prompt::Config functions
|
| 528 |
+
//=============================================================================
|
| 529 |
+
|
| 530 |
+
static void validatePromptConfig(const qualla::json& config) {
|
| 531 |
+
if (!config.is_object()) {
|
| 532 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "prompt config is not an object");
|
| 533 |
+
}
|
| 534 |
+
|
| 535 |
+
std::set<std::string> mandatoryFields{"version", "prompt-template"};
|
| 536 |
+
for (const auto& field : mandatoryFields) {
|
| 537 |
+
if (!config.contains(field)) {
|
| 538 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing prompt field: " + field);
|
| 539 |
+
}
|
| 540 |
+
}
|
| 541 |
+
|
| 542 |
+
// component is used in the "ENFORCE" macros
|
| 543 |
+
std::string component = "prompt";
|
| 544 |
+
|
| 545 |
+
for (auto& item : config.items()) {
|
| 546 |
+
if (item.key() == "version") {
|
| 547 |
+
JSON_ENFORCE_NUMERIC();
|
| 548 |
+
if (item.value().get<int>() != 1) {
|
| 549 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 550 |
+
"Invalid context config: unsupported version: " + item.value().dump());
|
| 551 |
+
}
|
| 552 |
+
} else if (item.key() == "prompt-template") {
|
| 553 |
+
JSON_ENFORCE_ARRAY();
|
| 554 |
+
for (auto& elem : item.value()) {
|
| 555 |
+
if (!elem.is_string()) {
|
| 556 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "prompt tags must be an array of strings");
|
| 557 |
+
}
|
| 558 |
+
}
|
| 559 |
+
} else {
|
| 560 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
|
| 561 |
+
}
|
| 562 |
+
}
|
| 563 |
+
}
|
| 564 |
+
|
| 565 |
+
static void translatePromptConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 566 |
+
quallaConfig["tags"] = genieConfig["prompt-template"];
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
//=============================================================================
|
| 570 |
+
// Embedding::Config functions
|
| 571 |
+
//=============================================================================
|
| 572 |
+
|
| 573 |
+
qnn::util::HandleManager<Embedding::Config> Embedding::Config::s_manager;
|
| 574 |
+
|
| 575 |
+
GenieEmbeddingConfig_Handle_t Embedding::Config::add(std::shared_ptr<Embedding::Config> config) {
|
| 576 |
+
return (GenieEmbeddingConfig_Handle_t)s_manager.add(config);
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
std::shared_ptr<Embedding::Config> Embedding::Config::get(GenieEmbeddingConfig_Handle_t handle) {
|
| 580 |
+
return s_manager.get((qnn::util::Handle_t)handle);
|
| 581 |
+
}
|
| 582 |
+
|
| 583 |
+
void Embedding::Config::remove(GenieEmbeddingConfig_Handle_t handle) {
|
| 584 |
+
s_manager.remove((qnn::util::Handle_t)handle);
|
| 585 |
+
}
|
| 586 |
+
|
| 587 |
+
static void validateEmbeddingConfig(const qualla::json& config) {
|
| 588 |
+
if (!config.is_object()) {
|
| 589 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Embedding config is not an object");
|
| 590 |
+
}
|
| 591 |
+
|
| 592 |
+
std::set<std::string> mandatoryFields{"version", "context", "tokenizer", "engine"};
|
| 593 |
+
for (const auto& field : mandatoryFields) {
|
| 594 |
+
if (!config.contains(field)) {
|
| 595 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
|
| 596 |
+
}
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
// component is used in the "ENFORCE" macros
|
| 600 |
+
std::string component = "embedding";
|
| 601 |
+
|
| 602 |
+
for (auto& item : config.items()) {
|
| 603 |
+
if (item.key() == "version") {
|
| 604 |
+
JSON_ENFORCE_NUMERIC();
|
| 605 |
+
if (item.value().get<int>() != 1) {
|
| 606 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 607 |
+
"Invalid embedding config: unsupported version: " + item.value().dump());
|
| 608 |
+
}
|
| 609 |
+
} else if (item.key() == "context") {
|
| 610 |
+
JSON_ENFORCE_OBJECT();
|
| 611 |
+
validateContextConfig(item.value());
|
| 612 |
+
} else if (item.key() == "tokenizer") {
|
| 613 |
+
JSON_ENFORCE_OBJECT();
|
| 614 |
+
validateTokenizerConfig(item.value());
|
| 615 |
+
} else if (item.key() == "prompt") { // optional parameter
|
| 616 |
+
JSON_ENFORCE_OBJECT();
|
| 617 |
+
validatePromptConfig(item.value());
|
| 618 |
+
} else if (item.key() == "truncate-input") { // optional parameter
|
| 619 |
+
JSON_ENFORCE_BOOLEAN();
|
| 620 |
+
} else if (item.key() == "engine") {
|
| 621 |
+
JSON_ENFORCE_OBJECT();
|
| 622 |
+
validateEngineConfig(config["engine"]);
|
| 623 |
+
} else {
|
| 624 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 625 |
+
"Unknown embedding config key: " + item.key());
|
| 626 |
+
}
|
| 627 |
+
}
|
| 628 |
+
}
|
| 629 |
+
|
| 630 |
+
static void translateEmbeddingConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
|
| 631 |
+
translateContextConfig(genieConfig["context"], quallaConfig["context"]);
|
| 632 |
+
translatePromptConfig(genieConfig["prompt"], quallaConfig["prompt"]);
|
| 633 |
+
translateTokenizerConfig(genieConfig["tokenizer"], quallaConfig);
|
| 634 |
+
translateEngineConfig(genieConfig["engine"], quallaConfig["engine"]);
|
| 635 |
+
|
| 636 |
+
if (genieConfig.contains(
|
| 637 |
+
"truncate-input")) { // to allow truncation of input incase it exceeds the context.
|
| 638 |
+
quallaConfig["truncate-input"] = genieConfig["truncate-input"];
|
| 639 |
+
}
|
| 640 |
+
}
|
| 641 |
+
|
| 642 |
+
Embedding::Config::Config(const char* configStr) {
|
| 643 |
+
qualla::json config;
|
| 644 |
+
|
| 645 |
+
{
|
| 646 |
+
std::set<qualla::json> keys;
|
| 647 |
+
|
| 648 |
+
auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
|
| 649 |
+
if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
|
| 650 |
+
if (keys.count(parsed) > 0) {
|
| 651 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 652 |
+
"Multiple embedding config key: " + parsed.dump());
|
| 653 |
+
}
|
| 654 |
+
keys.insert(parsed);
|
| 655 |
+
}
|
| 656 |
+
return true;
|
| 657 |
+
};
|
| 658 |
+
|
| 659 |
+
config = qualla::json::parse(configStr, callback);
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
if (!config.is_object()) {
|
| 663 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Embedding config is not an object");
|
| 664 |
+
}
|
| 665 |
+
|
| 666 |
+
std::set<std::string> mandatoryFields{"embedding"};
|
| 667 |
+
for (const auto& field : mandatoryFields) {
|
| 668 |
+
if (!config.contains(field)) {
|
| 669 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
|
| 670 |
+
}
|
| 671 |
+
}
|
| 672 |
+
|
| 673 |
+
// component is used in the "ENFORCE" macros
|
| 674 |
+
std::string component = "embedding";
|
| 675 |
+
|
| 676 |
+
for (auto& item : config.items()) {
|
| 677 |
+
if (item.key() == "embedding") {
|
| 678 |
+
JSON_ENFORCE_OBJECT();
|
| 679 |
+
validateEmbeddingConfig(item.value());
|
| 680 |
+
} else {
|
| 681 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 682 |
+
"Unknown embedding config key: " + item.key());
|
| 683 |
+
}
|
| 684 |
+
}
|
| 685 |
+
m_config = config;
|
| 686 |
+
}
|
| 687 |
+
|
| 688 |
+
qualla::json Embedding::Config::getJson() const { return m_config; }
|
| 689 |
+
|
| 690 |
+
//=============================================================================
|
| 691 |
+
// Embedding functions
|
| 692 |
+
//=============================================================================
|
| 693 |
+
|
| 694 |
+
qnn::util::HandleManager<Embedding> Embedding::s_manager;
|
| 695 |
+
std::atomic<std::uint32_t> Embedding::s_nameCounter{0u};
|
| 696 |
+
|
| 697 |
+
GenieEmbedding_Handle_t Embedding::add(std::shared_ptr<Embedding> embedding) {
|
| 698 |
+
return (GenieEmbedding_Handle_t)s_manager.add(embedding);
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
std::shared_ptr<Embedding> Embedding::get(GenieEmbedding_Handle_t handle) {
|
| 702 |
+
return s_manager.get((qnn::util::Handle_t)handle);
|
| 703 |
+
}
|
| 704 |
+
|
| 705 |
+
void Embedding::remove(GenieEmbedding_Handle_t handle) {
|
| 706 |
+
s_manager.remove((qnn::util::Handle_t)handle);
|
| 707 |
+
}
|
| 708 |
+
|
| 709 |
+
Embedding::Embedding(std::shared_ptr<Config> config) {
|
| 710 |
+
auto env = qualla::Env::create(qualla::json{});
|
| 711 |
+
qualla::json quallaConfig;
|
| 712 |
+
translateEmbeddingConfig(config->getJson()["embedding"], quallaConfig);
|
| 713 |
+
m_quallaEmbedding = qualla::Embedding::create(
|
| 714 |
+
env, "embedding" + std::to_string(s_nameCounter.fetch_add(1u)), quallaConfig);
|
| 715 |
+
if (!m_quallaEmbedding) {
|
| 716 |
+
throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a embedding object");
|
| 717 |
+
}
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
int32_t Embedding::generate(const char* queryStr,
|
| 721 |
+
GenieEmbedding_GenerateCallback_t callback,
|
| 722 |
+
const void* userData) {
|
| 723 |
+
std::string query(queryStr);
|
| 724 |
+
std::vector<float> outputEmbedding;
|
| 725 |
+
bool status = false;
|
| 726 |
+
status = m_quallaEmbedding->query(query, outputEmbedding);
|
| 727 |
+
if (status) {
|
| 728 |
+
std::vector<uint32_t> dimensions;
|
| 729 |
+
m_quallaEmbedding->output_dimensions(dimensions);
|
| 730 |
+
callback(dimensions.data(), dimensions.size(), outputEmbedding.data(), userData);
|
| 731 |
+
qualla::Embedding::KPIs kpis = m_quallaEmbedding->kpis();
|
| 732 |
+
printf(
|
| 733 |
+
"\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
|
| 734 |
+
"%f toks/sec\n",
|
| 735 |
+
kpis.init.total_usec,
|
| 736 |
+
kpis.prompt.last_usec,
|
| 737 |
+
kpis.tps.prompt);
|
| 738 |
+
}
|
| 739 |
+
return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERATE_FAILED);
|
| 740 |
+
}
|
Genie/Genie/src/Embedding.hpp
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
|
| 11 |
+
#include <atomic>
|
| 12 |
+
#include <memory>
|
| 13 |
+
|
| 14 |
+
#include "GenieEmbedding.h"
|
| 15 |
+
#include "Util/HandleManager.hpp"
|
| 16 |
+
#include "qualla/embedding.hpp"
|
| 17 |
+
|
| 18 |
+
namespace genie {
|
| 19 |
+
|
| 20 |
+
class Embedding {
|
| 21 |
+
public:
|
| 22 |
+
class Config {
|
| 23 |
+
public:
|
| 24 |
+
static GenieEmbeddingConfig_Handle_t add(std::shared_ptr<Config> config);
|
| 25 |
+
static std::shared_ptr<Config> get(GenieEmbeddingConfig_Handle_t handle);
|
| 26 |
+
static void remove(GenieEmbeddingConfig_Handle_t handle);
|
| 27 |
+
|
| 28 |
+
Config(const char* configStr);
|
| 29 |
+
qualla::json getJson() const;
|
| 30 |
+
|
| 31 |
+
private:
|
| 32 |
+
static qnn::util::HandleManager<Config> s_manager;
|
| 33 |
+
qualla::json m_config;
|
| 34 |
+
};
|
| 35 |
+
|
| 36 |
+
static GenieEmbedding_Handle_t add(std::shared_ptr<Embedding> embedding);
|
| 37 |
+
static std::shared_ptr<Embedding> get(GenieEmbedding_Handle_t handle);
|
| 38 |
+
static void remove(GenieEmbedding_Handle_t handle);
|
| 39 |
+
|
| 40 |
+
Embedding(std::shared_ptr<Config> config);
|
| 41 |
+
|
| 42 |
+
Embedding(const Embedding&) = delete;
|
| 43 |
+
Embedding& operator=(const Embedding&) = delete;
|
| 44 |
+
Embedding(Embedding&&) = delete;
|
| 45 |
+
Embedding& operator=(Embedding&&) = delete;
|
| 46 |
+
|
| 47 |
+
int32_t generate(const char* queryStr,
|
| 48 |
+
GenieEmbedding_GenerateCallback_t callback,
|
| 49 |
+
const void* userData);
|
| 50 |
+
|
| 51 |
+
private:
|
| 52 |
+
std::unique_ptr<qualla::Embedding> m_quallaEmbedding;
|
| 53 |
+
static qnn::util::HandleManager<Embedding> s_manager;
|
| 54 |
+
static std::atomic<std::uint32_t> s_nameCounter;
|
| 55 |
+
};
|
| 56 |
+
} // namespace genie
|
Genie/Genie/src/Exception.hpp
CHANGED
|
@@ -9,6 +9,7 @@
|
|
| 9 |
#pragma once
|
| 10 |
|
| 11 |
#include <exception>
|
|
|
|
| 12 |
#include <string>
|
| 13 |
|
| 14 |
#include "GenieCommon.h"
|
|
|
|
| 9 |
#pragma once
|
| 10 |
|
| 11 |
#include <exception>
|
| 12 |
+
#include <stdexcept>
|
| 13 |
#include <string>
|
| 14 |
|
| 15 |
#include "GenieCommon.h"
|
Genie/Genie/src/GenieDialog.cpp
CHANGED
|
@@ -232,6 +232,24 @@ Genie_Status_t GenieDialog_tokenQuery(const GenieDialog_Handle_t dialogHandle,
|
|
| 232 |
return status;
|
| 233 |
}
|
| 234 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
GENIE_API
|
| 236 |
Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
|
| 237 |
try {
|
|
@@ -246,4 +264,4 @@ Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
|
|
| 246 |
return GENIE_STATUS_ERROR_GENERAL;
|
| 247 |
}
|
| 248 |
return GENIE_STATUS_SUCCESS;
|
| 249 |
-
}
|
|
|
|
| 232 |
return status;
|
| 233 |
}
|
| 234 |
|
| 235 |
+
GENIE_API
|
| 236 |
+
Genie_Status_t GenieDialog_getSampler(const GenieDialog_Handle_t dialogHandle,
|
| 237 |
+
GenieSampler_Handle_t* dialogSamplerHandle) {
|
| 238 |
+
try {
|
| 239 |
+
GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 240 |
+
auto dialog = genie::Dialog::get(dialogHandle);
|
| 241 |
+
GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 242 |
+
GENIE_ENSURE(dialogSamplerHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 243 |
+
*dialogSamplerHandle = genie::Dialog::getSamplerHandle(dialog);
|
| 244 |
+
GENIE_ENSURE(*dialogSamplerHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 245 |
+
} catch (const std::exception& e) {
|
| 246 |
+
std::cerr << e.what() << std::endl;
|
| 247 |
+
return GENIE_STATUS_ERROR_GET_HANDLE_FAILED;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
return GENIE_STATUS_SUCCESS;
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
GENIE_API
|
| 254 |
Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
|
| 255 |
try {
|
|
|
|
| 264 |
return GENIE_STATUS_ERROR_GENERAL;
|
| 265 |
}
|
| 266 |
return GENIE_STATUS_SUCCESS;
|
| 267 |
+
}
|
Genie/Genie/src/GenieEmbedding.cpp
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//=============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//=============================================================================
|
| 8 |
+
|
| 9 |
+
#include "Embedding.hpp"
|
| 10 |
+
#include "Exception.hpp"
|
| 11 |
+
#include "GenieEmbedding.h"
|
| 12 |
+
#include "Macro.hpp"
|
| 13 |
+
#include "Util/HandleManager.hpp"
|
| 14 |
+
#include "qualla/detail/json.hpp"
|
| 15 |
+
|
| 16 |
+
using namespace genie;
|
| 17 |
+
|
| 18 |
+
GENIE_API
|
| 19 |
+
Genie_Status_t GenieEmbeddingConfig_createFromJson(const char* str,
|
| 20 |
+
GenieEmbeddingConfig_Handle_t* configHandle) {
|
| 21 |
+
try {
|
| 22 |
+
GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 23 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 24 |
+
auto config = std::make_shared<Embedding::Config>(str);
|
| 25 |
+
GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC);
|
| 26 |
+
*configHandle = genie::Embedding::Config::add(config);
|
| 27 |
+
} catch (const qualla::json::parse_error& e) {
|
| 28 |
+
std::cerr << e.what() << std::endl;
|
| 29 |
+
return GENIE_STATUS_ERROR_JSON_FORMAT;
|
| 30 |
+
} catch (const Exception& e) {
|
| 31 |
+
std::cerr << e.what() << std::endl;
|
| 32 |
+
return e.status();
|
| 33 |
+
} catch (const std::exception& e) {
|
| 34 |
+
std::cerr << e.what() << std::endl;
|
| 35 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 36 |
+
}
|
| 37 |
+
return GENIE_STATUS_SUCCESS;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
GENIE_API
|
| 41 |
+
Genie_Status_t GenieEmbeddingConfig_free(const GenieEmbeddingConfig_Handle_t configHandle) {
|
| 42 |
+
try {
|
| 43 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 44 |
+
{
|
| 45 |
+
// Check if the embedding actually exists
|
| 46 |
+
auto configObj = genie::Embedding::Config::get(configHandle);
|
| 47 |
+
GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 48 |
+
}
|
| 49 |
+
genie::Embedding::Config::remove(configHandle);
|
| 50 |
+
} catch (const std::exception& e) {
|
| 51 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 52 |
+
}
|
| 53 |
+
return GENIE_STATUS_SUCCESS;
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
GENIE_API
|
| 57 |
+
Genie_Status_t GenieEmbedding_create(const GenieEmbeddingConfig_Handle_t configHandle,
|
| 58 |
+
GenieEmbedding_Handle_t* embeddingHandle) {
|
| 59 |
+
try {
|
| 60 |
+
GENIE_ENSURE(embeddingHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 61 |
+
|
| 62 |
+
// Get config object
|
| 63 |
+
auto configObj = genie::Embedding::Config::get(configHandle);
|
| 64 |
+
GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 65 |
+
|
| 66 |
+
// Create embedding
|
| 67 |
+
auto embedding = std::make_shared<genie::Embedding>(configObj);
|
| 68 |
+
GENIE_ENSURE(embedding, GENIE_STATUS_ERROR_MEM_ALLOC);
|
| 69 |
+
|
| 70 |
+
// Create Handle
|
| 71 |
+
*embeddingHandle = genie::Embedding::add(embedding);
|
| 72 |
+
} catch (const std::exception& e) {
|
| 73 |
+
std::cerr << e.what() << std::endl;
|
| 74 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
// Return SUCCESS
|
| 78 |
+
return GENIE_STATUS_SUCCESS;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
GENIE_API
|
| 82 |
+
Genie_Status_t GenieEmbedding_generate(const GenieEmbedding_Handle_t embeddingHandle,
|
| 83 |
+
const char* queryStr,
|
| 84 |
+
const GenieEmbedding_GenerateCallback_t callback,
|
| 85 |
+
const void* userData) {
|
| 86 |
+
int32_t status;
|
| 87 |
+
|
| 88 |
+
try {
|
| 89 |
+
GENIE_ENSURE(embeddingHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 90 |
+
auto embedding = genie::Embedding::get(embeddingHandle);
|
| 91 |
+
GENIE_ENSURE(embedding, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 92 |
+
GENIE_ENSURE(queryStr, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 93 |
+
GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 94 |
+
|
| 95 |
+
status = embedding->generate(queryStr, callback, userData);
|
| 96 |
+
} catch (const std::exception& e) {
|
| 97 |
+
std::cerr << e.what() << std::endl;
|
| 98 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
return status;
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
GENIE_API
|
| 105 |
+
Genie_Status_t GenieEmbedding_free(const GenieEmbedding_Handle_t embeddingHandle) {
|
| 106 |
+
try {
|
| 107 |
+
GENIE_ENSURE(embeddingHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 108 |
+
{
|
| 109 |
+
// Check if the embedding actually exists
|
| 110 |
+
auto embedding = genie::Embedding::get(embeddingHandle);
|
| 111 |
+
GENIE_ENSURE(embedding, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 112 |
+
}
|
| 113 |
+
genie::Embedding::remove(embeddingHandle);
|
| 114 |
+
} catch (const std::exception& e) {
|
| 115 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 116 |
+
}
|
| 117 |
+
return GENIE_STATUS_SUCCESS;
|
| 118 |
+
}
|
Genie/Genie/src/GenieSampler.cpp
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//=============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//=============================================================================
|
| 8 |
+
|
| 9 |
+
#include <iostream>
|
| 10 |
+
|
| 11 |
+
#include "Exception.hpp"
|
| 12 |
+
#include "GenieSampler.h"
|
| 13 |
+
#include "Macro.hpp"
|
| 14 |
+
#include "Sampler.hpp"
|
| 15 |
+
#include "Util/HandleManager.hpp"
|
| 16 |
+
#include "qualla/detail/json.hpp"
|
| 17 |
+
|
| 18 |
+
using namespace genie;
|
| 19 |
+
GENIE_API
|
| 20 |
+
Genie_Status_t GenieSamplerConfig_createFromJson(const char* str,
|
| 21 |
+
GenieSamplerConfig_Handle_t* configHandle) {
|
| 22 |
+
try {
|
| 23 |
+
GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 24 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
|
| 25 |
+
auto config = std::make_shared<Sampler::Sampler::SamplerConfig>(str);
|
| 26 |
+
GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC);
|
| 27 |
+
*configHandle = Sampler::Sampler::SamplerConfig::add(config);
|
| 28 |
+
} catch (const qualla::json::parse_error& e) {
|
| 29 |
+
std::cerr << e.what() << std::endl;
|
| 30 |
+
return GENIE_STATUS_ERROR_JSON_FORMAT;
|
| 31 |
+
} catch (const Exception& e) {
|
| 32 |
+
std::cerr << e.what() << std::endl;
|
| 33 |
+
return e.status();
|
| 34 |
+
} catch (const std::exception& e) {
|
| 35 |
+
std::cerr << e.what() << std::endl;
|
| 36 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 37 |
+
}
|
| 38 |
+
return GENIE_STATUS_SUCCESS;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
GENIE_API
|
| 42 |
+
Genie_Status_t GenieSamplerConfig_free(const GenieSamplerConfig_Handle_t configHandle) {
|
| 43 |
+
try {
|
| 44 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 45 |
+
{
|
| 46 |
+
// Check if the dialog actually exists
|
| 47 |
+
auto configObj = Sampler::SamplerConfig::get(configHandle);
|
| 48 |
+
GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 49 |
+
}
|
| 50 |
+
Sampler::SamplerConfig::remove(configHandle);
|
| 51 |
+
} catch (const std::exception& e) {
|
| 52 |
+
return GENIE_STATUS_ERROR_GENERAL;
|
| 53 |
+
}
|
| 54 |
+
return GENIE_STATUS_SUCCESS;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
GENIE_API
|
| 58 |
+
Genie_Status_t GenieSamplerConfig_setParam(const GenieSamplerConfig_Handle_t configHandle,
|
| 59 |
+
const char* keyStr,
|
| 60 |
+
const char* valueStr) {
|
| 61 |
+
try {
|
| 62 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 63 |
+
auto samplerConfig = Sampler::SamplerConfig::get(configHandle);
|
| 64 |
+
GENIE_ENSURE(samplerConfig, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 65 |
+
samplerConfig->setParam(keyStr, valueStr);
|
| 66 |
+
} catch (const std::exception& e) {
|
| 67 |
+
std::cerr << e.what() << std::endl;
|
| 68 |
+
return GENIE_STATUS_ERROR_SET_PARAMS_FAILED;
|
| 69 |
+
}
|
| 70 |
+
return GENIE_STATUS_SUCCESS;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
GENIE_API
|
| 74 |
+
Genie_Status_t GenieSampler_applyConfig(const GenieSampler_Handle_t samplerHandle,
|
| 75 |
+
const GenieSamplerConfig_Handle_t configHandle) {
|
| 76 |
+
try {
|
| 77 |
+
GENIE_ENSURE(samplerHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 78 |
+
GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 79 |
+
|
| 80 |
+
auto sampler = Sampler::get(samplerHandle);
|
| 81 |
+
GENIE_ENSURE(sampler, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 82 |
+
|
| 83 |
+
auto samplerConfig = Sampler::SamplerConfig::get(configHandle);
|
| 84 |
+
GENIE_ENSURE(samplerConfig, GENIE_STATUS_ERROR_INVALID_HANDLE);
|
| 85 |
+
|
| 86 |
+
sampler->applyConfig(samplerConfig->getJson());
|
| 87 |
+
|
| 88 |
+
} catch (const std::exception& e) {
|
| 89 |
+
std::cerr << e.what() << std::endl;
|
| 90 |
+
return GENIE_STATUS_ERROR_APPLY_CONFIG_FAILED;
|
| 91 |
+
}
|
| 92 |
+
return GENIE_STATUS_SUCCESS;
|
| 93 |
+
}
|
Genie/Genie/src/Macro.hpp
CHANGED
|
@@ -8,6 +8,8 @@
|
|
| 8 |
|
| 9 |
#pragma once
|
| 10 |
|
|
|
|
|
|
|
| 11 |
//======================================================================================================================
|
| 12 |
// Error generation macros
|
| 13 |
//======================================================================================================================
|
|
|
|
| 8 |
|
| 9 |
#pragma once
|
| 10 |
|
| 11 |
+
#define ENABLE_DEBUG_LOGS 0
|
| 12 |
+
|
| 13 |
//======================================================================================================================
|
| 14 |
// Error generation macros
|
| 15 |
//======================================================================================================================
|
Genie/Genie/src/Sampler.cpp
ADDED
|
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
#include <exception>
|
| 9 |
+
#include <set>
|
| 10 |
+
|
| 11 |
+
#include "Exception.hpp"
|
| 12 |
+
#include "Macro.hpp"
|
| 13 |
+
#include "Sampler.hpp"
|
| 14 |
+
#include "qualla/detail/json.hpp"
|
| 15 |
+
|
| 16 |
+
using namespace genie;
|
| 17 |
+
|
| 18 |
+
//=============================================================================
|
| 19 |
+
// Sampler functions
|
| 20 |
+
//=============================================================================
|
| 21 |
+
|
| 22 |
+
qnn::util::HandleManager<Sampler> Sampler::s_manager;
|
| 23 |
+
|
| 24 |
+
GenieSampler_Handle_t Sampler::add(std::shared_ptr<Sampler> config) {
|
| 25 |
+
return (GenieSampler_Handle_t)s_manager.add(config);
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
std::shared_ptr<Sampler> Sampler::get(GenieSampler_Handle_t handle) {
|
| 29 |
+
return s_manager.get((qnn::util::Handle_t)handle);
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
void Sampler::remove(GenieSampler_Handle_t handle) {
|
| 33 |
+
s_manager.remove((qnn::util::Handle_t)handle);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
Sampler::Sampler(qualla::json& origJson,
|
| 37 |
+
std::vector<std::reference_wrapper<qualla::Sampler>>& quallaSamplers)
|
| 38 |
+
: m_origJson(origJson), m_quallaSamplers(quallaSamplers) {}
|
| 39 |
+
|
| 40 |
+
void Sampler::applyConfig(qualla::json samplerConfigJson) {
|
| 41 |
+
m_origJson["sampler"]["seed"] = qualla::Config::optional<int32_t>(
|
| 42 |
+
samplerConfigJson["sampler"], "seed", m_origJson["sampler"]["seed"]);
|
| 43 |
+
m_origJson["sampler"]["temp"] = qualla::Config::optional<float>(
|
| 44 |
+
samplerConfigJson["sampler"], "temp", m_origJson["sampler"]["temp"]);
|
| 45 |
+
m_origJson["sampler"]["top-k"] = qualla::Config::optional<size_t>(
|
| 46 |
+
samplerConfigJson["sampler"], "top-k", m_origJson["sampler"]["top-k"]);
|
| 47 |
+
m_origJson["sampler"]["top-p"] = qualla::Config::optional<float>(
|
| 48 |
+
samplerConfigJson["sampler"], "top-p", m_origJson["sampler"]["top-p"]);
|
| 49 |
+
m_origJson["sampler"]["version"] =
|
| 50 |
+
qualla::Config::optional<int32_t>(samplerConfigJson["sampler"], "version", 1);
|
| 51 |
+
m_origJson["sampler"]["type"] = "basic";
|
| 52 |
+
|
| 53 |
+
#if ENABLE_DEBUG_LOGS
|
| 54 |
+
std::cout << "Updated sampler config: " << std::endl;
|
| 55 |
+
std::cout << "temp: " << m_origJson["sampler"]["temp"].get<double>() << std::endl;
|
| 56 |
+
std::cout << "top-k: " << m_origJson["sampler"]["top-k"] << std::endl;
|
| 57 |
+
std::cout << "top-p: " << m_origJson["sampler"]["top-p"].get<double>() << std::endl;
|
| 58 |
+
std::cout << "seed: " << m_origJson["sampler"]["seed"] << std::endl;
|
| 59 |
+
#endif
|
| 60 |
+
// Loop through the live qualla sampler instances and update the parameters
|
| 61 |
+
for (auto& quallaSampler : m_quallaSamplers) {
|
| 62 |
+
quallaSampler.get().applyConfig(m_origJson["sampler"]);
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
//=============================================================================
|
| 67 |
+
// Sampler::SamplerConfig functions
|
| 68 |
+
//=============================================================================
|
| 69 |
+
|
| 70 |
+
qnn::util::HandleManager<Sampler::SamplerConfig> Sampler::SamplerConfig::s_manager;
|
| 71 |
+
|
| 72 |
+
GenieSamplerConfig_Handle_t Sampler::SamplerConfig::add(
|
| 73 |
+
std::shared_ptr<Sampler::SamplerConfig> config) {
|
| 74 |
+
return (GenieSamplerConfig_Handle_t)s_manager.add(config);
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
std::shared_ptr<Sampler::SamplerConfig> Sampler::SamplerConfig::get(
|
| 78 |
+
GenieSamplerConfig_Handle_t handle) {
|
| 79 |
+
return s_manager.get((qnn::util::Handle_t)handle);
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
void Sampler::SamplerConfig::remove(GenieSamplerConfig_Handle_t handle) {
|
| 83 |
+
s_manager.remove((qnn::util::Handle_t)handle);
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
Sampler::SamplerConfig::SamplerConfig(const char* configStr) {
|
| 87 |
+
qualla::json quallaConfig;
|
| 88 |
+
qualla::json config;
|
| 89 |
+
{
|
| 90 |
+
std::set<qualla::json> keys;
|
| 91 |
+
|
| 92 |
+
auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
|
| 93 |
+
if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
|
| 94 |
+
if (keys.count(parsed) > 0) {
|
| 95 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 96 |
+
"Multiple sampler config key: " + parsed.dump());
|
| 97 |
+
}
|
| 98 |
+
keys.insert(parsed);
|
| 99 |
+
}
|
| 100 |
+
return true;
|
| 101 |
+
};
|
| 102 |
+
|
| 103 |
+
config = qualla::json::parse(configStr, callback);
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
if (!config.is_object()) {
|
| 107 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Sampler config is not an object");
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
if (!config.contains("sampler")) {
|
| 111 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing field: sampler");
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
// component is used in the "ENFORCE" macros
|
| 115 |
+
const std::string component = "sampler";
|
| 116 |
+
for (auto& item : config.items()) {
|
| 117 |
+
if (item.key() == "sampler") {
|
| 118 |
+
JSON_ENFORCE_OBJECT();
|
| 119 |
+
validateSamplerConfig(item.value());
|
| 120 |
+
} else {
|
| 121 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
if (config["sampler"].contains("seed"))
|
| 126 |
+
quallaConfig["sampler"]["seed"] = config["sampler"]["seed"];
|
| 127 |
+
if (config["sampler"].contains("temp"))
|
| 128 |
+
quallaConfig["sampler"]["temp"] = config["sampler"]["temp"];
|
| 129 |
+
if (config["sampler"].contains("top-k"))
|
| 130 |
+
quallaConfig["sampler"]["top-k"] = config["sampler"]["top-k"];
|
| 131 |
+
if (config["sampler"].contains("top-p"))
|
| 132 |
+
quallaConfig["sampler"]["top-p"] = config["sampler"]["top-p"];
|
| 133 |
+
if (config["sampler"].contains("greedy"))
|
| 134 |
+
quallaConfig["sampler"]["greedy"] = config["sampler"]["greedy"];
|
| 135 |
+
if (config["sampler"].contains("version"))
|
| 136 |
+
quallaConfig["sampler"]["version"] = config["sampler"]["version"];
|
| 137 |
+
else
|
| 138 |
+
quallaConfig["sampler"]["version"] = 1;
|
| 139 |
+
|
| 140 |
+
quallaConfig["sampler"]["type"] = "basic";
|
| 141 |
+
|
| 142 |
+
m_config = quallaConfig;
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
void Sampler::SamplerConfig::setParam(const std::string& keyStr, const std::string& valueStr) {
|
| 146 |
+
if (!keyStr.empty()) {
|
| 147 |
+
// Case 1: Only the parameter mentioned in keyStr is to be updated by valueStr
|
| 148 |
+
std::set<std::string> validParams = {"seed", "top-p", "top-k", "temp"};
|
| 149 |
+
if (std::find(validParams.begin(), validParams.end(), keyStr) == validParams.end()) {
|
| 150 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid key obtained: " + keyStr);
|
| 151 |
+
}
|
| 152 |
+
try {
|
| 153 |
+
if (keyStr == "seed")
|
| 154 |
+
m_config["sampler"]["seed"] = std::stoi(valueStr);
|
| 155 |
+
else if (keyStr == "top-p")
|
| 156 |
+
m_config["sampler"]["top-p"] = std::stof(valueStr);
|
| 157 |
+
else if (keyStr == "top-k")
|
| 158 |
+
m_config["sampler"]["top-k"] = std::stof(valueStr);
|
| 159 |
+
else if (keyStr == "temp")
|
| 160 |
+
m_config["sampler"]["temp"] = std::stof(valueStr);
|
| 161 |
+
} catch (const std::invalid_argument& e) {
|
| 162 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 163 |
+
"Invalid value obtained: " + valueStr + " for key: " + keyStr);
|
| 164 |
+
}
|
| 165 |
+
} else {
|
| 166 |
+
// Case 2: User has passed entire json as a string in valueStr
|
| 167 |
+
|
| 168 |
+
if (valueStr.empty())
|
| 169 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Both keyStr and valueStr cannot be empty");
|
| 170 |
+
|
| 171 |
+
qualla::json config = qualla::json::parse(valueStr);
|
| 172 |
+
if (!config.contains("sampler")) {
|
| 173 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing field: sampler");
|
| 174 |
+
}
|
| 175 |
+
|
| 176 |
+
// component is used in the "ENFORCE" macros
|
| 177 |
+
const std::string component = "sampler";
|
| 178 |
+
for (auto& item : config.items()) {
|
| 179 |
+
if (item.key() == "sampler") {
|
| 180 |
+
JSON_ENFORCE_OBJECT();
|
| 181 |
+
validateSamplerConfig(item.value());
|
| 182 |
+
} else {
|
| 183 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
|
| 184 |
+
"Unknown sampler config key: " + item.key());
|
| 185 |
+
}
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
m_config["sampler"]["seed"] =
|
| 189 |
+
qualla::Config::optional<int32_t>(config["sampler"], "seed", m_config["sampler"]["seed"]);
|
| 190 |
+
m_config["sampler"]["temp"] =
|
| 191 |
+
qualla::Config::optional<float>(config["sampler"], "temp", m_config["sampler"]["temp"]);
|
| 192 |
+
m_config["sampler"]["top-k"] =
|
| 193 |
+
qualla::Config::optional<size_t>(config["sampler"], "top-k", m_config["sampler"]["top-k"]);
|
| 194 |
+
m_config["sampler"]["top-p"] =
|
| 195 |
+
qualla::Config::optional<float>(config["sampler"], "top-p", m_config["sampler"]["top-p"]);
|
| 196 |
+
m_config["sampler"]["version"] = qualla::Config::optional<int32_t>(
|
| 197 |
+
config["sampler"], "version", m_config["sampler"]["version"]);
|
| 198 |
+
|
| 199 |
+
m_config["sampler"]["type"] = "basic";
|
| 200 |
+
}
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
void Sampler::SamplerConfig::validateSamplerConfig(const qualla::json& config) {
|
| 204 |
+
if (!config.is_object()) {
|
| 205 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
const std::set<std::string> mandatoryFields{"version"};
|
| 209 |
+
for (const auto& field : mandatoryFields) {
|
| 210 |
+
if (!config.contains(field)) {
|
| 211 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
// component is used in the "ENFORCE" macros
|
| 216 |
+
const std::string component = "sampler";
|
| 217 |
+
|
| 218 |
+
for (auto& item : config.items()) {
|
| 219 |
+
if (item.key() == "version") {
|
| 220 |
+
JSON_ENFORCE_NUMERIC();
|
| 221 |
+
if (item.value().get<int>() != 1) {
|
| 222 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
|
| 223 |
+
"Invalid sampler config: unsupported version: " + item.value().dump());
|
| 224 |
+
}
|
| 225 |
+
} else if (item.key() == "seed") {
|
| 226 |
+
JSON_ENFORCE_NUMERIC();
|
| 227 |
+
} else if (item.key() == "temp") {
|
| 228 |
+
JSON_ENFORCE_NUMERIC();
|
| 229 |
+
} else if (item.key() == "top-k") {
|
| 230 |
+
JSON_ENFORCE_NUMERIC();
|
| 231 |
+
} else if (item.key() == "top-p") {
|
| 232 |
+
JSON_ENFORCE_NUMERIC();
|
| 233 |
+
} else if (item.key() == "greedy") {
|
| 234 |
+
JSON_ENFORCE_BOOLEAN();
|
| 235 |
+
} else {
|
| 236 |
+
throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
|
| 237 |
+
}
|
| 238 |
+
}
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
void Sampler::SamplerConfig::translateSamplerConfig(const qualla::json& genieConfig,
|
| 242 |
+
qualla::json& quallaConfig) {
|
| 243 |
+
if (genieConfig["dialog"].contains("sampler")) {
|
| 244 |
+
quallaConfig["sampler"]["type"] = "basic";
|
| 245 |
+
|
| 246 |
+
if (genieConfig["dialog"]["sampler"].contains("seed")) {
|
| 247 |
+
quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
|
| 248 |
+
}
|
| 249 |
+
if (genieConfig["dialog"]["sampler"].contains("temp")) {
|
| 250 |
+
quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
quallaConfig["sampler"]["role"] = "primary";
|
| 254 |
+
#if defined(GENIE_SPD_FEATURE)
|
| 255 |
+
if (genieConfig["dialog"]["type"] == "spd") {
|
| 256 |
+
quallaConfig["sampler"]["role"] = "target";
|
| 257 |
+
}
|
| 258 |
+
#endif
|
| 259 |
+
|
| 260 |
+
if (genieConfig["dialog"]["sampler"].contains("top-k")) {
|
| 261 |
+
quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
|
| 262 |
+
}
|
| 263 |
+
if (genieConfig["dialog"]["sampler"].contains("top-p")) {
|
| 264 |
+
quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
|
| 265 |
+
}
|
| 266 |
+
if (genieConfig["dialog"]["sampler"].contains("greedy")) {
|
| 267 |
+
quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
|
| 268 |
+
}
|
| 269 |
+
if (genieConfig["dialog"]["sampler"].contains("seed")) {
|
| 270 |
+
quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
|
| 271 |
+
}
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
qualla::json Sampler::SamplerConfig::getJson() const { return m_config; }
|
Genie/Genie/src/Sampler.hpp
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#pragma once
|
| 10 |
+
#include <memory>
|
| 11 |
+
|
| 12 |
+
#include "GenieSampler.h"
|
| 13 |
+
#include "Util/HandleManager.hpp"
|
| 14 |
+
#include "qualla/env.hpp"
|
| 15 |
+
#include "qualla/sampler.hpp"
|
| 16 |
+
|
| 17 |
+
namespace genie {
|
| 18 |
+
class Sampler {
|
| 19 |
+
public:
|
| 20 |
+
class SamplerConfig {
|
| 21 |
+
public:
|
| 22 |
+
static GenieSamplerConfig_Handle_t add(std::shared_ptr<SamplerConfig> config);
|
| 23 |
+
|
| 24 |
+
static std::shared_ptr<SamplerConfig> get(GenieSamplerConfig_Handle_t handle);
|
| 25 |
+
|
| 26 |
+
static void remove(GenieSamplerConfig_Handle_t handle);
|
| 27 |
+
|
| 28 |
+
static void validateSamplerConfig(const qualla::json& config);
|
| 29 |
+
|
| 30 |
+
static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig);
|
| 31 |
+
|
| 32 |
+
SamplerConfig(const char* configStr);
|
| 33 |
+
|
| 34 |
+
void setParam(const std::string& keyStr, const std::string& valueStr);
|
| 35 |
+
|
| 36 |
+
qualla::json getJson() const;
|
| 37 |
+
|
| 38 |
+
private:
|
| 39 |
+
static qnn::util::HandleManager<SamplerConfig> s_manager;
|
| 40 |
+
qualla::json m_config;
|
| 41 |
+
};
|
| 42 |
+
|
| 43 |
+
static GenieSampler_Handle_t add(std::shared_ptr<Sampler> sampler);
|
| 44 |
+
static std::shared_ptr<Sampler> get(GenieSampler_Handle_t handle);
|
| 45 |
+
static void remove(GenieSampler_Handle_t handle);
|
| 46 |
+
|
| 47 |
+
Sampler(qualla::json& origJson,
|
| 48 |
+
std::vector<std::reference_wrapper<qualla::Sampler>>& quallaSamplers);
|
| 49 |
+
|
| 50 |
+
void applyConfig(qualla::json samplerConfigJson);
|
| 51 |
+
|
| 52 |
+
const qualla::json& getJson();
|
| 53 |
+
|
| 54 |
+
private:
|
| 55 |
+
qualla::json m_origJson;
|
| 56 |
+
static qnn::util::HandleManager<Sampler> s_manager;
|
| 57 |
+
std::vector<std::reference_wrapper<qualla::Sampler>> m_quallaSamplers;
|
| 58 |
+
};
|
| 59 |
+
|
| 60 |
+
} // namespace genie
|
Genie/Genie/src/qualla/context.cpp
CHANGED
|
@@ -93,6 +93,10 @@ extern void needQnnHtpEngine();
|
|
| 93 |
extern void needQnnCpuEngine();
|
| 94 |
#endif
|
| 95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
static OnLoad needs([]() {
|
| 97 |
needStdoutLogger();
|
| 98 |
needFileLogger();
|
|
@@ -111,6 +115,10 @@ static OnLoad needs([]() {
|
|
| 111 |
#ifdef QUALLA_ENGINE_QNN_CPU
|
| 112 |
needQnnCpuEngine();
|
| 113 |
#endif
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
});
|
| 115 |
|
| 116 |
#endif
|
|
|
|
| 93 |
extern void needQnnCpuEngine();
|
| 94 |
#endif
|
| 95 |
|
| 96 |
+
#ifdef QUALLA_ENGINE_QNN_GPU
|
| 97 |
+
extern void needQnnGpuEngine();
|
| 98 |
+
#endif
|
| 99 |
+
|
| 100 |
static OnLoad needs([]() {
|
| 101 |
needStdoutLogger();
|
| 102 |
needFileLogger();
|
|
|
|
| 115 |
#ifdef QUALLA_ENGINE_QNN_CPU
|
| 116 |
needQnnCpuEngine();
|
| 117 |
#endif
|
| 118 |
+
|
| 119 |
+
#ifdef QUALLA_ENGINE_QNN_GPU
|
| 120 |
+
needQnnGpuEngine();
|
| 121 |
+
#endif
|
| 122 |
});
|
| 123 |
|
| 124 |
#endif
|
Genie/Genie/src/qualla/dialogs/ssd-q1.cpp
CHANGED
|
@@ -161,7 +161,7 @@ SelfSpecDecDialog::SelfSpecDecDialog(
|
|
| 161 |
m_inputType = _engine["primary"]->getInputType();
|
| 162 |
// Load KV prefix
|
| 163 |
Timer timer;
|
| 164 |
-
size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
|
| 165 |
if (n_restored_prefix != _forecast_prefix) {
|
| 166 |
// clang-format off
|
| 167 |
throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
|
|
@@ -1001,7 +1001,7 @@ bool SelfSpecDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback c
|
|
| 1001 |
void SelfSpecDecDialog::reset() {
|
| 1002 |
Dialog::reset();
|
| 1003 |
_n_past = _forecast_prefix;
|
| 1004 |
-
size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
|
| 1005 |
if (n_restored_prefix != _forecast_prefix) {
|
| 1006 |
// clang-format off
|
| 1007 |
throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
|
|
|
|
| 161 |
m_inputType = _engine["primary"]->getInputType();
|
| 162 |
// Load KV prefix
|
| 163 |
Timer timer;
|
| 164 |
+
size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name, true);
|
| 165 |
if (n_restored_prefix != _forecast_prefix) {
|
| 166 |
// clang-format off
|
| 167 |
throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
|
|
|
|
| 1001 |
void SelfSpecDecDialog::reset() {
|
| 1002 |
Dialog::reset();
|
| 1003 |
_n_past = _forecast_prefix;
|
| 1004 |
+
size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name, true);
|
| 1005 |
if (n_restored_prefix != _forecast_prefix) {
|
| 1006 |
// clang-format off
|
| 1007 |
throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
|
Genie/Genie/src/qualla/engine.cpp
CHANGED
|
@@ -69,7 +69,7 @@ bool Engine::updateKV(size_t n_past, const std::vector<bool>& selected) {
|
|
| 69 |
return false;
|
| 70 |
}
|
| 71 |
|
| 72 |
-
size_t Engine::restore(const std::string& name) {
|
| 73 |
_env.logger().error(fmt::format("{}-engine does not support restore", _type));
|
| 74 |
return 0;
|
| 75 |
}
|
|
|
|
| 69 |
return false;
|
| 70 |
}
|
| 71 |
|
| 72 |
+
size_t Engine::restore(const std::string& name, bool chooseHigherVariant) {
|
| 73 |
_env.logger().error(fmt::format("{}-engine does not support restore", _type));
|
| 74 |
return 0;
|
| 75 |
}
|
Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.cpp
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
|
| 9 |
+
#include "QnnMem.h"
|
| 10 |
+
#include "DmaBufAllocator.hpp"
|
| 11 |
+
#include "QnnTypeMacros.hpp"
|
| 12 |
+
|
| 13 |
+
#include <dlfcn.h>
|
| 14 |
+
#include <fcntl.h>
|
| 15 |
+
#include <linux/dma-buf.h>
|
| 16 |
+
#include <pthread.h>
|
| 17 |
+
#include <stdlib.h>
|
| 18 |
+
#include <sys/ioctl.h>
|
| 19 |
+
#include <sys/mman.h>
|
| 20 |
+
#include <unistd.h>
|
| 21 |
+
|
| 22 |
+
#include <cstdlib>
|
| 23 |
+
#include <fstream>
|
| 24 |
+
#include <iostream>
|
| 25 |
+
#include <numeric>
|
| 26 |
+
#include <string>
|
| 27 |
+
#include <vector>
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
DmaBufferAllocator::DmaBufferAllocator(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface)
|
| 31 |
+
: m_libDmaBufHeapHandle(nullptr),
|
| 32 |
+
m_dmaBufCreate(nullptr),
|
| 33 |
+
m_dmaBufAlloc(nullptr),
|
| 34 |
+
m_dmaBufDeinit(nullptr),
|
| 35 |
+
m_qnnInterface(qnnInterface),
|
| 36 |
+
m_contextHandle(contextHandle) {}
|
| 37 |
+
|
| 38 |
+
bool DmaBufferAllocator::initialize() {
|
| 39 |
+
// On Android, 32-bit and 64-bit libdmaBufheap.so can be found at /system/lib and /system/lib64
|
| 40 |
+
// respectively.
|
| 41 |
+
m_libDmaBufHeapHandle = dlopen("libdmabufheap.so", RTLD_NOW | RTLD_LOCAL);
|
| 42 |
+
if (nullptr == m_libDmaBufHeapHandle) {
|
| 43 |
+
QNN_ERROR("Unable to load backend. dlerror(): %s", dlerror());
|
| 44 |
+
return false;
|
| 45 |
+
}
|
| 46 |
+
m_dmaBufCreate = (DmaBufCreateFn_t)dlsym(
|
| 47 |
+
m_libDmaBufHeapHandle, "CreateDmabufHeapBufferAllocator");
|
| 48 |
+
m_dmaBufAlloc =
|
| 49 |
+
(DmaBufAllocFn_t)dlsym(m_libDmaBufHeapHandle, "DmabufHeapAlloc");
|
| 50 |
+
m_dmaBufDeinit = (DmaBufDeinitFn_t)dlsym(
|
| 51 |
+
m_libDmaBufHeapHandle, "FreeDmabufHeapBufferAllocator");
|
| 52 |
+
if (nullptr == m_dmaBufCreate || nullptr == m_dmaBufAlloc || nullptr == m_dmaBufDeinit) {
|
| 53 |
+
QNN_ERROR("Unable to access symbols in libdmaBufheap. dlerror(): %s", dlerror());
|
| 54 |
+
return false;
|
| 55 |
+
}
|
| 56 |
+
return true;
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
DmaBufferAllocator::~DmaBufferAllocator() {
|
| 60 |
+
if (m_libDmaBufHeapHandle) {
|
| 61 |
+
dlclose(m_libDmaBufHeapHandle);
|
| 62 |
+
m_libDmaBufHeapHandle = nullptr;
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
DmaBufferData* DmaBufferAllocator::getDmaBufTensorData(Qnn_Tensor_t* tensor) {
|
| 67 |
+
if (tensor == nullptr) return nullptr;
|
| 68 |
+
Qnn_MemHandle_t mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 69 |
+
if (mem_handle == nullptr) return nullptr;
|
| 70 |
+
return &m_memHandleToDmaBufMem.at(mem_handle);
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
void* DmaBufferAllocator::getBuffer(Qnn_Tensor_t* tensor) {
|
| 74 |
+
if (!tensor) {
|
| 75 |
+
QNN_WARN("DmaBufferAllocator: getBuffer: received a null pointer to a tensor");
|
| 76 |
+
return nullptr;
|
| 77 |
+
}
|
| 78 |
+
if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
|
| 79 |
+
QNN_ERROR("DmaBufferAllocator: Tensor not found with address = %p", tensor);
|
| 80 |
+
return nullptr;
|
| 81 |
+
}
|
| 82 |
+
DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
|
| 83 |
+
return dmaBufferData.memPointer;
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
int DmaBufferAllocator::getFd(Qnn_Tensor_t* tensor) {
|
| 89 |
+
DmaBufferData* data = getDmaBufTensorData(tensor);
|
| 90 |
+
if (data == nullptr) {
|
| 91 |
+
QNN_ERROR("DmaBufferAllocator: getFd : Couldn't find tensor %p", tensor);
|
| 92 |
+
return -1;
|
| 93 |
+
}
|
| 94 |
+
return data->fd;
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
size_t DmaBufferAllocator::getOffset(Qnn_Tensor_t* tensor) {
|
| 98 |
+
DmaBufferData* data = getDmaBufTensorData(tensor);
|
| 99 |
+
if (data == nullptr) {
|
| 100 |
+
QNN_ERROR("DmaBufferAllocator: getOffset : Couldn't find tensor %p", tensor);
|
| 101 |
+
return 0;
|
| 102 |
+
}
|
| 103 |
+
return data->offset;
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
size_t DmaBufferAllocator::getBufferSize(Qnn_Tensor_t* tensor) {
|
| 107 |
+
DmaBufferData* data = getDmaBufTensorData(tensor);
|
| 108 |
+
if (data == nullptr) {
|
| 109 |
+
QNN_ERROR("DmaBufferAllocator: getBufferSize : Couldn't find tensor %p", tensor);
|
| 110 |
+
return 0;
|
| 111 |
+
}
|
| 112 |
+
return data->totalBufferSize;
|
| 113 |
+
};
|
| 114 |
+
|
| 115 |
+
size_t DmaBufferAllocator::getTotalBufferSize(Qnn_Tensor_t* tensor) {
|
| 116 |
+
DmaBufferData* data = getDmaBufTensorData(tensor);
|
| 117 |
+
if (data == nullptr) {
|
| 118 |
+
QNN_ERROR("DmaBufferAllocator: getTotalBufferSize : Couldn't find tensor %p", tensor);
|
| 119 |
+
return 0;
|
| 120 |
+
}
|
| 121 |
+
return data->totalBufferSize;
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
bool DmaBufferAllocator::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) {
|
| 125 |
+
if (m_libDmaBufHeapHandle == nullptr) {
|
| 126 |
+
QNN_ERROR("DmaBufferAllocator not initialized");
|
| 127 |
+
return false;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
if (!tensor) {
|
| 131 |
+
QNN_ERROR("DmaBufferAllocator: Received nullptr for tensor");
|
| 132 |
+
return false;
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
if (m_tensorToDmaBufferData.find(tensor) != m_tensorToDmaBufferData.end()) {
|
| 136 |
+
QNN_ERROR("DmaBufferAllocator: Tensor already allocated");
|
| 137 |
+
return false;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
void* dmaBufferAllocator = m_dmaBufCreate();
|
| 141 |
+
if (dmaBufferAllocator == nullptr) {
|
| 142 |
+
QNN_ERROR("DmaBufferAllocator: nullptr returned for CreateDmabufHeapBufferAllocator().");
|
| 143 |
+
return false;
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
int fd = m_dmaBufAlloc(dmaBufferAllocator, "qcom,system", tensorDataSize, 0, 0);
|
| 147 |
+
if (fd < 0) {
|
| 148 |
+
QNN_ERROR("DmaBufAlloc returned a invalid file descriptor = %d", fd);
|
| 149 |
+
return false;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
void* memPointer = mmap(nullptr, tensorDataSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
|
| 153 |
+
if (MAP_FAILED == memPointer) {
|
| 154 |
+
printf("DmaBufferAllocator: Unable to open file returned by DmaBufAlloc with mmap");
|
| 155 |
+
return false;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
Qnn_MemDescriptor_t memDescriptor = {
|
| 159 |
+
{QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr},
|
| 160 |
+
QNN_TENSOR_GET_DATA_TYPE(tensor),
|
| 161 |
+
QNN_MEM_TYPE_DMA_BUF,
|
| 162 |
+
{.dmaBufInfo = {fd, memPointer}}};
|
| 163 |
+
QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
|
| 164 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr);
|
| 165 |
+
Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 166 |
+
|
| 167 |
+
if (QNN_SUCCESS !=
|
| 168 |
+
m_qnnInterface->memRegister(m_contextHandle, &memDescriptor, 1, &(memHandle))) {
|
| 169 |
+
QNN_ERROR("DmaBufferAllocator: Failure to register ion memory with the backend");
|
| 170 |
+
return false;
|
| 171 |
+
}
|
| 172 |
+
QNN_DEBUG("DmaBufferAllocator: Memregister successful with handle %p for DMA buffer with size: %zu and fd %d",
|
| 173 |
+
memHandle,
|
| 174 |
+
tensorDataSize,
|
| 175 |
+
fd);
|
| 176 |
+
QNN_TENSOR_SET_MEM_HANDLE(tensor, memHandle);
|
| 177 |
+
m_tensorToDmaBufferData.insert(
|
| 178 |
+
{tensor, DmaBufferData(dmaBufferAllocator, fd, memPointer, tensorDataSize)});
|
| 179 |
+
|
| 180 |
+
return true;
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
bool DmaBufferAllocator::freeTensorBuffer(Qnn_Tensor_t* tensor) {
|
| 184 |
+
if (!tensor) {
|
| 185 |
+
QNN_ERROR("DmaBufferAllocator: Received nullptr for tensor");
|
| 186 |
+
return false;
|
| 187 |
+
}
|
| 188 |
+
auto memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
|
| 189 |
+
if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
|
| 190 |
+
QNN_ERROR("DmaBufferAllocator: Failed to deregister custom memory handle with the backend");
|
| 191 |
+
return false;
|
| 192 |
+
}
|
| 193 |
+
if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
|
| 194 |
+
QNN_ERROR("DmaBufferAllocator: Tensor not found with address = %p", tensor);
|
| 195 |
+
return false;
|
| 196 |
+
}
|
| 197 |
+
DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
|
| 198 |
+
if (!m_dmaBufDeinit) {
|
| 199 |
+
QNN_ERROR("DmaBufferAllocator: DmaBuf Deinit function pointer is null");
|
| 200 |
+
return false;
|
| 201 |
+
}
|
| 202 |
+
munmap(dmaBufferData.memPointer, dmaBufferData.totalBufferSize);
|
| 203 |
+
m_dmaBufDeinit(dmaBufferData.dmaBufferAllocator);
|
| 204 |
+
m_tensorToDmaBufferData.erase(tensor);
|
| 205 |
+
return true;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
bool DmaBufferAllocator::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
|
| 209 |
+
if (nullptr == dest || nullptr == src) {
|
| 210 |
+
QNN_ERROR("DmaBufferAllocator: Received nullptr");
|
| 211 |
+
return false;
|
| 212 |
+
}
|
| 213 |
+
if (m_tensorToDmaBufferData.find(src) == m_tensorToDmaBufferData.end()) {
|
| 214 |
+
QNN_ERROR("DmaBufferAllocator: Src Tensor not found");
|
| 215 |
+
return false;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
|
| 219 |
+
QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src));
|
| 220 |
+
m_tensorToDmaBufferData.insert({dest, m_tensorToDmaBufferData[src]});
|
| 221 |
+
m_sameMemoryFreeTensors.insert(dest);
|
| 222 |
+
return true;
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
bool DmaBufferAllocator::beforeWriteToBuffer(Qnn_Tensor_t* tensor) {
|
| 228 |
+
if (!tensor) {
|
| 229 |
+
QNN_WARN("beforeWriteToBuffer: received a null pointer to a tensor");
|
| 230 |
+
return false;
|
| 231 |
+
}
|
| 232 |
+
if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
|
| 233 |
+
QNN_ERROR("beforeWriteToBuffer: Tensor not found with address = %p", tensor);
|
| 234 |
+
return false;
|
| 235 |
+
}
|
| 236 |
+
DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
|
| 237 |
+
struct dma_buf_sync buf_sync = {};
|
| 238 |
+
buf_sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_WRITE;
|
| 239 |
+
auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
|
| 240 |
+
if (ioctlReturnValue) {
|
| 241 |
+
QNN_ERROR(
|
| 242 |
+
"beforeWriteToBuffer: Error preparing the cache for buffer writes."
|
| 243 |
+
"The DMA_BUF_IOCTL_SYNC operation returned %d",
|
| 244 |
+
ioctlReturnValue);
|
| 245 |
+
return false;
|
| 246 |
+
}
|
| 247 |
+
return true;
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
bool DmaBufferAllocator::afterWriteToBuffer(Qnn_Tensor_t* tensor) {
|
| 251 |
+
if (!tensor) {
|
| 252 |
+
QNN_WARN("afterWriteToBuffer: received a null pointer to a tensor");
|
| 253 |
+
return false;
|
| 254 |
+
}
|
| 255 |
+
if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
|
| 256 |
+
QNN_ERROR("afterWriteToBuffer: Tensor not found with address = %p", tensor);
|
| 257 |
+
return false;
|
| 258 |
+
}
|
| 259 |
+
DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
|
| 260 |
+
struct dma_buf_sync buf_sync = {};
|
| 261 |
+
buf_sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_WRITE;
|
| 262 |
+
auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
|
| 263 |
+
if (ioctlReturnValue) {
|
| 264 |
+
QNN_ERROR(
|
| 265 |
+
"afterWriteToBuffer: Error close the cache after buffer writing."
|
| 266 |
+
"The DMA_BUF_IOCTL_SYNC operation returned %d",
|
| 267 |
+
ioctlReturnValue);
|
| 268 |
+
return false;
|
| 269 |
+
}
|
| 270 |
+
return true;
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
bool DmaBufferAllocator::beforeReadFromBuffer(Qnn_Tensor_t* tensor) {
|
| 274 |
+
if (!tensor) {
|
| 275 |
+
QNN_WARN("beforeReadFromBuffer: received a null pointer to a tensor");
|
| 276 |
+
return false;
|
| 277 |
+
}
|
| 278 |
+
if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
|
| 279 |
+
QNN_ERROR("beforeReadFromBuffer: Tensor not found with address = %p", tensor);
|
| 280 |
+
return false;
|
| 281 |
+
}
|
| 282 |
+
DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
|
| 283 |
+
struct dma_buf_sync buf_sync = {};
|
| 284 |
+
buf_sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_READ;
|
| 285 |
+
auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
|
| 286 |
+
if (ioctlReturnValue) {
|
| 287 |
+
QNN_ERROR(
|
| 288 |
+
"beforeReadFromBuffer: Error preparing the cache for buffer reading."
|
| 289 |
+
"The DMA_BUF_IOCTL_SYNC operation returned %d",
|
| 290 |
+
ioctlReturnValue);
|
| 291 |
+
return false;
|
| 292 |
+
}
|
| 293 |
+
return true;
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
bool DmaBufferAllocator::afterReadFromBuffer(Qnn_Tensor_t* tensor) {
|
| 297 |
+
if (!tensor) {
|
| 298 |
+
QNN_WARN("afterReadFromBuffer: received a null pointer to a tensor");
|
| 299 |
+
return false;
|
| 300 |
+
}
|
| 301 |
+
if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
|
| 302 |
+
QNN_ERROR("afterReadFromBuffer: Tensor not found with address = %p", tensor);
|
| 303 |
+
return false;
|
| 304 |
+
}
|
| 305 |
+
DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
|
| 306 |
+
struct dma_buf_sync buf_sync = {};
|
| 307 |
+
buf_sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_READ;
|
| 308 |
+
auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
|
| 309 |
+
if (ioctlReturnValue) {
|
| 310 |
+
QNN_ERROR(
|
| 311 |
+
"afterReadFromBuffer: Error closing the cache after buffer reading."
|
| 312 |
+
"The DMA_BUF_IOCTL_SYNC operation returned %d",
|
| 313 |
+
ioctlReturnValue);
|
| 314 |
+
return false;
|
| 315 |
+
}
|
| 316 |
+
return true;
|
| 317 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.hpp
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
//==============================================================================
|
| 2 |
+
//
|
| 3 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 4 |
+
// All Rights Reserved.
|
| 5 |
+
// Confidential and Proprietary - Qualcomm Technologies, Inc.
|
| 6 |
+
//
|
| 7 |
+
//==============================================================================
|
| 8 |
+
#pragma once
|
| 9 |
+
|
| 10 |
+
#include <map>
|
| 11 |
+
#include <unordered_map>
|
| 12 |
+
#include <unordered_set>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
#include "IBufferAlloc.hpp"
|
| 16 |
+
#include "QnnInterface.h"
|
| 17 |
+
#include "Log.hpp"
|
| 18 |
+
|
| 19 |
+
typedef void *(*DmaBufCreateFn_t)();
|
| 20 |
+
typedef int (*DmaBufAllocFn_t)(void *, const char *, size_t, unsigned int, size_t);
|
| 21 |
+
typedef void (*DmaBufDeinitFn_t)(void *);
|
| 22 |
+
|
| 23 |
+
struct DmaBufferData {
|
| 24 |
+
void *dmaBufferAllocator;
|
| 25 |
+
int fd;
|
| 26 |
+
void* memPointer;
|
| 27 |
+
size_t totalBufferSize;
|
| 28 |
+
int offset{0};
|
| 29 |
+
DmaBufferData() : dmaBufferAllocator(nullptr), fd(-1), memPointer(nullptr), totalBufferSize(0) {}
|
| 30 |
+
DmaBufferData(void *bufferAllocator, int fdIn, void* memPointerIn, size_t sizeIn)
|
| 31 |
+
: dmaBufferAllocator(bufferAllocator), fd(fdIn), memPointer(memPointerIn), totalBufferSize(sizeIn) {}
|
| 32 |
+
};
|
| 33 |
+
|
| 34 |
+
class DmaBufferAllocator final : public IBufferAlloc {
|
| 35 |
+
public:
|
| 36 |
+
DmaBufferAllocator(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface);
|
| 37 |
+
// Disable copy constructors, r-value referencing, etc
|
| 38 |
+
DmaBufferAllocator(const DmaBufferAllocator&) = delete;
|
| 39 |
+
DmaBufferAllocator& operator=(const DmaBufferAllocator&) = delete;
|
| 40 |
+
DmaBufferAllocator(DmaBufferAllocator&&) = delete;
|
| 41 |
+
DmaBufferAllocator& operator=(DmaBufferAllocator&&) = delete;
|
| 42 |
+
|
| 43 |
+
bool initialize() override;
|
| 44 |
+
void* getBuffer(Qnn_Tensor_t* tensor) override;
|
| 45 |
+
int getFd(Qnn_Tensor_t* tensor) override;
|
| 46 |
+
size_t getOffset(Qnn_Tensor_t* tensor) override;
|
| 47 |
+
size_t getBufferSize(Qnn_Tensor_t* tensor) override;
|
| 48 |
+
size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override;
|
| 49 |
+
|
| 50 |
+
bool freeTensorBuffer(Qnn_Tensor_t* tensor) override;
|
| 51 |
+
|
| 52 |
+
bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override;
|
| 53 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override;
|
| 54 |
+
|
| 55 |
+
virtual ~DmaBufferAllocator();
|
| 56 |
+
|
| 57 |
+
bool beforeWriteToBuffer(Qnn_Tensor_t *tensor) override;
|
| 58 |
+
bool afterWriteToBuffer(Qnn_Tensor_t *tensor) override;
|
| 59 |
+
bool beforeReadFromBuffer(Qnn_Tensor_t *tensor) override;
|
| 60 |
+
bool afterReadFromBuffer(Qnn_Tensor_t *tensor) override;
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override {
|
| 64 |
+
QNN_WARN("Offset based tensors not supported!!");
|
| 65 |
+
return false;;
|
| 66 |
+
}
|
| 67 |
+
bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override {
|
| 68 |
+
QNN_WARN("External Memory not supported!!");
|
| 69 |
+
return false;;
|
| 70 |
+
}
|
| 71 |
+
void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override {
|
| 72 |
+
QNN_WARN("Fused Buffers not supported\n");
|
| 73 |
+
return nullptr;
|
| 74 |
+
};
|
| 75 |
+
bool allocateBuffers(
|
| 76 |
+
const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
|
| 77 |
+
std::map<std::string, std::pair<int, size_t>>& tensor_offsets
|
| 78 |
+
) override {
|
| 79 |
+
QNN_WARN("Fused Buffers not supported\n");
|
| 80 |
+
return false;
|
| 81 |
+
};
|
| 82 |
+
bool mapFusedBufferOffset(
|
| 83 |
+
Qnn_Tensor_t* tensor,
|
| 84 |
+
size_t tensorDataSize,
|
| 85 |
+
int32_t fd,
|
| 86 |
+
uint32_t offset,
|
| 87 |
+
uint64_t totalBufferSize,
|
| 88 |
+
void* memPointer,
|
| 89 |
+
Qnn_ContextHandle_t contextHandle
|
| 90 |
+
) override {
|
| 91 |
+
QNN_WARN("Fused Buffers not supported\n");
|
| 92 |
+
return false;
|
| 93 |
+
};
|
| 94 |
+
bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override {
|
| 95 |
+
QNN_WARN("Fused Buffers not supported\n");
|
| 96 |
+
return false;
|
| 97 |
+
};
|
| 98 |
+
void freeFusedBuffers() override {
|
| 99 |
+
return;
|
| 100 |
+
};
|
| 101 |
+
bool mapFusedBufferOffset(
|
| 102 |
+
Qnn_Tensor_t* tensor,
|
| 103 |
+
int alloc_idx,
|
| 104 |
+
size_t offset,
|
| 105 |
+
Qnn_ContextHandle_t ctx,
|
| 106 |
+
size_t size
|
| 107 |
+
) override {
|
| 108 |
+
QNN_WARN("Fused Buffers not supported\n");
|
| 109 |
+
return false;
|
| 110 |
+
};
|
| 111 |
+
|
| 112 |
+
private:
|
| 113 |
+
DmaBufferData * getDmaBufTensorData(Qnn_Tensor_t* tensor);
|
| 114 |
+
|
| 115 |
+
// Pointer to the dlopen'd libdmabufheap.so shared library which contains
|
| 116 |
+
// dmaBufCreate, dmaBufAlloc, dmaBufDeinit
|
| 117 |
+
void *m_libDmaBufHeapHandle;
|
| 118 |
+
DmaBufCreateFn_t m_dmaBufCreate;
|
| 119 |
+
DmaBufAllocFn_t m_dmaBufAlloc;
|
| 120 |
+
DmaBufDeinitFn_t m_dmaBufDeinit;
|
| 121 |
+
|
| 122 |
+
QNN_INTERFACE_VER_TYPE* m_qnnInterface;
|
| 123 |
+
Qnn_ContextHandle_t m_contextHandle;
|
| 124 |
+
|
| 125 |
+
std::unordered_map<Qnn_Tensor_t *, DmaBufferData> m_tensorToDmaBufferData;
|
| 126 |
+
std::unordered_set<Qnn_Tensor_t*> m_sameMemoryFreeTensors;
|
| 127 |
+
std::unordered_map<Qnn_MemHandle_t, DmaBufferData> m_memHandleToDmaBufMem;
|
| 128 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp
CHANGED
|
@@ -53,4 +53,18 @@ class IBufferAlloc {
|
|
| 53 |
|
| 54 |
virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0;
|
| 55 |
virtual void freeFusedBuffers() = 0;
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0;
|
| 55 |
virtual void freeFusedBuffers() = 0;
|
| 56 |
+
|
| 57 |
+
// Functions to sync memory buffers for Read/Write using DmaBuf.
|
| 58 |
+
virtual bool beforeWriteToBuffer(Qnn_Tensor_t *tensor) {
|
| 59 |
+
return false;
|
| 60 |
+
};
|
| 61 |
+
virtual bool afterWriteToBuffer(Qnn_Tensor_t *tensor) {
|
| 62 |
+
return false;
|
| 63 |
+
};
|
| 64 |
+
virtual bool beforeReadFromBuffer(Qnn_Tensor_t *tensor) {
|
| 65 |
+
return false;
|
| 66 |
+
};
|
| 67 |
+
virtual bool afterReadFromBuffer(Qnn_Tensor_t *tensor) {
|
| 68 |
+
return false;
|
| 69 |
+
};
|
| 70 |
+
};
|
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp
CHANGED
|
@@ -10,6 +10,9 @@
|
|
| 10 |
#include <iostream>
|
| 11 |
|
| 12 |
#include "ClientBuffer.hpp"
|
|
|
|
|
|
|
|
|
|
| 13 |
#include "IBufferAlloc.hpp"
|
| 14 |
#include "IOTensor.hpp"
|
| 15 |
#include "RpcMem.hpp"
|
|
@@ -28,6 +31,14 @@ IOTensor::IOTensor(BufferAlloc bufferAllocIn, QNN_INTERFACE_VER_TYPE* qnnInterfa
|
|
| 28 |
bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
|
| 29 |
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 30 |
m_bufferManager = std::unique_ptr<IBufferAlloc>(new RpcMem(contextHandle, m_qnnInterface));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
}
|
| 32 |
|
| 33 |
if (true != m_bufferManager->initialize()) {
|
|
@@ -39,7 +50,7 @@ bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
|
|
| 39 |
}
|
| 40 |
|
| 41 |
IOTensor::~IOTensor() {
|
| 42 |
-
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 43 |
m_bufferManager->freeFusedBuffers();
|
| 44 |
}
|
| 45 |
}
|
|
@@ -215,6 +226,70 @@ bool IOTensor::setupOutputTensors(
|
|
| 215 |
return true;
|
| 216 |
}
|
| 217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 218 |
bool IOTensor::mapFusedBufferOffset(
|
| 219 |
GraphInfo_t* graph_info,
|
| 220 |
Qnn_ContextHandle_t context_handle,
|
|
|
|
| 10 |
#include <iostream>
|
| 11 |
|
| 12 |
#include "ClientBuffer.hpp"
|
| 13 |
+
#ifndef _WIN32
|
| 14 |
+
#include "DmaBufAllocator.hpp"
|
| 15 |
+
#endif
|
| 16 |
#include "IBufferAlloc.hpp"
|
| 17 |
#include "IOTensor.hpp"
|
| 18 |
#include "RpcMem.hpp"
|
|
|
|
| 31 |
bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
|
| 32 |
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
|
| 33 |
m_bufferManager = std::unique_ptr<IBufferAlloc>(new RpcMem(contextHandle, m_qnnInterface));
|
| 34 |
+
} else if (m_bufferAlloc == BufferAlloc::DMABUF) {
|
| 35 |
+
#ifdef _WIN32
|
| 36 |
+
return false;
|
| 37 |
+
#else
|
| 38 |
+
m_bufferManager =
|
| 39 |
+
std::unique_ptr<IBufferAlloc>(new DmaBufferAllocator(contextHandle, m_qnnInterface)
|
| 40 |
+
);
|
| 41 |
+
#endif
|
| 42 |
}
|
| 43 |
|
| 44 |
if (true != m_bufferManager->initialize()) {
|
|
|
|
| 50 |
}
|
| 51 |
|
| 52 |
IOTensor::~IOTensor() {
|
| 53 |
+
if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER || m_bufferAlloc == BufferAlloc::DMABUF) {
|
| 54 |
m_bufferManager->freeFusedBuffers();
|
| 55 |
}
|
| 56 |
}
|
|
|
|
| 226 |
return true;
|
| 227 |
}
|
| 228 |
|
| 229 |
+
// Setup details for Qnn_Tensor_t for execution.
|
| 230 |
+
// Reuse same memory handle for KV input and output tensor.
|
| 231 |
+
bool IOTensor::setupOutputWithSharedTensors(
|
| 232 |
+
Qnn_Tensor_t** tensors,
|
| 233 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 234 |
+
const GraphInfo_t& graphInfo,
|
| 235 |
+
std::unordered_map<std::string, size_t>& tensorsSize,
|
| 236 |
+
Qnn_ContextHandle_t contextHandle,
|
| 237 |
+
std::unordered_map<std::string, Qnn_Tensor_t*> sharedTensorMap
|
| 238 |
+
) {
|
| 239 |
+
uint32_t tensorCount = graphInfo.numOutputTensors;
|
| 240 |
+
TensorWrapper* tensorWrappers = graphInfo.outputTensors;
|
| 241 |
+
if (nullptr == tensorWrappers) {
|
| 242 |
+
QNN_ERROR("tensorWrappers is nullptr");
|
| 243 |
+
return false;
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
if (0 == tensorCount) {
|
| 247 |
+
QNN_DEBUG("tensor count is 0. Nothing to setup.");
|
| 248 |
+
return true;
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
*tensors = (Qnn_Tensor_t*)calloc(1, tensorCount * sizeof(Qnn_Tensor_t));
|
| 252 |
+
if (nullptr == *tensors) {
|
| 253 |
+
QNN_ERROR("mem alloc failed for *tensors");
|
| 254 |
+
return false;
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
bool returnStatus = true;
|
| 258 |
+
for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
|
| 259 |
+
Qnn_Tensor_t wrapperTensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tensorIdx]);
|
| 260 |
+
auto wrapperTensorName = std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx]));
|
| 261 |
+
if (true == returnStatus) {
|
| 262 |
+
(*tensors)[tensorIdx] = QNN_TENSOR_INIT;
|
| 263 |
+
returnStatus = deepCopyQnnTensorInfo(((*tensors) + tensorIdx), &wrapperTensor);
|
| 264 |
+
}
|
| 265 |
+
if (true == returnStatus) {
|
| 266 |
+
if (sharedTensorMap.find(wrapperTensorName) == sharedTensorMap.end()) {
|
| 267 |
+
QNN_DEBUG("IoTensor :: Create Buffer for Tensor %s", wrapperTensorName.c_str());
|
| 268 |
+
size_t tensorDataSize = tensorsSize[wrapperTensorName];
|
| 269 |
+
returnStatus = m_bufferManager->allocateTensorBuffer(
|
| 270 |
+
((*tensors) + tensorIdx), tensorDataSize
|
| 271 |
+
);
|
| 272 |
+
} else {
|
| 273 |
+
std::string inputName = QNN_TENSOR_GET_NAME(sharedTensorMap[wrapperTensorName]);
|
| 274 |
+
QNN_DEBUG("IoTensor :: Reuse Buffer %s for Tensor %s", inputName.c_str(), wrapperTensorName.c_str());
|
| 275 |
+
returnStatus = m_bufferManager->useSameMemory(
|
| 276 |
+
((*tensors) + tensorIdx), sharedTensorMap[wrapperTensorName]
|
| 277 |
+
);
|
| 278 |
+
}
|
| 279 |
+
}
|
| 280 |
+
if (true != returnStatus) {
|
| 281 |
+
QNN_ERROR("Failure in setupTensors, cleaning up resources");
|
| 282 |
+
tearDownTensors(*tensors, tensorIdx);
|
| 283 |
+
*tensors = nullptr;
|
| 284 |
+
QNN_ERROR("Failure in setupTensors, done cleaning up resources");
|
| 285 |
+
break;
|
| 286 |
+
} else {
|
| 287 |
+
tensorNameToTensorPointer.insert({wrapperTensorName, ((*tensors) + tensorIdx)});
|
| 288 |
+
}
|
| 289 |
+
}
|
| 290 |
+
return returnStatus;
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
bool IOTensor::mapFusedBufferOffset(
|
| 294 |
GraphInfo_t* graph_info,
|
| 295 |
Qnn_ContextHandle_t context_handle,
|
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp
CHANGED
|
@@ -28,6 +28,7 @@
|
|
| 28 |
enum class BufferAlloc {
|
| 29 |
DEFAULT, // malloc based allocator
|
| 30 |
SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform
|
|
|
|
| 31 |
INVALID
|
| 32 |
};
|
| 33 |
class IBufferAlloc;
|
|
@@ -60,6 +61,16 @@ class IOTensor {
|
|
| 60 |
bool skipBufferAllocation = false
|
| 61 |
);
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount);
|
| 64 |
|
| 65 |
bool tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t tensorCount);
|
|
@@ -146,6 +157,20 @@ class IOTensor {
|
|
| 146 |
|
| 147 |
std::unordered_set<void*>& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; }
|
| 148 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 149 |
private:
|
| 150 |
BufferAlloc m_bufferAlloc;
|
| 151 |
QNN_INTERFACE_VER_TYPE* m_qnnInterface;
|
|
|
|
| 28 |
enum class BufferAlloc {
|
| 29 |
DEFAULT, // malloc based allocator
|
| 30 |
SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform
|
| 31 |
+
DMABUF, // dma buffer allocator
|
| 32 |
INVALID
|
| 33 |
};
|
| 34 |
class IBufferAlloc;
|
|
|
|
| 61 |
bool skipBufferAllocation = false
|
| 62 |
);
|
| 63 |
|
| 64 |
+
bool setupOutputWithSharedTensors(
|
| 65 |
+
Qnn_Tensor_t** outputs,
|
| 66 |
+
std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
|
| 67 |
+
const GraphInfo_t& graphInfo,
|
| 68 |
+
std::unordered_map<std::string, size_t>& outputTensorsSize,
|
| 69 |
+
Qnn_ContextHandle_t contextHandle,
|
| 70 |
+
std::unordered_map<std::string, Qnn_Tensor_t *> sharedTensorMap
|
| 71 |
+
);
|
| 72 |
+
|
| 73 |
+
|
| 74 |
bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount);
|
| 75 |
|
| 76 |
bool tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t tensorCount);
|
|
|
|
| 157 |
|
| 158 |
std::unordered_set<void*>& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; }
|
| 159 |
|
| 160 |
+
// Functions to sync memory buffers for Read/Write using DmaBuf.
|
| 161 |
+
bool beforeWriteToBuffer(Qnn_Tensor_t *tensor) {
|
| 162 |
+
return m_bufferManager->beforeWriteToBuffer(tensor);
|
| 163 |
+
}
|
| 164 |
+
bool afterWriteToBuffer(Qnn_Tensor_t *tensor){
|
| 165 |
+
return m_bufferManager->afterWriteToBuffer(tensor);
|
| 166 |
+
}
|
| 167 |
+
bool beforeReadFromBuffer(Qnn_Tensor_t *tensor){
|
| 168 |
+
return m_bufferManager->beforeReadFromBuffer(tensor);
|
| 169 |
+
}
|
| 170 |
+
bool afterReadFromBuffer(Qnn_Tensor_t *tensor){
|
| 171 |
+
return m_bufferManager->afterReadFromBuffer(tensor);
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
private:
|
| 175 |
BufferAlloc m_bufferAlloc;
|
| 176 |
QNN_INTERFACE_VER_TYPE* m_qnnInterface;
|
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp
CHANGED
|
@@ -106,11 +106,17 @@ bool QnnApi::getContextConfigs(
|
|
| 106 |
) {
|
| 107 |
std::vector<QnnContext_Config_t*> contextConfigPtrsVec;
|
| 108 |
|
| 109 |
-
if (contextPriority
|
| 110 |
-
contextConfigPtrsVec.push_back((QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t)));
|
| 111 |
contextConfigPtrsVec.back()->option =
|
| 112 |
-
QnnContext_ConfigOption_t::
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
}
|
| 115 |
|
| 116 |
const char** graphNames = nullptr;
|
|
@@ -891,6 +897,8 @@ bool QnnApi::composeGraphs(
|
|
| 891 |
QnnLog_Level_t::QNN_LOG_LEVEL_VERBOSE
|
| 892 |
);
|
| 893 |
|
|
|
|
|
|
|
| 894 |
if (status == MODEL_NO_ERROR) {
|
| 895 |
return true;
|
| 896 |
}
|
|
@@ -1163,33 +1171,6 @@ bool QnnApi::createFromBinary(
|
|
| 1163 |
}
|
| 1164 |
}
|
| 1165 |
|
| 1166 |
-
QnnContext_Config_t** contextConfigs = nullptr;
|
| 1167 |
-
uint32_t contextConfigCount = 0;
|
| 1168 |
-
if (true != getContextConfigs(
|
| 1169 |
-
&contextConfigs,
|
| 1170 |
-
contextConfigCount,
|
| 1171 |
-
contextConfig.priority,
|
| 1172 |
-
graphSwitching,
|
| 1173 |
-
execSelectGraphs,
|
| 1174 |
-
loadSelectGraphs
|
| 1175 |
-
)) {
|
| 1176 |
-
QNN_ERROR("Couldn't populate context configs");
|
| 1177 |
-
return false;
|
| 1178 |
-
}
|
| 1179 |
-
|
| 1180 |
-
// Merge BE specific and agnostic configs
|
| 1181 |
-
QnnContext_Config_t** allContextConfigs{nullptr};
|
| 1182 |
-
if (true != mergeAllContextConfigs(
|
| 1183 |
-
&allContextConfigs,
|
| 1184 |
-
customConfigs,
|
| 1185 |
-
contextConfigs,
|
| 1186 |
-
customConfigCount,
|
| 1187 |
-
contextConfigCount
|
| 1188 |
-
)) {
|
| 1189 |
-
QNN_ERROR("Error merging custom and context configs");
|
| 1190 |
-
return false;
|
| 1191 |
-
}
|
| 1192 |
-
|
| 1193 |
if (nullptr == m_qnnSystemInterface.systemContextCreate ||
|
| 1194 |
nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
|
| 1195 |
nullptr == m_qnnSystemInterface.systemContextFree) {
|
|
@@ -1299,9 +1280,36 @@ bool QnnApi::createFromBinary(
|
|
| 1299 |
}
|
| 1300 |
|
| 1301 |
bool isIOBufferMgrInitialized = false;
|
| 1302 |
-
|
| 1303 |
for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
|
| 1304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1305 |
if (nullptr == m_qnnInterface.contextCreateFromBinary) {
|
| 1306 |
QNN_ERROR(
|
| 1307 |
"contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx
|
|
@@ -1498,7 +1506,13 @@ bool QnnApi::createFromBinary(
|
|
| 1498 |
first_contextHandle = contextHandle;
|
| 1499 |
}
|
| 1500 |
#endif
|
| 1501 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1502 |
}
|
| 1503 |
|
| 1504 |
m_isContextCreated = true;
|
|
@@ -1507,14 +1521,6 @@ bool QnnApi::createFromBinary(
|
|
| 1507 |
"Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size()
|
| 1508 |
);
|
| 1509 |
|
| 1510 |
-
if (true != freeContextConfigs(contextConfigs, contextConfigCount)) {
|
| 1511 |
-
QNN_ERROR("Couldn't free context configs");
|
| 1512 |
-
return false;
|
| 1513 |
-
}
|
| 1514 |
-
if (allContextConfigs) {
|
| 1515 |
-
free(allContextConfigs);
|
| 1516 |
-
}
|
| 1517 |
-
|
| 1518 |
if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
|
| 1519 |
if (!m_backendExtensions->interface()->afterCreateFromBinary()) {
|
| 1520 |
QNN_ERROR("Extensions Failure in afterCreateFromBinary()");
|
|
@@ -1599,34 +1605,6 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1599 |
}
|
| 1600 |
}
|
| 1601 |
|
| 1602 |
-
|
| 1603 |
-
QnnContext_Config_t** contextConfigs = nullptr;
|
| 1604 |
-
uint32_t contextConfigCount = 0;
|
| 1605 |
-
if (true != getContextConfigs(
|
| 1606 |
-
&contextConfigs,
|
| 1607 |
-
contextConfigCount,
|
| 1608 |
-
contextConfig.priority,
|
| 1609 |
-
graphSwitching,
|
| 1610 |
-
execSelectGraphs,
|
| 1611 |
-
loadSelectGraphs
|
| 1612 |
-
)) {
|
| 1613 |
-
QNN_ERROR("Couldn't populate context configs");
|
| 1614 |
-
return false;
|
| 1615 |
-
}
|
| 1616 |
-
|
| 1617 |
-
// Merge BE specific and agnostic configs
|
| 1618 |
-
QnnContext_Config_t** allContextConfigs{nullptr};
|
| 1619 |
-
if (true != mergeAllContextConfigs(
|
| 1620 |
-
&allContextConfigs,
|
| 1621 |
-
customConfigs,
|
| 1622 |
-
contextConfigs,
|
| 1623 |
-
customConfigCount,
|
| 1624 |
-
contextConfigCount
|
| 1625 |
-
)) {
|
| 1626 |
-
QNN_ERROR("Error merging custom and context configs");
|
| 1627 |
-
return false;
|
| 1628 |
-
}
|
| 1629 |
-
|
| 1630 |
if (nullptr == m_qnnSystemInterface.systemContextCreate ||
|
| 1631 |
nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
|
| 1632 |
nullptr == m_qnnSystemInterface.systemContextFree) {
|
|
@@ -1642,6 +1620,8 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1642 |
GraphInfo_t*** graphsInfo =
|
| 1643 |
(GraphInfo_t***)calloc(cachedBinariesPathVec.size(), sizeof(GraphInfo_t**));
|
| 1644 |
uint32_t graphsTotalNum = 0;
|
|
|
|
|
|
|
| 1645 |
|
| 1646 |
for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
|
| 1647 |
auto _startPerContext = std::chrono::steady_clock::now();
|
|
@@ -1710,17 +1690,41 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1710 |
m_qnnSystemInterface.systemContextFree(sysCtxHandle);
|
| 1711 |
sysCtxHandle = nullptr;
|
| 1712 |
|
| 1713 |
-
uint32_t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1714 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1715 |
if (mmap_budget > 0) {
|
| 1716 |
QnnHtpContext_CustomConfig_t customConfigReadBudget;
|
| 1717 |
customConfigReadBudget.option = QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET;
|
| 1718 |
customConfigReadBudget.fileReadMemoryBudgetInMb = mmap_budget;
|
| 1719 |
|
| 1720 |
QnnContext_Config_t** cfgs{nullptr};
|
| 1721 |
-
|
| 1722 |
uint32_t customConfigCountReadBudget = 1;
|
| 1723 |
-
|
| 1724 |
cfgs = (QnnContext_Config_t**)malloc(
|
| 1725 |
customConfigCountReadBudget * sizeof(QnnContext_Config_t*)
|
| 1726 |
);
|
|
@@ -1729,15 +1733,16 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1729 |
cfgs[0]->customConfig =
|
| 1730 |
reinterpret_cast<QnnContext_CustomConfig_t>(&customConfigReadBudget);
|
| 1731 |
if (true != mergeAllContextConfigs(
|
| 1732 |
-
&allContextConfigs,
|
| 1733 |
cfgs,
|
| 1734 |
-
allContextConfigs,
|
| 1735 |
customConfigCountReadBudget,
|
| 1736 |
contextConfigCount + customConfigCount + customConfigCountSF
|
| 1737 |
)) {
|
| 1738 |
QNN_ERROR("Error merging custom and context configs");
|
| 1739 |
return false;
|
| 1740 |
}
|
|
|
|
| 1741 |
}
|
| 1742 |
|
| 1743 |
if (m_profileBackendHandle) {
|
|
@@ -1751,7 +1756,7 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1751 |
.version = QNN_CONTEXT_PARAMS_VERSION_1,
|
| 1752 |
.v1 =
|
| 1753 |
QnnContext_ParamsV1_t{
|
| 1754 |
-
(const QnnContext_Config_t**)allContextConfigs,
|
| 1755 |
(const void*)buffer.get(),
|
| 1756 |
bufferSize,
|
| 1757 |
nullptr,
|
|
@@ -1778,18 +1783,15 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1778 |
}
|
| 1779 |
|
| 1780 |
auto start = std::chrono::steady_clock::now();
|
| 1781 |
-
|
| 1782 |
-
|
| 1783 |
auto errCode = m_qnnInterface.contextCreateFromBinaryListAsync(
|
| 1784 |
m_backendHandle,
|
| 1785 |
m_deviceHandle,
|
| 1786 |
const_cast<const QnnContext_Params_t**>(context_params_list.data()),
|
| 1787 |
-
(const QnnContext_Config_t**)
|
| 1788 |
nullptr
|
| 1789 |
);
|
| 1790 |
-
|
| 1791 |
-
|
| 1792 |
auto stop = std::chrono::steady_clock::now();
|
|
|
|
| 1793 |
QNN_DEBUG(
|
| 1794 |
"Initializing %lu context with %u graphs took: %lld us",
|
| 1795 |
cachedBinariesPathVec.size(),
|
|
@@ -1824,26 +1826,24 @@ bool QnnApi::createFromBinaryListAsync(
|
|
| 1824 |
|
| 1825 |
m_isContextCreated = true;
|
| 1826 |
|
| 1827 |
-
if (true != freeContextConfigs(contextConfigs, contextConfigCount)) {
|
| 1828 |
-
QNN_ERROR("Couldn't free context configs");
|
| 1829 |
-
return false;
|
| 1830 |
-
}
|
| 1831 |
-
|
| 1832 |
if (true != freeContextParams(context_params_list.data(), cachedBinariesPathVec.size())) {
|
| 1833 |
QNN_ERROR("Couldn't free context params list");
|
| 1834 |
return false;
|
| 1835 |
}
|
| 1836 |
|
| 1837 |
-
if (allContextConfigs) {
|
| 1838 |
-
free(allContextConfigs);
|
| 1839 |
-
}
|
| 1840 |
-
|
| 1841 |
if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
|
| 1842 |
if (!m_backendExtensions->interface()->afterCreateContextsFromBinaryList()) {
|
| 1843 |
QNN_ERROR("Extensions Failure in afterCreateContextsFromBinaryList()");
|
| 1844 |
return false;
|
| 1845 |
}
|
| 1846 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1847 |
return true;
|
| 1848 |
}
|
| 1849 |
#endif
|
|
@@ -2543,6 +2543,64 @@ bool QnnApi::extractProfilingEvent(QnnProfile_EventId_t profileEventId) {
|
|
| 2543 |
return true;
|
| 2544 |
}
|
| 2545 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2546 |
bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch) {
|
| 2547 |
#if QUALLA_QNN_API_VERSION < 21700
|
| 2548 |
QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries");
|
|
@@ -2650,7 +2708,7 @@ bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bo
|
|
| 2650 |
#endif
|
| 2651 |
}
|
| 2652 |
|
| 2653 |
-
bool QnnApi::updateIOEncodings(std::shared_ptr<uint8_t>& buffer,uint64_t bufferSize,uint32_t graphIndex){
|
| 2654 |
|
| 2655 |
QNN_DEBUG("Applying adapter Encodings");
|
| 2656 |
QnnSystemContext_Handle_t sysCtxHandle{nullptr};
|
|
@@ -2679,3 +2737,224 @@ bool QnnApi::updateIOEncodings(std::shared_ptr<uint8_t>& buffer,uint64_t buffer
|
|
| 2679 |
QNN_DEBUG(" updateIOEncodings success ");
|
| 2680 |
return true;
|
| 2681 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
) {
|
| 107 |
std::vector<QnnContext_Config_t*> contextConfigPtrsVec;
|
| 108 |
|
| 109 |
+
if (contextPriority == QNN_PRIORITY_UNDEFINED) {
|
| 110 |
+
contextConfigPtrsVec.push_back((QnnContext_Config_t *) malloc(sizeof(QnnContext_Config_t)));
|
| 111 |
contextConfigPtrsVec.back()->option =
|
| 112 |
+
QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_UNDEFINED;
|
| 113 |
+
} else {
|
| 114 |
+
if (contextPriority != QNN_PRIORITY_DEFAULT) {
|
| 115 |
+
contextConfigPtrsVec.push_back((QnnContext_Config_t *) malloc(sizeof(QnnContext_Config_t)));
|
| 116 |
+
contextConfigPtrsVec.back()->option =
|
| 117 |
+
QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_OPTION_PRIORITY;
|
| 118 |
+
contextConfigPtrsVec.back()->priority = contextPriority;
|
| 119 |
+
}
|
| 120 |
}
|
| 121 |
|
| 122 |
const char** graphNames = nullptr;
|
|
|
|
| 897 |
QnnLog_Level_t::QNN_LOG_LEVEL_VERBOSE
|
| 898 |
);
|
| 899 |
|
| 900 |
+
graphCountPerContext = m_graphsCount;
|
| 901 |
+
|
| 902 |
if (status == MODEL_NO_ERROR) {
|
| 903 |
return true;
|
| 904 |
}
|
|
|
|
| 1171 |
}
|
| 1172 |
}
|
| 1173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1174 |
if (nullptr == m_qnnSystemInterface.systemContextCreate ||
|
| 1175 |
nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
|
| 1176 |
nullptr == m_qnnSystemInterface.systemContextFree) {
|
|
|
|
| 1280 |
}
|
| 1281 |
|
| 1282 |
bool isIOBufferMgrInitialized = false;
|
|
|
|
| 1283 |
for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
|
| 1284 |
|
| 1285 |
+
// Create context configs for each context
|
| 1286 |
+
QnnContext_Config_t** contextConfigs = nullptr;
|
| 1287 |
+
uint32_t contextConfigCount = 0;
|
| 1288 |
+
if (true != getContextConfigs(
|
| 1289 |
+
&contextConfigs,
|
| 1290 |
+
contextConfigCount,
|
| 1291 |
+
contextConfig.priority,
|
| 1292 |
+
graphSwitching,
|
| 1293 |
+
execSelectGraphs,
|
| 1294 |
+
loadSelectGraphs
|
| 1295 |
+
)) {
|
| 1296 |
+
QNN_ERROR("Couldn't populate context configs");
|
| 1297 |
+
return false;
|
| 1298 |
+
}
|
| 1299 |
+
|
| 1300 |
+
// Merge BE specific and agnostic configs
|
| 1301 |
+
QnnContext_Config_t** allContextConfigs{nullptr};
|
| 1302 |
+
if (true != mergeAllContextConfigs(
|
| 1303 |
+
&allContextConfigs,
|
| 1304 |
+
customConfigs,
|
| 1305 |
+
contextConfigs,
|
| 1306 |
+
customConfigCount,
|
| 1307 |
+
contextConfigCount
|
| 1308 |
+
)) {
|
| 1309 |
+
QNN_ERROR("Error merging custom and context configs");
|
| 1310 |
+
return false;
|
| 1311 |
+
}
|
| 1312 |
+
|
| 1313 |
if (nullptr == m_qnnInterface.contextCreateFromBinary) {
|
| 1314 |
QNN_ERROR(
|
| 1315 |
"contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx
|
|
|
|
| 1506 |
first_contextHandle = contextHandle;
|
| 1507 |
}
|
| 1508 |
#endif
|
| 1509 |
+
if (true != freeContextConfigs(contextConfigs, contextConfigCount)) {
|
| 1510 |
+
QNN_ERROR("Couldn't free context configs");
|
| 1511 |
+
return false;
|
| 1512 |
+
}
|
| 1513 |
+
if (allContextConfigs) {
|
| 1514 |
+
free(allContextConfigs);
|
| 1515 |
+
}
|
| 1516 |
}
|
| 1517 |
|
| 1518 |
m_isContextCreated = true;
|
|
|
|
| 1521 |
"Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size()
|
| 1522 |
);
|
| 1523 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1524 |
if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
|
| 1525 |
if (!m_backendExtensions->interface()->afterCreateFromBinary()) {
|
| 1526 |
QNN_ERROR("Extensions Failure in afterCreateFromBinary()");
|
|
|
|
| 1605 |
}
|
| 1606 |
}
|
| 1607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1608 |
if (nullptr == m_qnnSystemInterface.systemContextCreate ||
|
| 1609 |
nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
|
| 1610 |
nullptr == m_qnnSystemInterface.systemContextFree) {
|
|
|
|
| 1620 |
GraphInfo_t*** graphsInfo =
|
| 1621 |
(GraphInfo_t***)calloc(cachedBinariesPathVec.size(), sizeof(GraphInfo_t**));
|
| 1622 |
uint32_t graphsTotalNum = 0;
|
| 1623 |
+
std::vector<QnnContext_Config_t**> allContextConfigs{(unsigned int)cachedBinariesPathVec.size(), nullptr};
|
| 1624 |
+
std::vector<uint32_t> allContextConfigsSize{(unsigned int)cachedBinariesPathVec.size()};
|
| 1625 |
|
| 1626 |
for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
|
| 1627 |
auto _startPerContext = std::chrono::steady_clock::now();
|
|
|
|
| 1690 |
m_qnnSystemInterface.systemContextFree(sysCtxHandle);
|
| 1691 |
sysCtxHandle = nullptr;
|
| 1692 |
|
| 1693 |
+
uint32_t contextConfigCount = 0;
|
| 1694 |
+
if (true != getContextConfigs(
|
| 1695 |
+
&allContextConfigs[contextIdx],
|
| 1696 |
+
contextConfigCount,
|
| 1697 |
+
contextConfig.priority,
|
| 1698 |
+
graphSwitching,
|
| 1699 |
+
execSelectGraphs,
|
| 1700 |
+
loadSelectGraphs
|
| 1701 |
+
)) {
|
| 1702 |
+
QNN_ERROR("Couldn't populate context configs");
|
| 1703 |
+
return false;
|
| 1704 |
+
}
|
| 1705 |
+
allContextConfigsSize[contextIdx] = contextConfigCount;
|
| 1706 |
|
| 1707 |
+
// Merge BE specific and agnostic configs
|
| 1708 |
+
if (true != mergeAllContextConfigs(
|
| 1709 |
+
&allContextConfigs[contextIdx],
|
| 1710 |
+
customConfigs,
|
| 1711 |
+
allContextConfigs[contextIdx],
|
| 1712 |
+
customConfigCount,
|
| 1713 |
+
contextConfigCount
|
| 1714 |
+
)) {
|
| 1715 |
+
QNN_ERROR("Error merging custom and context configs");
|
| 1716 |
+
return false;
|
| 1717 |
+
}
|
| 1718 |
+
allContextConfigsSize[contextIdx] += customConfigCount;
|
| 1719 |
+
|
| 1720 |
+
uint32_t customConfigCountSF = 0;
|
| 1721 |
if (mmap_budget > 0) {
|
| 1722 |
QnnHtpContext_CustomConfig_t customConfigReadBudget;
|
| 1723 |
customConfigReadBudget.option = QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET;
|
| 1724 |
customConfigReadBudget.fileReadMemoryBudgetInMb = mmap_budget;
|
| 1725 |
|
| 1726 |
QnnContext_Config_t** cfgs{nullptr};
|
|
|
|
| 1727 |
uint32_t customConfigCountReadBudget = 1;
|
|
|
|
| 1728 |
cfgs = (QnnContext_Config_t**)malloc(
|
| 1729 |
customConfigCountReadBudget * sizeof(QnnContext_Config_t*)
|
| 1730 |
);
|
|
|
|
| 1733 |
cfgs[0]->customConfig =
|
| 1734 |
reinterpret_cast<QnnContext_CustomConfig_t>(&customConfigReadBudget);
|
| 1735 |
if (true != mergeAllContextConfigs(
|
| 1736 |
+
&allContextConfigs[contextIdx],
|
| 1737 |
cfgs,
|
| 1738 |
+
allContextConfigs[contextIdx],
|
| 1739 |
customConfigCountReadBudget,
|
| 1740 |
contextConfigCount + customConfigCount + customConfigCountSF
|
| 1741 |
)) {
|
| 1742 |
QNN_ERROR("Error merging custom and context configs");
|
| 1743 |
return false;
|
| 1744 |
}
|
| 1745 |
+
allContextConfigsSize[contextIdx] += customConfigCountReadBudget;
|
| 1746 |
}
|
| 1747 |
|
| 1748 |
if (m_profileBackendHandle) {
|
|
|
|
| 1756 |
.version = QNN_CONTEXT_PARAMS_VERSION_1,
|
| 1757 |
.v1 =
|
| 1758 |
QnnContext_ParamsV1_t{
|
| 1759 |
+
(const QnnContext_Config_t**)allContextConfigs[contextIdx],
|
| 1760 |
(const void*)buffer.get(),
|
| 1761 |
bufferSize,
|
| 1762 |
nullptr,
|
|
|
|
| 1783 |
}
|
| 1784 |
|
| 1785 |
auto start = std::chrono::steady_clock::now();
|
|
|
|
|
|
|
| 1786 |
auto errCode = m_qnnInterface.contextCreateFromBinaryListAsync(
|
| 1787 |
m_backendHandle,
|
| 1788 |
m_deviceHandle,
|
| 1789 |
const_cast<const QnnContext_Params_t**>(context_params_list.data()),
|
| 1790 |
+
(const QnnContext_Config_t**)customConfigs,
|
| 1791 |
nullptr
|
| 1792 |
);
|
|
|
|
|
|
|
| 1793 |
auto stop = std::chrono::steady_clock::now();
|
| 1794 |
+
|
| 1795 |
QNN_DEBUG(
|
| 1796 |
"Initializing %lu context with %u graphs took: %lld us",
|
| 1797 |
cachedBinariesPathVec.size(),
|
|
|
|
| 1826 |
|
| 1827 |
m_isContextCreated = true;
|
| 1828 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1829 |
if (true != freeContextParams(context_params_list.data(), cachedBinariesPathVec.size())) {
|
| 1830 |
QNN_ERROR("Couldn't free context params list");
|
| 1831 |
return false;
|
| 1832 |
}
|
| 1833 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1834 |
if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
|
| 1835 |
if (!m_backendExtensions->interface()->afterCreateContextsFromBinaryList()) {
|
| 1836 |
QNN_ERROR("Extensions Failure in afterCreateContextsFromBinaryList()");
|
| 1837 |
return false;
|
| 1838 |
}
|
| 1839 |
}
|
| 1840 |
+
|
| 1841 |
+
for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
|
| 1842 |
+
if (true != freeContextConfigs(allContextConfigs[contextIdx], allContextConfigsSize[contextIdx])) {
|
| 1843 |
+
QNN_ERROR("Couldn't free context configs");
|
| 1844 |
+
return false;
|
| 1845 |
+
}
|
| 1846 |
+
}
|
| 1847 |
return true;
|
| 1848 |
}
|
| 1849 |
#endif
|
|
|
|
| 2543 |
return true;
|
| 2544 |
}
|
| 2545 |
|
| 2546 |
+
bool QnnApi::applyBinarySection(uint32_t graphId, std::string binSectionPath) {
|
| 2547 |
+
#if QUALLA_QNN_API_VERSION < 21700
|
| 2548 |
+
QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries");
|
| 2549 |
+
return false;
|
| 2550 |
+
#else
|
| 2551 |
+
// assumption splitNum from 0
|
| 2552 |
+
QNN_DEBUG("QnnApi::applyBinarySection %d ", graphId);
|
| 2553 |
+
if (nullptr == m_qnnInterface.contextApplyBinarySection) {
|
| 2554 |
+
QNN_ERROR("contextApplyBinarySection Interface not suported!!");
|
| 2555 |
+
return false;
|
| 2556 |
+
}
|
| 2557 |
+
if (graphId >= m_graphsCount) {
|
| 2558 |
+
QNN_ERROR(" Passed split %d base Model graphcount %d ", graphId, m_graphsCount);
|
| 2559 |
+
return false;
|
| 2560 |
+
}
|
| 2561 |
+
uint64_t bufferSize{0};
|
| 2562 |
+
std::shared_ptr<uint8_t> buffer{nullptr};
|
| 2563 |
+
bufferSize = getFileSize(binSectionPath);
|
| 2564 |
+
buffer = std::shared_ptr<uint8_t>(new uint8_t[bufferSize]);
|
| 2565 |
+
if (true != readBinaryFromFile(binSectionPath, buffer.get(), bufferSize)) {
|
| 2566 |
+
QNN_ERROR("Failed to read binary data for context index = %d", graphId);
|
| 2567 |
+
return false;
|
| 2568 |
+
}
|
| 2569 |
+
|
| 2570 |
+
QnnContext_Buffer_t qnnBuffer;
|
| 2571 |
+
qnnBuffer.version = QNN_CONTEXT_BUFFER_VERSION_1;
|
| 2572 |
+
qnnBuffer.v1.memType = QNN_CONTEXTMEMTYPE_RAW;
|
| 2573 |
+
qnnBuffer.v1.binaryBuf.dataSize = bufferSize;
|
| 2574 |
+
qnnBuffer.v1.binaryBuf.data = static_cast<void*>(buffer.get());
|
| 2575 |
+
auto graphCountPerContext = getGraphCountPerContext();
|
| 2576 |
+
if (graphCountPerContext <= 0) {
|
| 2577 |
+
QNN_ERROR(" graphCountPerContext is <=0 ");
|
| 2578 |
+
return false;
|
| 2579 |
+
}
|
| 2580 |
+
|
| 2581 |
+
auto contextHandle = m_contextVec[graphId / graphCountPerContext];
|
| 2582 |
+
auto graphHandle = m_graphsInfo[graphId]->graph;
|
| 2583 |
+
if (contextHandle == nullptr || graphHandle == nullptr) {
|
| 2584 |
+
QNN_ERROR(" contexthandle or graph handle is null for patch no = %d ", graphId);
|
| 2585 |
+
return false;
|
| 2586 |
+
}
|
| 2587 |
+
|
| 2588 |
+
auto errorCode = m_qnnInterface.contextApplyBinarySection(
|
| 2589 |
+
contextHandle,
|
| 2590 |
+
graphHandle,
|
| 2591 |
+
QNN_CONTEXT_SECTION_UPDATABLE,
|
| 2592 |
+
&qnnBuffer,
|
| 2593 |
+
nullptr, //profile handle is null
|
| 2594 |
+
nullptr //singal handle is null
|
| 2595 |
+
);
|
| 2596 |
+
if (errorCode != QNN_SUCCESS) {
|
| 2597 |
+
QNN_ERROR("Could not Apply Patch for graph = %d errocode = %zu ", graphId, errorCode);
|
| 2598 |
+
return false;
|
| 2599 |
+
}
|
| 2600 |
+
return true;
|
| 2601 |
+
#endif
|
| 2602 |
+
}
|
| 2603 |
+
|
| 2604 |
bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch) {
|
| 2605 |
#if QUALLA_QNN_API_VERSION < 21700
|
| 2606 |
QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries");
|
|
|
|
| 2708 |
#endif
|
| 2709 |
}
|
| 2710 |
|
| 2711 |
+
bool QnnApi::updateIOEncodings(std::shared_ptr<uint8_t>& buffer,uint64_t bufferSize,uint32_t graphIndex) {
|
| 2712 |
|
| 2713 |
QNN_DEBUG("Applying adapter Encodings");
|
| 2714 |
QnnSystemContext_Handle_t sysCtxHandle{nullptr};
|
|
|
|
| 2737 |
QNN_DEBUG(" updateIOEncodings success ");
|
| 2738 |
return true;
|
| 2739 |
}
|
| 2740 |
+
|
| 2741 |
+
// This is a light weight function of existing ::createFromBinary, used for
|
| 2742 |
+
// GPU execution to avoid conflicts with HTP use-case and for better readability.
|
| 2743 |
+
bool QnnApi::createFromBinary(
|
| 2744 |
+
std::vector<std::string> cachedBinariesPathVec
|
| 2745 |
+
) {
|
| 2746 |
+
auto _start = std::chrono::steady_clock::now();
|
| 2747 |
+
|
| 2748 |
+
if (nullptr == m_qnnSystemInterface.systemContextCreate ||
|
| 2749 |
+
nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
|
| 2750 |
+
nullptr == m_qnnSystemInterface.systemContextFree) {
|
| 2751 |
+
QNN_ERROR("QNN System function pointers are not populated.");
|
| 2752 |
+
return false;
|
| 2753 |
+
}
|
| 2754 |
+
|
| 2755 |
+
graphCountPerContext = getGraphCountPerContext();
|
| 2756 |
+
|
| 2757 |
+
for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
|
| 2758 |
+
uint64_t bufferSize{0};
|
| 2759 |
+
std::shared_ptr<uint8_t> buffer{nullptr};
|
| 2760 |
+
uint32_t graphsCount;
|
| 2761 |
+
|
| 2762 |
+
// read serialized binary into a byte buffer
|
| 2763 |
+
bufferSize = getFileSize(cachedBinariesPathVec[contextIdx]);
|
| 2764 |
+
if (0 == bufferSize) {
|
| 2765 |
+
QNN_ERROR(
|
| 2766 |
+
"Received path to an empty file for context index = %zu. Nothing to deserialize.",
|
| 2767 |
+
contextIdx
|
| 2768 |
+
);
|
| 2769 |
+
return false;
|
| 2770 |
+
}
|
| 2771 |
+
|
| 2772 |
+
buffer = std::shared_ptr<uint8_t>(
|
| 2773 |
+
new uint8_t[bufferSize], std::default_delete<uint8_t[]>()
|
| 2774 |
+
);
|
| 2775 |
+
if (!buffer) {
|
| 2776 |
+
QNN_ERROR("Failed to allocate memory for context index = %zu", contextIdx);
|
| 2777 |
+
return false;
|
| 2778 |
+
}
|
| 2779 |
+
if (true !=
|
| 2780 |
+
readBinaryFromFile(cachedBinariesPathVec[contextIdx], buffer.get(), bufferSize)) {
|
| 2781 |
+
QNN_ERROR("Failed to read binary data for context index = %zu", contextIdx);
|
| 2782 |
+
return false;
|
| 2783 |
+
}
|
| 2784 |
+
|
| 2785 |
+
// inspect binary info
|
| 2786 |
+
QnnSystemContext_Handle_t sysCtxHandle{nullptr};
|
| 2787 |
+
if (QNN_SUCCESS != m_qnnSystemInterface.systemContextCreate(&sysCtxHandle)) {
|
| 2788 |
+
QNN_ERROR("Could not create system handle for context index = %zu", contextIdx);
|
| 2789 |
+
return false;
|
| 2790 |
+
}
|
| 2791 |
+
|
| 2792 |
+
const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr};
|
| 2793 |
+
Qnn_ContextBinarySize_t binaryInfoSize{0};
|
| 2794 |
+
|
| 2795 |
+
if (QNN_SUCCESS != m_qnnSystemInterface.systemContextGetBinaryInfo(
|
| 2796 |
+
sysCtxHandle,
|
| 2797 |
+
static_cast<void*>(buffer.get()),
|
| 2798 |
+
bufferSize,
|
| 2799 |
+
&binaryInfo,
|
| 2800 |
+
&binaryInfoSize
|
| 2801 |
+
)) {
|
| 2802 |
+
QNN_ERROR("Failed to get context binary info for context index = %zu", contextIdx);
|
| 2803 |
+
return false;
|
| 2804 |
+
}
|
| 2805 |
+
|
| 2806 |
+
GraphInfo_t** graphsInfo;
|
| 2807 |
+
if (!copyMetadataToGraphsInfo(binaryInfo, graphsInfo, graphsCount)) {
|
| 2808 |
+
QNN_ERROR("Failed to copy metadata for graph index = %zu", contextIdx);
|
| 2809 |
+
freeGraphsInfo(&graphsInfo, graphsCount);
|
| 2810 |
+
if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
|
| 2811 |
+
return false;
|
| 2812 |
+
}
|
| 2813 |
+
|
| 2814 |
+
if (graphCountPerContext == -1) {
|
| 2815 |
+
graphCountPerContext = graphsCount;
|
| 2816 |
+
m_graphsInfo = (GraphInfo_t**)calloc(
|
| 2817 |
+
graphCountPerContext * cachedBinariesPathVec.size(), sizeof(GraphInfo_t*)
|
| 2818 |
+
);
|
| 2819 |
+
} else if (graphCountPerContext != graphsCount) {
|
| 2820 |
+
QNN_ERROR(
|
| 2821 |
+
"Different len(graphs) found in different context files. Found %u vs %u",
|
| 2822 |
+
graphsCount,
|
| 2823 |
+
graphCountPerContext
|
| 2824 |
+
);
|
| 2825 |
+
freeGraphsInfo(&graphsInfo, graphsCount);
|
| 2826 |
+
if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
|
| 2827 |
+
return false;
|
| 2828 |
+
}
|
| 2829 |
+
m_qnnSystemInterface.systemContextFree(sysCtxHandle);
|
| 2830 |
+
sysCtxHandle = nullptr;
|
| 2831 |
+
|
| 2832 |
+
if (nullptr == m_qnnInterface.contextCreateFromBinary) {
|
| 2833 |
+
QNN_ERROR(
|
| 2834 |
+
"contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx
|
| 2835 |
+
);
|
| 2836 |
+
freeGraphsInfo(&graphsInfo, graphsCount);
|
| 2837 |
+
if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
|
| 2838 |
+
return false;
|
| 2839 |
+
}
|
| 2840 |
+
Qnn_ContextHandle_t contextHandle{nullptr};
|
| 2841 |
+
auto _stop = std::chrono::steady_clock::now();
|
| 2842 |
+
QNN_DEBUG(
|
| 2843 |
+
"Loading contexts[%lu] took: %lld us",
|
| 2844 |
+
contextIdx,
|
| 2845 |
+
std::chrono::duration_cast<std::chrono::microseconds>(_stop - _start).count()
|
| 2846 |
+
);
|
| 2847 |
+
|
| 2848 |
+
auto start = std::chrono::steady_clock::now();
|
| 2849 |
+
|
| 2850 |
+
auto errCode = m_qnnInterface.contextCreateFromBinary(
|
| 2851 |
+
m_backendHandle,
|
| 2852 |
+
m_deviceHandle,
|
| 2853 |
+
nullptr,
|
| 2854 |
+
(const void*)buffer.get(),
|
| 2855 |
+
bufferSize,
|
| 2856 |
+
&contextHandle,
|
| 2857 |
+
nullptr // profile handle
|
| 2858 |
+
|
| 2859 |
+
);
|
| 2860 |
+
|
| 2861 |
+
if (errCode != QNN_SUCCESS) {
|
| 2862 |
+
QNN_ERROR(
|
| 2863 |
+
"Could not create context from binary for context index = %zu : err %d",
|
| 2864 |
+
contextIdx,
|
| 2865 |
+
(int)errCode
|
| 2866 |
+
);
|
| 2867 |
+
freeGraphsInfo(&graphsInfo, graphsCount);
|
| 2868 |
+
if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
|
| 2869 |
+
return false;
|
| 2870 |
+
}
|
| 2871 |
+
|
| 2872 |
+
auto stop = std::chrono::steady_clock::now();
|
| 2873 |
+
QNN_DEBUG(
|
| 2874 |
+
"Initializing context[%lu] with %u graphs took: %lld us",
|
| 2875 |
+
contextIdx,
|
| 2876 |
+
graphsCount,
|
| 2877 |
+
std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count()
|
| 2878 |
+
);
|
| 2879 |
+
|
| 2880 |
+
for (int n_graph = 0; n_graph < graphsCount; n_graph++) {
|
| 2881 |
+
// Allocate inputTensors and outputTensors
|
| 2882 |
+
GraphInfo_t* cur_graph = graphsInfo[n_graph];
|
| 2883 |
+
|
| 2884 |
+
m_graphsInfo[m_graphsCount++] = cur_graph;
|
| 2885 |
+
m_contextMap[cur_graph] = contextHandle;
|
| 2886 |
+
}
|
| 2887 |
+
m_contextVec.push_back(contextHandle);
|
| 2888 |
+
}
|
| 2889 |
+
|
| 2890 |
+
m_isContextCreated = true;
|
| 2891 |
+
|
| 2892 |
+
QNN_DEBUG(
|
| 2893 |
+
"Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size()
|
| 2894 |
+
);
|
| 2895 |
+
|
| 2896 |
+
if (nullptr == m_qnnInterface.graphRetrieve) {
|
| 2897 |
+
QNN_ERROR("graphRetrieveFnHandle is nullptr.");
|
| 2898 |
+
freeGraphsInfo(&m_graphsInfo, m_graphsCount);
|
| 2899 |
+
return false;
|
| 2900 |
+
}
|
| 2901 |
+
|
| 2902 |
+
for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) {
|
| 2903 |
+
if (!m_graphsInfo || QNN_SUCCESS != m_qnnInterface.graphRetrieve(
|
| 2904 |
+
m_contextVec[graphIdx / graphCountPerContext],
|
| 2905 |
+
m_graphsInfo[graphIdx]->graphName,
|
| 2906 |
+
&(m_graphsInfo[graphIdx]->graph)
|
| 2907 |
+
)) {
|
| 2908 |
+
QNN_ERROR("Unable to retrieve graph handle for graph index = %zu", graphIdx);
|
| 2909 |
+
freeGraphsInfo(&m_graphsInfo, m_graphsCount);
|
| 2910 |
+
return false;
|
| 2911 |
+
}
|
| 2912 |
+
}
|
| 2913 |
+
|
| 2914 |
+
return true;
|
| 2915 |
+
}
|
| 2916 |
+
|
| 2917 |
+
bool QnnApi::initialize(
|
| 2918 |
+
std::string backendPath,
|
| 2919 |
+
std::vector<std::string> modelPathOrCachedBinaryPath
|
| 2920 |
+
) {
|
| 2921 |
+
if (modelPathOrCachedBinaryPath.size() != 1) {
|
| 2922 |
+
QNN_ERROR("Multiple Files not supported for now!!");
|
| 2923 |
+
return false;
|
| 2924 |
+
}
|
| 2925 |
+
|
| 2926 |
+
if (false == getQnnInterface(backendPath)) {
|
| 2927 |
+
QNN_ERROR("Qnn getQnnInterface FAILED!");
|
| 2928 |
+
return false;
|
| 2929 |
+
}
|
| 2930 |
+
|
| 2931 |
+
const std::string systemLibraryPath = "libQnnSystem.so";
|
| 2932 |
+
if (false == getQnnSystemInterface(systemLibraryPath)) {
|
| 2933 |
+
QNN_ERROR("Qnn getQnnSystemInterface FAILED!");
|
| 2934 |
+
return false;
|
| 2935 |
+
}
|
| 2936 |
+
|
| 2937 |
+
QnnLog_Level_t logLevel = QNN_LOG_LEVEL_INFO;
|
| 2938 |
+
if (false == initializeLogging(logLevel, false)) {
|
| 2939 |
+
QNN_ERROR("Unable to Initialize logging in backend");
|
| 2940 |
+
return false;
|
| 2941 |
+
}
|
| 2942 |
+
|
| 2943 |
+
// Initialize Backend
|
| 2944 |
+
if (false == initializeBackend()) {
|
| 2945 |
+
QNN_ERROR("Qnn initializeBackend FAILED!");
|
| 2946 |
+
return false;
|
| 2947 |
+
}
|
| 2948 |
+
|
| 2949 |
+
if (false == createFromBinary(modelPathOrCachedBinaryPath)) {
|
| 2950 |
+
QNN_ERROR("Create From Binary FAILED!");
|
| 2951 |
+
return false;
|
| 2952 |
+
}
|
| 2953 |
+
|
| 2954 |
+
for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) {
|
| 2955 |
+
m_graphNameToIndex[m_graphsInfo[graphIdx]->graphName] = graphIdx;
|
| 2956 |
+
}
|
| 2957 |
+
QNN_DEBUG("Model Initialized");
|
| 2958 |
+
|
| 2959 |
+
return true;
|
| 2960 |
+
}
|
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp
CHANGED
|
@@ -370,6 +370,8 @@ class QnnApi {
|
|
| 370 |
|
| 371 |
bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);
|
| 372 |
|
|
|
|
|
|
|
| 373 |
QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; };
|
| 374 |
GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; };
|
| 375 |
uint32_t getGraphsCount() { return m_graphsCount; };
|
|
@@ -426,4 +428,11 @@ class QnnApi {
|
|
| 426 |
bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
|
| 427 |
uint64_t bufferSize,
|
| 428 |
uint32_t graphIndex);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
};
|
|
|
|
| 370 |
|
| 371 |
bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);
|
| 372 |
|
| 373 |
+
bool applyBinarySection(uint32_t graphId, std::string binSectionPath);
|
| 374 |
+
|
| 375 |
QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; };
|
| 376 |
GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; };
|
| 377 |
uint32_t getGraphsCount() { return m_graphsCount; };
|
|
|
|
| 428 |
bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
|
| 429 |
uint64_t bufferSize,
|
| 430 |
uint32_t graphIndex);
|
| 431 |
+
|
| 432 |
+
bool createFromBinary(std::vector<std::string> cachedBinariesPathVec);
|
| 433 |
+
|
| 434 |
+
bool initialize(
|
| 435 |
+
std::string backendPath,
|
| 436 |
+
std::vector<std::string> modelPathOrCachedBinaryPath
|
| 437 |
+
);
|
| 438 |
};
|
Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp
CHANGED
|
@@ -46,14 +46,14 @@ bool writeRawData(void* tensorData, size_t tensorSize, const std::filesystem::pa
|
|
| 46 |
bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
|
| 47 |
|
| 48 |
struct Dims {
|
| 49 |
-
|
| 50 |
-
|
| 51 |
Dims() : height(0), width(0), channel(0), bitWidth(0) {}
|
| 52 |
-
Dims(
|
| 53 |
: height(height), width(width), channel(channel), bitWidth(bitWidth) {}
|
| 54 |
Dims(std::vector<size_t>& tDims)
|
| 55 |
-
: height((
|
| 56 |
-
bitWidth((
|
| 57 |
// Hack to mix batch dimension
|
| 58 |
if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0];
|
| 59 |
if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0];
|
|
|
|
| 46 |
bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
|
| 47 |
|
| 48 |
struct Dims {
|
| 49 |
+
uint32_t batch = 1;
|
| 50 |
+
uint32_t height, width, channel, bitWidth;
|
| 51 |
Dims() : height(0), width(0), channel(0), bitWidth(0) {}
|
| 52 |
+
Dims(uint32_t height, uint32_t width, uint32_t channel, uint32_t bitWidth)
|
| 53 |
: height(height), width(width), channel(channel), bitWidth(bitWidth) {}
|
| 54 |
Dims(std::vector<size_t>& tDims)
|
| 55 |
+
: height((uint32_t)tDims[1]), width((uint32_t)tDims[2]), channel((uint32_t)tDims[3]),
|
| 56 |
+
bitWidth((uint32_t)tDims[4]) {
|
| 57 |
// Hack to mix batch dimension
|
| 58 |
if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0];
|
| 59 |
if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0];
|
Genie/Genie/src/qualla/engines/qnn-cpu.cpp
CHANGED
|
@@ -55,8 +55,10 @@ class QnnCpuEngine : public Engine {
|
|
| 55 |
virtual bool updateKV(size_t n_past) override;
|
| 56 |
virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
|
| 57 |
virtual bool save(const std::string& name) override;
|
| 58 |
-
virtual size_t restore(const std::string& name) override;
|
| 59 |
virtual void reset() override;
|
|
|
|
|
|
|
| 60 |
};
|
| 61 |
|
| 62 |
namespace fs = std::filesystem;
|
|
@@ -98,7 +100,40 @@ QnnCpuEngine::QnnCpuEngine(Context& ctx, const qualla::json& json) : Engine(ctx,
|
|
| 98 |
p.use_mmap = conf.optional<bool>("use-mmap", false);
|
| 99 |
p.ctx_size = _ctx.size();
|
| 100 |
p.n_vocab_size = _ctx.n_vocab();
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
_model = std::make_unique<QnnCpuModel>(_env, p);
|
| 103 |
|
| 104 |
// Load model
|
|
@@ -211,7 +246,7 @@ size_t QnnCpuEngine::process(
|
|
| 211 |
);
|
| 212 |
}
|
| 213 |
|
| 214 |
-
size_t QnnCpuEngine::restore(const std::string& name) {
|
| 215 |
fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role);
|
| 216 |
return _model->loadKVCache(cache_path.string());
|
| 217 |
}
|
|
@@ -226,6 +261,23 @@ void QnnCpuEngine::reset() {
|
|
| 226 |
updateKV(0);
|
| 227 |
}
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
// Registrator instance
|
| 230 |
static OnLoad regy([]() {
|
| 231 |
Engine::__register("qnn-cpu", [](Context& ctx, const json& conf) {
|
|
|
|
| 55 |
virtual bool updateKV(size_t n_past) override;
|
| 56 |
virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
|
| 57 |
virtual bool save(const std::string& name) override;
|
| 58 |
+
virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
|
| 59 |
virtual void reset() override;
|
| 60 |
+
virtual bool applyLoraAdapter(std::string lora_adapter_name) override;
|
| 61 |
+
virtual bool applyLoraStrength(std::string tensor_name, float tensor_val) override;
|
| 62 |
};
|
| 63 |
|
| 64 |
namespace fs = std::filesystem;
|
|
|
|
| 100 |
p.use_mmap = conf.optional<bool>("use-mmap", false);
|
| 101 |
p.ctx_size = _ctx.size();
|
| 102 |
p.n_vocab_size = _ctx.n_vocab();
|
| 103 |
+
p.lora_config_type = LoraConfigType::LORA_DISABLE;
|
| 104 |
+
qualla::json lora_conf = conf.optional<qualla::json>("lora", {});
|
| 105 |
+
if (lora_conf.size() != 0) {
|
| 106 |
+
p.lora_config_type = LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE;
|
| 107 |
+
if (lora_conf.is_array()) {
|
| 108 |
+
for (auto lc : lora_conf) {
|
| 109 |
+
std::string lnm = lc["adapter-name"];
|
| 110 |
+
p.lora_config[lnm].lora_name = lnm;
|
| 111 |
+
p.lora_config[lnm].alpha_tensor_name = lc["alpha-tensor-name"];
|
| 112 |
+
p.lora_config[lnm].alpha_tensor_val = 0.0f;
|
| 113 |
+
if(lc.contains("alpha-tensor-value")){
|
| 114 |
+
p.lora_config[lnm].alpha_tensor_val = lc["alpha-tensor-value"];
|
| 115 |
+
}
|
| 116 |
+
std::string basedir = "";
|
| 117 |
+
if(lc.contains("binsection-basedir")){
|
| 118 |
+
basedir = lc["binsection-basedir"];
|
| 119 |
+
}
|
| 120 |
+
uint32_t n = lc["bin-sections"].size();
|
| 121 |
+
for (uint32_t i = 0; i < n; i++) {
|
| 122 |
+
auto binSec = lc["bin-sections"].get<std::vector<std::string>>();
|
| 123 |
+
fs::path binsection_path = fs::path(binSec[i]);
|
| 124 |
+
if (binsection_path.is_relative()) binsection_path = basedir / fs::path(binSec[i]);
|
| 125 |
+
if (!fs::is_regular_file(binsection_path)) {
|
| 126 |
+
__ERROR("qnn-cpu: Can't access Lora binsection adapter : {}",
|
| 127 |
+
binsection_path.string());
|
| 128 |
+
throw std::runtime_error(
|
| 129 |
+
"qnn-cpu: Can't open adapter file : " + binsection_path.string()
|
| 130 |
+
);
|
| 131 |
+
}
|
| 132 |
+
p.lora_config[lnm].binsection_list.push_back(binsection_path.string());
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
}
|
| 137 |
_model = std::make_unique<QnnCpuModel>(_env, p);
|
| 138 |
|
| 139 |
// Load model
|
|
|
|
| 246 |
);
|
| 247 |
}
|
| 248 |
|
| 249 |
+
size_t QnnCpuEngine::restore(const std::string& name, bool chooseHigherVariant) {
|
| 250 |
fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role);
|
| 251 |
return _model->loadKVCache(cache_path.string());
|
| 252 |
}
|
|
|
|
| 261 |
updateKV(0);
|
| 262 |
}
|
| 263 |
|
| 264 |
+
// For Lora
|
| 265 |
+
bool QnnCpuEngine::applyLoraAdapter(std::string lora_adapter_name) {
|
| 266 |
+
if (!_model) {
|
| 267 |
+
__ERROR("qnn-cpu: applyLoraAdapter failed, model not initialized");
|
| 268 |
+
return false;
|
| 269 |
+
}
|
| 270 |
+
return _model->applyLoraAdapter(lora_adapter_name);
|
| 271 |
+
}
|
| 272 |
+
|
| 273 |
+
bool QnnCpuEngine::applyLoraStrength(std::string tensor_name, float tensor_val) {
|
| 274 |
+
if (!_model) {
|
| 275 |
+
__ERROR("qnn-cpu: applyLoraStrength failed, model not initialized");
|
| 276 |
+
return false;
|
| 277 |
+
}
|
| 278 |
+
return _model->applyLoraStrength(tensor_name, tensor_val);
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
// Registrator instance
|
| 282 |
static OnLoad regy([]() {
|
| 283 |
Engine::__register("qnn-cpu", [](Context& ctx, const json& conf) {
|
Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp
CHANGED
|
@@ -61,6 +61,12 @@ QnnCpuModel::QnnCpuModel(Env& env, const Params& params)
|
|
| 61 |
m_output_dim.push_back(m_numLogits);
|
| 62 |
m_output_dim.push_back(m_embd);
|
| 63 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
QnnCpuModel::~QnnCpuModel() {
|
|
@@ -383,6 +389,7 @@ bool QnnCpuModel::initializeTensorPointers() {
|
|
| 383 |
t_input_ids_k_cache = &input_specs["x3"];
|
| 384 |
t_input_ids_v_cache = &input_specs["x4"];
|
| 385 |
t_input_ids_n_past = &input_specs["x5"];
|
|
|
|
| 386 |
|
| 387 |
auto& output_specs = m_output_specs[model_order.back()];
|
| 388 |
t_logits = &output_specs["output_genAI"];
|
|
@@ -406,6 +413,7 @@ void QnnCpuModel::setupInputTensors(const std::vector<int32_t>& tokens, bool run
|
|
| 406 |
uint32_t* input_id_num_token_buffer = (uint32_t*)getBuffer(t_input_ids_num_token);
|
| 407 |
uint32_t* input_id_reset_kvcache_buffer = (uint32_t*)getBuffer(t_input_ids_reset_kvcache);
|
| 408 |
uint32_t* input_id_n_past_buffer = (uint32_t*)getBuffer(t_input_ids_n_past);
|
|
|
|
| 409 |
|
| 410 |
uint32_t size = 1;
|
| 411 |
for (auto dim : m_input_dim) {
|
|
@@ -420,6 +428,7 @@ void QnnCpuModel::setupInputTensors(const std::vector<int32_t>& tokens, bool run
|
|
| 420 |
std::memcpy(input_id_buffer, tokens.data(), tokens.size() * sizeof(uint32_t));
|
| 421 |
*input_id_num_token_buffer = tokens.size();
|
| 422 |
*input_id_n_past_buffer = m_nPast;
|
|
|
|
| 423 |
|
| 424 |
auto stop = std::chrono::steady_clock::now();
|
| 425 |
// QnnUtils::logProfile("setupInputTensors (cpp) took", start, stop);
|
|
@@ -589,6 +598,48 @@ size_t QnnCpuModel::getDequantLogits(std::vector<float>& dequant_logits, bool lo
|
|
| 589 |
return logits_all? prev_run.num_tokens_processed : 1;
|
| 590 |
}
|
| 591 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 592 |
// TODO: implement save/restore
|
| 593 |
size_t QnnCpuModel::loadKVCache(const std::string& load_path) {
|
| 594 |
//TO read the cache file into KV tensor
|
|
|
|
| 61 |
m_output_dim.push_back(m_numLogits);
|
| 62 |
m_output_dim.push_back(m_embd);
|
| 63 |
}
|
| 64 |
+
m_loraConfigType = params.lora_config_type;
|
| 65 |
+
m_lora_alpha_val = 1.0f;
|
| 66 |
+
|
| 67 |
+
if (m_loraConfigType == LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
|
| 68 |
+
m_loraConfig.insert(params.lora_config.begin(), params.lora_config.end());
|
| 69 |
+
}
|
| 70 |
}
|
| 71 |
|
| 72 |
QnnCpuModel::~QnnCpuModel() {
|
|
|
|
| 389 |
t_input_ids_k_cache = &input_specs["x3"];
|
| 390 |
t_input_ids_v_cache = &input_specs["x4"];
|
| 391 |
t_input_ids_n_past = &input_specs["x5"];
|
| 392 |
+
t_input_lora_alpha = &input_specs["x6"];
|
| 393 |
|
| 394 |
auto& output_specs = m_output_specs[model_order.back()];
|
| 395 |
t_logits = &output_specs["output_genAI"];
|
|
|
|
| 413 |
uint32_t* input_id_num_token_buffer = (uint32_t*)getBuffer(t_input_ids_num_token);
|
| 414 |
uint32_t* input_id_reset_kvcache_buffer = (uint32_t*)getBuffer(t_input_ids_reset_kvcache);
|
| 415 |
uint32_t* input_id_n_past_buffer = (uint32_t*)getBuffer(t_input_ids_n_past);
|
| 416 |
+
float* input_id_lora_alpha = (float*)getBuffer(t_input_lora_alpha);
|
| 417 |
|
| 418 |
uint32_t size = 1;
|
| 419 |
for (auto dim : m_input_dim) {
|
|
|
|
| 428 |
std::memcpy(input_id_buffer, tokens.data(), tokens.size() * sizeof(uint32_t));
|
| 429 |
*input_id_num_token_buffer = tokens.size();
|
| 430 |
*input_id_n_past_buffer = m_nPast;
|
| 431 |
+
*input_id_lora_alpha = m_lora_alpha_val;
|
| 432 |
|
| 433 |
auto stop = std::chrono::steady_clock::now();
|
| 434 |
// QnnUtils::logProfile("setupInputTensors (cpp) took", start, stop);
|
|
|
|
| 598 |
return logits_all? prev_run.num_tokens_processed : 1;
|
| 599 |
}
|
| 600 |
|
| 601 |
+
bool QnnCpuModel::applyBinarySections(std::vector<std::string>& binsection_list) {
|
| 602 |
+
//apply binary section for lora config
|
| 603 |
+
for (int i = 0; i < binsection_list.size(); i++) {
|
| 604 |
+
__DEBUG("qnn-cpu: applyBinarySections adapters {}", binsection_list.at(i));
|
| 605 |
+
if (!m_qnnApi->applyBinarySection(i, binsection_list.at(i))) {
|
| 606 |
+
__ERROR("qnn-cpu: Error in applyBinarySections {}", i);
|
| 607 |
+
return false;
|
| 608 |
+
}
|
| 609 |
+
}
|
| 610 |
+
return true;
|
| 611 |
+
}
|
| 612 |
+
|
| 613 |
+
bool QnnCpuModel::applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val) {
|
| 614 |
+
m_lora_alpha_val = alpha_val;
|
| 615 |
+
return true;
|
| 616 |
+
}
|
| 617 |
+
|
| 618 |
+
bool QnnCpuModel::applyLoraAdapter(const std::string& lora_adapter_name) {
|
| 619 |
+
if (m_loraConfigType != LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
|
| 620 |
+
__ERROR("qnn-cpu: Lora config is not enable for adapters");
|
| 621 |
+
return false;
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
if (!m_loraConfig.contains(lora_adapter_name)) {
|
| 625 |
+
__ERROR("qnn-cpu: Could not find lora adapters config to apply ");
|
| 626 |
+
return false;
|
| 627 |
+
}
|
| 628 |
+
if (!applyLoraStrength(
|
| 629 |
+
m_loraConfig[lora_adapter_name].alpha_tensor_name,
|
| 630 |
+
m_loraConfig[lora_adapter_name].alpha_tensor_val
|
| 631 |
+
)) {
|
| 632 |
+
__ERROR("qnn-cpu: Could not apply Alpha tensor ");
|
| 633 |
+
return false;
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
if (!applyBinarySections(m_loraConfig[lora_adapter_name].binsection_list)) {
|
| 637 |
+
__ERROR("qnn-cpu: Could not apply binary Sections ");
|
| 638 |
+
return false;
|
| 639 |
+
}
|
| 640 |
+
return true;
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
// TODO: implement save/restore
|
| 644 |
size_t QnnCpuModel::loadKVCache(const std::string& load_path) {
|
| 645 |
//TO read the cache file into KV tensor
|
Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp
CHANGED
|
@@ -26,6 +26,12 @@
|
|
| 26 |
|
| 27 |
namespace qualla {
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
class QnnCpuModel {
|
| 30 |
enum ExecutionMode { AUTODETECT, BERT_KV, KV_ONLY, BERT_ONLY };
|
| 31 |
|
|
@@ -34,6 +40,13 @@ class QnnCpuModel {
|
|
| 34 |
public:
|
| 35 |
enum ModelOutput { LOGITS = 0x0, EMBEDDINGS= 0x1 };
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
struct Params {
|
| 38 |
std::filesystem::path model_basedir;
|
| 39 |
std::string op_package;
|
|
@@ -50,6 +63,8 @@ class QnnCpuModel {
|
|
| 50 |
uint32_t n_layer;
|
| 51 |
uint32_t n_embd;
|
| 52 |
uint32_t n_heads;
|
|
|
|
|
|
|
| 53 |
};
|
| 54 |
|
| 55 |
const std::filesystem::path model_basedir;
|
|
@@ -92,6 +107,11 @@ class QnnCpuModel {
|
|
| 92 |
std::vector<Qnn_Param_t> m_params;
|
| 93 |
ExecutionMode m_mode{ExecutionMode::AUTODETECT};
|
| 94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
// Save some information about the last inference run
|
| 96 |
struct PreviousRunInfo {
|
| 97 |
bool was_bert_mode;
|
|
@@ -118,6 +138,7 @@ class QnnCpuModel {
|
|
| 118 |
QnnUtils::Tensor* t_input_ids_k_cache;
|
| 119 |
QnnUtils::Tensor* t_input_ids_v_cache;
|
| 120 |
QnnUtils::Tensor* t_input_ids_n_past;
|
|
|
|
| 121 |
float* dequant_logits_ptr{nullptr};
|
| 122 |
|
| 123 |
// Store pointers for bert
|
|
@@ -171,6 +192,10 @@ class QnnCpuModel {
|
|
| 171 |
size_t loadKVCache(const std::string& save_path);
|
| 172 |
bool saveKVCache(const std::string& load_path);
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
private:
|
| 175 |
bool m_mmap_context_bins = false; // mmap context binary files instead of reading them in memory
|
| 176 |
// Internal functions to separate different runInference logic
|
|
|
|
| 26 |
|
| 27 |
namespace qualla {
|
| 28 |
|
| 29 |
+
enum LoraConfigType {
|
| 30 |
+
LORA_DISABLE = 0,
|
| 31 |
+
LORA_INPUT_WEIGHT_ENABLE = 1,
|
| 32 |
+
LORA_ADAPTER_WEIGHT_ENABLE = 2
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
class QnnCpuModel {
|
| 36 |
enum ExecutionMode { AUTODETECT, BERT_KV, KV_ONLY, BERT_ONLY };
|
| 37 |
|
|
|
|
| 40 |
public:
|
| 41 |
enum ModelOutput { LOGITS = 0x0, EMBEDDINGS= 0x1 };
|
| 42 |
|
| 43 |
+
struct LoraConfig {
|
| 44 |
+
std::string lora_name;
|
| 45 |
+
std::vector<std::string> binsection_list; //loRAv2 adapter bins filenames
|
| 46 |
+
std::string alpha_tensor_name; //loRAv2 alpha tensor names
|
| 47 |
+
float alpha_tensor_val; //loRAv2 alpha tensor values
|
| 48 |
+
};
|
| 49 |
+
|
| 50 |
struct Params {
|
| 51 |
std::filesystem::path model_basedir;
|
| 52 |
std::string op_package;
|
|
|
|
| 63 |
uint32_t n_layer;
|
| 64 |
uint32_t n_embd;
|
| 65 |
uint32_t n_heads;
|
| 66 |
+
LoraConfigType lora_config_type;
|
| 67 |
+
std::map<std::string, LoraConfig> lora_config;
|
| 68 |
};
|
| 69 |
|
| 70 |
const std::filesystem::path model_basedir;
|
|
|
|
| 107 |
std::vector<Qnn_Param_t> m_params;
|
| 108 |
ExecutionMode m_mode{ExecutionMode::AUTODETECT};
|
| 109 |
|
| 110 |
+
// LoRA params and configs
|
| 111 |
+
float m_lora_alpha_val;
|
| 112 |
+
LoraConfigType m_loraConfigType;
|
| 113 |
+
std::map<std::string, LoraConfig> m_loraConfig;
|
| 114 |
+
|
| 115 |
// Save some information about the last inference run
|
| 116 |
struct PreviousRunInfo {
|
| 117 |
bool was_bert_mode;
|
|
|
|
| 138 |
QnnUtils::Tensor* t_input_ids_k_cache;
|
| 139 |
QnnUtils::Tensor* t_input_ids_v_cache;
|
| 140 |
QnnUtils::Tensor* t_input_ids_n_past;
|
| 141 |
+
QnnUtils::Tensor* t_input_lora_alpha;
|
| 142 |
float* dequant_logits_ptr{nullptr};
|
| 143 |
|
| 144 |
// Store pointers for bert
|
|
|
|
| 192 |
size_t loadKVCache(const std::string& save_path);
|
| 193 |
bool saveKVCache(const std::string& load_path);
|
| 194 |
|
| 195 |
+
bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
|
| 196 |
+
bool applyLoraAdapter(const std::string& lora_adapter_name);
|
| 197 |
+
bool applyBinarySections(std::vector<std::string>& binsection_list);
|
| 198 |
+
|
| 199 |
private:
|
| 200 |
bool m_mmap_context_bins = false; // mmap context binary files instead of reading them in memory
|
| 201 |
// Internal functions to separate different runInference logic
|
Genie/Genie/src/qualla/engines/qnn-gpu.cpp
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 2 |
+
// Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
|
| 3 |
+
|
| 4 |
+
#include <vector>
|
| 5 |
+
#include <string>
|
| 6 |
+
|
| 7 |
+
#include <qualla/engine.hpp>
|
| 8 |
+
#include <qualla/detail/config.hpp>
|
| 9 |
+
#include <qualla/detail/timer.hpp>
|
| 10 |
+
#include <qualla/detail/onload.hpp>
|
| 11 |
+
|
| 12 |
+
#include <fmt/format.h>
|
| 13 |
+
|
| 14 |
+
#include "gpu-model.hpp"
|
| 15 |
+
|
| 16 |
+
#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
|
| 17 |
+
#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
|
| 18 |
+
#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
|
| 19 |
+
#define __KPIS(__fmt, ...) \
|
| 20 |
+
_env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 21 |
+
#define __DEBUG(__fmt, ...) \
|
| 22 |
+
_env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 23 |
+
#define __TRACE(__fmt, ...) \
|
| 24 |
+
_env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
|
| 25 |
+
|
| 26 |
+
namespace qualla {
|
| 27 |
+
|
| 28 |
+
class GpuEngine : public Engine {
|
| 29 |
+
private:
|
| 30 |
+
QnnGpuModel::Params _params;
|
| 31 |
+
std::unique_ptr<QnnGpuModel> _model;
|
| 32 |
+
|
| 33 |
+
public:
|
| 34 |
+
GpuEngine(Context& ctx, const qualla::json& json);
|
| 35 |
+
~GpuEngine();
|
| 36 |
+
|
| 37 |
+
virtual size_t process(
|
| 38 |
+
const std::vector<int32_t>& tokens,
|
| 39 |
+
std::vector<float>& logits,
|
| 40 |
+
bool logits_all
|
| 41 |
+
) override;
|
| 42 |
+
|
| 43 |
+
virtual bool updateKV(size_t n_past) override;
|
| 44 |
+
virtual bool save(const std::string& name) override;
|
| 45 |
+
virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
|
| 46 |
+
virtual void reset() override;
|
| 47 |
+
|
| 48 |
+
virtual bool load() override;
|
| 49 |
+
virtual bool unload() override;
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
namespace fs = std::filesystem;
|
| 53 |
+
|
| 54 |
+
GpuEngine::GpuEngine(Context& ctx, const qualla::json& json) : Engine(ctx, "qnn-gpu", json) {
|
| 55 |
+
qualla::Timer start;
|
| 56 |
+
|
| 57 |
+
using FF = Feature::Flags;
|
| 58 |
+
_features = FF::OUTPUT_LOGITS | FF::SAVE_RESTORE | FF::DYNAMIC_LOAD;
|
| 59 |
+
|
| 60 |
+
__DEBUG("Qnn-Gpu : init start");
|
| 61 |
+
|
| 62 |
+
qualla::Config conf(json, _type + "-engine:");
|
| 63 |
+
|
| 64 |
+
// Parse config
|
| 65 |
+
_params.model_basedir = conf.optional<std::string>("model-basedir", "");
|
| 66 |
+
if (_params.model_basedir.is_relative()) {
|
| 67 |
+
_params.model_basedir = _env.path().models / _params.model_basedir;
|
| 68 |
+
_params.model_basedir = _params.model_basedir.make_preferred();
|
| 69 |
+
}
|
| 70 |
+
_params.model_list = conf.mandatory<std::vector<std::string>>("model-list");
|
| 71 |
+
|
| 72 |
+
_params.ctx_size = _ctx.size();
|
| 73 |
+
_params.num_heads = conf.optional<int64_t>("num-heads", 32);
|
| 74 |
+
_params.head_dim = conf.optional<int64_t>("head-dim", 128);
|
| 75 |
+
|
| 76 |
+
if (!conf.optional<bool>("dynamic-load", false)) {
|
| 77 |
+
load();
|
| 78 |
+
}
|
| 79 |
+
};
|
| 80 |
+
|
| 81 |
+
GpuEngine::~GpuEngine() {
|
| 82 |
+
unload();
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
bool GpuEngine::load() {
|
| 86 |
+
#ifdef _WIN32
|
| 87 |
+
// QnnGpu Engine does not support Windows.
|
| 88 |
+
return false;
|
| 89 |
+
#endif
|
| 90 |
+
if (_model) return true;
|
| 91 |
+
|
| 92 |
+
qualla::Timer start;
|
| 93 |
+
bool status = true;
|
| 94 |
+
|
| 95 |
+
__INFO("Qnn-Gpu : Loading Model");
|
| 96 |
+
|
| 97 |
+
_model = std::make_unique<QnnGpuModel>(_env, _params);
|
| 98 |
+
|
| 99 |
+
// Load model
|
| 100 |
+
status = _model->initializeModel();
|
| 101 |
+
if (!status) {
|
| 102 |
+
throw std::runtime_error("Qnn-Gpu :Failure to initialize model");
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
// Initialize IO Tensor buffers
|
| 106 |
+
status = _model->initializeIOTensors();
|
| 107 |
+
if (!status) {
|
| 108 |
+
throw std::runtime_error("Qnn-Gpu :Error in setting up IO Tensors");
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
// Initialize IO Tensor Pointers
|
| 112 |
+
if (true != _model->initializeTensorPointers()) {
|
| 113 |
+
throw std::runtime_error("Qnn-Gpu :Could not find I/O tensors in loaded graphs");
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
// Validate the model
|
| 117 |
+
if (true != _model->validateModel()) {
|
| 118 |
+
throw std::runtime_error("Qnn-Gpu :Model Validation Failed");
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
_kpis.load.update(start.elapsed_usec());
|
| 122 |
+
return true;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
bool GpuEngine::unload() {
|
| 126 |
+
qualla::Timer start;
|
| 127 |
+
__DEBUG("Qnn-Gpu : Unloading Model");
|
| 128 |
+
_model.reset(nullptr);
|
| 129 |
+
_kpis.unload.update(start.elapsed_usec());
|
| 130 |
+
return true;
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
// KV Cache updation after each inference is handled inside QnnGpu Backend
|
| 134 |
+
// GPU Engine uses same memory handle for each KV input/output to the graph and uses
|
| 135 |
+
// Scatter op to update KV after each inference to the same memory handle.
|
| 136 |
+
bool GpuEngine::updateKV(size_t n_past) {
|
| 137 |
+
return true;
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
size_t GpuEngine::process(
|
| 141 |
+
const std::vector<int32_t>& tokens,
|
| 142 |
+
std::vector<float>& logits,
|
| 143 |
+
bool logits_all
|
| 144 |
+
) {
|
| 145 |
+
if (!_model && !load()) {
|
| 146 |
+
return 0;
|
| 147 |
+
}
|
| 148 |
+
qualla::Timer start;
|
| 149 |
+
size_t n_tok = _model->runInference(tokens, logits, logits_all);
|
| 150 |
+
if (n_tok == 0) {
|
| 151 |
+
State::error("Qnn-Gpu : RunInference Failed!");
|
| 152 |
+
}
|
| 153 |
+
_kpis.process.update(start.elapsed_usec());
|
| 154 |
+
return n_tok;
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
size_t GpuEngine::restore(const std::string& name, bool chooseHigherVariant) {
|
| 158 |
+
if (!_model && !load()) {
|
| 159 |
+
return 0;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-gpu", _role);
|
| 163 |
+
return _model->loadKVCache(cache_path.string());
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
bool GpuEngine::save(const std::string& name) {
|
| 167 |
+
if (!_model && !load()) {
|
| 168 |
+
return false;
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-gpu", _role);
|
| 172 |
+
return _model->saveKVCache(cache_path.string());
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
// Reset requires clearing of KV caches only
|
| 177 |
+
void GpuEngine::reset() {
|
| 178 |
+
if (!_model && !load()) {
|
| 179 |
+
return;
|
| 180 |
+
}
|
| 181 |
+
_model->reset();
|
| 182 |
+
}
|
| 183 |
+
|
| 184 |
+
// Registrator instance
|
| 185 |
+
static OnLoad regy([]() {
|
| 186 |
+
Engine::__register("qnn-gpu", [](Context& ctx, const json& conf) {
|
| 187 |
+
return (Engine*)new GpuEngine(ctx, conf);
|
| 188 |
+
});
|
| 189 |
+
});
|
| 190 |
+
|
| 191 |
+
void needQnnGpuEngine() {}
|
| 192 |
+
|
| 193 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.cpp
ADDED
|
@@ -0,0 +1,603 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 2 |
+
// Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
|
| 3 |
+
|
| 4 |
+
#include <cassert>
|
| 5 |
+
#include <cstring>
|
| 6 |
+
#include <fstream>
|
| 7 |
+
#include <set>
|
| 8 |
+
#include <sstream>
|
| 9 |
+
|
| 10 |
+
#include "fmt/format.h"
|
| 11 |
+
#include "fmt/ranges.h"
|
| 12 |
+
#include "fp16/fp16.h"
|
| 13 |
+
#include "gpu-model.hpp"
|
| 14 |
+
#include "qualla/detail/cache-file.hpp"
|
| 15 |
+
#include "qualla/detail/timer.hpp"
|
| 16 |
+
#include "qualla/env.hpp"
|
| 17 |
+
|
| 18 |
+
namespace fs = std::filesystem;
|
| 19 |
+
|
| 20 |
+
static constexpr uint32_t g_magicNum = 0xC0DE;
|
| 21 |
+
|
| 22 |
+
#define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__));
|
| 23 |
+
#define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__));
|
| 24 |
+
#define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__));
|
| 25 |
+
#define __KPIS(__fmt, ...) \
|
| 26 |
+
_env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); });
|
| 27 |
+
#define __DEBUG(__fmt, ...) \
|
| 28 |
+
_env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); });
|
| 29 |
+
#define __TRACE(__fmt, ...) \
|
| 30 |
+
_env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); });
|
| 31 |
+
|
| 32 |
+
namespace qualla {
|
| 33 |
+
|
| 34 |
+
QnnGpuModel::QnnGpuModel(Env& env, const Params& params)
|
| 35 |
+
: _env(env), _modelBaseDir(params.model_basedir) {
|
| 36 |
+
// Initialize _qnnApi
|
| 37 |
+
_qnnApi = std::unique_ptr<QnnApi>(new QnnApi());
|
| 38 |
+
|
| 39 |
+
_ctxSize = params.ctx_size;
|
| 40 |
+
_numHeads = params.num_heads;
|
| 41 |
+
_headDim = params.head_dim;
|
| 42 |
+
#ifdef _WIN32
|
| 43 |
+
_useDmabufIo = false;
|
| 44 |
+
#else
|
| 45 |
+
_useDmabufIo = true;
|
| 46 |
+
#endif
|
| 47 |
+
// Set up filename list for context binaries.
|
| 48 |
+
for (auto& i : params.model_list) {
|
| 49 |
+
fs::path model_path = _modelBaseDir / fs::path(i);
|
| 50 |
+
if (!fs::is_regular_file(model_path)) {
|
| 51 |
+
__ERROR("Qnn-Gpu-Model : Can't access model file : {}", model_path.string());
|
| 52 |
+
throw std::runtime_error("Qnn-Gpu-Model : Can't access model file : " + model_path.string());
|
| 53 |
+
}
|
| 54 |
+
_modelList.push_back(model_path.string());
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
QnnGpuModel::~QnnGpuModel() { __INFO("Qnn-Gpu-Model : model destruct complete"); }
|
| 59 |
+
|
| 60 |
+
// Given a filename, initializeModel load and initializes QNN runtime libraries and the model
|
| 61 |
+
bool QnnGpuModel::initializeModel(void) {
|
| 62 |
+
qualla::Timer start;
|
| 63 |
+
|
| 64 |
+
__INFO("Qnn-Gpu-Model : Model Init Start");
|
| 65 |
+
|
| 66 |
+
const std::string backend = "libQnnGpu.so";
|
| 67 |
+
|
| 68 |
+
__INFO("Backend Library : {}", backend);
|
| 69 |
+
__INFO("Model Files : {}", _modelList);
|
| 70 |
+
|
| 71 |
+
if (!_qnnApi->initialize(backend, _modelList)) {
|
| 72 |
+
__ERROR("Qnn-Api : Initialization Failed!");
|
| 73 |
+
return false;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
// Initialize QNN IO Tensor
|
| 77 |
+
if (_useDmabufIo) {
|
| 78 |
+
_ioTensor =
|
| 79 |
+
std::unique_ptr<IOTensor>(new IOTensor(BufferAlloc::DMABUF, _qnnApi->getQnnInterfaceVer()));
|
| 80 |
+
} else {
|
| 81 |
+
_ioTensor = std::unique_ptr<IOTensor>(
|
| 82 |
+
new IOTensor(BufferAlloc::DEFAULT, _qnnApi->getQnnInterfaceVer()));
|
| 83 |
+
}
|
| 84 |
+
_numGraphs = _qnnApi->getGraphsCount();
|
| 85 |
+
__INFO("Qnn-Gpu-Model : initialized with {} graph(s)", _numGraphs);
|
| 86 |
+
|
| 87 |
+
GraphInfo_t** graphs_info = _qnnApi->getGraphsInfo();
|
| 88 |
+
for (size_t graphIdx = 0; graphIdx < _numGraphs; graphIdx++) {
|
| 89 |
+
GraphInfo_t* const graphInfo = graphs_info[graphIdx];
|
| 90 |
+
char* graphName = graphInfo->graphName;
|
| 91 |
+
std::string graphStr = std::string(graphName);
|
| 92 |
+
|
| 93 |
+
_modelOrder.push_back(graphStr);
|
| 94 |
+
}
|
| 95 |
+
__INFO("Qnn-Gpu-Model : model init complete: {} usec", start.elapsed_usec());
|
| 96 |
+
|
| 97 |
+
return true;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
// Once the model has been loaded, initialize IO Tensors
|
| 101 |
+
// _ioTensors is initialized by the context for now
|
| 102 |
+
bool QnnGpuModel::initializeIOTensors() {
|
| 103 |
+
qualla::Timer start;
|
| 104 |
+
|
| 105 |
+
// For QNN-GPU, we have only one context per model.
|
| 106 |
+
bool status = _ioTensor->initialize(_qnnApi->getContexts().back());
|
| 107 |
+
if (!status) {
|
| 108 |
+
__ERROR("Qnn-Gpu-Model : failure to initialize IOTensor");
|
| 109 |
+
return false;
|
| 110 |
+
}
|
| 111 |
+
// Getting graph info, Hardcoding single graph for now.
|
| 112 |
+
GraphInfo_t** const& graphsInfo = _qnnApi->getGraphsInfo();
|
| 113 |
+
|
| 114 |
+
for (size_t graphIdx = 0; graphIdx < _numGraphs; graphIdx++) {
|
| 115 |
+
GraphInfo_t* const& graphInfo = graphsInfo[graphIdx];
|
| 116 |
+
std::string graphName = std::string(graphInfo->graphName);
|
| 117 |
+
|
| 118 |
+
__DEBUG("Qnn-Gpu-Model : numInputTensors {} numOutputTensors {}",
|
| 119 |
+
graphInfo->numInputTensors,
|
| 120 |
+
graphInfo->numOutputTensors);
|
| 121 |
+
// Setup Inputs
|
| 122 |
+
{
|
| 123 |
+
std::unordered_map<std::string, size_t> inputTensorsSize;
|
| 124 |
+
for (size_t tensorIdx = 0; tensorIdx < graphInfo->numInputTensors; tensorIdx++) {
|
| 125 |
+
std::string tensorName;
|
| 126 |
+
std::vector<size_t> tensorDims;
|
| 127 |
+
auto& tensor = graphInfo->inputTensors[tensorIdx];
|
| 128 |
+
_qnnApi->getTensorNameAndShape(tensorName, tensorDims, tensor);
|
| 129 |
+
auto dims = QnnUtils::Dims(tensorDims);
|
| 130 |
+
inputTensorsSize[tensorName] = dims.getSize();
|
| 131 |
+
__DEBUG("Qnn-Gpu-Model : Input Tensor Info {} {} {} {}",
|
| 132 |
+
tensorIdx,
|
| 133 |
+
tensorName,
|
| 134 |
+
tensorDims,
|
| 135 |
+
inputTensorsSize[tensorName]);
|
| 136 |
+
std::vector<QnnUtils::QuantParam> quantParams;
|
| 137 |
+
if (!_qnnApi->getTensorQuantParams(&tensor, quantParams)) {
|
| 138 |
+
quantParams.emplace_back(0, 0);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
std::shared_ptr<QnnUtils::Tensor> tensorUtil =
|
| 142 |
+
std::shared_ptr<QnnUtils::Tensor>(new (std::nothrow) QnnUtils::Tensor);
|
| 143 |
+
tensorUtil->dims = dims;
|
| 144 |
+
tensorUtil->dtype = QNN_TENSOR_GET_DATA_TYPE(tensor);
|
| 145 |
+
tensorUtil->quantParam = quantParams;
|
| 146 |
+
_inputSpecs[graphName][tensorName] = tensorUtil;
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
Qnn_Tensor_t* tensor_bank = nullptr;
|
| 150 |
+
std::unordered_map<std::string, void*> tensor_ptr_map;
|
| 151 |
+
if (true != _ioTensor->setupInputTensors(&tensor_bank,
|
| 152 |
+
tensor_ptr_map,
|
| 153 |
+
*graphInfo,
|
| 154 |
+
inputTensorsSize,
|
| 155 |
+
_qnnApi->getContexts()[graphIdx],
|
| 156 |
+
false)) {
|
| 157 |
+
QNN_ERROR("Qnn-Gpu-Model : Error in setting up Input Tensors for graph %s",
|
| 158 |
+
graphName.c_str());
|
| 159 |
+
return false;
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
_inputTensors[graphName] = tensor_bank;
|
| 163 |
+
for (auto& [tensorName, tensor_ptr] : tensor_ptr_map) {
|
| 164 |
+
_inputSpecs[graphName][tensorName]->tensor = (Qnn_Tensor_t*)tensor_ptr;
|
| 165 |
+
}
|
| 166 |
+
__DEBUG("Qnn-Gpu-Model : Input Tensor Allocated for {}", graphName);
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
// Setup Outputs
|
| 170 |
+
{
|
| 171 |
+
std::unordered_map<std::string, size_t> outputTensorsSize;
|
| 172 |
+
std::unordered_map<std::string, Qnn_Tensor_t*> sharedTensorMap;
|
| 173 |
+
for (size_t tensorIdx = 0; tensorIdx < graphInfo->numOutputTensors; tensorIdx++) {
|
| 174 |
+
std::string tensorName;
|
| 175 |
+
std::vector<size_t> tensorDims;
|
| 176 |
+
|
| 177 |
+
auto& tensor = graphInfo->outputTensors[tensorIdx];
|
| 178 |
+
_qnnApi->getTensorNameAndShape(tensorName, tensorDims, tensor);
|
| 179 |
+
|
| 180 |
+
if (tensorName.starts_with("past_")) {
|
| 181 |
+
std::string tensorInName = tensorName.substr(0, tensorName.size() - 3) + "in";
|
| 182 |
+
sharedTensorMap[tensorName] = _inputSpecs[graphName][tensorInName]->tensor;
|
| 183 |
+
|
| 184 |
+
// Update Gpu _kvCache
|
| 185 |
+
auto [type, layer_id] = parseKVTensorName(tensorName);
|
| 186 |
+
_kvCache.push_back(
|
| 187 |
+
GpuKVCache((type == 1), layer_id, _inputSpecs[graphName][tensorInName].get()));
|
| 188 |
+
}
|
| 189 |
+
std::vector<QnnUtils::QuantParam> quantParams;
|
| 190 |
+
if (!_qnnApi->getTensorQuantParams(&tensor, quantParams)) {
|
| 191 |
+
quantParams.emplace_back(0, 0);
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
auto dims = QnnUtils::Dims(tensorDims);
|
| 195 |
+
outputTensorsSize[tensorName] = dims.getAlignedSize();
|
| 196 |
+
|
| 197 |
+
__DEBUG("Qnn-Gpu-Model : Output Tensor Info {} {} {} {}",
|
| 198 |
+
tensorIdx,
|
| 199 |
+
tensorName,
|
| 200 |
+
tensorDims,
|
| 201 |
+
outputTensorsSize[tensorName]);
|
| 202 |
+
std::shared_ptr<QnnUtils::Tensor> tensorUtil =
|
| 203 |
+
std::shared_ptr<QnnUtils::Tensor>(new (std::nothrow) QnnUtils::Tensor);
|
| 204 |
+
tensorUtil->dims = dims;
|
| 205 |
+
tensorUtil->dtype = QNN_TENSOR_GET_DATA_TYPE(tensor);
|
| 206 |
+
tensorUtil->quantParam = quantParams;
|
| 207 |
+
_outputSpecs[graphName][tensorName] = tensorUtil;
|
| 208 |
+
}
|
| 209 |
+
|
| 210 |
+
Qnn_Tensor_t* tensor_bank = nullptr;
|
| 211 |
+
std::unordered_map<std::string, void*> tensor_ptr_map;
|
| 212 |
+
if (_ioTensor->getBufferAllocType() == BufferAlloc::DMABUF) {
|
| 213 |
+
if (true != _ioTensor->setupOutputWithSharedTensors(&tensor_bank,
|
| 214 |
+
tensor_ptr_map,
|
| 215 |
+
*graphInfo,
|
| 216 |
+
outputTensorsSize,
|
| 217 |
+
_qnnApi->getContexts()[graphIdx],
|
| 218 |
+
sharedTensorMap)) {
|
| 219 |
+
QNN_ERROR("Qnn-Gpu-Model : Error in setting up Output Tensors for graph %s",
|
| 220 |
+
graphName.c_str());
|
| 221 |
+
return false;
|
| 222 |
+
}
|
| 223 |
+
} else {
|
| 224 |
+
if (true != _ioTensor->setupOutputTensors(&tensor_bank,
|
| 225 |
+
tensor_ptr_map,
|
| 226 |
+
*graphInfo,
|
| 227 |
+
outputTensorsSize,
|
| 228 |
+
_qnnApi->getContexts()[graphIdx],
|
| 229 |
+
false)) {
|
| 230 |
+
QNN_ERROR("Qnn-Gpu-Model : Error in setting up Input Tensors for graph %s",
|
| 231 |
+
graphName.c_str());
|
| 232 |
+
return false;
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
_outputTensors[graphName] = tensor_bank;
|
| 237 |
+
for (auto& [tensorName, tensor_ptr] : tensor_ptr_map) {
|
| 238 |
+
_outputSpecs[graphName][tensorName]->tensor = (Qnn_Tensor_t*)tensor_ptr;
|
| 239 |
+
}
|
| 240 |
+
|
| 241 |
+
__DEBUG("Qnn-Gpu-Model : Output Tensor Allocated {} {}", graphName, _outputTensors.size());
|
| 242 |
+
}
|
| 243 |
+
}
|
| 244 |
+
auto stop = std::chrono::steady_clock::now();
|
| 245 |
+
return true;
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
bool QnnGpuModel::initializeTensorPointers() {
|
| 249 |
+
auto inputSpec = _inputSpecs[_modelOrder.back()];
|
| 250 |
+
auto outputSpec = _outputSpecs[_modelOrder.back()];
|
| 251 |
+
|
| 252 |
+
t_inputIds = inputSpec[INPUT_IDS].get();
|
| 253 |
+
t_attnMask = inputSpec[ATTN_MASK].get();
|
| 254 |
+
t_positionIds = inputSpec[POS_IDS].get();
|
| 255 |
+
t_logits = outputSpec[LOGITS].get();
|
| 256 |
+
|
| 257 |
+
auto status = !(t_inputIds == nullptr || t_attnMask == nullptr || t_positionIds == nullptr ||
|
| 258 |
+
t_logits == nullptr);
|
| 259 |
+
|
| 260 |
+
if (!status) {
|
| 261 |
+
__ERROR("Qnn-Gpu-Model : error in setting up named tensor pointers for llama.");
|
| 262 |
+
return false;
|
| 263 |
+
}
|
| 264 |
+
return true;
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
bool QnnGpuModel::validateModel() {
|
| 268 |
+
// Validating context Size.
|
| 269 |
+
size_t numInputs = t_inputIds->dims.getNumElements();
|
| 270 |
+
size_t dimMask = t_attnMask->dims.getNumElements();
|
| 271 |
+
size_t modelCtxSize = dimMask / numInputs;
|
| 272 |
+
|
| 273 |
+
if (modelCtxSize != _ctxSize) {
|
| 274 |
+
__ERROR("Qnn-Gpu-Model : Invalid Context Size {} {}.", modelCtxSize, _ctxSize);
|
| 275 |
+
return false;
|
| 276 |
+
}
|
| 277 |
+
return true;
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
void QnnGpuModel::setupInputTensors(const std::vector<int32_t>& tokens) {
|
| 281 |
+
auto start = std::chrono::steady_clock::now();
|
| 282 |
+
|
| 283 |
+
if (tokens.size() > _ctxSize) {
|
| 284 |
+
std::string errMsg = "Called inference with more tokens than model supports: ";
|
| 285 |
+
errMsg += std::to_string(tokens.size()) + " vs. " + std::to_string(_ctxSize);
|
| 286 |
+
throw std::runtime_error(errMsg);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
// Setup 1. input_ids
|
| 290 |
+
// Index of input tokens in the embedding vocabulary
|
| 291 |
+
uint32_t* inputIdBuffer = (uint32_t*)getBuffer(t_inputIds);
|
| 292 |
+
if (inputIdBuffer) {
|
| 293 |
+
if (_useDmabufIo) {
|
| 294 |
+
_ioTensor->beforeWriteToBuffer(t_inputIds->tensor);
|
| 295 |
+
}
|
| 296 |
+
inputIdBuffer[0] = tokens[0];
|
| 297 |
+
if (_useDmabufIo) {
|
| 298 |
+
_ioTensor->afterWriteToBuffer(t_inputIds->tensor);
|
| 299 |
+
}
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
// Setup 2. attention_mask
|
| 303 |
+
// Mask to avoid performing attention of padding.
|
| 304 |
+
uint32_t* attnMaskBuffer = (uint32_t*)getBuffer(t_attnMask);
|
| 305 |
+
if (attnMaskBuffer) {
|
| 306 |
+
if (_useDmabufIo) {
|
| 307 |
+
_ioTensor->beforeWriteToBuffer(t_attnMask->tensor);
|
| 308 |
+
}
|
| 309 |
+
attnMaskBuffer[_numTokensProcessed] = 1;
|
| 310 |
+
if (_useDmabufIo) {
|
| 311 |
+
_ioTensor->afterWriteToBuffer(t_attnMask->tensor);
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
// Setup 3. position_ids
|
| 316 |
+
// Indices of positions of each input tokens in position embeddings.
|
| 317 |
+
uint32_t* positionIdBuffer = (uint32_t*)getBuffer(t_positionIds);
|
| 318 |
+
if (positionIdBuffer) {
|
| 319 |
+
if (_useDmabufIo) {
|
| 320 |
+
_ioTensor->beforeWriteToBuffer(t_positionIds->tensor);
|
| 321 |
+
}
|
| 322 |
+
positionIdBuffer[0] = (uint32_t)(_numTokensProcessed);
|
| 323 |
+
if (_useDmabufIo) {
|
| 324 |
+
_ioTensor->afterWriteToBuffer(t_positionIds->tensor);
|
| 325 |
+
}
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
auto stop = std::chrono::steady_clock::now();
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
template <class T1, class T2>
|
| 332 |
+
inline bool QnnGpuModel::executeModel(T1& input, T2& output, std::string graphName) {
|
| 333 |
+
bool ret = _qnnApi->graphExecute(input, output, graphName, timeLogs);
|
| 334 |
+
if (ret != true) {
|
| 335 |
+
QNN_ERROR("Qnn-Gpu-Model : Error executing inference: %d for graph %s", ret, graphName.c_str());
|
| 336 |
+
return false;
|
| 337 |
+
}
|
| 338 |
+
QNN_DEBUG("Qnn-Gpu-Model : Execute finished for graph %s", graphName.c_str());
|
| 339 |
+
return true;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
bool QnnGpuModel::runInferenceHelper(std::vector<std::string>& exec_models,
|
| 343 |
+
int32_t* wait_time_total,
|
| 344 |
+
int32_t* exec_time_total,
|
| 345 |
+
bool pipeline_kv_update,
|
| 346 |
+
size_t update_size) {
|
| 347 |
+
int32_t exec_time = 0;
|
| 348 |
+
int32_t wait_time = 0;
|
| 349 |
+
for (auto& graphName : exec_models) {
|
| 350 |
+
{
|
| 351 |
+
auto start_time = std::chrono::steady_clock::now();
|
| 352 |
+
Qnn_Tensor_t* inputTensors;
|
| 353 |
+
Qnn_Tensor_t* outputTensors;
|
| 354 |
+
try {
|
| 355 |
+
inputTensors = _inputTensors[graphName];
|
| 356 |
+
outputTensors = _outputTensors[graphName];
|
| 357 |
+
} catch (std::exception e) {
|
| 358 |
+
__DEBUG("Qnn-Gpu-Model : Could not find tensors %s", graphName.c_str());
|
| 359 |
+
return false;
|
| 360 |
+
}
|
| 361 |
+
bool status = executeModel(inputTensors, outputTensors, graphName);
|
| 362 |
+
if (!status) {
|
| 363 |
+
return false;
|
| 364 |
+
}
|
| 365 |
+
auto end_time = std::chrono::steady_clock::now();
|
| 366 |
+
exec_time += static_cast<int32_t>(
|
| 367 |
+
std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count());
|
| 368 |
+
}
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
*exec_time_total += exec_time;
|
| 372 |
+
*wait_time_total += wait_time;
|
| 373 |
+
return true;
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
size_t QnnGpuModel::runInference(const std::vector<int32_t>& tokens,
|
| 377 |
+
std::vector<float>& logits,
|
| 378 |
+
bool logits_all) {
|
| 379 |
+
auto start = std::chrono::steady_clock::now();
|
| 380 |
+
int32_t totalWaitTime = 0;
|
| 381 |
+
int32_t totalExecTime = 0;
|
| 382 |
+
|
| 383 |
+
// Setup inputs for inference
|
| 384 |
+
auto& execModels = _modelOrder;
|
| 385 |
+
int numIters = tokens.size();
|
| 386 |
+
for (int i = 0; i < numIters; i++) {
|
| 387 |
+
if (numIters > 1) {
|
| 388 |
+
__DEBUG("Qnn-Gpu-Model : Prompt Processing {} of {} tokens", i + 1, numIters);
|
| 389 |
+
} else {
|
| 390 |
+
__DEBUG("Qnn-Gpu-Model : Token Generation {} of {} tokens", i + 1, numIters);
|
| 391 |
+
}
|
| 392 |
+
std::vector<int32_t> curr_tokens;
|
| 393 |
+
curr_tokens.push_back(tokens[i]);
|
| 394 |
+
setupInputTensors(curr_tokens);
|
| 395 |
+
bool status =
|
| 396 |
+
runInferenceHelper(execModels, &totalWaitTime, &totalExecTime, false, tokens.size());
|
| 397 |
+
if (!status) {
|
| 398 |
+
return 0;
|
| 399 |
+
}
|
| 400 |
+
processLogits(logits, logits_all);
|
| 401 |
+
|
| 402 |
+
// Update the numProcessTokens to updated with Accepted Tokens.
|
| 403 |
+
_numTokensProcessed++;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
auto stop = std::chrono::steady_clock::now();
|
| 407 |
+
timeLogs["Run Inference (cpp) "].first += static_cast<double>(
|
| 408 |
+
std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count());
|
| 409 |
+
timeLogs["Run Inference (cpp) "].second++;
|
| 410 |
+
QNN_DEBUG("[TIME] Wait[%d] Exec[%d]\n", totalWaitTime, totalExecTime);
|
| 411 |
+
if (!logits_all) {
|
| 412 |
+
return 1;
|
| 413 |
+
}
|
| 414 |
+
return tokens.size();
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
// Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h0]_{in,out}
|
| 418 |
+
std::tuple<int, int> QnnGpuModel::parseKVTensorName(std::string name) {
|
| 419 |
+
if (!name.starts_with("past_")) return {0, 0};
|
| 420 |
+
|
| 421 |
+
const bool is_key = name.starts_with("past_key");
|
| 422 |
+
const size_t pos0 = (is_key) ? 9 : 11; // "past_key_" OR "past_value_"
|
| 423 |
+
const size_t pos1 = name.find('_', pos0);
|
| 424 |
+
|
| 425 |
+
int layer_idx = static_cast<int>(std::stoi(name.substr(pos0, pos1 - pos0)));
|
| 426 |
+
|
| 427 |
+
return std::make_tuple(is_key ? 1 : 2, layer_idx);
|
| 428 |
+
}
|
| 429 |
+
|
| 430 |
+
size_t QnnGpuModel::loadKVCache(const std::string& load_path) {
|
| 431 |
+
std::ifstream fs(load_path, std::ios::in | std::ios::binary);
|
| 432 |
+
if (fs.fail()) {
|
| 433 |
+
__ERROR("Qnn-Gpu-Model : loadKVCache errror reading file {}", load_path);
|
| 434 |
+
return 0;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
CacheFileSpec spec;
|
| 438 |
+
fs.read((char*)&spec, sizeof(spec));
|
| 439 |
+
if (spec.magic != g_magicNum) {
|
| 440 |
+
__ERROR("Qnn-Gpu-Model : loadKVCache expected {} found {:#x}", g_magicNum, spec.magic);
|
| 441 |
+
return 0;
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
// clang-format off
|
| 445 |
+
__INFO("Qnn-Gpu-Model : loadKVCache {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
|
| 446 |
+
spec.num_tensors, spec.magic, int(spec.dtype), spec.n_heads, spec.embed_dim, spec.update_size); fflush(stdout);
|
| 447 |
+
// clang-format on
|
| 448 |
+
|
| 449 |
+
_numTokensProcessed = static_cast<size_t>(spec.update_size);
|
| 450 |
+
if (_numTokensProcessed > 0) {
|
| 451 |
+
// Loop over _kvCache tensor and read from file
|
| 452 |
+
for (auto cache : _kvCache) {
|
| 453 |
+
if (_useDmabufIo) {
|
| 454 |
+
_ioTensor->beforeWriteToBuffer(t_inputIds->tensor);
|
| 455 |
+
}
|
| 456 |
+
char* buffer = (char*)getBuffer(cache.tensorUtil);
|
| 457 |
+
if (cache.isKey) {
|
| 458 |
+
// Kye Cache Dims [1, num_heads, head_dim, ctx_size]
|
| 459 |
+
// float16 bits equivalent to uint16_t
|
| 460 |
+
const size_t copySize = _numTokensProcessed;
|
| 461 |
+
const size_t skipSize = _ctxSize;
|
| 462 |
+
for (int i = 0; i < _numHeads; i++) {
|
| 463 |
+
for (int j = 0; j < _headDim; j++) {
|
| 464 |
+
fs.read(buffer, copySize * sizeof(uint16_t));
|
| 465 |
+
buffer += skipSize * sizeof(uint16_t);
|
| 466 |
+
}
|
| 467 |
+
}
|
| 468 |
+
} else {
|
| 469 |
+
// Kye Cache Dims [1, num_heads, ctx_size, head_dim]
|
| 470 |
+
// float16 bits equivalent to uint16_t
|
| 471 |
+
const size_t copySize = _numTokensProcessed * _headDim;
|
| 472 |
+
const size_t skipSize = _ctxSize * _headDim;
|
| 473 |
+
for (int i = 0; i < _numHeads; i++) {
|
| 474 |
+
fs.read(buffer, copySize * sizeof(uint16_t));
|
| 475 |
+
buffer += skipSize * sizeof(uint16_t);
|
| 476 |
+
}
|
| 477 |
+
}
|
| 478 |
+
if (_useDmabufIo) {
|
| 479 |
+
_ioTensor->afterWriteToBuffer(t_inputIds->tensor);
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
}
|
| 483 |
+
return _numTokensProcessed;
|
| 484 |
+
}
|
| 485 |
+
|
| 486 |
+
bool QnnGpuModel::saveKVCache(const std::string& save_path) {
|
| 487 |
+
std::ofstream fs(save_path, std::ios::out | std::ios::binary);
|
| 488 |
+
if (fs.fail()) {
|
| 489 |
+
__ERROR("Qnn-Gpu-Model : saveKVCache error opening file : {}", save_path);
|
| 490 |
+
throw std::runtime_error("Failed to write to cache file. Please re-check path");
|
| 491 |
+
}
|
| 492 |
+
|
| 493 |
+
const CacheFileSpec::DataType dtype = CacheFileSpec::DataType::FLOAT16_T;
|
| 494 |
+
|
| 495 |
+
uint32_t numKVTensors = _kvCache.size();
|
| 496 |
+
|
| 497 |
+
// Save the cache file metadata
|
| 498 |
+
CacheFileSpec file_spec(
|
| 499 |
+
numKVTensors, g_magicNum, dtype, 0x0, _numHeads, _headDim, _numTokensProcessed);
|
| 500 |
+
fs.write((char*)&file_spec, sizeof(file_spec));
|
| 501 |
+
|
| 502 |
+
// clang-format off
|
| 503 |
+
__INFO("Qnn-Gpu-Model : saveKVCache {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
|
| 504 |
+
numKVTensors, g_magicNum, int(dtype), _numHeads, _headDim, _numTokensProcessed); fflush(stdout);
|
| 505 |
+
// clang-format on
|
| 506 |
+
|
| 507 |
+
if (_numTokensProcessed > 0) {
|
| 508 |
+
// Loop over _kvCache tensor and write to file
|
| 509 |
+
for (auto cache : _kvCache) {
|
| 510 |
+
if (_useDmabufIo) {
|
| 511 |
+
_ioTensor->beforeReadFromBuffer(t_inputIds->tensor);
|
| 512 |
+
}
|
| 513 |
+
char* buffer = (char*)getBuffer(cache.tensorUtil);
|
| 514 |
+
if (cache.isKey) {
|
| 515 |
+
// Kye Cache Dims [1, num_heads, head_dim, ctx_size]
|
| 516 |
+
// float16 bits equivalent to uint16_t
|
| 517 |
+
const size_t copySize = _numTokensProcessed;
|
| 518 |
+
const size_t skipSize = _ctxSize;
|
| 519 |
+
for (int i = 0; i < _numHeads; i++) {
|
| 520 |
+
for (int j = 0; j < _headDim; j++) {
|
| 521 |
+
fs.write((char*)buffer, copySize * sizeof(uint16_t));
|
| 522 |
+
buffer += skipSize * sizeof(uint16_t);
|
| 523 |
+
}
|
| 524 |
+
}
|
| 525 |
+
} else {
|
| 526 |
+
// Kye Cache Dims [1, num_heads, ctx_size, head_dim]
|
| 527 |
+
// float16 bits equivalent to uint16_t
|
| 528 |
+
const size_t copySize = _numTokensProcessed * _headDim;
|
| 529 |
+
const size_t skipSize = _ctxSize * _headDim;
|
| 530 |
+
for (int i = 0; i < _numHeads; i++) {
|
| 531 |
+
fs.write((char*)buffer, copySize * sizeof(uint16_t));
|
| 532 |
+
buffer += skipSize;
|
| 533 |
+
}
|
| 534 |
+
}
|
| 535 |
+
if (_useDmabufIo) {
|
| 536 |
+
_ioTensor->afterReadFromBuffer(t_inputIds->tensor);
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
}
|
| 540 |
+
fs.flush();
|
| 541 |
+
fs.close();
|
| 542 |
+
|
| 543 |
+
return true;
|
| 544 |
+
}
|
| 545 |
+
|
| 546 |
+
size_t QnnGpuModel::processLogits(std::vector<float>& logits, bool logits_all) {
|
| 547 |
+
auto logitsSpec = _outputSpecs[_modelOrder.back()][LOGITS].get();
|
| 548 |
+
size_t logitsSize = getNumElements(logitsSpec);
|
| 549 |
+
if (_useDmabufIo) {
|
| 550 |
+
_ioTensor->beforeReadFromBuffer(t_inputIds->tensor);
|
| 551 |
+
}
|
| 552 |
+
uint16_t* logitBuf = (uint16_t*)getBuffer(logitsSpec);
|
| 553 |
+
|
| 554 |
+
if (!logits_all) {
|
| 555 |
+
logits.clear();
|
| 556 |
+
}
|
| 557 |
+
size_t allocateSize = logits.size() + logitsSize;
|
| 558 |
+
logits.reserve(allocateSize);
|
| 559 |
+
for (auto i = 0; i < logitsSize; ++i) {
|
| 560 |
+
logits.push_back(fp16_ieee_to_fp32_value(logitBuf[i]));
|
| 561 |
+
}
|
| 562 |
+
if (_useDmabufIo) {
|
| 563 |
+
_ioTensor->afterReadFromBuffer(t_inputIds->tensor);
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
return logits.size() / logitsSize;
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
bool QnnGpuModel::reset() {
|
| 570 |
+
// Reset Token Counter
|
| 571 |
+
_numTokensProcessed = 0;
|
| 572 |
+
|
| 573 |
+
// Reset Attention Mask
|
| 574 |
+
uint32_t* attnMaskBuffer = (uint32_t*)getBuffer(t_attnMask);
|
| 575 |
+
uint32_t attnMaskSize = getBufferSize(t_attnMask);
|
| 576 |
+
if (attnMaskBuffer) {
|
| 577 |
+
if (_useDmabufIo) {
|
| 578 |
+
_ioTensor->beforeWriteToBuffer(t_attnMask->tensor);
|
| 579 |
+
}
|
| 580 |
+
memset(attnMaskBuffer, 0, attnMaskSize);
|
| 581 |
+
if (_useDmabufIo) {
|
| 582 |
+
_ioTensor->afterWriteToBuffer(t_attnMask->tensor);
|
| 583 |
+
}
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
// Reset KV Cache.
|
| 587 |
+
// TODO : Check if mask_neg -100 is enough to remove
|
| 588 |
+
// effect of KV Cache. Test with mask_neg = -float_inf
|
| 589 |
+
for (auto cache : _kvCache) {
|
| 590 |
+
if (_useDmabufIo) {
|
| 591 |
+
_ioTensor->beforeWriteToBuffer(t_inputIds->tensor);
|
| 592 |
+
}
|
| 593 |
+
char* buffer = (char*)getBuffer(cache.tensorUtil);
|
| 594 |
+
uint32_t bufferSize = getBufferSize(cache.tensorUtil);
|
| 595 |
+
memset(buffer, 0, bufferSize);
|
| 596 |
+
if (_useDmabufIo) {
|
| 597 |
+
_ioTensor->afterWriteToBuffer(t_inputIds->tensor);
|
| 598 |
+
}
|
| 599 |
+
}
|
| 600 |
+
return true;
|
| 601 |
+
}
|
| 602 |
+
|
| 603 |
+
} // namespace qualla
|
Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.hpp
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
|
| 2 |
+
// Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
|
| 3 |
+
|
| 4 |
+
#ifndef __QUALLA_QNN_GPU_MODEL_H_
|
| 5 |
+
#define __QUALLA_QNN_GPU_MODEL_H_
|
| 6 |
+
|
| 7 |
+
#include <atomic>
|
| 8 |
+
#include <filesystem>
|
| 9 |
+
#include <string>
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
#include "IOTensor.hpp"
|
| 13 |
+
#include "QnnApi.hpp"
|
| 14 |
+
#include "qnn-utils.hpp"
|
| 15 |
+
#include "qualla/env.hpp"
|
| 16 |
+
|
| 17 |
+
namespace qualla {
|
| 18 |
+
|
| 19 |
+
// Maintain a list of named tensors for
|
| 20 |
+
static std::string INPUT_IDS = "input_ids";
|
| 21 |
+
static std::string ATTN_MASK = "attention_mask";
|
| 22 |
+
static std::string LOGITS = "logits";
|
| 23 |
+
static std::string POS_IDS = "position_ids";
|
| 24 |
+
|
| 25 |
+
class QnnGpuModel {
|
| 26 |
+
public:
|
| 27 |
+
struct Params {
|
| 28 |
+
std::filesystem::path model_basedir;
|
| 29 |
+
std::vector<std::string> model_list; // model filenames
|
| 30 |
+
uint32_t ctx_size;
|
| 31 |
+
uint32_t num_heads;
|
| 32 |
+
uint32_t head_dim;
|
| 33 |
+
};
|
| 34 |
+
|
| 35 |
+
struct GpuKVCache {
|
| 36 |
+
bool isKey;
|
| 37 |
+
uint32_t tensorId;
|
| 38 |
+
QnnUtils::Tensor* tensorUtil;
|
| 39 |
+
|
| 40 |
+
GpuKVCache() {
|
| 41 |
+
isKey = false;
|
| 42 |
+
tensorUtil = nullptr;
|
| 43 |
+
tensorId = 0;
|
| 44 |
+
}
|
| 45 |
+
GpuKVCache(bool _isKey, uint32_t _tensorId, QnnUtils::Tensor* _tensorUtil)
|
| 46 |
+
: isKey(_isKey), tensorId(_tensorId), tensorUtil(_tensorUtil) {}
|
| 47 |
+
};
|
| 48 |
+
|
| 49 |
+
// QNN specific variables
|
| 50 |
+
std::unique_ptr<QnnApi> _qnnApi;
|
| 51 |
+
std::unique_ptr<IOTensor> _ioTensor{nullptr};
|
| 52 |
+
|
| 53 |
+
// Model Location Storage
|
| 54 |
+
const std::filesystem::path _modelBaseDir;
|
| 55 |
+
std::vector<std::string> _modelList;
|
| 56 |
+
std::vector<std::string> _modelOrder;
|
| 57 |
+
|
| 58 |
+
bool _useDmabufIo;
|
| 59 |
+
|
| 60 |
+
// Model parameters
|
| 61 |
+
uint32_t _ctxSize{0};
|
| 62 |
+
uint32_t _numHeads{0};
|
| 63 |
+
uint32_t _headDim{0};
|
| 64 |
+
|
| 65 |
+
// Information regarding model execution settings and last inference
|
| 66 |
+
|
| 67 |
+
// Model specific variables
|
| 68 |
+
uint32_t _numGraphs;
|
| 69 |
+
// I/O Tensor Informations
|
| 70 |
+
std::unordered_map<std::string, Qnn_Tensor_t*> _inputTensors;
|
| 71 |
+
std::unordered_map<std::string,
|
| 72 |
+
std::unordered_map<std::string, std::shared_ptr<QnnUtils::Tensor>>>
|
| 73 |
+
_inputSpecs;
|
| 74 |
+
|
| 75 |
+
std::unordered_map<std::string, Qnn_Tensor_t*> _outputTensors;
|
| 76 |
+
std::unordered_map<std::string,
|
| 77 |
+
std::unordered_map<std::string, std::shared_ptr<QnnUtils::Tensor>>>
|
| 78 |
+
_outputSpecs;
|
| 79 |
+
|
| 80 |
+
// Store some pointers for easier access
|
| 81 |
+
QnnUtils::Tensor* t_inputIds{nullptr};
|
| 82 |
+
QnnUtils::Tensor* t_attnMask{nullptr};
|
| 83 |
+
QnnUtils::Tensor* t_positionIds{nullptr};
|
| 84 |
+
QnnUtils::Tensor* t_logits{nullptr};
|
| 85 |
+
|
| 86 |
+
// _numTokensProcessed defines number of population of kvcache
|
| 87 |
+
size_t _numTokensProcessed{0};
|
| 88 |
+
|
| 89 |
+
std::vector<GpuKVCache> _kvCache;
|
| 90 |
+
|
| 91 |
+
std::map<std::string, std::pair<double, uint16_t>> timeLogs;
|
| 92 |
+
|
| 93 |
+
// Model Constructor
|
| 94 |
+
QnnGpuModel(Env& env, const Params& params);
|
| 95 |
+
~QnnGpuModel();
|
| 96 |
+
|
| 97 |
+
bool initializeModel(void);
|
| 98 |
+
bool initializeIOTensors(void);
|
| 99 |
+
void setupInputTensors(const std::vector<int32_t>& tokens);
|
| 100 |
+
bool initializeTensorPointers();
|
| 101 |
+
bool validateModel();
|
| 102 |
+
|
| 103 |
+
template <class T1, class T2>
|
| 104 |
+
inline bool executeModel(T1& input, T2& output, std::string graph_name);
|
| 105 |
+
|
| 106 |
+
size_t runInference(const std::vector<int32_t>& tokens,
|
| 107 |
+
std::vector<float>& logits,
|
| 108 |
+
bool logits_all = false);
|
| 109 |
+
|
| 110 |
+
size_t loadKVCache(const std::string& save_path);
|
| 111 |
+
bool saveKVCache(const std::string& load_path);
|
| 112 |
+
bool reset();
|
| 113 |
+
|
| 114 |
+
private:
|
| 115 |
+
Env& _env;
|
| 116 |
+
// Internal functions to separate different runInference logic
|
| 117 |
+
bool runInferenceHelper(std::vector<std::string>& exec_models,
|
| 118 |
+
int32_t* wait_time_total,
|
| 119 |
+
int32_t* exec_time_total,
|
| 120 |
+
bool pipeline_kv_update,
|
| 121 |
+
size_t update_size);
|
| 122 |
+
size_t processLogits(std::vector<float>& logits, bool logits_all);
|
| 123 |
+
inline void* getBuffer(QnnUtils::Tensor& spec) { return _ioTensor->getBuffer(spec.tensor); }
|
| 124 |
+
inline void* getBuffer(QnnUtils::Tensor* spec) { return _ioTensor->getBuffer(spec->tensor); }
|
| 125 |
+
inline size_t getBufferSize(QnnUtils::Tensor& spec) { return spec.dims.getSize(); }
|
| 126 |
+
inline size_t getBufferSize(QnnUtils::Tensor* spec) { return spec->dims.getSize(); }
|
| 127 |
+
inline size_t getNumElements(QnnUtils::Tensor& spec) { return spec.dims.getNumElements(); }
|
| 128 |
+
inline size_t getNumElements(QnnUtils::Tensor* spec) { return spec->dims.getNumElements(); }
|
| 129 |
+
|
| 130 |
+
// Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h0]_{in,out}
|
| 131 |
+
std::tuple<int, int> parseKVTensorName(std::string name);
|
| 132 |
+
};
|
| 133 |
+
|
| 134 |
+
} // namespace qualla
|
| 135 |
+
|
| 136 |
+
#endif
|
Genie/Genie/src/qualla/engines/qnn-htp.cpp
CHANGED
|
@@ -353,11 +353,11 @@ qualla::InputType NspEngine::getInputType(){
|
|
| 353 |
return _model->m_inputType;
|
| 354 |
}
|
| 355 |
|
| 356 |
-
size_t NspEngine::restore(const std::string& name) {
|
| 357 |
if (!_model && !load()) return 0;
|
| 358 |
|
| 359 |
fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-htp", _role);
|
| 360 |
-
return _model->loadKVCache(cache_path.string());
|
| 361 |
}
|
| 362 |
|
| 363 |
bool NspEngine::save(const std::string& name) {
|
|
|
|
| 353 |
return _model->m_inputType;
|
| 354 |
}
|
| 355 |
|
| 356 |
+
size_t NspEngine::restore(const std::string& name, bool chooseHigherVariant) {
|
| 357 |
if (!_model && !load()) return 0;
|
| 358 |
|
| 359 |
fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-htp", _role);
|
| 360 |
+
return _model->loadKVCache(cache_path.string(), chooseHigherVariant);
|
| 361 |
}
|
| 362 |
|
| 363 |
bool NspEngine::save(const std::string& name) {
|
Genie/Genie/src/qualla/engines/qnn-htp.hpp
CHANGED
|
@@ -70,7 +70,7 @@ class NspEngine : public Engine {
|
|
| 70 |
virtual bool updateKV(size_t n_past) override;
|
| 71 |
virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
|
| 72 |
virtual bool save(const std::string& name) override;
|
| 73 |
-
virtual size_t restore(const std::string& name) override;
|
| 74 |
virtual void reset() override;
|
| 75 |
|
| 76 |
virtual bool set(qualla::json data) override;
|
|
|
|
| 70 |
virtual bool updateKV(size_t n_past) override;
|
| 71 |
virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
|
| 72 |
virtual bool save(const std::string& name) override;
|
| 73 |
+
virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
|
| 74 |
virtual void reset() override;
|
| 75 |
|
| 76 |
virtual bool set(qualla::json data) override;
|
Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp
CHANGED
|
@@ -338,7 +338,7 @@ bool NewNSPKVManager::registerPointerOffset() {
|
|
| 338 |
return true;
|
| 339 |
}
|
| 340 |
|
| 341 |
-
|
| 342 |
// clang-format off
|
| 343 |
__TRACE("qnn-kv : graph[{}] updateState to AR-{}(n_past={}, ptr={})", _mgr_idx,
|
| 344 |
_req_state.variant, _req_state.n_past, _req_state.ptr_offset);
|
|
@@ -354,9 +354,15 @@ bool NewNSPKVManager::updateState() {
|
|
| 354 |
cache.output_buffer += cache.is_key ? _n_ctx * _bw : _n_ctx * _n_embed * _bw;
|
| 355 |
}
|
| 356 |
}
|
| 357 |
-
|
| 358 |
_cur_state = _req_state;
|
|
|
|
|
|
|
| 359 |
_counter = _callback_fn(_mgr_idx);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
return true;
|
| 361 |
}
|
| 362 |
|
|
@@ -525,7 +531,7 @@ bool NewNSPKVManager::loadCache(
|
|
| 525 |
}
|
| 526 |
|
| 527 |
_req_state = {variant, n_valid, 0};
|
| 528 |
-
|
| 529 |
|
| 530 |
return true;
|
| 531 |
}
|
|
|
|
| 338 |
return true;
|
| 339 |
}
|
| 340 |
|
| 341 |
+
void NewNSPKVManager::updateKVCache(){
|
| 342 |
// clang-format off
|
| 343 |
__TRACE("qnn-kv : graph[{}] updateState to AR-{}(n_past={}, ptr={})", _mgr_idx,
|
| 344 |
_req_state.variant, _req_state.n_past, _req_state.ptr_offset);
|
|
|
|
| 354 |
cache.output_buffer += cache.is_key ? _n_ctx * _bw : _n_ctx * _n_embed * _bw;
|
| 355 |
}
|
| 356 |
}
|
|
|
|
| 357 |
_cur_state = _req_state;
|
| 358 |
+
}
|
| 359 |
+
void NewNSPKVManager::updateKVDispatcher(){
|
| 360 |
_counter = _callback_fn(_mgr_idx);
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
bool NewNSPKVManager::updateState() {
|
| 364 |
+
updateKVCache();
|
| 365 |
+
updateKVDispatcher();
|
| 366 |
return true;
|
| 367 |
}
|
| 368 |
|
|
|
|
| 531 |
}
|
| 532 |
|
| 533 |
_req_state = {variant, n_valid, 0};
|
| 534 |
+
updateKVCache();
|
| 535 |
|
| 536 |
return true;
|
| 537 |
}
|
Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp
CHANGED
|
@@ -134,7 +134,8 @@ class NewNSPKVManager {
|
|
| 134 |
int32_t n_heads
|
| 135 |
);
|
| 136 |
bool dumpCache(std::ofstream* fs, bool is_key, int32_t n_valid, int32_t n_heads);
|
| 137 |
-
|
|
|
|
| 138 |
bool updateState();
|
| 139 |
void runKVUpdateJob(int thread_idx); // Worker thread function
|
| 140 |
void setTensorAllocInfo(std::map<std::string, std::pair<int, size_t>>* alloc_info) {
|
|
|
|
| 134 |
int32_t n_heads
|
| 135 |
);
|
| 136 |
bool dumpCache(std::ofstream* fs, bool is_key, int32_t n_valid, int32_t n_heads);
|
| 137 |
+
void updateKVCache();
|
| 138 |
+
void updateKVDispatcher();
|
| 139 |
bool updateState();
|
| 140 |
void runKVUpdateJob(int thread_idx); // Worker thread function
|
| 141 |
void setTensorAllocInfo(std::map<std::string, std::pair<int, size_t>>* alloc_info) {
|
Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp
CHANGED
|
@@ -1960,6 +1960,9 @@ bool QnnNspModel::calculate_rope_embeddings(void) {
|
|
| 1960 |
const size_t nmemb = m_ctx_size * m_pos_dim;
|
| 1961 |
const int pos_bw = d_pos.bw();
|
| 1962 |
|
|
|
|
|
|
|
|
|
|
| 1963 |
rope_sin = malloc(nmemb * pos_bw);
|
| 1964 |
rope_cos = malloc(nmemb * pos_bw);
|
| 1965 |
|
|
@@ -1973,7 +1976,7 @@ bool QnnNspModel::calculate_rope_embeddings(void) {
|
|
| 1973 |
std::vector<double> inv_freq(m_pos_dim);
|
| 1974 |
const double exponent = 1.0 / static_cast<double>(m_pos_dim);
|
| 1975 |
for (int j = 0; j < m_pos_dim; j++)
|
| 1976 |
-
inv_freq[j] = 1.0 / pow(
|
| 1977 |
double attention_factor = 1.0;
|
| 1978 |
if (rope_scaling.rope_type == RopeScalingParams::ROPE_LLAMA3) {
|
| 1979 |
// Implemented from HuggingFace
|
|
@@ -1991,7 +1994,7 @@ bool QnnNspModel::calculate_rope_embeddings(void) {
|
|
| 1991 |
if (wavelen < high_freq_wavelen) // wavelen < high_freq_wavelen: do nothing
|
| 1992 |
continue;
|
| 1993 |
else if (wavelen > low_freq_wavelen) // wavelen > low_freq_wavelen: divide by factor
|
| 1994 |
-
inv_freq[j] = 1.0 / static_cast<double>(factor * pow(
|
| 1995 |
else { // otherwise: interpolate between the two, using a smooth factor
|
| 1996 |
assert(low_freq_wavelen != high_freq_wavelen);
|
| 1997 |
const double smooth =
|
|
@@ -2266,7 +2269,7 @@ void QnnNspModel::dumpTensorSpecs() {
|
|
| 2266 |
}
|
| 2267 |
}
|
| 2268 |
|
| 2269 |
-
size_t QnnNspModel::loadKVCache(const std::string& load_path) {
|
| 2270 |
|
| 2271 |
if(m_disableKvCache){
|
| 2272 |
__ERROR("KV cache is disabled, loading KV cache is not allowed");
|
|
@@ -2308,7 +2311,8 @@ size_t QnnNspModel::loadKVCache(const std::string& load_path) {
|
|
| 2308 |
// clang-format on
|
| 2309 |
|
| 2310 |
const int32_t n_valid = static_cast<int32_t>(spec.update_size);
|
| 2311 |
-
|
|
|
|
| 2312 |
_kv_dispatcher->setVariant(variant);
|
| 2313 |
|
| 2314 |
// Lock, load KeyCache then ValueCache, unlock
|
|
|
|
| 1960 |
const size_t nmemb = m_ctx_size * m_pos_dim;
|
| 1961 |
const int pos_bw = d_pos.bw();
|
| 1962 |
|
| 1963 |
+
const double theta = m_positional_encoding.rope_params.theta;
|
| 1964 |
+
const RopeScalingParams& rope_scaling = m_positional_encoding.rope_params.rope_scaling;
|
| 1965 |
+
|
| 1966 |
rope_sin = malloc(nmemb * pos_bw);
|
| 1967 |
rope_cos = malloc(nmemb * pos_bw);
|
| 1968 |
|
|
|
|
| 1976 |
std::vector<double> inv_freq(m_pos_dim);
|
| 1977 |
const double exponent = 1.0 / static_cast<double>(m_pos_dim);
|
| 1978 |
for (int j = 0; j < m_pos_dim; j++)
|
| 1979 |
+
inv_freq[j] = 1.0 / pow(theta, j * exponent);
|
| 1980 |
double attention_factor = 1.0;
|
| 1981 |
if (rope_scaling.rope_type == RopeScalingParams::ROPE_LLAMA3) {
|
| 1982 |
// Implemented from HuggingFace
|
|
|
|
| 1994 |
if (wavelen < high_freq_wavelen) // wavelen < high_freq_wavelen: do nothing
|
| 1995 |
continue;
|
| 1996 |
else if (wavelen > low_freq_wavelen) // wavelen > low_freq_wavelen: divide by factor
|
| 1997 |
+
inv_freq[j] = 1.0 / static_cast<double>(factor * pow(theta, j * exponent));
|
| 1998 |
else { // otherwise: interpolate between the two, using a smooth factor
|
| 1999 |
assert(low_freq_wavelen != high_freq_wavelen);
|
| 2000 |
const double smooth =
|
|
|
|
| 2269 |
}
|
| 2270 |
}
|
| 2271 |
|
| 2272 |
+
size_t QnnNspModel::loadKVCache(const std::string& load_path, bool chooseHigherVariant) {
|
| 2273 |
|
| 2274 |
if(m_disableKvCache){
|
| 2275 |
__ERROR("KV cache is disabled, loading KV cache is not allowed");
|
|
|
|
| 2311 |
// clang-format on
|
| 2312 |
|
| 2313 |
const int32_t n_valid = static_cast<int32_t>(spec.update_size);
|
| 2314 |
+
int32_t variant = nsp_graph_count.begin()->first; // Set KVManager to smallest variant
|
| 2315 |
+
if(chooseHigherVariant) variant = nsp_graph_count.rbegin()->first; // Ideal for loading KV prefix cache
|
| 2316 |
_kv_dispatcher->setVariant(variant);
|
| 2317 |
|
| 2318 |
// Lock, load KeyCache then ValueCache, unlock
|
Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp
CHANGED
|
@@ -54,14 +54,14 @@ struct RopeScalingParams {
|
|
| 54 |
double low_freq_factor;
|
| 55 |
double high_freq_factor;
|
| 56 |
int original_max_position_embeddings;
|
| 57 |
-
} llama3_params;
|
| 58 |
|
| 59 |
struct {
|
| 60 |
double factor;
|
| 61 |
std::vector<double> long_factor;
|
| 62 |
std::vector<double> short_factor;
|
| 63 |
int original_max_position_embeddings;
|
| 64 |
-
} longrope_params;
|
| 65 |
|
| 66 |
RopeScalingParams() {}
|
| 67 |
};
|
|
@@ -79,7 +79,7 @@ struct PositionalEncoding {
|
|
| 79 |
int32_t dims;
|
| 80 |
double theta;
|
| 81 |
RopeScalingParams rope_scaling;
|
| 82 |
-
} rope_params;
|
| 83 |
|
| 84 |
PositionalEncoding() { type = ROPE; }
|
| 85 |
};
|
|
@@ -265,10 +265,8 @@ class QnnNspModel {
|
|
| 265 |
QnnUtils::Tensor* t_position_ids{nullptr};
|
| 266 |
// PositionalEncodingType::ROPE variables
|
| 267 |
int32_t m_pos_dim{-1}; // Dimension of positional embedding tensor (incl partial_factor)
|
| 268 |
-
double rope_theta{10000.0}; // Base theta parameter for RoPE calculations
|
| 269 |
void* rope_sin{nullptr}; // Pre-calculated RoPE sin table of size [ctx_size, m_pos_dim]
|
| 270 |
void* rope_cos{nullptr}; // Pre-calculated RoPE cos table of size [ctx_size, m_pos_dim]
|
| 271 |
-
RopeScalingParams rope_scaling; // RoPE scaling parameters
|
| 272 |
|
| 273 |
QnnUtils::Tensor* t_position_ids_sin{nullptr};
|
| 274 |
QnnUtils::Tensor* t_position_ids_cos{nullptr};
|
|
@@ -398,7 +396,7 @@ class QnnNspModel {
|
|
| 398 |
|
| 399 |
bool debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName);
|
| 400 |
|
| 401 |
-
size_t loadKVCache(const std::string& load_path);
|
| 402 |
bool saveKVCache(const std::string& save_path);
|
| 403 |
bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
|
| 404 |
bool applyLoraAdapter(const std::string& lora_adapter_name);
|
|
|
|
| 54 |
double low_freq_factor;
|
| 55 |
double high_freq_factor;
|
| 56 |
int original_max_position_embeddings;
|
| 57 |
+
} llama3_params {0};
|
| 58 |
|
| 59 |
struct {
|
| 60 |
double factor;
|
| 61 |
std::vector<double> long_factor;
|
| 62 |
std::vector<double> short_factor;
|
| 63 |
int original_max_position_embeddings;
|
| 64 |
+
} longrope_params {0};
|
| 65 |
|
| 66 |
RopeScalingParams() {}
|
| 67 |
};
|
|
|
|
| 79 |
int32_t dims;
|
| 80 |
double theta;
|
| 81 |
RopeScalingParams rope_scaling;
|
| 82 |
+
} rope_params {0};
|
| 83 |
|
| 84 |
PositionalEncoding() { type = ROPE; }
|
| 85 |
};
|
|
|
|
| 265 |
QnnUtils::Tensor* t_position_ids{nullptr};
|
| 266 |
// PositionalEncodingType::ROPE variables
|
| 267 |
int32_t m_pos_dim{-1}; // Dimension of positional embedding tensor (incl partial_factor)
|
|
|
|
| 268 |
void* rope_sin{nullptr}; // Pre-calculated RoPE sin table of size [ctx_size, m_pos_dim]
|
| 269 |
void* rope_cos{nullptr}; // Pre-calculated RoPE cos table of size [ctx_size, m_pos_dim]
|
|
|
|
| 270 |
|
| 271 |
QnnUtils::Tensor* t_position_ids_sin{nullptr};
|
| 272 |
QnnUtils::Tensor* t_position_ids_cos{nullptr};
|
|
|
|
| 396 |
|
| 397 |
bool debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName);
|
| 398 |
|
| 399 |
+
size_t loadKVCache(const std::string& load_path, bool chooseHigherVariant=false);
|
| 400 |
bool saveKVCache(const std::string& save_path);
|
| 401 |
bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
|
| 402 |
bool applyLoraAdapter(const std::string& lora_adapter_name);
|
Genie/Genie/src/qualla/include/qualla/detail/basic-sampler.hpp
CHANGED
|
@@ -39,6 +39,7 @@ class BasicSampler : public Sampler {
|
|
| 39 |
virtual bool save(const std::string& name) override;
|
| 40 |
virtual bool restore(const std::string& name) override;
|
| 41 |
virtual void reset() override;
|
|
|
|
| 42 |
|
| 43 |
protected:
|
| 44 |
int32_t _process(std::span<const float> logits, std::vector<float>* probs_out, bool samp_tok);
|
|
|
|
| 39 |
virtual bool save(const std::string& name) override;
|
| 40 |
virtual bool restore(const std::string& name) override;
|
| 41 |
virtual void reset() override;
|
| 42 |
+
virtual void applyConfig(const qualla::json& conf) override;
|
| 43 |
|
| 44 |
protected:
|
| 45 |
int32_t _process(std::span<const float> logits, std::vector<float>* probs_out, bool samp_tok);
|
Genie/Genie/src/qualla/include/qualla/dialog.hpp
CHANGED
|
@@ -107,6 +107,7 @@ class Dialog : public State {
|
|
| 107 |
Tokenizer& tokenizer() { return *_tokenizer; }
|
| 108 |
Sampler& sampler(const std::string& role = "primary") { return *_sampler[role]; }
|
| 109 |
Engine& engine(const std::string& role = "primary") { return *_engine[role]; }
|
|
|
|
| 110 |
|
| 111 |
// Get latest KPIs.
|
| 112 |
// Updates TPS, etc as needed.
|
|
|
|
| 107 |
Tokenizer& tokenizer() { return *_tokenizer; }
|
| 108 |
Sampler& sampler(const std::string& role = "primary") { return *_sampler[role]; }
|
| 109 |
Engine& engine(const std::string& role = "primary") { return *_engine[role]; }
|
| 110 |
+
bool isSamplerPresent(std::string role) { return _sampler.find(role) != _sampler.end(); }
|
| 111 |
|
| 112 |
// Get latest KPIs.
|
| 113 |
// Updates TPS, etc as needed.
|
Genie/Genie/src/qualla/include/qualla/engine.hpp
CHANGED
|
@@ -86,7 +86,7 @@ class Engine : public State {
|
|
| 86 |
QUALLA_API virtual bool updateKV(size_t n_past, const std::vector<bool>& selected);
|
| 87 |
|
| 88 |
QUALLA_API virtual bool save(const std::string& name);
|
| 89 |
-
QUALLA_API virtual size_t restore(const std::string& name);
|
| 90 |
QUALLA_API virtual void reset();
|
| 91 |
|
| 92 |
QUALLA_API virtual bool cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding);
|
|
|
|
| 86 |
QUALLA_API virtual bool updateKV(size_t n_past, const std::vector<bool>& selected);
|
| 87 |
|
| 88 |
QUALLA_API virtual bool save(const std::string& name);
|
| 89 |
+
QUALLA_API virtual size_t restore(const std::string& name, bool chooseHigherVariant=false);
|
| 90 |
QUALLA_API virtual void reset();
|
| 91 |
|
| 92 |
QUALLA_API virtual bool cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding);
|
Genie/Genie/src/qualla/include/qualla/sampler.hpp
CHANGED
|
@@ -54,6 +54,7 @@ class Sampler : public State {
|
|
| 54 |
QUALLA_API virtual bool save(const std::string& name);
|
| 55 |
QUALLA_API virtual bool restore(const std::string& name);
|
| 56 |
QUALLA_API virtual void reset();
|
|
|
|
| 57 |
|
| 58 |
// Get sampler type
|
| 59 |
const std::string& type() const { return _type; }
|
|
|
|
| 54 |
QUALLA_API virtual bool save(const std::string& name);
|
| 55 |
QUALLA_API virtual bool restore(const std::string& name);
|
| 56 |
QUALLA_API virtual void reset();
|
| 57 |
+
QUALLA_API virtual void applyConfig(const qualla::json& conf);
|
| 58 |
|
| 59 |
// Get sampler type
|
| 60 |
const std::string& type() const { return _type; }
|
Genie/Genie/src/qualla/sampler.cpp
CHANGED
|
@@ -84,6 +84,10 @@ std::vector<int32_t> Sampler::process_multiple(
|
|
| 84 |
return {-1};
|
| 85 |
}
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
// Sampler registry
|
| 88 |
|
| 89 |
using Registry = std::unordered_map<std::string, Sampler::Creator>;
|
|
|
|
| 84 |
return {-1};
|
| 85 |
}
|
| 86 |
|
| 87 |
+
void Sampler::applyConfig(const qualla::json& conf) {
|
| 88 |
+
_env.logger().warn(fmt::format("Basic sampler supports this for now"));
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
// Sampler registry
|
| 92 |
|
| 93 |
using Registry = std::unordered_map<std::string, Sampler::Creator>;
|
Genie/Genie/src/qualla/samplers/basic.cpp
CHANGED
|
@@ -221,4 +221,12 @@ static OnLoad regy([]() {
|
|
| 221 |
|
| 222 |
void needBasicSampler() {}
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
} // namespace qualla
|
|
|
|
| 221 |
|
| 222 |
void needBasicSampler() {}
|
| 223 |
|
| 224 |
+
void BasicSampler::applyConfig(const json& conf) {
|
| 225 |
+
if (conf.contains("seed")) _seed = conf["seed"];
|
| 226 |
+
if (conf.contains("temp")) _temp = conf["temp"];
|
| 227 |
+
|
| 228 |
+
if (conf.contains("top-k")) _top_k = conf["top-k"];
|
| 229 |
+
if (conf.contains("top-p")) _top_p = conf["top-p"];
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
} // namespace qualla
|
Genie/Genie/src/qualla/tokenizers/rust/Cargo.lock
CHANGED
|
@@ -31,9 +31,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
|
| 31 |
|
| 32 |
[[package]]
|
| 33 |
name = "cc"
|
| 34 |
-
version = "1.1
|
| 35 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 36 |
-
checksum = "
|
| 37 |
dependencies = [
|
| 38 |
"shlex",
|
| 39 |
]
|
|
@@ -190,9 +190,9 @@ dependencies = [
|
|
| 190 |
|
| 191 |
[[package]]
|
| 192 |
name = "itoa"
|
| 193 |
-
version = "1.0.
|
| 194 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 195 |
-
checksum = "
|
| 196 |
|
| 197 |
[[package]]
|
| 198 |
name = "lazy_static"
|
|
@@ -202,9 +202,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
|
| 202 |
|
| 203 |
[[package]]
|
| 204 |
name = "libc"
|
| 205 |
-
version = "0.2.
|
| 206 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 207 |
-
checksum = "
|
| 208 |
|
| 209 |
[[package]]
|
| 210 |
name = "log"
|
|
@@ -322,9 +322,9 @@ dependencies = [
|
|
| 322 |
|
| 323 |
[[package]]
|
| 324 |
name = "proc-macro2"
|
| 325 |
-
version = "1.0.
|
| 326 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 327 |
-
checksum = "
|
| 328 |
dependencies = [
|
| 329 |
"unicode-ident",
|
| 330 |
]
|
|
@@ -413,9 +413,9 @@ dependencies = [
|
|
| 413 |
|
| 414 |
[[package]]
|
| 415 |
name = "regex-automata"
|
| 416 |
-
version = "0.4.
|
| 417 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 418 |
-
checksum = "
|
| 419 |
dependencies = [
|
| 420 |
"aho-corasick",
|
| 421 |
"memchr",
|
|
@@ -436,18 +436,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
|
|
| 436 |
|
| 437 |
[[package]]
|
| 438 |
name = "serde"
|
| 439 |
-
version = "1.0.
|
| 440 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 441 |
-
checksum = "
|
| 442 |
dependencies = [
|
| 443 |
"serde_derive",
|
| 444 |
]
|
| 445 |
|
| 446 |
[[package]]
|
| 447 |
name = "serde_derive"
|
| 448 |
-
version = "1.0.
|
| 449 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 450 |
-
checksum = "
|
| 451 |
dependencies = [
|
| 452 |
"proc-macro2",
|
| 453 |
"quote",
|
|
@@ -456,9 +456,9 @@ dependencies = [
|
|
| 456 |
|
| 457 |
[[package]]
|
| 458 |
name = "serde_json"
|
| 459 |
-
version = "1.0.
|
| 460 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 461 |
-
checksum = "
|
| 462 |
dependencies = [
|
| 463 |
"itoa",
|
| 464 |
"memchr",
|
|
@@ -498,9 +498,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
|
|
| 498 |
|
| 499 |
[[package]]
|
| 500 |
name = "syn"
|
| 501 |
-
version = "2.0.
|
| 502 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 503 |
-
checksum = "
|
| 504 |
dependencies = [
|
| 505 |
"proc-macro2",
|
| 506 |
"quote",
|
|
@@ -509,18 +509,18 @@ dependencies = [
|
|
| 509 |
|
| 510 |
[[package]]
|
| 511 |
name = "thiserror"
|
| 512 |
-
version = "1.0.
|
| 513 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 514 |
-
checksum = "
|
| 515 |
dependencies = [
|
| 516 |
"thiserror-impl",
|
| 517 |
]
|
| 518 |
|
| 519 |
[[package]]
|
| 520 |
name = "thiserror-impl"
|
| 521 |
-
version = "1.0.
|
| 522 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 523 |
-
checksum = "
|
| 524 |
dependencies = [
|
| 525 |
"proc-macro2",
|
| 526 |
"quote",
|
|
@@ -529,9 +529,9 @@ dependencies = [
|
|
| 529 |
|
| 530 |
[[package]]
|
| 531 |
name = "tokenizers"
|
| 532 |
-
version = "0.20.
|
| 533 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 534 |
-
checksum = "
|
| 535 |
dependencies = [
|
| 536 |
"aho-corasick",
|
| 537 |
"derive_builder",
|
|
@@ -569,9 +569,9 @@ dependencies = [
|
|
| 569 |
|
| 570 |
[[package]]
|
| 571 |
name = "unicode-ident"
|
| 572 |
-
version = "1.0.
|
| 573 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 574 |
-
checksum = "
|
| 575 |
|
| 576 |
[[package]]
|
| 577 |
name = "unicode-normalization-alignments"
|
|
|
|
| 31 |
|
| 32 |
[[package]]
|
| 33 |
name = "cc"
|
| 34 |
+
version = "1.2.1"
|
| 35 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 36 |
+
checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
|
| 37 |
dependencies = [
|
| 38 |
"shlex",
|
| 39 |
]
|
|
|
|
| 190 |
|
| 191 |
[[package]]
|
| 192 |
name = "itoa"
|
| 193 |
+
version = "1.0.13"
|
| 194 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 195 |
+
checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2"
|
| 196 |
|
| 197 |
[[package]]
|
| 198 |
name = "lazy_static"
|
|
|
|
| 202 |
|
| 203 |
[[package]]
|
| 204 |
name = "libc"
|
| 205 |
+
version = "0.2.164"
|
| 206 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 207 |
+
checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
|
| 208 |
|
| 209 |
[[package]]
|
| 210 |
name = "log"
|
|
|
|
| 322 |
|
| 323 |
[[package]]
|
| 324 |
name = "proc-macro2"
|
| 325 |
+
version = "1.0.91"
|
| 326 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 327 |
+
checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3"
|
| 328 |
dependencies = [
|
| 329 |
"unicode-ident",
|
| 330 |
]
|
|
|
|
| 413 |
|
| 414 |
[[package]]
|
| 415 |
name = "regex-automata"
|
| 416 |
+
version = "0.4.9"
|
| 417 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 418 |
+
checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
|
| 419 |
dependencies = [
|
| 420 |
"aho-corasick",
|
| 421 |
"memchr",
|
|
|
|
| 436 |
|
| 437 |
[[package]]
|
| 438 |
name = "serde"
|
| 439 |
+
version = "1.0.215"
|
| 440 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 441 |
+
checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
|
| 442 |
dependencies = [
|
| 443 |
"serde_derive",
|
| 444 |
]
|
| 445 |
|
| 446 |
[[package]]
|
| 447 |
name = "serde_derive"
|
| 448 |
+
version = "1.0.215"
|
| 449 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 450 |
+
checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
|
| 451 |
dependencies = [
|
| 452 |
"proc-macro2",
|
| 453 |
"quote",
|
|
|
|
| 456 |
|
| 457 |
[[package]]
|
| 458 |
name = "serde_json"
|
| 459 |
+
version = "1.0.133"
|
| 460 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 461 |
+
checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
|
| 462 |
dependencies = [
|
| 463 |
"itoa",
|
| 464 |
"memchr",
|
|
|
|
| 498 |
|
| 499 |
[[package]]
|
| 500 |
name = "syn"
|
| 501 |
+
version = "2.0.89"
|
| 502 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 503 |
+
checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e"
|
| 504 |
dependencies = [
|
| 505 |
"proc-macro2",
|
| 506 |
"quote",
|
|
|
|
| 509 |
|
| 510 |
[[package]]
|
| 511 |
name = "thiserror"
|
| 512 |
+
version = "1.0.69"
|
| 513 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 514 |
+
checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
|
| 515 |
dependencies = [
|
| 516 |
"thiserror-impl",
|
| 517 |
]
|
| 518 |
|
| 519 |
[[package]]
|
| 520 |
name = "thiserror-impl"
|
| 521 |
+
version = "1.0.69"
|
| 522 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 523 |
+
checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
|
| 524 |
dependencies = [
|
| 525 |
"proc-macro2",
|
| 526 |
"quote",
|
|
|
|
| 529 |
|
| 530 |
[[package]]
|
| 531 |
name = "tokenizers"
|
| 532 |
+
version = "0.20.3"
|
| 533 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 534 |
+
checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7"
|
| 535 |
dependencies = [
|
| 536 |
"aho-corasick",
|
| 537 |
"derive_builder",
|
|
|
|
| 569 |
|
| 570 |
[[package]]
|
| 571 |
name = "unicode-ident"
|
| 572 |
+
version = "1.0.14"
|
| 573 |
source = "registry+https://github.com/rust-lang/crates.io-index"
|
| 574 |
+
checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
|
| 575 |
|
| 576 |
[[package]]
|
| 577 |
name = "unicode-normalization-alignments"
|
Genie/Model/model.cpp
CHANGED
|
@@ -179,8 +179,29 @@ MODEL_LIB_EXPORT ModelError_t QnnModel_GenAI_composeGraphs(Qnn_BackendHandle_t b
|
|
| 179 |
(Qnn_Tensor_t)tin6),
|
| 180 |
err);
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
/* ADDING NODE FOR genAI */
|
| 183 |
-
const char* inputs_genAI[] = {"x0", "x1", "x2", "x3", "x4", "x5"};
|
| 184 |
|
| 185 |
Qnn_Tensor_t tout;
|
| 186 |
tout.version = QNN_TENSOR_VERSION_1;
|
|
@@ -224,7 +245,7 @@ MODEL_LIB_EXPORT ModelError_t QnnModel_GenAI_composeGraphs(Qnn_BackendHandle_t b
|
|
| 224 |
params, // Node Params
|
| 225 |
numParams, // Num Node Params
|
| 226 |
inputs_genAI, // Input Tensor Names
|
| 227 |
-
|
| 228 |
outputs_genAI, // Output Tensors
|
| 229 |
2 // Num Output Tensors
|
| 230 |
),
|
|
|
|
| 179 |
(Qnn_Tensor_t)tin6),
|
| 180 |
err);
|
| 181 |
|
| 182 |
+
uint32_t input6Dim[1] = {1};
|
| 183 |
+
Qnn_Tensor_t tin7;
|
| 184 |
+
tin7.version = QNN_TENSOR_VERSION_1;
|
| 185 |
+
tin7.v1.id = 0;
|
| 186 |
+
tin7.v1.name = "x6";
|
| 187 |
+
tin7.v1.type = QNN_TENSOR_TYPE_APP_WRITE;
|
| 188 |
+
tin7.v1.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER;
|
| 189 |
+
tin7.v1.dataType = QNN_DATATYPE_FLOAT_32;
|
| 190 |
+
tin7.v1.quantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;
|
| 191 |
+
tin7.v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
|
| 192 |
+
tin7.v1.quantizeParams.scaleOffsetEncoding = {.scale = 0.0000000000000000f,
|
| 193 |
+
.offset = 0};
|
| 194 |
+
tin7.v1.rank = 1;
|
| 195 |
+
tin7.v1.dimensions = input6Dim;
|
| 196 |
+
tin7.v1.memType = QNN_TENSORMEMTYPE_RAW;
|
| 197 |
+
tin7.v1.clientBuf = {.data = nullptr, .dataSize = 0};
|
| 198 |
+
VALIDATE(qnn_model.addTensor(
|
| 199 |
+
"x6", // Node Name
|
| 200 |
+
(Qnn_Tensor_t)tin7),
|
| 201 |
+
err);
|
| 202 |
+
|
| 203 |
/* ADDING NODE FOR genAI */
|
| 204 |
+
const char* inputs_genAI[] = {"x0", "x1", "x2", "x3", "x4", "x5", "x6"};
|
| 205 |
|
| 206 |
Qnn_Tensor_t tout;
|
| 207 |
tout.version = QNN_TENSOR_VERSION_1;
|
|
|
|
| 245 |
params, // Node Params
|
| 246 |
numParams, // Num Node Params
|
| 247 |
inputs_genAI, // Input Tensor Names
|
| 248 |
+
7, // Num Input Tensor Names
|
| 249 |
outputs_genAI, // Output Tensors
|
| 250 |
2 // Num Output Tensors
|
| 251 |
),
|
Genie/configs/llama2-7b/llama2-7b-draft-htp-target-htp-spd.json
CHANGED
|
@@ -43,7 +43,8 @@
|
|
| 43 |
"cpu-mask": "0xe0",
|
| 44 |
"kv-dim": 64,
|
| 45 |
"kv-update-method": "SHIFT_CONCAT",
|
| 46 |
-
"allow-async-init": false
|
|
|
|
| 47 |
},
|
| 48 |
"extensions": "htp_backend_ext_config.json"
|
| 49 |
},
|
|
|
|
| 43 |
"cpu-mask": "0xe0",
|
| 44 |
"kv-dim": 64,
|
| 45 |
"kv-update-method": "SHIFT_CONCAT",
|
| 46 |
+
"allow-async-init": false,
|
| 47 |
+
"enable-graph-switching": false
|
| 48 |
},
|
| 49 |
"extensions": "htp_backend_ext_config.json"
|
| 50 |
},
|
Genie/configs/llama2-7b/llama2-7b-genaitransformer-lora.json
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dialog" : {
|
| 3 |
+
"version" : 1,
|
| 4 |
+
"type" : "basic",
|
| 5 |
+
"stop-sequence" : [""],
|
| 6 |
+
"max-num-tokens" : 200,
|
| 7 |
+
"context" : {
|
| 8 |
+
"version" : 1,
|
| 9 |
+
"size": 512,
|
| 10 |
+
"n-vocab": 32000,
|
| 11 |
+
"bos-token": 1,
|
| 12 |
+
"eos-token": 2
|
| 13 |
+
},
|
| 14 |
+
"sampler" : {
|
| 15 |
+
"version" : 1,
|
| 16 |
+
"seed" : 100,
|
| 17 |
+
"temp" : 1.2,
|
| 18 |
+
"top-k" : 20,
|
| 19 |
+
"top-p" : 0.75,
|
| 20 |
+
"greedy" : false
|
| 21 |
+
},
|
| 22 |
+
"tokenizer" : {
|
| 23 |
+
"version" : 1,
|
| 24 |
+
"path" : "your/path/to/tokenizer_file.json"
|
| 25 |
+
},
|
| 26 |
+
"engine" : {
|
| 27 |
+
"version" : 1,
|
| 28 |
+
"n-threads" : 6,
|
| 29 |
+
"backend" : {
|
| 30 |
+
"version" : 1,
|
| 31 |
+
"type" : "QnnGenAiTransformer",
|
| 32 |
+
"QnnGenAiTransformer" : {
|
| 33 |
+
"version" : 1,
|
| 34 |
+
"n-layer": 32,
|
| 35 |
+
"n-embd": 4096,
|
| 36 |
+
"n-heads": 32
|
| 37 |
+
}
|
| 38 |
+
},
|
| 39 |
+
"model" : {
|
| 40 |
+
"version" : 1,
|
| 41 |
+
"type" : "library",
|
| 42 |
+
"library" : {
|
| 43 |
+
"version" : 1,
|
| 44 |
+
"model-bin" : "your/path/to/model/file.bin",
|
| 45 |
+
"lora": {
|
| 46 |
+
"version": 1,
|
| 47 |
+
"alpha-tensor-name": "alpha",
|
| 48 |
+
"adapters": [
|
| 49 |
+
{
|
| 50 |
+
"version": 1,
|
| 51 |
+
"name": "lora1",
|
| 52 |
+
"bin-sections": [
|
| 53 |
+
"your/path/to/model/lora/file.bin"
|
| 54 |
+
]
|
| 55 |
+
}
|
| 56 |
+
]
|
| 57 |
+
}
|
| 58 |
+
}
|
| 59 |
+
}
|
| 60 |
+
}
|
| 61 |
+
}
|
| 62 |
+
}
|
Genie/configs/llama2-7b/llama2-7b-genaitransformer.json
CHANGED
|
@@ -30,7 +30,10 @@
|
|
| 30 |
"version" : 1,
|
| 31 |
"type" : "QnnGenAiTransformer",
|
| 32 |
"QnnGenAiTransformer" : {
|
| 33 |
-
"version" : 1
|
|
|
|
|
|
|
|
|
|
| 34 |
}
|
| 35 |
},
|
| 36 |
"model" : {
|
|
|
|
| 30 |
"version" : 1,
|
| 31 |
"type" : "QnnGenAiTransformer",
|
| 32 |
"QnnGenAiTransformer" : {
|
| 33 |
+
"version" : 1,
|
| 34 |
+
"n-layer": 32,
|
| 35 |
+
"n-embd": 4096,
|
| 36 |
+
"n-heads": 32
|
| 37 |
}
|
| 38 |
},
|
| 39 |
"model" : {
|
Genie/configs/llama2-7b/llama2-7b-gpu.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"dialog" : {
|
| 3 |
+
"version" : 1,
|
| 4 |
+
"type" : "basic",
|
| 5 |
+
"context" : {
|
| 6 |
+
"version" : 1,
|
| 7 |
+
"size": 1024,
|
| 8 |
+
"n-vocab": 32000,
|
| 9 |
+
"bos-token": 1,
|
| 10 |
+
"eos-token": 2
|
| 11 |
+
},
|
| 12 |
+
"sampler" : {
|
| 13 |
+
"version" : 1,
|
| 14 |
+
"seed" : 42,
|
| 15 |
+
"temp" : 1.1,
|
| 16 |
+
"top-k" : 40,
|
| 17 |
+
"top-p" : 0.95,
|
| 18 |
+
"greedy" : false
|
| 19 |
+
},
|
| 20 |
+
"tokenizer" : {
|
| 21 |
+
"version" : 1,
|
| 22 |
+
"path" : "/path/to/tokenizer.json"
|
| 23 |
+
},
|
| 24 |
+
"engine" : {
|
| 25 |
+
"version" : 1,
|
| 26 |
+
"n-threads" : 3,
|
| 27 |
+
"backend" : {
|
| 28 |
+
"version" : 1,
|
| 29 |
+
"type" : "QnnGpu"
|
| 30 |
+
},
|
| 31 |
+
"model" : {
|
| 32 |
+
"version" : 1,
|
| 33 |
+
"type" : "binary",
|
| 34 |
+
"binary" : {
|
| 35 |
+
"version" : 1,
|
| 36 |
+
"ctx-bins" : [
|
| 37 |
+
"/path/to/model.bin"
|
| 38 |
+
]
|
| 39 |
+
}
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
}
|
| 43 |
+
}
|