Spaces:
Runtime error
Runtime error
| import time | |
| import gradio as gr | |
| from rdkit import Chem | |
| import torch | |
| import os | |
| import pandas as pd | |
| import hashlib | |
| from torch_geometric.loader import DataLoader | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts | |
| from torch_geometric.nn import GCNConv, global_mean_pool, GATConv, GAE, GATv2Conv, GraphSAGE, GENConv, GMMConv, \ | |
| GravNetConv, MessagePassing, global_max_pool, global_add_pool, GAT, GINConv, GINEConv, GraphNorm, SAGEConv, RGATConv | |
| from torch.nn.functional import sigmoid | |
| from torch import nn | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool, MessagePassing | |
| from torch_geometric.utils import add_self_loops | |
| from torch_geometric.data import Data, in_memory_dataset, Dataset, InMemoryDataset | |
| from torch_geometric.loader import DataLoader | |
| import numpy as np | |
| import os | |
| import torch | |
| from torch_geometric.data import Dataset, Data | |
| from torch_geometric.utils import to_networkx, to_dense_adj | |
| import networkx as nx | |
| import pandas as pd | |
| from rdkit import Chem | |
| from rdkit.Chem.rdchem import BondType as BT | |
| from rdkit.Chem import AllChem | |
| from sklearn.preprocessing import OneHotEncoder | |
| CHIRALITY_LIST = [ | |
| Chem.rdchem.ChiralType.CHI_UNSPECIFIED, | |
| Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW, | |
| Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW, | |
| Chem.rdchem.ChiralType.CHI_OTHER | |
| ] | |
| BOND_LIST = [ | |
| BT.SINGLE, | |
| BT.DOUBLE, | |
| BT.TRIPLE, | |
| BT.AROMATIC | |
| ] | |
| BONDDIR_LIST = [ | |
| Chem.rdchem.BondDir.NONE, | |
| Chem.rdchem.BondDir.ENDUPRIGHT, | |
| Chem.rdchem.BondDir.ENDDOWNRIGHT | |
| ] | |
| hybridization_list = ['OTHER', 'S', 'SP', 'SP2', 'SP3', 'SP3D', 'SP3D2', 'UNSPECIFIED'] | |
| hybridization_encoder = OneHotEncoder() | |
| hybridization_encoder.fit(torch.range(0, len(hybridization_list) - 1).unsqueeze(-1)) | |
| atom_list = ['H', 'C', 'O', 'S', 'N', 'P', 'F', 'Cl', 'Br', 'I', 'Si'] | |
| atom_encoder = OneHotEncoder() | |
| atom_encoder.fit(torch.range(0, len(atom_list) - 1).unsqueeze(-1)) | |
| chirarity_encoder = OneHotEncoder() | |
| chirarity_encoder.fit(torch.range(0, len(CHIRALITY_LIST) - 1).unsqueeze(-1)) | |
| def get_data_list(mol_list): | |
| data_list = [] | |
| # mol = Chem.MolFromInchi(inchi, sanitize=False, removeHs=False) | |
| # mol = Chem.AddHs(mol) | |
| for mol in mol_list: | |
| weights = [] | |
| type_idx = [] | |
| chirality_idx = [] | |
| atomic_number = [] | |
| degrees = [] | |
| total_degrees = [] | |
| formal_charges = [] | |
| hybridization_types = [] | |
| explicit_valences = [] | |
| implicit_valences = [] | |
| total_valences = [] | |
| atom_map_nums = [] | |
| isotopes = [] | |
| radical_electrons = [] | |
| inrings = [] | |
| atom_is_aromatic = [] | |
| for atom in mol.GetAtoms(): | |
| atom_is_aromatic.append(atom.GetIsAromatic()) | |
| type_idx.append(atom_list.index(atom.GetSymbol())) | |
| chirality_idx.append(CHIRALITY_LIST.index(atom.GetChiralTag())) | |
| atomic_number.append(atom.GetAtomicNum()) | |
| degrees.append(atom.GetDegree()) | |
| weights.append(atom.GetMass()) | |
| total_degrees.append(atom.GetTotalDegree()) | |
| formal_charges.append(atom.GetFormalCharge()) | |
| hybridization_types.append(hybridization_list.index(str(atom.GetHybridization()))) | |
| explicit_valences.append(atom.GetExplicitValence()) | |
| implicit_valences.append(atom.GetImplicitValence()) | |
| total_valences.append(atom.GetTotalValence()) | |
| atom_map_nums.append(atom.GetAtomMapNum()) | |
| isotopes.append(atom.GetIsotope()) | |
| radical_electrons.append(atom.GetNumRadicalElectrons()) | |
| inrings.append(int(atom.IsInRing())) | |
| x1 = torch.tensor(type_idx, dtype=torch.float32).view(-1, 1) | |
| x2 = torch.tensor(chirality_idx, dtype=torch.float32).view(-1, 1) | |
| x3 = torch.tensor(weights, dtype=torch.float32).view(-1, 1) | |
| x4 = torch.tensor(degrees, dtype=torch.float32).view(-1, 1) | |
| x5 = torch.tensor(total_degrees, dtype=torch.float32).view(-1, 1) | |
| x6 = torch.tensor(formal_charges, dtype=torch.float32).view(-1, 1) | |
| x7 = torch.tensor(hybridization_types, dtype=torch.float32).view(-1, 1) | |
| x8 = torch.tensor(explicit_valences, dtype=torch.float32).view(-1, 1) | |
| x9 = torch.tensor(implicit_valences, dtype=torch.float32).view(-1, 1) | |
| x10 = torch.tensor(total_valences, dtype=torch.float32).view(-1, 1) | |
| x11 = torch.tensor(atom_map_nums, dtype=torch.float32).view(-1, 1) | |
| x12 = torch.tensor(isotopes, dtype=torch.float32).view(-1, 1) | |
| x13 = torch.tensor(radical_electrons, dtype=torch.float32).view(-1, 1) | |
| x14 = torch.tensor(inrings, dtype=torch.float32).view(-1, 1) | |
| # x15 = torch.tensor(atom_is_aromatic, dtype=torch.float32).view(-1, 1) | |
| # x = [x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14] | |
| x = torch.cat([torch.tensor(atom_encoder.transform(x1).toarray(), dtype=torch.float32), | |
| torch.tensor(chirarity_encoder.transform(x2).toarray(), dtype=torch.float32), | |
| x3, | |
| x4, | |
| x5, | |
| x6, | |
| torch.tensor(hybridization_encoder.transform(x7).toarray(), dtype=torch.float32), | |
| x8, | |
| x9, | |
| x10, | |
| x11, | |
| x12, | |
| x13, | |
| x14, ], dim=-1) | |
| row, col, edge_feat = [], [], [] | |
| for bond in mol.GetBonds(): | |
| start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() | |
| row += [start, end] | |
| col += [end, start] | |
| edge_feat.append([ | |
| BOND_LIST.index(bond.GetBondType()), | |
| BONDDIR_LIST.index(bond.GetBondDir()), | |
| float(int(bond.IsInRing())), | |
| float(int(bond.GetIsAromatic())), | |
| float(int(bond.GetIsConjugated())) | |
| ]) | |
| edge_feat.append([ | |
| BOND_LIST.index(bond.GetBondType()), | |
| BONDDIR_LIST.index(bond.GetBondDir()), | |
| float(int(bond.IsInRing())), | |
| float(int(bond.GetIsAromatic())), | |
| float(int(bond.GetIsConjugated())) | |
| ]) | |
| edge_index = torch.tensor([row, col], dtype=torch.long) | |
| edge_attr = torch.tensor(np.array(edge_feat), dtype=torch.float32) | |
| fingerprint = torch.tensor(AllChem.GetMorganFingerprintAsBitVect(mol, 2), dtype=torch.float32) | |
| data = Data(x=x, | |
| edge_index=edge_index, | |
| edge_attr=edge_attr, | |
| fingerprint=fingerprint,) | |
| data_list.append(data) | |
| return data_list | |
| class GraphTransformerBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs): | |
| super(GraphTransformerBlock, self).__init__(**kwargs) | |
| self.edge_dim = edge_dim | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim) | |
| self.linear = nn.Linear(heads * out_channels, out_channels) | |
| self.layerNorm = nn.LayerNorm(out_channels) | |
| self.dropout = dropout | |
| def forward(self, x, edge_index, edge_attr): | |
| x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr) | |
| x_gat = self.linear(x_gat) | |
| x_gat = self.layerNorm(x + x_gat) | |
| return F.dropout(x_gat, self.dropout, training=self.training) | |
| class GraphTransformerBlock2(nn.Module): | |
| def __init__(self, in_channels, out_channels, heads=3, edge_dim=5, dropout=0, **kwargs): | |
| super(GraphTransformerBlock2, self).__init__(**kwargs) | |
| self.edge_dim = edge_dim | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.conv = GATConv(in_channels, out_channels, heads=heads, edge_dim=edge_dim) | |
| self.linear1 = nn.Linear(heads * out_channels, out_channels) | |
| self.layerNorm1 = nn.LayerNorm(out_channels) | |
| self.linear2 = nn.Linear(out_channels, out_channels) | |
| self.layerNorm2 = nn.LayerNorm(out_channels) | |
| self.dropout = dropout | |
| def forward(self, x, edge_index, edge_attr): | |
| x_gat = self.conv(x=x, edge_index=edge_index, edge_attr=edge_attr) | |
| x_gat = self.linear1(x_gat) | |
| x_gat = self.layerNorm1(x + x_gat) | |
| linear_ = self.linear2(x_gat) | |
| linear_ = self.layerNorm2(linear_ + x_gat) | |
| return F.dropout(linear_, self.dropout, training=self.training) | |
| class Trainer(object): | |
| def __init__(self, model, lr, device): | |
| self.model = model | |
| from torch import optim | |
| self.optimizer = optim.AdamW(self.model.parameters(), lr=lr) | |
| torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.1, patience=10, | |
| verbose=False, threshold=0.0001, threshold_mode='rel', | |
| cooldown=0, min_lr=0, eps=1e-08) | |
| self.device = device | |
| def train(self, data_loader): | |
| criterion = torch.nn.L1Loss() | |
| for i, data in enumerate(data_loader): | |
| data.to(self.device) | |
| y_hat = self.model(data) | |
| loss = criterion(y_hat, data.y) | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| return 0 | |
| class Tester(object): | |
| def __init__(self, model, device): | |
| self.model = model | |
| self.device = device | |
| def test_regressor(self, data_loader): | |
| y_true = [] | |
| y_pred = [] | |
| with torch.no_grad(): | |
| for data in data_loader: | |
| data.to(self.device, non_blocking=True) | |
| y_hat = self.model(data) | |
| # total_loss += torch.abs(y_hat - data.y).sum() | |
| # mre_total = torch.div(torch.abs(y_hat - data.y), data.y).sum() | |
| y_true.append(data.y) | |
| y_pred.append(y_hat) | |
| y_true = torch.concat(y_true) | |
| y_pred = torch.concat(y_pred) | |
| mae = torch.abs(y_true - y_pred).mean() | |
| # mre = torch.div(torch.abs(y_true - y_pred), y_true).mean() | |
| # medAE = torch.median(torch.abs(y_true - y_pred)) | |
| # medRE = torch.median(torch.div(torch.abs(y_true - y_pred), y_true)) | |
| # | |
| # score = torchmetrics.R2Score().to(self.device) | |
| # r2 = score(y_pred, y_true) | |
| # return mae.item(), medAE.item(), mre.item(), medRE.item(), r2.item() | |
| return mae.item() | |
| class MyNet(nn.Module): | |
| def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'): | |
| super(MyNet, self).__init__() | |
| self.emb_dim = emb_dim | |
| self.feat_dim = feat_dim | |
| self.drop_ratio = drop_ratio | |
| self.in_linear = nn.Linear(34, emb_dim) | |
| self.conv1 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv2 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv3 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv4 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv5 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv6 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv7 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv8 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv9 = GraphTransformerBlock(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| if pool == 'mean': | |
| self.pool = global_mean_pool | |
| elif pool == 'max': | |
| self.pool = global_max_pool | |
| elif pool == 'add': | |
| self.pool = global_add_pool | |
| self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim) | |
| self.out_lin = nn.Sequential( | |
| nn.Linear(self.feat_dim, self.feat_dim // 8), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(self.feat_dim // 8, self.feat_dim // 64), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(self.feat_dim // 64, 1), | |
| ) | |
| self.conv1d1 = OneDimConvBlock() | |
| self.conv1d2 = OneDimConvBlock() | |
| self.conv1d3 = OneDimConvBlock() | |
| self.conv1d4 = OneDimConvBlock() | |
| self.conv1d5 = OneDimConvBlock() | |
| self.conv1d6 = OneDimConvBlock() | |
| self.conv1d7 = OneDimConvBlock() | |
| self.conv1d8 = OneDimConvBlock() | |
| self.conv1d9 = OneDimConvBlock() | |
| self.conv1d10 = OneDimConvBlock() | |
| self.conv1d11 = OneDimConvBlock() | |
| self.conv1d12 = OneDimConvBlock() | |
| self.preconcat1 = nn.Linear(2048, 1024) | |
| self.preconcat2 = nn.Linear(1024, self.feat_dim) | |
| self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim) | |
| self.after_cat_drop = nn.Dropout(self.drop_ratio) | |
| def forward(self, data): | |
| x = data.x | |
| edge_index = data.edge_index | |
| edge_attr = data.edge_attr | |
| batch = data.batch | |
| fringerprint = data.fingerprint.reshape(-1, 2048) | |
| h = self.in_linear(x) | |
| h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True) | |
| fringerprint = self.conv1d1(fringerprint) | |
| fringerprint = self.conv1d2(fringerprint) | |
| fringerprint = self.conv1d3(fringerprint) | |
| fringerprint = self.conv1d4(fringerprint) | |
| fringerprint = self.conv1d5(fringerprint) | |
| fringerprint = self.conv1d6(fringerprint) | |
| fringerprint = self.conv1d7(fringerprint) | |
| fringerprint = self.conv1d8(fringerprint) | |
| fringerprint = self.conv1d9(fringerprint) | |
| fringerprint = self.conv1d10(fringerprint) | |
| fringerprint = self.conv1d11(fringerprint) | |
| fringerprint = self.conv1d12(fringerprint) | |
| fringerprint = self.preconcat1(fringerprint) | |
| fringerprint = self.preconcat2(fringerprint) | |
| h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) | |
| h = self.pool(h, batch) | |
| h = self.feat_lin(h) | |
| concat = torch.concat([h, fringerprint], dim=-1) | |
| concat = self.afterconcat1(concat) | |
| concat = self.after_cat_drop(concat) | |
| out = self.out_lin(concat) | |
| return out.squeeze() | |
| class OneDimConvBlock(nn.Module): | |
| def __init__(self, in_channel=2048, out_channel=2048): | |
| super().__init__() | |
| self.attention_conv = OneDimAttention(in_channel, in_channel) | |
| self.batchnorm1 = torch.nn.BatchNorm1d(in_channel) | |
| self.batchnorm2 = torch.nn.BatchNorm1d(in_channel) | |
| self.linear1 = nn.Linear(in_channel, in_channel) | |
| self.linear2 = nn.Linear(in_channel, out_channel) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(in_channel, in_channel), | |
| nn.ReLU(), | |
| nn.Linear(in_channel, in_channel), | |
| nn.ReLU() | |
| ) | |
| def forward(self, x): | |
| h = self.attention_conv(x, x, x) | |
| h = self.batchnorm1(x + h) | |
| h_new = self.ffn(h) | |
| h_new = self.batchnorm2(h + h_new) | |
| return F.dropout1d(self.linear2(h_new), training=self.training) | |
| class OneDimAttention(nn.Module): | |
| def __init__(self, in_size, out_size): | |
| super().__init__() | |
| self.in_size = torch.tensor(in_size) | |
| self.out_size = out_size | |
| self.linear = nn.Linear(in_size, out_size) | |
| def forward(self, q, k, v): | |
| attention = torch.mul(q, k) / torch.sqrt(self.in_size) | |
| attention = self.linear(attention) | |
| return torch.mul(F.softmax(attention, dim=-1), v) | |
| class MyNetTest(nn.Module): | |
| def __init__(self, emb_dim=512, feat_dim=256, edge_dim=5, heads=3, drop_ratio=0, pool='add'): | |
| super(MyNetTest, self).__init__() | |
| self.emb_dim = emb_dim | |
| self.feat_dim = feat_dim | |
| self.drop_ratio = drop_ratio | |
| self.in_linear = nn.Linear(34, emb_dim) | |
| self.conv1 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv2 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv3 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv4 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv5 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv6 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv7 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv8 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| self.conv9 = GraphTransformerBlock2(emb_dim, emb_dim, heads=heads, edge_dim=edge_dim) | |
| if pool == 'mean': | |
| self.pool = global_mean_pool | |
| elif pool == 'max': | |
| self.pool = global_max_pool | |
| elif pool == 'add': | |
| self.pool = global_add_pool | |
| self.feat_lin = nn.Linear(self.emb_dim, self.feat_dim) | |
| self.out_lin = nn.Sequential( | |
| nn.Linear(self.feat_dim, self.feat_dim // 8), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(self.feat_dim // 8, self.feat_dim // 64), | |
| nn.ReLU(inplace=True), | |
| nn.Linear(self.feat_dim // 64, 1), | |
| ) | |
| self.conv1d1 = OneDimConvBlock() | |
| self.conv1d2 = OneDimConvBlock() | |
| self.conv1d3 = OneDimConvBlock() | |
| self.conv1d4 = OneDimConvBlock() | |
| self.conv1d5 = OneDimConvBlock() | |
| self.conv1d6 = OneDimConvBlock() | |
| self.conv1d7 = OneDimConvBlock() | |
| self.conv1d8 = OneDimConvBlock() | |
| self.conv1d9 = OneDimConvBlock() | |
| self.conv1d10 = OneDimConvBlock() | |
| self.conv1d11 = OneDimConvBlock() | |
| self.conv1d12 = OneDimConvBlock() | |
| self.preconcat1 = nn.Linear(2048, 1024) | |
| self.preconcat2 = nn.Linear(1024, self.feat_dim) | |
| self.afterconcat1 = nn.Linear(2 * self.feat_dim, self.feat_dim) | |
| self.after_cat_drop = nn.Dropout(self.drop_ratio) | |
| def forward(self, data): | |
| x = data.x | |
| edge_index = data.edge_index | |
| edge_attr = data.edge_attr | |
| batch = data.batch | |
| fringerprint = data.fingerprint.reshape(-1, 2048) | |
| h = self.in_linear(x) | |
| h = F.relu(self.conv1(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv2(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv3(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv4(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv5(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv6(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv7(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv8(h, edge_index, edge_attr), inplace=True) | |
| h = F.relu(self.conv9(h, edge_index, edge_attr), inplace=True) | |
| fringerprint = self.conv1d1(fringerprint) | |
| fringerprint = self.conv1d2(fringerprint) | |
| fringerprint = self.conv1d3(fringerprint) | |
| fringerprint = self.conv1d4(fringerprint) | |
| fringerprint = self.conv1d5(fringerprint) | |
| fringerprint = self.conv1d6(fringerprint) | |
| fringerprint = self.conv1d7(fringerprint) | |
| fringerprint = self.conv1d8(fringerprint) | |
| fringerprint = self.conv1d9(fringerprint) | |
| fringerprint = self.conv1d10(fringerprint) | |
| fringerprint = self.conv1d11(fringerprint) | |
| fringerprint = self.conv1d12(fringerprint) | |
| fringerprint = self.preconcat1(fringerprint) | |
| fringerprint = self.preconcat2(fringerprint) | |
| h = F.dropout(F.relu(h), self.drop_ratio, training=self.training) | |
| h = self.pool(h, batch) | |
| h = self.feat_lin(h) | |
| concat = torch.concat([h, fringerprint], dim=-1) | |
| concat = self.afterconcat1(concat) | |
| concat = self.after_cat_drop(concat) | |
| out = self.out_lin(concat) | |
| return out.squeeze() | |
| model = MyNet(emb_dim=512, feat_dim=512) | |
| state = torch.load(os.path.join(os.getcwd(),'best_state_download_dict.pth')) | |
| model.load_state_dict(state) | |
| model.eval() | |
| try: | |
| os.mkdir(os.path.join(os.getcwd(),'save_df')) | |
| except: | |
| pass | |
| def get_rt_from_mol(mol): | |
| data_list = get_data_list([mol]) | |
| loader = DataLoader(data_list,batch_size=1) | |
| for batch in loader: | |
| break | |
| return model(batch).item() | |
| def pred_file_btyes(file_bytes,progress=gr.Progress()): | |
| progress(0,desc='Starting') | |
| file_name = os.path.join( | |
| os.path.join(os.getcwd(),'save_df'), | |
| (hashlib.md5(str(file_bytes).encode('utf-8')).hexdigest()+'.csv') | |
| ) | |
| if os.path.exists(file_name): | |
| print('该文件已经存在') | |
| return file_name | |
| with open('temp.sdf','bw') as f: | |
| f.write(file_bytes) | |
| sup = Chem.SDMolSupplier('temp.sdf') | |
| df = pd.DataFrame(columns=['InChI','Predicted RT']) | |
| for mol in progress.tqdm(sup): | |
| try: | |
| inchi = Chem.MolToInchi(mol) | |
| rt = get_rt_from_mol(mol) | |
| df.loc[len(df)] = [inchi,rt] | |
| except: | |
| pass | |
| df.to_csv(file_name) | |
| return file_name | |
| demo = gr.Interface( | |
| pred_file_btyes, | |
| gr.File(type='binary'), | |
| gr.File(type='filepath'), | |
| title='RT-Transformer Rentention Time Predictor', | |
| description='Input SDF Molecule File,Predicted RT output with a CSV File', | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |