Spaces:
Sleeping
Sleeping
| """ | |
| scripts/train_cough_agent.py β Train binary Cough classifier on COUGHVID. | |
| Labels: Healthy (0) vs Symptomatic (1) [COVID-19 merged into Symptomatic] | |
| Input: OPERA-CT 768-dim embeddings | |
| Model: BinaryMLP 768 -> 256 -> 64 -> 2 with BatchNorm + Dropout | |
| Output: saved_models/cough_opera_mlp.pt | |
| outputs/cough_confusion_matrix.png | |
| outputs/cough_roc_curve.png | |
| outputs/results_cough.json | |
| """ | |
| import os | |
| import sys | |
| import json | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import ( | |
| f1_score, recall_score, precision_score, | |
| roc_auc_score, accuracy_score, confusion_matrix, roc_curve | |
| ) | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import seaborn as sns | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) | |
| os.makedirs('outputs', exist_ok=True) | |
| os.makedirs('saved_models', exist_ok=True) | |
| RANDOM_STATE = 42 | |
| EPOCHS = 60 | |
| BATCH_SIZE = 64 | |
| LR = 1e-3 | |
| DROPOUT = 0.3 | |
| torch.manual_seed(RANDOM_STATE) | |
| # ββ Device βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Device: {device}") | |
| # ββ Data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| df = pd.read_csv('data/cough_labels_with_embeddings.csv').dropna(subset=['embedding_path']) | |
| print(f"Total samples: {len(df)}") | |
| print(df['label_str'].value_counts()) | |
| def load_embedding(path): | |
| return np.load(path).astype(np.float32) | |
| X = np.stack([load_embedding(p) for p in df['embedding_path']]) | |
| y = df['label'].values.astype(np.int64) | |
| # Train / val / test split (70/15/15) β stratified | |
| X_tv, X_test, y_tv, y_test = train_test_split( | |
| X, y, test_size=0.15, stratify=y, random_state=RANDOM_STATE) | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X_tv, y_tv, test_size=0.15/0.85, stratify=y_tv, random_state=RANDOM_STATE) | |
| print(f"Train: {len(X_train)} | Val: {len(X_val)} | Test: {len(X_test)}") | |
| print(f"Test positives (symptomatic): {y_test.sum()}") | |
| # ββ Dataset βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class EmbeddingDataset(Dataset): | |
| def __init__(self, X, y): | |
| self.X = torch.tensor(X, dtype=torch.float32) | |
| self.y = torch.tensor(y, dtype=torch.long) | |
| def __len__(self): return len(self.y) | |
| def __getitem__(self, i): return self.X[i], self.y[i] | |
| train_ds = EmbeddingDataset(X_train, y_train) | |
| val_ds = EmbeddingDataset(X_val, y_val) | |
| test_ds = EmbeddingDataset(X_test, y_test) | |
| # Weighted sampler for class imbalance | |
| counts = np.bincount(y_train) | |
| weights = 1.0 / counts[y_train] | |
| sampler = WeightedRandomSampler(weights, len(weights), replacement=True) | |
| train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, sampler=sampler) | |
| val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False) | |
| test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False) | |
| # ββ Model βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class CoughMLP(nn.Module): | |
| def __init__(self, input_dim=768, dropout=0.3): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, 256), | |
| nn.BatchNorm1d(256), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(256, 64), | |
| nn.BatchNorm1d(64), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 2), | |
| ) | |
| def forward(self, x): return self.net(x) | |
| model = CoughMLP(dropout=DROPOUT).to(device) | |
| # Class-weighted loss | |
| class_weights = torch.tensor( | |
| [1.0, len(y_train) / (2 * y_train.sum())], dtype=torch.float32 | |
| ).to(device) | |
| criterion = nn.CrossEntropyLoss(weight=class_weights) | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS) | |
| # ββ Training ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| best_val_f1 = 0.0 | |
| best_state = None | |
| patience = 12 | |
| no_improve = 0 | |
| print("\nTraining...") | |
| for epoch in range(1, EPOCHS + 1): | |
| model.train() | |
| for xb, yb in train_loader: | |
| xb, yb = xb.to(device), yb.to(device) | |
| optimizer.zero_grad() | |
| loss = criterion(model(xb), yb) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| # Validation | |
| model.eval() | |
| all_preds, all_probs, all_labels = [], [], [] | |
| with torch.no_grad(): | |
| for xb, yb in val_loader: | |
| logits = model(xb.to(device)) | |
| probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy() | |
| preds = logits.argmax(dim=1).cpu().numpy() | |
| all_preds.extend(preds) | |
| all_probs.extend(probs) | |
| all_labels.extend(yb.numpy()) | |
| val_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0) | |
| if val_f1 > best_val_f1: | |
| best_val_f1 = val_f1 | |
| best_state = {k: v.clone() for k, v in model.state_dict().items()} | |
| no_improve = 0 | |
| else: | |
| no_improve += 1 | |
| if epoch % 10 == 0: | |
| print(f" Epoch {epoch:3d} | Val F1: {val_f1:.4f} | Best: {best_val_f1:.4f}") | |
| if no_improve >= patience: | |
| print(f" Early stop at epoch {epoch}") | |
| break | |
| # ββ Threshold tuning on validation set βββββββββββββββββββββββββββββββββββββββ | |
| model.load_state_dict(best_state) | |
| model.eval() | |
| val_probs, val_labels = [], [] | |
| with torch.no_grad(): | |
| for xb, yb in val_loader: | |
| probs = torch.softmax(model(xb.to(device)), dim=1)[:, 1].cpu().numpy() | |
| val_probs.extend(probs) | |
| val_labels.extend(yb.numpy()) | |
| val_probs = np.array(val_probs) | |
| val_labels = np.array(val_labels) | |
| best_thresh, best_f1 = 0.5, 0.0 | |
| for t in np.arange(0.3, 0.8, 0.01): | |
| preds = (val_probs >= t).astype(int) | |
| f1 = f1_score(val_labels, preds, average='macro', zero_division=0) | |
| if f1 > best_f1: | |
| best_f1, best_thresh = f1, t | |
| print(f"\nBest threshold: {best_thresh:.2f} (Val Macro F1: {best_f1:.4f})") | |
| # ββ Test evaluation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| test_probs, test_labels = [], [] | |
| with torch.no_grad(): | |
| for xb, yb in test_loader: | |
| probs = torch.softmax(model(xb.to(device)), dim=1)[:, 1].cpu().numpy() | |
| test_probs.extend(probs) | |
| test_labels.extend(yb.numpy()) | |
| test_probs = np.array(test_probs) | |
| test_labels = np.array(test_labels) | |
| test_preds = (test_probs >= best_thresh).astype(int) | |
| acc = accuracy_score(test_labels, test_preds) | |
| f1 = f1_score(test_labels, test_preds, average='macro', zero_division=0) | |
| rec = recall_score(test_labels, test_preds, pos_label=1, zero_division=0) | |
| prec = precision_score(test_labels, test_preds, pos_label=1, zero_division=0) | |
| auc = roc_auc_score(test_labels, test_probs) | |
| cm = confusion_matrix(test_labels, test_preds) | |
| print(f"\nTest Results:") | |
| print(f" Accuracy : {acc:.4f}") | |
| print(f" Macro F1 : {f1:.4f}") | |
| print(f" Recall : {rec:.4f}") | |
| print(f" Precision: {prec:.4f}") | |
| print(f" AUROC : {auc:.4f}") | |
| print(f" Confusion Matrix:\n{cm}") | |
| # ββ Save model ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| torch.save({ | |
| 'state_dict': model.state_dict(), | |
| 'threshold': best_thresh, | |
| 'val_f1': best_val_f1, | |
| 'config': {'input_dim': 768, 'dropout': DROPOUT}, | |
| }, 'saved_models/cough_opera_mlp.pt') | |
| print("\nSaved: saved_models/cough_opera_mlp.pt") | |
| # ββ Confusion matrix plot βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| fig, ax = plt.subplots(figsize=(5, 4)) | |
| sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', | |
| xticklabels=['Healthy', 'Symptomatic'], | |
| yticklabels=['Healthy', 'Symptomatic'], ax=ax) | |
| ax.set_xlabel('Predicted') | |
| ax.set_ylabel('Actual') | |
| ax.set_title(f'Cough Agent β Confusion Matrix\nMacro F1={f1:.3f} | AUROC={auc:.3f}') | |
| plt.tight_layout() | |
| fig.savefig('outputs/cough_confusion_matrix.png', dpi=150) | |
| plt.close() | |
| print("Saved: outputs/cough_confusion_matrix.png") | |
| # ββ ROC curve βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| fpr, tpr, _ = roc_curve(test_labels, test_probs) | |
| fig, ax = plt.subplots(figsize=(5, 4)) | |
| ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC (AUC = {auc:.3f})') | |
| ax.plot([0,1],[0,1],'k--',lw=1) | |
| ax.set_xlabel('False Positive Rate') | |
| ax.set_ylabel('True Positive Rate') | |
| ax.set_title('Cough Agent β ROC Curve') | |
| ax.legend(loc='lower right') | |
| plt.tight_layout() | |
| fig.savefig('outputs/cough_roc_curve.png', dpi=150) | |
| plt.close() | |
| print("Saved: outputs/cough_roc_curve.png") | |
| # ββ Save results JSON βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| results = { | |
| 'model': 'CoughMLP (OPERA-CT)', | |
| 'task': 'Cough Classification (Healthy vs Symptomatic)', | |
| 'accuracy': round(acc, 4), | |
| 'f1_macro': round(f1, 4), | |
| 'recall': round(rec, 4), | |
| 'precision': round(prec, 4), | |
| 'auroc': round(auc, 4), | |
| 'threshold': round(best_thresh, 4), | |
| 'train_size': len(X_train), | |
| 'val_size': len(X_val), | |
| 'test_size': len(X_test), | |
| } | |
| with open('outputs/results_cough.json', 'w') as f: | |
| json.dump(results, f, indent=2) | |
| print("Saved: outputs/results_cough.json") | |