Spaces:
Runtime error
Runtime error
Commit Β·
d8f7979
1
Parent(s): 12e94ae
Restore classifier, move shell scripts to scripts
Browse files- README.md +12 -12
- remfx/classifier.py +19 -21
- remfx/models.py +1 -46
- scripts/chain_inference.py +2 -0
- download_ckpts.sh β scripts/download_ckpts.sh +0 -0
- download_eval_datasets.sh β scripts/download_eval_datasets.sh +0 -0
- eval.sh β scripts/eval.sh +3 -3
- remfx_detect.sh β scripts/remfx_detect.sh +1 -1
README.md
CHANGED
|
@@ -16,12 +16,12 @@ This repo can be used for many different tasks. Here are some examples.
|
|
| 16 |
## Run RemFX Detect on a single file
|
| 17 |
First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
|
| 18 |
```
|
| 19 |
-
|
| 20 |
-
|
| 21 |
```
|
| 22 |
## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
|
| 23 |
```
|
| 24 |
-
|
| 25 |
```
|
| 26 |
|
| 27 |
## Download the starter datasets
|
|
@@ -73,28 +73,28 @@ Also note that the training assumes you have a GPU. To train on CPU, set `accele
|
|
| 73 |
First download the General Purpose Audio Effect Removal evaluation datasets (see above).
|
| 74 |
To use the pretrained RemFX model, download the checkpoints
|
| 75 |
```
|
| 76 |
-
|
| 77 |
```
|
| 78 |
Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
|
| 79 |
```
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
|
| 87 |
```
|
| 88 |
To eval a custom monolithic model, first train a model (see Training)
|
| 89 |
Then run the evaluation script, with the config used and checkpoint_path.
|
| 90 |
```
|
| 91 |
-
|
| 92 |
```
|
| 93 |
|
| 94 |
To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml -> ckpts -> {effect}`.
|
| 95 |
Then run the evaluation script.
|
| 96 |
```
|
| 97 |
-
|
| 98 |
```
|
| 99 |
|
| 100 |
The script assumes that RemFX_eval_datasets is in the top-level directory.
|
|
|
|
| 16 |
## Run RemFX Detect on a single file
|
| 17 |
First, need to download the checkpoints from [zenodo](https://zenodo.org/record/8179396)
|
| 18 |
```
|
| 19 |
+
scripts/download_checkpoints.sh
|
| 20 |
+
scripts/remfx_detect.sh wet.wav -o dry.wav
|
| 21 |
```
|
| 22 |
## Download the [General Purpose Audio Effect Removal evaluation datasets](https://zenodo.org/record/8187288)
|
| 23 |
```
|
| 24 |
+
scripts/download_eval_datasets.sh
|
| 25 |
```
|
| 26 |
|
| 27 |
## Download the starter datasets
|
|
|
|
| 73 |
First download the General Purpose Audio Effect Removal evaluation datasets (see above).
|
| 74 |
To use the pretrained RemFX model, download the checkpoints
|
| 75 |
```
|
| 76 |
+
scripts/download_checkpoints.sh
|
| 77 |
```
|
| 78 |
Then run the evaluation script, select the RemFX configuration, between `remfx_oracle`, `remfx_detect`, and `remfx_all`. Then select N, the number of effects to remove.
|
| 79 |
```
|
| 80 |
+
scripts/eval.sh remfx_detect 0-0
|
| 81 |
+
scripts/eval.sh remfx_detect 1-1
|
| 82 |
+
scripts/eval.sh remfx_detect 2-2
|
| 83 |
+
scripts/eval.sh remfx_detect 3-3
|
| 84 |
+
scripts/eval.sh remfx_detect 4-4
|
| 85 |
+
scripts/eval.sh remfx_detect 5-5
|
| 86 |
|
| 87 |
```
|
| 88 |
To eval a custom monolithic model, first train a model (see Training)
|
| 89 |
Then run the evaluation script, with the config used and checkpoint_path.
|
| 90 |
```
|
| 91 |
+
scripts/eval.sh distortion_aug 0-0 -ckpt "logs/ckpts/2023-07-26-10-10-27/epoch\=05-valid_loss\=8.623.ckpt"
|
| 92 |
```
|
| 93 |
|
| 94 |
To eval a custom effect-specific model as part of the inference chain, first train a model (see Training), then edit `cfg/exp/remfx_{desired_configuration}.yaml -> ckpts -> {effect}`.
|
| 95 |
Then run the evaluation script.
|
| 96 |
```
|
| 97 |
+
scripts/eval.sh remfx_detect 0-0
|
| 98 |
```
|
| 99 |
|
| 100 |
The script assumes that RemFX_eval_datasets is in the top-level directory.
|
remfx/classifier.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
import torch
|
| 2 |
import torchaudio
|
| 3 |
import torch.nn as nn
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
# import hearbaseline.vggish
|
| 8 |
-
# import hearbaseline.wav2vec2
|
| 9 |
|
| 10 |
import wav2clip_hear
|
| 11 |
import panns_hear
|
|
@@ -173,10 +171,10 @@ class Cnn14(nn.Module):
|
|
| 173 |
|
| 174 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
| 175 |
|
| 176 |
-
self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
|
| 181 |
self.init_weight()
|
| 182 |
|
|
@@ -192,7 +190,7 @@ class Cnn14(nn.Module):
|
|
| 192 |
def init_weight(self):
|
| 193 |
init_bn(self.bn0)
|
| 194 |
init_layer(self.fc1)
|
| 195 |
-
init_layer(self.fc_audioset)
|
| 196 |
|
| 197 |
def forward(self, x: torch.Tensor, train: bool = False):
|
| 198 |
"""
|
|
@@ -212,12 +210,12 @@ class Cnn14(nn.Module):
|
|
| 212 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
| 213 |
# plt.savefig("spec_augment.png", dpi=300)
|
| 214 |
|
| 215 |
-
x = x.permute(0, 2, 1, 3)
|
| 216 |
-
x = self.bn0(x)
|
| 217 |
-
x = x.permute(0, 2, 1, 3)
|
| 218 |
|
| 219 |
# apply standardization
|
| 220 |
-
|
| 221 |
|
| 222 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
| 223 |
x = F.dropout(x, p=0.2, training=train)
|
|
@@ -239,13 +237,13 @@ class Cnn14(nn.Module):
|
|
| 239 |
x = F.dropout(x, p=0.5, training=train)
|
| 240 |
x = F.relu_(self.fc1(x))
|
| 241 |
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
| 245 |
|
| 246 |
-
|
| 247 |
-
return clipwise_output
|
| 248 |
-
# return outputs
|
| 249 |
|
| 250 |
|
| 251 |
class ConvBlock(nn.Module):
|
|
@@ -296,4 +294,4 @@ class ConvBlock(nn.Module):
|
|
| 296 |
else:
|
| 297 |
raise Exception("Incorrect argument!")
|
| 298 |
|
| 299 |
-
return x
|
|
|
|
| 1 |
import torch
|
| 2 |
import torchaudio
|
| 3 |
import torch.nn as nn
|
| 4 |
+
import hearbaseline
|
| 5 |
+
import hearbaseline.vggish
|
| 6 |
+
import hearbaseline.wav2vec2
|
|
|
|
|
|
|
| 7 |
|
| 8 |
import wav2clip_hear
|
| 9 |
import panns_hear
|
|
|
|
| 171 |
|
| 172 |
self.fc1 = nn.Linear(2048, 2048, bias=True)
|
| 173 |
|
| 174 |
+
# self.fc_audioset = nn.Linear(2048, num_classes, bias=True)
|
| 175 |
+
self.heads = torch.nn.ModuleList()
|
| 176 |
+
for _ in range(num_classes):
|
| 177 |
+
self.heads.append(nn.Linear(2048, 1, bias=True))
|
| 178 |
|
| 179 |
self.init_weight()
|
| 180 |
|
|
|
|
| 190 |
def init_weight(self):
|
| 191 |
init_bn(self.bn0)
|
| 192 |
init_layer(self.fc1)
|
| 193 |
+
# init_layer(self.fc_audioset)
|
| 194 |
|
| 195 |
def forward(self, x: torch.Tensor, train: bool = False):
|
| 196 |
"""
|
|
|
|
| 210 |
# axs[1].imshow(x[0, :, :, :].detach().squeeze().cpu().numpy())
|
| 211 |
# plt.savefig("spec_augment.png", dpi=300)
|
| 212 |
|
| 213 |
+
# x = x.permute(0, 2, 1, 3)
|
| 214 |
+
# x = self.bn0(x)
|
| 215 |
+
# x = x.permute(0, 2, 1, 3)
|
| 216 |
|
| 217 |
# apply standardization
|
| 218 |
+
x = (x - x.mean(dim=0, keepdim=True)) / x.std(dim=0, keepdim=True)
|
| 219 |
|
| 220 |
x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
|
| 221 |
x = F.dropout(x, p=0.2, training=train)
|
|
|
|
| 237 |
x = F.dropout(x, p=0.5, training=train)
|
| 238 |
x = F.relu_(self.fc1(x))
|
| 239 |
|
| 240 |
+
outputs = []
|
| 241 |
+
for head in self.heads:
|
| 242 |
+
outputs.append(torch.sigmoid(head(x)))
|
| 243 |
+
|
| 244 |
+
# clipwise_output = self.fc_audioset(x)
|
| 245 |
|
| 246 |
+
return outputs
|
|
|
|
|
|
|
| 247 |
|
| 248 |
|
| 249 |
class ConvBlock(nn.Module):
|
|
|
|
| 294 |
else:
|
| 295 |
raise Exception("Incorrect argument!")
|
| 296 |
|
| 297 |
+
return x
|
remfx/models.py
CHANGED
|
@@ -143,17 +143,8 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 143 |
prog_bar=True,
|
| 144 |
sync_dist=True,
|
| 145 |
)
|
| 146 |
-
# print(f"Input_{metric}", negate * self.metrics[metric](x, y))
|
| 147 |
-
# print(f"test_{metric}", negate * self.metrics[metric](output, y))
|
| 148 |
-
# self.output_str += f"{negate * self.metrics[metric](x, y).item():.4f},{negate * self.metrics[metric](output, y).item():.4f},"
|
| 149 |
-
# self.output_str += "\n"
|
| 150 |
return loss
|
| 151 |
|
| 152 |
-
def on_test_end(self) -> None:
|
| 153 |
-
pass
|
| 154 |
-
# with open("output.csv", "w") as f:
|
| 155 |
-
# f.write(self.output_str)
|
| 156 |
-
|
| 157 |
def sample(self, batch):
|
| 158 |
return self.forward(batch, 0)[1]
|
| 159 |
|
|
@@ -438,7 +429,6 @@ def mixup(x: torch.Tensor, y: torch.Tensor, alpha: float = 1.0):
|
|
| 438 |
|
| 439 |
return mixed_x, mixed_y, lam
|
| 440 |
|
| 441 |
-
|
| 442 |
class FXClassifier(pl.LightningModule):
|
| 443 |
def __init__(
|
| 444 |
self,
|
|
@@ -458,42 +448,7 @@ class FXClassifier(pl.LightningModule):
|
|
| 458 |
self.mixup = mixup
|
| 459 |
self.label_smoothing = label_smoothing
|
| 460 |
|
| 461 |
-
self.loss_fn = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
| 462 |
self.loss_fn = torch.nn.BCELoss()
|
| 463 |
-
|
| 464 |
-
if False:
|
| 465 |
-
self.train_f1 = torchmetrics.classification.MultilabelF1Score(
|
| 466 |
-
5, average="none", multidim_average="global"
|
| 467 |
-
)
|
| 468 |
-
self.val_f1 = torchmetrics.classification.MultilabelF1Score(
|
| 469 |
-
5, average="none", multidim_average="global"
|
| 470 |
-
)
|
| 471 |
-
self.test_f1 = torchmetrics.classification.MultilabelF1Score(
|
| 472 |
-
5, average="none", multidim_average="global"
|
| 473 |
-
)
|
| 474 |
-
|
| 475 |
-
self.train_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
| 476 |
-
5, threshold=0.5, average="macro", multidim_average="global"
|
| 477 |
-
)
|
| 478 |
-
self.val_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
| 479 |
-
5, threshold=0.5, average="macro", multidim_average="global"
|
| 480 |
-
)
|
| 481 |
-
self.test_f1_avg = torchmetrics.classification.MultilabelF1Score(
|
| 482 |
-
5, threshold=0.5, average="macro", multidim_average="global"
|
| 483 |
-
)
|
| 484 |
-
|
| 485 |
-
self.metrics = {
|
| 486 |
-
"train": self.train_acc,
|
| 487 |
-
"valid": self.val_acc,
|
| 488 |
-
"test": self.test_acc,
|
| 489 |
-
}
|
| 490 |
-
|
| 491 |
-
self.avg_metrics = {
|
| 492 |
-
"train": self.train_f1_avg,
|
| 493 |
-
"valid": self.val_f1_avg,
|
| 494 |
-
"test": self.test_f1_avg,
|
| 495 |
-
}
|
| 496 |
-
|
| 497 |
self.metrics = torch.nn.ModuleDict()
|
| 498 |
for effect in self.effects:
|
| 499 |
self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
|
@@ -578,4 +533,4 @@ class FXClassifier(pl.LightningModule):
|
|
| 578 |
lr=self.lr,
|
| 579 |
weight_decay=self.lr_weight_decay,
|
| 580 |
)
|
| 581 |
-
return optimizer
|
|
|
|
| 143 |
prog_bar=True,
|
| 144 |
sync_dist=True,
|
| 145 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
return loss
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
def sample(self, batch):
|
| 149 |
return self.forward(batch, 0)[1]
|
| 150 |
|
|
|
|
| 429 |
|
| 430 |
return mixed_x, mixed_y, lam
|
| 431 |
|
|
|
|
| 432 |
class FXClassifier(pl.LightningModule):
|
| 433 |
def __init__(
|
| 434 |
self,
|
|
|
|
| 448 |
self.mixup = mixup
|
| 449 |
self.label_smoothing = label_smoothing
|
| 450 |
|
|
|
|
| 451 |
self.loss_fn = torch.nn.BCELoss()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
self.metrics = torch.nn.ModuleDict()
|
| 453 |
for effect in self.effects:
|
| 454 |
self.metrics[f"train_{effect}_acc"] = torchmetrics.classification.Accuracy(
|
|
|
|
| 533 |
lr=self.lr,
|
| 534 |
weight_decay=self.lr_weight_decay,
|
| 535 |
)
|
| 536 |
+
return optimizer
|
scripts/chain_inference.py
CHANGED
|
@@ -45,6 +45,7 @@ def main(cfg: DictConfig):
|
|
| 45 |
|
| 46 |
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
| 47 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
|
|
|
| 48 |
trainer = hydra.utils.instantiate(
|
| 49 |
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
| 50 |
)
|
|
@@ -68,6 +69,7 @@ def main(cfg: DictConfig):
|
|
| 68 |
shuffle_effect_order=cfg.inference_effects_shuffle,
|
| 69 |
use_all_effect_models=cfg.inference_use_all_effect_models,
|
| 70 |
)
|
|
|
|
| 71 |
trainer.test(model=inference_model, datamodule=datamodule)
|
| 72 |
|
| 73 |
|
|
|
|
| 45 |
|
| 46 |
logger = hydra.utils.instantiate(cfg.logger, _convert_="partial")
|
| 47 |
log.info(f"Instantiating trainer <{cfg.trainer._target_}>.")
|
| 48 |
+
cfg.trainer.accelerator = "gpu" if torch.cuda.is_available() else "cpu"
|
| 49 |
trainer = hydra.utils.instantiate(
|
| 50 |
cfg.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
|
| 51 |
)
|
|
|
|
| 69 |
shuffle_effect_order=cfg.inference_effects_shuffle,
|
| 70 |
use_all_effect_models=cfg.inference_use_all_effect_models,
|
| 71 |
)
|
| 72 |
+
|
| 73 |
trainer.test(model=inference_model, datamodule=datamodule)
|
| 74 |
|
| 75 |
|
download_ckpts.sh β scripts/download_ckpts.sh
RENAMED
|
File without changes
|
download_eval_datasets.sh β scripts/download_eval_datasets.sh
RENAMED
|
File without changes
|
eval.sh β scripts/eval.sh
RENAMED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
#! /bin/bash
|
| 2 |
|
| 3 |
# Example usage:
|
| 4 |
-
#
|
| 5 |
-
#
|
| 6 |
# First 2 arguments are required, third argument is optional
|
| 7 |
|
| 8 |
# Default value for the optional parameter
|
| 9 |
ckpt_path=""
|
| 10 |
-
|
| 11 |
# Function to display script usage
|
| 12 |
function display_usage {
|
| 13 |
echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
|
|
|
|
| 1 |
#! /bin/bash
|
| 2 |
|
| 3 |
# Example usage:
|
| 4 |
+
# scripts/eval.sh remfx_detect 0-0
|
| 5 |
+
# scripts/eval.sh distortion_aug 0-0 -ckpt logs/ckpts/2023-01-21-12-21-44
|
| 6 |
# First 2 arguments are required, third argument is optional
|
| 7 |
|
| 8 |
# Default value for the optional parameter
|
| 9 |
ckpt_path=""
|
| 10 |
+
export DATASET_ROOT=RemFX_eval_datasets
|
| 11 |
# Function to display script usage
|
| 12 |
function display_usage {
|
| 13 |
echo "Usage: $0 <experiment> <dataset> [-ckpt {ckpt_path}]"
|
remfx_detect.sh β scripts/remfx_detect.sh
RENAMED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
#! /bin/bash
|
| 2 |
|
| 3 |
# Example usage:
|
| 4 |
-
#
|
| 5 |
# first argument is required, second argument is optional
|
| 6 |
|
| 7 |
# Check if first argument is empty
|
|
|
|
| 1 |
#! /bin/bash
|
| 2 |
|
| 3 |
# Example usage:
|
| 4 |
+
# scripts/remfx_detect.sh wet.wav -o examples/output.wav
|
| 5 |
# first argument is required, second argument is optional
|
| 6 |
|
| 7 |
# Check if first argument is empty
|