File size: 3,568 Bytes
5887db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Parity check: transformers vs whisper.cpp Q4_K GGML on sample_podcast.wav.
Reports per-speaker transcripts, CER, and segment count.
"""
from pathlib import Path
import os, re, subprocess, sys, time

import numpy as np
import soundfile as sf
import torch
from transformers import WhisperForConditionalGeneration, WhisperProcessor

ROOT = Path(__file__).parents[2]
WAV = ROOT / "experimental/multi-speaker/hf_space/static/sample_podcast.wav"
WHISPER_CLI = Path(__file__).parent / "whisper.cpp/build/bin/whisper-cli"
GGUF = Path(__file__).parent / "models/ggml-chorus-v1-q4_k.bin"
MODEL = "Trelis/Chorus-v1"

# Whisper + MPS + fp16 is broken in current transformers — falls back to CPU fp32.
device = "cpu"
dtype = torch.float32

print(f"[parity] device={device}")
token = os.environ.get("HF_TOKEN")
proc = WhisperProcessor.from_pretrained(MODEL, token=token)
model = WhisperForConditionalGeneration.from_pretrained(MODEL, token=token, dtype=dtype).to(device).eval()
model.generation_config.predict_timestamps = True
model.generation_config.max_initial_timestamp_index = 1500

tok = proc.tokenizer
ids = {n: tok.convert_tokens_to_ids(t) for n, t in [
    ("en", "<|en|>"), ("transcribe", "<|transcribe|>"),
    ("speaker1", "<|speaker1|>"), ("speaker2", "<|speaker2|>"),
]}

arr, sr = sf.read(WAV)
arr = np.asarray(arr, dtype=np.float32)
if arr.ndim > 1:
    arr = arr.mean(axis=1)
assert sr == 16_000
feats = proc.feature_extractor([arr], sampling_rate=16_000, return_tensors="pt").input_features.to(device).to(dtype)


def hf_transcribe(speaker_id: int) -> str:
    forced = [[1, ids["en"]], [2, ids["transcribe"]], [3, ids[f"speaker{speaker_id}"]]]
    with torch.no_grad():
        out = model.generate(feats, forced_decoder_ids=forced, return_timestamps=True, max_new_tokens=444)
    return tok.decode(out[0], skip_special_tokens=True).strip()


def cli_transcribe(speaker_id: int) -> str:
    r = subprocess.run(
        [str(WHISPER_CLI), "-m", str(GGUF), "-f", str(WAV),
         "-l", "en", "--speaker", str(speaker_id), "-nfa", "-np"],
        capture_output=True, text=True,
    )
    # Segments print as "[00:00:00.000 --> 00:00:01.640]   text"
    out = []
    for line in r.stdout.splitlines():
        m = re.match(r"^\[.*?\]\s+(.*)$", line.strip())
        if m:
            out.append(m.group(1))
    return " ".join(out).strip()


def normalize(s: str) -> str:
    s = re.sub(r"[^\w\s']", " ", s.lower())
    s = re.sub(r"\s+", " ", s).strip()
    return s


def cer(ref: str, hyp: str) -> float:
    """Character edit distance / len(ref). Levenshtein on chars."""
    ref, hyp = normalize(ref), normalize(hyp)
    if not ref:
        return 1.0 if hyp else 0.0
    m, n = len(ref), len(hyp)
    prev = list(range(n + 1))
    for i in range(1, m + 1):
        cur = [i] + [0] * n
        for j in range(1, n + 1):
            cost = 0 if ref[i - 1] == hyp[j - 1] else 1
            cur[j] = min(cur[j - 1] + 1, prev[j] + 1, prev[j - 1] + cost)
        prev = cur
    return prev[n] / m


print("\n--- Transformers (bf16 MPS) ---")
hf1 = hf_transcribe(1)
hf2 = hf_transcribe(2)
print(f"S1 [hf]: {hf1}\n")
print(f"S2 [hf]: {hf2}\n")

print("--- whisper.cpp Q4_K ---")
gg1 = cli_transcribe(1)
gg2 = cli_transcribe(2)
print(f"S1 [gg]: {gg1}\n")
print(f"S2 [gg]: {gg2}\n")

print("--- Parity (CER of Q4_K vs transformers) ---")
print(f"  speaker1  CER = {cer(hf1, gg1)*100:.2f}%   (len hf={len(normalize(hf1))}, gg={len(normalize(gg1))})")
print(f"  speaker2  CER = {cer(hf2, gg2)*100:.2f}%   (len hf={len(normalize(hf2))}, gg={len(normalize(gg2))})")