AuralSAM2 / ref-avs.code /utils /tensorboard.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
"""Optional Weights & Biases logging for Ref-AVS training."""
import os
import torchvision
import wandb
class Tensorboard:
def __init__(self, config):
key = config.get('wandb_key') or os.environ.get('WANDB_API_KEY', '')
if key:
os.environ['WANDB_API_KEY'] = key
mode = 'online' if config.get('wandb_online', False) else 'disabled'
self.tensor_board = wandb.init(
project=config['proj_name'],
name=config['experiment_name'],
config=config,
mode=mode,
settings=wandb.Settings(code_dir='.'),
)
self.restore_transform = torchvision.transforms.Compose([
DeNormalize(config['image_mean'], config['image_std']),
torchvision.transforms.ToPILImage(),
])
def upload_wandb_info(self, info_dict):
for key, value in info_dict.items():
self.tensor_board.log({key: value})
def upload_wandb_image(self, frames, pseudo_label_from_pred, pseudo_label_from_sam, img_number=4):
n = min(pseudo_label_from_pred.shape[0], img_number)
frames = frames[:n]
pseudo_label_from_sam = pseudo_label_from_sam[:n].float()
pseudo_label_from_pred = pseudo_label_from_pred[:n].float()
pseudo_label_from_sam[pseudo_label_from_sam == 255.] = 0.5
pseudo_label_from_pred[pseudo_label_from_pred == 255.] = 0.5
self.tensor_board.log({
'image': [wandb.Image(j, caption=f'id {i}') for i, j in enumerate(frames)],
'label': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_sam)],
'logits': [wandb.Image(j.squeeze(), caption=f'id {i}') for i, j in enumerate(pseudo_label_from_pred)],
})
def finish(self):
self.tensor_board.finish()
class DeNormalize(object):
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, tensor):
for t, m, s in zip(tensor, self.mean, self.std):
t.mul_(s).add_(m)
return tensor