phatdo commited on
Commit
baf9bda
·
verified ·
1 Parent(s): 67586db

Update utils/model.py

Browse files
Files changed (1) hide show
  1. utils/model.py +6 -2
utils/model.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import os
2
  import json
3
 
@@ -9,6 +11,7 @@ from model import FastSpeech2, ScheduledOptim
9
 
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
 
 
12
  def get_model(args, configs, device, train=False):
13
  (preprocess_config, model_config, train_config) = configs
14
 
@@ -34,6 +37,7 @@ def get_model(args, configs, device, train=False):
34
  model.requires_grad_ = False
35
  return model
36
 
 
37
  def get_model_infer(ckpt_path, configs, device):
38
  (preprocess_config, model_config, train_config) = configs
39
 
@@ -49,7 +53,7 @@ def get_param_num(model):
49
  num_param = sum(param.numel() for param in model.parameters())
50
  return num_param
51
 
52
-
53
  def get_vocoder(config, device):
54
  name = config["vocoder"]["model"]
55
  speaker = config["vocoder"]["speaker"]
@@ -81,7 +85,7 @@ def get_vocoder(config, device):
81
 
82
  return vocoder
83
 
84
-
85
  def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
86
  name = model_config["vocoder"]["model"]
87
  with torch.no_grad():
 
1
+ import spaces
2
+
3
  import os
4
  import json
5
 
 
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
14
+ @spaces.GPU(duration=10)
15
  def get_model(args, configs, device, train=False):
16
  (preprocess_config, model_config, train_config) = configs
17
 
 
37
  model.requires_grad_ = False
38
  return model
39
 
40
+ @spaces.GPU(duration=10)
41
  def get_model_infer(ckpt_path, configs, device):
42
  (preprocess_config, model_config, train_config) = configs
43
 
 
53
  num_param = sum(param.numel() for param in model.parameters())
54
  return num_param
55
 
56
+ @spaces.GPU(duration=10)
57
  def get_vocoder(config, device):
58
  name = config["vocoder"]["model"]
59
  speaker = config["vocoder"]["speaker"]
 
85
 
86
  return vocoder
87
 
88
+ @spaces.GPU(duration=10)
89
  def vocoder_infer(mels, vocoder, model_config, preprocess_config, lengths=None):
90
  name = model_config["vocoder"]["model"]
91
  with torch.no_grad():