File size: 5,950 Bytes
9497911
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151




# GraphGAN + GNN Predictor with Multi-Property Prediction
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool
from rdkit import Chem
import random
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

# ========== Generator ==========
class GraphGenerator(nn.Module):
    def __init__(self, latent_dim, num_node_features, num_nodes):
        super(GraphGenerator, self).__init__()
        self.latent_dim = latent_dim
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features
        self.fc = nn.Linear(latent_dim, num_nodes * num_node_features)

    def forward(self, z):
        node_feats = self.fc(z).view(-1, self.num_nodes, self.num_node_features)
        adj_matrix = torch.sigmoid(torch.matmul(node_feats, node_feats.transpose(1, 2)))
        return node_feats, adj_matrix

# ========== Discriminator ==========
class GraphDiscriminator(nn.Module):
    def __init__(self, in_channels):
        super(GraphDiscriminator, self).__init__()
        self.conv1 = GCNConv(in_channels, 32)
        self.conv2 = GCNConv(32, 64)
        self.fc = nn.Linear(64, 1)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return torch.sigmoid(self.fc(x))

# ========== GNN Property Predictor (multi-property) ==========
class PropertyPredictor(nn.Module):
    def __init__(self, in_channels, out_channels=3):  # e.g., 3 properties
        super(PropertyPredictor, self).__init__()
        self.conv1 = GCNConv(in_channels, 64)
        self.conv2 = GCNConv(64, 128)
        self.fc = nn.Linear(128, out_channels)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return self.fc(x)  # Vector of predicted properties

# ========== Utility: Convert adjacency to edge_index ==========
def adj_to_edge_index(adj_matrix, threshold=0.5):
    edge_index = (adj_matrix > threshold).nonzero(as_tuple=False).t()
    edge_index = edge_index[:, edge_index[0] != edge_index[1]]
    return edge_index

# ========== Convert SMILES to Graph ==========
from rdkit import Chem
from torch_geometric.data import Data

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    node_feats = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
    edge_index = []
    for bond in mol.GetBonds():
        edge_index.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()])
        edge_index.append([bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()])
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_feats, dtype=torch.float)
    return Data(x=x, edge_index=edge_index)

# ========== Train Property Predictor ==========
def train_property_predictor(model, dataset, targets, epochs=100):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_fn = nn.MSELoss()
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for data, y in zip(dataset, targets):
            optimizer.zero_grad()
            pred = model(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
            loss = loss_fn(pred.view(-1), y.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch}, Loss: {total_loss/len(dataset):.4f}")

# ========== Generate & Screen by Property ==========
def generate_and_screen(generator, predictor, top_k=5, property_index=0):
    generator.eval()
    predictor.eval()
    z = torch.randn(100, generator.latent_dim)
    node_feats, adj = generator(z)
    candidates = []

    for n, a in zip(node_feats, adj):
        edge_index = adj_to_edge_index(a)
        if edge_index.size(1) == 0:
            continue  # skip graphs with no edges
        data = Data(x=n, edge_index=edge_index)
        score_vec = predictor(data.x, data.edge_index, torch.zeros(data.x.size(0), dtype=torch.long))
        # The original code was trying to convert a multi-element tensor to a single scalar.
        # Instead, we extract the property value using the property_index.
        score = score_vec[0][property_index].item()  # Get the score for the specified property_index
        candidates.append((data, score, score_vec.detach().numpy()))

    candidates = sorted(candidates, key=lambda x: -x[1])[:top_k]
    return candidates

# ========== Entry Point ==========
if __name__ == '__main__':
    sample_smiles = ["[Fe]CO", "CC[Si]", "CCCCH", "COC", "CCCl"]
    graph_data = [smiles_to_graph(s) for s in sample_smiles if smiles_to_graph(s) is not None]

    # Multi-property targets: [conductivity, porosity, surface_area]
    targets = torch.tensor([
        [0.3, 0.3, 0.7],
        [0.6, 0.4, 0.6],
        [0.5, 0.5, 0.5],
        [0.9, 0.2, 0.8],
        [0.7, 0.6, 0.7]
    ])

    generator = GraphGenerator(latent_dim=16, num_node_features=1, num_nodes=5)
    discriminator = GraphDiscriminator(in_channels=1)
    predictor = PropertyPredictor(in_channels=1, out_channels=3)

    # Train the predictor on known molecules
    train_property_predictor(predictor, graph_data, targets, epochs=100)

    # Generate and screen candidates based on conductivity (index 0)
    top_candidates = generate_and_screen(generator, predictor, top_k=5, property_index=0)

    # Display results
    for i, (graph, score, all_props) in enumerate(top_candidates):
        print(f"\nCandidate {i+1}:")
        print(f"  Score (Conductivity): {score:.3f}")
        print(f"  All Properties: {all_props}")
        print(f"  Nodes: {graph.x.size(0)}, Edges: {graph.edge_index.size(1)}")