| import os
|
|
|
| from lib.utils.box_ops import giou_loss
|
| from torch.nn.functional import l1_loss
|
| from torch.nn import BCEWithLogitsLoss
|
|
|
| from lib.train.trainers import LTRTrainer
|
|
|
| from torch.nn.parallel import DistributedDataParallel as DDP
|
|
|
| from .base_functions import *
|
|
|
| from lib.models.atctrack import build_atctrack
|
| from lib.models.atctrack import build_atctrack
|
|
|
| from lib.train.actors import ATCTrackActor
|
|
|
| import importlib
|
|
|
| from ..utils.focal_loss import FocalLoss
|
|
|
|
|
| def run(settings):
|
| settings.description = 'Training script for atctrack'
|
|
|
|
|
| if not os.path.exists(settings.cfg_file):
|
| raise ValueError("%s doesn't exist." % settings.cfg_file)
|
| config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
|
| cfg = config_module.cfg
|
| config_module.update_config_from_file(settings.cfg_file)
|
| if settings.local_rank in [-1, 0]:
|
| print("New configuration is shown below.")
|
| for key in cfg.keys():
|
| print("%s configuration:" % key, cfg[key])
|
| print('\n')
|
|
|
|
|
| update_settings(settings, cfg)
|
|
|
|
|
| log_dir = os.path.join(settings.save_dir, 'logs')
|
| if settings.local_rank in [-1, 0]:
|
| if not os.path.exists(log_dir):
|
| os.makedirs(log_dir)
|
| settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
|
|
|
|
|
| loader_type = getattr(cfg.DATA, "LOADER", "tracking")
|
| if loader_type == "tracking":
|
| loader_train = build_dataloaders(cfg, settings)
|
| else:
|
| raise ValueError("illegal DATA LOADER")
|
|
|
|
|
|
|
| if settings.script_name == "atctrack":
|
| net = build_atctrack(cfg)
|
| else:
|
| raise ValueError("illegal script name")
|
|
|
|
|
| teacher_label_cache = None
|
| if cfg.MODEL.TARGET_STATE.ENABLE and cfg.MODEL.TARGET_STATE.TEACHER_ENABLE:
|
| from lib.train.data.teacher_label_cache import TeacherLabelCache
|
| from lib.utils.misc import is_main_process
|
| cache_dir = os.path.join(settings.save_dir, "teacher_label_cache")
|
| teacher_label_cache = TeacherLabelCache(cache_dir)
|
| if is_main_process():
|
| print(f"[TeacherLabelCache] Loaded {len(teacher_label_cache)} cached entries "
|
| f"from {cache_dir} (hit_rate={teacher_label_cache.hit_rate():.1%})")
|
|
|
|
|
| net.cuda()
|
| if settings.local_rank != -1:
|
|
|
| net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
|
| settings.device = torch.device("cuda:%d" % settings.local_rank)
|
| else:
|
| settings.device = torch.device("cuda:0")
|
| settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
|
| settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
|
| settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
|
|
|
|
|
| if teacher_label_cache is not None:
|
| base_net = net.module if settings.local_rank != -1 else net
|
| if getattr(base_net, "target_state_encoder", None) is not None:
|
| base_net.target_state_encoder.set_teacher_label_cache(teacher_label_cache)
|
| if is_main_process():
|
| print("[TeacherLabelCache] Attached to target_state_encoder.")
|
|
|
|
|
| if settings.script_name == "atctrack":
|
| focal_loss = FocalLoss()
|
| objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()}
|
| loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0}
|
| actor = ATCTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg)
|
| else:
|
| raise ValueError("illegal script name")
|
|
|
|
|
| optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
|
| use_amp = getattr(cfg.TRAIN, "AMP", False)
|
| trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp)
|
|
|
|
|
| trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
|
|
|