File size: 5,077 Bytes
2bbfbb7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
//! Main TTS pipeline orchestration
//!
//! Coordinates text processing, model inference, and audio synthesis

mod synthesis;

pub use synthesis::{IndexTTS, SynthesisOptions, SynthesisResult};

use crate::{Error, Result};
use std::path::{Path, PathBuf};

/// Pipeline stage enumeration
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PipelineStage {
    TextNormalization,
    Tokenization,
    SemanticEncoding,
    SpeakerConditioning,
    GptGeneration,
    AcousticExpansion,
    Vocoding,
    PostProcessing,
}

impl PipelineStage {
    /// Get stage name
    pub fn name(&self) -> &'static str {
        match self {
            PipelineStage::TextNormalization => "Text Normalization",
            PipelineStage::Tokenization => "Tokenization",
            PipelineStage::SemanticEncoding => "Semantic Encoding",
            PipelineStage::SpeakerConditioning => "Speaker Conditioning",
            PipelineStage::GptGeneration => "GPT Generation",
            PipelineStage::AcousticExpansion => "Acoustic Expansion",
            PipelineStage::Vocoding => "Vocoding",
            PipelineStage::PostProcessing => "Post Processing",
        }
    }

    /// Get all stages in order
    pub fn all() -> Vec<PipelineStage> {
        vec![
            PipelineStage::TextNormalization,
            PipelineStage::Tokenization,
            PipelineStage::SemanticEncoding,
            PipelineStage::SpeakerConditioning,
            PipelineStage::GptGeneration,
            PipelineStage::AcousticExpansion,
            PipelineStage::Vocoding,
            PipelineStage::PostProcessing,
        ]
    }
}

/// Pipeline progress callback
pub type ProgressCallback = Box<dyn Fn(PipelineStage, f32) + Send + Sync>;

/// Pipeline configuration
#[derive(Debug, Clone)]
pub struct PipelineConfig {
    /// Model directory
    pub model_dir: PathBuf,
    /// Use FP16 inference
    pub use_fp16: bool,
    /// Device (cpu, cuda:0, etc.)
    pub device: String,
    /// Enable caching
    pub enable_cache: bool,
    /// Maximum text length
    pub max_text_length: usize,
    /// Maximum audio duration (seconds)
    pub max_audio_duration: f32,
}

impl Default for PipelineConfig {
    fn default() -> Self {
        Self {
            model_dir: PathBuf::from("models"),
            use_fp16: false,
            device: "cpu".to_string(),
            enable_cache: true,
            max_text_length: 500,
            max_audio_duration: 30.0,
        }
    }
}

impl PipelineConfig {
    /// Create config with model directory
    pub fn with_model_dir<P: AsRef<Path>>(mut self, path: P) -> Self {
        self.model_dir = path.as_ref().to_path_buf();
        self
    }

    /// Enable FP16 inference
    pub fn with_fp16(mut self, enable: bool) -> Self {
        self.use_fp16 = enable;
        self
    }

    /// Set device
    pub fn with_device(mut self, device: &str) -> Self {
        self.device = device.to_string();
        self
    }

    /// Validate configuration
    pub fn validate(&self) -> Result<()> {
        if !self.model_dir.exists() {
            log::warn!(
                "Model directory does not exist: {}",
                self.model_dir.display()
            );
        }

        if self.max_text_length == 0 {
            return Err(Error::Config("max_text_length must be > 0".into()));
        }

        if self.max_audio_duration <= 0.0 {
            return Err(Error::Config("max_audio_duration must be > 0".into()));
        }

        Ok(())
    }
}

/// Text segmentation for long-form synthesis
pub fn segment_text(text: &str, max_segment_len: usize) -> Vec<String> {
    use crate::text::TextNormalizer;

    let normalizer = TextNormalizer::new();
    let sentences = normalizer.split_sentences(text);

    let mut segments = Vec::new();
    let mut current_segment = String::new();

    for sentence in sentences {
        if current_segment.len() + sentence.len() > max_segment_len && !current_segment.is_empty()
        {
            segments.push(current_segment.trim().to_string());
            current_segment = sentence;
        } else {
            if !current_segment.is_empty() {
                current_segment.push(' ');
            }
            current_segment.push_str(&sentence);
        }
    }

    if !current_segment.trim().is_empty() {
        segments.push(current_segment.trim().to_string());
    }

    segments
}

/// Concatenate audio segments with silence
pub fn concatenate_audio(segments: &[Vec<f32>], silence_duration_ms: u32, sample_rate: u32) -> Vec<f32> {
    let silence_samples = (silence_duration_ms as usize * sample_rate as usize) / 1000;
    let silence = vec![0.0f32; silence_samples];

    let mut result = Vec::new();

    for (i, segment) in segments.iter().enumerate() {
        result.extend_from_slice(segment);
        if i < segments.len() - 1 {
            result.extend_from_slice(&silence);
        }
    }

    result
}

/// Estimate synthesis duration
pub fn estimate_duration(text: &str, chars_per_second: f32) -> f32 {
    text.chars().count() as f32 / chars_per_second
}