| import argparse |
| import os |
| import json |
| from pathlib import Path |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn.functional as F |
| import cv2 |
| from train_global_unet import UNet, encode_scribble, list_test_pairs, list_train_pairs, load_palette, evaluate_predictions, tta_predict, TRAIN_H, TRAIN_W, ORIG_H, ORIG_W, TEST_PRED |
|
|
| def predict_threshold(prob, sc, predictor): |
| if predictor is None: |
| return 0.5 |
| fg_scrib = (sc == 1).sum() |
| bg_scrib = (sc == 0).sum() |
| scrib_fg_ratio = fg_scrib / (fg_scrib + bg_scrib) if fg_scrib + bg_scrib > 0 else 0.5 |
| pred_fg_frac = (prob > 0.5).mean() |
| prob_mean = prob.mean() |
| prob_std = prob.std() |
| p_clip = np.clip(prob, 1e-06, 1 - 1e-06) |
| entropy = -(p_clip * np.log(p_clip) + (1 - p_clip) * np.log(1 - p_clip)).mean() |
| feats = np.array([scrib_fg_ratio, pred_fg_frac, prob_mean, prob_std, entropy]) |
| t = float(feats @ np.array(predictor['weights']) + predictor['bias']) |
| return float(np.clip(t, predictor['clip'][0], predictor['clip'][1])) |
|
|
| def load_threshold_predictor(): |
| p = Path('threshold_predictor.json') |
| if not p.exists(): |
| return None |
| return json.load(open(p)) |
|
|
| def load_models(ckpt_specs, device): |
| models = [] |
| for ckpt_dir, base, seed in ckpt_specs: |
| for fd in sorted(Path(ckpt_dir).glob('fold_*')): |
| ckpt = fd / 'best.pth' |
| if not ckpt.exists(): |
| continue |
| m = UNet(in_ch=5, base=base, out_ch=1).to(device) |
| m.load_state_dict(torch.load(ckpt, map_location=device)) |
| m.eval() |
| models.append(m) |
| print(f' loaded {ckpt} (base={base}, seed={seed})') |
| return models |
|
|
| def parse_specs(spec_strs): |
| out = [] |
| for s in spec_strs: |
| parts = s.split(':') |
| path = parts[0] |
| base = int(parts[1]) |
| seed = int(parts[2]) if len(parts) > 2 else 42 |
| out.append((Path(path), base, seed)) |
| return out |
|
|
| def filter_small_components(pred, min_area=200): |
| n_fg, lab_fg, stats_fg, _ = cv2.connectedComponentsWithStats(pred, connectivity=8) |
| out = pred.copy() |
| for i in range(1, n_fg): |
| if stats_fg[i, cv2.CC_STAT_AREA] < min_area: |
| out[lab_fg == i] = 0 |
| inv = 1 - out |
| n_bg, lab_bg, stats_bg, _ = cv2.connectedComponentsWithStats(inv, connectivity=8) |
| for i in range(1, n_bg): |
| if stats_bg[i, cv2.CC_STAT_AREA] < min_area: |
| out[lab_bg == i] = 1 |
| return out |
|
|
| def postprocess(pred, min_area=200, close_ks=9): |
| if close_ks > 1: |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_ks, close_ks)) |
| pred = cv2.morphologyEx(pred, cv2.MORPH_CLOSE, kernel) |
| pred = filter_small_components(pred, min_area) |
| return pred |
|
|
| def predict_test1(models, device): |
| palette = load_palette() |
| predictor = load_threshold_predictor() |
| if predictor is not None: |
| print(f'Using learned per-image threshold predictor (clip={predictor['clip']})') |
| test_pairs = list_test_pairs() |
| TEST_PRED.mkdir(parents=True, exist_ok=True) |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| for stem, img_p, sc_p in test_pairs: |
| img = np.array(Image.open(img_p).convert('RGB')) |
| sc = np.array(Image.open(sc_p).convert('L')) |
| img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR) |
| sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST) |
| img_f = (img_r.astype(np.float32) / 255.0 - mean) / std |
| x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device) |
| prob_sum = None |
| for m in models: |
| p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3)) |
| prob_sum = p if prob_sum is None else prob_sum + p |
| prob = prob_sum / len(models) |
| prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR) |
| thresh = predict_threshold(prob_full, sc, predictor) |
| pred = (prob_full > thresh).astype(np.uint8) |
| pred = postprocess(pred, min_area=200, close_ks=9) |
| pred[sc == 0] = 0 |
| pred[sc == 1] = 1 |
| out_img = Image.fromarray(pred.astype(np.uint8), mode='P') |
| out_img.putpalette(palette) |
| out_img.save(TEST_PRED / f'{stem}.png') |
| print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}') |
|
|
| def eval_oof_ensemble(ckpt_specs, folds, seed, device, save=False): |
| pairs = list_train_pairs() |
|
|
| def fold_assignment(seed_): |
| rng = np.random.RandomState(seed_) |
| idx = np.arange(len(pairs)) |
| rng.shuffle(idx) |
| fold_arr_ = np.array_split(idx, folds) |
| return {ii: k for k in range(folds) for ii in fold_arr_[k]} |
| grouped = [] |
| for ckpt_dir, base, ckpt_seed in ckpt_specs: |
| fold_models = {} |
| for fd in sorted(Path(ckpt_dir).glob('fold_*')): |
| ckpt = fd / 'best.pth' |
| if not ckpt.exists(): |
| continue |
| k = int(fd.name.split('_')[1]) |
| m = UNet(in_ch=5, base=base, out_ch=1).to(device) |
| m.load_state_dict(torch.load(ckpt, map_location=device)) |
| m.eval() |
| fold_models[k] = m |
| fold_of_seed = fold_assignment(ckpt_seed) |
| grouped.append((fold_models, fold_of_seed)) |
| print(f' {ckpt_dir} has folds: {sorted(fold_models)} (seed={ckpt_seed})') |
| mean = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| std = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| train_pred_dir = Path('dataset/train/predictions') |
| if save: |
| train_pred_dir.mkdir(exist_ok=True) |
| palette = load_palette() |
| predictor = load_threshold_predictor() |
| if predictor is not None: |
| print(f'Using learned per-image threshold predictor') |
| all_p, all_g = ([], []) |
| for i, (stem, img_p, sc_p, gt_p) in enumerate(pairs): |
| ensemble_models = [] |
| for fold_models, fold_of_seed in grouped: |
| k_dir = fold_of_seed[i] |
| if k_dir in fold_models: |
| ensemble_models.append(fold_models[k_dir]) |
| if not ensemble_models: |
| continue |
| img = np.array(Image.open(img_p).convert('RGB')) |
| sc = np.array(Image.open(sc_p).convert('L')) |
| gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8) |
| img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR) |
| sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST) |
| img_f = (img_r.astype(np.float32) / 255.0 - mean) / std |
| x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device) |
| prob_sum = None |
| for m in ensemble_models: |
| p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3)) |
| prob_sum = p if prob_sum is None else prob_sum + p |
| prob = prob_sum / len(ensemble_models) |
| prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR) |
| thresh = predict_threshold(prob_full, sc, predictor) |
| pred = (prob_full > thresh).astype(np.uint8) |
| pred = postprocess(pred, min_area=200, close_ks=9) |
| pred[sc == 0] = 0 |
| pred[sc == 1] = 1 |
| all_p.append(pred) |
| all_g.append(gt) |
| if save: |
| out_img = Image.fromarray(pred.astype(np.uint8), mode='P') |
| out_img.putpalette(palette) |
| out_img.save(train_pred_dir / f'{stem}.png') |
| bg, fg, miou = evaluate_predictions(all_p, all_g) |
| print(f'OOF ensemble: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} (n={len(all_p)})') |
| if save: |
| print(f'Saved {len(all_p)} OOF predictions to {train_pred_dir}') |
| return miou |
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument('--ckpt-dirs', nargs='+', required=True, help="One or more 'path:base' entries, e.g. 'runs_global_unet:48 runs_v2:64'") |
| p.add_argument('--gpu', type=int, default=0) |
| p.add_argument('--eval', action='store_true', help='Evaluate out-of-fold ensemble on training set instead of predicting test1') |
| p.add_argument('--save', action='store_true', help='With --eval, also save predictions to dataset/train/predictions/') |
| p.add_argument('--folds', type=int, default=5) |
| p.add_argument('--seed', type=int, default=42) |
| args = p.parse_args() |
| device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') |
| ckpt_specs = parse_specs(args.ckpt_dirs) |
| if args.eval: |
| eval_oof_ensemble(ckpt_specs, args.folds, args.seed, device, save=args.save) |
| else: |
| models = load_models(ckpt_specs, device) |
| if not models: |
| print('No models found.') |
| return |
| print(f'Total models in ensemble: {len(models)}') |
| predict_test1(models, device) |
| if __name__ == '__main__': |
| main() |