kaurm43 commited on
Commit
47269bd
·
verified ·
1 Parent(s): 01a9026

Update PolyFusion/SchNet.py

Browse files
Files changed (1) hide show
  1. PolyFusion/SchNet.py +3 -7
PolyFusion/SchNet.py CHANGED
@@ -1,10 +1,6 @@
1
  """
2
  SchNet.py
3
  SchNet-based masked pretraining on polymer conformer geometry.
4
-
5
- This file provides (and uses internally):
6
- - NodeSchNetWrapper (used by CL.py AND used internally by MaskedSchNet)
7
- - MaskedSchNet training code (behavior preserved)
8
  """
9
 
10
  from __future__ import annotations
@@ -276,12 +272,12 @@ class NodeSchNet(nn.Module):
276
 
277
 
278
  # =============================================================================
279
- # Wrapper used by CL.py AND used internally by MaskedSchNet
280
  # =============================================================================
281
 
282
  class NodeSchNetWrapper(nn.Module):
283
  """
284
- - Produces pooled embedding for CL (mean pooling + pool_proj)
285
  - Provides node_logits(...) for reconstruction
286
  """
287
 
@@ -619,7 +615,7 @@ def train_and_eval(args: argparse.Namespace) -> None:
619
  except Exception as e:
620
  print(f"\nFailed to load best model from {best_model_path}: {e}")
621
 
622
- # Final evaluation (same as your original)
623
  model.eval()
624
  preds_z_all, true_z_all = [], []
625
  pred_dists_all, true_dists_all = [], []
 
1
  """
2
  SchNet.py
3
  SchNet-based masked pretraining on polymer conformer geometry.
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
 
272
 
273
 
274
  # =============================================================================
275
+ # Wrapper used by MaskedSchNet
276
  # =============================================================================
277
 
278
  class NodeSchNetWrapper(nn.Module):
279
  """
280
+ - Produces pooled embedding (mean pooling + pool_proj)
281
  - Provides node_logits(...) for reconstruction
282
  """
283
 
 
615
  except Exception as e:
616
  print(f"\nFailed to load best model from {best_model_path}: {e}")
617
 
618
+ # Final evaluation
619
  model.eval()
620
  preds_z_all, true_z_all = [], []
621
  pred_dists_all, true_dists_all = [], []