File size: 5,077 Bytes
2bbfbb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
//! Main TTS pipeline orchestration
//!
//! Coordinates text processing, model inference, and audio synthesis
mod synthesis;
pub use synthesis::{IndexTTS, SynthesisOptions, SynthesisResult};
use crate::{Error, Result};
use std::path::{Path, PathBuf};
/// Pipeline stage enumeration
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PipelineStage {
TextNormalization,
Tokenization,
SemanticEncoding,
SpeakerConditioning,
GptGeneration,
AcousticExpansion,
Vocoding,
PostProcessing,
}
impl PipelineStage {
/// Get stage name
pub fn name(&self) -> &'static str {
match self {
PipelineStage::TextNormalization => "Text Normalization",
PipelineStage::Tokenization => "Tokenization",
PipelineStage::SemanticEncoding => "Semantic Encoding",
PipelineStage::SpeakerConditioning => "Speaker Conditioning",
PipelineStage::GptGeneration => "GPT Generation",
PipelineStage::AcousticExpansion => "Acoustic Expansion",
PipelineStage::Vocoding => "Vocoding",
PipelineStage::PostProcessing => "Post Processing",
}
}
/// Get all stages in order
pub fn all() -> Vec<PipelineStage> {
vec![
PipelineStage::TextNormalization,
PipelineStage::Tokenization,
PipelineStage::SemanticEncoding,
PipelineStage::SpeakerConditioning,
PipelineStage::GptGeneration,
PipelineStage::AcousticExpansion,
PipelineStage::Vocoding,
PipelineStage::PostProcessing,
]
}
}
/// Pipeline progress callback
pub type ProgressCallback = Box<dyn Fn(PipelineStage, f32) + Send + Sync>;
/// Pipeline configuration
#[derive(Debug, Clone)]
pub struct PipelineConfig {
/// Model directory
pub model_dir: PathBuf,
/// Use FP16 inference
pub use_fp16: bool,
/// Device (cpu, cuda:0, etc.)
pub device: String,
/// Enable caching
pub enable_cache: bool,
/// Maximum text length
pub max_text_length: usize,
/// Maximum audio duration (seconds)
pub max_audio_duration: f32,
}
impl Default for PipelineConfig {
fn default() -> Self {
Self {
model_dir: PathBuf::from("models"),
use_fp16: false,
device: "cpu".to_string(),
enable_cache: true,
max_text_length: 500,
max_audio_duration: 30.0,
}
}
}
impl PipelineConfig {
/// Create config with model directory
pub fn with_model_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
self.model_dir = path.as_ref().to_path_buf();
self
}
/// Enable FP16 inference
pub fn with_fp16(mut self, enable: bool) -> Self {
self.use_fp16 = enable;
self
}
/// Set device
pub fn with_device(mut self, device: &str) -> Self {
self.device = device.to_string();
self
}
/// Validate configuration
pub fn validate(&self) -> Result<()> {
if !self.model_dir.exists() {
log::warn!(
"Model directory does not exist: {}",
self.model_dir.display()
);
}
if self.max_text_length == 0 {
return Err(Error::Config("max_text_length must be > 0".into()));
}
if self.max_audio_duration <= 0.0 {
return Err(Error::Config("max_audio_duration must be > 0".into()));
}
Ok(())
}
}
/// Text segmentation for long-form synthesis
pub fn segment_text(text: &str, max_segment_len: usize) -> Vec<String> {
use crate::text::TextNormalizer;
let normalizer = TextNormalizer::new();
let sentences = normalizer.split_sentences(text);
let mut segments = Vec::new();
let mut current_segment = String::new();
for sentence in sentences {
if current_segment.len() + sentence.len() > max_segment_len && !current_segment.is_empty()
{
segments.push(current_segment.trim().to_string());
current_segment = sentence;
} else {
if !current_segment.is_empty() {
current_segment.push(' ');
}
current_segment.push_str(&sentence);
}
}
if !current_segment.trim().is_empty() {
segments.push(current_segment.trim().to_string());
}
segments
}
/// Concatenate audio segments with silence
pub fn concatenate_audio(segments: &[Vec<f32>], silence_duration_ms: u32, sample_rate: u32) -> Vec<f32> {
let silence_samples = (silence_duration_ms as usize * sample_rate as usize) / 1000;
let silence = vec![0.0f32; silence_samples];
let mut result = Vec::new();
for (i, segment) in segments.iter().enumerate() {
result.extend_from_slice(segment);
if i < segments.len() - 1 {
result.extend_from_slice(&silence);
}
}
result
}
/// Estimate synthesis duration
pub fn estimate_duration(text: &str, chars_per_second: f32) -> f32 {
text.chars().count() as f32 / chars_per_second
}
|