Spaces:
Runtime error
Runtime error
liuyang
Enhance speaker assignment in transcription: Introduced interval overlap calculations and smoothing techniques for improved accuracy in speaker labeling. Added methods for determining dominant speakers and stabilizing segment boundaries.
f800f63
| import spaces | |
| import boto3 | |
| from botocore.exceptions import NoCredentialsError, ClientError | |
| from botocore.client import Config | |
| import os, pathlib | |
| CACHE_ROOT = "/home/user/app/cache" # any folder you own | |
| os.environ.update( | |
| TORCH_HOME = f"{CACHE_ROOT}/torch", | |
| XDG_CACHE_HOME = f"{CACHE_ROOT}/xdg", # torch fallback | |
| PYANNOTE_CACHE = f"{CACHE_ROOT}/pyannote", | |
| HF_HOME = f"{CACHE_ROOT}/huggingface", | |
| TRANSFORMERS_CACHE= f"{CACHE_ROOT}/transformers", | |
| MPLCONFIGDIR = f"{CACHE_ROOT}/mpl", | |
| ) | |
| INITIAL_PROMPT = ''' | |
| Use normal punctuation; end sentences properly. | |
| ''' | |
| # make sure the directories exist | |
| for path in os.environ.values(): | |
| pathlib.Path(path).mkdir(parents=True, exist_ok=True) | |
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import pandas as pd | |
| import time | |
| import datetime | |
| import re | |
| import subprocess | |
| import os | |
| import tempfile | |
| import spaces | |
| from faster_whisper import WhisperModel, BatchedInferencePipeline | |
| from faster_whisper.vad import VadOptions | |
| import requests | |
| import base64 | |
| from pyannote.audio import Pipeline, Inference, Model | |
| from pyannote.core import Segment | |
| import os, sys, importlib.util, pathlib, ctypes, tempfile, wave, math | |
| import json | |
| import webrtcvad | |
| spec = importlib.util.find_spec("nvidia.cudnn") | |
| if spec is None: | |
| sys.exit("β nvidia-cudnn-cu12 wheel not found. Run: pip install nvidia-cudnn-cu12") | |
| cudnn_dir = pathlib.Path(spec.origin).parent / "lib" | |
| cnn_so = cudnn_dir / "libcudnn_cnn.so.9" | |
| try: | |
| ctypes.CDLL(cnn_so, mode=ctypes.RTLD_GLOBAL) | |
| print(f"β Pre-loaded {cnn_so}") | |
| except OSError as e: | |
| sys.exit(f"β Could not load {cnn_so} : {e}") | |
| S3_ENDPOINT = os.getenv("S3_ENDPOINT") | |
| S3_ACCESS_KEY = os.getenv("S3_ACCESS_KEY") | |
| S3_SECRET_KEY = os.getenv("S3_SECRET_KEY") | |
| # Function to upload file to Cloudflare R2 | |
| def upload_data_to_r2(data, bucket_name, object_name, content_type='application/octet-stream'): | |
| """ | |
| Upload data directly to a Cloudflare R2 bucket. | |
| :param data: Data to upload (bytes or string). | |
| :param bucket_name: Name of the R2 bucket. | |
| :param object_name: Name of the object to save in the bucket. | |
| :param content_type: MIME type of the data. | |
| :return: True if data was uploaded, else False. | |
| """ | |
| try: | |
| # Convert string to bytes if necessary | |
| if isinstance(data, str): | |
| data = data.encode('utf-8') | |
| # Initialize a session using Cloudflare R2 credentials | |
| session = boto3.session.Session() | |
| s3 = session.client('s3', | |
| endpoint_url=f'https://{S3_ENDPOINT}', | |
| aws_access_key_id=S3_ACCESS_KEY, | |
| aws_secret_access_key=S3_SECRET_KEY, | |
| config = Config(s3={"addressing_style": "virtual", 'payload_signing_enabled': False}, signature_version='v4', | |
| request_checksum_calculation='when_required', | |
| response_checksum_validation='when_required',), | |
| ) | |
| # Upload the data to R2 bucket | |
| s3.put_object( | |
| Bucket=bucket_name, | |
| Key=object_name, | |
| Body=data, | |
| ContentType=content_type, | |
| ContentLength=len(data), # make length explicit to avoid streaming | |
| ) | |
| print(f"Data uploaded to R2 bucket '{bucket_name}' as '{object_name}'") | |
| return True | |
| except NoCredentialsError: | |
| print("Credentials not available") | |
| return False | |
| except ClientError as e: | |
| print(f"Failed to upload data to R2 bucket: {e}") | |
| return False | |
| except Exception as e: | |
| print(f"An unexpected error occurred: {e}") | |
| return False | |
| from huggingface_hub import snapshot_download | |
| MODEL_REPO = "deepdml/faster-whisper-large-v3-turbo-ct2" # CT2 format | |
| LOCAL_DIR = f"{CACHE_ROOT}/whisper_turbo" | |
| # ----------------------------------------------------------------------------- | |
| # Audio preprocess helper (from input_and_preprocess rule) | |
| # ----------------------------------------------------------------------------- | |
| TRIM_THRESHOLD_MS = 10_000 # 10 seconds | |
| DEFAULT_PAD_MS = 250 # safety context around detected speech | |
| FRAME_MS = 30 # VAD frame | |
| HANG_MS = 240 # hangover (keep speech "on" after silence) | |
| VAD_LEVEL = 2 # 0-3 | |
| def _decode_chunk_to_pcm(task: dict) -> bytes: | |
| """Use ffmpeg to decode the chunk to s16le mono @ 16k PCM bytes.""" | |
| src = task["source_uri"] | |
| ing = task["ingest_recipe"] | |
| seek = task["ffmpeg_seek"] | |
| cmd = [ | |
| "ffmpeg", "-nostdin", "-hide_banner", "-v", "error", | |
| "-ss", f"{max(0.0, float(seek['pre_ss_sec'])):.3f}", | |
| "-i", src, | |
| "-map", "0:a:0", | |
| "-ss", f"{float(seek['post_ss_sec']):.2f}", | |
| "-t", f"{float(seek['t_sec']):.3f}", | |
| ] | |
| # Optional L/R extraction | |
| if ing.get("channel_extract_filter"): | |
| cmd += ["-af", ing["channel_extract_filter"]] | |
| # Force mono 16k s16le to stdout | |
| cmd += ["-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", "-f", "s16le", "pipe:1"] | |
| p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) | |
| pcm, err = p.communicate() | |
| if p.returncode != 0: | |
| raise RuntimeError(f"ffmpeg failed: {err.decode('utf-8', 'ignore')}") | |
| return pcm | |
| def _find_head_tail_speech_ms( | |
| pcm: bytes, | |
| sr: int = 16000, | |
| frame_ms: int = FRAME_MS, | |
| vad_level: int = VAD_LEVEL, | |
| hang_ms: int = HANG_MS, | |
| ): | |
| """Return (first_ms, last_ms) speech boundaries using webrtcvad with hangover.""" | |
| if not pcm: | |
| return None, None | |
| vad = webrtcvad.Vad(int(vad_level)) | |
| bpf = 2 # bytes per sample (s16) | |
| samples_per_ms = sr // 1000 # 16 | |
| bytes_per_frame = samples_per_ms * bpf * frame_ms | |
| n_frames = len(pcm) // bytes_per_frame | |
| if n_frames == 0: | |
| return None, None | |
| first_ms, last_ms = None, None | |
| t_ms = 0 | |
| in_speech = False | |
| silence_run = 0 | |
| view = memoryview(pcm)[: n_frames * bytes_per_frame] | |
| for i in range(n_frames): | |
| frame = view[i * bytes_per_frame : (i + 1) * bytes_per_frame] | |
| if vad.is_speech(frame, sr): | |
| if first_ms is None: | |
| first_ms = t_ms | |
| in_speech = True | |
| silence_run = 0 | |
| else: | |
| if in_speech: | |
| silence_run += frame_ms | |
| if silence_run >= hang_ms: | |
| last_ms = t_ms - (silence_run - hang_ms) | |
| in_speech = False | |
| silence_run = 0 | |
| t_ms += frame_ms | |
| if in_speech: | |
| last_ms = t_ms | |
| return first_ms, last_ms | |
| def _write_wav(path: str, pcm: bytes, sr: int = 16000): | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| with wave.open(path, "wb") as w: | |
| w.setnchannels(1) | |
| w.setsampwidth(2) # s16 | |
| w.setframerate(sr) | |
| w.writeframes(pcm) | |
| def prepare_and_save_audio_for_model(task: dict, out_dir: str) -> dict: | |
| """ | |
| 1) Decode chunk to mono 16k PCM. | |
| 2) Run VAD to locate head/tail silence. | |
| 3) Trim only if head or tail >= 10s. | |
| 4) Save the (possibly trimmed) WAV to local file. | |
| 5) Return timing metadata, including 'trimmed_start_ms' to preserve global timestamps. | |
| """ | |
| # 0) Names & constants | |
| sr = 16000 | |
| bpf = 2 | |
| samples_per_ms = sr // 1000 | |
| def bytes_from_ms(ms: int) -> int: | |
| return int(ms * samples_per_ms) * bpf | |
| ch = task["channel"] | |
| ck = task["chunk"] | |
| job = task.get("job_id", "job") | |
| idx = str(ck["idx"]) | |
| # 1) Decode chunk | |
| pcm = _decode_chunk_to_pcm(task) | |
| planned_dur_ms = int(ck["dur_ms"]) | |
| # 2) VAD head/tail detection | |
| first_ms, last_ms = _find_head_tail_speech_ms(pcm, sr=sr) | |
| head_sil_ms = int(first_ms) if first_ms is not None else planned_dur_ms | |
| tail_sil_ms = int(planned_dur_ms - last_ms) if last_ms is not None else planned_dur_ms | |
| # 3) Decide trimming (only if head or tail >= 10s) | |
| trim_applied = False | |
| eff_start_ms = 0 | |
| eff_end_ms = planned_dur_ms | |
| trimmed_pcm = pcm | |
| if (head_sil_ms >= TRIM_THRESHOLD_MS) or (tail_sil_ms >= TRIM_THRESHOLD_MS): | |
| # If no speech found at all, mark skip | |
| if first_ms is None or last_ms is None or last_ms <= first_ms: | |
| out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_nospeech.wav") | |
| _write_wav(out_wav_path, b"", sr) | |
| return { | |
| "out_wav_path": out_wav_path, | |
| "sr": sr, | |
| "trim_applied": False, | |
| "trimmed_start_ms": 0, | |
| "head_silence_ms": head_sil_ms, | |
| "tail_silence_ms": tail_sil_ms, | |
| "effective_start_ms": 0, | |
| "effective_dur_ms": 0, | |
| "abs_start_ms": ck["global_offset_ms"], | |
| "chunk_idx": idx, | |
| "channel": ch, | |
| "skip": True, | |
| } | |
| # Apply padding & slice | |
| start_ms = max(0, int(first_ms) - DEFAULT_PAD_MS) | |
| end_ms = min(planned_dur_ms, int(last_ms) + DEFAULT_PAD_MS) | |
| if end_ms > start_ms: | |
| eff_start_ms = start_ms | |
| eff_end_ms = end_ms | |
| trimmed_pcm = pcm[bytes_from_ms(start_ms) : bytes_from_ms(end_ms)] | |
| trim_applied = True | |
| # 4) Write WAV to local file (trimmed or original) | |
| tag = "trim" if trim_applied else "full" | |
| out_wav_path = os.path.join(out_dir, f"{job}_{ch}_{idx}_{tag}.wav") | |
| _write_wav(out_wav_path, trimmed_pcm, sr) | |
| # 5) Return metadata | |
| return { | |
| "out_wav_path": out_wav_path, | |
| "sr": sr, | |
| "trim_applied": trim_applied, | |
| "trimmed_start_ms": eff_start_ms if trim_applied else 0, | |
| "head_silence_ms": head_sil_ms, | |
| "tail_silence_ms": tail_sil_ms, | |
| "effective_start_ms": eff_start_ms, | |
| "effective_dur_ms": eff_end_ms - eff_start_ms, | |
| "abs_start_ms": int(ck["global_offset_ms"]) + eff_start_ms, | |
| "chunk_idx": idx, | |
| "channel": ch, | |
| "job_id": job, | |
| "skip": False if (trim_applied or len(pcm) > 0) else True, | |
| } | |
| # Download once; later runs are instant | |
| snapshot_download( | |
| repo_id=MODEL_REPO, | |
| local_dir=LOCAL_DIR, | |
| local_dir_use_symlinks=True, # saves disk space | |
| resume_download=True | |
| ) | |
| model_cache_path = LOCAL_DIR # <ββ this is what we pass to WhisperModel | |
| # Lazy global holder ---------------------------------------------------------- | |
| _whisper = None | |
| _batched_whisper = None | |
| _diarizer = None | |
| _embedder = None | |
| # Create global diarization pipeline | |
| try: | |
| print("Loading diarization model...") | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| torch.set_float32_matmul_precision('high') | |
| _diarizer = Pipeline.from_pretrained( | |
| "pyannote/speaker-diarization-3.1", | |
| use_auth_token=os.getenv("HF_TOKEN"), | |
| ).to(torch.device("cuda")) | |
| print("Diarization model loaded successfully") | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| print(f"Could not load diarization model: {e}") | |
| _diarizer = None | |
| # GPU is guaranteed to exist *inside* this function | |
| def _load_models(): | |
| global _whisper, _batched_whisper, _diarizer | |
| if _whisper is None: | |
| print("Loading Whisper model...") | |
| _whisper = WhisperModel( | |
| model_cache_path, | |
| device="cuda", | |
| compute_type="float16", | |
| ) | |
| # Create batched inference pipeline for improved performance | |
| _batched_whisper = BatchedInferencePipeline(model=_whisper) | |
| print("Whisper model and batched pipeline loaded successfully") | |
| return _whisper, _batched_whisper, _diarizer | |
| # ----------------------------------------------------------------------------- | |
| class WhisperTranscriber: | |
| def __init__(self): | |
| # do **not** create the models here! | |
| pass | |
| def preprocess_from_task_json(self, task_json: str) -> dict: | |
| """Parse task JSON and run prepare_and_save_audio_for_model, returning metadata.""" | |
| try: | |
| task = json.loads(task_json) | |
| except Exception as e: | |
| raise RuntimeError(f"Invalid JSON: {e}") | |
| out_dir = os.path.join(CACHE_ROOT, "preprocessed") | |
| os.makedirs(out_dir, exist_ok=True) | |
| meta = prepare_and_save_audio_for_model(task, out_dir) | |
| return meta | |
| # each call gets a GPU slice | |
| def transcribe_full_audio(self, audio_path, language=None, translate=False, prompt=None, batch_size=16, base_offset_s: float = 0.0): | |
| """Transcribe the entire audio file without speaker diarization using batched inference""" | |
| whisper, batched_whisper, _ = _load_models() # models live on the GPU | |
| print(f"Transcribing full audio with batch size {batch_size}...") | |
| start_time = time.time() | |
| # Prepare options for batched inference | |
| options = dict( | |
| language=language, | |
| beam_size=5, | |
| vad_filter=True, # VAD is enabled by default for batched transcription | |
| vad_parameters=VadOptions( | |
| max_speech_duration_s=whisper.feature_extractor.chunk_length, | |
| min_speech_duration_ms=150, # ignore ultra-short blips | |
| min_silence_duration_ms=150, # split on short Mandarin pauses (if supported) speech_pad_ms=100, | |
| threshold=0.25, | |
| neg_threshold=0.2, | |
| ), | |
| word_timestamps=True, | |
| initial_prompt=prompt, | |
| condition_on_previous_text=False, # avoid runaway context | |
| language_detection_segments=1, | |
| task="translate" if translate else "transcribe", | |
| ) | |
| if batch_size > 1: | |
| # Use batched inference for better performance | |
| segments, transcript_info = batched_whisper.transcribe( | |
| audio_path, | |
| batch_size=batch_size, | |
| **options | |
| ) | |
| else: | |
| segments, transcript_info = whisper.transcribe( | |
| audio_path, | |
| **options | |
| ) | |
| segments = list(segments) | |
| detected_language = transcript_info.language | |
| print("Detected language: ", detected_language, "segments: ", len(segments)) | |
| # Process segments | |
| results = [] | |
| for seg in segments: | |
| # Create result entry with detailed format | |
| words_list = [] | |
| if seg.words: | |
| for word in seg.words: | |
| words_list.append({ | |
| "start": float(word.start) + float(base_offset_s), | |
| "end": float(word.end) + float(base_offset_s), | |
| "word": word.word, | |
| "probability": word.probability, | |
| "speaker": "SPEAKER_00" # No speaker identification in full transcription | |
| }) | |
| results.append({ | |
| "start": float(seg.start) + float(base_offset_s), | |
| "end": float(seg.end) + float(base_offset_s), | |
| "text": seg.text, | |
| "speaker": "SPEAKER_00", # Single speaker assumption | |
| "avg_logprob": seg.avg_logprob, | |
| "words": words_list, | |
| "duration": float(seg.end - seg.start) | |
| }) | |
| transcription_time = time.time() - start_time | |
| print(f"Full audio transcribed in {transcription_time:.2f} seconds using batch size {batch_size}") | |
| print(results) | |
| return results, detected_language | |
| # Removed audio cutting; transcription is done once on the full (preprocessed) audio | |
| # each call gets a GPU slice | |
| # Removed segment-wise transcription; using single full-audio transcription | |
| # each call gets a GPU slice | |
| def perform_diarization(self, audio_path, num_speakers=None, base_offset_s: float = 0.0): | |
| """Perform speaker diarization; return segments with global timestamps and per-speaker embeddings.""" | |
| _, _, diarizer = _load_models() # models live on the GPU | |
| if diarizer is None: | |
| print("Diarization model not available, creating single speaker segment") | |
| # Load audio to get duration | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| duration = waveform.shape[1] / sample_rate | |
| # Try to compute a single-speaker embedding | |
| speaker_embeddings = {} | |
| try: | |
| embedder = self._load_embedder() | |
| # Provide waveform as (channel, time) and pad if too short | |
| min_embed_duration_sec = 3.0 | |
| min_samples = int(min_embed_duration_sec * sample_rate) | |
| if waveform.shape[1] < min_samples: | |
| pad_len = min_samples - waveform.shape[1] | |
| pad = torch.zeros(waveform.shape[0], pad_len, dtype=waveform.dtype, device=waveform.device) | |
| waveform = torch.cat([waveform, pad], dim=1) | |
| emb = embedder({"waveform": waveform, "sample_rate": sample_rate}) | |
| speaker_embeddings["SPEAKER_00"] = emb.squeeze().tolist() | |
| except Exception: | |
| pass | |
| return [{ | |
| "start": 0.0 + float(base_offset_s), | |
| "end": duration + float(base_offset_s), | |
| "speaker": "SPEAKER_00" | |
| }], 1, speaker_embeddings | |
| print("Starting diarization...") | |
| start_time = time.time() | |
| # Load audio for diarization | |
| waveform, sample_rate = torchaudio.load(audio_path) | |
| # Perform diarization | |
| diarization = diarizer( | |
| {"waveform": waveform, "sample_rate": sample_rate}, | |
| num_speakers=num_speakers, | |
| ) | |
| # Convert to list format | |
| diarize_segments = [] | |
| diarization_list = list(diarization.itertracks(yield_label=True)) | |
| #print(diarization_list) | |
| for turn, _, speaker in diarization_list: | |
| diarize_segments.append({ | |
| "start": float(turn.start) + float(base_offset_s), | |
| "end": float(turn.end) + float(base_offset_s), | |
| "speaker": speaker | |
| }) | |
| unique_speakers = {speaker for segment in diarize_segments for speaker in [segment["speaker"]]} | |
| detected_num_speakers = len(unique_speakers) | |
| # Compute per-speaker embeddings by averaging segment embeddings | |
| speaker_embeddings = {} | |
| try: | |
| embedder = self._load_embedder() | |
| spk_to_embs = {spk: [] for spk in unique_speakers} | |
| # Primary path: slice in-memory waveform and zero-pad short segments | |
| min_embed_duration_sec = 3.0 | |
| audio_duration_sec = float(waveform.shape[1]) / float(sample_rate) | |
| for turn, _, speaker in diarization_list: | |
| seg_start = float(turn.start) | |
| seg_end = float(turn.end) | |
| if seg_end <= seg_start: | |
| continue | |
| start_sample = max(0, int(seg_start * sample_rate)) | |
| end_sample = min(waveform.shape[1], int(seg_end * sample_rate)) | |
| if end_sample <= start_sample: | |
| continue | |
| seg_wav = waveform[:, start_sample:end_sample].contiguous() | |
| min_samples = int(min_embed_duration_sec * sample_rate) | |
| if seg_wav.shape[1] < min_samples: | |
| pad_len = min_samples - seg_wav.shape[1] | |
| pad = torch.zeros(seg_wav.shape[0], pad_len, dtype=seg_wav.dtype, device=seg_wav.device) | |
| seg_wav = torch.cat([seg_wav, pad], dim=1) | |
| try: | |
| emb = embedder({"waveform": seg_wav, "sample_rate": sample_rate}) | |
| except Exception: | |
| # Fallback: use crop on the file with expanded window to minimum duration | |
| desired_end = min(seg_start + min_embed_duration_sec, audio_duration_sec) | |
| desired_start = max(0.0, desired_end - min_embed_duration_sec) | |
| emb = embedder.crop(audio_path, Segment(desired_start, desired_end)) | |
| spk_to_embs[speaker].append(emb.squeeze()) | |
| # average | |
| for spk, embs in spk_to_embs.items(): | |
| if len(embs) == 0: | |
| continue | |
| # stack and mean | |
| try: | |
| import torch as _torch | |
| embs_tensor = _torch.stack([_torch.as_tensor(e) for e in embs], dim=0) | |
| centroid = embs_tensor.mean(dim=0) | |
| # L2 normalize | |
| centroid = centroid / (centroid.norm(p=2) + 1e-12) | |
| speaker_embeddings[spk] = centroid.cpu().tolist() | |
| except Exception: | |
| # fallback to first embedding | |
| speaker_embeddings[spk] = embs[0].cpu().tolist() | |
| #print(speaker_embeddings[spk]) | |
| except Exception as e: | |
| print(f"Error during embedding calculation: {e}") | |
| print(f"Diarization segments: {diarize_segments}") | |
| pass | |
| diarization_time = time.time() - start_time | |
| print(f"Diarization completed in {diarization_time:.2f} seconds") | |
| return diarize_segments, detected_num_speakers, speaker_embeddings | |
| def _load_embedder(self): | |
| """Lazy-load speaker embedding inference model on GPU.""" | |
| global _embedder | |
| if _embedder is None: | |
| # window="whole" to compute one embedding per provided chunk | |
| token = os.getenv("HF_TOKEN") | |
| model = Model.from_pretrained("pyannote/embedding", use_auth_token=token) | |
| _embedder = Inference(model, window="whole", device=torch.device("cuda")) | |
| return _embedder | |
| def assign_speakers_to_transcription(self, transcription_results, diarization_segments): | |
| """Assign speakers to words and segments based on overlap with diarization segments.""" | |
| if not diarization_segments: | |
| return transcription_results | |
| # Helper: find the diarization speaker active at time t, or closest | |
| def speaker_at(t: float): | |
| for dseg in diarization_segments: | |
| if float(dseg["start"]) <= t < float(dseg["end"]): | |
| return dseg["speaker"] | |
| # if not inside, return closest segment's speaker | |
| closest = None | |
| best_dist = float("inf") | |
| for dseg in diarization_segments: | |
| if t < float(dseg["start"]): | |
| d = float(dseg["start"]) - t | |
| elif t > float(dseg["end"]): | |
| d = t - float(dseg["end"]) | |
| else: | |
| d = 0.0 | |
| if d < best_dist: | |
| best_dist = d | |
| closest = dseg | |
| return closest["speaker"] if closest else "SPEAKER_00" | |
| # Helper: overlap length between two intervals | |
| def interval_overlap(a_start: float, a_end: float, b_start: float, b_end: float) -> float: | |
| return max(0.0, min(a_end, b_end) - max(a_start, b_start)) | |
| # Helper: choose speaker for an interval by maximum overlap with diarization | |
| def best_speaker_for_interval(start_t: float, end_t: float) -> str: | |
| best_spk = None | |
| best_ov = -1.0 | |
| for dseg in diarization_segments: | |
| ov = interval_overlap(float(start_t), float(end_t), float(dseg["start"]), float(dseg["end"])) | |
| if ov > best_ov: | |
| best_ov = ov | |
| best_spk = dseg["speaker"] | |
| if best_ov > 0.0 and best_spk is not None: | |
| return best_spk | |
| # fallback to nearest by midpoint | |
| mid = (float(start_t) + float(end_t)) / 2.0 | |
| return speaker_at(mid) | |
| for seg in transcription_results: | |
| # Assign per-word speakers using overlap, then smooth and stabilize boundaries | |
| if seg.get("words"): | |
| words = seg["words"] | |
| # 1) Initial assignment by overlap | |
| for w in words: | |
| w_start = float(w["start"]) | |
| w_end = float(w["end"]) | |
| w["speaker"] = best_speaker_for_interval(w_start, w_end) | |
| # 2) Small median filter (window=3) to fix isolated outliers | |
| if len(words) >= 3: | |
| smoothed = [words[i]["speaker"] for i in range(len(words))] | |
| for i in range(1, len(words) - 1): | |
| prev_spk = words[i - 1]["speaker"] | |
| curr_spk = words[i]["speaker"] | |
| next_spk = words[i + 1]["speaker"] | |
| if prev_spk == next_spk and curr_spk != prev_spk: | |
| smoothed[i] = prev_spk | |
| for i in range(len(words)): | |
| words[i]["speaker"] = smoothed[i] | |
| # 3) Determine dominant speaker by summed word durations | |
| speaker_dur = {} | |
| total_word_dur = 0.0 | |
| for w in words: | |
| dur = max(0.0, float(w["end"]) - float(w["start"])) | |
| total_word_dur += dur | |
| spk = w.get("speaker", "SPEAKER_00") | |
| speaker_dur[spk] = speaker_dur.get(spk, 0.0) + dur | |
| if speaker_dur: | |
| dominant_speaker = max(speaker_dur.items(), key=lambda kv: kv[1])[0] | |
| else: | |
| dominant_speaker = speaker_at((float(seg["start"]) + float(seg["end"])) / 2.0) | |
| # 4) Boundary stabilization: relabel tiny prefix/suffix runs to dominant | |
| seg_duration = max(1e-6, float(seg["end"]) - float(seg["start"])) | |
| max_boundary_sec = 0.5 # hard cap for how much to relabel at edges | |
| max_boundary_frac = 0.2 # or up to 20% of the segment duration | |
| # prefix | |
| prefix_dur = 0.0 | |
| prefix_count = 0 | |
| for w in words: | |
| if w.get("speaker") == dominant_speaker: | |
| break | |
| prefix_dur += max(0.0, float(w["end"]) - float(w["start"])) | |
| prefix_count += 1 | |
| if prefix_count > 0 and prefix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration): | |
| for i in range(prefix_count): | |
| words[i]["speaker"] = dominant_speaker | |
| # suffix | |
| suffix_dur = 0.0 | |
| suffix_count = 0 | |
| for w in reversed(words): | |
| if w.get("speaker") == dominant_speaker: | |
| break | |
| suffix_dur += max(0.0, float(w["end"]) - float(w["start"])) | |
| suffix_count += 1 | |
| if suffix_count > 0 and suffix_dur <= min(max_boundary_sec, max_boundary_frac * seg_duration): | |
| for i in range(len(words) - suffix_count, len(words)): | |
| words[i]["speaker"] = dominant_speaker | |
| # 5) Final segment speaker | |
| seg["speaker"] = dominant_speaker | |
| else: | |
| # No word timings: choose by overlap with diarization over the whole segment | |
| seg["speaker"] = best_speaker_for_interval(float(seg["start"]), float(seg["end"])) | |
| return transcription_results | |
| def group_segments_by_speaker(self, segments, max_gap=1.0, max_duration=30.0): | |
| """Group consecutive segments from the same speaker""" | |
| if not segments: | |
| return segments | |
| grouped_segments = [] | |
| current_group = segments[0].copy() | |
| sentence_end_pattern = r"[.!?]+" | |
| for segment in segments[1:]: | |
| time_gap = segment["start"] - current_group["end"] | |
| current_duration = current_group["end"] - current_group["start"] | |
| # Conditions for combining segments | |
| can_combine = ( | |
| segment["speaker"] == current_group["speaker"] and | |
| time_gap <= max_gap and | |
| current_duration < max_duration and | |
| not re.search(sentence_end_pattern, current_group["text"][-1:]) | |
| ) | |
| if can_combine: | |
| # Merge segments | |
| current_group["end"] = segment["end"] | |
| current_group["text"] += " " + segment["text"] | |
| current_group["words"].extend(segment["words"]) | |
| current_group["duration"] = current_group["end"] - current_group["start"] | |
| else: | |
| # Start new group | |
| grouped_segments.append(current_group) | |
| current_group = segment.copy() | |
| grouped_segments.append(current_group) | |
| # Clean up text | |
| for segment in grouped_segments: | |
| segment["text"] = re.sub(r"\s+", " ", segment["text"]).strip() | |
| #segment["text"] = re.sub(r"\s+([.,!?])", r"\1", segment["text"]) | |
| return grouped_segments | |
| # each call gets a GPU slice | |
| def process_audio_full(self, task_json, language=None, translate=False, prompt=None, group_segments=True, batch_size=16): | |
| """Process a single chunk using task JSON (no diarization).""" | |
| if not task_json or not str(task_json).strip(): | |
| return {"error": "No JSON provided"} | |
| pre_meta = None | |
| try: | |
| print("Starting full transcription pipeline...") | |
| # Step 1: Preprocess per chunk JSON | |
| print("Preprocessing chunk JSON...") | |
| pre_meta = self.preprocess_from_task_json(task_json) | |
| if pre_meta.get("skip"): | |
| return {"segments": [], "language": "unknown", "num_speakers": 1, "transcription_method": "full_audio_batched", "batch_size": batch_size} | |
| wav_path = pre_meta["out_wav_path"] | |
| # Adjust timestamps by trimmed_start_ms: abs_start_ms is already global start for saved file | |
| base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0 | |
| # Step 2: Transcribe the entire audio with batching | |
| transcription_results, detected_language = self.transcribe_full_audio( | |
| wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s | |
| ) | |
| # Step 3: Group segments if requested (based on time gaps and sentence endings) | |
| if group_segments: | |
| transcription_results = self.group_segments_by_speaker(transcription_results) | |
| # Step 4: Return results | |
| return { | |
| "segments": transcription_results, | |
| "language": detected_language, | |
| "num_speakers": 1, # Single speaker assumption | |
| "transcription_method": "full_audio_batched", | |
| "batch_size": batch_size | |
| } | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| finally: | |
| # Clean up preprocessed wav | |
| if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]): | |
| try: | |
| os.unlink(pre_meta["out_wav_path"]) | |
| except Exception: | |
| pass | |
| # each call gets a GPU slice | |
| def process_audio(self, task_json, num_speakers=None, language=None, | |
| translate=False, prompt=None, group_segments=True, batch_size=8): | |
| """Main processing function with diarization using task JSON for a single chunk. | |
| Transcribes full (preprocessed) audio once, performs diarization, merges speakers into transcription. | |
| """ | |
| if not task_json or not str(task_json).strip(): | |
| return {"error": "No JSON provided"} | |
| pre_meta = None | |
| try: | |
| print("Starting new processing pipeline...") | |
| # Step 1: Preprocess per chunk JSON | |
| print("Preprocessing chunk JSON...") | |
| pre_meta = self.preprocess_from_task_json(task_json) | |
| if pre_meta.get("skip"): | |
| return {"segments": [], "language": "unknown", "num_speakers": 0, "transcription_method": "diarized_segments_batched", "batch_size": batch_size} | |
| wav_path = pre_meta["out_wav_path"] | |
| base_offset_s = float(pre_meta.get("abs_start_ms", 0)) / 1000.0 | |
| # Step 2: Transcribe full audio once | |
| transcription_results, detected_language = self.transcribe_full_audio( | |
| wav_path, language, translate, prompt, batch_size, base_offset_s=base_offset_s | |
| ) | |
| # Step 3: Perform diarization with global offset | |
| diarization_segments, detected_num_speakers, speaker_embeddings = self.perform_diarization( | |
| wav_path, num_speakers, base_offset_s=base_offset_s | |
| ) | |
| # Step 4: Merge diarization into transcription (assign speakers) | |
| transcription_results = self.assign_speakers_to_transcription(transcription_results, diarization_segments) | |
| # Step 5: Group segments if requested | |
| if group_segments: | |
| transcription_results = self.group_segments_by_speaker(transcription_results) | |
| # Step 6: Return results | |
| result = { | |
| "segments": transcription_results, | |
| "language": detected_language, | |
| "num_speakers": detected_num_speakers, | |
| "transcription_method": "diarized_segments_batched", | |
| "batch_size": batch_size, | |
| "speaker_embeddings": speaker_embeddings, | |
| } | |
| job_id = pre_meta["job_id"] | |
| task_id = pre_meta["chunk_idx"] | |
| filekey = f"ai-transcribe/split/{job_id}-{task_id}.json" | |
| ret = upload_data_to_r2(json.dumps(result), "intermediate", filekey) | |
| if ret: | |
| return {"filekey": filekey} | |
| else: | |
| return {"error": "Failed to upload to R2"} | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Processing failed: {str(e)}"} | |
| finally: | |
| # Clean up preprocessed wav | |
| if pre_meta and pre_meta.get("out_wav_path") and os.path.exists(pre_meta["out_wav_path"]): | |
| try: | |
| os.unlink(pre_meta["out_wav_path"]) | |
| except Exception: | |
| pass | |
| # Initialize transcriber | |
| transcriber = WhisperTranscriber() | |
| def format_segments_for_display(result): | |
| """Format segments for display in Gradio""" | |
| if "error" in result: | |
| return f"β Error: {result['error']}" | |
| segments = result.get("segments", []) | |
| language = result.get("language", "unknown") | |
| num_speakers = result.get("num_speakers", 1) | |
| method = result.get("transcription_method", "unknown") | |
| batch_size = result.get("batch_size", "N/A") | |
| output = f"π― **Detection Results:**\n" | |
| output += f"- Language: {language}\n" | |
| output += f"- Speakers: {num_speakers}\n" | |
| output += f"- Segments: {len(segments)}\n" | |
| output += f"- Method: {method}\n" | |
| output += f"- Batch Size: {batch_size}\n\n" | |
| output += "π **Transcription:**\n\n" | |
| for i, segment in enumerate(segments, 1): | |
| start_time = str(datetime.timedelta(seconds=int(segment["start"]))) | |
| end_time = str(datetime.timedelta(seconds=int(segment["end"]))) | |
| speaker = segment.get("speaker", "SPEAKER_00") | |
| text = segment["text"] | |
| output += f"**{speaker}** ({start_time} β {end_time})\n" | |
| output += f"{text}\n\n" | |
| return output | |
| def process_audio_gradio(task_json, num_speakers, language, translate, prompt, group_segments, use_diarization, batch_size): | |
| """Gradio interface function""" | |
| if use_diarization: | |
| result = transcriber.process_audio( | |
| task_json=task_json, | |
| num_speakers=num_speakers if num_speakers > 0 else None, | |
| language=language if language != "auto" else None, | |
| translate=translate, | |
| prompt=prompt if prompt and prompt.strip() else None, | |
| group_segments=group_segments, | |
| batch_size=batch_size | |
| ) | |
| else: | |
| result = transcriber.process_audio_full( | |
| task_json=task_json, | |
| language=language if language != "auto" else None, | |
| translate=translate, | |
| prompt=prompt if prompt and prompt.strip() else None, | |
| group_segments=group_segments, | |
| batch_size=batch_size | |
| ) | |
| formatted_output = format_segments_for_display(result) | |
| return formatted_output, result | |
| # Create Gradio interface | |
| demo = gr.Blocks( | |
| title="ποΈ Whisper Transcription with Speaker Diarization", | |
| theme="default" | |
| ) | |
| with demo: | |
| gr.Markdown(""" | |
| # ποΈ Advanced Audio Transcription & Speaker Diarization | |
| Upload an audio file to get accurate transcription with speaker identification, powered by: | |
| - **Faster-Whisper Large V3 Turbo** with batched inference for optimal performance | |
| - **Pyannote 3.1** for speaker diarization | |
| - **ZeroGPU** acceleration for optimal performance | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| task_json_input = gr.Textbox( | |
| label="π§Ύ Paste Task JSON", | |
| placeholder="Paste the per-chunk task JSON here...", | |
| lines=16, | |
| ) | |
| with gr.Accordion("βοΈ Advanced Settings", open=False): | |
| use_diarization = gr.Checkbox( | |
| label="Enable Speaker Diarization", | |
| value=True, | |
| info="Uncheck for faster transcription without speaker identification" | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=1, | |
| maximum=128, | |
| value=16, | |
| step=1, | |
| label="Batch Size", | |
| info="Higher values = faster processing but more GPU memory usage. Recommended: 8-24" | |
| ) | |
| num_speakers = gr.Slider( | |
| minimum=0, | |
| maximum=20, | |
| value=0, | |
| step=1, | |
| label="Number of Speakers (0 = auto-detect)", | |
| visible=True | |
| ) | |
| language = gr.Dropdown( | |
| choices=["auto", "en", "es", "fr", "de", "it", "pt", "ru", "ja", "ko", "zh"], | |
| value="auto", | |
| label="Language" | |
| ) | |
| translate = gr.Checkbox( | |
| label="Translate to English", | |
| value=False | |
| ) | |
| prompt = gr.Textbox( | |
| label="Vocabulary Prompt (names, acronyms, etc.)", | |
| placeholder="Enter names, technical terms, or context...", | |
| lines=2 | |
| ) | |
| group_segments = gr.Checkbox( | |
| label="Group segments by speaker/time", | |
| value=True | |
| ) | |
| process_btn = gr.Button("π Transcribe Audio", variant="primary") | |
| with gr.Column(): | |
| output_text = gr.Markdown( | |
| label="π Transcription Results", | |
| value="Paste task JSON and click 'Transcribe Audio' to get started!" | |
| ) | |
| output_json = gr.JSON( | |
| label="π§ Raw Output (JSON)", | |
| visible=False | |
| ) | |
| # Update visibility of num_speakers based on diarization toggle | |
| use_diarization.change( | |
| fn=lambda x: gr.update(visible=x), | |
| inputs=[use_diarization], | |
| outputs=[num_speakers] | |
| ) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_audio_gradio, | |
| inputs=[ | |
| task_json_input, | |
| num_speakers, | |
| language, | |
| translate, | |
| prompt, | |
| group_segments, | |
| use_diarization, | |
| batch_size | |
| ], | |
| outputs=[output_text, output_json] | |
| ) | |
| # Examples | |
| gr.Markdown("### π Usage Tips:") | |
| gr.Markdown(""" | |
| - Paste a single-chunk task JSON matching the preprocess schema | |
| - Batch Size: Higher values (16-24) = faster but uses more GPU memory | |
| - Speaker diarization: Enable for speaker identification (slower) | |
| - Languages: Supports 100+ languages with auto-detection | |
| - Vocabulary: Add names and technical terms in the prompt for better accuracy | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(debug=True) | |