drixo commited on
Commit
90a21e7
·
verified ·
1 Parent(s): f4d6261

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +20 -8
inference.py CHANGED
@@ -1,29 +1,41 @@
1
  import torch
2
  import torchaudio
 
3
  from .model import RealtimeTTS
4
  from .tokenizer import TTSTokenizer
5
  from .config import TTSConfig
6
 
 
7
  class TTSInference:
8
- def __init__(self, model_path, tokenizer_path, device="cpu"):
9
- self.device = device
 
 
 
10
  self.config = TTSConfig()
11
 
12
- self.model = RealtimeTTS(self.config).to(device)
13
- self.model.load_state_dict(torch.load(model_path, map_location=device))
 
 
14
  self.model.eval()
15
 
16
  self.tokenizer = TTSTokenizer(tokenizer_path)
17
 
18
- self.vocoder = torchaudio.pipelines.HIFIGAN_VOCODER_V3.get_model().to(device)
 
 
 
 
19
 
20
  @torch.no_grad()
21
- def synthesize(self, text):
22
  tokens = self.tokenizer.encode(text)
23
  tokens = torch.tensor(tokens).unsqueeze(0).to(self.device)
24
 
25
- # Dummy mel input for autoregressive decoding
26
- mel_input = torch.zeros(1, 200, self.config.d_model).to(self.device)
 
27
 
28
  mel = self.model(tokens, mel_input)
29
 
 
1
  import torch
2
  import torchaudio
3
+
4
  from .model import RealtimeTTS
5
  from .tokenizer import TTSTokenizer
6
  from .config import TTSConfig
7
 
8
+
9
  class TTSInference:
10
+ def __init__(self, model_path, tokenizer_path, device=None):
11
+ self.device = device or (
12
+ "cuda" if torch.cuda.is_available() else "cpu"
13
+ )
14
+
15
  self.config = TTSConfig()
16
 
17
+ self.model = RealtimeTTS(self.config).to(self.device)
18
+ self.model.load_state_dict(
19
+ torch.load(model_path, map_location=self.device)
20
+ )
21
  self.model.eval()
22
 
23
  self.tokenizer = TTSTokenizer(tokenizer_path)
24
 
25
+ self.vocoder = (
26
+ torchaudio.pipelines.HIFIGAN_VOCODER_V3
27
+ .get_model()
28
+ .to(self.device)
29
+ )
30
 
31
  @torch.no_grad()
32
+ def synthesize(self, text: str):
33
  tokens = self.tokenizer.encode(text)
34
  tokens = torch.tensor(tokens).unsqueeze(0).to(self.device)
35
 
36
+ mel_input = torch.zeros(
37
+ 1, tokens.size(1), self.config.d_model
38
+ ).to(self.device)
39
 
40
  mel = self.model(tokens, mel_input)
41