File size: 2,831 Bytes
f698f1c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 | import os
from dataclasses import dataclass
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from sentinelscan.data.dataset import CrackSegDataset
from sentinelscan.data.transforms import train_transforms, val_transforms
from sentinelscan.modeling.unet import UNet
from sentinelscan.modeling.losses import bce_dice_loss, dice_score
@dataclass
class TrainConfig:
train_images: str = "data/images/train"
train_masks: str = "data/masks/train"
val_images: str = "data/images/val"
val_masks: str = "data/masks/val"
out_path: str = "models/best.pt"
epochs: int = 25
batch_size: int = 8
lr: float = 1e-3
size: int = 512
device: str = "cuda" if torch.cuda.is_available() else "cpu"
def train(cfg: TrainConfig):
os.makedirs(os.path.dirname(cfg.out_path), exist_ok=True)
train_ds = CrackSegDataset(cfg.train_images, cfg.train_masks, transform=train_transforms(cfg.size))
val_ds = CrackSegDataset(cfg.val_images, cfg.val_masks, transform=val_transforms(cfg.size))
train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, num_workers=2, pin_memory=True)
model = UNet().to(cfg.device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr)
best_dice = -1.0
for epoch in range(1, cfg.epochs + 1):
model.train()
running_loss = 0.0
for images, masks in tqdm(train_loader, desc=f"Epoch {epoch}/{cfg.epochs} [train]"):
images = images.to(cfg.device, non_blocking=True)
masks = masks.to(cfg.device, non_blocking=True)
opt.zero_grad(set_to_none=True)
logits = model(images)
loss = bce_dice_loss(logits, masks)
loss.backward()
opt.step()
running_loss += loss.item()
avg_loss = running_loss / max(1, len(train_loader))
# Validation
model.eval()
dices = []
with torch.no_grad():
for images, masks in tqdm(val_loader, desc=f"Epoch {epoch}/{cfg.epochs} [val]"):
images = images.to(cfg.device, non_blocking=True)
masks = masks.to(cfg.device, non_blocking=True)
logits = model(images)
dices.append(dice_score(logits, masks))
mean_dice = sum(dices) / max(1, len(dices))
print(f"Epoch {epoch}: loss={avg_loss:.4f} val_dice={mean_dice:.4f}")
if mean_dice > best_dice:
best_dice = mean_dice
torch.save({"model_state": model.state_dict(), "cfg": cfg.__dict__}, cfg.out_path)
print(f"✅ Saved best model -> {cfg.out_path} (val_dice={best_dice:.4f})")
if __name__ == "__main__":
cfg = TrainConfig()
train(cfg) |