""" scripts/train_cough_agent.py — Train binary Cough classifier on COUGHVID. Labels: Healthy (0) vs Symptomatic (1) [COVID-19 merged into Symptomatic] Input: OPERA-CT 768-dim embeddings Model: BinaryMLP 768 -> 256 -> 64 -> 2 with BatchNorm + Dropout Output: saved_models/cough_opera_mlp.pt outputs/cough_confusion_matrix.png outputs/cough_roc_curve.png outputs/results_cough.json """ import os import sys import json import numpy as np import pandas as pd import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler from sklearn.model_selection import train_test_split from sklearn.metrics import ( f1_score, recall_score, precision_score, roc_auc_score, accuracy_score, confusion_matrix, roc_curve ) import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import seaborn as sns sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) os.makedirs('outputs', exist_ok=True) os.makedirs('saved_models', exist_ok=True) RANDOM_STATE = 42 EPOCHS = 60 BATCH_SIZE = 64 LR = 1e-3 DROPOUT = 0.3 torch.manual_seed(RANDOM_STATE) # ── Device ─────────────────────────────────────────────────────────────────── device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") # ── Data ───────────────────────────────────────────────────────────────────── df = pd.read_csv('data/cough_labels_with_embeddings.csv').dropna(subset=['embedding_path']) print(f"Total samples: {len(df)}") print(df['label_str'].value_counts()) def load_embedding(path): return np.load(path).astype(np.float32) X = np.stack([load_embedding(p) for p in df['embedding_path']]) y = df['label'].values.astype(np.int64) # Train / val / test split (70/15/15) — stratified X_tv, X_test, y_tv, y_test = train_test_split( X, y, test_size=0.15, stratify=y, random_state=RANDOM_STATE) X_train, X_val, y_train, y_val = train_test_split( X_tv, y_tv, test_size=0.15/0.85, stratify=y_tv, random_state=RANDOM_STATE) print(f"Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}") print(f"Test positives (symptomatic): {y_test.sum()}") # ── Dataset ─────────────────────────────────────────────────────────────────── class EmbeddingDataset(Dataset): def __init__(self, X, y): self.X = torch.tensor(X, dtype=torch.float32) self.y = torch.tensor(y, dtype=torch.long) def __len__(self): return len(self.y) def __getitem__(self, i): return self.X[i], self.y[i] train_ds = EmbeddingDataset(X_train, y_train) val_ds = EmbeddingDataset(X_val, y_val) test_ds = EmbeddingDataset(X_test, y_test) # Weighted sampler for class imbalance counts = np.bincount(y_train) weights = 1.0 / counts[y_train] sampler = WeightedRandomSampler(weights, len(weights), replacement=True) train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler) val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False) # ── Model ───────────────────────────────────────────────────────────────────── class CoughMLP(nn.Module): def __init__(self, input_dim=768, dropout=0.3): super().__init__() self.net = nn.Sequential( nn.Linear(input_dim, 256), nn.BatchNorm1d(256), nn.ReLU(), nn.Dropout(dropout), nn.Linear(256, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(dropout), nn.Linear(64, 2), ) def forward(self, x): return self.net(x) model = CoughMLP(dropout=DROPOUT).to(device) # Class-weighted loss class_weights = torch.tensor( [1.0, len(y_train) / (2 * y_train.sum())], dtype=torch.float32 ).to(device) criterion = nn.CrossEntropyLoss(weight=class_weights) optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) # ── Training ────────────────────────────────────────────────────────────────── best_val_f1 = 0.0 best_state = None patience = 12 no_improve = 0 print("\nTraining...") for epoch in range(1, EPOCHS + 1): model.train() for xb, yb in train_loader: xb, yb = xb.to(device), yb.to(device) optimizer.zero_grad() loss = criterion(model(xb), yb) loss.backward() optimizer.step() scheduler.step() # Validation model.eval() all_preds, all_probs, all_labels = [], [], [] with torch.no_grad(): for xb, yb in val_loader: logits = model(xb.to(device)) probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() preds = logits.argmax(dim=1).cpu().numpy() all_preds.extend(preds) all_probs.extend(probs) all_labels.extend(yb.numpy()) val_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) if val_f1 > best_val_f1: best_val_f1 = val_f1 best_state = {k: v.clone() for k, v in model.state_dict().items()} no_improve = 0 else: no_improve += 1 if epoch % 10 == 0: print(f" Epoch {epoch:3d} | Val F1: {val_f1:.4f} | Best: {best_val_f1:.4f}") if no_improve >= patience: print(f" Early stop at epoch {epoch}") break # ── Threshold tuning on validation set ─────────────────────────────────────── model.load_state_dict(best_state) model.eval() val_probs, val_labels = [], [] with torch.no_grad(): for xb, yb in val_loader: probs = torch.softmax(model(xb.to(device)), dim=1)[:, 1].cpu().numpy() val_probs.extend(probs) val_labels.extend(yb.numpy()) val_probs = np.array(val_probs) val_labels = np.array(val_labels) best_thresh, best_f1 = 0.5, 0.0 for t in np.arange(0.3, 0.8, 0.01): preds = (val_probs >= t).astype(int) f1 = f1_score(val_labels, preds, average='macro', zero_division=0) if f1 > best_f1: best_f1, best_thresh = f1, t print(f"\nBest threshold: {best_thresh:.2f} (Val Macro F1: {best_f1:.4f})") # ── Test evaluation ─────────────────────────────────────────────────────────── test_probs, test_labels = [], [] with torch.no_grad(): for xb, yb in test_loader: probs = torch.softmax(model(xb.to(device)), dim=1)[:, 1].cpu().numpy() test_probs.extend(probs) test_labels.extend(yb.numpy()) test_probs = np.array(test_probs) test_labels = np.array(test_labels) test_preds = (test_probs >= best_thresh).astype(int) acc = accuracy_score(test_labels, test_preds) f1 = f1_score(test_labels, test_preds, average='macro', zero_division=0) rec = recall_score(test_labels, test_preds, pos_label=1, zero_division=0) prec = precision_score(test_labels, test_preds, pos_label=1, zero_division=0) auc = roc_auc_score(test_labels, test_probs) cm = confusion_matrix(test_labels, test_preds) print(f"\nTest Results:") print(f" Accuracy : {acc:.4f}") print(f" Macro F1 : {f1:.4f}") print(f" Recall : {rec:.4f}") print(f" Precision: {prec:.4f}") print(f" AUROC : {auc:.4f}") print(f" Confusion Matrix:\n{cm}") # ── Save model ──────────────────────────────────────────────────────────────── torch.save({ 'state_dict': model.state_dict(), 'threshold': best_thresh, 'val_f1': best_val_f1, 'config': {'input_dim': 768, 'dropout': DROPOUT}, }, 'saved_models/cough_opera_mlp.pt') print("\nSaved: saved_models/cough_opera_mlp.pt") # ── Confusion matrix plot ───────────────────────────────────────────────────── fig, ax = plt.subplots(figsize=(5, 4)) sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Healthy', 'Symptomatic'], yticklabels=['Healthy', 'Symptomatic'], ax=ax) ax.set_xlabel('Predicted') ax.set_ylabel('Actual') ax.set_title(f'Cough Agent — Confusion Matrix\nMacro F1={f1:.3f} | AUROC={auc:.3f}') plt.tight_layout() fig.savefig('outputs/cough_confusion_matrix.png', dpi=150) plt.close() print("Saved: outputs/cough_confusion_matrix.png") # ── ROC curve ───────────────────────────────────────────────────────────────── fpr, tpr, _ = roc_curve(test_labels, test_probs) fig, ax = plt.subplots(figsize=(5, 4)) ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC (AUC = {auc:.3f})') ax.plot([0,1],[0,1],'k--',lw=1) ax.set_xlabel('False Positive Rate') ax.set_ylabel('True Positive Rate') ax.set_title('Cough Agent — ROC Curve') ax.legend(loc='lower right') plt.tight_layout() fig.savefig('outputs/cough_roc_curve.png', dpi=150) plt.close() print("Saved: outputs/cough_roc_curve.png") # ── Save results JSON ───────────────────────────────────────────────────────── results = { 'model': 'CoughMLP (OPERA-CT)', 'task': 'Cough Classification (Healthy vs Symptomatic)', 'accuracy': round(acc, 4), 'f1_macro': round(f1, 4), 'recall': round(rec, 4), 'precision': round(prec, 4), 'auroc': round(auc, 4), 'threshold': round(best_thresh, 4), 'train_size': len(X_train), 'val_size': len(X_val), 'test_size': len(X_test), } with open('outputs/results_cough.json', 'w') as f: json.dump(results, f, indent=2) print("Saved: outputs/results_cough.json")