Enorenio commited on
Commit
b63c30f
·
verified ·
1 Parent(s): 38429f5

Add inference pipeline (comments stripped)

Browse files
Files changed (1) hide show
  1. predict_ensemble.py +202 -0
predict_ensemble.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import json
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import cv2
10
+ from train_global_unet import UNet, encode_scribble, list_test_pairs, list_train_pairs, load_palette, evaluate_predictions, tta_predict, TRAIN_H, TRAIN_W, ORIG_H, ORIG_W, TEST_PRED
11
+
12
+ def predict_threshold(prob, sc, predictor):
13
+ if predictor is None:
14
+ return 0.5
15
+ fg_scrib = (sc == 1).sum()
16
+ bg_scrib = (sc == 0).sum()
17
+ scrib_fg_ratio = fg_scrib / (fg_scrib + bg_scrib) if fg_scrib + bg_scrib > 0 else 0.5
18
+ pred_fg_frac = (prob > 0.5).mean()
19
+ prob_mean = prob.mean()
20
+ prob_std = prob.std()
21
+ p_clip = np.clip(prob, 1e-06, 1 - 1e-06)
22
+ entropy = -(p_clip * np.log(p_clip) + (1 - p_clip) * np.log(1 - p_clip)).mean()
23
+ feats = np.array([scrib_fg_ratio, pred_fg_frac, prob_mean, prob_std, entropy])
24
+ t = float(feats @ np.array(predictor['weights']) + predictor['bias'])
25
+ return float(np.clip(t, predictor['clip'][0], predictor['clip'][1]))
26
+
27
+ def load_threshold_predictor():
28
+ p = Path('threshold_predictor.json')
29
+ if not p.exists():
30
+ return None
31
+ return json.load(open(p))
32
+
33
+ def load_models(ckpt_specs, device):
34
+ models = []
35
+ for ckpt_dir, base, seed in ckpt_specs:
36
+ for fd in sorted(Path(ckpt_dir).glob('fold_*')):
37
+ ckpt = fd / 'best.pth'
38
+ if not ckpt.exists():
39
+ continue
40
+ m = UNet(in_ch=5, base=base, out_ch=1).to(device)
41
+ m.load_state_dict(torch.load(ckpt, map_location=device))
42
+ m.eval()
43
+ models.append(m)
44
+ print(f' loaded {ckpt} (base={base}, seed={seed})')
45
+ return models
46
+
47
+ def parse_specs(spec_strs):
48
+ out = []
49
+ for s in spec_strs:
50
+ parts = s.split(':')
51
+ path = parts[0]
52
+ base = int(parts[1])
53
+ seed = int(parts[2]) if len(parts) > 2 else 42
54
+ out.append((Path(path), base, seed))
55
+ return out
56
+
57
+ def filter_small_components(pred, min_area=200):
58
+ n_fg, lab_fg, stats_fg, _ = cv2.connectedComponentsWithStats(pred, connectivity=8)
59
+ out = pred.copy()
60
+ for i in range(1, n_fg):
61
+ if stats_fg[i, cv2.CC_STAT_AREA] < min_area:
62
+ out[lab_fg == i] = 0
63
+ inv = 1 - out
64
+ n_bg, lab_bg, stats_bg, _ = cv2.connectedComponentsWithStats(inv, connectivity=8)
65
+ for i in range(1, n_bg):
66
+ if stats_bg[i, cv2.CC_STAT_AREA] < min_area:
67
+ out[lab_bg == i] = 1
68
+ return out
69
+
70
+ def postprocess(pred, min_area=200, close_ks=9):
71
+ if close_ks > 1:
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (close_ks, close_ks))
73
+ pred = cv2.morphologyEx(pred, cv2.MORPH_CLOSE, kernel)
74
+ pred = filter_small_components(pred, min_area)
75
+ return pred
76
+
77
+ def predict_test1(models, device):
78
+ palette = load_palette()
79
+ predictor = load_threshold_predictor()
80
+ if predictor is not None:
81
+ print(f'Using learned per-image threshold predictor (clip={predictor['clip']})')
82
+ test_pairs = list_test_pairs()
83
+ TEST_PRED.mkdir(parents=True, exist_ok=True)
84
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
85
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
86
+ for stem, img_p, sc_p in test_pairs:
87
+ img = np.array(Image.open(img_p).convert('RGB'))
88
+ sc = np.array(Image.open(sc_p).convert('L'))
89
+ img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
90
+ sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
91
+ img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
92
+ x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device)
93
+ prob_sum = None
94
+ for m in models:
95
+ p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3))
96
+ prob_sum = p if prob_sum is None else prob_sum + p
97
+ prob = prob_sum / len(models)
98
+ prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
99
+ thresh = predict_threshold(prob_full, sc, predictor)
100
+ pred = (prob_full > thresh).astype(np.uint8)
101
+ pred = postprocess(pred, min_area=200, close_ks=9)
102
+ pred[sc == 0] = 0
103
+ pred[sc == 1] = 1
104
+ out_img = Image.fromarray(pred.astype(np.uint8), mode='P')
105
+ out_img.putpalette(palette)
106
+ out_img.save(TEST_PRED / f'{stem}.png')
107
+ print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}')
108
+
109
+ def eval_oof_ensemble(ckpt_specs, folds, seed, device, save=False):
110
+ pairs = list_train_pairs()
111
+
112
+ def fold_assignment(seed_):
113
+ rng = np.random.RandomState(seed_)
114
+ idx = np.arange(len(pairs))
115
+ rng.shuffle(idx)
116
+ fold_arr_ = np.array_split(idx, folds)
117
+ return {ii: k for k in range(folds) for ii in fold_arr_[k]}
118
+ grouped = []
119
+ for ckpt_dir, base, ckpt_seed in ckpt_specs:
120
+ fold_models = {}
121
+ for fd in sorted(Path(ckpt_dir).glob('fold_*')):
122
+ ckpt = fd / 'best.pth'
123
+ if not ckpt.exists():
124
+ continue
125
+ k = int(fd.name.split('_')[1])
126
+ m = UNet(in_ch=5, base=base, out_ch=1).to(device)
127
+ m.load_state_dict(torch.load(ckpt, map_location=device))
128
+ m.eval()
129
+ fold_models[k] = m
130
+ fold_of_seed = fold_assignment(ckpt_seed)
131
+ grouped.append((fold_models, fold_of_seed))
132
+ print(f' {ckpt_dir} has folds: {sorted(fold_models)} (seed={ckpt_seed})')
133
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
134
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
135
+ train_pred_dir = Path('dataset/train/predictions')
136
+ if save:
137
+ train_pred_dir.mkdir(exist_ok=True)
138
+ palette = load_palette()
139
+ predictor = load_threshold_predictor()
140
+ if predictor is not None:
141
+ print(f'Using learned per-image threshold predictor')
142
+ all_p, all_g = ([], [])
143
+ for i, (stem, img_p, sc_p, gt_p) in enumerate(pairs):
144
+ ensemble_models = []
145
+ for fold_models, fold_of_seed in grouped:
146
+ k_dir = fold_of_seed[i]
147
+ if k_dir in fold_models:
148
+ ensemble_models.append(fold_models[k_dir])
149
+ if not ensemble_models:
150
+ continue
151
+ img = np.array(Image.open(img_p).convert('RGB'))
152
+ sc = np.array(Image.open(sc_p).convert('L'))
153
+ gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8)
154
+ img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
155
+ sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
156
+ img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
157
+ x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device)
158
+ prob_sum = None
159
+ for m in ensemble_models:
160
+ p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3))
161
+ prob_sum = p if prob_sum is None else prob_sum + p
162
+ prob = prob_sum / len(ensemble_models)
163
+ prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
164
+ thresh = predict_threshold(prob_full, sc, predictor)
165
+ pred = (prob_full > thresh).astype(np.uint8)
166
+ pred = postprocess(pred, min_area=200, close_ks=9)
167
+ pred[sc == 0] = 0
168
+ pred[sc == 1] = 1
169
+ all_p.append(pred)
170
+ all_g.append(gt)
171
+ if save:
172
+ out_img = Image.fromarray(pred.astype(np.uint8), mode='P')
173
+ out_img.putpalette(palette)
174
+ out_img.save(train_pred_dir / f'{stem}.png')
175
+ bg, fg, miou = evaluate_predictions(all_p, all_g)
176
+ print(f'OOF ensemble: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} (n={len(all_p)})')
177
+ if save:
178
+ print(f'Saved {len(all_p)} OOF predictions to {train_pred_dir}')
179
+ return miou
180
+
181
+ def main():
182
+ p = argparse.ArgumentParser()
183
+ p.add_argument('--ckpt-dirs', nargs='+', required=True, help="One or more 'path:base' entries, e.g. 'runs_global_unet:48 runs_v2:64'")
184
+ p.add_argument('--gpu', type=int, default=0)
185
+ p.add_argument('--eval', action='store_true', help='Evaluate out-of-fold ensemble on training set instead of predicting test1')
186
+ p.add_argument('--save', action='store_true', help='With --eval, also save predictions to dataset/train/predictions/')
187
+ p.add_argument('--folds', type=int, default=5)
188
+ p.add_argument('--seed', type=int, default=42)
189
+ args = p.parse_args()
190
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
191
+ ckpt_specs = parse_specs(args.ckpt_dirs)
192
+ if args.eval:
193
+ eval_oof_ensemble(ckpt_specs, args.folds, args.seed, device, save=args.save)
194
+ else:
195
+ models = load_models(ckpt_specs, device)
196
+ if not models:
197
+ print('No models found.')
198
+ return
199
+ print(f'Total models in ensemble: {len(models)}')
200
+ predict_test1(models, device)
201
+ if __name__ == '__main__':
202
+ main()