| | |
| | import math |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import nn |
| |
|
| | from core.utils import get_init_fn |
| | from core.vision_projector.base import BaseProjector |
| |
|
| |
|
| | class AdaptiveAvgPooling(nn.Module): |
| | def __init__(self, pooling_ratio=2): |
| | super(AdaptiveAvgPooling, self).__init__() |
| | self.pooling_ratio = pooling_ratio |
| |
|
| | def forward(self, x): |
| | b, num_tokens, c = x.shape |
| | h = int(math.sqrt(num_tokens)) |
| | assert h * h == num_tokens |
| |
|
| | shape = (h // self.pooling_ratio, h // self.pooling_ratio) |
| | x = x.permute(0, 2, 1).reshape(b, -1, h, h) |
| | x = F.adaptive_avg_pool2d(x, shape) |
| | x = x.flatten(2).transpose(1, 2) |
| |
|
| | return x |
| |
|
| |
|
| | class MLPProjector(BaseProjector): |
| | def __init__(self, args): |
| | super().__init__() |
| | self.setup_projector(args) |
| | self.pooling_ratio = args.pooling_ratio |
| | self.adaptive_avg_pool = AdaptiveAvgPooling(pooling_ratio=args.pooling_ratio) |
| | self.remove_vision_class_token = args.remove_vision_class_token |
| |
|
| | def init_tensors(self): |
| | self.init_method(self.projector[0].weight) |
| | self.init_method(self.projector[0].bias) |
| | self.init_method(self.projector[2].weight) |
| | self.init_method(self.projector[2].bias) |
| |
|
| | def setup_projector(self, args): |
| | self.init_method = get_init_fn(args.mlp_init, args.dim, init_depth=None) |
| | input_size = args.vision_model["width"] |
| | output_size = args.dim |
| | self.projector = nn.Sequential( |
| | nn.Linear( |
| | in_features=input_size, |
| | out_features=output_size, |
| | bias=True, |
| | dtype=torch.get_default_dtype(), |
| | ), |
| | nn.GELU(), |
| | nn.Linear( |
| | in_features=output_size, |
| | out_features=output_size, |
| | bias=True, |
| | dtype=torch.get_default_dtype(), |
| | ), |
| | ) |
| |
|