Capstone04 commited on
Commit
b62bd24
·
verified ·
1 Parent(s): 3e6a376

Initial push: ASR + Diarization pipeline wrapper

Browse files
Files changed (2) hide show
  1. config.json +5 -0
  2. modeling_asr_with_diarization.py +69 -0
config.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "model_type": "asr_with_diarization",
3
+ "min_speakers": 2,
4
+ "max_speakers": 5
5
+ }
modeling_asr_with_diarization.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel, PretrainedConfig
2
+ from pyannote.audio import Pipeline
3
+ from faster_whisper import WhisperModel
4
+ import torchaudio, noisereduce as nr
5
+ import os, json
6
+
7
+ class ASRWithDiarizationConfig(PretrainedConfig):
8
+ model_type = "asr_with_diarization"
9
+
10
+ def __init__(self, hf_token=None, min_speakers=1, max_speakers=5, **kwargs):
11
+ super().__init__(**kwargs)
12
+ self.hf_token = hf_token
13
+ self.min_speakers = min_speakers
14
+ self.max_speakers = max_speakers
15
+
16
+
17
+ class ASRWithDiarization(PreTrainedModel):
18
+ config_class = ASRWithDiarizationConfig
19
+
20
+ def __init__(self, config):
21
+ super().__init__(config)
22
+ self.diarization = Pipeline.from_pretrained(
23
+ "pyannote/speaker-diarization-3.1",
24
+ use_auth_token=config.hf_token
25
+ )
26
+ self.asr = WhisperModel("medium", device="cpu", compute_type="int8")
27
+
28
+ def forward(self, audio_path, output_dir, base_name):
29
+ os.makedirs(output_dir, exist_ok=True)
30
+
31
+ # --- Diarization ---
32
+ diar = self.diarization(
33
+ audio_path,
34
+ min_speakers=self.config.min_speakers,
35
+ max_speakers=self.config.max_speakers
36
+ )
37
+ diar_json = [
38
+ {"start": t.start, "end": t.end, "speaker": spk}
39
+ for t, _, spk in diar.itertracks(yield_label=True)
40
+ ]
41
+
42
+ # --- Transcription ---
43
+ audio, sr = torchaudio.load(audio_path)
44
+ merged_segments = []
45
+
46
+ for seg in diar_json:
47
+ start, end, spk = seg["start"], seg["end"], seg["speaker"]
48
+ chunk = audio[0, int(start * sr):int(end * sr)].numpy()
49
+ reduced = nr.reduce_noise(y=chunk, sr=sr)
50
+
51
+ segments, _ = self.asr.transcribe(reduced, word_timestamps=True)
52
+ tokens = []
53
+ for s in segments:
54
+ if hasattr(s, "words"):
55
+ for w in s.words:
56
+ tokens.append(
57
+ {"start": w.start, "end": w.end, "text": w.word}
58
+ )
59
+
60
+ merged_segments.append(
61
+ {"speaker": spk, "start": start, "end": end, "tokens": tokens}
62
+ )
63
+
64
+ # Save JSON
65
+ out_path = os.path.join(output_dir, f"{base_name}_output.json")
66
+ with open(out_path, "w") as f:
67
+ json.dump(merged_segments, f, indent=2)
68
+
69
+ return merged_segments