File size: 2,744 Bytes
23a4b9c
 
 
e73b59f
23a4b9c
 
 
 
 
 
 
78084f0
23a4b9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# DIARIZATION + ASR integration (add to your app.py)
import os
import tempfile
import torch
from pydub import AudioSegment
import soundfile as sf
from pyannote.audio import Pipeline   # pip install pyannote.audio
from transformers import pipeline as hf_pipeline

# --- CONFIG ---
DIAR_PYMODEL = "pyannote/speaker-diarization"   # or a specific version
HF_TOKEN = os.environ.get("HF_TOKEN", None)    # set as secret in Spac
DEVICE = 0 if torch.cuda.is_available() else -1

# create pipelines cache
DIAR_PIPE = None
ASR_PIPE_CACHE = {}

def get_diar_pipeline():
    global DIAR_PIPE
    if DIAR_PIPE is None:
        # Pipeline.from_pretrained will use HF_TOKEN from env automatically
        DIAR_PIPE = Pipeline.from_pretrained(DIAR_PYMODEL, use_auth_token=HF_TOKEN)
    return DIAR_PIPE

def get_asr_pipeline(model_id):
    if model_id in ASR_PIPE_CACHE:
        return ASR_PIPE_CACHE[model_id]
    p = hf_pipeline("automatic-speech-recognition", model=model_id, device=DEVICE)
    ASR_PIPE_CACHE[model_id] = p
    return p

def diarize_audio_to_segments(audio_path):
    """
    Returns list of segments: [{'start': float, 'end': float, 'speaker': 'SPEAKER_00'}, ...]
    """
    pipeline = get_diar_pipeline()
    # pyannote expects 16k mono; Pipeline will resample internally if needed
    diarization = pipeline(audio_path)
    segments = []
    # diarization is a pyannote.core.Annotation
    for turn, _, label in diarization.itertracks(yield_label=True):
        segments.append({"start": float(turn.start), "end": float(turn.end), "speaker": label})
    return segments

def extract_audio_segment(orig_path, start_s, end_s):
    audio = AudioSegment.from_file(orig_path)
    start_ms, end_ms = int(start_s * 1000), int(end_s * 1000)
    chunk = audio[start_ms:end_ms]
    tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
    chunk.export(tmp.name, format="wav")
    return tmp.name

def diarized_transcribe(audio_path, model_id):
    """
    Runs diarization then ASR per speaker segment. Returns list of speaker-attributed segments.
    """
    segments = diarize_audio_to_segments(audio_path)
    asr = get_asr_pipeline(model_id)

    speaker_results = []
    for seg in segments:
        seg_path = extract_audio_segment(audio_path, seg["start"], seg["end"])
        try:
            out = asr(seg_path)  # returns dict with "text" in HF pipeline
            text = out.get("text", str(out))
        except Exception as e:
            text = f"[ASR error: {e}]"
        speaker_results.append({
            "start": seg["start"],
            "end": seg["end"],
            "speaker": seg["speaker"],
            "text": text
        })
        try: os.unlink(seg_path)
        except: pass

    return speaker_results