Initial push: ASR + Diarization pipeline wrapper
Browse files- config.json +5 -0
- modeling_asr_with_diarization.py +69 -0
config.json
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "asr_with_diarization",
|
| 3 |
+
"min_speakers": 2,
|
| 4 |
+
"max_speakers": 5
|
| 5 |
+
}
|
modeling_asr_with_diarization.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 2 |
+
from pyannote.audio import Pipeline
|
| 3 |
+
from faster_whisper import WhisperModel
|
| 4 |
+
import torchaudio, noisereduce as nr
|
| 5 |
+
import os, json
|
| 6 |
+
|
| 7 |
+
class ASRWithDiarizationConfig(PretrainedConfig):
|
| 8 |
+
model_type = "asr_with_diarization"
|
| 9 |
+
|
| 10 |
+
def __init__(self, hf_token=None, min_speakers=1, max_speakers=5, **kwargs):
|
| 11 |
+
super().__init__(**kwargs)
|
| 12 |
+
self.hf_token = hf_token
|
| 13 |
+
self.min_speakers = min_speakers
|
| 14 |
+
self.max_speakers = max_speakers
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class ASRWithDiarization(PreTrainedModel):
|
| 18 |
+
config_class = ASRWithDiarizationConfig
|
| 19 |
+
|
| 20 |
+
def __init__(self, config):
|
| 21 |
+
super().__init__(config)
|
| 22 |
+
self.diarization = Pipeline.from_pretrained(
|
| 23 |
+
"pyannote/speaker-diarization-3.1",
|
| 24 |
+
use_auth_token=config.hf_token
|
| 25 |
+
)
|
| 26 |
+
self.asr = WhisperModel("medium", device="cpu", compute_type="int8")
|
| 27 |
+
|
| 28 |
+
def forward(self, audio_path, output_dir, base_name):
|
| 29 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
# --- Diarization ---
|
| 32 |
+
diar = self.diarization(
|
| 33 |
+
audio_path,
|
| 34 |
+
min_speakers=self.config.min_speakers,
|
| 35 |
+
max_speakers=self.config.max_speakers
|
| 36 |
+
)
|
| 37 |
+
diar_json = [
|
| 38 |
+
{"start": t.start, "end": t.end, "speaker": spk}
|
| 39 |
+
for t, _, spk in diar.itertracks(yield_label=True)
|
| 40 |
+
]
|
| 41 |
+
|
| 42 |
+
# --- Transcription ---
|
| 43 |
+
audio, sr = torchaudio.load(audio_path)
|
| 44 |
+
merged_segments = []
|
| 45 |
+
|
| 46 |
+
for seg in diar_json:
|
| 47 |
+
start, end, spk = seg["start"], seg["end"], seg["speaker"]
|
| 48 |
+
chunk = audio[0, int(start * sr):int(end * sr)].numpy()
|
| 49 |
+
reduced = nr.reduce_noise(y=chunk, sr=sr)
|
| 50 |
+
|
| 51 |
+
segments, _ = self.asr.transcribe(reduced, word_timestamps=True)
|
| 52 |
+
tokens = []
|
| 53 |
+
for s in segments:
|
| 54 |
+
if hasattr(s, "words"):
|
| 55 |
+
for w in s.words:
|
| 56 |
+
tokens.append(
|
| 57 |
+
{"start": w.start, "end": w.end, "text": w.word}
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
merged_segments.append(
|
| 61 |
+
{"speaker": spk, "start": start, "end": end, "tokens": tokens}
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Save JSON
|
| 65 |
+
out_path = os.path.join(output_dir, f"{base_name}_output.json")
|
| 66 |
+
with open(out_path, "w") as f:
|
| 67 |
+
json.dump(merged_segments, f, indent=2)
|
| 68 |
+
|
| 69 |
+
return merged_segments
|