zlaqa-version-c-ai-enginee / models /speech_pathology_model.py
anfastech's picture
New: Phoneme-level speech pathology diagnosis MVP with real-time streaming
1cd6149
"""
Speech Pathology Classifier Model
This module implements a multi-task speech pathology classifier using Wav2Vec2-XLSR-53
as the feature extractor with a custom classifier head for:
- Fluency scoring (binary classification)
- Articulation classification (4 classes: normal, substitution, omission, distortion)
"""
import logging
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Config
from typing import Dict, Optional, Tuple, List
import os
logger = logging.getLogger(__name__)
class MultiTaskClassifierHead(nn.Module):
"""
Multi-task classifier head for speech pathology diagnosis.
This head takes Wav2Vec2 features and produces:
1. Fluency score (binary: fluent vs disfluent)
2. Articulation classes (4 classes: normal, substitution, omission, distortion)
Architecture:
- Shared feature extractor layers
- Task-specific heads for fluency and articulation
"""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
dropout: float = 0.1,
num_articulation_classes: int = 4
):
"""
Initialize the multi-task classifier head.
Args:
input_dim: Input feature dimension from Wav2Vec2 (typically 1024 for large)
hidden_dims: List of hidden layer dimensions, e.g., [256, 128]
dropout: Dropout probability for regularization
num_articulation_classes: Number of articulation classes (default: 4)
"""
super().__init__()
self.num_articulation_classes = num_articulation_classes
# Build shared feature layers: 1024 β†’ 512 β†’ 256
layers = []
prev_dim = input_dim
for hidden_dim in hidden_dims:
layers.extend([
nn.Linear(prev_dim, hidden_dim),
nn.LayerNorm(hidden_dim),
nn.ReLU(),
nn.Dropout(dropout)
])
prev_dim = hidden_dim
self.shared_layers = nn.Sequential(*layers)
shared_output_dim = prev_dim
# Fluency head: 256 β†’ 64 β†’ 2 (stutter/normal)
self.fluency_head = nn.Sequential(
nn.Linear(shared_output_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, 2), # 2 classes: stutter/normal
)
# Articulation head: 256 β†’ 64 β†’ 4 (normal/sub/omit/dist)
self.articulation_head = nn.Sequential(
nn.Linear(shared_output_dim, 64),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(64, num_articulation_classes), # 4 classes
)
# Full combined head: 256 β†’ 128 β†’ 8 (all classes combined)
self.full_head = nn.Sequential(
nn.Linear(shared_output_dim, 128),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(128, 8), # 8 classes (combined fluency + articulation)
)
logger.info(
f"Initialized MultiTaskClassifierHead: "
f"input_dim={input_dim}, hidden_dims={hidden_dims}, "
f"articulation_classes={num_articulation_classes}"
)
def forward(
self,
features: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
"""
Forward pass through the multi-task head.
Args:
features: Wav2Vec2 features of shape (batch_size, seq_len, feature_dim)
attention_mask: Optional attention mask to mask out padding
Returns:
Dictionary containing:
- fluency_logits: Binary classification logits (batch_size, 1)
- articulation_logits: Multi-class logits (batch_size, num_classes)
- fluency_probs: Fluency probabilities (batch_size, 1)
- articulation_probs: Articulation class probabilities (batch_size, num_classes)
"""
# Pool features: mean pooling over sequence length (with attention mask if provided)
if attention_mask is not None:
# Expand attention mask to match feature dimensions
mask_expanded = attention_mask.unsqueeze(-1).expand(features.size()).float()
# Sum features where mask is 1, then divide by sum of mask
sum_features = torch.sum(features * mask_expanded, dim=1)
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9)
pooled_features = sum_features / sum_mask
else:
# Simple mean pooling
pooled_features = torch.mean(features, dim=1)
# Pass through shared layers
shared_features = self.shared_layers(pooled_features)
# Task-specific heads
fluency_logits = self.fluency_head(shared_features) # (batch, 2)
articulation_logits = self.articulation_head(shared_features) # (batch, 4)
full_logits = self.full_head(shared_features) # (batch, 8)
# Apply activations
fluency_probs = F.softmax(fluency_logits, dim=-1) # (batch, 2)
articulation_probs = F.softmax(articulation_logits, dim=-1) # (batch, 4)
full_probs = F.softmax(full_logits, dim=-1) # (batch, 8)
return {
"fluency_logits": fluency_logits,
"articulation_logits": articulation_logits,
"full_logits": full_logits,
"fluency_probs": fluency_probs,
"articulation_probs": articulation_probs,
"full_probs": full_probs,
"shared_features": shared_features,
}
class SpeechPathologyClassifier(nn.Module):
"""
Speech Pathology Classifier using Wav2Vec2-XLSR-53 with custom multi-task head.
This model combines:
- Wav2Vec2-XLSR-53: Pretrained speech feature extractor
- Custom MultiTaskClassifierHead: For fluency and articulation classification
Outputs:
- Fluency score: Probability of fluent speech (0-1)
- Articulation classes: Probabilities for 4 articulation types
"""
# Articulation class names
ARTICULATION_CLASSES = [
"normal", # Clear, correct articulation
"substitution", # Sound replaced with another (e.g., "wabbit" for "rabbit")
"omission", # Sound omitted (e.g., "ca" for "cat")
"distortion" # Sound distorted but recognizable
]
def __init__(
self,
model_name: str = "facebook/wav2vec2-large-xlsr-53",
classifier_hidden_dims: List[int] = None,
dropout: float = 0.1,
num_articulation_classes: int = 4,
device: Optional[str] = None,
use_fp16: bool = False
):
"""
Initialize the Speech Pathology Classifier.
Args:
model_name: HuggingFace model identifier for Wav2Vec2-XLSR-53
classifier_hidden_dims: List of hidden layer dimensions for classifier
Default: [256, 128]
dropout: Dropout probability for classifier layers
num_articulation_classes: Number of articulation classes (default: 4)
device: Device to run on ("cuda" or "cpu"). Auto-detects if None
use_fp16: Whether to use half-precision (requires CUDA)
Raises:
ValueError: If model_name is invalid or model cannot be loaded
RuntimeError: If CUDA is requested but unavailable
"""
super().__init__()
# Set device
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda" and not torch.cuda.is_available():
logger.warning("CUDA requested but not available. Falling back to CPU.")
device = "cpu"
self.device = torch.device(device)
self.use_fp16 = use_fp16 and device == "cuda"
self.is_trained = False # Track if classifier is trained
if classifier_hidden_dims is None:
classifier_hidden_dims = [256, 128]
logger.info(f"Initializing SpeechPathologyClassifier on {device}")
logger.info(f"Model: {model_name}")
logger.info(f"Classifier hidden dims: {classifier_hidden_dims}")
logger.info(f"FP16: {self.use_fp16}")
try:
# Load Wav2Vec2 model and processor
hf_token = os.getenv("HF_TOKEN")
logger.info("Loading Wav2Vec2 model and feature extractor...")
self.wav2vec2_model = Wav2Vec2Model.from_pretrained(
model_name,
token=hf_token if hf_token else None
)
# Use FeatureExtractor instead of Processor for feature extraction tasks
# Processor includes tokenizer which requires vocab file (not available for pre-trained models)
self.processor = Wav2Vec2FeatureExtractor.from_pretrained(
model_name,
token=hf_token if hf_token else None
)
# Get feature dimension from model config
config: Wav2Vec2Config = self.wav2vec2_model.config
feature_dim = config.hidden_size # Typically 1024 for large models
logger.info(f"Wav2Vec2 feature dimension: {feature_dim}")
# Freeze Wav2Vec2 parameters (optional - can be unfrozen for fine-tuning)
# For inference, we typically keep it frozen
for param in self.wav2vec2_model.parameters():
param.requires_grad = False
logger.info("Wav2Vec2 parameters frozen for inference")
# Initialize custom classifier head
self.classifier_head = MultiTaskClassifierHead(
input_dim=feature_dim,
hidden_dims=classifier_hidden_dims,
dropout=dropout,
num_articulation_classes=num_articulation_classes
)
# Try to load trained weights if available (None = try default paths)
self._load_trained_weights(None)
# Move to device
self.wav2vec2_model = self.wav2vec2_model.to(self.device)
self.classifier_head = self.classifier_head.to(self.device)
# Set to eval mode
self.eval()
# Convert to FP16 if requested
if self.use_fp16:
self.wav2vec2_model = self.wav2vec2_model.half()
logger.info("Model converted to FP16")
logger.info("βœ… SpeechPathologyClassifier initialized successfully")
except Exception as e:
logger.error(f"❌ Failed to initialize model: {e}", exc_info=True)
raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e
def _load_trained_weights(self, model_path: Optional[str] = None):
"""
Load trained classifier head weights if available.
Args:
model_path: Optional path to model checkpoint. If None, tries default checkpoint paths.
"""
from pathlib import Path
checkpoint_paths = []
# Add user-provided path
if model_path:
checkpoint_paths.append(Path(model_path))
# Add default checkpoint paths
checkpoint_paths.extend([
Path("models/checkpoints/classifier_head_best.pt"),
Path("models/checkpoints/classifier_head_trained.pt")
])
for checkpoint_path in checkpoint_paths:
if checkpoint_path.exists():
try:
checkpoint = torch.load(checkpoint_path, map_location=self.device)
# Handle both full checkpoint dict and state_dict directly
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
state_dict = checkpoint['model_state_dict']
epoch = checkpoint.get('epoch', 'unknown')
val_acc = checkpoint.get('val_accuracy', 'unknown')
else:
state_dict = checkpoint
epoch = 'unknown'
val_acc = 'unknown'
self.classifier_head.load_state_dict(state_dict)
logger.info(f"βœ… Loaded trained classifier head from {checkpoint_path}")
logger.info(f" Epoch: {epoch}, Validation Accuracy: {val_acc}")
self.is_trained = True
return
except Exception as e:
logger.warning(f"⚠️ Could not load checkpoint {checkpoint_path}: {e}")
continue
# No trained weights found
logger.warning("⚠️ No trained classifier weights found. Using untrained head (beta mode)")
logger.warning(" To train the classifier, run: python training/train_classifier_head.py")
self.is_trained = False
def forward(
self,
input_values: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
"""
Forward pass through the model.
Args:
input_values: Audio input tensor of shape (batch_size, seq_len)
Should be normalized to [-1, 1] range
attention_mask: Optional attention mask for padding
Returns:
Dictionary containing:
- fluency_logits: Binary classification logits
- articulation_logits: Multi-class logits
- fluency_probs: Fluency probabilities (0-1)
- articulation_probs: Articulation class probabilities
- wav2vec2_features: Raw Wav2Vec2 features (for debugging)
"""
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json, time
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Before Wav2Vec2 forward","data":{"input_values_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
# Extract features using Wav2Vec2
try:
with torch.no_grad() if not self.training else torch.enable_grad():
wav2vec2_outputs = self.wav2vec2_model(
input_values=input_values,
attention_mask=attention_mask
)
except Exception as e:
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json, time
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Wav2Vec2 forward exception","data":{"error":str(e),"error_type":type(e).__name__,"input_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
raise
# Get last hidden state (features)
features = wav2vec2_outputs.last_hidden_state # (batch_size, seq_len, feature_dim)
# #region agent log
try:
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f:
import json, time
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:297","message":"After Wav2Vec2 forward","data":{"features_shape":list(features.shape),"seq_len":features.shape[1] if len(features.shape) > 1 else 0},"timestamp":int(time.time()*1000)}) + '\n')
except: pass
# #endregion
# Safety check: ensure sequence length is valid (at least 1)
if features.shape[1] < 1:
raise ValueError(
f"Wav2Vec2 output sequence length is too short: {features.shape[1]}. "
f"Input was {input_values.shape}. Try using longer audio segments (>= 500ms)."
)
# Pass through classifier head
outputs = self.classifier_head(features, attention_mask)
# Add raw features for debugging/analysis
outputs["wav2vec2_features"] = features
return outputs
def predict(
self,
audio_array: torch.Tensor,
sample_rate: int = 16000,
return_dict: bool = True
) -> Dict[str, torch.Tensor]:
"""
Predict fluency and articulation for audio input.
Args:
audio_array: Audio tensor of shape (seq_len,) or (batch_size, seq_len)
Should be in range [-1, 1]
sample_rate: Sample rate of audio (should match processor, typically 16000)
return_dict: Whether to return dictionary or tuple
Returns:
Dictionary with predictions:
- fluency_score: Float probability of fluent speech (0-1)
- articulation_class: Integer class index (0-3)
- articulation_class_name: String class name
- articulation_probs: Probabilities for all classes
- confidence: Overall confidence score
"""
self.eval()
with torch.no_grad():
# Ensure audio is 2D (batch_size, seq_len)
if audio_array.dim() == 1:
audio_array = audio_array.unsqueeze(0)
# Move to device
audio_array = audio_array.to(self.device)
# Process audio through model
# Note: Processor should be used for preprocessing, but for inference
# we assume audio is already preprocessed
outputs = self.forward(audio_array)
# Extract predictions
fluency_probs = outputs["fluency_probs"].cpu()
articulation_probs = outputs["articulation_probs"].cpu()
# Get fluency score (probability of being fluent)
fluency_score = fluency_probs.item() if fluency_probs.numel() == 1 else fluency_probs[0].item()
# Get articulation class (argmax)
articulation_probs_flat = articulation_probs[0] if articulation_probs.dim() > 1 else articulation_probs
articulation_class_idx = torch.argmax(articulation_probs_flat).item()
articulation_class_name = self.ARTICULATION_CLASSES[articulation_class_idx]
articulation_confidence = articulation_probs_flat[articulation_class_idx].item()
# Overall confidence (average of fluency and articulation confidences)
overall_confidence = (fluency_score + articulation_confidence) / 2.0
if return_dict:
return {
"fluency_score": fluency_score,
"articulation_class": articulation_class_idx,
"articulation_class_name": articulation_class_name,
"articulation_probs": articulation_probs_flat.tolist(),
"confidence": overall_confidence,
}
else:
return fluency_score, articulation_class_idx, articulation_probs_flat.tolist()
def get_articulation_class_name(self, class_idx: int) -> str:
"""Get the name of an articulation class by index."""
if 0 <= class_idx < len(self.ARTICULATION_CLASSES):
return self.ARTICULATION_CLASSES[class_idx]
raise ValueError(f"Invalid articulation class index: {class_idx}")
def unfreeze_wav2vec2(self):
"""Unfreeze Wav2Vec2 parameters for fine-tuning."""
logger.info("Unfreezing Wav2Vec2 parameters for fine-tuning")
for param in self.wav2vec2_model.parameters():
param.requires_grad = True
def freeze_wav2vec2(self):
"""Freeze Wav2Vec2 parameters (default for inference)."""
logger.info("Freezing Wav2Vec2 parameters")
for param in self.wav2vec2_model.parameters():
param.requires_grad = False
def load_speech_pathology_model(
model_name: str = "facebook/wav2vec2-large-xlsr-53",
classifier_hidden_dims: List[int] = None,
dropout: float = 0.1,
device: Optional[str] = None,
use_fp16: bool = False,
model_path: Optional[str] = None
) -> SpeechPathologyClassifier:
"""
Load or create a SpeechPathologyClassifier instance.
Args:
model_name: HuggingFace model identifier
classifier_hidden_dims: Classifier hidden dimensions
dropout: Dropout probability
device: Device to run on
use_fp16: Whether to use FP16
model_path: Optional path to saved model checkpoint
Returns:
SpeechPathologyClassifier instance
"""
if model_path and os.path.exists(model_path):
logger.info(f"Loading model from checkpoint: {model_path}")
model = SpeechPathologyClassifier(
model_name=model_name,
classifier_hidden_dims=classifier_hidden_dims or [256, 128],
dropout=dropout,
device=device,
use_fp16=use_fp16
)
checkpoint = torch.load(model_path, map_location=device or "cpu")
model.load_state_dict(checkpoint["model_state_dict"])
logger.info("βœ… Model loaded from checkpoint")
return model
else:
logger.info("Creating new SpeechPathologyClassifier")
return SpeechPathologyClassifier(
model_name=model_name,
classifier_hidden_dims=classifier_hidden_dims or [256, 128],
dropout=dropout,
device=device,
use_fp16=use_fp16
)