|
|
from fireredasr.data.asr_feat import ASRFeatExtractor |
|
|
from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer |
|
|
|
|
|
import axengine as axe |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
from torch import Tensor |
|
|
from typing import Tuple, List, Dict |
|
|
import os |
|
|
import time |
|
|
import torchaudio |
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
|
|
try: |
|
|
torchaudio.set_audio_backend("soundfile") |
|
|
except Exception as e: |
|
|
print("Please run apt install libsnffile1 first") |
|
|
raise e |
|
|
|
|
|
from silero_vad_axera import load_silero_vad, read_audio, get_speech_timestamps |
|
|
|
|
|
INF = 1e10 |
|
|
|
|
|
|
|
|
def to_numpy(tensor): |
|
|
if isinstance(tensor, np.ndarray): |
|
|
return tensor |
|
|
if tensor.requires_grad: |
|
|
return tensor.detach().cpu().numpy() |
|
|
else: |
|
|
return tensor.cpu().numpy() |
|
|
|
|
|
|
|
|
def set_finished_beam_score_to_zero(scores, is_finished): |
|
|
NB, B = scores.size() |
|
|
is_finished = is_finished.float() |
|
|
mask_score = torch.tensor([0.0] + [-INF] * (B - 1)).float() |
|
|
mask_score = mask_score.view(1, B).repeat(NB, 1) |
|
|
return scores * (1 - is_finished) + mask_score * is_finished |
|
|
|
|
|
|
|
|
def set_finished_beam_y_to_eos(ys, is_finished, eos_id): |
|
|
is_finished = is_finished.long() |
|
|
return ys * (1 - is_finished) + eos_id * is_finished |
|
|
|
|
|
|
|
|
def expand_for_beam_search(n_layer_cross_k, beam_size): |
|
|
"""方法1: 使用expand_dims + tile + reshape (最快)""" |
|
|
num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape |
|
|
|
|
|
|
|
|
expanded = np.expand_dims(n_layer_cross_k, axis=2) |
|
|
|
|
|
tiled = np.tile(expanded, (1, 1, beam_size, 1, 1)) |
|
|
|
|
|
reshaped = tiled.reshape(num_layer, beam_size * batch_size, Ti, encoder_out_dim) |
|
|
|
|
|
return reshaped |
|
|
|
|
|
|
|
|
class FireRedASRAxModel: |
|
|
def __init__(self, |
|
|
encoder_path: str, |
|
|
decoder_loop_path: str, |
|
|
cmvn_file: str, |
|
|
dict_file: str, |
|
|
spm_model_path: str, |
|
|
providers=["AxEngineExecutionProvider"], |
|
|
decode_max_len=128, |
|
|
audio_dur=10): |
|
|
|
|
|
|
|
|
|
|
|
self.decode_max_len = decode_max_len |
|
|
self.sample_rate = 16000 |
|
|
self.decoder_hidden_dim = 1280 |
|
|
self.audio_dur = audio_dur |
|
|
self.max_feat_len = self.calc_feat_len(audio_dur) |
|
|
self.num_decoder_blocks = 16 |
|
|
self.blank_id = 0 |
|
|
self.sos_id = 3 |
|
|
self.eos_id = 4 |
|
|
self.pad_id = 2 |
|
|
|
|
|
self.feature_extractor = ASRFeatExtractor(cmvn_file) |
|
|
self.tokenizer = ChineseCharEnglishSpmTokenizer(dict_file, spm_model_path) |
|
|
|
|
|
self.init_encoder(encoder_path, providers) |
|
|
self.init_decoder_loop(decoder_loop_path, providers) |
|
|
self.pe = self.init_pe(decoder_loop_path) |
|
|
|
|
|
self.vad_model = load_silero_vad() |
|
|
|
|
|
|
|
|
self._preallocated_memory() |
|
|
|
|
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
def calc_feat_len(self, audio_dur): |
|
|
import math |
|
|
|
|
|
sample_rate = self.sample_rate |
|
|
frame_length = 25 * sample_rate / 1000 |
|
|
frame_shift = 10 * sample_rate / 1000 |
|
|
length = math.floor((audio_dur * sample_rate - frame_length) / frame_shift) + 1 |
|
|
return length |
|
|
|
|
|
def init_encoder(self, encoder_path, providers=None): |
|
|
self.encoder = axe.InferenceSession(encoder_path, providers=providers) |
|
|
|
|
|
def init_decoder_loop(self, decoder_path, providers=None): |
|
|
self.decoder_loop = axe.InferenceSession(decoder_path, providers=providers) |
|
|
|
|
|
def init_pe(self, decoder_path): |
|
|
decoder_path = os.path.dirname(decoder_path) |
|
|
decoder_path = os.path.join(decoder_path, "pe.npy") |
|
|
|
|
|
return np.load(decoder_path) |
|
|
|
|
|
def run_encoder( |
|
|
self, input: np.ndarray, input_length: np.ndarray |
|
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.encoder.run( |
|
|
None, {"encoder_input": input, "encoder_input_lengths": input_length} |
|
|
) |
|
|
return (n_layer_cross_k, n_layer_cross_v, cross_attn_mask) |
|
|
|
|
|
def decode_loop_one_token( |
|
|
self, |
|
|
tokens: np.ndarray, |
|
|
n_layer_self_k_cache: np.ndarray, |
|
|
n_layer_self_v_cache: np.ndarray, |
|
|
n_layer_cross_k_cache: np.ndarray, |
|
|
n_layer_cross_v_cache: np.ndarray, |
|
|
pe: np.ndarray, |
|
|
self_attn_mask: np.ndarray, |
|
|
cross_attn_mask: np.ndarray, |
|
|
) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
( |
|
|
logits, |
|
|
out_n_layer_self_k_cache, |
|
|
out_n_layer_self_v_cache, |
|
|
) = self.decoder_loop.run( |
|
|
None, |
|
|
{ |
|
|
"tokens": tokens, |
|
|
"in_n_layer_self_k_cache": n_layer_self_k_cache, |
|
|
"in_n_layer_self_v_cache": n_layer_self_v_cache, |
|
|
"n_layer_cross_k": n_layer_cross_k_cache, |
|
|
"n_layer_cross_v": n_layer_cross_v_cache, |
|
|
"pe": pe, |
|
|
"self_attn_mask": self_attn_mask, |
|
|
"cross_attn_mask": cross_attn_mask, |
|
|
}, |
|
|
) |
|
|
return (logits, out_n_layer_self_k_cache, out_n_layer_self_v_cache) |
|
|
|
|
|
def _preallocated_memory(self): |
|
|
"""预分配常用内存空间""" |
|
|
|
|
|
self.self_attn_mask_templates = {} |
|
|
for offset in range(self.decode_max_len): |
|
|
mask = np.zeros((1, 1, self.decode_max_len), dtype=np.float32) |
|
|
mask[:, :, :self.decode_max_len - offset - 1] = -np.inf |
|
|
self.self_attn_mask_templates[offset] = mask |
|
|
|
|
|
|
|
|
self.beam_scores_template = torch.tensor( |
|
|
[0.0] + [-INF] * (self.decode_max_len - 1) |
|
|
).float() |
|
|
|
|
|
def transcribe( |
|
|
self, |
|
|
batch_wav_path: List[str], |
|
|
beam_size: int = 1, |
|
|
nbest: int = 1, |
|
|
use_parallel: bool = False |
|
|
) -> List[Dict]: |
|
|
"""优化后的转录方法""" |
|
|
|
|
|
|
|
|
chunks = self._optimized_vad_split(batch_wav_path[0]) |
|
|
|
|
|
if use_parallel and len(chunks) > 1: |
|
|
return self._parallel_transcribe(chunks, beam_size, nbest) |
|
|
else: |
|
|
return self._sequential_transcribe(chunks, beam_size, nbest) |
|
|
|
|
|
def _optimized_vad_split(self, wav_path: str) -> List[torch.Tensor]: |
|
|
"""优化的VAD分块处理""" |
|
|
import torchaudio |
|
|
|
|
|
|
|
|
try: |
|
|
wav, sr = torchaudio.load(wav_path) |
|
|
if sr != self.sample_rate: |
|
|
wav = torchaudio.functional.resample(wav, sr, self.sample_rate) |
|
|
except: |
|
|
|
|
|
from silero_vad import read_audio |
|
|
wav = read_audio(wav_path, sampling_rate=self.sample_rate) |
|
|
wav = wav.unsqueeze(0) |
|
|
|
|
|
wav = wav.squeeze(0) |
|
|
|
|
|
|
|
|
max_chunk_samples = int(self.sample_rate * self.audio_dur) |
|
|
if wav.shape[0] < max_chunk_samples: |
|
|
return [wav] |
|
|
|
|
|
|
|
|
speech_timestamps = get_speech_timestamps( |
|
|
wav, |
|
|
self.vad_model, |
|
|
threshold=0.5, |
|
|
min_speech_duration_ms=250, |
|
|
min_silence_duration_ms=100, |
|
|
return_seconds=False, |
|
|
) |
|
|
|
|
|
|
|
|
return self._optimized_collect_chunks(wav, speech_timestamps) |
|
|
|
|
|
def _optimized_collect_chunks( |
|
|
self, |
|
|
wav: torch.Tensor, |
|
|
speech_timestamps: List[Dict] |
|
|
) -> List[torch.Tensor]: |
|
|
"""优化的分块合并算法""" |
|
|
max_chunk_samples = int(self.sample_rate * self.audio_dur) |
|
|
chunks = [] |
|
|
current_chunk = [] |
|
|
current_length = 0 |
|
|
|
|
|
for ts in speech_timestamps: |
|
|
start, end = ts["start"], ts["end"] |
|
|
chunk_length = end - start |
|
|
|
|
|
if current_length + chunk_length <= max_chunk_samples: |
|
|
current_chunk.append((start, end)) |
|
|
current_length += chunk_length |
|
|
else: |
|
|
if current_chunk: |
|
|
|
|
|
merged = torch.cat([wav[s:e] for s, e in current_chunk]) |
|
|
chunks.append(merged) |
|
|
|
|
|
if chunk_length > max_chunk_samples: |
|
|
|
|
|
num_splits = (chunk_length + max_chunk_samples - 1) // max_chunk_samples |
|
|
for i in range(num_splits): |
|
|
s = start + i * max_chunk_samples |
|
|
e = min(start + (i + 1) * max_chunk_samples, end) |
|
|
chunks.append(wav[s:e]) |
|
|
current_chunk = [] |
|
|
current_length = 0 |
|
|
else: |
|
|
current_chunk = [(start, end)] |
|
|
current_length = chunk_length |
|
|
|
|
|
|
|
|
if current_chunk: |
|
|
merged = torch.cat([wav[s:e] for s, e in current_chunk]) |
|
|
chunks.append(merged) |
|
|
|
|
|
return chunks |
|
|
|
|
|
def _optimized_decode_loop( |
|
|
self, |
|
|
n_layer_cross_k: np.ndarray, |
|
|
n_layer_cross_v: np.ndarray, |
|
|
cross_attn_mask: np.ndarray, |
|
|
beam_size: int, |
|
|
nbest: int |
|
|
) -> List[Dict]: |
|
|
"""优化的解码循环""" |
|
|
|
|
|
num_layer, batch_size, Ti, encoder_out_dim = n_layer_cross_k.shape |
|
|
encoder_out_length = cross_attn_mask.shape[-1] |
|
|
|
|
|
n_layer_cross_k = expand_for_beam_search(n_layer_cross_k, beam_size) |
|
|
n_layer_cross_v = expand_for_beam_search(n_layer_cross_v, beam_size) |
|
|
|
|
|
batch_size, Ti, encoder_out_length = cross_attn_mask.shape |
|
|
|
|
|
|
|
|
expanded = np.expand_dims(cross_attn_mask, axis=1) |
|
|
|
|
|
tiled = np.tile(expanded, (1, beam_size, 1, 1)) |
|
|
|
|
|
cross_attn_mask = tiled.reshape(beam_size * batch_size, Ti, encoder_out_length) |
|
|
|
|
|
|
|
|
n_layer_self_k_cache, n_layer_self_v_cache = self._optimized_init_self_cache( |
|
|
batch_size, beam_size |
|
|
) |
|
|
|
|
|
|
|
|
tokens = torch.full( |
|
|
(beam_size * batch_size, 1), |
|
|
self.sos_id, |
|
|
dtype=torch.int32, device=self.device |
|
|
) |
|
|
scores = self.beam_scores_template[:beam_size].repeat(batch_size).view( |
|
|
batch_size * beam_size, 1 |
|
|
).to(self.device) |
|
|
is_finished = torch.zeros_like(scores, dtype=torch.bool, device=self.device) |
|
|
|
|
|
|
|
|
prediction_tokens = tokens.clone() |
|
|
|
|
|
pe_np = self.pe |
|
|
|
|
|
for offset in range(self.decode_max_len): |
|
|
|
|
|
self_attn_mask = np.repeat( |
|
|
self.self_attn_mask_templates[offset], |
|
|
beam_size * batch_size, |
|
|
axis=0 |
|
|
) |
|
|
|
|
|
|
|
|
logits, n_layer_self_k_cache, n_layer_self_v_cache = ( |
|
|
self.decode_loop_one_token( |
|
|
tokens.cpu().numpy().astype(np.int32), |
|
|
n_layer_self_k_cache, |
|
|
n_layer_self_v_cache, |
|
|
n_layer_cross_k, |
|
|
n_layer_cross_v, |
|
|
pe_np[offset], |
|
|
self_attn_mask, |
|
|
cross_attn_mask |
|
|
) |
|
|
) |
|
|
|
|
|
logits = torch.from_numpy(logits).to(self.device).squeeze(1) |
|
|
t_scores = F.log_softmax(logits, dim=-1) |
|
|
|
|
|
|
|
|
tokens, scores, prediction_tokens, n_layer_self_k_cache, n_layer_self_v_cache, is_finished = ( |
|
|
self._optimized_beam_search( |
|
|
t_scores, tokens, scores, prediction_tokens, |
|
|
n_layer_self_k_cache, n_layer_self_v_cache, |
|
|
is_finished, beam_size, batch_size |
|
|
) |
|
|
) |
|
|
|
|
|
if is_finished.all(): |
|
|
break |
|
|
|
|
|
|
|
|
return self.extract_results_numpy_vectorized(scores.numpy(), prediction_tokens.numpy(), batch_size, beam_size, nbest) |
|
|
|
|
|
|
|
|
def _optimized_beam_search( |
|
|
self, |
|
|
t_scores: torch.Tensor, |
|
|
tokens: torch.Tensor, |
|
|
scores: torch.Tensor, |
|
|
prediction_tokens: torch.Tensor, |
|
|
n_layer_self_k_cache: torch.Tensor, |
|
|
n_layer_self_v_cache: torch.Tensor, |
|
|
is_finished: torch.Tensor, |
|
|
beam_size: int, |
|
|
batch_size: int |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
"""优化的beam search步骤""" |
|
|
|
|
|
|
|
|
t_topB_scores, t_topB_ys = torch.topk(t_scores, k=beam_size, dim=1) |
|
|
|
|
|
|
|
|
if is_finished.any(): |
|
|
|
|
|
t_topB_scores.masked_fill_(is_finished, 0.0) |
|
|
t_topB_scores[:, 1:].masked_fill_(is_finished, -INF) |
|
|
t_topB_ys.masked_fill_(is_finished, self.eos_id) |
|
|
|
|
|
|
|
|
scores = scores + t_topB_scores |
|
|
|
|
|
|
|
|
scores_2d = scores.view(batch_size, beam_size * beam_size) |
|
|
top_scores, top_ids = torch.topk(scores_2d, k=beam_size, dim=1) |
|
|
scores = top_scores.view(-1, 1) |
|
|
|
|
|
|
|
|
topB_row_number_in_each_B_rows_of_ys = torch.div(top_ids, beam_size, rounding_mode='floor') |
|
|
stride = beam_size * torch.arange(batch_size, device=self.device).view(batch_size, 1) |
|
|
topB_row_number_in_ys = (topB_row_number_in_each_B_rows_of_ys + stride).view(-1) |
|
|
|
|
|
|
|
|
tokens = torch.gather( |
|
|
t_topB_ys.view(batch_size, beam_size * beam_size), |
|
|
dim=1, |
|
|
index=top_ids, |
|
|
).view(beam_size * batch_size, 1) |
|
|
|
|
|
prediction_tokens = torch.cat([ |
|
|
prediction_tokens[topB_row_number_in_ys], |
|
|
tokens |
|
|
], dim=1) |
|
|
|
|
|
|
|
|
for i in range(n_layer_self_k_cache.shape[0]): |
|
|
n_layer_self_k_cache[i] = n_layer_self_k_cache[i][topB_row_number_in_ys] |
|
|
n_layer_self_v_cache[i] = n_layer_self_v_cache[i][topB_row_number_in_ys] |
|
|
|
|
|
|
|
|
is_finished = tokens.eq(self.eos_id) |
|
|
|
|
|
return tokens, scores, prediction_tokens, n_layer_self_k_cache, n_layer_self_v_cache, is_finished |
|
|
|
|
|
def _optimized_init_self_cache( |
|
|
self, batch_size: int, beam_size: int |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""优化的self cache初始化""" |
|
|
shape = ( |
|
|
self.num_decoder_blocks, |
|
|
batch_size * beam_size, |
|
|
self.decode_max_len, |
|
|
self.decoder_hidden_dim |
|
|
) |
|
|
n_layer_self_k_cache = np.zeros(shape, dtype=np.float32) |
|
|
n_layer_self_v_cache = np.zeros(shape, dtype=np.float32) |
|
|
return n_layer_self_k_cache, n_layer_self_v_cache |
|
|
|
|
|
def _extract_results( |
|
|
self, |
|
|
scores: torch.Tensor, |
|
|
prediction_tokens: torch.Tensor, |
|
|
batch_size: int, |
|
|
beam_size: int, |
|
|
nbest: int |
|
|
) -> List[Dict]: |
|
|
"""提取结果""" |
|
|
scores = scores.view(batch_size, beam_size) |
|
|
valid_lengths = torch.sum( |
|
|
torch.ne(prediction_tokens.view(batch_size, beam_size, -1), self.eos_id), |
|
|
dim=-1 |
|
|
).int() |
|
|
|
|
|
nbest_scores, nbest_ids = torch.topk(scores, k=nbest, dim=1) |
|
|
index = nbest_ids + beam_size * torch.arange(batch_size, device=self.device).unsqueeze(1) |
|
|
|
|
|
nbest_tokens = prediction_tokens.view(batch_size * beam_size, -1)[index.view(-1)] |
|
|
nbest_tokens = nbest_tokens.view(batch_size, nbest_ids.size(1), -1) |
|
|
|
|
|
results = [] |
|
|
for j, score in enumerate(nbest_scores[0]): |
|
|
hyp = { |
|
|
"token_ids": nbest_tokens[0, j, 1:valid_lengths[0, nbest_ids[0, j]]], |
|
|
"score": score, |
|
|
} |
|
|
results.append(hyp) |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def extract_results_numpy_vectorized( |
|
|
self, |
|
|
scores: np.ndarray, |
|
|
prediction_tokens: np.ndarray, |
|
|
batch_size: int, |
|
|
beam_size: int, |
|
|
nbest: int, |
|
|
eos_id: int = 4 |
|
|
) -> List[Dict]: |
|
|
"""向量化版本的NumPy实现""" |
|
|
|
|
|
|
|
|
scores_2d = scores.reshape(batch_size, beam_size) |
|
|
tokens_3d = prediction_tokens.reshape(batch_size, beam_size, -1) |
|
|
|
|
|
|
|
|
valid_lengths = np.sum(tokens_3d != eos_id, axis=-1).astype(np.int32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
partitioned_indices = np.argpartition(-scores_2d, nbest-1, axis=1)[:, :nbest] |
|
|
|
|
|
|
|
|
nbest_scores = np.take_along_axis(scores_2d, partitioned_indices, axis=1) |
|
|
sorted_order = np.argsort(-nbest_scores, axis=1) |
|
|
|
|
|
|
|
|
nbest_ids = np.take_along_axis(partitioned_indices, sorted_order, axis=1) |
|
|
nbest_scores = np.take_along_axis(nbest_scores, sorted_order, axis=1) |
|
|
|
|
|
|
|
|
batch_indices = np.arange(batch_size)[:, np.newaxis] |
|
|
global_indices = nbest_ids + beam_size * batch_indices |
|
|
flat_global_indices = global_indices.reshape(-1) |
|
|
|
|
|
|
|
|
flat_tokens = prediction_tokens.reshape(-1, prediction_tokens.shape[-1]) |
|
|
nbest_tokens = flat_tokens[flat_global_indices] |
|
|
nbest_tokens = nbest_tokens.reshape(batch_size, nbest, -1) |
|
|
|
|
|
|
|
|
nbest_valid_lengths = np.take_along_axis(valid_lengths, nbest_ids, axis=1) |
|
|
|
|
|
|
|
|
results = [] |
|
|
|
|
|
for b in range(batch_size): |
|
|
batch_results = [] |
|
|
for j in range(nbest): |
|
|
valid_len = nbest_valid_lengths[b, j] |
|
|
|
|
|
|
|
|
token_ids = nbest_tokens[b, j, 1:valid_len] |
|
|
|
|
|
hyp = { |
|
|
"token_ids": token_ids.tolist(), |
|
|
"score": float(nbest_scores[b, j]), |
|
|
} |
|
|
batch_results.append(hyp) |
|
|
|
|
|
|
|
|
|
|
|
if b == 0: |
|
|
results = batch_results |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def _sequential_transcribe( |
|
|
self, |
|
|
chunks: List[torch.Tensor], |
|
|
beam_size: int, |
|
|
nbest: int |
|
|
) -> Dict: |
|
|
"""顺序转录(单线程)""" |
|
|
tokens = [] |
|
|
wav_durations = [] |
|
|
transcribe_duration = 0 |
|
|
|
|
|
for chunk in chunks: |
|
|
|
|
|
feats, lengths, wav_duration = self._optimized_feature_extraction(chunk) |
|
|
wav_durations.append(wav_duration) |
|
|
|
|
|
|
|
|
start_time = time.time() |
|
|
n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder( |
|
|
feats, lengths.numpy().astype(np.int32) |
|
|
) |
|
|
|
|
|
nbest_hyps = self._optimized_decode_loop( |
|
|
n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest |
|
|
) |
|
|
|
|
|
tokens.extend([int(id) for id in nbest_hyps[0]["token_ids"]]) |
|
|
transcribe_duration += time.time() - start_time |
|
|
|
|
|
text = self.tokenizer.detokenize(tokens) |
|
|
return {"text": text}, wav_durations, transcribe_duration |
|
|
|
|
|
def _parallel_transcribe( |
|
|
self, |
|
|
chunks: List[torch.Tensor], |
|
|
beam_size: int, |
|
|
nbest: int |
|
|
) -> Dict: |
|
|
"""并行转录(多线程)""" |
|
|
import threading |
|
|
|
|
|
results = [] |
|
|
lock = threading.Lock() |
|
|
|
|
|
def process_chunk(chunk_idx, chunk): |
|
|
try: |
|
|
|
|
|
feats, lengths, wav_duration = self._optimized_feature_extraction(chunk) |
|
|
|
|
|
|
|
|
n_layer_cross_k, n_layer_cross_v, cross_attn_mask = self.run_encoder( |
|
|
feats, lengths.astype(np.int32) |
|
|
) |
|
|
|
|
|
|
|
|
nbest_hyps = self._optimized_decode_loop( |
|
|
n_layer_cross_k, n_layer_cross_v, cross_attn_mask, beam_size, nbest |
|
|
) |
|
|
|
|
|
with lock: |
|
|
results.append({ |
|
|
'chunk_idx': chunk_idx, |
|
|
'tokens': [int(id) for id in nbest_hyps[0]["token_ids"].cpu()], |
|
|
'duration': wav_duration |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error processing chunk {chunk_idx}: {e}") |
|
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=min(4, len(chunks))) as executor: |
|
|
futures = [] |
|
|
for i, chunk in enumerate(chunks): |
|
|
future = executor.submit(process_chunk, i, chunk) |
|
|
futures.append(future) |
|
|
|
|
|
|
|
|
for future in as_completed(futures): |
|
|
future.result() |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x['chunk_idx']) |
|
|
tokens = [] |
|
|
wav_durations = [] |
|
|
|
|
|
for result in results: |
|
|
tokens.extend(result['tokens']) |
|
|
wav_durations.append(result['duration']) |
|
|
|
|
|
text = self.tokenizer.detokenize(tokens) |
|
|
return {"text": text}, wav_durations, 0 |
|
|
|
|
|
def _optimized_feature_extraction( |
|
|
self, |
|
|
chunk: torch.Tensor |
|
|
) -> Tuple[np.ndarray, np.ndarray, float]: |
|
|
"""优化的特征提取""" |
|
|
chunk = (chunk.clamp(-1, 1) * 32768).to(torch.int16) |
|
|
feats, lengths, wav_duration = self.feature_extractor.run_chunk( |
|
|
chunk, self.sample_rate |
|
|
) |
|
|
|
|
|
|
|
|
if feats.shape[1] < self.max_feat_len: |
|
|
pad_width = ((0, 0), (0, self.max_feat_len - feats.shape[1]), (0, 0)) |
|
|
feats = np.pad(feats, pad_width, mode='constant', constant_values=0) |
|
|
|
|
|
feats = feats[:, :self.max_feat_len, :] |
|
|
lengths = np.minimum(lengths, self.max_feat_len) |
|
|
|
|
|
return feats, lengths, wav_duration |