CFPVesselSeg / train.py
farrell236's picture
add src
e99a83c
import argparse
from pathlib import Path
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from augmentations import get_train_transforms, get_val_transforms
from datasets.FIVES import FIVESDataset
from models import build_model
from losses import BCEDiceLoss, compute_dice_score
def train_one_epoch(model, loader, optimizer, scaler, criterion, device, use_amp=True):
model.train()
running_loss = 0.0
running_dice = 0.0
pbar = tqdm(loader, desc="Train", leave=False)
for batch in pbar:
images = batch["image"].to(device)
labels = batch["label"].to(device)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
logits = model(images)
loss = criterion(logits, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
dice = compute_dice_score(logits.detach(), labels)
running_loss += loss.item()
running_dice += dice
avg_loss = running_loss / (pbar.n + 1)
avg_dice = running_dice / (pbar.n + 1)
pbar.set_postfix(
loss=f"{avg_loss:.4f}",
dice=f"{avg_dice:.4f}",
)
return running_loss / len(loader), running_dice / len(loader)
@torch.no_grad()
def validate(model, loader, criterion, device, use_amp=True):
model.eval()
running_loss = 0.0
running_dice = 0.0
pbar = tqdm(loader, desc="Val", leave=False)
for batch in pbar:
images = batch["image"].to(device)
labels = batch["label"].to(device)
with torch.amp.autocast("cuda", enabled=use_amp and device.type == "cuda"):
logits = model(images)
loss = criterion(logits, labels)
dice = compute_dice_score(logits, labels)
running_loss += loss.item()
running_dice += dice
avg_loss = running_loss / (pbar.n + 1)
avg_dice = running_dice / (pbar.n + 1)
pbar.set_postfix(
loss=f"{avg_loss:.4f}",
dice=f"{avg_dice:.4f}",
)
return running_loss / len(loader), running_dice / len(loader)
def save_checkpoint(path, model, optimizer, epoch, best_dice, args):
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"best_dice": best_dice,
"args": vars(args),
},
path,
)
def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dataset = FIVESDataset(
root=args.data_root,
split="train",
transform=get_train_transforms(image_size=args.image_size),
)
val_dataset = FIVESDataset(
root=args.data_root,
split="test",
transform=get_val_transforms(image_size=args.image_size),
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers,
pin_memory=True,
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=True,
)
model = build_model(
model_name=args.model,
num_classes=1,
in_channels=3,
image_size=args.image_size,
backbone=args.backbone,
pretrained=not args.no_pretrained,
base_channels=args.base_channels,
dropout=args.dropout,
).to(device)
criterion = BCEDiceLoss(
bce_weight=args.bce_weight,
dice_weight=args.dice_weight,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
scaler = torch.amp.GradScaler(enabled=args.amp and device.type == "cuda")
best_dice = -1.0
print(f"Device: {device}")
print(f"Train samples: {len(train_dataset)}")
print(f"Val samples: {len(val_dataset)}")
print(f"Image size: {args.image_size}")
print(f"Batch size: {args.batch_size}")
print(f"Pretrained: {not args.no_pretrained}")
for epoch in range(1, args.epochs + 1):
print(f"\nEpoch [{epoch:03d}/{args.epochs}]")
train_loss, train_dice = train_one_epoch(
model=model,
loader=train_loader,
optimizer=optimizer,
scaler=scaler,
criterion=criterion,
device=device,
use_amp=args.amp,
)
val_loss, val_dice = validate(
model=model,
loader=val_loader,
criterion=criterion,
device=device,
use_amp=args.amp,
)
print(
f"train_loss={train_loss:.4f} "
f"train_dice={train_dice:.4f} "
f"val_loss={val_loss:.4f} "
f"val_dice={val_dice:.4f}"
)
if val_dice > best_dice:
best_dice = val_dice
save_checkpoint(
Path(args.output_dir) / "best.pt",
model,
optimizer,
epoch,
best_dice,
args,
)
print(f"Saved best checkpoint: val_dice={best_dice:.4f}")
if epoch % args.save_every == 0:
save_checkpoint(
Path(args.output_dir) / f"epoch_{epoch:03d}.pt",
model,
optimizer,
epoch,
best_dice,
args,
)
save_checkpoint(
Path(args.output_dir) / "last.pt",
model,
optimizer,
args.epochs,
best_dice,
args,
)
print("Training complete.")
print(f"Best val Dice: {best_dice:.4f}")
def parse_args():
parser = argparse.ArgumentParser(description="Train retinal vessel segmentation model on FIVES.")
parser.add_argument("--data-root", type=str, required=True)
parser.add_argument("--output-dir", type=str, default="checkpoints/fives")
parser.add_argument("--image-size", type=int, default=512)
parser.add_argument("--epochs", type=int, default=100)
parser.add_argument("--batch-size", type=int, default=4)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--model", type=str, default="resunet", choices=["resunet", "deeplabv3", "vit"])
parser.add_argument("--backbone", type=str, default="resnet50")
parser.add_argument("--base-channels", type=int, default=32)
parser.add_argument("--dropout", type=float, default=0.0)
parser.add_argument("--no-pretrained", action="store_true")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--bce-weight", type=float, default=1.0)
parser.add_argument("--dice-weight", type=float, default=1.0)
parser.add_argument("--save-every", type=int, default=25)
parser.add_argument("--amp", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
args = parse_args()
main(args)