Spaces:
Runtime error
Runtime error
Commit
·
2c19c0f
1
Parent(s):
a31136c
Create utils.py
Browse files
utils.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MIT License
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2022 Intelligent Systems Lab Org
|
| 4 |
+
|
| 5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
# in the Software without restriction, including without limitation the rights
|
| 8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
# furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
# The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
# copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
# SOFTWARE.
|
| 22 |
+
|
| 23 |
+
# File author: Shariq Farooq Bhat
|
| 24 |
+
|
| 25 |
+
"""Miscellaneous utility functions."""
|
| 26 |
+
|
| 27 |
+
from scipy import ndimage
|
| 28 |
+
|
| 29 |
+
import base64
|
| 30 |
+
import math
|
| 31 |
+
import re
|
| 32 |
+
from io import BytesIO
|
| 33 |
+
|
| 34 |
+
import matplotlib
|
| 35 |
+
import matplotlib.cm
|
| 36 |
+
import numpy as np
|
| 37 |
+
import requests
|
| 38 |
+
import torch
|
| 39 |
+
import torch.distributed as dist
|
| 40 |
+
import torch.nn
|
| 41 |
+
import torch.nn as nn
|
| 42 |
+
import torch.utils.data.distributed
|
| 43 |
+
from PIL import Image
|
| 44 |
+
from torchvision.transforms import ToTensor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class RunningAverage:
|
| 48 |
+
def __init__(self):
|
| 49 |
+
self.avg = 0
|
| 50 |
+
self.count = 0
|
| 51 |
+
|
| 52 |
+
def append(self, value):
|
| 53 |
+
self.avg = (value + self.count * self.avg) / (self.count + 1)
|
| 54 |
+
self.count += 1
|
| 55 |
+
|
| 56 |
+
def get_value(self):
|
| 57 |
+
return self.avg
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def denormalize(x):
|
| 61 |
+
"""Reverses the imagenet normalization applied to the input.
|
| 62 |
+
Args:
|
| 63 |
+
x (torch.Tensor - shape(N,3,H,W)): input tensor
|
| 64 |
+
Returns:
|
| 65 |
+
torch.Tensor - shape(N,3,H,W): Denormalized input
|
| 66 |
+
"""
|
| 67 |
+
mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device)
|
| 68 |
+
std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device)
|
| 69 |
+
return x * std + mean
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RunningAverageDict:
|
| 73 |
+
"""A dictionary of running averages."""
|
| 74 |
+
def __init__(self):
|
| 75 |
+
self._dict = None
|
| 76 |
+
|
| 77 |
+
def update(self, new_dict):
|
| 78 |
+
if new_dict is None:
|
| 79 |
+
return
|
| 80 |
+
|
| 81 |
+
if self._dict is None:
|
| 82 |
+
self._dict = dict()
|
| 83 |
+
for key, value in new_dict.items():
|
| 84 |
+
self._dict[key] = RunningAverage()
|
| 85 |
+
|
| 86 |
+
for key, value in new_dict.items():
|
| 87 |
+
self._dict[key].append(value)
|
| 88 |
+
|
| 89 |
+
def get_value(self):
|
| 90 |
+
if self._dict is None:
|
| 91 |
+
return None
|
| 92 |
+
return {key: value.get_value() for key, value in self._dict.items()}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def colorize(value, vmin=None, vmax=None, cmap='gray_r', invalid_val=-99, invalid_mask=None, background_color=(128, 128, 128, 255), gamma_corrected=False, value_transform=None):
|
| 96 |
+
"""Converts a depth map to a color image.
|
| 97 |
+
Args:
|
| 98 |
+
value (torch.Tensor, numpy.ndarry): Input depth map. Shape: (H, W) or (1, H, W) or (1, 1, H, W). All singular dimensions are squeezed
|
| 99 |
+
vmin (float, optional): vmin-valued entries are mapped to start color of cmap. If None, value.min() is used. Defaults to None.
|
| 100 |
+
vmax (float, optional): vmax-valued entries are mapped to end color of cmap. If None, value.max() is used. Defaults to None.
|
| 101 |
+
cmap (str, optional): matplotlib colormap to use. Defaults to 'magma_r'.
|
| 102 |
+
invalid_val (int, optional): Specifies value of invalid pixels that should be colored as 'background_color'. Defaults to -99.
|
| 103 |
+
invalid_mask (numpy.ndarray, optional): Boolean mask for invalid regions. Defaults to None.
|
| 104 |
+
background_color (tuple[int], optional): 4-tuple RGB color to give to invalid pixels. Defaults to (128, 128, 128, 255).
|
| 105 |
+
gamma_corrected (bool, optional): Apply gamma correction to colored image. Defaults to False.
|
| 106 |
+
value_transform (Callable, optional): Apply transform function to valid pixels before coloring. Defaults to None.
|
| 107 |
+
Returns:
|
| 108 |
+
numpy.ndarray, dtype - uint8: Colored depth map. Shape: (H, W, 4)
|
| 109 |
+
"""
|
| 110 |
+
if isinstance(value, torch.Tensor):
|
| 111 |
+
value = value.detach().cpu().numpy()
|
| 112 |
+
|
| 113 |
+
value = value.squeeze()
|
| 114 |
+
if invalid_mask is None:
|
| 115 |
+
invalid_mask = value == invalid_val
|
| 116 |
+
mask = np.logical_not(invalid_mask)
|
| 117 |
+
|
| 118 |
+
# normalize
|
| 119 |
+
vmin = np.percentile(value[mask],2) if vmin is None else vmin
|
| 120 |
+
vmax = np.percentile(value[mask],85) if vmax is None else vmax
|
| 121 |
+
if vmin != vmax:
|
| 122 |
+
value = (value - vmin) / (vmax - vmin) # vmin..vmax
|
| 123 |
+
else:
|
| 124 |
+
# Avoid 0-division
|
| 125 |
+
value = value * 0.
|
| 126 |
+
|
| 127 |
+
# squeeze last dim if it exists
|
| 128 |
+
# grey out the invalid values
|
| 129 |
+
|
| 130 |
+
value[invalid_mask] = np.nan
|
| 131 |
+
cmapper = matplotlib.cm.get_cmap(cmap)
|
| 132 |
+
if value_transform:
|
| 133 |
+
value = value_transform(value)
|
| 134 |
+
# value = value / value.max()
|
| 135 |
+
value = cmapper(value, bytes=True) # (nxmx4)
|
| 136 |
+
|
| 137 |
+
# img = value[:, :, :]
|
| 138 |
+
img = value[...]
|
| 139 |
+
img[invalid_mask] = background_color
|
| 140 |
+
|
| 141 |
+
# return img.transpose((2, 0, 1))
|
| 142 |
+
if gamma_corrected:
|
| 143 |
+
# gamma correction
|
| 144 |
+
img = img / 255
|
| 145 |
+
img = np.power(img, 2.2)
|
| 146 |
+
img = img * 255
|
| 147 |
+
img = img.astype(np.uint8)
|
| 148 |
+
return img
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def count_parameters(model, include_all=False):
|
| 152 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad or include_all)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def compute_errors(gt, pred):
|
| 156 |
+
"""Compute metrics for 'pred' compared to 'gt'
|
| 157 |
+
Args:
|
| 158 |
+
gt (numpy.ndarray): Ground truth values
|
| 159 |
+
pred (numpy.ndarray): Predicted values
|
| 160 |
+
gt.shape should be equal to pred.shape
|
| 161 |
+
Returns:
|
| 162 |
+
dict: Dictionary containing the following metrics:
|
| 163 |
+
'a1': Delta1 accuracy: Fraction of pixels that are within a scale factor of 1.25
|
| 164 |
+
'a2': Delta2 accuracy: Fraction of pixels that are within a scale factor of 1.25^2
|
| 165 |
+
'a3': Delta3 accuracy: Fraction of pixels that are within a scale factor of 1.25^3
|
| 166 |
+
'abs_rel': Absolute relative error
|
| 167 |
+
'rmse': Root mean squared error
|
| 168 |
+
'log_10': Absolute log10 error
|
| 169 |
+
'sq_rel': Squared relative error
|
| 170 |
+
'rmse_log': Root mean squared error on the log scale
|
| 171 |
+
'silog': Scale invariant log error
|
| 172 |
+
"""
|
| 173 |
+
thresh = np.maximum((gt / pred), (pred / gt))
|
| 174 |
+
a1 = (thresh < 1.25).mean()
|
| 175 |
+
a2 = (thresh < 1.25 ** 2).mean()
|
| 176 |
+
a3 = (thresh < 1.25 ** 3).mean()
|
| 177 |
+
|
| 178 |
+
abs_rel = np.mean(np.abs(gt - pred) / gt)
|
| 179 |
+
sq_rel = np.mean(((gt - pred) ** 2) / gt)
|
| 180 |
+
|
| 181 |
+
rmse = (gt - pred) ** 2
|
| 182 |
+
rmse = np.sqrt(rmse.mean())
|
| 183 |
+
|
| 184 |
+
rmse_log = (np.log(gt) - np.log(pred)) ** 2
|
| 185 |
+
rmse_log = np.sqrt(rmse_log.mean())
|
| 186 |
+
|
| 187 |
+
err = np.log(pred) - np.log(gt)
|
| 188 |
+
silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100
|
| 189 |
+
|
| 190 |
+
log_10 = (np.abs(np.log10(gt) - np.log10(pred))).mean()
|
| 191 |
+
return dict(a1=a1, a2=a2, a3=a3, abs_rel=abs_rel, rmse=rmse, log_10=log_10, rmse_log=rmse_log,
|
| 192 |
+
silog=silog, sq_rel=sq_rel)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def compute_metrics(gt, pred, interpolate=True, garg_crop=False, eigen_crop=True, dataset='nyu', min_depth_eval=0.1, max_depth_eval=10, **kwargs):
|
| 196 |
+
"""Compute metrics of predicted depth maps. Applies cropping and masking as necessary or specified via arguments. Refer to compute_errors for more details on metrics.
|
| 197 |
+
"""
|
| 198 |
+
if 'config' in kwargs:
|
| 199 |
+
config = kwargs['config']
|
| 200 |
+
garg_crop = config.garg_crop
|
| 201 |
+
eigen_crop = config.eigen_crop
|
| 202 |
+
min_depth_eval = config.min_depth_eval
|
| 203 |
+
max_depth_eval = config.max_depth_eval
|
| 204 |
+
|
| 205 |
+
if gt.shape[-2:] != pred.shape[-2:] and interpolate:
|
| 206 |
+
pred = nn.functional.interpolate(
|
| 207 |
+
pred, gt.shape[-2:], mode='bilinear', align_corners=True)
|
| 208 |
+
|
| 209 |
+
pred = pred.squeeze().cpu().numpy()
|
| 210 |
+
pred[pred < min_depth_eval] = min_depth_eval
|
| 211 |
+
pred[pred > max_depth_eval] = max_depth_eval
|
| 212 |
+
pred[np.isinf(pred)] = max_depth_eval
|
| 213 |
+
pred[np.isnan(pred)] = min_depth_eval
|
| 214 |
+
|
| 215 |
+
gt_depth = gt.squeeze().cpu().numpy()
|
| 216 |
+
valid_mask = np.logical_and(
|
| 217 |
+
gt_depth > min_depth_eval, gt_depth < max_depth_eval)
|
| 218 |
+
|
| 219 |
+
if garg_crop or eigen_crop:
|
| 220 |
+
gt_height, gt_width = gt_depth.shape
|
| 221 |
+
eval_mask = np.zeros(valid_mask.shape)
|
| 222 |
+
|
| 223 |
+
if garg_crop:
|
| 224 |
+
eval_mask[int(0.40810811 * gt_height):int(0.99189189 * gt_height),
|
| 225 |
+
int(0.03594771 * gt_width):int(0.96405229 * gt_width)] = 1
|
| 226 |
+
|
| 227 |
+
elif eigen_crop:
|
| 228 |
+
# print("-"*10, " EIGEN CROP ", "-"*10)
|
| 229 |
+
if dataset == 'kitti':
|
| 230 |
+
eval_mask[int(0.3324324 * gt_height):int(0.91351351 * gt_height),
|
| 231 |
+
int(0.0359477 * gt_width):int(0.96405229 * gt_width)] = 1
|
| 232 |
+
else:
|
| 233 |
+
# assert gt_depth.shape == (480, 640), "Error: Eigen crop is currently only valid for (480, 640) images"
|
| 234 |
+
eval_mask[45:471, 41:601] = 1
|
| 235 |
+
else:
|
| 236 |
+
eval_mask = np.ones(valid_mask.shape)
|
| 237 |
+
valid_mask = np.logical_and(valid_mask, eval_mask)
|
| 238 |
+
return compute_errors(gt_depth[valid_mask], pred[valid_mask])
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
#################################### Model uilts ################################################
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def parallelize(config, model, find_unused_parameters=True):
|
| 245 |
+
|
| 246 |
+
if config.gpu is not None:
|
| 247 |
+
torch.cuda.set_device(config.gpu)
|
| 248 |
+
model = model.cuda(config.gpu)
|
| 249 |
+
|
| 250 |
+
config.multigpu = False
|
| 251 |
+
if config.distributed:
|
| 252 |
+
# Use DDP
|
| 253 |
+
config.multigpu = True
|
| 254 |
+
config.rank = config.rank * config.ngpus_per_node + config.gpu
|
| 255 |
+
dist.init_process_group(backend=config.dist_backend, init_method=config.dist_url,
|
| 256 |
+
world_size=config.world_size, rank=config.rank)
|
| 257 |
+
config.batch_size = int(config.batch_size / config.ngpus_per_node)
|
| 258 |
+
# config.batch_size = 8
|
| 259 |
+
config.workers = int(
|
| 260 |
+
(config.num_workers + config.ngpus_per_node - 1) / config.ngpus_per_node)
|
| 261 |
+
print("Device", config.gpu, "Rank", config.rank, "batch size",
|
| 262 |
+
config.batch_size, "Workers", config.workers)
|
| 263 |
+
torch.cuda.set_device(config.gpu)
|
| 264 |
+
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
|
| 265 |
+
model = model.cuda(config.gpu)
|
| 266 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.gpu], output_device=config.gpu,
|
| 267 |
+
find_unused_parameters=find_unused_parameters)
|
| 268 |
+
|
| 269 |
+
elif config.gpu is None:
|
| 270 |
+
# Use DP
|
| 271 |
+
config.multigpu = True
|
| 272 |
+
model = model.cuda()
|
| 273 |
+
model = torch.nn.DataParallel(model)
|
| 274 |
+
|
| 275 |
+
return model
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
#################################################################################################
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
#####################################################################################################
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class colors:
|
| 285 |
+
'''Colors class:
|
| 286 |
+
Reset all colors with colors.reset
|
| 287 |
+
Two subclasses fg for foreground and bg for background.
|
| 288 |
+
Use as colors.subclass.colorname.
|
| 289 |
+
i.e. colors.fg.red or colors.bg.green
|
| 290 |
+
Also, the generic bold, disable, underline, reverse, strikethrough,
|
| 291 |
+
and invisible work with the main class
|
| 292 |
+
i.e. colors.bold
|
| 293 |
+
'''
|
| 294 |
+
reset = '\033[0m'
|
| 295 |
+
bold = '\033[01m'
|
| 296 |
+
disable = '\033[02m'
|
| 297 |
+
underline = '\033[04m'
|
| 298 |
+
reverse = '\033[07m'
|
| 299 |
+
strikethrough = '\033[09m'
|
| 300 |
+
invisible = '\033[08m'
|
| 301 |
+
|
| 302 |
+
class fg:
|
| 303 |
+
black = '\033[30m'
|
| 304 |
+
red = '\033[31m'
|
| 305 |
+
green = '\033[32m'
|
| 306 |
+
orange = '\033[33m'
|
| 307 |
+
blue = '\033[34m'
|
| 308 |
+
purple = '\033[35m'
|
| 309 |
+
cyan = '\033[36m'
|
| 310 |
+
lightgrey = '\033[37m'
|
| 311 |
+
darkgrey = '\033[90m'
|
| 312 |
+
lightred = '\033[91m'
|
| 313 |
+
lightgreen = '\033[92m'
|
| 314 |
+
yellow = '\033[93m'
|
| 315 |
+
lightblue = '\033[94m'
|
| 316 |
+
pink = '\033[95m'
|
| 317 |
+
lightcyan = '\033[96m'
|
| 318 |
+
|
| 319 |
+
class bg:
|
| 320 |
+
black = '\033[40m'
|
| 321 |
+
red = '\033[41m'
|
| 322 |
+
green = '\033[42m'
|
| 323 |
+
orange = '\033[43m'
|
| 324 |
+
blue = '\033[44m'
|
| 325 |
+
purple = '\033[45m'
|
| 326 |
+
cyan = '\033[46m'
|
| 327 |
+
lightgrey = '\033[47m'
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def printc(text, color):
|
| 331 |
+
print(f"{color}{text}{colors.reset}")
|
| 332 |
+
|
| 333 |
+
############################################
|
| 334 |
+
|
| 335 |
+
def get_image_from_url(url):
|
| 336 |
+
response = requests.get(url)
|
| 337 |
+
img = Image.open(BytesIO(response.content)).convert("RGB")
|
| 338 |
+
return img
|
| 339 |
+
|
| 340 |
+
def url_to_torch(url, size=(384, 384)):
|
| 341 |
+
img = get_image_from_url(url)
|
| 342 |
+
img = img.resize(size, Image.ANTIALIAS)
|
| 343 |
+
img = torch.from_numpy(np.asarray(img)).float()
|
| 344 |
+
img = img.permute(2, 0, 1)
|
| 345 |
+
img.div_(255)
|
| 346 |
+
return img
|
| 347 |
+
|
| 348 |
+
def pil_to_batched_tensor(img):
|
| 349 |
+
return ToTensor()(img).unsqueeze(0)
|
| 350 |
+
|
| 351 |
+
def save_raw_16bit(depth, fpath="raw.png"):
|
| 352 |
+
if isinstance(depth, torch.Tensor):
|
| 353 |
+
depth = depth.squeeze().cpu().numpy()
|
| 354 |
+
|
| 355 |
+
assert isinstance(depth, np.ndarray), "Depth must be a torch tensor or numpy array"
|
| 356 |
+
assert depth.ndim == 2, "Depth must be 2D"
|
| 357 |
+
depth = depth * 256 # scale for 16-bit png
|
| 358 |
+
depth = depth.astype(np.uint16)
|
| 359 |
+
depth = Image.fromarray(depth)
|
| 360 |
+
depth.save(fpath)
|
| 361 |
+
print("Saved raw depth to", fpath)
|