Spaces:
Sleeping
Sleeping
| """ | |
| Baseline classifiers for screening: ASD vs TD vs DD. | |
| Two tasks are run: | |
| (A) Binary: ASD vs non-ASD (TD + DD) -> screening use-case | |
| (B) Multi-class: ASD vs DD vs TD -> differential | |
| Models: | |
| - Logistic Regression | |
| - Random Forest | |
| - Support Vector Machine (RBF) | |
| Evaluation: stratified 5-fold cross-validation. | |
| Outputs: | |
| reports/metrics/classification_results.csv | |
| reports/figures/confusion_matrix_<task>_<model>.png | |
| reports/figures/feature_importance.png | |
| reports/figures/roc_curve_binary.png | |
| """ | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| from datetime import date | |
| from pathlib import Path | |
| import joblib | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import pandas as pd | |
| import seaborn as sns | |
| from sklearn.ensemble import RandomForestClassifier | |
| from sklearn.impute import SimpleImputer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.metrics import ( | |
| average_precision_score, | |
| brier_score_loss, | |
| ConfusionMatrixDisplay, | |
| RocCurveDisplay, | |
| accuracy_score, | |
| classification_report, | |
| confusion_matrix, | |
| f1_score, | |
| precision_score, | |
| recall_score, | |
| roc_auc_score, | |
| ) | |
| from sklearn.model_selection import LeaveOneGroupOut, StratifiedKFold, cross_val_predict | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.svm import SVC | |
| try: | |
| from src.feature_schema import ( | |
| FEATURES, | |
| UNCERTAIN_HIGH, | |
| UNCERTAIN_LOW, | |
| feature_schema_rows, | |
| ) | |
| except ModuleNotFoundError: # running as `python src/classifier.py` | |
| from feature_schema import ( | |
| FEATURES, | |
| UNCERTAIN_HIGH, | |
| UNCERTAIN_LOW, | |
| feature_schema_rows, | |
| ) | |
| PROJECT_ROOT = Path(__file__).resolve().parent.parent | |
| DATA_DIR = PROJECT_ROOT / "data" | |
| FIG_DIR = PROJECT_ROOT / "reports" / "figures" | |
| METRIC_DIR = PROJECT_ROOT / "reports" / "metrics" | |
| ARTIFACT_DIR = PROJECT_ROOT / "artifacts" | |
| FIG_DIR.mkdir(parents=True, exist_ok=True) | |
| METRIC_DIR.mkdir(parents=True, exist_ok=True) | |
| ARTIFACT_DIR.mkdir(parents=True, exist_ok=True) | |
| sns.set_theme(style="whitegrid", context="talk") | |
| RANDOM_STATE = 42 | |
| MODEL_VERSION = "v0.17.0-trust-dashboard" | |
| def _build_models(): | |
| return { | |
| "LogReg": Pipeline([ | |
| ("imp", SimpleImputer(strategy="median")), | |
| ("sc", StandardScaler()), | |
| ("clf", LogisticRegression(max_iter=2000, | |
| class_weight="balanced", | |
| random_state=RANDOM_STATE)), | |
| ]), | |
| "RandomForest": Pipeline([ | |
| ("imp", SimpleImputer(strategy="median")), | |
| ("clf", RandomForestClassifier( | |
| n_estimators=300, | |
| class_weight="balanced", | |
| random_state=RANDOM_STATE, | |
| )), | |
| ]), | |
| "SVM": Pipeline([ | |
| ("imp", SimpleImputer(strategy="median")), | |
| ("sc", StandardScaler()), | |
| ("clf", SVC(kernel="rbf", probability=True, | |
| class_weight="balanced", | |
| random_state=RANDOM_STATE)), | |
| ]), | |
| } | |
| def _safe_div(num: float, den: float) -> float: | |
| return float(num / den) if den else 0.0 | |
| def _binary_metric_row( | |
| y_true: np.ndarray, | |
| y_pred: np.ndarray, | |
| y_proba: np.ndarray, | |
| *, | |
| threshold: float, | |
| ) -> dict: | |
| tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() | |
| uncertain = (y_proba >= UNCERTAIN_LOW) & (y_proba < UNCERTAIN_HIGH) | |
| return { | |
| "accuracy": accuracy_score(y_true, y_pred), | |
| "f1_macro": f1_score(y_true, y_pred, average="macro"), | |
| "roc_auc": roc_auc_score(y_true, y_proba), | |
| "pr_auc": average_precision_score(y_true, y_proba), | |
| "sensitivity": recall_score(y_true, y_pred, pos_label=1, zero_division=0), | |
| "specificity": _safe_div(tn, tn + fp), | |
| "ppv": precision_score(y_true, y_pred, pos_label=1, zero_division=0), | |
| "npv": _safe_div(tn, tn + fn), | |
| "brier_score": brier_score_loss(y_true, y_proba), | |
| "threshold": threshold, | |
| "tp": int(tp), | |
| "fp": int(fp), | |
| "tn": int(tn), | |
| "fn": int(fn), | |
| "uncertain_count": int(uncertain.sum()), | |
| "uncertain_rate": float(uncertain.mean()), | |
| } | |
| def _round_metric_row(row: dict) -> dict: | |
| rounded = {} | |
| for key, value in row.items(): | |
| if isinstance(value, (float, np.floating)): | |
| rounded[key] = round(float(value), 4) | |
| else: | |
| rounded[key] = value | |
| return rounded | |
| def _cv_evaluate(X, y, models, task: str, class_order, display_labels): | |
| """Run 5-fold CV for each model. | |
| class_order: labels as they appear in y (e.g. [0, 1] or ['ASD', 'DD', 'TD']) | |
| display_labels: human-readable names in the same order. | |
| """ | |
| skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=RANDOM_STATE) | |
| rows = [] | |
| preds = {} | |
| probs = {} | |
| for name, pipe in models.items(): | |
| y_pred = cross_val_predict(pipe, X, y, cv=skf, n_jobs=-1) | |
| preds[name] = y_pred | |
| acc = accuracy_score(y, y_pred) | |
| f1_macro = f1_score(y, y_pred, average="macro") | |
| row = {"task": task, "model": name} | |
| if task == "binary": | |
| y_proba = cross_val_predict( | |
| pipe, X, y, cv=skf, method="predict_proba", n_jobs=-1 | |
| )[:, 1] | |
| probs[name] = y_proba | |
| row.update(_binary_metric_row( | |
| np.asarray(y), np.asarray(y_pred), np.asarray(y_proba), | |
| threshold=0.5, | |
| )) | |
| else: | |
| row.update({ | |
| "accuracy": acc, | |
| "f1_macro": f1_macro, | |
| }) | |
| rows.append(_round_metric_row(row)) | |
| cm = confusion_matrix(y, y_pred, labels=class_order) | |
| fig, ax = plt.subplots(figsize=(6, 5)) | |
| ConfusionMatrixDisplay(cm, display_labels=display_labels).plot( | |
| ax=ax, cmap="Blues", values_format="d", colorbar=False, | |
| ) | |
| ax.set_title(f"{task} | {name}\nacc={acc:.3f} f1={f1_macro:.3f}") | |
| fig.tight_layout() | |
| out = FIG_DIR / f"confusion_matrix_{task}_{name}.png" | |
| fig.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" saved {out.relative_to(PROJECT_ROOT)}") | |
| print(f"\n[{task} / {name}]") | |
| print(classification_report(y, y_pred, labels=class_order, | |
| target_names=display_labels, digits=3)) | |
| return rows, preds, probs | |
| def _plot_feature_importance(X, y): | |
| pipe = Pipeline([ | |
| ("imp", SimpleImputer(strategy="median")), | |
| ("clf", RandomForestClassifier(n_estimators=500, | |
| class_weight="balanced", | |
| random_state=RANDOM_STATE)), | |
| ]) | |
| pipe.fit(X, y) | |
| imp = pipe.named_steps["clf"].feature_importances_ | |
| order = np.argsort(imp)[::-1] | |
| feats = np.array(FEATURES)[order] | |
| vals = imp[order] | |
| fig, ax = plt.subplots(figsize=(9, 6)) | |
| sns.barplot(x=vals, y=feats, ax=ax, color="#4C72B0") | |
| ax.set_title("Random Forest feature importance (multi-class)") | |
| ax.set_xlabel("Importance") | |
| fig.tight_layout() | |
| out = FIG_DIR / "feature_importance.png" | |
| fig.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" saved {out.relative_to(PROJECT_ROOT)}") | |
| def _plot_roc_curves(X, y, probs): | |
| fig, ax = plt.subplots(figsize=(7, 6)) | |
| for name, p in probs.items(): | |
| RocCurveDisplay.from_predictions(y, p, name=name, ax=ax) | |
| ax.plot([0, 1], [0, 1], "k--", alpha=0.4) | |
| ax.set_title("ROC curves - ASD vs non-ASD (5-fold CV)") | |
| fig.tight_layout() | |
| out = FIG_DIR / "roc_curve_binary.png" | |
| fig.savefig(out, dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| print(f" saved {out.relative_to(PROJECT_ROOT)}") | |
| def _data_hash(df: pd.DataFrame) -> str: | |
| payload = df.sort_values(["corpus", "participant_id"]).to_csv(index=False) | |
| return hashlib.sha256(payload.encode("utf-8")).hexdigest()[:16] | |
| def _threshold_table(y_true: np.ndarray, y_proba: np.ndarray) -> pd.DataFrame: | |
| rows = [] | |
| for threshold in np.round(np.arange(0.05, 0.96, 0.05), 2): | |
| y_pred = (y_proba >= threshold).astype(int) | |
| row = _binary_metric_row(y_true, y_pred, y_proba, threshold=float(threshold)) | |
| rows.append(_round_metric_row(row)) | |
| return pd.DataFrame(rows) | |
| def _calibration_bins(y_true: np.ndarray, y_proba: np.ndarray, n_bins: int = 10) -> pd.DataFrame: | |
| bins = np.linspace(0, 1, n_bins + 1) | |
| labels = [f"{bins[i]:.1f}-{bins[i + 1]:.1f}" for i in range(n_bins)] | |
| df = pd.DataFrame({"y_true": y_true, "prob_asd": y_proba}) | |
| df["bin"] = pd.cut(df["prob_asd"], bins=bins, labels=labels, | |
| include_lowest=True, right=True) | |
| rows = [] | |
| for label, group in df.groupby("bin", observed=False): | |
| if group.empty: | |
| continue | |
| rows.append({ | |
| "bin": str(label), | |
| "n": int(len(group)), | |
| "predicted_mean": round(float(group["prob_asd"].mean()), 4), | |
| "observed_rate": round(float(group["y_true"].mean()), 4), | |
| }) | |
| return pd.DataFrame(rows) | |
| def _decision_curve(y_true: np.ndarray, y_proba: np.ndarray) -> pd.DataFrame: | |
| n = len(y_true) | |
| prevalence = float(np.mean(y_true)) | |
| rows = [] | |
| for threshold in np.round(np.arange(0.05, 0.96, 0.05), 2): | |
| y_pred = y_proba >= threshold | |
| tp = int(((y_true == 1) & y_pred).sum()) | |
| fp = int(((y_true == 0) & y_pred).sum()) | |
| odds = threshold / (1 - threshold) | |
| rows.append({ | |
| "threshold": float(threshold), | |
| "model_net_benefit": round(tp / n - fp / n * odds, 4), | |
| "treat_all_net_benefit": round(prevalence - (1 - prevalence) * odds, 4), | |
| "treat_none_net_benefit": 0.0, | |
| }) | |
| return pd.DataFrame(rows) | |
| def _subgroup_performance(df: pd.DataFrame, y_true: np.ndarray, y_proba: np.ndarray) -> pd.DataFrame: | |
| rows = [] | |
| eval_df = df.copy() | |
| eval_df["y_true"] = y_true | |
| eval_df["prob_asd"] = y_proba | |
| eval_df["pred"] = (y_proba >= 0.5).astype(int) | |
| eval_df["age_band"] = pd.cut( | |
| eval_df["age_months"], | |
| bins=[0, 36, 48, 60, 72, 200], | |
| labels=["<36", "36-47", "48-59", "60-71", "72+"], | |
| include_lowest=True, | |
| ).astype(str) | |
| for dimension in ["corpus", "sex", "age_band"]: | |
| for value, sub in eval_df.groupby(dimension, dropna=False): | |
| if len(sub) < 5 or sub["y_true"].nunique() < 2: | |
| continue | |
| metrics = _binary_metric_row( | |
| sub["y_true"].to_numpy(), | |
| sub["pred"].to_numpy(), | |
| sub["prob_asd"].to_numpy(), | |
| threshold=0.5, | |
| ) | |
| rows.append(_round_metric_row({ | |
| "dimension": dimension, | |
| "value": str(value), | |
| "n": len(sub), | |
| **metrics, | |
| })) | |
| return pd.DataFrame(rows) | |
| def _leave_one_corpus_out(df: pd.DataFrame) -> pd.DataFrame: | |
| X = df[FEATURES].values | |
| y = (df["group"] == "ASD").astype(int).values | |
| groups = df["corpus"].values | |
| rows = [] | |
| for train_idx, test_idx in LeaveOneGroupOut().split(X, y, groups): | |
| test_corpus = str(groups[test_idx][0]) | |
| if len(np.unique(y[test_idx])) < 2: | |
| rows.append({ | |
| "held_out_corpus": test_corpus, | |
| "n_test": int(len(test_idx)), | |
| "status": "skipped_single_class", | |
| }) | |
| continue | |
| pipe = _build_models()["LogReg"] | |
| pipe.fit(X[train_idx], y[train_idx]) | |
| proba = pipe.predict_proba(X[test_idx])[:, 1] | |
| pred = (proba >= 0.5).astype(int) | |
| rows.append(_round_metric_row({ | |
| "held_out_corpus": test_corpus, | |
| "n_test": int(len(test_idx)), | |
| "status": "evaluated", | |
| **_binary_metric_row(y[test_idx], pred, proba, threshold=0.5), | |
| })) | |
| return pd.DataFrame(rows) | |
| def _write_model_bundle(df: pd.DataFrame) -> dict: | |
| X = df[FEATURES].values | |
| y = (df["group"] == "ASD").astype(int).values | |
| model = _build_models()["LogReg"] | |
| model.fit(X, y) | |
| bundle = { | |
| "model": model, | |
| "model_version": MODEL_VERSION, | |
| "features": FEATURES, | |
| "thresholds": { | |
| "uncertain_low": UNCERTAIN_LOW, | |
| "uncertain_high": UNCERTAIN_HIGH, | |
| "default_binary": 0.5, | |
| }, | |
| "training_metadata": { | |
| "trained_on": date.today().isoformat(), | |
| "n_rows": int(len(df)), | |
| "n_asd": int(y.sum()), | |
| "n_non_asd": int((1 - y).sum()), | |
| "corpora": sorted(df["corpus"].dropna().unique().tolist()), | |
| "data_hash": _data_hash(df), | |
| }, | |
| } | |
| out = ARTIFACT_DIR / "screening_model.joblib" | |
| joblib.dump(bundle, out) | |
| print(f" saved {out.relative_to(PROJECT_ROOT)}") | |
| return bundle | |
| def _write_json(path: Path, payload: dict) -> None: | |
| path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| print(f" saved {path.relative_to(PROJECT_ROOT)}") | |
| def main() -> None: | |
| csv_path = DATA_DIR / "combined_features.csv" | |
| df = pd.read_csv(csv_path) | |
| df = df.dropna(subset=["group"]) | |
| print(f"Loaded {len(df)} rows. Group counts:") | |
| print(df["group"].value_counts().to_string()) | |
| X = df[FEATURES].values | |
| all_rows = [] | |
| # ---------------- Binary: ASD vs non-ASD ---------------- | |
| print("\n" + "=" * 70) | |
| print("TASK A: Binary ASD (1) vs non-ASD (0)") | |
| print("=" * 70) | |
| y_bin = (df["group"] == "ASD").astype(int).values | |
| rows, preds, probs = _cv_evaluate( | |
| X, y_bin, _build_models(), | |
| task="binary", | |
| class_order=[0, 1], | |
| display_labels=["non-ASD", "ASD"], | |
| ) | |
| all_rows.extend(rows) | |
| if probs: | |
| _plot_roc_curves(X, y_bin, probs) | |
| # ---------------- Multi-class: ASD / DD / TD ---------------- | |
| print("\n" + "=" * 70) | |
| print("TASK B: Multi-class ASD vs DD vs TD") | |
| print("=" * 70) | |
| multi_df = df[df["group"].isin(["ASD", "DD", "TD"])] | |
| X_m = multi_df[FEATURES].values | |
| y_m = multi_df["group"].astype(str).to_numpy() | |
| rows, _, _ = _cv_evaluate( | |
| X_m, y_m, _build_models(), | |
| task="multiclass", | |
| class_order=["ASD", "DD", "TD"], | |
| display_labels=["ASD", "DD", "TD"], | |
| ) | |
| all_rows.extend(rows) | |
| _plot_feature_importance(X_m, y_m) | |
| # Save results | |
| results_df = pd.DataFrame(all_rows) | |
| out = METRIC_DIR / "classification_results.csv" | |
| results_df.to_csv(out, index=False) | |
| print(f"\n[saved] {out.relative_to(PROJECT_ROOT)}") | |
| # Model Trust Dashboard inputs focus on LogReg, the selected interpretable | |
| # screening model. These CSVs are static assets for project_dashboard/. | |
| logreg_prob = probs.get("LogReg") | |
| logreg_pred = preds.get("LogReg") | |
| if logreg_prob is not None and logreg_pred is not None: | |
| pred_df = df[[ | |
| "participant_id", "corpus", "group", "sex", "age_months", | |
| ]].copy() | |
| pred_df["y_true"] = y_bin | |
| pred_df["prob_asd"] = np.round(logreg_prob, 6) | |
| pred_df["pred_050"] = logreg_pred | |
| pred_df["uncertainty_zone"] = np.select( | |
| [ | |
| logreg_prob < UNCERTAIN_LOW, | |
| (logreg_prob >= UNCERTAIN_LOW) & (logreg_prob < UNCERTAIN_HIGH), | |
| logreg_prob >= UNCERTAIN_HIGH, | |
| ], | |
| ["low", "uncertain", "high"], | |
| default="unknown", | |
| ) | |
| pred_out = METRIC_DIR / "binary_oof_predictions.csv" | |
| pred_df.to_csv(pred_out, index=False) | |
| print(f"[saved] {pred_out.relative_to(PROJECT_ROOT)}") | |
| threshold_out = METRIC_DIR / "threshold_metrics.csv" | |
| _threshold_table(y_bin, logreg_prob).to_csv(threshold_out, index=False) | |
| print(f"[saved] {threshold_out.relative_to(PROJECT_ROOT)}") | |
| calibration_out = METRIC_DIR / "calibration_bins.csv" | |
| _calibration_bins(y_bin, logreg_prob).to_csv(calibration_out, index=False) | |
| print(f"[saved] {calibration_out.relative_to(PROJECT_ROOT)}") | |
| dca_out = METRIC_DIR / "decision_curve.csv" | |
| _decision_curve(y_bin, logreg_prob).to_csv(dca_out, index=False) | |
| print(f"[saved] {dca_out.relative_to(PROJECT_ROOT)}") | |
| subgroup_out = METRIC_DIR / "subgroup_performance.csv" | |
| _subgroup_performance(df, y_bin, logreg_prob).to_csv(subgroup_out, index=False) | |
| print(f"[saved] {subgroup_out.relative_to(PROJECT_ROOT)}") | |
| loco_out = METRIC_DIR / "leave_one_corpus_out.csv" | |
| _leave_one_corpus_out(df).to_csv(loco_out, index=False) | |
| print(f"[saved] {loco_out.relative_to(PROJECT_ROOT)}") | |
| feature_schema_out = ARTIFACT_DIR / "feature_schema.json" | |
| _write_json(feature_schema_out, { | |
| "features": FEATURES, | |
| "feature_docs": feature_schema_rows(), | |
| "thresholds": { | |
| "uncertain_low": UNCERTAIN_LOW, | |
| "uncertain_high": UNCERTAIN_HIGH, | |
| }, | |
| }) | |
| bundle = _write_model_bundle(df) | |
| model_card_out = ARTIFACT_DIR / "model_card.json" | |
| _write_json(model_card_out, { | |
| "model_version": MODEL_VERSION, | |
| "model_type": "Logistic Regression with median imputation and standard scaling", | |
| "intended_use": "ASD screening support and research demo; not diagnostic.", | |
| "not_intended_use": "Autonomous diagnosis, emergency triage, or replacement for clinician assessment.", | |
| "inputs": FEATURES, | |
| "training_metadata": bundle["training_metadata"], | |
| "thresholds": bundle["thresholds"], | |
| "reporting_guidance": [ | |
| "TRIPOD+AI prediction model reporting", | |
| "DECIDE-AI early clinical decision-support evaluation", | |
| "Model card and dataset-card transparency", | |
| ], | |
| "clinical_caveats": [ | |
| "TalkBank/ASDBank cohorts are not a Thai external validation set.", | |
| "Audio-derived predictions require transcript QA and feature-drift checks.", | |
| "Probability estimates require calibration review before clinical use.", | |
| ], | |
| }) | |
| print("\n=== SUMMARY ===") | |
| print(results_df.to_string(index=False)) | |
| if __name__ == "__main__": | |
| main() | |