File size: 6,286 Bytes
0c5ca4b
 
 
 
96f87e0
55f6a13
57bf40b
55f6a13
170cf1f
 
55f6a13
 
 
 
 
 
 
 
 
 
 
 
0c5ca4b
 
55f6a13
0c5ca4b
55f6a13
 
 
0c5ca4b
 
 
 
57bf40b
96f87e0
0c5ca4b
 
 
57bf40b
 
 
170cf1f
0c5ca4b
96f87e0
 
55f6a13
 
0c5ca4b
55f6a13
 
0c5ca4b
96f87e0
 
 
 
 
 
55f6a13
 
 
0c5ca4b
55f6a13
 
0c5ca4b
55f6a13
 
 
 
 
 
 
 
 
 
 
0c5ca4b
55f6a13
 
 
 
0c5ca4b
 
170cf1f
0c5ca4b
 
55f6a13
0c5ca4b
55f6a13
 
 
 
 
 
 
 
 
 
0c5ca4b
 
55f6a13
170cf1f
55f6a13
 
0c5ca4b
170cf1f
 
 
0c5ca4b
 
 
 
 
170cf1f
 
0c5ca4b
 
57bf40b
 
170cf1f
0c5ca4b
 
 
 
 
 
 
55f6a13
170cf1f
 
 
0c5ca4b
96f87e0
 
 
 
 
 
 
 
 
 
 
 
 
170cf1f
 
96f87e0
170cf1f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
import time
import torch
import os
from riva.asrlib.decoder.python_decoder import BatchedMappedDecoderCuda, BatchedMappedOnlineDecoderCuda, BatchedMappedDecoderCudaConfig
from typing import List
from test_frame_reducer import FrameReducer

USE_FINAL_PROBS=False

def remove_duplicates_and_blank(hyp: List[int],
                                eos: int,
                                blank_id: int = 0) -> List[int]:
    new_hyp: List[int] = []
    cur = 0
    while cur < len(hyp):
        if hyp[cur] != blank_id and hyp[cur] != eos:
            new_hyp.append(hyp[cur])
        prev = cur
        while cur < len(hyp) and hyp[cur] == hyp[prev]:
            cur += 1
    return new_hyp

class RivaWFSTDecoder:
    def __init__(self, vocab_size, tlg_dir, config_dict=None, beam_size=8):
        config = BatchedMappedDecoderCudaConfig()
        config.online_opts.decoder_opts.lattice_beam = beam_size

        config.online_opts.lattice_postprocessor_opts.acoustic_scale = 10.0 # noqa
        config.n_input_per_chunk = 50
        config.online_opts.decoder_opts.default_beam = 17.0
        config.online_opts.decoder_opts.max_active = 7000
        config.online_opts.determinize_lattice = True
        config.online_opts.max_batch_size = 100
        config.online_opts.num_channels = config.online_opts.max_batch_size * 2
        config.online_opts.frame_shift_seconds = 0.04
        config.online_opts.lattice_postprocessor_opts.lm_scale = 5.0
        config.online_opts.lattice_postprocessor_opts.word_ins_penalty = 0.0
        config.online_opts.decoder_opts.blank_penalty = 0.95
        config.online_opts.num_post_processing_worker_threads = 16
        config.online_opts.num_decoder_copy_threads = 4
        config.online_opts.use_final_probs = USE_FINAL_PROBS

        #config.online_opts.decoder_opts.ntokens_pre_allocated = 10_000_000

        config.online_opts.lattice_postprocessor_opts.nbest = beam_size

        self.decoder = BatchedMappedDecoderCuda(
            config, os.path.join(tlg_dir, "TLG.fst"),
            os.path.join(tlg_dir, "words.txt"), vocab_size
        )

        self.online_decoder = BatchedMappedOnlineDecoderCuda(
            config.online_opts, os.path.join(tlg_dir, "TLG.fst"),
            os.path.join(tlg_dir, "words.txt"), vocab_size
        )

        self.word_id_to_word_str = load_word_symbols(os.path.join(tlg_dir, "words.txt"))
        self.nbest = beam_size
        self.vocab_size = vocab_size

    def decode_nbest(self, logits, length):
        logits = logits.to(torch.float32).contiguous()
        sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
        results = self.decoder.decode_nbest(logits, sequence_lengths_tensor)
        total_hyps = []
        for nbest_sentences in results:
            nbest_list = []
            for sent in nbest_sentences:
                # subtract 1 to get the label id, since fst decoder adds 1 to the label id
                hyp_ids = [label - 1 for label in sent.ilabels]
                new_hyp = remove_duplicates_and_blank(hyp_ids, eos=self.vocab_size-1, blank_id=0)
                nbest_list.append(new_hyp)
            total_hyps.append(nbest_list)
        return total_hyps

    def decode_mbr(self, logits, length):
        logits = logits.to(torch.float32).contiguous()
        sequence_lengths_tensor = length.to(torch.long).to('cpu').contiguous()
        results = self.decoder.decode_mbr(logits, sequence_lengths_tensor)
        total_hyps = []
        for sent in results:
            #print([word for word in sent])
            hyp = [word[0] for word in sent]
            hyp_zh = "".join(hyp)
            total_hyps.append(hyp_zh)
        return total_hyps



def load_word_symbols(path):
    word_id_to_word_str = {}
    with open(path, "rt") as fh:
        for line in fh:
            word_str, word_id = line.rstrip().split()
            word_id_to_word_str[int(word_id)] = word_str
    return word_id_to_word_str

if __name__ == "__main__":
    lang_dir = "../output" # TLG.fst, words.txt
    data = np.load('./data/input3.npz')
    word_id_to_word_str = load_word_symbols(os.path.join(lang_dir, "words.txt"))
    char_dict = load_word_symbols('./data/words.txt')

    beam_size = 7
    batch_size = 1
    counts = 1

    # ctc_log_probs [1,103,4233]
    ctc_log_probs = torch.from_numpy(data['ctc_log_probs'])  
    # ctc_log_probs , [batch_size,T,vocab_size ]
    ctc_log_probs = ctc_log_probs.repeat(batch_size,1,1)
    encoder_out_lens = torch.from_numpy(data['encoder_out_len'])   # encoder_out_lens single element 103
    #encoder_out_lens = torch.from_numpy(data['encoder_out_lens'])   # encoder_out_lens single element 103
    encoder_out_lens = encoder_out_lens.repeat(batch_size)          # [batch_size]
    ctc_log_probs = ctc_log_probs.contiguous().cuda()
    frame_reducer = FrameReducer()

    #ctc_log_probs, encoder_out_lens = frame_reducer(ctc_log_probs, encoder_out_lens.cuda(), ctc_log_probs)

    vocab_size = ctc_log_probs.shape[2]
    riva_decoder = RivaWFSTDecoder(vocab_size, lang_dir, beam_size)

    decode_start = time.perf_counter()
    for i in range(counts):
        print("ctc_log_probs.shape:", ctc_log_probs.shape)
        total_hyps = riva_decoder.decode_mbr(ctc_log_probs, encoder_out_lens)
        print('mbr', 'use_final_probs:', USE_FINAL_PROBS, total_hyps)
        #total_hyps = riva_decoder.decode_nbest(ctc_log_probs, encoder_out_lens)
        #print('nbest', total_hyps)
    decode_end = time.perf_counter() - decode_start
    #chunk_size = 32
    ctc_log_probs_list, is_first_chunk, is_last_chunk = [], [True] * batch_size, [True] * batch_size
    corr_ids = list(range(batch_size))
    for corr_id in corr_ids:
        success = riva_decoder.online_decoder.try_init_corr_id(corr_id)
        assert success
    for i in range(batch_size):
        ctc_log_probs_list.append(ctc_log_probs[i,:,:])
    channels, partial_hypotheses = \
    riva_decoder.online_decoder.decode_batch(corr_ids, ctc_log_probs_list,
                            is_first_chunk, is_last_chunk)

    for j, ph in enumerate(partial_hypotheses):
        #print("streaming word ids", ph.words, ph.score)
        pass   

    #print(f"Decode {ctc_log_probs.shape[0] * counts} sentences, cost {decode_end} seconds")