import os import json import torch import torch.nn as nn import numpy as np # --- Config --- WINDOW_SIZE = 7 HIDDEN_SIZE = 128 # Paths STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt") CHAR_VOCAB_PATH = "char_vocab.json" PHONE_VOCAB_PATH = "phone_vocab.json" # --- 1) Load vocabularies (must match training) --- def load_vocab(): with open(CHAR_VOCAB_PATH, "r") as f: char_to_idx = json.load(f) with open(PHONE_VOCAB_PATH, "r") as f: phone_to_idx = json.load(f) idx_to_char = {int(v): k for k, v in char_to_idx.items()} idx_to_phone = {int(v): k for k, v in phone_to_idx.items()} return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab() VOCAB_SIZE = len(CHAR_TO_IDX) NUM_PHONES = len(PHONE_TO_IDX) # --- 2) Model architecture (must match training) --- class PhonemeClassifier(nn.Module): def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE): super().__init__() self.window_size = window_size self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0) self.fc1 = nn.Linear(hidden_size * window_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, num_phones) def forward(self, x): # x: (batch, window_size) x = self.embedding(x) x = x.view(x.size(0), -1) x = self.relu(self.fc1(x)) x = self.fc2(x) return x # --- 3) Wrapper for inference --- class NetTALKWrapper: def __init__(self, state_dict_path=STATE_DICT_PATH, device=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.device = device self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device) # Load weights safely if not os.path.exists(state_dict_path): raise FileNotFoundError(f"Missing model weights at {state_dict_path}") sd = torch.load(state_dict_path, map_location=self.device) try: self.model.load_state_dict(sd) except Exception as e: if isinstance(sd, dict) and "model_state_dict" in sd: self.model.load_state_dict(sd["model_state_dict"]) else: raise RuntimeError( "Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)" ) from e self.model.eval() def _word_to_windows(self, word): pad = " " * (WINDOW_SIZE // 2) padded = pad + word.lower() + pad windows = [] for i in range(len(word)): w = padded[i : i + WINDOW_SIZE] idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w] windows.append(idxs) return torch.tensor(windows, dtype=torch.long, device=self.device) def predict(self, word): word = word.strip() if not word: return [] windows = self._word_to_windows(word) with torch.no_grad(): logits = self.model(windows) preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist() phones = [IDX_TO_PHONE[p] for p in preds] return phones def predict_string(self, word): return " ".join(self.predict(word))