import argparse import os import json from pathlib import Path import numpy as np from PIL import Image import torch import torch.nn.functional as F import cv2 from train_global_unet import UNet, encode_scribble, list_test_pairs, list_train_pairs, load_palette, evaluate_predictions, tta_predict, TRAIN_H, TRAIN_W, ORIG_H, ORIG_W, TEST_PRED def predict_threshold(prob, sc, predictor): if predictor is None: return 0.5 fg_scrib = (sc == 1).sum() bg_scrib = (sc == 0).sum() scrib_fg_ratio = fg_scrib / (fg_scrib + bg_scrib) if fg_scrib + bg_scrib > 0 else 0.5 pred_fg_frac = (prob > 0.5).mean() prob_mean = prob.mean() prob_std = prob.std() p_clip = np.clip(prob, 1e-06, 1 - 1e-06) entropy = -(p_clip * np.log(p_clip) + (1 - p_clip) * np.log(1 - p_clip)).mean() feats = np.array([scrib_fg_ratio, pred_fg_frac, prob_mean, prob_std, entropy]) t = float(feats @ np.array(predictor['weights']) + predictor['bias']) return float(np.clip(t, predictor['clip'][0], predictor['clip'][1])) def load_threshold_predictor(): p = Path('threshold_predictor.json') if not p.exists(): return None return json.load(open(p)) def load_models(ckpt_specs, device): models = [] for ckpt_dir, base, seed in ckpt_specs: for fd in sorted(Path(ckpt_dir).glob('fold_*')): ckpt = fd / 'best.pth' if not ckpt.exists(): continue m = UNet(in_ch=5, base=base, out_ch=1).to(device) m.load_state_dict(torch.load(ckpt, map_location=device)) m.eval() models.append(m) print(f' loaded {ckpt} (base={base}, seed={seed})') return models def parse_specs(spec_strs): out = [] for s in spec_strs: parts = s.split(':') path = parts[0] base = int(parts[1]) seed = int(parts[2]) if len(parts) > 2 else 42 out.append((Path(path), base, seed)) return out def filter_small_components(pred, min_area=200): n_fg, lab_fg, stats_fg, _ = cv2.connectedComponentsWithStats(pred, connectivity=8) out = pred.copy() for i in range(1, n_fg): if stats_fg[i, cv2.CC_STAT_AREA] < min_area: out[lab_fg == i] = 0 inv = 1 - out n_bg, lab_bg, stats_bg, _ = cv2.connectedComponentsWithStats(inv, connectivity=8) for i in range(1, n_bg): if stats_bg[i, cv2.CC_STAT_AREA] < min_area: out[lab_bg == i] = 1 return out def postprocess(pred, min_area=200, close_ks=9): if close_ks > 1: kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_ks, close_ks)) pred = cv2.morphologyEx(pred, cv2.MORPH_CLOSE, kernel) pred = filter_small_components(pred, min_area) return pred def predict_test1(models, device): palette = load_palette() predictor = load_threshold_predictor() if predictor is not None: print(f'Using learned per-image threshold predictor (clip={predictor['clip']})') test_pairs = list_test_pairs() TEST_PRED.mkdir(parents=True, exist_ok=True) mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) for stem, img_p, sc_p in test_pairs: img = np.array(Image.open(img_p).convert('RGB')) sc = np.array(Image.open(sc_p).convert('L')) img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR) sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST) img_f = (img_r.astype(np.float32) / 255.0 - mean) / std x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device) prob_sum = None for m in models: p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3)) prob_sum = p if prob_sum is None else prob_sum + p prob = prob_sum / len(models) prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR) thresh = predict_threshold(prob_full, sc, predictor) pred = (prob_full > thresh).astype(np.uint8) pred = postprocess(pred, min_area=200, close_ks=9) pred[sc == 0] = 0 pred[sc == 1] = 1 out_img = Image.fromarray(pred.astype(np.uint8), mode='P') out_img.putpalette(palette) out_img.save(TEST_PRED / f'{stem}.png') print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}') def eval_oof_ensemble(ckpt_specs, folds, seed, device, save=False): pairs = list_train_pairs() def fold_assignment(seed_): rng = np.random.RandomState(seed_) idx = np.arange(len(pairs)) rng.shuffle(idx) fold_arr_ = np.array_split(idx, folds) return {ii: k for k in range(folds) for ii in fold_arr_[k]} grouped = [] for ckpt_dir, base, ckpt_seed in ckpt_specs: fold_models = {} for fd in sorted(Path(ckpt_dir).glob('fold_*')): ckpt = fd / 'best.pth' if not ckpt.exists(): continue k = int(fd.name.split('_')[1]) m = UNet(in_ch=5, base=base, out_ch=1).to(device) m.load_state_dict(torch.load(ckpt, map_location=device)) m.eval() fold_models[k] = m fold_of_seed = fold_assignment(ckpt_seed) grouped.append((fold_models, fold_of_seed)) print(f' {ckpt_dir} has folds: {sorted(fold_models)} (seed={ckpt_seed})') mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) std = np.array([0.229, 0.224, 0.225], dtype=np.float32) train_pred_dir = Path('dataset/train/predictions') if save: train_pred_dir.mkdir(exist_ok=True) palette = load_palette() predictor = load_threshold_predictor() if predictor is not None: print(f'Using learned per-image threshold predictor') all_p, all_g = ([], []) for i, (stem, img_p, sc_p, gt_p) in enumerate(pairs): ensemble_models = [] for fold_models, fold_of_seed in grouped: k_dir = fold_of_seed[i] if k_dir in fold_models: ensemble_models.append(fold_models[k_dir]) if not ensemble_models: continue img = np.array(Image.open(img_p).convert('RGB')) sc = np.array(Image.open(sc_p).convert('L')) gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8) img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR) sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST) img_f = (img_r.astype(np.float32) / 255.0 - mean) / std x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device) prob_sum = None for m in ensemble_models: p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3)) prob_sum = p if prob_sum is None else prob_sum + p prob = prob_sum / len(ensemble_models) prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR) thresh = predict_threshold(prob_full, sc, predictor) pred = (prob_full > thresh).astype(np.uint8) pred = postprocess(pred, min_area=200, close_ks=9) pred[sc == 0] = 0 pred[sc == 1] = 1 all_p.append(pred) all_g.append(gt) if save: out_img = Image.fromarray(pred.astype(np.uint8), mode='P') out_img.putpalette(palette) out_img.save(train_pred_dir / f'{stem}.png') bg, fg, miou = evaluate_predictions(all_p, all_g) print(f'OOF ensemble: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} (n={len(all_p)})') if save: print(f'Saved {len(all_p)} OOF predictions to {train_pred_dir}') return miou def main(): p = argparse.ArgumentParser() p.add_argument('--ckpt-dirs', nargs='+', required=True, help="One or more 'path:base' entries, e.g. 'runs_global_unet:48 runs_v2:64'") p.add_argument('--gpu', type=int, default=0) p.add_argument('--eval', action='store_true', help='Evaluate out-of-fold ensemble on training set instead of predicting test1') p.add_argument('--save', action='store_true', help='With --eval, also save predictions to dataset/train/predictions/') p.add_argument('--folds', type=int, default=5) p.add_argument('--seed', type=int, default=42) args = p.parse_args() device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') ckpt_specs = parse_specs(args.ckpt_dirs) if args.eval: eval_oof_ensemble(ckpt_specs, args.folds, args.seed, device, save=args.save) else: models = load_models(ckpt_specs, device) if not models: print('No models found.') return print(f'Total models in ensemble: {len(models)}') predict_test1(models, device) if __name__ == '__main__': main()