RT-Transformer / app.py
Xue-Jun's picture
Update app.py
3c29d21 verified
import time
import gradio as gr
from rdkit import Chem
import torch
import os
import pandas as pd
import hashlib
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GAE, GATv2Conv, GraphSAGE, GENConv, GMMConv, \
GravNetConv, MessagePassing, global_max_pool, global_add_pool, GAT, GINConv, GINEConv, GraphNorm, SAGEConv, RGATConv
from torch.nn.functional import sigmoid
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data, in_memory_dataset, Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
import numpy as np
import os
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.utils import to_networkx, to_dense_adj
import networkx as nx
import pandas as pd
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem
from sklearn.preprocessing import OneHotEncoder
CHIRALITY_LIST = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
BT.SINGLE,
BT.DOUBLE,
BT.TRIPLE,
BT.AROMATIC
]
BONDDIR_LIST = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
hybridization_list = ['OTHER', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'UNSPECIFIED']
hybridization_encoder = OneHotEncoder()
hybridization_encoder.fit(torch.range(0, len(hybridization_list) - 1).unsqueeze(-1))
atom_list = ['H', 'C', 'O', 'S', 'N', 'P', 'F', 'Cl', 'Br', 'I', 'Si']
atom_encoder = OneHotEncoder()
atom_encoder.fit(torch.range(0, len(atom_list) - 1).unsqueeze(-1))
chirarity_encoder = OneHotEncoder()
chirarity_encoder.fit(torch.range(0, len(CHIRALITY_LIST) - 1).unsqueeze(-1))
def get_data_list(mol_list):
data_list = []
# mol = Chem.MolFromInchi(inchi, sanitize=False, removeHs=False)
# mol = Chem.AddHs(mol)
for mol in mol_list:
mol = Chem.AddHs(mol)
weights = []
type_idx = []
chirality_idx = []
atomic_number = []
degrees = []
total_degrees = []
formal_charges = []
hybridization_types = []
explicit_valences = []
implicit_valences = []
total_valences = []
atom_map_nums = []
isotopes = []
radical_electrons = []
inrings = []
atom_is_aromatic = []
for atom in mol.GetAtoms():
atom_is_aromatic.append(atom.GetIsAromatic())
type_idx.append(atom_list.index(atom.GetSymbol()))
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
atomic_number.append(atom.GetAtomicNum())
degrees.append(atom.GetDegree())
weights.append(atom.GetMass())
total_degrees.append(atom.GetTotalDegree())
formal_charges.append(atom.GetFormalCharge())
hybridization_types.append(hybridization_list.index(str(atom.GetHybridization())))
explicit_valences.append(atom.GetExplicitValence())
implicit_valences.append(atom.GetImplicitValence())
total_valences.append(atom.GetTotalValence())
atom_map_nums.append(atom.GetAtomMapNum())
isotopes.append(atom.GetIsotope())
radical_electrons.append(atom.GetNumRadicalElectrons())
inrings.append(int(atom.IsInRing()))
x1 = torch.tensor(type_idx, dtype=torch.float32).view(-1, 1)
x2 = torch.tensor(chirality_idx, dtype=torch.float32).view(-1, 1)
x3 = torch.tensor(weights, dtype=torch.float32).view(-1, 1)
x4 = torch.tensor(degrees, dtype=torch.float32).view(-1, 1)
x5 = torch.tensor(total_degrees, dtype=torch.float32).view(-1, 1)
x6 = torch.tensor(formal_charges, dtype=torch.float32).view(-1, 1)
x7 = torch.tensor(hybridization_types, dtype=torch.float32).view(-1, 1)
x8 = torch.tensor(explicit_valences, dtype=torch.float32).view(-1, 1)
x9 = torch.tensor(implicit_valences, dtype=torch.float32).view(-1, 1)
x10 = torch.tensor(total_valences, dtype=torch.float32).view(-1, 1)
x11 = torch.tensor(atom_map_nums, dtype=torch.float32).view(-1, 1)
x12 = torch.tensor(isotopes, dtype=torch.float32).view(-1, 1)
x13 = torch.tensor(radical_electrons, dtype=torch.float32).view(-1, 1)
x14 = torch.tensor(inrings, dtype=torch.float32).view(-1, 1)
# x15 = torch.tensor(atom_is_aromatic, dtype=torch.float32).view(-1, 1)
# x = [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14]
x = torch.cat([torch.tensor(atom_encoder.transform(x1).toarray(), dtype=torch.float32),
torch.tensor(chirarity_encoder.transform(x2).toarray(), dtype=torch.float32),
x3,
x4,
x5,
x6,
torch.tensor(hybridization_encoder.transform(x7).toarray(), dtype=torch.float32),
x8,
x9,
x10,
x11,
x12,
x13,
x14, ], dim=-1)
row, col, edge_feat = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir()),
float(int(bond.IsInRing())),
float(int(bond.GetIsAromatic())),
float(int(bond.GetIsConjugated()))
])
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir()),
float(int(bond.IsInRing())),
float(int(bond.GetIsAromatic())),
float(int(bond.GetIsConjugated()))
])
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.float32)
fingerprint = torch.tensor(AllChem.GetMorganFingerprintAsBitVect(mol, 2), dtype=torch.float32)
data = Data(x=x,
edge_index=edge_index,
edge_attr=edge_attr,
fingerprint=fingerprint,)
data_list.append(data)
return data_list
class GraphTransformerBlock(nn.Module):
def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs):
super(GraphTransformerBlock, self).__init__(**kwargs)
self.edge_dim = edge_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
self.linear = nn.Linear(heads * out_channels, out_channels)
self.layerNorm = nn.LayerNorm(out_channels)
self.dropout = dropout
def forward(self, x, edge_index, edge_attr):
x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr)
x_gat = self.linear(x_gat)
x_gat = self.layerNorm(x + x_gat)
return F.dropout(x_gat, self.dropout, training=self.training)
class GraphTransformerBlock2(nn.Module):
def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs):
super(GraphTransformerBlock2, self).__init__(**kwargs)
self.edge_dim = edge_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
self.linear1 = nn.Linear(heads * out_channels, out_channels)
self.layerNorm1 = nn.LayerNorm(out_channels)
self.linear2 = nn.Linear(out_channels, out_channels)
self.layerNorm2 = nn.LayerNorm(out_channels)
self.dropout = dropout
def forward(self, x, edge_index, edge_attr):
x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr)
x_gat = self.linear1(x_gat)
x_gat = self.layerNorm1(x + x_gat)
linear_ = self.linear2(x_gat)
linear_ = self.layerNorm2(linear_ + x_gat)
return F.dropout(linear_, self.dropout, training=self.training)
class Trainer(object):
def __init__(self, model, lr, device):
self.model = model
from torch import optim
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=0.0001, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-08)
self.device = device
def train(self, data_loader):
criterion = torch.nn.L1Loss()
for i, data in enumerate(data_loader):
data.to(self.device)
y_hat = self.model(data)
loss = criterion(y_hat, data.y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return 0
class Tester(object):
def __init__(self, model, device):
self.model = model
self.device = device
def test_regressor(self, data_loader):
y_true = []
y_pred = []
with torch.no_grad():
for data in data_loader:
data.to(self.device, non_blocking=True)
y_hat = self.model(data)
# total_loss += torch.abs(y_hat - data.y).sum()
# mre_total = torch.div(torch.abs(y_hat - data.y), data.y).sum()
y_true.append(data.y)
y_pred.append(y_hat)
y_true = torch.concat(y_true)
y_pred = torch.concat(y_pred)
mae = torch.abs(y_true - y_pred).mean()
# mre = torch.div(torch.abs(y_true - y_pred), y_true).mean()
# medAE = torch.median(torch.abs(y_true - y_pred))
# medRE = torch.median(torch.div(torch.abs(y_true - y_pred), y_true))
#
# score = torchmetrics.R2Score().to(self.device)
# r2 = score(y_pred, y_true)
# return mae.item(), medAE.item(), mre.item(), medRE.item(), r2.item()
return mae.item()
class MyNet(nn.Module):
def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'):
super(MyNet, self).__init__()
self.emb_dim = emb_dim
self.feat_dim = feat_dim
self.drop_ratio = drop_ratio
self.in_linear = nn.Linear(34, emb_dim)
self.conv1 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv2 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv3 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv4 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv5 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv6 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv7 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv8 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv9 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
if pool == 'mean':
self.pool = global_mean_pool
elif pool == 'max':
self.pool = global_max_pool
elif pool == 'add':
self.pool = global_add_pool
self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
self.out_lin = nn.Sequential(
nn.Linear(self.feat_dim, self.feat_dim // 8),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 8, self.feat_dim // 64),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 64, 1),
)
self.conv1d1 = OneDimConvBlock()
self.conv1d2 = OneDimConvBlock()
self.conv1d3 = OneDimConvBlock()
self.conv1d4 = OneDimConvBlock()
self.conv1d5 = OneDimConvBlock()
self.conv1d6 = OneDimConvBlock()
self.conv1d7 = OneDimConvBlock()
self.conv1d8 = OneDimConvBlock()
self.conv1d9 = OneDimConvBlock()
self.conv1d10 = OneDimConvBlock()
self.conv1d11 = OneDimConvBlock()
self.conv1d12 = OneDimConvBlock()
self.preconcat1 = nn.Linear(2048, 1024)
self.preconcat2 = nn.Linear(1024, self.feat_dim)
self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim)
self.after_cat_drop = nn.Dropout(self.drop_ratio)
def forward(self, data):
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
batch = data.batch
fringerprint = data.fingerprint.reshape(-1, 2048)
h = self.in_linear(x)
h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True)
fringerprint = self.conv1d1(fringerprint)
fringerprint = self.conv1d2(fringerprint)
fringerprint = self.conv1d3(fringerprint)
fringerprint = self.conv1d4(fringerprint)
fringerprint = self.conv1d5(fringerprint)
fringerprint = self.conv1d6(fringerprint)
fringerprint = self.conv1d7(fringerprint)
fringerprint = self.conv1d8(fringerprint)
fringerprint = self.conv1d9(fringerprint)
fringerprint = self.conv1d10(fringerprint)
fringerprint = self.conv1d11(fringerprint)
fringerprint = self.conv1d12(fringerprint)
fringerprint = self.preconcat1(fringerprint)
fringerprint = self.preconcat2(fringerprint)
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h = self.pool(h, batch)
h = self.feat_lin(h)
concat = torch.concat([h, fringerprint], dim=-1)
concat = self.afterconcat1(concat)
concat = self.after_cat_drop(concat)
out = self.out_lin(concat)
return out.squeeze()
class OneDimConvBlock(nn.Module):
def __init__(self, in_channel=2048, out_channel=2048):
super().__init__()
self.attention_conv = OneDimAttention(in_channel, in_channel)
self.batchnorm1 = torch.nn.BatchNorm1d(in_channel)
self.batchnorm2 = torch.nn.BatchNorm1d(in_channel)
self.linear1 = nn.Linear(in_channel, in_channel)
self.linear2 = nn.Linear(in_channel, out_channel)
self.ffn = nn.Sequential(
nn.Linear(in_channel, in_channel),
nn.ReLU(),
nn.Linear(in_channel, in_channel),
nn.ReLU()
)
def forward(self, x):
h = self.attention_conv(x, x, x)
h = self.batchnorm1(x + h)
h_new = self.ffn(h)
h_new = self.batchnorm2(h + h_new)
return F.dropout1d(self.linear2(h_new), training=self.training)
class OneDimAttention(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.in_size = torch.tensor(in_size)
self.out_size = out_size
self.linear = nn.Linear(in_size, out_size)
def forward(self, q, k, v):
attention = torch.mul(q, k) / torch.sqrt(self.in_size)
attention = self.linear(attention)
return torch.mul(F.softmax(attention, dim=-1), v)
class MyNetTest(nn.Module):
def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'):
super(MyNetTest, self).__init__()
self.emb_dim = emb_dim
self.feat_dim = feat_dim
self.drop_ratio = drop_ratio
self.in_linear = nn.Linear(34, emb_dim)
self.conv1 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv2 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv3 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv4 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv5 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv6 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv7 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv8 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv9 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
if pool == 'mean':
self.pool = global_mean_pool
elif pool == 'max':
self.pool = global_max_pool
elif pool == 'add':
self.pool = global_add_pool
self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
self.out_lin = nn.Sequential(
nn.Linear(self.feat_dim, self.feat_dim // 8),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 8, self.feat_dim // 64),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 64, 1),
)
self.conv1d1 = OneDimConvBlock()
self.conv1d2 = OneDimConvBlock()
self.conv1d3 = OneDimConvBlock()
self.conv1d4 = OneDimConvBlock()
self.conv1d5 = OneDimConvBlock()
self.conv1d6 = OneDimConvBlock()
self.conv1d7 = OneDimConvBlock()
self.conv1d8 = OneDimConvBlock()
self.conv1d9 = OneDimConvBlock()
self.conv1d10 = OneDimConvBlock()
self.conv1d11 = OneDimConvBlock()
self.conv1d12 = OneDimConvBlock()
self.preconcat1 = nn.Linear(2048, 1024)
self.preconcat2 = nn.Linear(1024, self.feat_dim)
self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim)
self.after_cat_drop = nn.Dropout(self.drop_ratio)
def forward(self, data):
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
batch = data.batch
fringerprint = data.fingerprint.reshape(-1, 2048)
h = self.in_linear(x)
h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True)
fringerprint = self.conv1d1(fringerprint)
fringerprint = self.conv1d2(fringerprint)
fringerprint = self.conv1d3(fringerprint)
fringerprint = self.conv1d4(fringerprint)
fringerprint = self.conv1d5(fringerprint)
fringerprint = self.conv1d6(fringerprint)
fringerprint = self.conv1d7(fringerprint)
fringerprint = self.conv1d8(fringerprint)
fringerprint = self.conv1d9(fringerprint)
fringerprint = self.conv1d10(fringerprint)
fringerprint = self.conv1d11(fringerprint)
fringerprint = self.conv1d12(fringerprint)
fringerprint = self.preconcat1(fringerprint)
fringerprint = self.preconcat2(fringerprint)
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h = self.pool(h, batch)
h = self.feat_lin(h)
concat = torch.concat([h, fringerprint], dim=-1)
concat = self.afterconcat1(concat)
concat = self.after_cat_drop(concat)
out = self.out_lin(concat)
return out.squeeze()
model = MyNet(emb_dim=512, feat_dim=512)
state = torch.load(os.path.join(os.getcwd(),'best_state_download_dict.pth'))
model.load_state_dict(state)
model.eval()
try:
os.mkdir(os.path.join(os.getcwd(),'save_df'))
except:
pass
def get_rt_from_mol(mol):
data_list = get_data_list([mol])
loader = DataLoader(data_list,batch_size=1)
for batch in loader:
break
return model(batch).item()
def pred_file_btyes(file_bytes,progress=gr.Progress()):
progress(0,desc='Starting')
file_name = os.path.join(
os.path.join(os.getcwd(),'save_df'),
(hashlib.md5(str(file_bytes).encode('utf-8')).hexdigest()+'.csv')
)
if os.path.exists(file_name):
print('该文件已经存在')
return file_name
with open('temp.sdf','bw') as f:
f.write(file_bytes)
sup = Chem.SDMolSupplier('temp.sdf')
df = pd.DataFrame(columns=['InChI','Predicted RT'])
for mol in progress.tqdm(sup):
try:
inchi = Chem.MolToInchi(mol)
rt = get_rt_from_mol(mol)
df.loc[len(df)] = [inchi,rt]
except:
pass
df.to_csv(file_name)
return file_name
demo = gr.Interface(
pred_file_btyes,
gr.File(type='binary'),
gr.File(type='filepath'),
title='RT-Transformer Rentention Time Predictor',
description='Input SDF Molecule File,Predicted RT output with a CSV File',
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")