Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import utils | |
| from torchvision.transforms import Resize | |
| from collections import OrderedDict | |
| import numpy as np | |
| import matplotlib.cm as cm | |
| import matplotlib as mpl | |
| from torchvision.transforms import InterpolationMode | |
| from .abs_model import abs_model | |
| from .blocks import * | |
| from .SSN import SSN | |
| from .SSN_v1 import SSN_v1 | |
| from .Loss.Loss import norm_loss, grad_loss | |
| from .Attention_Unet import Attention_Unet | |
| class Sparse_PH(abs_model): | |
| def __init__(self, opt): | |
| mid_act = opt['model']['mid_act'] | |
| out_act = opt['model']['out_act'] | |
| in_channels = opt['model']['in_channels'] | |
| out_channels = opt['model']['out_channels'] | |
| resnet = opt['model']['resnet'] | |
| backbone = opt['model']['backbone'] | |
| self.ncols = opt['hyper_params']['n_cols'] | |
| self.focal = opt['model']['focal'] | |
| self.clip = opt['model']['clip'] | |
| self.norm_loss_ = opt['model']['norm_loss'] | |
| self.grad_loss_ = opt['model']['grad_loss'] | |
| self.ggrad_loss_ = opt['model']['ggrad_loss'] | |
| self.lap_loss = opt['model']['lap_loss'] | |
| self.clip_range = opt['dataset']['linear_scale'] + opt['dataset']['linear_offset'] | |
| if backbone == 'Default': | |
| self.model = SSN_v1(in_channels=in_channels, | |
| out_channels=out_channels, | |
| mid_act=mid_act, | |
| out_act=out_act, | |
| resnet=resnet) | |
| elif backbone == 'ATTN': | |
| self.model = Attention_Unet(in_channels, out_channels, mid_act=mid_act, out_act=out_act) | |
| self.optimizer = get_optimizer(opt, self.model) | |
| self.visualization = {} | |
| self.norm_loss = norm_loss() | |
| self.grad_loss = grad_loss() | |
| def setup_input(self, x): | |
| return x | |
| def forward(self, x): | |
| return self.model(x) | |
| def compute_loss(self, y, pred): | |
| b = y.shape[0] | |
| # total_loss = avg_norm_loss(y, pred) | |
| nloss = self.norm_loss.loss(y, pred) * self.norm_loss_ | |
| gloss = self.grad_loss.loss(pred) * self.grad_loss_ | |
| ggloss = self.grad_loss.gloss(y, pred) * self.ggrad_loss_ | |
| laploss = self.grad_loss.laploss(pred) * self.lap_loss | |
| total_loss = nloss + gloss + ggloss + laploss | |
| self.loss_log = { | |
| 'norm_loss': nloss.item(), | |
| 'grad_loss': gloss.item(), | |
| 'grad_l1_loss': ggloss.item(), | |
| 'lap_loss': laploss.item(), | |
| } | |
| if self.focal: | |
| total_loss = torch.pow(total_loss, 3) | |
| return total_loss | |
| def supervise(self, input_x, y, is_training:bool)->float: | |
| optimizer = self.optimizer | |
| model = self.model | |
| x = input_x['x'] | |
| optimizer.zero_grad() | |
| pred = self.forward(x) | |
| if self.clip: | |
| pred = torch.clip(pred, 0.0, self.clip_range) | |
| loss = self.compute_loss(y, pred) | |
| if is_training: | |
| loss.backward() | |
| optimizer.step() | |
| xc = x.shape[1] | |
| for i in range(xc): | |
| self.visualization['x{}'.format(i)] = x[:, i:i+1].detach() | |
| self.visualization['y_fore'] = y[:, 0:1].detach() | |
| self.visualization['y_back'] = y[:, 1:2].detach() | |
| self.visualization['pred_fore'] = pred[:, 0:1].detach() | |
| self.visualization['pred_back'] = pred[:, 1:2].detach() | |
| return loss.item() | |
| def get_visualize(self) -> OrderedDict: | |
| """ Convert to visualization numpy array | |
| """ | |
| nrows = self.ncols | |
| visualizations = self.visualization | |
| ret_vis = OrderedDict() | |
| for k, v in visualizations.items(): | |
| batch = v.shape[0] | |
| n = min(nrows, batch) | |
| plot_v = v[:n] | |
| ret_vis[k] = np.clip(utils.make_grid(plot_v.cpu(), nrow=nrows).numpy().transpose(1,2,0), 0.0, 1.0) | |
| ret_vis[k] = self.plasma(ret_vis[k]) | |
| return ret_vis | |
| def get_logs(self): | |
| return self.loss_log | |
| def inference(self, x): | |
| x, device = x['x'], x['device'] | |
| x = torch.from_numpy(x.transpose((2,0,1))).unsqueeze(dim=0).float().to(device) | |
| pred = self.forward(x) | |
| pred = pred[0].detach().cpu().numpy().transpose((1,2,0)) | |
| return pred | |
| def batch_inference(self, x): | |
| x = x['x'] | |
| pred = self.forward(x) | |
| return pred | |
| """ Getter & Setter | |
| """ | |
| def get_models(self) -> dict: | |
| return {'model': self.model} | |
| def get_optimizers(self) -> dict: | |
| return {'optimizer': self.optimizer} | |
| def set_models(self, models: dict) : | |
| # input test | |
| if 'model' not in models.keys(): | |
| raise ValueError('{} not in self.model'.format('model')) | |
| self.model = models['model'] | |
| def set_optimizers(self, optimizer: dict): | |
| self.optimizer = optimizer['optimizer'] | |
| #################### | |
| # Personal Methods # | |
| #################### | |
| def plasma(self, x): | |
| norm = mpl.colors.Normalize(vmin=0.0, vmax=1) | |
| mapper = cm.ScalarMappable(norm=norm, cmap='plasma') | |
| bimg = mapper.to_rgba(x[:,:,0])[:,:,:3] | |
| return bimg | |