File size: 4,662 Bytes
dd8edb5
 
cb451b4
f26b2f5
dd8edb5
 
 
c173e30
f26b2f5
dd8edb5
f26b2f5
c173e30
f26b2f5
c173e30
fa83439
b088af5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f26b2f5
 
 
2222b3b
f26b2f5
5dbb1d4
 
cb451b4
f26b2f5
5dbb1d4
 
 
1acef58
f26b2f5
c173e30
1acef58
dd8edb5
f26b2f5
cb451b4
 
 
f26b2f5
b088af5
cb451b4
f26b2f5
 
 
fa83439
f26b2f5
cb451b4
f26b2f5
 
 
 
 
 
b088af5
cb451b4
f26b2f5
b088af5
f26b2f5
 
 
cb451b4
f26b2f5
 
b088af5
f26b2f5
883f9e7
f26b2f5
 
 
 
 
 
 
 
 
 
 
 
b088af5
 
f26b2f5
883f9e7
fa83439
d949c72
f26b2f5
daa79d8
883f9e7
f26b2f5
 
 
 
 
 
 
 
63e7642
daa79d8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import epitran
import re
import editdistance
import orjson
from jiwer import wer

# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

# --- WordMap ---
WORD_MAP = {
    'A': {'word': 'Apple', 'phonetic': 'ˈæpəl'},
    'B': {'word': 'Ball', 'phonetic': 'bɔːl'},
    'C': {'word': 'Cat', 'phonetic': 'kæt'},
    'D': {'word': 'Dog', 'phonetic': 'dɒɡ'},
    'E': {'word': 'Egg', 'phonetic': 'ɛɡ'},
    'F': {'word': 'Fish', 'phonetic': 'fɪʃ'},
    'G': {'word': 'Goat', 'phonetic': 'ɡoʊt'},
    'H': {'word': 'Hat', 'phonetic': 'hæt'},
    'I': {'word': 'Ice', 'phonetic': 'aɪs'},
    'J': {'word': 'Jar', 'phonetic': 'dʒɑːr'},
    'K': {'word': 'Kite', 'phonetic': 'kaɪt'},
    'L': {'word': 'Lion', 'phonetic': 'ˈlaɪən'},
    'M': {'word': 'Moon', 'phonetic': 'muːn'},
    'N': {'word': 'Nest', 'phonetic': 'nɛst'},
    'O': {'word': 'Orange', 'phonetic': 'ˈɔːrɪndʒ'},
    'P': {'word': 'Pen', 'phonetic': 'pɛn'},
    'Q': {'word': 'Queen', 'phonetic': 'kwiːn'},
    'R': {'word': 'Rabbit', 'phonetic': 'ˈræbɪt'},
    'S': {'word': 'Sun', 'phonetic': 'sʌn'},
    'T': {'word': 'Tree', 'phonetic': 'triː'},
    'U': {'word': 'Umbrella', 'phonetic': 'ʌmˈbrɛlə'},
    'V': {'word': 'Van', 'phonetic': 'væn'},
    'W': {'word': 'Watch', 'phonetic': 'wɒtʃ'},
    'X': {'word': 'Xylophone', 'phonetic': 'ˈzaɪləfoʊn'},
    'Y': {'word': 'Yarn', 'phonetic': 'jɑːrn'},
    'Z': {'word': 'Zebra', 'phonetic': 'ˈziːbrə'}
}

# --- Load wav2vec2 (smaller + faster than Whisper) ---
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(device).eval()

epi = epitran.Epitran("eng-Latn")
IPA_CACHE = {v['word'].lower(): re.sub(r'[^\w\s]', '', v['phonetic']) for v in WORD_MAP.values()}

# --- Helpers ---
def transliterate(word):
    word_lower = word.lower()
    if word_lower in IPA_CACHE:
        return IPA_CACHE[word_lower]
    try:
        return epi.transliterate(word_lower)
    except Exception:
        return ""

def transcribe(audio_path):
    waveform, sr = torchaudio.load(audio_path)
    if sr != 16000:
        waveform = torchaudio.functional.resample(waveform, sr, 16000)
    inputs = processor(waveform.squeeze(), sampling_rate=16000, return_tensors="pt", padding=True).to(device)

    with torch.no_grad():
        logits = model(**inputs).logits
    pred_ids = torch.argmax(logits, dim=-1)
    return processor.decode(pred_ids[0]).lower()

def analyze(language, reference_text, audio_input, detailed=True):
    try:
        transcription = transcribe(audio_input)

        # match closest word from WORD_MAP
        distances = {entry['word'].lower(): editdistance.eval(transcription, entry['word'].lower()) for entry in WORD_MAP.values()}
        closest_word = min(distances, key=distances.get)
        similarity = round((1 - distances[closest_word] / max(1, len(closest_word))) * 100, 2)

        if not detailed:
            return {"language": language, "reference": reference_text, "transcription": closest_word}

        # phoneme-level alignment
        ref_ph = list(transliterate(reference_text))
        obs_ph = list(transliterate(closest_word))

        edits = editdistance.eval(ref_ph, obs_ph)
        phon_acc = round((1 - edits / max(1, len(ref_ph))) * 100, 2)

        return {
            "language": language,
            "reference": reference_text,
            "transcription": closest_word,
            "metrics": {
                "similarity": similarity,
                "phoneme_accuracy": phon_acc,
                "asr_word_error_rate": round(wer(reference_text, closest_word) * 100, 2)
            },
            "alignment": {
                "reference_phonemes": "".join(ref_ph),
                "observed_phonemes": "".join(obs_ph),
                "edit_distance": edits
            }
        }
    except Exception as e:
        return {"error": str(e)}

# --- Gradio UI ---
with gr.Blocks() as demo:
    gr.Markdown("## Fast wav2vec2-based Phoneme Checker")

    with gr.Row():
        lang = gr.Dropdown(["English"], value="English", label="Language")
        ref = gr.Textbox(value="A", label="Reference Word")
    audio = gr.Audio(label="Record Audio", type="filepath")
    detailed = gr.Checkbox(value=True, label="Detailed Mode")
    out = gr.JSON(label="Results")

    demo_btn = gr.Button("Analyze")
    demo_btn.click(analyze, inputs=[lang, ref, audio, detailed], outputs=out)

demo.launch()