| | import torch
|
| | from speechbrain.inference.interfaces import Pretrained
|
| | import torchaudio
|
| |
|
| |
|
| | def merge_overlapping_segments(segments):
|
| | """
|
| | Merges segments that overlap or are contiguous, ensuring each speaker segment is represented once.
|
| |
|
| | Args:
|
| | segments (list of tuples): List of tuples representing (start, end, label) of segments.
|
| |
|
| | Returns:
|
| | list of tuples: Merged list of segments.
|
| | """
|
| | if not segments:
|
| | return []
|
| | merged = [segments[0]]
|
| | for current in segments[1:]:
|
| | prev = merged[-1]
|
| | if current[0] <= prev[1]:
|
| | if current[2] == prev[2]:
|
| | merged[-1] = (prev[0], max(prev[1], current[1]), prev[2])
|
| | else:
|
| | merged.append(current)
|
| | else:
|
| | merged.append(current)
|
| | return merged
|
| |
|
| |
|
| | def refine_transitions(aggregated_predictions):
|
| | """
|
| | Refines transitions between speaker segments to enhance accuracy.
|
| |
|
| | Args:
|
| | aggregated_predictions (list of tuples): The aggregated predictions with potential overlaps.
|
| |
|
| | Returns:
|
| | list of tuples: Predictions with adjusted transitions.
|
| | """
|
| | refined_predictions = []
|
| | for i in range(len(aggregated_predictions)):
|
| | if i == 0:
|
| | refined_predictions.append(aggregated_predictions[i])
|
| | continue
|
| |
|
| | current_start, current_end, current_label = aggregated_predictions[i]
|
| | prev_start, prev_end, prev_label = aggregated_predictions[i - 1]
|
| |
|
| | if current_start - prev_end <= 1.0:
|
| | new_start = prev_end
|
| | else:
|
| | new_start = current_start
|
| |
|
| | refined_predictions.append((new_start, current_end, current_label))
|
| |
|
| | return refined_predictions
|
| |
|
| |
|
| | def refine_transitions_with_confidence(aggregated_predictions, segment_confidences):
|
| | """
|
| | Refines transitions between segments based on confidence levels.
|
| |
|
| | Args:
|
| | aggregated_predictions (list of tuples): Initial aggregated predictions.
|
| | segment_confidences (list of float): Confidence scores corresponding to each segment.
|
| |
|
| | Returns:
|
| | list of tuples: Refined segment predictions.
|
| | """
|
| | refined_predictions = []
|
| | for i in range(len(aggregated_predictions)):
|
| | if i == 0:
|
| | refined_predictions.append(aggregated_predictions[i])
|
| | continue
|
| |
|
| | current_start, current_end, current_label = aggregated_predictions[i]
|
| | prev_start, prev_end, prev_label, prev_confidence = refined_predictions[-1] + (segment_confidences[i - 1],)
|
| |
|
| | current_confidence = segment_confidences[i]
|
| |
|
| | if current_label != prev_label:
|
| | if prev_confidence < current_confidence:
|
| | transition_point = current_start
|
| | else:
|
| | transition_point = prev_end
|
| | refined_predictions[-1] = (prev_start, transition_point, prev_label)
|
| | refined_predictions.append((transition_point, current_end, current_label))
|
| | else:
|
| | if prev_confidence < current_confidence:
|
| | refined_predictions[-1] = (prev_start, current_end, current_label)
|
| | else:
|
| | refined_predictions.append((current_start, current_end, current_label))
|
| |
|
| | return refined_predictions
|
| |
|
| |
|
| | def aggregate_segments_with_overlap(segment_predictions):
|
| | """
|
| | Aggregates overlapping segments into single segments based on speaker labels.
|
| |
|
| | Args:
|
| | segment_predictions (list of tuples): List of tuples representing (start, end, label) of segments.
|
| |
|
| | Returns:
|
| | list of tuples: Aggregated segments.
|
| | """
|
| | aggregated_predictions = []
|
| | last_start, last_end, last_label = segment_predictions[0]
|
| |
|
| | for start, end, label in segment_predictions[1:]:
|
| | if label == last_label and start <= last_end:
|
| | last_end = max(last_end, end)
|
| | else:
|
| | aggregated_predictions.append((last_start, last_end, last_label))
|
| | last_start, last_end, last_label = start, end, label
|
| |
|
| | aggregated_predictions.append((last_start, last_end, last_label))
|
| |
|
| | merged = merge_overlapping_segments(aggregated_predictions)
|
| | return merged
|
| |
|
| |
|
| | class SpeakerCounter(Pretrained):
|
| | """
|
| | A class for counting speakers in an audio file, built upon the SpeechBrain Pretrained class.
|
| | This class integrates several preprocessing and prediction modules to handle speaker diarization tasks.
|
| | """
|
| |
|
| | def __init__(self, *args, **kwargs):
|
| | """
|
| | Initialize the SpeakerCounter with standard and custom parameters.
|
| | Args:
|
| | *args: Variable length argument list.
|
| | **kwargs: Arbitrary keyword arguments.
|
| | """
|
| | super().__init__(*args, **kwargs)
|
| | self.sample_rate = self.hparams.sample_rate
|
| |
|
| | MODULES_NEEDED = [
|
| | "compute_features",
|
| | "mean_var_norm",
|
| | "embedding_model",
|
| | "classifier",
|
| | ]
|
| |
|
| | def resample_waveform(self, waveform, orig_sample_rate):
|
| | """
|
| | Resamples the input waveform to the target sample rate specified in the object.
|
| |
|
| | Args:
|
| | waveform (Tensor): The input waveform tensor.
|
| | orig_sample_rate (int): The original sample rate of the waveform.
|
| |
|
| | Returns:
|
| | Tensor: The resampled waveform.
|
| | """
|
| | if orig_sample_rate != self.sample_rate:
|
| | resample_transform = torchaudio.transforms.Resample(orig_freq=orig_sample_rate, new_freq=self.sample_rate)
|
| | waveform = resample_transform(waveform)
|
| | return waveform
|
| |
|
| | def encode_batch(self, wavs, wav_lens=None):
|
| | """
|
| | Encodes a batch of waveforms into embeddings using the loaded models.
|
| |
|
| | Args:
|
| | wavs (Tensor): Batch of waveforms.
|
| | wav_lens (Tensor, optional): Lengths of the waveforms for normalization.
|
| |
|
| | Returns:
|
| | Tensor: Batch of embeddings.
|
| | """
|
| | if len(wavs.shape) == 1:
|
| | wavs = wavs.unsqueeze(0)
|
| |
|
| | if wav_lens is None:
|
| | wav_lens = torch.ones(wavs.shape[0], device=self.device)
|
| |
|
| | wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)
|
| | wavs = wavs.float()
|
| |
|
| |
|
| | feats = self.mods.compute_features(wavs)
|
| | feats = self.mods.mean_var_norm(feats, wav_lens)
|
| | embeddings = self.mods.embedding_model(feats, wav_lens)
|
| | return embeddings
|
| |
|
| | def create_segments(self, waveform, segment_length, overlap):
|
| | """
|
| | Creates segments from a single waveform for batch processing.
|
| |
|
| | Args:
|
| | waveform (Tensor): Input waveform tensor.
|
| | segment_length (float): Length of each segment in seconds.
|
| | overlap (float): Overlap between segments in seconds.
|
| |
|
| | Returns:
|
| | tuple: (segments, segment_times) where segments is a list of tensors, and segment_times
|
| | is a list of (start, end) times.
|
| | """
|
| | num_samples = waveform.shape[1]
|
| | segment_samples = int(segment_length * self.sample_rate)
|
| | overlap_samples = int(overlap * self.sample_rate)
|
| | step_samples = segment_samples - overlap_samples
|
| | segments = []
|
| | segment_times = []
|
| |
|
| | for start in range(0, num_samples - segment_samples + 1, step_samples):
|
| | end = start + segment_samples
|
| | segments.append(waveform[:, start:end])
|
| | start_time = start / self.sample_rate
|
| | end_time = end / self.sample_rate
|
| | segment_times.append((start_time, end_time))
|
| |
|
| | return segments, segment_times
|
| |
|
| | def classify_file(self, path, segment_length=2.0, overlap=1.47):
|
| | """
|
| | Processes an audio file to classify and count speakers within segments.
|
| | Utilizes multiple stages of processing to handle overlapping speech and transitions.
|
| |
|
| | Args:
|
| | path (str): Path to the audio file.
|
| | segment_length (float): Length of each segment in seconds.
|
| | overlap (float): Overlap between segments in seconds.
|
| |
|
| | Outputs:
|
| | Writes the number of speakers in each segment to a text file.
|
| | """
|
| | waveform, osr = torchaudio.load(path)
|
| | waveform = self.resample_waveform(waveform, osr)
|
| |
|
| | segments, segment_times = self.create_segments(waveform, segment_length, overlap)
|
| | segment_predictions = []
|
| |
|
| | for segment, (start_time, end_time) in zip(segments, segment_times):
|
| | rel_length = torch.tensor([1.0])
|
| | emb = self.encode_batch(segment, rel_length)
|
| | out_prob = self.mods.classifier(emb).squeeze(1)
|
| | score, index = torch.max(out_prob, dim=-1)
|
| | text_lab = index.item()
|
| | segment_predictions.append((start_time, end_time, text_lab))
|
| |
|
| | aggregated_predictions = aggregate_segments_with_overlap(segment_predictions)
|
| | refined_predictions = refine_transitions(aggregated_predictions)
|
| | preds = refine_transitions_with_confidence(aggregated_predictions, refined_predictions)
|
| |
|
| | with open("sample_segment_predictions.txt", "w") as file:
|
| | for start_time, end_time, prediction in preds:
|
| | speaker_text = "no speech" if str(prediction) == "0" else (
|
| | "1 speaker" if str(prediction) == "1" else f"{prediction} speakers")
|
| | print(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}")
|
| | file.write(f"{start_time:.2f}-{end_time:.2f} has {speaker_text}\n")
|
| |
|
| | def forward(self, wavs, wav_lens=None):
|
| | """
|
| | Forward pass for classifying audio using preloaded modules.
|
| |
|
| | Args:
|
| | wavs (Tensor): Input waveforms.
|
| | wav_lens (Tensor, optional): Lengths of the input waveforms.
|
| |
|
| | Returns:
|
| | Output from classify_file method.
|
| | """
|
| | return self.classify_file(wavs, wav_lens)
|
| |
|