Ystar124's picture
model
9497911
# GraphGAN + GNN Predictor with Multi-Property Prediction
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
import random
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
# ========== Generator ==========
class GraphGenerator(nn.Module):
def __init__(self, latent_dim, num_node_features, num_nodes):
super(GraphGenerator, self).__init__()
self.latent_dim = latent_dim
self.num_nodes = num_nodes
self.num_node_features = num_node_features
self.fc = nn.Linear(latent_dim, num_nodes * num_node_features)
def forward(self, z):
node_feats = self.fc(z).view(-1, self.num_nodes, self.num_node_features)
adj_matrix = torch.sigmoid(torch.matmul(node_feats, node_feats.transpose(1, 2)))
return node_feats, adj_matrix
# ========== Discriminator ==========
class GraphDiscriminator(nn.Module):
def __init__(self, in_channels):
super(GraphDiscriminator, self).__init__()
self.conv1 = GCNConv(in_channels, 32)
self.conv2 = GCNConv(32, 64)
self.fc = nn.Linear(64, 1)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch)
return torch.sigmoid(self.fc(x))
# ========== GNN Property Predictor (multi-property) ==========
class PropertyPredictor(nn.Module):
def __init__(self, in_channels, out_channels=3): # e.g., 3 properties
super(PropertyPredictor, self).__init__()
self.conv1 = GCNConv(in_channels, 64)
self.conv2 = GCNConv(64, 128)
self.fc = nn.Linear(128, out_channels)
def forward(self, x, edge_index, batch):
x = F.relu(self.conv1(x, edge_index))
x = F.relu(self.conv2(x, edge_index))
x = global_mean_pool(x, batch)
return self.fc(x) # Vector of predicted properties
# ========== Utility: Convert adjacency to edge_index ==========
def adj_to_edge_index(adj_matrix, threshold=0.5):
edge_index = (adj_matrix > threshold).nonzero(as_tuple=False).t()
edge_index = edge_index[:, edge_index[0] != edge_index[1]]
return edge_index
# ========== Convert SMILES to Graph ==========
from rdkit import Chem
from torch_geometric.data import Data
def smiles_to_graph(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
return None
node_feats = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
edge_index = []
for bond in mol.GetBonds():
edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()])
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
x = torch.tensor(node_feats, dtype=torch.float)
return Data(x=x, edge_index=edge_index)
# ========== Train Property Predictor ==========
def train_property_predictor(model, dataset, targets, epochs=100):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()
for epoch in range(epochs):
model.train()
total_loss = 0
for data, y in zip(dataset, targets):
optimizer.zero_grad()
pred = model(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
loss = loss_fn(pred.view(-1), y.view(-1))
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")
# ========== Generate & Screen by Property ==========
def generate_and_screen(generator, predictor, top_k=5, property_index=0):
generator.eval()
predictor.eval()
z = torch.randn(100, generator.latent_dim)
node_feats, adj = generator(z)
candidates = []
for n, a in zip(node_feats, adj):
edge_index = adj_to_edge_index(a)
if edge_index.size(1) == 0:
continue # skip graphs with no edges
data = Data(x=n, edge_index=edge_index)
score_vec = predictor(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
# The original code was trying to convert a multi-element tensor to a single scalar.
# Instead, we extract the property value using the property_index.
score = score_vec[0][property_index].item() # Get the score for the specified property_index
candidates.append((data, score, score_vec.detach().numpy()))
candidates = sorted(candidates, key=lambda x: -x[1])[:top_k]
return candidates
# ========== Entry Point ==========
if __name__ == '__main__':
sample_smiles = ["[Fe]CO", "CC[Si]", "CCCCH", "COC", "CCCl"]
graph_data = [smiles_to_graph(s) for s in sample_smiles if smiles_to_graph(s) is not None]
# Multi-property targets: [conductivity, porosity, surface_area]
targets = torch.tensor([
[0.3, 0.3, 0.7],
[0.6, 0.4, 0.6],
[0.5, 0.5, 0.5],
[0.9, 0.2, 0.8],
[0.7, 0.6, 0.7]
])
generator = GraphGenerator(latent_dim=16, num_node_features=1, num_nodes=5)
discriminator = GraphDiscriminator(in_channels=1)
predictor = PropertyPredictor(in_channels=1, out_channels=3)
# Train the predictor on known molecules
train_property_predictor(predictor, graph_data, targets, epochs=100)
# Generate and screen candidates based on conductivity (index 0)
top_candidates = generate_and_screen(generator, predictor, top_k=5, property_index=0)
# Display results
for i, (graph, score, all_props) in enumerate(top_candidates):
print(f"\nCandidate {i+1}:")
print(f" Score (Conductivity): {score:.3f}")
print(f" All Properties: {all_props}")
print(f" Nodes: {graph.x.size(0)}, Edges: {graph.edge_index.size(1)}")