import torch from typing import Optional from torch.nn.modules import Transformer, TransformerEncoder from torchvision.models import vit_b_16, VisionTransformer H_DIM = 512 H_DIM_VT = 768 N_OUTPUT_TOKENS_VT = 16 * 197 class SimpleZRegressionVisionTransformer(torch.nn.Module): """ A VIT transformer encoder + linear regression head. - X -> z - y -> z - (X + y) -> z """ def __init__(self): super(SimpleZRegressionVisionTransformer, self).__init__() self.vit = vit_b_16() # TODO: implement checkpoints for # ImageNet21K # JBL-300M self.encoder: TransformerEncoder = self.vit.encoder self.regression_head: Optional[torch.nn.Linear] = None self.regression_head_wide: Optional[torch.nn.Linear] = None def forward(self, x): r""" Shape: - x: :math:`(N, C, H, W)` where `H = W = 224` by default. """ # (16, 3, 224, 224) # (16, 196, 768) x = self.vit._process_input(x) n = x.shape[0] # (16, 1, 768) batch_class_token = self.vit.class_token.expand(n, -1, -1) # (16, 197, 768) x = torch.cat([batch_class_token, x], dim=1) # (16 * 197, 768) x: torch.Tensor = self.encoder(x) # (16 * 197, 1) x = x.view((x.shape[0], x.shape[1] * x.shape[2])) if self.regression_head == None: self.regression_head = torch.nn.Linear( in_features=x.shape[1], out_features=1 ).cuda() # (16, 1) x = self.regression_head(x) x = torch.nn.functional.sigmoid(x) return x def forward_two_inputs(self, x, y_hat): r""" Use for formulation III: P(z | X, y_hat) Shape: - X: :math:`(N, D, H, W)` where `H = W = 224` by default. - y_hat: :math:`(N, C, H, W)` where `H = W = 224` by default. """ ## HACK: destroy any information given by y_hat, can we learn to ignore? # C = 0.00001 # y_hat = y_hat * C # (16, 3, 224, 224) # (16, 196, 768) # preprocess x x = self.vit._process_input(x) n = x.shape[0] # (16, 1, 768) batch_class_token = self.vit.class_token.expand(n, -1, -1) # (16, 197, 768) x = torch.cat([batch_class_token, x], dim=1) # (16 * 197, 768) x: torch.Tensor = self.encoder(x) # (16 * 197, 1) x = x.view((x.shape[0], x.shape[1] * x.shape[2])) # (16, 3, 224, 224) # (16, 196, 768) # preprocess y # TODO: this is really lazy, we pre-process x and y separately... but using the same encoder? y_hat = self.vit._process_input(y_hat) n = y_hat.shape[0] # (16, 1, 768) batch_class_token = self.vit.class_token.expand(n, -1, -1) # (16, 197, 768) y_hat = torch.cat([batch_class_token, y_hat], dim=1) # (16 * 197, 768) y_hat: torch.Tensor = self.encoder(y_hat) # (16 * 197, 1) y_hat = y_hat.view((y_hat.shape[0], y_hat.shape[1] * y_hat.shape[2])) # x = [x, y_hat] x = torch.cat([x, y_hat], dim=1) # create 2x width regression head if self.regression_head_wide == None: self.regression_head_wide = torch.nn.Linear( in_features=x.shape[1], out_features=1 ).cuda() # regress z x = self.regression_head_wide(x) # (16, 1) x = torch.nn.functional.sigmoid(x) return x @staticmethod def get(weights=None): return SimpleZRegressionVisionTransformer() class EnsembleZRegressionVisionTransformer(torch.nn.Module): """ A VIT transformer encoder + linear regression head. - X -> z - y -> z - (X + y) -> z """ def __init__(self): super(EnsembleZRegressionVisionTransformer, self).__init__() self.vit_1 = vit_b_16() self.vit_2 = vit_b_16() self.encoder1: TransformerEncoder = self.vit_1.encoder self.encoder2: TransformerEncoder = self.vit_2.encoder self.regression_head_1: Optional[torch.nn.Linear] = None self.regression_head_2: Optional[torch.nn.Linear] = None def _forward( self, x: torch.Tensor, vit: VisionTransformer, encoder: TransformerEncoder, head: torch.nn.Linear, ): r""" Shape: - x: :math:`(N, C, H, W)` where `H = W = 224` by default. """ # (16, 3, 224, 224) # (16, 196, 768) x = vit._process_input(x) n = x.shape[0] # (16, 1, 768) batch_class_token = vit.class_token.expand(n, -1, -1) # (16, 197, 768) x = torch.cat([batch_class_token, x], dim=1) # (16 * 197, 768) x: torch.Tensor = encoder(x) # (16 * 197, 1) x = x.view((x.shape[0], x.shape[1] * x.shape[2])) if head == None: head = torch.nn.Linear(in_features=x.shape[1], out_features=1).cuda() # (16, 1) x = head(x) x = torch.nn.functional.sigmoid(x) return x def forward(self, x, y_hat): r""" Shape: - X: :math:`(N, D, H, W)` where `H = W = 224` by default. - y_hat: :math:`(N, C, H, W)` where `H = W = 224` by default. """ # TODO: # models has two trunks, one head # want to be able to learn to weight the imporance of each of these values... right? # (z1 * lambda_1) + (z2 * lambda_2) = z_pred z1 = self._forward(x, self.vit_1, self.encoder1, self.regression_head_1) z2 = self._forward(y_hat, self.vit_2, self.encoder2, self.regression_head_2) return (z1 + z2) / 2 @staticmethod def get(weights=None): return SimpleZRegressionVisionTransformer() if __name__ == "__main__": pass