harmonicsnail commited on
Commit
2379562
·
1 Parent(s): 13ae703

added json dump and implemented fix

Browse files
Files changed (1) hide show
  1. model_inference.py +8 -22
model_inference.py CHANGED
@@ -13,30 +13,16 @@ CMUDICT_PATH = "cmudict.dict.txt"
13
  STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
14
 
15
  # --- 1) Rebuild vocab from CMUdict (same method you used in notebook) ---
16
- def build_vocab(cmudict_path=CMUDICT_PATH):
17
- words = []
18
- phones_all = []
19
- with open(cmudict_path, "r", encoding="utf-8", errors="ignore") as f:
20
- for line in f:
21
- if line.strip() and not line.startswith(";;;"):
22
- parts = line.strip().split()
23
- w = parts[0]
24
- p = parts[1:]
25
- words.append(w)
26
- phones_all.append(p)
27
-
28
- # character vocab from words (include space for padding)
29
- char_vocab = set("".join(words))
30
- char_vocab.add(" ") # ensure space exists
31
- char_to_idx = {c: i+1 for i, c in enumerate(sorted(char_vocab))} # reserve 0 for unknown/pad
32
- char_to_idx["<PAD>"] = 0
33
- idx_to_char = {i: c for c, i in char_to_idx.items()}
34
 
35
- phone_vocab = set(phone for p_list in phones_all for phone in p_list)
36
- phone_to_idx = {p: i for i, p in enumerate(sorted(phone_vocab))}
37
- idx_to_phone = {i: p for p, i in phone_to_idx.items()}
 
 
38
 
39
- return char_to_idx, idx_to_char, phone_to_idx, idx_to_phone
 
40
 
41
  CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = build_vocab()
42
 
 
13
  STATE_DICT_PATH = os.environ.get("NETTALK_STATE_DICT", "nettalk_state_dict.pt")
14
 
15
  # --- 1) Rebuild vocab from CMUdict (same method you used in notebook) ---
16
+ import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ def load_vocab():
19
+ with open("char_vocab.json") as f:
20
+ char_to_idx = json.load(f)
21
+ idx_to_char = {i: c for c, i in char_to_idx.items()}
22
+ return char_to_idx, idx_to_char
23
 
24
+ CHAR_TO_IDX, IDX_TO_CHAR = load_vocab()
25
+ VOCAB_SIZE = len(CHAR_TO_IDX)
26
 
27
  CHAR_TO_IDX, IDX_TO_CHAR, PHONE_TO_IDX, IDX_TO_PHONE = build_vocab()
28