|
|
""" |
|
|
Training Script for Speech Pathology Classifier Head |
|
|
|
|
|
This script fine-tunes the classification head on phoneme-level labeled data. |
|
|
Wav2Vec2 encoder is frozen; only the classifier head is trained. |
|
|
|
|
|
Usage: |
|
|
python training/train_classifier_head.py --config training/config.yaml |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
import json |
|
|
import yaml |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple, Optional, Any |
|
|
from datetime import datetime |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.optim as optim |
|
|
from torch.utils.data import Dataset, DataLoader, random_split |
|
|
import numpy as np |
|
|
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, confusion_matrix |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from models.speech_pathology_model import SpeechPathologyClassifier, MultiTaskClassifierHead |
|
|
from models.phoneme_mapper import PhonemeMapper |
|
|
from inference.inference_pipeline import InferencePipeline |
|
|
from config import default_audio_config, default_model_config, default_inference_config |
|
|
|
|
|
logging.basicConfig( |
|
|
level=logging.INFO, |
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
|
) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class PhonemeDataset(Dataset): |
|
|
"""Dataset for phoneme-level speech pathology training.""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
training_data: List[Dict[str, Any]], |
|
|
inference_pipeline: InferencePipeline, |
|
|
phoneme_mapper: PhonemeMapper |
|
|
): |
|
|
""" |
|
|
Initialize dataset. |
|
|
|
|
|
Args: |
|
|
training_data: List of training samples with frame labels |
|
|
inference_pipeline: Pipeline for extracting Wav2Vec2 features |
|
|
phoneme_mapper: Mapper for phoneme alignment |
|
|
""" |
|
|
self.training_data = training_data |
|
|
self.inference_pipeline = inference_pipeline |
|
|
self.phoneme_mapper = phoneme_mapper |
|
|
|
|
|
logger.info(f"Initialized dataset with {len(training_data)} samples") |
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.training_data) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a training sample.""" |
|
|
sample = self.training_data[idx] |
|
|
audio_file = sample['audio_file'] |
|
|
frame_labels = sample['frame_labels'] |
|
|
|
|
|
|
|
|
try: |
|
|
audio, sr = librosa.load(audio_file, sr=16000) |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to load {audio_file}: {e}") |
|
|
|
|
|
return { |
|
|
'features': torch.zeros(1, 1024), |
|
|
'labels': torch.tensor([0], dtype=torch.long), |
|
|
'valid': torch.tensor(False) |
|
|
} |
|
|
|
|
|
|
|
|
try: |
|
|
frame_features, frame_times = self.inference_pipeline.get_phone_level_features(audio) |
|
|
|
|
|
|
|
|
num_features = len(frame_features) |
|
|
num_labels = len(frame_labels) |
|
|
|
|
|
|
|
|
if num_labels < num_features: |
|
|
frame_labels = frame_labels + [0] * (num_features - num_labels) |
|
|
elif num_labels > num_features: |
|
|
frame_labels = frame_labels[:num_features] |
|
|
|
|
|
|
|
|
features_tensor = frame_features |
|
|
labels_tensor = torch.tensor(frame_labels[:num_features], dtype=torch.long) |
|
|
|
|
|
return { |
|
|
'features': features_tensor, |
|
|
'labels': labels_tensor, |
|
|
'valid': torch.tensor(True) |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to extract features from {audio_file}: {e}") |
|
|
return { |
|
|
'features': torch.zeros(1, 1024), |
|
|
'labels': torch.tensor([0], dtype=torch.long), |
|
|
'valid': torch.tensor(False) |
|
|
} |
|
|
|
|
|
|
|
|
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
|
"""Collate function for DataLoader.""" |
|
|
|
|
|
valid_batch = [b for b in batch if b['valid'].item()] |
|
|
|
|
|
if not valid_batch: |
|
|
|
|
|
return { |
|
|
'features': torch.zeros(1, 1, 1024), |
|
|
'labels': torch.zeros(1, 1, dtype=torch.long) |
|
|
} |
|
|
|
|
|
|
|
|
features_list = [] |
|
|
labels_list = [] |
|
|
|
|
|
for item in valid_batch: |
|
|
features_list.append(item['features']) |
|
|
labels_list.append(item['labels']) |
|
|
|
|
|
|
|
|
max_len = max(f.shape[0] for f in features_list) |
|
|
|
|
|
padded_features = [] |
|
|
padded_labels = [] |
|
|
|
|
|
for feat, lab in zip(features_list, labels_list): |
|
|
if feat.shape[0] < max_len: |
|
|
padding = max_len - feat.shape[0] |
|
|
feat = torch.cat([feat, torch.zeros(padding, feat.shape[1])]) |
|
|
lab = torch.cat([lab, torch.zeros(padding, dtype=torch.long)]) |
|
|
padded_features.append(feat) |
|
|
padded_labels.append(lab) |
|
|
|
|
|
return { |
|
|
'features': torch.stack(padded_features), |
|
|
'labels': torch.stack(padded_labels) |
|
|
} |
|
|
|
|
|
|
|
|
def calculate_class_weights(dataset: PhonemeDataset) -> torch.Tensor: |
|
|
"""Calculate class weights for imbalanced data.""" |
|
|
all_labels = [] |
|
|
for i in range(len(dataset)): |
|
|
sample = dataset[i] |
|
|
if sample['valid'].item(): |
|
|
all_labels.extend(sample['labels'].tolist()) |
|
|
|
|
|
if not all_labels: |
|
|
return torch.ones(8) |
|
|
|
|
|
unique, counts = np.unique(all_labels, return_counts=True) |
|
|
total = len(all_labels) |
|
|
|
|
|
weights = torch.ones(8) |
|
|
for cls, count in zip(unique, counts): |
|
|
if count > 0: |
|
|
weights[int(cls)] = total / (8 * count) |
|
|
|
|
|
logger.info(f"Class weights: {weights.tolist()}") |
|
|
return weights |
|
|
|
|
|
|
|
|
def train_epoch( |
|
|
model: nn.Module, |
|
|
dataloader: DataLoader, |
|
|
optimizer: optim.Optimizer, |
|
|
criterion: nn.Module, |
|
|
device: torch.device, |
|
|
epoch: int |
|
|
) -> Dict[str, float]: |
|
|
"""Train for one epoch.""" |
|
|
model.train() |
|
|
total_loss = 0.0 |
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
|
|
|
for batch_idx, batch in enumerate(dataloader): |
|
|
features = batch['features'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
|
|
|
batch_size, seq_len, feat_dim = features.shape |
|
|
features_flat = features.view(-1, feat_dim) |
|
|
labels_flat = labels.view(-1) |
|
|
|
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
shared_features = model.classifier_head.shared_layers(features_flat) |
|
|
logits = model.classifier_head.full_head(shared_features) |
|
|
|
|
|
|
|
|
loss = criterion(logits, labels_flat) |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
torch.nn.utils.clip_grad_norm_(model.classifier_head.parameters(), max_norm=1.0) |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
total_loss += loss.item() |
|
|
preds = torch.argmax(logits, dim=-1).cpu().numpy() |
|
|
all_preds.extend(preds) |
|
|
all_labels.extend(labels_flat.cpu().numpy()) |
|
|
|
|
|
if batch_idx % 10 == 0: |
|
|
logger.info(f"Epoch {epoch}, Batch {batch_idx}/{len(dataloader)}, Loss: {loss.item():.4f}") |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
accuracy = accuracy_score(all_labels, all_preds) |
|
|
|
|
|
return { |
|
|
'loss': avg_loss, |
|
|
'accuracy': accuracy |
|
|
} |
|
|
|
|
|
|
|
|
def validate( |
|
|
model: nn.Module, |
|
|
dataloader: DataLoader, |
|
|
criterion: nn.Module, |
|
|
device: torch.device |
|
|
) -> Dict[str, float]: |
|
|
"""Validate model.""" |
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
all_preds = [] |
|
|
all_labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for batch in dataloader: |
|
|
features = batch['features'].to(device) |
|
|
labels = batch['labels'].to(device) |
|
|
|
|
|
batch_size, seq_len, feat_dim = features.shape |
|
|
features_flat = features.view(-1, feat_dim) |
|
|
labels_flat = labels.view(-1) |
|
|
|
|
|
|
|
|
shared_features = model.classifier_head.shared_layers(features_flat) |
|
|
logits = model.classifier_head.full_head(shared_features) |
|
|
|
|
|
loss = criterion(logits, labels_flat) |
|
|
total_loss += loss.item() |
|
|
|
|
|
preds = torch.argmax(logits, dim=-1).cpu().numpy() |
|
|
all_preds.extend(preds) |
|
|
all_labels.extend(labels_flat.cpu().numpy()) |
|
|
|
|
|
avg_loss = total_loss / len(dataloader) |
|
|
accuracy = accuracy_score(all_labels, all_preds) |
|
|
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0) |
|
|
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0) |
|
|
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0) |
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_preds, labels=list(range(8))) |
|
|
|
|
|
return { |
|
|
'loss': avg_loss, |
|
|
'accuracy': accuracy, |
|
|
'f1_score': f1, |
|
|
'precision': precision, |
|
|
'recall': recall, |
|
|
'confusion_matrix': cm.tolist() |
|
|
} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Train classifier head") |
|
|
parser.add_argument('--config', type=str, default='training/config.yaml', |
|
|
help='Path to config file') |
|
|
parser.add_argument('--resume', type=str, default=None, |
|
|
help='Resume from checkpoint') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
with open(args.config, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() and config['device']['use_cuda'] else 'cpu') |
|
|
logger.info(f"Using device: {device}") |
|
|
|
|
|
|
|
|
training_file = Path(config['data']['training_dataset']) |
|
|
if not training_file.exists(): |
|
|
logger.error(f"Training dataset not found: {training_file}") |
|
|
logger.info("Run scripts/annotation_helper.py to export training data first") |
|
|
return |
|
|
|
|
|
with open(training_file, 'r') as f: |
|
|
training_data = json.load(f) |
|
|
|
|
|
logger.info(f"Loaded {len(training_data)} training samples") |
|
|
|
|
|
|
|
|
inference_pipeline = InferencePipeline( |
|
|
audio_config=default_audio_config, |
|
|
model_config=default_model_config, |
|
|
inference_config=default_inference_config |
|
|
) |
|
|
|
|
|
|
|
|
phoneme_mapper = PhonemeMapper( |
|
|
frame_duration_ms=20, |
|
|
sample_rate=16000 |
|
|
) |
|
|
|
|
|
|
|
|
dataset = PhonemeDataset(training_data, inference_pipeline, phoneme_mapper) |
|
|
|
|
|
|
|
|
train_size = int(config['data']['train_split'] * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
|
|
|
train_dataset, val_dataset = random_split( |
|
|
dataset, |
|
|
[train_size, val_size], |
|
|
generator=torch.Generator().manual_seed(config['data']['random_seed']) |
|
|
) |
|
|
|
|
|
logger.info(f"Train samples: {len(train_dataset)}, Val samples: {len(val_dataset)}") |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=config['training']['batch_size'], |
|
|
shuffle=True, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=config['training']['batch_size'], |
|
|
shuffle=False, |
|
|
collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
|
|
|
model = inference_pipeline.model |
|
|
model.train() |
|
|
|
|
|
|
|
|
for param in model.wav2vec2_model.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
for param in model.classifier_head.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
logger.info("Model prepared: Wav2Vec2 frozen, classifier head trainable") |
|
|
|
|
|
|
|
|
class_weights = calculate_class_weights(dataset) |
|
|
class_weights = class_weights.to(device) |
|
|
|
|
|
|
|
|
if config['training']['loss']['type'] == 'cross_entropy': |
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights) |
|
|
else: |
|
|
|
|
|
criterion = nn.CrossEntropyLoss(weight=class_weights) |
|
|
|
|
|
|
|
|
optimizer = optim.Adam( |
|
|
model.classifier_head.parameters(), |
|
|
lr=config['training']['learning_rate'], |
|
|
weight_decay=config['training']['weight_decay'] |
|
|
) |
|
|
|
|
|
|
|
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
|
optimizer, |
|
|
mode='min', |
|
|
factor=config['training']['scheduler_factor'], |
|
|
patience=config['training']['scheduler_patience'], |
|
|
min_lr=config['training']['scheduler_min_lr'] |
|
|
) |
|
|
|
|
|
|
|
|
best_val_loss = float('inf') |
|
|
patience_counter = 0 |
|
|
|
|
|
checkpoint_dir = Path(config['checkpoint']['save_dir']) |
|
|
checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
for epoch in range(config['training']['num_epochs']): |
|
|
logger.info(f"\n{'='*50}") |
|
|
logger.info(f"Epoch {epoch+1}/{config['training']['num_epochs']}") |
|
|
logger.info(f"{'='*50}") |
|
|
|
|
|
|
|
|
train_metrics = train_epoch(model, train_loader, optimizer, criterion, device, epoch+1) |
|
|
logger.info(f"Train - Loss: {train_metrics['loss']:.4f}, Accuracy: {train_metrics['accuracy']:.4f}") |
|
|
|
|
|
|
|
|
val_metrics = validate(model, val_loader, criterion, device) |
|
|
logger.info(f"Val - Loss: {val_metrics['loss']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}, " |
|
|
f"F1: {val_metrics['f1_score']:.4f}") |
|
|
|
|
|
|
|
|
scheduler.step(val_metrics['loss']) |
|
|
|
|
|
|
|
|
if config['checkpoint']['save_best'] and val_metrics['loss'] < best_val_loss: |
|
|
best_val_loss = val_metrics['loss'] |
|
|
checkpoint_path = checkpoint_dir / config['checkpoint']['best_filename'] |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.classifier_head.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'val_loss': val_metrics['loss'], |
|
|
'val_accuracy': val_metrics['accuracy'], |
|
|
'config': config |
|
|
}, checkpoint_path) |
|
|
logger.info(f"✅ Saved best checkpoint to {checkpoint_path}") |
|
|
patience_counter = 0 |
|
|
else: |
|
|
patience_counter += 1 |
|
|
|
|
|
|
|
|
if config['training']['early_stopping']['enabled']: |
|
|
if patience_counter >= config['training']['early_stopping']['patience']: |
|
|
logger.info(f"Early stopping triggered after {epoch+1} epochs") |
|
|
break |
|
|
|
|
|
|
|
|
if config['checkpoint']['save_last'] and (epoch + 1) % config['checkpoint']['save_frequency'] == 0: |
|
|
checkpoint_path = checkpoint_dir / config['checkpoint']['filename'] |
|
|
torch.save({ |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.classifier_head.state_dict(), |
|
|
'optimizer_state_dict': optimizer.state_dict(), |
|
|
'val_loss': val_metrics['loss'], |
|
|
'val_accuracy': val_metrics['accuracy'], |
|
|
'config': config |
|
|
}, checkpoint_path) |
|
|
logger.info(f"Saved checkpoint to {checkpoint_path}") |
|
|
|
|
|
logger.info("\n✅ Training complete!") |
|
|
logger.info(f"Best validation loss: {best_val_loss:.4f}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|