Spaces:
Runtime error
Runtime error
user
22d7bd3 | import glob | |
| import os | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| import model_io | |
| import utils | |
| from adabins import UnetAdaptiveBins | |
| def _is_pil_image(img): | |
| return isinstance(img, Image.Image) | |
| def _is_numpy_image(img): | |
| return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) | |
| class ToTensor(object): | |
| def __init__(self): | |
| self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| def __call__(self, image, target_size=(640, 480)): | |
| # image = image.resize(target_size) | |
| image = self.to_tensor(image) | |
| image = self.normalize(image) | |
| return image | |
| def to_tensor(self, pic): | |
| if not (_is_pil_image(pic) or _is_numpy_image(pic)): | |
| raise TypeError( | |
| 'pic should be PIL Image or ndarray. Got {}'.format(type(pic))) | |
| if isinstance(pic, np.ndarray): | |
| img = torch.from_numpy(pic.transpose((2, 0, 1))) | |
| return img | |
| # handle PIL Image | |
| if pic.mode == 'I': | |
| img = torch.from_numpy(np.array(pic, np.int32, copy=False)) | |
| elif pic.mode == 'I;16': | |
| img = torch.from_numpy(np.array(pic, np.int16, copy=False)) | |
| else: | |
| img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | |
| # PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK | |
| if pic.mode == 'YCbCr': | |
| nchannel = 3 | |
| elif pic.mode == 'I;16': | |
| nchannel = 1 | |
| else: | |
| nchannel = len(pic.mode) | |
| img = img.view(pic.size[1], pic.size[0], nchannel) | |
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |
| if isinstance(img, torch.ByteTensor): | |
| return img.float() | |
| else: | |
| return img | |
| class InferenceHelper: | |
| def __init__(self, models_path, dataset='nyu', device='cuda:0'): | |
| self.toTensor = ToTensor() | |
| self.device = device | |
| if dataset == 'nyu': | |
| self.min_depth = 1e-3 | |
| self.max_depth = 10 | |
| self.saving_factor = 1000 # used to save in 16 bit | |
| model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) | |
| pretrained_path = os.path.join(models_path, "AdaBins_nyu.pt") | |
| elif dataset == 'kitti': | |
| self.min_depth = 1e-3 | |
| self.max_depth = 80 | |
| self.saving_factor = 256 | |
| model = UnetAdaptiveBins.build(n_bins=256, min_val=self.min_depth, max_val=self.max_depth) | |
| pretrained_path = os.path.join(models_path, "AdaBins_kitti.pt") | |
| else: | |
| raise ValueError("dataset can be either 'nyu' or 'kitti' but got {}".format(dataset)) | |
| model, _, _ = model_io.load_checkpoint(pretrained_path, model) | |
| model.eval() | |
| self.model = model.to(self.device) | |
| def predict_pil(self, pil_image, visualized=False): | |
| # pil_image = pil_image.resize((640, 480)) | |
| img = np.asarray(pil_image) / 255. | |
| img = self.toTensor(img).unsqueeze(0).float().to(self.device) | |
| bin_centers, pred = self.predict(img) | |
| if visualized: | |
| viz = utils.colorize(torch.from_numpy(pred).unsqueeze(0), vmin=None, vmax=None, cmap='magma') | |
| # pred = np.asarray(pred*1000, dtype='uint16') | |
| viz = Image.fromarray(viz) | |
| return bin_centers, pred, viz | |
| return bin_centers, pred | |
| def predict(self, image): | |
| bins, pred = self.model(image) | |
| pred = np.clip(pred.cpu().numpy(), self.min_depth, self.max_depth) | |
| # Flip | |
| image = torch.Tensor(np.array(image.cpu().numpy())[..., ::-1].copy()).to(self.device) | |
| pred_lr = self.model(image)[-1] | |
| pred_lr = np.clip(pred_lr.cpu().numpy()[..., ::-1], self.min_depth, self.max_depth) | |
| # Take average of original and mirror | |
| final = 0.5 * (pred + pred_lr) | |
| final = nn.functional.interpolate(torch.Tensor(final), image.shape[-2:], | |
| mode='bilinear', align_corners=True).cpu().numpy() | |
| final[final < self.min_depth] = self.min_depth | |
| final[final > self.max_depth] = self.max_depth | |
| final[np.isinf(final)] = self.max_depth | |
| final[np.isnan(final)] = self.min_depth | |
| centers = 0.5 * (bins[:, 1:] + bins[:, :-1]) | |
| centers = centers.cpu().squeeze().numpy() | |
| centers = centers[centers > self.min_depth] | |
| centers = centers[centers < self.max_depth] | |
| return centers, final | |
| def predict_dir(self, test_dir, out_dir): | |
| os.makedirs(out_dir, exist_ok=True) | |
| transform = ToTensor() | |
| all_files = glob.glob(os.path.join(test_dir, "*")) | |
| self.model.eval() | |
| for f in tqdm(all_files): | |
| image = np.asarray(Image.open(f), dtype='float32') / 255. | |
| image = transform(image).unsqueeze(0).to(self.device) | |
| centers, final = self.predict(image) | |
| # final = final.squeeze().cpu().numpy() | |
| final = (final * self.saving_factor).astype('uint16') | |
| basename = os.path.basename(f).split('.')[0] | |
| save_path = os.path.join(out_dir, basename + ".png") | |
| Image.fromarray(final.squeeze()).save(save_path) | |
| def to(self, device): | |
| self.device = device | |
| self.model.to(device) | |
| if __name__ == '__main__': | |
| import matplotlib.pyplot as plt | |
| from time import time | |
| img = Image.open("test_imgs/classroom__rgb_00283.jpg") | |
| start = time() | |
| inferHelper = InferenceHelper() | |
| centers, pred = inferHelper.predict_pil(img) | |
| print(f"took :{time() - start}s") | |
| plt.imshow(pred.squeeze(), cmap='magma_r') | |
| plt.show() | |