File size: 4,722 Bytes
25986db b3f019f 25986db b3f019f 25986db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 | import os
# loss function related
from lib.utils.box_ops import giou_loss
from torch.nn.functional import l1_loss
from torch.nn import BCEWithLogitsLoss
# train pipeline related
from lib.train.trainers import LTRTrainer
# distributed training related
from torch.nn.parallel import DistributedDataParallel as DDP
# some more advanced functions
from .base_functions import *
# network related
from lib.models.atctrack import build_atctrack
from lib.models.atctrack import build_atctrack
# forward propagation related
from lib.train.actors import ATCTrackActor
# for import modules
import importlib
from ..utils.focal_loss import FocalLoss
def run(settings):
settings.description = 'Training script for atctrack'
# update the default configs with config file
if not os.path.exists(settings.cfg_file):
raise ValueError("%s doesn't exist." % settings.cfg_file)
config_module = importlib.import_module("lib.config.%s.config" % settings.script_name)
cfg = config_module.cfg # generate cfg from lib.config
config_module.update_config_from_file(settings.cfg_file) #update cfg from experiments
if settings.local_rank in [-1, 0]:
print("New configuration is shown below.")
for key in cfg.keys():
print("%s configuration:" % key, cfg[key])
print('\n')
# update settings based on cfg
update_settings(settings, cfg)
# Record the training log
log_dir = os.path.join(settings.save_dir, 'logs')
if settings.local_rank in [-1, 0]:
if not os.path.exists(log_dir):
os.makedirs(log_dir)
settings.log_file = os.path.join(log_dir, "%s-%s.log" % (settings.script_name, settings.config_name))
# Build dataloaders
loader_type = getattr(cfg.DATA, "LOADER", "tracking")
if loader_type == "tracking":
loader_train = build_dataloaders(cfg, settings)
else:
raise ValueError("illegal DATA LOADER")
# Create network
if settings.script_name == "atctrack":
net = build_atctrack(cfg)
else:
raise ValueError("illegal script name")
# ---- two-stage teacher labeling: initialise persistent cache ----
teacher_label_cache = None
if cfg.MODEL.TARGET_STATE.ENABLE and cfg.MODEL.TARGET_STATE.TEACHER_ENABLE:
from lib.train.data.teacher_label_cache import TeacherLabelCache
from lib.utils.misc import is_main_process
cache_dir = os.path.join(settings.save_dir, "teacher_label_cache")
teacher_label_cache = TeacherLabelCache(cache_dir)
if is_main_process():
print(f"[TeacherLabelCache] Loaded {len(teacher_label_cache)} cached entries "
f"from {cache_dir} (hit_rate={teacher_label_cache.hit_rate():.1%})")
# wrap networks to distributed one
net.cuda()
if settings.local_rank != -1:
# net = torch.nn.SyncBatchNorm.convert_sync_batchnorm(net) # add syncBN converter
net = DDP(net, device_ids=[settings.local_rank], find_unused_parameters=True)
settings.device = torch.device("cuda:%d" % settings.local_rank)
else:
settings.device = torch.device("cuda:0")
settings.deep_sup = getattr(cfg.TRAIN, "DEEP_SUPERVISION", False)
settings.distill = getattr(cfg.TRAIN, "DISTILL", False)
settings.distill_loss_type = getattr(cfg.TRAIN, "DISTILL_LOSS_TYPE", "KL")
# ---- attach teacher label cache to encoder ----
if teacher_label_cache is not None:
base_net = net.module if settings.local_rank != -1 else net
if getattr(base_net, "target_state_encoder", None) is not None:
base_net.target_state_encoder.set_teacher_label_cache(teacher_label_cache)
if is_main_process():
print("[TeacherLabelCache] Attached to target_state_encoder.")
# Loss functions and Actors
if settings.script_name == "atctrack":
focal_loss = FocalLoss()
objective = {'giou': giou_loss, 'l1': l1_loss, 'focal': focal_loss, 'cls': BCEWithLogitsLoss()}
loss_weight = {'giou': cfg.TRAIN.GIOU_WEIGHT, 'l1': cfg.TRAIN.L1_WEIGHT, 'focal': 1., 'cls': 1.0}
actor = ATCTrackActor(net=net, objective=objective, loss_weight=loss_weight, settings=settings, cfg=cfg)
else:
raise ValueError("illegal script name")
# Optimizer, parameters, and learning rates
optimizer, lr_scheduler = get_optimizer_scheduler(net, cfg)
use_amp = getattr(cfg.TRAIN, "AMP", False)
trainer = LTRTrainer(actor, [loader_train], optimizer, settings, lr_scheduler, use_amp=use_amp)
# train process
trainer.train(cfg.TRAIN.EPOCH, load_latest=True, fail_safe=True)
|