Capstone04 commited on
Commit
8f40048
·
verified ·
1 Parent(s): 7044c08

Upload folder using huggingface_hub

Browse files
README.md ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - asr
5
+ - diarization
6
+ pipeline_tag: automatic-speech-recognition
7
+ ---
8
+ # ASR + Diarization Pipeline
9
+
10
+ This package provides an **Automatic Speech Recognition (ASR) + Speaker Diarization** pipeline using:
11
+ - [OpenAI Whisper](https://huggingface.co/openai/whisper-medium)
12
+ - [Pyannote diarization](https://huggingface.co/pyannote/speaker-diarization-3.1)
13
+
14
+ ## Install
15
+ ```bash
16
+ pip install git+https://huggingface.co/Capstone04/asr-diarization-pipeline
asr_diarization/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .pipeline import ASR_Diarization
asr_diarization/inference.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .pipeline import ASR_Diarization
3
+
4
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
5
+ pipe = ASR_Diarization(HF_TOKEN)
6
+
7
+ def inference(inputs):
8
+ return pipe(inputs)
9
+
10
+ def inference_with_eval(inputs, output_dir, base_name, ref_rttm=None, ref_json=None):
11
+ result = pipe(inputs)
12
+ pipe.evaluate(output_dir, base_name, ref_rttm, ref_json)
13
+ return result
asr_diarization/pipeline.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torchaudio
5
+ import noisereduce as nr
6
+ from pyannote.audio import Pipeline
7
+ from transformers import pipeline as hf_pipeline
8
+ import tempfile
9
+
10
+ from pyannote.core import Annotation, Segment
11
+ from pyannote.metrics.diarization import DiarizationErrorRate
12
+ from jiwer import wer, Compose, ToLowerCase, RemovePunctuation, RemoveMultipleSpaces, Strip
13
+
14
+
15
+ class ASR_Diarization:
16
+ def __init__(self, HF_TOKEN,
17
+ diar_model="pyannote/speaker-diarization-3.1",
18
+ asr_model="openai/whisper-medium"):
19
+ self.HF_TOKEN = HF_TOKEN
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ # Load diarization model
23
+ self.diar_pipeline = Pipeline.from_pretrained(diar_model, use_auth_token=HF_TOKEN)
24
+
25
+ # Load ASR model with timestamps
26
+ self.asr_pipeline = hf_pipeline(
27
+ "automatic-speech-recognition",
28
+ model=asr_model,
29
+ device=0 if self.device == "cuda" else -1,
30
+ return_timestamps=True
31
+ )
32
+
33
+ def run_diarization(self, audio_path):
34
+ diarization = self.diar_pipeline(audio_path)
35
+ return [
36
+ {"start": t.start, "end": t.end, "speaker": spk}
37
+ for t, _, spk in diarization.itertracks(yield_label=True)
38
+ ]
39
+
40
+ def run_transcription(self, audio_path, diar_json):
41
+ audio, sr = torchaudio.load(audio_path)
42
+ merged_segments = []
43
+ speaker_segments = {}
44
+
45
+ for seg in diar_json:
46
+ start, end, spk = seg["start"], seg["end"], seg["speaker"]
47
+ start_sample, end_sample = int(start * sr), int(end * sr)
48
+ chunk = audio[0, start_sample:end_sample].numpy()
49
+
50
+ reduced = nr.reduce_noise(y=chunk, sr=sr)
51
+ result = self.asr_pipeline(reduced)
52
+
53
+ tokens = []
54
+ if "chunks" in result:
55
+ for word_info in result["chunks"]:
56
+ start_ts, end_ts = word_info.get("timestamp", (None, None)) or (None, None)
57
+ tokens.append({
58
+ "start": start_ts,
59
+ "end": end_ts,
60
+ "text": word_info["text"],
61
+ "tag": "w"
62
+ })
63
+
64
+ seg_dict = {
65
+ "speaker": spk,
66
+ "start": start,
67
+ "end": end,
68
+ "tokens": tokens
69
+ }
70
+ merged_segments.append(seg_dict)
71
+
72
+ if spk not in speaker_segments:
73
+ speaker_segments[spk] = []
74
+ speaker_segments[spk].append(seg_dict)
75
+
76
+ return merged_segments, list(speaker_segments.keys())
77
+
78
+ def run_pipeline(self, audio_path, output_dir=None, base_name=None,
79
+ ref_rttm=None, ref_json=None):
80
+ diar_json = self.run_diarization(audio_path)
81
+ merged_segments, speakers = self.run_transcription(audio_path, diar_json)
82
+
83
+ if output_dir and base_name:
84
+ os.makedirs(output_dir, exist_ok=True)
85
+
86
+ # Save RTTM
87
+ rttm_path = os.path.join(output_dir, f"{base_name}.rttm")
88
+ with open(rttm_path, "w") as f:
89
+ for seg in diar_json:
90
+ f.write(
91
+ f"SPEAKER {base_name} 1 {seg['start']:.6f} "
92
+ f"{seg['end']-seg['start']:.6f} <NA> <NA> "
93
+ f"{seg['speaker']} <NA>\n"
94
+ )
95
+
96
+ # Save transcription
97
+ merged_path = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
98
+ with open(merged_path, "w") as f:
99
+ json.dump(merged_segments, f, indent=2)
100
+
101
+ # --- evaluation if refs are provided ---
102
+ eval_results = None
103
+ if ref_rttm or ref_json:
104
+ eval_results = self.evaluate(output_dir, base_name,
105
+ ref_rttm=ref_rttm, ref_json=ref_json)
106
+
107
+ return {
108
+ "speakers": speakers,
109
+ "segments": merged_segments,
110
+ "evaluation": eval_results
111
+ }
112
+
113
+ def evaluate(self, output_dir, base_name, ref_rttm=None, ref_json=None):
114
+ results = {}
115
+
116
+ hyp_rttm = os.path.join(output_dir, f"{base_name}.rttm")
117
+ hyp_json = os.path.join(output_dir, f"{base_name}_merged_transcription.json")
118
+
119
+ if ref_rttm:
120
+ def load_rttm(path):
121
+ ann = Annotation()
122
+ for line in open(path):
123
+ if line.startswith("SPEAKER"):
124
+ p = line.split()
125
+ start, dur, spk = float(p[3]), float(p[4]), p[7]
126
+ ann[Segment(start, start+dur)] = spk
127
+ return ann
128
+
129
+ der_score = DiarizationErrorRate()(load_rttm(ref_rttm), load_rttm(hyp_rttm))
130
+ results["DER"] = round(der_score * 100, 2)
131
+
132
+ if ref_json:
133
+ def load_words(path):
134
+ data = json.load(open(path))
135
+ return " ".join([tok["text"] for seg in data for tok in seg["tokens"]])
136
+
137
+ ref_text, hyp_text = load_words(ref_json), load_words(hyp_json)
138
+ transform = Compose([ToLowerCase(), RemovePunctuation(),
139
+ RemoveMultipleSpaces(), Strip()])
140
+ results["WER_raw"] = round(wer(ref_text, hyp_text), 4)
141
+ results["WER_normalized"] = round(wer(transform(ref_text), transform(hyp_text)), 4)
142
+
143
+ return results if results else None
144
+
145
+ def __call__(self, inputs):
146
+ if isinstance(inputs, dict):
147
+ if "audio_bytes" in inputs:
148
+ audio_bytes = inputs["audio_bytes"]
149
+ elif "audio" in inputs:
150
+ audio_bytes = inputs["audio"]
151
+ else:
152
+ raise ValueError("No audio found in inputs")
153
+ else:
154
+ audio_bytes = inputs
155
+
156
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
157
+ tmp.write(audio_bytes)
158
+ tmp_path = tmp.name
159
+
160
+ return self.run_pipeline(tmp_path)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ pyannote.audio
4
+ transformers
5
+ noisereduce
6
+ jiwer
7
+ librosa
setup.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="asr_diarization",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ "torch",
9
+ "torchaudio",
10
+ "pyannote.audio",
11
+ "transformers",
12
+ "noisereduce",
13
+ "jiwer",
14
+ "librosa"
15
+ ],
16
+ )