""" Parity check: transformers vs whisper.cpp Q4_K GGML on sample_podcast.wav. Reports per-speaker transcripts, CER, and segment count. """ from pathlib import Path import os, re, subprocess, sys, time import numpy as np import soundfile as sf import torch from transformers import WhisperForConditionalGeneration, WhisperProcessor ROOT = Path(__file__).parents[2] WAV = ROOT / "experimental/multi-speaker/hf_space/static/sample_podcast.wav" WHISPER_CLI = Path(__file__).parent / "whisper.cpp/build/bin/whisper-cli" GGUF = Path(__file__).parent / "models/ggml-chorus-v1-q4_k.bin" MODEL = "Trelis/Chorus-v1" # Whisper + MPS + fp16 is broken in current transformers — falls back to CPU fp32. device = "cpu" dtype = torch.float32 print(f"[parity] device={device}") token = os.environ.get("HF_TOKEN") proc = WhisperProcessor.from_pretrained(MODEL, token=token) model = WhisperForConditionalGeneration.from_pretrained(MODEL, token=token, dtype=dtype).to(device).eval() model.generation_config.predict_timestamps = True model.generation_config.max_initial_timestamp_index = 1500 tok = proc.tokenizer ids = {n: tok.convert_tokens_to_ids(t) for n, t in [ ("en", "<|en|>"), ("transcribe", "<|transcribe|>"), ("speaker1", "<|speaker1|>"), ("speaker2", "<|speaker2|>"), ]} arr, sr = sf.read(WAV) arr = np.asarray(arr, dtype=np.float32) if arr.ndim > 1: arr = arr.mean(axis=1) assert sr == 16_000 feats = proc.feature_extractor([arr], sampling_rate=16_000, return_tensors="pt").input_features.to(device).to(dtype) def hf_transcribe(speaker_id: int) -> str: forced = [[1, ids["en"]], [2, ids["transcribe"]], [3, ids[f"speaker{speaker_id}"]]] with torch.no_grad(): out = model.generate(feats, forced_decoder_ids=forced, return_timestamps=True, max_new_tokens=444) return tok.decode(out[0], skip_special_tokens=True).strip() def cli_transcribe(speaker_id: int) -> str: r = subprocess.run( [str(WHISPER_CLI), "-m", str(GGUF), "-f", str(WAV), "-l", "en", "--speaker", str(speaker_id), "-nfa", "-np"], capture_output=True, text=True, ) # Segments print as "[00:00:00.000 --> 00:00:01.640] text" out = [] for line in r.stdout.splitlines(): m = re.match(r"^\[.*?\]\s+(.*)$", line.strip()) if m: out.append(m.group(1)) return " ".join(out).strip() def normalize(s: str) -> str: s = re.sub(r"[^\w\s']", " ", s.lower()) s = re.sub(r"\s+", " ", s).strip() return s def cer(ref: str, hyp: str) -> float: """Character edit distance / len(ref). Levenshtein on chars.""" ref, hyp = normalize(ref), normalize(hyp) if not ref: return 1.0 if hyp else 0.0 m, n = len(ref), len(hyp) prev = list(range(n + 1)) for i in range(1, m + 1): cur = [i] + [0] * n for j in range(1, n + 1): cost = 0 if ref[i - 1] == hyp[j - 1] else 1 cur[j] = min(cur[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost) prev = cur return prev[n] / m print("\n--- Transformers (bf16 MPS) ---") hf1 = hf_transcribe(1) hf2 = hf_transcribe(2) print(f"S1 [hf]: {hf1}\n") print(f"S2 [hf]: {hf2}\n") print("--- whisper.cpp Q4_K ---") gg1 = cli_transcribe(1) gg2 = cli_transcribe(2) print(f"S1 [gg]: {gg1}\n") print(f"S2 [gg]: {gg2}\n") print("--- Parity (CER of Q4_K vs transformers) ---") print(f" speaker1 CER = {cer(hf1, gg1)*100:.2f}% (len hf={len(normalize(hf1))}, gg={len(normalize(gg1))})") print(f" speaker2 CER = {cer(hf2, gg2)*100:.2f}% (len hf={len(normalize(hf2))}, gg={len(normalize(gg2))})")