Spaces:
Sleeping
Sleeping
File size: 3,404 Bytes
9496a85 c5a6262 7f6a0d4 9496a85 7f6a0d4 c5a6262 9496a85 c5a6262 9496a85 c5a6262 9496a85 c5a6262 2379562 c5a6262 2379562 c5a6262 9496a85 c5a6262 9496a85 c5a6262 9496a85 c5a6262 9496a85 c5a6262 9496a85 c5a6262 7f6a0d4 9496a85 7f6a0d4 9496a85 c5a6262 9496a85 c5a6262 9496a85 7f6a0d4 9496a85 7f6a0d4 9496a85 c5a6262 7f6a0d4 9496a85 c5a6262 9496a85 c5a6262 9496a85 7f6a0d4 9496a85 c5a6262 7f6a0d4 c5a6262 9496a85 c5a6262 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 | import os
import json
import torch
import torch.nn as nn
import numpy as np
# --- Config ---
WINDOW_SIZE = 7
HIDDEN_SIZE = 128
# Paths
STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
CHAR_VOCAB_PATH = "char_vocab.json"
PHONE_VOCAB_PATH = "phone_vocab.json"
# --- 1) Load vocabularies (must match training) ---
def load_vocab():
with open(CHAR_VOCAB_PATH, "r") as f:
char_to_idx = json.load(f)
with open(PHONE_VOCAB_PATH, "r") as f:
phone_to_idx = json.load(f)
idx_to_char = {int(v): k for k, v in char_to_idx.items()}
idx_to_phone = {int(v): k for k, v in phone_to_idx.items()}
return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab()
VOCAB_SIZE = len(CHAR_TO_IDX)
NUM_PHONES = len(PHONE_TO_IDX)
# --- 2) Model architecture (must match training) ---
class PhonemeClassifier(nn.Module):
def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
super().__init__()
self.window_size = window_size
self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
self.fc1 = nn.Linear(hidden_size * window_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_phones)
def forward(self, x):
# x: (batch, window_size)
x = self.embedding(x)
x = x.view(x.size(0), -1)
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
# --- 3) Wrapper for inference ---
class NetTALKWrapper:
def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.device = device
self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
# Load weights safely
if not os.path.exists(state_dict_path):
raise FileNotFoundError(f"Missing model weights at {state_dict_path}")
sd = torch.load(state_dict_path, map_location=self.device)
try:
self.model.load_state_dict(sd)
except Exception as e:
if isinstance(sd, dict) and "model_state_dict" in sd:
self.model.load_state_dict(sd["model_state_dict"])
else:
raise RuntimeError(
"Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)"
) from e
self.model.eval()
def _word_to_windows(self, word):
pad = " " * (WINDOW_SIZE // 2)
padded = pad + word.lower() + pad
windows = []
for i in range(len(word)):
w = padded[i : i + WINDOW_SIZE]
idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
windows.append(idxs)
return torch.tensor(windows, dtype=torch.long, device=self.device)
def predict(self, word):
word = word.strip()
if not word:
return []
windows = self._word_to_windows(word)
with torch.no_grad():
logits = self.model(windows)
preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist()
phones = [IDX_TO_PHONE[p] for p in preds]
return phones
def predict_string(self, word):
return " ".join(self.predict(word))
|