Spaces:
Runtime error
Runtime error
| import logging | |
| import tensorflow as tf | |
| from functools import lru_cache | |
| from uetasr.searchers import GreedyRNNT, BeamRNNT | |
| def get_searcher( | |
| searcher_type: str, | |
| decoder: tf.keras.Model, | |
| jointer: tf.keras.Model, | |
| text_decoder: tf.keras.layers.experimental.preprocessing.PreprocessingLayer, | |
| beam_size: int, | |
| max_symbols_per_step: int, | |
| ): | |
| common_kwargs = { | |
| "decoder": decoder, | |
| "jointer": jointer, | |
| "text_decoder": text_decoder, | |
| "return_scores": False, | |
| } | |
| if searcher_type == "greedy_search": | |
| searcher = GreedyRNNT( | |
| max_symbols_per_step=max_symbols_per_step, | |
| **common_kwargs, | |
| ) | |
| elif searcher_type == "beam_search": | |
| searcher = BeamRNNT( | |
| max_symbols_per_step=max_symbols_per_step, | |
| beam=beam_size, | |
| alpha=0.0, | |
| **common_kwargs, | |
| ) | |
| else: | |
| logging.info(f"Unknown searcher type: {searcher_type}") | |
| return searcher | |