Spaces:
Sleeping
Sleeping
Commit ·
c5a6262
1
Parent(s): 2379562
fixed model inference py
Browse files- model_inference.py +33 -35
model_inference.py
CHANGED
|
@@ -1,35 +1,37 @@
|
|
| 1 |
-
# model_inference.py
|
| 2 |
import os
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
-
#
|
| 8 |
WINDOW_SIZE = 7
|
| 9 |
HIDDEN_SIZE = 128
|
| 10 |
|
| 11 |
-
#
|
| 12 |
-
CMUDICT_PATH = "cmudict.dict.txt"
|
| 13 |
STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
# --- 1)
|
| 16 |
-
import json
|
| 17 |
-
|
| 18 |
def load_vocab():
|
| 19 |
-
with open("
|
| 20 |
char_to_idx = json.load(f)
|
| 21 |
-
|
| 22 |
-
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = build_vocab()
|
| 28 |
|
| 29 |
-
|
|
|
|
| 30 |
NUM_PHONES = len(PHONE_TO_IDX)
|
| 31 |
|
| 32 |
-
|
|
|
|
| 33 |
class PhonemeClassifier(nn.Module):
|
| 34 |
def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
|
| 35 |
super().__init__()
|
|
@@ -41,63 +43,59 @@ class PhonemeClassifier(nn.Module):
|
|
| 41 |
|
| 42 |
def forward(self, x):
|
| 43 |
# x: (batch, window_size)
|
| 44 |
-
x = self.embedding(x)
|
| 45 |
-
x = x.view(x.size(0), -1)
|
| 46 |
x = self.relu(self.fc1(x))
|
| 47 |
x = self.fc2(x)
|
| 48 |
return x
|
| 49 |
|
| 50 |
-
|
|
|
|
| 51 |
class NetTALKWrapper:
|
| 52 |
def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
|
| 53 |
if device is None:
|
| 54 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 55 |
self.device = device
|
| 56 |
|
| 57 |
-
# instantiate model with same architecture
|
| 58 |
self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
|
| 59 |
|
| 60 |
-
#
|
| 61 |
if not os.path.exists(state_dict_path):
|
| 62 |
-
raise FileNotFoundError(f"
|
|
|
|
| 63 |
sd = torch.load(state_dict_path, map_location=self.device)
|
| 64 |
try:
|
| 65 |
-
# sd could be a dict directly (state_dict)
|
| 66 |
self.model.load_state_dict(sd)
|
| 67 |
except Exception as e:
|
| 68 |
-
# If the saved file contains extra keys (e.g., a checkpoint dict), try to extract 'model_state_dict'
|
| 69 |
if isinstance(sd, dict) and "model_state_dict" in sd:
|
| 70 |
self.model.load_state_dict(sd["model_state_dict"])
|
| 71 |
else:
|
| 72 |
-
raise RuntimeError(
|
|
|
|
|
|
|
| 73 |
|
| 74 |
self.model.eval()
|
| 75 |
|
| 76 |
def _word_to_windows(self, word):
|
| 77 |
-
# pad with spaces on both sides
|
| 78 |
pad = " " * (WINDOW_SIZE // 2)
|
| 79 |
padded = pad + word.lower() + pad
|
| 80 |
windows = []
|
| 81 |
for i in range(len(word)):
|
| 82 |
-
w = padded[i:i + WINDOW_SIZE]
|
| 83 |
idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
|
| 84 |
windows.append(idxs)
|
| 85 |
-
return torch.tensor(windows, dtype=torch.long, device=self.device)
|
| 86 |
|
| 87 |
def predict(self, word):
|
| 88 |
word = word.strip()
|
| 89 |
if not word:
|
| 90 |
return []
|
| 91 |
-
windows = self._word_to_windows(word)
|
| 92 |
with torch.no_grad():
|
| 93 |
-
logits = self.model(windows)
|
| 94 |
-
|
| 95 |
-
preds = torch.argmax(probs, dim=-1).cpu().numpy().tolist()
|
| 96 |
-
|
| 97 |
-
# map indices to ARPAbet tokens
|
| 98 |
phones = [IDX_TO_PHONE[p] for p in preds]
|
| 99 |
return phones
|
| 100 |
|
| 101 |
def predict_string(self, word):
|
| 102 |
-
|
| 103 |
-
return " ".join(phones)
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
+
import json
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
import numpy as np
|
| 6 |
|
| 7 |
+
# --- Config ---
|
| 8 |
WINDOW_SIZE = 7
|
| 9 |
HIDDEN_SIZE = 128
|
| 10 |
|
| 11 |
+
# Paths
|
|
|
|
| 12 |
STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
|
| 13 |
+
CHAR_VOCAB_PATH = "char_vocab.json"
|
| 14 |
+
PHONE_VOCAB_PATH = "phone_vocab.json"
|
| 15 |
|
| 16 |
+
# --- 1) Load vocabularies (must match training) ---
|
|
|
|
|
|
|
| 17 |
def load_vocab():
|
| 18 |
+
with open(CHAR_VOCAB_PATH, "r") as f:
|
| 19 |
char_to_idx = json.load(f)
|
| 20 |
+
with open(PHONE_VOCAB_PATH, "r") as f:
|
| 21 |
+
phone_to_idx = json.load(f)
|
| 22 |
|
| 23 |
+
idx_to_char = {int(v): k for k, v in char_to_idx.items()}
|
| 24 |
+
idx_to_phone = {int(v): k for k, v in phone_to_idx.items()}
|
| 25 |
+
|
| 26 |
+
return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
|
| 27 |
|
|
|
|
| 28 |
|
| 29 |
+
CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab()
|
| 30 |
+
VOCAB_SIZE = len(CHAR_TO_IDX)
|
| 31 |
NUM_PHONES = len(PHONE_TO_IDX)
|
| 32 |
|
| 33 |
+
|
| 34 |
+
# --- 2) Model architecture (must match training) ---
|
| 35 |
class PhonemeClassifier(nn.Module):
|
| 36 |
def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
|
| 37 |
super().__init__()
|
|
|
|
| 43 |
|
| 44 |
def forward(self, x):
|
| 45 |
# x: (batch, window_size)
|
| 46 |
+
x = self.embedding(x)
|
| 47 |
+
x = x.view(x.size(0), -1)
|
| 48 |
x = self.relu(self.fc1(x))
|
| 49 |
x = self.fc2(x)
|
| 50 |
return x
|
| 51 |
|
| 52 |
+
|
| 53 |
+
# --- 3) Wrapper for inference ---
|
| 54 |
class NetTALKWrapper:
|
| 55 |
def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
|
| 56 |
if device is None:
|
| 57 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
self.device = device
|
| 59 |
|
|
|
|
| 60 |
self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
|
| 61 |
|
| 62 |
+
# Load weights safely
|
| 63 |
if not os.path.exists(state_dict_path):
|
| 64 |
+
raise FileNotFoundError(f"Missing model weights at {state_dict_path}")
|
| 65 |
+
|
| 66 |
sd = torch.load(state_dict_path, map_location=self.device)
|
| 67 |
try:
|
|
|
|
| 68 |
self.model.load_state_dict(sd)
|
| 69 |
except Exception as e:
|
|
|
|
| 70 |
if isinstance(sd, dict) and "model_state_dict" in sd:
|
| 71 |
self.model.load_state_dict(sd["model_state_dict"])
|
| 72 |
else:
|
| 73 |
+
raise RuntimeError(
|
| 74 |
+
"Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)"
|
| 75 |
+
) from e
|
| 76 |
|
| 77 |
self.model.eval()
|
| 78 |
|
| 79 |
def _word_to_windows(self, word):
|
|
|
|
| 80 |
pad = " " * (WINDOW_SIZE // 2)
|
| 81 |
padded = pad + word.lower() + pad
|
| 82 |
windows = []
|
| 83 |
for i in range(len(word)):
|
| 84 |
+
w = padded[i : i + WINDOW_SIZE]
|
| 85 |
idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
|
| 86 |
windows.append(idxs)
|
| 87 |
+
return torch.tensor(windows, dtype=torch.long, device=self.device)
|
| 88 |
|
| 89 |
def predict(self, word):
|
| 90 |
word = word.strip()
|
| 91 |
if not word:
|
| 92 |
return []
|
| 93 |
+
windows = self._word_to_windows(word)
|
| 94 |
with torch.no_grad():
|
| 95 |
+
logits = self.model(windows)
|
| 96 |
+
preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist()
|
|
|
|
|
|
|
|
|
|
| 97 |
phones = [IDX_TO_PHONE[p] for p in preds]
|
| 98 |
return phones
|
| 99 |
|
| 100 |
def predict_string(self, word):
|
| 101 |
+
return " ".join(self.predict(word))
|
|
|