RT-Transformer / app.py
Xue-Jun's picture
Update app.py
75c50a7
raw
history blame
22.2 kB
import time
import gradio as gr
from rdkit import Chem
import torch
import os
import pandas as pd
import hashlib
from torch_geometric.loader import DataLoader
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GAE, GATv2Conv, GraphSAGE, GENConv, GMMConv, \
GravNetConv, MessagePassing, global_max_pool, global_add_pool, GAT, GINConv, GINEConv, GraphNorm, SAGEConv, RGATConv
from torch.nn.functional import sigmoid
from torch import nn
import numpy as np
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, MessagePassing
from torch_geometric.utils import add_self_loops
from torch_geometric.data import Data, in_memory_dataset, Dataset, InMemoryDataset
from torch_geometric.loader import DataLoader
import numpy as np
import os
import torch
from torch_geometric.data import Dataset, Data
from torch_geometric.utils import to_networkx, to_dense_adj
import networkx as nx
import pandas as pd
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem
from sklearn.preprocessing import OneHotEncoder
CHIRALITY_LIST = [
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
Chem.rdchem.ChiralType.CHI_OTHER
]
BOND_LIST = [
BT.SINGLE,
BT.DOUBLE,
BT.TRIPLE,
BT.AROMATIC
]
BONDDIR_LIST = [
Chem.rdchem.BondDir.NONE,
Chem.rdchem.BondDir.ENDUPRIGHT,
Chem.rdchem.BondDir.ENDDOWNRIGHT
]
hybridization_list = ['OTHER', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'UNSPECIFIED']
hybridization_encoder = OneHotEncoder()
hybridization_encoder.fit(torch.range(0, len(hybridization_list) - 1).unsqueeze(-1))
atom_list = ['H', 'C', 'O', 'S', 'N', 'P', 'F', 'Cl', 'Br', 'I', 'Si']
atom_encoder = OneHotEncoder()
atom_encoder.fit(torch.range(0, len(atom_list) - 1).unsqueeze(-1))
chirarity_encoder = OneHotEncoder()
chirarity_encoder.fit(torch.range(0, len(CHIRALITY_LIST) - 1).unsqueeze(-1))
def get_data_list(mol_list):
data_list = []
# mol = Chem.MolFromInchi(inchi, sanitize=False, removeHs=False)
#
for mol in mol_list:
mol = Chem.AddHs(mol)
weights = []
type_idx = []
chirality_idx = []
atomic_number = []
degrees = []
total_degrees = []
formal_charges = []
hybridization_types = []
explicit_valences = []
implicit_valences = []
total_valences = []
atom_map_nums = []
isotopes = []
radical_electrons = []
inrings = []
atom_is_aromatic = []
for atom in mol.GetAtoms():
atom_is_aromatic.append(atom.GetIsAromatic())
type_idx.append(atom_list.index(atom.GetSymbol()))
chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag()))
atomic_number.append(atom.GetAtomicNum())
degrees.append(atom.GetDegree())
weights.append(atom.GetMass())
total_degrees.append(atom.GetTotalDegree())
formal_charges.append(atom.GetFormalCharge())
hybridization_types.append(hybridization_list.index(str(atom.GetHybridization())))
explicit_valences.append(atom.GetExplicitValence())
implicit_valences.append(atom.GetImplicitValence())
total_valences.append(atom.GetTotalValence())
atom_map_nums.append(atom.GetAtomMapNum())
isotopes.append(atom.GetIsotope())
radical_electrons.append(atom.GetNumRadicalElectrons())
inrings.append(int(atom.IsInRing()))
x1 = torch.tensor(type_idx, dtype=torch.float32).view(-1, 1)
x2 = torch.tensor(chirality_idx, dtype=torch.float32).view(-1, 1)
x3 = torch.tensor(weights, dtype=torch.float32).view(-1, 1)
x4 = torch.tensor(degrees, dtype=torch.float32).view(-1, 1)
x5 = torch.tensor(total_degrees, dtype=torch.float32).view(-1, 1)
x6 = torch.tensor(formal_charges, dtype=torch.float32).view(-1, 1)
x7 = torch.tensor(hybridization_types, dtype=torch.float32).view(-1, 1)
x8 = torch.tensor(explicit_valences, dtype=torch.float32).view(-1, 1)
x9 = torch.tensor(implicit_valences, dtype=torch.float32).view(-1, 1)
x10 = torch.tensor(total_valences, dtype=torch.float32).view(-1, 1)
x11 = torch.tensor(atom_map_nums, dtype=torch.float32).view(-1, 1)
x12 = torch.tensor(isotopes, dtype=torch.float32).view(-1, 1)
x13 = torch.tensor(radical_electrons, dtype=torch.float32).view(-1, 1)
x14 = torch.tensor(inrings, dtype=torch.float32).view(-1, 1)
# x15 = torch.tensor(atom_is_aromatic, dtype=torch.float32).view(-1, 1)
# x = [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14]
x = torch.cat([torch.tensor(atom_encoder.transform(x1).toarray(), dtype=torch.float32),
torch.tensor(chirarity_encoder.transform(x2).toarray(), dtype=torch.float32),
x3,
x4,
x5,
x6,
torch.tensor(hybridization_encoder.transform(x7).toarray(), dtype=torch.float32),
x8,
x9,
x10,
x11,
x12,
x13,
x14, ], dim=-1)
row, col, edge_feat = [], [], []
for bond in mol.GetBonds():
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
row += [start, end]
col += [end, start]
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir()),
float(int(bond.IsInRing())),
float(int(bond.GetIsAromatic())),
float(int(bond.GetIsConjugated()))
])
edge_feat.append([
BOND_LIST.index(bond.GetBondType()),
BONDDIR_LIST.index(bond.GetBondDir()),
float(int(bond.IsInRing())),
float(int(bond.GetIsAromatic())),
float(int(bond.GetIsConjugated()))
])
edge_index = torch.tensor([row, col], dtype=torch.long)
edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.float32)
fingerprint = torch.tensor(AllChem.GetMorganFingerprintAsBitVect(mol, 2), dtype=torch.float32)
data = Data(x=x,
edge_index=edge_index,
edge_attr=edge_attr,
fingerprint=fingerprint,)
data_list.append(data)
return data_list
class GraphTransformerBlock(nn.Module):
def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs):
super(GraphTransformerBlock, self).__init__(**kwargs)
self.edge_dim = edge_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
self.linear = nn.Linear(heads * out_channels, out_channels)
self.layerNorm = nn.LayerNorm(out_channels)
self.dropout = dropout
def forward(self, x, edge_index, edge_attr):
x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr)
x_gat = self.linear(x_gat)
x_gat = self.layerNorm(x + x_gat)
return F.dropout(x_gat, self.dropout, training=self.training)
class GraphTransformerBlock2(nn.Module):
def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs):
super(GraphTransformerBlock2, self).__init__(**kwargs)
self.edge_dim = edge_dim
self.in_channels = in_channels
self.out_channels = out_channels
self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim)
self.linear1 = nn.Linear(heads * out_channels, out_channels)
self.layerNorm1 = nn.LayerNorm(out_channels)
self.linear2 = nn.Linear(out_channels, out_channels)
self.layerNorm2 = nn.LayerNorm(out_channels)
self.dropout = dropout
def forward(self, x, edge_index, edge_attr):
x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr)
x_gat = self.linear1(x_gat)
x_gat = self.layerNorm1(x + x_gat)
linear_ = self.linear2(x_gat)
linear_ = self.layerNorm2(linear_ + x_gat)
return F.dropout(linear_, self.dropout, training=self.training)
class Trainer(object):
def __init__(self, model, lr, device):
self.model = model
from torch import optim
self.optimizer = optim.AdamW(self.model.parameters(), lr=lr)
torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=10,
verbose=False, threshold=0.0001, threshold_mode='rel',
cooldown=0, min_lr=0, eps=1e-08)
self.device = device
def train(self, data_loader):
criterion = torch.nn.L1Loss()
for i, data in enumerate(data_loader):
data.to(self.device)
y_hat = self.model(data)
loss = criterion(y_hat, data.y)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
return 0
class Tester(object):
def __init__(self, model, device):
self.model = model
self.device = device
def test_regressor(self, data_loader):
y_true = []
y_pred = []
with torch.no_grad():
for data in data_loader:
data.to(self.device, non_blocking=True)
y_hat = self.model(data)
# total_loss += torch.abs(y_hat - data.y).sum()
# mre_total = torch.div(torch.abs(y_hat - data.y), data.y).sum()
y_true.append(data.y)
y_pred.append(y_hat)
y_true = torch.concat(y_true)
y_pred = torch.concat(y_pred)
mae = torch.abs(y_true - y_pred).mean()
# mre = torch.div(torch.abs(y_true - y_pred), y_true).mean()
# medAE = torch.median(torch.abs(y_true - y_pred))
# medRE = torch.median(torch.div(torch.abs(y_true - y_pred), y_true))
#
# score = torchmetrics.R2Score().to(self.device)
# r2 = score(y_pred, y_true)
# return mae.item(), medAE.item(), mre.item(), medRE.item(), r2.item()
return mae.item()
class MyNet(nn.Module):
def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'):
super(MyNet, self).__init__()
self.emb_dim = emb_dim
self.feat_dim = feat_dim
self.drop_ratio = drop_ratio
self.in_linear = nn.Linear(34, emb_dim)
self.conv1 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv2 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv3 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv4 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv5 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv6 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv7 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv8 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv9 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
if pool == 'mean':
self.pool = global_mean_pool
elif pool == 'max':
self.pool = global_max_pool
elif pool == 'add':
self.pool = global_add_pool
self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
self.out_lin = nn.Sequential(
nn.Linear(self.feat_dim, self.feat_dim // 8),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 8, self.feat_dim // 64),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 64, 1),
)
self.conv1d1 = OneDimConvBlock()
self.conv1d2 = OneDimConvBlock()
self.conv1d3 = OneDimConvBlock()
self.conv1d4 = OneDimConvBlock()
self.conv1d5 = OneDimConvBlock()
self.conv1d6 = OneDimConvBlock()
self.conv1d7 = OneDimConvBlock()
self.conv1d8 = OneDimConvBlock()
self.conv1d9 = OneDimConvBlock()
self.conv1d10 = OneDimConvBlock()
self.conv1d11 = OneDimConvBlock()
self.conv1d12 = OneDimConvBlock()
self.preconcat1 = nn.Linear(2048, 1024)
self.preconcat2 = nn.Linear(1024, self.feat_dim)
self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim)
self.after_cat_drop = nn.Dropout(self.drop_ratio)
def forward(self, data):
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
batch = data.batch
fringerprint = data.fingerprint.reshape(-1, 2048)
h = self.in_linear(x)
h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True)
fringerprint = self.conv1d1(fringerprint)
fringerprint = self.conv1d2(fringerprint)
fringerprint = self.conv1d3(fringerprint)
fringerprint = self.conv1d4(fringerprint)
fringerprint = self.conv1d5(fringerprint)
fringerprint = self.conv1d6(fringerprint)
fringerprint = self.conv1d7(fringerprint)
fringerprint = self.conv1d8(fringerprint)
fringerprint = self.conv1d9(fringerprint)
fringerprint = self.conv1d10(fringerprint)
fringerprint = self.conv1d11(fringerprint)
fringerprint = self.conv1d12(fringerprint)
fringerprint = self.preconcat1(fringerprint)
fringerprint = self.preconcat2(fringerprint)
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h = self.pool(h, batch)
h = self.feat_lin(h)
concat = torch.concat([h, fringerprint], dim=-1)
concat = self.afterconcat1(concat)
concat = self.after_cat_drop(concat)
out = self.out_lin(concat)
return out.squeeze()
class OneDimConvBlock(nn.Module):
def __init__(self, in_channel=2048, out_channel=2048):
super().__init__()
self.attention_conv = OneDimAttention(in_channel, in_channel)
self.batchnorm1 = torch.nn.BatchNorm1d(in_channel)
self.batchnorm2 = torch.nn.BatchNorm1d(in_channel)
self.linear1 = nn.Linear(in_channel, in_channel)
self.linear2 = nn.Linear(in_channel, out_channel)
self.ffn = nn.Sequential(
nn.Linear(in_channel, in_channel),
nn.ReLU(),
nn.Linear(in_channel, in_channel),
nn.ReLU()
)
def forward(self, x):
h = self.attention_conv(x, x, x)
h = self.batchnorm1(x + h)
h_new = self.ffn(h)
h_new = self.batchnorm2(h + h_new)
return F.dropout1d(self.linear2(h_new), training=self.training)
class OneDimAttention(nn.Module):
def __init__(self, in_size, out_size):
super().__init__()
self.in_size = torch.tensor(in_size)
self.out_size = out_size
self.linear = nn.Linear(in_size, out_size)
def forward(self, q, k, v):
attention = torch.mul(q, k) / torch.sqrt(self.in_size)
attention = self.linear(attention)
return torch.mul(F.softmax(attention, dim=-1), v)
class MyNetTest(nn.Module):
def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'):
super(MyNetTest, self).__init__()
self.emb_dim = emb_dim
self.feat_dim = feat_dim
self.drop_ratio = drop_ratio
self.in_linear = nn.Linear(34, emb_dim)
self.conv1 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv2 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv3 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv4 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv5 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv6 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv7 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv8 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
self.conv9 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim)
if pool == 'mean':
self.pool = global_mean_pool
elif pool == 'max':
self.pool = global_max_pool
elif pool == 'add':
self.pool = global_add_pool
self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim)
self.out_lin = nn.Sequential(
nn.Linear(self.feat_dim, self.feat_dim // 8),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 8, self.feat_dim // 64),
nn.ReLU(inplace=True),
nn.Linear(self.feat_dim // 64, 1),
)
self.conv1d1 = OneDimConvBlock()
self.conv1d2 = OneDimConvBlock()
self.conv1d3 = OneDimConvBlock()
self.conv1d4 = OneDimConvBlock()
self.conv1d5 = OneDimConvBlock()
self.conv1d6 = OneDimConvBlock()
self.conv1d7 = OneDimConvBlock()
self.conv1d8 = OneDimConvBlock()
self.conv1d9 = OneDimConvBlock()
self.conv1d10 = OneDimConvBlock()
self.conv1d11 = OneDimConvBlock()
self.conv1d12 = OneDimConvBlock()
self.preconcat1 = nn.Linear(2048, 1024)
self.preconcat2 = nn.Linear(1024, self.feat_dim)
self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim)
self.after_cat_drop = nn.Dropout(self.drop_ratio)
def forward(self, data):
x = data.x
edge_index = data.edge_index
edge_attr = data.edge_attr
batch = data.batch
fringerprint = data.fingerprint.reshape(-1, 2048)
h = self.in_linear(x)
h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True)
h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True)
fringerprint = self.conv1d1(fringerprint)
fringerprint = self.conv1d2(fringerprint)
fringerprint = self.conv1d3(fringerprint)
fringerprint = self.conv1d4(fringerprint)
fringerprint = self.conv1d5(fringerprint)
fringerprint = self.conv1d6(fringerprint)
fringerprint = self.conv1d7(fringerprint)
fringerprint = self.conv1d8(fringerprint)
fringerprint = self.conv1d9(fringerprint)
fringerprint = self.conv1d10(fringerprint)
fringerprint = self.conv1d11(fringerprint)
fringerprint = self.conv1d12(fringerprint)
fringerprint = self.preconcat1(fringerprint)
fringerprint = self.preconcat2(fringerprint)
h = F.dropout(F.relu(h), self.drop_ratio, training=self.training)
h = self.pool(h, batch)
h = self.feat_lin(h)
concat = torch.concat([h, fringerprint], dim=-1)
concat = self.afterconcat1(concat)
concat = self.after_cat_drop(concat)
out = self.out_lin(concat)
return out.squeeze()
model = MyNet(emb_dim=512, feat_dim=512)
state = torch.load(os.path.join(os.getcwd(),'best_state_download_dict.pth'))
model.load_state_dict(state)
model.eval()
try:
os.mkdir(os.path.join(os.getcwd(),'save_df'))
except:
pass
def get_rt_from_mol(mol):
data_list = get_data_list([mol])
loader = DataLoader(data_list,batch_size=1)
for batch in loader:
break
return model(batch).item()
def pred_file_btyes(file_bytes,progress=gr.Progress()):
progress(0,desc='Starting')
file_name = os.path.join(
os.path.join(os.getcwd(),'save_df'),
(hashlib.md5(str(file_bytes).encode('utf-8')).hexdigest()+'.csv')
)
if os.path.exists(file_name):
print('该文件已经存在')
return file_name
with open('temp.sdf','bw') as f:
f.write(file_bytes)
sup = Chem.SDMolSupplier('temp.sdf')
df = pd.DataFrame(columns=['InChI','Predicted RT'])
for mol in progress.tqdm(sup):
try:
inchi = Chem.MolToInchi(mol)
rt = get_rt_from_mol(mol)
df.loc[len(df)] = [inchi,rt]
except:
pass
df.to_csv(file_name)
return file_name
demo = gr.Interface(
pred_file_btyes,
gr.File(type='binary'),
gr.File(type='filepath'),
title='RT-Transformer Rentention Time Predictor',
description='Input SDF Molecule File,Predicted RT output with a CSV File',
)
if __name__ == "__main__":
demo.launch()