Chorus-v1-GGML / scripts /parity_check.py
RonanMcGovern's picture
Initial upload: f16 + q4_k/q5_k/q8_0 GGML, scripts, whisper.cpp patch
5887db6 verified
"""
Parity check: transformers vs whisper.cpp Q4_K GGML on sample_podcast.wav.
Reports per-speaker transcripts, CER, and segment count.
"""
from pathlib import Path
import os, re, subprocess, sys, time
import numpy as np
import soundfile as sf
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor
ROOT = Path(__file__).parents[2]
WAV = ROOT / "experimental/multi-speaker/hf_space/static/sample_podcast.wav"
WHISPER_CLI = Path(__file__).parent / "whisper.cpp/build/bin/whisper-cli"
GGUF = Path(__file__).parent / "models/ggml-chorus-v1-q4_k.bin"
MODEL = "Trelis/Chorus-v1"
# Whisper + MPS + fp16 is broken in current transformers — falls back to CPU fp32.
device = "cpu"
dtype = torch.float32
print(f"[parity] device={device}")
token = os.environ.get("HF_TOKEN")
proc = WhisperProcessor.from_pretrained(MODEL, token=token)
model = WhisperForConditionalGeneration.from_pretrained(MODEL, token=token, dtype=dtype).to(device).eval()
model.generation_config.predict_timestamps = True
model.generation_config.max_initial_timestamp_index = 1500
tok = proc.tokenizer
ids = {n: tok.convert_tokens_to_ids(t) for n, t in [
("en", "<|en|>"), ("transcribe", "<|transcribe|>"),
("speaker1", "<|speaker1|>"), ("speaker2", "<|speaker2|>"),
]}
arr, sr = sf.read(WAV)
arr = np.asarray(arr, dtype=np.float32)
if arr.ndim > 1:
arr = arr.mean(axis=1)
assert sr == 16_000
feats = proc.feature_extractor([arr], sampling_rate=16_000, return_tensors="pt").input_features.to(device).to(dtype)
def hf_transcribe(speaker_id: int) -> str:
forced = [[1, ids["en"]], [2, ids["transcribe"]], [3, ids[f"speaker{speaker_id}"]]]
with torch.no_grad():
out = model.generate(feats, forced_decoder_ids=forced, return_timestamps=True, max_new_tokens=444)
return tok.decode(out[0], skip_special_tokens=True).strip()
def cli_transcribe(speaker_id: int) -> str:
r = subprocess.run(
[str(WHISPER_CLI), "-m", str(GGUF), "-f", str(WAV),
"-l", "en", "--speaker", str(speaker_id), "-nfa", "-np"],
capture_output=True, text=True,
)
# Segments print as "[00:00:00.000 --> 00:00:01.640] text"
out = []
for line in r.stdout.splitlines():
m = re.match(r"^\[.*?\]\s+(.*)$", line.strip())
if m:
out.append(m.group(1))
return " ".join(out).strip()
def normalize(s: str) -> str:
s = re.sub(r"[^\w\s']", " ", s.lower())
s = re.sub(r"\s+", " ", s).strip()
return s
def cer(ref: str, hyp: str) -> float:
"""Character edit distance / len(ref). Levenshtein on chars."""
ref, hyp = normalize(ref), normalize(hyp)
if not ref:
return 1.0 if hyp else 0.0
m, n = len(ref), len(hyp)
prev = list(range(n + 1))
for i in range(1, m + 1):
cur = [i] + [0] * n
for j in range(1, n + 1):
cost = 0 if ref[i - 1] == hyp[j - 1] else 1
cur[j] = min(cur[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost)
prev = cur
return prev[n] / m
print("\n--- Transformers (bf16 MPS) ---")
hf1 = hf_transcribe(1)
hf2 = hf_transcribe(2)
print(f"S1 [hf]: {hf1}\n")
print(f"S2 [hf]: {hf2}\n")
print("--- whisper.cpp Q4_K ---")
gg1 = cli_transcribe(1)
gg2 = cli_transcribe(2)
print(f"S1 [gg]: {gg1}\n")
print(f"S2 [gg]: {gg2}\n")
print("--- Parity (CER of Q4_K vs transformers) ---")
print(f" speaker1 CER = {cer(hf1, gg1)*100:.2f}% (len hf={len(normalize(hf1))}, gg={len(normalize(gg1))})")
print(f" speaker2 CER = {cer(hf2, gg2)*100:.2f}% (len hf={len(normalize(hf2))}, gg={len(normalize(gg2))})")