depscreen / ml /scripts /evaluate_redsm5.py
halsabbah's picture
Add CI/CD pipelines and code quality tooling
3187428
"""
Comprehensive evaluation script for the DSM-5 symptom classifier.
Generates all metrics, visualizations, and error analysis for the capstone report.
Usage:
python evaluate_redsm5.py [--model-path PATH] [--data-dir PATH] [--output-dir PATH]
"""
import argparse
import json
import logging
from pathlib import Path
import pandas as pd
import torch
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Reuse model and dataset from training script
from train_redsm5_model import SymptomClassifier, SymptomDataset, collate_fn
def per_symptom_breakdown(labels, preds, label_names) -> pd.DataFrame:
"""Create a per-symptom F1 breakdown table."""
p, r, f1, support = precision_recall_fscore_support(
labels, preds, average=None, labels=list(range(len(label_names))), zero_division=0
)
rows = []
for i, name in enumerate(label_names):
rows.append(
{
"Symptom": name,
"Precision": round(p[i], 4),
"Recall": round(r[i], 4),
"F1-Score": round(f1[i], 4),
"Support": int(support[i]),
}
)
return pd.DataFrame(rows)
def error_analysis(test_df, preds, label_names) -> pd.DataFrame:
"""Identify misclassified examples with true vs predicted labels."""
errors = []
for i, (true_label, pred_label) in enumerate(zip(test_df["label_id"], preds)):
if true_label != pred_label:
errors.append(
{
"sentence_text": test_df.iloc[i]["clean_text"],
"true_label": label_names[true_label],
"predicted_label": label_names[pred_label],
"post_id": test_df.iloc[i]["post_id"],
}
)
return pd.DataFrame(errors)
def post_level_evaluation(test_df, preds, label_names) -> dict:
"""Aggregate sentence predictions to post level and evaluate."""
test_df = test_df.copy()
test_df["predicted_label"] = [label_names[p] for p in preds]
results_per_post = {}
for post_id, group in test_df.groupby("post_id"):
true_symptoms = set(group[group["label"] != "NO_SYMPTOM"]["label"].unique())
pred_symptoms = set(group[group["predicted_label"] != "NO_SYMPTOM"]["predicted_label"].unique())
results_per_post[post_id] = {
"true_symptom_count": len(true_symptoms),
"pred_symptom_count": len(pred_symptoms),
"true_symptoms": sorted(true_symptoms),
"pred_symptoms": sorted(pred_symptoms),
"symptom_overlap": len(true_symptoms & pred_symptoms),
}
df = pd.DataFrame(results_per_post).T
# Severity accuracy (based on count thresholds)
def severity(count):
if count == 0:
return "none"
if count <= 2:
return "mild"
if count <= 4:
return "moderate"
return "severe"
df["true_severity"] = df["true_symptom_count"].apply(severity)
df["pred_severity"] = df["pred_symptom_count"].apply(severity)
severity_accuracy = (df["true_severity"] == df["pred_severity"]).mean()
# Average symptom count error
count_mae = (df["true_symptom_count"] - df["pred_symptom_count"]).abs().mean()
return {
"num_posts": len(df),
"severity_accuracy": round(severity_accuracy, 4),
"symptom_count_mae": round(count_mae, 4),
"avg_true_symptoms": round(df["true_symptom_count"].mean(), 2),
"avg_pred_symptoms": round(df["pred_symptom_count"].mean(), 2),
"avg_symptom_overlap": round(df["symptom_overlap"].mean(), 2),
}
def baseline_comparison(our_micro_f1: float) -> pd.DataFrame:
"""Compare to published ReDSM5 baselines."""
baselines = {
"SVM (TF-IDF)": 0.39,
"CNN (10 epochs)": 0.25,
"BERT (fine-tuned)": 0.51,
"LLaMA 3.2-1B (fine-tuned)": 0.54,
"DepScreen DistilBERT (ours)": our_micro_f1,
}
rows = []
for name, score in baselines.items():
delta = our_micro_f1 - score if name != "DepScreen DistilBERT (ours)" else 0
rows.append(
{
"Model": name,
"Micro-F1": round(score, 4),
"Delta vs Ours": f"+{delta:.4f}" if delta > 0 else ("—" if delta == 0 else f"{delta:.4f}"),
}
)
return pd.DataFrame(rows)
def generate_markdown_report(
test_metrics: dict,
per_class_df: pd.DataFrame,
error_df: pd.DataFrame,
post_metrics: dict,
baseline_df: pd.DataFrame,
model_name: str,
output_path: Path,
):
"""Generate a markdown evaluation report."""
lines = [
"# DepScreen — Model Evaluation Report",
"",
f"**Model**: {model_name}",
f"**Date**: {pd.Timestamp.now().strftime('%Y-%m-%d')}",
"",
"## Overall Metrics",
"",
"| Metric | Value |",
"|--------|-------|",
f"| Accuracy | {test_metrics['accuracy']:.4f} |",
f"| Micro-F1 | {test_metrics['micro_f1']:.4f} |",
f"| Macro-F1 | {test_metrics['macro_f1']:.4f} |",
f"| Micro-Precision | {test_metrics['micro_precision']:.4f} |",
f"| Micro-Recall | {test_metrics['micro_recall']:.4f} |",
"",
"## Per-Symptom Breakdown",
"",
per_class_df.to_markdown(index=False),
"",
"## Baseline Comparison",
"",
baseline_df.to_markdown(index=False),
"",
"## Post-Level Aggregation",
"",
"| Metric | Value |",
"|--------|-------|",
f"| Posts evaluated | {post_metrics['num_posts']} |",
f"| Severity classification accuracy | {post_metrics['severity_accuracy']:.2%} |",
f"| Symptom count MAE | {post_metrics['symptom_count_mae']:.2f} |",
f"| Avg true symptoms/post | {post_metrics['avg_true_symptoms']} |",
f"| Avg predicted symptoms/post | {post_metrics['avg_pred_symptoms']} |",
f"| Avg symptom overlap/post | {post_metrics['avg_symptom_overlap']} |",
"",
"## Error Analysis",
"",
f"Total misclassifications: {len(error_df)} / {test_metrics.get('total_samples', '?')}",
"",
"### Sample Errors (first 20)",
"",
]
if len(error_df) > 0:
sample = error_df.head(20)
lines.append(sample.to_markdown(index=False))
else:
lines.append("No errors found.")
lines.extend(
[
"",
"## Notes",
"",
"- PSYCHOMOTOR has only 1 test sample — F1 unreliable for this class",
"- APPETITE_CHANGE has only 4 test samples — F1 should be interpreted with caution",
"- NO_SYMPTOM has low recall (model prefers detecting symptoms) — safe bias for screening",
"- All baselines are from the ReDSM5 paper (CIKM 2025, arXiv:2508.03399)",
]
)
with open(output_path, "w") as f:
f.write("\n".join(lines))
print(f"Report saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(description="Evaluate DSM-5 symptom classifier")
parser.add_argument("--model-path", type=str, default=None)
parser.add_argument("--data-dir", type=str, default=None)
parser.add_argument("--output-dir", type=str, default=None)
parser.add_argument("--model-name", type=str, default="distilbert-base-uncased")
parser.add_argument("--max-length", type=int, default=128)
parser.add_argument("--batch-size", type=int, default=16)
args = parser.parse_args()
base_dir = Path(__file__).parent.parent
data_dir = Path(args.data_dir) if args.data_dir else base_dir / "data" / "redsm5" / "processed"
model_path = Path(args.model_path) if args.model_path else base_dir / "models" / "symptom_classifier.pt"
output_dir = Path(args.output_dir) if args.output_dir else base_dir / "evaluation"
output_dir.mkdir(parents=True, exist_ok=True)
# Device
if torch.backends.mps.is_available():
device = torch.device("mps")
elif torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
# Load metadata
with open(data_dir / "metadata.json") as f:
metadata = json.load(f)
label_map = metadata["label_map"]
label_names = [name for name, _ in sorted(label_map.items(), key=lambda x: x[1])]
num_classes = metadata["num_classes"]
# Load model
logger.info(f"Loading model from {model_path}")
model = SymptomClassifier(num_classes=num_classes, model_name=args.model_name)
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)
model.eval()
# Load test data
test_df = pd.read_csv(data_dir / "test.csv")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
test_dataset = SymptomDataset(
test_df["clean_text"].tolist(),
test_df["label_id"].tolist(),
tokenizer,
args.max_length,
)
num_workers = 0 if device.type == "mps" else 2
test_loader = DataLoader(
test_dataset,
batch_size=args.batch_size,
collate_fn=collate_fn,
num_workers=num_workers,
)
# Evaluate
logger.info("Running evaluation...")
criterion = torch.nn.CrossEntropyLoss()
all_preds = []
all_labels = []
total_loss = 0
with torch.no_grad():
for batch in test_loader:
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
labels = batch["label"].to(device)
logits = model(input_ids, attention_mask)
loss = criterion(logits, labels)
total_loss += loss.item()
preds = torch.argmax(logits, dim=1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
# Metrics
accuracy = accuracy_score(all_labels, all_preds)
micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="micro")
macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="macro")
test_metrics = {
"accuracy": accuracy,
"micro_f1": micro_f1,
"micro_precision": micro_p,
"micro_recall": micro_r,
"macro_f1": macro_f1,
"macro_precision": macro_p,
"macro_recall": macro_r,
"total_samples": len(all_labels),
"loss": total_loss / len(test_loader),
}
# Per-symptom breakdown
per_class_df = per_symptom_breakdown(all_labels, all_preds, label_names)
print("\n" + "=" * 60)
print("Per-Symptom F1 Breakdown:")
print("=" * 60)
print(per_class_df.to_string(index=False))
# Error analysis
error_df = error_analysis(test_df, all_preds, label_names)
print(f"\nMisclassified: {len(error_df)} / {len(all_labels)} ({len(error_df) / len(all_labels):.1%})")
# Post-level
post_metrics = post_level_evaluation(test_df, all_preds, label_names)
print(f"\nPost-level severity accuracy: {post_metrics['severity_accuracy']:.2%}")
print(f"Symptom count MAE: {post_metrics['symptom_count_mae']:.2f}")
# Baseline comparison
baseline_df = baseline_comparison(micro_f1)
print("\n" + "=" * 60)
print("Baseline Comparison:")
print("=" * 60)
print(baseline_df.to_string(index=False))
# Save everything
per_class_df.to_csv(output_dir / "per_symptom_f1.csv", index=False)
error_df.to_csv(output_dir / "error_analysis.csv", index=False)
baseline_df.to_csv(output_dir / "baseline_comparison.csv", index=False)
with open(output_dir / "post_level_metrics.json", "w") as f:
json.dump(post_metrics, f, indent=2)
with open(output_dir / "test_metrics.json", "w") as f:
json.dump(test_metrics, f, indent=2, default=str)
# Confusion matrix as CSV
cm = confusion_matrix(all_labels, all_preds, labels=list(range(num_classes)))
cm_df = pd.DataFrame(cm, index=label_names, columns=label_names)
cm_df.to_csv(output_dir / "confusion_matrix.csv")
# Markdown report
generate_markdown_report(
test_metrics,
per_class_df,
error_df,
post_metrics,
baseline_df,
args.model_name,
output_dir / "evaluation_report.md",
)
print(f"\nAll evaluation artifacts saved to: {output_dir}")
if __name__ == "__main__":
main()