EYEDOL commited on
Commit
23a4b9c
·
verified ·
1 Parent(s): 3c85dab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DIARIZATION + ASR integration (add to your app.py)
2
+ import os
3
+ import tempfile
4
+ from pydub import AudioSegment
5
+ import soundfile as sf
6
+ from pyannote.audio import Pipeline # pip install pyannote.audio
7
+ from transformers import pipeline as hf_pipeline
8
+
9
+ # --- CONFIG ---
10
+ DIAR_PYMODEL = "pyannote/speaker-diarization" # or a specific version
11
+ HF_TOKEN = os.environ.get("HF_TOKEN", None) # set as secret in Space if required
12
+ DEVICE = 0 if torch.cuda.is_available() else -1
13
+
14
+ # create pipelines cache
15
+ DIAR_PIPE = None
16
+ ASR_PIPE_CACHE = {}
17
+
18
+ def get_diar_pipeline():
19
+ global DIAR_PIPE
20
+ if DIAR_PIPE is None:
21
+ # Pipeline.from_pretrained will use HF_TOKEN from env automatically
22
+ DIAR_PIPE = Pipeline.from_pretrained(DIAR_PYMODEL, use_auth_token=HF_TOKEN)
23
+ return DIAR_PIPE
24
+
25
+ def get_asr_pipeline(model_id):
26
+ if model_id in ASR_PIPE_CACHE:
27
+ return ASR_PIPE_CACHE[model_id]
28
+ p = hf_pipeline("automatic-speech-recognition", model=model_id, device=DEVICE)
29
+ ASR_PIPE_CACHE[model_id] = p
30
+ return p
31
+
32
+ def diarize_audio_to_segments(audio_path):
33
+ """
34
+ Returns list of segments: [{'start': float, 'end': float, 'speaker': 'SPEAKER_00'}, ...]
35
+ """
36
+ pipeline = get_diar_pipeline()
37
+ # pyannote expects 16k mono; Pipeline will resample internally if needed
38
+ diarization = pipeline(audio_path)
39
+ segments = []
40
+ # diarization is a pyannote.core.Annotation
41
+ for turn, _, label in diarization.itertracks(yield_label=True):
42
+ segments.append({"start": float(turn.start), "end": float(turn.end), "speaker": label})
43
+ return segments
44
+
45
+ def extract_audio_segment(orig_path, start_s, end_s):
46
+ audio = AudioSegment.from_file(orig_path)
47
+ start_ms, end_ms = int(start_s * 1000), int(end_s * 1000)
48
+ chunk = audio[start_ms:end_ms]
49
+ tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
50
+ chunk.export(tmp.name, format="wav")
51
+ return tmp.name
52
+
53
+ def diarized_transcribe(audio_path, model_id):
54
+ """
55
+ Runs diarization then ASR per speaker segment. Returns list of speaker-attributed segments.
56
+ """
57
+ segments = diarize_audio_to_segments(audio_path)
58
+ asr = get_asr_pipeline(model_id)
59
+
60
+ speaker_results = []
61
+ for seg in segments:
62
+ seg_path = extract_audio_segment(audio_path, seg["start"], seg["end"])
63
+ try:
64
+ out = asr(seg_path) # returns dict with "text" in HF pipeline
65
+ text = out.get("text", str(out))
66
+ except Exception as e:
67
+ text = f"[ASR error: {e}]"
68
+ speaker_results.append({
69
+ "start": seg["start"],
70
+ "end": seg["end"],
71
+ "speaker": seg["speaker"],
72
+ "text": text
73
+ })
74
+ try: os.unlink(seg_path)
75
+ except: pass
76
+
77
+ return speaker_results