File size: 2,138 Bytes
5f923cd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
#include "runtime/components/constrained_decoding/constraint_provider_factory.h"

#include <memory>
#include <variant>
#include <vector>

#include "absl/status/status.h"  // from @com_google_absl
#include "absl/status/statusor.h"  // from @com_google_absl
#include "runtime/components/constrained_decoding/constraint_provider.h"
#include "runtime/components/constrained_decoding/constraint_provider_config.h"
#include "runtime/components/constrained_decoding/external_constraint_config.h"
#include "runtime/components/constrained_decoding/external_constraint_provider.h"
#include "runtime/components/constrained_decoding/llg_constraint_config.h"
#include "runtime/components/constrained_decoding/llg_constraint_provider.h"
#include "runtime/components/tokenizer.h"

namespace litert::lm {

absl::StatusOr<std::unique_ptr<ConstraintProvider>> CreateConstraintProvider(
    const ConstraintProviderConfig& constraint_provider_config,
    const Tokenizer& tokenizer,
    const std::vector<std::vector<int>>& stop_token_ids) {
  if (std::holds_alternative<ExternalConstraintConfig>(
          constraint_provider_config)) {
    return std::make_unique<ExternalConstraintProvider>();
  } else if (std::holds_alternative<LlGuidanceConfig>(
                 constraint_provider_config)) {
    auto llg_guidance_config =
        std::get<LlGuidanceConfig>(constraint_provider_config);
    if (!llg_guidance_config.eos_id.has_value()) {
      // If eos_id is not provided in the config, use the first valid stop token
      // as the eos_id.
      for (const auto& stop_sequence : stop_token_ids) {
        if (stop_sequence.size() == 1) {
          llg_guidance_config.eos_id = stop_sequence[0];
          break;
        }
      }
      if (!llg_guidance_config.eos_id.has_value()) {
        return absl::InvalidArgumentError(
            "LlGuidanceConfig::eos_id wasn't set and no valid stop token was "
            "found in SessionConfig.");
      }
    }
    return LlgConstraintProvider::Create(tokenizer, llg_guidance_config);
  }

  return absl::UnimplementedError("Unknown ConstraintProviderConfig type.");
}

}  // namespace litert::lm