Spaces:
Running
Running
File size: 8,642 Bytes
5f923cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 | // Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_UTIL_LITERT_LM_LOADER_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_UTIL_LITERT_LM_LOADER_H_
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <utility>
#include <variant>
#include "absl/log/absl_check.h" // from @com_google_absl
#include "absl/log/absl_log.h" // from @com_google_absl
#include "absl/status/status.h" // from @com_google_absl
#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/synchronization/mutex.h" // from @com_google_absl
#include "litert/cc/litert_buffer_ref.h" // from @litert
#include "runtime/components/model_resources.h"
#include "runtime/util/memory_mapped_file.h"
#include "runtime/util/scoped_file.h"
#include "schema/core/litertlm_header_schema_generated.h"
#include "schema/core/litertlm_read.h"
namespace litert::lm {
// Each buffer is keyed by the data type as the major key and the model type
// as the optional secondary key when the data type is TFLiteModel or
// TFLiteWeights.
struct BufferKey {
schema::AnySectionDataType data_type;
std::optional<ModelType>
model_type; // This can be nullopt for data types
// other than TFLiteModel or TFLiteWeights!
// Constructor for common cases (no ModelType needed)
explicit BufferKey(schema::AnySectionDataType type)
: data_type(type), model_type(std::nullopt) {}
// Constructor for TFLiteModel or TFLiteWeights case
explicit BufferKey(schema::AnySectionDataType type, ModelType model_type)
: data_type(type), model_type(model_type) {
ABSL_CHECK(
(type == schema::AnySectionDataType_TFLiteModel ||
type == schema::AnySectionDataType_TFLiteWeights) &&
"ModelType should only be provided for TFLiteModel or TFLiteWeights");
}
// Equality operator (REQUIRED for std::unordered_map, good for std::map)
bool operator==(const BufferKey& other) const {
return data_type == other.data_type && model_type == other.model_type;
}
};
// Hash function for BufferKey
struct BufferKeyHash {
size_t operator()(const BufferKey& k) const {
size_t h1 = std::hash<schema::AnySectionDataType>{}(k.data_type);
size_t h2 = 0;
if (k.model_type.has_value()) {
h2 = std::hash<ModelType>{}(k.model_type.value());
}
// A simple hash combine. For more robust hashing, consider
// boost::hash_combine
return h1 ^ (h2 << 1);
}
};
// A class to load the Litert LM model from the .litertlm file. The loader will
// read the model header from and map the sections to the section buffers.
class LitertLmLoader {
public:
// Creates a LitertLmLoader from the model file. The loader will read the
// model header from and map the sections to the section buffers.
explicit LitertLmLoader(ScopedFile model_file)
: model_source_(std::move(model_file)) {
ABSL_CHECK_OK(Initialize());
}
// Creates a LitertLmLoader from an already memory-mapped model file.
// This is useful when the file is managed externally.
explicit LitertLmLoader(
std::shared_ptr<MemoryMappedFile> memory_mapped_model_file);
// Returns the tokenizer section buffer for the SentencePiece tokenizer.
// If not found, returns std::nullopt.
std::optional<litert::BufferRef<uint8_t>> GetSentencePieceTokenizer() {
return GetSectionBuffer(BufferKey(schema::AnySectionDataType_SP_Tokenizer));
}
// Returns the tokenizer section buffer for the HuggingFace tokenizer.
// If not found, returns std::nullopt.
std::optional<litert::OwningBufferRef<uint8_t>> GetHuggingFaceTokenizer();
// Returns the TFLite model section buffer.
litert::BufferRef<uint8_t> GetTFLiteModel(ModelType model_type) {
auto optional_section_buffer = GetSectionBuffer(
BufferKey(schema::AnySectionDataType_TFLiteModel, model_type));
if (optional_section_buffer.has_value()) {
return optional_section_buffer.value();
}
ABSL_LOG(WARNING) << "TFLite model for type: "
<< ModelTypeToString(model_type)
<< " not found. Skipping.";
return litert::BufferRef<uint8_t>();
};
litert::BufferRef<uint8_t> GetTFLiteWeights(ModelType model_type) {
auto optional_section_buffer = GetSectionBuffer(
BufferKey(schema::AnySectionDataType_TFLiteWeights, model_type));
if (optional_section_buffer.has_value()) {
return optional_section_buffer.value();
}
ABSL_LOG(WARNING) << "TFLite weights for type: "
<< ModelTypeToString(model_type)
<< " not found. Skipping.";
return litert::BufferRef<uint8_t>();
};
// Returns the TFLite model section buffer.
std::optional<std::string> GetTFLiteModelBackendConstraint(
ModelType model_type) {
if (section_backend_constraint_.contains(
BufferKey(schema::AnySectionDataType_TFLiteModel, model_type))) {
return section_backend_constraint_[BufferKey(
schema::AnySectionDataType_TFLiteModel, model_type)];
}
ABSL_LOG(WARNING) << "TFLite model type: " << ModelTypeToString(model_type)
<< " not found for backend constraints. Skipping.";
return std::nullopt;
};
// Returns the tokenizer section buffer.
litert::BufferRef<uint8_t> GetLlmMetadata() {
return GetSectionBuffer(
BufferKey(schema::AnySectionDataType_LlmMetadataProto))
.value();
}
absl::StatusOr<std::pair<size_t, size_t>> GetSectionLocation(
BufferKey buffer_key) const;
absl::StatusOr<std::reference_wrapper<ScopedFile>> GetScopedFile();
private:
// Initializes the LitertLmLoader. Includes reading the model header and
// recording the section locations for on-demand loading later.
absl::Status Initialize();
absl::Status MapSection(BufferKey buffer_key, uint64_t begin_offset,
uint64_t end_offset)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(section_buffers_mutex_);
// Returns the section buffer for the given buffer key. Will map the section
// if it has not been mapped yet. If not found, returns std::nullopt.
std::optional<litert::BufferRef<uint8_t>> GetSectionBuffer(
BufferKey buffer_key) ABSL_LOCKS_EXCLUDED(section_buffers_mutex_);
// The model file to be loaded, can be either a ScopedFile or a
// memory-mapped file.
std::variant<ScopedFile, std::shared_ptr<MemoryMappedFile>> model_source_;
// The header of the model file. Use this to understand what sections are
// available and their offsets.
schema::LitertlmHeader header_;
// The section locations in the model file. This is populated during
// initialization and later used to map the section buffers to the section
// memory mapped files on-demand.
::std::unordered_map<
BufferKey, std::pair</*begin_offset*/ uint64_t, /*end_offset=*/uint64_t>,
BufferKeyHash>
section_locations_;
absl::Mutex section_buffers_mutex_;
// The section memory mapped files - stored here to ensure they are not
// unmapped while in use. On Windows, these MemoryMappedFiles may contain more
// than the current section's data because Windows has a data alignment of
// 64KB but the LiteRT LM file has a 16KB alignment.
::std::unordered_map<BufferKey, std::unique_ptr<MemoryMappedFile>,
BufferKeyHash>
section_memory_mapped_files_ ABSL_GUARDED_BY(section_buffers_mutex_);
// The section buffers. Unlike the section_memory_mapped_files_, these
// buffers point to only the data of the each section, even on Windows.
::std::unordered_map<BufferKey, litert::BufferRef<uint8_t>, BufferKeyHash>
section_buffers_ ABSL_GUARDED_BY(section_buffers_mutex_);
// Map of all the sections' metadata, for now, focusing on the backend
// constraints
::std::unordered_map<BufferKey, std::string, BufferKeyHash>
section_backend_constraint_;
};
} // namespace litert::lm
#endif // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_UTIL_LITERT_LM_LOADER_H_
|