Spaces:
Sleeping
Sleeping
| import torch | |
| def run_inference(model, model_str, input_tensor, device, tta=True): | |
| # add extra batch dimension | |
| samples = input_tensor.unsqueeze(0) | |
| if tta: | |
| with torch.no_grad(): | |
| bs, n_crops, c, h, w = samples.size() | |
| # flatten out the batch and n_crops dimensions | |
| input = samples.view(-1, c, h, w).to(device) | |
| out = model(input) | |
| out = torch.sigmoid(out) | |
| # average the outputs of the TTA crops | |
| out_mean = out.view(bs, n_crops, -1).mean(1) | |
| return out_mean | |