| """ |
| Comprehensive evaluation script for the DSM-5 symptom classifier. |
| |
| Generates all metrics, visualizations, and error analysis for the capstone report. |
| |
| Usage: |
| python evaluate_redsm5.py [--model-path PATH] [--data-dir PATH] [--output-dir PATH] |
| """ |
|
|
| import argparse |
| import json |
| import logging |
| from pathlib import Path |
|
|
| import pandas as pd |
| import torch |
| from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support |
| from torch.utils.data import DataLoader |
| from transformers import AutoTokenizer |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| from train_redsm5_model import SymptomClassifier, SymptomDataset, collate_fn |
|
|
|
|
| def per_symptom_breakdown(labels, preds, label_names) -> pd.DataFrame: |
| """Create a per-symptom F1 breakdown table.""" |
| p, r, f1, support = precision_recall_fscore_support( |
| labels, preds, average=None, labels=list(range(len(label_names))), zero_division=0 |
| ) |
| rows = [] |
| for i, name in enumerate(label_names): |
| rows.append( |
| { |
| "Symptom": name, |
| "Precision": round(p[i], 4), |
| "Recall": round(r[i], 4), |
| "F1-Score": round(f1[i], 4), |
| "Support": int(support[i]), |
| } |
| ) |
| return pd.DataFrame(rows) |
|
|
|
|
| def error_analysis(test_df, preds, label_names) -> pd.DataFrame: |
| """Identify misclassified examples with true vs predicted labels.""" |
| errors = [] |
| for i, (true_label, pred_label) in enumerate(zip(test_df["label_id"], preds)): |
| if true_label != pred_label: |
| errors.append( |
| { |
| "sentence_text": test_df.iloc[i]["clean_text"], |
| "true_label": label_names[true_label], |
| "predicted_label": label_names[pred_label], |
| "post_id": test_df.iloc[i]["post_id"], |
| } |
| ) |
| return pd.DataFrame(errors) |
|
|
|
|
| def post_level_evaluation(test_df, preds, label_names) -> dict: |
| """Aggregate sentence predictions to post level and evaluate.""" |
| test_df = test_df.copy() |
| test_df["predicted_label"] = [label_names[p] for p in preds] |
|
|
| results_per_post = {} |
| for post_id, group in test_df.groupby("post_id"): |
| true_symptoms = set(group[group["label"] != "NO_SYMPTOM"]["label"].unique()) |
| pred_symptoms = set(group[group["predicted_label"] != "NO_SYMPTOM"]["predicted_label"].unique()) |
| results_per_post[post_id] = { |
| "true_symptom_count": len(true_symptoms), |
| "pred_symptom_count": len(pred_symptoms), |
| "true_symptoms": sorted(true_symptoms), |
| "pred_symptoms": sorted(pred_symptoms), |
| "symptom_overlap": len(true_symptoms & pred_symptoms), |
| } |
|
|
| df = pd.DataFrame(results_per_post).T |
|
|
| |
| def severity(count): |
| if count == 0: |
| return "none" |
| if count <= 2: |
| return "mild" |
| if count <= 4: |
| return "moderate" |
| return "severe" |
|
|
| df["true_severity"] = df["true_symptom_count"].apply(severity) |
| df["pred_severity"] = df["pred_symptom_count"].apply(severity) |
| severity_accuracy = (df["true_severity"] == df["pred_severity"]).mean() |
|
|
| |
| count_mae = (df["true_symptom_count"] - df["pred_symptom_count"]).abs().mean() |
|
|
| return { |
| "num_posts": len(df), |
| "severity_accuracy": round(severity_accuracy, 4), |
| "symptom_count_mae": round(count_mae, 4), |
| "avg_true_symptoms": round(df["true_symptom_count"].mean(), 2), |
| "avg_pred_symptoms": round(df["pred_symptom_count"].mean(), 2), |
| "avg_symptom_overlap": round(df["symptom_overlap"].mean(), 2), |
| } |
|
|
|
|
| def baseline_comparison(our_micro_f1: float) -> pd.DataFrame: |
| """Compare to published ReDSM5 baselines.""" |
| baselines = { |
| "SVM (TF-IDF)": 0.39, |
| "CNN (10 epochs)": 0.25, |
| "BERT (fine-tuned)": 0.51, |
| "LLaMA 3.2-1B (fine-tuned)": 0.54, |
| "DepScreen DistilBERT (ours)": our_micro_f1, |
| } |
| rows = [] |
| for name, score in baselines.items(): |
| delta = our_micro_f1 - score if name != "DepScreen DistilBERT (ours)" else 0 |
| rows.append( |
| { |
| "Model": name, |
| "Micro-F1": round(score, 4), |
| "Delta vs Ours": f"+{delta:.4f}" if delta > 0 else ("—" if delta == 0 else f"{delta:.4f}"), |
| } |
| ) |
| return pd.DataFrame(rows) |
|
|
|
|
| def generate_markdown_report( |
| test_metrics: dict, |
| per_class_df: pd.DataFrame, |
| error_df: pd.DataFrame, |
| post_metrics: dict, |
| baseline_df: pd.DataFrame, |
| model_name: str, |
| output_path: Path, |
| ): |
| """Generate a markdown evaluation report.""" |
| lines = [ |
| "# DepScreen — Model Evaluation Report", |
| "", |
| f"**Model**: {model_name}", |
| f"**Date**: {pd.Timestamp.now().strftime('%Y-%m-%d')}", |
| "", |
| "## Overall Metrics", |
| "", |
| "| Metric | Value |", |
| "|--------|-------|", |
| f"| Accuracy | {test_metrics['accuracy']:.4f} |", |
| f"| Micro-F1 | {test_metrics['micro_f1']:.4f} |", |
| f"| Macro-F1 | {test_metrics['macro_f1']:.4f} |", |
| f"| Micro-Precision | {test_metrics['micro_precision']:.4f} |", |
| f"| Micro-Recall | {test_metrics['micro_recall']:.4f} |", |
| "", |
| "## Per-Symptom Breakdown", |
| "", |
| per_class_df.to_markdown(index=False), |
| "", |
| "## Baseline Comparison", |
| "", |
| baseline_df.to_markdown(index=False), |
| "", |
| "## Post-Level Aggregation", |
| "", |
| "| Metric | Value |", |
| "|--------|-------|", |
| f"| Posts evaluated | {post_metrics['num_posts']} |", |
| f"| Severity classification accuracy | {post_metrics['severity_accuracy']:.2%} |", |
| f"| Symptom count MAE | {post_metrics['symptom_count_mae']:.2f} |", |
| f"| Avg true symptoms/post | {post_metrics['avg_true_symptoms']} |", |
| f"| Avg predicted symptoms/post | {post_metrics['avg_pred_symptoms']} |", |
| f"| Avg symptom overlap/post | {post_metrics['avg_symptom_overlap']} |", |
| "", |
| "## Error Analysis", |
| "", |
| f"Total misclassifications: {len(error_df)} / {test_metrics.get('total_samples', '?')}", |
| "", |
| "### Sample Errors (first 20)", |
| "", |
| ] |
|
|
| if len(error_df) > 0: |
| sample = error_df.head(20) |
| lines.append(sample.to_markdown(index=False)) |
| else: |
| lines.append("No errors found.") |
|
|
| lines.extend( |
| [ |
| "", |
| "## Notes", |
| "", |
| "- PSYCHOMOTOR has only 1 test sample — F1 unreliable for this class", |
| "- APPETITE_CHANGE has only 4 test samples — F1 should be interpreted with caution", |
| "- NO_SYMPTOM has low recall (model prefers detecting symptoms) — safe bias for screening", |
| "- All baselines are from the ReDSM5 paper (CIKM 2025, arXiv:2508.03399)", |
| ] |
| ) |
|
|
| with open(output_path, "w") as f: |
| f.write("\n".join(lines)) |
|
|
| print(f"Report saved to: {output_path}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Evaluate DSM-5 symptom classifier") |
| parser.add_argument("--model-path", type=str, default=None) |
| parser.add_argument("--data-dir", type=str, default=None) |
| parser.add_argument("--output-dir", type=str, default=None) |
| parser.add_argument("--model-name", type=str, default="distilbert-base-uncased") |
| parser.add_argument("--max-length", type=int, default=128) |
| parser.add_argument("--batch-size", type=int, default=16) |
| args = parser.parse_args() |
|
|
| base_dir = Path(__file__).parent.parent |
| data_dir = Path(args.data_dir) if args.data_dir else base_dir / "data" / "redsm5" / "processed" |
| model_path = Path(args.model_path) if args.model_path else base_dir / "models" / "symptom_classifier.pt" |
| output_dir = Path(args.output_dir) if args.output_dir else base_dir / "evaluation" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| if torch.backends.mps.is_available(): |
| device = torch.device("mps") |
| elif torch.cuda.is_available(): |
| device = torch.device("cuda") |
| else: |
| device = torch.device("cpu") |
|
|
| |
| with open(data_dir / "metadata.json") as f: |
| metadata = json.load(f) |
| label_map = metadata["label_map"] |
| label_names = [name for name, _ in sorted(label_map.items(), key=lambda x: x[1])] |
| num_classes = metadata["num_classes"] |
|
|
| |
| logger.info(f"Loading model from {model_path}") |
| model = SymptomClassifier(num_classes=num_classes, model_name=args.model_name) |
| model.load_state_dict(torch.load(model_path, map_location=device)) |
| model.to(device) |
| model.eval() |
|
|
| |
| test_df = pd.read_csv(data_dir / "test.csv") |
| tokenizer = AutoTokenizer.from_pretrained(args.model_name) |
| test_dataset = SymptomDataset( |
| test_df["clean_text"].tolist(), |
| test_df["label_id"].tolist(), |
| tokenizer, |
| args.max_length, |
| ) |
| num_workers = 0 if device.type == "mps" else 2 |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=args.batch_size, |
| collate_fn=collate_fn, |
| num_workers=num_workers, |
| ) |
|
|
| |
| logger.info("Running evaluation...") |
| criterion = torch.nn.CrossEntropyLoss() |
|
|
| all_preds = [] |
| all_labels = [] |
| total_loss = 0 |
|
|
| with torch.no_grad(): |
| for batch in test_loader: |
| input_ids = batch["input_ids"].to(device) |
| attention_mask = batch["attention_mask"].to(device) |
| labels = batch["label"].to(device) |
| logits = model(input_ids, attention_mask) |
| loss = criterion(logits, labels) |
| total_loss += loss.item() |
| preds = torch.argmax(logits, dim=1) |
| all_preds.extend(preds.cpu().numpy()) |
| all_labels.extend(labels.cpu().numpy()) |
|
|
| |
| accuracy = accuracy_score(all_labels, all_preds) |
| micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro") |
| macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro") |
|
|
| test_metrics = { |
| "accuracy": accuracy, |
| "micro_f1": micro_f1, |
| "micro_precision": micro_p, |
| "micro_recall": micro_r, |
| "macro_f1": macro_f1, |
| "macro_precision": macro_p, |
| "macro_recall": macro_r, |
| "total_samples": len(all_labels), |
| "loss": total_loss / len(test_loader), |
| } |
|
|
| |
| per_class_df = per_symptom_breakdown(all_labels, all_preds, label_names) |
| print("\n" + "=" * 60) |
| print("Per-Symptom F1 Breakdown:") |
| print("=" * 60) |
| print(per_class_df.to_string(index=False)) |
|
|
| |
| error_df = error_analysis(test_df, all_preds, label_names) |
| print(f"\nMisclassified: {len(error_df)} / {len(all_labels)} ({len(error_df) / len(all_labels):.1%})") |
|
|
| |
| post_metrics = post_level_evaluation(test_df, all_preds, label_names) |
| print(f"\nPost-level severity accuracy: {post_metrics['severity_accuracy']:.2%}") |
| print(f"Symptom count MAE: {post_metrics['symptom_count_mae']:.2f}") |
|
|
| |
| baseline_df = baseline_comparison(micro_f1) |
| print("\n" + "=" * 60) |
| print("Baseline Comparison:") |
| print("=" * 60) |
| print(baseline_df.to_string(index=False)) |
|
|
| |
| per_class_df.to_csv(output_dir / "per_symptom_f1.csv", index=False) |
| error_df.to_csv(output_dir / "error_analysis.csv", index=False) |
| baseline_df.to_csv(output_dir / "baseline_comparison.csv", index=False) |
|
|
| with open(output_dir / "post_level_metrics.json", "w") as f: |
| json.dump(post_metrics, f, indent=2) |
|
|
| with open(output_dir / "test_metrics.json", "w") as f: |
| json.dump(test_metrics, f, indent=2, default=str) |
|
|
| |
| cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes))) |
| cm_df = pd.DataFrame(cm, index=label_names, columns=label_names) |
| cm_df.to_csv(output_dir / "confusion_matrix.csv") |
|
|
| |
| generate_markdown_report( |
| test_metrics, |
| per_class_df, |
| error_df, |
| post_metrics, |
| baseline_df, |
| args.model_name, |
| output_dir / "evaluation_report.md", |
| ) |
|
|
| print(f"\nAll evaluation artifacts saved to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|