# GraphGAN + GNN Predictor with Multi-Property Prediction import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.data import Data, DataLoader from torch_geometric.nn import GCNConv, global_mean_pool from rdkit import Chem import random import matplotlib.pyplot as plt import pandas as pd from sklearn.model_selection import train_test_split # ========== Generator ========== class GraphGenerator(nn.Module): def __init__(self, latent_dim, num_node_features, num_nodes): super(GraphGenerator, self).__init__() self.latent_dim = latent_dim self.num_nodes = num_nodes self.num_node_features = num_node_features self.fc = nn.Linear(latent_dim, num_nodes * num_node_features) def forward(self, z): node_feats = self.fc(z).view(-1, self.num_nodes, self.num_node_features) adj_matrix = torch.sigmoid(torch.matmul(node_feats, node_feats.transpose(1, 2))) return node_feats, adj_matrix # ========== Discriminator ========== class GraphDiscriminator(nn.Module): def __init__(self, in_channels): super(GraphDiscriminator, self).__init__() self.conv1 = GCNConv(in_channels, 32) self.conv2 = GCNConv(32, 64) self.fc = nn.Linear(64, 1) def forward(self, x, edge_index, batch): x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return torch.sigmoid(self.fc(x)) # ========== GNN Property Predictor (multi-property) ========== class PropertyPredictor(nn.Module): def __init__(self, in_channels, out_channels=3): # e.g., 3 properties super(PropertyPredictor, self).__init__() self.conv1 = GCNConv(in_channels, 64) self.conv2 = GCNConv(64, 128) self.fc = nn.Linear(128, out_channels) def forward(self, x, edge_index, batch): x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = global_mean_pool(x, batch) return self.fc(x) # Vector of predicted properties # ========== Utility: Convert adjacency to edge_index ========== def adj_to_edge_index(adj_matrix, threshold=0.5): edge_index = (adj_matrix > threshold).nonzero(as_tuple=False).t() edge_index = edge_index[:, edge_index[0] != edge_index[1]] return edge_index # ========== Convert SMILES to Graph ========== from rdkit import Chem from torch_geometric.data import Data def smiles_to_graph(smiles): mol = Chem.MolFromSmiles(smiles) if mol is None: return None node_feats = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()] edge_index = [] for bond in mol.GetBonds(): edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() x = torch.tensor(node_feats, dtype=torch.float) return Data(x=x, edge_index=edge_index) # ========== Train Property Predictor ========== def train_property_predictor(model, dataset, targets, epochs=100): optimizer = torch.optim.Adam(model.parameters(), lr=0.001) loss_fn = nn.MSELoss() for epoch in range(epochs): model.train() total_loss = 0 for data, y in zip(dataset, targets): optimizer.zero_grad() pred = model(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long)) loss = loss_fn(pred.view(-1), y.view(-1)) loss.backward() optimizer.step() total_loss += loss.item() print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}") # ========== Generate & Screen by Property ========== def generate_and_screen(generator, predictor, top_k=5, property_index=0): generator.eval() predictor.eval() z = torch.randn(100, generator.latent_dim) node_feats, adj = generator(z) candidates = [] for n, a in zip(node_feats, adj): edge_index = adj_to_edge_index(a) if edge_index.size(1) == 0: continue # skip graphs with no edges data = Data(x=n, edge_index=edge_index) score_vec = predictor(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long)) # The original code was trying to convert a multi-element tensor to a single scalar. # Instead, we extract the property value using the property_index. score = score_vec[0][property_index].item() # Get the score for the specified property_index candidates.append((data, score, score_vec.detach().numpy())) candidates = sorted(candidates, key=lambda x: -x[1])[:top_k] return candidates # ========== Entry Point ========== if __name__ == '__main__': sample_smiles = ["[Fe]CO", "CC[Si]", "CCCCH", "COC", "CCCl"] graph_data = [smiles_to_graph(s) for s in sample_smiles if smiles_to_graph(s) is not None] # Multi-property targets: [conductivity, porosity, surface_area] targets = torch.tensor([ [0.3, 0.3, 0.7], [0.6, 0.4, 0.6], [0.5, 0.5, 0.5], [0.9, 0.2, 0.8], [0.7, 0.6, 0.7] ]) generator = GraphGenerator(latent_dim=16, num_node_features=1, num_nodes=5) discriminator = GraphDiscriminator(in_channels=1) predictor = PropertyPredictor(in_channels=1, out_channels=3) # Train the predictor on known molecules train_property_predictor(predictor, graph_data, targets, epochs=100) # Generate and screen candidates based on conductivity (index 0) top_candidates = generate_and_screen(generator, predictor, top_k=5, property_index=0) # Display results for i, (graph, score, all_props) in enumerate(top_candidates): print(f"\nCandidate {i+1}:") print(f" Score (Conductivity): {score:.3f}") print(f" All Properties: {all_props}") print(f" Nodes: {graph.x.size(0)}, Edges: {graph.edge_index.size(1)}")