# DIARIZATION + ASR integration (add to your app.py) import os import tempfile import torch from pydub import AudioSegment import soundfile as sf from pyannote.audio import Pipeline # pip install pyannote.audio from transformers import pipeline as hf_pipeline # --- CONFIG --- DIAR_PYMODEL = "pyannote/speaker-diarization" # or a specific version HF_TOKEN = os.environ.get("HF_TOKEN", None) # set as secret in Spac DEVICE = 0 if torch.cuda.is_available() else -1 # create pipelines cache DIAR_PIPE = None ASR_PIPE_CACHE = {} def get_diar_pipeline(): global DIAR_PIPE if DIAR_PIPE is None: # Pipeline.from_pretrained will use HF_TOKEN from env automatically DIAR_PIPE = Pipeline.from_pretrained(DIAR_PYMODEL, use_auth_token=HF_TOKEN) return DIAR_PIPE def get_asr_pipeline(model_id): if model_id in ASR_PIPE_CACHE: return ASR_PIPE_CACHE[model_id] p = hf_pipeline("automatic-speech-recognition", model=model_id, device=DEVICE) ASR_PIPE_CACHE[model_id] = p return p def diarize_audio_to_segments(audio_path): """ Returns list of segments: [{'start': float, 'end': float, 'speaker': 'SPEAKER_00'}, ...] """ pipeline = get_diar_pipeline() # pyannote expects 16k mono; Pipeline will resample internally if needed diarization = pipeline(audio_path) segments = [] # diarization is a pyannote.core.Annotation for turn, _, label in diarization.itertracks(yield_label=True): segments.append({"start": float(turn.start), "end": float(turn.end), "speaker": label}) return segments def extract_audio_segment(orig_path, start_s, end_s): audio = AudioSegment.from_file(orig_path) start_ms, end_ms = int(start_s * 1000), int(end_s * 1000) chunk = audio[start_ms:end_ms] tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") chunk.export(tmp.name, format="wav") return tmp.name def diarized_transcribe(audio_path, model_id): """ Runs diarization then ASR per speaker segment. Returns list of speaker-attributed segments. """ segments = diarize_audio_to_segments(audio_path) asr = get_asr_pipeline(model_id) speaker_results = [] for seg in segments: seg_path = extract_audio_segment(audio_path, seg["start"], seg["end"]) try: out = asr(seg_path) # returns dict with "text" in HF pipeline text = out.get("text", str(out)) except Exception as e: text = f"[ASR error: {e}]" speaker_results.append({ "start": seg["start"], "end": seg["end"], "speaker": seg["speaker"], "text": text }) try: os.unlink(seg_path) except: pass return speaker_results