Spaces:
Build error
Build error
File size: 452 Bytes
c6a12ae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import torch.nn as nn
class ProjectionBlock(nn.Module):
def __init__(self, input_dim_CLIP, input_dim_phi2):
super().__init__()
self.pre_norm = nn.LayerNorm(input_dim_CLIP)
self.proj = nn.Sequential(
nn.Linear(input_dim_CLIP, input_dim_phi2),
nn.GELU(),
nn.Linear(input_dim_phi2, input_dim_phi2)
)
def forward(self, x):
x = self.pre_norm(x)
return self.proj(x) |