File size: 3,927 Bytes
c5b096b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""DiaFoot.AI v2 — Data Composition Ablation.

The most important experiment: prove that adding healthy + non-DFU data helps.

Trains 3 segmentation models:
    (a) DFU-only: Train only on DFU images
    (b) DFU + non-DFU: Train on DFU + non-DFU (current best)
    (c) All: Train on all three classes (including healthy)

Usage:
    python scripts/run_ablation.py --variant dfu_only --device cuda --epochs 50
    python scripts/run_ablation.py --variant dfu_nondfu --device cuda --epochs 50
    python scripts/run_ablation.py --variant all --device cuda --epochs 50
"""

from __future__ import annotations

import argparse
import logging
import sys
from pathlib import Path

import torch

sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from scripts.train import build_dataloaders
from src.models.unetpp import build_unetpp
from src.training.losses import DiceCELoss
from src.training.schedulers import CosineAnnealingWithWarmup
from src.training.trainer import TrainConfig, Trainer

ABLATION_CONFIGS = {
    "dfu_only": {
        "classes": ["dfu"],
        "checkpoint_dir": "checkpoints/ablation_dfu_only",
        "description": "DFU images only (no negatives)",
    },
    "dfu_nondfu": {
        "classes": ["dfu", "non_dfu"],
        "checkpoint_dir": "checkpoints/ablation_dfu_nondfu",
        "description": "DFU + non-DFU wounds (current approach)",
    },
    "all": {
        "classes": None,  # No filter = all classes
        "checkpoint_dir": "checkpoints/ablation_all",
        "description": "All classes including healthy",
    },
}


def main() -> None:
    """Run data composition ablation."""
    parser = argparse.ArgumentParser(description="Data Composition Ablation")
    parser.add_argument(
        "--variant",
        type=str,
        required=True,
        choices=list(ABLATION_CONFIGS.keys()),
    )
    parser.add_argument("--splits-dir", type=str, default="data/splits")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--epochs", type=int, default=50)
    parser.add_argument("--batch-size", type=int, default=16)
    parser.add_argument("--num-workers", type=int, default=8)
    parser.add_argument("--verbose", action="store_true")
    args = parser.parse_args()

    logging.basicConfig(
        level=logging.DEBUG if args.verbose else logging.INFO,
        format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
        datefmt="%H:%M:%S",
    )
    logger = logging.getLogger("ablation")

    config = ABLATION_CONFIGS[args.variant]
    logger.info("Ablation: %s — %s", args.variant, config["description"])

    model = build_unetpp(
        encoder_name="efficientnet-b4",
        encoder_weights="imagenet",
        classes=1,
        decoder_attention_type="scse",
    )

    train_loader, val_loader = build_dataloaders(
        args.splits_dir,
        args.batch_size,
        args.num_workers,
        filter_classes=config["classes"],
    )
    logger.info(
        "Data: %d train, %d val batches",
        len(train_loader),
        len(val_loader),
    )

    loss_fn = DiceCELoss()
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=1e-4,
        weight_decay=1e-2,
    )
    scheduler = CosineAnnealingWithWarmup(
        optimizer,
        warmup_epochs=5,
        max_epochs=args.epochs,
    )

    torch.manual_seed(42)

    trainer_config = TrainConfig(
        epochs=args.epochs,
        precision="bf16-mixed",
        compile_model=False,
        gradient_clip=1.0,
        checkpoint_dir=config["checkpoint_dir"],
        monitor_metric="val/loss",
        monitor_mode="min",
        device=args.device,
        early_stopping_patience=15,
    )

    trainer = Trainer(model=model, config=trainer_config)
    trainer.fit(train_loader, val_loader, loss_fn, optimizer, scheduler)
    logger.info("Ablation %s complete.", args.variant)


if __name__ == "__main__":
    main()