Spaces:
Running
Running
| from urllib.request import urlopen | |
| import torch | |
| from torch import nn | |
| import numpy as np | |
| from skimage.morphology import label | |
| import os | |
| from HD_BET.paths import folder_with_parameter_files | |
| def get_params_fname(fold): | |
| return os.path.join(folder_with_parameter_files, "%d.model" % fold) | |
| def maybe_download_parameters(fold=0, force_overwrite=False): | |
| """ | |
| Downloads the parameters for some fold if it is not present yet. | |
| :param fold: | |
| :param force_overwrite: if True the old parameter file will be deleted (if present) prior to download | |
| :return: | |
| """ | |
| assert 0 <= fold <= 4, "fold must be between 0 and 4" | |
| if not os.path.isdir(folder_with_parameter_files): | |
| maybe_mkdir_p(folder_with_parameter_files) | |
| out_filename = get_params_fname(fold) | |
| if force_overwrite and os.path.isfile(out_filename): | |
| os.remove(out_filename) | |
| if not os.path.isfile(out_filename): | |
| url = "https://zenodo.org/record/2540695/files/%d.model?download=1" % fold | |
| print("Downloading", url, "...") | |
| data = urlopen(url).read() | |
| #out_filename = "/media/sdb/divyanshu/divyanshu/aidan_segmentation/nnUNet_pLGG/home/divyanshu/hd-bet_params/0.model" | |
| with open(out_filename, 'wb') as f: | |
| f.write(data) | |
| def init_weights(module): | |
| if isinstance(module, nn.Conv3d): | |
| module.weight = nn.init.kaiming_normal(module.weight, a=1e-2) | |
| if module.bias is not None: | |
| module.bias = nn.init.constant(module.bias, 0) | |
| def softmax_helper(x): | |
| rpt = [1 for _ in range(len(x.size()))] | |
| rpt[1] = x.size(1) | |
| x_max = x.max(1, keepdim=True)[0].repeat(*rpt) | |
| e_x = torch.exp(x - x_max) | |
| return e_x / e_x.sum(1, keepdim=True).repeat(*rpt) | |
| class SetNetworkToVal(object): | |
| def __init__(self, use_dropout_sampling=False, norm_use_average=True): | |
| self.norm_use_average = norm_use_average | |
| self.use_dropout_sampling = use_dropout_sampling | |
| def __call__(self, module): | |
| if isinstance(module, nn.Dropout3d) or isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout): | |
| module.train(self.use_dropout_sampling) | |
| elif isinstance(module, nn.InstanceNorm3d) or isinstance(module, nn.InstanceNorm2d) or \ | |
| isinstance(module, nn.InstanceNorm1d) \ | |
| or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or \ | |
| isinstance(module, nn.BatchNorm1d): | |
| module.train(not self.norm_use_average) | |
| def postprocess_prediction(seg): | |
| # basically look for connected components and choose the largest one, delete everything else | |
| print("running postprocessing... ") | |
| mask = seg != 0 | |
| lbls = label(mask, connectivity=mask.ndim) | |
| lbls_sizes = [np.sum(lbls == i) for i in np.unique(lbls)] | |
| largest_region = np.argmax(lbls_sizes[1:]) + 1 | |
| seg[lbls != largest_region] = 0 | |
| return seg | |
| def subdirs(folder, join=True, prefix=None, suffix=None, sort=True): | |
| if join: | |
| l = os.path.join | |
| else: | |
| l = lambda x, y: y | |
| res = [l(folder, i) for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i)) | |
| and (prefix is None or i.startswith(prefix)) | |
| and (suffix is None or i.endswith(suffix))] | |
| if sort: | |
| res.sort() | |
| return res | |
| def subfiles(folder, join=True, prefix=None, suffix=None, sort=True): | |
| if join: | |
| l = os.path.join | |
| else: | |
| l = lambda x, y: y | |
| res = [l(folder, i) for i in os.listdir(folder) if os.path.isfile(os.path.join(folder, i)) | |
| and (prefix is None or i.startswith(prefix)) | |
| and (suffix is None or i.endswith(suffix))] | |
| if sort: | |
| res.sort() | |
| return res | |
| subfolders = subdirs # I am tired of confusing those | |
| def maybe_mkdir_p(directory): | |
| splits = directory.split("/")[1:] | |
| for i in range(0, len(splits)): | |
| if not os.path.isdir(os.path.join("", *splits[:i+1])): | |
| os.mkdir(os.path.join("", *splits[:i+1])) | |