import argparse import json from datetime import datetime, timezone from pathlib import Path import numpy as np import onnxruntime as ort from liteasr_ffi import LiteAsrFfi try: from huggingface_hub import hf_hub_download except Exception: # pragma: no cover hf_hub_download = None LANGUAGE_MAP = { "japanese": "ja", "english": "en", "chinese": "zh", "korean": "ko", "french": "fr", "german": "de", "spanish": "es", "italian": "it", "portuguese": "pt", "russian": "ru", } def load_json(path: Path) -> dict | None: if not path.exists(): return None return json.loads(path.read_text(encoding="utf-8")) def save_json(path: Path, payload: dict) -> None: path.write_text( json.dumps(payload, indent=2, ensure_ascii=True) + "\n", encoding="utf-8", ) def resolve_model_id(model_id: str | None, model_name: str | None) -> str: if model_id: return model_id if model_name: return model_name.replace("__", "/") model_file = Path(".model_id") if model_file.exists(): value = model_file.read_text(encoding="utf-8").strip() if value: return value raise SystemExit("Model ID is required. Pass --model-id or --model-name.") def model_name_from_id(model_id: str) -> str: return model_id.replace("/", "__") def resolve_audio_path(path_value: str | None) -> Path: if path_value: return Path(path_value) fallback = Path("samples") / "a01.wav" if fallback.exists(): return fallback raise SystemExit("Audio path is required. Pass --audio or add samples/a01.wav.") def download_file(repo_id: str, target_dir: Path, filename: str, required: bool) -> None: target_dir.mkdir(parents=True, exist_ok=True) path = target_dir / filename if path.exists(): return if hf_hub_download is None: if required: raise SystemExit( "huggingface_hub not installed. Install it or place tokenizer/config files manually." ) return try: hf_hub_download(repo_id=repo_id, filename=filename, local_dir=target_dir) except Exception: if required: raise def normalize_token_list(value) -> list[int]: if value is None: return [] if isinstance(value, int): return [] if value < 0 else [int(value)] if isinstance(value, list): items = [] for item in value: if isinstance(item, (int, float)): items.append(int(item)) return items return [] def parse_forced_decoder_ids(value, decoder_start_token_id: int | None) -> list[int] | None: if not value: return None try: ordered = sorted(value, key=lambda x: x[0]) tokens = [int(item[1]) for item in ordered] if decoder_start_token_id is not None: if not tokens or tokens[0] != int(decoder_start_token_id): tokens = [int(decoder_start_token_id)] + tokens return tokens except Exception: return None def load_vocab(tokenizer_json_path: Path) -> dict[str, int]: payload = load_json(tokenizer_json_path) if payload is None: raise SystemExit(f"tokenizer.json not found: {tokenizer_json_path}") vocab_obj = payload.get("model", {}).get("vocab") if not isinstance(vocab_obj, dict): raise SystemExit("tokenizer.json missing model.vocab") vocab: dict[str, int] = {} for token, value in vocab_obj.items(): if isinstance(value, int): vocab[token] = value for item in payload.get("added_tokens", []) or []: if isinstance(item, dict): content = item.get("content") idx = item.get("id") if isinstance(content, str) and isinstance(idx, int): vocab[content] = idx return vocab def remove_optional_prompt_tokens( prompt_ids: list[int], vocab: dict[str, int], language: str, omit_language_token: bool, omit_notimestamps_token: bool, ) -> list[int]: out = list(prompt_ids) if omit_language_token: lang_key = f"<|{language.lower()}|>" if lang_key not in vocab: mapped = LANGUAGE_MAP.get(language.lower()) if mapped: lang_key = f"<|{mapped}|>" lang_id = vocab.get(lang_key) if lang_id is not None: out = [tid for tid in out if tid != lang_id] if omit_notimestamps_token: nt_id = vocab.get("<|notimestamps|>") if nt_id is not None: out = [tid for tid in out if tid != nt_id] return out def log_softmax(x: np.ndarray) -> np.ndarray: x_max = np.max(x, axis=-1, keepdims=True) y = x - x_max logsum = np.log(np.sum(np.exp(y), axis=-1, keepdims=True)) return y - logsum def sample_next(logits_slice: np.ndarray, temperature: float) -> int: if temperature == 0.0: return int(np.argmax(logits_slice)) scaled = logits_slice / float(temperature) probs = np.exp(log_softmax(scaled)) probs = probs / np.sum(probs) return int(np.random.choice(len(probs), p=probs)) def main() -> None: parser = argparse.ArgumentParser( description="ONNX decode with Rust FFI preprocessing/prompt/decode helpers" ) parser.add_argument("--dll-path", help="Path to liteasr_ffi.dll") parser.add_argument("--model-id", help="Hugging Face model ID") parser.add_argument("--model-name", help="Local model name under models/") parser.add_argument("--onnx-dir", help="Directory containing ONNX files") parser.add_argument("--tokenizer-dir", help="Directory containing tokenizer files") parser.add_argument( "--tokenizer-id", help="Tokenizer/model ID to fetch tokenizer/config files (defaults to model-id).", ) parser.add_argument( "--config-id", help="Config/model ID to fetch config.json (defaults to model-id).", ) parser.add_argument("--audio", help="Path to input wav file") parser.add_argument("--input-features", help="Path to input_features.npy") parser.add_argument("--baseline-features", action="store_true") parser.add_argument("--baseline-meta", action="store_true") parser.add_argument("--run-name", default="a01") parser.add_argument("--out-root", default="artifacts/onnx_py_ffi") parser.add_argument("--out-dir", help="Output directory") parser.add_argument("--language", default="japanese") parser.add_argument("--task", default="transcribe") parser.add_argument("--with-timestamps", action="store_true") parser.add_argument("--omit-language-token", action="store_true") parser.add_argument("--omit-notimestamps-token", action="store_true") parser.add_argument("--lite-whisper-prompt", action="store_true") parser.add_argument("--max-new-tokens", type=int, default=128) parser.add_argument("--num-beams", type=int) parser.add_argument("--temperature", type=float) parser.add_argument("--do-sample", action="store_true") parser.add_argument("--seed", type=int) parser.add_argument("--no-speech-threshold", type=float) parser.add_argument("--logprob-threshold", type=float) parser.add_argument("--dump-step0", action="store_true") args = parser.parse_args() ffi = LiteAsrFfi(args.dll_path) model_id = resolve_model_id(args.model_id, args.model_name) model_name = args.model_name or model_name_from_id(model_id) onnx_dir = Path(args.onnx_dir) if args.onnx_dir else Path("models") / model_name / "onnx" tokenizer_dir = ( Path(args.tokenizer_dir) if args.tokenizer_dir else Path("models") / model_name / "tokenizer" ) tokenizer_id = args.tokenizer_id or model_id if model_id.startswith("efficient-speech/") and args.tokenizer_id is None: tokenizer_id = ( "openai/whisper-large-v3-turbo" if "turbo" in model_id else "openai/whisper-large-v3" ) config_id = args.config_id or model_id download_file(tokenizer_id, tokenizer_dir, "tokenizer.json", required=True) download_file(config_id, tokenizer_dir, "config.json", required=True) download_file(tokenizer_id, tokenizer_dir, "preprocessor_config.json", required=False) download_file(config_id, tokenizer_dir, "generation_config.json", required=False) encoder_path = onnx_dir / "encoder_model.onnx" decoder_path = onnx_dir / "decoder_model.onnx" decoder_past_path = onnx_dir / "decoder_with_past_model.onnx" for p in (encoder_path, decoder_path, decoder_past_path): if not p.exists(): raise SystemExit(f"Missing ONNX file: {p}") config = load_json(tokenizer_dir / "config.json") or {} preprocessor = load_json(tokenizer_dir / "preprocessor_config.json") or {} generation = load_json(tokenizer_dir / "generation_config.json") or {} tokenizer_json_path = tokenizer_dir / "tokenizer.json" vocab = load_vocab(tokenizer_json_path) decoder_start_token_id = config.get("decoder_start_token_id") eos_token_id = config.get("eos_token_id") if decoder_start_token_id is None: raise SystemExit("decoder_start_token_id not found in config.json") sampling_rate = int(preprocessor.get("sampling_rate", 16000)) n_mels = int(preprocessor.get("feature_size") or preprocessor.get("n_mels") or 80) suppress_tokens = normalize_token_list(generation.get("suppress_tokens")) begin_suppress_tokens = normalize_token_list(generation.get("begin_suppress_tokens")) num_beams = int(generation.get("num_beams", 1) or 1) temperature = float(generation.get("temperature", 1.0) or 1.0) do_sample = bool(generation.get("do_sample", False)) no_speech_threshold = generation.get("no_speech_threshold") logprob_threshold = generation.get("logprob_threshold") baseline_meta = None if args.baseline_meta: meta_path = Path("artifacts") / "baseline" / model_name / args.run_name / "meta.json" if meta_path.exists(): baseline_meta = load_json(meta_path) gen = (baseline_meta or {}).get("generation_config") or {} if gen.get("num_beams") is not None: num_beams = int(gen["num_beams"]) if gen.get("temperature") is not None: temperature = float(gen["temperature"]) if gen.get("do_sample") is not None: do_sample = bool(gen["do_sample"]) if gen.get("suppress_tokens") is not None: suppress_tokens = normalize_token_list(gen.get("suppress_tokens")) if gen.get("begin_suppress_tokens") is not None: begin_suppress_tokens = normalize_token_list(gen.get("begin_suppress_tokens")) if gen.get("no_speech_threshold") is not None: no_speech_threshold = gen.get("no_speech_threshold") if gen.get("logprob_threshold") is not None: logprob_threshold = gen.get("logprob_threshold") if args.num_beams is not None: num_beams = int(args.num_beams) if args.temperature is not None: temperature = float(args.temperature) if args.do_sample: do_sample = True if args.no_speech_threshold is not None: no_speech_threshold = float(args.no_speech_threshold) if args.logprob_threshold is not None: logprob_threshold = float(args.logprob_threshold) if args.seed is not None: np.random.seed(args.seed) encoder_sess = ort.InferenceSession(str(encoder_path), providers=["CPUExecutionProvider"]) decoder_sess = ort.InferenceSession(str(decoder_path), providers=["CPUExecutionProvider"]) decoder_past_sess = ort.InferenceSession( str(decoder_past_path), providers=["CPUExecutionProvider"] ) if args.input_features: input_features = np.load(Path(args.input_features)).astype(np.float32) elif args.baseline_features: baseline_path = ( Path("artifacts") / "baseline" / model_name / args.run_name / "input_features.npy" ) if not baseline_path.exists(): raise SystemExit(f"baseline input_features not found: {baseline_path}") input_features = np.load(baseline_path).astype(np.float32) else: audio_path = resolve_audio_path(args.audio) log_mel = ffi.preprocess_wav(audio_path, sampling_rate, n_mels) input_features = log_mel[np.newaxis, :, :].astype(np.float32) if input_features.ndim != 3: raise SystemExit(f"input_features has unexpected shape: {input_features.shape}") enc_out_names = [o.name for o in encoder_sess.get_outputs()] enc_out = encoder_sess.run( enc_out_names, {encoder_sess.get_inputs()[0].name: input_features}, ) enc_map = dict(zip(enc_out_names, enc_out)) encoder_hidden = enc_map.get("encoder_hidden_states", enc_out[0]) omit_language_token = args.omit_language_token omit_notimestamps_token = args.omit_notimestamps_token if args.lite_whisper_prompt: if "turbo-acc" in model_id: omit_notimestamps_token = True elif "turbo" in model_id: omit_language_token = True prompt_ids = None if baseline_meta: prompt_ids = parse_forced_decoder_ids( baseline_meta.get("forced_decoder_ids"), int(decoder_start_token_id), ) if prompt_ids is None: prompt_ids = ffi.build_prompt_ids( tokenizer_json_path, args.language, args.task, args.with_timestamps, omit_language_token, omit_notimestamps_token, ) else: prompt_ids = remove_optional_prompt_tokens( prompt_ids, vocab, args.language, omit_language_token, omit_notimestamps_token, ) dec_out_names = [o.name for o in decoder_sess.get_outputs()] dec_past_out_names = [o.name for o in decoder_past_sess.get_outputs()] dec_past_inputs = decoder_past_sess.get_inputs() needs_encoder_hidden = any(inp.name == "encoder_hidden_states" for inp in dec_past_inputs) input_ids = np.array([prompt_ids], dtype=np.int64) outputs = decoder_sess.run( dec_out_names, { decoder_sess.get_inputs()[0].name: input_ids, decoder_sess.get_inputs()[1].name: encoder_hidden, }, ) dec_map = dict(zip(dec_out_names, outputs)) logits = dec_map.get("logits", outputs[0]) raw_step0_logits = logits.copy() step0_slice = ffi.apply_suppression( np.ascontiguousarray(logits[0, -1], dtype=np.float32), suppress_tokens, begin_suppress_tokens, step=0, ) step0_present = {k: v for k, v in dec_map.items() if k.startswith("present.")} present = dict(step0_present) token_ids = list(prompt_ids) nospeech_id = vocab.get("<|nospeech|>") no_speech_prob = None sum_logprob = 0.0 gen_count = 0 def greedy_or_sample(logits_slice: np.ndarray) -> int: if do_sample or temperature != 1.0: return sample_next(logits_slice, temperature) return int(np.argmax(logits_slice)) if num_beams <= 1: log_probs_first = log_softmax(step0_slice) if nospeech_id is not None and 0 <= nospeech_id < log_probs_first.shape[0]: no_speech_prob = float(np.exp(log_probs_first[nospeech_id])) next_token = greedy_or_sample(step0_slice) sum_logprob += float(log_probs_first[next_token]) gen_count += 1 token_ids.append(next_token) step = 1 for _ in range(args.max_new_tokens - 1): if eos_token_id is not None and next_token == int(eos_token_id): break past_inputs = {"input_ids": np.array([[next_token]], dtype=np.int64)} if needs_encoder_hidden: past_inputs["encoder_hidden_states"] = encoder_hidden for input_meta in dec_past_inputs[1:]: name = input_meta.name present_name = name.replace("past.", "present.") past_inputs[name] = present[present_name] outputs = decoder_past_sess.run(dec_past_out_names, past_inputs) out_map = dict(zip(dec_past_out_names, outputs)) logits = out_map.get("logits", outputs[0]) logits_slice = ffi.apply_suppression( np.ascontiguousarray(logits[0, -1], dtype=np.float32), suppress_tokens, begin_suppress_tokens, step=step, ) present = {k: v for k, v in out_map.items() if k.startswith("present.")} log_probs = log_softmax(logits_slice) next_token = greedy_or_sample(logits_slice) sum_logprob += float(log_probs[next_token]) gen_count += 1 token_ids.append(next_token) step += 1 else: first_log_probs = log_softmax(step0_slice / float(temperature if temperature > 0 else 1.0)) if nospeech_id is not None and 0 <= nospeech_id < first_log_probs.shape[0]: no_speech_prob = float(np.exp(first_log_probs[nospeech_id])) top_indices = np.argsort(first_log_probs)[::-1][:num_beams] beams = [] for idx in top_indices: beams.append( { "tokens": token_ids + [int(idx)], "score": float(first_log_probs[idx]), "present": present, } ) for step in range(1, args.max_new_tokens): candidates = [] for beam in beams: if eos_token_id is not None and beam["tokens"][-1] == int(eos_token_id): candidates.append(beam) continue past_inputs = {"input_ids": np.array([[beam["tokens"][-1]]], dtype=np.int64)} if needs_encoder_hidden: past_inputs["encoder_hidden_states"] = encoder_hidden for input_meta in dec_past_inputs[1:]: name = input_meta.name present_name = name.replace("past.", "present.") past_inputs[name] = beam["present"][present_name] outputs = decoder_past_sess.run(dec_past_out_names, past_inputs) out_map = dict(zip(dec_past_out_names, outputs)) logits = out_map.get("logits", outputs[0]) logits_slice = ffi.apply_suppression( np.ascontiguousarray(logits[0, -1], dtype=np.float32), suppress_tokens, begin_suppress_tokens, step=step, ) present_next = {k: v for k, v in out_map.items() if k.startswith("present.")} log_probs = log_softmax(logits_slice / float(temperature if temperature > 0 else 1.0)) step_top = np.argsort(log_probs)[::-1][:num_beams] for idx in step_top: candidates.append( { "tokens": beam["tokens"] + [int(idx)], "score": beam["score"] + float(log_probs[idx]), "present": present_next, } ) candidates.sort(key=lambda x: x["score"], reverse=True) beams = candidates[:num_beams] if all( eos_token_id is not None and beam["tokens"][-1] == int(eos_token_id) for beam in beams ): break best = max(beams, key=lambda x: x["score"]) token_ids = list(best["tokens"]) if len(token_ids) > len(prompt_ids): sum_logprob = float(best["score"]) gen_count = len(token_ids) - len(prompt_ids) avg_logprob = sum_logprob / gen_count if gen_count > 0 else None if ( no_speech_threshold is not None and no_speech_prob is not None and no_speech_prob > float(no_speech_threshold) and ( logprob_threshold is None or (avg_logprob is not None and avg_logprob < float(logprob_threshold)) ) ): token_ids = list(prompt_ids) transcript = ffi.decode_tokens( tokenizer_json_path=tokenizer_json_path, token_ids=token_ids, skip_special_tokens=True, ).strip() if args.out_dir: out_dir = Path(args.out_dir) else: out_dir = Path(args.out_root) / model_name / args.run_name out_dir.mkdir(parents=True, exist_ok=True) if args.dump_step0: step0_dir = out_dir / "step0" step0_dir.mkdir(parents=True, exist_ok=True) np.save(step0_dir / "decoder_logits_step0.npy", raw_step0_logits) np.savez(step0_dir / "present_step0.npz", **step0_present) np.save(step0_dir / "encoder_hidden.npy", encoder_hidden) np.save(out_dir / "input_features.npy", input_features) np.save(out_dir / "token_ids.npy", np.array([token_ids], dtype=np.int64)) (out_dir / "transcript.txt").write_text(transcript + "\n", encoding="utf-8") meta = { "model_id": model_id, "model_name": model_name, "dll_path": str(ffi.dll_path), "audio": args.audio, "sampling_rate": sampling_rate, "n_mels": int(n_mels), "input_features_shape": list(input_features.shape), "prompt_ids": prompt_ids, "decoder_start_token_id": int(decoder_start_token_id), "eos_token_id": eos_token_id, "max_new_tokens": args.max_new_tokens, "num_beams": num_beams, "temperature": temperature, "do_sample": do_sample, "no_speech_threshold": no_speech_threshold, "logprob_threshold": logprob_threshold, "no_speech_prob": no_speech_prob, "avg_logprob": avg_logprob, "omit_language_token": omit_language_token, "omit_notimestamps_token": omit_notimestamps_token, "lite_whisper_prompt": args.lite_whisper_prompt, "baseline_meta_applied": bool(baseline_meta), "token_ids_shape": [1, len(token_ids)], "created_at": datetime.now(timezone.utc).isoformat(), } save_json(out_dir / "meta.json", meta) safe_transcript = transcript.encode("cp932", errors="replace").decode("cp932") print(f"Saved artifacts to: {out_dir}") print(f"Transcript: {safe_transcript}") if __name__ == "__main__": main()