AuralSAM2 / avs.code /v2.code /model /mymodel.py
yyliu01's picture
Upload folder using huggingface_hub
c6dfc69 verified
import logging
from typing import List, Optional, Tuple, Union
import numpy
import numpy as np
import torch
from PIL.Image import Image
from model.visual.sam2.modeling.sam2_base import SAM2Base
from model.visual.sam2.modeling.backbones.hieradet import Hiera
from model.visual.sam2.modeling.backbones.image_encoder import FpnNeck
from model.visual.sam2.modeling.backbones.image_encoder import ImageEncoder
from model.visual.sam2.modeling.position_encoding import PositionEmbeddingSine
from model.visual.sam2.modeling.memory_attention import MemoryAttention
from model.visual.sam2.modeling.memory_attention import MemoryAttentionLayer
from model.visual.sam2.modeling.sam.transformer import RoPEAttention
from model.visual.sam2.modeling.memory_encoder import MemoryEncoder
from model.visual.sam2.modeling.memory_encoder import MaskDownSampler
from model.visual.sam2.modeling.memory_encoder import Fuser
from model.visual.sam2.modeling.memory_encoder import CXBlock
from model.visual.sam2.utils.transforms import SAM2Transforms
from model.visual.sam2.modeling.backbones.hieradet import do_pool
from model.visual.sam2.modeling.backbones.utils import (
PatchEmbed,
window_partition,
window_unpartition,
)
class AVmodel(torch.nn.Module):
"""End-to-end AV segmentation: SAM2 visual backbone + AuralFuser audio-visual fusion + tracking head."""
def __init__(self, param, mask_threshold=0.0, max_hole_area=0.0, max_sprinkle_area=0.0, ):
super().__init__()
self.param = param
self.mask_threshold = mask_threshold
self._bb_feat_sizes = [(int(self.param.image_size / 4), int(self.param.image_size / 4)),
(int(self.param.image_size / 8), int(self.param.image_size / 8)),
(int(self.param.image_size / 16), int(self.param.image_size / 16))]
from model.visual.sam2.build_sam import build_sam2_visual_predictor
self.v_model = build_sam2_visual_predictor(self.param.sam_config_path, self.param.backbone_weight,
apply_postprocessing=True, mode='train')
self._transforms = SAM2Transforms(
resolution=self.v_model.image_size,
mask_threshold=mask_threshold,
max_hole_area=max_hole_area,
max_sprinkle_area=max_sprinkle_area,
)
from model.aural_fuser import AuralFuser
self.aural_fuser = AuralFuser(hyp_param=self.param)
def _prepare_backbone_features(self, backbone_out):
"""Prepare and flatten visual features."""
backbone_out = backbone_out.copy()
assert len(backbone_out["backbone_fpn"]) == len(backbone_out["vision_pos_enc"])
assert len(backbone_out["backbone_fpn"]) >= self.v_model.num_feature_levels
feature_maps = backbone_out["backbone_fpn"][-self.v_model.num_feature_levels:]
vision_pos_embeds = backbone_out["vision_pos_enc"][-self.v_model.num_feature_levels:]
feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
vision_pos_embeds = [x.flatten(2).permute(2, 0, 1) for x in vision_pos_embeds]
return backbone_out, vision_feats, vision_pos_embeds, feat_sizes
def forward_frame(self, frame_):
frame = torch.nn.functional.interpolate(frame_, (self.param.image_size, self.param.image_size),
antialias=True, align_corners=False, mode='bilinear')
return self.v_model.image_encoder(frame)
def forward(self, frames, spect, prompts, sam_process=False):
"""Fuse audio into FPN features, then run SAM2 tracking. `sam_process` is reserved for prompt path."""
backbone_feats = self.v_model.forward_image(frames, pre_compute=False)
audio_residual_feats = self.aural_fuser(backbone_feats, spect)
visual_resfeats, audio_resfeats, proj_feats = audio_residual_feats
map_res = visual_resfeats[::-1]
vec_res = audio_resfeats[::-1]
av_feats = (map_res, vec_res)
backbone_feats = self.v_model.precompute_high_res_features(backbone_feats)
backbone_feats = self.v_model.dont_prepare_prompt_inputs(backbone_feats, num_frames=frames.shape[0],
cond_frame=int(frames.shape[0]/2) if self.training else 0)
outputs = self.v_model.forward_tracking_wo_prompt(backbone_feats, audio_res=av_feats)
return outputs, proj_feats
@property
def device(self) -> torch.device:
return self.v_model.device
def freeze_sam_parameters(self):
self.v_model.eval()
for name, parameter in self.v_model.named_parameters():
parameter.requires_grad = False