jstzwjr commited on
Commit
c71c7c5
·
1 Parent(s): 11481cd

add genie2.29

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Genie/Genie/GenieSymbols.default +5 -0
  2. Genie/Genie/make/Android.mk +2 -0
  3. Genie/Genie/make/Application.mk +1 -1
  4. Genie/Genie/make/Makefile.linux-x86_64 +14 -5
  5. Genie/Genie/src/Dialog.cpp +54 -80
  6. Genie/Genie/src/Dialog.hpp +6 -1
  7. Genie/Genie/src/Embedding.cpp +740 -0
  8. Genie/Genie/src/Embedding.hpp +56 -0
  9. Genie/Genie/src/Exception.hpp +1 -0
  10. Genie/Genie/src/GenieDialog.cpp +19 -1
  11. Genie/Genie/src/GenieEmbedding.cpp +118 -0
  12. Genie/Genie/src/GenieSampler.cpp +93 -0
  13. Genie/Genie/src/Macro.hpp +2 -0
  14. Genie/Genie/src/Sampler.cpp +275 -0
  15. Genie/Genie/src/Sampler.hpp +60 -0
  16. Genie/Genie/src/qualla/context.cpp +8 -0
  17. Genie/Genie/src/qualla/dialogs/ssd-q1.cpp +2 -2
  18. Genie/Genie/src/qualla/engine.cpp +1 -1
  19. Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.cpp +317 -0
  20. Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.hpp +128 -0
  21. Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp +15 -1
  22. Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp +76 -1
  23. Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp +25 -0
  24. Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp +369 -90
  25. Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp +9 -0
  26. Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp +5 -5
  27. Genie/Genie/src/qualla/engines/qnn-cpu.cpp +55 -3
  28. Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp +51 -0
  29. Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp +25 -0
  30. Genie/Genie/src/qualla/engines/qnn-gpu.cpp +193 -0
  31. Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.cpp +603 -0
  32. Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.hpp +136 -0
  33. Genie/Genie/src/qualla/engines/qnn-htp.cpp +2 -2
  34. Genie/Genie/src/qualla/engines/qnn-htp.hpp +1 -1
  35. Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp +9 -3
  36. Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp +2 -1
  37. Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp +8 -4
  38. Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp +4 -6
  39. Genie/Genie/src/qualla/include/qualla/detail/basic-sampler.hpp +1 -0
  40. Genie/Genie/src/qualla/include/qualla/dialog.hpp +1 -0
  41. Genie/Genie/src/qualla/include/qualla/engine.hpp +1 -1
  42. Genie/Genie/src/qualla/include/qualla/sampler.hpp +1 -0
  43. Genie/Genie/src/qualla/sampler.cpp +4 -0
  44. Genie/Genie/src/qualla/samplers/basic.cpp +8 -0
  45. Genie/Genie/src/qualla/tokenizers/rust/Cargo.lock +26 -26
  46. Genie/Model/model.cpp +23 -2
  47. Genie/configs/llama2-7b/llama2-7b-draft-htp-target-htp-spd.json +2 -1
  48. Genie/configs/llama2-7b/llama2-7b-genaitransformer-lora.json +62 -0
  49. Genie/configs/llama2-7b/llama2-7b-genaitransformer.json +4 -1
  50. Genie/configs/llama2-7b/llama2-7b-gpu.json +43 -0
Genie/Genie/GenieSymbols.default CHANGED
@@ -14,6 +14,11 @@
14
  GenieDialogConfig_free*;
15
  GenieDialog_create*;
16
  GenieDialog_query*;
 
 
 
 
 
17
  GenieDialog_tokenQuery*;
18
  GenieDialog_embeddingQuery*;
19
  GenieDialog_save*;
 
14
  GenieDialogConfig_free*;
15
  GenieDialog_create*;
16
  GenieDialog_query*;
17
+ GenieDialog_getSampler*;
18
+ GenieSampler_applyConfig*;
19
+ GenieSamplerConfig_createFromJson*;
20
+ GenieSamplerConfig_setParam*;
21
+ GenieSamplerConfig_free*;
22
  GenieDialog_tokenQuery*;
23
  GenieDialog_embeddingQuery*;
24
  GenieDialog_save*;
Genie/Genie/make/Android.mk CHANGED
@@ -29,6 +29,7 @@ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../../../../include/QNN/HTP
29
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers
30
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api
31
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu
 
32
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp
33
 
34
  #========================== Define T2T Lib variables =============================================
@@ -45,6 +46,7 @@ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/dialogs
45
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp)
46
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp)
47
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp)
 
48
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp)
49
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp)
50
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp)
 
29
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/tokenizers
30
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-api
31
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu
32
+ PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-gpu
33
  PACKAGE_C_INCLUDES += -I $(LOCAL_PATH)/../src/qualla/engines/qnn-htp
34
 
35
  #========================== Define T2T Lib variables =============================================
 
46
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/*.cpp)
47
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-api/*.cpp)
48
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-cpu/*.cpp)
49
+ MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-gpu/*.cpp)
50
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/engines/qnn-htp/*.cpp)
51
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/utils/*.cpp)
52
  MY_SRC_FILES += $(wildcard $(LOCAL_PATH)/../src/qualla/loggers/*.cpp)
Genie/Genie/make/Application.mk CHANGED
@@ -10,5 +10,5 @@ APP_ABI := arm64-v8a
10
  APP_STL := c++_shared
11
  APP_PLATFORM := android-21
12
  APP_MODULES := Genie
13
- APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
14
  APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all
 
10
  APP_STL := c++_shared
11
  APP_PLATFORM := android-21
12
  APP_MODULES := Genie
13
+ APP_CPPFLAGS += -std=c++2a -O3 -Wall -frtti -fexceptions -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_HTP=TRUE -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_ENGINE_QNN_GPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
14
  APP_LDFLAGS += -lc -lm -ldl -Wl,--version-script=GenieSymbols.default -Wl,--strip-all
Genie/Genie/make/Makefile.linux-x86_64 CHANGED
@@ -17,6 +17,7 @@ SRC_DIR_SAMPLE_DIALOGS := src/qualla/dialogs
17
  SRC_DIR_GENIE_ENGINES := src/qualla/engines
18
  SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api
19
  SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu
 
20
  SRC_DIR_GENIE_UTILS := src/qualla/utils
21
  #
22
  SRC_DIR_GENIE_LOGGERS := src/qualla/loggers
@@ -29,6 +30,7 @@ SRC_DIR_GENIE := src
29
 
30
  # Includes
31
  GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu
 
32
  GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api
33
  GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp
34
  GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers
@@ -62,7 +64,7 @@ endif
62
  GENIE_all: $(libGenie)
63
 
64
  # Include paths
65
- INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE)
66
 
67
  # set compiler flags
68
  COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES)
@@ -71,11 +73,11 @@ COMMON_LDFLAGS = -shared -s -fPIC -pthread -L/usr/lib/x86_64-linux-gnu -L./src/
71
  COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include
72
 
73
  ifdef QNN_DEBUG_ENABLE
74
- CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
75
  CFLAGS += $(COMMON_CFLAGS)
76
  LDFLAGS += $(COMMON_LDFLAGS)
77
  else
78
- CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
79
  CFLAGS += $(COMMON_CFLAGS)
80
  LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto
81
  endif
@@ -89,6 +91,7 @@ SOURCES_GENIE_QNN_API_CPP := $(wildcard $(SRC_DIR_GENIE_QNN_API)/*.cpp)
89
  SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp))
90
  SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp)
91
  SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp)
 
92
  SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp)
93
 
94
 
@@ -108,6 +111,8 @@ OBJ_DIR_GENIE_ENGINES := $(OBJ_DIR_QUALLA)/engines
108
  OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils
109
  OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu
110
  $(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU))
 
 
111
 
112
  OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers
113
  OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers
@@ -125,6 +130,7 @@ OBJECTS_GENIE_ENGINES := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES)/%.o,$(foreach
125
  OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x))))
126
  OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x))))
127
  OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x))))
 
128
 
129
  OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x))))
130
  OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x))))
@@ -157,16 +163,18 @@ $(OBJ_DIR_GENIE_UTILS)/%.o: $(SRC_DIR_GENIE_UTILS)/%.cpp $(CXX) $(CXXFLAGS) -c $
157
 
158
  $(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
159
 
 
 
160
  $(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
161
 
162
  $(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
163
 
164
 
165
  # set up resources
166
- directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS)
167
 
168
  # Compile
169
- $(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories)
170
  $(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers)
171
 
172
 
@@ -179,6 +187,7 @@ $(OBJECTS_GENIE_ENGINES): | $(OBJ_DIR_GENIE_ENGINES)
179
  $(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS)
180
  $(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS)
181
  $(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU)
 
182
  $(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS)
183
  $(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS)
184
 
 
17
  SRC_DIR_GENIE_ENGINES := src/qualla/engines
18
  SRC_DIR_GENIE_QNN_API := src/qualla/engines/qnn-api
19
  SRC_DIR_GENIE_ENGINES_CPU := src/qualla/engines/qnn-cpu
20
+ SRC_DIR_GENIE_ENGINES_GPU := src/qualla/engines/qnn-gpu
21
  SRC_DIR_GENIE_UTILS := src/qualla/utils
22
  #
23
  SRC_DIR_GENIE_LOGGERS := src/qualla/loggers
 
30
 
31
  # Includes
32
  GENIE_ENGINES_CPU_INCLUDE := src/qualla/engines/qnn-cpu
33
+ GENIE_ENGINES_GPU_INCLUDE := src/qualla/engines/qnn-gpu
34
  GENIE_ENGINES_API_INCLUDE := src/qualla/engines/qnn-api
35
  GENIE_ENGINES_HTP_INCLUDE := src/qualla/engines/qnn-htp
36
  GENIE_TOKENIZER_INCLUDE := src/qualla/tokenizers
 
64
  GENIE_all: $(libGenie)
65
 
66
  # Include paths
67
+ INCLUDES += -I$(GENIE_INCLUDE) -I$(QUALLA_INCLUDE) -I$(SRC_DIR_GENIE_TOKENIZERS) -I$(QNN_API_INCLUDE) -I$(GENIE_ENGINES_CPU_INCLUDE) -I$(GENIE_ENGINES_GPU_INCLUDE) -I$(QNN_API_HTP_INCLUDE) -I$(GENIE_ENGINES_API_INCLUDE) -I$(GENIE_TOKENIZER_INCLUDE) -I$(GENIE_C_API_HEADERS_INCLUDE)
68
 
69
  # set compiler flags
70
  COMMON_CXXFLAGS = -std=c++2a -frtti -fPIC -Wall -pg -pthread -nostdinc++ -stdlib=libc++ -idirafter /usr/lib/llvm-14/include/c++/v1 -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include $(INCLUDES)
 
73
  COMMON_CFLAGS = -nostdinc -idirafter /usr/lib/llvm-14/lib/clang/14.0.0/include/ -idirafter /usr/include
74
 
75
  ifdef QNN_DEBUG_ENABLE
76
+ CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O0 -g -DQNN_API="" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_ENGINE_QNN_GPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
77
  CFLAGS += $(COMMON_CFLAGS)
78
  LDFLAGS += $(COMMON_LDFLAGS)
79
  else
80
+ CXXFLAGS += $(COMMON_CXXFLAGS) -march=x86-64 -O3 -Wno-write-strings -fvisibility=hidden -DGENIE_API="__attribute__((visibility(\"default\")))" -DSPILLFILL -DQUALLA_ENGINE_QNN_CPU=TRUE -DQUALLA_ENGINE_QNN_GPU=TRUE -DQUALLA_APPS=OFF -DFMT_HEADER_ONLY -DGENIE_SAMPLE -DQUALLA_INTERNAL_QNN_SDK -DGENIE_SSD_FEATURE -DGENIE_SPD_FEATURE -DGENIE_LADE_FEATURE -DGENIE_MULTISTREAM_FEATURE -DGENIE_LORA_FEATURE -DGENIE_E2T_FEATURE
81
  CFLAGS += $(COMMON_CFLAGS)
82
  LDFLAGS += $(COMMON_LDFLAGS) -fvisibility=hidden -flto
83
  endif
 
91
  SOURCES_GENIE_ENGINES_CPP := $(filter-out $(SRC_DIR_GENIE_ENGINES)/qnn-htp.cpp, $(wildcard $(SRC_DIR_GENIE_ENGINES)/*.cpp))
92
  SOURCES_GENIE_DIALOGS_CPP := $(wildcard $(SRC_DIR_SAMPLE_DIALOGS)/*.cpp)
93
  SOURCES_GENIE_ENGINES_CPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_CPU)/*.cpp)
94
+ SOURCES_GENIE_ENGINES_GPU_CPP := $(wildcard $(SRC_DIR_GENIE_ENGINES_GPU)/*.cpp)
95
  SOURCES_GENIE_UTILS_CPP := $(wildcard $(SRC_DIR_GENIE_UTILS)/*.cpp)
96
 
97
 
 
111
  OBJ_DIR_GENIE_UTILS := $(OBJ_DIR_QUALLA)/utils
112
  OBJ_DIR_GENIE_ENGINES_CPU := $(OBJ_DIR_QUALLA)/engines/qnn-cpu
113
  $(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_CPU))
114
+ OBJ_DIR_GENIE_ENGINES_GPU := $(OBJ_DIR_QUALLA)/engines/qnn-gpu
115
+ $(shell mkdir -p $(OBJ_DIR_GENIE_ENGINES_GPU))
116
 
117
  OBJ_DIR_GENIE_LOGGERS := obj/$(QNN_TARGET)/qualla/loggers
118
  OBJ_DIR_GENIE_SAMPLERS := obj/$(QNN_TARGET)/qualla/samplers
 
130
  OBJECTS_GENIE_DIALOGS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_DIALOGS)/%.o,$(foreach x,$(SOURCES_GENIE_DIALOGS_CPP),$(notdir $(x))))
131
  OBJECTS_GENIE_UTILS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_UTILS)/%.o,$(foreach x,$(SOURCES_GENIE_UTILS_CPP),$(notdir $(x))))
132
  OBJECTS_GENIE_ENGINES_CPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_CPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_CPU_CPP),$(notdir $(x))))
133
+ OBJECTS_GENIE_ENGINES_GPU := $(patsubst %.cpp,$(OBJ_DIR_GENIE_ENGINES_GPU)/%.o,$(foreach x,$(SOURCES_GENIE_ENGINES_GPU_CPP),$(notdir $(x))))
134
 
135
  OBJECTS_GENIE_LOGGERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_LOGGERS)/%.o,$(foreach x,$(SOURCES_GENIE_LOGGERS_CPP),$(notdir $(x))))
136
  OBJECTS_GENIE_SAMPLERS := $(patsubst %.cpp,$(OBJ_DIR_GENIE_SAMPLERS)/%.o,$(foreach x,$(SOURCES_GENIE_SAMPLERS_CPP),$(notdir $(x))))
 
163
 
164
  $(OBJ_DIR_GENIE_ENGINES_CPU)/%.o: $(SRC_DIR_GENIE_ENGINES_CPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
165
 
166
+ $(OBJ_DIR_GENIE_ENGINES_GPU)/%.o: $(SRC_DIR_GENIE_ENGINES_GPU)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
167
+
168
  $(OBJ_DIR_GENIE_LOGGERS)/%.o: $(SRC_DIR_GENIE_LOGGERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
169
 
170
  $(OBJ_DIR_GENIE_SAMPLERS)/%.o: $(SRC_DIR_GENIE_SAMPLERS)/%.cpp $(CXX) $(CXXFLAGS) -c $^ -o $@
171
 
172
 
173
  # set up resources
174
+ directories := $(TARGET_DIR) $(OBJ_DIR_GENIE) $(OBJ_DIR_GENIE_QNN_API) $(OBJ_DIR_QUALLA) $(OBJ_DIR_GENIE_TOKENIZERS) $(OBJ_DIR_GENIE_ENGINES) $(OBJ_DIR_GENIE_DIALOGS) $(OBJ_DIR_GENIE_UTILS) $(OBJ_DIR_GENIE_ENGINES_CPU) $(OBJ_DIR_GENIE_ENGINES_GPU) $(OBJ_DIR_GENIE_LOGGERS) $(OBJ_DIR_GENIE_SAMPLERS)
175
 
176
  # Compile
177
+ $(libGenie): $(OBJECTS_GENIE) $(OBJECTS_QUALLA) $(OBJECTS_GENIE_QNN_API) $(OBJECTS_GENIE_TOKENIZERS) $(OBJECTS_GENIE_ENGINES) $(OBJECTS_GENIE_DIALOGS) $(OBJECTS_GENIE_UTILS) $(OBJECTS_GENIE_ENGINES_CPU) $(OBJECTS_GENIE_ENGINES_GPU) $(OBJECTS_GENIE_LOGGERS) $(OBJECTS_GENIE_SAMPLERS) | $(directories)
178
  $(CXX) $(CXXFLAGS) -shared -o $@ $^ $(LIBS) $(libtokenizers)
179
 
180
 
 
187
  $(OBJECTS_GENIE_DIALOGS): | $(OBJ_DIR_GENIE_DIALOGS)
188
  $(OBJECTS_GENIE_UTILS): | $(OBJ_DIR_GENIE_UTILS)
189
  $(OBJECTS_GENIE_ENGINES_CPU): | $(OBJ_DIR_GENIE_ENGINES_CPU)
190
+ $(OBJECTS_GENIE_ENGINES_GPU): | $(OBJ_DIR_GENIE_ENGINES_GPU)
191
  $(OBJECTS_GENIE_LOGGERS): | $(OBJ_DIR_GENIE_LOGGERS)
192
  $(OBJECTS_GENIE_SAMPLERS): | $(OBJ_DIR_GENIE_SAMPLERS)
193
 
Genie/Genie/src/Dialog.cpp CHANGED
@@ -95,81 +95,6 @@ static void translateContextConfig(const qualla::json& genieConfig, qualla::json
95
  }
96
  }
97
 
98
- //=============================================================================
99
- // Sampler::Config functions
100
- //=============================================================================
101
-
102
- static void validateSamplerConfig(const qualla::json& config) {
103
- if (!config.is_object()) {
104
- throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
105
- }
106
-
107
- std::set<std::string> mandatoryFields{"version"};
108
- for (const auto& field : mandatoryFields) {
109
- if (!config.contains(field)) {
110
- throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
111
- }
112
- }
113
-
114
- // component is used in the "ENFORCE" macros
115
- std::string component = "sampler";
116
-
117
- for (auto& item : config.items()) {
118
- if (item.key() == "version") {
119
- JSON_ENFORCE_NUMERIC();
120
- if (item.value().get<int>() != 1) {
121
- throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
122
- "Invalid sampler config: unsupported version: " + item.value().dump());
123
- }
124
- } else if (item.key() == "seed") {
125
- JSON_ENFORCE_NUMERIC();
126
- } else if (item.key() == "temp") {
127
- JSON_ENFORCE_NUMERIC();
128
- } else if (item.key() == "top-k") {
129
- JSON_ENFORCE_NUMERIC();
130
- } else if (item.key() == "top-p") {
131
- JSON_ENFORCE_NUMERIC();
132
- } else if (item.key() == "greedy") {
133
- JSON_ENFORCE_BOOLEAN();
134
- } else {
135
- throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
136
- }
137
- }
138
- }
139
-
140
- static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
141
- if (genieConfig["dialog"].contains("sampler")) {
142
- quallaConfig["sampler"]["type"] = "basic";
143
-
144
- if (genieConfig["dialog"]["sampler"].contains("seed")) {
145
- quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
146
- }
147
- if (genieConfig["dialog"]["sampler"].contains("temp")) {
148
- quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
149
- }
150
-
151
- quallaConfig["sampler"]["role"] = "primary";
152
- #if defined(GENIE_SPD_FEATURE)
153
- if (genieConfig["dialog"]["type"] == "spd") {
154
- quallaConfig["sampler"]["role"] = "target";
155
- }
156
- #endif
157
-
158
- if (genieConfig["dialog"]["sampler"].contains("top-k")) {
159
- quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
160
- }
161
- if (genieConfig["dialog"]["sampler"].contains("top-p")) {
162
- quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
163
- }
164
- if (genieConfig["dialog"]["sampler"].contains("greedy")) {
165
- quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
166
- }
167
- if (genieConfig["dialog"]["sampler"].contains("seed")) {
168
- quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
169
- }
170
- }
171
- }
172
-
173
  //=============================================================================
174
  // Tokenizer::Config functions
175
  //=============================================================================
@@ -322,6 +247,8 @@ static void validateBackendHtpConfig(const qualla::json& config) {
322
  } else if (item.key() == "rope-theta") {
323
  rope_theta_set = true;
324
  JSON_ENFORCE_NUMERIC();
 
 
325
  } else {
326
  throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
327
  }
@@ -410,7 +337,7 @@ static void validateBackendConfig(const qualla::json& config) {
410
  htp = true;
411
  } else if (type == "QnnGenAiTransformer") {
412
  genai = true;
413
- } else {
414
  throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
415
  "Invalid backend config: unsupported type: " + item.value().dump());
416
  }
@@ -629,6 +556,9 @@ static void validateModelLibraryConfig(const qualla::json& config) {
629
  }
630
  } else if (item.key() == "model-bin") {
631
  JSON_ENFORCE_STRING();
 
 
 
632
  } else {
633
  throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
634
  }
@@ -956,6 +886,10 @@ static void translateEngineConfig(const qualla::json& genieEngineConfig,
956
  quallaEngineConfig["use-async-Init"] =
957
  genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
958
  }
 
 
 
 
959
  } else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
960
  quallaEngineConfig["type"] = "qnn-cpu";
961
  quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
@@ -979,6 +913,8 @@ static void translateEngineConfig(const qualla::json& genieEngineConfig,
979
  quallaEngineConfig["n_heads"] =
980
  genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
981
  }
 
 
982
  }
983
 
984
  if (genieEngineConfig["backend"].contains("extensions")) {
@@ -1020,6 +956,21 @@ static void translateEngineConfig(const qualla::json& genieEngineConfig,
1020
  quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
1021
  quallaEngineConfig["op-package"] =
1022
  getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1023
  }
1024
  if (genieEngineConfig["model"].contains("positional-encoding")) {
1025
  quallaEngineConfig["positional-encoding"]["type"] =
@@ -1424,7 +1375,7 @@ static void validateDialogConfig(const qualla::json& config) {
1424
  validateTokenizerConfig(item.value());
1425
  } else if (item.key() == "sampler") {
1426
  JSON_ENFORCE_OBJECT();
1427
- validateSamplerConfig(item.value());
1428
  } else if (item.key() == "engine") {
1429
  JSON_ENFORCE_ARRAY_OR_OBJECT();
1430
  } else if (item.key() == "embedding") {
@@ -1550,7 +1501,7 @@ static void translateDialogConfig(const qualla::json& genieConfig, qualla::json&
1550
 
1551
  translateContextConfig(genieConfig, quallaConfig);
1552
  translateTokenizerConfig(genieConfig, quallaConfig);
1553
- translateSamplerConfig(genieConfig, quallaConfig);
1554
  translateMultiEngineConfig(genieConfig, quallaConfig);
1555
  translateEmbeddingConfig(genieConfig, quallaConfig);
1556
  }
@@ -1611,7 +1562,7 @@ Dialog::Config::Config(const char* configStr) {
1611
  m_config = config;
1612
  }
1613
 
1614
- qualla::json Dialog::Config::getJson() const { return m_config; }
1615
 
1616
  //=============================================================================
1617
  // Dialog functions
@@ -1640,6 +1591,27 @@ Dialog::Dialog(std::shared_ptr<Config> config) {
1640
  if (!m_quallaDialog) {
1641
  throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
1642
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1643
  }
1644
 
1645
  static_assert(qualla::Sentence::Code::COMPLETE ==
@@ -1801,4 +1773,6 @@ int32_t Dialog::tokenQuery(const uint32_t* tokens,
1801
  kpis.generate.last_usec,
1802
  kpis.tps.generate);
1803
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1804
- }
 
 
 
95
  }
96
  }
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  //=============================================================================
99
  // Tokenizer::Config functions
100
  //=============================================================================
 
247
  } else if (item.key() == "rope-theta") {
248
  rope_theta_set = true;
249
  JSON_ENFORCE_NUMERIC();
250
+ } else if (item.key() == "enable-graph-switching") {
251
+ JSON_ENFORCE_BOOLEAN();
252
  } else {
253
  throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
254
  }
 
337
  htp = true;
338
  } else if (type == "QnnGenAiTransformer") {
339
  genai = true;
340
+ } else if (type != "QnnGpu") {
341
  throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
342
  "Invalid backend config: unsupported type: " + item.value().dump());
343
  }
 
556
  }
557
  } else if (item.key() == "model-bin") {
558
  JSON_ENFORCE_STRING();
559
+ } else if (item.key() == "lora") {
560
+ JSON_ENFORCE_OBJECT();
561
+ validateLoraConfig(item.value());
562
  } else {
563
  throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
564
  }
 
886
  quallaEngineConfig["use-async-Init"] =
887
  genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
888
  }
889
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("enable-graph-switching")) {
890
+ quallaEngineConfig["enable-graph-switching"] =
891
+ genieEngineConfig["backend"]["QnnHtp"]["enable-graph-switching"];
892
+ }
893
  } else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
894
  quallaEngineConfig["type"] = "qnn-cpu";
895
  quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
 
913
  quallaEngineConfig["n_heads"] =
914
  genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
915
  }
916
+ } else if (genieEngineConfig["backend"]["type"] == "QnnGpu") {
917
+ quallaEngineConfig["type"] = "qnn-gpu";
918
  }
919
 
920
  if (genieEngineConfig["backend"].contains("extensions")) {
 
956
  quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
957
  quallaEngineConfig["op-package"] =
958
  getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
959
+ if (genieEngineConfig["model"]["library"].contains("lora")) {
960
+ for (int i = 0; i < genieEngineConfig["model"]["library"]["lora"]["adapters"].size(); i++) {
961
+ quallaEngineConfig["lora"][i]["adapter-name"] =
962
+ genieEngineConfig["model"]["library"]["lora"]["adapters"][i]["name"];
963
+ if (genieEngineConfig["model"]["library"]["lora"].contains("alpha-tensor-name")) {
964
+ quallaEngineConfig["lora"][i]["alpha-tensor-name"] =
965
+ genieEngineConfig["model"]["library"]["lora"]
966
+ ["alpha-tensor-name"];
967
+ }
968
+ quallaEngineConfig["lora"][i]["alpha-tensor-value"] = 1.0f;
969
+ quallaEngineConfig["lora"][i]["binsection-basedir"] = "";
970
+ quallaEngineConfig["lora"][i]["bin-sections"] =
971
+ genieEngineConfig["model"]["library"]["lora"]["adapters"][i]["bin-sections"];
972
+ }
973
+ }
974
  }
975
  if (genieEngineConfig["model"].contains("positional-encoding")) {
976
  quallaEngineConfig["positional-encoding"]["type"] =
 
1375
  validateTokenizerConfig(item.value());
1376
  } else if (item.key() == "sampler") {
1377
  JSON_ENFORCE_OBJECT();
1378
+ Sampler::SamplerConfig::validateSamplerConfig(item.value());
1379
  } else if (item.key() == "engine") {
1380
  JSON_ENFORCE_ARRAY_OR_OBJECT();
1381
  } else if (item.key() == "embedding") {
 
1501
 
1502
  translateContextConfig(genieConfig, quallaConfig);
1503
  translateTokenizerConfig(genieConfig, quallaConfig);
1504
+ Sampler::SamplerConfig::translateSamplerConfig(genieConfig, quallaConfig);
1505
  translateMultiEngineConfig(genieConfig, quallaConfig);
1506
  translateEmbeddingConfig(genieConfig, quallaConfig);
1507
  }
 
1562
  m_config = config;
1563
  }
1564
 
1565
+ qualla::json& Dialog::Config::getJson() { return m_config; }
1566
 
1567
  //=============================================================================
1568
  // Dialog functions
 
1591
  if (!m_quallaDialog) {
1592
  throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a dialog object");
1593
  }
1594
+ /*
1595
+ * spec-dec has a mandatory "target" sampler and an optional "draft" sampler
1596
+ * Check their availability and pass their references to Dialog Sampler to update with
1597
+ * applyConfig()
1598
+ */
1599
+ std::shared_ptr<Sampler> sampler;
1600
+ std::vector<std::reference_wrapper<qualla::Sampler>> quallaSamplers;
1601
+ if (quallaConfig["type"] == "spec-dec") {
1602
+ quallaSamplers.push_back(m_quallaDialog->sampler("target"));
1603
+ if (m_quallaDialog->isSamplerPresent("draft"))
1604
+ quallaSamplers.push_back(m_quallaDialog->sampler("draft"));
1605
+ sampler = std::make_shared<Sampler>(config->getJson()["dialog"], quallaSamplers);
1606
+ } else {
1607
+ quallaSamplers.push_back(m_quallaDialog->sampler()); // Default role is "primary"
1608
+ sampler = std::make_shared<Sampler>(config->getJson()["dialog"], quallaSamplers);
1609
+ }
1610
+ m_samplerHandle = Sampler::add(sampler);
1611
+ }
1612
+
1613
+ GenieSampler_Handle_t Dialog::getSamplerHandle(std::shared_ptr<Dialog> dialog) {
1614
+ return dialog->m_samplerHandle;
1615
  }
1616
 
1617
  static_assert(qualla::Sentence::Code::COMPLETE ==
 
1773
  kpis.generate.last_usec,
1774
  kpis.tps.generate);
1775
  return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_QUERY_FAILED);
1776
+ }
1777
+
1778
+ Dialog::~Dialog() { Sampler::remove(m_samplerHandle); }
Genie/Genie/src/Dialog.hpp CHANGED
@@ -10,11 +10,13 @@
10
 
11
  #include <atomic>
12
  #include <memory>
 
13
 
14
  #include "GenieDialog.h"
15
  #include "Util/HandleManager.hpp"
16
  #include "qualla/dialog.hpp"
17
  #include "qualla/DialogCallback.hpp"
 
18
 
19
  namespace genie {
20
 
@@ -33,7 +35,7 @@ class Dialog {
33
  static void remove(GenieDialogConfig_Handle_t handle);
34
 
35
  Config(const char* configStr);
36
- qualla::json getJson() const;
37
 
38
  private:
39
  static qnn::util::HandleManager<Config> s_manager;
@@ -43,10 +45,12 @@ class Dialog {
43
  static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
44
  static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
45
  static void remove(GenieDialog_Handle_t handle);
 
46
 
47
  qualla::DialogCallback dialogCallback;
48
 
49
  Dialog(std::shared_ptr<Config> config);
 
50
 
51
  Dialog(const Dialog&) = delete;
52
  Dialog& operator=(const Dialog&) = delete;
@@ -91,5 +95,6 @@ class Dialog {
91
  uint32_t m_tokenLimit{UINT32_MAX};
92
  static qnn::util::HandleManager<Dialog> s_manager;
93
  static std::atomic<std::uint32_t> s_nameCounter;
 
94
  };
95
  } // namespace genie
 
10
 
11
  #include <atomic>
12
  #include <memory>
13
+ #include <functional>
14
 
15
  #include "GenieDialog.h"
16
  #include "Util/HandleManager.hpp"
17
  #include "qualla/dialog.hpp"
18
  #include "qualla/DialogCallback.hpp"
19
+ #include "Sampler.hpp"
20
 
21
  namespace genie {
22
 
 
35
  static void remove(GenieDialogConfig_Handle_t handle);
36
 
37
  Config(const char* configStr);
38
+ qualla::json& getJson();
39
 
40
  private:
41
  static qnn::util::HandleManager<Config> s_manager;
 
45
  static GenieDialog_Handle_t add(std::shared_ptr<Dialog> dialog);
46
  static std::shared_ptr<Dialog> get(GenieDialog_Handle_t handle);
47
  static void remove(GenieDialog_Handle_t handle);
48
+ static GenieSampler_Handle_t getSamplerHandle(std::shared_ptr<genie::Dialog> dialog);
49
 
50
  qualla::DialogCallback dialogCallback;
51
 
52
  Dialog(std::shared_ptr<Config> config);
53
+ ~Dialog();
54
 
55
  Dialog(const Dialog&) = delete;
56
  Dialog& operator=(const Dialog&) = delete;
 
95
  uint32_t m_tokenLimit{UINT32_MAX};
96
  static qnn::util::HandleManager<Dialog> s_manager;
97
  static std::atomic<std::uint32_t> s_nameCounter;
98
+ GenieSampler_Handle_t m_samplerHandle;
99
  };
100
  } // namespace genie
Genie/Genie/src/Embedding.cpp ADDED
@@ -0,0 +1,740 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include <exception>
10
+ #include <set>
11
+ #include <sstream>
12
+
13
+ #include "Embedding.hpp"
14
+ #include "Exception.hpp"
15
+ #include "Macro.hpp"
16
+ #include "qualla/detail/json.hpp"
17
+ #include "qualla/env.hpp"
18
+
19
+ using namespace genie;
20
+
21
+ #ifdef _WIN32
22
+ inline std::string libPrefix = "";
23
+ inline std::string libSuffix = ".dll";
24
+ #else
25
+ inline std::string libPrefix = "lib";
26
+ inline std::string libSuffix = ".so";
27
+ #endif
28
+
29
+ inline std::string getLibName(std::string baseName) { return libPrefix + baseName + libSuffix; }
30
+
31
+ //=============================================================================
32
+ // Context::Config functions
33
+ //=============================================================================
34
+
35
+ static void validateContextConfig(const qualla::json& config) {
36
+ if (!config.is_object()) {
37
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "context config is not an object");
38
+ }
39
+
40
+ std::set<std::string> mandatoryFields{
41
+ "version", "n-vocab", "ctx-size", "embed-size", "pad-token"};
42
+ for (const auto& field : mandatoryFields) {
43
+ if (!config.contains(field)) {
44
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing context field: " + field);
45
+ }
46
+ }
47
+
48
+ // component is used in the "ENFORCE" macros
49
+ std::string component = "context";
50
+
51
+ for (auto& item : config.items()) {
52
+ if (item.key() == "version") {
53
+ JSON_ENFORCE_NUMERIC();
54
+ if (item.value().get<int>() != 1) {
55
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
56
+ "Invalid context config: unsupported version: " + item.value().dump());
57
+ }
58
+ } else if (item.key() == "n-vocab") {
59
+ JSON_ENFORCE_NUMERIC();
60
+ } else if (item.key() == "ctx-size") {
61
+ JSON_ENFORCE_NUMERIC();
62
+ } else if (item.key() == "embed-size") {
63
+ JSON_ENFORCE_NUMERIC();
64
+ } else if (item.key() == "pad-token") {
65
+ JSON_ENFORCE_NUMERIC();
66
+ } else {
67
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
68
+ }
69
+ }
70
+ }
71
+
72
+ static void translateContextConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
73
+ quallaConfig["n-vocab"] = genieConfig["n-vocab"];
74
+ quallaConfig["size"] = genieConfig["ctx-size"];
75
+ quallaConfig["n-embd"] = genieConfig["embed-size"];
76
+ quallaConfig["pad-token"] = genieConfig["pad-token"];
77
+ }
78
+
79
+ //=============================================================================
80
+ // Tokenizer::Config functions
81
+ //=============================================================================
82
+
83
+ static void validateTokenizerConfig(const qualla::json& config) {
84
+ if (!config.is_object()) {
85
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "tokenizer config is not an object");
86
+ }
87
+
88
+ std::set<std::string> mandatoryFields{"version", "path"};
89
+ for (const auto& field : mandatoryFields) {
90
+ if (!config.contains(field)) {
91
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing tokenizer field: " + field);
92
+ }
93
+ }
94
+
95
+ // component is used in the "ENFORCE" macros
96
+ std::string component = "tokenizer";
97
+
98
+ for (auto& item : config.items()) {
99
+ if (item.key() == "version") {
100
+ JSON_ENFORCE_NUMERIC();
101
+ if (item.value().get<int>() != 1) {
102
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
103
+ "Invalid tokenizer config: unsupported version: " + item.value().dump());
104
+ }
105
+ } else if (item.key() == "path") {
106
+ JSON_ENFORCE_STRING();
107
+ // Note: the existence of this file is checked by qualla
108
+ } else {
109
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
110
+ "Unknown tokenizer config key: " + item.key());
111
+ }
112
+ }
113
+ }
114
+
115
+ static void translateTokenizerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
116
+ quallaConfig["tokenizer"] = genieConfig["path"];
117
+ }
118
+
119
+ //=============================================================================
120
+ // Backend::Config functions
121
+ //=============================================================================
122
+
123
+ static void validateBackendHtpConfig(const qualla::json& config) {
124
+ if (!config.is_object()) {
125
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnHtp config is not an object");
126
+ }
127
+
128
+ std::set<std::string> mandatoryFields{
129
+ "version", "spill-fill-bufsize", "use-mmap", "pooled-output", "allow-async-init"};
130
+ for (const auto& field : mandatoryFields) {
131
+ if (!config.contains(field)) {
132
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp field: " + field);
133
+ }
134
+ }
135
+
136
+ // component is used in the "ENFORCE" macros
137
+ std::string component = "QnnHtp";
138
+
139
+ for (auto& item : config.items()) {
140
+ if (item.key() == "version") {
141
+ JSON_ENFORCE_NUMERIC();
142
+ if (item.value().get<int>() != 1) {
143
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
144
+ "Invalid QnnHtp config: unsupported version: " + item.value().dump());
145
+ }
146
+ } else if (item.key() == "spill-fill-bufsize") {
147
+ JSON_ENFORCE_NUMERIC();
148
+ } else if (item.key() == "use-mmap") {
149
+ JSON_ENFORCE_BOOLEAN();
150
+ } else if (item.key() == "pooled-output") {
151
+ JSON_ENFORCE_BOOLEAN();
152
+ } else if (item.key() == "allow-async-init") {
153
+ JSON_ENFORCE_BOOLEAN();
154
+ } else if (item.key() == "disable-kv-cache") {
155
+ JSON_ENFORCE_BOOLEAN();
156
+ } else {
157
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown QnnHtp config key: " + item.key());
158
+ }
159
+ }
160
+ }
161
+
162
+ static void validateBackendGenaiConfig(const qualla::json& config) {
163
+ if (!config.is_object()) {
164
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "QnnGenAiTransformer config is not an object");
165
+ }
166
+
167
+ std::set<std::string> mandatoryFields{"version"};
168
+ for (const auto& field : mandatoryFields) {
169
+ if (!config.contains(field)) {
170
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
171
+ "Missing QnnGenAiTransformer field: " + field);
172
+ }
173
+ }
174
+
175
+ // component is used in the "ENFORCE" macros
176
+ std::string component = "QnnGenAiTransformer";
177
+
178
+ for (auto& item : config.items()) {
179
+ if (item.key() == "version") {
180
+ JSON_ENFORCE_NUMERIC();
181
+ if (item.value().get<int>() != 1) {
182
+ throw Exception(
183
+ GENIE_STATUS_ERROR_JSON_VALUE,
184
+ "Invalid QnnGenAiTransformer config: unsupported version: " + item.value().dump());
185
+ }
186
+ } else if (item.key() == "n-logits") {
187
+ JSON_ENFORCE_NUMERIC();
188
+ } else if (item.key() == "n-layer") {
189
+ JSON_ENFORCE_NUMERIC();
190
+ } else if (item.key() == "n-embd") {
191
+ JSON_ENFORCE_NUMERIC();
192
+ } else if (item.key() == "n-heads") {
193
+ JSON_ENFORCE_NUMERIC();
194
+ } else {
195
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
196
+ "Unknown QnnGenAiTransformer config key: " + item.key());
197
+ }
198
+ }
199
+ }
200
+
201
+ static void validateBackendConfig(const qualla::json& config) {
202
+ if (!config.is_object()) {
203
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "backend config is not an object");
204
+ }
205
+
206
+ std::set<std::string> mandatoryFields{"version", "type"};
207
+ for (const auto& field : mandatoryFields) {
208
+ if (!config.contains(field)) {
209
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing backend field: " + field);
210
+ }
211
+ }
212
+
213
+ // component is used in the "ENFORCE" macros
214
+ std::string component = "backend";
215
+
216
+ std::string type;
217
+ bool htp = false;
218
+ qualla::json htpConfig;
219
+ bool genai = false;
220
+ qualla::json genaiConfig;
221
+
222
+ for (auto& item : config.items()) {
223
+ if (item.key() == "version") {
224
+ JSON_ENFORCE_NUMERIC();
225
+ if (item.value().get<int>() != 1) {
226
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
227
+ "Invalid backend config: unsupported version: " + item.value().dump());
228
+ }
229
+ } else if (item.key() == "type") {
230
+ JSON_ENFORCE_STRING();
231
+ type = item.value().get<std::string>();
232
+ if (type == "QnnHtp") {
233
+ htp = true;
234
+ } else if (type == "QnnGenAiTransformer") {
235
+ genai = true;
236
+ } else {
237
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
238
+ "Invalid backend config: unsupported type: " + item.value().dump());
239
+ }
240
+ } else if (item.key() == "extensions") {
241
+ JSON_ENFORCE_STRING();
242
+ } else if (item.key() == "QnnHtp") {
243
+ JSON_ENFORCE_OBJECT();
244
+ htpConfig = item.value();
245
+ } else if (item.key() == "QnnGenAiTransformer") {
246
+ JSON_ENFORCE_OBJECT();
247
+ genaiConfig = item.value();
248
+ } else {
249
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown backend config key: " + item.key());
250
+ }
251
+ }
252
+
253
+ if (htp) {
254
+ if (!htpConfig.is_object()) {
255
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing QnnHtp embedding config");
256
+ }
257
+ validateBackendHtpConfig(htpConfig);
258
+ } else {
259
+ if (htpConfig.is_object()) {
260
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
261
+ "QnnHtp backend config for incorrect backend type: " + type);
262
+ }
263
+ }
264
+
265
+ if (genai) {
266
+ if (!genaiConfig.is_object()) {
267
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
268
+ "Missing QnnGenAiTransformer embedding config");
269
+ }
270
+ validateBackendGenaiConfig(genaiConfig);
271
+ } else {
272
+ if (genaiConfig.is_object()) {
273
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
274
+ "QnnGenAiTransformer backend config for incorrect backend type: " + type);
275
+ }
276
+ }
277
+ }
278
+
279
+ //=============================================================================
280
+ // Model::Config functions
281
+ //=============================================================================
282
+
283
+ static void validateModelBinaryConfig(const qualla::json& config) {
284
+ if (!config.is_object()) {
285
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "binary config is not an object");
286
+ }
287
+
288
+ std::set<std::string> mandatoryFields{"version", "ctx-bins"};
289
+ for (const auto& field : mandatoryFields) {
290
+ if (!config.contains(field)) {
291
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary field: " + field);
292
+ }
293
+ }
294
+
295
+ // component is used in the "ENFORCE" macros
296
+ std::string component = "binary";
297
+
298
+ for (auto& item : config.items()) {
299
+ if (item.key() == "version") {
300
+ JSON_ENFORCE_NUMERIC();
301
+ if (item.value().get<int>() != 1) {
302
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
303
+ "Invalid binary config: unsupported version: " + item.value().dump());
304
+ }
305
+ } else if (item.key() == "ctx-bins") {
306
+ JSON_ENFORCE_ARRAY();
307
+ for (auto& elem : item.value()) {
308
+ if (!elem.is_string()) {
309
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "ctx-bins must be an array of strings");
310
+ }
311
+ }
312
+ } else {
313
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown binary config key: " + item.key());
314
+ }
315
+ }
316
+ }
317
+
318
+ static void validateModelLibraryConfig(const qualla::json& config) {
319
+ if (!config.is_object()) {
320
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "library config is not an object");
321
+ }
322
+
323
+ std::set<std::string> mandatoryFields{"version", "model-bin"};
324
+ for (const auto& field : mandatoryFields) {
325
+ if (!config.contains(field)) {
326
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library field: " + field);
327
+ }
328
+ }
329
+
330
+ // component is used in the "ENFORCE" macros
331
+ std::string component = "library";
332
+
333
+ for (auto& item : config.items()) {
334
+ if (item.key() == "version") {
335
+ JSON_ENFORCE_NUMERIC();
336
+ if (item.value().get<int>() != 1) {
337
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
338
+ "Invalid library config: unsupported version: " + item.value().dump());
339
+ }
340
+ } else if (item.key() == "model-bin") {
341
+ JSON_ENFORCE_STRING();
342
+ } else {
343
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown library config key: " + item.key());
344
+ }
345
+ }
346
+ }
347
+
348
+ static void validateModelConfig(const qualla::json& config) {
349
+ if (!config.is_object()) {
350
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "model config is not an object");
351
+ }
352
+
353
+ std::set<std::string> mandatoryFields{"version", "type"};
354
+ for (const auto& field : mandatoryFields) {
355
+ if (!config.contains(field)) {
356
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing model field: " + field);
357
+ }
358
+ }
359
+
360
+ // component is used in the "ENFORCE" macros
361
+ std::string component = "model";
362
+
363
+ std::string type;
364
+ bool binary = false;
365
+ qualla::json binaryConfig;
366
+ bool library = false;
367
+ qualla::json libraryConfig;
368
+
369
+ for (auto& item : config.items()) {
370
+ if (item.key() == "version") {
371
+ JSON_ENFORCE_NUMERIC();
372
+ if (item.value().get<int>() != 1) {
373
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
374
+ "Invalid model config: unsupported version: " + item.value().dump());
375
+ }
376
+ } else if (item.key() == "type") {
377
+ JSON_ENFORCE_STRING();
378
+ type = item.value().get<std::string>();
379
+ if (type == "binary") {
380
+ binary = true;
381
+ } else if (type == "library") {
382
+ library = true;
383
+ } else {
384
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
385
+ "Invalid model config: unsupported type: " + item.value().dump());
386
+ }
387
+ } else if (item.key() == "binary") {
388
+ JSON_ENFORCE_OBJECT();
389
+ binaryConfig = item.value();
390
+ } else if (item.key() == "library") {
391
+ JSON_ENFORCE_OBJECT();
392
+ libraryConfig = item.value();
393
+ } else {
394
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown model config key: " + item.key());
395
+ }
396
+ }
397
+
398
+ if (binary) {
399
+ if (!binaryConfig.is_object()) {
400
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing binary model config");
401
+ }
402
+ validateModelBinaryConfig(binaryConfig);
403
+ } else {
404
+ if (binaryConfig.is_object()) {
405
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
406
+ "binary model config for incorrect model type: " + type);
407
+ }
408
+ }
409
+
410
+ if (library) {
411
+ if (!libraryConfig.is_object()) {
412
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing library model config");
413
+ }
414
+ validateModelLibraryConfig(libraryConfig);
415
+ } else {
416
+ if (libraryConfig.is_object()) {
417
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
418
+ "library model config for incorrect model type: " + type);
419
+ }
420
+ }
421
+ }
422
+
423
+ //=============================================================================
424
+ // Engine::Config functions
425
+ //=============================================================================
426
+
427
+ static void validateEngineConfig(const qualla::json& config) {
428
+ if (!config.is_object()) {
429
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "engine config is not an object");
430
+ }
431
+
432
+ std::set<std::string> mandatoryFields{"version", "backend", "model"};
433
+ for (const auto& field : mandatoryFields) {
434
+ if (!config.contains(field)) {
435
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing engine field: " + field);
436
+ }
437
+ }
438
+
439
+ // component is used in the "ENFORCE" macros
440
+ std::string component = "engine";
441
+
442
+ for (auto& item : config.items()) {
443
+ if (item.key() == "version") {
444
+ JSON_ENFORCE_NUMERIC();
445
+ if (item.value().get<int>() != 1) {
446
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
447
+ "Invalid engine config: unsupported version: " + item.value().dump());
448
+ }
449
+ } else if (item.key() == "backend") {
450
+ JSON_ENFORCE_OBJECT();
451
+ validateBackendConfig(item.value());
452
+ } else if (item.key() == "model") {
453
+ JSON_ENFORCE_OBJECT();
454
+ validateModelConfig(item.value());
455
+ } else if (item.key() == "n-threads") {
456
+ JSON_ENFORCE_NUMERIC();
457
+ } else {
458
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown engine config key: " + item.key());
459
+ }
460
+ }
461
+ }
462
+
463
+ static void translateEngineConfig(const qualla::json& genieEngineConfig,
464
+ qualla::json& quallaEngineConfig) {
465
+ if (genieEngineConfig["version"] == 1) {
466
+ if (genieEngineConfig.contains("n-threads"))
467
+ quallaEngineConfig["n-threads"] = genieEngineConfig["n-threads"];
468
+
469
+ if (genieEngineConfig["backend"]["type"] == "QnnHtp") {
470
+ quallaEngineConfig["type"] = "qnn-htp";
471
+ quallaEngineConfig["model-architecture-type"] = "encoder",
472
+ quallaEngineConfig["backend-lib"] = getLibName("QnnHtp");
473
+ quallaEngineConfig["use-mmap"] = genieEngineConfig["backend"]["QnnHtp"]["use-mmap"];
474
+ quallaEngineConfig["spill-fill-bufsize"] =
475
+ genieEngineConfig["backend"]["QnnHtp"]["spill-fill-bufsize"];
476
+ quallaEngineConfig["pooled-output"] = genieEngineConfig["backend"]["QnnHtp"]["pooled-output"];
477
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("disable-kv-cache")) {
478
+ quallaEngineConfig["disable-kv-cache"] =
479
+ genieEngineConfig["backend"]["QnnHtp"]["disable-kv-cache"];
480
+ }
481
+ // By default, Qualla will default to the async init path.
482
+ // For now, we are forcing async init off unless explicitly
483
+ // specified in the Genie config. It is HTP specific feature only.
484
+ quallaEngineConfig["use-async-Init"] = false;
485
+ if (genieEngineConfig["backend"]["QnnHtp"].contains("allow-async-init")) {
486
+ quallaEngineConfig["use-async-Init"] =
487
+ genieEngineConfig["backend"]["QnnHtp"]["allow-async-init"];
488
+ }
489
+ } else if (genieEngineConfig["backend"]["type"] == "QnnGenAiTransformer") {
490
+ quallaEngineConfig["type"] = "qnn-cpu";
491
+ quallaEngineConfig["model-output"] = "embeddings";
492
+ quallaEngineConfig["backend-lib"] = getLibName("QnnGenAiTransformer");
493
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-logits")) {
494
+ quallaEngineConfig["n_logits"] =
495
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-logits"];
496
+ }
497
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-layer")) {
498
+ quallaEngineConfig["n_layer"] =
499
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-layer"];
500
+ }
501
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-embd")) {
502
+ quallaEngineConfig["n_embd"] =
503
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-embd"];
504
+ }
505
+ if (genieEngineConfig["backend"]["QnnGenAiTransformer"].contains("n-heads")) {
506
+ quallaEngineConfig["n_heads"] =
507
+ genieEngineConfig["backend"]["QnnGenAiTransformer"]["n-heads"];
508
+ }
509
+ }
510
+
511
+ if (genieEngineConfig["backend"].contains("extensions")) {
512
+ quallaEngineConfig["backend-ext-conf"] = genieEngineConfig["backend"]["extensions"];
513
+ }
514
+
515
+ if (genieEngineConfig["model"]["type"] == "binary") {
516
+ quallaEngineConfig["model-list"] = genieEngineConfig["model"]["binary"]["ctx-bins"];
517
+ } else if (genieEngineConfig["model"]["type"] == "library") {
518
+ quallaEngineConfig["model"] = getLibName("QnnGenAiTransformerModel");
519
+ quallaEngineConfig["model-bin-path"] = genieEngineConfig["model"]["library"]["model-bin"];
520
+ quallaEngineConfig["op-package"] =
521
+ getLibName("QnnGenAiTransformerCpuOpPkg") + ":QnnOpPackage_interfaceProvider";
522
+ }
523
+ }
524
+ }
525
+
526
+ //=============================================================================
527
+ // Prompt::Config functions
528
+ //=============================================================================
529
+
530
+ static void validatePromptConfig(const qualla::json& config) {
531
+ if (!config.is_object()) {
532
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "prompt config is not an object");
533
+ }
534
+
535
+ std::set<std::string> mandatoryFields{"version", "prompt-template"};
536
+ for (const auto& field : mandatoryFields) {
537
+ if (!config.contains(field)) {
538
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing prompt field: " + field);
539
+ }
540
+ }
541
+
542
+ // component is used in the "ENFORCE" macros
543
+ std::string component = "prompt";
544
+
545
+ for (auto& item : config.items()) {
546
+ if (item.key() == "version") {
547
+ JSON_ENFORCE_NUMERIC();
548
+ if (item.value().get<int>() != 1) {
549
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
550
+ "Invalid context config: unsupported version: " + item.value().dump());
551
+ }
552
+ } else if (item.key() == "prompt-template") {
553
+ JSON_ENFORCE_ARRAY();
554
+ for (auto& elem : item.value()) {
555
+ if (!elem.is_string()) {
556
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE, "prompt tags must be an array of strings");
557
+ }
558
+ }
559
+ } else {
560
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown context config key: " + item.key());
561
+ }
562
+ }
563
+ }
564
+
565
+ static void translatePromptConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
566
+ quallaConfig["tags"] = genieConfig["prompt-template"];
567
+ }
568
+
569
+ //=============================================================================
570
+ // Embedding::Config functions
571
+ //=============================================================================
572
+
573
+ qnn::util::HandleManager<Embedding::Config> Embedding::Config::s_manager;
574
+
575
+ GenieEmbeddingConfig_Handle_t Embedding::Config::add(std::shared_ptr<Embedding::Config> config) {
576
+ return (GenieEmbeddingConfig_Handle_t)s_manager.add(config);
577
+ }
578
+
579
+ std::shared_ptr<Embedding::Config> Embedding::Config::get(GenieEmbeddingConfig_Handle_t handle) {
580
+ return s_manager.get((qnn::util::Handle_t)handle);
581
+ }
582
+
583
+ void Embedding::Config::remove(GenieEmbeddingConfig_Handle_t handle) {
584
+ s_manager.remove((qnn::util::Handle_t)handle);
585
+ }
586
+
587
+ static void validateEmbeddingConfig(const qualla::json& config) {
588
+ if (!config.is_object()) {
589
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Embedding config is not an object");
590
+ }
591
+
592
+ std::set<std::string> mandatoryFields{"version", "context", "tokenizer", "engine"};
593
+ for (const auto& field : mandatoryFields) {
594
+ if (!config.contains(field)) {
595
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
596
+ }
597
+ }
598
+
599
+ // component is used in the "ENFORCE" macros
600
+ std::string component = "embedding";
601
+
602
+ for (auto& item : config.items()) {
603
+ if (item.key() == "version") {
604
+ JSON_ENFORCE_NUMERIC();
605
+ if (item.value().get<int>() != 1) {
606
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
607
+ "Invalid embedding config: unsupported version: " + item.value().dump());
608
+ }
609
+ } else if (item.key() == "context") {
610
+ JSON_ENFORCE_OBJECT();
611
+ validateContextConfig(item.value());
612
+ } else if (item.key() == "tokenizer") {
613
+ JSON_ENFORCE_OBJECT();
614
+ validateTokenizerConfig(item.value());
615
+ } else if (item.key() == "prompt") { // optional parameter
616
+ JSON_ENFORCE_OBJECT();
617
+ validatePromptConfig(item.value());
618
+ } else if (item.key() == "truncate-input") { // optional parameter
619
+ JSON_ENFORCE_BOOLEAN();
620
+ } else if (item.key() == "engine") {
621
+ JSON_ENFORCE_OBJECT();
622
+ validateEngineConfig(config["engine"]);
623
+ } else {
624
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
625
+ "Unknown embedding config key: " + item.key());
626
+ }
627
+ }
628
+ }
629
+
630
+ static void translateEmbeddingConfig(const qualla::json& genieConfig, qualla::json& quallaConfig) {
631
+ translateContextConfig(genieConfig["context"], quallaConfig["context"]);
632
+ translatePromptConfig(genieConfig["prompt"], quallaConfig["prompt"]);
633
+ translateTokenizerConfig(genieConfig["tokenizer"], quallaConfig);
634
+ translateEngineConfig(genieConfig["engine"], quallaConfig["engine"]);
635
+
636
+ if (genieConfig.contains(
637
+ "truncate-input")) { // to allow truncation of input incase it exceeds the context.
638
+ quallaConfig["truncate-input"] = genieConfig["truncate-input"];
639
+ }
640
+ }
641
+
642
+ Embedding::Config::Config(const char* configStr) {
643
+ qualla::json config;
644
+
645
+ {
646
+ std::set<qualla::json> keys;
647
+
648
+ auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
649
+ if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
650
+ if (keys.count(parsed) > 0) {
651
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
652
+ "Multiple embedding config key: " + parsed.dump());
653
+ }
654
+ keys.insert(parsed);
655
+ }
656
+ return true;
657
+ };
658
+
659
+ config = qualla::json::parse(configStr, callback);
660
+ }
661
+
662
+ if (!config.is_object()) {
663
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Embedding config is not an object");
664
+ }
665
+
666
+ std::set<std::string> mandatoryFields{"embedding"};
667
+ for (const auto& field : mandatoryFields) {
668
+ if (!config.contains(field)) {
669
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing embedding field: " + field);
670
+ }
671
+ }
672
+
673
+ // component is used in the "ENFORCE" macros
674
+ std::string component = "embedding";
675
+
676
+ for (auto& item : config.items()) {
677
+ if (item.key() == "embedding") {
678
+ JSON_ENFORCE_OBJECT();
679
+ validateEmbeddingConfig(item.value());
680
+ } else {
681
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
682
+ "Unknown embedding config key: " + item.key());
683
+ }
684
+ }
685
+ m_config = config;
686
+ }
687
+
688
+ qualla::json Embedding::Config::getJson() const { return m_config; }
689
+
690
+ //=============================================================================
691
+ // Embedding functions
692
+ //=============================================================================
693
+
694
+ qnn::util::HandleManager<Embedding> Embedding::s_manager;
695
+ std::atomic<std::uint32_t> Embedding::s_nameCounter{0u};
696
+
697
+ GenieEmbedding_Handle_t Embedding::add(std::shared_ptr<Embedding> embedding) {
698
+ return (GenieEmbedding_Handle_t)s_manager.add(embedding);
699
+ }
700
+
701
+ std::shared_ptr<Embedding> Embedding::get(GenieEmbedding_Handle_t handle) {
702
+ return s_manager.get((qnn::util::Handle_t)handle);
703
+ }
704
+
705
+ void Embedding::remove(GenieEmbedding_Handle_t handle) {
706
+ s_manager.remove((qnn::util::Handle_t)handle);
707
+ }
708
+
709
+ Embedding::Embedding(std::shared_ptr<Config> config) {
710
+ auto env = qualla::Env::create(qualla::json{});
711
+ qualla::json quallaConfig;
712
+ translateEmbeddingConfig(config->getJson()["embedding"], quallaConfig);
713
+ m_quallaEmbedding = qualla::Embedding::create(
714
+ env, "embedding" + std::to_string(s_nameCounter.fetch_add(1u)), quallaConfig);
715
+ if (!m_quallaEmbedding) {
716
+ throw Exception(GENIE_STATUS_ERROR_MEM_ALLOC, "Could not create a embedding object");
717
+ }
718
+ }
719
+
720
+ int32_t Embedding::generate(const char* queryStr,
721
+ GenieEmbedding_GenerateCallback_t callback,
722
+ const void* userData) {
723
+ std::string query(queryStr);
724
+ std::vector<float> outputEmbedding;
725
+ bool status = false;
726
+ status = m_quallaEmbedding->query(query, outputEmbedding);
727
+ if (status) {
728
+ std::vector<uint32_t> dimensions;
729
+ m_quallaEmbedding->output_dimensions(dimensions);
730
+ callback(dimensions.data(), dimensions.size(), outputEmbedding.data(), userData);
731
+ qualla::Embedding::KPIs kpis = m_quallaEmbedding->kpis();
732
+ printf(
733
+ "\n\n[KPIS]:\nInit Time: %zu us\nPrompt Processing Time: %zu us, Prompt Processing Rate : "
734
+ "%f toks/sec\n",
735
+ kpis.init.total_usec,
736
+ kpis.prompt.last_usec,
737
+ kpis.tps.prompt);
738
+ }
739
+ return (status) ? (GENIE_STATUS_SUCCESS) : (GENIE_STATUS_ERROR_GENERATE_FAILED);
740
+ }
Genie/Genie/src/Embedding.hpp ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+
11
+ #include <atomic>
12
+ #include <memory>
13
+
14
+ #include "GenieEmbedding.h"
15
+ #include "Util/HandleManager.hpp"
16
+ #include "qualla/embedding.hpp"
17
+
18
+ namespace genie {
19
+
20
+ class Embedding {
21
+ public:
22
+ class Config {
23
+ public:
24
+ static GenieEmbeddingConfig_Handle_t add(std::shared_ptr<Config> config);
25
+ static std::shared_ptr<Config> get(GenieEmbeddingConfig_Handle_t handle);
26
+ static void remove(GenieEmbeddingConfig_Handle_t handle);
27
+
28
+ Config(const char* configStr);
29
+ qualla::json getJson() const;
30
+
31
+ private:
32
+ static qnn::util::HandleManager<Config> s_manager;
33
+ qualla::json m_config;
34
+ };
35
+
36
+ static GenieEmbedding_Handle_t add(std::shared_ptr<Embedding> embedding);
37
+ static std::shared_ptr<Embedding> get(GenieEmbedding_Handle_t handle);
38
+ static void remove(GenieEmbedding_Handle_t handle);
39
+
40
+ Embedding(std::shared_ptr<Config> config);
41
+
42
+ Embedding(const Embedding&) = delete;
43
+ Embedding& operator=(const Embedding&) = delete;
44
+ Embedding(Embedding&&) = delete;
45
+ Embedding& operator=(Embedding&&) = delete;
46
+
47
+ int32_t generate(const char* queryStr,
48
+ GenieEmbedding_GenerateCallback_t callback,
49
+ const void* userData);
50
+
51
+ private:
52
+ std::unique_ptr<qualla::Embedding> m_quallaEmbedding;
53
+ static qnn::util::HandleManager<Embedding> s_manager;
54
+ static std::atomic<std::uint32_t> s_nameCounter;
55
+ };
56
+ } // namespace genie
Genie/Genie/src/Exception.hpp CHANGED
@@ -9,6 +9,7 @@
9
  #pragma once
10
 
11
  #include <exception>
 
12
  #include <string>
13
 
14
  #include "GenieCommon.h"
 
9
  #pragma once
10
 
11
  #include <exception>
12
+ #include <stdexcept>
13
  #include <string>
14
 
15
  #include "GenieCommon.h"
Genie/Genie/src/GenieDialog.cpp CHANGED
@@ -232,6 +232,24 @@ Genie_Status_t GenieDialog_tokenQuery(const GenieDialog_Handle_t dialogHandle,
232
  return status;
233
  }
234
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
  GENIE_API
236
  Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
237
  try {
@@ -246,4 +264,4 @@ Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
246
  return GENIE_STATUS_ERROR_GENERAL;
247
  }
248
  return GENIE_STATUS_SUCCESS;
249
- }
 
232
  return status;
233
  }
234
 
235
+ GENIE_API
236
+ Genie_Status_t GenieDialog_getSampler(const GenieDialog_Handle_t dialogHandle,
237
+ GenieSampler_Handle_t* dialogSamplerHandle) {
238
+ try {
239
+ GENIE_ENSURE(dialogHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
240
+ auto dialog = genie::Dialog::get(dialogHandle);
241
+ GENIE_ENSURE(dialog, GENIE_STATUS_ERROR_INVALID_HANDLE);
242
+ GENIE_ENSURE(dialogSamplerHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
243
+ *dialogSamplerHandle = genie::Dialog::getSamplerHandle(dialog);
244
+ GENIE_ENSURE(*dialogSamplerHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
245
+ } catch (const std::exception& e) {
246
+ std::cerr << e.what() << std::endl;
247
+ return GENIE_STATUS_ERROR_GET_HANDLE_FAILED;
248
+ }
249
+
250
+ return GENIE_STATUS_SUCCESS;
251
+ }
252
+
253
  GENIE_API
254
  Genie_Status_t GenieDialog_free(const GenieDialog_Handle_t dialogHandle) {
255
  try {
 
264
  return GENIE_STATUS_ERROR_GENERAL;
265
  }
266
  return GENIE_STATUS_SUCCESS;
267
+ }
Genie/Genie/src/GenieEmbedding.cpp ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //=============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //=============================================================================
8
+
9
+ #include "Embedding.hpp"
10
+ #include "Exception.hpp"
11
+ #include "GenieEmbedding.h"
12
+ #include "Macro.hpp"
13
+ #include "Util/HandleManager.hpp"
14
+ #include "qualla/detail/json.hpp"
15
+
16
+ using namespace genie;
17
+
18
+ GENIE_API
19
+ Genie_Status_t GenieEmbeddingConfig_createFromJson(const char* str,
20
+ GenieEmbeddingConfig_Handle_t* configHandle) {
21
+ try {
22
+ GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
23
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
24
+ auto config = std::make_shared<Embedding::Config>(str);
25
+ GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC);
26
+ *configHandle = genie::Embedding::Config::add(config);
27
+ } catch (const qualla::json::parse_error& e) {
28
+ std::cerr << e.what() << std::endl;
29
+ return GENIE_STATUS_ERROR_JSON_FORMAT;
30
+ } catch (const Exception& e) {
31
+ std::cerr << e.what() << std::endl;
32
+ return e.status();
33
+ } catch (const std::exception& e) {
34
+ std::cerr << e.what() << std::endl;
35
+ return GENIE_STATUS_ERROR_GENERAL;
36
+ }
37
+ return GENIE_STATUS_SUCCESS;
38
+ }
39
+
40
+ GENIE_API
41
+ Genie_Status_t GenieEmbeddingConfig_free(const GenieEmbeddingConfig_Handle_t configHandle) {
42
+ try {
43
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
44
+ {
45
+ // Check if the embedding actually exists
46
+ auto configObj = genie::Embedding::Config::get(configHandle);
47
+ GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
48
+ }
49
+ genie::Embedding::Config::remove(configHandle);
50
+ } catch (const std::exception& e) {
51
+ return GENIE_STATUS_ERROR_GENERAL;
52
+ }
53
+ return GENIE_STATUS_SUCCESS;
54
+ }
55
+
56
+ GENIE_API
57
+ Genie_Status_t GenieEmbedding_create(const GenieEmbeddingConfig_Handle_t configHandle,
58
+ GenieEmbedding_Handle_t* embeddingHandle) {
59
+ try {
60
+ GENIE_ENSURE(embeddingHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
61
+
62
+ // Get config object
63
+ auto configObj = genie::Embedding::Config::get(configHandle);
64
+ GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
65
+
66
+ // Create embedding
67
+ auto embedding = std::make_shared<genie::Embedding>(configObj);
68
+ GENIE_ENSURE(embedding, GENIE_STATUS_ERROR_MEM_ALLOC);
69
+
70
+ // Create Handle
71
+ *embeddingHandle = genie::Embedding::add(embedding);
72
+ } catch (const std::exception& e) {
73
+ std::cerr << e.what() << std::endl;
74
+ return GENIE_STATUS_ERROR_GENERAL;
75
+ }
76
+
77
+ // Return SUCCESS
78
+ return GENIE_STATUS_SUCCESS;
79
+ }
80
+
81
+ GENIE_API
82
+ Genie_Status_t GenieEmbedding_generate(const GenieEmbedding_Handle_t embeddingHandle,
83
+ const char* queryStr,
84
+ const GenieEmbedding_GenerateCallback_t callback,
85
+ const void* userData) {
86
+ int32_t status;
87
+
88
+ try {
89
+ GENIE_ENSURE(embeddingHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
90
+ auto embedding = genie::Embedding::get(embeddingHandle);
91
+ GENIE_ENSURE(embedding, GENIE_STATUS_ERROR_INVALID_HANDLE);
92
+ GENIE_ENSURE(queryStr, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
93
+ GENIE_ENSURE(callback, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
94
+
95
+ status = embedding->generate(queryStr, callback, userData);
96
+ } catch (const std::exception& e) {
97
+ std::cerr << e.what() << std::endl;
98
+ return GENIE_STATUS_ERROR_GENERAL;
99
+ }
100
+
101
+ return status;
102
+ }
103
+
104
+ GENIE_API
105
+ Genie_Status_t GenieEmbedding_free(const GenieEmbedding_Handle_t embeddingHandle) {
106
+ try {
107
+ GENIE_ENSURE(embeddingHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
108
+ {
109
+ // Check if the embedding actually exists
110
+ auto embedding = genie::Embedding::get(embeddingHandle);
111
+ GENIE_ENSURE(embedding, GENIE_STATUS_ERROR_INVALID_HANDLE);
112
+ }
113
+ genie::Embedding::remove(embeddingHandle);
114
+ } catch (const std::exception& e) {
115
+ return GENIE_STATUS_ERROR_GENERAL;
116
+ }
117
+ return GENIE_STATUS_SUCCESS;
118
+ }
Genie/Genie/src/GenieSampler.cpp ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //=============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //=============================================================================
8
+
9
+ #include <iostream>
10
+
11
+ #include "Exception.hpp"
12
+ #include "GenieSampler.h"
13
+ #include "Macro.hpp"
14
+ #include "Sampler.hpp"
15
+ #include "Util/HandleManager.hpp"
16
+ #include "qualla/detail/json.hpp"
17
+
18
+ using namespace genie;
19
+ GENIE_API
20
+ Genie_Status_t GenieSamplerConfig_createFromJson(const char* str,
21
+ GenieSamplerConfig_Handle_t* configHandle) {
22
+ try {
23
+ GENIE_ENSURE(str, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
24
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_ARGUMENT);
25
+ auto config = std::make_shared<Sampler::Sampler::SamplerConfig>(str);
26
+ GENIE_ENSURE(config, GENIE_STATUS_ERROR_MEM_ALLOC);
27
+ *configHandle = Sampler::Sampler::SamplerConfig::add(config);
28
+ } catch (const qualla::json::parse_error& e) {
29
+ std::cerr << e.what() << std::endl;
30
+ return GENIE_STATUS_ERROR_JSON_FORMAT;
31
+ } catch (const Exception& e) {
32
+ std::cerr << e.what() << std::endl;
33
+ return e.status();
34
+ } catch (const std::exception& e) {
35
+ std::cerr << e.what() << std::endl;
36
+ return GENIE_STATUS_ERROR_GENERAL;
37
+ }
38
+ return GENIE_STATUS_SUCCESS;
39
+ }
40
+
41
+ GENIE_API
42
+ Genie_Status_t GenieSamplerConfig_free(const GenieSamplerConfig_Handle_t configHandle) {
43
+ try {
44
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
45
+ {
46
+ // Check if the dialog actually exists
47
+ auto configObj = Sampler::SamplerConfig::get(configHandle);
48
+ GENIE_ENSURE(configObj, GENIE_STATUS_ERROR_INVALID_HANDLE);
49
+ }
50
+ Sampler::SamplerConfig::remove(configHandle);
51
+ } catch (const std::exception& e) {
52
+ return GENIE_STATUS_ERROR_GENERAL;
53
+ }
54
+ return GENIE_STATUS_SUCCESS;
55
+ }
56
+
57
+ GENIE_API
58
+ Genie_Status_t GenieSamplerConfig_setParam(const GenieSamplerConfig_Handle_t configHandle,
59
+ const char* keyStr,
60
+ const char* valueStr) {
61
+ try {
62
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
63
+ auto samplerConfig = Sampler::SamplerConfig::get(configHandle);
64
+ GENIE_ENSURE(samplerConfig, GENIE_STATUS_ERROR_INVALID_HANDLE);
65
+ samplerConfig->setParam(keyStr, valueStr);
66
+ } catch (const std::exception& e) {
67
+ std::cerr << e.what() << std::endl;
68
+ return GENIE_STATUS_ERROR_SET_PARAMS_FAILED;
69
+ }
70
+ return GENIE_STATUS_SUCCESS;
71
+ }
72
+
73
+ GENIE_API
74
+ Genie_Status_t GenieSampler_applyConfig(const GenieSampler_Handle_t samplerHandle,
75
+ const GenieSamplerConfig_Handle_t configHandle) {
76
+ try {
77
+ GENIE_ENSURE(samplerHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
78
+ GENIE_ENSURE(configHandle, GENIE_STATUS_ERROR_INVALID_HANDLE);
79
+
80
+ auto sampler = Sampler::get(samplerHandle);
81
+ GENIE_ENSURE(sampler, GENIE_STATUS_ERROR_INVALID_HANDLE);
82
+
83
+ auto samplerConfig = Sampler::SamplerConfig::get(configHandle);
84
+ GENIE_ENSURE(samplerConfig, GENIE_STATUS_ERROR_INVALID_HANDLE);
85
+
86
+ sampler->applyConfig(samplerConfig->getJson());
87
+
88
+ } catch (const std::exception& e) {
89
+ std::cerr << e.what() << std::endl;
90
+ return GENIE_STATUS_ERROR_APPLY_CONFIG_FAILED;
91
+ }
92
+ return GENIE_STATUS_SUCCESS;
93
+ }
Genie/Genie/src/Macro.hpp CHANGED
@@ -8,6 +8,8 @@
8
 
9
  #pragma once
10
 
 
 
11
  //======================================================================================================================
12
  // Error generation macros
13
  //======================================================================================================================
 
8
 
9
  #pragma once
10
 
11
+ #define ENABLE_DEBUG_LOGS 0
12
+
13
  //======================================================================================================================
14
  // Error generation macros
15
  //======================================================================================================================
Genie/Genie/src/Sampler.cpp ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+ #include <exception>
9
+ #include <set>
10
+
11
+ #include "Exception.hpp"
12
+ #include "Macro.hpp"
13
+ #include "Sampler.hpp"
14
+ #include "qualla/detail/json.hpp"
15
+
16
+ using namespace genie;
17
+
18
+ //=============================================================================
19
+ // Sampler functions
20
+ //=============================================================================
21
+
22
+ qnn::util::HandleManager<Sampler> Sampler::s_manager;
23
+
24
+ GenieSampler_Handle_t Sampler::add(std::shared_ptr<Sampler> config) {
25
+ return (GenieSampler_Handle_t)s_manager.add(config);
26
+ }
27
+
28
+ std::shared_ptr<Sampler> Sampler::get(GenieSampler_Handle_t handle) {
29
+ return s_manager.get((qnn::util::Handle_t)handle);
30
+ }
31
+
32
+ void Sampler::remove(GenieSampler_Handle_t handle) {
33
+ s_manager.remove((qnn::util::Handle_t)handle);
34
+ }
35
+
36
+ Sampler::Sampler(qualla::json& origJson,
37
+ std::vector<std::reference_wrapper<qualla::Sampler>>& quallaSamplers)
38
+ : m_origJson(origJson), m_quallaSamplers(quallaSamplers) {}
39
+
40
+ void Sampler::applyConfig(qualla::json samplerConfigJson) {
41
+ m_origJson["sampler"]["seed"] = qualla::Config::optional<int32_t>(
42
+ samplerConfigJson["sampler"], "seed", m_origJson["sampler"]["seed"]);
43
+ m_origJson["sampler"]["temp"] = qualla::Config::optional<float>(
44
+ samplerConfigJson["sampler"], "temp", m_origJson["sampler"]["temp"]);
45
+ m_origJson["sampler"]["top-k"] = qualla::Config::optional<size_t>(
46
+ samplerConfigJson["sampler"], "top-k", m_origJson["sampler"]["top-k"]);
47
+ m_origJson["sampler"]["top-p"] = qualla::Config::optional<float>(
48
+ samplerConfigJson["sampler"], "top-p", m_origJson["sampler"]["top-p"]);
49
+ m_origJson["sampler"]["version"] =
50
+ qualla::Config::optional<int32_t>(samplerConfigJson["sampler"], "version", 1);
51
+ m_origJson["sampler"]["type"] = "basic";
52
+
53
+ #if ENABLE_DEBUG_LOGS
54
+ std::cout << "Updated sampler config: " << std::endl;
55
+ std::cout << "temp: " << m_origJson["sampler"]["temp"].get<double>() << std::endl;
56
+ std::cout << "top-k: " << m_origJson["sampler"]["top-k"] << std::endl;
57
+ std::cout << "top-p: " << m_origJson["sampler"]["top-p"].get<double>() << std::endl;
58
+ std::cout << "seed: " << m_origJson["sampler"]["seed"] << std::endl;
59
+ #endif
60
+ // Loop through the live qualla sampler instances and update the parameters
61
+ for (auto& quallaSampler : m_quallaSamplers) {
62
+ quallaSampler.get().applyConfig(m_origJson["sampler"]);
63
+ }
64
+ }
65
+
66
+ //=============================================================================
67
+ // Sampler::SamplerConfig functions
68
+ //=============================================================================
69
+
70
+ qnn::util::HandleManager<Sampler::SamplerConfig> Sampler::SamplerConfig::s_manager;
71
+
72
+ GenieSamplerConfig_Handle_t Sampler::SamplerConfig::add(
73
+ std::shared_ptr<Sampler::SamplerConfig> config) {
74
+ return (GenieSamplerConfig_Handle_t)s_manager.add(config);
75
+ }
76
+
77
+ std::shared_ptr<Sampler::SamplerConfig> Sampler::SamplerConfig::get(
78
+ GenieSamplerConfig_Handle_t handle) {
79
+ return s_manager.get((qnn::util::Handle_t)handle);
80
+ }
81
+
82
+ void Sampler::SamplerConfig::remove(GenieSamplerConfig_Handle_t handle) {
83
+ s_manager.remove((qnn::util::Handle_t)handle);
84
+ }
85
+
86
+ Sampler::SamplerConfig::SamplerConfig(const char* configStr) {
87
+ qualla::json quallaConfig;
88
+ qualla::json config;
89
+ {
90
+ std::set<qualla::json> keys;
91
+
92
+ auto callback = [&keys](int depth, qualla::json::parse_event_t event, qualla::json& parsed) {
93
+ if ((depth == 1) && (event == qualla::json::parse_event_t::key)) {
94
+ if (keys.count(parsed) > 0) {
95
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
96
+ "Multiple sampler config key: " + parsed.dump());
97
+ }
98
+ keys.insert(parsed);
99
+ }
100
+ return true;
101
+ };
102
+
103
+ config = qualla::json::parse(configStr, callback);
104
+ }
105
+
106
+ if (!config.is_object()) {
107
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Sampler config is not an object");
108
+ }
109
+
110
+ if (!config.contains("sampler")) {
111
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing field: sampler");
112
+ }
113
+
114
+ // component is used in the "ENFORCE" macros
115
+ const std::string component = "sampler";
116
+ for (auto& item : config.items()) {
117
+ if (item.key() == "sampler") {
118
+ JSON_ENFORCE_OBJECT();
119
+ validateSamplerConfig(item.value());
120
+ } else {
121
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
122
+ }
123
+ }
124
+
125
+ if (config["sampler"].contains("seed"))
126
+ quallaConfig["sampler"]["seed"] = config["sampler"]["seed"];
127
+ if (config["sampler"].contains("temp"))
128
+ quallaConfig["sampler"]["temp"] = config["sampler"]["temp"];
129
+ if (config["sampler"].contains("top-k"))
130
+ quallaConfig["sampler"]["top-k"] = config["sampler"]["top-k"];
131
+ if (config["sampler"].contains("top-p"))
132
+ quallaConfig["sampler"]["top-p"] = config["sampler"]["top-p"];
133
+ if (config["sampler"].contains("greedy"))
134
+ quallaConfig["sampler"]["greedy"] = config["sampler"]["greedy"];
135
+ if (config["sampler"].contains("version"))
136
+ quallaConfig["sampler"]["version"] = config["sampler"]["version"];
137
+ else
138
+ quallaConfig["sampler"]["version"] = 1;
139
+
140
+ quallaConfig["sampler"]["type"] = "basic";
141
+
142
+ m_config = quallaConfig;
143
+ }
144
+
145
+ void Sampler::SamplerConfig::setParam(const std::string& keyStr, const std::string& valueStr) {
146
+ if (!keyStr.empty()) {
147
+ // Case 1: Only the parameter mentioned in keyStr is to be updated by valueStr
148
+ std::set<std::string> validParams = {"seed", "top-p", "top-k", "temp"};
149
+ if (std::find(validParams.begin(), validParams.end(), keyStr) == validParams.end()) {
150
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Invalid key obtained: " + keyStr);
151
+ }
152
+ try {
153
+ if (keyStr == "seed")
154
+ m_config["sampler"]["seed"] = std::stoi(valueStr);
155
+ else if (keyStr == "top-p")
156
+ m_config["sampler"]["top-p"] = std::stof(valueStr);
157
+ else if (keyStr == "top-k")
158
+ m_config["sampler"]["top-k"] = std::stof(valueStr);
159
+ else if (keyStr == "temp")
160
+ m_config["sampler"]["temp"] = std::stof(valueStr);
161
+ } catch (const std::invalid_argument& e) {
162
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
163
+ "Invalid value obtained: " + valueStr + " for key: " + keyStr);
164
+ }
165
+ } else {
166
+ // Case 2: User has passed entire json as a string in valueStr
167
+
168
+ if (valueStr.empty())
169
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Both keyStr and valueStr cannot be empty");
170
+
171
+ qualla::json config = qualla::json::parse(valueStr);
172
+ if (!config.contains("sampler")) {
173
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing field: sampler");
174
+ }
175
+
176
+ // component is used in the "ENFORCE" macros
177
+ const std::string component = "sampler";
178
+ for (auto& item : config.items()) {
179
+ if (item.key() == "sampler") {
180
+ JSON_ENFORCE_OBJECT();
181
+ validateSamplerConfig(item.value());
182
+ } else {
183
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA,
184
+ "Unknown sampler config key: " + item.key());
185
+ }
186
+ }
187
+
188
+ m_config["sampler"]["seed"] =
189
+ qualla::Config::optional<int32_t>(config["sampler"], "seed", m_config["sampler"]["seed"]);
190
+ m_config["sampler"]["temp"] =
191
+ qualla::Config::optional<float>(config["sampler"], "temp", m_config["sampler"]["temp"]);
192
+ m_config["sampler"]["top-k"] =
193
+ qualla::Config::optional<size_t>(config["sampler"], "top-k", m_config["sampler"]["top-k"]);
194
+ m_config["sampler"]["top-p"] =
195
+ qualla::Config::optional<float>(config["sampler"], "top-p", m_config["sampler"]["top-p"]);
196
+ m_config["sampler"]["version"] = qualla::Config::optional<int32_t>(
197
+ config["sampler"], "version", m_config["sampler"]["version"]);
198
+
199
+ m_config["sampler"]["type"] = "basic";
200
+ }
201
+ }
202
+
203
+ void Sampler::SamplerConfig::validateSamplerConfig(const qualla::json& config) {
204
+ if (!config.is_object()) {
205
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "sampler config is not an object");
206
+ }
207
+
208
+ const std::set<std::string> mandatoryFields{"version"};
209
+ for (const auto& field : mandatoryFields) {
210
+ if (!config.contains(field)) {
211
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Missing sampler field: " + field);
212
+ }
213
+ }
214
+
215
+ // component is used in the "ENFORCE" macros
216
+ const std::string component = "sampler";
217
+
218
+ for (auto& item : config.items()) {
219
+ if (item.key() == "version") {
220
+ JSON_ENFORCE_NUMERIC();
221
+ if (item.value().get<int>() != 1) {
222
+ throw Exception(GENIE_STATUS_ERROR_JSON_VALUE,
223
+ "Invalid sampler config: unsupported version: " + item.value().dump());
224
+ }
225
+ } else if (item.key() == "seed") {
226
+ JSON_ENFORCE_NUMERIC();
227
+ } else if (item.key() == "temp") {
228
+ JSON_ENFORCE_NUMERIC();
229
+ } else if (item.key() == "top-k") {
230
+ JSON_ENFORCE_NUMERIC();
231
+ } else if (item.key() == "top-p") {
232
+ JSON_ENFORCE_NUMERIC();
233
+ } else if (item.key() == "greedy") {
234
+ JSON_ENFORCE_BOOLEAN();
235
+ } else {
236
+ throw Exception(GENIE_STATUS_ERROR_JSON_SCHEMA, "Unknown sampler config key: " + item.key());
237
+ }
238
+ }
239
+ }
240
+
241
+ void Sampler::SamplerConfig::translateSamplerConfig(const qualla::json& genieConfig,
242
+ qualla::json& quallaConfig) {
243
+ if (genieConfig["dialog"].contains("sampler")) {
244
+ quallaConfig["sampler"]["type"] = "basic";
245
+
246
+ if (genieConfig["dialog"]["sampler"].contains("seed")) {
247
+ quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
248
+ }
249
+ if (genieConfig["dialog"]["sampler"].contains("temp")) {
250
+ quallaConfig["sampler"]["temp"] = genieConfig["dialog"]["sampler"]["temp"];
251
+ }
252
+
253
+ quallaConfig["sampler"]["role"] = "primary";
254
+ #if defined(GENIE_SPD_FEATURE)
255
+ if (genieConfig["dialog"]["type"] == "spd") {
256
+ quallaConfig["sampler"]["role"] = "target";
257
+ }
258
+ #endif
259
+
260
+ if (genieConfig["dialog"]["sampler"].contains("top-k")) {
261
+ quallaConfig["sampler"]["top-k"] = genieConfig["dialog"]["sampler"]["top-k"];
262
+ }
263
+ if (genieConfig["dialog"]["sampler"].contains("top-p")) {
264
+ quallaConfig["sampler"]["top-p"] = genieConfig["dialog"]["sampler"]["top-p"];
265
+ }
266
+ if (genieConfig["dialog"]["sampler"].contains("greedy")) {
267
+ quallaConfig["sampler"]["greedy"] = genieConfig["dialog"]["sampler"]["greedy"];
268
+ }
269
+ if (genieConfig["dialog"]["sampler"].contains("seed")) {
270
+ quallaConfig["sampler"]["seed"] = genieConfig["dialog"]["sampler"]["seed"];
271
+ }
272
+ }
273
+ }
274
+
275
+ qualla::json Sampler::SamplerConfig::getJson() const { return m_config; }
Genie/Genie/src/Sampler.hpp ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #pragma once
10
+ #include <memory>
11
+
12
+ #include "GenieSampler.h"
13
+ #include "Util/HandleManager.hpp"
14
+ #include "qualla/env.hpp"
15
+ #include "qualla/sampler.hpp"
16
+
17
+ namespace genie {
18
+ class Sampler {
19
+ public:
20
+ class SamplerConfig {
21
+ public:
22
+ static GenieSamplerConfig_Handle_t add(std::shared_ptr<SamplerConfig> config);
23
+
24
+ static std::shared_ptr<SamplerConfig> get(GenieSamplerConfig_Handle_t handle);
25
+
26
+ static void remove(GenieSamplerConfig_Handle_t handle);
27
+
28
+ static void validateSamplerConfig(const qualla::json& config);
29
+
30
+ static void translateSamplerConfig(const qualla::json& genieConfig, qualla::json& quallaConfig);
31
+
32
+ SamplerConfig(const char* configStr);
33
+
34
+ void setParam(const std::string& keyStr, const std::string& valueStr);
35
+
36
+ qualla::json getJson() const;
37
+
38
+ private:
39
+ static qnn::util::HandleManager<SamplerConfig> s_manager;
40
+ qualla::json m_config;
41
+ };
42
+
43
+ static GenieSampler_Handle_t add(std::shared_ptr<Sampler> sampler);
44
+ static std::shared_ptr<Sampler> get(GenieSampler_Handle_t handle);
45
+ static void remove(GenieSampler_Handle_t handle);
46
+
47
+ Sampler(qualla::json& origJson,
48
+ std::vector<std::reference_wrapper<qualla::Sampler>>& quallaSamplers);
49
+
50
+ void applyConfig(qualla::json samplerConfigJson);
51
+
52
+ const qualla::json& getJson();
53
+
54
+ private:
55
+ qualla::json m_origJson;
56
+ static qnn::util::HandleManager<Sampler> s_manager;
57
+ std::vector<std::reference_wrapper<qualla::Sampler>> m_quallaSamplers;
58
+ };
59
+
60
+ } // namespace genie
Genie/Genie/src/qualla/context.cpp CHANGED
@@ -93,6 +93,10 @@ extern void needQnnHtpEngine();
93
  extern void needQnnCpuEngine();
94
  #endif
95
 
 
 
 
 
96
  static OnLoad needs([]() {
97
  needStdoutLogger();
98
  needFileLogger();
@@ -111,6 +115,10 @@ static OnLoad needs([]() {
111
  #ifdef QUALLA_ENGINE_QNN_CPU
112
  needQnnCpuEngine();
113
  #endif
 
 
 
 
114
  });
115
 
116
  #endif
 
93
  extern void needQnnCpuEngine();
94
  #endif
95
 
96
+ #ifdef QUALLA_ENGINE_QNN_GPU
97
+ extern void needQnnGpuEngine();
98
+ #endif
99
+
100
  static OnLoad needs([]() {
101
  needStdoutLogger();
102
  needFileLogger();
 
115
  #ifdef QUALLA_ENGINE_QNN_CPU
116
  needQnnCpuEngine();
117
  #endif
118
+
119
+ #ifdef QUALLA_ENGINE_QNN_GPU
120
+ needQnnGpuEngine();
121
+ #endif
122
  });
123
 
124
  #endif
Genie/Genie/src/qualla/dialogs/ssd-q1.cpp CHANGED
@@ -161,7 +161,7 @@ SelfSpecDecDialog::SelfSpecDecDialog(
161
  m_inputType = _engine["primary"]->getInputType();
162
  // Load KV prefix
163
  Timer timer;
164
- size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
165
  if (n_restored_prefix != _forecast_prefix) {
166
  // clang-format off
167
  throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
@@ -1001,7 +1001,7 @@ bool SelfSpecDecDialog::process(std::vector<int32_t>& tokens, Dialog::Callback c
1001
  void SelfSpecDecDialog::reset() {
1002
  Dialog::reset();
1003
  _n_past = _forecast_prefix;
1004
- size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name);
1005
  if (n_restored_prefix != _forecast_prefix) {
1006
  // clang-format off
1007
  throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
 
161
  m_inputType = _engine["primary"]->getInputType();
162
  // Load KV prefix
163
  Timer timer;
164
+ size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name, true);
165
  if (n_restored_prefix != _forecast_prefix) {
166
  // clang-format off
167
  throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
 
1001
  void SelfSpecDecDialog::reset() {
1002
  Dialog::reset();
1003
  _n_past = _forecast_prefix;
1004
+ size_t n_restored_prefix = _engine["primary"]->restore(_kv_prefix_name, true);
1005
  if (n_restored_prefix != _forecast_prefix) {
1006
  // clang-format off
1007
  throw std::runtime_error( fmt::format( "SSD : Loaded {} KV$ from {} but expected {} KV$",
Genie/Genie/src/qualla/engine.cpp CHANGED
@@ -69,7 +69,7 @@ bool Engine::updateKV(size_t n_past, const std::vector<bool>& selected) {
69
  return false;
70
  }
71
 
72
- size_t Engine::restore(const std::string& name) {
73
  _env.logger().error(fmt::format("{}-engine does not support restore", _type));
74
  return 0;
75
  }
 
69
  return false;
70
  }
71
 
72
+ size_t Engine::restore(const std::string& name, bool chooseHigherVariant) {
73
  _env.logger().error(fmt::format("{}-engine does not support restore", _type));
74
  return 0;
75
  }
Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.cpp ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+
9
+ #include "QnnMem.h"
10
+ #include "DmaBufAllocator.hpp"
11
+ #include "QnnTypeMacros.hpp"
12
+
13
+ #include <dlfcn.h>
14
+ #include <fcntl.h>
15
+ #include <linux/dma-buf.h>
16
+ #include <pthread.h>
17
+ #include <stdlib.h>
18
+ #include <sys/ioctl.h>
19
+ #include <sys/mman.h>
20
+ #include <unistd.h>
21
+
22
+ #include <cstdlib>
23
+ #include <fstream>
24
+ #include <iostream>
25
+ #include <numeric>
26
+ #include <string>
27
+ #include <vector>
28
+
29
+
30
+ DmaBufferAllocator::DmaBufferAllocator(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface)
31
+ : m_libDmaBufHeapHandle(nullptr),
32
+ m_dmaBufCreate(nullptr),
33
+ m_dmaBufAlloc(nullptr),
34
+ m_dmaBufDeinit(nullptr),
35
+ m_qnnInterface(qnnInterface),
36
+ m_contextHandle(contextHandle) {}
37
+
38
+ bool DmaBufferAllocator::initialize() {
39
+ // On Android, 32-bit and 64-bit libdmaBufheap.so can be found at /system/lib and /system/lib64
40
+ // respectively.
41
+ m_libDmaBufHeapHandle = dlopen("libdmabufheap.so", RTLD_NOW | RTLD_LOCAL);
42
+ if (nullptr == m_libDmaBufHeapHandle) {
43
+ QNN_ERROR("Unable to load backend. dlerror(): %s", dlerror());
44
+ return false;
45
+ }
46
+ m_dmaBufCreate = (DmaBufCreateFn_t)dlsym(
47
+ m_libDmaBufHeapHandle, "CreateDmabufHeapBufferAllocator");
48
+ m_dmaBufAlloc =
49
+ (DmaBufAllocFn_t)dlsym(m_libDmaBufHeapHandle, "DmabufHeapAlloc");
50
+ m_dmaBufDeinit = (DmaBufDeinitFn_t)dlsym(
51
+ m_libDmaBufHeapHandle, "FreeDmabufHeapBufferAllocator");
52
+ if (nullptr == m_dmaBufCreate || nullptr == m_dmaBufAlloc || nullptr == m_dmaBufDeinit) {
53
+ QNN_ERROR("Unable to access symbols in libdmaBufheap. dlerror(): %s", dlerror());
54
+ return false;
55
+ }
56
+ return true;
57
+ }
58
+
59
+ DmaBufferAllocator::~DmaBufferAllocator() {
60
+ if (m_libDmaBufHeapHandle) {
61
+ dlclose(m_libDmaBufHeapHandle);
62
+ m_libDmaBufHeapHandle = nullptr;
63
+ }
64
+ }
65
+
66
+ DmaBufferData* DmaBufferAllocator::getDmaBufTensorData(Qnn_Tensor_t* tensor) {
67
+ if (tensor == nullptr) return nullptr;
68
+ Qnn_MemHandle_t mem_handle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
69
+ if (mem_handle == nullptr) return nullptr;
70
+ return &m_memHandleToDmaBufMem.at(mem_handle);
71
+ }
72
+
73
+ void* DmaBufferAllocator::getBuffer(Qnn_Tensor_t* tensor) {
74
+ if (!tensor) {
75
+ QNN_WARN("DmaBufferAllocator: getBuffer: received a null pointer to a tensor");
76
+ return nullptr;
77
+ }
78
+ if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
79
+ QNN_ERROR("DmaBufferAllocator: Tensor not found with address = %p", tensor);
80
+ return nullptr;
81
+ }
82
+ DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
83
+ return dmaBufferData.memPointer;
84
+ }
85
+
86
+
87
+
88
+ int DmaBufferAllocator::getFd(Qnn_Tensor_t* tensor) {
89
+ DmaBufferData* data = getDmaBufTensorData(tensor);
90
+ if (data == nullptr) {
91
+ QNN_ERROR("DmaBufferAllocator: getFd : Couldn't find tensor %p", tensor);
92
+ return -1;
93
+ }
94
+ return data->fd;
95
+ }
96
+
97
+ size_t DmaBufferAllocator::getOffset(Qnn_Tensor_t* tensor) {
98
+ DmaBufferData* data = getDmaBufTensorData(tensor);
99
+ if (data == nullptr) {
100
+ QNN_ERROR("DmaBufferAllocator: getOffset : Couldn't find tensor %p", tensor);
101
+ return 0;
102
+ }
103
+ return data->offset;
104
+ }
105
+
106
+ size_t DmaBufferAllocator::getBufferSize(Qnn_Tensor_t* tensor) {
107
+ DmaBufferData* data = getDmaBufTensorData(tensor);
108
+ if (data == nullptr) {
109
+ QNN_ERROR("DmaBufferAllocator: getBufferSize : Couldn't find tensor %p", tensor);
110
+ return 0;
111
+ }
112
+ return data->totalBufferSize;
113
+ };
114
+
115
+ size_t DmaBufferAllocator::getTotalBufferSize(Qnn_Tensor_t* tensor) {
116
+ DmaBufferData* data = getDmaBufTensorData(tensor);
117
+ if (data == nullptr) {
118
+ QNN_ERROR("DmaBufferAllocator: getTotalBufferSize : Couldn't find tensor %p", tensor);
119
+ return 0;
120
+ }
121
+ return data->totalBufferSize;
122
+ }
123
+
124
+ bool DmaBufferAllocator::allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) {
125
+ if (m_libDmaBufHeapHandle == nullptr) {
126
+ QNN_ERROR("DmaBufferAllocator not initialized");
127
+ return false;
128
+ }
129
+
130
+ if (!tensor) {
131
+ QNN_ERROR("DmaBufferAllocator: Received nullptr for tensor");
132
+ return false;
133
+ }
134
+
135
+ if (m_tensorToDmaBufferData.find(tensor) != m_tensorToDmaBufferData.end()) {
136
+ QNN_ERROR("DmaBufferAllocator: Tensor already allocated");
137
+ return false;
138
+ }
139
+
140
+ void* dmaBufferAllocator = m_dmaBufCreate();
141
+ if (dmaBufferAllocator == nullptr) {
142
+ QNN_ERROR("DmaBufferAllocator: nullptr returned for CreateDmabufHeapBufferAllocator().");
143
+ return false;
144
+ }
145
+
146
+ int fd = m_dmaBufAlloc(dmaBufferAllocator, "qcom,system", tensorDataSize, 0, 0);
147
+ if (fd < 0) {
148
+ QNN_ERROR("DmaBufAlloc returned a invalid file descriptor = %d", fd);
149
+ return false;
150
+ }
151
+
152
+ void* memPointer = mmap(nullptr, tensorDataSize, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0);
153
+ if (MAP_FAILED == memPointer) {
154
+ printf("DmaBufferAllocator: Unable to open file returned by DmaBufAlloc with mmap");
155
+ return false;
156
+ }
157
+
158
+ Qnn_MemDescriptor_t memDescriptor = {
159
+ {QNN_TENSOR_GET_RANK(tensor), QNN_TENSOR_GET_DIMENSIONS(tensor), nullptr},
160
+ QNN_TENSOR_GET_DATA_TYPE(tensor),
161
+ QNN_MEM_TYPE_DMA_BUF,
162
+ {.dmaBufInfo = {fd, memPointer}}};
163
+ QNN_TENSOR_SET_MEM_TYPE(tensor, QNN_TENSORMEMTYPE_MEMHANDLE);
164
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, nullptr);
165
+ Qnn_MemHandle_t memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
166
+
167
+ if (QNN_SUCCESS !=
168
+ m_qnnInterface->memRegister(m_contextHandle, &memDescriptor, 1, &(memHandle))) {
169
+ QNN_ERROR("DmaBufferAllocator: Failure to register ion memory with the backend");
170
+ return false;
171
+ }
172
+ QNN_DEBUG("DmaBufferAllocator: Memregister successful with handle %p for DMA buffer with size: %zu and fd %d",
173
+ memHandle,
174
+ tensorDataSize,
175
+ fd);
176
+ QNN_TENSOR_SET_MEM_HANDLE(tensor, memHandle);
177
+ m_tensorToDmaBufferData.insert(
178
+ {tensor, DmaBufferData(dmaBufferAllocator, fd, memPointer, tensorDataSize)});
179
+
180
+ return true;
181
+ }
182
+
183
+ bool DmaBufferAllocator::freeTensorBuffer(Qnn_Tensor_t* tensor) {
184
+ if (!tensor) {
185
+ QNN_ERROR("DmaBufferAllocator: Received nullptr for tensor");
186
+ return false;
187
+ }
188
+ auto memHandle = QNN_TENSOR_GET_MEM_HANDLE(tensor);
189
+ if (QNN_SUCCESS != m_qnnInterface->memDeRegister(&memHandle, 1)) {
190
+ QNN_ERROR("DmaBufferAllocator: Failed to deregister custom memory handle with the backend");
191
+ return false;
192
+ }
193
+ if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
194
+ QNN_ERROR("DmaBufferAllocator: Tensor not found with address = %p", tensor);
195
+ return false;
196
+ }
197
+ DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
198
+ if (!m_dmaBufDeinit) {
199
+ QNN_ERROR("DmaBufferAllocator: DmaBuf Deinit function pointer is null");
200
+ return false;
201
+ }
202
+ munmap(dmaBufferData.memPointer, dmaBufferData.totalBufferSize);
203
+ m_dmaBufDeinit(dmaBufferData.dmaBufferAllocator);
204
+ m_tensorToDmaBufferData.erase(tensor);
205
+ return true;
206
+ }
207
+
208
+ bool DmaBufferAllocator::useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) {
209
+ if (nullptr == dest || nullptr == src) {
210
+ QNN_ERROR("DmaBufferAllocator: Received nullptr");
211
+ return false;
212
+ }
213
+ if (m_tensorToDmaBufferData.find(src) == m_tensorToDmaBufferData.end()) {
214
+ QNN_ERROR("DmaBufferAllocator: Src Tensor not found");
215
+ return false;
216
+ }
217
+
218
+ QNN_TENSOR_SET_MEM_TYPE(dest, QNN_TENSOR_GET_MEM_TYPE(src));
219
+ QNN_TENSOR_SET_MEM_HANDLE(dest, QNN_TENSOR_GET_MEM_HANDLE(src));
220
+ m_tensorToDmaBufferData.insert({dest, m_tensorToDmaBufferData[src]});
221
+ m_sameMemoryFreeTensors.insert(dest);
222
+ return true;
223
+ }
224
+
225
+
226
+
227
+ bool DmaBufferAllocator::beforeWriteToBuffer(Qnn_Tensor_t* tensor) {
228
+ if (!tensor) {
229
+ QNN_WARN("beforeWriteToBuffer: received a null pointer to a tensor");
230
+ return false;
231
+ }
232
+ if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
233
+ QNN_ERROR("beforeWriteToBuffer: Tensor not found with address = %p", tensor);
234
+ return false;
235
+ }
236
+ DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
237
+ struct dma_buf_sync buf_sync = {};
238
+ buf_sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_WRITE;
239
+ auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
240
+ if (ioctlReturnValue) {
241
+ QNN_ERROR(
242
+ "beforeWriteToBuffer: Error preparing the cache for buffer writes."
243
+ "The DMA_BUF_IOCTL_SYNC operation returned %d",
244
+ ioctlReturnValue);
245
+ return false;
246
+ }
247
+ return true;
248
+ }
249
+
250
+ bool DmaBufferAllocator::afterWriteToBuffer(Qnn_Tensor_t* tensor) {
251
+ if (!tensor) {
252
+ QNN_WARN("afterWriteToBuffer: received a null pointer to a tensor");
253
+ return false;
254
+ }
255
+ if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
256
+ QNN_ERROR("afterWriteToBuffer: Tensor not found with address = %p", tensor);
257
+ return false;
258
+ }
259
+ DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
260
+ struct dma_buf_sync buf_sync = {};
261
+ buf_sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_WRITE;
262
+ auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
263
+ if (ioctlReturnValue) {
264
+ QNN_ERROR(
265
+ "afterWriteToBuffer: Error close the cache after buffer writing."
266
+ "The DMA_BUF_IOCTL_SYNC operation returned %d",
267
+ ioctlReturnValue);
268
+ return false;
269
+ }
270
+ return true;
271
+ }
272
+
273
+ bool DmaBufferAllocator::beforeReadFromBuffer(Qnn_Tensor_t* tensor) {
274
+ if (!tensor) {
275
+ QNN_WARN("beforeReadFromBuffer: received a null pointer to a tensor");
276
+ return false;
277
+ }
278
+ if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
279
+ QNN_ERROR("beforeReadFromBuffer: Tensor not found with address = %p", tensor);
280
+ return false;
281
+ }
282
+ DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
283
+ struct dma_buf_sync buf_sync = {};
284
+ buf_sync.flags = DMA_BUF_SYNC_START | DMA_BUF_SYNC_READ;
285
+ auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
286
+ if (ioctlReturnValue) {
287
+ QNN_ERROR(
288
+ "beforeReadFromBuffer: Error preparing the cache for buffer reading."
289
+ "The DMA_BUF_IOCTL_SYNC operation returned %d",
290
+ ioctlReturnValue);
291
+ return false;
292
+ }
293
+ return true;
294
+ }
295
+
296
+ bool DmaBufferAllocator::afterReadFromBuffer(Qnn_Tensor_t* tensor) {
297
+ if (!tensor) {
298
+ QNN_WARN("afterReadFromBuffer: received a null pointer to a tensor");
299
+ return false;
300
+ }
301
+ if (m_tensorToDmaBufferData.find(tensor) == m_tensorToDmaBufferData.end()) {
302
+ QNN_ERROR("afterReadFromBuffer: Tensor not found with address = %p", tensor);
303
+ return false;
304
+ }
305
+ DmaBufferData dmaBufferData = m_tensorToDmaBufferData[tensor];
306
+ struct dma_buf_sync buf_sync = {};
307
+ buf_sync.flags = DMA_BUF_SYNC_END | DMA_BUF_SYNC_READ;
308
+ auto ioctlReturnValue = ioctl(dmaBufferData.fd, DMA_BUF_IOCTL_SYNC, &buf_sync);
309
+ if (ioctlReturnValue) {
310
+ QNN_ERROR(
311
+ "afterReadFromBuffer: Error closing the cache after buffer reading."
312
+ "The DMA_BUF_IOCTL_SYNC operation returned %d",
313
+ ioctlReturnValue);
314
+ return false;
315
+ }
316
+ return true;
317
+ }
Genie/Genie/src/qualla/engines/qnn-api/DmaBufAllocator.hpp ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ //==============================================================================
2
+ //
3
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4
+ // All Rights Reserved.
5
+ // Confidential and Proprietary - Qualcomm Technologies, Inc.
6
+ //
7
+ //==============================================================================
8
+ #pragma once
9
+
10
+ #include <map>
11
+ #include <unordered_map>
12
+ #include <unordered_set>
13
+ #include <vector>
14
+
15
+ #include "IBufferAlloc.hpp"
16
+ #include "QnnInterface.h"
17
+ #include "Log.hpp"
18
+
19
+ typedef void *(*DmaBufCreateFn_t)();
20
+ typedef int (*DmaBufAllocFn_t)(void *, const char *, size_t, unsigned int, size_t);
21
+ typedef void (*DmaBufDeinitFn_t)(void *);
22
+
23
+ struct DmaBufferData {
24
+ void *dmaBufferAllocator;
25
+ int fd;
26
+ void* memPointer;
27
+ size_t totalBufferSize;
28
+ int offset{0};
29
+ DmaBufferData() : dmaBufferAllocator(nullptr), fd(-1), memPointer(nullptr), totalBufferSize(0) {}
30
+ DmaBufferData(void *bufferAllocator, int fdIn, void* memPointerIn, size_t sizeIn)
31
+ : dmaBufferAllocator(bufferAllocator), fd(fdIn), memPointer(memPointerIn), totalBufferSize(sizeIn) {}
32
+ };
33
+
34
+ class DmaBufferAllocator final : public IBufferAlloc {
35
+ public:
36
+ DmaBufferAllocator(Qnn_ContextHandle_t contextHandle, QNN_INTERFACE_VER_TYPE* qnnInterface);
37
+ // Disable copy constructors, r-value referencing, etc
38
+ DmaBufferAllocator(const DmaBufferAllocator&) = delete;
39
+ DmaBufferAllocator& operator=(const DmaBufferAllocator&) = delete;
40
+ DmaBufferAllocator(DmaBufferAllocator&&) = delete;
41
+ DmaBufferAllocator& operator=(DmaBufferAllocator&&) = delete;
42
+
43
+ bool initialize() override;
44
+ void* getBuffer(Qnn_Tensor_t* tensor) override;
45
+ int getFd(Qnn_Tensor_t* tensor) override;
46
+ size_t getOffset(Qnn_Tensor_t* tensor) override;
47
+ size_t getBufferSize(Qnn_Tensor_t* tensor) override;
48
+ size_t getTotalBufferSize(Qnn_Tensor_t* tensor) override;
49
+
50
+ bool freeTensorBuffer(Qnn_Tensor_t* tensor) override;
51
+
52
+ bool allocateTensorBuffer(Qnn_Tensor_t* tensor, size_t tensorDataSize) override;
53
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src) override;
54
+
55
+ virtual ~DmaBufferAllocator();
56
+
57
+ bool beforeWriteToBuffer(Qnn_Tensor_t *tensor) override;
58
+ bool afterWriteToBuffer(Qnn_Tensor_t *tensor) override;
59
+ bool beforeReadFromBuffer(Qnn_Tensor_t *tensor) override;
60
+ bool afterReadFromBuffer(Qnn_Tensor_t *tensor) override;
61
+
62
+
63
+ bool useSameMemory(Qnn_Tensor_t* dest, Qnn_Tensor_t* src, int offset) override {
64
+ QNN_WARN("Offset based tensors not supported!!");
65
+ return false;;
66
+ }
67
+ bool useExternalMemory(Qnn_Tensor_t* dest, void* extMem) override {
68
+ QNN_WARN("External Memory not supported!!");
69
+ return false;;
70
+ }
71
+ void* allocateTensorFusedBuffer(uint64_t bufferSize, int32_t* fd) override {
72
+ QNN_WARN("Fused Buffers not supported\n");
73
+ return nullptr;
74
+ };
75
+ bool allocateBuffers(
76
+ const std::map<int, std::map<std::string, size_t>>& allocs_per_chunk,
77
+ std::map<std::string, std::pair<int, size_t>>& tensor_offsets
78
+ ) override {
79
+ QNN_WARN("Fused Buffers not supported\n");
80
+ return false;
81
+ };
82
+ bool mapFusedBufferOffset(
83
+ Qnn_Tensor_t* tensor,
84
+ size_t tensorDataSize,
85
+ int32_t fd,
86
+ uint32_t offset,
87
+ uint64_t totalBufferSize,
88
+ void* memPointer,
89
+ Qnn_ContextHandle_t contextHandle
90
+ ) override {
91
+ QNN_WARN("Fused Buffers not supported\n");
92
+ return false;
93
+ };
94
+ bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) override {
95
+ QNN_WARN("Fused Buffers not supported\n");
96
+ return false;
97
+ };
98
+ void freeFusedBuffers() override {
99
+ return;
100
+ };
101
+ bool mapFusedBufferOffset(
102
+ Qnn_Tensor_t* tensor,
103
+ int alloc_idx,
104
+ size_t offset,
105
+ Qnn_ContextHandle_t ctx,
106
+ size_t size
107
+ ) override {
108
+ QNN_WARN("Fused Buffers not supported\n");
109
+ return false;
110
+ };
111
+
112
+ private:
113
+ DmaBufferData * getDmaBufTensorData(Qnn_Tensor_t* tensor);
114
+
115
+ // Pointer to the dlopen'd libdmabufheap.so shared library which contains
116
+ // dmaBufCreate, dmaBufAlloc, dmaBufDeinit
117
+ void *m_libDmaBufHeapHandle;
118
+ DmaBufCreateFn_t m_dmaBufCreate;
119
+ DmaBufAllocFn_t m_dmaBufAlloc;
120
+ DmaBufDeinitFn_t m_dmaBufDeinit;
121
+
122
+ QNN_INTERFACE_VER_TYPE* m_qnnInterface;
123
+ Qnn_ContextHandle_t m_contextHandle;
124
+
125
+ std::unordered_map<Qnn_Tensor_t *, DmaBufferData> m_tensorToDmaBufferData;
126
+ std::unordered_set<Qnn_Tensor_t*> m_sameMemoryFreeTensors;
127
+ std::unordered_map<Qnn_MemHandle_t, DmaBufferData> m_memHandleToDmaBufMem;
128
+ };
Genie/Genie/src/qualla/engines/qnn-api/IBufferAlloc.hpp CHANGED
@@ -53,4 +53,18 @@ class IBufferAlloc {
53
 
54
  virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0;
55
  virtual void freeFusedBuffers() = 0;
56
- };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  virtual bool deregisterTensorFusedBuffer(Qnn_Tensor_t* tensor) = 0;
55
  virtual void freeFusedBuffers() = 0;
56
+
57
+ // Functions to sync memory buffers for Read/Write using DmaBuf.
58
+ virtual bool beforeWriteToBuffer(Qnn_Tensor_t *tensor) {
59
+ return false;
60
+ };
61
+ virtual bool afterWriteToBuffer(Qnn_Tensor_t *tensor) {
62
+ return false;
63
+ };
64
+ virtual bool beforeReadFromBuffer(Qnn_Tensor_t *tensor) {
65
+ return false;
66
+ };
67
+ virtual bool afterReadFromBuffer(Qnn_Tensor_t *tensor) {
68
+ return false;
69
+ };
70
+ };
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.cpp CHANGED
@@ -10,6 +10,9 @@
10
  #include <iostream>
11
 
12
  #include "ClientBuffer.hpp"
 
 
 
13
  #include "IBufferAlloc.hpp"
14
  #include "IOTensor.hpp"
15
  #include "RpcMem.hpp"
@@ -28,6 +31,14 @@ IOTensor::IOTensor(BufferAlloc bufferAllocIn, QNN_INTERFACE_VER_TYPE* qnnInterfa
28
  bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
29
  if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
30
  m_bufferManager = std::unique_ptr<IBufferAlloc>(new RpcMem(contextHandle, m_qnnInterface));
 
 
 
 
 
 
 
 
31
  }
32
 
33
  if (true != m_bufferManager->initialize()) {
@@ -39,7 +50,7 @@ bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
39
  }
40
 
41
  IOTensor::~IOTensor() {
42
- if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
43
  m_bufferManager->freeFusedBuffers();
44
  }
45
  }
@@ -215,6 +226,70 @@ bool IOTensor::setupOutputTensors(
215
  return true;
216
  }
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  bool IOTensor::mapFusedBufferOffset(
219
  GraphInfo_t* graph_info,
220
  Qnn_ContextHandle_t context_handle,
 
10
  #include <iostream>
11
 
12
  #include "ClientBuffer.hpp"
13
+ #ifndef _WIN32
14
+ #include "DmaBufAllocator.hpp"
15
+ #endif
16
  #include "IBufferAlloc.hpp"
17
  #include "IOTensor.hpp"
18
  #include "RpcMem.hpp"
 
31
  bool IOTensor::initialize(Qnn_ContextHandle_t contextHandle) {
32
  if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER) {
33
  m_bufferManager = std::unique_ptr<IBufferAlloc>(new RpcMem(contextHandle, m_qnnInterface));
34
+ } else if (m_bufferAlloc == BufferAlloc::DMABUF) {
35
+ #ifdef _WIN32
36
+ return false;
37
+ #else
38
+ m_bufferManager =
39
+ std::unique_ptr<IBufferAlloc>(new DmaBufferAllocator(contextHandle, m_qnnInterface)
40
+ );
41
+ #endif
42
  }
43
 
44
  if (true != m_bufferManager->initialize()) {
 
50
  }
51
 
52
  IOTensor::~IOTensor() {
53
+ if (m_bufferAlloc == BufferAlloc::SHARED_BUFFER || m_bufferAlloc == BufferAlloc::DMABUF) {
54
  m_bufferManager->freeFusedBuffers();
55
  }
56
  }
 
226
  return true;
227
  }
228
 
229
+ // Setup details for Qnn_Tensor_t for execution.
230
+ // Reuse same memory handle for KV input and output tensor.
231
+ bool IOTensor::setupOutputWithSharedTensors(
232
+ Qnn_Tensor_t** tensors,
233
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
234
+ const GraphInfo_t& graphInfo,
235
+ std::unordered_map<std::string, size_t>& tensorsSize,
236
+ Qnn_ContextHandle_t contextHandle,
237
+ std::unordered_map<std::string, Qnn_Tensor_t*> sharedTensorMap
238
+ ) {
239
+ uint32_t tensorCount = graphInfo.numOutputTensors;
240
+ TensorWrapper* tensorWrappers = graphInfo.outputTensors;
241
+ if (nullptr == tensorWrappers) {
242
+ QNN_ERROR("tensorWrappers is nullptr");
243
+ return false;
244
+ }
245
+
246
+ if (0 == tensorCount) {
247
+ QNN_DEBUG("tensor count is 0. Nothing to setup.");
248
+ return true;
249
+ }
250
+
251
+ *tensors = (Qnn_Tensor_t*)calloc(1, tensorCount * sizeof(Qnn_Tensor_t));
252
+ if (nullptr == *tensors) {
253
+ QNN_ERROR("mem alloc failed for *tensors");
254
+ return false;
255
+ }
256
+
257
+ bool returnStatus = true;
258
+ for (size_t tensorIdx = 0; tensorIdx < tensorCount; tensorIdx++) {
259
+ Qnn_Tensor_t wrapperTensor = GET_TENSOR_WRAPPER_TENSOR(tensorWrappers[tensorIdx]);
260
+ auto wrapperTensorName = std::string(GET_TENSOR_WRAPPER_NAME(tensorWrappers[tensorIdx]));
261
+ if (true == returnStatus) {
262
+ (*tensors)[tensorIdx] = QNN_TENSOR_INIT;
263
+ returnStatus = deepCopyQnnTensorInfo(((*tensors) + tensorIdx), &wrapperTensor);
264
+ }
265
+ if (true == returnStatus) {
266
+ if (sharedTensorMap.find(wrapperTensorName) == sharedTensorMap.end()) {
267
+ QNN_DEBUG("IoTensor :: Create Buffer for Tensor %s", wrapperTensorName.c_str());
268
+ size_t tensorDataSize = tensorsSize[wrapperTensorName];
269
+ returnStatus = m_bufferManager->allocateTensorBuffer(
270
+ ((*tensors) + tensorIdx), tensorDataSize
271
+ );
272
+ } else {
273
+ std::string inputName = QNN_TENSOR_GET_NAME(sharedTensorMap[wrapperTensorName]);
274
+ QNN_DEBUG("IoTensor :: Reuse Buffer %s for Tensor %s", inputName.c_str(), wrapperTensorName.c_str());
275
+ returnStatus = m_bufferManager->useSameMemory(
276
+ ((*tensors) + tensorIdx), sharedTensorMap[wrapperTensorName]
277
+ );
278
+ }
279
+ }
280
+ if (true != returnStatus) {
281
+ QNN_ERROR("Failure in setupTensors, cleaning up resources");
282
+ tearDownTensors(*tensors, tensorIdx);
283
+ *tensors = nullptr;
284
+ QNN_ERROR("Failure in setupTensors, done cleaning up resources");
285
+ break;
286
+ } else {
287
+ tensorNameToTensorPointer.insert({wrapperTensorName, ((*tensors) + tensorIdx)});
288
+ }
289
+ }
290
+ return returnStatus;
291
+ }
292
+
293
  bool IOTensor::mapFusedBufferOffset(
294
  GraphInfo_t* graph_info,
295
  Qnn_ContextHandle_t context_handle,
Genie/Genie/src/qualla/engines/qnn-api/IOTensor.hpp CHANGED
@@ -28,6 +28,7 @@
28
  enum class BufferAlloc {
29
  DEFAULT, // malloc based allocator
30
  SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform
 
31
  INVALID
32
  };
33
  class IBufferAlloc;
@@ -60,6 +61,16 @@ class IOTensor {
60
  bool skipBufferAllocation = false
61
  );
62
 
 
 
 
 
 
 
 
 
 
 
63
  bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount);
64
 
65
  bool tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t tensorCount);
@@ -146,6 +157,20 @@ class IOTensor {
146
 
147
  std::unordered_set<void*>& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; }
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  private:
150
  BufferAlloc m_bufferAlloc;
151
  QNN_INTERFACE_VER_TYPE* m_qnnInterface;
 
28
  enum class BufferAlloc {
29
  DEFAULT, // malloc based allocator
30
  SHARED_BUFFER, // shared buffer allocator; actual allocator depends on the platform
31
+ DMABUF, // dma buffer allocator
32
  INVALID
33
  };
34
  class IBufferAlloc;
 
61
  bool skipBufferAllocation = false
62
  );
63
 
64
+ bool setupOutputWithSharedTensors(
65
+ Qnn_Tensor_t** outputs,
66
+ std::unordered_map<std::string, void*>& tensorNameToTensorPointer,
67
+ const GraphInfo_t& graphInfo,
68
+ std::unordered_map<std::string, size_t>& outputTensorsSize,
69
+ Qnn_ContextHandle_t contextHandle,
70
+ std::unordered_map<std::string, Qnn_Tensor_t *> sharedTensorMap
71
+ );
72
+
73
+
74
  bool tearDownTensors(Qnn_Tensor_t* tensors, uint32_t tensorCount);
75
 
76
  bool tearDownTensors(std::vector<Qnn_Tensor_t*>& tensors, uint32_t tensorCount);
 
157
 
158
  std::unordered_set<void*>& getFreeTensorsPointerSet() { return m_freeTensorsPointerSet; }
159
 
160
+ // Functions to sync memory buffers for Read/Write using DmaBuf.
161
+ bool beforeWriteToBuffer(Qnn_Tensor_t *tensor) {
162
+ return m_bufferManager->beforeWriteToBuffer(tensor);
163
+ }
164
+ bool afterWriteToBuffer(Qnn_Tensor_t *tensor){
165
+ return m_bufferManager->afterWriteToBuffer(tensor);
166
+ }
167
+ bool beforeReadFromBuffer(Qnn_Tensor_t *tensor){
168
+ return m_bufferManager->beforeReadFromBuffer(tensor);
169
+ }
170
+ bool afterReadFromBuffer(Qnn_Tensor_t *tensor){
171
+ return m_bufferManager->afterReadFromBuffer(tensor);
172
+ }
173
+
174
  private:
175
  BufferAlloc m_bufferAlloc;
176
  QNN_INTERFACE_VER_TYPE* m_qnnInterface;
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.cpp CHANGED
@@ -106,11 +106,17 @@ bool QnnApi::getContextConfigs(
106
  ) {
107
  std::vector<QnnContext_Config_t*> contextConfigPtrsVec;
108
 
109
- if (contextPriority != QNN_PRIORITY_DEFAULT) {
110
- contextConfigPtrsVec.push_back((QnnContext_Config_t*)malloc(sizeof(QnnContext_Config_t)));
111
  contextConfigPtrsVec.back()->option =
112
- QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_OPTION_PRIORITY;
113
- contextConfigPtrsVec.back()->priority = contextPriority;
 
 
 
 
 
 
114
  }
115
 
116
  const char** graphNames = nullptr;
@@ -891,6 +897,8 @@ bool QnnApi::composeGraphs(
891
  QnnLog_Level_t::QNN_LOG_LEVEL_VERBOSE
892
  );
893
 
 
 
894
  if (status == MODEL_NO_ERROR) {
895
  return true;
896
  }
@@ -1163,33 +1171,6 @@ bool QnnApi::createFromBinary(
1163
  }
1164
  }
1165
 
1166
- QnnContext_Config_t** contextConfigs = nullptr;
1167
- uint32_t contextConfigCount = 0;
1168
- if (true != getContextConfigs(
1169
- &contextConfigs,
1170
- contextConfigCount,
1171
- contextConfig.priority,
1172
- graphSwitching,
1173
- execSelectGraphs,
1174
- loadSelectGraphs
1175
- )) {
1176
- QNN_ERROR("Couldn't populate context configs");
1177
- return false;
1178
- }
1179
-
1180
- // Merge BE specific and agnostic configs
1181
- QnnContext_Config_t** allContextConfigs{nullptr};
1182
- if (true != mergeAllContextConfigs(
1183
- &allContextConfigs,
1184
- customConfigs,
1185
- contextConfigs,
1186
- customConfigCount,
1187
- contextConfigCount
1188
- )) {
1189
- QNN_ERROR("Error merging custom and context configs");
1190
- return false;
1191
- }
1192
-
1193
  if (nullptr == m_qnnSystemInterface.systemContextCreate ||
1194
  nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
1195
  nullptr == m_qnnSystemInterface.systemContextFree) {
@@ -1299,9 +1280,36 @@ bool QnnApi::createFromBinary(
1299
  }
1300
 
1301
  bool isIOBufferMgrInitialized = false;
1302
-
1303
  for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
1304
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1305
  if (nullptr == m_qnnInterface.contextCreateFromBinary) {
1306
  QNN_ERROR(
1307
  "contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx
@@ -1498,7 +1506,13 @@ bool QnnApi::createFromBinary(
1498
  first_contextHandle = contextHandle;
1499
  }
1500
  #endif
1501
-
 
 
 
 
 
 
1502
  }
1503
 
1504
  m_isContextCreated = true;
@@ -1507,14 +1521,6 @@ bool QnnApi::createFromBinary(
1507
  "Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size()
1508
  );
1509
 
1510
- if (true != freeContextConfigs(contextConfigs, contextConfigCount)) {
1511
- QNN_ERROR("Couldn't free context configs");
1512
- return false;
1513
- }
1514
- if (allContextConfigs) {
1515
- free(allContextConfigs);
1516
- }
1517
-
1518
  if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
1519
  if (!m_backendExtensions->interface()->afterCreateFromBinary()) {
1520
  QNN_ERROR("Extensions Failure in afterCreateFromBinary()");
@@ -1599,34 +1605,6 @@ bool QnnApi::createFromBinaryListAsync(
1599
  }
1600
  }
1601
 
1602
-
1603
- QnnContext_Config_t** contextConfigs = nullptr;
1604
- uint32_t contextConfigCount = 0;
1605
- if (true != getContextConfigs(
1606
- &contextConfigs,
1607
- contextConfigCount,
1608
- contextConfig.priority,
1609
- graphSwitching,
1610
- execSelectGraphs,
1611
- loadSelectGraphs
1612
- )) {
1613
- QNN_ERROR("Couldn't populate context configs");
1614
- return false;
1615
- }
1616
-
1617
- // Merge BE specific and agnostic configs
1618
- QnnContext_Config_t** allContextConfigs{nullptr};
1619
- if (true != mergeAllContextConfigs(
1620
- &allContextConfigs,
1621
- customConfigs,
1622
- contextConfigs,
1623
- customConfigCount,
1624
- contextConfigCount
1625
- )) {
1626
- QNN_ERROR("Error merging custom and context configs");
1627
- return false;
1628
- }
1629
-
1630
  if (nullptr == m_qnnSystemInterface.systemContextCreate ||
1631
  nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
1632
  nullptr == m_qnnSystemInterface.systemContextFree) {
@@ -1642,6 +1620,8 @@ bool QnnApi::createFromBinaryListAsync(
1642
  GraphInfo_t*** graphsInfo =
1643
  (GraphInfo_t***)calloc(cachedBinariesPathVec.size(), sizeof(GraphInfo_t**));
1644
  uint32_t graphsTotalNum = 0;
 
 
1645
 
1646
  for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
1647
  auto _startPerContext = std::chrono::steady_clock::now();
@@ -1710,17 +1690,41 @@ bool QnnApi::createFromBinaryListAsync(
1710
  m_qnnSystemInterface.systemContextFree(sysCtxHandle);
1711
  sysCtxHandle = nullptr;
1712
 
1713
- uint32_t customConfigCountSF = 0;
 
 
 
 
 
 
 
 
 
 
 
 
1714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1715
  if (mmap_budget > 0) {
1716
  QnnHtpContext_CustomConfig_t customConfigReadBudget;
1717
  customConfigReadBudget.option = QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET;
1718
  customConfigReadBudget.fileReadMemoryBudgetInMb = mmap_budget;
1719
 
1720
  QnnContext_Config_t** cfgs{nullptr};
1721
-
1722
  uint32_t customConfigCountReadBudget = 1;
1723
-
1724
  cfgs = (QnnContext_Config_t**)malloc(
1725
  customConfigCountReadBudget * sizeof(QnnContext_Config_t*)
1726
  );
@@ -1729,15 +1733,16 @@ bool QnnApi::createFromBinaryListAsync(
1729
  cfgs[0]->customConfig =
1730
  reinterpret_cast<QnnContext_CustomConfig_t>(&customConfigReadBudget);
1731
  if (true != mergeAllContextConfigs(
1732
- &allContextConfigs,
1733
  cfgs,
1734
- allContextConfigs,
1735
  customConfigCountReadBudget,
1736
  contextConfigCount + customConfigCount + customConfigCountSF
1737
  )) {
1738
  QNN_ERROR("Error merging custom and context configs");
1739
  return false;
1740
  }
 
1741
  }
1742
 
1743
  if (m_profileBackendHandle) {
@@ -1751,7 +1756,7 @@ bool QnnApi::createFromBinaryListAsync(
1751
  .version = QNN_CONTEXT_PARAMS_VERSION_1,
1752
  .v1 =
1753
  QnnContext_ParamsV1_t{
1754
- (const QnnContext_Config_t**)allContextConfigs,
1755
  (const void*)buffer.get(),
1756
  bufferSize,
1757
  nullptr,
@@ -1778,18 +1783,15 @@ bool QnnApi::createFromBinaryListAsync(
1778
  }
1779
 
1780
  auto start = std::chrono::steady_clock::now();
1781
-
1782
-
1783
  auto errCode = m_qnnInterface.contextCreateFromBinaryListAsync(
1784
  m_backendHandle,
1785
  m_deviceHandle,
1786
  const_cast<const QnnContext_Params_t**>(context_params_list.data()),
1787
- (const QnnContext_Config_t**)allContextConfigs,
1788
  nullptr
1789
  );
1790
-
1791
-
1792
  auto stop = std::chrono::steady_clock::now();
 
1793
  QNN_DEBUG(
1794
  "Initializing %lu context with %u graphs took: %lld us",
1795
  cachedBinariesPathVec.size(),
@@ -1824,26 +1826,24 @@ bool QnnApi::createFromBinaryListAsync(
1824
 
1825
  m_isContextCreated = true;
1826
 
1827
- if (true != freeContextConfigs(contextConfigs, contextConfigCount)) {
1828
- QNN_ERROR("Couldn't free context configs");
1829
- return false;
1830
- }
1831
-
1832
  if (true != freeContextParams(context_params_list.data(), cachedBinariesPathVec.size())) {
1833
  QNN_ERROR("Couldn't free context params list");
1834
  return false;
1835
  }
1836
 
1837
- if (allContextConfigs) {
1838
- free(allContextConfigs);
1839
- }
1840
-
1841
  if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
1842
  if (!m_backendExtensions->interface()->afterCreateContextsFromBinaryList()) {
1843
  QNN_ERROR("Extensions Failure in afterCreateContextsFromBinaryList()");
1844
  return false;
1845
  }
1846
  }
 
 
 
 
 
 
 
1847
  return true;
1848
  }
1849
  #endif
@@ -2543,6 +2543,64 @@ bool QnnApi::extractProfilingEvent(QnnProfile_EventId_t profileEventId) {
2543
  return true;
2544
  }
2545
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2546
  bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch) {
2547
  #if QUALLA_QNN_API_VERSION < 21700
2548
  QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries");
@@ -2650,7 +2708,7 @@ bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bo
2650
  #endif
2651
  }
2652
 
2653
- bool QnnApi::updateIOEncodings(std::shared_ptr<uint8_t>& buffer,uint64_t bufferSize,uint32_t graphIndex){
2654
 
2655
  QNN_DEBUG("Applying adapter Encodings");
2656
  QnnSystemContext_Handle_t sysCtxHandle{nullptr};
@@ -2679,3 +2737,224 @@ bool QnnApi::updateIOEncodings(std::shared_ptr<uint8_t>& buffer,uint64_t buffer
2679
  QNN_DEBUG(" updateIOEncodings success ");
2680
  return true;
2681
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  ) {
107
  std::vector<QnnContext_Config_t*> contextConfigPtrsVec;
108
 
109
+ if (contextPriority == QNN_PRIORITY_UNDEFINED) {
110
+ contextConfigPtrsVec.push_back((QnnContext_Config_t *) malloc(sizeof(QnnContext_Config_t)));
111
  contextConfigPtrsVec.back()->option =
112
+ QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_UNDEFINED;
113
+ } else {
114
+ if (contextPriority != QNN_PRIORITY_DEFAULT) {
115
+ contextConfigPtrsVec.push_back((QnnContext_Config_t *) malloc(sizeof(QnnContext_Config_t)));
116
+ contextConfigPtrsVec.back()->option =
117
+ QnnContext_ConfigOption_t::QNN_CONTEXT_CONFIG_OPTION_PRIORITY;
118
+ contextConfigPtrsVec.back()->priority = contextPriority;
119
+ }
120
  }
121
 
122
  const char** graphNames = nullptr;
 
897
  QnnLog_Level_t::QNN_LOG_LEVEL_VERBOSE
898
  );
899
 
900
+ graphCountPerContext = m_graphsCount;
901
+
902
  if (status == MODEL_NO_ERROR) {
903
  return true;
904
  }
 
1171
  }
1172
  }
1173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1174
  if (nullptr == m_qnnSystemInterface.systemContextCreate ||
1175
  nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
1176
  nullptr == m_qnnSystemInterface.systemContextFree) {
 
1280
  }
1281
 
1282
  bool isIOBufferMgrInitialized = false;
 
1283
  for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
1284
 
1285
+ // Create context configs for each context
1286
+ QnnContext_Config_t** contextConfigs = nullptr;
1287
+ uint32_t contextConfigCount = 0;
1288
+ if (true != getContextConfigs(
1289
+ &contextConfigs,
1290
+ contextConfigCount,
1291
+ contextConfig.priority,
1292
+ graphSwitching,
1293
+ execSelectGraphs,
1294
+ loadSelectGraphs
1295
+ )) {
1296
+ QNN_ERROR("Couldn't populate context configs");
1297
+ return false;
1298
+ }
1299
+
1300
+ // Merge BE specific and agnostic configs
1301
+ QnnContext_Config_t** allContextConfigs{nullptr};
1302
+ if (true != mergeAllContextConfigs(
1303
+ &allContextConfigs,
1304
+ customConfigs,
1305
+ contextConfigs,
1306
+ customConfigCount,
1307
+ contextConfigCount
1308
+ )) {
1309
+ QNN_ERROR("Error merging custom and context configs");
1310
+ return false;
1311
+ }
1312
+
1313
  if (nullptr == m_qnnInterface.contextCreateFromBinary) {
1314
  QNN_ERROR(
1315
  "contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx
 
1506
  first_contextHandle = contextHandle;
1507
  }
1508
  #endif
1509
+ if (true != freeContextConfigs(contextConfigs, contextConfigCount)) {
1510
+ QNN_ERROR("Couldn't free context configs");
1511
+ return false;
1512
+ }
1513
+ if (allContextConfigs) {
1514
+ free(allContextConfigs);
1515
+ }
1516
  }
1517
 
1518
  m_isContextCreated = true;
 
1521
  "Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size()
1522
  );
1523
 
 
 
 
 
 
 
 
 
1524
  if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
1525
  if (!m_backendExtensions->interface()->afterCreateFromBinary()) {
1526
  QNN_ERROR("Extensions Failure in afterCreateFromBinary()");
 
1605
  }
1606
  }
1607
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1608
  if (nullptr == m_qnnSystemInterface.systemContextCreate ||
1609
  nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
1610
  nullptr == m_qnnSystemInterface.systemContextFree) {
 
1620
  GraphInfo_t*** graphsInfo =
1621
  (GraphInfo_t***)calloc(cachedBinariesPathVec.size(), sizeof(GraphInfo_t**));
1622
  uint32_t graphsTotalNum = 0;
1623
+ std::vector<QnnContext_Config_t**> allContextConfigs{(unsigned int)cachedBinariesPathVec.size(), nullptr};
1624
+ std::vector<uint32_t> allContextConfigsSize{(unsigned int)cachedBinariesPathVec.size()};
1625
 
1626
  for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
1627
  auto _startPerContext = std::chrono::steady_clock::now();
 
1690
  m_qnnSystemInterface.systemContextFree(sysCtxHandle);
1691
  sysCtxHandle = nullptr;
1692
 
1693
+ uint32_t contextConfigCount = 0;
1694
+ if (true != getContextConfigs(
1695
+ &allContextConfigs[contextIdx],
1696
+ contextConfigCount,
1697
+ contextConfig.priority,
1698
+ graphSwitching,
1699
+ execSelectGraphs,
1700
+ loadSelectGraphs
1701
+ )) {
1702
+ QNN_ERROR("Couldn't populate context configs");
1703
+ return false;
1704
+ }
1705
+ allContextConfigsSize[contextIdx] = contextConfigCount;
1706
 
1707
+ // Merge BE specific and agnostic configs
1708
+ if (true != mergeAllContextConfigs(
1709
+ &allContextConfigs[contextIdx],
1710
+ customConfigs,
1711
+ allContextConfigs[contextIdx],
1712
+ customConfigCount,
1713
+ contextConfigCount
1714
+ )) {
1715
+ QNN_ERROR("Error merging custom and context configs");
1716
+ return false;
1717
+ }
1718
+ allContextConfigsSize[contextIdx] += customConfigCount;
1719
+
1720
+ uint32_t customConfigCountSF = 0;
1721
  if (mmap_budget > 0) {
1722
  QnnHtpContext_CustomConfig_t customConfigReadBudget;
1723
  customConfigReadBudget.option = QNN_HTP_CONTEXT_CONFIG_OPTION_FILE_READ_MEMORY_BUDGET;
1724
  customConfigReadBudget.fileReadMemoryBudgetInMb = mmap_budget;
1725
 
1726
  QnnContext_Config_t** cfgs{nullptr};
 
1727
  uint32_t customConfigCountReadBudget = 1;
 
1728
  cfgs = (QnnContext_Config_t**)malloc(
1729
  customConfigCountReadBudget * sizeof(QnnContext_Config_t*)
1730
  );
 
1733
  cfgs[0]->customConfig =
1734
  reinterpret_cast<QnnContext_CustomConfig_t>(&customConfigReadBudget);
1735
  if (true != mergeAllContextConfigs(
1736
+ &allContextConfigs[contextIdx],
1737
  cfgs,
1738
+ allContextConfigs[contextIdx],
1739
  customConfigCountReadBudget,
1740
  contextConfigCount + customConfigCount + customConfigCountSF
1741
  )) {
1742
  QNN_ERROR("Error merging custom and context configs");
1743
  return false;
1744
  }
1745
+ allContextConfigsSize[contextIdx] += customConfigCountReadBudget;
1746
  }
1747
 
1748
  if (m_profileBackendHandle) {
 
1756
  .version = QNN_CONTEXT_PARAMS_VERSION_1,
1757
  .v1 =
1758
  QnnContext_ParamsV1_t{
1759
+ (const QnnContext_Config_t**)allContextConfigs[contextIdx],
1760
  (const void*)buffer.get(),
1761
  bufferSize,
1762
  nullptr,
 
1783
  }
1784
 
1785
  auto start = std::chrono::steady_clock::now();
 
 
1786
  auto errCode = m_qnnInterface.contextCreateFromBinaryListAsync(
1787
  m_backendHandle,
1788
  m_deviceHandle,
1789
  const_cast<const QnnContext_Params_t**>(context_params_list.data()),
1790
+ (const QnnContext_Config_t**)customConfigs,
1791
  nullptr
1792
  );
 
 
1793
  auto stop = std::chrono::steady_clock::now();
1794
+
1795
  QNN_DEBUG(
1796
  "Initializing %lu context with %u graphs took: %lld us",
1797
  cachedBinariesPathVec.size(),
 
1826
 
1827
  m_isContextCreated = true;
1828
 
 
 
 
 
 
1829
  if (true != freeContextParams(context_params_list.data(), cachedBinariesPathVec.size())) {
1830
  QNN_ERROR("Couldn't free context params list");
1831
  return false;
1832
  }
1833
 
 
 
 
 
1834
  if (nullptr != m_backendExtensions && m_backendExtensions->interface()) {
1835
  if (!m_backendExtensions->interface()->afterCreateContextsFromBinaryList()) {
1836
  QNN_ERROR("Extensions Failure in afterCreateContextsFromBinaryList()");
1837
  return false;
1838
  }
1839
  }
1840
+
1841
+ for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
1842
+ if (true != freeContextConfigs(allContextConfigs[contextIdx], allContextConfigsSize[contextIdx])) {
1843
+ QNN_ERROR("Couldn't free context configs");
1844
+ return false;
1845
+ }
1846
+ }
1847
  return true;
1848
  }
1849
  #endif
 
2543
  return true;
2544
  }
2545
 
2546
+ bool QnnApi::applyBinarySection(uint32_t graphId, std::string binSectionPath) {
2547
+ #if QUALLA_QNN_API_VERSION < 21700
2548
+ QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries");
2549
+ return false;
2550
+ #else
2551
+ // assumption splitNum from 0
2552
+ QNN_DEBUG("QnnApi::applyBinarySection %d ", graphId);
2553
+ if (nullptr == m_qnnInterface.contextApplyBinarySection) {
2554
+ QNN_ERROR("contextApplyBinarySection Interface not suported!!");
2555
+ return false;
2556
+ }
2557
+ if (graphId >= m_graphsCount) {
2558
+ QNN_ERROR(" Passed split %d base Model graphcount %d ", graphId, m_graphsCount);
2559
+ return false;
2560
+ }
2561
+ uint64_t bufferSize{0};
2562
+ std::shared_ptr<uint8_t> buffer{nullptr};
2563
+ bufferSize = getFileSize(binSectionPath);
2564
+ buffer = std::shared_ptr<uint8_t>(new uint8_t[bufferSize]);
2565
+ if (true != readBinaryFromFile(binSectionPath, buffer.get(), bufferSize)) {
2566
+ QNN_ERROR("Failed to read binary data for context index = %d", graphId);
2567
+ return false;
2568
+ }
2569
+
2570
+ QnnContext_Buffer_t qnnBuffer;
2571
+ qnnBuffer.version = QNN_CONTEXT_BUFFER_VERSION_1;
2572
+ qnnBuffer.v1.memType = QNN_CONTEXTMEMTYPE_RAW;
2573
+ qnnBuffer.v1.binaryBuf.dataSize = bufferSize;
2574
+ qnnBuffer.v1.binaryBuf.data = static_cast<void*>(buffer.get());
2575
+ auto graphCountPerContext = getGraphCountPerContext();
2576
+ if (graphCountPerContext <= 0) {
2577
+ QNN_ERROR(" graphCountPerContext is <=0 ");
2578
+ return false;
2579
+ }
2580
+
2581
+ auto contextHandle = m_contextVec[graphId / graphCountPerContext];
2582
+ auto graphHandle = m_graphsInfo[graphId]->graph;
2583
+ if (contextHandle == nullptr || graphHandle == nullptr) {
2584
+ QNN_ERROR(" contexthandle or graph handle is null for patch no = %d ", graphId);
2585
+ return false;
2586
+ }
2587
+
2588
+ auto errorCode = m_qnnInterface.contextApplyBinarySection(
2589
+ contextHandle,
2590
+ graphHandle,
2591
+ QNN_CONTEXT_SECTION_UPDATABLE,
2592
+ &qnnBuffer,
2593
+ nullptr, //profile handle is null
2594
+ nullptr //singal handle is null
2595
+ );
2596
+ if (errorCode != QNN_SUCCESS) {
2597
+ QNN_ERROR("Could not Apply Patch for graph = %d errocode = %zu ", graphId, errorCode);
2598
+ return false;
2599
+ }
2600
+ return true;
2601
+ #endif
2602
+ }
2603
+
2604
  bool QnnApi::applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch) {
2605
  #if QUALLA_QNN_API_VERSION < 21700
2606
  QNN_ERROR("LoRA adaptors require QNN SDK >= 2.25.1. Please update your libraries");
 
2708
  #endif
2709
  }
2710
 
2711
+ bool QnnApi::updateIOEncodings(std::shared_ptr<uint8_t>& buffer,uint64_t bufferSize,uint32_t graphIndex) {
2712
 
2713
  QNN_DEBUG("Applying adapter Encodings");
2714
  QnnSystemContext_Handle_t sysCtxHandle{nullptr};
 
2737
  QNN_DEBUG(" updateIOEncodings success ");
2738
  return true;
2739
  }
2740
+
2741
+ // This is a light weight function of existing ::createFromBinary, used for
2742
+ // GPU execution to avoid conflicts with HTP use-case and for better readability.
2743
+ bool QnnApi::createFromBinary(
2744
+ std::vector<std::string> cachedBinariesPathVec
2745
+ ) {
2746
+ auto _start = std::chrono::steady_clock::now();
2747
+
2748
+ if (nullptr == m_qnnSystemInterface.systemContextCreate ||
2749
+ nullptr == m_qnnSystemInterface.systemContextGetBinaryInfo ||
2750
+ nullptr == m_qnnSystemInterface.systemContextFree) {
2751
+ QNN_ERROR("QNN System function pointers are not populated.");
2752
+ return false;
2753
+ }
2754
+
2755
+ graphCountPerContext = getGraphCountPerContext();
2756
+
2757
+ for (size_t contextIdx = 0; contextIdx < cachedBinariesPathVec.size(); contextIdx++) {
2758
+ uint64_t bufferSize{0};
2759
+ std::shared_ptr<uint8_t> buffer{nullptr};
2760
+ uint32_t graphsCount;
2761
+
2762
+ // read serialized binary into a byte buffer
2763
+ bufferSize = getFileSize(cachedBinariesPathVec[contextIdx]);
2764
+ if (0 == bufferSize) {
2765
+ QNN_ERROR(
2766
+ "Received path to an empty file for context index = %zu. Nothing to deserialize.",
2767
+ contextIdx
2768
+ );
2769
+ return false;
2770
+ }
2771
+
2772
+ buffer = std::shared_ptr<uint8_t>(
2773
+ new uint8_t[bufferSize], std::default_delete<uint8_t[]>()
2774
+ );
2775
+ if (!buffer) {
2776
+ QNN_ERROR("Failed to allocate memory for context index = %zu", contextIdx);
2777
+ return false;
2778
+ }
2779
+ if (true !=
2780
+ readBinaryFromFile(cachedBinariesPathVec[contextIdx], buffer.get(), bufferSize)) {
2781
+ QNN_ERROR("Failed to read binary data for context index = %zu", contextIdx);
2782
+ return false;
2783
+ }
2784
+
2785
+ // inspect binary info
2786
+ QnnSystemContext_Handle_t sysCtxHandle{nullptr};
2787
+ if (QNN_SUCCESS != m_qnnSystemInterface.systemContextCreate(&sysCtxHandle)) {
2788
+ QNN_ERROR("Could not create system handle for context index = %zu", contextIdx);
2789
+ return false;
2790
+ }
2791
+
2792
+ const QnnSystemContext_BinaryInfo_t* binaryInfo{nullptr};
2793
+ Qnn_ContextBinarySize_t binaryInfoSize{0};
2794
+
2795
+ if (QNN_SUCCESS != m_qnnSystemInterface.systemContextGetBinaryInfo(
2796
+ sysCtxHandle,
2797
+ static_cast<void*>(buffer.get()),
2798
+ bufferSize,
2799
+ &binaryInfo,
2800
+ &binaryInfoSize
2801
+ )) {
2802
+ QNN_ERROR("Failed to get context binary info for context index = %zu", contextIdx);
2803
+ return false;
2804
+ }
2805
+
2806
+ GraphInfo_t** graphsInfo;
2807
+ if (!copyMetadataToGraphsInfo(binaryInfo, graphsInfo, graphsCount)) {
2808
+ QNN_ERROR("Failed to copy metadata for graph index = %zu", contextIdx);
2809
+ freeGraphsInfo(&graphsInfo, graphsCount);
2810
+ if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
2811
+ return false;
2812
+ }
2813
+
2814
+ if (graphCountPerContext == -1) {
2815
+ graphCountPerContext = graphsCount;
2816
+ m_graphsInfo = (GraphInfo_t**)calloc(
2817
+ graphCountPerContext * cachedBinariesPathVec.size(), sizeof(GraphInfo_t*)
2818
+ );
2819
+ } else if (graphCountPerContext != graphsCount) {
2820
+ QNN_ERROR(
2821
+ "Different len(graphs) found in different context files. Found %u vs %u",
2822
+ graphsCount,
2823
+ graphCountPerContext
2824
+ );
2825
+ freeGraphsInfo(&graphsInfo, graphsCount);
2826
+ if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
2827
+ return false;
2828
+ }
2829
+ m_qnnSystemInterface.systemContextFree(sysCtxHandle);
2830
+ sysCtxHandle = nullptr;
2831
+
2832
+ if (nullptr == m_qnnInterface.contextCreateFromBinary) {
2833
+ QNN_ERROR(
2834
+ "contextCreateFromBinaryFnHandle is nullptr for context index = %zu", contextIdx
2835
+ );
2836
+ freeGraphsInfo(&graphsInfo, graphsCount);
2837
+ if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
2838
+ return false;
2839
+ }
2840
+ Qnn_ContextHandle_t contextHandle{nullptr};
2841
+ auto _stop = std::chrono::steady_clock::now();
2842
+ QNN_DEBUG(
2843
+ "Loading contexts[%lu] took: %lld us",
2844
+ contextIdx,
2845
+ std::chrono::duration_cast<std::chrono::microseconds>(_stop - _start).count()
2846
+ );
2847
+
2848
+ auto start = std::chrono::steady_clock::now();
2849
+
2850
+ auto errCode = m_qnnInterface.contextCreateFromBinary(
2851
+ m_backendHandle,
2852
+ m_deviceHandle,
2853
+ nullptr,
2854
+ (const void*)buffer.get(),
2855
+ bufferSize,
2856
+ &contextHandle,
2857
+ nullptr // profile handle
2858
+
2859
+ );
2860
+
2861
+ if (errCode != QNN_SUCCESS) {
2862
+ QNN_ERROR(
2863
+ "Could not create context from binary for context index = %zu : err %d",
2864
+ contextIdx,
2865
+ (int)errCode
2866
+ );
2867
+ freeGraphsInfo(&graphsInfo, graphsCount);
2868
+ if (contextIdx > 0) freeGraphsInfo(&m_graphsInfo, m_graphsCount);
2869
+ return false;
2870
+ }
2871
+
2872
+ auto stop = std::chrono::steady_clock::now();
2873
+ QNN_DEBUG(
2874
+ "Initializing context[%lu] with %u graphs took: %lld us",
2875
+ contextIdx,
2876
+ graphsCount,
2877
+ std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count()
2878
+ );
2879
+
2880
+ for (int n_graph = 0; n_graph < graphsCount; n_graph++) {
2881
+ // Allocate inputTensors and outputTensors
2882
+ GraphInfo_t* cur_graph = graphsInfo[n_graph];
2883
+
2884
+ m_graphsInfo[m_graphsCount++] = cur_graph;
2885
+ m_contextMap[cur_graph] = contextHandle;
2886
+ }
2887
+ m_contextVec.push_back(contextHandle);
2888
+ }
2889
+
2890
+ m_isContextCreated = true;
2891
+
2892
+ QNN_DEBUG(
2893
+ "Initialized %u graphs from %lu contexts", m_graphsCount, cachedBinariesPathVec.size()
2894
+ );
2895
+
2896
+ if (nullptr == m_qnnInterface.graphRetrieve) {
2897
+ QNN_ERROR("graphRetrieveFnHandle is nullptr.");
2898
+ freeGraphsInfo(&m_graphsInfo, m_graphsCount);
2899
+ return false;
2900
+ }
2901
+
2902
+ for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) {
2903
+ if (!m_graphsInfo || QNN_SUCCESS != m_qnnInterface.graphRetrieve(
2904
+ m_contextVec[graphIdx / graphCountPerContext],
2905
+ m_graphsInfo[graphIdx]->graphName,
2906
+ &(m_graphsInfo[graphIdx]->graph)
2907
+ )) {
2908
+ QNN_ERROR("Unable to retrieve graph handle for graph index = %zu", graphIdx);
2909
+ freeGraphsInfo(&m_graphsInfo, m_graphsCount);
2910
+ return false;
2911
+ }
2912
+ }
2913
+
2914
+ return true;
2915
+ }
2916
+
2917
+ bool QnnApi::initialize(
2918
+ std::string backendPath,
2919
+ std::vector<std::string> modelPathOrCachedBinaryPath
2920
+ ) {
2921
+ if (modelPathOrCachedBinaryPath.size() != 1) {
2922
+ QNN_ERROR("Multiple Files not supported for now!!");
2923
+ return false;
2924
+ }
2925
+
2926
+ if (false == getQnnInterface(backendPath)) {
2927
+ QNN_ERROR("Qnn getQnnInterface FAILED!");
2928
+ return false;
2929
+ }
2930
+
2931
+ const std::string systemLibraryPath = "libQnnSystem.so";
2932
+ if (false == getQnnSystemInterface(systemLibraryPath)) {
2933
+ QNN_ERROR("Qnn getQnnSystemInterface FAILED!");
2934
+ return false;
2935
+ }
2936
+
2937
+ QnnLog_Level_t logLevel = QNN_LOG_LEVEL_INFO;
2938
+ if (false == initializeLogging(logLevel, false)) {
2939
+ QNN_ERROR("Unable to Initialize logging in backend");
2940
+ return false;
2941
+ }
2942
+
2943
+ // Initialize Backend
2944
+ if (false == initializeBackend()) {
2945
+ QNN_ERROR("Qnn initializeBackend FAILED!");
2946
+ return false;
2947
+ }
2948
+
2949
+ if (false == createFromBinary(modelPathOrCachedBinaryPath)) {
2950
+ QNN_ERROR("Create From Binary FAILED!");
2951
+ return false;
2952
+ }
2953
+
2954
+ for (size_t graphIdx = 0; graphIdx < m_graphsCount; graphIdx++) {
2955
+ m_graphNameToIndex[m_graphsInfo[graphIdx]->graphName] = graphIdx;
2956
+ }
2957
+ QNN_DEBUG("Model Initialized");
2958
+
2959
+ return true;
2960
+ }
Genie/Genie/src/qualla/engines/qnn-api/QnnApi.hpp CHANGED
@@ -370,6 +370,8 @@ class QnnApi {
370
 
371
  bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);
372
 
 
 
373
  QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; };
374
  GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; };
375
  uint32_t getGraphsCount() { return m_graphsCount; };
@@ -426,4 +428,11 @@ class QnnApi {
426
  bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
427
  uint64_t bufferSize,
428
  uint32_t graphIndex);
 
 
 
 
 
 
 
429
  };
 
370
 
371
  bool applyBinarySection(uint32_t binIndex, std::string binSectionPath,bool useMmap,bool graphSwitch);
372
 
373
+ bool applyBinarySection(uint32_t graphId, std::string binSectionPath);
374
+
375
  QNN_INTERFACE_VER_TYPE* getQnnInterfaceVer() { return &m_qnnInterface; };
376
  GraphInfo_t**& getGraphsInfo() { return m_graphsInfo; };
377
  uint32_t getGraphsCount() { return m_graphsCount; };
 
428
  bool updateIOEncodings(std::shared_ptr<uint8_t>& buffer,
429
  uint64_t bufferSize,
430
  uint32_t graphIndex);
431
+
432
+ bool createFromBinary(std::vector<std::string> cachedBinariesPathVec);
433
+
434
+ bool initialize(
435
+ std::string backendPath,
436
+ std::vector<std::string> modelPathOrCachedBinaryPath
437
+ );
438
  };
Genie/Genie/src/qualla/engines/qnn-api/qnn-utils.hpp CHANGED
@@ -46,14 +46,14 @@ bool writeRawData(void* tensorData, size_t tensorSize, const std::filesystem::pa
46
  bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
47
 
48
  struct Dims {
49
- int32_t batch = 1;
50
- int32_t height, width, channel, bitWidth;
51
  Dims() : height(0), width(0), channel(0), bitWidth(0) {}
52
- Dims(int32_t height, int32_t width, int32_t channel, int32_t bitWidth)
53
  : height(height), width(width), channel(channel), bitWidth(bitWidth) {}
54
  Dims(std::vector<size_t>& tDims)
55
- : height((int32_t)tDims[1]), width((int32_t)tDims[2]), channel((int32_t)tDims[3]),
56
- bitWidth((int32_t)tDims[4]) {
57
  // Hack to mix batch dimension
58
  if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0];
59
  if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0];
 
46
  bool readRawData(void* tensorData, size_t tensorSize, const std::filesystem::path& path);
47
 
48
  struct Dims {
49
+ uint32_t batch = 1;
50
+ uint32_t height, width, channel, bitWidth;
51
  Dims() : height(0), width(0), channel(0), bitWidth(0) {}
52
+ Dims(uint32_t height, uint32_t width, uint32_t channel, uint32_t bitWidth)
53
  : height(height), width(width), channel(channel), bitWidth(bitWidth) {}
54
  Dims(std::vector<size_t>& tDims)
55
+ : height((uint32_t)tDims[1]), width((uint32_t)tDims[2]), channel((uint32_t)tDims[3]),
56
+ bitWidth((uint32_t)tDims[4]) {
57
  // Hack to mix batch dimension
58
  if (tDims[0] != 1 && tDims[1] == 1) height = tDims[0];
59
  if (tDims[0] > 1 && tDims[1] != 1) batch = tDims[0];
Genie/Genie/src/qualla/engines/qnn-cpu.cpp CHANGED
@@ -55,8 +55,10 @@ class QnnCpuEngine : public Engine {
55
  virtual bool updateKV(size_t n_past) override;
56
  virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
57
  virtual bool save(const std::string& name) override;
58
- virtual size_t restore(const std::string& name) override;
59
  virtual void reset() override;
 
 
60
  };
61
 
62
  namespace fs = std::filesystem;
@@ -98,7 +100,40 @@ QnnCpuEngine::QnnCpuEngine(Context& ctx, const qualla::json& json) : Engine(ctx,
98
  p.use_mmap = conf.optional<bool>("use-mmap", false);
99
  p.ctx_size = _ctx.size();
100
  p.n_vocab_size = _ctx.n_vocab();
101
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  _model = std::make_unique<QnnCpuModel>(_env, p);
103
 
104
  // Load model
@@ -211,7 +246,7 @@ size_t QnnCpuEngine::process(
211
  );
212
  }
213
 
214
- size_t QnnCpuEngine::restore(const std::string& name) {
215
  fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role);
216
  return _model->loadKVCache(cache_path.string());
217
  }
@@ -226,6 +261,23 @@ void QnnCpuEngine::reset() {
226
  updateKV(0);
227
  }
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  // Registrator instance
230
  static OnLoad regy([]() {
231
  Engine::__register("qnn-cpu", [](Context& ctx, const json& conf) {
 
55
  virtual bool updateKV(size_t n_past) override;
56
  virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
57
  virtual bool save(const std::string& name) override;
58
+ virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
59
  virtual void reset() override;
60
+ virtual bool applyLoraAdapter(std::string lora_adapter_name) override;
61
+ virtual bool applyLoraStrength(std::string tensor_name, float tensor_val) override;
62
  };
63
 
64
  namespace fs = std::filesystem;
 
100
  p.use_mmap = conf.optional<bool>("use-mmap", false);
101
  p.ctx_size = _ctx.size();
102
  p.n_vocab_size = _ctx.n_vocab();
103
+ p.lora_config_type = LoraConfigType::LORA_DISABLE;
104
+ qualla::json lora_conf = conf.optional<qualla::json>("lora", {});
105
+ if (lora_conf.size() != 0) {
106
+ p.lora_config_type = LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE;
107
+ if (lora_conf.is_array()) {
108
+ for (auto lc : lora_conf) {
109
+ std::string lnm = lc["adapter-name"];
110
+ p.lora_config[lnm].lora_name = lnm;
111
+ p.lora_config[lnm].alpha_tensor_name = lc["alpha-tensor-name"];
112
+ p.lora_config[lnm].alpha_tensor_val = 0.0f;
113
+ if(lc.contains("alpha-tensor-value")){
114
+ p.lora_config[lnm].alpha_tensor_val = lc["alpha-tensor-value"];
115
+ }
116
+ std::string basedir = "";
117
+ if(lc.contains("binsection-basedir")){
118
+ basedir = lc["binsection-basedir"];
119
+ }
120
+ uint32_t n = lc["bin-sections"].size();
121
+ for (uint32_t i = 0; i < n; i++) {
122
+ auto binSec = lc["bin-sections"].get<std::vector<std::string>>();
123
+ fs::path binsection_path = fs::path(binSec[i]);
124
+ if (binsection_path.is_relative()) binsection_path = basedir / fs::path(binSec[i]);
125
+ if (!fs::is_regular_file(binsection_path)) {
126
+ __ERROR("qnn-cpu: Can't access Lora binsection adapter : {}",
127
+ binsection_path.string());
128
+ throw std::runtime_error(
129
+ "qnn-cpu: Can't open adapter file : " + binsection_path.string()
130
+ );
131
+ }
132
+ p.lora_config[lnm].binsection_list.push_back(binsection_path.string());
133
+ }
134
+ }
135
+ }
136
+ }
137
  _model = std::make_unique<QnnCpuModel>(_env, p);
138
 
139
  // Load model
 
246
  );
247
  }
248
 
249
+ size_t QnnCpuEngine::restore(const std::string& name, bool chooseHigherVariant) {
250
  fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-cpu", _role);
251
  return _model->loadKVCache(cache_path.string());
252
  }
 
261
  updateKV(0);
262
  }
263
 
264
+ // For Lora
265
+ bool QnnCpuEngine::applyLoraAdapter(std::string lora_adapter_name) {
266
+ if (!_model) {
267
+ __ERROR("qnn-cpu: applyLoraAdapter failed, model not initialized");
268
+ return false;
269
+ }
270
+ return _model->applyLoraAdapter(lora_adapter_name);
271
+ }
272
+
273
+ bool QnnCpuEngine::applyLoraStrength(std::string tensor_name, float tensor_val) {
274
+ if (!_model) {
275
+ __ERROR("qnn-cpu: applyLoraStrength failed, model not initialized");
276
+ return false;
277
+ }
278
+ return _model->applyLoraStrength(tensor_name, tensor_val);
279
+ }
280
+
281
  // Registrator instance
282
  static OnLoad regy([]() {
283
  Engine::__register("qnn-cpu", [](Context& ctx, const json& conf) {
Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.cpp CHANGED
@@ -61,6 +61,12 @@ QnnCpuModel::QnnCpuModel(Env& env, const Params& params)
61
  m_output_dim.push_back(m_numLogits);
62
  m_output_dim.push_back(m_embd);
63
  }
 
 
 
 
 
 
64
  }
65
 
66
  QnnCpuModel::~QnnCpuModel() {
@@ -383,6 +389,7 @@ bool QnnCpuModel::initializeTensorPointers() {
383
  t_input_ids_k_cache = &input_specs["x3"];
384
  t_input_ids_v_cache = &input_specs["x4"];
385
  t_input_ids_n_past = &input_specs["x5"];
 
386
 
387
  auto& output_specs = m_output_specs[model_order.back()];
388
  t_logits = &output_specs["output_genAI"];
@@ -406,6 +413,7 @@ void QnnCpuModel::setupInputTensors(const std::vector<int32_t>& tokens, bool run
406
  uint32_t* input_id_num_token_buffer = (uint32_t*)getBuffer(t_input_ids_num_token);
407
  uint32_t* input_id_reset_kvcache_buffer = (uint32_t*)getBuffer(t_input_ids_reset_kvcache);
408
  uint32_t* input_id_n_past_buffer = (uint32_t*)getBuffer(t_input_ids_n_past);
 
409
 
410
  uint32_t size = 1;
411
  for (auto dim : m_input_dim) {
@@ -420,6 +428,7 @@ void QnnCpuModel::setupInputTensors(const std::vector<int32_t>& tokens, bool run
420
  std::memcpy(input_id_buffer, tokens.data(), tokens.size() * sizeof(uint32_t));
421
  *input_id_num_token_buffer = tokens.size();
422
  *input_id_n_past_buffer = m_nPast;
 
423
 
424
  auto stop = std::chrono::steady_clock::now();
425
  // QnnUtils::logProfile("setupInputTensors (cpp) took", start, stop);
@@ -589,6 +598,48 @@ size_t QnnCpuModel::getDequantLogits(std::vector<float>& dequant_logits, bool lo
589
  return logits_all? prev_run.num_tokens_processed : 1;
590
  }
591
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
592
  // TODO: implement save/restore
593
  size_t QnnCpuModel::loadKVCache(const std::string& load_path) {
594
  //TO read the cache file into KV tensor
 
61
  m_output_dim.push_back(m_numLogits);
62
  m_output_dim.push_back(m_embd);
63
  }
64
+ m_loraConfigType = params.lora_config_type;
65
+ m_lora_alpha_val = 1.0f;
66
+
67
+ if (m_loraConfigType == LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
68
+ m_loraConfig.insert(params.lora_config.begin(), params.lora_config.end());
69
+ }
70
  }
71
 
72
  QnnCpuModel::~QnnCpuModel() {
 
389
  t_input_ids_k_cache = &input_specs["x3"];
390
  t_input_ids_v_cache = &input_specs["x4"];
391
  t_input_ids_n_past = &input_specs["x5"];
392
+ t_input_lora_alpha = &input_specs["x6"];
393
 
394
  auto& output_specs = m_output_specs[model_order.back()];
395
  t_logits = &output_specs["output_genAI"];
 
413
  uint32_t* input_id_num_token_buffer = (uint32_t*)getBuffer(t_input_ids_num_token);
414
  uint32_t* input_id_reset_kvcache_buffer = (uint32_t*)getBuffer(t_input_ids_reset_kvcache);
415
  uint32_t* input_id_n_past_buffer = (uint32_t*)getBuffer(t_input_ids_n_past);
416
+ float* input_id_lora_alpha = (float*)getBuffer(t_input_lora_alpha);
417
 
418
  uint32_t size = 1;
419
  for (auto dim : m_input_dim) {
 
428
  std::memcpy(input_id_buffer, tokens.data(), tokens.size() * sizeof(uint32_t));
429
  *input_id_num_token_buffer = tokens.size();
430
  *input_id_n_past_buffer = m_nPast;
431
+ *input_id_lora_alpha = m_lora_alpha_val;
432
 
433
  auto stop = std::chrono::steady_clock::now();
434
  // QnnUtils::logProfile("setupInputTensors (cpp) took", start, stop);
 
598
  return logits_all? prev_run.num_tokens_processed : 1;
599
  }
600
 
601
+ bool QnnCpuModel::applyBinarySections(std::vector<std::string>& binsection_list) {
602
+ //apply binary section for lora config
603
+ for (int i = 0; i < binsection_list.size(); i++) {
604
+ __DEBUG("qnn-cpu: applyBinarySections adapters {}", binsection_list.at(i));
605
+ if (!m_qnnApi->applyBinarySection(i, binsection_list.at(i))) {
606
+ __ERROR("qnn-cpu: Error in applyBinarySections {}", i);
607
+ return false;
608
+ }
609
+ }
610
+ return true;
611
+ }
612
+
613
+ bool QnnCpuModel::applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val) {
614
+ m_lora_alpha_val = alpha_val;
615
+ return true;
616
+ }
617
+
618
+ bool QnnCpuModel::applyLoraAdapter(const std::string& lora_adapter_name) {
619
+ if (m_loraConfigType != LoraConfigType::LORA_ADAPTER_WEIGHT_ENABLE) {
620
+ __ERROR("qnn-cpu: Lora config is not enable for adapters");
621
+ return false;
622
+ }
623
+
624
+ if (!m_loraConfig.contains(lora_adapter_name)) {
625
+ __ERROR("qnn-cpu: Could not find lora adapters config to apply ");
626
+ return false;
627
+ }
628
+ if (!applyLoraStrength(
629
+ m_loraConfig[lora_adapter_name].alpha_tensor_name,
630
+ m_loraConfig[lora_adapter_name].alpha_tensor_val
631
+ )) {
632
+ __ERROR("qnn-cpu: Could not apply Alpha tensor ");
633
+ return false;
634
+ }
635
+
636
+ if (!applyBinarySections(m_loraConfig[lora_adapter_name].binsection_list)) {
637
+ __ERROR("qnn-cpu: Could not apply binary Sections ");
638
+ return false;
639
+ }
640
+ return true;
641
+ }
642
+
643
  // TODO: implement save/restore
644
  size_t QnnCpuModel::loadKVCache(const std::string& load_path) {
645
  //TO read the cache file into KV tensor
Genie/Genie/src/qualla/engines/qnn-cpu/cpu-model.hpp CHANGED
@@ -26,6 +26,12 @@
26
 
27
  namespace qualla {
28
 
 
 
 
 
 
 
29
  class QnnCpuModel {
30
  enum ExecutionMode { AUTODETECT, BERT_KV, KV_ONLY, BERT_ONLY };
31
 
@@ -34,6 +40,13 @@ class QnnCpuModel {
34
  public:
35
  enum ModelOutput { LOGITS = 0x0, EMBEDDINGS= 0x1 };
36
 
 
 
 
 
 
 
 
37
  struct Params {
38
  std::filesystem::path model_basedir;
39
  std::string op_package;
@@ -50,6 +63,8 @@ class QnnCpuModel {
50
  uint32_t n_layer;
51
  uint32_t n_embd;
52
  uint32_t n_heads;
 
 
53
  };
54
 
55
  const std::filesystem::path model_basedir;
@@ -92,6 +107,11 @@ class QnnCpuModel {
92
  std::vector<Qnn_Param_t> m_params;
93
  ExecutionMode m_mode{ExecutionMode::AUTODETECT};
94
 
 
 
 
 
 
95
  // Save some information about the last inference run
96
  struct PreviousRunInfo {
97
  bool was_bert_mode;
@@ -118,6 +138,7 @@ class QnnCpuModel {
118
  QnnUtils::Tensor* t_input_ids_k_cache;
119
  QnnUtils::Tensor* t_input_ids_v_cache;
120
  QnnUtils::Tensor* t_input_ids_n_past;
 
121
  float* dequant_logits_ptr{nullptr};
122
 
123
  // Store pointers for bert
@@ -171,6 +192,10 @@ class QnnCpuModel {
171
  size_t loadKVCache(const std::string& save_path);
172
  bool saveKVCache(const std::string& load_path);
173
 
 
 
 
 
174
  private:
175
  bool m_mmap_context_bins = false; // mmap context binary files instead of reading them in memory
176
  // Internal functions to separate different runInference logic
 
26
 
27
  namespace qualla {
28
 
29
+ enum LoraConfigType {
30
+ LORA_DISABLE = 0,
31
+ LORA_INPUT_WEIGHT_ENABLE = 1,
32
+ LORA_ADAPTER_WEIGHT_ENABLE = 2
33
+ };
34
+
35
  class QnnCpuModel {
36
  enum ExecutionMode { AUTODETECT, BERT_KV, KV_ONLY, BERT_ONLY };
37
 
 
40
  public:
41
  enum ModelOutput { LOGITS = 0x0, EMBEDDINGS= 0x1 };
42
 
43
+ struct LoraConfig {
44
+ std::string lora_name;
45
+ std::vector<std::string> binsection_list; //loRAv2 adapter bins filenames
46
+ std::string alpha_tensor_name; //loRAv2 alpha tensor names
47
+ float alpha_tensor_val; //loRAv2 alpha tensor values
48
+ };
49
+
50
  struct Params {
51
  std::filesystem::path model_basedir;
52
  std::string op_package;
 
63
  uint32_t n_layer;
64
  uint32_t n_embd;
65
  uint32_t n_heads;
66
+ LoraConfigType lora_config_type;
67
+ std::map<std::string, LoraConfig> lora_config;
68
  };
69
 
70
  const std::filesystem::path model_basedir;
 
107
  std::vector<Qnn_Param_t> m_params;
108
  ExecutionMode m_mode{ExecutionMode::AUTODETECT};
109
 
110
+ // LoRA params and configs
111
+ float m_lora_alpha_val;
112
+ LoraConfigType m_loraConfigType;
113
+ std::map<std::string, LoraConfig> m_loraConfig;
114
+
115
  // Save some information about the last inference run
116
  struct PreviousRunInfo {
117
  bool was_bert_mode;
 
138
  QnnUtils::Tensor* t_input_ids_k_cache;
139
  QnnUtils::Tensor* t_input_ids_v_cache;
140
  QnnUtils::Tensor* t_input_ids_n_past;
141
+ QnnUtils::Tensor* t_input_lora_alpha;
142
  float* dequant_logits_ptr{nullptr};
143
 
144
  // Store pointers for bert
 
192
  size_t loadKVCache(const std::string& save_path);
193
  bool saveKVCache(const std::string& load_path);
194
 
195
+ bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
196
+ bool applyLoraAdapter(const std::string& lora_adapter_name);
197
+ bool applyBinarySections(std::vector<std::string>& binsection_list);
198
+
199
  private:
200
  bool m_mmap_context_bins = false; // mmap context binary files instead of reading them in memory
201
  // Internal functions to separate different runInference logic
Genie/Genie/src/qualla/engines/qnn-gpu.cpp ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
2
+ // Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
3
+
4
+ #include <vector>
5
+ #include <string>
6
+
7
+ #include <qualla/engine.hpp>
8
+ #include <qualla/detail/config.hpp>
9
+ #include <qualla/detail/timer.hpp>
10
+ #include <qualla/detail/onload.hpp>
11
+
12
+ #include <fmt/format.h>
13
+
14
+ #include "gpu-model.hpp"
15
+
16
+ #define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__))
17
+ #define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__))
18
+ #define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__))
19
+ #define __KPIS(__fmt, ...) \
20
+ _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
21
+ #define __DEBUG(__fmt, ...) \
22
+ _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
23
+ #define __TRACE(__fmt, ...) \
24
+ _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); })
25
+
26
+ namespace qualla {
27
+
28
+ class GpuEngine : public Engine {
29
+ private:
30
+ QnnGpuModel::Params _params;
31
+ std::unique_ptr<QnnGpuModel> _model;
32
+
33
+ public:
34
+ GpuEngine(Context& ctx, const qualla::json& json);
35
+ ~GpuEngine();
36
+
37
+ virtual size_t process(
38
+ const std::vector<int32_t>& tokens,
39
+ std::vector<float>& logits,
40
+ bool logits_all
41
+ ) override;
42
+
43
+ virtual bool updateKV(size_t n_past) override;
44
+ virtual bool save(const std::string& name) override;
45
+ virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
46
+ virtual void reset() override;
47
+
48
+ virtual bool load() override;
49
+ virtual bool unload() override;
50
+ };
51
+
52
+ namespace fs = std::filesystem;
53
+
54
+ GpuEngine::GpuEngine(Context& ctx, const qualla::json& json) : Engine(ctx, "qnn-gpu", json) {
55
+ qualla::Timer start;
56
+
57
+ using FF = Feature::Flags;
58
+ _features = FF::OUTPUT_LOGITS | FF::SAVE_RESTORE | FF::DYNAMIC_LOAD;
59
+
60
+ __DEBUG("Qnn-Gpu : init start");
61
+
62
+ qualla::Config conf(json, _type + "-engine:");
63
+
64
+ // Parse config
65
+ _params.model_basedir = conf.optional<std::string>("model-basedir", "");
66
+ if (_params.model_basedir.is_relative()) {
67
+ _params.model_basedir = _env.path().models / _params.model_basedir;
68
+ _params.model_basedir = _params.model_basedir.make_preferred();
69
+ }
70
+ _params.model_list = conf.mandatory<std::vector<std::string>>("model-list");
71
+
72
+ _params.ctx_size = _ctx.size();
73
+ _params.num_heads = conf.optional<int64_t>("num-heads", 32);
74
+ _params.head_dim = conf.optional<int64_t>("head-dim", 128);
75
+
76
+ if (!conf.optional<bool>("dynamic-load", false)) {
77
+ load();
78
+ }
79
+ };
80
+
81
+ GpuEngine::~GpuEngine() {
82
+ unload();
83
+ }
84
+
85
+ bool GpuEngine::load() {
86
+ #ifdef _WIN32
87
+ // QnnGpu Engine does not support Windows.
88
+ return false;
89
+ #endif
90
+ if (_model) return true;
91
+
92
+ qualla::Timer start;
93
+ bool status = true;
94
+
95
+ __INFO("Qnn-Gpu : Loading Model");
96
+
97
+ _model = std::make_unique<QnnGpuModel>(_env, _params);
98
+
99
+ // Load model
100
+ status = _model->initializeModel();
101
+ if (!status) {
102
+ throw std::runtime_error("Qnn-Gpu :Failure to initialize model");
103
+ }
104
+
105
+ // Initialize IO Tensor buffers
106
+ status = _model->initializeIOTensors();
107
+ if (!status) {
108
+ throw std::runtime_error("Qnn-Gpu :Error in setting up IO Tensors");
109
+ }
110
+
111
+ // Initialize IO Tensor Pointers
112
+ if (true != _model->initializeTensorPointers()) {
113
+ throw std::runtime_error("Qnn-Gpu :Could not find I/O tensors in loaded graphs");
114
+ }
115
+
116
+ // Validate the model
117
+ if (true != _model->validateModel()) {
118
+ throw std::runtime_error("Qnn-Gpu :Model Validation Failed");
119
+ }
120
+
121
+ _kpis.load.update(start.elapsed_usec());
122
+ return true;
123
+ }
124
+
125
+ bool GpuEngine::unload() {
126
+ qualla::Timer start;
127
+ __DEBUG("Qnn-Gpu : Unloading Model");
128
+ _model.reset(nullptr);
129
+ _kpis.unload.update(start.elapsed_usec());
130
+ return true;
131
+ }
132
+
133
+ // KV Cache updation after each inference is handled inside QnnGpu Backend
134
+ // GPU Engine uses same memory handle for each KV input/output to the graph and uses
135
+ // Scatter op to update KV after each inference to the same memory handle.
136
+ bool GpuEngine::updateKV(size_t n_past) {
137
+ return true;
138
+ }
139
+
140
+ size_t GpuEngine::process(
141
+ const std::vector<int32_t>& tokens,
142
+ std::vector<float>& logits,
143
+ bool logits_all
144
+ ) {
145
+ if (!_model && !load()) {
146
+ return 0;
147
+ }
148
+ qualla::Timer start;
149
+ size_t n_tok = _model->runInference(tokens, logits, logits_all);
150
+ if (n_tok == 0) {
151
+ State::error("Qnn-Gpu : RunInference Failed!");
152
+ }
153
+ _kpis.process.update(start.elapsed_usec());
154
+ return n_tok;
155
+ }
156
+
157
+ size_t GpuEngine::restore(const std::string& name, bool chooseHigherVariant) {
158
+ if (!_model && !load()) {
159
+ return 0;
160
+ }
161
+
162
+ fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-gpu", _role);
163
+ return _model->loadKVCache(cache_path.string());
164
+ }
165
+
166
+ bool GpuEngine::save(const std::string& name) {
167
+ if (!_model && !load()) {
168
+ return false;
169
+ }
170
+
171
+ fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-gpu", _role);
172
+ return _model->saveKVCache(cache_path.string());
173
+ }
174
+
175
+
176
+ // Reset requires clearing of KV caches only
177
+ void GpuEngine::reset() {
178
+ if (!_model && !load()) {
179
+ return;
180
+ }
181
+ _model->reset();
182
+ }
183
+
184
+ // Registrator instance
185
+ static OnLoad regy([]() {
186
+ Engine::__register("qnn-gpu", [](Context& ctx, const json& conf) {
187
+ return (Engine*)new GpuEngine(ctx, conf);
188
+ });
189
+ });
190
+
191
+ void needQnnGpuEngine() {}
192
+
193
+ } // namespace qualla
Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.cpp ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
2
+ // Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
3
+
4
+ #include <cassert>
5
+ #include <cstring>
6
+ #include <fstream>
7
+ #include <set>
8
+ #include <sstream>
9
+
10
+ #include "fmt/format.h"
11
+ #include "fmt/ranges.h"
12
+ #include "fp16/fp16.h"
13
+ #include "gpu-model.hpp"
14
+ #include "qualla/detail/cache-file.hpp"
15
+ #include "qualla/detail/timer.hpp"
16
+ #include "qualla/env.hpp"
17
+
18
+ namespace fs = std::filesystem;
19
+
20
+ static constexpr uint32_t g_magicNum = 0xC0DE;
21
+
22
+ #define __INFO(__fmt, ...) _env.logger().post(Logger::INFO, fmt::format(__fmt, ##__VA_ARGS__));
23
+ #define __WARN(__fmt, ...) _env.logger().post(Logger::WARN, fmt::format(__fmt, ##__VA_ARGS__));
24
+ #define __ERROR(__fmt, ...) _env.logger().post(Logger::ERROR, fmt::format(__fmt, ##__VA_ARGS__));
25
+ #define __KPIS(__fmt, ...) \
26
+ _env.logger().post(Logger::ENGINE_KPIS, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); });
27
+ #define __DEBUG(__fmt, ...) \
28
+ _env.logger().post(Logger::ENGINE_DEBUG, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); });
29
+ #define __TRACE(__fmt, ...) \
30
+ _env.logger().post(Logger::ENGINE_TRACE, [&]() { return fmt::format(__fmt, ##__VA_ARGS__); });
31
+
32
+ namespace qualla {
33
+
34
+ QnnGpuModel::QnnGpuModel(Env& env, const Params& params)
35
+ : _env(env), _modelBaseDir(params.model_basedir) {
36
+ // Initialize _qnnApi
37
+ _qnnApi = std::unique_ptr<QnnApi>(new QnnApi());
38
+
39
+ _ctxSize = params.ctx_size;
40
+ _numHeads = params.num_heads;
41
+ _headDim = params.head_dim;
42
+ #ifdef _WIN32
43
+ _useDmabufIo = false;
44
+ #else
45
+ _useDmabufIo = true;
46
+ #endif
47
+ // Set up filename list for context binaries.
48
+ for (auto& i : params.model_list) {
49
+ fs::path model_path = _modelBaseDir / fs::path(i);
50
+ if (!fs::is_regular_file(model_path)) {
51
+ __ERROR("Qnn-Gpu-Model : Can't access model file : {}", model_path.string());
52
+ throw std::runtime_error("Qnn-Gpu-Model : Can't access model file : " + model_path.string());
53
+ }
54
+ _modelList.push_back(model_path.string());
55
+ }
56
+ }
57
+
58
+ QnnGpuModel::~QnnGpuModel() { __INFO("Qnn-Gpu-Model : model destruct complete"); }
59
+
60
+ // Given a filename, initializeModel load and initializes QNN runtime libraries and the model
61
+ bool QnnGpuModel::initializeModel(void) {
62
+ qualla::Timer start;
63
+
64
+ __INFO("Qnn-Gpu-Model : Model Init Start");
65
+
66
+ const std::string backend = "libQnnGpu.so";
67
+
68
+ __INFO("Backend Library : {}", backend);
69
+ __INFO("Model Files : {}", _modelList);
70
+
71
+ if (!_qnnApi->initialize(backend, _modelList)) {
72
+ __ERROR("Qnn-Api : Initialization Failed!");
73
+ return false;
74
+ }
75
+
76
+ // Initialize QNN IO Tensor
77
+ if (_useDmabufIo) {
78
+ _ioTensor =
79
+ std::unique_ptr<IOTensor>(new IOTensor(BufferAlloc::DMABUF, _qnnApi->getQnnInterfaceVer()));
80
+ } else {
81
+ _ioTensor = std::unique_ptr<IOTensor>(
82
+ new IOTensor(BufferAlloc::DEFAULT, _qnnApi->getQnnInterfaceVer()));
83
+ }
84
+ _numGraphs = _qnnApi->getGraphsCount();
85
+ __INFO("Qnn-Gpu-Model : initialized with {} graph(s)", _numGraphs);
86
+
87
+ GraphInfo_t** graphs_info = _qnnApi->getGraphsInfo();
88
+ for (size_t graphIdx = 0; graphIdx < _numGraphs; graphIdx++) {
89
+ GraphInfo_t* const graphInfo = graphs_info[graphIdx];
90
+ char* graphName = graphInfo->graphName;
91
+ std::string graphStr = std::string(graphName);
92
+
93
+ _modelOrder.push_back(graphStr);
94
+ }
95
+ __INFO("Qnn-Gpu-Model : model init complete: {} usec", start.elapsed_usec());
96
+
97
+ return true;
98
+ }
99
+
100
+ // Once the model has been loaded, initialize IO Tensors
101
+ // _ioTensors is initialized by the context for now
102
+ bool QnnGpuModel::initializeIOTensors() {
103
+ qualla::Timer start;
104
+
105
+ // For QNN-GPU, we have only one context per model.
106
+ bool status = _ioTensor->initialize(_qnnApi->getContexts().back());
107
+ if (!status) {
108
+ __ERROR("Qnn-Gpu-Model : failure to initialize IOTensor");
109
+ return false;
110
+ }
111
+ // Getting graph info, Hardcoding single graph for now.
112
+ GraphInfo_t** const& graphsInfo = _qnnApi->getGraphsInfo();
113
+
114
+ for (size_t graphIdx = 0; graphIdx < _numGraphs; graphIdx++) {
115
+ GraphInfo_t* const& graphInfo = graphsInfo[graphIdx];
116
+ std::string graphName = std::string(graphInfo->graphName);
117
+
118
+ __DEBUG("Qnn-Gpu-Model : numInputTensors {} numOutputTensors {}",
119
+ graphInfo->numInputTensors,
120
+ graphInfo->numOutputTensors);
121
+ // Setup Inputs
122
+ {
123
+ std::unordered_map<std::string, size_t> inputTensorsSize;
124
+ for (size_t tensorIdx = 0; tensorIdx < graphInfo->numInputTensors; tensorIdx++) {
125
+ std::string tensorName;
126
+ std::vector<size_t> tensorDims;
127
+ auto& tensor = graphInfo->inputTensors[tensorIdx];
128
+ _qnnApi->getTensorNameAndShape(tensorName, tensorDims, tensor);
129
+ auto dims = QnnUtils::Dims(tensorDims);
130
+ inputTensorsSize[tensorName] = dims.getSize();
131
+ __DEBUG("Qnn-Gpu-Model : Input Tensor Info {} {} {} {}",
132
+ tensorIdx,
133
+ tensorName,
134
+ tensorDims,
135
+ inputTensorsSize[tensorName]);
136
+ std::vector<QnnUtils::QuantParam> quantParams;
137
+ if (!_qnnApi->getTensorQuantParams(&tensor, quantParams)) {
138
+ quantParams.emplace_back(0, 0);
139
+ }
140
+
141
+ std::shared_ptr<QnnUtils::Tensor> tensorUtil =
142
+ std::shared_ptr<QnnUtils::Tensor>(new (std::nothrow) QnnUtils::Tensor);
143
+ tensorUtil->dims = dims;
144
+ tensorUtil->dtype = QNN_TENSOR_GET_DATA_TYPE(tensor);
145
+ tensorUtil->quantParam = quantParams;
146
+ _inputSpecs[graphName][tensorName] = tensorUtil;
147
+ }
148
+
149
+ Qnn_Tensor_t* tensor_bank = nullptr;
150
+ std::unordered_map<std::string, void*> tensor_ptr_map;
151
+ if (true != _ioTensor->setupInputTensors(&tensor_bank,
152
+ tensor_ptr_map,
153
+ *graphInfo,
154
+ inputTensorsSize,
155
+ _qnnApi->getContexts()[graphIdx],
156
+ false)) {
157
+ QNN_ERROR("Qnn-Gpu-Model : Error in setting up Input Tensors for graph %s",
158
+ graphName.c_str());
159
+ return false;
160
+ }
161
+
162
+ _inputTensors[graphName] = tensor_bank;
163
+ for (auto& [tensorName, tensor_ptr] : tensor_ptr_map) {
164
+ _inputSpecs[graphName][tensorName]->tensor = (Qnn_Tensor_t*)tensor_ptr;
165
+ }
166
+ __DEBUG("Qnn-Gpu-Model : Input Tensor Allocated for {}", graphName);
167
+ }
168
+
169
+ // Setup Outputs
170
+ {
171
+ std::unordered_map<std::string, size_t> outputTensorsSize;
172
+ std::unordered_map<std::string, Qnn_Tensor_t*> sharedTensorMap;
173
+ for (size_t tensorIdx = 0; tensorIdx < graphInfo->numOutputTensors; tensorIdx++) {
174
+ std::string tensorName;
175
+ std::vector<size_t> tensorDims;
176
+
177
+ auto& tensor = graphInfo->outputTensors[tensorIdx];
178
+ _qnnApi->getTensorNameAndShape(tensorName, tensorDims, tensor);
179
+
180
+ if (tensorName.starts_with("past_")) {
181
+ std::string tensorInName = tensorName.substr(0, tensorName.size() - 3) + "in";
182
+ sharedTensorMap[tensorName] = _inputSpecs[graphName][tensorInName]->tensor;
183
+
184
+ // Update Gpu _kvCache
185
+ auto [type, layer_id] = parseKVTensorName(tensorName);
186
+ _kvCache.push_back(
187
+ GpuKVCache((type == 1), layer_id, _inputSpecs[graphName][tensorInName].get()));
188
+ }
189
+ std::vector<QnnUtils::QuantParam> quantParams;
190
+ if (!_qnnApi->getTensorQuantParams(&tensor, quantParams)) {
191
+ quantParams.emplace_back(0, 0);
192
+ }
193
+
194
+ auto dims = QnnUtils::Dims(tensorDims);
195
+ outputTensorsSize[tensorName] = dims.getAlignedSize();
196
+
197
+ __DEBUG("Qnn-Gpu-Model : Output Tensor Info {} {} {} {}",
198
+ tensorIdx,
199
+ tensorName,
200
+ tensorDims,
201
+ outputTensorsSize[tensorName]);
202
+ std::shared_ptr<QnnUtils::Tensor> tensorUtil =
203
+ std::shared_ptr<QnnUtils::Tensor>(new (std::nothrow) QnnUtils::Tensor);
204
+ tensorUtil->dims = dims;
205
+ tensorUtil->dtype = QNN_TENSOR_GET_DATA_TYPE(tensor);
206
+ tensorUtil->quantParam = quantParams;
207
+ _outputSpecs[graphName][tensorName] = tensorUtil;
208
+ }
209
+
210
+ Qnn_Tensor_t* tensor_bank = nullptr;
211
+ std::unordered_map<std::string, void*> tensor_ptr_map;
212
+ if (_ioTensor->getBufferAllocType() == BufferAlloc::DMABUF) {
213
+ if (true != _ioTensor->setupOutputWithSharedTensors(&tensor_bank,
214
+ tensor_ptr_map,
215
+ *graphInfo,
216
+ outputTensorsSize,
217
+ _qnnApi->getContexts()[graphIdx],
218
+ sharedTensorMap)) {
219
+ QNN_ERROR("Qnn-Gpu-Model : Error in setting up Output Tensors for graph %s",
220
+ graphName.c_str());
221
+ return false;
222
+ }
223
+ } else {
224
+ if (true != _ioTensor->setupOutputTensors(&tensor_bank,
225
+ tensor_ptr_map,
226
+ *graphInfo,
227
+ outputTensorsSize,
228
+ _qnnApi->getContexts()[graphIdx],
229
+ false)) {
230
+ QNN_ERROR("Qnn-Gpu-Model : Error in setting up Input Tensors for graph %s",
231
+ graphName.c_str());
232
+ return false;
233
+ }
234
+ }
235
+
236
+ _outputTensors[graphName] = tensor_bank;
237
+ for (auto& [tensorName, tensor_ptr] : tensor_ptr_map) {
238
+ _outputSpecs[graphName][tensorName]->tensor = (Qnn_Tensor_t*)tensor_ptr;
239
+ }
240
+
241
+ __DEBUG("Qnn-Gpu-Model : Output Tensor Allocated {} {}", graphName, _outputTensors.size());
242
+ }
243
+ }
244
+ auto stop = std::chrono::steady_clock::now();
245
+ return true;
246
+ }
247
+
248
+ bool QnnGpuModel::initializeTensorPointers() {
249
+ auto inputSpec = _inputSpecs[_modelOrder.back()];
250
+ auto outputSpec = _outputSpecs[_modelOrder.back()];
251
+
252
+ t_inputIds = inputSpec[INPUT_IDS].get();
253
+ t_attnMask = inputSpec[ATTN_MASK].get();
254
+ t_positionIds = inputSpec[POS_IDS].get();
255
+ t_logits = outputSpec[LOGITS].get();
256
+
257
+ auto status = !(t_inputIds == nullptr || t_attnMask == nullptr || t_positionIds == nullptr ||
258
+ t_logits == nullptr);
259
+
260
+ if (!status) {
261
+ __ERROR("Qnn-Gpu-Model : error in setting up named tensor pointers for llama.");
262
+ return false;
263
+ }
264
+ return true;
265
+ }
266
+
267
+ bool QnnGpuModel::validateModel() {
268
+ // Validating context Size.
269
+ size_t numInputs = t_inputIds->dims.getNumElements();
270
+ size_t dimMask = t_attnMask->dims.getNumElements();
271
+ size_t modelCtxSize = dimMask / numInputs;
272
+
273
+ if (modelCtxSize != _ctxSize) {
274
+ __ERROR("Qnn-Gpu-Model : Invalid Context Size {} {}.", modelCtxSize, _ctxSize);
275
+ return false;
276
+ }
277
+ return true;
278
+ }
279
+
280
+ void QnnGpuModel::setupInputTensors(const std::vector<int32_t>& tokens) {
281
+ auto start = std::chrono::steady_clock::now();
282
+
283
+ if (tokens.size() > _ctxSize) {
284
+ std::string errMsg = "Called inference with more tokens than model supports: ";
285
+ errMsg += std::to_string(tokens.size()) + " vs. " + std::to_string(_ctxSize);
286
+ throw std::runtime_error(errMsg);
287
+ }
288
+
289
+ // Setup 1. input_ids
290
+ // Index of input tokens in the embedding vocabulary
291
+ uint32_t* inputIdBuffer = (uint32_t*)getBuffer(t_inputIds);
292
+ if (inputIdBuffer) {
293
+ if (_useDmabufIo) {
294
+ _ioTensor->beforeWriteToBuffer(t_inputIds->tensor);
295
+ }
296
+ inputIdBuffer[0] = tokens[0];
297
+ if (_useDmabufIo) {
298
+ _ioTensor->afterWriteToBuffer(t_inputIds->tensor);
299
+ }
300
+ }
301
+
302
+ // Setup 2. attention_mask
303
+ // Mask to avoid performing attention of padding.
304
+ uint32_t* attnMaskBuffer = (uint32_t*)getBuffer(t_attnMask);
305
+ if (attnMaskBuffer) {
306
+ if (_useDmabufIo) {
307
+ _ioTensor->beforeWriteToBuffer(t_attnMask->tensor);
308
+ }
309
+ attnMaskBuffer[_numTokensProcessed] = 1;
310
+ if (_useDmabufIo) {
311
+ _ioTensor->afterWriteToBuffer(t_attnMask->tensor);
312
+ }
313
+ }
314
+
315
+ // Setup 3. position_ids
316
+ // Indices of positions of each input tokens in position embeddings.
317
+ uint32_t* positionIdBuffer = (uint32_t*)getBuffer(t_positionIds);
318
+ if (positionIdBuffer) {
319
+ if (_useDmabufIo) {
320
+ _ioTensor->beforeWriteToBuffer(t_positionIds->tensor);
321
+ }
322
+ positionIdBuffer[0] = (uint32_t)(_numTokensProcessed);
323
+ if (_useDmabufIo) {
324
+ _ioTensor->afterWriteToBuffer(t_positionIds->tensor);
325
+ }
326
+ }
327
+
328
+ auto stop = std::chrono::steady_clock::now();
329
+ }
330
+
331
+ template <class T1, class T2>
332
+ inline bool QnnGpuModel::executeModel(T1& input, T2& output, std::string graphName) {
333
+ bool ret = _qnnApi->graphExecute(input, output, graphName, timeLogs);
334
+ if (ret != true) {
335
+ QNN_ERROR("Qnn-Gpu-Model : Error executing inference: %d for graph %s", ret, graphName.c_str());
336
+ return false;
337
+ }
338
+ QNN_DEBUG("Qnn-Gpu-Model : Execute finished for graph %s", graphName.c_str());
339
+ return true;
340
+ }
341
+
342
+ bool QnnGpuModel::runInferenceHelper(std::vector<std::string>& exec_models,
343
+ int32_t* wait_time_total,
344
+ int32_t* exec_time_total,
345
+ bool pipeline_kv_update,
346
+ size_t update_size) {
347
+ int32_t exec_time = 0;
348
+ int32_t wait_time = 0;
349
+ for (auto& graphName : exec_models) {
350
+ {
351
+ auto start_time = std::chrono::steady_clock::now();
352
+ Qnn_Tensor_t* inputTensors;
353
+ Qnn_Tensor_t* outputTensors;
354
+ try {
355
+ inputTensors = _inputTensors[graphName];
356
+ outputTensors = _outputTensors[graphName];
357
+ } catch (std::exception e) {
358
+ __DEBUG("Qnn-Gpu-Model : Could not find tensors %s", graphName.c_str());
359
+ return false;
360
+ }
361
+ bool status = executeModel(inputTensors, outputTensors, graphName);
362
+ if (!status) {
363
+ return false;
364
+ }
365
+ auto end_time = std::chrono::steady_clock::now();
366
+ exec_time += static_cast<int32_t>(
367
+ std::chrono::duration_cast<std::chrono::microseconds>(end_time - start_time).count());
368
+ }
369
+ }
370
+
371
+ *exec_time_total += exec_time;
372
+ *wait_time_total += wait_time;
373
+ return true;
374
+ }
375
+
376
+ size_t QnnGpuModel::runInference(const std::vector<int32_t>& tokens,
377
+ std::vector<float>& logits,
378
+ bool logits_all) {
379
+ auto start = std::chrono::steady_clock::now();
380
+ int32_t totalWaitTime = 0;
381
+ int32_t totalExecTime = 0;
382
+
383
+ // Setup inputs for inference
384
+ auto& execModels = _modelOrder;
385
+ int numIters = tokens.size();
386
+ for (int i = 0; i < numIters; i++) {
387
+ if (numIters > 1) {
388
+ __DEBUG("Qnn-Gpu-Model : Prompt Processing {} of {} tokens", i + 1, numIters);
389
+ } else {
390
+ __DEBUG("Qnn-Gpu-Model : Token Generation {} of {} tokens", i + 1, numIters);
391
+ }
392
+ std::vector<int32_t> curr_tokens;
393
+ curr_tokens.push_back(tokens[i]);
394
+ setupInputTensors(curr_tokens);
395
+ bool status =
396
+ runInferenceHelper(execModels, &totalWaitTime, &totalExecTime, false, tokens.size());
397
+ if (!status) {
398
+ return 0;
399
+ }
400
+ processLogits(logits, logits_all);
401
+
402
+ // Update the numProcessTokens to updated with Accepted Tokens.
403
+ _numTokensProcessed++;
404
+ }
405
+
406
+ auto stop = std::chrono::steady_clock::now();
407
+ timeLogs["Run Inference (cpp) "].first += static_cast<double>(
408
+ std::chrono::duration_cast<std::chrono::microseconds>(stop - start).count());
409
+ timeLogs["Run Inference (cpp) "].second++;
410
+ QNN_DEBUG("[TIME] Wait[%d] Exec[%d]\n", totalWaitTime, totalExecTime);
411
+ if (!logits_all) {
412
+ return 1;
413
+ }
414
+ return tokens.size();
415
+ }
416
+
417
+ // Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h0]_{in,out}
418
+ std::tuple<int, int> QnnGpuModel::parseKVTensorName(std::string name) {
419
+ if (!name.starts_with("past_")) return {0, 0};
420
+
421
+ const bool is_key = name.starts_with("past_key");
422
+ const size_t pos0 = (is_key) ? 9 : 11; // "past_key_" OR "past_value_"
423
+ const size_t pos1 = name.find('_', pos0);
424
+
425
+ int layer_idx = static_cast<int>(std::stoi(name.substr(pos0, pos1 - pos0)));
426
+
427
+ return std::make_tuple(is_key ? 1 : 2, layer_idx);
428
+ }
429
+
430
+ size_t QnnGpuModel::loadKVCache(const std::string& load_path) {
431
+ std::ifstream fs(load_path, std::ios::in | std::ios::binary);
432
+ if (fs.fail()) {
433
+ __ERROR("Qnn-Gpu-Model : loadKVCache errror reading file {}", load_path);
434
+ return 0;
435
+ }
436
+
437
+ CacheFileSpec spec;
438
+ fs.read((char*)&spec, sizeof(spec));
439
+ if (spec.magic != g_magicNum) {
440
+ __ERROR("Qnn-Gpu-Model : loadKVCache expected {} found {:#x}", g_magicNum, spec.magic);
441
+ return 0;
442
+ }
443
+
444
+ // clang-format off
445
+ __INFO("Qnn-Gpu-Model : loadKVCache {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
446
+ spec.num_tensors, spec.magic, int(spec.dtype), spec.n_heads, spec.embed_dim, spec.update_size); fflush(stdout);
447
+ // clang-format on
448
+
449
+ _numTokensProcessed = static_cast<size_t>(spec.update_size);
450
+ if (_numTokensProcessed > 0) {
451
+ // Loop over _kvCache tensor and read from file
452
+ for (auto cache : _kvCache) {
453
+ if (_useDmabufIo) {
454
+ _ioTensor->beforeWriteToBuffer(t_inputIds->tensor);
455
+ }
456
+ char* buffer = (char*)getBuffer(cache.tensorUtil);
457
+ if (cache.isKey) {
458
+ // Kye Cache Dims [1, num_heads, head_dim, ctx_size]
459
+ // float16 bits equivalent to uint16_t
460
+ const size_t copySize = _numTokensProcessed;
461
+ const size_t skipSize = _ctxSize;
462
+ for (int i = 0; i < _numHeads; i++) {
463
+ for (int j = 0; j < _headDim; j++) {
464
+ fs.read(buffer, copySize * sizeof(uint16_t));
465
+ buffer += skipSize * sizeof(uint16_t);
466
+ }
467
+ }
468
+ } else {
469
+ // Kye Cache Dims [1, num_heads, ctx_size, head_dim]
470
+ // float16 bits equivalent to uint16_t
471
+ const size_t copySize = _numTokensProcessed * _headDim;
472
+ const size_t skipSize = _ctxSize * _headDim;
473
+ for (int i = 0; i < _numHeads; i++) {
474
+ fs.read(buffer, copySize * sizeof(uint16_t));
475
+ buffer += skipSize * sizeof(uint16_t);
476
+ }
477
+ }
478
+ if (_useDmabufIo) {
479
+ _ioTensor->afterWriteToBuffer(t_inputIds->tensor);
480
+ }
481
+ }
482
+ }
483
+ return _numTokensProcessed;
484
+ }
485
+
486
+ bool QnnGpuModel::saveKVCache(const std::string& save_path) {
487
+ std::ofstream fs(save_path, std::ios::out | std::ios::binary);
488
+ if (fs.fail()) {
489
+ __ERROR("Qnn-Gpu-Model : saveKVCache error opening file : {}", save_path);
490
+ throw std::runtime_error("Failed to write to cache file. Please re-check path");
491
+ }
492
+
493
+ const CacheFileSpec::DataType dtype = CacheFileSpec::DataType::FLOAT16_T;
494
+
495
+ uint32_t numKVTensors = _kvCache.size();
496
+
497
+ // Save the cache file metadata
498
+ CacheFileSpec file_spec(
499
+ numKVTensors, g_magicNum, dtype, 0x0, _numHeads, _headDim, _numTokensProcessed);
500
+ fs.write((char*)&file_spec, sizeof(file_spec));
501
+
502
+ // clang-format off
503
+ __INFO("Qnn-Gpu-Model : saveKVCache {{ num_tensors {}, magic {}, dtype {}, n_heads {}, embed_dim {} update_size {} }}",
504
+ numKVTensors, g_magicNum, int(dtype), _numHeads, _headDim, _numTokensProcessed); fflush(stdout);
505
+ // clang-format on
506
+
507
+ if (_numTokensProcessed > 0) {
508
+ // Loop over _kvCache tensor and write to file
509
+ for (auto cache : _kvCache) {
510
+ if (_useDmabufIo) {
511
+ _ioTensor->beforeReadFromBuffer(t_inputIds->tensor);
512
+ }
513
+ char* buffer = (char*)getBuffer(cache.tensorUtil);
514
+ if (cache.isKey) {
515
+ // Kye Cache Dims [1, num_heads, head_dim, ctx_size]
516
+ // float16 bits equivalent to uint16_t
517
+ const size_t copySize = _numTokensProcessed;
518
+ const size_t skipSize = _ctxSize;
519
+ for (int i = 0; i < _numHeads; i++) {
520
+ for (int j = 0; j < _headDim; j++) {
521
+ fs.write((char*)buffer, copySize * sizeof(uint16_t));
522
+ buffer += skipSize * sizeof(uint16_t);
523
+ }
524
+ }
525
+ } else {
526
+ // Kye Cache Dims [1, num_heads, ctx_size, head_dim]
527
+ // float16 bits equivalent to uint16_t
528
+ const size_t copySize = _numTokensProcessed * _headDim;
529
+ const size_t skipSize = _ctxSize * _headDim;
530
+ for (int i = 0; i < _numHeads; i++) {
531
+ fs.write((char*)buffer, copySize * sizeof(uint16_t));
532
+ buffer += skipSize;
533
+ }
534
+ }
535
+ if (_useDmabufIo) {
536
+ _ioTensor->afterReadFromBuffer(t_inputIds->tensor);
537
+ }
538
+ }
539
+ }
540
+ fs.flush();
541
+ fs.close();
542
+
543
+ return true;
544
+ }
545
+
546
+ size_t QnnGpuModel::processLogits(std::vector<float>& logits, bool logits_all) {
547
+ auto logitsSpec = _outputSpecs[_modelOrder.back()][LOGITS].get();
548
+ size_t logitsSize = getNumElements(logitsSpec);
549
+ if (_useDmabufIo) {
550
+ _ioTensor->beforeReadFromBuffer(t_inputIds->tensor);
551
+ }
552
+ uint16_t* logitBuf = (uint16_t*)getBuffer(logitsSpec);
553
+
554
+ if (!logits_all) {
555
+ logits.clear();
556
+ }
557
+ size_t allocateSize = logits.size() + logitsSize;
558
+ logits.reserve(allocateSize);
559
+ for (auto i = 0; i < logitsSize; ++i) {
560
+ logits.push_back(fp16_ieee_to_fp32_value(logitBuf[i]));
561
+ }
562
+ if (_useDmabufIo) {
563
+ _ioTensor->afterReadFromBuffer(t_inputIds->tensor);
564
+ }
565
+
566
+ return logits.size() / logitsSize;
567
+ }
568
+
569
+ bool QnnGpuModel::reset() {
570
+ // Reset Token Counter
571
+ _numTokensProcessed = 0;
572
+
573
+ // Reset Attention Mask
574
+ uint32_t* attnMaskBuffer = (uint32_t*)getBuffer(t_attnMask);
575
+ uint32_t attnMaskSize = getBufferSize(t_attnMask);
576
+ if (attnMaskBuffer) {
577
+ if (_useDmabufIo) {
578
+ _ioTensor->beforeWriteToBuffer(t_attnMask->tensor);
579
+ }
580
+ memset(attnMaskBuffer, 0, attnMaskSize);
581
+ if (_useDmabufIo) {
582
+ _ioTensor->afterWriteToBuffer(t_attnMask->tensor);
583
+ }
584
+ }
585
+
586
+ // Reset KV Cache.
587
+ // TODO : Check if mask_neg -100 is enough to remove
588
+ // effect of KV Cache. Test with mask_neg = -float_inf
589
+ for (auto cache : _kvCache) {
590
+ if (_useDmabufIo) {
591
+ _ioTensor->beforeWriteToBuffer(t_inputIds->tensor);
592
+ }
593
+ char* buffer = (char*)getBuffer(cache.tensorUtil);
594
+ uint32_t bufferSize = getBufferSize(cache.tensorUtil);
595
+ memset(buffer, 0, bufferSize);
596
+ if (_useDmabufIo) {
597
+ _ioTensor->afterWriteToBuffer(t_inputIds->tensor);
598
+ }
599
+ }
600
+ return true;
601
+ }
602
+
603
+ } // namespace qualla
Genie/Genie/src/qualla/engines/qnn-gpu/gpu-model.hpp ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
2
+ // Confidential & Proprietary - Qualcomm Technologies, Inc. ("QTI")
3
+
4
+ #ifndef __QUALLA_QNN_GPU_MODEL_H_
5
+ #define __QUALLA_QNN_GPU_MODEL_H_
6
+
7
+ #include <atomic>
8
+ #include <filesystem>
9
+ #include <string>
10
+ #include <vector>
11
+
12
+ #include "IOTensor.hpp"
13
+ #include "QnnApi.hpp"
14
+ #include "qnn-utils.hpp"
15
+ #include "qualla/env.hpp"
16
+
17
+ namespace qualla {
18
+
19
+ // Maintain a list of named tensors for
20
+ static std::string INPUT_IDS = "input_ids";
21
+ static std::string ATTN_MASK = "attention_mask";
22
+ static std::string LOGITS = "logits";
23
+ static std::string POS_IDS = "position_ids";
24
+
25
+ class QnnGpuModel {
26
+ public:
27
+ struct Params {
28
+ std::filesystem::path model_basedir;
29
+ std::vector<std::string> model_list; // model filenames
30
+ uint32_t ctx_size;
31
+ uint32_t num_heads;
32
+ uint32_t head_dim;
33
+ };
34
+
35
+ struct GpuKVCache {
36
+ bool isKey;
37
+ uint32_t tensorId;
38
+ QnnUtils::Tensor* tensorUtil;
39
+
40
+ GpuKVCache() {
41
+ isKey = false;
42
+ tensorUtil = nullptr;
43
+ tensorId = 0;
44
+ }
45
+ GpuKVCache(bool _isKey, uint32_t _tensorId, QnnUtils::Tensor* _tensorUtil)
46
+ : isKey(_isKey), tensorId(_tensorId), tensorUtil(_tensorUtil) {}
47
+ };
48
+
49
+ // QNN specific variables
50
+ std::unique_ptr<QnnApi> _qnnApi;
51
+ std::unique_ptr<IOTensor> _ioTensor{nullptr};
52
+
53
+ // Model Location Storage
54
+ const std::filesystem::path _modelBaseDir;
55
+ std::vector<std::string> _modelList;
56
+ std::vector<std::string> _modelOrder;
57
+
58
+ bool _useDmabufIo;
59
+
60
+ // Model parameters
61
+ uint32_t _ctxSize{0};
62
+ uint32_t _numHeads{0};
63
+ uint32_t _headDim{0};
64
+
65
+ // Information regarding model execution settings and last inference
66
+
67
+ // Model specific variables
68
+ uint32_t _numGraphs;
69
+ // I/O Tensor Informations
70
+ std::unordered_map<std::string, Qnn_Tensor_t*> _inputTensors;
71
+ std::unordered_map<std::string,
72
+ std::unordered_map<std::string, std::shared_ptr<QnnUtils::Tensor>>>
73
+ _inputSpecs;
74
+
75
+ std::unordered_map<std::string, Qnn_Tensor_t*> _outputTensors;
76
+ std::unordered_map<std::string,
77
+ std::unordered_map<std::string, std::shared_ptr<QnnUtils::Tensor>>>
78
+ _outputSpecs;
79
+
80
+ // Store some pointers for easier access
81
+ QnnUtils::Tensor* t_inputIds{nullptr};
82
+ QnnUtils::Tensor* t_attnMask{nullptr};
83
+ QnnUtils::Tensor* t_positionIds{nullptr};
84
+ QnnUtils::Tensor* t_logits{nullptr};
85
+
86
+ // _numTokensProcessed defines number of population of kvcache
87
+ size_t _numTokensProcessed{0};
88
+
89
+ std::vector<GpuKVCache> _kvCache;
90
+
91
+ std::map<std::string, std::pair<double, uint16_t>> timeLogs;
92
+
93
+ // Model Constructor
94
+ QnnGpuModel(Env& env, const Params& params);
95
+ ~QnnGpuModel();
96
+
97
+ bool initializeModel(void);
98
+ bool initializeIOTensors(void);
99
+ void setupInputTensors(const std::vector<int32_t>& tokens);
100
+ bool initializeTensorPointers();
101
+ bool validateModel();
102
+
103
+ template <class T1, class T2>
104
+ inline bool executeModel(T1& input, T2& output, std::string graph_name);
105
+
106
+ size_t runInference(const std::vector<int32_t>& tokens,
107
+ std::vector<float>& logits,
108
+ bool logits_all = false);
109
+
110
+ size_t loadKVCache(const std::string& save_path);
111
+ bool saveKVCache(const std::string& load_path);
112
+ bool reset();
113
+
114
+ private:
115
+ Env& _env;
116
+ // Internal functions to separate different runInference logic
117
+ bool runInferenceHelper(std::vector<std::string>& exec_models,
118
+ int32_t* wait_time_total,
119
+ int32_t* exec_time_total,
120
+ bool pipeline_kv_update,
121
+ size_t update_size);
122
+ size_t processLogits(std::vector<float>& logits, bool logits_all);
123
+ inline void* getBuffer(QnnUtils::Tensor& spec) { return _ioTensor->getBuffer(spec.tensor); }
124
+ inline void* getBuffer(QnnUtils::Tensor* spec) { return _ioTensor->getBuffer(spec->tensor); }
125
+ inline size_t getBufferSize(QnnUtils::Tensor& spec) { return spec.dims.getSize(); }
126
+ inline size_t getBufferSize(QnnUtils::Tensor* spec) { return spec->dims.getSize(); }
127
+ inline size_t getNumElements(QnnUtils::Tensor& spec) { return spec.dims.getNumElements(); }
128
+ inline size_t getNumElements(QnnUtils::Tensor* spec) { return spec->dims.getNumElements(); }
129
+
130
+ // Parse KV$ Tensor names here - supports past_{key,value}_{layer_idx}[_h0]_{in,out}
131
+ std::tuple<int, int> parseKVTensorName(std::string name);
132
+ };
133
+
134
+ } // namespace qualla
135
+
136
+ #endif
Genie/Genie/src/qualla/engines/qnn-htp.cpp CHANGED
@@ -353,11 +353,11 @@ qualla::InputType NspEngine::getInputType(){
353
  return _model->m_inputType;
354
  }
355
 
356
- size_t NspEngine::restore(const std::string& name) {
357
  if (!_model && !load()) return 0;
358
 
359
  fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-htp", _role);
360
- return _model->loadKVCache(cache_path.string());
361
  }
362
 
363
  bool NspEngine::save(const std::string& name) {
 
353
  return _model->m_inputType;
354
  }
355
 
356
+ size_t NspEngine::restore(const std::string& name, bool chooseHigherVariant) {
357
  if (!_model && !load()) return 0;
358
 
359
  fs::path cache_path = std::filesystem::path(name) / fmt::format("kv-cache.{}.qnn-htp", _role);
360
+ return _model->loadKVCache(cache_path.string(), chooseHigherVariant);
361
  }
362
 
363
  bool NspEngine::save(const std::string& name) {
Genie/Genie/src/qualla/engines/qnn-htp.hpp CHANGED
@@ -70,7 +70,7 @@ class NspEngine : public Engine {
70
  virtual bool updateKV(size_t n_past) override;
71
  virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
72
  virtual bool save(const std::string& name) override;
73
- virtual size_t restore(const std::string& name) override;
74
  virtual void reset() override;
75
 
76
  virtual bool set(qualla::json data) override;
 
70
  virtual bool updateKV(size_t n_past) override;
71
  virtual bool updateKV(size_t n_past, const std::vector<bool>& selected) override;
72
  virtual bool save(const std::string& name) override;
73
+ virtual size_t restore(const std::string& name, bool chooseHigherVariant) override;
74
  virtual void reset() override;
75
 
76
  virtual bool set(qualla::json data) override;
Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.cpp CHANGED
@@ -338,7 +338,7 @@ bool NewNSPKVManager::registerPointerOffset() {
338
  return true;
339
  }
340
 
341
- bool NewNSPKVManager::updateState() {
342
  // clang-format off
343
  __TRACE("qnn-kv : graph[{}] updateState to AR-{}(n_past={}, ptr={})", _mgr_idx,
344
  _req_state.variant, _req_state.n_past, _req_state.ptr_offset);
@@ -354,9 +354,15 @@ bool NewNSPKVManager::updateState() {
354
  cache.output_buffer += cache.is_key ? _n_ctx * _bw : _n_ctx * _n_embed * _bw;
355
  }
356
  }
357
-
358
  _cur_state = _req_state;
 
 
359
  _counter = _callback_fn(_mgr_idx);
 
 
 
 
 
360
  return true;
361
  }
362
 
@@ -525,7 +531,7 @@ bool NewNSPKVManager::loadCache(
525
  }
526
 
527
  _req_state = {variant, n_valid, 0};
528
- updateState();
529
 
530
  return true;
531
  }
 
338
  return true;
339
  }
340
 
341
+ void NewNSPKVManager::updateKVCache(){
342
  // clang-format off
343
  __TRACE("qnn-kv : graph[{}] updateState to AR-{}(n_past={}, ptr={})", _mgr_idx,
344
  _req_state.variant, _req_state.n_past, _req_state.ptr_offset);
 
354
  cache.output_buffer += cache.is_key ? _n_ctx * _bw : _n_ctx * _n_embed * _bw;
355
  }
356
  }
 
357
  _cur_state = _req_state;
358
+ }
359
+ void NewNSPKVManager::updateKVDispatcher(){
360
  _counter = _callback_fn(_mgr_idx);
361
+ }
362
+
363
+ bool NewNSPKVManager::updateState() {
364
+ updateKVCache();
365
+ updateKVDispatcher();
366
  return true;
367
  }
368
 
 
531
  }
532
 
533
  _req_state = {variant, n_valid, 0};
534
+ updateKVCache();
535
 
536
  return true;
537
  }
Genie/Genie/src/qualla/engines/qnn-htp/nsp-kvmanager.hpp CHANGED
@@ -134,7 +134,8 @@ class NewNSPKVManager {
134
  int32_t n_heads
135
  );
136
  bool dumpCache(std::ofstream* fs, bool is_key, int32_t n_valid, int32_t n_heads);
137
-
 
138
  bool updateState();
139
  void runKVUpdateJob(int thread_idx); // Worker thread function
140
  void setTensorAllocInfo(std::map<std::string, std::pair<int, size_t>>* alloc_info) {
 
134
  int32_t n_heads
135
  );
136
  bool dumpCache(std::ofstream* fs, bool is_key, int32_t n_valid, int32_t n_heads);
137
+ void updateKVCache();
138
+ void updateKVDispatcher();
139
  bool updateState();
140
  void runKVUpdateJob(int thread_idx); // Worker thread function
141
  void setTensorAllocInfo(std::map<std::string, std::pair<int, size_t>>* alloc_info) {
Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.cpp CHANGED
@@ -1960,6 +1960,9 @@ bool QnnNspModel::calculate_rope_embeddings(void) {
1960
  const size_t nmemb = m_ctx_size * m_pos_dim;
1961
  const int pos_bw = d_pos.bw();
1962
 
 
 
 
1963
  rope_sin = malloc(nmemb * pos_bw);
1964
  rope_cos = malloc(nmemb * pos_bw);
1965
 
@@ -1973,7 +1976,7 @@ bool QnnNspModel::calculate_rope_embeddings(void) {
1973
  std::vector<double> inv_freq(m_pos_dim);
1974
  const double exponent = 1.0 / static_cast<double>(m_pos_dim);
1975
  for (int j = 0; j < m_pos_dim; j++)
1976
- inv_freq[j] = 1.0 / pow(rope_theta, j * exponent);
1977
  double attention_factor = 1.0;
1978
  if (rope_scaling.rope_type == RopeScalingParams::ROPE_LLAMA3) {
1979
  // Implemented from HuggingFace
@@ -1991,7 +1994,7 @@ bool QnnNspModel::calculate_rope_embeddings(void) {
1991
  if (wavelen < high_freq_wavelen) // wavelen < high_freq_wavelen: do nothing
1992
  continue;
1993
  else if (wavelen > low_freq_wavelen) // wavelen > low_freq_wavelen: divide by factor
1994
- inv_freq[j] = 1.0 / static_cast<double>(factor * pow(rope_theta, j * exponent));
1995
  else { // otherwise: interpolate between the two, using a smooth factor
1996
  assert(low_freq_wavelen != high_freq_wavelen);
1997
  const double smooth =
@@ -2266,7 +2269,7 @@ void QnnNspModel::dumpTensorSpecs() {
2266
  }
2267
  }
2268
 
2269
- size_t QnnNspModel::loadKVCache(const std::string& load_path) {
2270
 
2271
  if(m_disableKvCache){
2272
  __ERROR("KV cache is disabled, loading KV cache is not allowed");
@@ -2308,7 +2311,8 @@ size_t QnnNspModel::loadKVCache(const std::string& load_path) {
2308
  // clang-format on
2309
 
2310
  const int32_t n_valid = static_cast<int32_t>(spec.update_size);
2311
- const int32_t variant = nsp_graph_count.begin()->first; // Set KVManager to smallest variant
 
2312
  _kv_dispatcher->setVariant(variant);
2313
 
2314
  // Lock, load KeyCache then ValueCache, unlock
 
1960
  const size_t nmemb = m_ctx_size * m_pos_dim;
1961
  const int pos_bw = d_pos.bw();
1962
 
1963
+ const double theta = m_positional_encoding.rope_params.theta;
1964
+ const RopeScalingParams& rope_scaling = m_positional_encoding.rope_params.rope_scaling;
1965
+
1966
  rope_sin = malloc(nmemb * pos_bw);
1967
  rope_cos = malloc(nmemb * pos_bw);
1968
 
 
1976
  std::vector<double> inv_freq(m_pos_dim);
1977
  const double exponent = 1.0 / static_cast<double>(m_pos_dim);
1978
  for (int j = 0; j < m_pos_dim; j++)
1979
+ inv_freq[j] = 1.0 / pow(theta, j * exponent);
1980
  double attention_factor = 1.0;
1981
  if (rope_scaling.rope_type == RopeScalingParams::ROPE_LLAMA3) {
1982
  // Implemented from HuggingFace
 
1994
  if (wavelen < high_freq_wavelen) // wavelen < high_freq_wavelen: do nothing
1995
  continue;
1996
  else if (wavelen > low_freq_wavelen) // wavelen > low_freq_wavelen: divide by factor
1997
+ inv_freq[j] = 1.0 / static_cast<double>(factor * pow(theta, j * exponent));
1998
  else { // otherwise: interpolate between the two, using a smooth factor
1999
  assert(low_freq_wavelen != high_freq_wavelen);
2000
  const double smooth =
 
2269
  }
2270
  }
2271
 
2272
+ size_t QnnNspModel::loadKVCache(const std::string& load_path, bool chooseHigherVariant) {
2273
 
2274
  if(m_disableKvCache){
2275
  __ERROR("KV cache is disabled, loading KV cache is not allowed");
 
2311
  // clang-format on
2312
 
2313
  const int32_t n_valid = static_cast<int32_t>(spec.update_size);
2314
+ int32_t variant = nsp_graph_count.begin()->first; // Set KVManager to smallest variant
2315
+ if(chooseHigherVariant) variant = nsp_graph_count.rbegin()->first; // Ideal for loading KV prefix cache
2316
  _kv_dispatcher->setVariant(variant);
2317
 
2318
  // Lock, load KeyCache then ValueCache, unlock
Genie/Genie/src/qualla/engines/qnn-htp/nsp-model.hpp CHANGED
@@ -54,14 +54,14 @@ struct RopeScalingParams {
54
  double low_freq_factor;
55
  double high_freq_factor;
56
  int original_max_position_embeddings;
57
- } llama3_params;
58
 
59
  struct {
60
  double factor;
61
  std::vector<double> long_factor;
62
  std::vector<double> short_factor;
63
  int original_max_position_embeddings;
64
- } longrope_params;
65
 
66
  RopeScalingParams() {}
67
  };
@@ -79,7 +79,7 @@ struct PositionalEncoding {
79
  int32_t dims;
80
  double theta;
81
  RopeScalingParams rope_scaling;
82
- } rope_params;
83
 
84
  PositionalEncoding() { type = ROPE; }
85
  };
@@ -265,10 +265,8 @@ class QnnNspModel {
265
  QnnUtils::Tensor* t_position_ids{nullptr};
266
  // PositionalEncodingType::ROPE variables
267
  int32_t m_pos_dim{-1}; // Dimension of positional embedding tensor (incl partial_factor)
268
- double rope_theta{10000.0}; // Base theta parameter for RoPE calculations
269
  void* rope_sin{nullptr}; // Pre-calculated RoPE sin table of size [ctx_size, m_pos_dim]
270
  void* rope_cos{nullptr}; // Pre-calculated RoPE cos table of size [ctx_size, m_pos_dim]
271
- RopeScalingParams rope_scaling; // RoPE scaling parameters
272
 
273
  QnnUtils::Tensor* t_position_ids_sin{nullptr};
274
  QnnUtils::Tensor* t_position_ids_cos{nullptr};
@@ -398,7 +396,7 @@ class QnnNspModel {
398
 
399
  bool debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName);
400
 
401
- size_t loadKVCache(const std::string& load_path);
402
  bool saveKVCache(const std::string& save_path);
403
  bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
404
  bool applyLoraAdapter(const std::string& lora_adapter_name);
 
54
  double low_freq_factor;
55
  double high_freq_factor;
56
  int original_max_position_embeddings;
57
+ } llama3_params {0};
58
 
59
  struct {
60
  double factor;
61
  std::vector<double> long_factor;
62
  std::vector<double> short_factor;
63
  int original_max_position_embeddings;
64
+ } longrope_params {0};
65
 
66
  RopeScalingParams() {}
67
  };
 
79
  int32_t dims;
80
  double theta;
81
  RopeScalingParams rope_scaling;
82
+ } rope_params {0};
83
 
84
  PositionalEncoding() { type = ROPE; }
85
  };
 
265
  QnnUtils::Tensor* t_position_ids{nullptr};
266
  // PositionalEncodingType::ROPE variables
267
  int32_t m_pos_dim{-1}; // Dimension of positional embedding tensor (incl partial_factor)
 
268
  void* rope_sin{nullptr}; // Pre-calculated RoPE sin table of size [ctx_size, m_pos_dim]
269
  void* rope_cos{nullptr}; // Pre-calculated RoPE cos table of size [ctx_size, m_pos_dim]
 
270
 
271
  QnnUtils::Tensor* t_position_ids_sin{nullptr};
272
  QnnUtils::Tensor* t_position_ids_cos{nullptr};
 
396
 
397
  bool debugOutputs(QnnUtils::Tensor* outTensor, std::string& outTensorName);
398
 
399
+ size_t loadKVCache(const std::string& load_path, bool chooseHigherVariant=false);
400
  bool saveKVCache(const std::string& save_path);
401
  bool applyLoraStrength(const std::string& alpha_tensor_name, const float alpha_val);
402
  bool applyLoraAdapter(const std::string& lora_adapter_name);
Genie/Genie/src/qualla/include/qualla/detail/basic-sampler.hpp CHANGED
@@ -39,6 +39,7 @@ class BasicSampler : public Sampler {
39
  virtual bool save(const std::string& name) override;
40
  virtual bool restore(const std::string& name) override;
41
  virtual void reset() override;
 
42
 
43
  protected:
44
  int32_t _process(std::span<const float> logits, std::vector<float>* probs_out, bool samp_tok);
 
39
  virtual bool save(const std::string& name) override;
40
  virtual bool restore(const std::string& name) override;
41
  virtual void reset() override;
42
+ virtual void applyConfig(const qualla::json& conf) override;
43
 
44
  protected:
45
  int32_t _process(std::span<const float> logits, std::vector<float>* probs_out, bool samp_tok);
Genie/Genie/src/qualla/include/qualla/dialog.hpp CHANGED
@@ -107,6 +107,7 @@ class Dialog : public State {
107
  Tokenizer& tokenizer() { return *_tokenizer; }
108
  Sampler& sampler(const std::string& role = "primary") { return *_sampler[role]; }
109
  Engine& engine(const std::string& role = "primary") { return *_engine[role]; }
 
110
 
111
  // Get latest KPIs.
112
  // Updates TPS, etc as needed.
 
107
  Tokenizer& tokenizer() { return *_tokenizer; }
108
  Sampler& sampler(const std::string& role = "primary") { return *_sampler[role]; }
109
  Engine& engine(const std::string& role = "primary") { return *_engine[role]; }
110
+ bool isSamplerPresent(std::string role) { return _sampler.find(role) != _sampler.end(); }
111
 
112
  // Get latest KPIs.
113
  // Updates TPS, etc as needed.
Genie/Genie/src/qualla/include/qualla/engine.hpp CHANGED
@@ -86,7 +86,7 @@ class Engine : public State {
86
  QUALLA_API virtual bool updateKV(size_t n_past, const std::vector<bool>& selected);
87
 
88
  QUALLA_API virtual bool save(const std::string& name);
89
- QUALLA_API virtual size_t restore(const std::string& name);
90
  QUALLA_API virtual void reset();
91
 
92
  QUALLA_API virtual bool cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding);
 
86
  QUALLA_API virtual bool updateKV(size_t n_past, const std::vector<bool>& selected);
87
 
88
  QUALLA_API virtual bool save(const std::string& name);
89
+ QUALLA_API virtual size_t restore(const std::string& name, bool chooseHigherVariant=false);
90
  QUALLA_API virtual void reset();
91
 
92
  QUALLA_API virtual bool cacheEosEmbedding(std::vector<uint8_t>& eosEmbedding);
Genie/Genie/src/qualla/include/qualla/sampler.hpp CHANGED
@@ -54,6 +54,7 @@ class Sampler : public State {
54
  QUALLA_API virtual bool save(const std::string& name);
55
  QUALLA_API virtual bool restore(const std::string& name);
56
  QUALLA_API virtual void reset();
 
57
 
58
  // Get sampler type
59
  const std::string& type() const { return _type; }
 
54
  QUALLA_API virtual bool save(const std::string& name);
55
  QUALLA_API virtual bool restore(const std::string& name);
56
  QUALLA_API virtual void reset();
57
+ QUALLA_API virtual void applyConfig(const qualla::json& conf);
58
 
59
  // Get sampler type
60
  const std::string& type() const { return _type; }
Genie/Genie/src/qualla/sampler.cpp CHANGED
@@ -84,6 +84,10 @@ std::vector<int32_t> Sampler::process_multiple(
84
  return {-1};
85
  }
86
 
 
 
 
 
87
  // Sampler registry
88
 
89
  using Registry = std::unordered_map<std::string, Sampler::Creator>;
 
84
  return {-1};
85
  }
86
 
87
+ void Sampler::applyConfig(const qualla::json& conf) {
88
+ _env.logger().warn(fmt::format("Basic sampler supports this for now"));
89
+ }
90
+
91
  // Sampler registry
92
 
93
  using Registry = std::unordered_map<std::string, Sampler::Creator>;
Genie/Genie/src/qualla/samplers/basic.cpp CHANGED
@@ -221,4 +221,12 @@ static OnLoad regy([]() {
221
 
222
  void needBasicSampler() {}
223
 
 
 
 
 
 
 
 
 
224
  } // namespace qualla
 
221
 
222
  void needBasicSampler() {}
223
 
224
+ void BasicSampler::applyConfig(const json& conf) {
225
+ if (conf.contains("seed")) _seed = conf["seed"];
226
+ if (conf.contains("temp")) _temp = conf["temp"];
227
+
228
+ if (conf.contains("top-k")) _top_k = conf["top-k"];
229
+ if (conf.contains("top-p")) _top_p = conf["top-p"];
230
+ }
231
+
232
  } // namespace qualla
Genie/Genie/src/qualla/tokenizers/rust/Cargo.lock CHANGED
@@ -31,9 +31,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
31
 
32
  [[package]]
33
  name = "cc"
34
- version = "1.1.34"
35
  source = "registry+https://github.com/rust-lang/crates.io-index"
36
- checksum = "67b9470d453346108f93a59222a9a1a5724db32d0a4727b7ab7ace4b4d822dc9"
37
  dependencies = [
38
  "shlex",
39
  ]
@@ -190,9 +190,9 @@ dependencies = [
190
 
191
  [[package]]
192
  name = "itoa"
193
- version = "1.0.11"
194
  source = "registry+https://github.com/rust-lang/crates.io-index"
195
- checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
196
 
197
  [[package]]
198
  name = "lazy_static"
@@ -202,9 +202,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
202
 
203
  [[package]]
204
  name = "libc"
205
- version = "0.2.161"
206
  source = "registry+https://github.com/rust-lang/crates.io-index"
207
- checksum = "8e9489c2807c139ffd9c1794f4af0ebe86a828db53ecdc7fea2111d0fed085d1"
208
 
209
  [[package]]
210
  name = "log"
@@ -322,9 +322,9 @@ dependencies = [
322
 
323
  [[package]]
324
  name = "proc-macro2"
325
- version = "1.0.89"
326
  source = "registry+https://github.com/rust-lang/crates.io-index"
327
- checksum = "f139b0662de085916d1fb67d2b4169d1addddda1919e696f3252b740b629986e"
328
  dependencies = [
329
  "unicode-ident",
330
  ]
@@ -413,9 +413,9 @@ dependencies = [
413
 
414
  [[package]]
415
  name = "regex-automata"
416
- version = "0.4.8"
417
  source = "registry+https://github.com/rust-lang/crates.io-index"
418
- checksum = "368758f23274712b504848e9d5a6f010445cc8b87a7cdb4d7cbee666c1288da3"
419
  dependencies = [
420
  "aho-corasick",
421
  "memchr",
@@ -436,18 +436,18 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
436
 
437
  [[package]]
438
  name = "serde"
439
- version = "1.0.214"
440
  source = "registry+https://github.com/rust-lang/crates.io-index"
441
- checksum = "f55c3193aca71c12ad7890f1785d2b73e1b9f63a0bbc353c08ef26fe03fc56b5"
442
  dependencies = [
443
  "serde_derive",
444
  ]
445
 
446
  [[package]]
447
  name = "serde_derive"
448
- version = "1.0.214"
449
  source = "registry+https://github.com/rust-lang/crates.io-index"
450
- checksum = "de523f781f095e28fa605cdce0f8307e451cc0fd14e2eb4cd2e98a355b147766"
451
  dependencies = [
452
  "proc-macro2",
453
  "quote",
@@ -456,9 +456,9 @@ dependencies = [
456
 
457
  [[package]]
458
  name = "serde_json"
459
- version = "1.0.132"
460
  source = "registry+https://github.com/rust-lang/crates.io-index"
461
- checksum = "d726bfaff4b320266d395898905d0eba0345aae23b54aee3a737e260fd46db03"
462
  dependencies = [
463
  "itoa",
464
  "memchr",
@@ -498,9 +498,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
498
 
499
  [[package]]
500
  name = "syn"
501
- version = "2.0.87"
502
  source = "registry+https://github.com/rust-lang/crates.io-index"
503
- checksum = "25aa4ce346d03a6dcd68dd8b4010bcb74e54e62c90c573f394c46eae99aba32d"
504
  dependencies = [
505
  "proc-macro2",
506
  "quote",
@@ -509,18 +509,18 @@ dependencies = [
509
 
510
  [[package]]
511
  name = "thiserror"
512
- version = "1.0.66"
513
  source = "registry+https://github.com/rust-lang/crates.io-index"
514
- checksum = "5d171f59dbaa811dbbb1aee1e73db92ec2b122911a48e1390dfe327a821ddede"
515
  dependencies = [
516
  "thiserror-impl",
517
  ]
518
 
519
  [[package]]
520
  name = "thiserror-impl"
521
- version = "1.0.66"
522
  source = "registry+https://github.com/rust-lang/crates.io-index"
523
- checksum = "b08be0f17bd307950653ce45db00cd31200d82b624b36e181337d9c7d92765b5"
524
  dependencies = [
525
  "proc-macro2",
526
  "quote",
@@ -529,9 +529,9 @@ dependencies = [
529
 
530
  [[package]]
531
  name = "tokenizers"
532
- version = "0.20.1"
533
  source = "registry+https://github.com/rust-lang/crates.io-index"
534
- checksum = "b172ffa9a2e5c31bbddc940cd5725d933ced983a9333bbebc4c7eda3bbce1557"
535
  dependencies = [
536
  "aho-corasick",
537
  "derive_builder",
@@ -569,9 +569,9 @@ dependencies = [
569
 
570
  [[package]]
571
  name = "unicode-ident"
572
- version = "1.0.13"
573
  source = "registry+https://github.com/rust-lang/crates.io-index"
574
- checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe"
575
 
576
  [[package]]
577
  name = "unicode-normalization-alignments"
 
31
 
32
  [[package]]
33
  name = "cc"
34
+ version = "1.2.1"
35
  source = "registry+https://github.com/rust-lang/crates.io-index"
36
+ checksum = "fd9de9f2205d5ef3fd67e685b0df337994ddd4495e2a28d185500d0e1edfea47"
37
  dependencies = [
38
  "shlex",
39
  ]
 
190
 
191
  [[package]]
192
  name = "itoa"
193
+ version = "1.0.13"
194
  source = "registry+https://github.com/rust-lang/crates.io-index"
195
+ checksum = "540654e97a3f4470a492cd30ff187bc95d89557a903a2bbf112e2fae98104ef2"
196
 
197
  [[package]]
198
  name = "lazy_static"
 
202
 
203
  [[package]]
204
  name = "libc"
205
+ version = "0.2.164"
206
  source = "registry+https://github.com/rust-lang/crates.io-index"
207
+ checksum = "433bfe06b8c75da9b2e3fbea6e5329ff87748f0b144ef75306e674c3f6f7c13f"
208
 
209
  [[package]]
210
  name = "log"
 
322
 
323
  [[package]]
324
  name = "proc-macro2"
325
+ version = "1.0.91"
326
  source = "registry+https://github.com/rust-lang/crates.io-index"
327
+ checksum = "307e3004becf10f5a6e0d59d20f3cd28231b0e0827a96cd3e0ce6d14bc1e4bb3"
328
  dependencies = [
329
  "unicode-ident",
330
  ]
 
413
 
414
  [[package]]
415
  name = "regex-automata"
416
+ version = "0.4.9"
417
  source = "registry+https://github.com/rust-lang/crates.io-index"
418
+ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908"
419
  dependencies = [
420
  "aho-corasick",
421
  "memchr",
 
436
 
437
  [[package]]
438
  name = "serde"
439
+ version = "1.0.215"
440
  source = "registry+https://github.com/rust-lang/crates.io-index"
441
+ checksum = "6513c1ad0b11a9376da888e3e0baa0077f1aed55c17f50e7b2397136129fb88f"
442
  dependencies = [
443
  "serde_derive",
444
  ]
445
 
446
  [[package]]
447
  name = "serde_derive"
448
+ version = "1.0.215"
449
  source = "registry+https://github.com/rust-lang/crates.io-index"
450
+ checksum = "ad1e866f866923f252f05c889987993144fb74e722403468a4ebd70c3cd756c0"
451
  dependencies = [
452
  "proc-macro2",
453
  "quote",
 
456
 
457
  [[package]]
458
  name = "serde_json"
459
+ version = "1.0.133"
460
  source = "registry+https://github.com/rust-lang/crates.io-index"
461
+ checksum = "c7fceb2473b9166b2294ef05efcb65a3db80803f0b03ef86a5fc88a2b85ee377"
462
  dependencies = [
463
  "itoa",
464
  "memchr",
 
498
 
499
  [[package]]
500
  name = "syn"
501
+ version = "2.0.89"
502
  source = "registry+https://github.com/rust-lang/crates.io-index"
503
+ checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e"
504
  dependencies = [
505
  "proc-macro2",
506
  "quote",
 
509
 
510
  [[package]]
511
  name = "thiserror"
512
+ version = "1.0.69"
513
  source = "registry+https://github.com/rust-lang/crates.io-index"
514
+ checksum = "b6aaf5339b578ea85b50e080feb250a3e8ae8cfcdff9a461c9ec2904bc923f52"
515
  dependencies = [
516
  "thiserror-impl",
517
  ]
518
 
519
  [[package]]
520
  name = "thiserror-impl"
521
+ version = "1.0.69"
522
  source = "registry+https://github.com/rust-lang/crates.io-index"
523
+ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1"
524
  dependencies = [
525
  "proc-macro2",
526
  "quote",
 
529
 
530
  [[package]]
531
  name = "tokenizers"
532
+ version = "0.20.3"
533
  source = "registry+https://github.com/rust-lang/crates.io-index"
534
+ checksum = "67b67c92f6d705e2a1d106fb0b28c696f9074901a9c656ee5d9f5de204c39bf7"
535
  dependencies = [
536
  "aho-corasick",
537
  "derive_builder",
 
569
 
570
  [[package]]
571
  name = "unicode-ident"
572
+ version = "1.0.14"
573
  source = "registry+https://github.com/rust-lang/crates.io-index"
574
+ checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83"
575
 
576
  [[package]]
577
  name = "unicode-normalization-alignments"
Genie/Model/model.cpp CHANGED
@@ -179,8 +179,29 @@ MODEL_LIB_EXPORT ModelError_t QnnModel_GenAI_composeGraphs(Qnn_BackendHandle_t b
179
  (Qnn_Tensor_t)tin6),
180
  err);
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  /* ADDING NODE FOR genAI */
183
- const char* inputs_genAI[] = {"x0", "x1", "x2", "x3", "x4", "x5"};
184
 
185
  Qnn_Tensor_t tout;
186
  tout.version = QNN_TENSOR_VERSION_1;
@@ -224,7 +245,7 @@ MODEL_LIB_EXPORT ModelError_t QnnModel_GenAI_composeGraphs(Qnn_BackendHandle_t b
224
  params, // Node Params
225
  numParams, // Num Node Params
226
  inputs_genAI, // Input Tensor Names
227
- 6, // Num Input Tensor Names
228
  outputs_genAI, // Output Tensors
229
  2 // Num Output Tensors
230
  ),
 
179
  (Qnn_Tensor_t)tin6),
180
  err);
181
 
182
+ uint32_t input6Dim[1] = {1};
183
+ Qnn_Tensor_t tin7;
184
+ tin7.version = QNN_TENSOR_VERSION_1;
185
+ tin7.v1.id = 0;
186
+ tin7.v1.name = "x6";
187
+ tin7.v1.type = QNN_TENSOR_TYPE_APP_WRITE;
188
+ tin7.v1.dataFormat = QNN_TENSOR_DATA_FORMAT_FLAT_BUFFER;
189
+ tin7.v1.dataType = QNN_DATATYPE_FLOAT_32;
190
+ tin7.v1.quantizeParams.encodingDefinition = QNN_DEFINITION_UNDEFINED;
191
+ tin7.v1.quantizeParams.quantizationEncoding = QNN_QUANTIZATION_ENCODING_UNDEFINED;
192
+ tin7.v1.quantizeParams.scaleOffsetEncoding = {.scale = 0.0000000000000000f,
193
+ .offset = 0};
194
+ tin7.v1.rank = 1;
195
+ tin7.v1.dimensions = input6Dim;
196
+ tin7.v1.memType = QNN_TENSORMEMTYPE_RAW;
197
+ tin7.v1.clientBuf = {.data = nullptr, .dataSize = 0};
198
+ VALIDATE(qnn_model.addTensor(
199
+ "x6", // Node Name
200
+ (Qnn_Tensor_t)tin7),
201
+ err);
202
+
203
  /* ADDING NODE FOR genAI */
204
+ const char* inputs_genAI[] = {"x0", "x1", "x2", "x3", "x4", "x5", "x6"};
205
 
206
  Qnn_Tensor_t tout;
207
  tout.version = QNN_TENSOR_VERSION_1;
 
245
  params, // Node Params
246
  numParams, // Num Node Params
247
  inputs_genAI, // Input Tensor Names
248
+ 7, // Num Input Tensor Names
249
  outputs_genAI, // Output Tensors
250
  2 // Num Output Tensors
251
  ),
Genie/configs/llama2-7b/llama2-7b-draft-htp-target-htp-spd.json CHANGED
@@ -43,7 +43,8 @@
43
  "cpu-mask": "0xe0",
44
  "kv-dim": 64,
45
  "kv-update-method": "SHIFT_CONCAT",
46
- "allow-async-init": false
 
47
  },
48
  "extensions": "htp_backend_ext_config.json"
49
  },
 
43
  "cpu-mask": "0xe0",
44
  "kv-dim": 64,
45
  "kv-update-method": "SHIFT_CONCAT",
46
+ "allow-async-init": false,
47
+ "enable-graph-switching": false
48
  },
49
  "extensions": "htp_backend_ext_config.json"
50
  },
Genie/configs/llama2-7b/llama2-7b-genaitransformer-lora.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dialog" : {
3
+ "version" : 1,
4
+ "type" : "basic",
5
+ "stop-sequence" : [""],
6
+ "max-num-tokens" : 200,
7
+ "context" : {
8
+ "version" : 1,
9
+ "size": 512,
10
+ "n-vocab": 32000,
11
+ "bos-token": 1,
12
+ "eos-token": 2
13
+ },
14
+ "sampler" : {
15
+ "version" : 1,
16
+ "seed" : 100,
17
+ "temp" : 1.2,
18
+ "top-k" : 20,
19
+ "top-p" : 0.75,
20
+ "greedy" : false
21
+ },
22
+ "tokenizer" : {
23
+ "version" : 1,
24
+ "path" : "your/path/to/tokenizer_file.json"
25
+ },
26
+ "engine" : {
27
+ "version" : 1,
28
+ "n-threads" : 6,
29
+ "backend" : {
30
+ "version" : 1,
31
+ "type" : "QnnGenAiTransformer",
32
+ "QnnGenAiTransformer" : {
33
+ "version" : 1,
34
+ "n-layer": 32,
35
+ "n-embd": 4096,
36
+ "n-heads": 32
37
+ }
38
+ },
39
+ "model" : {
40
+ "version" : 1,
41
+ "type" : "library",
42
+ "library" : {
43
+ "version" : 1,
44
+ "model-bin" : "your/path/to/model/file.bin",
45
+ "lora": {
46
+ "version": 1,
47
+ "alpha-tensor-name": "alpha",
48
+ "adapters": [
49
+ {
50
+ "version": 1,
51
+ "name": "lora1",
52
+ "bin-sections": [
53
+ "your/path/to/model/lora/file.bin"
54
+ ]
55
+ }
56
+ ]
57
+ }
58
+ }
59
+ }
60
+ }
61
+ }
62
+ }
Genie/configs/llama2-7b/llama2-7b-genaitransformer.json CHANGED
@@ -30,7 +30,10 @@
30
  "version" : 1,
31
  "type" : "QnnGenAiTransformer",
32
  "QnnGenAiTransformer" : {
33
- "version" : 1
 
 
 
34
  }
35
  },
36
  "model" : {
 
30
  "version" : 1,
31
  "type" : "QnnGenAiTransformer",
32
  "QnnGenAiTransformer" : {
33
+ "version" : 1,
34
+ "n-layer": 32,
35
+ "n-embd": 4096,
36
+ "n-heads": 32
37
  }
38
  },
39
  "model" : {
Genie/configs/llama2-7b/llama2-7b-gpu.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dialog" : {
3
+ "version" : 1,
4
+ "type" : "basic",
5
+ "context" : {
6
+ "version" : 1,
7
+ "size": 1024,
8
+ "n-vocab": 32000,
9
+ "bos-token": 1,
10
+ "eos-token": 2
11
+ },
12
+ "sampler" : {
13
+ "version" : 1,
14
+ "seed" : 42,
15
+ "temp" : 1.1,
16
+ "top-k" : 40,
17
+ "top-p" : 0.95,
18
+ "greedy" : false
19
+ },
20
+ "tokenizer" : {
21
+ "version" : 1,
22
+ "path" : "/path/to/tokenizer.json"
23
+ },
24
+ "engine" : {
25
+ "version" : 1,
26
+ "n-threads" : 3,
27
+ "backend" : {
28
+ "version" : 1,
29
+ "type" : "QnnGpu"
30
+ },
31
+ "model" : {
32
+ "version" : 1,
33
+ "type" : "binary",
34
+ "binary" : {
35
+ "version" : 1,
36
+ "ctx-bins" : [
37
+ "/path/to/model.bin"
38
+ ]
39
+ }
40
+ }
41
+ }
42
+ }
43
+ }