| | import os |
| | import logging |
| |
|
| | import torch |
| | torch.set_float32_matmul_precision("medium") |
| |
|
| |
|
| | import pytorch_lightning as pl |
| | from pytorch_lightning.loggers import TensorBoardLogger |
| | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint |
| | from data_loader import DataModule |
| | from trainer import MusicClassifier |
| | from omegaconf import DictConfig |
| | import hydra |
| | from hydra.utils import to_absolute_path |
| | from hydra.core.hydra_config import HydraConfig |
| | from pytorch_lightning.callbacks import EarlyStopping |
| | from pytorch_lightning.utilities.combined_loader import CombinedLoader |
| | from pytorch_lightning.strategies import DDPStrategy |
| |
|
| | |
| | def get_latest_version(log_dir): |
| | version_dirs = [d for d in os.listdir(log_dir) if d.startswith('version_')] |
| | version_dirs.sort(key=lambda x: int(x.split('_')[-1])) |
| | return version_dirs[-1] if version_dirs else None |
| |
|
| | log = logging.getLogger(__name__) |
| | @hydra.main(version_base=None, config_path="config", config_name="train_config") |
| | def main(config: DictConfig): |
| |
|
| | log_base_dir = 'tb_logs/train_audio_classification' |
| | |
| | is_mt = False |
| | if "mt" in config.model.classifier: |
| | is_mt = True |
| |
|
| | logger = TensorBoardLogger("tb_logs", name="train_audio_classification") |
| | logger.log_hyperparams(config) |
| | train_log_dir = logger.log_dir |
| | print(f"Logging to {train_log_dir}") |
| | log.info("Training starts") |
| |
|
| | data_module = DataModule( config ) |
| | data_module.setup() |
| |
|
| | |
| | trainloaders = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.train_dataloader())} |
| | vallowers = {dataset_name: loader for dataset_name, loader in zip(config.datasets, data_module.val_dataloader())} |
| | |
| | |
| | combined_train_loader = CombinedLoader(trainloaders, mode="max_size") |
| | combined_val_loader = CombinedLoader(vallowers, mode="max_size") |
| |
|
| | latest_version = get_latest_version(log_base_dir) |
| | next_version = int(latest_version.split('_')[-1]) + 1 if latest_version else 0 |
| | next_version = f"version_{next_version}" |
| |
|
| | val_epoch_file = os.path.join(log_base_dir, latest_version, 'val_epoch.txt') |
| |
|
| | model = MusicClassifier( config, output_file = val_epoch_file) |
| |
|
| | if is_mt: |
| | checkpoint_callback_mood = ModelCheckpoint(**config.checkpoint_mood) |
| | checkpoint_callback_va = ModelCheckpoint(**config.checkpoint_va) |
| | early_stop_callback = EarlyStopping(**config.earlystopping) |
| |
|
| | if config.model.kd == True: |
| | trainer = pl.Trainer( |
| | **config.trainer, |
| | strategy=DDPStrategy(find_unused_parameters=True), |
| | callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback], |
| | logger=logger, |
| | num_sanity_val_steps=0 |
| | ) |
| | else: |
| | trainer = pl.Trainer( |
| | **config.trainer, |
| | strategy=DDPStrategy(find_unused_parameters=False), |
| | callbacks=[checkpoint_callback_mood, checkpoint_callback_va, early_stop_callback], |
| | logger=logger, |
| | num_sanity_val_steps=0 |
| | ) |
| |
|
| | else: |
| | checkpoint_callback = ModelCheckpoint(**config.checkpoint) |
| | |
| | trainer = pl.Trainer( |
| | **config.trainer, |
| | callbacks=[checkpoint_callback, early_stop_callback], |
| | logger=logger, |
| | num_sanity_val_steps = 0 |
| | ) |
| | |
| | trainer.fit(model, combined_train_loader, combined_val_loader) |
| |
|
| | if trainer.global_rank == 0: |
| | best_checkpoint_file = os.path.join(train_log_dir, 'best_checkpoint.txt') |
| | with open(best_checkpoint_file, 'w') as f: |
| | if is_mt: |
| | f.write(f"Best checkpoint (mood): {checkpoint_callback_mood.best_model_path}\n") |
| | f.write(f"Best checkpoint (va): {checkpoint_callback_va.best_model_path}\n") |
| | else: |
| | f.write(f"Best checkpoint: {checkpoint_callback.best_model_path}\n") |
| | f.write(f"Version: {logger.version}\n") |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|