yolo-test / core /inference.py
kuzheren
update
e62c44b
import cv2
import numpy as np
from core.utils import get_smart_coords, create_hann_window
from core.config import NUM_CLASSES
def process_single_rotation(model, img_rot, hann, tile_size, target_size, stride, conf_thresh):
h, w = img_rot.shape[:2]
acc = np.zeros((NUM_CLASSES, h, w), np.float32)
wt = np.zeros((h, w), np.float32)
xs = get_smart_coords(w, tile_size, stride)
ys = get_smart_coords(h, tile_size, stride)
for y in ys:
for x in xs:
tile = img_rot[y:y + tile_size, x:x + tile_size]
inp = cv2.resize(tile, (target_size, target_size))
res = model(inp, imgsz=target_size, conf=conf_thresh, verbose=False, retina_masks=True)[0]
if res.masks is not None:
masks = res.masks.data.cpu().numpy()
classes = res.boxes.cls.cpu().numpy().astype(int)
confs = res.boxes.conf.cpu().numpy()
for m, c, cf in zip(masks, classes, confs):
if c >= NUM_CLASSES:
continue
mr = cv2.resize(m, (tile_size, tile_size))
acc[c, y:y + tile_size, x:x + tile_size] += mr * cf * hann
wt[y:y + tile_size, x:x + tile_size] += hann
return acc, wt
def run_inference(img, model, use_tta=False, tile_size=512, target_size=640, stride=256, conf_thresh=0.3):
orig_h, orig_w = img.shape[:2]
pad_h, pad_w = max(0, tile_size - orig_h), max(0, tile_size - orig_w)
if pad_h or pad_w:
img = cv2.copyMakeBorder(img, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=(0, 0, 0))
h, w = img.shape[:2]
g_acc = np.zeros((NUM_CLASSES, h, w), np.float32)
g_wt = np.zeros((h, w), np.float32)
hann = create_hann_window(tile_size)
rotations = 4 if use_tta else 1
for k in range(rotations):
rot = np.rot90(img, k=k)
a, ww = process_single_rotation(model, rot, hann, tile_size, target_size, stride, conf_thresh)
g_acc += np.rot90(a, k=-k, axes=(1, 2))
g_wt += np.rot90(ww, k=-k)
g_wt = np.clip(g_wt, 1e-5, None)
for i in range(NUM_CLASSES):
g_acc[i] /= g_wt
max_p = np.max(g_acc, axis=0)[:orig_h, :orig_w]
win_c = np.argmax(g_acc, axis=0)[:orig_h, :orig_w]
return max_p, win_c, orig_h, orig_w