SATEv1.5 / speaker /speaker_identification.py
Shuwei Hou
update_speaker_id_to_json
a213dac
from typing import List, Union, Optional
import os
import json
import numpy as np
import librosa
from transformers import pipeline
# Default sample rate for audio processing
DEFAULT_SAMPLE_RATE = 16000
# Singleton pattern to avoid loading the model multiple times
_PREDICTOR_INSTANCE = None
def get_predictor():
"""
Get or create the singleton predictor instance.
Returns:
Predictor: A shared instance of the Predictor class.
"""
global _PREDICTOR_INSTANCE
if _PREDICTOR_INSTANCE is None:
_PREDICTOR_INSTANCE = Predictor()
return _PREDICTOR_INSTANCE
class Predictor:
def __init__(self, model_path: Optional[str] = None):
"""
Initialize the predictor with a pre-trained model.
Args:
model_path: Optional path to a local model. If None, uses the default HuggingFace model.
"""
# Load Hugging Face audio-classification pipeline
self.model = pipeline("audio-classification", model="bookbot/wav2vec2-adult-child-cls")
def preprocess(self, input_item: Union[str, np.ndarray]) -> np.ndarray:
"""
Preprocess an input item (either file path or numpy array).
Args:
input_item: Either a file path string or a numpy array of audio data.
Returns:
np.ndarray: Processed audio data as a numpy array.
Raises:
ValueError: If input type is unsupported.
"""
if isinstance(input_item, str):
# Load audio file to numpy array
audio, _ = librosa.load(input_item, sr=DEFAULT_SAMPLE_RATE)
return audio
elif isinstance(input_item, np.ndarray):
return input_item
else:
raise ValueError(f"Unsupported input type: {type(input_item)}")
def predict(self, input_list: List[Union[str, np.ndarray]]) -> List[int]:
"""
Predict speaker type (child=0, adult=1) for a list of audio inputs.
Args:
input_list: List of inputs, either file paths or numpy arrays.
Returns:
List[int]: List of predictions (0=child, 1=adult, -1=unknown).
"""
# Preprocess all inputs first
processed = [self.preprocess(item) for item in input_list]
# Batch inference
preds = self.model(processed, sampling_rate=DEFAULT_SAMPLE_RATE)
# Map label to 0 (child) or 1 (adult)
label_map = {
"child": 0,
"adult": 1
}
results = []
for pred in preds:
# pred can be a list of dicts (top-k), take the top prediction
if isinstance(pred, list):
label = pred[0]["label"]
else:
label = pred["label"]
results.append(label_map.get(label.lower(), -1)) # -1 for unknown label
return results
# Usage:
# predictor = Predictor("path/to/model")
# predictions = predictor.predict(list_of_inputs)
def assign_speaker_for_audio_list(audio_list: List[Union[str, np.ndarray]]) -> List[str]:
"""
Assigns speaker IDs for a list of audio segments.
Args:
audio_list: List of audio inputs (either file paths or numpy arrays,
assumed to have sampling rate = 16000).
Returns:
List[str]: List of speaker IDs corresponding to each audio segment.
"Child" for child, "Examiner" for adult.
"""
if not audio_list:
return []
# Use singleton predictor to avoid reloading model
predictor = get_predictor()
# Get list of 0 (child) or 1 (adult)
numeric_labels = predictor.predict(audio_list)
# Map to Child and Examiner, preserving order
speaker_ids = ["Child" if label == 0 else "Examiner" if label == 1 else "Unknown" for label in numeric_labels]
return speaker_ids
# you don't have to implement this function
def assign_speaker(session_id: str):
base_dir = os.path.join("session_data", session_id)
json_path = os.path.join(base_dir, "transcription_cunit.json")
wav_path = os.path.join(base_dir, "audio.wav")
with open(json_path, "r", encoding="utf-8") as f:
data = json.load(f)
segments = data.get("segments", [])
if not segments:
return
audio, sr = librosa.load(wav_path, sr=DEFAULT_SAMPLE_RATE, mono=True)
n_samples = len(audio)
dur_sec = n_samples / float(DEFAULT_SAMPLE_RATE)
model_inputs: List[np.ndarray] = []
model_indices: List[int] = []
prefilled_unknown: List[int] = []
for i, seg in enumerate(segments):
start = seg.get("start")
end = seg.get("end")
if (
start is None or end is None
or not isinstance(start, (int, float))
or not isinstance(end, (int, float))
or end <= start
or start >= dur_sec
):
prefilled_unknown.append(i)
continue
s = max(0.0, float(start))
e = min(float(end), dur_sec)
if e <= s:
prefilled_unknown.append(i)
continue
s_idx = int(round(s * DEFAULT_SAMPLE_RATE))
e_idx = int(round(e * DEFAULT_SAMPLE_RATE))
s_idx = max(0, min(s_idx, n_samples))
e_idx = max(0, min(e_idx, n_samples))
if e_idx <= s_idx:
prefilled_unknown.append(i)
continue
snippet = audio[s_idx:e_idx]
if snippet.size == 0:
prefilled_unknown.append(i)
continue
model_inputs.append(snippet)
model_indices.append(i)
speakers = ["Unknown"] * len(segments)
if model_inputs:
predicted = assign_speaker_for_audio_list(model_inputs) # ["Child"/"Examiner"/"Unknown"]
for seg_idx, spk in zip(model_indices, predicted):
speakers[seg_idx] = spk
for seg_idx in prefilled_unknown:
speakers[seg_idx] = "Unknown"
for i, seg in enumerate(segments):
seg["speaker"] = speakers[i]
with open(json_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)