File size: 2,391 Bytes
b62bd24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | from transformers import PreTrainedModel, PretrainedConfig
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
import torchaudio, noisereduce as nr
import os, json
class ASRWithDiarizationConfig(PretrainedConfig):
model_type = "asr_with_diarization"
def __init__(self, hf_token=None, min_speakers=1, max_speakers=5, **kwargs):
super().__init__(**kwargs)
self.hf_token = hf_token
self.min_speakers = min_speakers
self.max_speakers = max_speakers
class ASRWithDiarization(PreTrainedModel):
config_class = ASRWithDiarizationConfig
def __init__(self, config):
super().__init__(config)
self.diarization = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
use_auth_token=config.hf_token
)
self.asr = WhisperModel("medium", device="cpu", compute_type="int8")
def forward(self, audio_path, output_dir, base_name):
os.makedirs(output_dir, exist_ok=True)
# --- Diarization ---
diar = self.diarization(
audio_path,
min_speakers=self.config.min_speakers,
max_speakers=self.config.max_speakers
)
diar_json = [
{"start": t.start, "end": t.end, "speaker": spk}
for t, _, spk in diar.itertracks(yield_label=True)
]
# --- Transcription ---
audio, sr = torchaudio.load(audio_path)
merged_segments = []
for seg in diar_json:
start, end, spk = seg["start"], seg["end"], seg["speaker"]
chunk = audio[0, int(start * sr):int(end * sr)].numpy()
reduced = nr.reduce_noise(y=chunk, sr=sr)
segments, _ = self.asr.transcribe(reduced, word_timestamps=True)
tokens = []
for s in segments:
if hasattr(s, "words"):
for w in s.words:
tokens.append(
{"start": w.start, "end": w.end, "text": w.word}
)
merged_segments.append(
{"speaker": spk, "start": start, "end": end, "tokens": tokens}
)
# Save JSON
out_path = os.path.join(output_dir, f"{base_name}_output.json")
with open(out_path, "w") as f:
json.dump(merged_segments, f, indent=2)
return merged_segments
|