Spaces:
Runtime error
Runtime error
| import os | |
| import tensorflow as tf | |
| from functools import lru_cache | |
| from huggingface_hub import hf_hub_download | |
| from hyperpyyaml import load_hyperpyyaml | |
| from typing import Union | |
| from decode import get_searcher | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "-1" | |
| def _get_checkpoint_filename( | |
| repo_id: str, | |
| filename: str, | |
| local_dir: str = None, | |
| local_dir_use_symlinks: Union[bool, str] = "auto", | |
| subfolder: str = "checkpoints" | |
| ) -> str: | |
| model_filename = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=local_dir_use_symlinks, | |
| ) | |
| return model_filename | |
| def _get_bpe_model_filename( | |
| repo_id: str, | |
| filename: str, | |
| local_dir: str = None, | |
| local_dir_use_symlinks: Union[bool, str] = "auto", | |
| subfolder: str = "vocabs" | |
| ) -> str: | |
| bpe_model_filename = hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| subfolder=subfolder, | |
| local_dir=local_dir, | |
| local_dir_use_symlinks=local_dir_use_symlinks, | |
| ) | |
| return bpe_model_filename | |
| def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"): | |
| for postfix in ["index", "data-00000-of-00001"]: | |
| tmp = _get_checkpoint_filename( | |
| repo_id=repo_id, | |
| filename="avg_top5_27-32.ckpt.{}".format(postfix), | |
| subfolder=checkpoint_dir, | |
| local_dir=os.path.dirname(__file__), # noqa | |
| local_dir_use_symlinks=True, | |
| ) | |
| print(tmp) | |
| for postfix in ["model", "vocab"]: | |
| tmp = _get_bpe_model_filename( | |
| repo_id=repo_id, | |
| filename="subword_vietnamese_500.{}".format(postfix), | |
| local_dir=os.path.dirname(__file__), # noqa | |
| local_dir_use_symlinks=True, | |
| ) | |
| print(tmp) | |
| config_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="config.yaml", | |
| local_dir=os.path.dirname(__file__), # noqa | |
| local_dir_use_symlinks=True, | |
| ) | |
| with open(config_path, "r") as f: | |
| config = load_hyperpyyaml(f) | |
| encoder_model = config["encoder_model"] | |
| text_encoder = config["text_encoder"] | |
| jointer = config["jointer_model"] | |
| decoder = config["decoder_model"] | |
| # searcher = config["decoder"] | |
| model = config["model"] | |
| audio_encoder = config["audio_encoder"] | |
| model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial() | |
| return audio_encoder, encoder_model, jointer, decoder, text_encoder, model | |
| def read_audio(in_filename: str): | |
| audio = tf.io.read_file(in_filename) | |
| audio = tf.audio.decode_wav(audio)[0] | |
| audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0) | |
| return audio | |
| class UETASRModel: | |
| def __init__( | |
| self, | |
| repo_id: str, | |
| decoding_method: str, | |
| beam_size: int, | |
| max_symbols_per_step: int, | |
| ): | |
| self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id) | |
| self.searcher = get_searcher( | |
| decoding_method, | |
| decoder, | |
| jointer, | |
| text_encoder, | |
| beam_size, | |
| max_symbols_per_step, | |
| ) | |
| def predict(self, in_filename: str): | |
| inputs = read_audio(in_filename) | |
| features = self.featurizer(inputs) | |
| features = self.model.cmvn(features) if self.model.use_cmvn else features | |
| mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1]) | |
| mask = tf.expand_dims(mask, axis=1) | |
| encoder_outputs, encoder_masks = self.encoder_model( | |
| features, mask, training=False) | |
| encoder_mask = tf.squeeze(encoder_masks, axis=1) | |
| features_length = tf.math.reduce_sum( | |
| tf.cast(encoder_mask, tf.int32), | |
| axis=1 | |
| ) | |
| outputs = self.searcher.infer(encoder_outputs, features_length) | |
| outputs = tf.squeeze(outputs) | |
| outputs = tf.compat.as_str_any(outputs.numpy()) | |
| return outputs | |