(optim) prefetch and speed up eval.
Browse files- SEGMENTATION_PLAN.md +1 -1
- configs/default.yaml +3 -3
- train.py +111 -53
SEGMENTATION_PLAN.md
CHANGED
|
@@ -108,7 +108,7 @@ Focus: segmentation only (no dataset collection or inpainting).
|
|
| 108 |
## Configuration Surface (key)
|
| 109 |
- Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
|
| 110 |
- Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
|
| 111 |
-
- Conditioning: `
|
| 112 |
- MinMax: `enable=true`, `kernel=6`.
|
| 113 |
- Label: `coarse_label_downsample='maxpool'`.
|
| 114 |
- Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
|
|
|
|
| 108 |
## Configuration Surface (key)
|
| 109 |
- Backbone/weights: `mit_b2` (pretrained ImageNet-1K).
|
| 110 |
- Sizes: `p=768`, `coarse_train=512`, `coarse_test=1024`, `overlap=128`.
|
| 111 |
+
- Conditioning: `cond_from='coarse_logits_1x1'`, `cond_crop='patch'`.
|
| 112 |
- MinMax: `enable=true`, `kernel=6`.
|
| 113 |
- Label: `coarse_label_downsample='maxpool'`.
|
| 114 |
- Training: `iters=40000`, `batch=8`, `lr=6e-5`, `wd=0.01`, `schedule='poly'`, `power=1.0`.
|
configs/default.yaml
CHANGED
|
@@ -37,9 +37,9 @@ optim:
|
|
| 37 |
# training housekeeping
|
| 38 |
seed: 42
|
| 39 |
out_dir: runs/wireseghr
|
| 40 |
-
eval_interval:
|
| 41 |
-
ckpt_interval:
|
| 42 |
-
resume: runs/wireseghr/ckpt_1500.pt # optional
|
| 43 |
|
| 44 |
# dataset paths (placeholders)
|
| 45 |
data:
|
|
|
|
| 37 |
# training housekeeping
|
| 38 |
seed: 42
|
| 39 |
out_dir: runs/wireseghr
|
| 40 |
+
eval_interval: 100
|
| 41 |
+
ckpt_interval: 300
|
| 42 |
+
# resume: runs/wireseghr/ckpt_1500.pt # optional
|
| 43 |
|
| 44 |
# dataset paths (placeholders)
|
| 45 |
data:
|
train.py
CHANGED
|
@@ -14,6 +14,7 @@ from tqdm import tqdm
|
|
| 14 |
import random
|
| 15 |
import torch.backends.cudnn as cudnn
|
| 16 |
import cv2
|
|
|
|
| 17 |
|
| 18 |
from src.wireseghr.model import WireSegHR
|
| 19 |
from src.wireseghr.model.minmax import MinMaxLuminance
|
|
@@ -23,6 +24,46 @@ from src.wireseghr.data.sampler import BalancedPatchSampler
|
|
| 23 |
from src.wireseghr.metrics import compute_metrics
|
| 24 |
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def main():
|
| 27 |
parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
|
| 28 |
parser.add_argument(
|
|
@@ -66,6 +107,23 @@ def main():
|
|
| 66 |
train_images = cfg["data"]["train_images"]
|
| 67 |
train_masks = cfg["data"]["train_masks"]
|
| 68 |
dset = WireSegDataset(train_images, train_masks, split="train")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
# Validation and test
|
| 70 |
val_images = cfg["data"].get("val_images", None)
|
| 71 |
val_masks = cfg["data"].get("val_masks", None)
|
|
@@ -120,9 +178,14 @@ def main():
|
|
| 120 |
model.train()
|
| 121 |
step = start_step
|
| 122 |
pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
|
|
|
|
| 123 |
while step < iters:
|
| 124 |
optim.zero_grad(set_to_none=True)
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
batch = _prepare_batch(
|
| 127 |
imgs, masks, coarse_train, patch_size, sampler, minmax, device
|
| 128 |
)
|
|
@@ -169,6 +232,9 @@ def main():
|
|
| 169 |
|
| 170 |
# Eval & Checkpoint
|
| 171 |
if (step % eval_interval == 0) and (dset_val is not None):
|
|
|
|
|
|
|
|
|
|
| 172 |
model.eval()
|
| 173 |
val_stats = validate(
|
| 174 |
model,
|
|
@@ -181,6 +247,7 @@ def main():
|
|
| 181 |
mm_kernel,
|
| 182 |
patch_size,
|
| 183 |
overlap,
|
|
|
|
| 184 |
)
|
| 185 |
print(
|
| 186 |
f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
|
|
@@ -323,8 +390,8 @@ def _prepare_batch(
|
|
| 323 |
)[0]
|
| 324 |
zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device)
|
| 325 |
c_t = torch.cat(
|
| 326 |
-
[rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse
|
| 327 |
-
) #
|
| 328 |
xs_coarse.append(c_t)
|
| 329 |
|
| 330 |
# Sample fine patch (CPU mask), then slice GPU min/max and transfer only patches
|
|
@@ -347,8 +414,6 @@ def _prepare_batch(
|
|
| 347 |
)
|
| 348 |
patches_min.append(ymin_patch)
|
| 349 |
patches_max.append(ymax_patch)
|
| 350 |
-
# Binary location mask (ones inside the patch)
|
| 351 |
-
loc_masks.append(np.ones((patch_size, patch_size), dtype=np.float32))
|
| 352 |
yx_list.append((y0, x0))
|
| 353 |
|
| 354 |
x_coarse = torch.stack(xs_coarse, dim=0) # already on device
|
|
@@ -362,7 +427,6 @@ def _prepare_batch(
|
|
| 362 |
"mask_patches": patches_mask,
|
| 363 |
"ymin_patches": patches_min,
|
| 364 |
"ymax_patches": patches_max,
|
| 365 |
-
"loc_patches": loc_masks,
|
| 366 |
"patch_yx": yx_list,
|
| 367 |
"mask_full": masks,
|
| 368 |
}
|
|
@@ -371,9 +435,9 @@ def _prepare_batch(
|
|
| 371 |
def _build_fine_inputs(
|
| 372 |
batch, cond_map: torch.Tensor, device: torch.device
|
| 373 |
) -> torch.Tensor:
|
| 374 |
-
# Build fine input tensor
|
| 375 |
B = cond_map.shape[0]
|
| 376 |
-
P = batch["
|
| 377 |
full_h, full_w = batch["full_h"], batch["full_w"]
|
| 378 |
hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
|
| 379 |
|
|
@@ -382,7 +446,6 @@ def _build_fine_inputs(
|
|
| 382 |
rgb = batch["rgb_patches"][i]
|
| 383 |
ymin = batch["ymin_patches"][i]
|
| 384 |
ymax = batch["ymax_patches"][i]
|
| 385 |
-
loc = batch["loc_patches"][i]
|
| 386 |
y0, x0 = batch["patch_yx"][i]
|
| 387 |
|
| 388 |
# Map full-res patch box to low-res cond grid, crop and upsample to P
|
|
@@ -402,8 +465,7 @@ def _build_fine_inputs(
|
|
| 402 |
) # 3xPxP
|
| 403 |
ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() # 1xPxP
|
| 404 |
ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() # 1xPxP
|
| 405 |
-
|
| 406 |
-
x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch, loc_t], dim=0) # 7xPxP
|
| 407 |
xs.append(x)
|
| 408 |
x_fine = torch.stack(xs, dim=0)
|
| 409 |
return x_fine
|
|
@@ -492,6 +554,7 @@ def validate(
|
|
| 492 |
minmax_kernel: int,
|
| 493 |
fine_patch_size: int,
|
| 494 |
fine_overlap: int,
|
|
|
|
| 495 |
) -> Dict[str, float]:
|
| 496 |
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 497 |
model = model.to(device)
|
|
@@ -558,16 +621,16 @@ def validate(
|
|
| 558 |
for k in coarse_sum:
|
| 559 |
coarse_sum[k] += m_c[k]
|
| 560 |
|
| 561 |
-
# Fine-stage tiled inference and stitching
|
| 562 |
P = fine_patch_size
|
| 563 |
stride = P - fine_overlap
|
| 564 |
assert stride > 0
|
| 565 |
assert H >= P and W >= P
|
| 566 |
-
|
| 567 |
-
|
|
|
|
| 568 |
|
| 569 |
# Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
|
| 570 |
-
# y_min_full, y_max_full exist in this scope from above branch
|
| 571 |
hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
|
| 572 |
|
| 573 |
ys = list(range(0, max(H - P, 0) + 1, stride))
|
|
@@ -577,15 +640,16 @@ def validate(
|
|
| 577 |
if xs[-1] != (W - P):
|
| 578 |
xs.append(W - P)
|
| 579 |
|
|
|
|
| 580 |
for y0 in ys:
|
| 581 |
for x0 in xs:
|
| 582 |
-
|
| 583 |
-
# Build fine input 1x7xP x P
|
| 584 |
-
patch_rgb = img[y0:y1, x0:x1, :]
|
| 585 |
-
ymin_patch = y_min_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
|
| 586 |
-
ymax_patch = y_max_full[0, 0, y0:y1, x0:x1].detach().cpu().numpy()
|
| 587 |
-
loc = np.ones((P, P), dtype=np.float32)
|
| 588 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
# Cond crop mapping (same as training _build_fine_inputs)
|
| 590 |
y0c = (y0 * hc4) // H
|
| 591 |
y1c = ((y1 * hc4) + H - 1) // H
|
|
@@ -596,37 +660,31 @@ def validate(
|
|
| 596 |
cond_sub, size=(P, P), mode="bilinear", align_corners=False
|
| 597 |
).squeeze(1) # 1xPxP
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
|
| 620 |
-
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
)
|
| 625 |
-
|
| 626 |
-
prob_sum[y0:y1, x0:x1] += prob_f_up
|
| 627 |
-
weight[y0:y1, x0:x1] += 1.0
|
| 628 |
-
|
| 629 |
-
prob_full = prob_sum / weight
|
| 630 |
pred_fine = (prob_full > prob_thresh).astype(np.uint8)
|
| 631 |
m_f = compute_metrics(pred_fine, mask)
|
| 632 |
for k in metrics_sum:
|
|
@@ -699,7 +757,7 @@ def save_test_visuals(
|
|
| 699 |
align_corners=False,
|
| 700 |
)[0]
|
| 701 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 702 |
-
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c
|
| 703 |
with autocast(
|
| 704 |
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 705 |
):
|
|
|
|
| 14 |
import random
|
| 15 |
import torch.backends.cudnn as cudnn
|
| 16 |
import cv2
|
| 17 |
+
from torch.utils.data import DataLoader
|
| 18 |
|
| 19 |
from src.wireseghr.model import WireSegHR
|
| 20 |
from src.wireseghr.model.minmax import MinMaxLuminance
|
|
|
|
| 24 |
from src.wireseghr.metrics import compute_metrics
|
| 25 |
|
| 26 |
|
| 27 |
+
class SizeBatchSampler:
|
| 28 |
+
"""Batch sampler that groups indices by exact (H, W) so all samples in a batch share size.
|
| 29 |
+
|
| 30 |
+
This enables DataLoader prefetching while preserving the existing assumption
|
| 31 |
+
in `_prepare_batch()` that all items in a batch have the same full resolution.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, dset: WireSegDataset, batch_size: int):
|
| 35 |
+
self.dset = dset
|
| 36 |
+
self.batch_size = batch_size
|
| 37 |
+
# Precompute epoch length as the total number of full batches across bins
|
| 38 |
+
bins = self.dset.size_bins
|
| 39 |
+
self._len = 0
|
| 40 |
+
for hw, idxs in bins.items():
|
| 41 |
+
_ = hw # unused, clarity
|
| 42 |
+
self._len += (len(idxs) // self.batch_size)
|
| 43 |
+
|
| 44 |
+
def __len__(self) -> int:
|
| 45 |
+
return self._len
|
| 46 |
+
|
| 47 |
+
def __iter__(self):
|
| 48 |
+
# Create randomized batches per epoch across size bins
|
| 49 |
+
bins = self.dset.size_bins
|
| 50 |
+
keys = list(bins.keys())
|
| 51 |
+
random.shuffle(keys)
|
| 52 |
+
for hw in keys:
|
| 53 |
+
pool = list(bins[hw])
|
| 54 |
+
random.shuffle(pool)
|
| 55 |
+
# Yield only full batches to keep fixed batch size and same-size assumption
|
| 56 |
+
for i in range(0, len(pool) - (len(pool) % self.batch_size), self.batch_size):
|
| 57 |
+
yield pool[i : i + self.batch_size]
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def collate_train(batch: List[Dict]):
|
| 61 |
+
"""Collate function that returns lists of numpy arrays to match existing pipeline."""
|
| 62 |
+
imgs = [b["image"] for b in batch]
|
| 63 |
+
masks = [b["mask"] for b in batch]
|
| 64 |
+
return imgs, masks
|
| 65 |
+
|
| 66 |
+
|
| 67 |
def main():
|
| 68 |
parser = argparse.ArgumentParser(description="WireSegHR training (skeleton)")
|
| 69 |
parser.add_argument(
|
|
|
|
| 107 |
train_images = cfg["data"]["train_images"]
|
| 108 |
train_masks = cfg["data"]["train_masks"]
|
| 109 |
dset = WireSegDataset(train_images, train_masks, split="train")
|
| 110 |
+
# DataLoader with prefetching and size-aware batching
|
| 111 |
+
loader_cfg = cfg.get("loader", {})
|
| 112 |
+
num_workers = int(loader_cfg.get("num_workers", 4))
|
| 113 |
+
prefetch_factor = int(loader_cfg.get("prefetch_factor", 2))
|
| 114 |
+
pin_memory = bool(loader_cfg.get("pin_memory", True))
|
| 115 |
+
persistent_workers = bool(loader_cfg.get("persistent_workers", True)) if num_workers > 0 else False
|
| 116 |
+
batch_sampler = SizeBatchSampler(dset, batch_size)
|
| 117 |
+
loader_kwargs = dict(
|
| 118 |
+
batch_sampler=batch_sampler,
|
| 119 |
+
num_workers=num_workers,
|
| 120 |
+
pin_memory=pin_memory,
|
| 121 |
+
persistent_workers=persistent_workers,
|
| 122 |
+
collate_fn=collate_train,
|
| 123 |
+
)
|
| 124 |
+
if num_workers > 0:
|
| 125 |
+
loader_kwargs["prefetch_factor"] = prefetch_factor
|
| 126 |
+
train_loader = DataLoader(dset, **loader_kwargs)
|
| 127 |
# Validation and test
|
| 128 |
val_images = cfg["data"].get("val_images", None)
|
| 129 |
val_masks = cfg["data"].get("val_masks", None)
|
|
|
|
| 178 |
model.train()
|
| 179 |
step = start_step
|
| 180 |
pbar = tqdm(total=iters - step, initial=0, desc="Train", ncols=100)
|
| 181 |
+
data_iter = iter(train_loader)
|
| 182 |
while step < iters:
|
| 183 |
optim.zero_grad(set_to_none=True)
|
| 184 |
+
try:
|
| 185 |
+
imgs, masks = next(data_iter)
|
| 186 |
+
except StopIteration:
|
| 187 |
+
data_iter = iter(train_loader)
|
| 188 |
+
imgs, masks = next(data_iter)
|
| 189 |
batch = _prepare_batch(
|
| 190 |
imgs, masks, coarse_train, patch_size, sampler, minmax, device
|
| 191 |
)
|
|
|
|
| 232 |
|
| 233 |
# Eval & Checkpoint
|
| 234 |
if (step % eval_interval == 0) and (dset_val is not None):
|
| 235 |
+
# Free training-step tensors before eval to lower peak memory
|
| 236 |
+
del x_fine, logits_coarse, cond_map, logits_fine, y_coarse, y_fine, loss_coarse, loss_fine, loss
|
| 237 |
+
torch.cuda.empty_cache()
|
| 238 |
model.eval()
|
| 239 |
val_stats = validate(
|
| 240 |
model,
|
|
|
|
| 247 |
mm_kernel,
|
| 248 |
patch_size,
|
| 249 |
overlap,
|
| 250 |
+
batch_size,
|
| 251 |
)
|
| 252 |
print(
|
| 253 |
f"[Val @ {step}][Fine] IoU={val_stats['iou']:.4f} F1={val_stats['f1']:.4f} P={val_stats['precision']:.4f} R={val_stats['recall']:.4f}"
|
|
|
|
| 390 |
)[0]
|
| 391 |
zeros_coarse = torch.zeros(1, coarse_train, coarse_train, device=device)
|
| 392 |
c_t = torch.cat(
|
| 393 |
+
[rgb_coarse_t, y_min_c_t, y_max_c_t, zeros_coarse], dim=0
|
| 394 |
+
) # 6xHc x Wc
|
| 395 |
xs_coarse.append(c_t)
|
| 396 |
|
| 397 |
# Sample fine patch (CPU mask), then slice GPU min/max and transfer only patches
|
|
|
|
| 414 |
)
|
| 415 |
patches_min.append(ymin_patch)
|
| 416 |
patches_max.append(ymax_patch)
|
|
|
|
|
|
|
| 417 |
yx_list.append((y0, x0))
|
| 418 |
|
| 419 |
x_coarse = torch.stack(xs_coarse, dim=0) # already on device
|
|
|
|
| 427 |
"mask_patches": patches_mask,
|
| 428 |
"ymin_patches": patches_min,
|
| 429 |
"ymax_patches": patches_max,
|
|
|
|
| 430 |
"patch_yx": yx_list,
|
| 431 |
"mask_full": masks,
|
| 432 |
}
|
|
|
|
| 435 |
def _build_fine_inputs(
|
| 436 |
batch, cond_map: torch.Tensor, device: torch.device
|
| 437 |
) -> torch.Tensor:
|
| 438 |
+
# Build fine input tensor Bx6xP x P; crop cond from low-res map, upsample to P
|
| 439 |
B = cond_map.shape[0]
|
| 440 |
+
P = batch["rgb_patches"][0].shape[0]
|
| 441 |
full_h, full_w = batch["full_h"], batch["full_w"]
|
| 442 |
hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
|
| 443 |
|
|
|
|
| 446 |
rgb = batch["rgb_patches"][i]
|
| 447 |
ymin = batch["ymin_patches"][i]
|
| 448 |
ymax = batch["ymax_patches"][i]
|
|
|
|
| 449 |
y0, x0 = batch["patch_yx"][i]
|
| 450 |
|
| 451 |
# Map full-res patch box to low-res cond grid, crop and upsample to P
|
|
|
|
| 465 |
) # 3xPxP
|
| 466 |
ymin_t = torch.from_numpy(ymin)[None, ...].to(device).float() # 1xPxP
|
| 467 |
ymax_t = torch.from_numpy(ymax)[None, ...].to(device).float() # 1xPxP
|
| 468 |
+
x = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0) # 6xPxP
|
|
|
|
| 469 |
xs.append(x)
|
| 470 |
x_fine = torch.stack(xs, dim=0)
|
| 471 |
return x_fine
|
|
|
|
| 554 |
minmax_kernel: int,
|
| 555 |
fine_patch_size: int,
|
| 556 |
fine_overlap: int,
|
| 557 |
+
fine_batch: int,
|
| 558 |
) -> Dict[str, float]:
|
| 559 |
# Coarse-only validation: resize image to coarse_size, predict coarse logits, upsample to full and compute metrics
|
| 560 |
model = model.to(device)
|
|
|
|
| 621 |
for k in coarse_sum:
|
| 622 |
coarse_sum[k] += m_c[k]
|
| 623 |
|
| 624 |
+
# Fine-stage tiled inference and stitching (BATCHED)
|
| 625 |
P = fine_patch_size
|
| 626 |
stride = P - fine_overlap
|
| 627 |
assert stride > 0
|
| 628 |
assert H >= P and W >= P
|
| 629 |
+
# Accumulate on device to avoid CPU<->GPU thrash
|
| 630 |
+
prob_sum_t = torch.zeros((H, W), device=device, dtype=torch.float32)
|
| 631 |
+
weight_t = torch.zeros((H, W), device=device, dtype=torch.float32)
|
| 632 |
|
| 633 |
# Prepare min/max on full-res (already computed above as y_min_full/y_max_full)
|
|
|
|
| 634 |
hc4, wc4 = cond_map.shape[2], cond_map.shape[3]
|
| 635 |
|
| 636 |
ys = list(range(0, max(H - P, 0) + 1, stride))
|
|
|
|
| 640 |
if xs[-1] != (W - P):
|
| 641 |
xs.append(W - P)
|
| 642 |
|
| 643 |
+
coords: List[Tuple[int, int]] = []
|
| 644 |
for y0 in ys:
|
| 645 |
for x0 in xs:
|
| 646 |
+
coords.append((y0, x0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 647 |
|
| 648 |
+
for i0 in range(0, len(coords), fine_batch):
|
| 649 |
+
batch_coords = coords[i0 : i0 + fine_batch]
|
| 650 |
+
xs_list: List[torch.Tensor] = []
|
| 651 |
+
for (y0, x0) in batch_coords:
|
| 652 |
+
y1, x1 = y0 + P, x0 + P
|
| 653 |
# Cond crop mapping (same as training _build_fine_inputs)
|
| 654 |
y0c = (y0 * hc4) // H
|
| 655 |
y1c = ((y1 * hc4) + H - 1) // H
|
|
|
|
| 660 |
cond_sub, size=(P, P), mode="bilinear", align_corners=False
|
| 661 |
).squeeze(1) # 1xPxP
|
| 662 |
|
| 663 |
+
# Build fine input channels directly from on-device tensors
|
| 664 |
+
rgb_t = t_img[0, :, y0:y1, x0:x1] # 3xPxP
|
| 665 |
+
ymin_t = y_min_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
|
| 666 |
+
ymax_t = y_max_full[0, 0, y0:y1, x0:x1].float().unsqueeze(0) # 1xPxP
|
| 667 |
+
x_f = torch.cat([rgb_t, ymin_t, ymax_t, cond_patch], dim=0).unsqueeze(0)
|
| 668 |
+
xs_list.append(x_f)
|
| 669 |
+
|
| 670 |
+
x_f_batch = torch.cat(xs_list, dim=0) # Bx6xPxP
|
| 671 |
+
|
| 672 |
+
with autocast(
|
| 673 |
+
device_type=device.type,
|
| 674 |
+
enabled=(device.type == "cuda" and amp_flag),
|
| 675 |
+
):
|
| 676 |
+
logits_f = model.forward_fine(x_f_batch)
|
| 677 |
+
prob_f = torch.softmax(logits_f, dim=1)[:, 1:2]
|
| 678 |
+
prob_f_up = F.interpolate(
|
| 679 |
+
prob_f, size=(P, P), mode="bilinear", align_corners=False
|
| 680 |
+
)[:, 0, :, :] # BxPxP
|
| 681 |
+
|
| 682 |
+
for bi, (y0, x0) in enumerate(batch_coords):
|
| 683 |
+
y1, x1 = y0 + P, x0 + P
|
| 684 |
+
prob_sum_t[y0:y1, x0:x1] += prob_f_up[bi]
|
| 685 |
+
weight_t[y0:y1, x0:x1] += 1.0
|
| 686 |
+
|
| 687 |
+
prob_full = (prob_sum_t / weight_t).detach().cpu().numpy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
pred_fine = (prob_full > prob_thresh).astype(np.uint8)
|
| 689 |
m_f = compute_metrics(pred_fine, mask)
|
| 690 |
for k in metrics_sum:
|
|
|
|
| 757 |
align_corners=False,
|
| 758 |
)[0]
|
| 759 |
zeros_c = torch.zeros(1, coarse_size, coarse_size, device=device)
|
| 760 |
+
x_t = torch.cat([rgb_c, y_min_c, y_max_c, zeros_c], dim=0).unsqueeze(0)
|
| 761 |
with autocast(
|
| 762 |
device_type=device.type, enabled=(device.type == "cuda" and amp_flag)
|
| 763 |
):
|