Spaces:
Runtime error
Runtime error
| "Manages data input pipeline - folderstransformbatch input. Includes support for classification, segmentation and bounding boxes" | |
| from numbers import Integral | |
| from ..torch_core import * | |
| from .image import * | |
| from .transform import * | |
| from ..data_block import * | |
| from ..basic_data import * | |
| from ..layers import * | |
| from .learner import * | |
| from torchvision import transforms as tvt | |
| __all__ = ['get_image_files', 'denormalize', 'get_annotations', 'ImageDataBunch', | |
| 'ImageList', 'normalize', 'normalize_funcs', 'resize_to', | |
| 'channel_view', 'mnist_stats', 'cifar_stats', 'imagenet_stats', 'imagenet_stats_inception', 'download_images', | |
| 'verify_images', 'bb_pad_collate', 'ImageImageList', 'PointsLabelList', | |
| 'ObjectCategoryList', 'ObjectItemList', 'SegmentationLabelList', 'SegmentationItemList', 'PointsItemList'] | |
| image_extensions = set(k for k,v in mimetypes.types_map.items() if v.startswith('image/')) | |
| def get_image_files(c:PathOrStr, check_ext:bool=True, recurse=False)->FilePathList: | |
| "Return list of files in `c` that are images. `check_ext` will filter to `image_extensions`." | |
| return get_files(c, extensions=(image_extensions if check_ext else None), recurse=recurse) | |
| def get_annotations(fname, prefix=None): | |
| "Open a COCO style json in `fname` and returns the lists of filenames (with maybe `prefix`) and labelled bboxes." | |
| annot_dict = json.load(open(fname)) | |
| id2images, id2bboxes, id2cats = {}, collections.defaultdict(list), collections.defaultdict(list) | |
| classes = {} | |
| for o in annot_dict['categories']: | |
| classes[o['id']] = o['name'] | |
| for o in annot_dict['annotations']: | |
| bb = o['bbox'] | |
| id2bboxes[o['image_id']].append([bb[1],bb[0], bb[3]+bb[1], bb[2]+bb[0]]) | |
| id2cats[o['image_id']].append(classes[o['category_id']]) | |
| for o in annot_dict['images']: | |
| if o['id'] in id2bboxes: | |
| id2images[o['id']] = ifnone(prefix, '') + o['file_name'] | |
| ids = list(id2images.keys()) | |
| return [id2images[k] for k in ids], [[id2bboxes[k], id2cats[k]] for k in ids] | |
| def bb_pad_collate(samples:BatchSamples, pad_idx:int=0) -> Tuple[FloatTensor, Tuple[LongTensor, LongTensor]]: | |
| "Function that collect `samples` of labelled bboxes and adds padding with `pad_idx`." | |
| if isinstance(samples[0][1], int): return data_collate(samples) | |
| max_len = max([len(s[1].data[1]) for s in samples]) | |
| bboxes = torch.zeros(len(samples), max_len, 4) | |
| labels = torch.zeros(len(samples), max_len).long() + pad_idx | |
| imgs = [] | |
| for i,s in enumerate(samples): | |
| imgs.append(s[0].data[None]) | |
| bbs, lbls = s[1].data | |
| if not (bbs.nelement() == 0): | |
| bboxes[i,-len(lbls):] = bbs | |
| labels[i,-len(lbls):] = tensor(lbls) | |
| return torch.cat(imgs,0), (bboxes,labels) | |
| def normalize(x:TensorImage, mean,std:Tensor)->TensorImage: | |
| "Normalize `x` with `mean` and `std`." | |
| return (x-mean[...,None,None]) / std[...,None,None] | |
| def denormalize(x:TensorImage, mean,std:Tensor, do_x:bool=True)->TensorImage: | |
| "Denormalize `x` with `mean` and `std`." | |
| return x.cpu().float()*std[...,None,None] + mean[...,None,None] if do_x else x.cpu() | |
| def _normalize_batch(b:Tuple[Tensor,Tensor], mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Tensor,Tensor]: | |
| "`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`." | |
| x,y = b | |
| mean,std = mean.to(x.device),std.to(x.device) | |
| if do_x: x = normalize(x,mean,std) | |
| if do_y and len(y.shape) == 4: y = normalize(y,mean,std) | |
| return x,y | |
| def normalize_funcs(mean:Tensor, std:Tensor, do_x:bool=True, do_y:bool=False)->Tuple[Callable,Callable]: | |
| "Create normalize/denormalize func using `mean` and `std`, can specify `do_y` and `device`." | |
| mean,std = tensor(mean),tensor(std) | |
| return (partial(_normalize_batch, mean=mean, std=std, do_x=do_x, do_y=do_y), | |
| partial(denormalize, mean=mean, std=std, do_x=do_x)) | |
| cifar_stats = ([0.491, 0.482, 0.447], [0.247, 0.243, 0.261]) | |
| imagenet_stats = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| imagenet_stats_inception = ([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) | |
| mnist_stats = ([0.15]*3, [0.15]*3) | |
| def channel_view(x:Tensor)->Tensor: | |
| "Make channel the first axis of `x` and flatten remaining axes" | |
| return x.transpose(0,1).contiguous().view(x.shape[1],-1) | |
| class ImageDataBunch(DataBunch): | |
| "DataBunch suitable for computer vision." | |
| _square_show = True | |
| def create_from_ll(cls, lls:LabelLists, bs:int=64, val_bs:int=None, ds_tfms:Optional[TfmList]=None, | |
| num_workers:int=defaults.cpus, dl_tfms:Optional[Collection[Callable]]=None, device:torch.device=None, | |
| test:Optional[PathOrStr]=None, collate_fn:Callable=data_collate, size:int=None, no_check:bool=False, | |
| resize_method:ResizeMethod=None, mult:int=None, padding_mode:str='reflection', | |
| mode:str='bilinear', tfm_y:bool=False)->'ImageDataBunch': | |
| "Create an `ImageDataBunch` from `LabelLists` `lls` with potential `ds_tfms`." | |
| lls = lls.transform(tfms=ds_tfms, size=size, resize_method=resize_method, mult=mult, padding_mode=padding_mode, | |
| mode=mode, tfm_y=tfm_y) | |
| if test is not None: lls.add_test_folder(test) | |
| return lls.databunch(bs=bs, val_bs=val_bs, dl_tfms=dl_tfms, num_workers=num_workers, collate_fn=collate_fn, | |
| device=device, no_check=no_check) | |
| def from_folder(cls, path:PathOrStr, train:PathOrStr='train', valid:PathOrStr='valid', | |
| valid_pct=None, seed:int=None, classes:Collection=None, **kwargs:Any)->'ImageDataBunch': | |
| "Create from imagenet style dataset in `path` with `train`,`valid`,`test` subfolders (or provide `valid_pct`)." | |
| path=Path(path) | |
| il = ImageList.from_folder(path) | |
| if valid_pct is None: src = il.split_by_folder(train=train, valid=valid) | |
| else: src = il.split_by_rand_pct(valid_pct, seed) | |
| src = src.label_from_folder(classes=classes) | |
| return cls.create_from_ll(src, **kwargs) | |
| def from_df(cls, path:PathOrStr, df:pd.DataFrame, folder:PathOrStr=None, label_delim:str=None, valid_pct:float=0.2, | |
| seed:int=None, fn_col:IntsOrStrs=0, label_col:IntsOrStrs=1, suffix:str='', **kwargs:Any)->'ImageDataBunch': | |
| "Create from a `DataFrame` `df`." | |
| src = (ImageList.from_df(df, path=path, folder=folder, suffix=suffix, cols=fn_col) | |
| .split_by_rand_pct(valid_pct, seed) | |
| .label_from_df(label_delim=label_delim, cols=label_col)) | |
| return cls.create_from_ll(src, **kwargs) | |
| def from_csv(cls, path:PathOrStr, folder:PathOrStr=None, label_delim:str=None, csv_labels:PathOrStr='labels.csv', | |
| valid_pct:float=0.2, seed:int=None, fn_col:int=0, label_col:int=1, suffix:str='', delimiter:str=None, | |
| header:Optional[Union[int,str]]='infer', **kwargs:Any)->'ImageDataBunch': | |
| "Create from a csv file in `path/csv_labels`." | |
| path = Path(path) | |
| df = pd.read_csv(path/csv_labels, header=header, delimiter=delimiter) | |
| return cls.from_df(path, df, folder=folder, label_delim=label_delim, valid_pct=valid_pct, seed=seed, | |
| fn_col=fn_col, label_col=label_col, suffix=suffix, **kwargs) | |
| def from_lists(cls, path:PathOrStr, fnames:FilePathList, labels:Collection[str], valid_pct:float=0.2, seed:int=None, | |
| item_cls:Callable=None, **kwargs): | |
| "Create from list of `fnames` in `path`." | |
| item_cls = ifnone(item_cls, ImageList) | |
| fname2label = {f:l for (f,l) in zip(fnames, labels)} | |
| src = (item_cls(fnames, path=path).split_by_rand_pct(valid_pct, seed) | |
| .label_from_func(lambda x:fname2label[x])) | |
| return cls.create_from_ll(src, **kwargs) | |
| def from_name_func(cls, path:PathOrStr, fnames:FilePathList, label_func:Callable, valid_pct:float=0.2, seed:int=None, | |
| **kwargs): | |
| "Create from list of `fnames` in `path` with `label_func`." | |
| src = ImageList(fnames, path=path).split_by_rand_pct(valid_pct, seed) | |
| return cls.create_from_ll(src.label_from_func(label_func), **kwargs) | |
| def from_name_re(cls, path:PathOrStr, fnames:FilePathList, pat:str, valid_pct:float=0.2, **kwargs): | |
| "Create from list of `fnames` in `path` with re expression `pat`." | |
| pat = re.compile(pat) | |
| def _get_label(fn): | |
| if isinstance(fn, Path): fn = fn.as_posix() | |
| res = pat.search(str(fn)) | |
| assert res,f'Failed to find "{pat}" in "{fn}"' | |
| return res.group(1) | |
| return cls.from_name_func(path, fnames, _get_label, valid_pct=valid_pct, **kwargs) | |
| def single_from_classes(path:Union[Path, str], classes:Collection[str], ds_tfms:TfmList=None, **kwargs): | |
| "Create an empty `ImageDataBunch` in `path` with `classes`. Typically used for inference." | |
| warn("""This method is deprecated and will be removed in a future version, use `load_learner` after | |
| `Learner.export()`""", DeprecationWarning) | |
| sd = ImageList([], path=path, ignore_empty=True).split_none() | |
| return sd.label_const(0, label_cls=CategoryList, classes=classes).transform(ds_tfms, **kwargs).databunch() | |
| def batch_stats(self, funcs:Collection[Callable]=None, ds_type:DatasetType=DatasetType.Train)->Tensor: | |
| "Grab a batch of data and call reduction function `func` per channel" | |
| funcs = ifnone(funcs, [torch.mean,torch.std]) | |
| x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu() | |
| return [func(channel_view(x), 1) for func in funcs] | |
| def normalize(self, stats:Collection[Tensor]=None, do_x:bool=True, do_y:bool=False)->None: | |
| "Add normalize transform using `stats` (defaults to `DataBunch.batch_stats`)" | |
| if getattr(self,'norm',False): raise Exception('Can not call normalize twice') | |
| if stats is None: self.stats = self.batch_stats() | |
| else: self.stats = stats | |
| self.norm,self.denorm = normalize_funcs(*self.stats, do_x=do_x, do_y=do_y) | |
| self.add_tfm(self.norm) | |
| return self | |
| def download_image(url,dest, timeout=4): | |
| try: r = download_url(url, dest, overwrite=True, show_progress=False, timeout=timeout) | |
| except Exception as e: print(f"Error {url} {e}") | |
| def _download_image_inner(dest, url, i, timeout=4): | |
| suffix = re.findall(r'\.\w+?(?=(?:\?|$))', url) | |
| suffix = suffix[0] if len(suffix)>0 else '.jpg' | |
| download_image(url, dest/f"{i:08d}{suffix}", timeout=timeout) | |
| def download_images(urls:Collection[str], dest:PathOrStr, max_pics:int=1000, max_workers:int=8, timeout=4): | |
| "Download images listed in text file `urls` to path `dest`, at most `max_pics`" | |
| urls = open(urls).read().strip().split("\n")[:max_pics] | |
| dest = Path(dest) | |
| dest.mkdir(exist_ok=True) | |
| parallel(partial(_download_image_inner, dest, timeout=timeout), urls, max_workers=max_workers) | |
| def resize_to(img, targ_sz:int, use_min:bool=False): | |
| "Size to resize to, to hit `targ_sz` at same aspect ratio, in PIL coords (i.e w*h)" | |
| w,h = img.size | |
| min_sz = (min if use_min else max)(w,h) | |
| ratio = targ_sz/min_sz | |
| return int(w*ratio),int(h*ratio) | |
| def verify_image(file:Path, idx:int, delete:bool, max_size:Union[int,Tuple[int,int]]=None, dest:Path=None, n_channels:int=3, | |
| interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None, resume:bool=False, **kwargs): | |
| "Check if the image in `file` exists, maybe resize it and copy it in `dest`." | |
| try: | |
| # deal with partially broken images as indicated by PIL warnings | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings('error') | |
| try: | |
| with open(file, 'rb') as img_file: PIL.Image.open(img_file) | |
| except Warning as w: | |
| if "Possibly corrupt EXIF data" in str(w): | |
| if delete: # green light to modify files | |
| print(f"{file}: Removing corrupt EXIF data") | |
| warnings.simplefilter("ignore") | |
| # save EXIF-cleaned up image, which happens automatically | |
| PIL.Image.open(file).save(file) | |
| else: # keep user's files intact | |
| print(f"{file}: Not removing corrupt EXIF data, pass `delete=True` to do that") | |
| else: warnings.warn(w) | |
| img = PIL.Image.open(file) | |
| imgarr = np.array(img) | |
| img_channels = 1 if len(imgarr.shape) == 2 else imgarr.shape[2] | |
| if (max_size is not None and (img.height > max_size or img.width > max_size)) or img_channels != n_channels: | |
| assert isinstance(dest, Path), "You should provide `dest` Path to save resized image" | |
| dest_fname = dest/file.name | |
| if ext is not None: dest_fname=dest_fname.with_suffix(ext) | |
| if resume and os.path.isfile(dest_fname): return | |
| if max_size is not None: | |
| new_sz = resize_to(img, max_size) | |
| img = img.resize(new_sz, resample=interp) | |
| if n_channels == 3: img = img.convert("RGB") | |
| img.save(dest_fname, img_format, **kwargs) | |
| except Exception as e: | |
| print(f'{e}') | |
| if delete: file.unlink() | |
| def verify_images(path:PathOrStr, delete:bool=True, max_workers:int=4, max_size:Union[int]=None, recurse:bool=False, | |
| dest:PathOrStr='.', n_channels:int=3, interp=PIL.Image.BILINEAR, ext:str=None, img_format:str=None, | |
| resume:bool=None, **kwargs): | |
| "Check if the images in `path` aren't broken, maybe resize them and copy it in `dest`." | |
| path = Path(path) | |
| if resume is None and dest == '.': resume=False | |
| dest = path/Path(dest) | |
| os.makedirs(dest, exist_ok=True) | |
| files = get_image_files(path, recurse=recurse) | |
| func = partial(verify_image, delete=delete, max_size=max_size, dest=dest, n_channels=n_channels, interp=interp, | |
| ext=ext, img_format=img_format, resume=resume, **kwargs) | |
| parallel(func, files, max_workers=max_workers) | |
| class ImageList(ItemList): | |
| "`ItemList` suitable for computer vision." | |
| _bunch,_square_show,_square_show_res = ImageDataBunch,True,True | |
| def __init__(self, *args, convert_mode='RGB', after_open:Callable=None, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.convert_mode,self.after_open = convert_mode,after_open | |
| self.copy_new += ['convert_mode', 'after_open'] | |
| self.c,self.sizes = 3,{} | |
| def open(self, fn): | |
| "Open image in `fn`, subclass and overwrite for custom behavior." | |
| return open_image(fn, convert_mode=self.convert_mode, after_open=self.after_open) | |
| def get(self, i): | |
| fn = super().get(i) | |
| res = self.open(fn) | |
| self.sizes[i] = res.size | |
| return res | |
| def from_folder(cls, path:PathOrStr='.', extensions:Collection[str]=None, **kwargs)->ItemList: | |
| "Get the list of files in `path` that have an image suffix. `recurse` determines if we search subfolders." | |
| extensions = ifnone(extensions, image_extensions) | |
| return super().from_folder(path=path, extensions=extensions, **kwargs) | |
| def from_df(cls, df:DataFrame, path:PathOrStr, cols:IntsOrStrs=0, folder:PathOrStr=None, suffix:str='', **kwargs)->'ItemList': | |
| "Get the filenames in `cols` of `df` with `folder` in front of them, `suffix` at the end." | |
| suffix = suffix or '' | |
| res = super().from_df(df, path=path, cols=cols, **kwargs) | |
| pref = f'{res.path}{os.path.sep}' | |
| if folder is not None: pref += f'{folder}{os.path.sep}' | |
| res.items = np.char.add(np.char.add(pref, res.items.astype(str)), suffix) | |
| return res | |
| def from_csv(cls, path:PathOrStr, csv_name:str, header:str='infer', delimiter:str=None, **kwargs)->'ItemList': | |
| "Get the filenames in `path/csv_name` opened with `header`." | |
| path = Path(path) | |
| df = pd.read_csv(path/csv_name, header=header, delimiter=delimiter) | |
| return cls.from_df(df, path=path, **kwargs) | |
| def reconstruct(self, t:Tensor): return Image(t.float().clamp(min=0,max=1)) | |
| def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs): | |
| "Show the `xs` (inputs) and `ys` (targets) on a figure of `figsize`." | |
| rows = int(np.ceil(math.sqrt(len(xs)))) | |
| axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize) | |
| for x,y,ax in zip(xs, ys, axs.flatten()): x.show(ax=ax, y=y, **kwargs) | |
| for ax in axs.flatten()[len(xs):]: ax.axis('off') | |
| plt.tight_layout() | |
| def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs): | |
| "Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`." | |
| if self._square_show_res: | |
| title = 'Ground truth\nPredictions' | |
| rows = int(np.ceil(math.sqrt(len(xs)))) | |
| axs = subplots(rows, rows, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=12) | |
| for x,y,z,ax in zip(xs,ys,zs,axs.flatten()): x.show(ax=ax, title=f'{str(y)}\n{str(z)}', **kwargs) | |
| for ax in axs.flatten()[len(xs):]: ax.axis('off') | |
| else: | |
| title = 'Ground truth/Predictions' | |
| axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14) | |
| for i,(x,y,z) in enumerate(zip(xs,ys,zs)): | |
| x.show(ax=axs[i,0], y=y, **kwargs) | |
| x.show(ax=axs[i,1], y=z, **kwargs) | |
| class ObjectCategoryProcessor(MultiCategoryProcessor): | |
| "`PreProcessor` for labelled bounding boxes." | |
| def __init__(self, ds:ItemList, pad_idx:int=0): | |
| super().__init__(ds) | |
| self.pad_idx = pad_idx | |
| self.state_attrs.append('pad_idx') | |
| def process(self, ds:ItemList): | |
| ds.pad_idx = self.pad_idx | |
| super().process(ds) | |
| def process_one(self,item): return [item[0], [self.c2i.get(o,None) for o in item[1]]] | |
| def generate_classes(self, items): | |
| "Generate classes from unique `items` and add `background`." | |
| classes = super().generate_classes([o[1] for o in items]) | |
| classes = ['background'] + list(classes) | |
| return classes | |
| def _get_size(xs,i): | |
| size = xs.sizes.get(i,None) | |
| if size is None: | |
| # Image hasn't been accessed yet, so we don't know its size | |
| _ = xs[i] | |
| size = xs.sizes[i] | |
| return size | |
| class ObjectCategoryList(MultiCategoryList): | |
| "`ItemList` for labelled bounding boxes." | |
| _processor = ObjectCategoryProcessor | |
| def get(self, i): | |
| return ImageBBox.create(*_get_size(self.x,i), *self.items[i], classes=self.classes, pad_idx=self.pad_idx) | |
| def analyze_pred(self, pred): return pred | |
| def reconstruct(self, t, x): | |
| (bboxes, labels) = t | |
| if len((labels - self.pad_idx).nonzero()) == 0: return | |
| i = (labels - self.pad_idx).nonzero().min() | |
| bboxes,labels = bboxes[i:],labels[i:] | |
| return ImageBBox.create(*x.size, bboxes, labels=labels, classes=self.classes, scale=False) | |
| class ObjectItemList(ImageList): | |
| "`ItemList` suitable for object detection." | |
| _label_cls,_square_show_res = ObjectCategoryList,False | |
| class SegmentationProcessor(PreProcessor): | |
| "`PreProcessor` that stores the classes for segmentation." | |
| def __init__(self, ds:ItemList): self.classes = ds.classes | |
| def process(self, ds:ItemList): ds.classes,ds.c = self.classes,len(self.classes) | |
| class SegmentationLabelList(ImageList): | |
| "`ItemList` for segmentation masks." | |
| _processor=SegmentationProcessor | |
| def __init__(self, items:Iterator, classes:Collection=None, **kwargs): | |
| super().__init__(items, **kwargs) | |
| self.copy_new.append('classes') | |
| self.classes,self.loss_func = classes,CrossEntropyFlat(axis=1) | |
| def open(self, fn): return open_mask(fn) | |
| def analyze_pred(self, pred, thresh:float=0.5): return pred.argmax(dim=0)[None] | |
| def reconstruct(self, t:Tensor): return ImageSegment(t) | |
| class SegmentationItemList(ImageList): | |
| "`ItemList` suitable for segmentation tasks." | |
| _label_cls,_square_show_res = SegmentationLabelList,False | |
| class PointsProcessor(PreProcessor): | |
| "`PreProcessor` that stores the number of targets for point regression." | |
| def __init__(self, ds:ItemList): self.c = len(ds.items[0].reshape(-1)) | |
| def process(self, ds:ItemList): ds.c = self.c | |
| class PointsLabelList(ItemList): | |
| "`ItemList` for points." | |
| _processor = PointsProcessor | |
| def __init__(self, items:Iterator, **kwargs): | |
| super().__init__(items, **kwargs) | |
| self.loss_func = MSELossFlat() | |
| def get(self, i): | |
| o = super().get(i) | |
| return ImagePoints(FlowField(_get_size(self.x,i), o), scale=True) | |
| def analyze_pred(self, pred, thresh:float=0.5): return pred.view(-1,2) | |
| def reconstruct(self, t, x): return ImagePoints(FlowField(x.size, t), scale=False) | |
| class PointsItemList(ImageList): | |
| "`ItemList` for `Image` to `ImagePoints` tasks." | |
| _label_cls,_square_show_res = PointsLabelList,False | |
| class ImageImageList(ImageList): | |
| "`ItemList` suitable for `Image` to `Image` tasks." | |
| _label_cls,_square_show,_square_show_res = ImageList,False,False | |
| def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs): | |
| "Show the `xs` (inputs) and `ys`(targets) on a figure of `figsize`." | |
| axs = subplots(len(xs), 2, imgsize=imgsize, figsize=figsize) | |
| for i, (x,y) in enumerate(zip(xs,ys)): | |
| x.show(ax=axs[i,0], **kwargs) | |
| y.show(ax=axs[i,1], **kwargs) | |
| plt.tight_layout() | |
| def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs): | |
| "Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`." | |
| title = 'Input / Prediction / Target' | |
| axs = subplots(len(xs), 3, imgsize=imgsize, figsize=figsize, title=title, weight='bold', size=14) | |
| for i,(x,y,z) in enumerate(zip(xs,ys,zs)): | |
| x.show(ax=axs[i,0], **kwargs) | |
| y.show(ax=axs[i,2], **kwargs) | |
| z.show(ax=axs[i,1], **kwargs) | |
| def _ll_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]): | |
| "Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`" | |
| self.train.x.after_open = compose(train_tfm) | |
| self.valid.x.after_open = compose(valid_tfm) | |
| return self | |
| def _db_pre_transform(self, train_tfm:List[Callable], valid_tfm:List[Callable]): | |
| "Call `train_tfm` and `valid_tfm` after opening image, before converting from `PIL.Image`" | |
| self.train_ds.x.after_open = compose(train_tfm) | |
| self.valid_ds.x.after_open = compose(valid_tfm) | |
| return self | |
| def _presize(self, size:int, val_xtra_size:int=32, scale:Tuple[float]=(0.08, 1.0), ratio:Tuple[float]=(0.75, 4./3.), | |
| interpolation:int=2): | |
| "Resize images to `size` using `RandomResizedCrop`, passing along `kwargs` to train transform" | |
| return self.pre_transform( | |
| tvt.RandomResizedCrop(size, scale=scale, ratio=ratio, interpolation=interpolation), | |
| [tvt.Resize(size+val_xtra_size), tvt.CenterCrop(size)]) | |
| LabelLists.pre_transform = _ll_pre_transform | |
| DataBunch.pre_transform = _db_pre_transform | |
| LabelLists.presize = _presize | |
| DataBunch.presize = _presize | |