| from __future__ import annotations |
|
|
| import json |
| import sys |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Iterable |
|
|
| import librosa |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import soundfile as sf |
| from scipy.optimize import linear_sum_assignment |
| from scipy.signal import medfilt |
|
|
| ROOT = Path(__file__).resolve().parent |
| for path in (ROOT, ROOT / "datasets"): |
| path_str = str(path) |
| if path_str not in sys.path: |
| sys.path.insert(0, path_str) |
|
|
| import feature as ls_feature |
|
|
|
|
| @dataclass |
| class InferenceResult: |
| logits: np.ndarray |
| probabilities: np.ndarray |
| full_logits: np.ndarray |
| full_probabilities: np.ndarray |
| frame_hz: float |
| duration_seconds: float |
|
|
|
|
| def ensure_mono(audio: np.ndarray) -> np.ndarray: |
| if audio.ndim == 1: |
| return audio.astype(np.float32, copy=False) |
| return audio.mean(axis=1, dtype=np.float32) |
|
|
|
|
| def load_audio(audio_path: Path) -> tuple[np.ndarray, int]: |
| audio, sample_rate = sf.read(audio_path) |
| return ensure_mono(audio), sample_rate |
|
|
|
|
| def config_from_metadata(metadata: dict) -> dict: |
| return { |
| "data": { |
| "feat": { |
| "sample_rate": int(metadata["sample_rate"]), |
| "win_length": int(metadata["win_length"]), |
| "hop_length": int(metadata["hop_length"]), |
| "n_fft": int(metadata["n_fft"]), |
| "n_mels": int(metadata["n_mels"]), |
| }, |
| "context_recp": int(metadata["context_recp"]), |
| "subsampling": int(metadata["subsampling"]), |
| "feat_type": str(metadata["feat_type"]), |
| "max_speakers": int(metadata.get("max_speakers", int(metadata["max_nspks"]) - 2)), |
| }, |
| "model": { |
| "params": { |
| "conv_delay": int(metadata["conv_delay"]), |
| } |
| }, |
| } |
|
|
|
|
| def extract_features(audio: np.ndarray, sample_rate: int, config: dict) -> np.ndarray: |
| target_sr = int(config["data"]["feat"]["sample_rate"]) |
| if sample_rate != target_sr: |
| audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sr) |
| frame_shift = int(config["data"]["feat"]["hop_length"]) |
| subsampling = int(config["data"]["subsampling"]) |
| usable_samples = (len(audio) // (frame_shift * subsampling)) * (frame_shift * subsampling) |
| if usable_samples > 0: |
| audio = audio[:usable_samples] |
| stft = ls_feature.stft( |
| audio, |
| frame_size=int(config["data"]["feat"]["win_length"]), |
| frame_shift=frame_shift, |
| ) |
| feats = ls_feature.transform(stft, str(config["data"]["feat_type"])) |
| feats = ls_feature.splice(feats, int(config["data"]["context_recp"])) |
| feats, _ = ls_feature.subsample(feats, feats, subsampling) |
| return np.array(feats, copy=True).astype(np.float32, copy=False) |
|
|
|
|
| def frame_hz(config: dict) -> float: |
| return config["data"]["feat"]["sample_rate"] / ( |
| config["data"]["feat"]["hop_length"] * config["data"]["subsampling"] |
| ) |
|
|
|
|
| def parse_rttm(rttm_path: Path) -> tuple[list[dict], list[str]]: |
| entries = [] |
| speaker_order = [] |
| with open(rttm_path, "r", encoding="utf-8") as handle: |
| for line in handle: |
| parts = line.strip().split() |
| if not parts: |
| continue |
| speaker = parts[7] |
| if speaker not in speaker_order: |
| speaker_order.append(speaker) |
| entries.append( |
| { |
| "recording_id": parts[1], |
| "start": float(parts[3]), |
| "duration": float(parts[4]), |
| "speaker": speaker, |
| } |
| ) |
| return entries, speaker_order |
|
|
|
|
| def rttm_to_frame_matrix(entries: list[dict], speakers: list[str], num_frames: int, frame_rate: float) -> np.ndarray: |
| matrix = np.zeros((num_frames, len(speakers)), dtype=np.float32) |
| speaker_to_index = {speaker: index for index, speaker in enumerate(speakers)} |
| for entry in entries: |
| start = int(round(entry["start"] * frame_rate)) |
| stop = int(round((entry["start"] + entry["duration"]) * frame_rate)) |
| matrix[start : min(stop, num_frames), speaker_to_index[entry["speaker"]]] = 1.0 |
| return matrix |
|
|
|
|
| def collar_mask(reference: np.ndarray, collar_frames: int) -> np.ndarray: |
| if collar_frames <= 0: |
| return np.ones(reference.shape[0], dtype=bool) |
| mask = np.ones(reference.shape[0], dtype=bool) |
| for column in range(reference.shape[1]): |
| padded = np.pad(reference[:, column], (1, 1), constant_values=0) |
| changes = np.where(np.diff(padded) != 0)[0] |
| for change in changes: |
| start = max(0, change - collar_frames) |
| stop = min(reference.shape[0], change + collar_frames) |
| mask[start:stop] = False |
| return mask |
|
|
|
|
| def _pair_cost(pred_column: np.ndarray, ref_column: np.ndarray) -> float: |
| pred_column = pred_column.astype(bool) |
| ref_column = ref_column.astype(bool) |
| n_ref = ref_column.sum() |
| n_sys = pred_column.sum() |
| n_map = np.logical_and(pred_column, ref_column).sum() |
| miss = max(n_ref - n_sys, 0) |
| false_alarm = max(n_sys - n_ref, 0) |
| speaker_error = min(n_ref, n_sys) - n_map |
| return float(miss + false_alarm + speaker_error) |
|
|
|
|
| def map_predictions( |
| prediction_binary: np.ndarray, |
| reference_binary: np.ndarray, |
| valid_mask: np.ndarray, |
| ) -> tuple[np.ndarray, dict[int, int], list[int]]: |
| masked_pred = prediction_binary[valid_mask] |
| masked_ref = reference_binary[valid_mask] |
| num_pred = prediction_binary.shape[1] |
| num_ref = reference_binary.shape[1] |
| mapped = np.zeros((prediction_binary.shape[0], num_ref), dtype=np.float32) |
| assignment: dict[int, int] = {} |
| if num_pred == 0 or num_ref == 0: |
| return mapped, assignment, list(range(num_pred)) |
| cost = np.zeros((num_pred, num_ref), dtype=np.float32) |
| for pred_index in range(num_pred): |
| for ref_index in range(num_ref): |
| cost[pred_index, ref_index] = _pair_cost(masked_pred[:, pred_index], masked_ref[:, ref_index]) |
| row_index, col_index = linear_sum_assignment(cost) |
| matched_pred = set() |
| for pred_index, ref_index in zip(row_index, col_index): |
| mapped[:, ref_index] = prediction_binary[:, pred_index] |
| assignment[int(ref_index)] = int(pred_index) |
| matched_pred.add(int(pred_index)) |
| unmatched_pred = [pred_index for pred_index in range(num_pred) if pred_index not in matched_pred] |
| return mapped, assignment, unmatched_pred |
|
|
|
|
| def compute_der( |
| probabilities: np.ndarray, |
| reference_binary: np.ndarray, |
| threshold: float, |
| median_width: int, |
| collar_seconds: float, |
| frame_rate: float, |
| ) -> dict: |
| prediction_binary = (probabilities > threshold).astype(np.float32) |
| if median_width > 1: |
| prediction_binary = medfilt(prediction_binary, kernel_size=(median_width, 1)).astype(np.float32) |
| valid_mask = collar_mask(reference_binary, int(round(collar_seconds * frame_rate))) |
| mapped_binary, assignment, unmatched_pred = map_predictions(prediction_binary, reference_binary, valid_mask) |
| mapped_probabilities = np.zeros((probabilities.shape[0], reference_binary.shape[1]), dtype=np.float32) |
| for ref_index, pred_index in assignment.items(): |
| mapped_probabilities[:, ref_index] = probabilities[:, pred_index] |
| extra_binary = prediction_binary[:, unmatched_pred] if unmatched_pred else np.zeros((prediction_binary.shape[0], 0), dtype=np.float32) |
| scored_reference = np.concatenate( |
| [reference_binary, np.zeros((reference_binary.shape[0], extra_binary.shape[1]), dtype=np.float32)], |
| axis=1, |
| ) |
| scored_prediction = np.concatenate([mapped_binary, extra_binary], axis=1) |
| masked_ref = scored_reference[valid_mask] |
| masked_pred = scored_prediction[valid_mask] |
| n_ref = masked_ref.sum(axis=1) |
| n_sys = masked_pred.sum(axis=1) |
| miss = np.maximum(n_ref - n_sys, 0).sum() |
| false_alarm = np.maximum(n_sys - n_ref, 0).sum() |
| mapped_overlap = np.logical_and(masked_ref == 1, masked_pred == 1).sum(axis=1) |
| speaker_error = (np.minimum(n_ref, n_sys) - mapped_overlap).sum() |
| speaker_scored = masked_ref.sum() |
| der = float((miss + false_alarm + speaker_error) / speaker_scored) if speaker_scored else 0.0 |
| return { |
| "der": der, |
| "speaker_scored": float(speaker_scored), |
| "speaker_miss": float(miss), |
| "speaker_false_alarm": float(false_alarm), |
| "speaker_error": float(speaker_error), |
| "threshold": threshold, |
| "median_width": median_width, |
| "collar_seconds": collar_seconds, |
| "mapped_binary": mapped_binary, |
| "mapped_probabilities": mapped_probabilities, |
| "valid_mask": valid_mask, |
| "assignment": assignment, |
| "unmatched_prediction_indices": unmatched_pred, |
| } |
|
|
|
|
| def write_rttm( |
| recording_id: str, |
| binary_prediction: np.ndarray, |
| output_path: Path, |
| frame_rate: float, |
| speaker_labels: Iterable[str] | None = None, |
| ) -> None: |
| speaker_labels = list(speaker_labels or [f"spk{index:02d}" for index in range(binary_prediction.shape[1])]) |
| with open(output_path, "w", encoding="utf-8") as handle: |
| for speaker_index, speaker in enumerate(speaker_labels): |
| padded = np.pad(binary_prediction[:, speaker_index], (1, 1), constant_values=0) |
| changes = np.where(np.diff(padded) != 0)[0] |
| for start, stop in zip(changes[::2], changes[1::2]): |
| start_seconds = start / frame_rate |
| duration_seconds = (stop - start) / frame_rate |
| handle.write( |
| f"SPEAKER {recording_id} 1 {start_seconds:.3f} {duration_seconds:.3f} <NA> <NA> {speaker} <NA> <NA>\n" |
| ) |
|
|
|
|
| def save_heatmap( |
| reference_binary: np.ndarray, |
| mapped_binary: np.ndarray, |
| mapped_probabilities: np.ndarray, |
| frame_rate: float, |
| speaker_labels: list[str], |
| output_path: Path, |
| ) -> None: |
| duration_seconds = reference_binary.shape[0] / frame_rate |
| fig, axes = plt.subplots(3, 1, figsize=(16, 8), sharex=True, constrained_layout=True) |
| plots = [ |
| (reference_binary, "Expected (RTTM)", "Greys"), |
| (mapped_binary, "Predicted (Mapped, Binary)", "Greys"), |
| (mapped_probabilities, "Predicted (Mapped, Probability)", "viridis"), |
| ] |
| for axis, (matrix, title, cmap) in zip(axes, plots): |
| image = axis.imshow( |
| matrix.T, |
| aspect="auto", |
| origin="lower", |
| interpolation="nearest", |
| extent=[0.0, duration_seconds, -0.5, matrix.shape[1] - 0.5], |
| cmap=cmap, |
| vmin=0.0, |
| vmax=1.0, |
| ) |
| axis.set_title(title) |
| axis.set_yticks(range(len(speaker_labels))) |
| axis.set_yticklabels(speaker_labels) |
| if cmap != "Greys": |
| fig.colorbar(image, ax=axis, fraction=0.02, pad=0.01) |
| axes[-1].set_xlabel("Time (seconds)") |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| fig.savefig(output_path, dpi=200) |
| plt.close(fig) |
|
|
|
|
| def save_json(payload: dict, output_path: Path) -> None: |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(output_path, "w", encoding="utf-8") as handle: |
| json.dump(payload, handle, indent=2, sort_keys=True) |
|
|