import os import time import torch import torchaudio import gradio as gr from transformers import ( Wav2Vec2Processor, HubertForCTC, WhisperProcessor, WhisperForConditionalGeneration ) from phonemizer import phonemize import difflib # === Setup: Load all 3 models === # 1. Base HuBERT base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").eval() # 2. Whisper + phonemizer whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base") whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").eval() # 3. My Hubert Model token = os.environ.get("HF_TOKEN") your_proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=token) your_model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=token).eval() # === Helper === def load_audio(filepath): waveform, sr = torchaudio.load(filepath) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) return waveform.squeeze() def calc_per(pred, ref): pred_list = pred.strip().split() ref_list = ref.strip().split() sm = difflib.SequenceMatcher(None, ref_list, pred_list) dist = sum(tr[-1] for tr in sm.get_opcodes() if tr[0] != 'equal') if len(ref_list) == 0: return 0.0 return round(100 * dist / len(ref_list), 2) # === Inference functions === def run_hubert_base(wav): start = time.time() inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): logits = base_model(**inputs).logits ids = torch.argmax(logits, dim=-1) phonemes = base_proc.batch_decode(ids)[0] return phonemes, time.time() - start def run_whisper(wav): start = time.time() inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): ids = whisper_model.generate(inputs["input_features"]) text = whisper_proc.batch_decode(ids, skip_special_tokens=True)[0] phonemes = phonemize(text, language='en-us', backend='espeak') return phonemes, time.time() - start def run_your_model(wav): start = time.time() inputs = your_proc(wav, sampling_rate=16000, return_tensors="pt") with torch.no_grad(): logits = your_model(**inputs).logits ids = torch.argmax(logits, dim=-1) phonemes = your_proc.batch_decode(ids)[0] return phonemes, time.time() - start # === Main Gradio function === def benchmark_all(audio_path, reference_phoneme): wav = load_audio(audio_path) results = [] # 1. HuBERT Base phonemes, dur = run_hubert_base(wav) per = calc_per(phonemes, reference_phoneme) results.append(["HuBERT-Base", phonemes, f"{dur:.2f}s", f"{per}%"]) # 2. Whisper phonemes, dur = run_whisper(wav) per = calc_per(phonemes, reference_phoneme) results.append(["Whisper + Phonemizer", phonemes, f"{dur:.2f}s", f"{per}%"]) # 3. My Hubert model phonemes, dur = run_your_model(wav) per = calc_per(phonemes, reference_phoneme) results.append(["Your HuBERT (fine-tuned)", phonemes, f"{dur:.2f}s", f"{per}%"]) return results # === UI === demo = gr.Interface( fn=benchmark_all, inputs=[ gr.Audio(type="filepath", label="Upload Audio"), gr.Textbox(label="Ground-truth Phonemes (space-separated)", placeholder="f ə n ə m aɪ z") ], outputs=gr.Dataframe(headers=["Model", "Phoneme Output", "Inference Time", "PER (%)"]), title="Phoneme Recognition Benchmark", description="Compare HuBERT-Base, Whisper, and your fine-tuned model on phoneme recognition." ) if __name__ == "__main__": demo.launch()