import os from dataclasses import dataclass import torch from torch.utils.data import DataLoader from tqdm import tqdm from sentinelscan.data.dataset import CrackSegDataset from sentinelscan.data.transforms import train_transforms, val_transforms from sentinelscan.modeling.unet import UNet from sentinelscan.modeling.losses import bce_dice_loss, dice_score @dataclass class TrainConfig: train_images: str = "data/images/train" train_masks: str = "data/masks/train" val_images: str = "data/images/val" val_masks: str = "data/masks/val" out_path: str = "models/best.pt" epochs: int = 25 batch_size: int = 8 lr: float = 1e-3 size: int = 512 device: str = "cuda" if torch.cuda.is_available() else "cpu" def train(cfg: TrainConfig): os.makedirs(os.path.dirname(cfg.out_path), exist_ok=True) train_ds = CrackSegDataset(cfg.train_images, cfg.train_masks, transform=train_transforms(cfg.size)) val_ds = CrackSegDataset(cfg.val_images, cfg.val_masks, transform=val_transforms(cfg.size)) train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True) val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True) model = UNet().to(cfg.device) opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr) best_dice = -1.0 for epoch in range(1, cfg.epochs + 1): model.train() running_loss = 0.0 for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} [train]"): images = images.to(cfg.device, non_blocking=True) masks = masks.to(cfg.device, non_blocking=True) opt.zero_grad(set_to_none=True) logits = model(images) loss = bce_dice_loss(logits, masks) loss.backward() opt.step() running_loss += loss.item() avg_loss = running_loss / max(1, len(train_loader)) # Validation model.eval() dices = [] with torch.no_grad(): for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} [val]"): images = images.to(cfg.device, non_blocking=True) masks = masks.to(cfg.device, non_blocking=True) logits = model(images) dices.append(dice_score(logits, masks)) mean_dice = sum(dices) / max(1, len(dices)) print(f"Epoch {epoch}: loss={avg_loss:.4f} val_dice={mean_dice:.4f}") if mean_dice > best_dice: best_dice = mean_dice torch.save({"model_state": model.state_dict(), "cfg": cfg.__dict__}, cfg.out_path) print(f"✅ Saved best model -> {cfg.out_path} (val_dice={best_dice:.4f})") if __name__ == "__main__": cfg = TrainConfig() train(cfg)