kaurm43 commited on
Commit
01a9026
·
verified ·
1 Parent(s): 58d199f

Update PolyFusion/GINE.py

Browse files
Files changed (1) hide show
  1. PolyFusion/GINE.py +11 -16
PolyFusion/GINE.py CHANGED
@@ -1,11 +1,6 @@
1
  """
2
  GINE.py
3
- GINE-based masked pretraining on polymer graphs.
4
-
5
- This file provides (and uses internally):
6
- - safe_get, match_edge_attr_to_index (used by CL.py)
7
- - GineEncoder (used by CL.py AND used internally by MaskedGINE)
8
- - MaskedGINE training script unchanged in behavior (still predicts masked atoms + hop anchors)
9
  """
10
 
11
  from __future__ import annotations
@@ -287,7 +282,7 @@ def compute_class_weights(train_atomic: List[torch.Tensor]) -> torch.Tensor:
287
 
288
 
289
  # =============================================================================
290
- # Encoder wrapper used by CL.py AND used internally by MaskedGINE
291
  # =============================================================================
292
 
293
  class GineBlock(nn.Module):
@@ -309,7 +304,7 @@ class GineBlock(nn.Module):
309
 
310
  class GineEncoder(nn.Module):
311
  """
312
- Graph encoder for CL.py:
313
  - Produces node embeddings via GINE
314
  - Provides pooled graph embedding via mean pooling + pool_proj
315
  - Provides node_logits(...) for reconstruction (atomic prediction head)
@@ -338,10 +333,10 @@ class GineEncoder(nn.Module):
338
 
339
  self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
340
 
341
- # node head for masked-atom reconstruction (used in CL.py)
342
  self.atom_head = nn.Linear(node_emb_dim, MASK_ATOM_ID + 1)
343
 
344
- # pooled embedding projection for CL.py
345
  self.pool_proj = nn.Linear(node_emb_dim, emb_dim)
346
 
347
  if class_weights is not None:
@@ -376,7 +371,7 @@ class GineEncoder(nn.Module):
376
 
377
  def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
378
  """
379
- Returns pooled graph embedding (B, emb_dim) for CL.py.
380
  Pool = mean over nodes per graph (batch vector).
381
  """
382
  if batch is None:
@@ -401,7 +396,7 @@ class GineEncoder(nn.Module):
401
 
402
 
403
  # =============================================================================
404
- # Training dataset + collate (unchanged behavior)
405
  # =============================================================================
406
 
407
  class PolymerDataset(Dataset):
@@ -554,7 +549,7 @@ def collate_batch(batch):
554
 
555
 
556
  # =============================================================================
557
- # Masked pretraining model (uses GineEncoder internally)
558
  # =============================================================================
559
 
560
  class MaskedGINE(nn.Module):
@@ -574,7 +569,7 @@ class MaskedGINE(nn.Module):
574
  class_weights=None,
575
  ):
576
  super().__init__()
577
- # Use GineEncoder internally (so wrapper is used here too)
578
  self.encoder = GineEncoder(
579
  node_emb_dim=node_emb_dim,
580
  edge_emb_dim=edge_emb_dim,
@@ -595,7 +590,7 @@ class MaskedGINE(nn.Module):
595
  self.log_var_z = None
596
  self.log_var_pos = None
597
 
598
- # expose class_weights same way as before
599
  self.class_weights = getattr(self.encoder, "class_weights", None)
600
 
601
  def forward(
@@ -885,7 +880,7 @@ def train_and_evaluate(args: argparse.Namespace) -> None:
885
  except Exception as e:
886
  print(f"\nFailed to load best model from {best_model_path}: {e}")
887
 
888
- # Final evaluation (same as your original)
889
  model.eval()
890
  preds_z_all, true_z_all = [], []
891
  pred_dists_all, true_dists_all = [], []
 
1
  """
2
  GINE.py
3
+ GINE-based masked pretraining on polymer 2D graphs.
 
 
 
 
 
4
  """
5
 
6
  from __future__ import annotations
 
282
 
283
 
284
  # =============================================================================
285
+ # Encoder wrapper used by MaskedGINE
286
  # =============================================================================
287
 
288
  class GineBlock(nn.Module):
 
304
 
305
  class GineEncoder(nn.Module):
306
  """
307
+ Graph encoder:
308
  - Produces node embeddings via GINE
309
  - Provides pooled graph embedding via mean pooling + pool_proj
310
  - Provides node_logits(...) for reconstruction (atomic prediction head)
 
333
 
334
  self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
335
 
336
+ # node head for masked-atom reconstruction
337
  self.atom_head = nn.Linear(node_emb_dim, MASK_ATOM_ID + 1)
338
 
339
+ # pooled embedding projection
340
  self.pool_proj = nn.Linear(node_emb_dim, emb_dim)
341
 
342
  if class_weights is not None:
 
371
 
372
  def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
373
  """
374
+ Returns pooled graph embedding (B, emb_dim).
375
  Pool = mean over nodes per graph (batch vector).
376
  """
377
  if batch is None:
 
396
 
397
 
398
  # =============================================================================
399
+ # Training dataset + collate
400
  # =============================================================================
401
 
402
  class PolymerDataset(Dataset):
 
549
 
550
 
551
  # =============================================================================
552
+ # Masked pretraining model
553
  # =============================================================================
554
 
555
  class MaskedGINE(nn.Module):
 
569
  class_weights=None,
570
  ):
571
  super().__init__()
572
+ # Use GineEncoder internally
573
  self.encoder = GineEncoder(
574
  node_emb_dim=node_emb_dim,
575
  edge_emb_dim=edge_emb_dim,
 
590
  self.log_var_z = None
591
  self.log_var_pos = None
592
 
593
+ # class_weights
594
  self.class_weights = getattr(self.encoder, "class_weights", None)
595
 
596
  def forward(
 
880
  except Exception as e:
881
  print(f"\nFailed to load best model from {best_model_path}: {e}")
882
 
883
+ # Final evaluation
884
  model.eval()
885
  preds_z_all, true_z_all = [], []
886
  pred_dists_all, true_dists_all = [], []