| import os |
| import random |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import time |
| from scipy.signal import find_peaks |
| import wandb |
| from tqdm import tqdm |
| import concurrent.futures |
| from typing import List, Sequence, Union |
| import time |
| from memory_profiler import profile |
|
|
| |
| |
| _dp_matrix_out_dir = None |
|
|
| def set_dp_matrix_out_dir(path): |
| global _dp_matrix_out_dir |
| _dp_matrix_out_dir = path |
|
|
| timit_leehon_39_phonemes = [ |
| 'ao', 'ae', 'ah','aw', 'er', 'ay', |
| 'b', 'sil', 'ch', 'd', 'dh', 'dx', 'eh', 'el', 'm', 'en', 'ng', 'ey', |
| 'f', 'g', 'hh', 'ih', 'iy', 'jh', 'k', 'v', 'w', 'y', 'z', 'sh', 't', 'r', 's', 'th','uh', 'uw', 'oy', 'ow','p' |
| ] |
|
|
| timit_61_phonemes = [ |
| 'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', |
| 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', |
| 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'm', 'n', 'ng', 'l', |
| 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw','ux', |
| 'v', 'w', 'y', 'z', 'zh' |
| ] |
|
|
| |
| |
| phoneme_to_idx_MACRO = {phoneme: idx for idx, phoneme in enumerate(timit_leehon_39_phonemes)} |
|
|
| idx_to_phoneme_MACRO = {idx: phoneme for phoneme, idx in phoneme_to_idx_MACRO.items()} |
| timit_to_leehon_map_MACRO = { |
| 'aa': 'ao', 'ae': 'ae', 'ah': 'ah', 'ao': 'ao', 'aw': 'aw', 'ax': 'ah', 'ax-h': 'ah', 'axr': 'er', 'ay': 'ay', |
| 'b': 'b', 'bcl': 'sil', 'ch': 'ch', 'd': 'd', 'dcl': 'sil', 'dh': 'dh', 'dx': 'dx', 'eh': 'eh', 'el': 'el', |
| 'em': 'm', 'en': 'en', 'eng': 'ng', 'epi': 'sil', 'er': 'er', 'ey': 'ey', 'f': 'f', 'g': 'g', 'gcl': 'sil', |
| 'h#': 'sil', 'hh': 'hh', 'hv': 'hh', 'ih': 'ih', 'ix': 'ih', 'iy': 'iy', 'jh': 'jh', 'k': 'k', 'kcl': 'sil', |
| 'l': 'el', 'm': 'm', 'n': 'en', 'ng': 'ng', 'nx': 'en', 'ow': 'ow', 'oy': 'oy', 'p': 'p', 'pau': 'sil', 'pcl': 'sil', |
| 'q': 't', 'qcl': 'sil', 'r': 'r', 's': 's', 'sh': 'sh', 't': 't', 'tcl': 'sil', 'th': 'th', 'uh': 'uh', 'uw': 'uw', |
| 'ux': 'uw', 'v': 'v', 'w': 'w', 'y': 'y', 'z': 'z', 'zh': 'sh', |
| } |
|
|
| def create_truth_probs_real(segments, phonemes, phoneme_to_index, num_frames): |
| segments = [0] + list(segments) |
| num_phonemes = len(phoneme_to_index) |
| probs_real = torch.zeros((num_frames, num_phonemes), dtype=torch.float32) |
| for seg_idx in range(len(phonemes)): |
| start = int(segments[seg_idx]) |
| end = int(segments[seg_idx + 1]) if seg_idx + 1 < len(segments) else num_frames |
| if end > start: |
| ph_label = phonemes[seg_idx].lower() |
| ph_index = phoneme_to_index.get(ph_label, phoneme_to_index.get('sil', 0)) |
| probs_real[start:end, ph_index] = 1.0 |
| return probs_real |
|
|
|
|
| |
|
|
| |
| def phoneme_alignment_Hard_DP(p_seq, w_phi, original_lengths, len_ratio, derivative_preds_np, probs_real): |
| |
| gamma = 1e-20 |
| T = int(original_lengths[0]) |
| n = len(p_seq) |
| device = derivative_preds_np.device |
|
|
| if isinstance(probs_real, np.ndarray): |
| probs_real = torch.tensor(probs_real, device=device) |
| cumsum_probs = torch.cumsum(probs_real, dim=0) |
| |
| phoneme_mappings = {p.lower(): timit_to_leehon_map_MACRO.get(p.lower(), 'sil') if p.lower() not in timit_leehon_39_phonemes else p.lower() for p in p_seq} |
| derivatives = torch.cat([torch.tensor([0], device=device), torch.diff(derivative_preds_np, dim=0)]) |
|
|
| |
| dp_mat = torch.full((n, T, T), float(-1e9), device=device) |
| p_idx0 = phoneme_to_idx_MACRO[phoneme_mappings[p_seq[0].lower()]] |
|
|
| |
| t_e = torch.arange(T, device=device) |
| dp_mat[0, 0, :] = ( |
| w_phi[0] * compute_phi_1(derivatives, 0, t_e) |
| ) |
|
|
| |
| for i in tqdm(range(1, n)): |
| p_idx = phoneme_to_idx_MACRO[phoneme_mappings[p_seq[i].lower()]] |
| t_start = torch.arange(T, device=device) |
| t_end = torch.arange(T, device=device) |
| t_start_grid, t_end_grid = torch.meshgrid(t_start, t_end, indexing='ij') |
| valid_mask = t_start_grid < t_end_grid |
| |
| phi1_dev = compute_phi_1(derivatives, t_start_grid, t_end_grid) |
| phi2 = compute_phi_2(cumsum_probs, p_idx, t_start_grid, t_end_grid) |
| total_phi = w_phi[0] * phi1_dev |
| |
| prev_scores = torch.full((T, T), float(-1e9), device=device) |
| |
| for t_end_val in range(T): |
| valid_starts = t_start[t_start < t_end_val] |
| if valid_starts.numel() == 0: |
| continue |
| |
| |
| |
| prev = dp_mat[i-1, :valid_starts[-1]+1, valid_starts] |
| max_prev, _ = torch.max(prev, dim=0) |
| prev_scores[valid_starts, t_end_val] = max_prev |
|
|
| dp_mat[i] = torch.where(valid_mask, total_phi + prev_scores, torch.full_like(total_phi, float(-1e9))) |
|
|
| |
| best_start_times = torch.zeros((n), dtype=derivative_preds_np.dtype, device=device) |
| best_prev_t_end = T - 1 |
| |
| for i in range(n): |
| cur_ph = n - 1 - i |
| |
| scores = dp_mat[cur_ph, :, best_prev_t_end] |
| |
| |
| |
| best_t_start = torch.argmax(scores) |
| |
| best_start_times[cur_ph] = best_t_start.to(derivative_preds_np.dtype) |
| best_prev_t_end = int(best_t_start.item()) |
| |
| |
| dp_mat_cpu = dp_mat.detach().cpu() |
| best_start_times_cpu = best_start_times.detach().cpu().numpy() |
| dp_to_plot = dp_mat_cpu.max(dim=1)[0].numpy() |
| masked_dp = np.ma.masked_where(dp_to_plot <= -1e8, dp_to_plot) |
|
|
| plt.figure(figsize=(12, 6)) |
| cmap = plt.cm.viridis |
| cmap.set_bad(color='white') |
| plt.imshow(masked_dp, aspect='auto', origin='lower', cmap=cmap) |
| plt.colorbar(label='Hard DP Score') |
| plt.xlabel('End time (frame)') |
| plt.ylabel('Phoneme index') |
| plt.title('Classic (Hard) DP Matrix with Best Path') |
| plt.plot(best_start_times_cpu, range(len(best_start_times_cpu)), 'r.-', label='Argmax path') |
| plt.legend() |
| plt.tight_layout() |
| plt.savefig('dp_matrix_hard_classic.png') |
| plt.close() |
| |
| return best_start_times |
|
|
| |
|
|
| from scipy.signal import find_peaks |
|
|
| def phoneme_alignment_naive_peak_detection(p_seq, w_phi, original_lengths, len_ratio, derivative_preds_np, probs_real): |
| """ |
| Ablation version: Replaces DP with Naive Scipy Peak Detection. |
| """ |
| gamma = 1e-20 |
| T = int(original_lengths[0]) |
| n = len(p_seq) |
| device = derivative_preds_np.device |
|
|
| |
| if isinstance(probs_real, np.ndarray): |
| probs_real = torch.tensor(probs_real, device=device) |
| |
| |
| |
| cumsum_probs = torch.cumsum(probs_real, dim=0) |
| signal = derivative_preds_np.detach().cpu().numpy().flatten() |
|
|
| |
| |
| peaks, properties = find_peaks(signal, prominence=0.05) |
| |
| peak_heights = signal[peaks] |
| |
| |
| top_indices = np.argsort(peak_heights)[-n:] |
| best_peaks = np.sort(peaks[top_indices]) |
| if len(best_peaks) < n: |
| filler = np.linspace(0, T-1, n) |
| best_peaks = filler |
| best_start_times = torch.tensor(best_peaks, dtype=derivative_preds_np.dtype, device=device) |
|
|
| |
| dp_mat = torch.full((n, T, T), float(-1e9), device=device) |
| for i, peak_time in enumerate(best_peaks): |
| dp_mat[i, :, int(peak_time)] = 1.0 |
|
|
| |
| dp_mat_cpu = dp_mat.detach().cpu() |
| best_start_times_cpu = best_start_times.detach().cpu().numpy() |
| dp_to_plot = dp_mat_cpu.max(dim=1)[0].numpy() |
| |
| masked_dp = np.ma.masked_where(dp_to_plot <= -1e8, dp_to_plot) |
|
|
| plt.figure(figsize=(12, 6)) |
| cmap = plt.cm.viridis |
| cmap.set_bad(color='white') |
| |
| plt.imshow(masked_dp, aspect='auto', origin='lower', cmap=cmap) |
| plt.colorbar(label='Peak Detection (Naive)') |
| plt.xlabel('End time (frame)') |
| plt.ylabel('Phoneme index') |
| plt.title('Naive Peak Detection (Ablation)') |
| plt.plot(best_start_times_cpu, range(len(best_start_times_cpu)), 'r.-', label='Detected Peaks') |
| plt.legend() |
| plt.tight_layout() |
| save_path = 'peak_detection_ablation.png' |
| plt.savefig(save_path) |
| plt.close() |
| |
| print(f"Ablation plot saved as {save_path}") |
|
|
| return best_start_times |
|
|
| |
| def phoneme_alignment(p_seq, w_phi, original_lengths, len_ratio, derivative_preds_np, probs_real): |
| gamma = 1e-20 |
| T = int(original_lengths[0]) |
| n = len(p_seq) |
| device = derivative_preds_np.device |
|
|
| if isinstance(probs_real, np.ndarray): |
| probs_real = torch.tensor(probs_real, device=device) |
| cumsum_probs = torch.cumsum(probs_real, dim=0) |
| |
| |
| |
| phoneme_mappings = {p.lower(): timit_to_leehon_map_MACRO.get(p.lower(), 'sil') if p.lower() not in timit_leehon_39_phonemes else p.lower() for p in p_seq} |
| derivatives = torch.cat([torch.tensor([0], device=derivative_preds_np.device), torch.diff(derivative_preds_np, dim=0)]) |
|
|
| dp_mat = torch.full((n, T, T), float(-1e9), device=device) |
|
|
| p_idx0 = phoneme_to_idx_MACRO[phoneme_mappings[p_seq[0].lower()]] |
|
|
| |
| t_e = torch.arange(T, device=device) |
| dp_mat[0, 0, :] = ( |
| w_phi[0] * compute_phi_1(derivatives, 0, t_e) |
| + w_phi[1] * compute_phi_1(derivatives, 0, t_e) |
|
|
| ) |
|
|
| for i in tqdm(range(1, n)): |
| p_idx = phoneme_to_idx_MACRO[phoneme_mappings[p_seq[i].lower()]] |
| |
| t_start = torch.arange(T, device=device) |
| t_end = torch.arange(T, device=device) |
| t_start_grid, t_end_grid = torch.meshgrid(t_start, t_end, indexing='ij') |
| valid_mask = t_start_grid < t_end_grid |
|
|
| phi1_dev = compute_phi_1(derivatives, t_start_grid, t_end_grid) |
| phi1 = compute_phi_1(derivative_preds_np, t_start_grid, t_end_grid) |
| phi2 = compute_phi_2(cumsum_probs, p_idx, t_start_grid, t_end_grid) |
| total_phi = w_phi[0] * phi1_dev + w_phi[1] * phi2 |
|
|
| |
| |
| |
| |
| |
| |
| col_lse = torch.logsumexp(dp_mat[i-1] / gamma, dim=0) * gamma |
| prev_scores = torch.where( |
| valid_mask, |
| col_lse.unsqueeze(1).expand(T, T), |
| torch.full((T, T), float(-1e9), device=device), |
| ) |
| dp_mat[i] = torch.where(valid_mask, total_phi + prev_scores, torch.full_like(total_phi, float(-1e9))) |
|
|
| |
| best_start_times = torch.zeros((n), dtype=derivative_preds_np.dtype, device=device) |
| best_prev_t_end = T-1 |
| for i in range(n): |
| cur_ph = n-1-i |
| scores = dp_mat[cur_ph, :, best_prev_t_end] |
| soft_weights = torch.softmax(scores / gamma, dim=0) |
| expected_idx = (soft_weights * torch.arange(T, device=device, dtype=derivative_preds_np.dtype)).sum() |
| best_start_times[cur_ph] = expected_idx |
| best_prev_t_end = int(expected_idx.round().item()) |
| |
| |
| |
| |
| if _dp_matrix_out_dir is not None: |
| dp_mat_cpu = dp_mat.detach().cpu() |
| best_start_times_cpu = best_start_times.detach().cpu().numpy() |
| dp_to_plot = dp_mat_cpu.max(dim=1)[0].numpy() |
| masked_dp = np.ma.masked_where(dp_to_plot <= -1e8, dp_to_plot) |
|
|
| plt.figure(figsize=(12, 6)) |
| cmap = plt.cm.viridis |
| cmap.set_bad(color='white') |
| real_min = masked_dp.min() |
| real_max = masked_dp.max() |
| |
| plt.imshow(masked_dp, aspect='auto', origin='lower', cmap=cmap, vmin=real_min, vmax=real_max) |
| plt.colorbar(label='DP Score (max over start)') |
| plt.xlabel('End time (frame)') |
| plt.ylabel('Phoneme index') |
| plt.title('DP Matrix with Best Path') |
| |
| plt.plot(best_start_times_cpu, range(len(best_start_times_cpu)), 'r.-', label='Best start times') |
| plt.legend() |
| plt.tight_layout() |
| _save_path = os.path.join(_dp_matrix_out_dir or '.', 'dp_matrix_with_path.png') |
| plt.savefig(_save_path) |
| print(f"DP matrix with path plot saved as {_save_path}") |
| |
| return best_start_times |
|
|
| def compute_phi_1(derivative_preds_np: torch.Tensor, t_start: Union[torch.Tensor, int], t_end: Union[torch.Tensor, int]) -> torch.Tensor: |
| """ |
| Computes phi_1 for dynamic programming. |
| t_start and t_end can be scalars or tensors of the same shape. |
| Returns a tensor of scores. |
| """ |
| |
| t_start = torch.as_tensor(t_start, device=derivative_preds_np.device) |
| t_end = torch.as_tensor(t_end, device=derivative_preds_np.device) |
| |
| t_start, t_end = torch.broadcast_tensors(t_start, t_end) |
| |
| valid = (t_end < derivative_preds_np.shape[0]-1) & (t_start < derivative_preds_np.shape[0]-1) & (t_end > 0) & (t_start > 0) |
| score = torch.zeros_like(t_start, dtype=derivative_preds_np.dtype, device=derivative_preds_np.device) |
|
|
| eps = 1e-6 |
| tanh_scale = 1e-3 |
| if valid.any(): |
| |
| |
| idx_s = t_start[valid].long() |
| s_center = torch.tanh(tanh_scale * derivative_preds_np[idx_s]) |
| s_prev = torch.tanh(tanh_scale * derivative_preds_np[idx_s -1]) |
| s_next = torch.tanh(tanh_scale * derivative_preds_np[idx_s +1]) |
| delta_prev_s = s_center - s_prev |
| delta_next_s = s_center - s_next |
| scores_zerocross_s = (1-torch.sqrt(s_center**2)) + torch.sqrt(delta_prev_s **2 + eps) + torch.sqrt(delta_next_s**2 + eps) |
| |
| score[valid] += scores_zerocross_s |
|
|
| |
| idx_e = t_end[valid].long() |
| e_center = torch.tanh(tanh_scale * derivative_preds_np[idx_e]) |
| e_prev = torch.tanh(tanh_scale * derivative_preds_np[idx_e -1]) |
| e_next = torch.tanh(tanh_scale * derivative_preds_np[idx_e +1]) |
| delta_prev_e = e_center - e_prev |
| delta_next_e = e_center - e_next |
| scores_zerocross_e = (1-torch.sqrt(e_center**2)) + torch.sqrt(delta_prev_e **2 + eps) + torch.sqrt(delta_next_e**2 + eps) |
| |
| score[valid] += scores_zerocross_e |
| return score |
|
|
| def compute_phi_2(cumsum_probs: torch.Tensor, p: int, t_start: Union[torch.Tensor, int], t_end: Union[torch.Tensor, int]) -> torch.Tensor: |
| """ |
| Computes phi_2 for dynamic programming. |
| t_start and t_end can be scalars or tensors of the same shape. |
| Returns a tensor of scores. |
| """ |
| t_start = torch.as_tensor(t_start, device=cumsum_probs.device) |
| t_end = torch.as_tensor(t_end, device=cumsum_probs.device) |
| t_start, t_end = torch.broadcast_tensors(t_start, t_end) |
| |
| valid = (t_end < cumsum_probs.shape[0]) & (t_start < cumsum_probs.shape[0]) & (t_end > 0) & (t_start >= 0) |
| probs_score = torch.zeros_like(t_start, dtype=cumsum_probs.dtype, device=cumsum_probs.device) |
| |
| probs_score[valid] = cumsum_probs[t_end[valid], p] - torch.where( |
| t_start[valid] > 0, |
| cumsum_probs[t_start[valid], p], |
| torch.zeros_like(t_start[valid], dtype=cumsum_probs.dtype, device=cumsum_probs.device) |
| ) |
| lengths = (t_end - t_start).clamp(min=1) |
| probs_score[valid] = probs_score[valid] / lengths[valid] |
| return (probs_score) |
|
|
| def best_phoneme_for_segments(cumsum_probs: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor): |
| """ |
| For each (t_start, t_end) pair (tensors broadcasted to same shape), |
| compute the average probability per phoneme over the segment and return: |
| - max_vals: tensor of shape (pairs,) with the max average prob per pair |
| - max_idx: LongTensor of shape (pairs,) with the argmax phoneme index per pair |
| """ |
| device = cumsum_probs.device |
| t_start = torch.as_tensor(t_start, device=device) |
| t_end = torch.as_tensor(t_end, device=device) |
| t_start, t_end = torch.broadcast_tensors(t_start, t_end) |
|
|
| valid = (t_end < cumsum_probs.shape[0]) & (t_start < cumsum_probs.shape[0]) & (t_end > 0) & (t_start >= 0) |
| max_vals = torch.zeros_like(t_start, dtype=cumsum_probs.dtype, device=device) |
| max_idx = torch.full_like(t_start, -1, dtype=torch.long, device=device) |
|
|
| if not valid.any(): |
| return max_vals, max_idx |
|
|
| idx_end = t_end[valid].long() |
| idx_start = t_start[valid].long() |
| probs_end = cumsum_probs[idx_end] |
| probs_start = torch.zeros_like(probs_end) |
| nonzero_mask = idx_start > 0 |
| if nonzero_mask.any(): |
| probs_start[nonzero_mask] = cumsum_probs[idx_start[nonzero_mask]] |
|
|
| segment_sum = probs_end - probs_start |
| lengths = (t_end[valid] - t_start[valid]).clamp(min=1).unsqueeze(1).to(segment_sum.dtype) |
| segment_mean = segment_sum / lengths |
|
|
| vals, idxs = segment_mean.max(dim=1) |
| max_vals[valid] = vals |
| max_idx[valid] = idxs.long() |
| return max_vals, max_idx |
|
|
| def get_timit_61_phoneme_mappings(): |
| """ |
| Returns the TIMIT 61 phoneme-to-index mapping and the reverse index-to-phoneme mapping. |
| |
| Returns: |
| phoneme_to_idx (dict): Dictionary mapping phonemes to unique indices. |
| idx_to_phoneme (dict): Dictionary mapping indices to their corresponding phonemes. |
| """ |
| |
| timit_61_phonemes = [ |
| 'aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', |
| 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', |
| 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm', 'n', 'ng', |
| 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', |
| 'v', 'w', 'y', 'z', 'zh' |
| ] |
| timit_leehon_39_phonemes = [ |
| 'ao', 'ae', 'ah','aw', 'er', 'ay', |
| 'b', 'sil', 'ch', 'd', 'dh', 'dx', 'eh', 'el', 'm', 'en', 'ng', 'ey', |
| 'f', 'g', 'hh', 'ih', 'iy', 'jh', 'k', 'v', 'w', 'y', 'z', 'sh', 't', 'r', 's', 'th','uh', 'uw', 'oy', 'ow','p' |
| ] |
| |
| phoneme_to_idx = {phoneme: idx for idx, phoneme in enumerate(timit_leehon_39_phonemes)} |
| idx_to_phoneme = {idx: phoneme for phoneme, idx in phoneme_to_idx.items()} |
|
|
| return phoneme_to_idx, idx_to_phoneme |
|
|
| |
|
|
|
|
| def timit_to_leehon(timit_label): |
| |
| timit_to_leehon_map = { |
| 'aa': 'ao', 'ae': 'ae', 'ah': 'ah', 'ao': 'ao', 'aw': 'aw', 'ax': 'ah', 'ax-h': 'ah', 'axr': 'er', 'ay': 'ay', |
| 'b': 'b', 'bcl': 'sil', 'ch': 'ch', 'd': 'd', 'dcl': 'sil', 'dh': 'dh', 'dx': 'dx', 'eh': 'eh', 'el': 'el', |
| 'em': 'm', 'en': 'en', 'eng': 'ng', 'epi': 'sil', 'er': 'er', 'ey': 'ey', 'f': 'f', 'g': 'g', 'gcl': 'sil', |
| 'h#': 'sil', 'hh': 'hh', 'hv': 'hh', 'ih': 'ih', 'ix': 'ih', 'iy': 'iy', 'jh': 'jh', 'k': 'k', 'kcl': 'sil', |
| 'l': 'el', 'm': 'm', 'n': 'en', 'ng': 'ng', 'nx': 'en', 'ow': 'ow', 'oy': 'oy', 'p': 'p', 'pau': 'sil', 'pcl': 'sil', |
| 'q': 't', 'qcl': 'sil', 'r': 'r', 's': 's', 'sh': 'sh', 't': 't', 'tcl': 'sil', 'th': 'th', 'uh': 'uh', 'uw': 'uw', |
| 'ux': 'uw', 'v': 'v', 'w': 'w', 'y': 'y', 'z': 'z', 'zh': 'sh', '':'sil' |
| } |
| |
| |
| return timit_to_leehon_map.get(timit_label.lower(), None) |
|
|
| def load_phoneme_stats(): |
| phonemes_path = "phonemes_39" |
| stats_path = "phoneme_stats_39.out" |
| |
| |
| with open(phonemes_path, "r") as f: |
| phonemes = [line.strip() for line in f] |
|
|
| |
| with open(stats_path, "r") as f: |
| lines = f.readlines() |
| mu_values = list(map(float, lines[1].strip().split())) |
| sigma_values = list(map(float, lines[2].strip().split())) |
| |
| |
| phoneme_mu_dict = dict(zip(phonemes, mu_values)) |
| phoneme_sigma_dict = dict(zip(phonemes, sigma_values)) |
| |
| return phoneme_mu_dict, phoneme_sigma_dict |
|
|
| |
|
|
|
|
| def get_mu_stats(p): |
| phoneme_mu_dict, _ = load_phoneme_stats() |
| """Return the mu value for the given phoneme p.""" |
| return phoneme_mu_dict.get(p, None) |
|
|
| def get_sigma_stats(p): |
| _, phoneme_sigma_dict = load_phoneme_stats() |
| """Return the mu value for the given phoneme p.""" |
| return phoneme_sigma_dict.get(p, None) |
|
|
| def replicate_first_k_frames(x, k, dim): |
| return torch.cat([x.index_select(dim=dim, index=torch.LongTensor([0] * k).to(x.device)), x], dim=dim) |
|
|
| class LambdaLayer(nn.Module): |
| def __init__(self, lambd): |
| super(LambdaLayer, self).__init__() |
| self.lambd = lambd |
| def forward(self, x): |
| return self.lambd(x) |
|
|
| class PrintShapeLayer(nn.Module): |
| def __init__(self): |
| super(PrintShapeLayer, self).__init__() |
| def forward(self, x): |
| print(x.shape) |
| return x |
|
|
| def length_to_mask(length, max_len=None, dtype=None): |
| """length: B. |
| return B x max_len. |
| If max_len is None, then max of length will be used. |
| """ |
| assert len(length.shape) == 1, 'Length shape should be 1 dimensional.' |
| max_len = max_len or length.max().item() |
| mask = torch.arange(max_len, device=length.device, |
| dtype=length.dtype).expand(len(length), max_len) < length.unsqueeze(1) |
| if dtype is not None: |
| mask = torch.as_tensor(mask, dtype=dtype, device=length.device) |
| return mask |
|
|
| def detect_peaks_worker(xi,w_phi, p_seq, original_lengths, probs_real, len_ratio, width, distance): |
| print(f"num peaks = {len(p_seq)}") |
| print(f"xi type: {type(xi)}") |
| preds_np = xi.requires_grad_(True) |
| median_h = preds_np.median() |
| preds_np = preds_np - median_h |
| derivative_preds_np = preds_np |
| xmin, xmax = xi.min(), xi.max() |
| xi = (xi - xmin) / (xmax - xmin) |
| xi = xi.flatten() |
| |
| peaks = phoneme_alignment(p_seq,w_phi, original_lengths, len_ratio, derivative_preds_np, probs_real) |
| |
| if len(peaks) == 0: |
| peaks = torch.tensor([xi.shape[0] - 1], device=xi.device) |
| |
| return peaks |
|
|
| def detect_peaks(x,w_phi, original_lengths_all, phonemes, len_ratio, probs_real_all): |
| """Detect peaks of next_frame_classifier using multithreading.""" |
| |
| out = [] |
| with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: |
| xi=x |
| p_seq = phonemes |
| original_lengths = original_lengths_all |
| probs_real = probs_real_all |
| if len(xi)!=0: |
| result = detect_peaks_worker(xi, w_phi, p_seq, [original_lengths], probs_real, len_ratio, width=None, distance=None) |
| out.append(result) |
| |
| return out |
|
|
| class PrecisionRecallMetric: |
| def __init__(self): |
| self.precision_counter = 0 |
| self.recall_counter = 0 |
| self.pred_counter = 0 |
| self.gt_counter = 0 |
| self.eps = 1e-5 |
| self.data = [] |
| self.tolerance = 2 |
| self.width_range = [1] |
| self.distance_range = [1] |
|
|
| def get_metrics(self, precision_counter, recall_counter, pred_counter, gt_counter): |
| EPS = 1e-7 |
| |
| precision = precision_counter / (pred_counter + self.eps) |
| recall = recall_counter / (gt_counter + self.eps) |
| f1 = 2 * (precision * recall) / (precision + recall + self.eps) |
| |
| os = recall / (precision + EPS) - 1 |
| r1 = np.sqrt((1 - recall) ** 2 + os ** 2) |
| r2 = (-os + recall - 1) / (np.sqrt(2)) |
| rval = 1 - (np.abs(r1) + np.abs(r2)) / 2 |
|
|
| return precision, recall, f1, rval |
|
|
| def zero(self): |
| self.data = [] |
| |
| def update(self, seg, pos_pred, length,original_lengths_all, probs_all,phonemes_all): |
| for seg_i, pos_pred_i, length_i , original_length, probs,phonemes in zip(seg, pos_pred, length,original_lengths_all,probs_all,phonemes_all): |
| self.data.append((seg_i, pos_pred_i, length_i.item(),[original_length.item()], probs, phonemes)) |
|
|
|
|
| def get_stats(self, width=None, distance=None): |
| print(f"calculating metrics using {len(self.data)} entries") |
| max_rval = -float("inf") |
| min_l1_dist = float("inf") |
| best_params = None |
| segs = list(map(lambda x: x[0], self.data)) |
| length = list(map(lambda x: x[2], self.data)) |
| yhats = list(map(lambda x: x[1], self.data)) |
| original_lengths_all = list(map(lambda x: x[3], self.data)) |
| probs = list(map(lambda x: x[4], self.data)) |
| phonemes = list(map(lambda x: x[5], self.data)) |
|
|
| width_range = self.width_range |
| distance_range = self.distance_range |
|
|
| if width is not None: |
| width_range = [width] |
| distance_range = [distance] |
| sr = 16000 |
| len_ratio = 161.34011627906978 |
|
|
| for width in width_range: |
| for distance in distance_range: |
| for (y, yhat,original_len, phoneme, prob) in zip(segs, yhats, original_lengths_all, phonemes, probs): |
| if isinstance(y,list): |
| y = torch.tensor(y, device=yhat.device, dtype=yhat.dtype) |
| peaks = detect_peaks(x=yhat,w_phi= [0.5,0.5], |
| original_lengths_all = original_len[0], |
| phonemes = phoneme, |
| len_ratio = 161.34011627906978 , |
| probs_real_all = prob) |
| peaks = peaks[0]* len_ratio/sr |
| yhat = peaks |
| yhat = yhat[1:] |
|
|
| if isinstance(y,list): |
| y = torch.tensor(y, device=yhat.device, dtype=yhat.dtype) |
| y = y*len_ratio/sr |
| l1_dist = torch.mean(torch.abs(y - yhat)).item() |
| l2_dist = torch.mean((y - yhat)**2).item() |
| if l1_dist<min_l1_dist: |
| min_l1_dist = l1_dist |
| out = (l1_dist,l2_dist) |
| best_params = width, distance |
| self.zero() |
| print(f"best peak detection params: {best_params} (width, distance)") |
| print(f"best peak detection L1_DIST: {l1_dist}") |
| print(f"best peak detection L2_DIST: {l2_dist}") |
| return out, best_params |
|
|
|
|
| class StatsMeter: |
| def __init__(self): |
| self.data = [] |
|
|
| def update(self, item): |
| if type(item) == list: |
| self.data.extend(item) |
| else: |
| self.data.append(item) |
|
|
| def get_stats(self): |
| data = np.array(self.data) |
| if len(data)==0: |
| return float('nan') |
| mean = data.mean() |
| return mean |
|
|
| def zero(self): |
| self.data.clear() |
| assert len(self.data) == 0, "StatsMeter didn't clear" |
|
|
|
|
| class Timer: |
| def __init__(self, msg): |
| self.msg = msg |
| self.start_time = None |
|
|
| def __enter__(self): |
| self.start_time = time.time() |
| print(f"{self.msg} -- started") |
|
|
| def __exit__(self, exc_type, exc_value, exc_tb): |
| print(f"{self.msg} -- done in {(time.time() - self.start_time)} secs") |
|
|
|
|
| def max_min_norm(x): |
| x -= x.min(-1, keepdim=True)[0] |
| x /= x.max(-1, keepdim=True)[0] |
| return x |
|
|
|
|
| def line(): |
| print(90 * "-") |