| import cv2 |
| import numpy as np |
| from core.utils import get_smart_coords, create_hann_window |
| from core.config import NUM_CLASSES |
|
|
| def process_single_rotation(model, img_rot, hann, tile_size, target_size, stride, conf_thresh): |
| h, w = img_rot.shape[:2] |
| acc = np.zeros((NUM_CLASSES, h, w), np.float32) |
| wt = np.zeros((h, w), np.float32) |
|
|
| xs = get_smart_coords(w, tile_size, stride) |
| ys = get_smart_coords(h, tile_size, stride) |
|
|
| for y in ys: |
| for x in xs: |
| tile = img_rot[y:y + tile_size, x:x + tile_size] |
| inp = cv2.resize(tile, (target_size, target_size)) |
|
|
| res = model(inp, imgsz=target_size, conf=conf_thresh, verbose=False, retina_masks=True)[0] |
|
|
| if res.masks is not None: |
| masks = res.masks.data.cpu().numpy() |
| classes = res.boxes.cls.cpu().numpy().astype(int) |
| confs = res.boxes.conf.cpu().numpy() |
| for m, c, cf in zip(masks, classes, confs): |
| if c >= NUM_CLASSES: |
| continue |
| mr = cv2.resize(m, (tile_size, tile_size)) |
| acc[c, y:y + tile_size, x:x + tile_size] += mr * cf * hann |
|
|
| wt[y:y + tile_size, x:x + tile_size] += hann |
| return acc, wt |
|
|
| def run_inference(img, model, use_tta=False, tile_size=512, target_size=640, stride=256, conf_thresh=0.3): |
| orig_h, orig_w = img.shape[:2] |
| pad_h, pad_w = max(0, tile_size - orig_h), max(0, tile_size - orig_w) |
| if pad_h or pad_w: |
| img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(0, 0, 0)) |
|
|
| h, w = img.shape[:2] |
| g_acc = np.zeros((NUM_CLASSES, h, w), np.float32) |
| g_wt = np.zeros((h, w), np.float32) |
| hann = create_hann_window(tile_size) |
|
|
| rotations = 4 if use_tta else 1 |
|
|
| for k in range(rotations): |
| rot = np.rot90(img, k=k) |
| a, ww = process_single_rotation(model, rot, hann, tile_size, target_size, stride, conf_thresh) |
| g_acc += np.rot90(a, k=-k, axes=(1, 2)) |
| g_wt += np.rot90(ww, k=-k) |
|
|
| g_wt = np.clip(g_wt, 1e-5, None) |
| for i in range(NUM_CLASSES): |
| g_acc[i] /= g_wt |
|
|
| max_p = np.max(g_acc, axis=0)[:orig_h, :orig_w] |
| win_c = np.argmax(g_acc, axis=0)[:orig_h, :orig_w] |
| |
| return max_p, win_c, orig_h, orig_w |