| | import math |
| | import torch |
| | import numpy as np |
| | from megatron import get_args |
| |
|
| | def slidingcrops(img, mask): |
| | |
| | |
| | args = get_args() |
| | assert args.img_h == args.img_w |
| | crop_size = args.img_h |
| | stride = args.seg_stride |
| | ignore_index = args.ignore_index |
| | n, c, h, w = img.shape |
| | assert h >= crop_size |
| | assert w >= crop_size |
| | long_size = max(h, w) |
| |
|
| | img_slices, mask_slices, slices_info = [], [], [] |
| | if long_size > crop_size: |
| | assert stride <= crop_size |
| | h_step_num = int(math.ceil((h - crop_size) / float(stride))) + 1 |
| | w_step_num = int(math.ceil((w - crop_size) / float(stride))) + 1 |
| | for yy in range(h_step_num): |
| | for xx in range(w_step_num): |
| | sy, sx = yy * stride, xx * stride |
| | ey, ex = sy + crop_size, sx + crop_size |
| | img_sub = img[:, :, sy: ey, sx: ex] |
| | mask_sub = mask[:, sy: ey, sx: ex] |
| |
|
| | |
| | sub_h, sub_w = img_sub.shape[2:] |
| | pad_h = max(crop_size - sub_h, 0) |
| | pad_w = max(crop_size - sub_w, 0) |
| | img_sub = torch.nn.functional.pad(img_sub, pad=(0, pad_w, 0, pad_h), value=ignore_index) |
| | mask_sub = torch.nn.functional.pad(mask_sub, pad=(0, pad_w, 0, pad_h)) |
| |
|
| | img_slices.append(img_sub) |
| | mask_slices.append(mask_sub) |
| | slices_info.append([sy, ey, sx, ex, sub_h, sub_w]) |
| |
|
| | return torch.cat(img_slices), torch.cat(mask_slices), slices_info, (h, w) |
| | else: |
| | return img, mask, [[0, h, 0, w, h, w]], (h, w) |
| |
|
| |
|
| | def slidingjoins(preds, probs, labels, slices_info, img_size): |
| | args = get_args() |
| | num_slices = len(slices_info) |
| |
|
| | if num_slices == 1: |
| | return preds, labels |
| |
|
| | h, w = img_size |
| | split_size = args.micro_batch_size |
| |
|
| | preds_split = torch.split(preds, split_size) |
| | probs_split = torch.split(probs, split_size) |
| | labels_split = torch.split(labels, split_size) |
| |
|
| | assert(len(preds_split) == num_slices) |
| |
|
| | total_max_probs = torch.zeros((split_size, h, w), dtype=torch.float, device='cuda') |
| | total_preds = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda') |
| | total_labels = torch.zeros((split_size, h, w), dtype=torch.int, device='cuda') |
| |
|
| | for i in range(num_slices): |
| | sy, ey, sx, ex, sub_h, sub_w = slices_info[i] |
| | assert sy + sub_h <= h |
| | assert sx + sub_w <= w |
| | curr_max_probs = total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] |
| | curr_preds = total_preds[:, sy:sy + sub_h, sx:sx + sub_w] |
| |
|
| | local_max_probs = probs_split[i][:, :sub_h, : sub_w] |
| | local_preds = preds_split[i][:, :sub_h, :sub_w] |
| |
|
| | result_max_probs = torch.maximum(curr_max_probs, local_max_probs) |
| | result_preds = torch.where(curr_max_probs >= local_max_probs, curr_preds, local_preds) |
| |
|
| | total_max_probs[:, sy:sy + sub_h, sx:sx + sub_w] = result_max_probs |
| | total_preds[:, sy:sy + sub_h, sx:sx + sub_w] = result_preds |
| | total_labels[:, sy:sy + sub_h, sx:sx + sub_w] = labels_split[i][0, :sub_h, :sub_w] |
| |
|
| | return total_preds, total_labels |
| |
|
| |
|