Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,562 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
"""
Base Prediction Head Class for UniCeption
"""
from dataclasses import dataclass
from typing import Dict, List, Tuple
import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
@dataclass
class PredictionHeadInput:
last_feature: Float[Tensor, "batch_size feat_dim feat_height feat_width"]
@dataclass
class PredictionHeadLayeredInput:
list_features: List[Float[Tensor, "batch_size feat_dim feat_height feat_width"]]
target_output_shape: Tuple[int, int]
@dataclass
class PredictionHeadTokenInput:
last_feature: Float[Tensor, "batch_size feat_dim num_tokens"]
@dataclass
class PixelTaskOutput:
"""
PixelTaskOutput have dense pixel-wise output in BCHW format,
with the same spatial resolution as the input image.
"""
decoded_channels: Float[Tensor, "batch_size output_channels height width"]
@dataclass
class SummaryTaskOutput:
"""
SummaryTaskOutput have a single latent output for each image in BC format.
"""
decoded_channels: Float[Tensor, "batch_size output_channels"]
@dataclass
class AdaptorInput:
adaptor_feature: Float[Tensor, "batch_size sliced_channels height width"]
output_shape_hw: Tuple[int, int]
@dataclass
class AdaptorOutput:
value: Float[Tensor, "batch_size sliced_channels ..."]
@dataclass
class PredictionHeadOutput:
adaptor_output: Dict[str, AdaptorOutput]
@dataclass
class MaskAdaptorOutput:
logits: Float[Tensor, "batch_size 1 height width"]
mask: Float[Tensor, "batch_size 1 height width"]
@dataclass
class Covariance2DAdaptorOutput:
covariance: Float[Tensor, "batch_size 3 height width"] # the 3 channels are s_x^2, s_y^2, and rho_xy
log_det: Float[Tensor, "batch_size 1 height width"] # log determinant of the covariance matrix
inv_covariance: Float[
Tensor, "batch_size 3 height width"
] # the channels are [0,0], [1,1], and [0,1] of the inverse covariance matrix
@dataclass
class RegressionAdaptorOutput:
value: Float[Tensor, "batch_size sliced_channels height width"]
@dataclass
class RegressionWithConfidenceAdaptorOutput:
value: Float[Tensor, "batch_size sliced_channels height width"]
confidence: Float[Tensor, "batch_size 1 height width"]
@dataclass
class RegressionWithMaskAdaptorOutput:
value: Float[Tensor, "batch_size sliced_channels height width"]
logits: Float[Tensor, "batch_size 1 height width"]
mask: Float[Tensor, "batch_size 1 height width"]
@dataclass
class RegressionWithConfidenceAndMaskAdaptorOutput:
value: Float[Tensor, "batch_size sliced_channels height width"]
confidence: Float[Tensor, "batch_size 1 height width"]
logits: Float[Tensor, "batch_size 1 height width"]
mask: Float[Tensor, "batch_size 1 height width"]
class UniCeptionPredictionHeadBase(nn.Module):
def __init__(
self,
name: str,
*args,
**kwargs,
):
"""
Base class for all prediction heads in UniCeption.
"""
super().__init__(*args, **kwargs)
self.name: str = name
def forward(
self,
head_input: PredictionHeadInput,
) -> PredictionHeadOutput:
"""
Forward interface for the UniCeption prediction heads.
Args:
head_input (PredictionHeadInput): Input to the prediction head.
Returns:
head_output (PredictionHeadOutput): Output of the prediction head.
"""
raise NotImplementedError
class UniCeptionAdaptorBase(nn.Module):
def __init__(
self,
name: str,
required_channels: int,
*args,
**kwargs,
):
"""
Base class for all adaptors in UniCeption.
"""
super().__init__(*args, **kwargs)
self.name: str = name
self.required_channels: int = required_channels
def forward(
self,
adaptor_input: AdaptorInput,
) -> AdaptorOutput:
"""
Forward interface for the UniCeption adaptors.
Args:
adaptor_input (AdaptorInput): Input to the adaptor.
Returns:
adaptor_output (AdaptorOutput): Output of the adaptor.
"""
raise NotImplementedError
class AdaptorMap(nn.Module):
def __init__(self, *adaptors: UniCeptionAdaptorBase):
"""
AdaptorMap slices the input tensor and passes it to the corresponding adaptors.
Args:
*adaptors (List[UniCeptionAdaptorBase]): List of adaptors in the Adaptor
"""
super().__init__()
self.adaptors = nn.ModuleDict({adaptor.name: adaptor for adaptor in adaptors})
self.required_channels = sum([adaptor.required_channels for adaptor in adaptors])
def forward(
self,
adaptor_input: AdaptorInput,
) -> Dict[str, AdaptorOutput]:
"""
Run the input through the adaptors and return the output.
Args:
adaptor_input (AdaptorInput): Input to the adaptors.
Returns:
Dict[str, AdaptorOutput]: Output of the adaptors, from adaptor name to AdaptorOutput.
"""
# split adaptor input into chunks
adaptor_features = torch.split(
adaptor_input.decoded_channels, [adaptor.required_channels for adaptor in self.adaptors.values()], dim=1
)
result = {
adaptor_name: adaptor(AdaptorInput(adaptor_features[i], adaptor_features[i].shape[2:]))
for i, (adaptor_name, adaptor) in enumerate(self.adaptors.items())
}
return result
|