respitriage / scripts /train_binary_agent.py
SujalSha's picture
Upload folder using huggingface_hub
d0ace1e verified
"""
scripts/train_binary_agent.py β€” Train COPD or Pneumonia binary MLP classifier.
Change DISEASE at the top to switch between agents.
Run twice:
DISEASE = 'COPD' β†’ saved_models/copd_opera_mlp.pt
DISEASE = 'Pneumonia' β†’ saved_models/pneumonia_opera_mlp.pt
Requirements: run scripts/extract_opera_embeddings.py first.
"""
import os
import sys
import json
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
f1_score, recall_score, precision_score, roc_auc_score, accuracy_score
)
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from models.mlp_classifier import BinaryMLPClassifier, FocalLoss
from models.embedding_dataset import EmbeddingDataset
# ══════════════════════════════════════════════════════════════════════════════
# CONFIG β€” change DISEASE to 'Pneumonia' for the second agent
# ══════════════════════════════════════════════════════════════════════════════
DISEASE = 'Pneumonia'
CSV_PATH = f'data/{DISEASE.lower()}_binary_labels_with_embeddings.csv'
MODEL_SAVE_PATH = f'saved_models/{DISEASE.lower()}_opera_mlp.pt'
RESULTS_PATH = f'outputs/results_{DISEASE.lower()}.json'
INPUT_DIM = 768 # OPERA-CT HT-SAT output dimension
HIDDEN_DIMS = [256, 64]
DROPOUT = 0.3
BATCH_SIZE = 64
MAX_EPOCHS = 100
PATIENCE = 15 # early stopping patience
LR = 3e-4
WEIGHT_DECAY = 1e-4
TARGET_RECALL = 0.80 # clinical safety: must not miss 80%+ of cases
RANDOM_STATE = 42
# ══════════════════════════════════════════════════════════════════════════════
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"[train] Disease: {DISEASE} | Device: {device}")
# ── Load and split dataset ──────────────────────────────────────────────────
df = pd.read_csv(CSV_PATH).dropna(subset=['embedding_path'])
print(f"[train] Total samples: {len(df)}")
print(f"[train] Label distribution:\n{df['label'].value_counts()}")
train_df, temp_df = train_test_split(
df, test_size=0.30, stratify=df['label'], random_state=RANDOM_STATE
)
val_df, test_df = train_test_split(
temp_df, test_size=0.50, stratify=temp_df['label'], random_state=RANDOM_STATE
)
train_df.to_csv(f'data/{DISEASE.lower()}_train_split.csv', index=False)
val_df.to_csv( f'data/{DISEASE.lower()}_val_split.csv', index=False)
test_df.to_csv( f'data/{DISEASE.lower()}_test_split.csv', index=False)
print(f"[train] Split β€” Train: {len(train_df)} | Val: {len(val_df)} | Test: {len(test_df)}")
# ── Datasets ────────────────────────────────────────────────────────────────
train_dataset = EmbeddingDataset(f'data/{DISEASE.lower()}_train_split.csv', augment=True)
val_dataset = EmbeddingDataset(f'data/{DISEASE.lower()}_val_split.csv', augment=False)
test_dataset = EmbeddingDataset(f'data/{DISEASE.lower()}_test_split.csv', augment=False)
# WeightedRandomSampler β€” balanced batches for minority class
labels = train_df['label'].values
n_pos = (labels == 1).sum()
n_neg = (labels == 0).sum()
weights = np.where(labels == 1, 1.0 / n_pos, 1.0 / n_neg).astype(np.float64)
sampler = WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, sampler=sampler)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
# ── Model, loss, optimiser ──────────────────────────────────────────────────
model = BinaryMLPClassifier(input_dim=INPUT_DIM, hidden_dims=HIDDEN_DIMS, dropout=DROPOUT).to(device)
criterion = FocalLoss(alpha=0.25, gamma=2.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=MAX_EPOCHS)
# ── Training loop ────────────────────────────────────────────────────────────
best_val_f1 = 0.0
patience_counter = 0
best_model_state = None
for epoch in range(MAX_EPOCHS):
model.train()
train_loss = 0.0
for embeddings, labels_batch in train_loader:
embeddings = embeddings.to(device)
labels_batch = labels_batch.to(device)
optimizer.zero_grad()
logits = model(embeddings)
loss = criterion(logits, labels_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
train_loss += loss.item()
scheduler.step()
# Validate
model.eval()
all_preds, all_probs, all_labels = [], [], []
with torch.no_grad():
for embeddings, labels_batch in val_loader:
logits = model(embeddings.to(device))
probs = torch.softmax(logits, dim=1)[:, 1]
preds = (probs > 0.5).long()
all_probs.extend(probs.cpu().numpy())
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels_batch.numpy())
val_f1 = f1_score(all_labels, all_preds, average='macro', zero_division=0)
val_recall = recall_score(all_labels, all_preds, pos_label=1, zero_division=0)
if val_f1 > best_val_f1:
best_val_f1 = val_f1
best_model_state = {k: v.clone() for k, v in model.state_dict().items()}
patience_counter = 0
else:
patience_counter += 1
if epoch % 10 == 0:
print(f" Epoch {epoch:3d} | Loss: {train_loss/len(train_loader):.4f} "
f"| Val F1: {val_f1:.4f} | Val Recall: {val_recall:.4f}")
if patience_counter >= PATIENCE:
print(f"[train] Early stopping at epoch {epoch}")
break
# ── Threshold tuning on validation set ─────────────────────────────────────
model.load_state_dict(best_model_state)
model.eval()
print(f"\n[train] Tuning threshold for recall >= {TARGET_RECALL}...")
all_probs_val, all_labels_val = [], []
with torch.no_grad():
for embeddings, labels_batch in val_loader:
logits = model(embeddings.to(device))
probs = torch.softmax(logits, dim=1)[:, 1]
all_probs_val.extend(probs.cpu().numpy())
all_labels_val.extend(labels_batch.numpy())
all_probs_val = np.array(all_probs_val)
all_labels_val = np.array(all_labels_val)
best_threshold = 0.5
best_threshold_f1 = 0.0
for thresh in np.arange(0.20, 0.71, 0.01):
preds = (all_probs_val >= thresh).astype(int)
rec = recall_score(all_labels_val, preds, pos_label=1, zero_division=0)
f1 = f1_score(all_labels_val, preds, average='macro', zero_division=0)
if rec >= TARGET_RECALL and f1 > best_threshold_f1:
best_threshold_f1 = f1
best_threshold = thresh
print(f"[train] Best threshold: {best_threshold:.2f} | Val F1: {best_threshold_f1:.4f}")
# ── Final evaluation on held-out test set ──────────────────────────────────
print("\n[train] Evaluating on test set...")
all_probs_test, all_labels_test = [], []
with torch.no_grad():
for embeddings, labels_batch in test_loader:
logits = model(embeddings.to(device))
probs = torch.softmax(logits, dim=1)[:, 1]
all_probs_test.extend(probs.cpu().numpy())
all_labels_test.extend(labels_batch.numpy())
all_probs_test = np.array(all_probs_test)
all_labels_test = np.array(all_labels_test)
all_preds_test = (all_probs_test >= best_threshold).astype(int)
test_results = {
'disease': DISEASE,
'threshold': float(best_threshold),
'accuracy': float(accuracy_score(all_labels_test, all_preds_test)),
'f1_macro': float(f1_score(all_labels_test, all_preds_test, average='macro')),
'recall': float(recall_score(all_labels_test, all_preds_test, pos_label=1)),
'precision': float(precision_score(all_labels_test, all_preds_test, pos_label=1, zero_division=0)),
'auroc': float(roc_auc_score(all_labels_test, all_probs_test)),
}
print("\n[train] Test Results:")
for k, v in test_results.items():
print(f" {k}: {v}")
# ── Save model + metadata ───────────────────────────────────────────────────
os.makedirs('saved_models', exist_ok=True)
os.makedirs('outputs', exist_ok=True)
torch.save({
'model_state_dict': best_model_state,
'threshold': best_threshold,
'hidden_dims': HIDDEN_DIMS,
'input_dim': INPUT_DIM,
'test_results': test_results,
'disease': DISEASE,
}, MODEL_SAVE_PATH)
with open(RESULTS_PATH, 'w') as f:
json.dump(test_results, f, indent=2)
print(f"\n[train] Model saved to {MODEL_SAVE_PATH}")
print(f"[train] Results saved to {RESULTS_PATH}")