|
|
|
|
|
import os |
|
|
import tempfile |
|
|
import torch |
|
|
from pydub import AudioSegment |
|
|
import soundfile as sf |
|
|
from pyannote.audio import Pipeline |
|
|
from transformers import pipeline as hf_pipeline |
|
|
|
|
|
|
|
|
DIAR_PYMODEL = "pyannote/speaker-diarization" |
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", None) |
|
|
DEVICE = 0 if torch.cuda.is_available() else -1 |
|
|
|
|
|
|
|
|
DIAR_PIPE = None |
|
|
ASR_PIPE_CACHE = {} |
|
|
|
|
|
def get_diar_pipeline(): |
|
|
global DIAR_PIPE |
|
|
if DIAR_PIPE is None: |
|
|
|
|
|
DIAR_PIPE = Pipeline.from_pretrained(DIAR_PYMODEL, use_auth_token=HF_TOKEN) |
|
|
return DIAR_PIPE |
|
|
|
|
|
def get_asr_pipeline(model_id): |
|
|
if model_id in ASR_PIPE_CACHE: |
|
|
return ASR_PIPE_CACHE[model_id] |
|
|
p = hf_pipeline("automatic-speech-recognition", model=model_id, device=DEVICE) |
|
|
ASR_PIPE_CACHE[model_id] = p |
|
|
return p |
|
|
|
|
|
def diarize_audio_to_segments(audio_path): |
|
|
""" |
|
|
Returns list of segments: [{'start': float, 'end': float, 'speaker': 'SPEAKER_00'}, ...] |
|
|
""" |
|
|
pipeline = get_diar_pipeline() |
|
|
|
|
|
diarization = pipeline(audio_path) |
|
|
segments = [] |
|
|
|
|
|
for turn, _, label in diarization.itertracks(yield_label=True): |
|
|
segments.append({"start": float(turn.start), "end": float(turn.end), "speaker": label}) |
|
|
return segments |
|
|
|
|
|
def extract_audio_segment(orig_path, start_s, end_s): |
|
|
audio = AudioSegment.from_file(orig_path) |
|
|
start_ms, end_ms = int(start_s * 1000), int(end_s * 1000) |
|
|
chunk = audio[start_ms:end_ms] |
|
|
tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
|
|
chunk.export(tmp.name, format="wav") |
|
|
return tmp.name |
|
|
|
|
|
def diarized_transcribe(audio_path, model_id): |
|
|
""" |
|
|
Runs diarization then ASR per speaker segment. Returns list of speaker-attributed segments. |
|
|
""" |
|
|
segments = diarize_audio_to_segments(audio_path) |
|
|
asr = get_asr_pipeline(model_id) |
|
|
|
|
|
speaker_results = [] |
|
|
for seg in segments: |
|
|
seg_path = extract_audio_segment(audio_path, seg["start"], seg["end"]) |
|
|
try: |
|
|
out = asr(seg_path) |
|
|
text = out.get("text", str(out)) |
|
|
except Exception as e: |
|
|
text = f"[ASR error: {e}]" |
|
|
speaker_results.append({ |
|
|
"start": seg["start"], |
|
|
"end": seg["end"], |
|
|
"speaker": seg["speaker"], |
|
|
"text": text |
|
|
}) |
|
|
try: os.unlink(seg_path) |
|
|
except: pass |
|
|
|
|
|
return speaker_results |
|
|
|