tuandunghcmut's picture
Add files using upload-large-folder tool
f0384a9 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional
import torch
from torch import Tensor
def interpolate(
input: Tensor,
size: Optional[List[int]] = None,
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
) -> Tensor:
"""
Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
assert (
input.shape[0] != 0 or input.shape[1] != 0
), "At least one of the two first dimensions must be non zero"
if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
return torch.nn.functional.interpolate(
input.transpose(0, 1), size, scale_factor, mode, align_corners
).transpose(0, 1)
# empty batch dimension is now supported in pytorch
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
def targets_to(targets: List[Dict[str, Any]], device):
"""Moves the target dicts to the given device."""
excluded_keys = [
"questionId",
"tokens_positive",
"tokens",
"dataset_name",
"sentence_id",
"original_img_id",
"nb_eval",
"task_id",
"original_id",
]
return [
{
k: v.to(device) if k not in excluded_keys else v
for k, v in t.items()
if k != "caption" and k != "answer_type_mask"
}
for t in targets
]