File size: 994 Bytes
7b615ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import cv2
import torch
import numpy as np
def predict_with_tta(model, image):
transforms = [
lambda x: x,
lambda x: torch.flip(x, dims=[3]),
lambda x: torch.rot90(x, 1, [2, 3])
]
predictions = []
for tf in transforms:
aug = tf(image)
#with torch.no_grad():
pred = model(aug)
inv_pred = tf(pred)
#predictions.append(torch.softmax(inv_pred, dim=1))
predictions.append(inv_pred)
#avg_pred = torch.stack(predictions).mean(0)
#return torch.argmax(avg_pred, dim=1).squeeze(0)
avg_logits = torch.stack(predictions).mean(0) # [B, C, H, W]
return avg_logits
def refine_mask(mask_tensor):
mask = mask_tensor.cpu().numpy().astype(np.uint8)
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
closed = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
return torch.from_numpy(opened).to(mask_tensor.device)
|