Spaces:
Runtime error
Runtime error
Commit ·
e8eaf47
1
Parent(s): 943f213
Add resume from checkpoint for training
Browse files- remfx/models.py +3 -3
- scripts/train.py +4 -0
remfx/models.py
CHANGED
|
@@ -43,7 +43,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 43 |
effects_order = order
|
| 44 |
else:
|
| 45 |
effects_order = self.effect_order
|
| 46 |
-
|
| 47 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
| 48 |
for effect_label in rem_fx_labels
|
| 49 |
]
|
|
@@ -56,7 +56,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 56 |
id="input_effected_audio",
|
| 57 |
samples=input_samples.cpu(),
|
| 58 |
sampling_rate=self.sample_rate,
|
| 59 |
-
caption=
|
| 60 |
)
|
| 61 |
log_wandb_audio_batch(
|
| 62 |
logger=self.logger,
|
|
@@ -66,7 +66,7 @@ class RemFXChainInference(pl.LightningModule):
|
|
| 66 |
caption="Target Data",
|
| 67 |
)
|
| 68 |
with torch.no_grad():
|
| 69 |
-
for i, (elem, effects_list) in enumerate(zip(x,
|
| 70 |
elem = elem.unsqueeze(0) # Add batch dim
|
| 71 |
# Get the correct effect by search for names in effects_order
|
| 72 |
effect_list_names = [effect.__name__ for effect in effects_list]
|
|
|
|
| 43 |
effects_order = order
|
| 44 |
else:
|
| 45 |
effects_order = self.effect_order
|
| 46 |
+
effects_present = [
|
| 47 |
[ALL_EFFECTS[i] for i, effect in enumerate(effect_label) if effect == 1.0]
|
| 48 |
for effect_label in rem_fx_labels
|
| 49 |
]
|
|
|
|
| 56 |
id="input_effected_audio",
|
| 57 |
samples=input_samples.cpu(),
|
| 58 |
sampling_rate=self.sample_rate,
|
| 59 |
+
caption="Input Data",
|
| 60 |
)
|
| 61 |
log_wandb_audio_batch(
|
| 62 |
logger=self.logger,
|
|
|
|
| 66 |
caption="Target Data",
|
| 67 |
)
|
| 68 |
with torch.no_grad():
|
| 69 |
+
for i, (elem, effects_list) in enumerate(zip(x, effects_present)):
|
| 70 |
elem = elem.unsqueeze(0) # Add batch dim
|
| 71 |
# Get the correct effect by search for names in effects_order
|
| 72 |
effect_list_names = [effect.__name__ for effect in effects_list]
|
scripts/train.py
CHANGED
|
@@ -16,6 +16,10 @@ def main(cfg: DictConfig):
|
|
| 16 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
| 17 |
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
# Init all callbacks
|
| 20 |
callbacks = []
|
| 21 |
if "callbacks" in cfg:
|
|
|
|
| 16 |
log.info(f"Instantiating model <{cfg.model._target_}>.")
|
| 17 |
model = hydra.utils.instantiate(cfg.model, _convert_="partial")
|
| 18 |
|
| 19 |
+
if "ckpt_path" in cfg:
|
| 20 |
+
log.info(f"Loading checkpoint from <{cfg.ckpt_path}>.")
|
| 21 |
+
model = model.load_from_checkpoint(cfg.ckpt_path)
|
| 22 |
+
|
| 23 |
# Init all callbacks
|
| 24 |
callbacks = []
|
| 25 |
if "callbacks" in cfg:
|