Spaces:
Sleeping
Sleeping
| import math | |
| import torch | |
| import torch.nn as nn | |
| from abc import ABC | |
| class ABC_Model(ABC): | |
| def global_average_pooling_2d(self, x, keepdims=False): | |
| x = torch.mean(x.view(x.size(0), x.size(1), -1), -1) | |
| if keepdims: | |
| x = x.view(x.size(0), x.size(1), 1, 1) | |
| return x | |
| def initialize(self, modules): | |
| for m in modules: | |
| if isinstance(m, nn.Conv2d): | |
| # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| # m.weight.data.normal_(0, math.sqrt(2. / n)) | |
| torch.nn.init.kaiming_normal_(m.weight) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| def get_parameter_groups(self, print_fn=print): | |
| groups = ([], [], [], []) | |
| for name, value in self.named_parameters(): | |
| # pretrained weights | |
| if 'model' in name: | |
| if 'weight' in name: | |
| # print_fn(f'pretrained weights : {name}') | |
| groups[0].append(value) | |
| else: | |
| # print_fn(f'pretrained bias : {name}') | |
| groups[1].append(value) | |
| # scracthed weights | |
| else: | |
| if 'weight' in name: | |
| if print_fn is not None: | |
| print_fn(f'scratched weights : {name}') | |
| groups[2].append(value) | |
| else: | |
| if print_fn is not None: | |
| print_fn(f'scratched bias : {name}') | |
| groups[3].append(value) | |
| return groups | |