sparse-cafm / src /models /simple_z_predictor.py
leharris3's picture
Minimal HF Space deployment with gradio 5.x fix
0917e8d
import torch
from typing import Optional
from torch.nn.modules import Transformer, TransformerEncoder
from torchvision.models import vit_b_16, VisionTransformer
H_DIM = 512
H_DIM_VT = 768
N_OUTPUT_TOKENS_VT = 16 * 197
class SimpleZRegressionVisionTransformer(torch.nn.Module):
"""
A VIT transformer encoder + linear regression head.
- X -> z
- y -> z
- (X + y) -> z
"""
def __init__(self):
super(SimpleZRegressionVisionTransformer, self).__init__()
self.vit = vit_b_16()
# TODO: implement checkpoints for
# ImageNet21K
# JBL-300M
self.encoder: TransformerEncoder = self.vit.encoder
self.regression_head: Optional[torch.nn.Linear] = None
self.regression_head_wide: Optional[torch.nn.Linear] = None
def forward(self, x):
r"""
Shape:
- x: :math:`(N, C, H, W)` where `H = W = 224` by default.
"""
# (16, 3, 224, 224)
# (16, 196, 768)
x = self.vit._process_input(x)
n = x.shape[0]
# (16, 1, 768)
batch_class_token = self.vit.class_token.expand(n, -1, -1)
# (16, 197, 768)
x = torch.cat([batch_class_token, x], dim=1)
# (16 * 197, 768)
x: torch.Tensor = self.encoder(x)
# (16 * 197, 1)
x = x.view((x.shape[0], x.shape[1] * x.shape[2]))
if self.regression_head == None:
self.regression_head = torch.nn.Linear(
in_features=x.shape[1], out_features=1
).cuda()
# (16, 1)
x = self.regression_head(x)
x = torch.nn.functional.sigmoid(x)
return x
def forward_two_inputs(self, x, y_hat):
r"""
Use for formulation III: P(z | X, y_hat)
Shape:
- X: :math:`(N, D, H, W)` where `H = W = 224` by default.
- y_hat: :math:`(N, C, H, W)` where `H = W = 224` by default.
"""
## HACK: destroy any information given by y_hat, can we learn to ignore?
# C = 0.00001
# y_hat = y_hat * C
# (16, 3, 224, 224)
# (16, 196, 768)
# preprocess x
x = self.vit._process_input(x)
n = x.shape[0]
# (16, 1, 768)
batch_class_token = self.vit.class_token.expand(n, -1, -1)
# (16, 197, 768)
x = torch.cat([batch_class_token, x], dim=1)
# (16 * 197, 768)
x: torch.Tensor = self.encoder(x)
# (16 * 197, 1)
x = x.view((x.shape[0], x.shape[1] * x.shape[2]))
# (16, 3, 224, 224)
# (16, 196, 768)
# preprocess y
# TODO: this is really lazy, we pre-process x and y separately... but using the same encoder?
y_hat = self.vit._process_input(y_hat)
n = y_hat.shape[0]
# (16, 1, 768)
batch_class_token = self.vit.class_token.expand(n, -1, -1)
# (16, 197, 768)
y_hat = torch.cat([batch_class_token, y_hat], dim=1)
# (16 * 197, 768)
y_hat: torch.Tensor = self.encoder(y_hat)
# (16 * 197, 1)
y_hat = y_hat.view((y_hat.shape[0], y_hat.shape[1] * y_hat.shape[2]))
# x = [x, y_hat]
x = torch.cat([x, y_hat], dim=1)
# create 2x width regression head
if self.regression_head_wide == None:
self.regression_head_wide = torch.nn.Linear(
in_features=x.shape[1], out_features=1
).cuda()
# regress z
x = self.regression_head_wide(x)
# (16, 1)
x = torch.nn.functional.sigmoid(x)
return x
@staticmethod
def get(weights=None):
return SimpleZRegressionVisionTransformer()
class EnsembleZRegressionVisionTransformer(torch.nn.Module):
"""
A VIT transformer encoder + linear regression head.
- X -> z
- y -> z
- (X + y) -> z
"""
def __init__(self):
super(EnsembleZRegressionVisionTransformer, self).__init__()
self.vit_1 = vit_b_16()
self.vit_2 = vit_b_16()
self.encoder1: TransformerEncoder = self.vit_1.encoder
self.encoder2: TransformerEncoder = self.vit_2.encoder
self.regression_head_1: Optional[torch.nn.Linear] = None
self.regression_head_2: Optional[torch.nn.Linear] = None
def _forward(
self,
x: torch.Tensor,
vit: VisionTransformer,
encoder: TransformerEncoder,
head: torch.nn.Linear,
):
r"""
Shape:
- x: :math:`(N, C, H, W)` where `H = W = 224` by default.
"""
# (16, 3, 224, 224)
# (16, 196, 768)
x = vit._process_input(x)
n = x.shape[0]
# (16, 1, 768)
batch_class_token = vit.class_token.expand(n, -1, -1)
# (16, 197, 768)
x = torch.cat([batch_class_token, x], dim=1)
# (16 * 197, 768)
x: torch.Tensor = encoder(x)
# (16 * 197, 1)
x = x.view((x.shape[0], x.shape[1] * x.shape[2]))
if head == None:
head = torch.nn.Linear(in_features=x.shape[1], out_features=1).cuda()
# (16, 1)
x = head(x)
x = torch.nn.functional.sigmoid(x)
return x
def forward(self, x, y_hat):
r"""
Shape:
- X: :math:`(N, D, H, W)` where `H = W = 224` by default.
- y_hat: :math:`(N, C, H, W)` where `H = W = 224` by default.
"""
# TODO:
# models has two trunks, one head
# want to be able to learn to weight the imporance of each of these values... right?
# (z1 * lambda_1) + (z2 * lambda_2) = z_pred
z1 = self._forward(x, self.vit_1, self.encoder1, self.regression_head_1)
z2 = self._forward(y_hat, self.vit_2, self.encoder2, self.regression_head_2)
return (z1 + z2) / 2
@staticmethod
def get(weights=None):
return SimpleZRegressionVisionTransformer()
if __name__ == "__main__":
pass