| | import torch |
| | import torch.nn as nn |
| | import os |
| | import time |
| | from tools import mutils |
| |
|
| | saved_grad = None |
| | saved_name = None |
| |
|
| | base_url = './results' |
| | os.makedirs(base_url, exist_ok=True) |
| |
|
| |
|
| | def normalize_tensor_mm(tensor): |
| | return (tensor - tensor.min()) / (tensor.max() - tensor.min()) |
| |
|
| |
|
| | def normalize_tensor_sigmoid(tensor): |
| | return nn.functional.sigmoid(tensor) |
| |
|
| |
|
| | def save_image(tensor, name=None, save_path=None, exit_flag=False, timestamp=False, norm=False): |
| | import torchvision.utils as vutils |
| | os.makedirs(base_url, exist_ok=True) |
| | if norm: |
| | tensor = normalize_tensor_mm(tensor) |
| | grid = vutils.make_grid(tensor.detach().cpu(), nrow=4) |
| |
|
| | if save_path: |
| | vutils.save_image(grid, save_path) |
| | else: |
| | if timestamp: |
| | vutils.save_image(grid, f'{base_url}/{name}_{mutils.get_timestamp()}.png') |
| | else: |
| | vutils.save_image(grid, f'{base_url}/{name}.png') |
| | if exit_flag: |
| | exit(0) |
| |
|
| |
|
| | def save_feature(tensor, name, exit_flag=False, timestamp=False): |
| | import torchvision.utils as vutils |
| | |
| | tensors = [tensor] |
| | titles = ['original', 'min-max', 'sigmoid'] |
| | os.makedirs(base_url, exist_ok=True) |
| | if timestamp: |
| | name += '_' + str(time.time()).replace('.', '') |
| |
|
| | for index, tensor in enumerate(tensors): |
| | _data = tensor.detach().cpu().squeeze(0).unsqueeze(1) |
| | num_per_row = 8 |
| | grid = vutils.make_grid(_data, nrow=num_per_row) |
| | vutils.save_image(grid, f'{base_url}/{name}_{titles[index]}.png') |
| | if exit_flag: |
| | exit(0) |
| |
|