Spaces:
Sleeping
Sleeping
File size: 4,954 Bytes
72a3513 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
"""Module containing main class for policy network."""
from abc import ABC
from typing import Dict
import torch
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Linear
from torch.nn.functional import binary_cross_entropy_with_logits, cross_entropy, one_hot
from torch_geometric.data.batch import Batch
from torchmetrics.functional.classification import f1_score, recall, specificity
from synplan.ml.networks.modules import MCTSNetwork
class PolicyNetwork(MCTSNetwork, LightningModule, ABC):
"""Policy network."""
def __init__(
self,
*args,
n_rules: int,
vector_dim: int,
policy_type: str = "ranking",
**kwargs
):
"""Initializes a policy network with the given number of reaction rules (output
dimension) and vector graph embedding dimension, and creates linear layers for
predicting the regular and priority reaction rules.
:param n_rules: The number of reaction rules in the policy network.
:param vector_dim: The dimensionality of the input vectors.
"""
super().__init__(vector_dim, *args, **kwargs)
self.save_hyperparameters()
self.policy_type = policy_type
self.n_rules = n_rules
self.y_predictor = Linear(vector_dim, n_rules)
if self.policy_type == "filtering":
self.priority_predictor = Linear(vector_dim, n_rules)
def forward(self, batch: Batch) -> Tensor:
"""Takes a molecular graph, applies a graph convolution and sigmoid layers to
predict regular and priority reaction rules.
:param batch: The input batch of molecular graphs.
:return: Returns the vector of probabilities (given by sigmoid) of successful
application of regular and priority reaction rules.
"""
x = self.embedder(batch, self.batch_size)
y = self.y_predictor(x)
if self.policy_type == "ranking":
y = torch.softmax(y, dim=-1)
return y
if self.policy_type == "filtering":
y = torch.sigmoid(y)
priority = torch.sigmoid(self.priority_predictor(x))
return y, priority
def _get_loss(self, batch: Batch) -> Dict[str, Tensor]:
"""Calculates the loss and various classification metrics for a given batch for
reaction rules prediction.
:param batch: The batch of molecular graphs.
:return: A dictionary with loss value and balanced accuracy of reaction rules
prediction.
"""
true_y = batch.y_rules.long()
x = self.embedder(batch, self.batch_size)
pred_y = self.y_predictor(x)
if self.policy_type == "ranking":
true_one_hot = one_hot(true_y, num_classes=self.n_rules)
loss = cross_entropy(pred_y, true_one_hot.float())
ba_y = (
recall(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
+ specificity(
pred_y, true_y, task="multiclass", num_classes=self.n_rules
)
) / 2
f1_y = f1_score(pred_y, true_y, task="multiclass", num_classes=self.n_rules)
metrics = {"loss": loss, "balanced_accuracy_y": ba_y, "f1_score_y": f1_y}
elif self.policy_type == "filtering":
loss_y = binary_cross_entropy_with_logits(pred_y, true_y.float())
ba_y = (
recall(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
+ specificity(
pred_y, true_y, task="multilabel", num_labels=self.n_rules
)
) / 2
f1_y = f1_score(pred_y, true_y, task="multilabel", num_labels=self.n_rules)
true_priority = batch.y_priority.float()
pred_priority = self.priority_predictor(x)
loss_priority = binary_cross_entropy_with_logits(
pred_priority, true_priority
)
loss = loss_y + loss_priority
true_priority = true_priority.long()
ba_priority = (
recall(
pred_priority,
true_priority,
task="multilabel",
num_labels=self.n_rules,
)
+ specificity(
pred_priority,
true_priority,
task="multilabel",
num_labels=self.n_rules,
)
) / 2
f1_priority = f1_score(
pred_priority, true_priority, task="multilabel", num_labels=self.n_rules
)
metrics = {
"loss": loss,
"balanced_accuracy_y": ba_y,
"f1_score_y": f1_y,
"balanced_accuracy_priority": ba_priority,
"f1_score_priority": f1_priority,
}
return metrics
|