|
|
import logging |
|
|
import tensorflow as tf |
|
|
from functools import lru_cache |
|
|
from uetasr.searchers import GreedyRNNT, BeamRNNT |
|
|
|
|
|
|
|
|
@lru_cache(maxsize=5) |
|
|
def get_searcher( |
|
|
searcher_type: str, |
|
|
decoder: tf.keras.Model, |
|
|
jointer: tf.keras.Model, |
|
|
text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer, |
|
|
beam_size: int, |
|
|
max_symbols_per_step: int, |
|
|
): |
|
|
common_kwargs = { |
|
|
"decoder": decoder, |
|
|
"jointer": jointer, |
|
|
"text_decoder": text_decoder, |
|
|
"return_scores": False, |
|
|
} |
|
|
if searcher_type == "greedy_search": |
|
|
searcher = GreedyRNNT( |
|
|
max_symbols_per_step=max_symbols_per_step, |
|
|
**common_kwargs, |
|
|
) |
|
|
else: |
|
|
searcher = BeamRNNT( |
|
|
max_symbols_per_step=max_symbols_per_step, |
|
|
beam=beam_size, |
|
|
alpha=0.0, |
|
|
**common_kwargs, |
|
|
) |
|
|
|
|
|
return searcher |
|
|
|