| |
| """ |
| Misc functions, including distributed helpers. |
| """ |
|
|
| import collections |
| import re |
|
|
| from dataclasses import dataclass, field as field_ptr_behaviour, fields, is_dataclass |
| from typing import Any, get_args, get_origin, List, Mapping, Optional, Sequence, Union |
|
|
| import torch |
|
|
|
|
| MyTensor = Union[torch.Tensor, List[Any]] |
|
|
|
|
| def interpolate( |
| input, size=None, scale_factor=None, mode="nearest", align_corners=None |
| ): |
| |
| """ |
| Equivalent to nn.functional.interpolate, but with support for empty channel sizes. |
| """ |
| if input.numel() > 0: |
| return torch.nn.functional.interpolate( |
| input, size, scale_factor, mode, align_corners |
| ) |
|
|
| assert ( |
| input.shape[0] != 0 or input.shape[1] != 0 |
| ), "At least one of the two first dimensions must be non zero" |
|
|
| if input.shape[1] == 0: |
| |
| return torch.nn.functional.interpolate( |
| input.transpose(0, 1), size, scale_factor, mode, align_corners |
| ).transpose(0, 1) |
|
|
| |
| return torch.nn.functional.interpolate( |
| input, size, scale_factor, mode, align_corners |
| ) |
|
|
|
|
| @dataclass |
| class BatchedPointer: |
| stage_ids: MyTensor |
| stage_ids__type = torch.long |
| query_ids: MyTensor |
| query_ids__type = torch.long |
| object_ids: MyTensor |
| object_ids__type = torch.long |
| ptr_mask: MyTensor |
| ptr_mask__type = torch.bool |
| ptr_types: MyTensor |
| ptr_types__type = torch.long |
|
|
|
|
| @dataclass |
| class FindStage: |
| img_ids: MyTensor |
| img_ids__type = torch.long |
| text_ids: MyTensor |
| text_ids__type = torch.long |
|
|
| input_boxes: MyTensor |
| input_boxes__type = torch.float |
| input_boxes_mask: MyTensor |
| input_boxes_mask__type = torch.bool |
| input_boxes_label: MyTensor |
| input_boxes_label__type = torch.long |
|
|
| input_points: MyTensor |
| input_points__type = torch.float |
| input_points_mask: MyTensor |
| input_points_mask__type = torch.bool |
|
|
| |
| |
| object_ids: Optional[List[List]] = None |
|
|
|
|
| @dataclass |
| class BatchedFindTarget: |
| |
| num_boxes: MyTensor |
| num_boxes__type = torch.long |
|
|
| |
| boxes: MyTensor |
| boxes__type = torch.float |
| |
| |
| boxes_padded: MyTensor |
| boxes_padded__type = torch.float |
|
|
| |
| repeated_boxes: MyTensor |
| repeated_boxes__type = torch.float |
|
|
| |
| segments: Optional[MyTensor] |
| segments__type = torch.bool |
|
|
| |
| semantic_segments: Optional[MyTensor] |
| semantic_segments__type = torch.bool |
|
|
| is_valid_segment: Optional[MyTensor] |
| is_valid_segment__type = torch.bool |
|
|
| |
| is_exhaustive: MyTensor |
| is_exhaustive__type = torch.bool |
|
|
| |
| object_ids: MyTensor |
| object_ids__type = torch.long |
| object_ids_padded: MyTensor |
| object_ids_padded__type = torch.long |
|
|
|
|
| @dataclass |
| class BatchedInferenceMetadata: |
| """All metadata required to post-process a find stage""" |
|
|
| |
| coco_image_id: MyTensor |
| coco_image_id__type = torch.long |
|
|
| |
| original_image_id: MyTensor |
| original_image_id__type = torch.long |
|
|
| |
| original_category_id: MyTensor |
| original_category_id__type = torch.int |
|
|
| |
| original_size: MyTensor |
| original_size__type = torch.long |
|
|
| |
| object_id: MyTensor |
| object_id__type = torch.long |
|
|
| |
| frame_index: MyTensor |
| frame_index__type = torch.long |
|
|
| |
| |
|
|
| |
| is_conditioning_only: List[Optional[bool]] |
|
|
|
|
| @dataclass |
| class BatchedDatapoint: |
| img_batch: torch.Tensor |
| find_text_batch: List[str] |
| find_inputs: List[FindStage] |
| find_targets: List[BatchedFindTarget] |
| find_metadatas: List[BatchedInferenceMetadata] |
| raw_images: Optional[List[Any]] = None |
|
|
|
|
| def convert_my_tensors(obj): |
| def is_optional_field(field) -> bool: |
| return get_origin(field) is Union and type(None) in get_args(field) |
|
|
| for field in fields(obj): |
| if is_dataclass(getattr(obj, field.name)): |
| convert_my_tensors(getattr(obj, field.name)) |
| continue |
|
|
| field_type = field.type |
| if is_optional_field(field.type): |
| field_type = Union[get_args(field.type)[:-1]] |
|
|
| if field_type != MyTensor or getattr(obj, field.name) is None: |
| continue |
|
|
| elif len(getattr(obj, field.name)) and isinstance( |
| getattr(obj, field.name)[0], torch.Tensor |
| ): |
| stack_dim = 0 |
| if field.name in [ |
| "input_boxes", |
| "input_boxes_label", |
| ]: |
| stack_dim = 1 |
| setattr( |
| obj, |
| field.name, |
| torch.stack(getattr(obj, field.name), dim=stack_dim).to( |
| getattr(obj, field.name + "__type") |
| ), |
| ) |
| else: |
| setattr( |
| obj, |
| field.name, |
| torch.as_tensor( |
| getattr(obj, field.name), dtype=getattr(obj, field.name + "__type") |
| ), |
| ) |
| return obj |
|
|