File size: 920 Bytes
0ff8d3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from einops import rearrange
import torch
import torch.nn as nn

class LinearPatchProjector(nn.Module):
    """

    z: [B, C, H, W]  ->  patches: [B, (H/k)*(W/k), C*k*k]  ->  embeds: [B, N, d_model]

    """
    def __init__(self, latent_ch: int, k: int, d_model: int):
        super().__init__()
        self.k = k
        self.patch_dim = latent_ch * k * k
        # 用 fp32 算 Linear,数值更稳
        self.proj = nn.Linear(self.patch_dim, d_model, bias=True)


    def forward(self, z):  # z: [B,C,H,W]
        B, C, H, W = z.shape
        k = self.k
        assert H % k == 0 and W % k == 0, f"H,W 必须是 k 的整数倍, got {(H,W)} vs k={k}"
        # 统一用明确的 block 方式切 patch(不会踩 stride 坑)
        patches = rearrange(z, 'b c (hs k1) (ws k2) -> b (hs ws) (c k1 k2)', k1=k, k2=k)
        embeds = self.proj(patches)  # fp32 线性
        return embeds