""" Deep-learning classifier (PyTorch) for screening ASD vs non-ASD. Two complementary models are trained: 1. TabularMLP -- a plain multi-layer perceptron that consumes the hand-engineered CHAT features already produced by data_loader.py. Directly comparable to the sklearn baselines in classifier.py. 2. UtteranceLSTM -- a small bidirectional LSTM that reads the *sequence* of CHI utterances from each transcript (one word-count vector per utterance). Shows that a sequence model can pick up temporal speech patterns that aggregate features miss. Both are evaluated with stratified 5-fold CV so the numbers are directly comparable with `src/classifier.py`. Outputs: reports/metrics/deep_learning_results.csv reports/figures/deep_learning_roc.png reports/figures/deep_learning_training_curves.png """ from __future__ import annotations from pathlib import Path from typing import List, Tuple import matplotlib.pyplot as plt import numpy as np import pandas as pd import pylangacq as pla import seaborn as sns import torch import torch.nn as nn from sklearn.impute import SimpleImputer from sklearn.metrics import ( accuracy_score, classification_report, f1_score, roc_auc_score, roc_curve, ) from sklearn.model_selection import StratifiedKFold from sklearn.preprocessing import StandardScaler from torch.utils.data import DataLoader, Dataset try: from src.feature_schema import FEATURES except ModuleNotFoundError: from feature_schema import FEATURES PROJECT_ROOT = Path(__file__).resolve().parent.parent DATA_DIR = PROJECT_ROOT / "data" FLUSBERG_DIR = DATA_DIR / "Flusberg" FIG_DIR = PROJECT_ROOT / "reports" / "figures" METRIC_DIR = PROJECT_ROOT / "reports" / "metrics" FIG_DIR.mkdir(parents=True, exist_ok=True) METRIC_DIR.mkdir(parents=True, exist_ok=True) sns.set_theme(style="whitegrid", context="talk") RANDOM_STATE = 42 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") print(f"[info] using device: {DEVICE}") torch.manual_seed(RANDOM_STATE) np.random.seed(RANDOM_STATE) # ========================================================================= # Model 1: Tabular MLP on hand-engineered features # ========================================================================= class TabularMLP(nn.Module): def __init__(self, in_dim: int, hidden: int = 64, dropout: float = 0.3): super().__init__() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden, hidden // 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(hidden // 2, 1), ) def forward(self, x): return self.net(x).squeeze(-1) def _train_mlp(X_tr, y_tr, X_va, y_va, in_dim: int, epochs: int = 80, lr: float = 1e-3, weight_decay: float = 1e-4) -> Tuple[TabularMLP, List[float], List[float]]: model = TabularMLP(in_dim).to(DEVICE) opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) # class imbalance weighting pos = float((y_tr == 1).sum()) neg = float((y_tr == 0).sum()) pos_weight = torch.tensor([neg / max(pos, 1.0)], device=DEVICE) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight) X_tr_t = torch.tensor(X_tr, dtype=torch.float32, device=DEVICE) y_tr_t = torch.tensor(y_tr, dtype=torch.float32, device=DEVICE) X_va_t = torch.tensor(X_va, dtype=torch.float32, device=DEVICE) y_va_t = torch.tensor(y_va, dtype=torch.float32, device=DEVICE) tr_losses, va_losses = [], [] best_state, best_va = None, float("inf") patience, bad = 15, 0 for _ in range(epochs): model.train() opt.zero_grad() out = model(X_tr_t) loss = loss_fn(out, y_tr_t) loss.backward() opt.step() tr_losses.append(loss.item()) model.eval() with torch.no_grad(): va_loss = loss_fn(model(X_va_t), y_va_t).item() va_losses.append(va_loss) if va_loss < best_va - 1e-4: best_va = va_loss best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} bad = 0 else: bad += 1 if bad >= patience: break if best_state is not None: model.load_state_dict(best_state) return model, tr_losses, va_losses def run_tabular_mlp(df: pd.DataFrame): X = df[FEATURES].values.astype(np.float32) y = (df["group"] == "ASD").astype(int).values skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE) all_pred = np.zeros(len(y)) all_prob = np.zeros(len(y)) fold_curves = [] for fold, (tr_idx, va_idx) in enumerate(skf.split(X, y), 1): imp = SimpleImputer(strategy="median") sc = StandardScaler() X_tr = sc.fit_transform(imp.fit_transform(X[tr_idx])) X_va = sc.transform(imp.transform(X[va_idx])) y_tr, y_va = y[tr_idx], y[va_idx] model, tr_l, va_l = _train_mlp(X_tr, y_tr, X_va, y_va, in_dim=X.shape[1]) fold_curves.append((tr_l, va_l)) model.eval() with torch.no_grad(): logits = model(torch.tensor(X_va, dtype=torch.float32, device=DEVICE)) prob = torch.sigmoid(logits).cpu().numpy() all_prob[va_idx] = prob all_pred[va_idx] = (prob >= 0.5).astype(int) print(f" fold {fold} epochs={len(tr_l):3d} " f"val-loss={va_l[-1]:.3f} val-auc={roc_auc_score(y_va, prob):.3f}") return y, all_pred, all_prob, fold_curves # ========================================================================= # Model 2: Utterance-sequence Bi-LSTM # ========================================================================= # Simple bag-of-features per utterance: # [ word_count, has_xxx, has_yyy, is_zero, is_question, has_nonverbal ] # This stays compatible with any CHAT file and doesn't need a vocabulary. UTT_FEATS = 6 MAX_UTT = 120 # truncate / pad to this many CHI utterances per file def _utt_vector(utterance, raw_tier: str) -> np.ndarray: word_count = sum( 1 for t in utterance.tokens if t.word and t.word not in {".", "?", "!", ",", ";", ":", "+...", "+..", "+/.", "+//.", "+/?"} ) has_xxx = int("xxx" in raw_tier) has_yyy = int("yyy" in raw_tier) stripped = raw_tier.strip().rstrip(" .?!").strip() is_zero = int(stripped == "0") is_q = int(raw_tier.rstrip().endswith("?")) has_nonverb = int("&=" in raw_tier) return np.array([word_count, has_xxx, has_yyy, is_zero, is_q, has_nonverb], dtype=np.float32) def _file_to_sequence(cha_path: Path) -> np.ndarray: try: r = pla.read_chat(str(cha_path)) except Exception: r = pla.read_chat(str(cha_path), strict=False) seq = [] for u in r.utterances(): if u.participant != "CHI": continue raw = u.tiers.get("CHI", "") seq.append(_utt_vector(u, raw)) if len(seq) >= MAX_UTT: break if not seq: return np.zeros((1, UTT_FEATS), dtype=np.float32) return np.stack(seq) def _resolve_cha_path(row: pd.Series) -> Path: """Reconstruct the original .cha path from the combined CSV row.""" corpus = row["corpus"] pid = str(row["participant_id"]) if corpus == "eigsti": return DATA_DIR / "Eigsti" / row["group"] / f"{pid}.cha" if corpus == "nadig": return DATA_DIR / "Nadig" / f"{pid}.cha" if corpus == "nyu_emerson": return DATA_DIR / "NYU-Emerson" / f"{pid}.cha" if corpus == "flusberg": for path in FLUSBERG_DIR.glob("*/*.cha"): if path.stem == pid or path.stem.lstrip("0") == pid: return path raise ValueError(f"unknown corpus: {corpus}") class SeqDataset(Dataset): def __init__(self, seqs, labels): self.seqs = seqs self.labels = labels def __len__(self): return len(self.seqs) def __getitem__(self, i): return self.seqs[i], self.labels[i] def _collate(batch): seqs, ys = zip(*batch) lens = [s.shape[0] for s in seqs] max_len = max(lens) padded = np.zeros((len(seqs), max_len, UTT_FEATS), dtype=np.float32) for i, s in enumerate(seqs): padded[i, :s.shape[0]] = s return (torch.tensor(padded), torch.tensor(lens), torch.tensor(np.asarray(ys), dtype=torch.float32)) class UtteranceLSTM(nn.Module): def __init__(self, in_dim=UTT_FEATS, hidden=16, dropout=0.25): super().__init__() self.lstm = nn.LSTM(in_dim, hidden, batch_first=True, bidirectional=True) self.drop = nn.Dropout(dropout) self.head = nn.Linear(2 * hidden, 1) def forward(self, x, lens): packed = nn.utils.rnn.pack_padded_sequence( x, lens.cpu(), batch_first=True, enforce_sorted=False ) _, (h, _) = self.lstm(packed) # h: (2, B, H) h = torch.cat([h[0], h[1]], dim=-1) # (B, 2H) return self.head(self.drop(h)).squeeze(-1) def run_lstm(df: pd.DataFrame): print("\n[LSTM] building per-utterance sequences...") seqs, labels = [], [] for _, row in df.iterrows(): try: path = _resolve_cha_path(row) seqs.append(_file_to_sequence(path)) labels.append(int(row["group"] == "ASD")) except Exception as e: # noqa: BLE001 print(f" skip {row['participant_id']}: {e}") labels = np.array(labels) lens = [s.shape[0] for s in seqs] print(f"[LSTM] {len(seqs)} sequences, " f"len: min={min(lens)} max={max(lens)} median={int(np.median(lens))}") # Normalize sequence features using training-set stats per fold skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE) all_pred = np.zeros(len(labels)) all_prob = np.zeros(len(labels)) for fold, (tr_idx, va_idx) in enumerate(skf.split(np.zeros(len(labels)), labels), 1): # fit scaler on training utterances train_concat = np.concatenate([seqs[i] for i in tr_idx], axis=0) mu = train_concat.mean(0, keepdims=True) sd = train_concat.std(0, keepdims=True) + 1e-6 tr_seqs = [(seqs[i] - mu) / sd for i in tr_idx] va_seqs = [(seqs[i] - mu) / sd for i in va_idx] y_tr, y_va = labels[tr_idx], labels[va_idx] model = UtteranceLSTM().to(DEVICE) opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4) pos = float((y_tr == 1).sum()) neg = float((y_tr == 0).sum()) pw = torch.tensor([neg / max(pos, 1.0)], device=DEVICE) loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw) tr_loader = DataLoader(SeqDataset(tr_seqs, y_tr), batch_size=16, shuffle=True, collate_fn=_collate) va_loader = DataLoader(SeqDataset(va_seqs, y_va), batch_size=32, shuffle=False, collate_fn=_collate) best_state, best_va, bad = None, float("inf"), 0 for epoch in range(14): model.train() for x, lens_b, y in tr_loader: x, lens_b, y = x.to(DEVICE), lens_b.to(DEVICE), y.to(DEVICE) opt.zero_grad() loss = loss_fn(model(x, lens_b), y) loss.backward() opt.step() model.eval() with torch.no_grad(): va_losses = [] for x, lens_b, y in va_loader: x, lens_b, y = x.to(DEVICE), lens_b.to(DEVICE), y.to(DEVICE) va_losses.append(loss_fn(model(x, lens_b), y).item()) va_mean = float(np.mean(va_losses)) if va_mean < best_va - 1e-4: best_va = va_mean best_state = {k: v.detach().clone() for k, v in model.state_dict().items()} bad = 0 else: bad += 1 if bad >= 4: break if best_state is not None: model.load_state_dict(best_state) # predict on validation set (keep original order) model.eval() probs = np.zeros(len(va_idx)) with torch.no_grad(): idx = 0 for x, lens_b, _ in DataLoader(SeqDataset(va_seqs, y_va), batch_size=16, shuffle=False, collate_fn=_collate): x, lens_b = x.to(DEVICE), lens_b.to(DEVICE) p = torch.sigmoid(model(x, lens_b)).cpu().numpy() probs[idx:idx + len(p)] = p idx += len(p) all_prob[va_idx] = probs all_pred[va_idx] = (probs >= 0.5).astype(int) print(f" fold {fold} val-auc={roc_auc_score(y_va, probs):.3f}") return labels, all_pred, all_prob # ========================================================================= # Runner # ========================================================================= def _plot_training_curves(fold_curves): fig, ax = plt.subplots(figsize=(9, 5)) for i, (tr, va) in enumerate(fold_curves, 1): ax.plot(tr, alpha=0.4, label=f"fold{i} train" if i == 1 else None, color="tab:blue") ax.plot(va, alpha=0.7, label=f"fold{i} val" if i == 1 else None, color="tab:orange") ax.set_title("TabularMLP training curves (5 folds)") ax.set_xlabel("Epoch") ax.set_ylabel("BCE loss") ax.legend() fig.tight_layout() out = FIG_DIR / "deep_learning_training_curves.png" fig.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) print(f" saved {out.relative_to(PROJECT_ROOT)}") def _plot_roc(y, results: dict): fig, ax = plt.subplots(figsize=(7, 6)) for name, prob in results.items(): fpr, tpr, _ = roc_curve(y, prob) auc = roc_auc_score(y, prob) ax.plot(fpr, tpr, label=f"{name} (AUC={auc:.3f})", linewidth=2.2) ax.plot([0, 1], [0, 1], "k--", alpha=0.4) ax.set_xlabel("False positive rate") ax.set_ylabel("True positive rate") ax.set_title("Deep learning: ASD vs non-ASD") ax.legend() fig.tight_layout() out = FIG_DIR / "deep_learning_roc.png" fig.savefig(out, dpi=150, bbox_inches="tight") plt.close(fig) print(f" saved {out.relative_to(PROJECT_ROOT)}") def main() -> None: df = pd.read_csv(DATA_DIR / "combined_features.csv").dropna(subset=["group"]) print(f"Loaded {len(df)} rows.") # -- TabularMLP -- print("\n" + "=" * 70) print("Model 1: Tabular MLP (hand-engineered features)") print("=" * 70) y, mlp_pred, mlp_prob, curves = run_tabular_mlp(df) _plot_training_curves(curves) # -- UtteranceLSTM -- print("\n" + "=" * 70) print("Model 2: Utterance-sequence Bi-LSTM") print("=" * 70) y2, lstm_pred, lstm_prob = run_lstm(df) assert np.array_equal(y, y2), "label order mismatch" rows = [] for name, pred, prob in [ ("TabularMLP", mlp_pred, mlp_prob), ("UtteranceLSTM", lstm_pred, lstm_prob), ]: acc = accuracy_score(y, pred) f1 = f1_score(y, pred, average="macro") auc = roc_auc_score(y, prob) rows.append({"model": name, "accuracy": round(acc, 4), "f1_macro": round(f1, 4), "roc_auc": round(auc, 4)}) print(f"\n[{name}]") print(classification_report(y, pred, target_names=["non-ASD", "ASD"], digits=3)) _plot_roc(y, {"TabularMLP": mlp_prob, "UtteranceLSTM": lstm_prob}) out = METRIC_DIR / "deep_learning_results.csv" pd.DataFrame(rows).to_csv(out, index=False) print(f"\n[saved] {out.relative_to(PROJECT_ROOT)}") print("\n=== SUMMARY ===") print(pd.DataFrame(rows).to_string(index=False)) if __name__ == "__main__": main()