File size: 2,477 Bytes
596aaa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
prepare_data.py  —  organise raw CBIS-DDSM images into train/val folder structure.

If your downloaded images are already in  data/train/benign  etc., skip this.

Usage
-----
python prepare_data.py --images /path/to/raw/images --csv /path/to/labels.csv

CSV must have columns:  file_path, pathology
  pathology values:  BENIGN, MALIGNANT  (or benign, malignant)

Output
------
data/
  train/benign/  train/malignant/
  val/benign/    val/malignant/
"""

import argparse
import os
import shutil
import random

TRAIN_RATIO = 0.85


def prepare(images_dir: str, csv_path: str, output_dir: str, seed: int = 42) -> None:
    import csv

    random.seed(seed)

    records: list[tuple[str, str]] = []
    with open(csv_path, newline="") as f:
        reader = csv.DictReader(f)
        for row in reader:
            # normalise label
            label = row.get("pathology", row.get("label", "")).strip().lower()
            if label in ("benign", "benign_without_callback"):
                label = "benign"
            elif label in ("malignant",):
                label = "malignant"
            else:
                continue  # skip unknown labels

            img_path = os.path.join(images_dir, row.get("file_path", "").strip())
            if os.path.isfile(img_path):
                records.append((img_path, label))

    print(f"Found {len(records)} labelled images")
    random.shuffle(records)

    split = int(len(records) * TRAIN_RATIO)
    splits = {"train": records[:split], "val": records[split:]}

    for split_name, items in splits.items():
        for label in ("benign", "malignant"):
            os.makedirs(os.path.join(output_dir, split_name, label), exist_ok=True)
        for src, label in items:
            fname = os.path.basename(src)
            dst = os.path.join(output_dir, split_name, label, fname)
            shutil.copy2(src, dst)
        counts = {lbl: sum(1 for _, l in items if l == lbl) for lbl in ("benign", "malignant")}
        print(f"{split_name}: {counts}")

    print(f"Data prepared in {output_dir}/")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--images", required=True, help="Directory containing raw image files")
    parser.add_argument("--csv", required=True, help="CSV file with file_path and pathology columns")
    parser.add_argument("--output", default="data", help="Output directory")
    args = parser.parse_args()
    prepare(args.images, args.csv, args.output)