| | |
| |
|
| | use crate::{Error, Result}; |
| | use serde::{Deserialize, Serialize}; |
| | use std::path::{Path, PathBuf}; |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct Config { |
| | |
| | pub gpt: GptConfig, |
| | |
| | pub vocoder: VocoderConfig, |
| | |
| | pub s2mel: S2MelConfig, |
| | |
| | pub dataset: DatasetConfig, |
| | |
| | pub emotions: EmotionConfig, |
| | |
| | pub inference: InferenceConfig, |
| | |
| | pub model_dir: PathBuf, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct GptConfig { |
| | |
| | pub layers: usize, |
| | |
| | pub model_dim: usize, |
| | |
| | pub heads: usize, |
| | |
| | pub max_text_tokens: usize, |
| | |
| | pub max_mel_tokens: usize, |
| | |
| | pub stop_mel_token: usize, |
| | |
| | pub start_text_token: usize, |
| | |
| | pub start_mel_token: usize, |
| | |
| | pub num_mel_codes: usize, |
| | |
| | pub num_text_tokens: usize, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct VocoderConfig { |
| | |
| | pub name: String, |
| | |
| | pub checkpoint: Option<PathBuf>, |
| | |
| | pub use_fp16: bool, |
| | |
| | pub use_deepspeed: bool, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct S2MelConfig { |
| | |
| | pub checkpoint: PathBuf, |
| | |
| | pub preprocess: PreprocessConfig, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct PreprocessConfig { |
| | |
| | pub sr: u32, |
| | |
| | pub n_fft: usize, |
| | |
| | pub hop_length: usize, |
| | |
| | pub win_length: usize, |
| | |
| | pub n_mels: usize, |
| | |
| | pub fmin: f32, |
| | |
| | pub fmax: f32, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct DatasetConfig { |
| | |
| | pub bpe_model: PathBuf, |
| | |
| | pub vocab_size: usize, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct EmotionConfig { |
| | |
| | pub num_dims: usize, |
| | |
| | pub num: Vec<usize>, |
| | |
| | pub matrix_path: Option<PathBuf>, |
| | } |
| |
|
| | |
| | #[derive(Debug, Clone, Serialize, Deserialize)] |
| | pub struct InferenceConfig { |
| | |
| | pub device: String, |
| | |
| | pub use_fp16: bool, |
| | |
| | pub batch_size: usize, |
| | |
| | pub top_k: usize, |
| | |
| | pub top_p: f32, |
| | |
| | pub temperature: f32, |
| | |
| | pub repetition_penalty: f32, |
| | |
| | pub length_penalty: f32, |
| | } |
| |
|
| | impl Default for Config { |
| | fn default() -> Self { |
| | Self { |
| | gpt: GptConfig::default(), |
| | vocoder: VocoderConfig::default(), |
| | s2mel: S2MelConfig::default(), |
| | dataset: DatasetConfig::default(), |
| | emotions: EmotionConfig::default(), |
| | inference: InferenceConfig::default(), |
| | model_dir: PathBuf::from("models"), |
| | } |
| | } |
| | } |
| |
|
| | impl Default for GptConfig { |
| | fn default() -> Self { |
| | Self { |
| | layers: 8, |
| | model_dim: 512, |
| | heads: 8, |
| | max_text_tokens: 120, |
| | max_mel_tokens: 250, |
| | stop_mel_token: 8193, |
| | start_text_token: 8192, |
| | start_mel_token: 8192, |
| | num_mel_codes: 8194, |
| | num_text_tokens: 6681, |
| | } |
| | } |
| | } |
| |
|
| | impl Default for VocoderConfig { |
| | fn default() -> Self { |
| | Self { |
| | name: "bigvgan_v2_22khz_80band_256x".into(), |
| | checkpoint: None, |
| | use_fp16: true, |
| | use_deepspeed: false, |
| | } |
| | } |
| | } |
| |
|
| | impl Default for S2MelConfig { |
| | fn default() -> Self { |
| | Self { |
| | checkpoint: PathBuf::from("models/s2mel.onnx"), |
| | preprocess: PreprocessConfig::default(), |
| | } |
| | } |
| | } |
| |
|
| | impl Default for PreprocessConfig { |
| | fn default() -> Self { |
| | Self { |
| | sr: 22050, |
| | n_fft: 1024, |
| | hop_length: 256, |
| | win_length: 1024, |
| | n_mels: 80, |
| | fmin: 0.0, |
| | fmax: 8000.0, |
| | } |
| | } |
| | } |
| |
|
| | impl Default for DatasetConfig { |
| | fn default() -> Self { |
| | Self { |
| | bpe_model: PathBuf::from("models/bpe.model"), |
| | vocab_size: 6681, |
| | } |
| | } |
| | } |
| |
|
| | impl Default for EmotionConfig { |
| | fn default() -> Self { |
| | Self { |
| | num_dims: 8, |
| | num: vec![5, 6, 8, 6, 5, 4, 7, 6], |
| | matrix_path: Some(PathBuf::from("models/emotion_matrix.safetensors")), |
| | } |
| | } |
| | } |
| |
|
| | impl Default for InferenceConfig { |
| | fn default() -> Self { |
| | Self { |
| | device: "cpu".into(), |
| | use_fp16: false, |
| | batch_size: 1, |
| | top_k: 50, |
| | top_p: 0.95, |
| | temperature: 1.0, |
| | repetition_penalty: 1.0, |
| | length_penalty: 1.0, |
| | } |
| | } |
| | } |
| |
|
| | impl Config { |
| | |
| | pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let path = path.as_ref(); |
| | if !path.exists() { |
| | return Err(Error::FileNotFound(path.display().to_string())); |
| | } |
| |
|
| | let content = std::fs::read_to_string(path)?; |
| | let config: Config = serde_yaml::from_str(&content)?; |
| | Ok(config) |
| | } |
| |
|
| | |
| | pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> { |
| | let content = serde_yaml::to_string(self) |
| | .map_err(|e| Error::Config(format!("Failed to serialize config: {}", e)))?; |
| | std::fs::write(path, content)?; |
| | Ok(()) |
| | } |
| |
|
| | |
| | pub fn load_json<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let path = path.as_ref(); |
| | if !path.exists() { |
| | return Err(Error::FileNotFound(path.display().to_string())); |
| | } |
| |
|
| | let content = std::fs::read_to_string(path)?; |
| | let config: Config = serde_json::from_str(&content)?; |
| | Ok(config) |
| | } |
| |
|
| | |
| | pub fn create_default<P: AsRef<Path>>(path: P) -> Result<Self> { |
| | let config = Config::default(); |
| | config.save(path)?; |
| | Ok(config) |
| | } |
| |
|
| | |
| | pub fn validate(&self) -> Result<()> { |
| | |
| | if !self.model_dir.exists() { |
| | log::warn!( |
| | "Model directory does not exist: {}", |
| | self.model_dir.display() |
| | ); |
| | } |
| |
|
| | |
| | if self.gpt.layers == 0 { |
| | return Err(Error::Config("GPT layers must be > 0".into())); |
| | } |
| | if self.gpt.model_dim == 0 { |
| | return Err(Error::Config("GPT model_dim must be > 0".into())); |
| | } |
| | if self.gpt.heads == 0 { |
| | return Err(Error::Config("GPT heads must be > 0".into())); |
| | } |
| | if self.gpt.model_dim % self.gpt.heads != 0 { |
| | return Err(Error::Config( |
| | "GPT model_dim must be divisible by heads".into(), |
| | )); |
| | } |
| |
|
| | |
| | if self.s2mel.preprocess.sr == 0 { |
| | return Err(Error::Config("Sample rate must be > 0".into())); |
| | } |
| | if self.s2mel.preprocess.n_fft == 0 { |
| | return Err(Error::Config("n_fft must be > 0".into())); |
| | } |
| | if self.s2mel.preprocess.hop_length == 0 { |
| | return Err(Error::Config("hop_length must be > 0".into())); |
| | } |
| |
|
| | |
| | if self.inference.temperature <= 0.0 { |
| | return Err(Error::Config("Temperature must be > 0".into())); |
| | } |
| | if self.inference.top_p <= 0.0 || self.inference.top_p > 1.0 { |
| | return Err(Error::Config("top_p must be in (0, 1]".into())); |
| | } |
| |
|
| | Ok(()) |
| | } |
| | } |
| |
|