""" Copyright (c) All Rights Reserved by bowen """ import json import math import os import sys import pathlib from typing import Iterable, List import random import itertools import numpy as np import pandas as pd import tqdm import torch import torch.amp from PIL import Image # from src.data import CocoEvaluator # from src.misc import (MetricLogger, SmoothedValue, reduce_dict) # from src.solver.utils import output_to_smiles, output_to_smiles2 # from src.solver.utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral # from src.misc.draw_box_utils import draw_objs # from sklearn.metrics import f1_score # from src.postprocess.abbreviation_detector import get_ocr_recognition_only # from src.postprocess.utils_dataset import CaptionRemover from skimage.measure import label ######################################add metric postprocess import rdkit from rdkit import Chem from rdkit.Chem import Draw, AllChem from rdkit.Chem import rdchem, RWMol, CombineMols from rdkit import Chem from rdkit.Chem import rdFMCS import copy from paddleocr import PaddleOCR import re from rdkit import DataStructs import matplotlib.pyplot as plt from matplotlib.patches import Rectangle, Circle from scipy.spatial import cKDTree, KDTree from rdkit.Geometry import Point3D import multiprocessing def select_longest_smiles(smiles): # 将 SMILES 以 '.' 分割为多个部分 components = smiles.split('.') # 选择字符数最多的部分作为主结构 longest_component = max(components, key=len) return longest_component def MCS_mol(mcs): #mcs_smart = mcs.smartsString mcs_mol = Chem.MolFromSmarts(mcs.smartsString) AllChem.Compute2DCoords(mcs_mol) return mcs_mol def g_atompair_matches(pair,mcs): mcs_mol = MCS_mol(mcs) matches0 = pair[0].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) matches1 = pair[1].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) if len(matches0) != len(matches1): matches0=list(matches0) matches1=list(matches1) print( " g_atompair_matches noted: matcher not equal !!") if len(matches0)>len(matches1) and len(matches1) !=0: for i in range(0,len(matches0)): if i < len(matches1): pass else: ii=i % len(matches1) matches1.append(matches1[ii]) else: for i in range(0,len(matches1)): if i < len(matches0) and len(matches0): pass else: ii=i % len(matches0) matches0.append(matches0[ii]) # assert len(matches0) == len(matches1), "matcher not equal break!!" if len(matches0) != len(matches1): atommaping_pairs=[[]] else:atommaping_pairs=[list(zip(matches0[i],matches1[i])) for i in range(0,len(matches0))] return atommaping_pairs class CustomError(Exception): """A custom exception for specific errors.""" pass bond_dirs = {'NONE': Chem.rdchem.BondDir.NONE, 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT, 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE, 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH, 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,} BONDTYPE = {'SINGLE': Chem.rdchem.BondType.SINGLE, 'DOUBLE': Chem.rdchem.BondType.DOUBLE, 'TRIPLE': Chem.rdchem.BondType.TRIPLE, 'AROMATIC': Chem.rdchem.BondType.AROMATIC} BOND_DIRS = {'NONE': Chem.rdchem.BondDir.NONE, 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT, 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE, 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH, 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,} BONDDIRECT=['ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT'] BONDTYPE2ORD={ 'wdge':1, 'dash':1, Chem.rdchem.BondType.SINGLE: 1, Chem.rdchem.BondType.DOUBLE: 2, Chem.rdchem.BondType.TRIPLE: 3, Chem.rdchem.BondType.AROMATIC: 1.5, } BONDTYPE={'SINGLE': Chem.BondType.SINGLE, 'DOUBLE': Chem.BondType.DOUBLE, 'TRIPLE': Chem.BondType.TRIPLE, 'AROMATIC': Chem.BondType.AROMATIC} VALENCES = { "H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1], "Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2], "Br": [1], "I": [1], "*":[3,4,5,6], } ELEMENTS = [ "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", "Sc", "Ti", "Ru", "Rh","Rn","Rf", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", "Ga", "Ge", "As", "Se", "Br", "Kr", "Sr", "Zr", "Nb", "Mo", "Tc", "Pd", "Ag", "Cd", "In", "Sn", "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu", "Hf", "Ta", "W", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi", "Po", "At", "Fr", "Ac", "Th", "Pa", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", "Md", "No", "Lr", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Cn", "Nh", "Fl", "Mc", "Lv", "Og" ] # "Rg", "Rb", "Re", "Ra"as RGROUP in the Molscribe data #"V", "Y","U", # be viewed as C for paddleOCR smt ONELEMENTS ['A','J] #"Ts" #as a chemical group [S](C1=CC=C(C=C1)C)(=O)=O RGROUP_SYMBOLS = ['R',"R'" 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', 'Ra', 'Rb', 'Rc', 'Rd','Re','Rg', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', "V", "Y","U",'M', 'G','L', 'Nr','Tt','Uu','Vv','Ww',#CLEF Nr is not in periodic table 'D',#CLEF as [2H] but not recongited by rdkit chemdraw ] COLORS = { u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0', u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75' } class Substitution(object): '''Define common substitutions for chemical shorthand''' def __init__(self, abbrvs, smarts, smiles, probability): assert type(abbrvs) is list self.abbrvs = abbrvs self.smarts = smarts self.smiles = smiles self.probability = probability SUBSTITUTIONS: List[Substitution] = [ #abbrvs, smarts, smiles #patch4 USPTO,try put the longer one first, as re use match by order Substitution(['CH2CH2NSO2CH3'], '[CH2][CH]', '[CH2]CNS(=O)(C)=O', 0.5), Substitution(['NHNHCOCF3'], 'NHNHCOCF3', '[NH]NC(=O)C(F)(F)(F)', 0.5), Substitution(['CO2CysPr'], 'CO2CysPr', '[C](=O)ON[C@H](C(CCC)=O)CS', 0.5), Substitution(['OCH2CHOHCH2'], 'OCH2CHOHCH2', '[O]CC(O)C', 0.5), Substitution(['OCH2CHOHCH2OH'], 'OCH2CHOHCH2', '[O]CC(O)CO', 0.5), # elif symbol in ['SO2(CH2)3SO2NHCH2CHCH2OH']:smiles='[S](=O)(=O)CCCS(=O)(=O)NC[C]CO' Substitution(['SO2(CH2)3SO2NHCH2CHCH2OH'], 'OCH2CHOHCH2', '[S](=O)(=O)CCCS(=O)(=O)NC[C]CO', 0.5), Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5), # Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5), Substitution(['CO2Et', 'COOEt'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5), Substitution(['OAc','AcO'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7), Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7), Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1), Substitution(['OBz','BzO'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7), # Benzoyl Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2), # Benzoyl Substitution(['COOBn','BnO2C'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)OCc1ccccc1", 0.7), # Benzyl Substitution(['OBn','BnO'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7), # Benzyl Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2), # Benzyl Substitution(['NHBn'], '[NH]Cc1ccccc1', "[NH]Cc1ccccc1", 0.2), # Benzyl Substitution(['NBn2'], '[NH]Cc1ccccc1', "[N](Cc1ccccc1)Cc1ccccc1", 0.2), # Benzyl Substitution(['NHBoc','BocHN',"BOCHN"], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH]C(=O)OC(C)(C)C", 0.6), Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6), Substitution(['Boc','BOc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2), Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2), Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4), Substitution(['NHCbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[NH]C(=O)OCc1ccccc1", 0.4), Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3), Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3', "[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6), Substitution(['FmocHN','FmOcHN', 'NHFmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3', "[NH]C(=O)OCC1c(cccc2)c2c3c1cccc3", 0.6), Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5), Substitution(['OMs','MsO'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7), Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2), Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5), Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2), Substitution(['PMBN'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[N]Cc1ccc(OC)cc1", 0.2), Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1), # Substitution(['SEM','MES'], '[CH2;D2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]CSi(C)(C)C", 0.2), Substitution(['SEM','MES'], '[CH2;D2][O][CH2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]OCC[Si](C)(C)C", 0.2),#fix above Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2), Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5), Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2), Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7), Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2), Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3), Substitution(['TFAH2N'], 'C(=O)C(F)(F)F', "[NH]C(=O)C(F)(F)F", 0.3), Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5), Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6), # Ts Substitution(['TsO','OTs'], '[O]S(C1=CC=C(C=C1)C)(=O)=O', "[O]S(C1=CC=C(C=C1)C)(=O)=O", 0.6), # Ts Substitution(['COCH3'], '[OH0;D2][CH3;D1]', "[C](=O)C", 0.3), # Alkyl chains Substitution(['OMe', 'MeO','H;CO', 'CH3O','OCH3', 'H3CO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3), Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3), Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[N]C", 0.3),#modified as [NH]not wanted Substitution(['NMe2', 'Me2N'], '[N;X3](C)[CH3;D1]', "[N](C)C", 0.3),#modified as [NH]not wanted Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1), Substitution(['OEt', 'EtO','C2H5O','OC2H5'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5), Substitution(['MeOH2C','CH2OMe'], '[CH2;D2]O[CH3]', "[CH2]OC", 0.5), Substitution(['Et', 'CH2CH3','CH3CH2'], '[CH2;D2][CH3]', "[CH2]C", 0.3), Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3), Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3), # Substitution(['nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3), # Branched Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2), Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2), Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2), Substitution(['OtBu','tBuO'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6), Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3), # Other shorthands (MIGHT NOT WANT ALL OF THESE) Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5), Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5), Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5), Substitution(['OCCl3', 'Cl3CO'], '[OH0;X2][CH0;D4](Cl)(Cl)Cl', "[O]C(Cl)(Cl)Cl", 0.5), Substitution(['SCF3', 'F3CS'], '[SH0;X2][CH0;D4](F)(F)F', "[S]C(F)(F)F", 0.5), Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5), Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH Substitution(['CO2NH4','COONH4','H4NOOC','H4NO2C'], 'C(=O)[OH]', "[C](=O)ON", 0.5), # COOH Substitution([ 'COO-','CO2-'], 'C(=O)[OH]', "[C](=O)[O-]", 0.5), # COOH # Substitution([ 'COO'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5), # Substitution(['OCH3', 'H3CO'], '[OH0;D2][CH3]', "[O]C", 0.4), #TODO if need just addit here Substitution(['N3'], '[N]=[N+]=[N-]', "[N]=[N+]=[N-]", 0.4),#ACS image dataset has # [N-]=[N+] Substitution(['N2+Cl-','Cl-N2+'], '[N+]#[N].[Cl-]', "[N+]#[N].[Cl-]", 0.4),#ACS image dataset has Substitution(['N2'], '[N]=[N-]', "[N]=[N-]", 0.4),#ACS image dataset has Substitution(['N2H'], '[N]=[N-]', "[N]=[NH]", 0.4),#ACS image dataset has Substitution(['NO','N=O','O=N','ON'], '[N]=[O]', "[N]=O", 0.4),#ACS image dataset has Substitution(['NCH3'], '[N]C', "[NH]C", 0.4),#ACS image dataset has Substitution(['NOMe'], '[N]OC', "[N]OC", 0.4),#ACS image dataset has Substitution(['OCH2'], '[O]C', "[O]C", 0.4),#FORMULA_REGEX Substitution(['C=O','O=C'], '[C]=[O]', "[C]=O", 0.4),#ACS image dataset has Substitution(['NPh','PhN'], 'NC1=CC=CC=C1', "[N]C1=CC=CC=C1", 0.4),#ACS image dataset has Substitution(['NHPh','PhNH','PhHN'], 'NC1=CC=CC=C1', "[NH]C1=CC=CC=C1", 0.4),#ACS image dataset has Substitution(['TMSO','OSMT'], 'O[Si](C)(C)C', "[O][Si](C)(C)C", 0.5), Substitution(['SPh','PhS'], 'SC1=CC=CC=C1', "[S]C1=CC=CC=C1", 0.4),#ACS image dataset has Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4), Substitution(['SO3NH2','SO3NH4','H4NO3S'], 'S(=O)(=O)[OH]', "[S](=O)(=O)ON", 0.4), Substitution(['SO3'], 'S(=O)(=O)[OH]', "[S](=O)(=O)[O-]", 0.4), Substitution(['SO2CF3'], '[S](=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.5), Substitution(['SO2Cl'], '[S](=O)(=O)Cl', "[S](=O)(=O)Cl", 0.5), Substitution(['SO2F'], '[S](=O)(=O)F', "[S](=O)(=O)F", 0.5), Substitution(['SO2'], '[S](=O)(=O)', "[S](=O)(=O)", 0.5), Substitution(['SO2NH'], '[S](=O)(=O)[N]', "[S](=O)(=O)[N]", 0.5),#US07323045-20080129-C00062 may lead wrong connext Substitution(['SO2NH2'], '[S](=O)(=O)[NH2]', "[S](=O)(=O)[NH2]", 0.5), Substitution(['SO2Me','SO2CH3'], '[S](=O)(=O)C', "[S](=O)(=O)C", 0.5), Substitution(['NHO2S'], '[S](=O)(=O)[N]', "[N][S](=O)(=O)", 0.5),#US07323045-20080129-C00062 may lead wrong connext Substitution(['OSO2Me'], '[O]S(=O)(=O)C', "[O]S(=O)(=O)C", 0.5), Substitution(['NHSO2Me'], '[NH]S(=O)(=O)C', "[NH]S(=O)(=O)C", 0.5), Substitution(['SOCH3','SOMe'], '[S](=O)(=O)', "[S](=O)C", 0.5), Substitution(['P+Ph3Br-'], '[P+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', "[P+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3", 0.5), Substitution(['N+Ph3Br-'], '[N+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', "[N+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3", 0.5), Substitution(['PPh2'], "[P](C1=CC=CC=C1)C2=CC=CC=C2", "[P](C1=CC=CC=C1)C2=CC=CC=C2", 0.5), # Substitution(['BOcHN',"BOCHN"], "[NH]C(OC(C)(C)C)=O", "[NH]C(OC(C)(C)C)=O", 0.5), Substitution(['CO2Me', 'COOMe'], 'C(=O)[OH0;D2][CH3]', "[C](=O)OC", 0.5), Substitution(['ONa', 'NaO'], '[O][Na]', "[O][Na]", 0.5), Substitution(['OTBDMS', 'TBDMSO'], "[O][Si](C)(C)C(C)(C)C", "[O][Si](C)(C)C(C)(C)C", 0.5), Substitution(['CONH2'], '[C](O)(N)', "[C](=O)[NH2]", 0.5), Substitution(['NHNH2'], '[NH2;D1]', "[NH]N", 0.1), Substitution(['CONH'], 'CONH', '[C](=O)N', 0.5), Substitution(['CH3CONH'], '[NH]C(=O)C', '[NH]C(=O)C', 0.5), Substitution(['NH3Cl'], '[NH]Cl', '[NH]Cl', 0.5), Substitution(['SAc','AcS'], '[S]C(C)=O', "[S]C(C)=O", 0.5), Substitution(['OAll'], '[O]CC=C', '[O]CC=C', 0.5), # Substitution(['Tos'], '[Si](C)(C)C', '[Si](C)(C)C', 0.5),#NOTE different case ?? @@acs dataset ,we use the SO2here Substitution(['Tos','TOs'], '[Si](C)(C)C', '[S](=O)(=O)C(C=C1)=CC=C1C', 0.5),#NOTE different case ?? Substitution(['OTos','OTOs','soTO'], '[Si](C)(C)C', '[O]S(=O)(=O)C(C=C1)=CC=C1C', 0.5),#NOTE different case ?? Substitution(['TsN'], '[N]S(C1=CC=C(C=C1)C)(=O)=O', '[N]S(C1=CC=C(C=C1)C)(=O)=O', 0.5), Substitution(['Ts'], '[S](C1=CC=C(C=C1)C)(=O)=O', '[S](C1=CC=C(C=C1)C)(=O)=O', 0.5), Substitution(['COCF3'], '[C](=O)C(F)(F)(F)', '[C](=O)C(F)(F)(F)', 0.5), Substitution(['CF2', 'F2C'], '[C;D4](F)(F)', "[C](F)(F)", 0.5), Substitution(['PMB'], '[CH2]C1=CC=C(C=C1)OC', '[CH2]C1=CC=C(C=C1)OC', 0.5), Substitution(['NHCOtBu'], '[NH]C(C(C)(C)C)=O','[NH]C(C(C)(C)C)=O', 0.5), Substitution(['OCN'], '[N]=C=O', "[N]=C=O", 0.5), Substitution(['Me3Si'], '[Si](C)(C)(C)', "[Si](C)(C)(C)", 0.5), Substitution(['PhO','OPh'], '[O]C1=CC=CC=C1', "[O]C1=CC=CC=C1", 0.5), Substitution(['Allyl'], '[CH2]C=C', '[CH2]C=C', 0.5), Substitution(['C7H3'], '[C]#CC#CC#CC', '[C]#CC#CC#CC', 0.5), Substitution(['C5H11'], '[CH2]CCCC', '[CH2]CCCC', 0.5), Substitution(['R1R2N'], "[N]([*])[*]", "[N]([*])[*]", 0.5), Substitution(['CO2R'], '[C](=O)O*', '[C](=O)O*', 0.5), Substitution(['CCl3CH2O2C'], '[C](=O)OCC(Cl)(Cl)Cl', '[C](=O)OCC(Cl)(Cl)Cl', 0.5), Substitution(['NHOH'], '[NH]O', '[NH]O', 0.5), Substitution(['CO2'], '[C](=O)[O]', '[C](=O)[O]', 0.5), Substitution(['O2C'], '[C](=O)[O]', '[O][C](=O)', 0.5),#NOTE the direction matters Substitution(['PPh3'], '[P](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', '[P](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', 0.5), Substitution(['TfO'], '[C](=O)[O]', '[O]S(=O)(C(F)(F)F)=O', 0.5), Substitution(['OCH2Ph'], '[O]CC1=CC=CC=C1', '[O]CC1=CC=CC=C1', 0.5), Substitution(['OCH2CF3'], '[O]CC(F)(F)(F)', '[O]CC(F)(F)(F)', 0.5), Substitution(['COOCH2Ph'], '[C](=O)OCC1=CC=CC=C1', '[C](=O)OCC1=CC=CC=C1', 0.5), Substitution(['OCH2OC2H5'], '[C](=O)C(C)(C)C', '[O]COCC', 0.5), Substitution(['Trt'], '[C](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', '[C](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', 0.5), Substitution(['SF5'], '[S](F)(F)(F)(F)F', '[S](F)(F)(F)(F)F', 0.5), # Substitution(['CH2CH'], '[CH2][CH]', '[CH2][CH]', 0.5), # Substitution(['CH2CH2'], '[CH2][CH2]', '[CH2][CH2]', 0.5), # #SIMPLE abbv Substitution(['S*'], '[S]*', '[S]*', 0.5), Substitution(['N*, NH*'], '[NH]*', '[NH]*', 0.5), Substitution(['C*','CH2*'], '[C]*', '[CH2]*', 0.5), Substitution(['P*',"PH*"], '[P]*', '[PH]*', 0.5), Substitution(['O*'], '[O]*', '[O]*', 0.5), #() effect Substitution(['N(CH3)2'], '[N](C)(C)', "[N](C)(C)", 0.5), Substitution(['(C2H5)2N','Et2N'], '[N](C)(C)', "[N](CC)(CC)", 0.5), Substitution(['B(OH)2'], '[B](O)O', "[B](O)O", 0.5), Substitution(['CO2C(CH3)3'], '[C](=O)C(C)(C)C', '[C](=O)C(C)(C)C', 0.5), Substitution(['P(O)(OEt)2', 'P(OEt)2(O)'], "[P](OCC)(=O)CCO", "[P](OCC)(=O)OCC", 0.5), Substitution(['(CH2)16Me'], '[CH2]CCCCCCCCCCCCCCCC', "[CH2]CCCCCCCCCCCCCCCC", 0.3), Substitution(['(CH2)11Me'], '[CH2]CCCCCCCCCCC', "[CH2]CCCCCCCCCCC", 0.3), Substitution(['N(H)Et','Et(H)N'], '[NH]CC', '[NH]CC', 0.5), Substitution(['N(H)Me','Me(H)N'], '[NH]C', '[NH]C', 0.5), ] ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs} def extract_abbreviation_key(item): if isinstance(item, list): while isinstance(item, list): item = item[0] return item return item def clean_unpaired_brackets(text): #keep paired, del unpared result = [] stack = [] bracket_pairs = {')': '(', ']': '['} opening_brackets = {'(', '['} for char in text: if char in opening_brackets: stack.append(char) result.append(char) elif char in bracket_pairs: if stack and stack[-1] == bracket_pairs[char]: stack.pop() result.append(char) else: # 未配对的闭合括号,跳过 continue else: result.append(char) return ''.join(result) # def del_unpairebrackets(opening_brackets): # # 移除未配对的开括号 # keep paired, del unpared # result = [] # stack = [] # bracket_pairs = {')': '(', ']': '['} # opening_brackets = {'(', '['} # for char in result: # if char in opening_brackets: # stack.append(char) # elif char in bracket_pairs: # if stack and stack[-1] == bracket_pairs[char]: # stack.pop() # final_result.append(char) # else: # continue # else: # final_result.append(char) # # 如果仍有未闭合的开括号,移除它们 # return ''.join(c for c in final_result if not stack or c not in opening_brackets) def replace_c1(text): # Use negative lookahead to ensure 'C1' isn't followed by another digit return re.sub(r'C1(?!\d)', 'Cl', text) def transform_formula(formula): # 匹配 C 后面的数字和 Hg(允许 Hg 后跟其他元素) match = re.match(r'C(\d+)(.*?)Hg(.*)', formula) if not match: return formula n = int(match.group(1)) prefix = match.group(2) # Hg 前的部分(如空字符串或其他元素) suffix = match.group(3) # Hg 后的部分(如 O2) g_new = n * 2 + 1 return f"C{n}{prefix}H{g_new}{suffix}" def Cg_transform_formula(formula): # 匹配 C 后面的数字和 Hg(允许 Hg 后跟其他元素) match = re.match(r'CgH(\d+)(.*?)', formula) if not match: return formula n = int(match.group(1)) suffix = match.group(2) # Hg 后的部分(如 O2) g_new = (n-1)// 2 return f"C{g_new}H{n}{suffix}" def normalize_ocr_text(text, replacement_map): """Normalize OCR text using the predefined mapping rules""" if 'C1'in text: text=replace_c1(text) if 'Hg' in text: text= transform_formula(text) if 'Cg' in text: text= Cg_transform_formula(text) if 'Q' in text: pattern = r'Q([A-Z])(\w+)' replacement = r'O\1\2' text = re.sub(pattern, replacement, text) if text in ELEMENTS: return text #remove space if ' ' in text: text = text.replace(" ", "") if any(c in text for c in '0oO'): # Step 1: Replace 'o' or 'O' with '0' when after a digit and before a letter or end of string # text = re.sub(r'(?<=[1-9])[oO](?=[a-zA-GI-Z]|$)', '0', text) text = re.sub(r'(? [('C', 2), ('H', 4), ('O', 1)] """ tokens = FORMULA_REGEX.findall(formula) # if ''.join(tokens) != formula: # tokens = FORMULA_REGEX_BACKUP.findall(formula) return _parse_tokens(tokens) def _expand_carbon(elements: list): """ Given list of pairs `(elt, num)`, output single list of all atoms in order, expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O']) """ expanded = [] i = 0 while i < len(elements): elt, num = elements[i] # skip unreasonable number of atoms if num > 100000: i += 1; continue # expand carbon sequence if elt == 'C' and num > 1 and i + 1 < len(elements): next_elt, next_num = elements[i + 1] if next_num > 100000: i += 1; continue quotient, remainder = next_num // num, next_num % num for _ in range(num): expanded.append('C') for _ in range(quotient): expanded.append(next_elt) for _ in range(remainder): expanded.append(next_elt) i += 2 # recurse if `elt` itself is a list (nested formula) elif isinstance(elt, list): new_elt = _expand_carbon(elt) for _ in range(num): expanded.append(new_elt) i += 1 # simplest case: simply append `elt` `num` times else: for _ in range(num): expanded.append(elt) i += 1 if expanded==[]: return False else: return expanded def replace_bracket(match): content = match.group(1) # 条件1:包含数字或 '+' 或 '-',保留整个 [content] if re.search(r'\d|\+|-', content): return f'[{content}]' # 条件2:仅为 'H',保留 elif content == 'H': return '[H]' # 条件3:字符长度 >=2 且包含 'H',则去除括号和 H elif len(content) >= 2 and 'H' in content: return ''.join([ch for ch in content if ch != 'H']) # 条件4:其他情况,去掉括号 else: return content # return re.sub(r'\[([^\[\]]+)\]', replace_bracket, smi) def formula_regex(abbrev):# molscribe way for the combine abbver style tokens = FORMULA_REGEX.findall(abbrev) # elements=_parse_tokens(tokens) abbrev_exp=_expand_carbon(_parse_tokens(tokens)) if abbrev_exp==[]: return False else: return abbrev_exp def _expand_abbreviationMS(abbrev): """ Expand abbreviation into its SMILES; also converts [Rn] to [n*] Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula """ if abbrev in ABBREVIATIONS: return ABBREVIATIONS[abbrev].smiles # if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): if abbrev in RGROUP_SYMBOLS or (abbrev[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()): if abbrev[1:].isdigit(): return f'[{abbrev[1:]}*]' return '*' return f'[{abbrev}]' def _get_bond_symb(bond_num): """ Get SMILES symbol for a bond given bond order Used in `_condensed_formula_list_to_smiles` while writing the SMILES string """ if bond_num == 0: return '.' elif bond_num == 1: return '' elif bond_num == 2: return '=' elif bond_num == 3: return '#' else: print(f"check this val {bond_num} !!!" ) return '' def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None): """ Converts condensed formula (in the form of a list of symbols) to smiles Input: `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2 `start_bond`: # bonds attached to beginning of formula `end_bond`: # bonds attached to end of formula (deduce automatically if None) `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically) Returns: `smiles`: smiles corresponding to input condensed formula `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified `num_trials`: number of trials `success` (bool): whether conversion was successful """ # `direction` not specified: try left to right; if fails, try right to left if direction is None: num_trials = 1 for dir_choice in [1, -1]: smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice) num_trials += trials if success: return smiles, bonds_left, num_trials, success return None, None, num_trials, False assert direction == 1 or direction == -1 def dfs(smiles, bonds_left, cur_idx, add_idx): """ `smiles`: SMILES string so far `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached) `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to `add_idx`: index (in list `formula`) of atom to be attached to current atom `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2) """ num_trials = 1 # end of formula: return result if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1): if end_bond is not None and end_bond != bonds_left: return smiles, bonds_left, num_trials, False return smiles, bonds_left, num_trials, True # no more bonds but there are atoms remaining: conversion failed if bonds_left <= 0: return smiles, bonds_left, num_trials, False to_add = formula_list[add_idx] # atom to be added to current atom if not isinstance(to_add, str): return smiles, bonds_left, num_trials, False if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1 if bonds_left > 1: # "atom" added does not use up remaining bonds of current atom # get smiles of "atom" (which is itself a condensed formula) add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) if val > 0: add_str = _get_bond_symb(val + 1) + add_str num_trials += trials if not success: return smiles, bonds_left, num_trials, False # put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction) else: # "atom" added uses up remaining bonds of current atom # get smiles of "atom" and bonds left on it add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) num_trials += trials if not success: return smiles, bonds_left, num_trials, False # append smiles of "atom" (without parentheses) to smiles; it becomes new current atom result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction) smiles, bonds_left, trials, success = result num_trials += trials return smiles, bonds_left, num_trials, success # atom added is a single symbol (as opposed to nested condensed formula) for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added add_str = _expand_abbreviationMS(to_add) # expand to smiles if symbol is abbreviation if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom if cur_idx >= 0: add_str = _get_bond_symb(val) + add_str result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction) else: # atom added uses up remaining bonds of current atom; it becomes new current atom if cur_idx >= 0: add_str = _get_bond_symb(bonds_left) + add_str result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction) trials, success = result[2:] num_trials += trials if success: return result[0], result[1], num_trials, success if num_trials > 10000: break return smiles, bonds_left, num_trials, False cur_idx = -1 if direction == 1 else len(formula_list) add_idx = 0 if direction == 1 else len(formula_list) - 1 return dfs('', start_bond, cur_idx, add_idx) def swap_paren_bracket(text): # Check if string starts with '(' if not text.startswith('('): return text # Pattern: match (...) followed by [...] pattern = r'^\((.*?)\)\[(.*?)\]' # Find match match = re.match(pattern, text) if match: # Swap the groups: [group2](group1) return f'[{match.group(2)}]({match.group(1)})' return text def convert_ch2_string(s): # 匹配 (CH2)后面跟数字或字母的模式 pattern = r'\(CH2\)(\d+|[a-zA-Z]+)' match = re.fullmatch(pattern, s) if not match: return s # 如果不是目标模式,返回原字符串 suffix = match.group(1) if suffix.isdigit(): n = int(suffix) if n == 1: return '[CH2]' else: return '[CH2]' + 'C' * (n - 1) else: # 处理变量情况,如 (CH2)m var = suffix print(var,s) return s def process_string_joinused(s): # 检查字符串是否以[]开头 match = re.match(r'^\[([^\]]*)\](.*)$', s) if not match: return s # 如果不匹配,直接返回原字符串 content, rest = match.groups() # 计算[]中字符数 char_count = len(content) # 如果字符数大于1且包含H if char_count > 1 and 'H' in content: # 移除H及其后连续的数字 new_content = re.sub(r'H\d*', '', content) return f'[{new_content}]{rest}' return s def all_elements_in_dict(lst, dictionary): """ 递归检查列表(可能嵌套)中的所有元素是否都存在于字典的键中 :param lst: 要检查的列表(可能包含嵌套列表) :param dictionary: 要检查的字典 :return: 如果所有元素都在字典键中返回True,否则返回False """ for element in lst: if isinstance(element, list): # 如果是嵌套列表,递归检查 if not all_elements_in_dict(element, dictionary): return False else: # 如果是普通元素,检查是否在字典键中 if element not in dictionary: return False return True def expand_cf2_to_smiles(input_string): # 正则表达式匹配 (CF2)nX 的模式,X 为任意字母数字字符串 pattern = r'\(CF2\)(\d+)([A-Za-z0-9]+)' match = re.match(pattern, input_string) if not match: return input_string # 提取数字 n 和末尾的化学基团 X n = int(match.group(1)) tail_group = f"[{match.group(2)}]" # 构建 SMILES 字符串 # 每个 CF2 单元是 [C](F)(F),重复 n 次,最后接 tail_group cf2_unit = 'C(F)(F)' smiles = '[C](F)(F)' + cf2_unit * (n-1) + tail_group if n > 0 else tail_group return smiles def find_repeating_unit_and_smiles(s): match = re.fullmatch(r'(.+?)(?:\1)+', s) if match: unit = match.group(1) repeat_count = len(s) // len(unit) # 根据重复单元生成SMILES(适当处理CH2 -> C, CF2 -> CF2) if unit == "CH2": smiles_unit = "C" # CH2 -> C smi_init="[CH2]" elif unit == "CF2": smiles_unit = "C(F)(F)" # CF2保持原样 smi_init="[C](F)(F)" elif unit == "SO2": smiles_unit = "S(=O)(=O)" # SO2保持原样 smi_init="[S](=O)(=O)" else: smiles_unit,smi_init='','' print(f'please add the repateat patter here !!! for: {s}') # smiles_unit = unit # 其他单元直接使用 # 生成最终的SMILES smiles = smi_init + smiles_unit * (repeat_count - 1 ) return smiles, repeat_count, unit else: return None, 0, None # 如果没有匹配到,则返回None def get_smiles_from_symbol(symbol, mol, bonds): """ Convert symbol (abbrev. or condensed formula) to smiles If condensed formula, determine parsing direction and num. bonds on each side using coordinates """ if symbol in ABBREVIATIONS: return ABBREVIATIONS[symbol].smiles if symbol in RGROUP_SYMBOLS or (symbol[0] in RGROUP_SYMBOLS and symbol[1:].isdigit()): if symbol[1:].isdigit(): return f'[{symbol[1:]}*]' return '*' if len(symbol) > 20: return None smiles=convert_ch2_string(symbol) if smiles !=symbol: return smiles if '(CF2)' in symbol: smiles=expand_cf2_to_smiles(symbol) return smiles smiles, repeat_count, unit = find_repeating_unit_and_smiles(symbol) if repeat_count>0: return smiles #TODO@@@ add as speical case or add function, # this is hard encode NOTE fix this next version if symbol in ['CH2CH','CHCH2','CH2CH2', 'CH2CH2CH','CH2CH2CH','H2CH2CHC','CHCH2CH2','(CH2)10', 'H2C','CH2',#'CH2CH2NSO2CH3', 'OCH2CHOHCH2NH','OCH2CHOHCH2','CF2O','OF2C','EtO2CHN','EtO2C', 'CH2CH2C(O)0CH2CH3','CH2CH2C(O)OCH2CH3','l23I', 'OCH2CH2OH','OCH2CHCH2CCH3','CH2O', '(H4NO)2','SO2NHCH2CH','OCH2CH','OCF2H','COCOOCH2CH3','CH2CH2CH2CH','HCH2CH2CH2C','CF3CF2CF2CF2SO3', # 'SO2(CH2)3SO2NHCH2CHCH2OH', '(CF2)8H','PH3C','CO','OC', 'CF2CF2H','NHSO2CH3','CH2CH2C','CH;CH2C(O)0CHCH3','CH2CH2C(O)OCHCH3', 'NH2','H2N', 'CHO', 'OHC', 'N(SO2CH3)2','CH2CH2O','CH2CH2C(O)OCH2CH3', #ACS 'Ar2P(O)','PhO2S','NHP(O)Ph2','P*Ph3','P+Ph3','NH2.HCl', #CLEF 'S[O]a', #USPTO '(C3H6O)7CH3','HC','(HC','(CH2CH2CH2CH-)','3(CHCHCHCH272', #UOB 'NHzBrH','NH2BrH', #staker '(co)','(CO)', #JPO 'CH3CH','CH3CCH3','CH3CO','CH3OCH2','CO2C','CH2CO2CH3',"COCl", ]:#NOTE this are not passed by _condensed_formula_list_to_smiles function #TODO fix me in next version, may be need LLM to track this # Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5), # Substitution(['NH2','H2N'], '[NH2;D1]', "[NH2]", 0.1), #TODO symbol2SMILES() need dig ChemDraw if symbol in ['CH2CH','CHCH2']:smiles='[CH2][CH]' elif symbol in ['PH3C']:smiles='[CH2]P' elif symbol in ['l23I']:smiles='[I]' elif symbol in ['HC','(HC']:smiles='[CH]' elif symbol in ['NHzBrH','NH2BrH']:smiles='[NH2].Br' elif symbol in ['(C3H6O)7CH3']:smiles="[O]CCC"+"OCCC"*6+'C'#TODO maybe as function elif symbol in ['NH2.HCl']:smiles="[NH2].Cl" elif symbol in ['CH2CH2CH2CH','(CH2CH2CH2CH-)']:smiles='[CH2]CC[CH]' elif symbol in ['3(CHCHCHCH272', 'CHCHCHCH2']:smiles='[CH]CC[CH2]' # elif symbol in ['D']:smiles='[2H]' elif symbol in [ 'CH3CH']:smiles='[CH]C' elif symbol in [ 'CH2CO2CH3']:smiles='[CH2]C(=O)OC' elif symbol in [ 'CO2C']:smiles='[C](=O)O[C]' elif symbol in [ 'CH3CCH3']:smiles='[C](C)(C)' elif symbol in [ 'CH3CO']:smiles='[C](=O)C' elif symbol in [ 'CH3OCH2']:smiles='[CH2]OC' elif symbol in [ '(co)','(CO)']:smiles='[C](=O)' elif symbol in ['Ar2P(O)']:smiles='[P](*)(*)(=O)' elif symbol in ['PhO2S']:smiles='[S](=O)(=O)c1ccccc1' elif symbol in ['CO','OC']:smiles='[C](=O)' elif symbol in ['CH2O']:smiles='[CH2][O]' elif symbol in ['P*Ph3','P+Ph3',]:smiles='[P+](c1ccccc1)(c1ccccc1)(c1ccccc1)' elif symbol in ['NHP(O)Ph2']:smiles='[NH]P(=O)(c1ccccc1)c1ccccc1' elif symbol in ['CH;CH2C(O)0CHCH3','CH2CH2C(O)OCHCH3']:smiles='[CH2]CC(=O)OCC' elif symbol in ['CH2CH2CH','H2CH2CHC','CHCH2CH2']:smiles='[CH2][CH2][CH]' elif symbol in ['CH2CH2CH2CH']:smiles='[CH2]CC[CH]' elif symbol in ['HCH2CH2CH2C']:smiles='[CH]CC[CH2]' elif symbol in ['H2C','CH2']:smiles='[CH2]' elif symbol in ['H2CH2C','CH2CH2']:smiles='[CH2][CH2]' elif symbol in ['CHO', 'OHC']:smiles="[CH](=O)" elif symbol in ['NH2','H2N']:smiles="[NH2]" elif symbol in ['(CF2)8H',]:smiles="[C](F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)" elif symbol in ['CH2CH2C(O)OCH2CH3','CH2CH2C(O)0CH2CH3']:smiles='[CH2]CC(=O)OCC' elif symbol in ['CF3CF2CF2CF2SO3']:smiles='[S](=O)(=O)([O-])C(F)(F)C(F)(F)C(F)(F)C(F)(F)(F)' elif symbol in ['S[O]a']:smiles='[S](=O)' elif symbol in ['COCl']:smiles='[C](=O)Cl' elif symbol in ['OCF2H']:smiles="[O]C(F)(F)" elif symbol in ['CF2O']:smiles="[C](F)(F)[O]" elif symbol in ['OF2C']:smiles="[O][C](F)(F)" elif symbol in ['CF2CF2H']:smiles="[C](F)(F)C(F)(F)" # elif symbol in ['CH2CH2NSO2CH3']:smiles='[CH2]CNS(=O)(C)=O' elif symbol in ['CH2CH2O']:smiles='[CH2]CO' elif symbol in ['OCH2CH2OH']:smiles='[O]CCO'#NOTE Chemdraw may give some idea elif symbol in ['EtO2CHN']:smiles='[N]C(=O)OCC' elif symbol in ['OCH2CHOHCH2NH']:smiles='[O]CC(O)CN' elif symbol in ['OCH2CHCH2CCH3']:smiles='[O]C[CH]C[C]C' elif symbol in ['(H4NO)2']:smiles='[O]NON' elif symbol in ['SO2NHCH2CH']:smiles='[S](=O)(=O)NC[CH]' elif symbol in ['N(SO2CH3)2']:smiles='[N](S(=O)(=O)C)(S(=O)(=O)C)' elif symbol in ['CH2CH2C(O)OCH2CH3']:smiles='[CH2]CC(=O)OCC' elif symbol in ['OCH2CH']:smiles='[O]C[CH]' elif symbol in ['EtO2C']:smiles='C(=O)OCC' elif symbol in ['CH2CH2C']:smiles='[CH2]C[C]' elif symbol in ['NHSO2CH3']:smiles='[NH]S(=O)(=O)C' elif symbol in ['COCOOCH2CH3']:smiles='C(=O)C(=O)OCC' # elif symbol in ['SO2(CH2)3SO2NHCH2CHCH2OH']:smiles='[S](=O)(=O)CCCS(=O)(=O)NC[C]CO' # elif symbol in ['H4NO3S']:smiles='[S]NCC' # elif symbol in ['(CH2)10','[CH]CCCCCCCCC']:smiles='[CH]CCCCCCCCC'#as in convert_ch2_string() else:smiles=None return smiles total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))#TODO aromtaic bond effect ?? formula_list = _expand_carbon(_parse_formula(symbol)) # all_in_dict = all(fl in ABBREVIATIONS for fl in formula_list) all_in_dict=all_elements_in_dict(formula_list,ABBREVIATIONS) #total_bonds, bonds_left 机制是有问题的, 所以需要以上的修补,机制不完善 smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) # if debug: print(f'{[formula_list, total_bonds]} use _condensed_formula_list_to_smiles {success} <<-------\n {smiles}') if success: smiles=swap_paren_bracket(smiles) return smiles elif all_in_dict :#NOTE resolve abbv combine # smiles=ABBREVIATIONS[formula_list[0]].smiles key = extract_abbreviation_key(formula_list[0]) if key in ABBREVIATIONS: smiles = ABBREVIATIONS[key].smiles else: # raise ValueError(f"Abbreviation {key} not found in ABBREVIATIONS.") print(f"Abbreviation {key} not found in ABBREVIATIONS.") smiles='' for fl_i in range(1,len(formula_list)): cur_smi=process_string_joinused(ABBREVIATIONS[formula_list[fl_i]].smiles) smiles += cur_smi return smiles return None def abbrev2smile(abbrev,abbrev_exp,mol,idx): atom_gost = mol.GetAtomWithIdx(idx) bonds_gost = atom_gost.GetBonds() sub_smi = get_smiles_from_symbol(abbrev, mol, bonds_gost) if sub_smi: # print(f"succes expanding {abbrev},{abbrev_exp}\n{sub_smi}\t{idx}") return sub_smi else: print(f"failed expanding {abbrev},{abbrev_exp}\n{sub_smi}\t{idx}") return '[*]' # if abbrev_exp[0] in ABBREVIATIONS: # init_smi=ABBREVIATIONS[abbrev_exp[0]].smiles # else: # if len(abbrev_exp[0])==1: # init_smi=f'[{abbrev_exp[0]}]' # else: # print(f"{abbrev_exp[0]} @@@formula_regex") # init_smi=f'[{abbrev_exp[0]}]' # # init_smi=ABBREVIATIONS[abbrev_exp[0]].smiles if abbrev_exp[0] in ABBREVIATIONS else # if len(abbrev_exp)==1: # sub_smi=init_smi # return sub_smi # elif len(abbrev_exp)>1: # sub_smi=init_smi # for i_ in range(1,len(abbrev_exp)): # smi_=ABBREVIATIONS[abbrev_exp[i_]].smiles if abbrev_exp[i_] in ABBREVIATIONS else f'[{abbs[i_]}]' # smi_2=re.sub(r'\[([^\[\]]+)\]', replace_bracket, smi_) # sub_smi +=smi_2#default combine them with single bond TODO fixme ifneed # return sub_smi # else: # return False def replace_cg_notation(astr): def replacer(match): h_count = int(match.group(1)) c_count = (h_count - 1) // 2 return f'C{c_count}H{h_count}' return re.sub(r'CgH(\d+)', replacer, astr) def _expand_abbreviation(abbrev, mol,idx):# ABBREVIATIONS, RGROUP_SYMBOLS, ELEMENTS): """ Expand abbreviation into its SMILES; also converts [Rn] to [n*]. """ if abbrev in ABBREVIATIONS: return ABBREVIATIONS[abbrev].smiles # elif sub_smi_HC:return sub_smi_HC elif N_C_H_expand(abbrev):return N_C_H_expand(abbrev) elif C_F_expand(abbrev):return C_F_expand(abbrev) elif C_H_expand2(abbrev):return C_H_expand2(abbrev) elif C_H_expand(abbrev):return C_H_expand(abbrev) elif C_H_affixExpand(abbrev):return C_H_affixExpand(abbrev) # elif abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): elif abbrev in RGROUP_SYMBOLS or (abbrev[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()): if abbrev[1:].isdigit(): return f'[{abbrev[1:]}*]' elif abbrev in ELEMENTS: return f'[{abbrev}]' elif formula_regex(abbrev): abbrev_exp= formula_regex(abbrev) return abbrev2smile(abbrev,abbrev_exp,mol,idx)#last use Molscribe way match = re.match(r'^(\d+)?(.*)', abbrev) if match: numeric_part, remaining_part = match.groups() if remaining_part in ELEMENTS: return f'[{abbrev}]' elif numeric_part: return f'[{numeric_part}*]' else: print(f"fixme !!!@@@@: {abbrev}") return '[*]' def count_current_bonds(mol, atom_idx): """Count current bonds (including bond order) for an atom.""" atom = mol.GetAtomWithIdx(atom_idx) return sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds()) debug_not=True def expandABB(mol, ABBREVIATIONS, placeholder_atoms):#, RGROUP_SYMBOLS, ELEMENTS): mols = [mol] # 逆序遍历 placeholder_atoms,确保删除后不会影响后续索引 for idx in sorted(placeholder_atoms.keys(), reverse=True) : group = placeholder_atoms[idx] group_smiles = _expand_abbreviation(group,mol,idx) submol = Chem.MolFromSmiles(group_smiles) # 获取官能团的子分子 try: submol_rw = Chem.RWMol(submol) # 转换为可编辑的 RWMol except Exception as e: print(f"abbver: {group}") print(f'try to convert {group_smiles} to sub_mol') print(e) if debug_not: print(f"Failed to convert {group_smiles} to sub_mol, using placeholder [*] instead.") submol = Chem.MolFromSmiles('[*]') submol_rw = Chem.RWMol(submol) else: raise e#NOTE use it when debugging with adding abber and fixing rules in det_engine.py # 1. 识别 submol 的 anchor atoms(连接点) anchor_atoms = [0]#always use the fisrt atom as anchor atom for atom in submol_rw.GetAtoms(): # 具有自由基的原子或标记为连接点的原子(例如 [*]) if atom.GetNumRadicalElectrons() > 0 and atom.GetIdx() not in anchor_atoms:# or atom.GetSymbol() == '*': anchor_atoms.append(atom.GetIdx()) # 2. 复制主分子 new_mol = Chem.RWMol(mol) placeholder_idx = idx # 3. 记录 placeholder (*) 原子的邻居及其键类型 bonds_info = [] for bond in new_mol.GetBonds(): if bond.GetBeginAtomIdx() == placeholder_idx: bonds_info.append({ "neighbor": bond.GetEndAtomIdx(), "bond_type": bond.GetBondType() }) elif bond.GetEndAtomIdx() == placeholder_idx: bonds_info.append({ "neighbor": bond.GetBeginAtomIdx(), "bond_type": bond.GetBondType() }) # 4. 断开 placeholder 的所有键 for bond_info in bonds_info: new_mol.RemoveBond(placeholder_idx, bond_info["neighbor"]) # 5. 删除 placeholder 原子 new_mol.RemoveAtom(placeholder_idx) # 6. 重新计算邻居索引(删除后索引变化) adjusted_bonds_info = [] for bond_info in bonds_info: neighbor = bond_info["neighbor"] if neighbor < placeholder_idx: adjusted_neighbor = neighbor else: adjusted_neighbor = neighbor - 1 # 索引因删除原子而减 1 adjusted_bonds_info.append({ "neighbor": adjusted_neighbor, "bond_type": bond_info["bond_type"] }) # 7. 合并 submol new_mol = Chem.RWMol(Chem.CombineMols(new_mol, submol_rw)) # 8. 计算 submol 的 anchor atoms 在合并后的索引 submol_atom_offset = new_mol.GetNumAtoms() - submol_rw.GetNumAtoms() new_anchor_indices = [submol_atom_offset + anchor_idx for anchor_idx in anchor_atoms] # 9. 重新连接官能团,使用原始键类型 if len(new_anchor_indices) == 1: # 单连接点情况:所有邻居连接到唯一的 anchor atom anchor_idx = new_anchor_indices[0] for bond_info in adjusted_bonds_info: neighbor = bond_info["neighbor"] bond_type = bond_info["bond_type"] new_mol.AddBond(neighbor, anchor_idx, bond_type) # 重置自由基电子数 a1 = new_mol.GetAtomWithIdx(neighbor) a2 = new_mol.GetAtomWithIdx(anchor_idx) a1.SetNumRadicalElectrons(0) a2.SetNumRadicalElectrons(0) else: # # 多连接点情况:先尝试按顺序连接, 如果* 连* 会存在多种合理价态的不同分子情况 # 多连接点情况:根据邻居数量和 anchor atoms 分配连接 if len(adjusted_bonds_info) > len(new_anchor_indices): print(adjusted_bonds_info,' <---adjusted_bonds_info') print(new_anchor_indices,'<---new_anchor_indices') # raise ValueError(f"Too many neighbors ({len(adjusted_bonds_info)}) for submol with {len(new_anchor_indices)} anchor atoms.") # for i, bond_info in enumerate(adjusted_bonds_info): # # 按顺序将邻居连接到 anchor atoms # anchor_idx = new_anchor_indices[i % len(new_anchor_indices)] # neighbor = bond_info["neighbor"] # bond_type = bond_info["bond_type"] # new_mol.AddBond(neighbor, anchor_idx, bond_type) # # 重置自由基电子数 # a1 = new_mol.GetAtomWithIdx(neighbor) # a2 = new_mol.GetAtomWithIdx(anchor_idx) # a1.SetNumRadicalElectrons(0) # a2.SetNumRadicalElectrons(0) # 跟踪每个 anchor 的当前成键数 anchor_bond_counts = {idx: new_mol.GetAtomWithIdx(idx).GetTotalValence() for idx in new_anchor_indices} print(anchor_bond_counts,'<---anchor_bond_counts') # max_valence = {6: 4, 7: 3, 8: 2} # 示例:C=4, N=3, O=2,需根据实际原子类型扩展 adjusted_bonds_info = sorted(adjusted_bonds_info, key=lambda x: x['neighbor']) if mol.GetNumConformers() > 0:#as some mol may not have the conf dispite pass the 2d assign process pos_0 = mol.GetConformer().GetAtomPosition(adjusted_bonds_info[0]['neighbor']) pos_1 = mol.GetConformer().GetAtomPosition(adjusted_bonds_info[-1]['neighbor']) print(pos_0.x,pos_1.x,"xxx",adjusted_bonds_info) # if group =='SO2NH': # if pos_0.x 0.1: if orig_valid and scaled_valid: if orig_score >= scaled_score and orig_text: return orig_text, orig_score, cropped_img_orig elif scaled_text: return scaled_text, scaled_score, cropped_img_scaled elif orig_valid and not scaled_valid: return orig_text, orig_score, cropped_img_orig elif scaled_valid and not orig_valid: return scaled_text, scaled_score, cropped_img_scaled else: print(f"Both texts are invalid: orig_text='{orig_text}', scaled_text='{scaled_text}'") if orig_score >= scaled_score: return orig_text, orig_score, cropped_img_orig else: return scaled_text, scaled_score, cropped_img_scaled # 如果分差小于0.1,选择更合理的化学表达式 else: # 如果只有一个有效,选择有效的 if orig_valid and not scaled_valid: return orig_text, orig_score, cropped_img_orig elif scaled_valid and not orig_valid: return scaled_text, scaled_score, cropped_img_scaled # 如果都有效,比较长度 elif orig_valid and scaled_valid: if orig_text in ABBREVIATIONS and scaled_text not in ABBREVIATIONS: if N_C_H_expand(scaled_text) or C_F_expand(scaled_text) or C_H_expand2(scaled_text) or C_H_expand(scaled_text): if len(scaled_text)> len(orig_text): return scaled_text, scaled_score, cropped_img_scaled return orig_text, orig_score, cropped_img_orig elif orig_text not in ABBREVIATIONS and scaled_text in ABBREVIATIONS: if N_C_H_expand(orig_text) or C_F_expand(orig_text) or C_H_expand2(orig_text) or C_H_expand(orig_text): if len(orig_text)> len(scaled_text): return orig_text, orig_score, cropped_img_orig return scaled_text, scaled_score, cropped_img_scaled elif orig_text not in ABBREVIATIONS and scaled_text not in ABBREVIATIONS: if len(orig_text) > len(scaled_text): return orig_text, orig_score, cropped_img_orig else: if len(orig_text) == len(scaled_text): if orig_score >= scaled_score : return orig_text, orig_score, cropped_img_orig else: return scaled_text, scaled_score, cropped_img_scaled return scaled_text, scaled_score, cropped_img_scaled elif orig_text in ABBREVIATIONS and scaled_text in ABBREVIATIONS: if len(orig_text) >= len(scaled_text): return orig_text, orig_score, cropped_img_orig else: return scaled_text, scaled_score, cropped_img_scaled # 如果都不有效,优先选择 orig(若存在) elif orig_text: return orig_text, orig_score, cropped_img_orig elif scaled_text: return scaled_text, scaled_score, cropped_img_scaled # 默认返回 scaled(若存在) return scaled_text, scaled_score, cropped_img_scaled if scaled_text else (None, None, None) # def expandABB(mol,ABBREVIATIONS, placeholder_atoms):# RGROUP_SYMBOLS, ELEMENTS): # mols = [mol] # # Process placeholders in reverse order to avoid index issues # for idx in sorted(placeholder_atoms.keys(), reverse=True): # group = placeholder_atoms[idx] # group_smiles = _expand_abbreviation(group)# ABBREVIATIONS, RGROUP_SYMBOLS, ELEMENTS) # try: # submol = Chem.MolFromSmiles(group_smiles) # if not submol: # raise ValueError(f"Invalid SMILES for group {group}: {group_smiles}") # submol_rw = RWMol(submol) # except Exception as e: # print(f"Error processing SMILES for group {group}: {e}") # continue # # Create a new editable molecule # new_mol = RWMol(mol) # placeholder_idx = idx # # Get neighbors of the placeholder atom # neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()] # # Identify anchor atoms in submol (atoms marked as [*] or with isotope labels) # anchor_atoms = [] # for atom in submol.GetAtoms(): # if atom.GetNumRadicalElectrons() > 0: # #atom.GetSymbol() == '*' or atom.GetIsotope() > 0: # anchor_atoms.append(atom.GetIdx()) # # Validate number of anchor atoms vs. neighbors # if len(anchor_atoms) != len(neighbors): # print(f"Warning: Mismatch between anchor atoms ({len(anchor_atoms)}) and neighbors ({len(neighbors)}) for group {group}") # print(len(anchor_atoms), len(neighbors)) # if len(anchor_atoms)==0: # anchor_atoms.append(0)# use the first atom of submol as default such as PPh3 # # Remove bonds involving the placeholder atom # bonds_to_remove = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) # for bond in new_mol.GetBonds() # if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx] # for bond in bonds_to_remove: # new_mol.RemoveBond(bond[0], bond[1]) # # Remove the placeholder atom # new_mol.RemoveAtom(placeholder_idx) # # Adjust neighbor indices after atom removal # new_neighbors = [n - 1 if n > placeholder_idx else n for n in neighbors] # # Combine molecules # new_mol = RWMol(CombineMols(new_mol, submol_rw)) # # Connect anchor atoms to neighbors # submol_offset = new_mol.GetNumAtoms() - submol.GetNumAtoms() # for anchor_idx, neighbor_idx in zip(anchor_atoms, new_neighbors): # new_anchor_idx = submol_offset + anchor_idx # new_mol.AddBond(neighbor_idx, new_anchor_idx, Chem.BondType.SINGLE) # # Reset radical electrons # new_mol.GetAtomWithIdx(neighbor_idx).SetNumRadicalElectrons(0) # new_mol.GetAtomWithIdx(new_anchor_idx).SetNumRadicalElectrons(0) # mol = new_mol # mols.append(mol) # # Generate final SMILES # try: # modified_smiles = Chem.MolToSmiles(mols[-1]) # except Exception as e: # print(f"Error generating SMILES: {e}") # return mols[-1], None # return mols[-1], modified_smiles # def _expand_abbreviation(abbrev): # """ # Expand abbreviation into its SMILES; also converts [Rn] to [n*] # Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula # """ # if abbrev in ABBREVIATIONS: # return ABBREVIATIONS[abbrev].smiles # elif abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): # if abbrev[1:].isdigit(): # return f'[{abbrev[1:]}*]' # elif abbrev in ELEMENTS:#ocr tool need this # return f'[{abbrev}]' # # try abbrev # match = re.match(r'^(\d+)?(.*)', abbrev) # if match: # numeric_part, remaining_part = match.groups() # if remaining_part in ELEMENTS: # return f'[{abbrev}]' # else: # if numeric_part: # abbrev=f'[{numeric_part}*]' # return '[*]' # def expandABB(mol,ABBREVIATIONS, placeholder_atoms): # mols = [mol] # # **第三步: 替换 * 并合并官能团** # # 逆序遍历 placeholder_atoms,确保删除后不会影响后续索引 # for idx in sorted(placeholder_atoms.keys(), reverse=True): # group = placeholder_atoms[idx] # 获取官能团名称 # # print(idx, group) # group=_expand_abbreviation(group) # submol = Chem.MolFromSmiles(group) # 获取官能团的子分子 # submol_rw = RWMol(submol) # 让 submol 变成可编辑的 RWMol # anchor_atom_idx = 0 # 选择 `submol` 的第一个原子作为连接点 as defined in ABBREVIATIONS # # **1. 复制主分子** # new_mol = RWMol(mol) # # **2. 计算 `*` 在 `new_mol` 中的索引** # placeholder_idx = idx # # **3. 记录 `*` 原子的邻居** # neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()] # # **4. 断开 `*` 的所有键** # bonds_to_remove = [] # 记录要断开的键 # for bond in new_mol.GetBonds(): # if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx: # bonds_to_remove.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())) # for bond in bonds_to_remove: # new_mol.RemoveBond(bond[0], bond[1]) # # **5. 删除 `*` 原子** # new_mol.RemoveAtom(placeholder_idx) # # **6. 重新计算 `neighbors`(删除后索引变化)** # new_neighbors = [] # for neighbor in neighbors: # if neighbor < placeholder_idx: # new_neighbors.append(neighbor) # else: # new_neighbors.append(neighbor - 1) # 因为删除了一个原子,所有索引 -1 # # **7. 合并 `submol`** # new_mol = RWMol(CombineMols(new_mol, submol_rw)) # # **8. 计算 `submol` 的第一个原子在合并后的位置** # new_anchor_idx = new_mol.GetNumAtoms() - len(submol_rw.GetAtoms()) + anchor_atom_idx # # **9. 重新连接官能团** # for neighbor in new_neighbors: # # print(neighbor, new_anchor_idx, "!!") # new_mol.AddBond(neighbor, new_anchor_idx, Chem.BondType.SINGLE) # a1=new_mol.GetAtomWithIdx(neighbor) # a2=new_mol.GetAtomWithIdx(new_anchor_idx) # a1.SetNumRadicalElectrons(0) # a2.SetNumRadicalElectrons(0)## 将自由基电子数设为 0,as has added new bond # # **10. 更新主分子** # mol = new_mol # mols.append(mol) # # 输出修改后的分子 SMILES # modified_smiles = Chem.MolToSmiles(mols[-1]) # # print(f"修改后的分子 SMILES: {modified_smiles}") # return mols[-1], modified_smiles # Helper function to check if two boxes overlap def boxes_overlap(box1, box2): x1, y1, x2, y2 = box1 bx1, by1, bx2, by2 = box2 return not (x2 < bx1 or x1 > bx2 or y2 < by1 or y1 > by2) def boxes_overlap2(atombonx, bondbox): """ 检查两个矩形框是否重叠,并返回 bondbox 中不重叠一端到中心 10% 位置的坐标。 参数: atombonx: tuple (x1, y1, x2, y2) 表示原子框的坐标 bondbox: tuple (bx1, by1, bx2, by2) 表示键框的坐标 返回: tuple (x, y) 表示 bondbox 不重叠一端到中心 80% 位置的坐标,如果完全包含返回 (None, None) """ x1, y1, x2, y2 = atombonx bx1, by1, bx2, by2 = bondbox # 计算 bond_box 的中心坐标 bond_center_x = (bx1 + bx2) / 2 bond_center_y = (by1 + by2) / 2 # 辅助函数:计算点到 atom_box 中心的距离 def distance_to_center(x, y): center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 return ((x - center_x) ** 2 + (y - center_y) ** 2) ** 0.5 # 辅助函数:计算从中心到端点 80% 位置的坐标 def get_80_percent_point(far_x, far_y): # 从中心到端点的向量,按 80% 比例缩放 dx = far_x - bond_center_x dy = far_y - bond_center_y new_x = bond_center_x + 0.7 * dx#let added H close to the heavy neighbor new_y = bond_center_y + 0.7 * dy return new_x, new_y # 检查是否完全不相交 if (bx2 < x1 or bx1 > x2 or by2 < y1 or by1 > y2): # 完全不相交,返回较远一端到中心 80% 位置 dist1 = distance_to_center(bx1, by1) dist2 = distance_to_center(bx2, by2) far_x, far_y = (bx2, by2) if dist2 > dist1 else (bx1, by1) return get_80_percent_point(far_x, far_y) # 检查是否完全包含在 atom_box 内 if (bx1 >= x1 and bx2 <= x2 and by1 >= y1 and by2 <= y2): # bondbox 完全在 atom_box 内,无法确定不重叠部分,返回 bond_center_x, bond_center_y # return None, None return bond_center_x, bond_center_y # 检查一端是否在 atom_box 内 if (bx1 >= x1 and bx1 <= x2 and by1 >= y1 and by1 <= y2): # bx1, by1 在 atom_box 内,返回 bx2, by2 到中心 80% 位置 return get_80_percent_point(bx2, by2) elif (bx2 >= x1 and bx2 <= x2 and by2 >= y1 and by2 <= y2): # bx2, by2 在 atom_box 内,返回 bx1, by1 到中心 80% 位置 return get_80_percent_point(bx1, by1) # 处理部分相交但两端都不在 atom_box 内的情况 # 返回较远一端到中心 80% 位置 dist1 = distance_to_center(bx1, by1) dist2 = distance_to_center(bx2, by2) far_x, far_y = (bx2, by2) if dist2 > dist1 else (bx1, by1) return get_80_percent_point(far_x, far_y) charge_labels = [19,20,21,22,23] def outputbox_update(output,charge_labels,bond_labels,lab2idx): bonds_mask = np.array([True if ins in bond_labels else False for ins in output['pred_classes']]) bond_bbox=output['bbox'][bonds_mask] atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']]) atom_bbox=output['bbox'][atoms_mask] new_atoms=[] b_len=3 single_odd_b2a=dict() for bi,bb in enumerate(bond_bbox): overlapped_atoms = [] overlapped_abox=[] for ai,aa in enumerate(atom_bbox): overlap_flag=boxes_overlap(bb, aa)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol if overlap_flag: # print(bb, aa,overlap_flag) overlapped_atoms.append(ai) overlapped_abox.append(aa) if len(overlapped_atoms) == 1: single_odd_b2a[bi]=overlapped_atoms # Compute the non-overlapping part of the bond box to place hydrogen non_overlapping_x,non_overlapping_y=boxes_overlap2(overlapped_abox[0], bb) new_atom_out={'bbox': np.array([non_overlapping_x - b_len, non_overlapping_y - b_len, non_overlapping_x + b_len, non_overlapping_y + b_len]).reshape(-1,4), 'bbox_centers': np.array([non_overlapping_x,non_overlapping_y]).reshape(-1,2), 'scores': np.array([1.0]), 'pred_classes': np.array([lab2idx['H']])} new_atoms.append(new_atom_out) output2_=copy.deepcopy(output) for boxout in new_atoms: for k,arr in boxout.items(): value_or_row=output2_[k] if arr.ndim == 1: output2_[k]=np.append(value_or_row, arr) elif arr.ndim >= 2: output2_[k] = np.concatenate([value_or_row, arr], axis=0) else: print('errprs, unkown conditions !!!@') return output2_, single_odd_b2a def remove_unconnected_hydrogens(mol): """ 移除分子中不与重原子相连的氢原子(包括孤立 H 和只连到其他 H 的 H)。 参数: mol: RDKit Mol 对象 返回: 移除氢原子后的 RWMol 对象 """ # 转换为可编辑的 RWMol 对象 molexp = Chem.RWMol(mol) to_remove = [] # 遍历所有原子 for atom in molexp.GetAtoms(): if atom.GetSymbol() == 'H': # 只处理氢原子 neighbors = atom.GetNeighbors() # 检查邻居中是否有重原子 has_heavy_atom = False for neighbor in neighbors: if neighbor.GetSymbol() != 'H': # 如果邻居不是 H,则是重原子 has_heavy_atom = True break # 如果没有重原子邻居,标记为移除 if not has_heavy_atom: to_remove.append(atom.GetIdx()) # 按索引从大到小排序,避免移除时索引混乱 to_remove.sort(reverse=True) # 移除标记的原子 for ai in to_remove: molexp.RemoveAtom(ai) return molexp from rdkit import Chem from rdkit.Chem import AllChem def remove_unconnected_hydrogens2(mol): """ 移除分子中不与重原子相连的氢原子(包括孤立 H 和只连到其他 H 的 H),并返回移除的氢原子坐标。 参数: mol: RDKit Mol 对象 返回: rw_mol: 移除氢原子后的 RWMol 对象 removed_h_coords: 移除的氢原子的坐标列表 [(x1, y1, z1), (x2, y2, z2), ...] """ # 转换为可编辑的 RWMol 对象 rw_mol = Chem.RWMol(mol) to_remove = [] # 获取分子的构象(假设只有一个构象) conformer = rw_mol.GetConformer() # 存储移除的氢原子坐标 removed_h_coords = [] # 遍历所有原子 for atom in rw_mol.GetAtoms(): if atom.GetSymbol() == 'H': # 只处理氢原子 neighbors = atom.GetNeighbors() # 检查邻居中是否有重原子 has_heavy_atom = False for neighbor in neighbors: if neighbor.GetSymbol() != 'H': # 如果邻居不是 H,则是重原子 has_heavy_atom = True break # 如果没有重原子邻居,标记为移除,并记录坐标 if not has_heavy_atom: to_remove.append(atom.GetIdx()) pos = conformer.GetAtomPosition(atom.GetIdx()) removed_h_coords.append((pos.x, pos.y, pos.z)) # 按索引从大到小排序,避免移除时索引混乱 to_remove.sort(reverse=True) # 移除标记的原子 for ai in to_remove: rw_mol.RemoveAtom(ai) return rw_mol, removed_h_coords def detect_unconnected_hydrogens(mol): rw_mol = Chem.RWMol(mol) to_remove = [] # 获取分子的构象(假设只有一个构象) conformer = rw_mol.GetConformer() # 存储移除的氢原子坐标 removed_h_coords = [] # 遍历所有原子 for atom in rw_mol.GetAtoms(): if atom.GetSymbol() == 'H': # 只处理氢原子 neighbors = atom.GetNeighbors() # 检查邻居中是否有重原子 has_heavy_atom = False for neighbor in neighbors: if neighbor.GetSymbol() != 'H': # 如果邻居不是 H,则是重原子 has_heavy_atom = True break # 如果没有重原子邻居,标记为移除,并记录坐标 if not has_heavy_atom: to_remove.append(atom.GetIdx()) pos = conformer.GetAtomPosition(atom.GetIdx()) removed_h_coords.append((pos.x, pos.y, pos.z)) # 按索引从大到小排序,避免移除时索引混乱 to_remove.sort(reverse=True) return to_remove def view_box_center2(bond_bbox, bond_centers, bond_scores, bond_classes,overlap_dist_thresh=5.0, max_centers_per_box=5, plot_view=False, ): """ 筛选和可视化 bond_bbox 和 bond_centers,处理重叠圆和过多中心的框。 参数: bond_bbox: numpy array, [x1, y1, x2, y2] 格式的框坐标 bond_centers: numpy array, [x, y] 格式的中心坐标 bond_scores: numpy array, 得分 overlap_dist_thresh: float,判断圆重叠的距离阈值(默认为 5 个单位) max_centers_per_box: int,一个框内允许的最大中心数(超过则移除) 返回: tuple: (筛选后的 bond_bbox, bond_centers, bond_scores) """ # 确保输入形状匹配 assert len(bond_bbox) == len(bond_centers) == len(bond_scores), "Input arrays must have equal length" n = len(bond_bbox) # Step 1: 处理重叠的 bond_centers(保留得分最高的) keep_centers = np.ones(n, dtype=bool) # 标记要保留的中心 for i in range(n): if not keep_centers[i]: continue for j in range(i + 1, n): if not keep_centers[j]: continue # 计算两个中心之间的欧几里得距离 dist = np.sqrt(np.sum((bond_centers[i] - bond_centers[j]) ** 2)) if dist < overlap_dist_thresh: # 如果重叠,保留得分较高的 if bond_scores[i] > bond_scores[j]: keep_centers[j] = False else: keep_centers[i] = False # 应用初步筛选 bond_bbox = bond_bbox[keep_centers] bond_centers = bond_centers[keep_centers] bond_scores = bond_scores[keep_centers] bond_classes= bond_classes[keep_centers] n = len(bond_bbox) # 更新数量 # Step 2: 检查每个框内的中心数量 keep_boxes = np.ones(n, dtype=bool) # 标记要保留的框 for i in range(n): # 计算框内的中心数量 x1, y1, x2, y2 = bond_bbox[i] centers_in_box = np.sum((bond_centers[:, 0] >= x1) & (bond_centers[:, 0] <= x2) & (bond_centers[:, 1] >= y1) & (bond_centers[:, 1] <= y2)) if centers_in_box > max_centers_per_box: keep_boxes[i] = False # 应用最终筛选 final_bond_bbox = bond_bbox[keep_boxes] final_bond_centers = bond_centers[keep_boxes] final_bond_scores = bond_scores[keep_boxes] final_bond_classes= bond_classes[keep_boxes] if plot_view: # 可视化(可选) fig, ax = plt.subplots(figsize=(10, 10)) for box in final_bond_bbox: x1, y1, x2, y2 = box width = x2 - x1 height = y2 - y1 rect = Rectangle((x1, y1), width, height, linewidth=1, edgecolor='blue', facecolor='none') ax.add_patch(rect) for center in final_bond_centers: circle = Circle(center, radius=5, edgecolor='red', facecolor='none', linewidth=1) ax.add_patch(circle) # 设置坐标轴范围 x_min = min(final_bond_bbox[:, 0].min(), final_bond_centers[:, 0].min()) - 10 x_max = max(final_bond_bbox[:, 2].max(), final_bond_centers[:, 0].max()) + 10 y_min = min(final_bond_bbox[:, 1].min(), final_bond_centers[:, 1].min()) - 10 y_max = max(final_bond_bbox[:, 3].max(), final_bond_centers[:, 1].max()) + 10 ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) ax.set_title("Filtered Boxes and Centers") ax.set_xlabel("X") ax.set_ylabel("Y") plt.gca().set_aspect('equal', adjustable='box') plt.grid(True, linestyle='--', alpha=0.7) # plt.show() else: fig=None return final_bond_bbox, final_bond_centers, final_bond_scores,final_bond_classes,fig def calculate_iou(box1, box2): """ 计算两个框的 IoU(Intersection over Union)。 参数: box1, box2: [x1, y1, x2, y2] 格式的框坐标 返回: float: IoU 值 """ x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) intersection = max(0, x2 - x1) * max(0, y2 - y1) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) union = area1 + area2 - intersection return intersection / union if union > 0 else 0 def nms_per_class(labels, boxes, scores, iou_thresh=0.5): """ 对每个类别应用 NMS,保留得分最高的框。 参数: labels: numpy array,类别标签 boxes: numpy array,框坐标 [x1, y1, x2, y2] scores: numpy array,得分 iou_thresh: float,IoU 阈值 返回: dict: 筛选后的输出 """ # 按类别分组 unique_labels = np.unique(labels) kept_indices = [] for label in unique_labels: # 筛选当前类别的框 class_mask = labels == label class_indices = np.where(class_mask)[0] class_boxes = boxes[class_mask] class_scores = scores[class_mask] # 按得分从高到低排序 order = np.argsort(class_scores)[::-1] class_boxes = class_boxes[order] class_scores = class_scores[order] class_indices = class_indices[order] # NMS keep = [] while len(class_scores) > 0: # 保留得分最高的框 keep.append(class_indices[0]) if len(class_scores) == 1: break # 计算当前框与其他框的 IoU ious = np.array([calculate_iou(class_boxes[0], box) for box in class_boxes[1:]]) # 保留 IoU 低于阈值的框 keep_mask = ious < iou_thresh class_boxes = class_boxes[1:][keep_mask] class_scores = class_scores[1:][keep_mask] class_indices = class_indices[1:][keep_mask] kept_indices.extend(keep) # 根据保留的索引更新输出 kept_indices = np.array(kept_indices) return { 'labels': labels[kept_indices], 'boxes': boxes[kept_indices], 'scores': scores[kept_indices] } import numpy as np def get_overlap_region(box1, box2): """ Get the overlapping region of two boxes. Args: box1, box2: [x_min, y_min, x_max, y_max] Returns: tuple: (x_min, y_min, x_max, y_max) of overlap region, or None if no overlap """ x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) if x2 <= x1 or y2 <= y1: return None # No overlap return (x1, y1, x2, y2) def are_bond_connected(box1, box2, bond_bboxes, bond_iou_threshold=0.1): """ Check if two atom boxes are connected by a bond box, with bond center in overlap region. Args: box1, box2: atom boxes to check bond_bboxes: array of bond boxes bond_iou_threshold: IoU threshold for initial bond overlap Returns: bool: True if connected by a bond with center in overlap region """ # Get the overlap region of the two atom boxes overlap_region = get_overlap_region(box1, box2) if overlap_region is None: return False # No overlap between atom boxes ox_min, oy_min, ox_max, oy_max = overlap_region for bond_box in bond_bboxes: # Preliminary IoU check iou1 = calculate_iou(box1, bond_box) iou2 = calculate_iou(box2, bond_box) if iou1 > bond_iou_threshold and iou2 > bond_iou_threshold: # Calculate bond box center bond_center_x = (bond_box[0] + bond_box[2]) / 2 bond_center_y = (bond_box[1] + bond_box[3]) / 2 # Check if bond center is within the overlap region if (ox_min <= bond_center_x <= ox_max and oy_min <= bond_center_y <= oy_max): return True return False def calculate_iou(box1, box2): """ 计算两个边界框的 IoU box1, box2: [x_min, y_min, x_max, y_max] """ x1 = max(box1[0], box2[0]) y1 = max(box1[1], box2[1]) x2 = min(box1[2], box2[2]) y2 = min(box1[3], box2[3]) intersection = max(0, x2 - x1) * max(0, y2 - y1) area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) union = area1 + area2 - intersection return intersection / union if union > 0 else 0 def nms(atom_bboxes, atom_scores, atom_classes, iou_threshold=0.5): """ 应用非极大值抑制 (NMS) atom_bboxes: 列表,包含所有边界框 [x_min, y_min, x_max, y_max] atom_scores: 列表,包含每个边界框的置信度 atom_classes: 列表,包含每个边界框的类别 iou_threshold: IoU 阈值,用于判断是否抑制 返回: 保留的边界框、类别和置信度的索引 """ # 按置信度排序,获取索引 indices = np.argsort(atom_scores)[::-1] # 从高到低排序 keep_indices = [] while len(indices) > 0: # 使用 len(indices) 替代 indices.size # 保留当前最高置信度的框 current_idx = indices[0] keep_indices.append(current_idx) # 计算当前框与其他框的 IoU ious = np.array([calculate_iou(atom_bboxes[current_idx], atom_bboxes[idx]) for idx in indices[1:]]) # 找出 IoU > threshold 的索引(相对于 indices[1:] 的偏移) suppress_indices = indices[1:][ious > iou_threshold] # 更新 indices,去除当前框和被抑制的框 indices = np.setdiff1d(indices, np.concatenate(([current_idx], suppress_indices))) # 调试信息 # print(f"Current idx: {current_idx}, rmoved: {suppress_indices}, Remaining: {indices}") # print(f"Current idx: {current_idx}, rmoved: {suppress_indices}, IOU: {ious}") # 返回保留的框、类别和置信度 kept_bboxes = np.array([atom_bboxes[i] for i in keep_indices]) kept_classes = np.array([atom_classes[i] for i in keep_indices]) kept_scores = np.array([atom_scores[i] for i in keep_indices]) return kept_bboxes, kept_classes, kept_scores def count_bond_overlaps(box, bond_bboxes, bond_iou_threshold=0.1): """ Count how many bond boxes overlap with an atom box. Args: box: atom box [x_min, y_min, x_max, y_max] bond_bboxes: array of bond boxes bond_iou_threshold: IoU threshold for overlap Returns: int: number of overlapping bond boxes """ return sum(1 for bond_box in bond_bboxes if calculate_iou(box, bond_box) > bond_iou_threshold) def count_bond_overlaps(box, bond_bboxes, bond_iou_threshold=0.01): """Count how many bond boxes overlap with an atom box.""" return sum(1 for bond_box in bond_bboxes if calculate_iou(box, bond_box) > bond_iou_threshold) def count_atom_overlaps(box, all_bboxes, exclude_idx, min_iou=0.01): """Count how many other atom boxes overlap with this box.""" return sum(1 for i, other_box in enumerate(all_bboxes) if i != exclude_idx and calculate_iou(box, other_box) > min_iou) def merge_low_iou_boxes(kept_bboxes, kept_classes, kept_scores, bond_bboxes, merge_threshold=0.5, score_threshold=0.7, bond_iou_threshold=0.01, high_iou_threshold=0.8, large_score_threshold=0.5): """ Merge or filter boxes with IoU conditions, removing large low-score boxes first. Args: kept_bboxes: array, atom bounding boxes [x_min, y_min, x_max, y_max] kept_classes: array, class labels (e.g., 0 for 'C') kept_scores: array, confidence scores bond_bboxes: array, bond bounding boxes merge_threshold: float, upper IoU threshold for merging score_threshold: float, score threshold to preserve boxes bond_iou_threshold: float, IoU threshold for bond connectivity high_iou_threshold: float, IoU threshold for high-IoU merging large_score_threshold: float, score threshold for large box removal (default 0.5) Returns: tuple: (merged_bboxes, merged_classes, merged_scores) """ if len(kept_bboxes) <= 1: return kept_bboxes, kept_classes, kept_scores kept_bboxes = np.array(kept_bboxes) kept_classes = np.array(kept_classes) kept_scores = np.array(kept_scores) bond_bboxes = np.array(bond_bboxes) # Step 0: Remove large boxes with low scores, high atom overlaps, and high bond overlaps areas = (kept_bboxes[:, 2] - kept_bboxes[:, 0]) * (kept_bboxes[:, 3] - kept_bboxes[:, 1]) median_area = np.median(areas) keep_mask = np.ones(len(kept_bboxes), dtype=bool) for i in range(len(kept_bboxes)): if kept_scores[i] < large_score_threshold: atom_overlaps = count_atom_overlaps(kept_bboxes[i], kept_bboxes, i) bond_overlaps = count_bond_overlaps(kept_bboxes[i], bond_bboxes, bond_iou_threshold) is_large = areas[i] > median_area # Define "large" as above median if is_large and atom_overlaps >= 2 and bond_overlaps >= 3: keep_mask[i] = False print(f"Removed large low-score box idx {i}: score {kept_scores[i]}, " f"area {areas[i]}, atom overlaps {atom_overlaps}, bond overlaps {bond_overlaps}") # Filter boxes kept_bboxes = kept_bboxes[keep_mask] print(f"afterRemoved large low-score atom box::{len(kept_bboxes)} ") kept_classes = kept_classes[keep_mask] kept_scores = kept_scores[keep_mask] if len(kept_bboxes) == 0: return np.array([]), np.array([]), np.array([]) merged_bboxes = [] merged_classes = [] merged_scores = [] used_indices = set() # Step 1: Merge boxes with IoU > high_iou_threshold i = 0 while i < len(kept_bboxes): if i in used_indices: i += 1 continue high_iou_group = [i] for j in range(len(kept_bboxes)): if j in used_indices or j == i: continue iou = calculate_iou(kept_bboxes[i], kept_bboxes[j]) if iou > high_iou_threshold: high_iou_group.append(j) if len(high_iou_group) > 1:#atom box ovrlaped group_scores = kept_scores[high_iou_group] max_score_idx = high_iou_group[np.argmax(group_scores)] merged_bboxes.append(kept_bboxes[max_score_idx]) merged_classes.append(kept_classes[max_score_idx]) merged_scores.append(kept_scores[max_score_idx]) used_indices.update(high_iou_group) print(f"Merged high-IoU (> {high_iou_threshold}) boxes: {high_iou_group}, " f"kept index: {max_score_idx}") i += 1 # Step 2: Process remaining boxes i = 0 while i < len(kept_bboxes): if i in used_indices: i += 1 continue current_indices = [i] for j in range(len(kept_bboxes)): if j in used_indices or j == i: continue iou = calculate_iou(kept_bboxes[i], kept_bboxes[j])#IOU between atoms box if 0.05 <= iou < merge_threshold:#better detect model with score matters #any small IOU between atoms will processed here if kept_scores[j]<0.7: current_indices.append(j) group_indices = current_indices group_scores = kept_scores[group_indices] group_classes = kept_classes[group_indices] group_bboxes = kept_bboxes[group_indices] max_score = np.max(group_scores) max_score_idx = group_indices[np.argmax(group_scores)] if max_score >= score_threshold: bond_connected = False if len(group_indices) > 1: for idx1, idx2 in zip(group_indices[:-1], group_indices[1:]): if are_bond_connected(kept_bboxes[idx1], kept_bboxes[idx2], bond_bboxes, bond_iou_threshold): bond_connected = True break if bond_connected: for idx in group_indices: merged_bboxes.append(kept_bboxes[idx]) merged_classes.append(kept_classes[idx]) merged_scores.append(kept_scores[idx]) print(f"Kept all bond-connected boxes: {group_indices}") else: bond_overlap_counts = [count_bond_overlaps(kept_bboxes[idx], bond_bboxes, bond_iou_threshold) for idx in group_indices] max_overlaps = max(bond_overlap_counts) candidates = [idx for idx, count in zip(group_indices, bond_overlap_counts) if count == max_overlaps] best_idx = max(candidates, key=lambda idx: kept_scores[idx]) merged_bboxes.append(kept_bboxes[best_idx]) merged_classes.append(kept_classes[best_idx]) merged_scores.append(kept_scores[best_idx]) # print(f"No bond box overlap, kept box with most bond overlaps: {best_idx}, " # f"overlap count: {max_overlaps}") else: if len(group_indices) == 1: merged_bboxes.append(kept_bboxes[i]) merged_classes.append(kept_classes[i]) merged_scores.append(kept_scores[i]) print(f"Merged lower IOU @@ ONLY ONE box {i}") else: new_bbox = [ np.min(group_bboxes[:, 0]), # x_min np.min(group_bboxes[:, 1]), # y_min np.max(group_bboxes[:, 2]), # x_max np.max(group_bboxes[:, 3]) # y_max ] merged_bboxes.append(new_bbox) merged_classes.append(group_classes[np.argmax(group_scores)]) merged_scores.append(max_score) print(f"Merged low-score boxes: {group_indices}") used_indices.update(group_indices) i += 1 print(f"after processs low IOU atom box::{len(merged_bboxes)} ") return (np.array(merged_bboxes), np.array(merged_classes), np.array(merged_scores)) def refine_boxes(atom_bboxes, atom_scores, atom_classes, bond_bboxes, nms_iou_threshold=0.5, merge_threshold=0.5, score_threshold=0.5, bond_iou_threshold=0.01, high_iou_threshold=0.8): """ Iteratively apply NMS and merge until the number of boxes stabilizes. Args: atom_bboxes, atom_scores, atom_classes: Initial atom box data bond_bboxes: Bond box data nms_iou_threshold, merge_threshold, score_threshold, bond_iou_threshold, high_iou_threshold: Parameters Returns: tuple: (final_bboxes, final_classes, final_scores) """ current_bboxes = np.array(atom_bboxes) current_classes = np.array(atom_classes) current_scores = np.array(atom_scores) prev_count = len(current_bboxes) + 1 # Ensure at least one iteration iteration = 0 while len(current_bboxes) < prev_count: print(f"\nIteration {iteration}: Starting with {len(current_bboxes)} boxes") prev_count = len(current_bboxes) # Apply NMS kept_bboxes, kept_classes, kept_scores = nms( current_bboxes, current_scores, current_classes, iou_threshold=nms_iou_threshold ) print(f"After NMS: {len(kept_bboxes)} boxes") # Apply merge_low_iou_boxes merged_bboxes, merged_classes, merged_scores = merge_low_iou_boxes( kept_bboxes, kept_classes, kept_scores, bond_bboxes, merge_threshold=merge_threshold, score_threshold=score_threshold, bond_iou_threshold=bond_iou_threshold, high_iou_threshold=high_iou_threshold ) print(f"After merge: {len(merged_bboxes)} boxes") # Update for next iteration current_bboxes = merged_bboxes current_classes = merged_classes current_scores = merged_scores iteration += 1 print(f"Converged after {iteration} iterations with {len(current_bboxes)} boxes") return current_bboxes, current_scores, current_classes def merge_low_iou_boxes_old(kept_bboxes, kept_classes, kept_scores, merge_threshold=0.3): """ 合并 IoU < merge_threshold 的边界框,使用较高 score 的 class """ if len(kept_bboxes) <= 1: return kept_bboxes, kept_classes, kept_scores merged_bboxes = [] merged_classes = [] merged_scores = [] used_indices = set() for i in range(len(kept_bboxes)): if i in used_indices: continue # 找到 IoU < merge_threshold 的框组 current_indices = [i] for j in range(i + 1, len(kept_bboxes)): if j in used_indices: continue iou = calculate_iou(kept_bboxes[i], kept_bboxes[j]) if iou < merge_threshold and iou >0.01: current_indices.append(j) # 获取相关框的 score, class, 和 bbox scores = kept_scores[current_indices] classes = kept_classes[current_indices] bboxes = kept_bboxes[current_indices] max_score = np.max(scores) max_score_idx = current_indices[np.argmax(scores)] if max_score > 0.5: # 保留 score 最大的框 merged_bboxes.append(kept_bboxes[max_score_idx]) merged_classes.append(kept_classes[max_score_idx]) merged_scores.append(kept_scores[max_score_idx]) else: # 合并框,取最小和最大坐标 new_bbox = [ np.min(bboxes[:, 0]), # x_min np.min(bboxes[:, 1]), # y_min np.max(bboxes[:, 2]), # x_max np.max(bboxes[:, 3]) # y_max ] merged_bboxes.append(new_bbox) merged_classes.append(0)#repalce with * merged_scores.append(max_score) # 标记已使用的索引 used_indices.update(current_indices) # 转换为 NumPy 数组 merged_bboxes = np.array(merged_bboxes) merged_classes = np.array(merged_classes) merged_scores = np.array(merged_scores) return merged_bboxes, merged_classes, merged_scores ############################################################################################################################################################ #molscrbe evaluate from SmilesPE.pretokenizer import atomwise_tokenizer def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True): if type(smiles) is not str or smiles == '': return '', False if ignore_cistrans: smiles = smiles.replace('/', '').replace('\\', '') if replace_rgroup: tokens = atomwise_tokenizer(smiles) for j, token in enumerate(tokens): if token[0] == '[' and token[-1] == ']': symbol = token[1:-1] if symbol[0] == 'R' and symbol[1:].isdigit(): tokens[j] = f'[{symbol[1:]}*]' elif Chem.AtomFromSmiles(token) is None: tokens[j] = '*' smiles = ''.join(tokens) try: canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral)) success = True except: canon_smiles = smiles success = False return canon_smiles, success def convert_smiles_to_canonsmiles( smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16): with multiprocessing.Pool(num_workers) as p: results = p.starmap(canonicalize_smiles, [(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list], chunksize=128) canon_smiles, success = zip(*results) return list(canon_smiles), np.mean(success) def tanimoto_similarity(smiles1, smiles2): try: mol1 = Chem.MolFromSmiles(smiles1) mol2 = Chem.MolFromSmiles(smiles2) fp1 = Chem.RDKFingerprint(mol1) fp2 = Chem.RDKFingerprint(mol2) tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2) return tanimoto except: return 0 def compute_tanimoto_similarities(gold_smiles, pred_smiles, num_workers=16): with multiprocessing.Pool(num_workers) as p: similarities = p.starmap(tanimoto_similarity, [(gs, ps) for gs, ps in zip(gold_smiles, pred_smiles)]) return similarities class SmilesEvaluator(object): def __init__(self, gold_smiles, num_workers=16, tanimoto=False): self.gold_smiles = gold_smiles self.num_workers = num_workers self.tanimoto = tanimoto self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles, ignore_cistrans=True, num_workers=num_workers) self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles, ignore_chiral=True, ignore_cistrans=True, num_workers=num_workers) self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans) self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral) def _replace_empty(self, smiles_list): """Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty.""" return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "" for smiles in smiles_list] def evaluate(self, pred_smiles, include_details=False): results = {} if self.tanimoto: results['tanimoto'] = np.mean(compute_tanimoto_similarities(self.gold_smiles, pred_smiles)) # Ignore double bond cis/trans pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_cistrans=True, num_workers=self.num_workers) results['canon_smiles'] = np.mean(np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)) if include_details: results['canon_smiles_details'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)) # Ignore chirality (Graph exact match) pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, ignore_chiral=True, ignore_cistrans=True, num_workers=self.num_workers) results['graph'] = np.mean(np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)) # Evaluate on molecules with chiral centers chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g]) results['chiral'] = np.mean(chiral[:, 0] == chiral[:, 1]) if len(chiral) > 0 else -1 return results ############################################################################################################################################################ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, data_loader: Iterable, optimizer: torch.optim.Optimizer, device: torch.device, epoch: int, max_norm: float = 0, **kwargs): model.train() criterion.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) # metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) header = 'Epoch: [{}]'.format(epoch) print_freq = kwargs.get('print_freq', 10) ema = kwargs.get('ema', None) scaler = kwargs.get('scaler', None) for samples, targets in metric_logger.log_every(data_loader, print_freq, header): samples = samples.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] if scaler is not None: with torch.autocast(device_type=str(device), cache_enabled=True): outputs = model(samples, targets) with torch.autocast(device_type=str(device), enabled=False): loss_dict = criterion(outputs, targets) loss = sum(loss_dict.values()) scaler.scale(loss).backward() if max_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() else: outputs = model(samples, targets) loss_dict = criterion(outputs, targets) loss = sum(loss_dict.values()) optimizer.zero_grad() loss.backward() if max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) optimizer.step() # ema if ema is not None: ema.update(model) loss_dict_reduced = reduce_dict(loss_dict) loss_value = sum(loss_dict_reduced.values()) if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) metric_logger.update(loss=loss_value, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()} # @torch.no_grad() # def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, data_loader, base_ds, device, output_dir, # annot_file=f'/home/jovyan/rt-detr/data/real_processed/CLEF_with_charge/annotations/val.json', # outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_CLEF.csv', # ): # model.eval() # criterion.eval() # metric_logger = MetricLogger(delimiter=" ") # header = 'Test:' # iou_types = postprocessors.iou_types # coco_evaluator = CocoEvaluator(base_ds, iou_types) # panoptic_evaluator = None # # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # # home='/home/jovyan/rt-detr' # # dataset = 'CLEF' # # annot_file=f'/home/jovyan/rt-detr/data/real_processed/{dataset}_with_charge/annotations/test.json' # # outcsv_filename/home/jovyan/rt-detr/rt-detr/output/output_charge_{dataset}.csv' # # annot_file=f'/home/jovyan/rt-detr/data/real_processed/{dataset}_with_charge/annotations/test.json' # # outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_{dataset}.csv' # with open(annot_file, 'r') as file: # data = json.load(file) # image_id_to_name = {} # for image_data in data['images']: # image_id = image_data['id'] # image_path = image_data['file_name'] # image_name = os.path.basename(image_path) # image_id_to_name[image_id] = image_name # res_smiles = [] # bond_labels = [13,14,15,16,17,18] # idx_to_labels={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', # 9:'I',10:'P',11:'H',12:'Si', # #bond # 13:'single',14:'wdge',15:'dash', # 16:'=',17:'#',18:':',#aromatic # #charge # 19:'-4',20:'-2', # 21:'-1',#- # 22:'+1',#+ # 23:'2', # } # lab2idx={v:k for k,v in idx_to_labels.items()} # #indigo bond type stero maping # indi_bond={ # "1":'single', "2":'=',"3":'#',"4":':',"5":'wdge',"6":'dash', # } # smiles_data = pd.DataFrame({'file_name': [], # 'SMILES':[]}) # output_dict = {} # target_dict = {} # filtered_output_dict = {} # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # for samples, targets in metric_logger.log_every(data_loader, 10, header): # samples = samples.to(device) # targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # outputs = model(samples) # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) # results = postprocessors(outputs, orig_target_sizes)#RTDETRPostProcessor@@src/zoo/rtertr # #results: a list of dict label box score # res = {target['image_id'].item(): output for target, output in zip(targets, results)} # for target, output in zip(targets, results): # output_dict[target['image_id'].item()] = output # stats = {} # # stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} # if coco_evaluator is not None: # if 'bbox' in iou_types: # # stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() # stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats # if 'segm' in iou_types: # stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # # ocr_recognition_only = get_ocr_recognition_only(force_cpu=False) # # caption_remover = CaptionRemover(force_cpu=True) # for key, value in output_dict.items():#TODO improve here # selected_indices = value['scores'] > 0.5#may be >=0.5 cut off, as used the sigmoid? # if value['labels'][selected_indices].size(0) != 0:#no good prediction # filtered_output_dict[key] = { # 'labels': value['labels'][selected_indices],# may be selected_indices ==0 as all small than0.5 # 'boxes': value['boxes'][selected_indices], # 'scores': value['scores'][selected_indices] # } # else: # ima_name=image_id_to_name[key] # print(key,"all prediction scores small 0.5!!",len(output_dict),f"{ima_name}")## # for i,(key,value) in enumerate(filtered_output_dict.items()): # result = []#TODO need a box2mol or graph # smi_mol=output_to_smiles(value,idx_to_labels,bond_labels,result)#TODO use the idx_to_labels numer to if --else # if smi_mol: # res_smiles.append(smi_mol[0]) #TODO check this erro other0 # else: # res_smiles.append('') # new_row = {'file_name':image_id_to_name[key], 'SMILES':res_smiles[i]} # smiles_data = smiles_data._append(new_row, ignore_index=True) # print(f"will save {len(smiles_data)} dataframe into csv") # smiles_data.to_csv(outcsv_filename, index=False) # return stats, coco_evaluator def remove_bond_directions_if_no_chiral(mol): # 检查分子是否有效 if mol is None: return None # 计算手性中心 chiral_centers = Chem.FindMolChiralCenters(mol, includeUnassigned=True) # 如果没有手性中心,移除单键的立体化学标记 if not chiral_centers: for bond in mol.GetBonds(): # 只处理单键 if bond.GetBondType() == Chem.BondType.SINGLE: # 移除楔形和虚线标记 bond.SetBondDir(Chem.BondDir.NONE) return mol ####################################################################################### def molExpanding(mol_rebuit,placeholder_atoms,wdbs,bond_dirs,alignmol=False): cm=copy.deepcopy(mol_rebuit) # print(placeholder_atoms) expand_mol, expand_smiles= expandABB(cm,ABBREVIATIONS, placeholder_atoms) rdm=copy.deepcopy(expand_mol) AllChem.Compute2DCoords(rdm) target_mol, ref_mol=rdm, cm if alignmol: mcs=rdFMCS.FindMCS([target_mol, ref_mol], # larger,small order atomCompare=rdFMCS.AtomCompare.CompareAny, # bondCompare=rdFMCS.BondCompare.CompareAny, ringCompare=rdFMCS.RingCompare.IgnoreRingFusion, matchChiralTag=False, ) atommaping_pairs=g_atompair_matches([target_mol, ref_mol],mcs) atomMap=atommaping_pairs[0] try: rmsd2=rdkit.Chem.rdMolAlign.AlignMol(prbMol=target_mol, refMol=ref_mol, atomMap=atomMap,maxIters=2000000) except Exception as e: print(atomMap,"@@@@") print(e) #after get atomMap c2p={cur:pre for cur, pre in atomMap} p2c={pre:cur for cur, pre in atomMap} for b in wdbs:#add bond direction p0,p1=int(b[0]), int(b[1])#may be not in the atomMap as the mcs_sub if p0 in p2c.keys() and p1 in p2c.keys(): c0,c1=p2c[p0],p2c[p1] # print("[pre0,pre1]vs[c0,c1]current atom id",[p0,p1],[c0,c1]) b_=target_mol.GetBondBetweenAtoms(c0,c1) if b_: b_.SetBondDir(bond_dirs[b[3]]) expandStero_smi=Chem.MolToSmiles(target_mol)#directly will not add the stero info into smiles, must have the assing steps else: expandStero_smi =expand_smiles m=target_mol.GetMol() # Chem.SanitizeMol(m) Chem.DetectBondStereochemistry(m) Chem.AssignChiralTypesFromBondDirs(m) Chem.AssignStereochemistry(m)#expandStero_smi , m return expandStero_smi, m def remove_backslash_and_slash(input_string): if "\\" in input_string: input_string = input_string.replace("\\", "") if "/" in input_string: input_string = input_string.replace("/", "") return input_string def remove_number_before_star(input_string): result = list(input_string) i = 0 while i < len(result): if result[i] == '*' and i!= len(result) -1: #*c1c(*)c(*)c(C(*)(*)C(C)C)c(*)c1* --> *c1c(*)c(*)c(C(*)(*)C(C)C)c(*)c1* j = i - 1 if result[j-1].isalpha(): continue while j >= 0 and result[j].isdigit(): result[j] = '' j -= 1 i += 1 return ''.join(result) def remove_SP(input_string): pattern = r'\[([^@]*)@?[A-Z0-9]*\]' # if "S@SP1" in input_string: # input_string = input_string.replace("S@SP1", "S") # elif "S@SP2" in input_string: # input_string = input_string.replace("S@SP2", "S") # elif "S@SP3" in input_string: # input_string = input_string.replace("S@SP3", "S") input_string = re.sub(r'@SP[1-3]', '', input_string) if '@TB' in input_string: result = re.sub(pattern, r'[\1]', input_string) input_string=result return input_string def rdkit_canonicalize_smiles(smiles): Aad_string = r'([A-Z][a-z]*)([0-9]+)' tokens = atomwise_tokenizer(smiles) for j, token in enumerate(tokens): if token[0] == '[' and token[-1] == ']': symbol = token[1:-1] # matches = re.findall(Aad_string, symbol)#findall may give not wanted, such as [BH2], shuld not change matches = re.match(Aad_string, symbol) if matches: letters, numbers = matches.groups() print(f"{letters} {numbers}") # tokens[j] = f'[{symbol[1:]}*]' tokens[j] = '*' elif symbol in RGROUP_SYMBOLS:# or (symbol[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()): tokens[j] = '*' elif Chem.AtomFromSmiles(token) is None: tokens[j] = '*' smiles = ''.join(tokens) try: canon_smiles = Chem.CanonSmiles(smiles, useChiral=False) success = True except: canon_smiles = smiles success = False return canon_smiles, success def NoRadical_Smi(smi): aa=Chem.MolFromSmiles(smi) for atom in aa.GetAtoms(): if atom.GetNumRadicalElectrons() > 0: # 检查是否有自由基 # print(f"找到自由基原子: {atom.GetSymbol()}, 自由电子数: {atom.GetNumRadicalElectrons()}") # 添加氢原子以去除自由基 atom.SetNumRadicalElectrons(0) # 将自由电子数设为 0 # 根据硫原子的化合价调整氢原子数 atom.SetNumExplicitHs(atom.GetTotalValence() - atom.GetExplicitValence()) san_before=Chem.MolToSmiles(aa) # print(san_before) return san_before import logging def check_and_fix_valence(smiles_or_list): """ Check atom valences in a SMILES string or a list [smiles, suffix/prefix]. Fix unusual valences (e.g., N(2)) by adding/removing hydrogens to maintain neutrality. Returns: (corrected_smiles_or_list, warnings) """ # Set up logging logging.basicConfig(level=logging.WARNING) warnings = [] # Standard valence dictionary for common atoms standard_valences = { 'C': [4], 'N': [3], # Prioritize valence 3 for neutral nitrogen (e.g., amines, amides) 'O': [2], 'H': [1], 'F': [1] } # Handle input: SMILES string or list from C_H_expand if isinstance(smiles_or_list, list): smiles, other_part = smiles_or_list else: smiles, other_part = smiles_or_list, None # Process main SMILES mol = Chem.MolFromSmiles(smiles, sanitize=False) if smiles else None if mol is None: warnings.append(f"Invalid SMILES: {smiles}") return smiles_or_list, warnings # Process other_part if it exists and is a valid SMILES other_part_mol = None if other_part: try: other_part_mol = Chem.MolFromSmiles(other_part, sanitize=False) except: pass # other_part may not be valid SMILES (e.g., a suffix/prefix) # Helper function to check and fix valence for a molecule def process_molecule(mol, is_other_part=False): nonlocal warnings corrected = False prefix = "other_part" if is_other_part else "SMILES" # Compute valence explicitly to avoid precondition violation mol.UpdatePropertyCache(strict=False) # Check valences for atom in mol.GetAtoms(): symbol = atom.GetSymbol() valence = atom.GetTotalValence() expected_valences = standard_valences.get(symbol, [valence]) if valence not in expected_valences: warnings.append(f"Unusual valence in {prefix} for {symbol}: {valence} (expected {expected_valences})") # Fix nitrogen valence issues by adjusting hydrogens if any('N' in w for w in warnings if prefix in w): rw_mol = Chem.RWMol(mol) # Editable molecule for atom in rw_mol.GetAtoms(): if atom.GetSymbol() != 'N': continue valence = atom.GetTotalValence() if valence < 3: # Add hydrogens to reach valence 3 hydrogens_needed = 3 - valence atom.SetNumExplicitHs(atom.GetNumExplicitHs() + hydrogens_needed) corrected = True elif valence > 3: # Remove hydrogens if possible hydrogens_to_remove = valence - 3 current_hydrogens = atom.GetNumExplicitHs() if current_hydrogens >= hydrogens_to_remove: atom.SetNumExplicitHs(current_hydrogens - hydrogens_to_remove) corrected = True else: warnings.append(f"Cannot reduce N valence in {prefix} to 3 without removing non-H bonds") if corrected: mol = rw_mol.GetMol() # Sanitize molecule after corrections if corrected: try: Chem.SanitizeMol(mol, catchErrors=True) return mol, True except Exception as e: warnings.append(f"Failed to sanitize {prefix} after correction: {str(e)}") return mol, False return mol, False # Process main molecule mol, mol_corrected = process_molecule(mol) # Convert main molecule back to SMILES corrected_smiles = Chem.MolToSmiles(mol) if mol_corrected else smiles # Process other_part if it's a valid molecule corrected_other_part = other_part if other_part_mol: other_part_mol, other_corrected = process_molecule(other_part_mol, is_other_part=True) corrected_other_part = Chem.MolToSmiles(other_part_mol) if other_corrected else other_part # Return based on input type if other_part: return [corrected_smiles, corrected_other_part], warnings return corrected_smiles, warnings def molfpsim(original_smiles,test_smiles):#I2M use the coordinates, so 2D coformation should be always #only use longest for desalts, one molecule comparing test_smiles= select_longest_smiles(test_smiles) original_smiles= select_longest_smiles(original_smiles) test_smiles, warnings=check_and_fix_valence(test_smiles) original_smiles = remove_backslash_and_slash(original_smiles)#c/s test_smiles = remove_backslash_and_slash(test_smiles) original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used test_smiles = remove_SP(test_smiles) rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)#R-->* if "S" in rd_smi_ori and success1:#NOTE H replace radical electron rd_smi_ori=NoRadical_Smi(rd_smi_ori) rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) original_smiles,test_smiles=rd_smi_ori,rd_smi mol1 = Chem.MolFromSmiles(original_smiles)#TODO considering smiles with rdkit not recongized in real data mol2 = Chem.MolFromSmiles(test_smiles)#TODO considering smiles with rdkit not recongized in real data morganfps1 = AllChem.GetMorganFingerprint(mol1, useChirality=False) morganfps2 = AllChem.GetMorganFingerprint(mol2, useChirality=False) morgan_tani = DataStructs.DiceSimilarity(morganfps1, morganfps2) fp1 = Chem.RDKFingerprint(mol1) fp2 = Chem.RDKFingerprint(mol2) tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2) return morgan_tani, tanimoto def comparing_smiles2(original_smiles,test_smiles):#I2M use the coordinates, so 2D coformation should be always original_smiles = remove_backslash_and_slash(original_smiles)#c/s test_smiles = remove_backslash_and_slash(test_smiles) original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used test_smiles = remove_SP(test_smiles) rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)#R-->* if "S" in rd_smi_ori and success1:#NOTE H replace radical electron rd_smi_ori=NoRadical_Smi(rd_smi_ori) rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) original_smiles,test_smiles=rd_smi_ori,rd_smi try: original_mol = Chem.MolFromSmiles(original_smiles)#considering whe nmmet abbrev test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)#as build mol may not sanitized for rdkit if success2 and success1: # if original_smiles!=test_smiles: # print(f'smiles ori,pred after Chem.CanonSmiles:\n{original_smiles}\n{test_smiles}') RDarom_smi=Chem.MolToSmiles(original_mol) RDarom_smi_test=Chem.MolToSmiles(test_mol) if RDarom_smi==RDarom_smi_test: return True else: print(f'smiles ori,pred after Chem.CanonSmiles:\n{RDarom_smi}\n{RDarom_smi_test}\n') if original_mol: Chem.SanitizeMol(original_mol) keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True) else: keku_smi_ori=original_smiles if test_mol: Chem.SanitizeMol(test_mol) keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True) else: keku_smi=test_smiles if '*' not in keku_smi: keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori)) keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi)) else: keku_inch_ori= 1 keku_inch_test= 2 rd_smi=Chem.MolToSmiles(test_mol)#need improve the acc rd_smi_ori=Chem.MolToSmiles(original_mol) except Exception as e:#TODO fixme here print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")# original_smiles,test_smiles\n{original_smiles}\n{test_smiles}") print(e,"!!!!!!!\n") keku_inch_ori= 1 keku_inch_test= 2 keku_smi=1 keku_smi_ori=2 #add molscribe rules here if not success1:#ori smiles still invaild even after * replaced rd_smi_ori = rd_smi # else: # if canon_smiles1 == canon_smiles2: # rd_smi_ori = rd_smi # else: if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :#as orinial smiles may use kekuleSmiles style return True else:return False def smiles12_comparing(original_smiles,test_smiles): original_smiles = remove_backslash_and_slash(original_smiles)#c/s test_smiles = remove_backslash_and_slash(test_smiles) original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used test_smiles = remove_SP(test_smiles) rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles) rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) original_smiles,test_smiles=rd_smi_ori,rd_smi try: original_mol = Chem.MolFromSmiles(original_smiles)#considering whe nmmet abbrev test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)#as build mol may not sanitized for rdkit if original_mol: Chem.SanitizeMol(original_mol) keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True) else: keku_smi_ori=original_smiles if test_mol: Chem.SanitizeMol(test_mol) keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True) else: keku_smi=test_smiles if '*' not in keku_smi: keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori)) keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi)) else: keku_inch_ori= 1 keku_inch_test= 2 rd_smi=Chem.MolToSmiles(test_mol)#need improve the acc rd_smi_ori=Chem.MolToSmiles(original_mol) except Exception as e:#TODO fixme here print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")# original_smiles,test_smiles\n{original_smiles}\n{test_smiles}") print(e,"!!!!!!!\n") keku_inch_ori= 1 keku_inch_test= 2 keku_smi=1 keku_smi_ori=2 #add molscribe rules here if not success1:#ori smiles still invaild even after * replaced rd_smi_ori = rd_smi # else: # if canon_smiles1 == canon_smiles2: # rd_smi_ori = rd_smi # else: if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :#as orinial smiles may use kekuleSmiles style return True else:return False def comparing_smiles(new_row,test_smiles):#I2M use the coordinates, so 2D coformation should be always original_smiles=new_row['SMILESori'] original_smiles = remove_backslash_and_slash(original_smiles)#c/s test_smiles = remove_backslash_and_slash(test_smiles) original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used test_smiles = remove_SP(test_smiles) rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles) rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) original_smiles,test_smiles=rd_smi_ori,rd_smi try: original_mol = Chem.MolFromSmiles(original_smiles)#considering whe nmmet abbrev test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)#as build mol may not sanitized for rdkit if original_mol: Chem.SanitizeMol(original_mol) keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True) else: keku_smi_ori=original_smiles if test_mol: Chem.SanitizeMol(test_mol) keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True) else: keku_smi=test_smiles if '*' not in keku_smi: keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori)) keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi)) else: keku_inch_ori= 1 keku_inch_test= 2 rd_smi=Chem.MolToSmiles(test_mol)#need improve the acc rd_smi_ori=Chem.MolToSmiles(original_mol) except Exception as e:#TODO fixme here print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")# original_smiles,test_smiles\n{original_smiles}\n{test_smiles}") print(new_row) print(e,"!!!!!!!\n") keku_inch_ori= 1 keku_inch_test= 2 keku_smi=1 keku_smi_ori=2 #add molscribe rules here if not success1:#ori smiles still invaild even after * replaced rd_smi_ori = rd_smi # else: # if canon_smiles1 == canon_smiles2: # rd_smi_ori = rd_smi # else: if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :#as orinial smiles may use kekuleSmiles style return True else:return False def bbox2center(bbox): x_center = (bbox[:, 0] + bbox[:, 2]) / 2 y_center = (bbox[:, 1] + bbox[:, 3]) / 2 # center_coords = torch.stack((x_center, y_center), dim=1) centers = np.stack((x_center, y_center), axis=1) return centers import cv2 BONDDIRECT=['ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT'] def reorder_bond_bbox(bond_bbox, single_atom_bond): # 分离普通索引和需要后置的索引 normal_indices = [] special_indices = [] # 获取需要后置的 key keys_to_move = set(single_atom_bond.keys()) # 分类所有索引 for i in range(len(bond_bbox)): if i in keys_to_move: special_indices.append(i) else: normal_indices.append(i) # 新顺序:普通索引在前,特殊索引在后 new_order = normal_indices + special_indices # 重排 bond_bbox reordered_bbox = [bond_bbox[i] for i in new_order] return reordered_bbox def boxes_overlap(box1, box2): """ 检查两个边界框是否重叠 box1, box2: [x1, y1, x2, y2] """ return not (box1[2] < box2[0] or box1[0] > box2[2] or box1[3] < box2[1] or box1[1] > box2[3]) def calculate_center(box): """ 计算边界框的中心点 """ return np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]) def merge_boxes(box1, box2): """ 合并两个边界框,返回新边界框 [x1, y1, x2, y2] """ return [ min(box1[0], box2[0]), min(box1[1], box2[1]), max(box1[2], box2[2]), max(box1[3], box2[3]) ] def get_merged_box(boxes): """Calculate the smallest box encompassing all given boxes.""" x_mins = [box[0] for box in boxes] y_mins = [box[1] for box in boxes] x_maxs = [box[2] for box in boxes] y_maxs = [box[3] for box in boxes] return [min(x_mins), min(y_mins), max(x_maxs), max(y_maxs)] def box_area(box): """Calculate the area of a box.""" return (box[2] - box[0]) * (box[3] - box[1]) def Newbox_(atom_bbox,bond_bbox, lab2idx): #add H atom box when on direction bond new_atoms=[] b_len=3 single_odd_b2a=dict() for bi,bb in enumerate(bond_bbox): overlapped_atoms = [] overlapped_abox=[] for ai,aa in enumerate(atom_bbox): overlap_flag=boxes_overlap(bb, aa)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol if overlap_flag: # print(bb, aa,overlap_flag) overlapped_atoms.append(ai) overlapped_abox.append(aa) if len(overlapped_atoms) == 1: single_odd_b2a[bi]=overlapped_atoms # Compute the non-overlapping part of the bond box to place hydrogen non_overlapping_x,non_overlapping_y=boxes_overlap2(overlapped_abox[0], bb) new_atom_out={'bbox': np.array([non_overlapping_x - b_len, non_overlapping_y - b_len, non_overlapping_x + b_len, non_overlapping_y + b_len]).reshape(-1,4), 'bbox_centers': np.array([non_overlapping_x,non_overlapping_y]).reshape(-1,2), 'scores': np.array([1.0]), 'pred_classes': np.array([lab2idx['H']])} new_atoms.append(new_atom_out) return new_atoms, single_odd_b2a def has_boxes(data): #TO CHECK OCR detct used or not return isinstance(data, list) and len(data) > 0 and all( isinstance(item, list) and len(item) == 2 and isinstance(item[0], list) and len(item[0]) == 4 for item in data ) def AtomBox2bondBox(atom_box,bond_bbox): b_nei=[] overlap=True for bi,bb in enumerate(bond_bbox): overlap_flag=boxes_overlap(bb, atom_box)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol if overlap_flag: b_nei.append(bi) if len(b_nei)==0: # delt_hei.append(hei) overlap=False return overlap, b_nei import torchvision.transforms.v2 as T def image_to_tensor(image_path,debug=True): image = Image.open(image_path) w, h = image.size # 处理灰度或其他模式 if image.mode == "L": if debug: print("检测到灰度图像 (1 通道),转换为 RGB...") image = image.convert("RGB") elif image.mode != "RGB": if debug: print(f"检测到 {image.mode} 模式,转换为 RGB...") image = image.convert("RGB") # Define a transform to convert the image to a tensor and normalize it transform = T.Compose([ T.Resize((640, 640)), # 调整大小 # T.ToImageTensor(), # 转换为 PyTorch Tensor T.ToTensor(), lambda x: x.to(torch.float32), # 手动转换数据类型# T.ConvertDtype(dtype=torch.float32), # 转换数据类型 ]) # Apply the transform to the image tensor = transform(image) return tensor,w,h # from src.zoo.rtdetr.rtdetr_postprocessor import RTDETRPostProcessor @torch.no_grad() def evaluate_x(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, data_loader, device, outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_CLEF.csv', visual_check=False, other2ppsocr=True, getacc=False, ): postprocessor2=RTDETRPostProcessor(num_classes=30, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) output_directory = os.path.dirname(outcsv_filename) prefix_f = os.path.basename(outcsv_filename).split('.')[0] if other2ppsocr: ocr = PaddleOCR( use_angle_cls=True, lang='latin',use_space_char=True,use_debug=False, use_gpu=True if cv2.cuda.getCudaEnabledDeviceCount() > 0 else False) ocr2 = ocr2 = PaddleOCR(use_angle_cls=True,use_gpu =False,use_debug=False, rec_algorithm='SVTR_LCNet', rec_model_dir='/nfs_home/bowen/.paddleocr/whl/rec/en/en_PP-OCRv4_rec_infer', lang="en") outcsv_filename=f"{output_directory}/{prefix_f}_withOCR.csv" if visual_check: output_directory = os.path.dirname(outcsv_filename) prefix_f = os.path.basename(outcsv_filename).split('.')[0] ima_checkdir=f"{output_directory}/{prefix_f}_Boxed" os.makedirs(ima_checkdir, exist_ok=True) if getacc: acc_summary=f"{outcsv_filename}.I2Msummary.txt" flogout = open(f'{acc_summary}' , 'w') failed=[] mydiff=[] simRD=0 sim=0 mysum=0 model.eval() criterion.eval() metric_logger = MetricLogger(delimiter=" ") header = 'Infering:' res_smiles = [] idx_to_labels23={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', 9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH', 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'2',} idx_to_labels30 = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', 9:'I',10:'P',11:'H',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH', 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'2', 23:'CF3',#NOTE rdkit get element not supporting group 24:'CN', 25:'Me', 26:'CO2Et', 27:'R', 28:'Ph', 29:'*', } bond_labels = [13,14,15,16,17] if postprocessors.num_classes==23: # print(data["categories"]) print(f'usage idx_to_labels23',idx_to_labels23) idx_to_labels = idx_to_labels23 elif postprocessors.num_classes==30: # print(data["categories"])#NOTE 11 is H not * now print(f'usage idx_to_labels30',idx_to_labels30) idx_to_labels = idx_to_labels30 else: print(f"error unkown ways@@@@@@@@@@@!!!!!!!!!!idx_to_labels::{len(idx_to_labels)}\n{idx_to_labels}") abrevie={"[23*]":'CF3', "[24*]":'CN', "[25*]":'Me', "[26*]":'CO2Et', "[27*]":'R', "[28*]":'Ph', "[29*]":'3~7UP', } # idx_to_labels = idx_to_labels23 lab2idx={ v:k for k,v in idx_to_labels.items() } smiles_data = pd.DataFrame({'file_name': [], 'SMILESori':[], 'SMILESpre':[], 'SMILESexp':[], } ) output_dict = {} output_ori={} filtered_output_dict = {} box_thresh=0.1 # for samples, targets in metric_logger.log_every(data_loader, 10, header): # samples = samples.to(device) # # targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # outputs = model(samples) # # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)#.to(device) # orig_target_sizes = targets["orig_size"].to(device) # results = postprocessors(outputs, orig_target_sizes)#RTDETRPostProcessor@@src/zoo/rtertr # for i_, z in enumerate(zip(targets['image_id'], results)): # ti, output=z # output_dict[ti.item()] = [ # output, # targets['img_path'][i_], # targets['SMILES'][i_], # ] # output_ori[ti.item()] =[ # targets['img_path'][i_], # targets['SMILES'][i_], # ] # print(len(output_ori),len(output_dict)) for samples, targets in metric_logger.log_every(data_loader, 10, header): # orig_target_sizes = targets["orig_size"].to(device) for i_, ti in enumerate(targets['image_id']): output_dict[ti.item()] = [ targets['img_path'][i_], targets['SMILES'][i_], ] for key, value in output_dict.items(): image_path = value[0] SMILESori = value[1] # selected_indices = value['scores'] > 0.5#may be >=0.5 cut off, as used the sigmoid? # selected_indices = value[0]['scores'] > box_thresh # true_count = selected_indices.sum().item() #testing here image_path='/cadd_data/samba_share/from_docker/data/work_space/ori/real/acs/ol020229e-Scheme-c3-10.png' tensor,w,h = image_to_tensor(image_path) tensor=tensor.unsqueeze(0).to(device) print(tensor.size()) # Output tensor shape (C x H x W) ori_size=torch.Tensor([w,h]).long().unsqueeze(0).to(device) outputs = model(tensor) result_ = postprocessor2(outputs, ori_size) # result_ = postprocessors(outputs, ori_size) score_=result_[0]['scores'] boxe_=result_[0]['boxes'] label_=result_[0]['labels'] #---------------------------------################################ selected_indices =score_ > box_thresh true_count = selected_indices.sum().item() output={ 'labels': label_[selected_indices].to("cpu").numpy(), 'boxes': boxe_[selected_indices].to("cpu").numpy(), 'scores': score_[selected_indices].to("cpu").numpy() } img_ori = Image.open(image_path).convert('RGB') w_ori, h_ori = img_ori.size # 获取原始图像的尺寸 print(w_ori, h_ori, "orignianl vs 1000,1000") print(f"selected_indices 中 True 的数量: {true_count}") print(f"before nms_per_class, :: box 的数量:{len(output['labels'])}") output = nms_per_class(output['labels'], output['boxes'], output['scores'], iou_thresh=0.5) print(f"after nms_per_class, :: box 的数量:{len(output['labels'])}") # filtered_output_dict={image_path: output} x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2 y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2 # center_coords = torch.stack((x_center, y_center), dim=1) center_coords = np.stack((x_center, y_center), axis=1) # center_coords=np.stack((x_center, y_center)).reshape(-1,2)#NOTE not do this, mix element order shits #TODO split atom_charge \ bond drawing output = {'bbox': output["boxes"],#.to("cpu").numpy(), 'bbox_centers': center_coords,#.to("cpu").numpy(), 'scores': output["scores"],#.to("cpu").numpy(), 'pred_classes': output["labels"],#.to("cpu").numpy() } ############################################################################################################################ img_ori = Image.open(image_path).convert('RGB') w_ori, h_ori = img_ori.size # 获取原始图像的尺寸 print(w_ori, h_ori, "orignianl vs 1000,1000") # 计算缩放比例 scale_x = 1000 / w_ori scale_y = 1000 / h_ori img_ori_1k = img_ori.resize((1000,1000)) img = Image.open(image_path).convert('RGB') img = img.resize((1000,1000)) # atom_bondBox_check=True print(f"from oupt socore > {box_thresh},get box {len(output['bbox'])} after nms_per_class ") # split into atom bond charge nms, then mergd , then box2 mol NOTE charege and bond confidence at least >10% charge_mask = np.array([True if ins in charge_labels and output['scores'][i]>0.1 else False for i, ins in enumerate(output['pred_classes'])]) charges_bbox=output['bbox'][charge_mask] charges_centers= output['bbox_centers'][charge_mask] charges_classes= output['pred_classes'][charge_mask] charges_scores= output['scores'][charge_mask] charges_bbox, charges_centers, charges_scores,charges_classes,figc =view_box_center2(charges_bbox, charges_centers, charges_scores,charges_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) #view_box_center2 help remove large box if boxscore small than 0.5 # bonds_mask2 = np.array([True if ins in bond_labels else False for ins in output['pred_classes']]) # bonds_mask= output['scores'][bonds_mask2]>=0.1# TODO fix me, as training bond box overlap with bondbox,aussme bond socre make sense bonds_mask = np.array([True if ins in bond_labels and output['scores'][i]>0.2 else False for i, ins in enumerate(output['pred_classes'])]) bond_bbox=output['bbox'][bonds_mask] bond_centers= output['bbox_centers'][bonds_mask] bond_classes= output['pred_classes'][bonds_mask] bond_scores= output['scores'][bonds_mask] # bond_bbox2, bond_centers2, bond_scores2,bond_classes2,fig=view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) bond_bbox, bond_centers, bond_scores,bond_classes,fig =view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=3) bond_bbox, bond_classes, bond_scores = nms(bond_bbox, bond_scores,bond_classes, iou_threshold=0.5) heavy_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins != lab2idx['H'] else False for ins in output['pred_classes']]) h_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins == lab2idx['H'] else False for ins in output['pred_classes']]) #TODO fix me if heavy or H all need this view_box_center2 filtering heavy_bbox = output['bbox'][heavy_mask] heavy_classes = output['pred_classes'][heavy_mask] heavy_centers= output['bbox_centers'][heavy_mask] heavy_scores= output['scores'][heavy_mask] heavy_bbox, heavy_centers, heavy_scores,heavy_classes,fighv =view_box_center2(heavy_bbox, heavy_centers, heavy_scores,heavy_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) #TODO del isolated C without bond box overlap delt_hei=[] for hei,hebox in enumerate(heavy_bbox): he_class=idx_to_labels[heavy_classes[hei]] b_nei=[] if he_class in ['C']:#TODO add other cases for bi,bb in enumerate(bond_bbox): overlap_flag=boxes_overlap(bb, hebox)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol if overlap_flag: b_nei.append(bi) if len(b_nei)==0: delt_hei.append(hei) n = len(heavy_scores) # 更新数量 keep_boxes = np.ones(n, dtype=bool) keep_boxes[delt_hei]=False heavy_bbox, heavy_centers, heavy_scores,heavy_classes=heavy_bbox[keep_boxes], heavy_centers[keep_boxes], heavy_scores[keep_boxes],heavy_classes[keep_boxes] h_bbox = output['bbox'][h_mask] h_centers= output['bbox_centers'][h_mask] h_classes= output['pred_classes'][h_mask] h_scores= output['scores'][h_mask] h_bbox, h_centers, h_scores,h_classes,figh =view_box_center2(h_bbox, h_centers, h_scores,h_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) #NOTE need keep the order heavy atom first then following with Hs # atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']]) # atom_bbox=output['bbox'][atoms_mask] # atom_classes=output['pred_classes'][atoms_mask] # 合并 bbox,保持重原子在前,氢原子在后 atom_bbox = np.concatenate([heavy_bbox, h_bbox], axis=0) atom_classes = np.concatenate([heavy_classes, h_classes], axis=0) # atom_centers = np.concatenate([heavy_centers, h_centers], axis=0) atom_scores = np.concatenate([heavy_scores, h_scores], axis=0) #TODO nms checking # kept_bboxes, kept_classes, kept_scores=nms(atom_bbox, atom_scores, atom_classes, iou_threshold=0.5) # # kept_bboxes, kept_classes, kept_scores=nms_atomBox(atom_bbox, atom_scores, atom_classes, iou_threshold=0.5) # merged_bboxes, merged_classes, merged_scores = merge_low_iou_boxes(kept_bboxes, kept_classes, kept_scores, merge_threshold=0.5, score_threshold=0.7) # print(f'ater nms kept_box {len(kept_bboxes)}, followd merge_low_iou_boxes kept_box:: {len(merged_bboxes)}') # atom_bbox, atom_classes, atom_scores=merged_bboxes, merged_classes, merged_scores atom_bbox, atom_scores, atom_classes = refine_boxes(atom_bbox, atom_scores, atom_classes, bond_bbox) x_center = (atom_bbox[:, 0] + atom_bbox[:, 2]) / 2 y_center = (atom_bbox[:, 1] + atom_bbox[:, 3]) / 2 # center_coords = torch.stack((x_center, y_center), dim=1) center_coords = np.stack((x_center, y_center), axis=1) atom_centers=center_coords print(f"before NMS :: heavy box {len(heavy_bbox)} ---- H box {len(h_bbox)}---bond box{len(bond_bbox)}") print(f"after NMS+view_box_center2 :: atom box {len(atom_bbox)} bond box {len(bond_bbox)} charge box {len(charges_bbox)} ") # print(f"bond box with only single atom box overlap:: {single_odd_bi}") print(f"atom box afte NMS and merge_low_iou_boxes") print(f"get box {len(output['bbox'])} with NMS") print(f"atom score >0.1 bond score >0.2, then folllowed with NMS") print(f"bond_bbox nums::",bond_bbox.shape,len(bond_bbox)) print(f" OCR will start involved ")# #check if ODD single-bonds with only one atom exisits, try add the atoms box for this bond new_atoms, single_odd_b2a= Newbox_(atom_bbox,bond_bbox, lab2idx ) print(f"new_atoms number {len(new_atoms)}\n{new_atoms}") if len(new_atoms)>0: for boxout in new_atoms: for k,arr in boxout.items(): value_or_row=output[k] if arr.ndim == 1: output[k]=np.append(value_or_row, arr) elif arr.ndim >= 2: output[k] = np.concatenate([value_or_row, arr], axis=0) else: print('errprs, unkown conditions !!!@') #NOTE try to use OCR to help postprocess box adding and del # 加载图像 OCR image = cv2.imread(image_path) image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # 预处理图像突出下标 gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV) # print(_, thresh) kernel = np.ones((2, 2), np.uint8) dilated = cv2.dilate(thresh, kernel, iterations=1) # cv2.imwrite("preprocessed.jpg", dilated)#NOTE comment if need checking # result = ocr.ocr("preprocessed.jpg", cls=True) # ocr.ocr(image_npocr, cls=True, det=False) result = ocr.ocr(dilated, cls=True) # 直接传递 NumPy 数组 # 解析结果 text_boxes = [] text_contents = [] confidences = [] for line in result: print(line) if line: for box_info in line: box = box_info[0] x_coords = [point[0] for point in box] y_coords = [point[1] for point in box] text_box = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] text = box_info[1][0] text_boxes.append(text_box) text_contents.append(text) confidences.append(box_info[1][1]) print("Detected text boxes:", text_boxes) print("Detected text contents:", text_contents) print("Confidences:", confidences) #after whole img OCRed # Initialize dictionaries and lists ai2text = {} ai2relplace = {} ai2rdkitlab_unknown = {} non_overlapping_texts = [] # Build initial KDTree tree = cKDTree(atom_centers) # Collect indices to delete after the loop to keep tree valid during processing indices_to_delete = set() # Process each OCR text box for ti, text_box in enumerate(text_boxes): text_center = calculate_center(text_box) ocr_text = text_contents[ti] # Normalize OCR text if ocr_text in ['OH', 'HO']: ocr_text = 'O' elif ocr_text in ['SH', 'HS']: ocr_text = 'S' elif ocr_text in ['NH', 'HN']: ocr_text = 'N' elif ocr_text in ['CH', 'HC']: ocr_text = 'C' elif ocr_text == '0': ocr_text = 'O' elif ocr_text == 'L': ocr_text = 'Li' elif ocr_text[-1]=='-': if ocr_text[:-1] in ABBREVIATIONS: ocr_text=ocr_text[:-1] # Find all overlapping atom boxes overlapping_indices = [] for idx in range(len(atom_bbox)): if idx not in indices_to_delete and boxes_overlap(atom_bbox[idx], text_box): overlapping_indices.append(idx) if overlapping_indices: # If there are overlapping atom boxes, merge them if len(overlapping_indices) > 1: # Get the smallest box encompassing all overlapping atom boxes overlapping_boxes = [atom_bbox[idx] for idx in overlapping_indices] merged_box = get_merged_box(overlapping_boxes) overlapping_indices_atomboxclass=[idx_to_labels[atom_classes[i]] for i in overlapping_indices] print(f"Merging {len(overlapping_indices)} atom boxes overlapping with OCR text: {ocr_text}") print(f" {overlapping_indices} boxes type{overlapping_indices_atomboxclass} merged as OCR text: {ocr_text}") merged_area = box_area(merged_box) text_area = box_area(text_box) final_box = merged_box if merged_area >= text_area else text_box else: # If only one overlap, use the text box directly final_box = text_box # Use the OCR text box as the merged box primary_idx = overlapping_indices[0] # atom_bbox[primary_idx] = text_box # Update the primary atom box atom_bbox[primary_idx] = final_box # Update class and dictionaries based on OCR text if ocr_text in ABBREVIATIONS: ai2relplace[primary_idx] = ocr_text atom_classes[primary_idx] = 0 if ocr_text in lab2idx: atom_classes[primary_idx] = lab2idx[ocr_text] elif ocr_text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']: atom_classes[primary_idx] = lab2idx[ocr_text] elif ocr_text in RGROUP_SYMBOLS or (ocr_text[0] == 'R' and ocr_text[1:].isdigit()): atom_classes[primary_idx] = 0 else: ai2rdkitlab_unknown[primary_idx] = ocr_text atom_classes[primary_idx] = 0 ai2text[primary_idx] = ocr_text # Mark redundant indices for deletion indices_to_delete.update(overlapping_indices[1:]) else: # No overlap: record the text box and nearest atom index distance, nearest_idx = tree.query(text_center) if nearest_idx not in indices_to_delete: # Only record if nearest_idx is still valid print(f"No overlap for OCR text '{ocr_text}', nearest atom box index: {nearest_idx}") non_overlapping_texts.append({ 'text': ocr_text, 'text_box': text_box, 'nearest_atom_idx': nearest_idx, 'distance': distance }) #set up atom_ocr match atom_class atom_ocr=[] for i,ai in enumerate(atom_classes): if i in ai2text: atom_ocr.append(ai2text[i]) # elif i in ai2rdkitlab_unknown: # atom_ocr.append(ai2rdkitlab_unknown[i]) else: atom_ocr.append(idx_to_labels[ai]) print(f"atom class + ocr presented as symbols::\n{atom_ocr}") atom_ocr=np.array(atom_ocr) # Perform deletions after the loop if indices_to_delete: indices_to_keep = np.setdiff1d(np.arange(len(atom_bbox)), list(indices_to_delete)) atom_bbox = atom_bbox[indices_to_keep] atom_classes = atom_classes[indices_to_keep] atom_centers = atom_centers[indices_to_keep] atom_scores = atom_scores[indices_to_keep] atom_ocr= atom_ocr[indices_to_keep] # Adjust dictionary indices for d in [ai2text, ai2relplace, ai2rdkitlab_unknown]: d_new = {} for old_idx, value in d.items(): new_idx = np.where(indices_to_keep == old_idx)[0][0] if old_idx in indices_to_keep else None if new_idx is not None: d_new[new_idx] = value d.clear() d.update(d_new) # Adjust nearest_atom_idx in non_overlapping_texts for entry in non_overlapping_texts: old_idx = entry['nearest_atom_idx'] if old_idx in indices_to_keep: entry['nearest_atom_idx'] = np.where(indices_to_keep == old_idx)[0][0] else: entry['nearest_atom_idx'] = -1 # Mark as invalid if the nearest atom was deleted # Rebuild KDTree if needed for further use tree = cKDTree(atom_centers) # Final output print("Whole img with OCR :: ai2relplace, ai2rdkitlab_unknown:", [ai2relplace, ai2rdkitlab_unknown]) print(f"Adjusted ai ocr_text: {ai2text}") print(f"Atom box num: {len(atom_bbox)}:: {[idx_to_labels[i] for i in atom_classes]}") print("Non-overlapping OCR text boxes:", non_overlapping_texts) #for all heavy atom labels, consider N3 pred as N, or other cases, I2M not good as paddle on ABC atomcorp_img = Image.open(image_path).convert('RGB') atomcorp_img1k=atomcorp_img.resize([1000,1000]) text_contents_star=[] text_confidences_star=[] text_boxes_star=[] boxid2del=dict() ocr_discrepancies = {} # New dictionary to record OCR vs. AI mismatches print(f"has atom_bbox number {len(atom_bbox)}") for i,box in enumerate(atom_bbox):#split ocr image # if i in ai2text: continue #may be need comment this, if splited OCR acc better!! abox =box* [scale_x, scale_y, scale_x, scale_y] cropped_img=atomcorp_img1k.crop(abox)#if use the small ori image will not get infos image_npocr = np.array(cropped_img) result_ocr= ocr2.ocr(image_npocr, det=False)#,cls=True,use_debug=False, det=False)#det fale not box but get rcongized more # result_ocr= ocr.ocr(image_npocr, cls=True, det=False)#,cls=True, det=False)#det fale not box but get rcongized more if result_ocr: for line in result_ocr: # print(f"Atom box--- {i}, OCR result---: {line}") if line: box_flag=has_boxes(line) for box_info in line: # print(len(box_info)) if not box_flag: text=box_info[0] #[^a-zA-Z0-9\*\-\+] 表示匹配除了字母、数字、*、- 和 + 之外的所有字符。 text=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', text)#remove special chars score_=box_info[1] text_contents_star.append(text) text_confidences_star.append(score_) else:#when paddleOCRuse detection model get text box info box = box_info[0] x_coords = [point[0] for point in box] y_coords = [point[1] for point in box] text_box = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] text = box_info[1][0] text=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', text)#remove special chars text_boxes_star.append(text_box) text_contents_star.append(text) score_=box_info[1][1] text_confidences_star.append(score_) if i in ai2text:#ocr 全img vs split img # print(f'from whole img ocr atom box {i}----from whole img::{ai2text[i]}') if ai2text[i] != text: text=ai2text[i] if len(ai2text[i])>=len(text) else text print(f"Atom box {i}@@ OCR text: {text}, score: {score_}, AI class: {idx_to_labels[atom_classes[i]]}, AI score: {atom_scores[i]}") # Normalize OCR text if text in ['OH', 'HO']: text = 'O' elif text in ['SH', 'HS']: text = 'S' elif text in ['NH', 'HN']: text = 'N' elif text in ['CH', 'HC']: text = 'C' elif text == '0': text = 'O' elif text == 'L': text = 'Li' elif '-' in text: if text[:-1] in ABBREVIATIONS: text=text[:-1] # Check if OCR text is a single character and not a valid element is_single_char = len(text) == 1 ai_pred = idx_to_labels[atom_classes[i]] #TOD add more simpfiled if text=='0': atom_classes[i]=lab2idx['O'] elif text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']: atom_classes[i]=lab2idx[text]#need update to keep H following Heavy # elif # ocr recongnized on lable C as other things chars elif is_single_char and text not in ELEMENTS and ai_pred == 'C': # Do not replace AI prediction, just record discrepancy ocr_discrepancies[i] = { 'ocr_text': text, 'ocr_score': score_, 'ai_class': ai_pred, 'ai_score': atom_scores[i] } else: overlap, b_nei=AtomBox2bondBox(atom_bbox[i],bond_bbox) if not overlap: if text not in ELEMENTS and text not in ABBREVIATIONS: # print(f"new cases::{text} for atombox {i} {atom_bbox[i]}check how to fix it !!!") # print(f'OCR text:: {text} score ::{box_info}||atom clss::{idx_to_labels[atom_classes[i]]} {atom_scores[i]}') if text != idx_to_labels[atom_classes[i]]: boxid2del[i]= [text,idx_to_labels[atom_classes[i]]]#will delt this atom box infos else: if text != idx_to_labels[atom_classes[i]]: if atom_scores[i]<=score_: if text in RGROUP_SYMBOLS or text in ABBREVIATIONS: ai2relplace[i]=text atom_classes[i]=0 if text in lab2idx and lab2idx[text] in list(range(23,29)):atom_classes[i]=lab2idx[text] elif text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']: atom_classes[i]=lab2idx[text] else: ai2relplace[i]=text atom_classes[i]=0 # 按照 value 的第一个元素(假设是字符串)的长度进行排序,长度大的排前 boxid2del = dict(sorted(boxid2del.items(), key=lambda item: item[0], reverse=True)) print(f"considering del box",boxid2del) print("after split img OCR:: ai2relplace,ai2rdkitlab_unknown",[ai2relplace,ai2rdkitlab_unknown]) print(f"considering delet atomb box :{boxid2del}") syms=[] for i in range(len(atom_classes)): if i in ai2relplace: syms.append(ai2relplace[i]) elif i in ai2rdkitlab_unknown:syms.append(ai2rdkitlab_unknown[i]) else: syms.append(idx_to_labels[atom_classes[i]]) print(f"atombox {atom_classes}:: number {len(atom_classes)}\n",[idx_to_labels[i] for i in atom_classes]) print(f" {syms}") #chedck isolated box, if need add bond box between the isolated box or not isolated_ais = [] # 第一步:构建 bond 到 atom 的映射,并计算 distance_threshold bond_distances = [] singleAtomBond=dict() for bi, bb in enumerate(bond_bbox): overlapped_atoms = [] overlapped_abox = [] for ai, aa in enumerate(atom_bbox): overlap_flag = boxes_overlap(bb, aa) if overlap_flag: overlapped_atoms.append(ai) overlapped_abox.append(aa) # if bi not in b2a.keys(): # b2a[bi] = [ai] # else: # b2a[bi].append(ai) if len(overlapped_atoms) == 2: center1 = calculate_center(atom_bbox[overlapped_atoms[0]]) center2 = calculate_center(atom_bbox[overlapped_atoms[1]]) distance = np.linalg.norm(center1 - center2) bond_distances.append(distance) # print(f"Bond {bi} connects atoms {overlapped_atoms}, distance: {distance:.2f}") elif len(overlapped_atoms) == 1: print(f"single bond - atom still exists for bond {bi}, need porcess this !!") if bi not in singleAtomBond: singleAtomBond[bi]=overlapped_atoms#considering use the add H box for solve TODO # 动态计算 distance_threshold distance_threshold = max(bond_distances) if bond_distances else 100.0 # 默认值 10 如果无 bond distance_threshold_min = min(bond_distances) if bond_distances else 100.0 # 默认值 10 如果无 bond print(f"Calculated distance_threshold center based: {distance_threshold:.2f}") # 第二步:构建 atom 到 bond 的映射,并检测孤立原子 a2b=dict() for ai, aa in enumerate(atom_bbox): b_nei = [] for bi, bb in enumerate(bond_bbox): overlap_flag = boxes_overlap(bb, aa) if overlap_flag: b_nei.append(bi) a2b[ai] = b_nei if a2b[ai] ==[]: if ai not in isolated_ais: isolated_ais.append(ai) isolated_ais=sorted(isolated_ais,reverse=True)#avoid delte atom with index errors print(f"isolated_ais atom box {isolated_ais}\n ", [idx_to_labels[i] for i in atom_classes[isolated_ais]]) # 第三步:处理孤立原子,尝试合并或删除 updated_atom_bbox = atom_bbox.copy() updated_atom_classes = atom_classes.copy() updated_atom_scores = atom_scores.copy() print(f"atom bbox num {len(atom_bbox)}")#ttt new_bond_bbox=[] deleted_ais=[] del4boxid2del=set() for isolated_ai in isolated_ais: isolated_box = atom_bbox[isolated_ai] isolated_center = calculate_center(isolated_box) nearest_distance = float('inf') nearest_ai = -1 # 找到最近的非孤立原子 for ai, aa in enumerate(atom_bbox): if ai not in isolated_ais and ai != isolated_ai: center = calculate_center(aa) distance = np.linalg.norm(isolated_center - center) if distance < nearest_distance: nearest_distance = distance nearest_ai = ai # 合并或删除逻辑 if nearest_ai != -1: if nearest_distance<=distance_threshold_min or (nearest_distance <=distance_threshold and nearest_distance>=distance_threshold_min):#this the centers dist not bond length nearest_box = atom_bbox[nearest_ai] nearest_class = atom_classes[nearest_ai] nearest_center = calculate_center(nearest_box) if isolated_ai in boxid2del: textocr2del=boxid2del[isolated_ai][0] else: textocr2del=None #NOTE based ont the class and ovelap bond box to adjust overlap1,bondnei=AtomBox2bondBox(nearest_box,bond_bbox) if len(bondnei)==1:#could be add two other bond, add bond box # if textocr2del in [',', '+', '-'] or not any(c.isupper() for c in textocr2del): if textocr2del is not None and not any(c.isupper() for c in textocr2del): # del4boxid2del.add(isolated_ai) deleted_ais.append(isolated_ai) pass else: new_bc = (isolated_center + nearest_center)*0.5 new_bondbox=np.array([new_bc[0] - nearest_distance*0.5, new_bc[1] - nearest_distance*0.5, new_bc[0] + nearest_distance*0.5, new_bc[1] + nearest_distance*0.5] ) new_bond_bbox.append(new_bondbox.reshape(-1,4)) print(f'add a new bond box new_bc for two atom boxes {isolated_ai} ---- {nearest_ai}::\n {idx_to_labels[atom_classes[isolated_ai]]} --- {idx_to_labels[atom_classes[nearest_ai]]}') else:#TODO fix me when get the case with >=2 bonds need add bond also try: new_box = merge_boxes(isolated_box, nearest_box) updated_atom_bbox[nearest_ai] = new_box chosed_score_ = max(atom_scores[isolated_ai], atom_scores[nearest_ai]) updated_atom_scores[nearest_ai] = chosed_score_ except Exception as e: print(f"file_name@: {image_path}\n SMILES in csv:\n{SMILESori}") print(e) print('nearest_ai ', nearest_ai) check2=True if check2: padding=5 # box_thresh=0.3 atombox_img=draw_objs(copy.deepcopy(img), atom_bbox* [scale_x, scale_y, scale_x, scale_y], atom_classes, atom_scores , category_index=idx_to_labels, box_thresh=box_thresh, line_thickness=3, font='arial.ttf', font_size=10) bonbox_img=draw_objs(copy.deepcopy(img), bond_bbox* [scale_x, scale_y, scale_x, scale_y], bond_classes, bond_scores , category_index=idx_to_labels, box_thresh=0.01, line_thickness=3, font='arial.ttf', font_size=10) # Get sizes of the individual images atom_width, atom_height = atombox_img.size bon_width, bon_height = bonbox_img.size combined_width = atom_width + bon_width + padding * 3 combined_height = max(atom_height, bon_height) + padding * 2 combined_img = Image.new('RGB', (combined_width, combined_height), color=(255, 255, 255)) # White background # Paste the images onto the new canvas combined_img.paste(atombox_img, (padding, padding)) # Top-left combined_img.paste(bonbox_img, (atom_width + padding * 2, padding)) print(f"atom box afte NMS and merge_low_iou_boxes") combined_img.save(f"tttttttttttttttttttttttBoxed.png" ) raise Exception("@debug this!!\n") if chosed_score_>=0.5: if chosed_score_==atom_scores[isolated_ai]: updated_atom_classes[nearest_ai] = 0 # mrege replaced with * # else: # updated_atom_classes[nearest_ai] = atom_classes[nearest_ai] # 保留较高 score 的类别 updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0)#after mreged need del it # updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0) updated_atom_classes = np.delete(updated_atom_classes, isolated_ai) updated_atom_scores = np.delete(updated_atom_scores, isolated_ai) print(f"Merged atom {isolated_ai} into {nearest_ai}, new box: {new_box}") isolated_ais.remove(isolated_ai) deleted_ais.append(isolated_ai) # elif nearest_distance<=distance_threshold_min:#very close,mrege with nearest one elif atom_scores[isolated_ai] < 0.5: # 删除低分孤立原子 updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0) updated_atom_classes = np.delete(updated_atom_classes, isolated_ai) updated_atom_scores = np.delete(updated_atom_scores, isolated_ai) print(f"DELET isolated atom {isolated_ai} with score {atom_scores[isolated_ai]}") deleted_ais.append(isolated_ai) # 更新索引,因为数组维度变化 isolated_ais = [i if i < isolated_ai else i - 1 for i in isolated_ais if i != isolated_ai] else: print(f"KEEP isolated atom {isolated_ai} with score {atom_scores[isolated_ai]} >= 0.5") else: if atom_scores[isolated_ai] < 0.5: updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0) updated_atom_classes = np.delete(updated_atom_classes, isolated_ai) updated_atom_scores = np.delete(updated_atom_scores, isolated_ai) print(f"DELET isolated atom {isolated_ai} with score {atom_scores[isolated_ai]}") deleted_ais.append(isolated_ai) isolated_ais = [i if i < isolated_ai else i - 1 for i in isolated_ais if i != isolated_ai] else: print(f"KEEP isolated atom {isolated_ai} with score {atom_scores[isolated_ai]} >= 0.5") if len(new_bond_bbox)>0: for i,bond_box in enumerate(new_bond_bbox): bond_bbox= np.concatenate([bond_bbox,bond_box],axis=0) bond_scores= np.concatenate((bond_scores,np.array([0.9])),axis=0) bond_classes= np.concatenate([bond_classes,np.array([13])],axis=0) #reset bond center x_center = (bond_bbox[:, 0] + bond_bbox[:, 2]) / 2 y_center = (bond_bbox[:, 1] + bond_bbox[:, 3]) / 2 # center_coords = torch.stack((x_center, y_center), dim=1) center_coords = np.stack((x_center, y_center), axis=1) bond_centers=center_coords #del the additional atom box that not connected by bond box also mismatch other rules if len(deleted_ais) > 0: # 如果有需要删除的索引 print(f"will delete atom box with idx :: {deleted_ais}") # 使用 np.delete 一次性删除所有指定的行 atom_classes = np.delete(atom_classes, deleted_ais, axis=0) atom_scores = np.delete(atom_scores, deleted_ais, axis=0) atom_bbox = np.delete(atom_bbox, deleted_ais, axis=0) atom_ocr = np.delete(atom_ocr, deleted_ais, axis=0) # eles=[idx_to_labels[i] for i in atom_classes] # print(eles,len(eles)) cur_atomSymbols=[idx_to_labels[i] for i in atom_classes] ocr_wholeImg=[] for i in atom_classes: if i in ai2relplace: ocr_wholeImg.append(ai2relplace[i]) elif i in ai2rdkitlab_unknown: ocr_wholeImg.append(ai2rdkitlab_unknown[i]) else: ocr_wholeImg.append(idx_to_labels[i]) print("ai2relplace,ai2rdkitlab_unknown",ai2relplace,ai2rdkitlab_unknown) print("cur_atomSymbols:",cur_atomSymbols) print(" atomSymbolsOCR:",ocr_wholeImg) # 找到 'H' 的索引, H after Heavy h_indices = np.where(atom_classes == lab2idx['H'])[0] non_h_indices = np.where(atom_classes != lab2idx['H'])[0] # print(h_indices,non_h_indices) # 重新排序 new_order = np.concatenate((non_h_indices, h_indices)).astype(np.int64) # newid2old_Hafter={ i:j for i,j in enumerate(new_order)} # old2newid_Hafter={ j:i for i,j in enumerate(new_order)} atom_classes = atom_classes[new_order] atom_bbox = atom_bbox[new_order] atom_scores = atom_scores[new_order] x_center = (atom_bbox[:, 0] + atom_bbox[:, 2]) / 2 y_center = (atom_bbox[:, 1] + atom_bbox[:, 3]) / 2 # center_coords = torch.stack((x_center, y_center), dim=1) center_coords = np.stack((x_center, y_center), axis=1) atom_centers=center_coords#TODO 记得把 abbve idx label same reoder or mapping then bond #bond box reoder like atom box, let the singleAtomBond later bond_bbox = reorder_bond_bbox(bond_bbox, singleAtomBond) bond_classes = reorder_bond_bbox(bond_classes, singleAtomBond) bond_scores = reorder_bond_bbox(bond_scores, singleAtomBond) bond_centers = reorder_bond_bbox(bond_centers, singleAtomBond) # 第二步:构建 atom 到 bond 的映射,并检测孤立原子 a2b=dict() for ai, aa in enumerate(atom_bbox): b_nei = [] for bi, bb in enumerate(bond_bbox): overlap_flag = boxes_overlap(bb, aa) if overlap_flag: b_nei.append(bi) a2b[ai] = b_nei if a2b[ai] ==[]: if ai not in isolated_ais: isolated_ais.append(ai) b2a=dict() for bi,bb in enumerate(bond_bbox): overlapped_atoms = [] overlapped_abox=[] for ai,aa in enumerate(atom_bbox): overlap_flag=boxes_overlap(bb, aa)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol if overlap_flag: # print(bb, aa,overlap_flag) overlapped_atoms.append(ai) overlapped_abox.append(aa) if bi not in b2a.keys(): b2a[bi]=[ai] else: # vais=b2a[bi] b2a[bi].append(ai) if len(overlapped_atoms) == 1: print(f"single bond -atom still exists {overlapped_atoms}") #c2a a2c #charge atom idx maping if len(charges_classes) > 0: # print(charges_bbox,charges_classes,len(charges_classes)) kdt = cKDTree(atom_centers) atid_list=list(range(len(atom_centers))) used_charge_indices=set() c2a=dict() for i, (x,y) in enumerate(charges_centers): overlapped_abox=[] cc=charges_bbox[i] for ai, aa in enumerate(atom_bbox): overlap_flag=boxes_overlap(cc, aa) ac_iou=calculate_iou(cc, aa) charge_=charges_classes[i] charge_score=charges_scores[i] if overlap_flag: if i in c2a: c2a[i].append(ai) else: c2a[i]=[ai] if ai not in atid_list: print(f"Warning: ai {ai} is out of range for atom_list.") continue # 跳过当前循环迭代 # idx_to_labels[charges_classes[0]] a2c=dict() for ci,v in c2a.items(): charge_=idx_to_labels[charges_classes[ci]] if len(v)==1: a2c[v[0]]=ci else: for ai in v: ats=idx_to_labels[atom_classes[ai]] if ats=='other': ats='*' if ats in ['F','Cl','I','Br','O'] and int(charge_)<0: a2c[ai]=ci elif ats in ['N','H','P'] and int(charge_)>0: a2c[ai]=ci else: print(f'unusuaal case charge {charge_} with atom {ats}!!') print(f"all a2b b2a a2c c2a done, start mol built") #finsh the update of box back to the output for retraining used output={ 'bbox': np.concatenate([atom_bbox, bond_bbox,charges_bbox], axis=0), 'bbox_centers': np.concatenate([atom_centers, bond_centers,charges_centers],axis=0), 'scores': np.concatenate([atom_scores, bond_scores, charges_scores],axis=0), 'pred_classes': np.concatenate([atom_classes, bond_classes, charges_classes],axis=0), 'image_path': image_path } # boxinfo boxinfor={ 'bbox': output['bbox'], 'scores': output['scores'],#TODO use same vocabl ? 'pred_classes': output['pred_classes'],#[ lab2idx[x] for x in output['pred_classes']],#changet it back to character 'image_path': image_path } #split agin for buit mol charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']]) charges_bbox=output['bbox'][charge_mask] charges_centers=bbox2center(charges_bbox) # charges_centers= output['bbox_centers'][charge_mask] charges_classes= output['pred_classes'][charge_mask] charges_scores= output['scores'][charge_mask] charges_bbox, charges_centers, charges_scores,charges_classes,figc =view_box_center2(charges_bbox, charges_centers, charges_scores,charges_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) #view_box_center2 help remove large box if boxscore small than 0.5 # bonds_mask2 = np.array([True if ins in bond_labels else False for ins in output['pred_classes']]) # bonds_mask= output['scores'][bonds_mask2]>=0.1# TODO fix me, as training bond box overlap with bondbox,aussme bond socre make sense bonds_mask = np.array([True if ins in bond_labels and output['scores'][i]>0.2 else False for i, ins in enumerate(output['pred_classes'])]) bond_bbox=output['bbox'][bonds_mask] bond_centers=bbox2center(bond_bbox) # bond_centers= output['bbox_centers'][bonds_mask] bond_classes= output['pred_classes'][bonds_mask] bond_scores= output['scores'][bonds_mask] print(f"before view_box_center2 bond nums {len(bond_scores)}") # bond_bbox2, bond_centers2, bond_scores2,bond_classes2,fig=view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) bond_bbox, bond_centers, bond_scores,bond_classes,fig =view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=3) print(f"after view_box_center2 bond nums {len(bond_scores)}") heavy_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins != lab2idx['H'] else False for ins in output['pred_classes']]) h_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins == lab2idx['H'] else False for ins in output['pred_classes']]) #TODO fix me if heavy or H all need this view_box_center2 filtering heavy_bbox = output['bbox'][heavy_mask] # heavy_classes = output['pred_classes'][heavy_mask] heavy_centers=bbox2center(heavy_bbox) # heavy_centers= output['bbox_centers'][heavy_mask] heavy_scores= output['scores'][heavy_mask] heavy_classes = output['pred_classes'][heavy_mask] heavy_bbox, heavy_centers, heavy_scores,heavy_classes,fighv =view_box_center2(heavy_bbox, heavy_centers, heavy_scores,heavy_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) ###########################start build mol ########################## rwmol_ = Chem.RWMol() boxi2ai = {} # 预测索引 -> RDKit 索引 placeholder_atoms=dict() J=0 for i, (bbox, a) in enumerate(zip(atom_bboxes, atom_classes)): a2labl=False a=replace_cg_notation(a) # print(a,'atom box class label') if a in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:# '*', I2M's defined atom types # if a=='H':continue#skip H fristly,only with heavy atom then addH ad = Chem.Atom(a)#TODO consider non chemical group and label for using #TODO add pd rdkit known elemetns here elif a in ELEMENTS: ad = Chem.Atom(a) elif a in ABBREVIATIONS : ad = Chem.Atom("*") placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, a2labl=True else: if N_C_H_expand(a): ad = Chem.Atom("*") placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, a2labl=True elif C_H_expand(a): ad = Chem.Atom("*") placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, a2labl=True elif C_H_expand2(a): ad = Chem.Atom("*") placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, a2labl=True elif formula_regex(a): ad = Chem.Atom("*") placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, a2labl=True else: ad = Chem.Atom("*") if a not in ['*',"other"]: a2labl=True # placeholder_atoms[idx] = a # atom = Chem.Atom(symbol) rwmol_.AddAtom(ad) boxi2ai[J] = rwmol_.GetNumAtoms() - 1 if a2labl: rwmol_.GetAtomWithIdx(J).SetProp("atomLabel", f"{a}")#mol set with label, mol_rebuild not J+=1 # 使用 KDTree 构建重原子间的键(如果提供了 bond_bbox) if len(charges_classes) > 0: for k,v in a2c.items(): fc=int(idx_to_labels[charges_classes[v]]) rwmol_.GetAtomWithIdx(k).SetFormalCharge(fc) # print(f"mol with heavy atoms number {i+1}, max heavy atom id {i}") print(f"mol with atoms number {i+1}, max atom id {i}") print(f"mol with bond box number {len(bond_classes)}") print(f"placeholder_atoms@@ {placeholder_atoms}") #重原子 skeleton mol bonds=dict() existing_bonds = set() b2aa=dict() singleAtomBond=[] bondWithdirct=[] # tree_heavy = KDTree(heavy_centers)#TODO before add bond consdiering reodering bond ?? tree_atom = KDTree(atom_centers)#TODO as atom bond are all reodered to kee H last if len(idx_to_labels)==30: _margin=0#ad this version bond dynamicaly changed for bi, (bbox, idx_) in enumerate(zip(bond_bbox, bond_classes)):#not work for cross-bond, longer bond, as the center of bond may be close to as atoms not it two atoms bond_type = idx_to_labels[idx_] if len(idx_to_labels)==23: if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: _margin = 5 else: _margin = 8 anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1]) oposite_anchor_positions = anchor_positions.copy() oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1] # Upper left, lower right, lower left, upper right # x1y1, x2y2, x1y2, x2y1 : dinuogl lines anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions]) # print(f"anchor_positions {anchor_positions.shape}\n{anchor_positions}") dists, neighbours = tree_atom.query(anchor_positions, k=1) if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0: # visualize setup begin_idx, end_idx = neighbours[:2] else: # visualize setup begin_idx, end_idx = neighbours[2:] atom1_idx = boxi2ai[begin_idx] atom2_idx = boxi2ai[end_idx] if atom1_idx == atom2_idx:#NOTE when bond with only one terminal atom, other side H not used print(f"attempt to add self-bond:{bi} atomIdx1 == atomIdx2 ::{[atom1_idx, atom2_idx]}") print(f"for bond bi {bi} H atom may involbed dists:",dists) print(neighbours) print("anchor_positions",anchor_positions) else: if bond_type in ['-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: if bond_type in BONDDIRECT: bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', bond_type) bondWithdirct.append(bi) else: bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', None) bond_type=BONDTYPE['SINGLE'] elif bond_type == '=': bonds[bi] = (atom1_idx, atom2_idx, 'DOUBLE', None) bond_type=BONDTYPE['DOUBLE'] elif bond_type == '#': bonds[bi] = (atom1_idx, atom2_idx, 'TRIPLE', None) bond_type=BONDTYPE['TRIPLE'] else: print(f'unkown bond type relaced with single@@ {bond_type}') bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', None) bond_type=BONDTYPE['SINGLE'] # 检查价态 atom1 = rwmol_.GetAtomWithIdx(atom1_idx) atom2 = rwmol_.GetAtomWithIdx(atom2_idx) val1 = sum(b.GetBondTypeAsDouble() for b in atom1.GetBonds()) val2 = sum(b.GetBondTypeAsDouble() for b in atom2.GetBonds()) max_val1 = max(VALENCES[atom1.GetSymbol()]) max_val2 = max(VALENCES[atom2.GetSymbol()]) # bond_order = bond_type.AsDouble() bond_order=BONDTYPE2ORD[bond_type] if val1 + bond_order <= max_val1 and val2 + bond_order <= max_val2: bond1 = rwmol_.GetBondBetweenAtoms(atom1_idx, atom2_idx) bond2 = rwmol_.GetBondBetweenAtoms(atom2_idx, atom1_idx) if bond1 or bond2: # print(f'bond exists for {[atom1_idx, atom2_idx]}') pass # if (atom1_idx, atom2_idx) not in existing_bonds and (atom2_idx, atom1_idx) not in existing_bonds: else: # print(atom1_idx, atom2_idx, bond_type,[ bi, idx_to_labels[idx_] ]) rwmol_.AddBond(atom1_idx, atom2_idx, bond_type) else: print(f"Skipping bond {bi}: Exceeds valence.") existing_bonds.add((atom1_idx, atom2_idx)) b2aa[bi]=sorted([atom1_idx, atom2_idx]) if len(bond_bbox)==1 and len(atom_bbox)==2: ca1='[*:0][C:2]#[C:3][*:1]'#acs phC#CpH rwmol_ = Chem.RWMol() ats= ['*','*','C','C'] for ia in ats: a=Chem.Atom(ia) id_=rwmol_.AddAtom(a) # print(ia,id_) rwmol_.AddBond(2, 3, Chem.BondType.TRIPLE) rwmol_.AddBond(0, 2, Chem.BondType.SINGLE) rwmol_.AddBond(1, 3, Chem.BondType.SINGLE) # Chem.MolFromSmiles(ca1) for i in range(len(atom_classes)): atom_classes[i]=lab2idx['*'] AllChem.Compute2DCoords(rwmol_) else: rwmol_=copy.deepcopy(rwmol_) print(f"placeholder_atoms {placeholder_atoms}") #assign 2D coords mol = rwmol_.GetMol() mol.RemoveAllConformers() conf = Chem.Conformer(mol.GetNumAtoms()) # conf.Set3D(True) # for i, (x, y) in enumerate(heavy_centers): for i, (x, y) in enumerate(atom_centers): x, y=float(x),float(y) conf.SetAtomPosition(i, (x, y, 0))#TODO why some time need -y, just display same as ori? mol.AddConformer(conf) # Chem.SanitizeMol(mol) Chem.AssignStereochemistryFrom3D(mol) rwmol_=Chem.RWMol(mol) #as afte H a\lso didthis skeleton_mol=copy.deepcopy(rwmol_) print(skeleton_mol.GetNumBonds()) chiral_centers_aids = Chem.FindMolChiralCenters(mol, includeUnassigned=True) # H realted post-process heavyNumber=len(heavy_centers) print(f'mol with heavy number atoms {heavyNumber}, max id {heavyNumber-1}') onlyHeayMol=copy.deepcopy(rwmol_) chiral_centers = Chem.FindMolChiralCenters( rwmol_, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) chiral_center_ids = [idx for idx, _ in chiral_centers] Hais=[] Hais_bt=[] Hbd=[] # H_existing_bonds = set() for bi, ais in b2a.items():#from box overlap bt=bond_classes[bi]# in [14,15]#directon bond for ai in ais: if ai>heavyNumber-1: if bt in [14,15]:#directon bond Hais.append(ais)#NOTE ais ai increasing order as two for loop increasing print(f"within H bond box id {bi} bond direction {idx_to_labels[bt]} atoms box id {ais} ") Hais_bt.append(idx_to_labels[bt]) Hbd.append(bi) # print(bonds[bi] ) # add Hbonds with direction H_existing_bonds = set() ha2boxa=dict() for ais, bt in zip(Hais,Hais_bt): idx_2=ais[-1] idx_1=ais[0] hbond=rwmol_.GetBondBetweenAtoms(idx_1,idx_2) if hbond is not None: if idx_1 in chiral_center_ids:#if not in the chiral atom, will not set bond directions hbond.SetBondDir(BOND_DIRS[bt]) else: had = Chem.Atom("H") addHatom_idx = rwmol_.AddAtom(had) ha2boxa[addHatom_idx]=idx_2 # print(idx_2,addHatom_idx)#Note if detected H box will lead idx_2 != addHatom_idx atom= rwmol_.GetAtomWithIdx(idx_1) max_val=max(VALENCES[atom.GetSymbol()]) val = sum(b.GetBondTypeAsDouble() for b in atom.GetBonds()) if (idx_1, addHatom_idx) not in H_existing_bonds and (addHatom_idx, idx_1) not in H_existing_bonds: if val<=max_val-1: # print(f"atom id {idx_1} val {val} max_val {max_val}") print(idx_1, addHatom_idx)#let check bond exist or not!! rwmol_.AddBond(idx_1,addHatom_idx, Chem.BondType.SINGLE)#BOND_DIRS[bt] b=rwmol_.GetBondBetweenAtoms(idx_1,addHatom_idx) if idx_1 in chiral_center_ids:#if not in the chiral atom, will not set bond directions b.SetBondDir(BOND_DIRS[bt])#############Note can be done in the following tree H_existing_bonds.add((idx_1,addHatom_idx)) i if len(ha2boxa)>0:#consider Hnow #use box coords assign 2D, remove extra Hs also update box rwmol_.RemoveAllConformers()# conf = Chem.Conformer(rwmol_.GetNumAtoms()) conf.Set3D(True) coords2d=[] for i, (x, y) in enumerate(heavy_centers): position = Point3D(float(x), float(y), 0.) # Create a Point3D object with x, y, and z=0 conf.SetAtomPosition(i, position) coords2d.append([x,y]) for k,v in ha2boxa.items(): x,y=atom_centers[v] position = Point3D(float(x), float(y), 0.) # Create a Point3D object with x, y, and z=0 conf.SetAtomPosition(k, position) coords2d.append([x,y]) rwmol_.AddConformer(conf) additonalH=detect_unconnected_hydrogens(rwmol_) if len(additonalH)>0: rwmol_,rmovedAtomcoords=remove_unconnected_hydrogens2(rwmol_) #NOTE 留给将来WEB开发用will dercease h atom,but the box have not updated TODO fix me this in feature activate learning #update atom box infors if len(rmovedAtomcoords)>0:#update box infors delbb=[] kdt = cKDTree(atom_centers) for i, (x,y,z) in enumerate(rmovedAtomcoords):#z=0 dist, idx_=kdt.query([x,y], k=1) delbb.append(idx_) mask = np.ones(len(atom_classes), dtype=bool) # 初始化为 True mask[delbb] = False atom_bbox = atom_bbox[mask] atom_classes = atom_classes[mask] atom_centers = atom_centers[mask] # mol# mol_rebuit=copy.deepcopy(mol) mol=copy.deepcopy(rwmol_) conf=mol.GetConformers()[0] mola2xy=dict() mola2d=[] for i,a in enumerate(mol.GetAtoms()): x,y,z=conf.GetAtomPosition(i) mola2xy[i]=[x,y] mola2d.append([x,y]) # print( x,y,z) kdt = cKDTree(mola2d) chiral_centers = Chem.FindMolChiralCenters( mol, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) chiral_center_ids = [idx for idx, _ in chiral_centers] for bi,bcent in enumerate(bond_centers): if bi in bondWithdirct :#and bi not in Hbd:#Note as set Hbd previously dists, a1a2 = kdt.query(bcent, k=2) a1,a2=sorted(a1a2) a1,a2=int(a1),int(a2) bt= mol.GetBondBetweenAtoms(a1, a2)#RDKit 的键是无向的,返回的是同一个 Bond 对象 if bt: # 获取键的当前起点和终点 current_begin = bt.GetBeginAtomIdx() current_end = bt.GetEndAtomIdx() bond_dir=bond_dirs[idx_to_labels[bond_classes[bi]]] if bond_dir == rdchem.BondDir.BEGINWEDGE: reverse_dir = rdchem.BondDir.BEGINDASH elif bond_dir == rdchem.BondDir.BEGINDASH: reverse_dir = rdchem.BondDir.BEGINWEDGE # else: # reverse_dir= rdchem.BondDir.BEGINWEDGE if a1 in chiral_center_ids: if current_begin == a1: bt.SetBondDir(bond_dir) print(f'a1 dir') else: # 如果手性原子是终点,反转方向(例如用相反的楔形键) bt.SetBondDir(reverse_dir) print(f'a1 reverse_dir') # print(f'set bond direction a1a2 {[bi, a1,a2]}') # bt.SetBondDir(bond_dirs[idx_to_labels[bond_classes[bi]]]) elif a2 in chiral_center_ids: if current_begin == a2: bt.SetBondDir(bond_dir) print(f'a2 dir {bond_dir} {reverse_dir}') else: # 如果手性原子是终点,反转方向(例如用相反的楔形键),but not work, just remove and add mol.RemoveBond(current_begin, current_end) mol.AddBond(current_end, current_begin, bt.GetBondType()) bond = mol.GetBondBetweenAtoms(current_end, current_begin) bond.SetBondDir(bond_dir) print(f'a2 reverse_dir {bond_dir} {reverse_dir}') # bt= mol.GetBondBetweenAtoms(a2, a1) # print(f'set bond direction a2a1 {[bi, a2,a1]}') # bt.SetBondDir(bond_dirs[idx_to_labels[bond_classes[bi]]]) else: print('bond stro not with chiral atom???, will ignore this stero bond infors') print(f"{[bi, bond_dir, current_begin,current_end]}") # beginatom=mol.GetAtomWithIdx(current_begin) # Endatom=mol.GetAtomWithIdx(current_end) # beginatom_neis=len(beginatom.GetBonds()) # Endatom_neis=len(Endatom.GetBonds()) try: mol_rebuit=mol.GetMol() conf = mol_rebuit.GetConformer() Chem.WedgeMolBonds(mol_rebuit,conf)# Chem.DetectBondStereochemistry(mol_rebuit) Chem.AssignChiralTypesFromBondDirs(mol_rebuit) Chem.AssignStereochemistry(mol_rebuit) # smiH=Chem.MolToSmiles(mol_rebuit) print(F"smiH\n",smiH) # canon_smilesH = Chem.CanonSmiles(smiH) # print(F"canon_smilesH\n",canon_smilesH) # rdkit_coni_smiH=Chem.MolToSmiles(Chem.MolFromSmiles(smiH)) # print(f"Chem.MolToSmiles(Chem.MolFromSmiles(smiH))\n {rdkit_coni_smiH}") # mol = rdkit.Chem.RWMol(mol_rebuit) other2ppsocr=True if other2ppsocr: print() need_cut=[] ppstr=[] ppstr_score=[] crops=[] index_token=dict() expan=0#NOTE this control how much the part of bond in crop_Img for i_,(heav_c,heav_box) in enumerate(zip(atom_classes,atom_bbox)): if lab2idx['*']==heav_c or lab2idx['other']==heav_c or lab2idx['Cl']==heav_c: need_cut.append(i_) a=heav_box+np.array([-expan,-expan,expan,expan]) # print(heav_box.shape,a.shape) box=a * [scale_x, scale_y, scale_x, scale_y]#TODO need the fix as w h may not equal!! # print(a,box,[scale_x, scale_y, scale_x, scale_y]) cropped_img = img_ori_1k.crop(box) crops.append(cropped_img) image_npocr = np.array(cropped_img) result_ocr= ocr2.ocr(image_npocr, det=False) s_, score_ =result_ocr[0][0] s_previos=atom_ocr[i_] if s_previos != "other" : s_=s_previos if len(s_previos)>=len(s_) else s_ print(f'ocr::idx:{i_}',s_, score_ ) if score_<=0.1:# process cropped_img and try again # print(s_, "xxx",score_) s_='*' if s_=='+' or s_=='-': s_="*" if len(s_)>1: s_=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', s_)#remove special chars if re.match(r'^\d+$', s_): s_=f'{s_}*'#number+ * # print(f'why only numbers ? {s_}') if s_=='L':s_='Li' elif s_=='0':s_='O' elif s_ in ['N,+ CI','N,+ Cl' ,'N,+Cl','N,+CI','N+CI']:s_='N2+Cl-' elif s_ in ['NO,','O,N' ]:s_='NO2' match = re.match(r'^(\d+)?(.*)', s_) # print(s_,'xxxx') if match: numeric_part, remaining_part = match.groups() fc_=mol.GetAtomWithIdx(i_).GetFormalCharge() if remaining_part in ELEMENTS: new_atom = Chem.Atom(remaining_part) mol.ReplaceAtom(i_, new_atom) print(i_, remaining_part,"@@@") elif remaining_part in ABBREVIATIONS:# can be expanded with placeholder_atoms placeholder_atoms[i_]=s_# such 2Na will be get for rdkit elif remaining_part=='OH': new_atom = Chem.Atom("O") mol.ReplaceAtom(i_, new_atom) elif remaining_part=='SH': new_atom = Chem.Atom("S") mol.ReplaceAtom(i_, new_atom) elif remaining_part=='NH': new_atom = Chem.Atom("N") mol.ReplaceAtom(i_, new_atom) mol.GetAtomWithIdx(i_).SetFormalCharge(fc_) index_token[i_]=f'{s_}:{i_}' print(f"idx:{i_}, atm: <{idx_to_labels[heav_c]}> --- [{s_}:{i_}] with score:{score_} ||previousOCR:: {atom_ocr[i_]}") if s_ in ELEMENTS : new_atom = Chem.Atom(s_) mol.ReplaceAtom(i_, new_atom) mol.GetAtomWithIdx(i_).SetProp("atomLabel", f"{s_}")#mol set with label, mol_rebuit not ppstr.append(s_) ppstr_score.append(score_) if s_ in ABBREVIATIONS.keys(): placeholder_atoms[i_]=s_ # bond_dirs_rev={v:k for k,v in bond_dirs.items()} wdbs=[] for b in mol.GetBonds(): bd=b.GetBondDir() bt=b.GetBondType() # print(bd) if bd ==bond_dirs['BEGINDASH'] or bd==bond_dirs['BEGINWEDGE']: a1,a2=b.GetBeginAtomIdx(), b.GetEndAtomIdx() wdbs.append([a1,a2,bt,bond_dirs_rev[bd]]) #expand mol if exists # if len(placeholder_atoms)>0:### cm=copy.deepcopy(mol) # print(placeholder_atoms) expand_mol, expand_smiles= expandABB(cm,ABBREVIATIONS, placeholder_atoms) SMILESpre=expand_smiles rdm=copy.deepcopy(expand_mol) target_mol, ref_mol=rdm, cm AllChem.Compute2DCoords(target_mol) pair=[target_mol, ref_mol] mcs=rdFMCS.FindMCS([target_mol, ref_mol], # larger,small order # atomCompare=rdFMCS.AtomCompare.CompareAny, bondCompare=rdFMCS.BondCompare.CompareAny, ringCompare=rdFMCS.RingCompare.IgnoreRingFusion, matchChiralTag=False, ) mcs_mol = Chem.MolFromSmarts(mcs.smartsString) AllChem.Compute2DCoords(mcs_mol) matches0 = pair[0].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) matches1 = pair[1].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) if len(matches0) != len(matches1): matches0=list(matches0) matches1=list(matches1) # print( "noted: matcher not equal !!") if len(matches0)>len(matches1): for i in range(0,len(matches0)): if i < len(matches1): pass else: ii=i % len(matches1) matches1.append(matches1[ii]) else: for i in range(0,len(matches1)): if i < len(matches0): pass else: ii=i % len(matches0) matches0.append(matches0[ii]) assert len(matches0) == len(matches1), "matcher not equal break!!" atommaping_pairs=[list(zip(matches0[i],matches1[i])) for i in range(0,len(matches0))] atomMap=atommaping_pairs[0] rmsd2=rdkit.Chem.rdMolAlign.AlignMol(prbMol=target_mol, refMol=ref_mol, atomMap=atomMap,maxIters=2000000) print(f"rmsd {rmsd2}") #ocr_mol ocr_mol = copy.deepcopy(target_mol) AllChem.Compute2DCoords(ocr_mol) ocr_smi = Chem.MolToSmiles(ocr_mol) molexp=ocr_mol expandStero_smi, success= rdkit_canonicalize_smiles(ocr_smi) # expandStero_smi = Chem.CanonSmiles(ocr_smi)#, useChiral=(not ignore_chiral)) # TODO #[3H] 2H prpared box for training are too smalled, need adjust if visual_check: boxed_img = draw_objs(img, atom_bbox, atom_classes, atom_scores, category_index=idx_to_labels, box_thresh=0.5, line_thickness=3, font='arial.ttf', font_size=10) opts = Draw.MolDrawOptions() opts.addAtomIndices = False opts.addStereoAnnotation = False img_ori = Image.open(image_path).convert('RGB') img_ori_1k = img_ori.resize((1000,1000)) if other2ppsocr: img_rebuit = Draw.MolToImage(ocr_mol, options=opts,size=(1000, 1000)) else: img_rebuit = Draw.MolToImage(ocr_mol, options=opts,size=(1000, 1000)) combined_img = Image.new('RGB', (img_ori_1k.width + boxed_img.width + img_rebuit.width, img_ori_1k.height)) combined_img.paste(img_ori_1k, (0, 0)) combined_img.paste(boxed_img, (img_ori_1k.width, 0)) combined_img.paste(img_rebuit, (img_ori_1k.width + boxed_img.width, 0)) imprefix=os.path.basename(image_path).split('.')[0] combined_img.save(f"{ima_checkdir}/{imprefix}Boxed.png") new_row = {'file_name':image_path, "SMILESori":SMILESori, 'SMILESpre':SMILESpre, 'SMILESexp':expandStero_smi, } smiles_data = smiles_data._append(new_row, ignore_index=True) #accu similarity calculation if getacc: sameWithOutStero=comparing_smiles(new_row,SMILESpre)#try to ingnore cis chiral, as 2d coords including all the infos sameWithOutStero_exp=comparing_smiles(new_row,expandStero_smi)#this ignore chairity and *number be * NOTE if (type(SMILESori)!=type('a')) or (type(SMILESpre)!=type('a')): if sameWithOutStero or sameWithOutStero_exp: mysum += 1 else: print(f"smiles problems\n{SMILESori}\n{SMILESpre}\n{image_path}") failed.append([SMILESori,SMILESpre,image_path]) mydiff.append([SMILESori,SMILESpre,image_path]) continue mol1 = Chem.MolFromSmiles(SMILESori)#TODO considering smiles with rdkit not recongized in real data if mol1 is None: rd_smi_ori, success1_=rdkit_canonicalize_smiles(SMILESori) mol1=Chem.MolFromSmiles(rd_smi_ori) if (mol_rebuit is None) or (mol1 is None): if sameWithOutStero or sameWithOutStero_exp: mysum += 1 else: print(f'get rdkit mol None\n{SMILESori}\n{SMILESpre}\n{image_path}') failed.append([SMILESori,SMILESpre,image_path]) mydiff.append([SMILESori,SMILESpre,image_path]) continue if mol1: rdk_smi1=Chem.MolToSmiles(mol1) else: rdk_smi1=SMILESori if mol_rebuit: rdk_smi2=Chem.MolToSmiles(mol_rebuit) else: rdk_smi2='' # if rdk_smi1==rdk_smi2 or rdk_smi1==expandStero_smi or sameWithOutStero:#also considering the abbre in Ori if rdk_smi1==rdk_smi2 or rdk_smi1==expandStero_smi: mysum += 1 else: if sameWithOutStero or sameWithOutStero_exp: mysum += 1 else: mydiff.append([SMILESori,SMILESpre,image_path]) if visual_check: combined_img.save(f"{ima_checkdir}/{imprefix}Boxed_diff{len(mydiff)}.png") try: morganfps1 = AllChem.GetMorganFingerprint(mol1, 3,useChirality=True) morganfps2 = AllChem.GetMorganFingerprint(mol_rebuit, 3,useChirality=True) morgan_tani = DataStructs.DiceSimilarity(morganfps1, morganfps2) fp1 = Chem.RDKFingerprint(mol1) fp2 = Chem.RDKFingerprint(mol_rebuit) tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2) if expandStero_smi!= '': fp3 = Chem.RDKFingerprint(molexp) morganfps3 = AllChem.GetMorganFingerprint(molexp, 3,useChirality=True) morgan_tani3 = DataStructs.DiceSimilarity(morganfps1, morganfps3) tanimoto3 = DataStructs.FingerprintSimilarity(fp1, fp3) if morgan_tani3> morgan_tani or tanimoto3> tanimoto : sim+=morgan_tani3 simRD+=tanimoto3 else: simRD+=tanimoto sim+=morgan_tani except Exception as e: print(f"mol to fingerprint erros") simRD+=0 sim+=0 continue except Exception as e: print(f"file_name@: {image_path}\n SMILES in csv:\n{SMILESori}") raise Exception("@debug this!!\n") if getacc: sim_100 = 100*sim/len(smiles_data) simrd100 = 100*simRD/len(smiles_data) flogout.write(f"rdkit concanlized==smiles:{100*mysum/len(smiles_data)}%\n") flogout.write(f"failed:{len(failed)}\n totoal saved in csv : {len(smiles_data)}\n") flogout.write(f"avarage similarity morgan tanimoto: RDKFp tanimoto:: {sim_100}%, {simrd100}% \n")#morgan_tani considering chiraty flogout.write(f'I2M@@:: match--{mysum},unmatch--{len(mydiff)},failed--{len(failed)},correct %{100*mysum/len(smiles_data)} \n') #molscribe evalutate from src.solver.evaluate import SmilesEvaluator evaluator = SmilesEvaluator(smiles_data['SMILESori'], tanimoto=False) res_pre=evaluator.evaluate(smiles_data['SMILESpre']) res_exp=evaluator.evaluate(smiles_data['SMILESexp']) flogout.write(f'MolScribe style evaluation@SMILESpre:: {str(res_pre)} \n') flogout.write(f'MolScribe style evaluation@SMILESexp:: {str(res_exp)} \n') flogout.close() print(f"will save {len(smiles_data)} dataframe into csv") smiles_data.to_csv(outcsv_filename, index=False) import torch.nn as nn import torch.nn.functional as F import torchvision class RTDETRPostProcessor(nn.Module): __share__ = ['num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category'] def __init__(self, classes_dict=None, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) -> None: super().__init__() self.use_focal_loss = use_focal_loss if classes_dict is None: classes_dict = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', 9:'I',10:'P',11:'H',12:'Si', #bond 13:'single',14:'wdge',15:'dash', 16:'=',17:'#',18:':',#aromatic #charge 19:'-4',20:'-2', 21:'-1',#- 22:'+1',#+ 23:'+2', } num_classes=len(classes_dict) self.num_top_queries = num_top_queries self.num_classes = num_classes self.remap_mscoco_category = remap_mscoco_category self.deploy_mode = False mscoco_category2label = {k: i for i, k in enumerate(classes_dict.keys())} mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} self.mscoco_label2category=mscoco_label2category def extra_repr(self) -> str: return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' # def forward(self, outputs, orig_target_sizes): def forward(self, outputs, orig_target_sizes): logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) if self.use_focal_loss: scores = F.sigmoid(logits) scores, index = torch.topk(scores.flatten(1), self.num_top_queries, axis=-1) labels = index % self.num_classes index = index // self.num_classes boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) else: scores = F.softmax(logits)[:, :, :-1] scores, labels = scores.max(dim=-1) boxes = bbox_pred if scores.shape[1] > self.num_top_queries: scores, index = torch.topk(scores, self.num_top_queries, dim=-1) labels = torch.gather(labels, dim=1, index=index) boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) # TODO for onnx export if self.deploy_mode: return labels, boxes, scores # TODO if self.remap_mscoco_category: # from ...data.coco import mscoco_label2category labels = torch.tensor([self.mscoco_label2category[int(x.item())] for x in labels.flatten()])\ .to(boxes.device).reshape(labels.shape) results = [] for lab, box, sco in zip(labels, boxes, scores): result = dict(labels=lab, boxes=box, scores=sco) results.append(result) return results def deploy(self, ): self.eval() self.deploy_mode = True return self @property def iou_types(self, ): return ('bbox', )