custom-gopt-252-eval / examples /infer_one_audio.py
faeea's picture
Rewrite Chinese README for one-audio local inference and add single-audio script
2dd9831 verified
import argparse
import json
import sys
import warnings
from pathlib import Path
import numpy as np
import torch
SCORE_NAMES = ["accuracy", "completeness", "fluency", "prosodic", "total"]
def get_args():
parser = argparse.ArgumentParser(
description="Run one-audio inference with the bundled Whisper + Charsiu + Streaming GOPT pipeline."
)
parser.add_argument("--audio", type=Path, required=True, help="Path to one wav audio file.")
parser.add_argument("--bundle-dir", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--repo-root", type=Path, required=True, help="Path to the cloned custom-gopt repository root.")
parser.add_argument("--charsiu-src-dir", type=Path, required=True, help="Path to the official Charsiu repo root or its src directory.")
parser.add_argument("--device", type=str, default=None, help="cuda / cuda:0 / cpu. Defaults to cuda if available.")
parser.add_argument("--output-json", type=Path, default=None)
parser.add_argument("--main-context-tokens", type=int, default=None)
parser.add_argument("--right-context-tokens", type=int, default=None)
return parser.parse_args()
def add_repo_paths(repo_root):
repo_src = repo_root / "src"
prep_src = repo_src / "prep_data"
for path in [repo_src, prep_src]:
path_str = str(path)
if path_str not in sys.path:
sys.path.insert(0, path_str)
def load_asr_pipeline(whisper_model_dir, device):
from transformers import pipeline
from transformers.utils import logging as hf_logging
hf_logging.set_verbosity_error()
if str(device).startswith("cuda"):
pipe_device = int(str(device).split(":", 1)[1]) if ":" in str(device) else 0
torch_dtype = torch.float16
else:
pipe_device = -1
torch_dtype = torch.float32
pipe = pipeline(
"automatic-speech-recognition",
model=str(whisper_model_dir),
tokenizer=str(whisper_model_dir),
feature_extractor=str(whisper_model_dir),
framework="pt",
device=pipe_device,
dtype=torch_dtype,
)
if hasattr(pipe.model, "generation_config"):
pipe.model.generation_config.use_cache = False
kwargs = {
"return_timestamps": "word",
"generate_kwargs": {
"language": "english",
"task": "transcribe",
"max_new_tokens": 128,
"use_cache": False,
},
}
return pipe, kwargs
def load_audio_for_asr(audio_path, sample_rate):
import librosa
import soundfile as sf
audio, sr = sf.read(str(audio_path))
if audio.ndim > 1:
audio = audio.mean(axis=1)
audio = np.asarray(audio, dtype=np.float32)
if sr != sample_rate:
audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
sr = sample_rate
return audio, sr
def transcribe_audio(asr_pipe, pipe_kwargs, audio_path, sample_rate, normalize_word):
audio, sr = load_audio_for_asr(audio_path, sample_rate)
result = asr_pipe({"raw": audio, "sampling_rate": sr}, **pipe_kwargs)
transcript = (result.get("text") or "").strip()
words = []
for chunk in result.get("chunks", []):
text = normalize_word(chunk.get("text", ""))
timestamp = chunk.get("timestamp") or (None, None)
if not text or timestamp[0] is None or timestamp[1] is None:
continue
words.append(
{
"text": text,
"start": float(timestamp[0]),
"end": float(timestamp[1]),
}
)
return transcript, words
def load_model(model_dir, repo_root, device):
add_repo_paths(repo_root)
from models import StreamingGOPT, StreamingGOPTNoPhn
cfg = json.loads((model_dir / "config.json").read_text(encoding="utf-8"))
model_args = cfg["args"]
model_cls = StreamingGOPT if model_args["model"] == "streaming_gopt" else StreamingGOPTNoPhn
model = model_cls(
embed_dim=int(model_args["embed_dim"]),
num_heads=int(model_args["heads"]),
depth=int(model_args["depth"]),
input_dim=int(cfg["input_dim"]),
seq_len=int(cfg["seq_len"]),
phn_num=int(cfg["phn_num"]),
)
state = torch.load(model_dir / "best_audio_model.pth", map_location=device)
incompatible = model.load_state_dict(state, strict=False)
allowed_missing = {
"mlp_head_word4.0.weight",
"mlp_head_word4.0.bias",
"mlp_head_word4.1.weight",
"mlp_head_word4.1.bias",
}
unexpected = set(incompatible.unexpected_keys)
missing = set(incompatible.missing_keys)
if unexpected:
raise RuntimeError(f"Unexpected checkpoint keys: {sorted(unexpected)}")
disallowed_missing = missing - allowed_missing
if disallowed_missing:
raise RuntimeError(f"Missing required checkpoint keys: {sorted(disallowed_missing)}")
model = model.to(device)
model.eval()
return model, cfg
def build_phone_segments(
audio_path,
transcript,
repo_root,
charsiu_src_dir,
charsiu_model_dir,
sample_rate,
device,
phn_dict,
expected_feat_dim,
):
add_repo_paths(repo_root)
for module_name in ["Charsiu", "models", "utils", "processors"]:
sys.modules.pop(module_name, None)
from build_charsiu_seq_data import (
audio_logits,
build_model_phone_map,
build_silence_keep_mask,
import_official_charsiu_forced_align,
load_official_charsiu_aligner,
normalize_phone,
segment_feature,
)
charsiu = load_official_charsiu_aligner(
model_name=str(charsiu_model_dir),
device=str(device),
sample_rate=sample_rate,
sil_threshold=4,
lang="en",
charsiu_src_dir=str(charsiu_src_dir),
)
phone_to_frame_id, _, _ = build_model_phone_map(charsiu)
phone_groups, words = charsiu.charsiu_processor.get_phones_and_words(transcript)
if not phone_groups:
raise ValueError("ASR transcript cannot be converted into phones.")
flat_records = []
for word_id, phones in enumerate(phone_groups):
word_text = str(words[word_id]).lower() if word_id < len(words) else f"word_{word_id}"
for phone in phones:
norm_phone = normalize_phone(phone)
if not norm_phone or norm_phone == "SIL":
continue
if norm_phone not in phone_to_frame_id:
raise ValueError(f"Phone not found in Charsiu frame-classifier vocab: {norm_phone}")
if norm_phone not in phn_dict:
raise ValueError(f"Phone not found in GOPT phone vocab: {norm_phone}")
flat_records.append(
{
"phone": norm_phone,
"word_id": int(word_id),
"word_text": word_text,
}
)
if not flat_records:
raise ValueError("No valid phones remained after transcript normalization.")
phone_ids = charsiu.charsiu_processor.get_phone_ids(phone_groups)
target_phone_ids = list(phone_ids[1:-1])
if len(target_phone_ids) != len(flat_records):
raise ValueError(
f"Phone count mismatch after G2P: target_phone_ids={len(target_phone_ids)} flat_records={len(flat_records)}"
)
probs, audio_duration = audio_logits(
audio_path=str(audio_path),
processor=charsiu.charsiu_processor,
model=charsiu.aligner,
sample_rate=sample_rate,
device=device,
)
keep_mask = build_silence_keep_mask(charsiu, probs)
kept_indices = np.flatnonzero(keep_mask)
kept_probs = probs[keep_mask]
if kept_probs.shape[0] < len(target_phone_ids):
raise ValueError(
f"Not enough non-silence frames for the recognized phones: frames={kept_probs.shape[0]} phones={len(target_phone_ids)}"
)
forced_align = import_official_charsiu_forced_align(str(charsiu_src_dir))
aligned_phone_ids = np.asarray(forced_align(kept_probs, target_phone_ids), dtype=np.int32)
frame_step = float(audio_duration) / max(len(probs), 1)
segments = []
for phone_idx, record in enumerate(flat_records):
token_frames = np.flatnonzero(aligned_phone_ids == phone_idx)
if token_frames.size == 0:
raise ValueError(f"Empty aligned segment for phone index {phone_idx} ({record['phone']}).")
segment_probs = kept_probs[token_frames]
target_id = int(phone_to_frame_id[record["phone"]])
base_feature = segment_feature(segment_probs, target_id, frame_step).astype(np.float32)
expected_base_dim = int(expected_feat_dim) - 1
expected_phone_prob_dim = expected_base_dim - 4
current_phone_prob_dim = int(base_feature.shape[0]) - 4
if current_phone_prob_dim == expected_phone_prob_dim + 1:
# Current local Charsiu exports an extra [PAD] probability channel at the end.
# The released GOPT checkpoint was trained without that channel.
base_feature = np.concatenate([base_feature[:expected_phone_prob_dim], base_feature[-4:]], axis=0)
elif current_phone_prob_dim != expected_phone_prob_dim:
raise ValueError(
f"Unexpected phone probability dimension: current={current_phone_prob_dim} expected={expected_phone_prob_dim}"
)
feature = np.concatenate([base_feature, np.array([0.0], dtype=np.float32)], axis=0)
start_frame = int(kept_indices[token_frames[0]])
end_frame = int(kept_indices[token_frames[-1]]) + 1
start_time = float(start_frame * frame_step)
end_time = float(min(audio_duration, end_frame * frame_step))
segments.append(
{
"phone": record["phone"],
"phone_id": int(phn_dict[record["phone"]]),
"word_id": int(record["word_id"]),
"word_text": record["word_text"],
"start_time": start_time,
"end_time": end_time,
"feature": feature,
}
)
return segments
def prepare_model_inputs(segments, seq_len, feat_dim, norm_mean, norm_std, device):
if len(segments) > seq_len:
raise ValueError(f"Phone sequence too long for this model: {len(segments)} > seq_len({seq_len})")
feat = np.zeros((seq_len, feat_dim), dtype=np.float32)
phn = np.full((seq_len,), -1, dtype=np.int64)
for idx, segment in enumerate(segments):
if int(segment["feature"].shape[-1]) != feat_dim:
raise ValueError(
f"Feature dimension mismatch at segment {idx}: got {segment['feature'].shape[-1]}, expected {feat_dim}"
)
feat[idx] = segment["feature"]
phn[idx] = int(segment["phone_id"])
valid = phn >= 0
feat[valid] = (feat[valid] - float(norm_mean)) / float(norm_std)
x = torch.from_numpy(feat).unsqueeze(0).to(device)
p = torch.from_numpy(phn).unsqueeze(0).to(device)
return x, p
def predict_scores(model, x, p, main_context_tokens, right_context_tokens):
with torch.no_grad():
u1, u2, u3, u4, u5, _, _, _, _, _ = model(
x,
p,
main_context_tokens=int(main_context_tokens),
right_context_tokens=int(right_context_tokens),
)
values = torch.cat([u1, u2, u3, u4, u5], dim=1).squeeze(0).cpu().numpy() * 5.0
values = np.clip(values, 0.0, 5.0)
return {name: float(value) for name, value in zip(SCORE_NAMES, values.tolist())}
def build_output(audio_path, transcript, asr_words, scores, segments, device, bundle_dir):
return {
"status": "ok",
"audio_path": str(audio_path.resolve()).replace("\\", "/"),
"bundle_dir": str(bundle_dir.resolve()).replace("\\", "/"),
"device": str(device),
"transcript": transcript,
"utterance_scores": scores,
"overall_score": float(scores["total"]),
"num_phone_segments": int(len(segments)),
"num_asr_words": int(len(asr_words)),
"recognized_words": [str(word["text"]).lower() for word in asr_words],
}
def main():
warnings.filterwarnings("ignore", message=".*return_token_timestamps.*")
args = get_args()
bundle_dir = args.bundle_dir.resolve()
repo_root = args.repo_root.resolve()
charsiu_src_dir = args.charsiu_src_dir.resolve()
device = torch.device(args.device or ("cuda" if torch.cuda.is_available() else "cpu"))
if not args.audio.exists():
raise FileNotFoundError(f"Audio file not found: {args.audio}")
model_dir = bundle_dir / "streaming_gopt_best"
whisper_model_dir = bundle_dir / "whisper_best_model"
charsiu_model_dir = bundle_dir / "charsiu_en_w2v2_tiny_fc_10ms"
inference_assets_path = model_dir / "inference_assets.json"
inference_assets = json.loads(inference_assets_path.read_text(encoding="utf-8"))
sample_rate = int(inference_assets["sample_rate"])
norm_mean = float(inference_assets["train_norm_mean"])
norm_std = float(inference_assets["train_norm_std"])
phn_dict = {str(key): int(value) for key, value in inference_assets["phn_dict"].items()}
model, cfg = load_model(model_dir, repo_root, device)
add_repo_paths(repo_root)
from build_charsiu_seq_data import normalize_word
asr_pipe, asr_kwargs = load_asr_pipeline(whisper_model_dir, device)
transcript, asr_words = transcribe_audio(asr_pipe, asr_kwargs, args.audio, sample_rate, normalize_word)
if not transcript:
raise ValueError("Whisper produced an empty transcript.")
segments = build_phone_segments(
audio_path=args.audio,
transcript=transcript,
repo_root=repo_root,
charsiu_src_dir=charsiu_src_dir,
charsiu_model_dir=charsiu_model_dir,
sample_rate=sample_rate,
device=device,
phn_dict=phn_dict,
expected_feat_dim=int(cfg["input_dim"]),
)
seq_len = int(cfg["seq_len"])
feat_dim = int(cfg["input_dim"])
x, p = prepare_model_inputs(
segments=segments,
seq_len=seq_len,
feat_dim=feat_dim,
norm_mean=norm_mean,
norm_std=norm_std,
device=device,
)
model_args = cfg["args"]
main_context_choices = model_args.get("main_context_token_choices") or [
int(item.strip()) for item in str(model_args["main_context_tokens"]).split(",") if item.strip()
]
right_context_choices = model_args.get("right_context_token_choices") or [
int(item.strip()) for item in str(model_args["right_context_tokens"]).split(",") if item.strip()
]
main_context_tokens = int(args.main_context_tokens if args.main_context_tokens is not None else max(main_context_choices))
right_context_tokens = int(
args.right_context_tokens if args.right_context_tokens is not None else max(right_context_choices)
)
scores = predict_scores(model, x, p, main_context_tokens, right_context_tokens)
payload = build_output(args.audio, transcript, asr_words, scores, segments, device, bundle_dir)
result_json = json.dumps(payload, ensure_ascii=False, indent=2)
print(result_json)
if args.output_json is not None:
args.output_json.parent.mkdir(parents=True, exist_ok=True)
args.output_json.write_text(result_json, encoding="utf-8")
if __name__ == "__main__":
main()