Whisper / python /whisper_ax.py
inoryQwQ
Update cpp bins, python scripts, add English readme
798e40d
import axengine as axe
import numpy as np
import librosa
import os
from typing import Union, List
import json
from dataclasses import dataclass, field
import zhconv
import base64
@dataclass
class WhisperConfig:
n_mels: int = 0
sample_rate: int = 0
n_fft: int = 0
hop_length: int = 0
sot: int = 0
eot: int = 0
blank_id: int = 0
no_timestamps: int = 0
no_speech: int = 0
translate: int = 0
transcribe: int = 0
n_vocab: int = 0
n_text_ctx: int = 0
n_text_state: int = 0
sot_sequence: np.ndarray = field(
default_factory=lambda: np.array([0, 0, 0, 0], dtype=np.int32)
)
class Whisper:
def __init__(self, model_type: str, model_path: str, language: str, task: str):
self.language = language
self.task = task
self.encoder, self.decoder, model_config = self.load_model(
model_type, model_path, language, task
)
self.config = self.load_config(model_config)
def load_model(self, model_type, model_path, language, task):
encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
decoder_path = f"{model_type}/{model_type}-decoder.axmodel"
model_config_file = f"{model_type}/{model_type}_config.json"
token_file = f"{model_type}/{model_type}-tokens.txt"
required_files = [
os.path.join(model_path, i)
for i in (encoder_path, decoder_path, model_config_file, token_file)
]
# Check file existence
for i, file_path in enumerate(required_files):
assert os.path.exists(file_path), f"{file_path} NOT exist"
# Load encoder
encoder = axe.InferenceSession(
required_files[0], providers=["AxEngineExecutionProvider"]
)
# Load decoder main
decoder = axe.InferenceSession(
required_files[1], providers=["AxEngineExecutionProvider"]
)
# Load tokens
model_config = json.load(open(required_files[2], "r"))
model_config["all_language_tokens"] = [
int(i) for i in model_config["all_language_tokens"].split(",")
]
model_config["all_language_codes"] = [
i for i in model_config["all_language_codes"].split(",")
]
self.id2token = self.load_tokens(required_files[3])
self.lang2token = {
k: v
for k, v in zip(
model_config["all_language_codes"], model_config["all_language_tokens"]
)
}
self.task2token = {
"transcribe": model_config["transcribe"],
"translate": model_config["translate"],
}
return encoder, decoder, model_config
def load_config(self, model_config):
config = WhisperConfig
config.n_mels = model_config["n_mels"]
config.sample_rate = 16000
config.n_fft = 480
config.hop_length = 160
config.sot = model_config["sot"]
config.eot = model_config["eot"]
config.blank_id = model_config["blank_id"]
config.no_timestamps = model_config["no_timestamps"]
config.no_speech = model_config["no_speech"]
config.translate = model_config["translate"]
config.transcribe = model_config["transcribe"]
config.n_vocab = model_config["n_vocab"]
config.n_text_ctx = model_config["n_text_ctx"]
config.n_text_state = model_config["n_text_state"]
config.n_text_layer = model_config["n_text_layer"]
lang_token = model_config["all_language_tokens"][
model_config["all_language_codes"].index(self.language)
]
task_token = (
config.transcribe if self.task == "transcribe" else config.translate
)
config.sot_sequence = np.array(
[config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
)
return config
def load_tokens(self, filename):
tokens = dict()
with open(filename, "r") as f:
for line in f:
t, i = line.split()
tokens[int(i)] = t
return tokens
def load_audio(self, audio: str):
samples, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
if sample_rate != self.config.sample_rate:
samples = librosa.resample(
samples, orig_sr=sample_rate, target_sr=self.config.sample_rate
)
samples = np.ascontiguousarray(samples)
return samples, self.config.sample_rate
def compute_feature(self, audio: np.ndarray):
mel = librosa.feature.melspectrogram(
y=audio,
sr=self.config.sample_rate,
n_fft=self.config.n_fft,
hop_length=self.config.hop_length,
window="hann",
center=True,
pad_mode="reflect",
power=2.0,
n_mels=self.config.n_mels,
)
log_spec = np.log10(np.maximum(mel, 1e-10))
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
mel = (log_spec + 4.0) / 4.0
target = 3000
if mel.shape[1] > target:
# -50 so that there are some zero tail paddings.
mel = mel[:, :target]
mel[:, -50:] = 0
# We don't need to pad it to 30 seconds now!
if mel.shape[1] < target:
mel = np.concatenate(
(
mel,
np.zeros(
(self.config.n_mels, target - mel.shape[1]), dtype=np.float32
),
),
axis=-1,
)
return mel[np.newaxis, ...]
def run_encoder(
self,
mel: np.ndarray,
) -> List[np.ndarray]:
cross_kv = self.encoder.run(
None,
{
self.encoder.get_inputs()[0].name: mel,
},
)
return cross_kv
def run_decoder(self, inputs: List[np.ndarray]) -> List[np.ndarray]:
feed = {
self.decoder.get_inputs()[i].name: inputs[i] for i in range(len(inputs))
}
out = self.decoder.run(
None,
feed,
)
return out
def get_self_cache(self) -> List[np.ndarray]:
batch_size = 1
self_k = np.zeros(
(
self.config.n_text_layer,
batch_size,
self.config.n_text_ctx,
self.config.n_text_state,
),
dtype=np.float32,
)
self_v = np.zeros(
(
self.config.n_text_layer,
batch_size,
self.config.n_text_ctx,
self.config.n_text_state,
),
dtype=np.float32,
)
return self_k, self_v
def causal_mask_1d(self, n: int, L: int):
"""
Returns a 1-D int mask of shape (L,) with:
0 -> allowed
1 -> masked (will be converted to -inf later)
"""
mask = np.ones((L,), dtype=np.int32)
if n > 0:
mask[:n] = 0
return mask
def run_mel(self, mel):
cross_k, cross_v = self.run_encoder(mel)
self_k, self_v = self.get_self_cache()
offset = np.array([0], dtype=np.int32)
for t in self.config.sot_sequence:
token = np.array([[t]], dtype=np.int32) # sot
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
logits, this_self_k, this_self_v = self.run_decoder(
[token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask]
)
self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k
self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v
offset += 1
idx = logits[0, 0].argmax()
eot = self.config.eot
ans = []
while idx != eot and offset.item() < self.config.n_text_ctx:
ans.append(idx)
token = np.array([[idx]], dtype=np.int32)
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
logits, this_self_k, this_self_v = self.run_decoder(
[token] + [self_k, self_v] + [cross_k, cross_v] + [offset, mask]
)
self_k[:, :, offset.item() : offset.item() + 1, :] = this_self_k
self_v[:, :, offset.item() : offset.item() + 1, :] = this_self_v
offset += 1
idx = logits[0, 0].argmax()
# print(ans)
s = b""
for i in ans:
if i in self.id2token:
s += base64.b64decode(self.id2token[i])
text = s.decode().strip()
if self.language == "zh":
try:
sim_zh = zhconv.convert(text, "zh-hans")
return sim_zh
except:
return text
return text
def run(
self, audio: Union[str, np.ndarray], language: str = None, task: str = None
) -> str:
if isinstance(audio, str):
audio, sample_rate = self.load_audio(audio)
mel = self.compute_feature(audio)
if language is not None and self.language != language:
self.config.sot_sequence[1] = self.lang2token(language)
if task is not None and self.task != task:
self.config.sot_sequence[2] = self.task2token(task)
return self.run_mel(mel)