Spaces:
Runtime error
Runtime error
| import os | |
| from datetime import datetime | |
| from pathlib import Path | |
| import torch | |
| import typer | |
| from accelerate import Accelerator | |
| from accelerate.utils import LoggerType | |
| from torch import Tensor | |
| from torch.optim import AdamW | |
| # from torch.optim.lr_scheduler import ReduceLROnPlateau | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from data import MusdbDataset | |
| from splitter import Splitter | |
| DISABLE_TQDM = os.environ.get("DISABLE_TQDM", False) | |
| app = typer.Typer(pretty_exceptions_show_locals=False) | |
| def spectrogram_loss(masked_target: Tensor, original: Tensor) -> Tensor: | |
| """ | |
| masked_target (Tensor): a masked STFT generated by applying a net's | |
| estimated mask for source S to the ground truth STFT for source S | |
| original (Tensor): an original input mixture | |
| """ | |
| square_difference = torch.square(masked_target - original) | |
| loss_value = torch.mean(square_difference) | |
| return loss_value | |
| def train( | |
| dataset: str = "data/musdb18-wav", | |
| output_dir: str = None, | |
| fp16: bool = False, | |
| cpu: bool = True, | |
| max_steps: int = 100, | |
| num_train_epochs: int = 1, | |
| per_device_train_batch_size: int = 1, | |
| effective_batch_size: int = 4, | |
| max_grad_norm: float = 0.0, | |
| ) -> None: | |
| if not output_dir: | |
| now_str = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| output_dir = f"experiments/{now_str}" | |
| output_dir = Path(output_dir) | |
| logging_dir = output_dir / "tracker_logs" | |
| accelerator = Accelerator( | |
| fp16=fp16, | |
| cpu=cpu, | |
| logging_dir=logging_dir, | |
| log_with=[LoggerType.TENSORBOARD], | |
| ) | |
| accelerator.init_trackers(logging_dir / "run") | |
| train_dataset = MusdbDataset(root=dataset, is_train=True) | |
| train_dataloader = DataLoader( | |
| train_dataset, | |
| shuffle=True, | |
| batch_size=per_device_train_batch_size, | |
| ) | |
| model = Splitter(stem_names=[s for s in train_dataset.targets]) | |
| optimizer = AdamW( | |
| model.parameters(), | |
| lr=1e-3, | |
| eps=1e-8, | |
| ) | |
| model, optimizer, train_dataloader = accelerator.prepare( | |
| model, optimizer, train_dataloader | |
| ) | |
| num_train_steps = ( | |
| max_steps if max_steps > 0 else len(train_dataloader) * num_train_epochs | |
| ) | |
| accelerator.print(f"Num train steps: {num_train_steps}") | |
| step_batch_size = per_device_train_batch_size * accelerator.num_processes | |
| gradient_accumulation_steps = max( | |
| 1, | |
| effective_batch_size // step_batch_size, | |
| ) | |
| accelerator.print( | |
| f"Gradient Accumulation Steps: {gradient_accumulation_steps}\nEffective Batch Size: {gradient_accumulation_steps * step_batch_size}" | |
| ) | |
| global_step = 0 | |
| while global_step < num_train_steps: | |
| accelerator.wait_for_everyone() | |
| # accelerator.print(f"global step: {global_step}") | |
| # accelerator.print("running train...") | |
| model.train() | |
| batch_iterator = tqdm( | |
| train_dataloader, | |
| desc="Batch", | |
| disable=((not accelerator.is_local_main_process) or DISABLE_TQDM), | |
| ) | |
| for batch_idx, batch in enumerate(batch_iterator): | |
| assert per_device_train_batch_size == 1, "For now limit to 1." | |
| x_wav, y_target_wavs = batch | |
| predictions = model(x_wav) | |
| stem_losses = [] | |
| for name, masked_stft in predictions.items(): | |
| target_stft, _ = model.compute_stft(y_target_wavs[name].squeeze()) | |
| loss = spectrogram_loss( | |
| masked_target=masked_stft, | |
| original=target_stft, | |
| ) | |
| stem_losses.append(loss) | |
| accelerator.log({f"train-loss-{name}": 1.0 * loss}, step=global_step) | |
| total_loss = ( | |
| torch.sum(torch.stack(stem_losses)) / gradient_accumulation_steps | |
| ) | |
| accelerator.print(f"global step: {global_step}\tloss: {total_loss:.4f}") | |
| accelerator.log({f"train-loss": 1.0 * total_loss}, step=global_step) | |
| accelerator.backward(total_loss) | |
| if (batch_idx + 1) % gradient_accumulation_steps == 0: | |
| if max_grad_norm > 0: | |
| accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) | |
| optimizer.step() | |
| optimizer.zero_grad() | |
| global_step += 1 | |
| accelerator.wait_for_everyone() | |
| accelerator.end_training() | |
| accelerator.print(f"Saving model to {output_dir}...") | |
| unwrapped_model = accelerator.unwrap_model(model) | |
| unwrapped_model.save_pretrained( | |
| output_dir, | |
| save_function=accelerator.save, | |
| state_dict=accelerator.get_state_dict(model), | |
| ) | |
| accelerator.wait_for_everyone() | |
| accelerator.print("DONE!") | |
| if __name__ == "__main__": | |
| app() | |