#!/usr/bin/env python3 # -*- encoding: utf-8 -*- # Copyright FunASR (https://github.com/FunAudioLLM/SenseVoice). All Rights Reserved. # MIT License (https://opensource.org/licenses/MIT) import os.path from pathlib import Path from typing import List, Union, Tuple import torch import numpy as np import axengine as axe from funasr.utils.postprocess_utils import rich_transcription_postprocess try: import librosa except ImportError: print("Warning: librosa not found. Please install it using 'pip install librosa'.") # Provide a fallback implementation if needed def load_wav_fallback(path, sr=None): import wave import numpy as np with wave.open(path, 'rb') as wf: num_frames = wf.getnframes() frames = wf.readframes(num_frames) return np.frombuffer(frames, dtype=np.int16).astype(np.float32) / 32768.0, wf.getframerate() from utils.infer_utils import ( CharTokenizer, get_logger, read_yaml, ) from utils.frontend import WavFrontend from utils.ctc_alignment import ctc_forced_align logging = get_logger() def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): if maxlen is None: maxlen = lengths.max() row_vector = torch.arange(0, maxlen, 1).to(lengths.device) matrix = torch.unsqueeze(lengths, dim=-1) mask = row_vector < matrix mask = mask.detach() return mask.type(dtype).to(device) if device is not None else mask.type(dtype) class AX_SenseVoiceSmall: """ Author: Speech Lab of DAMO Academy, Alibaba Group Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition https://arxiv.org/abs/2206.08317 """ def __init__( self, model_dir: Union[str, Path] = None, batch_size: int = 1, seq_len: int = 68 ): model_file = os.path.join(model_dir, "sensevoice.axmodel") config_file = os.path.join(model_dir, "sensevoice/config.yaml") cmvn_file = os.path.join(model_dir, "sensevoice/am.mvn") config = read_yaml(config_file) self.model_dir = model_dir # token_list = os.path.join(model_dir, "tokens.json") # with open(token_list, "r", encoding="utf-8") as f: # token_list = json.load(f) # self.converter = TokenIDConverter(token_list) self.tokenizer = CharTokenizer() config["frontend_conf"]['cmvn_file'] = cmvn_file self.frontend = WavFrontend(**config["frontend_conf"]) # self.ort_infer = OrtInferSession( # model_file, device_id, intra_op_num_threads=intra_op_num_threads # ) #self.session = axe.InferenceSession(model_file, providers='AxEngineExecutionProvider') self.session = axe.InferenceSession(model_file) self.batch_size = batch_size self.blank_id = 0 self.seq_len = seq_len self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13} self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13} self.textnorm_dict = {"withitn": 14, "woitn": 15} self.textnorm_int_dict = {25016: 14, 25017: 15} self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004} def __call__(self, wav_content: Union[str, np.ndarray, List[str]], language: str, withitn: bool, position_encoding: np.ndarray, tokenizer=None, **kwargs) -> List: """Enhanced model inference with additional features from model.py Args: wav_content: Audio data or path language: Language code for processing withitn: Whether to use ITN (inverse text normalization) position_encoding: Position encoding tensor tokenizer: Tokenizer for text conversion **kwargs: Additional arguments """ # Start time tracking for metadata import time meta_data = {} time_start = time.perf_counter() # Load waveform data waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) waveform_nums = len(waveform_list) time_load = time.perf_counter() meta_data["load_data"] = f"{time_load - time_start:0.3f}" # Load queries from saved numpy files language_query = np.load(os.path.join(self.model_dir, f"{language}.npy")) textnorm_query = np.load(os.path.join(self.model_dir, "withitn.npy") if withitn else os.path.join(self.model_dir, "woitn.npy")) event_emo_query = np.load(os.path.join(self.model_dir, "event_emo.npy")) # Concatenate queries to form input_query input_query = np.concatenate((language_query, event_emo_query, textnorm_query), axis=1) # Process features results = "" # Handle output_dir without using DatadirWriter (which is not available) slice_len = self.seq_len - 4 time_pre = time.perf_counter() meta_data["preprocess"] = f"{time_pre - time_load:0.3f}" for beg_idx in range(0, waveform_nums, self.batch_size): end_idx = min(waveform_nums, beg_idx + self.batch_size) feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) time_feat = time.perf_counter() meta_data["extract_feat"] = f"{time_feat - time_pre:0.3f}" for i in range(int(np.ceil(feats.shape[1] / slice_len))): sub_feats = np.concatenate([input_query, feats[:, i*slice_len : (i+1)*slice_len, :]], axis=1) feats_len[0] = sub_feats.shape[1] if feats_len[0] < self.seq_len: sub_feats = np.concatenate([sub_feats, np.zeros((1, self.seq_len - feats_len[0], 560), dtype=np.float32)], axis=1) masks = sequence_mask(torch.IntTensor([self.seq_len]), maxlen=self.seq_len, dtype=torch.float32)[:, None, :] masks = masks.numpy() # Run inference ctc_logits, encoder_out_lens = self.infer(sub_feats, masks, position_encoding) # Convert to torch tensor for processing ctc_logits = torch.from_numpy(ctc_logits).float() # Process results for each batch b, _, _ = ctc_logits.size() for j in range(b): x = ctc_logits[j, : encoder_out_lens[j].item(), :] yseq = x.argmax(dim=-1) yseq = torch.unique_consecutive(yseq, dim=-1) mask = yseq != self.blank_id token_int = yseq[mask].tolist()[4:] #前4个略去: <|zh|><|ANGRY|><|Speech|><|withitn|> # Convert tokens to text text = tokenizer.decode(token_int) if tokenizer is not None else str(token_int) if tokenizer is not None: results+= text else: results+= token_int return results def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: def load_wav(path: str) -> np.ndarray: try: # Use librosa if available if 'librosa' in globals(): waveform, _ = librosa.load(path, sr=fs) else: # Use fallback implementation waveform, native_sr = load_wav_fallback(path) if fs is not None and native_sr != fs: # Implement resampling if needed print(f"Warning: Resampling from {native_sr} to {fs} is not implemented in fallback mode") return waveform except Exception as e: print(f"Error loading audio file {path}: {e}") # Return empty audio in case of error return np.zeros(1600, dtype=np.float32) if isinstance(wav_content, np.ndarray): return [wav_content] if isinstance(wav_content, str): return [load_wav(wav_content)] if isinstance(wav_content, list): return [load_wav(path) for path in wav_content] raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]") def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: feats, feats_len = [], [] for waveform in waveform_list: speech, _ = self.frontend.fbank(waveform) feat, feat_len = self.frontend.lfr_cmvn(speech) feats.append(feat) feats_len.append(feat_len) feats = self.pad_feats(feats, np.max(feats_len)) feats_len = np.array(feats_len).astype(np.int32) return feats, feats_len @staticmethod def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray: def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray: pad_width = ((0, max_feat_len - cur_len), (0, 0)) return np.pad(feat, pad_width, "constant", constant_values=0) feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats] feats = np.array(feat_res).astype(np.float32) return feats def infer(self, feats: np.ndarray, masks: np.ndarray, position_encoding: np.ndarray, ) -> Tuple[np.ndarray, np.ndarray]: #outputs = self.ort_infer([feats, masks, position_encoding]) outputs =self.session.run(None, { 'speech': feats, 'masks': masks, 'position_encoding': position_encoding }) return outputs