""" Training Script for Speech Pathology Classifier Head This script fine-tunes the classification head on phoneme-level labeled data. Wav2Vec2 encoder is frozen; only the classifier head is trained. Usage: python training/train_classifier_head.py --config training/config.yaml """ import logging import os import sys import json import yaml import argparse from pathlib import Path from typing import Dict, List, Tuple, Optional, Any from datetime import datetime import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import Dataset, DataLoader, random_split import numpy as np from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix import librosa import soundfile as sf # Add project root to path sys.path.insert(0, str(Path(__file__).parent.parent)) from models.speech_pathology_model import SpeechPathologyClassifier, MultiTaskClassifierHead from models.phoneme_mapper import PhonemeMapper from inference.inference_pipeline import InferencePipeline from config import default_audio_config, default_model_config, default_inference_config logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) class PhonemeDataset(Dataset): """Dataset for phoneme-level speech pathology training.""" def __init__( self, training_data: List[Dict[str, Any]], inference_pipeline: InferencePipeline, phoneme_mapper: PhonemeMapper ): """ Initialize dataset. Args: training_data: List of training samples with frame labels inference_pipeline: Pipeline for extracting Wav2Vec2 features phoneme_mapper: Mapper for phoneme alignment """ self.training_data = training_data self.inference_pipeline = inference_pipeline self.phoneme_mapper = phoneme_mapper logger.info(f"Initialized dataset with {len(training_data)} samples") def __len__(self) -> int: return len(self.training_data) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a training sample.""" sample = self.training_data[idx] audio_file = sample['audio_file'] frame_labels = sample['frame_labels'] # Load audio try: audio, sr = librosa.load(audio_file, sr=16000) except Exception as e: logger.error(f"Failed to load {audio_file}: {e}") # Return dummy data return { 'features': torch.zeros(1, 1024), 'labels': torch.tensor([0], dtype=torch.long), 'valid': torch.tensor(False) } # Extract Wav2Vec2 features try: frame_features, frame_times = self.inference_pipeline.get_phone_level_features(audio) # Align labels with features num_features = len(frame_features) num_labels = len(frame_labels) # Pad or truncate labels to match features if num_labels < num_features: frame_labels = frame_labels + [0] * (num_features - num_labels) elif num_labels > num_features: frame_labels = frame_labels[:num_features] # Convert to tensors features_tensor = frame_features # Already a tensor labels_tensor = torch.tensor(frame_labels[:num_features], dtype=torch.long) return { 'features': features_tensor, 'labels': labels_tensor, 'valid': torch.tensor(True) } except Exception as e: logger.error(f"Failed to extract features from {audio_file}: {e}") return { 'features': torch.zeros(1, 1024), 'labels': torch.tensor([0], dtype=torch.long), 'valid': torch.tensor(False) } def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """Collate function for DataLoader.""" # Filter out invalid samples valid_batch = [b for b in batch if b['valid'].item()] if not valid_batch: # Return dummy batch return { 'features': torch.zeros(1, 1, 1024), 'labels': torch.zeros(1, 1, dtype=torch.long) } # Stack features and labels features_list = [] labels_list = [] for item in valid_batch: features_list.append(item['features']) labels_list.append(item['labels']) # Pad to same length max_len = max(f.shape[0] for f in features_list) padded_features = [] padded_labels = [] for feat, lab in zip(features_list, labels_list): if feat.shape[0] < max_len: padding = max_len - feat.shape[0] feat = torch.cat([feat, torch.zeros(padding, feat.shape[1])]) lab = torch.cat([lab, torch.zeros(padding, dtype=torch.long)]) padded_features.append(feat) padded_labels.append(lab) return { 'features': torch.stack(padded_features), 'labels': torch.stack(padded_labels) } def calculate_class_weights(dataset: PhonemeDataset) -> torch.Tensor: """Calculate class weights for imbalanced data.""" all_labels = [] for i in range(len(dataset)): sample = dataset[i] if sample['valid'].item(): all_labels.extend(sample['labels'].tolist()) if not all_labels: return torch.ones(8) unique, counts = np.unique(all_labels, return_counts=True) total = len(all_labels) weights = torch.ones(8) for cls, count in zip(unique, counts): if count > 0: weights[int(cls)] = total / (8 * count) # Inverse frequency weighting logger.info(f"Class weights: {weights.tolist()}") return weights def train_epoch( model: nn.Module, dataloader: DataLoader, optimizer: optim.Optimizer, criterion: nn.Module, device: torch.device, epoch: int ) -> Dict[str, float]: """Train for one epoch.""" model.train() total_loss = 0.0 all_preds = [] all_labels = [] for batch_idx, batch in enumerate(dataloader): features = batch['features'].to(device) # (batch, seq_len, 1024) labels = batch['labels'].to(device) # (batch, seq_len) # Flatten for processing batch_size, seq_len, feat_dim = features.shape features_flat = features.view(-1, feat_dim) # (batch * seq_len, 1024) labels_flat = labels.view(-1) # (batch * seq_len) # Forward pass optimizer.zero_grad() # Get predictions from full_head shared_features = model.classifier_head.shared_layers(features_flat) logits = model.classifier_head.full_head(shared_features) # (batch * seq_len, 8) # Calculate loss loss = criterion(logits, labels_flat) # Backward pass loss.backward() torch.nn.utils.clip_grad_norm_(model.classifier_head.parameters(), max_norm=1.0) optimizer.step() # Metrics total_loss += loss.item() preds = torch.argmax(logits, dim=-1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels_flat.cpu().numpy()) if batch_idx % 10 == 0: logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}") avg_loss = total_loss / len(dataloader) accuracy = accuracy_score(all_labels, all_preds) return { 'loss': avg_loss, 'accuracy': accuracy } def validate( model: nn.Module, dataloader: DataLoader, criterion: nn.Module, device: torch.device ) -> Dict[str, float]: """Validate model.""" model.eval() total_loss = 0.0 all_preds = [] all_labels = [] with torch.no_grad(): for batch in dataloader: features = batch['features'].to(device) labels = batch['labels'].to(device) batch_size, seq_len, feat_dim = features.shape features_flat = features.view(-1, feat_dim) labels_flat = labels.view(-1) # Forward pass shared_features = model.classifier_head.shared_layers(features_flat) logits = model.classifier_head.full_head(shared_features) loss = criterion(logits, labels_flat) total_loss += loss.item() preds = torch.argmax(logits, dim=-1).cpu().numpy() all_preds.extend(preds) all_labels.extend(labels_flat.cpu().numpy()) avg_loss = total_loss / len(dataloader) accuracy = accuracy_score(all_labels, all_preds) f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) # Per-class metrics cm = confusion_matrix(all_labels, all_preds, labels=list(range(8))) return { 'loss': avg_loss, 'accuracy': accuracy, 'f1_score': f1, 'precision': precision, 'recall': recall, 'confusion_matrix': cm.tolist() } def main(): parser = argparse.ArgumentParser(description="Train classifier head") parser.add_argument('--config', type=str, default='training/config.yaml', help='Path to config file') parser.add_argument('--resume', type=str, default=None, help='Resume from checkpoint') args = parser.parse_args() # Load config with open(args.config, 'r') as f: config = yaml.safe_load(f) # Set device device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu') logger.info(f"Using device: {device}") # Load training data training_file = Path(config['data']['training_dataset']) if not training_file.exists(): logger.error(f"Training dataset not found: {training_file}") logger.info("Run scripts/annotation_helper.py to export training data first") return with open(training_file, 'r') as f: training_data = json.load(f) logger.info(f"Loaded {len(training_data)} training samples") # Initialize inference pipeline for feature extraction inference_pipeline = InferencePipeline( audio_config=default_audio_config, model_config=default_model_config, inference_config=default_inference_config ) # Initialize phoneme mapper phoneme_mapper = PhonemeMapper( frame_duration_ms=20, sample_rate=16000 ) # Create dataset dataset = PhonemeDataset(training_data, inference_pipeline, phoneme_mapper) # Split dataset train_size = int(config['data']['train_split'] * len(dataset)) val_size = len(dataset) - train_size train_dataset, val_dataset = random_split( dataset, [train_size, val_size], generator=torch.Generator().manual_seed(config['data']['random_seed']) ) logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") # Create data loaders train_loader = DataLoader( train_dataset, batch_size=config['training']['batch_size'], shuffle=True, collate_fn=collate_fn ) val_loader = DataLoader( val_dataset, batch_size=config['training']['batch_size'], shuffle=False, collate_fn=collate_fn ) # Load model model = inference_pipeline.model model.train() # Set to training mode # Freeze Wav2Vec2 (should already be frozen, but ensure it) for param in model.wav2vec2_model.parameters(): param.requires_grad = False # Unfreeze classifier head for param in model.classifier_head.parameters(): param.requires_grad = True logger.info("Model prepared: Wav2Vec2 frozen, classifier head trainable") # Calculate class weights class_weights = calculate_class_weights(dataset) class_weights = class_weights.to(device) # Loss function if config['training']['loss']['type'] == 'cross_entropy': criterion = nn.CrossEntropyLoss(weight=class_weights) else: # Focal loss implementation would go here criterion = nn.CrossEntropyLoss(weight=class_weights) # Optimizer optimizer = optim.Adam( model.classifier_head.parameters(), lr=config['training']['learning_rate'], weight_decay=config['training']['weight_decay'] ) # Scheduler scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=config['training']['scheduler_factor'], patience=config['training']['scheduler_patience'], min_lr=config['training']['scheduler_min_lr'] ) # Training loop best_val_loss = float('inf') patience_counter = 0 checkpoint_dir = Path(config['checkpoint']['save_dir']) checkpoint_dir.mkdir(parents=True, exist_ok=True) for epoch in range(config['training']['num_epochs']): logger.info(f"\n{'='*50}") logger.info(f"Epoch {epoch+1}/{config['training']['num_epochs']}") logger.info(f"{'='*50}") # Train train_metrics = train_epoch(model, train_loader, optimizer, criterion, device, epoch+1) logger.info(f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}") # Validate val_metrics = validate(model, val_loader, criterion, device) logger.info(f"Val - Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, " f"F1: {val_metrics['f1_score']:.4f}") # Scheduler step scheduler.step(val_metrics['loss']) # Save checkpoint if config['checkpoint']['save_best'] and val_metrics['loss'] < best_val_loss: best_val_loss = val_metrics['loss'] checkpoint_path = checkpoint_dir / config['checkpoint']['best_filename'] torch.save({ 'epoch': epoch, 'model_state_dict': model.classifier_head.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_metrics['loss'], 'val_accuracy': val_metrics['accuracy'], 'config': config }, checkpoint_path) logger.info(f"āœ… Saved best checkpoint to {checkpoint_path}") patience_counter = 0 else: patience_counter += 1 # Early stopping if config['training']['early_stopping']['enabled']: if patience_counter >= config['training']['early_stopping']['patience']: logger.info(f"Early stopping triggered after {epoch+1} epochs") break # Save last checkpoint if config['checkpoint']['save_last'] and (epoch + 1) % config['checkpoint']['save_frequency'] == 0: checkpoint_path = checkpoint_dir / config['checkpoint']['filename'] torch.save({ 'epoch': epoch, 'model_state_dict': model.classifier_head.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'val_loss': val_metrics['loss'], 'val_accuracy': val_metrics['accuracy'], 'config': config }, checkpoint_path) logger.info(f"Saved checkpoint to {checkpoint_path}") logger.info("\nāœ… Training complete!") logger.info(f"Best validation loss: {best_val_loss:.4f}") if __name__ == "__main__": main()