File size: 4,020 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""End-to-end Ref-AVS: SAM2 visual backbone + AuralFuser fusion + tracking head.

Orchestration follows ``avs.code/v1m.code/model/mymodel.py``.
"""
import torch

from model.visual.sam2.build_sam import build_sam2_visual_predictor
from model.visual.sam2.utils.transforms import SAM2Transforms
from model.aural_fuser import AuralFuser
from transformers import AutoTokenizer, AutoModel


class AVmodel(torch.nn.Module):
    """SAM2 + audio/text fusion (``aural_fuser``) + SAM2 tracking decoder."""

    def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0):
        super().__init__()
        self.param = param
        self.mask_threshold = mask_threshold
        self._bb_feat_sizes = [
            (int(self.param.image_size / 4), int(self.param.image_size / 4)),
            (int(self.param.image_size / 8), int(self.param.image_size / 8)),
            (int(self.param.image_size / 16), int(self.param.image_size / 16)),
        ]

        self.v_model = build_sam2_visual_predictor(
            self.param.sam_config_path,
            self.param.backbone_weight,
            apply_postprocessing=True,
            mode='train',
            hydra_overrides_extra=["++model.image_size={}".format(self.param.image_size)],
        )
        self._transforms = SAM2Transforms(
            resolution=self.v_model.image_size,
            mask_threshold=mask_threshold,
            max_hole_area=max_hole_area,
            max_sprinkle_area=max_sprinkle_area,
        )
        self.aural_fuser = AuralFuser(hyp_param=self.param)
        self.text_tokenizer = AutoTokenizer.from_pretrained('distilbert/distilroberta-base')
        self.t_model = AutoModel.from_pretrained('distilbert/distilroberta-base')

    def _encode_text(self, prompts):
        """RoBERTa embeddings for referring expressions (frozen at train time)."""
        enc = self.text_tokenizer(
            *prompts,
            max_length=25,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        enc['input_ids'] = enc['input_ids'].cuda(self.param.local_rank, non_blocking=True)
        enc['attention_mask'] = enc['attention_mask'].cuda(self.param.local_rank, non_blocking=True)
        with torch.no_grad():
            return self.t_model(**enc).last_hidden_state

    def forward_frame(self, frame_):
        """Single-frame SAM2 image encoder pass (same helper pattern as v1m)."""
        frame = torch.nn.functional.interpolate(
            frame_, (self.param.image_size, self.param.image_size),
            antialias=True, align_corners=False, mode='bilinear',
        )
        return self.v_model.image_encoder(frame)

    def forward(self, frames, spect, prompts, sam_process=False):
        """Fuse audio+text into FPN, then run SAM2 tracking without box/mask prompts."""
        text_feats = self._encode_text(prompts)
        backbone_feats = self.v_model.forward_image(frames, pre_compute=False)
        audio_residual_feats = self.aural_fuser(backbone_feats, spect, text_feats)
        visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats

        map_res = visual_resfeats[::-1]
        vec_res = audio_resfeats[::-1]
        av_feats = (map_res, vec_res)

        backbone_feats = self.v_model.precompute_high_res_features(backbone_feats)
        backbone_feats = self.v_model.dont_prepare_prompt_inputs(
            backbone_feats,
            num_frames=frames.shape[0],
            condition_frame=int(frames.shape[0] / 2),
        )
        outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats)
        return outputs, proj_feats

    @property
    def device(self) -> torch.device:
        return self.v_model.device

    def freeze_sam_parameters(self):
        """Freeze SAM2 and text backbone; only ``aural_fuser`` is trained."""
        self.v_model.eval()
        self.t_model.eval()
        for _, parameter in self.v_model.named_parameters():
            parameter.requires_grad = False