harmonicsnail commited on
Commit
c5a6262
·
1 Parent(s): 2379562

fixed model inference py

Browse files
Files changed (1) hide show
  1. model_inference.py +33 -35
model_inference.py CHANGED
@@ -1,35 +1,37 @@
1
- # model_inference.py
2
  import os
 
3
  import torch
4
  import torch.nn as nn
5
  import numpy as np
6
 
7
- # Window and hidden sizes must match your training config
8
  WINDOW_SIZE = 7
9
  HIDDEN_SIZE = 128
10
 
11
- # Path to CMU dict in the repo root (must be present)
12
- CMUDICT_PATH = "cmudict.dict.txt"
13
  STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
 
 
14
 
15
- # --- 1) Rebuild vocab from CMUdict (same method you used in notebook) ---
16
- import json
17
-
18
  def load_vocab():
19
- with open("char_vocab.json") as f:
20
  char_to_idx = json.load(f)
21
- idx_to_char = {i: c for c, i in char_to_idx.items()}
22
- return char_to_idx, idx_to_char
23
 
24
- CHAR_TO_IDX, IDX_TO_CHAR = load_vocab()
25
- VOCAB_SIZE = len(CHAR_TO_IDX)
 
 
26
 
27
- CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = build_vocab()
28
 
29
- VOCAB_SIZE = len(CHAR_TO_IDX) # includes PAD token
 
30
  NUM_PHONES = len(PHONE_TO_IDX)
31
 
32
- # --- 2) Architecture matching your notebook ---
 
33
  class PhonemeClassifier(nn.Module):
34
  def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
35
  super().__init__()
@@ -41,63 +43,59 @@ class PhonemeClassifier(nn.Module):
41
 
42
  def forward(self, x):
43
  # x: (batch, window_size)
44
- x = self.embedding(x) # (batch, window, hidden)
45
- x = x.view(x.size(0), -1) # flatten window
46
  x = self.relu(self.fc1(x))
47
  x = self.fc2(x)
48
  return x
49
 
50
- # --- 3) Wrapper that loads state_dict and provides predict(word) ---
 
51
  class NetTALKWrapper:
52
  def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
53
  if device is None:
54
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
  self.device = device
56
 
57
- # instantiate model with same architecture
58
  self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
59
 
60
- # Try loading state_dict
61
  if not os.path.exists(state_dict_path):
62
- raise FileNotFoundError(f"State dict not found at {state_dict_path}. Please upload it to the repo or set NETTALK_STATE_DICT env var.")
 
63
  sd = torch.load(state_dict_path, map_location=self.device)
64
  try:
65
- # sd could be a dict directly (state_dict)
66
  self.model.load_state_dict(sd)
67
  except Exception as e:
68
- # If the saved file contains extra keys (e.g., a checkpoint dict), try to extract 'model_state_dict'
69
  if isinstance(sd, dict) and "model_state_dict" in sd:
70
  self.model.load_state_dict(sd["model_state_dict"])
71
  else:
72
- raise RuntimeError("Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)") from e
 
 
73
 
74
  self.model.eval()
75
 
76
  def _word_to_windows(self, word):
77
- # pad with spaces on both sides
78
  pad = " " * (WINDOW_SIZE // 2)
79
  padded = pad + word.lower() + pad
80
  windows = []
81
  for i in range(len(word)):
82
- w = padded[i:i + WINDOW_SIZE]
83
  idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
84
  windows.append(idxs)
85
- return torch.tensor(windows, dtype=torch.long, device=self.device) # (len(word), window_size)
86
 
87
  def predict(self, word):
88
  word = word.strip()
89
  if not word:
90
  return []
91
- windows = self._word_to_windows(word) # (L, window_size)
92
  with torch.no_grad():
93
- logits = self.model(windows) # (L, num_phones)
94
- probs = torch.softmax(logits, dim=-1)
95
- preds = torch.argmax(probs, dim=-1).cpu().numpy().tolist()
96
-
97
- # map indices to ARPAbet tokens
98
  phones = [IDX_TO_PHONE[p] for p in preds]
99
  return phones
100
 
101
  def predict_string(self, word):
102
- phones = self.predict(word)
103
- return " ".join(phones)
 
 
1
  import os
2
+ import json
3
  import torch
4
  import torch.nn as nn
5
  import numpy as np
6
 
7
+ # --- Config ---
8
  WINDOW_SIZE = 7
9
  HIDDEN_SIZE = 128
10
 
11
+ # Paths
 
12
  STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
13
+ CHAR_VOCAB_PATH = "char_vocab.json"
14
+ PHONE_VOCAB_PATH = "phone_vocab.json"
15
 
16
+ # --- 1) Load vocabularies (must match training) ---
 
 
17
  def load_vocab():
18
+ with open(CHAR_VOCAB_PATH, "r") as f:
19
  char_to_idx = json.load(f)
20
+ with open(PHONE_VOCAB_PATH, "r") as f:
21
+ phone_to_idx = json.load(f)
22
 
23
+ idx_to_char = {int(v): k for k, v in char_to_idx.items()}
24
+ idx_to_phone = {int(v): k for k, v in phone_to_idx.items()}
25
+
26
+ return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
27
 
 
28
 
29
+ CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = load_vocab()
30
+ VOCAB_SIZE = len(CHAR_TO_IDX)
31
  NUM_PHONES = len(PHONE_TO_IDX)
32
 
33
+
34
+ # --- 2) Model architecture (must match training) ---
35
  class PhonemeClassifier(nn.Module):
36
  def __init__(self, vocab_size, hidden_size, num_phones, window_size=WINDOW_SIZE):
37
  super().__init__()
 
43
 
44
  def forward(self, x):
45
  # x: (batch, window_size)
46
+ x = self.embedding(x)
47
+ x = x.view(x.size(0), -1)
48
  x = self.relu(self.fc1(x))
49
  x = self.fc2(x)
50
  return x
51
 
52
+
53
+ # --- 3) Wrapper for inference ---
54
  class NetTALKWrapper:
55
  def __init__(self, state_dict_path=STATE_DICT_PATH, device=None):
56
  if device is None:
57
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
  self.device = device
59
 
 
60
  self.model = PhonemeClassifier(VOCAB_SIZE, HIDDEN_SIZE, NUM_PHONES, WINDOW_SIZE).to(self.device)
61
 
62
+ # Load weights safely
63
  if not os.path.exists(state_dict_path):
64
+ raise FileNotFoundError(f"Missing model weights at {state_dict_path}")
65
+
66
  sd = torch.load(state_dict_path, map_location=self.device)
67
  try:
 
68
  self.model.load_state_dict(sd)
69
  except Exception as e:
 
70
  if isinstance(sd, dict) and "model_state_dict" in sd:
71
  self.model.load_state_dict(sd["model_state_dict"])
72
  else:
73
+ raise RuntimeError(
74
+ "Failed to load state_dict. Ensure you saved with torch.save(model.state_dict(), ...)"
75
+ ) from e
76
 
77
  self.model.eval()
78
 
79
  def _word_to_windows(self, word):
 
80
  pad = " " * (WINDOW_SIZE // 2)
81
  padded = pad + word.lower() + pad
82
  windows = []
83
  for i in range(len(word)):
84
+ w = padded[i : i + WINDOW_SIZE]
85
  idxs = [CHAR_TO_IDX.get(ch, 0) for ch in w]
86
  windows.append(idxs)
87
+ return torch.tensor(windows, dtype=torch.long, device=self.device)
88
 
89
  def predict(self, word):
90
  word = word.strip()
91
  if not word:
92
  return []
93
+ windows = self._word_to_windows(word)
94
  with torch.no_grad():
95
+ logits = self.model(windows)
96
+ preds = torch.argmax(torch.softmax(logits, dim=-1), dim=-1).cpu().numpy().tolist()
 
 
 
97
  phones = [IDX_TO_PHONE[p] for p in preds]
98
  return phones
99
 
100
  def predict_string(self, word):
101
+ return " ".join(self.predict(word))