File size: 2,391 Bytes
b62bd24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from transformers import PreTrainedModel, PretrainedConfig
from pyannote.audio import Pipeline
from faster_whisper import WhisperModel
import torchaudio, noisereduce as nr
import os, json

class ASRWithDiarizationConfig(PretrainedConfig):
    model_type = "asr_with_diarization"

    def __init__(self, hf_token=None, min_speakers=1, max_speakers=5, **kwargs):
        super().__init__(**kwargs)
        self.hf_token = hf_token
        self.min_speakers = min_speakers
        self.max_speakers = max_speakers


class ASRWithDiarization(PreTrainedModel):
    config_class = ASRWithDiarizationConfig

    def __init__(self, config):
        super().__init__(config)
        self.diarization = Pipeline.from_pretrained(
            "pyannote/speaker-diarization-3.1",
            use_auth_token=config.hf_token
        )
        self.asr = WhisperModel("medium", device="cpu", compute_type="int8")

    def forward(self, audio_path, output_dir, base_name):
        os.makedirs(output_dir, exist_ok=True)

        # --- Diarization ---
        diar = self.diarization(
            audio_path,
            min_speakers=self.config.min_speakers,
            max_speakers=self.config.max_speakers
        )
        diar_json = [
            {"start": t.start, "end": t.end, "speaker": spk}
            for t, _, spk in diar.itertracks(yield_label=True)
        ]

        # --- Transcription ---
        audio, sr = torchaudio.load(audio_path)
        merged_segments = []

        for seg in diar_json:
            start, end, spk = seg["start"], seg["end"], seg["speaker"]
            chunk = audio[0, int(start * sr):int(end * sr)].numpy()
            reduced = nr.reduce_noise(y=chunk, sr=sr)

            segments, _ = self.asr.transcribe(reduced, word_timestamps=True)
            tokens = []
            for s in segments:
                if hasattr(s, "words"):
                    for w in s.words:
                        tokens.append(
                            {"start": w.start, "end": w.end, "text": w.word}
                        )

            merged_segments.append(
                {"speaker": spk, "start": start, "end": end, "tokens": tokens}
            )

        # Save JSON
        out_path = os.path.join(output_dir, f"{base_name}_output.json")
        with open(out_path, "w") as f:
            json.dump(merged_segments, f, indent=2)

        return merged_segments