File size: 4,261 Bytes
e05eed1 98a67a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
// SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
#include "language_model.h"
#include <locale>
#include <codecvt>
using namespace std;
const NullLanguageModel_t NullLanguageModel;
NullLanguageModel_t::NullLanguageModel_t()
: LanguageModel({})
{
}
TokenMappingWrapper::TokenMappingWrapper(token_mapping_t mapping)
: token_mapping(move(mapping))
{
for (const auto &mp : token_mapping) {
if (mp.second.size() == 1) {
wchar_t c = mp.second.front();
reverse_token_mapping.emplace(c, mp.first);
}
}
}
TokenMappingWrapper::Ptr create_token_mapping(token_mapping_t tokenMapping)
{
return make_shared<TokenMappingWrapper>(move(tokenMapping));
}
template<typename token_t>
vector<tuple<wstring, float>>
decode_sequences_impl(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
c10::optional<torch::Tensor> probs)
{
const token_mapping_t &mapping = tokenMapping->token_mapping;
auto tokensAccess = tokens.accessor<token_t, 2>();
torch::Tensor pTens = probs.value_or(torch::ones({ tokens.size(0) }, torch::kFloat32));
if (pTens.dim() == 1) {
pTens = pTens.unsqueeze(1);
}
auto probsAccess = pTens.accessor<float, 2>();
const int64_t B = tokens.size(0);
const int64_t T = tokens.size(1);
vector<tuple<wstring, float>> ret;
for (int64_t b = 0; b < B; ++b) {
wstring buff;
float logProb = 0.0f; // log 1
bool done = false;
for (int64_t t = 0; t < T && ! done; ++t) {
typename token_mapping_t::key_type tokIdx = tokensAccess[b][t];
if (t < probsAccess.size(1)) {
logProb += log(probsAccess[b][t]);
}
switch (tokIdx) {
case 0:
// Blank char
continue;
case 1:
// End of sequence char
done = true;
break;
case 2:
buff.push_back('^');
break;
default:
auto iter = mapping.find(tokIdx);
if (iter == mapping.end()) {
throw std::runtime_error("The token mapping doesn't contain an entry for index " + to_string(tokIdx));
}
buff += iter->second;
break;
}
}
ret.emplace_back(move(buff), exp(logProb));
}
return ret;
}
vector<tuple<wstring, float>>
decode_sequences(torch::Tensor tokens, const TokenMappingWrapper *tokenMapping,
c10::optional<torch::Tensor> probs)
{
if (tokens.dim() != 2) {
throw std::runtime_error("`tokens` must be 2-dimensions of type B,T!");
}
if (tokenMapping == nullptr) {
throw std::runtime_error("Cannot supply a null token mapping!");
}
const token_mapping_t &mapping = tokenMapping->token_mapping;
if (mapping.empty()) {
throw std::runtime_error("The token mapping hasn't been initialized!");
}
if (probs.has_value()) {
if (probs.value().scalar_type() != torch::kFloat32) {
throw std::runtime_error("If the probability distribution is specified, then it must be of type `torch.float32`");
}
if (probs.value().size(0) != tokens.size(0)) {
throw std::runtime_error("The probability distribution batch size doesn't match the tokens batch size!");
}
if (probs.value().dim() == 2 && probs.value().size(1) != tokens.size(1)) {
throw std::runtime_error("Invalid probability distribution shape!");
}
}
vector<tuple<wstring, float>> ret;
AT_DISPATCH_INTEGRAL_TYPES(
tokens.scalar_type(),
"decode_sequences_impl",
([&] {
ret = decode_sequences_impl<scalar_t>(tokens, tokenMapping, probs);
})
);
return ret;
}
std::string ws2s(const std::wstring& wstr)
{
using convert_typeX = std::codecvt_utf8<wchar_t>;
std::wstring_convert<convert_typeX, wchar_t> converterX;
return converterX.to_bytes(wstr);
}
|