File size: 5,127 Bytes
146a630 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import torch
import torch.nn as nn
from .configuration_vora import VoRAConfig
def _get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor, device: torch.device
) -> torch.Tensor:
omega = torch.arange(embed_dim // 2).float().to(device)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D / 2,)
pos = pos.reshape(-1) # (M,)
out = pos[:, None] * omega[None, :] # (M, D / 2), outer product
emb_sin, emb_cos = torch.sin(out).to(device), torch.cos(out).to(device) # (M, D / 2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb
def get_sincos_pos_embed(h: int, w: int, embed_dim: int, device: torch.device) -> torch.Tensor:
assert embed_dim % 2 == 0, embed_dim
grid_h = torch.arange(h).float().to(device)
grid_w = torch.arange(w).float().to(device)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0).to(device)
grid = grid.reshape([2, 1, h, w])
emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], device)
emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], device)
pos_embed = torch.cat([emb_h, emb_w], dim=1) # (H * W, D)
return pos_embed
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def extra_repr(self) -> str:
return f"{tuple(self.weight.shape)}, eps={self.eps}"
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
class VisionEmbedding(nn.Module):
def __init__(self,
config: VoRAConfig = None,
hidden_size: int = 4096,
):
super().__init__()
self.patch_size = config.patch_size
self.proj = nn.Conv2d(
3,
hidden_size,
kernel_size=(self.patch_size, self.patch_size),
stride=(self.patch_size, self.patch_size),
bias=True,
)
self.norm = RMSNorm(hidden_size, eps=1e-05)
self.embed_dim = hidden_size
def forward(self, pixel_values: torch.Tensor):
_, _, H, W = pixel_values.shape
tokens = self.norm(self.proj(pixel_values).flatten(2).transpose(1, 2))
pos_embed = get_sincos_pos_embed(
H // self.patch_size, W // self.patch_size, embed_dim=self.embed_dim, device=tokens.device
)
tokens = tokens + pos_embed.to(tokens.device)
return tokens
class AIMv2PatchEmbed(nn.Module):
def __init__(self, config: VoRAConfig):
super().__init__()
self.proj = nn.Conv2d(
3,
config.vision_embedding_intermediate_size,
kernel_size=(config.patch_size, config.patch_size),
stride=(config.patch_size, config.patch_size),
)
self.norm = RMSNorm(config.vision_embedding_intermediate_size, eps=config.rms_norm_eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
return x
class AIMv2ViTPreprocessor(nn.Module):
def __init__(self,
config: VoRAConfig = None,
hidden_size: int = 4096,
):
super().__init__()
num_patches = (config.image_size // config.patch_size) ** 2
self.config = config
self.patchifier = AIMv2PatchEmbed(config)
self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.vision_embedding_intermediate_size)))
self.out_proj = nn.Linear(config.vision_embedding_intermediate_size, hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, C, H, W = x.shape
h_token = H // self.config.patch_size
w_token = W // self.config.patch_size
tokens = self.patchifier(x)
_, N, _ = tokens.shape
pos_embed = self.pos_embed.to(tokens.device)
if N <= pos_embed.size(1):
# 如果 N 小于或等于 num_patches,直接相加
tokens = tokens + pos_embed[:, :N]
else:
# 如果 N 大于 num_patches,使用双线性插值
# 将 pos_embed 调整为 (1, num_patches, hidden_size) 的形状
pos_embed = pos_embed.view(1, int(pos_embed.size(1)**0.5), int(pos_embed.size(1)**0.5), -1).permute(0, 3, 1, 2)
# 使用双线性插值调整大小
pos_embed = F.interpolate(pos_embed, size=(h_token, w_token), mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
# 重塑为 (1, N, hidden_size) 形状
pos_embed = pos_embed.view(1, N, pos_embed.size(-1))
tokens = tokens + pos_embed
return self.out_proj(tokens)
def build_vision_embedding(config: VoRAConfig, hidden_size):
if config.vision_embedding_type == "AIMv2":
return AIMv2ViTPreprocessor(config, hidden_size)
return VisionEmbedding(config, hidden_size) |