import os # loss function related from lib.utils.box_ops import giou_loss from torch.nn.functional import l1_loss from torch.nn import BCEWithLogitsLoss # train pipeline related from lib.train.trainers import LTRTrainer # distributed training related from torch.nn.parallel import DistributedDataParallel as DDP # some more advanced functions from .base_functions import * # network related from lib.models.atctrack import build_atctrack from lib.models.atctrack import build_atctrack # forward propagation related from lib.train.actors import ATCTrackActor # for import modules import importlib from ..utils.focal_loss import FocalLoss def run(settings): settings.description = 'Training script for atctrack' # update the default configs with config file if not os.path.exists(settings.cfg_file): raise ValueError("%s doesn't exist." % settings.cfg_file) config_module = importlib.import_module("lib.config.%s.config" % settings.script_name) cfg = config_module.cfg # generate cfg from lib.config config_module.update_config_from_file(settings.cfg_file) #update cfg from experiments if settings.local_rank in [-1, 0]: print("New configuration is shown below.") for key in cfg.keys(): print("%s configuration:" % key, cfg[key]) print('\n') # update settings based on cfg update_settings(settings, cfg) # Record the training log log_dir = os.path.join(settings.save_dir, 'logs') if settings.local_rank in [-1, 0]: if not os.path.exists(log_dir): os.makedirs(log_dir) settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name)) # Build dataloaders loader_type = getattr(cfg.DATA, "LOADER", "tracking") if loader_type == "tracking": loader_train = build_dataloaders(cfg, settings) else: raise ValueError("illegal DATA LOADER") # Create network if settings.script_name == "atctrack": net = build_atctrack(cfg) else: raise ValueError("illegal script name") # ---- two-stage teacher labeling: initialise persistent cache ---- teacher_label_cache = None if cfg.MODEL.TARGET_STATE.ENABLE and cfg.MODEL.TARGET_STATE.TEACHER_ENABLE: from lib.train.data.teacher_label_cache import TeacherLabelCache from lib.utils.misc import is_main_process cache_dir = os.path.join(settings.save_dir, "teacher_label_cache") teacher_label_cache = TeacherLabelCache(cache_dir) if is_main_process(): print(f"[TeacherLabelCache] Loaded {len(teacher_label_cache)} cached entries " f"from {cache_dir} (hit_rate={teacher_label_cache.hit_rate():.1%})") # wrap networks to distributed one net.cuda() if settings.local_rank != -1: # net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True) settings.device = torch.device("cuda:%d" % settings.local_rank) else: settings.device = torch.device("cuda:0") settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False) settings.distill = getattr(cfg.TRAIN, "DISTILL", False) settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL") # ---- attach teacher label cache to encoder ---- if teacher_label_cache is not None: base_net = net.module if settings.local_rank != -1 else net if getattr(base_net, "target_state_encoder", None) is not None: base_net.target_state_encoder.set_teacher_label_cache(teacher_label_cache) if is_main_process(): print("[TeacherLabelCache] Attached to target_state_encoder.") # Loss functions and Actors if settings.script_name == "atctrack": focal_loss = FocalLoss() objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()} loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0} actor = ATCTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg) else: raise ValueError("illegal script name") # Optimizer, parameters, and learning rates optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg) use_amp = getattr(cfg.TRAIN, "AMP", False) trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp) # train process trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)