File size: 3,372 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// Copyright 2025 The ODML Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "runtime/framework/resource_management/utils/resource_manager_utils.h"

#include <algorithm>
#include <cstddef>
#include <iterator>
#include <vector>

#include "absl/status/status.h"  // from @com_google_absl
#include "runtime/util/status_macros.h"  // NOLINT

namespace litert::lm {

namespace {
// Returns the comparable length of the vector, excluding the negative tokens,
// which indicates the multi modal token and just a placeholder. From token id
// level, there is no way to tell if the multi modal token is matched or not, so
// we assume the multi modal token are not matched. If all tokens are not
// negative, return vec.size().
size_t ComparableLengthExcludingNegativeTokens(const std::vector<int>& vec) {
  for (size_t i = 0; i < vec.size(); ++i) {
    if (vec[i] < 0) {
      return i;
    }
  }
  return vec.size();
}
}  // namespace

absl::Status RemoveMatchingTokens(const std::vector<int> &processed_tokens,
                                  std::vector<int> *input_ids, int *time_step) {
  RET_CHECK_NE(input_ids, nullptr) << "input_ids is null.";
  RET_CHECK_NE(time_step, nullptr) << "time_step is null.";
  RET_CHECK_GE(*time_step, 0) << "Time step is negative.";
  RET_CHECK_GE(processed_tokens.size(), *time_step)
      << "The processed tokens size is smaller than the time step.";
  // Determine the number of tokens available in processed_tokens from the
  // current time_step. *time_step is int, .size() is size_t. RET_CHECKs ensure
  // *time_step is valid.
  const size_t processed_tokens_comparable_len =
      processed_tokens.size() - static_cast<size_t>(*time_step);
  // Determine how many elements to actually compare (the minimum of the two
  // effective sequence lengths).
  const size_t comparison_len =
      std::min(processed_tokens_comparable_len,
               ComparableLengthExcludingNegativeTokens(*input_ids));
  // Find the first mismatch.
  // The first range for mismatch is [input_ids->begin(), input_ids->begin() +
  // comparison_len). The second range for mismatch is [processed_tokens.begin()
  // + *time_step, processed_tokens.begin() + *time_step + comparison_len).
  auto mismatch_it =
      std::mismatch(input_ids->begin(), input_ids->begin() + comparison_len,
                    processed_tokens.begin() + *time_step);
  // Calculate the number of matching tokens.
  // mismatch_it.first is an iterator to the first differing element in
  // input_ids, or input_ids->begin() + comparison_len if all compared elements
  // match.
  int matching_tokens = std::distance(input_ids->begin(), mismatch_it.first);
  // Update the input_ids and time_step.
  input_ids->erase(input_ids->begin(), mismatch_it.first);
  *time_step += matching_tokens;
  return absl::OkStatus();
};

}  // namespace litert::lm