wav2vec2-api / app.py
KuyaToto's picture
Update app.py
f26b2f5 verified
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import epitran
import re
import editdistance
import orjson
from jiwer import wer
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)
# --- WordMap ---
WORD_MAP = {
'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'},
'B': {'word': 'Ball', 'phonetic': 'bɔːl'},
'C': {'word': 'Cat', 'phonetic': 'kæt'},
'D': {'word': 'Dog', 'phonetic': 'dɒɡ'},
'E': {'word': 'Egg', 'phonetic': 'ɛɡ'},
'F': {'word': 'Fish', 'phonetic': 'fɪʃ'},
'G': {'word': 'Goat', 'phonetic': 'ɡoʊt'},
'H': {'word': 'Hat', 'phonetic': 'hæt'},
'I': {'word': 'Ice', 'phonetic': 'aɪs'},
'J': {'word': 'Jar', 'phonetic': 'dʒɑːr'},
'K': {'word': 'Kite', 'phonetic': 'kaɪt'},
'L': {'word': 'Lion', 'phonetic': 'ˈlaɪən'},
'M': {'word': 'Moon', 'phonetic': 'muːn'},
'N': {'word': 'Nest', 'phonetic': 'nɛst'},
'O': {'word': 'Orange', 'phonetic': 'ˈɔːrɪndʒ'},
'P': {'word': 'Pen', 'phonetic': 'pɛn'},
'Q': {'word': 'Queen', 'phonetic': 'kwiːn'},
'R': {'word': 'Rabbit', 'phonetic': 'ˈræbɪt'},
'S': {'word': 'Sun', 'phonetic': 'sʌn'},
'T': {'word': 'Tree', 'phonetic': 'triː'},
'U': {'word': 'Umbrella', 'phonetic': 'ʌmˈbrɛlə'},
'V': {'word': 'Van', 'phonetic': 'væn'},
'W': {'word': 'Watch', 'phonetic': 'wɒtʃ'},
'X': {'word': 'Xylophone', 'phonetic': 'ˈzaɪləfoʊn'},
'Y': {'word': 'Yarn', 'phonetic': 'jɑːrn'},
'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
}
# --- Load wav2vec2 (smaller + faster than Whisper) ---
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval()
epi = epitran.Epitran("eng-Latn")
IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()}
# --- Helpers ---
def transliterate(word):
word_lower = word.lower()
if word_lower in IPA_CACHE:
return IPA_CACHE[word_lower]
try:
return epi.transliterate(word_lower)
except Exception:
return ""
def transcribe(audio_path):
waveform, sr = torchaudio.load(audio_path)
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
logits = model(**inputs).logits
pred_ids = torch.argmax(logits, dim=-1)
return processor.decode(pred_ids[0]).lower()
def analyze(language, reference_text, audio_input, detailed=True):
try:
transcription = transcribe(audio_input)
# match closest word from WORD_MAP
distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()}
closest_word = min(distances, key=distances.get)
similarity = round((1 - distances[closest_word] / max(1, len(closest_word))) * 100, 2)
if not detailed:
return {"language": language, "reference": reference_text, "transcription": closest_word}
# phoneme-level alignment
ref_ph = list(transliterate(reference_text))
obs_ph = list(transliterate(closest_word))
edits = editdistance.eval(ref_ph, obs_ph)
phon_acc = round((1 - edits / max(1, len(ref_ph))) * 100, 2)
return {
"language": language,
"reference": reference_text,
"transcription": closest_word,
"metrics": {
"similarity": similarity,
"phoneme_accuracy": phon_acc,
"asr_word_error_rate": round(wer(reference_text, closest_word) * 100, 2)
},
"alignment": {
"reference_phonemes": "".join(ref_ph),
"observed_phonemes": "".join(obs_ph),
"edit_distance": edits
}
}
except Exception as e:
return {"error": str(e)}
# --- Gradio UI ---
with gr.Blocks() as demo:
gr.Markdown("## Fast wav2vec2-based Phoneme Checker")
with gr.Row():
lang = gr.Dropdown(["English"], value="English", label="Language")
ref = gr.Textbox(value="A", label="Reference Word")
audio = gr.Audio(label="Record Audio", type="filepath")
detailed = gr.Checkbox(value=True, label="Detailed Mode")
out = gr.JSON(label="Results")
demo_btn = gr.Button("Analyze")
demo_btn.click(analyze, inputs=[lang, ref, audio, detailed], outputs=out)
demo.launch()