Modern_TalkNET / model_inference.py
harmonicsnail's picture
fixed model inference py
c5a6262
import os
import json
import torch
import torch.nn as nn
import numpy as np
# --- Config ---
WINDOW_SIZE = 7
HIDDEN_SIZE = 128
# Paths
STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
CHAR_VOCAB_PATH = "char_vocab.json"
PHONE_VOCAB_PATH = "phone_vocab.json"
# --- 1) Load vocabularies (must match training) ---
def load_vocab():
with open(CHAR_VOCAB_PATH, "r") as f:
char_to_idx = json.load(f)
with open(PHONE_VOCAB_PATH, "r") as f:
phone_to_idx = json.load(f)
idx_to_char = {int(v): k for k, v in char_to_idx.items()}
idx_to_phone = {int(v): k for k, v in phone_to_idx.items()}
return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab()
VOCAB_SIZE = len(CHAR_TO_IDX)
NUM_PHONES = len(PHONE_TO_IDX)
# --- 2) Model architecture (must match training) ---
class PhonemeClassifier(nn.Module):
def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
super().__init__()
self.window_size = window_size
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.fc1 = nn.Linear(hidden_size * window_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_phones)
def forward(self, x):
# x: (batch, window_size)
x = self.embedding(x)
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# --- 3) Wrapper for inference ---
class NetTALKWrapper:
def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
# Load weights safely
if not os.path.exists(state_dict_path):
raise FileNotFoundError(f"Missing model weights at {state_dict_path}")
sd = torch.load(state_dict_path, map_location=self.device)
try:
self.model.load_state_dict(sd)
except Exception as e:
if isinstance(sd, dict) and "model_state_dict" in sd:
self.model.load_state_dict(sd["model_state_dict"])
else:
raise RuntimeError(
"Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)"
) from e
self.model.eval()
def _word_to_windows(self, word):
pad = " " * (WINDOW_SIZE // 2)
padded = pad + word.lower() + pad
windows = []
for i in range(len(word)):
w = padded[i : i + WINDOW_SIZE]
idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
windows.append(idxs)
return torch.tensor(windows, dtype=torch.long, device=self.device)
def predict(self, word):
word = word.strip()
if not word:
return []
windows = self._word_to_windows(word)
with torch.no_grad():
logits = self.model(windows)
preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist()
phones = [IDX_TO_PHONE[p] for p in preds]
return phones
def predict_string(self, word):
return " ".join(self.predict(word))