| | import os |
| | import tensorflow as tf |
| | from functools import lru_cache |
| | from huggingface_hub import hf_hub_download |
| | from hyperpyyaml import load_hyperpyyaml |
| | from typing import Union |
| |
|
| | from decode import get_searcher |
| |
|
| | os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
| |
|
| |
|
| | def _get_checkpoint_filename( |
| | repo_id: str, |
| | filename: str, |
| | local_dir: str = None, |
| | local_dir_use_symlinks: Union[bool, str] = "auto", |
| | subfolder: str = "checkpoints" |
| | ) -> str: |
| | model_filename = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=filename, |
| | subfolder=subfolder, |
| | local_dir=local_dir, |
| | local_dir_use_symlinks=local_dir_use_symlinks, |
| | ) |
| | return model_filename |
| |
|
| |
|
| | def _get_bpe_model_filename( |
| | repo_id: str, |
| | filename: str, |
| | local_dir: str = None, |
| | local_dir_use_symlinks: Union[bool, str] = "auto", |
| | subfolder: str = "vocabs" |
| | ) -> str: |
| | bpe_model_filename = hf_hub_download( |
| | repo_id=repo_id, |
| | filename=filename, |
| | subfolder=subfolder, |
| | local_dir=local_dir, |
| | local_dir_use_symlinks=local_dir_use_symlinks, |
| | ) |
| | return bpe_model_filename |
| |
|
| |
|
| | @lru_cache(maxsize=1) |
| | def _get_conformer_pre_trained_model(repo_id: str, checkpoint_dir: str = "checkpoints"): |
| | for postfix in ["index", "data-00000-of-00001"]: |
| | tmp = _get_checkpoint_filename( |
| | repo_id=repo_id, |
| | filename="avg_top5_27-32.ckpt.{}".format(postfix), |
| | subfolder=checkpoint_dir, |
| | local_dir=os.path.dirname(__file__), |
| | local_dir_use_symlinks=True, |
| | ) |
| | print(tmp) |
| |
|
| | for postfix in ["model", "vocab"]: |
| | tmp = _get_bpe_model_filename( |
| | repo_id=repo_id, |
| | filename="subword_vietnamese_500.{}".format(postfix), |
| | local_dir=os.path.dirname(__file__), |
| | local_dir_use_symlinks=True, |
| | ) |
| | print(tmp) |
| |
|
| | config_path = hf_hub_download( |
| | repo_id=repo_id, |
| | filename="config.yaml", |
| | local_dir=os.path.dirname(__file__), |
| | local_dir_use_symlinks=True, |
| | ) |
| |
|
| | with open(config_path, "r") as f: |
| | config = load_hyperpyyaml(f) |
| |
|
| | encoder_model = config["encoder_model"] |
| | text_encoder = config["text_encoder"] |
| | jointer = config["jointer_model"] |
| | decoder = config["decoder_model"] |
| | |
| | model = config["model"] |
| | audio_encoder = config["audio_encoder"] |
| | model.load_weights(os.path.join(checkpoint_dir, "avg_top5_27-32.ckpt")).expect_partial() |
| |
|
| | return audio_encoder, encoder_model, jointer, decoder, text_encoder, model |
| |
|
| |
|
| | def read_audio(in_filename: str): |
| | audio = tf.io.read_file(in_filename) |
| | audio = tf.audio.decode_wav(audio)[0] |
| | audio = tf.expand_dims(tf.squeeze(audio, axis=-1), axis=0) |
| | return audio |
| |
|
| |
|
| | class UETASRModel: |
| | def __init__( |
| | self, |
| | repo_id: str, |
| | decoding_method: str, |
| | beam_size: int, |
| | max_symbols_per_step: int, |
| | ): |
| | self.featurizer, self.encoder_model, jointer, decoder, text_encoder, self.model = _get_conformer_pre_trained_model(repo_id) |
| | self.searcher = get_searcher( |
| | decoding_method, |
| | decoder, |
| | jointer, |
| | text_encoder, |
| | beam_size, |
| | max_symbols_per_step, |
| | ) |
| |
|
| | def predict(self, in_filename: str): |
| | inputs = read_audio(in_filename) |
| | features = self.featurizer(inputs) |
| | features = self.model.cmvn(features) if self.model.use_cmvn else features |
| |
|
| | mask = tf.sequence_mask([tf.shape(features)[1]], maxlen=tf.shape(features)[1]) |
| | mask = tf.expand_dims(mask, axis=1) |
| | encoder_outputs, encoder_masks = self.encoder_model( |
| | features, mask, training=False) |
| |
|
| | encoder_mask = tf.squeeze(encoder_masks, axis=1) |
| | features_length = tf.math.reduce_sum( |
| | tf.cast(encoder_mask, tf.int32), |
| | axis=1 |
| | ) |
| |
|
| | outputs = self.searcher.infer(encoder_outputs, features_length) |
| | outputs = tf.squeeze(outputs) |
| | outputs = tf.compat.as_str_any(outputs.numpy()) |
| |
|
| | return outputs |
| |
|