ATCTrack-VLM / lib /train /admin /tensorboard.py
SunXiang2025's picture
Upload ATCTrack-VLM code and selected checkpoints
25986db verified
import os
from collections import OrderedDict
try:
from torch.utils.tensorboard import SummaryWriter
except:
print('WARNING: You are using tensorboardX instead sis you have a too old pytorch version.')
from tensorboardX import SummaryWriter
class TensorboardWriter:
def __init__(self, directory, loader_names):
self.directory = directory
self.writer = OrderedDict({name: SummaryWriter(os.path.join(self.directory, name)) for name in loader_names})
def write_info(self, script_name, description):
tb_info_writer = SummaryWriter(os.path.join(self.directory, 'info'))
tb_info_writer.add_text('Script_name', script_name)
tb_info_writer.add_text('Description', description)
tb_info_writer.close()
def write_epoch(self, stats: OrderedDict, epoch: int, ind=-1):
for loader_name, loader_stats in stats.items():
if loader_stats is None:
continue
for var_name, val in loader_stats.items():
if hasattr(val, 'history') and getattr(val, 'has_new_data', True):
self.writer[loader_name].add_scalar(var_name, val.history[ind], epoch)