import os import sys import json import math import random import argparse from pathlib import Path import numpy as np from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader DATA_ROOT = Path('dataset') TRAIN_IMG = DATA_ROOT / 'train' / 'images' TRAIN_SC = DATA_ROOT / 'train' / 'scribbles' TRAIN_GT = DATA_ROOT / 'train' / 'ground_truth' TEST_IMG = DATA_ROOT / 'test1' / 'images' TEST_SC = DATA_ROOT / 'test1' / 'scribbles' TEST_PRED = DATA_ROOT / 'test1' / 'predictions' TRAIN_H = int(os.environ.get('TRAIN_H', '384')) TRAIN_W = int(os.environ.get('TRAIN_W', '512')) ORIG_H, ORIG_W = (375, 500) CKPT_DIR = Path(os.environ.get('CKPT_DIR', 'runs_global_unet')) CKPT_DIR.mkdir(exist_ok=True) def list_train_pairs(): pairs = [] for img_path in sorted(TRAIN_IMG.iterdir()): if img_path.name.startswith('.'): continue stem = img_path.stem sc_path = TRAIN_SC / f'{stem}.png' gt_path = TRAIN_GT / f'{stem}.png' if sc_path.exists() and gt_path.exists(): pairs.append((stem, img_path, sc_path, gt_path)) return pairs def list_test_pairs(): pairs = [] for img_path in sorted(TEST_IMG.iterdir()): if img_path.name.startswith('.'): continue stem = img_path.stem sc_path = TEST_SC / f'{stem}.png' if sc_path.exists(): pairs.append((stem, img_path, sc_path)) return pairs def list_pseudo_pairs(pseudo_label_method='v3v4'): pairs = [] for setname in ['test1', 'test2']: img_dir = Path(f'dataset/{setname}/images') sc_dir = Path(f'dataset/{setname}/scribbles') gt_dir = Path(f'dataset/{setname}/predictions_{pseudo_label_method}') if not gt_dir.exists(): continue for ip in sorted(img_dir.iterdir()): if ip.name.startswith('.'): continue stem = ip.stem sp = sc_dir / f'{stem}.png' gp = gt_dir / f'{stem}.png' if sp.exists() and gp.exists(): pairs.append((stem, ip, sp, gp)) return pairs def load_palette(): any_gt = next(TRAIN_GT.glob('*.png')) return Image.open(any_gt).getpalette() def encode_scribble(sc): bg_ch = (sc == 0).astype(np.float32) fg_ch = (sc == 1).astype(np.float32) return np.stack([bg_ch, fg_ch], axis=0) def random_affine(img, sc, gt, rng): H, W = img.shape[:2] angle = rng.uniform(-12, 12) scale = rng.uniform(0.85, 1.2) tx = rng.uniform(-0.05, 0.05) * W ty = rng.uniform(-0.05, 0.05) * H cx, cy = (W / 2, H / 2) a = math.radians(angle) cos_a, sin_a = (math.cos(a) * scale, math.sin(a) * scale) M = np.array([[cos_a, -sin_a, (1 - cos_a) * cx + sin_a * cy + tx], [sin_a, cos_a, (1 - cos_a) * cy - sin_a * cx + ty]], dtype=np.float32) import cv2 img_a = cv2.warpAffine(img, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) sc_a = cv2.warpAffine(sc, M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=255) gt_a = cv2.warpAffine(gt, M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0) return (img_a, sc_a, gt_a) def color_jitter(img, rng): img_f = img.astype(np.float32) / 255.0 img_f = img_f * rng.uniform(0.8, 1.2) mean = img_f.mean(axis=(0, 1), keepdims=True) img_f = (img_f - mean) * rng.uniform(0.8, 1.2) + mean if rng.random() < 0.7: gray = img_f.mean(axis=2, keepdims=True) img_f = img_f * rng.uniform(0.7, 1.3) + gray * (1 - rng.uniform(0.7, 1.3)) img_f = np.clip(img_f, 0, 1) return (img_f * 255).astype(np.uint8) class ScribbleSegDataset(Dataset): def __init__(self, pairs, train=True, image_size=(TRAIN_H, TRAIN_W), cutmix_p=0.0): self.pairs = pairs self.train = train self.H, self.W = image_size self.cutmix_p = cutmix_p self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32) def __len__(self): return len(self.pairs) def _load_one(self, idx): import cv2 stem, img_p, sc_p, gt_p = self.pairs[idx] img = np.array(Image.open(img_p).convert('RGB')) sc = np.array(Image.open(sc_p).convert('L')) gt = np.array(Image.open(gt_p)) if img.shape[:2] != (self.H, self.W): img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_LINEAR) sc = cv2.resize(sc, (self.W, self.H), interpolation=cv2.INTER_NEAREST) gt = cv2.resize(gt, (self.W, self.H), interpolation=cv2.INTER_NEAREST) return (stem, img, sc, gt) def __getitem__(self, idx): stem, img, sc, gt = self._load_one(idx) rng = random.Random() if self.train: if rng.random() < 0.5: img = img[:, ::-1, :].copy() sc = sc[:, ::-1].copy() gt = gt[:, ::-1].copy() img, sc, gt = random_affine(img, sc, gt, rng) img = color_jitter(img, rng) if rng.random() < 0.3: drop_mask = (sc != 255) & (np.random.rand(*sc.shape) < 0.3) sc = sc.copy() sc[drop_mask] = 255 if self.cutmix_p > 0 and rng.random() < self.cutmix_p: j = rng.randint(0, len(self.pairs) - 1) _, img2, sc2, gt2 = self._load_one(j) rh = rng.randint(int(0.3 * self.H), int(0.6 * self.H)) rw = rng.randint(int(0.3 * self.W), int(0.6 * self.W)) ry = rng.randint(0, self.H - rh) rx = rng.randint(0, self.W - rw) img = img.copy() sc = sc.copy() gt = gt.copy() img[ry:ry + rh, rx:rx + rw] = img2[ry:ry + rh, rx:rx + rw] sc[ry:ry + rh, rx:rx + rw] = sc2[ry:ry + rh, rx:rx + rw] gt[ry:ry + rh, rx:rx + rw] = gt2[ry:ry + rh, rx:rx + rw] img_f = img.astype(np.float32) / 255.0 img_f = (img_f - self.mean) / self.std img_t = torch.from_numpy(img_f.transpose(2, 0, 1)) sc_enc = encode_scribble(sc) sc_t = torch.from_numpy(sc_enc) x = torch.cat([img_t, sc_t], dim=0) gt_bin = (gt > 0).astype(np.float32) y = torch.from_numpy(gt_bin) return (x, y, stem) class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True)) def forward(self, x): return self.block(x) class UNet(nn.Module): def __init__(self, in_ch=5, base=48, out_ch=1): super().__init__() c1, c2, c3, c4, c5 = (base, base * 2, base * 4, base * 8, base * 16) self.enc1 = ConvBlock(in_ch, c1) self.enc2 = ConvBlock(c1, c2) self.enc3 = ConvBlock(c2, c3) self.enc4 = ConvBlock(c3, c4) self.bottleneck = ConvBlock(c4, c5) self.pool = nn.MaxPool2d(2) self.up4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.dec4 = ConvBlock(c5 + c4, c4) self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.dec3 = ConvBlock(c4 + c3, c3) self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.dec2 = ConvBlock(c3 + c2, c2) self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) self.dec1 = ConvBlock(c2 + c1, c1) self.head = nn.Conv2d(c1, out_ch, 1) self._init_weights() def _init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d): nn.init.ones_(m.weight) nn.init.zeros_(m.bias) def forward(self, x): e1 = self.enc1(x) e2 = self.enc2(self.pool(e1)) e3 = self.enc3(self.pool(e2)) e4 = self.enc4(self.pool(e3)) b = self.bottleneck(self.pool(e4)) d4 = self.dec4(torch.cat([self.up4(b), e4], 1)) d3 = self.dec3(torch.cat([self.up3(d4), e3], 1)) d2 = self.dec2(torch.cat([self.up2(d3), e2], 1)) d1 = self.dec1(torch.cat([self.up1(d2), e1], 1)) return self.head(d1) def soft_dice_loss(logits, target, eps=1e-06): p = torch.sigmoid(logits).squeeze(1) inter = (p * target).sum(dim=(1, 2)) denom = p.sum(dim=(1, 2)) + target.sum(dim=(1, 2)) dice = (2 * inter + eps) / (denom + eps) return 1 - dice.mean() def combined_loss(logits, target): bce = F.binary_cross_entropy_with_logits(logits.squeeze(1), target) dice = soft_dice_loss(logits, target) return 0.5 * bce + 0.5 * dice def compute_iou(pred_bin, gt_bin, cls): p = pred_bin == cls g = gt_bin == cls inter = np.logical_and(p, g).sum() union = np.logical_or(p, g).sum() return inter / union if union > 0 else 0.0 def evaluate_predictions(preds, gts): bg, fg = ([], []) for p, g in zip(preds, gts): bg.append(compute_iou(p, g, 0)) fg.append(compute_iou(p, g, 1)) bg = np.mean(bg) fg = np.mean(fg) return (bg, fg, (bg + fg) / 2) def train_one_fold(train_pairs, val_pairs, epochs, batch_size, lr, fold_id, device, base=48, cutmix_p=0.0): train_ds = ScribbleSegDataset(train_pairs, train=True, cutmix_p=cutmix_p) val_ds = ScribbleSegDataset(val_pairs, train=False) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True) val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True) model = UNet(in_ch=5, base=base, out_ch=1).to(device) n_params = sum((p.numel() for p in model.parameters())) print(f'[fold {fold_id}] U-Net params: {n_params / 1000000.0:.2f}M (base={base}), train={len(train_ds)}, val={len(val_ds)}') opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0001) sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs, eta_min=lr / 30) scaler = torch.amp.GradScaler('cuda') best_miou = -1.0 best_state = None log = [] patience = 25 bad_epochs = 0 for epoch in range(epochs): model.train() train_loss = 0.0 n = 0 for x, y, _ in train_dl: x, y = (x.to(device, non_blocking=True), y.to(device, non_blocking=True)) opt.zero_grad(set_to_none=True) with torch.amp.autocast('cuda', dtype=torch.float16): logits = model(x) loss = combined_loss(logits, y) scaler.scale(loss).backward() scaler.step(opt) scaler.update() train_loss += loss.item() * x.size(0) n += x.size(0) train_loss /= n sched.step() model.eval() all_p, all_g = ([], []) with torch.no_grad(): for x, y, _ in val_dl: x = x.to(device, non_blocking=True) with torch.amp.autocast('cuda', dtype=torch.float16): logits = model(x) p = (torch.sigmoid(logits).squeeze(1).float().cpu().numpy() > 0.5).astype(np.uint8) g = y.numpy().astype(np.uint8) for i in range(p.shape[0]): all_p.append(p[i]) all_g.append(g[i]) bg, fg, miou = evaluate_predictions(all_p, all_g) log.append({'epoch': epoch, 'loss': train_loss, 'val_bg': bg, 'val_fg': fg, 'val_miou': miou}) print(f'[fold {fold_id} ep {epoch:03d}] loss={train_loss:.4f} val: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} lr={sched.get_last_lr()[0]:.2e}') if miou > best_miou: best_miou = miou best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} bad_epochs = 0 else: bad_epochs += 1 if bad_epochs >= patience: print(f'[fold {fold_id}] early stopping at epoch {epoch} (best mIoU={best_miou:.4f})') break fold_dir = CKPT_DIR / f'fold_{fold_id}' fold_dir.mkdir(exist_ok=True) torch.save(best_state, fold_dir / 'best.pth') with open(fold_dir / 'log.json', 'w') as f: json.dump(log, f, indent=2) return best_miou def cmd_train(args): set_seed(args.seed) device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f'Device: {device}') pairs = list_train_pairs() print(f'Training pairs: {len(pairs)}') pseudo_pairs = [] if getattr(args, 'pseudo_method', None): pseudo_pairs = list_pseudo_pairs(args.pseudo_method) print(f'Pseudo-labeled pairs ({args.pseudo_method}): {len(pseudo_pairs)}') rng = np.random.RandomState(args.seed) indices = np.arange(len(pairs)) rng.shuffle(indices) if args.folds == 1: n_val = max(1, len(pairs) // 5) splits = [(indices[n_val:], indices[:n_val])] else: fold_arr = np.array_split(indices, args.folds) splits = [] for k in range(args.folds): val_idx = fold_arr[k] train_idx = np.concatenate([fold_arr[i] for i in range(args.folds) if i != k]) splits.append((train_idx, val_idx)) fold_mious = [] for k, (train_idx, val_idx) in enumerate(splits): train_pairs = [pairs[i] for i in train_idx] if pseudo_pairs: train_pairs = train_pairs + pseudo_pairs val_pairs = [pairs[i] for i in val_idx] print(f'\n=== Fold {k + 1}/{len(splits)}: train={len(train_pairs)} ({len(train_idx)} real + {len(pseudo_pairs)} pseudo), val={len(val_pairs)} ===') miou = train_one_fold(train_pairs, val_pairs, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, fold_id=k, device=device, base=args.base, cutmix_p=args.cutmix_p) fold_mious.append(miou) print('\n=== Cross-validation summary ===') for k, m in enumerate(fold_mious): print(f' fold {k}: {m:.4f}') print(f' mean: {np.mean(fold_mious):.4f} (+/- {np.std(fold_mious):.4f})') def tta_predict(model, x, device, scales=(1.0,)): model.eval() H, W = (x.shape[-2], x.shape[-1]) probs = [] with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.float16): for s in scales: if s == 1.0: xs = x else: new_h = int(round(H * s / 32) * 32) new_w = int(round(W * s / 32) * 32) rgb = F.interpolate(x[:, :3], size=(new_h, new_w), mode='bilinear', align_corners=False) sc = F.interpolate(x[:, 3:], size=(new_h, new_w), mode='nearest') xs = torch.cat([rgb, sc], dim=1) p1 = torch.sigmoid(model(xs)) p2 = torch.sigmoid(model(torch.flip(xs, dims=[3]))) p2 = torch.flip(p2, dims=[3]) p = (p1 + p2) / 2 if p.shape[-2:] != (H, W): p = F.interpolate(p, size=(H, W), mode='bilinear', align_corners=False) probs.append(p) return (sum(probs) / len(probs)).squeeze().float().cpu().numpy() def cmd_predict(args): import cv2 device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') fold_dirs = sorted(CKPT_DIR.glob('fold_*')) fold_dirs = [f for f in fold_dirs if (f / 'best.pth').exists()] if not fold_dirs: print('No trained models found.') sys.exit(1) print(f'Ensembling {len(fold_dirs)} folds.') models = [] for fd in fold_dirs: m = UNet(in_ch=5, base=args.base, out_ch=1).to(device) m.load_state_dict(torch.load(fd / 'best.pth', map_location=device)) m.eval() models.append(m) palette = load_palette() test_pairs = list_test_pairs() TEST_PRED.mkdir(parents=True, exist_ok=True) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) for stem, img_p, sc_p in test_pairs: img = np.array(Image.open(img_p).convert('RGB')) sc = np.array(Image.open(sc_p).convert('L')) img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR) sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST) img_f = (img_r.astype(np.float32) / 255.0 - mean) / std img_t = torch.from_numpy(img_f.transpose(2, 0, 1)) sc_t = torch.from_numpy(encode_scribble(sc_r)) x = torch.cat([img_t, sc_t], dim=0).unsqueeze(0).to(device) prob_sum = None for m in models: p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3)) prob_sum = p if prob_sum is None else prob_sum + p prob = prob_sum / len(models) prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR) pred = (prob_full > 0.5).astype(np.uint8) pred_snap = pred.copy() pred_snap[sc == 0] = 0 pred_snap[sc == 1] = 1 out_img = Image.fromarray(pred_snap.astype(np.uint8), mode='P') out_img.putpalette(palette) out_img.save(TEST_PRED / f'{stem}.png') print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}') def cmd_eval_train(args): import cv2 device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') pairs = list_train_pairs() rng = np.random.RandomState(args.seed) indices = np.arange(len(pairs)) rng.shuffle(indices) folds = np.array_split(indices, args.folds) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) train_pred_dir = Path('dataset/train/predictions') if args.save: train_pred_dir.mkdir(exist_ok=True) palette = load_palette() all_p, all_g = ([], []) for k in range(args.folds): ckpt = CKPT_DIR / f'fold_{k}' / 'best.pth' if not ckpt.exists(): print(f'skip fold {k} - no checkpoint') continue model = UNet(in_ch=5, base=args.base, out_ch=1).to(device) model.load_state_dict(torch.load(ckpt, map_location=device)) model.eval() val_idx = folds[k] for i in val_idx: stem, img_p, sc_p, gt_p = pairs[i] img = np.array(Image.open(img_p).convert('RGB')) sc = np.array(Image.open(sc_p).convert('L')) gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8) img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR) sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST) img_f = (img_r.astype(np.float32) / 255.0 - mean) / std x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device) prob = tta_predict(model, x, device, scales=(0.7, 1.0, 1.3)) prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR) pred = (prob_full > 0.5).astype(np.uint8) pred[sc == 0] = 0 pred[sc == 1] = 1 all_p.append(pred) all_g.append(gt) if args.save: out_img = Image.fromarray(pred.astype(np.uint8), mode='P') out_img.putpalette(palette) out_img.save(train_pred_dir / f'{stem}.png') if args.folds == 1: break bg, fg, miou = evaluate_predictions(all_p, all_g) print(f'Held-out CV: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} (n={len(all_p)} images)') if args.save: print(f'Saved {len(all_p)} train predictions to {train_pred_dir}') def set_seed(seed): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def main(): p = argparse.ArgumentParser() sub = p.add_subparsers(dest='cmd') pt = sub.add_parser('train') pt.add_argument('--epochs', type=int, default=120) pt.add_argument('--batch-size', type=int, default=8) pt.add_argument('--lr', type=float, default=0.001) pt.add_argument('--folds', type=int, default=1) pt.add_argument('--seed', type=int, default=42) pt.add_argument('--gpu', type=int, default=0) pt.add_argument('--base', type=int, default=48, help='U-Net base channel count') pt.add_argument('--ckpt-suffix', type=str, default='', help='Suffix for runs_global_unet dir') pt.add_argument('--cutmix-p', type=float, default=0.0, help='Probability of CutMix per sample') pt.add_argument('--pseudo-method', type=str, default='', help="If set (e.g. 'v3v4'), use that method's predictions on test1+test2 as additional pseudo-labeled training data.") pp = sub.add_parser('predict') pp.add_argument('--gpu', type=int, default=0) pp.add_argument('--base', type=int, default=48) pe = sub.add_parser('eval') pe.add_argument('--folds', type=int, default=1) pe.add_argument('--seed', type=int, default=42) pe.add_argument('--gpu', type=int, default=0) pe.add_argument('--base', type=int, default=48) pe.add_argument('--save', action='store_true', help='Save out-of-fold predictions to dataset/train/predictions/') args = p.parse_args() if args.cmd == 'train': cmd_train(args) elif args.cmd == 'predict': cmd_predict(args) elif args.cmd == 'eval': cmd_eval_train(args) else: p.print_help() if __name__ == '__main__': main()