Spaces:
Sleeping
Sleeping
| import torch | |
| from torchvision import transforms | |
| from PIL import Image | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| import os | |
| def get_clothes_mask(old_label) : | |
| clothes = torch.FloatTensor((old_label.cpu().numpy() == 3).astype(np.int)) | |
| return clothes | |
| def changearm(old_label): | |
| label=old_label | |
| arm1=torch.FloatTensor((old_label.cpu().numpy()==5).astype(np.int)) | |
| arm2=torch.FloatTensor((old_label.cpu().numpy()==6).astype(np.int)) | |
| label=label*(1-arm1)+arm1*3 | |
| label=label*(1-arm2)+arm2*3 | |
| return label | |
| def gen_noise(shape): | |
| noise = np.zeros(shape, dtype=np.uint8) | |
| ### noise | |
| noise = cv2.randn(noise, 0, 255) | |
| noise = np.asarray(noise / 255, dtype=np.uint8) | |
| noise = torch.tensor(noise, dtype=torch.float32) | |
| return noise | |
| def cross_entropy2d(input, target, weight=None, size_average=True): | |
| n, c, h, w = input.size() | |
| nt, ht, wt = target.size() | |
| # Handle inconsistent size between input and target | |
| if h != ht or w != wt: | |
| input = F.interpolate(input, size=(ht, wt), mode="bilinear", align_corners=True) | |
| input = input.transpose(1, 2).transpose(2, 3).contiguous().view(-1, c) | |
| target = target.view(-1) | |
| loss = F.cross_entropy( | |
| input, target, weight=weight, size_average=size_average, ignore_index=250 | |
| ) | |
| return loss | |
| def ndim_tensor2im(image_tensor, imtype=np.uint8, batch=0): | |
| image_numpy = image_tensor[batch].cpu().float().numpy() | |
| result = np.argmax(image_numpy, axis=0) | |
| return result.astype(imtype) | |
| def visualize_segmap(input, multi_channel=True, tensor_out=True, batch=0) : | |
| palette = [ | |
| 0, 0, 0, 128, 0, 0, 254, 0, 0, 0, 85, 0, 169, 0, 51, | |
| 254, 85, 0, 0, 0, 85, 0, 119, 220, 85, 85, 0, 0, 85, 85, | |
| 85, 51, 0, 52, 86, 128, 0, 128, 0, 0, 0, 254, 51, 169, 220, | |
| 0, 254, 254, 85, 254, 169, 169, 254, 85, 254, 254, 0, 254, 169, 0 | |
| ] | |
| input = input.detach() | |
| if multi_channel : | |
| input = ndim_tensor2im(input,batch=batch) | |
| else : | |
| input = input[batch][0].cpu() | |
| input = np.asarray(input) | |
| input = input.astype(np.uint8) | |
| input = Image.fromarray(input, 'P') | |
| input.putpalette(palette) | |
| if tensor_out : | |
| trans = transforms.ToTensor() | |
| return trans(input.convert('RGB')) | |
| return input | |
| def pred_to_onehot(prediction) : | |
| size = prediction.shape | |
| prediction_max = torch.argmax(prediction, dim=1) | |
| oneHot_size = (size[0], 13, size[2], size[3]) | |
| pred_onehot = torch.FloatTensor(torch.Size(oneHot_size)).zero_() | |
| pred_onehot = pred_onehot.scatter_(1, prediction_max.unsqueeze(1).data.long(), 1.0) | |
| return pred_onehot | |
| def cal_miou(prediction, target) : | |
| size = prediction.shape | |
| target = target.cpu() | |
| prediction = pred_to_onehot(prediction.detach().cpu()) | |
| list = [1,2,3,4,5,6,7,8] | |
| union = 0 | |
| intersection = 0 | |
| for b in range(size[0]) : | |
| for c in list : | |
| intersection += torch.logical_and(target[b,c], prediction[b,c]).sum() | |
| union += torch.logical_or(target[b,c], prediction[b,c]).sum() | |
| return intersection.item()/union.item() | |
| def save_images(img_tensors, img_names, save_dir): | |
| for img_tensor, img_name in zip(img_tensors, img_names): | |
| tensor = (img_tensor.clone() + 1) * 0.5 * 255 | |
| tensor = tensor.cpu().clamp(0, 255) | |
| try: | |
| array = tensor.numpy().astype('uint8') | |
| except: | |
| array = tensor.detach().numpy().astype('uint8') | |
| if array.shape[0] == 1: | |
| array = array.squeeze(0) | |
| elif array.shape[0] == 3: | |
| array = array.swapaxes(0, 1).swapaxes(1, 2) | |
| im = Image.fromarray(array) | |
| im.save(os.path.join(save_dir, img_name), format='JPEG') | |
| def create_network(cls, opt): | |
| net = cls(opt) | |
| net.print_network() | |
| if len(opt.gpu_ids) > 0: | |
| assert(torch.cuda.is_available()) | |
| net.cuda() | |
| net.init_weights(opt.init_type, opt.init_variance) | |
| return net |