""" Configuration module for Speech Pathology MVP. This module contains all configuration dataclasses for the speech pathology diagnosis system, including audio processing, model inference, API, and training settings. """ from dataclasses import dataclass, field from typing import List, Optional @dataclass class AudioConfig: """ Configuration for audio processing and VAD (Voice Activity Detection). Attributes: sample_rate: Target sample rate for audio processing in Hz. Standard 16kHz is optimal for speech recognition models. chunk_duration_ms: Duration of each audio chunk in milliseconds. Set to 20ms for phone-level analysis (matches typical phone duration in speech). vad_aggressiveness: VAD aggressiveness mode (0-3). - 0: Least aggressive (fewer false positives) - 1: Moderate - 2: More aggressive - 3: Most aggressive (fewer false negatives) Higher values detect more speech but may include noise. frame_length_ms: Frame length for audio analysis in milliseconds. Typically matches chunk_duration_ms for consistency. hop_length_ms: Hop length between frames in milliseconds. Overlapping frames improve temporal resolution. n_fft: Number of FFT points for spectral analysis. Should be power of 2, typically 512 or 1024 for 16kHz audio. """ sample_rate: int = 16000 chunk_duration_ms: int = 20 vad_aggressiveness: int = 3 frame_length_ms: int = 20 hop_length_ms: int = 10 n_fft: int = 512 @dataclass class ModelConfig: """ Configuration for Wav2Vec2 model and classifier architecture. Attributes: model_name: HuggingFace model identifier for Wav2Vec2-XLSR-53. "facebook/wav2vec2-large-xlsr-53" provides excellent multilingual speech representations. classifier_hidden_dims: List of hidden layer dimensions for the classifier. [256, 128] creates a 2-layer MLP: - Input: Wav2Vec2 feature dim (1024 for large) - Hidden 1: 256 units - Hidden 2: 128 units - Output: 2 (articulation + fluency scores) dropout: Dropout probability for classifier layers. Prevents overfitting during training. num_labels: Number of output labels (2: articulation, fluency). device: Device to run inference on ("cuda" or "cpu"). use_fp16: Whether to use half-precision for faster inference. Requires CUDA and may slightly reduce accuracy. max_length: Maximum sequence length in samples. Longer sequences are truncated to this length. """ model_name: str = "facebook/wav2vec2-large-xlsr-53" classifier_hidden_dims: List[int] = field(default_factory=lambda: [256, 128]) dropout: float = 0.1 num_labels: int = 2 device: str = "cuda" # Will be set to "cpu" if CUDA unavailable use_fp16: bool = False max_length: int = 250000 # ~15.6 seconds at 16kHz @dataclass class InferenceConfig: """ Configuration for inference thresholds and post-processing. Attributes: fluency_threshold: Threshold for fluency classification (0.0-1.0). Scores above this indicate fluent speech. Lower values = more sensitive (detects more issues). articulation_threshold: Threshold for articulation classification (0.0-1.0). Scores above this indicate clear articulation. Lower values = more sensitive (detects more issues). min_chunk_duration_ms: Minimum duration of a chunk to be analyzed. Filters out very short audio segments. smoothing_window: Window size for temporal smoothing of predictions. Reduces jitter in frame-level predictions. batch_size: Number of chunks to process in parallel during inference. Higher values = faster but more memory usage. window_size_ms: Size of sliding window in milliseconds (default: 1000ms = 1 second). Minimum for Wav2Vec2 stability. hop_size_ms: Hop size between windows in milliseconds (default: 10ms). Controls temporal resolution (100 frames/second). frame_rate: Frames per second (calculated from hop_size_ms). minimum_audio_length: Minimum audio length in seconds (must be >= window_size_ms). phone_level_strategy: Strategy for phone-level analysis ("sliding_window"). """ fluency_threshold: float = 0.5 articulation_threshold: float = 0.6 min_chunk_duration_ms: int = 10 smoothing_window: int = 5 batch_size: int = 32 window_size_ms: int = 1000 # 1 second minimum for Wav2Vec2 hop_size_ms: int = 10 # 10ms for phone-level resolution frame_rate: float = 100.0 # 100 frames per second (1/hop_size_ms) minimum_audio_length: float = 1.0 # Must be >= window_size_ms phone_level_strategy: str = "sliding_window" @dataclass class APIConfig: """ Configuration for FastAPI REST API and WebSocket server. Attributes: host: Host address to bind the API server. "0.0.0.0" allows external connections. port: Port number for the REST API server. websocket_port: Port number for the WebSocket streaming server. Separate port allows independent scaling. max_file_size_mb: Maximum uploaded file size in megabytes. Prevents memory exhaustion from large files. allowed_extensions: List of allowed audio file extensions. cors_origins: List of allowed CORS origins. ["*"] allows all origins (use with caution in production). request_timeout: Request timeout in seconds. Prevents hanging requests from blocking server. max_connections: Maximum concurrent WebSocket connections. chunk_size: Size of WebSocket message chunks in bytes. """ host: str = "0.0.0.0" port: int = 8000 websocket_port: int = 8001 max_file_size_mb: int = 100 allowed_extensions: List[str] = field(default_factory=lambda: [".wav", ".mp3", ".flac", ".m4a"]) cors_origins: List[str] = field(default_factory=lambda: ["*"]) request_timeout: int = 300 max_connections: int = 100 chunk_size: int = 4096 @dataclass class GradioConfig: """ Configuration for Gradio web interface. Attributes: enabled: Whether to enable the Gradio interface. port: Port number for Gradio server. Typically runs on a different port than REST API. share: Whether to create a public Gradio link. Useful for testing but exposes the interface publicly. theme: Gradio theme name or custom theme configuration. title: Title displayed in the Gradio interface. description: Description text for the Gradio interface. examples: List of example file paths to include in the interface. """ enabled: bool = True port: int = 7860 share: bool = False theme: str = "default" title: str = "Speech Pathology Diagnosis" description: str = "Upload audio for articulation and fluency analysis" examples: List[str] = field(default_factory=list) @dataclass class TrainingConfig: """ Configuration for model training (if training custom classifier). Attributes: learning_rate: Initial learning rate for optimizer. Lower values (1e-4 to 1e-5) work well for fine-tuning. batch_size: Training batch size. Larger batches = more stable gradients but more memory. num_epochs: Number of training epochs. Early stopping should prevent overfitting. weight_decay: L2 regularization strength. Prevents overfitting by penalizing large weights. warmup_steps: Number of warmup steps for learning rate scheduler. Gradually increases LR from 0 to target value. eval_steps: Evaluate model every N steps during training. save_steps: Save model checkpoint every N steps. output_dir: Directory to save trained models and checkpoints. logging_dir: Directory to save training logs (for TensorBoard). gradient_accumulation_steps: Accumulate gradients over N steps. Effective batch size = batch_size * N. fp16: Whether to use mixed precision training (faster on modern GPUs). dataloader_num_workers: Number of worker processes for data loading. Parallelizes data preprocessing. """ learning_rate: float = 5e-5 batch_size: int = 16 num_epochs: int = 10 weight_decay: float = 0.01 warmup_steps: int = 500 eval_steps: int = 500 save_steps: int = 1000 output_dir: str = "./models/trained" logging_dir: str = "./logs" gradient_accumulation_steps: int = 1 fp16: bool = False dataloader_num_workers: int = 4 @dataclass class LoggingConfig: """ Configuration for logging system. Attributes: level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). format: Log message format string. file: Optional log file path. If None, logs only to console. max_bytes: Maximum log file size before rotation (in bytes). backup_count: Number of backup log files to keep. """ level: str = "INFO" format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" file: Optional[str] = None max_bytes: int = 10485760 # 10MB backup_count: int = 5 # Default configuration instances default_audio_config = AudioConfig() default_model_config = ModelConfig() default_inference_config = InferenceConfig() default_api_config = APIConfig() default_gradio_config = GradioConfig() default_training_config = TrainingConfig() default_logging_config = LoggingConfig()