DIA / app.py
EYEDOL's picture
Update app.py
78084f0 verified
# DIARIZATION + ASR integration (add to your app.py)
import os
import tempfile
import torch
from pydub import AudioSegment
import soundfile as sf
from pyannote.audio import Pipeline # pip install pyannote.audio
from transformers import pipeline as hf_pipeline
# --- CONFIG ---
DIAR_PYMODEL = "pyannote/speaker-diarization" # or a specific version
HF_TOKEN = os.environ.get("HF_TOKEN", None) # set as secret in Spac
DEVICE = 0 if torch.cuda.is_available() else -1
# create pipelines cache
DIAR_PIPE = None
ASR_PIPE_CACHE = {}
def get_diar_pipeline():
global DIAR_PIPE
if DIAR_PIPE is None:
# Pipeline.from_pretrained will use HF_TOKEN from env automatically
DIAR_PIPE = Pipeline.from_pretrained(DIAR_PYMODEL, use_auth_token=HF_TOKEN)
return DIAR_PIPE
def get_asr_pipeline(model_id):
if model_id in ASR_PIPE_CACHE:
return ASR_PIPE_CACHE[model_id]
p = hf_pipeline("automatic-speech-recognition", model=model_id, device=DEVICE)
ASR_PIPE_CACHE[model_id] = p
return p
def diarize_audio_to_segments(audio_path):
"""
Returns list of segments: [{'start': float, 'end': float, 'speaker': 'SPEAKER_00'}, ...]
"""
pipeline = get_diar_pipeline()
# pyannote expects 16k mono; Pipeline will resample internally if needed
diarization = pipeline(audio_path)
segments = []
# diarization is a pyannote.core.Annotation
for turn, _, label in diarization.itertracks(yield_label=True):
segments.append({"start": float(turn.start), "end": float(turn.end), "speaker": label})
return segments
def extract_audio_segment(orig_path, start_s, end_s):
audio = AudioSegment.from_file(orig_path)
start_ms, end_ms = int(start_s * 1000), int(end_s * 1000)
chunk = audio[start_ms:end_ms]
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
chunk.export(tmp.name, format="wav")
return tmp.name
def diarized_transcribe(audio_path, model_id):
"""
Runs diarization then ASR per speaker segment. Returns list of speaker-attributed segments.
"""
segments = diarize_audio_to_segments(audio_path)
asr = get_asr_pipeline(model_id)
speaker_results = []
for seg in segments:
seg_path = extract_audio_segment(audio_path, seg["start"], seg["end"])
try:
out = asr(seg_path) # returns dict with "text" in HF pipeline
text = out.get("text", str(out))
except Exception as e:
text = f"[ASR error: {e}]"
speaker_results.append({
"start": seg["start"],
"end": seg["end"],
"speaker": seg["speaker"],
"text": text
})
try: os.unlink(seg_path)
except: pass
return speaker_results