Spaces:
Sleeping
Sleeping
| # GraphGAN + GNN Predictor with Multi-Property Prediction | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch_geometric.data import Data, DataLoader | |
| from torch_geometric.nn import GCNConv, global_mean_pool | |
| from rdkit import Chem | |
| import random | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split | |
| # ========== Generator ========== | |
| class GraphGenerator(nn.Module): | |
| def __init__(self, latent_dim, num_node_features, num_nodes): | |
| super(GraphGenerator, self).__init__() | |
| self.latent_dim = latent_dim | |
| self.num_nodes = num_nodes | |
| self.num_node_features = num_node_features | |
| self.fc = nn.Linear(latent_dim, num_nodes * num_node_features) | |
| def forward(self, z): | |
| node_feats = self.fc(z).view(-1, self.num_nodes, self.num_node_features) | |
| adj_matrix = torch.sigmoid(torch.matmul(node_feats, node_feats.transpose(1, 2))) | |
| return node_feats, adj_matrix | |
| # ========== Discriminator ========== | |
| class GraphDiscriminator(nn.Module): | |
| def __init__(self, in_channels): | |
| super(GraphDiscriminator, self).__init__() | |
| self.conv1 = GCNConv(in_channels, 32) | |
| self.conv2 = GCNConv(32, 64) | |
| self.fc = nn.Linear(64, 1) | |
| def forward(self, x, edge_index, batch): | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = F.relu(self.conv2(x, edge_index)) | |
| x = global_mean_pool(x, batch) | |
| return torch.sigmoid(self.fc(x)) | |
| # ========== GNN Property Predictor (multi-property) ========== | |
| class PropertyPredictor(nn.Module): | |
| def __init__(self, in_channels, out_channels=3): # e.g., 3 properties | |
| super(PropertyPredictor, self).__init__() | |
| self.conv1 = GCNConv(in_channels, 64) | |
| self.conv2 = GCNConv(64, 128) | |
| self.fc = nn.Linear(128, out_channels) | |
| def forward(self, x, edge_index, batch): | |
| x = F.relu(self.conv1(x, edge_index)) | |
| x = F.relu(self.conv2(x, edge_index)) | |
| x = global_mean_pool(x, batch) | |
| return self.fc(x) # Vector of predicted properties | |
| # ========== Utility: Convert adjacency to edge_index ========== | |
| def adj_to_edge_index(adj_matrix, threshold=0.5): | |
| edge_index = (adj_matrix > threshold).nonzero(as_tuple=False).t() | |
| edge_index = edge_index[:, edge_index[0] != edge_index[1]] | |
| return edge_index | |
| # ========== Convert SMILES to Graph ========== | |
| from rdkit import Chem | |
| from torch_geometric.data import Data | |
| def smiles_to_graph(smiles): | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| node_feats = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()] | |
| edge_index = [] | |
| for bond in mol.GetBonds(): | |
| edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) | |
| edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]) | |
| edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() | |
| x = torch.tensor(node_feats, dtype=torch.float) | |
| return Data(x=x, edge_index=edge_index) | |
| # ========== Train Property Predictor ========== | |
| def train_property_predictor(model, dataset, targets, epochs=100): | |
| optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
| loss_fn = nn.MSELoss() | |
| for epoch in range(epochs): | |
| model.train() | |
| total_loss = 0 | |
| for data, y in zip(dataset, targets): | |
| optimizer.zero_grad() | |
| pred = model(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long)) | |
| loss = loss_fn(pred.view(-1), y.view(-1)) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}") | |
| # ========== Generate & Screen by Property ========== | |
| def generate_and_screen(generator, predictor, top_k=5, property_index=0): | |
| generator.eval() | |
| predictor.eval() | |
| z = torch.randn(100, generator.latent_dim) | |
| node_feats, adj = generator(z) | |
| candidates = [] | |
| for n, a in zip(node_feats, adj): | |
| edge_index = adj_to_edge_index(a) | |
| if edge_index.size(1) == 0: | |
| continue # skip graphs with no edges | |
| data = Data(x=n, edge_index=edge_index) | |
| score_vec = predictor(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long)) | |
| # The original code was trying to convert a multi-element tensor to a single scalar. | |
| # Instead, we extract the property value using the property_index. | |
| score = score_vec[0][property_index].item() # Get the score for the specified property_index | |
| candidates.append((data, score, score_vec.detach().numpy())) | |
| candidates = sorted(candidates, key=lambda x: -x[1])[:top_k] | |
| return candidates | |
| # ========== Entry Point ========== | |
| if __name__ == '__main__': | |
| sample_smiles = ["[Fe]CO", "CC[Si]", "CCCCH", "COC", "CCCl"] | |
| graph_data = [smiles_to_graph(s) for s in sample_smiles if smiles_to_graph(s) is not None] | |
| # Multi-property targets: [conductivity, porosity, surface_area] | |
| targets = torch.tensor([ | |
| [0.3, 0.3, 0.7], | |
| [0.6, 0.4, 0.6], | |
| [0.5, 0.5, 0.5], | |
| [0.9, 0.2, 0.8], | |
| [0.7, 0.6, 0.7] | |
| ]) | |
| generator = GraphGenerator(latent_dim=16, num_node_features=1, num_nodes=5) | |
| discriminator = GraphDiscriminator(in_channels=1) | |
| predictor = PropertyPredictor(in_channels=1, out_channels=3) | |
| # Train the predictor on known molecules | |
| train_property_predictor(predictor, graph_data, targets, epochs=100) | |
| # Generate and screen candidates based on conductivity (index 0) | |
| top_candidates = generate_and_screen(generator, predictor, top_k=5, property_index=0) | |
| # Display results | |
| for i, (graph, score, all_props) in enumerate(top_candidates): | |
| print(f"\nCandidate {i+1}:") | |
| print(f" Score (Conductivity): {score:.3f}") | |
| print(f" All Properties: {all_props}") | |
| print(f" Nodes: {graph.x.size(0)}, Edges: {graph.edge_index.size(1)}") | |