scribble-segmentation / train_global_unet.py
Enorenio's picture
Add training/model definitions (comments stripped)
38429f5 verified
import os
import sys
import json
import math
import random
import argparse
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
DATA_ROOT = Path('dataset')
TRAIN_IMG = DATA_ROOT / 'train' / 'images'
TRAIN_SC = DATA_ROOT / 'train' / 'scribbles'
TRAIN_GT = DATA_ROOT / 'train' / 'ground_truth'
TEST_IMG = DATA_ROOT / 'test1' / 'images'
TEST_SC = DATA_ROOT / 'test1' / 'scribbles'
TEST_PRED = DATA_ROOT / 'test1' / 'predictions'
TRAIN_H = int(os.environ.get('TRAIN_H', '384'))
TRAIN_W = int(os.environ.get('TRAIN_W', '512'))
ORIG_H, ORIG_W = (375, 500)
CKPT_DIR = Path(os.environ.get('CKPT_DIR', 'runs_global_unet'))
CKPT_DIR.mkdir(exist_ok=True)
def list_train_pairs():
pairs = []
for img_path in sorted(TRAIN_IMG.iterdir()):
if img_path.name.startswith('.'):
continue
stem = img_path.stem
sc_path = TRAIN_SC / f'{stem}.png'
gt_path = TRAIN_GT / f'{stem}.png'
if sc_path.exists() and gt_path.exists():
pairs.append((stem, img_path, sc_path, gt_path))
return pairs
def list_test_pairs():
pairs = []
for img_path in sorted(TEST_IMG.iterdir()):
if img_path.name.startswith('.'):
continue
stem = img_path.stem
sc_path = TEST_SC / f'{stem}.png'
if sc_path.exists():
pairs.append((stem, img_path, sc_path))
return pairs
def list_pseudo_pairs(pseudo_label_method='v3v4'):
pairs = []
for setname in ['test1', 'test2']:
img_dir = Path(f'dataset/{setname}/images')
sc_dir = Path(f'dataset/{setname}/scribbles')
gt_dir = Path(f'dataset/{setname}/predictions_{pseudo_label_method}')
if not gt_dir.exists():
continue
for ip in sorted(img_dir.iterdir()):
if ip.name.startswith('.'):
continue
stem = ip.stem
sp = sc_dir / f'{stem}.png'
gp = gt_dir / f'{stem}.png'
if sp.exists() and gp.exists():
pairs.append((stem, ip, sp, gp))
return pairs
def load_palette():
any_gt = next(TRAIN_GT.glob('*.png'))
return Image.open(any_gt).getpalette()
def encode_scribble(sc):
bg_ch = (sc == 0).astype(np.float32)
fg_ch = (sc == 1).astype(np.float32)
return np.stack([bg_ch, fg_ch], axis=0)
def random_affine(img, sc, gt, rng):
H, W = img.shape[:2]
angle = rng.uniform(-12, 12)
scale = rng.uniform(0.85, 1.2)
tx = rng.uniform(-0.05, 0.05) * W
ty = rng.uniform(-0.05, 0.05) * H
cx, cy = (W / 2, H / 2)
a = math.radians(angle)
cos_a, sin_a = (math.cos(a) * scale, math.sin(a) * scale)
M = np.array([[cos_a, -sin_a, (1 - cos_a) * cx + sin_a * cy + tx], [sin_a, cos_a, (1 - cos_a) * cy - sin_a * cx + ty]], dtype=np.float32)
import cv2
img_a = cv2.warpAffine(img, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
sc_a = cv2.warpAffine(sc, M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=255)
gt_a = cv2.warpAffine(gt, M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
return (img_a, sc_a, gt_a)
def color_jitter(img, rng):
img_f = img.astype(np.float32) / 255.0
img_f = img_f * rng.uniform(0.8, 1.2)
mean = img_f.mean(axis=(0, 1), keepdims=True)
img_f = (img_f - mean) * rng.uniform(0.8, 1.2) + mean
if rng.random() < 0.7:
gray = img_f.mean(axis=2, keepdims=True)
img_f = img_f * rng.uniform(0.7, 1.3) + gray * (1 - rng.uniform(0.7, 1.3))
img_f = np.clip(img_f, 0, 1)
return (img_f * 255).astype(np.uint8)
class ScribbleSegDataset(Dataset):
def __init__(self, pairs, train=True, image_size=(TRAIN_H, TRAIN_W), cutmix_p=0.0):
self.pairs = pairs
self.train = train
self.H, self.W = image_size
self.cutmix_p = cutmix_p
self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
def __len__(self):
return len(self.pairs)
def _load_one(self, idx):
import cv2
stem, img_p, sc_p, gt_p = self.pairs[idx]
img = np.array(Image.open(img_p).convert('RGB'))
sc = np.array(Image.open(sc_p).convert('L'))
gt = np.array(Image.open(gt_p))
if img.shape[:2] != (self.H, self.W):
img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
sc = cv2.resize(sc, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
gt = cv2.resize(gt, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
return (stem, img, sc, gt)
def __getitem__(self, idx):
stem, img, sc, gt = self._load_one(idx)
rng = random.Random()
if self.train:
if rng.random() < 0.5:
img = img[:, ::-1, :].copy()
sc = sc[:, ::-1].copy()
gt = gt[:, ::-1].copy()
img, sc, gt = random_affine(img, sc, gt, rng)
img = color_jitter(img, rng)
if rng.random() < 0.3:
drop_mask = (sc != 255) & (np.random.rand(*sc.shape) < 0.3)
sc = sc.copy()
sc[drop_mask] = 255
if self.cutmix_p > 0 and rng.random() < self.cutmix_p:
j = rng.randint(0, len(self.pairs) - 1)
_, img2, sc2, gt2 = self._load_one(j)
rh = rng.randint(int(0.3 * self.H), int(0.6 * self.H))
rw = rng.randint(int(0.3 * self.W), int(0.6 * self.W))
ry = rng.randint(0, self.H - rh)
rx = rng.randint(0, self.W - rw)
img = img.copy()
sc = sc.copy()
gt = gt.copy()
img[ry:ry + rh, rx:rx + rw] = img2[ry:ry + rh, rx:rx + rw]
sc[ry:ry + rh, rx:rx + rw] = sc2[ry:ry + rh, rx:rx + rw]
gt[ry:ry + rh, rx:rx + rw] = gt2[ry:ry + rh, rx:rx + rw]
img_f = img.astype(np.float32) / 255.0
img_f = (img_f - self.mean) / self.std
img_t = torch.from_numpy(img_f.transpose(2, 0, 1))
sc_enc = encode_scribble(sc)
sc_t = torch.from_numpy(sc_enc)
x = torch.cat([img_t, sc_t], dim=0)
gt_bin = (gt > 0).astype(np.float32)
y = torch.from_numpy(gt_bin)
return (x, y, stem)
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
def forward(self, x):
return self.block(x)
class UNet(nn.Module):
def __init__(self, in_ch=5, base=48, out_ch=1):
super().__init__()
c1, c2, c3, c4, c5 = (base, base * 2, base * 4, base * 8, base * 16)
self.enc1 = ConvBlock(in_ch, c1)
self.enc2 = ConvBlock(c1, c2)
self.enc3 = ConvBlock(c2, c3)
self.enc4 = ConvBlock(c3, c4)
self.bottleneck = ConvBlock(c4, c5)
self.pool = nn.MaxPool2d(2)
self.up4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.dec4 = ConvBlock(c5 + c4, c4)
self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.dec3 = ConvBlock(c4 + c3, c3)
self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.dec2 = ConvBlock(c3 + c2, c2)
self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
self.dec1 = ConvBlock(c2 + c1, c1)
self.head = nn.Conv2d(c1, out_ch, 1)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
def forward(self, x):
e1 = self.enc1(x)
e2 = self.enc2(self.pool(e1))
e3 = self.enc3(self.pool(e2))
e4 = self.enc4(self.pool(e3))
b = self.bottleneck(self.pool(e4))
d4 = self.dec4(torch.cat([self.up4(b), e4], 1))
d3 = self.dec3(torch.cat([self.up3(d4), e3], 1))
d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
return self.head(d1)
def soft_dice_loss(logits, target, eps=1e-06):
p = torch.sigmoid(logits).squeeze(1)
inter = (p * target).sum(dim=(1, 2))
denom = p.sum(dim=(1, 2)) + target.sum(dim=(1, 2))
dice = (2 * inter + eps) / (denom + eps)
return 1 - dice.mean()
def combined_loss(logits, target):
bce = F.binary_cross_entropy_with_logits(logits.squeeze(1), target)
dice = soft_dice_loss(logits, target)
return 0.5 * bce + 0.5 * dice
def compute_iou(pred_bin, gt_bin, cls):
p = pred_bin == cls
g = gt_bin == cls
inter = np.logical_and(p, g).sum()
union = np.logical_or(p, g).sum()
return inter / union if union > 0 else 0.0
def evaluate_predictions(preds, gts):
bg, fg = ([], [])
for p, g in zip(preds, gts):
bg.append(compute_iou(p, g, 0))
fg.append(compute_iou(p, g, 1))
bg = np.mean(bg)
fg = np.mean(fg)
return (bg, fg, (bg + fg) / 2)
def train_one_fold(train_pairs, val_pairs, epochs, batch_size, lr, fold_id, device, base=48, cutmix_p=0.0):
train_ds = ScribbleSegDataset(train_pairs, train=True, cutmix_p=cutmix_p)
val_ds = ScribbleSegDataset(val_pairs, train=False)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
model = UNet(in_ch=5, base=base, out_ch=1).to(device)
n_params = sum((p.numel() for p in model.parameters()))
print(f'[fold {fold_id}] U-Net params: {n_params / 1000000.0:.2f}M (base={base}), train={len(train_ds)}, val={len(val_ds)}')
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0001)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs, eta_min=lr / 30)
scaler = torch.amp.GradScaler('cuda')
best_miou = -1.0
best_state = None
log = []
patience = 25
bad_epochs = 0
for epoch in range(epochs):
model.train()
train_loss = 0.0
n = 0
for x, y, _ in train_dl:
x, y = (x.to(device, non_blocking=True), y.to(device, non_blocking=True))
opt.zero_grad(set_to_none=True)
with torch.amp.autocast('cuda', dtype=torch.float16):
logits = model(x)
loss = combined_loss(logits, y)
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
train_loss += loss.item() * x.size(0)
n += x.size(0)
train_loss /= n
sched.step()
model.eval()
all_p, all_g = ([], [])
with torch.no_grad():
for x, y, _ in val_dl:
x = x.to(device, non_blocking=True)
with torch.amp.autocast('cuda', dtype=torch.float16):
logits = model(x)
p = (torch.sigmoid(logits).squeeze(1).float().cpu().numpy() > 0.5).astype(np.uint8)
g = y.numpy().astype(np.uint8)
for i in range(p.shape[0]):
all_p.append(p[i])
all_g.append(g[i])
bg, fg, miou = evaluate_predictions(all_p, all_g)
log.append({'epoch': epoch, 'loss': train_loss, 'val_bg': bg, 'val_fg': fg, 'val_miou': miou})
print(f'[fold {fold_id} ep {epoch:03d}] loss={train_loss:.4f} val: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} lr={sched.get_last_lr()[0]:.2e}')
if miou > best_miou:
best_miou = miou
best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
bad_epochs = 0
else:
bad_epochs += 1
if bad_epochs >= patience:
print(f'[fold {fold_id}] early stopping at epoch {epoch} (best mIoU={best_miou:.4f})')
break
fold_dir = CKPT_DIR / f'fold_{fold_id}'
fold_dir.mkdir(exist_ok=True)
torch.save(best_state, fold_dir / 'best.pth')
with open(fold_dir / 'log.json', 'w') as f:
json.dump(log, f, indent=2)
return best_miou
def cmd_train(args):
set_seed(args.seed)
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')
pairs = list_train_pairs()
print(f'Training pairs: {len(pairs)}')
pseudo_pairs = []
if getattr(args, 'pseudo_method', None):
pseudo_pairs = list_pseudo_pairs(args.pseudo_method)
print(f'Pseudo-labeled pairs ({args.pseudo_method}): {len(pseudo_pairs)}')
rng = np.random.RandomState(args.seed)
indices = np.arange(len(pairs))
rng.shuffle(indices)
if args.folds == 1:
n_val = max(1, len(pairs) // 5)
splits = [(indices[n_val:], indices[:n_val])]
else:
fold_arr = np.array_split(indices, args.folds)
splits = []
for k in range(args.folds):
val_idx = fold_arr[k]
train_idx = np.concatenate([fold_arr[i] for i in range(args.folds) if i != k])
splits.append((train_idx, val_idx))
fold_mious = []
for k, (train_idx, val_idx) in enumerate(splits):
train_pairs = [pairs[i] for i in train_idx]
if pseudo_pairs:
train_pairs = train_pairs + pseudo_pairs
val_pairs = [pairs[i] for i in val_idx]
print(f'\n=== Fold {k + 1}/{len(splits)}: train={len(train_pairs)} ({len(train_idx)} real + {len(pseudo_pairs)} pseudo), val={len(val_pairs)} ===')
miou = train_one_fold(train_pairs, val_pairs, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, fold_id=k, device=device, base=args.base, cutmix_p=args.cutmix_p)
fold_mious.append(miou)
print('\n=== Cross-validation summary ===')
for k, m in enumerate(fold_mious):
print(f' fold {k}: {m:.4f}')
print(f' mean: {np.mean(fold_mious):.4f} (+/- {np.std(fold_mious):.4f})')
def tta_predict(model, x, device, scales=(1.0,)):
model.eval()
H, W = (x.shape[-2], x.shape[-1])
probs = []
with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.float16):
for s in scales:
if s == 1.0:
xs = x
else:
new_h = int(round(H * s / 32) * 32)
new_w = int(round(W * s / 32) * 32)
rgb = F.interpolate(x[:, :3], size=(new_h, new_w), mode='bilinear', align_corners=False)
sc = F.interpolate(x[:, 3:], size=(new_h, new_w), mode='nearest')
xs = torch.cat([rgb, sc], dim=1)
p1 = torch.sigmoid(model(xs))
p2 = torch.sigmoid(model(torch.flip(xs, dims=[3])))
p2 = torch.flip(p2, dims=[3])
p = (p1 + p2) / 2
if p.shape[-2:] != (H, W):
p = F.interpolate(p, size=(H, W), mode='bilinear', align_corners=False)
probs.append(p)
return (sum(probs) / len(probs)).squeeze().float().cpu().numpy()
def cmd_predict(args):
import cv2
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
fold_dirs = sorted(CKPT_DIR.glob('fold_*'))
fold_dirs = [f for f in fold_dirs if (f / 'best.pth').exists()]
if not fold_dirs:
print('No trained models found.')
sys.exit(1)
print(f'Ensembling {len(fold_dirs)} folds.')
models = []
for fd in fold_dirs:
m = UNet(in_ch=5, base=args.base, out_ch=1).to(device)
m.load_state_dict(torch.load(fd / 'best.pth', map_location=device))
m.eval()
models.append(m)
palette = load_palette()
test_pairs = list_test_pairs()
TEST_PRED.mkdir(parents=True, exist_ok=True)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
for stem, img_p, sc_p in test_pairs:
img = np.array(Image.open(img_p).convert('RGB'))
sc = np.array(Image.open(sc_p).convert('L'))
img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
img_t = torch.from_numpy(img_f.transpose(2, 0, 1))
sc_t = torch.from_numpy(encode_scribble(sc_r))
x = torch.cat([img_t, sc_t], dim=0).unsqueeze(0).to(device)
prob_sum = None
for m in models:
p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3))
prob_sum = p if prob_sum is None else prob_sum + p
prob = prob_sum / len(models)
prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
pred = (prob_full > 0.5).astype(np.uint8)
pred_snap = pred.copy()
pred_snap[sc == 0] = 0
pred_snap[sc == 1] = 1
out_img = Image.fromarray(pred_snap.astype(np.uint8), mode='P')
out_img.putpalette(palette)
out_img.save(TEST_PRED / f'{stem}.png')
print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}')
def cmd_eval_train(args):
import cv2
device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
pairs = list_train_pairs()
rng = np.random.RandomState(args.seed)
indices = np.arange(len(pairs))
rng.shuffle(indices)
folds = np.array_split(indices, args.folds)
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
train_pred_dir = Path('dataset/train/predictions')
if args.save:
train_pred_dir.mkdir(exist_ok=True)
palette = load_palette()
all_p, all_g = ([], [])
for k in range(args.folds):
ckpt = CKPT_DIR / f'fold_{k}' / 'best.pth'
if not ckpt.exists():
print(f'skip fold {k} - no checkpoint')
continue
model = UNet(in_ch=5, base=args.base, out_ch=1).to(device)
model.load_state_dict(torch.load(ckpt, map_location=device))
model.eval()
val_idx = folds[k]
for i in val_idx:
stem, img_p, sc_p, gt_p = pairs[i]
img = np.array(Image.open(img_p).convert('RGB'))
sc = np.array(Image.open(sc_p).convert('L'))
gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8)
img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device)
prob = tta_predict(model, x, device, scales=(0.7, 1.0, 1.3))
prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
pred = (prob_full > 0.5).astype(np.uint8)
pred[sc == 0] = 0
pred[sc == 1] = 1
all_p.append(pred)
all_g.append(gt)
if args.save:
out_img = Image.fromarray(pred.astype(np.uint8), mode='P')
out_img.putpalette(palette)
out_img.save(train_pred_dir / f'{stem}.png')
if args.folds == 1:
break
bg, fg, miou = evaluate_predictions(all_p, all_g)
print(f'Held-out CV: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} (n={len(all_p)} images)')
if args.save:
print(f'Saved {len(all_p)} train predictions to {train_pred_dir}')
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def main():
p = argparse.ArgumentParser()
sub = p.add_subparsers(dest='cmd')
pt = sub.add_parser('train')
pt.add_argument('--epochs', type=int, default=120)
pt.add_argument('--batch-size', type=int, default=8)
pt.add_argument('--lr', type=float, default=0.001)
pt.add_argument('--folds', type=int, default=1)
pt.add_argument('--seed', type=int, default=42)
pt.add_argument('--gpu', type=int, default=0)
pt.add_argument('--base', type=int, default=48, help='U-Net base channel count')
pt.add_argument('--ckpt-suffix', type=str, default='', help='Suffix for runs_global_unet dir')
pt.add_argument('--cutmix-p', type=float, default=0.0, help='Probability of CutMix per sample')
pt.add_argument('--pseudo-method', type=str, default='', help="If set (e.g. 'v3v4'), use that method's predictions on test1+test2 as additional pseudo-labeled training data.")
pp = sub.add_parser('predict')
pp.add_argument('--gpu', type=int, default=0)
pp.add_argument('--base', type=int, default=48)
pe = sub.add_parser('eval')
pe.add_argument('--folds', type=int, default=1)
pe.add_argument('--seed', type=int, default=42)
pe.add_argument('--gpu', type=int, default=0)
pe.add_argument('--base', type=int, default=48)
pe.add_argument('--save', action='store_true', help='Save out-of-fold predictions to dataset/train/predictions/')
args = p.parse_args()
if args.cmd == 'train':
cmd_train(args)
elif args.cmd == 'predict':
cmd_predict(args)
elif args.cmd == 'eval':
cmd_eval_train(args)
else:
p.print_help()
if __name__ == '__main__':
main()