File size: 8,894 Bytes
b63c30f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import argparse
import os
import json
from pathlib import Path
import numpy as np
from PIL import Image
import torch
import torch.nn.functional as F
import cv2
from train_global_unet import UNet, encode_scribble, list_test_pairs, list_train_pairs, load_palette, evaluate_predictions, tta_predict, TRAIN_H, TRAIN_W, ORIG_H, ORIG_W, TEST_PRED

def predict_threshold(prob, sc, predictor):
    if predictor is None:
        return 0.5
    fg_scrib = (sc == 1).sum()
    bg_scrib = (sc == 0).sum()
    scrib_fg_ratio = fg_scrib / (fg_scrib + bg_scrib) if fg_scrib + bg_scrib > 0 else 0.5
    pred_fg_frac = (prob > 0.5).mean()
    prob_mean = prob.mean()
    prob_std = prob.std()
    p_clip = np.clip(prob, 1e-06, 1 - 1e-06)
    entropy = -(p_clip * np.log(p_clip) + (1 - p_clip) * np.log(1 - p_clip)).mean()
    feats = np.array([scrib_fg_ratio, pred_fg_frac, prob_mean, prob_std, entropy])
    t = float(feats @ np.array(predictor['weights']) + predictor['bias'])
    return float(np.clip(t, predictor['clip'][0], predictor['clip'][1]))

def load_threshold_predictor():
    p = Path('threshold_predictor.json')
    if not p.exists():
        return None
    return json.load(open(p))

def load_models(ckpt_specs, device):
    models = []
    for ckpt_dir, base, seed in ckpt_specs:
        for fd in sorted(Path(ckpt_dir).glob('fold_*')):
            ckpt = fd / 'best.pth'
            if not ckpt.exists():
                continue
            m = UNet(in_ch=5, base=base, out_ch=1).to(device)
            m.load_state_dict(torch.load(ckpt, map_location=device))
            m.eval()
            models.append(m)
            print(f'  loaded {ckpt}  (base={base}, seed={seed})')
    return models

def parse_specs(spec_strs):
    out = []
    for s in spec_strs:
        parts = s.split(':')
        path = parts[0]
        base = int(parts[1])
        seed = int(parts[2]) if len(parts) > 2 else 42
        out.append((Path(path), base, seed))
    return out

def filter_small_components(pred, min_area=200):
    n_fg, lab_fg, stats_fg, _ = cv2.connectedComponentsWithStats(pred, connectivity=8)
    out = pred.copy()
    for i in range(1, n_fg):
        if stats_fg[i, cv2.CC_STAT_AREA] < min_area:
            out[lab_fg == i] = 0
    inv = 1 - out
    n_bg, lab_bg, stats_bg, _ = cv2.connectedComponentsWithStats(inv, connectivity=8)
    for i in range(1, n_bg):
        if stats_bg[i, cv2.CC_STAT_AREA] < min_area:
            out[lab_bg == i] = 1
    return out

def postprocess(pred, min_area=200, close_ks=9):
    if close_ks > 1:
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_ks, close_ks))
        pred = cv2.morphologyEx(pred, cv2.MORPH_CLOSE, kernel)
    pred = filter_small_components(pred, min_area)
    return pred

def predict_test1(models, device):
    palette = load_palette()
    predictor = load_threshold_predictor()
    if predictor is not None:
        print(f'Using learned per-image threshold predictor (clip={predictor['clip']})')
    test_pairs = list_test_pairs()
    TEST_PRED.mkdir(parents=True, exist_ok=True)
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    for stem, img_p, sc_p in test_pairs:
        img = np.array(Image.open(img_p).convert('RGB'))
        sc = np.array(Image.open(sc_p).convert('L'))
        img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
        sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
        img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
        x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device)
        prob_sum = None
        for m in models:
            p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3))
            prob_sum = p if prob_sum is None else prob_sum + p
        prob = prob_sum / len(models)
        prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
        thresh = predict_threshold(prob_full, sc, predictor)
        pred = (prob_full > thresh).astype(np.uint8)
        pred = postprocess(pred, min_area=200, close_ks=9)
        pred[sc == 0] = 0
        pred[sc == 1] = 1
        out_img = Image.fromarray(pred.astype(np.uint8), mode='P')
        out_img.putpalette(palette)
        out_img.save(TEST_PRED / f'{stem}.png')
    print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}')

def eval_oof_ensemble(ckpt_specs, folds, seed, device, save=False):
    pairs = list_train_pairs()

    def fold_assignment(seed_):
        rng = np.random.RandomState(seed_)
        idx = np.arange(len(pairs))
        rng.shuffle(idx)
        fold_arr_ = np.array_split(idx, folds)
        return {ii: k for k in range(folds) for ii in fold_arr_[k]}
    grouped = []
    for ckpt_dir, base, ckpt_seed in ckpt_specs:
        fold_models = {}
        for fd in sorted(Path(ckpt_dir).glob('fold_*')):
            ckpt = fd / 'best.pth'
            if not ckpt.exists():
                continue
            k = int(fd.name.split('_')[1])
            m = UNet(in_ch=5, base=base, out_ch=1).to(device)
            m.load_state_dict(torch.load(ckpt, map_location=device))
            m.eval()
            fold_models[k] = m
        fold_of_seed = fold_assignment(ckpt_seed)
        grouped.append((fold_models, fold_of_seed))
        print(f'  {ckpt_dir} has folds: {sorted(fold_models)}  (seed={ckpt_seed})')
    mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
    std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    train_pred_dir = Path('dataset/train/predictions')
    if save:
        train_pred_dir.mkdir(exist_ok=True)
        palette = load_palette()
    predictor = load_threshold_predictor()
    if predictor is not None:
        print(f'Using learned per-image threshold predictor')
    all_p, all_g = ([], [])
    for i, (stem, img_p, sc_p, gt_p) in enumerate(pairs):
        ensemble_models = []
        for fold_models, fold_of_seed in grouped:
            k_dir = fold_of_seed[i]
            if k_dir in fold_models:
                ensemble_models.append(fold_models[k_dir])
        if not ensemble_models:
            continue
        img = np.array(Image.open(img_p).convert('RGB'))
        sc = np.array(Image.open(sc_p).convert('L'))
        gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8)
        img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
        sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
        img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
        x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device)
        prob_sum = None
        for m in ensemble_models:
            p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3))
            prob_sum = p if prob_sum is None else prob_sum + p
        prob = prob_sum / len(ensemble_models)
        prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
        thresh = predict_threshold(prob_full, sc, predictor)
        pred = (prob_full > thresh).astype(np.uint8)
        pred = postprocess(pred, min_area=200, close_ks=9)
        pred[sc == 0] = 0
        pred[sc == 1] = 1
        all_p.append(pred)
        all_g.append(gt)
        if save:
            out_img = Image.fromarray(pred.astype(np.uint8), mode='P')
            out_img.putpalette(palette)
            out_img.save(train_pred_dir / f'{stem}.png')
    bg, fg, miou = evaluate_predictions(all_p, all_g)
    print(f'OOF ensemble: bg={bg:.4f}  fg={fg:.4f}  mIoU={miou:.4f}  (n={len(all_p)})')
    if save:
        print(f'Saved {len(all_p)} OOF predictions to {train_pred_dir}')
    return miou

def main():
    p = argparse.ArgumentParser()
    p.add_argument('--ckpt-dirs', nargs='+', required=True, help="One or more 'path:base' entries, e.g. 'runs_global_unet:48 runs_v2:64'")
    p.add_argument('--gpu', type=int, default=0)
    p.add_argument('--eval', action='store_true', help='Evaluate out-of-fold ensemble on training set instead of predicting test1')
    p.add_argument('--save', action='store_true', help='With --eval, also save predictions to dataset/train/predictions/')
    p.add_argument('--folds', type=int, default=5)
    p.add_argument('--seed', type=int, default=42)
    args = p.parse_args()
    device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
    ckpt_specs = parse_specs(args.ckpt_dirs)
    if args.eval:
        eval_oof_ensemble(ckpt_specs, args.folds, args.seed, device, save=args.save)
    else:
        models = load_models(ckpt_specs, device)
        if not models:
            print('No models found.')
            return
        print(f'Total models in ensemble: {len(models)}')
        predict_test1(models, device)
if __name__ == '__main__':
    main()