| import logging |
| from typing import Any, Mapping |
|
|
| import torch |
| from torch import nn |
|
|
| from mmaudio.ext.synchformer.motionformer import MotionFormer |
|
|
|
|
| class Synchformer(nn.Module): |
|
|
| def __init__(self): |
| super().__init__() |
|
|
| self.vfeat_extractor = MotionFormer(extract_features=True, |
| factorize_space_time=True, |
| agg_space_module='TransformerEncoderLayer', |
| agg_time_module='torch.nn.Identity', |
| add_global_repr=False) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| def forward(self, vis): |
| B, S, Tv, C, H, W = vis.shape |
| vis = vis.permute(0, 1, 3, 2, 4, 5) |
| |
| |
| vis = self.vfeat_extractor(vis) |
| return vis |
|
|
| def load_state_dict(self, sd: Mapping[str, Any], strict: bool = True): |
| |
| sd = {k: v for k, v in sd.items() if k.startswith('vfeat_extractor')} |
|
|
| return super().load_state_dict(sd, strict) |
|
|
|
|
| if __name__ == "__main__": |
| model = Synchformer().cuda().eval() |
| sd = torch.load('./ext_weights/synchformer_state_dict.pth', weights_only=True) |
| model.load_state_dict(sd) |
|
|
| vid = torch.randn(2, 7, 16, 3, 224, 224).cuda() |
| features = model.extract_vfeats(vid, for_loop=False).detach().cpu() |
| print(features.shape) |
|
|
| |
| |
| |
|
|