| from .projector import Projector |
| import torch.nn as nn |
| import re |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class Connector(nn.Module): |
| def __init__(self, in_dims: int, out_dims: int, cnct_arch:str): |
| super().__init__() |
| pattern = r"mlp(\d+)-(relu|gelu|linear)-dropout(\d+)?(-residual-batchnorm|-batchnorm-residual|-residual|-batchnorm|-nobias)?" |
| match = re.match(pattern, cnct_arch) |
| if match: |
| layers = int(match.group(1)) |
| act = match.group(2) |
| dropout_p = int(match.group(3)) |
| num_digit = len(match.group(3)) |
| dropout_p = dropout_p / 10**num_digit |
| if match.group(4) != None: |
| residual = True if ("-residual" in match.group(4)) else False |
| batchnorm = True if ("-batchnorm" in match.group(4)) else False |
| nobias = True if ("-nobias" in match.group(4)) else False |
| else: |
| residual = False |
| batchnorm = False |
| nobias = False |
| latent_dims = [out_dims] * layers |
| self.mlp = Projector( |
| in_dims=in_dims, |
| out_dims=out_dims, |
| latent_dims=latent_dims, |
| bias=not nobias, |
| dropout_p=dropout_p, |
| activation=act, |
| identity_map=residual, |
| use_batchnorm=batchnorm, |
| ) |
| elif cnct_arch == 'identity': |
| self.mlp = nn.Identity() |
| else: |
| raise ValueError(f'no such connection architecture {cnct_arch}') |
| |
| def __call__(self, x): |
| ret = self.mlp(x) |
| return ret |
|
|
| if __name__ == "__main__": |
| m = Connector(cnct_arch='identity',in_dims=4096,out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp1-relu-dropout2-residual',in_dims=4096,out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp1-relu-dropout2-batchnorm',in_dims=4096,out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp1-relu-dropout2-residual-batchnorm',in_dims=4096,out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp3-gelu-dropout2',in_dims=4096,out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp16-relu-dropout75',in_dims=4096,out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp0-linear-dropout0', in_dims=4096, out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp0-linear-dropout0-nobias', in_dims=4096, out_dims=768) |
| print(m) |
| m = Connector(cnct_arch='mlp2-linear-dropout0-nobias', in_dims=4096, out_dims=768) |
| print(m) |
|
|
| m = Connector(cnct_arch='mlp2-gelu-dropout0', in_dims=512, out_dims=512) |
| count = 0 |
| for p in m.parameters(): |
| count += p.numel() |
| print(count) |