| | import os |
| | import torch |
| | from torch_geometric.loader import DataLoader |
| | from pathlib import Path |
| | import numpy as np |
| | import multiprocessing |
| | from joblib import Parallel, delayed |
| |
|
| | |
| | from data.loaders import load_dataset |
| | from data.featurize import ( |
| | smiles_to_graph, |
| | selfies_to_graph, |
| | ecfp_to_graph, |
| | smiles_for_gp, |
| | selfies_for_gp, |
| | ecfp_for_gp, |
| | ) |
| | from data.polyatomic_featurize import compressed_topsignal_graph_from_smiles |
| | from models.gnn import GCN, GIN, GAT, GraphSAGE |
| | from models.polyatomic import PolyatomicNet |
| |
|
| | |
| | |
| | REPRESENTATIONS = { |
| | "smiles": smiles_to_graph, |
| | "selfies": selfies_to_graph, |
| | "ecfp": ecfp_to_graph, |
| | "polyatomic": compressed_topsignal_graph_from_smiles, |
| | } |
| |
|
| | |
| | GP_FEATURIZERS = { |
| | "smiles": smiles_for_gp, |
| | "selfies": selfies_for_gp, |
| | "ecfp": ecfp_for_gp, |
| | } |
| |
|
| | GNN_MODELS = { |
| | "gcn": GCN, |
| | "gin": GIN, |
| | "gat": GAT, |
| | "sage": GraphSAGE, |
| | "polyatomic": PolyatomicNet, |
| | } |
| |
|
| | try: |
| | multiprocessing.set_start_method("fork") |
| | except RuntimeError: |
| | pass |
| |
|
| |
|
| | def featurize_dataset_parallel(X, y, featurizer, n_jobs=None): |
| | """Your provided parallel featurization function.""" |
| | if n_jobs is None: |
| | n_jobs = max(1, multiprocessing.cpu_count() - 2) |
| |
|
| | if featurizer.__name__ == "compressed_topsignal_graph_from_smiles": |
| | results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)( |
| | delayed(featurizer)(xi, yi) for xi, yi in zip(X, y) |
| | ) |
| | else: |
| | results = Parallel(n_jobs=n_jobs, backend="loky", verbose=10)( |
| | delayed(featurizer)(xi) for xi in X |
| | ) |
| |
|
| | |
| | data_list = [] |
| | |
| | for i, res in enumerate(results): |
| | if res is not None: |
| | res.y = torch.tensor([y[i]], dtype=torch.float) |
| | data_list.append(res) |
| |
|
| | return data_list |
| |
|
| |
|
| | def prepare_and_load_data(args): |
| | """ |
| | Performs the expensive featurization ONCE and caches the result. |
| | Subsequent runs will load the cached file instantly. |
| | """ |
| | root_dir = Path(__file__).parent.parent.resolve().__str__() |
| | datasets_dir = root_dir + "/" + "datasets" |
| | cache_dir = Path(datasets_dir) |
| | if not os.path.exists(cache_dir): |
| | cache_dir.mkdir(exist_ok=False) |
| | train_cache_file = Path(f"{cache_dir}" + "/" + f"{args.rep}_data_{args.dataset}.pt") |
| | test_cache_file = Path( |
| | f"{cache_dir}" + "/" + f"{args.rep}_test_data_{args.dataset}.pt" |
| | ) |
| |
|
| | print(f"train cache file is: {train_cache_file}") |
| | print(f"test cache file is: {test_cache_file}") |
| |
|
| | if train_cache_file.exists() and test_cache_file.exists(): |
| | print( |
| | f"INFO: Loading pre-featurized data from cache for dataset '{args.dataset}'..." |
| | ) |
| | train_graphs = torch.load(train_cache_file, weights_only=False) |
| | test_graphs = torch.load(test_cache_file, weights_only=False) |
| | return train_graphs, test_graphs |
| |
|
| | print("INFO: No cached data found. Starting one-time featurization process...") |
| | X_train, X_test, y_train, y_test = load_dataset(args.dataset) |
| | featurizer = REPRESENTATIONS[args.rep] |
| |
|
| | print("Featurizing training set (this may take a while)...") |
| | train_graphs = featurize_dataset_parallel(X_train, y_train, featurizer) |
| | torch.save(train_graphs, train_cache_file) |
| | print(f"Saved featurized training data to {train_cache_file}") |
| |
|
| | print("Featurizing test set...") |
| | test_graphs = featurize_dataset_parallel(X_test, y_test, featurizer) |
| | torch.save(test_graphs, test_cache_file) |
| | print(f"Saved featurized test data to {test_cache_file}") |
| |
|
| | return train_graphs, test_graphs |
| |
|
| |
|
| | def get_model_instance(args, params, train_graphs): |
| | """Instantiates a model, handling the special case for polyatomic.""" |
| | model_class = GNN_MODELS[args.model] |
| | sample_graph = train_graphs[0] |
| |
|
| | if args.model == "polyatomic": |
| | from torch_geometric.utils import degree |
| |
|
| | print("INFO: Calculating degree vector for polyatomic model...") |
| | loader = DataLoader(train_graphs, batch_size=params.get("batch_size", 128)) |
| | deg = torch.zeros(32, dtype=torch.long) |
| | for data in loader: |
| | d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long) |
| | bc = torch.bincount(d, minlength=deg.size(0)) |
| | if bc.size(0) > deg.size(0): |
| | new_deg = torch.zeros(bc.size(0), dtype=torch.long) |
| | new_deg[: deg.size(0)] = deg |
| | deg = new_deg |
| | deg += bc |
| | return model_class( |
| | node_feat_dim=sample_graph.x.shape[1], |
| | edge_feat_dim=sample_graph.edge_attr.shape[1], |
| | graph_feat_dim=sample_graph.graph_feats.shape[0], |
| | hidden_dim=params["hidden_dim"], |
| | deg=deg, |
| | ) |
| | else: |
| | return model_class( |
| | in_channels=sample_graph.num_node_features, |
| | hidden_channels=params["hidden_dim"], |
| | out_channels=1, |
| | ) |
| |
|