uetasr / decode.py
thanhtvt's picture
Update decode.py
ee21cdf verified
raw
history blame contribute delete
905 Bytes
import logging
import tensorflow as tf
from functools import lru_cache
from uetasr.searchers import GreedyRNNT, BeamRNNT
@lru_cache(maxsize=5)
def get_searcher(
searcher_type: str,
decoder: tf.keras.Model,
jointer: tf.keras.Model,
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer,
beam_size: int,
max_symbols_per_step: int,
):
common_kwargs = {
"decoder": decoder,
"jointer": jointer,
"text_decoder": text_decoder,
"return_scores": False,
}
if searcher_type == "greedy_search":
searcher = GreedyRNNT(
max_symbols_per_step=max_symbols_per_step,
**common_kwargs,
)
else:
searcher = BeamRNNT(
max_symbols_per_step=max_symbols_per_step,
beam=beam_size,
alpha=0.0,
**common_kwargs,
)
return searcher