File size: 7,599 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""Distributed inference on Ref-AVS (test_s / test_u / test_n); uses Trainer.valid / valid_null like main.py."""
import os
import pathlib
import argparse
import random

import numpy
import torch
from easydict import EasyDict


_real_mkdir = pathlib.Path.mkdir


def _safe_mkdir(self, mode=0o777, parents=False, exist_ok=False):
    try:
        return _real_mkdir(self, mode, parents, exist_ok=exist_ok)
    except PermissionError:
        pass


pathlib.Path.mkdir = _safe_mkdir


def seed_it(seed):
    random.seed(seed)
    os.environ["PYTHONSEED"] = str(seed)
    numpy.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = True


class _DummyTensorboard:
    """Minimal Tensorboard stub so Trainer.valid / valid_null run without wandb logging."""

    def upload_wandb_info(self, info_dict):
        pass

    def upload_wandb_image(self, *args, **kwargs):
        pass


def main(local_rank, ngpus_per_node, hyp_param):
    hyp_param.local_rank = local_rank
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
        rank=hyp_param.local_rank,
        world_size=hyp_param.gpus,
    )
    seed_it(local_rank + hyp_param.seed)
    torch.cuda.set_device(hyp_param.local_rank)

    import model.visual.sam2  # noqa: F401 — registers Hydra config store
    from hydra import compose
    from omegaconf import OmegaConf

    arch_h = compose(config_name='configs/auralfuser/architecture.yaml')
    OmegaConf.resolve(arch_h)
    hyp_param.aural_fuser = OmegaConf.to_container(arch_h.aural_fuser, resolve=True)

    train_cfg = compose(config_name='configs/training/sam2_training_config.yaml')
    OmegaConf.resolve(train_cfg)
    hyp_param.contrastive_learning = OmegaConf.to_container(train_cfg.contrastive_learning, resolve=True)

    hyp_param.image_size = 1024
    hyp_param.image_embedding_size = int(hyp_param.image_size / 16)

    from model.mymodel import AVmodel
    av_model = AVmodel(hyp_param).cuda(hyp_param.local_rank)
    if not hyp_param.inference_ckpt:
        raise ValueError("--inference_ckpt is required for inference.")

    ckpt_sd = torch.load(hyp_param.inference_ckpt, map_location="cpu")
    if not isinstance(ckpt_sd, dict):
        raise TypeError("Checkpoint must be a state_dict dictionary.")
    if any(k.startswith("v_model.") or k.startswith("aural_fuser.") for k in ckpt_sd):
        av_model.load_state_dict(ckpt_sd, strict=True)
    else:
        av_model.aural_fuser.load_state_dict(ckpt_sd, strict=True)

    av_model = torch.nn.parallel.DistributedDataParallel(
        av_model, device_ids=[hyp_param.local_rank], find_unused_parameters=False,
    )
    av_model.eval()

    from dataloader.dataset import AV
    from dataloader.visual.visual_augmentation import Augmentation as VisualAugmentation
    from dataloader.audio.audio_augmentation import Augmentation as AudioAugmentation
    from torch.utils.data import DataLoader, Subset
    from torch.utils.data.distributed import DistributedSampler

    visual_aug = VisualAugmentation(
        hyp_param.image_mean, hyp_param.image_std,
        hyp_param.image_size, hyp_param.image_size,
        hyp_param.scale_list, ignore_index=hyp_param.ignore_index,
    )
    audio_aug = AudioAugmentation(mono=True)

    max_batches = getattr(hyp_param, "inference_max_batches", 0) or 0
    val_batch_size = getattr(hyp_param, "inference_val_batch_size", 4)

    def _test_loader(split):
        ds = AV(
            split=split,
            augmentation={"visual": visual_aug, "audio": audio_aug},
            param=hyp_param,
            root_path=hyp_param.data_root_path,
        )
        if max_batches > 0:
            n_samples = min(max_batches * val_batch_size, len(ds))
            ds = Subset(ds, range(n_samples))
        sampler = DistributedSampler(ds, shuffle=False)
        return DataLoader(
            ds,
            batch_size=val_batch_size,
            sampler=sampler,
            num_workers=hyp_param.num_workers,
        )

    test_s_loader = _test_loader('test_s')
    test_u_loader = _test_loader('test_u')
    test_n_loader = _test_loader('test_n')

    from trainer.train import Trainer
    from utils.foreground_iou import ForegroundIoU
    from utils.foreground_fscore import ForegroundFScore
    from utils.foreground_s import ForegroundS

    metrics = {
        "foreground_iou": ForegroundIoU(),
        "foreground_f-score": ForegroundFScore(hyp_param.local_rank),
        "foreground_s": ForegroundS(),
    }
    trainer = Trainer(hyp_param, loss=None, tensorboard=_DummyTensorboard(), metrics=metrics)

    test_s_iou, test_s_f = trainer.valid(
        epoch=0, dataloader=test_s_loader, model=av_model, process='test_s',
    )
    torch.cuda.empty_cache()

    test_u_iou, test_u_f = trainer.valid(
        epoch=0, dataloader=test_u_loader, model=av_model, process='test_u',
    )
    torch.cuda.empty_cache()

    test_n_s = trainer.valid_null(
        epoch=0, dataloader=test_n_loader, model=av_model, process='test_n',
    )
    torch.cuda.empty_cache()

    if hyp_param.local_rank <= 0:
        print("\n========== Ref-AVS inference (same splits / metrics as training valid) ==========")
        print("  test_s   f_iou={}  f_f-score={}".format(test_s_iou, test_s_f))
        print("  test_u   f_iou={}  f_f-score={}".format(test_u_iou, test_u_f))
        print("  test_n   f_s={}".format(test_n_s))
        print("=======================================================\n")


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Ref-AVS inference: test_s / test_u / test_n')

    parser.add_argument('--local_rank', type=int, default=-1,
                        help='multi-process training for DDP')
    parser.add_argument('-g', '--gpus', default=1, type=int,
                        help='number of gpus per node')
    parser.add_argument('--batch_size', default=1, type=int,
                        help='unused at inference (validation uses inference_val_batch_size)')
    parser.add_argument('--epochs', default=80, type=int, help='unused')
    parser.add_argument('--lr', default=1e-5, type=float, help='unused')
    parser.add_argument('--online', action='store_true', help='unused')
    parser.add_argument(
        '--inference_ckpt', type=str, required=True,
        help='Trained AuralFuser checkpoint (.pth). SAM2 from backbone_weight in configs.',
    )
    parser.add_argument('--inference_max_batches', type=int, default=0,
                        help='0 = full split; >0 = first N batches per split (debug)')
    parser.add_argument('--inference_val_batch_size', type=int, default=4,
                        help='Validation batch size (default 4, same as main.py _test_loader)')

    args = parser.parse_args()

    from configs.config import C
    args = EasyDict({**C, **vars(args)})

    _repo = pathlib.Path(__file__).resolve().parent
    _workspace = _repo.parent
    args.data_root_path = str(_workspace / 'REFAVS')
    args.backbone_weight = str(_workspace / 'ckpts' / 'sam_ckpts' / 'sam2_hiera_large.pt')
    args.audio.PRETRAINED_VGGISH_MODEL_PATH = str(_workspace / 'ckpts' / 'vggish-10086976.pth')
    args.saved_dir = '/tmp/ref_avs_infer_ckpt'
    pathlib.Path(args.saved_dir).mkdir(parents=True, exist_ok=True)

    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '9902'

    torch.multiprocessing.spawn(main, nprocs=args.gpus, args=(args.gpus, args))