File size: 10,403 Bytes
79f7931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278e294
 
 
 
 
 
 
79f7931
 
 
 
 
 
278e294
 
 
 
 
79f7931
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""
Configuration module for Speech Pathology MVP.

This module contains all configuration dataclasses for the speech pathology
diagnosis system, including audio processing, model inference, API, and training settings.
"""

from dataclasses import dataclass, field
from typing import List, Optional


@dataclass
class AudioConfig:
    """
    Configuration for audio processing and VAD (Voice Activity Detection).
    
    Attributes:
        sample_rate: Target sample rate for audio processing in Hz.
                     Standard 16kHz is optimal for speech recognition models.
        chunk_duration_ms: Duration of each audio chunk in milliseconds.
                          Set to 20ms for phone-level analysis (matches typical
                          phone duration in speech).
        vad_aggressiveness: VAD aggressiveness mode (0-3).
                           - 0: Least aggressive (fewer false positives)
                           - 1: Moderate
                           - 2: More aggressive
                           - 3: Most aggressive (fewer false negatives)
                           Higher values detect more speech but may include noise.
        frame_length_ms: Frame length for audio analysis in milliseconds.
                        Typically matches chunk_duration_ms for consistency.
        hop_length_ms: Hop length between frames in milliseconds.
                      Overlapping frames improve temporal resolution.
        n_fft: Number of FFT points for spectral analysis.
               Should be power of 2, typically 512 or 1024 for 16kHz audio.
    """
    sample_rate: int = 16000
    chunk_duration_ms: int = 20
    vad_aggressiveness: int = 3
    frame_length_ms: int = 20
    hop_length_ms: int = 10
    n_fft: int = 512


@dataclass
class ModelConfig:
    """
    Configuration for Wav2Vec2 model and classifier architecture.
    
    Attributes:
        model_name: HuggingFace model identifier for Wav2Vec2-XLSR-53.
                   "facebook/wav2vec2-large-xlsr-53" provides excellent
                   multilingual speech representations.
        classifier_hidden_dims: List of hidden layer dimensions for the classifier.
                               [256, 128] creates a 2-layer MLP:
                               - Input: Wav2Vec2 feature dim (1024 for large)
                               - Hidden 1: 256 units
                               - Hidden 2: 128 units
                               - Output: 2 (articulation + fluency scores)
        dropout: Dropout probability for classifier layers.
                Prevents overfitting during training.
        num_labels: Number of output labels (2: articulation, fluency).
        device: Device to run inference on ("cuda" or "cpu").
        use_fp16: Whether to use half-precision for faster inference.
                 Requires CUDA and may slightly reduce accuracy.
        max_length: Maximum sequence length in samples.
                   Longer sequences are truncated to this length.
    """
    model_name: str = "facebook/wav2vec2-large-xlsr-53"
    classifier_hidden_dims: List[int] = field(default_factory=lambda: [256, 128])
    dropout: float = 0.1
    num_labels: int = 2
    device: str = "cuda"  # Will be set to "cpu" if CUDA unavailable
    use_fp16: bool = False
    max_length: int = 250000  # ~15.6 seconds at 16kHz


@dataclass
class InferenceConfig:
    """
    Configuration for inference thresholds and post-processing.
    
    Attributes:
        fluency_threshold: Threshold for fluency classification (0.0-1.0).
                          Scores above this indicate fluent speech.
                          Lower values = more sensitive (detects more issues).
        articulation_threshold: Threshold for articulation classification (0.0-1.0).
                               Scores above this indicate clear articulation.
                               Lower values = more sensitive (detects more issues).
        min_chunk_duration_ms: Minimum duration of a chunk to be analyzed.
                              Filters out very short audio segments.
        smoothing_window: Window size for temporal smoothing of predictions.
                         Reduces jitter in frame-level predictions.
        batch_size: Number of chunks to process in parallel during inference.
                   Higher values = faster but more memory usage.
        window_size_ms: Size of sliding window in milliseconds (default: 1000ms = 1 second).
                       Minimum for Wav2Vec2 stability.
        hop_size_ms: Hop size between windows in milliseconds (default: 10ms).
                    Controls temporal resolution (100 frames/second).
        frame_rate: Frames per second (calculated from hop_size_ms).
        minimum_audio_length: Minimum audio length in seconds (must be >= window_size_ms).
        phone_level_strategy: Strategy for phone-level analysis ("sliding_window").
    """
    fluency_threshold: float = 0.5
    articulation_threshold: float = 0.6
    min_chunk_duration_ms: int = 10
    smoothing_window: int = 5
    batch_size: int = 32
    window_size_ms: int = 1000  # 1 second minimum for Wav2Vec2
    hop_size_ms: int = 10  # 10ms for phone-level resolution
    frame_rate: float = 100.0  # 100 frames per second (1/hop_size_ms)
    minimum_audio_length: float = 1.0  # Must be >= window_size_ms
    phone_level_strategy: str = "sliding_window"


@dataclass
class APIConfig:
    """
    Configuration for FastAPI REST API and WebSocket server.
    
    Attributes:
        host: Host address to bind the API server.
             "0.0.0.0" allows external connections.
        port: Port number for the REST API server.
        websocket_port: Port number for the WebSocket streaming server.
                       Separate port allows independent scaling.
        max_file_size_mb: Maximum uploaded file size in megabytes.
                         Prevents memory exhaustion from large files.
        allowed_extensions: List of allowed audio file extensions.
        cors_origins: List of allowed CORS origins.
                    ["*"] allows all origins (use with caution in production).
        request_timeout: Request timeout in seconds.
                       Prevents hanging requests from blocking server.
        max_connections: Maximum concurrent WebSocket connections.
        chunk_size: Size of WebSocket message chunks in bytes.
    """
    host: str = "0.0.0.0"
    port: int = 8000
    websocket_port: int = 8001
    max_file_size_mb: int = 100
    allowed_extensions: List[str] = field(default_factory=lambda: [".wav", ".mp3", ".flac", ".m4a"])
    cors_origins: List[str] = field(default_factory=lambda: ["*"])
    request_timeout: int = 300
    max_connections: int = 100
    chunk_size: int = 4096


@dataclass
class GradioConfig:
    """
    Configuration for Gradio web interface.
    
    Attributes:
        enabled: Whether to enable the Gradio interface.
        port: Port number for Gradio server.
             Typically runs on a different port than REST API.
        share: Whether to create a public Gradio link.
              Useful for testing but exposes the interface publicly.
        theme: Gradio theme name or custom theme configuration.
        title: Title displayed in the Gradio interface.
        description: Description text for the Gradio interface.
        examples: List of example file paths to include in the interface.
    """
    enabled: bool = True
    port: int = 7860
    share: bool = False
    theme: str = "default"
    title: str = "Speech Pathology Diagnosis"
    description: str = "Upload audio for articulation and fluency analysis"
    examples: List[str] = field(default_factory=list)


@dataclass
class TrainingConfig:
    """
    Configuration for model training (if training custom classifier).
    
    Attributes:
        learning_rate: Initial learning rate for optimizer.
                      Lower values (1e-4 to 1e-5) work well for fine-tuning.
        batch_size: Training batch size.
                   Larger batches = more stable gradients but more memory.
        num_epochs: Number of training epochs.
                   Early stopping should prevent overfitting.
        weight_decay: L2 regularization strength.
                    Prevents overfitting by penalizing large weights.
        warmup_steps: Number of warmup steps for learning rate scheduler.
                     Gradually increases LR from 0 to target value.
        eval_steps: Evaluate model every N steps during training.
        save_steps: Save model checkpoint every N steps.
        output_dir: Directory to save trained models and checkpoints.
        logging_dir: Directory to save training logs (for TensorBoard).
        gradient_accumulation_steps: Accumulate gradients over N steps.
                                    Effective batch size = batch_size * N.
        fp16: Whether to use mixed precision training (faster on modern GPUs).
        dataloader_num_workers: Number of worker processes for data loading.
                               Parallelizes data preprocessing.
    """
    learning_rate: float = 5e-5
    batch_size: int = 16
    num_epochs: int = 10
    weight_decay: float = 0.01
    warmup_steps: int = 500
    eval_steps: int = 500
    save_steps: int = 1000
    output_dir: str = "./models/trained"
    logging_dir: str = "./logs"
    gradient_accumulation_steps: int = 1
    fp16: bool = False
    dataloader_num_workers: int = 4


@dataclass
class LoggingConfig:
    """
    Configuration for logging system.
    
    Attributes:
        level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL).
        format: Log message format string.
        file: Optional log file path. If None, logs only to console.
        max_bytes: Maximum log file size before rotation (in bytes).
        backup_count: Number of backup log files to keep.
    """
    level: str = "INFO"
    format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
    file: Optional[str] = None
    max_bytes: int = 10485760  # 10MB
    backup_count: int = 5


# Default configuration instances
default_audio_config = AudioConfig()
default_model_config = ModelConfig()
default_inference_config = InferenceConfig()
default_api_config = APIConfig()
default_gradio_config = GradioConfig()
default_training_config = TrainingConfig()
default_logging_config = LoggingConfig()