Spaces:
Running
Running
Update PolyFusion/SchNet.py
Browse files- PolyFusion/SchNet.py +3 -7
PolyFusion/SchNet.py
CHANGED
|
@@ -1,10 +1,6 @@
|
|
| 1 |
"""
|
| 2 |
SchNet.py
|
| 3 |
SchNet-based masked pretraining on polymer conformer geometry.
|
| 4 |
-
|
| 5 |
-
This file provides (and uses internally):
|
| 6 |
-
- NodeSchNetWrapper (used by CL.py AND used internally by MaskedSchNet)
|
| 7 |
-
- MaskedSchNet training code (behavior preserved)
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
|
@@ -276,12 +272,12 @@ class NodeSchNet(nn.Module):
|
|
| 276 |
|
| 277 |
|
| 278 |
# =============================================================================
|
| 279 |
-
# Wrapper used by
|
| 280 |
# =============================================================================
|
| 281 |
|
| 282 |
class NodeSchNetWrapper(nn.Module):
|
| 283 |
"""
|
| 284 |
-
- Produces pooled embedding
|
| 285 |
- Provides node_logits(...) for reconstruction
|
| 286 |
"""
|
| 287 |
|
|
@@ -619,7 +615,7 @@ def train_and_eval(args: argparse.Namespace) -> None:
|
|
| 619 |
except Exception as e:
|
| 620 |
print(f"\nFailed to load best model from {best_model_path}: {e}")
|
| 621 |
|
| 622 |
-
# Final evaluation
|
| 623 |
model.eval()
|
| 624 |
preds_z_all, true_z_all = [], []
|
| 625 |
pred_dists_all, true_dists_all = [], []
|
|
|
|
| 1 |
"""
|
| 2 |
SchNet.py
|
| 3 |
SchNet-based masked pretraining on polymer conformer geometry.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
"""
|
| 5 |
|
| 6 |
from __future__ import annotations
|
|
|
|
| 272 |
|
| 273 |
|
| 274 |
# =============================================================================
|
| 275 |
+
# Wrapper used by MaskedSchNet
|
| 276 |
# =============================================================================
|
| 277 |
|
| 278 |
class NodeSchNetWrapper(nn.Module):
|
| 279 |
"""
|
| 280 |
+
- Produces pooled embedding (mean pooling + pool_proj)
|
| 281 |
- Provides node_logits(...) for reconstruction
|
| 282 |
"""
|
| 283 |
|
|
|
|
| 615 |
except Exception as e:
|
| 616 |
print(f"\nFailed to load best model from {best_model_path}: {e}")
|
| 617 |
|
| 618 |
+
# Final evaluation
|
| 619 |
model.eval()
|
| 620 |
preds_z_all, true_z_all = [], []
|
| 621 |
pred_dists_all, true_dists_all = [], []
|