AuralSAM2 / avs.code /v1m.code /utils /tensorboard.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
import os
import PIL
import matplotlib.pyplot as plt
import numpy
import torch
import torchvision
import wandb
# from utils.visualize import show_img
color_map = {"background": (0, 0, 0), "longitudinal": (128, 0, 0), "pothole": (0, 128, 0),
"alligator": (128, 128, 0), "transverse": (128, 0, 128), "ignore": (255, 255, 255)}
class Tensorboard:
def __init__(self, config):
if config.get('wandb_online', False):
key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '')
if key:
os.environ['WANDB_API_KEY'] = key
wandb.login(key=key, relogin=False)
self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'],
config=config, settings=wandb.Settings(code_dir=""))
else:
os.environ.setdefault("WANDB_MODE", "disabled")
self.tensor_board = wandb.init(project=config['proj_name'], name=config['experiment_name'],
config=config, mode="disabled",
settings=wandb.Settings(code_dir=""))
self._log_images = bool(config.get('wandb_online', False))
self.restore_transform = torchvision.transforms.Compose([
DeNormalize(config['image_mean'], config['image_std']),
torchvision.transforms.ToPILImage()])
def upload_wandb_info(self, info_dict):
for i, info in enumerate(info_dict):
self.tensor_board.log({info: info_dict[info]})
return
def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4):
if not self._log_images:
return
def _batched_rgb(t):
"""[N,C,H,W] or [C,H,W] float tensor on CPU."""
if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t)
t = t.detach().cpu().float()
if t.dim() == 3:
return t.unsqueeze(0)
if t.dim() == 4:
return t
raise ValueError("frames must be [C,H,W] or [N,C,H,W], got shape {}".format(tuple(t.shape)))
def _batched_mask(t):
"""[N,H,W] or [N,1,H,W] or [H,W]."""
if not isinstance(t, torch.Tensor):
t = torch.as_tensor(t)
t = t.detach().cpu().float()
while t.dim() > 3:
t = t.squeeze(1)
if t.dim() == 2:
t = t.unsqueeze(0)
if t.dim() != 3:
raise ValueError("masks must be [H,W], [N,H,W] or [N,1,H,W], got shape {}".format(tuple(t.shape)))
return t
frames = _batched_rgb(frames)
pseudo_label_from_pred = _batched_mask(pseudo_label_from_pred)
pseudo_label_from_sam = _batched_mask(pseudo_label_from_sam)
n = min(frames.shape[0], pseudo_label_from_pred.shape[0], pseudo_label_from_sam.shape[0], img_number)
frames = frames[:n]
pseudo_label_from_pred = pseudo_label_from_pred[:n]
pseudo_label_from_sam = pseudo_label_from_sam[:n]
pseudo_label_from_sam = pseudo_label_from_sam.clone()
pseudo_label_from_pred = pseudo_label_from_pred.clone()
pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5
pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5
denorm = self.restore_transform.transforms[0]
image_list = []
label_list = []
logits_list = []
for i in range(n):
fi = frames[i].clone()
if fi.shape[0] == 3:
denorm(fi)
fi.clamp_(0.0, 1.0)
image_list.append(wandb.Image(fi, caption="id {}".format(str(i))))
# wandb.Image expects torch tensors as [C, H, W] (it permutes CHW→HWC)
ms = pseudo_label_from_sam[i].squeeze()
mp = pseudo_label_from_pred[i].squeeze()
if ms.dim() == 2:
ms = ms.unsqueeze(0)
if mp.dim() == 2:
mp = mp.unsqueeze(0)
label_list.append(wandb.Image(ms, caption="id {}".format(str(i))))
logits_list.append(wandb.Image(mp, caption="id {}".format(str(i))))
self.tensor_board.log({"image": image_list, "label": label_list, "logits": logits_list})
def de_normalize(self, image):
return [self.restore_transform(i.detach().cpu()) if (isinstance(i, torch.Tensor) and len(i.shape) == 3)
else colorize_mask(i.detach().cpu().numpy(), self.palette)
for i in image]
def finish(self):
self.tensor_board.finish()
class DeNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor
def colorize_mask(mask, palette):
zero_pad = 256 * 3 - len(palette)
for i in range(zero_pad):
palette.append(0)
# palette[-6:-3] = [183, 65, 14]
new_mask = PIL.Image.fromarray(mask.astype(numpy.uint8)).convert('P')
new_mask.putpalette(palette)
return new_mask