Upload 5 files
Browse files- I2M_R4.onnx +3 -0
- ONNX0630.py +2025 -0
- app.py +63 -7
- det_engine.py +0 -0
- utils.py +712 -0
I2M_R4.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3b9cc7af809d1b91d467400f416f800d3908cd5ec733d32b7cefe906b9b71122
|
| 3 |
+
size 212933527
|
ONNX0630.py
ADDED
|
@@ -0,0 +1,2025 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# coding: utf-8
|
| 3 |
+
import os,sys
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
model_usedpath='/nfs_home/bowen/works/pys/codes/i2m'
|
| 7 |
+
sys.path.append(model_usedpath)
|
| 8 |
+
home="/nfs_home/bowen/works/pys/codes/i2m"
|
| 9 |
+
bmd=f'/nfs_home/bowen/works/pys/codes/i2m/output0602/checkpoint0070.pth'#
|
| 10 |
+
parser = argparse.ArgumentParser()
|
| 11 |
+
parser.add_argument('--config', '-c', type=str, default=f'{home}/configs/rtdetr/rtdetr_r50vd_6x_coco.yml')
|
| 12 |
+
parser.add_argument('--resume', '-r', type=str, default=f'{bmd}')
|
| 13 |
+
parser.add_argument('--tuning', '-t', type=str,)# default='/nfs_home/bowen/model_checkpoint/rtdetr_r50vd_2x_coco_objects365_from_paddle.pth')
|
| 14 |
+
parser.add_argument('--test-only',default=True,)
|
| 15 |
+
parser.add_argument('--amp', default=False,)
|
| 16 |
+
parser.add_argument('--dataname', '-da', type=str, default=None)
|
| 17 |
+
parser.add_argument('--gpuid', '-gi', type=str, default=None)
|
| 18 |
+
parser.add_argument('--number', '-n', type=str, default=None)
|
| 19 |
+
args, unknown = parser.parse_known_args()#in jupyter
|
| 20 |
+
print(args)
|
| 21 |
+
if args.gpuid:
|
| 22 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpuid}'
|
| 23 |
+
else:
|
| 24 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = '4'
|
| 25 |
+
|
| 26 |
+
parralel_n=2
|
| 27 |
+
os.environ["OMP_NUM_THREADS"] = f"{parralel_n}" # OpenMP
|
| 28 |
+
os.environ["OPENBLAS_NUM_THREADS"] = f"{parralel_n}" # OpenBLAS
|
| 29 |
+
os.environ["MKL_NUM_THREADS"] = f"{parralel_n}" # Intel MKL
|
| 30 |
+
os.environ["VECLIB_MAXIMUM_THREADS"] = f"{parralel_n}" # macOS Accelerate
|
| 31 |
+
os.environ["NUMEXPR_NUM_THREADS"] = f"{parralel_n}" # NumExpr
|
| 32 |
+
"""
|
| 33 |
+
WARNING: OMP_NUM_THREADS set to 4, not 1. The computation speed will not be optimized if you use data parallel. It will fail if this PaddlePaddle binary is compiled with OpenBlas since OpenBlas does not support multi-threads.
|
| 34 |
+
PLEASE USE OMP_NUM_THREADS WISELY.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
import shutil
|
| 40 |
+
import pandas as pd
|
| 41 |
+
# print(sys.path)
|
| 42 |
+
print(__file__)
|
| 43 |
+
cur_dir = os.path.dirname(os.path.abspath(__file__))
|
| 44 |
+
print(cur_dir)
|
| 45 |
+
python_path=cur_dir
|
| 46 |
+
|
| 47 |
+
sys.path.append(python_path)
|
| 48 |
+
# model_path='I2M_realv2.onnx'
|
| 49 |
+
# model_abs_path = os.path.abspath(model_path)
|
| 50 |
+
# if os.path.exists(model_abs_path):
|
| 51 |
+
# print(model_abs_path)
|
| 52 |
+
|
| 53 |
+
# from src.solver.det_engine import *
|
| 54 |
+
import cv2
|
| 55 |
+
|
| 56 |
+
import sys,copy
|
| 57 |
+
import torchvision
|
| 58 |
+
|
| 59 |
+
import torch
|
| 60 |
+
import tqdm
|
| 61 |
+
import matplotlib.pyplot as plt
|
| 62 |
+
from matplotlib.patches import Rectangle, Circle
|
| 63 |
+
from det_engine import N_C_H_expand, C_H_expand,C_H_expand2, C_F_expand, formula_regex, RTDETRPostProcessor
|
| 64 |
+
from det_engine import SmilesEvaluator, molfpsim
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
import rdkit
|
| 68 |
+
from rdkit import Chem
|
| 69 |
+
from rdkit.Chem import Draw, AllChem
|
| 70 |
+
from rdkit import DataStructs
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
print("CUDA available:", torch.cuda.is_available())
|
| 75 |
+
print("Number of GPUs:", torch.cuda.device_count())
|
| 76 |
+
# print("Current device:", torch.cuda.current_device())
|
| 77 |
+
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "No GPU")
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# In[ ]:
|
| 81 |
+
|
| 82 |
+
# 计算bbox面积并找到最小的
|
| 83 |
+
def bbox_area(bbox):
|
| 84 |
+
x1, y1, x2, y2 = bbox
|
| 85 |
+
return (x2 - x1) * (y2 - y1)
|
| 86 |
+
|
| 87 |
+
def mol_idx( mol ):
|
| 88 |
+
atoms = mol.GetNumAtoms()
|
| 89 |
+
for idx in range( atoms ):
|
| 90 |
+
mol.GetAtomWithIdx( idx ).SetProp( 'molAtomMapNumber', str( mol.GetAtomWithIdx( idx ).GetIdx() ) )
|
| 91 |
+
return mol
|
| 92 |
+
|
| 93 |
+
# 移除原子索引
|
| 94 |
+
def mol_idx_del(mol):
|
| 95 |
+
atoms = mol.GetNumAtoms()
|
| 96 |
+
for idx in range(atoms):
|
| 97 |
+
atom = mol.GetAtomWithIdx(idx)
|
| 98 |
+
if atom.HasProp('molAtomMapNumber'): # 检查属性是否存在
|
| 99 |
+
atom.ClearProp('molAtomMapNumber') # 清除属性
|
| 100 |
+
return mol
|
| 101 |
+
|
| 102 |
+
def is_contained_in(bbox_small, bbox_large):
|
| 103 |
+
x_min_s, y_min_s, x_max_s, y_max_s = bbox_small
|
| 104 |
+
x_min_l, y_min_l, x_max_l, y_max_l = bbox_large
|
| 105 |
+
return (x_min_s >= x_min_l and x_max_s <= x_max_l and
|
| 106 |
+
y_min_s >= y_min_l and y_max_s <= y_max_l)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def NoRadical_Smi(smi):
|
| 110 |
+
aa=Chem.MolFromSmiles(smi)
|
| 111 |
+
for atom in aa.GetAtoms():
|
| 112 |
+
if atom.GetNumRadicalElectrons() > 0: # 检查是否有自由基
|
| 113 |
+
# print(f"找到自由基原子: {atom.GetSymbol()}, 自由电子数: {atom.GetNumRadicalElectrons()}")
|
| 114 |
+
# 添加氢原子以去除自由基
|
| 115 |
+
atom.SetNumRadicalElectrons(0) # 将自由电子数设为 0
|
| 116 |
+
# 根据硫原子的化合价调整氢原子数
|
| 117 |
+
atom.SetNumExplicitHs(atom.GetTotalValence() - atom.GetExplicitValence())
|
| 118 |
+
san_before=Chem.MolToSmiles(aa)
|
| 119 |
+
# print(san_before)
|
| 120 |
+
return san_before
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def select_longest_smiles(smiles):
|
| 124 |
+
# 将 SMILES 以 '.' 分割为多个部分
|
| 125 |
+
components = smiles.split('.')
|
| 126 |
+
# 选择字符数最多的部分作为主结构
|
| 127 |
+
longest_component = max(components, key=len)
|
| 128 |
+
return longest_component
|
| 129 |
+
|
| 130 |
+
# 解析电荷值
|
| 131 |
+
def parse_charge(charge_str):
|
| 132 |
+
if charge_str.endswith('+'):
|
| 133 |
+
return int(charge_str[:-1]) if charge_str[:-1] else 1 # "1+" -> 1, "+" -> 1
|
| 134 |
+
elif charge_str.endswith('-'):
|
| 135 |
+
return -int(charge_str[:-1]) if charge_str[:-1] else -1 # "2-" -> -2, "-" -> -1
|
| 136 |
+
else :
|
| 137 |
+
return int(charge_str)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def set_bondDriection(rwmol_,bondWithdirct):
|
| 142 |
+
#set direction
|
| 143 |
+
chiral_center_ids = Chem.FindMolChiralCenters(rwmol_, includeUnassigned=True)
|
| 144 |
+
# chiral_center_ids
|
| 145 |
+
chirai_ai2sterolab=dict()
|
| 146 |
+
if len(chiral_center_ids)>0:
|
| 147 |
+
chirai_ai2sterolab={ai:stero_lab for ai, stero_lab in chiral_center_ids }
|
| 148 |
+
|
| 149 |
+
for bi, binfo in bondWithdirct.items():
|
| 150 |
+
atom1_idx, atom2_idx, bond_type, score, w_d = binfo
|
| 151 |
+
bt= rwmol_.GetBondBetweenAtoms(atom1_idx, atom2_idx)#RDKit 的键是无向的,返回的是同一个 Bond 对象
|
| 152 |
+
current_begin = bt.GetBeginAtomIdx()
|
| 153 |
+
current_end = bt.GetEndAtomIdx()
|
| 154 |
+
if w_d=='wdge':
|
| 155 |
+
bond_dir_=rdchem.BondDir.BEGINWEDGE
|
| 156 |
+
reverse_dir = rdchem.BondDir.BEGINDASH
|
| 157 |
+
elif w_d=='dash':
|
| 158 |
+
bond_dir_=rdchem.BondDir.BEGINDASH
|
| 159 |
+
reverse_dir = rdchem.BondDir.BEGINWEDGE
|
| 160 |
+
|
| 161 |
+
if atom1_idx in chirai_ai2sterolab.keys():
|
| 162 |
+
if current_begin == atom1_idx:
|
| 163 |
+
bt.SetBondDir(bond_dir_)
|
| 164 |
+
print(f'atom1_idx dir')
|
| 165 |
+
else:
|
| 166 |
+
# 如果手性原子是终点,反转方向(例如用相反的楔形键)
|
| 167 |
+
bt.SetBondDir(reverse_dir)
|
| 168 |
+
print(f'atom1_idx reverse_dir')
|
| 169 |
+
elif atom2_idx in chirai_ai2sterolab.keys():
|
| 170 |
+
if current_begin == atom2_idx:
|
| 171 |
+
bt.SetBondDir(bond_dir_)
|
| 172 |
+
print(f'atom2_idx dir {bond_dir_} {reverse_dir}')
|
| 173 |
+
else:
|
| 174 |
+
# 如果手性原子是终点,反转方向(例如用相反的楔形键),but not work, just remove and add
|
| 175 |
+
rwmol_.RemoveBond(current_begin, current_end)
|
| 176 |
+
rwmol_.AddBond(current_end, current_begin, bt.GetBondType())
|
| 177 |
+
bond = rwmol_.GetBondBetweenAtoms(current_end, current_begin)
|
| 178 |
+
bond.SetBondDir(bond_dir_)
|
| 179 |
+
print(f'atom2_idx reverse_dir {bond_dir_} {reverse_dir}')
|
| 180 |
+
else:
|
| 181 |
+
print('bond stro not with chiral atom???, will ignore this stero bond infors')
|
| 182 |
+
print(f"{[bi, bond_dir_, current_begin,current_end]}")
|
| 183 |
+
return rwmol_
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# In[786]:
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
atom_labels = [0,1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
| 191 |
+
bond_labels = [13,14,15,16,17,18]
|
| 192 |
+
charge_labels=[19,20,21,22,23]
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
idx_to_labels={0:'other',1:'C',2:'O',3:'N',4:'Cl',5:'Br',6:'S',7:'F',8:'B',
|
| 196 |
+
9:'I',10:'P',11:'H',12:'Si',
|
| 197 |
+
#bond
|
| 198 |
+
13:'single',14:'wdge',15:'dash',
|
| 199 |
+
16:'=',17:'#',18:':',#aromatic
|
| 200 |
+
#charge
|
| 201 |
+
19:'-4',20:'-2',
|
| 202 |
+
21:'-1',#-
|
| 203 |
+
22:'+1',#+
|
| 204 |
+
23:'+2',
|
| 205 |
+
}
|
| 206 |
+
lab2idx={ v:k for k,v in idx_to_labels.items()}
|
| 207 |
+
bond_labels_symb={idx_to_labels[i] for i in bond_labels}
|
| 208 |
+
|
| 209 |
+
bond_dirs = {'NONE': Chem.rdchem.BondDir.NONE,
|
| 210 |
+
'ENDUPRIGHT': Chem.rdchem.BondDir.ENDUPRIGHT,
|
| 211 |
+
'BEGINWEDGE': Chem.rdchem.BondDir.BEGINWEDGE,
|
| 212 |
+
'BEGINDASH': Chem.rdchem.BondDir.BEGINDASH,
|
| 213 |
+
'ENDDOWNRIGHT': Chem.rdchem.BondDir.ENDDOWNRIGHT,
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
import pandas as pd
|
| 218 |
+
from typing import Iterable, List
|
| 219 |
+
from PIL import Image
|
| 220 |
+
import json,re
|
| 221 |
+
|
| 222 |
+
#TODO now abc single bond and OCR checking
|
| 223 |
+
#OCR 得到纯数字box 离原子距离应该小于最小的bond 距离,否则丢弃
|
| 224 |
+
from utils import calculate_iou,adjust_bbox1
|
| 225 |
+
from scipy.spatial import cKDTree, KDTree
|
| 226 |
+
import numpy as np
|
| 227 |
+
from rdkit import Chem
|
| 228 |
+
from paddleocr import PaddleOCR
|
| 229 |
+
from rdkit.Chem import rdchem, RWMol, CombineMols
|
| 230 |
+
|
| 231 |
+
from det_engine import ABBREVIATIONS,remove_SP
|
| 232 |
+
from det_engine import molExpanding,remove_bond_directions_if_no_chiral
|
| 233 |
+
from det_engine import (comparing_smiles,comparing_smiles2, remove_SP, expandABB,
|
| 234 |
+
ELEMENTS,
|
| 235 |
+
ABBREVIATIONS)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
from det_engine import expandABB
|
| 240 |
+
|
| 241 |
+
def bbox2shapes(bboxes, classes, lab2idx):
|
| 242 |
+
shapes = []
|
| 243 |
+
for bbox, label in zip(bboxes, classes):
|
| 244 |
+
x1, y1, x2, y2 = bbox
|
| 245 |
+
if label not in lab2idx :
|
| 246 |
+
label='other'
|
| 247 |
+
|
| 248 |
+
# Create shape dictionary
|
| 249 |
+
shape = {
|
| 250 |
+
"kie_linking": [],
|
| 251 |
+
"label": label,
|
| 252 |
+
"score": 1.0,
|
| 253 |
+
"points": [
|
| 254 |
+
[x1, y1], # top-left
|
| 255 |
+
[x2, y1], # top-right
|
| 256 |
+
[x2, y2], # bottom-right
|
| 257 |
+
[x1, y2] # bottom-left
|
| 258 |
+
],
|
| 259 |
+
"group_id": None,
|
| 260 |
+
"description": None,
|
| 261 |
+
"difficult": False,
|
| 262 |
+
"shape_type": "rectangle",
|
| 263 |
+
"flags": None,
|
| 264 |
+
"attributes": {}
|
| 265 |
+
}
|
| 266 |
+
shapes.append(shape)
|
| 267 |
+
return shapes
|
| 268 |
+
|
| 269 |
+
def get_longest_part(smi_string):
|
| 270 |
+
if '.' in smi_string: # 如果包含点号
|
| 271 |
+
parts = smi_string.split('.') # 按点号分割
|
| 272 |
+
longest_part = max(parts, key=len) # 取最长的部分
|
| 273 |
+
return longest_part
|
| 274 |
+
else:
|
| 275 |
+
return smi_string # 如果不包含点号,返回原字符串
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def split_output_by_numeric_classes(output):
|
| 279 |
+
# 初始化两个结果字典
|
| 280 |
+
numeric_output = {key: [] for key in output.keys()}
|
| 281 |
+
non_numeric_output = {key: [] for key in output.keys()}
|
| 282 |
+
|
| 283 |
+
# 遍历所有元素
|
| 284 |
+
for i in range(len(output['pred_classes'])):
|
| 285 |
+
class_name = output['pred_classes'][i]
|
| 286 |
+
|
| 287 |
+
# 检查是否是纯数字(包括正负号)
|
| 288 |
+
if re.fullmatch(r'^[+-]?\d+[+-]?$', class_name):
|
| 289 |
+
target_dict = numeric_output
|
| 290 |
+
else:
|
| 291 |
+
target_dict = non_numeric_output
|
| 292 |
+
|
| 293 |
+
# 将当前元素添加到相应的字典中
|
| 294 |
+
for key in output.keys():
|
| 295 |
+
target_dict[key].append(output[key][i])
|
| 296 |
+
|
| 297 |
+
return numeric_output, non_numeric_output
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def convert_shapes_to_output(json_data):
|
| 301 |
+
output = {
|
| 302 |
+
'bbox': [],
|
| 303 |
+
'bbox_centers': [],
|
| 304 |
+
'scores': [],
|
| 305 |
+
'pred_classes': []
|
| 306 |
+
}
|
| 307 |
+
for shape in json_data['shapes']:
|
| 308 |
+
# Extract bbox coordinates (assuming shape['points'] is [[x1,y1], [x2,y2], [x3,y3], [x4,y4]])
|
| 309 |
+
points = shape['points']
|
| 310 |
+
x_coords = [p[0] for p in points]
|
| 311 |
+
y_coords = [p[1] for p in points]
|
| 312 |
+
# Calculate bbox as [x_min, y_min, x_max, y_max]
|
| 313 |
+
bbox = [min(x_coords), min(y_coords), max(x_coords), max(y_coords)]
|
| 314 |
+
# Calculate center coordinates
|
| 315 |
+
center_x = (bbox[0] + bbox[2]) / 2
|
| 316 |
+
center_y = (bbox[1] + bbox[3]) / 2
|
| 317 |
+
|
| 318 |
+
# Get score (use 1.0 if not available)
|
| 319 |
+
score = shape.get('score', 1.0)
|
| 320 |
+
|
| 321 |
+
# Get class label (assuming shape['label'] contains the class)
|
| 322 |
+
pred_class = shape['label']
|
| 323 |
+
|
| 324 |
+
# Append to output
|
| 325 |
+
output['bbox'].append(bbox)
|
| 326 |
+
output['bbox_centers'].append([center_x, center_y])
|
| 327 |
+
output['scores'].append(score)
|
| 328 |
+
output['pred_classes'].append(pred_class)
|
| 329 |
+
|
| 330 |
+
return output
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def getJsonData(src_json):
|
| 334 |
+
with open(src_json, 'r') as f:
|
| 335 |
+
coco_data = json.load(f)
|
| 336 |
+
return coco_data
|
| 337 |
+
|
| 338 |
+
def replace_cg_notation(astr):
|
| 339 |
+
def replacer(match):
|
| 340 |
+
h_count = int(match.group(1))
|
| 341 |
+
c_count = (h_count - 1) // 2
|
| 342 |
+
return f'C{c_count}H{h_count}'
|
| 343 |
+
|
| 344 |
+
return re.sub(r'CgH(\d+)', replacer, astr)
|
| 345 |
+
|
| 346 |
+
def viewcheck(image_path,bbox,color='red'):
|
| 347 |
+
image = Image.open(image_path)
|
| 348 |
+
image_array = np.array(image)
|
| 349 |
+
# 创建绘图
|
| 350 |
+
plt.figure(figsize=(5, 4)) # 设置图像大小
|
| 351 |
+
plt.imshow(image_array) # 显示图像
|
| 352 |
+
bbox = np.array(bbox)
|
| 353 |
+
x_coords = (bbox[:, 0]+bbox[:, 2])*0.5
|
| 354 |
+
y_coords =( bbox[:, 1]+bbox[:, 3])*0.5
|
| 355 |
+
plt.scatter(x_coords, y_coords, c=color, s=50, label='Atom Centers', edgecolors='black')
|
| 356 |
+
# 添加标注(可选)
|
| 357 |
+
for i, (x, y) in enumerate(zip(x_coords, y_coords)):
|
| 358 |
+
plt.text(x, y, f'a {i}', fontsize=12, color=color, ha='center', va='bottom')
|
| 359 |
+
|
| 360 |
+
bclass_simple={"single":'-', "wdge":'w', "dash":'--',
|
| 361 |
+
"=":'=', "#":"#", ":":"aro"}
|
| 362 |
+
|
| 363 |
+
def viewcheck_b(image_path,bbox,bclass,color='red',figsize=(5,4)):
|
| 364 |
+
image = Image.open(image_path)
|
| 365 |
+
image_array = np.array(image)
|
| 366 |
+
# 创建绘图
|
| 367 |
+
plt.figure(figsize=figsize) # 设置图像大小
|
| 368 |
+
plt.imshow(image_array) # 显示图像
|
| 369 |
+
# 提取 bbox
|
| 370 |
+
bbox = np.array(bbox)
|
| 371 |
+
x_coords = (bbox[:, 0]+bbox[:, 2])*0.5
|
| 372 |
+
y_coords =( bbox[:, 1]+bbox[:, 3])*0.5
|
| 373 |
+
plt.scatter(x_coords, y_coords, c=color, s=50, label='bond Centers', edgecolors='black')
|
| 374 |
+
# 添加标注(可选)
|
| 375 |
+
for i, (x, y) in enumerate(zip(x_coords, y_coords)):
|
| 376 |
+
simpl_b=bclass_simple[bclass[i]]
|
| 377 |
+
plt.text(x, y, f'b{i}{simpl_b}', fontsize=12, color=color, ha='center', va='bottom')
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
def anchor_draw(image_path, bond_bbox):
|
| 381 |
+
# 加载图像
|
| 382 |
+
image = Image.open(image_path)
|
| 383 |
+
image_array = np.array(image)
|
| 384 |
+
|
| 385 |
+
# 初始化
|
| 386 |
+
_margin = 3
|
| 387 |
+
all_anchor_positions = []
|
| 388 |
+
all_oposite_anchor_positions = []
|
| 389 |
+
|
| 390 |
+
# 计算所有 bond 的锚点
|
| 391 |
+
for bi, bbox in enumerate(bond_bbox):
|
| 392 |
+
# 计算锚点
|
| 393 |
+
anchor_positions = (np.array(bbox) + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| 394 |
+
oposite_anchor_positions = anchor_positions.copy()
|
| 395 |
+
oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
| 396 |
+
anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
| 397 |
+
|
| 398 |
+
# 存储前两个点为 anchor_positions,后两个点为 oposite_anchor_positions
|
| 399 |
+
all_anchor_positions.append(anchor_positions[:2]) # [上左, 下右]
|
| 400 |
+
all_oposite_anchor_positions.append(anchor_positions[2:]) # [下左, 上右]
|
| 401 |
+
|
| 402 |
+
# 转换为 numpy 数组
|
| 403 |
+
all_anchor_positions = np.array(all_anchor_positions).reshape(-1, 2)
|
| 404 |
+
all_oposite_anchor_positions = np.array(all_oposite_anchor_positions).reshape(-1, 2)
|
| 405 |
+
|
| 406 |
+
# 图 1:绘制 anchor_positions
|
| 407 |
+
plt.figure(figsize=(10, 8))
|
| 408 |
+
plt.imshow(image_array)
|
| 409 |
+
plt.scatter(all_anchor_positions[:, 0], all_anchor_positions[:, 1], c='red', s=50, label='Anchor Positions', edgecolors='black')
|
| 410 |
+
for i, (x, y) in enumerate(all_anchor_positions):
|
| 411 |
+
plt.text(x, y, f'B{int(i/2)}:{i%2}', fontsize=10, color='white', ha='center', va='bottom')
|
| 412 |
+
plt.title('Anchor Positions (Upper Left, Lower Right)')
|
| 413 |
+
plt.legend()
|
| 414 |
+
plt.axis('off')
|
| 415 |
+
plt.savefig('anchor_positions.png')
|
| 416 |
+
|
| 417 |
+
plt.figure(figsize=(10, 8))
|
| 418 |
+
plt.imshow(image_array)
|
| 419 |
+
plt.scatter(all_oposite_anchor_positions[:, 0], all_oposite_anchor_positions[:, 1], c='blue', s=50, label='Opposite Anchor Positions', edgecolors='black')
|
| 420 |
+
for i, (x, y) in enumerate(all_oposite_anchor_positions):
|
| 421 |
+
plt.text(x, y, f'B{int(i/2)}:{i%2}', fontsize=10, color='white', ha='center', va='bottom')
|
| 422 |
+
plt.title('Opposite Anchor Positions (Lower Left, Upper Right)')
|
| 423 |
+
plt.legend()
|
| 424 |
+
plt.axis('off')
|
| 425 |
+
plt.savefig('Opposite_anchor_positions.png')
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
# 计算 4 个顶点
|
| 429 |
+
def get_corners(bbox):
|
| 430 |
+
x_min, y_min, x_max, y_max = bbox
|
| 431 |
+
return np.array([
|
| 432 |
+
[x_min, y_min], [x_max, y_min], # 上左,上右
|
| 433 |
+
[x_min, y_max], [x_max, y_max] # 下左,下右
|
| 434 |
+
])
|
| 435 |
+
|
| 436 |
+
# 计算两组顶点之间的最小距离并返回最近的 atom_idx
|
| 437 |
+
def find_nearest_atom(bond_corners, atom_bboxes, exclude_idx=None):
|
| 438 |
+
min_dist = float('inf')
|
| 439 |
+
nearest_idx = None
|
| 440 |
+
for i, atom_bbox in enumerate(atom_bboxes):
|
| 441 |
+
if exclude_idx is not None and i in exclude_idx:
|
| 442 |
+
continue
|
| 443 |
+
atom_corners = get_corners(atom_bbox)
|
| 444 |
+
for bc in bond_corners:
|
| 445 |
+
for ac in atom_corners:
|
| 446 |
+
dist = np.sqrt((bc[0] - ac[0])**2 + (bc[1] - ac[1])**2)
|
| 447 |
+
if dist < min_dist:
|
| 448 |
+
min_dist = dist
|
| 449 |
+
nearest_idx = i
|
| 450 |
+
return nearest_idx, min_dist
|
| 451 |
+
# 计算顶点到顶点的距离
|
| 452 |
+
def get_min_distance_to_atom_box(vertices, atom_bboxes, exclude_idx=None):
|
| 453 |
+
min_dist = float('inf')
|
| 454 |
+
closest_atom_idx = -1
|
| 455 |
+
for i, ab in enumerate(atom_bboxes):
|
| 456 |
+
if exclude_idx is not None and i in exclude_idx:
|
| 457 |
+
continue
|
| 458 |
+
ab_vertices = np.array([[ab[0], ab[1]], [ab[2], ab[3]], [ab[0], ab[3]], [ab[2], ab[1]]])
|
| 459 |
+
for v in vertices:
|
| 460 |
+
for av in ab_vertices:
|
| 461 |
+
dist = np.linalg.norm(v - av)
|
| 462 |
+
if dist < min_dist:
|
| 463 |
+
min_dist = dist
|
| 464 |
+
closest_atom_idx = i
|
| 465 |
+
return min_dist, closest_atom_idx
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# 检查孤立原子并添加键
|
| 469 |
+
def boxes_overlap(box1, box2):
|
| 470 |
+
x1, y1, x2, y2 = box1
|
| 471 |
+
x3, y3, x4, y4 = box2
|
| 472 |
+
return not (x2 < x3 or x4 < x1 or y2 < y3 or y4 < y1)
|
| 473 |
+
|
| 474 |
+
def min_corner_distance(box1, box2):
|
| 475 |
+
corners1 = [[box1[0], box1[1]], [box1[2], box1[3]], [box1[0], box1[3]], [box1[2], box1[1]]]
|
| 476 |
+
corners2 = [[box2[0], box2[1]], [box2[2], box2[3]], [box2[0], box2[3]], [box2[2], box2[1]]]
|
| 477 |
+
min_dist = float('inf')
|
| 478 |
+
for c1 in corners1:
|
| 479 |
+
for c2 in corners2:
|
| 480 |
+
dist = np.sqrt((c1[0] - c2[0])**2 + (c1[1] - c2[1])**2)
|
| 481 |
+
min_dist = min(min_dist, dist)
|
| 482 |
+
return min_dist
|
| 483 |
+
|
| 484 |
+
def clear_directory(path):
|
| 485 |
+
if os.path.exists(path):
|
| 486 |
+
print(f"Clearing contents of: {path}")
|
| 487 |
+
for filename in os.listdir(path):
|
| 488 |
+
file_path = os.path.join(path, filename)
|
| 489 |
+
try:
|
| 490 |
+
if os.path.isfile(file_path) or os.path.islink(file_path):
|
| 491 |
+
os.unlink(file_path) # 删除文件或符号链接
|
| 492 |
+
elif os.path.isdir(file_path):
|
| 493 |
+
shutil.rmtree(file_path) # 删除子目录
|
| 494 |
+
except Exception as e:
|
| 495 |
+
print(f'Failed to delete {file_path}. Reason: {e}')
|
| 496 |
+
else:
|
| 497 |
+
print(f"Directory does not exist: {path}")
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def NHR_string(text):
|
| 501 |
+
# 模式 1: 匹配 NHR 后跟一个数字
|
| 502 |
+
pattern1 = r'NHR\d'
|
| 503 |
+
# 模式 2: 匹配 RHN 后跟至少一个数字或小写字母
|
| 504 |
+
pattern2 = r'RHN[0-9a-z]+'
|
| 505 |
+
# 模式 3: 匹配 R 后跟至少一个数字或小写字母,再跟 NH,替换为 NHR
|
| 506 |
+
pattern3 = r'R[0-9a-z]+NH'
|
| 507 |
+
# 先处理模式 3,替换为 NHR
|
| 508 |
+
# text = re.sub(pattern3, 'NHR', text)
|
| 509 |
+
# 检查是否匹配模式 1
|
| 510 |
+
if re.search(pattern1, text):
|
| 511 |
+
# print(f"Matched pattern 1: {text}")
|
| 512 |
+
text='NH*'
|
| 513 |
+
# 检查是否匹配模式 2
|
| 514 |
+
elif re.search(pattern2, text):
|
| 515 |
+
# print(f"Matched pattern 2: {text}")
|
| 516 |
+
text='NH*'
|
| 517 |
+
elif re.search(pattern3, text):
|
| 518 |
+
text='NH*'
|
| 519 |
+
|
| 520 |
+
return text
|
| 521 |
+
|
| 522 |
+
from det_engine import normalize_ocr_text, check_and_fix_valence, rdkit_canonicalize_smiles
|
| 523 |
+
from det_engine import is_valid_chem_text,select_chem_expression
|
| 524 |
+
# Preprocess atom boxes to handle large functional groups
|
| 525 |
+
def preprocess_atom_boxes(atom_centers, atom_bbox, size_threshold_factor=2.5, min_subboxes=2):
|
| 526 |
+
"""
|
| 527 |
+
Identify large atom boxes and split them into smaller sub-boxes of approximately average size.
|
| 528 |
+
Returns updated atom_centers, atom_bbox, and a mapping of sub-boxes to original box IDs.
|
| 529 |
+
"""
|
| 530 |
+
# Calculate areas of atom boxes
|
| 531 |
+
areas = [(bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) for bbox in atom_bbox]
|
| 532 |
+
# Compute average area, excluding max and min to avoid outliers
|
| 533 |
+
if len(areas) > 2:
|
| 534 |
+
sorted_areas = sorted(areas)
|
| 535 |
+
avg_area = np.mean(sorted_areas[1:-1]) # Exclude min and max
|
| 536 |
+
else:
|
| 537 |
+
avg_area = np.mean(areas) if areas else 1.0
|
| 538 |
+
|
| 539 |
+
new_atom_centers = []
|
| 540 |
+
new_atom_bbox = []
|
| 541 |
+
original_to_subbox = {} # Maps original atom index to list of new sub-box indices
|
| 542 |
+
subbox_to_original = {} # Maps new sub-box index to original atom index
|
| 543 |
+
new_idx = 0
|
| 544 |
+
|
| 545 |
+
for i, (bbox, center) in enumerate(zip(atom_bbox, atom_centers)):
|
| 546 |
+
area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
|
| 547 |
+
# Identify large boxes (e.g., functional groups like CH2CH2CH2CH)
|
| 548 |
+
if area > avg_area * size_threshold_factor:
|
| 549 |
+
# Estimate number of sub-boxes based on area ratio
|
| 550 |
+
num_subboxes = max(min_subboxes, int(round(area / avg_area)))
|
| 551 |
+
# Split box along the longer dimension (x or y)
|
| 552 |
+
width = bbox[2] - bbox[0]
|
| 553 |
+
height = bbox[3] - bbox[1]
|
| 554 |
+
if width >= height:
|
| 555 |
+
# Split horizontally
|
| 556 |
+
sub_width = width / num_subboxes
|
| 557 |
+
subboxes = [
|
| 558 |
+
[bbox[0] + j * sub_width, bbox[1], bbox[0] + (j + 1) * sub_width, bbox[3]]
|
| 559 |
+
for j in range(num_subboxes)
|
| 560 |
+
]
|
| 561 |
+
else:
|
| 562 |
+
# Split vertically
|
| 563 |
+
sub_height = height / num_subboxes
|
| 564 |
+
subboxes = [
|
| 565 |
+
[bbox[0], bbox[1] + j * sub_height, bbox[2], bbox[1] + (j + 1) * sub_height]
|
| 566 |
+
for j in range(num_subboxes)
|
| 567 |
+
]
|
| 568 |
+
# Compute centers for sub-boxes
|
| 569 |
+
sub_centers = [
|
| 570 |
+
[(subbox[0] + subbox[2]) / 2, (subbox[1] + subbox[3]) / 2]
|
| 571 |
+
for subbox in subboxes
|
| 572 |
+
]
|
| 573 |
+
# Add sub-boxes and centers
|
| 574 |
+
new_atom_bbox.extend(subboxes)
|
| 575 |
+
new_atom_centers.extend(sub_centers)
|
| 576 |
+
original_to_subbox[i] = list(range(new_idx, new_idx + num_subboxes))
|
| 577 |
+
for j in range(num_subboxes):
|
| 578 |
+
subbox_to_original[new_idx + j] = i
|
| 579 |
+
new_idx += num_subboxes
|
| 580 |
+
else:
|
| 581 |
+
# Keep original box
|
| 582 |
+
new_atom_bbox.append(bbox)
|
| 583 |
+
new_atom_centers.append(center)
|
| 584 |
+
original_to_subbox[i] = [new_idx]
|
| 585 |
+
subbox_to_original[new_idx] = i
|
| 586 |
+
new_idx += 1
|
| 587 |
+
|
| 588 |
+
return np.array(new_atom_centers), new_atom_bbox, original_to_subbox, subbox_to_original
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
#add OCR here for placeholder_atoms adding
|
| 596 |
+
other2ppsocr=True
|
| 597 |
+
if other2ppsocr:
|
| 598 |
+
ocr = PaddleOCR(
|
| 599 |
+
use_angle_cls=True,
|
| 600 |
+
lang='latin',use_space_char=True,use_debug=False,
|
| 601 |
+
use_gpu=True if cv2.cuda.getCudaEnabledDeviceCount() > 0 else False
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
ocr2 = ocr2 = PaddleOCR(use_angle_cls=True,use_gpu =False,use_debug=False,
|
| 605 |
+
rec_algorithm='SVTR_LCNet',
|
| 606 |
+
lang="en")
|
| 607 |
+
# outcsv_filename=f"{output_directory}/{prefix_f}_withOCR.csv"
|
| 608 |
+
|
| 609 |
+
# box_thresh=0.45# 292 -240=52
|
| 610 |
+
box_thresh=0.5# 292 -233=59
|
| 611 |
+
useocr=True
|
| 612 |
+
box_matter=0
|
| 613 |
+
getacc=False
|
| 614 |
+
getfpsim=False
|
| 615 |
+
visual_check=False
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# da='acs'
|
| 620 |
+
# src_dir=f"D:\RPA\codes_share\wsl_\chem_box\\real\\real\{da}"
|
| 621 |
+
# src_file=f"{src_dir}.csv"
|
| 622 |
+
# df = pd.read_csv(src_file)
|
| 623 |
+
# dst_dirac = f"D:\RPA\codes_share\wsl_\chem_box\\need2check\{da}_ac"
|
| 624 |
+
# dst_dirb = f"D:\RPA\codes_share\wsl_\chem_box\\need2check\{da}_b"
|
| 625 |
+
da='acs'
|
| 626 |
+
#198th row, fixwd with expanded smiles
|
| 627 |
+
#326th, Tos we use the SO2Ph not SiC3 version,conflict fixed
|
| 628 |
+
|
| 629 |
+
#for view
|
| 630 |
+
# view_check_dir=f"D:\RPA\codes_share\wsl_\chem_box\\need2check\\{da}_fixedView"
|
| 631 |
+
da='CLEF'#NOTE
|
| 632 |
+
#462 S[O]a fixed,
|
| 633 |
+
#179 NHR8 fixed as NH-R8
|
| 634 |
+
#fix rows@582,750,411,7612,761 [(CH2)m] [(CH2)q] [(CH2)s] RDKIT NOT readable fixed as [CH2]
|
| 635 |
+
# 30,214, 795, 856, 583, 654, 618,339,138, 927, 203, 869, 261, 634, 180, 63,758, 718, 741,832,88,250, 799,303,956,810,596|bond erro, wrong smiles
|
| 636 |
+
#TODO still failed:1||SO2 mutil-rows from 992
|
| 637 |
+
|
| 638 |
+
da='UOB'#NOTE TODO fix rows@5119, 3420,990,1082,2451,3626,1634,627,5385
|
| 639 |
+
#all paseed @ v3
|
| 640 |
+
da='USPTO'#NOTE TODO fix rows@ 458, 10+ [O.], [NH2.],|| 4566,5523, ima!=smi, 3927,5234,4062|poly unitProb
|
| 641 |
+
#1658,4625,4944 #also SO3, SOOO erro
|
| 642 |
+
#1164 CHO,2703 CN NC err, 4421 CH2O erro
|
| 643 |
+
#4921, Rgroup error fixwed
|
| 644 |
+
#58, SMILES WRONG fixed (NH4NO)2
|
| 645 |
+
#2352 Fix wrong smiles
|
| 646 |
+
#4590 fix wrong smi
|
| 647 |
+
#3381, 4921 wrong smi fixed
|
| 648 |
+
#3071 image C8H13 may not expandable
|
| 649 |
+
# da='staker'#NOTE TODO fix rows@11422(del 11420.png, as it is table not chemMol)
|
| 650 |
+
#SO3 as SOOO, should be S(=O)(=O)O, as o-o-o strange in chemicstry, this erro 50 as below
|
| 651 |
+
#1971,5770,5972,5973,7541,7542,7666,7854,8258,8917,8918,11129,13281,14109,17131,17132,17189,17493,21091,21093,22314,22315,24524,24525,27294,27295,27296,27297,27586,27587,29562,29766,32835,33517,36197,36198,38199,38200,38661,38663,39174,39410,46717,48380,48381,48382,48443,48624,48625
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
da='JPO'
|
| 657 |
+
# da='staker'
|
| 658 |
+
|
| 659 |
+
if args.dataname:
|
| 660 |
+
da=args.dataname
|
| 661 |
+
|
| 662 |
+
# ac_b=False
|
| 663 |
+
ac_b=False
|
| 664 |
+
ac_b_smilesnotsame_writJson=True
|
| 665 |
+
if ac_b:
|
| 666 |
+
view_check_dir=f"D:\RPA\codes_share\wsl_\X-AnyLabeling\\need2check\\view_check_{da}\\failed"
|
| 667 |
+
view_dirac=f"{view_check_dir}/{da}_ac"
|
| 668 |
+
view_dirb=f"{view_check_dir}/{da}_b"
|
| 669 |
+
dst_dirac =view_dirac#when double check used
|
| 670 |
+
dst_dirb =view_dirb
|
| 671 |
+
|
| 672 |
+
# Construct paths using os.path.join
|
| 673 |
+
src_dir = cur_dir
|
| 674 |
+
src_file = os.path.join(src_dir, f"{da}.csv")
|
| 675 |
+
# df = pd.read_csv(src_file)
|
| 676 |
+
# print(f"src_file:\n{src_file}")
|
| 677 |
+
# Construct check and view directories
|
| 678 |
+
# view_check_dir2 = os.path.join(src_dir, f"{da}_fixedView", "failed")
|
| 679 |
+
# view_check_dir2 = os.path.join(src_dir, f"view_check_{da}", "failed")
|
| 680 |
+
view_check_dir2 = os.path.join(src_dir, f"view_check_{da}", "v3")#v3 has the manulay ac b json
|
| 681 |
+
|
| 682 |
+
N=1
|
| 683 |
+
if args.number:
|
| 684 |
+
N=int(args.number)
|
| 685 |
+
|
| 686 |
+
# view_dirac2 = os.path.join(view_check_dir2, f"{da}_ac_N_{N}")
|
| 687 |
+
# view_dirb2 = os.path.join(view_check_dir2, f"{da}_b_N_{N}")
|
| 688 |
+
view_dirac2 = os.path.join(view_check_dir2, f"{da}_ac")
|
| 689 |
+
view_dirb2 = os.path.join(view_check_dir2, f"{da}_b")
|
| 690 |
+
|
| 691 |
+
view_dirac_tmp = os.path.join(view_check_dir2, f"{da}_actmp")
|
| 692 |
+
view_dirac_tmp_debug=True
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
if ac_b:
|
| 696 |
+
need2mkdir=[view_check_dir,view_dirac, view_dirb, view_check_dir2,view_dirac2, view_dirb2]
|
| 697 |
+
else:
|
| 698 |
+
need2mkdir=[ view_check_dir2,view_dirac2, view_dirb2,view_dirac_tmp]
|
| 699 |
+
for dir_v in need2mkdir :
|
| 700 |
+
|
| 701 |
+
if not os.path.exists(dir_v):
|
| 702 |
+
os.makedirs(dir_v)
|
| 703 |
+
|
| 704 |
+
# ac_b=False
|
| 705 |
+
ac_b=False
|
| 706 |
+
# if ac_b:#update _ac _b
|
| 707 |
+
# # 清空两个目录
|
| 708 |
+
# clear_directory(view_dirac2) #NOTE we only check for the better models faileds
|
| 709 |
+
# clear_directory(view_dirb2)
|
| 710 |
+
|
| 711 |
+
# #note box not equal as abbv eixsits, process single bond..TODO need check and fixing, may be need rdkit smiles Match
|
| 712 |
+
# df['file_name'] = df['file_path'].str.split('/').str[-1]
|
| 713 |
+
# # df['file_base'] =f"{da}_" + df['file_name'].str.replace('.png', '', regex=False)
|
| 714 |
+
# df['file_base'] = df['file_name'].str.replace('.png', '', regex=False)
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
# outcsv_filename=f"{src_dir}/{da}_OUTPUTwithOCR.csv"
|
| 719 |
+
outcsv_filename=os.path.join(src_dir, f"{da}_OUTPUTwithOCR.csv")
|
| 720 |
+
|
| 721 |
+
if getacc:
|
| 722 |
+
acc_summary=f"{outcsv_filename}.I2Msummary.txt"
|
| 723 |
+
flogout = open(f'{acc_summary}' , 'w')
|
| 724 |
+
flogout2 = open(f'{outcsv_filename}_acBoxWrong' , 'a')
|
| 725 |
+
failed=[]
|
| 726 |
+
failed_fb=[]
|
| 727 |
+
mydiff=[]
|
| 728 |
+
simRD=0
|
| 729 |
+
sim=0
|
| 730 |
+
simRDlist=[]
|
| 731 |
+
mysum=0
|
| 732 |
+
|
| 733 |
+
smiles_data = pd.DataFrame({'file_name': [],
|
| 734 |
+
'SMILESori':[],
|
| 735 |
+
'SMILESpre':[],
|
| 736 |
+
'SMILESexp':[],
|
| 737 |
+
})
|
| 738 |
+
|
| 739 |
+
# rows_check = df
|
| 740 |
+
miss_file=[]
|
| 741 |
+
miss_filejs=[]
|
| 742 |
+
# for id_, row in rows_check.iterrows():
|
| 743 |
+
debug=False
|
| 744 |
+
|
| 745 |
+
rt_out=False
|
| 746 |
+
if not ac_b:
|
| 747 |
+
view_dirac=view_dirac2
|
| 748 |
+
view_dirb=view_dirb2
|
| 749 |
+
dst_dirac =view_dirac#when double check used
|
| 750 |
+
dst_dirb =view_dirb
|
| 751 |
+
test_dir=f'./test/'#TODO WEB_dev put test images here
|
| 752 |
+
# pngs=[f for f in os.listdir(view_dirac2) if '.png' in f]
|
| 753 |
+
pngs=[f for f in os.listdir(test_dir) if '.png' in f]
|
| 754 |
+
# if da=='staker':
|
| 755 |
+
# pngs=[f for f in os.listdir("/nfs_home/bowen/works/pys/codes/i2m/datas/real/staker") if '.png' in f]
|
| 756 |
+
|
| 757 |
+
rt_out=True
|
| 758 |
+
# view_check_dir3=f"D:\RPA\codes_share\wsl_\chem_box\\need2check\\{da}_fixedView\\v3"
|
| 759 |
+
# view_check_dir3= os.path.join(src_dir, f"{da}_fixedView", "failed")
|
| 760 |
+
view_check_dir3= os.path.join(src_dir, f"view_check_{da}", "v4")#with model output
|
| 761 |
+
view_dirac3=f"{view_check_dir3}/{da}_ac"
|
| 762 |
+
view_dirb3=f"{view_check_dir3}/{da}_b"
|
| 763 |
+
for dir_v in [view_check_dir3,view_dirac3, view_dirb3]:
|
| 764 |
+
if not os.path.exists(dir_v):
|
| 765 |
+
os.makedirs(dir_v)
|
| 766 |
+
# pngs=[f for f in os.listdir(view_dirac3) if '.png' in f]
|
| 767 |
+
|
| 768 |
+
#as abbrev expanded lead a b not equal as original
|
| 769 |
+
acn=False
|
| 770 |
+
bn=False
|
| 771 |
+
|
| 772 |
+
|
| 773 |
+
import torchvision.transforms.v2 as T
|
| 774 |
+
|
| 775 |
+
def image_to_tensor(image_path,debug=True):
|
| 776 |
+
image = Image.open(image_path)
|
| 777 |
+
w, h = image.size
|
| 778 |
+
|
| 779 |
+
# 处理灰度或其他模式
|
| 780 |
+
if image.mode == "L":
|
| 781 |
+
if debug: print("检测到灰度图像 (1 通道),转换为 RGB...")
|
| 782 |
+
image = image.convert("RGB")
|
| 783 |
+
elif image.mode != "RGB":
|
| 784 |
+
if debug: print(f"检测到 {image.mode} 模式,转换为 RGB...")
|
| 785 |
+
image = image.convert("RGB")
|
| 786 |
+
# Define a transform to convert the image to a tensor and normalize it
|
| 787 |
+
transform = T.Compose([
|
| 788 |
+
T.Resize((640, 640)), # 调整大小
|
| 789 |
+
# T.ToImageTensor(), # 转换为 PyTorch Tensor
|
| 790 |
+
T.ToTensor(),
|
| 791 |
+
lambda x: x.to(torch.float32), # 手动转换数据类型# T.ConvertDtype(dtype=torch.float32), # 转换数据类型
|
| 792 |
+
])
|
| 793 |
+
|
| 794 |
+
# Apply the transform to the image
|
| 795 |
+
tensor = transform(image)
|
| 796 |
+
|
| 797 |
+
return tensor,w,h
|
| 798 |
+
def ouptnp2abc(output,idx_to_labels):
|
| 799 |
+
# Define label lists
|
| 800 |
+
atom_labels = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
|
| 801 |
+
bond_labels = [13, 14, 15, 16, 17, 18]
|
| 802 |
+
charge_labels = [19, 20, 21, 22, 23]
|
| 803 |
+
# Create masks for atoms, bonds, and charges
|
| 804 |
+
atom_mask = np.isin(output['pred_classes'], atom_labels)
|
| 805 |
+
bond_mask = np.isin(output['pred_classes'], bond_labels)
|
| 806 |
+
charge_mask = np.isin(output['pred_classes'], charge_labels)
|
| 807 |
+
# Initialize output dictionaries
|
| 808 |
+
output_a = {'bbox': [], 'bbox_centers': [], 'scores': [], 'pred_classes': []}
|
| 809 |
+
output_b = {'bbox': [], 'bbox_centers': [], 'scores': [], 'pred_classes': []}
|
| 810 |
+
output_c = {'bbox': [], 'bbox_centers': [], 'scores': [], 'pred_classes': []}
|
| 811 |
+
# Filter and convert for atoms (output_a)
|
| 812 |
+
if np.any(atom_mask):
|
| 813 |
+
output_a['bbox'] = output['bbox'][atom_mask].tolist()
|
| 814 |
+
output_a['bbox_centers'] = output['bbox_centers'][atom_mask].tolist()
|
| 815 |
+
output_a['scores'] = output['scores'][atom_mask].tolist()
|
| 816 |
+
output_a['pred_classes'] = output['pred_classes'][atom_mask].tolist()
|
| 817 |
+
output_a['pred_classes'] = [idx_to_labels[idx] for idx in output_a['pred_classes']]
|
| 818 |
+
|
| 819 |
+
# Filter and convert for bonds (output_b)
|
| 820 |
+
if np.any(bond_mask):
|
| 821 |
+
output_b['bbox'] = output['bbox'][bond_mask].tolist()
|
| 822 |
+
output_b['bbox_centers'] = output['bbox_centers'][bond_mask].tolist()
|
| 823 |
+
output_b['scores'] = output['scores'][bond_mask].tolist()
|
| 824 |
+
output_b['pred_classes'] = output['pred_classes'][bond_mask].tolist()
|
| 825 |
+
output_b['pred_classes'] = [idx_to_labels[idx] for idx in output_b['pred_classes']]
|
| 826 |
+
|
| 827 |
+
# Filter and convert for charges (output_c)
|
| 828 |
+
if np.any(charge_mask):
|
| 829 |
+
output_c['bbox'] = output['bbox'][charge_mask].tolist()
|
| 830 |
+
output_c['bbox_centers'] = output['bbox_centers'][charge_mask].tolist()
|
| 831 |
+
output_c['scores'] = output['scores'][charge_mask].tolist()
|
| 832 |
+
output_c['pred_classes'] = output['pred_classes'][charge_mask].tolist()
|
| 833 |
+
output_c['pred_classes'] = [idx_to_labels[idx] for idx in output_c['pred_classes']]
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
return output_a, output_b, output_c
|
| 837 |
+
|
| 838 |
+
def bbox2center(bbox):
|
| 839 |
+
x_center = (bbox[:, 0] + bbox[:, 2]) / 2
|
| 840 |
+
y_center = (bbox[:, 1] + bbox[:, 3]) / 2
|
| 841 |
+
# center_coords = torch.stack((x_center, y_center), dim=1)
|
| 842 |
+
centers = np.stack((x_center, y_center), axis=1)
|
| 843 |
+
return centers
|
| 844 |
+
|
| 845 |
+
class bcolors:
|
| 846 |
+
HEADER = '\033[95m'
|
| 847 |
+
OKBLUE = '\033[94m'
|
| 848 |
+
OKCYAN = '\033[96m'
|
| 849 |
+
OKGREEN = '\033[92m'
|
| 850 |
+
WARNING = '\033[93m'
|
| 851 |
+
FAIL = '\033[91m'
|
| 852 |
+
ENDC = '\033[0m'
|
| 853 |
+
BOLD = '\033[1m'
|
| 854 |
+
UNDERLINE = '\033[4m'
|
| 855 |
+
|
| 856 |
+
postprocessor=RTDETRPostProcessor(classes_dict=idx_to_labels, use_focal_loss=True, num_top_queries=300, remap_mscoco_category=False)
|
| 857 |
+
|
| 858 |
+
#load onnx model
|
| 859 |
+
import torch.onnx
|
| 860 |
+
import onnx
|
| 861 |
+
import onnxruntime as ort
|
| 862 |
+
onnx_model_path = "/nfs_home/bowen/works/pys/codes/i2m/I2M_R4.onnx"#20250605
|
| 863 |
+
def image_to_tensor2(image_path):
|
| 864 |
+
# img_path="/cadd_data/samba_share/from_docker/data/work_space/ori/real/acs/op300209p-Scheme-c2-4.png"
|
| 865 |
+
img_path= image_path
|
| 866 |
+
if img_path is not None and os.path.exists(img_path):
|
| 867 |
+
# Load Image From Path Directly
|
| 868 |
+
# NOTE: Potential issue - unable to handle the flipped image.
|
| 869 |
+
# Temporary workaround: cv_image = cv2.imread(img_path)
|
| 870 |
+
cv_image = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), -1)
|
| 871 |
+
input_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
|
| 872 |
+
|
| 873 |
+
image_h, image_w = input_image.shape[:2]
|
| 874 |
+
input_h, input_w = 640,640
|
| 875 |
+
|
| 876 |
+
# Compute the scaling factors
|
| 877 |
+
ratio_h = input_h / image_h
|
| 878 |
+
ratio_w = input_w / image_w
|
| 879 |
+
print(ratio_h,ratio_w)
|
| 880 |
+
# Perform the pre-processing steps
|
| 881 |
+
image = cv2.resize(
|
| 882 |
+
input_image, (0, 0), fx=ratio_w, fy=ratio_h, interpolation=2
|
| 883 |
+
)
|
| 884 |
+
image = image.transpose((2, 0, 1)) # HWC to CHW
|
| 885 |
+
image = np.ascontiguousarray(image).astype("float32")
|
| 886 |
+
image /= 255 # 0 - 255 to 0.0 - 1.0
|
| 887 |
+
if len(image.shape) == 3:
|
| 888 |
+
image = image[None]
|
| 889 |
+
wh=image_w,image_h
|
| 890 |
+
return torch.from_numpy(image), image_w, image_h
|
| 891 |
+
|
| 892 |
+
# 准备输入数据
|
| 893 |
+
def to_numpy(tensor):
|
| 894 |
+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
|
| 895 |
+
|
| 896 |
+
# 加载并检查ONNX模型
|
| 897 |
+
onnx_model = onnx.load(onnx_model_path)
|
| 898 |
+
onnx.checker.check_model(onnx_model)
|
| 899 |
+
print("ONNX模型检查通过")
|
| 900 |
+
# 使用ONNX Runtime进行推理
|
| 901 |
+
ort_session = ort.InferenceSession(onnx_model_path)
|
| 902 |
+
onnx_=True
|
| 903 |
+
dfm=0
|
| 904 |
+
|
| 905 |
+
# ff='US20030130506A1_p0046_x1541_y1396_c00157'
|
| 906 |
+
# for ff in pngs:
|
| 907 |
+
def main():
|
| 908 |
+
for id_, ff in enumerate(pngs):
|
| 909 |
+
# if 'US20060154945A1_p0016_x0402_y1570_c00053' not in ff: continue
|
| 910 |
+
# indices = df.index[df['image_id'] == ff[:-4]].tolist()
|
| 911 |
+
# indices = df.index[df['file_base'] == ff[:-4]].tolist()
|
| 912 |
+
# try:
|
| 913 |
+
# id_=indices[0]
|
| 914 |
+
# except Exception as e:
|
| 915 |
+
# print([indices,ff])
|
| 916 |
+
# raise e
|
| 917 |
+
# SMILESori=rows_check.iloc[id_].SMILES
|
| 918 |
+
# file_base=rows_check.iloc[id_].file_base
|
| 919 |
+
# # if debug: print(id_, file_base)
|
| 920 |
+
# image_path= os.path.join(dst_dirac, f"{file_base}.png")
|
| 921 |
+
|
| 922 |
+
# ac_datadir=os.path.join(dst_dirac, f"{file_base}.json")
|
| 923 |
+
# ac_exist= os.path.exists(ac_datadir)
|
| 924 |
+
# if not ac_exist:
|
| 925 |
+
# miss_filejs.append(ac_datadir)
|
| 926 |
+
# continue
|
| 927 |
+
# image_path= os.path.join(test_dir, f"{ff}")
|
| 928 |
+
image_path= os.path.join(f"{ff}")
|
| 929 |
+
SMILESori=''
|
| 930 |
+
print(f"@@@@@@@@@@@@@@@@@@@@@@@ {id_}\n{image_path}\n {SMILESori}")
|
| 931 |
+
# print(image_path,b_datadir,ac_datadir)
|
| 932 |
+
|
| 933 |
+
img_ori = Image.open(image_path).convert('RGB')
|
| 934 |
+
w_ori, h_ori = img_ori.size # 获取原始图像的尺寸
|
| 935 |
+
# if [w_ori, h_ori]!=[256,256] and da=='staker':
|
| 936 |
+
# print(f"图像的尺寸不为256x256,而是{w_ori}x{h_ori},请检查图像是否正确:\n{ff}")
|
| 937 |
+
# continue
|
| 938 |
+
|
| 939 |
+
# print(f"图像的尺寸",[w_ori, h_ori ])
|
| 940 |
+
scale_x = 1000 / w_ori
|
| 941 |
+
scale_y = 1000 / h_ori
|
| 942 |
+
img_ori_1k = img_ori.resize((1000,1000))
|
| 943 |
+
# Example usage: #change thie image
|
| 944 |
+
tensor,w,h = image_to_tensor(image_path)
|
| 945 |
+
# tensor,w,h = image_to_tensor2(image_path)
|
| 946 |
+
tensor=tensor.unsqueeze(0)
|
| 947 |
+
if onnx_:
|
| 948 |
+
ort_inputs = {
|
| 949 |
+
ort_session.get_inputs()[0].name: to_numpy(tensor),
|
| 950 |
+
# ort_session.get_inputs()[1].name: to_numpy(dummy_grid)
|
| 951 |
+
}
|
| 952 |
+
ort_outputs = ort_session.run(None, ort_inputs)
|
| 953 |
+
# 转换为PyTorch格式
|
| 954 |
+
onnx_pred_logits = torch.from_numpy(ort_outputs[0])
|
| 955 |
+
onnx_pred_boxes = torch.from_numpy(ort_outputs[1])
|
| 956 |
+
# 构建与原模型一致的输出字典
|
| 957 |
+
onnx_output_dict = {
|
| 958 |
+
"pred_logits": onnx_pred_logits,
|
| 959 |
+
"pred_boxes": onnx_pred_boxes,
|
| 960 |
+
}
|
| 961 |
+
# else:
|
| 962 |
+
# #use original model
|
| 963 |
+
# with torch.no_grad():
|
| 964 |
+
# # print("training",_model.training)
|
| 965 |
+
# outputs_tensor = _model(tensor)
|
| 966 |
+
# 打印并比较结果
|
| 967 |
+
# print("PyTorch输出:", outputs_tensor)
|
| 968 |
+
# print("ONNX Runtime输出:", ort_outputs[0],ort_outputs[1],len(ort_outputs))
|
| 969 |
+
|
| 970 |
+
ori_size=torch.Tensor([w,h]).long().unsqueeze(0)
|
| 971 |
+
# result_ = postprocessor(outputs_tensor, ori_size)
|
| 972 |
+
result_ = postprocessor(onnx_output_dict, ori_size)
|
| 973 |
+
|
| 974 |
+
score_=result_[0]['scores']
|
| 975 |
+
boxe_=result_[0]['boxes']
|
| 976 |
+
label_=result_[0]['labels']
|
| 977 |
+
selected_indices =score_ > box_thresh
|
| 978 |
+
output={
|
| 979 |
+
'labels': label_[selected_indices].to("cpu").numpy(),
|
| 980 |
+
'boxes': boxe_[selected_indices].to("cpu").numpy(),
|
| 981 |
+
'scores': score_[selected_indices].to("cpu").numpy()
|
| 982 |
+
}
|
| 983 |
+
center_coords=bbox2center(output['boxes'])
|
| 984 |
+
output = {'bbox': output["boxes"],
|
| 985 |
+
'bbox_centers': center_coords,
|
| 986 |
+
'scores': output["scores"],
|
| 987 |
+
'pred_classes': output["labels"]}
|
| 988 |
+
output_a, output_b, output_c= ouptnp2abc(output,idx_to_labels)
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
|
| 992 |
+
if debug:print("c,a,b>>>>>",len(output_c['pred_classes']),len(output_a['pred_classes']),len(output_b['pred_classes']))
|
| 993 |
+
if len(output_a['pred_classes'])==0:
|
| 994 |
+
file_path = 'Check_AboxIs0.txt'
|
| 995 |
+
content = f'{image_path}@@{id_}---{image_path}\n'
|
| 996 |
+
# 文件存在则追加写入,不存在则创建并写入
|
| 997 |
+
with open(file_path, 'a', encoding='utf-8') as f:
|
| 998 |
+
f.write(content)
|
| 999 |
+
continue #may need manulay labeling
|
| 1000 |
+
|
| 1001 |
+
overlap_records = []
|
| 1002 |
+
to_remove = set()
|
| 1003 |
+
bond_boxes = output_b['bbox']
|
| 1004 |
+
|
| 1005 |
+
bboxes = output_a['bbox'].copy()
|
| 1006 |
+
a_center = output_a['bbox_centers'].copy()
|
| 1007 |
+
|
| 1008 |
+
scores = output_a['scores'].copy()
|
| 1009 |
+
pred_classes = output_a['pred_classes'].copy()
|
| 1010 |
+
to_remove = set()
|
| 1011 |
+
|
| 1012 |
+
# 计算所有 atom bbox 之间的 IoU, 并根据 IoU 进行处理
|
| 1013 |
+
for i in range(len(bboxes)):
|
| 1014 |
+
for j in range(i + 1, len(bboxes)):
|
| 1015 |
+
# iou, relationship, inter_area, union_area = calculate_iou(bboxes[i], bboxes[j])
|
| 1016 |
+
x_min1, y_min1, x_max1, y_max1 = bboxes[i]
|
| 1017 |
+
x_min2, y_min2, x_max2, y_max2 = bboxes[j]
|
| 1018 |
+
# 计算交集坐标
|
| 1019 |
+
x_min_inter = max(x_min1, x_min2)
|
| 1020 |
+
y_min_inter = max(y_min1, y_min2)
|
| 1021 |
+
x_max_inter = min(x_max1, x_max2)
|
| 1022 |
+
y_max_inter = min(y_max1, y_max2)
|
| 1023 |
+
# 计算交集面积
|
| 1024 |
+
inter_width = max(0, x_max_inter - x_min_inter)
|
| 1025 |
+
inter_height = max(0, y_max_inter - y_min_inter)
|
| 1026 |
+
inter_area = inter_width * inter_height
|
| 1027 |
+
# 计算两个框的面积
|
| 1028 |
+
area1 = (x_max1 - x_min1) * (y_max1 - y_min1)
|
| 1029 |
+
area2 = (x_max2 - x_min2) * (y_max2 - y_min2)
|
| 1030 |
+
# 计算并集面积
|
| 1031 |
+
union_area = area1 + area2 - inter_area
|
| 1032 |
+
# 计算 IoU
|
| 1033 |
+
iou = inter_area / union_area if union_area > 0 else 0
|
| 1034 |
+
score_i = scores[i] if scores[i] is not None else -1
|
| 1035 |
+
score_j = scores[j] if scores[j] is not None else -1
|
| 1036 |
+
# 完全重合
|
| 1037 |
+
if iou == 1:
|
| 1038 |
+
if score_i > score_j:
|
| 1039 |
+
to_remove.add(j)
|
| 1040 |
+
else:
|
| 1041 |
+
to_remove.add(i)
|
| 1042 |
+
elif iou>=0.8 and iou <1.0:#NOTE fix me if not right
|
| 1043 |
+
if score_i > score_j:
|
| 1044 |
+
to_remove.add(j)
|
| 1045 |
+
if debug: print([i,j,score_i,score_j],iou,f"will remove j {j}, i-j {i,j}")
|
| 1046 |
+
else:
|
| 1047 |
+
to_remove.add(i)
|
| 1048 |
+
if debug: print([i,j,score_i,score_j],iou,f"will remove i {i}, i-j {i,j} ")
|
| 1049 |
+
|
| 1050 |
+
# 包含关系
|
| 1051 |
+
elif iou > 0 and iou < 0.89 :
|
| 1052 |
+
if debug: print([i,j,score_i,score_j],iou,"<<<<<<111")
|
| 1053 |
+
if inter_area == area1 and area1 < area2: # bbox[j] 包含 bbox[i]
|
| 1054 |
+
large_idx, small_idx = j, i
|
| 1055 |
+
elif inter_area == area2 and area2 < area1: # bbox[i] 包含 bbox[j]
|
| 1056 |
+
large_idx, small_idx = i, j
|
| 1057 |
+
else:
|
| 1058 |
+
if debug: print([i,j,score_i,score_j],iou,'OVERLAP without processed this version')
|
| 1059 |
+
continue
|
| 1060 |
+
# 检查是否包含 bond box
|
| 1061 |
+
contains_bond = False
|
| 1062 |
+
for bond_bbox in bond_boxes:
|
| 1063 |
+
if is_contained_in(bond_bbox, bboxes[large_idx]):
|
| 1064 |
+
contains_bond = True
|
| 1065 |
+
# 调整较大 bbox
|
| 1066 |
+
bboxes[large_idx] = adjust_bbox1(bboxes[large_idx], bboxes[small_idx], bond_bbox)
|
| 1067 |
+
# to_remove.add(small_idx)
|
| 1068 |
+
break
|
| 1069 |
+
if not contains_bond:
|
| 1070 |
+
to_remove.add(small_idx)#NOTE use the cutoff >0.45,
|
| 1071 |
+
elif iou==0:#==0
|
| 1072 |
+
pass
|
| 1073 |
+
else:
|
| 1074 |
+
print([i,j,score_i,score_j],iou,"<<<<<<222")
|
| 1075 |
+
print('what this case ???')
|
| 1076 |
+
|
| 1077 |
+
# 删除被移除的 bbox
|
| 1078 |
+
atom_bboxes = [bboxes[i] for i in range(len(bboxes)) if i not in to_remove]
|
| 1079 |
+
atom_scores = [scores[i] for i in range(len(scores)) if i not in to_remove]
|
| 1080 |
+
atom_centers = [a_center[i] for i in range(len(a_center)) if i not in to_remove]
|
| 1081 |
+
atom_classes = [pred_classes[i] for i in range(len(pred_classes)) if i not in to_remove]
|
| 1082 |
+
#TODO need sort box with x first, then y dim, useful for * with multi neiborbond
|
| 1083 |
+
# Sort atom_bboxes and atom_scores by x1 (bbox[0]) first, then y1 (bbox[1])
|
| 1084 |
+
sorted_indices = sorted(range(len(atom_bboxes)), key=lambda i: (atom_bboxes[i][0], atom_bboxes[i][1]))
|
| 1085 |
+
atom_bboxes = [atom_bboxes[i] for i in sorted_indices]
|
| 1086 |
+
atom_scores = [atom_scores[i] for i in sorted_indices]
|
| 1087 |
+
atom_centers = [atom_centers[i] for i in sorted_indices]
|
| 1088 |
+
atom_classes = [atom_classes[i] for i in sorted_indices]
|
| 1089 |
+
|
| 1090 |
+
print(len(atom_classes),'xxxxxxxx')
|
| 1091 |
+
bond_bbox = output_b['bbox'].copy()
|
| 1092 |
+
bond_scores = output_b['scores'].copy()
|
| 1093 |
+
bond_classes = output_b['pred_classes'].copy()
|
| 1094 |
+
|
| 1095 |
+
if len(atom_bboxes)!=len(output_a['bbox']):
|
| 1096 |
+
# print(f"need manualy fix ac json file------ {file_base}")
|
| 1097 |
+
if getacc:
|
| 1098 |
+
flogout2.write(f"fix ac json file---: {file_base} \n")
|
| 1099 |
+
# raise ValueError(f"need manualy fix ac json file------ {file_base}")
|
| 1100 |
+
# NOTE NEED this codes follow code not del box , used4 prepare recorrect json boxfiles
|
| 1101 |
+
# atom_bboxes = output_a['bbox'].copy()
|
| 1102 |
+
# atom_scores = output_a['scores'].copy()
|
| 1103 |
+
# atom_classes = output_a['pred_classes'].copy()
|
| 1104 |
+
# atom_centers = output_a['bbox_centers'].copy()
|
| 1105 |
+
# sorted_indices = sorted(range(len(atom_bboxes)), key=lambda i: (atom_bboxes[i][0], atom_bboxes[i][1]))
|
| 1106 |
+
# atom_bboxes = [atom_bboxes[i] for i in sorted_indices]
|
| 1107 |
+
# atom_scores = [atom_scores[i] for i in sorted_indices]
|
| 1108 |
+
# atom_centers = [atom_centers[i] for i in sorted_indices]
|
| 1109 |
+
# atom_classes = [atom_classes[i] for i in sorted_indices]
|
| 1110 |
+
|
| 1111 |
+
|
| 1112 |
+
# atom_bbox=final_bboxes
|
| 1113 |
+
bonds = dict()
|
| 1114 |
+
b2aa = dict()
|
| 1115 |
+
singleAtomBond = dict()
|
| 1116 |
+
bondWithdirct = dict()
|
| 1117 |
+
_margin = 0
|
| 1118 |
+
bond_direction = dict()
|
| 1119 |
+
|
| 1120 |
+
# Preprocess atom boxes
|
| 1121 |
+
atom_centers_, atom_bbox_, original_to_subbox, subbox_to_original = preprocess_atom_boxes(atom_centers, atom_bboxes)
|
| 1122 |
+
# Build KDTree with updated atom centers
|
| 1123 |
+
tree_atom = KDTree(atom_centers_)#have to includ the splited box
|
| 1124 |
+
if debug:
|
| 1125 |
+
print(f"KDTree built with {len(atom_centers_)} atom centers")
|
| 1126 |
+
|
| 1127 |
+
for bi, (bbox, bond_type) in enumerate(zip(bond_bbox, bond_classes)):
|
| 1128 |
+
score = bond_scores[bi]
|
| 1129 |
+
if score is None:
|
| 1130 |
+
score = 1.0 # From manual addition
|
| 1131 |
+
bond_scores[bi] = score
|
| 1132 |
+
|
| 1133 |
+
anchor_positions = (np.array(bbox) + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| 1134 |
+
oposite_anchor_positions = anchor_positions.copy()
|
| 1135 |
+
oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
| 1136 |
+
anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
| 1137 |
+
|
| 1138 |
+
# Query KDTree for nearest atoms
|
| 1139 |
+
dists, neighbours = tree_atom.query(anchor_positions, k=1)
|
| 1140 |
+
if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
|
| 1141 |
+
begin_idx, end_idx = neighbours[:2]
|
| 1142 |
+
else:
|
| 1143 |
+
begin_idx, end_idx = neighbours[2:]
|
| 1144 |
+
|
| 1145 |
+
# Map sub-box indices back to original atom indices
|
| 1146 |
+
atom1_idx = int(subbox_to_original[int(begin_idx)])
|
| 1147 |
+
atom2_idx = int(subbox_to_original[int(end_idx)])
|
| 1148 |
+
|
| 1149 |
+
if atom1_idx == atom2_idx:
|
| 1150 |
+
if debug:
|
| 1151 |
+
print(f"singleAtomBond detected with bond id:{bi} atomIdx1 == atomIdx2 ::{[atom1_idx, atom2_idx]}")
|
| 1152 |
+
singleAtomBond[bi] = [atom1_idx]
|
| 1153 |
+
|
| 1154 |
+
min_ai = min([atom1_idx, atom2_idx])
|
| 1155 |
+
max_ai = max([atom1_idx, atom2_idx])
|
| 1156 |
+
|
| 1157 |
+
if debug:
|
| 1158 |
+
print(f"Bond {bi}: [{min_ai}, {max_ai}]")
|
| 1159 |
+
|
| 1160 |
+
# Assign bond type
|
| 1161 |
+
if bond_type in ['single', 'wdge', 'dash', '-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| 1162 |
+
bond_ = [min_ai, max_ai, 'SINGLE', score]
|
| 1163 |
+
if bond_type in ['wdge', 'dash', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| 1164 |
+
bondWithdirct[bi] = [min_ai, max_ai, 'SINGLE', score, bond_type]
|
| 1165 |
+
elif bond_type == '=':
|
| 1166 |
+
bond_ = [min_ai, max_ai, 'DOUBLE', score]
|
| 1167 |
+
elif bond_type == '#':
|
| 1168 |
+
bond_ = [min_ai, max_ai, 'TRIPLE', score]
|
| 1169 |
+
elif bond_type == ':':
|
| 1170 |
+
bond_ = [min_ai, max_ai, 'AROMATIC', score]
|
| 1171 |
+
else:
|
| 1172 |
+
if debug:
|
| 1173 |
+
print(f"Unknown bond_type: {bond_type} for bond {bi} [{min_ai, max_ai}]")
|
| 1174 |
+
bond_ = [min_ai, max_ai, 'SINGLE', score]
|
| 1175 |
+
|
| 1176 |
+
bonds[bi] = bond_
|
| 1177 |
+
b2aa[bi] = sorted([min_ai, max_ai])
|
| 1178 |
+
|
| 1179 |
+
if debug:
|
| 1180 |
+
print(f"bonds {len(bonds)}, b2aa {len(b2aa)}, singleAtomBond {len(singleAtomBond)}, bondWithdirct {len(bondWithdirct)}")
|
| 1181 |
+
|
| 1182 |
+
|
| 1183 |
+
#try to set up a2b, baesed on bond-anchor_positions--atom center relationship
|
| 1184 |
+
a2b=dict()#may be updated as following singleAtomBond cases process
|
| 1185 |
+
isolated_a=set()
|
| 1186 |
+
aa2b_d2=dict()
|
| 1187 |
+
for k,v in b2aa.items():
|
| 1188 |
+
vt=(v[0],v[1])
|
| 1189 |
+
if vt in aa2b_d2:
|
| 1190 |
+
aa2b_d2[vt].append(k)
|
| 1191 |
+
else:
|
| 1192 |
+
aa2b_d2[vt]=[k]
|
| 1193 |
+
|
| 1194 |
+
for a in set(v):
|
| 1195 |
+
if a not in a2b.keys():
|
| 1196 |
+
a2b[a]=[k]
|
| 1197 |
+
else:
|
| 1198 |
+
a2b[a].append(k)
|
| 1199 |
+
|
| 1200 |
+
# 初始化 a2neib, iso_lated atom box and singleAtomBond box process need
|
| 1201 |
+
a2neib = {}
|
| 1202 |
+
# 遍历 a2b,构建邻居关系
|
| 1203 |
+
for atom, bns in a2b.items():
|
| 1204 |
+
neighbors = set() # 使用集合去重
|
| 1205 |
+
for bond in bns:
|
| 1206 |
+
atom_pair = b2aa[bond] # 获取 bond 连接的原子对
|
| 1207 |
+
# 如果当前原子在 atom_pair 中,添加另一个原子作为邻居
|
| 1208 |
+
nei={ai for ai in atom_pair if ai !=atom }
|
| 1209 |
+
neighbors.update(nei)
|
| 1210 |
+
# if atom in atom_pair:
|
| 1211 |
+
# other_atom = atom_pair[0] if atom == atom_pair[1] else atom_pair[1]
|
| 1212 |
+
# neighbors.add(other_atom)
|
| 1213 |
+
a2neib[atom] = sorted(list(neighbors)) # 转换为有序列表
|
| 1214 |
+
|
| 1215 |
+
#check isolated atom exsit, if need add bond for isloated atom box when overlaping with other atom box
|
| 1216 |
+
isolated_a=set()
|
| 1217 |
+
for ai, a_lab in enumerate(atom_classes):
|
| 1218 |
+
if ai not in a2b.keys():
|
| 1219 |
+
isolated_a.add(ai)
|
| 1220 |
+
if debug:print("detected possible isolated atom:", isolated_a)
|
| 1221 |
+
|
| 1222 |
+
|
| 1223 |
+
repeate_bonds={k:v for k,v in aa2b_d2.items() if len(v)>=2 }
|
| 1224 |
+
if debug:print(f"repeat bond box ids {repeate_bonds}")
|
| 1225 |
+
#get the minimu size of bond box, check isolated_a atom box overlap with other atom box, if overlap, then add bond box (default bond label with single, score 1.0) between them
|
| 1226 |
+
# update a2b,b2aa, and bond box bond_classes, elif not box not overlap, the isolated_a box min(4 point of box cornners to other atom box connrer) enough small than the existed bond box size
|
| 1227 |
+
if len(isolated_a)>0:
|
| 1228 |
+
isolated_a2del=[]
|
| 1229 |
+
# 计算现有键的最小尺寸
|
| 1230 |
+
bond_sizes = []
|
| 1231 |
+
for bbox in bond_bbox:
|
| 1232 |
+
width = bbox[2] - bbox[0]
|
| 1233 |
+
height = bbox[3] - bbox[1]
|
| 1234 |
+
size = min(width, height) # 使用较小边作为键的尺寸
|
| 1235 |
+
bond_sizes.append(size)
|
| 1236 |
+
min_bond_size = min(bond_sizes) if bond_sizes else 10.0 # 默认值若无键
|
| 1237 |
+
if debug:print("min_bond_size ",min_bond_size, 10)
|
| 1238 |
+
new_bond_idx = len(bond_bbox)
|
| 1239 |
+
isolated_aFound=[]
|
| 1240 |
+
singleAtomBond_fixed=[]
|
| 1241 |
+
# at2b_dist=dict()
|
| 1242 |
+
|
| 1243 |
+
for iso_atom in isolated_a:
|
| 1244 |
+
iso_box = atom_bboxes[iso_atom]
|
| 1245 |
+
|
| 1246 |
+
#with SingleAtomBond first then check with other atom box, may a1a2 repeat on >=two bonds
|
| 1247 |
+
for bi,atom_idx_list in singleAtomBond.items():
|
| 1248 |
+
bond_box = bond_bbox[bi]
|
| 1249 |
+
atom1_idx = atom_idx_list[0]
|
| 1250 |
+
bond_vertices = get_corners(bond_box)
|
| 1251 |
+
# 计算 atom1_center 到 bond box 4 个顶点的距离
|
| 1252 |
+
atom1_center = atom_centers[atom1_idx]
|
| 1253 |
+
distances = [np.linalg.norm(np.array(atom1_center) - v) for v in bond_vertices]
|
| 1254 |
+
closest_indices = np.argsort(distances)[:2] # 距离最小的两个顶点
|
| 1255 |
+
connected_vertices = bond_vertices[closest_indices]
|
| 1256 |
+
unconnected_vertices = bond_vertices[[i for i in range(4) if i not in closest_indices]]
|
| 1257 |
+
# exclude_=a2neib[atom1_idx]
|
| 1258 |
+
exclude_=[atom1_idx]+a2neib[atom1_idx]#add it self
|
| 1259 |
+
print(f'exclude this atom itself :: {exclude_},and its neiboughs {a2neib[atom1_idx]}')
|
| 1260 |
+
# 找到 atom2(未连接端到所有 atom box 顶点的最小距离)
|
| 1261 |
+
# _, atom2_idx_ = get_min_distance_to_atom_box(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
|
| 1262 |
+
atom2_idx_, dist2 = find_nearest_atom(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
|
| 1263 |
+
if iso_atom == atom2_idx_:
|
| 1264 |
+
# 从 atom1 找到最近的另一个 atom (atom2_1)
|
| 1265 |
+
if atom2_idx_< atom1_idx:
|
| 1266 |
+
k=[atom2_idx_, atom1_idx]
|
| 1267 |
+
else:
|
| 1268 |
+
k=[atom1_idx, atom2_idx_]
|
| 1269 |
+
|
| 1270 |
+
if atom2_idx_ not in a2neib[atom1_idx]:
|
| 1271 |
+
b2aa[bi]=k
|
| 1272 |
+
bonds[bi][0]=k[0]
|
| 1273 |
+
bonds[bi][1]=k[1]
|
| 1274 |
+
a2b.setdefault(iso_atom, []).append(bi)
|
| 1275 |
+
|
| 1276 |
+
if debug: print(f'@@isolated_a fix the SingleAtomBond {bi} as bond:{bonds[bi]} !!')
|
| 1277 |
+
singleAtomBond_fixed.append(bi)
|
| 1278 |
+
isolated_aFound.append(atom2_idx_)
|
| 1279 |
+
|
| 1280 |
+
if len(repeate_bonds)>0:
|
| 1281 |
+
at2b_dist=dict()#NOTE the case repeate bonds with isolated atom box
|
| 1282 |
+
iso_box_vertices = get_corners(iso_box)
|
| 1283 |
+
iso_atom_center = atom_centers[iso_atom]
|
| 1284 |
+
bond_box_idx_, bond_box_dist = find_nearest_atom(iso_box_vertices, bond_bbox, exclude_idx=[])
|
| 1285 |
+
for a1a2,bis in repeate_bonds.items():#{(2, 3): [3, 4]}
|
| 1286 |
+
for bi in bis:
|
| 1287 |
+
if bi ==bond_box_idx_:
|
| 1288 |
+
bond_box = bond_bbox[bi]
|
| 1289 |
+
bond_vertices = get_corners(bond_box)
|
| 1290 |
+
a1_,a2_=a1a2
|
| 1291 |
+
a1_atombox= atom_bboxes[a1_]
|
| 1292 |
+
a2_atombox= atom_bboxes[a2_]
|
| 1293 |
+
a1_flag= boxes_overlap(a1_atombox, bond_box)
|
| 1294 |
+
a2_flag= boxes_overlap(a2_atombox, bond_box)
|
| 1295 |
+
if a1_flag:
|
| 1296 |
+
atom1_idx_=a1_
|
| 1297 |
+
dist1=0
|
| 1298 |
+
elif a2_flag:
|
| 1299 |
+
atom1_idx_=a2_
|
| 1300 |
+
dist1=0
|
| 1301 |
+
else:
|
| 1302 |
+
distances = [np.linalg.norm(np.array(iso_atom_center) - v) for v in bond_vertices]
|
| 1303 |
+
closest_indices2 = np.argsort(distances)[:2] # 距离最小的两个顶点
|
| 1304 |
+
connected_vertices2 = bond_vertices[closest_indices2]#isolated_close
|
| 1305 |
+
connected_vertices1 = bond_vertices[[i for i in range(4) if i not in closest_indices2]]
|
| 1306 |
+
atom1_idx_, dist1 = find_nearest_atom(connected_vertices1, atom_bboxes, exclude_idx=[iso_atom])
|
| 1307 |
+
if debug:print("a1_flag,a2_flag,atom1_idx_, iso_atom",[a1_flag,a2_flag,atom1_idx_,iso_atom])
|
| 1308 |
+
min_ai=min([atom1_idx_,iso_atom])
|
| 1309 |
+
max_ai=max([atom1_idx_,iso_atom])
|
| 1310 |
+
k=(min_ai,max_ai)
|
| 1311 |
+
print(k,'repeate',bi)
|
| 1312 |
+
if k not in at2b_dist:
|
| 1313 |
+
at2b_dist[k]=[bi,a1a2,dist1]
|
| 1314 |
+
else:
|
| 1315 |
+
if dist1< at2b_dist[k][1]:
|
| 1316 |
+
at2b_dist[k]=[bi,a1a2,dist1]
|
| 1317 |
+
if debug:print(f"repate bond box id: {bi} fixed with {at2b_dist}")
|
| 1318 |
+
isolated_aFound.append(iso_atom)
|
| 1319 |
+
# for k,v in at2b_dist.items():
|
| 1320 |
+
#update bond atom box mapping
|
| 1321 |
+
isolated_a2del.append(iso_atom)
|
| 1322 |
+
b2aa[bi] = [min_ai,max_ai]
|
| 1323 |
+
a2b.setdefault(iso_atom, []).append(bi)
|
| 1324 |
+
bonds[bi][0]=k[0]
|
| 1325 |
+
bonds[bi][1]=k[1]
|
| 1326 |
+
if bi in bondWithdirct:
|
| 1327 |
+
bondWithdirct[bi][0]=k[0]
|
| 1328 |
+
bondWithdirct[bi][1]=k[1]
|
| 1329 |
+
|
| 1330 |
+
isolated_a=[ ai for ai in isolated_a if ai not in isolated_aFound]#updated
|
| 1331 |
+
singleAtomBond={bi:aili for bi,aili in singleAtomBond.items() if bi not in singleAtomBond_fixed}#updated
|
| 1332 |
+
|
| 1333 |
+
for iso_atom in isolated_a:
|
| 1334 |
+
iso_box = atom_bboxes[iso_atom]
|
| 1335 |
+
#with SingleAtomBond first then chec
|
| 1336 |
+
for other_idx, other_box in enumerate(atom_bboxes):
|
| 1337 |
+
if other_idx == iso_atom\
|
| 1338 |
+
or (atom_classes[other_idx] in ['other',"*"] and atom_classes[iso_atom] in ['other',"*"]):
|
| 1339 |
+
#also not inlcude other -- *
|
| 1340 |
+
continue
|
| 1341 |
+
# 检查重叠
|
| 1342 |
+
min_ai=min([iso_atom,other_idx])
|
| 1343 |
+
max_ai=max([iso_atom,other_idx])
|
| 1344 |
+
|
| 1345 |
+
if boxes_overlap(iso_box, other_box):
|
| 1346 |
+
# 添加默认单键
|
| 1347 |
+
new_bbox = [
|
| 1348 |
+
min(iso_box[0], other_box[0]),
|
| 1349 |
+
min(iso_box[1], other_box[1]),
|
| 1350 |
+
max(iso_box[2], other_box[2]),
|
| 1351 |
+
max(iso_box[3], other_box[3])
|
| 1352 |
+
]
|
| 1353 |
+
bond_bbox.append(new_bbox)
|
| 1354 |
+
bond_classes.append('single')
|
| 1355 |
+
bond_scores.append(1.0)
|
| 1356 |
+
b2aa[new_bond_idx] = [iso_atom, other_idx]
|
| 1357 |
+
a2b.setdefault(iso_atom, []).append(new_bond_idx)
|
| 1358 |
+
a2b.setdefault(other_idx, []).append(new_bond_idx)
|
| 1359 |
+
isolated_a2del.append(iso_atom)
|
| 1360 |
+
new_bond_idx += 1
|
| 1361 |
+
bond_=[min_ai, max_ai, 'SINGLE', 1.0]
|
| 1362 |
+
last_=len(bonds)
|
| 1363 |
+
bonds[last_] = bond_
|
| 1364 |
+
|
| 1365 |
+
if debug:
|
| 1366 |
+
print(f"添加键 {new_bond_idx-1} 连接原子 {iso_atom} 和 {other_idx},as isoated box overlap with it ")
|
| 1367 |
+
# break
|
| 1368 |
+
else:
|
| 1369 |
+
# 检查角点最小距离
|
| 1370 |
+
min_dist = float('inf')
|
| 1371 |
+
closest_atom = None
|
| 1372 |
+
dist = min_corner_distance(iso_box, other_box)
|
| 1373 |
+
if dist < min_dist:
|
| 1374 |
+
min_dist = dist
|
| 1375 |
+
closest_atom = other_idx
|
| 1376 |
+
if min_dist < min_bond_size:
|
| 1377 |
+
# 添加默认单键
|
| 1378 |
+
new_bbox = [
|
| 1379 |
+
min(iso_box[0], atom_bboxes[closest_atom][0]),
|
| 1380 |
+
min(iso_box[1], atom_bboxes[closest_atom][1]),
|
| 1381 |
+
max(iso_box[2], atom_bboxes[closest_atom][2]),
|
| 1382 |
+
max(iso_box[3], atom_bboxes[closest_atom][3])
|
| 1383 |
+
]
|
| 1384 |
+
bond_bbox.append(new_bbox)
|
| 1385 |
+
bond_classes.append('single')
|
| 1386 |
+
bond_scores.append(1.0)
|
| 1387 |
+
b2aa[new_bond_idx] = [iso_atom, closest_atom]
|
| 1388 |
+
a2b.setdefault(iso_atom, []).append(new_bond_idx)
|
| 1389 |
+
a2b.setdefault(closest_atom, []).append(new_bond_idx)
|
| 1390 |
+
isolated_a2del.append(iso_atom)
|
| 1391 |
+
new_bond_idx += 1
|
| 1392 |
+
if debug:
|
| 1393 |
+
print(f"添加键 {new_bond_idx-1} 连接原子 {iso_atom} 和 {closest_atom} (距离 {min_dist})")
|
| 1394 |
+
bond_=[min_ai, max_ai, 'SINGLE', 1.0]
|
| 1395 |
+
last_=len(bonds)
|
| 1396 |
+
bonds[last_] = bond_
|
| 1397 |
+
|
| 1398 |
+
# break#as isolated may be get more than 2 bonds
|
| 1399 |
+
if debug:
|
| 1400 |
+
print('isolated_a2del and isolated_a number',len(isolated_a2del),len(isolated_a))
|
| 1401 |
+
print('isolated_a ',isolated_a)
|
| 1402 |
+
print('isolated_a2del ',isolated_a2del)
|
| 1403 |
+
|
| 1404 |
+
a2b = dict(sorted(a2b.items()))
|
| 1405 |
+
|
| 1406 |
+
# 先处理 singleAtomBond, 再removed duplicated
|
| 1407 |
+
if len(singleAtomBond) > 0:
|
| 1408 |
+
# 初始化 a2neib
|
| 1409 |
+
a2neib = {}
|
| 1410 |
+
# 遍历 a2b,构建邻居关系
|
| 1411 |
+
for atom, bns in a2b.items():
|
| 1412 |
+
neighbors = set() # 使用集合去重
|
| 1413 |
+
for bond in bns:
|
| 1414 |
+
atom_pair = b2aa[bond] # 获取 bond 连接的原子对
|
| 1415 |
+
# 如果当前原子在 atom_pair 中,添加另一个原子作为邻居
|
| 1416 |
+
nei={ai for ai in atom_pair if ai !=atom }
|
| 1417 |
+
neighbors.update(nei)
|
| 1418 |
+
# if atom in atom_pair:
|
| 1419 |
+
# other_atom = atom_pair[0] if atom == atom_pair[1] else atom_pair[1]
|
| 1420 |
+
# neighbors.add(other_atom)
|
| 1421 |
+
a2neib[atom] = sorted(list(neighbors)) # 转换为有序列表
|
| 1422 |
+
|
| 1423 |
+
# 找到所有 C 的 bbox 尺寸
|
| 1424 |
+
c_bboxes = [bbox for bbox, cls in zip(output_a['bbox'], output_a['pred_classes']) if cls == 'C']
|
| 1425 |
+
if not c_bboxes:
|
| 1426 |
+
# 如果没有C原子,使用所有bbox中最小的
|
| 1427 |
+
print("Warning: No 'C' atoms found, using smallest bbox in output_a instead.")
|
| 1428 |
+
all_bboxes = output_a['bbox']
|
| 1429 |
+
if not all_bboxes:
|
| 1430 |
+
raise ValueError("No bboxes found in output_a at all.")
|
| 1431 |
+
smallest_bbox = min(all_bboxes, key=bbox_area)
|
| 1432 |
+
c_bboxes = [smallest_bbox] # 计算最小宽度和高度
|
| 1433 |
+
min_width = min([bbox[2] - bbox[0] for bbox in c_bboxes])
|
| 1434 |
+
min_height = min([bbox[3] - bbox[1] for bbox in c_bboxes])
|
| 1435 |
+
|
| 1436 |
+
# 处理 singleAtomBond
|
| 1437 |
+
for bi, atom_idx_list in singleAtomBond.items():
|
| 1438 |
+
bond_box = bond_bbox[bi]
|
| 1439 |
+
atom1_idx = atom_idx_list[0]
|
| 1440 |
+
bond_vertices = get_corners(bond_box)
|
| 1441 |
+
# 计算 atom1_center 到 bond box 4 个顶点的距离
|
| 1442 |
+
atom1_center = atom_centers[atom1_idx]
|
| 1443 |
+
distances = [np.linalg.norm(np.array(atom1_center) - v) for v in bond_vertices]
|
| 1444 |
+
closest_indices = np.argsort(distances)[:2] # 距离最小的两个顶点
|
| 1445 |
+
connected_vertices = bond_vertices[closest_indices]
|
| 1446 |
+
unconnected_vertices = bond_vertices[[i for i in range(4) if i not in closest_indices]]
|
| 1447 |
+
# exclude_=a2neib[atom1_idx]
|
| 1448 |
+
exclude_=[atom1_idx]#add it self
|
| 1449 |
+
print(f'exclude this atom itself :: {exclude_},and its neiboughs {a2neib[atom1_idx]}')
|
| 1450 |
+
# 找到 atom2(未连接端到所有 atom box 顶点的最小距离)
|
| 1451 |
+
# _, atom2_idx_ = get_min_distance_to_atom_box(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
|
| 1452 |
+
atom2_idx_, dist2 = find_nearest_atom(unconnected_vertices, atom_bboxes, exclude_idx=exclude_)
|
| 1453 |
+
# 从 atom1 找到最近的另一个 atom (atom2_1)
|
| 1454 |
+
atom1_corners = get_corners(atom_bboxes[atom1_idx])
|
| 1455 |
+
atom2_1_idx, dist2_1 = find_nearest_atom(atom1_corners, atom_bboxes, exclude_idx=exclude_)
|
| 1456 |
+
if debug:print("atom2_idx_ , atom2_1_idx,atom1_idx:",atom2_idx_, atom2_1_idx,atom1_idx)
|
| 1457 |
+
if atom2_idx_< atom1_idx:
|
| 1458 |
+
k=[atom2_idx_, atom1_idx]
|
| 1459 |
+
else:
|
| 1460 |
+
k=[atom1_idx, atom2_idx_]
|
| 1461 |
+
|
| 1462 |
+
if atom2_idx_ == atom2_1_idx :
|
| 1463 |
+
if atom2_idx_ not in a2neib[atom1_idx]:
|
| 1464 |
+
if debug: print('add new bond with existed atom')
|
| 1465 |
+
b2aa[bi]=k
|
| 1466 |
+
bonds[bi][0]=k[0]
|
| 1467 |
+
bonds[bi][1]=k[1]
|
| 1468 |
+
else:#need insert new atom box at this bond terminal site, default with C
|
| 1469 |
+
new_center=np.mean(unconnected_vertices, axis=0)
|
| 1470 |
+
# 生成新 C 的 bbox
|
| 1471 |
+
new_bbox = [
|
| 1472 |
+
new_center[0] - min_width / 2,
|
| 1473 |
+
new_center[1] - min_height / 2,
|
| 1474 |
+
new_center[0] + min_width / 2,
|
| 1475 |
+
new_center[1] + min_height / 2
|
| 1476 |
+
]
|
| 1477 |
+
if debug: print('new atom box adding as C')
|
| 1478 |
+
atom_bboxes.append(new_bbox)
|
| 1479 |
+
atom_centers.append(new_center.tolist())
|
| 1480 |
+
atom_scores.append(bond_scores[bi]) # 使用 bond 的 score
|
| 1481 |
+
atom_classes.append('C')
|
| 1482 |
+
#updating
|
| 1483 |
+
atom2_idx_= len(atom_classes)-1
|
| 1484 |
+
k=[atom1_idx, atom2_idx_]
|
| 1485 |
+
bonds[bi][1]=atom2_idx_
|
| 1486 |
+
b2aa[bi][1]=atom2_idx_
|
| 1487 |
+
else: #atom2_idx_ != atom2_1_idx, keep atom2_idx_ from bond box privlage
|
| 1488 |
+
if atom2_idx_ not in a2neib[atom1_idx]:
|
| 1489 |
+
if debug: print(f'atom2_idx_ != atom2_1_idx| {atom2_idx_} != {atom2_1_idx} @add new bond with existed atom')
|
| 1490 |
+
b2aa[bi]=k
|
| 1491 |
+
bonds[bi][0]=k[0]
|
| 1492 |
+
bonds[bi][1]=k[1]
|
| 1493 |
+
else:#need insert new atom box at this bond terminal site, default with C
|
| 1494 |
+
new_center=np.mean(unconnected_vertices, axis=0)
|
| 1495 |
+
# 生成新 C 的 bbox
|
| 1496 |
+
new_bbox = [
|
| 1497 |
+
new_center[0] - min_width / 2,
|
| 1498 |
+
new_center[1] - min_height / 2,
|
| 1499 |
+
new_center[0] + min_width / 2,
|
| 1500 |
+
new_center[1] + min_height / 2
|
| 1501 |
+
]
|
| 1502 |
+
atom_bboxes.append(new_bbox)#updateing atom box
|
| 1503 |
+
atom_centers.append(new_center.tolist())
|
| 1504 |
+
atom_scores.append(bond_scores[bi]) # 使用 bond 的 score
|
| 1505 |
+
atom_classes.append('C')
|
| 1506 |
+
#updating
|
| 1507 |
+
atom2_idx_= len(atom_classes)-1
|
| 1508 |
+
k=[atom1_idx, atom2_idx_]
|
| 1509 |
+
bonds[bi][1]=atom2_idx_
|
| 1510 |
+
b2aa[bi][1]=atom2_idx_
|
| 1511 |
+
if debug: print(f'atom2_idx_ != atom2_1_idx@new atom box {atom2_idx_}adding as C, with bond {bi} a1a2 {k}')
|
| 1512 |
+
|
| 1513 |
+
if bi in bondWithdirct.keys():
|
| 1514 |
+
bondWithdirct[bi][0]=k[0]
|
| 1515 |
+
bondWithdirct[bi][1]=k[1]#update atom2 index
|
| 1516 |
+
|
| 1517 |
+
#TODO, fix me, this case, may need ocr.ocr first, try to dicide need isolated atom added bond times
|
| 1518 |
+
if debug:print(f"before del bonds {len(bond_bbox)}")
|
| 1519 |
+
# viewcheck_b(image_path,bond_bbox,bond_classes,color='green',figsize=(10,7))
|
| 1520 |
+
# viewcheck(image_path,atom_bboxes,color='red')
|
| 1521 |
+
#update aa2b for remove duplicated bonds
|
| 1522 |
+
aa2b=dict()
|
| 1523 |
+
for bi, aa in b2aa.items():
|
| 1524 |
+
min_ai=min(aa)
|
| 1525 |
+
max_ai=max(aa)
|
| 1526 |
+
if bond_scores[bi] is None:
|
| 1527 |
+
bond_scores[bi]=1.0
|
| 1528 |
+
score_=bond_scores[bi]
|
| 1529 |
+
bond_type=bond_classes[bi]
|
| 1530 |
+
# print([bond_type,score_])
|
| 1531 |
+
#bond_type check afte singleAtomBond
|
| 1532 |
+
if bond_type in ['single','wdge','dash', '-', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| 1533 |
+
bond_ = [min_ai, max_ai, 'SINGLE', score]
|
| 1534 |
+
if bond_type in ['wdge','dash','ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| 1535 |
+
bondWithdirct[bi]=[min_ai, max_ai,'SINGLE', score, bond_type]
|
| 1536 |
+
elif bond_type == '=':
|
| 1537 |
+
bond_ = [min_ai, max_ai, 'DOUBLE', score]
|
| 1538 |
+
# print(bond_,"@@@@")
|
| 1539 |
+
elif bond_type == '#':
|
| 1540 |
+
bond_ = [min_ai, max_ai, 'TRIPLE', score]
|
| 1541 |
+
elif bond_type == ':':
|
| 1542 |
+
bond_ = [min_ai, max_ai, 'AROMATIC', score]
|
| 1543 |
+
else:
|
| 1544 |
+
print(f"what case here !!! with bond_type: {bond_type} || {[bi,min_ai, max_ai]}")
|
| 1545 |
+
bond_=[min_ai, max_ai, 'SINGLE', score]
|
| 1546 |
+
|
| 1547 |
+
if (min_ai, max_ai) not in aa2b.keys() or aa2b[(min_ai, max_ai)][-2]<score_:
|
| 1548 |
+
aa2b[(min_ai, max_ai)]=[bi,score_,bond_[-2]]
|
| 1549 |
+
#SINGEL Atom bond 本来是不重复的,会误认repeate and remove TODO
|
| 1550 |
+
|
| 1551 |
+
#remove duplicated bonds based on score
|
| 1552 |
+
if len(aa2b)!=len(b2aa):
|
| 1553 |
+
# 1. 去重并生成新的 bi 映射
|
| 1554 |
+
new_bi_map = {} # 格式: {old_bi: new_bi}
|
| 1555 |
+
new_bonds = {}
|
| 1556 |
+
new_aa2b = {}
|
| 1557 |
+
new_b2aa = {}
|
| 1558 |
+
new_bondWithdirct = {}
|
| 1559 |
+
new_singleAtomBond = {}
|
| 1560 |
+
# 按 aa2b 的顺序分配新 bi(保留分数高的键)
|
| 1561 |
+
for new_bi, ((min_ai, max_ai), (old_bi, score, bond_type)) in enumerate(
|
| 1562 |
+
sorted(aa2b.items(), key=lambda x: x[1][1], reverse=True) # 按分数降序排序
|
| 1563 |
+
):
|
| 1564 |
+
new_bi_map[old_bi] = new_bi
|
| 1565 |
+
new_bonds[new_bi] = [min_ai, max_ai, bond_type, score]
|
| 1566 |
+
new_aa2b[(min_ai, max_ai)] = [new_bi, score, bond_type]
|
| 1567 |
+
new_b2aa[new_bi] = [min_ai, max_ai]
|
| 1568 |
+
|
| 1569 |
+
# 2. 更新 bondWithdirct & singleAtomBond
|
| 1570 |
+
for old_bi, bond_info in bondWithdirct.items():
|
| 1571 |
+
if old_bi in new_bi_map:
|
| 1572 |
+
new_bi = new_bi_map[old_bi]
|
| 1573 |
+
new_bondWithdirct[new_bi] = bond_info
|
| 1574 |
+
|
| 1575 |
+
for old_bi, bond_info in singleAtomBond.items():
|
| 1576 |
+
if old_bi in new_bi_map:
|
| 1577 |
+
new_bi = new_bi_map[old_bi]
|
| 1578 |
+
new_singleAtomBond[new_bi] = bond_info
|
| 1579 |
+
|
| 1580 |
+
# 3. 替换旧数据结构, TODO ad bond box class scores here
|
| 1581 |
+
bonds = new_bonds
|
| 1582 |
+
aa2b = new_aa2b
|
| 1583 |
+
b2aa = new_b2aa
|
| 1584 |
+
bondWithdirct = new_bondWithdirct
|
| 1585 |
+
singleAtomBond = new_singleAtomBond
|
| 1586 |
+
if debug: print(f"去重完成: bonds={len(bonds)}, aa2b={len(aa2b)}, b2aa={len(b2aa)}, bondWithdirct={len(bondWithdirct)}")
|
| 1587 |
+
#remove duplicated bonds based on score
|
| 1588 |
+
# 4. 更新 bond_bbox, bond_scores, bond_classes
|
| 1589 |
+
old_bns=max(new_bi_map.keys())
|
| 1590 |
+
to_remove_bonds=set()
|
| 1591 |
+
for i in range(old_bns):
|
| 1592 |
+
if i not in new_bi_map.keys():
|
| 1593 |
+
to_remove_bonds.add(i)
|
| 1594 |
+
print(to_remove_bonds)
|
| 1595 |
+
# 删除被移除的 bbox
|
| 1596 |
+
bond_scores = [bond_scores[i] for i in range(len(bond_scores)) if i not in to_remove_bonds]
|
| 1597 |
+
bond_classes = [bond_classes[i] for i in range(len(bond_classes)) if i not in to_remove_bonds]
|
| 1598 |
+
bond_bbox = [bond_bbox[i] for i in range(len(bond_bbox)) if i not in to_remove_bonds]
|
| 1599 |
+
bond_center = [[ (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 ] for bbox in bond_bbox]
|
| 1600 |
+
|
| 1601 |
+
|
| 1602 |
+
a2b=dict()
|
| 1603 |
+
isolated_a=set()
|
| 1604 |
+
for k,v in b2aa.items():
|
| 1605 |
+
# a1,a2=v
|
| 1606 |
+
for a in v:
|
| 1607 |
+
if a not in a2b.keys():
|
| 1608 |
+
a2b[a]=[k]
|
| 1609 |
+
else:
|
| 1610 |
+
a2b[a].append(k)
|
| 1611 |
+
|
| 1612 |
+
for ai, a_lab in enumerate(atom_classes):
|
| 1613 |
+
if ai not in a2b.keys():
|
| 1614 |
+
isolated_a.add(ai)
|
| 1615 |
+
a2b = dict(sorted(a2b.items()))
|
| 1616 |
+
|
| 1617 |
+
# 初始化 a2neib
|
| 1618 |
+
a2neib = {}
|
| 1619 |
+
# 遍历 a2b,构建邻居关系
|
| 1620 |
+
for atom, bns in a2b.items():
|
| 1621 |
+
neighbors = set() # 使用集合去重
|
| 1622 |
+
for bond in bns:
|
| 1623 |
+
atom_pair = b2aa[bond] # 获取 bond 连接的原子对
|
| 1624 |
+
# 如果当前原子在 atom_pair 中,添加另一个原子作为邻居
|
| 1625 |
+
nei={ai for ai in atom_pair if ai !=atom }
|
| 1626 |
+
neighbors.update(nei)
|
| 1627 |
+
# if atom in atom_pair:
|
| 1628 |
+
# other_atom = atom_pair[0] if atom == atom_pair[1] else atom_pair[1]
|
| 1629 |
+
# neighbors.add(other_atom)
|
| 1630 |
+
a2neib[atom] = sorted(list(neighbors)) # 转换为有序列表
|
| 1631 |
+
|
| 1632 |
+
debug2=False
|
| 1633 |
+
if debug2:
|
| 1634 |
+
# 输出结果
|
| 1635 |
+
print("\nBonds:")
|
| 1636 |
+
for bi, bond_info in bonds.items():
|
| 1637 |
+
print(f"Bond {bi}: {bond_info}")
|
| 1638 |
+
print("\nSingle Atom Bonds:")
|
| 1639 |
+
for bi, atom_idx in singleAtomBond.items():
|
| 1640 |
+
print(f"Bond {bi}: {atom_idx}")
|
| 1641 |
+
print("Atom to Bonds box idx maping:")
|
| 1642 |
+
for ai, bond_ids in a2b.items():
|
| 1643 |
+
print(f"a2b-id {ai}: {bond_ids}")
|
| 1644 |
+
print(f"isolated_ atom box:: {isolated_a}")
|
| 1645 |
+
print(f"b2aa::{b2aa}")
|
| 1646 |
+
# 输出结果
|
| 1647 |
+
print("a2neib:")
|
| 1648 |
+
for atom, neighbors in a2neib.items():
|
| 1649 |
+
print(f"Atom {atom}: {neighbors}")
|
| 1650 |
+
|
| 1651 |
+
other2ppsocr = True
|
| 1652 |
+
ocr_ai2lab = dict()
|
| 1653 |
+
ocr_bbs = dict()
|
| 1654 |
+
scale_crop = False
|
| 1655 |
+
ocr_ai2lab_ori=dict()
|
| 1656 |
+
ocr_ai2lab_sca=dict()
|
| 1657 |
+
|
| 1658 |
+
|
| 1659 |
+
if other2ppsocr:
|
| 1660 |
+
elements = ['S', 'N', 'P', 'C', 'O']
|
| 1661 |
+
keys = [f"{e}{suffix}" for e in elements for suffix in ['R"', "R'", "R", "*"]]
|
| 1662 |
+
replacement_map = {key: f'{key[0]}*' for key in keys}
|
| 1663 |
+
if da=='staker':
|
| 1664 |
+
_margin=2#as staker use small image 256X256
|
| 1665 |
+
else:
|
| 1666 |
+
_margin=0
|
| 1667 |
+
for i, atc in enumerate(atom_classes):
|
| 1668 |
+
if 'other' == atc: # 30 idx_lab version OH-->Cl with high
|
| 1669 |
+
# Initialize variables to store both results
|
| 1670 |
+
orig_result = None
|
| 1671 |
+
orig_score = 0
|
| 1672 |
+
scaled_result = None
|
| 1673 |
+
scaled_score = 0
|
| 1674 |
+
|
| 1675 |
+
# Process original image crop
|
| 1676 |
+
abox_orig = np.array(atom_bboxes[i]) + np.array([-_margin, -_margin,_margin, _margin])
|
| 1677 |
+
cropped_img_orig = img_ori.crop(abox_orig)
|
| 1678 |
+
image_npocr_orig = np.array(cropped_img_orig)
|
| 1679 |
+
result_ocr_orig = ocr.ocr(image_npocr_orig, det=False)
|
| 1680 |
+
|
| 1681 |
+
if result_ocr_orig:
|
| 1682 |
+
orig_text = result_ocr_orig[0][0][0]
|
| 1683 |
+
orig_score = result_ocr_orig[0][0][1]
|
| 1684 |
+
if debug: print(f'oriCrop:\t {orig_text} {orig_score}')
|
| 1685 |
+
orig_text = normalize_ocr_text(orig_text, replacement_map)
|
| 1686 |
+
ocr_ai2lab_ori[i]=[orig_text,orig_score]
|
| 1687 |
+
# Process scaled image crop
|
| 1688 |
+
abox_scaled = np.array(atom_bboxes[i]) * np.array([scale_x, scale_y, scale_x, scale_y]) + np.array([-_margin, -_margin,_margin, _margin])
|
| 1689 |
+
cropped_img_scaled = img_ori_1k.crop(abox_scaled)
|
| 1690 |
+
image_npocr_scaled = np.array(cropped_img_scaled)
|
| 1691 |
+
result_ocr_scaled = ocr.ocr(image_npocr_scaled, det=False)
|
| 1692 |
+
|
| 1693 |
+
if result_ocr_scaled:
|
| 1694 |
+
scaled_text = result_ocr_scaled[0][0][0]
|
| 1695 |
+
scaled_score = result_ocr_scaled[0][0][1]
|
| 1696 |
+
if debug: print(f'scaled:\t {scaled_text} {scaled_score}')
|
| 1697 |
+
scaled_text = normalize_ocr_text(scaled_text, replacement_map)
|
| 1698 |
+
ocr_ai2lab_sca[i]=[scaled_text,scaled_score]
|
| 1699 |
+
|
| 1700 |
+
|
| 1701 |
+
|
| 1702 |
+
final_text, final_score, final_crop = select_chem_expression(
|
| 1703 |
+
orig_text, orig_score, scaled_text, scaled_score, cropped_img_orig, cropped_img_scaled
|
| 1704 |
+
)
|
| 1705 |
+
|
| 1706 |
+
if orig_text=='NO2' or scaled_text=='NO2':
|
| 1707 |
+
final_text='NO2'#AS stm NO score >NO2
|
| 1708 |
+
elif orig_text=='SO2' or scaled_text=='SO2':
|
| 1709 |
+
final_text='SO2'#AS stm NO score >SO2
|
| 1710 |
+
# elif orig_starts_upper == scaled_starts_upper:
|
| 1711 |
+
# # If both start with uppercase or both don't, use the higher score
|
| 1712 |
+
# final_text = orig_text if orig_score >= scaled_score else scaled_text
|
| 1713 |
+
# elif orig_starts_upper != scaled_starts_upper:
|
| 1714 |
+
# # If one starts with uppercase, use that one
|
| 1715 |
+
# final_text = orig_text if orig_starts_upper else scaled_text
|
| 1716 |
+
|
| 1717 |
+
if final_text:
|
| 1718 |
+
ocr_ai2lab[i] = [final_text, final_score]
|
| 1719 |
+
ocr_bbs[i] = final_crop
|
| 1720 |
+
atom_classes[i] = final_text
|
| 1721 |
+
if debug:
|
| 1722 |
+
print("ori",ocr_ai2lab_ori)
|
| 1723 |
+
print("sca",ocr_ai2lab_sca)
|
| 1724 |
+
print(ocr_ai2lab)
|
| 1725 |
+
#TODO make chem-group recongized dataBase next works !!!
|
| 1726 |
+
|
| 1727 |
+
if len(ocr_bbs)>0:
|
| 1728 |
+
if debug:print(f'numbs of ocr {len(ocr_bbs)} crop_ images')
|
| 1729 |
+
#merge the isolated_a Ph3Br into closet atom box
|
| 1730 |
+
# 3 in isolated_a, isolated_a, isolated_aFound
|
| 1731 |
+
giveup_isolateds=dict()
|
| 1732 |
+
if len(isolated_a):#after updated isolated_a still has the isolatd item
|
| 1733 |
+
for iso_atom in isolated_a:
|
| 1734 |
+
atom1_corners = get_corners(atom_bboxes[iso_atom])
|
| 1735 |
+
atom2_1_idx, dist2_1 = find_nearest_atom(atom1_corners, atom_bboxes, exclude_idx=[iso_atom])
|
| 1736 |
+
atom1_lab=atom_classes[iso_atom]
|
| 1737 |
+
atom2_lab=atom_classes[atom2_1_idx]
|
| 1738 |
+
if atom1_lab in ['Ph3Br','Ph3Br-']:
|
| 1739 |
+
if iso_atom not in giveup_isolateds.keys():
|
| 1740 |
+
giveup_isolateds[iso_atom]=[atom1_lab]
|
| 1741 |
+
else:
|
| 1742 |
+
giveup_isolateds[iso_atom].append(atom1_lab)
|
| 1743 |
+
|
| 1744 |
+
if atom2_lab in ['P','P+']:#merge as new group
|
| 1745 |
+
atom2_lab='P+Ph3Br-'
|
| 1746 |
+
elif atom2_lab in ['N','N+']:#merge as new group
|
| 1747 |
+
atom2_lab='N+Ph3Br-'
|
| 1748 |
+
|
| 1749 |
+
atom_classes[atom2_1_idx]=atom2_lab #update bonded atom label with the merged
|
| 1750 |
+
|
| 1751 |
+
#TODO add cases that need merge OCR results with bonded atom box
|
| 1752 |
+
if debug:
|
| 1753 |
+
print(f"giveup_isolateds {giveup_isolateds}")
|
| 1754 |
+
print(len(atom_classes),len(bond_classes),'<<<<<<<<<<<')#,len(charges_classes))
|
| 1755 |
+
###########################start build mol ##########################
|
| 1756 |
+
rwmol_ = Chem.RWMol()
|
| 1757 |
+
boxi2ai = {} # 预测索引 -> RDKit 索引
|
| 1758 |
+
placeholder_atoms=dict()
|
| 1759 |
+
# print(len(atom_classes),len(bond_classes))#,len(charges_classes))
|
| 1760 |
+
#assign atom
|
| 1761 |
+
J=0
|
| 1762 |
+
for i, (bbox, a) in enumerate(zip(atom_bboxes, atom_classes)):
|
| 1763 |
+
a2labl=False
|
| 1764 |
+
a=replace_cg_notation(a)
|
| 1765 |
+
# print(a,'atom box class label')
|
| 1766 |
+
if a in ['H', 'C', 'O', 'N', 'Cl', 'Br', 'S', 'F', 'B', 'I', 'P', 'Si']:# '*', I2M's defined atom types
|
| 1767 |
+
# if a=='H':continue#skip H fristly,only with heavy atom then addH
|
| 1768 |
+
ad = Chem.Atom(a)#TODO consider non chemical group and label for using
|
| 1769 |
+
#TODO add pd rdkit known elemetns here
|
| 1770 |
+
elif a in ELEMENTS:
|
| 1771 |
+
ad = Chem.Atom(a)
|
| 1772 |
+
|
| 1773 |
+
elif a in ABBREVIATIONS :
|
| 1774 |
+
ad = Chem.Atom("*")
|
| 1775 |
+
placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
|
| 1776 |
+
a2labl=True
|
| 1777 |
+
else:
|
| 1778 |
+
if N_C_H_expand(a):
|
| 1779 |
+
ad = Chem.Atom("*")
|
| 1780 |
+
placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
|
| 1781 |
+
a2labl=True
|
| 1782 |
+
elif C_H_expand(a):
|
| 1783 |
+
ad = Chem.Atom("*")
|
| 1784 |
+
placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
|
| 1785 |
+
a2labl=True
|
| 1786 |
+
elif C_H_expand2(a):
|
| 1787 |
+
ad = Chem.Atom("*")
|
| 1788 |
+
placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
|
| 1789 |
+
a2labl=True
|
| 1790 |
+
|
| 1791 |
+
elif formula_regex(a):
|
| 1792 |
+
ad = Chem.Atom("*")
|
| 1793 |
+
placeholder_atoms[i] = a # 记录非标准原但有定义的官能团 类型及其位置,
|
| 1794 |
+
a2labl=True
|
| 1795 |
+
else:
|
| 1796 |
+
ad = Chem.Atom("*")
|
| 1797 |
+
if a not in ['*',"other"]:
|
| 1798 |
+
a2labl=True
|
| 1799 |
+
# placeholder_atoms[idx] = a
|
| 1800 |
+
# atom = Chem.Atom(symbol)
|
| 1801 |
+
rwmol_.AddAtom(ad)
|
| 1802 |
+
boxi2ai[J] = rwmol_.GetNumAtoms() - 1
|
| 1803 |
+
if a2labl: rwmol_.GetAtomWithIdx(J).SetProp("atomLabel", f"{a}")#mol set with label, mol_rebuild not
|
| 1804 |
+
J+=1
|
| 1805 |
+
|
| 1806 |
+
|
| 1807 |
+
charges_classes= output_c['pred_classes']
|
| 1808 |
+
charges_centers= output_c['bbox_centers']
|
| 1809 |
+
charges_scores= output_c['scores']
|
| 1810 |
+
charges_bbox= output_c['bbox']
|
| 1811 |
+
a2c=dict()
|
| 1812 |
+
c2a=dict()
|
| 1813 |
+
|
| 1814 |
+
# #assign charge
|
| 1815 |
+
if len(charges_classes) > 0:
|
| 1816 |
+
kdt = cKDTree(atom_centers)
|
| 1817 |
+
c2a = {} # 电荷索引到原子索引的映射
|
| 1818 |
+
used_atoms = set() # 跟踪已分配电荷的原子
|
| 1819 |
+
for i, charge_box in enumerate(charges_bbox):
|
| 1820 |
+
charge_value = parse_charge(charges_classes[i])
|
| 1821 |
+
overlapped_atoms = []
|
| 1822 |
+
# 检查重叠
|
| 1823 |
+
for ai, atom_box in enumerate(atom_bboxes):
|
| 1824 |
+
if boxes_overlap(charge_box, atom_box):
|
| 1825 |
+
overlapped_atoms.append(ai)
|
| 1826 |
+
if overlapped_atoms:
|
| 1827 |
+
# 如果有重叠,选择第一个未使用的原子(假设一个电荷只分配一个原子)
|
| 1828 |
+
for ai in overlapped_atoms:
|
| 1829 |
+
if ai not in used_atoms:
|
| 1830 |
+
c2a[i] = ai
|
| 1831 |
+
used_atoms.add(ai)
|
| 1832 |
+
break
|
| 1833 |
+
else:
|
| 1834 |
+
# 不重叠时,使用角点距离和 KDTree 验证
|
| 1835 |
+
x, y = charges_centers[i]
|
| 1836 |
+
dist_kdt, ai_kdt = kdt.query([x, y], k=1)
|
| 1837 |
+
# 计算角点距离最近的原子
|
| 1838 |
+
min_dist = float('inf')
|
| 1839 |
+
ai_corner = None
|
| 1840 |
+
for ai, atom_box in enumerate(atom_bboxes):
|
| 1841 |
+
dist = min_corner_distance(charge_box, atom_box)
|
| 1842 |
+
if dist < min_dist:
|
| 1843 |
+
min_dist = dist
|
| 1844 |
+
ai_corner = ai
|
| 1845 |
+
# 比较 KDTree 和角点距离结果
|
| 1846 |
+
if ai_kdt == ai_corner and ai_kdt not in used_atoms:
|
| 1847 |
+
c2a[i] = ai_kdt
|
| 1848 |
+
used_atoms.add(ai_kdt)
|
| 1849 |
+
else:
|
| 1850 |
+
# 检查电荷值和原子类型
|
| 1851 |
+
if charge_value != 0:
|
| 1852 |
+
symbol_kdt =atom_classes[ai_kdt]
|
| 1853 |
+
symbol_corner =atom_classes[ai_corner]
|
| 1854 |
+
# 如果电荷值不为零,分配给非C的原子,如果都是非C, 则根据kdt k=1来分配电荷
|
| 1855 |
+
if symbol_kdt == 'C' and symbol_corner != 'C' and ai_corner not in used_atoms:
|
| 1856 |
+
# KDTree 是碳,角点不是碳,优先分配给角点原子
|
| 1857 |
+
c2a[i] = ai_corner
|
| 1858 |
+
used_atoms.add(ai_corner)
|
| 1859 |
+
elif symbol_corner == 'C' and symbol_kdt != 'C' and ai_kdt not in used_atoms:
|
| 1860 |
+
# 角点是碳,KDTree 不是碳,优先分配给 KDTree 原子
|
| 1861 |
+
c2a[i] = ai_kdt
|
| 1862 |
+
used_atoms.add(ai_kdt)
|
| 1863 |
+
else:
|
| 1864 |
+
# 两个都是非碳,或两个都是碳,默认使用 KDTree 结果
|
| 1865 |
+
if ai_kdt not in used_atoms:
|
| 1866 |
+
c2a[i] = ai_kdt
|
| 1867 |
+
used_atoms.add(ai_kdt)
|
| 1868 |
+
elif ai_corner not in used_atoms:
|
| 1869 |
+
# 如果 KDTree 结果已使用,尝试角点结果
|
| 1870 |
+
c2a[i] = ai_corner
|
| 1871 |
+
used_atoms.add(ai_corner)
|
| 1872 |
+
|
| 1873 |
+
#assign charge
|
| 1874 |
+
a2c={v:k for k,v in c2a.items()}
|
| 1875 |
+
for k,v in a2c.items():
|
| 1876 |
+
fc=int(charges_classes[v])
|
| 1877 |
+
rwmol_.GetAtomWithIdx(k).SetFormalCharge(fc)
|
| 1878 |
+
# if k in placeholder_atoms:
|
| 1879 |
+
if atom_classes[k] in ['COO','CO2']:#TODO add more charge if need
|
| 1880 |
+
if fc==-1:
|
| 1881 |
+
atom_classes[k]=f"{atom_classes[k]}-"
|
| 1882 |
+
placeholder_atoms[k]=atom_classes[k]
|
| 1883 |
+
atom = rwmol_.GetAtomWithIdx(k)
|
| 1884 |
+
atom.SetProp("atomLabel",placeholder_atoms[k])
|
| 1885 |
+
elif fc==1:
|
| 1886 |
+
atom_classes[k]=f"{atom_classes[k]}+"
|
| 1887 |
+
placeholder_atoms[k]=atom_classes[k]
|
| 1888 |
+
atom = rwmol_.GetAtomWithIdx(k)
|
| 1889 |
+
atom.SetProp("atomLabel",placeholder_atoms[k])
|
| 1890 |
+
else:
|
| 1891 |
+
print(f"charge adding {fc} @ {atom_classes[v]}")
|
| 1892 |
+
print(f'placeholder_atoms {placeholder_atoms}')
|
| 1893 |
+
#add bonds
|
| 1894 |
+
for bi, bond in bonds.items():
|
| 1895 |
+
atom1_idx, atom2_idx, bond_type, score = bond
|
| 1896 |
+
if atom1_idx ==atom2_idx:print(f"self bond should be avoid or del on previous process!!")
|
| 1897 |
+
# print(f"Adding bond between atoms {atom1_idx} and {atom2_idx} of type {bond_type}")
|
| 1898 |
+
if bond_type == 'SINGLE':
|
| 1899 |
+
rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.SINGLE)
|
| 1900 |
+
elif bond_type == 'DOUBLE':
|
| 1901 |
+
rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.DOUBLE)
|
| 1902 |
+
elif bond_type == 'TRIPLE':
|
| 1903 |
+
rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.TRIPLE)
|
| 1904 |
+
elif bond_type == 'AROMATIC':
|
| 1905 |
+
rwmol_.AddBond(atom1_idx, atom2_idx, Chem.BondType.AROMATIC)
|
| 1906 |
+
else:
|
| 1907 |
+
print(f"Unknown bond type: {bond_type}")
|
| 1908 |
+
|
| 1909 |
+
if debug: print(f"all a2b b2a a2c c2a done, start mol built done")
|
| 1910 |
+
#set direction
|
| 1911 |
+
if len(bondWithdirct)>0:
|
| 1912 |
+
print(f"set bond direction for mollecule ")
|
| 1913 |
+
# rwmol_=set_bondDriection(rwmol_,bondWithdirct)
|
| 1914 |
+
|
| 1915 |
+
skeleton_smi = Chem.MolToSmiles(rwmol_) #TODO WEB_dev, use this rwmol_ for display without expand the R groups
|
| 1916 |
+
#ASSIGN COORDS
|
| 1917 |
+
coords = [(x,-y,0) for x,y in atom_centers]
|
| 1918 |
+
coords = tuple(coords)
|
| 1919 |
+
coords = tuple(tuple(num / 100 for num in sub_tuple) for sub_tuple in coords)
|
| 1920 |
+
|
| 1921 |
+
mol2D = rwmol_.GetMol()
|
| 1922 |
+
mol2D.RemoveAllConformers()
|
| 1923 |
+
conf = Chem.Conformer(mol2D.GetNumAtoms())
|
| 1924 |
+
conf.Set3D(True)
|
| 1925 |
+
for i, (x, y, z) in enumerate(coords):
|
| 1926 |
+
conf.SetAtomPosition(i, (x, y, z))
|
| 1927 |
+
mol2D.AddConformer(conf)
|
| 1928 |
+
try:
|
| 1929 |
+
Chem.SanitizeMol(mol2D)
|
| 1930 |
+
Chem.AssignStereochemistryFrom3D(mol2D)
|
| 1931 |
+
mol_rebuit2d=Chem.RWMol(mol2D)
|
| 1932 |
+
except Exception as e:
|
| 1933 |
+
print(e)
|
| 1934 |
+
print('before expanding!!! try to sanizemol and assign stereo')
|
| 1935 |
+
mol_rebuit2d=Chem.RWMol(rwmol_)
|
| 1936 |
+
if len(giveup_isolateds)>0:
|
| 1937 |
+
#clean with remove giveup_isolateds
|
| 1938 |
+
# 1. 先为每个原子设置一个“old_index”属性
|
| 1939 |
+
for atom in mol_rebuit2d.GetAtoms():
|
| 1940 |
+
atom.SetProp('old_index', str(atom.GetIdx()))
|
| 1941 |
+
|
| 1942 |
+
# 2. 删除原子时建议按照降序删除,避免索引变化带来的问题
|
| 1943 |
+
for ai in sorted(giveup_isolateds.keys(), reverse=True):
|
| 1944 |
+
mol_rebuit2d.RemoveAtom(ai)
|
| 1945 |
+
print(f"atom {ai} label {giveup_isolateds[ai]} removed")
|
| 1946 |
+
|
| 1947 |
+
# 3. 删除操作完成后,构建老索引到新索引的映射
|
| 1948 |
+
old_to_new = {}
|
| 1949 |
+
for atom in mol_rebuit2d.GetAtoms():
|
| 1950 |
+
old_idx = int(atom.GetProp('old_index'))
|
| 1951 |
+
new_idx = atom.GetIdx()
|
| 1952 |
+
old_to_new[old_idx] = new_idx
|
| 1953 |
+
|
| 1954 |
+
if len(placeholder_atoms)>0:#update placeholder_atoms
|
| 1955 |
+
placeholder_atoms2=dict()
|
| 1956 |
+
for k,v in placeholder_atoms.items():
|
| 1957 |
+
placeholder_atoms2[old_to_new[k]]=v
|
| 1958 |
+
|
| 1959 |
+
placeholder_atoms=placeholder_atoms2
|
| 1960 |
+
try:
|
| 1961 |
+
SMILESpre = Chem.MolToSmiles(mol_rebuit2d)
|
| 1962 |
+
except Exception as e:
|
| 1963 |
+
print(f"Error during SMILES generation: {e}")
|
| 1964 |
+
SMILESpre = Chem.MolToSmiles(mol_rebuit2d, canonical=False)
|
| 1965 |
+
|
| 1966 |
+
|
| 1967 |
+
if len(placeholder_atoms)>0:
|
| 1968 |
+
mol_expan=copy.deepcopy(mol_rebuit2d)
|
| 1969 |
+
if debug: print(f'MOL will be expanded with {placeholder_atoms} !!')
|
| 1970 |
+
wdbs=[]
|
| 1971 |
+
bond_dirs_rev={v:k for k,v in bond_dirs.items()}
|
| 1972 |
+
|
| 1973 |
+
for b in mol_expan.GetBonds():
|
| 1974 |
+
bd=b.GetBondDir()
|
| 1975 |
+
bt=b.GetBondType()
|
| 1976 |
+
# print(bd)
|
| 1977 |
+
if bd ==bond_dirs['BEGINDASH'] or bd==bond_dirs['BEGINWEDGE']:
|
| 1978 |
+
a1, a2 = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
|
| 1979 |
+
wdbs.append([a1,a2,bt,bond_dirs_rev[bd]])
|
| 1980 |
+
|
| 1981 |
+
expandStero_smi1,molexp= molExpanding(mol_expan,placeholder_atoms,wdbs,bond_dirs)#TODO fix me whe n multi strings on a atom will missing this ocr infors
|
| 1982 |
+
molexp=remove_bond_directions_if_no_chiral(molexp)
|
| 1983 |
+
try:
|
| 1984 |
+
Chem.SanitizeMol(molexp)
|
| 1985 |
+
expandStero_smi=Chem.MolToSmiles(molexp)
|
| 1986 |
+
except Exception as e:
|
| 1987 |
+
print(f"Error during sanitization: {e}")
|
| 1988 |
+
expandStero_smi = expandStero_smi1
|
| 1989 |
+
|
| 1990 |
+
expandStero_smi=remove_SP(expandStero_smi)
|
| 1991 |
+
|
| 1992 |
+
else:
|
| 1993 |
+
molexp=mol_rebuit2d
|
| 1994 |
+
expandStero_smi=SMILESpre #save into csv files,
|
| 1995 |
+
|
| 1996 |
+
#TODO WEB_dev, now can display mol with expanded abbev from molexp
|
| 1997 |
+
new_row = {'file_name':image_path, "SMILESori":SMILESori,
|
| 1998 |
+
'SMILESpre':SMILESpre,
|
| 1999 |
+
'SMILESexp':expandStero_smi,
|
| 2000 |
+
}
|
| 2001 |
+
|
| 2002 |
+
# smiles_data = smiles_data._append(new_row, ignore_index=True)#TODO WEB_dev task done here, we can save predicted Rdkit Obj or smiles or display on web
|
| 2003 |
+
print(f"final prediction:\n {expandStero_smi}")
|
| 2004 |
+
|
| 2005 |
+
return expandStero_smi
|
| 2006 |
+
|
| 2007 |
+
main()
|
| 2008 |
+
|
| 2009 |
+
# 安全释放资源
|
| 2010 |
+
# def release_ocr(ocr_instance):
|
| 2011 |
+
# # 关闭所有相关模型
|
| 2012 |
+
# if hasattr(ocr_instance, 'detector'):
|
| 2013 |
+
# ocr_instance.detector = None
|
| 2014 |
+
# if hasattr(ocr_instance, 'recognizer'):
|
| 2015 |
+
# ocr_instance.recognizer = None
|
| 2016 |
+
# if hasattr(ocr_instance, 'cls'):
|
| 2017 |
+
# ocr_instance.cls = None
|
| 2018 |
+
|
| 2019 |
+
# # 调用释放函数
|
| 2020 |
+
# release_ocr(ocr)
|
| 2021 |
+
# del ocr
|
| 2022 |
+
# release_ocr(ocr2)
|
| 2023 |
+
# del ocr2
|
| 2024 |
+
|
| 2025 |
+
|
app.py
CHANGED
|
@@ -1,7 +1,63 @@
|
|
| 1 |
-
from fastapi import FastAPI
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, File, UploadFile
|
| 2 |
+
from fastapi.responses import JSONResponse
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import os
|
| 5 |
+
from ONNX0630 import main as predict_smiles
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
|
| 9 |
+
# Initialize FastAPI app
|
| 10 |
+
app = FastAPI(title="Chemical Structure to SMILES API")
|
| 11 |
+
|
| 12 |
+
# API endpoint to predict SMILES from an image
|
| 13 |
+
@app.post("/predict")
|
| 14 |
+
async def predict(file: UploadFile = File(...)):
|
| 15 |
+
try:
|
| 16 |
+
# Read and save the uploaded image
|
| 17 |
+
contents = await file.read()
|
| 18 |
+
image = Image.open(io.BytesIO(contents))
|
| 19 |
+
temp_path = f"temp_{file.filename}"
|
| 20 |
+
image.save(temp_path)
|
| 21 |
+
|
| 22 |
+
# Call the model function
|
| 23 |
+
smiles = predict_smiles(temp_path)
|
| 24 |
+
|
| 25 |
+
# Clean up temporary file
|
| 26 |
+
os.remove(temp_path)
|
| 27 |
+
|
| 28 |
+
return JSONResponse(content={"smiles": smiles})
|
| 29 |
+
except Exception as e:
|
| 30 |
+
return JSONResponse(content={"error": str(e)}, status_code=500)
|
| 31 |
+
|
| 32 |
+
# Gradio interface
|
| 33 |
+
def gradio_predict(image):
|
| 34 |
+
try:
|
| 35 |
+
# Save the uploaded image
|
| 36 |
+
temp_path = "temp_image.png"
|
| 37 |
+
image.save(temp_path)
|
| 38 |
+
|
| 39 |
+
# Call the model function
|
| 40 |
+
smiles = predict_smiles(temp_path)
|
| 41 |
+
|
| 42 |
+
# Clean up
|
| 43 |
+
os.remove(temp_path)
|
| 44 |
+
|
| 45 |
+
return smiles
|
| 46 |
+
except Exception as e:
|
| 47 |
+
return f"Error: {str(e)}"
|
| 48 |
+
|
| 49 |
+
# Define Gradio interface
|
| 50 |
+
iface = gr.Interface(
|
| 51 |
+
fn=gradio_predict,
|
| 52 |
+
inputs=gr.Image(type="pil"),
|
| 53 |
+
outputs=gr.Textbox(),
|
| 54 |
+
title="Chemical Structure to SMILES Converter",
|
| 55 |
+
description="Upload an image of a chemical structure to get its SMILES string."
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Launch Gradio with FastAPI
|
| 59 |
+
app = gr.mount_gradio_app(app, iface, path="/")
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
import uvicorn
|
| 63 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
det_engine.py
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
utils.py
ADDED
|
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import json
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
from scipy.spatial import cKDTree
|
| 8 |
+
from rdkit import Chem
|
| 9 |
+
from rdkit.Chem import RWMol
|
| 10 |
+
from rdkit.Chem import Draw, AllChem
|
| 11 |
+
from rdkit.Chem import rdDepictor
|
| 12 |
+
import matplotlib.pyplot as plt
|
| 13 |
+
import re
|
| 14 |
+
##################### MolScribe####################################################################################
|
| 15 |
+
from typing import List
|
| 16 |
+
import matplotlib.pyplot as plt
|
| 17 |
+
from matplotlib.patches import Rectangle, Circle
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
COLORS = {
|
| 21 |
+
u'c': '0.0,0.75,0.75', u'b': '0.0,0.0,1.0', u'g': '0.0,0.5,0.0', u'y': '0.75,0.75,0',
|
| 22 |
+
u'k': '0.0,0.0,0.0', u'r': '1.0,0.0,0.0', u'm': '0.75,0,0.75'
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
#helper function
|
| 26 |
+
def view_box_center(bond_bbox,heavy_centers):
|
| 27 |
+
fig, ax = plt.subplots(figsize=(10, 10))
|
| 28 |
+
# 绘制矩形框 (boxes)
|
| 29 |
+
for box in bond_bbox:
|
| 30 |
+
x1, y1, x2, y2 = box
|
| 31 |
+
width = x2 - x1
|
| 32 |
+
height = y2 - y1
|
| 33 |
+
rect = Rectangle((x1, y1), width, height, linewidth=1, edgecolor='blue', facecolor='none')
|
| 34 |
+
ax.add_patch(rect)
|
| 35 |
+
|
| 36 |
+
# 绘制圆形 (centers)
|
| 37 |
+
for center in heavy_centers:
|
| 38 |
+
x, y = center
|
| 39 |
+
circle = Circle((x, y), radius=5, edgecolor='red', facecolor='none', linewidth=1)
|
| 40 |
+
ax.add_patch(circle)
|
| 41 |
+
|
| 42 |
+
# 设置坐标轴范围(根据数据自动调整)
|
| 43 |
+
x_min = min(bond_bbox[:, 0].min(), heavy_centers[:, 0].min()) - 10
|
| 44 |
+
x_max = max(bond_bbox[:, 2].max(), heavy_centers[:, 0].max()) + 10
|
| 45 |
+
y_min = min(bond_bbox[:, 1].min(), heavy_centers[:, 1].min()) - 10
|
| 46 |
+
y_max = max(bond_bbox[:, 3].max(), heavy_centers[:, 1].max()) + 10
|
| 47 |
+
ax.set_xlim(x_min, x_max)
|
| 48 |
+
ax.set_ylim(y_min, y_max)
|
| 49 |
+
|
| 50 |
+
# 设置标题和标签
|
| 51 |
+
ax.set_title("Boxes and Centers")
|
| 52 |
+
ax.set_xlabel("X")
|
| 53 |
+
ax.set_ylabel("Y")
|
| 54 |
+
# 显示图像
|
| 55 |
+
plt.gca().set_aspect('equal', adjustable='box') # 保持比例
|
| 56 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 57 |
+
|
| 58 |
+
def molIDX(mol):
|
| 59 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
| 60 |
+
atom.SetAtomMapNum(i) #映射
|
| 61 |
+
# print(i)
|
| 62 |
+
return mol
|
| 63 |
+
|
| 64 |
+
def molIDX_del(mol):
|
| 65 |
+
for i, atom in enumerate(mol.GetAtoms()):
|
| 66 |
+
atom.SetAtomMapNum(0) #映射
|
| 67 |
+
print(i)
|
| 68 |
+
return mol
|
| 69 |
+
from det_engine import ABBREVIATIONS
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def Val_extract_atom_info(error_message):
|
| 74 |
+
"""
|
| 75 |
+
从错误信息中提取 atomid, atomType 和 valence。
|
| 76 |
+
:param error_message: 错误信息字符串
|
| 77 |
+
:return: (atomid, atomType, valence) 元组
|
| 78 |
+
"""
|
| 79 |
+
# 定义正则表达式来提取原子信息
|
| 80 |
+
pattern = r"Explicit valence for atom # (\d+) (\w), (\d+)"
|
| 81 |
+
pattern2 =r"Explicit valence for atom # (\d+) (\w) "
|
| 82 |
+
# print(type(error_message))
|
| 83 |
+
if not isinstance(error_message, type('strs')):
|
| 84 |
+
error_message=str(error_message)
|
| 85 |
+
match = re.search(pattern, error_message)
|
| 86 |
+
match2 = re.search(pattern2, error_message)
|
| 87 |
+
if match:
|
| 88 |
+
# 提取 atomid, atomType 和 valence
|
| 89 |
+
atomid = int(match.group(1)) # 原子索引
|
| 90 |
+
atomType = match.group(2) # 原子类型
|
| 91 |
+
valence = int(match.group(3)) # 当前价态
|
| 92 |
+
return atomid, atomType, valence
|
| 93 |
+
elif match2:
|
| 94 |
+
atomid = int(match2.group(1)) # 原子索引
|
| 95 |
+
atomType = match2.group(2) # 原子类型
|
| 96 |
+
# valence = int(match2.group(3)) # 当前价态
|
| 97 |
+
return atomid, atomType, None
|
| 98 |
+
|
| 99 |
+
else:
|
| 100 |
+
raise ValueError("无法从错误信息中提取原子信息")
|
| 101 |
+
|
| 102 |
+
def calculate_charge_adjustment(atom_symbol, current_valence):
|
| 103 |
+
"""
|
| 104 |
+
计算需要调整的电荷,根据反馈的原子符号和当前价态。
|
| 105 |
+
:param atom_symbol: 原子符号(如 "C")
|
| 106 |
+
:param current_valence: 当前价态(如 5)
|
| 107 |
+
:return: 需要添加的电荷数(正数表示负电荷,负数表示正电荷)
|
| 108 |
+
"""
|
| 109 |
+
if atom_symbol not in VALENCES:
|
| 110 |
+
raise ValueError(f"未知的原子符号: {atom_symbol}")
|
| 111 |
+
|
| 112 |
+
# 查找该元素的最大价态
|
| 113 |
+
max_valence = max(VALENCES[atom_symbol])
|
| 114 |
+
if current_valence is None:
|
| 115 |
+
current_valence=max_valence
|
| 116 |
+
# 如果当前价态大于最大允许价态,需要调整电荷
|
| 117 |
+
if current_valence > max_valence:
|
| 118 |
+
# 需要添加的负电荷数
|
| 119 |
+
charge_adjustment = current_valence - max_valence
|
| 120 |
+
return charge_adjustment
|
| 121 |
+
else:
|
| 122 |
+
# 当前价态已经符合最大允许价态,不需要调整
|
| 123 |
+
return 0
|
| 124 |
+
|
| 125 |
+
from rdkit.Chem import rdchem, RWMol, CombineMols
|
| 126 |
+
|
| 127 |
+
def expandABB(mol,ABBREVIATIONS, placeholder_atoms):
|
| 128 |
+
mols = [mol]
|
| 129 |
+
# **第三步: 替换 * 并合并官能团**
|
| 130 |
+
# 逆序遍历 placeholder_atoms,确保删除后不会影响后续索引
|
| 131 |
+
for idx in sorted(placeholder_atoms.keys(), reverse=True):
|
| 132 |
+
group = placeholder_atoms[idx] # 获取官能团名称
|
| 133 |
+
# print(idx, group)
|
| 134 |
+
submol = Chem.MolFromSmiles(ABBREVIATIONS[group].smiles) # 获取官能团的子���子
|
| 135 |
+
submol_rw = RWMol(submol) # 让 submol 变成可编辑的 RWMol
|
| 136 |
+
anchor_atom_idx = 0 # 选择 `submol` 的第一个原子作为连接点 as defined in ABBREVIATIONS
|
| 137 |
+
# **1. 复制主分子**
|
| 138 |
+
new_mol = RWMol(mol)
|
| 139 |
+
# **2. 计算 `*` 在 `new_mol` 中的索引**
|
| 140 |
+
placeholder_idx = idx
|
| 141 |
+
# **3. 记录 `*` 原子的邻居**
|
| 142 |
+
neighbors = [nb.GetIdx() for nb in new_mol.GetAtomWithIdx(placeholder_idx).GetNeighbors()]
|
| 143 |
+
# **4. 断开 `*` 的所有键**
|
| 144 |
+
bonds_to_remove = [] # 记录要断开的键
|
| 145 |
+
for bond in new_mol.GetBonds():
|
| 146 |
+
if bond.GetBeginAtomIdx() == placeholder_idx or bond.GetEndAtomIdx() == placeholder_idx:
|
| 147 |
+
bonds_to_remove.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
|
| 148 |
+
for bond in bonds_to_remove:
|
| 149 |
+
new_mol.RemoveBond(bond[0], bond[1])
|
| 150 |
+
# **5. 删除 `*` 原子**
|
| 151 |
+
new_mol.RemoveAtom(placeholder_idx)
|
| 152 |
+
# **6. 重新计算 `neighbors`(删除后索引变化)**
|
| 153 |
+
new_neighbors = []
|
| 154 |
+
for neighbor in neighbors:
|
| 155 |
+
if neighbor < placeholder_idx:
|
| 156 |
+
new_neighbors.append(neighbor)
|
| 157 |
+
else:
|
| 158 |
+
new_neighbors.append(neighbor - 1) # 因为删除了一个原子,所有索引 -1
|
| 159 |
+
# **7. 合并 `submol`**
|
| 160 |
+
new_mol = RWMol(CombineMols(new_mol, submol_rw))
|
| 161 |
+
|
| 162 |
+
# **8. 计算 `submol` 的第一个原子在合并后的位置**
|
| 163 |
+
new_anchor_idx = new_mol.GetNumAtoms() - len(submol_rw.GetAtoms()) + anchor_atom_idx
|
| 164 |
+
|
| 165 |
+
# **9. 重新连接官能团**
|
| 166 |
+
for neighbor in new_neighbors:
|
| 167 |
+
# print(neighbor, new_anchor_idx, "!!")
|
| 168 |
+
new_mol.AddBond(neighbor, new_anchor_idx, Chem.BondType.SINGLE)
|
| 169 |
+
a1=new_mol.GetAtomWithIdx(neighbor)
|
| 170 |
+
a2=new_mol.GetAtomWithIdx(new_anchor_idx)
|
| 171 |
+
a1.SetNumRadicalElectrons(0)
|
| 172 |
+
a2.SetNumRadicalElectrons(0)## 将自由基电子数设为 0,as has added new bond
|
| 173 |
+
# **10. 更新主分子**
|
| 174 |
+
mol = new_mol
|
| 175 |
+
mols.append(mol)
|
| 176 |
+
# # 遍历分子中的每个原子
|
| 177 |
+
# for atom in mols[-1].GetAtoms(): NOTE considering original image has the RadicalElectrons
|
| 178 |
+
# atom_idx = atom.GetIdx() # 原子索引
|
| 179 |
+
# radical_electrons = atom.GetNumRadicalElectrons() # 自由基电子数
|
| 180 |
+
# if radical_electrons > 0:
|
| 181 |
+
# # print(f"原子 {atom_idx} 存在自由基,自由基电子数: {radical_electrons}\n current NumExplicitHs: {atom.GetNumExplicitHs()}")
|
| 182 |
+
# # 消除自由基:通过添加氢原子调整价态
|
| 183 |
+
# atom.SetNumRadicalElectrons(0) # 将自由基电子数设为 0,as has added bond
|
| 184 |
+
# # atom.SetNumExplicitHs(atom.GetNumExplicitHs() + radical_electrons)
|
| 185 |
+
Chem.SanitizeMol(mols[-1])
|
| 186 |
+
# 输出修改后的分子 SMILES
|
| 187 |
+
modified_smiles = Chem.MolToSmiles(mols[-1])
|
| 188 |
+
# print(f"修改后的分子 SMILES: {modified_smiles}")
|
| 189 |
+
return mols[-1], modified_smiles
|
| 190 |
+
|
| 191 |
+
################################################################################################################################################################
|
| 192 |
+
def output_to_smiles(output,idx_to_labels,bond_labels,result):#this will output * without abbre version
|
| 193 |
+
#only output smiles with *
|
| 194 |
+
x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
|
| 195 |
+
y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
|
| 196 |
+
|
| 197 |
+
center_coords = torch.stack((x_center, y_center), dim=1)
|
| 198 |
+
|
| 199 |
+
output = {'bbox': output["boxes"].to("cpu").numpy(),
|
| 200 |
+
'bbox_centers': center_coords.to("cpu").numpy(),
|
| 201 |
+
'scores': output["scores"].to("cpu").numpy(),
|
| 202 |
+
'pred_classes': output["labels"].to("cpu").numpy()}
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
atoms_list, bonds_list,charge = bbox_to_graph_with_charge(output,
|
| 206 |
+
idx_to_labels=idx_to_labels,
|
| 207 |
+
bond_labels=bond_labels,
|
| 208 |
+
result=result)
|
| 209 |
+
smiles, mol= mol_from_graph_with_chiral(atoms_list, bonds_list,charge)
|
| 210 |
+
abc=[atoms_list, bonds_list,charge ]
|
| 211 |
+
|
| 212 |
+
if isinstance(smiles, type(None)):
|
| 213 |
+
print(f"get atoms_list problems")
|
| 214 |
+
# smiles, mol=None,None
|
| 215 |
+
elif isinstance(atoms_list,type(None)):
|
| 216 |
+
print(f"get atoms_list problems")
|
| 217 |
+
# smiles, mol=None,None
|
| 218 |
+
# else:
|
| 219 |
+
# smiles, mol=smiles_mol
|
| 220 |
+
return abc,smiles,mol,output
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def output_to_smiles2(output,idx_to_labels,bond_labels,result):#this will output * without abbre version
|
| 224 |
+
#only output smiles with *
|
| 225 |
+
x_center = (output["boxes"][:, 0] + output["boxes"][:, 2]) / 2
|
| 226 |
+
y_center = (output["boxes"][:, 1] + output["boxes"][:, 3]) / 2
|
| 227 |
+
|
| 228 |
+
center_coords = torch.stack((x_center, y_center), dim=1)
|
| 229 |
+
|
| 230 |
+
output = {'bbox': output["boxes"].to("cpu").numpy(),
|
| 231 |
+
'bbox_centers': center_coords.to("cpu").numpy(),
|
| 232 |
+
'scores': output["scores"].to("cpu").numpy(),
|
| 233 |
+
'pred_classes': output["labels"].to("cpu").numpy()}
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
atoms_list, bonds_list,charge = bbox_to_graph_with_charge(output,
|
| 237 |
+
idx_to_labels=idx_to_labels,
|
| 238 |
+
bond_labels=bond_labels,
|
| 239 |
+
result=result)
|
| 240 |
+
smiles, mol= mol_from_graph_with_chiral(atoms_list, bonds_list,charge)
|
| 241 |
+
abc=[atoms_list, bonds_list,charge ]
|
| 242 |
+
if isinstance(smiles, type(None)):
|
| 243 |
+
print(f"get atoms_list problems")
|
| 244 |
+
# smiles, mol=None,None
|
| 245 |
+
elif isinstance(atoms_list,type(None)):
|
| 246 |
+
print(f"get atoms_list problems")
|
| 247 |
+
# smiles, mol=None,None
|
| 248 |
+
# else:
|
| 249 |
+
# smiles, mol=smiles_mol
|
| 250 |
+
return abc,smiles,mol,output
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def bbox_to_graph(output, idx_to_labels, bond_labels,result):
|
| 255 |
+
|
| 256 |
+
# calculate atoms mask (pred classes that are atoms/bonds)
|
| 257 |
+
atoms_mask = np.array([True if ins not in bond_labels else False for ins in output['pred_classes']])
|
| 258 |
+
|
| 259 |
+
# get atom list
|
| 260 |
+
atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]]
|
| 261 |
+
|
| 262 |
+
# if len(result) !=0 and 'other' in atoms_list:
|
| 263 |
+
# new_list = []
|
| 264 |
+
# replace_index = 0
|
| 265 |
+
# for item in atoms_list:
|
| 266 |
+
# if item == 'other':
|
| 267 |
+
# new_list.append(result[replace_index % len(result)])
|
| 268 |
+
# replace_index += 1
|
| 269 |
+
# else:
|
| 270 |
+
# new_list.append(item)
|
| 271 |
+
# atoms_list = new_list
|
| 272 |
+
|
| 273 |
+
atoms_list = pd.DataFrame({'atom': atoms_list,
|
| 274 |
+
'x': output['bbox_centers'][atoms_mask, 0],
|
| 275 |
+
'y': output['bbox_centers'][atoms_mask, 1]})
|
| 276 |
+
|
| 277 |
+
# in case atoms with sign gets detected two times, keep only the signed one
|
| 278 |
+
for idx, row in atoms_list.iterrows():
|
| 279 |
+
if row.atom[-1] != '0':
|
| 280 |
+
if row.atom[-2] != '-':#assume charge value -9~9
|
| 281 |
+
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])]
|
| 282 |
+
else:
|
| 283 |
+
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])]
|
| 284 |
+
|
| 285 |
+
kdt = cKDTree(overlapping[['x', 'y']])
|
| 286 |
+
dists, neighbours = kdt.query([row.x, row.y], k=2)
|
| 287 |
+
if dists[1] < 7:
|
| 288 |
+
atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True)
|
| 289 |
+
|
| 290 |
+
bonds_list = []
|
| 291 |
+
|
| 292 |
+
# get bonds
|
| 293 |
+
for bbox, bond_type, score in zip(output['bbox'][np.logical_not(atoms_mask)],
|
| 294 |
+
output['pred_classes'][np.logical_not(atoms_mask)],
|
| 295 |
+
output['scores'][np.logical_not(atoms_mask)]):
|
| 296 |
+
|
| 297 |
+
# if idx_to_labels[bond_type] == 'SINGLE':
|
| 298 |
+
if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| 299 |
+
_margin = 5
|
| 300 |
+
else:
|
| 301 |
+
_margin = 8
|
| 302 |
+
|
| 303 |
+
# anchor positions are _margin distances away from the corners of the bbox.
|
| 304 |
+
anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| 305 |
+
oposite_anchor_positions = anchor_positions.copy()
|
| 306 |
+
oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
| 307 |
+
|
| 308 |
+
# Upper left, lower right, lower left, upper right
|
| 309 |
+
# 0 - 1, 2 - 3
|
| 310 |
+
anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
| 311 |
+
|
| 312 |
+
# get the closest point to every corner
|
| 313 |
+
atoms_pos = atoms_list[['x', 'y']].values
|
| 314 |
+
kdt = cKDTree(atoms_pos)
|
| 315 |
+
dists, neighbours = kdt.query(anchor_positions, k=1)
|
| 316 |
+
|
| 317 |
+
# check corner with the smallest total distance to closest atoms
|
| 318 |
+
if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
|
| 319 |
+
# visualize setup
|
| 320 |
+
begin_idx, end_idx = neighbours[:2]
|
| 321 |
+
else:
|
| 322 |
+
# visualize setup
|
| 323 |
+
begin_idx, end_idx = neighbours[2:]
|
| 324 |
+
|
| 325 |
+
#NOTE this proces may lead self-bonding for one atom
|
| 326 |
+
if begin_idx != end_idx:# avoid self-bond
|
| 327 |
+
bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score))
|
| 328 |
+
else:
|
| 329 |
+
continue
|
| 330 |
+
# return atoms_list.atom.values.tolist(), bonds_list
|
| 331 |
+
return atoms_list, bonds_list
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def calculate_distance(coord1, coord2):
|
| 335 |
+
# Calculate Euclidean distance between two coordinates
|
| 336 |
+
return math.sqrt((coord1[0] - coord2[0])**2 + (coord1[1] - coord2[1])**2)
|
| 337 |
+
|
| 338 |
+
def assemble_atoms_with_charges(atom_list, charge_list):
|
| 339 |
+
used_charge_indices=set()
|
| 340 |
+
atom_list = atom_list.reset_index(drop=True)
|
| 341 |
+
# atom_list['atom'] = atom_list['atom'] + '0'
|
| 342 |
+
kdt = cKDTree(atom_list[['x','y']])
|
| 343 |
+
for i, charge in charge_list.iterrows():
|
| 344 |
+
if i in used_charge_indices:
|
| 345 |
+
continue
|
| 346 |
+
charge_=charge['charge']
|
| 347 |
+
# if charge_=='1':charge_='+'
|
| 348 |
+
dist, idx_atom=kdt.query([charge_list.x[i],charge_list.y[i]], k=1)
|
| 349 |
+
# atom_str=atom_list.loc[idx_atom,'atom']
|
| 350 |
+
if idx_atom not in atom_list.index:
|
| 351 |
+
print(f"Warning: idx_atom {idx_atom} is out of range for atom_list.")
|
| 352 |
+
continue # 跳过当前循环迭代
|
| 353 |
+
atom_str = atom_list.iloc[idx_atom]['atom']
|
| 354 |
+
if atom_str=='*':
|
| 355 |
+
atom_=atom_str + charge_
|
| 356 |
+
else:
|
| 357 |
+
try:
|
| 358 |
+
atom_ = re.findall(r'[A-Za-z*]+', atom_str)[0] + charge_
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(atom_str,charge_,charge_list)
|
| 361 |
+
print(f"@assemble_atoms_with_charges\n {e}\n{atom_list}")
|
| 362 |
+
atom_=atom_str + charge_
|
| 363 |
+
atom_list.loc[idx_atom,'atom']=atom_
|
| 364 |
+
|
| 365 |
+
return atom_list
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def assemble_atoms_with_charges2(atom_list, charge_list, max_distance=10):
|
| 370 |
+
used_charge_indices = set()
|
| 371 |
+
|
| 372 |
+
for idx, atom in atom_list.iterrows():
|
| 373 |
+
atom_coord = atom['x'],atom['y']
|
| 374 |
+
atom_label = atom['atom']
|
| 375 |
+
closest_charge = None
|
| 376 |
+
min_distance = float('inf')
|
| 377 |
+
|
| 378 |
+
for i, charge in charge_list.iterrows():
|
| 379 |
+
if i in used_charge_indices:
|
| 380 |
+
continue
|
| 381 |
+
|
| 382 |
+
charge_coord = charge['x'],charge['y']
|
| 383 |
+
charge_label = charge['charge']
|
| 384 |
+
|
| 385 |
+
distance = calculate_distance(atom_coord, charge_coord)
|
| 386 |
+
#NOTE how t determin this max_distance, dependent on image size??
|
| 387 |
+
if distance <= max_distance and distance < min_distance:
|
| 388 |
+
closest_charge = charge
|
| 389 |
+
min_distance = distance
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
if closest_charge is not None:
|
| 393 |
+
if closest_charge['charge'] == '1':
|
| 394 |
+
charge_ = '+'
|
| 395 |
+
else:
|
| 396 |
+
charge_ = closest_charge['charge']
|
| 397 |
+
atom_ = atom['atom'] + charge_
|
| 398 |
+
|
| 399 |
+
# atom['atom'] = atom_
|
| 400 |
+
atom_list.loc[idx,'atom'] = atom_
|
| 401 |
+
used_charge_indices.add(tuple(charge))
|
| 402 |
+
|
| 403 |
+
else:
|
| 404 |
+
# atom['atom'] = atom['atom'] + '0'
|
| 405 |
+
atom_list.loc[idx,'atom'] = atom['atom'] + '0'
|
| 406 |
+
|
| 407 |
+
return atom_list
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def bbox_to_graph_with_charge(output, idx_to_labels, bond_labels,result):
|
| 412 |
+
|
| 413 |
+
bond_labels_pre=bond_labels
|
| 414 |
+
# charge_labels = [18,19,20,21,22]#make influence
|
| 415 |
+
atoms_mask = np.array([True if ins not in bond_labels and ins not in charge_labels else False for ins in output['pred_classes']])
|
| 416 |
+
|
| 417 |
+
try:
|
| 418 |
+
# print(atoms_mask.shape)
|
| 419 |
+
# print(output['pred_classes'].shape)
|
| 420 |
+
atoms_list = [idx_to_labels[a] for a in output['pred_classes'][atoms_mask]]
|
| 421 |
+
if isinstance(atoms_list, pd.Series) and atoms_list.empty:
|
| 422 |
+
return None, None, None
|
| 423 |
+
else:
|
| 424 |
+
atoms_list = pd.DataFrame({'atom': atoms_list,
|
| 425 |
+
'x': output['bbox_centers'][atoms_mask, 0],
|
| 426 |
+
'y': output['bbox_centers'][atoms_mask, 1],
|
| 427 |
+
'bbox': output['bbox'][atoms_mask].tolist() ,#need this for */other converting
|
| 428 |
+
'scores': output['scores'][atoms_mask].tolist(),
|
| 429 |
+
})
|
| 430 |
+
except Exception as e:
|
| 431 |
+
print(output['pred_classes'][atoms_mask].dtype,output['pred_classes'][atoms_mask])#int64 [ 1 1 1 1 1 2 1 29]
|
| 432 |
+
print(e)
|
| 433 |
+
print(idx_to_labels)
|
| 434 |
+
# print(output['pred_classes'][atoms_mask],"output['pred_classes'][atoms_mask]")
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# confict_atompaire=[]
|
| 438 |
+
# # 如果你想计算所有边界框之间的IOU,考虑2个原子box 重叠 是否要删掉一个?? TODO gmy version most box larger then normal mix the rules
|
| 439 |
+
# for i in range(len(atoms_list)):
|
| 440 |
+
# for j in range(i + 1, len(atoms_list)):
|
| 441 |
+
# iou_value = calculate_iou(atoms_list.bbox[i], atoms_list.bbox[j])
|
| 442 |
+
# if iou_value !=0:
|
| 443 |
+
# # print(f"IOU between box {i} and box {j}: {iou_value}")
|
| 444 |
+
# if i !=j : confict_atompaire.append([i,j])
|
| 445 |
+
# if len(confict_atompaire)>0:
|
| 446 |
+
# need_del=[]
|
| 447 |
+
# for i,j in confict_atompaire:
|
| 448 |
+
# ij_lab=[atoms_list.loc[i].atom,atoms_list.loc[j].atom ]
|
| 449 |
+
# ij_score=[atoms_list.loc[i].scores,atoms_list.loc[j].scores]
|
| 450 |
+
# # print(ij_lab,ij_score)
|
| 451 |
+
# if ij_lab==['C','N'] or ij_lab==['N','C']:
|
| 452 |
+
# if atoms_list.loc[i].atom =='C':
|
| 453 |
+
# need_del.append(i)
|
| 454 |
+
# else:
|
| 455 |
+
# need_del.append(j)
|
| 456 |
+
# elif atoms_list.loc[i].scores> atoms_list.loc[j].scores:
|
| 457 |
+
# need_del.append(j)
|
| 458 |
+
# elif atoms_list.loc[j].scores> atoms_list.loc[i].scores:
|
| 459 |
+
# need_del.append(i)
|
| 460 |
+
# print(need_del)
|
| 461 |
+
# atoms_list= atoms_list.drop(need_del)
|
| 462 |
+
|
| 463 |
+
charge_mask = np.array([True if ins in charge_labels else False for ins in output['pred_classes']])
|
| 464 |
+
charge_list = [idx_to_labels[a] for a in output['pred_classes'][charge_mask]]
|
| 465 |
+
charge_list = pd.DataFrame({'charge': charge_list,
|
| 466 |
+
'x': output['bbox_centers'][charge_mask, 0],
|
| 467 |
+
'y': output['bbox_centers'][charge_mask, 1],
|
| 468 |
+
'scores': output['scores'][charge_mask],
|
| 469 |
+
|
| 470 |
+
})
|
| 471 |
+
|
| 472 |
+
# print(charge_list,'\n@bbox_to_graph_with_charge')
|
| 473 |
+
try:
|
| 474 |
+
atoms_list['atom'] = atoms_list['atom']+'0'#add 0
|
| 475 |
+
except Exception as e:
|
| 476 |
+
print(e)
|
| 477 |
+
print(atoms_list['atom'],'atoms_list["atom"] @@ adding 0 ')
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
if len(charge_list) > 0:
|
| 481 |
+
atoms_list = assemble_atoms_with_charges(atoms_list,charge_list)
|
| 482 |
+
# else:#Note Most mols are not formal charged
|
| 483 |
+
# atoms_list['atom'] = atoms_list['atom']+'0'
|
| 484 |
+
# print(atoms_list,"after @@assemble_atoms_with_charges ")
|
| 485 |
+
|
| 486 |
+
# in case atoms with sign gets detected two times, keep only the signed one
|
| 487 |
+
for idx, row in atoms_list.iterrows():
|
| 488 |
+
if row.atom[-1] != '0':
|
| 489 |
+
try:
|
| 490 |
+
if row.atom[-2] != '-':#assume charge value -9~9
|
| 491 |
+
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-1])]
|
| 492 |
+
except Exception as e:
|
| 493 |
+
print(row.atom,"@rin case atoms with sign gets detected two times")
|
| 494 |
+
print(e)
|
| 495 |
+
else:
|
| 496 |
+
overlapping = atoms_list[atoms_list.atom.str.startswith(row.atom[:-2])]
|
| 497 |
+
|
| 498 |
+
kdt = cKDTree(overlapping[['x', 'y']])
|
| 499 |
+
dists, neighbours = kdt.query([row.x, row.y], k=2)
|
| 500 |
+
if dists[1] < 7:
|
| 501 |
+
atoms_list.drop(overlapping.index[neighbours[1]], axis=0, inplace=True)
|
| 502 |
+
|
| 503 |
+
bonds_list = []
|
| 504 |
+
# get bonds
|
| 505 |
+
# bond_mask=np.logical_not(np.logical_not(atoms_mask) | np.logical_not(charge_mask))
|
| 506 |
+
bond_mask=np.logical_not(atoms_mask) & np.logical_not(charge_mask)
|
| 507 |
+
for bbox, bond_type, score in zip(output['bbox'][bond_mask], #NOTE also including the charge part
|
| 508 |
+
output['pred_classes'][bond_mask],
|
| 509 |
+
output['scores'][bond_mask]):
|
| 510 |
+
|
| 511 |
+
# if idx_to_labels[bond_type] == 'SINGLE':
|
| 512 |
+
if len(idx_to_labels)==23:
|
| 513 |
+
if idx_to_labels[bond_type] in ['-','SINGLE', 'NONE', 'ENDUPRIGHT', 'BEGINWEDGE', 'BEGINDASH', 'ENDDOWNRIGHT']:
|
| 514 |
+
_margin = 5
|
| 515 |
+
else:
|
| 516 |
+
_margin = 8
|
| 517 |
+
elif len(idx_to_labels)==30:
|
| 518 |
+
_margin=0#ad this version bond dynamicaly changed
|
| 519 |
+
elif len(idx_to_labels)==24:
|
| 520 |
+
_margin=0#ad this version bond dynamicaly changed
|
| 521 |
+
# anchor positions are _margin distances away from the corners of the bbox.
|
| 522 |
+
anchor_positions = (bbox + [_margin, _margin, -_margin, -_margin]).reshape([2, -1])
|
| 523 |
+
oposite_anchor_positions = anchor_positions.copy()
|
| 524 |
+
oposite_anchor_positions[:, 1] = oposite_anchor_positions[:, 1][::-1]
|
| 525 |
+
|
| 526 |
+
# Upper left, lower right, lower left, upper right
|
| 527 |
+
# 0 - 1, 2 - 3
|
| 528 |
+
anchor_positions = np.concatenate([anchor_positions, oposite_anchor_positions])
|
| 529 |
+
|
| 530 |
+
# get the closest point to every corner
|
| 531 |
+
atoms_pos = atoms_list[['x', 'y']].values
|
| 532 |
+
kdt = cKDTree(atoms_pos)
|
| 533 |
+
dists, neighbours = kdt.query(anchor_positions, k=1)
|
| 534 |
+
|
| 535 |
+
# check corner with the smallest total distance to closest atoms
|
| 536 |
+
if np.argmin((dists[0] + dists[1], dists[2] + dists[3])) == 0:
|
| 537 |
+
# visualize setup
|
| 538 |
+
begin_idx, end_idx = neighbours[:2]
|
| 539 |
+
else:
|
| 540 |
+
# visualize setup
|
| 541 |
+
begin_idx, end_idx = neighbours[2:]
|
| 542 |
+
|
| 543 |
+
#NOTE this proces may lead self-bonding for one atom
|
| 544 |
+
if begin_idx != end_idx:
|
| 545 |
+
if bond_type in bond_labels:# avoid self-bond
|
| 546 |
+
bonds_list.append((begin_idx, end_idx, idx_to_labels[bond_type], idx_to_labels[bond_type], score))
|
| 547 |
+
else:
|
| 548 |
+
print(f'this box may be charges box not bonds {[bbox, bond_type, score ]}')
|
| 549 |
+
else:
|
| 550 |
+
continue
|
| 551 |
+
# return atoms_list.atom.values.tolist(), bonds_list
|
| 552 |
+
# print(f"@box2graph: atom,bond nums:: {len(atoms_list)}, {len(bonds_list)}")
|
| 553 |
+
return atoms_list, bonds_list,charge_list#dataframe, list
|
| 554 |
+
|
| 555 |
+
def parse_atom(node):
|
| 556 |
+
s10 = [str(x) for x in range(10)]
|
| 557 |
+
# Determine atom and formal charge
|
| 558 |
+
if 'other' in node:
|
| 559 |
+
a = '*'
|
| 560 |
+
if '-' in node or '+' in node:
|
| 561 |
+
fc = -1 if node[-1] == '-' else 1
|
| 562 |
+
else:
|
| 563 |
+
fc = int(node[-2:]) if node[-2:] in s10 else 0
|
| 564 |
+
elif node[-1] in s10:
|
| 565 |
+
if '-' in node or '+' in node:
|
| 566 |
+
fc = -1 if node[-1] == '-' else 1
|
| 567 |
+
a = node[:-1]
|
| 568 |
+
else:
|
| 569 |
+
a = node[:-1]
|
| 570 |
+
fc = int(node[-1])
|
| 571 |
+
elif node[-1] == '+':
|
| 572 |
+
a = node[:-1]
|
| 573 |
+
fc = 1
|
| 574 |
+
elif node[-1] == '-':
|
| 575 |
+
a = node[:-1]
|
| 576 |
+
fc = -1
|
| 577 |
+
else:
|
| 578 |
+
a = node
|
| 579 |
+
fc = 0
|
| 580 |
+
return a, fc
|
| 581 |
+
|
| 582 |
+
#from engine
|
| 583 |
+
|
| 584 |
+
def iou_(box1, box2):
|
| 585 |
+
"""
|
| 586 |
+
计算两个框的 IoU(Intersection over Union)。
|
| 587 |
+
参数:
|
| 588 |
+
box1, box2: [x1, y1, x2, y2] 格式的框坐标
|
| 589 |
+
|
| 590 |
+
返回:
|
| 591 |
+
float: IoU 值
|
| 592 |
+
"""
|
| 593 |
+
x1 = max(box1[0], box2[0])
|
| 594 |
+
y1 = max(box1[1], box2[1])
|
| 595 |
+
x2 = min(box1[2], box2[2])
|
| 596 |
+
y2 = min(box1[3], box2[3])
|
| 597 |
+
|
| 598 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| 599 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 600 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 601 |
+
union = area1 + area2 - intersection
|
| 602 |
+
return intersection / union if union > 0 else 0
|
| 603 |
+
|
| 604 |
+
|
| 605 |
+
def calculate_iou(bbox1, bbox2):
|
| 606 |
+
# 提取坐标
|
| 607 |
+
x_min1, y_min1, x_max1, y_max1 = bbox1
|
| 608 |
+
x_min2, y_min2, x_max2, y_max2 = bbox2
|
| 609 |
+
|
| 610 |
+
# 计算交集坐标
|
| 611 |
+
x_min_inter = max(x_min1, x_min2)
|
| 612 |
+
y_min_inter = max(y_min1, y_min2)
|
| 613 |
+
x_max_inter = min(x_max1, x_max2)
|
| 614 |
+
y_max_inter = min(y_max1, y_max2)
|
| 615 |
+
|
| 616 |
+
# 计算交集面积
|
| 617 |
+
inter_width = max(0, x_max_inter - x_min_inter)
|
| 618 |
+
inter_height = max(0, y_max_inter - y_min_inter)
|
| 619 |
+
inter_area = inter_width * inter_height
|
| 620 |
+
|
| 621 |
+
# 计算两个框的面积
|
| 622 |
+
area1 = (x_max1 - x_min1) * (y_max1 - y_min1)
|
| 623 |
+
area2 = (x_max2 - x_min2) * (y_max2 - y_min2)
|
| 624 |
+
|
| 625 |
+
# 计算并集面积
|
| 626 |
+
union_area = area1 + area2 - inter_area
|
| 627 |
+
|
| 628 |
+
# 计算 IoU
|
| 629 |
+
iou = inter_area / union_area if union_area > 0 else 0
|
| 630 |
+
|
| 631 |
+
# 判断关系并记录
|
| 632 |
+
result = []
|
| 633 |
+
if iou == 0:
|
| 634 |
+
result.append("无重叠")
|
| 635 |
+
elif iou > 0:
|
| 636 |
+
result.append("有重叠")
|
| 637 |
+
if iou == 1:
|
| 638 |
+
result.append("完全重合")
|
| 639 |
+
elif inter_area == area2:
|
| 640 |
+
result.append("bbox1 包含 bbox2")
|
| 641 |
+
elif inter_area == area1:
|
| 642 |
+
result.append("bbox2 包含 bbox1")
|
| 643 |
+
|
| 644 |
+
return iou, result, inter_area, union_area
|
| 645 |
+
|
| 646 |
+
def adjust_bbox1(large_bbox, small_bbox, bond_bbox):
|
| 647 |
+
# 假设调整策略:扣除小的 atom bbox 和 bond box 的区域
|
| 648 |
+
# 这里简单假设从较大 bbox 中移除小的区域,可能需要根据具体需求调整
|
| 649 |
+
x_min_l, y_min_l, x_max_l, y_max_l = large_bbox
|
| 650 |
+
x_min_s, y_min_s, x_max_s, y_max_s = small_bbox
|
| 651 |
+
x_min_b, y_min_b, x_max_b, y_max_b = bond_bbox
|
| 652 |
+
scaled_box= max([x_min_l,x_min_s,x_min_b]),max([y_min_l,y_min_s,y_min_b]),x_max_l, y_max_l
|
| 653 |
+
return large_bbox
|
| 654 |
+
# 示例调整:如果小的 bbox 和 bond box 在较大 bbox 内,缩小较大 bbox
|
| 655 |
+
# if x_min_s > x_min_l and y_min_s > y_min_l:
|
| 656 |
+
# return [x_min_l, y_min_l, x_min_s, y_min_s] # 示例:保留左上部分
|
| 657 |
+
# return large_bbox # 默认不调整
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def nms_per_class(labels, boxes, scores, iou_thresh=0.5):
|
| 661 |
+
"""
|
| 662 |
+
对每个类别应用 NMS,保留得分最高的框。
|
| 663 |
+
参数:
|
| 664 |
+
labels: numpy array,类别标签
|
| 665 |
+
boxes: numpy array,框坐标 [x1, y1, x2, y2]
|
| 666 |
+
scores: numpy array,得分
|
| 667 |
+
iou_thresh: float,IoU 阈值
|
| 668 |
+
返回:
|
| 669 |
+
dict: 筛选后的输出
|
| 670 |
+
"""
|
| 671 |
+
# 按类别分组
|
| 672 |
+
unique_labels = np.unique(labels)
|
| 673 |
+
kept_indices = []
|
| 674 |
+
for label in unique_labels:
|
| 675 |
+
# 筛选当前类别的框
|
| 676 |
+
class_mask = labels == label
|
| 677 |
+
class_indices = np.where(class_mask)[0]
|
| 678 |
+
class_boxes = boxes[class_mask]
|
| 679 |
+
class_scores = scores[class_mask]
|
| 680 |
+
|
| 681 |
+
# 按得分从高到低排序
|
| 682 |
+
order = np.argsort(class_scores)[::-1]
|
| 683 |
+
class_boxes = class_boxes[order]
|
| 684 |
+
class_scores = class_scores[order]
|
| 685 |
+
class_indices = class_indices[order]
|
| 686 |
+
|
| 687 |
+
# NMS
|
| 688 |
+
keep = []
|
| 689 |
+
while len(class_scores) > 0:
|
| 690 |
+
# 保留得分最高的框
|
| 691 |
+
keep.append(class_indices[0])
|
| 692 |
+
if len(class_scores) == 1:
|
| 693 |
+
break
|
| 694 |
+
|
| 695 |
+
# 计算当前框与其他框的 IoU
|
| 696 |
+
ious = np.array([calculate_iou(class_boxes[0], box) for box in class_boxes[1:]])
|
| 697 |
+
# 保留 IoU 低于阈值的框
|
| 698 |
+
keep_mask = ious < iou_thresh
|
| 699 |
+
class_boxes = class_boxes[1:][keep_mask]
|
| 700 |
+
class_scores = class_scores[1:][keep_mask]
|
| 701 |
+
class_indices = class_indices[1:][keep_mask]
|
| 702 |
+
|
| 703 |
+
kept_indices.extend(keep)
|
| 704 |
+
|
| 705 |
+
# 根据保留的索引更新输出
|
| 706 |
+
kept_indices = np.array(kept_indices)
|
| 707 |
+
return {
|
| 708 |
+
'labels': labels[kept_indices],
|
| 709 |
+
'boxes': boxes[kept_indices],
|
| 710 |
+
'scores': scores[kept_indices]
|
| 711 |
+
}
|
| 712 |
+
|