File size: 5,973 Bytes
9491b77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
---
license: mit
tags:
- audio
- audio-enhancement
- speech-enhancement
- bandwidth-extension
- codec-repair
- neural-codec
- waveform-processing
- pytorch
library_name: pytorch
pipeline_tag: audio-to-audio
frameworks: PyTorch
language:
- en
---
# Brontes: Synthesis-First Waveform Enhancement
**Brontes** is a time-domain audio enhancement model designed for neural codec repair and bandwidth extension. This is the general pretrained model trained on diverse audio data.
## Model Description
Brontes upsamples and repairs speech degraded by neural codec compression. Unlike conventional Wave U-Net approaches that rely on dense skip connections, Brontes uses a **synthesis-first architecture** with selective deep skips, forcing the model to actively reconstruct rather than copy degraded input details.
### Key Capabilities
- **Neural codec repair** — removes compression artifacts from neural codec outputs
- **Bandwidth extension** — upsamples from 24 kHz to 48 kHz (2× extension)
- **Waveform-domain processing** — operates directly on audio samples, no spectrogram conversion
- **Synthesis-first design** — only the two deepest skips retained, preventing artifact leakage
- **LSTM bottleneck** — captures long-range temporal dependencies at maximum compression
### Model Architecture
- **Type:** Encoder-decoder U-Net with selective skip connections
- **Stages:** 6 encoder stages + 6 decoder stages (4096× total compression)
- **Bottleneck:** Bidirectional LSTM for temporal modeling
- **Parameters:** ~29M
- **Input:** 24 kHz mono audio (codec-degraded)
- **Output:** 48 kHz mono audio (enhanced)
## Intended Use
This is a **general pretrained model** trained on diverse audio data. For optimal performance on your specific use case:
⚠️ **It is strongly recommended to fine-tune this model on your target dataset** using the `--pretrained` flag.
### Primary Use Cases
- Repairing audio degraded by neural codecs (e.g., EnCodec, SoundStream, Lyra)
- Bandwidth extension from narrowband/wideband to fullband
- Speech enhancement and quality improvement
- Post-processing for codec-compressed audio
## Quick Start
For detailed usage instructions, training, and fine-tuning, please see the [GitHub repository](https://github.com/ZDisket/Brontes).
### Basic Inference Example
```python
import torch
import torchaudio
import yaml
from brontes import Brontes
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load config
with open('configs/config_brontes_48khz_demucs.yaml', 'r') as f:
config = yaml.safe_load(f)
# Create model
model = Brontes(unet_config=config['model'].get('unet_config', {})).to(device)
# Load checkpoint
checkpoint = torch.load('path/to/checkpoint.pt', map_location=device)
model.load_state_dict(checkpoint['model'] if 'model' in checkpoint else checkpoint)
model.eval()
# Load audio
audio, sr = torchaudio.load('input.wav')
target_sr = config['dataset']['sample_rate']
# Resample if necessary
if sr != target_sr:
resampler = torchaudio.transforms.Resample(sr, target_sr)
audio = resampler(audio)
# Convert to mono and normalize
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
max_val = audio.abs().max()
if max_val > 0:
audio = audio / max_val
# Add batch dimension and process
audio = audio.unsqueeze(0).to(device)
with torch.no_grad():
output, _, _, _ = model(audio)
# Save output
output = output.squeeze(0).cpu()
if output.abs().max() > 1.0:
output = output / output.abs().max()
torchaudio.save('output.wav', output, target_sr)
```
Or use the command-line interface:
```bash
python infer_brontes.py \
--config configs/config_brontes_48khz_demucs.yaml \
--checkpoint path/to/checkpoint.pt \
--input input.wav \
--output output.wav
```
## Training Details
### Training Data
The model was trained on diverse audio data including:
- Clean speech recordings
- Codec-degraded audio pairs
- Various acoustic conditions and speakers
### Training Procedure
- **Pretraining:** 10,000 steps generator-only training
- **Adversarial training:** Multi-Period Discriminator (MPD) + Multi-Band Spectral Discriminator (MBSD)
- **Loss functions:** Multi-scale mel loss, pitch loss, adversarial loss, feature matching
- **Precision:** BF16 mixed precision
- **Framework:** PyTorch with custom training loop
## Fine-tuning Recommendations
To achieve best results on your specific dataset:
1. **Prepare paired data:** Input (degraded) and target (clean) audio pairs
2. **Use the `--pretrained` flag** to load model weights without optimizer state
3. **Train for 10-50k steps** depending on dataset size
4. **Monitor validation loss** to prevent overfitting
See the [repository README](https://github.com/ZDisket/Brontes) for detailed fine-tuning instructions.
## Limitations
- **Domain-specific performance:** General model may not perform optimally on highly specialized audio (fine-tuning recommended)
- **Mono audio only:** Currently supports single-channel audio
- **Fixed sample rates:** Designed for 24 kHz input → 48 kHz output
- **Codec-specific artifacts:** Performance may vary across different codec types
- **Long-form audio:** Very long audio files may require chunking or sufficient GPU memory
## Ethical Considerations
- This model is designed for audio enhancement and should not be used to create misleading or deceptive content
- Users should respect privacy and consent when processing speech recordings
- Enhanced audio should be clearly labeled as processed when used in sensitive contexts
## License
Both the model weights and code are released under the MIT License.
## Additional Resources
- **GitHub Repository:** [https://github.com/ZDisket/Brontes](https://github.com/ZDisket/Brontes)
- **Technical Report:** See the repository
- **Issues & Support:** [GitHub Issues](https://github.com/ZDisket/Brontes/issues)
## Acknowledgments
Compute resources provided by Hot Aisle and AI at AMD.
|