diff --git "a/det_engine.py" "b/det_engine.py" new file mode 100644--- /dev/null +++ "b/det_engine.py" @@ -0,0 +1,5456 @@ +""" +Copyright (c) All Rights Reserved +by bowen +""" + +import json +import math +import os +import sys +import pathlib +from typing import Iterable, List +import random +import itertools + +import numpy as np +import pandas as pd +import tqdm +import torch +import torch.amp +from PIL import Image +# from src.data import CocoEvaluator +# from src.misc import (MetricLogger, SmoothedValue, reduce_dict) +# from src.solver.utils import output_to_smiles, output_to_smiles2 +# from src.solver.utils import bbox_to_graph_with_charge, mol_from_graph_with_chiral +# from src.misc.draw_box_utils import draw_objs + +# from sklearn.metrics import f1_score +# from src.postprocess.abbreviation_detector import get_ocr_recognition_only +# from src.postprocess.utils_dataset import CaptionRemover +from skimage.measure import label +######################################add metric postprocess +import rdkit +from rdkit import Chem +from rdkit.Chem import Draw, AllChem +from rdkit.Chem import rdchem, RWMol, CombineMols +from rdkit import Chem +from rdkit.Chem import rdFMCS +import copy +from paddleocr import PaddleOCR +import re +from rdkit import DataStructs +import matplotlib.pyplot as plt +from matplotlib.patches import Rectangle, Circle +from scipy.spatial import cKDTree, KDTree +from rdkit.Geometry import Point3D +import multiprocessing + + + +def select_longest_smiles(smiles): + # 将 SMILES 以 '.' 分割为多个部分 + components = smiles.split('.') + # 选择字符数最多的部分作为主结构 + longest_component = max(components, key=len) + return longest_component + +def MCS_mol(mcs): + #mcs_smart = mcs.smartsString + mcs_mol = Chem.MolFromSmarts(mcs.smartsString) + AllChem.Compute2DCoords(mcs_mol) + return mcs_mol + +def g_atompair_matches(pair,mcs): + mcs_mol = MCS_mol(mcs) + matches0 = pair[0].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) + matches1 = pair[1].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) + if len(matches0) != len(matches1): + matches0=list(matches0) + matches1=list(matches1) + print( " g_atompair_matches noted: matcher not equal !!") + if len(matches0)>len(matches1) and len(matches1) !=0: + for i in range(0,len(matches0)): + if i < len(matches1): + pass + else: + ii=i % len(matches1) + matches1.append(matches1[ii]) + else: + for i in range(0,len(matches1)): + if i < len(matches0) and len(matches0): + pass + else: + ii=i % len(matches0) + matches0.append(matches0[ii]) + # assert len(matches0) == len(matches1), "matcher not equal break!!" + if len(matches0) != len(matches1): + atommaping_pairs=[[]] + else:atommaping_pairs=[list(zip(matches0[i],matches1[i])) for i in range(0,len(matches0))] + return atommaping_pairs + + +class CustomError(Exception): + """A custom exception for specific errors.""" + pass + +bond_dirs = {'NONE': Chem.rdchem.BondDir.NONE, + 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT, + 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE, + 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH, + 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,} + +BONDTYPE = {'SINGLE': Chem.rdchem.BondType.SINGLE, + 'DOUBLE': Chem.rdchem.BondType.DOUBLE, + 'TRIPLE': Chem.rdchem.BondType.TRIPLE, + 'AROMATIC': Chem.rdchem.BondType.AROMATIC} +BOND_DIRS = {'NONE': Chem.rdchem.BondDir.NONE, + 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT, + 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE, + 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH, + 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,} +BONDDIRECT=['ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT'] + + +BONDTYPE2ORD={ + 'wdge':1, + 'dash':1, + Chem.rdchem.BondType.SINGLE: 1, + Chem.rdchem.BondType.DOUBLE: 2, + Chem.rdchem.BondType.TRIPLE: 3, + Chem.rdchem.BondType.AROMATIC: 1.5, + } + +BONDTYPE={'SINGLE': Chem.BondType.SINGLE, + 'DOUBLE': Chem.BondType.DOUBLE, + 'TRIPLE': Chem.BondType.TRIPLE, + 'AROMATIC': Chem.BondType.AROMATIC} + +VALENCES = { + "H": [1], "Li": [1], "Be": [2], "B": [3], "C": [4], "N": [3, 5], "O": [2], "F": [1], + "Na": [1], "Mg": [2], "Al": [3], "Si": [4], "P": [5, 3], "S": [6, 2, 4], "Cl": [1], "K": [1], "Ca": [2], + "Br": [1], "I": [1], "*":[3,4,5,6], +} + +ELEMENTS = [ + "H", "He", "Li", "Be", "B", "C", "N", "O", "F", "Ne", + "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar", "K", "Ca", + "Sc", "Ti", "Ru", "Rh","Rn","Rf", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn", + "Ga", "Ge", "As", "Se", "Br", "Kr", "Sr", "Zr", + "Nb", "Mo", "Tc", "Pd", "Ag", "Cd", "In", "Sn", + "Sb", "Te", "I", "Xe", "Cs", "Ba", "La", "Ce", "Pr", "Nd", + "Pm", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", + "Lu", "Hf", "Ta", "W", "Os", "Ir", "Pt", "Au", "Hg", + "Tl", "Pb", "Bi", "Po", "At", "Fr", "Ac", "Th", + "Pa", "Np", "Pu", "Am", "Cm", "Bk", "Cf", "Es", "Fm", + "Md", "No", "Lr", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", + "Cn", "Nh", "Fl", "Mc", "Lv", "Og" +] + # "Rg", "Rb", "Re", "Ra"as RGROUP in the Molscribe data + #"V", "Y","U", # be viewed as C for paddleOCR smt ONELEMENTS ['A','J] + #"Ts" #as a chemical group [S](C1=CC=C(C=C1)C)(=O)=O +RGROUP_SYMBOLS = ['R',"R'" 'R1', 'R2', 'R3', 'R4', 'R5', 'R6', 'R7', 'R8', 'R9', 'R10', 'R11', 'R12', + 'Ra', 'Rb', 'Rc', 'Rd','Re','Rg', 'X', 'Y', 'Z', 'Q', 'A', 'E', 'Ar', + "V", "Y","U",'M', 'G','L', + 'Nr','Tt','Uu','Vv','Ww',#CLEF Nr is not in periodic table + 'D',#CLEF as [2H] but not recongited by rdkit chemdraw + ] + +COLORS = { + u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0', + u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75' +} + +class Substitution(object): + '''Define common substitutions for chemical shorthand''' + def __init__(self, abbrvs, smarts, smiles, probability): + assert type(abbrvs) is list + self.abbrvs = abbrvs + self.smarts = smarts + self.smiles = smiles + self.probability = probability + +SUBSTITUTIONS: List[Substitution] = [ + #abbrvs, smarts, smiles + #patch4 USPTO,try put the longer one first, as re use match by order + Substitution(['CH2CH2NSO2CH3'], '[CH2][CH]', '[CH2]CNS(=O)(C)=O', 0.5), + Substitution(['NHNHCOCF3'], 'NHNHCOCF3', '[NH]NC(=O)C(F)(F)(F)', 0.5), + Substitution(['CO2CysPr'], 'CO2CysPr', '[C](=O)ON[C@H](C(CCC)=O)CS', 0.5), + Substitution(['OCH2CHOHCH2'], 'OCH2CHOHCH2', '[O]CC(O)C', 0.5), + Substitution(['OCH2CHOHCH2OH'], 'OCH2CHOHCH2', '[O]CC(O)CO', 0.5), + # elif symbol in ['SO2(CH2)3SO2NHCH2CHCH2OH']:smiles='[S](=O)(=O)CCCS(=O)(=O)NC[C]CO' + Substitution(['SO2(CH2)3SO2NHCH2CHCH2OH'], 'OCH2CHOHCH2', '[S](=O)(=O)CCCS(=O)(=O)NC[C]CO', 0.5), + + + + + Substitution(['NO2', 'O2N'], '[N+](=O)[O-]', "[N+](=O)[O-]", 0.5), + # Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5), + Substitution(['CO2Et', 'COOEt'], 'C(=O)[OH0;D2][CH2;D2][CH3]', "[C](=O)OCC", 0.5), + + Substitution(['OAc','AcO'], '[OH0;X2]C(=O)[CH3]', "[O]C(=O)C", 0.7), + Substitution(['NHAc'], '[NH1;D2]C(=O)[CH3]', "[NH]C(=O)C", 0.7), + Substitution(['Ac'], 'C(=O)[CH3]', "[C](=O)C", 0.1), + + Substitution(['OBz','BzO'], '[OH0;D2]C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[O]C(=O)c1ccccc1", 0.7), # Benzoyl + Substitution(['Bz'], 'C(=O)[cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)c1ccccc1", 0.2), # Benzoyl + + Substitution(['COOBn','BnO2C'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[C](=O)OCc1ccccc1", 0.7), # Benzyl + Substitution(['OBn','BnO'], '[OH0;D2][CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[O]Cc1ccccc1", 0.7), # Benzyl + Substitution(['Bn'], '[CH2;D2][cH0]1[cH][cH][cH][cH][cH]1', "[CH2]c1ccccc1", 0.2), # Benzyl + Substitution(['NHBn'], '[NH]Cc1ccccc1', "[NH]Cc1ccccc1", 0.2), # Benzyl + Substitution(['NBn2'], '[NH]Cc1ccccc1', "[N](Cc1ccccc1)Cc1ccccc1", 0.2), # Benzyl + + Substitution(['NHBoc','BocHN',"BOCHN"], '[NH1;D2]C(=O)OC([CH3])([CH3])[CH3]', "[NH]C(=O)OC(C)(C)C", 0.6), + Substitution(['NBoc'], '[NH0;D3]C(=O)OC([CH3])([CH3])[CH3]', "[NH1]C(=O)OC(C)(C)C", 0.6), + Substitution(['Boc','BOc'], 'C(=O)OC([CH3])([CH3])[CH3]', "[C](=O)OC(C)(C)C", 0.2), + + Substitution(['Cbm'], 'C(=O)[NH2;D1]', "[C](=O)N", 0.2), + Substitution(['Cbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[C](=O)OCc1ccccc1", 0.4), + Substitution(['NHCbz'], 'C(=O)OC[cH]1[cH][cH][cH1][cH][cH]1', "[NH]C(=O)OCc1ccccc1", 0.4), + Substitution(['Cy'], '[CH1;X3]1[CH2][CH2][CH2][CH2][CH2]1', "[CH1]1CCCCC1", 0.3), + Substitution(['Fmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3', + "[C](=O)OCC1c(cccc2)c2c3c1cccc3", 0.6), + Substitution(['FmocHN','FmOcHN', 'NHFmoc'], 'C(=O)O[CH2][CH1]1c([cH1][cH1][cH1][cH1]2)c2c3c1[cH1][cH1][cH1][cH1]3', + "[NH]C(=O)OCC1c(cccc2)c2c3c1cccc3", 0.6), + Substitution(['Mes'], '[cH0]1c([CH3])cc([CH3])cc([CH3])1', "[c]1c(C)cc(C)cc(C)1", 0.5), + Substitution(['OMs','MsO'], '[OH0;D2]S(=O)(=O)[CH3]', "[O]S(=O)(=O)C", 0.7), + Substitution(['Ms'], 'S(=O)(=O)[CH3]', "[S](=O)(=O)C", 0.2), + Substitution(['Ph'], '[cH0]1[cH][cH][cH1][cH][cH]1', "[c]1ccccc1", 0.5), + + Substitution(['PMB'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[CH2]c1ccc(OC)cc1", 0.2), + Substitution(['PMBN'], '[CH2;D2][cH0]1[cH1][cH1][cH0](O[CH3])[cH1][cH1]1', "[N]Cc1ccc(OC)cc1", 0.2), + Substitution(['Py'], '[cH0]1[n;+0][cH1][cH1][cH1][cH1]1', "[c]1ncccc1", 0.1), + # Substitution(['SEM','MES'], '[CH2;D2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]CSi(C)(C)C", 0.2), + Substitution(['SEM','MES'], '[CH2;D2][O][CH2][CH2][Si]([CH3])([CH3])[CH3]', "[CH2]OCC[Si](C)(C)C", 0.2),#fix above + + Substitution(['Suc'], 'C(=O)[CH2][CH2]C(=O)[OH]', "[C](=O)CCC(=O)O", 0.2), + Substitution(['TBS'], '[Si]([CH3])([CH3])C([CH3])([CH3])[CH3]', "[Si](C)(C)C(C)(C)C", 0.5), + Substitution(['TBZ'], 'C(=S)[cH]1[cH][cH][cH1][cH][cH]1', "[C](=S)c1ccccc1", 0.2), + Substitution(['OTf'], '[OH0;D2]S(=O)(=O)C(F)(F)F', "[O]S(=O)(=O)C(F)(F)F", 0.7), + Substitution(['Tf'], 'S(=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.2), + Substitution(['TFA'], 'C(=O)C(F)(F)F', "[C](=O)C(F)(F)F", 0.3), + Substitution(['TFAH2N'], 'C(=O)C(F)(F)F', "[NH]C(=O)C(F)(F)F", 0.3), + Substitution(['TMS'], '[Si]([CH3])([CH3])[CH3]', "[Si](C)(C)C", 0.5), + Substitution(['Ts'], 'S(=O)(=O)c1[cH1][cH1][cH0]([CH3])[cH1][cH1]1', "[S](=O)(=O)c1ccc(C)cc1", 0.6), # Ts + Substitution(['TsO','OTs'], '[O]S(C1=CC=C(C=C1)C)(=O)=O', "[O]S(C1=CC=C(C=C1)C)(=O)=O", 0.6), # Ts + + Substitution(['COCH3'], '[OH0;D2][CH3;D1]', "[C](=O)C", 0.3), + # Alkyl chains + Substitution(['OMe', 'MeO','H;CO', 'CH3O','OCH3', 'H3CO'], '[OH0;D2][CH3;D1]', "[O]C", 0.3), + Substitution(['SMe', 'MeS'], '[SH0;D2][CH3;D1]', "[S]C", 0.3), + Substitution(['NMe', 'MeN'], '[N;X3][CH3;D1]', "[N]C", 0.3),#modified as [NH]not wanted + Substitution(['NMe2', 'Me2N'], '[N;X3](C)[CH3;D1]', "[N](C)C", 0.3),#modified as [NH]not wanted + + Substitution(['Me'], '[CH3;D1]', "[CH3]", 0.1), + Substitution(['OEt', 'EtO','C2H5O','OC2H5'], '[OH0;D2][CH2;D2][CH3]', "[O]CC", 0.5), + Substitution(['MeOH2C','CH2OMe'], '[CH2;D2]O[CH3]', "[CH2]OC", 0.5), + Substitution(['Et', 'CH2CH3','CH3CH2'], '[CH2;D2][CH3]', "[CH2]C", 0.3), + + + Substitution(['Pr', 'nPr', 'n-Pr'], '[CH2;D2][CH2;D2][CH3]', "[CH2]CC", 0.3), + Substitution(['Bu', 'nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3), + # Substitution(['nBu', 'n-Bu'], '[CH2;D2][CH2;D2][CH2;D2][CH3]', "[CH2]CCC", 0.3), + + # Branched + Substitution(['iPr', 'i-Pr'], '[CH1;D3]([CH3])[CH3]', "[CH1](C)C", 0.2), + Substitution(['iBu', 'i-Bu'], '[CH2;D2][CH1;D3]([CH3])[CH3]', "[CH2]C(C)C", 0.2), + Substitution(['OiBu'], '[OH0;D2][CH2;D2][CH1;D3]([CH3])[CH3]', "[O]CC(C)C", 0.2), + Substitution(['OtBu','tBuO'], '[OH0;D2][CH0]([CH3])([CH3])[CH3]', "[O]C(C)(C)C", 0.6), + Substitution(['tBu', 't-Bu'], '[CH0]([CH3])([CH3])[CH3]', "[C](C)(C)C", 0.3), + + # Other shorthands (MIGHT NOT WANT ALL OF THESE) + Substitution(['CF3', 'F3C'], '[CH0;D4](F)(F)F', "[C](F)(F)F", 0.5), + Substitution(['NCF3', 'F3CN'], '[N;X3][CH0;D4](F)(F)F', "[NH]C(F)(F)F", 0.5), + Substitution(['OCF3', 'F3CO'], '[OH0;X2][CH0;D4](F)(F)F', "[O]C(F)(F)F", 0.5), + Substitution(['OCCl3', 'Cl3CO'], '[OH0;X2][CH0;D4](Cl)(Cl)Cl', "[O]C(Cl)(Cl)Cl", 0.5), + Substitution(['SCF3', 'F3CS'], '[SH0;X2][CH0;D4](F)(F)F', "[S]C(F)(F)F", 0.5), + Substitution(['CCl3'], '[CH0;D4](Cl)(Cl)Cl', "[C](Cl)(Cl)Cl", 0.5), + Substitution(['CO2H', 'HO2C', 'COOH'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH + Substitution(['CO2NH4','COONH4','H4NOOC','H4NO2C'], 'C(=O)[OH]', "[C](=O)ON", 0.5), # COOH + Substitution([ 'COO-','CO2-'], 'C(=O)[OH]', "[C](=O)[O-]", 0.5), # COOH + # Substitution([ 'COO'], 'C(=O)[OH]', "[C](=O)O", 0.5), # COOH + Substitution(['CN', 'NC'], 'C#[ND1]', "[C]#N", 0.5), + # Substitution(['OCH3', 'H3CO'], '[OH0;D2][CH3]', "[O]C", 0.4), + #TODO if need just addit here + Substitution(['N3'], '[N]=[N+]=[N-]', "[N]=[N+]=[N-]", 0.4),#ACS image dataset has + # [N-]=[N+] + Substitution(['N2+Cl-','Cl-N2+'], '[N+]#[N].[Cl-]', "[N+]#[N].[Cl-]", 0.4),#ACS image dataset has + Substitution(['N2'], '[N]=[N-]', "[N]=[N-]", 0.4),#ACS image dataset has + Substitution(['N2H'], '[N]=[N-]', "[N]=[NH]", 0.4),#ACS image dataset has + Substitution(['NO','N=O','O=N','ON'], '[N]=[O]', "[N]=O", 0.4),#ACS image dataset has + Substitution(['NCH3'], '[N]C', "[NH]C", 0.4),#ACS image dataset has + Substitution(['NOMe'], '[N]OC', "[N]OC", 0.4),#ACS image dataset has + Substitution(['OCH2'], '[O]C', "[O]C", 0.4),#FORMULA_REGEX + Substitution(['C=O','O=C'], '[C]=[O]', "[C]=O", 0.4),#ACS image dataset has + Substitution(['NPh','PhN'], 'NC1=CC=CC=C1', "[N]C1=CC=CC=C1", 0.4),#ACS image dataset has + Substitution(['NHPh','PhNH','PhHN'], 'NC1=CC=CC=C1', "[NH]C1=CC=CC=C1", 0.4),#ACS image dataset has + Substitution(['TMSO','OSMT'], 'O[Si](C)(C)C', "[O][Si](C)(C)C", 0.5), + Substitution(['SPh','PhS'], 'SC1=CC=CC=C1', "[S]C1=CC=CC=C1", 0.4),#ACS image dataset has + Substitution(['SO3H'], 'S(=O)(=O)[OH]', "[S](=O)(=O)O", 0.4), + Substitution(['SO3NH2','SO3NH4','H4NO3S'], 'S(=O)(=O)[OH]', "[S](=O)(=O)ON", 0.4), + Substitution(['SO3'], 'S(=O)(=O)[OH]', "[S](=O)(=O)[O-]", 0.4), + Substitution(['SO2CF3'], '[S](=O)(=O)C(F)(F)F', "[S](=O)(=O)C(F)(F)F", 0.5), + Substitution(['SO2Cl'], '[S](=O)(=O)Cl', "[S](=O)(=O)Cl", 0.5), + Substitution(['SO2F'], '[S](=O)(=O)F', "[S](=O)(=O)F", 0.5), + Substitution(['SO2'], '[S](=O)(=O)', "[S](=O)(=O)", 0.5), + Substitution(['SO2NH'], '[S](=O)(=O)[N]', "[S](=O)(=O)[N]", 0.5),#US07323045-20080129-C00062 may lead wrong connext + Substitution(['SO2NH2'], '[S](=O)(=O)[NH2]', "[S](=O)(=O)[NH2]", 0.5), + Substitution(['SO2Me','SO2CH3'], '[S](=O)(=O)C', "[S](=O)(=O)C", 0.5), + Substitution(['NHO2S'], '[S](=O)(=O)[N]', "[N][S](=O)(=O)", 0.5),#US07323045-20080129-C00062 may lead wrong connext + Substitution(['OSO2Me'], '[O]S(=O)(=O)C', "[O]S(=O)(=O)C", 0.5), + Substitution(['NHSO2Me'], '[NH]S(=O)(=O)C', "[NH]S(=O)(=O)C", 0.5), + Substitution(['SOCH3','SOMe'], '[S](=O)(=O)', "[S](=O)C", 0.5), + + Substitution(['P+Ph3Br-'], '[P+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', "[P+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3", 0.5), + Substitution(['N+Ph3Br-'], '[N+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', "[N+](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3", 0.5), + Substitution(['PPh2'], "[P](C1=CC=CC=C1)C2=CC=CC=C2", "[P](C1=CC=CC=C1)C2=CC=CC=C2", 0.5), + # Substitution(['BOcHN',"BOCHN"], "[NH]C(OC(C)(C)C)=O", "[NH]C(OC(C)(C)C)=O", 0.5), + Substitution(['CO2Me', 'COOMe'], 'C(=O)[OH0;D2][CH3]', "[C](=O)OC", 0.5), + Substitution(['ONa', 'NaO'], '[O][Na]', "[O][Na]", 0.5), + Substitution(['OTBDMS', 'TBDMSO'], "[O][Si](C)(C)C(C)(C)C", "[O][Si](C)(C)C(C)(C)C", 0.5), + Substitution(['CONH2'], '[C](O)(N)', "[C](=O)[NH2]", 0.5), + Substitution(['NHNH2'], '[NH2;D1]', "[NH]N", 0.1), + Substitution(['CONH'], 'CONH', '[C](=O)N', 0.5), + Substitution(['CH3CONH'], '[NH]C(=O)C', '[NH]C(=O)C', 0.5), + Substitution(['NH3Cl'], '[NH]Cl', '[NH]Cl', 0.5), + + Substitution(['SAc','AcS'], '[S]C(C)=O', "[S]C(C)=O", 0.5), + Substitution(['OAll'], '[O]CC=C', '[O]CC=C', 0.5), + # Substitution(['Tos'], '[Si](C)(C)C', '[Si](C)(C)C', 0.5),#NOTE different case ?? @@acs dataset ,we use the SO2here + Substitution(['Tos','TOs'], '[Si](C)(C)C', '[S](=O)(=O)C(C=C1)=CC=C1C', 0.5),#NOTE different case ?? + Substitution(['OTos','OTOs','soTO'], '[Si](C)(C)C', '[O]S(=O)(=O)C(C=C1)=CC=C1C', 0.5),#NOTE different case ?? + Substitution(['TsN'], '[N]S(C1=CC=C(C=C1)C)(=O)=O', '[N]S(C1=CC=C(C=C1)C)(=O)=O', 0.5), + Substitution(['Ts'], '[S](C1=CC=C(C=C1)C)(=O)=O', '[S](C1=CC=C(C=C1)C)(=O)=O', 0.5), + Substitution(['COCF3'], '[C](=O)C(F)(F)(F)', '[C](=O)C(F)(F)(F)', 0.5), + Substitution(['CF2', 'F2C'], '[C;D4](F)(F)', "[C](F)(F)", 0.5), + Substitution(['PMB'], '[CH2]C1=CC=C(C=C1)OC', '[CH2]C1=CC=C(C=C1)OC', 0.5), + Substitution(['NHCOtBu'], '[NH]C(C(C)(C)C)=O','[NH]C(C(C)(C)C)=O', 0.5), + Substitution(['OCN'], '[N]=C=O', "[N]=C=O", 0.5), + Substitution(['Me3Si'], '[Si](C)(C)(C)', "[Si](C)(C)(C)", 0.5), + Substitution(['PhO','OPh'], '[O]C1=CC=CC=C1', "[O]C1=CC=CC=C1", 0.5), + Substitution(['Allyl'], '[CH2]C=C', '[CH2]C=C', 0.5), + Substitution(['C7H3'], '[C]#CC#CC#CC', '[C]#CC#CC#CC', 0.5), + Substitution(['C5H11'], '[CH2]CCCC', '[CH2]CCCC', 0.5), + Substitution(['R1R2N'], "[N]([*])[*]", "[N]([*])[*]", 0.5), + Substitution(['CO2R'], '[C](=O)O*', '[C](=O)O*', 0.5), + Substitution(['CCl3CH2O2C'], '[C](=O)OCC(Cl)(Cl)Cl', '[C](=O)OCC(Cl)(Cl)Cl', 0.5), + Substitution(['NHOH'], '[NH]O', '[NH]O', 0.5), + Substitution(['CO2'], '[C](=O)[O]', '[C](=O)[O]', 0.5), + Substitution(['O2C'], '[C](=O)[O]', '[O][C](=O)', 0.5),#NOTE the direction matters + + Substitution(['PPh3'], '[P](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', '[P](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', 0.5), + Substitution(['TfO'], '[C](=O)[O]', '[O]S(=O)(C(F)(F)F)=O', 0.5), + Substitution(['OCH2Ph'], '[O]CC1=CC=CC=C1', '[O]CC1=CC=CC=C1', 0.5), + Substitution(['OCH2CF3'], '[O]CC(F)(F)(F)', '[O]CC(F)(F)(F)', 0.5), + Substitution(['COOCH2Ph'], '[C](=O)OCC1=CC=CC=C1', '[C](=O)OCC1=CC=CC=C1', 0.5), + Substitution(['OCH2OC2H5'], '[C](=O)C(C)(C)C', '[O]COCC', 0.5), + + Substitution(['Trt'], '[C](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', '[C](C1=CC=CC=C1)(C2=CC=CC=C2)C3=CC=CC=C3', 0.5), + Substitution(['SF5'], '[S](F)(F)(F)(F)F', '[S](F)(F)(F)(F)F', 0.5), + + # Substitution(['CH2CH'], '[CH2][CH]', '[CH2][CH]', 0.5), + # Substitution(['CH2CH2'], '[CH2][CH2]', '[CH2][CH2]', 0.5), + + # #SIMPLE abbv + Substitution(['S*'], '[S]*', '[S]*', 0.5), + Substitution(['N*, NH*'], '[NH]*', '[NH]*', 0.5), + Substitution(['C*','CH2*'], '[C]*', '[CH2]*', 0.5), + Substitution(['P*',"PH*"], '[P]*', '[PH]*', 0.5), + Substitution(['O*'], '[O]*', '[O]*', 0.5), + #() effect + Substitution(['N(CH3)2'], '[N](C)(C)', "[N](C)(C)", 0.5), + Substitution(['(C2H5)2N','Et2N'], '[N](C)(C)', "[N](CC)(CC)", 0.5), + Substitution(['B(OH)2'], '[B](O)O', "[B](O)O", 0.5), + Substitution(['CO2C(CH3)3'], '[C](=O)C(C)(C)C', '[C](=O)C(C)(C)C', 0.5), + Substitution(['P(O)(OEt)2', 'P(OEt)2(O)'], "[P](OCC)(=O)CCO", "[P](OCC)(=O)OCC", 0.5), + Substitution(['(CH2)16Me'], '[CH2]CCCCCCCCCCCCCCCC', "[CH2]CCCCCCCCCCCCCCCC", 0.3), + Substitution(['(CH2)11Me'], '[CH2]CCCCCCCCCCC', "[CH2]CCCCCCCCCCC", 0.3), + Substitution(['N(H)Et','Et(H)N'], '[NH]CC', '[NH]CC', 0.5), + Substitution(['N(H)Me','Me(H)N'], '[NH]C', '[NH]C', 0.5), + + + +] +ABBREVIATIONS = {abbrv: sub for sub in SUBSTITUTIONS for abbrv in sub.abbrvs} + + +def extract_abbreviation_key(item): + if isinstance(item, list): + while isinstance(item, list): + item = item[0] + return item + return item + + +def clean_unpaired_brackets(text): + #keep paired, del unpared + result = [] + stack = [] + bracket_pairs = {')': '(', ']': '['} + opening_brackets = {'(', '['} + + for char in text: + if char in opening_brackets: + stack.append(char) + result.append(char) + elif char in bracket_pairs: + if stack and stack[-1] == bracket_pairs[char]: + stack.pop() + result.append(char) + else: + # 未配对的闭合括号,跳过 + continue + else: + result.append(char) + return ''.join(result) + +# def del_unpairebrackets(opening_brackets): +# # 移除未配对的开括号 +# keep paired, del unpared +# result = [] +# stack = [] +# bracket_pairs = {')': '(', ']': '['} +# opening_brackets = {'(', '['} +# for char in result: +# if char in opening_brackets: +# stack.append(char) +# elif char in bracket_pairs: +# if stack and stack[-1] == bracket_pairs[char]: +# stack.pop() +# final_result.append(char) +# else: +# continue +# else: +# final_result.append(char) + + # # 如果仍有未闭合的开括号,移除它们 + # return ''.join(c for c in final_result if not stack or c not in opening_brackets) + +def replace_c1(text): + # Use negative lookahead to ensure 'C1' isn't followed by another digit + return re.sub(r'C1(?!\d)', 'Cl', text) +def transform_formula(formula): + # 匹配 C 后面的数字和 Hg(允许 Hg 后跟其他元素) + match = re.match(r'C(\d+)(.*?)Hg(.*)', formula) + if not match: + return formula + + n = int(match.group(1)) + prefix = match.group(2) # Hg 前的部分(如空字符串或其他元素) + suffix = match.group(3) # Hg 后的部分(如 O2) + g_new = n * 2 + 1 + return f"C{n}{prefix}H{g_new}{suffix}" +def Cg_transform_formula(formula): + # 匹配 C 后面的数字和 Hg(允许 Hg 后跟其他元素) + match = re.match(r'CgH(\d+)(.*?)', formula) + if not match: + return formula + + n = int(match.group(1)) + suffix = match.group(2) # Hg 后的部分(如 O2) + g_new = (n-1)// 2 + return f"C{g_new}H{n}{suffix}" + +def normalize_ocr_text(text, replacement_map): + """Normalize OCR text using the predefined mapping rules""" + if 'C1'in text: + text=replace_c1(text) + if 'Hg' in text: + text= transform_formula(text) + if 'Cg' in text: + text= Cg_transform_formula(text) + if 'Q' in text: + pattern = r'Q([A-Z])(\w+)' + replacement = r'O\1\2' + text = re.sub(pattern, replacement, text) + if text in ELEMENTS: + return text + #remove space + if ' ' in text: + text = text.replace(" ", "") + if any(c in text for c in '0oO'): + # Step 1: Replace 'o' or 'O' with '0' when after a digit and before a letter or end of string + # text = re.sub(r'(?<=[1-9])[oO](?=[a-zA-GI-Z]|$)', '0', text) + text = re.sub(r'(? [('C', 2), ('H', 4), ('O', 1)] + """ + tokens = FORMULA_REGEX.findall(formula) + # if ''.join(tokens) != formula: + # tokens = FORMULA_REGEX_BACKUP.findall(formula) + return _parse_tokens(tokens) + + +def _expand_carbon(elements: list): + """ + Given list of pairs `(elt, num)`, output single list of all atoms in order, + expanding carbon sequences (CaXb where a > 1 and X is halogen) if necessary + Example: [('C', 2), ('H', 4), ('O', 1)] -> ['C', 'H', 'H', 'C', 'H', 'H', 'O']) + """ + expanded = [] + i = 0 + while i < len(elements): + elt, num = elements[i] + # skip unreasonable number of atoms + if num > 100000: + i += 1; continue + # expand carbon sequence + if elt == 'C' and num > 1 and i + 1 < len(elements): + next_elt, next_num = elements[i + 1] + if next_num > 100000: + i += 1; continue + quotient, remainder = next_num // num, next_num % num + for _ in range(num): + expanded.append('C') + for _ in range(quotient): + expanded.append(next_elt) + for _ in range(remainder): + expanded.append(next_elt) + i += 2 + # recurse if `elt` itself is a list (nested formula) + elif isinstance(elt, list): + new_elt = _expand_carbon(elt) + for _ in range(num): + expanded.append(new_elt) + i += 1 + # simplest case: simply append `elt` `num` times + else: + for _ in range(num): + expanded.append(elt) + i += 1 + if expanded==[]: + return False + else: + return expanded + +def replace_bracket(match): + content = match.group(1) + # 条件1:包含数字或 '+' 或 '-',保留整个 [content] + if re.search(r'\d|\+|-', content): + return f'[{content}]' + # 条件2:仅为 'H',保留 + elif content == 'H': + return '[H]' + # 条件3:字符长度 >=2 且包含 'H',则去除括号和 H + elif len(content) >= 2 and 'H' in content: + return ''.join([ch for ch in content if ch != 'H']) + # 条件4:其他情况,去掉括号 + else: + return content + + # return re.sub(r'\[([^\[\]]+)\]', replace_bracket, smi) + +def formula_regex(abbrev):# molscribe way for the combine abbver style + tokens = FORMULA_REGEX.findall(abbrev) + # elements=_parse_tokens(tokens) + abbrev_exp=_expand_carbon(_parse_tokens(tokens)) + if abbrev_exp==[]: + return False + else: + return abbrev_exp + +def _expand_abbreviationMS(abbrev): + """ + Expand abbreviation into its SMILES; also converts [Rn] to [n*] + Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula + """ + if abbrev in ABBREVIATIONS: + return ABBREVIATIONS[abbrev].smiles + # if abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): + if abbrev in RGROUP_SYMBOLS or (abbrev[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()): + if abbrev[1:].isdigit(): + return f'[{abbrev[1:]}*]' + return '*' + return f'[{abbrev}]' + + +def _get_bond_symb(bond_num): + """ + Get SMILES symbol for a bond given bond order + Used in `_condensed_formula_list_to_smiles` while writing the SMILES string + """ + if bond_num == 0: + return '.' + elif bond_num == 1: + return '' + elif bond_num == 2: + return '=' + elif bond_num == 3: + return '#' + else: + print(f"check this val {bond_num} !!!" ) + + return '' +def _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond=None, direction=None): + """ + Converts condensed formula (in the form of a list of symbols) to smiles + Input: + `formula_list`: e.g. ['C', 'H', 'H', 'N', ['C', 'H', 'H', 'H'], ['C', 'H', 'H', 'H']] for CH2N(CH3)2 + `start_bond`: # bonds attached to beginning of formula + `end_bond`: # bonds attached to end of formula (deduce automatically if None) + `direction` (1, -1, or None): direction in which to process the list (1: left to right; -1: right to left; None: deduce automatically) + Returns: + `smiles`: smiles corresponding to input condensed formula + `bonds_left`: bonds remaining at the end of the formula (for connecting back to main molecule); should equal `end_bond` if specified + `num_trials`: number of trials + `success` (bool): whether conversion was successful + """ + # `direction` not specified: try left to right; if fails, try right to left + if direction is None: + num_trials = 1 + for dir_choice in [1, -1]: + smiles, bonds_left, trials, success = _condensed_formula_list_to_smiles(formula_list, start_bond, end_bond, dir_choice) + num_trials += trials + if success: + return smiles, bonds_left, num_trials, success + return None, None, num_trials, False + assert direction == 1 or direction == -1 + + def dfs(smiles, bonds_left, cur_idx, add_idx): + """ + `smiles`: SMILES string so far + `cur_idx`: index (in list `formula`) of current atom (i.e. atom to which subsequent atoms are being attached) + `cur_flat_idx`: index of current atom in list of atom tokens of SMILES so far + `bonds_left`: bonds remaining on current atom for subsequent atoms to be attached to + `add_idx`: index (in list `formula`) of atom to be attached to current atom + `add_flat_idx`: index of atom to be added in list of atom tokens of SMILES so far + Note: "atom" could refer to nested condensed formula (e.g. CH3 in CH2N(CH3)2) + """ + num_trials = 1 + # end of formula: return result + if (direction == 1 and add_idx == len(formula_list)) or (direction == -1 and add_idx == -1): + if end_bond is not None and end_bond != bonds_left: + return smiles, bonds_left, num_trials, False + return smiles, bonds_left, num_trials, True + + # no more bonds but there are atoms remaining: conversion failed + if bonds_left <= 0: + return smiles, bonds_left, num_trials, False + to_add = formula_list[add_idx] # atom to be added to current atom + if not isinstance(to_add, str): + return smiles, bonds_left, num_trials, False + if isinstance(to_add, list): # "atom" added is a list (i.e. nested condensed formula): assume valence of 1 + if bonds_left > 1: + # "atom" added does not use up remaining bonds of current atom + # get smiles of "atom" (which is itself a condensed formula) + add_str, val, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) + if val > 0: + add_str = _get_bond_symb(val + 1) + add_str + num_trials += trials + if not success: + return smiles, bonds_left, num_trials, False + # put smiles of "atom" in parentheses and append to smiles; go to next atom to add to current atom + result = dfs(smiles + f'({add_str})', bonds_left - 1, cur_idx, add_idx + direction) + else: + # "atom" added uses up remaining bonds of current atom + # get smiles of "atom" and bonds left on it + add_str, bonds_left, trials, success = _condensed_formula_list_to_smiles(to_add, 1, None, direction) + num_trials += trials + if not success: + return smiles, bonds_left, num_trials, False + # append smiles of "atom" (without parentheses) to smiles; it becomes new current atom + result = dfs(smiles + add_str, bonds_left, add_idx, add_idx + direction) + smiles, bonds_left, trials, success = result + num_trials += trials + return smiles, bonds_left, num_trials, success + # atom added is a single symbol (as opposed to nested condensed formula) + for val in VALENCES.get(to_add, [1]): # try all possible valences of atom added + add_str = _expand_abbreviationMS(to_add) # expand to smiles if symbol is abbreviation + if bonds_left > val: # atom added does not use up remaining bonds of current atom; go to next atom to add to current atom + if cur_idx >= 0: + add_str = _get_bond_symb(val) + add_str + result = dfs(smiles + f'({add_str})', bonds_left - val, cur_idx, add_idx + direction) + else: # atom added uses up remaining bonds of current atom; it becomes new current atom + if cur_idx >= 0: + add_str = _get_bond_symb(bonds_left) + add_str + result = dfs(smiles + add_str, val - bonds_left, add_idx, add_idx + direction) + trials, success = result[2:] + num_trials += trials + if success: + return result[0], result[1], num_trials, success + if num_trials > 10000: + break + return smiles, bonds_left, num_trials, False + + cur_idx = -1 if direction == 1 else len(formula_list) + add_idx = 0 if direction == 1 else len(formula_list) - 1 + return dfs('', start_bond, cur_idx, add_idx) + +def swap_paren_bracket(text): + # Check if string starts with '(' + if not text.startswith('('): + return text + # Pattern: match (...) followed by [...] + pattern = r'^\((.*?)\)\[(.*?)\]' + # Find match + match = re.match(pattern, text) + if match: + # Swap the groups: [group2](group1) + return f'[{match.group(2)}]({match.group(1)})' + + return text + +def convert_ch2_string(s): + # 匹配 (CH2)后面跟数字或字母的模式 + pattern = r'\(CH2\)(\d+|[a-zA-Z]+)' + match = re.fullmatch(pattern, s) + if not match: + return s # 如果不是目标模式,返回原字符串 + + suffix = match.group(1) + + if suffix.isdigit(): + n = int(suffix) + if n == 1: + return '[CH2]' + else: + return '[CH2]' + 'C' * (n - 1) + else: + # 处理变量情况,如 (CH2)m + var = suffix + print(var,s) + return s + + +def process_string_joinused(s): + # 检查字符串是否以[]开头 + match = re.match(r'^\[([^\]]*)\](.*)$', s) + if not match: + return s # 如果不匹配,直接返回原字符串 + + content, rest = match.groups() + # 计算[]中字符数 + char_count = len(content) + + # 如果字符数大于1且包含H + if char_count > 1 and 'H' in content: + # 移除H及其后连续的数字 + new_content = re.sub(r'H\d*', '', content) + return f'[{new_content}]{rest}' + return s + +def all_elements_in_dict(lst, dictionary): + """ + 递归检查列表(可能嵌套)中的所有元素是否都存在于字典的键中 + + :param lst: 要检查的列表(可能包含嵌套列表) + :param dictionary: 要检查的字典 + :return: 如果所有元素都在字典键中返回True,否则返回False + """ + for element in lst: + if isinstance(element, list): + # 如果是嵌套列表,递归检查 + if not all_elements_in_dict(element, dictionary): + return False + else: + # 如果是普通元素,检查是否在字典键中 + if element not in dictionary: + return False + return True + +def expand_cf2_to_smiles(input_string): + # 正则表达式匹配 (CF2)nX 的模式,X 为任意字母数字字符串 + pattern = r'\(CF2\)(\d+)([A-Za-z0-9]+)' + match = re.match(pattern, input_string) + if not match: + return input_string + # 提取数字 n 和末尾的化学基团 X + n = int(match.group(1)) + tail_group = f"[{match.group(2)}]" + # 构建 SMILES 字符串 + # 每个 CF2 单元是 [C](F)(F),重复 n 次,最后接 tail_group + cf2_unit = 'C(F)(F)' + smiles = '[C](F)(F)' + cf2_unit * (n-1) + tail_group if n > 0 else tail_group + return smiles + +def find_repeating_unit_and_smiles(s): + match = re.fullmatch(r'(.+?)(?:\1)+', s) + if match: + unit = match.group(1) + repeat_count = len(s) // len(unit) + # 根据重复单元生成SMILES(适当处理CH2 -> C, CF2 -> CF2) + if unit == "CH2": + smiles_unit = "C" # CH2 -> C + smi_init="[CH2]" + elif unit == "CF2": + smiles_unit = "C(F)(F)" # CF2保持原样 + smi_init="[C](F)(F)" + elif unit == "SO2": + smiles_unit = "S(=O)(=O)" # SO2保持原样 + smi_init="[S](=O)(=O)" + else: + smiles_unit,smi_init='','' + print(f'please add the repateat patter here !!! for: {s}') + # smiles_unit = unit # 其他单元直接使用 + # 生成最终的SMILES + smiles = smi_init + smiles_unit * (repeat_count - 1 ) + + return smiles, repeat_count, unit + else: + return None, 0, None # 如果没有匹配到,则返回None + +def get_smiles_from_symbol(symbol, mol, bonds): + """ + Convert symbol (abbrev. or condensed formula) to smiles + If condensed formula, determine parsing direction and num. bonds on each side using coordinates + """ + if symbol in ABBREVIATIONS: + return ABBREVIATIONS[symbol].smiles + if symbol in RGROUP_SYMBOLS or (symbol[0] in RGROUP_SYMBOLS and symbol[1:].isdigit()): + if symbol[1:].isdigit(): + return f'[{symbol[1:]}*]' + return '*' + + if len(symbol) > 20: + return None + smiles=convert_ch2_string(symbol) + if smiles !=symbol: + return smiles + if '(CF2)' in symbol: + smiles=expand_cf2_to_smiles(symbol) + return smiles + smiles, repeat_count, unit = find_repeating_unit_and_smiles(symbol) + if repeat_count>0: + return smiles + + #TODO@@@ add as speical case or add function, + # this is hard encode NOTE fix this next version + if symbol in ['CH2CH','CHCH2','CH2CH2', 'CH2CH2CH','CH2CH2CH','H2CH2CHC','CHCH2CH2','(CH2)10', 'H2C','CH2',#'CH2CH2NSO2CH3', + 'OCH2CHOHCH2NH','OCH2CHOHCH2','CF2O','OF2C','EtO2CHN','EtO2C', + 'CH2CH2C(O)0CH2CH3','CH2CH2C(O)OCH2CH3','l23I', + 'OCH2CH2OH','OCH2CHCH2CCH3','CH2O', + '(H4NO)2','SO2NHCH2CH','OCH2CH','OCF2H','COCOOCH2CH3','CH2CH2CH2CH','HCH2CH2CH2C','CF3CF2CF2CF2SO3', + # 'SO2(CH2)3SO2NHCH2CHCH2OH', + '(CF2)8H','PH3C','CO','OC', + 'CF2CF2H','NHSO2CH3','CH2CH2C','CH;CH2C(O)0CHCH3','CH2CH2C(O)OCHCH3', + 'NH2','H2N', 'CHO', 'OHC', 'N(SO2CH3)2','CH2CH2O','CH2CH2C(O)OCH2CH3', + #ACS + 'Ar2P(O)','PhO2S','NHP(O)Ph2','P*Ph3','P+Ph3','NH2.HCl', + #CLEF + 'S[O]a', + #USPTO + '(C3H6O)7CH3','HC','(HC','(CH2CH2CH2CH-)','3(CHCHCHCH272', + #UOB + 'NHzBrH','NH2BrH', + #staker + '(co)','(CO)', + #JPO + 'CH3CH','CH3CCH3','CH3CO','CH3OCH2','CO2C','CH2CO2CH3',"COCl", + ]:#NOTE this are not passed by _condensed_formula_list_to_smiles function + #TODO fix me in next version, may be need LLM to track this + # Substitution(['CHO', 'OHC'], '[CH1](=O)', "[CH1](=O)", 0.5), + # Substitution(['NH2','H2N'], '[NH2;D1]', "[NH2]", 0.1), + #TODO symbol2SMILES() need dig ChemDraw + if symbol in ['CH2CH','CHCH2']:smiles='[CH2][CH]' + elif symbol in ['PH3C']:smiles='[CH2]P' + elif symbol in ['l23I']:smiles='[I]' + elif symbol in ['HC','(HC']:smiles='[CH]' + elif symbol in ['NHzBrH','NH2BrH']:smiles='[NH2].Br' + elif symbol in ['(C3H6O)7CH3']:smiles="[O]CCC"+"OCCC"*6+'C'#TODO maybe as function + elif symbol in ['NH2.HCl']:smiles="[NH2].Cl" + elif symbol in ['CH2CH2CH2CH','(CH2CH2CH2CH-)']:smiles='[CH2]CC[CH]' + elif symbol in ['3(CHCHCHCH272', 'CHCHCHCH2']:smiles='[CH]CC[CH2]' + # elif symbol in ['D']:smiles='[2H]' + elif symbol in [ 'CH3CH']:smiles='[CH]C' + elif symbol in [ 'CH2CO2CH3']:smiles='[CH2]C(=O)OC' + elif symbol in [ 'CO2C']:smiles='[C](=O)O[C]' + elif symbol in [ 'CH3CCH3']:smiles='[C](C)(C)' + elif symbol in [ 'CH3CO']:smiles='[C](=O)C' + elif symbol in [ 'CH3OCH2']:smiles='[CH2]OC' + + elif symbol in [ '(co)','(CO)']:smiles='[C](=O)' + elif symbol in ['Ar2P(O)']:smiles='[P](*)(*)(=O)' + elif symbol in ['PhO2S']:smiles='[S](=O)(=O)c1ccccc1' + elif symbol in ['CO','OC']:smiles='[C](=O)' + elif symbol in ['CH2O']:smiles='[CH2][O]' + elif symbol in ['P*Ph3','P+Ph3',]:smiles='[P+](c1ccccc1)(c1ccccc1)(c1ccccc1)' + elif symbol in ['NHP(O)Ph2']:smiles='[NH]P(=O)(c1ccccc1)c1ccccc1' + elif symbol in ['CH;CH2C(O)0CHCH3','CH2CH2C(O)OCHCH3']:smiles='[CH2]CC(=O)OCC' + elif symbol in ['CH2CH2CH','H2CH2CHC','CHCH2CH2']:smiles='[CH2][CH2][CH]' + elif symbol in ['CH2CH2CH2CH']:smiles='[CH2]CC[CH]' + elif symbol in ['HCH2CH2CH2C']:smiles='[CH]CC[CH2]' + elif symbol in ['H2C','CH2']:smiles='[CH2]' + elif symbol in ['H2CH2C','CH2CH2']:smiles='[CH2][CH2]' + elif symbol in ['CHO', 'OHC']:smiles="[CH](=O)" + elif symbol in ['NH2','H2N']:smiles="[NH2]" + elif symbol in ['(CF2)8H',]:smiles="[C](F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)C(F)(F)" + elif symbol in ['CH2CH2C(O)OCH2CH3','CH2CH2C(O)0CH2CH3']:smiles='[CH2]CC(=O)OCC' + elif symbol in ['CF3CF2CF2CF2SO3']:smiles='[S](=O)(=O)([O-])C(F)(F)C(F)(F)C(F)(F)C(F)(F)(F)' + elif symbol in ['S[O]a']:smiles='[S](=O)' + elif symbol in ['COCl']:smiles='[C](=O)Cl' + + + + elif symbol in ['OCF2H']:smiles="[O]C(F)(F)" + elif symbol in ['CF2O']:smiles="[C](F)(F)[O]" + elif symbol in ['OF2C']:smiles="[O][C](F)(F)" + elif symbol in ['CF2CF2H']:smiles="[C](F)(F)C(F)(F)" + # elif symbol in ['CH2CH2NSO2CH3']:smiles='[CH2]CNS(=O)(C)=O' + elif symbol in ['CH2CH2O']:smiles='[CH2]CO' + elif symbol in ['OCH2CH2OH']:smiles='[O]CCO'#NOTE Chemdraw may give some idea + elif symbol in ['EtO2CHN']:smiles='[N]C(=O)OCC' + elif symbol in ['OCH2CHOHCH2NH']:smiles='[O]CC(O)CN' + elif symbol in ['OCH2CHCH2CCH3']:smiles='[O]C[CH]C[C]C' + elif symbol in ['(H4NO)2']:smiles='[O]NON' + elif symbol in ['SO2NHCH2CH']:smiles='[S](=O)(=O)NC[CH]' + elif symbol in ['N(SO2CH3)2']:smiles='[N](S(=O)(=O)C)(S(=O)(=O)C)' + elif symbol in ['CH2CH2C(O)OCH2CH3']:smiles='[CH2]CC(=O)OCC' + elif symbol in ['OCH2CH']:smiles='[O]C[CH]' + elif symbol in ['EtO2C']:smiles='C(=O)OCC' + elif symbol in ['CH2CH2C']:smiles='[CH2]C[C]' + elif symbol in ['NHSO2CH3']:smiles='[NH]S(=O)(=O)C' + elif symbol in ['COCOOCH2CH3']:smiles='C(=O)C(=O)OCC' + # elif symbol in ['SO2(CH2)3SO2NHCH2CHCH2OH']:smiles='[S](=O)(=O)CCCS(=O)(=O)NC[C]CO' + # elif symbol in ['H4NO3S']:smiles='[S]NCC' + # elif symbol in ['(CH2)10','[CH]CCCCCCCCC']:smiles='[CH]CCCCCCCCC'#as in convert_ch2_string() + else:smiles=None + return smiles + + total_bonds = int(sum([bond.GetBondTypeAsDouble() for bond in bonds]))#TODO aromtaic bond effect ?? + formula_list = _expand_carbon(_parse_formula(symbol)) + # all_in_dict = all(fl in ABBREVIATIONS for fl in formula_list) + all_in_dict=all_elements_in_dict(formula_list,ABBREVIATIONS) + #total_bonds, bonds_left 机制是有问题的, 所以需要以上的修补,机制不完善 + smiles, bonds_left, num_trails, success = _condensed_formula_list_to_smiles(formula_list, total_bonds, None) + # if debug: + print(f'{[formula_list, total_bonds]} use _condensed_formula_list_to_smiles {success} <<-------\n {smiles}') + if success: + smiles=swap_paren_bracket(smiles) + return smiles + elif all_in_dict :#NOTE resolve abbv combine + # smiles=ABBREVIATIONS[formula_list[0]].smiles + key = extract_abbreviation_key(formula_list[0]) + if key in ABBREVIATIONS: + smiles = ABBREVIATIONS[key].smiles + else: + # raise ValueError(f"Abbreviation {key} not found in ABBREVIATIONS.") + print(f"Abbreviation {key} not found in ABBREVIATIONS.") + smiles='' + for fl_i in range(1,len(formula_list)): + cur_smi=process_string_joinused(ABBREVIATIONS[formula_list[fl_i]].smiles) + smiles += cur_smi + return smiles + + return None + +def abbrev2smile(abbrev,abbrev_exp,mol,idx): + + atom_gost = mol.GetAtomWithIdx(idx) + bonds_gost = atom_gost.GetBonds() + sub_smi = get_smiles_from_symbol(abbrev, mol, bonds_gost) + + if sub_smi: + # print(f"succes expanding {abbrev},{abbrev_exp}\n{sub_smi}\t{idx}") + return sub_smi + else: + print(f"failed expanding {abbrev},{abbrev_exp}\n{sub_smi}\t{idx}") + return '[*]' + + # if abbrev_exp[0] in ABBREVIATIONS: + # init_smi=ABBREVIATIONS[abbrev_exp[0]].smiles + # else: + # if len(abbrev_exp[0])==1: + # init_smi=f'[{abbrev_exp[0]}]' + # else: + # print(f"{abbrev_exp[0]} @@@formula_regex") + # init_smi=f'[{abbrev_exp[0]}]' + # # init_smi=ABBREVIATIONS[abbrev_exp[0]].smiles if abbrev_exp[0] in ABBREVIATIONS else + # if len(abbrev_exp)==1: + # sub_smi=init_smi + # return sub_smi + # elif len(abbrev_exp)>1: + # sub_smi=init_smi + # for i_ in range(1,len(abbrev_exp)): + + # smi_=ABBREVIATIONS[abbrev_exp[i_]].smiles if abbrev_exp[i_] in ABBREVIATIONS else f'[{abbs[i_]}]' + # smi_2=re.sub(r'\[([^\[\]]+)\]', replace_bracket, smi_) + # sub_smi +=smi_2#default combine them with single bond TODO fixme ifneed + # return sub_smi + # else: + # return False +def replace_cg_notation(astr): + def replacer(match): + h_count = int(match.group(1)) + c_count = (h_count - 1) // 2 + return f'C{c_count}H{h_count}' + + return re.sub(r'CgH(\d+)', replacer, astr) + + +def _expand_abbreviation(abbrev, mol,idx):# ABBREVIATIONS, RGROUP_SYMBOLS, ELEMENTS): + """ + Expand abbreviation into its SMILES; also converts [Rn] to [n*]. + """ + + if abbrev in ABBREVIATIONS: + return ABBREVIATIONS[abbrev].smiles + # elif sub_smi_HC:return sub_smi_HC + elif N_C_H_expand(abbrev):return N_C_H_expand(abbrev) + elif C_F_expand(abbrev):return C_F_expand(abbrev) + elif C_H_expand2(abbrev):return C_H_expand2(abbrev) + elif C_H_expand(abbrev):return C_H_expand(abbrev) + elif C_H_affixExpand(abbrev):return C_H_affixExpand(abbrev) + # elif abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): + elif abbrev in RGROUP_SYMBOLS or (abbrev[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()): + if abbrev[1:].isdigit(): + return f'[{abbrev[1:]}*]' + elif abbrev in ELEMENTS: + return f'[{abbrev}]' + + elif formula_regex(abbrev): + abbrev_exp= formula_regex(abbrev) + return abbrev2smile(abbrev,abbrev_exp,mol,idx)#last use Molscribe way + + match = re.match(r'^(\d+)?(.*)', abbrev) + if match: + numeric_part, remaining_part = match.groups() + if remaining_part in ELEMENTS: + return f'[{abbrev}]' + elif numeric_part: + return f'[{numeric_part}*]' + + else: + print(f"fixme !!!@@@@: {abbrev}") + + return '[*]' + +def count_current_bonds(mol, atom_idx): + """Count current bonds (including bond order) for an atom.""" + atom = mol.GetAtomWithIdx(atom_idx) + return sum(bond.GetBondTypeAsDouble() for bond in atom.GetBonds()) + +debug_not=True + +def expandABB(mol, ABBREVIATIONS, placeholder_atoms):#, RGROUP_SYMBOLS, ELEMENTS): + mols = [mol] + # 逆序遍历 placeholder_atoms,确保删除后不会影响后续索引 + for idx in sorted(placeholder_atoms.keys(), reverse=True) : + group = placeholder_atoms[idx] + group_smiles = _expand_abbreviation(group,mol,idx) + submol = Chem.MolFromSmiles(group_smiles) # 获取官能团的子分子 + try: + submol_rw = Chem.RWMol(submol) # 转换为可编辑的 RWMol + except Exception as e: + print(f"abbver: {group}") + print(f'try to convert {group_smiles} to sub_mol') + print(e) + if debug_not: + print(f"Failed to convert {group_smiles} to sub_mol, using placeholder [*] instead.") + submol = Chem.MolFromSmiles('[*]') + submol_rw = Chem.RWMol(submol) + else: + raise e#NOTE use it when debugging with adding abber and fixing rules in det_engine.py + + # 1. 识别 submol 的 anchor atoms(连接点) + anchor_atoms = [0]#always use the fisrt atom as anchor atom + for atom in submol_rw.GetAtoms(): + # 具有自由基的原子或标记为连接点的原子(例如 [*]) + if atom.GetNumRadicalElectrons() > 0 and atom.GetIdx() not in anchor_atoms:# or atom.GetSymbol() == '*': + anchor_atoms.append(atom.GetIdx()) + # 2. 复制主分子 + new_mol = Chem.RWMol(mol) + placeholder_idx = idx + # 3. 记录 placeholder (*) 原子的邻居及其键类型 + bonds_info = [] + for bond in new_mol.GetBonds(): + if bond.GetBeginAtomIdx() == placeholder_idx: + bonds_info.append({ + "neighbor": bond.GetEndAtomIdx(), + "bond_type": bond.GetBondType() + }) + elif bond.GetEndAtomIdx() == placeholder_idx: + bonds_info.append({ + "neighbor": bond.GetBeginAtomIdx(), + "bond_type": bond.GetBondType() + }) + + # 4. 断开 placeholder 的所有键 + for bond_info in bonds_info: + new_mol.RemoveBond(placeholder_idx, bond_info["neighbor"]) + + # 5. 删除 placeholder 原子 + new_mol.RemoveAtom(placeholder_idx) + + # 6. 重新计算邻居索引(删除后索引变化) + adjusted_bonds_info = [] + for bond_info in bonds_info: + neighbor = bond_info["neighbor"] + if neighbor < placeholder_idx: + adjusted_neighbor = neighbor + else: + adjusted_neighbor = neighbor - 1 # 索引因删除原子而减 1 + adjusted_bonds_info.append({ + "neighbor": adjusted_neighbor, + "bond_type": bond_info["bond_type"] + }) + + # 7. 合并 submol + new_mol = Chem.RWMol(Chem.CombineMols(new_mol, submol_rw)) + + # 8. 计算 submol 的 anchor atoms 在合并后的索引 + submol_atom_offset = new_mol.GetNumAtoms() - submol_rw.GetNumAtoms() + new_anchor_indices = [submol_atom_offset + anchor_idx for anchor_idx in anchor_atoms] + + # 9. 重新连接官能团,使用原始键类型 + if len(new_anchor_indices) == 1: + # 单连接点情况:所有邻居连接到唯一的 anchor atom + anchor_idx = new_anchor_indices[0] + for bond_info in adjusted_bonds_info: + neighbor = bond_info["neighbor"] + bond_type = bond_info["bond_type"] + new_mol.AddBond(neighbor, anchor_idx, bond_type) + # 重置自由基电子数 + a1 = new_mol.GetAtomWithIdx(neighbor) + a2 = new_mol.GetAtomWithIdx(anchor_idx) + a1.SetNumRadicalElectrons(0) + a2.SetNumRadicalElectrons(0) + else: + # # 多连接点情况:先尝试按顺序连接, 如果* 连* 会存在多种合理价态的不同分子情况 + # 多连接点情况:根据邻居数量和 anchor atoms 分配连接 + if len(adjusted_bonds_info) > len(new_anchor_indices): + print(adjusted_bonds_info,' <---adjusted_bonds_info') + print(new_anchor_indices,'<---new_anchor_indices') + # raise ValueError(f"Too many neighbors ({len(adjusted_bonds_info)}) for submol with {len(new_anchor_indices)} anchor atoms.") + # for i, bond_info in enumerate(adjusted_bonds_info): + # # 按顺序将邻居连接到 anchor atoms + # anchor_idx = new_anchor_indices[i % len(new_anchor_indices)] + # neighbor = bond_info["neighbor"] + # bond_type = bond_info["bond_type"] + # new_mol.AddBond(neighbor, anchor_idx, bond_type) + # # 重置自由基电子数 + # a1 = new_mol.GetAtomWithIdx(neighbor) + # a2 = new_mol.GetAtomWithIdx(anchor_idx) + # a1.SetNumRadicalElectrons(0) + # a2.SetNumRadicalElectrons(0) + # 跟踪每个 anchor 的当前成键数 + anchor_bond_counts = {idx: new_mol.GetAtomWithIdx(idx).GetTotalValence() for idx in new_anchor_indices} + print(anchor_bond_counts,'<---anchor_bond_counts') + # max_valence = {6: 4, 7: 3, 8: 2} # 示例:C=4, N=3, O=2,需根据实际原子类型扩展 + adjusted_bonds_info = sorted(adjusted_bonds_info, key=lambda x: x['neighbor']) + if mol.GetNumConformers() > 0:#as some mol may not have the conf dispite pass the 2d assign process + pos_0 = mol.GetConformer().GetAtomPosition(adjusted_bonds_info[0]['neighbor']) + pos_1 = mol.GetConformer().GetAtomPosition(adjusted_bonds_info[-1]['neighbor']) + print(pos_0.x,pos_1.x,"xxx",adjusted_bonds_info) + # if group =='SO2NH': + # if pos_0.x 0.1: + if orig_valid and scaled_valid: + if orig_score >= scaled_score and orig_text: + return orig_text, orig_score, cropped_img_orig + elif scaled_text: + return scaled_text, scaled_score, cropped_img_scaled + elif orig_valid and not scaled_valid: + return orig_text, orig_score, cropped_img_orig + elif scaled_valid and not orig_valid: + return scaled_text, scaled_score, cropped_img_scaled + else: + print(f"Both texts are invalid: orig_text='{orig_text}', scaled_text='{scaled_text}'") + if orig_score >= scaled_score: + return orig_text, orig_score, cropped_img_orig + else: + return scaled_text, scaled_score, cropped_img_scaled + # 如果分差小于0.1,选择更合理的化学表达式 + else: + # 如果只有一个有效,选择有效的 + if orig_valid and not scaled_valid: + return orig_text, orig_score, cropped_img_orig + elif scaled_valid and not orig_valid: + return scaled_text, scaled_score, cropped_img_scaled + # 如果都有效,比较长度 + elif orig_valid and scaled_valid: + if orig_text in ABBREVIATIONS and scaled_text not in ABBREVIATIONS: + if N_C_H_expand(scaled_text) or C_F_expand(scaled_text) or C_H_expand2(scaled_text) or C_H_expand(scaled_text): + if len(scaled_text)> len(orig_text): + return scaled_text, scaled_score, cropped_img_scaled + return orig_text, orig_score, cropped_img_orig + elif orig_text not in ABBREVIATIONS and scaled_text in ABBREVIATIONS: + if N_C_H_expand(orig_text) or C_F_expand(orig_text) or C_H_expand2(orig_text) or C_H_expand(orig_text): + if len(orig_text)> len(scaled_text): + return orig_text, orig_score, cropped_img_orig + return scaled_text, scaled_score, cropped_img_scaled + elif orig_text not in ABBREVIATIONS and scaled_text not in ABBREVIATIONS: + if len(orig_text) > len(scaled_text): + return orig_text, orig_score, cropped_img_orig + else: + if len(orig_text) == len(scaled_text): + if orig_score >= scaled_score : + return orig_text, orig_score, cropped_img_orig + else: + return scaled_text, scaled_score, cropped_img_scaled + return scaled_text, scaled_score, cropped_img_scaled + + elif orig_text in ABBREVIATIONS and scaled_text in ABBREVIATIONS: + if len(orig_text) >= len(scaled_text): + return orig_text, orig_score, cropped_img_orig + else: + return scaled_text, scaled_score, cropped_img_scaled + # 如果都不有效,优先选择 orig(若存在) + elif orig_text: + return orig_text, orig_score, cropped_img_orig + elif scaled_text: + return scaled_text, scaled_score, cropped_img_scaled + + # 默认返回 scaled(若存在) + return scaled_text, scaled_score, cropped_img_scaled if scaled_text else (None, None, None) + +# def expandABB(mol,ABBREVIATIONS, placeholder_atoms):# RGROUP_SYMBOLS, ELEMENTS): +# mols = [mol] + +# # Process placeholders in reverse order to avoid index issues +# for idx in sorted(placeholder_atoms.keys(), reverse=True): +# group = placeholder_atoms[idx] +# group_smiles = _expand_abbreviation(group)# ABBREVIATIONS, RGROUP_SYMBOLS, ELEMENTS) + +# try: +# submol = Chem.MolFromSmiles(group_smiles) +# if not submol: +# raise ValueError(f"Invalid SMILES for group {group}: {group_smiles}") +# submol_rw = RWMol(submol) +# except Exception as e: +# print(f"Error processing SMILES for group {group}: {e}") +# continue + +# # Create a new editable molecule +# new_mol = RWMol(mol) +# placeholder_idx = idx + +# # Get neighbors of the placeholder atom +# neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()] + +# # Identify anchor atoms in submol (atoms marked as [*] or with isotope labels) +# anchor_atoms = [] +# for atom in submol.GetAtoms(): +# if atom.GetNumRadicalElectrons() > 0: +# #atom.GetSymbol() == '*' or atom.GetIsotope() > 0: +# anchor_atoms.append(atom.GetIdx()) + +# # Validate number of anchor atoms vs. neighbors +# if len(anchor_atoms) != len(neighbors): +# print(f"Warning: Mismatch between anchor atoms ({len(anchor_atoms)}) and neighbors ({len(neighbors)}) for group {group}") +# print(len(anchor_atoms), len(neighbors)) +# if len(anchor_atoms)==0: +# anchor_atoms.append(0)# use the first atom of submol as default such as PPh3 + + +# # Remove bonds involving the placeholder atom +# bonds_to_remove = [(bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()) +# for bond in new_mol.GetBonds() +# if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx] +# for bond in bonds_to_remove: +# new_mol.RemoveBond(bond[0], bond[1]) + +# # Remove the placeholder atom +# new_mol.RemoveAtom(placeholder_idx) + +# # Adjust neighbor indices after atom removal +# new_neighbors = [n - 1 if n > placeholder_idx else n for n in neighbors] + +# # Combine molecules +# new_mol = RWMol(CombineMols(new_mol, submol_rw)) + +# # Connect anchor atoms to neighbors +# submol_offset = new_mol.GetNumAtoms() - submol.GetNumAtoms() +# for anchor_idx, neighbor_idx in zip(anchor_atoms, new_neighbors): +# new_anchor_idx = submol_offset + anchor_idx +# new_mol.AddBond(neighbor_idx, new_anchor_idx, Chem.BondType.SINGLE) + +# # Reset radical electrons +# new_mol.GetAtomWithIdx(neighbor_idx).SetNumRadicalElectrons(0) +# new_mol.GetAtomWithIdx(new_anchor_idx).SetNumRadicalElectrons(0) + +# mol = new_mol +# mols.append(mol) + +# # Generate final SMILES +# try: +# modified_smiles = Chem.MolToSmiles(mols[-1]) +# except Exception as e: +# print(f"Error generating SMILES: {e}") +# return mols[-1], None + +# return mols[-1], modified_smiles + + +# def _expand_abbreviation(abbrev): +# """ +# Expand abbreviation into its SMILES; also converts [Rn] to [n*] +# Used in `_condensed_formula_list_to_smiles` when encountering abbrev. in condensed formula +# """ +# if abbrev in ABBREVIATIONS: +# return ABBREVIATIONS[abbrev].smiles +# elif abbrev in RGROUP_SYMBOLS or (abbrev[0] == 'R' and abbrev[1:].isdigit()): + +# if abbrev[1:].isdigit(): +# return f'[{abbrev[1:]}*]' +# elif abbrev in ELEMENTS:#ocr tool need this +# return f'[{abbrev}]' +# # try abbrev + +# match = re.match(r'^(\d+)?(.*)', abbrev) +# if match: +# numeric_part, remaining_part = match.groups() +# if remaining_part in ELEMENTS: +# return f'[{abbrev}]' +# else: +# if numeric_part: +# abbrev=f'[{numeric_part}*]' +# return '[*]' + + + +# def expandABB(mol,ABBREVIATIONS, placeholder_atoms): +# mols = [mol] +# # **第三步: 替换 * 并合并官能团** +# # 逆序遍历 placeholder_atoms,确保删除后不会影响后续索引 +# for idx in sorted(placeholder_atoms.keys(), reverse=True): +# group = placeholder_atoms[idx] # 获取官能团名称 +# # print(idx, group) +# group=_expand_abbreviation(group) +# submol = Chem.MolFromSmiles(group) # 获取官能团的子分子 +# submol_rw = RWMol(submol) # 让 submol 变成可编辑的 RWMol +# anchor_atom_idx = 0 # 选择 `submol` 的第一个原子作为连接点 as defined in ABBREVIATIONS +# # **1. 复制主分子** +# new_mol = RWMol(mol) +# # **2. 计算 `*` 在 `new_mol` 中的索引** +# placeholder_idx = idx +# # **3. 记录 `*` 原子的邻居** +# neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()] +# # **4. 断开 `*` 的所有键** +# bonds_to_remove = [] # 记录要断开的键 +# for bond in new_mol.GetBonds(): +# if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx: +# bonds_to_remove.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx())) +# for bond in bonds_to_remove: +# new_mol.RemoveBond(bond[0], bond[1]) +# # **5. 删除 `*` 原子** +# new_mol.RemoveAtom(placeholder_idx) +# # **6. 重新计算 `neighbors`(删除后索引变化)** +# new_neighbors = [] +# for neighbor in neighbors: +# if neighbor < placeholder_idx: +# new_neighbors.append(neighbor) +# else: +# new_neighbors.append(neighbor - 1) # 因为删除了一个原子,所有索引 -1 +# # **7. 合并 `submol`** +# new_mol = RWMol(CombineMols(new_mol, submol_rw)) + +# # **8. 计算 `submol` 的第一个原子在合并后的位置** +# new_anchor_idx = new_mol.GetNumAtoms() - len(submol_rw.GetAtoms()) + anchor_atom_idx + +# # **9. 重新连接官能团** +# for neighbor in new_neighbors: +# # print(neighbor, new_anchor_idx, "!!") +# new_mol.AddBond(neighbor, new_anchor_idx, Chem.BondType.SINGLE) +# a1=new_mol.GetAtomWithIdx(neighbor) +# a2=new_mol.GetAtomWithIdx(new_anchor_idx) +# a1.SetNumRadicalElectrons(0) +# a2.SetNumRadicalElectrons(0)## 将自由基电子数设为 0,as has added new bond +# # **10. 更新主分子** +# mol = new_mol +# mols.append(mol) +# # 输出修改后的分子 SMILES +# modified_smiles = Chem.MolToSmiles(mols[-1]) +# # print(f"修改后的分子 SMILES: {modified_smiles}") +# return mols[-1], modified_smiles + + + + + +# Helper function to check if two boxes overlap +def boxes_overlap(box1, box2): + x1, y1, x2, y2 = box1 + bx1, by1, bx2, by2 = box2 + return not (x2 < bx1 or x1 > bx2 or y2 < by1 or y1 > by2) + +def boxes_overlap2(atombonx, bondbox): + """ + 检查两个矩形框是否重叠,并返回 bondbox 中不重叠一端到中心 10% 位置的坐标。 + + 参数: + atombonx: tuple (x1, y1, x2, y2) 表示原子框的坐标 + bondbox: tuple (bx1, by1, bx2, by2) 表示键框的坐标 + + 返回: + tuple (x, y) 表示 bondbox 不重叠一端到中心 80% 位置的坐标,如果完全包含返回 (None, None) + """ + x1, y1, x2, y2 = atombonx + bx1, by1, bx2, by2 = bondbox + + # 计算 bond_box 的中心坐标 + bond_center_x = (bx1 + bx2) / 2 + bond_center_y = (by1 + by2) / 2 + + # 辅助函数:计算点到 atom_box 中心的距离 + def distance_to_center(x, y): + center_x = (x1 + x2) / 2 + center_y = (y1 + y2) / 2 + return ((x - center_x) ** 2 + (y - center_y) ** 2) ** 0.5 + + # 辅助函数:计算从中心到端点 80% 位置的坐标 + def get_80_percent_point(far_x, far_y): + # 从中心到端点的向量,按 80% 比例缩放 + dx = far_x - bond_center_x + dy = far_y - bond_center_y + new_x = bond_center_x + 0.7 * dx#let added H close to the heavy neighbor + new_y = bond_center_y + 0.7 * dy + return new_x, new_y + + # 检查是否完全不相交 + if (bx2 < x1 or bx1 > x2 or by2 < y1 or by1 > y2): + # 完全不相交,返回较远一端到中心 80% 位置 + dist1 = distance_to_center(bx1, by1) + dist2 = distance_to_center(bx2, by2) + far_x, far_y = (bx2, by2) if dist2 > dist1 else (bx1, by1) + return get_80_percent_point(far_x, far_y) + + # 检查是否完全包含在 atom_box 内 + if (bx1 >= x1 and bx2 <= x2 and by1 >= y1 and by2 <= y2): + # bondbox 完全在 atom_box 内,无法确定不重叠部分,返回 bond_center_x, bond_center_y + # return None, None + return bond_center_x, bond_center_y + + # 检查一端是否在 atom_box 内 + if (bx1 >= x1 and bx1 <= x2 and by1 >= y1 and by1 <= y2): + # bx1, by1 在 atom_box 内,返回 bx2, by2 到中心 80% 位置 + return get_80_percent_point(bx2, by2) + elif (bx2 >= x1 and bx2 <= x2 and by2 >= y1 and by2 <= y2): + # bx2, by2 在 atom_box 内,返回 bx1, by1 到中心 80% 位置 + return get_80_percent_point(bx1, by1) + + # 处理部分相交但两端都不在 atom_box 内的情况 + # 返回较远一端到中心 80% 位置 + dist1 = distance_to_center(bx1, by1) + dist2 = distance_to_center(bx2, by2) + far_x, far_y = (bx2, by2) if dist2 > dist1 else (bx1, by1) + return get_80_percent_point(far_x, far_y) + + +charge_labels = [19,20,21,22,23] +def outputbox_update(output,charge_labels,bond_labels,lab2idx): + bonds_mask = np.array([True if ins in bond_labels else False for ins in output['pred_classes']]) + bond_bbox=output['bbox'][bonds_mask] + atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']]) + atom_bbox=output['bbox'][atoms_mask] + new_atoms=[] + b_len=3 + single_odd_b2a=dict() + for bi,bb in enumerate(bond_bbox): + overlapped_atoms = [] + overlapped_abox=[] + for ai,aa in enumerate(atom_bbox): + overlap_flag=boxes_overlap(bb, aa)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol + if overlap_flag: + # print(bb, aa,overlap_flag) + overlapped_atoms.append(ai) + overlapped_abox.append(aa) + if len(overlapped_atoms) == 1: + single_odd_b2a[bi]=overlapped_atoms + # Compute the non-overlapping part of the bond box to place hydrogen + non_overlapping_x,non_overlapping_y=boxes_overlap2(overlapped_abox[0], bb) + new_atom_out={'bbox': np.array([non_overlapping_x - b_len, + non_overlapping_y - b_len, + non_overlapping_x + b_len, + non_overlapping_y + b_len]).reshape(-1,4), + 'bbox_centers': np.array([non_overlapping_x,non_overlapping_y]).reshape(-1,2), + 'scores': np.array([1.0]), + 'pred_classes': np.array([lab2idx['H']])} + new_atoms.append(new_atom_out) + + output2_=copy.deepcopy(output) + for boxout in new_atoms: + for k,arr in boxout.items(): + value_or_row=output2_[k] + if arr.ndim == 1: + output2_[k]=np.append(value_or_row, arr) + elif arr.ndim >= 2: + output2_[k] = np.concatenate([value_or_row, arr], axis=0) + else: + print('errprs, unkown conditions !!!@') + return output2_, single_odd_b2a + + +def remove_unconnected_hydrogens(mol): + """ + 移除分子中不与重原子相连的氢原子(包括孤立 H 和只连到其他 H 的 H)。 + + 参数: + mol: RDKit Mol 对象 + + 返回: + 移除氢原子后的 RWMol 对象 + """ + # 转换为可编辑的 RWMol 对象 + molexp = Chem.RWMol(mol) + to_remove = [] + + # 遍历所有原子 + for atom in molexp.GetAtoms(): + if atom.GetSymbol() == 'H': # 只处理氢原子 + neighbors = atom.GetNeighbors() + # 检查邻居中是否有重原子 + has_heavy_atom = False + for neighbor in neighbors: + if neighbor.GetSymbol() != 'H': # 如果邻居不是 H,则是重原子 + has_heavy_atom = True + break + # 如果没有重原子邻居,标记为移除 + if not has_heavy_atom: + to_remove.append(atom.GetIdx()) + # 按索引从大到小排序,避免移除时索引混乱 + to_remove.sort(reverse=True) + + # 移除标记的原子 + for ai in to_remove: + molexp.RemoveAtom(ai) + return molexp + +from rdkit import Chem +from rdkit.Chem import AllChem + +def remove_unconnected_hydrogens2(mol): + """ + 移除分子中不与重原子相连的氢原子(包括孤立 H 和只连到其他 H 的 H),并返回移除的氢原子坐标。 + + 参数: + mol: RDKit Mol 对象 + + 返回: + rw_mol: 移除氢原子后的 RWMol 对象 + removed_h_coords: 移除的氢原子的坐标列表 [(x1, y1, z1), (x2, y2, z2), ...] + """ + # 转换为可编辑的 RWMol 对象 + rw_mol = Chem.RWMol(mol) + to_remove = [] + + # 获取分子的构象(假设只有一个构象) + conformer = rw_mol.GetConformer() + + # 存储移除的氢原子坐标 + removed_h_coords = [] + + # 遍历所有原子 + for atom in rw_mol.GetAtoms(): + if atom.GetSymbol() == 'H': # 只处理氢原子 + neighbors = atom.GetNeighbors() + # 检查邻居中是否有重原子 + has_heavy_atom = False + for neighbor in neighbors: + if neighbor.GetSymbol() != 'H': # 如果邻居不是 H,则是重原子 + has_heavy_atom = True + break + # 如果没有重原子邻居,标记为移除,并记录坐标 + if not has_heavy_atom: + to_remove.append(atom.GetIdx()) + pos = conformer.GetAtomPosition(atom.GetIdx()) + removed_h_coords.append((pos.x, pos.y, pos.z)) + # 按索引从大到小排序,避免移除时索引混乱 + to_remove.sort(reverse=True) + # 移除标记的原子 + for ai in to_remove: + rw_mol.RemoveAtom(ai) + + return rw_mol, removed_h_coords + +def detect_unconnected_hydrogens(mol): + rw_mol = Chem.RWMol(mol) + to_remove = [] + # 获取分子的构象(假设只有一个构象) + conformer = rw_mol.GetConformer() + # 存储移除的氢原子坐标 + removed_h_coords = [] + # 遍历所有原子 + for atom in rw_mol.GetAtoms(): + if atom.GetSymbol() == 'H': # 只处理氢原子 + neighbors = atom.GetNeighbors() + # 检查邻居中是否有重原子 + has_heavy_atom = False + for neighbor in neighbors: + if neighbor.GetSymbol() != 'H': # 如果邻居不是 H,则是重原子 + has_heavy_atom = True + break + # 如果没有重原子邻居,标记为移除,并记录坐标 + if not has_heavy_atom: + to_remove.append(atom.GetIdx()) + pos = conformer.GetAtomPosition(atom.GetIdx()) + removed_h_coords.append((pos.x, pos.y, pos.z)) + # 按索引从大到小排序,避免移除时索引混乱 + to_remove.sort(reverse=True) + return to_remove + +def view_box_center2(bond_bbox, bond_centers, bond_scores, bond_classes,overlap_dist_thresh=5.0, + max_centers_per_box=5, + plot_view=False, + ): + """ + 筛选和可视化 bond_bbox 和 bond_centers,处理重叠圆和过多中心的框。 + + 参数: + bond_bbox: numpy array, [x1, y1, x2, y2] 格式的框坐标 + bond_centers: numpy array, [x, y] 格式的中心坐标 + bond_scores: numpy array, 得分 + overlap_dist_thresh: float,判断圆重叠的距离阈值(默认为 5 个单位) + max_centers_per_box: int,一个框内允许的最大中心数(超过则移除) + + 返回: + tuple: (筛选后的 bond_bbox, bond_centers, bond_scores) + """ + # 确保输入形状匹配 + assert len(bond_bbox) == len(bond_centers) == len(bond_scores), "Input arrays must have equal length" + n = len(bond_bbox) + # Step 1: 处理重叠的 bond_centers(保留得分最高的) + keep_centers = np.ones(n, dtype=bool) # 标记要保留的中心 + for i in range(n): + if not keep_centers[i]: + continue + for j in range(i + 1, n): + if not keep_centers[j]: + continue + # 计算两个中心之间的欧几里得距离 + dist = np.sqrt(np.sum((bond_centers[i] - bond_centers[j]) ** 2)) + if dist < overlap_dist_thresh: + # 如果重叠,保留得分较高的 + if bond_scores[i] > bond_scores[j]: + keep_centers[j] = False + else: + keep_centers[i] = False + # 应用初步筛选 + bond_bbox = bond_bbox[keep_centers] + bond_centers = bond_centers[keep_centers] + bond_scores = bond_scores[keep_centers] + bond_classes= bond_classes[keep_centers] + n = len(bond_bbox) # 更新数量 + # Step 2: 检查每个框内的中心数量 + keep_boxes = np.ones(n, dtype=bool) # 标记要保留的框 + for i in range(n): + # 计算框内的中心数量 + x1, y1, x2, y2 = bond_bbox[i] + centers_in_box = np.sum((bond_centers[:, 0] >= x1) & (bond_centers[:, 0] <= x2) & + (bond_centers[:, 1] >= y1) & (bond_centers[:, 1] <= y2)) + if centers_in_box > max_centers_per_box: + keep_boxes[i] = False + # 应用最终筛选 + final_bond_bbox = bond_bbox[keep_boxes] + final_bond_centers = bond_centers[keep_boxes] + final_bond_scores = bond_scores[keep_boxes] + final_bond_classes= bond_classes[keep_boxes] + if plot_view: + # 可视化(可选) + fig, ax = plt.subplots(figsize=(10, 10)) + for box in final_bond_bbox: + x1, y1, x2, y2 = box + width = x2 - x1 + height = y2 - y1 + rect = Rectangle((x1, y1), width, height, linewidth=1, edgecolor='blue', facecolor='none') + ax.add_patch(rect) + for center in final_bond_centers: + circle = Circle(center, radius=5, edgecolor='red', facecolor='none', linewidth=1) + ax.add_patch(circle) + + # 设置坐标轴范围 + x_min = min(final_bond_bbox[:, 0].min(), final_bond_centers[:, 0].min()) - 10 + x_max = max(final_bond_bbox[:, 2].max(), final_bond_centers[:, 0].max()) + 10 + y_min = min(final_bond_bbox[:, 1].min(), final_bond_centers[:, 1].min()) - 10 + y_max = max(final_bond_bbox[:, 3].max(), final_bond_centers[:, 1].max()) + 10 + ax.set_xlim(x_min, x_max) + ax.set_ylim(y_min, y_max) + + ax.set_title("Filtered Boxes and Centers") + ax.set_xlabel("X") + ax.set_ylabel("Y") + plt.gca().set_aspect('equal', adjustable='box') + plt.grid(True, linestyle='--', alpha=0.7) + # plt.show() + else: + fig=None + return final_bond_bbox, final_bond_centers, final_bond_scores,final_bond_classes,fig + +def calculate_iou(box1, box2): + """ + 计算两个框的 IoU(Intersection over Union)。 + + 参数: + box1, box2: [x1, y1, x2, y2] 格式的框坐标 + + 返回: + float: IoU 值 + """ + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + intersection = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0 + +def nms_per_class(labels, boxes, scores, iou_thresh=0.5): + """ + 对每个类别应用 NMS,保留得分最高的框。 + 参数: + labels: numpy array,类别标签 + boxes: numpy array,框坐标 [x1, y1, x2, y2] + scores: numpy array,得分 + iou_thresh: float,IoU 阈值 + 返回: + dict: 筛选后的输出 + """ + # 按类别分组 + unique_labels = np.unique(labels) + kept_indices = [] + for label in unique_labels: + # 筛选当前类别的框 + class_mask = labels == label + class_indices = np.where(class_mask)[0] + class_boxes = boxes[class_mask] + class_scores = scores[class_mask] + + # 按得分从高到低排序 + order = np.argsort(class_scores)[::-1] + class_boxes = class_boxes[order] + class_scores = class_scores[order] + class_indices = class_indices[order] + + # NMS + keep = [] + while len(class_scores) > 0: + # 保留得分最高的框 + keep.append(class_indices[0]) + if len(class_scores) == 1: + break + + # 计算当前框与其他框的 IoU + ious = np.array([calculate_iou(class_boxes[0], box) for box in class_boxes[1:]]) + # 保留 IoU 低于阈值的框 + keep_mask = ious < iou_thresh + class_boxes = class_boxes[1:][keep_mask] + class_scores = class_scores[1:][keep_mask] + class_indices = class_indices[1:][keep_mask] + + kept_indices.extend(keep) + + # 根据保留的索引更新输出 + kept_indices = np.array(kept_indices) + return { + 'labels': labels[kept_indices], + 'boxes': boxes[kept_indices], + 'scores': scores[kept_indices] + } + + + + +import numpy as np +def get_overlap_region(box1, box2): + """ + Get the overlapping region of two boxes. + + Args: + box1, box2: [x_min, y_min, x_max, y_max] + + Returns: + tuple: (x_min, y_min, x_max, y_max) of overlap region, or None if no overlap + """ + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + if x2 <= x1 or y2 <= y1: + return None # No overlap + return (x1, y1, x2, y2) + +def are_bond_connected(box1, box2, bond_bboxes, bond_iou_threshold=0.1): + """ + Check if two atom boxes are connected by a bond box, with bond center in overlap region. + + Args: + box1, box2: atom boxes to check + bond_bboxes: array of bond boxes + bond_iou_threshold: IoU threshold for initial bond overlap + + Returns: + bool: True if connected by a bond with center in overlap region + """ + # Get the overlap region of the two atom boxes + overlap_region = get_overlap_region(box1, box2) + if overlap_region is None: + return False # No overlap between atom boxes + + ox_min, oy_min, ox_max, oy_max = overlap_region + + for bond_box in bond_bboxes: + # Preliminary IoU check + iou1 = calculate_iou(box1, bond_box) + iou2 = calculate_iou(box2, bond_box) + if iou1 > bond_iou_threshold and iou2 > bond_iou_threshold: + # Calculate bond box center + bond_center_x = (bond_box[0] + bond_box[2]) / 2 + bond_center_y = (bond_box[1] + bond_box[3]) / 2 + + # Check if bond center is within the overlap region + if (ox_min <= bond_center_x <= ox_max and + oy_min <= bond_center_y <= oy_max): + return True + return False + +def calculate_iou(box1, box2): + """ + 计算两个边界框的 IoU + box1, box2: [x_min, y_min, x_max, y_max] + """ + x1 = max(box1[0], box2[0]) + y1 = max(box1[1], box2[1]) + x2 = min(box1[2], box2[2]) + y2 = min(box1[3], box2[3]) + + intersection = max(0, x2 - x1) * max(0, y2 - y1) + area1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) + area2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) + union = area1 + area2 - intersection + + return intersection / union if union > 0 else 0 + +def nms(atom_bboxes, atom_scores, atom_classes, iou_threshold=0.5): + """ + 应用非极大值抑制 (NMS) + atom_bboxes: 列表,包含所有边界框 [x_min, y_min, x_max, y_max] + atom_scores: 列表,包含每个边界框的置信度 + atom_classes: 列表,包含每个边界框的类别 + iou_threshold: IoU 阈值,用于判断是否抑制 + 返回: 保留的边界框、类别和置信度的索引 + """ + # 按置信度排序,获取索引 + indices = np.argsort(atom_scores)[::-1] # 从高到低排序 + + keep_indices = [] + while len(indices) > 0: # 使用 len(indices) 替代 indices.size + # 保留当前最高置信度的框 + current_idx = indices[0] + keep_indices.append(current_idx) + + # 计算当前框与其他框的 IoU + ious = np.array([calculate_iou(atom_bboxes[current_idx], atom_bboxes[idx]) for idx in indices[1:]]) + # 找出 IoU > threshold 的索引(相对于 indices[1:] 的偏移) + suppress_indices = indices[1:][ious > iou_threshold] + # 更新 indices,去除当前框和被抑制的框 + indices = np.setdiff1d(indices, np.concatenate(([current_idx], suppress_indices))) + # 调试信息 + # print(f"Current idx: {current_idx}, rmoved: {suppress_indices}, Remaining: {indices}") + # print(f"Current idx: {current_idx}, rmoved: {suppress_indices}, IOU: {ious}") + + # 返回保留的框、类别和置信度 + kept_bboxes = np.array([atom_bboxes[i] for i in keep_indices]) + kept_classes = np.array([atom_classes[i] for i in keep_indices]) + kept_scores = np.array([atom_scores[i] for i in keep_indices]) + + return kept_bboxes, kept_classes, kept_scores + +def count_bond_overlaps(box, bond_bboxes, bond_iou_threshold=0.1): + """ + Count how many bond boxes overlap with an atom box. + + Args: + box: atom box [x_min, y_min, x_max, y_max] + bond_bboxes: array of bond boxes + bond_iou_threshold: IoU threshold for overlap + + Returns: + int: number of overlapping bond boxes + """ + return sum(1 for bond_box in bond_bboxes if calculate_iou(box, bond_box) > bond_iou_threshold) + + +def count_bond_overlaps(box, bond_bboxes, bond_iou_threshold=0.01): + """Count how many bond boxes overlap with an atom box.""" + return sum(1 for bond_box in bond_bboxes if calculate_iou(box, bond_box) > bond_iou_threshold) + +def count_atom_overlaps(box, all_bboxes, exclude_idx, min_iou=0.01): + """Count how many other atom boxes overlap with this box.""" + return sum(1 for i, other_box in enumerate(all_bboxes) + if i != exclude_idx and calculate_iou(box, other_box) > min_iou) + +def merge_low_iou_boxes(kept_bboxes, kept_classes, kept_scores, bond_bboxes, + merge_threshold=0.5, score_threshold=0.7, bond_iou_threshold=0.01, + high_iou_threshold=0.8, large_score_threshold=0.5): + """ + Merge or filter boxes with IoU conditions, removing large low-score boxes first. + + Args: + kept_bboxes: array, atom bounding boxes [x_min, y_min, x_max, y_max] + kept_classes: array, class labels (e.g., 0 for 'C') + kept_scores: array, confidence scores + bond_bboxes: array, bond bounding boxes + merge_threshold: float, upper IoU threshold for merging + score_threshold: float, score threshold to preserve boxes + bond_iou_threshold: float, IoU threshold for bond connectivity + high_iou_threshold: float, IoU threshold for high-IoU merging + large_score_threshold: float, score threshold for large box removal (default 0.5) + + Returns: + tuple: (merged_bboxes, merged_classes, merged_scores) + """ + if len(kept_bboxes) <= 1: + return kept_bboxes, kept_classes, kept_scores + + kept_bboxes = np.array(kept_bboxes) + kept_classes = np.array(kept_classes) + kept_scores = np.array(kept_scores) + bond_bboxes = np.array(bond_bboxes) + + # Step 0: Remove large boxes with low scores, high atom overlaps, and high bond overlaps + areas = (kept_bboxes[:, 2] - kept_bboxes[:, 0]) * (kept_bboxes[:, 3] - kept_bboxes[:, 1]) + median_area = np.median(areas) + keep_mask = np.ones(len(kept_bboxes), dtype=bool) + + for i in range(len(kept_bboxes)): + if kept_scores[i] < large_score_threshold: + atom_overlaps = count_atom_overlaps(kept_bboxes[i], kept_bboxes, i) + bond_overlaps = count_bond_overlaps(kept_bboxes[i], bond_bboxes, bond_iou_threshold) + is_large = areas[i] > median_area # Define "large" as above median + if is_large and atom_overlaps >= 2 and bond_overlaps >= 3: + keep_mask[i] = False + print(f"Removed large low-score box idx {i}: score {kept_scores[i]}, " + f"area {areas[i]}, atom overlaps {atom_overlaps}, bond overlaps {bond_overlaps}") + + # Filter boxes + kept_bboxes = kept_bboxes[keep_mask] + print(f"afterRemoved large low-score atom box::{len(kept_bboxes)} ") + kept_classes = kept_classes[keep_mask] + kept_scores = kept_scores[keep_mask] + if len(kept_bboxes) == 0: + return np.array([]), np.array([]), np.array([]) + + merged_bboxes = [] + merged_classes = [] + merged_scores = [] + used_indices = set() + + # Step 1: Merge boxes with IoU > high_iou_threshold + i = 0 + while i < len(kept_bboxes): + if i in used_indices: + i += 1 + continue + + high_iou_group = [i] + for j in range(len(kept_bboxes)): + if j in used_indices or j == i: + continue + iou = calculate_iou(kept_bboxes[i], kept_bboxes[j]) + if iou > high_iou_threshold: + high_iou_group.append(j) + + if len(high_iou_group) > 1:#atom box ovrlaped + group_scores = kept_scores[high_iou_group] + max_score_idx = high_iou_group[np.argmax(group_scores)] + merged_bboxes.append(kept_bboxes[max_score_idx]) + merged_classes.append(kept_classes[max_score_idx]) + merged_scores.append(kept_scores[max_score_idx]) + used_indices.update(high_iou_group) + print(f"Merged high-IoU (> {high_iou_threshold}) boxes: {high_iou_group}, " + f"kept index: {max_score_idx}") + i += 1 + + # Step 2: Process remaining boxes + i = 0 + while i < len(kept_bboxes): + if i in used_indices: + i += 1 + continue + + current_indices = [i] + for j in range(len(kept_bboxes)): + if j in used_indices or j == i: + continue + iou = calculate_iou(kept_bboxes[i], kept_bboxes[j])#IOU between atoms box + if 0.05 <= iou < merge_threshold:#better detect model with score matters + #any small IOU between atoms will processed here + if kept_scores[j]<0.7: + current_indices.append(j) + + group_indices = current_indices + group_scores = kept_scores[group_indices] + group_classes = kept_classes[group_indices] + group_bboxes = kept_bboxes[group_indices] + + max_score = np.max(group_scores) + max_score_idx = group_indices[np.argmax(group_scores)] + + if max_score >= score_threshold: + bond_connected = False + if len(group_indices) > 1: + for idx1, idx2 in zip(group_indices[:-1], group_indices[1:]): + if are_bond_connected(kept_bboxes[idx1], kept_bboxes[idx2], + bond_bboxes, bond_iou_threshold): + bond_connected = True + break + if bond_connected: + for idx in group_indices: + merged_bboxes.append(kept_bboxes[idx]) + merged_classes.append(kept_classes[idx]) + merged_scores.append(kept_scores[idx]) + print(f"Kept all bond-connected boxes: {group_indices}") + else: + bond_overlap_counts = [count_bond_overlaps(kept_bboxes[idx], bond_bboxes, + bond_iou_threshold) for idx in group_indices] + max_overlaps = max(bond_overlap_counts) + candidates = [idx for idx, count in zip(group_indices, bond_overlap_counts) + if count == max_overlaps] + best_idx = max(candidates, key=lambda idx: kept_scores[idx]) + merged_bboxes.append(kept_bboxes[best_idx]) + merged_classes.append(kept_classes[best_idx]) + merged_scores.append(kept_scores[best_idx]) + # print(f"No bond box overlap, kept box with most bond overlaps: {best_idx}, " + # f"overlap count: {max_overlaps}") + else: + if len(group_indices) == 1: + merged_bboxes.append(kept_bboxes[i]) + merged_classes.append(kept_classes[i]) + merged_scores.append(kept_scores[i]) + print(f"Merged lower IOU @@ ONLY ONE box {i}") + else: + new_bbox = [ + np.min(group_bboxes[:, 0]), # x_min + np.min(group_bboxes[:, 1]), # y_min + np.max(group_bboxes[:, 2]), # x_max + np.max(group_bboxes[:, 3]) # y_max + ] + merged_bboxes.append(new_bbox) + merged_classes.append(group_classes[np.argmax(group_scores)]) + merged_scores.append(max_score) + print(f"Merged low-score boxes: {group_indices}") + used_indices.update(group_indices) + i += 1 + + print(f"after processs low IOU atom box::{len(merged_bboxes)} ") + return (np.array(merged_bboxes), np.array(merged_classes), np.array(merged_scores)) + + +def refine_boxes(atom_bboxes, atom_scores, atom_classes, bond_bboxes, + nms_iou_threshold=0.5, merge_threshold=0.5, score_threshold=0.5, + bond_iou_threshold=0.01, high_iou_threshold=0.8): + """ + Iteratively apply NMS and merge until the number of boxes stabilizes. + + Args: + atom_bboxes, atom_scores, atom_classes: Initial atom box data + bond_bboxes: Bond box data + nms_iou_threshold, merge_threshold, score_threshold, bond_iou_threshold, high_iou_threshold: Parameters + + Returns: + tuple: (final_bboxes, final_classes, final_scores) + """ + current_bboxes = np.array(atom_bboxes) + current_classes = np.array(atom_classes) + current_scores = np.array(atom_scores) + prev_count = len(current_bboxes) + 1 # Ensure at least one iteration + + iteration = 0 + while len(current_bboxes) < prev_count: + print(f"\nIteration {iteration}: Starting with {len(current_bboxes)} boxes") + prev_count = len(current_bboxes) + + # Apply NMS + kept_bboxes, kept_classes, kept_scores = nms( + current_bboxes, current_scores, current_classes, iou_threshold=nms_iou_threshold + ) + print(f"After NMS: {len(kept_bboxes)} boxes") + + # Apply merge_low_iou_boxes + merged_bboxes, merged_classes, merged_scores = merge_low_iou_boxes( + kept_bboxes, kept_classes, kept_scores, bond_bboxes, + merge_threshold=merge_threshold, score_threshold=score_threshold, + bond_iou_threshold=bond_iou_threshold, high_iou_threshold=high_iou_threshold + ) + print(f"After merge: {len(merged_bboxes)} boxes") + + # Update for next iteration + current_bboxes = merged_bboxes + current_classes = merged_classes + current_scores = merged_scores + iteration += 1 + + print(f"Converged after {iteration} iterations with {len(current_bboxes)} boxes") + return current_bboxes, current_scores, current_classes + +def merge_low_iou_boxes_old(kept_bboxes, kept_classes, kept_scores, merge_threshold=0.3): + """ + 合并 IoU < merge_threshold 的边界框,使用较高 score 的 class + """ + if len(kept_bboxes) <= 1: + return kept_bboxes, kept_classes, kept_scores + + merged_bboxes = [] + merged_classes = [] + merged_scores = [] + used_indices = set() + + for i in range(len(kept_bboxes)): + if i in used_indices: + continue + + # 找到 IoU < merge_threshold 的框组 + current_indices = [i] + for j in range(i + 1, len(kept_bboxes)): + if j in used_indices: + continue + iou = calculate_iou(kept_bboxes[i], kept_bboxes[j]) + if iou < merge_threshold and iou >0.01: + current_indices.append(j) + + # 获取相关框的 score, class, 和 bbox + scores = kept_scores[current_indices] + classes = kept_classes[current_indices] + bboxes = kept_bboxes[current_indices] + + max_score = np.max(scores) + max_score_idx = current_indices[np.argmax(scores)] + + if max_score > 0.5: + # 保留 score 最大的框 + merged_bboxes.append(kept_bboxes[max_score_idx]) + merged_classes.append(kept_classes[max_score_idx]) + merged_scores.append(kept_scores[max_score_idx]) + else: + # 合并框,取最小和最大坐标 + new_bbox = [ + np.min(bboxes[:, 0]), # x_min + np.min(bboxes[:, 1]), # y_min + np.max(bboxes[:, 2]), # x_max + np.max(bboxes[:, 3]) # y_max + ] + merged_bboxes.append(new_bbox) + merged_classes.append(0)#repalce with * + merged_scores.append(max_score) + + # 标记已使用的索引 + used_indices.update(current_indices) + + # 转换为 NumPy 数组 + merged_bboxes = np.array(merged_bboxes) + merged_classes = np.array(merged_classes) + merged_scores = np.array(merged_scores) + + return merged_bboxes, merged_classes, merged_scores + +############################################################################################################################################################ +#molscrbe evaluate +from SmilesPE.pretokenizer import atomwise_tokenizer + +def canonicalize_smiles(smiles, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True): + if type(smiles) is not str or smiles == '': + return '', False + if ignore_cistrans: + smiles = smiles.replace('/', '').replace('\\', '') + if replace_rgroup: + tokens = atomwise_tokenizer(smiles) + for j, token in enumerate(tokens): + if token[0] == '[' and token[-1] == ']': + symbol = token[1:-1] + if symbol[0] == 'R' and symbol[1:].isdigit(): + tokens[j] = f'[{symbol[1:]}*]' + elif Chem.AtomFromSmiles(token) is None: + tokens[j] = '*' + smiles = ''.join(tokens) + try: + canon_smiles = Chem.CanonSmiles(smiles, useChiral=(not ignore_chiral)) + success = True + except: + canon_smiles = smiles + success = False + return canon_smiles, success + +def convert_smiles_to_canonsmiles( + smiles_list, ignore_chiral=False, ignore_cistrans=False, replace_rgroup=True, num_workers=16): + with multiprocessing.Pool(num_workers) as p: + results = p.starmap(canonicalize_smiles, + [(smiles, ignore_chiral, ignore_cistrans, replace_rgroup) for smiles in smiles_list], + chunksize=128) + canon_smiles, success = zip(*results) + return list(canon_smiles), np.mean(success) + +def tanimoto_similarity(smiles1, smiles2): + try: + mol1 = Chem.MolFromSmiles(smiles1) + mol2 = Chem.MolFromSmiles(smiles2) + fp1 = Chem.RDKFingerprint(mol1) + fp2 = Chem.RDKFingerprint(mol2) + tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2) + return tanimoto + except: + return 0 + + +def compute_tanimoto_similarities(gold_smiles, pred_smiles, num_workers=16): + with multiprocessing.Pool(num_workers) as p: + similarities = p.starmap(tanimoto_similarity, [(gs, ps) for gs, ps in zip(gold_smiles, pred_smiles)]) + return similarities + +class SmilesEvaluator(object): + def __init__(self, gold_smiles, num_workers=16, tanimoto=False): + self.gold_smiles = gold_smiles + self.num_workers = num_workers + self.tanimoto = tanimoto + self.gold_smiles_cistrans, _ = convert_smiles_to_canonsmiles(gold_smiles, + ignore_cistrans=True, + num_workers=num_workers) + self.gold_smiles_chiral, _ = convert_smiles_to_canonsmiles(gold_smiles, + ignore_chiral=True, ignore_cistrans=True, + num_workers=num_workers) + self.gold_smiles_cistrans = self._replace_empty(self.gold_smiles_cistrans) + self.gold_smiles_chiral = self._replace_empty(self.gold_smiles_chiral) + + def _replace_empty(self, smiles_list): + """Replace empty SMILES in the gold, otherwise it will be considered correct if both pred and gold is empty.""" + return [smiles if smiles is not None and type(smiles) is str and smiles != "" else "" + for smiles in smiles_list] + + def evaluate(self, pred_smiles, include_details=False): + results = {} + if self.tanimoto: + results['tanimoto'] = np.mean(compute_tanimoto_similarities(self.gold_smiles, pred_smiles)) + # Ignore double bond cis/trans + pred_smiles_cistrans, _ = convert_smiles_to_canonsmiles(pred_smiles, + ignore_cistrans=True, + num_workers=self.num_workers) + results['canon_smiles'] = np.mean(np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)) + if include_details: + results['canon_smiles_details'] = (np.array(self.gold_smiles_cistrans) == np.array(pred_smiles_cistrans)) + # Ignore chirality (Graph exact match) + pred_smiles_chiral, _ = convert_smiles_to_canonsmiles(pred_smiles, + ignore_chiral=True, ignore_cistrans=True, + num_workers=self.num_workers) + results['graph'] = np.mean(np.array(self.gold_smiles_chiral) == np.array(pred_smiles_chiral)) + # Evaluate on molecules with chiral centers + chiral = np.array([[g, p] for g, p in zip(self.gold_smiles_cistrans, pred_smiles_cistrans) if '@' in g]) + results['chiral'] = np.mean(chiral[:, 0] == chiral[:, 1]) if len(chiral) > 0 else -1 + return results + + + +############################################################################################################################################################ +def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module, + data_loader: Iterable, optimizer: torch.optim.Optimizer, + device: torch.device, epoch: int, max_norm: float = 0, **kwargs): + model.train() + criterion.train() + metric_logger = MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) + # metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}')) + header = 'Epoch: [{}]'.format(epoch) + print_freq = kwargs.get('print_freq', 10) + + ema = kwargs.get('ema', None) + scaler = kwargs.get('scaler', None) + + for samples, targets in metric_logger.log_every(data_loader, print_freq, header): + samples = samples.to(device) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + if scaler is not None: + with torch.autocast(device_type=str(device), cache_enabled=True): + outputs = model(samples, targets) + + with torch.autocast(device_type=str(device), enabled=False): + loss_dict = criterion(outputs, targets) + + loss = sum(loss_dict.values()) + scaler.scale(loss).backward() + + if max_norm > 0: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + + else: + outputs = model(samples, targets) + loss_dict = criterion(outputs, targets) + + loss = sum(loss_dict.values()) + optimizer.zero_grad() + loss.backward() + + if max_norm > 0: + torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) + + optimizer.step() + + # ema + if ema is not None: + ema.update(model) + + loss_dict_reduced = reduce_dict(loss_dict) + loss_value = sum(loss_dict_reduced.values()) + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + metric_logger.update(loss=loss_value, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + + +# @torch.no_grad() +# def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, data_loader, base_ds, device, output_dir, +# annot_file=f'/home/jovyan/rt-detr/data/real_processed/CLEF_with_charge/annotations/val.json', +# outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_CLEF.csv', +# ): +# model.eval() +# criterion.eval() + +# metric_logger = MetricLogger(delimiter=" ") +# header = 'Test:' + +# iou_types = postprocessors.iou_types +# coco_evaluator = CocoEvaluator(base_ds, iou_types) + +# panoptic_evaluator = None + +# # # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# # home='/home/jovyan/rt-detr' +# # dataset = 'CLEF' +# # annot_file=f'/home/jovyan/rt-detr/data/real_processed/{dataset}_with_charge/annotations/test.json' +# # outcsv_filename/home/jovyan/rt-detr/rt-detr/output/output_charge_{dataset}.csv' + + +# # annot_file=f'/home/jovyan/rt-detr/data/real_processed/{dataset}_with_charge/annotations/test.json' +# # outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_{dataset}.csv' +# with open(annot_file, 'r') as file: +# data = json.load(file) + + + + +# image_id_to_name = {} + +# for image_data in data['images']: +# image_id = image_data['id'] +# image_path = image_data['file_name'] +# image_name = os.path.basename(image_path) +# image_id_to_name[image_id] = image_name + +# res_smiles = [] +# bond_labels = [13,14,15,16,17,18] +# idx_to_labels={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', +# 9:'I',10:'P',11:'H',12:'Si', +# #bond +# 13:'single',14:'wdge',15:'dash', +# 16:'=',17:'#',18:':',#aromatic +# #charge +# 19:'-4',20:'-2', +# 21:'-1',#- +# 22:'+1',#+ +# 23:'2', +# } +# lab2idx={v:k for k,v in idx_to_labels.items()} +# #indigo bond type stero maping +# indi_bond={ +# "1":'single', "2":'=',"3":'#',"4":':',"5":'wdge',"6":'dash', +# } + + +# smiles_data = pd.DataFrame({'file_name': [], +# 'SMILES':[]}) + +# output_dict = {} +# target_dict = {} +# filtered_output_dict = {} +# # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# for samples, targets in metric_logger.log_every(data_loader, 10, header): +# samples = samples.to(device) +# targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + +# outputs = model(samples) + +# orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) +# results = postprocessors(outputs, orig_target_sizes)#RTDETRPostProcessor@@src/zoo/rtertr +# #results: a list of dict label box score +# res = {target['image_id'].item(): output for target, output in zip(targets, results)} + +# for target, output in zip(targets, results): +# output_dict[target['image_id'].item()] = output + +# stats = {} +# # stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} +# if coco_evaluator is not None: +# if 'bbox' in iou_types: +# # stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist() +# stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats +# if 'segm' in iou_types: +# stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist() + + + +# # ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ +# # ocr_recognition_only = get_ocr_recognition_only(force_cpu=False) +# # caption_remover = CaptionRemover(force_cpu=True) +# for key, value in output_dict.items():#TODO improve here +# selected_indices = value['scores'] > 0.5#may be >=0.5 cut off, as used the sigmoid? +# if value['labels'][selected_indices].size(0) != 0:#no good prediction +# filtered_output_dict[key] = { +# 'labels': value['labels'][selected_indices],# may be selected_indices ==0 as all small than0.5 +# 'boxes': value['boxes'][selected_indices], +# 'scores': value['scores'][selected_indices] +# } +# else: +# ima_name=image_id_to_name[key] +# print(key,"all prediction scores small 0.5!!",len(output_dict),f"{ima_name}")## + +# for i,(key,value) in enumerate(filtered_output_dict.items()): +# result = []#TODO need a box2mol or graph +# smi_mol=output_to_smiles(value,idx_to_labels,bond_labels,result)#TODO use the idx_to_labels numer to if --else +# if smi_mol: +# res_smiles.append(smi_mol[0]) #TODO check this erro other0 +# else: +# res_smiles.append('') + +# new_row = {'file_name':image_id_to_name[key], 'SMILES':res_smiles[i]} +# smiles_data = smiles_data._append(new_row, ignore_index=True) + +# print(f"will save {len(smiles_data)} dataframe into csv") +# smiles_data.to_csv(outcsv_filename, index=False) + +# return stats, coco_evaluator + +def remove_bond_directions_if_no_chiral(mol): + # 检查分子是否有效 + if mol is None: + return None + # 计算手性中心 + chiral_centers = Chem.FindMolChiralCenters(mol, includeUnassigned=True) + # 如果没有手性中心,移除单键的立体化学标记 + if not chiral_centers: + for bond in mol.GetBonds(): + # 只处理单键 + if bond.GetBondType() == Chem.BondType.SINGLE: + # 移除楔形和虚线标记 + bond.SetBondDir(Chem.BondDir.NONE) + return mol +####################################################################################### +def molExpanding(mol_rebuit,placeholder_atoms,wdbs,bond_dirs,alignmol=False): + cm=copy.deepcopy(mol_rebuit) + # print(placeholder_atoms) + expand_mol, expand_smiles= expandABB(cm,ABBREVIATIONS, placeholder_atoms) + rdm=copy.deepcopy(expand_mol) + AllChem.Compute2DCoords(rdm) + target_mol, ref_mol=rdm, cm + + if alignmol: + mcs=rdFMCS.FindMCS([target_mol, ref_mol], # larger,small order + atomCompare=rdFMCS.AtomCompare.CompareAny, + # bondCompare=rdFMCS.BondCompare.CompareAny, + ringCompare=rdFMCS.RingCompare.IgnoreRingFusion, + matchChiralTag=False, + ) + atommaping_pairs=g_atompair_matches([target_mol, ref_mol],mcs) + atomMap=atommaping_pairs[0] + try: + rmsd2=rdkit.Chem.rdMolAlign.AlignMol(prbMol=target_mol, refMol=ref_mol, atomMap=atomMap,maxIters=2000000) + except Exception as e: + print(atomMap,"@@@@") + print(e) + #after get atomMap + c2p={cur:pre for cur, pre in atomMap} + p2c={pre:cur for cur, pre in atomMap} + for b in wdbs:#add bond direction + p0,p1=int(b[0]), int(b[1])#may be not in the atomMap as the mcs_sub + if p0 in p2c.keys() and p1 in p2c.keys(): + c0,c1=p2c[p0],p2c[p1] + # print("[pre0,pre1]vs[c0,c1]current atom id",[p0,p1],[c0,c1]) + b_=target_mol.GetBondBetweenAtoms(c0,c1) + if b_: + b_.SetBondDir(bond_dirs[b[3]]) + expandStero_smi=Chem.MolToSmiles(target_mol)#directly will not add the stero info into smiles, must have the assing steps + else: + expandStero_smi =expand_smiles + + m=target_mol.GetMol() + # Chem.SanitizeMol(m) + Chem.DetectBondStereochemistry(m) + Chem.AssignChiralTypesFromBondDirs(m) + Chem.AssignStereochemistry(m)#expandStero_smi , m + + return expandStero_smi, m + + +def remove_backslash_and_slash(input_string): + if "\\" in input_string: + input_string = input_string.replace("\\", "") + if "/" in input_string: + input_string = input_string.replace("/", "") + + return input_string + + +def remove_number_before_star(input_string): + result = list(input_string) + + i = 0 + while i < len(result): + if result[i] == '*' and i!= len(result) -1: + #*c1c(*)c(*)c(C(*)(*)C(C)C)c(*)c1* --> *c1c(*)c(*)c(C(*)(*)C(C)C)c(*)c1* + j = i - 1 + if result[j-1].isalpha(): + continue + while j >= 0 and result[j].isdigit(): + result[j] = '' + j -= 1 + i += 1 + + return ''.join(result) + +def remove_SP(input_string): + pattern = r'\[([^@]*)@?[A-Z0-9]*\]' + # if "S@SP1" in input_string: + # input_string = input_string.replace("S@SP1", "S") + # elif "S@SP2" in input_string: + # input_string = input_string.replace("S@SP2", "S") + # elif "S@SP3" in input_string: + # input_string = input_string.replace("S@SP3", "S") + input_string = re.sub(r'@SP[1-3]', '', input_string) + if '@TB' in input_string: + result = re.sub(pattern, r'[\1]', input_string) + input_string=result + return input_string + +def rdkit_canonicalize_smiles(smiles): + Aad_string = r'([A-Z][a-z]*)([0-9]+)' + tokens = atomwise_tokenizer(smiles) + for j, token in enumerate(tokens): + if token[0] == '[' and token[-1] == ']': + symbol = token[1:-1] + # matches = re.findall(Aad_string, symbol)#findall may give not wanted, such as [BH2], shuld not change + matches = re.match(Aad_string, symbol) + if matches: + letters, numbers = matches.groups() + print(f"{letters} {numbers}") + # tokens[j] = f'[{symbol[1:]}*]' + tokens[j] = '*' + elif symbol in RGROUP_SYMBOLS:# or (symbol[0] in RGROUP_SYMBOLS and abbrev[1:].isdigit()): + tokens[j] = '*' + elif Chem.AtomFromSmiles(token) is None: + tokens[j] = '*' + + smiles = ''.join(tokens) + try: + canon_smiles = Chem.CanonSmiles(smiles, useChiral=False) + success = True + except: + canon_smiles = smiles + success = False + return canon_smiles, success + +def NoRadical_Smi(smi): + aa=Chem.MolFromSmiles(smi) + for atom in aa.GetAtoms(): + if atom.GetNumRadicalElectrons() > 0: # 检查是否有自由基 + # print(f"找到自由基原子: {atom.GetSymbol()}, 自由电子数: {atom.GetNumRadicalElectrons()}") + # 添加氢原子以去除自由基 + atom.SetNumRadicalElectrons(0) # 将自由电子数设为 0 + # 根据硫原子的化合价调整氢原子数 + atom.SetNumExplicitHs(atom.GetTotalValence() - atom.GetExplicitValence()) + san_before=Chem.MolToSmiles(aa) + # print(san_before) + return san_before + +import logging + +def check_and_fix_valence(smiles_or_list): + """ + Check atom valences in a SMILES string or a list [smiles, suffix/prefix]. + Fix unusual valences (e.g., N(2)) by adding/removing hydrogens to maintain neutrality. + Returns: (corrected_smiles_or_list, warnings) + """ + # Set up logging + logging.basicConfig(level=logging.WARNING) + warnings = [] + + # Standard valence dictionary for common atoms + standard_valences = { + 'C': [4], + 'N': [3], # Prioritize valence 3 for neutral nitrogen (e.g., amines, amides) + 'O': [2], + 'H': [1], + 'F': [1] + } + + # Handle input: SMILES string or list from C_H_expand + if isinstance(smiles_or_list, list): + smiles, other_part = smiles_or_list + else: + smiles, other_part = smiles_or_list, None + + # Process main SMILES + mol = Chem.MolFromSmiles(smiles, sanitize=False) if smiles else None + if mol is None: + warnings.append(f"Invalid SMILES: {smiles}") + return smiles_or_list, warnings + + # Process other_part if it exists and is a valid SMILES + other_part_mol = None + if other_part: + try: + other_part_mol = Chem.MolFromSmiles(other_part, sanitize=False) + except: + pass # other_part may not be valid SMILES (e.g., a suffix/prefix) + + # Helper function to check and fix valence for a molecule + def process_molecule(mol, is_other_part=False): + nonlocal warnings + corrected = False + prefix = "other_part" if is_other_part else "SMILES" + + # Compute valence explicitly to avoid precondition violation + mol.UpdatePropertyCache(strict=False) + + # Check valences + for atom in mol.GetAtoms(): + symbol = atom.GetSymbol() + valence = atom.GetTotalValence() + expected_valences = standard_valences.get(symbol, [valence]) + if valence not in expected_valences: + warnings.append(f"Unusual valence in {prefix} for {symbol}: {valence} (expected {expected_valences})") + + # Fix nitrogen valence issues by adjusting hydrogens + if any('N' in w for w in warnings if prefix in w): + rw_mol = Chem.RWMol(mol) # Editable molecule + for atom in rw_mol.GetAtoms(): + if atom.GetSymbol() != 'N': + continue + valence = atom.GetTotalValence() + if valence < 3: + # Add hydrogens to reach valence 3 + hydrogens_needed = 3 - valence + atom.SetNumExplicitHs(atom.GetNumExplicitHs() + hydrogens_needed) + corrected = True + elif valence > 3: + # Remove hydrogens if possible + hydrogens_to_remove = valence - 3 + current_hydrogens = atom.GetNumExplicitHs() + if current_hydrogens >= hydrogens_to_remove: + atom.SetNumExplicitHs(current_hydrogens - hydrogens_to_remove) + corrected = True + else: + warnings.append(f"Cannot reduce N valence in {prefix} to 3 without removing non-H bonds") + if corrected: + mol = rw_mol.GetMol() + + # Sanitize molecule after corrections + if corrected: + try: + Chem.SanitizeMol(mol, catchErrors=True) + return mol, True + except Exception as e: + warnings.append(f"Failed to sanitize {prefix} after correction: {str(e)}") + return mol, False + return mol, False + + # Process main molecule + mol, mol_corrected = process_molecule(mol) + + # Convert main molecule back to SMILES + corrected_smiles = Chem.MolToSmiles(mol) if mol_corrected else smiles + + # Process other_part if it's a valid molecule + corrected_other_part = other_part + if other_part_mol: + other_part_mol, other_corrected = process_molecule(other_part_mol, is_other_part=True) + corrected_other_part = Chem.MolToSmiles(other_part_mol) if other_corrected else other_part + + # Return based on input type + if other_part: + return [corrected_smiles, corrected_other_part], warnings + return corrected_smiles, warnings + +def molfpsim(original_smiles,test_smiles):#I2M use the coordinates, so 2D coformation should be always + #only use longest for desalts, one molecule comparing + test_smiles= select_longest_smiles(test_smiles) + original_smiles= select_longest_smiles(original_smiles) + test_smiles, warnings=check_and_fix_valence(test_smiles) + + original_smiles = remove_backslash_and_slash(original_smiles)#c/s + test_smiles = remove_backslash_and_slash(test_smiles) + original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] + test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) + original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used + test_smiles = remove_SP(test_smiles) + + rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)#R-->* + if "S" in rd_smi_ori and success1:#NOTE H replace radical electron + rd_smi_ori=NoRadical_Smi(rd_smi_ori) + rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) + original_smiles,test_smiles=rd_smi_ori,rd_smi + + mol1 = Chem.MolFromSmiles(original_smiles)#TODO considering smiles with rdkit not recongized in real data + mol2 = Chem.MolFromSmiles(test_smiles)#TODO considering smiles with rdkit not recongized in real data + + morganfps1 = AllChem.GetMorganFingerprint(mol1, useChirality=False) + morganfps2 = AllChem.GetMorganFingerprint(mol2, useChirality=False) + morgan_tani = DataStructs.DiceSimilarity(morganfps1, morganfps2) + fp1 = Chem.RDKFingerprint(mol1) + fp2 = Chem.RDKFingerprint(mol2) + tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2) + return morgan_tani, tanimoto + + + + +def comparing_smiles2(original_smiles,test_smiles):#I2M use the coordinates, so 2D coformation should be always + original_smiles = remove_backslash_and_slash(original_smiles)#c/s + test_smiles = remove_backslash_and_slash(test_smiles) + original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] + test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) + original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used + test_smiles = remove_SP(test_smiles) + + rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles)#R-->* + if "S" in rd_smi_ori and success1:#NOTE H replace radical electron + rd_smi_ori=NoRadical_Smi(rd_smi_ori) + + rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) + original_smiles,test_smiles=rd_smi_ori,rd_smi + + try: + original_mol = Chem.MolFromSmiles(original_smiles)#considering whe nmmet abbrev + test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)#as build mol may not sanitized for rdkit + if success2 and success1: + # if original_smiles!=test_smiles: + # print(f'smiles ori,pred after Chem.CanonSmiles:\n{original_smiles}\n{test_smiles}') + RDarom_smi=Chem.MolToSmiles(original_mol) + RDarom_smi_test=Chem.MolToSmiles(test_mol) + if RDarom_smi==RDarom_smi_test: + return True + else: + print(f'smiles ori,pred after Chem.CanonSmiles:\n{RDarom_smi}\n{RDarom_smi_test}\n') + + if original_mol: + Chem.SanitizeMol(original_mol) + keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True) + else: + keku_smi_ori=original_smiles + + if test_mol: + Chem.SanitizeMol(test_mol) + keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True) + else: + keku_smi=test_smiles + + if '*' not in keku_smi: + keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori)) + keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi)) + else: + keku_inch_ori= 1 + keku_inch_test= 2 + + rd_smi=Chem.MolToSmiles(test_mol)#need improve the acc + rd_smi_ori=Chem.MolToSmiles(original_mol) + except Exception as e:#TODO fixme here + print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")# original_smiles,test_smiles\n{original_smiles}\n{test_smiles}") + print(e,"!!!!!!!\n") + keku_inch_ori= 1 + keku_inch_test= 2 + keku_smi=1 + keku_smi_ori=2 + #add molscribe rules here + if not success1:#ori smiles still invaild even after * replaced + rd_smi_ori = rd_smi + # else: + # if canon_smiles1 == canon_smiles2: + # rd_smi_ori = rd_smi + # else: + if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :#as orinial smiles may use kekuleSmiles style + return True + else:return False + +def smiles12_comparing(original_smiles,test_smiles): + original_smiles = remove_backslash_and_slash(original_smiles)#c/s + test_smiles = remove_backslash_and_slash(test_smiles) + original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] + test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) + original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used + test_smiles = remove_SP(test_smiles) + + rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles) + rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) + original_smiles,test_smiles=rd_smi_ori,rd_smi + try: + original_mol = Chem.MolFromSmiles(original_smiles)#considering whe nmmet abbrev + test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)#as build mol may not sanitized for rdkit + if original_mol: + Chem.SanitizeMol(original_mol) + keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True) + else: + keku_smi_ori=original_smiles + + if test_mol: + Chem.SanitizeMol(test_mol) + keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True) + else: + keku_smi=test_smiles + + if '*' not in keku_smi: + keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori)) + keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi)) + else: + keku_inch_ori= 1 + keku_inch_test= 2 + + rd_smi=Chem.MolToSmiles(test_mol)#need improve the acc + rd_smi_ori=Chem.MolToSmiles(original_mol) + except Exception as e:#TODO fixme here + print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")# original_smiles,test_smiles\n{original_smiles}\n{test_smiles}") + print(e,"!!!!!!!\n") + keku_inch_ori= 1 + keku_inch_test= 2 + keku_smi=1 + keku_smi_ori=2 + #add molscribe rules here + if not success1:#ori smiles still invaild even after * replaced + rd_smi_ori = rd_smi + # else: + # if canon_smiles1 == canon_smiles2: + # rd_smi_ori = rd_smi + # else: + if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :#as orinial smiles may use kekuleSmiles style + return True + else:return False + + +def comparing_smiles(new_row,test_smiles):#I2M use the coordinates, so 2D coformation should be always + original_smiles=new_row['SMILESori'] + original_smiles = remove_backslash_and_slash(original_smiles)#c/s + test_smiles = remove_backslash_and_slash(test_smiles) + original_smiles = re.sub(r'\[(\d+)\*', '[*',original_smiles)#[1*]-->[*] + test_smiles = re.sub(r'\[(\d+)\*', '[*',test_smiles) + original_smiles = remove_SP(original_smiles)#additional complex space stero from coordinates, most not used + test_smiles = remove_SP(test_smiles) + + rd_smi_ori, success1=rdkit_canonicalize_smiles(original_smiles) + rd_smi, success2=rdkit_canonicalize_smiles(test_smiles) + original_smiles,test_smiles=rd_smi_ori,rd_smi + try: + original_mol = Chem.MolFromSmiles(original_smiles)#considering whe nmmet abbrev + test_mol = Chem.MolFromSmiles(test_smiles,sanitize=False)#as build mol may not sanitized for rdkit + if original_mol: + Chem.SanitizeMol(original_mol) + keku_smi_ori=Chem.MolToSmiles(original_mol,kekuleSmiles=True) + else: + keku_smi_ori=original_smiles + + if test_mol: + Chem.SanitizeMol(test_mol) + keku_smi=Chem.MolToSmiles(test_mol,kekuleSmiles=True) + else: + keku_smi=test_smiles + + if '*' not in keku_smi: + keku_inch_ori= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi_ori)) + keku_inch_test= Chem.MolToInchi(Chem.MolFromSmiles(keku_smi)) + else: + keku_inch_ori= 1 + keku_inch_test= 2 + + rd_smi=Chem.MolToSmiles(test_mol)#need improve the acc + rd_smi_ori=Chem.MolToSmiles(original_mol) + except Exception as e:#TODO fixme here + print(f"comparing_smiles@@@ kekulize or SanitizeMol problems")# original_smiles,test_smiles\n{original_smiles}\n{test_smiles}") + print(new_row) + print(e,"!!!!!!!\n") + keku_inch_ori= 1 + keku_inch_test= 2 + keku_smi=1 + keku_smi_ori=2 + #add molscribe rules here + if not success1:#ori smiles still invaild even after * replaced + rd_smi_ori = rd_smi + # else: + # if canon_smiles1 == canon_smiles2: + # rd_smi_ori = rd_smi + # else: + if rd_smi_ori == rd_smi or keku_smi_ori == keku_smi or keku_inch_ori==keku_inch_test :#as orinial smiles may use kekuleSmiles style + return True + else:return False + + + + + + + +def bbox2center(bbox): + x_center = (bbox[:, 0] + bbox[:, 2]) / 2 + y_center = (bbox[:, 1] + bbox[:, 3]) / 2 + # center_coords = torch.stack((x_center, y_center), dim=1) + centers = np.stack((x_center, y_center), axis=1) + return centers + +import cv2 +BONDDIRECT=['ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT'] + + +def reorder_bond_bbox(bond_bbox, single_atom_bond): + # 分离普通索引和需要后置的索引 + normal_indices = [] + special_indices = [] + # 获取需要后置的 key + keys_to_move = set(single_atom_bond.keys()) + # 分类所有索引 + for i in range(len(bond_bbox)): + if i in keys_to_move: + special_indices.append(i) + else: + normal_indices.append(i) + # 新顺序:普通索引在前,特殊索引在后 + new_order = normal_indices + special_indices + # 重排 bond_bbox + reordered_bbox = [bond_bbox[i] for i in new_order] + return reordered_bbox + +def boxes_overlap(box1, box2): + """ + 检查两个边界框是否重叠 + box1, box2: [x1, y1, x2, y2] + """ + return not (box1[2] < box2[0] or box1[0] > box2[2] or + box1[3] < box2[1] or box1[1] > box2[3]) +def calculate_center(box): + """ + 计算边界框的中心点 + """ + return np.array([(box[0] + box[2]) / 2, (box[1] + box[3]) / 2]) +def merge_boxes(box1, box2): + """ + 合并两个边界框,返回新边界框 [x1, y1, x2, y2] + """ + return [ + min(box1[0], box2[0]), + min(box1[1], box2[1]), + max(box1[2], box2[2]), + max(box1[3], box2[3]) + ] + + +def get_merged_box(boxes): + """Calculate the smallest box encompassing all given boxes.""" + x_mins = [box[0] for box in boxes] + y_mins = [box[1] for box in boxes] + x_maxs = [box[2] for box in boxes] + y_maxs = [box[3] for box in boxes] + return [min(x_mins), min(y_mins), max(x_maxs), max(y_maxs)] + +def box_area(box): + """Calculate the area of a box.""" + return (box[2] - box[0]) * (box[3] - box[1]) + +def Newbox_(atom_bbox,bond_bbox, lab2idx): + #add H atom box when on direction bond + new_atoms=[] + b_len=3 + single_odd_b2a=dict() + for bi,bb in enumerate(bond_bbox): + overlapped_atoms = [] + overlapped_abox=[] + for ai,aa in enumerate(atom_bbox): + overlap_flag=boxes_overlap(bb, aa)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol + if overlap_flag: + # print(bb, aa,overlap_flag) + overlapped_atoms.append(ai) + overlapped_abox.append(aa) + if len(overlapped_atoms) == 1: + single_odd_b2a[bi]=overlapped_atoms + # Compute the non-overlapping part of the bond box to place hydrogen + non_overlapping_x,non_overlapping_y=boxes_overlap2(overlapped_abox[0], bb) + new_atom_out={'bbox': np.array([non_overlapping_x - b_len, + non_overlapping_y - b_len, + non_overlapping_x + b_len, + non_overlapping_y + b_len]).reshape(-1,4), + 'bbox_centers': np.array([non_overlapping_x,non_overlapping_y]).reshape(-1,2), + 'scores': np.array([1.0]), + 'pred_classes': np.array([lab2idx['H']])} + new_atoms.append(new_atom_out) + return new_atoms, single_odd_b2a + + +def has_boxes(data): + #TO CHECK OCR detct used or not + return isinstance(data, list) and len(data) > 0 and all( + isinstance(item, list) and len(item) == 2 and + isinstance(item[0], list) and len(item[0]) == 4 + for item in data + ) + +def AtomBox2bondBox(atom_box,bond_bbox): + b_nei=[] + overlap=True + for bi,bb in enumerate(bond_bbox): + overlap_flag=boxes_overlap(bb, atom_box)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol + if overlap_flag: + b_nei.append(bi) + if len(b_nei)==0: + # delt_hei.append(hei) + overlap=False + return overlap, b_nei + + +import torchvision.transforms.v2 as T + +def image_to_tensor(image_path,debug=True): + image = Image.open(image_path) + w, h = image.size + + # 处理灰度或其他模式 + if image.mode == "L": + if debug: print("检测到灰度图像 (1 通道),转换为 RGB...") + image = image.convert("RGB") + elif image.mode != "RGB": + if debug: print(f"检测到 {image.mode} 模式,转换为 RGB...") + image = image.convert("RGB") + # Define a transform to convert the image to a tensor and normalize it + transform = T.Compose([ + T.Resize((640, 640)), # 调整大小 + # T.ToImageTensor(), # 转换为 PyTorch Tensor + T.ToTensor(), + lambda x: x.to(torch.float32), # 手动转换数据类型# T.ConvertDtype(dtype=torch.float32), # 转换数据类型 + ]) + + # Apply the transform to the image + tensor = transform(image) + + return tensor,w,h + + + +# from src.zoo.rtdetr.rtdetr_postprocessor import RTDETRPostProcessor + +@torch.no_grad() +def evaluate_x(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, + data_loader, device, + outcsv_filename=f'/home/jovyan/rt-detr/rt-detr/output/output_charge_CLEF.csv', + visual_check=False, + other2ppsocr=True, + getacc=False, + ): + + postprocessor2=RTDETRPostProcessor(num_classes=30, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) + output_directory = os.path.dirname(outcsv_filename) + prefix_f = os.path.basename(outcsv_filename).split('.')[0] + if other2ppsocr: + ocr = PaddleOCR( + use_angle_cls=True, + lang='latin',use_space_char=True,use_debug=False, + use_gpu=True if cv2.cuda.getCudaEnabledDeviceCount() > 0 else False) + + ocr2 = ocr2 = PaddleOCR(use_angle_cls=True,use_gpu =False,use_debug=False, + rec_algorithm='SVTR_LCNet', rec_model_dir='/nfs_home/bowen/.paddleocr/whl/rec/en/en_PP-OCRv4_rec_infer', + lang="en") + outcsv_filename=f"{output_directory}/{prefix_f}_withOCR.csv" + + + if visual_check: + output_directory = os.path.dirname(outcsv_filename) + prefix_f = os.path.basename(outcsv_filename).split('.')[0] + ima_checkdir=f"{output_directory}/{prefix_f}_Boxed" + os.makedirs(ima_checkdir, exist_ok=True) + + if getacc: + acc_summary=f"{outcsv_filename}.I2Msummary.txt" + flogout = open(f'{acc_summary}' , 'w') + failed=[] + mydiff=[] + simRD=0 + sim=0 + mysum=0 + + model.eval() + criterion.eval() + metric_logger = MetricLogger(delimiter=" ") + header = 'Infering:' + res_smiles = [] + idx_to_labels23={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', + 9:'I',10:'P',11:'*',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH', + 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'2',} + idx_to_labels30 = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', + 9:'I',10:'P',11:'H',12:'Si',13:'NONE',14:'BEGINWEDGE',15:'BEGINDASH', + 16:'=',17:'#',18:'-4',19:'-2',20:'-1',21:'1',22:'2', + 23:'CF3',#NOTE rdkit get element not supporting group + 24:'CN', + 25:'Me', + 26:'CO2Et', + 27:'R', + 28:'Ph', + 29:'*', + } + bond_labels = [13,14,15,16,17] + + if postprocessors.num_classes==23: + # print(data["categories"]) + print(f'usage idx_to_labels23',idx_to_labels23) + idx_to_labels = idx_to_labels23 + elif postprocessors.num_classes==30: + # print(data["categories"])#NOTE 11 is H not * now + print(f'usage idx_to_labels30',idx_to_labels30) + idx_to_labels = idx_to_labels30 + else: + print(f"error unkown ways@@@@@@@@@@@!!!!!!!!!!idx_to_labels::{len(idx_to_labels)}\n{idx_to_labels}") + abrevie={"[23*]":'CF3', + "[24*]":'CN', + "[25*]":'Me', + "[26*]":'CO2Et', + "[27*]":'R', + "[28*]":'Ph', + "[29*]":'3~7UP', + } + # idx_to_labels = idx_to_labels23 + lab2idx={ v:k for k,v in idx_to_labels.items() } + + smiles_data = pd.DataFrame({'file_name': [], + 'SMILESori':[], + 'SMILESpre':[], + 'SMILESexp':[], + } + ) + output_dict = {} + output_ori={} + filtered_output_dict = {} + box_thresh=0.1 + # for samples, targets in metric_logger.log_every(data_loader, 10, header): + # samples = samples.to(device) + # # targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + # outputs = model(samples) + # # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)#.to(device) + # orig_target_sizes = targets["orig_size"].to(device) + # results = postprocessors(outputs, orig_target_sizes)#RTDETRPostProcessor@@src/zoo/rtertr + # for i_, z in enumerate(zip(targets['image_id'], results)): + # ti, output=z + # output_dict[ti.item()] = [ + # output, + # targets['img_path'][i_], + # targets['SMILES'][i_], + # ] + + # output_ori[ti.item()] =[ + # targets['img_path'][i_], + # targets['SMILES'][i_], + # ] + # print(len(output_ori),len(output_dict)) + for samples, targets in metric_logger.log_every(data_loader, 10, header): + # orig_target_sizes = targets["orig_size"].to(device) + for i_, ti in enumerate(targets['image_id']): + output_dict[ti.item()] = [ + targets['img_path'][i_], + targets['SMILES'][i_], + ] + + + for key, value in output_dict.items(): + + image_path = value[0] + SMILESori = value[1] + + # selected_indices = value['scores'] > 0.5#may be >=0.5 cut off, as used the sigmoid? + # selected_indices = value[0]['scores'] > box_thresh + # true_count = selected_indices.sum().item() + #testing here + image_path='/cadd_data/samba_share/from_docker/data/work_space/ori/real/acs/ol020229e-Scheme-c3-10.png' + + tensor,w,h = image_to_tensor(image_path) + tensor=tensor.unsqueeze(0).to(device) + print(tensor.size()) # Output tensor shape (C x H x W) + ori_size=torch.Tensor([w,h]).long().unsqueeze(0).to(device) + outputs = model(tensor) + result_ = postprocessor2(outputs, ori_size) + # result_ = postprocessors(outputs, ori_size) + score_=result_[0]['scores'] + boxe_=result_[0]['boxes'] + label_=result_[0]['labels'] + #---------------------------------################################ + selected_indices =score_ > box_thresh + true_count = selected_indices.sum().item() + output={ + 'labels': label_[selected_indices].to("cpu").numpy(), + 'boxes': boxe_[selected_indices].to("cpu").numpy(), + 'scores': score_[selected_indices].to("cpu").numpy() + } + + img_ori = Image.open(image_path).convert('RGB') + w_ori, h_ori = img_ori.size # 获取原始图像的尺寸 + print(w_ori, h_ori, "orignianl vs 1000,1000") + + print(f"selected_indices 中 True 的数量: {true_count}") + print(f"before nms_per_class, :: box 的数量:{len(output['labels'])}") + output = nms_per_class(output['labels'], output['boxes'], output['scores'], iou_thresh=0.5) + print(f"after nms_per_class, :: box 的数量:{len(output['labels'])}") + + + # filtered_output_dict={image_path: output} + x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2 + y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2 + # center_coords = torch.stack((x_center, y_center), dim=1) + center_coords = np.stack((x_center, y_center), axis=1) + # center_coords=np.stack((x_center, y_center)).reshape(-1,2)#NOTE not do this, mix element order shits + #TODO split atom_charge \ bond drawing + output = {'bbox': output["boxes"],#.to("cpu").numpy(), + 'bbox_centers': center_coords,#.to("cpu").numpy(), + 'scores': output["scores"],#.to("cpu").numpy(), + 'pred_classes': output["labels"],#.to("cpu").numpy() + } + ############################################################################################################################ + img_ori = Image.open(image_path).convert('RGB') + w_ori, h_ori = img_ori.size # 获取原始图像的尺寸 + print(w_ori, h_ori, "orignianl vs 1000,1000") + # 计算缩放比例 + scale_x = 1000 / w_ori + scale_y = 1000 / h_ori + img_ori_1k = img_ori.resize((1000,1000)) + img = Image.open(image_path).convert('RGB') + img = img.resize((1000,1000)) + # atom_bondBox_check=True + + print(f"from oupt socore > {box_thresh},get box {len(output['bbox'])} after nms_per_class ") + # split into atom bond charge nms, then mergd , then box2 mol NOTE charege and bond confidence at least >10% + charge_mask = np.array([True if ins in charge_labels and output['scores'][i]>0.1 else False for i, ins in enumerate(output['pred_classes'])]) + charges_bbox=output['bbox'][charge_mask] + charges_centers= output['bbox_centers'][charge_mask] + charges_classes= output['pred_classes'][charge_mask] + charges_scores= output['scores'][charge_mask] + charges_bbox, charges_centers, charges_scores,charges_classes,figc =view_box_center2(charges_bbox, charges_centers, charges_scores,charges_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + #view_box_center2 help remove large box if boxscore small than 0.5 + # bonds_mask2 = np.array([True if ins in bond_labels else False for ins in output['pred_classes']]) + # bonds_mask= output['scores'][bonds_mask2]>=0.1# TODO fix me, as training bond box overlap with bondbox,aussme bond socre make sense + bonds_mask = np.array([True if ins in bond_labels and output['scores'][i]>0.2 else False for i, ins in enumerate(output['pred_classes'])]) + bond_bbox=output['bbox'][bonds_mask] + bond_centers= output['bbox_centers'][bonds_mask] + bond_classes= output['pred_classes'][bonds_mask] + bond_scores= output['scores'][bonds_mask] + # bond_bbox2, bond_centers2, bond_scores2,bond_classes2,fig=view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + bond_bbox, bond_centers, bond_scores,bond_classes,fig =view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=3) + bond_bbox, bond_classes, bond_scores = nms(bond_bbox, bond_scores,bond_classes, iou_threshold=0.5) + + heavy_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins != lab2idx['H'] else False for ins in output['pred_classes']]) + h_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins == lab2idx['H'] else False for ins in output['pred_classes']]) + + #TODO fix me if heavy or H all need this view_box_center2 filtering + heavy_bbox = output['bbox'][heavy_mask] + heavy_classes = output['pred_classes'][heavy_mask] + heavy_centers= output['bbox_centers'][heavy_mask] + heavy_scores= output['scores'][heavy_mask] + heavy_bbox, heavy_centers, heavy_scores,heavy_classes,fighv =view_box_center2(heavy_bbox, heavy_centers, heavy_scores,heavy_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + + #TODO del isolated C without bond box overlap + delt_hei=[] + for hei,hebox in enumerate(heavy_bbox): + he_class=idx_to_labels[heavy_classes[hei]] + b_nei=[] + if he_class in ['C']:#TODO add other cases + for bi,bb in enumerate(bond_bbox): + overlap_flag=boxes_overlap(bb, hebox)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol + if overlap_flag: + b_nei.append(bi) + if len(b_nei)==0: + delt_hei.append(hei) + n = len(heavy_scores) # 更新数量 + keep_boxes = np.ones(n, dtype=bool) + keep_boxes[delt_hei]=False + heavy_bbox, heavy_centers, heavy_scores,heavy_classes=heavy_bbox[keep_boxes], heavy_centers[keep_boxes], heavy_scores[keep_boxes],heavy_classes[keep_boxes] + + h_bbox = output['bbox'][h_mask] + h_centers= output['bbox_centers'][h_mask] + h_classes= output['pred_classes'][h_mask] + h_scores= output['scores'][h_mask] + h_bbox, h_centers, h_scores,h_classes,figh =view_box_center2(h_bbox, h_centers, h_scores,h_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + + #NOTE need keep the order heavy atom first then following with Hs + # atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']]) + # atom_bbox=output['bbox'][atoms_mask] + # atom_classes=output['pred_classes'][atoms_mask] + # 合并 bbox,保持重原子在前,氢原子在后 + atom_bbox = np.concatenate([heavy_bbox, h_bbox], axis=0) + atom_classes = np.concatenate([heavy_classes, h_classes], axis=0) + # atom_centers = np.concatenate([heavy_centers, h_centers], axis=0) + atom_scores = np.concatenate([heavy_scores, h_scores], axis=0) + #TODO nms checking + # kept_bboxes, kept_classes, kept_scores=nms(atom_bbox, atom_scores, atom_classes, iou_threshold=0.5) + # # kept_bboxes, kept_classes, kept_scores=nms_atomBox(atom_bbox, atom_scores, atom_classes, iou_threshold=0.5) + # merged_bboxes, merged_classes, merged_scores = merge_low_iou_boxes(kept_bboxes, kept_classes, kept_scores, merge_threshold=0.5, score_threshold=0.7) + # print(f'ater nms kept_box {len(kept_bboxes)}, followd merge_low_iou_boxes kept_box:: {len(merged_bboxes)}') + # atom_bbox, atom_classes, atom_scores=merged_bboxes, merged_classes, merged_scores + atom_bbox, atom_scores, atom_classes = refine_boxes(atom_bbox, atom_scores, atom_classes, bond_bbox) + + + x_center = (atom_bbox[:, 0] + atom_bbox[:, 2]) / 2 + y_center = (atom_bbox[:, 1] + atom_bbox[:, 3]) / 2 + # center_coords = torch.stack((x_center, y_center), dim=1) + center_coords = np.stack((x_center, y_center), axis=1) + atom_centers=center_coords + + print(f"before NMS :: heavy box {len(heavy_bbox)} ---- H box {len(h_bbox)}---bond box{len(bond_bbox)}") + print(f"after NMS+view_box_center2 :: atom box {len(atom_bbox)} bond box {len(bond_bbox)} charge box {len(charges_bbox)} ") + # print(f"bond box with only single atom box overlap:: {single_odd_bi}") + print(f"atom box afte NMS and merge_low_iou_boxes") + print(f"get box {len(output['bbox'])} with NMS") + print(f"atom score >0.1 bond score >0.2, then folllowed with NMS") + print(f"bond_bbox nums::",bond_bbox.shape,len(bond_bbox)) + print(f" OCR will start involved ")# + #check if ODD single-bonds with only one atom exisits, try add the atoms box for this bond + new_atoms, single_odd_b2a= Newbox_(atom_bbox,bond_bbox, lab2idx ) + print(f"new_atoms number {len(new_atoms)}\n{new_atoms}") + if len(new_atoms)>0: + for boxout in new_atoms: + for k,arr in boxout.items(): + value_or_row=output[k] + if arr.ndim == 1: + output[k]=np.append(value_or_row, arr) + elif arr.ndim >= 2: + output[k] = np.concatenate([value_or_row, arr], axis=0) + else: + print('errprs, unkown conditions !!!@') + #NOTE try to use OCR to help postprocess box adding and del + # 加载图像 OCR + image = cv2.imread(image_path) + image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + # 预处理图像突出下标 + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + _, thresh = cv2.threshold(gray, 150, 255, cv2.THRESH_BINARY_INV) + # print(_, thresh) + kernel = np.ones((2, 2), np.uint8) + dilated = cv2.dilate(thresh, kernel, iterations=1) + # cv2.imwrite("preprocessed.jpg", dilated)#NOTE comment if need checking + # result = ocr.ocr("preprocessed.jpg", cls=True) + # ocr.ocr(image_npocr, cls=True, det=False) + result = ocr.ocr(dilated, cls=True) # 直接传递 NumPy 数组 + # 解析结果 + text_boxes = [] + text_contents = [] + confidences = [] + for line in result: + print(line) + if line: + for box_info in line: + box = box_info[0] + x_coords = [point[0] for point in box] + y_coords = [point[1] for point in box] + text_box = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] + text = box_info[1][0] + text_boxes.append(text_box) + text_contents.append(text) + confidences.append(box_info[1][1]) + print("Detected text boxes:", text_boxes) + print("Detected text contents:", text_contents) + print("Confidences:", confidences) + #after whole img OCRed + # Initialize dictionaries and lists + ai2text = {} + ai2relplace = {} + ai2rdkitlab_unknown = {} + non_overlapping_texts = [] + # Build initial KDTree + tree = cKDTree(atom_centers) + # Collect indices to delete after the loop to keep tree valid during processing + indices_to_delete = set() + # Process each OCR text box + for ti, text_box in enumerate(text_boxes): + text_center = calculate_center(text_box) + ocr_text = text_contents[ti] + + # Normalize OCR text + if ocr_text in ['OH', 'HO']: + ocr_text = 'O' + elif ocr_text in ['SH', 'HS']: + ocr_text = 'S' + elif ocr_text in ['NH', 'HN']: + ocr_text = 'N' + elif ocr_text in ['CH', 'HC']: + ocr_text = 'C' + elif ocr_text == '0': + ocr_text = 'O' + elif ocr_text == 'L': + ocr_text = 'Li' + elif ocr_text[-1]=='-': + if ocr_text[:-1] in ABBREVIATIONS: + ocr_text=ocr_text[:-1] + + # Find all overlapping atom boxes + overlapping_indices = [] + for idx in range(len(atom_bbox)): + if idx not in indices_to_delete and boxes_overlap(atom_bbox[idx], text_box): + overlapping_indices.append(idx) + + if overlapping_indices: + # If there are overlapping atom boxes, merge them + if len(overlapping_indices) > 1: + # Get the smallest box encompassing all overlapping atom boxes + overlapping_boxes = [atom_bbox[idx] for idx in overlapping_indices] + merged_box = get_merged_box(overlapping_boxes) + overlapping_indices_atomboxclass=[idx_to_labels[atom_classes[i]] for i in overlapping_indices] + print(f"Merging {len(overlapping_indices)} atom boxes overlapping with OCR text: {ocr_text}") + print(f" {overlapping_indices} boxes type{overlapping_indices_atomboxclass} merged as OCR text: {ocr_text}") + merged_area = box_area(merged_box) + text_area = box_area(text_box) + final_box = merged_box if merged_area >= text_area else text_box + else: + # If only one overlap, use the text box directly + final_box = text_box + # Use the OCR text box as the merged box + primary_idx = overlapping_indices[0] + # atom_bbox[primary_idx] = text_box + + # Update the primary atom box + atom_bbox[primary_idx] = final_box + # Update class and dictionaries based on OCR text + if ocr_text in ABBREVIATIONS: + ai2relplace[primary_idx] = ocr_text + atom_classes[primary_idx] = 0 + if ocr_text in lab2idx: + atom_classes[primary_idx] = lab2idx[ocr_text] + elif ocr_text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']: + atom_classes[primary_idx] = lab2idx[ocr_text] + elif ocr_text in RGROUP_SYMBOLS or (ocr_text[0] == 'R' and ocr_text[1:].isdigit()): + atom_classes[primary_idx] = 0 + else: + ai2rdkitlab_unknown[primary_idx] = ocr_text + atom_classes[primary_idx] = 0 + + ai2text[primary_idx] = ocr_text + + # Mark redundant indices for deletion + indices_to_delete.update(overlapping_indices[1:]) + + else: + # No overlap: record the text box and nearest atom index + distance, nearest_idx = tree.query(text_center) + if nearest_idx not in indices_to_delete: # Only record if nearest_idx is still valid + print(f"No overlap for OCR text '{ocr_text}', nearest atom box index: {nearest_idx}") + non_overlapping_texts.append({ + 'text': ocr_text, + 'text_box': text_box, + 'nearest_atom_idx': nearest_idx, + 'distance': distance + }) + + #set up atom_ocr match atom_class + atom_ocr=[] + for i,ai in enumerate(atom_classes): + if i in ai2text: + atom_ocr.append(ai2text[i]) + # elif i in ai2rdkitlab_unknown: + # atom_ocr.append(ai2rdkitlab_unknown[i]) + else: + atom_ocr.append(idx_to_labels[ai]) + print(f"atom class + ocr presented as symbols::\n{atom_ocr}") + atom_ocr=np.array(atom_ocr) + # Perform deletions after the loop + if indices_to_delete: + indices_to_keep = np.setdiff1d(np.arange(len(atom_bbox)), list(indices_to_delete)) + atom_bbox = atom_bbox[indices_to_keep] + atom_classes = atom_classes[indices_to_keep] + atom_centers = atom_centers[indices_to_keep] + atom_scores = atom_scores[indices_to_keep] + atom_ocr= atom_ocr[indices_to_keep] + + # Adjust dictionary indices + for d in [ai2text, ai2relplace, ai2rdkitlab_unknown]: + d_new = {} + for old_idx, value in d.items(): + new_idx = np.where(indices_to_keep == old_idx)[0][0] if old_idx in indices_to_keep else None + if new_idx is not None: + d_new[new_idx] = value + d.clear() + d.update(d_new) + + # Adjust nearest_atom_idx in non_overlapping_texts + for entry in non_overlapping_texts: + old_idx = entry['nearest_atom_idx'] + if old_idx in indices_to_keep: + entry['nearest_atom_idx'] = np.where(indices_to_keep == old_idx)[0][0] + else: + entry['nearest_atom_idx'] = -1 # Mark as invalid if the nearest atom was deleted + + # Rebuild KDTree if needed for further use + tree = cKDTree(atom_centers) + + # Final output + print("Whole img with OCR :: ai2relplace, ai2rdkitlab_unknown:", [ai2relplace, ai2rdkitlab_unknown]) + print(f"Adjusted ai ocr_text: {ai2text}") + print(f"Atom box num: {len(atom_bbox)}:: {[idx_to_labels[i] for i in atom_classes]}") + print("Non-overlapping OCR text boxes:", non_overlapping_texts) + + #for all heavy atom labels, consider N3 pred as N, or other cases, I2M not good as paddle on ABC + atomcorp_img = Image.open(image_path).convert('RGB') + atomcorp_img1k=atomcorp_img.resize([1000,1000]) + text_contents_star=[] + text_confidences_star=[] + text_boxes_star=[] + boxid2del=dict() + ocr_discrepancies = {} # New dictionary to record OCR vs. AI mismatches + print(f"has atom_bbox number {len(atom_bbox)}") + for i,box in enumerate(atom_bbox):#split ocr image + # if i in ai2text: continue #may be need comment this, if splited OCR acc better!! + abox =box* [scale_x, scale_y, scale_x, scale_y] + cropped_img=atomcorp_img1k.crop(abox)#if use the small ori image will not get infos + image_npocr = np.array(cropped_img) + result_ocr= ocr2.ocr(image_npocr, det=False)#,cls=True,use_debug=False, det=False)#det fale not box but get rcongized more + # result_ocr= ocr.ocr(image_npocr, cls=True, det=False)#,cls=True, det=False)#det fale not box but get rcongized more + if result_ocr: + for line in result_ocr: + # print(f"Atom box--- {i}, OCR result---: {line}") + if line: + box_flag=has_boxes(line) + for box_info in line: + # print(len(box_info)) + if not box_flag: + text=box_info[0] + #[^a-zA-Z0-9\*\-\+] 表示匹配除了字母、数字、*、- 和 + 之外的所有字符。 + text=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', text)#remove special chars + score_=box_info[1] + text_contents_star.append(text) + text_confidences_star.append(score_) + else:#when paddleOCRuse detection model get text box info + box = box_info[0] + x_coords = [point[0] for point in box] + y_coords = [point[1] for point in box] + text_box = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)] + text = box_info[1][0] + text=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', text)#remove special chars + text_boxes_star.append(text_box) + text_contents_star.append(text) + score_=box_info[1][1] + text_confidences_star.append(score_) + if i in ai2text:#ocr 全img vs split img + # print(f'from whole img ocr atom box {i}----from whole img::{ai2text[i]}') + if ai2text[i] != text: + text=ai2text[i] if len(ai2text[i])>=len(text) else text + print(f"Atom box {i}@@ OCR text: {text}, score: {score_}, AI class: {idx_to_labels[atom_classes[i]]}, AI score: {atom_scores[i]}") + # Normalize OCR text + if text in ['OH', 'HO']: + text = 'O' + elif text in ['SH', 'HS']: + text = 'S' + elif text in ['NH', 'HN']: + text = 'N' + elif text in ['CH', 'HC']: + text = 'C' + elif text == '0': + text = 'O' + elif text == 'L': + text = 'Li' + elif '-' in text: + if text[:-1] in ABBREVIATIONS: + text=text[:-1] + + # Check if OCR text is a single character and not a valid element + is_single_char = len(text) == 1 + ai_pred = idx_to_labels[atom_classes[i]] + #TOD add more simpfiled + if text=='0': + atom_classes[i]=lab2idx['O'] + elif text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']: + atom_classes[i]=lab2idx[text]#need update to keep H following Heavy + # elif # ocr recongnized on lable C as other things chars + elif is_single_char and text not in ELEMENTS and ai_pred == 'C': + # Do not replace AI prediction, just record discrepancy + ocr_discrepancies[i] = { + 'ocr_text': text, + 'ocr_score': score_, + 'ai_class': ai_pred, + 'ai_score': atom_scores[i] + } + else: + overlap, b_nei=AtomBox2bondBox(atom_bbox[i],bond_bbox) + if not overlap: + if text not in ELEMENTS and text not in ABBREVIATIONS: + # print(f"new cases::{text} for atombox {i} {atom_bbox[i]}check how to fix it !!!") + # print(f'OCR text:: {text} score ::{box_info}||atom clss::{idx_to_labels[atom_classes[i]]} {atom_scores[i]}') + if text != idx_to_labels[atom_classes[i]]: + boxid2del[i]= [text,idx_to_labels[atom_classes[i]]]#will delt this atom box infos + else: + if text != idx_to_labels[atom_classes[i]]: + if atom_scores[i]<=score_: + if text in RGROUP_SYMBOLS or text in ABBREVIATIONS: + ai2relplace[i]=text + atom_classes[i]=0 + if text in lab2idx and lab2idx[text] in list(range(23,29)):atom_classes[i]=lab2idx[text] + elif text in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']: + atom_classes[i]=lab2idx[text] + else: + ai2relplace[i]=text + atom_classes[i]=0 + + # 按照 value 的第一个元素(假设是字符串)的长度进行排序,长度大的排前 + boxid2del = dict(sorted(boxid2del.items(), key=lambda item: item[0], reverse=True)) + print(f"considering del box",boxid2del) + print("after split img OCR:: ai2relplace,ai2rdkitlab_unknown",[ai2relplace,ai2rdkitlab_unknown]) + print(f"considering delet atomb box :{boxid2del}") + syms=[] + for i in range(len(atom_classes)): + if i in ai2relplace: syms.append(ai2relplace[i]) + elif i in ai2rdkitlab_unknown:syms.append(ai2rdkitlab_unknown[i]) + else: + syms.append(idx_to_labels[atom_classes[i]]) + print(f"atombox {atom_classes}:: number {len(atom_classes)}\n",[idx_to_labels[i] for i in atom_classes]) + print(f" {syms}") + #chedck isolated box, if need add bond box between the isolated box or not + isolated_ais = [] + # 第一步:构建 bond 到 atom 的映射,并计算 distance_threshold + bond_distances = [] + singleAtomBond=dict() + for bi, bb in enumerate(bond_bbox): + overlapped_atoms = [] + overlapped_abox = [] + for ai, aa in enumerate(atom_bbox): + overlap_flag = boxes_overlap(bb, aa) + if overlap_flag: + overlapped_atoms.append(ai) + overlapped_abox.append(aa) + # if bi not in b2a.keys(): + # b2a[bi] = [ai] + # else: + # b2a[bi].append(ai) + if len(overlapped_atoms) == 2: + center1 = calculate_center(atom_bbox[overlapped_atoms[0]]) + center2 = calculate_center(atom_bbox[overlapped_atoms[1]]) + distance = np.linalg.norm(center1 - center2) + bond_distances.append(distance) + # print(f"Bond {bi} connects atoms {overlapped_atoms}, distance: {distance:.2f}") + elif len(overlapped_atoms) == 1: + print(f"single bond - atom still exists for bond {bi}, need porcess this !!") + if bi not in singleAtomBond: + singleAtomBond[bi]=overlapped_atoms#considering use the add H box for solve TODO + + # 动态计算 distance_threshold + distance_threshold = max(bond_distances) if bond_distances else 100.0 # 默认值 10 如果无 bond + distance_threshold_min = min(bond_distances) if bond_distances else 100.0 # 默认值 10 如果无 bond + print(f"Calculated distance_threshold center based: {distance_threshold:.2f}") + + # 第二步:构建 atom 到 bond 的映射,并检测孤立原子 + a2b=dict() + for ai, aa in enumerate(atom_bbox): + b_nei = [] + for bi, bb in enumerate(bond_bbox): + overlap_flag = boxes_overlap(bb, aa) + if overlap_flag: + b_nei.append(bi) + a2b[ai] = b_nei + if a2b[ai] ==[]: + if ai not in isolated_ais: + isolated_ais.append(ai) + + isolated_ais=sorted(isolated_ais,reverse=True)#avoid delte atom with index errors + print(f"isolated_ais atom box {isolated_ais}\n ", [idx_to_labels[i] for i in atom_classes[isolated_ais]]) + + # 第三步:处理孤立原子,尝试合并或删除 + updated_atom_bbox = atom_bbox.copy() + updated_atom_classes = atom_classes.copy() + updated_atom_scores = atom_scores.copy() + print(f"atom bbox num {len(atom_bbox)}")#ttt + new_bond_bbox=[] + deleted_ais=[] + del4boxid2del=set() + for isolated_ai in isolated_ais: + isolated_box = atom_bbox[isolated_ai] + isolated_center = calculate_center(isolated_box) + nearest_distance = float('inf') + nearest_ai = -1 + # 找到最近的非孤立原子 + for ai, aa in enumerate(atom_bbox): + if ai not in isolated_ais and ai != isolated_ai: + center = calculate_center(aa) + distance = np.linalg.norm(isolated_center - center) + if distance < nearest_distance: + nearest_distance = distance + nearest_ai = ai + # 合并或删除逻辑 + if nearest_ai != -1: + if nearest_distance<=distance_threshold_min or (nearest_distance <=distance_threshold and nearest_distance>=distance_threshold_min):#this the centers dist not bond length + nearest_box = atom_bbox[nearest_ai] + nearest_class = atom_classes[nearest_ai] + nearest_center = calculate_center(nearest_box) + if isolated_ai in boxid2del: + textocr2del=boxid2del[isolated_ai][0] + else: + textocr2del=None + #NOTE based ont the class and ovelap bond box to adjust + overlap1,bondnei=AtomBox2bondBox(nearest_box,bond_bbox) + if len(bondnei)==1:#could be add two other bond, add bond box + # if textocr2del in [',', '+', '-'] or not any(c.isupper() for c in textocr2del): + if textocr2del is not None and not any(c.isupper() for c in textocr2del): + # del4boxid2del.add(isolated_ai) + deleted_ais.append(isolated_ai) + pass + else: + new_bc = (isolated_center + nearest_center)*0.5 + new_bondbox=np.array([new_bc[0] - nearest_distance*0.5, + new_bc[1] - nearest_distance*0.5, + new_bc[0] + nearest_distance*0.5, + new_bc[1] + nearest_distance*0.5] + ) + new_bond_bbox.append(new_bondbox.reshape(-1,4)) + print(f'add a new bond box new_bc for two atom boxes {isolated_ai} ---- {nearest_ai}::\n {idx_to_labels[atom_classes[isolated_ai]]} --- {idx_to_labels[atom_classes[nearest_ai]]}') + else:#TODO fix me when get the case with >=2 bonds need add bond also + try: + new_box = merge_boxes(isolated_box, nearest_box) + updated_atom_bbox[nearest_ai] = new_box + chosed_score_ = max(atom_scores[isolated_ai], atom_scores[nearest_ai]) + updated_atom_scores[nearest_ai] = chosed_score_ + except Exception as e: + print(f"file_name@: {image_path}\n SMILES in csv:\n{SMILESori}") + print(e) + print('nearest_ai ', nearest_ai) + check2=True + if check2: + padding=5 + # box_thresh=0.3 + atombox_img=draw_objs(copy.deepcopy(img), + atom_bbox* [scale_x, scale_y, scale_x, scale_y], + atom_classes, atom_scores , + category_index=idx_to_labels, + box_thresh=box_thresh, + line_thickness=3, + font='arial.ttf', + font_size=10) + bonbox_img=draw_objs(copy.deepcopy(img), + bond_bbox* [scale_x, scale_y, scale_x, scale_y], + bond_classes, bond_scores , + category_index=idx_to_labels, + box_thresh=0.01, + line_thickness=3, + font='arial.ttf', + font_size=10) + # Get sizes of the individual images + atom_width, atom_height = atombox_img.size + bon_width, bon_height = bonbox_img.size + combined_width = atom_width + bon_width + padding * 3 + combined_height = max(atom_height, bon_height) + padding * 2 + combined_img = Image.new('RGB', (combined_width, combined_height), color=(255, 255, 255)) # White background + # Paste the images onto the new canvas + combined_img.paste(atombox_img, (padding, padding)) # Top-left + combined_img.paste(bonbox_img, (atom_width + padding * 2, padding)) + print(f"atom box afte NMS and merge_low_iou_boxes") + combined_img.save(f"tttttttttttttttttttttttBoxed.png" + ) + raise Exception("@debug this!!\n") + + if chosed_score_>=0.5: + if chosed_score_==atom_scores[isolated_ai]: + updated_atom_classes[nearest_ai] = 0 # mrege replaced with * + # else: + # updated_atom_classes[nearest_ai] = atom_classes[nearest_ai] # 保留较高 score 的类别 + updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0)#after mreged need del it + # updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0) + updated_atom_classes = np.delete(updated_atom_classes, isolated_ai) + updated_atom_scores = np.delete(updated_atom_scores, isolated_ai) + print(f"Merged atom {isolated_ai} into {nearest_ai}, new box: {new_box}") + isolated_ais.remove(isolated_ai) + deleted_ais.append(isolated_ai) + # elif nearest_distance<=distance_threshold_min:#very close,mrege with nearest one + elif atom_scores[isolated_ai] < 0.5: + # 删除低分孤立原子 + updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0) + updated_atom_classes = np.delete(updated_atom_classes, isolated_ai) + updated_atom_scores = np.delete(updated_atom_scores, isolated_ai) + print(f"DELET isolated atom {isolated_ai} with score {atom_scores[isolated_ai]}") + deleted_ais.append(isolated_ai) + # 更新索引,因为数组维度变化 + isolated_ais = [i if i < isolated_ai else i - 1 for i in isolated_ais if i != isolated_ai] + else: + print(f"KEEP isolated atom {isolated_ai} with score {atom_scores[isolated_ai]} >= 0.5") + + + else: + if atom_scores[isolated_ai] < 0.5: + updated_atom_bbox = np.delete(updated_atom_bbox, isolated_ai, axis=0) + updated_atom_classes = np.delete(updated_atom_classes, isolated_ai) + updated_atom_scores = np.delete(updated_atom_scores, isolated_ai) + print(f"DELET isolated atom {isolated_ai} with score {atom_scores[isolated_ai]}") + deleted_ais.append(isolated_ai) + isolated_ais = [i if i < isolated_ai else i - 1 for i in isolated_ais if i != isolated_ai] + else: + print(f"KEEP isolated atom {isolated_ai} with score {atom_scores[isolated_ai]} >= 0.5") + + if len(new_bond_bbox)>0: + for i,bond_box in enumerate(new_bond_bbox): + bond_bbox= np.concatenate([bond_bbox,bond_box],axis=0) + bond_scores= np.concatenate((bond_scores,np.array([0.9])),axis=0) + bond_classes= np.concatenate([bond_classes,np.array([13])],axis=0) + #reset bond center + x_center = (bond_bbox[:, 0] + bond_bbox[:, 2]) / 2 + y_center = (bond_bbox[:, 1] + bond_bbox[:, 3]) / 2 + # center_coords = torch.stack((x_center, y_center), dim=1) + center_coords = np.stack((x_center, y_center), axis=1) + bond_centers=center_coords + + #del the additional atom box that not connected by bond box also mismatch other rules + if len(deleted_ais) > 0: # 如果有需要删除的索引 + print(f"will delete atom box with idx :: {deleted_ais}") + # 使用 np.delete 一次性删除所有指定的行 + atom_classes = np.delete(atom_classes, deleted_ais, axis=0) + atom_scores = np.delete(atom_scores, deleted_ais, axis=0) + atom_bbox = np.delete(atom_bbox, deleted_ais, axis=0) + atom_ocr = np.delete(atom_ocr, deleted_ais, axis=0) + + # eles=[idx_to_labels[i] for i in atom_classes] + # print(eles,len(eles)) + cur_atomSymbols=[idx_to_labels[i] for i in atom_classes] + ocr_wholeImg=[] + for i in atom_classes: + if i in ai2relplace: + ocr_wholeImg.append(ai2relplace[i]) + elif i in ai2rdkitlab_unknown: + ocr_wholeImg.append(ai2rdkitlab_unknown[i]) + else: + ocr_wholeImg.append(idx_to_labels[i]) + print("ai2relplace,ai2rdkitlab_unknown",ai2relplace,ai2rdkitlab_unknown) + print("cur_atomSymbols:",cur_atomSymbols) + print(" atomSymbolsOCR:",ocr_wholeImg) + + # 找到 'H' 的索引, H after Heavy + h_indices = np.where(atom_classes == lab2idx['H'])[0] + non_h_indices = np.where(atom_classes != lab2idx['H'])[0] + # print(h_indices,non_h_indices) + # 重新排序 + new_order = np.concatenate((non_h_indices, h_indices)).astype(np.int64) + # newid2old_Hafter={ i:j for i,j in enumerate(new_order)} + # old2newid_Hafter={ j:i for i,j in enumerate(new_order)} + atom_classes = atom_classes[new_order] + atom_bbox = atom_bbox[new_order] + atom_scores = atom_scores[new_order] + x_center = (atom_bbox[:, 0] + atom_bbox[:, 2]) / 2 + y_center = (atom_bbox[:, 1] + atom_bbox[:, 3]) / 2 + # center_coords = torch.stack((x_center, y_center), dim=1) + center_coords = np.stack((x_center, y_center), axis=1) + atom_centers=center_coords#TODO 记得把 abbve idx label same reoder or mapping then bond + #bond box reoder like atom box, let the singleAtomBond later + bond_bbox = reorder_bond_bbox(bond_bbox, singleAtomBond) + bond_classes = reorder_bond_bbox(bond_classes, singleAtomBond) + bond_scores = reorder_bond_bbox(bond_scores, singleAtomBond) + bond_centers = reorder_bond_bbox(bond_centers, singleAtomBond) + + # 第二步:构建 atom 到 bond 的映射,并检测孤立原子 + a2b=dict() + for ai, aa in enumerate(atom_bbox): + b_nei = [] + for bi, bb in enumerate(bond_bbox): + overlap_flag = boxes_overlap(bb, aa) + if overlap_flag: + b_nei.append(bi) + a2b[ai] = b_nei + if a2b[ai] ==[]: + if ai not in isolated_ais: + isolated_ais.append(ai) + + b2a=dict() + for bi,bb in enumerate(bond_bbox): + overlapped_atoms = [] + overlapped_abox=[] + for ai,aa in enumerate(atom_bbox): + overlap_flag=boxes_overlap(bb, aa)#TODO use tghe atom bond box overlap get bond atom mapping,then built mol + if overlap_flag: + # print(bb, aa,overlap_flag) + overlapped_atoms.append(ai) + overlapped_abox.append(aa) + if bi not in b2a.keys(): + b2a[bi]=[ai] + else: + # vais=b2a[bi] + b2a[bi].append(ai) + if len(overlapped_atoms) == 1: + print(f"single bond -atom still exists {overlapped_atoms}") + + #c2a a2c + #charge atom idx maping + if len(charges_classes) > 0: + # print(charges_bbox,charges_classes,len(charges_classes)) + kdt = cKDTree(atom_centers) + atid_list=list(range(len(atom_centers))) + used_charge_indices=set() + c2a=dict() + for i, (x,y) in enumerate(charges_centers): + overlapped_abox=[] + cc=charges_bbox[i] + for ai, aa in enumerate(atom_bbox): + overlap_flag=boxes_overlap(cc, aa) + ac_iou=calculate_iou(cc, aa) + charge_=charges_classes[i] + charge_score=charges_scores[i] + if overlap_flag: + if i in c2a: + c2a[i].append(ai) + else: + c2a[i]=[ai] + if ai not in atid_list: + print(f"Warning: ai {ai} is out of range for atom_list.") + continue # 跳过当前循环迭代 + # idx_to_labels[charges_classes[0]] + a2c=dict() + for ci,v in c2a.items(): + charge_=idx_to_labels[charges_classes[ci]] + if len(v)==1: + a2c[v[0]]=ci + else: + for ai in v: + ats=idx_to_labels[atom_classes[ai]] + if ats=='other': + ats='*' + if ats in ['F','Cl','I','Br','O'] and int(charge_)<0: + a2c[ai]=ci + elif ats in ['N','H','P'] and int(charge_)>0: + a2c[ai]=ci + else: + print(f'unusuaal case charge {charge_} with atom {ats}!!') + + print(f"all a2b b2a a2c c2a done, start mol built") + #finsh the update of box back to the output for retraining used + output={ + 'bbox': np.concatenate([atom_bbox, bond_bbox,charges_bbox], axis=0), + 'bbox_centers': np.concatenate([atom_centers, bond_centers,charges_centers],axis=0), + 'scores': np.concatenate([atom_scores, bond_scores, charges_scores],axis=0), + 'pred_classes': np.concatenate([atom_classes, bond_classes, charges_classes],axis=0), + 'image_path': image_path + } + # boxinfo + boxinfor={ + 'bbox': output['bbox'], + 'scores': output['scores'],#TODO use same vocabl ? + 'pred_classes': output['pred_classes'],#[ lab2idx[x] for x in output['pred_classes']],#changet it back to character + 'image_path': image_path + } + #split agin for buit mol + charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']]) + charges_bbox=output['bbox'][charge_mask] + charges_centers=bbox2center(charges_bbox) + # charges_centers= output['bbox_centers'][charge_mask] + charges_classes= output['pred_classes'][charge_mask] + charges_scores= output['scores'][charge_mask] + charges_bbox, charges_centers, charges_scores,charges_classes,figc =view_box_center2(charges_bbox, charges_centers, charges_scores,charges_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + #view_box_center2 help remove large box if boxscore small than 0.5 + # bonds_mask2 = np.array([True if ins in bond_labels else False for ins in output['pred_classes']]) + # bonds_mask= output['scores'][bonds_mask2]>=0.1# TODO fix me, as training bond box overlap with bondbox,aussme bond socre make sense + bonds_mask = np.array([True if ins in bond_labels and output['scores'][i]>0.2 else False for i, ins in enumerate(output['pred_classes'])]) + bond_bbox=output['bbox'][bonds_mask] + bond_centers=bbox2center(bond_bbox) + # bond_centers= output['bbox_centers'][bonds_mask] + bond_classes= output['pred_classes'][bonds_mask] + bond_scores= output['scores'][bonds_mask] + print(f"before view_box_center2 bond nums {len(bond_scores)}") + # bond_bbox2, bond_centers2, bond_scores2,bond_classes2,fig=view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + bond_bbox, bond_centers, bond_scores,bond_classes,fig =view_box_center2(bond_bbox, bond_centers, bond_scores,bond_classes, overlap_dist_thresh=5.0, max_centers_per_box=3) + print(f"after view_box_center2 bond nums {len(bond_scores)}") + + heavy_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins != lab2idx['H'] else False for ins in output['pred_classes']]) + h_mask= np.array([True if ins not in bond_labels and ins not in charge_labels and ins == lab2idx['H'] else False for ins in output['pred_classes']]) + + #TODO fix me if heavy or H all need this view_box_center2 filtering + heavy_bbox = output['bbox'][heavy_mask] + # heavy_classes = output['pred_classes'][heavy_mask] + heavy_centers=bbox2center(heavy_bbox) + # heavy_centers= output['bbox_centers'][heavy_mask] + heavy_scores= output['scores'][heavy_mask] + heavy_classes = output['pred_classes'][heavy_mask] + heavy_bbox, heavy_centers, heavy_scores,heavy_classes,fighv =view_box_center2(heavy_bbox, heavy_centers, heavy_scores,heavy_classes, overlap_dist_thresh=5.0, max_centers_per_box=5) + ###########################start build mol ########################## + rwmol_ = Chem.RWMol() + boxi2ai = {} # 预测索引 -> RDKit 索引 + placeholder_atoms=dict() + J=0 + for i, (bbox, a) in enumerate(zip(atom_bboxes, atom_classes)): + a2labl=False + a=replace_cg_notation(a) + # print(a,'atom box class label') + if a in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:# '*', I2M's defined atom types + # if a=='H':continue#skip H fristly,only with heavy atom then addH + ad = Chem.Atom(a)#TODO consider non chemical group and label for using + #TODO add pd rdkit known elemetns here + elif a in ELEMENTS: + ad = Chem.Atom(a) + elif a in ABBREVIATIONS : + ad = Chem.Atom("*") + placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, + a2labl=True + + else: + if N_C_H_expand(a): + ad = Chem.Atom("*") + placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, + a2labl=True + elif C_H_expand(a): + ad = Chem.Atom("*") + placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, + a2labl=True + elif C_H_expand2(a): + ad = Chem.Atom("*") + placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, + a2labl=True + elif formula_regex(a): + ad = Chem.Atom("*") + placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置, + a2labl=True + else: + ad = Chem.Atom("*") + if a not in ['*',"other"]: + a2labl=True + # placeholder_atoms[idx] = a + # atom = Chem.Atom(symbol) + rwmol_.AddAtom(ad) + boxi2ai[J] = rwmol_.GetNumAtoms() - 1 + if a2labl: rwmol_.GetAtomWithIdx(J).SetProp("atomLabel", f"{a}")#mol set with label, mol_rebuild not + J+=1 + + # 使用 KDTree 构建重原子间的键(如果提供了 bond_bbox) + if len(charges_classes) > 0: + for k,v in a2c.items(): + fc=int(idx_to_labels[charges_classes[v]]) + rwmol_.GetAtomWithIdx(k).SetFormalCharge(fc) + # print(f"mol with heavy atoms number {i+1}, max heavy atom id {i}") + print(f"mol with atoms number {i+1}, max atom id {i}") + print(f"mol with bond box number {len(bond_classes)}") + print(f"placeholder_atoms@@ {placeholder_atoms}") + + #重原子 skeleton mol + bonds=dict() + existing_bonds = set() + b2aa=dict() + singleAtomBond=[] + bondWithdirct=[] + + # tree_heavy = KDTree(heavy_centers)#TODO before add bond consdiering reodering bond ?? + tree_atom = KDTree(atom_centers)#TODO as atom bond are all reodered to kee H last + if len(idx_to_labels)==30: + _margin=0#ad this version bond dynamicaly changed + for bi, (bbox, idx_) in enumerate(zip(bond_bbox, bond_classes)):#not work for cross-bond, longer bond, as the center of bond may be close to as atoms not it two atoms + bond_type = idx_to_labels[idx_] + if len(idx_to_labels)==23: + if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: + _margin = 5 + else: + _margin = 8 + anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1]) + oposite_anchor_positions = anchor_positions.copy() + oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1] + # Upper left, lower right, lower left, upper right + # x1y1, x2y2, x1y2, x2y1 : dinuogl lines + anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions]) + # print(f"anchor_positions {anchor_positions.shape}\n{anchor_positions}") + dists, neighbours = tree_atom.query(anchor_positions, k=1) + if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0: + # visualize setup + begin_idx, end_idx = neighbours[:2] + else: + # visualize setup + begin_idx, end_idx = neighbours[2:] + atom1_idx = boxi2ai[begin_idx] + atom2_idx = boxi2ai[end_idx] + if atom1_idx == atom2_idx:#NOTE when bond with only one terminal atom, other side H not used + print(f"attempt to add self-bond:{bi} atomIdx1 == atomIdx2 ::{[atom1_idx, atom2_idx]}") + print(f"for bond bi {bi} H atom may involbed dists:",dists) + print(neighbours) + print("anchor_positions",anchor_positions) + else: + if bond_type in ['-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']: + if bond_type in BONDDIRECT: + bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', bond_type) + bondWithdirct.append(bi) + else: + bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', None) + bond_type=BONDTYPE['SINGLE'] + elif bond_type == '=': + bonds[bi] = (atom1_idx, atom2_idx, 'DOUBLE', None) + bond_type=BONDTYPE['DOUBLE'] + elif bond_type == '#': + bonds[bi] = (atom1_idx, atom2_idx, 'TRIPLE', None) + bond_type=BONDTYPE['TRIPLE'] + else: + print(f'unkown bond type relaced with single@@ {bond_type}') + bonds[bi] = (atom1_idx, atom2_idx, 'SINGLE', None) + bond_type=BONDTYPE['SINGLE'] + # 检查价态 + atom1 = rwmol_.GetAtomWithIdx(atom1_idx) + atom2 = rwmol_.GetAtomWithIdx(atom2_idx) + val1 = sum(b.GetBondTypeAsDouble() for b in atom1.GetBonds()) + val2 = sum(b.GetBondTypeAsDouble() for b in atom2.GetBonds()) + max_val1 = max(VALENCES[atom1.GetSymbol()]) + max_val2 = max(VALENCES[atom2.GetSymbol()]) + # bond_order = bond_type.AsDouble() + bond_order=BONDTYPE2ORD[bond_type] + if val1 + bond_order <= max_val1 and val2 + bond_order <= max_val2: + bond1 = rwmol_.GetBondBetweenAtoms(atom1_idx, atom2_idx) + bond2 = rwmol_.GetBondBetweenAtoms(atom2_idx, atom1_idx) + if bond1 or bond2: + # print(f'bond exists for {[atom1_idx, atom2_idx]}') + pass + # if (atom1_idx, atom2_idx) not in existing_bonds and (atom2_idx, atom1_idx) not in existing_bonds: + else: + # print(atom1_idx, atom2_idx, bond_type,[ bi, idx_to_labels[idx_] ]) + rwmol_.AddBond(atom1_idx, atom2_idx, bond_type) + else: + print(f"Skipping bond {bi}: Exceeds valence.") + existing_bonds.add((atom1_idx, atom2_idx)) + b2aa[bi]=sorted([atom1_idx, atom2_idx]) + + if len(bond_bbox)==1 and len(atom_bbox)==2: + ca1='[*:0][C:2]#[C:3][*:1]'#acs phC#CpH + rwmol_ = Chem.RWMol() + ats= ['*','*','C','C'] + for ia in ats: + a=Chem.Atom(ia) + id_=rwmol_.AddAtom(a) + # print(ia,id_) + rwmol_.AddBond(2, 3, Chem.BondType.TRIPLE) + rwmol_.AddBond(0, 2, Chem.BondType.SINGLE) + rwmol_.AddBond(1, 3, Chem.BondType.SINGLE) + + # Chem.MolFromSmiles(ca1) + for i in range(len(atom_classes)): + atom_classes[i]=lab2idx['*'] + AllChem.Compute2DCoords(rwmol_) + else: + rwmol_=copy.deepcopy(rwmol_) + print(f"placeholder_atoms {placeholder_atoms}") + + #assign 2D coords + mol = rwmol_.GetMol() + mol.RemoveAllConformers() + conf = Chem.Conformer(mol.GetNumAtoms()) + # conf.Set3D(True) + # for i, (x, y) in enumerate(heavy_centers): + for i, (x, y) in enumerate(atom_centers): + x, y=float(x),float(y) + conf.SetAtomPosition(i, (x, y, 0))#TODO why some time need -y, just display same as ori? + mol.AddConformer(conf) + # Chem.SanitizeMol(mol) + Chem.AssignStereochemistryFrom3D(mol) + rwmol_=Chem.RWMol(mol) + #as afte H a\lso didthis + skeleton_mol=copy.deepcopy(rwmol_) + print(skeleton_mol.GetNumBonds()) + chiral_centers_aids = Chem.FindMolChiralCenters(mol, includeUnassigned=True) + + # H realted post-process + heavyNumber=len(heavy_centers) + print(f'mol with heavy number atoms {heavyNumber}, max id {heavyNumber-1}') + onlyHeayMol=copy.deepcopy(rwmol_) + chiral_centers = Chem.FindMolChiralCenters( + rwmol_, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) + chiral_center_ids = [idx for idx, _ in chiral_centers] + Hais=[] + Hais_bt=[] + Hbd=[] + # H_existing_bonds = set() + for bi, ais in b2a.items():#from box overlap + bt=bond_classes[bi]# in [14,15]#directon bond + for ai in ais: + if ai>heavyNumber-1: + if bt in [14,15]:#directon bond + Hais.append(ais)#NOTE ais ai increasing order as two for loop increasing + print(f"within H bond box id {bi} bond direction {idx_to_labels[bt]} atoms box id {ais} ") + Hais_bt.append(idx_to_labels[bt]) + Hbd.append(bi) + # print(bonds[bi] ) + # add Hbonds with direction + H_existing_bonds = set() + ha2boxa=dict() + for ais, bt in zip(Hais,Hais_bt): + idx_2=ais[-1] + idx_1=ais[0] + hbond=rwmol_.GetBondBetweenAtoms(idx_1,idx_2) + if hbond is not None: + if idx_1 in chiral_center_ids:#if not in the chiral atom, will not set bond directions + hbond.SetBondDir(BOND_DIRS[bt]) + else: + had = Chem.Atom("H") + addHatom_idx = rwmol_.AddAtom(had) + ha2boxa[addHatom_idx]=idx_2 + # print(idx_2,addHatom_idx)#Note if detected H box will lead idx_2 != addHatom_idx + atom= rwmol_.GetAtomWithIdx(idx_1) + max_val=max(VALENCES[atom.GetSymbol()]) + val = sum(b.GetBondTypeAsDouble() for b in atom.GetBonds()) + if (idx_1, addHatom_idx) not in H_existing_bonds and (addHatom_idx, idx_1) not in H_existing_bonds: + if val<=max_val-1: + # print(f"atom id {idx_1} val {val} max_val {max_val}") + print(idx_1, addHatom_idx)#let check bond exist or not!! + rwmol_.AddBond(idx_1,addHatom_idx, Chem.BondType.SINGLE)#BOND_DIRS[bt] + b=rwmol_.GetBondBetweenAtoms(idx_1,addHatom_idx) + if idx_1 in chiral_center_ids:#if not in the chiral atom, will not set bond directions + b.SetBondDir(BOND_DIRS[bt])#############Note can be done in the following tree + H_existing_bonds.add((idx_1,addHatom_idx)) + i + if len(ha2boxa)>0:#consider Hnow + #use box coords assign 2D, remove extra Hs also update box + rwmol_.RemoveAllConformers()# + conf = Chem.Conformer(rwmol_.GetNumAtoms()) + conf.Set3D(True) + coords2d=[] + for i, (x, y) in enumerate(heavy_centers): + position = Point3D(float(x), float(y), 0.) # Create a Point3D object with x, y, and z=0 + conf.SetAtomPosition(i, position) + coords2d.append([x,y]) + for k,v in ha2boxa.items(): + x,y=atom_centers[v] + position = Point3D(float(x), float(y), 0.) # Create a Point3D object with x, y, and z=0 + conf.SetAtomPosition(k, position) + coords2d.append([x,y]) + rwmol_.AddConformer(conf) + + additonalH=detect_unconnected_hydrogens(rwmol_) + if len(additonalH)>0: + rwmol_,rmovedAtomcoords=remove_unconnected_hydrogens2(rwmol_) #NOTE 留给将来WEB开发用will dercease h atom,but the box have not updated TODO fix me this in feature activate learning + #update atom box infors + if len(rmovedAtomcoords)>0:#update box infors + delbb=[] + kdt = cKDTree(atom_centers) + for i, (x,y,z) in enumerate(rmovedAtomcoords):#z=0 + dist, idx_=kdt.query([x,y], k=1) + delbb.append(idx_) + mask = np.ones(len(atom_classes), dtype=bool) # 初始化为 True + mask[delbb] = False + atom_bbox = atom_bbox[mask] + atom_classes = atom_classes[mask] + atom_centers = atom_centers[mask] + # mol# mol_rebuit=copy.deepcopy(mol) + + mol=copy.deepcopy(rwmol_) + conf=mol.GetConformers()[0] + mola2xy=dict() + mola2d=[] + for i,a in enumerate(mol.GetAtoms()): + x,y,z=conf.GetAtomPosition(i) + mola2xy[i]=[x,y] + mola2d.append([x,y]) + # print( x,y,z) + kdt = cKDTree(mola2d) + chiral_centers = Chem.FindMolChiralCenters( + mol, includeUnassigned=True, includeCIP=False, useLegacyImplementation=False) + chiral_center_ids = [idx for idx, _ in chiral_centers] + + for bi,bcent in enumerate(bond_centers): + if bi in bondWithdirct :#and bi not in Hbd:#Note as set Hbd previously + dists, a1a2 = kdt.query(bcent, k=2) + a1,a2=sorted(a1a2) + a1,a2=int(a1),int(a2) + bt= mol.GetBondBetweenAtoms(a1, a2)#RDKit 的键是无向的,返回的是同一个 Bond 对象 + if bt: + # 获取键的当前起点和终点 + current_begin = bt.GetBeginAtomIdx() + current_end = bt.GetEndAtomIdx() + bond_dir=bond_dirs[idx_to_labels[bond_classes[bi]]] + if bond_dir == rdchem.BondDir.BEGINWEDGE: + reverse_dir = rdchem.BondDir.BEGINDASH + elif bond_dir == rdchem.BondDir.BEGINDASH: + reverse_dir = rdchem.BondDir.BEGINWEDGE + # else: + # reverse_dir= rdchem.BondDir.BEGINWEDGE + if a1 in chiral_center_ids: + if current_begin == a1: + bt.SetBondDir(bond_dir) + print(f'a1 dir') + else: + # 如果手性原子是终点,反转方向(例如用相反的楔形键) + bt.SetBondDir(reverse_dir) + print(f'a1 reverse_dir') + # print(f'set bond direction a1a2 {[bi, a1,a2]}') + # bt.SetBondDir(bond_dirs[idx_to_labels[bond_classes[bi]]]) + elif a2 in chiral_center_ids: + if current_begin == a2: + bt.SetBondDir(bond_dir) + print(f'a2 dir {bond_dir} {reverse_dir}') + else: + # 如果手性原子是终点,反转方向(例如用相反的楔形键),but not work, just remove and add + mol.RemoveBond(current_begin, current_end) + mol.AddBond(current_end, current_begin, bt.GetBondType()) + bond = mol.GetBondBetweenAtoms(current_end, current_begin) + bond.SetBondDir(bond_dir) + print(f'a2 reverse_dir {bond_dir} {reverse_dir}') + # bt= mol.GetBondBetweenAtoms(a2, a1) + # print(f'set bond direction a2a1 {[bi, a2,a1]}') + # bt.SetBondDir(bond_dirs[idx_to_labels[bond_classes[bi]]]) + else: + print('bond stro not with chiral atom???, will ignore this stero bond infors') + print(f"{[bi, bond_dir, current_begin,current_end]}") + # beginatom=mol.GetAtomWithIdx(current_begin) + # Endatom=mol.GetAtomWithIdx(current_end) + # beginatom_neis=len(beginatom.GetBonds()) + # Endatom_neis=len(Endatom.GetBonds()) + try: + mol_rebuit=mol.GetMol() + conf = mol_rebuit.GetConformer() + Chem.WedgeMolBonds(mol_rebuit,conf)# + Chem.DetectBondStereochemistry(mol_rebuit) + Chem.AssignChiralTypesFromBondDirs(mol_rebuit) + Chem.AssignStereochemistry(mol_rebuit) + # + smiH=Chem.MolToSmiles(mol_rebuit) + print(F"smiH\n",smiH) + # canon_smilesH = Chem.CanonSmiles(smiH) + # print(F"canon_smilesH\n",canon_smilesH) + # rdkit_coni_smiH=Chem.MolToSmiles(Chem.MolFromSmiles(smiH)) + # print(f"Chem.MolToSmiles(Chem.MolFromSmiles(smiH))\n {rdkit_coni_smiH}") + # + mol = rdkit.Chem.RWMol(mol_rebuit) + other2ppsocr=True + if other2ppsocr: + print() + need_cut=[] + ppstr=[] + ppstr_score=[] + crops=[] + index_token=dict() + expan=0#NOTE this control how much the part of bond in crop_Img + for i_,(heav_c,heav_box) in enumerate(zip(atom_classes,atom_bbox)): + if lab2idx['*']==heav_c or lab2idx['other']==heav_c or lab2idx['Cl']==heav_c: + need_cut.append(i_) + a=heav_box+np.array([-expan,-expan,expan,expan]) + # print(heav_box.shape,a.shape) + box=a * [scale_x, scale_y, scale_x, scale_y]#TODO need the fix as w h may not equal!! + # print(a,box,[scale_x, scale_y, scale_x, scale_y]) + cropped_img = img_ori_1k.crop(box) + crops.append(cropped_img) + image_npocr = np.array(cropped_img) + result_ocr= ocr2.ocr(image_npocr, det=False) + s_, score_ =result_ocr[0][0] + s_previos=atom_ocr[i_] + if s_previos != "other" : + s_=s_previos if len(s_previos)>=len(s_) else s_ + print(f'ocr::idx:{i_}',s_, score_ ) + if score_<=0.1:# process cropped_img and try again + # print(s_, "xxx",score_) + s_='*' + if s_=='+' or s_=='-': + s_="*" + if len(s_)>1: + s_=re.sub(r'[^a-zA-Z0-9,\*\-\+]', '', s_)#remove special chars + if re.match(r'^\d+$', s_): + s_=f'{s_}*'#number+ * + # print(f'why only numbers ? {s_}') + if s_=='L':s_='Li' + elif s_=='0':s_='O' + elif s_ in ['N,+ CI','N,+ Cl' ,'N,+Cl','N,+CI','N+CI']:s_='N2+Cl-' + elif s_ in ['NO,','O,N' ]:s_='NO2' + + + match = re.match(r'^(\d+)?(.*)', s_) + # print(s_,'xxxx') + if match: + numeric_part, remaining_part = match.groups() + fc_=mol.GetAtomWithIdx(i_).GetFormalCharge() + if remaining_part in ELEMENTS: + new_atom = Chem.Atom(remaining_part) + mol.ReplaceAtom(i_, new_atom) + print(i_, remaining_part,"@@@") + elif remaining_part in ABBREVIATIONS:# can be expanded with placeholder_atoms + placeholder_atoms[i_]=s_# such 2Na will be get for rdkit + elif remaining_part=='OH': + new_atom = Chem.Atom("O") + mol.ReplaceAtom(i_, new_atom) + elif remaining_part=='SH': + new_atom = Chem.Atom("S") + mol.ReplaceAtom(i_, new_atom) + elif remaining_part=='NH': + new_atom = Chem.Atom("N") + mol.ReplaceAtom(i_, new_atom) + mol.GetAtomWithIdx(i_).SetFormalCharge(fc_) + index_token[i_]=f'{s_}:{i_}' + print(f"idx:{i_}, atm: <{idx_to_labels[heav_c]}> --- [{s_}:{i_}] with score:{score_} ||previousOCR:: {atom_ocr[i_]}") + if s_ in ELEMENTS : + new_atom = Chem.Atom(s_) + mol.ReplaceAtom(i_, new_atom) + mol.GetAtomWithIdx(i_).SetProp("atomLabel", f"{s_}")#mol set with label, mol_rebuit not + ppstr.append(s_) + ppstr_score.append(score_) + if s_ in ABBREVIATIONS.keys(): + placeholder_atoms[i_]=s_ + # + bond_dirs_rev={v:k for k,v in bond_dirs.items()} + wdbs=[] + for b in mol.GetBonds(): + bd=b.GetBondDir() + bt=b.GetBondType() + # print(bd) + if bd ==bond_dirs['BEGINDASH'] or bd==bond_dirs['BEGINWEDGE']: + a1,a2=b.GetBeginAtomIdx(), b.GetEndAtomIdx() + wdbs.append([a1,a2,bt,bond_dirs_rev[bd]]) + + #expand mol if exists + # if len(placeholder_atoms)>0:### + cm=copy.deepcopy(mol) + # print(placeholder_atoms) + expand_mol, expand_smiles= expandABB(cm,ABBREVIATIONS, placeholder_atoms) + SMILESpre=expand_smiles + rdm=copy.deepcopy(expand_mol) + target_mol, ref_mol=rdm, cm + AllChem.Compute2DCoords(target_mol) + pair=[target_mol, ref_mol] + mcs=rdFMCS.FindMCS([target_mol, ref_mol], # larger,small order + # atomCompare=rdFMCS.AtomCompare.CompareAny, + bondCompare=rdFMCS.BondCompare.CompareAny, + ringCompare=rdFMCS.RingCompare.IgnoreRingFusion, + matchChiralTag=False, + ) + mcs_mol = Chem.MolFromSmarts(mcs.smartsString) + AllChem.Compute2DCoords(mcs_mol) + + matches0 = pair[0].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) + matches1 = pair[1].GetSubstructMatches(mcs_mol, useQueryQueryMatches=True,uniquify=False, maxMatches=1000, useChirality=False) + if len(matches0) != len(matches1): + matches0=list(matches0) + matches1=list(matches1) + # print( "noted: matcher not equal !!") + if len(matches0)>len(matches1): + for i in range(0,len(matches0)): + if i < len(matches1): + pass + else: + ii=i % len(matches1) + matches1.append(matches1[ii]) + else: + for i in range(0,len(matches1)): + if i < len(matches0): + pass + else: + ii=i % len(matches0) + matches0.append(matches0[ii]) + assert len(matches0) == len(matches1), "matcher not equal break!!" + atommaping_pairs=[list(zip(matches0[i],matches1[i])) for i in range(0,len(matches0))] + atomMap=atommaping_pairs[0] + rmsd2=rdkit.Chem.rdMolAlign.AlignMol(prbMol=target_mol, refMol=ref_mol, atomMap=atomMap,maxIters=2000000) + print(f"rmsd {rmsd2}") + #ocr_mol + ocr_mol = copy.deepcopy(target_mol) + AllChem.Compute2DCoords(ocr_mol) + ocr_smi = Chem.MolToSmiles(ocr_mol) + molexp=ocr_mol + expandStero_smi, success= rdkit_canonicalize_smiles(ocr_smi) + # expandStero_smi = Chem.CanonSmiles(ocr_smi)#, useChiral=(not ignore_chiral)) + + # TODO #[3H] 2H prpared box for training are too smalled, need adjust + if visual_check: + boxed_img = draw_objs(img, + atom_bbox, + atom_classes, + atom_scores, + category_index=idx_to_labels, + box_thresh=0.5, + line_thickness=3, + font='arial.ttf', + font_size=10) + opts = Draw.MolDrawOptions() + opts.addAtomIndices = False + opts.addStereoAnnotation = False + img_ori = Image.open(image_path).convert('RGB') + img_ori_1k = img_ori.resize((1000,1000)) + if other2ppsocr: + img_rebuit = Draw.MolToImage(ocr_mol, options=opts,size=(1000, 1000)) + else: + img_rebuit = Draw.MolToImage(ocr_mol, options=opts,size=(1000, 1000)) + combined_img = Image.new('RGB', (img_ori_1k.width + boxed_img.width + img_rebuit.width, img_ori_1k.height)) + combined_img.paste(img_ori_1k, (0, 0)) + combined_img.paste(boxed_img, (img_ori_1k.width, 0)) + combined_img.paste(img_rebuit, (img_ori_1k.width + boxed_img.width, 0)) + imprefix=os.path.basename(image_path).split('.')[0] + combined_img.save(f"{ima_checkdir}/{imprefix}Boxed.png") + + new_row = {'file_name':image_path, "SMILESori":SMILESori, + 'SMILESpre':SMILESpre, + 'SMILESexp':expandStero_smi, + } + smiles_data = smiles_data._append(new_row, ignore_index=True) + + #accu similarity calculation + if getacc: + sameWithOutStero=comparing_smiles(new_row,SMILESpre)#try to ingnore cis chiral, as 2d coords including all the infos + sameWithOutStero_exp=comparing_smiles(new_row,expandStero_smi)#this ignore chairity and *number be * NOTE + + if (type(SMILESori)!=type('a')) or (type(SMILESpre)!=type('a')): + if sameWithOutStero or sameWithOutStero_exp: + mysum += 1 + else: + print(f"smiles problems\n{SMILESori}\n{SMILESpre}\n{image_path}") + failed.append([SMILESori,SMILESpre,image_path]) + mydiff.append([SMILESori,SMILESpre,image_path]) + continue + mol1 = Chem.MolFromSmiles(SMILESori)#TODO considering smiles with rdkit not recongized in real data + if mol1 is None: + rd_smi_ori, success1_=rdkit_canonicalize_smiles(SMILESori) + mol1=Chem.MolFromSmiles(rd_smi_ori) + if (mol_rebuit is None) or (mol1 is None): + if sameWithOutStero or sameWithOutStero_exp: + mysum += 1 + else: + print(f'get rdkit mol None\n{SMILESori}\n{SMILESpre}\n{image_path}') + failed.append([SMILESori,SMILESpre,image_path]) + mydiff.append([SMILESori,SMILESpre,image_path]) + continue + if mol1: + rdk_smi1=Chem.MolToSmiles(mol1) + else: + rdk_smi1=SMILESori + if mol_rebuit: + rdk_smi2=Chem.MolToSmiles(mol_rebuit) + else: + rdk_smi2='' + # if rdk_smi1==rdk_smi2 or rdk_smi1==expandStero_smi or sameWithOutStero:#also considering the abbre in Ori + if rdk_smi1==rdk_smi2 or rdk_smi1==expandStero_smi: + mysum += 1 + else: + if sameWithOutStero or sameWithOutStero_exp: + mysum += 1 + else: + mydiff.append([SMILESori,SMILESpre,image_path]) + if visual_check: + combined_img.save(f"{ima_checkdir}/{imprefix}Boxed_diff{len(mydiff)}.png") + try: + morganfps1 = AllChem.GetMorganFingerprint(mol1, 3,useChirality=True) + morganfps2 = AllChem.GetMorganFingerprint(mol_rebuit, 3,useChirality=True) + morgan_tani = DataStructs.DiceSimilarity(morganfps1, morganfps2) + fp1 = Chem.RDKFingerprint(mol1) + fp2 = Chem.RDKFingerprint(mol_rebuit) + tanimoto = DataStructs.FingerprintSimilarity(fp1, fp2) + if expandStero_smi!= '': + fp3 = Chem.RDKFingerprint(molexp) + morganfps3 = AllChem.GetMorganFingerprint(molexp, 3,useChirality=True) + morgan_tani3 = DataStructs.DiceSimilarity(morganfps1, morganfps3) + tanimoto3 = DataStructs.FingerprintSimilarity(fp1, fp3) + if morgan_tani3> morgan_tani or tanimoto3> tanimoto : + sim+=morgan_tani3 + simRD+=tanimoto3 + else: + simRD+=tanimoto + sim+=morgan_tani + except Exception as e: + print(f"mol to fingerprint erros") + simRD+=0 + sim+=0 + continue + except Exception as e: + print(f"file_name@: {image_path}\n SMILES in csv:\n{SMILESori}") + raise Exception("@debug this!!\n") + + if getacc: + sim_100 = 100*sim/len(smiles_data) + simrd100 = 100*simRD/len(smiles_data) + flogout.write(f"rdkit concanlized==smiles:{100*mysum/len(smiles_data)}%\n") + flogout.write(f"failed:{len(failed)}\n totoal saved in csv : {len(smiles_data)}\n") + flogout.write(f"avarage similarity morgan tanimoto: RDKFp tanimoto:: {sim_100}%, {simrd100}% \n")#morgan_tani considering chiraty + flogout.write(f'I2M@@:: match--{mysum},unmatch--{len(mydiff)},failed--{len(failed)},correct %{100*mysum/len(smiles_data)} \n') + #molscribe evalutate + from src.solver.evaluate import SmilesEvaluator + evaluator = SmilesEvaluator(smiles_data['SMILESori'], tanimoto=False) + res_pre=evaluator.evaluate(smiles_data['SMILESpre']) + res_exp=evaluator.evaluate(smiles_data['SMILESexp']) + flogout.write(f'MolScribe style evaluation@SMILESpre:: {str(res_pre)} \n') + flogout.write(f'MolScribe style evaluation@SMILESexp:: {str(res_exp)} \n') + flogout.close() + print(f"will save {len(smiles_data)} dataframe into csv") + smiles_data.to_csv(outcsv_filename, index=False) + + +import torch.nn as nn +import torch.nn.functional as F +import torchvision + + +class RTDETRPostProcessor(nn.Module): + __share__ = ['num_classes', 'use_focal_loss', 'num_top_queries', 'remap_mscoco_category'] + + def __init__(self, classes_dict=None, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False) -> None: + super().__init__() + self.use_focal_loss = use_focal_loss + if classes_dict is None: + classes_dict = {0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B', + 9:'I',10:'P',11:'H',12:'Si', + #bond + 13:'single',14:'wdge',15:'dash', + 16:'=',17:'#',18:':',#aromatic + #charge + 19:'-4',20:'-2', + 21:'-1',#- + 22:'+1',#+ + 23:'+2', + } + num_classes=len(classes_dict) + self.num_top_queries = num_top_queries + self.num_classes = num_classes + self.remap_mscoco_category = remap_mscoco_category + self.deploy_mode = False + + mscoco_category2label = {k: i for i, k in enumerate(classes_dict.keys())} + mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} + self.mscoco_label2category=mscoco_label2category + + def extra_repr(self) -> str: + return f'use_focal_loss={self.use_focal_loss}, num_classes={self.num_classes}, num_top_queries={self.num_top_queries}' + + # def forward(self, outputs, orig_target_sizes): + def forward(self, outputs, orig_target_sizes): + + logits, boxes = outputs['pred_logits'], outputs['pred_boxes'] + # orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0) + + bbox_pred = torchvision.ops.box_convert(boxes, in_fmt='cxcywh', out_fmt='xyxy') + bbox_pred *= orig_target_sizes.repeat(1, 2).unsqueeze(1) + + if self.use_focal_loss: + scores = F.sigmoid(logits) + scores, index = torch.topk(scores.flatten(1), self.num_top_queries, axis=-1) + labels = index % self.num_classes + index = index // self.num_classes + boxes = bbox_pred.gather(dim=1, index=index.unsqueeze(-1).repeat(1, 1, bbox_pred.shape[-1])) + + else: + scores = F.softmax(logits)[:, :, :-1] + scores, labels = scores.max(dim=-1) + boxes = bbox_pred + if scores.shape[1] > self.num_top_queries: + scores, index = torch.topk(scores, self.num_top_queries, dim=-1) + labels = torch.gather(labels, dim=1, index=index) + boxes = torch.gather(boxes, dim=1, index=index.unsqueeze(-1).tile(1, 1, boxes.shape[-1])) + + # TODO for onnx export + if self.deploy_mode: + return labels, boxes, scores + + # TODO + if self.remap_mscoco_category: + # from ...data.coco import mscoco_label2category + labels = torch.tensor([self.mscoco_label2category[int(x.item())] for x in labels.flatten()])\ + .to(boxes.device).reshape(labels.shape) + + results = [] + for lab, box, sco in zip(labels, boxes, scores): + result = dict(labels=lab, boxes=box, scores=sco) + results.append(result) + + return results + + + def deploy(self, ): + self.eval() + self.deploy_mode = True + return self + + @property + def iou_types(self, ): + return ('bbox', ) \ No newline at end of file