Spaces:
Running
Running
| // Copyright 2025 The ODML Authors. | |
| // | |
| // Licensed under the Apache License, Version 2.0 (the "License"); | |
| // you may not use this file except in compliance with the License. | |
| // You may obtain a copy of the License at | |
| // | |
| // http://www.apache.org/licenses/LICENSE-2.0 | |
| // | |
| // Unless required by applicable law or agreed to in writing, software | |
| // distributed under the License is distributed on an "AS IS" BASIS, | |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| // See the License for the specific language governing permissions and | |
| // limitations under the License. | |
| // All the loaded model resources the executor needs to hold to avoid the model | |
| // being destroyed. | |
| namespace litert::lm { | |
| enum class ModelType { | |
| kUnknown = 0, // Placeholder for uninitialized model type. | |
| kTfLitePrefillDecode = 1, // The base model is used for prefill and decode. | |
| kTfLiteEmbedder = 2, | |
| kTfLitePerLayerEmbedder = 3, | |
| kTfLiteAux = 4, | |
| kTfLiteAudioFrontend = 9, // Audio frontend model is weight-less Short-Time | |
| // Fourier Transform (STFT) to convert audio to | |
| // spectrogram. | |
| kTfLiteAudioEncoderHw = 5, | |
| kTfLiteAudioAdapter = 10, | |
| kTfLiteEndOfAudio = 6, | |
| kTfLiteVisionAdapter = 7, | |
| kTfLiteEndOfVision = 12, // The end of vision token model. | |
| kTfLiteVisionEncoder = 8, | |
| kArtisanTextDecoder = 11, // The text decoder model for the artisan gpu. | |
| kTfLiteMtpDrafter = 13, // The MTP drafter model. | |
| }; | |
| // Utility function to convert a string to ModelType. It's case insensitive. | |
| inline absl::StatusOr<ModelType> StringToModelType( | |
| absl::string_view model_type_str) { | |
| const std::string lower_case_model_type_str = | |
| absl::AsciiStrToLower(model_type_str); | |
| if (lower_case_model_type_str == "tf_lite_prefill_decode") { | |
| return ModelType::kTfLitePrefillDecode; | |
| } else if (lower_case_model_type_str == "tf_lite_embedder") { | |
| return ModelType::kTfLiteEmbedder; | |
| } else if (lower_case_model_type_str == "tf_lite_per_layer_embedder") { | |
| return ModelType::kTfLitePerLayerEmbedder; | |
| } else if (lower_case_model_type_str == "tf_lite_aux") { | |
| return ModelType::kTfLiteAux; | |
| } else if (lower_case_model_type_str == "tf_lite_audio_frontend") { | |
| return ModelType::kTfLiteAudioFrontend; | |
| } else if (lower_case_model_type_str == "tf_lite_audio_encoder_hw") { | |
| return ModelType::kTfLiteAudioEncoderHw; | |
| } else if (lower_case_model_type_str == "tf_lite_audio_adapter") { | |
| return ModelType::kTfLiteAudioAdapter; | |
| } else if (lower_case_model_type_str == "tf_lite_end_of_audio") { | |
| return ModelType::kTfLiteEndOfAudio; | |
| } else if (lower_case_model_type_str == "tf_lite_vision_adapter") { | |
| return ModelType::kTfLiteVisionAdapter; | |
| } else if (lower_case_model_type_str == "tf_lite_end_of_vision") { | |
| return ModelType::kTfLiteEndOfVision; | |
| } else if (lower_case_model_type_str == "tf_lite_vision_encoder") { | |
| return ModelType::kTfLiteVisionEncoder; | |
| } else if (lower_case_model_type_str == "tf_lite_artisan_text_decoder") { | |
| return ModelType::kArtisanTextDecoder; | |
| } else if (lower_case_model_type_str == "tf_lite_mtp_drafter") { | |
| return ModelType::kTfLiteMtpDrafter; | |
| } else { | |
| return absl::InvalidArgumentError( | |
| absl::StrCat("Unknown model type: ", model_type_str)); | |
| } | |
| } | |
| // Utility function to convert a ModelType to string. | |
| inline std::string ModelTypeToString(ModelType model_type) { | |
| switch (model_type) { | |
| case ModelType::kTfLitePrefillDecode: | |
| return "TF_LITE_PREFILL_DECODE"; | |
| case ModelType::kTfLiteEmbedder: | |
| return "TF_LITE_EMBEDDER"; | |
| case ModelType::kTfLitePerLayerEmbedder: | |
| return "TF_LITE_PER_LAYER_EMBEDDER"; | |
| case ModelType::kTfLiteAux: | |
| return "TF_LITE_AUX"; | |
| case ModelType::kTfLiteAudioFrontend: | |
| return "TF_LITE_AUDIO_FRONTEND"; | |
| case ModelType::kTfLiteAudioEncoderHw: | |
| return "TF_LITE_AUDIO_ENCODER_HW"; | |
| case ModelType::kTfLiteAudioAdapter: | |
| return "TF_LITE_AUDIO_ADAPTER"; | |
| case ModelType::kTfLiteEndOfAudio: | |
| return "TF_LITE_END_OF_AUDIO"; | |
| case ModelType::kTfLiteVisionAdapter: | |
| return "TF_LITE_VISION_ADAPTER"; | |
| case ModelType::kTfLiteEndOfVision: | |
| return "TF_LITE_END_OF_VISION"; | |
| case ModelType::kTfLiteVisionEncoder: | |
| return "TF_LITE_VISION_ENCODER"; | |
| case ModelType::kArtisanTextDecoder: | |
| return "TF_LITE_ARTISAN_TEXT_DECODER"; | |
| case ModelType::kTfLiteMtpDrafter: | |
| return "TF_LITE_MTP_DRAFTER"; | |
| case ModelType::kUnknown: | |
| return "UNKNOWN"; | |
| default: | |
| return "INVALID"; | |
| } | |
| } | |
| // ModelResources is an interface that manages all the loaded model resources | |
| // that need to be hold to avoid the model being destroyed. It provides a way | |
| // to load the models in a lazy way. | |
| // Basically, it will create the models when they are actually used. Before the | |
| // Get*() functions are called, the models are not created yet. And once the | |
| // models are created, they will be re-used for all the following calls. | |
| // | |
| // It's not thread-safe. | |
| class ModelResources { | |
| public: | |
| virtual ~ModelResources() = default; | |
| // Returns the litert model. We will create the model if it is not created | |
| // yet. And the model is created from memory mapped file, so physical memory | |
| // is only allocated when the model is actually used. | |
| virtual absl::StatusOr<const litert::Model*> GetTFLiteModel( | |
| ModelType model_type) = 0; | |
| // Returns the TFLite model buffer. Note that the returned string_view is | |
| // valid only until the ModelResources is destroyed. | |
| // When there is no model for the given model type, it will return an error | |
| // status. | |
| // Prefer to use GetTFLiteModel() if possible, as this function will leave | |
| // the model lifecycle management to the caller. | |
| virtual absl::StatusOr<absl::string_view> GetTFLiteModelBuffer( | |
| ModelType model_type) = 0; | |
| // Returns the reference to the ScopedFile. This is used for the getting the | |
| // external weights that should not be mmapped into the memory. | |
| virtual absl::StatusOr<std::reference_wrapper<ScopedFile>> | |
| GetScopedFile() = 0; | |
| // Returns the section start offset and end offset. | |
| virtual absl::StatusOr<std::pair<size_t, size_t>> GetWeightsSectionOffset( | |
| ModelType model_type) = 0; | |
| // Returns the TFLite model backend constraint. When there is no constraint | |
| // for the given model type, it will return an nullopt. | |
| virtual std::optional<std::string> GetTFLiteModelBackendConstraint( | |
| ModelType model_type) = 0; | |
| // Builds a tokenizer instance from the model and returns it. | |
| virtual absl::StatusOr<std::unique_ptr<Tokenizer>> GetTokenizer() = 0; | |
| // Returns the llm metadata. | |
| virtual absl::StatusOr<const proto::LlmMetadata*> GetLlmMetadata() = 0; | |
| }; | |
| } // namespace litert::lm | |