asd-screening-tool / src /deep_learning.py
Paoo
deploy: publish sanitized v0.17 streamlit app
8514865
"""
Deep-learning classifier (PyTorch) for screening ASD vs non-ASD.
Two complementary models are trained:
1. TabularMLP -- a plain multi-layer perceptron that consumes the
hand-engineered CHAT features already produced by
data_loader.py. Directly comparable to the
sklearn baselines in classifier.py.
2. UtteranceLSTM -- a small bidirectional LSTM that reads the
*sequence* of CHI utterances from each transcript
(one word-count vector per utterance). Shows that
a sequence model can pick up temporal speech
patterns that aggregate features miss.
Both are evaluated with stratified 5-fold CV so the numbers are directly
comparable with `src/classifier.py`.
Outputs:
reports/metrics/deep_learning_results.csv
reports/figures/deep_learning_roc.png
reports/figures/deep_learning_training_curves.png
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pylangacq as pla
import seaborn as sns
import torch
import torch.nn as nn
from sklearn.impute import SimpleImputer
from sklearn.metrics import (
accuracy_score,
classification_report,
f1_score,
roc_auc_score,
roc_curve,
)
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from torch.utils.data import DataLoader, Dataset
try:
from src.feature_schema import FEATURES
except ModuleNotFoundError:
from feature_schema import FEATURES
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJECT_ROOT / "data"
FLUSBERG_DIR = DATA_DIR / "Flusberg"
FIG_DIR = PROJECT_ROOT / "reports" / "figures"
METRIC_DIR = PROJECT_ROOT / "reports" / "metrics"
FIG_DIR.mkdir(parents=True, exist_ok=True)
METRIC_DIR.mkdir(parents=True, exist_ok=True)
sns.set_theme(style="whitegrid", context="talk")
RANDOM_STATE = 42
DEVICE = torch.device("cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu")
print(f"[info] using device: {DEVICE}")
torch.manual_seed(RANDOM_STATE)
np.random.seed(RANDOM_STATE)
# =========================================================================
# Model 1: Tabular MLP on hand-engineered features
# =========================================================================
class TabularMLP(nn.Module):
def __init__(self, in_dim: int, hidden: int = 64, dropout: float = 0.3):
super().__init__()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, hidden // 2),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden // 2, 1),
)
def forward(self, x):
return self.net(x).squeeze(-1)
def _train_mlp(X_tr, y_tr, X_va, y_va, in_dim: int,
epochs: int = 80, lr: float = 1e-3,
weight_decay: float = 1e-4) -> Tuple[TabularMLP, List[float], List[float]]:
model = TabularMLP(in_dim).to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
# class imbalance weighting
pos = float((y_tr == 1).sum())
neg = float((y_tr == 0).sum())
pos_weight = torch.tensor([neg / max(pos, 1.0)], device=DEVICE)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
X_tr_t = torch.tensor(X_tr, dtype=torch.float32, device=DEVICE)
y_tr_t = torch.tensor(y_tr, dtype=torch.float32, device=DEVICE)
X_va_t = torch.tensor(X_va, dtype=torch.float32, device=DEVICE)
y_va_t = torch.tensor(y_va, dtype=torch.float32, device=DEVICE)
tr_losses, va_losses = [], []
best_state, best_va = None, float("inf")
patience, bad = 15, 0
for _ in range(epochs):
model.train()
opt.zero_grad()
out = model(X_tr_t)
loss = loss_fn(out, y_tr_t)
loss.backward()
opt.step()
tr_losses.append(loss.item())
model.eval()
with torch.no_grad():
va_loss = loss_fn(model(X_va_t), y_va_t).item()
va_losses.append(va_loss)
if va_loss < best_va - 1e-4:
best_va = va_loss
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
bad = 0
else:
bad += 1
if bad >= patience:
break
if best_state is not None:
model.load_state_dict(best_state)
return model, tr_losses, va_losses
def run_tabular_mlp(df: pd.DataFrame):
X = df[FEATURES].values.astype(np.float32)
y = (df["group"] == "ASD").astype(int).values
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
all_pred = np.zeros(len(y))
all_prob = np.zeros(len(y))
fold_curves = []
for fold, (tr_idx, va_idx) in enumerate(skf.split(X, y), 1):
imp = SimpleImputer(strategy="median")
sc = StandardScaler()
X_tr = sc.fit_transform(imp.fit_transform(X[tr_idx]))
X_va = sc.transform(imp.transform(X[va_idx]))
y_tr, y_va = y[tr_idx], y[va_idx]
model, tr_l, va_l = _train_mlp(X_tr, y_tr, X_va, y_va, in_dim=X.shape[1])
fold_curves.append((tr_l, va_l))
model.eval()
with torch.no_grad():
logits = model(torch.tensor(X_va, dtype=torch.float32, device=DEVICE))
prob = torch.sigmoid(logits).cpu().numpy()
all_prob[va_idx] = prob
all_pred[va_idx] = (prob >= 0.5).astype(int)
print(f" fold {fold} epochs={len(tr_l):3d} "
f"val-loss={va_l[-1]:.3f} val-auc={roc_auc_score(y_va, prob):.3f}")
return y, all_pred, all_prob, fold_curves
# =========================================================================
# Model 2: Utterance-sequence Bi-LSTM
# =========================================================================
# Simple bag-of-features per utterance:
# [ word_count, has_xxx, has_yyy, is_zero, is_question, has_nonverbal ]
# This stays compatible with any CHAT file and doesn't need a vocabulary.
UTT_FEATS = 6
MAX_UTT = 120 # truncate / pad to this many CHI utterances per file
def _utt_vector(utterance, raw_tier: str) -> np.ndarray:
word_count = sum(
1 for t in utterance.tokens
if t.word and t.word not in {".", "?", "!", ",", ";", ":",
"+...", "+..", "+/.", "+//.", "+/?"}
)
has_xxx = int("xxx" in raw_tier)
has_yyy = int("yyy" in raw_tier)
stripped = raw_tier.strip().rstrip(" .?!").strip()
is_zero = int(stripped == "0")
is_q = int(raw_tier.rstrip().endswith("?"))
has_nonverb = int("&=" in raw_tier)
return np.array([word_count, has_xxx, has_yyy, is_zero, is_q, has_nonverb],
dtype=np.float32)
def _file_to_sequence(cha_path: Path) -> np.ndarray:
try:
r = pla.read_chat(str(cha_path))
except Exception:
r = pla.read_chat(str(cha_path), strict=False)
seq = []
for u in r.utterances():
if u.participant != "CHI":
continue
raw = u.tiers.get("CHI", "")
seq.append(_utt_vector(u, raw))
if len(seq) >= MAX_UTT:
break
if not seq:
return np.zeros((1, UTT_FEATS), dtype=np.float32)
return np.stack(seq)
def _resolve_cha_path(row: pd.Series) -> Path:
"""Reconstruct the original .cha path from the combined CSV row."""
corpus = row["corpus"]
pid = str(row["participant_id"])
if corpus == "eigsti":
return DATA_DIR / "Eigsti" / row["group"] / f"{pid}.cha"
if corpus == "nadig":
return DATA_DIR / "Nadig" / f"{pid}.cha"
if corpus == "nyu_emerson":
return DATA_DIR / "NYU-Emerson" / f"{pid}.cha"
if corpus == "flusberg":
for path in FLUSBERG_DIR.glob("*/*.cha"):
if path.stem == pid or path.stem.lstrip("0") == pid:
return path
raise ValueError(f"unknown corpus: {corpus}")
class SeqDataset(Dataset):
def __init__(self, seqs, labels):
self.seqs = seqs
self.labels = labels
def __len__(self):
return len(self.seqs)
def __getitem__(self, i):
return self.seqs[i], self.labels[i]
def _collate(batch):
seqs, ys = zip(*batch)
lens = [s.shape[0] for s in seqs]
max_len = max(lens)
padded = np.zeros((len(seqs), max_len, UTT_FEATS), dtype=np.float32)
for i, s in enumerate(seqs):
padded[i, :s.shape[0]] = s
return (torch.tensor(padded),
torch.tensor(lens),
torch.tensor(np.asarray(ys), dtype=torch.float32))
class UtteranceLSTM(nn.Module):
def __init__(self, in_dim=UTT_FEATS, hidden=16, dropout=0.25):
super().__init__()
self.lstm = nn.LSTM(in_dim, hidden, batch_first=True,
bidirectional=True)
self.drop = nn.Dropout(dropout)
self.head = nn.Linear(2 * hidden, 1)
def forward(self, x, lens):
packed = nn.utils.rnn.pack_padded_sequence(
x, lens.cpu(), batch_first=True, enforce_sorted=False
)
_, (h, _) = self.lstm(packed) # h: (2, B, H)
h = torch.cat([h[0], h[1]], dim=-1) # (B, 2H)
return self.head(self.drop(h)).squeeze(-1)
def run_lstm(df: pd.DataFrame):
print("\n[LSTM] building per-utterance sequences...")
seqs, labels = [], []
for _, row in df.iterrows():
try:
path = _resolve_cha_path(row)
seqs.append(_file_to_sequence(path))
labels.append(int(row["group"] == "ASD"))
except Exception as e: # noqa: BLE001
print(f" skip {row['participant_id']}: {e}")
labels = np.array(labels)
lens = [s.shape[0] for s in seqs]
print(f"[LSTM] {len(seqs)} sequences, "
f"len: min={min(lens)} max={max(lens)} median={int(np.median(lens))}")
# Normalize sequence features using training-set stats per fold
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
all_pred = np.zeros(len(labels))
all_prob = np.zeros(len(labels))
for fold, (tr_idx, va_idx) in enumerate(skf.split(np.zeros(len(labels)), labels), 1):
# fit scaler on training utterances
train_concat = np.concatenate([seqs[i] for i in tr_idx], axis=0)
mu = train_concat.mean(0, keepdims=True)
sd = train_concat.std(0, keepdims=True) + 1e-6
tr_seqs = [(seqs[i] - mu) / sd for i in tr_idx]
va_seqs = [(seqs[i] - mu) / sd for i in va_idx]
y_tr, y_va = labels[tr_idx], labels[va_idx]
model = UtteranceLSTM().to(DEVICE)
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
pos = float((y_tr == 1).sum())
neg = float((y_tr == 0).sum())
pw = torch.tensor([neg / max(pos, 1.0)], device=DEVICE)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pw)
tr_loader = DataLoader(SeqDataset(tr_seqs, y_tr),
batch_size=16, shuffle=True, collate_fn=_collate)
va_loader = DataLoader(SeqDataset(va_seqs, y_va),
batch_size=32, shuffle=False, collate_fn=_collate)
best_state, best_va, bad = None, float("inf"), 0
for epoch in range(14):
model.train()
for x, lens_b, y in tr_loader:
x, lens_b, y = x.to(DEVICE), lens_b.to(DEVICE), y.to(DEVICE)
opt.zero_grad()
loss = loss_fn(model(x, lens_b), y)
loss.backward()
opt.step()
model.eval()
with torch.no_grad():
va_losses = []
for x, lens_b, y in va_loader:
x, lens_b, y = x.to(DEVICE), lens_b.to(DEVICE), y.to(DEVICE)
va_losses.append(loss_fn(model(x, lens_b), y).item())
va_mean = float(np.mean(va_losses))
if va_mean < best_va - 1e-4:
best_va = va_mean
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
bad = 0
else:
bad += 1
if bad >= 4:
break
if best_state is not None:
model.load_state_dict(best_state)
# predict on validation set (keep original order)
model.eval()
probs = np.zeros(len(va_idx))
with torch.no_grad():
idx = 0
for x, lens_b, _ in DataLoader(SeqDataset(va_seqs, y_va),
batch_size=16, shuffle=False,
collate_fn=_collate):
x, lens_b = x.to(DEVICE), lens_b.to(DEVICE)
p = torch.sigmoid(model(x, lens_b)).cpu().numpy()
probs[idx:idx + len(p)] = p
idx += len(p)
all_prob[va_idx] = probs
all_pred[va_idx] = (probs >= 0.5).astype(int)
print(f" fold {fold} val-auc={roc_auc_score(y_va, probs):.3f}")
return labels, all_pred, all_prob
# =========================================================================
# Runner
# =========================================================================
def _plot_training_curves(fold_curves):
fig, ax = plt.subplots(figsize=(9, 5))
for i, (tr, va) in enumerate(fold_curves, 1):
ax.plot(tr, alpha=0.4, label=f"fold{i} train" if i == 1 else None,
color="tab:blue")
ax.plot(va, alpha=0.7, label=f"fold{i} val" if i == 1 else None,
color="tab:orange")
ax.set_title("TabularMLP training curves (5 folds)")
ax.set_xlabel("Epoch")
ax.set_ylabel("BCE loss")
ax.legend()
fig.tight_layout()
out = FIG_DIR / "deep_learning_training_curves.png"
fig.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.relative_to(PROJECT_ROOT)}")
def _plot_roc(y, results: dict):
fig, ax = plt.subplots(figsize=(7, 6))
for name, prob in results.items():
fpr, tpr, _ = roc_curve(y, prob)
auc = roc_auc_score(y, prob)
ax.plot(fpr, tpr, label=f"{name} (AUC={auc:.3f})", linewidth=2.2)
ax.plot([0, 1], [0, 1], "k--", alpha=0.4)
ax.set_xlabel("False positive rate")
ax.set_ylabel("True positive rate")
ax.set_title("Deep learning: ASD vs non-ASD")
ax.legend()
fig.tight_layout()
out = FIG_DIR / "deep_learning_roc.png"
fig.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.relative_to(PROJECT_ROOT)}")
def main() -> None:
df = pd.read_csv(DATA_DIR / "combined_features.csv").dropna(subset=["group"])
print(f"Loaded {len(df)} rows.")
# -- TabularMLP --
print("\n" + "=" * 70)
print("Model 1: Tabular MLP (hand-engineered features)")
print("=" * 70)
y, mlp_pred, mlp_prob, curves = run_tabular_mlp(df)
_plot_training_curves(curves)
# -- UtteranceLSTM --
print("\n" + "=" * 70)
print("Model 2: Utterance-sequence Bi-LSTM")
print("=" * 70)
y2, lstm_pred, lstm_prob = run_lstm(df)
assert np.array_equal(y, y2), "label order mismatch"
rows = []
for name, pred, prob in [
("TabularMLP", mlp_pred, mlp_prob),
("UtteranceLSTM", lstm_pred, lstm_prob),
]:
acc = accuracy_score(y, pred)
f1 = f1_score(y, pred, average="macro")
auc = roc_auc_score(y, prob)
rows.append({"model": name,
"accuracy": round(acc, 4),
"f1_macro": round(f1, 4),
"roc_auc": round(auc, 4)})
print(f"\n[{name}]")
print(classification_report(y, pred, target_names=["non-ASD", "ASD"],
digits=3))
_plot_roc(y, {"TabularMLP": mlp_prob, "UtteranceLSTM": lstm_prob})
out = METRIC_DIR / "deep_learning_results.csv"
pd.DataFrame(rows).to_csv(out, index=False)
print(f"\n[saved] {out.relative_to(PROJECT_ROOT)}")
print("\n=== SUMMARY ===")
print(pd.DataFrame(rows).to_string(index=False))
if __name__ == "__main__":
main()