x-ray-vision / src /run_inference.py
taheera's picture
initial commit
df9c255
raw
history blame contribute delete
565 Bytes
import torch
def run_inference(model, model_str, input_tensor, device, tta=True):
# add extra batch dimension
samples = input_tensor.unsqueeze(0)
if tta:
with torch.no_grad():
bs, n_crops, c, h, w = samples.size()
# flatten out the batch and n_crops dimensions
input = samples.view(-1, c, h, w).to(device)
out = model(input)
out = torch.sigmoid(out)
# average the outputs of the TTA crops
out_mean = out.view(bs, n_crops, -1).mean(1)
return out_mean