| | |
| | |
| | |
| |
|
| | use crate::{Error, Result}; |
| | use ndarray::{Array, Array2, IxDyn}; |
| | use std::collections::HashMap; |
| | use std::path::Path; |
| |
|
| | use crate::model::OnnxSession; |
| | use super::{Vocoder, snake_activation_vec}; |
| |
|
| | |
| | #[derive(Debug, Clone)] |
| | pub struct BigVGANConfig { |
| | |
| | pub sample_rate: u32, |
| | |
| | pub num_mels: usize, |
| | |
| | pub upsample_rates: Vec<usize>, |
| | |
| | pub upsample_kernel_sizes: Vec<usize>, |
| | |
| | pub resblock_kernel_sizes: Vec<usize>, |
| | |
| | pub resblock_dilation_sizes: Vec<Vec<usize>>, |
| | |
| | pub upsample_initial_channel: usize, |
| | |
| | pub use_anti_alias: bool, |
| | } |
| |
|
| | impl Default for BigVGANConfig { |
| | fn default() -> Self { |
| | Self { |
| | sample_rate: 22050, |
| | num_mels: 80, |
| | upsample_rates: vec![8, 8, 2, 2], |
| | upsample_kernel_sizes: vec![16, 16, 4, 4], |
| | resblock_kernel_sizes: vec![3, 7, 11], |
| | resblock_dilation_sizes: vec![vec![1, 3, 5], vec![1, 3, 5], vec![1, 3, 5]], |
| | upsample_initial_channel: 512, |
| | use_anti_alias: true, |
| | } |
| | } |
| | } |
| |
|
| | impl BigVGANConfig { |
| | |
| | pub fn total_upsample_factor(&self) -> usize { |
| | self.upsample_rates.iter().product() |
| | } |
| |
|
| | |
| | pub fn hop_length(&self) -> usize { |
| | self.total_upsample_factor() |
| | } |
| | } |
| |
|
| | |
| | pub struct BigVGAN { |
| | session: Option<OnnxSession>, |
| | config: BigVGANConfig, |
| | } |
| |
|
| | impl BigVGAN { |
| | |
| | pub fn load<P: AsRef<Path>>(path: P, config: BigVGANConfig) -> Result<Self> { |
| | let session = OnnxSession::load(path)?; |
| | Ok(Self { |
| | session: Some(session), |
| | config, |
| | }) |
| | } |
| |
|
| | |
| | pub fn new_fallback(config: BigVGANConfig) -> Self { |
| | Self { |
| | session: None, |
| | config, |
| | } |
| | } |
| |
|
| | |
| | pub fn config(&self) -> &BigVGANConfig { |
| | &self.config |
| | } |
| |
|
| | |
| | fn synthesize_fallback(&self, mel: &Array2<f32>) -> Result<Vec<f32>> { |
| | |
| | let num_frames = mel.ncols(); |
| | let hop_length = self.config.hop_length(); |
| | let frame_size = hop_length * 4; |
| |
|
| | let output_length = (num_frames - 1) * hop_length + frame_size; |
| | let mut output = vec![0.0f32; output_length]; |
| | let mut window_sum = vec![0.0f32; output_length]; |
| |
|
| | |
| | let window: Vec<f32> = (0..frame_size) |
| | .map(|n| { |
| | 0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / frame_size as f32).cos()) |
| | }) |
| | .collect(); |
| |
|
| | |
| | for frame_idx in 0..num_frames { |
| | let start = frame_idx * hop_length; |
| |
|
| | |
| | let mel_frame: Vec<f32> = (0..self.config.num_mels) |
| | .map(|i| mel[[i, frame_idx]]) |
| | .collect(); |
| |
|
| | |
| | let frame = self.generate_frame(&mel_frame, frame_size); |
| |
|
| | |
| | for i in 0..frame_size { |
| | if start + i < output_length { |
| | output[start + i] += frame[i] * window[i]; |
| | window_sum[start + i] += window[i] * window[i]; |
| | } |
| | } |
| | } |
| |
|
| | |
| | for i in 0..output_length { |
| | if window_sum[i] > 1e-8 { |
| | output[i] /= window_sum[i]; |
| | } |
| | } |
| |
|
| | |
| | let output = snake_activation_vec(&output, 0.3); |
| |
|
| | Ok(output) |
| | } |
| |
|
| | |
| | fn generate_frame(&self, mel: &[f32], frame_size: usize) -> Vec<f32> { |
| | use rand::Rng; |
| | let mut rng = rand::thread_rng(); |
| |
|
| | |
| | let energy: f32 = mel.iter().map(|x| x.exp()).sum::<f32>() / mel.len() as f32; |
| | let energy = energy.sqrt().min(2.0); |
| |
|
| | |
| | let mut frame = vec![0.0f32; frame_size]; |
| |
|
| | |
| | for (freq_idx, &mel_val) in mel.iter().enumerate() { |
| | let freq = (freq_idx as f32 / mel.len() as f32) * (self.config.sample_rate as f32 / 2.0); |
| | let amplitude = mel_val.exp().min(1.0) * 0.1; |
| |
|
| | |
| | for i in 0..frame_size { |
| | let t = i as f32 / self.config.sample_rate as f32; |
| | frame[i] += amplitude * (2.0 * std::f32::consts::PI * freq * t).sin(); |
| | } |
| | } |
| |
|
| | |
| | for i in 0..frame_size { |
| | frame[i] += rng.gen_range(-0.1..0.1) * energy * 0.1; |
| | } |
| |
|
| | |
| | let max_abs = frame.iter().map(|x| x.abs()).fold(0.0f32, f32::max); |
| | if max_abs > 1.0 { |
| | for v in frame.iter_mut() { |
| | *v /= max_abs; |
| | } |
| | } |
| |
|
| | frame |
| | } |
| |
|
| | |
| | pub fn post_process(&self, audio: &[f32]) -> Vec<f32> { |
| | use crate::audio::{normalize_audio, apply_fade}; |
| |
|
| | let normalized = normalize_audio(audio); |
| |
|
| | |
| | let fade_samples = (self.config.sample_rate as f32 * 0.01) as usize; |
| | apply_fade(&normalized, fade_samples, fade_samples) |
| | } |
| | } |
| |
|
| | impl Vocoder for BigVGAN { |
| | fn synthesize(&self, mel: &Array2<f32>) -> Result<Vec<f32>> { |
| | if let Some(ref session) = self.session { |
| | |
| | let input = mel.clone().into_shape(IxDyn(&[1, mel.nrows(), mel.ncols()]))?; |
| |
|
| | let mut inputs = HashMap::new(); |
| | inputs.insert("mel".to_string(), input); |
| |
|
| | let outputs = session.run(inputs)?; |
| |
|
| | let audio = outputs |
| | .get("audio") |
| | .ok_or_else(|| Error::Vocoder("Missing audio output".into()))?; |
| |
|
| | |
| | let samples: Vec<f32> = audio.iter().cloned().collect(); |
| |
|
| | Ok(self.post_process(&samples)) |
| | } else { |
| | |
| | let audio = self.synthesize_fallback(mel)?; |
| | Ok(self.post_process(&audio)) |
| | } |
| | } |
| |
|
| | fn sample_rate(&self) -> u32 { |
| | self.config.sample_rate |
| | } |
| |
|
| | fn hop_length(&self) -> usize { |
| | self.config.hop_length() |
| | } |
| | } |
| |
|
| | |
| | pub fn create_bigvgan_22k() -> BigVGAN { |
| | let config = BigVGANConfig { |
| | sample_rate: 22050, |
| | ..Default::default() |
| | }; |
| | BigVGAN::new_fallback(config) |
| | } |
| |
|
| | |
| | pub fn create_bigvgan_24k() -> BigVGAN { |
| | let config = BigVGANConfig { |
| | sample_rate: 24000, |
| | upsample_rates: vec![12, 10, 2, 2], |
| | ..Default::default() |
| | }; |
| | BigVGAN::new_fallback(config) |
| | } |
| |
|
| | #[cfg(test)] |
| | mod tests { |
| | use super::*; |
| |
|
| | #[test] |
| | fn test_bigvgan_config() { |
| | let config = BigVGANConfig::default(); |
| | assert_eq!(config.total_upsample_factor(), 256); |
| | assert_eq!(config.hop_length(), 256); |
| | } |
| |
|
| | #[test] |
| | fn test_bigvgan_fallback() { |
| | let vocoder = create_bigvgan_22k(); |
| | assert_eq!(vocoder.sample_rate(), 22050); |
| |
|
| | |
| | let mel = Array2::zeros((80, 10)); |
| | let result = vocoder.synthesize(&mel); |
| | assert!(result.is_ok()); |
| |
|
| | let audio = result.unwrap(); |
| | assert!(audio.len() > 0); |
| | } |
| |
|
| | #[test] |
| | fn test_generate_frame() { |
| | let vocoder = create_bigvgan_22k(); |
| | let mel = vec![0.0f32; 80]; |
| | let frame = vocoder.generate_frame(&mel, 256); |
| | assert_eq!(frame.len(), 256); |
| | } |
| |
|
| | #[test] |
| | fn test_post_process() { |
| | let vocoder = create_bigvgan_22k(); |
| | let audio = vec![0.5f32; 1000]; |
| | let processed = vocoder.post_process(&audio); |
| | assert_eq!(processed.len(), audio.len()); |
| | |
| | assert!(processed[0].abs() < 0.1); |
| | } |
| | } |
| |
|