Spaces:
Sleeping
Sleeping
| import torch | |
| from typing import Optional | |
| from torch.nn.modules import Transformer, TransformerEncoder | |
| from torchvision.models import vit_b_16, VisionTransformer | |
| H_DIM = 512 | |
| H_DIM_VT = 768 | |
| N_OUTPUT_TOKENS_VT = 16 * 197 | |
| class SimpleZRegressionVisionTransformer(torch.nn.Module): | |
| """ | |
| A VIT transformer encoder + linear regression head. | |
| - X -> z | |
| - y -> z | |
| - (X + y) -> z | |
| """ | |
| def __init__(self): | |
| super(SimpleZRegressionVisionTransformer, self).__init__() | |
| self.vit = vit_b_16() | |
| # TODO: implement checkpoints for | |
| # ImageNet21K | |
| # JBL-300M | |
| self.encoder: TransformerEncoder = self.vit.encoder | |
| self.regression_head: Optional[torch.nn.Linear] = None | |
| self.regression_head_wide: Optional[torch.nn.Linear] = None | |
| def forward(self, x): | |
| r""" | |
| Shape: | |
| - x: :math:`(N, C, H, W)` where `H = W = 224` by default. | |
| """ | |
| # (16, 3, 224, 224) | |
| # (16, 196, 768) | |
| x = self.vit._process_input(x) | |
| n = x.shape[0] | |
| # (16, 1, 768) | |
| batch_class_token = self.vit.class_token.expand(n, -1, -1) | |
| # (16, 197, 768) | |
| x = torch.cat([batch_class_token, x], dim=1) | |
| # (16 * 197, 768) | |
| x: torch.Tensor = self.encoder(x) | |
| # (16 * 197, 1) | |
| x = x.view((x.shape[0], x.shape[1] * x.shape[2])) | |
| if self.regression_head == None: | |
| self.regression_head = torch.nn.Linear( | |
| in_features=x.shape[1], out_features=1 | |
| ).cuda() | |
| # (16, 1) | |
| x = self.regression_head(x) | |
| x = torch.nn.functional.sigmoid(x) | |
| return x | |
| def forward_two_inputs(self, x, y_hat): | |
| r""" | |
| Use for formulation III: P(z | X, y_hat) | |
| Shape: | |
| - X: :math:`(N, D, H, W)` where `H = W = 224` by default. | |
| - y_hat: :math:`(N, C, H, W)` where `H = W = 224` by default. | |
| """ | |
| ## HACK: destroy any information given by y_hat, can we learn to ignore? | |
| # C = 0.00001 | |
| # y_hat = y_hat * C | |
| # (16, 3, 224, 224) | |
| # (16, 196, 768) | |
| # preprocess x | |
| x = self.vit._process_input(x) | |
| n = x.shape[0] | |
| # (16, 1, 768) | |
| batch_class_token = self.vit.class_token.expand(n, -1, -1) | |
| # (16, 197, 768) | |
| x = torch.cat([batch_class_token, x], dim=1) | |
| # (16 * 197, 768) | |
| x: torch.Tensor = self.encoder(x) | |
| # (16 * 197, 1) | |
| x = x.view((x.shape[0], x.shape[1] * x.shape[2])) | |
| # (16, 3, 224, 224) | |
| # (16, 196, 768) | |
| # preprocess y | |
| # TODO: this is really lazy, we pre-process x and y separately... but using the same encoder? | |
| y_hat = self.vit._process_input(y_hat) | |
| n = y_hat.shape[0] | |
| # (16, 1, 768) | |
| batch_class_token = self.vit.class_token.expand(n, -1, -1) | |
| # (16, 197, 768) | |
| y_hat = torch.cat([batch_class_token, y_hat], dim=1) | |
| # (16 * 197, 768) | |
| y_hat: torch.Tensor = self.encoder(y_hat) | |
| # (16 * 197, 1) | |
| y_hat = y_hat.view((y_hat.shape[0], y_hat.shape[1] * y_hat.shape[2])) | |
| # x = [x, y_hat] | |
| x = torch.cat([x, y_hat], dim=1) | |
| # create 2x width regression head | |
| if self.regression_head_wide == None: | |
| self.regression_head_wide = torch.nn.Linear( | |
| in_features=x.shape[1], out_features=1 | |
| ).cuda() | |
| # regress z | |
| x = self.regression_head_wide(x) | |
| # (16, 1) | |
| x = torch.nn.functional.sigmoid(x) | |
| return x | |
| def get(weights=None): | |
| return SimpleZRegressionVisionTransformer() | |
| class EnsembleZRegressionVisionTransformer(torch.nn.Module): | |
| """ | |
| A VIT transformer encoder + linear regression head. | |
| - X -> z | |
| - y -> z | |
| - (X + y) -> z | |
| """ | |
| def __init__(self): | |
| super(EnsembleZRegressionVisionTransformer, self).__init__() | |
| self.vit_1 = vit_b_16() | |
| self.vit_2 = vit_b_16() | |
| self.encoder1: TransformerEncoder = self.vit_1.encoder | |
| self.encoder2: TransformerEncoder = self.vit_2.encoder | |
| self.regression_head_1: Optional[torch.nn.Linear] = None | |
| self.regression_head_2: Optional[torch.nn.Linear] = None | |
| def _forward( | |
| self, | |
| x: torch.Tensor, | |
| vit: VisionTransformer, | |
| encoder: TransformerEncoder, | |
| head: torch.nn.Linear, | |
| ): | |
| r""" | |
| Shape: | |
| - x: :math:`(N, C, H, W)` where `H = W = 224` by default. | |
| """ | |
| # (16, 3, 224, 224) | |
| # (16, 196, 768) | |
| x = vit._process_input(x) | |
| n = x.shape[0] | |
| # (16, 1, 768) | |
| batch_class_token = vit.class_token.expand(n, -1, -1) | |
| # (16, 197, 768) | |
| x = torch.cat([batch_class_token, x], dim=1) | |
| # (16 * 197, 768) | |
| x: torch.Tensor = encoder(x) | |
| # (16 * 197, 1) | |
| x = x.view((x.shape[0], x.shape[1] * x.shape[2])) | |
| if head == None: | |
| head = torch.nn.Linear(in_features=x.shape[1], out_features=1).cuda() | |
| # (16, 1) | |
| x = head(x) | |
| x = torch.nn.functional.sigmoid(x) | |
| return x | |
| def forward(self, x, y_hat): | |
| r""" | |
| Shape: | |
| - X: :math:`(N, D, H, W)` where `H = W = 224` by default. | |
| - y_hat: :math:`(N, C, H, W)` where `H = W = 224` by default. | |
| """ | |
| # TODO: | |
| # models has two trunks, one head | |
| # want to be able to learn to weight the imporance of each of these values... right? | |
| # (z1 * lambda_1) + (z2 * lambda_2) = z_pred | |
| z1 = self._forward(x, self.vit_1, self.encoder1, self.regression_head_1) | |
| z2 = self._forward(y_hat, self.vit_2, self.encoder2, self.regression_head_2) | |
| return (z1 + z2) / 2 | |
| def get(weights=None): | |
| return SimpleZRegressionVisionTransformer() | |
| if __name__ == "__main__": | |
| pass |