Spaces:
Runtime error
Runtime error
simplify (#1)
Browse files- simplify (fd5dde034da333a19a1b05e6a56a80a5f9803b61)
- Utils/JDC/model.py +1 -1
- app.py +2 -3
- models.py +11 -76
Utils/JDC/model.py
CHANGED
|
@@ -134,7 +134,7 @@ class JDCNet(nn.Module):
|
|
| 134 |
# sizes: (b, 31, 722), (b, 31, 2)
|
| 135 |
# classifier output consists of predicted pitch classes per frame
|
| 136 |
# detector output consists of: (isvoice, notvoice) estimates per frame
|
| 137 |
-
return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
|
| 138 |
|
| 139 |
@staticmethod
|
| 140 |
def init_weights(m):
|
|
|
|
| 134 |
# sizes: (b, 31, 722), (b, 31, 2)
|
| 135 |
# classifier output consists of predicted pitch classes per frame
|
| 136 |
# detector output consists of: (isvoice, notvoice) estimates per frame
|
| 137 |
+
return torch.abs(classifier_out.squeeze(-1)), GAN_feature, poolblock_out
|
| 138 |
|
| 139 |
@staticmethod
|
| 140 |
def init_weights(m):
|
app.py
CHANGED
|
@@ -13,7 +13,6 @@ from transformers import WavLMModel
|
|
| 13 |
from env import AttrDict
|
| 14 |
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
| 15 |
from models import Generator
|
| 16 |
-
from stft import TorchSTFT
|
| 17 |
from Utils.JDC.model import JDCNet
|
| 18 |
|
| 19 |
|
|
@@ -38,7 +37,6 @@ h = AttrDict(json_config)
|
|
| 38 |
# load models
|
| 39 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 40 |
generator = Generator(h, F0_model).to(device)
|
| 41 |
-
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
|
| 42 |
|
| 43 |
state_dict_g = torch.load(ptfile, map_location=device)
|
| 44 |
generator.load_state_dict(state_dict_g['generator'], strict=True)
|
|
@@ -84,6 +82,7 @@ def convert(tgt_spk, src_wav, f0_shift=0):
|
|
| 84 |
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
|
| 85 |
|
| 86 |
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
|
|
|
|
| 87 |
|
| 88 |
# src
|
| 89 |
wav, sr = librosa.load(src_wav, sr=16000)
|
|
@@ -98,7 +97,7 @@ def convert(tgt_spk, src_wav, f0_shift=0):
|
|
| 98 |
f0 = generator.get_f0(mel, f0_mean_tgt)
|
| 99 |
f0 = tune_f0(f0, f0_shift)
|
| 100 |
x = generator.get_x(x, spk_emb, spk_id)
|
| 101 |
-
y = generator.infer(x, f0
|
| 102 |
|
| 103 |
audio = y.squeeze()
|
| 104 |
audio = audio / torch.max(torch.abs(audio)) * 0.95
|
|
|
|
| 13 |
from env import AttrDict
|
| 14 |
from meldataset import mel_spectrogram, MAX_WAV_VALUE
|
| 15 |
from models import Generator
|
|
|
|
| 16 |
from Utils.JDC.model import JDCNet
|
| 17 |
|
| 18 |
|
|
|
|
| 37 |
# load models
|
| 38 |
F0_model = JDCNet(num_class=1, seq_len=192)
|
| 39 |
generator = Generator(h, F0_model).to(device)
|
|
|
|
| 40 |
|
| 41 |
state_dict_g = torch.load(ptfile, map_location=device)
|
| 42 |
generator.load_state_dict(state_dict_g['generator'], strict=True)
|
|
|
|
| 82 |
spk_emb = torch.from_numpy(spk_emb).unsqueeze(0).to(device)
|
| 83 |
|
| 84 |
f0_mean_tgt = f0_stats[tgt_spk]["mean"]
|
| 85 |
+
f0_mean_tgt = torch.FloatTensor([f0_mean_tgt]).unsqueeze(0).to(device)
|
| 86 |
|
| 87 |
# src
|
| 88 |
wav, sr = librosa.load(src_wav, sr=16000)
|
|
|
|
| 97 |
f0 = generator.get_f0(mel, f0_mean_tgt)
|
| 98 |
f0 = tune_f0(f0, f0_shift)
|
| 99 |
x = generator.get_x(x, spk_emb, spk_id)
|
| 100 |
+
y = generator.infer(x, f0)
|
| 101 |
|
| 102 |
audio = y.squeeze()
|
| 103 |
audio = audio / torch.max(torch.abs(audio)) * 0.95
|
models.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
import math
|
| 2 |
import torch
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import torch.nn as nn
|
|
@@ -486,9 +485,6 @@ class Generator(torch.nn.Module):
|
|
| 486 |
g = g + spk_emb.unsqueeze(-1)
|
| 487 |
|
| 488 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
| 489 |
-
if len(f0.shape) == 1:
|
| 490 |
-
f0 = f0.unsqueeze(0)
|
| 491 |
-
|
| 492 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 493 |
|
| 494 |
har_source, _, _ = self.m_source(f0)
|
|
@@ -526,28 +522,21 @@ class Generator(torch.nn.Module):
|
|
| 526 |
|
| 527 |
return spec, phase
|
| 528 |
|
| 529 |
-
def get_f0(self, mel, f0_mean_tgt, voiced_threshold=10
|
| 530 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
| 531 |
-
|
| 532 |
voiced = f0 > voiced_threshold
|
| 533 |
|
| 534 |
lf0 = torch.log(f0)
|
| 535 |
-
|
| 536 |
-
|
|
|
|
|
|
|
| 537 |
f0_adj = torch.exp(lf0_adj)
|
| 538 |
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
f0_adj = self.interp_f0(f0_adj.unsqueeze(0), voiced.unsqueeze(0)).squeeze(0)
|
| 544 |
-
energy = torch.sum(mel.squeeze(0), dim=0) # simple vad
|
| 545 |
-
unsilent = energy > -700
|
| 546 |
-
unsilent = unsilent | voiced
|
| 547 |
-
f0_adj = torch.where(unsilent, f0_adj, 0)
|
| 548 |
-
|
| 549 |
-
if len(f0_adj.shape) == 1:
|
| 550 |
-
f0_adj = f0_adj.unsqueeze(0)
|
| 551 |
|
| 552 |
return f0_adj
|
| 553 |
|
|
@@ -562,7 +551,7 @@ class Generator(torch.nn.Module):
|
|
| 562 |
|
| 563 |
return x
|
| 564 |
|
| 565 |
-
def infer(self, x, f0
|
| 566 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 567 |
|
| 568 |
har_source, _, _ = self.m_source(f0)
|
|
@@ -593,62 +582,8 @@ class Generator(torch.nn.Module):
|
|
| 593 |
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
| 594 |
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
| 595 |
|
| 596 |
-
y = stft.inverse(spec, phase)
|
| 597 |
-
|
| 598 |
return y
|
| 599 |
-
|
| 600 |
-
def interp_f0(self, pitch, voiced):
|
| 601 |
-
"""Fill unvoiced regions via linear interpolation"""
|
| 602 |
-
|
| 603 |
-
# Handle no voiced frames
|
| 604 |
-
if not voiced.any():
|
| 605 |
-
return pitch
|
| 606 |
-
|
| 607 |
-
# Pitch is linear in base-2 log-space
|
| 608 |
-
pitch = torch.log2(pitch)
|
| 609 |
-
|
| 610 |
-
# Anchor endpoints
|
| 611 |
-
pitch[..., 0] = pitch[voiced][..., 0]
|
| 612 |
-
pitch[..., -1] = pitch[voiced][..., -1]
|
| 613 |
-
voiced[..., 0] = True
|
| 614 |
-
voiced[..., -1] = True
|
| 615 |
-
|
| 616 |
-
# Interpolate
|
| 617 |
-
pitch[~voiced] = self.interp(
|
| 618 |
-
torch.where(~voiced[0])[0][None],
|
| 619 |
-
torch.where(voiced[0])[0][None],
|
| 620 |
-
pitch[voiced][None])
|
| 621 |
-
|
| 622 |
-
return 2 ** pitch
|
| 623 |
-
|
| 624 |
-
@staticmethod
|
| 625 |
-
def interp(x, xp, fp):
|
| 626 |
-
"""1D linear interpolation for monotonically increasing sample points"""
|
| 627 |
-
# Handle edge cases
|
| 628 |
-
if xp.shape[-1] == 0:
|
| 629 |
-
return x
|
| 630 |
-
if xp.shape[-1] == 1:
|
| 631 |
-
return torch.full(
|
| 632 |
-
x.shape,
|
| 633 |
-
fp.squeeze(),
|
| 634 |
-
device=fp.device,
|
| 635 |
-
dtype=fp.dtype)
|
| 636 |
-
|
| 637 |
-
# Get slope and intercept using right-side first-differences
|
| 638 |
-
m = (fp[:, 1:] - fp[:, :-1]) / (xp[:, 1:] - xp[:, :-1])
|
| 639 |
-
b = fp[:, :-1] - (m.mul(xp[:, :-1]))
|
| 640 |
-
|
| 641 |
-
# Get indices to sample slope and intercept
|
| 642 |
-
indicies = torch.sum(torch.ge(x[:, :, None], xp[:, None, :]), -1) - 1
|
| 643 |
-
indicies = torch.clamp(indicies, 0, m.shape[-1] - 1)
|
| 644 |
-
line_idx = torch.linspace(
|
| 645 |
-
0,
|
| 646 |
-
indicies.shape[0],
|
| 647 |
-
1,
|
| 648 |
-
device=indicies.device).to(torch.long).expand(indicies.shape)
|
| 649 |
-
|
| 650 |
-
# Interpolate
|
| 651 |
-
return m[line_idx, indicies].mul(x) + b[line_idx, indicies]
|
| 652 |
|
| 653 |
def remove_weight_norm(self):
|
| 654 |
print('Removing weight norm...')
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
import torch.nn as nn
|
|
|
|
| 485 |
g = g + spk_emb.unsqueeze(-1)
|
| 486 |
|
| 487 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
|
|
|
|
|
|
|
|
|
| 488 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 489 |
|
| 490 |
har_source, _, _ = self.m_source(f0)
|
|
|
|
| 522 |
|
| 523 |
return spec, phase
|
| 524 |
|
| 525 |
+
def get_f0(self, mel, f0_mean_tgt, voiced_threshold=10):
|
| 526 |
f0, _, _ = self.F0_model(mel.unsqueeze(1))
|
|
|
|
| 527 |
voiced = f0 > voiced_threshold
|
| 528 |
|
| 529 |
lf0 = torch.log(f0)
|
| 530 |
+
lf0_ = lf0 * voiced.float()
|
| 531 |
+
lf0_mean = lf0_.sum(1) / voiced.float().sum(1)
|
| 532 |
+
lf0_mean = lf0_mean.unsqueeze(1)
|
| 533 |
+
lf0_adj = lf0 - lf0_mean + torch.log(f0_mean_tgt)
|
| 534 |
f0_adj = torch.exp(lf0_adj)
|
| 535 |
|
| 536 |
+
energy = mel.sum(1)
|
| 537 |
+
unsilent = energy > -700
|
| 538 |
+
unsilent = unsilent | voiced # simple vad
|
| 539 |
+
f0_adj = f0_adj * unsilent.float()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 540 |
|
| 541 |
return f0_adj
|
| 542 |
|
|
|
|
| 551 |
|
| 552 |
return x
|
| 553 |
|
| 554 |
+
def infer(self, x, f0):
|
| 555 |
f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
| 556 |
|
| 557 |
har_source, _, _ = self.m_source(f0)
|
|
|
|
| 582 |
spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
|
| 583 |
phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
|
| 584 |
|
| 585 |
+
y = self.stft.inverse(spec, phase)
|
|
|
|
| 586 |
return y
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 587 |
|
| 588 |
def remove_weight_norm(self):
|
| 589 |
print('Removing weight norm...')
|