Andrej Janchevski commited on
Commit ·
16cab72
1
Parent(s): b701828
fix(research): remove wandb dependency and guard optional imports
Browse files- Strip wandb imports and calls from MultiProxAn diffusion models and
utils (prevents ImportError at checkpoint load time)
- Wrap graph_tool, pyemd, pygsp, dist_helper imports in try/except
across spectre_utils, molecular_metrics, train_metrics (these are
only needed for training metrics, not inference)
- Fix pandas deprecation in COINs load_graph.py (iteritems -> items,
to_frame column naming)
- src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py +2 -2
- src/research/MultiProxAn/src/analysis/rdkit_functions.py +0 -10
- src/research/MultiProxAn/src/analysis/spectre_utils.py +20 -29
- src/research/MultiProxAn/src/analysis/visualization.py +0 -12
- src/research/MultiProxAn/src/diffusion_model.py +0 -32
- src/research/MultiProxAn/src/diffusion_model_discrete.py +0 -31
- src/research/MultiProxAn/src/metrics/molecular_metrics.py +0 -18
- src/research/MultiProxAn/src/metrics/molecular_metrics_discrete.py +0 -6
- src/research/MultiProxAn/src/metrics/train_metrics.py +0 -10
- src/research/MultiProxAn/src/utils.py +0 -7
src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py
CHANGED
|
@@ -128,8 +128,8 @@ class Loader:
|
|
| 128 |
hits_at_3_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(3).sum() / num_edges
|
| 129 |
hits_at_10_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(10).sum() / num_edges
|
| 130 |
community_query_counts = community_query_edge_counts.groupby(["c_s", "r"]).count()
|
| 131 |
-
community_query_edge_counts = community_query_edge_counts.to_frame().assign(rank=0, rrank=0)
|
| 132 |
-
for (c_s, r), c_t_count in community_query_counts.
|
| 133 |
community_query_edge_counts.loc[(c_s, r), "rank"] = np.arange(1, c_t_count + 1)
|
| 134 |
community_query_edge_counts.loc[(c_s, r), "rrank"] = 1 / np.arange(1, c_t_count + 1)
|
| 135 |
mr_limit = (community_query_edge_counts["c_t"] * community_query_edge_counts["rank"]).sum() / num_edges
|
|
|
|
| 128 |
hits_at_3_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(3).sum() / num_edges
|
| 129 |
hits_at_10_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(10).sum() / num_edges
|
| 130 |
community_query_counts = community_query_edge_counts.groupby(["c_s", "r"]).count()
|
| 131 |
+
community_query_edge_counts = community_query_edge_counts.to_frame(name="c_t").assign(rank=0, rrank=0.0)
|
| 132 |
+
for (c_s, r), c_t_count in community_query_counts.items():
|
| 133 |
community_query_edge_counts.loc[(c_s, r), "rank"] = np.arange(1, c_t_count + 1)
|
| 134 |
community_query_edge_counts.loc[(c_s, r), "rrank"] = 1 / np.arange(1, c_t_count + 1)
|
| 135 |
mr_limit = (community_query_edge_counts["c_t"] * community_query_edge_counts["rank"]).sum() / num_edges
|
src/research/MultiProxAn/src/analysis/rdkit_functions.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import re
|
| 4 |
-
import wandb
|
| 5 |
try:
|
| 6 |
from rdkit import Chem
|
| 7 |
print("Found rdkit, all good")
|
|
@@ -316,19 +315,10 @@ def compute_molecular_metrics(molecule_list, train_smiles, dataset_info):
|
|
| 316 |
fraction_mol_stable = molecule_stable / float(n_molecules)
|
| 317 |
fraction_atm_stable = nr_stable_bonds / float(n_atoms)
|
| 318 |
validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
|
| 319 |
-
if wandb.run:
|
| 320 |
-
wandb.log(validity_dict)
|
| 321 |
else:
|
| 322 |
validity_dict = {'mol_stable': -1, 'atm_stable': -1}
|
| 323 |
|
| 324 |
metrics = BasicMolecularMetrics(dataset_info, train_smiles)
|
| 325 |
rdkit_metrics = metrics.evaluate(molecule_list)
|
| 326 |
all_smiles = rdkit_metrics[-1]
|
| 327 |
-
if wandb.run:
|
| 328 |
-
nc = rdkit_metrics[-2]
|
| 329 |
-
dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1],
|
| 330 |
-
'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3],
|
| 331 |
-
'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']}
|
| 332 |
-
wandb.log(dic)
|
| 333 |
-
|
| 334 |
return validity_dict, rdkit_metrics, all_smiles
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import torch
|
| 3 |
import re
|
|
|
|
| 4 |
try:
|
| 5 |
from rdkit import Chem
|
| 6 |
print("Found rdkit, all good")
|
|
|
|
| 315 |
fraction_mol_stable = molecule_stable / float(n_molecules)
|
| 316 |
fraction_atm_stable = nr_stable_bonds / float(n_atoms)
|
| 317 |
validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
|
|
|
|
|
|
|
| 318 |
else:
|
| 319 |
validity_dict = {'mol_stable': -1, 'atm_stable': -1}
|
| 320 |
|
| 321 |
metrics = BasicMolecularMetrics(dataset_info, train_smiles)
|
| 322 |
rdkit_metrics = metrics.evaluate(molecule_list)
|
| 323 |
all_smiles = rdkit_metrics[-1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
return validity_dict, rdkit_metrics, all_smiles
|
src/research/MultiProxAn/src/analysis/spectre_utils.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
# Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation
|
| 4 |
#
|
| 5 |
###############################################################################
|
| 6 |
-
import graph_tool.all as gt
|
| 7 |
##Navigate to the ./util/orca directory and compile orca.cpp
|
| 8 |
# g++ -O2 -std=c++11 -o orca orca.cpp
|
| 9 |
import os
|
|
@@ -12,18 +11,26 @@ import torch
|
|
| 12 |
import torch.nn as nn
|
| 13 |
import numpy as np
|
| 14 |
import networkx as nx
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
import
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
PRINT_TIME = False
|
| 29 |
__all__ = ['degree_stats', 'clustering_stats', 'orbit_stats_all', 'spectral_stats', 'eval_acc_lobster_graph']
|
|
@@ -778,8 +785,6 @@ class SpectreSamplingMetrics(nn.Module):
|
|
| 778 |
degree = degree_stats(reference_graphs, networkx_graphs, is_parallel=False,
|
| 779 |
compute_emd=self.compute_emd)
|
| 780 |
to_log['degree'] = degree
|
| 781 |
-
if wandb.run:
|
| 782 |
-
wandb.run.summary['degree'] = degree
|
| 783 |
|
| 784 |
# val_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.val]
|
| 785 |
# train_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.train]
|
|
@@ -795,8 +800,6 @@ class SpectreSamplingMetrics(nn.Module):
|
|
| 795 |
compute_emd=self.compute_emd)
|
| 796 |
|
| 797 |
to_log['spectre'] = spectre
|
| 798 |
-
if wandb.run:
|
| 799 |
-
wandb.run.summary['spectre'] = spectre
|
| 800 |
|
| 801 |
if 'clustering' in self.metrics_list:
|
| 802 |
if local_rank == 0:
|
|
@@ -804,8 +807,6 @@ class SpectreSamplingMetrics(nn.Module):
|
|
| 804 |
clustering = clustering_stats(reference_graphs, networkx_graphs, bins=100, is_parallel=False,
|
| 805 |
compute_emd=self.compute_emd)
|
| 806 |
to_log['clustering'] = clustering
|
| 807 |
-
if wandb.run:
|
| 808 |
-
wandb.run.summary['clustering'] = clustering
|
| 809 |
|
| 810 |
if 'motif' in self.metrics_list:
|
| 811 |
if local_rank == 0:
|
|
@@ -813,32 +814,24 @@ class SpectreSamplingMetrics(nn.Module):
|
|
| 813 |
motif = motif_stats(reference_graphs, networkx_graphs, motif_type='4cycle', ground_truth_match=None, bins=100,
|
| 814 |
compute_emd=self.compute_emd)
|
| 815 |
to_log['motif'] = motif
|
| 816 |
-
if wandb.run:
|
| 817 |
-
wandb.run.summary['motif'] = motif
|
| 818 |
|
| 819 |
if 'orbit' in self.metrics_list:
|
| 820 |
if local_rank == 0:
|
| 821 |
print("Computing orbit stats...")
|
| 822 |
orbit = orbit_stats_all(reference_graphs, networkx_graphs, compute_emd=self.compute_emd)
|
| 823 |
to_log['orbit'] = orbit
|
| 824 |
-
if wandb.run:
|
| 825 |
-
wandb.run.summary['orbit'] = orbit
|
| 826 |
|
| 827 |
if 'sbm' in self.metrics_list:
|
| 828 |
if local_rank == 0:
|
| 829 |
print("Computing accuracy...")
|
| 830 |
acc = eval_acc_sbm_graph(networkx_graphs, refinement_steps=100, strict=True)
|
| 831 |
to_log['sbm_acc'] = acc
|
| 832 |
-
if wandb.run:
|
| 833 |
-
wandb.run.summary['sbmacc'] = acc
|
| 834 |
|
| 835 |
if 'planar' in self.metrics_list:
|
| 836 |
if local_rank ==0:
|
| 837 |
print('Computing planar accuracy...')
|
| 838 |
planar_acc = eval_acc_planar_graph(networkx_graphs)
|
| 839 |
to_log['planar_acc'] = planar_acc
|
| 840 |
-
if wandb.run:
|
| 841 |
-
wandb.run.summary['planar_acc'] = planar_acc
|
| 842 |
|
| 843 |
if 'sbm' or 'planar' in self.metrics_list:
|
| 844 |
if local_rank == 0:
|
|
@@ -853,8 +846,6 @@ class SpectreSamplingMetrics(nn.Module):
|
|
| 853 |
|
| 854 |
if local_rank == 0:
|
| 855 |
print("Sampling statistics", to_log)
|
| 856 |
-
if wandb.run:
|
| 857 |
-
wandb.log(to_log, commit=False)
|
| 858 |
|
| 859 |
def reset(self):
|
| 860 |
pass
|
|
|
|
| 3 |
# Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation
|
| 4 |
#
|
| 5 |
###############################################################################
|
|
|
|
| 6 |
##Navigate to the ./util/orca directory and compile orca.cpp
|
| 7 |
# g++ -O2 -std=c++11 -o orca orca.cpp
|
| 8 |
import os
|
|
|
|
| 11 |
import torch.nn as nn
|
| 12 |
import numpy as np
|
| 13 |
import networkx as nx
|
| 14 |
+
|
| 15 |
+
# Heavy metric-computation deps — optional, not needed for inference.
|
| 16 |
+
# Deferred so checkpoint unpickling works without graph-tool / pyemd / pygsp.
|
| 17 |
+
try:
|
| 18 |
+
import graph_tool.all as gt
|
| 19 |
+
except ImportError:
|
| 20 |
+
gt = None
|
| 21 |
+
try:
|
| 22 |
+
import subprocess as sp
|
| 23 |
+
import concurrent.futures
|
| 24 |
+
import pygsp as pg
|
| 25 |
+
import secrets
|
| 26 |
+
from string import ascii_uppercase, digits
|
| 27 |
+
from datetime import datetime
|
| 28 |
+
from scipy.linalg import eigvalsh
|
| 29 |
+
from scipy.stats import chi2
|
| 30 |
+
from src.analysis.dist_helper import compute_mmd, gaussian_emd, gaussian, emd, gaussian_tv, disc
|
| 31 |
+
from torch_geometric.utils import to_networkx
|
| 32 |
+
except ImportError:
|
| 33 |
+
pass # Metrics unavailable — inference still works
|
| 34 |
|
| 35 |
PRINT_TIME = False
|
| 36 |
__all__ = ['degree_stats', 'clustering_stats', 'orbit_stats_all', 'spectral_stats', 'eval_acc_lobster_graph']
|
|
|
|
| 785 |
degree = degree_stats(reference_graphs, networkx_graphs, is_parallel=False,
|
| 786 |
compute_emd=self.compute_emd)
|
| 787 |
to_log['degree'] = degree
|
|
|
|
|
|
|
| 788 |
|
| 789 |
# val_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.val]
|
| 790 |
# train_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.train]
|
|
|
|
| 800 |
compute_emd=self.compute_emd)
|
| 801 |
|
| 802 |
to_log['spectre'] = spectre
|
|
|
|
|
|
|
| 803 |
|
| 804 |
if 'clustering' in self.metrics_list:
|
| 805 |
if local_rank == 0:
|
|
|
|
| 807 |
clustering = clustering_stats(reference_graphs, networkx_graphs, bins=100, is_parallel=False,
|
| 808 |
compute_emd=self.compute_emd)
|
| 809 |
to_log['clustering'] = clustering
|
|
|
|
|
|
|
| 810 |
|
| 811 |
if 'motif' in self.metrics_list:
|
| 812 |
if local_rank == 0:
|
|
|
|
| 814 |
motif = motif_stats(reference_graphs, networkx_graphs, motif_type='4cycle', ground_truth_match=None, bins=100,
|
| 815 |
compute_emd=self.compute_emd)
|
| 816 |
to_log['motif'] = motif
|
|
|
|
|
|
|
| 817 |
|
| 818 |
if 'orbit' in self.metrics_list:
|
| 819 |
if local_rank == 0:
|
| 820 |
print("Computing orbit stats...")
|
| 821 |
orbit = orbit_stats_all(reference_graphs, networkx_graphs, compute_emd=self.compute_emd)
|
| 822 |
to_log['orbit'] = orbit
|
|
|
|
|
|
|
| 823 |
|
| 824 |
if 'sbm' in self.metrics_list:
|
| 825 |
if local_rank == 0:
|
| 826 |
print("Computing accuracy...")
|
| 827 |
acc = eval_acc_sbm_graph(networkx_graphs, refinement_steps=100, strict=True)
|
| 828 |
to_log['sbm_acc'] = acc
|
|
|
|
|
|
|
| 829 |
|
| 830 |
if 'planar' in self.metrics_list:
|
| 831 |
if local_rank ==0:
|
| 832 |
print('Computing planar accuracy...')
|
| 833 |
planar_acc = eval_acc_planar_graph(networkx_graphs)
|
| 834 |
to_log['planar_acc'] = planar_acc
|
|
|
|
|
|
|
| 835 |
|
| 836 |
if 'sbm' or 'planar' in self.metrics_list:
|
| 837 |
if local_rank == 0:
|
|
|
|
| 846 |
|
| 847 |
if local_rank == 0:
|
| 848 |
print("Sampling statistics", to_log)
|
|
|
|
|
|
|
| 849 |
|
| 850 |
def reset(self):
|
| 851 |
pass
|
src/research/MultiProxAn/src/analysis/visualization.py
CHANGED
|
@@ -8,7 +8,6 @@ import imageio
|
|
| 8 |
import networkx as nx
|
| 9 |
import numpy as np
|
| 10 |
import rdkit.Chem
|
| 11 |
-
import wandb
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
|
| 14 |
|
|
@@ -78,9 +77,6 @@ class MolecularVisualization:
|
|
| 78 |
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
|
| 79 |
try:
|
| 80 |
Draw.MolToFile(mol, file_path)
|
| 81 |
-
if wandb.run and log is not None:
|
| 82 |
-
print(f"Saving {file_path} to wandb")
|
| 83 |
-
wandb.log({log: wandb.Image(file_path)}, commit=True)
|
| 84 |
except rdkit.Chem.KekulizeException:
|
| 85 |
print("Can't kekulize molecule")
|
| 86 |
|
|
@@ -115,10 +111,6 @@ class MolecularVisualization:
|
|
| 115 |
imgs.extend([imgs[-1]] * 10)
|
| 116 |
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
|
| 117 |
|
| 118 |
-
if wandb.run:
|
| 119 |
-
print(f"Saving {gif_path} to wandb")
|
| 120 |
-
wandb.log({"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True)
|
| 121 |
-
|
| 122 |
# draw grid image
|
| 123 |
try:
|
| 124 |
img = Draw.MolsToGridImage(mols, molsPerRow=20, subImgSize=(200, 200))
|
|
@@ -185,8 +177,6 @@ class NonMolecularVisualization:
|
|
| 185 |
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
|
| 186 |
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
|
| 187 |
im = plt.imread(file_path)
|
| 188 |
-
if wandb.run and log is not None:
|
| 189 |
-
wandb.log({log: [wandb.Image(im, caption=file_path)]})
|
| 190 |
|
| 191 |
def visualize_chain(self, path, nodes_list, adjacency_matrix):
|
| 192 |
# convert graphs to networkx
|
|
@@ -219,5 +209,3 @@ class NonMolecularVisualization:
|
|
| 219 |
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
| 220 |
imgs.extend([imgs[-1]] * 10)
|
| 221 |
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
|
| 222 |
-
if wandb.run:
|
| 223 |
-
wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
|
|
|
|
| 8 |
import networkx as nx
|
| 9 |
import numpy as np
|
| 10 |
import rdkit.Chem
|
|
|
|
| 11 |
import matplotlib.pyplot as plt
|
| 12 |
|
| 13 |
|
|
|
|
| 77 |
mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
|
| 78 |
try:
|
| 79 |
Draw.MolToFile(mol, file_path)
|
|
|
|
|
|
|
|
|
|
| 80 |
except rdkit.Chem.KekulizeException:
|
| 81 |
print("Can't kekulize molecule")
|
| 82 |
|
|
|
|
| 111 |
imgs.extend([imgs[-1]] * 10)
|
| 112 |
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
|
| 113 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
# draw grid image
|
| 115 |
try:
|
| 116 |
img = Draw.MolsToGridImage(mols, molsPerRow=20, subImgSize=(200, 200))
|
|
|
|
| 177 |
graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
|
| 178 |
self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
|
| 179 |
im = plt.imread(file_path)
|
|
|
|
|
|
|
| 180 |
|
| 181 |
def visualize_chain(self, path, nodes_list, adjacency_matrix):
|
| 182 |
# convert graphs to networkx
|
|
|
|
| 209 |
gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
|
| 210 |
imgs.extend([imgs[-1]] * 10)
|
| 211 |
imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
|
|
|
|
|
|
src/research/MultiProxAn/src/diffusion_model.py
CHANGED
|
@@ -6,7 +6,6 @@ import numpy as np
|
|
| 6 |
import pytorch_lightning as pl
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
| 9 |
-
import wandb
|
| 10 |
from tqdm.auto import tqdm
|
| 11 |
|
| 12 |
from diffusion.noise_schedule import PredefinedNoiseSchedule
|
|
@@ -141,8 +140,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 141 |
|
| 142 |
def on_fit_start(self) -> None:
|
| 143 |
self.train_iterations = len(self.trainer.datamodule.train_dataloader())
|
| 144 |
-
if self.local_rank == 0:
|
| 145 |
-
utils.setup_wandb(self.cfg)
|
| 146 |
|
| 147 |
def on_train_epoch_start(self) -> None:
|
| 148 |
self.start_epoch_time = time.time()
|
|
@@ -186,15 +183,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 186 |
metrics = [self.val_nll.compute(), self.val_X_mse.compute(), self.val_E_mse.compute(),
|
| 187 |
self.val_y_mse.compute(), self.val_X_logp.compute(), self.val_E_logp.compute(),
|
| 188 |
self.val_y_logp.compute()]
|
| 189 |
-
if wandb.run:
|
| 190 |
-
wandb.log({"val/epoch_NLL": metrics[0],
|
| 191 |
-
"val/X_mse": metrics[1],
|
| 192 |
-
"val/E_mse": metrics[2],
|
| 193 |
-
"val/y_mse": metrics[3],
|
| 194 |
-
"val/X_logp": metrics[4],
|
| 195 |
-
"val/E_logp": metrics[5],
|
| 196 |
-
"val/y_logp": metrics[6]}, commit=False)
|
| 197 |
-
|
| 198 |
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type MSE {metrics[1] :.2f} -- ",
|
| 199 |
f"Val Edge type MSE: {metrics[2] :.2f} -- Val Global feat. MSE {metrics[3] :.2f}",
|
| 200 |
f"-- Val X Reconstruction loss {metrics[4] :.2f} -- Val E Reconstruction loss {metrics[5] :.2f}",
|
|
@@ -203,8 +191,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 203 |
# Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
|
| 204 |
val_nll = metrics[0]
|
| 205 |
self.log("val/epoch_NLL", val_nll, sync_dist=True)
|
| 206 |
-
if wandb.run:
|
| 207 |
-
wandb.log(self.log_info(), commit=False)
|
| 208 |
|
| 209 |
if val_nll < self.best_val_nll:
|
| 210 |
self.best_val_nll = val_nll
|
|
@@ -249,8 +235,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 249 |
self.test_X_logp.reset()
|
| 250 |
self.test_E_logp.reset()
|
| 251 |
self.test_y_logp.reset()
|
| 252 |
-
if self.local_rank == 0:
|
| 253 |
-
utils.setup_wandb(self.cfg)
|
| 254 |
|
| 255 |
def test_step(self, data, i):
|
| 256 |
dense_data, node_mask = utils.to_dense(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr,
|
|
@@ -277,19 +261,12 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 277 |
"test/X_logp": metrics[4],
|
| 278 |
"test/E_logp": metrics[5],
|
| 279 |
"test/y_logp": metrics[6]}
|
| 280 |
-
if wandb.run:
|
| 281 |
-
wandb.log(log_dict, commit=False)
|
| 282 |
-
|
| 283 |
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type MSE {metrics[1] :.2f} -- ",
|
| 284 |
f"Test Edge type MSE: {metrics[2] :.2f} -- Test Global feat. MSE {metrics[3] :.2f}",
|
| 285 |
f"-- Test X Reconstruction loss {metrics[4] :.2f} -- Test E Reconstruction loss {metrics[5] :.2f}",
|
| 286 |
f"-- Test y Reconstruction loss {metrics[6] : .2f}\n")
|
| 287 |
|
| 288 |
test_nll = metrics[0]
|
| 289 |
-
if wandb.run:
|
| 290 |
-
wandb.log({"test/epoch_NLL": test_nll}, commit=False)
|
| 291 |
-
wandb.log(self.log_info(), commit=False)
|
| 292 |
-
|
| 293 |
print(f'Test loss: {test_nll :.4f}')
|
| 294 |
|
| 295 |
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
|
@@ -320,9 +297,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 320 |
samples_left_to_save -= to_save
|
| 321 |
samples_left_to_generate -= to_generate
|
| 322 |
chains_left_to_save -= chains_save
|
| 323 |
-
if wandb.run:
|
| 324 |
-
wandb.log({"test/time": total_eval_time}, commit=False)
|
| 325 |
-
wandb.run.summary['test_time'] = total_eval_time
|
| 326 |
print(f'Test time: {total_eval_time :.4f} seconds')
|
| 327 |
|
| 328 |
self.sampling_metrics.reset()
|
|
@@ -587,12 +561,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
|
|
| 587 |
|
| 588 |
nll = self.test_nll(nlls) if test else self.val_nll(nlls) # Average over the batch
|
| 589 |
|
| 590 |
-
wandb.log({"kl prior": kl_prior.mean(),
|
| 591 |
-
"Estimator loss terms": loss_all_t.mean(),
|
| 592 |
-
"Loss term 0": loss_term_0,
|
| 593 |
-
"log_pn": log_pN.mean(),
|
| 594 |
-
'test_nll' if test else 'val_nll': nll},
|
| 595 |
-
commit=False)
|
| 596 |
return nll
|
| 597 |
|
| 598 |
def forward(self, noisy_data, extra_data, node_mask):
|
|
|
|
| 6 |
import pytorch_lightning as pl
|
| 7 |
import torch
|
| 8 |
import torch.nn as nn
|
|
|
|
| 9 |
from tqdm.auto import tqdm
|
| 10 |
|
| 11 |
from diffusion.noise_schedule import PredefinedNoiseSchedule
|
|
|
|
| 140 |
|
| 141 |
def on_fit_start(self) -> None:
|
| 142 |
self.train_iterations = len(self.trainer.datamodule.train_dataloader())
|
|
|
|
|
|
|
| 143 |
|
| 144 |
def on_train_epoch_start(self) -> None:
|
| 145 |
self.start_epoch_time = time.time()
|
|
|
|
| 183 |
metrics = [self.val_nll.compute(), self.val_X_mse.compute(), self.val_E_mse.compute(),
|
| 184 |
self.val_y_mse.compute(), self.val_X_logp.compute(), self.val_E_logp.compute(),
|
| 185 |
self.val_y_logp.compute()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type MSE {metrics[1] :.2f} -- ",
|
| 187 |
f"Val Edge type MSE: {metrics[2] :.2f} -- Val Global feat. MSE {metrics[3] :.2f}",
|
| 188 |
f"-- Val X Reconstruction loss {metrics[4] :.2f} -- Val E Reconstruction loss {metrics[5] :.2f}",
|
|
|
|
| 191 |
# Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
|
| 192 |
val_nll = metrics[0]
|
| 193 |
self.log("val/epoch_NLL", val_nll, sync_dist=True)
|
|
|
|
|
|
|
| 194 |
|
| 195 |
if val_nll < self.best_val_nll:
|
| 196 |
self.best_val_nll = val_nll
|
|
|
|
| 235 |
self.test_X_logp.reset()
|
| 236 |
self.test_E_logp.reset()
|
| 237 |
self.test_y_logp.reset()
|
|
|
|
|
|
|
| 238 |
|
| 239 |
def test_step(self, data, i):
|
| 240 |
dense_data, node_mask = utils.to_dense(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr,
|
|
|
|
| 261 |
"test/X_logp": metrics[4],
|
| 262 |
"test/E_logp": metrics[5],
|
| 263 |
"test/y_logp": metrics[6]}
|
|
|
|
|
|
|
|
|
|
| 264 |
print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type MSE {metrics[1] :.2f} -- ",
|
| 265 |
f"Test Edge type MSE: {metrics[2] :.2f} -- Test Global feat. MSE {metrics[3] :.2f}",
|
| 266 |
f"-- Test X Reconstruction loss {metrics[4] :.2f} -- Test E Reconstruction loss {metrics[5] :.2f}",
|
| 267 |
f"-- Test y Reconstruction loss {metrics[6] : .2f}\n")
|
| 268 |
|
| 269 |
test_nll = metrics[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
print(f'Test loss: {test_nll :.4f}')
|
| 271 |
|
| 272 |
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
|
|
|
| 297 |
samples_left_to_save -= to_save
|
| 298 |
samples_left_to_generate -= to_generate
|
| 299 |
chains_left_to_save -= chains_save
|
|
|
|
|
|
|
|
|
|
| 300 |
print(f'Test time: {total_eval_time :.4f} seconds')
|
| 301 |
|
| 302 |
self.sampling_metrics.reset()
|
|
|
|
| 561 |
|
| 562 |
nll = self.test_nll(nlls) if test else self.val_nll(nlls) # Average over the batch
|
| 563 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
return nll
|
| 565 |
|
| 566 |
def forward(self, noisy_data, extra_data, node_mask):
|
src/research/MultiProxAn/src/diffusion_model_discrete.py
CHANGED
|
@@ -6,7 +6,6 @@ import pytorch_lightning as pl
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
| 9 |
-
import wandb
|
| 10 |
from tqdm import tqdm
|
| 11 |
|
| 12 |
from diffusion.noise_schedule import DiscreteUniformTransition, MarginalUniformTransition, \
|
|
@@ -148,8 +147,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
|
|
| 148 |
def on_fit_start(self) -> None:
|
| 149 |
self.train_iterations = len(self.trainer.datamodule.train_dataloader())
|
| 150 |
self.print("Size of the input features", self.Xdim, self.Edim, self.ydim)
|
| 151 |
-
if self.local_rank == 0:
|
| 152 |
-
utils.setup_wandb(self.cfg)
|
| 153 |
|
| 154 |
def on_train_epoch_start(self) -> None:
|
| 155 |
self.print("Starting train epoch...")
|
|
@@ -187,13 +184,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
|
|
| 187 |
def on_validation_epoch_end(self) -> None:
|
| 188 |
metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
|
| 189 |
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
| 190 |
-
if wandb.run:
|
| 191 |
-
wandb.log({"val/epoch_NLL": metrics[0],
|
| 192 |
-
"val/X_kl": metrics[1],
|
| 193 |
-
"val/E_kl": metrics[2],
|
| 194 |
-
"val/X_logp": metrics[3],
|
| 195 |
-
"val/E_logp": metrics[4]}, commit=False)
|
| 196 |
-
|
| 197 |
self.print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
| 198 |
f"Val Edge type KL: {metrics[2] :.2f}")
|
| 199 |
|
|
@@ -242,8 +232,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
|
|
| 242 |
self.test_E_kl.reset()
|
| 243 |
self.test_X_logp.reset()
|
| 244 |
self.test_E_logp.reset()
|
| 245 |
-
if self.local_rank == 0:
|
| 246 |
-
utils.setup_wandb(self.cfg)
|
| 247 |
|
| 248 |
def test_step(self, data, i):
|
| 249 |
dense_data, node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
|
|
@@ -258,20 +246,10 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
|
|
| 258 |
""" Measure likelihood on a test set and compute stability metrics. """
|
| 259 |
metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
|
| 260 |
self.test_X_logp.compute(), self.test_E_logp.compute()]
|
| 261 |
-
if wandb.run:
|
| 262 |
-
wandb.log({"test/epoch_NLL": metrics[0],
|
| 263 |
-
"test/X_kl": metrics[1],
|
| 264 |
-
"test/E_kl": metrics[2],
|
| 265 |
-
"test/X_logp": metrics[3],
|
| 266 |
-
"test/E_logp": metrics[4]}, commit=False)
|
| 267 |
-
|
| 268 |
self.print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
|
| 269 |
f"Test Edge type KL: {metrics[2] :.2f}")
|
| 270 |
|
| 271 |
test_nll = metrics[0]
|
| 272 |
-
if wandb.run:
|
| 273 |
-
wandb.log({"test/epoch_NLL": test_nll}, commit=False)
|
| 274 |
-
|
| 275 |
self.print(f'Test loss: {test_nll :.4f}')
|
| 276 |
|
| 277 |
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
|
@@ -304,9 +282,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
|
|
| 304 |
samples_left_to_save -= to_save
|
| 305 |
samples_left_to_generate -= to_generate
|
| 306 |
chains_left_to_save -= chains_save
|
| 307 |
-
if wandb.run:
|
| 308 |
-
wandb.log({"test/time": total_eval_time}, commit=False)
|
| 309 |
-
wandb.run.summary['test_time'] = total_eval_time
|
| 310 |
print(f'Test time: {total_eval_time :.4f} seconds')
|
| 311 |
|
| 312 |
self.print("Saving the generated graphs")
|
|
@@ -531,12 +506,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
|
|
| 531 |
# Update NLL metric object and return batch nll
|
| 532 |
nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch
|
| 533 |
|
| 534 |
-
if wandb.run:
|
| 535 |
-
wandb.log({"kl prior": kl_prior.mean(),
|
| 536 |
-
"Estimator loss terms": loss_all_t.mean(),
|
| 537 |
-
"log_pn": log_pN.mean(),
|
| 538 |
-
"loss_term_0": loss_term_0,
|
| 539 |
-
'batch_test_nll' if test else 'val_nll': nll}, commit=False)
|
| 540 |
return nll
|
| 541 |
|
| 542 |
def forward(self, noisy_data, extra_data, node_mask):
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
import torch.nn.functional as F
|
|
|
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
from diffusion.noise_schedule import DiscreteUniformTransition, MarginalUniformTransition, \
|
|
|
|
| 147 |
def on_fit_start(self) -> None:
|
| 148 |
self.train_iterations = len(self.trainer.datamodule.train_dataloader())
|
| 149 |
self.print("Size of the input features", self.Xdim, self.Edim, self.ydim)
|
|
|
|
|
|
|
| 150 |
|
| 151 |
def on_train_epoch_start(self) -> None:
|
| 152 |
self.print("Starting train epoch...")
|
|
|
|
| 184 |
def on_validation_epoch_end(self) -> None:
|
| 185 |
metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
|
| 186 |
self.val_X_logp.compute(), self.val_E_logp.compute()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
self.print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
|
| 188 |
f"Val Edge type KL: {metrics[2] :.2f}")
|
| 189 |
|
|
|
|
| 232 |
self.test_E_kl.reset()
|
| 233 |
self.test_X_logp.reset()
|
| 234 |
self.test_E_logp.reset()
|
|
|
|
|
|
|
| 235 |
|
| 236 |
def test_step(self, data, i):
|
| 237 |
dense_data, node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
|
|
|
|
| 246 |
""" Measure likelihood on a test set and compute stability metrics. """
|
| 247 |
metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
|
| 248 |
self.test_X_logp.compute(), self.test_E_logp.compute()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
self.print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
|
| 250 |
f"Test Edge type KL: {metrics[2] :.2f}")
|
| 251 |
|
| 252 |
test_nll = metrics[0]
|
|
|
|
|
|
|
|
|
|
| 253 |
self.print(f'Test loss: {test_nll :.4f}')
|
| 254 |
|
| 255 |
samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
|
|
|
|
| 282 |
samples_left_to_save -= to_save
|
| 283 |
samples_left_to_generate -= to_generate
|
| 284 |
chains_left_to_save -= chains_save
|
|
|
|
|
|
|
|
|
|
| 285 |
print(f'Test time: {total_eval_time :.4f} seconds')
|
| 286 |
|
| 287 |
self.print("Saving the generated graphs")
|
|
|
|
| 506 |
# Update NLL metric object and return batch nll
|
| 507 |
nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch
|
| 508 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 509 |
return nll
|
| 510 |
|
| 511 |
def forward(self, noisy_data, extra_data, node_mask):
|
src/research/MultiProxAn/src/metrics/molecular_metrics.py
CHANGED
|
@@ -6,7 +6,6 @@ from src.analysis.rdkit_functions import compute_molecular_metrics
|
|
| 6 |
import torch
|
| 7 |
from torchmetrics import Metric, MetricCollection
|
| 8 |
from torch import Tensor
|
| 9 |
-
import wandb
|
| 10 |
import torch.nn as nn
|
| 11 |
|
| 12 |
|
|
@@ -25,8 +24,6 @@ class TrainMolecularMetrics(nn.Module):
|
|
| 25 |
to_log['train/' + key] = val.item()
|
| 26 |
for key, val in self.train_bond_metrics.compute().items():
|
| 27 |
to_log['train/' + key] = val.item()
|
| 28 |
-
if wandb.run:
|
| 29 |
-
wandb.log(to_log, commit=False)
|
| 30 |
|
| 31 |
def reset(self):
|
| 32 |
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
|
@@ -42,9 +39,6 @@ class TrainMolecularMetrics(nn.Module):
|
|
| 42 |
for key, val in epoch_bond_metrics.items():
|
| 43 |
to_log['train_epoch/epoch' + key] = val.item()
|
| 44 |
|
| 45 |
-
if wandb.run:
|
| 46 |
-
wandb.log(to_log, commit=False)
|
| 47 |
-
|
| 48 |
for key, val in epoch_atom_metrics.items():
|
| 49 |
epoch_atom_metrics[key] = f"{val.item() :.3f}"
|
| 50 |
for key, val in epoch_bond_metrics.items():
|
|
@@ -135,18 +129,6 @@ class SamplingMolecularMetrics(nn.Module):
|
|
| 135 |
edge_mae = self.edge_dist_mae.compute()
|
| 136 |
valency_mae = self.valency_dist_mae.compute()
|
| 137 |
|
| 138 |
-
if wandb.run:
|
| 139 |
-
wandb.log(to_log, commit=False)
|
| 140 |
-
wandb.run.summary['Gen n distribution'] = generated_n_dist
|
| 141 |
-
wandb.run.summary['Gen node distribution'] = generated_node_dist
|
| 142 |
-
wandb.run.summary['Gen edge distribution'] = generated_edge_dist
|
| 143 |
-
wandb.run.summary['Gen valency distribution'] = generated_valency_dist
|
| 144 |
-
|
| 145 |
-
wandb.log({'basic_metrics/n_mae': n_mae,
|
| 146 |
-
'basic_metrics/node_mae': node_mae,
|
| 147 |
-
'basic_metrics/edge_mae': edge_mae,
|
| 148 |
-
'basic_metrics/valency_mae': valency_mae}, commit=False)
|
| 149 |
-
|
| 150 |
if local_rank == 0:
|
| 151 |
print("Custom metrics computed.")
|
| 152 |
if local_rank == 0:
|
|
|
|
| 6 |
import torch
|
| 7 |
from torchmetrics import Metric, MetricCollection
|
| 8 |
from torch import Tensor
|
|
|
|
| 9 |
import torch.nn as nn
|
| 10 |
|
| 11 |
|
|
|
|
| 24 |
to_log['train/' + key] = val.item()
|
| 25 |
for key, val in self.train_bond_metrics.compute().items():
|
| 26 |
to_log['train/' + key] = val.item()
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def reset(self):
|
| 29 |
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
|
|
|
| 39 |
for key, val in epoch_bond_metrics.items():
|
| 40 |
to_log['train_epoch/epoch' + key] = val.item()
|
| 41 |
|
|
|
|
|
|
|
|
|
|
| 42 |
for key, val in epoch_atom_metrics.items():
|
| 43 |
epoch_atom_metrics[key] = f"{val.item() :.3f}"
|
| 44 |
for key, val in epoch_bond_metrics.items():
|
|
|
|
| 129 |
edge_mae = self.edge_dist_mae.compute()
|
| 130 |
valency_mae = self.valency_dist_mae.compute()
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if local_rank == 0:
|
| 133 |
print("Custom metrics computed.")
|
| 134 |
if local_rank == 0:
|
src/research/MultiProxAn/src/metrics/molecular_metrics_discrete.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
| 1 |
import torch
|
| 2 |
from torchmetrics import Metric, MetricCollection
|
| 3 |
from torch import Tensor
|
| 4 |
-
import wandb
|
| 5 |
import torch.nn as nn
|
| 6 |
|
| 7 |
|
|
@@ -167,8 +166,6 @@ class TrainMolecularMetricsDiscrete(nn.Module):
|
|
| 167 |
to_log['train/' + key] = val.item()
|
| 168 |
for key, val in self.train_bond_metrics.compute().items():
|
| 169 |
to_log['train/' + key] = val.item()
|
| 170 |
-
if wandb.run:
|
| 171 |
-
wandb.log(to_log, commit=False)
|
| 172 |
|
| 173 |
def reset(self):
|
| 174 |
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
|
@@ -183,9 +180,6 @@ class TrainMolecularMetricsDiscrete(nn.Module):
|
|
| 183 |
to_log['train_epoch/' + key] = val.item()
|
| 184 |
for key, val in epoch_bond_metrics.items():
|
| 185 |
to_log['train_epoch/' + key] = val.item()
|
| 186 |
-
if wandb.run:
|
| 187 |
-
wandb.log(to_log, commit=False)
|
| 188 |
-
|
| 189 |
for key, val in epoch_atom_metrics.items():
|
| 190 |
epoch_atom_metrics[key] = val.item()
|
| 191 |
for key, val in epoch_bond_metrics.items():
|
|
|
|
| 1 |
import torch
|
| 2 |
from torchmetrics import Metric, MetricCollection
|
| 3 |
from torch import Tensor
|
|
|
|
| 4 |
import torch.nn as nn
|
| 5 |
|
| 6 |
|
|
|
|
| 166 |
to_log['train/' + key] = val.item()
|
| 167 |
for key, val in self.train_bond_metrics.compute().items():
|
| 168 |
to_log['train/' + key] = val.item()
|
|
|
|
|
|
|
| 169 |
|
| 170 |
def reset(self):
|
| 171 |
for metric in [self.train_atom_metrics, self.train_bond_metrics]:
|
|
|
|
| 180 |
to_log['train_epoch/' + key] = val.item()
|
| 181 |
for key, val in epoch_bond_metrics.items():
|
| 182 |
to_log['train_epoch/' + key] = val.item()
|
|
|
|
|
|
|
|
|
|
| 183 |
for key, val in epoch_atom_metrics.items():
|
| 184 |
epoch_atom_metrics[key] = val.item()
|
| 185 |
for key, val in epoch_bond_metrics.items():
|
src/research/MultiProxAn/src/metrics/train_metrics.py
CHANGED
|
@@ -3,7 +3,6 @@ from torch import Tensor
|
|
| 3 |
import torch.nn as nn
|
| 4 |
from torchmetrics import Metric, MeanSquaredError, MetricCollection
|
| 5 |
import time
|
| 6 |
-
import wandb
|
| 7 |
from src.metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchMSE, SumExceptBatchKL, CrossEntropyMetric, \
|
| 8 |
ProbabilityMetric, NLL
|
| 9 |
|
|
@@ -36,8 +35,6 @@ class TrainLoss(nn.Module):
|
|
| 36 |
'train_loss/node_MSE': self.train_node_mse.compute(),
|
| 37 |
'train_loss/edge_MSE': self.train_edge_mse.compute(),
|
| 38 |
'train_loss/y_mse': self.train_y_mse.compute()}
|
| 39 |
-
if wandb.run:
|
| 40 |
-
wandb.log(to_log, commit=True)
|
| 41 |
|
| 42 |
return mse
|
| 43 |
|
|
@@ -53,8 +50,6 @@ class TrainLoss(nn.Module):
|
|
| 53 |
to_log = {"train_epoch/epoch_X_mse": epoch_node_mse,
|
| 54 |
"train_epoch/epoch_E_mse": epoch_edge_mse,
|
| 55 |
"train_epoch/epoch_y_mse": epoch_y_mse}
|
| 56 |
-
if wandb.run:
|
| 57 |
-
wandb.log(to_log)
|
| 58 |
return to_log
|
| 59 |
|
| 60 |
|
|
@@ -101,8 +96,6 @@ class TrainLossDiscrete(nn.Module):
|
|
| 101 |
"train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1,
|
| 102 |
"train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
|
| 103 |
"train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1}
|
| 104 |
-
if wandb.run:
|
| 105 |
-
wandb.log(to_log, commit=True)
|
| 106 |
return loss_X + self.lambda_train[0] * loss_E + self.lambda_train[1] * loss_y
|
| 107 |
|
| 108 |
def reset(self):
|
|
@@ -117,9 +110,6 @@ class TrainLossDiscrete(nn.Module):
|
|
| 117 |
to_log = {"train_epoch/x_CE": epoch_node_loss,
|
| 118 |
"train_epoch/E_CE": epoch_edge_loss,
|
| 119 |
"train_epoch/y_CE": epoch_y_loss}
|
| 120 |
-
if wandb.run:
|
| 121 |
-
wandb.log(to_log, commit=False)
|
| 122 |
-
|
| 123 |
return to_log
|
| 124 |
|
| 125 |
|
|
|
|
| 3 |
import torch.nn as nn
|
| 4 |
from torchmetrics import Metric, MeanSquaredError, MetricCollection
|
| 5 |
import time
|
|
|
|
| 6 |
from src.metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchMSE, SumExceptBatchKL, CrossEntropyMetric, \
|
| 7 |
ProbabilityMetric, NLL
|
| 8 |
|
|
|
|
| 35 |
'train_loss/node_MSE': self.train_node_mse.compute(),
|
| 36 |
'train_loss/edge_MSE': self.train_edge_mse.compute(),
|
| 37 |
'train_loss/y_mse': self.train_y_mse.compute()}
|
|
|
|
|
|
|
| 38 |
|
| 39 |
return mse
|
| 40 |
|
|
|
|
| 50 |
to_log = {"train_epoch/epoch_X_mse": epoch_node_mse,
|
| 51 |
"train_epoch/epoch_E_mse": epoch_edge_mse,
|
| 52 |
"train_epoch/epoch_y_mse": epoch_y_mse}
|
|
|
|
|
|
|
| 53 |
return to_log
|
| 54 |
|
| 55 |
|
|
|
|
| 96 |
"train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1,
|
| 97 |
"train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
|
| 98 |
"train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1}
|
|
|
|
|
|
|
| 99 |
return loss_X + self.lambda_train[0] * loss_E + self.lambda_train[1] * loss_y
|
| 100 |
|
| 101 |
def reset(self):
|
|
|
|
| 110 |
to_log = {"train_epoch/x_CE": epoch_node_loss,
|
| 111 |
"train_epoch/E_CE": epoch_edge_loss,
|
| 112 |
"train_epoch/y_CE": epoch_y_loss}
|
|
|
|
|
|
|
|
|
|
| 113 |
return to_log
|
| 114 |
|
| 115 |
|
src/research/MultiProxAn/src/utils.py
CHANGED
|
@@ -4,7 +4,6 @@ from omegaconf import OmegaConf, open_dict
|
|
| 4 |
from torch_geometric.utils import to_dense_adj, to_dense_batch
|
| 5 |
import torch
|
| 6 |
import omegaconf
|
| 7 |
-
import wandb
|
| 8 |
|
| 9 |
|
| 10 |
def create_folders(args):
|
|
@@ -131,9 +130,3 @@ class PlaceHolder:
|
|
| 131 |
return self
|
| 132 |
|
| 133 |
|
| 134 |
-
def setup_wandb(cfg):
|
| 135 |
-
config_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
|
| 136 |
-
kwargs = {'name': cfg.general.name, 'project': f'graph_ddm_{cfg.dataset.name}', 'config': config_dict,
|
| 137 |
-
'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': cfg.general.wandb}
|
| 138 |
-
wandb.init(**kwargs)
|
| 139 |
-
wandb.save('*.txt')
|
|
|
|
| 4 |
from torch_geometric.utils import to_dense_adj, to_dense_batch
|
| 5 |
import torch
|
| 6 |
import omegaconf
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
def create_folders(args):
|
|
|
|
| 130 |
return self
|
| 131 |
|
| 132 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|