Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import functools | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any | |
| import huggingface_hub as _hfh | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| # pyannote.audio (<=3.3) passes use_auth_token to huggingface_hub functions | |
| # that no longer accept it. Patch them to convert use_auth_token → token. | |
| for _fn_name in ("hf_hub_download", "model_info", "snapshot_download"): | |
| _orig = getattr(_hfh, _fn_name, None) | |
| if _orig is None: | |
| continue | |
| def _patched(*args, _orig_fn=_orig, **kwargs): | |
| if "use_auth_token" in kwargs: | |
| kwargs["token"] = kwargs.pop("use_auth_token") | |
| return _orig_fn(*args, **kwargs) | |
| setattr(_hfh, _fn_name, _patched) | |
| from models_config import DIARIZATION_MODELS, TRANSCRIPTION_MODELS | |
| logger = logging.getLogger(__name__) | |
| _model_cache: dict[str, Any] = {} | |
| # --------------------------------------------------------------------------- | |
| # Domain dictionary: initial_prompt + post-processing corrections | |
| # --------------------------------------------------------------------------- | |
| _DICT_DIR = Path(__file__).parent / "dictionary" | |
| def _load_initial_prompt() -> str: | |
| p = _DICT_DIR / "initial_prompt.txt" | |
| if p.exists(): | |
| return p.read_text(encoding="utf-8").strip() | |
| return "" | |
| def _load_corrections() -> dict[str, str]: | |
| p = _DICT_DIR / "corrections.json" | |
| if not p.exists(): | |
| return {} | |
| try: | |
| data = json.loads(p.read_text(encoding="utf-8")) | |
| return data.get("corrections", {}) | |
| except Exception: | |
| return {} | |
| _INITIAL_PROMPT: str = _load_initial_prompt() | |
| _CORRECTIONS: dict[str, str] = _load_corrections() | |
| if _INITIAL_PROMPT: | |
| logger.info("Domain initial_prompt loaded (%d chars)", len(_INITIAL_PROMPT)) | |
| if _CORRECTIONS: | |
| logger.info("Domain corrections loaded (%d rules)", len(_CORRECTIONS)) | |
| def apply_corrections(text: str) -> str: | |
| """Apply domain-specific ASR corrections to transcribed text.""" | |
| if not _CORRECTIONS: | |
| return text | |
| result = text | |
| for wrong, correct in _CORRECTIONS.items(): | |
| pattern = re.compile(re.escape(wrong), re.IGNORECASE) | |
| result = pattern.sub(correct, result) | |
| return result | |
| # --------------------------------------------------------------------------- | |
| # Audio helpers | |
| # --------------------------------------------------------------------------- | |
| def load_audio(path: str, sr: int = 16000) -> tuple[np.ndarray, int]: | |
| audio, out_sr = librosa.load(path, sr=sr, mono=True) | |
| return audio, out_sr | |
| def save_audio_tmp(audio: np.ndarray, sr: int, suffix: str = ".wav") -> str: | |
| fd, tmp_path = tempfile.mkstemp(suffix=suffix) | |
| os.close(fd) | |
| sf.write(tmp_path, audio, sr) | |
| return tmp_path | |
| def trim_audio( | |
| audio: np.ndarray, sr: int, start_sec: float | None, end_sec: float | None | |
| ) -> np.ndarray: | |
| start = int((start_sec or 0) * sr) | |
| end = int(end_sec * sr) if end_sec else len(audio) | |
| return audio[start:end] | |
| def parse_time_str(raw: str) -> float | None: | |
| """Convert '8:30' -> 510.0 seconds. Returns None for '-', '...' or empty.""" | |
| if not raw or raw.strip() in ("-", "...", ""): | |
| return None | |
| parts = raw.strip().split(":") | |
| if len(parts) == 2: | |
| return float(parts[0]) * 60 + float(parts[1]) | |
| if len(parts) == 3: | |
| return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2]) | |
| return float(raw) | |
| # --------------------------------------------------------------------------- | |
| # Model loading (lazy, cached) | |
| # --------------------------------------------------------------------------- | |
| def _device() -> torch.device: | |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def _load_diarization(model_id: str, token: str | None) -> Any: | |
| key = f"diar__{model_id}" | |
| if key in _model_cache: | |
| return _model_cache[key] | |
| cfg = DIARIZATION_MODELS[model_id] | |
| if cfg["backend"] == "pyannote": | |
| from pyannote.audio import Pipeline as PyannotePipeline | |
| pipe = PyannotePipeline.from_pretrained(cfg["repo_id"], use_auth_token=token) | |
| pipe.to(_device()) | |
| _model_cache[key] = ("pyannote", pipe) | |
| elif cfg["backend"] == "nemo-msdd": | |
| try: | |
| from nemo.collections.asr.models import ClusteringDiarizer | |
| except ImportError: | |
| raise ImportError( | |
| "NeMo MSDD requires nemo_toolkit. " | |
| "Install: pip install nemo_toolkit[asr]" | |
| ) | |
| nemo_model = ClusteringDiarizer.from_pretrained(cfg.get("config", "diar_msdd_telephonic")) | |
| _model_cache[key] = ("nemo-msdd", nemo_model) | |
| else: | |
| raise ValueError(f"Unknown diarization backend: {cfg['backend']}") | |
| return _model_cache[key] | |
| def prefetch_model_weights(model_id: str) -> None: | |
| """Download model weights to local cache (no GPU needed). | |
| Call this before the GPU-decorated function to avoid wasting GPU time on downloads.""" | |
| cfg = TRANSCRIPTION_MODELS.get(model_id) | |
| if cfg is None: | |
| return | |
| repo_id = cfg.get("repo_id") | |
| if not repo_id: | |
| return | |
| if cfg["backend"] in ("transformers-whisper", "faster-whisper"): | |
| from huggingface_hub import snapshot_download | |
| logger.info("Prefetching weights for %s (%s)...", model_id, repo_id) | |
| snapshot_download(repo_id, token=os.environ.get("HF_TOKEN")) | |
| logger.info("Prefetch complete for %s", model_id) | |
| def _load_transcription(model_id: str) -> Any: | |
| key = f"trans__{model_id}" | |
| if key in _model_cache: | |
| return _model_cache[key] | |
| cfg = TRANSCRIPTION_MODELS[model_id] | |
| if cfg["backend"] == "faster-whisper": | |
| from faster_whisper import WhisperModel | |
| dev = "cuda" if torch.cuda.is_available() else "cpu" | |
| ct = cfg.get("compute_type", "float16") if dev == "cuda" else "int8" | |
| model = WhisperModel(cfg["repo_id"], device=dev, compute_type=ct) | |
| _model_cache[key] = ("faster-whisper", model, cfg) | |
| elif cfg["backend"] == "transformers-whisper": | |
| from transformers import ( | |
| AutoModelForSpeechSeq2Seq, | |
| AutoProcessor, | |
| pipeline as hf_pipeline, | |
| ) | |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 | |
| dev = _device() | |
| model = AutoModelForSpeechSeq2Seq.from_pretrained( | |
| cfg["repo_id"], torch_dtype=dtype, low_cpu_mem_usage=True, | |
| ) | |
| model.to(dev) | |
| processor = AutoProcessor.from_pretrained(cfg["repo_id"]) | |
| pipe = hf_pipeline( | |
| "automatic-speech-recognition", | |
| model=model, | |
| tokenizer=processor.tokenizer, | |
| feature_extractor=processor.feature_extractor, | |
| torch_dtype=dtype, | |
| device=dev, | |
| ) | |
| _model_cache[key] = ("transformers-whisper", pipe, cfg) | |
| elif cfg["backend"] == "gigaam": | |
| try: | |
| import gigaam | |
| except ImportError: | |
| raise ImportError( | |
| "gigaam is not installed. Run: pip install gigaam" | |
| ) | |
| model = gigaam.load_model(cfg["model_type"]) | |
| if torch.cuda.is_available(): | |
| model = model.to("cuda") | |
| _model_cache[key] = ("gigaam", model, cfg) | |
| else: | |
| raise ValueError(f"Unknown transcription backend: {cfg['backend']}") | |
| return _model_cache[key] | |
| def unload_models() -> None: | |
| _model_cache.clear() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # --------------------------------------------------------------------------- | |
| # Diarization | |
| # --------------------------------------------------------------------------- | |
| def _merge_diarization_segments( | |
| segments: list[dict], gap: float = 0.3, min_dur: float = 0.3 | |
| ) -> list[dict]: | |
| """Merge adjacent segments of the same speaker separated by <= gap seconds. | |
| Drop segments shorter than min_dur after merging.""" | |
| if not segments: | |
| return [] | |
| merged: list[dict] = [segments[0].copy()] | |
| for seg in segments[1:]: | |
| prev = merged[-1] | |
| if seg["speaker"] == prev["speaker"] and seg["start"] - prev["end"] <= gap: | |
| prev["end"] = max(prev["end"], seg["end"]) | |
| else: | |
| merged.append(seg.copy()) | |
| return [s for s in merged if s["end"] - s["start"] >= min_dur] | |
| def run_diarization( | |
| audio_path: str, model_id: str, token: str | None = None, | |
| min_speakers: int | None = None, max_speakers: int | None = None, | |
| merge_gap: float = 0.3, merge_min_dur: float = 0.3, | |
| ) -> list[dict]: | |
| backend, model = _load_diarization(model_id, token) | |
| if backend == "pyannote": | |
| diar_kwargs: dict[str, Any] = {} | |
| if min_speakers is not None: | |
| diar_kwargs["min_speakers"] = min_speakers | |
| if max_speakers is not None: | |
| diar_kwargs["max_speakers"] = max_speakers | |
| diarization = model(audio_path, **diar_kwargs) | |
| raw = [ | |
| { | |
| "speaker": speaker, | |
| "start": round(turn.start, 2), | |
| "end": round(turn.end, 2), | |
| } | |
| for turn, _, speaker in diarization.itertracks(yield_label=True) | |
| ] | |
| logger.info("Diarization raw segments: %d", len(raw)) | |
| merged = _merge_diarization_segments(raw, gap=merge_gap, min_dur=merge_min_dur) | |
| logger.info("Diarization merged segments: %d (dropped short)", len(merged)) | |
| return merged | |
| if backend == "nemo-msdd": | |
| import tempfile as _tf | |
| nemo_model = model | |
| with _tf.TemporaryDirectory() as tmpdir: | |
| manifest_path = os.path.join(tmpdir, "input.json") | |
| import json as _json | |
| with open(manifest_path, "w") as mf: | |
| _json.dump({"audio_filepath": audio_path, "offset": 0, "duration": None, "label": "infer", "text": "-"}, mf) | |
| nemo_model.diarize(paths2audio_files=[audio_path], batch_size=1) | |
| rttm_path = os.path.join(tmpdir, "pred_rttms", Path(audio_path).stem + ".rttm") | |
| raw = [] | |
| if Path(rttm_path).exists(): | |
| with open(rttm_path) as rf: | |
| for line in rf: | |
| parts = line.strip().split() | |
| if len(parts) >= 8: | |
| raw.append({ | |
| "speaker": parts[7], | |
| "start": round(float(parts[3]), 2), | |
| "end": round(float(parts[3]) + float(parts[4]), 2), | |
| }) | |
| merged = _merge_diarization_segments(raw, gap=merge_gap, min_dur=merge_min_dur) | |
| logger.info("NeMo MSDD segments: %d raw, %d merged", len(raw), len(merged)) | |
| return merged | |
| raise ValueError(f"Unsupported diarization backend: {backend}") | |
| # --------------------------------------------------------------------------- | |
| # Transcription | |
| # --------------------------------------------------------------------------- | |
| TRANSCRIPTION_STRATEGIES = { | |
| "hybrid": "Весь файл + спикеры из диаризации (рекомендуется)", | |
| "per_segment": "По сегментам диаризации", | |
| } | |
| def run_transcription( | |
| audio_path: str, | |
| model_id: str, | |
| diar_segments: list[dict] | None = None, | |
| strategy: str = "hybrid", | |
| whisper_kwargs: dict | None = None, | |
| ) -> list[dict]: | |
| backend, model, cfg = _load_transcription(model_id) | |
| if backend == "faster-whisper": | |
| if strategy == "hybrid" and diar_segments: | |
| return _transcribe_faster_whisper_hybrid(audio_path, model, cfg, diar_segments, whisper_kwargs) | |
| return _transcribe_faster_whisper_per_segment(audio_path, model, cfg, diar_segments, whisper_kwargs) | |
| if backend == "transformers-whisper": | |
| return _transcribe_transformers_whisper(audio_path, model, cfg, diar_segments, strategy) | |
| if backend == "gigaam": | |
| return _transcribe_gigaam(audio_path, model, cfg, diar_segments) | |
| raise ValueError(f"Unsupported transcription backend: {backend}") | |
| _HALLUCINATION_PATTERNS = { | |
| "продолжение следует", "субтитры сделал", "субтитры делал", | |
| "подписывайтесь на канал", "ставьте лайк", "до новых встреч", | |
| "спасибо за просмотр", "всем пока", "subtitles by", | |
| "thank you for watching", "subscribe", "to be continued", | |
| "amara.org", "редактор субтитров", | |
| } | |
| _WHISPER_COMMON = { | |
| "beam_size": 5, | |
| "condition_on_previous_text": False, | |
| "no_speech_threshold": 0.45, | |
| "log_prob_threshold": -0.8, | |
| "compression_ratio_threshold": 2.0, | |
| "repetition_penalty": 1.2, | |
| "temperature": [0.0, 0.2, 0.4], | |
| "initial_prompt": _INITIAL_PROMPT or None, | |
| } | |
| def _is_hallucination(text: str) -> bool: | |
| t = text.lower().strip().rstrip(".!…").strip() | |
| if not t: | |
| return True | |
| for pat in _HALLUCINATION_PATTERNS: | |
| if pat in t: | |
| return True | |
| if len(set(t.split())) <= 1 and len(t.split()) > 2: | |
| return True | |
| return False | |
| _WHISPER_HYBRID_DEFAULTS = { | |
| "beam_size": 5, | |
| "condition_on_previous_text": False, | |
| "no_speech_threshold": 0.5, | |
| "log_prob_threshold": -0.7, | |
| "compression_ratio_threshold": 2.4, | |
| "repetition_penalty": 1.1, | |
| "temperature": [0.0, 0.2, 0.4], | |
| "initial_prompt": _INITIAL_PROMPT or None, | |
| } | |
| def _collect_whisper_words(segments_iter, label: str = "") -> list[dict]: | |
| """Collect word-level timestamps from Whisper, filtering hallucinated segments.""" | |
| raw_seg = 0 | |
| empty_seg = 0 | |
| hall_seg = 0 | |
| words: list[dict] = [] | |
| for s in segments_iter: | |
| raw_seg += 1 | |
| text = s.text.strip() | |
| if not text: | |
| empty_seg += 1 | |
| continue | |
| if _is_hallucination(text): | |
| hall_seg += 1 | |
| logger.debug("Hallucination filtered: '%s'", text) | |
| continue | |
| if hasattr(s, "words") and s.words: | |
| for w in s.words: | |
| wt = w.word.strip() | |
| if wt: | |
| words.append({"start": round(w.start, 2), "end": round(w.end, 2), "word": wt}) | |
| else: | |
| words.append({"start": round(s.start, 2), "end": round(s.end, 2), "word": text}) | |
| logger.info( | |
| "Whisper %s: segs=%d, empty=%d, hall=%d, words=%d", | |
| label, raw_seg, empty_seg, hall_seg, len(words), | |
| ) | |
| return words | |
| def _words_to_speaker_segments( | |
| words: list[dict], diar_segments: list[dict], max_gap: float = 2.0, | |
| ) -> list[dict]: | |
| """Assign each word a speaker via diarization overlap, then group by speaker.""" | |
| if not words: | |
| return [] | |
| for w in words: | |
| mid = (w["start"] + w["end"]) / 2 | |
| best_spk = "UNKNOWN" | |
| best_ov = 0.0 | |
| for d in diar_segments: | |
| if d["start"] <= mid <= d["end"]: | |
| best_spk = d["speaker"] | |
| break | |
| ov = min(w["end"], d["end"]) - max(w["start"], d["start"]) | |
| if ov > best_ov: | |
| best_ov = ov | |
| best_spk = d["speaker"] | |
| w["speaker"] = best_spk | |
| segments: list[dict] = [] | |
| cur = {"speaker": words[0]["speaker"], "start": words[0]["start"], | |
| "end": words[0]["end"], "words": [words[0]["word"]]} | |
| for w in words[1:]: | |
| if w["speaker"] == cur["speaker"] and w["start"] - cur["end"] < max_gap: | |
| cur["end"] = w["end"] | |
| cur["words"].append(w["word"]) | |
| else: | |
| segments.append({ | |
| "speaker": cur["speaker"], "start": cur["start"], | |
| "end": cur["end"], "text": " ".join(cur["words"]), | |
| }) | |
| cur = {"speaker": w["speaker"], "start": w["start"], | |
| "end": w["end"], "words": [w["word"]]} | |
| segments.append({ | |
| "speaker": cur["speaker"], "start": cur["start"], | |
| "end": cur["end"], "text": " ".join(cur["words"]), | |
| }) | |
| return segments | |
| def _resolve_prompt(params: dict) -> dict: | |
| """Replace '__domain__' marker with actual prompt, strip None.""" | |
| p = params.get("initial_prompt") | |
| if p == "__domain__": | |
| params["initial_prompt"] = _INITIAL_PROMPT or None | |
| if params.get("initial_prompt") is None: | |
| params.pop("initial_prompt", None) | |
| return params | |
| def _transcribe_faster_whisper_hybrid( | |
| audio_path: str, model: Any, cfg: dict, diar_segments: list[dict], | |
| whisper_kwargs: dict | None = None, | |
| ) -> list[dict]: | |
| """Transcribe full file with Whisper, assign speakers via word-level timestamps.""" | |
| lang = cfg.get("language", "ru") | |
| params = _resolve_prompt({**_WHISPER_HYBRID_DEFAULTS, **(whisper_kwargs or {})}) | |
| params["word_timestamps"] = True | |
| segments_iter, _info = model.transcribe( | |
| audio_path, language=lang, vad_filter=True, **params, | |
| ) | |
| words = _collect_whisper_words(segments_iter, "hybrid-vad") | |
| if len(words) < 3: | |
| logger.warning("VAD produced too few words (%d), retrying without VAD", len(words)) | |
| segments_iter2, _ = model.transcribe( | |
| audio_path, language=lang, vad_filter=False, **params, | |
| ) | |
| words = _collect_whisper_words(segments_iter2, "hybrid-novad") | |
| return _words_to_speaker_segments(words, diar_segments) | |
| def _transcribe_faster_whisper_per_segment( | |
| audio_path: str, model: Any, cfg: dict, diar_segments: list[dict] | None, | |
| whisper_kwargs: dict | None = None, | |
| ) -> list[dict]: | |
| """Transcribe each diarization segment individually.""" | |
| lang = cfg.get("language", "ru") | |
| params = _resolve_prompt({**_WHISPER_COMMON, **(whisper_kwargs or {})}) | |
| if not diar_segments: | |
| segments_iter, _info = model.transcribe( | |
| audio_path, language=lang, vad_filter=True, **params, | |
| ) | |
| return [ | |
| {"start": round(s.start, 2), "end": round(s.end, 2), "text": s.text.strip()} | |
| for s in segments_iter | |
| if s.text.strip() and not _is_hallucination(s.text) | |
| ] | |
| audio, sr = load_audio(audio_path) | |
| pad_sec = 0.5 | |
| results: list[dict] = [] | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| for i, seg in enumerate(diar_segments): | |
| start_s = max(0, seg["start"] - pad_sec) | |
| end_s = min(len(audio) / sr, seg["end"] + pad_sec) | |
| start_sample = int(start_s * sr) | |
| end_sample = int(end_s * sr) | |
| chunk = audio[start_sample:end_sample] | |
| if len(chunk) < int(sr * 0.15): | |
| results.append({**seg, "text": ""}) | |
| continue | |
| seg_path = os.path.join(tmpdir, f"seg_{i}.wav") | |
| sf.write(seg_path, chunk, sr) | |
| try: | |
| segs_iter, info = model.transcribe( | |
| seg_path, language=lang, vad_filter=False, | |
| without_timestamps=True, **params, | |
| ) | |
| good_texts: list[str] = [] | |
| for s in segs_iter: | |
| t = s.text.strip() | |
| if not t: | |
| continue | |
| if _is_hallucination(t): | |
| continue | |
| good_texts.append(t) | |
| text = " ".join(good_texts) | |
| except Exception as exc: | |
| logger.warning("Whisper failed on segment %d: %s", i, exc) | |
| text = "" | |
| results.append({**seg, "text": text}) | |
| return results | |
| def _transcribe_transformers_whisper( | |
| audio_path: str, pipe: Any, cfg: dict, | |
| diar_segments: list[dict] | None, strategy: str, | |
| ) -> list[dict]: | |
| """Transcribe using HuggingFace transformers pipeline (for native HF Whisper checkpoints).""" | |
| lang = cfg.get("language", "ru") | |
| generate_kwargs = {"language": lang, "task": "transcribe"} | |
| if _INITIAL_PROMPT: | |
| generate_kwargs["prompt_ids"] = pipe.tokenizer.get_prompt_ids( | |
| _INITIAL_PROMPT, return_tensors="np" | |
| ) if hasattr(pipe, "tokenizer") and hasattr(pipe.tokenizer, "get_prompt_ids") else None | |
| if generate_kwargs["prompt_ids"] is None: | |
| generate_kwargs.pop("prompt_ids", None) | |
| use_word_ts = strategy == "hybrid" and diar_segments | |
| result = pipe( | |
| audio_path, | |
| return_timestamps="word" if use_word_ts else True, | |
| generate_kwargs=generate_kwargs, | |
| chunk_length_s=30, | |
| batch_size=8, | |
| ) | |
| if use_word_ts: | |
| words: list[dict] = [] | |
| for chunk in result.get("chunks", []): | |
| text = chunk.get("text", "").strip() | |
| if not text: | |
| continue | |
| ts = chunk.get("timestamp", (0.0, 0.0)) | |
| start = ts[0] if ts[0] is not None else 0.0 | |
| end = ts[1] if ts[1] is not None else start | |
| words.append({"start": round(start, 2), "end": round(end, 2), "word": text}) | |
| logger.info("Transformers-whisper word-level: %d words", len(words)) | |
| if words: | |
| return _words_to_speaker_segments(words, diar_segments) | |
| whisper_segments: list[dict] = [] | |
| if not use_word_ts: | |
| for chunk in result.get("chunks", []): | |
| text = chunk.get("text", "").strip() | |
| ts = chunk.get("timestamp", (0.0, 0.0)) | |
| start = ts[0] if ts[0] is not None else 0.0 | |
| end = ts[1] if ts[1] is not None else start | |
| if not text or _is_hallucination(text): | |
| continue | |
| whisper_segments.append({"start": round(start, 2), "end": round(end, 2), "text": text}) | |
| logger.info("Transformers-whisper: %d segments", len(whisper_segments)) | |
| if not whisper_segments and result.get("text", "").strip(): | |
| whisper_segments = [{"start": 0.0, "end": 0.0, "text": result["text"].strip()}] | |
| if diar_segments: | |
| return _align_speakers(whisper_segments, diar_segments) | |
| return whisper_segments | |
| def _transcribe_gigaam( | |
| audio_path: str, model: Any, cfg: dict, diar_segments: list[dict] | None | |
| ) -> list[dict]: | |
| if not diar_segments: | |
| result = model.transcribe(audio_path) | |
| text = result if isinstance(result, str) else (result[0] if result else "") | |
| return [{"start": 0.0, "end": 0.0, "text": text, "speaker": "UNKNOWN"}] | |
| audio, sr = load_audio(audio_path) | |
| results: list[dict] = [] | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| for i, seg in enumerate(diar_segments): | |
| start_sample = int(seg["start"] * sr) | |
| end_sample = int(seg["end"] * sr) | |
| chunk = audio[start_sample:end_sample] | |
| if len(chunk) < int(sr * 0.1): | |
| results.append({**seg, "text": ""}) | |
| continue | |
| seg_path = os.path.join(tmpdir, f"seg_{i}.wav") | |
| sf.write(seg_path, chunk, sr) | |
| transcription = model.transcribe(seg_path) | |
| text = transcription if isinstance(transcription, str) else (transcription[0] if transcription else "") | |
| results.append({**seg, "text": text.strip()}) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Speaker alignment | |
| # --------------------------------------------------------------------------- | |
| def _align_speakers( | |
| trans_segments: list[dict], diar_segments: list[dict] | |
| ) -> list[dict]: | |
| """Assign speaker labels from diarization to transcription segments by overlap.""" | |
| results: list[dict] = [] | |
| for t in trans_segments: | |
| best_speaker = "UNKNOWN" | |
| best_ov = 0.0 | |
| for d in diar_segments: | |
| ov = min(t["end"], d["end"]) - max(t["start"], d["start"]) | |
| if ov > best_ov: | |
| best_ov = ov | |
| best_speaker = d["speaker"] | |
| results.append({**t, "speaker": best_speaker}) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # Full pipeline | |
| # --------------------------------------------------------------------------- | |
| def process_file( | |
| audio_path: str, | |
| diar_model_id: str, | |
| trans_model_id: str, | |
| token: str | None = None, | |
| start_sec: float | None = None, | |
| end_sec: float | None = None, | |
| min_speakers: int | None = 2, | |
| max_speakers: int | None = None, | |
| strategy: str = "hybrid", | |
| whisper_kwargs: dict | None = None, | |
| merge_gap: float = 0.3, | |
| merge_min_dur: float = 0.3, | |
| ) -> dict: | |
| """Diarize + transcribe a single audio file. Optionally trim to [start_sec, end_sec].""" | |
| work_path = audio_path | |
| if start_sec is not None or end_sec is not None: | |
| audio, sr = load_audio(audio_path) | |
| audio = trim_audio(audio, sr, start_sec, end_sec) | |
| work_path = save_audio_tmp(audio, sr) | |
| diar_segments = run_diarization( | |
| work_path, diar_model_id, token, | |
| min_speakers=min_speakers, max_speakers=max_speakers, | |
| merge_gap=merge_gap, merge_min_dur=merge_min_dur, | |
| ) | |
| trans_segments = run_transcription( | |
| work_path, trans_model_id, diar_segments, | |
| strategy=strategy, whisper_kwargs=whisper_kwargs, | |
| ) | |
| for seg in trans_segments: | |
| if seg.get("text"): | |
| seg["text"] = apply_corrections(seg["text"]) | |
| if work_path != audio_path: | |
| Path(work_path).unlink(missing_ok=True) | |
| speakers = sorted(set(s.get("speaker", "") for s in diar_segments)) | |
| return { | |
| "file": Path(audio_path).stem, | |
| "diarization_model": diar_model_id, | |
| "transcription_model": trans_model_id, | |
| "num_speakers": len(speakers), | |
| "speakers": speakers, | |
| "diarization_segments": diar_segments, | |
| "transcription": trans_segments, | |
| } | |