// Copyright 2025 The ODML Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_UTIL_LITERT_LM_LOADER_H_ #define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_UTIL_LITERT_LM_LOADER_H_ #include #include #include #include #include #include #include #include #include #include #include "absl/log/absl_check.h" // from @com_google_absl #include "absl/log/absl_log.h" // from @com_google_absl #include "absl/status/status.h" // from @com_google_absl #include "absl/status/statusor.h" // from @com_google_absl #include "absl/synchronization/mutex.h" // from @com_google_absl #include "litert/cc/litert_buffer_ref.h" // from @litert #include "runtime/components/model_resources.h" #include "runtime/util/memory_mapped_file.h" #include "runtime/util/scoped_file.h" #include "schema/core/litertlm_header_schema_generated.h" #include "schema/core/litertlm_read.h" namespace litert::lm { // Each buffer is keyed by the data type as the major key and the model type // as the optional secondary key when the data type is TFLiteModel or // TFLiteWeights. struct BufferKey { schema::AnySectionDataType data_type; std::optional model_type; // This can be nullopt for data types // other than TFLiteModel or TFLiteWeights! // Constructor for common cases (no ModelType needed) explicit BufferKey(schema::AnySectionDataType type) : data_type(type), model_type(std::nullopt) {} // Constructor for TFLiteModel or TFLiteWeights case explicit BufferKey(schema::AnySectionDataType type, ModelType model_type) : data_type(type), model_type(model_type) { ABSL_CHECK( (type == schema::AnySectionDataType_TFLiteModel || type == schema::AnySectionDataType_TFLiteWeights) && "ModelType should only be provided for TFLiteModel or TFLiteWeights"); } // Equality operator (REQUIRED for std::unordered_map, good for std::map) bool operator==(const BufferKey& other) const { return data_type == other.data_type && model_type == other.model_type; } }; // Hash function for BufferKey struct BufferKeyHash { size_t operator()(const BufferKey& k) const { size_t h1 = std::hash{}(k.data_type); size_t h2 = 0; if (k.model_type.has_value()) { h2 = std::hash{}(k.model_type.value()); } // A simple hash combine. For more robust hashing, consider // boost::hash_combine return h1 ^ (h2 << 1); } }; // A class to load the Litert LM model from the .litertlm file. The loader will // read the model header from and map the sections to the section buffers. class LitertLmLoader { public: // Creates a LitertLmLoader from the model file. The loader will read the // model header from and map the sections to the section buffers. explicit LitertLmLoader(ScopedFile model_file) : model_source_(std::move(model_file)) { ABSL_CHECK_OK(Initialize()); } // Creates a LitertLmLoader from an already memory-mapped model file. // This is useful when the file is managed externally. explicit LitertLmLoader( std::shared_ptr memory_mapped_model_file); // Returns the tokenizer section buffer for the SentencePiece tokenizer. // If not found, returns std::nullopt. std::optional> GetSentencePieceTokenizer() { return GetSectionBuffer(BufferKey(schema::AnySectionDataType_SP_Tokenizer)); } // Returns the tokenizer section buffer for the HuggingFace tokenizer. // If not found, returns std::nullopt. std::optional> GetHuggingFaceTokenizer(); // Returns the TFLite model section buffer. litert::BufferRef GetTFLiteModel(ModelType model_type) { auto optional_section_buffer = GetSectionBuffer( BufferKey(schema::AnySectionDataType_TFLiteModel, model_type)); if (optional_section_buffer.has_value()) { return optional_section_buffer.value(); } ABSL_LOG(WARNING) << "TFLite model for type: " << ModelTypeToString(model_type) << " not found. Skipping."; return litert::BufferRef(); }; litert::BufferRef GetTFLiteWeights(ModelType model_type) { auto optional_section_buffer = GetSectionBuffer( BufferKey(schema::AnySectionDataType_TFLiteWeights, model_type)); if (optional_section_buffer.has_value()) { return optional_section_buffer.value(); } ABSL_LOG(WARNING) << "TFLite weights for type: " << ModelTypeToString(model_type) << " not found. Skipping."; return litert::BufferRef(); }; // Returns the TFLite model section buffer. std::optional GetTFLiteModelBackendConstraint( ModelType model_type) { if (section_backend_constraint_.contains( BufferKey(schema::AnySectionDataType_TFLiteModel, model_type))) { return section_backend_constraint_[BufferKey( schema::AnySectionDataType_TFLiteModel, model_type)]; } ABSL_LOG(WARNING) << "TFLite model type: " << ModelTypeToString(model_type) << " not found for backend constraints. Skipping."; return std::nullopt; }; // Returns the tokenizer section buffer. litert::BufferRef GetLlmMetadata() { return GetSectionBuffer( BufferKey(schema::AnySectionDataType_LlmMetadataProto)) .value(); } absl::StatusOr> GetSectionLocation( BufferKey buffer_key) const; absl::StatusOr> GetScopedFile(); private: // Initializes the LitertLmLoader. Includes reading the model header and // recording the section locations for on-demand loading later. absl::Status Initialize(); absl::Status MapSection(BufferKey buffer_key, uint64_t begin_offset, uint64_t end_offset) ABSL_EXCLUSIVE_LOCKS_REQUIRED(section_buffers_mutex_); // Returns the section buffer for the given buffer key. Will map the section // if it has not been mapped yet. If not found, returns std::nullopt. std::optional> GetSectionBuffer( BufferKey buffer_key) ABSL_LOCKS_EXCLUDED(section_buffers_mutex_); // The model file to be loaded, can be either a ScopedFile or a // memory-mapped file. std::variant> model_source_; // The header of the model file. Use this to understand what sections are // available and their offsets. schema::LitertlmHeader header_; // The section locations in the model file. This is populated during // initialization and later used to map the section buffers to the section // memory mapped files on-demand. ::std::unordered_map< BufferKey, std::pair, BufferKeyHash> section_locations_; absl::Mutex section_buffers_mutex_; // The section memory mapped files - stored here to ensure they are not // unmapped while in use. On Windows, these MemoryMappedFiles may contain more // than the current section's data because Windows has a data alignment of // 64KB but the LiteRT LM file has a 16KB alignment. ::std::unordered_map, BufferKeyHash> section_memory_mapped_files_ ABSL_GUARDED_BY(section_buffers_mutex_); // The section buffers. Unlike the section_memory_mapped_files_, these // buffers point to only the data of the each section, even on Windows. ::std::unordered_map, BufferKeyHash> section_buffers_ ABSL_GUARDED_BY(section_buffers_mutex_); // Map of all the sections' metadata, for now, focusing on the backend // constraints ::std::unordered_map section_backend_constraint_; }; } // namespace litert::lm #endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_UTIL_LITERT_LM_LOADER_H_