File size: 6,350 Bytes
14114e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved
"""
Misc functions, including distributed helpers.
"""
import collections
import re
from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass
from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union
import torch
MyTensor = Union[torch.Tensor, List[Any]]
def interpolate(
input, size=None, scale_factor=None, mode="nearest", align_corners=None
):
# type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
"""
Equivalent to nn.functional.interpolate, but with support for empty channel sizes.
"""
if input.numel() > 0:
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
assert (
input.shape[0] != 0 or input.shape[1] != 0
), "At least one of the two first dimensions must be non zero"
if input.shape[1] == 0:
# Pytorch doesn't support null dimension on the channel dimension, so we transpose to fake a null batch dim
return torch.nn.functional.interpolate(
input.transpose(0, 1), size, scale_factor, mode, align_corners
).transpose(0, 1)
# empty batch dimension is now supported in pytorch
return torch.nn.functional.interpolate(
input, size, scale_factor, mode, align_corners
)
@dataclass
class BatchedPointer:
stage_ids: MyTensor
stage_ids__type = torch.long
query_ids: MyTensor
query_ids__type = torch.long
object_ids: MyTensor
object_ids__type = torch.long
ptr_mask: MyTensor
ptr_mask__type = torch.bool
ptr_types: MyTensor
ptr_types__type = torch.long
@dataclass
class FindStage:
img_ids: MyTensor
img_ids__type = torch.long
text_ids: MyTensor
text_ids__type = torch.long
input_boxes: MyTensor
input_boxes__type = torch.float
input_boxes_mask: MyTensor
input_boxes_mask__type = torch.bool
input_boxes_label: MyTensor
input_boxes_label__type = torch.long
input_points: MyTensor
input_points__type = torch.float
input_points_mask: MyTensor
input_points_mask__type = torch.bool
# We track the object ids referred to by this query.
# This is beneficial for tracking in videos without the need for pointers.
object_ids: Optional[List[List]] = None # List of objects per query
@dataclass
class BatchedFindTarget:
# The number of boxes in each find query
num_boxes: MyTensor
num_boxes__type = torch.long
# Target boxes in normalized CxCywh format
boxes: MyTensor
boxes__type = torch.float
# Target boxes in normalized CxCywh format but in padded representation
# as used in BinaryHungarianMatcherV2 (unlike the packed ones in `boxes`)
boxes_padded: MyTensor
boxes_padded__type = torch.float
# For hybrid matching, we repeat the boxes
repeated_boxes: MyTensor
repeated_boxes__type = torch.float
# Target Segmentation masks
segments: Optional[MyTensor]
segments__type = torch.bool
# Target Semantic Segmentation masks
semantic_segments: Optional[MyTensor]
semantic_segments__type = torch.bool
is_valid_segment: Optional[MyTensor]
is_valid_segment__type = torch.bool
# Whether annotations are exhaustive for each query
is_exhaustive: MyTensor
is_exhaustive__type = torch.bool
# The object id for each ground-truth box, in both packed and padded representations
object_ids: MyTensor
object_ids__type = torch.long
object_ids_padded: MyTensor
object_ids_padded__type = torch.long
@dataclass
class BatchedInferenceMetadata:
"""All metadata required to post-process a find stage"""
# Coco id that corresponds to the "image" for evaluation by the coco evaluator
coco_image_id: MyTensor
coco_image_id__type = torch.long
# id in the original dataset, such that we can use the original evaluator
original_image_id: MyTensor
original_image_id__type = torch.long
# Original category id (if we want to use the original evaluator)
original_category_id: MyTensor
original_category_id__type = torch.int
# Size of the raw image (height, width)
original_size: MyTensor
original_size__type = torch.long
# id of the object in the media (track_id for a video)
object_id: MyTensor
object_id__type = torch.long
# index of the frame in the media (0 in the case of a single-frame media)
frame_index: MyTensor
frame_index__type = torch.long
# Adding for relations inference
# get_text_input: List[Optional[str]]
# Adding for TA conditional inference
is_conditioning_only: List[Optional[bool]]
@dataclass
class BatchedDatapoint:
img_batch: torch.Tensor
find_text_batch: List[str]
find_inputs: List[FindStage]
find_targets: List[BatchedFindTarget]
find_metadatas: List[BatchedInferenceMetadata]
raw_images: Optional[List[Any]] = None
def convert_my_tensors(obj):
def is_optional_field(field) -> bool:
return get_origin(field) is Union and type(None) in get_args(field)
for field in fields(obj):
if is_dataclass(getattr(obj, field.name)):
convert_my_tensors(getattr(obj, field.name))
continue
field_type = field.type
if is_optional_field(field.type):
field_type = Union[get_args(field.type)[:-1]] # Get the Optional field type
if field_type != MyTensor or getattr(obj, field.name) is None:
continue
elif len(getattr(obj, field.name)) and isinstance(
getattr(obj, field.name)[0], torch.Tensor
):
stack_dim = 0
if field.name in [
"input_boxes",
"input_boxes_label",
]:
stack_dim = 1
setattr(
obj,
field.name,
torch.stack(getattr(obj, field.name), dim=stack_dim).to(
getattr(obj, field.name + "__type")
),
)
else:
setattr(
obj,
field.name,
torch.as_tensor(
getattr(obj, field.name), dtype=getattr(obj, field.name + "__type")
),
)
return obj
|