Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from typing import Tuple | |
| from toolbox.torchaudio.configuration_utils import PretrainedConfig | |
| class MPNetConfig(PretrainedConfig): | |
| """ | |
| https://github.com/yxlu-0102/MP-SENet/blob/main/config.json | |
| """ | |
| def __init__(self, | |
| num_gpus: int = 0, | |
| batch_size: int = 4, | |
| learning_rate: float = 0.0005, | |
| adam_b1: float = 0.8, | |
| adam_b2: float = 0.99, | |
| lr_decay: float = 0.99, | |
| seed: int = 1234, | |
| dense_channel: int = 64, | |
| compress_factor: float = 0.3, | |
| num_tsconformers: int = 4, | |
| beta: float = 2.0, | |
| sample_rate: int = 16000, | |
| segment_size: int = 32000, | |
| n_fft: int = 400, | |
| hop_size: int = 100, | |
| win_size: int = 400, | |
| num_workers: int = 4, | |
| dist_config: dict = None, | |
| discriminator_dim: int = 32, | |
| discriminator_in_channel: int = 2, | |
| **kwargs | |
| ): | |
| super(MPNetConfig, self).__init__(**kwargs) | |
| self.num_gpus = num_gpus | |
| self.batch_size = batch_size | |
| self.learning_rate = learning_rate | |
| self.adam_b1 = adam_b1 | |
| self.adam_b2 = adam_b2 | |
| self.lr_decay = lr_decay | |
| self.seed = seed | |
| self.dense_channel = dense_channel | |
| self.compress_factor = compress_factor | |
| self.num_tsconformers = num_tsconformers | |
| self.beta = beta | |
| self.sample_rate = sample_rate | |
| self.segment_size = segment_size | |
| self.n_fft = n_fft | |
| self.hop_size = hop_size | |
| self.win_size = win_size | |
| self.num_workers = num_workers | |
| self.dist_config = dist_config or { | |
| "dist_backend": "nccl", | |
| "dist_url": "tcp://localhost:54321", | |
| "world_size": 1 | |
| } | |
| self.discriminator_dim = discriminator_dim | |
| self.discriminator_in_channel = discriminator_in_channel | |
| if __name__ == "__main__": | |
| pass | |