Spaces:
Sleeping
Sleeping
File size: 5,950 Bytes
9497911 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
# GraphGAN + GNN Predictor with Multi-Property Prediction
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
import random
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
# ========== Generator ==========
class GraphGenerator(nn.Module):
def __init__(self, latent_dim, num_node_features, num_nodes):
super(GraphGenerator, self).__init__()
self.latent_dim = latent_dim
self.num_nodes = num_nodes
self.num_node_features = num_node_features
self.fc = nn.Linear(latent_dim, num_nodes * num_node_features)
def forward(self, z):
node_feats = self.fc(z).view(-1, self.num_nodes, self.num_node_features)
adj_matrix = torch.sigmoid(torch.matmul(node_feats, node_feats.transpose(1, 2)))
return node_feats, adj_matrix
# ========== Discriminator ==========
class GraphDiscriminator(nn.Module):
def __init__(self, in_channels):
super(GraphDiscriminator, self).__init__()
self.conv1 = GCNConv(in_channels, 32)
self.conv2 = GCNConv(32, 64)
self.fc = nn.Linear(64, 1)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch)
return torch.sigmoid(self.fc(x))
# ========== GNN Property Predictor (multi-property) ==========
class PropertyPredictor(nn.Module):
def __init__(self, in_channels, out_channels=3): # e.g., 3 properties
super(PropertyPredictor, self).__init__()
self.conv1 = GCNConv(in_channels, 64)
self.conv2 = GCNConv(64, 128)
self.fc = nn.Linear(128, out_channels)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch)
return self.fc(x) # Vector of predicted properties
# ========== Utility: Convert adjacency to edge_index ==========
def adj_to_edge_index(adj_matrix, threshold=0.5):
edge_index = (adj_matrix > threshold).nonzero(as_tuple=False).t()
edge_index = edge_index[:, edge_index[0] != edge_index[1]]
return edge_index
# ========== Convert SMILES to Graph ==========
from rdkit import Chem
from torch_geometric.data import Data
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
node_feats = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
edge_index = []
for bond in mol.GetBonds():
edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
x = torch.tensor(node_feats, dtype=torch.float)
return Data(x=x, edge_index=edge_index)
# ========== Train Property Predictor ==========
def train_property_predictor(model, dataset, targets, epochs=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
for epoch in range(epochs):
model.train()
total_loss = 0
for data, y in zip(dataset, targets):
optimizer.zero_grad()
pred = model(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
loss = loss_fn(pred.view(-1), y.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")
# ========== Generate & Screen by Property ==========
def generate_and_screen(generator, predictor, top_k=5, property_index=0):
generator.eval()
predictor.eval()
z = torch.randn(100, generator.latent_dim)
node_feats, adj = generator(z)
candidates = []
for n, a in zip(node_feats, adj):
edge_index = adj_to_edge_index(a)
if edge_index.size(1) == 0:
continue # skip graphs with no edges
data = Data(x=n, edge_index=edge_index)
score_vec = predictor(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
# The original code was trying to convert a multi-element tensor to a single scalar.
# Instead, we extract the property value using the property_index.
score = score_vec[0][property_index].item() # Get the score for the specified property_index
candidates.append((data, score, score_vec.detach().numpy()))
candidates = sorted(candidates, key=lambda x: -x[1])[:top_k]
return candidates
# ========== Entry Point ==========
if __name__ == '__main__':
sample_smiles = ["[Fe]CO", "CC[Si]", "CCCCH", "COC", "CCCl"]
graph_data = [smiles_to_graph(s) for s in sample_smiles if smiles_to_graph(s) is not None]
# Multi-property targets: [conductivity, porosity, surface_area]
targets = torch.tensor([
[0.3, 0.3, 0.7],
[0.6, 0.4, 0.6],
[0.5, 0.5, 0.5],
[0.9, 0.2, 0.8],
[0.7, 0.6, 0.7]
])
generator = GraphGenerator(latent_dim=16, num_node_features=1, num_nodes=5)
discriminator = GraphDiscriminator(in_channels=1)
predictor = PropertyPredictor(in_channels=1, out_channels=3)
# Train the predictor on known molecules
train_property_predictor(predictor, graph_data, targets, epochs=100)
# Generate and screen candidates based on conductivity (index 0)
top_candidates = generate_and_screen(generator, predictor, top_k=5, property_index=0)
# Display results
for i, (graph, score, all_props) in enumerate(top_candidates):
print(f"\nCandidate {i+1}:")
print(f" Score (Conductivity): {score:.3f}")
print(f" All Properties: {all_props}")
print(f" Nodes: {graph.x.size(0)}, Edges: {graph.edge_index.size(1)}")
|