store-dialogs-qa / pipeline.py
Niko-NN's picture
feat: integrate domain dictionaries for flower shop ASR
ac34330
from __future__ import annotations
import functools
import json
import logging
import os
import re
import tempfile
from pathlib import Path
from typing import Any
import huggingface_hub as _hfh
import librosa
import numpy as np
import soundfile as sf
import torch
# pyannote.audio (<=3.3) passes use_auth_token to huggingface_hub functions
# that no longer accept it. Patch them to convert use_auth_token → token.
for _fn_name in ("hf_hub_download", "model_info", "snapshot_download"):
_orig = getattr(_hfh, _fn_name, None)
if _orig is None:
continue
@functools.wraps(_orig)
def _patched(*args, _orig_fn=_orig, **kwargs):
if "use_auth_token" in kwargs:
kwargs["token"] = kwargs.pop("use_auth_token")
return _orig_fn(*args, **kwargs)
setattr(_hfh, _fn_name, _patched)
from models_config import DIARIZATION_MODELS, TRANSCRIPTION_MODELS
logger = logging.getLogger(__name__)
_model_cache: dict[str, Any] = {}
# ---------------------------------------------------------------------------
# Domain dictionary: initial_prompt + post-processing corrections
# ---------------------------------------------------------------------------
_DICT_DIR = Path(__file__).parent / "dictionary"
def _load_initial_prompt() -> str:
p = _DICT_DIR / "initial_prompt.txt"
if p.exists():
return p.read_text(encoding="utf-8").strip()
return ""
def _load_corrections() -> dict[str, str]:
p = _DICT_DIR / "corrections.json"
if not p.exists():
return {}
try:
data = json.loads(p.read_text(encoding="utf-8"))
return data.get("corrections", {})
except Exception:
return {}
_INITIAL_PROMPT: str = _load_initial_prompt()
_CORRECTIONS: dict[str, str] = _load_corrections()
if _INITIAL_PROMPT:
logger.info("Domain initial_prompt loaded (%d chars)", len(_INITIAL_PROMPT))
if _CORRECTIONS:
logger.info("Domain corrections loaded (%d rules)", len(_CORRECTIONS))
def apply_corrections(text: str) -> str:
"""Apply domain-specific ASR corrections to transcribed text."""
if not _CORRECTIONS:
return text
result = text
for wrong, correct in _CORRECTIONS.items():
pattern = re.compile(re.escape(wrong), re.IGNORECASE)
result = pattern.sub(correct, result)
return result
# ---------------------------------------------------------------------------
# Audio helpers
# ---------------------------------------------------------------------------
def load_audio(path: str, sr: int = 16000) -> tuple[np.ndarray, int]:
audio, out_sr = librosa.load(path, sr=sr, mono=True)
return audio, out_sr
def save_audio_tmp(audio: np.ndarray, sr: int, suffix: str = ".wav") -> str:
fd, tmp_path = tempfile.mkstemp(suffix=suffix)
os.close(fd)
sf.write(tmp_path, audio, sr)
return tmp_path
def trim_audio(
audio: np.ndarray, sr: int, start_sec: float | None, end_sec: float | None
) -> np.ndarray:
start = int((start_sec or 0) * sr)
end = int(end_sec * sr) if end_sec else len(audio)
return audio[start:end]
def parse_time_str(raw: str) -> float | None:
"""Convert '8:30' -> 510.0 seconds. Returns None for '-', '...' or empty."""
if not raw or raw.strip() in ("-", "...", ""):
return None
parts = raw.strip().split(":")
if len(parts) == 2:
return float(parts[0]) * 60 + float(parts[1])
if len(parts) == 3:
return float(parts[0]) * 3600 + float(parts[1]) * 60 + float(parts[2])
return float(raw)
# ---------------------------------------------------------------------------
# Model loading (lazy, cached)
# ---------------------------------------------------------------------------
def _device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _load_diarization(model_id: str, token: str | None) -> Any:
key = f"diar__{model_id}"
if key in _model_cache:
return _model_cache[key]
cfg = DIARIZATION_MODELS[model_id]
if cfg["backend"] == "pyannote":
from pyannote.audio import Pipeline as PyannotePipeline
pipe = PyannotePipeline.from_pretrained(cfg["repo_id"], use_auth_token=token)
pipe.to(_device())
_model_cache[key] = ("pyannote", pipe)
elif cfg["backend"] == "nemo-msdd":
try:
from nemo.collections.asr.models import ClusteringDiarizer
except ImportError:
raise ImportError(
"NeMo MSDD requires nemo_toolkit. "
"Install: pip install nemo_toolkit[asr]"
)
nemo_model = ClusteringDiarizer.from_pretrained(cfg.get("config", "diar_msdd_telephonic"))
_model_cache[key] = ("nemo-msdd", nemo_model)
else:
raise ValueError(f"Unknown diarization backend: {cfg['backend']}")
return _model_cache[key]
def prefetch_model_weights(model_id: str) -> None:
"""Download model weights to local cache (no GPU needed).
Call this before the GPU-decorated function to avoid wasting GPU time on downloads."""
cfg = TRANSCRIPTION_MODELS.get(model_id)
if cfg is None:
return
repo_id = cfg.get("repo_id")
if not repo_id:
return
if cfg["backend"] in ("transformers-whisper", "faster-whisper"):
from huggingface_hub import snapshot_download
logger.info("Prefetching weights for %s (%s)...", model_id, repo_id)
snapshot_download(repo_id, token=os.environ.get("HF_TOKEN"))
logger.info("Prefetch complete for %s", model_id)
def _load_transcription(model_id: str) -> Any:
key = f"trans__{model_id}"
if key in _model_cache:
return _model_cache[key]
cfg = TRANSCRIPTION_MODELS[model_id]
if cfg["backend"] == "faster-whisper":
from faster_whisper import WhisperModel
dev = "cuda" if torch.cuda.is_available() else "cpu"
ct = cfg.get("compute_type", "float16") if dev == "cuda" else "int8"
model = WhisperModel(cfg["repo_id"], device=dev, compute_type=ct)
_model_cache[key] = ("faster-whisper", model, cfg)
elif cfg["backend"] == "transformers-whisper":
from transformers import (
AutoModelForSpeechSeq2Seq,
AutoProcessor,
pipeline as hf_pipeline,
)
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
dev = _device()
model = AutoModelForSpeechSeq2Seq.from_pretrained(
cfg["repo_id"], torch_dtype=dtype, low_cpu_mem_usage=True,
)
model.to(dev)
processor = AutoProcessor.from_pretrained(cfg["repo_id"])
pipe = hf_pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
torch_dtype=dtype,
device=dev,
)
_model_cache[key] = ("transformers-whisper", pipe, cfg)
elif cfg["backend"] == "gigaam":
try:
import gigaam
except ImportError:
raise ImportError(
"gigaam is not installed. Run: pip install gigaam"
)
model = gigaam.load_model(cfg["model_type"])
if torch.cuda.is_available():
model = model.to("cuda")
_model_cache[key] = ("gigaam", model, cfg)
else:
raise ValueError(f"Unknown transcription backend: {cfg['backend']}")
return _model_cache[key]
def unload_models() -> None:
_model_cache.clear()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# ---------------------------------------------------------------------------
# Diarization
# ---------------------------------------------------------------------------
def _merge_diarization_segments(
segments: list[dict], gap: float = 0.3, min_dur: float = 0.3
) -> list[dict]:
"""Merge adjacent segments of the same speaker separated by <= gap seconds.
Drop segments shorter than min_dur after merging."""
if not segments:
return []
merged: list[dict] = [segments[0].copy()]
for seg in segments[1:]:
prev = merged[-1]
if seg["speaker"] == prev["speaker"] and seg["start"] - prev["end"] <= gap:
prev["end"] = max(prev["end"], seg["end"])
else:
merged.append(seg.copy())
return [s for s in merged if s["end"] - s["start"] >= min_dur]
def run_diarization(
audio_path: str, model_id: str, token: str | None = None,
min_speakers: int | None = None, max_speakers: int | None = None,
merge_gap: float = 0.3, merge_min_dur: float = 0.3,
) -> list[dict]:
backend, model = _load_diarization(model_id, token)
if backend == "pyannote":
diar_kwargs: dict[str, Any] = {}
if min_speakers is not None:
diar_kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
diar_kwargs["max_speakers"] = max_speakers
diarization = model(audio_path, **diar_kwargs)
raw = [
{
"speaker": speaker,
"start": round(turn.start, 2),
"end": round(turn.end, 2),
}
for turn, _, speaker in diarization.itertracks(yield_label=True)
]
logger.info("Diarization raw segments: %d", len(raw))
merged = _merge_diarization_segments(raw, gap=merge_gap, min_dur=merge_min_dur)
logger.info("Diarization merged segments: %d (dropped short)", len(merged))
return merged
if backend == "nemo-msdd":
import tempfile as _tf
nemo_model = model
with _tf.TemporaryDirectory() as tmpdir:
manifest_path = os.path.join(tmpdir, "input.json")
import json as _json
with open(manifest_path, "w") as mf:
_json.dump({"audio_filepath": audio_path, "offset": 0, "duration": None, "label": "infer", "text": "-"}, mf)
nemo_model.diarize(paths2audio_files=[audio_path], batch_size=1)
rttm_path = os.path.join(tmpdir, "pred_rttms", Path(audio_path).stem + ".rttm")
raw = []
if Path(rttm_path).exists():
with open(rttm_path) as rf:
for line in rf:
parts = line.strip().split()
if len(parts) >= 8:
raw.append({
"speaker": parts[7],
"start": round(float(parts[3]), 2),
"end": round(float(parts[3]) + float(parts[4]), 2),
})
merged = _merge_diarization_segments(raw, gap=merge_gap, min_dur=merge_min_dur)
logger.info("NeMo MSDD segments: %d raw, %d merged", len(raw), len(merged))
return merged
raise ValueError(f"Unsupported diarization backend: {backend}")
# ---------------------------------------------------------------------------
# Transcription
# ---------------------------------------------------------------------------
TRANSCRIPTION_STRATEGIES = {
"hybrid": "Весь файл + спикеры из диаризации (рекомендуется)",
"per_segment": "По сегментам диаризации",
}
def run_transcription(
audio_path: str,
model_id: str,
diar_segments: list[dict] | None = None,
strategy: str = "hybrid",
whisper_kwargs: dict | None = None,
) -> list[dict]:
backend, model, cfg = _load_transcription(model_id)
if backend == "faster-whisper":
if strategy == "hybrid" and diar_segments:
return _transcribe_faster_whisper_hybrid(audio_path, model, cfg, diar_segments, whisper_kwargs)
return _transcribe_faster_whisper_per_segment(audio_path, model, cfg, diar_segments, whisper_kwargs)
if backend == "transformers-whisper":
return _transcribe_transformers_whisper(audio_path, model, cfg, diar_segments, strategy)
if backend == "gigaam":
return _transcribe_gigaam(audio_path, model, cfg, diar_segments)
raise ValueError(f"Unsupported transcription backend: {backend}")
_HALLUCINATION_PATTERNS = {
"продолжение следует", "субтитры сделал", "субтитры делал",
"подписывайтесь на канал", "ставьте лайк", "до новых встреч",
"спасибо за просмотр", "всем пока", "subtitles by",
"thank you for watching", "subscribe", "to be continued",
"amara.org", "редактор субтитров",
}
_WHISPER_COMMON = {
"beam_size": 5,
"condition_on_previous_text": False,
"no_speech_threshold": 0.45,
"log_prob_threshold": -0.8,
"compression_ratio_threshold": 2.0,
"repetition_penalty": 1.2,
"temperature": [0.0, 0.2, 0.4],
"initial_prompt": _INITIAL_PROMPT or None,
}
def _is_hallucination(text: str) -> bool:
t = text.lower().strip().rstrip(".!…").strip()
if not t:
return True
for pat in _HALLUCINATION_PATTERNS:
if pat in t:
return True
if len(set(t.split())) <= 1 and len(t.split()) > 2:
return True
return False
_WHISPER_HYBRID_DEFAULTS = {
"beam_size": 5,
"condition_on_previous_text": False,
"no_speech_threshold": 0.5,
"log_prob_threshold": -0.7,
"compression_ratio_threshold": 2.4,
"repetition_penalty": 1.1,
"temperature": [0.0, 0.2, 0.4],
"initial_prompt": _INITIAL_PROMPT or None,
}
def _collect_whisper_words(segments_iter, label: str = "") -> list[dict]:
"""Collect word-level timestamps from Whisper, filtering hallucinated segments."""
raw_seg = 0
empty_seg = 0
hall_seg = 0
words: list[dict] = []
for s in segments_iter:
raw_seg += 1
text = s.text.strip()
if not text:
empty_seg += 1
continue
if _is_hallucination(text):
hall_seg += 1
logger.debug("Hallucination filtered: '%s'", text)
continue
if hasattr(s, "words") and s.words:
for w in s.words:
wt = w.word.strip()
if wt:
words.append({"start": round(w.start, 2), "end": round(w.end, 2), "word": wt})
else:
words.append({"start": round(s.start, 2), "end": round(s.end, 2), "word": text})
logger.info(
"Whisper %s: segs=%d, empty=%d, hall=%d, words=%d",
label, raw_seg, empty_seg, hall_seg, len(words),
)
return words
def _words_to_speaker_segments(
words: list[dict], diar_segments: list[dict], max_gap: float = 2.0,
) -> list[dict]:
"""Assign each word a speaker via diarization overlap, then group by speaker."""
if not words:
return []
for w in words:
mid = (w["start"] + w["end"]) / 2
best_spk = "UNKNOWN"
best_ov = 0.0
for d in diar_segments:
if d["start"] <= mid <= d["end"]:
best_spk = d["speaker"]
break
ov = min(w["end"], d["end"]) - max(w["start"], d["start"])
if ov > best_ov:
best_ov = ov
best_spk = d["speaker"]
w["speaker"] = best_spk
segments: list[dict] = []
cur = {"speaker": words[0]["speaker"], "start": words[0]["start"],
"end": words[0]["end"], "words": [words[0]["word"]]}
for w in words[1:]:
if w["speaker"] == cur["speaker"] and w["start"] - cur["end"] < max_gap:
cur["end"] = w["end"]
cur["words"].append(w["word"])
else:
segments.append({
"speaker": cur["speaker"], "start": cur["start"],
"end": cur["end"], "text": " ".join(cur["words"]),
})
cur = {"speaker": w["speaker"], "start": w["start"],
"end": w["end"], "words": [w["word"]]}
segments.append({
"speaker": cur["speaker"], "start": cur["start"],
"end": cur["end"], "text": " ".join(cur["words"]),
})
return segments
def _resolve_prompt(params: dict) -> dict:
"""Replace '__domain__' marker with actual prompt, strip None."""
p = params.get("initial_prompt")
if p == "__domain__":
params["initial_prompt"] = _INITIAL_PROMPT or None
if params.get("initial_prompt") is None:
params.pop("initial_prompt", None)
return params
def _transcribe_faster_whisper_hybrid(
audio_path: str, model: Any, cfg: dict, diar_segments: list[dict],
whisper_kwargs: dict | None = None,
) -> list[dict]:
"""Transcribe full file with Whisper, assign speakers via word-level timestamps."""
lang = cfg.get("language", "ru")
params = _resolve_prompt({**_WHISPER_HYBRID_DEFAULTS, **(whisper_kwargs or {})})
params["word_timestamps"] = True
segments_iter, _info = model.transcribe(
audio_path, language=lang, vad_filter=True, **params,
)
words = _collect_whisper_words(segments_iter, "hybrid-vad")
if len(words) < 3:
logger.warning("VAD produced too few words (%d), retrying without VAD", len(words))
segments_iter2, _ = model.transcribe(
audio_path, language=lang, vad_filter=False, **params,
)
words = _collect_whisper_words(segments_iter2, "hybrid-novad")
return _words_to_speaker_segments(words, diar_segments)
def _transcribe_faster_whisper_per_segment(
audio_path: str, model: Any, cfg: dict, diar_segments: list[dict] | None,
whisper_kwargs: dict | None = None,
) -> list[dict]:
"""Transcribe each diarization segment individually."""
lang = cfg.get("language", "ru")
params = _resolve_prompt({**_WHISPER_COMMON, **(whisper_kwargs or {})})
if not diar_segments:
segments_iter, _info = model.transcribe(
audio_path, language=lang, vad_filter=True, **params,
)
return [
{"start": round(s.start, 2), "end": round(s.end, 2), "text": s.text.strip()}
for s in segments_iter
if s.text.strip() and not _is_hallucination(s.text)
]
audio, sr = load_audio(audio_path)
pad_sec = 0.5
results: list[dict] = []
with tempfile.TemporaryDirectory() as tmpdir:
for i, seg in enumerate(diar_segments):
start_s = max(0, seg["start"] - pad_sec)
end_s = min(len(audio) / sr, seg["end"] + pad_sec)
start_sample = int(start_s * sr)
end_sample = int(end_s * sr)
chunk = audio[start_sample:end_sample]
if len(chunk) < int(sr * 0.15):
results.append({**seg, "text": ""})
continue
seg_path = os.path.join(tmpdir, f"seg_{i}.wav")
sf.write(seg_path, chunk, sr)
try:
segs_iter, info = model.transcribe(
seg_path, language=lang, vad_filter=False,
without_timestamps=True, **params,
)
good_texts: list[str] = []
for s in segs_iter:
t = s.text.strip()
if not t:
continue
if _is_hallucination(t):
continue
good_texts.append(t)
text = " ".join(good_texts)
except Exception as exc:
logger.warning("Whisper failed on segment %d: %s", i, exc)
text = ""
results.append({**seg, "text": text})
return results
def _transcribe_transformers_whisper(
audio_path: str, pipe: Any, cfg: dict,
diar_segments: list[dict] | None, strategy: str,
) -> list[dict]:
"""Transcribe using HuggingFace transformers pipeline (for native HF Whisper checkpoints)."""
lang = cfg.get("language", "ru")
generate_kwargs = {"language": lang, "task": "transcribe"}
if _INITIAL_PROMPT:
generate_kwargs["prompt_ids"] = pipe.tokenizer.get_prompt_ids(
_INITIAL_PROMPT, return_tensors="np"
) if hasattr(pipe, "tokenizer") and hasattr(pipe.tokenizer, "get_prompt_ids") else None
if generate_kwargs["prompt_ids"] is None:
generate_kwargs.pop("prompt_ids", None)
use_word_ts = strategy == "hybrid" and diar_segments
result = pipe(
audio_path,
return_timestamps="word" if use_word_ts else True,
generate_kwargs=generate_kwargs,
chunk_length_s=30,
batch_size=8,
)
if use_word_ts:
words: list[dict] = []
for chunk in result.get("chunks", []):
text = chunk.get("text", "").strip()
if not text:
continue
ts = chunk.get("timestamp", (0.0, 0.0))
start = ts[0] if ts[0] is not None else 0.0
end = ts[1] if ts[1] is not None else start
words.append({"start": round(start, 2), "end": round(end, 2), "word": text})
logger.info("Transformers-whisper word-level: %d words", len(words))
if words:
return _words_to_speaker_segments(words, diar_segments)
whisper_segments: list[dict] = []
if not use_word_ts:
for chunk in result.get("chunks", []):
text = chunk.get("text", "").strip()
ts = chunk.get("timestamp", (0.0, 0.0))
start = ts[0] if ts[0] is not None else 0.0
end = ts[1] if ts[1] is not None else start
if not text or _is_hallucination(text):
continue
whisper_segments.append({"start": round(start, 2), "end": round(end, 2), "text": text})
logger.info("Transformers-whisper: %d segments", len(whisper_segments))
if not whisper_segments and result.get("text", "").strip():
whisper_segments = [{"start": 0.0, "end": 0.0, "text": result["text"].strip()}]
if diar_segments:
return _align_speakers(whisper_segments, diar_segments)
return whisper_segments
def _transcribe_gigaam(
audio_path: str, model: Any, cfg: dict, diar_segments: list[dict] | None
) -> list[dict]:
if not diar_segments:
result = model.transcribe(audio_path)
text = result if isinstance(result, str) else (result[0] if result else "")
return [{"start": 0.0, "end": 0.0, "text": text, "speaker": "UNKNOWN"}]
audio, sr = load_audio(audio_path)
results: list[dict] = []
with tempfile.TemporaryDirectory() as tmpdir:
for i, seg in enumerate(diar_segments):
start_sample = int(seg["start"] * sr)
end_sample = int(seg["end"] * sr)
chunk = audio[start_sample:end_sample]
if len(chunk) < int(sr * 0.1):
results.append({**seg, "text": ""})
continue
seg_path = os.path.join(tmpdir, f"seg_{i}.wav")
sf.write(seg_path, chunk, sr)
transcription = model.transcribe(seg_path)
text = transcription if isinstance(transcription, str) else (transcription[0] if transcription else "")
results.append({**seg, "text": text.strip()})
return results
# ---------------------------------------------------------------------------
# Speaker alignment
# ---------------------------------------------------------------------------
def _align_speakers(
trans_segments: list[dict], diar_segments: list[dict]
) -> list[dict]:
"""Assign speaker labels from diarization to transcription segments by overlap."""
results: list[dict] = []
for t in trans_segments:
best_speaker = "UNKNOWN"
best_ov = 0.0
for d in diar_segments:
ov = min(t["end"], d["end"]) - max(t["start"], d["start"])
if ov > best_ov:
best_ov = ov
best_speaker = d["speaker"]
results.append({**t, "speaker": best_speaker})
return results
# ---------------------------------------------------------------------------
# Full pipeline
# ---------------------------------------------------------------------------
def process_file(
audio_path: str,
diar_model_id: str,
trans_model_id: str,
token: str | None = None,
start_sec: float | None = None,
end_sec: float | None = None,
min_speakers: int | None = 2,
max_speakers: int | None = None,
strategy: str = "hybrid",
whisper_kwargs: dict | None = None,
merge_gap: float = 0.3,
merge_min_dur: float = 0.3,
) -> dict:
"""Diarize + transcribe a single audio file. Optionally trim to [start_sec, end_sec]."""
work_path = audio_path
if start_sec is not None or end_sec is not None:
audio, sr = load_audio(audio_path)
audio = trim_audio(audio, sr, start_sec, end_sec)
work_path = save_audio_tmp(audio, sr)
diar_segments = run_diarization(
work_path, diar_model_id, token,
min_speakers=min_speakers, max_speakers=max_speakers,
merge_gap=merge_gap, merge_min_dur=merge_min_dur,
)
trans_segments = run_transcription(
work_path, trans_model_id, diar_segments,
strategy=strategy, whisper_kwargs=whisper_kwargs,
)
for seg in trans_segments:
if seg.get("text"):
seg["text"] = apply_corrections(seg["text"])
if work_path != audio_path:
Path(work_path).unlink(missing_ok=True)
speakers = sorted(set(s.get("speaker", "") for s in diar_segments))
return {
"file": Path(audio_path).stem,
"diarization_model": diar_model_id,
"transcription_model": trans_model_id,
"num_speakers": len(speakers),
"speakers": speakers,
"diarization_segments": diar_segments,
"transcription": trans_segments,
}