Spaces:
Runtime error
Runtime error
Commit ·
7e4b346
1
Parent(s): c9ec30f
Remove umx
Browse files- remfx/models.py +0 -49
remfx/models.py
CHANGED
|
@@ -6,7 +6,6 @@ from torch import Tensor, nn
|
|
| 6 |
from torchaudio.models import HDemucs
|
| 7 |
from auraloss.time import SISDRLoss
|
| 8 |
from auraloss.freq import MultiResolutionSTFTLoss
|
| 9 |
-
from umx.openunmix.model import OpenUnmix, Separator
|
| 10 |
|
| 11 |
from remfx.utils import spectrogram
|
| 12 |
from remfx.tcn import TCN
|
|
@@ -256,54 +255,6 @@ class RemFX(pl.LightningModule):
|
|
| 256 |
return loss
|
| 257 |
|
| 258 |
|
| 259 |
-
class OpenUnmixModel(nn.Module):
|
| 260 |
-
def __init__(
|
| 261 |
-
self,
|
| 262 |
-
n_fft: int = 2048,
|
| 263 |
-
hop_length: int = 512,
|
| 264 |
-
n_channels: int = 1,
|
| 265 |
-
alpha: float = 0.3,
|
| 266 |
-
sample_rate: int = 22050,
|
| 267 |
-
):
|
| 268 |
-
super().__init__()
|
| 269 |
-
self.n_channels = n_channels
|
| 270 |
-
self.n_fft = n_fft
|
| 271 |
-
self.hop_length = hop_length
|
| 272 |
-
self.alpha = alpha
|
| 273 |
-
window = torch.hann_window(n_fft)
|
| 274 |
-
self.register_buffer("window", window)
|
| 275 |
-
|
| 276 |
-
self.num_bins = self.n_fft // 2 + 1
|
| 277 |
-
self.sample_rate = sample_rate
|
| 278 |
-
self.model = OpenUnmix(
|
| 279 |
-
nb_channels=self.n_channels,
|
| 280 |
-
nb_bins=self.num_bins,
|
| 281 |
-
)
|
| 282 |
-
self.separator = Separator(
|
| 283 |
-
target_models={"other": self.model},
|
| 284 |
-
nb_channels=self.n_channels,
|
| 285 |
-
sample_rate=self.sample_rate,
|
| 286 |
-
n_fft=self.n_fft,
|
| 287 |
-
n_hop=self.hop_length,
|
| 288 |
-
)
|
| 289 |
-
self.mrstftloss = MultiResolutionSTFTLoss(
|
| 290 |
-
n_bins=self.num_bins, sample_rate=self.sample_rate
|
| 291 |
-
)
|
| 292 |
-
self.l1loss = nn.L1Loss()
|
| 293 |
-
|
| 294 |
-
def forward(self, batch):
|
| 295 |
-
x, target = batch
|
| 296 |
-
X = spectrogram(x, self.window, self.n_fft, self.hop_length, self.alpha)
|
| 297 |
-
Y = self.model(X)
|
| 298 |
-
sep_out = self.separator(x).squeeze(1)
|
| 299 |
-
loss = self.mrstftloss(sep_out, target) + self.l1loss(sep_out, target) * 100
|
| 300 |
-
|
| 301 |
-
return loss, sep_out
|
| 302 |
-
|
| 303 |
-
def sample(self, x: Tensor) -> Tensor:
|
| 304 |
-
return self.separator(x).squeeze(1)
|
| 305 |
-
|
| 306 |
-
|
| 307 |
class DemucsModel(nn.Module):
|
| 308 |
def __init__(self, sample_rate, **kwargs) -> None:
|
| 309 |
super().__init__()
|
|
|
|
| 6 |
from torchaudio.models import HDemucs
|
| 7 |
from auraloss.time import SISDRLoss
|
| 8 |
from auraloss.freq import MultiResolutionSTFTLoss
|
|
|
|
| 9 |
|
| 10 |
from remfx.utils import spectrogram
|
| 11 |
from remfx.tcn import TCN
|
|
|
|
| 255 |
return loss
|
| 256 |
|
| 257 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
class DemucsModel(nn.Module):
|
| 259 |
def __init__(self, sample_rate, **kwargs) -> None:
|
| 260 |
super().__init__()
|