asd-screening-tool / src /classifier.py
Paoo
deploy: publish sanitized v0.17 streamlit app
8514865
"""
Baseline classifiers for screening: ASD vs TD vs DD.
Two tasks are run:
(A) Binary: ASD vs non-ASD (TD + DD) -> screening use-case
(B) Multi-class: ASD vs DD vs TD -> differential
Models:
- Logistic Regression
- Random Forest
- Support Vector Machine (RBF)
Evaluation: stratified 5-fold cross-validation.
Outputs:
reports/metrics/classification_results.csv
reports/figures/confusion_matrix_<task>_<model>.png
reports/figures/feature_importance.png
reports/figures/roc_curve_binary.png
"""
from __future__ import annotations
import hashlib
import json
from datetime import date
from pathlib import Path
import joblib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
average_precision_score,
brier_score_loss,
ConfusionMatrixDisplay,
RocCurveDisplay,
accuracy_score,
classification_report,
confusion_matrix,
f1_score,
precision_score,
recall_score,
roc_auc_score,
)
from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold, cross_val_predict
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
try:
from src.feature_schema import (
FEATURES,
UNCERTAIN_HIGH,
UNCERTAIN_LOW,
feature_schema_rows,
)
except ModuleNotFoundError: # running as `python src/classifier.py`
from feature_schema import (
FEATURES,
UNCERTAIN_HIGH,
UNCERTAIN_LOW,
feature_schema_rows,
)
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DATA_DIR = PROJECT_ROOT / "data"
FIG_DIR = PROJECT_ROOT / "reports" / "figures"
METRIC_DIR = PROJECT_ROOT / "reports" / "metrics"
ARTIFACT_DIR = PROJECT_ROOT / "artifacts"
FIG_DIR.mkdir(parents=True, exist_ok=True)
METRIC_DIR.mkdir(parents=True, exist_ok=True)
ARTIFACT_DIR.mkdir(parents=True, exist_ok=True)
sns.set_theme(style="whitegrid", context="talk")
RANDOM_STATE = 42
MODEL_VERSION = "v0.17.0-trust-dashboard"
def _build_models():
return {
"LogReg": Pipeline([
("imp", SimpleImputer(strategy="median")),
("sc", StandardScaler()),
("clf", LogisticRegression(max_iter=2000,
class_weight="balanced",
random_state=RANDOM_STATE)),
]),
"RandomForest": Pipeline([
("imp", SimpleImputer(strategy="median")),
("clf", RandomForestClassifier(
n_estimators=300,
class_weight="balanced",
random_state=RANDOM_STATE,
)),
]),
"SVM": Pipeline([
("imp", SimpleImputer(strategy="median")),
("sc", StandardScaler()),
("clf", SVC(kernel="rbf", probability=True,
class_weight="balanced",
random_state=RANDOM_STATE)),
]),
}
def _safe_div(num: float, den: float) -> float:
return float(num / den) if den else 0.0
def _binary_metric_row(
y_true: np.ndarray,
y_pred: np.ndarray,
y_proba: np.ndarray,
*,
threshold: float,
) -> dict:
tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
uncertain = (y_proba >= UNCERTAIN_LOW) & (y_proba < UNCERTAIN_HIGH)
return {
"accuracy": accuracy_score(y_true, y_pred),
"f1_macro": f1_score(y_true, y_pred, average="macro"),
"roc_auc": roc_auc_score(y_true, y_proba),
"pr_auc": average_precision_score(y_true, y_proba),
"sensitivity": recall_score(y_true, y_pred, pos_label=1, zero_division=0),
"specificity": _safe_div(tn, tn + fp),
"ppv": precision_score(y_true, y_pred, pos_label=1, zero_division=0),
"npv": _safe_div(tn, tn + fn),
"brier_score": brier_score_loss(y_true, y_proba),
"threshold": threshold,
"tp": int(tp),
"fp": int(fp),
"tn": int(tn),
"fn": int(fn),
"uncertain_count": int(uncertain.sum()),
"uncertain_rate": float(uncertain.mean()),
}
def _round_metric_row(row: dict) -> dict:
rounded = {}
for key, value in row.items():
if isinstance(value, (float, np.floating)):
rounded[key] = round(float(value), 4)
else:
rounded[key] = value
return rounded
def _cv_evaluate(X, y, models, task: str, class_order, display_labels):
"""Run 5-fold CV for each model.
class_order: labels as they appear in y (e.g. [0, 1] or ['ASD', 'DD', 'TD'])
display_labels: human-readable names in the same order.
"""
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE)
rows = []
preds = {}
probs = {}
for name, pipe in models.items():
y_pred = cross_val_predict(pipe, X, y, cv=skf, n_jobs=-1)
preds[name] = y_pred
acc = accuracy_score(y, y_pred)
f1_macro = f1_score(y, y_pred, average="macro")
row = {"task": task, "model": name}
if task == "binary":
y_proba = cross_val_predict(
pipe, X, y, cv=skf, method="predict_proba", n_jobs=-1
)[:, 1]
probs[name] = y_proba
row.update(_binary_metric_row(
np.asarray(y), np.asarray(y_pred), np.asarray(y_proba),
threshold=0.5,
))
else:
row.update({
"accuracy": acc,
"f1_macro": f1_macro,
})
rows.append(_round_metric_row(row))
cm = confusion_matrix(y, y_pred, labels=class_order)
fig, ax = plt.subplots(figsize=(6, 5))
ConfusionMatrixDisplay(cm, display_labels=display_labels).plot(
ax=ax, cmap="Blues", values_format="d", colorbar=False,
)
ax.set_title(f"{task} | {name}\nacc={acc:.3f} f1={f1_macro:.3f}")
fig.tight_layout()
out = FIG_DIR / f"confusion_matrix_{task}_{name}.png"
fig.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.relative_to(PROJECT_ROOT)}")
print(f"\n[{task} / {name}]")
print(classification_report(y, y_pred, labels=class_order,
target_names=display_labels, digits=3))
return rows, preds, probs
def _plot_feature_importance(X, y):
pipe = Pipeline([
("imp", SimpleImputer(strategy="median")),
("clf", RandomForestClassifier(n_estimators=500,
class_weight="balanced",
random_state=RANDOM_STATE)),
])
pipe.fit(X, y)
imp = pipe.named_steps["clf"].feature_importances_
order = np.argsort(imp)[::-1]
feats = np.array(FEATURES)[order]
vals = imp[order]
fig, ax = plt.subplots(figsize=(9, 6))
sns.barplot(x=vals, y=feats, ax=ax, color="#4C72B0")
ax.set_title("Random Forest feature importance (multi-class)")
ax.set_xlabel("Importance")
fig.tight_layout()
out = FIG_DIR / "feature_importance.png"
fig.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.relative_to(PROJECT_ROOT)}")
def _plot_roc_curves(X, y, probs):
fig, ax = plt.subplots(figsize=(7, 6))
for name, p in probs.items():
RocCurveDisplay.from_predictions(y, p, name=name, ax=ax)
ax.plot([0, 1], [0, 1], "k--", alpha=0.4)
ax.set_title("ROC curves - ASD vs non-ASD (5-fold CV)")
fig.tight_layout()
out = FIG_DIR / "roc_curve_binary.png"
fig.savefig(out, dpi=150, bbox_inches="tight")
plt.close(fig)
print(f" saved {out.relative_to(PROJECT_ROOT)}")
def _data_hash(df: pd.DataFrame) -> str:
payload = df.sort_values(["corpus", "participant_id"]).to_csv(index=False)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16]
def _threshold_table(y_true: np.ndarray, y_proba: np.ndarray) -> pd.DataFrame:
rows = []
for threshold in np.round(np.arange(0.05, 0.96, 0.05), 2):
y_pred = (y_proba >= threshold).astype(int)
row = _binary_metric_row(y_true, y_pred, y_proba, threshold=float(threshold))
rows.append(_round_metric_row(row))
return pd.DataFrame(rows)
def _calibration_bins(y_true: np.ndarray, y_proba: np.ndarray, n_bins: int = 10) -> pd.DataFrame:
bins = np.linspace(0, 1, n_bins + 1)
labels = [f"{bins[i]:.1f}-{bins[i + 1]:.1f}" for i in range(n_bins)]
df = pd.DataFrame({"y_true": y_true, "prob_asd": y_proba})
df["bin"] = pd.cut(df["prob_asd"], bins=bins, labels=labels,
include_lowest=True, right=True)
rows = []
for label, group in df.groupby("bin", observed=False):
if group.empty:
continue
rows.append({
"bin": str(label),
"n": int(len(group)),
"predicted_mean": round(float(group["prob_asd"].mean()), 4),
"observed_rate": round(float(group["y_true"].mean()), 4),
})
return pd.DataFrame(rows)
def _decision_curve(y_true: np.ndarray, y_proba: np.ndarray) -> pd.DataFrame:
n = len(y_true)
prevalence = float(np.mean(y_true))
rows = []
for threshold in np.round(np.arange(0.05, 0.96, 0.05), 2):
y_pred = y_proba >= threshold
tp = int(((y_true == 1) & y_pred).sum())
fp = int(((y_true == 0) & y_pred).sum())
odds = threshold / (1 - threshold)
rows.append({
"threshold": float(threshold),
"model_net_benefit": round(tp / n - fp / n * odds, 4),
"treat_all_net_benefit": round(prevalence - (1 - prevalence) * odds, 4),
"treat_none_net_benefit": 0.0,
})
return pd.DataFrame(rows)
def _subgroup_performance(df: pd.DataFrame, y_true: np.ndarray, y_proba: np.ndarray) -> pd.DataFrame:
rows = []
eval_df = df.copy()
eval_df["y_true"] = y_true
eval_df["prob_asd"] = y_proba
eval_df["pred"] = (y_proba >= 0.5).astype(int)
eval_df["age_band"] = pd.cut(
eval_df["age_months"],
bins=[0, 36, 48, 60, 72, 200],
labels=["<36", "36-47", "48-59", "60-71", "72+"],
include_lowest=True,
).astype(str)
for dimension in ["corpus", "sex", "age_band"]:
for value, sub in eval_df.groupby(dimension, dropna=False):
if len(sub) < 5 or sub["y_true"].nunique() < 2:
continue
metrics = _binary_metric_row(
sub["y_true"].to_numpy(),
sub["pred"].to_numpy(),
sub["prob_asd"].to_numpy(),
threshold=0.5,
)
rows.append(_round_metric_row({
"dimension": dimension,
"value": str(value),
"n": len(sub),
**metrics,
}))
return pd.DataFrame(rows)
def _leave_one_corpus_out(df: pd.DataFrame) -> pd.DataFrame:
X = df[FEATURES].values
y = (df["group"] == "ASD").astype(int).values
groups = df["corpus"].values
rows = []
for train_idx, test_idx in LeaveOneGroupOut().split(X, y, groups):
test_corpus = str(groups[test_idx][0])
if len(np.unique(y[test_idx])) < 2:
rows.append({
"held_out_corpus": test_corpus,
"n_test": int(len(test_idx)),
"status": "skipped_single_class",
})
continue
pipe = _build_models()["LogReg"]
pipe.fit(X[train_idx], y[train_idx])
proba = pipe.predict_proba(X[test_idx])[:, 1]
pred = (proba >= 0.5).astype(int)
rows.append(_round_metric_row({
"held_out_corpus": test_corpus,
"n_test": int(len(test_idx)),
"status": "evaluated",
**_binary_metric_row(y[test_idx], pred, proba, threshold=0.5),
}))
return pd.DataFrame(rows)
def _write_model_bundle(df: pd.DataFrame) -> dict:
X = df[FEATURES].values
y = (df["group"] == "ASD").astype(int).values
model = _build_models()["LogReg"]
model.fit(X, y)
bundle = {
"model": model,
"model_version": MODEL_VERSION,
"features": FEATURES,
"thresholds": {
"uncertain_low": UNCERTAIN_LOW,
"uncertain_high": UNCERTAIN_HIGH,
"default_binary": 0.5,
},
"training_metadata": {
"trained_on": date.today().isoformat(),
"n_rows": int(len(df)),
"n_asd": int(y.sum()),
"n_non_asd": int((1 - y).sum()),
"corpora": sorted(df["corpus"].dropna().unique().tolist()),
"data_hash": _data_hash(df),
},
}
out = ARTIFACT_DIR / "screening_model.joblib"
joblib.dump(bundle, out)
print(f" saved {out.relative_to(PROJECT_ROOT)}")
return bundle
def _write_json(path: Path, payload: dict) -> None:
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
print(f" saved {path.relative_to(PROJECT_ROOT)}")
def main() -> None:
csv_path = DATA_DIR / "combined_features.csv"
df = pd.read_csv(csv_path)
df = df.dropna(subset=["group"])
print(f"Loaded {len(df)} rows. Group counts:")
print(df["group"].value_counts().to_string())
X = df[FEATURES].values
all_rows = []
# ---------------- Binary: ASD vs non-ASD ----------------
print("\n" + "=" * 70)
print("TASK A: Binary ASD (1) vs non-ASD (0)")
print("=" * 70)
y_bin = (df["group"] == "ASD").astype(int).values
rows, preds, probs = _cv_evaluate(
X, y_bin, _build_models(),
task="binary",
class_order=[0, 1],
display_labels=["non-ASD", "ASD"],
)
all_rows.extend(rows)
if probs:
_plot_roc_curves(X, y_bin, probs)
# ---------------- Multi-class: ASD / DD / TD ----------------
print("\n" + "=" * 70)
print("TASK B: Multi-class ASD vs DD vs TD")
print("=" * 70)
multi_df = df[df["group"].isin(["ASD", "DD", "TD"])]
X_m = multi_df[FEATURES].values
y_m = multi_df["group"].astype(str).to_numpy()
rows, _, _ = _cv_evaluate(
X_m, y_m, _build_models(),
task="multiclass",
class_order=["ASD", "DD", "TD"],
display_labels=["ASD", "DD", "TD"],
)
all_rows.extend(rows)
_plot_feature_importance(X_m, y_m)
# Save results
results_df = pd.DataFrame(all_rows)
out = METRIC_DIR / "classification_results.csv"
results_df.to_csv(out, index=False)
print(f"\n[saved] {out.relative_to(PROJECT_ROOT)}")
# Model Trust Dashboard inputs focus on LogReg, the selected interpretable
# screening model. These CSVs are static assets for project_dashboard/.
logreg_prob = probs.get("LogReg")
logreg_pred = preds.get("LogReg")
if logreg_prob is not None and logreg_pred is not None:
pred_df = df[[
"participant_id", "corpus", "group", "sex", "age_months",
]].copy()
pred_df["y_true"] = y_bin
pred_df["prob_asd"] = np.round(logreg_prob, 6)
pred_df["pred_050"] = logreg_pred
pred_df["uncertainty_zone"] = np.select(
[
logreg_prob < UNCERTAIN_LOW,
(logreg_prob >= UNCERTAIN_LOW) & (logreg_prob < UNCERTAIN_HIGH),
logreg_prob >= UNCERTAIN_HIGH,
],
["low", "uncertain", "high"],
default="unknown",
)
pred_out = METRIC_DIR / "binary_oof_predictions.csv"
pred_df.to_csv(pred_out, index=False)
print(f"[saved] {pred_out.relative_to(PROJECT_ROOT)}")
threshold_out = METRIC_DIR / "threshold_metrics.csv"
_threshold_table(y_bin, logreg_prob).to_csv(threshold_out, index=False)
print(f"[saved] {threshold_out.relative_to(PROJECT_ROOT)}")
calibration_out = METRIC_DIR / "calibration_bins.csv"
_calibration_bins(y_bin, logreg_prob).to_csv(calibration_out, index=False)
print(f"[saved] {calibration_out.relative_to(PROJECT_ROOT)}")
dca_out = METRIC_DIR / "decision_curve.csv"
_decision_curve(y_bin, logreg_prob).to_csv(dca_out, index=False)
print(f"[saved] {dca_out.relative_to(PROJECT_ROOT)}")
subgroup_out = METRIC_DIR / "subgroup_performance.csv"
_subgroup_performance(df, y_bin, logreg_prob).to_csv(subgroup_out, index=False)
print(f"[saved] {subgroup_out.relative_to(PROJECT_ROOT)}")
loco_out = METRIC_DIR / "leave_one_corpus_out.csv"
_leave_one_corpus_out(df).to_csv(loco_out, index=False)
print(f"[saved] {loco_out.relative_to(PROJECT_ROOT)}")
feature_schema_out = ARTIFACT_DIR / "feature_schema.json"
_write_json(feature_schema_out, {
"features": FEATURES,
"feature_docs": feature_schema_rows(),
"thresholds": {
"uncertain_low": UNCERTAIN_LOW,
"uncertain_high": UNCERTAIN_HIGH,
},
})
bundle = _write_model_bundle(df)
model_card_out = ARTIFACT_DIR / "model_card.json"
_write_json(model_card_out, {
"model_version": MODEL_VERSION,
"model_type": "Logistic Regression with median imputation and standard scaling",
"intended_use": "ASD screening support and research demo; not diagnostic.",
"not_intended_use": "Autonomous diagnosis, emergency triage, or replacement for clinician assessment.",
"inputs": FEATURES,
"training_metadata": bundle["training_metadata"],
"thresholds": bundle["thresholds"],
"reporting_guidance": [
"TRIPOD+AI prediction model reporting",
"DECIDE-AI early clinical decision-support evaluation",
"Model card and dataset-card transparency",
],
"clinical_caveats": [
"TalkBank/ASDBank cohorts are not a Thai external validation set.",
"Audio-derived predictions require transcript QA and feature-drift checks.",
"Probability estimates require calibration review before clinical use.",
],
})
print("\n=== SUMMARY ===")
print(results_df.to_string(index=False))
if __name__ == "__main__":
main()