| import cv2 | |
| import torch | |
| import numpy as np | |
| def predict_with_tta(model, image): | |
| transforms = [ | |
| lambda x: x, | |
| lambda x: torch.flip(x, dims=[3]), | |
| lambda x: torch.rot90(x, 1, [2, 3]) | |
| ] | |
| predictions = [] | |
| for tf in transforms: | |
| aug = tf(image) | |
| #with torch.no_grad(): | |
| pred = model(aug) | |
| inv_pred = tf(pred) | |
| #predictions.append(torch.softmax(inv_pred, dim=1)) | |
| predictions.append(inv_pred) | |
| #avg_pred = torch.stack(predictions).mean(0) | |
| #return torch.argmax(avg_pred, dim=1).squeeze(0) | |
| avg_logits = torch.stack(predictions).mean(0) # [B, C, H, W] | |
| return avg_logits | |
| def refine_mask(mask_tensor): | |
| mask = mask_tensor.cpu().numpy().astype(np.uint8) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) | |
| closed = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel) | |
| return torch.from_numpy(opened).to(mask_tensor.device) | |