import gradio as gr import torch import torchaudio from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import epitran import re import editdistance import orjson from jiwer import wer # --- Device --- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using:", device) # --- WordMap --- WORD_MAP = { 'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'}, 'B': {'word': 'Ball', 'phonetic': 'bɔːl'}, 'C': {'word': 'Cat', 'phonetic': 'kæt'}, 'D': {'word': 'Dog', 'phonetic': 'dɒɡ'}, 'E': {'word': 'Egg', 'phonetic': 'ɛɡ'}, 'F': {'word': 'Fish', 'phonetic': 'fɪʃ'}, 'G': {'word': 'Goat', 'phonetic': 'ɡoʊt'}, 'H': {'word': 'Hat', 'phonetic': 'hæt'}, 'I': {'word': 'Ice', 'phonetic': 'aɪs'}, 'J': {'word': 'Jar', 'phonetic': 'dʒɑːr'}, 'K': {'word': 'Kite', 'phonetic': 'kaɪt'}, 'L': {'word': 'Lion', 'phonetic': 'ˈlaɪən'}, 'M': {'word': 'Moon', 'phonetic': 'muːn'}, 'N': {'word': 'Nest', 'phonetic': 'nɛst'}, 'O': {'word': 'Orange', 'phonetic': 'ˈɔːrɪndʒ'}, 'P': {'word': 'Pen', 'phonetic': 'pɛn'}, 'Q': {'word': 'Queen', 'phonetic': 'kwiːn'}, 'R': {'word': 'Rabbit', 'phonetic': 'ˈræbɪt'}, 'S': {'word': 'Sun', 'phonetic': 'sʌn'}, 'T': {'word': 'Tree', 'phonetic': 'triː'}, 'U': {'word': 'Umbrella', 'phonetic': 'ʌmˈbrɛlə'}, 'V': {'word': 'Van', 'phonetic': 'væn'}, 'W': {'word': 'Watch', 'phonetic': 'wɒtʃ'}, 'X': {'word': 'Xylophone', 'phonetic': 'ˈzaɪləfoʊn'}, 'Y': {'word': 'Yarn', 'phonetic': 'jɑːrn'}, 'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'} } # --- Load wav2vec2 (smaller + faster than Whisper) --- processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval() epi = epitran.Epitran("eng-Latn") IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()} # --- Helpers --- def transliterate(word): word_lower = word.lower() if word_lower in IPA_CACHE: return IPA_CACHE[word_lower] try: return epi.transliterate(word_lower) except Exception: return "" def transcribe(audio_path): waveform, sr = torchaudio.load(audio_path) if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True).to(device) with torch.no_grad(): logits = model(**inputs).logits pred_ids = torch.argmax(logits, dim=-1) return processor.decode(pred_ids[0]).lower() def analyze(language, reference_text, audio_input, detailed=True): try: transcription = transcribe(audio_input) # match closest word from WORD_MAP distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()} closest_word = min(distances, key=distances.get) similarity = round((1 - distances[closest_word] / max(1, len(closest_word))) * 100, 2) if not detailed: return {"language": language, "reference": reference_text, "transcription": closest_word} # phoneme-level alignment ref_ph = list(transliterate(reference_text)) obs_ph = list(transliterate(closest_word)) edits = editdistance.eval(ref_ph, obs_ph) phon_acc = round((1 - edits / max(1, len(ref_ph))) * 100, 2) return { "language": language, "reference": reference_text, "transcription": closest_word, "metrics": { "similarity": similarity, "phoneme_accuracy": phon_acc, "asr_word_error_rate": round(wer(reference_text, closest_word) * 100, 2) }, "alignment": { "reference_phonemes": "".join(ref_ph), "observed_phonemes": "".join(obs_ph), "edit_distance": edits } } except Exception as e: return {"error": str(e)} # --- Gradio UI --- with gr.Blocks() as demo: gr.Markdown("## Fast wav2vec2-based Phoneme Checker") with gr.Row(): lang = gr.Dropdown(["English"], value="English", label="Language") ref = gr.Textbox(value="A", label="Reference Word") audio = gr.Audio(label="Record Audio", type="filepath") detailed = gr.Checkbox(value=True, label="Detailed Mode") out = gr.JSON(label="Results") demo_btn = gr.Button("Analyze") demo_btn.click(analyze, inputs=[lang, ref, audio, detailed], outputs=out) demo.launch()