Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import scipy.sparse as sp | |
| def scipy_to_torch_sparse(scp_matrix): | |
| values = scp_matrix.data | |
| indices = np.vstack((scp_matrix.row, scp_matrix.col)) | |
| i = torch.LongTensor(indices) | |
| v = torch.FloatTensor(values) | |
| shape = scp_matrix.shape | |
| sparse_tensor = torch.sparse.FloatTensor(i, v, torch.Size(shape)) | |
| return sparse_tensor | |
| def load_config(dataset_path, hyperparameters = None): | |
| config = {} | |
| config['DATASET'] = dataset_path | |
| with open(Path(dataset_path) / "config.json") as f: | |
| data_config = json.load(f) | |
| if hyperparameters is None: | |
| hyperparameters = {} | |
| hyperparameters['latents'] = 64 | |
| hyperparameters['initial_filters'] = 16 | |
| # Data augmentation is left to the dataset construction info | |
| hyperparameters['flip_h'] = data_config['flip_h'] | |
| hyperparameters['flip_v'] = data_config['flip_v'] | |
| hyperparameters['transpose'] = data_config['transpose'] | |
| hyperparameters['rotate'] = data_config['rotate'] | |
| config.update({ | |
| 'organs': data_config['organs'], | |
| 'organ_names': data_config['organ_names'], | |
| 'resolutions': data_config['resolutions'], | |
| 'inputsize': data_config['inputsize'], | |
| 'device': torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
| }) | |
| config.update(hyperparameters) | |
| num_pooling = len(config['resolutions']) | |
| config['filters'] = [config['initial_filters'] * (i+1) // 2 for i in range(num_pooling+1)] | |
| config['filters'] = [x for x in config['filters'] for _ in (0, 1)] | |
| config['filters'][0] = 2 | |
| if config['naive']: | |
| adj_path = "Naive" | |
| else: | |
| adj_path = "NonNaive" | |
| A_ = [] | |
| for res in config['resolutions']: | |
| A = np.load(Path(dataset_path) / adj_path / f"adj_{res}_block_diagonal.npy") | |
| A = sp.csc_matrix(A).tocoo() | |
| A_.extend([A.copy()]) | |
| A_.append(A_[-1]) | |
| A_t = [scipy_to_torch_sparse(x).to(config['device']) for x in A_] | |
| config['n_nodes'] = [A.shape[0] for A in A_] | |
| D_ = [] | |
| for res in ['to_' + x for x in config['resolutions'][1:]]: | |
| D = np.load(Path(dataset_path) / adj_path / f"downsampling_{res}.npy") | |
| D_.append(sp.csc_matrix(D).tocoo()) | |
| D_t = [scipy_to_torch_sparse(x).to(config['device']) for x in D_] | |
| U_ = [] | |
| for res in ['to_' + x for x in config['resolutions'][:-1]]: | |
| U = np.load(Path(dataset_path) / adj_path / f"upsampling_{res}.npy") | |
| U_.append(sp.csc_matrix(U).tocoo()) | |
| U_t = [scipy_to_torch_sparse(x).to(config['device']) for x in U_] | |
| return config, D_t, U_t, A_t |