File size: 994 Bytes
7b615ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import cv2
import torch
import numpy as np

def predict_with_tta(model, image):
    transforms = [
        lambda x: x,
        lambda x: torch.flip(x, dims=[3]),
        lambda x: torch.rot90(x, 1, [2, 3])
    ]
    predictions = []
    for tf in transforms:
        aug = tf(image)
        #with torch.no_grad():
        pred = model(aug)
        inv_pred = tf(pred)
        #predictions.append(torch.softmax(inv_pred, dim=1))
        predictions.append(inv_pred)
    
    #avg_pred = torch.stack(predictions).mean(0)
    #return torch.argmax(avg_pred, dim=1).squeeze(0)
    avg_logits = torch.stack(predictions).mean(0)  # [B, C, H, W]
    return avg_logits

def refine_mask(mask_tensor):
    mask = mask_tensor.cpu().numpy().astype(np.uint8)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    closed = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    opened = cv2.morphologyEx(closed, cv2.MORPH_OPEN, kernel)
    return torch.from_numpy(opened).to(mask_tensor.device)