Spaces:
Running
Running
Update usr/diff/shallow_diffusion_tts.py
Browse files
usr/diff/shallow_diffusion_tts.py
CHANGED
|
@@ -15,7 +15,7 @@ from modules.fastspeech.fs2 import FastSpeech2
|
|
| 15 |
from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
| 16 |
from utils.hparams import hparams
|
| 17 |
|
| 18 |
-
|
| 19 |
|
| 20 |
def exists(x):
|
| 21 |
return x is not None
|
|
@@ -227,11 +227,9 @@ class GaussianDiffusion(nn.Module):
|
|
| 227 |
|
| 228 |
return loss
|
| 229 |
|
| 230 |
-
@spaces.GPU(duration=120)
|
| 231 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
| 232 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
| 233 |
-
b, *_, device = *txt_tokens.shape,
|
| 234 |
-
print(f"在{device}上运行")
|
| 235 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
| 236 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
| 237 |
cond = ret['decoder_inp'].transpose(1, 2)
|
|
@@ -258,7 +256,6 @@ class GaussianDiffusion(nn.Module):
|
|
| 258 |
x = torch.randn(shape, device=device)
|
| 259 |
|
| 260 |
if hparams.get('pndm_speedup'):
|
| 261 |
-
print("pndm_speedup 加速中...")
|
| 262 |
self.noise_list = deque(maxlen=4)
|
| 263 |
iteration_interval = hparams['pndm_speedup']
|
| 264 |
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|
|
|
|
| 15 |
from modules.diffsinger_midi.fs2 import FastSpeech2MIDI
|
| 16 |
from utils.hparams import hparams
|
| 17 |
|
| 18 |
+
|
| 19 |
|
| 20 |
def exists(x):
|
| 21 |
return x is not None
|
|
|
|
| 227 |
|
| 228 |
return loss
|
| 229 |
|
|
|
|
| 230 |
def forward(self, txt_tokens, mel2ph=None, spk_embed=None,
|
| 231 |
ref_mels=None, f0=None, uv=None, energy=None, infer=False, **kwargs):
|
| 232 |
+
b, *_, device = *txt_tokens.shape, txt_tokens.device
|
|
|
|
| 233 |
ret = self.fs2(txt_tokens, mel2ph, spk_embed, ref_mels, f0, uv, energy,
|
| 234 |
skip_decoder=(not infer), infer=infer, **kwargs)
|
| 235 |
cond = ret['decoder_inp'].transpose(1, 2)
|
|
|
|
| 256 |
x = torch.randn(shape, device=device)
|
| 257 |
|
| 258 |
if hparams.get('pndm_speedup'):
|
|
|
|
| 259 |
self.noise_list = deque(maxlen=4)
|
| 260 |
iteration_interval = hparams['pndm_speedup']
|
| 261 |
for i in tqdm(reversed(range(0, t, iteration_interval)), desc='sample time step',
|