| import argparse |
| from glob import glob |
| from unittest import case |
| import dill |
| from argparse import Namespace |
| import torch |
| import torchaudio |
| import torch.nn.functional as F |
| from utils import (max_min_norm, |
| get_timit_61_phoneme_mappings) |
| from next_frame_classifier import NextFrameClassifier |
|
|
| from dataloader import spectral_size |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import os |
|
|
| import dutch_preprocess |
| from utils import timit_to_leehon_map_MACRO, timit_leehon_39_phonemes, timit_61_phonemes |
|
|
| |
| |
| _MODEL_CACHE = {} |
|
|
| def _load_model(ckpt_path): |
| cached = _MODEL_CACHE.get(ckpt_path) |
| if cached is not None: |
| return cached |
| ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage) |
| hp = ckpt["hparams"] |
| model = NextFrameClassifier(hp) |
| try: |
| weights = ckpt["state_dict"] |
| except Exception: |
| weights = ckpt["model_state_dict"] |
| weights = {k.replace("NFC.", ""): v for k, v in weights.items()} |
| model.load_state_dict(weights) |
| model.eval() |
| peak_detection_params = dill.loads(ckpt['peak_detection_params'])['cpc_1'] |
| _MODEL_CACHE.clear() |
| _MODEL_CACHE[ckpt_path] = (model, peak_detection_params) |
| return model, peak_detection_params |
|
|
| def main_predict(wav, ckpt, w_phi, language="english", annotation="phn", no_plots=False): |
| print(f"running inference on: {wav}") |
| print(f"running inferece using ckpt: {ckpt}") |
| print("\n\n", 90 * "-") |
|
|
| |
| |
| |
| _orig_savefig = plt.savefig if no_plots else None |
| if no_plots: |
| plt.savefig = lambda *a, **k: None |
|
|
| model, peak_detection_params = _load_model(ckpt) |
| |
| |
| audio, sr = torchaudio.load(wav) |
| assert sr == 16000, "model was trained with audio sampled at 16khz, please downsample." |
| audio = audio[0] |
| |
| |
| base_dir = os.path.dirname(wav) |
| base_name = os.path.basename(wav).split('.')[0] |
| |
| |
| |
| search_pattern = os.path.join(base_dir, f"{base_name}*.{annotation}") |
| matching_files = glob(search_pattern) |
| if matching_files: |
| phn_path = matching_files[0] |
| else: |
| print("No matching .phn file found. Using default naming convention.") |
| phn_path = wav.replace("wav", "phn") |
|
|
| |
| audio_len = len(audio) |
| spectral_len = spectral_size(audio_len) |
| len_ratio = (audio_len / spectral_len) |
|
|
| |
| with open(phn_path, "r") as f: |
| lines = f.readlines() |
| lines = list(map(lambda line: line.split(), lines)) |
|
|
| |
| times = torch.FloatTensor(list(map(lambda line: int(float(line[1]) / len_ratio), lines)))[:-1] |
| |
| times_sec = torch.FloatTensor(list(map(lambda line: (float(line[1]) / sr), lines)))[:-1] |
| |
| |
| |
| phonemes = list(map(lambda line: line[2].strip(), lines)) |
|
|
| |
| |
| |
| |
| truth_labels = list(phonemes) |
| |
| if language == "dutch": |
| lh39_ph = [] |
| for IFA_ph in phonemes: |
| print(f"\nINPUT: {IFA_ph}") |
| |
| output = dutch_preprocess.aligner_pipeline(IFA_ph if IFA_ph.lower() not in timit_61_phonemes else timit_to_leehon_map_MACRO[IFA_ph.lower()]) |
| |
| |
| |
| |
| |
| |
| lh39_ph.append([x["lh39"] for x in output]) |
| if not output: |
| print("Results: None") |
| print(phonemes) |
| print(f"Dutch IPA to LH39 mapping: {lh39_ph}") |
| |
| |
| phonemes = np.hstack(lh39_ph).tolist() |
| |
| audio, seg, phonemes, length = audio.unsqueeze(0), [times.tolist()], [phonemes], [audio_len/len_ratio] |
|
|
| |
| |
| with torch.no_grad(): |
| model.eval() |
| |
| |
| |
| preds,original_lengths, probs, frame_labels, _,preds_peaks, w_phi = model(audio,None,phonemes,length) |
| |
| |
| |
| phoneme_to_idx, idx_to_phoneme = get_timit_61_phoneme_mappings() |
| |
| |
| phoneme_labels = [idx_to_phoneme[i] for i in range(39)] |
| |
| |
| |
| probs_real = F.softmax(probs, dim=-1).squeeze(0).detach().numpy() |
| |
| out_dir = os.path.dirname(wav) |
| base_name = os.path.basename(wav).replace('.wav', '') |
|
|
| plt.figure(figsize=(15, 5)) |
| plt.imshow(probs_real.T, aspect='auto', cmap='viridis') |
| plt.colorbar(label='Probability') |
| plt.xlabel('Frame Index') |
| plt.ylabel('Phoneme') |
| plt.yticks(ticks=range(39), labels=phoneme_labels) |
| plt.title('Frame-wise Label Probability Map') |
| |
| |
| for i, s in enumerate(times): |
| s_val = float(s) |
| plt.axvline(x=s_val, color='red', linestyle='--', linewidth=1, |
| label='Truth boundary' if i == 0 else "") |
| if i < len(truth_labels): |
| plt.text(s_val, probs_real.shape[1] + 1, truth_labels[i], |
| color='red', rotation=90, va='top', ha='center', fontsize=8) |
| plt.savefig(os.path.join(out_dir, f"{base_name}_probs.png")) |
| plt.close() |
| |
| |
| preds = preds[1][0] |
| preds = max_min_norm(preds) |
| preds_np = preds.detach().numpy()[0] |
| median_h = np.median(preds_np) |
| preds_np = preds_np - median_h |
|
|
| |
| preds = torch.tensor(preds_peaks[0], dtype=torch.float32) |
| print(f"predicted boundaries (s): {preds}") |
|
|
| |
| |
| |
| |
| |
| abs_p = np.abs(preds_np) if preds_np.size else np.array([]) |
| if abs_p.size: |
| noise_thresh = 0.05 * float(np.max(abs_p)) |
| peaks = abs_p[abs_p > noise_thresh] |
| marker_h = float(np.median(peaks)) if peaks.size else float(np.max(abs_p)) |
| else: |
| marker_h = 0.1 |
| if marker_h <= 0: |
| marker_h = 0.1 |
|
|
| |
| signal = np.zeros(int(original_lengths[0])) |
| signal_max_idx = signal.shape[0] - 1 |
| for t in times: |
| idx = min(int(t), signal_max_idx) |
| signal[idx] = marker_h |
| times_clipped = [t for t in times if float(t) <= signal_max_idx] |
|
|
| plt.figure(figsize=(12, 6)) |
| plt.plot(signal, marker='*', linestyle='-', label='Truth boundary marker') |
| derivative_preds_np = np.diff(preds_np) |
| derivative_preds_np = np.concatenate([[0], derivative_preds_np]) |
| plt.plot(range(len(derivative_preds_np)), derivative_preds_np, |
| marker='o', label='Derivative of latent score', color='magenta') |
| plt.plot(range(len(preds_np)), preds_np, |
| marker='*', label='Latent score', color='red') |
|
|
| preds_plot = np.zeros(int(original_lengths[0])) |
| for pred in preds: |
| idx = int(pred * sr / len_ratio) |
| if idx <= len(preds_plot) - 1: |
| preds_plot[idx] = marker_h |
| plt.plot(range(len(preds_plot)), preds_plot, |
| marker='^', label='Predicted boundary', linestyle='None') |
|
|
| y_top = plt.ylim()[1] |
| for i, s in enumerate(times_clipped): |
| s_val = float(s) |
| plt.axvline(x=s_val, color='red', linestyle='--', linewidth=1, |
| label='Truth boundary' if i == 0 else "") |
| if i < len(truth_labels): |
| plt.text(s_val, y_top, truth_labels[i], color='red', |
| rotation=90, va='top', ha='center', fontsize=8) |
|
|
| plt.xlabel('Frame Index') |
| plt.ylabel('Score') |
| plt.title('Predicted Boundaries') |
| plt.legend(loc='upper right') |
| |
| |
| y_half = max(marker_h * 4.0, 0.1) |
| plt.ylim(-y_half, y_half) |
| plt.savefig(os.path.join(out_dir, f"{base_name}_boundaries.png")) |
| plt.close() |
|
|
| pred_bound, truth_bound = preds, times_sec |
| mapped_ph = lh39_ph if language == "dutch" else None |
|
|
| if no_plots: |
| plt.savefig = _orig_savefig |
|
|
| return pred_bound, truth_bound, mapped_ph |
|
|
| _SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) |
| DEFAULT_CKPT_ENGLISH = os.path.join(_SCRIPT_DIR, "pretrained_models", "falcon_timit_english.pt") |
| |
| |
| |
| DEFAULT_CKPT_MULTILINGUAL = os.path.join(_SCRIPT_DIR, "pretrained_models", "falcon_joint_multilingual.pt") |
|
|
| def resolve_internal_language(lang: str, mode: str, annotation: str) -> str: |
| """ |
| Map user-facing (--lang, --mode, --annotation) to the internal `language` |
| flag main_predict() understands. |
| |
| 'english' = no G2P; assumes labels are already TIMIT-39 phonemes. |
| 'dutch' = G2P pipeline (panphon-based mapping). Used for any non-English |
| language, word-level mode, or plain-text input. |
| """ |
| if lang == "english" and mode == "phoneme" and annotation.lower() == "phn": |
| return "english" |
| return "dutch" |
|
|
| def resolve_default_ckpt(lang: str) -> str: |
| return DEFAULT_CKPT_MULTILINGUAL if lang == "multilingual" else DEFAULT_CKPT_ENGLISH |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description='Unsupervised segmentation inference script') |
| parser.add_argument('--wav', help='path to wav file') |
| parser.add_argument('--ckpt', default=None, |
| help='Path to checkpoint file. If omitted, uses ' |
| 'pretrained_models/falcon_timit_english.pt for --lang english ' |
| 'or falcon_joint_multilingual.pt for --lang multilingual ' |
| '(the joint TIMIT+Buckeye model is best for cross-lingual zero-shot).') |
|
|
| parser.add_argument('--mode', type=str, default='phoneme', choices=['phoneme', 'word'], |
| help='Alignment granularity: "phoneme" = phoneme-level alignment (default). ' |
| '"word" = word-level alignment (zero-shot, no additional training).') |
| parser.add_argument('--lang', type=str, default='english', choices=['english', 'multilingual'], |
| help='Language setting: "english" = trained English phoneme alignment (default). ' |
| '"multilingual" = any non-English language (zero-shot cross-lingual).') |
| parser.add_argument('--annotation', type=str, default='phn', |
| help='Annotation file extension to search for (e.g. phn, wrd, word, txt). Default: phn') |
| parser.add_argument('--no-plots', action='store_true', |
| help='Skip per-file diagnostic plots (probs/logits/boundaries/dp_matrix) for faster runs.') |
| args = parser.parse_args() |
|
|
| ckpt = args.ckpt or resolve_default_ckpt(args.lang) |
| language = resolve_internal_language(args.lang, args.mode, args.annotation) |
| main_predict(args.wav, ckpt, w_phi=0.5, language=language, |
| annotation=args.annotation, no_plots=args.no_plots) |
|
|