Spaces:
Sleeping
Sleeping
Commit ·
9496a85
1
Parent(s): 2a517e7
Add inference model + Gradio app
Browse files- app.py +24 -78
- model_inference.py +100 -49
- requirements.txt +2 -6
app.py
CHANGED
|
@@ -1,95 +1,41 @@
|
|
| 1 |
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
| 4 |
-
import numpy as np
|
| 5 |
-
import soundfile as sf
|
| 6 |
-
import tempfile
|
| 7 |
from model_inference import NetTALKWrapper
|
| 8 |
|
| 9 |
-
#
|
| 10 |
-
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
| 14 |
-
model = NetTALKWrapper(
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
def synthesize_gtts(phoneme_text):
|
| 18 |
-
from gtts import gTTS
|
| 19 |
-
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
| 20 |
-
# gTTS outputs mp3 -> convert to wav using soundfile via numpy? Simpler: save mp3 then load then re-save wav
|
| 21 |
-
mp3_tmp = tmp.name + ".mp3"
|
| 22 |
-
tts = gTTS(phoneme_text, lang="en")
|
| 23 |
-
tts.save(mp3_tmp)
|
| 24 |
-
# load mp3 with soundfile may not work; scipy can read via pydub if available.
|
| 25 |
-
try:
|
| 26 |
-
import pydub
|
| 27 |
-
audio = pydub.AudioSegment.from_mp3(mp3_tmp)
|
| 28 |
-
audio.export(tmp.name, format="wav")
|
| 29 |
-
except Exception:
|
| 30 |
-
# fallback: return mp3 (Gradio accepts mp3 as audio)
|
| 31 |
-
return mp3_tmp
|
| 32 |
-
return tmp.name
|
| 33 |
-
|
| 34 |
-
# optional: Coqui TTS (phoneme-aware) - heavier but can take ARPAbet inputs
|
| 35 |
-
def synthesize_coqui(arpabet):
|
| 36 |
-
# This requires the `TTS` package and an appropriate model that accepts phoneme input.
|
| 37 |
-
try:
|
| 38 |
-
from TTS.api import TTS
|
| 39 |
-
except Exception as e:
|
| 40 |
-
raise RuntimeError("TTS package not installed or failed to import.") from e
|
| 41 |
-
|
| 42 |
-
# choose a model name you installed / that exists; example placeholder:
|
| 43 |
-
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
|
| 44 |
-
# Some TTS models accept `phoneme` argument or `phoneme_input=True`. Check the model docs.
|
| 45 |
-
wav = tts.tts(arpabet, speaker=None, phoneme_input=False)
|
| 46 |
-
# wav is a numpy array and sample rate accessible via tts.synthesizer.output_sample_rate
|
| 47 |
-
sr = tts.synthesizer.output_sample_rate if hasattr(tts.synthesizer, "output_sample_rate") else 22050
|
| 48 |
-
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
|
| 49 |
-
sf.write(tmp.name, wav, sr)
|
| 50 |
-
return tmp.name
|
| 51 |
-
|
| 52 |
-
def predict_and_speak(word):
|
| 53 |
if not word or not word.strip():
|
| 54 |
return "Please enter a word", None
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
-
phonemes = model.predict(word)
|
| 57 |
-
audio_path = None
|
| 58 |
-
|
| 59 |
-
# Try preferred backend
|
| 60 |
-
try:
|
| 61 |
-
if TTS_BACKEND == "coqui":
|
| 62 |
-
audio_path = synthesize_coqui(phonemes)
|
| 63 |
-
else:
|
| 64 |
-
audio_path = synthesize_gtts(phonemes)
|
| 65 |
-
except Exception as e:
|
| 66 |
-
# If synth fails, still return phonemes and a None audio
|
| 67 |
-
print("Synthesis failed:", e)
|
| 68 |
-
audio_path = None
|
| 69 |
-
|
| 70 |
-
# gr.Audio accepts: filename (wav/mp3), numpy array, or (np, sr)
|
| 71 |
-
return phonemes, audio_path
|
| 72 |
-
|
| 73 |
-
# ---- Gradio UI ----
|
| 74 |
css = """
|
| 75 |
-
|
| 76 |
-
|
| 77 |
"""
|
| 78 |
|
| 79 |
-
with gr.Blocks(css=css, theme=gr.themes.
|
| 80 |
-
gr.Markdown("# 🧠 NetTALK
|
| 81 |
-
gr.Markdown("Enter a word
|
| 82 |
-
|
| 83 |
with gr.Row():
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
run_btn.click(fn=predict_and_speak, inputs=[word_in], outputs=[phoneme_out, audio_out])
|
| 91 |
|
| 92 |
-
|
| 93 |
|
| 94 |
if __name__ == "__main__":
|
| 95 |
-
demo.launch()
|
|
|
|
| 1 |
# app.py
|
| 2 |
import gradio as gr
|
| 3 |
import os
|
|
|
|
|
|
|
|
|
|
| 4 |
from model_inference import NetTALKWrapper
|
| 5 |
|
| 6 |
+
# Optional: set env var NETTALK_STATE_DICT to different filename if needed
|
| 7 |
+
STATE_DICT = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
|
| 8 |
|
| 9 |
+
# instantiate the model once
|
| 10 |
+
try:
|
| 11 |
+
model = NetTALKWrapper(state_dict_path=STATE_DICT)
|
| 12 |
+
except Exception as e:
|
| 13 |
+
# Gradio will show this on startup logs — helpful for debugging
|
| 14 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
| 15 |
|
| 16 |
+
def predict_phonemes(word: str):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
if not word or not word.strip():
|
| 18 |
return "Please enter a word", None
|
| 19 |
+
phonemes = model.predict_string(word)
|
| 20 |
+
# return phoneme string; no audio here (you can add TTS later)
|
| 21 |
+
return phonemes, None
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
css = """
|
| 24 |
+
.gradio-container { max-width: 900px; margin: auto; }
|
| 25 |
+
body { background: linear-gradient(135deg,#071024,#081226); color: #e6eef8; }
|
| 26 |
"""
|
| 27 |
|
| 28 |
+
with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
|
| 29 |
+
gr.Markdown("# 🧠 NetTALK phoneme predictor")
|
| 30 |
+
gr.Markdown("Enter a word and get ARPAbet phonemes predicted by the trained model.")
|
|
|
|
| 31 |
with gr.Row():
|
| 32 |
+
word = gr.Textbox(label="Enter word", placeholder="example: 'computer'", lines=1)
|
| 33 |
+
btn = gr.Button("Predict")
|
| 34 |
+
out_ph = gr.Textbox(label="Predicted ARPAbet Phonemes")
|
| 35 |
+
# placeholder for future audio output
|
| 36 |
+
out_audio = gr.Audio(label="Synthesized audio (optional)", visible=False)
|
|
|
|
|
|
|
| 37 |
|
| 38 |
+
btn.click(predict_phonemes, inputs=word, outputs=[out_ph, out_audio])
|
| 39 |
|
| 40 |
if __name__ == "__main__":
|
| 41 |
+
demo.launch(server_name="0.0.0.0", server_port=7860)
|
model_inference.py
CHANGED
|
@@ -1,66 +1,117 @@
|
|
| 1 |
# model_inference.py
|
|
|
|
| 2 |
import torch
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
class NetTALKWrapper:
|
| 6 |
-
def __init__(self,
|
| 7 |
-
# pick device automatically
|
| 8 |
if device is None:
|
| 9 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
self.device = device
|
| 11 |
|
| 12 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
try:
|
| 14 |
-
|
|
|
|
| 15 |
except Exception as e:
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def __init__(self):
|
| 22 |
-
super().__init__()
|
| 23 |
-
self.dummy = nn.Linear(10, 10)
|
| 24 |
-
def forward(self, x):
|
| 25 |
-
return torch.randn(1, 20) # placeholder
|
| 26 |
-
m = DummyModel()
|
| 27 |
-
sd = torch.load(model_path, map_location="cpu")
|
| 28 |
-
try:
|
| 29 |
-
m.load_state_dict(sd)
|
| 30 |
-
self.model = m.to(self.device)
|
| 31 |
-
except Exception:
|
| 32 |
-
raise RuntimeError("Could not load model. Please update model_inference.py to use your architecture.")
|
| 33 |
|
| 34 |
self.model.eval()
|
| 35 |
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
""
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
return torch.from_numpy(arr).to(self.device).float()
|
| 49 |
-
|
| 50 |
-
def decode_to_arpabet(self, model_output):
|
| 51 |
-
"""
|
| 52 |
-
Convert model raw output to an ARPAbet string (e.g., "HH AH0 L OW1").
|
| 53 |
-
Replace this with your decoder logic (argmax, beam search, label mapping, etc).
|
| 54 |
-
"""
|
| 55 |
-
# PLACEHOLDER: just return dummy tokens
|
| 56 |
-
return "AH0 N T EH1 R P AH0 B EH1 T"
|
| 57 |
-
|
| 58 |
-
def predict(self, word: str):
|
| 59 |
-
# basic sanitization
|
| 60 |
word = word.strip()
|
| 61 |
if not word:
|
| 62 |
-
return
|
| 63 |
-
|
| 64 |
with torch.no_grad():
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# model_inference.py
|
| 2 |
+
import os
|
| 3 |
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
+
# Window and hidden sizes must match your training config
|
| 8 |
+
WINDOW_SIZE = 7
|
| 9 |
+
HIDDEN_SIZE = 128
|
| 10 |
+
|
| 11 |
+
# Path to CMU dict in the repo root (must be present)
|
| 12 |
+
CMUDICT_PATH = "cmudict.dict.txt"
|
| 13 |
+
STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
|
| 14 |
+
|
| 15 |
+
# --- 1) Rebuild vocab from CMUdict (same method you used in notebook) ---
|
| 16 |
+
def build_vocab(cmudict_path=CMUDICT_PATH):
|
| 17 |
+
words = []
|
| 18 |
+
phones_all = []
|
| 19 |
+
with open(cmudict_path, "r", encoding="utf-8", errors="ignore") as f:
|
| 20 |
+
for line in f:
|
| 21 |
+
if line.strip() and not line.startswith(";;;"):
|
| 22 |
+
parts = line.strip().split()
|
| 23 |
+
w = parts[0]
|
| 24 |
+
p = parts[1:]
|
| 25 |
+
words.append(w)
|
| 26 |
+
phones_all.append(p)
|
| 27 |
+
|
| 28 |
+
# character vocab from words (include space for padding)
|
| 29 |
+
char_vocab = set("".join(words))
|
| 30 |
+
char_vocab.add(" ") # ensure space exists
|
| 31 |
+
char_to_idx = {c: i+1 for i, c in enumerate(sorted(char_vocab))} # reserve 0 for unknown/pad
|
| 32 |
+
char_to_idx["<PAD>"] = 0
|
| 33 |
+
idx_to_char = {i: c for c, i in char_to_idx.items()}
|
| 34 |
+
|
| 35 |
+
phone_vocab = set(phone for p_list in phones_all for phone in p_list)
|
| 36 |
+
phone_to_idx = {p: i for i, p in enumerate(sorted(phone_vocab))}
|
| 37 |
+
idx_to_phone = {i: p for p, i in phone_to_idx.items()}
|
| 38 |
+
|
| 39 |
+
return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
|
| 40 |
+
|
| 41 |
+
CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = build_vocab()
|
| 42 |
+
|
| 43 |
+
VOCAB_SIZE = len(CHAR_TO_IDX) # includes PAD token
|
| 44 |
+
NUM_PHONES = len(PHONE_TO_IDX)
|
| 45 |
+
|
| 46 |
+
# --- 2) Architecture matching your notebook ---
|
| 47 |
+
class PhonemeClassifier(nn.Module):
|
| 48 |
+
def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.window_size = window_size
|
| 51 |
+
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
|
| 52 |
+
self.fc1 = nn.Linear(hidden_size * window_size, hidden_size)
|
| 53 |
+
self.relu = nn.ReLU()
|
| 54 |
+
self.fc2 = nn.Linear(hidden_size, num_phones)
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
# x: (batch, window_size)
|
| 58 |
+
x = self.embedding(x) # (batch, window, hidden)
|
| 59 |
+
x = x.view(x.size(0), -1) # flatten window
|
| 60 |
+
x = self.relu(self.fc1(x))
|
| 61 |
+
x = self.fc2(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
# --- 3) Wrapper that loads state_dict and provides predict(word) ---
|
| 65 |
class NetTALKWrapper:
|
| 66 |
+
def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
|
|
|
|
| 67 |
if device is None:
|
| 68 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 69 |
self.device = device
|
| 70 |
|
| 71 |
+
# instantiate model with same architecture
|
| 72 |
+
self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
|
| 73 |
+
|
| 74 |
+
# Try loading state_dict
|
| 75 |
+
if not os.path.exists(state_dict_path):
|
| 76 |
+
raise FileNotFoundError(f"State dict not found at {state_dict_path}. Please upload it to the repo or set NETTALK_STATE_DICT env var.")
|
| 77 |
+
sd = torch.load(state_dict_path, map_location=self.device)
|
| 78 |
try:
|
| 79 |
+
# sd could be a dict directly (state_dict)
|
| 80 |
+
self.model.load_state_dict(sd)
|
| 81 |
except Exception as e:
|
| 82 |
+
# If the saved file contains extra keys (e.g., a checkpoint dict), try to extract 'model_state_dict'
|
| 83 |
+
if isinstance(sd, dict) and "model_state_dict" in sd:
|
| 84 |
+
self.model.load_state_dict(sd["model_state_dict"])
|
| 85 |
+
else:
|
| 86 |
+
raise RuntimeError("Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)") from e
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
self.model.eval()
|
| 89 |
|
| 90 |
+
def _word_to_windows(self, word):
|
| 91 |
+
# pad with spaces on both sides
|
| 92 |
+
pad = " " * (WINDOW_SIZE // 2)
|
| 93 |
+
padded = pad + word.lower() + pad
|
| 94 |
+
windows = []
|
| 95 |
+
for i in range(len(word)):
|
| 96 |
+
w = padded[i:i + WINDOW_SIZE]
|
| 97 |
+
idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
|
| 98 |
+
windows.append(idxs)
|
| 99 |
+
return torch.tensor(windows, dtype=torch.long, device=self.device) # (len(word), window_size)
|
| 100 |
+
|
| 101 |
+
def predict(self, word):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
word = word.strip()
|
| 103 |
if not word:
|
| 104 |
+
return []
|
| 105 |
+
windows = self._word_to_windows(word) # (L, window_size)
|
| 106 |
with torch.no_grad():
|
| 107 |
+
logits = self.model(windows) # (L, num_phones)
|
| 108 |
+
probs = torch.softmax(logits, dim=-1)
|
| 109 |
+
preds = torch.argmax(probs, dim=-1).cpu().numpy().tolist()
|
| 110 |
+
|
| 111 |
+
# map indices to ARPAbet tokens
|
| 112 |
+
phones = [IDX_TO_PHONE[p] for p in preds]
|
| 113 |
+
return phones
|
| 114 |
+
|
| 115 |
+
def predict_string(self, word):
|
| 116 |
+
phones = self.predict(word)
|
| 117 |
+
return " ".join(phones)
|
requirements.txt
CHANGED
|
@@ -1,12 +1,8 @@
|
|
| 1 |
torch
|
| 2 |
gradio>=3.0
|
| 3 |
numpy
|
| 4 |
-
scipy
|
| 5 |
soundfile
|
| 6 |
-
#
|
| 7 |
-
# For a fast fallback TTS:
|
| 8 |
gTTS
|
| 9 |
-
|
| 10 |
TTS
|
| 11 |
-
# Helpful: phonemizer if you want alternative phoneme utilities
|
| 12 |
-
phonemizer
|
|
|
|
| 1 |
torch
|
| 2 |
gradio>=3.0
|
| 3 |
numpy
|
|
|
|
| 4 |
soundfile
|
| 5 |
+
# optional (for audio synthesis later):
|
|
|
|
| 6 |
gTTS
|
| 7 |
+
pydub
|
| 8 |
TTS
|
|
|
|
|
|