File size: 5,127 Bytes
146a630
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn

from .configuration_vora import VoRAConfig


def _get_1d_sincos_pos_embed_from_grid(
    embed_dim: int, pos: torch.Tensor, device: torch.device
) -> torch.Tensor:
    omega = torch.arange(embed_dim // 2).float().to(device)
    omega /= embed_dim / 2.0
    omega = 1.0 / 10000**omega  # (D / 2,)
    pos = pos.reshape(-1)  # (M,)
    out = pos[:, None] * omega[None, :]  # (M, D / 2), outer product
    emb_sin, emb_cos = torch.sin(out).to(device), torch.cos(out).to(device)  # (M, D / 2)
    emb = torch.cat([emb_sin, emb_cos], dim=1)  # (M, D)
    return emb


def get_sincos_pos_embed(h: int, w: int, embed_dim: int, device: torch.device) -> torch.Tensor:
    assert embed_dim % 2 == 0, embed_dim
    grid_h = torch.arange(h).float().to(device)
    grid_w = torch.arange(w).float().to(device)
    grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
    grid = torch.stack(grid, dim=0).to(device)
    grid = grid.reshape([2, 1, h, w])
    emb_h = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0], device)
    emb_w = _get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1], device)
    pos_embed = torch.cat([emb_h, emb_w], dim=1)  # (H * W, D)
    return pos_embed


class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(dim))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

    def extra_repr(self) -> str:
        return f"{tuple(self.weight.shape)}, eps={self.eps}"

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)


class VisionEmbedding(nn.Module):
    def __init__(self,
        config: VoRAConfig = None,
        hidden_size: int = 4096,
    ):
        super().__init__()
        self.patch_size = config.patch_size
        self.proj = nn.Conv2d(
            3,
            hidden_size,
            kernel_size=(self.patch_size, self.patch_size),
            stride=(self.patch_size, self.patch_size),
            bias=True,
        )
        self.norm = RMSNorm(hidden_size, eps=1e-05)
        self.embed_dim = hidden_size

    def forward(self, pixel_values: torch.Tensor):
        _, _, H, W = pixel_values.shape
        tokens = self.norm(self.proj(pixel_values).flatten(2).transpose(1, 2))
        pos_embed = get_sincos_pos_embed(
            H // self.patch_size, W // self.patch_size, embed_dim=self.embed_dim, device=tokens.device
        )
        tokens = tokens + pos_embed.to(tokens.device)
        return tokens


class AIMv2PatchEmbed(nn.Module):
    def __init__(self, config: VoRAConfig):
        super().__init__()
        self.proj = nn.Conv2d(
            3,
            config.vision_embedding_intermediate_size,
            kernel_size=(config.patch_size, config.patch_size),
            stride=(config.patch_size, config.patch_size),
        )
        self.norm = RMSNorm(config.vision_embedding_intermediate_size, eps=config.rms_norm_eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
        return x


class AIMv2ViTPreprocessor(nn.Module):
    def __init__(self,         
        config: VoRAConfig = None,
        hidden_size: int = 4096,
    ):
        super().__init__()
        num_patches = (config.image_size // config.patch_size) ** 2
        self.config = config

        self.patchifier = AIMv2PatchEmbed(config)
        self.pos_embed = nn.Parameter(torch.zeros((1, num_patches, config.vision_embedding_intermediate_size)))
        self.out_proj = nn.Linear(config.vision_embedding_intermediate_size, hidden_size, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        h_token = H // self.config.patch_size
        w_token = W // self.config.patch_size
        tokens = self.patchifier(x)
        _, N, _ = tokens.shape
        pos_embed = self.pos_embed.to(tokens.device)

        if N <= pos_embed.size(1):
            # 如果 N 小于或等于 num_patches,直接相加
            tokens = tokens + pos_embed[:, :N]
        else:
            # 如果 N 大于 num_patches,使用双线性插值
            # 将 pos_embed 调整为 (1, num_patches, hidden_size) 的形状
            pos_embed = pos_embed.view(1, int(pos_embed.size(1)**0.5), int(pos_embed.size(1)**0.5), -1).permute(0, 3, 1, 2)
            # 使用双线性插值调整大小
            pos_embed = F.interpolate(pos_embed, size=(h_token, w_token), mode='bilinear', align_corners=False).permute(0, 2, 3, 1)
            # 重塑为 (1, N, hidden_size) 形状
            pos_embed = pos_embed.view(1, N, pos_embed.size(-1))
            tokens = tokens + pos_embed

        return self.out_proj(tokens)


def build_vision_embedding(config: VoRAConfig, hidden_size):
    if config.vision_embedding_type == "AIMv2":
        return AIMv2ViTPreprocessor(config, hidden_size)
    return VisionEmbedding(config, hidden_size)