| """ |
| Parity check: transformers vs whisper.cpp Q4_K GGML on sample_podcast.wav. |
| Reports per-speaker transcripts, CER, and segment count. |
| """ |
| from pathlib import Path |
| import os, re, subprocess, sys, time |
|
|
| import numpy as np |
| import soundfile as sf |
| import torch |
| from transformers import WhisperForConditionalGeneration, WhisperProcessor |
|
|
| ROOT = Path(__file__).parents[2] |
| WAV = ROOT / "experimental/multi-speaker/hf_space/static/sample_podcast.wav" |
| WHISPER_CLI = Path(__file__).parent / "whisper.cpp/build/bin/whisper-cli" |
| GGUF = Path(__file__).parent / "models/ggml-chorus-v1-q4_k.bin" |
| MODEL = "Trelis/Chorus-v1" |
|
|
| |
| device = "cpu" |
| dtype = torch.float32 |
|
|
| print(f"[parity] device={device}") |
| token = os.environ.get("HF_TOKEN") |
| proc = WhisperProcessor.from_pretrained(MODEL, token=token) |
| model = WhisperForConditionalGeneration.from_pretrained(MODEL, token=token, dtype=dtype).to(device).eval() |
| model.generation_config.predict_timestamps = True |
| model.generation_config.max_initial_timestamp_index = 1500 |
|
|
| tok = proc.tokenizer |
| ids = {n: tok.convert_tokens_to_ids(t) for n, t in [ |
| ("en", "<|en|>"), ("transcribe", "<|transcribe|>"), |
| ("speaker1", "<|speaker1|>"), ("speaker2", "<|speaker2|>"), |
| ]} |
|
|
| arr, sr = sf.read(WAV) |
| arr = np.asarray(arr, dtype=np.float32) |
| if arr.ndim > 1: |
| arr = arr.mean(axis=1) |
| assert sr == 16_000 |
| feats = proc.feature_extractor([arr], sampling_rate=16_000, return_tensors="pt").input_features.to(device).to(dtype) |
|
|
|
|
| def hf_transcribe(speaker_id: int) -> str: |
| forced = [[1, ids["en"]], [2, ids["transcribe"]], [3, ids[f"speaker{speaker_id}"]]] |
| with torch.no_grad(): |
| out = model.generate(feats, forced_decoder_ids=forced, return_timestamps=True, max_new_tokens=444) |
| return tok.decode(out[0], skip_special_tokens=True).strip() |
|
|
|
|
| def cli_transcribe(speaker_id: int) -> str: |
| r = subprocess.run( |
| [str(WHISPER_CLI), "-m", str(GGUF), "-f", str(WAV), |
| "-l", "en", "--speaker", str(speaker_id), "-nfa", "-np"], |
| capture_output=True, text=True, |
| ) |
| |
| out = [] |
| for line in r.stdout.splitlines(): |
| m = re.match(r"^\[.*?\]\s+(.*)$", line.strip()) |
| if m: |
| out.append(m.group(1)) |
| return " ".join(out).strip() |
|
|
|
|
| def normalize(s: str) -> str: |
| s = re.sub(r"[^\w\s']", " ", s.lower()) |
| s = re.sub(r"\s+", " ", s).strip() |
| return s |
|
|
|
|
| def cer(ref: str, hyp: str) -> float: |
| """Character edit distance / len(ref). Levenshtein on chars.""" |
| ref, hyp = normalize(ref), normalize(hyp) |
| if not ref: |
| return 1.0 if hyp else 0.0 |
| m, n = len(ref), len(hyp) |
| prev = list(range(n + 1)) |
| for i in range(1, m + 1): |
| cur = [i] + [0] * n |
| for j in range(1, n + 1): |
| cost = 0 if ref[i - 1] == hyp[j - 1] else 1 |
| cur[j] = min(cur[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost) |
| prev = cur |
| return prev[n] / m |
|
|
|
|
| print("\n--- Transformers (bf16 MPS) ---") |
| hf1 = hf_transcribe(1) |
| hf2 = hf_transcribe(2) |
| print(f"S1 [hf]: {hf1}\n") |
| print(f"S2 [hf]: {hf2}\n") |
|
|
| print("--- whisper.cpp Q4_K ---") |
| gg1 = cli_transcribe(1) |
| gg2 = cli_transcribe(2) |
| print(f"S1 [gg]: {gg1}\n") |
| print(f"S2 [gg]: {gg2}\n") |
|
|
| print("--- Parity (CER of Q4_K vs transformers) ---") |
| print(f" speaker1 CER = {cer(hf1, gg1)*100:.2f}% (len hf={len(normalize(hf1))}, gg={len(normalize(gg1))})") |
| print(f" speaker2 CER = {cer(hf2, gg2)*100:.2f}% (len hf={len(normalize(hf2))}, gg={len(normalize(gg2))})") |
|
|