|
|
""" |
|
|
Configuration module for Speech Pathology MVP. |
|
|
|
|
|
This module contains all configuration dataclasses for the speech pathology |
|
|
diagnosis system, including audio processing, model inference, API, and training settings. |
|
|
""" |
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AudioConfig: |
|
|
""" |
|
|
Configuration for audio processing and VAD (Voice Activity Detection). |
|
|
|
|
|
Attributes: |
|
|
sample_rate: Target sample rate for audio processing in Hz. |
|
|
Standard 16kHz is optimal for speech recognition models. |
|
|
chunk_duration_ms: Duration of each audio chunk in milliseconds. |
|
|
Set to 20ms for phone-level analysis (matches typical |
|
|
phone duration in speech). |
|
|
vad_aggressiveness: VAD aggressiveness mode (0-3). |
|
|
- 0: Least aggressive (fewer false positives) |
|
|
- 1: Moderate |
|
|
- 2: More aggressive |
|
|
- 3: Most aggressive (fewer false negatives) |
|
|
Higher values detect more speech but may include noise. |
|
|
frame_length_ms: Frame length for audio analysis in milliseconds. |
|
|
Typically matches chunk_duration_ms for consistency. |
|
|
hop_length_ms: Hop length between frames in milliseconds. |
|
|
Overlapping frames improve temporal resolution. |
|
|
n_fft: Number of FFT points for spectral analysis. |
|
|
Should be power of 2, typically 512 or 1024 for 16kHz audio. |
|
|
""" |
|
|
sample_rate: int = 16000 |
|
|
chunk_duration_ms: int = 20 |
|
|
vad_aggressiveness: int = 3 |
|
|
frame_length_ms: int = 20 |
|
|
hop_length_ms: int = 10 |
|
|
n_fft: int = 512 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelConfig: |
|
|
""" |
|
|
Configuration for Wav2Vec2 model and classifier architecture. |
|
|
|
|
|
Attributes: |
|
|
model_name: HuggingFace model identifier for Wav2Vec2-XLSR-53. |
|
|
"facebook/wav2vec2-large-xlsr-53" provides excellent |
|
|
multilingual speech representations. |
|
|
classifier_hidden_dims: List of hidden layer dimensions for the classifier. |
|
|
[256, 128] creates a 2-layer MLP: |
|
|
- Input: Wav2Vec2 feature dim (1024 for large) |
|
|
- Hidden 1: 256 units |
|
|
- Hidden 2: 128 units |
|
|
- Output: 2 (articulation + fluency scores) |
|
|
dropout: Dropout probability for classifier layers. |
|
|
Prevents overfitting during training. |
|
|
num_labels: Number of output labels (2: articulation, fluency). |
|
|
device: Device to run inference on ("cuda" or "cpu"). |
|
|
use_fp16: Whether to use half-precision for faster inference. |
|
|
Requires CUDA and may slightly reduce accuracy. |
|
|
max_length: Maximum sequence length in samples. |
|
|
Longer sequences are truncated to this length. |
|
|
""" |
|
|
model_name: str = "facebook/wav2vec2-large-xlsr-53" |
|
|
classifier_hidden_dims: List[int] = field(default_factory=lambda: [256, 128]) |
|
|
dropout: float = 0.1 |
|
|
num_labels: int = 2 |
|
|
device: str = "cuda" |
|
|
use_fp16: bool = False |
|
|
max_length: int = 250000 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class InferenceConfig: |
|
|
""" |
|
|
Configuration for inference thresholds and post-processing. |
|
|
|
|
|
Attributes: |
|
|
fluency_threshold: Threshold for fluency classification (0.0-1.0). |
|
|
Scores above this indicate fluent speech. |
|
|
Lower values = more sensitive (detects more issues). |
|
|
articulation_threshold: Threshold for articulation classification (0.0-1.0). |
|
|
Scores above this indicate clear articulation. |
|
|
Lower values = more sensitive (detects more issues). |
|
|
min_chunk_duration_ms: Minimum duration of a chunk to be analyzed. |
|
|
Filters out very short audio segments. |
|
|
smoothing_window: Window size for temporal smoothing of predictions. |
|
|
Reduces jitter in frame-level predictions. |
|
|
batch_size: Number of chunks to process in parallel during inference. |
|
|
Higher values = faster but more memory usage. |
|
|
window_size_ms: Size of sliding window in milliseconds (default: 1000ms = 1 second). |
|
|
Minimum for Wav2Vec2 stability. |
|
|
hop_size_ms: Hop size between windows in milliseconds (default: 10ms). |
|
|
Controls temporal resolution (100 frames/second). |
|
|
frame_rate: Frames per second (calculated from hop_size_ms). |
|
|
minimum_audio_length: Minimum audio length in seconds (must be >= window_size_ms). |
|
|
phone_level_strategy: Strategy for phone-level analysis ("sliding_window"). |
|
|
""" |
|
|
fluency_threshold: float = 0.5 |
|
|
articulation_threshold: float = 0.6 |
|
|
min_chunk_duration_ms: int = 10 |
|
|
smoothing_window: int = 5 |
|
|
batch_size: int = 32 |
|
|
window_size_ms: int = 1000 |
|
|
hop_size_ms: int = 10 |
|
|
frame_rate: float = 100.0 |
|
|
minimum_audio_length: float = 1.0 |
|
|
phone_level_strategy: str = "sliding_window" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class APIConfig: |
|
|
""" |
|
|
Configuration for FastAPI REST API and WebSocket server. |
|
|
|
|
|
Attributes: |
|
|
host: Host address to bind the API server. |
|
|
"0.0.0.0" allows external connections. |
|
|
port: Port number for the REST API server. |
|
|
websocket_port: Port number for the WebSocket streaming server. |
|
|
Separate port allows independent scaling. |
|
|
max_file_size_mb: Maximum uploaded file size in megabytes. |
|
|
Prevents memory exhaustion from large files. |
|
|
allowed_extensions: List of allowed audio file extensions. |
|
|
cors_origins: List of allowed CORS origins. |
|
|
["*"] allows all origins (use with caution in production). |
|
|
request_timeout: Request timeout in seconds. |
|
|
Prevents hanging requests from blocking server. |
|
|
max_connections: Maximum concurrent WebSocket connections. |
|
|
chunk_size: Size of WebSocket message chunks in bytes. |
|
|
""" |
|
|
host: str = "0.0.0.0" |
|
|
port: int = 8000 |
|
|
websocket_port: int = 8001 |
|
|
max_file_size_mb: int = 100 |
|
|
allowed_extensions: List[str] = field(default_factory=lambda: [".wav", ".mp3", ".flac", ".m4a"]) |
|
|
cors_origins: List[str] = field(default_factory=lambda: ["*"]) |
|
|
request_timeout: int = 300 |
|
|
max_connections: int = 100 |
|
|
chunk_size: int = 4096 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class GradioConfig: |
|
|
""" |
|
|
Configuration for Gradio web interface. |
|
|
|
|
|
Attributes: |
|
|
enabled: Whether to enable the Gradio interface. |
|
|
port: Port number for Gradio server. |
|
|
Typically runs on a different port than REST API. |
|
|
share: Whether to create a public Gradio link. |
|
|
Useful for testing but exposes the interface publicly. |
|
|
theme: Gradio theme name or custom theme configuration. |
|
|
title: Title displayed in the Gradio interface. |
|
|
description: Description text for the Gradio interface. |
|
|
examples: List of example file paths to include in the interface. |
|
|
""" |
|
|
enabled: bool = True |
|
|
port: int = 7860 |
|
|
share: bool = False |
|
|
theme: str = "default" |
|
|
title: str = "Speech Pathology Diagnosis" |
|
|
description: str = "Upload audio for articulation and fluency analysis" |
|
|
examples: List[str] = field(default_factory=list) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainingConfig: |
|
|
""" |
|
|
Configuration for model training (if training custom classifier). |
|
|
|
|
|
Attributes: |
|
|
learning_rate: Initial learning rate for optimizer. |
|
|
Lower values (1e-4 to 1e-5) work well for fine-tuning. |
|
|
batch_size: Training batch size. |
|
|
Larger batches = more stable gradients but more memory. |
|
|
num_epochs: Number of training epochs. |
|
|
Early stopping should prevent overfitting. |
|
|
weight_decay: L2 regularization strength. |
|
|
Prevents overfitting by penalizing large weights. |
|
|
warmup_steps: Number of warmup steps for learning rate scheduler. |
|
|
Gradually increases LR from 0 to target value. |
|
|
eval_steps: Evaluate model every N steps during training. |
|
|
save_steps: Save model checkpoint every N steps. |
|
|
output_dir: Directory to save trained models and checkpoints. |
|
|
logging_dir: Directory to save training logs (for TensorBoard). |
|
|
gradient_accumulation_steps: Accumulate gradients over N steps. |
|
|
Effective batch size = batch_size * N. |
|
|
fp16: Whether to use mixed precision training (faster on modern GPUs). |
|
|
dataloader_num_workers: Number of worker processes for data loading. |
|
|
Parallelizes data preprocessing. |
|
|
""" |
|
|
learning_rate: float = 5e-5 |
|
|
batch_size: int = 16 |
|
|
num_epochs: int = 10 |
|
|
weight_decay: float = 0.01 |
|
|
warmup_steps: int = 500 |
|
|
eval_steps: int = 500 |
|
|
save_steps: int = 1000 |
|
|
output_dir: str = "./models/trained" |
|
|
logging_dir: str = "./logs" |
|
|
gradient_accumulation_steps: int = 1 |
|
|
fp16: bool = False |
|
|
dataloader_num_workers: int = 4 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LoggingConfig: |
|
|
""" |
|
|
Configuration for logging system. |
|
|
|
|
|
Attributes: |
|
|
level: Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL). |
|
|
format: Log message format string. |
|
|
file: Optional log file path. If None, logs only to console. |
|
|
max_bytes: Maximum log file size before rotation (in bytes). |
|
|
backup_count: Number of backup log files to keep. |
|
|
""" |
|
|
level: str = "INFO" |
|
|
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" |
|
|
file: Optional[str] = None |
|
|
max_bytes: int = 10485760 |
|
|
backup_count: int = 5 |
|
|
|
|
|
|
|
|
|
|
|
default_audio_config = AudioConfig() |
|
|
default_model_config = ModelConfig() |
|
|
default_inference_config = InferenceConfig() |
|
|
default_api_config = APIConfig() |
|
|
default_gradio_config = GradioConfig() |
|
|
default_training_config = TrainingConfig() |
|
|
default_logging_config = LoggingConfig() |
|
|
|
|
|
|