import axengine as axe import numpy as np import librosa import os from typing import Union, List import json from dataclasses import dataclass, field import zhconv import base64 @dataclass class WhisperConfig: n_mels: int = 0 sample_rate: int = 0 n_fft: int = 0 hop_length: int = 0 sot: int = 0 eot: int = 0 blank_id: int = 0 no_timestamps: int = 0 no_speech: int = 0 translate: int = 0 transcribe: int = 0 n_vocab: int = 0 n_text_ctx: int = 0 n_text_state: int = 0 sot_sequence: np.ndarray = field( default_factory=lambda: np.array([0, 0, 0, 0], dtype=np.int32) ) class Whisper: def __init__(self, model_type: str, model_path: str, language: str, task: str): self.language = language self.task = task self.encoder, self.decoder, model_config = self.load_model( model_type, model_path, language, task ) self.config = self.load_config(model_config) def load_model(self, model_type, model_path, language, task): encoder_path = f"{model_type}/{model_type}-encoder.axmodel" decoder_path = f"{model_type}/{model_type}-decoder.axmodel" model_config_file = f"{model_type}/{model_type}_config.json" token_file = f"{model_type}/{model_type}-tokens.txt" required_files = [ os.path.join(model_path, i) for i in (encoder_path, decoder_path, model_config_file, token_file) ] # Check file existence for i, file_path in enumerate(required_files): assert os.path.exists(file_path), f"{file_path} NOT exist" # Load encoder encoder = axe.InferenceSession( required_files[0], providers=["AxEngineExecutionProvider"] ) # Load decoder main decoder = axe.InferenceSession( required_files[1], providers=["AxEngineExecutionProvider"] ) # Load tokens model_config = json.load(open(required_files[2], "r")) model_config["all_language_tokens"] = [ int(i) for i in model_config["all_language_tokens"].split(",") ] model_config["all_language_codes"] = [ i for i in model_config["all_language_codes"].split(",") ] self.id2token = self.load_tokens(required_files[3]) self.lang2token = { k: v for k, v in zip( model_config["all_language_codes"], model_config["all_language_tokens"] ) } self.task2token = { "transcribe": model_config["transcribe"], "translate": model_config["translate"], } return encoder, decoder, model_config def load_config(self, model_config): config = WhisperConfig config.n_mels = model_config["n_mels"] config.sample_rate = 16000 config.n_fft = 480 config.hop_length = 160 config.sot = model_config["sot"] config.eot = model_config["eot"] config.blank_id = model_config["blank_id"] config.no_timestamps = model_config["no_timestamps"] config.no_speech = model_config["no_speech"] config.translate = model_config["translate"] config.transcribe = model_config["transcribe"] config.n_vocab = model_config["n_vocab"] config.n_text_ctx = model_config["n_text_ctx"] config.n_text_state = model_config["n_text_state"] config.n_text_layer = model_config["n_text_layer"] lang_token = model_config["all_language_tokens"][ model_config["all_language_codes"].index(self.language) ] task_token = ( config.transcribe if self.task == "transcribe" else config.translate ) config.sot_sequence = np.array( [config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32 ) return config def load_tokens(self, filename): tokens = dict() with open(filename, "r") as f: for line in f: t, i = line.split() tokens[int(i)] = t return tokens def load_audio(self, audio: str): samples, sample_rate = librosa.load(audio, sr=self.config.sample_rate) if sample_rate != self.config.sample_rate: samples = librosa.resample( samples, orig_sr=sample_rate, target_sr=self.config.sample_rate ) samples = np.ascontiguousarray(samples) return samples, self.config.sample_rate def compute_feature(self, audio: np.ndarray): mel = librosa.feature.melspectrogram( y=audio, sr=self.config.sample_rate, n_fft=self.config.n_fft, hop_length=self.config.hop_length, window="hann", center=True, pad_mode="reflect", power=2.0, n_mels=self.config.n_mels, ) log_spec = np.log10(np.maximum(mel, 1e-10)) log_spec = np.maximum(log_spec, log_spec.max() - 8.0) mel = (log_spec + 4.0) / 4.0 target = 3000 if mel.shape[1] > target: # -50 so that there are some zero tail paddings. mel = mel[:, :target] mel[:, -50:] = 0 # We don't need to pad it to 30 seconds now! if mel.shape[1] < target: mel = np.concatenate( ( mel, np.zeros( (self.config.n_mels, target - mel.shape[1]), dtype=np.float32 ), ), axis=-1, ) return mel[np.newaxis, ...] def run_encoder( self, mel: np.ndarray, ) -> List[np.ndarray]: cross_kv = self.encoder.run( None, { self.encoder.get_inputs()[0].name: mel, }, ) return cross_kv def run_decoder(self, inputs: List[np.ndarray]) -> List[np.ndarray]: feed = { self.decoder.get_inputs()[i].name: inputs[i] for i in range(len(inputs)) } out = self.decoder.run( None, feed, ) return out def get_self_cache(self) -> List[np.ndarray]: batch_size = 1 self_k = np.zeros( ( self.config.n_text_layer, batch_size, self.config.n_text_ctx, self.config.n_text_state, ), dtype=np.float32, ) self_v = np.zeros( ( self.config.n_text_layer, batch_size, self.config.n_text_ctx, self.config.n_text_state, ), dtype=np.float32, ) return self_k, self_v def causal_mask_1d(self, n: int, L: int): """ Returns a 1-D int mask of shape (L,) with: 0 -> allowed 1 -> masked (will be converted to -inf later) """ mask = np.ones((L,), dtype=np.int32) if n > 0: mask[:n] = 0 return mask def run_mel(self, mel): cross_k, cross_v = self.run_encoder(mel) self_k, self_v = self.get_self_cache() offset = np.array([0], dtype=np.int32) for t in self.config.sot_sequence: token = np.array([[t]], dtype=np.int32) # sot mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx) logits, this_self_k, this_self_v = self.run_decoder( [token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask] ) self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v offset += 1 idx = logits[0, 0].argmax() eot = self.config.eot ans = [] while idx != eot and offset.item() < self.config.n_text_ctx: ans.append(idx) token = np.array([[idx]], dtype=np.int32) mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx) logits, this_self_k, this_self_v = self.run_decoder( [token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask] ) self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v offset += 1 idx = logits[0, 0].argmax() # print(ans) s = b"" for i in ans: if i in self.id2token: s += base64.b64decode(self.id2token[i]) text = s.decode().strip() if self.language == "zh": try: sim_zh = zhconv.convert(text, "zh-hans") return sim_zh except: return text return text def run( self, audio: Union[str, np.ndarray], language: str = None, task: str = None ) -> str: if isinstance(audio, str): audio, sample_rate = self.load_audio(audio) mel = self.compute_feature(audio) if language is not None and self.language != language: self.config.sot_sequence[1] = self.lang2token(language) if task is not None and self.task != task: self.config.sot_sequence[2] = self.task2token(task) return self.run_mel(mel)