|
|
'''Some helper functions for PyTorch, including: |
|
|
- get_mean_and_std: calculate the mean and std value of dataset. |
|
|
- msr_init: net parameter initialization. |
|
|
- progress_bar: progress bar mimic xlua.progress. |
|
|
''' |
|
|
import errno |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
import math |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.init as init |
|
|
from torch.autograd import Variable |
|
|
|
|
|
__all__ = ['get_mean_and_std', 'init_params', 'mkdir_p', 'save_checkpoint', 'torch_accuracy', 'AverageMeter','get_vid_module_dict'] |
|
|
|
|
|
|
|
|
def get_mean_and_std(dataset): |
|
|
'''Compute the mean and std value of dataset.''' |
|
|
dataloader = trainloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2) |
|
|
|
|
|
mean = torch.zeros(3) |
|
|
std = torch.zeros(3) |
|
|
print('==> Computing mean and std..') |
|
|
for inputs, targets in dataloader: |
|
|
for i in range(3): |
|
|
mean[i] += inputs[:,i,:,:].mean() |
|
|
std[i] += inputs[:,i,:,:].std() |
|
|
mean.div_(len(dataset)) |
|
|
std.div_(len(dataset)) |
|
|
return mean, std |
|
|
|
|
|
def init_params(net): |
|
|
'''Init layer parameters.''' |
|
|
for m in net.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
init.kaiming_normal(m.weight, mode='fan_out') |
|
|
if m.bias: |
|
|
init.constant(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
init.constant(m.weight, 1) |
|
|
init.constant(m.bias, 0) |
|
|
elif isinstance(m, nn.Linear): |
|
|
init.normal(m.weight, std=1e-3) |
|
|
if m.bias: |
|
|
init.constant(m.bias, 0) |
|
|
|
|
|
def mkdir_p(path): |
|
|
'''make dir if not exist''' |
|
|
try: |
|
|
os.makedirs(path) |
|
|
except OSError as exc: |
|
|
if exc.errno == errno.EEXIST and os.path.isdir(path): |
|
|
pass |
|
|
else: |
|
|
raise |
|
|
|
|
|
|
|
|
def save_checkpoint(state, checkpoint, filename='checkpoint.pth.tar'): |
|
|
filepath = os.path.join(checkpoint, filename) |
|
|
torch.save(state, filepath) |
|
|
|
|
|
|
|
|
def torch_accuracy(output, target, topk=(1,)): |
|
|
topn = max(topk) |
|
|
batch_size = output.size(0) |
|
|
|
|
|
_, pred = output.topk(topn, 1, True, True) |
|
|
pred = pred.t() |
|
|
|
|
|
if len(target.size()) == 1: |
|
|
is_correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
|
elif len(target.size()) == 2: |
|
|
is_correct = pred.eq(target.max(1)[1].expand_as(pred)) |
|
|
|
|
|
ans = [] |
|
|
for i in topk: |
|
|
is_correct_i = is_correct[:i].view(-1).float().sum(0, keepdim=True) |
|
|
ans.append(is_correct_i.mul_(100.0 / batch_size)) |
|
|
|
|
|
return ans |
|
|
|
|
|
def get_vid_module_dict(model, hook_layers): |
|
|
vid_module_dict = {} |
|
|
hook_layers = [layer_name + '.weight' for layer_name in hook_layers] |
|
|
layer_names = model._modules.keys() |
|
|
i = 0 |
|
|
for name, p in model.named_parameters(): |
|
|
if name in hook_layers : |
|
|
name = 'feature_maps' + str(i) |
|
|
channels = p.shape[0] |
|
|
mean, var = get_mean_and_variance(channels) |
|
|
vid_module_dict[name + '_mean'] = mean |
|
|
vid_module_dict[name + '_var'] = var |
|
|
i += 1 |
|
|
|
|
|
return vid_module_dict |
|
|
|
|
|
|
|
|
def get_mean_and_variance(in_channels): |
|
|
out_channels = in_channels |
|
|
var_adap_avg_pool = False |
|
|
eps = 1e-5 |
|
|
|
|
|
|
|
|
mean = get_adaptation_layer(in_channels, out_channels, False) |
|
|
|
|
|
var = get_adaptation_layer(in_channels, out_channels, False) |
|
|
var.add_module(str(len(var)+1), nn.Softplus()) |
|
|
|
|
|
return mean, var |
|
|
|
|
|
|
|
|
def get_adaptation_layer(in_channels, out_channels, adap_avg_pool): |
|
|
layer = nn.Sequential( |
|
|
nn.Conv2d(in_channels=in_channels, out_channels=in_channels, |
|
|
kernel_size=1, stride=1, padding=0), |
|
|
nn.ReLU(), |
|
|
|
|
|
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, |
|
|
kernel_size=1, stride=1, padding=0 ) |
|
|
) |
|
|
|
|
|
return layer |
|
|
|
|
|
class AverageMeter(object): |
|
|
name = 'No name' |
|
|
|
|
|
def __init__(self, name='No name'): |
|
|
self.name = name |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
self.sum = 0 |
|
|
self.mean = 0 |
|
|
self.num = 0 |
|
|
self.now = 0 |
|
|
|
|
|
def update(self, mean_var, count=1): |
|
|
if math.isnan(mean_var): |
|
|
mean_var = 1e6 |
|
|
print('Avgmeter getting Nan!') |
|
|
self.now = mean_var |
|
|
self.num += count |
|
|
|
|
|
self.sum += mean_var * count |
|
|
self.mean = float(self.sum) / self.num |
|
|
|
|
|
|