Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from utils import utils_image as util | |
| def infer(model, L): | |
| E = model(L) | |
| return E | |
| def inferp(model, L, modulo=16): | |
| h, w = L.size()[-2:] | |
| paddingBottom = int(np.ceil(h/modulo)*modulo-h) | |
| paddingRight = int(np.ceil(w/modulo)*modulo-w) | |
| L = torch.nn.ReplicationPad2d((0, paddingRight, 0, paddingBottom))(L) | |
| E = model(L) | |
| E = E[..., :h, :w] | |
| return E | |
| def inferspfn(model, L, refield=32, min_size=256, sf=1, modulo=1): | |
| h, w = L.size()[-2:] | |
| if h*w <= min_size**2: | |
| L = torch.nn.ReplicationPad2d((0, int(np.ceil(w/modulo)*modulo-w), 0, int(np.ceil(h/modulo)*modulo-h)))(L) | |
| E = model(L) | |
| E = E[..., :h*sf, :w*sf] | |
| else: | |
| top = slice(0, (h//2//refield+1)*refield) | |
| bottom = slice(h - (h//2//refield+1)*refield, h) | |
| left = slice(0, (w//2//refield+1)*refield) | |
| right = slice(w - (w//2//refield+1)*refield, w) | |
| Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] | |
| if h * w <= 4*(min_size**2): | |
| Es = [model(Ls[i]) for i in range(4)] | |
| else: | |
| Es = [inferspfn(model, Ls[i], refield=refield, min_size=min_size, sf=sf, modulo=modulo) for i in range(4)] | |
| b, c = Es[0].size()[:2] | |
| E = torch.zeros(b, c, sf * h, sf * w).type_as(L) | |
| E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] | |
| E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] | |
| E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] | |
| E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] | |
| return E | |
| def infersp(model, L, refield=32, min_size=256, sf=1, modulo=1): | |
| E = inferspfn(model, L, refield=refield, min_size=min_size, sf=sf, modulo=modulo) | |
| return E | |
| def inferosp(model, L, refield=32, min_size=256, sf=1, modulo=1): | |
| h, w = L.size()[-2:] | |
| top = slice(0, (h//2//refield+1)*refield) | |
| bottom = slice(h - (h//2//refield+1)*refield, h) | |
| left = slice(0, (w//2//refield+1)*refield) | |
| right = slice(w - (w//2//refield+1)*refield, w) | |
| Ls = [L[..., top, left], L[..., top, right], L[..., bottom, left], L[..., bottom, right]] | |
| Es = [model(Ls[i]) for i in range(4)] | |
| b, c = Es[0].size()[:2] | |
| E = torch.zeros(b, c, sf * h, sf * w).type_as(L) | |
| E[..., :h//2*sf, :w//2*sf] = Es[0][..., :h//2*sf, :w//2*sf] | |
| E[..., :h//2*sf, w//2*sf:w*sf] = Es[1][..., :h//2*sf, (-w + w//2)*sf:] | |
| E[..., h//2*sf:h*sf, :w//2*sf] = Es[2][..., (-h + h//2)*sf:, :w//2*sf] | |
| E[..., h//2*sf:h*sf, w//2*sf:w*sf] = Es[3][..., (-h + h//2)*sf:, (-w + w//2)*sf:] | |
| return E | |
| def inference(model, L, mode=0, refield=128, min_size=256, sf=1, modulo=1): | |
| if mode == 0: | |
| E = infer(model, L) | |
| elif mode == 1: | |
| E = inferp(model, L, modulo) | |
| elif mode == 2: | |
| E = infersp(model, L, refield, min_size, sf, modulo) | |
| elif mode == 3: | |
| E = inferosp(model, L, refield, min_size, sf, modulo) | |
| return E | |
| if __name__ == '__main__': | |
| class Net(torch.nn.Module): | |
| def __init__(self, in_channels=3, out_channels=3): | |
| super(Net, self).__init__() | |
| self.conv = torch.nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) | |
| def forward(self, x): | |
| x = self.conv(x) | |
| return x | |
| start = torch.cuda.Event(enable_timing=True) | |
| end = torch.cuda.Event(enable_timing=True) | |
| model = Net() | |
| model = model.eval() | |
| x = torch.randn((2,3,400,400)) | |
| torch.cuda.empty_cache() | |
| with torch.no_grad(): | |
| for mode in range(5): | |
| y = inference(model, x, mode) | |
| print(y.shape) | |