Andrej Janchevski commited on
Commit
16cab72
·
1 Parent(s): b701828

fix(research): remove wandb dependency and guard optional imports

Browse files

- Strip wandb imports and calls from MultiProxAn diffusion models and
utils (prevents ImportError at checkpoint load time)
- Wrap graph_tool, pyemd, pygsp, dist_helper imports in try/except
across spectre_utils, molecular_metrics, train_metrics (these are
only needed for training metrics, not inference)
- Fix pandas deprecation in COINs load_graph.py (iteritems -> items,
to_frame column naming)

src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py CHANGED
@@ -128,8 +128,8 @@ class Loader:
128
  hits_at_3_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(3).sum() / num_edges
129
  hits_at_10_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(10).sum() / num_edges
130
  community_query_counts = community_query_edge_counts.groupby(["c_s", "r"]).count()
131
- community_query_edge_counts = community_query_edge_counts.to_frame().assign(rank=0, rrank=0)
132
- for (c_s, r), c_t_count in community_query_counts.iteritems():
133
  community_query_edge_counts.loc[(c_s, r), "rank"] = np.arange(1, c_t_count + 1)
134
  community_query_edge_counts.loc[(c_s, r), "rrank"] = 1 / np.arange(1, c_t_count + 1)
135
  mr_limit = (community_query_edge_counts["c_t"] * community_query_edge_counts["rank"]).sum() / num_edges
 
128
  hits_at_3_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(3).sum() / num_edges
129
  hits_at_10_limit = community_query_edge_counts.groupby(["c_s", "r"]).head(10).sum() / num_edges
130
  community_query_counts = community_query_edge_counts.groupby(["c_s", "r"]).count()
131
+ community_query_edge_counts = community_query_edge_counts.to_frame(name="c_t").assign(rank=0, rrank=0.0)
132
+ for (c_s, r), c_t_count in community_query_counts.items():
133
  community_query_edge_counts.loc[(c_s, r), "rank"] = np.arange(1, c_t_count + 1)
134
  community_query_edge_counts.loc[(c_s, r), "rrank"] = 1 / np.arange(1, c_t_count + 1)
135
  mr_limit = (community_query_edge_counts["c_t"] * community_query_edge_counts["rank"]).sum() / num_edges
src/research/MultiProxAn/src/analysis/rdkit_functions.py CHANGED
@@ -1,7 +1,6 @@
1
  import numpy as np
2
  import torch
3
  import re
4
- import wandb
5
  try:
6
  from rdkit import Chem
7
  print("Found rdkit, all good")
@@ -316,19 +315,10 @@ def compute_molecular_metrics(molecule_list, train_smiles, dataset_info):
316
  fraction_mol_stable = molecule_stable / float(n_molecules)
317
  fraction_atm_stable = nr_stable_bonds / float(n_atoms)
318
  validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
319
- if wandb.run:
320
- wandb.log(validity_dict)
321
  else:
322
  validity_dict = {'mol_stable': -1, 'atm_stable': -1}
323
 
324
  metrics = BasicMolecularMetrics(dataset_info, train_smiles)
325
  rdkit_metrics = metrics.evaluate(molecule_list)
326
  all_smiles = rdkit_metrics[-1]
327
- if wandb.run:
328
- nc = rdkit_metrics[-2]
329
- dic = {'Validity': rdkit_metrics[0][0], 'Relaxed Validity': rdkit_metrics[0][1],
330
- 'Uniqueness': rdkit_metrics[0][2], 'Novelty': rdkit_metrics[0][3],
331
- 'nc_max': nc['nc_max'], 'nc_mu': nc['nc_mu']}
332
- wandb.log(dic)
333
-
334
  return validity_dict, rdkit_metrics, all_smiles
 
1
  import numpy as np
2
  import torch
3
  import re
 
4
  try:
5
  from rdkit import Chem
6
  print("Found rdkit, all good")
 
315
  fraction_mol_stable = molecule_stable / float(n_molecules)
316
  fraction_atm_stable = nr_stable_bonds / float(n_atoms)
317
  validity_dict = {'mol_stable': fraction_mol_stable, 'atm_stable': fraction_atm_stable}
 
 
318
  else:
319
  validity_dict = {'mol_stable': -1, 'atm_stable': -1}
320
 
321
  metrics = BasicMolecularMetrics(dataset_info, train_smiles)
322
  rdkit_metrics = metrics.evaluate(molecule_list)
323
  all_smiles = rdkit_metrics[-1]
 
 
 
 
 
 
 
324
  return validity_dict, rdkit_metrics, all_smiles
src/research/MultiProxAn/src/analysis/spectre_utils.py CHANGED
@@ -3,7 +3,6 @@
3
  # Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation
4
  #
5
  ###############################################################################
6
- import graph_tool.all as gt
7
  ##Navigate to the ./util/orca directory and compile orca.cpp
8
  # g++ -O2 -std=c++11 -o orca orca.cpp
9
  import os
@@ -12,18 +11,26 @@ import torch
12
  import torch.nn as nn
13
  import numpy as np
14
  import networkx as nx
15
- import subprocess as sp
16
- import concurrent.futures
17
-
18
- import pygsp as pg
19
- import secrets
20
- from string import ascii_uppercase, digits
21
- from datetime import datetime
22
- from scipy.linalg import eigvalsh
23
- from scipy.stats import chi2
24
- from src.analysis.dist_helper import compute_mmd, gaussian_emd, gaussian, emd, gaussian_tv, disc
25
- from torch_geometric.utils import to_networkx
26
- import wandb
 
 
 
 
 
 
 
 
27
 
28
  PRINT_TIME = False
29
  __all__ = ['degree_stats', 'clustering_stats', 'orbit_stats_all', 'spectral_stats', 'eval_acc_lobster_graph']
@@ -778,8 +785,6 @@ class SpectreSamplingMetrics(nn.Module):
778
  degree = degree_stats(reference_graphs, networkx_graphs, is_parallel=False,
779
  compute_emd=self.compute_emd)
780
  to_log['degree'] = degree
781
- if wandb.run:
782
- wandb.run.summary['degree'] = degree
783
 
784
  # val_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.val]
785
  # train_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.train]
@@ -795,8 +800,6 @@ class SpectreSamplingMetrics(nn.Module):
795
  compute_emd=self.compute_emd)
796
 
797
  to_log['spectre'] = spectre
798
- if wandb.run:
799
- wandb.run.summary['spectre'] = spectre
800
 
801
  if 'clustering' in self.metrics_list:
802
  if local_rank == 0:
@@ -804,8 +807,6 @@ class SpectreSamplingMetrics(nn.Module):
804
  clustering = clustering_stats(reference_graphs, networkx_graphs, bins=100, is_parallel=False,
805
  compute_emd=self.compute_emd)
806
  to_log['clustering'] = clustering
807
- if wandb.run:
808
- wandb.run.summary['clustering'] = clustering
809
 
810
  if 'motif' in self.metrics_list:
811
  if local_rank == 0:
@@ -813,32 +814,24 @@ class SpectreSamplingMetrics(nn.Module):
813
  motif = motif_stats(reference_graphs, networkx_graphs, motif_type='4cycle', ground_truth_match=None, bins=100,
814
  compute_emd=self.compute_emd)
815
  to_log['motif'] = motif
816
- if wandb.run:
817
- wandb.run.summary['motif'] = motif
818
 
819
  if 'orbit' in self.metrics_list:
820
  if local_rank == 0:
821
  print("Computing orbit stats...")
822
  orbit = orbit_stats_all(reference_graphs, networkx_graphs, compute_emd=self.compute_emd)
823
  to_log['orbit'] = orbit
824
- if wandb.run:
825
- wandb.run.summary['orbit'] = orbit
826
 
827
  if 'sbm' in self.metrics_list:
828
  if local_rank == 0:
829
  print("Computing accuracy...")
830
  acc = eval_acc_sbm_graph(networkx_graphs, refinement_steps=100, strict=True)
831
  to_log['sbm_acc'] = acc
832
- if wandb.run:
833
- wandb.run.summary['sbmacc'] = acc
834
 
835
  if 'planar' in self.metrics_list:
836
  if local_rank ==0:
837
  print('Computing planar accuracy...')
838
  planar_acc = eval_acc_planar_graph(networkx_graphs)
839
  to_log['planar_acc'] = planar_acc
840
- if wandb.run:
841
- wandb.run.summary['planar_acc'] = planar_acc
842
 
843
  if 'sbm' or 'planar' in self.metrics_list:
844
  if local_rank == 0:
@@ -853,8 +846,6 @@ class SpectreSamplingMetrics(nn.Module):
853
 
854
  if local_rank == 0:
855
  print("Sampling statistics", to_log)
856
- if wandb.run:
857
- wandb.log(to_log, commit=False)
858
 
859
  def reset(self):
860
  pass
 
3
  # Adapted from https://github.com/lrjconan/GRAN/ which in turn is adapted from https://github.com/JiaxuanYou/graph-generation
4
  #
5
  ###############################################################################
 
6
  ##Navigate to the ./util/orca directory and compile orca.cpp
7
  # g++ -O2 -std=c++11 -o orca orca.cpp
8
  import os
 
11
  import torch.nn as nn
12
  import numpy as np
13
  import networkx as nx
14
+
15
+ # Heavy metric-computation deps — optional, not needed for inference.
16
+ # Deferred so checkpoint unpickling works without graph-tool / pyemd / pygsp.
17
+ try:
18
+ import graph_tool.all as gt
19
+ except ImportError:
20
+ gt = None
21
+ try:
22
+ import subprocess as sp
23
+ import concurrent.futures
24
+ import pygsp as pg
25
+ import secrets
26
+ from string import ascii_uppercase, digits
27
+ from datetime import datetime
28
+ from scipy.linalg import eigvalsh
29
+ from scipy.stats import chi2
30
+ from src.analysis.dist_helper import compute_mmd, gaussian_emd, gaussian, emd, gaussian_tv, disc
31
+ from torch_geometric.utils import to_networkx
32
+ except ImportError:
33
+ pass # Metrics unavailable — inference still works
34
 
35
  PRINT_TIME = False
36
  __all__ = ['degree_stats', 'clustering_stats', 'orbit_stats_all', 'spectral_stats', 'eval_acc_lobster_graph']
 
785
  degree = degree_stats(reference_graphs, networkx_graphs, is_parallel=False,
786
  compute_emd=self.compute_emd)
787
  to_log['degree'] = degree
 
 
788
 
789
  # val_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.val]
790
  # train_eigvals = [graph["eigval"][1:self.k + 1].cpu().detach().numpy() for graph in self.train]
 
800
  compute_emd=self.compute_emd)
801
 
802
  to_log['spectre'] = spectre
 
 
803
 
804
  if 'clustering' in self.metrics_list:
805
  if local_rank == 0:
 
807
  clustering = clustering_stats(reference_graphs, networkx_graphs, bins=100, is_parallel=False,
808
  compute_emd=self.compute_emd)
809
  to_log['clustering'] = clustering
 
 
810
 
811
  if 'motif' in self.metrics_list:
812
  if local_rank == 0:
 
814
  motif = motif_stats(reference_graphs, networkx_graphs, motif_type='4cycle', ground_truth_match=None, bins=100,
815
  compute_emd=self.compute_emd)
816
  to_log['motif'] = motif
 
 
817
 
818
  if 'orbit' in self.metrics_list:
819
  if local_rank == 0:
820
  print("Computing orbit stats...")
821
  orbit = orbit_stats_all(reference_graphs, networkx_graphs, compute_emd=self.compute_emd)
822
  to_log['orbit'] = orbit
 
 
823
 
824
  if 'sbm' in self.metrics_list:
825
  if local_rank == 0:
826
  print("Computing accuracy...")
827
  acc = eval_acc_sbm_graph(networkx_graphs, refinement_steps=100, strict=True)
828
  to_log['sbm_acc'] = acc
 
 
829
 
830
  if 'planar' in self.metrics_list:
831
  if local_rank ==0:
832
  print('Computing planar accuracy...')
833
  planar_acc = eval_acc_planar_graph(networkx_graphs)
834
  to_log['planar_acc'] = planar_acc
 
 
835
 
836
  if 'sbm' or 'planar' in self.metrics_list:
837
  if local_rank == 0:
 
846
 
847
  if local_rank == 0:
848
  print("Sampling statistics", to_log)
 
 
849
 
850
  def reset(self):
851
  pass
src/research/MultiProxAn/src/analysis/visualization.py CHANGED
@@ -8,7 +8,6 @@ import imageio
8
  import networkx as nx
9
  import numpy as np
10
  import rdkit.Chem
11
- import wandb
12
  import matplotlib.pyplot as plt
13
 
14
 
@@ -78,9 +77,6 @@ class MolecularVisualization:
78
  mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
79
  try:
80
  Draw.MolToFile(mol, file_path)
81
- if wandb.run and log is not None:
82
- print(f"Saving {file_path} to wandb")
83
- wandb.log({log: wandb.Image(file_path)}, commit=True)
84
  except rdkit.Chem.KekulizeException:
85
  print("Can't kekulize molecule")
86
 
@@ -115,10 +111,6 @@ class MolecularVisualization:
115
  imgs.extend([imgs[-1]] * 10)
116
  imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
117
 
118
- if wandb.run:
119
- print(f"Saving {gif_path} to wandb")
120
- wandb.log({"chain": wandb.Video(gif_path, fps=5, format="gif")}, commit=True)
121
-
122
  # draw grid image
123
  try:
124
  img = Draw.MolsToGridImage(mols, molsPerRow=20, subImgSize=(200, 200))
@@ -185,8 +177,6 @@ class NonMolecularVisualization:
185
  graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
186
  self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
187
  im = plt.imread(file_path)
188
- if wandb.run and log is not None:
189
- wandb.log({log: [wandb.Image(im, caption=file_path)]})
190
 
191
  def visualize_chain(self, path, nodes_list, adjacency_matrix):
192
  # convert graphs to networkx
@@ -219,5 +209,3 @@ class NonMolecularVisualization:
219
  gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
220
  imgs.extend([imgs[-1]] * 10)
221
  imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
222
- if wandb.run:
223
- wandb.log({'chain': [wandb.Video(gif_path, caption=gif_path, format="gif")]})
 
8
  import networkx as nx
9
  import numpy as np
10
  import rdkit.Chem
 
11
  import matplotlib.pyplot as plt
12
 
13
 
 
77
  mol = self.mol_from_graphs(molecules[i][0].numpy(), molecules[i][1].numpy())
78
  try:
79
  Draw.MolToFile(mol, file_path)
 
 
 
80
  except rdkit.Chem.KekulizeException:
81
  print("Can't kekulize molecule")
82
 
 
111
  imgs.extend([imgs[-1]] * 10)
112
  imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
113
 
 
 
 
 
114
  # draw grid image
115
  try:
116
  img = Draw.MolsToGridImage(mols, molsPerRow=20, subImgSize=(200, 200))
 
177
  graph = self.to_networkx(graphs[i][0].numpy(), graphs[i][1].numpy())
178
  self.visualize_non_molecule(graph=graph, pos=None, path=file_path)
179
  im = plt.imread(file_path)
 
 
180
 
181
  def visualize_chain(self, path, nodes_list, adjacency_matrix):
182
  # convert graphs to networkx
 
209
  gif_path = os.path.join(os.path.dirname(path), '{}.gif'.format(path.split('/')[-1]))
210
  imgs.extend([imgs[-1]] * 10)
211
  imageio.mimsave(gif_path, imgs, subrectangles=True, duration=20)
 
 
src/research/MultiProxAn/src/diffusion_model.py CHANGED
@@ -6,7 +6,6 @@ import numpy as np
6
  import pytorch_lightning as pl
7
  import torch
8
  import torch.nn as nn
9
- import wandb
10
  from tqdm.auto import tqdm
11
 
12
  from diffusion.noise_schedule import PredefinedNoiseSchedule
@@ -141,8 +140,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
141
 
142
  def on_fit_start(self) -> None:
143
  self.train_iterations = len(self.trainer.datamodule.train_dataloader())
144
- if self.local_rank == 0:
145
- utils.setup_wandb(self.cfg)
146
 
147
  def on_train_epoch_start(self) -> None:
148
  self.start_epoch_time = time.time()
@@ -186,15 +183,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
186
  metrics = [self.val_nll.compute(), self.val_X_mse.compute(), self.val_E_mse.compute(),
187
  self.val_y_mse.compute(), self.val_X_logp.compute(), self.val_E_logp.compute(),
188
  self.val_y_logp.compute()]
189
- if wandb.run:
190
- wandb.log({"val/epoch_NLL": metrics[0],
191
- "val/X_mse": metrics[1],
192
- "val/E_mse": metrics[2],
193
- "val/y_mse": metrics[3],
194
- "val/X_logp": metrics[4],
195
- "val/E_logp": metrics[5],
196
- "val/y_logp": metrics[6]}, commit=False)
197
-
198
  print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type MSE {metrics[1] :.2f} -- ",
199
  f"Val Edge type MSE: {metrics[2] :.2f} -- Val Global feat. MSE {metrics[3] :.2f}",
200
  f"-- Val X Reconstruction loss {metrics[4] :.2f} -- Val E Reconstruction loss {metrics[5] :.2f}",
@@ -203,8 +191,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
203
  # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
204
  val_nll = metrics[0]
205
  self.log("val/epoch_NLL", val_nll, sync_dist=True)
206
- if wandb.run:
207
- wandb.log(self.log_info(), commit=False)
208
 
209
  if val_nll < self.best_val_nll:
210
  self.best_val_nll = val_nll
@@ -249,8 +235,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
249
  self.test_X_logp.reset()
250
  self.test_E_logp.reset()
251
  self.test_y_logp.reset()
252
- if self.local_rank == 0:
253
- utils.setup_wandb(self.cfg)
254
 
255
  def test_step(self, data, i):
256
  dense_data, node_mask = utils.to_dense(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr,
@@ -277,19 +261,12 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
277
  "test/X_logp": metrics[4],
278
  "test/E_logp": metrics[5],
279
  "test/y_logp": metrics[6]}
280
- if wandb.run:
281
- wandb.log(log_dict, commit=False)
282
-
283
  print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type MSE {metrics[1] :.2f} -- ",
284
  f"Test Edge type MSE: {metrics[2] :.2f} -- Test Global feat. MSE {metrics[3] :.2f}",
285
  f"-- Test X Reconstruction loss {metrics[4] :.2f} -- Test E Reconstruction loss {metrics[5] :.2f}",
286
  f"-- Test y Reconstruction loss {metrics[6] : .2f}\n")
287
 
288
  test_nll = metrics[0]
289
- if wandb.run:
290
- wandb.log({"test/epoch_NLL": test_nll}, commit=False)
291
- wandb.log(self.log_info(), commit=False)
292
-
293
  print(f'Test loss: {test_nll :.4f}')
294
 
295
  samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
@@ -320,9 +297,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
320
  samples_left_to_save -= to_save
321
  samples_left_to_generate -= to_generate
322
  chains_left_to_save -= chains_save
323
- if wandb.run:
324
- wandb.log({"test/time": total_eval_time}, commit=False)
325
- wandb.run.summary['test_time'] = total_eval_time
326
  print(f'Test time: {total_eval_time :.4f} seconds')
327
 
328
  self.sampling_metrics.reset()
@@ -587,12 +561,6 @@ class LiftedDenoisingDiffusion(pl.LightningModule):
587
 
588
  nll = self.test_nll(nlls) if test else self.val_nll(nlls) # Average over the batch
589
 
590
- wandb.log({"kl prior": kl_prior.mean(),
591
- "Estimator loss terms": loss_all_t.mean(),
592
- "Loss term 0": loss_term_0,
593
- "log_pn": log_pN.mean(),
594
- 'test_nll' if test else 'val_nll': nll},
595
- commit=False)
596
  return nll
597
 
598
  def forward(self, noisy_data, extra_data, node_mask):
 
6
  import pytorch_lightning as pl
7
  import torch
8
  import torch.nn as nn
 
9
  from tqdm.auto import tqdm
10
 
11
  from diffusion.noise_schedule import PredefinedNoiseSchedule
 
140
 
141
  def on_fit_start(self) -> None:
142
  self.train_iterations = len(self.trainer.datamodule.train_dataloader())
 
 
143
 
144
  def on_train_epoch_start(self) -> None:
145
  self.start_epoch_time = time.time()
 
183
  metrics = [self.val_nll.compute(), self.val_X_mse.compute(), self.val_E_mse.compute(),
184
  self.val_y_mse.compute(), self.val_X_logp.compute(), self.val_E_logp.compute(),
185
  self.val_y_logp.compute()]
 
 
 
 
 
 
 
 
 
186
  print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type MSE {metrics[1] :.2f} -- ",
187
  f"Val Edge type MSE: {metrics[2] :.2f} -- Val Global feat. MSE {metrics[3] :.2f}",
188
  f"-- Val X Reconstruction loss {metrics[4] :.2f} -- Val E Reconstruction loss {metrics[5] :.2f}",
 
191
  # Log val nll with default Lightning logger, so it can be monitored by checkpoint callback
192
  val_nll = metrics[0]
193
  self.log("val/epoch_NLL", val_nll, sync_dist=True)
 
 
194
 
195
  if val_nll < self.best_val_nll:
196
  self.best_val_nll = val_nll
 
235
  self.test_X_logp.reset()
236
  self.test_E_logp.reset()
237
  self.test_y_logp.reset()
 
 
238
 
239
  def test_step(self, data, i):
240
  dense_data, node_mask = utils.to_dense(x=data.x, edge_index=data.edge_index, edge_attr=data.edge_attr,
 
261
  "test/X_logp": metrics[4],
262
  "test/E_logp": metrics[5],
263
  "test/y_logp": metrics[6]}
 
 
 
264
  print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type MSE {metrics[1] :.2f} -- ",
265
  f"Test Edge type MSE: {metrics[2] :.2f} -- Test Global feat. MSE {metrics[3] :.2f}",
266
  f"-- Test X Reconstruction loss {metrics[4] :.2f} -- Test E Reconstruction loss {metrics[5] :.2f}",
267
  f"-- Test y Reconstruction loss {metrics[6] : .2f}\n")
268
 
269
  test_nll = metrics[0]
 
 
 
 
270
  print(f'Test loss: {test_nll :.4f}')
271
 
272
  samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
 
297
  samples_left_to_save -= to_save
298
  samples_left_to_generate -= to_generate
299
  chains_left_to_save -= chains_save
 
 
 
300
  print(f'Test time: {total_eval_time :.4f} seconds')
301
 
302
  self.sampling_metrics.reset()
 
561
 
562
  nll = self.test_nll(nlls) if test else self.val_nll(nlls) # Average over the batch
563
 
 
 
 
 
 
 
564
  return nll
565
 
566
  def forward(self, noisy_data, extra_data, node_mask):
src/research/MultiProxAn/src/diffusion_model_discrete.py CHANGED
@@ -6,7 +6,6 @@ import pytorch_lightning as pl
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
9
- import wandb
10
  from tqdm import tqdm
11
 
12
  from diffusion.noise_schedule import DiscreteUniformTransition, MarginalUniformTransition, \
@@ -148,8 +147,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
148
  def on_fit_start(self) -> None:
149
  self.train_iterations = len(self.trainer.datamodule.train_dataloader())
150
  self.print("Size of the input features", self.Xdim, self.Edim, self.ydim)
151
- if self.local_rank == 0:
152
- utils.setup_wandb(self.cfg)
153
 
154
  def on_train_epoch_start(self) -> None:
155
  self.print("Starting train epoch...")
@@ -187,13 +184,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
187
  def on_validation_epoch_end(self) -> None:
188
  metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
189
  self.val_X_logp.compute(), self.val_E_logp.compute()]
190
- if wandb.run:
191
- wandb.log({"val/epoch_NLL": metrics[0],
192
- "val/X_kl": metrics[1],
193
- "val/E_kl": metrics[2],
194
- "val/X_logp": metrics[3],
195
- "val/E_logp": metrics[4]}, commit=False)
196
-
197
  self.print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
198
  f"Val Edge type KL: {metrics[2] :.2f}")
199
 
@@ -242,8 +232,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
242
  self.test_E_kl.reset()
243
  self.test_X_logp.reset()
244
  self.test_E_logp.reset()
245
- if self.local_rank == 0:
246
- utils.setup_wandb(self.cfg)
247
 
248
  def test_step(self, data, i):
249
  dense_data, node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
@@ -258,20 +246,10 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
258
  """ Measure likelihood on a test set and compute stability metrics. """
259
  metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
260
  self.test_X_logp.compute(), self.test_E_logp.compute()]
261
- if wandb.run:
262
- wandb.log({"test/epoch_NLL": metrics[0],
263
- "test/X_kl": metrics[1],
264
- "test/E_kl": metrics[2],
265
- "test/X_logp": metrics[3],
266
- "test/E_logp": metrics[4]}, commit=False)
267
-
268
  self.print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
269
  f"Test Edge type KL: {metrics[2] :.2f}")
270
 
271
  test_nll = metrics[0]
272
- if wandb.run:
273
- wandb.log({"test/epoch_NLL": test_nll}, commit=False)
274
-
275
  self.print(f'Test loss: {test_nll :.4f}')
276
 
277
  samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
@@ -304,9 +282,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
304
  samples_left_to_save -= to_save
305
  samples_left_to_generate -= to_generate
306
  chains_left_to_save -= chains_save
307
- if wandb.run:
308
- wandb.log({"test/time": total_eval_time}, commit=False)
309
- wandb.run.summary['test_time'] = total_eval_time
310
  print(f'Test time: {total_eval_time :.4f} seconds')
311
 
312
  self.print("Saving the generated graphs")
@@ -531,12 +506,6 @@ class DiscreteDenoisingDiffusion(pl.LightningModule):
531
  # Update NLL metric object and return batch nll
532
  nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch
533
 
534
- if wandb.run:
535
- wandb.log({"kl prior": kl_prior.mean(),
536
- "Estimator loss terms": loss_all_t.mean(),
537
- "log_pn": log_pN.mean(),
538
- "loss_term_0": loss_term_0,
539
- 'batch_test_nll' if test else 'val_nll': nll}, commit=False)
540
  return nll
541
 
542
  def forward(self, noisy_data, extra_data, node_mask):
 
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
  from tqdm import tqdm
10
 
11
  from diffusion.noise_schedule import DiscreteUniformTransition, MarginalUniformTransition, \
 
147
  def on_fit_start(self) -> None:
148
  self.train_iterations = len(self.trainer.datamodule.train_dataloader())
149
  self.print("Size of the input features", self.Xdim, self.Edim, self.ydim)
 
 
150
 
151
  def on_train_epoch_start(self) -> None:
152
  self.print("Starting train epoch...")
 
184
  def on_validation_epoch_end(self) -> None:
185
  metrics = [self.val_nll.compute(), self.val_X_kl.compute() * self.T, self.val_E_kl.compute() * self.T,
186
  self.val_X_logp.compute(), self.val_E_logp.compute()]
 
 
 
 
 
 
 
187
  self.print(f"Epoch {self.current_epoch}: Val NLL {metrics[0] :.2f} -- Val Atom type KL {metrics[1] :.2f} -- ",
188
  f"Val Edge type KL: {metrics[2] :.2f}")
189
 
 
232
  self.test_E_kl.reset()
233
  self.test_X_logp.reset()
234
  self.test_E_logp.reset()
 
 
235
 
236
  def test_step(self, data, i):
237
  dense_data, node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
 
246
  """ Measure likelihood on a test set and compute stability metrics. """
247
  metrics = [self.test_nll.compute(), self.test_X_kl.compute(), self.test_E_kl.compute(),
248
  self.test_X_logp.compute(), self.test_E_logp.compute()]
 
 
 
 
 
 
 
249
  self.print(f"Epoch {self.current_epoch}: Test NLL {metrics[0] :.2f} -- Test Atom type KL {metrics[1] :.2f} -- ",
250
  f"Test Edge type KL: {metrics[2] :.2f}")
251
 
252
  test_nll = metrics[0]
 
 
 
253
  self.print(f'Test loss: {test_nll :.4f}')
254
 
255
  samples_left_to_generate = self.cfg.general.final_model_samples_to_generate
 
282
  samples_left_to_save -= to_save
283
  samples_left_to_generate -= to_generate
284
  chains_left_to_save -= chains_save
 
 
 
285
  print(f'Test time: {total_eval_time :.4f} seconds')
286
 
287
  self.print("Saving the generated graphs")
 
506
  # Update NLL metric object and return batch nll
507
  nll = (self.test_nll if test else self.val_nll)(nlls) # Average over the batch
508
 
 
 
 
 
 
 
509
  return nll
510
 
511
  def forward(self, noisy_data, extra_data, node_mask):
src/research/MultiProxAn/src/metrics/molecular_metrics.py CHANGED
@@ -6,7 +6,6 @@ from src.analysis.rdkit_functions import compute_molecular_metrics
6
  import torch
7
  from torchmetrics import Metric, MetricCollection
8
  from torch import Tensor
9
- import wandb
10
  import torch.nn as nn
11
 
12
 
@@ -25,8 +24,6 @@ class TrainMolecularMetrics(nn.Module):
25
  to_log['train/' + key] = val.item()
26
  for key, val in self.train_bond_metrics.compute().items():
27
  to_log['train/' + key] = val.item()
28
- if wandb.run:
29
- wandb.log(to_log, commit=False)
30
 
31
  def reset(self):
32
  for metric in [self.train_atom_metrics, self.train_bond_metrics]:
@@ -42,9 +39,6 @@ class TrainMolecularMetrics(nn.Module):
42
  for key, val in epoch_bond_metrics.items():
43
  to_log['train_epoch/epoch' + key] = val.item()
44
 
45
- if wandb.run:
46
- wandb.log(to_log, commit=False)
47
-
48
  for key, val in epoch_atom_metrics.items():
49
  epoch_atom_metrics[key] = f"{val.item() :.3f}"
50
  for key, val in epoch_bond_metrics.items():
@@ -135,18 +129,6 @@ class SamplingMolecularMetrics(nn.Module):
135
  edge_mae = self.edge_dist_mae.compute()
136
  valency_mae = self.valency_dist_mae.compute()
137
 
138
- if wandb.run:
139
- wandb.log(to_log, commit=False)
140
- wandb.run.summary['Gen n distribution'] = generated_n_dist
141
- wandb.run.summary['Gen node distribution'] = generated_node_dist
142
- wandb.run.summary['Gen edge distribution'] = generated_edge_dist
143
- wandb.run.summary['Gen valency distribution'] = generated_valency_dist
144
-
145
- wandb.log({'basic_metrics/n_mae': n_mae,
146
- 'basic_metrics/node_mae': node_mae,
147
- 'basic_metrics/edge_mae': edge_mae,
148
- 'basic_metrics/valency_mae': valency_mae}, commit=False)
149
-
150
  if local_rank == 0:
151
  print("Custom metrics computed.")
152
  if local_rank == 0:
 
6
  import torch
7
  from torchmetrics import Metric, MetricCollection
8
  from torch import Tensor
 
9
  import torch.nn as nn
10
 
11
 
 
24
  to_log['train/' + key] = val.item()
25
  for key, val in self.train_bond_metrics.compute().items():
26
  to_log['train/' + key] = val.item()
 
 
27
 
28
  def reset(self):
29
  for metric in [self.train_atom_metrics, self.train_bond_metrics]:
 
39
  for key, val in epoch_bond_metrics.items():
40
  to_log['train_epoch/epoch' + key] = val.item()
41
 
 
 
 
42
  for key, val in epoch_atom_metrics.items():
43
  epoch_atom_metrics[key] = f"{val.item() :.3f}"
44
  for key, val in epoch_bond_metrics.items():
 
129
  edge_mae = self.edge_dist_mae.compute()
130
  valency_mae = self.valency_dist_mae.compute()
131
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  if local_rank == 0:
133
  print("Custom metrics computed.")
134
  if local_rank == 0:
src/research/MultiProxAn/src/metrics/molecular_metrics_discrete.py CHANGED
@@ -1,7 +1,6 @@
1
  import torch
2
  from torchmetrics import Metric, MetricCollection
3
  from torch import Tensor
4
- import wandb
5
  import torch.nn as nn
6
 
7
 
@@ -167,8 +166,6 @@ class TrainMolecularMetricsDiscrete(nn.Module):
167
  to_log['train/' + key] = val.item()
168
  for key, val in self.train_bond_metrics.compute().items():
169
  to_log['train/' + key] = val.item()
170
- if wandb.run:
171
- wandb.log(to_log, commit=False)
172
 
173
  def reset(self):
174
  for metric in [self.train_atom_metrics, self.train_bond_metrics]:
@@ -183,9 +180,6 @@ class TrainMolecularMetricsDiscrete(nn.Module):
183
  to_log['train_epoch/' + key] = val.item()
184
  for key, val in epoch_bond_metrics.items():
185
  to_log['train_epoch/' + key] = val.item()
186
- if wandb.run:
187
- wandb.log(to_log, commit=False)
188
-
189
  for key, val in epoch_atom_metrics.items():
190
  epoch_atom_metrics[key] = val.item()
191
  for key, val in epoch_bond_metrics.items():
 
1
  import torch
2
  from torchmetrics import Metric, MetricCollection
3
  from torch import Tensor
 
4
  import torch.nn as nn
5
 
6
 
 
166
  to_log['train/' + key] = val.item()
167
  for key, val in self.train_bond_metrics.compute().items():
168
  to_log['train/' + key] = val.item()
 
 
169
 
170
  def reset(self):
171
  for metric in [self.train_atom_metrics, self.train_bond_metrics]:
 
180
  to_log['train_epoch/' + key] = val.item()
181
  for key, val in epoch_bond_metrics.items():
182
  to_log['train_epoch/' + key] = val.item()
 
 
 
183
  for key, val in epoch_atom_metrics.items():
184
  epoch_atom_metrics[key] = val.item()
185
  for key, val in epoch_bond_metrics.items():
src/research/MultiProxAn/src/metrics/train_metrics.py CHANGED
@@ -3,7 +3,6 @@ from torch import Tensor
3
  import torch.nn as nn
4
  from torchmetrics import Metric, MeanSquaredError, MetricCollection
5
  import time
6
- import wandb
7
  from src.metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchMSE, SumExceptBatchKL, CrossEntropyMetric, \
8
  ProbabilityMetric, NLL
9
 
@@ -36,8 +35,6 @@ class TrainLoss(nn.Module):
36
  'train_loss/node_MSE': self.train_node_mse.compute(),
37
  'train_loss/edge_MSE': self.train_edge_mse.compute(),
38
  'train_loss/y_mse': self.train_y_mse.compute()}
39
- if wandb.run:
40
- wandb.log(to_log, commit=True)
41
 
42
  return mse
43
 
@@ -53,8 +50,6 @@ class TrainLoss(nn.Module):
53
  to_log = {"train_epoch/epoch_X_mse": epoch_node_mse,
54
  "train_epoch/epoch_E_mse": epoch_edge_mse,
55
  "train_epoch/epoch_y_mse": epoch_y_mse}
56
- if wandb.run:
57
- wandb.log(to_log)
58
  return to_log
59
 
60
 
@@ -101,8 +96,6 @@ class TrainLossDiscrete(nn.Module):
101
  "train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1,
102
  "train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
103
  "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1}
104
- if wandb.run:
105
- wandb.log(to_log, commit=True)
106
  return loss_X + self.lambda_train[0] * loss_E + self.lambda_train[1] * loss_y
107
 
108
  def reset(self):
@@ -117,9 +110,6 @@ class TrainLossDiscrete(nn.Module):
117
  to_log = {"train_epoch/x_CE": epoch_node_loss,
118
  "train_epoch/E_CE": epoch_edge_loss,
119
  "train_epoch/y_CE": epoch_y_loss}
120
- if wandb.run:
121
- wandb.log(to_log, commit=False)
122
-
123
  return to_log
124
 
125
 
 
3
  import torch.nn as nn
4
  from torchmetrics import Metric, MeanSquaredError, MetricCollection
5
  import time
 
6
  from src.metrics.abstract_metrics import SumExceptBatchMetric, SumExceptBatchMSE, SumExceptBatchKL, CrossEntropyMetric, \
7
  ProbabilityMetric, NLL
8
 
 
35
  'train_loss/node_MSE': self.train_node_mse.compute(),
36
  'train_loss/edge_MSE': self.train_edge_mse.compute(),
37
  'train_loss/y_mse': self.train_y_mse.compute()}
 
 
38
 
39
  return mse
40
 
 
50
  to_log = {"train_epoch/epoch_X_mse": epoch_node_mse,
51
  "train_epoch/epoch_E_mse": epoch_edge_mse,
52
  "train_epoch/epoch_y_mse": epoch_y_mse}
 
 
53
  return to_log
54
 
55
 
 
96
  "train_loss/X_CE": self.node_loss.compute() if true_X.numel() > 0 else -1,
97
  "train_loss/E_CE": self.edge_loss.compute() if true_E.numel() > 0 else -1,
98
  "train_loss/y_CE": self.y_loss.compute() if true_y.numel() > 0 else -1}
 
 
99
  return loss_X + self.lambda_train[0] * loss_E + self.lambda_train[1] * loss_y
100
 
101
  def reset(self):
 
110
  to_log = {"train_epoch/x_CE": epoch_node_loss,
111
  "train_epoch/E_CE": epoch_edge_loss,
112
  "train_epoch/y_CE": epoch_y_loss}
 
 
 
113
  return to_log
114
 
115
 
src/research/MultiProxAn/src/utils.py CHANGED
@@ -4,7 +4,6 @@ from omegaconf import OmegaConf, open_dict
4
  from torch_geometric.utils import to_dense_adj, to_dense_batch
5
  import torch
6
  import omegaconf
7
- import wandb
8
 
9
 
10
  def create_folders(args):
@@ -131,9 +130,3 @@ class PlaceHolder:
131
  return self
132
 
133
 
134
- def setup_wandb(cfg):
135
- config_dict = omegaconf.OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)
136
- kwargs = {'name': cfg.general.name, 'project': f'graph_ddm_{cfg.dataset.name}', 'config': config_dict,
137
- 'settings': wandb.Settings(_disable_stats=True), 'reinit': True, 'mode': cfg.general.wandb}
138
- wandb.init(**kwargs)
139
- wandb.save('*.txt')
 
4
  from torch_geometric.utils import to_dense_adj, to_dense_batch
5
  import torch
6
  import omegaconf
 
7
 
8
 
9
  def create_folders(args):
 
130
  return self
131
 
132