|
|
| import gc |
| import json |
| import os |
| import re |
| import shutil |
| import subprocess |
| import tempfile |
| import traceback |
| from collections import Counter |
| from pathlib import Path |
| from typing import Any, Dict, List, Tuple |
|
|
| import gradio as gr |
| import numpy as np |
| import pandas as pd |
| import soundfile as sf |
| import torch |
| from faster_whisper import WhisperModel |
| from pyannote.audio import Pipeline |
|
|
| GPU_AVAILABLE = torch.cuda.is_available() |
| ASR_DEVICE = "cuda" if GPU_AVAILABLE else "cpu" |
| DIAR_DEVICE = "cuda" if GPU_AVAILABLE else "cpu" |
|
|
| BEAM_SIZE = 5 |
| BEST_OF = 5 |
| PATIENCE = 1.0 |
| TEMPERATURES = [0.0, 0.2, 0.4] |
| WINDOW_SECONDS = 28.0 |
| WINDOW_GAP_SECONDS = 1.2 |
| WINDOW_PAD_SECONDS = 0.35 |
| MIN_SPEECH_SECONDS = 0.18 |
| MIN_SILENCE_SECONDS = 0.35 |
| MAX_SEGMENT_SECONDS = 7.0 |
| MAX_SEGMENT_WORDS = 30 |
|
|
| BAD_PHRASES = [ |
| "transcribe exactly", |
| "hindi must be written only in devanagari script", |
| "english must be written only in latin script", |
| "never use urdu arabic or perso arabic script", |
| "thank you for watching", |
| "subscribe", |
| ] |
| URDU_ARABIC_SCRIPT_RE = re.compile(r"[\u0600-\u06FF\u0750-\u077F\u08A0-\u08FF]") |
|
|
| def cleanup_torch(): |
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| try: |
| torch.cuda.ipc_collect() |
| except Exception: |
| pass |
|
|
| def compute_type_for_model(asr_model_name: str) -> str: |
| if ASR_DEVICE != "cuda": |
| return "int8" |
| if asr_model_name == "large-v3": |
| return "int8_float16" |
| return "float16" |
|
|
| def run_cmd(cmd): |
| result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| if result.returncode != 0: |
| raise RuntimeError(f"Command failed:\n{' '.join(cmd)}\n\nSTDERR:\n{result.stderr}") |
| return result |
|
|
| def ffprobe_duration(input_path: Path): |
| cmd = ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "default=noprint_wrappers=1:nokey=1", str(input_path)] |
| result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) |
| if result.returncode != 0: |
| return None |
| try: |
| return float(result.stdout.strip()) |
| except Exception: |
| return None |
|
|
| def to_wav_16k(input_path: Path, output_path: Path, enhance_audio: bool): |
| af = ["aresample=async=1:first_pts=0"] |
| if enhance_audio: |
| af = [ |
| "highpass=f=80", |
| "lowpass=f=7600", |
| "dynaudnorm=f=150:g=15:p=0.90", |
| "aresample=async=1:first_pts=0", |
| ] |
| cmd = ["ffmpeg", "-y", "-i", str(input_path), "-vn", "-ac", "1", "-ar", "16000", "-c:a", "pcm_s16le", "-af", ",".join(af), str(output_path)] |
| run_cmd(cmd) |
| return output_path |
|
|
| def load_waveform_for_pyannote(wav_path: Path): |
| audio, sample_rate = sf.read(str(wav_path), dtype="float32") |
| if audio.ndim > 1: |
| audio = audio.mean(axis=1) |
| waveform = torch.from_numpy(audio).unsqueeze(0) |
| return {"waveform": waveform, "sample_rate": int(sample_rate)} |
|
|
| def load_audio_np(audio_path: Path): |
| audio, sample_rate = sf.read(str(audio_path), dtype="float32") |
| if audio.ndim > 1: |
| audio = np.mean(audio, axis=1).astype(np.float32) |
| audio = np.asarray(audio, dtype=np.float32) |
| if sample_rate != 16000: |
| raise ValueError(f"Expected 16k WAV after ffmpeg conversion, got {sample_rate}") |
| if len(audio) == 0: |
| raise ValueError("Audio file is empty") |
| return audio, sample_rate |
|
|
| def normalize_spaces(text): |
| text = (text or "").replace("\n", " ").replace("\r", " ") |
| text = re.sub(r"\s+", " ", text).strip() |
| return text |
|
|
| def normalize_for_compare(text): |
| text = normalize_spaces(text).casefold() |
| text = re.sub(r"[\W_]+", " ", text, flags=re.UNICODE) |
| return re.sub(r"\s+", " ", text).strip() |
|
|
| def looks_bad_text(text): |
| norm = normalize_for_compare(text) |
| if not norm: |
| return True |
| return any(p in norm for p in BAD_PHRASES) |
|
|
| def contains_urdu_or_arabic_script(text): |
| return bool(URDU_ARABIC_SCRIPT_RE.search(text or "")) |
|
|
| def similarity(a: str, b: str) -> float: |
| from difflib import SequenceMatcher |
| if not a and not b: |
| return 1.0 |
| if not a or not b: |
| return 0.0 |
| return SequenceMatcher(None, a, b).ratio() |
|
|
| def text_has_bad_repetition(text): |
| norm = normalize_for_compare(text) |
| words = norm.split() |
| if len(words) < 8: |
| return False |
| for n in range(1, min(6, len(words) // 2 + 1)): |
| run = 1 |
| prev = None |
| for i in range(0, len(words) - n + 1, n): |
| gram = tuple(words[i:i + n]) |
| if len(gram) != n: |
| continue |
| if gram == prev: |
| run += 1 |
| if run >= 3: |
| return True |
| else: |
| run = 1 |
| prev = gram |
| counts = Counter(words) |
| if len(words) >= 12 and counts and max(counts.values()) / max(1, len(words)) >= 0.45: |
| return True |
| return False |
|
|
| def safe_float(value: Any, default: float = 0.0) -> float: |
| try: |
| return float(value) |
| except Exception: |
| return default |
|
|
| def format_hhmmss_mmm(seconds): |
| seconds = max(0.0, float(seconds)) |
| total_ms = int(round(seconds * 1000.0)) |
| ms = total_ms % 1000 |
| total_s = total_ms // 1000 |
| s = total_s % 60 |
| total_m = total_s // 60 |
| m = total_m % 60 |
| h = total_m // 60 |
| return f"{h:02d}:{m:02d}:{s:02d}.{ms:03d}" |
|
|
| def preflight(media_file, asr_model_name, language, enhance_audio, num_speakers, min_speakers, max_speakers): |
| lines = [ |
| "=== PREFLIGHT ===", |
| f"GPU available: {GPU_AVAILABLE}", |
| f"ASR device: {ASR_DEVICE}", |
| f"Diarization device: {DIAR_DEVICE}", |
| "Diarization model: pyannote/speaker-diarization-community-1", |
| f"ASR model: {asr_model_name}", |
| f"ASR compute type: {compute_type_for_model(asr_model_name)}", |
| f"Language: {language}", |
| f"Enhance audio: {enhance_audio}", |
| f"HF_TOKEN present: {bool(os.getenv('HF_TOKEN'))}", |
| f"ffmpeg found: {shutil.which('ffmpeg') is not None}", |
| f"ffprobe found: {shutil.which('ffprobe') is not None}", |
| f"torch version: {torch.__version__}", |
| f"Speaker controls -> num:{num_speakers} min:{min_speakers} max:{max_speakers}", |
| "Repo-style transcription logic is active.", |
| ] |
| if media_file is None: |
| lines.append("No media file uploaded yet.") |
| return "\n".join(lines) |
| try: |
| p = Path(media_file) |
| size_mb = p.stat().st_size / (1024 * 1024) |
| dur = ffprobe_duration(p) |
| lines.append(f"Uploaded file: {p.name}") |
| lines.append(f"File size: {size_mb:.2f} MB") |
| if dur is not None: |
| lines.append(f"Estimated duration: {dur:.2f} sec") |
| if dur > 1800: |
| lines.append("Warning: long file on T4 small. Start with medium.") |
| except Exception as e: |
| lines.append(f"File inspection failed: {e}") |
| return "\n".join(lines) |
|
|
| |
| def frame_rms(audio: np.ndarray, sample_rate: int, frame_ms: float = 30.0, hop_ms: float = 10.0) -> Tuple[np.ndarray, np.ndarray]: |
| frame = max(1, int(sample_rate * frame_ms / 1000.0)) |
| hop = max(1, int(sample_rate * hop_ms / 1000.0)) |
| if len(audio) < frame: |
| padded = np.pad(audio, (0, frame - len(audio))) |
| return np.array([0.0], dtype=np.float32), np.array([float(np.sqrt(np.mean(padded * padded) + 1e-12))], dtype=np.float32) |
| starts = np.arange(0, len(audio) - frame + 1, hop, dtype=np.int64) |
| rms = np.empty(len(starts), dtype=np.float32) |
| for i, start in enumerate(starts): |
| chunk = audio[start:start + frame] |
| rms[i] = float(np.sqrt(np.mean(chunk * chunk) + 1e-12)) |
| times = starts.astype(np.float32) / float(sample_rate) |
| return times, rms |
|
|
| def fill_short_silences(active: np.ndarray, max_gap_frames: int) -> np.ndarray: |
| if max_gap_frames <= 0 or len(active) == 0: |
| return active |
| output = active.copy() |
| i = 0 |
| n = len(output) |
| while i < n: |
| if output[i]: |
| i += 1 |
| continue |
| start = i |
| while i < n and not output[i]: |
| i += 1 |
| end = i |
| left_active = start > 0 and output[start - 1] |
| right_active = end < n and output[end] |
| if left_active and right_active and (end - start) <= max_gap_frames: |
| output[start:end] = True |
| return output |
|
|
| def remove_short_speech(active: np.ndarray, min_speech_frames: int) -> np.ndarray: |
| if min_speech_frames <= 1 or len(active) == 0: |
| return active |
| output = active.copy() |
| i = 0 |
| n = len(output) |
| while i < n: |
| if not output[i]: |
| i += 1 |
| continue |
| start = i |
| while i < n and output[i]: |
| i += 1 |
| end = i |
| if (end - start) < min_speech_frames: |
| output[start:end] = False |
| return output |
|
|
| def detect_speech_intervals(audio: np.ndarray, sample_rate: int, total_duration: float) -> List[Tuple[float, float]]: |
| times, rms = frame_rms(audio, sample_rate) |
| db = 20.0 * np.log10(np.maximum(rms, 1e-8)) |
| p20 = float(np.percentile(db, 20)) |
| p50 = float(np.percentile(db, 50)) |
| p75 = float(np.percentile(db, 75)) |
| p90 = float(np.percentile(db, 90)) |
| threshold = max(p20 + 6.0, p50 + 2.5) |
| threshold = min(threshold, p75 - 2.0 if p75 > p20 + 8.0 else threshold) |
| threshold = min(threshold, p90 - 8.0 if p90 > p20 + 12.0 else threshold) |
| active = db >= threshold |
| hop_seconds = 0.010 |
| active = fill_short_silences(active, max_gap_frames=int(MIN_SILENCE_SECONDS / hop_seconds)) |
| active = remove_short_speech(active, min_speech_frames=max(1, int(MIN_SPEECH_SECONDS / hop_seconds))) |
|
|
| intervals = [] |
| i = 0 |
| n = len(active) |
| while i < n: |
| if not active[i]: |
| i += 1 |
| continue |
| start_idx = i |
| while i < n and active[i]: |
| i += 1 |
| end_idx = i |
| start = max(0.0, float(times[start_idx]) - WINDOW_PAD_SECONDS) |
| end = min(total_duration, float(times[min(end_idx - 1, len(times) - 1)]) + 0.03 + WINDOW_PAD_SECONDS) |
| if end - start >= MIN_SPEECH_SECONDS: |
| intervals.append((start, end)) |
| if not intervals: |
| return [(0.0, total_duration)] |
|
|
| merged = [] |
| for start, end in intervals: |
| if not merged: |
| merged.append((start, end)) |
| continue |
| prev_start, prev_end = merged[-1] |
| if start - prev_end <= WINDOW_GAP_SECONDS and (end - prev_start) <= WINDOW_SECONDS: |
| merged[-1] = (prev_start, max(prev_end, end)) |
| else: |
| merged.append((start, end)) |
| return merged |
|
|
| def split_long_intervals(intervals: List[Tuple[float, float]], total_duration: float) -> List[Tuple[float, float]]: |
| windows = [] |
| for start, end in intervals: |
| duration = end - start |
| if duration <= WINDOW_SECONDS: |
| windows.append((start, end)) |
| continue |
| cursor = start |
| overlap = min(1.0, max(0.0, WINDOW_PAD_SECONDS)) |
| while cursor < end: |
| win_end = min(end, cursor + WINDOW_SECONDS) |
| windows.append((max(0.0, cursor), min(total_duration, win_end))) |
| if win_end >= end: |
| break |
| cursor = max(cursor + 1.0, win_end - overlap) |
| cleaned = [] |
| for start, end in windows: |
| if end - start < 0.08: |
| continue |
| if cleaned and abs(start - cleaned[-1][0]) < 0.05 and abs(end - cleaned[-1][1]) < 0.05: |
| continue |
| cleaned.append((round(start, 3), round(end, 3))) |
| return cleaned |
|
|
| def word_list_from_segment(seg: Any, base_offset: float, window_start: float, window_end: float) -> List[Dict[str, Any]]: |
| words = [] |
| raw_words = getattr(seg, "words", None) or [] |
| for w in raw_words: |
| if getattr(w, "start", None) is None or getattr(w, "end", None) is None: |
| continue |
| start = float(w.start) + base_offset |
| end = float(w.end) + base_offset |
| if end < window_start - 0.20 or start > window_end + 0.20: |
| continue |
| words.append({"start": round(max(0.0, start), 2), "end": round(max(0.0, end), 2), "word": str(getattr(w, "word", "") or "")}) |
| return words |
|
|
| def transcribe_window(model: WhisperModel, audio: np.ndarray, sample_rate: int, start: float, end: float, language: str) -> List[Dict[str, Any]]: |
| start_sample = max(0, int(start * sample_rate)) |
| end_sample = min(len(audio), int(end * sample_rate)) |
| chunk = audio[start_sample:end_sample] |
| if len(chunk) < int(0.08 * sample_rate): |
| return [] |
|
|
| prompt = ( |
| "This is an Indian meeting conversation containing only Hindi, Hinglish, and English. " |
| "Transcribe exactly. Do not translate. " |
| "Hindi must be written only in Devanagari script. " |
| "English must be written only in Latin script. " |
| "Never use Urdu, Arabic, or Perso-Arabic script. " |
| "Preserve names, product terms, technical terms, repository names, GitHub terms, and code-mixed speech exactly." |
| ) |
| kwargs: Dict[str, Any] = { |
| "language": language, |
| "beam_size": BEAM_SIZE, |
| "best_of": BEST_OF, |
| "patience": PATIENCE, |
| "temperature": TEMPERATURES, |
| "condition_on_previous_text": False, |
| "vad_filter": False, |
| "word_timestamps": True, |
| "task": "transcribe", |
| "initial_prompt": prompt, |
| "no_speech_threshold": 0.82, |
| "log_prob_threshold": -1.35, |
| "compression_ratio_threshold": 2.55, |
| "hallucination_silence_threshold": 1.2, |
| } |
| try: |
| segments_iter, _ = model.transcribe(chunk, **kwargs) |
| except TypeError: |
| for key in ["hallucination_silence_threshold", "best_of", "patience", "initial_prompt"]: |
| kwargs.pop(key, None) |
| segments_iter, _ = model.transcribe(chunk, **kwargs) |
|
|
| output = [] |
| for seg in segments_iter: |
| text = normalize_spaces(str(getattr(seg, "text", "") or "")) |
| if not text: |
| continue |
| if contains_urdu_or_arabic_script(text): |
| continue |
| seg_start = float(getattr(seg, "start", 0.0)) + start |
| seg_end = float(getattr(seg, "end", 0.0)) + start |
| if seg_end <= seg_start: |
| continue |
| seg_start = max(0.0, min(seg_start, end)) |
| seg_end = max(seg_start + 0.01, min(seg_end, end)) |
| output.append({ |
| "start": round(seg_start, 2), |
| "end": round(seg_end, 2), |
| "text": text, |
| "words": word_list_from_segment(seg, start, start, end), |
| }) |
| return output |
|
|
| def split_segment_by_words(seg: Dict[str, Any]) -> List[Dict[str, Any]]: |
| words = seg.get("words") or [] |
| if not words: |
| return [dict(seg)] |
| start = safe_float(seg.get("start")) |
| end = safe_float(seg.get("end"), start) |
| if (end - start) <= MAX_SEGMENT_SECONDS and len(words) <= MAX_SEGMENT_WORDS: |
| return [dict(seg)] |
|
|
| pieces = [] |
| bucket = [] |
|
|
| def flush(): |
| nonlocal bucket |
| if not bucket: |
| return |
| text = normalize_spaces("".join(str(w.get("word", "")) for w in bucket)) |
| if text: |
| new_seg = dict(seg) |
| new_seg["start"] = round(safe_float(bucket[0].get("start")), 2) |
| new_seg["end"] = round(safe_float(bucket[-1].get("end")), 2) |
| new_seg["text"] = text |
| new_seg["words"] = [dict(w) for w in bucket] |
| pieces.append(new_seg) |
| bucket = [] |
|
|
| for idx, word in enumerate(words): |
| bucket.append(word) |
| bucket_start = safe_float(bucket[0].get("start")) |
| bucket_end = safe_float(bucket[-1].get("end")) |
| duration = bucket_end - bucket_start |
| next_gap = 0.0 |
| if idx + 1 < len(words): |
| next_gap = max(0.0, safe_float(words[idx + 1].get("start")) - safe_float(word.get("end"))) |
| token = str(word.get("word", "")).strip() |
| boundary = token.endswith((".", "?", "!", ",", "।")) or next_gap >= 0.45 |
| too_long = duration >= MAX_SEGMENT_SECONDS or len(bucket) >= MAX_SEGMENT_WORDS |
| if (boundary and duration >= 0.9) or too_long: |
| flush() |
| flush() |
| return pieces or [dict(seg)] |
|
|
| def dedupe_transcript_segments(segments: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
| if not segments: |
| return [] |
| segments = sorted(segments, key=lambda x: (safe_float(x.get("start")), safe_float(x.get("end")))) |
| cleaned = [] |
| for seg in segments: |
| text = normalize_spaces(str(seg.get("text", ""))) |
| if not text: |
| continue |
| if contains_urdu_or_arabic_script(text): |
| continue |
| if text_has_bad_repetition(text): |
| continue |
| seg = dict(seg) |
| seg["text"] = text |
| start = safe_float(seg.get("start")) |
| end = safe_float(seg.get("end"), start) |
| if end <= start: |
| continue |
|
|
| curr_norm = normalize_for_compare(text) |
| duplicate_idx = None |
| for idx in range(max(0, len(cleaned) - 6), len(cleaned)): |
| prev = cleaned[idx] |
| prev_norm = normalize_for_compare(str(prev.get("text", ""))) |
| if not prev_norm or not curr_norm: |
| continue |
| prev_start = safe_float(prev.get("start")) |
| prev_end = safe_float(prev.get("end"), prev_start) |
| time_overlap = max(0.0, min(prev_end, end) - max(prev_start, start)) |
| min_duration = max(0.01, min(prev_end - prev_start, end - start)) |
| overlap_ratio = time_overlap / min_duration |
| near_boundary = abs(start - prev_start) <= 1.25 or abs(end - prev_end) <= 1.25 or start - prev_end <= 0.8 |
| same_or_contained = curr_norm == prev_norm or curr_norm in prev_norm or prev_norm in curr_norm |
| very_similar = similarity(curr_norm, prev_norm) >= 0.94 |
| if (overlap_ratio >= 0.35 or near_boundary) and (same_or_contained or very_similar): |
| duplicate_idx = idx |
| break |
|
|
| if duplicate_idx is not None: |
| prev = cleaned[duplicate_idx] |
| if len(curr_norm) > len(normalize_for_compare(str(prev.get("text", "")))): |
| prev["text"] = text |
| if seg.get("words"): |
| prev["words"] = seg.get("words") |
| prev["start"] = round(min(safe_float(prev.get("start")), start), 2) |
| prev["end"] = round(max(safe_float(prev.get("end")), end), 2) |
| continue |
|
|
| cleaned.append(seg) |
|
|
| final = [] |
| for seg in cleaned: |
| item = {"start": round(safe_float(seg.get("start")), 2), "end": round(safe_float(seg.get("end")), 2), "text": normalize_spaces(str(seg.get("text", "")))} |
| if seg.get("words"): |
| item["words"] = seg.get("words") |
| final.append(item) |
| return final |
|
|
| def transcribe_audio_chunked_repo_style(wav_path: Path, asr_model_name: str, language_choice: str): |
| audio, sample_rate = load_audio_np(wav_path) |
| total_duration = len(audio) / float(sample_rate) |
| intervals = detect_speech_intervals(audio, sample_rate, total_duration) |
| windows = split_long_intervals(intervals, total_duration) |
|
|
| model = WhisperModel( |
| asr_model_name, |
| device=ASR_DEVICE, |
| compute_type=compute_type_for_model(asr_model_name), |
| cpu_threads=4 if ASR_DEVICE == "cpu" else 2, |
| num_workers=1, |
| ) |
|
|
| |
| lang = "hi" if language_choice == "auto" else language_choice |
|
|
| all_segments = [] |
| for start, end in windows: |
| try: |
| window_segments = transcribe_window(model, audio, sample_rate, start, end, lang) |
| except Exception: |
| continue |
| for seg in window_segments: |
| all_segments.extend(split_segment_by_words(seg)) |
|
|
| del model |
| cleanup_torch() |
|
|
| results = dedupe_transcript_segments(all_segments) |
| results.sort(key=lambda x: (float(x["start"]), float(x["end"]))) |
| return results, len(windows), lang |
|
|
| def choose_speaker_for_word(word_start, word_end, diar_df): |
| if diar_df.empty: |
| return "UNKNOWN_SPEAKER" |
| tmp = diar_df.copy() |
| tmp["overlap"] = tmp.apply(lambda r: max(0.0, min(word_end, r["end"]) - max(word_start, r["start"])), axis=1) |
| hits = tmp[tmp["overlap"] > 0].copy() |
| if not hits.empty: |
| best = hits.sort_values("overlap", ascending=False).iloc[0] |
| return str(best["speaker"]) |
| mid = (word_start + word_end) / 2.0 |
| tmp["dist"] = tmp.apply(lambda r: min(abs(mid - r["start"]), abs(mid - r["end"])), axis=1) |
| best = tmp.sort_values("dist").iloc[0] |
| return str(best["speaker"]) |
|
|
| def assign_speaker_to_segment(segment, diar_df): |
| speaker_counts = {} |
| for w in segment.get("words", []): |
| spk = choose_speaker_for_word(float(w["start"]), float(w["end"]), diar_df) |
| speaker_counts[spk] = speaker_counts.get(spk, 0) + 1 |
| if speaker_counts: |
| return max(speaker_counts, key=speaker_counts.get) |
| return "UNKNOWN_SPEAKER" |
|
|
| def merge_adjacent_same_speaker(segments): |
| if not segments: |
| return [] |
| merged = [dict(segments[0])] |
| for seg in segments[1:]: |
| last = merged[-1] |
| if seg["speaker"] == last["speaker"]: |
| last["end"] = max(float(last["end"]), float(seg["end"])) |
| if seg["text"]: |
| last["text"] = normalize_spaces(last["text"] + " " + seg["text"]) |
| else: |
| merged.append(dict(seg)) |
| return merged |
|
|
| def process_media(media_file, asr_model_name, language, enhance_audio, filter_known_bad, num_speakers, min_speakers, max_speakers, progress=gr.Progress(track_tqdm=False)): |
| if media_file is None: |
| raise gr.Error("Please upload a media file.") |
| hf_token = (os.getenv("HF_TOKEN") or "").strip() |
| if not hf_token: |
| raise gr.Error("Missing HF_TOKEN Space Secret.") |
| work_root = Path(tempfile.mkdtemp(prefix="diarized_c1_")) |
| out_dir = work_root / "outputs" |
| out_dir.mkdir(parents=True, exist_ok=True) |
| input_path = Path(media_file) |
| wav_path = out_dir / "input_16k.wav" |
|
|
| try: |
| progress(0.05, desc="Preparing audio") |
| to_wav_16k(input_path, wav_path, enhance_audio=enhance_audio) |
|
|
| progress(0.16, desc=f"Repo-style transcription: {asr_model_name}") |
| raw_segments, window_count, used_language = transcribe_audio_chunked_repo_style(wav_path, asr_model_name, language) |
|
|
| if filter_known_bad: |
| filtered = [] |
| for seg in raw_segments: |
| t = normalize_spaces(seg.get("text", "")) |
| if not t: |
| continue |
| if looks_bad_text(t): |
| continue |
| if text_has_bad_repetition(t): |
| continue |
| seg = dict(seg) |
| seg["text"] = t |
| filtered.append(seg) |
| raw_segments = filtered |
|
|
| word_count = sum(len(seg.get("words", []) or []) for seg in raw_segments) |
|
|
| progress(0.56, desc="Loading diarization model") |
| pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-community-1", token=hf_token) |
| if DIAR_DEVICE == "cuda": |
| pipeline.to(torch.device("cuda")) |
|
|
| diar_kwargs = {} |
| if num_speakers and int(num_speakers) > 0: |
| diar_kwargs["num_speakers"] = int(num_speakers) |
| else: |
| if min_speakers and int(min_speakers) > 0: |
| diar_kwargs["min_speakers"] = int(min_speakers) |
| if max_speakers and int(max_speakers) > 0: |
| diar_kwargs["max_speakers"] = int(max_speakers) |
|
|
| progress(0.70, desc="Running diarization") |
| media = load_waveform_for_pyannote(wav_path) |
| output = pipeline(media, **diar_kwargs) |
| if hasattr(output, "exclusive_speaker_diarization"): |
| diarization = output.exclusive_speaker_diarization |
| elif hasattr(output, "speaker_diarization"): |
| diarization = output.speaker_diarization |
| else: |
| diarization = output |
|
|
| del pipeline |
| cleanup_torch() |
|
|
| diar_rows = [] |
| for turn, _, speaker in diarization.itertracks(yield_label=True): |
| diar_rows.append({"start": float(turn.start), "end": float(turn.end), "speaker": str(speaker)}) |
| diar_df = pd.DataFrame(diar_rows).sort_values(["start", "end"]).reset_index(drop=True) |
|
|
| progress(0.84, desc="Assigning speakers to raw segments") |
| assigned = [] |
| for seg in raw_segments: |
| speaker = assign_speaker_to_segment(seg, diar_df) |
| assigned.append({ |
| "speaker": speaker, |
| "start": float(seg["start"]), |
| "end": float(seg["end"]), |
| "text": seg["text"], |
| }) |
|
|
| cleaned = merge_adjacent_same_speaker(assigned) |
|
|
| raw_speakers = [] |
| for r in cleaned: |
| if r["speaker"] not in raw_speakers: |
| raw_speakers.append(r["speaker"]) |
| speaker_map = {spk: f"Speaker {i:02d}" for i, spk in enumerate(raw_speakers, start=1)} |
|
|
| final_rows = [] |
| for seg in cleaned: |
| final_rows.append({ |
| "speaker": speaker_map[seg["speaker"]], |
| "start": float(seg["start"]), |
| "end": float(seg["end"]), |
| "start_hhmmss": format_hhmmss_mmm(seg["start"]), |
| "end_hhmmss": format_hhmmss_mmm(seg["end"]), |
| "text": seg["text"], |
| }) |
|
|
| df = pd.DataFrame(final_rows) |
| txt_path = out_dir / "speaker_transcript.txt" |
| json_path = out_dir / "speaker_transcript.json" |
| csv_path = out_dir / "speaker_transcript.csv" |
|
|
| df.to_csv(csv_path, index=False) |
| with open(json_path, "w", encoding="utf-8") as f: |
| json.dump(final_rows, f, ensure_ascii=False, indent=2) |
| with open(txt_path, "w", encoding="utf-8") as f: |
| for _, row in df.iterrows(): |
| f.write(f"{row['speaker']}: {row['start_hhmmss']} - {row['end_hhmmss']}\n") |
| f.write(f"Text: {row['text']}\n\n") |
|
|
| preview_lines = [ |
| "=== RUN SUMMARY ===", |
| f"ASR model used: {asr_model_name}", |
| f"Repo-style language used: {used_language}", |
| f"ASR device used: {ASR_DEVICE}", |
| f"Diarization device used: {DIAR_DEVICE}", |
| f"Speech windows: {window_count}", |
| f"Raw transcript segments: {len(raw_segments)}", |
| f"Raw transcript words: {word_count}", |
| f"Diarization segments: {len(diar_df)}", |
| f"Final cleaned diarized segments: {len(df)}", |
| f"Detected speakers: {len(raw_speakers)}", |
| "", |
| ] |
| for _, row in df.head(20).iterrows(): |
| preview_lines.append(f"{row['speaker']}: {row['start_hhmmss']} - {row['end_hhmmss']}") |
| preview_lines.append(f"Text: {row['text']}") |
| preview_lines.append("") |
|
|
| progress(1.0, desc="Done") |
| return "\n".join(preview_lines), df, str(txt_path), str(json_path), str(csv_path) |
| except Exception: |
| return "=== FAILURE ===\n" + traceback.format_exc(), [], None, None, None |
|
|
| with gr.Blocks(title="Diarized Speaker Segments Community-1") as demo: |
| gr.Markdown( |
| """ |
| # Diarized Speaker Segments Community-1 |
| Uses **attached-repo transcription logic** plus **pyannote/speaker-diarization-community-1**. |
| |
| Cleanup rule: |
| - if adjacent speaker segments are the same, merge them |
| - otherwise do not touch them |
| |
| Notes: |
| - default ASR model is **medium** |
| - **large-v3** is available for comparison |
| - default language is **hi** to mimic the attached repo behavior |
| """ |
| ) |
| with gr.Row(): |
| with gr.Column(): |
| media_file = gr.File(label="Upload video/audio", type="filepath") |
| asr_model_name = gr.Dropdown( |
| choices=["medium", "large-v3"], |
| value="medium", |
| label="ASR model", |
| info="Default is medium. large-v3 is available for comparison." |
| ) |
| language = gr.Dropdown( |
| choices=["hi", "auto", "en"], |
| value="hi", |
| label="Language", |
| info="Default is hi to mimic the attached repo transcription behavior." |
| ) |
| enhance_audio = gr.Checkbox(value=True, label="Enhance audio before transcription") |
| filter_known_bad = gr.Checkbox(value=True, label="Filter obvious hallucination / prompt-leak phrases") |
| with gr.Row(): |
| num_speakers = gr.Number(label="Exact number of speakers (optional)", value=None, precision=0) |
| min_speakers = gr.Number(label="Min speakers (optional)", value=1, precision=0) |
| max_speakers = gr.Number(label="Max speakers (optional)", value=8, precision=0) |
| with gr.Row(): |
| preflight_btn = gr.Button("Run preflight") |
| run_btn = gr.Button("Generate diarized transcript", variant="primary") |
| with gr.Column(): |
| preview = gr.Textbox(label="Diagnostics / Preview", lines=24) |
| table = gr.Dataframe(label="Diarized transcript segments", wrap=True, interactive=False) |
| txt_file = gr.File(label="TXT output") |
| json_file = gr.File(label="JSON output") |
| csv_file = gr.File(label="CSV output") |
|
|
| preflight_btn.click( |
| fn=preflight, |
| inputs=[media_file, asr_model_name, language, enhance_audio, num_speakers, min_speakers, max_speakers], |
| outputs=[preview], |
| ) |
| run_btn.click( |
| fn=process_media, |
| inputs=[media_file, asr_model_name, language, enhance_audio, filter_known_bad, num_speakers, min_speakers, max_speakers], |
| outputs=[preview, table, txt_file, json_file, csv_file], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|