Spaces:
Sleeping
Sleeping
File size: 4,662 Bytes
dd8edb5 cb451b4 f26b2f5 dd8edb5 c173e30 f26b2f5 dd8edb5 f26b2f5 c173e30 f26b2f5 c173e30 fa83439 b088af5 f26b2f5 2222b3b f26b2f5 5dbb1d4 cb451b4 f26b2f5 5dbb1d4 1acef58 f26b2f5 c173e30 1acef58 dd8edb5 f26b2f5 cb451b4 f26b2f5 b088af5 cb451b4 f26b2f5 fa83439 f26b2f5 cb451b4 f26b2f5 b088af5 cb451b4 f26b2f5 b088af5 f26b2f5 cb451b4 f26b2f5 b088af5 f26b2f5 883f9e7 f26b2f5 b088af5 f26b2f5 883f9e7 fa83439 d949c72 f26b2f5 daa79d8 883f9e7 f26b2f5 63e7642 daa79d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import epitran
import re
import editdistance
import orjson
from jiwer import wer
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)
# --- WordMap ---
WORD_MAP = {
'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'},
'B': {'word': 'Ball', 'phonetic': 'bɔːl'},
'C': {'word': 'Cat', 'phonetic': 'kæt'},
'D': {'word': 'Dog', 'phonetic': 'dɒɡ'},
'E': {'word': 'Egg', 'phonetic': 'ɛɡ'},
'F': {'word': 'Fish', 'phonetic': 'fɪʃ'},
'G': {'word': 'Goat', 'phonetic': 'ɡoʊt'},
'H': {'word': 'Hat', 'phonetic': 'hæt'},
'I': {'word': 'Ice', 'phonetic': 'aɪs'},
'J': {'word': 'Jar', 'phonetic': 'dʒɑːr'},
'K': {'word': 'Kite', 'phonetic': 'kaɪt'},
'L': {'word': 'Lion', 'phonetic': 'ˈlaɪən'},
'M': {'word': 'Moon', 'phonetic': 'muːn'},
'N': {'word': 'Nest', 'phonetic': 'nɛst'},
'O': {'word': 'Orange', 'phonetic': 'ˈɔːrɪndʒ'},
'P': {'word': 'Pen', 'phonetic': 'pɛn'},
'Q': {'word': 'Queen', 'phonetic': 'kwiːn'},
'R': {'word': 'Rabbit', 'phonetic': 'ˈræbɪt'},
'S': {'word': 'Sun', 'phonetic': 'sʌn'},
'T': {'word': 'Tree', 'phonetic': 'triː'},
'U': {'word': 'Umbrella', 'phonetic': 'ʌmˈbrɛlə'},
'V': {'word': 'Van', 'phonetic': 'væn'},
'W': {'word': 'Watch', 'phonetic': 'wɒtʃ'},
'X': {'word': 'Xylophone', 'phonetic': 'ˈzaɪləfoʊn'},
'Y': {'word': 'Yarn', 'phonetic': 'jɑːrn'},
'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
}
# --- Load wav2vec2 (smaller + faster than Whisper) ---
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval()
epi = epitran.Epitran("eng-Latn")
IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()}
# --- Helpers ---
def transliterate(word):
word_lower = word.lower()
if word_lower in IPA_CACHE:
return IPA_CACHE[word_lower]
try:
return epi.transliterate(word_lower)
except Exception:
return ""
def transcribe(audio_path):
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
logits = model(**inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
return processor.decode(pred_ids[0]).lower()
def analyze(language, reference_text, audio_input, detailed=True):
try:
transcription = transcribe(audio_input)
# match closest word from WORD_MAP
distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()}
closest_word = min(distances, key=distances.get)
similarity = round((1 - distances[closest_word] / max(1, len(closest_word))) * 100, 2)
if not detailed:
return {"language": language, "reference": reference_text, "transcription": closest_word}
# phoneme-level alignment
ref_ph = list(transliterate(reference_text))
obs_ph = list(transliterate(closest_word))
edits = editdistance.eval(ref_ph, obs_ph)
phon_acc = round((1 - edits / max(1, len(ref_ph))) * 100, 2)
return {
"language": language,
"reference": reference_text,
"transcription": closest_word,
"metrics": {
"similarity": similarity,
"phoneme_accuracy": phon_acc,
"asr_word_error_rate": round(wer(reference_text, closest_word) * 100, 2)
},
"alignment": {
"reference_phonemes": "".join(ref_ph),
"observed_phonemes": "".join(obs_ph),
"edit_distance": edits
}
}
except Exception as e:
return {"error": str(e)}
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## Fast wav2vec2-based Phoneme Checker")
with gr.Row():
lang = gr.Dropdown(["English"], value="English", label="Language")
ref = gr.Textbox(value="A", label="Reference Word")
audio = gr.Audio(label="Record Audio", type="filepath")
detailed = gr.Checkbox(value=True, label="Detailed Mode")
out = gr.JSON(label="Results")
demo_btn = gr.Button("Analyze")
demo_btn.click(analyze, inputs=[lang, ref, audio, detailed], outputs=out)
demo.launch()
|