harmonicsnail commited on
Commit
7f6a0d4
·
1 Parent(s): 7e130d1

Added updated files

Browse files
Files changed (3) hide show
  1. model_inference.py +66 -0
  2. nettalk_model.pt +0 -0
  3. requirements.txt +12 -0
model_inference.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model_inference.py
2
+ import torch
3
+ import numpy as np
4
+
5
+ class NetTALKWrapper:
6
+ def __init__(self, model_path="nettalk_model.pt", device=None):
7
+ # pick device automatically
8
+ if device is None:
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.device = device
11
+
12
+ # If you saved state_dict, load accordingly:
13
+ try:
14
+ self.model = torch.load(model_path, map_location=self.device)
15
+ except Exception as e:
16
+ # fallback: user may have saved state_dict
17
+ print("torch.load failed; try loading state_dict. Error:", e)
18
+ # Example placeholder architecture - REPLACE with your actual model class
19
+ from torch import nn
20
+ class DummyModel(nn.Module):
21
+ def __init__(self):
22
+ super().__init__()
23
+ self.dummy = nn.Linear(10, 10)
24
+ def forward(self, x):
25
+ return torch.randn(1, 20) # placeholder
26
+ m = DummyModel()
27
+ sd = torch.load(model_path, map_location="cpu")
28
+ try:
29
+ m.load_state_dict(sd)
30
+ self.model = m.to(self.device)
31
+ except Exception:
32
+ raise RuntimeError("Could not load model. Please update model_inference.py to use your architecture.")
33
+
34
+ self.model.eval()
35
+
36
+ # ---- Replace these helper methods with your real preprocess/decoder ----
37
+ def preprocess(self, word: str):
38
+ """
39
+ Convert `word` (string) to input tensor expected by your NetTALK model.
40
+ Example NetTALK uses character windowing / one-hot encoding — replace below.
41
+ """
42
+ # PLACEHOLDER: map characters to indices, pad/truncate to length L, then to tensor
43
+ # *Replace with your actual preprocessing code*
44
+ max_len = 32
45
+ arr = np.zeros((1, max_len), dtype=np.int64)
46
+ for i, c in enumerate(word.lower()[:max_len]):
47
+ arr[0, i] = ord(c) # placeholder mapping
48
+ return torch.from_numpy(arr).to(self.device).float()
49
+
50
+ def decode_to_arpabet(self, model_output):
51
+ """
52
+ Convert model raw output to an ARPAbet string (e.g., "HH AH0 L OW1").
53
+ Replace this with your decoder logic (argmax, beam search, label mapping, etc).
54
+ """
55
+ # PLACEHOLDER: just return dummy tokens
56
+ return "AH0 N T EH1 R P AH0 B EH1 T"
57
+
58
+ def predict(self, word: str):
59
+ # basic sanitization
60
+ word = word.strip()
61
+ if not word:
62
+ return ""
63
+ x = self.preprocess(word)
64
+ with torch.no_grad():
65
+ y = self.model(x)
66
+ return self.decode_to_arpabet(y)
nettalk_model.pt ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ gradio>=3.0
3
+ numpy
4
+ scipy
5
+ soundfile
6
+ # Optional TTS backends (pick one):
7
+ # For a fast fallback TTS:
8
+ gTTS
9
+ # For a more advanced phoneme-aware TTS (may require GPU & larger install):
10
+ TTS
11
+ # Helpful: phonemizer if you want alternative phoneme utilities
12
+ phonemizer