|
|
""" |
|
|
Speech Pathology Classifier Model |
|
|
|
|
|
This module implements a multi-task speech pathology classifier using Wav2Vec2-XLSR-53 |
|
|
as the feature extractor with a custom classifier head for: |
|
|
- Fluency scoring (binary classification) |
|
|
- Articulation classification (4 classes: normal, substitution, omission, distortion) |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from transformers import Wav2Vec2Model, Wav2Vec2FeatureExtractor, Wav2Vec2Config |
|
|
from typing import Dict, Optional, Tuple, List |
|
|
import os |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MultiTaskClassifierHead(nn.Module): |
|
|
""" |
|
|
Multi-task classifier head for speech pathology diagnosis. |
|
|
|
|
|
This head takes Wav2Vec2 features and produces: |
|
|
1. Fluency score (binary: fluent vs disfluent) |
|
|
2. Articulation classes (4 classes: normal, substitution, omission, distortion) |
|
|
|
|
|
Architecture: |
|
|
- Shared feature extractor layers |
|
|
- Task-specific heads for fluency and articulation |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_dim: int, |
|
|
hidden_dims: List[int], |
|
|
dropout: float = 0.1, |
|
|
num_articulation_classes: int = 4 |
|
|
): |
|
|
""" |
|
|
Initialize the multi-task classifier head. |
|
|
|
|
|
Args: |
|
|
input_dim: Input feature dimension from Wav2Vec2 (typically 1024 for large) |
|
|
hidden_dims: List of hidden layer dimensions, e.g., [256, 128] |
|
|
dropout: Dropout probability for regularization |
|
|
num_articulation_classes: Number of articulation classes (default: 4) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
self.num_articulation_classes = num_articulation_classes |
|
|
|
|
|
|
|
|
layers = [] |
|
|
prev_dim = input_dim |
|
|
|
|
|
for hidden_dim in hidden_dims: |
|
|
layers.extend([ |
|
|
nn.Linear(prev_dim, hidden_dim), |
|
|
nn.LayerNorm(hidden_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout) |
|
|
]) |
|
|
prev_dim = hidden_dim |
|
|
|
|
|
self.shared_layers = nn.Sequential(*layers) |
|
|
shared_output_dim = prev_dim |
|
|
|
|
|
|
|
|
self.fluency_head = nn.Sequential( |
|
|
nn.Linear(shared_output_dim, 64), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(64, 2), |
|
|
) |
|
|
|
|
|
|
|
|
self.articulation_head = nn.Sequential( |
|
|
nn.Linear(shared_output_dim, 64), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(64, num_articulation_classes), |
|
|
) |
|
|
|
|
|
|
|
|
self.full_head = nn.Sequential( |
|
|
nn.Linear(shared_output_dim, 128), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(128, 8), |
|
|
) |
|
|
|
|
|
logger.info( |
|
|
f"Initialized MultiTaskClassifierHead: " |
|
|
f"input_dim={input_dim}, hidden_dims={hidden_dims}, " |
|
|
f"articulation_classes={num_articulation_classes}" |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
features: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass through the multi-task head. |
|
|
|
|
|
Args: |
|
|
features: Wav2Vec2 features of shape (batch_size, seq_len, feature_dim) |
|
|
attention_mask: Optional attention mask to mask out padding |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- fluency_logits: Binary classification logits (batch_size, 1) |
|
|
- articulation_logits: Multi-class logits (batch_size, num_classes) |
|
|
- fluency_probs: Fluency probabilities (batch_size, 1) |
|
|
- articulation_probs: Articulation class probabilities (batch_size, num_classes) |
|
|
""" |
|
|
|
|
|
if attention_mask is not None: |
|
|
|
|
|
mask_expanded = attention_mask.unsqueeze(-1).expand(features.size()).float() |
|
|
|
|
|
sum_features = torch.sum(features * mask_expanded, dim=1) |
|
|
sum_mask = torch.clamp(mask_expanded.sum(dim=1), min=1e-9) |
|
|
pooled_features = sum_features / sum_mask |
|
|
else: |
|
|
|
|
|
pooled_features = torch.mean(features, dim=1) |
|
|
|
|
|
|
|
|
shared_features = self.shared_layers(pooled_features) |
|
|
|
|
|
|
|
|
fluency_logits = self.fluency_head(shared_features) |
|
|
articulation_logits = self.articulation_head(shared_features) |
|
|
full_logits = self.full_head(shared_features) |
|
|
|
|
|
|
|
|
fluency_probs = F.softmax(fluency_logits, dim=-1) |
|
|
articulation_probs = F.softmax(articulation_logits, dim=-1) |
|
|
full_probs = F.softmax(full_logits, dim=-1) |
|
|
|
|
|
return { |
|
|
"fluency_logits": fluency_logits, |
|
|
"articulation_logits": articulation_logits, |
|
|
"full_logits": full_logits, |
|
|
"fluency_probs": fluency_probs, |
|
|
"articulation_probs": articulation_probs, |
|
|
"full_probs": full_probs, |
|
|
"shared_features": shared_features, |
|
|
} |
|
|
|
|
|
|
|
|
class SpeechPathologyClassifier(nn.Module): |
|
|
""" |
|
|
Speech Pathology Classifier using Wav2Vec2-XLSR-53 with custom multi-task head. |
|
|
|
|
|
This model combines: |
|
|
- Wav2Vec2-XLSR-53: Pretrained speech feature extractor |
|
|
- Custom MultiTaskClassifierHead: For fluency and articulation classification |
|
|
|
|
|
Outputs: |
|
|
- Fluency score: Probability of fluent speech (0-1) |
|
|
- Articulation classes: Probabilities for 4 articulation types |
|
|
""" |
|
|
|
|
|
|
|
|
ARTICULATION_CLASSES = [ |
|
|
"normal", |
|
|
"substitution", |
|
|
"omission", |
|
|
"distortion" |
|
|
] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
model_name: str = "facebook/wav2vec2-large-xlsr-53", |
|
|
classifier_hidden_dims: List[int] = None, |
|
|
dropout: float = 0.1, |
|
|
num_articulation_classes: int = 4, |
|
|
device: Optional[str] = None, |
|
|
use_fp16: bool = False |
|
|
): |
|
|
""" |
|
|
Initialize the Speech Pathology Classifier. |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model identifier for Wav2Vec2-XLSR-53 |
|
|
classifier_hidden_dims: List of hidden layer dimensions for classifier |
|
|
Default: [256, 128] |
|
|
dropout: Dropout probability for classifier layers |
|
|
num_articulation_classes: Number of articulation classes (default: 4) |
|
|
device: Device to run on ("cuda" or "cpu"). Auto-detects if None |
|
|
use_fp16: Whether to use half-precision (requires CUDA) |
|
|
|
|
|
Raises: |
|
|
ValueError: If model_name is invalid or model cannot be loaded |
|
|
RuntimeError: If CUDA is requested but unavailable |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
if device == "cuda" and not torch.cuda.is_available(): |
|
|
logger.warning("CUDA requested but not available. Falling back to CPU.") |
|
|
device = "cpu" |
|
|
|
|
|
self.device = torch.device(device) |
|
|
self.use_fp16 = use_fp16 and device == "cuda" |
|
|
self.is_trained = False |
|
|
|
|
|
if classifier_hidden_dims is None: |
|
|
classifier_hidden_dims = [256, 128] |
|
|
|
|
|
logger.info(f"Initializing SpeechPathologyClassifier on {device}") |
|
|
logger.info(f"Model: {model_name}") |
|
|
logger.info(f"Classifier hidden dims: {classifier_hidden_dims}") |
|
|
logger.info(f"FP16: {self.use_fp16}") |
|
|
|
|
|
try: |
|
|
|
|
|
hf_token = os.getenv("HF_TOKEN") |
|
|
|
|
|
logger.info("Loading Wav2Vec2 model and feature extractor...") |
|
|
self.wav2vec2_model = Wav2Vec2Model.from_pretrained( |
|
|
model_name, |
|
|
token=hf_token if hf_token else None |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
self.processor = Wav2Vec2FeatureExtractor.from_pretrained( |
|
|
model_name, |
|
|
token=hf_token if hf_token else None |
|
|
) |
|
|
|
|
|
|
|
|
config: Wav2Vec2Config = self.wav2vec2_model.config |
|
|
feature_dim = config.hidden_size |
|
|
|
|
|
logger.info(f"Wav2Vec2 feature dimension: {feature_dim}") |
|
|
|
|
|
|
|
|
|
|
|
for param in self.wav2vec2_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
logger.info("Wav2Vec2 parameters frozen for inference") |
|
|
|
|
|
|
|
|
self.classifier_head = MultiTaskClassifierHead( |
|
|
input_dim=feature_dim, |
|
|
hidden_dims=classifier_hidden_dims, |
|
|
dropout=dropout, |
|
|
num_articulation_classes=num_articulation_classes |
|
|
) |
|
|
|
|
|
|
|
|
self._load_trained_weights(None) |
|
|
|
|
|
|
|
|
self.wav2vec2_model = self.wav2vec2_model.to(self.device) |
|
|
self.classifier_head = self.classifier_head.to(self.device) |
|
|
|
|
|
|
|
|
self.eval() |
|
|
|
|
|
|
|
|
if self.use_fp16: |
|
|
self.wav2vec2_model = self.wav2vec2_model.half() |
|
|
logger.info("Model converted to FP16") |
|
|
|
|
|
logger.info("β
SpeechPathologyClassifier initialized successfully") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"β Failed to initialize model: {e}", exc_info=True) |
|
|
raise RuntimeError(f"Failed to load Wav2Vec2 model: {e}") from e |
|
|
|
|
|
def _load_trained_weights(self, model_path: Optional[str] = None): |
|
|
""" |
|
|
Load trained classifier head weights if available. |
|
|
|
|
|
Args: |
|
|
model_path: Optional path to model checkpoint. If None, tries default checkpoint paths. |
|
|
""" |
|
|
from pathlib import Path |
|
|
|
|
|
checkpoint_paths = [] |
|
|
|
|
|
|
|
|
if model_path: |
|
|
checkpoint_paths.append(Path(model_path)) |
|
|
|
|
|
|
|
|
checkpoint_paths.extend([ |
|
|
Path("models/checkpoints/classifier_head_best.pt"), |
|
|
Path("models/checkpoints/classifier_head_trained.pt") |
|
|
]) |
|
|
|
|
|
for checkpoint_path in checkpoint_paths: |
|
|
if checkpoint_path.exists(): |
|
|
try: |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
|
|
|
|
|
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: |
|
|
state_dict = checkpoint['model_state_dict'] |
|
|
epoch = checkpoint.get('epoch', 'unknown') |
|
|
val_acc = checkpoint.get('val_accuracy', 'unknown') |
|
|
else: |
|
|
state_dict = checkpoint |
|
|
epoch = 'unknown' |
|
|
val_acc = 'unknown' |
|
|
|
|
|
self.classifier_head.load_state_dict(state_dict) |
|
|
logger.info(f"β
Loaded trained classifier head from {checkpoint_path}") |
|
|
logger.info(f" Epoch: {epoch}, Validation Accuracy: {val_acc}") |
|
|
self.is_trained = True |
|
|
return |
|
|
except Exception as e: |
|
|
logger.warning(f"β οΈ Could not load checkpoint {checkpoint_path}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
logger.warning("β οΈ No trained classifier weights found. Using untrained head (beta mode)") |
|
|
logger.warning(" To train the classifier, run: python training/train_classifier_head.py") |
|
|
self.is_trained = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass through the model. |
|
|
|
|
|
Args: |
|
|
input_values: Audio input tensor of shape (batch_size, seq_len) |
|
|
Should be normalized to [-1, 1] range |
|
|
attention_mask: Optional attention mask for padding |
|
|
|
|
|
Returns: |
|
|
Dictionary containing: |
|
|
- fluency_logits: Binary classification logits |
|
|
- articulation_logits: Multi-class logits |
|
|
- fluency_probs: Fluency probabilities (0-1) |
|
|
- articulation_probs: Articulation class probabilities |
|
|
- wav2vec2_features: Raw Wav2Vec2 features (for debugging) |
|
|
""" |
|
|
|
|
|
try: |
|
|
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: |
|
|
import json, time |
|
|
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Before Wav2Vec2 forward","data":{"input_values_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n') |
|
|
except: pass |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
with torch.no_grad() if not self.training else torch.enable_grad(): |
|
|
wav2vec2_outputs = self.wav2vec2_model( |
|
|
input_values=input_values, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
except Exception as e: |
|
|
|
|
|
try: |
|
|
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: |
|
|
import json, time |
|
|
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:288","message":"Wav2Vec2 forward exception","data":{"error":str(e),"error_type":type(e).__name__,"input_shape":list(input_values.shape)},"timestamp":int(time.time()*1000)}) + '\n') |
|
|
except: pass |
|
|
|
|
|
raise |
|
|
|
|
|
|
|
|
features = wav2vec2_outputs.last_hidden_state |
|
|
|
|
|
|
|
|
try: |
|
|
with open(r'c:\Users\kpanfas\Desktop\zlaqa\slaq-version-d-to-a\zlaqa-version-b\ai-enginee\zlaqa-version-b-ai-enginee\.cursor\debug.log', 'a') as f: |
|
|
import json, time |
|
|
f.write(json.dumps({"sessionId":"debug-session","runId":"run1","hypothesisId":"D","location":"speech_pathology_model.py:297","message":"After Wav2Vec2 forward","data":{"features_shape":list(features.shape),"seq_len":features.shape[1] if len(features.shape) > 1 else 0},"timestamp":int(time.time()*1000)}) + '\n') |
|
|
except: pass |
|
|
|
|
|
|
|
|
|
|
|
if features.shape[1] < 1: |
|
|
raise ValueError( |
|
|
f"Wav2Vec2 output sequence length is too short: {features.shape[1]}. " |
|
|
f"Input was {input_values.shape}. Try using longer audio segments (>= 500ms)." |
|
|
) |
|
|
|
|
|
|
|
|
outputs = self.classifier_head(features, attention_mask) |
|
|
|
|
|
|
|
|
outputs["wav2vec2_features"] = features |
|
|
|
|
|
return outputs |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
audio_array: torch.Tensor, |
|
|
sample_rate: int = 16000, |
|
|
return_dict: bool = True |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Predict fluency and articulation for audio input. |
|
|
|
|
|
Args: |
|
|
audio_array: Audio tensor of shape (seq_len,) or (batch_size, seq_len) |
|
|
Should be in range [-1, 1] |
|
|
sample_rate: Sample rate of audio (should match processor, typically 16000) |
|
|
return_dict: Whether to return dictionary or tuple |
|
|
|
|
|
Returns: |
|
|
Dictionary with predictions: |
|
|
- fluency_score: Float probability of fluent speech (0-1) |
|
|
- articulation_class: Integer class index (0-3) |
|
|
- articulation_class_name: String class name |
|
|
- articulation_probs: Probabilities for all classes |
|
|
- confidence: Overall confidence score |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
if audio_array.dim() == 1: |
|
|
audio_array = audio_array.unsqueeze(0) |
|
|
|
|
|
|
|
|
audio_array = audio_array.to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self.forward(audio_array) |
|
|
|
|
|
|
|
|
fluency_probs = outputs["fluency_probs"].cpu() |
|
|
articulation_probs = outputs["articulation_probs"].cpu() |
|
|
|
|
|
|
|
|
fluency_score = fluency_probs.item() if fluency_probs.numel() == 1 else fluency_probs[0].item() |
|
|
|
|
|
|
|
|
articulation_probs_flat = articulation_probs[0] if articulation_probs.dim() > 1 else articulation_probs |
|
|
articulation_class_idx = torch.argmax(articulation_probs_flat).item() |
|
|
articulation_class_name = self.ARTICULATION_CLASSES[articulation_class_idx] |
|
|
articulation_confidence = articulation_probs_flat[articulation_class_idx].item() |
|
|
|
|
|
|
|
|
overall_confidence = (fluency_score + articulation_confidence) / 2.0 |
|
|
|
|
|
if return_dict: |
|
|
return { |
|
|
"fluency_score": fluency_score, |
|
|
"articulation_class": articulation_class_idx, |
|
|
"articulation_class_name": articulation_class_name, |
|
|
"articulation_probs": articulation_probs_flat.tolist(), |
|
|
"confidence": overall_confidence, |
|
|
} |
|
|
else: |
|
|
return fluency_score, articulation_class_idx, articulation_probs_flat.tolist() |
|
|
|
|
|
def get_articulation_class_name(self, class_idx: int) -> str: |
|
|
"""Get the name of an articulation class by index.""" |
|
|
if 0 <= class_idx < len(self.ARTICULATION_CLASSES): |
|
|
return self.ARTICULATION_CLASSES[class_idx] |
|
|
raise ValueError(f"Invalid articulation class index: {class_idx}") |
|
|
|
|
|
def unfreeze_wav2vec2(self): |
|
|
"""Unfreeze Wav2Vec2 parameters for fine-tuning.""" |
|
|
logger.info("Unfreezing Wav2Vec2 parameters for fine-tuning") |
|
|
for param in self.wav2vec2_model.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def freeze_wav2vec2(self): |
|
|
"""Freeze Wav2Vec2 parameters (default for inference).""" |
|
|
logger.info("Freezing Wav2Vec2 parameters") |
|
|
for param in self.wav2vec2_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
def load_speech_pathology_model( |
|
|
model_name: str = "facebook/wav2vec2-large-xlsr-53", |
|
|
classifier_hidden_dims: List[int] = None, |
|
|
dropout: float = 0.1, |
|
|
device: Optional[str] = None, |
|
|
use_fp16: bool = False, |
|
|
model_path: Optional[str] = None |
|
|
) -> SpeechPathologyClassifier: |
|
|
""" |
|
|
Load or create a SpeechPathologyClassifier instance. |
|
|
|
|
|
Args: |
|
|
model_name: HuggingFace model identifier |
|
|
classifier_hidden_dims: Classifier hidden dimensions |
|
|
dropout: Dropout probability |
|
|
device: Device to run on |
|
|
use_fp16: Whether to use FP16 |
|
|
model_path: Optional path to saved model checkpoint |
|
|
|
|
|
Returns: |
|
|
SpeechPathologyClassifier instance |
|
|
""" |
|
|
if model_path and os.path.exists(model_path): |
|
|
logger.info(f"Loading model from checkpoint: {model_path}") |
|
|
model = SpeechPathologyClassifier( |
|
|
model_name=model_name, |
|
|
classifier_hidden_dims=classifier_hidden_dims or [256, 128], |
|
|
dropout=dropout, |
|
|
device=device, |
|
|
use_fp16=use_fp16 |
|
|
) |
|
|
checkpoint = torch.load(model_path, map_location=device or "cpu") |
|
|
model.load_state_dict(checkpoint["model_state_dict"]) |
|
|
logger.info("β
Model loaded from checkpoint") |
|
|
return model |
|
|
else: |
|
|
logger.info("Creating new SpeechPathologyClassifier") |
|
|
return SpeechPathologyClassifier( |
|
|
model_name=model_name, |
|
|
classifier_hidden_dims=classifier_hidden_dims or [256, 128], |
|
|
dropout=dropout, |
|
|
device=device, |
|
|
use_fp16=use_fp16 |
|
|
) |
|
|
|
|
|
|