pal-b-large-opt-350m / projector.py
daiweichen's picture
Upload PAL_B_RM_opt
3a2aa34 verified
raw
history blame
2.24 kB
from typing import Sequence
import torch.nn as nn
import torch
class Projector(nn.Module):
in_dims: int
out_dims: int
latent_dims: Sequence[int]
bias: bool
dropout_p: float
activation: str
identity_map: bool
use_batchnorm: bool
def __init__(
self,
in_dims: int,
out_dims: int,
latent_dims: Sequence[int] = tuple([]),
bias: bool = True,
dropout_p: float = 0.2,
activation:str='relu',
identity_map=False,
use_batchnorm: bool = False,
):
super().__init__()
self.in_dims = in_dims
self.out_dims = out_dims
self.bias = bias
self.dropout_p = dropout_p
self.latent_dims = latent_dims
self.act = None
self.identity_map = identity_map
self.use_batchnorm = use_batchnorm
if activation == 'relu':
self.act = nn.ReLU
elif activation == 'gelu':
self.act = nn.GELU
elif activation == 'linear':
self.act = nn.Identity
else:
raise ValueError(f'no such activation {activation}')
if identity_map == True:
self.identity = nn.Identity()
# self.alpha = nn.Parameter(torch.tensor(0.5))
layer_dims = [in_dims] + list(latent_dims)
layers = []
for i in range(len(layer_dims) - 1):
layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1], bias=self.bias))
if self.use_batchnorm: # Add batch normalization layer if enabled
layers.append(nn.BatchNorm1d(layer_dims[i + 1]))
layers.extend([
nn.Dropout(p=self.dropout_p),
self.act()
])
layers.append(nn.Linear(layer_dims[-1], out_dims, bias=self.bias))
self.layers = nn.Sequential(*layers)
def forward(self, x) -> torch.Tensor:
"""Forward pass of the projector model.
Args:
x: The input features.
Returns:
torch.Tensor: The projected features.
"""
if self.identity_map:
x = self.identity(x) + self.layers(x)
else:
x = self.layers(x)
return x