Spaces:
Runtime error
Runtime error
igashov
commited on
Commit
·
d1da608
1
Parent(s):
b0ab0d5
fix size_nn
Browse files- app.py +1 -1
- src/linker_size_lightning.py +6 -3
app.py
CHANGED
|
@@ -72,7 +72,7 @@ print('Loaded diffusion model')
|
|
| 72 |
|
| 73 |
|
| 74 |
def sample_fn(_data):
|
| 75 |
-
output, _ = size_nn.forward(_data)
|
| 76 |
probabilities = torch.softmax(output, dim=1)
|
| 77 |
distribution = torch.distributions.Categorical(probs=probabilities)
|
| 78 |
samples = distribution.sample()
|
|
|
|
| 72 |
|
| 73 |
|
| 74 |
def sample_fn(_data):
|
| 75 |
+
output, _ = size_nn.forward(_data, return_loss=False)
|
| 76 |
probabilities = torch.softmax(output, dim=1)
|
| 77 |
distribution = torch.distributions.Categorical(probs=probabilities)
|
| 78 |
samples = distribution.sample()
|
src/linker_size_lightning.py
CHANGED
|
@@ -79,7 +79,7 @@ class SizeClassifier(pl.LightningModule):
|
|
| 79 |
def test_dataloader(self):
|
| 80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
| 81 |
|
| 82 |
-
def forward(self, data):
|
| 83 |
h = data['one_hot']
|
| 84 |
x = data['positions']
|
| 85 |
fragment_mask = data['fragment_mask']
|
|
@@ -103,8 +103,11 @@ class SizeClassifier(pl.LightningModule):
|
|
| 103 |
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
| 104 |
output = output.view(bs, n_nodes, -1).mean(1)
|
| 105 |
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
return output, loss
|
| 110 |
|
|
|
|
| 79 |
def test_dataloader(self):
|
| 80 |
return get_dataloader(self.test_dataset, self.batch_size, collate_fn=collate_with_fragment_edges)
|
| 81 |
|
| 82 |
+
def forward(self, data, return_loss=True):
|
| 83 |
h = data['one_hot']
|
| 84 |
x = data['positions']
|
| 85 |
fragment_mask = data['fragment_mask']
|
|
|
|
| 103 |
output = self.gnn.forward(h, edges, distances, fragment_mask, distance_edge_mask)
|
| 104 |
output = output.view(bs, n_nodes, -1).mean(1)
|
| 105 |
|
| 106 |
+
if return_loss:
|
| 107 |
+
true = self.get_true_labels(linker_mask)
|
| 108 |
+
loss = cross_entropy(output, true, weight=self.loss_weights)
|
| 109 |
+
else:
|
| 110 |
+
loss = None
|
| 111 |
|
| 112 |
return output, loss
|
| 113 |
|