File size: 2,929 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_
#define THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_

#include <vector>

#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "absl/types/span.h"  // from @com_google_absl

namespace litert::lm {

// Preprocesses the token ids before feeding them to the LLM.
// - token_ids: The token ids to be preprocessed.
// - start_token_id: The token id of the start token. It is used to be prepended
//   to the token ids.
// - max_num_tokens: The maximum number of tokens the model can hold (i.e. the
//   length of the kv-cache). Setting it to be 0 means no limit.
// - context_length_ratio_threhold: The threshold of the ratio of input context
//   length to the max_num_tokens. If the ratio is larger than the threshold,
//   the function will return an error, indicating that there is not enough
//   space for the model to generate the output. For example, if the max_num
//   tokens is 1024 and the context_length_ratio_threhold is 0.9, the function
//   will return an error if the input token length is 922 (922 > 1024 * 0.9).
absl::Status PreprocessTokenIds(std::vector<int>& token_ids, int start_token_id,
                                int max_num_tokens,
                                float context_length_ratio_threhold = 0.9f);

// Checks if the stop token is found in the decoded token ids.
// - decoded_token_ids: The decoded token ids. The size of the vector span is
//   the batch size of the decoded token ids generated by the model.
// - stop_token_id: The token id of the stop token.
// - stop_token_found: The vector to store the information of whether a stop
//   token is found in each of the decoded candidate sequences. The function is
//   responsible for updating the vector to reflect whether a stop token is
//   found in the decoded token ids. The function returns true when all values
//   in the vector are true.
absl::StatusOr<bool> StopTokenFound(absl::Span<const int> decoded_token_ids,
                                    const std::vector<int>& stop_token_ids,
                                    std::vector<bool>& stop_token_found);

}  // namespace litert::lm

#endif  // THIRD_PARTY_ODML_LITERT_LM_RUNTIME_COMPONENTS_TOKEN_ID_UTIL_H_