harmonicsnail commited on
Commit
9496a85
·
1 Parent(s): 2a517e7

Add inference model + Gradio app

Browse files
Files changed (3) hide show
  1. app.py +24 -78
  2. model_inference.py +100 -49
  3. requirements.txt +2 -6
app.py CHANGED
@@ -1,95 +1,41 @@
1
  # app.py
2
  import gradio as gr
3
  import os
4
- import numpy as np
5
- import soundfile as sf
6
- import tempfile
7
  from model_inference import NetTALKWrapper
8
 
9
- # choose TTS backend: "gtts" or "coqui" (TTS) or "none"
10
- TTS_BACKEND = os.environ.get("TTS_BACKEND", "gtts")
11
 
12
- # load model once (fast startup if model is cached)
13
- MODEL_PATH = "nettalk_model.pt"
14
- model = NetTALKWrapper(MODEL_PATH)
 
 
 
15
 
16
- # optional: simple gTTS-based synth (works by speaking the phoneme string as text)
17
- def synthesize_gtts(phoneme_text):
18
- from gtts import gTTS
19
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
20
- # gTTS outputs mp3 -> convert to wav using soundfile via numpy? Simpler: save mp3 then load then re-save wav
21
- mp3_tmp = tmp.name + ".mp3"
22
- tts = gTTS(phoneme_text, lang="en")
23
- tts.save(mp3_tmp)
24
- # load mp3 with soundfile may not work; scipy can read via pydub if available.
25
- try:
26
- import pydub
27
- audio = pydub.AudioSegment.from_mp3(mp3_tmp)
28
- audio.export(tmp.name, format="wav")
29
- except Exception:
30
- # fallback: return mp3 (Gradio accepts mp3 as audio)
31
- return mp3_tmp
32
- return tmp.name
33
-
34
- # optional: Coqui TTS (phoneme-aware) - heavier but can take ARPAbet inputs
35
- def synthesize_coqui(arpabet):
36
- # This requires the `TTS` package and an appropriate model that accepts phoneme input.
37
- try:
38
- from TTS.api import TTS
39
- except Exception as e:
40
- raise RuntimeError("TTS package not installed or failed to import.") from e
41
-
42
- # choose a model name you installed / that exists; example placeholder:
43
- tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
44
- # Some TTS models accept `phoneme` argument or `phoneme_input=True`. Check the model docs.
45
- wav = tts.tts(arpabet, speaker=None, phoneme_input=False)
46
- # wav is a numpy array and sample rate accessible via tts.synthesizer.output_sample_rate
47
- sr = tts.synthesizer.output_sample_rate if hasattr(tts.synthesizer, "output_sample_rate") else 22050
48
- tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
49
- sf.write(tmp.name, wav, sr)
50
- return tmp.name
51
-
52
- def predict_and_speak(word):
53
  if not word or not word.strip():
54
  return "Please enter a word", None
 
 
 
55
 
56
- phonemes = model.predict(word)
57
- audio_path = None
58
-
59
- # Try preferred backend
60
- try:
61
- if TTS_BACKEND == "coqui":
62
- audio_path = synthesize_coqui(phonemes)
63
- else:
64
- audio_path = synthesize_gtts(phonemes)
65
- except Exception as e:
66
- # If synth fails, still return phonemes and a None audio
67
- print("Synthesis failed:", e)
68
- audio_path = None
69
-
70
- # gr.Audio accepts: filename (wav/mp3), numpy array, or (np, sr)
71
- return phonemes, audio_path
72
-
73
- # ---- Gradio UI ----
74
  css = """
75
- body { background: linear-gradient(135deg,#0f172a,#020617); color: #e6eef8; }
76
- .gradio-container { max-width: 900px; margin: auto; padding: 20px; }
77
  """
78
 
79
- with gr.Blocks(css=css, theme=gr.themes.Default()) as demo:
80
- gr.Markdown("# 🧠 NetTALK ARPAbet demo")
81
- gr.Markdown("Enter a word, get predicted ARPAbet phonemes and a synthesized audio preview.")
82
-
83
  with gr.Row():
84
- word_in = gr.Textbox(label="Enter word", placeholder="example: 'computer'", lines=1)
85
- run_btn = gr.Button("Predict")
86
-
87
- phoneme_out = gr.Textbox(label="Predicted ARPAbet Phonemes")
88
- audio_out = gr.Audio(label="Synthesized audio (preview)")
89
-
90
- run_btn.click(fn=predict_and_speak, inputs=[word_in], outputs=[phoneme_out, audio_out])
91
 
92
- gr.Markdown("Tip: Replace `preprocess()` and `decode_to_arpabet()` in `model_inference.py` with your real model code.")
93
 
94
  if __name__ == "__main__":
95
- demo.launch()
 
1
  # app.py
2
  import gradio as gr
3
  import os
 
 
 
4
  from model_inference import NetTALKWrapper
5
 
6
+ # Optional: set env var NETTALK_STATE_DICT to different filename if needed
7
+ STATE_DICT = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
8
 
9
+ # instantiate the model once
10
+ try:
11
+ model = NetTALKWrapper(state_dict_path=STATE_DICT)
12
+ except Exception as e:
13
+ # Gradio will show this on startup logs — helpful for debugging
14
+ raise RuntimeError(f"Failed to load model: {e}")
15
 
16
+ def predict_phonemes(word: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  if not word or not word.strip():
18
  return "Please enter a word", None
19
+ phonemes = model.predict_string(word)
20
+ # return phoneme string; no audio here (you can add TTS later)
21
+ return phonemes, None
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  css = """
24
+ .gradio-container { max-width: 900px; margin: auto; }
25
+ body { background: linear-gradient(135deg,#071024,#081226); color: #e6eef8; }
26
  """
27
 
28
+ with gr.Blocks(css=css, theme=gr.themes.Base()) as demo:
29
+ gr.Markdown("# 🧠 NetTALK phoneme predictor")
30
+ gr.Markdown("Enter a word and get ARPAbet phonemes predicted by the trained model.")
 
31
  with gr.Row():
32
+ word = gr.Textbox(label="Enter word", placeholder="example: 'computer'", lines=1)
33
+ btn = gr.Button("Predict")
34
+ out_ph = gr.Textbox(label="Predicted ARPAbet Phonemes")
35
+ # placeholder for future audio output
36
+ out_audio = gr.Audio(label="Synthesized audio (optional)", visible=False)
 
 
37
 
38
+ btn.click(predict_phonemes, inputs=word, outputs=[out_ph, out_audio])
39
 
40
  if __name__ == "__main__":
41
+ demo.launch(server_name="0.0.0.0", server_port=7860)
model_inference.py CHANGED
@@ -1,66 +1,117 @@
1
  # model_inference.py
 
2
  import torch
 
3
  import numpy as np
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  class NetTALKWrapper:
6
- def __init__(self, model_path="nettalk_model.pt", device=None):
7
- # pick device automatically
8
  if device is None:
9
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
  self.device = device
11
 
12
- # If you saved state_dict, load accordingly:
 
 
 
 
 
 
13
  try:
14
- self.model = torch.load(model_path, map_location=self.device)
 
15
  except Exception as e:
16
- # fallback: user may have saved state_dict
17
- print("torch.load failed; try loading state_dict. Error:", e)
18
- # Example placeholder architecture - REPLACE with your actual model class
19
- from torch import nn
20
- class DummyModel(nn.Module):
21
- def __init__(self):
22
- super().__init__()
23
- self.dummy = nn.Linear(10, 10)
24
- def forward(self, x):
25
- return torch.randn(1, 20) # placeholder
26
- m = DummyModel()
27
- sd = torch.load(model_path, map_location="cpu")
28
- try:
29
- m.load_state_dict(sd)
30
- self.model = m.to(self.device)
31
- except Exception:
32
- raise RuntimeError("Could not load model. Please update model_inference.py to use your architecture.")
33
 
34
  self.model.eval()
35
 
36
- # ---- Replace these helper methods with your real preprocess/decoder ----
37
- def preprocess(self, word: str):
38
- """
39
- Convert `word` (string) to input tensor expected by your NetTALK model.
40
- Example NetTALK uses character windowing / one-hot encoding — replace below.
41
- """
42
- # PLACEHOLDER: map characters to indices, pad/truncate to length L, then to tensor
43
- # *Replace with your actual preprocessing code*
44
- max_len = 32
45
- arr = np.zeros((1, max_len), dtype=np.int64)
46
- for i, c in enumerate(word.lower()[:max_len]):
47
- arr[0, i] = ord(c) # placeholder mapping
48
- return torch.from_numpy(arr).to(self.device).float()
49
-
50
- def decode_to_arpabet(self, model_output):
51
- """
52
- Convert model raw output to an ARPAbet string (e.g., "HH AH0 L OW1").
53
- Replace this with your decoder logic (argmax, beam search, label mapping, etc).
54
- """
55
- # PLACEHOLDER: just return dummy tokens
56
- return "AH0 N T EH1 R P AH0 B EH1 T"
57
-
58
- def predict(self, word: str):
59
- # basic sanitization
60
  word = word.strip()
61
  if not word:
62
- return ""
63
- x = self.preprocess(word)
64
  with torch.no_grad():
65
- y = self.model(x)
66
- return self.decode_to_arpabet(y)
 
 
 
 
 
 
 
 
 
 
1
  # model_inference.py
2
+ import os
3
  import torch
4
+ import torch.nn as nn
5
  import numpy as np
6
 
7
+ # Window and hidden sizes must match your training config
8
+ WINDOW_SIZE = 7
9
+ HIDDEN_SIZE = 128
10
+
11
+ # Path to CMU dict in the repo root (must be present)
12
+ CMUDICT_PATH = "cmudict.dict.txt"
13
+ STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
14
+
15
+ # --- 1) Rebuild vocab from CMUdict (same method you used in notebook) ---
16
+ def build_vocab(cmudict_path=CMUDICT_PATH):
17
+ words = []
18
+ phones_all = []
19
+ with open(cmudict_path, "r", encoding="utf-8", errors="ignore") as f:
20
+ for line in f:
21
+ if line.strip() and not line.startswith(";;;"):
22
+ parts = line.strip().split()
23
+ w = parts[0]
24
+ p = parts[1:]
25
+ words.append(w)
26
+ phones_all.append(p)
27
+
28
+ # character vocab from words (include space for padding)
29
+ char_vocab = set("".join(words))
30
+ char_vocab.add(" ") # ensure space exists
31
+ char_to_idx = {c: i+1 for i, c in enumerate(sorted(char_vocab))} # reserve 0 for unknown/pad
32
+ char_to_idx["<PAD>"] = 0
33
+ idx_to_char = {i: c for c, i in char_to_idx.items()}
34
+
35
+ phone_vocab = set(phone for p_list in phones_all for phone in p_list)
36
+ phone_to_idx = {p: i for i, p in enumerate(sorted(phone_vocab))}
37
+ idx_to_phone = {i: p for p, i in phone_to_idx.items()}
38
+
39
+ return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
40
+
41
+ CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = build_vocab()
42
+
43
+ VOCAB_SIZE = len(CHAR_TO_IDX) # includes PAD token
44
+ NUM_PHONES = len(PHONE_TO_IDX)
45
+
46
+ # --- 2) Architecture matching your notebook ---
47
+ class PhonemeClassifier(nn.Module):
48
+ def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
49
+ super().__init__()
50
+ self.window_size = window_size
51
+ self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
52
+ self.fc1 = nn.Linear(hidden_size * window_size, hidden_size)
53
+ self.relu = nn.ReLU()
54
+ self.fc2 = nn.Linear(hidden_size, num_phones)
55
+
56
+ def forward(self, x):
57
+ # x: (batch, window_size)
58
+ x = self.embedding(x) # (batch, window, hidden)
59
+ x = x.view(x.size(0), -1) # flatten window
60
+ x = self.relu(self.fc1(x))
61
+ x = self.fc2(x)
62
+ return x
63
+
64
+ # --- 3) Wrapper that loads state_dict and provides predict(word) ---
65
  class NetTALKWrapper:
66
+ def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
 
67
  if device is None:
68
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  self.device = device
70
 
71
+ # instantiate model with same architecture
72
+ self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
73
+
74
+ # Try loading state_dict
75
+ if not os.path.exists(state_dict_path):
76
+ raise FileNotFoundError(f"State dict not found at {state_dict_path}. Please upload it to the repo or set NETTALK_STATE_DICT env var.")
77
+ sd = torch.load(state_dict_path, map_location=self.device)
78
  try:
79
+ # sd could be a dict directly (state_dict)
80
+ self.model.load_state_dict(sd)
81
  except Exception as e:
82
+ # If the saved file contains extra keys (e.g., a checkpoint dict), try to extract 'model_state_dict'
83
+ if isinstance(sd, dict) and "model_state_dict" in sd:
84
+ self.model.load_state_dict(sd["model_state_dict"])
85
+ else:
86
+ raise RuntimeError("Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)") from e
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  self.model.eval()
89
 
90
+ def _word_to_windows(self, word):
91
+ # pad with spaces on both sides
92
+ pad = " " * (WINDOW_SIZE // 2)
93
+ padded = pad + word.lower() + pad
94
+ windows = []
95
+ for i in range(len(word)):
96
+ w = padded[i:i + WINDOW_SIZE]
97
+ idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
98
+ windows.append(idxs)
99
+ return torch.tensor(windows, dtype=torch.long, device=self.device) # (len(word), window_size)
100
+
101
+ def predict(self, word):
 
 
 
 
 
 
 
 
 
 
 
 
102
  word = word.strip()
103
  if not word:
104
+ return []
105
+ windows = self._word_to_windows(word) # (L, window_size)
106
  with torch.no_grad():
107
+ logits = self.model(windows) # (L, num_phones)
108
+ probs = torch.softmax(logits, dim=-1)
109
+ preds = torch.argmax(probs, dim=-1).cpu().numpy().tolist()
110
+
111
+ # map indices to ARPAbet tokens
112
+ phones = [IDX_TO_PHONE[p] for p in preds]
113
+ return phones
114
+
115
+ def predict_string(self, word):
116
+ phones = self.predict(word)
117
+ return " ".join(phones)
requirements.txt CHANGED
@@ -1,12 +1,8 @@
1
  torch
2
  gradio>=3.0
3
  numpy
4
- scipy
5
  soundfile
6
- # Optional TTS backends (pick one):
7
- # For a fast fallback TTS:
8
  gTTS
9
- # For a more advanced phoneme-aware TTS (may require GPU & larger install):
10
  TTS
11
- # Helpful: phonemizer if you want alternative phoneme utilities
12
- phonemizer
 
1
  torch
2
  gradio>=3.0
3
  numpy
 
4
  soundfile
5
+ # optional (for audio synthesis later):
 
6
  gTTS
7
+ pydub
8
  TTS