Enorenio commited on
Commit
38429f5
·
verified ·
1 Parent(s): 8c25d86

Add training/model definitions (comments stripped)

Browse files
Files changed (1) hide show
  1. train_global_unet.py +505 -0
train_global_unet.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import math
5
+ import random
6
+ import argparse
7
+ from pathlib import Path
8
+ import numpy as np
9
+ from PIL import Image
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ from torch.utils.data import Dataset, DataLoader
14
+ DATA_ROOT = Path('dataset')
15
+ TRAIN_IMG = DATA_ROOT / 'train' / 'images'
16
+ TRAIN_SC = DATA_ROOT / 'train' / 'scribbles'
17
+ TRAIN_GT = DATA_ROOT / 'train' / 'ground_truth'
18
+ TEST_IMG = DATA_ROOT / 'test1' / 'images'
19
+ TEST_SC = DATA_ROOT / 'test1' / 'scribbles'
20
+ TEST_PRED = DATA_ROOT / 'test1' / 'predictions'
21
+ TRAIN_H = int(os.environ.get('TRAIN_H', '384'))
22
+ TRAIN_W = int(os.environ.get('TRAIN_W', '512'))
23
+ ORIG_H, ORIG_W = (375, 500)
24
+ CKPT_DIR = Path(os.environ.get('CKPT_DIR', 'runs_global_unet'))
25
+ CKPT_DIR.mkdir(exist_ok=True)
26
+
27
+ def list_train_pairs():
28
+ pairs = []
29
+ for img_path in sorted(TRAIN_IMG.iterdir()):
30
+ if img_path.name.startswith('.'):
31
+ continue
32
+ stem = img_path.stem
33
+ sc_path = TRAIN_SC / f'{stem}.png'
34
+ gt_path = TRAIN_GT / f'{stem}.png'
35
+ if sc_path.exists() and gt_path.exists():
36
+ pairs.append((stem, img_path, sc_path, gt_path))
37
+ return pairs
38
+
39
+ def list_test_pairs():
40
+ pairs = []
41
+ for img_path in sorted(TEST_IMG.iterdir()):
42
+ if img_path.name.startswith('.'):
43
+ continue
44
+ stem = img_path.stem
45
+ sc_path = TEST_SC / f'{stem}.png'
46
+ if sc_path.exists():
47
+ pairs.append((stem, img_path, sc_path))
48
+ return pairs
49
+
50
+ def list_pseudo_pairs(pseudo_label_method='v3v4'):
51
+ pairs = []
52
+ for setname in ['test1', 'test2']:
53
+ img_dir = Path(f'dataset/{setname}/images')
54
+ sc_dir = Path(f'dataset/{setname}/scribbles')
55
+ gt_dir = Path(f'dataset/{setname}/predictions_{pseudo_label_method}')
56
+ if not gt_dir.exists():
57
+ continue
58
+ for ip in sorted(img_dir.iterdir()):
59
+ if ip.name.startswith('.'):
60
+ continue
61
+ stem = ip.stem
62
+ sp = sc_dir / f'{stem}.png'
63
+ gp = gt_dir / f'{stem}.png'
64
+ if sp.exists() and gp.exists():
65
+ pairs.append((stem, ip, sp, gp))
66
+ return pairs
67
+
68
+ def load_palette():
69
+ any_gt = next(TRAIN_GT.glob('*.png'))
70
+ return Image.open(any_gt).getpalette()
71
+
72
+ def encode_scribble(sc):
73
+ bg_ch = (sc == 0).astype(np.float32)
74
+ fg_ch = (sc == 1).astype(np.float32)
75
+ return np.stack([bg_ch, fg_ch], axis=0)
76
+
77
+ def random_affine(img, sc, gt, rng):
78
+ H, W = img.shape[:2]
79
+ angle = rng.uniform(-12, 12)
80
+ scale = rng.uniform(0.85, 1.2)
81
+ tx = rng.uniform(-0.05, 0.05) * W
82
+ ty = rng.uniform(-0.05, 0.05) * H
83
+ cx, cy = (W / 2, H / 2)
84
+ a = math.radians(angle)
85
+ cos_a, sin_a = (math.cos(a) * scale, math.sin(a) * scale)
86
+ M = np.array([[cos_a, -sin_a, (1 - cos_a) * cx + sin_a * cy + tx], [sin_a, cos_a, (1 - cos_a) * cy - sin_a * cx + ty]], dtype=np.float32)
87
+ import cv2
88
+ img_a = cv2.warpAffine(img, M, (W, H), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)
89
+ sc_a = cv2.warpAffine(sc, M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=255)
90
+ gt_a = cv2.warpAffine(gt, M, (W, H), flags=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
91
+ return (img_a, sc_a, gt_a)
92
+
93
+ def color_jitter(img, rng):
94
+ img_f = img.astype(np.float32) / 255.0
95
+ img_f = img_f * rng.uniform(0.8, 1.2)
96
+ mean = img_f.mean(axis=(0, 1), keepdims=True)
97
+ img_f = (img_f - mean) * rng.uniform(0.8, 1.2) + mean
98
+ if rng.random() < 0.7:
99
+ gray = img_f.mean(axis=2, keepdims=True)
100
+ img_f = img_f * rng.uniform(0.7, 1.3) + gray * (1 - rng.uniform(0.7, 1.3))
101
+ img_f = np.clip(img_f, 0, 1)
102
+ return (img_f * 255).astype(np.uint8)
103
+
104
+ class ScribbleSegDataset(Dataset):
105
+
106
+ def __init__(self, pairs, train=True, image_size=(TRAIN_H, TRAIN_W), cutmix_p=0.0):
107
+ self.pairs = pairs
108
+ self.train = train
109
+ self.H, self.W = image_size
110
+ self.cutmix_p = cutmix_p
111
+ self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
112
+ self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
113
+
114
+ def __len__(self):
115
+ return len(self.pairs)
116
+
117
+ def _load_one(self, idx):
118
+ import cv2
119
+ stem, img_p, sc_p, gt_p = self.pairs[idx]
120
+ img = np.array(Image.open(img_p).convert('RGB'))
121
+ sc = np.array(Image.open(sc_p).convert('L'))
122
+ gt = np.array(Image.open(gt_p))
123
+ if img.shape[:2] != (self.H, self.W):
124
+ img = cv2.resize(img, (self.W, self.H), interpolation=cv2.INTER_LINEAR)
125
+ sc = cv2.resize(sc, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
126
+ gt = cv2.resize(gt, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
127
+ return (stem, img, sc, gt)
128
+
129
+ def __getitem__(self, idx):
130
+ stem, img, sc, gt = self._load_one(idx)
131
+ rng = random.Random()
132
+ if self.train:
133
+ if rng.random() < 0.5:
134
+ img = img[:, ::-1, :].copy()
135
+ sc = sc[:, ::-1].copy()
136
+ gt = gt[:, ::-1].copy()
137
+ img, sc, gt = random_affine(img, sc, gt, rng)
138
+ img = color_jitter(img, rng)
139
+ if rng.random() < 0.3:
140
+ drop_mask = (sc != 255) & (np.random.rand(*sc.shape) < 0.3)
141
+ sc = sc.copy()
142
+ sc[drop_mask] = 255
143
+ if self.cutmix_p > 0 and rng.random() < self.cutmix_p:
144
+ j = rng.randint(0, len(self.pairs) - 1)
145
+ _, img2, sc2, gt2 = self._load_one(j)
146
+ rh = rng.randint(int(0.3 * self.H), int(0.6 * self.H))
147
+ rw = rng.randint(int(0.3 * self.W), int(0.6 * self.W))
148
+ ry = rng.randint(0, self.H - rh)
149
+ rx = rng.randint(0, self.W - rw)
150
+ img = img.copy()
151
+ sc = sc.copy()
152
+ gt = gt.copy()
153
+ img[ry:ry + rh, rx:rx + rw] = img2[ry:ry + rh, rx:rx + rw]
154
+ sc[ry:ry + rh, rx:rx + rw] = sc2[ry:ry + rh, rx:rx + rw]
155
+ gt[ry:ry + rh, rx:rx + rw] = gt2[ry:ry + rh, rx:rx + rw]
156
+ img_f = img.astype(np.float32) / 255.0
157
+ img_f = (img_f - self.mean) / self.std
158
+ img_t = torch.from_numpy(img_f.transpose(2, 0, 1))
159
+ sc_enc = encode_scribble(sc)
160
+ sc_t = torch.from_numpy(sc_enc)
161
+ x = torch.cat([img_t, sc_t], dim=0)
162
+ gt_bin = (gt > 0).astype(np.float32)
163
+ y = torch.from_numpy(gt_bin)
164
+ return (x, y, stem)
165
+
166
+ class ConvBlock(nn.Module):
167
+
168
+ def __init__(self, in_ch, out_ch):
169
+ super().__init__()
170
+ self.block = nn.Sequential(nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True))
171
+
172
+ def forward(self, x):
173
+ return self.block(x)
174
+
175
+ class UNet(nn.Module):
176
+
177
+ def __init__(self, in_ch=5, base=48, out_ch=1):
178
+ super().__init__()
179
+ c1, c2, c3, c4, c5 = (base, base * 2, base * 4, base * 8, base * 16)
180
+ self.enc1 = ConvBlock(in_ch, c1)
181
+ self.enc2 = ConvBlock(c1, c2)
182
+ self.enc3 = ConvBlock(c2, c3)
183
+ self.enc4 = ConvBlock(c3, c4)
184
+ self.bottleneck = ConvBlock(c4, c5)
185
+ self.pool = nn.MaxPool2d(2)
186
+ self.up4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
187
+ self.dec4 = ConvBlock(c5 + c4, c4)
188
+ self.up3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
189
+ self.dec3 = ConvBlock(c4 + c3, c3)
190
+ self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
191
+ self.dec2 = ConvBlock(c3 + c2, c2)
192
+ self.up1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
193
+ self.dec1 = ConvBlock(c2 + c1, c1)
194
+ self.head = nn.Conv2d(c1, out_ch, 1)
195
+ self._init_weights()
196
+
197
+ def _init_weights(self):
198
+ for m in self.modules():
199
+ if isinstance(m, nn.Conv2d):
200
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
201
+ if m.bias is not None:
202
+ nn.init.zeros_(m.bias)
203
+ elif isinstance(m, nn.BatchNorm2d):
204
+ nn.init.ones_(m.weight)
205
+ nn.init.zeros_(m.bias)
206
+
207
+ def forward(self, x):
208
+ e1 = self.enc1(x)
209
+ e2 = self.enc2(self.pool(e1))
210
+ e3 = self.enc3(self.pool(e2))
211
+ e4 = self.enc4(self.pool(e3))
212
+ b = self.bottleneck(self.pool(e4))
213
+ d4 = self.dec4(torch.cat([self.up4(b), e4], 1))
214
+ d3 = self.dec3(torch.cat([self.up3(d4), e3], 1))
215
+ d2 = self.dec2(torch.cat([self.up2(d3), e2], 1))
216
+ d1 = self.dec1(torch.cat([self.up1(d2), e1], 1))
217
+ return self.head(d1)
218
+
219
+ def soft_dice_loss(logits, target, eps=1e-06):
220
+ p = torch.sigmoid(logits).squeeze(1)
221
+ inter = (p * target).sum(dim=(1, 2))
222
+ denom = p.sum(dim=(1, 2)) + target.sum(dim=(1, 2))
223
+ dice = (2 * inter + eps) / (denom + eps)
224
+ return 1 - dice.mean()
225
+
226
+ def combined_loss(logits, target):
227
+ bce = F.binary_cross_entropy_with_logits(logits.squeeze(1), target)
228
+ dice = soft_dice_loss(logits, target)
229
+ return 0.5 * bce + 0.5 * dice
230
+
231
+ def compute_iou(pred_bin, gt_bin, cls):
232
+ p = pred_bin == cls
233
+ g = gt_bin == cls
234
+ inter = np.logical_and(p, g).sum()
235
+ union = np.logical_or(p, g).sum()
236
+ return inter / union if union > 0 else 0.0
237
+
238
+ def evaluate_predictions(preds, gts):
239
+ bg, fg = ([], [])
240
+ for p, g in zip(preds, gts):
241
+ bg.append(compute_iou(p, g, 0))
242
+ fg.append(compute_iou(p, g, 1))
243
+ bg = np.mean(bg)
244
+ fg = np.mean(fg)
245
+ return (bg, fg, (bg + fg) / 2)
246
+
247
+ def train_one_fold(train_pairs, val_pairs, epochs, batch_size, lr, fold_id, device, base=48, cutmix_p=0.0):
248
+ train_ds = ScribbleSegDataset(train_pairs, train=True, cutmix_p=cutmix_p)
249
+ val_ds = ScribbleSegDataset(val_pairs, train=False)
250
+ train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True, drop_last=True)
251
+ val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
252
+ model = UNet(in_ch=5, base=base, out_ch=1).to(device)
253
+ n_params = sum((p.numel() for p in model.parameters()))
254
+ print(f'[fold {fold_id}] U-Net params: {n_params / 1000000.0:.2f}M (base={base}), train={len(train_ds)}, val={len(val_ds)}')
255
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.0001)
256
+ sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs, eta_min=lr / 30)
257
+ scaler = torch.amp.GradScaler('cuda')
258
+ best_miou = -1.0
259
+ best_state = None
260
+ log = []
261
+ patience = 25
262
+ bad_epochs = 0
263
+ for epoch in range(epochs):
264
+ model.train()
265
+ train_loss = 0.0
266
+ n = 0
267
+ for x, y, _ in train_dl:
268
+ x, y = (x.to(device, non_blocking=True), y.to(device, non_blocking=True))
269
+ opt.zero_grad(set_to_none=True)
270
+ with torch.amp.autocast('cuda', dtype=torch.float16):
271
+ logits = model(x)
272
+ loss = combined_loss(logits, y)
273
+ scaler.scale(loss).backward()
274
+ scaler.step(opt)
275
+ scaler.update()
276
+ train_loss += loss.item() * x.size(0)
277
+ n += x.size(0)
278
+ train_loss /= n
279
+ sched.step()
280
+ model.eval()
281
+ all_p, all_g = ([], [])
282
+ with torch.no_grad():
283
+ for x, y, _ in val_dl:
284
+ x = x.to(device, non_blocking=True)
285
+ with torch.amp.autocast('cuda', dtype=torch.float16):
286
+ logits = model(x)
287
+ p = (torch.sigmoid(logits).squeeze(1).float().cpu().numpy() > 0.5).astype(np.uint8)
288
+ g = y.numpy().astype(np.uint8)
289
+ for i in range(p.shape[0]):
290
+ all_p.append(p[i])
291
+ all_g.append(g[i])
292
+ bg, fg, miou = evaluate_predictions(all_p, all_g)
293
+ log.append({'epoch': epoch, 'loss': train_loss, 'val_bg': bg, 'val_fg': fg, 'val_miou': miou})
294
+ print(f'[fold {fold_id} ep {epoch:03d}] loss={train_loss:.4f} val: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} lr={sched.get_last_lr()[0]:.2e}')
295
+ if miou > best_miou:
296
+ best_miou = miou
297
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
298
+ bad_epochs = 0
299
+ else:
300
+ bad_epochs += 1
301
+ if bad_epochs >= patience:
302
+ print(f'[fold {fold_id}] early stopping at epoch {epoch} (best mIoU={best_miou:.4f})')
303
+ break
304
+ fold_dir = CKPT_DIR / f'fold_{fold_id}'
305
+ fold_dir.mkdir(exist_ok=True)
306
+ torch.save(best_state, fold_dir / 'best.pth')
307
+ with open(fold_dir / 'log.json', 'w') as f:
308
+ json.dump(log, f, indent=2)
309
+ return best_miou
310
+
311
+ def cmd_train(args):
312
+ set_seed(args.seed)
313
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
314
+ print(f'Device: {device}')
315
+ pairs = list_train_pairs()
316
+ print(f'Training pairs: {len(pairs)}')
317
+ pseudo_pairs = []
318
+ if getattr(args, 'pseudo_method', None):
319
+ pseudo_pairs = list_pseudo_pairs(args.pseudo_method)
320
+ print(f'Pseudo-labeled pairs ({args.pseudo_method}): {len(pseudo_pairs)}')
321
+ rng = np.random.RandomState(args.seed)
322
+ indices = np.arange(len(pairs))
323
+ rng.shuffle(indices)
324
+ if args.folds == 1:
325
+ n_val = max(1, len(pairs) // 5)
326
+ splits = [(indices[n_val:], indices[:n_val])]
327
+ else:
328
+ fold_arr = np.array_split(indices, args.folds)
329
+ splits = []
330
+ for k in range(args.folds):
331
+ val_idx = fold_arr[k]
332
+ train_idx = np.concatenate([fold_arr[i] for i in range(args.folds) if i != k])
333
+ splits.append((train_idx, val_idx))
334
+ fold_mious = []
335
+ for k, (train_idx, val_idx) in enumerate(splits):
336
+ train_pairs = [pairs[i] for i in train_idx]
337
+ if pseudo_pairs:
338
+ train_pairs = train_pairs + pseudo_pairs
339
+ val_pairs = [pairs[i] for i in val_idx]
340
+ print(f'\n=== Fold {k + 1}/{len(splits)}: train={len(train_pairs)} ({len(train_idx)} real + {len(pseudo_pairs)} pseudo), val={len(val_pairs)} ===')
341
+ miou = train_one_fold(train_pairs, val_pairs, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, fold_id=k, device=device, base=args.base, cutmix_p=args.cutmix_p)
342
+ fold_mious.append(miou)
343
+ print('\n=== Cross-validation summary ===')
344
+ for k, m in enumerate(fold_mious):
345
+ print(f' fold {k}: {m:.4f}')
346
+ print(f' mean: {np.mean(fold_mious):.4f} (+/- {np.std(fold_mious):.4f})')
347
+
348
+ def tta_predict(model, x, device, scales=(1.0,)):
349
+ model.eval()
350
+ H, W = (x.shape[-2], x.shape[-1])
351
+ probs = []
352
+ with torch.no_grad(), torch.amp.autocast('cuda', dtype=torch.float16):
353
+ for s in scales:
354
+ if s == 1.0:
355
+ xs = x
356
+ else:
357
+ new_h = int(round(H * s / 32) * 32)
358
+ new_w = int(round(W * s / 32) * 32)
359
+ rgb = F.interpolate(x[:, :3], size=(new_h, new_w), mode='bilinear', align_corners=False)
360
+ sc = F.interpolate(x[:, 3:], size=(new_h, new_w), mode='nearest')
361
+ xs = torch.cat([rgb, sc], dim=1)
362
+ p1 = torch.sigmoid(model(xs))
363
+ p2 = torch.sigmoid(model(torch.flip(xs, dims=[3])))
364
+ p2 = torch.flip(p2, dims=[3])
365
+ p = (p1 + p2) / 2
366
+ if p.shape[-2:] != (H, W):
367
+ p = F.interpolate(p, size=(H, W), mode='bilinear', align_corners=False)
368
+ probs.append(p)
369
+ return (sum(probs) / len(probs)).squeeze().float().cpu().numpy()
370
+
371
+ def cmd_predict(args):
372
+ import cv2
373
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
374
+ fold_dirs = sorted(CKPT_DIR.glob('fold_*'))
375
+ fold_dirs = [f for f in fold_dirs if (f / 'best.pth').exists()]
376
+ if not fold_dirs:
377
+ print('No trained models found.')
378
+ sys.exit(1)
379
+ print(f'Ensembling {len(fold_dirs)} folds.')
380
+ models = []
381
+ for fd in fold_dirs:
382
+ m = UNet(in_ch=5, base=args.base, out_ch=1).to(device)
383
+ m.load_state_dict(torch.load(fd / 'best.pth', map_location=device))
384
+ m.eval()
385
+ models.append(m)
386
+ palette = load_palette()
387
+ test_pairs = list_test_pairs()
388
+ TEST_PRED.mkdir(parents=True, exist_ok=True)
389
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
390
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
391
+ for stem, img_p, sc_p in test_pairs:
392
+ img = np.array(Image.open(img_p).convert('RGB'))
393
+ sc = np.array(Image.open(sc_p).convert('L'))
394
+ img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
395
+ sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
396
+ img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
397
+ img_t = torch.from_numpy(img_f.transpose(2, 0, 1))
398
+ sc_t = torch.from_numpy(encode_scribble(sc_r))
399
+ x = torch.cat([img_t, sc_t], dim=0).unsqueeze(0).to(device)
400
+ prob_sum = None
401
+ for m in models:
402
+ p = tta_predict(m, x, device, scales=(0.7, 1.0, 1.3))
403
+ prob_sum = p if prob_sum is None else prob_sum + p
404
+ prob = prob_sum / len(models)
405
+ prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
406
+ pred = (prob_full > 0.5).astype(np.uint8)
407
+ pred_snap = pred.copy()
408
+ pred_snap[sc == 0] = 0
409
+ pred_snap[sc == 1] = 1
410
+ out_img = Image.fromarray(pred_snap.astype(np.uint8), mode='P')
411
+ out_img.putpalette(palette)
412
+ out_img.save(TEST_PRED / f'{stem}.png')
413
+ print(f'Wrote {len(test_pairs)} predictions to {TEST_PRED}')
414
+
415
+ def cmd_eval_train(args):
416
+ import cv2
417
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
418
+ pairs = list_train_pairs()
419
+ rng = np.random.RandomState(args.seed)
420
+ indices = np.arange(len(pairs))
421
+ rng.shuffle(indices)
422
+ folds = np.array_split(indices, args.folds)
423
+ mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
424
+ std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
425
+ train_pred_dir = Path('dataset/train/predictions')
426
+ if args.save:
427
+ train_pred_dir.mkdir(exist_ok=True)
428
+ palette = load_palette()
429
+ all_p, all_g = ([], [])
430
+ for k in range(args.folds):
431
+ ckpt = CKPT_DIR / f'fold_{k}' / 'best.pth'
432
+ if not ckpt.exists():
433
+ print(f'skip fold {k} - no checkpoint')
434
+ continue
435
+ model = UNet(in_ch=5, base=args.base, out_ch=1).to(device)
436
+ model.load_state_dict(torch.load(ckpt, map_location=device))
437
+ model.eval()
438
+ val_idx = folds[k]
439
+ for i in val_idx:
440
+ stem, img_p, sc_p, gt_p = pairs[i]
441
+ img = np.array(Image.open(img_p).convert('RGB'))
442
+ sc = np.array(Image.open(sc_p).convert('L'))
443
+ gt = (np.array(Image.open(gt_p)) > 0).astype(np.uint8)
444
+ img_r = cv2.resize(img, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_LINEAR)
445
+ sc_r = cv2.resize(sc, (TRAIN_W, TRAIN_H), interpolation=cv2.INTER_NEAREST)
446
+ img_f = (img_r.astype(np.float32) / 255.0 - mean) / std
447
+ x = torch.cat([torch.from_numpy(img_f.transpose(2, 0, 1)), torch.from_numpy(encode_scribble(sc_r))], 0).unsqueeze(0).to(device)
448
+ prob = tta_predict(model, x, device, scales=(0.7, 1.0, 1.3))
449
+ prob_full = cv2.resize(prob, (ORIG_W, ORIG_H), interpolation=cv2.INTER_LINEAR)
450
+ pred = (prob_full > 0.5).astype(np.uint8)
451
+ pred[sc == 0] = 0
452
+ pred[sc == 1] = 1
453
+ all_p.append(pred)
454
+ all_g.append(gt)
455
+ if args.save:
456
+ out_img = Image.fromarray(pred.astype(np.uint8), mode='P')
457
+ out_img.putpalette(palette)
458
+ out_img.save(train_pred_dir / f'{stem}.png')
459
+ if args.folds == 1:
460
+ break
461
+ bg, fg, miou = evaluate_predictions(all_p, all_g)
462
+ print(f'Held-out CV: bg={bg:.4f} fg={fg:.4f} mIoU={miou:.4f} (n={len(all_p)} images)')
463
+ if args.save:
464
+ print(f'Saved {len(all_p)} train predictions to {train_pred_dir}')
465
+
466
+ def set_seed(seed):
467
+ random.seed(seed)
468
+ np.random.seed(seed)
469
+ torch.manual_seed(seed)
470
+ torch.cuda.manual_seed_all(seed)
471
+
472
+ def main():
473
+ p = argparse.ArgumentParser()
474
+ sub = p.add_subparsers(dest='cmd')
475
+ pt = sub.add_parser('train')
476
+ pt.add_argument('--epochs', type=int, default=120)
477
+ pt.add_argument('--batch-size', type=int, default=8)
478
+ pt.add_argument('--lr', type=float, default=0.001)
479
+ pt.add_argument('--folds', type=int, default=1)
480
+ pt.add_argument('--seed', type=int, default=42)
481
+ pt.add_argument('--gpu', type=int, default=0)
482
+ pt.add_argument('--base', type=int, default=48, help='U-Net base channel count')
483
+ pt.add_argument('--ckpt-suffix', type=str, default='', help='Suffix for runs_global_unet dir')
484
+ pt.add_argument('--cutmix-p', type=float, default=0.0, help='Probability of CutMix per sample')
485
+ pt.add_argument('--pseudo-method', type=str, default='', help="If set (e.g. 'v3v4'), use that method's predictions on test1+test2 as additional pseudo-labeled training data.")
486
+ pp = sub.add_parser('predict')
487
+ pp.add_argument('--gpu', type=int, default=0)
488
+ pp.add_argument('--base', type=int, default=48)
489
+ pe = sub.add_parser('eval')
490
+ pe.add_argument('--folds', type=int, default=1)
491
+ pe.add_argument('--seed', type=int, default=42)
492
+ pe.add_argument('--gpu', type=int, default=0)
493
+ pe.add_argument('--base', type=int, default=48)
494
+ pe.add_argument('--save', action='store_true', help='Save out-of-fold predictions to dataset/train/predictions/')
495
+ args = p.parse_args()
496
+ if args.cmd == 'train':
497
+ cmd_train(args)
498
+ elif args.cmd == 'predict':
499
+ cmd_predict(args)
500
+ elif args.cmd == 'eval':
501
+ cmd_eval_train(args)
502
+ else:
503
+ p.print_help()
504
+ if __name__ == '__main__':
505
+ main()