jibsn commited on
Commit
6d6d481
·
verified ·
1 Parent(s): a85cc6f

Upload 5 files

Browse files
Files changed (5) hide show
  1. I2M_R4.onnx +3 -0
  2. ONNX0630.py +2025 -0
  3. app.py +63 -7
  4. det_engine.py +0 -0
  5. utils.py +712 -0
I2M_R4.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3b9cc7af809d1b91d467400f416f800d3908cd5ec733d32b7cefe906b9b71122
3
+ size 212933527
ONNX0630.py ADDED
@@ -0,0 +1,2025 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+ import os,sys
4
+ import argparse
5
+
6
+ model_usedpath='/nfs_home/bowen/works/pys/codes/i2m'
7
+ sys.path.append(model_usedpath)
8
+ home="/nfs_home/bowen/works/pys/codes/i2m"
9
+ bmd=f'/nfs_home/bowen/works/pys/codes/i2m/output0602/checkpoint0070.pth'#
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument('--config', '-c', type=str, default=f'{home}/configs/rtdetr/rtdetr_r50vd_6x_coco.yml')
12
+ parser.add_argument('--resume', '-r', type=str, default=f'{bmd}')
13
+ parser.add_argument('--tuning', '-t', type=str,)# default='/nfs_home/bowen/model_checkpoint/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth')
14
+ parser.add_argument('--test-only',default=True,)
15
+ parser.add_argument('--amp', default=False,)
16
+ parser.add_argument('--dataname', '-da', type=str, default=None)
17
+ parser.add_argument('--gpuid', '-gi', type=str, default=None)
18
+ parser.add_argument('--number', '-n', type=str, default=None)
19
+ args, unknown = parser.parse_known_args()#in jupyter
20
+ print(args)
21
+ if args.gpuid:
22
+ os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpuid}'
23
+ else:
24
+ os.environ['CUDA_VISIBLE_DEVICES'] = '4'
25
+
26
+ parralel_n=2
27
+ os.environ["OMP_NUM_THREADS"] = f"{parralel_n}" # OpenMP
28
+ os.environ["OPENBLAS_NUM_THREADS"] = f"{parralel_n}" # OpenBLAS
29
+ os.environ["MKL_NUM_THREADS"] = f"{parralel_n}" # Intel MKL
30
+ os.environ["VECLIB_MAXIMUM_THREADS"] = f"{parralel_n}" # macOS Accelerate
31
+ os.environ["NUMEXPR_NUM_THREADS"] = f"{parralel_n}" # NumExpr
32
+ """
33
+ WARNING: OMP_NUM_THREADS set to 4, not 1. The computation speed will not be optimized if you use data parallel. It will fail if this PaddlePaddle binary is compiled with OpenBlas since OpenBlas does not support multi-threads.
34
+ PLEASE USE OMP_NUM_THREADS WISELY.
35
+
36
+ """
37
+
38
+
39
+ import shutil
40
+ import pandas as pd
41
+ # print(sys.path)
42
+ print(__file__)
43
+ cur_dir = os.path.dirname(os.path.abspath(__file__))
44
+ print(cur_dir)
45
+ python_path=cur_dir
46
+
47
+ sys.path.append(python_path)
48
+ # model_path='I2M_realv2.onnx'
49
+ # model_abs_path = os.path.abspath(model_path)
50
+ # if os.path.exists(model_abs_path):
51
+ # print(model_abs_path)
52
+
53
+ # from src.solver.det_engine import *
54
+ import cv2
55
+
56
+ import sys,copy
57
+ import torchvision
58
+
59
+ import torch
60
+ import tqdm
61
+ import matplotlib.pyplot as plt
62
+ from matplotlib.patches import Rectangle, Circle
63
+ from det_engine import N_C_H_expand, C_H_expand,C_H_expand2, C_F_expand, formula_regex, RTDETRPostProcessor
64
+ from det_engine import SmilesEvaluator, molfpsim
65
+
66
+
67
+ import rdkit
68
+ from rdkit import Chem
69
+ from rdkit.Chem import Draw, AllChem
70
+ from rdkit import DataStructs
71
+
72
+
73
+
74
+ print("CUDA available:", torch.cuda.is_available())
75
+ print("Number of GPUs:", torch.cuda.device_count())
76
+ # print("Current device:", torch.cuda.current_device())
77
+ print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
78
+
79
+
80
+ # In[ ]:
81
+
82
+ # 计算bbox面积并找到最小的
83
+ def bbox_area(bbox):
84
+ x1, y1, x2, y2 = bbox
85
+ return (x2 - x1) * (y2 - y1)
86
+
87
+ def mol_idx( mol ):
88
+ atoms = mol.GetNumAtoms()
89
+ for idx in range( atoms ):
90
+ mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
91
+ return mol
92
+
93
+ # 移除原子索引
94
+ def mol_idx_del(mol):
95
+ atoms = mol.GetNumAtoms()
96
+ for idx in range(atoms):
97
+ atom = mol.GetAtomWithIdx(idx)
98
+ if atom.HasProp('molAtomMapNumber'): # 检查属性是否存在
99
+ atom.ClearProp('molAtomMapNumber') # 清除属性
100
+ return mol
101
+
102
+ def is_contained_in(bbox_small, bbox_large):
103
+ x_min_s, y_min_s, x_max_s, y_max_s = bbox_small
104
+ x_min_l, y_min_l, x_max_l, y_max_l = bbox_large
105
+ return (x_min_s >= x_min_l and x_max_s <= x_max_l and
106
+ y_min_s >= y_min_l and y_max_s <= y_max_l)
107
+
108
+
109
+ def NoRadical_Smi(smi):
110
+ aa=Chem.MolFromSmiles(smi)
111
+ for atom in aa.GetAtoms():
112
+ if atom.GetNumRadicalElectrons() > 0: # 检查是否有自由基
113
+ # print(f"找到自由基原子: {atom.GetSymbol()}, 自由电子数: {atom.GetNumRadicalElectrons()}")
114
+ # 添加氢原子以去除自由基
115
+ atom.SetNumRadicalElectrons(0) # 将自由电子数设为 0
116
+ # 根据硫原子的化合价调整氢原子数
117
+ atom.SetNumExplicitHs(atom.GetTotalValence() - atom.GetExplicitValence())
118
+ san_before=Chem.MolToSmiles(aa)
119
+ # print(san_before)
120
+ return san_before
121
+
122
+
123
+ def select_longest_smiles(smiles):
124
+ # 将 SMILES 以 '.' 分割为多个部分
125
+ components = smiles.split('.')
126
+ # 选择字符数最多的部分作为主结构
127
+ longest_component = max(components, key=len)
128
+ return longest_component
129
+
130
+ # 解析电荷值
131
+ def parse_charge(charge_str):
132
+ if charge_str.endswith('+'):
133
+ return int(charge_str[:-1]) if charge_str[:-1] else 1 # "1+" -> 1, "+" -> 1
134
+ elif charge_str.endswith('-'):
135
+ return -int(charge_str[:-1]) if charge_str[:-1] else -1 # "2-" -> -2, "-" -> -1
136
+ else :
137
+ return int(charge_str)
138
+
139
+
140
+
141
+ def set_bondDriection(rwmol_,bondWithdirct):
142
+ #set direction
143
+ chiral_center_ids = Chem.FindMolChiralCenters(rwmol_, includeUnassigned=True)
144
+ # chiral_center_ids
145
+ chirai_ai2sterolab=dict()
146
+ if len(chiral_center_ids)>0:
147
+ chirai_ai2sterolab={ai:stero_lab for ai, stero_lab in chiral_center_ids }
148
+
149
+ for bi, binfo in bondWithdirct.items():
150
+ atom1_idx, atom2_idx, bond_type, score, w_d = binfo
151
+ bt= rwmol_.GetBondBetweenAtoms(atom1_idx, atom2_idx)#RDKit 的键是无向的,返回的是同一个 Bond 对象
152
+ current_begin = bt.GetBeginAtomIdx()
153
+ current_end = bt.GetEndAtomIdx()
154
+ if w_d=='wdge':
155
+ bond_dir_=rdchem.BondDir.BEGINWEDGE
156
+ reverse_dir = rdchem.BondDir.BEGINDASH
157
+ elif w_d=='dash':
158
+ bond_dir_=rdchem.BondDir.BEGINDASH
159
+ reverse_dir = rdchem.BondDir.BEGINWEDGE
160
+
161
+ if atom1_idx in chirai_ai2sterolab.keys():
162
+ if current_begin == atom1_idx:
163
+ bt.SetBondDir(bond_dir_)
164
+ print(f'atom1_idx dir')
165
+ else:
166
+ # 如果手性原子是终点,反转方向(例如用相反的楔形键)
167
+ bt.SetBondDir(reverse_dir)
168
+ print(f'atom1_idx reverse_dir')
169
+ elif atom2_idx in chirai_ai2sterolab.keys():
170
+ if current_begin == atom2_idx:
171
+ bt.SetBondDir(bond_dir_)
172
+ print(f'atom2_idx dir {bond_dir_} {reverse_dir}')
173
+ else:
174
+ # 如果手性原子是终点,反转方向(例如用相反的楔形键),but not work, just remove and add
175
+ rwmol_.RemoveBond(current_begin, current_end)
176
+ rwmol_.AddBond(current_end, current_begin, bt.GetBondType())
177
+ bond = rwmol_.GetBondBetweenAtoms(current_end, current_begin)
178
+ bond.SetBondDir(bond_dir_)
179
+ print(f'atom2_idx reverse_dir {bond_dir_} {reverse_dir}')
180
+ else:
181
+ print('bond stro not with chiral atom???, will ignore this stero bond infors')
182
+ print(f"{[bi, bond_dir_, current_begin,current_end]}")
183
+ return rwmol_
184
+
185
+
186
+
187
+ # In[786]:
188
+
189
+
190
+ atom_labels = [0,1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
191
+ bond_labels = [13,14,15,16,17,18]
192
+ charge_labels=[19,20,21,22,23]
193
+
194
+
195
+ idx_to_labels={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
196
+ 9:'I',10:'P',11:'H',12:'Si',
197
+ #bond
198
+ 13:'single',14:'wdge',15:'dash',
199
+ 16:'=',17:'#',18:':',#aromatic
200
+ #charge
201
+ 19:'-4',20:'-2',
202
+ 21:'-1',#-
203
+ 22:'+1',#+
204
+ 23:'+2',
205
+ }
206
+ lab2idx={ v:k for k,v in idx_to_labels.items()}
207
+ bond_labels_symb={idx_to_labels[i] for i in bond_labels}
208
+
209
+ bond_dirs = {'NONE': Chem.rdchem.BondDir.NONE,
210
+ 'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT,
211
+ 'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE,
212
+ 'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH,
213
+ 'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,
214
+ }
215
+
216
+
217
+ import pandas as pd
218
+ from typing import Iterable, List
219
+ from PIL import Image
220
+ import json,re
221
+
222
+ #TODO now abc single bond and OCR checking
223
+ #OCR 得到纯数字box 离原子距离应该小于最小的bond 距离,否则丢弃
224
+ from utils import calculate_iou,adjust_bbox1
225
+ from scipy.spatial import cKDTree, KDTree
226
+ import numpy as np
227
+ from rdkit import Chem
228
+ from paddleocr import PaddleOCR
229
+ from rdkit.Chem import rdchem, RWMol, CombineMols
230
+
231
+ from det_engine import ABBREVIATIONS,remove_SP
232
+ from det_engine import molExpanding,remove_bond_directions_if_no_chiral
233
+ from det_engine import (comparing_smiles,comparing_smiles2, remove_SP, expandABB,
234
+ ELEMENTS,
235
+ ABBREVIATIONS)
236
+
237
+
238
+
239
+ from det_engine import expandABB
240
+
241
+ def bbox2shapes(bboxes, classes, lab2idx):
242
+ shapes = []
243
+ for bbox, label in zip(bboxes, classes):
244
+ x1, y1, x2, y2 = bbox
245
+ if label not in lab2idx :
246
+ label='other'
247
+
248
+ # Create shape dictionary
249
+ shape = {
250
+ "kie_linking": [],
251
+ "label": label,
252
+ "score": 1.0,
253
+ "points": [
254
+ [x1, y1], # top-left
255
+ [x2, y1], # top-right
256
+ [x2, y2], # bottom-right
257
+ [x1, y2] # bottom-left
258
+ ],
259
+ "group_id": None,
260
+ "description": None,
261
+ "difficult": False,
262
+ "shape_type": "rectangle",
263
+ "flags": None,
264
+ "attributes": {}
265
+ }
266
+ shapes.append(shape)
267
+ return shapes
268
+
269
+ def get_longest_part(smi_string):
270
+ if '.' in smi_string: # 如果包含点号
271
+ parts = smi_string.split('.') # 按点号分割
272
+ longest_part = max(parts, key=len) # 取最长的部分
273
+ return longest_part
274
+ else:
275
+ return smi_string # 如果不包含点号,返回原字符串
276
+
277
+
278
+ def split_output_by_numeric_classes(output):
279
+ # 初始化两个结果字典
280
+ numeric_output = {key: [] for key in output.keys()}
281
+ non_numeric_output = {key: [] for key in output.keys()}
282
+
283
+ # 遍历所有元素
284
+ for i in range(len(output['pred_classes'])):
285
+ class_name = output['pred_classes'][i]
286
+
287
+ # 检查是否是纯数字(包括正负号)
288
+ if re.fullmatch(r'^[+-]?\d+[+-]?$', class_name):
289
+ target_dict = numeric_output
290
+ else:
291
+ target_dict = non_numeric_output
292
+
293
+ # 将当前元素添加到相应的字典中
294
+ for key in output.keys():
295
+ target_dict[key].append(output[key][i])
296
+
297
+ return numeric_output, non_numeric_output
298
+
299
+
300
+ def convert_shapes_to_output(json_data):
301
+ output = {
302
+ 'bbox': [],
303
+ 'bbox_centers': [],
304
+ 'scores': [],
305
+ 'pred_classes': []
306
+ }
307
+ for shape in json_data['shapes']:
308
+ # Extract bbox coordinates (assuming shape['points'] is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]])
309
+ points = shape['points']
310
+ x_coords = [p[0] for p in points]
311
+ y_coords = [p[1] for p in points]
312
+ # Calculate bbox as [x_min, y_min, x_max, y_max]
313
+ bbox = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
314
+ # Calculate center coordinates
315
+ center_x = (bbox[0] + bbox[2]) / 2
316
+ center_y = (bbox[1] + bbox[3]) / 2
317
+
318
+ # Get score (use 1.0 if not available)
319
+ score = shape.get('score', 1.0)
320
+
321
+ # Get class label (assuming shape['label'] contains the class)
322
+ pred_class = shape['label']
323
+
324
+ # Append to output
325
+ output['bbox'].append(bbox)
326
+ output['bbox_centers'].append([center_x, center_y])
327
+ output['scores'].append(score)
328
+ output['pred_classes'].append(pred_class)
329
+
330
+ return output
331
+
332
+
333
+ def getJsonData(src_json):
334
+ with open(src_json, 'r') as f:
335
+ coco_data = json.load(f)
336
+ return coco_data
337
+
338
+ def replace_cg_notation(astr):
339
+ def replacer(match):
340
+ h_count = int(match.group(1))
341
+ c_count = (h_count - 1) // 2
342
+ return f'C{c_count}H{h_count}'
343
+
344
+ return re.sub(r'CgH(\d+)', replacer, astr)
345
+
346
+ def viewcheck(image_path,bbox,color='red'):
347
+ image = Image.open(image_path)
348
+ image_array = np.array(image)
349
+ # 创建绘图
350
+ plt.figure(figsize=(5, 4)) # 设置图像大小
351
+ plt.imshow(image_array) # 显示图像
352
+ bbox = np.array(bbox)
353
+ x_coords = (bbox[:, 0]+bbox[:, 2])*0.5
354
+ y_coords =( bbox[:, 1]+bbox[:, 3])*0.5
355
+ plt.scatter(x_coords, y_coords, c=color, s=50, label='Atom Centers', edgecolors='black')
356
+ # 添加标注(可选)
357
+ for i, (x, y) in enumerate(zip(x_coords, y_coords)):
358
+ plt.text(x, y, f'a {i}', fontsize=12, color=color, ha='center', va='bottom')
359
+
360
+ bclass_simple={"single":'-', "wdge":'w', "dash":'--',
361
+ "=":'=', "#":"#", ":":"aro"}
362
+
363
+ def viewcheck_b(image_path,bbox,bclass,color='red',figsize=(5,4)):
364
+ image = Image.open(image_path)
365
+ image_array = np.array(image)
366
+ # 创建绘图
367
+ plt.figure(figsize=figsize) # 设置图像大小
368
+ plt.imshow(image_array) # 显示图像
369
+ # 提取 bbox
370
+ bbox = np.array(bbox)
371
+ x_coords = (bbox[:, 0]+bbox[:, 2])*0.5
372
+ y_coords =( bbox[:, 1]+bbox[:, 3])*0.5
373
+ plt.scatter(x_coords, y_coords, c=color, s=50, label='bond Centers', edgecolors='black')
374
+ # 添加标注(可选)
375
+ for i, (x, y) in enumerate(zip(x_coords, y_coords)):
376
+ simpl_b=bclass_simple[bclass[i]]
377
+ plt.text(x, y, f'b{i}{simpl_b}', fontsize=12, color=color, ha='center', va='bottom')
378
+
379
+
380
+ def anchor_draw(image_path, bond_bbox):
381
+ # 加载图像
382
+ image = Image.open(image_path)
383
+ image_array = np.array(image)
384
+
385
+ # 初始化
386
+ _margin = 3
387
+ all_anchor_positions = []
388
+ all_oposite_anchor_positions = []
389
+
390
+ # 计算所有 bond 的锚点
391
+ for bi, bbox in enumerate(bond_bbox):
392
+ # 计算锚点
393
+ anchor_positions = (np.array(bbox) + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
394
+ oposite_anchor_positions = anchor_positions.copy()
395
+ oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
396
+ anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
397
+
398
+ # 存储前两个点为 anchor_positions,后两个点为 oposite_anchor_positions
399
+ all_anchor_positions.append(anchor_positions[:2]) # [上左, 下右]
400
+ all_oposite_anchor_positions.append(anchor_positions[2:]) # [下左, 上右]
401
+
402
+ # 转换为 numpy 数组
403
+ all_anchor_positions = np.array(all_anchor_positions).reshape(-1, 2)
404
+ all_oposite_anchor_positions = np.array(all_oposite_anchor_positions).reshape(-1, 2)
405
+
406
+ # 图 1:绘制 anchor_positions
407
+ plt.figure(figsize=(10, 8))
408
+ plt.imshow(image_array)
409
+ plt.scatter(all_anchor_positions[:, 0], all_anchor_positions[:, 1], c='red', s=50, label='Anchor Positions', edgecolors='black')
410
+ for i, (x, y) in enumerate(all_anchor_positions):
411
+ plt.text(x, y, f'B{int(i/2)}:{i%2}', fontsize=10, color='white', ha='center', va='bottom')
412
+ plt.title('Anchor Positions (Upper Left, Lower Right)')
413
+ plt.legend()
414
+ plt.axis('off')
415
+ plt.savefig('anchor_positions.png')
416
+
417
+ plt.figure(figsize=(10, 8))
418
+ plt.imshow(image_array)
419
+ plt.scatter(all_oposite_anchor_positions[:, 0], all_oposite_anchor_positions[:, 1], c='blue', s=50, label='Opposite Anchor Positions', edgecolors='black')
420
+ for i, (x, y) in enumerate(all_oposite_anchor_positions):
421
+ plt.text(x, y, f'B{int(i/2)}:{i%2}', fontsize=10, color='white', ha='center', va='bottom')
422
+ plt.title('Opposite Anchor Positions (Lower Left, Upper Right)')
423
+ plt.legend()
424
+ plt.axis('off')
425
+ plt.savefig('Opposite_anchor_positions.png')
426
+
427
+
428
+ # 计算 4 个顶点
429
+ def get_corners(bbox):
430
+ x_min, y_min, x_max, y_max = bbox
431
+ return np.array([
432
+ [x_min, y_min], [x_max, y_min], # 上左,上右
433
+ [x_min, y_max], [x_max, y_max] # 下左,下右
434
+ ])
435
+
436
+ # 计算两组顶点之间的最小距离并返回最近的 atom_idx
437
+ def find_nearest_atom(bond_corners, atom_bboxes, exclude_idx=None):
438
+ min_dist = float('inf')
439
+ nearest_idx = None
440
+ for i, atom_bbox in enumerate(atom_bboxes):
441
+ if exclude_idx is not None and i in exclude_idx:
442
+ continue
443
+ atom_corners = get_corners(atom_bbox)
444
+ for bc in bond_corners:
445
+ for ac in atom_corners:
446
+ dist = np.sqrt((bc[0] - ac[0])**2 + (bc[1] - ac[1])**2)
447
+ if dist < min_dist:
448
+ min_dist = dist
449
+ nearest_idx = i
450
+ return nearest_idx, min_dist
451
+ # 计算顶点到顶点的距离
452
+ def get_min_distance_to_atom_box(vertices, atom_bboxes, exclude_idx=None):
453
+ min_dist = float('inf')
454
+ closest_atom_idx = -1
455
+ for i, ab in enumerate(atom_bboxes):
456
+ if exclude_idx is not None and i in exclude_idx:
457
+ continue
458
+ ab_vertices = np.array([[ab[0], ab[1]], [ab[2], ab[3]], [ab[0], ab[3]], [ab[2], ab[1]]])
459
+ for v in vertices:
460
+ for av in ab_vertices:
461
+ dist = np.linalg.norm(v - av)
462
+ if dist < min_dist:
463
+ min_dist = dist
464
+ closest_atom_idx = i
465
+ return min_dist, closest_atom_idx
466
+
467
+
468
+ # 检查孤立原子并添加键
469
+ def boxes_overlap(box1, box2):
470
+ x1, y1, x2, y2 = box1
471
+ x3, y3, x4, y4 = box2
472
+ return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1)
473
+
474
+ def min_corner_distance(box1, box2):
475
+ corners1 = [[box1[0], box1[1]], [box1[2], box1[3]], [box1[0], box1[3]], [box1[2], box1[1]]]
476
+ corners2 = [[box2[0], box2[1]], [box2[2], box2[3]], [box2[0], box2[3]], [box2[2], box2[1]]]
477
+ min_dist = float('inf')
478
+ for c1 in corners1:
479
+ for c2 in corners2:
480
+ dist = np.sqrt((c1[0] - c2[0])**2 + (c1[1] - c2[1])**2)
481
+ min_dist = min(min_dist, dist)
482
+ return min_dist
483
+
484
+ def clear_directory(path):
485
+ if os.path.exists(path):
486
+ print(f"Clearing contents of: {path}")
487
+ for filename in os.listdir(path):
488
+ file_path = os.path.join(path, filename)
489
+ try:
490
+ if os.path.isfile(file_path) or os.path.islink(file_path):
491
+ os.unlink(file_path) # 删除文件或符号链接
492
+ elif os.path.isdir(file_path):
493
+ shutil.rmtree(file_path) # 删除子目录
494
+ except Exception as e:
495
+ print(f'Failed to delete {file_path}. Reason: {e}')
496
+ else:
497
+ print(f"Directory does not exist: {path}")
498
+
499
+
500
+ def NHR_string(text):
501
+ # 模式 1: 匹配 NHR 后跟一个数字
502
+ pattern1 = r'NHR\d'
503
+ # 模式 2: 匹配 RHN 后跟至少一个数字或小写字母
504
+ pattern2 = r'RHN[0-9a-z]+'
505
+ # 模式 3: 匹配 R 后跟至少一个数字或小写字母,再跟 NH,替换为 NHR
506
+ pattern3 = r'R[0-9a-z]+NH'
507
+ # 先处理模式 3,替换为 NHR
508
+ # text = re.sub(pattern3, 'NHR', text)
509
+ # 检查是否匹配模式 1
510
+ if re.search(pattern1, text):
511
+ # print(f"Matched pattern 1: {text}")
512
+ text='NH*'
513
+ # 检查是否匹配模式 2
514
+ elif re.search(pattern2, text):
515
+ # print(f"Matched pattern 2: {text}")
516
+ text='NH*'
517
+ elif re.search(pattern3, text):
518
+ text='NH*'
519
+
520
+ return text
521
+
522
+ from det_engine import normalize_ocr_text, check_and_fix_valence, rdkit_canonicalize_smiles
523
+ from det_engine import is_valid_chem_text,select_chem_expression
524
+ # Preprocess atom boxes to handle large functional groups
525
+ def preprocess_atom_boxes(atom_centers, atom_bbox, size_threshold_factor=2.5, min_subboxes=2):
526
+ """
527
+ Identify large atom boxes and split them into smaller sub-boxes of approximately average size.
528
+ Returns updated atom_centers, atom_bbox, and a mapping of sub-boxes to original box IDs.
529
+ """
530
+ # Calculate areas of atom boxes
531
+ areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in atom_bbox]
532
+ # Compute average area, excluding max and min to avoid outliers
533
+ if len(areas) > 2:
534
+ sorted_areas = sorted(areas)
535
+ avg_area = np.mean(sorted_areas[1:-1]) # Exclude min and max
536
+ else:
537
+ avg_area = np.mean(areas) if areas else 1.0
538
+
539
+ new_atom_centers = []
540
+ new_atom_bbox = []
541
+ original_to_subbox = {} # Maps original atom index to list of new sub-box indices
542
+ subbox_to_original = {} # Maps new sub-box index to original atom index
543
+ new_idx = 0
544
+
545
+ for i, (bbox, center) in enumerate(zip(atom_bbox, atom_centers)):
546
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
547
+ # Identify large boxes (e.g., functional groups like CH2CH2CH2CH)
548
+ if area > avg_area * size_threshold_factor:
549
+ # Estimate number of sub-boxes based on area ratio
550
+ num_subboxes = max(min_subboxes, int(round(area / avg_area)))
551
+ # Split box along the longer dimension (x or y)
552
+ width = bbox[2] - bbox[0]
553
+ height = bbox[3] - bbox[1]
554
+ if width >= height:
555
+ # Split horizontally
556
+ sub_width = width / num_subboxes
557
+ subboxes = [
558
+ [bbox[0] + j * sub_width, bbox[1], bbox[0] + (j + 1) * sub_width, bbox[3]]
559
+ for j in range(num_subboxes)
560
+ ]
561
+ else:
562
+ # Split vertically
563
+ sub_height = height / num_subboxes
564
+ subboxes = [
565
+ [bbox[0], bbox[1] + j * sub_height, bbox[2], bbox[1] + (j + 1) * sub_height]
566
+ for j in range(num_subboxes)
567
+ ]
568
+ # Compute centers for sub-boxes
569
+ sub_centers = [
570
+ [(subbox[0] + subbox[2]) / 2, (subbox[1] + subbox[3]) / 2]
571
+ for subbox in subboxes
572
+ ]
573
+ # Add sub-boxes and centers
574
+ new_atom_bbox.extend(subboxes)
575
+ new_atom_centers.extend(sub_centers)
576
+ original_to_subbox[i] = list(range(new_idx, new_idx + num_subboxes))
577
+ for j in range(num_subboxes):
578
+ subbox_to_original[new_idx + j] = i
579
+ new_idx += num_subboxes
580
+ else:
581
+ # Keep original box
582
+ new_atom_bbox.append(bbox)
583
+ new_atom_centers.append(center)
584
+ original_to_subbox[i] = [new_idx]
585
+ subbox_to_original[new_idx] = i
586
+ new_idx += 1
587
+
588
+ return np.array(new_atom_centers), new_atom_bbox, original_to_subbox, subbox_to_original
589
+
590
+
591
+
592
+
593
+
594
+
595
+ #add OCR here for placeholder_atoms adding
596
+ other2ppsocr=True
597
+ if other2ppsocr:
598
+ ocr = PaddleOCR(
599
+ use_angle_cls=True,
600
+ lang='latin',use_space_char=True,use_debug=False,
601
+ use_gpu=True if cv2.cuda.getCudaEnabledDeviceCount() > 0 else False
602
+ )
603
+
604
+ ocr2 = ocr2 = PaddleOCR(use_angle_cls=True,use_gpu =False,use_debug=False,
605
+ rec_algorithm='SVTR_LCNet',
606
+ lang="en")
607
+ # outcsv_filename=f"{output_directory}/{prefix_f}_withOCR.csv"
608
+
609
+ # box_thresh=0.45# 292 -240=52
610
+ box_thresh=0.5# 292 -233=59
611
+ useocr=True
612
+ box_matter=0
613
+ getacc=False
614
+ getfpsim=False
615
+ visual_check=False
616
+
617
+
618
+
619
+ # da='acs'
620
+ # src_dir=f"D:\RPA\codes_share\wsl_\chem_box\\real\\real\{da}"
621
+ # src_file=f"{src_dir}.csv"
622
+ # df = pd.read_csv(src_file)
623
+ # dst_dirac = f"D:\RPA\codes_share\wsl_\chem_box\\need2check\{da}_ac"
624
+ # dst_dirb = f"D:\RPA\codes_share\wsl_\chem_box\\need2check\{da}_b"
625
+ da='acs'
626
+ #198th row, fixwd with expanded smiles
627
+ #326th, Tos we use the SO2Ph not SiC3 version,conflict fixed
628
+
629
+ #for view
630
+ # view_check_dir=f"D:\RPA\codes_share\wsl_\chem_box\\need2check\\{da}_fixedView"
631
+ da='CLEF'#NOTE
632
+ #462 S[O]a fixed,
633
+ #179 NHR8 fixed as NH-R8
634
+ #fix rows@582,750,411,7612,761 [(CH2)m] [(CH2)q] [(CH2)s] RDKIT NOT readable fixed as [CH2]
635
+ # 30,214, 795, 856, 583, 654, 618,339,138, 927, 203, 869, 261, 634, 180, 63,758, 718, 741,832,88,250, 799,303,956,810,596|bond erro, wrong smiles
636
+ #TODO still failed:1||SO2 mutil-rows from 992
637
+
638
+ da='UOB'#NOTE TODO fix rows@5119, 3420,990,1082,2451,3626,1634,627,5385
639
+ #all paseed @ v3
640
+ da='USPTO'#NOTE TODO fix rows@ 458, 10+ [O.], [NH2.],|| 4566,5523, ima!=smi, 3927,5234,4062|poly unitProb
641
+ #1658,4625,4944 #also SO3, SOOO erro
642
+ #1164 CHO,2703 CN NC err, 4421 CH2O erro
643
+ #4921, Rgroup error fixwed
644
+ #58, SMILES WRONG fixed (NH4NO)2
645
+ #2352 Fix wrong smiles
646
+ #4590 fix wrong smi
647
+ #3381, 4921 wrong smi fixed
648
+ #3071 image C8H13 may not expandable
649
+ # da='staker'#NOTE TODO fix rows@11422(del 11420.png, as it is table not chemMol)
650
+ #SO3 as SOOO, should be S(=O)(=O)O, as o-o-o strange in chemicstry, this erro 50 as below
651
+ #1971,5770,5972,5973,7541,7542,7666,7854,8258,8917,8918,11129,13281,14109,17131,17132,17189,17493,21091,21093,22314,22315,24524,24525,27294,27295,27296,27297,27586,27587,29562,29766,32835,33517,36197,36198,38199,38200,38661,38663,39174,39410,46717,48380,48381,48382,48443,48624,48625
652
+
653
+
654
+
655
+
656
+ da='JPO'
657
+ # da='staker'
658
+
659
+ if args.dataname:
660
+ da=args.dataname
661
+
662
+ # ac_b=False
663
+ ac_b=False
664
+ ac_b_smilesnotsame_writJson=True
665
+ if ac_b:
666
+ view_check_dir=f"D:\RPA\codes_share\wsl_\X-AnyLabeling\\need2check\\view_check_{da}\\failed"
667
+ view_dirac=f"{view_check_dir}/{da}_ac"
668
+ view_dirb=f"{view_check_dir}/{da}_b"
669
+ dst_dirac =view_dirac#when double check used
670
+ dst_dirb =view_dirb
671
+
672
+ # Construct paths using os.path.join
673
+ src_dir = cur_dir
674
+ src_file = os.path.join(src_dir, f"{da}.csv")
675
+ # df = pd.read_csv(src_file)
676
+ # print(f"src_file:\n{src_file}")
677
+ # Construct check and view directories
678
+ # view_check_dir2 = os.path.join(src_dir, f"{da}_fixedView", "failed")
679
+ # view_check_dir2 = os.path.join(src_dir, f"view_check_{da}", "failed")
680
+ view_check_dir2 = os.path.join(src_dir, f"view_check_{da}", "v3")#v3 has the manulay ac b json
681
+
682
+ N=1
683
+ if args.number:
684
+ N=int(args.number)
685
+
686
+ # view_dirac2 = os.path.join(view_check_dir2, f"{da}_ac_N_{N}")
687
+ # view_dirb2 = os.path.join(view_check_dir2, f"{da}_b_N_{N}")
688
+ view_dirac2 = os.path.join(view_check_dir2, f"{da}_ac")
689
+ view_dirb2 = os.path.join(view_check_dir2, f"{da}_b")
690
+
691
+ view_dirac_tmp = os.path.join(view_check_dir2, f"{da}_actmp")
692
+ view_dirac_tmp_debug=True
693
+
694
+
695
+ if ac_b:
696
+ need2mkdir=[view_check_dir,view_dirac, view_dirb, view_check_dir2,view_dirac2, view_dirb2]
697
+ else:
698
+ need2mkdir=[ view_check_dir2,view_dirac2, view_dirb2,view_dirac_tmp]
699
+ for dir_v in need2mkdir :
700
+
701
+ if not os.path.exists(dir_v):
702
+ os.makedirs(dir_v)
703
+
704
+ # ac_b=False
705
+ ac_b=False
706
+ # if ac_b:#update _ac _b
707
+ # # 清空两个目录
708
+ # clear_directory(view_dirac2) #NOTE we only check for the better models faileds
709
+ # clear_directory(view_dirb2)
710
+
711
+ # #note box not equal as abbv eixsits, process single bond..TODO need check and fixing, may be need rdkit smiles Match
712
+ # df['file_name'] = df['file_path'].str.split('/').str[-1]
713
+ # # df['file_base'] =f"{da}_" + df['file_name'].str.replace('.png', '', regex=False)
714
+ # df['file_base'] = df['file_name'].str.replace('.png', '', regex=False)
715
+
716
+
717
+
718
+ # outcsv_filename=f"{src_dir}/{da}_OUTPUTwithOCR.csv"
719
+ outcsv_filename=os.path.join(src_dir, f"{da}_OUTPUTwithOCR.csv")
720
+
721
+ if getacc:
722
+ acc_summary=f"{outcsv_filename}.I2Msummary.txt"
723
+ flogout = open(f'{acc_summary}' , 'w')
724
+ flogout2 = open(f'{outcsv_filename}_acBoxWrong' , 'a')
725
+ failed=[]
726
+ failed_fb=[]
727
+ mydiff=[]
728
+ simRD=0
729
+ sim=0
730
+ simRDlist=[]
731
+ mysum=0
732
+
733
+ smiles_data = pd.DataFrame({'file_name': [],
734
+ 'SMILESori':[],
735
+ 'SMILESpre':[],
736
+ 'SMILESexp':[],
737
+ })
738
+
739
+ # rows_check = df
740
+ miss_file=[]
741
+ miss_filejs=[]
742
+ # for id_, row in rows_check.iterrows():
743
+ debug=False
744
+
745
+ rt_out=False
746
+ if not ac_b:
747
+ view_dirac=view_dirac2
748
+ view_dirb=view_dirb2
749
+ dst_dirac =view_dirac#when double check used
750
+ dst_dirb =view_dirb
751
+ test_dir=f'./test/'#TODO WEB_dev put test images here
752
+ # pngs=[f for f in os.listdir(view_dirac2) if '.png' in f]
753
+ pngs=[f for f in os.listdir(test_dir) if '.png' in f]
754
+ # if da=='staker':
755
+ # pngs=[f for f in os.listdir("/nfs_home/bowen/works/pys/codes/i2m/datas/real/staker") if '.png' in f]
756
+
757
+ rt_out=True
758
+ # view_check_dir3=f"D:\RPA\codes_share\wsl_\chem_box\\need2check\\{da}_fixedView\\v3"
759
+ # view_check_dir3= os.path.join(src_dir, f"{da}_fixedView", "failed")
760
+ view_check_dir3= os.path.join(src_dir, f"view_check_{da}", "v4")#with model output
761
+ view_dirac3=f"{view_check_dir3}/{da}_ac"
762
+ view_dirb3=f"{view_check_dir3}/{da}_b"
763
+ for dir_v in [view_check_dir3,view_dirac3, view_dirb3]:
764
+ if not os.path.exists(dir_v):
765
+ os.makedirs(dir_v)
766
+ # pngs=[f for f in os.listdir(view_dirac3) if '.png' in f]
767
+
768
+ #as abbrev expanded lead a b not equal as original
769
+ acn=False
770
+ bn=False
771
+
772
+
773
+ import torchvision.transforms.v2 as T
774
+
775
+ def image_to_tensor(image_path,debug=True):
776
+ image = Image.open(image_path)
777
+ w, h = image.size
778
+
779
+ # 处理灰度或其他模式
780
+ if image.mode == "L":
781
+ if debug: print("检测到灰度图像 (1 通道),转换为 RGB...")
782
+ image = image.convert("RGB")
783
+ elif image.mode != "RGB":
784
+ if debug: print(f"检测到 {image.mode} 模式,转换为 RGB...")
785
+ image = image.convert("RGB")
786
+ # Define a transform to convert the image to a tensor and normalize it
787
+ transform = T.Compose([
788
+ T.Resize((640, 640)), # 调整大小
789
+ # T.ToImageTensor(), # 转换为 PyTorch Tensor
790
+ T.ToTensor(),
791
+ lambda x: x.to(torch.float32), # 手动转换数据类型# T.ConvertDtype(dtype=torch.float32), # 转换数据类型
792
+ ])
793
+
794
+ # Apply the transform to the image
795
+ tensor = transform(image)
796
+
797
+ return tensor,w,h
798
+ def ouptnp2abc(output,idx_to_labels):
799
+ # Define label lists
800
+ atom_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
801
+ bond_labels = [13, 14, 15, 16, 17, 18]
802
+ charge_labels = [19, 20, 21, 22, 23]
803
+ # Create masks for atoms, bonds, and charges
804
+ atom_mask = np.isin(output['pred_classes'], atom_labels)
805
+ bond_mask = np.isin(output['pred_classes'], bond_labels)
806
+ charge_mask = np.isin(output['pred_classes'], charge_labels)
807
+ # Initialize output dictionaries
808
+ output_a = {'bbox': [], 'bbox_centers': [], 'scores': [], 'pred_classes': []}
809
+ output_b = {'bbox': [], 'bbox_centers': [], 'scores': [], 'pred_classes': []}
810
+ output_c = {'bbox': [], 'bbox_centers': [], 'scores': [], 'pred_classes': []}
811
+ # Filter and convert for atoms (output_a)
812
+ if np.any(atom_mask):
813
+ output_a['bbox'] = output['bbox'][atom_mask].tolist()
814
+ output_a['bbox_centers'] = output['bbox_centers'][atom_mask].tolist()
815
+ output_a['scores'] = output['scores'][atom_mask].tolist()
816
+ output_a['pred_classes'] = output['pred_classes'][atom_mask].tolist()
817
+ output_a['pred_classes'] = [idx_to_labels[idx] for idx in output_a['pred_classes']]
818
+
819
+ # Filter and convert for bonds (output_b)
820
+ if np.any(bond_mask):
821
+ output_b['bbox'] = output['bbox'][bond_mask].tolist()
822
+ output_b['bbox_centers'] = output['bbox_centers'][bond_mask].tolist()
823
+ output_b['scores'] = output['scores'][bond_mask].tolist()
824
+ output_b['pred_classes'] = output['pred_classes'][bond_mask].tolist()
825
+ output_b['pred_classes'] = [idx_to_labels[idx] for idx in output_b['pred_classes']]
826
+
827
+ # Filter and convert for charges (output_c)
828
+ if np.any(charge_mask):
829
+ output_c['bbox'] = output['bbox'][charge_mask].tolist()
830
+ output_c['bbox_centers'] = output['bbox_centers'][charge_mask].tolist()
831
+ output_c['scores'] = output['scores'][charge_mask].tolist()
832
+ output_c['pred_classes'] = output['pred_classes'][charge_mask].tolist()
833
+ output_c['pred_classes'] = [idx_to_labels[idx] for idx in output_c['pred_classes']]
834
+
835
+
836
+ return output_a, output_b, output_c
837
+
838
+ def bbox2center(bbox):
839
+ x_center = (bbox[:, 0] + bbox[:, 2]) / 2
840
+ y_center = (bbox[:, 1] + bbox[:, 3]) / 2
841
+ # center_coords = torch.stack((x_center, y_center), dim=1)
842
+ centers = np.stack((x_center, y_center), axis=1)
843
+ return centers
844
+
845
+ class bcolors:
846
+ HEADER = '\033[95m'
847
+ OKBLUE = '\033[94m'
848
+ OKCYAN = '\033[96m'
849
+ OKGREEN = '\033[92m'
850
+ WARNING = '\033[93m'
851
+ FAIL = '\033[91m'
852
+ ENDC = '\033[0m'
853
+ BOLD = '\033[1m'
854
+ UNDERLINE = '\033[4m'
855
+
856
+ postprocessor=RTDETRPostProcessor(classes_dict=idx_to_labels, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False)
857
+
858
+ #load onnx model
859
+ import torch.onnx
860
+ import onnx
861
+ import onnxruntime as ort
862
+ onnx_model_path = "/nfs_home/bowen/works/pys/codes/i2m/I2M_R4.onnx"#20250605
863
+ def image_to_tensor2(image_path):
864
+ # img_path="/cadd_data/samba_share/from_docker/data/work_space/ori/real/acs/op300209p-Scheme-c2-4.png"
865
+ img_path= image_path
866
+ if img_path is not None and os.path.exists(img_path):
867
+ # Load Image From Path Directly
868
+ # NOTE: Potential issue - unable to handle the flipped image.
869
+ # Temporary workaround: cv_image = cv2.imread(img_path)
870
+ cv_image = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
871
+ input_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
872
+
873
+ image_h, image_w = input_image.shape[:2]
874
+ input_h, input_w = 640,640
875
+
876
+ # Compute the scaling factors
877
+ ratio_h = input_h / image_h
878
+ ratio_w = input_w / image_w
879
+ print(ratio_h,ratio_w)
880
+ # Perform the pre-processing steps
881
+ image = cv2.resize(
882
+ input_image, (0, 0), fx=ratio_w, fy=ratio_h, interpolation=2
883
+ )
884
+ image = image.transpose((2, 0, 1)) # HWC to CHW
885
+ image = np.ascontiguousarray(image).astype("float32")
886
+ image /= 255 # 0 - 255 to 0.0 - 1.0
887
+ if len(image.shape) == 3:
888
+ image = image[None]
889
+ wh=image_w,image_h
890
+ return torch.from_numpy(image), image_w, image_h
891
+
892
+ # 准备输入数据
893
+ def to_numpy(tensor):
894
+ return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
895
+
896
+ # 加载并检查ONNX模型
897
+ onnx_model = onnx.load(onnx_model_path)
898
+ onnx.checker.check_model(onnx_model)
899
+ print("ONNX模型检查通过")
900
+ # 使用ONNX Runtime进行推理
901
+ ort_session = ort.InferenceSession(onnx_model_path)
902
+ onnx_=True
903
+ dfm=0
904
+
905
+ # ff='US20030130506A1_p0046_x1541_y1396_c00157'
906
+ # for ff in pngs:
907
+ def main():
908
+ for id_, ff in enumerate(pngs):
909
+ # if 'US20060154945A1_p0016_x0402_y1570_c00053' not in ff: continue
910
+ # indices = df.index[df['image_id'] == ff[:-4]].tolist()
911
+ # indices = df.index[df['file_base'] == ff[:-4]].tolist()
912
+ # try:
913
+ # id_=indices[0]
914
+ # except Exception as e:
915
+ # print([indices,ff])
916
+ # raise e
917
+ # SMILESori=rows_check.iloc[id_].SMILES
918
+ # file_base=rows_check.iloc[id_].file_base
919
+ # # if debug: print(id_, file_base)
920
+ # image_path= os.path.join(dst_dirac, f"{file_base}.png")
921
+
922
+ # ac_datadir=os.path.join(dst_dirac, f"{file_base}.json")
923
+ # ac_exist= os.path.exists(ac_datadir)
924
+ # if not ac_exist:
925
+ # miss_filejs.append(ac_datadir)
926
+ # continue
927
+ # image_path= os.path.join(test_dir, f"{ff}")
928
+ image_path= os.path.join(f"{ff}")
929
+ SMILESori=''
930
+ print(f"@@@@@@@@@@@@@@@@@@@@@@@ {id_}\n{image_path}\n {SMILESori}")
931
+ # print(image_path,b_datadir,ac_datadir)
932
+
933
+ img_ori = Image.open(image_path).convert('RGB')
934
+ w_ori, h_ori = img_ori.size # 获取原始图像的尺寸
935
+ # if [w_ori, h_ori]!=[256,256] and da=='staker':
936
+ # print(f"图像的尺寸不为256x256,而是{w_ori}x{h_ori},请检查图像是否正确:\n{ff}")
937
+ # continue
938
+
939
+ # print(f"图像的尺寸",[w_ori, h_ori ])
940
+ scale_x = 1000 / w_ori
941
+ scale_y = 1000 / h_ori
942
+ img_ori_1k = img_ori.resize((1000,1000))
943
+ # Example usage: #change thie image
944
+ tensor,w,h = image_to_tensor(image_path)
945
+ # tensor,w,h = image_to_tensor2(image_path)
946
+ tensor=tensor.unsqueeze(0)
947
+ if onnx_:
948
+ ort_inputs = {
949
+ ort_session.get_inputs()[0].name: to_numpy(tensor),
950
+ # ort_session.get_inputs()[1].name: to_numpy(dummy_grid)
951
+ }
952
+ ort_outputs = ort_session.run(None, ort_inputs)
953
+ # 转换为PyTorch格式
954
+ onnx_pred_logits = torch.from_numpy(ort_outputs[0])
955
+ onnx_pred_boxes = torch.from_numpy(ort_outputs[1])
956
+ # 构建与原模型一致的输出字典
957
+ onnx_output_dict = {
958
+ "pred_logits": onnx_pred_logits,
959
+ "pred_boxes": onnx_pred_boxes,
960
+ }
961
+ # else:
962
+ # #use original model
963
+ # with torch.no_grad():
964
+ # # print("training",_model.training)
965
+ # outputs_tensor = _model(tensor)
966
+ # 打印并比较结果
967
+ # print("PyTorch输出:", outputs_tensor)
968
+ # print("ONNX Runtime输出:", ort_outputs[0],ort_outputs[1],len(ort_outputs))
969
+
970
+ ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
971
+ # result_ = postprocessor(outputs_tensor, ori_size)
972
+ result_ = postprocessor(onnx_output_dict, ori_size)
973
+
974
+ score_=result_[0]['scores']
975
+ boxe_=result_[0]['boxes']
976
+ label_=result_[0]['labels']
977
+ selected_indices =score_ > box_thresh
978
+ output={
979
+ 'labels': label_[selected_indices].to("cpu").numpy(),
980
+ 'boxes': boxe_[selected_indices].to("cpu").numpy(),
981
+ 'scores': score_[selected_indices].to("cpu").numpy()
982
+ }
983
+ center_coords=bbox2center(output['boxes'])
984
+ output = {'bbox': output["boxes"],
985
+ 'bbox_centers': center_coords,
986
+ 'scores': output["scores"],
987
+ 'pred_classes': output["labels"]}
988
+ output_a, output_b, output_c= ouptnp2abc(output,idx_to_labels)
989
+
990
+
991
+
992
+ if debug:print("c,a,b>>>>>",len(output_c['pred_classes']),len(output_a['pred_classes']),len(output_b['pred_classes']))
993
+ if len(output_a['pred_classes'])==0:
994
+ file_path = 'Check_AboxIs0.txt'
995
+ content = f'{image_path}@@{id_}---{image_path}\n'
996
+ # 文件存在则追加写入,不存在则创建并写入
997
+ with open(file_path, 'a', encoding='utf-8') as f:
998
+ f.write(content)
999
+ continue #may need manulay labeling
1000
+
1001
+ overlap_records = []
1002
+ to_remove = set()
1003
+ bond_boxes = output_b['bbox']
1004
+
1005
+ bboxes = output_a['bbox'].copy()
1006
+ a_center = output_a['bbox_centers'].copy()
1007
+
1008
+ scores = output_a['scores'].copy()
1009
+ pred_classes = output_a['pred_classes'].copy()
1010
+ to_remove = set()
1011
+
1012
+ # 计算所有 atom bbox 之间的 IoU, 并根据 IoU 进行处理
1013
+ for i in range(len(bboxes)):
1014
+ for j in range(i + 1, len(bboxes)):
1015
+ # iou, relationship, inter_area, union_area = calculate_iou(bboxes[i], bboxes[j])
1016
+ x_min1, y_min1, x_max1, y_max1 = bboxes[i]
1017
+ x_min2, y_min2, x_max2, y_max2 = bboxes[j]
1018
+ # 计算交集坐标
1019
+ x_min_inter = max(x_min1, x_min2)
1020
+ y_min_inter = max(y_min1, y_min2)
1021
+ x_max_inter = min(x_max1, x_max2)
1022
+ y_max_inter = min(y_max1, y_max2)
1023
+ # 计算交集面积
1024
+ inter_width = max(0, x_max_inter - x_min_inter)
1025
+ inter_height = max(0, y_max_inter - y_min_inter)
1026
+ inter_area = inter_width * inter_height
1027
+ # 计算两个框的面积
1028
+ area1 = (x_max1 - x_min1) * (y_max1 - y_min1)
1029
+ area2 = (x_max2 - x_min2) * (y_max2 - y_min2)
1030
+ # 计算并集面积
1031
+ union_area = area1 + area2 - inter_area
1032
+ # 计算 IoU
1033
+ iou = inter_area / union_area if union_area > 0 else 0
1034
+ score_i = scores[i] if scores[i] is not None else -1
1035
+ score_j = scores[j] if scores[j] is not None else -1
1036
+ # 完全重合
1037
+ if iou == 1:
1038
+ if score_i > score_j:
1039
+ to_remove.add(j)
1040
+ else:
1041
+ to_remove.add(i)
1042
+ elif iou>=0.8 and iou <1.0:#NOTE fix me if not right
1043
+ if score_i > score_j:
1044
+ to_remove.add(j)
1045
+ if debug: print([i,j,score_i,score_j],iou,f"will remove j {j}, i-j {i,j}")
1046
+ else:
1047
+ to_remove.add(i)
1048
+ if debug: print([i,j,score_i,score_j],iou,f"will remove i {i}, i-j {i,j} ")
1049
+
1050
+ # 包含关系
1051
+ elif iou > 0 and iou < 0.89 :
1052
+ if debug: print([i,j,score_i,score_j],iou,"<<<<<<111")
1053
+ if inter_area == area1 and area1 < area2: # bbox[j] 包含 bbox[i]
1054
+ large_idx, small_idx = j, i
1055
+ elif inter_area == area2 and area2 < area1: # bbox[i] 包含 bbox[j]
1056
+ large_idx, small_idx = i, j
1057
+ else:
1058
+ if debug: print([i,j,score_i,score_j],iou,'OVERLAP without processed this version')
1059
+ continue
1060
+ # 检查是否包含 bond box
1061
+ contains_bond = False
1062
+ for bond_bbox in bond_boxes:
1063
+ if is_contained_in(bond_bbox, bboxes[large_idx]):
1064
+ contains_bond = True
1065
+ # 调整较大 bbox
1066
+ bboxes[large_idx] = adjust_bbox1(bboxes[large_idx], bboxes[small_idx], bond_bbox)
1067
+ # to_remove.add(small_idx)
1068
+ break
1069
+ if not contains_bond:
1070
+ to_remove.add(small_idx)#NOTE use the cutoff >0.45,
1071
+ elif iou==0:#==0
1072
+ pass
1073
+ else:
1074
+ print([i,j,score_i,score_j],iou,"<<<<<<222")
1075
+ print('what this case ???')
1076
+
1077
+ # 删除被移除的 bbox
1078
+ atom_bboxes = [bboxes[i] for i in range(len(bboxes)) if i not in to_remove]
1079
+ atom_scores = [scores[i] for i in range(len(scores)) if i not in to_remove]
1080
+ atom_centers = [a_center[i] for i in range(len(a_center)) if i not in to_remove]
1081
+ atom_classes = [pred_classes[i] for i in range(len(pred_classes)) if i not in to_remove]
1082
+ #TODO need sort box with x first, then y dim, useful for * with multi neiborbond
1083
+ # Sort atom_bboxes and atom_scores by x1 (bbox[0]) first, then y1 (bbox[1])
1084
+ sorted_indices = sorted(range(len(atom_bboxes)), key=lambda i: (atom_bboxes[i][0], atom_bboxes[i][1]))
1085
+ atom_bboxes = [atom_bboxes[i] for i in sorted_indices]
1086
+ atom_scores = [atom_scores[i] for i in sorted_indices]
1087
+ atom_centers = [atom_centers[i] for i in sorted_indices]
1088
+ atom_classes = [atom_classes[i] for i in sorted_indices]
1089
+
1090
+ print(len(atom_classes),'xxxxxxxx')
1091
+ bond_bbox = output_b['bbox'].copy()
1092
+ bond_scores = output_b['scores'].copy()
1093
+ bond_classes = output_b['pred_classes'].copy()
1094
+
1095
+ if len(atom_bboxes)!=len(output_a['bbox']):
1096
+ # print(f"need manualy fix ac json file------ {file_base}")
1097
+ if getacc:
1098
+ flogout2.write(f"fix ac json file---: {file_base} \n")
1099
+ # raise ValueError(f"need manualy fix ac json file------ {file_base}")
1100
+ # NOTE NEED this codes follow code not del box , used4 prepare recorrect json boxfiles
1101
+ # atom_bboxes = output_a['bbox'].copy()
1102
+ # atom_scores = output_a['scores'].copy()
1103
+ # atom_classes = output_a['pred_classes'].copy()
1104
+ # atom_centers = output_a['bbox_centers'].copy()
1105
+ # sorted_indices = sorted(range(len(atom_bboxes)), key=lambda i: (atom_bboxes[i][0], atom_bboxes[i][1]))
1106
+ # atom_bboxes = [atom_bboxes[i] for i in sorted_indices]
1107
+ # atom_scores = [atom_scores[i] for i in sorted_indices]
1108
+ # atom_centers = [atom_centers[i] for i in sorted_indices]
1109
+ # atom_classes = [atom_classes[i] for i in sorted_indices]
1110
+
1111
+
1112
+ # atom_bbox=final_bboxes
1113
+ bonds = dict()
1114
+ b2aa = dict()
1115
+ singleAtomBond = dict()
1116
+ bondWithdirct = dict()
1117
+ _margin = 0
1118
+ bond_direction = dict()
1119
+
1120
+ # Preprocess atom boxes
1121
+ atom_centers_, atom_bbox_, original_to_subbox, subbox_to_original = preprocess_atom_boxes(atom_centers, atom_bboxes)
1122
+ # Build KDTree with updated atom centers
1123
+ tree_atom = KDTree(atom_centers_)#have to includ the splited box
1124
+ if debug:
1125
+ print(f"KDTree built with {len(atom_centers_)} atom centers")
1126
+
1127
+ for bi, (bbox, bond_type) in enumerate(zip(bond_bbox, bond_classes)):
1128
+ score = bond_scores[bi]
1129
+ if score is None:
1130
+ score = 1.0 # From manual addition
1131
+ bond_scores[bi] = score
1132
+
1133
+ anchor_positions = (np.array(bbox) + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
1134
+ oposite_anchor_positions = anchor_positions.copy()
1135
+ oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
1136
+ anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
1137
+
1138
+ # Query KDTree for nearest atoms
1139
+ dists, neighbours = tree_atom.query(anchor_positions, k=1)
1140
+ if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
1141
+ begin_idx, end_idx = neighbours[:2]
1142
+ else:
1143
+ begin_idx, end_idx = neighbours[2:]
1144
+
1145
+ # Map sub-box indices back to original atom indices
1146
+ atom1_idx = int(subbox_to_original[int(begin_idx)])
1147
+ atom2_idx = int(subbox_to_original[int(end_idx)])
1148
+
1149
+ if atom1_idx == atom2_idx:
1150
+ if debug:
1151
+ print(f"singleAtomBond detected with bond id:{bi} atomIdx1 == atomIdx2 ::{[atom1_idx, atom2_idx]}")
1152
+ singleAtomBond[bi] = [atom1_idx]
1153
+
1154
+ min_ai = min([atom1_idx, atom2_idx])
1155
+ max_ai = max([atom1_idx, atom2_idx])
1156
+
1157
+ if debug:
1158
+ print(f"Bond {bi}: [{min_ai}, {max_ai}]")
1159
+
1160
+ # Assign bond type
1161
+ if bond_type in ['single', 'wdge', 'dash', '-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
1162
+ bond_ = [min_ai, max_ai, 'SINGLE', score]
1163
+ if bond_type in ['wdge', 'dash', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
1164
+ bondWithdirct[bi] = [min_ai, max_ai, 'SINGLE', score, bond_type]
1165
+ elif bond_type == '=':
1166
+ bond_ = [min_ai, max_ai, 'DOUBLE', score]
1167
+ elif bond_type == '#':
1168
+ bond_ = [min_ai, max_ai, 'TRIPLE', score]
1169
+ elif bond_type == ':':
1170
+ bond_ = [min_ai, max_ai, 'AROMATIC', score]
1171
+ else:
1172
+ if debug:
1173
+ print(f"Unknown bond_type: {bond_type} for bond {bi} [{min_ai, max_ai}]")
1174
+ bond_ = [min_ai, max_ai, 'SINGLE', score]
1175
+
1176
+ bonds[bi] = bond_
1177
+ b2aa[bi] = sorted([min_ai, max_ai])
1178
+
1179
+ if debug:
1180
+ print(f"bonds {len(bonds)}, b2aa {len(b2aa)}, singleAtomBond {len(singleAtomBond)}, bondWithdirct {len(bondWithdirct)}")
1181
+
1182
+
1183
+ #try to set up a2b, baesed on bond-anchor_positions--atom center relationship
1184
+ a2b=dict()#may be updated as following singleAtomBond cases process
1185
+ isolated_a=set()
1186
+ aa2b_d2=dict()
1187
+ for k,v in b2aa.items():
1188
+ vt=(v[0],v[1])
1189
+ if vt in aa2b_d2:
1190
+ aa2b_d2[vt].append(k)
1191
+ else:
1192
+ aa2b_d2[vt]=[k]
1193
+
1194
+ for a in set(v):
1195
+ if a not in a2b.keys():
1196
+ a2b[a]=[k]
1197
+ else:
1198
+ a2b[a].append(k)
1199
+
1200
+ # 初始化 a2neib, iso_lated atom box and singleAtomBond box process need
1201
+ a2neib = {}
1202
+ # 遍历 a2b,构建邻居关系
1203
+ for atom, bns in a2b.items():
1204
+ neighbors = set() # 使用集合去重
1205
+ for bond in bns:
1206
+ atom_pair = b2aa[bond] # 获取 bond 连接的原子对
1207
+ # 如果当前原子在 atom_pair 中,添加另一个原子作为邻居
1208
+ nei={ai for ai in atom_pair if ai !=atom }
1209
+ neighbors.update(nei)
1210
+ # if atom in atom_pair:
1211
+ # other_atom = atom_pair[0] if atom == atom_pair[1] else atom_pair[1]
1212
+ # neighbors.add(other_atom)
1213
+ a2neib[atom] = sorted(list(neighbors)) # 转换为有序列表
1214
+
1215
+ #check isolated atom exsit, if need add bond for isloated atom box when overlaping with other atom box
1216
+ isolated_a=set()
1217
+ for ai, a_lab in enumerate(atom_classes):
1218
+ if ai not in a2b.keys():
1219
+ isolated_a.add(ai)
1220
+ if debug:print("detected possible isolated atom:", isolated_a)
1221
+
1222
+
1223
+ repeate_bonds={k:v for k,v in aa2b_d2.items() if len(v)>=2 }
1224
+ if debug:print(f"repeat bond box ids {repeate_bonds}")
1225
+ #get the minimu size of bond box, check isolated_a atom box overlap with other atom box, if overlap, then add bond box (default bond label with single, score 1.0) between them
1226
+ # update a2b,b2aa, and bond box bond_classes, elif not box not overlap, the isolated_a box min(4 point of box cornners to other atom box connrer) enough small than the existed bond box size
1227
+ if len(isolated_a)>0:
1228
+ isolated_a2del=[]
1229
+ # 计算现有键的最小尺寸
1230
+ bond_sizes = []
1231
+ for bbox in bond_bbox:
1232
+ width = bbox[2] - bbox[0]
1233
+ height = bbox[3] - bbox[1]
1234
+ size = min(width, height) # 使用较小边作为键的尺寸
1235
+ bond_sizes.append(size)
1236
+ min_bond_size = min(bond_sizes) if bond_sizes else 10.0 # 默认值若无键
1237
+ if debug:print("min_bond_size ",min_bond_size, 10)
1238
+ new_bond_idx = len(bond_bbox)
1239
+ isolated_aFound=[]
1240
+ singleAtomBond_fixed=[]
1241
+ # at2b_dist=dict()
1242
+
1243
+ for iso_atom in isolated_a:
1244
+ iso_box = atom_bboxes[iso_atom]
1245
+
1246
+ #with SingleAtomBond first then check with other atom box, may a1a2 repeat on >=two bonds
1247
+ for bi,atom_idx_list in singleAtomBond.items():
1248
+ bond_box = bond_bbox[bi]
1249
+ atom1_idx = atom_idx_list[0]
1250
+ bond_vertices = get_corners(bond_box)
1251
+ # 计算 atom1_center 到 bond box 4 个顶点的距离
1252
+ atom1_center = atom_centers[atom1_idx]
1253
+ distances = [np.linalg.norm(np.array(atom1_center) - v) for v in bond_vertices]
1254
+ closest_indices = np.argsort(distances)[:2] # 距离最小的两个顶点
1255
+ connected_vertices = bond_vertices[closest_indices]
1256
+ unconnected_vertices = bond_vertices[[i for i in range(4) if i not in closest_indices]]
1257
+ # exclude_=a2neib[atom1_idx]
1258
+ exclude_=[atom1_idx]+a2neib[atom1_idx]#add it self
1259
+ print(f'exclude this atom itself :: {exclude_},and its neiboughs {a2neib[atom1_idx]}')
1260
+ # 找到 atom2(未连接端到所有 atom box 顶点的最小距离)
1261
+ # _, atom2_idx_ = get_min_distance_to_atom_box(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
1262
+ atom2_idx_, dist2 = find_nearest_atom(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
1263
+ if iso_atom == atom2_idx_:
1264
+ # 从 atom1 找到最近的另一个 atom (atom2_1)
1265
+ if atom2_idx_< atom1_idx:
1266
+ k=[atom2_idx_, atom1_idx]
1267
+ else:
1268
+ k=[atom1_idx, atom2_idx_]
1269
+
1270
+ if atom2_idx_ not in a2neib[atom1_idx]:
1271
+ b2aa[bi]=k
1272
+ bonds[bi][0]=k[0]
1273
+ bonds[bi][1]=k[1]
1274
+ a2b.setdefault(iso_atom, []).append(bi)
1275
+
1276
+ if debug: print(f'@@isolated_a fix the SingleAtomBond {bi} as bond:{bonds[bi]} !!')
1277
+ singleAtomBond_fixed.append(bi)
1278
+ isolated_aFound.append(atom2_idx_)
1279
+
1280
+ if len(repeate_bonds)>0:
1281
+ at2b_dist=dict()#NOTE the case repeate bonds with isolated atom box
1282
+ iso_box_vertices = get_corners(iso_box)
1283
+ iso_atom_center = atom_centers[iso_atom]
1284
+ bond_box_idx_, bond_box_dist = find_nearest_atom(iso_box_vertices, bond_bbox, exclude_idx=[])
1285
+ for a1a2,bis in repeate_bonds.items():#{(2, 3): [3, 4]}
1286
+ for bi in bis:
1287
+ if bi ==bond_box_idx_:
1288
+ bond_box = bond_bbox[bi]
1289
+ bond_vertices = get_corners(bond_box)
1290
+ a1_,a2_=a1a2
1291
+ a1_atombox= atom_bboxes[a1_]
1292
+ a2_atombox= atom_bboxes[a2_]
1293
+ a1_flag= boxes_overlap(a1_atombox, bond_box)
1294
+ a2_flag= boxes_overlap(a2_atombox, bond_box)
1295
+ if a1_flag:
1296
+ atom1_idx_=a1_
1297
+ dist1=0
1298
+ elif a2_flag:
1299
+ atom1_idx_=a2_
1300
+ dist1=0
1301
+ else:
1302
+ distances = [np.linalg.norm(np.array(iso_atom_center) - v) for v in bond_vertices]
1303
+ closest_indices2 = np.argsort(distances)[:2] # 距离最小的两个顶点
1304
+ connected_vertices2 = bond_vertices[closest_indices2]#isolated_close
1305
+ connected_vertices1 = bond_vertices[[i for i in range(4) if i not in closest_indices2]]
1306
+ atom1_idx_, dist1 = find_nearest_atom(connected_vertices1, atom_bboxes, exclude_idx=[iso_atom])
1307
+ if debug:print("a1_flag,a2_flag,atom1_idx_, iso_atom",[a1_flag,a2_flag,atom1_idx_,iso_atom])
1308
+ min_ai=min([atom1_idx_,iso_atom])
1309
+ max_ai=max([atom1_idx_,iso_atom])
1310
+ k=(min_ai,max_ai)
1311
+ print(k,'repeate',bi)
1312
+ if k not in at2b_dist:
1313
+ at2b_dist[k]=[bi,a1a2,dist1]
1314
+ else:
1315
+ if dist1< at2b_dist[k][1]:
1316
+ at2b_dist[k]=[bi,a1a2,dist1]
1317
+ if debug:print(f"repate bond box id: {bi} fixed with {at2b_dist}")
1318
+ isolated_aFound.append(iso_atom)
1319
+ # for k,v in at2b_dist.items():
1320
+ #update bond atom box mapping
1321
+ isolated_a2del.append(iso_atom)
1322
+ b2aa[bi] = [min_ai,max_ai]
1323
+ a2b.setdefault(iso_atom, []).append(bi)
1324
+ bonds[bi][0]=k[0]
1325
+ bonds[bi][1]=k[1]
1326
+ if bi in bondWithdirct:
1327
+ bondWithdirct[bi][0]=k[0]
1328
+ bondWithdirct[bi][1]=k[1]
1329
+
1330
+ isolated_a=[ ai for ai in isolated_a if ai not in isolated_aFound]#updated
1331
+ singleAtomBond={bi:aili for bi,aili in singleAtomBond.items() if bi not in singleAtomBond_fixed}#updated
1332
+
1333
+ for iso_atom in isolated_a:
1334
+ iso_box = atom_bboxes[iso_atom]
1335
+ #with SingleAtomBond first then chec
1336
+ for other_idx, other_box in enumerate(atom_bboxes):
1337
+ if other_idx == iso_atom\
1338
+ or (atom_classes[other_idx] in ['other',"*"] and atom_classes[iso_atom] in ['other',"*"]):
1339
+ #also not inlcude other -- *
1340
+ continue
1341
+ # 检查重叠
1342
+ min_ai=min([iso_atom,other_idx])
1343
+ max_ai=max([iso_atom,other_idx])
1344
+
1345
+ if boxes_overlap(iso_box, other_box):
1346
+ # 添加默认单键
1347
+ new_bbox = [
1348
+ min(iso_box[0], other_box[0]),
1349
+ min(iso_box[1], other_box[1]),
1350
+ max(iso_box[2], other_box[2]),
1351
+ max(iso_box[3], other_box[3])
1352
+ ]
1353
+ bond_bbox.append(new_bbox)
1354
+ bond_classes.append('single')
1355
+ bond_scores.append(1.0)
1356
+ b2aa[new_bond_idx] = [iso_atom, other_idx]
1357
+ a2b.setdefault(iso_atom, []).append(new_bond_idx)
1358
+ a2b.setdefault(other_idx, []).append(new_bond_idx)
1359
+ isolated_a2del.append(iso_atom)
1360
+ new_bond_idx += 1
1361
+ bond_=[min_ai, max_ai, 'SINGLE', 1.0]
1362
+ last_=len(bonds)
1363
+ bonds[last_] = bond_
1364
+
1365
+ if debug:
1366
+ print(f"添加键 {new_bond_idx-1} 连接原子 {iso_atom} 和 {other_idx},as isoated box overlap with it ")
1367
+ # break
1368
+ else:
1369
+ # 检查角点最小距离
1370
+ min_dist = float('inf')
1371
+ closest_atom = None
1372
+ dist = min_corner_distance(iso_box, other_box)
1373
+ if dist < min_dist:
1374
+ min_dist = dist
1375
+ closest_atom = other_idx
1376
+ if min_dist < min_bond_size:
1377
+ # 添加默认单键
1378
+ new_bbox = [
1379
+ min(iso_box[0], atom_bboxes[closest_atom][0]),
1380
+ min(iso_box[1], atom_bboxes[closest_atom][1]),
1381
+ max(iso_box[2], atom_bboxes[closest_atom][2]),
1382
+ max(iso_box[3], atom_bboxes[closest_atom][3])
1383
+ ]
1384
+ bond_bbox.append(new_bbox)
1385
+ bond_classes.append('single')
1386
+ bond_scores.append(1.0)
1387
+ b2aa[new_bond_idx] = [iso_atom, closest_atom]
1388
+ a2b.setdefault(iso_atom, []).append(new_bond_idx)
1389
+ a2b.setdefault(closest_atom, []).append(new_bond_idx)
1390
+ isolated_a2del.append(iso_atom)
1391
+ new_bond_idx += 1
1392
+ if debug:
1393
+ print(f"添加键 {new_bond_idx-1} 连接原子 {iso_atom} 和 {closest_atom} (距离 {min_dist})")
1394
+ bond_=[min_ai, max_ai, 'SINGLE', 1.0]
1395
+ last_=len(bonds)
1396
+ bonds[last_] = bond_
1397
+
1398
+ # break#as isolated may be get more than 2 bonds
1399
+ if debug:
1400
+ print('isolated_a2del and isolated_a number',len(isolated_a2del),len(isolated_a))
1401
+ print('isolated_a ',isolated_a)
1402
+ print('isolated_a2del ',isolated_a2del)
1403
+
1404
+ a2b = dict(sorted(a2b.items()))
1405
+
1406
+ # 先处理 singleAtomBond, 再removed duplicated
1407
+ if len(singleAtomBond) > 0:
1408
+ # 初始化 a2neib
1409
+ a2neib = {}
1410
+ # 遍历 a2b,构建邻居关系
1411
+ for atom, bns in a2b.items():
1412
+ neighbors = set() # 使用集合去重
1413
+ for bond in bns:
1414
+ atom_pair = b2aa[bond] # 获取 bond 连接的原子对
1415
+ # 如果当前原子在 atom_pair 中,添加另一个原子作为邻居
1416
+ nei={ai for ai in atom_pair if ai !=atom }
1417
+ neighbors.update(nei)
1418
+ # if atom in atom_pair:
1419
+ # other_atom = atom_pair[0] if atom == atom_pair[1] else atom_pair[1]
1420
+ # neighbors.add(other_atom)
1421
+ a2neib[atom] = sorted(list(neighbors)) # 转换为有序列表
1422
+
1423
+ # 找到所有 C 的 bbox 尺寸
1424
+ c_bboxes = [bbox for bbox, cls in zip(output_a['bbox'], output_a['pred_classes']) if cls == 'C']
1425
+ if not c_bboxes:
1426
+ # 如果没有C原子,使用所有bbox中最小的
1427
+ print("Warning: No 'C' atoms found, using smallest bbox in output_a instead.")
1428
+ all_bboxes = output_a['bbox']
1429
+ if not all_bboxes:
1430
+ raise ValueError("No bboxes found in output_a at all.")
1431
+ smallest_bbox = min(all_bboxes, key=bbox_area)
1432
+ c_bboxes = [smallest_bbox] # 计算最小宽度和高度
1433
+ min_width = min([bbox[2] - bbox[0] for bbox in c_bboxes])
1434
+ min_height = min([bbox[3] - bbox[1] for bbox in c_bboxes])
1435
+
1436
+ # 处理 singleAtomBond
1437
+ for bi, atom_idx_list in singleAtomBond.items():
1438
+ bond_box = bond_bbox[bi]
1439
+ atom1_idx = atom_idx_list[0]
1440
+ bond_vertices = get_corners(bond_box)
1441
+ # 计算 atom1_center 到 bond box 4 个顶点的距离
1442
+ atom1_center = atom_centers[atom1_idx]
1443
+ distances = [np.linalg.norm(np.array(atom1_center) - v) for v in bond_vertices]
1444
+ closest_indices = np.argsort(distances)[:2] # 距离最小的两个顶点
1445
+ connected_vertices = bond_vertices[closest_indices]
1446
+ unconnected_vertices = bond_vertices[[i for i in range(4) if i not in closest_indices]]
1447
+ # exclude_=a2neib[atom1_idx]
1448
+ exclude_=[atom1_idx]#add it self
1449
+ print(f'exclude this atom itself :: {exclude_},and its neiboughs {a2neib[atom1_idx]}')
1450
+ # 找到 atom2(未连接端到所有 atom box 顶点的最小距离)
1451
+ # _, atom2_idx_ = get_min_distance_to_atom_box(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
1452
+ atom2_idx_, dist2 = find_nearest_atom(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
1453
+ # 从 atom1 找到最近的另一个 atom (atom2_1)
1454
+ atom1_corners = get_corners(atom_bboxes[atom1_idx])
1455
+ atom2_1_idx, dist2_1 = find_nearest_atom(atom1_corners, atom_bboxes, exclude_idx=exclude_)
1456
+ if debug:print("atom2_idx_ , atom2_1_idx,atom1_idx:",atom2_idx_, atom2_1_idx,atom1_idx)
1457
+ if atom2_idx_< atom1_idx:
1458
+ k=[atom2_idx_, atom1_idx]
1459
+ else:
1460
+ k=[atom1_idx, atom2_idx_]
1461
+
1462
+ if atom2_idx_ == atom2_1_idx :
1463
+ if atom2_idx_ not in a2neib[atom1_idx]:
1464
+ if debug: print('add new bond with existed atom')
1465
+ b2aa[bi]=k
1466
+ bonds[bi][0]=k[0]
1467
+ bonds[bi][1]=k[1]
1468
+ else:#need insert new atom box at this bond terminal site, default with C
1469
+ new_center=np.mean(unconnected_vertices, axis=0)
1470
+ # 生成新 C 的 bbox
1471
+ new_bbox = [
1472
+ new_center[0] - min_width / 2,
1473
+ new_center[1] - min_height / 2,
1474
+ new_center[0] + min_width / 2,
1475
+ new_center[1] + min_height / 2
1476
+ ]
1477
+ if debug: print('new atom box adding as C')
1478
+ atom_bboxes.append(new_bbox)
1479
+ atom_centers.append(new_center.tolist())
1480
+ atom_scores.append(bond_scores[bi]) # 使用 bond 的 score
1481
+ atom_classes.append('C')
1482
+ #updating
1483
+ atom2_idx_= len(atom_classes)-1
1484
+ k=[atom1_idx, atom2_idx_]
1485
+ bonds[bi][1]=atom2_idx_
1486
+ b2aa[bi][1]=atom2_idx_
1487
+ else: #atom2_idx_ != atom2_1_idx, keep atom2_idx_ from bond box privlage
1488
+ if atom2_idx_ not in a2neib[atom1_idx]:
1489
+ if debug: print(f'atom2_idx_ != atom2_1_idx| {atom2_idx_} != {atom2_1_idx} @add new bond with existed atom')
1490
+ b2aa[bi]=k
1491
+ bonds[bi][0]=k[0]
1492
+ bonds[bi][1]=k[1]
1493
+ else:#need insert new atom box at this bond terminal site, default with C
1494
+ new_center=np.mean(unconnected_vertices, axis=0)
1495
+ # 生成新 C 的 bbox
1496
+ new_bbox = [
1497
+ new_center[0] - min_width / 2,
1498
+ new_center[1] - min_height / 2,
1499
+ new_center[0] + min_width / 2,
1500
+ new_center[1] + min_height / 2
1501
+ ]
1502
+ atom_bboxes.append(new_bbox)#updateing atom box
1503
+ atom_centers.append(new_center.tolist())
1504
+ atom_scores.append(bond_scores[bi]) # 使用 bond 的 score
1505
+ atom_classes.append('C')
1506
+ #updating
1507
+ atom2_idx_= len(atom_classes)-1
1508
+ k=[atom1_idx, atom2_idx_]
1509
+ bonds[bi][1]=atom2_idx_
1510
+ b2aa[bi][1]=atom2_idx_
1511
+ if debug: print(f'atom2_idx_ != atom2_1_idx@new atom box {atom2_idx_}adding as C, with bond {bi} a1a2 {k}')
1512
+
1513
+ if bi in bondWithdirct.keys():
1514
+ bondWithdirct[bi][0]=k[0]
1515
+ bondWithdirct[bi][1]=k[1]#update atom2 index
1516
+
1517
+ #TODO, fix me, this case, may need ocr.ocr first, try to dicide need isolated atom added bond times
1518
+ if debug:print(f"before del bonds {len(bond_bbox)}")
1519
+ # viewcheck_b(image_path,bond_bbox,bond_classes,color='green',figsize=(10,7))
1520
+ # viewcheck(image_path,atom_bboxes,color='red')
1521
+ #update aa2b for remove duplicated bonds
1522
+ aa2b=dict()
1523
+ for bi, aa in b2aa.items():
1524
+ min_ai=min(aa)
1525
+ max_ai=max(aa)
1526
+ if bond_scores[bi] is None:
1527
+ bond_scores[bi]=1.0
1528
+ score_=bond_scores[bi]
1529
+ bond_type=bond_classes[bi]
1530
+ # print([bond_type,score_])
1531
+ #bond_type check afte singleAtomBond
1532
+ if bond_type in ['single','wdge','dash', '-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
1533
+ bond_ = [min_ai, max_ai, 'SINGLE', score]
1534
+ if bond_type in ['wdge','dash','ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
1535
+ bondWithdirct[bi]=[min_ai, max_ai,'SINGLE', score, bond_type]
1536
+ elif bond_type == '=':
1537
+ bond_ = [min_ai, max_ai, 'DOUBLE', score]
1538
+ # print(bond_,"@@@@")
1539
+ elif bond_type == '#':
1540
+ bond_ = [min_ai, max_ai, 'TRIPLE', score]
1541
+ elif bond_type == ':':
1542
+ bond_ = [min_ai, max_ai, 'AROMATIC', score]
1543
+ else:
1544
+ print(f"what case here !!! with bond_type: {bond_type} || {[bi,min_ai, max_ai]}")
1545
+ bond_=[min_ai, max_ai, 'SINGLE', score]
1546
+
1547
+ if (min_ai, max_ai) not in aa2b.keys() or aa2b[(min_ai, max_ai)][-2]<score_:
1548
+ aa2b[(min_ai, max_ai)]=[bi,score_,bond_[-2]]
1549
+ #SINGEL Atom bond 本来是不重复的,会误认repeate and remove TODO
1550
+
1551
+ #remove duplicated bonds based on score
1552
+ if len(aa2b)!=len(b2aa):
1553
+ # 1. 去重并生成新的 bi 映射
1554
+ new_bi_map = {} # 格式: {old_bi: new_bi}
1555
+ new_bonds = {}
1556
+ new_aa2b = {}
1557
+ new_b2aa = {}
1558
+ new_bondWithdirct = {}
1559
+ new_singleAtomBond = {}
1560
+ # 按 aa2b 的顺序分配新 bi(保留分数高的键)
1561
+ for new_bi, ((min_ai, max_ai), (old_bi, score, bond_type)) in enumerate(
1562
+ sorted(aa2b.items(), key=lambda x: x[1][1], reverse=True) # 按分数降序排序
1563
+ ):
1564
+ new_bi_map[old_bi] = new_bi
1565
+ new_bonds[new_bi] = [min_ai, max_ai, bond_type, score]
1566
+ new_aa2b[(min_ai, max_ai)] = [new_bi, score, bond_type]
1567
+ new_b2aa[new_bi] = [min_ai, max_ai]
1568
+
1569
+ # 2. 更新 bondWithdirct & singleAtomBond
1570
+ for old_bi, bond_info in bondWithdirct.items():
1571
+ if old_bi in new_bi_map:
1572
+ new_bi = new_bi_map[old_bi]
1573
+ new_bondWithdirct[new_bi] = bond_info
1574
+
1575
+ for old_bi, bond_info in singleAtomBond.items():
1576
+ if old_bi in new_bi_map:
1577
+ new_bi = new_bi_map[old_bi]
1578
+ new_singleAtomBond[new_bi] = bond_info
1579
+
1580
+ # 3. 替换旧数据结构, TODO ad bond box class scores here
1581
+ bonds = new_bonds
1582
+ aa2b = new_aa2b
1583
+ b2aa = new_b2aa
1584
+ bondWithdirct = new_bondWithdirct
1585
+ singleAtomBond = new_singleAtomBond
1586
+ if debug: print(f"去重完成: bonds={len(bonds)}, aa2b={len(aa2b)}, b2aa={len(b2aa)}, bondWithdirct={len(bondWithdirct)}")
1587
+ #remove duplicated bonds based on score
1588
+ # 4. 更新 bond_bbox, bond_scores, bond_classes
1589
+ old_bns=max(new_bi_map.keys())
1590
+ to_remove_bonds=set()
1591
+ for i in range(old_bns):
1592
+ if i not in new_bi_map.keys():
1593
+ to_remove_bonds.add(i)
1594
+ print(to_remove_bonds)
1595
+ # 删除被移除的 bbox
1596
+ bond_scores = [bond_scores[i] for i in range(len(bond_scores)) if i not in to_remove_bonds]
1597
+ bond_classes = [bond_classes[i] for i in range(len(bond_classes)) if i not in to_remove_bonds]
1598
+ bond_bbox = [bond_bbox[i] for i in range(len(bond_bbox)) if i not in to_remove_bonds]
1599
+ bond_center = [[ (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 ] for bbox in bond_bbox]
1600
+
1601
+
1602
+ a2b=dict()
1603
+ isolated_a=set()
1604
+ for k,v in b2aa.items():
1605
+ # a1,a2=v
1606
+ for a in v:
1607
+ if a not in a2b.keys():
1608
+ a2b[a]=[k]
1609
+ else:
1610
+ a2b[a].append(k)
1611
+
1612
+ for ai, a_lab in enumerate(atom_classes):
1613
+ if ai not in a2b.keys():
1614
+ isolated_a.add(ai)
1615
+ a2b = dict(sorted(a2b.items()))
1616
+
1617
+ # 初始化 a2neib
1618
+ a2neib = {}
1619
+ # 遍历 a2b,构建邻居关系
1620
+ for atom, bns in a2b.items():
1621
+ neighbors = set() # 使用集合去重
1622
+ for bond in bns:
1623
+ atom_pair = b2aa[bond] # 获取 bond 连接的原子对
1624
+ # 如果当前原子在 atom_pair 中,添加另一个原子作为邻居
1625
+ nei={ai for ai in atom_pair if ai !=atom }
1626
+ neighbors.update(nei)
1627
+ # if atom in atom_pair:
1628
+ # other_atom = atom_pair[0] if atom == atom_pair[1] else atom_pair[1]
1629
+ # neighbors.add(other_atom)
1630
+ a2neib[atom] = sorted(list(neighbors)) # 转换为有序列表
1631
+
1632
+ debug2=False
1633
+ if debug2:
1634
+ # 输出结果
1635
+ print("\nBonds:")
1636
+ for bi, bond_info in bonds.items():
1637
+ print(f"Bond {bi}: {bond_info}")
1638
+ print("\nSingle Atom Bonds:")
1639
+ for bi, atom_idx in singleAtomBond.items():
1640
+ print(f"Bond {bi}: {atom_idx}")
1641
+ print("Atom to Bonds box idx maping:")
1642
+ for ai, bond_ids in a2b.items():
1643
+ print(f"a2b-id {ai}: {bond_ids}")
1644
+ print(f"isolated_ atom box:: {isolated_a}")
1645
+ print(f"b2aa::{b2aa}")
1646
+ # 输出结果
1647
+ print("a2neib:")
1648
+ for atom, neighbors in a2neib.items():
1649
+ print(f"Atom {atom}: {neighbors}")
1650
+
1651
+ other2ppsocr = True
1652
+ ocr_ai2lab = dict()
1653
+ ocr_bbs = dict()
1654
+ scale_crop = False
1655
+ ocr_ai2lab_ori=dict()
1656
+ ocr_ai2lab_sca=dict()
1657
+
1658
+
1659
+ if other2ppsocr:
1660
+ elements = ['S', 'N', 'P', 'C', 'O']
1661
+ keys = [f"{e}{suffix}" for e in elements for suffix in ['R"', "R'", "R", "*"]]
1662
+ replacement_map = {key: f'{key[0]}*' for key in keys}
1663
+ if da=='staker':
1664
+ _margin=2#as staker use small image 256X256
1665
+ else:
1666
+ _margin=0
1667
+ for i, atc in enumerate(atom_classes):
1668
+ if 'other' == atc: # 30 idx_lab version OH-->Cl with high
1669
+ # Initialize variables to store both results
1670
+ orig_result = None
1671
+ orig_score = 0
1672
+ scaled_result = None
1673
+ scaled_score = 0
1674
+
1675
+ # Process original image crop
1676
+ abox_orig = np.array(atom_bboxes[i]) + np.array([-_margin, -_margin,_margin, _margin])
1677
+ cropped_img_orig = img_ori.crop(abox_orig)
1678
+ image_npocr_orig = np.array(cropped_img_orig)
1679
+ result_ocr_orig = ocr.ocr(image_npocr_orig, det=False)
1680
+
1681
+ if result_ocr_orig:
1682
+ orig_text = result_ocr_orig[0][0][0]
1683
+ orig_score = result_ocr_orig[0][0][1]
1684
+ if debug: print(f'oriCrop:\t {orig_text} {orig_score}')
1685
+ orig_text = normalize_ocr_text(orig_text, replacement_map)
1686
+ ocr_ai2lab_ori[i]=[orig_text,orig_score]
1687
+ # Process scaled image crop
1688
+ abox_scaled = np.array(atom_bboxes[i]) * np.array([scale_x, scale_y, scale_x, scale_y]) + np.array([-_margin, -_margin,_margin, _margin])
1689
+ cropped_img_scaled = img_ori_1k.crop(abox_scaled)
1690
+ image_npocr_scaled = np.array(cropped_img_scaled)
1691
+ result_ocr_scaled = ocr.ocr(image_npocr_scaled, det=False)
1692
+
1693
+ if result_ocr_scaled:
1694
+ scaled_text = result_ocr_scaled[0][0][0]
1695
+ scaled_score = result_ocr_scaled[0][0][1]
1696
+ if debug: print(f'scaled:\t {scaled_text} {scaled_score}')
1697
+ scaled_text = normalize_ocr_text(scaled_text, replacement_map)
1698
+ ocr_ai2lab_sca[i]=[scaled_text,scaled_score]
1699
+
1700
+
1701
+
1702
+ final_text, final_score, final_crop = select_chem_expression(
1703
+ orig_text, orig_score, scaled_text, scaled_score, cropped_img_orig, cropped_img_scaled
1704
+ )
1705
+
1706
+ if orig_text=='NO2' or scaled_text=='NO2':
1707
+ final_text='NO2'#AS stm NO score >NO2
1708
+ elif orig_text=='SO2' or scaled_text=='SO2':
1709
+ final_text='SO2'#AS stm NO score >SO2
1710
+ # elif orig_starts_upper == scaled_starts_upper:
1711
+ # # If both start with uppercase or both don't, use the higher score
1712
+ # final_text = orig_text if orig_score >= scaled_score else scaled_text
1713
+ # elif orig_starts_upper != scaled_starts_upper:
1714
+ # # If one starts with uppercase, use that one
1715
+ # final_text = orig_text if orig_starts_upper else scaled_text
1716
+
1717
+ if final_text:
1718
+ ocr_ai2lab[i] = [final_text, final_score]
1719
+ ocr_bbs[i] = final_crop
1720
+ atom_classes[i] = final_text
1721
+ if debug:
1722
+ print("ori",ocr_ai2lab_ori)
1723
+ print("sca",ocr_ai2lab_sca)
1724
+ print(ocr_ai2lab)
1725
+ #TODO make chem-group recongized dataBase next works !!!
1726
+
1727
+ if len(ocr_bbs)>0:
1728
+ if debug:print(f'numbs of ocr {len(ocr_bbs)} crop_ images')
1729
+ #merge the isolated_a Ph3Br into closet atom box
1730
+ # 3 in isolated_a, isolated_a, isolated_aFound
1731
+ giveup_isolateds=dict()
1732
+ if len(isolated_a):#after updated isolated_a still has the isolatd item
1733
+ for iso_atom in isolated_a:
1734
+ atom1_corners = get_corners(atom_bboxes[iso_atom])
1735
+ atom2_1_idx, dist2_1 = find_nearest_atom(atom1_corners, atom_bboxes, exclude_idx=[iso_atom])
1736
+ atom1_lab=atom_classes[iso_atom]
1737
+ atom2_lab=atom_classes[atom2_1_idx]
1738
+ if atom1_lab in ['Ph3Br','Ph3Br-']:
1739
+ if iso_atom not in giveup_isolateds.keys():
1740
+ giveup_isolateds[iso_atom]=[atom1_lab]
1741
+ else:
1742
+ giveup_isolateds[iso_atom].append(atom1_lab)
1743
+
1744
+ if atom2_lab in ['P','P+']:#merge as new group
1745
+ atom2_lab='P+Ph3Br-'
1746
+ elif atom2_lab in ['N','N+']:#merge as new group
1747
+ atom2_lab='N+Ph3Br-'
1748
+
1749
+ atom_classes[atom2_1_idx]=atom2_lab #update bonded atom label with the merged
1750
+
1751
+ #TODO add cases that need merge OCR results with bonded atom box
1752
+ if debug:
1753
+ print(f"giveup_isolateds {giveup_isolateds}")
1754
+ print(len(atom_classes),len(bond_classes),'<<<<<<<<<<<')#,len(charges_classes))
1755
+ ###########################start build mol ##########################
1756
+ rwmol_ = Chem.RWMol()
1757
+ boxi2ai = {} # 预测索引 -> RDKit 索引
1758
+ placeholder_atoms=dict()
1759
+ # print(len(atom_classes),len(bond_classes))#,len(charges_classes))
1760
+ #assign atom
1761
+ J=0
1762
+ for i, (bbox, a) in enumerate(zip(atom_bboxes, atom_classes)):
1763
+ a2labl=False
1764
+ a=replace_cg_notation(a)
1765
+ # print(a,'atom box class label')
1766
+ if a in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:# '*', I2M's defined atom types
1767
+ # if a=='H':continue#skip H fristly,only with heavy atom then addH
1768
+ ad = Chem.Atom(a)#TODO consider non chemical group and label for using
1769
+ #TODO add pd rdkit known elemetns here
1770
+ elif a in ELEMENTS:
1771
+ ad = Chem.Atom(a)
1772
+
1773
+ elif a in ABBREVIATIONS :
1774
+ ad = Chem.Atom("*")
1775
+ placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
1776
+ a2labl=True
1777
+ else:
1778
+ if N_C_H_expand(a):
1779
+ ad = Chem.Atom("*")
1780
+ placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
1781
+ a2labl=True
1782
+ elif C_H_expand(a):
1783
+ ad = Chem.Atom("*")
1784
+ placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
1785
+ a2labl=True
1786
+ elif C_H_expand2(a):
1787
+ ad = Chem.Atom("*")
1788
+ placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
1789
+ a2labl=True
1790
+
1791
+ elif formula_regex(a):
1792
+ ad = Chem.Atom("*")
1793
+ placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
1794
+ a2labl=True
1795
+ else:
1796
+ ad = Chem.Atom("*")
1797
+ if a not in ['*',"other"]:
1798
+ a2labl=True
1799
+ # placeholder_atoms[idx] = a
1800
+ # atom = Chem.Atom(symbol)
1801
+ rwmol_.AddAtom(ad)
1802
+ boxi2ai[J] = rwmol_.GetNumAtoms() - 1
1803
+ if a2labl: rwmol_.GetAtomWithIdx(J).SetProp("atomLabel", f"{a}")#mol set with label, mol_rebuild not
1804
+ J+=1
1805
+
1806
+
1807
+ charges_classes= output_c['pred_classes']
1808
+ charges_centers= output_c['bbox_centers']
1809
+ charges_scores= output_c['scores']
1810
+ charges_bbox= output_c['bbox']
1811
+ a2c=dict()
1812
+ c2a=dict()
1813
+
1814
+ # #assign charge
1815
+ if len(charges_classes) > 0:
1816
+ kdt = cKDTree(atom_centers)
1817
+ c2a = {} # 电荷索引到原子索引的映射
1818
+ used_atoms = set() # 跟踪已分配电荷的原子
1819
+ for i, charge_box in enumerate(charges_bbox):
1820
+ charge_value = parse_charge(charges_classes[i])
1821
+ overlapped_atoms = []
1822
+ # 检查重叠
1823
+ for ai, atom_box in enumerate(atom_bboxes):
1824
+ if boxes_overlap(charge_box, atom_box):
1825
+ overlapped_atoms.append(ai)
1826
+ if overlapped_atoms:
1827
+ # 如果有重叠,选择第一个未使用的原子(假设一个电荷只分配一个原子)
1828
+ for ai in overlapped_atoms:
1829
+ if ai not in used_atoms:
1830
+ c2a[i] = ai
1831
+ used_atoms.add(ai)
1832
+ break
1833
+ else:
1834
+ # 不重叠时,使用角点距离和 KDTree 验证
1835
+ x, y = charges_centers[i]
1836
+ dist_kdt, ai_kdt = kdt.query([x, y], k=1)
1837
+ # 计算角点距离最近的原子
1838
+ min_dist = float('inf')
1839
+ ai_corner = None
1840
+ for ai, atom_box in enumerate(atom_bboxes):
1841
+ dist = min_corner_distance(charge_box, atom_box)
1842
+ if dist < min_dist:
1843
+ min_dist = dist
1844
+ ai_corner = ai
1845
+ # 比较 KDTree 和角点距离结果
1846
+ if ai_kdt == ai_corner and ai_kdt not in used_atoms:
1847
+ c2a[i] = ai_kdt
1848
+ used_atoms.add(ai_kdt)
1849
+ else:
1850
+ # 检查电荷值和原子类型
1851
+ if charge_value != 0:
1852
+ symbol_kdt =atom_classes[ai_kdt]
1853
+ symbol_corner =atom_classes[ai_corner]
1854
+ # 如果电荷值不为零,分配给非C的原子,如果都是非C, 则根据kdt k=1来分配电荷
1855
+ if symbol_kdt == 'C' and symbol_corner != 'C' and ai_corner not in used_atoms:
1856
+ # KDTree 是碳,角点不是碳,优先分配给角点原子
1857
+ c2a[i] = ai_corner
1858
+ used_atoms.add(ai_corner)
1859
+ elif symbol_corner == 'C' and symbol_kdt != 'C' and ai_kdt not in used_atoms:
1860
+ # 角点是碳,KDTree 不是碳,优先分配给 KDTree 原子
1861
+ c2a[i] = ai_kdt
1862
+ used_atoms.add(ai_kdt)
1863
+ else:
1864
+ # 两个都是非碳,或两个都是碳,默认使用 KDTree 结果
1865
+ if ai_kdt not in used_atoms:
1866
+ c2a[i] = ai_kdt
1867
+ used_atoms.add(ai_kdt)
1868
+ elif ai_corner not in used_atoms:
1869
+ # 如果 KDTree 结果已使用,尝试角点结果
1870
+ c2a[i] = ai_corner
1871
+ used_atoms.add(ai_corner)
1872
+
1873
+ #assign charge
1874
+ a2c={v:k for k,v in c2a.items()}
1875
+ for k,v in a2c.items():
1876
+ fc=int(charges_classes[v])
1877
+ rwmol_.GetAtomWithIdx(k).SetFormalCharge(fc)
1878
+ # if k in placeholder_atoms:
1879
+ if atom_classes[k] in ['COO','CO2']:#TODO add more charge if need
1880
+ if fc==-1:
1881
+ atom_classes[k]=f"{atom_classes[k]}-"
1882
+ placeholder_atoms[k]=atom_classes[k]
1883
+ atom = rwmol_.GetAtomWithIdx(k)
1884
+ atom.SetProp("atomLabel",placeholder_atoms[k])
1885
+ elif fc==1:
1886
+ atom_classes[k]=f"{atom_classes[k]}+"
1887
+ placeholder_atoms[k]=atom_classes[k]
1888
+ atom = rwmol_.GetAtomWithIdx(k)
1889
+ atom.SetProp("atomLabel",placeholder_atoms[k])
1890
+ else:
1891
+ print(f"charge adding {fc} @ {atom_classes[v]}")
1892
+ print(f'placeholder_atoms {placeholder_atoms}')
1893
+ #add bonds
1894
+ for bi, bond in bonds.items():
1895
+ atom1_idx, atom2_idx, bond_type, score = bond
1896
+ if atom1_idx ==atom2_idx:print(f"self bond should be avoid or del on previous process!!")
1897
+ # print(f"Adding bond between atoms {atom1_idx} and {atom2_idx} of type {bond_type}")
1898
+ if bond_type == 'SINGLE':
1899
+ rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.SINGLE)
1900
+ elif bond_type == 'DOUBLE':
1901
+ rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.DOUBLE)
1902
+ elif bond_type == 'TRIPLE':
1903
+ rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.TRIPLE)
1904
+ elif bond_type == 'AROMATIC':
1905
+ rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.AROMATIC)
1906
+ else:
1907
+ print(f"Unknown bond type: {bond_type}")
1908
+
1909
+ if debug: print(f"all a2b b2a a2c c2a done, start mol built done")
1910
+ #set direction
1911
+ if len(bondWithdirct)>0:
1912
+ print(f"set bond direction for mollecule ")
1913
+ # rwmol_=set_bondDriection(rwmol_,bondWithdirct)
1914
+
1915
+ skeleton_smi = Chem.MolToSmiles(rwmol_) #TODO WEB_dev, use this rwmol_ for display without expand the R groups
1916
+ #ASSIGN COORDS
1917
+ coords = [(x,-y,0) for x,y in atom_centers]
1918
+ coords = tuple(coords)
1919
+ coords = tuple(tuple(num / 100 for num in sub_tuple) for sub_tuple in coords)
1920
+
1921
+ mol2D = rwmol_.GetMol()
1922
+ mol2D.RemoveAllConformers()
1923
+ conf = Chem.Conformer(mol2D.GetNumAtoms())
1924
+ conf.Set3D(True)
1925
+ for i, (x, y, z) in enumerate(coords):
1926
+ conf.SetAtomPosition(i, (x, y, z))
1927
+ mol2D.AddConformer(conf)
1928
+ try:
1929
+ Chem.SanitizeMol(mol2D)
1930
+ Chem.AssignStereochemistryFrom3D(mol2D)
1931
+ mol_rebuit2d=Chem.RWMol(mol2D)
1932
+ except Exception as e:
1933
+ print(e)
1934
+ print('before expanding!!! try to sanizemol and assign stereo')
1935
+ mol_rebuit2d=Chem.RWMol(rwmol_)
1936
+ if len(giveup_isolateds)>0:
1937
+ #clean with remove giveup_isolateds
1938
+ # 1. 先为每个原子设置一个“old_index”属性
1939
+ for atom in mol_rebuit2d.GetAtoms():
1940
+ atom.SetProp('old_index', str(atom.GetIdx()))
1941
+
1942
+ # 2. 删除原子时建议按照降序删除,避免索引变化带来的问题
1943
+ for ai in sorted(giveup_isolateds.keys(), reverse=True):
1944
+ mol_rebuit2d.RemoveAtom(ai)
1945
+ print(f"atom {ai} label {giveup_isolateds[ai]} removed")
1946
+
1947
+ # 3. 删除操作完成后,构建老索引到新索引的映射
1948
+ old_to_new = {}
1949
+ for atom in mol_rebuit2d.GetAtoms():
1950
+ old_idx = int(atom.GetProp('old_index'))
1951
+ new_idx = atom.GetIdx()
1952
+ old_to_new[old_idx] = new_idx
1953
+
1954
+ if len(placeholder_atoms)>0:#update placeholder_atoms
1955
+ placeholder_atoms2=dict()
1956
+ for k,v in placeholder_atoms.items():
1957
+ placeholder_atoms2[old_to_new[k]]=v
1958
+
1959
+ placeholder_atoms=placeholder_atoms2
1960
+ try:
1961
+ SMILESpre = Chem.MolToSmiles(mol_rebuit2d)
1962
+ except Exception as e:
1963
+ print(f"Error during SMILES generation: {e}")
1964
+ SMILESpre = Chem.MolToSmiles(mol_rebuit2d, canonical=False)
1965
+
1966
+
1967
+ if len(placeholder_atoms)>0:
1968
+ mol_expan=copy.deepcopy(mol_rebuit2d)
1969
+ if debug: print(f'MOL will be expanded with {placeholder_atoms} !!')
1970
+ wdbs=[]
1971
+ bond_dirs_rev={v:k for k,v in bond_dirs.items()}
1972
+
1973
+ for b in mol_expan.GetBonds():
1974
+ bd=b.GetBondDir()
1975
+ bt=b.GetBondType()
1976
+ # print(bd)
1977
+ if bd ==bond_dirs['BEGINDASH'] or bd==bond_dirs['BEGINWEDGE']:
1978
+ a1, a2 = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
1979
+ wdbs.append([a1,a2,bt,bond_dirs_rev[bd]])
1980
+
1981
+ expandStero_smi1,molexp= molExpanding(mol_expan,placeholder_atoms,wdbs,bond_dirs)#TODO fix me whe n multi strings on a atom will missing this ocr infors
1982
+ molexp=remove_bond_directions_if_no_chiral(molexp)
1983
+ try:
1984
+ Chem.SanitizeMol(molexp)
1985
+ expandStero_smi=Chem.MolToSmiles(molexp)
1986
+ except Exception as e:
1987
+ print(f"Error during sanitization: {e}")
1988
+ expandStero_smi = expandStero_smi1
1989
+
1990
+ expandStero_smi=remove_SP(expandStero_smi)
1991
+
1992
+ else:
1993
+ molexp=mol_rebuit2d
1994
+ expandStero_smi=SMILESpre #save into csv files,
1995
+
1996
+ #TODO WEB_dev, now can display mol with expanded abbev from molexp
1997
+ new_row = {'file_name':image_path, "SMILESori":SMILESori,
1998
+ 'SMILESpre':SMILESpre,
1999
+ 'SMILESexp':expandStero_smi,
2000
+ }
2001
+
2002
+ # smiles_data = smiles_data._append(new_row, ignore_index=True)#TODO WEB_dev task done here, we can save predicted Rdkit Obj or smiles or display on web
2003
+ print(f"final prediction:\n {expandStero_smi}")
2004
+
2005
+ return expandStero_smi
2006
+
2007
+ main()
2008
+
2009
+ # 安全释放资源
2010
+ # def release_ocr(ocr_instance):
2011
+ # # 关闭所有相关模型
2012
+ # if hasattr(ocr_instance, 'detector'):
2013
+ # ocr_instance.detector = None
2014
+ # if hasattr(ocr_instance, 'recognizer'):
2015
+ # ocr_instance.recognizer = None
2016
+ # if hasattr(ocr_instance, 'cls'):
2017
+ # ocr_instance.cls = None
2018
+
2019
+ # # 调用释放函数
2020
+ # release_ocr(ocr)
2021
+ # del ocr
2022
+ # release_ocr(ocr2)
2023
+ # del ocr2
2024
+
2025
+
app.py CHANGED
@@ -1,7 +1,63 @@
1
- from fastapi import FastAPI
2
-
3
- app = FastAPI()
4
-
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse
3
+ import gradio as gr
4
+ import os
5
+ from ONNX0630 import main as predict_smiles
6
+ from PIL import Image
7
+ import io
8
+
9
+ # Initialize FastAPI app
10
+ app = FastAPI(title="Chemical Structure to SMILES API")
11
+
12
+ # API endpoint to predict SMILES from an image
13
+ @app.post("/predict")
14
+ async def predict(file: UploadFile = File(...)):
15
+ try:
16
+ # Read and save the uploaded image
17
+ contents = await file.read()
18
+ image = Image.open(io.BytesIO(contents))
19
+ temp_path = f"temp_{file.filename}"
20
+ image.save(temp_path)
21
+
22
+ # Call the model function
23
+ smiles = predict_smiles(temp_path)
24
+
25
+ # Clean up temporary file
26
+ os.remove(temp_path)
27
+
28
+ return JSONResponse(content={"smiles": smiles})
29
+ except Exception as e:
30
+ return JSONResponse(content={"error": str(e)}, status_code=500)
31
+
32
+ # Gradio interface
33
+ def gradio_predict(image):
34
+ try:
35
+ # Save the uploaded image
36
+ temp_path = "temp_image.png"
37
+ image.save(temp_path)
38
+
39
+ # Call the model function
40
+ smiles = predict_smiles(temp_path)
41
+
42
+ # Clean up
43
+ os.remove(temp_path)
44
+
45
+ return smiles
46
+ except Exception as e:
47
+ return f"Error: {str(e)}"
48
+
49
+ # Define Gradio interface
50
+ iface = gr.Interface(
51
+ fn=gradio_predict,
52
+ inputs=gr.Image(type="pil"),
53
+ outputs=gr.Textbox(),
54
+ title="Chemical Structure to SMILES Converter",
55
+ description="Upload an image of a chemical structure to get its SMILES string."
56
+ )
57
+
58
+ # Launch Gradio with FastAPI
59
+ app = gr.mount_gradio_app(app, iface, path="/")
60
+
61
+ if __name__ == "__main__":
62
+ import uvicorn
63
+ uvicorn.run(app, host="0.0.0.0", port=7860)
det_engine.py ADDED
The diff for this file is too large to render. See raw diff
 
utils.py ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import math
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from scipy.spatial import cKDTree
8
+ from rdkit import Chem
9
+ from rdkit.Chem import RWMol
10
+ from rdkit.Chem import Draw, AllChem
11
+ from rdkit.Chem import rdDepictor
12
+ import matplotlib.pyplot as plt
13
+ import re
14
+ ##################### MolScribe####################################################################################
15
+ from typing import List
16
+ import matplotlib.pyplot as plt
17
+ from matplotlib.patches import Rectangle, Circle
18
+
19
+
20
+ COLORS = {
21
+ u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
22
+ u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
23
+ }
24
+
25
+ #helper function
26
+ def view_box_center(bond_bbox,heavy_centers):
27
+ fig, ax = plt.subplots(figsize=(10, 10))
28
+ # 绘制矩形框 (boxes)
29
+ for box in bond_bbox:
30
+ x1, y1, x2, y2 = box
31
+ width = x2 - x1
32
+ height = y2 - y1
33
+ rect = Rectangle((x1, y1), width, height, linewidth=1, edgecolor='blue', facecolor='none')
34
+ ax.add_patch(rect)
35
+
36
+ # 绘制圆形 (centers)
37
+ for center in heavy_centers:
38
+ x, y = center
39
+ circle = Circle((x, y), radius=5, edgecolor='red', facecolor='none', linewidth=1)
40
+ ax.add_patch(circle)
41
+
42
+ # 设置坐标轴范围(根据数据自动调整)
43
+ x_min = min(bond_bbox[:, 0].min(), heavy_centers[:, 0].min()) - 10
44
+ x_max = max(bond_bbox[:, 2].max(), heavy_centers[:, 0].max()) + 10
45
+ y_min = min(bond_bbox[:, 1].min(), heavy_centers[:, 1].min()) - 10
46
+ y_max = max(bond_bbox[:, 3].max(), heavy_centers[:, 1].max()) + 10
47
+ ax.set_xlim(x_min, x_max)
48
+ ax.set_ylim(y_min, y_max)
49
+
50
+ # 设置标题和标签
51
+ ax.set_title("Boxes and Centers")
52
+ ax.set_xlabel("X")
53
+ ax.set_ylabel("Y")
54
+ # 显示图像
55
+ plt.gca().set_aspect('equal', adjustable='box') # 保持比例
56
+ plt.grid(True, linestyle='--', alpha=0.7)
57
+
58
+ def molIDX(mol):
59
+ for i, atom in enumerate(mol.GetAtoms()):
60
+ atom.SetAtomMapNum(i) #映射
61
+ # print(i)
62
+ return mol
63
+
64
+ def molIDX_del(mol):
65
+ for i, atom in enumerate(mol.GetAtoms()):
66
+ atom.SetAtomMapNum(0) #映射
67
+ print(i)
68
+ return mol
69
+ from det_engine import ABBREVIATIONS
70
+
71
+
72
+
73
+ def Val_extract_atom_info(error_message):
74
+ """
75
+ 从错误信息中提取 atomid, atomType 和 valence。
76
+ :param error_message: 错误信息字符串
77
+ :return: (atomid, atomType, valence) 元组
78
+ """
79
+ # 定义正则表达式来提取原子信息
80
+ pattern = r"Explicit valence for atom # (\d+) (\w), (\d+)"
81
+ pattern2 =r"Explicit valence for atom # (\d+) (\w) "
82
+ # print(type(error_message))
83
+ if not isinstance(error_message, type('strs')):
84
+ error_message=str(error_message)
85
+ match = re.search(pattern, error_message)
86
+ match2 = re.search(pattern2, error_message)
87
+ if match:
88
+ # 提取 atomid, atomType 和 valence
89
+ atomid = int(match.group(1)) # 原子索引
90
+ atomType = match.group(2) # 原子类型
91
+ valence = int(match.group(3)) # 当前价态
92
+ return atomid, atomType, valence
93
+ elif match2:
94
+ atomid = int(match2.group(1)) # 原子索引
95
+ atomType = match2.group(2) # 原子类型
96
+ # valence = int(match2.group(3)) # 当前价态
97
+ return atomid, atomType, None
98
+
99
+ else:
100
+ raise ValueError("无法从错误信息中提取原子信息")
101
+
102
+ def calculate_charge_adjustment(atom_symbol, current_valence):
103
+ """
104
+ 计算需要调整的电荷,根据反馈的原子符号和当前价态。
105
+ :param atom_symbol: 原子符号(如 "C")
106
+ :param current_valence: 当前价态(如 5)
107
+ :return: 需要添加的电荷数(正数表示负电荷,负数表示正电荷)
108
+ """
109
+ if atom_symbol not in VALENCES:
110
+ raise ValueError(f"未知的原子符号: {atom_symbol}")
111
+
112
+ # 查找该元素的最大价态
113
+ max_valence = max(VALENCES[atom_symbol])
114
+ if current_valence is None:
115
+ current_valence=max_valence
116
+ # 如果当前价态大于最大允许价态,需要调整电荷
117
+ if current_valence > max_valence:
118
+ # 需要添加的负电荷数
119
+ charge_adjustment = current_valence - max_valence
120
+ return charge_adjustment
121
+ else:
122
+ # 当前价态已经符合最大允许价态,不需要调整
123
+ return 0
124
+
125
+ from rdkit.Chem import rdchem, RWMol, CombineMols
126
+
127
+ def expandABB(mol,ABBREVIATIONS, placeholder_atoms):
128
+ mols = [mol]
129
+ # **第三步: 替换 * 并合并官能团**
130
+ # 逆序遍历 placeholder_atoms,确保删除后不会影响后续索引
131
+ for idx in sorted(placeholder_atoms.keys(), reverse=True):
132
+ group = placeholder_atoms[idx] # 获取官能团名称
133
+ # print(idx, group)
134
+ submol = Chem.MolFromSmiles(ABBREVIATIONS[group].smiles) # 获取官能团的子���子
135
+ submol_rw = RWMol(submol) # 让 submol 变成可编辑的 RWMol
136
+ anchor_atom_idx = 0 # 选择 `submol` 的第一个原子作为连接点 as defined in ABBREVIATIONS
137
+ # **1. 复制主分子**
138
+ new_mol = RWMol(mol)
139
+ # **2. 计算 `*` 在 `new_mol` 中的索引**
140
+ placeholder_idx = idx
141
+ # **3. 记录 `*` 原子的邻居**
142
+ neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()]
143
+ # **4. 断开 `*` 的所有键**
144
+ bonds_to_remove = [] # 记录要断开的键
145
+ for bond in new_mol.GetBonds():
146
+ if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx:
147
+ bonds_to_remove.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
148
+ for bond in bonds_to_remove:
149
+ new_mol.RemoveBond(bond[0], bond[1])
150
+ # **5. 删除 `*` 原子**
151
+ new_mol.RemoveAtom(placeholder_idx)
152
+ # **6. 重新计算 `neighbors`(删除后索引变化)**
153
+ new_neighbors = []
154
+ for neighbor in neighbors:
155
+ if neighbor < placeholder_idx:
156
+ new_neighbors.append(neighbor)
157
+ else:
158
+ new_neighbors.append(neighbor - 1) # 因为删除了一个原子,所有索引 -1
159
+ # **7. 合并 `submol`**
160
+ new_mol = RWMol(CombineMols(new_mol, submol_rw))
161
+
162
+ # **8. 计算 `submol` 的第一个原子在合并后的位置**
163
+ new_anchor_idx = new_mol.GetNumAtoms() - len(submol_rw.GetAtoms()) + anchor_atom_idx
164
+
165
+ # **9. 重新连接官能团**
166
+ for neighbor in new_neighbors:
167
+ # print(neighbor, new_anchor_idx, "!!")
168
+ new_mol.AddBond(neighbor, new_anchor_idx, Chem.BondType.SINGLE)
169
+ a1=new_mol.GetAtomWithIdx(neighbor)
170
+ a2=new_mol.GetAtomWithIdx(new_anchor_idx)
171
+ a1.SetNumRadicalElectrons(0)
172
+ a2.SetNumRadicalElectrons(0)## 将自由基电子数设为 0,as has added new bond
173
+ # **10. 更新主分子**
174
+ mol = new_mol
175
+ mols.append(mol)
176
+ # # 遍历分子中的每个原子
177
+ # for atom in mols[-1].GetAtoms(): NOTE considering original image has the RadicalElectrons
178
+ # atom_idx = atom.GetIdx() # 原子索引
179
+ # radical_electrons = atom.GetNumRadicalElectrons() # 自由基电子数
180
+ # if radical_electrons > 0:
181
+ # # print(f"原子 {atom_idx} 存在自由基,自由基电子数: {radical_electrons}\n current NumExplicitHs: {atom.GetNumExplicitHs()}")
182
+ # # 消除自由基:通过添加氢原子调整价态
183
+ # atom.SetNumRadicalElectrons(0) # 将自由基电子数设为 0,as has added bond
184
+ # # atom.SetNumExplicitHs(atom.GetNumExplicitHs() + radical_electrons)
185
+ Chem.SanitizeMol(mols[-1])
186
+ # 输出修改后的分子 SMILES
187
+ modified_smiles = Chem.MolToSmiles(mols[-1])
188
+ # print(f"修改后的分子 SMILES: {modified_smiles}")
189
+ return mols[-1], modified_smiles
190
+
191
+ ################################################################################################################################################################
192
+ def output_to_smiles(output,idx_to_labels,bond_labels,result):#this will output * without abbre version
193
+ #only output smiles with *
194
+ x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
195
+ y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
196
+
197
+ center_coords = torch.stack((x_center, y_center), dim=1)
198
+
199
+ output = {'bbox': output["boxes"].to("cpu").numpy(),
200
+ 'bbox_centers': center_coords.to("cpu").numpy(),
201
+ 'scores': output["scores"].to("cpu").numpy(),
202
+ 'pred_classes': output["labels"].to("cpu").numpy()}
203
+
204
+
205
+ atoms_list, bonds_list,charge = bbox_to_graph_with_charge(output,
206
+ idx_to_labels=idx_to_labels,
207
+ bond_labels=bond_labels,
208
+ result=result)
209
+ smiles, mol= mol_from_graph_with_chiral(atoms_list, bonds_list,charge)
210
+ abc=[atoms_list, bonds_list,charge ]
211
+
212
+ if isinstance(smiles, type(None)):
213
+ print(f"get atoms_list problems")
214
+ # smiles, mol=None,None
215
+ elif isinstance(atoms_list,type(None)):
216
+ print(f"get atoms_list problems")
217
+ # smiles, mol=None,None
218
+ # else:
219
+ # smiles, mol=smiles_mol
220
+ return abc,smiles,mol,output
221
+
222
+
223
+ def output_to_smiles2(output,idx_to_labels,bond_labels,result):#this will output * without abbre version
224
+ #only output smiles with *
225
+ x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
226
+ y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
227
+
228
+ center_coords = torch.stack((x_center, y_center), dim=1)
229
+
230
+ output = {'bbox': output["boxes"].to("cpu").numpy(),
231
+ 'bbox_centers': center_coords.to("cpu").numpy(),
232
+ 'scores': output["scores"].to("cpu").numpy(),
233
+ 'pred_classes': output["labels"].to("cpu").numpy()}
234
+
235
+
236
+ atoms_list, bonds_list,charge = bbox_to_graph_with_charge(output,
237
+ idx_to_labels=idx_to_labels,
238
+ bond_labels=bond_labels,
239
+ result=result)
240
+ smiles, mol= mol_from_graph_with_chiral(atoms_list, bonds_list,charge)
241
+ abc=[atoms_list, bonds_list,charge ]
242
+ if isinstance(smiles, type(None)):
243
+ print(f"get atoms_list problems")
244
+ # smiles, mol=None,None
245
+ elif isinstance(atoms_list,type(None)):
246
+ print(f"get atoms_list problems")
247
+ # smiles, mol=None,None
248
+ # else:
249
+ # smiles, mol=smiles_mol
250
+ return abc,smiles,mol,output
251
+
252
+
253
+
254
+ def bbox_to_graph(output, idx_to_labels, bond_labels,result):
255
+
256
+ # calculate atoms mask (pred classes that are atoms/bonds)
257
+ atoms_mask = np.array([True if ins not in bond_labels else False for ins in output['pred_classes']])
258
+
259
+ # get atom list
260
+ atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]]
261
+
262
+ # if len(result) !=0 and 'other' in atoms_list:
263
+ # new_list = []
264
+ # replace_index = 0
265
+ # for item in atoms_list:
266
+ # if item == 'other':
267
+ # new_list.append(result[replace_index % len(result)])
268
+ # replace_index += 1
269
+ # else:
270
+ # new_list.append(item)
271
+ # atoms_list = new_list
272
+
273
+ atoms_list = pd.DataFrame({'atom': atoms_list,
274
+ 'x': output['bbox_centers'][atoms_mask, 0],
275
+ 'y': output['bbox_centers'][atoms_mask, 1]})
276
+
277
+ # in case atoms with sign gets detected two times, keep only the signed one
278
+ for idx, row in atoms_list.iterrows():
279
+ if row.atom[-1] != '0':
280
+ if row.atom[-2] != '-':#assume charge value -9~9
281
+ overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])]
282
+ else:
283
+ overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])]
284
+
285
+ kdt = cKDTree(overlapping[['x', 'y']])
286
+ dists, neighbours = kdt.query([row.x, row.y], k=2)
287
+ if dists[1] < 7:
288
+ atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True)
289
+
290
+ bonds_list = []
291
+
292
+ # get bonds
293
+ for bbox, bond_type, score in zip(output['bbox'][np.logical_not(atoms_mask)],
294
+ output['pred_classes'][np.logical_not(atoms_mask)],
295
+ output['scores'][np.logical_not(atoms_mask)]):
296
+
297
+ # if idx_to_labels[bond_type] == 'SINGLE':
298
+ if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
299
+ _margin = 5
300
+ else:
301
+ _margin = 8
302
+
303
+ # anchor positions are _margin distances away from the corners of the bbox.
304
+ anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
305
+ oposite_anchor_positions = anchor_positions.copy()
306
+ oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
307
+
308
+ # Upper left, lower right, lower left, upper right
309
+ # 0 - 1, 2 - 3
310
+ anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
311
+
312
+ # get the closest point to every corner
313
+ atoms_pos = atoms_list[['x', 'y']].values
314
+ kdt = cKDTree(atoms_pos)
315
+ dists, neighbours = kdt.query(anchor_positions, k=1)
316
+
317
+ # check corner with the smallest total distance to closest atoms
318
+ if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
319
+ # visualize setup
320
+ begin_idx, end_idx = neighbours[:2]
321
+ else:
322
+ # visualize setup
323
+ begin_idx, end_idx = neighbours[2:]
324
+
325
+ #NOTE this proces may lead self-bonding for one atom
326
+ if begin_idx != end_idx:# avoid self-bond
327
+ bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score))
328
+ else:
329
+ continue
330
+ # return atoms_list.atom.values.tolist(), bonds_list
331
+ return atoms_list, bonds_list
332
+
333
+
334
+ def calculate_distance(coord1, coord2):
335
+ # Calculate Euclidean distance between two coordinates
336
+ return math.sqrt((coord1[0] - coord2[0])**2 + (coord1[1] - coord2[1])**2)
337
+
338
+ def assemble_atoms_with_charges(atom_list, charge_list):
339
+ used_charge_indices=set()
340
+ atom_list = atom_list.reset_index(drop=True)
341
+ # atom_list['atom'] = atom_list['atom'] + '0'
342
+ kdt = cKDTree(atom_list[['x','y']])
343
+ for i, charge in charge_list.iterrows():
344
+ if i in used_charge_indices:
345
+ continue
346
+ charge_=charge['charge']
347
+ # if charge_=='1':charge_='+'
348
+ dist, idx_atom=kdt.query([charge_list.x[i],charge_list.y[i]], k=1)
349
+ # atom_str=atom_list.loc[idx_atom,'atom']
350
+ if idx_atom not in atom_list.index:
351
+ print(f"Warning: idx_atom {idx_atom} is out of range for atom_list.")
352
+ continue # 跳过当前循环迭代
353
+ atom_str = atom_list.iloc[idx_atom]['atom']
354
+ if atom_str=='*':
355
+ atom_=atom_str + charge_
356
+ else:
357
+ try:
358
+ atom_ = re.findall(r'[A-Za-z*]+', atom_str)[0] + charge_
359
+ except Exception as e:
360
+ print(atom_str,charge_,charge_list)
361
+ print(f"@assemble_atoms_with_charges\n {e}\n{atom_list}")
362
+ atom_=atom_str + charge_
363
+ atom_list.loc[idx_atom,'atom']=atom_
364
+
365
+ return atom_list
366
+
367
+
368
+
369
+ def assemble_atoms_with_charges2(atom_list, charge_list, max_distance=10):
370
+ used_charge_indices = set()
371
+
372
+ for idx, atom in atom_list.iterrows():
373
+ atom_coord = atom['x'],atom['y']
374
+ atom_label = atom['atom']
375
+ closest_charge = None
376
+ min_distance = float('inf')
377
+
378
+ for i, charge in charge_list.iterrows():
379
+ if i in used_charge_indices:
380
+ continue
381
+
382
+ charge_coord = charge['x'],charge['y']
383
+ charge_label = charge['charge']
384
+
385
+ distance = calculate_distance(atom_coord, charge_coord)
386
+ #NOTE how t determin this max_distance, dependent on image size??
387
+ if distance <= max_distance and distance < min_distance:
388
+ closest_charge = charge
389
+ min_distance = distance
390
+
391
+
392
+ if closest_charge is not None:
393
+ if closest_charge['charge'] == '1':
394
+ charge_ = '+'
395
+ else:
396
+ charge_ = closest_charge['charge']
397
+ atom_ = atom['atom'] + charge_
398
+
399
+ # atom['atom'] = atom_
400
+ atom_list.loc[idx,'atom'] = atom_
401
+ used_charge_indices.add(tuple(charge))
402
+
403
+ else:
404
+ # atom['atom'] = atom['atom'] + '0'
405
+ atom_list.loc[idx,'atom'] = atom['atom'] + '0'
406
+
407
+ return atom_list
408
+
409
+
410
+
411
+ def bbox_to_graph_with_charge(output, idx_to_labels, bond_labels,result):
412
+
413
+ bond_labels_pre=bond_labels
414
+ # charge_labels = [18,19,20,21,22]#make influence
415
+ atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']])
416
+
417
+ try:
418
+ # print(atoms_mask.shape)
419
+ # print(output['pred_classes'].shape)
420
+ atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]]
421
+ if isinstance(atoms_list, pd.Series) and atoms_list.empty:
422
+ return None, None, None
423
+ else:
424
+ atoms_list = pd.DataFrame({'atom': atoms_list,
425
+ 'x': output['bbox_centers'][atoms_mask, 0],
426
+ 'y': output['bbox_centers'][atoms_mask, 1],
427
+ 'bbox': output['bbox'][atoms_mask].tolist() ,#need this for */other converting
428
+ 'scores': output['scores'][atoms_mask].tolist(),
429
+ })
430
+ except Exception as e:
431
+ print(output['pred_classes'][atoms_mask].dtype,output['pred_classes'][atoms_mask])#int64 [ 1 1 1 1 1 2 1 29]
432
+ print(e)
433
+ print(idx_to_labels)
434
+ # print(output['pred_classes'][atoms_mask],"output['pred_classes'][atoms_mask]")
435
+
436
+
437
+ # confict_atompaire=[]
438
+ # # 如果你想计算所有边界框之间的IOU,考虑2个原子box 重叠 是否要删掉一个?? TODO gmy version most box larger then normal mix the rules
439
+ # for i in range(len(atoms_list)):
440
+ # for j in range(i + 1, len(atoms_list)):
441
+ # iou_value = calculate_iou(atoms_list.bbox[i], atoms_list.bbox[j])
442
+ # if iou_value !=0:
443
+ # # print(f"IOU between box {i} and box {j}: {iou_value}")
444
+ # if i !=j : confict_atompaire.append([i,j])
445
+ # if len(confict_atompaire)>0:
446
+ # need_del=[]
447
+ # for i,j in confict_atompaire:
448
+ # ij_lab=[atoms_list.loc[i].atom,atoms_list.loc[j].atom ]
449
+ # ij_score=[atoms_list.loc[i].scores,atoms_list.loc[j].scores]
450
+ # # print(ij_lab,ij_score)
451
+ # if ij_lab==['C','N'] or ij_lab==['N','C']:
452
+ # if atoms_list.loc[i].atom =='C':
453
+ # need_del.append(i)
454
+ # else:
455
+ # need_del.append(j)
456
+ # elif atoms_list.loc[i].scores> atoms_list.loc[j].scores:
457
+ # need_del.append(j)
458
+ # elif atoms_list.loc[j].scores> atoms_list.loc[i].scores:
459
+ # need_del.append(i)
460
+ # print(need_del)
461
+ # atoms_list= atoms_list.drop(need_del)
462
+
463
+ charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']])
464
+ charge_list = [idx_to_labels[a] for a in output['pred_classes'][charge_mask]]
465
+ charge_list = pd.DataFrame({'charge': charge_list,
466
+ 'x': output['bbox_centers'][charge_mask, 0],
467
+ 'y': output['bbox_centers'][charge_mask, 1],
468
+ 'scores': output['scores'][charge_mask],
469
+
470
+ })
471
+
472
+ # print(charge_list,'\n@bbox_to_graph_with_charge')
473
+ try:
474
+ atoms_list['atom'] = atoms_list['atom']+'0'#add 0
475
+ except Exception as e:
476
+ print(e)
477
+ print(atoms_list['atom'],'atoms_list["atom"] @@ adding 0 ')
478
+
479
+
480
+ if len(charge_list) > 0:
481
+ atoms_list = assemble_atoms_with_charges(atoms_list,charge_list)
482
+ # else:#Note Most mols are not formal charged
483
+ # atoms_list['atom'] = atoms_list['atom']+'0'
484
+ # print(atoms_list,"after @@assemble_atoms_with_charges ")
485
+
486
+ # in case atoms with sign gets detected two times, keep only the signed one
487
+ for idx, row in atoms_list.iterrows():
488
+ if row.atom[-1] != '0':
489
+ try:
490
+ if row.atom[-2] != '-':#assume charge value -9~9
491
+ overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])]
492
+ except Exception as e:
493
+ print(row.atom,"@rin case atoms with sign gets detected two times")
494
+ print(e)
495
+ else:
496
+ overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])]
497
+
498
+ kdt = cKDTree(overlapping[['x', 'y']])
499
+ dists, neighbours = kdt.query([row.x, row.y], k=2)
500
+ if dists[1] < 7:
501
+ atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True)
502
+
503
+ bonds_list = []
504
+ # get bonds
505
+ # bond_mask=np.logical_not(np.logical_not(atoms_mask) | np.logical_not(charge_mask))
506
+ bond_mask=np.logical_not(atoms_mask) & np.logical_not(charge_mask)
507
+ for bbox, bond_type, score in zip(output['bbox'][bond_mask], #NOTE also including the charge part
508
+ output['pred_classes'][bond_mask],
509
+ output['scores'][bond_mask]):
510
+
511
+ # if idx_to_labels[bond_type] == 'SINGLE':
512
+ if len(idx_to_labels)==23:
513
+ if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
514
+ _margin = 5
515
+ else:
516
+ _margin = 8
517
+ elif len(idx_to_labels)==30:
518
+ _margin=0#ad this version bond dynamicaly changed
519
+ elif len(idx_to_labels)==24:
520
+ _margin=0#ad this version bond dynamicaly changed
521
+ # anchor positions are _margin distances away from the corners of the bbox.
522
+ anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
523
+ oposite_anchor_positions = anchor_positions.copy()
524
+ oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
525
+
526
+ # Upper left, lower right, lower left, upper right
527
+ # 0 - 1, 2 - 3
528
+ anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
529
+
530
+ # get the closest point to every corner
531
+ atoms_pos = atoms_list[['x', 'y']].values
532
+ kdt = cKDTree(atoms_pos)
533
+ dists, neighbours = kdt.query(anchor_positions, k=1)
534
+
535
+ # check corner with the smallest total distance to closest atoms
536
+ if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
537
+ # visualize setup
538
+ begin_idx, end_idx = neighbours[:2]
539
+ else:
540
+ # visualize setup
541
+ begin_idx, end_idx = neighbours[2:]
542
+
543
+ #NOTE this proces may lead self-bonding for one atom
544
+ if begin_idx != end_idx:
545
+ if bond_type in bond_labels:# avoid self-bond
546
+ bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score))
547
+ else:
548
+ print(f'this box may be charges box not bonds {[bbox, bond_type, score ]}')
549
+ else:
550
+ continue
551
+ # return atoms_list.atom.values.tolist(), bonds_list
552
+ # print(f"@box2graph: atom,bond nums:: {len(atoms_list)}, {len(bonds_list)}")
553
+ return atoms_list, bonds_list,charge_list#dataframe, list
554
+
555
+ def parse_atom(node):
556
+ s10 = [str(x) for x in range(10)]
557
+ # Determine atom and formal charge
558
+ if 'other' in node:
559
+ a = '*'
560
+ if '-' in node or '+' in node:
561
+ fc = -1 if node[-1] == '-' else 1
562
+ else:
563
+ fc = int(node[-2:]) if node[-2:] in s10 else 0
564
+ elif node[-1] in s10:
565
+ if '-' in node or '+' in node:
566
+ fc = -1 if node[-1] == '-' else 1
567
+ a = node[:-1]
568
+ else:
569
+ a = node[:-1]
570
+ fc = int(node[-1])
571
+ elif node[-1] == '+':
572
+ a = node[:-1]
573
+ fc = 1
574
+ elif node[-1] == '-':
575
+ a = node[:-1]
576
+ fc = -1
577
+ else:
578
+ a = node
579
+ fc = 0
580
+ return a, fc
581
+
582
+ #from engine
583
+
584
+ def iou_(box1, box2):
585
+ """
586
+ 计算两个框的 IoU(Intersection over Union)。
587
+ 参数:
588
+ box1, box2: [x1, y1, x2, y2] 格式的框坐标
589
+
590
+ 返回:
591
+ float: IoU 值
592
+ """
593
+ x1 = max(box1[0], box2[0])
594
+ y1 = max(box1[1], box2[1])
595
+ x2 = min(box1[2], box2[2])
596
+ y2 = min(box1[3], box2[3])
597
+
598
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
599
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
600
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
601
+ union = area1 + area2 - intersection
602
+ return intersection / union if union > 0 else 0
603
+
604
+
605
+ def calculate_iou(bbox1, bbox2):
606
+ # 提取坐标
607
+ x_min1, y_min1, x_max1, y_max1 = bbox1
608
+ x_min2, y_min2, x_max2, y_max2 = bbox2
609
+
610
+ # 计算交集坐标
611
+ x_min_inter = max(x_min1, x_min2)
612
+ y_min_inter = max(y_min1, y_min2)
613
+ x_max_inter = min(x_max1, x_max2)
614
+ y_max_inter = min(y_max1, y_max2)
615
+
616
+ # 计算交集面积
617
+ inter_width = max(0, x_max_inter - x_min_inter)
618
+ inter_height = max(0, y_max_inter - y_min_inter)
619
+ inter_area = inter_width * inter_height
620
+
621
+ # 计算两个框的面积
622
+ area1 = (x_max1 - x_min1) * (y_max1 - y_min1)
623
+ area2 = (x_max2 - x_min2) * (y_max2 - y_min2)
624
+
625
+ # 计算并集面积
626
+ union_area = area1 + area2 - inter_area
627
+
628
+ # 计算 IoU
629
+ iou = inter_area / union_area if union_area > 0 else 0
630
+
631
+ # 判断关系并记录
632
+ result = []
633
+ if iou == 0:
634
+ result.append("无重叠")
635
+ elif iou > 0:
636
+ result.append("有重叠")
637
+ if iou == 1:
638
+ result.append("完全重合")
639
+ elif inter_area == area2:
640
+ result.append("bbox1 包含 bbox2")
641
+ elif inter_area == area1:
642
+ result.append("bbox2 包含 bbox1")
643
+
644
+ return iou, result, inter_area, union_area
645
+
646
+ def adjust_bbox1(large_bbox, small_bbox, bond_bbox):
647
+ # 假设调整策略:扣除小的 atom bbox 和 bond box 的区域
648
+ # 这里简单假设从较大 bbox 中移除小的区域,可能需要根据具体需求调整
649
+ x_min_l, y_min_l, x_max_l, y_max_l = large_bbox
650
+ x_min_s, y_min_s, x_max_s, y_max_s = small_bbox
651
+ x_min_b, y_min_b, x_max_b, y_max_b = bond_bbox
652
+ scaled_box= max([x_min_l,x_min_s,x_min_b]),max([y_min_l,y_min_s,y_min_b]),x_max_l, y_max_l
653
+ return large_bbox
654
+ # 示例调整:如果小的 bbox 和 bond box 在较大 bbox 内,缩小较大 bbox
655
+ # if x_min_s > x_min_l and y_min_s > y_min_l:
656
+ # return [x_min_l, y_min_l, x_min_s, y_min_s] # 示例:保留左上部分
657
+ # return large_bbox # 默认不调整
658
+
659
+
660
+ def nms_per_class(labels, boxes, scores, iou_thresh=0.5):
661
+ """
662
+ 对每个类别应用 NMS,保留得分最高的框。
663
+ 参数:
664
+ labels: numpy array,类别标签
665
+ boxes: numpy array,框坐标 [x1, y1, x2, y2]
666
+ scores: numpy array,得分
667
+ iou_thresh: float,IoU 阈值
668
+ 返回:
669
+ dict: 筛选后的输出
670
+ """
671
+ # 按类别分组
672
+ unique_labels = np.unique(labels)
673
+ kept_indices = []
674
+ for label in unique_labels:
675
+ # 筛选当前类别的框
676
+ class_mask = labels == label
677
+ class_indices = np.where(class_mask)[0]
678
+ class_boxes = boxes[class_mask]
679
+ class_scores = scores[class_mask]
680
+
681
+ # 按得分从高到低排序
682
+ order = np.argsort(class_scores)[::-1]
683
+ class_boxes = class_boxes[order]
684
+ class_scores = class_scores[order]
685
+ class_indices = class_indices[order]
686
+
687
+ # NMS
688
+ keep = []
689
+ while len(class_scores) > 0:
690
+ # 保留得分最高的框
691
+ keep.append(class_indices[0])
692
+ if len(class_scores) == 1:
693
+ break
694
+
695
+ # 计算当前框与其他框的 IoU
696
+ ious = np.array([calculate_iou(class_boxes[0], box) for box in class_boxes[1:]])
697
+ # 保留 IoU 低于阈值的框
698
+ keep_mask = ious < iou_thresh
699
+ class_boxes = class_boxes[1:][keep_mask]
700
+ class_scores = class_scores[1:][keep_mask]
701
+ class_indices = class_indices[1:][keep_mask]
702
+
703
+ kept_indices.extend(keep)
704
+
705
+ # 根据保留的索引更新输出
706
+ kept_indices = np.array(kept_indices)
707
+ return {
708
+ 'labels': labels[kept_indices],
709
+ 'boxes': boxes[kept_indices],
710
+ 'scores': scores[kept_indices]
711
+ }
712
+