File size: 3,404 Bytes
9496a85
c5a6262
7f6a0d4
9496a85
7f6a0d4
 
c5a6262
9496a85
 
 
c5a6262
9496a85
c5a6262
 
9496a85
c5a6262
2379562
c5a6262
2379562
c5a6262
 
9496a85
c5a6262
 
 
 
9496a85
 
c5a6262
 
9496a85
 
c5a6262
 
9496a85
 
 
 
 
 
 
 
 
 
 
c5a6262
 
9496a85
 
 
 
c5a6262
 
7f6a0d4
9496a85
7f6a0d4
 
 
 
9496a85
 
c5a6262
9496a85
c5a6262
 
9496a85
7f6a0d4
9496a85
7f6a0d4
9496a85
 
 
c5a6262
 
 
7f6a0d4
 
 
9496a85
 
 
 
 
c5a6262
9496a85
 
c5a6262
9496a85
 
7f6a0d4
 
9496a85
c5a6262
7f6a0d4
c5a6262
 
9496a85
 
 
 
c5a6262
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import json
import torch
import torch.nn as nn
import numpy as np

# --- Config ---
WINDOW_SIZE = 7
HIDDEN_SIZE = 128

# Paths
STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
CHAR_VOCAB_PATH = "char_vocab.json"
PHONE_VOCAB_PATH = "phone_vocab.json"

# --- 1) Load vocabularies (must match training) ---
def load_vocab():
    with open(CHAR_VOCAB_PATH, "r") as f:
        char_to_idx = json.load(f)
    with open(PHONE_VOCAB_PATH, "r") as f:
        phone_to_idx = json.load(f)

    idx_to_char = {int(v): k for k, v in char_to_idx.items()}
    idx_to_phone = {int(v): k for k, v in phone_to_idx.items()}

    return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone


CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab()
VOCAB_SIZE = len(CHAR_TO_IDX)
NUM_PHONES = len(PHONE_TO_IDX)


# --- 2) Model architecture (must match training) ---
class PhonemeClassifier(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
        super().__init__()
        self.window_size = window_size
        self.embedding = nn.Embedding(vocab_size, hidden_size, padding_idx=0)
        self.fc1 = nn.Linear(hidden_size * window_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_phones)

    def forward(self, x):
        # x: (batch, window_size)
        x = self.embedding(x)
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# --- 3) Wrapper for inference ---
class NetTALKWrapper:
    def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
        if device is None:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.device = device

        self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)

        # Load weights safely
        if not os.path.exists(state_dict_path):
            raise FileNotFoundError(f"Missing model weights at {state_dict_path}")

        sd = torch.load(state_dict_path, map_location=self.device)
        try:
            self.model.load_state_dict(sd)
        except Exception as e:
            if isinstance(sd, dict) and "model_state_dict" in sd:
                self.model.load_state_dict(sd["model_state_dict"])
            else:
                raise RuntimeError(
                    "Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)"
                ) from e

        self.model.eval()

    def _word_to_windows(self, word):
        pad = " " * (WINDOW_SIZE // 2)
        padded = pad + word.lower() + pad
        windows = []
        for i in range(len(word)):
            w = padded[i : i + WINDOW_SIZE]
            idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
            windows.append(idxs)
        return torch.tensor(windows, dtype=torch.long, device=self.device)

    def predict(self, word):
        word = word.strip()
        if not word:
            return []
        windows = self._word_to_windows(word)
        with torch.no_grad():
            logits = self.model(windows)
            preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist()
        phones = [IDX_TO_PHONE[p] for p in preds]
        return phones

    def predict_string(self, word):
        return " ".join(self.predict(word))