PronunciationChecker / src /pronunciation_checker.py
karlhajal's picture
Update src/pronunciation_checker.py
863c9ae verified
# SPDX-FileContributor: Karl El Hajal
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoFeatureExtractor, AutoModel
from scipy.spatial.distance import cdist
from dtw import accelerated_dtw
from src.audio_preprocessing import process_wav, get_red_green_segments
class PronunciationChecker:
def __init__(self, model_name = "microsoft/wavlm-large"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model_name = model_name
self.processor = AutoFeatureExtractor.from_pretrained(self.model_name)
self.model = AutoModel.from_pretrained(self.model_name).eval().to(self.device)
self.sr = 16000
torch.set_num_threads(1)
self.vad_model, vad_utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
(self.get_speech_timestamps, _, self.read_audio, _, _) = vad_utils
def trim_silences_with_silero_vad(self, audio_path):
wav = self.read_audio(audio_path)
speech_timestamps = self.get_speech_timestamps(
wav,
self.vad_model
)
start_sample = speech_timestamps[0]['start']
end_sample = speech_timestamps[-1]['end']
wav = wav[start_sample:end_sample]
return wav
def preprocess_wav(self, wav_path, do_trim_silences=True):
temp_audio_path = "temp.wav"
audio_segment = process_wav(wav_path, 16000, do_trim_silences=False)
audio_segment.export(temp_audio_path, format="wav")
if do_trim_silences:
wav = self.trim_silences_with_silero_vad(temp_audio_path)
else:
wav = self.read_audio(temp_audio_path)
return wav, self.sr
def extract_features(self, wav, layer=None):
inputs = self.processor(wav.squeeze().to(self.device), sampling_rate=16000, return_tensors="pt", padding=True)
inputs = {key: val.to(self.device) for key, val in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs, output_hidden_states=True)
if layer is None:
features = outputs.last_hidden_state
else:
hidden_states = outputs.hidden_states
features = hidden_states[layer]
features = features.squeeze().cpu().numpy()
return features, wav.squeeze().cpu().numpy(), 16000
@staticmethod
def compute_dtw(ref_features, input_features):
# distance_metric = "euclidean"
distance_metric = "cosine"
dist_matrix = cdist(ref_features, input_features, metric=distance_metric)
_, _, acc, path = accelerated_dtw(ref_features, input_features, dist=distance_metric)
return dist_matrix, path
@staticmethod
def calculate_red_percentage(red_segments, labels_data, scaling_factor):
red_percentages = []
def intersection_length(start1, end1, start2, end2):
overlap_start = max(start1, start2)
overlap_end = min(end1, end2)
return max(0, overlap_end - overlap_start)
for start, end, grapheme, *boolean_labels in labels_data:
red_intersection = 0.0
for index in red_segments:
red_start_time = index * scaling_factor
red_end_time = (index + 1) * scaling_factor
red_intersection += intersection_length(start, end, red_start_time, red_end_time)
total_grapheme_duration = end - start
red_percentage = (red_intersection / total_grapheme_duration)
red_percentages.append(min(red_percentage, 1.))
return red_percentages
@staticmethod
def plot_waveform_with_overlay(wav, sr, dist_matrix, path, wav_type='ref', threshold=0.3, labels_data=None, input_number=None):
feature_stride = 320
scaling_factor = feature_stride / sr
time_ref = np.linspace(0, len(wav) / sr, len(wav))
fig, ax = plt.subplots(4, 1, figsize=(15, 12), gridspec_kw={'height_ratios': [5, 1, 1, 1]})
ax[0].plot(time_ref, wav, label="Waveform", color="blue", alpha=0.7)
red_segments, green_segments, _ = get_red_green_segments(dist_matrix, path, wav_type=wav_type, threshold=threshold)
for index in green_segments:
start_time = index * scaling_factor
end_time = (index + 1) * scaling_factor
ax[0].axvspan(start_time, end_time, facecolor=(0, 1, 0), alpha=0.5)
for index in red_segments:
start_time = index * scaling_factor
end_time = (index + 1) * scaling_factor
ax[0].axvspan(start_time, end_time, facecolor=(1, 0, 0), alpha=0.5)
ax[0].set_xlabel("Time (s)")
ax[0].set_ylabel("Amplitude")
ax[0].set_title("Waveform with DTW Distance Overlay")
ax[0].legend()
ax[1].set_xlim(ax[0].get_xlim())
ax[2].set_xlim(ax[0].get_xlim())
ax[3].set_xlim(ax[0].get_xlim())
if labels_data:
red_percentages = PronunciationChecker.calculate_red_percentage(red_segments, labels_data, scaling_factor)
for index, (start, end, grapheme, *boolean_labels) in enumerate(labels_data):
for subplot in ax:
subplot.axvline(start, color='gray', linestyle='--', alpha=0.7)
subplot.axvline(end, color='gray', linestyle='--', alpha=0.7)
ax[1].text((start + end) / 2, max(wav) * 0.8, grapheme, ha='center', va='center', fontsize=10, color='black')
ax[2].barh(input_number, end - start, left=start, color='green' if boolean_labels[input_number] else 'red', alpha=0.7)
red_percentage = red_percentages[index]
ax[3].barh(input_number, end - start, left=start, color='red', alpha=red_percentage)
ax[2].set_xlabel("Time (s)")
ax[2].set_yticks(range(1))
ax[2].set_yticklabels([f"Input {input_number}"])
ax[2].set_title("Boolean Labels")
ax[2].grid(False)
return fig