File size: 905 Bytes
e9812a3 4ac7ffc e9812a3 ee21cdf e9812a3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import logging
import tensorflow as tf
from functools import lru_cache
from uetasr.searchers import GreedyRNNT, BeamRNNT
@lru_cache(maxsize=5)
def get_searcher(
searcher_type: str,
decoder: tf.keras.Model,
jointer: tf.keras.Model,
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
beam_size: int,
max_symbols_per_step: int,
):
common_kwargs = {
"decoder": decoder,
"jointer": jointer,
"text_decoder": text_decoder,
"return_scores": False,
}
if searcher_type == "greedy_search":
searcher = GreedyRNNT(
max_symbols_per_step=max_symbols_per_step,
**common_kwargs,
)
else:
searcher = BeamRNNT(
max_symbols_per_step=max_symbols_per_step,
beam=beam_size,
alpha=0.0,
**common_kwargs,
)
return searcher
|