| from typing import Sequence |
| import torch.nn as nn |
| import torch |
|
|
|
|
| class Projector(nn.Module): |
|
|
| in_dims: int |
| out_dims: int |
| latent_dims: Sequence[int] |
| bias: bool |
| dropout_p: float |
| activation: str |
| identity_map: bool |
| use_batchnorm: bool |
|
|
| def __init__( |
| self, |
| in_dims: int, |
| out_dims: int, |
| latent_dims: Sequence[int] = tuple([]), |
| bias: bool = True, |
| dropout_p: float = 0.2, |
| activation:str='relu', |
| identity_map=False, |
| use_batchnorm: bool = False, |
| ): |
| super().__init__() |
|
|
| self.in_dims = in_dims |
| self.out_dims = out_dims |
| self.bias = bias |
| self.dropout_p = dropout_p |
| self.latent_dims = latent_dims |
| self.act = None |
| self.identity_map = identity_map |
| self.use_batchnorm = use_batchnorm |
| |
| if activation == 'relu': |
| self.act = nn.ReLU |
| elif activation == 'gelu': |
| self.act = nn.GELU |
| elif activation == 'linear': |
| self.act = nn.Identity |
| else: |
| raise ValueError(f'no such activation {activation}') |
| |
| if identity_map == True: |
| self.identity = nn.Identity() |
| |
|
|
| layer_dims = [in_dims] + list(latent_dims) |
| layers = [] |
|
|
| for i in range(len(layer_dims) - 1): |
| layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1], bias=self.bias)) |
| if self.use_batchnorm: |
| layers.append(nn.BatchNorm1d(layer_dims[i + 1])) |
| layers.extend([ |
| nn.Dropout(p=self.dropout_p), |
| self.act() |
| ]) |
| |
| layers.append(nn.Linear(layer_dims[-1], out_dims, bias=self.bias)) |
| self.layers = nn.Sequential(*layers) |
|
|
| def forward(self, x) -> torch.Tensor: |
| """Forward pass of the projector model. |
| |
| Args: |
| x: The input features. |
| |
| Returns: |
| torch.Tensor: The projected features. |
| |
| """ |
| if self.identity_map: |
| x = self.identity(x) + self.layers(x) |
| else: |
| x = self.layers(x) |
| return x |
|
|