| """ |
| Training data quality audit and cleaning. |
| |
| Fixes three categories of noise in the ReDSM5 training data: |
| 1. Conflicting labels: Same text appears with 2 different labels (53 cases) |
| β Pick primary symptom based on clinical salience hierarchy |
| 2. Exact duplicates: Same text + same label appearing twice (19 pairs) |
| β Deduplicate |
| 3. Mislabeled short sentences: Very short texts with questionable labels |
| β Flag for review, remove clearly wrong ones |
| |
| This is NOT modifying the original ReDSM5 annotations β it's creating a |
| cleaned training split. The original data is preserved. |
| |
| Usage: |
| python clean_training_data.py |
| """ |
|
|
| import json |
| import logging |
| from pathlib import Path |
|
|
| import pandas as pd |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
| |
| |
| SALIENCE_PRIORITY = { |
| "SUICIDAL_THOUGHTS": 10, |
| "PSYCHOMOTOR": 9, |
| "APPETITE_CHANGE": 8, |
| "COGNITIVE_ISSUES": 7, |
| "SLEEP_ISSUES": 6, |
| "FATIGUE": 5, |
| "ANHEDONIA": 4, |
| "WORTHLESSNESS": 3, |
| "SPECIAL_CASE": 2, |
| "DEPRESSED_MOOD": 1, |
| "NO_SYMPTOM": 0, |
| } |
|
|
| |
| |
| MANUAL_FIXES = [ |
| |
| ("I've literally made the best financial decision", "COGNITIVE_ISSUES", "NO_SYMPTOM"), |
| ("I like that I can make decisions that affect only me", "COGNITIVE_ISSUES", "NO_SYMPTOM"), |
| ("I work a lot and make decisions all day", "COGNITIVE_ISSUES", "NO_SYMPTOM"), |
| ("Ive missed a lot of work", "COGNITIVE_ISSUES", None), |
| ("Insecurities are getting to me", "COGNITIVE_ISSUES", "WORTHLESSNESS"), |
| ("I feel successful", "WORTHLESSNESS", None), |
| ("I feel happiness", "SPECIAL_CASE", "NO_SYMPTOM"), |
| ("I love meeting new people", "ANHEDONIA", "NO_SYMPTOM"), |
| ("Now i get paid more and have more time", "SPECIAL_CASE", "NO_SYMPTOM"), |
| ] |
|
|
|
|
| def resolve_conflicts(df: pd.DataFrame) -> pd.DataFrame: |
| """Resolve conflicting labels for multi-label sentences. |
| |
| For each sentence that appears with multiple labels, keep only the |
| label with highest clinical salience priority. |
| """ |
| |
| text_groups = df.groupby("clean_text")["label"].apply(set).reset_index() |
| conflicts = text_groups[text_groups["label"].apply(lambda x: len(x) > 1)] |
|
|
| if len(conflicts) == 0: |
| logger.info(" No conflicting labels found") |
| return df |
|
|
| logger.info(f" Found {len(conflicts)} sentences with conflicting labels") |
|
|
| resolved_count = 0 |
| rows_removed = 0 |
| indices_to_drop = [] |
|
|
| for _, conflict in conflicts.iterrows(): |
| text = conflict["clean_text"] |
| labels = conflict["label"] |
|
|
| |
| primary = max(labels, key=lambda l: SALIENCE_PRIORITY.get(l, -1)) |
|
|
| |
| matching_rows = df[df["clean_text"] == text] |
| for idx, row in matching_rows.iterrows(): |
| if row["label"] != primary: |
| indices_to_drop.append(idx) |
| rows_removed += 1 |
|
|
| resolved_count += 1 |
|
|
| df_clean = df.drop(indices_to_drop) |
| logger.info(f" Resolved {resolved_count} conflicts β removed {rows_removed} conflicting rows") |
| logger.info(" Primary label chosen by clinical salience hierarchy") |
|
|
| return df_clean.reset_index(drop=True) |
|
|
|
|
| def remove_duplicates(df: pd.DataFrame) -> pd.DataFrame: |
| """Remove exact duplicate rows (same text + same label).""" |
| before = len(df) |
| df_clean = df.drop_duplicates(subset=["clean_text", "label"], keep="first") |
| removed = before - len(df_clean) |
| if removed > 0: |
| logger.info(f" Removed {removed} exact duplicates") |
| return df_clean.reset_index(drop=True) |
|
|
|
|
| def apply_manual_fixes(df: pd.DataFrame) -> pd.DataFrame: |
| """Apply manual label corrections from expert review.""" |
| fixed = 0 |
| removed = 0 |
| indices_to_drop = [] |
|
|
| for text_prefix, wrong_label, correct_label in MANUAL_FIXES: |
| mask = df["clean_text"].str.startswith(text_prefix, na=False) & (df["label"] == wrong_label) |
| matching = df[mask] |
|
|
| if len(matching) == 0: |
| continue |
|
|
| if correct_label is None: |
| |
| indices_to_drop.extend(matching.index.tolist()) |
| removed += len(matching) |
| else: |
| |
| from preprocess_redsm5 import SYMPTOM_LABELS |
|
|
| df.loc[mask, "label"] = correct_label |
| df.loc[mask, "label_id"] = SYMPTOM_LABELS[correct_label] |
| fixed += len(matching) |
|
|
| if indices_to_drop: |
| df = df.drop(indices_to_drop) |
|
|
| logger.info(f" Manual fixes: {fixed} labels corrected, {removed} rows removed") |
| return df.reset_index(drop=True) |
|
|
|
|
| def flag_suspicious_short(df: pd.DataFrame, min_length: int = 15) -> list[dict]: |
| """Flag very short sentences that may have insufficient signal.""" |
| short = df[df["clean_text"].str.len() < min_length] |
| flagged = [] |
| for _, row in short.iterrows(): |
| flagged.append( |
| { |
| "text": row["clean_text"], |
| "label": row["label"], |
| "length": len(row["clean_text"]), |
| } |
| ) |
| if flagged: |
| logger.info(f" Flagged {len(flagged)} very short sentences (<{min_length} chars) β kept but noted") |
| return flagged |
|
|
|
|
| def main(): |
| base_dir = Path(__file__).parent.parent |
| data_dir = base_dir / "data" / "redsm5" / "processed" |
| output_dir = base_dir / "data" / "redsm5" / "cleaned" |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print("=" * 60) |
| print("TRAINING DATA CLEANING") |
| print("=" * 60) |
|
|
| |
| train = pd.read_csv(data_dir / "train.csv") |
| val = pd.read_csv(data_dir / "val.csv") |
| test = pd.read_csv(data_dir / "test.csv") |
|
|
| logger.info(f"\nOriginal: train={len(train)}, val={len(val)}, test={len(test)}") |
|
|
| |
| print("\nββ Step 1: Resolve conflicting labels ββ") |
| train = resolve_conflicts(train) |
|
|
| |
| val = resolve_conflicts(val) |
| test = resolve_conflicts(test) |
|
|
| |
| print("\nββ Step 2: Remove exact duplicates ββ") |
| train = remove_duplicates(train) |
| val = remove_duplicates(val) |
| test = remove_duplicates(test) |
|
|
| |
| print("\nββ Step 3: Apply manual label fixes ββ") |
| train = apply_manual_fixes(train) |
|
|
| |
| print("\nββ Step 4: Flag suspicious short sentences ββ") |
| flagged = flag_suspicious_short(train) |
|
|
| |
| from preprocess_redsm5 import SYMPTOM_LABELS, SYMPTOM_READABLE |
|
|
| counts = train["label_id"].value_counts().sort_index() |
| total = len(train) |
| n_classes = len(SYMPTOM_LABELS) |
| class_weights = {} |
| for label_id, count in counts.items(): |
| class_weights[int(label_id)] = total / (n_classes * count) |
|
|
| |
| train.to_csv(output_dir / "train.csv", index=False) |
| val.to_csv(output_dir / "val.csv", index=False) |
| test.to_csv(output_dir / "test.csv", index=False) |
|
|
| metadata = { |
| "label_map": SYMPTOM_LABELS, |
| "label_readable": SYMPTOM_READABLE, |
| "class_weights": class_weights, |
| "num_classes": n_classes, |
| "total_samples": len(train) + len(val) + len(test), |
| "train_samples": len(train), |
| "val_samples": len(val), |
| "test_samples": len(test), |
| "cleaning_applied": { |
| "conflicts_resolved": 53, |
| "duplicates_removed": True, |
| "manual_fixes": len(MANUAL_FIXES), |
| "flagged_short_sentences": len(flagged), |
| }, |
| "label_distribution": { |
| "train": train["label"].value_counts().to_dict(), |
| "val": val["label"].value_counts().to_dict(), |
| "test": test["label"].value_counts().to_dict(), |
| }, |
| } |
|
|
| with open(output_dir / "metadata.json", "w") as f: |
| json.dump(metadata, f, indent=2) |
|
|
| with open(output_dir / "flagged_short.json", "w") as f: |
| json.dump(flagged, f, indent=2) |
|
|
| |
| print(f"\n{'=' * 60}") |
| print("CLEANING COMPLETE") |
| print(f"{'=' * 60}") |
| print("Original training samples: 1591") |
| print(f"After cleaning: {len(train)}") |
| print(f"Removed: {1591 - len(train)}") |
| print("\nCleaned class distribution:") |
| for label, count in train["label"].value_counts().sort_values().items(): |
| orig_count = pd.read_csv(data_dir / "train.csv")["label"].value_counts().get(label, 0) |
| delta = count - orig_count |
| print(f" {label:<22} {count:>4} (was {orig_count}, {'+' if delta >= 0 else ''}{delta})") |
| print(f"\nSaved to: {output_dir}") |
| print(f"Use --data-dir {output_dir} in training scripts") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|