import itertools from torchvision.transforms import functional as F def five_crop(image, ratio=0.6): w, h = image.size hw = (h*ratio, w*ratio) return F.five_crop(image, hw) def nine_crop(image, ratio=0.4): w, h = image.size t = (0, int((0.5-ratio/2)*h), int((1.0 - ratio)*h)) b = (int(ratio*h), int((0.5+ratio/2)*h), h) l = (0, int((0.5-ratio/2)*w), int((1.0 - ratio)*w)) r = (int(ratio*w), int((0.5+ratio/2)*w), w) h, w = list(zip(t, b)), list(zip(l, r)) images = [] for s in itertools.product(h, w): h, w = s top, left = h[0], w[0] height, width = h[1]-h[0], w[1]-w[0] images.append(F.crop(image, top, left, height, width)) return images