File size: 2,831 Bytes
f698f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from sentinelscan.data.dataset import CrackSegDataset
from sentinelscan.data.transforms import train_transforms, val_transforms
from sentinelscan.modeling.unet import UNet
from sentinelscan.modeling.losses import bce_dice_loss, dice_score

@dataclass
class TrainConfig:
    train_images: str = "data/images/train"
    train_masks: str = "data/masks/train"
    val_images: str = "data/images/val"
    val_masks: str = "data/masks/val"
    out_path: str = "models/best.pt"
    epochs: int = 25
    batch_size: int = 8
    lr: float = 1e-3
    size: int = 512
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

def train(cfg: TrainConfig):
    os.makedirs(os.path.dirname(cfg.out_path), exist_ok=True)

    train_ds = CrackSegDataset(cfg.train_images, cfg.train_masks, transform=train_transforms(cfg.size))
    val_ds   = CrackSegDataset(cfg.val_images, cfg.val_masks, transform=val_transforms(cfg.size))

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model = UNet().to(cfg.device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)

    best_dice = -1.0

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        running_loss = 0.0

        for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} [train]"):
            images = images.to(cfg.device, non_blocking=True)
            masks  = masks.to(cfg.device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            logits = model(images)
            loss = bce_dice_loss(logits, masks)
            loss.backward()
            opt.step()

            running_loss += loss.item()

        avg_loss = running_loss / max(1, len(train_loader))

        # Validation
        model.eval()
        dices = []
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} [val]"):
                images = images.to(cfg.device, non_blocking=True)
                masks  = masks.to(cfg.device, non_blocking=True)
                logits = model(images)
                dices.append(dice_score(logits, masks))

        mean_dice = sum(dices) / max(1, len(dices))

        print(f"Epoch {epoch}: loss={avg_loss:.4f} val_dice={mean_dice:.4f}")

        if mean_dice > best_dice:
            best_dice = mean_dice
            torch.save({"model_state": model.state_dict(), "cfg": cfg.__dict__}, cfg.out_path)
            print(f"✅ Saved best model -> {cfg.out_path} (val_dice={best_dice:.4f})")

if __name__ == "__main__":
    cfg = TrainConfig()
    train(cfg)