from transformers import PreTrainedModel, PretrainedConfig from pyannote.audio import Pipeline from faster_whisper import WhisperModel import torchaudio, noisereduce as nr import os, json class ASRWithDiarizationConfig(PretrainedConfig): model_type = "asr_with_diarization" def __init__(self, hf_token=None, min_speakers=1, max_speakers=5, **kwargs): super().__init__(**kwargs) self.hf_token = hf_token self.min_speakers = min_speakers self.max_speakers = max_speakers class ASRWithDiarization(PreTrainedModel): config_class = ASRWithDiarizationConfig def __init__(self, config): super().__init__(config) self.diarization = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=config.hf_token ) self.asr = WhisperModel("medium", device="cpu", compute_type="int8") def forward(self, audio_path, output_dir, base_name): os.makedirs(output_dir, exist_ok=True) # --- Diarization --- diar = self.diarization( audio_path, min_speakers=self.config.min_speakers, max_speakers=self.config.max_speakers ) diar_json = [ {"start": t.start, "end": t.end, "speaker": spk} for t, _, spk in diar.itertracks(yield_label=True) ] # --- Transcription --- audio, sr = torchaudio.load(audio_path) merged_segments = [] for seg in diar_json: start, end, spk = seg["start"], seg["end"], seg["speaker"] chunk = audio[0, int(start * sr):int(end * sr)].numpy() reduced = nr.reduce_noise(y=chunk, sr=sr) segments, _ = self.asr.transcribe(reduced, word_timestamps=True) tokens = [] for s in segments: if hasattr(s, "words"): for w in s.words: tokens.append( {"start": w.start, "end": w.end, "text": w.word} ) merged_segments.append( {"speaker": spk, "start": start, "end": end, "tokens": tokens} ) # Save JSON out_path = os.path.join(output_dir, f"{base_name}_output.json") with open(out_path, "w") as f: json.dump(merged_segments, f, indent=2) return merged_segments