depscreen / ml /scripts /preprocess_redsm5.py
halsabbah's picture
Add CI/CD pipelines and code quality tooling
3187428
"""
ReDSM5 preprocessing pipeline for sentence-level DSM-5 symptom classification.
Loads the ReDSM5 dataset (1,484 posts, 2,058 annotations), creates an 11-class
sentence-level classification dataset, and splits by post_id to prevent data leakage.
Classes:
9 DSM-5 symptoms + SPECIAL_CASE + NO_SYMPTOM
Usage:
python preprocess_redsm5.py [--redsm5-dir PATH] [--output-dir PATH]
"""
import argparse
import json
import re
from pathlib import Path
import pandas as pd
from sklearn.model_selection import GroupShuffleSplit
# ── Label configuration ──────────────────────────────────────────────────────
SYMPTOM_LABELS = {
"DEPRESSED_MOOD": 0,
"ANHEDONIA": 1,
"APPETITE_CHANGE": 2,
"SLEEP_ISSUES": 3,
"PSYCHOMOTOR": 4,
"FATIGUE": 5,
"WORTHLESSNESS": 6,
"COGNITIVE_ISSUES": 7,
"SUICIDAL_THOUGHTS": 8,
"SPECIAL_CASE": 9,
"NO_SYMPTOM": 10,
}
SYMPTOM_READABLE = {
"DEPRESSED_MOOD": "Depressed Mood",
"ANHEDONIA": "Loss of Interest / Pleasure",
"APPETITE_CHANGE": "Appetite / Weight Change",
"SLEEP_ISSUES": "Sleep Disturbance",
"PSYCHOMOTOR": "Psychomotor Changes",
"FATIGUE": "Fatigue / Loss of Energy",
"WORTHLESSNESS": "Worthlessness / Guilt",
"COGNITIVE_ISSUES": "Difficulty Concentrating",
"SUICIDAL_THOUGHTS": "Suicidal Ideation",
"SPECIAL_CASE": "Other Clinical Indicator",
"NO_SYMPTOM": "No Symptom Detected",
}
# ── Text cleaning ─────────────────────────────────────────────────────────────
def clean_sentence(text: str) -> str:
"""Clean a single sentence for model input."""
if not isinstance(text, str):
return ""
# Replace URLs
text = re.sub(r"http\S+|www\.\S+", "[URL]", text)
# Replace Reddit usernames and subreddits
text = re.sub(r"u/\w+", "[USER]", text)
text = re.sub(r"r/\w+", "[SUBREDDIT]", text)
# Normalize unicode quotes and dashes
text = text.replace("\u2018", "'").replace("\u2019", "'")
text = text.replace("\u201c", '"').replace("\u201d", '"')
text = text.replace("\u2014", " -- ").replace("\u2013", " - ")
# Collapse whitespace
text = re.sub(r"\s+", " ", text).strip()
return text
def split_into_sentences(text: str) -> list[str]:
"""Rule-based sentence splitter for Reddit-style text."""
if not isinstance(text, str) or len(text.strip()) == 0:
return []
# Split on sentence-ending punctuation followed by space + uppercase or end
# Handles abbreviations like Dr., Mr., etc. imperfectly but good enough
parts = re.split(r"(?<=[.!?])\s+(?=[A-Z\"])", text)
sentences = []
for part in parts:
part = part.strip()
if len(part) >= 5: # Skip very short fragments
sentences.append(part)
# If no splits happened and text is long, return as single sentence
if not sentences and len(text.strip()) >= 5:
sentences = [text.strip()]
return sentences
# ── Data loading ──────────────────────────────────────────────────────────────
def load_data(redsm5_dir: Path) -> tuple[pd.DataFrame, pd.DataFrame]:
"""Load posts and annotations CSVs."""
posts_path = redsm5_dir / "redsm5_posts.csv"
annot_path = redsm5_dir / "redsm5_annotations.csv"
if not posts_path.exists():
raise FileNotFoundError(f"Posts file not found: {posts_path}")
if not annot_path.exists():
raise FileNotFoundError(f"Annotations file not found: {annot_path}")
posts = pd.read_csv(posts_path)
annotations = pd.read_csv(annot_path)
print(f"Loaded {len(posts)} posts, {len(annotations)} annotations")
return posts, annotations
# ── Positive samples ──────────────────────────────────────────────────────────
def create_positive_samples(annotations: pd.DataFrame) -> pd.DataFrame:
"""Create training samples from status=1 (symptom present) annotations.
Deduplicates: if the same (sentence_id, DSM5_symptom) pair appears multiple
times (multiple annotator explanations), keep only the first.
"""
positives = annotations[annotations["status"] == 1].copy()
# Deduplicate same sentence + same symptom (keep first annotator)
before = len(positives)
positives = positives.drop_duplicates(subset=["sentence_id", "DSM5_symptom"], keep="first")
after = len(positives)
if before != after:
print(f" Deduplicated {before - after} duplicate (sentence, symptom) pairs")
positives["label"] = positives["DSM5_symptom"]
positives["label_id"] = positives["DSM5_symptom"].map(SYMPTOM_LABELS)
positives["clean_text"] = positives["sentence_text"].apply(clean_sentence)
# Remove empty after cleaning
positives = positives[positives["clean_text"].str.len() >= 5]
print(f" Positive samples: {len(positives)}")
print(" Per symptom:")
for symptom, count in positives["label"].value_counts().items():
print(f" {symptom}: {count}")
return positives[["post_id", "sentence_id", "sentence_text", "clean_text", "label", "label_id"]].reset_index(
drop=True
)
# ── Negative samples ─────────────────────────────────────────────────────────
def create_negative_samples(
posts: pd.DataFrame,
annotations: pd.DataFrame,
max_negatives: int = 400,
) -> pd.DataFrame:
"""Create NO_SYMPTOM training samples from two sources:
1. Sentences that only appear with status=0 (never status=1) in annotations.
2. Sentences extracted from completely unannotated posts.
"""
# Source 1: True negative sentences (only status=0, never status=1)
positive_sentence_ids = set(annotations[annotations["status"] == 1]["sentence_id"].unique())
all_sentence_ids = set(annotations["sentence_id"].unique())
negative_only_ids = all_sentence_ids - positive_sentence_ids
neg_from_annotations = annotations[annotations["sentence_id"].isin(negative_only_ids)].drop_duplicates(
subset=["sentence_id"], keep="first"
)
print(f" Negative sentences from annotations (status=0 only): {len(neg_from_annotations)}")
# Source 2: Sentences from unannotated posts
annotated_post_ids = set(annotations["post_id"].unique())
all_post_ids = set(posts["post_id"].unique())
unannotated_post_ids = all_post_ids - annotated_post_ids
unannotated_posts = posts[posts["post_id"].isin(unannotated_post_ids)]
neg_from_posts_rows = []
for _, row in unannotated_posts.iterrows():
sentences = split_into_sentences(row["text"])
for i, sent in enumerate(sentences):
neg_from_posts_rows.append(
{
"post_id": row["post_id"],
"sentence_id": f"{row['post_id']}_neg_{i}",
"sentence_text": sent,
}
)
neg_from_posts = pd.DataFrame(neg_from_posts_rows)
print(f" Negative sentences from unannotated posts: {len(neg_from_posts)}")
# Combine both sources
neg_combined_rows = []
for _, row in neg_from_annotations.iterrows():
neg_combined_rows.append(
{
"post_id": row["post_id"],
"sentence_id": row["sentence_id"],
"sentence_text": row["sentence_text"],
}
)
for _, row in neg_from_posts.iterrows():
neg_combined_rows.append(
{
"post_id": row["post_id"],
"sentence_id": row["sentence_id"],
"sentence_text": row["sentence_text"],
}
)
negatives = pd.DataFrame(neg_combined_rows)
negatives["clean_text"] = negatives["sentence_text"].apply(clean_sentence)
negatives = negatives[negatives["clean_text"].str.len() >= 5]
# Cap negatives to prevent class domination
if len(negatives) > max_negatives:
negatives = negatives.sample(n=max_negatives, random_state=42)
print(f" Capped NO_SYMPTOM to {max_negatives} samples")
negatives["label"] = "NO_SYMPTOM"
negatives["label_id"] = SYMPTOM_LABELS["NO_SYMPTOM"]
print(f" Total negative samples: {len(negatives)}")
return negatives[["post_id", "sentence_id", "sentence_text", "clean_text", "label", "label_id"]].reset_index(
drop=True
)
# ── Splitting ─────────────────────────────────────────────────────────────────
def split_by_post_id(
df: pd.DataFrame,
test_size: float = 0.10,
val_size: float = 0.10,
random_state: int = 42,
) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
"""Split dataset by post_id to prevent data leakage.
Uses GroupShuffleSplit so no sentences from the same post appear in
different splits.
"""
# First split: train+val vs test
gss_test = GroupShuffleSplit(n_splits=1, test_size=test_size, random_state=random_state)
train_val_idx, test_idx = next(gss_test.split(df, groups=df["post_id"]))
train_val = df.iloc[train_val_idx]
test = df.iloc[test_idx]
# Second split: train vs val
relative_val_size = val_size / (1 - test_size)
gss_val = GroupShuffleSplit(n_splits=1, test_size=relative_val_size, random_state=random_state)
train_idx, val_idx = next(gss_val.split(train_val, groups=train_val["post_id"]))
train = train_val.iloc[train_idx]
val = train_val.iloc[val_idx]
return train.reset_index(drop=True), val.reset_index(drop=True), test.reset_index(drop=True)
def compute_class_weights(train_df: pd.DataFrame) -> dict[int, float]:
"""Compute inverse-frequency class weights for CrossEntropyLoss."""
counts = train_df["label_id"].value_counts().sort_index()
total = len(train_df)
n_classes = len(counts)
weights = {}
for label_id, count in counts.items():
weights[int(label_id)] = total / (n_classes * count)
return weights
# ── Main pipeline ─────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(description="Preprocess ReDSM5 dataset")
parser.add_argument("--redsm5-dir", type=str, default=None, help="Path to redsm5 directory (with CSV files)")
parser.add_argument("--output-dir", type=str, default=None, help="Path to output directory for processed splits")
parser.add_argument("--max-negatives", type=int, default=400, help="Maximum NO_SYMPTOM samples (default: 400)")
args = parser.parse_args()
# Resolve paths
project_root = Path(__file__).parent.parent.parent.parent # backend/ml/scripts β†’ project root
redsm5_dir = Path(args.redsm5_dir) if args.redsm5_dir else project_root / "redsm5"
output_dir = (
Path(args.output_dir) if args.output_dir else (Path(__file__).parent.parent / "data" / "redsm5" / "processed")
)
output_dir.mkdir(parents=True, exist_ok=True)
print("=" * 60)
print("ReDSM5 Preprocessing Pipeline")
print("=" * 60)
print(f"Input: {redsm5_dir}")
print(f"Output: {output_dir}")
# ── Load ──
print("\n── Loading data ──")
posts, annotations = load_data(redsm5_dir)
# ── Create samples ──
print("\n── Creating positive samples (status=1) ──")
positives = create_positive_samples(annotations)
print("\n── Creating negative samples (NO_SYMPTOM) ──")
negatives = create_negative_samples(posts, annotations, max_negatives=args.max_negatives)
# ── Combine ──
combined = pd.concat([positives, negatives], ignore_index=True)
combined = combined.sample(frac=1, random_state=42).reset_index(drop=True) # Shuffle
print(f"\nTotal dataset: {len(combined)} samples across {combined['post_id'].nunique()} posts")
# ── Split ──
print("\n── Splitting by post_id (80/10/10) ──")
train, val, test = split_by_post_id(combined)
print(f" Train: {len(train)} samples ({train['post_id'].nunique()} posts)")
print(f" Val: {len(val)} samples ({val['post_id'].nunique()} posts)")
print(f" Test: {len(test)} samples ({test['post_id'].nunique()} posts)")
# Verify no post leakage
train_posts = set(train["post_id"])
val_posts = set(val["post_id"])
test_posts = set(test["post_id"])
assert len(train_posts & val_posts) == 0, "Post leakage: train ∩ val"
assert len(train_posts & test_posts) == 0, "Post leakage: train ∩ test"
assert len(val_posts & test_posts) == 0, "Post leakage: val ∩ test"
print(" βœ“ No post_id leakage across splits")
# ── Per-split label distribution ──
print("\n── Label distribution per split ──")
for name, split in [("Train", train), ("Val", val), ("Test", test)]:
print(f"\n {name}:")
for label, count in split["label"].value_counts().sort_index().items():
print(f" {label}: {count}")
# ── Class weights ──
class_weights = compute_class_weights(train)
print("\n── Class weights (inverse frequency) ──")
for label_id, weight in sorted(class_weights.items()):
label_name = [k for k, v in SYMPTOM_LABELS.items() if v == label_id][0]
print(f" {label_name} ({label_id}): {weight:.3f}")
# ── Save ──
print("\n── Saving splits ──")
train.to_csv(output_dir / "train.csv", index=False)
val.to_csv(output_dir / "val.csv", index=False)
test.to_csv(output_dir / "test.csv", index=False)
metadata = {
"label_map": SYMPTOM_LABELS,
"label_readable": SYMPTOM_READABLE,
"class_weights": class_weights,
"num_classes": len(SYMPTOM_LABELS),
"total_samples": len(combined),
"train_samples": len(train),
"val_samples": len(val),
"test_samples": len(test),
"train_posts": train["post_id"].nunique(),
"val_posts": val["post_id"].nunique(),
"test_posts": test["post_id"].nunique(),
"label_distribution": {
"train": train["label"].value_counts().to_dict(),
"val": val["label"].value_counts().to_dict(),
"test": test["label"].value_counts().to_dict(),
},
}
with open(output_dir / "metadata.json", "w") as f:
json.dump(metadata, f, indent=2)
print(f"\nSaved to {output_dir}:")
print(f" train.csv ({len(train)} rows)")
print(f" val.csv ({len(val)} rows)")
print(f" test.csv ({len(test)} rows)")
print(" metadata.json")
print("\n" + "=" * 60)
print("Preprocessing complete!")
print("Next step: python train_redsm5_model.py")
print("=" * 60)
if __name__ == "__main__":
main()