# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Any, Dict, List, Optional import torch from torch import Tensor def interpolate( input: Tensor, size: Optional[List[int]] = None, scale_factor: Optional[float] = None, mode: str = "nearest", align_corners: Optional[bool] = None, ) -> Tensor: """ Equivalent to nn.functional.interpolate, but with support for empty channel sizes. """ if input.numel() > 0: return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners ) assert ( input.shape[0] != 0 or input.shape[1] != 0 ), "At least one of the two first dimensions must be non zero" if input.shape[1] == 0: # Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim return torch.nn.functional.interpolate( input.transpose(0, 1), size, scale_factor, mode, align_corners ).transpose(0, 1) # empty batch dimension is now supported in pytorch return torch.nn.functional.interpolate( input, size, scale_factor, mode, align_corners ) def targets_to(targets: List[Dict[str, Any]], device): """Moves the target dicts to the given device.""" excluded_keys = [ "questionId", "tokens_positive", "tokens", "dataset_name", "sentence_id", "original_img_id", "nb_eval", "task_id", "original_id", ] return [ { k: v.to(device) if k not in excluded_keys else v for k, v in t.items() if k != "caption" and k != "answer_type_mask" } for t in targets ]