ZDisket's picture
Update README.md
9491b77 verified
---
license: mit
tags:
- audio
- audio-enhancement
- speech-enhancement
- bandwidth-extension
- codec-repair
- neural-codec
- waveform-processing
- pytorch
library_name: pytorch
pipeline_tag: audio-to-audio
frameworks: PyTorch
language:
- en
---
# Brontes: Synthesis-First Waveform Enhancement
**Brontes** is a time-domain audio enhancement model designed for neural codec repair and bandwidth extension. This is the general pretrained model trained on diverse audio data.
## Model Description
Brontes upsamples and repairs speech degraded by neural codec compression. Unlike conventional Wave U-Net approaches that rely on dense skip connections, Brontes uses a **synthesis-first architecture** with selective deep skips, forcing the model to actively reconstruct rather than copy degraded input details.
### Key Capabilities
- **Neural codec repair** — removes compression artifacts from neural codec outputs
- **Bandwidth extension** — upsamples from 24 kHz to 48 kHz (2× extension)
- **Waveform-domain processing** — operates directly on audio samples, no spectrogram conversion
- **Synthesis-first design** — only the two deepest skips retained, preventing artifact leakage
- **LSTM bottleneck** — captures long-range temporal dependencies at maximum compression
### Model Architecture
- **Type:** Encoder-decoder U-Net with selective skip connections
- **Stages:** 6 encoder stages + 6 decoder stages (4096× total compression)
- **Bottleneck:** Bidirectional LSTM for temporal modeling
- **Parameters:** ~29M
- **Input:** 24 kHz mono audio (codec-degraded)
- **Output:** 48 kHz mono audio (enhanced)
## Intended Use
This is a **general pretrained model** trained on diverse audio data. For optimal performance on your specific use case:
⚠️ **It is strongly recommended to fine-tune this model on your target dataset** using the `--pretrained` flag.
### Primary Use Cases
- Repairing audio degraded by neural codecs (e.g., EnCodec, SoundStream, Lyra)
- Bandwidth extension from narrowband/wideband to fullband
- Speech enhancement and quality improvement
- Post-processing for codec-compressed audio
## Quick Start
For detailed usage instructions, training, and fine-tuning, please see the [GitHub repository](https://github.com/ZDisket/Brontes).
### Basic Inference Example
```python
import torch
import torchaudio
import yaml
from brontes import Brontes
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load config
with open('configs/config_brontes_48khz_demucs.yaml', 'r') as f:
config = yaml.safe_load(f)
# Create model
model = Brontes(unet_config=config['model'].get('unet_config', {})).to(device)
# Load checkpoint
checkpoint = torch.load('path/to/checkpoint.pt', map_location=device)
model.load_state_dict(checkpoint['model'] if 'model' in checkpoint else checkpoint)
model.eval()
# Load audio
audio, sr = torchaudio.load('input.wav')
target_sr = config['dataset']['sample_rate']
# Resample if necessary
if sr != target_sr:
resampler = torchaudio.transforms.Resample(sr, target_sr)
audio = resampler(audio)
# Convert to mono and normalize
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
max_val = audio.abs().max()
if max_val > 0:
audio = audio / max_val
# Add batch dimension and process
audio = audio.unsqueeze(0).to(device)
with torch.no_grad():
output, _, _, _ = model(audio)
# Save output
output = output.squeeze(0).cpu()
if output.abs().max() > 1.0:
output = output / output.abs().max()
torchaudio.save('output.wav', output, target_sr)
```
Or use the command-line interface:
```bash
python infer_brontes.py \
--config configs/config_brontes_48khz_demucs.yaml \
--checkpoint path/to/checkpoint.pt \
--input input.wav \
--output output.wav
```
## Training Details
### Training Data
The model was trained on diverse audio data including:
- Clean speech recordings
- Codec-degraded audio pairs
- Various acoustic conditions and speakers
### Training Procedure
- **Pretraining:** 10,000 steps generator-only training
- **Adversarial training:** Multi-Period Discriminator (MPD) + Multi-Band Spectral Discriminator (MBSD)
- **Loss functions:** Multi-scale mel loss, pitch loss, adversarial loss, feature matching
- **Precision:** BF16 mixed precision
- **Framework:** PyTorch with custom training loop
## Fine-tuning Recommendations
To achieve best results on your specific dataset:
1. **Prepare paired data:** Input (degraded) and target (clean) audio pairs
2. **Use the `--pretrained` flag** to load model weights without optimizer state
3. **Train for 10-50k steps** depending on dataset size
4. **Monitor validation loss** to prevent overfitting
See the [repository README](https://github.com/ZDisket/Brontes) for detailed fine-tuning instructions.
## Limitations
- **Domain-specific performance:** General model may not perform optimally on highly specialized audio (fine-tuning recommended)
- **Mono audio only:** Currently supports single-channel audio
- **Fixed sample rates:** Designed for 24 kHz input → 48 kHz output
- **Codec-specific artifacts:** Performance may vary across different codec types
- **Long-form audio:** Very long audio files may require chunking or sufficient GPU memory
## Ethical Considerations
- This model is designed for audio enhancement and should not be used to create misleading or deceptive content
- Users should respect privacy and consent when processing speech recordings
- Enhanced audio should be clearly labeled as processed when used in sensitive contexts
## License
Both the model weights and code are released under the MIT License.
## Additional Resources
- **GitHub Repository:** [https://github.com/ZDisket/Brontes](https://github.com/ZDisket/Brontes)
- **Technical Report:** See the repository
- **Issues & Support:** [GitHub Issues](https://github.com/ZDisket/Brontes/issues)
## Acknowledgments
Compute resources provided by Hot Aisle and AI at AMD.