Yongyi Zang commited on
Commit
26ab161
·
1 Parent(s): 6bf7e1d
README.md CHANGED
@@ -1,3 +1,88 @@
1
- # Music Source Restoration Kit
2
 
3
- This repository offers a collection of model implementations, training configurations, and evaluation scripts to help you quickly get started with training and evaluating music source restoration models.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Music Source Restoration Kit
2
 
3
+ This repository offers a collection of model implementations, training configurations, and evaluation scripts to help you quickly get started with training and evaluating music source restoration models.
4
+
5
+ We have designed the repository to be a GAN-based framework; to learn more about the GANs, you can watch [this video](https://www.youtube.com/watch?v=TpMIssRdhco).
6
+
7
+ ## Directory Structure
8
+
9
+ The repository is organized to separate concerns, making it easy to extend and maintain. Click on a directory to learn more about its contents.
10
+
11
+ ```
12
+ MSRKit/
13
+ ├── README.md <- You are here
14
+ ├── config.yaml <- Main configuration file for experiments
15
+ ├── train.py <- Main script to start training
16
+ ├── unwrap.py <- Utility to extract generator weights from a checkpoint
17
+
18
+ ├── data/ <- [Data loading and augmentation](./data/README.md)
19
+
20
+ ├── evaluation/ <- [Evaluation metrics](./evaluation/README.md)
21
+
22
+ ├── losses/ <- [Loss function implementations](./losses/README.md)
23
+
24
+ ├── models/ <- [Top-level generator model architectures](./models/README.md)
25
+
26
+ └── modules/ <- [Core building blocks for models](./modules/README.md)
27
+ ├── discriminator/ <- [Discriminator architectures](./modules/discriminator/README.md)
28
+ └── generator/ <- [Reusable generator components](./modules/generator/README.md)
29
+ ```
30
+
31
+ ## 🚀 Getting Started
32
+
33
+ ### 1. Setup
34
+
35
+ First, clone the repository and install the required dependencies.
36
+
37
+ ```bash
38
+ git clone https://github.com/yongyizang/MSRKit.git
39
+ cd MSRKit
40
+ pip install -r requirements.txt
41
+ ```
42
+
43
+ *Note: The `FAD_CLAP` metric requires `laion-clap`. Please install it via `pip install laion-clap`.*
44
+
45
+ ### 2. Configure Your Experiment
46
+
47
+ Modify the `config.yaml` file to set up your dataset paths, model hyperparameters, and training settings.
48
+
49
+ Key sections to update:
50
+
51
+ - `data.train_dataset.root_directory`: Path to your training data.
52
+ - `data.train_dataset.file_list`: Path to a `.txt` file listing your training samples.
53
+ - `data.val_dataset.root_directory`: Path to your validation data.
54
+ - `data.val_dataset.file_list`: Path to a `.txt` file listing your validation samples.
55
+ - `model`: Choose the generator model and its parameters.
56
+ - `discriminators`: Add and configure one or more discriminators.
57
+ - `trainer`: Set training parameters like `max_steps`, `devices` (GPU IDs), and `precision`.
58
+
59
+ ### 3. Start Training
60
+
61
+ Launch the training process using the `train.py` script and your configuration file.
62
+
63
+ ```bash
64
+ python train.py --config config.yaml
65
+ ```
66
+
67
+ Logs, checkpoints, and audio samples will be saved in the `lightning_logs/` directory.
68
+
69
+ ### 4. Unwrap Generator Weights
70
+
71
+ After training, you may want to use the generator model for inference without the rest of the Lightning module. The `unwrap.py` script extracts the generator's `state_dict` from a checkpoint file.
72
+
73
+ ```bash
74
+ python unwrap.py --ckpt "path/to/your/checkpoint.ckpt" --out "path/to/save/generator.pth"
75
+ ```
76
+
77
+ This creates a clean `.pth` file containing only the generator's weights. This is useful if you want to use the generator model for inference without the rest of the Lightning module, or if you want to fine-tune the generator model on a different dataset.
78
+
79
+ ## Building Your First Model
80
+
81
+ To build your first model, you can reference the model architecture in the `models/` directory. You can also refer to the `modules/` directory for the building blocks used in the model architectures. At a very high level, we have implemented the following processing blocks:
82
+ - Spectral Operations: `Fourier`, `Band`
83
+ - Sequence Modeling Blocks: `RoFormerBlock` (and an example of modified attention pattern, `AttentionRegisterRoFormerBlock`), `RNNBlock`, `ConvNeXt1DBlock`
84
+ - Convolutional Blocks: `ConvNeXt2DBlock`, `ConvNeXt1DBlock`
85
+ - Discriminator Architectures: `MultiPeriodDiscriminator`, `MultiScaleDiscriminator`, `MultiResolutionDiscriminator`, `MultiFrequencyDiscriminator`
86
+
87
+ ## ⚖️ License
88
+ This project is licensed under the MIT License.
data/README.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Module
2
+
3
+ This directory contains all the necessary components for data loading, processing, and augmentation.
4
+
5
+ ## Files
6
+
7
+ ### `dataset.py`
8
+
9
+ This file defines the `RawStems` dataset class, which is the core of the data pipeline. It dynamically creates training examples by mixing a target stem with other stems based on a specified Signal-to-Noise Ratio (SNR).
10
+
11
+ #### `RawStems`
12
+
13
+ A PyTorch `Dataset` that loads and processes raw audio stems for music source restoration tasks.
14
+
15
+ **`__init__` Arguments:**
16
+
17
+ - `target_stem` (`str`): The name of the target stem folder (e.g., `"Voc"` or `"Gtr_EG"`).
18
+ - `root_directory` (`Union[str, Path]`): The root directory containing subfolders for each song.
19
+ - `file_list` (`Optional[Union[str, Path]]`): Path to a `.txt` file where each line is a path to a song folder, relative to `root_directory`.
20
+ - `sr` (`int`): The target sample rate to load audio at. Default: `44100`.
21
+ - `clip_duration` (`float`): The duration of the audio clips to be extracted, in seconds. Default: `3.0`.
22
+ - `snr_range` (`Tuple[float, float]`): A tuple representing the min and max SNR (in dB) for mixing the target stem with the noise (other stems). Default: `(0.0, 10.0)`.
23
+ - `apply_augmentation` (`bool`): Whether to apply on-the-fly augmentations to the audio. Default: `True`.
24
+
25
+ ### `augment.py`
26
+
27
+ This file implements the audio augmentation pipelines using the `pedalboard` library.
28
+
29
+ #### `StemAugmentation`
30
+
31
+ Applies a chain of augmentations suitable for the *target* audio source before it's mixed. This simulates variations in recording quality and effects.
32
+
33
+ - **Effects include**: Random EQ, Resampling, Compression, Distortion, and Reverb.
34
+
35
+ #### `MixtureAugmentation`
36
+
37
+ Applies a chain of augmentations to the final *mixture* audio. This simulates artifacts that could occur on a fully mixed track.
38
+
39
+ - **Effects include**: Limiting, Resampling, and MP3 a.k.a Codec compression.
evaluation/README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Module
2
+
3
+ This directory contains classes for evaluating model performance during validation. All metrics inherit from a base `Metric` class for a consistent interface.
4
+
5
+ ## Files
6
+
7
+ ### `metrics.py`
8
+
9
+ #### `SI_SNR` (Scale-Invariant Signal-to-Noise Ratio)
10
+
11
+ A common metric for audio source separation that measures the quality of the restored signal relative to the original target. It is invariant to the overall scaling of the estimated signal.
12
+
13
+ - `update(pred, target)`: Updates the running statistics with a new batch of predicted and target audio tensors.
14
+ - `compute()`: Calculates the mean and standard deviation of the SI-SNR scores accumulated since the last reset.
15
+ - `reset()`: Clears the accumulated statistics.
16
+
17
+ #### `FAD_CLAP` (Fréchet Audio Distance using CLAP)
18
+
19
+ Measures the Fréchet distance between the distributions of embeddings from the generated audio and the ground truth audio. It uses a pre-trained CLAP (Contrastive Language-Audio Pretraining) model to generate these embeddings, providing a perceptually relevant measure of audio quality and similarity.
20
+
21
+ **Note:** This metric requires the `laion-clap` library. If not installed, it will fall back to using random embeddings, which is not meaningful for evaluation.
22
+
23
+ - `update(pred, target)`: Extracts CLAP embeddings from the predicted and target audio tensors and stores them.
24
+ - `compute()`: Calculates the FAD score between the collected sets of embeddings.
25
+ - `reset()`: Clears the stored embeddings.
26
+
27
+ **`__init__` Arguments:**
28
+
29
+ - `embedding_dim` (`int`): The dimensionality of the embeddings. Should match the CLAP model. Default: `512`.
30
+ - `model_name` (`str`): The name of the CLAP model architecture to use. Default: `'HTSAT-base'`.
31
+ - `ckpt_path` (`Optional[str]`): Optional path to a specific CLAP model checkpoint. If `None`, it uses the default pre-trained weights.
losses/README.md ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Losses Module
2
+
3
+ This directory contains the implementations of various loss functions used for training the generator and discriminators.
4
+
5
+ ## Files
6
+
7
+ ### `gan_loss.py`
8
+
9
+ This file implements adversarial losses for both the generator and discriminator, as well as a feature matching loss.
10
+
11
+ We provide both LSGAN and Hinge GAN implementations. LSGAN and Hinge GAN differ primarily in how they penalize mistakes.
12
+
13
+ - LSGAN uses a "least squares" approach that constantly pushes fake samples toward looking real, with the penalty growing quadratically the further off they are - this means even terrible fakes get strong learning signals, preventing vanishing gradients, but the discriminator never stops pushing even on samples that are already good enough, which can cause instability.
14
+ - Hinge GAN instead creates a "satisfaction zone" where once the discriminator is confident enough about a sample (real or fake), it stops trying to improve its classification - this focuses all the learning on ambiguous samples near the decision boundary. The result: LSGAN provides consistent gradients throughout training but can overshoot and destabilize, while Hinge GAN typically produces sharper images by not wasting effort on already-separated samples, though it risks killing gradients entirely if the discriminator gets too confident too fast.
15
+
16
+ #### `GeneratorLoss`
17
+
18
+ Calculates the adversarial loss for the generator, encouraging it to produce outputs that the discriminator classifies as real.
19
+
20
+ **`__init__` Arguments:**
21
+
22
+ - `gan_type` (`str`): The type of GAN loss to use. Supports `'hinge'` and `'lsgan'` (Least Squares GAN). Default: `'hinge'`.
23
+
24
+ #### `DiscriminatorLoss`
25
+
26
+ Calculates the adversarial loss for the discriminator, training it to distinguish between real and fake (generated) inputs.
27
+
28
+ **`__init__` Arguments:**
29
+
30
+ - `gan_type` (`str`): The type of GAN loss to use. Supports `'hinge'` and `'lsgan'`. Default: `'hinge'`.
31
+
32
+ #### `FeatureMatchingLoss`
33
+
34
+ Calculates the L1 distance between the feature maps of the real and fake inputs from the intermediate layers of the discriminator. This helps stabilize training by matching the statistical properties of the features.
35
+
36
+ -----
37
+
38
+ ### `reconstruction_loss.py`
39
+
40
+ This file implements reconstruction losses that measure the direct difference between the generated audio and the ground truth target audio in various domains.
41
+
42
+ #### `MultiMelSpecReconstructionLoss`
43
+
44
+ Calculates the L1 loss between the log-mel spectrograms of the predicted and target audio. It computes this loss using multiple different STFT configurations (FFT size, hop length, mel bands) and averages the results for a more robust, multi-resolution spectral loss.
45
+
46
+ **`__init__` Arguments:**
47
+
48
+ - `sample_rate` (`int`): The sample rate of the audio.
49
+ - `n_fft` (`List[int]`): A list of FFT sizes for the different STFT resolutions.
50
+ - `hop_length` (`List[int]`): A list of hop lengths corresponding to the FFT sizes.
51
+ - `n_mels` (`List[int]`): A list of the number of mel bands corresponding to the FFT sizes.
52
+
53
+ #### `ComplexSpecReconstructionLoss`
54
+
55
+ Calculates the L1 loss on the magnitude of the complex spectrograms.
56
+
57
+ #### `MultiComplexSpecReconstructionLoss`
58
+
59
+ A multi-resolution version of `ComplexSpecReconstructionLoss`.
60
+
61
+ #### `WaveformReconstructionLoss`
62
+
63
+ Calculates a simple L1 loss directly on the raw audio waveforms.
models/README.md ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Models Module
2
+
3
+ This directory contains the high-level generator architectures. These models define the main structure for transforming a mixed audio waveform into a restored target stem. They process the audio in the spectral domain and utilize various building blocks from the `modules/` directory.
4
+
5
+ All models (currently) first transform the input waveform into a spectrogram, process it in the time-frequency domain, and then convert it back to a waveform using an inverse STFT. They uniformly assumes a mono audio tensor being processed of shape [batch, samples].
6
+
7
+ ## Files
8
+
9
+ ### `MelRoFormer.py`
10
+
11
+ #### `MelRoFormer`
12
+
13
+ A dual-path Transformer-based model that applies attention alternately along the frequency and time axes of the spectrogram. It uses `RoFormerBlock`s, which incorporate Rotary Position Embeddings (RoPE) for effective sequence modeling. This model references https://arxiv.org/abs/2409.04702.
14
+
15
+ **`__init__` Arguments:**
16
+
17
+ - `hidden_channels` (`int`): The number of channels (embedding dimension) used throughout the model.
18
+ - `num_layers` (`int`): The number of layers (a time block + a frequency block is one layer).
19
+ - `num_heads` (`int`): The number of attention heads in each RoFormer block.
20
+ - `window_size` (`int`): The STFT window size.
21
+ - `hop_size` (`int`): The STFT hop size.
22
+ - `sample_rate` (`int`): The sample rate of the input audio.
23
+
24
+ -----
25
+
26
+ ### `MelRNN.py`
27
+
28
+ #### `MelRNN`
29
+
30
+ A dual-path model similar to `MelRoFormer`, but it uses bidirectional GRUs (`RNNBlock`) instead of Transformers for processing the time and frequency axes. This can be a lighter-weight alternative to the attention-based models. This model references (yet deviates from) https://arxiv.org/abs/2209.15174.
31
+
32
+ **`__init__` Arguments:**
33
+
34
+ - `hidden_channels` (`int`): The number of channels (embedding dimension).
35
+ - `num_layers` (`int`): The number of RNN layers.
36
+ - `num_groups` (`int`): The number of groups for the `GroupedRNN` within each `RNNBlock`.
37
+ - `window_size` (`int`): The STFT window size.
38
+ - `hop_size` (`int`): The STFT hop size.
39
+ - `sample_rate` (`int`): The sample rate of the input audio.
40
+
41
+ -----
42
+
43
+ ### `UNet.py`
44
+
45
+ #### `MelUNet`
46
+
47
+ A U-Net architecture that operates on the 2D spectrogram. It uses a series of downsampling and upsampling blocks (`ConvNeXt2DBlock`) with skip connections to capture multi-scale features in the spectrogram.
48
+
49
+ **`__init__` Arguments:**
50
+
51
+ - `hidden_channels` (`int`): The initial number of channels in the network. Channel count doubles with each downsampling step.
52
+ - `num_layers` (`int`): The depth of the U-Net (number of downsampling/upsampling stages).
53
+ - `upsampling_factor` (`int`): The factor for upsampling/downsampling in each block (typically `2`).
54
+ - `window_size` (`int`): The STFT window size.
55
+ - `hop_size` (`int`): The STFT hop size.
56
+ - `sample_rate` (`int`): The sample rate of the input audio.
57
+
58
+ -----
modules/README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modules Directory
2
+
3
+ This directory contains the fundamental building blocks used to construct the larger models and discriminators. It is divided into subdirectories based on function.
4
+
5
+ ## Subdirectories
6
+
7
+ - **`discriminator/`**: Contains complete, stand-alone discriminator architectures.
8
+ - **`generator/`**: Contains reusable neural network layers and blocks (e.g., attention, RNN, ConvNeXt blocks) used in the main generator models.
9
+ - **`spectral_ops.py`**: Includes modules for spectral processing:
10
+ - `Fourier`: A wrapper for `torch.stft` and `torch.istft`.
11
+ - `Band`: A module to split a spectrogram into different frequency bands (e.g., mel scale) for processing and reassemble them.
12
+
13
+ # Discriminator Modules
14
+
15
+ This directory provides a suite of powerful, multi-component discriminators. The training script combines these into a single powerful ensemble discriminator. Each is designed to analyze audio from a different perspective (time, frequency, scale), making the generator's task more challenging and leading to higher-quality results.
16
+
17
+ ## Files
18
+
19
+ ### `MultiPeriodDiscriminator.py`
20
+
21
+ #### `MultiPeriodDiscriminator`
22
+
23
+ This discriminator operates on the raw audio waveform. It consists of several sub-discriminators, each viewing the input signal at a different *period*. For example, a sub-discriminator with `period=2` will reshape the audio into a 2D representation where adjacent samples are folded, allowing it to spot artifacts at that specific frequency. This is highly effective at detecting periodic artifacts.
24
+
25
+ **`__init__` Arguments:**
26
+
27
+ - `nch` (`int`): Number of input channels (e.g., `1` for mono). Default: `1`.
28
+ - `sample_rate` (`int`): Sample rate of the audio. Default: `48000`.
29
+ - `periods` (`List[int]`): A list of periods for each sub-discriminator. Prime numbers are recommended. Default: `[2, 3, 5, 7, 11]`.
30
+ - `norm` (`bool`): Whether to use spectral normalization. Default: `True`.
31
+
32
+ -----
33
+
34
+ ### `MultiScaleDiscriminator.py`
35
+
36
+ #### `MultiScaleDiscriminator`
37
+
38
+ This discriminator also operates on the raw waveform. It contains multiple sub-discriminators that process the audio at different resolutions by downsampling the input. This allows it to identify artifacts at various time scales, from fine-grained details to broader structural issues.
39
+
40
+ **`__init__` Arguments:**
41
+
42
+ - `sample_rate` (`int`): Sample rate of the audio.
43
+ - `downsample_rates` (`List[int]`): A list of factors to downsample the audio for each sub-discriminator. Default: `[2, 4]`.
44
+ - `nch` (`int`): Number of input channels. Default: `1`.
45
+ - `norm` (`bool`): Whether to use spectral normalization. Default: `True`.
46
+
47
+ -----
48
+
49
+ ### `MultiResolutionDiscriminator.py`
50
+
51
+ #### `MultiResolutionDiscriminator`
52
+
53
+ This discriminator operates in the spectral domain. It consists of several sub-discriminators, each analyzing the STFT of the input audio using a different window length. This allows it to detect spectral artifacts across different time-frequency resolutions.
54
+
55
+ **`__init__` Arguments:**
56
+
57
+ - `nch` (`int`): Number of input channels. Default: `1`.
58
+ - `sample_rate` (`int`): Sample rate of the audio. Default: `48000`.
59
+ - `window_lengths` (`List[int]`): A list of STFT window lengths for each sub-discriminator. Default: `[2048, 1024, 512]`.
60
+ - `hop_factor` (`float`): The ratio of hop length to window length. Default: `0.25`.
61
+ - `bands` (`List[Tuple[float, float]]`): Frequency bands to analyze, specified as fractions of the Nyquist frequency.
62
+ - `norm` (`bool`): Whether to use spectral normalization. Default: `True`.
63
+ - `hidden_channels` (`int`): The number of hidden channels in the conv layers. Default: `32`.
64
+
65
+ -----
66
+
67
+ ### `MultiFrequencyDiscriminator.py`
68
+
69
+ #### `MultiFrequencyDiscriminator`
70
+
71
+ This discriminator is similar to `MultiResolutionDiscriminator` but with a different internal architecture focused on capturing features across frequency bands. It also processes the real and imaginary parts of the STFT as separate channels. This discriminator references https://arxiv.org/abs/2210.13438's discriminator architecture.
72
+
73
+ **`__init__` Arguments:**
74
+
75
+ - `nch` (`int`): Number of input channels.
76
+ - `window_sizes` (`List[int]`): A list of STFT window sizes for each sub-discriminator.
77
+ - `hidden_channels` (`int`): The number of base hidden channels. Default: `8`.
78
+ - `sample_rate` (`int`): Sample rate of the audio. Default: `48000`.
79
+ - `norm` (`bool`): Whether to use spectral normalization. Default: `True`.
80
+
81
+ -----
82
+
83
+ # Generator Modules
84
+
85
+ This directory contains reusable building blocks that form the core components of the main generator models in the `/models` directory.
86
+
87
+ ## Files
88
+
89
+ ### `RoFormerBlock.py`
90
+
91
+ #### `RoFormerBlock`
92
+
93
+ A standard Transformer block that uses **Ro**tary **P**osition **E**mbeddings (RoPE) instead of absolute or learned position embeddings. RoPE injects positional information by rotating the query and key vectors, which is particularly effective for sequence modeling. The block consists of a self-attention layer followed by an MLP, with residual connections and RMS normalization.
94
+
95
+ **`__init__` Arguments:**
96
+
97
+ - `n_embd` (`int`): The embedding dimension (number of channels).
98
+ - `n_head` (`int`): The number of attention heads.
99
+ - `max_seq_len` (`int`): The maximum sequence length this block can handle, used to pre-compute the RoPE cache.
100
+ - `rope_base` (`int`): The base value for the rotary position embedding calculation. Default: `10000`.
101
+
102
+ -----
103
+
104
+ ### `AttentionRegisterRoFormerBlock.py`
105
+
106
+ #### `AttentionRegisterRoFormerBlock`
107
+
108
+ An extension of the `RoFormerBlock` that implements **Attention Registers**. This technique adds a small number of learnable "register" tokens to the sequence. These tokens act as a global memory or scratchpad for the attention mechanism, improving its ability to retain and access information across the entire sequence, especially when combined with a sliding window attention mechanism.
109
+
110
+ **`__init__` Arguments:**
111
+
112
+ - *(Inherits from `RoFormerBlock`)*
113
+ - `num_register_tokens` (`int`): The number of register tokens to prepend to the sequence. Default: `0`.
114
+ - `window_size` (`int`): The size of the sliding attention window. If `-1`, full attention is used. Default: `-1`.
115
+
116
+ -----
117
+
118
+ ### `RNNBlock.py`
119
+
120
+ #### `RNNBlock`
121
+
122
+ A block that uses a Recurrent Neural Network (RNN) layer followed by an MLP, with residual connections and RMS normalization. It uses a `GroupedRNN` internally.
123
+
124
+ **`__init__` Arguments:**
125
+
126
+ - `n_embd` (`int`): The embedding dimension.
127
+ - `n_layer` (`int`): The number of layers in the RNN.
128
+ - `n_groups` (`int`): The number of parallel, smaller RNNs to use in the `GroupedRNN`. The embedding dimension is split across these groups.
129
+ - `rnn_type` (`str`): The type of RNN cell to use, either `'gru'` or `'lstm'`. Default: `'gru'`.
130
+ - `bidirectional` (`bool`): Whether to use a bidirectional RNN. Default: `False`.
131
+
132
+ -----
133
+
134
+ ### `ConvNeXt1DBlock.py` & `ConvNeXt2DBlock.py`
135
+
136
+ #### `ConvNeXt1DBlock` / `ConvNeXt2DBlock`
137
+
138
+ Implementations of the ConvNeXt block for 1D and 2D data, respectively. This block is a modern, pure-convolutional architecture that adopts design principles from Vision Transformers. It features a depthwise convolution followed by pointwise convolutions (linear layers) in an inverted bottleneck structure. These blocks can be run in `'normal'` mode (downsampling) or `'transposed'` mode (upsampling).
139
+
140
+ **`__init__` Arguments:**
141
+
142
+ - `kernel_size` (`int` or `tuple`): The kernel size for the depthwise convolution.
143
+ - `stride` (`int` or `tuple`): The stride for the convolution, used for down/up-sampling.
144
+ - `input_dim` (`int`): The number of input channels.
145
+ - `output_dim` (`int`): The number of output channels.
146
+ - `mode` (`str`): Operation mode, either `'normal'` for `ConvNd` or `'transposed'` for `ConvTransposeNd`. Default: `'normal'`.
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ einops==0.8.1
2
+ importlib-metadata==8.0.0
3
+ jaraco.collections==5.1.0
4
+ librosa==0.11.0
5
+ thop==0.1.1.post2209072238
6
+ tomli==2.0.1
7
+ torch==2.8.0
8
+ torchaudio==2.8.0
9
+ pytorch-lightning