LiteASR-ONNX-DLL / ffi_python /onnx_transcribe_ffi.py
zukky's picture
Upload folder using huggingface_hub
27a58dc verified
import argparse
import json
from datetime import datetime, timezone
from pathlib import Path
import numpy as np
import onnxruntime as ort
from liteasr_ffi import LiteAsrFfi
try:
from huggingface_hub import hf_hub_download
except Exception: # pragma: no cover
hf_hub_download = None
LANGUAGE_MAP = {
"japanese": "ja",
"english": "en",
"chinese": "zh",
"korean": "ko",
"french": "fr",
"german": "de",
"spanish": "es",
"italian": "it",
"portuguese": "pt",
"russian": "ru",
}
def load_json(path: Path) -> dict | None:
if not path.exists():
return None
return json.loads(path.read_text(encoding="utf-8"))
def save_json(path: Path, payload: dict) -> None:
path.write_text(
json.dumps(payload, indent=2, ensure_ascii=True) + "\n",
encoding="utf-8",
)
def resolve_model_id(model_id: str | None, model_name: str | None) -> str:
if model_id:
return model_id
if model_name:
return model_name.replace("__", "/")
model_file = Path(".model_id")
if model_file.exists():
value = model_file.read_text(encoding="utf-8").strip()
if value:
return value
raise SystemExit("Model ID is required. Pass --model-id or --model-name.")
def model_name_from_id(model_id: str) -> str:
return model_id.replace("/", "__")
def resolve_audio_path(path_value: str | None) -> Path:
if path_value:
return Path(path_value)
fallback = Path("samples") / "a01.wav"
if fallback.exists():
return fallback
raise SystemExit("Audio path is required. Pass --audio or add samples/a01.wav.")
def download_file(repo_id: str, target_dir: Path, filename: str, required: bool) -> None:
target_dir.mkdir(parents=True, exist_ok=True)
path = target_dir / filename
if path.exists():
return
if hf_hub_download is None:
if required:
raise SystemExit(
"huggingface_hub not installed. Install it or place tokenizer/config files manually."
)
return
try:
hf_hub_download(repo_id=repo_id, filename=filename, local_dir=target_dir)
except Exception:
if required:
raise
def normalize_token_list(value) -> list[int]:
if value is None:
return []
if isinstance(value, int):
return [] if value < 0 else [int(value)]
if isinstance(value, list):
items = []
for item in value:
if isinstance(item, (int, float)):
items.append(int(item))
return items
return []
def parse_forced_decoder_ids(value, decoder_start_token_id: int | None) -> list[int] | None:
if not value:
return None
try:
ordered = sorted(value, key=lambda x: x[0])
tokens = [int(item[1]) for item in ordered]
if decoder_start_token_id is not None:
if not tokens or tokens[0] != int(decoder_start_token_id):
tokens = [int(decoder_start_token_id)] + tokens
return tokens
except Exception:
return None
def load_vocab(tokenizer_json_path: Path) -> dict[str, int]:
payload = load_json(tokenizer_json_path)
if payload is None:
raise SystemExit(f"tokenizer.json not found: {tokenizer_json_path}")
vocab_obj = payload.get("model", {}).get("vocab")
if not isinstance(vocab_obj, dict):
raise SystemExit("tokenizer.json missing model.vocab")
vocab: dict[str, int] = {}
for token, value in vocab_obj.items():
if isinstance(value, int):
vocab[token] = value
for item in payload.get("added_tokens", []) or []:
if isinstance(item, dict):
content = item.get("content")
idx = item.get("id")
if isinstance(content, str) and isinstance(idx, int):
vocab[content] = idx
return vocab
def remove_optional_prompt_tokens(
prompt_ids: list[int],
vocab: dict[str, int],
language: str,
omit_language_token: bool,
omit_notimestamps_token: bool,
) -> list[int]:
out = list(prompt_ids)
if omit_language_token:
lang_key = f"<|{language.lower()}|>"
if lang_key not in vocab:
mapped = LANGUAGE_MAP.get(language.lower())
if mapped:
lang_key = f"<|{mapped}|>"
lang_id = vocab.get(lang_key)
if lang_id is not None:
out = [tid for tid in out if tid != lang_id]
if omit_notimestamps_token:
nt_id = vocab.get("<|notimestamps|>")
if nt_id is not None:
out = [tid for tid in out if tid != nt_id]
return out
def log_softmax(x: np.ndarray) -> np.ndarray:
x_max = np.max(x, axis=-1, keepdims=True)
y = x - x_max
logsum = np.log(np.sum(np.exp(y), axis=-1, keepdims=True))
return y - logsum
def sample_next(logits_slice: np.ndarray, temperature: float) -> int:
if temperature == 0.0:
return int(np.argmax(logits_slice))
scaled = logits_slice / float(temperature)
probs = np.exp(log_softmax(scaled))
probs = probs / np.sum(probs)
return int(np.random.choice(len(probs), p=probs))
def main() -> None:
parser = argparse.ArgumentParser(
description="ONNX decode with Rust FFI preprocessing/prompt/decode helpers"
)
parser.add_argument("--dll-path", help="Path to liteasr_ffi.dll")
parser.add_argument("--model-id", help="Hugging Face model ID")
parser.add_argument("--model-name", help="Local model name under models/")
parser.add_argument("--onnx-dir", help="Directory containing ONNX files")
parser.add_argument("--tokenizer-dir", help="Directory containing tokenizer files")
parser.add_argument(
"--tokenizer-id",
help="Tokenizer/model ID to fetch tokenizer/config files (defaults to model-id).",
)
parser.add_argument(
"--config-id",
help="Config/model ID to fetch config.json (defaults to model-id).",
)
parser.add_argument("--audio", help="Path to input wav file")
parser.add_argument("--input-features", help="Path to input_features.npy")
parser.add_argument("--baseline-features", action="store_true")
parser.add_argument("--baseline-meta", action="store_true")
parser.add_argument("--run-name", default="a01")
parser.add_argument("--out-root", default="artifacts/onnx_py_ffi")
parser.add_argument("--out-dir", help="Output directory")
parser.add_argument("--language", default="japanese")
parser.add_argument("--task", default="transcribe")
parser.add_argument("--with-timestamps", action="store_true")
parser.add_argument("--omit-language-token", action="store_true")
parser.add_argument("--omit-notimestamps-token", action="store_true")
parser.add_argument("--lite-whisper-prompt", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=128)
parser.add_argument("--num-beams", type=int)
parser.add_argument("--temperature", type=float)
parser.add_argument("--do-sample", action="store_true")
parser.add_argument("--seed", type=int)
parser.add_argument("--no-speech-threshold", type=float)
parser.add_argument("--logprob-threshold", type=float)
parser.add_argument("--dump-step0", action="store_true")
args = parser.parse_args()
ffi = LiteAsrFfi(args.dll_path)
model_id = resolve_model_id(args.model_id, args.model_name)
model_name = args.model_name or model_name_from_id(model_id)
onnx_dir = Path(args.onnx_dir) if args.onnx_dir else Path("models") / model_name / "onnx"
tokenizer_dir = (
Path(args.tokenizer_dir)
if args.tokenizer_dir
else Path("models") / model_name / "tokenizer"
)
tokenizer_id = args.tokenizer_id or model_id
if model_id.startswith("efficient-speech/") and args.tokenizer_id is None:
tokenizer_id = (
"openai/whisper-large-v3-turbo"
if "turbo" in model_id
else "openai/whisper-large-v3"
)
config_id = args.config_id or model_id
download_file(tokenizer_id, tokenizer_dir, "tokenizer.json", required=True)
download_file(config_id, tokenizer_dir, "config.json", required=True)
download_file(tokenizer_id, tokenizer_dir, "preprocessor_config.json", required=False)
download_file(config_id, tokenizer_dir, "generation_config.json", required=False)
encoder_path = onnx_dir / "encoder_model.onnx"
decoder_path = onnx_dir / "decoder_model.onnx"
decoder_past_path = onnx_dir / "decoder_with_past_model.onnx"
for p in (encoder_path, decoder_path, decoder_past_path):
if not p.exists():
raise SystemExit(f"Missing ONNX file: {p}")
config = load_json(tokenizer_dir / "config.json") or {}
preprocessor = load_json(tokenizer_dir / "preprocessor_config.json") or {}
generation = load_json(tokenizer_dir / "generation_config.json") or {}
tokenizer_json_path = tokenizer_dir / "tokenizer.json"
vocab = load_vocab(tokenizer_json_path)
decoder_start_token_id = config.get("decoder_start_token_id")
eos_token_id = config.get("eos_token_id")
if decoder_start_token_id is None:
raise SystemExit("decoder_start_token_id not found in config.json")
sampling_rate = int(preprocessor.get("sampling_rate", 16000))
n_mels = int(preprocessor.get("feature_size") or preprocessor.get("n_mels") or 80)
suppress_tokens = normalize_token_list(generation.get("suppress_tokens"))
begin_suppress_tokens = normalize_token_list(generation.get("begin_suppress_tokens"))
num_beams = int(generation.get("num_beams", 1) or 1)
temperature = float(generation.get("temperature", 1.0) or 1.0)
do_sample = bool(generation.get("do_sample", False))
no_speech_threshold = generation.get("no_speech_threshold")
logprob_threshold = generation.get("logprob_threshold")
baseline_meta = None
if args.baseline_meta:
meta_path = Path("artifacts") / "baseline" / model_name / args.run_name / "meta.json"
if meta_path.exists():
baseline_meta = load_json(meta_path)
gen = (baseline_meta or {}).get("generation_config") or {}
if gen.get("num_beams") is not None:
num_beams = int(gen["num_beams"])
if gen.get("temperature") is not None:
temperature = float(gen["temperature"])
if gen.get("do_sample") is not None:
do_sample = bool(gen["do_sample"])
if gen.get("suppress_tokens") is not None:
suppress_tokens = normalize_token_list(gen.get("suppress_tokens"))
if gen.get("begin_suppress_tokens") is not None:
begin_suppress_tokens = normalize_token_list(gen.get("begin_suppress_tokens"))
if gen.get("no_speech_threshold") is not None:
no_speech_threshold = gen.get("no_speech_threshold")
if gen.get("logprob_threshold") is not None:
logprob_threshold = gen.get("logprob_threshold")
if args.num_beams is not None:
num_beams = int(args.num_beams)
if args.temperature is not None:
temperature = float(args.temperature)
if args.do_sample:
do_sample = True
if args.no_speech_threshold is not None:
no_speech_threshold = float(args.no_speech_threshold)
if args.logprob_threshold is not None:
logprob_threshold = float(args.logprob_threshold)
if args.seed is not None:
np.random.seed(args.seed)
encoder_sess = ort.InferenceSession(str(encoder_path), providers=["CPUExecutionProvider"])
decoder_sess = ort.InferenceSession(str(decoder_path), providers=["CPUExecutionProvider"])
decoder_past_sess = ort.InferenceSession(
str(decoder_past_path), providers=["CPUExecutionProvider"]
)
if args.input_features:
input_features = np.load(Path(args.input_features)).astype(np.float32)
elif args.baseline_features:
baseline_path = (
Path("artifacts") / "baseline" / model_name / args.run_name / "input_features.npy"
)
if not baseline_path.exists():
raise SystemExit(f"baseline input_features not found: {baseline_path}")
input_features = np.load(baseline_path).astype(np.float32)
else:
audio_path = resolve_audio_path(args.audio)
log_mel = ffi.preprocess_wav(audio_path, sampling_rate, n_mels)
input_features = log_mel[np.newaxis, :, :].astype(np.float32)
if input_features.ndim != 3:
raise SystemExit(f"input_features has unexpected shape: {input_features.shape}")
enc_out_names = [o.name for o in encoder_sess.get_outputs()]
enc_out = encoder_sess.run(
enc_out_names,
{encoder_sess.get_inputs()[0].name: input_features},
)
enc_map = dict(zip(enc_out_names, enc_out))
encoder_hidden = enc_map.get("encoder_hidden_states", enc_out[0])
omit_language_token = args.omit_language_token
omit_notimestamps_token = args.omit_notimestamps_token
if args.lite_whisper_prompt:
if "turbo-acc" in model_id:
omit_notimestamps_token = True
elif "turbo" in model_id:
omit_language_token = True
prompt_ids = None
if baseline_meta:
prompt_ids = parse_forced_decoder_ids(
baseline_meta.get("forced_decoder_ids"),
int(decoder_start_token_id),
)
if prompt_ids is None:
prompt_ids = ffi.build_prompt_ids(
tokenizer_json_path,
args.language,
args.task,
args.with_timestamps,
omit_language_token,
omit_notimestamps_token,
)
else:
prompt_ids = remove_optional_prompt_tokens(
prompt_ids,
vocab,
args.language,
omit_language_token,
omit_notimestamps_token,
)
dec_out_names = [o.name for o in decoder_sess.get_outputs()]
dec_past_out_names = [o.name for o in decoder_past_sess.get_outputs()]
dec_past_inputs = decoder_past_sess.get_inputs()
needs_encoder_hidden = any(inp.name == "encoder_hidden_states" for inp in dec_past_inputs)
input_ids = np.array([prompt_ids], dtype=np.int64)
outputs = decoder_sess.run(
dec_out_names,
{
decoder_sess.get_inputs()[0].name: input_ids,
decoder_sess.get_inputs()[1].name: encoder_hidden,
},
)
dec_map = dict(zip(dec_out_names, outputs))
logits = dec_map.get("logits", outputs[0])
raw_step0_logits = logits.copy()
step0_slice = ffi.apply_suppression(
np.ascontiguousarray(logits[0, -1], dtype=np.float32),
suppress_tokens,
begin_suppress_tokens,
step=0,
)
step0_present = {k: v for k, v in dec_map.items() if k.startswith("present.")}
present = dict(step0_present)
token_ids = list(prompt_ids)
nospeech_id = vocab.get("<|nospeech|>")
no_speech_prob = None
sum_logprob = 0.0
gen_count = 0
def greedy_or_sample(logits_slice: np.ndarray) -> int:
if do_sample or temperature != 1.0:
return sample_next(logits_slice, temperature)
return int(np.argmax(logits_slice))
if num_beams <= 1:
log_probs_first = log_softmax(step0_slice)
if nospeech_id is not None and 0 <= nospeech_id < log_probs_first.shape[0]:
no_speech_prob = float(np.exp(log_probs_first[nospeech_id]))
next_token = greedy_or_sample(step0_slice)
sum_logprob += float(log_probs_first[next_token])
gen_count += 1
token_ids.append(next_token)
step = 1
for _ in range(args.max_new_tokens - 1):
if eos_token_id is not None and next_token == int(eos_token_id):
break
past_inputs = {"input_ids": np.array([[next_token]], dtype=np.int64)}
if needs_encoder_hidden:
past_inputs["encoder_hidden_states"] = encoder_hidden
for input_meta in dec_past_inputs[1:]:
name = input_meta.name
present_name = name.replace("past.", "present.")
past_inputs[name] = present[present_name]
outputs = decoder_past_sess.run(dec_past_out_names, past_inputs)
out_map = dict(zip(dec_past_out_names, outputs))
logits = out_map.get("logits", outputs[0])
logits_slice = ffi.apply_suppression(
np.ascontiguousarray(logits[0, -1], dtype=np.float32),
suppress_tokens,
begin_suppress_tokens,
step=step,
)
present = {k: v for k, v in out_map.items() if k.startswith("present.")}
log_probs = log_softmax(logits_slice)
next_token = greedy_or_sample(logits_slice)
sum_logprob += float(log_probs[next_token])
gen_count += 1
token_ids.append(next_token)
step += 1
else:
first_log_probs = log_softmax(step0_slice / float(temperature if temperature > 0 else 1.0))
if nospeech_id is not None and 0 <= nospeech_id < first_log_probs.shape[0]:
no_speech_prob = float(np.exp(first_log_probs[nospeech_id]))
top_indices = np.argsort(first_log_probs)[::-1][:num_beams]
beams = []
for idx in top_indices:
beams.append(
{
"tokens": token_ids + [int(idx)],
"score": float(first_log_probs[idx]),
"present": present,
}
)
for step in range(1, args.max_new_tokens):
candidates = []
for beam in beams:
if eos_token_id is not None and beam["tokens"][-1] == int(eos_token_id):
candidates.append(beam)
continue
past_inputs = {"input_ids": np.array([[beam["tokens"][-1]]], dtype=np.int64)}
if needs_encoder_hidden:
past_inputs["encoder_hidden_states"] = encoder_hidden
for input_meta in dec_past_inputs[1:]:
name = input_meta.name
present_name = name.replace("past.", "present.")
past_inputs[name] = beam["present"][present_name]
outputs = decoder_past_sess.run(dec_past_out_names, past_inputs)
out_map = dict(zip(dec_past_out_names, outputs))
logits = out_map.get("logits", outputs[0])
logits_slice = ffi.apply_suppression(
np.ascontiguousarray(logits[0, -1], dtype=np.float32),
suppress_tokens,
begin_suppress_tokens,
step=step,
)
present_next = {k: v for k, v in out_map.items() if k.startswith("present.")}
log_probs = log_softmax(logits_slice / float(temperature if temperature > 0 else 1.0))
step_top = np.argsort(log_probs)[::-1][:num_beams]
for idx in step_top:
candidates.append(
{
"tokens": beam["tokens"] + [int(idx)],
"score": beam["score"] + float(log_probs[idx]),
"present": present_next,
}
)
candidates.sort(key=lambda x: x["score"], reverse=True)
beams = candidates[:num_beams]
if all(
eos_token_id is not None and beam["tokens"][-1] == int(eos_token_id)
for beam in beams
):
break
best = max(beams, key=lambda x: x["score"])
token_ids = list(best["tokens"])
if len(token_ids) > len(prompt_ids):
sum_logprob = float(best["score"])
gen_count = len(token_ids) - len(prompt_ids)
avg_logprob = sum_logprob / gen_count if gen_count > 0 else None
if (
no_speech_threshold is not None
and no_speech_prob is not None
and no_speech_prob > float(no_speech_threshold)
and (
logprob_threshold is None
or (avg_logprob is not None and avg_logprob < float(logprob_threshold))
)
):
token_ids = list(prompt_ids)
transcript = ffi.decode_tokens(
tokenizer_json_path=tokenizer_json_path,
token_ids=token_ids,
skip_special_tokens=True,
).strip()
if args.out_dir:
out_dir = Path(args.out_dir)
else:
out_dir = Path(args.out_root) / model_name / args.run_name
out_dir.mkdir(parents=True, exist_ok=True)
if args.dump_step0:
step0_dir = out_dir / "step0"
step0_dir.mkdir(parents=True, exist_ok=True)
np.save(step0_dir / "decoder_logits_step0.npy", raw_step0_logits)
np.savez(step0_dir / "present_step0.npz", **step0_present)
np.save(step0_dir / "encoder_hidden.npy", encoder_hidden)
np.save(out_dir / "input_features.npy", input_features)
np.save(out_dir / "token_ids.npy", np.array([token_ids], dtype=np.int64))
(out_dir / "transcript.txt").write_text(transcript + "\n", encoding="utf-8")
meta = {
"model_id": model_id,
"model_name": model_name,
"dll_path": str(ffi.dll_path),
"audio": args.audio,
"sampling_rate": sampling_rate,
"n_mels": int(n_mels),
"input_features_shape": list(input_features.shape),
"prompt_ids": prompt_ids,
"decoder_start_token_id": int(decoder_start_token_id),
"eos_token_id": eos_token_id,
"max_new_tokens": args.max_new_tokens,
"num_beams": num_beams,
"temperature": temperature,
"do_sample": do_sample,
"no_speech_threshold": no_speech_threshold,
"logprob_threshold": logprob_threshold,
"no_speech_prob": no_speech_prob,
"avg_logprob": avg_logprob,
"omit_language_token": omit_language_token,
"omit_notimestamps_token": omit_notimestamps_token,
"lite_whisper_prompt": args.lite_whisper_prompt,
"baseline_meta_applied": bool(baseline_meta),
"token_ids_shape": [1, len(token_ids)],
"created_at": datetime.now(timezone.utc).isoformat(),
}
save_json(out_dir / "meta.json", meta)
safe_transcript = transcript.encode("cp932", errors="replace").decode("cp932")
print(f"Saved artifacts to: {out_dir}")
print(f"Transcript: {safe_transcript}")
if __name__ == "__main__":
main()