Spaces:
Running
Running
File size: 7,594 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_MODEL_RESOURCES_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_MODEL_RESOURCES_H_
// All the loaded model resources the executor needs to hold to avoid the model
// being destroyed.
#include <cstddef>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/ascii.h" // from @com_google_absl
#include "absl/strings/str_cat.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "litert/cc/litert_model.h" // from @litert
#include "runtime/components/tokenizer.h"
#include "runtime/proto/llm_metadata.pb.h"
#include "runtime/util/scoped_file.h"
namespace litert::lm {
enum class ModelType {
kUnknown = 0, // Placeholder for uninitialized model type.
kTfLitePrefillDecode = 1, // The base model is used for prefill and decode.
kTfLiteEmbedder = 2,
kTfLitePerLayerEmbedder = 3,
kTfLiteAux = 4,
kTfLiteAudioFrontend = 9, // Audio frontend model is weight-less Short-Time
// Fourier Transform (STFT) to convert audio to
// spectrogram.
kTfLiteAudioEncoderHw = 5,
kTfLiteAudioAdapter = 10,
kTfLiteEndOfAudio = 6,
kTfLiteVisionAdapter = 7,
kTfLiteEndOfVision = 12, // The end of vision token model.
kTfLiteVisionEncoder = 8,
kArtisanTextDecoder = 11, // The text decoder model for the artisan gpu.
kTfLiteMtpDrafter = 13, // The MTP drafter model.
};
// Utility function to convert a string to ModelType. It's case insensitive.
inline absl::StatusOr<ModelType> StringToModelType(
absl::string_view model_type_str) {
const std::string lower_case_model_type_str =
absl::AsciiStrToLower(model_type_str);
if (lower_case_model_type_str == "tf_lite_prefill_decode") {
return ModelType::kTfLitePrefillDecode;
} else if (lower_case_model_type_str == "tf_lite_embedder") {
return ModelType::kTfLiteEmbedder;
} else if (lower_case_model_type_str == "tf_lite_per_layer_embedder") {
return ModelType::kTfLitePerLayerEmbedder;
} else if (lower_case_model_type_str == "tf_lite_aux") {
return ModelType::kTfLiteAux;
} else if (lower_case_model_type_str == "tf_lite_audio_frontend") {
return ModelType::kTfLiteAudioFrontend;
} else if (lower_case_model_type_str == "tf_lite_audio_encoder_hw") {
return ModelType::kTfLiteAudioEncoderHw;
} else if (lower_case_model_type_str == "tf_lite_audio_adapter") {
return ModelType::kTfLiteAudioAdapter;
} else if (lower_case_model_type_str == "tf_lite_end_of_audio") {
return ModelType::kTfLiteEndOfAudio;
} else if (lower_case_model_type_str == "tf_lite_vision_adapter") {
return ModelType::kTfLiteVisionAdapter;
} else if (lower_case_model_type_str == "tf_lite_end_of_vision") {
return ModelType::kTfLiteEndOfVision;
} else if (lower_case_model_type_str == "tf_lite_vision_encoder") {
return ModelType::kTfLiteVisionEncoder;
} else if (lower_case_model_type_str == "tf_lite_artisan_text_decoder") {
return ModelType::kArtisanTextDecoder;
} else if (lower_case_model_type_str == "tf_lite_mtp_drafter") {
return ModelType::kTfLiteMtpDrafter;
} else {
return absl::InvalidArgumentError(
absl::StrCat("Unknown model type: ", model_type_str));
}
}
// Utility function to convert a ModelType to string.
inline std::string ModelTypeToString(ModelType model_type) {
switch (model_type) {
case ModelType::kTfLitePrefillDecode:
return "TF_LITE_PREFILL_DECODE";
case ModelType::kTfLiteEmbedder:
return "TF_LITE_EMBEDDER";
case ModelType::kTfLitePerLayerEmbedder:
return "TF_LITE_PER_LAYER_EMBEDDER";
case ModelType::kTfLiteAux:
return "TF_LITE_AUX";
case ModelType::kTfLiteAudioFrontend:
return "TF_LITE_AUDIO_FRONTEND";
case ModelType::kTfLiteAudioEncoderHw:
return "TF_LITE_AUDIO_ENCODER_HW";
case ModelType::kTfLiteAudioAdapter:
return "TF_LITE_AUDIO_ADAPTER";
case ModelType::kTfLiteEndOfAudio:
return "TF_LITE_END_OF_AUDIO";
case ModelType::kTfLiteVisionAdapter:
return "TF_LITE_VISION_ADAPTER";
case ModelType::kTfLiteEndOfVision:
return "TF_LITE_END_OF_VISION";
case ModelType::kTfLiteVisionEncoder:
return "TF_LITE_VISION_ENCODER";
case ModelType::kArtisanTextDecoder:
return "TF_LITE_ARTISAN_TEXT_DECODER";
case ModelType::kTfLiteMtpDrafter:
return "TF_LITE_MTP_DRAFTER";
case ModelType::kUnknown:
return "UNKNOWN";
default:
return "INVALID";
}
}
// ModelResources is an interface that manages all the loaded model resources
// that need to be hold to avoid the model being destroyed. It provides a way
// to load the models in a lazy way.
// Basically, it will create the models when they are actually used. Before the
// Get*() functions are called, the models are not created yet. And once the
// models are created, they will be re-used for all the following calls.
//
// It's not thread-safe.
class ModelResources {
public:
virtual ~ModelResources() = default;
// Returns the litert model. We will create the model if it is not created
// yet. And the model is created from memory mapped file, so physical memory
// is only allocated when the model is actually used.
virtual absl::StatusOr<const litert::Model*> GetTFLiteModel(
ModelType model_type) = 0;
// Returns the TFLite model buffer. Note that the returned string_view is
// valid only until the ModelResources is destroyed.
// When there is no model for the given model type, it will return an error
// status.
// Prefer to use GetTFLiteModel() if possible, as this function will leave
// the model lifecycle management to the caller.
virtual absl::StatusOr<absl::string_view> GetTFLiteModelBuffer(
ModelType model_type) = 0;
// Returns the reference to the ScopedFile. This is used for the getting the
// external weights that should not be mmapped into the memory.
virtual absl::StatusOr<std::reference_wrapper<ScopedFile>>
GetScopedFile() = 0;
// Returns the section start offset and end offset.
virtual absl::StatusOr<std::pair<size_t, size_t>> GetWeightsSectionOffset(
ModelType model_type) = 0;
// Returns the TFLite model backend constraint. When there is no constraint
// for the given model type, it will return an nullopt.
virtual std::optional<std::string> GetTFLiteModelBackendConstraint(
ModelType model_type) = 0;
// Builds a tokenizer instance from the model and returns it.
virtual absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() = 0;
// Returns the llm metadata.
virtual absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() = 0;
};
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_MODEL_RESOURCES_H_
|