import os import torch from torch.utils.data import DataLoader from torchvision import datasets, transforms from torchvision.transforms.functional import rotate import config as c import sklearn.metrics as sk import numpy as np from copy import deepcopy def stable_cumsum(arr, rtol=1e-05, atol=1e-08): """Use high precision for cumsum and check that final value matches sum Parameters ---------- arr : array-like To be cumulatively summed as flat rtol : float Relative tolerance, see ``np.allclose`` atol : float Absolute tolerance, see ``np.allclose`` """ out = np.cumsum(arr, dtype=np.float64) expected = np.sum(arr, dtype=np.float64) if not np.allclose(out[-1], expected, rtol=rtol, atol=atol): raise RuntimeError('cumsum was found to be unstable: ' 'its last element does not correspond to sum') return out def fpr_and_fdr_at_recall(y_true, y_score, recall_level=0.95, pos_label=None): classes = np.unique(y_true) if (pos_label is None and not (np.array_equal(classes, [0, 1]) or np.array_equal(classes, [-1, 1]) or np.array_equal(classes, [0]) or np.array_equal(classes, [-1]) or np.array_equal(classes, [1]))): raise ValueError("Data is not binary and pos_label is not specified") elif pos_label is None: pos_label = 1. # make y_true a boolean vector y_true = (y_true == pos_label) # sort scores and corresponding truth values desc_score_indices = np.argsort(y_score, kind="mergesort")[::-1] y_score = y_score[desc_score_indices] #print(y_score) y_true = y_true[desc_score_indices] # y_score typically has many tied values. Here we extract # the indices associated with the distinct values. We also # concatenate a value for the end of the curve. distinct_value_indices = np.where(np.diff(y_score))[0] threshold_idxs = np.r_[distinct_value_indices, y_true.size - 1] # accumulate the true positives with decreasing threshold tps = stable_cumsum(y_true)[threshold_idxs] fps = 1 + threshold_idxs - tps # add one because of zero-based indexing thresholds = y_score[threshold_idxs] recall = tps / tps[-1] last_ind = tps.searchsorted(tps[-1]) sl = slice(last_ind, None, -1) # [last_ind::-1] recall, fps, tps, thresholds = np.r_[recall[sl], 1], np.r_[fps[sl], 0], np.r_[tps[sl], 0], thresholds[sl] #print(recall) cutoff = np.argmin(np.abs(recall - recall_level)) return fps[cutoff] / (np.sum(np.logical_not(y_true))), thresholds[cutoff] # , fps[cutoff]/(fps[cutoff] + tps[cutoff]) def get_random_transforms(): augmentative_transforms = [] if c.transf_rotations: augmentative_transforms += [transforms.RandomRotation(180)] if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: augmentative_transforms += [transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, saturation=c.transf_saturation)] tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), transforms.Normalize(c.norm_mean, c.norm_std)] transform_train = transforms.Compose(tfs) return transform_train def get_fixed_transforms(degrees): cust_rot = lambda x: rotate(x, degrees, False, False, None) augmentative_transforms = [cust_rot] if c.transf_brightness > 0.0 or c.transf_contrast > 0.0 or c.transf_saturation > 0.0: augmentative_transforms += [ transforms.ColorJitter(brightness=c.transf_brightness, contrast=c.transf_contrast, saturation=c.transf_saturation)] tfs = [transforms.Resize(c.img_size)] + augmentative_transforms + [transforms.ToTensor(), transforms.Normalize(c.norm_mean, c.norm_std)] return transforms.Compose(tfs) def t2np(tensor): '''pytorch tensor -> numpy array''' return tensor.cpu().data.numpy() if tensor is not None else None def get_loss(z, jac): '''check equation 4 of the paper why this makes sense - oh and just ignore the scaling here''' return torch.mean(0.5 * torch.sum(z ** 2, dim=(1,)) - jac) / z.shape[1] # def get_loss_neg_pos(z, jac, labels): # '''损失函数:正样本接近高斯分布,负样本远离高斯分布''' # # 计算流模型的标准生成损失 # normalizing_loss = torch.mean(0.5 * torch.sum(z ** 2, dim=(1,)) - jac) / z.shape[1] # # 对正样本(标签为0)希望其潜在特征接近高斯分布 # positive_loss = normalizing_loss * (labels == 0).float() # # 对负样本(标签为1)希望其潜在特征远离高斯分布 # negative_loss = -normalizing_loss * (labels == 1).float() # # 计算总损失 # total_loss = torch.mean(positive_loss + negative_loss) # return total_loss def get_loss_neg_pos(z, jac, labels, target_distribution="gaussian", margin = 500): # 计算流模型的标准生成损失 loss_sample_pos = 0.5 * torch.sum((z-10) ** 2, dim=(1,)) - jac #损失是否应该都大于零 loss_sample_neg = 0.5 * torch.sum(z ** 2, dim=(1,)) - jac positive_loss = loss_sample_pos * (labels == 0).float() negative_loss = loss_sample_neg * (labels == 1).float() # 计算总损失 total_loss = torch.mean(positive_loss + negative_loss )/ z.shape[1] return total_loss def get_loss_neg_pos_margin(z, jac, labels, margin = 500): # 计算流模型的标准生成损失 # print(jac) # jac = torch.clamp(jac, min=1e-5, max=1e5) # z = torch.clamp(z, min=-1e5, max=1e5) loss_sample = 0.5 * torch.sum(z ** 2, dim=(1,)) #损失是否应该都大于零 # print(loss_sample) # positive_loss = (-loss_sample) * (labels == 0).float()* (loss_sample 0.5).float() # 趋向0,差距越小越好 consistent_loss = consistent_loss/len(labels) # total_loss = shape_loss + consistent_loss * 0.05 total_loss = consistent_loss return shape_loss, consistent_loss, total_loss def get_measures(_pos, _neg, recall_level=0.95): pos = np.array(_pos[:]).reshape((-1, 1)) neg = np.array(_neg[:]).reshape((-1, 1)) examples = np.squeeze(np.vstack((pos, neg))) labels = np.zeros(len(examples), dtype=np.int32) labels[:len(pos)] += 1 auroc = sk.roc_auc_score(labels, examples) aupr = sk.average_precision_score(labels, examples) fpr, threshold = fpr_and_fdr_at_recall(labels, examples, recall_level) return auroc, aupr, fpr def find_best_threshold(y_true, y_pred): "We assume first half is real 0, and the second half is fake 1" N = y_true.shape[0] if y_pred[0:N//2].max() <= y_pred[N//2:N].min(): # perfectly separable case return (y_pred[0:N//2].max() + y_pred[N//2:N].min()) / 2 best_acc = 0 best_thres = 0 for thres in y_pred: temp = deepcopy(y_pred) temp[temp>=thres] = 1 temp[temp= best_acc: best_thres = thres best_acc = acc return best_thres def get_loss_neg(z, jac, labels, margin = 500): # 计算流模型的标准生成损失 # print(jac) loss_sample = 0.5 * torch.sum(z ** 2, dim=(1,)) -jac #损失是否应该都大于零 # print(loss_sample) # positive_loss = (-loss_sample) * (labels == 0).float()* (loss_sample