Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor | |
| import epitran | |
| import re | |
| import editdistance | |
| import orjson | |
| from jiwer import wer | |
| # --- Device --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print("Using:", device) | |
| # --- WordMap --- | |
| WORD_MAP = { | |
| 'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'}, | |
| 'B': {'word': 'Ball', 'phonetic': 'bɔːl'}, | |
| 'C': {'word': 'Cat', 'phonetic': 'kæt'}, | |
| 'D': {'word': 'Dog', 'phonetic': 'dɒɡ'}, | |
| 'E': {'word': 'Egg', 'phonetic': 'ɛɡ'}, | |
| 'F': {'word': 'Fish', 'phonetic': 'fɪʃ'}, | |
| 'G': {'word': 'Goat', 'phonetic': 'ɡoʊt'}, | |
| 'H': {'word': 'Hat', 'phonetic': 'hæt'}, | |
| 'I': {'word': 'Ice', 'phonetic': 'aɪs'}, | |
| 'J': {'word': 'Jar', 'phonetic': 'dʒɑːr'}, | |
| 'K': {'word': 'Kite', 'phonetic': 'kaɪt'}, | |
| 'L': {'word': 'Lion', 'phonetic': 'ˈlaɪən'}, | |
| 'M': {'word': 'Moon', 'phonetic': 'muːn'}, | |
| 'N': {'word': 'Nest', 'phonetic': 'nɛst'}, | |
| 'O': {'word': 'Orange', 'phonetic': 'ˈɔːrɪndʒ'}, | |
| 'P': {'word': 'Pen', 'phonetic': 'pɛn'}, | |
| 'Q': {'word': 'Queen', 'phonetic': 'kwiːn'}, | |
| 'R': {'word': 'Rabbit', 'phonetic': 'ˈræbɪt'}, | |
| 'S': {'word': 'Sun', 'phonetic': 'sʌn'}, | |
| 'T': {'word': 'Tree', 'phonetic': 'triː'}, | |
| 'U': {'word': 'Umbrella', 'phonetic': 'ʌmˈbrɛlə'}, | |
| 'V': {'word': 'Van', 'phonetic': 'væn'}, | |
| 'W': {'word': 'Watch', 'phonetic': 'wɒtʃ'}, | |
| 'X': {'word': 'Xylophone', 'phonetic': 'ˈzaɪləfoʊn'}, | |
| 'Y': {'word': 'Yarn', 'phonetic': 'jɑːrn'}, | |
| 'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'} | |
| } | |
| # --- Load wav2vec2 (smaller + faster than Whisper) --- | |
| processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") | |
| model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval() | |
| epi = epitran.Epitran("eng-Latn") | |
| IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()} | |
| # --- Helpers --- | |
| def transliterate(word): | |
| word_lower = word.lower() | |
| if word_lower in IPA_CACHE: | |
| return IPA_CACHE[word_lower] | |
| try: | |
| return epi.transliterate(word_lower) | |
| except Exception: | |
| return "" | |
| def transcribe(audio_path): | |
| waveform, sr = torchaudio.load(audio_path) | |
| if sr != 16000: | |
| waveform = torchaudio.functional.resample(waveform, sr, 16000) | |
| inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True).to(device) | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| pred_ids = torch.argmax(logits, dim=-1) | |
| return processor.decode(pred_ids[0]).lower() | |
| def analyze(language, reference_text, audio_input, detailed=True): | |
| try: | |
| transcription = transcribe(audio_input) | |
| # match closest word from WORD_MAP | |
| distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()} | |
| closest_word = min(distances, key=distances.get) | |
| similarity = round((1 - distances[closest_word] / max(1, len(closest_word))) * 100, 2) | |
| if not detailed: | |
| return {"language": language, "reference": reference_text, "transcription": closest_word} | |
| # phoneme-level alignment | |
| ref_ph = list(transliterate(reference_text)) | |
| obs_ph = list(transliterate(closest_word)) | |
| edits = editdistance.eval(ref_ph, obs_ph) | |
| phon_acc = round((1 - edits / max(1, len(ref_ph))) * 100, 2) | |
| return { | |
| "language": language, | |
| "reference": reference_text, | |
| "transcription": closest_word, | |
| "metrics": { | |
| "similarity": similarity, | |
| "phoneme_accuracy": phon_acc, | |
| "asr_word_error_rate": round(wer(reference_text, closest_word) * 100, 2) | |
| }, | |
| "alignment": { | |
| "reference_phonemes": "".join(ref_ph), | |
| "observed_phonemes": "".join(obs_ph), | |
| "edit_distance": edits | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --- Gradio UI --- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Fast wav2vec2-based Phoneme Checker") | |
| with gr.Row(): | |
| lang = gr.Dropdown(["English"], value="English", label="Language") | |
| ref = gr.Textbox(value="A", label="Reference Word") | |
| audio = gr.Audio(label="Record Audio", type="filepath") | |
| detailed = gr.Checkbox(value=True, label="Detailed Mode") | |
| out = gr.JSON(label="Results") | |
| demo_btn = gr.Button("Analyze") | |
| demo_btn.click(analyze, inputs=[lang, ref, audio, detailed], outputs=out) | |
| demo.launch() | |