| import os |
| from dataclasses import dataclass |
| import torch |
| from torch.utils.data import DataLoader |
| from tqdm import tqdm |
|
|
| from sentinelscan.data.dataset import CrackSegDataset |
| from sentinelscan.data.transforms import train_transforms, val_transforms |
| from sentinelscan.modeling.unet import UNet |
| from sentinelscan.modeling.losses import bce_dice_loss, dice_score |
|
|
| @dataclass |
| class TrainConfig: |
| train_images: str = "data/images/train" |
| train_masks: str = "data/masks/train" |
| val_images: str = "data/images/val" |
| val_masks: str = "data/masks/val" |
| out_path: str = "models/best.pt" |
| epochs: int = 25 |
| batch_size: int = 8 |
| lr: float = 1e-3 |
| size: int = 512 |
| device: str = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| def train(cfg: TrainConfig): |
| os.makedirs(os.path.dirname(cfg.out_path), exist_ok=True) |
|
|
| train_ds = CrackSegDataset(cfg.train_images, cfg.train_masks, transform=train_transforms(cfg.size)) |
| val_ds = CrackSegDataset(cfg.val_images, cfg.val_masks, transform=val_transforms(cfg.size)) |
|
|
| train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True) |
| val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True) |
|
|
| model = UNet().to(cfg.device) |
| opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr) |
|
|
| best_dice = -1.0 |
|
|
| for epoch in range(1, cfg.epochs + 1): |
| model.train() |
| running_loss = 0.0 |
|
|
| for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} [train]"): |
| images = images.to(cfg.device, non_blocking=True) |
| masks = masks.to(cfg.device, non_blocking=True) |
|
|
| opt.zero_grad(set_to_none=True) |
| logits = model(images) |
| loss = bce_dice_loss(logits, masks) |
| loss.backward() |
| opt.step() |
|
|
| running_loss += loss.item() |
|
|
| avg_loss = running_loss / max(1, len(train_loader)) |
|
|
| |
| model.eval() |
| dices = [] |
| with torch.no_grad(): |
| for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} [val]"): |
| images = images.to(cfg.device, non_blocking=True) |
| masks = masks.to(cfg.device, non_blocking=True) |
| logits = model(images) |
| dices.append(dice_score(logits, masks)) |
|
|
| mean_dice = sum(dices) / max(1, len(dices)) |
|
|
| print(f"Epoch {epoch}: loss={avg_loss:.4f} val_dice={mean_dice:.4f}") |
|
|
| if mean_dice > best_dice: |
| best_dice = mean_dice |
| torch.save({"model_state": model.state_dict(), "cfg": cfg.__dict__}, cfg.out_path) |
| print(f"✅ Saved best model -> {cfg.out_path} (val_dice={best_dice:.4f})") |
|
|
| if __name__ == "__main__": |
| cfg = TrainConfig() |
| train(cfg) |