File size: 1,610 Bytes
e05eed1
98a67a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

#pragma once

#include <memory>

#include <torch/torch.h>

#include "prefix.h"
#include "log_sum_exp.h"

typedef std::unordered_map<int64_t, std::wstring> token_mapping_t;
typedef std::unordered_map<wchar_t, int64_t> reverse_token_mapping_t;


class LanguageModel
{
public:
    virtual ~LanguageModel() {}

    virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const = 0;

    const token_mapping_t &TokenMapping() const { return m_tokenMapping; }

protected:
    LanguageModel(token_mapping_t tokenMapping)
        : m_tokenMapping(std::move(tokenMapping))
    {}

    token_mapping_t m_tokenMapping;
};


class NullLanguageModel_t
    : public LanguageModel
{
public:
    NullLanguageModel_t();

    virtual float_t ScoreTransition(const Prefix *p, token_t nextToken) const override
    {
        // log P(1)
        // Which means the probability is unchanged
        return 0;
    }
};

extern const NullLanguageModel_t NullLanguageModel;

struct TokenMappingWrapper
{
    typedef std::shared_ptr<TokenMappingWrapper> Ptr;

    TokenMappingWrapper(token_mapping_t mapping);

    token_mapping_t token_mapping;
    reverse_token_mapping_t reverse_token_mapping;
};

TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping);

std::vector<std::tuple<std::wstring, float>>
    decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
                     c10::optional<torch::Tensor> probs = torch::nullopt);