| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Any, Dict, List, Optional |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| |
|
| | def interpolate( |
| | input: Tensor, |
| | size: Optional[List[int]] = None, |
| | scale_factor: Optional[float] = None, |
| | mode: str = "nearest", |
| | align_corners: Optional[bool] = None, |
| | ) -> Tensor: |
| | """ |
| | Equivalent to nn.functional.interpolate, but with support for empty channel sizes. |
| | """ |
| | if input.numel() > 0: |
| | return torch.nn.functional.interpolate( |
| | input, size, scale_factor, mode, align_corners |
| | ) |
| |
|
| | assert ( |
| | input.shape[0] != 0 or input.shape[1] != 0 |
| | ), "At least one of the two first dimensions must be non zero" |
| |
|
| | if input.shape[1] == 0: |
| | |
| | return torch.nn.functional.interpolate( |
| | input.transpose(0, 1), size, scale_factor, mode, align_corners |
| | ).transpose(0, 1) |
| |
|
| | |
| | return torch.nn.functional.interpolate( |
| | input, size, scale_factor, mode, align_corners |
| | ) |
| |
|
| |
|
| | def targets_to(targets: List[Dict[str, Any]], device): |
| | """Moves the target dicts to the given device.""" |
| | excluded_keys = [ |
| | "questionId", |
| | "tokens_positive", |
| | "tokens", |
| | "dataset_name", |
| | "sentence_id", |
| | "original_img_id", |
| | "nb_eval", |
| | "task_id", |
| | "original_id", |
| | ] |
| | return [ |
| | { |
| | k: v.to(device) if k not in excluded_keys else v |
| | for k, v in t.items() |
| | if k != "caption" and k != "answer_type_mask" |
| | } |
| | for t in targets |
| | ] |
| |
|