| | """ Classifier head and layer factory |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | from torch import nn as nn |
| | from torch.nn import functional as F |
| |
|
| | from .adaptive_avgmax_pool import SelectAdaptivePool2d |
| | from .linear import Linear |
| |
|
| |
|
| | def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): |
| | flatten_in_pool = not use_conv |
| | if not pool_type: |
| | assert num_classes == 0 or use_conv,\ |
| | 'Pooling can only be disabled if classifier is also removed or conv classifier is used' |
| | flatten_in_pool = False |
| | global_pool = SelectAdaptivePool2d(pool_type=pool_type, flatten=flatten_in_pool) |
| | num_pooled_features = num_features * global_pool.feat_mult() |
| | return global_pool, num_pooled_features |
| |
|
| |
|
| | def _create_fc(num_features, num_classes, use_conv=False): |
| | if num_classes <= 0: |
| | fc = nn.Identity() |
| | elif use_conv: |
| | fc = nn.Conv2d(num_features, num_classes, 1, bias=True) |
| | else: |
| | |
| | fc = Linear(num_features, num_classes, bias=True) |
| | return fc |
| |
|
| |
|
| | def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False): |
| | global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv) |
| | fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) |
| | return global_pool, fc |
| |
|
| |
|
| | class ClassifierHead(nn.Module): |
| | """Classifier head w/ configurable global pooling and dropout.""" |
| |
|
| | def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): |
| | super(ClassifierHead, self).__init__() |
| | self.drop_rate = drop_rate |
| | self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) |
| | self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) |
| | self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() |
| |
|
| | def forward(self, x): |
| | x = self.global_pool(x) |
| | if self.drop_rate: |
| | x = F.dropout(x, p=float(self.drop_rate), training=self.training) |
| | x = self.fc(x) |
| | x = self.flatten(x) |
| | return x |
| |
|