Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import torch | |
| import torchaudio | |
| import gradio as gr | |
| from transformers import ( | |
| Wav2Vec2Processor, HubertForCTC, | |
| WhisperProcessor, WhisperForConditionalGeneration | |
| ) | |
| from phonemizer import phonemize | |
| import difflib | |
| # === Setup: Load all 3 models === | |
| # 1. Base HuBERT | |
| base_proc = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft") | |
| base_model = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft").eval() | |
| # 2. Whisper + phonemizer | |
| whisper_proc = WhisperProcessor.from_pretrained("openai/whisper-base") | |
| whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base").eval() | |
| # 3. My Hubert Model | |
| token = os.environ.get("HF_TOKEN") | |
| your_proc = Wav2Vec2Processor.from_pretrained("tecasoftai/hubert-finetune", token=token) | |
| your_model = HubertForCTC.from_pretrained("tecasoftai/hubert-finetune", token=token).eval() | |
| # === Helper === | |
| def load_audio(filepath): | |
| waveform, sr = torchaudio.load(filepath) | |
| if sr != 16000: | |
| waveform = torchaudio.functional.resample(waveform, sr, 16000) | |
| return waveform.squeeze() | |
| def calc_per(pred, ref): | |
| pred_list = pred.strip().split() | |
| ref_list = ref.strip().split() | |
| sm = difflib.SequenceMatcher(None, ref_list, pred_list) | |
| dist = sum(tr[-1] for tr in sm.get_opcodes() if tr[0] != 'equal') | |
| if len(ref_list) == 0: | |
| return 0.0 | |
| return round(100 * dist / len(ref_list), 2) | |
| # === Inference functions === | |
| def run_hubert_base(wav): | |
| start = time.time() | |
| inputs = base_proc(wav, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = base_model(**inputs).logits | |
| ids = torch.argmax(logits, dim=-1) | |
| phonemes = base_proc.batch_decode(ids)[0] | |
| return phonemes, time.time() - start | |
| def run_whisper(wav): | |
| start = time.time() | |
| inputs = whisper_proc(wav, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| ids = whisper_model.generate(inputs["input_features"]) | |
| text = whisper_proc.batch_decode(ids, skip_special_tokens=True)[0] | |
| phonemes = phonemize(text, language='en-us', backend='espeak') | |
| return phonemes, time.time() - start | |
| def run_your_model(wav): | |
| start = time.time() | |
| inputs = your_proc(wav, sampling_rate=16000, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = your_model(**inputs).logits | |
| ids = torch.argmax(logits, dim=-1) | |
| phonemes = your_proc.batch_decode(ids)[0] | |
| return phonemes, time.time() - start | |
| # === Main Gradio function === | |
| def benchmark_all(audio_path, reference_phoneme): | |
| wav = load_audio(audio_path) | |
| results = [] | |
| # 1. HuBERT Base | |
| phonemes, dur = run_hubert_base(wav) | |
| per = calc_per(phonemes, reference_phoneme) | |
| results.append(["HuBERT-Base", phonemes, f"{dur:.2f}s", f"{per}%"]) | |
| # 2. Whisper | |
| phonemes, dur = run_whisper(wav) | |
| per = calc_per(phonemes, reference_phoneme) | |
| results.append(["Whisper + Phonemizer", phonemes, f"{dur:.2f}s", f"{per}%"]) | |
| # 3. My Hubert model | |
| phonemes, dur = run_your_model(wav) | |
| per = calc_per(phonemes, reference_phoneme) | |
| results.append(["Your HuBERT (fine-tuned)", phonemes, f"{dur:.2f}s", f"{per}%"]) | |
| return results | |
| # === UI === | |
| demo = gr.Interface( | |
| fn=benchmark_all, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Upload Audio"), | |
| gr.Textbox(label="Ground-truth Phonemes (space-separated)", placeholder="f ə n ə m aɪ z") | |
| ], | |
| outputs=gr.Dataframe(headers=["Model", "Phoneme Output", "Inference Time", "PER (%)"]), | |
| title="Phoneme Recognition Benchmark", | |
| description="Compare HuBERT-Base, Whisper, and your fine-tuned model on phoneme recognition." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |