|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import numpy as np
|
|
|
|
|
|
from typing import Union, Tuple
|
|
|
from PIL import Image
|
|
|
|
|
|
from .layers import attn, layer_norm, mlp
|
|
|
from .image_crops import overlap_crop_image
|
|
|
from .config import VisionConfig
|
|
|
|
|
|
if torch.backends.mps.is_available():
|
|
|
|
|
|
|
|
|
def adaptive_avg_pool2d(input, output_size):
|
|
|
return F.adaptive_avg_pool2d(input.to("cpu"), output_size).to("mps")
|
|
|
|
|
|
else:
|
|
|
adaptive_avg_pool2d = F.adaptive_avg_pool2d
|
|
|
|
|
|
DeviceLike = Union[str, torch.device, int]
|
|
|
|
|
|
|
|
|
def prepare_crops(
|
|
|
image: Image.Image, config: VisionConfig, device: DeviceLike
|
|
|
) -> Tuple[torch.Tensor, Tuple[int, int]]:
|
|
|
np_image = np.array(image.convert("RGB"))
|
|
|
overlap_crops = overlap_crop_image(
|
|
|
np_image, max_crops=config.max_crops, overlap_margin=config.overlap_margin
|
|
|
)
|
|
|
all_crops = overlap_crops["crops"]
|
|
|
all_crops = np.transpose(all_crops, (0, 3, 1, 2))
|
|
|
all_crops = (
|
|
|
torch.from_numpy(all_crops)
|
|
|
.to(device=device, dtype=torch.bfloat16)
|
|
|
.div_(255.0)
|
|
|
.sub_(0.5)
|
|
|
.div_(0.5)
|
|
|
)
|
|
|
return all_crops, overlap_crops["tiling"]
|
|
|
|
|
|
|
|
|
def create_patches(x, patch_size):
|
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
P1 = P2 = patch_size
|
|
|
|
|
|
|
|
|
|
|
|
x = x.reshape(B, C, H // P1, P1, W // P2, P2)
|
|
|
|
|
|
|
|
|
|
|
|
x = x.permute(0, 2, 4, 1, 3, 5)
|
|
|
|
|
|
|
|
|
|
|
|
x = x.reshape(B, (H // P1) * (W // P2), C * P1 * P2)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def vision_encoder(input_BCHW: torch.Tensor, w: nn.Module, config: VisionConfig):
|
|
|
x = create_patches(input_BCHW, config.enc_patch_size)
|
|
|
|
|
|
x = w.patch_emb(x)
|
|
|
x = x + w.pos_emb
|
|
|
for block in w.blocks:
|
|
|
x = x + attn(layer_norm(x, block.ln1), block.attn, n_heads=config.enc_n_heads)
|
|
|
x = x + mlp(layer_norm(x, block.ln2), block.mlp)
|
|
|
x = layer_norm(x, w.post_ln)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def vision_projection(
|
|
|
global_features: torch.Tensor,
|
|
|
reconstructed: torch.Tensor,
|
|
|
w: nn.Module,
|
|
|
config: VisionConfig,
|
|
|
):
|
|
|
reconstructed = reconstructed.permute(2, 0, 1)
|
|
|
reconstructed = adaptive_avg_pool2d(
|
|
|
reconstructed, output_size=(config.enc_n_layers, config.enc_n_layers)
|
|
|
)
|
|
|
reconstructed = reconstructed.permute(1, 2, 0).view(729, config.enc_dim)
|
|
|
final_features = torch.cat([global_features, reconstructed], dim=-1)
|
|
|
return mlp(final_features, w.proj_mlp)
|
|
|
|
|
|
|
|
|
def build_vision_model(config: VisionConfig, dtype: torch.dtype):
|
|
|
patch_dim = config.enc_patch_size * config.enc_patch_size * config.in_channels
|
|
|
grid_size = config.crop_size // config.enc_patch_size
|
|
|
num_patches = grid_size * grid_size
|
|
|
|
|
|
vision = nn.ModuleDict(
|
|
|
{
|
|
|
"patch_emb": nn.Linear(patch_dim, config.enc_dim, dtype=dtype),
|
|
|
"blocks": nn.ModuleList(
|
|
|
[
|
|
|
nn.ModuleDict(
|
|
|
{
|
|
|
"ln1": nn.LayerNorm(config.enc_dim, dtype=dtype),
|
|
|
"attn": nn.ModuleDict(
|
|
|
{
|
|
|
"qkv": nn.Linear(
|
|
|
config.enc_dim, 3 * config.enc_dim, dtype=dtype
|
|
|
),
|
|
|
"proj": nn.Linear(
|
|
|
config.enc_dim, config.enc_dim, dtype=dtype
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
"ln2": nn.LayerNorm(config.enc_dim, dtype=dtype),
|
|
|
"mlp": nn.ModuleDict(
|
|
|
{
|
|
|
"fc1": nn.Linear(
|
|
|
config.enc_dim, config.enc_ff_dim, dtype=dtype
|
|
|
),
|
|
|
"fc2": nn.Linear(
|
|
|
config.enc_ff_dim, config.enc_dim, dtype=dtype
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
}
|
|
|
)
|
|
|
for _ in range(config.enc_n_layers)
|
|
|
]
|
|
|
),
|
|
|
"post_ln": nn.LayerNorm(config.enc_dim, dtype=dtype),
|
|
|
"proj_mlp": nn.ModuleDict(
|
|
|
{
|
|
|
"fc1": nn.Linear(
|
|
|
config.enc_dim * 2, config.proj_inner_dim, dtype=dtype
|
|
|
),
|
|
|
"fc2": nn.Linear(
|
|
|
config.proj_inner_dim, config.proj_out_dim, dtype=dtype
|
|
|
),
|
|
|
}
|
|
|
),
|
|
|
}
|
|
|
)
|
|
|
vision.pos_emb = nn.Parameter(
|
|
|
torch.zeros(1, num_patches, config.enc_dim, dtype=dtype)
|
|
|
)
|
|
|
return vision
|
|
|
|