palli23 commited on
Commit
845e97f
·
1 Parent(s): 399a407

diarization1Mæló

Browse files
Files changed (2) hide show
  1. app.py +94 -55
  2. requirements.txt +7 -7
app.py CHANGED
@@ -1,82 +1,121 @@
1
  import os
2
  import gradio as gr
3
  import spaces
4
- import tempfile
 
 
 
 
5
  import torch
 
6
 
7
- from transformers import pipeline
8
- from pyannote.audio import Pipeline
9
 
10
- # ==========================================================
11
- # ZeroGPU SAFE GLOBALS FIX — PYANNOTE 3.1 CHECKPOINT COMPAT
12
- # ==========================================================
13
- from torch.serialization import add_safe_globals
14
- from pyannote.audio.core.task import Specifications
15
- from pyannote.audio.core.model import Model
16
 
17
- add_safe_globals({
18
- "Specifications": Specifications,
19
- "pyannote.audio.core.task.Specifications": Specifications,
20
- "Model": Model,
21
- "pyannote.audio.core.model.Model": Model,
22
- })
23
 
 
 
 
 
 
 
24
 
25
- ASR_MODEL = "palli23/whisper-small-sam_spjall"
26
- DIAR_MODEL = "pyannote/speaker-diarization-3.1"
27
 
 
 
 
28
 
29
- @spaces.GPU(duration=120)
30
- def transcribe_with_diarization(audio_path):
31
 
32
- if not audio_path:
33
- return "Hladdu upp hljóðskrá."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # ----------------------------
36
- # Load diarization pipeline
37
- # (NO token argument!)
38
- # ----------------------------
39
- diarization = Pipeline.from_pretrained(
40
- DIAR_MODEL,
41
- cache_dir="/home/user/.cache"
42
- ).to("cuda")
43
 
44
- diar = diarization(audio_path)
 
 
 
 
45
 
46
- # ----------------------------
47
- # Whisper ASR
48
- # ----------------------------
49
- asr = pipeline(
50
- task="automatic-speech-recognition",
51
- model=ASR_MODEL,
52
- device=0
53
- )
54
 
55
- output_lines = []
 
 
 
 
 
56
 
57
- for turn, _, speaker in diar.itertracks(yield_label=True):
58
 
59
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
60
- diar.crop(audio_path, turn).export(tmp.name, format="wav")
61
- seg_file = tmp.name
62
 
63
- text = asr(seg_file)["text"].strip()
64
- output_lines.append(f"[MÆLENDI {speaker}] {text}")
65
 
66
- os.unlink(seg_file)
67
 
68
- return "\n".join(output_lines) or "Enginn texti fannst."
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- # ==========================================================
72
- # UI
73
- # ==========================================================
 
 
 
 
 
74
  with gr.Blocks() as demo:
75
- gr.Markdown("# ���️ Íslenskt ASR + mælendagreining (ZeroGPU)")
76
- audio = gr.Audio(type="filepath", label="Hlaða inn hljóði (.wav / .mp3)")
77
- out = gr.Textbox(lines=25, label="Útskrift")
 
 
 
78
 
79
- btn = gr.Button("Transcribe")
80
- btn.click(transcribe_with_diarization, inputs=audio, outputs=out)
81
 
82
  demo.launch(auth=("beta", "beta2025"))
 
1
  import os
2
  import gradio as gr
3
  import spaces
4
+ import webrtcvad
5
+ import numpy as np
6
+ from pydub import AudioSegment
7
+ from sklearn.cluster import AgglomerativeClustering
8
+ from transformers import pipeline, Wav2Vec2Processor, Wav2Vec2Model
9
  import torch
10
+ import tempfile
11
 
12
+ ASR_MODEL = "palli23/whisper-small-sam_spjall"
 
13
 
14
+ # Load speech embedding model (ECAPA)
15
+ EMB_MODEL = "speechbrain/spkrec-ecapa-voxceleb"
16
+ processor = Wav2Vec2Processor.from_pretrained(EMB_MODEL)
17
+ embedder = Wav2Vec2Model.from_pretrained(EMB_MODEL)
 
 
18
 
 
 
 
 
 
 
19
 
20
+ def audio_to_frames(path, frame_ms=30):
21
+ audio = AudioSegment.from_file(path).set_channels(1).set_frame_rate(16000)
22
+ samples = np.array(audio.get_array_of_samples()).astype(np.int16)
23
+ frame_len = int(16000 * frame_ms / 1000)
24
+ for i in range(0, len(samples), frame_len):
25
+ yield samples[i:i + frame_len]
26
 
 
 
27
 
28
+ def extract_segments(path):
29
+ vad = webrtcvad.Vad(2)
30
+ frames = list(audio_to_frames(path))
31
 
32
+ segments = []
33
+ current = []
34
 
35
+ for frame in frames:
36
+ if len(frame) < 480:
37
+ continue
38
+
39
+ is_speech = vad.is_speech(frame.tobytes(), 16000)
40
+ if is_speech:
41
+ current.append(frame)
42
+ else:
43
+ if current:
44
+ segments.append(np.concatenate(current))
45
+ current = []
46
+
47
+ if current:
48
+ segments.append(np.concatenate(current))
49
+
50
+ return segments
51
 
 
 
 
 
 
 
 
 
52
 
53
+ def embed_audio(segment):
54
+ with torch.no_grad():
55
+ inputs = processor(segment, sampling_rate=16000, return_tensors="pt")
56
+ emb = embedder(**inputs).last_hidden_state.mean(dim=1)
57
+ return emb[0].numpy()
58
 
 
 
 
 
 
 
 
 
59
 
60
+ def cluster_speakers(embeddings, max_speakers=5):
61
+ X = np.stack(embeddings)
62
+ clustering = AgglomerativeClustering(
63
+ n_clusters=None,
64
+ distance_threshold=1.0
65
+ ).fit(X)
66
 
67
+ return clustering.labels_
68
 
 
 
 
69
 
70
+ asr = pipeline("automatic-speech-recognition",
71
+ model=ASR_MODEL, device=0)
72
 
 
73
 
74
+ @spaces.GPU(duration=120)
75
+ def diarize_and_transcribe(audio_path):
76
+ if not audio_path:
77
+ return "Hladdu upp hljóðskrá"
78
+
79
+ # --- STEP 1: VAD speech detection ---
80
+ segments = extract_segments(audio_path)
81
+ if not segments:
82
+ return "Engin tala heyrðist í skránni."
83
+
84
+ embeddings = [embed_audio(seg) for seg in segments]
85
+
86
+ # --- STEP 2: Speaker clustering ---
87
+ labels = cluster_speakers(embeddings)
88
 
89
+ # --- STEP 3: ASR á hverju segmenti ---
90
+ out = []
91
+ for seg, spk in zip(segments, labels):
92
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f:
93
+ audio = (seg.astype(np.int16)).tobytes()
94
+ temp_audio = AudioSegment(
95
+ data=audio,
96
+ sample_width=2,
97
+ frame_rate=16000,
98
+ channels=1
99
+ )
100
+ temp_audio.export(f.name, format="wav")
101
+ seg_path = f.name
102
 
103
+ txt = asr(seg_path)["text"].strip()
104
+ out.append(f"[MÆLENDI {spk}] {txt}")
105
+ os.unlink(seg_path)
106
+
107
+ return "\n".join(out)
108
+
109
+
110
+ # --- Gradio UI ---
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("# Íslenskt ASR + VAD mælendagreining (WebRTC)")
113
+ gr.Markdown("Virkar á ZeroGPU\nHladdu upp .mp3 / .wav (allt að 5 mín)")
114
+
115
+ audio = gr.Audio(type="filepath")
116
+ btn = gr.Button("Transcribe með mælendum")
117
+ out = gr.Textbox(lines=35)
118
 
119
+ btn.click(diarize_and_transcribe, inputs=audio, outputs=out)
 
120
 
121
  demo.launch(auth=("beta", "beta2025"))
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
- gradio
2
- transformers
3
- torch
4
- spaces
5
- pyannote.audio
6
- librosa
7
- soundfile
 
1
+ torch==2.0.1
2
+ transformers==4.40.2
3
+ webrtcvad
4
+ pydub
5
+ numpy
6
+ scikit-learn
7
+ sentencepiece