WeNet / wenet /cli /paraformer_model.py
inoryQwQ's picture
First commit
3c50954
Raw
History Blame Contribute Delete
4.41 kB
import io
import os
from typing import Dict, List, Union
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from wenet.cli.hub import Hub
from wenet.paraformer.search import (gen_timestamps_from_peak,
paraformer_greedy_search)
from wenet.text.paraformer_tokenizer import ParaformerTokenizer
class Paraformer:
def __init__(self, model_dir: str, resample_rate: int = 16000) -> None:
model_path = os.path.join(model_dir, 'final.zip')
units_path = os.path.join(model_dir, 'units.txt')
self.model = torch.jit.load(model_path)
self.resample_rate = resample_rate
self.device = torch.device("cpu")
self.tokenizer = ParaformerTokenizer(symbol_table=units_path)
@torch.inference_mode()
def transcribe_batch(self,
audio_files: List[Union[str, bytes]],
tokens_info: bool = False) -> List[Dict]:
feats_lst = []
feats_lens_lst = []
for audio in audio_files:
if isinstance(audio, bytes):
with io.BytesIO(audio) as fobj:
waveform, sample_rate = torchaudio.load(fobj,
normalize=False)
else:
waveform, sample_rate = torchaudio.load(audio, normalize=False)
if sample_rate != self.resample_rate:
waveform = torchaudio.transforms.Resample(
orig_freq=sample_rate,
new_freq=self.resample_rate)(waveform)
waveform = waveform.to(torch.float)
feats = kaldi.fbank(waveform,
num_mel_bins=80,
frame_length=25,
frame_shift=10,
energy_floor=0.0,
sample_frequency=self.resample_rate,
window_type="hamming")
feats_lst.append(feats)
feats_lens_lst.append(
torch.tensor(feats.shape[0], dtype=torch.int64))
feats_tensor = torch.nn.utils.rnn.pad_sequence(
feats_lst, batch_first=True).to(device=self.device)
feats_lens_tensor = torch.tensor(feats_lens_lst, device=self.device)
decoder_out, token_num, tp_alphas, frames = self.model.forward_paraformer(
feats_tensor, feats_lens_tensor)
frames = frames.cpu().numpy()
cif_peaks = self.model.forward_cif_peaks(tp_alphas, token_num)
results = paraformer_greedy_search(decoder_out, token_num, cif_peaks)
r = []
for (i, res) in enumerate(results):
result = {}
result['confidence'] = res.confidence
result['text'] = self.tokenizer.detokenize(res.tokens)[0]
if tokens_info:
tokens_info_l = []
times = gen_timestamps_from_peak(res.times,
num_frames=frames[i],
frame_rate=0.02)
for i, x in enumerate(res.tokens[:len(times)]):
tokens_info_l.append({
'token':
self.tokenizer.char_dict[x],
'start':
round(times[i][0], 3),
'end':
round(times[i][1], 3),
'confidence':
round(res.tokens_confidence[i], 2)
})
result['tokens'] = tokens_info_l
r.append(result)
return r
def transcribe(self, audio_file: str, tokens_info: bool = False) -> dict:
result = self.transcribe_batch([audio_file], tokens_info)[0]
return result
def align(self, audio_file: str, label: str) -> dict:
raise NotImplementedError("Align is currently not supported")
def load_model(model_dir: str = None,
gpu: int = -1,
device: str = "cpu") -> Paraformer:
if model_dir is None:
model_dir = Hub.get_model_by_lang('paraformer')
if gpu != -1:
# remain the original usage of gpu
device = "cuda"
paraformer = Paraformer(model_dir)
paraformer.device = torch.device(device)
paraformer.model.to(device)
return paraformer