Spaces:
Running
Running
Update PolyFusion/GINE.py
Browse files- PolyFusion/GINE.py +11 -16
PolyFusion/GINE.py
CHANGED
|
@@ -1,11 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
GINE.py
|
| 3 |
-
GINE-based masked pretraining on polymer graphs.
|
| 4 |
-
|
| 5 |
-
This file provides (and uses internally):
|
| 6 |
-
- safe_get, match_edge_attr_to_index (used by CL.py)
|
| 7 |
-
- GineEncoder (used by CL.py AND used internally by MaskedGINE)
|
| 8 |
-
- MaskedGINE training script unchanged in behavior (still predicts masked atoms + hop anchors)
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
@@ -287,7 +282,7 @@ def compute_class_weights(train_atomic: List[torch.Tensor]) -> torch.Tensor:
|
|
| 287 |
|
| 288 |
|
| 289 |
# =============================================================================
|
| 290 |
-
# Encoder wrapper used by
|
| 291 |
# =============================================================================
|
| 292 |
|
| 293 |
class GineBlock(nn.Module):
|
|
@@ -309,7 +304,7 @@ class GineBlock(nn.Module):
|
|
| 309 |
|
| 310 |
class GineEncoder(nn.Module):
|
| 311 |
"""
|
| 312 |
-
Graph encoder
|
| 313 |
- Produces node embeddings via GINE
|
| 314 |
- Provides pooled graph embedding via mean pooling + pool_proj
|
| 315 |
- Provides node_logits(...) for reconstruction (atomic prediction head)
|
|
@@ -338,10 +333,10 @@ class GineEncoder(nn.Module):
|
|
| 338 |
|
| 339 |
self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
|
| 340 |
|
| 341 |
-
# node head for masked-atom reconstruction
|
| 342 |
self.atom_head = nn.Linear(node_emb_dim, MASK_ATOM_ID + 1)
|
| 343 |
|
| 344 |
-
# pooled embedding projection
|
| 345 |
self.pool_proj = nn.Linear(node_emb_dim, emb_dim)
|
| 346 |
|
| 347 |
if class_weights is not None:
|
|
@@ -376,7 +371,7 @@ class GineEncoder(nn.Module):
|
|
| 376 |
|
| 377 |
def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
|
| 378 |
"""
|
| 379 |
-
Returns pooled graph embedding (B, emb_dim)
|
| 380 |
Pool = mean over nodes per graph (batch vector).
|
| 381 |
"""
|
| 382 |
if batch is None:
|
|
@@ -401,7 +396,7 @@ class GineEncoder(nn.Module):
|
|
| 401 |
|
| 402 |
|
| 403 |
# =============================================================================
|
| 404 |
-
# Training dataset + collate
|
| 405 |
# =============================================================================
|
| 406 |
|
| 407 |
class PolymerDataset(Dataset):
|
|
@@ -554,7 +549,7 @@ def collate_batch(batch):
|
|
| 554 |
|
| 555 |
|
| 556 |
# =============================================================================
|
| 557 |
-
# Masked pretraining model
|
| 558 |
# =============================================================================
|
| 559 |
|
| 560 |
class MaskedGINE(nn.Module):
|
|
@@ -574,7 +569,7 @@ class MaskedGINE(nn.Module):
|
|
| 574 |
class_weights=None,
|
| 575 |
):
|
| 576 |
super().__init__()
|
| 577 |
-
# Use GineEncoder internally
|
| 578 |
self.encoder = GineEncoder(
|
| 579 |
node_emb_dim=node_emb_dim,
|
| 580 |
edge_emb_dim=edge_emb_dim,
|
|
@@ -595,7 +590,7 @@ class MaskedGINE(nn.Module):
|
|
| 595 |
self.log_var_z = None
|
| 596 |
self.log_var_pos = None
|
| 597 |
|
| 598 |
-
#
|
| 599 |
self.class_weights = getattr(self.encoder, "class_weights", None)
|
| 600 |
|
| 601 |
def forward(
|
|
@@ -885,7 +880,7 @@ def train_and_evaluate(args: argparse.Namespace) -> None:
|
|
| 885 |
except Exception as e:
|
| 886 |
print(f"\nFailed to load best model from {best_model_path}: {e}")
|
| 887 |
|
| 888 |
-
# Final evaluation
|
| 889 |
model.eval()
|
| 890 |
preds_z_all, true_z_all = [], []
|
| 891 |
pred_dists_all, true_dists_all = [], []
|
|
|
|
| 1 |
"""
|
| 2 |
GINE.py
|
| 3 |
+
GINE-based masked pretraining on polymer 2D graphs.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
|
|
| 282 |
|
| 283 |
|
| 284 |
# =============================================================================
|
| 285 |
+
# Encoder wrapper used by MaskedGINE
|
| 286 |
# =============================================================================
|
| 287 |
|
| 288 |
class GineBlock(nn.Module):
|
|
|
|
| 304 |
|
| 305 |
class GineEncoder(nn.Module):
|
| 306 |
"""
|
| 307 |
+
Graph encoder:
|
| 308 |
- Produces node embeddings via GINE
|
| 309 |
- Provides pooled graph embedding via mean pooling + pool_proj
|
| 310 |
- Provides node_logits(...) for reconstruction (atomic prediction head)
|
|
|
|
| 333 |
|
| 334 |
self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
|
| 335 |
|
| 336 |
+
# node head for masked-atom reconstruction
|
| 337 |
self.atom_head = nn.Linear(node_emb_dim, MASK_ATOM_ID + 1)
|
| 338 |
|
| 339 |
+
# pooled embedding projection
|
| 340 |
self.pool_proj = nn.Linear(node_emb_dim, emb_dim)
|
| 341 |
|
| 342 |
if class_weights is not None:
|
|
|
|
| 371 |
|
| 372 |
def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
|
| 373 |
"""
|
| 374 |
+
Returns pooled graph embedding (B, emb_dim).
|
| 375 |
Pool = mean over nodes per graph (batch vector).
|
| 376 |
"""
|
| 377 |
if batch is None:
|
|
|
|
| 396 |
|
| 397 |
|
| 398 |
# =============================================================================
|
| 399 |
+
# Training dataset + collate
|
| 400 |
# =============================================================================
|
| 401 |
|
| 402 |
class PolymerDataset(Dataset):
|
|
|
|
| 549 |
|
| 550 |
|
| 551 |
# =============================================================================
|
| 552 |
+
# Masked pretraining model
|
| 553 |
# =============================================================================
|
| 554 |
|
| 555 |
class MaskedGINE(nn.Module):
|
|
|
|
| 569 |
class_weights=None,
|
| 570 |
):
|
| 571 |
super().__init__()
|
| 572 |
+
# Use GineEncoder internally
|
| 573 |
self.encoder = GineEncoder(
|
| 574 |
node_emb_dim=node_emb_dim,
|
| 575 |
edge_emb_dim=edge_emb_dim,
|
|
|
|
| 590 |
self.log_var_z = None
|
| 591 |
self.log_var_pos = None
|
| 592 |
|
| 593 |
+
# class_weights
|
| 594 |
self.class_weights = getattr(self.encoder, "class_weights", None)
|
| 595 |
|
| 596 |
def forward(
|
|
|
|
| 880 |
except Exception as e:
|
| 881 |
print(f"\nFailed to load best model from {best_model_path}: {e}")
|
| 882 |
|
| 883 |
+
# Final evaluation
|
| 884 |
model.eval()
|
| 885 |
preds_z_all, true_z_all = [], []
|
| 886 |
pred_dists_all, true_dists_all = [], []
|