|
|
import argparse |
|
|
import json |
|
|
from datetime import datetime, timezone |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
|
|
|
from liteasr_ffi import LiteAsrFfi |
|
|
|
|
|
try: |
|
|
from huggingface_hub import hf_hub_download |
|
|
except Exception: |
|
|
hf_hub_download = None |
|
|
|
|
|
LANGUAGE_MAP = { |
|
|
"japanese": "ja", |
|
|
"english": "en", |
|
|
"chinese": "zh", |
|
|
"korean": "ko", |
|
|
"french": "fr", |
|
|
"german": "de", |
|
|
"spanish": "es", |
|
|
"italian": "it", |
|
|
"portuguese": "pt", |
|
|
"russian": "ru", |
|
|
} |
|
|
|
|
|
|
|
|
def load_json(path: Path) -> dict | None: |
|
|
if not path.exists(): |
|
|
return None |
|
|
return json.loads(path.read_text(encoding="utf-8")) |
|
|
|
|
|
|
|
|
def save_json(path: Path, payload: dict) -> None: |
|
|
path.write_text( |
|
|
json.dumps(payload, indent=2, ensure_ascii=True) + "\n", |
|
|
encoding="utf-8", |
|
|
) |
|
|
|
|
|
|
|
|
def resolve_model_id(model_id: str | None, model_name: str | None) -> str: |
|
|
if model_id: |
|
|
return model_id |
|
|
if model_name: |
|
|
return model_name.replace("__", "/") |
|
|
model_file = Path(".model_id") |
|
|
if model_file.exists(): |
|
|
value = model_file.read_text(encoding="utf-8").strip() |
|
|
if value: |
|
|
return value |
|
|
raise SystemExit("Model ID is required. Pass --model-id or --model-name.") |
|
|
|
|
|
|
|
|
def model_name_from_id(model_id: str) -> str: |
|
|
return model_id.replace("/", "__") |
|
|
|
|
|
|
|
|
def resolve_audio_path(path_value: str | None) -> Path: |
|
|
if path_value: |
|
|
return Path(path_value) |
|
|
fallback = Path("samples") / "a01.wav" |
|
|
if fallback.exists(): |
|
|
return fallback |
|
|
raise SystemExit("Audio path is required. Pass --audio or add samples/a01.wav.") |
|
|
|
|
|
|
|
|
def download_file(repo_id: str, target_dir: Path, filename: str, required: bool) -> None: |
|
|
target_dir.mkdir(parents=True, exist_ok=True) |
|
|
path = target_dir / filename |
|
|
if path.exists(): |
|
|
return |
|
|
if hf_hub_download is None: |
|
|
if required: |
|
|
raise SystemExit( |
|
|
"huggingface_hub not installed. Install it or place tokenizer/config files manually." |
|
|
) |
|
|
return |
|
|
try: |
|
|
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=target_dir) |
|
|
except Exception: |
|
|
if required: |
|
|
raise |
|
|
|
|
|
|
|
|
def normalize_token_list(value) -> list[int]: |
|
|
if value is None: |
|
|
return [] |
|
|
if isinstance(value, int): |
|
|
return [] if value < 0 else [int(value)] |
|
|
if isinstance(value, list): |
|
|
items = [] |
|
|
for item in value: |
|
|
if isinstance(item, (int, float)): |
|
|
items.append(int(item)) |
|
|
return items |
|
|
return [] |
|
|
|
|
|
|
|
|
def parse_forced_decoder_ids(value, decoder_start_token_id: int | None) -> list[int] | None: |
|
|
if not value: |
|
|
return None |
|
|
try: |
|
|
ordered = sorted(value, key=lambda x: x[0]) |
|
|
tokens = [int(item[1]) for item in ordered] |
|
|
if decoder_start_token_id is not None: |
|
|
if not tokens or tokens[0] != int(decoder_start_token_id): |
|
|
tokens = [int(decoder_start_token_id)] + tokens |
|
|
return tokens |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
|
|
|
def load_vocab(tokenizer_json_path: Path) -> dict[str, int]: |
|
|
payload = load_json(tokenizer_json_path) |
|
|
if payload is None: |
|
|
raise SystemExit(f"tokenizer.json not found: {tokenizer_json_path}") |
|
|
vocab_obj = payload.get("model", {}).get("vocab") |
|
|
if not isinstance(vocab_obj, dict): |
|
|
raise SystemExit("tokenizer.json missing model.vocab") |
|
|
vocab: dict[str, int] = {} |
|
|
for token, value in vocab_obj.items(): |
|
|
if isinstance(value, int): |
|
|
vocab[token] = value |
|
|
for item in payload.get("added_tokens", []) or []: |
|
|
if isinstance(item, dict): |
|
|
content = item.get("content") |
|
|
idx = item.get("id") |
|
|
if isinstance(content, str) and isinstance(idx, int): |
|
|
vocab[content] = idx |
|
|
return vocab |
|
|
|
|
|
|
|
|
def remove_optional_prompt_tokens( |
|
|
prompt_ids: list[int], |
|
|
vocab: dict[str, int], |
|
|
language: str, |
|
|
omit_language_token: bool, |
|
|
omit_notimestamps_token: bool, |
|
|
) -> list[int]: |
|
|
out = list(prompt_ids) |
|
|
if omit_language_token: |
|
|
lang_key = f"<|{language.lower()}|>" |
|
|
if lang_key not in vocab: |
|
|
mapped = LANGUAGE_MAP.get(language.lower()) |
|
|
if mapped: |
|
|
lang_key = f"<|{mapped}|>" |
|
|
lang_id = vocab.get(lang_key) |
|
|
if lang_id is not None: |
|
|
out = [tid for tid in out if tid != lang_id] |
|
|
if omit_notimestamps_token: |
|
|
nt_id = vocab.get("<|notimestamps|>") |
|
|
if nt_id is not None: |
|
|
out = [tid for tid in out if tid != nt_id] |
|
|
return out |
|
|
|
|
|
|
|
|
def log_softmax(x: np.ndarray) -> np.ndarray: |
|
|
x_max = np.max(x, axis=-1, keepdims=True) |
|
|
y = x - x_max |
|
|
logsum = np.log(np.sum(np.exp(y), axis=-1, keepdims=True)) |
|
|
return y - logsum |
|
|
|
|
|
|
|
|
def sample_next(logits_slice: np.ndarray, temperature: float) -> int: |
|
|
if temperature == 0.0: |
|
|
return int(np.argmax(logits_slice)) |
|
|
scaled = logits_slice / float(temperature) |
|
|
probs = np.exp(log_softmax(scaled)) |
|
|
probs = probs / np.sum(probs) |
|
|
return int(np.random.choice(len(probs), p=probs)) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
parser = argparse.ArgumentParser( |
|
|
description="ONNX decode with Rust FFI preprocessing/prompt/decode helpers" |
|
|
) |
|
|
parser.add_argument("--dll-path", help="Path to liteasr_ffi.dll") |
|
|
parser.add_argument("--model-id", help="Hugging Face model ID") |
|
|
parser.add_argument("--model-name", help="Local model name under models/") |
|
|
parser.add_argument("--onnx-dir", help="Directory containing ONNX files") |
|
|
parser.add_argument("--tokenizer-dir", help="Directory containing tokenizer files") |
|
|
parser.add_argument( |
|
|
"--tokenizer-id", |
|
|
help="Tokenizer/model ID to fetch tokenizer/config files (defaults to model-id).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--config-id", |
|
|
help="Config/model ID to fetch config.json (defaults to model-id).", |
|
|
) |
|
|
parser.add_argument("--audio", help="Path to input wav file") |
|
|
parser.add_argument("--input-features", help="Path to input_features.npy") |
|
|
parser.add_argument("--baseline-features", action="store_true") |
|
|
parser.add_argument("--baseline-meta", action="store_true") |
|
|
parser.add_argument("--run-name", default="a01") |
|
|
parser.add_argument("--out-root", default="artifacts/onnx_py_ffi") |
|
|
parser.add_argument("--out-dir", help="Output directory") |
|
|
parser.add_argument("--language", default="japanese") |
|
|
parser.add_argument("--task", default="transcribe") |
|
|
parser.add_argument("--with-timestamps", action="store_true") |
|
|
parser.add_argument("--omit-language-token", action="store_true") |
|
|
parser.add_argument("--omit-notimestamps-token", action="store_true") |
|
|
parser.add_argument("--lite-whisper-prompt", action="store_true") |
|
|
parser.add_argument("--max-new-tokens", type=int, default=128) |
|
|
parser.add_argument("--num-beams", type=int) |
|
|
parser.add_argument("--temperature", type=float) |
|
|
parser.add_argument("--do-sample", action="store_true") |
|
|
parser.add_argument("--seed", type=int) |
|
|
parser.add_argument("--no-speech-threshold", type=float) |
|
|
parser.add_argument("--logprob-threshold", type=float) |
|
|
parser.add_argument("--dump-step0", action="store_true") |
|
|
args = parser.parse_args() |
|
|
|
|
|
ffi = LiteAsrFfi(args.dll_path) |
|
|
model_id = resolve_model_id(args.model_id, args.model_name) |
|
|
model_name = args.model_name or model_name_from_id(model_id) |
|
|
onnx_dir = Path(args.onnx_dir) if args.onnx_dir else Path("models") / model_name / "onnx" |
|
|
tokenizer_dir = ( |
|
|
Path(args.tokenizer_dir) |
|
|
if args.tokenizer_dir |
|
|
else Path("models") / model_name / "tokenizer" |
|
|
) |
|
|
|
|
|
tokenizer_id = args.tokenizer_id or model_id |
|
|
if model_id.startswith("efficient-speech/") and args.tokenizer_id is None: |
|
|
tokenizer_id = ( |
|
|
"openai/whisper-large-v3-turbo" |
|
|
if "turbo" in model_id |
|
|
else "openai/whisper-large-v3" |
|
|
) |
|
|
config_id = args.config_id or model_id |
|
|
|
|
|
download_file(tokenizer_id, tokenizer_dir, "tokenizer.json", required=True) |
|
|
download_file(config_id, tokenizer_dir, "config.json", required=True) |
|
|
download_file(tokenizer_id, tokenizer_dir, "preprocessor_config.json", required=False) |
|
|
download_file(config_id, tokenizer_dir, "generation_config.json", required=False) |
|
|
|
|
|
encoder_path = onnx_dir / "encoder_model.onnx" |
|
|
decoder_path = onnx_dir / "decoder_model.onnx" |
|
|
decoder_past_path = onnx_dir / "decoder_with_past_model.onnx" |
|
|
for p in (encoder_path, decoder_path, decoder_past_path): |
|
|
if not p.exists(): |
|
|
raise SystemExit(f"Missing ONNX file: {p}") |
|
|
|
|
|
config = load_json(tokenizer_dir / "config.json") or {} |
|
|
preprocessor = load_json(tokenizer_dir / "preprocessor_config.json") or {} |
|
|
generation = load_json(tokenizer_dir / "generation_config.json") or {} |
|
|
tokenizer_json_path = tokenizer_dir / "tokenizer.json" |
|
|
vocab = load_vocab(tokenizer_json_path) |
|
|
|
|
|
decoder_start_token_id = config.get("decoder_start_token_id") |
|
|
eos_token_id = config.get("eos_token_id") |
|
|
if decoder_start_token_id is None: |
|
|
raise SystemExit("decoder_start_token_id not found in config.json") |
|
|
|
|
|
sampling_rate = int(preprocessor.get("sampling_rate", 16000)) |
|
|
n_mels = int(preprocessor.get("feature_size") or preprocessor.get("n_mels") or 80) |
|
|
suppress_tokens = normalize_token_list(generation.get("suppress_tokens")) |
|
|
begin_suppress_tokens = normalize_token_list(generation.get("begin_suppress_tokens")) |
|
|
num_beams = int(generation.get("num_beams", 1) or 1) |
|
|
temperature = float(generation.get("temperature", 1.0) or 1.0) |
|
|
do_sample = bool(generation.get("do_sample", False)) |
|
|
no_speech_threshold = generation.get("no_speech_threshold") |
|
|
logprob_threshold = generation.get("logprob_threshold") |
|
|
|
|
|
baseline_meta = None |
|
|
if args.baseline_meta: |
|
|
meta_path = Path("artifacts") / "baseline" / model_name / args.run_name / "meta.json" |
|
|
if meta_path.exists(): |
|
|
baseline_meta = load_json(meta_path) |
|
|
gen = (baseline_meta or {}).get("generation_config") or {} |
|
|
if gen.get("num_beams") is not None: |
|
|
num_beams = int(gen["num_beams"]) |
|
|
if gen.get("temperature") is not None: |
|
|
temperature = float(gen["temperature"]) |
|
|
if gen.get("do_sample") is not None: |
|
|
do_sample = bool(gen["do_sample"]) |
|
|
if gen.get("suppress_tokens") is not None: |
|
|
suppress_tokens = normalize_token_list(gen.get("suppress_tokens")) |
|
|
if gen.get("begin_suppress_tokens") is not None: |
|
|
begin_suppress_tokens = normalize_token_list(gen.get("begin_suppress_tokens")) |
|
|
if gen.get("no_speech_threshold") is not None: |
|
|
no_speech_threshold = gen.get("no_speech_threshold") |
|
|
if gen.get("logprob_threshold") is not None: |
|
|
logprob_threshold = gen.get("logprob_threshold") |
|
|
|
|
|
if args.num_beams is not None: |
|
|
num_beams = int(args.num_beams) |
|
|
if args.temperature is not None: |
|
|
temperature = float(args.temperature) |
|
|
if args.do_sample: |
|
|
do_sample = True |
|
|
if args.no_speech_threshold is not None: |
|
|
no_speech_threshold = float(args.no_speech_threshold) |
|
|
if args.logprob_threshold is not None: |
|
|
logprob_threshold = float(args.logprob_threshold) |
|
|
if args.seed is not None: |
|
|
np.random.seed(args.seed) |
|
|
|
|
|
encoder_sess = ort.InferenceSession(str(encoder_path), providers=["CPUExecutionProvider"]) |
|
|
decoder_sess = ort.InferenceSession(str(decoder_path), providers=["CPUExecutionProvider"]) |
|
|
decoder_past_sess = ort.InferenceSession( |
|
|
str(decoder_past_path), providers=["CPUExecutionProvider"] |
|
|
) |
|
|
|
|
|
if args.input_features: |
|
|
input_features = np.load(Path(args.input_features)).astype(np.float32) |
|
|
elif args.baseline_features: |
|
|
baseline_path = ( |
|
|
Path("artifacts") / "baseline" / model_name / args.run_name / "input_features.npy" |
|
|
) |
|
|
if not baseline_path.exists(): |
|
|
raise SystemExit(f"baseline input_features not found: {baseline_path}") |
|
|
input_features = np.load(baseline_path).astype(np.float32) |
|
|
else: |
|
|
audio_path = resolve_audio_path(args.audio) |
|
|
log_mel = ffi.preprocess_wav(audio_path, sampling_rate, n_mels) |
|
|
input_features = log_mel[np.newaxis, :, :].astype(np.float32) |
|
|
|
|
|
if input_features.ndim != 3: |
|
|
raise SystemExit(f"input_features has unexpected shape: {input_features.shape}") |
|
|
|
|
|
enc_out_names = [o.name for o in encoder_sess.get_outputs()] |
|
|
enc_out = encoder_sess.run( |
|
|
enc_out_names, |
|
|
{encoder_sess.get_inputs()[0].name: input_features}, |
|
|
) |
|
|
enc_map = dict(zip(enc_out_names, enc_out)) |
|
|
encoder_hidden = enc_map.get("encoder_hidden_states", enc_out[0]) |
|
|
|
|
|
omit_language_token = args.omit_language_token |
|
|
omit_notimestamps_token = args.omit_notimestamps_token |
|
|
if args.lite_whisper_prompt: |
|
|
if "turbo-acc" in model_id: |
|
|
omit_notimestamps_token = True |
|
|
elif "turbo" in model_id: |
|
|
omit_language_token = True |
|
|
|
|
|
prompt_ids = None |
|
|
if baseline_meta: |
|
|
prompt_ids = parse_forced_decoder_ids( |
|
|
baseline_meta.get("forced_decoder_ids"), |
|
|
int(decoder_start_token_id), |
|
|
) |
|
|
if prompt_ids is None: |
|
|
prompt_ids = ffi.build_prompt_ids( |
|
|
tokenizer_json_path, |
|
|
args.language, |
|
|
args.task, |
|
|
args.with_timestamps, |
|
|
omit_language_token, |
|
|
omit_notimestamps_token, |
|
|
) |
|
|
else: |
|
|
prompt_ids = remove_optional_prompt_tokens( |
|
|
prompt_ids, |
|
|
vocab, |
|
|
args.language, |
|
|
omit_language_token, |
|
|
omit_notimestamps_token, |
|
|
) |
|
|
|
|
|
dec_out_names = [o.name for o in decoder_sess.get_outputs()] |
|
|
dec_past_out_names = [o.name for o in decoder_past_sess.get_outputs()] |
|
|
dec_past_inputs = decoder_past_sess.get_inputs() |
|
|
needs_encoder_hidden = any(inp.name == "encoder_hidden_states" for inp in dec_past_inputs) |
|
|
|
|
|
input_ids = np.array([prompt_ids], dtype=np.int64) |
|
|
outputs = decoder_sess.run( |
|
|
dec_out_names, |
|
|
{ |
|
|
decoder_sess.get_inputs()[0].name: input_ids, |
|
|
decoder_sess.get_inputs()[1].name: encoder_hidden, |
|
|
}, |
|
|
) |
|
|
dec_map = dict(zip(dec_out_names, outputs)) |
|
|
logits = dec_map.get("logits", outputs[0]) |
|
|
raw_step0_logits = logits.copy() |
|
|
|
|
|
step0_slice = ffi.apply_suppression( |
|
|
np.ascontiguousarray(logits[0, -1], dtype=np.float32), |
|
|
suppress_tokens, |
|
|
begin_suppress_tokens, |
|
|
step=0, |
|
|
) |
|
|
step0_present = {k: v for k, v in dec_map.items() if k.startswith("present.")} |
|
|
present = dict(step0_present) |
|
|
token_ids = list(prompt_ids) |
|
|
|
|
|
nospeech_id = vocab.get("<|nospeech|>") |
|
|
no_speech_prob = None |
|
|
sum_logprob = 0.0 |
|
|
gen_count = 0 |
|
|
|
|
|
def greedy_or_sample(logits_slice: np.ndarray) -> int: |
|
|
if do_sample or temperature != 1.0: |
|
|
return sample_next(logits_slice, temperature) |
|
|
return int(np.argmax(logits_slice)) |
|
|
|
|
|
if num_beams <= 1: |
|
|
log_probs_first = log_softmax(step0_slice) |
|
|
if nospeech_id is not None and 0 <= nospeech_id < log_probs_first.shape[0]: |
|
|
no_speech_prob = float(np.exp(log_probs_first[nospeech_id])) |
|
|
next_token = greedy_or_sample(step0_slice) |
|
|
sum_logprob += float(log_probs_first[next_token]) |
|
|
gen_count += 1 |
|
|
token_ids.append(next_token) |
|
|
|
|
|
step = 1 |
|
|
for _ in range(args.max_new_tokens - 1): |
|
|
if eos_token_id is not None and next_token == int(eos_token_id): |
|
|
break |
|
|
past_inputs = {"input_ids": np.array([[next_token]], dtype=np.int64)} |
|
|
if needs_encoder_hidden: |
|
|
past_inputs["encoder_hidden_states"] = encoder_hidden |
|
|
for input_meta in dec_past_inputs[1:]: |
|
|
name = input_meta.name |
|
|
present_name = name.replace("past.", "present.") |
|
|
past_inputs[name] = present[present_name] |
|
|
outputs = decoder_past_sess.run(dec_past_out_names, past_inputs) |
|
|
out_map = dict(zip(dec_past_out_names, outputs)) |
|
|
logits = out_map.get("logits", outputs[0]) |
|
|
logits_slice = ffi.apply_suppression( |
|
|
np.ascontiguousarray(logits[0, -1], dtype=np.float32), |
|
|
suppress_tokens, |
|
|
begin_suppress_tokens, |
|
|
step=step, |
|
|
) |
|
|
present = {k: v for k, v in out_map.items() if k.startswith("present.")} |
|
|
log_probs = log_softmax(logits_slice) |
|
|
next_token = greedy_or_sample(logits_slice) |
|
|
sum_logprob += float(log_probs[next_token]) |
|
|
gen_count += 1 |
|
|
token_ids.append(next_token) |
|
|
step += 1 |
|
|
else: |
|
|
first_log_probs = log_softmax(step0_slice / float(temperature if temperature > 0 else 1.0)) |
|
|
if nospeech_id is not None and 0 <= nospeech_id < first_log_probs.shape[0]: |
|
|
no_speech_prob = float(np.exp(first_log_probs[nospeech_id])) |
|
|
top_indices = np.argsort(first_log_probs)[::-1][:num_beams] |
|
|
beams = [] |
|
|
for idx in top_indices: |
|
|
beams.append( |
|
|
{ |
|
|
"tokens": token_ids + [int(idx)], |
|
|
"score": float(first_log_probs[idx]), |
|
|
"present": present, |
|
|
} |
|
|
) |
|
|
|
|
|
for step in range(1, args.max_new_tokens): |
|
|
candidates = [] |
|
|
for beam in beams: |
|
|
if eos_token_id is not None and beam["tokens"][-1] == int(eos_token_id): |
|
|
candidates.append(beam) |
|
|
continue |
|
|
past_inputs = {"input_ids": np.array([[beam["tokens"][-1]]], dtype=np.int64)} |
|
|
if needs_encoder_hidden: |
|
|
past_inputs["encoder_hidden_states"] = encoder_hidden |
|
|
for input_meta in dec_past_inputs[1:]: |
|
|
name = input_meta.name |
|
|
present_name = name.replace("past.", "present.") |
|
|
past_inputs[name] = beam["present"][present_name] |
|
|
outputs = decoder_past_sess.run(dec_past_out_names, past_inputs) |
|
|
out_map = dict(zip(dec_past_out_names, outputs)) |
|
|
logits = out_map.get("logits", outputs[0]) |
|
|
logits_slice = ffi.apply_suppression( |
|
|
np.ascontiguousarray(logits[0, -1], dtype=np.float32), |
|
|
suppress_tokens, |
|
|
begin_suppress_tokens, |
|
|
step=step, |
|
|
) |
|
|
present_next = {k: v for k, v in out_map.items() if k.startswith("present.")} |
|
|
log_probs = log_softmax(logits_slice / float(temperature if temperature > 0 else 1.0)) |
|
|
step_top = np.argsort(log_probs)[::-1][:num_beams] |
|
|
for idx in step_top: |
|
|
candidates.append( |
|
|
{ |
|
|
"tokens": beam["tokens"] + [int(idx)], |
|
|
"score": beam["score"] + float(log_probs[idx]), |
|
|
"present": present_next, |
|
|
} |
|
|
) |
|
|
candidates.sort(key=lambda x: x["score"], reverse=True) |
|
|
beams = candidates[:num_beams] |
|
|
if all( |
|
|
eos_token_id is not None and beam["tokens"][-1] == int(eos_token_id) |
|
|
for beam in beams |
|
|
): |
|
|
break |
|
|
best = max(beams, key=lambda x: x["score"]) |
|
|
token_ids = list(best["tokens"]) |
|
|
if len(token_ids) > len(prompt_ids): |
|
|
sum_logprob = float(best["score"]) |
|
|
gen_count = len(token_ids) - len(prompt_ids) |
|
|
|
|
|
avg_logprob = sum_logprob / gen_count if gen_count > 0 else None |
|
|
if ( |
|
|
no_speech_threshold is not None |
|
|
and no_speech_prob is not None |
|
|
and no_speech_prob > float(no_speech_threshold) |
|
|
and ( |
|
|
logprob_threshold is None |
|
|
or (avg_logprob is not None and avg_logprob < float(logprob_threshold)) |
|
|
) |
|
|
): |
|
|
token_ids = list(prompt_ids) |
|
|
|
|
|
transcript = ffi.decode_tokens( |
|
|
tokenizer_json_path=tokenizer_json_path, |
|
|
token_ids=token_ids, |
|
|
skip_special_tokens=True, |
|
|
).strip() |
|
|
|
|
|
if args.out_dir: |
|
|
out_dir = Path(args.out_dir) |
|
|
else: |
|
|
out_dir = Path(args.out_root) / model_name / args.run_name |
|
|
out_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
if args.dump_step0: |
|
|
step0_dir = out_dir / "step0" |
|
|
step0_dir.mkdir(parents=True, exist_ok=True) |
|
|
np.save(step0_dir / "decoder_logits_step0.npy", raw_step0_logits) |
|
|
np.savez(step0_dir / "present_step0.npz", **step0_present) |
|
|
np.save(step0_dir / "encoder_hidden.npy", encoder_hidden) |
|
|
|
|
|
np.save(out_dir / "input_features.npy", input_features) |
|
|
np.save(out_dir / "token_ids.npy", np.array([token_ids], dtype=np.int64)) |
|
|
(out_dir / "transcript.txt").write_text(transcript + "\n", encoding="utf-8") |
|
|
|
|
|
meta = { |
|
|
"model_id": model_id, |
|
|
"model_name": model_name, |
|
|
"dll_path": str(ffi.dll_path), |
|
|
"audio": args.audio, |
|
|
"sampling_rate": sampling_rate, |
|
|
"n_mels": int(n_mels), |
|
|
"input_features_shape": list(input_features.shape), |
|
|
"prompt_ids": prompt_ids, |
|
|
"decoder_start_token_id": int(decoder_start_token_id), |
|
|
"eos_token_id": eos_token_id, |
|
|
"max_new_tokens": args.max_new_tokens, |
|
|
"num_beams": num_beams, |
|
|
"temperature": temperature, |
|
|
"do_sample": do_sample, |
|
|
"no_speech_threshold": no_speech_threshold, |
|
|
"logprob_threshold": logprob_threshold, |
|
|
"no_speech_prob": no_speech_prob, |
|
|
"avg_logprob": avg_logprob, |
|
|
"omit_language_token": omit_language_token, |
|
|
"omit_notimestamps_token": omit_notimestamps_token, |
|
|
"lite_whisper_prompt": args.lite_whisper_prompt, |
|
|
"baseline_meta_applied": bool(baseline_meta), |
|
|
"token_ids_shape": [1, len(token_ids)], |
|
|
"created_at": datetime.now(timezone.utc).isoformat(), |
|
|
} |
|
|
save_json(out_dir / "meta.json", meta) |
|
|
|
|
|
safe_transcript = transcript.encode("cp932", errors="replace").decode("cp932") |
|
|
print(f"Saved artifacts to: {out_dir}") |
|
|
print(f"Transcript: {safe_transcript}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|