Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class AveragePooling(nn.Module): | |
| def __init__(self, pooling_size=2, device='cpu'): | |
| super(AveragePooling, self).__init__() | |
| self.pooling_size = pooling_size | |
| self.device = device | |
| self.to(device) | |
| def forward(self, image_features): | |
| batch_size, num_features, dim = image_features.size() | |
| height = width = int(num_features ** 0.5) | |
| image_features = image_features.view(batch_size, height, width, dim) | |
| pooled_features = F.avg_pool2d(image_features.permute(0, 3, 1, 2), kernel_size=self.pooling_size) | |
| pooled_features = pooled_features.permute(0, 2, 3, 1) | |
| pooled_features = pooled_features.view(batch_size, -1, dim) | |
| return pooled_features | |
| class AttentionPooling(nn.Module): | |
| def __init__(self, input_dim, pooling_size=2, device='cpu',dtype=torch.float32): | |
| super(AttentionPooling, self).__init__() | |
| self.pooling_size = pooling_size | |
| self.device = device | |
| self.mlp = nn.Sequential( | |
| nn.Linear(input_dim, input_dim), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(input_dim, 1)) | |
| # self.mlp.to(device,dtype) | |
| def forward(self, x): | |
| batch_size, n, dim = x.shape | |
| sqrt_n = int(n ** 0.5) | |
| pooling_size = self.pooling_size | |
| x = x.view(batch_size, sqrt_n, sqrt_n, dim) | |
| pooled_features = [] | |
| for i in range(0, sqrt_n, pooling_size): | |
| for j in range(0, sqrt_n, pooling_size): | |
| region = x[:, i:i+pooling_size, j:j+pooling_size, :] | |
| region = region.reshape(batch_size, -1, dim) | |
| alpha = self.mlp(region) | |
| alpha = torch.softmax(alpha, dim=1) | |
| region_pooled = torch.sum(alpha * region, dim=1) | |
| pooled_features.append(region_pooled) | |
| output = torch.stack(pooled_features, dim=1) | |
| return output | |
| def build_pooling(pooling_type, input_dim=None, pooling_size=2, device='cpu',dtype=torch.float32): | |
| if pooling_type == 'average': | |
| return AveragePooling(pooling_size=pooling_size, device=device) | |
| elif pooling_type == 'attention': | |
| if input_dim is None: | |
| raise ValueError("input_dim must be specified for attention pooling") | |
| return AttentionPooling(input_dim=input_dim, pooling_size=pooling_size, device=device, dtype=dtype) | |
| else: | |
| raise ValueError("Unknown pooling type: {}".format(pooling_type)) | |