RDNet / tools /saver.py
lime-j's picture
Upload 89 files
347b44e
import torch
import torch.nn as nn
import os
import time
from tools import mutils
saved_grad = None
saved_name = None
base_url = './results'
os.makedirs(base_url, exist_ok=True)
def normalize_tensor_mm(tensor):
return (tensor - tensor.min()) / (tensor.max() - tensor.min())
def normalize_tensor_sigmoid(tensor):
return nn.functional.sigmoid(tensor)
def save_image(tensor, name=None, save_path=None, exit_flag=False, timestamp=False, nrow=4, split_dir=None):
if split_dir:
_base_url = os.path.join(base_url, split_dir)
else:
_base_url = base_url
os.makedirs(_base_url, exist_ok=True)
import torchvision.utils as vutils
grid = vutils.make_grid(tensor.detach().cpu(), nrow=nrow)
if save_path:
vutils.save_image(grid, save_path)
else:
if timestamp:
vutils.save_image(grid, f'{_base_url}/{name}_{mutils.get_timestamp()}.png')
else:
vutils.save_image(grid, f'{_base_url}/{name}.png')
if exit_flag:
exit(0)
def save_feature(tensor, name, exit_flag=False, timestamp=False):
import torchvision.utils as vutils
# tensors = [tensor, normalize_tensor_mm(tensor), normalize_tensor_sigmoid(tensor)]
tensors = [tensor]
titles = ['original', 'min-max', 'sigmoid']
if timestamp:
name += '_' + str(time.time()).replace('.', '')
for index, tensor in enumerate(tensors):
_data = tensor.detach().cpu().squeeze(0).unsqueeze(1)
num_per_row = 4
if _data.shape[0] / 4 > 4:
num_per_row = int(_data.shape[0] / 4)
num_per_row = 8
grid = vutils.make_grid(_data, nrow=num_per_row)
vutils.save_image(grid, f'{base_url}/{name}_{titles[index]}.png')
print(f'{base_url}/{name}_{titles[index]}.png')
if exit_flag:
exit(0)
def save(tensor, name, exit_flag=False):
import torchvision.utils as vutils
grid = vutils.make_grid(tensor.detach().cpu().squeeze(0).unsqueeze(1), nrow=4)
# grid = (grid - grid.min()) / (grid.max() - grid.min())
# print(grid)
vutils.save_image(grid, f'{base_url}/{name}.png')
if exit_flag:
exit(0)
def save_grid_direct(grad, name):
grad = grad.view(1, 8, 320, 320) * 255 / (320 * 320)
# grad = grad.view(grad.shape[0],grad)
save(grad.clamp(0, 255), name)
module_grad = grad.clamp(-200, 200)
print(module_grad.min().item(), module_grad.max().item(), module_grad.mean().item())
module_grad_flat = module_grad.flatten()
print(name, len(module_grad_flat[module_grad_flat < 0]) / len(module_grad_flat),
len(module_grad_flat[module_grad_flat < 0]), len(module_grad_flat[module_grad_flat == 0]))
import matplotlib.pyplot as plt
import numpy as np
y, x = np.histogram(module_grad.cpu().flatten().numpy(), bins=50, density=True)
# plt.hist(module_grad.cpu().flatten().numpy(), 50)
# for a, b in zip(x[:-1], y):
# print(a, b)
# print(x)
# print(y)
plt.bar(x[:-1], y)
# print('hist', hist)
# print(module_grad.cpu().flatten().numpy())
plt.show()
def save_grid(grad, name, exit_flag=False):
global saved_grad, saved_name
print(grad.shape)
if saved_grad is None:
print(name)
saved_grad = grad
saved_name = name
else:
# grad_flat = grad.flatten()
# print(name, len(grad_flat[grad_flat < 0]) / len(grad_flat))
module_grad = grad / (saved_grad + 1e-7)
print(module_grad.max())
save(module_grad.clamp(0, 255) / 255., name)
module_grad = module_grad.clamp(-300, 300)
print(module_grad.min().item(), module_grad.max().item(), module_grad.mean().item())
module_grad_flat = module_grad.flatten()
print(name, len(module_grad_flat[module_grad_flat < 0]) / len(module_grad_flat),
len(module_grad_flat[module_grad_flat < 0]), len(module_grad_flat[module_grad_flat == 0]))
import matplotlib.pyplot as plt
import numpy as np
y, x = np.histogram(module_grad.cpu().flatten().numpy(), bins=50, density=True)
# plt.hist(module_grad.cpu().flatten().numpy(), 50)
# for a, b in zip(x[:-1], y):
# print(a, b)
# print(x)
# print(y)
plt.bar(x[:-1], y)
# print('hist', hist)
# print(module_grad.cpu().flatten().numpy())
plt.show()
exit(0)
# print(len(grad))
# print(grad)
# print(grad[0].shape)
# grad = grad[0]
#
# grad_flat = grad.flatten()
# print('--------***')
# print('--------***')
# print('--------***')
# print(name, len(grad_flat[grad_flat < 0]) / len(grad_flat))
# print('--------***')
# print('--------***')
# print('--------***')
# import torchvision.transforms as vtrans
# import matplotlib.pyplot as plt
# plt.hist()
# plt.imshow(vtrans.ToPILImage()(grid))
# plt.title(name + ' grad')
# plt.show()
#
# if name in ['HE', 'CE Module', 'SOFT']:
# if saved_grad is None:
# saved_grad = grad
# saved_name = name
# else:
# grad = grad.reshape_as(saved_grad)
# print((saved_grad - grad).mean())
# print('diff: ', (saved_grad - grad).abs().max().item())
# print('mean: ', name, grad.mean().item(), saved_name, saved_grad.mean().item())
#
# saved_grad = grad
# saved_name = name
if exit_flag:
exit(0)
def show_grid(grid, name, exit_flag=False):
import torchvision.utils as vutils
import torchvision.transforms as vtrans
import matplotlib.pyplot as plt
grid = (grid - grid.min()) / (grid.max() - grid.min())
grid = vutils.make_grid(grid.cpu().squeeze(0).unsqueeze(1), nrow=4)
# name = unique.get_unique(name)
plt.imshow(vtrans.ToPILImage()(grid))
plt.title(name)
plt.show()
# vutils.save_image(grid, f'/home/huqiming/workspace/Pytorch_Retinaface/results/{name}.png')
if exit_flag:
exit(0)
def show_img(img, name, exit_flag=False):
import torchvision.utils as vutils
import torchvision.transforms as vtrans
import matplotlib.pyplot as plt
grid = vutils.make_grid(img.cpu().squeeze(0))
# name = unique.get_unique(name)
plt.imshow(vtrans.ToPILImage()(grid))
plt.title(name)
plt.show()
# vutils.save_image(grid, f'/home/huqiming/workspace/Pytorch_Retinaface/results/{name}.png')
if exit_flag:
exit(0)
class SaverBlock(nn.Module):
def __init__(self):
super(SaverBlock, self).__init__()
def forward(self, x):
save_feature(x[0], 'intermediate_', timestamp=True)
return x