|
|
from typing import List, Union, Optional |
|
|
import os |
|
|
import json |
|
|
import numpy as np |
|
|
import librosa |
|
|
from transformers import pipeline |
|
|
|
|
|
|
|
|
DEFAULT_SAMPLE_RATE = 16000 |
|
|
|
|
|
|
|
|
_PREDICTOR_INSTANCE = None |
|
|
|
|
|
def get_predictor(): |
|
|
""" |
|
|
Get or create the singleton predictor instance. |
|
|
Returns: |
|
|
Predictor: A shared instance of the Predictor class. |
|
|
""" |
|
|
global _PREDICTOR_INSTANCE |
|
|
if _PREDICTOR_INSTANCE is None: |
|
|
_PREDICTOR_INSTANCE = Predictor() |
|
|
return _PREDICTOR_INSTANCE |
|
|
|
|
|
class Predictor: |
|
|
def __init__(self, model_path: Optional[str] = None): |
|
|
""" |
|
|
Initialize the predictor with a pre-trained model. |
|
|
|
|
|
Args: |
|
|
model_path: Optional path to a local model. If None, uses the default HuggingFace model. |
|
|
""" |
|
|
|
|
|
self.model = pipeline("audio-classification", model="bookbot/wav2vec2-adult-child-cls") |
|
|
|
|
|
def preprocess(self, input_item: Union[str, np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Preprocess an input item (either file path or numpy array). |
|
|
|
|
|
Args: |
|
|
input_item: Either a file path string or a numpy array of audio data. |
|
|
|
|
|
Returns: |
|
|
np.ndarray: Processed audio data as a numpy array. |
|
|
|
|
|
Raises: |
|
|
ValueError: If input type is unsupported. |
|
|
""" |
|
|
if isinstance(input_item, str): |
|
|
|
|
|
audio, _ = librosa.load(input_item, sr=DEFAULT_SAMPLE_RATE) |
|
|
return audio |
|
|
elif isinstance(input_item, np.ndarray): |
|
|
return input_item |
|
|
else: |
|
|
raise ValueError(f"Unsupported input type: {type(input_item)}") |
|
|
|
|
|
def predict(self, input_list: List[Union[str, np.ndarray]]) -> List[int]: |
|
|
""" |
|
|
Predict speaker type (child=0, adult=1) for a list of audio inputs. |
|
|
|
|
|
Args: |
|
|
input_list: List of inputs, either file paths or numpy arrays. |
|
|
|
|
|
Returns: |
|
|
List[int]: List of predictions (0=child, 1=adult, -1=unknown). |
|
|
""" |
|
|
|
|
|
processed = [self.preprocess(item) for item in input_list] |
|
|
|
|
|
|
|
|
preds = self.model(processed, sampling_rate=DEFAULT_SAMPLE_RATE) |
|
|
|
|
|
|
|
|
label_map = { |
|
|
"child": 0, |
|
|
"adult": 1 |
|
|
} |
|
|
|
|
|
results = [] |
|
|
for pred in preds: |
|
|
|
|
|
if isinstance(pred, list): |
|
|
label = pred[0]["label"] |
|
|
else: |
|
|
label = pred["label"] |
|
|
results.append(label_map.get(label.lower(), -1)) |
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def assign_speaker_for_audio_list(audio_list: List[Union[str, np.ndarray]]) -> List[str]: |
|
|
""" |
|
|
Assigns speaker IDs for a list of audio segments. |
|
|
|
|
|
Args: |
|
|
audio_list: List of audio inputs (either file paths or numpy arrays, |
|
|
assumed to have sampling rate = 16000). |
|
|
|
|
|
Returns: |
|
|
List[str]: List of speaker IDs corresponding to each audio segment. |
|
|
"Child" for child, "Examiner" for adult. |
|
|
""" |
|
|
if not audio_list: |
|
|
return [] |
|
|
|
|
|
|
|
|
predictor = get_predictor() |
|
|
|
|
|
|
|
|
numeric_labels = predictor.predict(audio_list) |
|
|
|
|
|
|
|
|
speaker_ids = ["Child" if label == 0 else "Examiner" if label == 1 else "Unknown" for label in numeric_labels] |
|
|
return speaker_ids |
|
|
|
|
|
|
|
|
|
|
|
def assign_speaker(session_id: str): |
|
|
|
|
|
base_dir = os.path.join("session_data", session_id) |
|
|
json_path = os.path.join(base_dir, "transcription_cunit.json") |
|
|
wav_path = os.path.join(base_dir, "audio.wav") |
|
|
|
|
|
with open(json_path, "r", encoding="utf-8") as f: |
|
|
data = json.load(f) |
|
|
segments = data.get("segments", []) |
|
|
|
|
|
if not segments: |
|
|
return |
|
|
|
|
|
audio, sr = librosa.load(wav_path, sr=DEFAULT_SAMPLE_RATE, mono=True) |
|
|
n_samples = len(audio) |
|
|
dur_sec = n_samples / float(DEFAULT_SAMPLE_RATE) |
|
|
|
|
|
model_inputs: List[np.ndarray] = [] |
|
|
model_indices: List[int] = [] |
|
|
prefilled_unknown: List[int] = [] |
|
|
|
|
|
for i, seg in enumerate(segments): |
|
|
start = seg.get("start") |
|
|
end = seg.get("end") |
|
|
|
|
|
if ( |
|
|
start is None or end is None |
|
|
or not isinstance(start, (int, float)) |
|
|
or not isinstance(end, (int, float)) |
|
|
or end <= start |
|
|
or start >= dur_sec |
|
|
): |
|
|
prefilled_unknown.append(i) |
|
|
continue |
|
|
|
|
|
s = max(0.0, float(start)) |
|
|
e = min(float(end), dur_sec) |
|
|
|
|
|
if e <= s: |
|
|
prefilled_unknown.append(i) |
|
|
continue |
|
|
|
|
|
s_idx = int(round(s * DEFAULT_SAMPLE_RATE)) |
|
|
e_idx = int(round(e * DEFAULT_SAMPLE_RATE)) |
|
|
|
|
|
s_idx = max(0, min(s_idx, n_samples)) |
|
|
e_idx = max(0, min(e_idx, n_samples)) |
|
|
|
|
|
if e_idx <= s_idx: |
|
|
prefilled_unknown.append(i) |
|
|
continue |
|
|
|
|
|
snippet = audio[s_idx:e_idx] |
|
|
|
|
|
if snippet.size == 0: |
|
|
prefilled_unknown.append(i) |
|
|
continue |
|
|
|
|
|
model_inputs.append(snippet) |
|
|
model_indices.append(i) |
|
|
|
|
|
speakers = ["Unknown"] * len(segments) |
|
|
if model_inputs: |
|
|
predicted = assign_speaker_for_audio_list(model_inputs) |
|
|
for seg_idx, spk in zip(model_indices, predicted): |
|
|
speakers[seg_idx] = spk |
|
|
|
|
|
for seg_idx in prefilled_unknown: |
|
|
speakers[seg_idx] = "Unknown" |
|
|
|
|
|
for i, seg in enumerate(segments): |
|
|
seg["speaker"] = speakers[i] |
|
|
|
|
|
|
|
|
with open(json_path, "w", encoding="utf-8") as f: |
|
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|