# SPDX-FileContributor: Karl El Hajal import torch import numpy as np import matplotlib.pyplot as plt from transformers import AutoFeatureExtractor, AutoModel from scipy.spatial.distance import cdist from dtw import accelerated_dtw from src.audio_preprocessing import process_wav, get_red_green_segments class PronunciationChecker: def __init__(self, model_name = "microsoft/wavlm-large"): self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model_name = model_name self.processor = AutoFeatureExtractor.from_pretrained(self.model_name) self.model = AutoModel.from_pretrained(self.model_name).eval().to(self.device) self.sr = 16000 torch.set_num_threads(1) self.vad_model, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad') (self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils def trim_silences_with_silero_vad(self, audio_path): wav = self.read_audio(audio_path) speech_timestamps = self.get_speech_timestamps( wav, self.vad_model ) start_sample = speech_timestamps[0]['start'] end_sample = speech_timestamps[-1]['end'] wav = wav[start_sample:end_sample] return wav def preprocess_wav(self, wav_path, do_trim_silences=True): temp_audio_path = "temp.wav" audio_segment = process_wav(wav_path, 16000, do_trim_silences=False) audio_segment.export(temp_audio_path, format="wav") if do_trim_silences: wav = self.trim_silences_with_silero_vad(temp_audio_path) else: wav = self.read_audio(temp_audio_path) return wav, self.sr def extract_features(self, wav, layer=None): inputs = self.processor(wav.squeeze().to(self.device), sampling_rate=16000, return_tensors="pt", padding=True) inputs = {key: val.to(self.device) for key, val in inputs.items()} with torch.no_grad(): outputs = self.model(**inputs, output_hidden_states=True) if layer is None: features = outputs.last_hidden_state else: hidden_states = outputs.hidden_states features = hidden_states[layer] features = features.squeeze().cpu().numpy() return features, wav.squeeze().cpu().numpy(), 16000 @staticmethod def compute_dtw(ref_features, input_features): # distance_metric = "euclidean" distance_metric = "cosine" dist_matrix = cdist(ref_features, input_features, metric=distance_metric) _, _, acc, path = accelerated_dtw(ref_features, input_features, dist=distance_metric) return dist_matrix, path @staticmethod def calculate_red_percentage(red_segments, labels_data, scaling_factor): red_percentages = [] def intersection_length(start1, end1, start2, end2): overlap_start = max(start1, start2) overlap_end = min(end1, end2) return max(0, overlap_end - overlap_start) for start, end, grapheme, *boolean_labels in labels_data: red_intersection = 0.0 for index in red_segments: red_start_time = index * scaling_factor red_end_time = (index + 1) * scaling_factor red_intersection += intersection_length(start, end, red_start_time, red_end_time) total_grapheme_duration = end - start red_percentage = (red_intersection / total_grapheme_duration) red_percentages.append(min(red_percentage, 1.)) return red_percentages @staticmethod def plot_waveform_with_overlay(wav, sr, dist_matrix, path, wav_type='ref', threshold=0.3, labels_data=None, input_number=None): feature_stride = 320 scaling_factor = feature_stride / sr time_ref = np.linspace(0, len(wav) / sr, len(wav)) fig, ax = plt.subplots(4, 1, figsize=(15, 12), gridspec_kw={'height_ratios': [5, 1, 1, 1]}) ax[0].plot(time_ref, wav, label="Waveform", color="blue", alpha=0.7) red_segments, green_segments, _ = get_red_green_segments(dist_matrix, path, wav_type=wav_type, threshold=threshold) for index in green_segments: start_time = index * scaling_factor end_time = (index + 1) * scaling_factor ax[0].axvspan(start_time, end_time, facecolor=(0, 1, 0), alpha=0.5) for index in red_segments: start_time = index * scaling_factor end_time = (index + 1) * scaling_factor ax[0].axvspan(start_time, end_time, facecolor=(1, 0, 0), alpha=0.5) ax[0].set_xlabel("Time (s)") ax[0].set_ylabel("Amplitude") ax[0].set_title("Waveform with DTW Distance Overlay") ax[0].legend() ax[1].set_xlim(ax[0].get_xlim()) ax[2].set_xlim(ax[0].get_xlim()) ax[3].set_xlim(ax[0].get_xlim()) if labels_data: red_percentages = PronunciationChecker.calculate_red_percentage(red_segments, labels_data, scaling_factor) for index, (start, end, grapheme, *boolean_labels) in enumerate(labels_data): for subplot in ax: subplot.axvline(start, color='gray', linestyle='--', alpha=0.7) subplot.axvline(end, color='gray', linestyle='--', alpha=0.7) ax[1].text((start + end) / 2, max(wav) * 0.8, grapheme, ha='center', va='center', fontsize=10, color='black') ax[2].barh(input_number, end - start, left=start, color='green' if boolean_labels[input_number] else 'red', alpha=0.7) red_percentage = red_percentages[index] ax[3].barh(input_number, end - start, left=start, color='red', alpha=red_percentage) ax[2].set_xlabel("Time (s)") ax[2].set_yticks(range(1)) ax[2].set_yticklabels([f"Input {input_number}"]) ax[2].set_title("Boolean Labels") ax[2].grid(False) return fig