""" Speech Pathology Classifier Model This module implements a multi-task speech pathology classifier using Wav2Vec2-XLSR-53 as the feature extractor with a custom classifier head for: - Fluency scoring (binary classification) - Articulation classification (4 classes: normal, substitution, omission, distortion) """ import logging import torch import torch.nn as nn from torch.nn import functional as F from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Config from typing import Dict, Optional, Tuple, List import os logger = logging.getLogger(__name__) class MultiTaskClassifierHead(nn.Module): """ Multi-task classifier head for speech pathology diagnosis. This head takes Wav2Vec2 features and produces: 1. Fluency score (binary: fluent vs disfluent) 2. Articulation classes (4 classes: normal, substitution, omission, distortion) Architecture: - Shared feature extractor layers - Task-specific heads for fluency and articulation """ def __init__( self, input_dim: int, hidden_dims: List[int], dropout: float = 0.1, num_articulation_classes: int = 4 ): """ Initialize the multi-task classifier head. Args: input_dim: Input feature dimension from Wav2Vec2 (typically 1024 for large) hidden_dims: List of hidden layer dimensions, e.g., [256, 128] dropout: Dropout probability for regularization num_articulation_classes: Number of articulation classes (default: 4) """ super().__init__() self.num_articulation_classes = num_articulation_classes # Build shared feature layers: 1024 → 512 → 256 layers = [] prev_dim = input_dim for hidden_dim in hidden_dims: layers.extend([ nn.Linear(prev_dim, hidden_dim), nn.LayerNorm(hidden_dim), nn.ReLU(), nn.Dropout(dropout) ]) prev_dim = hidden_dim self.shared_layers = nn.Sequential(*layers) shared_output_dim = prev_dim # Fluency head: 256 → 64 → 2 (stutter/normal) self.fluency_head = nn.Sequential( nn.Linear(shared_output_dim, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, 2), # 2 classes: stutter/normal ) # Articulation head: 256 → 64 → 4 (normal/sub/omit/dist) self.articulation_head = nn.Sequential( nn.Linear(shared_output_dim, 64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, num_articulation_classes), # 4 classes ) # Full combined head: 256 → 128 → 8 (all classes combined) self.full_head = nn.Sequential( nn.Linear(shared_output_dim, 128), nn.ReLU(), nn.Dropout(dropout), nn.Linear(128, 8), # 8 classes (combined fluency + articulation) ) logger.info( f"Initialized MultiTaskClassifierHead: " f"input_dim={input_dim}, hidden_dims={hidden_dims}, " f"articulation_classes={num_articulation_classes}" ) def forward( self, features: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: """ Forward pass through the multi-task head. Args: features: Wav2Vec2 features of shape (batch_size, seq_len, feature_dim) attention_mask: Optional attention mask to mask out padding Returns: Dictionary containing: - fluency_logits: Binary classification logits (batch_size, 1) - articulation_logits: Multi-class logits (batch_size, num_classes) - fluency_probs: Fluency probabilities (batch_size, 1) - articulation_probs: Articulation class probabilities (batch_size, num_classes) """ # Pool features: mean pooling over sequence length (with attention mask if provided) if attention_mask is not None: # Expand attention mask to match feature dimensions mask_expanded = attention_mask.unsqueeze(-1).expand(features.size()).float() # Sum features where mask is 1, then divide by sum of mask sum_features = torch.sum(features * mask_expanded, dim=1) sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) pooled_features = sum_features / sum_mask else: # Simple mean pooling pooled_features = torch.mean(features, dim=1) # Pass through shared layers shared_features = self.shared_layers(pooled_features) # Task-specific heads fluency_logits = self.fluency_head(shared_features) # (batch, 2) articulation_logits = self.articulation_head(shared_features) # (batch, 4) full_logits = self.full_head(shared_features) # (batch, 8) # Apply activations fluency_probs = F.softmax(fluency_logits, dim=-1) # (batch, 2) articulation_probs = F.softmax(articulation_logits, dim=-1) # (batch, 4) full_probs = F.softmax(full_logits, dim=-1) # (batch, 8) return { "fluency_logits": fluency_logits, "articulation_logits": articulation_logits, "full_logits": full_logits, "fluency_probs": fluency_probs, "articulation_probs": articulation_probs, "full_probs": full_probs, "shared_features": shared_features, } class SpeechPathologyClassifier(nn.Module): """ Speech Pathology Classifier using Wav2Vec2-XLSR-53 with custom multi-task head. This model combines: - Wav2Vec2-XLSR-53: Pretrained speech feature extractor - Custom MultiTaskClassifierHead: For fluency and articulation classification Outputs: - Fluency score: Probability of fluent speech (0-1) - Articulation classes: Probabilities for 4 articulation types """ # Articulation class names ARTICULATION_CLASSES = [ "normal", # Clear, correct articulation "substitution", # Sound replaced with another (e.g., "wabbit" for "rabbit") "omission", # Sound omitted (e.g., "ca" for "cat") "distortion" # Sound distorted but recognizable ] def __init__( self, model_name: str = "facebook/wav2vec2-large-xlsr-53", classifier_hidden_dims: List[int] = None, dropout: float = 0.1, num_articulation_classes: int = 4, device: Optional[str] = None, use_fp16: bool = False ): """ Initialize the Speech Pathology Classifier. Args: model_name: HuggingFace model identifier for Wav2Vec2-XLSR-53 classifier_hidden_dims: List of hidden layer dimensions for classifier Default: [256, 128] dropout: Dropout probability for classifier layers num_articulation_classes: Number of articulation classes (default: 4) device: Device to run on ("cuda" or "cpu"). Auto-detects if None use_fp16: Whether to use half-precision (requires CUDA) Raises: ValueError: If model_name is invalid or model cannot be loaded RuntimeError: If CUDA is requested but unavailable """ super().__init__() # Set device if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda" and not torch.cuda.is_available(): logger.warning("CUDA requested but not available. Falling back to CPU.") device = "cpu" self.device = torch.device(device) self.use_fp16 = use_fp16 and device == "cuda" self.is_trained = False # Track if classifier is trained if classifier_hidden_dims is None: classifier_hidden_dims = [256, 128] logger.info(f"Initializing SpeechPathologyClassifier on {device}") logger.info(f"Model: {model_name}") logger.info(f"Classifier hidden dims: {classifier_hidden_dims}") logger.info(f"FP16: {self.use_fp16}") try: # Load Wav2Vec2 model and processor hf_token = os.getenv("HF_TOKEN") logger.info("Loading Wav2Vec2 model and feature extractor...") self.wav2vec2_model = Wav2Vec2Model.from_pretrained( model_name, token=hf_token if hf_token else None ) # Use FeatureExtractor instead of Processor for feature extraction tasks # Processor includes tokenizer which requires vocab file (not available for pre-trained models) self.processor = Wav2Vec2FeatureExtractor.from_pretrained( model_name, token=hf_token if hf_token else None ) # Get feature dimension from model config config: Wav2Vec2Config = self.wav2vec2_model.config feature_dim = config.hidden_size # Typically 1024 for large models logger.info(f"Wav2Vec2 feature dimension: {feature_dim}") # Freeze Wav2Vec2 parameters (optional - can be unfrozen for fine-tuning) # For inference, we typically keep it frozen for param in self.wav2vec2_model.parameters(): param.requires_grad = False logger.info("Wav2Vec2 parameters frozen for inference") # Initialize custom classifier head self.classifier_head = MultiTaskClassifierHead( input_dim=feature_dim, hidden_dims=classifier_hidden_dims, dropout=dropout, num_articulation_classes=num_articulation_classes ) # Try to load trained weights if available (None = try default paths) self._load_trained_weights(None) # Move to device self.wav2vec2_model = self.wav2vec2_model.to(self.device) self.classifier_head = self.classifier_head.to(self.device) # Set to eval mode self.eval() # Convert to FP16 if requested if self.use_fp16: self.wav2vec2_model = self.wav2vec2_model.half() logger.info("Model converted to FP16") logger.info("✅ SpeechPathologyClassifier initialized successfully") except Exception as e: logger.error(f"❌ Failed to initialize model: {e}", exc_info=True) raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e def _load_trained_weights(self, model_path: Optional[str] = None): """ Load trained classifier head weights if available. Args: model_path: Optional path to model checkpoint. If None, tries default checkpoint paths. """ from pathlib import Path checkpoint_paths = [] # Add user-provided path if model_path: checkpoint_paths.append(Path(model_path)) # Add default checkpoint paths checkpoint_paths.extend([ Path("models/checkpoints/classifier_head_best.pt"), Path("models/checkpoints/classifier_head_trained.pt") ]) for checkpoint_path in checkpoint_paths: if checkpoint_path.exists(): try: checkpoint = torch.load(checkpoint_path, map_location=self.device) # Handle both full checkpoint dict and state_dict directly if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: state_dict = checkpoint['model_state_dict'] epoch = checkpoint.get('epoch', 'unknown') val_acc = checkpoint.get('val_accuracy', 'unknown') else: state_dict = checkpoint epoch = 'unknown' val_acc = 'unknown' self.classifier_head.load_state_dict(state_dict) logger.info(f"✅ Loaded trained classifier head from {checkpoint_path}") logger.info(f" Epoch: {epoch}, Validation Accuracy: {val_acc}") self.is_trained = True return except Exception as e: logger.warning(f"⚠️ Could not load checkpoint {checkpoint_path}: {e}") continue # No trained weights found logger.warning("⚠️ No trained classifier weights found. Using untrained head (beta mode)") logger.warning(" To train the classifier, run: python training/train_classifier_head.py") self.is_trained = False def forward( self, input_values: torch.Tensor, attention_mask: Optional[torch.Tensor] = None ) -> Dict[str, torch.Tensor]: """ Forward pass through the model. Args: input_values: Audio input tensor of shape (batch_size, seq_len) Should be normalized to [-1, 1] range attention_mask: Optional attention mask for padding Returns: Dictionary containing: - fluency_logits: Binary classification logits - articulation_logits: Multi-class logits - fluency_probs: Fluency probabilities (0-1) - articulation_probs: Articulation class probabilities - wav2vec2_features: Raw Wav2Vec2 features (for debugging) """ # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json, time f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Before Wav2Vec2 forward","data":{"input_values_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion # Extract features using Wav2Vec2 try: with torch.no_grad() if not self.training else torch.enable_grad(): wav2vec2_outputs = self.wav2vec2_model( input_values=input_values, attention_mask=attention_mask ) except Exception as e: # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json, time f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Wav2Vec2 forward exception","data":{"error":str(e),"error_type":type(e).__name__,"input_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion raise # Get last hidden state (features) features = wav2vec2_outputs.last_hidden_state # (batch_size, seq_len, feature_dim) # #region agent log try: with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: import json, time f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:297","message":"After Wav2Vec2 forward","data":{"features_shape":list(features.shape),"seq_len":features.shape[1] if len(features.shape) > 1 else 0},"timestamp":int(time.time()*1000)}) + '\n') except: pass # #endregion # Safety check: ensure sequence length is valid (at least 1) if features.shape[1] < 1: raise ValueError( f"Wav2Vec2 output sequence length is too short: {features.shape[1]}. " f"Input was {input_values.shape}. Try using longer audio segments (>= 500ms)." ) # Pass through classifier head outputs = self.classifier_head(features, attention_mask) # Add raw features for debugging/analysis outputs["wav2vec2_features"] = features return outputs def predict( self, audio_array: torch.Tensor, sample_rate: int = 16000, return_dict: bool = True ) -> Dict[str, torch.Tensor]: """ Predict fluency and articulation for audio input. Args: audio_array: Audio tensor of shape (seq_len,) or (batch_size, seq_len) Should be in range [-1, 1] sample_rate: Sample rate of audio (should match processor, typically 16000) return_dict: Whether to return dictionary or tuple Returns: Dictionary with predictions: - fluency_score: Float probability of fluent speech (0-1) - articulation_class: Integer class index (0-3) - articulation_class_name: String class name - articulation_probs: Probabilities for all classes - confidence: Overall confidence score """ self.eval() with torch.no_grad(): # Ensure audio is 2D (batch_size, seq_len) if audio_array.dim() == 1: audio_array = audio_array.unsqueeze(0) # Move to device audio_array = audio_array.to(self.device) # Process audio through model # Note: Processor should be used for preprocessing, but for inference # we assume audio is already preprocessed outputs = self.forward(audio_array) # Extract predictions fluency_probs = outputs["fluency_probs"].cpu() articulation_probs = outputs["articulation_probs"].cpu() # Get fluency score (probability of being fluent) fluency_score = fluency_probs.item() if fluency_probs.numel() == 1 else fluency_probs[0].item() # Get articulation class (argmax) articulation_probs_flat = articulation_probs[0] if articulation_probs.dim() > 1 else articulation_probs articulation_class_idx = torch.argmax(articulation_probs_flat).item() articulation_class_name = self.ARTICULATION_CLASSES[articulation_class_idx] articulation_confidence = articulation_probs_flat[articulation_class_idx].item() # Overall confidence (average of fluency and articulation confidences) overall_confidence = (fluency_score + articulation_confidence) / 2.0 if return_dict: return { "fluency_score": fluency_score, "articulation_class": articulation_class_idx, "articulation_class_name": articulation_class_name, "articulation_probs": articulation_probs_flat.tolist(), "confidence": overall_confidence, } else: return fluency_score, articulation_class_idx, articulation_probs_flat.tolist() def get_articulation_class_name(self, class_idx: int) -> str: """Get the name of an articulation class by index.""" if 0 <= class_idx < len(self.ARTICULATION_CLASSES): return self.ARTICULATION_CLASSES[class_idx] raise ValueError(f"Invalid articulation class index: {class_idx}") def unfreeze_wav2vec2(self): """Unfreeze Wav2Vec2 parameters for fine-tuning.""" logger.info("Unfreezing Wav2Vec2 parameters for fine-tuning") for param in self.wav2vec2_model.parameters(): param.requires_grad = True def freeze_wav2vec2(self): """Freeze Wav2Vec2 parameters (default for inference).""" logger.info("Freezing Wav2Vec2 parameters") for param in self.wav2vec2_model.parameters(): param.requires_grad = False def load_speech_pathology_model( model_name: str = "facebook/wav2vec2-large-xlsr-53", classifier_hidden_dims: List[int] = None, dropout: float = 0.1, device: Optional[str] = None, use_fp16: bool = False, model_path: Optional[str] = None ) -> SpeechPathologyClassifier: """ Load or create a SpeechPathologyClassifier instance. Args: model_name: HuggingFace model identifier classifier_hidden_dims: Classifier hidden dimensions dropout: Dropout probability device: Device to run on use_fp16: Whether to use FP16 model_path: Optional path to saved model checkpoint Returns: SpeechPathologyClassifier instance """ if model_path and os.path.exists(model_path): logger.info(f"Loading model from checkpoint: {model_path}") model = SpeechPathologyClassifier( model_name=model_name, classifier_hidden_dims=classifier_hidden_dims or [256, 128], dropout=dropout, device=device, use_fp16=use_fp16 ) checkpoint = torch.load(model_path, map_location=device or "cpu") model.load_state_dict(checkpoint["model_state_dict"]) logger.info("✅ Model loaded from checkpoint") return model else: logger.info("Creating new SpeechPathologyClassifier") return SpeechPathologyClassifier( model_name=model_name, classifier_hidden_dims=classifier_hidden_dims or [256, 128], dropout=dropout, device=device, use_fp16=use_fp16 )