# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n from abc import ABCMeta, abstractmethod import torch import torchvision from core.vision_encoder import pe from torch.nn.utils.rnn import pad_sequence from sam_audio.model.config import ( PerceptionEncoderConfig, VisionEncoderConfig, ) class RescaleTransform(object): """Rescale the image in a sample to a given size. Args: output_size (tuple or int): Desired output size. If tuple, output is matched to output_size. If int, smaller of image edges is matched to output_size keeping aspect ratio the same. """ def __init__(self, output_size, interpolation): assert isinstance(output_size, (int, tuple)) self.output_size = output_size if isinstance(output_size, int): self.output_size = (output_size, output_size) self.interpolation = interpolation def __call__(self, sample): # sample: [T, C, H, W] sample = torch.nn.functional.interpolate( sample.float(), size=self.output_size, mode=self.interpolation.value ) return sample class VisionEncoder(torch.nn.Module, metaclass=ABCMeta): def __init__(self, cfg: VisionEncoderConfig): super().__init__() self.batch_size = cfg.batch_size self.dim = cfg.dim self.transform = self.get_transform() @torch.no_grad() def forward(self, videos: list[torch.Tensor]) -> torch.Tensor: """ Encodes a list of input videos. Each element of the list is a video represented as a tensor [T, C, H, W] Args: videos (list[torch.Tensor]): List of input image tensors to be processed. Returns: torch.Tensor: Encoded feature representations of the input tensors. The output is padded along the time dimension for variable length videos """ result = [] for video in videos: video = self.transform(video) if self.batch_size > 0 and video.size(0) > self.batch_size: res = [] for i in range(0, video.size(0), self.batch_size): res.append(self.encode(video[i : i + self.batch_size])) result.append(torch.cat(res, dim=0)) else: result.append(self.encode(video)) return pad_sequence(result, batch_first=True, padding_value=0.0) @abstractmethod def encode(self, x: torch.Tensor) -> torch.Tensor: pass @abstractmethod def get_transform(self): pass class PerceptionEncoder(VisionEncoder): def __init__(self, cfg: PerceptionEncoderConfig): self.normalize_feature = cfg.normalize_feature self.interpolation_mode = cfg.interpolation_mode self.image_size = cfg.image_size super().__init__(cfg) self.model = pe.CLIP.from_config(cfg.name) def encode(self, x): image_features = self.model.encode_image(x, normalize=self.normalize_feature) return image_features def get_transform(self): T = torchvision.transforms try: interp = getattr(T.InterpolationMode, self.interpolation_mode.upper()) except AttributeError as err: raise ValueError( f"Unsupported interpolation_mode: {self.interpolation_mode}" ) from err crop = [ T.Resize( (self.image_size, self.image_size), interpolation=interp, ) ] return T.Compose( crop + [ T.Lambda(lambda x: x.float() / 255.0), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5], inplace=True), ] )