Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| # --- Config --- | |
| WINDOW_SIZE = 7 | |
| HIDDEN_SIZE = 128 | |
| # Paths | |
| STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt") | |
| CHAR_VOCAB_PATH = "char_vocab.json" | |
| PHONE_VOCAB_PATH = "phone_vocab.json" | |
| # --- 1) Load vocabularies (must match training) --- | |
| def load_vocab(): | |
| with open(CHAR_VOCAB_PATH, "r") as f: | |
| char_to_idx = json.load(f) | |
| with open(PHONE_VOCAB_PATH, "r") as f: | |
| phone_to_idx = json.load(f) | |
| idx_to_char = {int(v): k for k, v in char_to_idx.items()} | |
| idx_to_phone = {int(v): k for k, v in phone_to_idx.items()} | |
| return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone | |
| CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab() | |
| VOCAB_SIZE = len(CHAR_TO_IDX) | |
| NUM_PHONES = len(PHONE_TO_IDX) | |
| # --- 2) Model architecture (must match training) --- | |
| class PhonemeClassifier(nn.Module): | |
| def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE): | |
| super().__init__() | |
| self.window_size = window_size | |
| self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0) | |
| self.fc1 = nn.Linear(hidden_size * window_size, hidden_size) | |
| self.relu = nn.ReLU() | |
| self.fc2 = nn.Linear(hidden_size, num_phones) | |
| def forward(self, x): | |
| # x: (batch, window_size) | |
| x = self.embedding(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.relu(self.fc1(x)) | |
| x = self.fc2(x) | |
| return x | |
| # --- 3) Wrapper for inference --- | |
| class NetTALKWrapper: | |
| def __init__(self, state_dict_path=STATE_DICT_PATH, device=None): | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.device = device | |
| self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device) | |
| # Load weights safely | |
| if not os.path.exists(state_dict_path): | |
| raise FileNotFoundError(f"Missing model weights at {state_dict_path}") | |
| sd = torch.load(state_dict_path, map_location=self.device) | |
| try: | |
| self.model.load_state_dict(sd) | |
| except Exception as e: | |
| if isinstance(sd, dict) and "model_state_dict" in sd: | |
| self.model.load_state_dict(sd["model_state_dict"]) | |
| else: | |
| raise RuntimeError( | |
| "Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)" | |
| ) from e | |
| self.model.eval() | |
| def _word_to_windows(self, word): | |
| pad = " " * (WINDOW_SIZE // 2) | |
| padded = pad + word.lower() + pad | |
| windows = [] | |
| for i in range(len(word)): | |
| w = padded[i : i + WINDOW_SIZE] | |
| idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w] | |
| windows.append(idxs) | |
| return torch.tensor(windows, dtype=torch.long, device=self.device) | |
| def predict(self, word): | |
| word = word.strip() | |
| if not word: | |
| return [] | |
| windows = self._word_to_windows(word) | |
| with torch.no_grad(): | |
| logits = self.model(windows) | |
| preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist() | |
| phones = [IDX_TO_PHONE[p] for p in preds] | |
| return phones | |
| def predict_string(self, word): | |
| return " ".join(self.predict(word)) | |