File size: 7,327 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
"""DDP training entry: AV model with SAM2 frozen, AuralFuser trainable, Hydra transforms and loss."""
import os
import torch
import numpy
import random
import argparse
from easydict import EasyDict


def seed_it(seed):
    """Fix RNGs and cuDNN for reproducible runs (rank offsets seed in DDP)."""
    os.environ["PYTHONSEED"] = str(seed)
    random.seed(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.deterministic = True

    torch.backends.cudnn.benchmark = False


def main(local_rank, ngpus_per_node, hyp_param):
    hyp_param.local_rank = local_rank
    # NCCL process group; world size = GPUs on this node
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
        rank=hyp_param.local_rank,
        world_size=hyp_param.gpus * 1
    )
    seed_it(local_rank + hyp_param.seed)

    torch.cuda.set_device(hyp_param.local_rank)

    import model.visual.sam2  # noqa: F401 — registers Hydra `configs` (initialize_config_module)

    from hydra import compose
    from hydra.utils import instantiate
    from omegaconf import OmegaConf

    # Hydra configs under v1m.code/configs (same pattern as training/sam2_training_config.yaml)
    transform_config_path = 'training/sam2_training_config.yaml'

    if 'hiera_t' in hyp_param.sam_config_path:
        hyp_param.image_size = 224
        hyp_param.image_embedding_size = int(hyp_param.image_size / 16)
        print('\n upload image size to be {}x{} \n'.format(224, 224), flush=True)

    cfg = compose(config_name=transform_config_path)
    OmegaConf.resolve(cfg)
    hyp_param.contrastive_learning = OmegaConf.to_container(cfg.contrastive_learning, resolve=True)

    arch_h = compose(config_name='auralfuser/architecture.yaml')
    OmegaConf.resolve(arch_h)
    hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)

    from model.mymodel import AVmodel
    av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank)

    av_model = torch.nn.parallel.distributed.DistributedDataParallel(av_model, device_ids=[hyp_param.local_rank],
                                                                     find_unused_parameters=True)

    # Optimizer: parameter groups from AuralFuser only (train_* vs VGG backbone)
    from utils.utils import manipulate_params
    parameter_list = manipulate_params(hyp_param, av_model.module.aural_fuser)
    optimiser = torch.optim.AdamW(parameter_list, betas=(0.9, 0.999))

    from dataloader.dataset import AV
    from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
    from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
    from torch.utils.data.distributed import DistributedSampler

    compose_api = instantiate(cfg.train_transforms, _recursive_=True)[0]

    audio_augmentation = AudioAugmentation(mono=True)
    train_dataset = AV(split='train', augmentation={"visual": compose_api, "audio": audio_augmentation},
                       param=hyp_param, root_path=hyp_param.data_root_path)


    visual_augmentation = VisualAugmentation(hyp_param.image_mean, hyp_param.image_std,
                                             hyp_param.image_size, hyp_param.image_size,
                                             hyp_param.scale_list, ignore_index=hyp_param.ignore_index)

    audio_augmentation = AudioAugmentation(mono=True)

    random_sampler = DistributedSampler(train_dataset, shuffle=True)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=hyp_param.batch_size,
                                                   sampler=random_sampler,
                                                   num_workers=hyp_param.num_workers, drop_last=True)

    test_dataset = AV(split='test', augmentation={"visual": visual_augmentation, "audio": audio_augmentation},
                      param=hyp_param, root_path=hyp_param.data_root_path)

    order_sampler = DistributedSampler(test_dataset, shuffle=False)
    test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, sampler=order_sampler,
                                                  num_workers=hyp_param.num_workers)


    criterion = instantiate(cfg.loss, _recursive_=True)['all']
    from utils.tensorboard import Tensorboard
    tensorboard = Tensorboard(config=hyp_param) if hyp_param.local_rank <= 0 else None

    from trainer.train import Trainer
    from utils.foreground_iou import ForegroundIoU
    from utils.foreground_fscore import ForegroundFScore
    metrics = {"foreground_iou": ForegroundIoU(), "foreground_f-score": ForegroundFScore(0 if hyp_param.local_rank <= 0 else hyp_param.local_rank)}

    trainer = Trainer(hyp_param, loss=criterion, tensorboard=tensorboard, metrics=metrics)
    

    curr_best = 0.  # checkpoint when IoU (iou_select mode) improves

    for epoch in range(hyp_param.epochs):
        av_model.train()
        av_model.module.freeze_sam_parameters()
        random_sampler.set_epoch(epoch)
        trainer.train(epoch=epoch, dataloader=train_dataloader, model=av_model, optimiser=optimiser)

        torch.distributed.barrier()
        torch.cuda.empty_cache()

        av_model.eval()
        # Three validation modes: default first mask / IoU-selected mask / IoU + objectness gate
        curr_results1, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='first_index')
        curr_results, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_select')
        curr_results3, _ = trainer.valid(epoch=epoch, dataloader=test_dataloader, model=av_model, process='iou_occ_select')
        if hyp_param.local_rank <= 0 and curr_results > curr_best:
            curr_best = curr_results
            torch.save(av_model.module.aural_fuser.state_dict(), os.path.join(hyp_param.saved_dir, str(curr_results) + ".pth"))
        torch.distributed.barrier()
        torch.cuda.empty_cache()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='PyTorch Training')
    parser.add_argument('-n', '--nodes', default=1, type=int, metavar='N')

    parser.add_argument("--local_rank", type=int, default=-1,
                        help='multi-process training for DDP')

    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')

    parser.add_argument('--batch_size', default=1, type=int)

    parser.add_argument('--epochs', default=80, type=int,
                        help="total epochs that used for the training")

    parser.add_argument('--lr', default=1e-4, type=float,
                        help='Default HEAD Learning rate is same as others, '
                             '*Note: in ddp training, lr will automatically times by n_gpu')

    parser.add_argument('--online', action="store_true",
                        help='switch on for visualization; switch off for debug')

    args = parser.parse_args()

    from configs.config import C

    args = EasyDict({**C, **vars(args)})
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '9902'

    torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))