File size: 3,892 Bytes
70af406
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n

from abc import ABCMeta, abstractmethod

import torch
import torchvision
from core.vision_encoder import pe
from torch.nn.utils.rnn import pad_sequence

from sam_audio.model.config import (
    PerceptionEncoderConfig,
    VisionEncoderConfig,
)


class RescaleTransform(object):
    """Rescale the image in a sample to a given size.



    Args:

        output_size (tuple or int): Desired output size. If tuple, output is

            matched to output_size. If int, smaller of image edges is matched

            to output_size keeping aspect ratio the same.

    """

    def __init__(self, output_size, interpolation):
        assert isinstance(output_size, (int, tuple))
        self.output_size = output_size
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        self.interpolation = interpolation

    def __call__(self, sample):
        # sample: [T, C, H, W]
        sample = torch.nn.functional.interpolate(
            sample.float(), size=self.output_size, mode=self.interpolation.value
        )
        return sample


class VisionEncoder(torch.nn.Module, metaclass=ABCMeta):
    def __init__(self, cfg: VisionEncoderConfig):
        super().__init__()
        self.batch_size = cfg.batch_size
        self.dim = cfg.dim
        self.transform = self.get_transform()

    @torch.no_grad()
    def forward(self, videos: list[torch.Tensor]) -> torch.Tensor:
        """

        Encodes a list of input videos.  Each element of the list is a video represented

            as a tensor [T, C, H, W]

        Args:

            videos (list[torch.Tensor]): List of input image tensors to be processed.



        Returns:

            torch.Tensor: Encoded feature representations of the input tensors.

                The output is padded along the time dimension for variable length videos

        """
        result = []
        for video in videos:
            video = self.transform(video)
            if self.batch_size > 0 and video.size(0) > self.batch_size:
                res = []
                for i in range(0, video.size(0), self.batch_size):
                    res.append(self.encode(video[i : i + self.batch_size]))
                result.append(torch.cat(res, dim=0))
            else:
                result.append(self.encode(video))
        return pad_sequence(result, batch_first=True, padding_value=0.0)

    @abstractmethod
    def encode(self, x: torch.Tensor) -> torch.Tensor:
        pass

    @abstractmethod
    def get_transform(self):
        pass


class PerceptionEncoder(VisionEncoder):
    def __init__(self, cfg: PerceptionEncoderConfig):
        self.normalize_feature = cfg.normalize_feature
        self.interpolation_mode = cfg.interpolation_mode
        self.image_size = cfg.image_size
        super().__init__(cfg)
        self.model = pe.CLIP.from_config(cfg.name)

    def encode(self, x):
        image_features = self.model.encode_image(x, normalize=self.normalize_feature)
        return image_features

    def get_transform(self):
        T = torchvision.transforms
        try:
            interp = getattr(T.InterpolationMode, self.interpolation_mode.upper())
        except AttributeError as err:
            raise ValueError(
                f"Unsupported interpolation_mode: {self.interpolation_mode}"
            ) from err
        crop = [
            T.Resize(
                (self.image_size, self.image_size),
                interpolation=interp,
            )
        ]

        return T.Compose(
            crop
            + [
                T.Lambda(lambda x: x.float() / 255.0),
                T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True),
            ]
        )