Spaces:
Sleeping
Sleeping
| # SPDX-FileContributor: Karl El Hajal | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from transformers import AutoFeatureExtractor, AutoModel | |
| from scipy.spatial.distance import cdist | |
| from dtw import accelerated_dtw | |
| from src.audio_preprocessing import process_wav, get_red_green_segments | |
| class PronunciationChecker: | |
| def __init__(self, model_name = "microsoft/wavlm-large"): | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model_name = model_name | |
| self.processor = AutoFeatureExtractor.from_pretrained(self.model_name) | |
| self.model = AutoModel.from_pretrained(self.model_name).eval().to(self.device) | |
| self.sr = 16000 | |
| torch.set_num_threads(1) | |
| self.vad_model, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad') | |
| (self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils | |
| def trim_silences_with_silero_vad(self, audio_path): | |
| wav = self.read_audio(audio_path) | |
| speech_timestamps = self.get_speech_timestamps( | |
| wav, | |
| self.vad_model | |
| ) | |
| start_sample = speech_timestamps[0]['start'] | |
| end_sample = speech_timestamps[-1]['end'] | |
| wav = wav[start_sample:end_sample] | |
| return wav | |
| def preprocess_wav(self, wav_path, do_trim_silences=True): | |
| temp_audio_path = "temp.wav" | |
| audio_segment = process_wav(wav_path, 16000, do_trim_silences=False) | |
| audio_segment.export(temp_audio_path, format="wav") | |
| if do_trim_silences: | |
| wav = self.trim_silences_with_silero_vad(temp_audio_path) | |
| else: | |
| wav = self.read_audio(temp_audio_path) | |
| return wav, self.sr | |
| def extract_features(self, wav, layer=None): | |
| inputs = self.processor(wav.squeeze().to(self.device), sampling_rate=16000, return_tensors="pt", padding=True) | |
| inputs = {key: val.to(self.device) for key, val in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs, output_hidden_states=True) | |
| if layer is None: | |
| features = outputs.last_hidden_state | |
| else: | |
| hidden_states = outputs.hidden_states | |
| features = hidden_states[layer] | |
| features = features.squeeze().cpu().numpy() | |
| return features, wav.squeeze().cpu().numpy(), 16000 | |
| def compute_dtw(ref_features, input_features): | |
| # distance_metric = "euclidean" | |
| distance_metric = "cosine" | |
| dist_matrix = cdist(ref_features, input_features, metric=distance_metric) | |
| _, _, acc, path = accelerated_dtw(ref_features, input_features, dist=distance_metric) | |
| return dist_matrix, path | |
| def calculate_red_percentage(red_segments, labels_data, scaling_factor): | |
| red_percentages = [] | |
| def intersection_length(start1, end1, start2, end2): | |
| overlap_start = max(start1, start2) | |
| overlap_end = min(end1, end2) | |
| return max(0, overlap_end - overlap_start) | |
| for start, end, grapheme, *boolean_labels in labels_data: | |
| red_intersection = 0.0 | |
| for index in red_segments: | |
| red_start_time = index * scaling_factor | |
| red_end_time = (index + 1) * scaling_factor | |
| red_intersection += intersection_length(start, end, red_start_time, red_end_time) | |
| total_grapheme_duration = end - start | |
| red_percentage = (red_intersection / total_grapheme_duration) | |
| red_percentages.append(min(red_percentage, 1.)) | |
| return red_percentages | |
| def plot_waveform_with_overlay(wav, sr, dist_matrix, path, wav_type='ref', threshold=0.3, labels_data=None, input_number=None): | |
| feature_stride = 320 | |
| scaling_factor = feature_stride / sr | |
| time_ref = np.linspace(0, len(wav) / sr, len(wav)) | |
| fig, ax = plt.subplots(4, 1, figsize=(15, 12), gridspec_kw={'height_ratios': [5, 1, 1, 1]}) | |
| ax[0].plot(time_ref, wav, label="Waveform", color="blue", alpha=0.7) | |
| red_segments, green_segments, _ = get_red_green_segments(dist_matrix, path, wav_type=wav_type, threshold=threshold) | |
| for index in green_segments: | |
| start_time = index * scaling_factor | |
| end_time = (index + 1) * scaling_factor | |
| ax[0].axvspan(start_time, end_time, facecolor=(0, 1, 0), alpha=0.5) | |
| for index in red_segments: | |
| start_time = index * scaling_factor | |
| end_time = (index + 1) * scaling_factor | |
| ax[0].axvspan(start_time, end_time, facecolor=(1, 0, 0), alpha=0.5) | |
| ax[0].set_xlabel("Time (s)") | |
| ax[0].set_ylabel("Amplitude") | |
| ax[0].set_title("Waveform with DTW Distance Overlay") | |
| ax[0].legend() | |
| ax[1].set_xlim(ax[0].get_xlim()) | |
| ax[2].set_xlim(ax[0].get_xlim()) | |
| ax[3].set_xlim(ax[0].get_xlim()) | |
| if labels_data: | |
| red_percentages = PronunciationChecker.calculate_red_percentage(red_segments, labels_data, scaling_factor) | |
| for index, (start, end, grapheme, *boolean_labels) in enumerate(labels_data): | |
| for subplot in ax: | |
| subplot.axvline(start, color='gray', linestyle='--', alpha=0.7) | |
| subplot.axvline(end, color='gray', linestyle='--', alpha=0.7) | |
| ax[1].text((start + end) / 2, max(wav) * 0.8, grapheme, ha='center', va='center', fontsize=10, color='black') | |
| ax[2].barh(input_number, end - start, left=start, color='green' if boolean_labels[input_number] else 'red', alpha=0.7) | |
| red_percentage = red_percentages[index] | |
| ax[3].barh(input_number, end - start, left=start, color='red', alpha=red_percentage) | |
| ax[2].set_xlabel("Time (s)") | |
| ax[2].set_yticks(range(1)) | |
| ax[2].set_yticklabels([f"Input {input_number}"]) | |
| ax[2].set_title("Boolean Labels") | |
| ax[2].grid(False) | |
| return fig |