Spaces:
Build error
Build error
msIntui commited on
Commit ·
9847531
0
Parent(s):
Initial commit: Add core files for P&ID processing
Browse files- .gitattributes +6 -0
- .gitignore +33 -0
- README.md +464 -0
- base.py +138 -0
- chatbot_agent.py +193 -0
- common.py +14 -0
- config.py +116 -0
- data_aggregation_ai.py +455 -0
- detection_schema.py +481 -0
- detectors.py +1096 -0
- gradioChatApp.py +672 -0
- graph_construction.py +113 -0
- graph_processor.py +115 -0
- graph_visualization.py +235 -0
- line_detection_ai.py +410 -0
- packages.txt +3 -0
- pdf_processor.py +288 -0
- requirements.txt +41 -0
- setup.py +24 -0
- storage.py +208 -0
- symbol_detection.py +454 -0
- text_detection_combined.py +553 -0
- utils.py +122 -0
.gitattributes
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
models/* filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
chat/*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
chat/adapter_model.safetensors filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Samples and large files
|
| 2 |
+
samples/
|
| 3 |
+
*.pdf
|
| 4 |
+
*.jpg
|
| 5 |
+
*.jpeg
|
| 6 |
+
*.png
|
| 7 |
+
*.zip
|
| 8 |
+
|
| 9 |
+
# Model files
|
| 10 |
+
*.pth
|
| 11 |
+
*.pt
|
| 12 |
+
*.tar
|
| 13 |
+
models/*
|
| 14 |
+
chat/*.safetensors
|
| 15 |
+
|
| 16 |
+
# Python
|
| 17 |
+
__pycache__/
|
| 18 |
+
*.py[cod]
|
| 19 |
+
*.class
|
| 20 |
+
.env
|
| 21 |
+
.venv/
|
| 22 |
+
venv/
|
| 23 |
+
ENV/
|
| 24 |
+
|
| 25 |
+
# IDE
|
| 26 |
+
.vscode/
|
| 27 |
+
.idea/
|
| 28 |
+
|
| 29 |
+
# Other
|
| 30 |
+
archive/
|
| 31 |
+
archive 2/
|
| 32 |
+
results/
|
| 33 |
+
DeepLSD/
|
README.md
ADDED
|
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Intelligent_PID
|
| 3 |
+
emoji: 🔍
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.3.0
|
| 8 |
+
app_file: gradioChatApp.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
# P&ID Processing with AI-Powered Graph Construction
|
| 14 |
+
|
| 15 |
+
## Overview
|
| 16 |
+
|
| 17 |
+
This project processes P&ID (Piping and Instrumentation Diagram) images using multiple AI models for symbol detection, text recognition, and line detection. It constructs a graph representation of the diagram and provides an interactive interface for querying the diagram's contents.
|
| 18 |
+
|
| 19 |
+
## Process Flow
|
| 20 |
+
|
| 21 |
+
```mermaid
|
| 22 |
+
graph TD
|
| 23 |
+
subgraph "Document Input"
|
| 24 |
+
A[Upload Document] --> B[Validate File]
|
| 25 |
+
B -->|PDF/Image| C[Document Processor]
|
| 26 |
+
B -->|Invalid| ERR[Error Message]
|
| 27 |
+
C -->|PDF| D1[Extract Pages]
|
| 28 |
+
C -->|Image| D2[Direct Process]
|
| 29 |
+
end
|
| 30 |
+
|
| 31 |
+
subgraph "Image Preprocessing"
|
| 32 |
+
D1 --> E[Optimize Image]
|
| 33 |
+
D2 --> E
|
| 34 |
+
E -->|CLAHE Enhancement| E1[Contrast Enhancement]
|
| 35 |
+
E1 -->|Denoising| E2[Clean Image]
|
| 36 |
+
E2 -->|Binarization| E3[Binary Image]
|
| 37 |
+
E3 -->|Resize| E4[Normalized Image]
|
| 38 |
+
end
|
| 39 |
+
|
| 40 |
+
subgraph "Line Detection Pipeline"
|
| 41 |
+
E4 --> L1[Load DeepLSD Model]
|
| 42 |
+
L1 --> L2[Scale Image 0.1x]
|
| 43 |
+
L2 --> L3[Grayscale Conversion]
|
| 44 |
+
L3 --> L4[Model Inference]
|
| 45 |
+
L4 --> L5[Scale Coordinates]
|
| 46 |
+
L5 --> L6[Draw Lines]
|
| 47 |
+
end
|
| 48 |
+
|
| 49 |
+
subgraph "Detection Pipeline"
|
| 50 |
+
E4 --> F[Symbol Detection]
|
| 51 |
+
E4 --> G[Text Detection]
|
| 52 |
+
|
| 53 |
+
F --> S1[Load YOLO Models]
|
| 54 |
+
G --> T1[Load OCR Models]
|
| 55 |
+
|
| 56 |
+
S1 --> S2[Detect Symbols]
|
| 57 |
+
T1 --> T2[Detect Text]
|
| 58 |
+
|
| 59 |
+
S2 --> S3[Process Symbols]
|
| 60 |
+
T2 --> T3[Process Text]
|
| 61 |
+
|
| 62 |
+
L6 --> L7[Process Lines]
|
| 63 |
+
end
|
| 64 |
+
|
| 65 |
+
subgraph "Data Integration"
|
| 66 |
+
S3 --> I[Data Aggregation]
|
| 67 |
+
T3 --> I
|
| 68 |
+
L7 --> I
|
| 69 |
+
I --> J[Create Edges]
|
| 70 |
+
J --> K[Build Graph Network]
|
| 71 |
+
K --> L[Generate Knowledge Graph]
|
| 72 |
+
end
|
| 73 |
+
|
| 74 |
+
subgraph "User Interface"
|
| 75 |
+
L --> M[Interactive Visualization]
|
| 76 |
+
M --> N[Chat Interface]
|
| 77 |
+
N --> O[Query Processing]
|
| 78 |
+
O --> P[Response Generation]
|
| 79 |
+
P --> N
|
| 80 |
+
end
|
| 81 |
+
|
| 82 |
+
style A fill:#f9f,stroke:#333,stroke-width:2px
|
| 83 |
+
style F fill:#fbb,stroke:#333,stroke-width:2px
|
| 84 |
+
style G fill:#bfb,stroke:#333,stroke-width:2px
|
| 85 |
+
style H fill:#bbf,stroke:#333,stroke-width:2px
|
| 86 |
+
style I fill:#fbf,stroke:#333,stroke-width:2px
|
| 87 |
+
style N fill:#bbf,stroke:#333,stroke-width:2px
|
| 88 |
+
|
| 89 |
+
%% Add style for model nodes
|
| 90 |
+
style SM1 fill:#ffe6e6,stroke:#333,stroke-width:2px
|
| 91 |
+
style SM2 fill:#ffe6e6,stroke:#333,stroke-width:2px
|
| 92 |
+
style LM1 fill:#e6e6ff,stroke:#333,stroke-width:2px
|
| 93 |
+
style DC1 fill:#e6ffe6,stroke:#333,stroke-width:2px
|
| 94 |
+
style DC2 fill:#e6ffe6,stroke:#333,stroke-width:2px
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
## Architecture
|
| 98 |
+
|
| 99 |
+

|
| 100 |
+
|
| 101 |
+
## Features
|
| 102 |
+
|
| 103 |
+
- **Multi-modal AI Processing**:
|
| 104 |
+
- Combined OCR approach using Tesseract, EasyOCR, and DocTR
|
| 105 |
+
- Symbol detection with optimized thresholds
|
| 106 |
+
- Intelligent line and connection detection
|
| 107 |
+
- **Document Processing**:
|
| 108 |
+
- Support for PDF, PNG, JPG, JPEG formats
|
| 109 |
+
- Automatic page extraction from PDFs
|
| 110 |
+
- Image optimization pipeline
|
| 111 |
+
- **Text Detection Types**:
|
| 112 |
+
- Equipment Tags
|
| 113 |
+
- Line Numbers
|
| 114 |
+
- Instrument Tags
|
| 115 |
+
- Valve Numbers
|
| 116 |
+
- Pipe Sizes
|
| 117 |
+
- Flow Directions
|
| 118 |
+
- Service Descriptions
|
| 119 |
+
- Process Instruments
|
| 120 |
+
- Nozzles
|
| 121 |
+
- Pipe Connectors
|
| 122 |
+
- **Data Integration**:
|
| 123 |
+
- Automatic edge detection
|
| 124 |
+
- Relationship mapping
|
| 125 |
+
- Confidence scoring
|
| 126 |
+
- Detailed detection statistics
|
| 127 |
+
- **User Interface**:
|
| 128 |
+
- Interactive visualization tabs
|
| 129 |
+
- Real-time processing feedback
|
| 130 |
+
- AI-powered chat interface
|
| 131 |
+
- Knowledge graph exploration
|
| 132 |
+
|
| 133 |
+
The entire process is visualized through an interactive Gradio-based UI, allowing users to upload a P&ID image, follow the detection steps, and view both the results and insights in real time.
|
| 134 |
+
|
| 135 |
+
## Key Files
|
| 136 |
+
|
| 137 |
+
- **gradioChatApp.py**: The main Gradio app script that handles the frontend and orchestrates the overall flow.
|
| 138 |
+
- **symbol_detection.py**: Module for detecting symbols using YOLO models.
|
| 139 |
+
- **text_detection_combined.py**: Unified module for text detection using multiple OCR engines (Tesseract, EasyOCR, DocTR).
|
| 140 |
+
- **line_detection_ai.py**: Module for detecting lines and connections using AI.
|
| 141 |
+
- **data_aggregation.py**: Aggregates detected elements into a structured format.
|
| 142 |
+
- **graph_construction.py**: Constructs the graph network from aggregated data.
|
| 143 |
+
- **graph_processor.py**: Handles graph visualization and processing.
|
| 144 |
+
- **pdf_processor.py**: Handles PDF document processing and page extraction.
|
| 145 |
+
|
| 146 |
+
## Setup and Installation
|
| 147 |
+
|
| 148 |
+
1. Clone the repository:
|
| 149 |
+
```bash
|
| 150 |
+
git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
|
| 151 |
+
cd intui-PnID-POC
|
| 152 |
+
```
|
| 153 |
+
|
| 154 |
+
2. Install dependencies using uv:
|
| 155 |
+
```bash
|
| 156 |
+
# Install uv if you haven't already
|
| 157 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 158 |
+
|
| 159 |
+
# Create and activate virtual environment
|
| 160 |
+
uv venv
|
| 161 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 162 |
+
|
| 163 |
+
# Install dependencies
|
| 164 |
+
uv pip install -r requirements.txt
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
3. Download required models:
|
| 168 |
+
```bash
|
| 169 |
+
python download_model.py # Downloads DeepLSD model for line detection
|
| 170 |
+
```
|
| 171 |
+
|
| 172 |
+
4. Run the application:
|
| 173 |
+
```bash
|
| 174 |
+
python gradioChatApp.py
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
## Models
|
| 178 |
+
|
| 179 |
+
### Line Detection Model
|
| 180 |
+
- **DeepLSD Model**:
|
| 181 |
+
- File: deeplsd_md.tar
|
| 182 |
+
- Purpose: Line segment detection in P&ID diagrams
|
| 183 |
+
- Input Resolution: Variable (scaled to 0.1x for performance)
|
| 184 |
+
- Processing: Grayscale conversion and binary thresholding
|
| 185 |
+
|
| 186 |
+
### Text Detection Models
|
| 187 |
+
- **Combined OCR Approach**:
|
| 188 |
+
- Tesseract OCR
|
| 189 |
+
- EasyOCR
|
| 190 |
+
- DocTR
|
| 191 |
+
- Purpose: Text recognition and classification
|
| 192 |
+
|
| 193 |
+
### Graph Processing
|
| 194 |
+
- **NetworkX-based**:
|
| 195 |
+
- Purpose: Graph construction and analysis
|
| 196 |
+
- Features: Node linking, edge creation, path analysis
|
| 197 |
+
|
| 198 |
+
## Updating the Environment
|
| 199 |
+
|
| 200 |
+
To update the environment, use the following:
|
| 201 |
+
|
| 202 |
+
```bash
|
| 203 |
+
conda env update --file environment.yml --prune
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
This command will update the environment according to changes made in the `environment.yml`.
|
| 207 |
+
|
| 208 |
+
### Step 6: Deactivate the environment
|
| 209 |
+
|
| 210 |
+
When you're done, deactivate the environment by:
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
conda deactivate
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
2. Upload a P&ID image through the interface.
|
| 217 |
+
3. Follow the sequential steps of symbol, text, and line detection.
|
| 218 |
+
4. View the generated graph and AI agent's reasoning in the real-time chat box.
|
| 219 |
+
5. Save and export the results if satisfactory.
|
| 220 |
+
|
| 221 |
+
## Folder Structure
|
| 222 |
+
|
| 223 |
+
```
|
| 224 |
+
├── assets/
|
| 225 |
+
│ └── AiAgent.png
|
| 226 |
+
│ └── llm.png
|
| 227 |
+
├── gradioApp.py
|
| 228 |
+
├── symbol_detection.py
|
| 229 |
+
├── text_detection_combined.py
|
| 230 |
+
├── line_detection_ai.py
|
| 231 |
+
├── data_aggregation.py
|
| 232 |
+
├── graph_construction.py
|
| 233 |
+
├── graph_processor.py
|
| 234 |
+
├── pdf_processor.py
|
| 235 |
+
├── pnid_agent.py
|
| 236 |
+
├── requirements.txt
|
| 237 |
+
├── results/
|
| 238 |
+
├── models/
|
| 239 |
+
│ └── symbol_detection_model.pth
|
| 240 |
+
```
|
| 241 |
+
|
| 242 |
+
## /models Folder
|
| 243 |
+
|
| 244 |
+
- **models/symbol_detection_model.pth**: This folder contains the pre-trained model for symbol detection in P&ID diagrams. This model is crucial for detecting key symbols such as valves, instruments, and pipes in the diagram. Make sure to download the model and place it in the `/models` directory before running the app.
|
| 245 |
+
|
| 246 |
+
## Future Work
|
| 247 |
+
|
| 248 |
+
- **Advanced Symbol Recognition**: Improve symbol detection by integrating more sophisticated recognition models.
|
| 249 |
+
- **Graph Enhancement**: Introduce more complex graph structures and logic for representing the relationships between the diagram's elements.
|
| 250 |
+
- **Data Export**: Allow export in additional formats such as DEXPI-compliant XML or JSON.
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
# Docker Information
|
| 254 |
+
|
| 255 |
+
We'll cover the basic docker operations here.
|
| 256 |
+
|
| 257 |
+
## Building
|
| 258 |
+
|
| 259 |
+
There is a dockerfile for each different project (they have slightly different requiremnts).
|
| 260 |
+
|
| 261 |
+
### `gradioChatApp.py`
|
| 262 |
+
|
| 263 |
+
Run this one as follows:
|
| 264 |
+
|
| 265 |
+
```
|
| 266 |
+
> docker build -t exp-pnid-to-graph_chat-w-graph:0.0.4 -f Dockerfile-chatApp .
|
| 267 |
+
> docker tag exp-pnid-to-graph_chat-w-graph:0.0.4 intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
|
| 268 |
+
```
|
| 269 |
+
|
| 270 |
+
## Deploying to ACR
|
| 271 |
+
|
| 272 |
+
### `gradioChatApp.py`
|
| 273 |
+
|
| 274 |
+
```
|
| 275 |
+
> az login
|
| 276 |
+
> az acr login --name intaicr
|
| 277 |
+
> docker push intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
|
| 278 |
+
```
|
| 279 |
+
|
| 280 |
+
## Models
|
| 281 |
+
|
| 282 |
+
### Symbol Detection Models
|
| 283 |
+
- **Intui_SDM_41.pt**: Primary model for equipment and large symbol detection
|
| 284 |
+
- Classes: Equipment, Vessels, Heat Exchangers
|
| 285 |
+
- Input Resolution: 1280x1280
|
| 286 |
+
- Confidence Threshold: 0.3-0.7 (adaptive)
|
| 287 |
+
|
| 288 |
+
- **Intui_SDM_20.pt**: Secondary model for instrument and small symbol detection
|
| 289 |
+
- Classes: Instruments, Valves, Indicators
|
| 290 |
+
- Input Resolution: 1280x1280
|
| 291 |
+
- Confidence Threshold: 0.3-0.7 (adaptive)
|
| 292 |
+
|
| 293 |
+
### Line Detection Model
|
| 294 |
+
- **intui_LDM_01.pt**: Specialized model for line and connection detection
|
| 295 |
+
- Classes: Solid Lines, Dashed Lines
|
| 296 |
+
- Input Resolution: 1280x1280
|
| 297 |
+
- Confidence Threshold: 0.5
|
| 298 |
+
|
| 299 |
+
### Text Detection Models
|
| 300 |
+
- **Tesseract**: v5.3.0
|
| 301 |
+
- Configuration:
|
| 302 |
+
- OEM Mode: 3 (Default)
|
| 303 |
+
- PSM Mode: 11 (Sparse text)
|
| 304 |
+
- Custom Whitelist: A-Z, 0-9, special characters
|
| 305 |
+
|
| 306 |
+
- **EasyOCR**: v1.7.1
|
| 307 |
+
- Configuration:
|
| 308 |
+
- Language: English
|
| 309 |
+
- Paragraph Mode: False
|
| 310 |
+
- Height Threshold: 2.0
|
| 311 |
+
- Width Threshold: 2.0
|
| 312 |
+
- Contrast Threshold: 0.2
|
| 313 |
+
|
| 314 |
+
- **DocTR**: v0.6.0
|
| 315 |
+
- Models:
|
| 316 |
+
- fast_base-688a8b34.pt
|
| 317 |
+
- crnn_vgg16_bn-9762b0b0.pt
|
| 318 |
+
|
| 319 |
+
# P&ID Line Detection
|
| 320 |
+
|
| 321 |
+
A deep learning-based pipeline for detecting lines in P&ID diagrams using DeepLSD.
|
| 322 |
+
|
| 323 |
+
## Architecture
|
| 324 |
+
```mermaid
|
| 325 |
+
graph TD
|
| 326 |
+
A[Input Image] --> B[Line Detection]
|
| 327 |
+
B --> C[DeepLSD Model]
|
| 328 |
+
C --> D[Post-processing]
|
| 329 |
+
D --> E[Output JSON/Image]
|
| 330 |
+
|
| 331 |
+
subgraph Line Detection Pipeline
|
| 332 |
+
B --> F[Image Preprocessing]
|
| 333 |
+
F --> G[Scale Image 0.1x]
|
| 334 |
+
G --> H[Grayscale Conversion]
|
| 335 |
+
H --> C
|
| 336 |
+
C --> I[Scale Coordinates]
|
| 337 |
+
I --> J[Draw Lines]
|
| 338 |
+
J --> E
|
| 339 |
+
end
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
## Setup
|
| 343 |
+
|
| 344 |
+
### Prerequisites
|
| 345 |
+
- Python 3.12+
|
| 346 |
+
- uv (for dependency management)
|
| 347 |
+
- Git
|
| 348 |
+
- CUDA-capable GPU (optional)
|
| 349 |
+
|
| 350 |
+
### Installation
|
| 351 |
+
|
| 352 |
+
1. Clone the repository:
|
| 353 |
+
```bash
|
| 354 |
+
git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
|
| 355 |
+
cd intui-PnID-POC
|
| 356 |
+
```
|
| 357 |
+
|
| 358 |
+
2. Install dependencies using uv:
|
| 359 |
+
```bash
|
| 360 |
+
# Install uv if you haven't already
|
| 361 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 362 |
+
|
| 363 |
+
# Create and activate virtual environment
|
| 364 |
+
uv venv
|
| 365 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 366 |
+
|
| 367 |
+
# Install dependencies
|
| 368 |
+
uv pip install -r requirements.txt
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
3. Download DeepLSD model:
|
| 372 |
+
```bash
|
| 373 |
+
python download_model.py
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
## Usage
|
| 377 |
+
|
| 378 |
+
1. Run the line detection:
|
| 379 |
+
```bash
|
| 380 |
+
python line_detection_ai.py
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
The script will:
|
| 384 |
+
- Load the DeepLSD model
|
| 385 |
+
- Process input images at 0.1x scale for performance
|
| 386 |
+
- Generate line detections
|
| 387 |
+
- Save results as JSON and annotated images
|
| 388 |
+
|
| 389 |
+
## Configuration
|
| 390 |
+
|
| 391 |
+
Key parameters in `line_detection_ai.py`:
|
| 392 |
+
- `scale_factor`: Image scaling (default: 0.1)
|
| 393 |
+
- `device`: CPU/GPU selection
|
| 394 |
+
- `mask_json_paths`: Paths to text/symbol detection results
|
| 395 |
+
|
| 396 |
+
## Input/Output
|
| 397 |
+
|
| 398 |
+
### Input
|
| 399 |
+
- Original P&ID images
|
| 400 |
+
- Optional text/symbol detection JSON files for masking
|
| 401 |
+
|
| 402 |
+
### Output
|
| 403 |
+
- Annotated images with detected lines
|
| 404 |
+
- JSON files containing line coordinates and metadata
|
| 405 |
+
|
| 406 |
+
## Project Structure
|
| 407 |
+
|
| 408 |
+
```
|
| 409 |
+
├── line_detection_ai.py # Main line detection script
|
| 410 |
+
├── detectors.py # Line detector implementation
|
| 411 |
+
├── download_model.py # Model download utility
|
| 412 |
+
├── models/ # Directory for model files
|
| 413 |
+
│ └── deeplsd_md.tar # DeepLSD model weights
|
| 414 |
+
├── results/ # Output directory
|
| 415 |
+
└── requirements.txt # Project dependencies
|
| 416 |
+
```
|
| 417 |
+
|
| 418 |
+
## Dependencies
|
| 419 |
+
|
| 420 |
+
Key dependencies:
|
| 421 |
+
- torch
|
| 422 |
+
- opencv-python
|
| 423 |
+
- numpy
|
| 424 |
+
- DeepLSD
|
| 425 |
+
|
| 426 |
+
See `requirements.txt` for the complete list.
|
| 427 |
+
|
| 428 |
+
## Contributing
|
| 429 |
+
|
| 430 |
+
1. Fork the repository
|
| 431 |
+
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
| 432 |
+
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
|
| 433 |
+
4. Push to the branch (`git push origin feature/amazing-feature`)
|
| 434 |
+
5. Open a Pull Request
|
| 435 |
+
|
| 436 |
+
## License
|
| 437 |
+
|
| 438 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 439 |
+
|
| 440 |
+
## Acknowledgments
|
| 441 |
+
|
| 442 |
+
- [DeepLSD](https://github.com/cvg/DeepLSD) for the line detection model
|
| 443 |
+
- Original P&ID processing pipeline by IntuigenceAI
|
| 444 |
+
|
| 445 |
+
---
|
| 446 |
+
title: PnID Diagram Analyzer
|
| 447 |
+
emoji: 🔍
|
| 448 |
+
colorFrom: blue
|
| 449 |
+
colorTo: red
|
| 450 |
+
sdk: gradio
|
| 451 |
+
sdk_version: 4.19.2
|
| 452 |
+
app_file: gradioChatApp.py
|
| 453 |
+
pinned: false
|
| 454 |
+
---
|
| 455 |
+
|
| 456 |
+
# PnID Diagram Analyzer
|
| 457 |
+
|
| 458 |
+
This app analyzes PnID diagrams using AI to detect and interpret various elements.
|
| 459 |
+
|
| 460 |
+
## Features
|
| 461 |
+
- Line detection
|
| 462 |
+
- Symbol recognition
|
| 463 |
+
- Text detection
|
| 464 |
+
- Graph construction
|
base.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import List, Optional, Dict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from loguru import logger
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
from common import DetectionResult
|
| 14 |
+
from storage import StorageInterface
|
| 15 |
+
from utils import DebugHandler, CoordinateTransformer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BaseConfig(ABC):
|
| 19 |
+
"""Abstract Base Config for all configuration classes."""
|
| 20 |
+
|
| 21 |
+
def __post_init__(self):
|
| 22 |
+
"""Ensures default values are set correctly for all configs."""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
class BaseDetector(ABC):
|
| 26 |
+
"""Abstract base class for detection models."""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
config: BaseConfig,
|
| 30 |
+
debug_handler: DebugHandler = None):
|
| 31 |
+
self.config = config
|
| 32 |
+
self.debug_handler = debug_handler or DebugHandler()
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def _load_model(self, model_path: str):
|
| 36 |
+
"""Load and return the detection model."""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def detect(self, image: np.ndarray, *args, **kwargs):
|
| 41 |
+
"""Run detection on an input image."""
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 46 |
+
"""Preprocess the input image before detection."""
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 51 |
+
"""Postprocess the input image before detection."""
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BaseDetectionPipeline(ABC):
|
| 56 |
+
"""Abstract base class for detection pipelines."""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
storage: StorageInterface,
|
| 61 |
+
debug_handler=None
|
| 62 |
+
):
|
| 63 |
+
# self.detector = detector
|
| 64 |
+
self.storage = storage
|
| 65 |
+
self.debug_handler = debug_handler or DebugHandler()
|
| 66 |
+
self.transformer = CoordinateTransformer()
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def process_image(
|
| 70 |
+
self,
|
| 71 |
+
image_path: str,
|
| 72 |
+
output_dir: str,
|
| 73 |
+
config
|
| 74 |
+
) -> DetectionResult:
|
| 75 |
+
"""Main processing pipeline for a single image."""
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def _apply_roi(self, image: np.ndarray, roi: np.ndarray) -> np.ndarray:
|
| 79 |
+
"""Apply region of interest cropping."""
|
| 80 |
+
if roi is not None and len(roi) == 4:
|
| 81 |
+
x_min, y_min, x_max, y_max = roi
|
| 82 |
+
return image[y_min:y_max, x_min:x_max]
|
| 83 |
+
return image
|
| 84 |
+
|
| 85 |
+
def _adjust_coordinates(self, detections: List[Dict], roi: np.ndarray) -> List[Dict]:
|
| 86 |
+
"""Adjust detection coordinates based on ROI"""
|
| 87 |
+
if roi is None or len(roi) != 4:
|
| 88 |
+
return detections
|
| 89 |
+
|
| 90 |
+
x_offset, y_offset = roi[0], roi[1]
|
| 91 |
+
adjusted = []
|
| 92 |
+
|
| 93 |
+
for det in detections:
|
| 94 |
+
try:
|
| 95 |
+
adjusted_bbox = [
|
| 96 |
+
int(det["bbox"][0] + x_offset),
|
| 97 |
+
int(det["bbox"][1] + y_offset),
|
| 98 |
+
int(det["bbox"][2] + x_offset),
|
| 99 |
+
int(det["bbox"][3] + y_offset)
|
| 100 |
+
]
|
| 101 |
+
adjusted_det = {**det, "bbox": adjusted_bbox}
|
| 102 |
+
adjusted.append(adjusted_det)
|
| 103 |
+
except KeyError:
|
| 104 |
+
logger.warning("Invalid detection format during coordinate adjustment")
|
| 105 |
+
return adjusted
|
| 106 |
+
|
| 107 |
+
def _persist_results(
|
| 108 |
+
self,
|
| 109 |
+
output_dir: str,
|
| 110 |
+
image_path: str,
|
| 111 |
+
detections: List[Dict],
|
| 112 |
+
annotated_image: Optional[np.ndarray]
|
| 113 |
+
) -> Dict[str, str]:
|
| 114 |
+
"""Save detection results and annotations"""
|
| 115 |
+
self.storage.create_directory(output_dir)
|
| 116 |
+
base_name = Path(image_path).stem
|
| 117 |
+
|
| 118 |
+
# Save JSON results
|
| 119 |
+
json_path = Path(output_dir) / f"{base_name}_lines.json"
|
| 120 |
+
self.storage.save_file(
|
| 121 |
+
str(json_path),
|
| 122 |
+
json.dumps({
|
| 123 |
+
"solid_lines": {"lines": detections},
|
| 124 |
+
"dashed_lines": {"lines": []}
|
| 125 |
+
}, indent=2).encode('utf-8')
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Save annotated image
|
| 129 |
+
img_path = None
|
| 130 |
+
if annotated_image is not None:
|
| 131 |
+
img_path = Path(output_dir) / f"{base_name}_annotated.jpg"
|
| 132 |
+
_, img_data = cv2.imencode('.jpg', annotated_image)
|
| 133 |
+
self.storage.save_file(str(img_path), img_data.tobytes())
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"json_path": str(json_path),
|
| 137 |
+
"image_path": str(img_path) if img_path else None
|
| 138 |
+
}
|
chatbot_agent.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# chatbot_agent.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
import traceback
|
| 8 |
+
import logging
|
| 9 |
+
|
| 10 |
+
# Get logger
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
| 14 |
+
|
| 15 |
+
def format_message(role, content):
|
| 16 |
+
"""Format message for chatbot history."""
|
| 17 |
+
return {"role": role, "content": content}
|
| 18 |
+
|
| 19 |
+
def initialize_graph_prompt(graph_data):
|
| 20 |
+
"""Initialize the conversation with detailed node, edge, linker, and text information."""
|
| 21 |
+
summary_info = (
|
| 22 |
+
f"Symbols (represented as nodes): {graph_data['summary']['symbol_count']}, "
|
| 23 |
+
f"Texts: {graph_data['summary']['text_count']}, "
|
| 24 |
+
f"Lines: {graph_data['summary']['line_count']}, "
|
| 25 |
+
f"Linkers: {graph_data['summary']['linker_count']}, "
|
| 26 |
+
f"Edges: {graph_data['summary']['edge_count']}."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Prepare detailed node (symbol) data
|
| 30 |
+
node_details = "Nodes (symbols) in the graph include the following details:\n"
|
| 31 |
+
for symbol in graph_data["detailed_results"]["symbols"]:
|
| 32 |
+
node_details += (
|
| 33 |
+
f"Node ID: {symbol['symbol_id']}, Class ID: {symbol['class_id']}, "
|
| 34 |
+
f"Category: {symbol['category']}, Type: {symbol['type']}, "
|
| 35 |
+
f"Label: {symbol['label']}, Confidence: {symbol['confidence']}\n"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Prepare edge data
|
| 39 |
+
edge_details = "Edges in the graph showing connections between nodes are as follows:\n"
|
| 40 |
+
for edge in graph_data["detailed_results"].get("edges", []):
|
| 41 |
+
edge_details += (
|
| 42 |
+
f"Edge ID: {edge['edge_id']}, From Node: {edge['symbol_1_id']}, "
|
| 43 |
+
f"To Node: {edge['symbol_2_id']}, Type: {edge.get('type', 'unknown')}\n"
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# Prepare linker data
|
| 48 |
+
linker_details = "Linkers in the diagram are as follows:\n"
|
| 49 |
+
for linker in graph_data["detailed_results"].get("linkers", []):
|
| 50 |
+
linker_details += (
|
| 51 |
+
f"Symbol ID: {linker['symbol_id']}, Associated Text IDs: {linker.get('text_ids', [])}, "
|
| 52 |
+
f"Associated Edge IDs: {linker.get('edge_ids', [])}, Position: {linker.get('bbox', 'unknown')}\n"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# Prepare text (tag) data
|
| 57 |
+
text_details = "Text elements with associated tags in the diagram are as follows:\n"
|
| 58 |
+
for text in graph_data["detailed_results"].get("texts", []):
|
| 59 |
+
text_details += (
|
| 60 |
+
f"Text ID: {text['text_id']}, Content: {text['content']}, "
|
| 61 |
+
f"Confidence: {text['confidence']}, Position: {text['bbox']}\n"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
initial_prompt = (
|
| 66 |
+
"You have access to a knowledge graph generated from a P&ID diagram. "
|
| 67 |
+
f"The summary information includes:\n{summary_info}\n\n"
|
| 68 |
+
"The detailed information about each node (symbol) in the graph is as follows:\n"
|
| 69 |
+
f"{node_details}\n"
|
| 70 |
+
"The edges connecting these nodes are as follows:\n"
|
| 71 |
+
f"{edge_details}\n"
|
| 72 |
+
"The linkers in the diagram are as follows:\n"
|
| 73 |
+
f"{linker_details}\n"
|
| 74 |
+
"The text elements and their tags in the diagram are as follows:\n"
|
| 75 |
+
f"{text_details}\n"
|
| 76 |
+
"Answer questions about specific nodes, edges, types, labels, categories, linkers, or text tags using this information."
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
return initial_prompt
|
| 80 |
+
|
| 81 |
+
def get_assistant_response(user_message, json_path):
|
| 82 |
+
"""Generate response based on P&ID data and OpenAI."""
|
| 83 |
+
try:
|
| 84 |
+
# Load the aggregated data
|
| 85 |
+
with open(json_path, 'r') as f:
|
| 86 |
+
data = json.load(f)
|
| 87 |
+
|
| 88 |
+
# Process the user's question
|
| 89 |
+
question = user_message.lower()
|
| 90 |
+
|
| 91 |
+
# Use rule-based responses for specific questions
|
| 92 |
+
if "valve" in question or "valves" in question:
|
| 93 |
+
valve_count = sum(1 for symbol in data.get('symbols', [])
|
| 94 |
+
if 'class' in symbol and 'valve' in symbol['class'].lower())
|
| 95 |
+
return f"I found {valve_count} valves in this P&ID."
|
| 96 |
+
|
| 97 |
+
elif "pump" in question or "pumps" in question:
|
| 98 |
+
pump_count = sum(1 for symbol in data.get('symbols', [])
|
| 99 |
+
if 'class' in symbol and 'pump' in symbol['class'].lower())
|
| 100 |
+
return f"I found {pump_count} pumps in this P&ID."
|
| 101 |
+
|
| 102 |
+
elif "equipment" in question or "components" in question:
|
| 103 |
+
equipment_types = {}
|
| 104 |
+
for symbol in data.get('symbols', []):
|
| 105 |
+
if 'class' in symbol:
|
| 106 |
+
eq_type = symbol['class']
|
| 107 |
+
equipment_types[eq_type] = equipment_types.get(eq_type, 0) + 1
|
| 108 |
+
|
| 109 |
+
response = "Here's a summary of the equipment I found:\n"
|
| 110 |
+
for eq_type, count in equipment_types.items():
|
| 111 |
+
response += f"- {eq_type}: {count}\n"
|
| 112 |
+
return response
|
| 113 |
+
|
| 114 |
+
# For other questions, use OpenAI
|
| 115 |
+
else:
|
| 116 |
+
# Prepare the conversation context
|
| 117 |
+
graph_data = {
|
| 118 |
+
"summary": {
|
| 119 |
+
"symbol_count": len(data.get('symbols', [])),
|
| 120 |
+
"text_count": len(data.get('texts', [])),
|
| 121 |
+
"line_count": len(data.get('lines', [])),
|
| 122 |
+
"edge_count": len(data.get('edges', [])),
|
| 123 |
+
},
|
| 124 |
+
"detailed_results": data
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
initial_prompt = initialize_graph_prompt(graph_data)
|
| 128 |
+
conversation = [
|
| 129 |
+
{"role": "system", "content": initial_prompt},
|
| 130 |
+
{"role": "user", "content": user_message}
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
response = client.chat.completions.create(
|
| 134 |
+
model="gpt-4-turbo",
|
| 135 |
+
messages=conversation
|
| 136 |
+
)
|
| 137 |
+
return response.choices[0].message.content
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
logger.error(f"Error in get_assistant_response: {str(e)}")
|
| 141 |
+
logger.error(traceback.format_exc())
|
| 142 |
+
return "I apologize, but I encountered an error analyzing the P&ID data. Please try asking a different question."
|
| 143 |
+
|
| 144 |
+
# Testing and Usage block
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
# Load the knowledge graph data from JSON file
|
| 147 |
+
json_file_path = "results/0_aggregated_detections.json"
|
| 148 |
+
try:
|
| 149 |
+
with open(json_file_path, 'r') as file:
|
| 150 |
+
graph_data = json.load(file)
|
| 151 |
+
except FileNotFoundError:
|
| 152 |
+
print(f"Error: File not found at {json_file_path}")
|
| 153 |
+
graph_data = None
|
| 154 |
+
except json.JSONDecodeError:
|
| 155 |
+
print("Error: Failed to decode JSON. Please check the file format.")
|
| 156 |
+
graph_data = None
|
| 157 |
+
|
| 158 |
+
# Initialize conversation history with assistant's welcome message
|
| 159 |
+
history = [format_message("assistant", "Hello! I am ready to answer your questions about the P&ID knowledge graph. The graph includes nodes (symbols), edges, linkers, and text tags, and I have detailed information available about each. Please ask any questions related to these elements and their connections.")]
|
| 160 |
+
|
| 161 |
+
# Print the assistant's welcome message
|
| 162 |
+
print("Assistant:", history[0]["content"])
|
| 163 |
+
|
| 164 |
+
# Individual Testing Options
|
| 165 |
+
if graph_data:
|
| 166 |
+
# Option 1: Test the graph prompt initialization
|
| 167 |
+
print("\n--- Test: Graph Prompt Initialization ---")
|
| 168 |
+
initial_prompt = initialize_graph_prompt(graph_data)
|
| 169 |
+
print(initial_prompt)
|
| 170 |
+
|
| 171 |
+
# Option 2: Simulate a conversation with a test question
|
| 172 |
+
print("\n--- Test: Simulate Conversation ---")
|
| 173 |
+
test_question = "Can you tell me about the connections between the nodes?"
|
| 174 |
+
history.append(format_message("user", test_question))
|
| 175 |
+
|
| 176 |
+
print(f"\nUser: {test_question}")
|
| 177 |
+
for response in get_assistant_response(test_question, json_file_path):
|
| 178 |
+
print("Assistant:", response)
|
| 179 |
+
history.append(format_message("assistant", response))
|
| 180 |
+
|
| 181 |
+
# Option 3: Manually input questions for interactive testing
|
| 182 |
+
while True:
|
| 183 |
+
user_question = input("\nYou: ")
|
| 184 |
+
if user_question.lower() in ["exit", "quit"]:
|
| 185 |
+
print("Exiting chat. Goodbye!")
|
| 186 |
+
break
|
| 187 |
+
|
| 188 |
+
history.append(format_message("user", user_question))
|
| 189 |
+
for response in get_assistant_response(user_question, json_file_path):
|
| 190 |
+
print("Assistant:", response)
|
| 191 |
+
history.append(format_message("assistant", response))
|
| 192 |
+
else:
|
| 193 |
+
print("Unable to load graph data. Please check the file path and format.")
|
common.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from detection_schema import Line
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class DetectionResult:
|
| 9 |
+
success: bool
|
| 10 |
+
error: Optional[str] = None
|
| 11 |
+
annotated_image: Optional[np.ndarray] = None
|
| 12 |
+
processing_time: float = 0.0
|
| 13 |
+
json_path: Optional[str] = None
|
| 14 |
+
image_path: Optional[str] = None
|
config.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
from base import BaseConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ImageConfig(BaseConfig):
|
| 9 |
+
"""Configuration for global image-related settings"""
|
| 10 |
+
|
| 11 |
+
roi: Optional[np.ndarray] = field(default_factory=lambda: np.array([500, 500, 5300, 4000]))
|
| 12 |
+
mask_json_path: str = "./text_and_symbol_bboxes.json"
|
| 13 |
+
save_annotations: bool = True
|
| 14 |
+
annotation_style: Dict = field(default_factory=lambda: {
|
| 15 |
+
'bbox_color': (255, 0, 0),
|
| 16 |
+
'line_color': (0, 255, 0),
|
| 17 |
+
'text_color': (0, 0, 255),
|
| 18 |
+
'thickness': 2,
|
| 19 |
+
'font_scale': 0.6
|
| 20 |
+
})
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class SymbolConfig:
|
| 24 |
+
"""Configuration for Symbol Detection"""
|
| 25 |
+
model_paths: Dict[str, str] = field(default_factory=lambda: {
|
| 26 |
+
"model1": "models/Intui_SDM_41.pt",
|
| 27 |
+
"model2": "models/Intui_SDM_20.pt"
|
| 28 |
+
})
|
| 29 |
+
# Multiple thresholds to test
|
| 30 |
+
confidence_thresholds: List[float] = field(default_factory=lambda: [0.1, 0.3, 0.5, 0.7, 0.9])
|
| 31 |
+
apply_preprocessing: bool = False
|
| 32 |
+
resize_image: bool = True
|
| 33 |
+
max_dimension: int = 1280
|
| 34 |
+
iou_threshold: float = 0.5
|
| 35 |
+
# Optional mapping from class_id to SymbolType
|
| 36 |
+
symbol_type_mapping: Dict[str, str] = field(default_factory=lambda: {
|
| 37 |
+
"valve": "VALVE",
|
| 38 |
+
"pump": "PUMP",
|
| 39 |
+
"sensor": "SENSOR"
|
| 40 |
+
})
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class TagConfig(BaseConfig):
|
| 45 |
+
"""Configuration for Tag Detection with OCR"""
|
| 46 |
+
confidence_threshold: float = 0.5
|
| 47 |
+
iou_threshold: float = 0.4
|
| 48 |
+
ocr_engines: List[str] = field(default_factory=lambda: ['tesseract', 'easyocr', 'doctr'])
|
| 49 |
+
text_patterns: Dict[str, str] = field(default_factory=lambda: {
|
| 50 |
+
'Line_Number': r"\d{1,5}-[A-Z]{2,4}-\d{1,3}",
|
| 51 |
+
'Equipment_Tag': r"[A-Z]{1,3}-[A-Z0-9]{1,4}-\d{1,3}",
|
| 52 |
+
'Instrument_Tag': r"\d{2,3}-[A-Z]{2,4}-\d{2,3}",
|
| 53 |
+
'Valve_Number': r"[A-Z]{1,2}-\d{3}",
|
| 54 |
+
'Pipe_Size': r"\d{1,2}\"",
|
| 55 |
+
'Flow_Direction': r"FROM|TO",
|
| 56 |
+
'Service_Description': r"STEAM|WATER|AIR|GAS|DRAIN",
|
| 57 |
+
'Process_Instrument': r"\d{2,3}(?:-[A-Z]{2,3})?-\d{2,3}|[A-Z]{2,3}-\d{2,3}",
|
| 58 |
+
'Nozzle': r"N[0-9]{1,2}|MH",
|
| 59 |
+
'Pipe_Connector': r"[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5}"
|
| 60 |
+
})
|
| 61 |
+
tesseract_config: str = r'--oem 3 --psm 11'
|
| 62 |
+
easyocr_params: Dict = field(default_factory=lambda: {
|
| 63 |
+
'paragraph': False,
|
| 64 |
+
'height_ths': 2.0,
|
| 65 |
+
'width_ths': 2.0,
|
| 66 |
+
'contrast_ths': 0.2
|
| 67 |
+
})
|
| 68 |
+
|
| 69 |
+
@dataclass
|
| 70 |
+
class LineConfig(BaseConfig):
|
| 71 |
+
"""Configuration for Line Detection"""
|
| 72 |
+
|
| 73 |
+
threshold_distance: float = 10.0
|
| 74 |
+
expansion_factor: float = 1.1
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class PointConfig(BaseConfig):
|
| 79 |
+
"""Configuration for Point Detection"""
|
| 80 |
+
|
| 81 |
+
threshold_distance: float = 10.0
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
@dataclass
|
| 85 |
+
class JunctionConfig(BaseConfig):
|
| 86 |
+
"""Configuration for Junction Detection"""
|
| 87 |
+
|
| 88 |
+
window_size: int = 21
|
| 89 |
+
radius: int = 5
|
| 90 |
+
angle_threshold_lb: float = 15.0
|
| 91 |
+
angle_threshold_ub: float = 75.0
|
| 92 |
+
<<<<<<< HEAD
|
| 93 |
+
expansion_factor: float = 1.1
|
| 94 |
+
mask_json_paths: List[str] = field(default_factory=list) # Now a list
|
| 95 |
+
|
| 96 |
+
roi: Optional[np.ndarray] = field(default_factory=lambda: np.array([300, 500, 7500, 7000]))
|
| 97 |
+
save_annotations: bool = True
|
| 98 |
+
annotation_style: Dict = None
|
| 99 |
+
|
| 100 |
+
def __post_init__(self):
|
| 101 |
+
self.annotation_style = self.annotation_style or {
|
| 102 |
+
'bbox_color': (255, 0, 0),
|
| 103 |
+
'line_color': (0, 255, 0),
|
| 104 |
+
'text_color': (0, 0, 255),
|
| 105 |
+
'thickness': 2,
|
| 106 |
+
'font_scale': 0.6
|
| 107 |
+
}
|
| 108 |
+
=======
|
| 109 |
+
|
| 110 |
+
# @dataclass
|
| 111 |
+
# class JunctionConfig:
|
| 112 |
+
# radius: int = 5
|
| 113 |
+
# angle_threshold: float = 25.0
|
| 114 |
+
# colinear_threshold: float = 5.0
|
| 115 |
+
# connection_threshold: float = 5.0
|
| 116 |
+
>>>>>>> temp/test-integration
|
data_aggregation_ai.py
ADDED
|
@@ -0,0 +1,455 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import List, Dict, Optional, Tuple
|
| 6 |
+
from storage import StorageFactory
|
| 7 |
+
import uuid
|
| 8 |
+
import traceback
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class DataAggregator:
|
| 13 |
+
def __init__(self, storage=None):
|
| 14 |
+
self.storage = storage or StorageFactory.get_storage()
|
| 15 |
+
self.logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
def _parse_line_data(self, lines_data: dict) -> List[dict]:
|
| 18 |
+
"""Parse line detection data with coordinate validation"""
|
| 19 |
+
parsed_lines = []
|
| 20 |
+
|
| 21 |
+
for line in lines_data.get("lines", []):
|
| 22 |
+
try:
|
| 23 |
+
# Extract and validate line coordinates
|
| 24 |
+
start_coords = line["start"]["coords"]
|
| 25 |
+
end_coords = line["end"]["coords"]
|
| 26 |
+
bbox = line["bbox"]
|
| 27 |
+
|
| 28 |
+
# Validate coordinates
|
| 29 |
+
if not (self._is_valid_point(start_coords) and
|
| 30 |
+
self._is_valid_point(end_coords) and
|
| 31 |
+
self._is_valid_bbox(bbox)):
|
| 32 |
+
self.logger.warning(f"Invalid coordinates in line: {line['id']}")
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
# Create parsed line with validated coordinates
|
| 36 |
+
parsed_line = {
|
| 37 |
+
"id": line["id"],
|
| 38 |
+
"start_point": {
|
| 39 |
+
"x": int(start_coords["x"]),
|
| 40 |
+
"y": int(start_coords["y"]),
|
| 41 |
+
"type": line["start"]["type"],
|
| 42 |
+
"confidence": line["start"]["confidence"]
|
| 43 |
+
},
|
| 44 |
+
"end_point": {
|
| 45 |
+
"x": int(end_coords["x"]),
|
| 46 |
+
"y": int(end_coords["y"]),
|
| 47 |
+
"type": line["end"]["type"],
|
| 48 |
+
"confidence": line["end"]["confidence"]
|
| 49 |
+
},
|
| 50 |
+
"bbox": {
|
| 51 |
+
"xmin": int(bbox["xmin"]),
|
| 52 |
+
"ymin": int(bbox["ymin"]),
|
| 53 |
+
"xmax": int(bbox["xmax"]),
|
| 54 |
+
"ymax": int(bbox["ymax"])
|
| 55 |
+
},
|
| 56 |
+
"style": line["style"],
|
| 57 |
+
"confidence": line["confidence"]
|
| 58 |
+
}
|
| 59 |
+
parsed_lines.append(parsed_line)
|
| 60 |
+
|
| 61 |
+
except Exception as e:
|
| 62 |
+
self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}")
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
return parsed_lines
|
| 66 |
+
|
| 67 |
+
def _is_valid_point(self, point: dict) -> bool:
|
| 68 |
+
"""Validate point coordinates"""
|
| 69 |
+
try:
|
| 70 |
+
x, y = point.get("x"), point.get("y")
|
| 71 |
+
return (isinstance(x, (int, float)) and
|
| 72 |
+
isinstance(y, (int, float)) and
|
| 73 |
+
0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed
|
| 74 |
+
except:
|
| 75 |
+
return False
|
| 76 |
+
|
| 77 |
+
def _is_valid_bbox(self, bbox: dict) -> bool:
|
| 78 |
+
"""Validate bbox coordinates"""
|
| 79 |
+
try:
|
| 80 |
+
xmin = bbox.get("xmin")
|
| 81 |
+
ymin = bbox.get("ymin")
|
| 82 |
+
xmax = bbox.get("xmax")
|
| 83 |
+
ymax = bbox.get("ymax")
|
| 84 |
+
|
| 85 |
+
return (isinstance(xmin, (int, float)) and
|
| 86 |
+
isinstance(ymin, (int, float)) and
|
| 87 |
+
isinstance(xmax, (int, float)) and
|
| 88 |
+
isinstance(ymax, (int, float)) and
|
| 89 |
+
xmin < xmax and ymin < ymax and
|
| 90 |
+
0 <= xmin <= 10000 and 0 <= ymin <= 10000 and
|
| 91 |
+
0 <= xmax <= 10000 and 0 <= ymax <= 10000)
|
| 92 |
+
except:
|
| 93 |
+
return False
|
| 94 |
+
|
| 95 |
+
def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]:
|
| 96 |
+
"""Create nodes and edges for the knowledge graph following the three-step process"""
|
| 97 |
+
nodes = []
|
| 98 |
+
edges = []
|
| 99 |
+
|
| 100 |
+
# Step 1: Create Object Nodes with their properties and center points
|
| 101 |
+
# 1a. Symbol Nodes
|
| 102 |
+
for symbol in symbols:
|
| 103 |
+
bbox = symbol["bbox"]
|
| 104 |
+
center_x = (bbox["xmin"] + bbox["xmax"]) / 2
|
| 105 |
+
center_y = (bbox["ymin"] + bbox["ymax"]) / 2
|
| 106 |
+
|
| 107 |
+
node = {
|
| 108 |
+
"id": symbol.get("id", str(uuid.uuid4())),
|
| 109 |
+
"type": "symbol",
|
| 110 |
+
"category": symbol.get("category", "unknown"),
|
| 111 |
+
"bbox": bbox,
|
| 112 |
+
"center": {"x": center_x, "y": center_y},
|
| 113 |
+
"confidence": symbol.get("confidence", 1.0),
|
| 114 |
+
"properties": {
|
| 115 |
+
"class": symbol.get("class", ""),
|
| 116 |
+
"equipment_type": symbol.get("equipment_type", ""),
|
| 117 |
+
"original_label": symbol.get("original_label", ""),
|
| 118 |
+
}
|
| 119 |
+
}
|
| 120 |
+
nodes.append(node)
|
| 121 |
+
|
| 122 |
+
# 1b. Text Nodes
|
| 123 |
+
for text in texts:
|
| 124 |
+
bbox = text["bbox"]
|
| 125 |
+
center_x = (bbox["xmin"] + bbox["xmax"]) / 2
|
| 126 |
+
center_y = (bbox["ymin"] + bbox["ymax"]) / 2
|
| 127 |
+
|
| 128 |
+
node = {
|
| 129 |
+
"id": text.get("id", str(uuid.uuid4())),
|
| 130 |
+
"type": "text",
|
| 131 |
+
"content": text.get("text", ""),
|
| 132 |
+
"bbox": bbox,
|
| 133 |
+
"center": {"x": center_x, "y": center_y},
|
| 134 |
+
"confidence": text.get("confidence", 1.0),
|
| 135 |
+
"properties": {
|
| 136 |
+
"font_size": text.get("font_size"),
|
| 137 |
+
"rotation": text.get("rotation", 0.0),
|
| 138 |
+
"text_type": text.get("text_type", "unknown")
|
| 139 |
+
}
|
| 140 |
+
}
|
| 141 |
+
nodes.append(node)
|
| 142 |
+
|
| 143 |
+
# Step 2: Create Junction Nodes (T/L connections)
|
| 144 |
+
junction_map = {} # To track junctions for edge creation
|
| 145 |
+
for line in lines:
|
| 146 |
+
# Check start point
|
| 147 |
+
if line["start_point"].get("type") in ["T", "L"]:
|
| 148 |
+
junction_id = str(uuid.uuid4())
|
| 149 |
+
junction_node = {
|
| 150 |
+
"id": junction_id,
|
| 151 |
+
"type": "junction",
|
| 152 |
+
"junction_type": line["start_point"]["type"],
|
| 153 |
+
"coords": {
|
| 154 |
+
"x": line["start_point"]["x"],
|
| 155 |
+
"y": line["start_point"]["y"]
|
| 156 |
+
},
|
| 157 |
+
"properties": {
|
| 158 |
+
"confidence": line["start_point"].get("confidence", 1.0)
|
| 159 |
+
}
|
| 160 |
+
}
|
| 161 |
+
nodes.append(junction_node)
|
| 162 |
+
junction_map[f"{line['start_point']['x']}_{line['start_point']['y']}"] = junction_id
|
| 163 |
+
|
| 164 |
+
# Check end point
|
| 165 |
+
if line["end_point"].get("type") in ["T", "L"]:
|
| 166 |
+
junction_id = str(uuid.uuid4())
|
| 167 |
+
junction_node = {
|
| 168 |
+
"id": junction_id,
|
| 169 |
+
"type": "junction",
|
| 170 |
+
"junction_type": line["end_point"]["type"],
|
| 171 |
+
"coords": {
|
| 172 |
+
"x": line["end_point"]["x"],
|
| 173 |
+
"y": line["end_point"]["y"]
|
| 174 |
+
},
|
| 175 |
+
"properties": {
|
| 176 |
+
"confidence": line["end_point"].get("confidence", 1.0)
|
| 177 |
+
}
|
| 178 |
+
}
|
| 179 |
+
nodes.append(junction_node)
|
| 180 |
+
junction_map[f"{line['end_point']['x']}_{line['end_point']['y']}"] = junction_id
|
| 181 |
+
|
| 182 |
+
# Step 3: Create Edges with connection points and topology
|
| 183 |
+
# 3a. Line-Junction Connections
|
| 184 |
+
for line in lines:
|
| 185 |
+
line_id = line.get("id", str(uuid.uuid4()))
|
| 186 |
+
start_key = f"{line['start_point']['x']}_{line['start_point']['y']}"
|
| 187 |
+
end_key = f"{line['end_point']['x']}_{line['end_point']['y']}"
|
| 188 |
+
|
| 189 |
+
# Create edge for line itself
|
| 190 |
+
edge = {
|
| 191 |
+
"id": line_id,
|
| 192 |
+
"type": "line",
|
| 193 |
+
"source": junction_map.get(start_key, str(uuid.uuid4())),
|
| 194 |
+
"target": junction_map.get(end_key, str(uuid.uuid4())),
|
| 195 |
+
"properties": {
|
| 196 |
+
"style": line["style"],
|
| 197 |
+
"confidence": line.get("confidence", 1.0),
|
| 198 |
+
"connection_points": {
|
| 199 |
+
"start": {"x": line["start_point"]["x"], "y": line["start_point"]["y"]},
|
| 200 |
+
"end": {"x": line["end_point"]["x"], "y": line["end_point"]["y"]}
|
| 201 |
+
},
|
| 202 |
+
"bbox": line["bbox"]
|
| 203 |
+
}
|
| 204 |
+
}
|
| 205 |
+
edges.append(edge)
|
| 206 |
+
|
| 207 |
+
# 3b. Symbol-Line Connections (based on spatial proximity)
|
| 208 |
+
for symbol in symbols:
|
| 209 |
+
symbol_center = {
|
| 210 |
+
"x": (symbol["bbox"]["xmin"] + symbol["bbox"]["xmax"]) / 2,
|
| 211 |
+
"y": (symbol["bbox"]["ymin"] + symbol["bbox"]["ymax"]) / 2
|
| 212 |
+
}
|
| 213 |
+
|
| 214 |
+
# Find connected lines based on proximity to endpoints
|
| 215 |
+
for line in lines:
|
| 216 |
+
# Check if line endpoints are near symbol center
|
| 217 |
+
for point_type in ["start_point", "end_point"]:
|
| 218 |
+
point = line[point_type]
|
| 219 |
+
dist = ((point["x"] - symbol_center["x"])**2 +
|
| 220 |
+
(point["y"] - symbol_center["y"])**2)**0.5
|
| 221 |
+
|
| 222 |
+
if dist < 50: # Threshold for connection, adjust as needed
|
| 223 |
+
edge = {
|
| 224 |
+
"id": str(uuid.uuid4()),
|
| 225 |
+
"type": "symbol_line_connection",
|
| 226 |
+
"source": symbol["id"],
|
| 227 |
+
"target": line["id"],
|
| 228 |
+
"properties": {
|
| 229 |
+
"connection_point": {"x": point["x"], "y": point["y"]},
|
| 230 |
+
"connection_type": point_type,
|
| 231 |
+
"distance": dist
|
| 232 |
+
}
|
| 233 |
+
}
|
| 234 |
+
edges.append(edge)
|
| 235 |
+
|
| 236 |
+
# 3c. Symbol-Text Associations (based on proximity and containment)
|
| 237 |
+
for text in texts:
|
| 238 |
+
text_center = {
|
| 239 |
+
"x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2,
|
| 240 |
+
"y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2
|
| 241 |
+
}
|
| 242 |
+
|
| 243 |
+
for symbol in symbols:
|
| 244 |
+
# Check if text is near or contained within symbol
|
| 245 |
+
if (text_center["x"] >= symbol["bbox"]["xmin"] - 20 and
|
| 246 |
+
text_center["x"] <= symbol["bbox"]["xmax"] + 20 and
|
| 247 |
+
text_center["y"] >= symbol["bbox"]["ymin"] - 20 and
|
| 248 |
+
text_center["y"] <= symbol["bbox"]["ymax"] + 20):
|
| 249 |
+
|
| 250 |
+
edge = {
|
| 251 |
+
"id": str(uuid.uuid4()),
|
| 252 |
+
"type": "symbol_text_association",
|
| 253 |
+
"source": symbol["id"],
|
| 254 |
+
"target": text["id"],
|
| 255 |
+
"properties": {
|
| 256 |
+
"association_type": "label",
|
| 257 |
+
"confidence": min(symbol.get("confidence", 1.0),
|
| 258 |
+
text.get("confidence", 1.0))
|
| 259 |
+
}
|
| 260 |
+
}
|
| 261 |
+
edges.append(edge)
|
| 262 |
+
|
| 263 |
+
# 3d. Line-Text Associations (based on proximity and alignment)
|
| 264 |
+
for text in texts:
|
| 265 |
+
text_center = {
|
| 266 |
+
"x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2,
|
| 267 |
+
"y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2
|
| 268 |
+
}
|
| 269 |
+
text_bbox = text["bbox"]
|
| 270 |
+
|
| 271 |
+
for line in lines:
|
| 272 |
+
line_bbox = line["bbox"]
|
| 273 |
+
line_center = {
|
| 274 |
+
"x": (line_bbox["xmin"] + line_bbox["xmax"]) / 2,
|
| 275 |
+
"y": (line_bbox["ymin"] + line_bbox["ymax"]) / 2
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
# Check if text is near the line (using both center and bbox)
|
| 279 |
+
is_nearby_horizontal = (
|
| 280 |
+
abs(text_center["y"] - line_center["y"]) < 30 and # Vertical proximity
|
| 281 |
+
text_bbox["xmin"] <= line_bbox["xmax"] and
|
| 282 |
+
text_bbox["xmax"] >= line_bbox["xmin"]
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
is_nearby_vertical = (
|
| 286 |
+
abs(text_center["x"] - line_center["x"]) < 30 and # Horizontal proximity
|
| 287 |
+
text_bbox["ymin"] <= line_bbox["ymax"] and
|
| 288 |
+
text_bbox["ymax"] >= line_bbox["ymin"]
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Determine text type and position relative to line
|
| 292 |
+
if is_nearby_horizontal or is_nearby_vertical:
|
| 293 |
+
text_type = text.get("text_type", "unknown").lower()
|
| 294 |
+
|
| 295 |
+
# Classify the text based on content and position
|
| 296 |
+
if any(pattern in text.get("text", "").upper()
|
| 297 |
+
for pattern in ["-", "LINE", "PIPE"]):
|
| 298 |
+
association_type = "line_id"
|
| 299 |
+
else:
|
| 300 |
+
association_type = "description"
|
| 301 |
+
|
| 302 |
+
edge = {
|
| 303 |
+
"id": str(uuid.uuid4()),
|
| 304 |
+
"type": "line_text_association",
|
| 305 |
+
"source": line["id"],
|
| 306 |
+
"target": text["id"],
|
| 307 |
+
"properties": {
|
| 308 |
+
"association_type": association_type,
|
| 309 |
+
"relative_position": "horizontal" if is_nearby_horizontal else "vertical",
|
| 310 |
+
"confidence": min(line.get("confidence", 1.0),
|
| 311 |
+
text.get("confidence", 1.0)),
|
| 312 |
+
"distance": abs(text_center["y"] - line_center["y"]) if is_nearby_horizontal
|
| 313 |
+
else abs(text_center["x"] - line_center["x"])
|
| 314 |
+
}
|
| 315 |
+
}
|
| 316 |
+
edges.append(edge)
|
| 317 |
+
|
| 318 |
+
return nodes, edges
|
| 319 |
+
|
| 320 |
+
def _validate_coordinates(self, data, data_type):
|
| 321 |
+
"""Validate coordinates in the data"""
|
| 322 |
+
if not data:
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
if data_type == 'line':
|
| 327 |
+
# Check start and end points
|
| 328 |
+
start = data.get('start_point', {})
|
| 329 |
+
end = data.get('end_point', {})
|
| 330 |
+
bbox = data.get('bbox', {})
|
| 331 |
+
|
| 332 |
+
required_fields = ['x', 'y', 'type']
|
| 333 |
+
if not all(field in start for field in required_fields):
|
| 334 |
+
self.logger.warning(f"Missing required fields in start_point: {start}")
|
| 335 |
+
return False
|
| 336 |
+
if not all(field in end for field in required_fields):
|
| 337 |
+
self.logger.warning(f"Missing required fields in end_point: {end}")
|
| 338 |
+
return False
|
| 339 |
+
|
| 340 |
+
# Validate bbox coordinates
|
| 341 |
+
if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
|
| 342 |
+
self.logger.warning(f"Invalid bbox format: {bbox}")
|
| 343 |
+
return False
|
| 344 |
+
|
| 345 |
+
# Check coordinate consistency
|
| 346 |
+
if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
|
| 347 |
+
self.logger.warning(f"Invalid bbox coordinates: {bbox}")
|
| 348 |
+
return False
|
| 349 |
+
|
| 350 |
+
elif data_type in ['symbol', 'text']:
|
| 351 |
+
bbox = data.get('bbox', {})
|
| 352 |
+
if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
|
| 353 |
+
self.logger.warning(f"Invalid {data_type} bbox format: {bbox}")
|
| 354 |
+
return False
|
| 355 |
+
|
| 356 |
+
# Check coordinate consistency
|
| 357 |
+
if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
|
| 358 |
+
self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}")
|
| 359 |
+
return False
|
| 360 |
+
|
| 361 |
+
return True
|
| 362 |
+
|
| 363 |
+
except Exception as e:
|
| 364 |
+
self.logger.error(f"Validation error for {data_type}: {str(e)}")
|
| 365 |
+
return False
|
| 366 |
+
|
| 367 |
+
def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict:
|
| 368 |
+
"""Aggregate detection results and create graph structure"""
|
| 369 |
+
try:
|
| 370 |
+
# Load line detection results
|
| 371 |
+
lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8'))
|
| 372 |
+
lines = self._parse_line_data(lines_data)
|
| 373 |
+
|
| 374 |
+
# Load symbol detections
|
| 375 |
+
symbols = []
|
| 376 |
+
if symbols_path and Path(symbols_path).exists():
|
| 377 |
+
symbols_data = json.loads(self.storage.load_file(symbols_path).decode('utf-8'))
|
| 378 |
+
symbols = symbols_data.get("symbols", [])
|
| 379 |
+
|
| 380 |
+
# Load text detections
|
| 381 |
+
texts = []
|
| 382 |
+
if texts_path and Path(texts_path).exists():
|
| 383 |
+
texts_data = json.loads(self.storage.load_file(texts_path).decode('utf-8'))
|
| 384 |
+
texts = texts_data.get("texts", [])
|
| 385 |
+
|
| 386 |
+
# Create graph data
|
| 387 |
+
nodes, edges = self._create_graph_data(lines, symbols, texts)
|
| 388 |
+
|
| 389 |
+
# Combine all detections
|
| 390 |
+
aggregated_data = {
|
| 391 |
+
"lines": lines,
|
| 392 |
+
"symbols": symbols,
|
| 393 |
+
"texts": texts,
|
| 394 |
+
"nodes": nodes,
|
| 395 |
+
"edges": edges,
|
| 396 |
+
"metadata": {
|
| 397 |
+
"timestamp": datetime.now().isoformat(),
|
| 398 |
+
"version": "2.0"
|
| 399 |
+
}
|
| 400 |
+
}
|
| 401 |
+
|
| 402 |
+
return aggregated_data
|
| 403 |
+
|
| 404 |
+
except Exception as e:
|
| 405 |
+
logger.error(f"Error during aggregation: {str(e)}")
|
| 406 |
+
raise
|
| 407 |
+
|
| 408 |
+
if __name__ == "__main__":
|
| 409 |
+
import os
|
| 410 |
+
from pprint import pprint
|
| 411 |
+
|
| 412 |
+
# Initialize the aggregator
|
| 413 |
+
aggregator = DataAggregator()
|
| 414 |
+
|
| 415 |
+
# Test paths (adjust these to match your results folder)
|
| 416 |
+
results_dir = "results/"
|
| 417 |
+
symbols_path = os.path.join(results_dir, "0_text_detected_symbols.json")
|
| 418 |
+
texts_path = os.path.join(results_dir, "0_text_detected_texts.json")
|
| 419 |
+
lines_path = os.path.join(results_dir, "0_text_detected_lines.json")
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
# Aggregate the data
|
| 423 |
+
aggregated_data = aggregator.aggregate_data(
|
| 424 |
+
symbols_path=symbols_path,
|
| 425 |
+
texts_path=texts_path,
|
| 426 |
+
lines_path=lines_path
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# Save the aggregated result
|
| 430 |
+
output_path = os.path.join(results_dir, "0_aggregated_test.json")
|
| 431 |
+
with open(output_path, 'w') as f:
|
| 432 |
+
json.dump(aggregated_data, f, indent=2)
|
| 433 |
+
|
| 434 |
+
# Print some statistics
|
| 435 |
+
print("\nAggregation Results:")
|
| 436 |
+
print(f"Number of Symbols: {len(aggregated_data['symbols'])}")
|
| 437 |
+
print(f"Number of Texts: {len(aggregated_data['texts'])}")
|
| 438 |
+
print(f"Number of Lines: {len(aggregated_data['lines'])}")
|
| 439 |
+
print(f"Number of Nodes: {len(aggregated_data['nodes'])}")
|
| 440 |
+
print(f"Number of Edges: {len(aggregated_data['edges'])}")
|
| 441 |
+
|
| 442 |
+
# Print sample of each type
|
| 443 |
+
print("\nSample Node:")
|
| 444 |
+
if aggregated_data['nodes']:
|
| 445 |
+
pprint(aggregated_data['nodes'][0])
|
| 446 |
+
|
| 447 |
+
print("\nSample Edge:")
|
| 448 |
+
if aggregated_data['edges']:
|
| 449 |
+
pprint(aggregated_data['edges'][0])
|
| 450 |
+
|
| 451 |
+
print(f"\nAggregated data saved to: {output_path}")
|
| 452 |
+
|
| 453 |
+
except Exception as e:
|
| 454 |
+
print(f"Error during testing: {str(e)}")
|
| 455 |
+
traceback.print_exc()
|
detection_schema.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Optional, Tuple, Dict
|
| 3 |
+
import uuid
|
| 4 |
+
from enum import Enum
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# ======================== Point ======================== #
|
| 9 |
+
class ConnectionType(Enum):
|
| 10 |
+
SOLID = "solid"
|
| 11 |
+
DASHED = "dashed"
|
| 12 |
+
PHANTOM = "phantom"
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Coordinates:
|
| 16 |
+
x: int
|
| 17 |
+
y: int
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class BBox:
|
| 21 |
+
xmin: int
|
| 22 |
+
ymin: int
|
| 23 |
+
xmax: int
|
| 24 |
+
ymax: int
|
| 25 |
+
|
| 26 |
+
def width(self) -> int:
|
| 27 |
+
return self.xmax - self.xmin
|
| 28 |
+
|
| 29 |
+
def height(self) -> int:
|
| 30 |
+
return self.ymax - self.ymin
|
| 31 |
+
|
| 32 |
+
class JunctionType(str, Enum):
|
| 33 |
+
T = "T"
|
| 34 |
+
L = "L"
|
| 35 |
+
END = "END"
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class Point:
|
| 39 |
+
coords: Coordinates
|
| 40 |
+
bbox: BBox
|
| 41 |
+
type: JunctionType
|
| 42 |
+
confidence: float = 1.0
|
| 43 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# # ======================== Symbol ======================== #
|
| 47 |
+
# class SymbolType(Enum):
|
| 48 |
+
# VALVE = "valve"
|
| 49 |
+
# PUMP = "pump"
|
| 50 |
+
# SENSOR = "sensor"
|
| 51 |
+
# # Add others as needed
|
| 52 |
+
#
|
| 53 |
+
class ValveSubtype(Enum):
|
| 54 |
+
GATE = "gate"
|
| 55 |
+
GLOBE = "globe"
|
| 56 |
+
BUTTERFLY = "butterfly"
|
| 57 |
+
#
|
| 58 |
+
# @dataclass
|
| 59 |
+
# class Symbol:
|
| 60 |
+
# symbol_type: SymbolType
|
| 61 |
+
# bbox: BBox
|
| 62 |
+
# center: Coordinates
|
| 63 |
+
# connections: List[Point] = field(default_factory=list)
|
| 64 |
+
# subtype: Optional[ValveSubtype] = None
|
| 65 |
+
# id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 66 |
+
# confidence: float = 0.95
|
| 67 |
+
# model_metadata: dict = field(default_factory=dict)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ======================== Symbol ======================== #
|
| 71 |
+
class SymbolType(Enum):
|
| 72 |
+
VALVE = "valve"
|
| 73 |
+
PUMP = "pump"
|
| 74 |
+
SENSOR = "sensor"
|
| 75 |
+
OTHER = "other" # Added to handle unknown categories
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class Symbol:
|
| 79 |
+
center: Coordinates
|
| 80 |
+
symbol_type: SymbolType = field(default=SymbolType.OTHER)
|
| 81 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 82 |
+
class_id: int = -1
|
| 83 |
+
original_label: str = ""
|
| 84 |
+
category: str = "" # e.g., "inst"
|
| 85 |
+
type: str = "" # e.g., "ind"
|
| 86 |
+
label: str = "" # e.g., "Solenoid_actuator"
|
| 87 |
+
bbox: BBox = None
|
| 88 |
+
confidence: float = 0.95
|
| 89 |
+
model_source: str = "" # e.g., "model2"
|
| 90 |
+
connections: List[Point] = field(default_factory=list)
|
| 91 |
+
subtype: Optional[ValveSubtype] = None
|
| 92 |
+
model_metadata: dict = field(default_factory=dict)
|
| 93 |
+
|
| 94 |
+
def __post_init__(self):
|
| 95 |
+
"""
|
| 96 |
+
Handle any additional post-processing after initialization.
|
| 97 |
+
"""
|
| 98 |
+
# Ensure bbox is a BBox object
|
| 99 |
+
if isinstance(self.bbox, list) and len(self.bbox) == 4:
|
| 100 |
+
self.bbox = BBox(*self.bbox)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ======================== Line ======================== #
|
| 104 |
+
@dataclass
|
| 105 |
+
class LineStyle:
|
| 106 |
+
connection_type: ConnectionType
|
| 107 |
+
stroke_width: int = 2
|
| 108 |
+
color: str = "#000000" # CSS-style colors
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class Line:
|
| 112 |
+
start: Point
|
| 113 |
+
end: Point
|
| 114 |
+
bbox: BBox
|
| 115 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 116 |
+
style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID))
|
| 117 |
+
confidence: float = 0.90
|
| 118 |
+
topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ======================== Junction ======================== #
|
| 122 |
+
class JunctionType(str, Enum):
|
| 123 |
+
T = "T"
|
| 124 |
+
L = "L"
|
| 125 |
+
END = "END"
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class JunctionProperties:
|
| 129 |
+
flow_direction: Optional[str] = None # "in", "out"
|
| 130 |
+
pressure: Optional[float] = None # kPa
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
class Junction:
|
| 134 |
+
center: Coordinates
|
| 135 |
+
junction_type: JunctionType
|
| 136 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 137 |
+
properties: JunctionProperties = field(default_factory=JunctionProperties)
|
| 138 |
+
connected_lines: List[str] = field(default_factory=list) # Line IDs
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# # ======================== Tag ======================== #
|
| 142 |
+
# @dataclass
|
| 143 |
+
# class Tag:
|
| 144 |
+
# text: str
|
| 145 |
+
# bbox: BBox
|
| 146 |
+
# associated_element: str # ID of linked symbol/line
|
| 147 |
+
# id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 148 |
+
# font_size: int = 12
|
| 149 |
+
# rotation: float = 0.0 # Degrees
|
| 150 |
+
|
| 151 |
+
@dataclass
|
| 152 |
+
class Tag:
|
| 153 |
+
text: str
|
| 154 |
+
bbox: BBox
|
| 155 |
+
confidence: float = 1.0
|
| 156 |
+
source: str = "" # e.g., "easyocr"
|
| 157 |
+
text_type: str = "Unknown" # e.g., "Unknown", could be something else later
|
| 158 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 159 |
+
associated_element: Optional[str] = None # ID of linked symbol/line (can be None)
|
| 160 |
+
font_size: int = 12
|
| 161 |
+
rotation: float = 0.0 # Degrees
|
| 162 |
+
|
| 163 |
+
def __post_init__(self):
|
| 164 |
+
"""
|
| 165 |
+
Ensure bbox is properly converted.
|
| 166 |
+
"""
|
| 167 |
+
if isinstance(self.bbox, list) and len(self.bbox) == 4:
|
| 168 |
+
self.bbox = BBox(*self.bbox)
|
| 169 |
+
|
| 170 |
+
# ----------------------------
|
| 171 |
+
# DETECTION CONTEXT
|
| 172 |
+
# ----------------------------
|
| 173 |
+
|
| 174 |
+
@dataclass
|
| 175 |
+
class DetectionContext:
|
| 176 |
+
"""
|
| 177 |
+
In-memory container for all detected elements (lines, points, symbols, junctions, tags).
|
| 178 |
+
Each element is stored in a dict keyed by 'id' for quick lookup and update.
|
| 179 |
+
"""
|
| 180 |
+
lines: Dict[str, Line] = field(default_factory=dict)
|
| 181 |
+
points: Dict[str, Point] = field(default_factory=dict)
|
| 182 |
+
symbols: Dict[str, Symbol] = field(default_factory=dict)
|
| 183 |
+
junctions: Dict[str, Junction] = field(default_factory=dict)
|
| 184 |
+
tags: Dict[str, Tag] = field(default_factory=dict)
|
| 185 |
+
|
| 186 |
+
# -------------------------
|
| 187 |
+
# 1) ADD / GET / REMOVE
|
| 188 |
+
# -------------------------
|
| 189 |
+
def add_line(self, line: Line) -> None:
|
| 190 |
+
self.lines[line.id] = line
|
| 191 |
+
|
| 192 |
+
def get_line(self, line_id: str) -> Optional[Line]:
|
| 193 |
+
return self.lines.get(line_id)
|
| 194 |
+
|
| 195 |
+
def remove_line(self, line_id: str) -> None:
|
| 196 |
+
self.lines.pop(line_id, None)
|
| 197 |
+
|
| 198 |
+
def add_point(self, point: Point) -> None:
|
| 199 |
+
self.points[point.id] = point
|
| 200 |
+
|
| 201 |
+
def get_point(self, point_id: str) -> Optional[Point]:
|
| 202 |
+
return self.points.get(point_id)
|
| 203 |
+
|
| 204 |
+
def remove_point(self, point_id: str) -> None:
|
| 205 |
+
self.points.pop(point_id, None)
|
| 206 |
+
|
| 207 |
+
def add_symbol(self, symbol: Symbol) -> None:
|
| 208 |
+
self.symbols[symbol.id] = symbol
|
| 209 |
+
|
| 210 |
+
def get_symbol(self, symbol_id: str) -> Optional[Symbol]:
|
| 211 |
+
return self.symbols.get(symbol_id)
|
| 212 |
+
|
| 213 |
+
def remove_symbol(self, symbol_id: str) -> None:
|
| 214 |
+
self.symbols.pop(symbol_id, None)
|
| 215 |
+
|
| 216 |
+
def add_junction(self, junction: Junction) -> None:
|
| 217 |
+
self.junctions[junction.id] = junction
|
| 218 |
+
|
| 219 |
+
def get_junction(self, junction_id: str) -> Optional[Junction]:
|
| 220 |
+
return self.junctions.get(junction_id)
|
| 221 |
+
|
| 222 |
+
def remove_junction(self, junction_id: str) -> None:
|
| 223 |
+
self.junctions.pop(junction_id, None)
|
| 224 |
+
|
| 225 |
+
def add_tag(self, tag: Tag) -> None:
|
| 226 |
+
self.tags[tag.id] = tag
|
| 227 |
+
|
| 228 |
+
def get_tag(self, tag_id: str) -> Optional[Tag]:
|
| 229 |
+
return self.tags.get(tag_id)
|
| 230 |
+
|
| 231 |
+
def remove_tag(self, tag_id: str) -> None:
|
| 232 |
+
self.tags.pop(tag_id, None)
|
| 233 |
+
|
| 234 |
+
# -------------------------
|
| 235 |
+
# 2) SERIALIZATION: to_dict / from_dict
|
| 236 |
+
# -------------------------
|
| 237 |
+
def to_dict(self) -> dict:
|
| 238 |
+
"""Convert all stored objects into a JSON-serializable dictionary."""
|
| 239 |
+
return {
|
| 240 |
+
"lines": [self._line_to_dict(line) for line in self.lines.values()],
|
| 241 |
+
"points": [self._point_to_dict(pt) for pt in self.points.values()],
|
| 242 |
+
"symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()],
|
| 243 |
+
"junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()],
|
| 244 |
+
"tags": [self._tag_to_dict(tg) for tg in self.tags.values()]
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def from_dict(cls, data: dict) -> "DetectionContext":
|
| 249 |
+
"""
|
| 250 |
+
Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON).
|
| 251 |
+
"""
|
| 252 |
+
context = cls()
|
| 253 |
+
|
| 254 |
+
# Points
|
| 255 |
+
for pt_dict in data.get("points", []):
|
| 256 |
+
pt_obj = cls._point_from_dict(pt_dict)
|
| 257 |
+
context.add_point(pt_obj)
|
| 258 |
+
|
| 259 |
+
# Lines
|
| 260 |
+
for ln_dict in data.get("lines", []):
|
| 261 |
+
ln_obj = cls._line_from_dict(ln_dict)
|
| 262 |
+
context.add_line(ln_obj)
|
| 263 |
+
|
| 264 |
+
# Symbols
|
| 265 |
+
for sym_dict in data.get("symbols", []):
|
| 266 |
+
sym_obj = cls._symbol_from_dict(sym_dict)
|
| 267 |
+
context.add_symbol(sym_obj)
|
| 268 |
+
|
| 269 |
+
# Junctions
|
| 270 |
+
for jn_dict in data.get("junctions", []):
|
| 271 |
+
jn_obj = cls._junction_from_dict(jn_dict)
|
| 272 |
+
context.add_junction(jn_obj)
|
| 273 |
+
|
| 274 |
+
# Tags
|
| 275 |
+
for tg_dict in data.get("tags", []):
|
| 276 |
+
tg_obj = cls._tag_from_dict(tg_dict)
|
| 277 |
+
context.add_tag(tg_obj)
|
| 278 |
+
|
| 279 |
+
return context
|
| 280 |
+
|
| 281 |
+
# -------------------------
|
| 282 |
+
# 3) HELPER METHODS FOR (DE)SERIALIZATION
|
| 283 |
+
# -------------------------
|
| 284 |
+
@staticmethod
|
| 285 |
+
def _bbox_to_dict(bbox: BBox) -> dict:
|
| 286 |
+
return {
|
| 287 |
+
"xmin": bbox.xmin,
|
| 288 |
+
"ymin": bbox.ymin,
|
| 289 |
+
"xmax": bbox.xmax,
|
| 290 |
+
"ymax": bbox.ymax
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def _bbox_from_dict(d: dict) -> BBox:
|
| 295 |
+
return BBox(
|
| 296 |
+
xmin=d["xmin"],
|
| 297 |
+
ymin=d["ymin"],
|
| 298 |
+
xmax=d["xmax"],
|
| 299 |
+
ymax=d["ymax"]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
@staticmethod
|
| 303 |
+
def _coords_to_dict(coords: Coordinates) -> dict:
|
| 304 |
+
return {
|
| 305 |
+
"x": coords.x,
|
| 306 |
+
"y": coords.y
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
@staticmethod
|
| 310 |
+
def _coords_from_dict(d: dict) -> Coordinates:
|
| 311 |
+
return Coordinates(x=d["x"], y=d["y"])
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def _line_style_to_dict(style: LineStyle) -> dict:
|
| 315 |
+
return {
|
| 316 |
+
"connection_type": style.connection_type.value,
|
| 317 |
+
"stroke_width": style.stroke_width,
|
| 318 |
+
"color": style.color
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def _line_style_from_dict(d: dict) -> LineStyle:
|
| 323 |
+
return LineStyle(
|
| 324 |
+
connection_type=ConnectionType(d["connection_type"]),
|
| 325 |
+
stroke_width=d.get("stroke_width", 2),
|
| 326 |
+
color=d.get("color", "#000000")
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def _point_to_dict(pt: Point) -> dict:
|
| 331 |
+
return {
|
| 332 |
+
"id": pt.id,
|
| 333 |
+
"coords": DetectionContext._coords_to_dict(pt.coords),
|
| 334 |
+
"bbox": DetectionContext._bbox_to_dict(pt.bbox),
|
| 335 |
+
"type": pt.type.value,
|
| 336 |
+
"confidence": pt.confidence
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
@staticmethod
|
| 340 |
+
def _point_from_dict(d: dict) -> Point:
|
| 341 |
+
return Point(
|
| 342 |
+
id=d["id"],
|
| 343 |
+
coords=DetectionContext._coords_from_dict(d["coords"]),
|
| 344 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 345 |
+
type=JunctionType(d["type"]),
|
| 346 |
+
confidence=d.get("confidence", 1.0)
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
@staticmethod
|
| 350 |
+
def _line_to_dict(ln: Line) -> dict:
|
| 351 |
+
return {
|
| 352 |
+
"id": ln.id,
|
| 353 |
+
"start": DetectionContext._point_to_dict(ln.start),
|
| 354 |
+
"end": DetectionContext._point_to_dict(ln.end),
|
| 355 |
+
"bbox": DetectionContext._bbox_to_dict(ln.bbox),
|
| 356 |
+
"style": DetectionContext._line_style_to_dict(ln.style),
|
| 357 |
+
"confidence": ln.confidence,
|
| 358 |
+
"topological_links": ln.topological_links
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
@staticmethod
|
| 362 |
+
def _line_from_dict(d: dict) -> Line:
|
| 363 |
+
return Line(
|
| 364 |
+
id=d["id"],
|
| 365 |
+
start=DetectionContext._point_from_dict(d["start"]),
|
| 366 |
+
end=DetectionContext._point_from_dict(d["end"]),
|
| 367 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 368 |
+
style=DetectionContext._line_style_from_dict(d["style"]),
|
| 369 |
+
confidence=d.get("confidence", 0.90),
|
| 370 |
+
topological_links=d.get("topological_links", [])
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
@staticmethod
|
| 374 |
+
def _symbol_to_dict(sym: Symbol) -> dict:
|
| 375 |
+
return {
|
| 376 |
+
"id": sym.id,
|
| 377 |
+
"symbol_type": sym.symbol_type.value,
|
| 378 |
+
"bbox": DetectionContext._bbox_to_dict(sym.bbox),
|
| 379 |
+
"center": DetectionContext._coords_to_dict(sym.center),
|
| 380 |
+
"connections": [DetectionContext._point_to_dict(p) for p in sym.connections],
|
| 381 |
+
"subtype": sym.subtype.value if sym.subtype else None,
|
| 382 |
+
"confidence": sym.confidence,
|
| 383 |
+
"model_metadata": sym.model_metadata
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def _symbol_from_dict(d: dict) -> Symbol:
|
| 388 |
+
return Symbol(
|
| 389 |
+
id=d["id"],
|
| 390 |
+
symbol_type=SymbolType(d["symbol_type"]),
|
| 391 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 392 |
+
center=DetectionContext._coords_from_dict(d["center"]),
|
| 393 |
+
connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])],
|
| 394 |
+
subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None,
|
| 395 |
+
confidence=d.get("confidence", 0.95),
|
| 396 |
+
model_metadata=d.get("model_metadata", {})
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
@staticmethod
|
| 400 |
+
def _junction_props_to_dict(props: JunctionProperties) -> dict:
|
| 401 |
+
return {
|
| 402 |
+
"flow_direction": props.flow_direction,
|
| 403 |
+
"pressure": props.pressure
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
@staticmethod
|
| 407 |
+
def _junction_props_from_dict(d: dict) -> JunctionProperties:
|
| 408 |
+
return JunctionProperties(
|
| 409 |
+
flow_direction=d.get("flow_direction"),
|
| 410 |
+
pressure=d.get("pressure")
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
@staticmethod
|
| 414 |
+
def _junction_to_dict(jn: Junction) -> dict:
|
| 415 |
+
return {
|
| 416 |
+
"id": jn.id,
|
| 417 |
+
"center": DetectionContext._coords_to_dict(jn.center),
|
| 418 |
+
"junction_type": jn.junction_type.value,
|
| 419 |
+
"properties": DetectionContext._junction_props_to_dict(jn.properties),
|
| 420 |
+
"connected_lines": jn.connected_lines
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
@staticmethod
|
| 424 |
+
def _junction_from_dict(d: dict) -> Junction:
|
| 425 |
+
return Junction(
|
| 426 |
+
id=d["id"],
|
| 427 |
+
center=DetectionContext._coords_from_dict(d["center"]),
|
| 428 |
+
junction_type=JunctionType(d["junction_type"]),
|
| 429 |
+
properties=DetectionContext._junction_props_from_dict(d["properties"]),
|
| 430 |
+
connected_lines=d.get("connected_lines", [])
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
@staticmethod
|
| 434 |
+
def _tag_to_dict(tg: Tag) -> dict:
|
| 435 |
+
return {
|
| 436 |
+
"id": tg.id,
|
| 437 |
+
"text": tg.text,
|
| 438 |
+
"bbox": DetectionContext._bbox_to_dict(tg.bbox),
|
| 439 |
+
"associated_element": tg.associated_element,
|
| 440 |
+
"font_size": tg.font_size,
|
| 441 |
+
"rotation": tg.rotation
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _tag_from_dict(d: dict) -> Tag:
|
| 446 |
+
return Tag(
|
| 447 |
+
id=d["id"],
|
| 448 |
+
text=d["text"],
|
| 449 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 450 |
+
associated_element=d["associated_element"],
|
| 451 |
+
font_size=d.get("font_size", 12),
|
| 452 |
+
rotation=d.get("rotation", 0.0)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# -------------------------
|
| 456 |
+
# 4) OPTIONAL UTILS
|
| 457 |
+
# -------------------------
|
| 458 |
+
def to_json(self, indent: int = 2) -> str:
|
| 459 |
+
"""Convert context to JSON, ensuring dataclasses and numpy types are handled correctly."""
|
| 460 |
+
return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def _json_serializer(obj):
|
| 464 |
+
"""Handles numpy types and unknown objects for JSON serialization."""
|
| 465 |
+
if isinstance(obj, np.integer):
|
| 466 |
+
return int(obj)
|
| 467 |
+
if isinstance(obj, np.floating):
|
| 468 |
+
return float(obj)
|
| 469 |
+
if isinstance(obj, np.ndarray):
|
| 470 |
+
return obj.tolist() # Convert arrays to lists
|
| 471 |
+
if isinstance(obj, Enum):
|
| 472 |
+
return obj.value # Convert Enums to string values
|
| 473 |
+
if hasattr(obj, "__dict__"):
|
| 474 |
+
return obj.__dict__ # Convert dataclass objects to dict
|
| 475 |
+
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
| 476 |
+
|
| 477 |
+
@classmethod
|
| 478 |
+
def from_json(cls, json_str: str) -> "DetectionContext":
|
| 479 |
+
"""Load DetectionContext from a JSON string."""
|
| 480 |
+
data = json.loads(json_str)
|
| 481 |
+
return cls.from_dict(data)
|
detectors.py
ADDED
|
@@ -0,0 +1,1096 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Optional, Tuple, Dict
|
| 7 |
+
from dataclasses import replace
|
| 8 |
+
from math import sqrt
|
| 9 |
+
import json
|
| 10 |
+
import uuid
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Base classes and utilities
|
| 14 |
+
from base import BaseDetector
|
| 15 |
+
from detection_schema import DetectionContext
|
| 16 |
+
from utils import DebugHandler
|
| 17 |
+
from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
|
| 18 |
+
|
| 19 |
+
# DeepLSD model for line detection
|
| 20 |
+
from deeplsd.models.deeplsd_inference import DeepLSD
|
| 21 |
+
from ultralytics import YOLO
|
| 22 |
+
|
| 23 |
+
# Detection schema: dataclasses for different objects
|
| 24 |
+
from detection_schema import (
|
| 25 |
+
BBox,
|
| 26 |
+
Coordinates,
|
| 27 |
+
Point,
|
| 28 |
+
Line,
|
| 29 |
+
Symbol,
|
| 30 |
+
Tag,
|
| 31 |
+
SymbolType,
|
| 32 |
+
LineStyle,
|
| 33 |
+
ConnectionType,
|
| 34 |
+
JunctionType,
|
| 35 |
+
Junction
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Skeletonization and label processing for junction detection
|
| 39 |
+
from skimage.morphology import skeletonize
|
| 40 |
+
from skimage.measure import label
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class LineDetector(BaseDetector):
|
| 44 |
+
"""
|
| 45 |
+
DeepLSD-based line detection that populates newly detected lines (and naive endpoints)
|
| 46 |
+
directly into a DetectionContext.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
<<<<<<< HEAD
|
| 50 |
+
def __init__(self, model_path=None, model=None, model_config=None, device=None, debug_handler=None):
|
| 51 |
+
self.device = device or torch.device('cpu')
|
| 52 |
+
self.debug_handler = debug_handler
|
| 53 |
+
if model is not None:
|
| 54 |
+
self.model = model
|
| 55 |
+
else:
|
| 56 |
+
super().__init__(model_path)
|
| 57 |
+
self.config = model_config or {}
|
| 58 |
+
self.scale_factor = 8.0 # Inverse of 0.5 scaling
|
| 59 |
+
self.margin = 10 # BBox expansion margin
|
| 60 |
+
=======
|
| 61 |
+
def __init__(self,
|
| 62 |
+
config: LineConfig,
|
| 63 |
+
model_path: str,
|
| 64 |
+
model_config: dict,
|
| 65 |
+
device: torch.device,
|
| 66 |
+
debug_handler: DebugHandler = None):
|
| 67 |
+
self.device = device
|
| 68 |
+
self.model_path = model_path
|
| 69 |
+
self.model_config = model_config
|
| 70 |
+
super().__init__(config, debug_handler)
|
| 71 |
+
self._load_params()
|
| 72 |
+
self.model = self._load_model(model_path)
|
| 73 |
+
self.scale_factor = 0.75 # For downscaling input to model
|
| 74 |
+
self.margin = 10
|
| 75 |
+
>>>>>>> temp/test-integration
|
| 76 |
+
|
| 77 |
+
# -------------------------------------
|
| 78 |
+
# BaseDetector requirements
|
| 79 |
+
# -------------------------------------
|
| 80 |
+
def _load_model(self, model_path: str) -> DeepLSD:
|
| 81 |
+
"""Load and configure the DeepLSD model."""
|
| 82 |
+
if not os.path.exists(model_path):
|
| 83 |
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 84 |
+
ckpt = torch.load(model_path, map_location=self.device)
|
| 85 |
+
<<<<<<< HEAD
|
| 86 |
+
model = DeepLSD(self.config)
|
| 87 |
+
model.load_state_dict(ckpt['model'])
|
| 88 |
+
=======
|
| 89 |
+
model = DeepLSD(self.model_config)
|
| 90 |
+
model.load_state_dict(ckpt["model"])
|
| 91 |
+
>>>>>>> temp/test-integration
|
| 92 |
+
return model.to(self.device).eval()
|
| 93 |
+
|
| 94 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 95 |
+
"""
|
| 96 |
+
Not used directly here. We'll handle our own
|
| 97 |
+
masking + threshold steps in the detect() method.
|
| 98 |
+
"""
|
| 99 |
+
return image
|
| 100 |
+
|
| 101 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 102 |
+
"""
|
| 103 |
+
Not used directly. Postprocessing is integrated
|
| 104 |
+
into detect() after we create lines.
|
| 105 |
+
"""
|
| 106 |
+
return image
|
| 107 |
+
|
| 108 |
+
# -------------------------------------
|
| 109 |
+
# Our main detection method
|
| 110 |
+
# -------------------------------------
|
| 111 |
+
def detect(self,
|
| 112 |
+
image: np.ndarray,
|
| 113 |
+
context: DetectionContext,
|
| 114 |
+
mask_coords: Optional[List[BBox]] = None,
|
| 115 |
+
*args,
|
| 116 |
+
**kwargs) -> None:
|
| 117 |
+
"""
|
| 118 |
+
Main detection pipeline:
|
| 119 |
+
1) Apply mask
|
| 120 |
+
2) Convert to binary & downscale
|
| 121 |
+
3) Run DeepLSD
|
| 122 |
+
4) Build minimal Line objects (with naive endpoints)
|
| 123 |
+
5) Scale lines to original resolution
|
| 124 |
+
6) Store the lines into the context
|
| 125 |
+
|
| 126 |
+
We do NOT unify endpoints here or classify them as T/L/etc.
|
| 127 |
+
"""
|
| 128 |
+
mask_coords = mask_coords or []
|
| 129 |
+
|
| 130 |
+
# (A) Preprocess
|
| 131 |
+
processed_img = self._apply_mask_and_downscale(image, mask_coords)
|
| 132 |
+
|
| 133 |
+
# (B) Inference
|
| 134 |
+
raw_output = self._run_model_inference(processed_img)
|
| 135 |
+
|
| 136 |
+
# (C) Create lines in downscaled space
|
| 137 |
+
downscaled_lines = self._create_lines_from_output(raw_output)
|
| 138 |
+
|
| 139 |
+
# (D) Scale them to original resolution
|
| 140 |
+
lines_scaled = [self._scale_line(ln) for ln in downscaled_lines]
|
| 141 |
+
|
| 142 |
+
# (E) Add them to context
|
| 143 |
+
for line in lines_scaled:
|
| 144 |
+
context.add_line(line)
|
| 145 |
+
|
| 146 |
+
# -------------------------------------
|
| 147 |
+
# Internal helpers
|
| 148 |
+
# -------------------------------------
|
| 149 |
+
def _load_params(self):
|
| 150 |
+
"""Load any model_config parameters if needed."""
|
| 151 |
+
pass
|
| 152 |
+
|
| 153 |
+
def _apply_mask_and_downscale(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
|
| 154 |
+
"""Apply rectangular mask, then threshold, then downscale."""
|
| 155 |
+
masked = self._apply_masking(image, mask_coords)
|
| 156 |
+
gray = cv2.cvtColor(masked, cv2.COLOR_RGB2GRAY)
|
| 157 |
+
<<<<<<< HEAD
|
| 158 |
+
binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
|
| 159 |
+
return cv2.resize(binary, None, fx=1/self.scale_factor, fy=1/self.scale_factor)
|
| 160 |
+
=======
|
| 161 |
+
binary_full = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
|
| 162 |
+
>>>>>>> temp/test-integration
|
| 163 |
+
|
| 164 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
|
| 165 |
+
dilated = cv2.dilate(binary_full, kernel, iterations=2)
|
| 166 |
+
|
| 167 |
+
# Downscale
|
| 168 |
+
binary_downscaled = cv2.resize(
|
| 169 |
+
dilated,
|
| 170 |
+
None,
|
| 171 |
+
fx=self.scale_factor,
|
| 172 |
+
fy=self.scale_factor
|
| 173 |
+
)
|
| 174 |
+
return binary_downscaled
|
| 175 |
+
|
| 176 |
+
def _apply_masking(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
|
| 177 |
+
"""White-out rectangular areas to ignore them."""
|
| 178 |
+
masked = image.copy()
|
| 179 |
+
for bbox in mask_coords:
|
| 180 |
+
x1, y1 = int(bbox.xmin), int(bbox.ymin)
|
| 181 |
+
x2, y2 = int(bbox.xmax), int(bbox.ymax)
|
| 182 |
+
cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)
|
| 183 |
+
return masked
|
| 184 |
+
|
| 185 |
+
def _run_model_inference(self, downscaled_binary: np.ndarray) -> np.ndarray:
|
| 186 |
+
"""Run DeepLSD on the downscaled binary image, returning raw lines [N, 2, 2]."""
|
| 187 |
+
tensor = torch.tensor(downscaled_binary, dtype=torch.float32, device=self.device)[None, None] / 255.0
|
| 188 |
+
# tensor = torch.tensor(downscaled_binary, dtype=torch.float32, device=self.device)[None, None] / 255.0
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
output = self.model({"image": tensor})
|
| 191 |
+
# shape: [batch, num_lines, 2, 2]
|
| 192 |
+
return output["lines"][0]
|
| 193 |
+
|
| 194 |
+
def _create_lines_from_output(self, model_output: np.ndarray) -> List[Line]:
|
| 195 |
+
"""
|
| 196 |
+
Convert each [2,2] line segment into a minimal Line with naive endpoints (type=END).
|
| 197 |
+
Coordinates are in downscaled space.
|
| 198 |
+
"""
|
| 199 |
+
lines = []
|
| 200 |
+
for endpoints in model_output:
|
| 201 |
+
(x1, y1), (x2, y2) = endpoints # shape (2,) each
|
| 202 |
+
|
| 203 |
+
p_start = self._create_point(x1, y1)
|
| 204 |
+
p_end = self._create_point(x2, y2)
|
| 205 |
+
|
| 206 |
+
# minimal bounding box in downscaled coords
|
| 207 |
+
x_min = min(x1, x2)
|
| 208 |
+
x_max = max(x1, x2)
|
| 209 |
+
y_min = min(y1, y2)
|
| 210 |
+
y_max = max(y1, y2)
|
| 211 |
+
|
| 212 |
+
line_obj = Line(
|
| 213 |
+
start=p_start,
|
| 214 |
+
end=p_end,
|
| 215 |
+
bbox=BBox(
|
| 216 |
+
xmin=int(x_min),
|
| 217 |
+
ymin=int(y_min),
|
| 218 |
+
xmax=int(x_max),
|
| 219 |
+
ymax=int(y_max)
|
| 220 |
+
),
|
| 221 |
+
# style / confidence / ID assigned by default
|
| 222 |
+
style=LineStyle(
|
| 223 |
+
connection_type=ConnectionType.SOLID,
|
| 224 |
+
stroke_width=2,
|
| 225 |
+
color="#000000"
|
| 226 |
+
),
|
| 227 |
+
confidence=0.9,
|
| 228 |
+
topological_links=[]
|
| 229 |
+
)
|
| 230 |
+
lines.append(line_obj)
|
| 231 |
+
|
| 232 |
+
return lines
|
| 233 |
+
|
| 234 |
+
def _create_point(self, x: float, y: float) -> Point:
|
| 235 |
+
"""
|
| 236 |
+
Creates a naive 'END'-type Point at downscaled coords.
|
| 237 |
+
We'll scale it later.
|
| 238 |
+
"""
|
| 239 |
+
margin = 2
|
| 240 |
+
return Point(
|
| 241 |
+
coords=Coordinates(x=int(x), y=int(y)),
|
| 242 |
+
bbox=BBox(
|
| 243 |
+
xmin=int(x - margin),
|
| 244 |
+
ymin=int(y - margin),
|
| 245 |
+
xmax=int(x + margin),
|
| 246 |
+
ymax=int(y + margin)
|
| 247 |
+
),
|
| 248 |
+
type=JunctionType.END, # no classification here
|
| 249 |
+
confidence=1.0
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
def _scale_line(self, line: Line) -> Line:
|
| 253 |
+
"""
|
| 254 |
+
Scale line's start/end points + bounding box to original resolution.
|
| 255 |
+
"""
|
| 256 |
+
scaled_start = self._scale_point(line.start)
|
| 257 |
+
scaled_end = self._scale_point(line.end)
|
| 258 |
+
|
| 259 |
+
# recalc bounding box in original scale
|
| 260 |
+
new_bbox = BBox(
|
| 261 |
+
xmin=min(scaled_start.bbox.xmin, scaled_end.bbox.xmin),
|
| 262 |
+
ymin=min(scaled_start.bbox.ymin, scaled_end.bbox.ymin),
|
| 263 |
+
xmax=max(scaled_start.bbox.xmax, scaled_end.bbox.xmax),
|
| 264 |
+
ymax=max(scaled_start.bbox.ymax, scaled_end.bbox.ymax)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return replace(line, start=scaled_start, end=scaled_end, bbox=new_bbox)
|
| 268 |
+
|
| 269 |
+
def _scale_point(self, point: Point) -> Point:
|
| 270 |
+
sx = int(point.coords.x * 1/self.scale_factor)
|
| 271 |
+
sy = int(point.coords.y * 1/self.scale_factor)
|
| 272 |
+
|
| 273 |
+
bb = point.bbox
|
| 274 |
+
scaled_bbox = BBox(
|
| 275 |
+
xmin=int(bb.xmin * 1/self.scale_factor),
|
| 276 |
+
ymin=int(bb.ymin * 1/self.scale_factor),
|
| 277 |
+
xmax=int(bb.xmax * 1/self.scale_factor),
|
| 278 |
+
ymax=int(bb.ymax * 1/self.scale_factor)
|
| 279 |
+
)
|
| 280 |
+
return replace(point, coords=Coordinates(sx, sy), bbox=scaled_bbox)
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
class PointDetector(BaseDetector):
|
| 284 |
+
"""
|
| 285 |
+
A detector that:
|
| 286 |
+
1) Reads lines from the context
|
| 287 |
+
2) Clusters endpoints within 'threshold_distance'
|
| 288 |
+
3) Updates lines so that shared endpoints reference the same Point object
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __init__(self,
|
| 292 |
+
config:PointConfig,
|
| 293 |
+
debug_handler: DebugHandler = None):
|
| 294 |
+
super().__init__(config, debug_handler) # No real model to load
|
| 295 |
+
self.threshold_distance = config.threshold_distance
|
| 296 |
+
|
| 297 |
+
def _load_model(self, model_path: str):
|
| 298 |
+
"""No model needed for simple point unification."""
|
| 299 |
+
return None
|
| 300 |
+
|
| 301 |
+
def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
|
| 302 |
+
"""
|
| 303 |
+
Main method called by the pipeline.
|
| 304 |
+
1) Gather all line endpoints from context
|
| 305 |
+
2) Cluster them within 'threshold_distance'
|
| 306 |
+
3) Update the line endpoints so they reference the unified cluster point
|
| 307 |
+
"""
|
| 308 |
+
# 1) Collect all endpoints
|
| 309 |
+
endpoints = []
|
| 310 |
+
for line in context.lines.values():
|
| 311 |
+
endpoints.append(line.start)
|
| 312 |
+
endpoints.append(line.end)
|
| 313 |
+
|
| 314 |
+
# 2) Cluster endpoints
|
| 315 |
+
clusters = self._cluster_points(endpoints, self.threshold_distance)
|
| 316 |
+
|
| 317 |
+
# 3) Build a dictionary of "representative" points
|
| 318 |
+
# So that each cluster has one "canonical" point
|
| 319 |
+
# Then we link all the points in that cluster to the canonical reference
|
| 320 |
+
unified_point_map = {}
|
| 321 |
+
for cluster in clusters:
|
| 322 |
+
# let's pick the first point in the cluster as the "representative"
|
| 323 |
+
rep_point = cluster[0]
|
| 324 |
+
for p in cluster[1:]:
|
| 325 |
+
unified_point_map[p.id] = rep_point
|
| 326 |
+
|
| 327 |
+
# 4) Update all lines to reference the canonical point
|
| 328 |
+
for line in context.lines.values():
|
| 329 |
+
# unify start
|
| 330 |
+
if line.start.id in unified_point_map:
|
| 331 |
+
line.start = unified_point_map[line.start.id]
|
| 332 |
+
# unify end
|
| 333 |
+
if line.end.id in unified_point_map:
|
| 334 |
+
line.end = unified_point_map[line.end.id]
|
| 335 |
+
|
| 336 |
+
# We could also store the final set of unique points back in context.points
|
| 337 |
+
# (e.g. clearing old duplicates).
|
| 338 |
+
# That step is optional: you might prefer to keep everything in lines only,
|
| 339 |
+
# or you might want context.points as a separate reference.
|
| 340 |
+
|
| 341 |
+
# If you want to keep unique points in context.points:
|
| 342 |
+
new_points = {}
|
| 343 |
+
for line in context.lines.values():
|
| 344 |
+
new_points[line.start.id] = line.start
|
| 345 |
+
new_points[line.end.id] = line.end
|
| 346 |
+
context.points = new_points # replace the dictionary of points
|
| 347 |
+
|
| 348 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 349 |
+
"""No specific image preprocessing needed."""
|
| 350 |
+
return image
|
| 351 |
+
|
| 352 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 353 |
+
"""No specific image postprocessing needed."""
|
| 354 |
+
return image
|
| 355 |
+
|
| 356 |
+
# ----------------------
|
| 357 |
+
# HELPER: clustering
|
| 358 |
+
# ----------------------
|
| 359 |
+
def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
|
| 360 |
+
"""
|
| 361 |
+
Very naive clustering:
|
| 362 |
+
1) Start from the first point
|
| 363 |
+
2) If it's within threshold of an existing cluster's representative,
|
| 364 |
+
put it in that cluster
|
| 365 |
+
3) Otherwise start a new cluster
|
| 366 |
+
Return: list of clusters, each is a list of Points
|
| 367 |
+
"""
|
| 368 |
+
clusters = []
|
| 369 |
+
|
| 370 |
+
for pt in points:
|
| 371 |
+
placed = False
|
| 372 |
+
for cluster in clusters:
|
| 373 |
+
# pick the first point in the cluster as reference
|
| 374 |
+
ref_pt = cluster[0]
|
| 375 |
+
if self._distance(pt, ref_pt) < threshold:
|
| 376 |
+
cluster.append(pt)
|
| 377 |
+
placed = True
|
| 378 |
+
break
|
| 379 |
+
|
| 380 |
+
if not placed:
|
| 381 |
+
clusters.append([pt])
|
| 382 |
+
|
| 383 |
+
return clusters
|
| 384 |
+
|
| 385 |
+
def _distance(self, p1: Point, p2: Point) -> float:
|
| 386 |
+
dx = p1.coords.x - p2.coords.x
|
| 387 |
+
dy = p1.coords.y - p2.coords.y
|
| 388 |
+
return sqrt(dx*dx + dy*dy)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class JunctionDetector(BaseDetector):
|
| 392 |
+
"""
|
| 393 |
+
Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
|
| 394 |
+
and analyzing local connectivity. Also creates Junction objects in the context.
|
| 395 |
+
"""
|
| 396 |
+
|
| 397 |
+
def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
|
| 398 |
+
super().__init__(config, debug_handler) # no real model path
|
| 399 |
+
self.window_size = config.window_size
|
| 400 |
+
self.radius = config.radius
|
| 401 |
+
self.angle_threshold_lb = config.angle_threshold_lb
|
| 402 |
+
self.angle_threshold_ub = config.angle_threshold_ub
|
| 403 |
+
self.debug_handler = debug_handler or DebugHandler()
|
| 404 |
+
|
| 405 |
+
def _load_model(self, model_path: str):
|
| 406 |
+
"""Not loading any actual model, just skeleton logic."""
|
| 407 |
+
return None
|
| 408 |
+
|
| 409 |
+
def detect(self,
|
| 410 |
+
image: np.ndarray,
|
| 411 |
+
context: DetectionContext,
|
| 412 |
+
*args,
|
| 413 |
+
**kwargs) -> None:
|
| 414 |
+
"""
|
| 415 |
+
1) Convert to binary & skeletonize
|
| 416 |
+
2) Classify each point in the context
|
| 417 |
+
3) Create a Junction for each point and store it in context.junctions
|
| 418 |
+
(with 'connected_lines' referencing lines that share this point).
|
| 419 |
+
"""
|
| 420 |
+
# 1) Preprocess -> skeleton
|
| 421 |
+
skeleton = self._create_skeleton(image)
|
| 422 |
+
|
| 423 |
+
# 2) Classify each point
|
| 424 |
+
for pt in context.points.values():
|
| 425 |
+
pt.type = self._classify_point(skeleton, pt)
|
| 426 |
+
|
| 427 |
+
# 3) Create a Junction object for each point
|
| 428 |
+
# If you prefer only T or L, you can filter out END points.
|
| 429 |
+
self._record_junctions_in_context(context)
|
| 430 |
+
|
| 431 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 432 |
+
"""We might do thresholding; let's do a simple binary threshold."""
|
| 433 |
+
if image.ndim == 3:
|
| 434 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 435 |
+
else:
|
| 436 |
+
gray = image
|
| 437 |
+
_, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
|
| 438 |
+
return bin_image
|
| 439 |
+
|
| 440 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 441 |
+
return image
|
| 442 |
+
|
| 443 |
+
def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
|
| 444 |
+
"""Skeletonize the binarized image."""
|
| 445 |
+
bin_img = self._preprocess(raw_image)
|
| 446 |
+
# For skeletonize, we need a boolean array
|
| 447 |
+
inv = cv2.bitwise_not(bin_img)
|
| 448 |
+
inv_bool = (inv > 127).astype(np.uint8)
|
| 449 |
+
skel = skeletonize(inv_bool).astype(np.uint8) * 255
|
| 450 |
+
return skel
|
| 451 |
+
|
| 452 |
+
def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
|
| 453 |
+
"""
|
| 454 |
+
Given a skeleton image, look around 'pt' in a local window
|
| 455 |
+
to determine if it's an END, L, or T.
|
| 456 |
+
"""
|
| 457 |
+
classification = JunctionType.END # default
|
| 458 |
+
|
| 459 |
+
half_w = self.window_size // 2
|
| 460 |
+
x, y = pt.coords.x, pt.coords.y
|
| 461 |
+
|
| 462 |
+
top = max(0, y - half_w)
|
| 463 |
+
bottom = min(skeleton.shape[0], y + half_w + 1)
|
| 464 |
+
left = max(0, x - half_w)
|
| 465 |
+
right = min(skeleton.shape[1], x + half_w + 1)
|
| 466 |
+
|
| 467 |
+
patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)
|
| 468 |
+
|
| 469 |
+
# create circular mask
|
| 470 |
+
circle_mask = np.zeros_like(patch, dtype=np.uint8)
|
| 471 |
+
local_cx = x - left
|
| 472 |
+
local_cy = y - top
|
| 473 |
+
cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
|
| 474 |
+
circle_skel = patch & circle_mask
|
| 475 |
+
|
| 476 |
+
# label connected regions
|
| 477 |
+
labeled = label(circle_skel, connectivity=2)
|
| 478 |
+
num_exits = labeled.max()
|
| 479 |
+
|
| 480 |
+
if num_exits == 1:
|
| 481 |
+
classification = JunctionType.END
|
| 482 |
+
elif num_exits == 2:
|
| 483 |
+
# check angle for L
|
| 484 |
+
classification = self._check_angle_for_L(labeled)
|
| 485 |
+
elif num_exits == 3:
|
| 486 |
+
classification = JunctionType.T
|
| 487 |
+
|
| 488 |
+
return classification
|
| 489 |
+
|
| 490 |
+
def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
|
| 491 |
+
"""
|
| 492 |
+
If the angle between two branches is within
|
| 493 |
+
[angle_threshold_lb, angle_threshold_ub], it's 'L'.
|
| 494 |
+
Otherwise default to END.
|
| 495 |
+
"""
|
| 496 |
+
coords = np.argwhere(labeled_region == 1)
|
| 497 |
+
if len(coords) < 2:
|
| 498 |
+
return JunctionType.END
|
| 499 |
+
|
| 500 |
+
(y1, x1), (y2, x2) = coords[:2]
|
| 501 |
+
dx = x2 - x1
|
| 502 |
+
dy = y2 - y1
|
| 503 |
+
angle = math.degrees(math.atan2(dy, dx))
|
| 504 |
+
acute_angle = min(abs(angle), 180 - abs(angle))
|
| 505 |
+
|
| 506 |
+
if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
|
| 507 |
+
return JunctionType.L
|
| 508 |
+
return JunctionType.END
|
| 509 |
+
|
| 510 |
+
# -----------------------------------------
|
| 511 |
+
# EXTRA STEP: Create Junction objects
|
| 512 |
+
# -----------------------------------------
|
| 513 |
+
def _record_junctions_in_context(self, context: DetectionContext):
|
| 514 |
+
"""
|
| 515 |
+
Create a Junction object for each point in context.points.
|
| 516 |
+
If you only want T/L points as junctions, filter them out.
|
| 517 |
+
Also track any lines that connect to this point.
|
| 518 |
+
"""
|
| 519 |
+
|
| 520 |
+
for pt in context.points.values():
|
| 521 |
+
# If you prefer to store all points as junction, do it:
|
| 522 |
+
# or if you want only T or L, do:
|
| 523 |
+
# if pt.type in {JunctionType.T, JunctionType.L}: ...
|
| 524 |
+
|
| 525 |
+
jn = Junction(
|
| 526 |
+
center=pt.coords,
|
| 527 |
+
junction_type=pt.type,
|
| 528 |
+
# add more properties if needed
|
| 529 |
+
)
|
| 530 |
+
|
| 531 |
+
# find lines that connect to this point
|
| 532 |
+
connected_lines = []
|
| 533 |
+
for ln in context.lines.values():
|
| 534 |
+
if ln.start.id == pt.id or ln.end.id == pt.id:
|
| 535 |
+
connected_lines.append(ln.id)
|
| 536 |
+
|
| 537 |
+
jn.connected_lines = connected_lines
|
| 538 |
+
|
| 539 |
+
# add to context
|
| 540 |
+
context.add_junction(jn)
|
| 541 |
+
|
| 542 |
+
# from loguru import logger
|
| 543 |
+
#
|
| 544 |
+
#
|
| 545 |
+
# class SymbolDetector(BaseDetector):
|
| 546 |
+
# """
|
| 547 |
+
# YOLO-based symbol detector using multiple confidence thresholds,
|
| 548 |
+
# merges final detections, and stores them in the context.
|
| 549 |
+
# """
|
| 550 |
+
#
|
| 551 |
+
# def __init__(self, config: SymbolConfig, debug_handler: Optional[DebugHandler] = None):
|
| 552 |
+
# super().__init__(config, debug_handler)
|
| 553 |
+
# self.config = config
|
| 554 |
+
# self.debug_handler = debug_handler or DebugHandler()
|
| 555 |
+
# self.models = self._load_models()
|
| 556 |
+
# self.class_map = self._build_class_map()
|
| 557 |
+
#
|
| 558 |
+
# logger.info("Symbol detector initialized with config: %s", self.config)
|
| 559 |
+
#
|
| 560 |
+
# # -----------------------------
|
| 561 |
+
# # BaseDetector Implementation
|
| 562 |
+
# # -----------------------------
|
| 563 |
+
# def _load_model(self, model_path: str):
|
| 564 |
+
# """We won't use this single-model loader; see _load_models()."""
|
| 565 |
+
# pass
|
| 566 |
+
#
|
| 567 |
+
# def detect(self,
|
| 568 |
+
# image: np.ndarray,
|
| 569 |
+
# context: DetectionContext,
|
| 570 |
+
# roi_offset: Tuple[int, int],
|
| 571 |
+
# *args,
|
| 572 |
+
# **kwargs) -> None:
|
| 573 |
+
# """
|
| 574 |
+
# Run multi-threshold YOLO detection for each model, pick best threshold,
|
| 575 |
+
# merge detections, and store Symbol objects in context.
|
| 576 |
+
# """
|
| 577 |
+
# try:
|
| 578 |
+
# with self.debug_handler.track_performance("symbol_detection"):
|
| 579 |
+
# # 1) Possibly preprocess & resize
|
| 580 |
+
# processed_img = self._preprocess(image)
|
| 581 |
+
# resized_img, scale_factor = self._resize_image(processed_img)
|
| 582 |
+
#
|
| 583 |
+
# # 2) Detect with all models, each using multiple thresholds
|
| 584 |
+
# all_detections = []
|
| 585 |
+
# for model_name, model in self.models.items():
|
| 586 |
+
# best_detections = self._detect_best_threshold(
|
| 587 |
+
# model, resized_img, image.shape, scale_factor, model_name
|
| 588 |
+
# )
|
| 589 |
+
# all_detections.extend(best_detections)
|
| 590 |
+
#
|
| 591 |
+
# # 3) Merge detections using NMS logic
|
| 592 |
+
# merged_detections = self._merge_detections(all_detections)
|
| 593 |
+
#
|
| 594 |
+
# # 4) Update context with final symbols
|
| 595 |
+
# self._update_context(merged_detections, context)
|
| 596 |
+
#
|
| 597 |
+
# # 5) Create optional debug image artifact
|
| 598 |
+
# debug_image = self._create_debug_image(processed_img, merged_detections)
|
| 599 |
+
# _, debug_img_encoded = cv2.imencode('.jpg', debug_image)
|
| 600 |
+
# self.debug_handler.save_artifact(
|
| 601 |
+
# name="symbol_detection_debug",
|
| 602 |
+
# data=debug_img_encoded.tobytes(),
|
| 603 |
+
# extension="jpg"
|
| 604 |
+
# )
|
| 605 |
+
#
|
| 606 |
+
# except Exception as e:
|
| 607 |
+
# logger.error("Symbol detection failed: %s", str(e), exc_info=True)
|
| 608 |
+
# self.debug_handler.save_artifact(
|
| 609 |
+
# name="symbol_detection_error",
|
| 610 |
+
# data=f"Detection error: {str(e)}".encode('utf-8'),
|
| 611 |
+
# extension="txt"
|
| 612 |
+
# )
|
| 613 |
+
#
|
| 614 |
+
# def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 615 |
+
# """Preprocess if needed (e.g., histogram equalization)."""
|
| 616 |
+
# if self.config.apply_preprocessing:
|
| 617 |
+
# gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 618 |
+
# equalized = cv2.equalizeHist(gray)
|
| 619 |
+
# # convert back to BGR for YOLO
|
| 620 |
+
# return cv2.cvtColor(equalized, cv2.COLOR_GRAY2BGR)
|
| 621 |
+
# return image.copy()
|
| 622 |
+
#
|
| 623 |
+
# def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 624 |
+
# return None
|
| 625 |
+
#
|
| 626 |
+
# # -----------------------------
|
| 627 |
+
# # Internal Helpers
|
| 628 |
+
# # -----------------------------
|
| 629 |
+
# def _load_models(self) -> Dict[str, YOLO]:
|
| 630 |
+
# """Load multiple YOLO models from config."""
|
| 631 |
+
# models = {}
|
| 632 |
+
# for model_name, path_str in self.config.model_paths.items():
|
| 633 |
+
# path = Path(path_str)
|
| 634 |
+
# if not path.exists():
|
| 635 |
+
# raise FileNotFoundError(f"Model file not found: {path_str}")
|
| 636 |
+
# models[model_name] = YOLO(str(path))
|
| 637 |
+
# logger.info(f"Loaded model '{model_name}' from {path_str}")
|
| 638 |
+
# return models
|
| 639 |
+
#
|
| 640 |
+
# def _build_class_map(self) -> Dict[int, SymbolType]:
|
| 641 |
+
# """
|
| 642 |
+
# Convert config symbol_type_mapping (like {"pump": "PUMP"})
|
| 643 |
+
# into a dictionary from YOLO class_id to SymbolType.
|
| 644 |
+
# If you have a fixed list of YOLO classes, you can map them here.
|
| 645 |
+
# """
|
| 646 |
+
# # For example, if YOLO has classes like ["valve", "pump", ...],
|
| 647 |
+
# # you might want to do something more dynamic.
|
| 648 |
+
# # For now, let's just return an empty dict or handle it in detection.
|
| 649 |
+
# return {}
|
| 650 |
+
#
|
| 651 |
+
# def _resize_image(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
|
| 652 |
+
# """Resize while maintaining aspect ratio if needed."""
|
| 653 |
+
# h, w = image.shape[:2]
|
| 654 |
+
# if not self.config.resize_image:
|
| 655 |
+
# return image, 1.0
|
| 656 |
+
#
|
| 657 |
+
# if max(w, h) > self.config.max_dimension:
|
| 658 |
+
# scale = self.config.max_dimension / max(w, h)
|
| 659 |
+
# new_w, new_h = int(w * scale), int(h * scale)
|
| 660 |
+
# resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 661 |
+
# return resized, scale
|
| 662 |
+
# return image, 1.0
|
| 663 |
+
#
|
| 664 |
+
# def _detect_best_threshold(self,
|
| 665 |
+
# model: YOLO,
|
| 666 |
+
# resized_img: np.ndarray,
|
| 667 |
+
# orig_shape: Tuple[int, int, int],
|
| 668 |
+
# scale_factor: float,
|
| 669 |
+
# model_name: str) -> List[Dict]:
|
| 670 |
+
# """
|
| 671 |
+
# Run detection across multiple confidence thresholds.
|
| 672 |
+
# Use the threshold that yields the 'best metric' (currently # of detections).
|
| 673 |
+
# """
|
| 674 |
+
# best_metric = -1
|
| 675 |
+
# best_threshold = 0.5
|
| 676 |
+
# best_detections_list = []
|
| 677 |
+
#
|
| 678 |
+
# # Evaluate each threshold
|
| 679 |
+
# for thresh in self.config.confidence_thresholds:
|
| 680 |
+
# # Run YOLO detection
|
| 681 |
+
# # Setting conf=thresh or conf=0.0 + we do filtering ourselves.
|
| 682 |
+
# results = model.predict(
|
| 683 |
+
# source=resized_img,
|
| 684 |
+
# imgsz=self.config.max_dimension,
|
| 685 |
+
# conf=0.0, # We'll filter manually below
|
| 686 |
+
# verbose=False
|
| 687 |
+
# )
|
| 688 |
+
#
|
| 689 |
+
# # Convert to detection dict
|
| 690 |
+
# detections_list = []
|
| 691 |
+
# for result in results:
|
| 692 |
+
# for box in result.boxes:
|
| 693 |
+
# conf_val = float(box.conf[0])
|
| 694 |
+
# if conf_val >= thresh:
|
| 695 |
+
# # Convert bounding box coords to original (local) coords
|
| 696 |
+
# x1, y1, x2, y2 = self._scale_coordinates(
|
| 697 |
+
# box.xyxy[0].cpu().numpy(),
|
| 698 |
+
# resized_img.shape, # shape after resizing
|
| 699 |
+
# scale_factor
|
| 700 |
+
# )
|
| 701 |
+
# class_id = int(box.cls[0])
|
| 702 |
+
# label = result.names[class_id] if result.names else "unknown_label"
|
| 703 |
+
#
|
| 704 |
+
# # parse label (category, type, new_label)
|
| 705 |
+
# category, type_str, new_label = self._parse_label(label)
|
| 706 |
+
#
|
| 707 |
+
# detection_info = {
|
| 708 |
+
# "symbol_id": str(uuid.uuid4()),
|
| 709 |
+
# "class_id": class_id,
|
| 710 |
+
# "original_label": label,
|
| 711 |
+
# "category": category,
|
| 712 |
+
# "type": type_str,
|
| 713 |
+
# "label": new_label,
|
| 714 |
+
# "confidence": conf_val,
|
| 715 |
+
# "bbox": [x1, y1, x2, y2],
|
| 716 |
+
# "model_source": model_name
|
| 717 |
+
# }
|
| 718 |
+
# detections_list.append(detection_info)
|
| 719 |
+
#
|
| 720 |
+
# # Evaluate
|
| 721 |
+
# metric = self._evaluate_detections(detections_list)
|
| 722 |
+
# if metric > best_metric:
|
| 723 |
+
# best_metric = metric
|
| 724 |
+
# best_threshold = thresh
|
| 725 |
+
# best_detections_list = detections_list
|
| 726 |
+
#
|
| 727 |
+
# logger.info(f"For model {model_name}, best threshold={best_threshold:.2f} with {best_metric} detections.")
|
| 728 |
+
# return best_detections_list
|
| 729 |
+
#
|
| 730 |
+
# def _evaluate_detections(self, detections_list: List[Dict]) -> int:
|
| 731 |
+
# """A simple metric: # of detections."""
|
| 732 |
+
# return len(detections_list)
|
| 733 |
+
#
|
| 734 |
+
# def _parse_label(self, label: str) -> Tuple[str, str, str]:
|
| 735 |
+
# """
|
| 736 |
+
# Attempt to parse the YOLO label into (category, type, new_label).
|
| 737 |
+
# Example label: "inst_ind_Solenoid_actuator"
|
| 738 |
+
# -> category=inst, type=ind, new_label="Solenoid_actuator"
|
| 739 |
+
# If no underscores, we fallback to "Unknown" for type.
|
| 740 |
+
# """
|
| 741 |
+
# split_label = label.split('_')
|
| 742 |
+
# if len(split_label) >= 3:
|
| 743 |
+
# category = split_label[0]
|
| 744 |
+
# type_ = split_label[1]
|
| 745 |
+
# new_label = '_'.join(split_label[2:])
|
| 746 |
+
# elif len(split_label) == 2:
|
| 747 |
+
# category = split_label[0]
|
| 748 |
+
# type_ = split_label[1]
|
| 749 |
+
# new_label = split_label[1]
|
| 750 |
+
# elif len(split_label) == 1:
|
| 751 |
+
# category = split_label[0]
|
| 752 |
+
# type_ = "Unknown"
|
| 753 |
+
# new_label = split_label[0]
|
| 754 |
+
# else:
|
| 755 |
+
# logger.warning(f"Unexpected label format: {label}")
|
| 756 |
+
# return ("Unknown", "Unknown", label)
|
| 757 |
+
#
|
| 758 |
+
# return (category, type_, new_label)
|
| 759 |
+
#
|
| 760 |
+
# def _scale_coordinates(self,
|
| 761 |
+
# coords: np.ndarray,
|
| 762 |
+
# resized_shape: Tuple[int, int, int],
|
| 763 |
+
# scale_factor: float) -> Tuple[int, int, int, int]:
|
| 764 |
+
# """
|
| 765 |
+
# Scale YOLO's [x1,y1,x2,y2] from the resized image back to the original local coords.
|
| 766 |
+
# """
|
| 767 |
+
# x1, y1, x2, y2 = coords
|
| 768 |
+
# # Because we resized by scale_factor
|
| 769 |
+
# # so original coordinate = coords / scale_factor
|
| 770 |
+
# return (
|
| 771 |
+
# int(x1 / scale_factor),
|
| 772 |
+
# int(y1 / scale_factor),
|
| 773 |
+
# int(x2 / scale_factor),
|
| 774 |
+
# int(y2 / scale_factor),
|
| 775 |
+
# )
|
| 776 |
+
#
|
| 777 |
+
# def _merge_detections(self, all_detections: List[Dict]) -> List[Dict]:
|
| 778 |
+
# """Merge using NMS-like approach (IoU-based) across all models."""
|
| 779 |
+
# if not all_detections:
|
| 780 |
+
# return []
|
| 781 |
+
#
|
| 782 |
+
# # Sort by confidence (descending)
|
| 783 |
+
# all_detections.sort(key=lambda x: x['confidence'], reverse=True)
|
| 784 |
+
# keep = [True] * len(all_detections)
|
| 785 |
+
#
|
| 786 |
+
# for i in range(len(all_detections)):
|
| 787 |
+
# if not keep[i]:
|
| 788 |
+
# continue
|
| 789 |
+
# for j in range(i + 1, len(all_detections)):
|
| 790 |
+
# if not keep[j]:
|
| 791 |
+
# continue
|
| 792 |
+
# # Merge if same class_id & high IoU
|
| 793 |
+
# if (all_detections[i]['class_id'] == all_detections[j]['class_id'] and
|
| 794 |
+
# self._calculate_iou(all_detections[i]['bbox'], all_detections[j]['bbox']) > 0.5):
|
| 795 |
+
# keep[j] = False
|
| 796 |
+
#
|
| 797 |
+
# return [det for idx, det in enumerate(all_detections) if keep[idx]]
|
| 798 |
+
#
|
| 799 |
+
# def _calculate_iou(self, box1: List[int], box2: List[int]) -> float:
|
| 800 |
+
# """Intersection over Union"""
|
| 801 |
+
# x_left = max(box1[0], box2[0])
|
| 802 |
+
# y_top = max(box1[1], box2[1])
|
| 803 |
+
# x_right = min(box1[2], box2[2])
|
| 804 |
+
# y_bottom = min(box1[3], box2[3])
|
| 805 |
+
#
|
| 806 |
+
# inter_w = max(0, x_right - x_left)
|
| 807 |
+
# inter_h = max(0, y_bottom - y_top)
|
| 808 |
+
# intersection = inter_w * inter_h
|
| 809 |
+
#
|
| 810 |
+
# area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 811 |
+
# area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 812 |
+
# union = float(area1 + area2 - intersection)
|
| 813 |
+
# return intersection / union if union > 0 else 0.0
|
| 814 |
+
#
|
| 815 |
+
# def _update_context(self, detections: List[Dict], context: DetectionContext) -> None:
|
| 816 |
+
# """Convert final detections into Symbol objects & add to context."""
|
| 817 |
+
# for det in detections:
|
| 818 |
+
# x1, y1, x2, y2 = det['bbox']
|
| 819 |
+
# # Use your Symbol dataclass from detection_schema
|
| 820 |
+
# symbol_obj = Symbol(
|
| 821 |
+
# bbox=BBox(xmin=x1, ymin=y1, xmax=x2, ymax=y2),
|
| 822 |
+
# center=Coordinates(x=(x1 + x2) // 2, y=(y1 + y2) // 2),
|
| 823 |
+
# symbol_type=SymbolType.OTHER, # default
|
| 824 |
+
# confidence=det['confidence'],
|
| 825 |
+
# model_source=det['model_source'],
|
| 826 |
+
# class_id=det['class_id'],
|
| 827 |
+
# original_label=det['original_label'],
|
| 828 |
+
# category=det['category'],
|
| 829 |
+
# type=det['type'],
|
| 830 |
+
# label=det['label']
|
| 831 |
+
# )
|
| 832 |
+
# context.add_symbol(symbol_obj)
|
| 833 |
+
#
|
| 834 |
+
# def _create_debug_image(self, image: np.ndarray, detections: List[Dict]) -> np.ndarray:
|
| 835 |
+
# """Optional: draw bounding boxes & labels on a copy of 'image'."""
|
| 836 |
+
# debug_img = image.copy()
|
| 837 |
+
# for det in detections:
|
| 838 |
+
# x1, y1, x2, y2 = det['bbox']
|
| 839 |
+
# cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
| 840 |
+
# txt = f"{det['label']} {det['confidence']:.2f}"
|
| 841 |
+
# cv2.putText(debug_img, txt, (x1, max(0, y1 - 10)),
|
| 842 |
+
# cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
|
| 843 |
+
# return debug_img
|
| 844 |
+
#
|
| 845 |
+
#
|
| 846 |
+
# class TagDetector(BaseDetector):
|
| 847 |
+
# """
|
| 848 |
+
# A placeholder detector that reads precomputed tag data
|
| 849 |
+
# from a JSON file and populates the context with Tag objects.
|
| 850 |
+
# """
|
| 851 |
+
#
|
| 852 |
+
# def __init__(self,
|
| 853 |
+
# config: TagConfig,
|
| 854 |
+
# debug_handler: Optional[DebugHandler] = None,
|
| 855 |
+
# tag_json_path: str = "./tags.json"):
|
| 856 |
+
# super().__init__(config=config, debug_handler=debug_handler)
|
| 857 |
+
# self.tag_json_path = tag_json_path
|
| 858 |
+
#
|
| 859 |
+
# def _load_model(self, model_path: str):
|
| 860 |
+
# """Not loading an actual model; tag data is read from JSON."""
|
| 861 |
+
# return None
|
| 862 |
+
#
|
| 863 |
+
# def detect(self,
|
| 864 |
+
# image: np.ndarray,
|
| 865 |
+
# context: DetectionContext,
|
| 866 |
+
# roi_offset: Tuple[int, int],
|
| 867 |
+
# *args,
|
| 868 |
+
# **kwargs) -> None:
|
| 869 |
+
# """
|
| 870 |
+
# Reads from a JSON file containing tag info,
|
| 871 |
+
# adjusts coordinates using roi_offset, and updates context.
|
| 872 |
+
# """
|
| 873 |
+
#
|
| 874 |
+
# tag_data = self._load_json_data(self.tag_json_path)
|
| 875 |
+
# if not tag_data:
|
| 876 |
+
# return
|
| 877 |
+
#
|
| 878 |
+
# x_min, y_min = roi_offset # Offset values from cropping
|
| 879 |
+
#
|
| 880 |
+
# for record in tag_data.get("detections", []): # Fix: Use "detections" key
|
| 881 |
+
# tag_obj = self._parse_tag_record(record, x_min, y_min)
|
| 882 |
+
# context.add_tag(tag_obj)
|
| 883 |
+
#
|
| 884 |
+
# def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 885 |
+
# return image
|
| 886 |
+
#
|
| 887 |
+
# def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 888 |
+
# return image
|
| 889 |
+
#
|
| 890 |
+
# # --------------
|
| 891 |
+
# # HELPER METHODS
|
| 892 |
+
# # --------------
|
| 893 |
+
# def _load_json_data(self, json_path: str) -> dict:
|
| 894 |
+
# if not os.path.exists(json_path):
|
| 895 |
+
# self.debug_handler.save_artifact(name="tag_error",
|
| 896 |
+
# data=b"Missing tag JSON file",
|
| 897 |
+
# extension="txt")
|
| 898 |
+
# return {}
|
| 899 |
+
#
|
| 900 |
+
# with open(json_path, "r", encoding="utf-8") as f:
|
| 901 |
+
# return json.load(f)
|
| 902 |
+
#
|
| 903 |
+
# def _parse_tag_record(self, record: dict, x_min: int, y_min: int) -> Tag:
|
| 904 |
+
# """
|
| 905 |
+
# Builds a Tag object from a JSON record, adjusting coordinates for cropping.
|
| 906 |
+
# """
|
| 907 |
+
# bbox_list = record.get("bbox", [0, 0, 0, 0])
|
| 908 |
+
# bbox_obj = BBox(
|
| 909 |
+
# xmin=bbox_list[0] - x_min,
|
| 910 |
+
# ymin=bbox_list[1] - y_min,
|
| 911 |
+
# xmax=bbox_list[2] - x_min,
|
| 912 |
+
# ymax=bbox_list[3] - y_min
|
| 913 |
+
# )
|
| 914 |
+
#
|
| 915 |
+
# return Tag(
|
| 916 |
+
# text=record.get("text", ""),
|
| 917 |
+
# bbox=bbox_obj,
|
| 918 |
+
# confidence=record.get("confidence", 1.0),
|
| 919 |
+
# source=record.get("source", ""),
|
| 920 |
+
# text_type=record.get("text_type", "Unknown"),
|
| 921 |
+
# id=record.get("id", str(uuid.uuid4())),
|
| 922 |
+
# font_size=record.get("font_size", 12),
|
| 923 |
+
# rotation=record.get("rotation", 0.0)
|
| 924 |
+
# )
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
import json
|
| 928 |
+
import uuid
|
| 929 |
+
|
| 930 |
+
class SymbolDetector(BaseDetector):
|
| 931 |
+
"""
|
| 932 |
+
A placeholder detector that reads precomputed symbol data
|
| 933 |
+
from a JSON file and populates the context with Symbol objects.
|
| 934 |
+
"""
|
| 935 |
+
|
| 936 |
+
def __init__(self,
|
| 937 |
+
config: SymbolConfig,
|
| 938 |
+
debug_handler: Optional[DebugHandler] = None,
|
| 939 |
+
symbol_json_path: str = "./symbols.json"):
|
| 940 |
+
super().__init__(config=config, debug_handler=debug_handler)
|
| 941 |
+
self.symbol_json_path = symbol_json_path
|
| 942 |
+
|
| 943 |
+
def _load_model(self, model_path: str):
|
| 944 |
+
"""Not loading an actual model; symbol data is read from JSON."""
|
| 945 |
+
return None
|
| 946 |
+
|
| 947 |
+
def detect(self,
|
| 948 |
+
image: np.ndarray,
|
| 949 |
+
context: DetectionContext,
|
| 950 |
+
roi_offset: Tuple[int, int],
|
| 951 |
+
*args,
|
| 952 |
+
**kwargs) -> None:
|
| 953 |
+
"""
|
| 954 |
+
Reads from a JSON file containing symbol info,
|
| 955 |
+
adjusts coordinates using roi_offset, and updates context.
|
| 956 |
+
"""
|
| 957 |
+
symbol_data = self._load_json_data(self.symbol_json_path)
|
| 958 |
+
if not symbol_data:
|
| 959 |
+
return
|
| 960 |
+
|
| 961 |
+
x_min, y_min = roi_offset # Offset values from cropping
|
| 962 |
+
|
| 963 |
+
for record in symbol_data.get("detections", []): # Fix: Use "detections" key
|
| 964 |
+
sym_obj = self._parse_symbol_record(record, x_min, y_min)
|
| 965 |
+
context.add_symbol(sym_obj)
|
| 966 |
+
|
| 967 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 968 |
+
return image
|
| 969 |
+
|
| 970 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 971 |
+
return image
|
| 972 |
+
|
| 973 |
+
# --------------
|
| 974 |
+
# HELPER METHODS
|
| 975 |
+
# --------------
|
| 976 |
+
def _load_json_data(self, json_path: str) -> dict:
|
| 977 |
+
if not os.path.exists(json_path):
|
| 978 |
+
self.debug_handler.save_artifact(name="symbol_error",
|
| 979 |
+
data=b"Missing symbol JSON file",
|
| 980 |
+
extension="txt")
|
| 981 |
+
return {}
|
| 982 |
+
|
| 983 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 984 |
+
return json.load(f)
|
| 985 |
+
|
| 986 |
+
def _parse_symbol_record(self, record: dict, x_min: int, y_min: int) -> Symbol:
|
| 987 |
+
"""
|
| 988 |
+
Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
|
| 989 |
+
"""
|
| 990 |
+
bbox_list = record.get("bbox", [0, 0, 0, 0])
|
| 991 |
+
bbox_obj = BBox(
|
| 992 |
+
xmin=bbox_list[0] - x_min,
|
| 993 |
+
ymin=bbox_list[1] - y_min,
|
| 994 |
+
xmax=bbox_list[2] - x_min,
|
| 995 |
+
ymax=bbox_list[3] - y_min
|
| 996 |
+
)
|
| 997 |
+
|
| 998 |
+
# Compute the center
|
| 999 |
+
center_coords = Coordinates(
|
| 1000 |
+
x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
|
| 1001 |
+
y=(bbox_obj.ymin + bbox_obj.ymax) // 2
|
| 1002 |
+
)
|
| 1003 |
+
|
| 1004 |
+
return Symbol(
|
| 1005 |
+
id=record.get("symbol_id", ""),
|
| 1006 |
+
class_id=record.get("class_id", -1),
|
| 1007 |
+
original_label=record.get("original_label", ""),
|
| 1008 |
+
category=record.get("category", ""),
|
| 1009 |
+
type=record.get("type", ""),
|
| 1010 |
+
label=record.get("label", ""),
|
| 1011 |
+
bbox=bbox_obj,
|
| 1012 |
+
center=center_coords,
|
| 1013 |
+
confidence=record.get("confidence", 0.95),
|
| 1014 |
+
model_source=record.get("model_source", ""),
|
| 1015 |
+
connections=[]
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
class TagDetector(BaseDetector):
|
| 1019 |
+
"""
|
| 1020 |
+
A placeholder detector that reads precomputed tag data
|
| 1021 |
+
from a JSON file and populates the context with Tag objects.
|
| 1022 |
+
"""
|
| 1023 |
+
|
| 1024 |
+
def __init__(self,
|
| 1025 |
+
config: TagConfig,
|
| 1026 |
+
debug_handler: Optional[DebugHandler] = None,
|
| 1027 |
+
tag_json_path: str = "./tags.json"):
|
| 1028 |
+
super().__init__(config=config, debug_handler=debug_handler)
|
| 1029 |
+
self.tag_json_path = tag_json_path
|
| 1030 |
+
|
| 1031 |
+
def _load_model(self, model_path: str):
|
| 1032 |
+
"""Not loading an actual model; tag data is read from JSON."""
|
| 1033 |
+
return None
|
| 1034 |
+
|
| 1035 |
+
def detect(self,
|
| 1036 |
+
image: np.ndarray,
|
| 1037 |
+
context: DetectionContext,
|
| 1038 |
+
roi_offset: Tuple[int, int],
|
| 1039 |
+
*args,
|
| 1040 |
+
**kwargs) -> None:
|
| 1041 |
+
"""
|
| 1042 |
+
Reads from a JSON file containing tag info,
|
| 1043 |
+
adjusts coordinates using roi_offset, and updates context.
|
| 1044 |
+
"""
|
| 1045 |
+
|
| 1046 |
+
tag_data = self._load_json_data(self.tag_json_path)
|
| 1047 |
+
if not tag_data:
|
| 1048 |
+
return
|
| 1049 |
+
|
| 1050 |
+
x_min, y_min = roi_offset # Offset values from cropping
|
| 1051 |
+
|
| 1052 |
+
for record in tag_data.get("detections", []): # Fix: Use "detections" key
|
| 1053 |
+
tag_obj = self._parse_tag_record(record, x_min, y_min)
|
| 1054 |
+
context.add_tag(tag_obj)
|
| 1055 |
+
|
| 1056 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 1057 |
+
return image
|
| 1058 |
+
|
| 1059 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 1060 |
+
return image
|
| 1061 |
+
|
| 1062 |
+
# --------------
|
| 1063 |
+
# HELPER METHODS
|
| 1064 |
+
# --------------
|
| 1065 |
+
def _load_json_data(self, json_path: str) -> dict:
|
| 1066 |
+
if not os.path.exists(json_path):
|
| 1067 |
+
self.debug_handler.save_artifact(name="tag_error",
|
| 1068 |
+
data=b"Missing tag JSON file",
|
| 1069 |
+
extension="txt")
|
| 1070 |
+
return {}
|
| 1071 |
+
|
| 1072 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 1073 |
+
return json.load(f)
|
| 1074 |
+
|
| 1075 |
+
def _parse_tag_record(self, record: dict, x_min: int, y_min: int) -> Tag:
|
| 1076 |
+
"""
|
| 1077 |
+
Builds a Tag object from a JSON record, adjusting coordinates for cropping.
|
| 1078 |
+
"""
|
| 1079 |
+
bbox_list = record.get("bbox", [0, 0, 0, 0])
|
| 1080 |
+
bbox_obj = BBox(
|
| 1081 |
+
xmin=bbox_list[0] - x_min,
|
| 1082 |
+
ymin=bbox_list[1] - y_min,
|
| 1083 |
+
xmax=bbox_list[2] - x_min,
|
| 1084 |
+
ymax=bbox_list[3] - y_min
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
return Tag(
|
| 1088 |
+
text=record.get("text", ""),
|
| 1089 |
+
bbox=bbox_obj,
|
| 1090 |
+
confidence=record.get("confidence", 1.0),
|
| 1091 |
+
source=record.get("source", ""),
|
| 1092 |
+
text_type=record.get("text_type", "Unknown"),
|
| 1093 |
+
id=record.get("id", str(uuid.uuid4())),
|
| 1094 |
+
font_size=record.get("font_size", 12),
|
| 1095 |
+
rotation=record.get("rotation", 0.0)
|
| 1096 |
+
)
|
gradioChatApp.py
ADDED
|
@@ -0,0 +1,672 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import json
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from symbol_detection import run_detection_with_optimal_threshold
|
| 7 |
+
from line_detection_ai import DiagramDetectionPipeline, LineDetector, LineConfig, ImageConfig, DebugHandler, PointConfig, JunctionConfig, PointDetector, JunctionDetector, SymbolConfig, SymbolDetector, TagConfig, TagDetector
|
| 8 |
+
from data_aggregation_ai import DataAggregator
|
| 9 |
+
from chatbot_agent import get_assistant_response
|
| 10 |
+
from storage import StorageFactory, LocalStorage
|
| 11 |
+
import traceback
|
| 12 |
+
from text_detection_combined import process_drawing
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from pdf_processor import DocumentProcessor
|
| 15 |
+
import networkx as nx
|
| 16 |
+
import logging
|
| 17 |
+
import matplotlib.pyplot as plt
|
| 18 |
+
from dotenv import load_dotenv
|
| 19 |
+
import torch
|
| 20 |
+
from graph_visualization import create_graph_visualization
|
| 21 |
+
import shutil
|
| 22 |
+
|
| 23 |
+
# Load environment variables from .env file
|
| 24 |
+
load_dotenv()
|
| 25 |
+
|
| 26 |
+
# Configure logging at the start of the file
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
level=logging.INFO,
|
| 29 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 30 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# Get logger for this module
|
| 34 |
+
logger = logging.getLogger(__name__)
|
| 35 |
+
|
| 36 |
+
# Disable duplicate logs from other modules
|
| 37 |
+
logging.getLogger('PIL').setLevel(logging.WARNING)
|
| 38 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 39 |
+
logging.getLogger('gradio').setLevel(logging.WARNING)
|
| 40 |
+
logging.getLogger('networkx').setLevel(logging.WARNING)
|
| 41 |
+
logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
|
| 42 |
+
logging.getLogger('symbol_detection').setLevel(logging.WARNING)
|
| 43 |
+
|
| 44 |
+
# Only log important messages
|
| 45 |
+
def log_process_step(message, level=logging.INFO):
|
| 46 |
+
"""Log processing steps with appropriate level"""
|
| 47 |
+
if level >= logging.WARNING:
|
| 48 |
+
logger.log(level, message)
|
| 49 |
+
elif "completed" in message.lower() or "generated" in message.lower():
|
| 50 |
+
logger.info(message)
|
| 51 |
+
|
| 52 |
+
# Helper function to format timestamps
|
| 53 |
+
def get_timestamp():
|
| 54 |
+
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 55 |
+
|
| 56 |
+
def format_message(role, content):
|
| 57 |
+
"""Format message for chatbot history."""
|
| 58 |
+
return {"role": role, "content": content}
|
| 59 |
+
|
| 60 |
+
# Load avatar images for agents
|
| 61 |
+
localStorage = LocalStorage()
|
| 62 |
+
agent_avatar = base64.b64encode(localStorage.load_file("assets/AiAgent.png")).decode()
|
| 63 |
+
llm_avatar = base64.b64encode(localStorage.load_file("assets/llm.png")).decode()
|
| 64 |
+
user_avatar = base64.b64encode(localStorage.load_file("assets/user.png")).decode()
|
| 65 |
+
|
| 66 |
+
# Chat message formatting with avatars and enhanced HTML for readability
|
| 67 |
+
def chat_message(role, message, avatar, timestamp):
|
| 68 |
+
# Convert Markdown-style formatting to HTML
|
| 69 |
+
formatted_message = (
|
| 70 |
+
message.replace("**", "<strong>").replace("**", "</strong>")
|
| 71 |
+
.replace("###", "<h3>").replace("##", "<h2>")
|
| 72 |
+
.replace("#", "<h1>").replace("\n", "<br>")
|
| 73 |
+
.replace("```", "<pre><code>").replace("`", "</code></pre>")
|
| 74 |
+
.replace("\n1. ", "<br>1. ") # For ordered lists starting with "1."
|
| 75 |
+
.replace("\n2. ", "<br>2. ")
|
| 76 |
+
.replace("\n3. ", "<br>3. ")
|
| 77 |
+
.replace("\n4. ", "<br>4. ")
|
| 78 |
+
.replace("\n5. ", "<br>5. ")
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
return f"""
|
| 82 |
+
<div class="chat-message {role}">
|
| 83 |
+
<img src="data:image/png;base64,{avatar}" class="avatar"/>
|
| 84 |
+
<div>
|
| 85 |
+
<div class="speech-bubble {role}-bubble">{formatted_message}</div>
|
| 86 |
+
<div class="timestamp">{timestamp}</div>
|
| 87 |
+
</div>
|
| 88 |
+
</div>
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
# Main processing function for P&ID steps
|
| 92 |
+
def process_pnid(image_file, progress_status, progress=gr.Progress()):
|
| 93 |
+
"""Process P&ID document with real-time progress updates."""
|
| 94 |
+
try:
|
| 95 |
+
# Disable verbose logging for processing components
|
| 96 |
+
logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
|
| 97 |
+
logging.getLogger('symbol_detection').setLevel(logging.WARNING)
|
| 98 |
+
logging.getLogger('text_detection').setLevel(logging.WARNING)
|
| 99 |
+
|
| 100 |
+
progress_text = []
|
| 101 |
+
outputs = [None] * 9
|
| 102 |
+
|
| 103 |
+
def update_progress(step, message):
|
| 104 |
+
timestamp = get_timestamp()
|
| 105 |
+
progress_text.append(f"{timestamp} - {message}")
|
| 106 |
+
outputs[7] = "\n".join(progress_text[-20:]) # Keep last 20 lines
|
| 107 |
+
progress(step, desc=f"Step {step}/7: {message}")
|
| 108 |
+
return outputs
|
| 109 |
+
|
| 110 |
+
# Update progress with smaller steps
|
| 111 |
+
update_progress(0.1, "Starting processing...")
|
| 112 |
+
yield outputs
|
| 113 |
+
|
| 114 |
+
storage = StorageFactory.get_storage()
|
| 115 |
+
results_dir = "results"
|
| 116 |
+
outputs = [None] * 9
|
| 117 |
+
|
| 118 |
+
if image_file is None:
|
| 119 |
+
raise ValueError("No file uploaded")
|
| 120 |
+
|
| 121 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 122 |
+
current_progress = 0
|
| 123 |
+
progress_text = []
|
| 124 |
+
|
| 125 |
+
# Step 1: File Upload (10%)
|
| 126 |
+
logger.info(f"Processing file: {os.path.basename(image_file)}")
|
| 127 |
+
update_progress(0.1, "Step 1/7: File uploaded successfully")
|
| 128 |
+
yield outputs
|
| 129 |
+
|
| 130 |
+
# Step 2: Document Processing (25%)
|
| 131 |
+
update_progress(0.25, "Step 2/7: Processing document...")
|
| 132 |
+
yield outputs
|
| 133 |
+
|
| 134 |
+
doc_processor = DocumentProcessor(storage)
|
| 135 |
+
processed_pages = doc_processor.process_document(
|
| 136 |
+
file_path=image_file,
|
| 137 |
+
output_dir=results_dir
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
if not processed_pages:
|
| 141 |
+
raise ValueError("No pages processed from document")
|
| 142 |
+
|
| 143 |
+
display_path = processed_pages[0]
|
| 144 |
+
outputs[0] = display_path
|
| 145 |
+
update_progress(0.25, "Document processed successfully")
|
| 146 |
+
yield outputs
|
| 147 |
+
|
| 148 |
+
# Step 3: Symbol Detection (45%)
|
| 149 |
+
update_progress(0.45, "Step 3/7: Symbol Detection")
|
| 150 |
+
yield outputs
|
| 151 |
+
|
| 152 |
+
# Store detection results and diagram_bbox
|
| 153 |
+
detection_results = run_detection_with_optimal_threshold(
|
| 154 |
+
display_path,
|
| 155 |
+
results_dir=results_dir,
|
| 156 |
+
file_name=os.path.basename(display_path),
|
| 157 |
+
resize_image=True,
|
| 158 |
+
storage=storage
|
| 159 |
+
)
|
| 160 |
+
detection_image_path, detection_json_path, _, diagram_bbox = detection_results
|
| 161 |
+
|
| 162 |
+
if diagram_bbox is None:
|
| 163 |
+
logger.warning("No diagram bounding box detected, using full image")
|
| 164 |
+
# Provide a fallback bbox if needed
|
| 165 |
+
diagram_bbox = [0, 0, 0, 0] # Or get image dimensions
|
| 166 |
+
|
| 167 |
+
outputs[1] = detection_image_path
|
| 168 |
+
update_progress(0.45, "Symbol detection completed")
|
| 169 |
+
yield outputs
|
| 170 |
+
|
| 171 |
+
# Step 4: Text Detection (65%)
|
| 172 |
+
update_progress(0.65, "Step 4/7: Text Detection")
|
| 173 |
+
yield outputs
|
| 174 |
+
|
| 175 |
+
text_results, text_summary = process_drawing(display_path, results_dir, storage)
|
| 176 |
+
outputs[2] = text_results['image_path']
|
| 177 |
+
|
| 178 |
+
update_progress(0.65, "Text detection completed")
|
| 179 |
+
update_progress(0.65, f"Found {text_summary['total_detections']} text elements")
|
| 180 |
+
yield outputs
|
| 181 |
+
|
| 182 |
+
# Step 5: Line Detection (80%)
|
| 183 |
+
update_progress(0.80, "Step 5/7: Line Detection")
|
| 184 |
+
yield outputs
|
| 185 |
+
|
| 186 |
+
try:
|
| 187 |
+
# Initialize components
|
| 188 |
+
debug_handler = DebugHandler(enabled=True, storage=storage)
|
| 189 |
+
|
| 190 |
+
# Configure detectors
|
| 191 |
+
line_config = LineConfig()
|
| 192 |
+
point_config = PointConfig()
|
| 193 |
+
junction_config = JunctionConfig()
|
| 194 |
+
symbol_config = SymbolConfig()
|
| 195 |
+
tag_config = TagConfig()
|
| 196 |
+
|
| 197 |
+
# Create all required detectors
|
| 198 |
+
symbol_detector = SymbolDetector(
|
| 199 |
+
config=symbol_config,
|
| 200 |
+
debug_handler=debug_handler
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
tag_detector = TagDetector(
|
| 204 |
+
config=tag_config,
|
| 205 |
+
debug_handler=debug_handler
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
line_detector = LineDetector(
|
| 209 |
+
config=line_config,
|
| 210 |
+
model_path="models/deeplsd_md.tar",
|
| 211 |
+
model_config={"detect_lines": True},
|
| 212 |
+
device=torch.device("cpu"),
|
| 213 |
+
debug_handler=debug_handler
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
point_detector = PointDetector(
|
| 217 |
+
config=point_config,
|
| 218 |
+
debug_handler=debug_handler
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
junction_detector = JunctionDetector(
|
| 222 |
+
config=junction_config,
|
| 223 |
+
debug_handler=debug_handler
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# Create and run pipeline with all detectors
|
| 227 |
+
pipeline = DiagramDetectionPipeline(
|
| 228 |
+
tag_detector=tag_detector,
|
| 229 |
+
symbol_detector=symbol_detector,
|
| 230 |
+
line_detector=line_detector,
|
| 231 |
+
point_detector=point_detector,
|
| 232 |
+
junction_detector=junction_detector,
|
| 233 |
+
storage=storage,
|
| 234 |
+
debug_handler=debug_handler
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
# Run pipeline
|
| 238 |
+
result = pipeline.run(
|
| 239 |
+
image_path=display_path,
|
| 240 |
+
output_dir=results_dir,
|
| 241 |
+
config=ImageConfig()
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
if result.success:
|
| 245 |
+
line_image_path = result.image_path
|
| 246 |
+
line_json_path = result.json_path
|
| 247 |
+
outputs[3] = line_image_path
|
| 248 |
+
update_progress(0.80, "Line detection completed")
|
| 249 |
+
else:
|
| 250 |
+
logger.error(f"Pipeline failed: {result.error}")
|
| 251 |
+
raise Exception(result.error)
|
| 252 |
+
|
| 253 |
+
except Exception as e:
|
| 254 |
+
logger.error(f"Line detection error: {str(e)}")
|
| 255 |
+
raise
|
| 256 |
+
|
| 257 |
+
# Step 6: Data Aggregation (90%)
|
| 258 |
+
update_progress(0.90, "Step 6/7: Data Aggregation")
|
| 259 |
+
yield outputs
|
| 260 |
+
|
| 261 |
+
data_aggregator = DataAggregator(storage=storage)
|
| 262 |
+
aggregated_data = data_aggregator.aggregate_data(
|
| 263 |
+
symbols_path=detection_json_path,
|
| 264 |
+
texts_path=text_results['json_path'],
|
| 265 |
+
lines_path=line_json_path
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Add image path to aggregated data
|
| 269 |
+
aggregated_data['image_path'] = display_path
|
| 270 |
+
|
| 271 |
+
# Save aggregated data
|
| 272 |
+
aggregated_json_path = os.path.join(results_dir, f"{Path(display_path).stem}_aggregated.json")
|
| 273 |
+
with open(aggregated_json_path, 'w') as f:
|
| 274 |
+
json.dump(aggregated_data, f, indent=2)
|
| 275 |
+
|
| 276 |
+
# Use the detection image as the aggregated view for now
|
| 277 |
+
# TODO: Implement visualization in DataAggregator if needed
|
| 278 |
+
outputs[4] = detection_image_path # Changed from aggregated_image_path
|
| 279 |
+
outputs[8] = aggregated_json_path
|
| 280 |
+
update_progress(0.90, "Data aggregation completed")
|
| 281 |
+
yield outputs
|
| 282 |
+
|
| 283 |
+
# Step 7: Graph Generation (95%)
|
| 284 |
+
update_progress(0.95, "Step 7/7: Generating knowledge graph...")
|
| 285 |
+
yield outputs
|
| 286 |
+
|
| 287 |
+
try:
|
| 288 |
+
with open(aggregated_json_path, 'r') as f:
|
| 289 |
+
aggregated_detection_data = json.load(f)
|
| 290 |
+
|
| 291 |
+
logger.info("Creating knowledge graph...")
|
| 292 |
+
|
| 293 |
+
# Create graph visualization - this will save the visualization file
|
| 294 |
+
G, _ = create_graph_visualization(aggregated_json_path, save_plot=True)
|
| 295 |
+
|
| 296 |
+
if G is not None:
|
| 297 |
+
# Use the saved visualization file
|
| 298 |
+
graph_image_path = os.path.join(os.path.dirname(aggregated_json_path), "graph_visualization.png")
|
| 299 |
+
|
| 300 |
+
if os.path.exists(graph_image_path):
|
| 301 |
+
outputs[5] = graph_image_path
|
| 302 |
+
update_progress(0.95, "Knowledge graph generated")
|
| 303 |
+
logger.info("Knowledge graph generated and saved successfully")
|
| 304 |
+
|
| 305 |
+
# Final completion (100%)
|
| 306 |
+
update_progress(1.0, "✅ Processing Complete")
|
| 307 |
+
welcome_message = chat_message(
|
| 308 |
+
"agent",
|
| 309 |
+
"Processing complete! I can help answer questions about the P&ID contents.",
|
| 310 |
+
agent_avatar,
|
| 311 |
+
get_timestamp()
|
| 312 |
+
)
|
| 313 |
+
outputs[6] = welcome_message
|
| 314 |
+
update_progress(1.0, "✅ All processing steps completed successfully!")
|
| 315 |
+
yield outputs
|
| 316 |
+
else:
|
| 317 |
+
logger.warning("Graph visualization file not found")
|
| 318 |
+
update_progress(1.0, "⚠️ Warning: Graph visualization could not be generated")
|
| 319 |
+
yield outputs
|
| 320 |
+
else:
|
| 321 |
+
logger.warning("No graph was generated")
|
| 322 |
+
update_progress(1.0, "⚠️ Warning: No graph could be generated")
|
| 323 |
+
yield outputs
|
| 324 |
+
|
| 325 |
+
except Exception as e:
|
| 326 |
+
logger.error(f"Error in graph generation: {str(e)}")
|
| 327 |
+
logger.error(f"Traceback: {traceback.format_exc()}")
|
| 328 |
+
raise
|
| 329 |
+
|
| 330 |
+
except Exception as e:
|
| 331 |
+
logger.error(f"Error in process_pnid: {str(e)}")
|
| 332 |
+
logger.error(traceback.format_exc())
|
| 333 |
+
error_msg = f"❌ Error: {str(e)}"
|
| 334 |
+
update_progress(1.0, error_msg)
|
| 335 |
+
yield outputs
|
| 336 |
+
|
| 337 |
+
# Separate function for Chat interaction
|
| 338 |
+
def handle_user_message(user_input, chat_history, json_path_state):
|
| 339 |
+
"""Handle user messages and generate responses."""
|
| 340 |
+
try:
|
| 341 |
+
if not user_input or not user_input.strip():
|
| 342 |
+
return chat_history
|
| 343 |
+
|
| 344 |
+
# Add user message
|
| 345 |
+
timestamp = get_timestamp()
|
| 346 |
+
new_history = chat_history + chat_message("user", user_input, user_avatar, timestamp)
|
| 347 |
+
|
| 348 |
+
# Check if json_path exists and is valid
|
| 349 |
+
if not json_path_state or not os.path.exists(json_path_state):
|
| 350 |
+
error_message = "Please upload and process a P&ID document first."
|
| 351 |
+
return new_history + chat_message("assistant", error_message, agent_avatar, get_timestamp())
|
| 352 |
+
|
| 353 |
+
try:
|
| 354 |
+
# Log for debugging
|
| 355 |
+
logger.info(f"Sending question to assistant: {user_input}")
|
| 356 |
+
logger.info(f"Using JSON path: {json_path_state}")
|
| 357 |
+
|
| 358 |
+
# Generate response
|
| 359 |
+
response = get_assistant_response(user_input, json_path_state)
|
| 360 |
+
|
| 361 |
+
# Handle the response
|
| 362 |
+
if isinstance(response, (str, dict)):
|
| 363 |
+
response_text = str(response)
|
| 364 |
+
else:
|
| 365 |
+
try:
|
| 366 |
+
# Try to get the first response from generator
|
| 367 |
+
response_text = next(response) if hasattr(response, '__next__') else str(response)
|
| 368 |
+
except StopIteration:
|
| 369 |
+
response_text = "I apologize, but I couldn't generate a response."
|
| 370 |
+
except Exception as e:
|
| 371 |
+
logger.error(f"Error processing response: {str(e)}")
|
| 372 |
+
response_text = "I apologize, but I encountered an error processing your request."
|
| 373 |
+
|
| 374 |
+
logger.info(f"Generated response: {response_text}")
|
| 375 |
+
|
| 376 |
+
if not response_text.strip():
|
| 377 |
+
response_text = "I apologize, but I couldn't generate a response. Please try asking your question differently."
|
| 378 |
+
|
| 379 |
+
# Add response to chat history
|
| 380 |
+
new_history += chat_message("assistant", response_text, agent_avatar, get_timestamp())
|
| 381 |
+
|
| 382 |
+
except Exception as e:
|
| 383 |
+
logger.error(f"Error generating response: {str(e)}")
|
| 384 |
+
logger.error(traceback.format_exc())
|
| 385 |
+
error_message = "I apologize, but I encountered an error processing your request. Please try again."
|
| 386 |
+
new_history += chat_message("assistant", error_message, agent_avatar, get_timestamp())
|
| 387 |
+
|
| 388 |
+
return new_history
|
| 389 |
+
|
| 390 |
+
except Exception as e:
|
| 391 |
+
logger.error(f"Chat error: {str(e)}")
|
| 392 |
+
logger.error(traceback.format_exc())
|
| 393 |
+
return chat_history + chat_message(
|
| 394 |
+
"assistant",
|
| 395 |
+
"I apologize, but something went wrong. Please try again.",
|
| 396 |
+
agent_avatar,
|
| 397 |
+
get_timestamp()
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# Update custom CSS
|
| 401 |
+
custom_css = """
|
| 402 |
+
.full-height-row {
|
| 403 |
+
height: calc(100vh - 150px); /* Adjusted height */
|
| 404 |
+
margin: 0;
|
| 405 |
+
padding: 10px;
|
| 406 |
+
}
|
| 407 |
+
.upload-box {
|
| 408 |
+
background: #2a2a2a;
|
| 409 |
+
border-radius: 8px;
|
| 410 |
+
padding: 15px;
|
| 411 |
+
margin-bottom: 15px;
|
| 412 |
+
border: 1px solid #3a3a3a;
|
| 413 |
+
}
|
| 414 |
+
.status-box-container {
|
| 415 |
+
background: #2a2a2a;
|
| 416 |
+
border-radius: 8px;
|
| 417 |
+
padding: 15px;
|
| 418 |
+
height: calc(100vh - 350px); /* Reduced height */
|
| 419 |
+
border: 1px solid #3a3a3a;
|
| 420 |
+
margin-bottom: 15px;
|
| 421 |
+
}
|
| 422 |
+
.status-box {
|
| 423 |
+
font-family: 'Courier New', monospace;
|
| 424 |
+
font-size: 12px;
|
| 425 |
+
line-height: 1.4;
|
| 426 |
+
background-color: #1a1a1a;
|
| 427 |
+
color: #00ff00;
|
| 428 |
+
padding: 10px;
|
| 429 |
+
border-radius: 5px;
|
| 430 |
+
height: calc(100% - 40px); /* Adjust for header */
|
| 431 |
+
overflow-y: auto;
|
| 432 |
+
white-space: pre-wrap;
|
| 433 |
+
word-wrap: break-word;
|
| 434 |
+
border: none;
|
| 435 |
+
}
|
| 436 |
+
.preview-tabs {
|
| 437 |
+
height: calc(100vh - 350px); /* Reduced height */
|
| 438 |
+
background: #2a2a2a;
|
| 439 |
+
border-radius: 8px;
|
| 440 |
+
padding: 15px;
|
| 441 |
+
border: 1px solid #3a3a3a;
|
| 442 |
+
margin-bottom: 15px;
|
| 443 |
+
}
|
| 444 |
+
.chat-container {
|
| 445 |
+
height: 100%; /* Take full height */
|
| 446 |
+
display: flex;
|
| 447 |
+
flex-direction: column;
|
| 448 |
+
background: #2a2a2a;
|
| 449 |
+
border-radius: 8px;
|
| 450 |
+
padding: 15px;
|
| 451 |
+
border: 1px solid #3a3a3a;
|
| 452 |
+
}
|
| 453 |
+
.chatbox {
|
| 454 |
+
flex: 1; /* Take remaining space */
|
| 455 |
+
overflow-y: auto;
|
| 456 |
+
background: #1a1a1a;
|
| 457 |
+
border-radius: 8px;
|
| 458 |
+
padding: 15px;
|
| 459 |
+
margin-bottom: 15px;
|
| 460 |
+
color: #ffffff;
|
| 461 |
+
min-height: 200px; /* Ensure minimum height */
|
| 462 |
+
}
|
| 463 |
+
.chat-input-group {
|
| 464 |
+
height: auto; /* Allow natural height */
|
| 465 |
+
min-height: 120px; /* Minimum height for input area */
|
| 466 |
+
background: #1a1a1a;
|
| 467 |
+
border-radius: 8px;
|
| 468 |
+
padding: 15px;
|
| 469 |
+
margin-top: auto; /* Push to bottom */
|
| 470 |
+
}
|
| 471 |
+
.chat-input {
|
| 472 |
+
background: #2a2a2a;
|
| 473 |
+
color: #ffffff;
|
| 474 |
+
border: 1px solid #3a3a3a;
|
| 475 |
+
border-radius: 5px;
|
| 476 |
+
padding: 12px;
|
| 477 |
+
min-height: 80px;
|
| 478 |
+
width: 100%;
|
| 479 |
+
margin-bottom: 10px;
|
| 480 |
+
}
|
| 481 |
+
.send-button {
|
| 482 |
+
width: 100%;
|
| 483 |
+
background: #4a4a4a;
|
| 484 |
+
color: #ffffff;
|
| 485 |
+
border-radius: 5px;
|
| 486 |
+
border: none;
|
| 487 |
+
padding: 12px;
|
| 488 |
+
cursor: pointer;
|
| 489 |
+
transition: background-color 0.3s;
|
| 490 |
+
}
|
| 491 |
+
.result-image {
|
| 492 |
+
border-radius: 8px;
|
| 493 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 494 |
+
margin: 10px 0;
|
| 495 |
+
background: #ffffff;
|
| 496 |
+
}
|
| 497 |
+
.chat-message {
|
| 498 |
+
display: flex;
|
| 499 |
+
margin-bottom: 1rem;
|
| 500 |
+
align-items: flex-start;
|
| 501 |
+
}
|
| 502 |
+
.chat-message .avatar {
|
| 503 |
+
width: 40px;
|
| 504 |
+
height: 40px;
|
| 505 |
+
margin-right: 10px;
|
| 506 |
+
border-radius: 50%;
|
| 507 |
+
}
|
| 508 |
+
.chat-message .speech-bubble {
|
| 509 |
+
background: #2a2a2a;
|
| 510 |
+
padding: 10px 15px;
|
| 511 |
+
border-radius: 10px;
|
| 512 |
+
max-width: 80%;
|
| 513 |
+
margin-bottom: 5px;
|
| 514 |
+
}
|
| 515 |
+
.chat-message .timestamp {
|
| 516 |
+
font-size: 0.8em;
|
| 517 |
+
color: #666;
|
| 518 |
+
}
|
| 519 |
+
.logo-row {
|
| 520 |
+
width: 100%;
|
| 521 |
+
background-color: #1a1a1a;
|
| 522 |
+
padding: 10px 0;
|
| 523 |
+
margin: 0;
|
| 524 |
+
border-bottom: 1px solid #3a3a3a;
|
| 525 |
+
}
|
| 526 |
+
"""
|
| 527 |
+
|
| 528 |
+
def create_ui():
|
| 529 |
+
with gr.Blocks(css=custom_css) as demo:
|
| 530 |
+
# Logo row
|
| 531 |
+
with gr.Row(elem_classes=["logo-row"]):
|
| 532 |
+
try:
|
| 533 |
+
logo_path = os.path.join(os.path.dirname(__file__), "assets", "intuigence.png")
|
| 534 |
+
if os.path.exists(logo_path):
|
| 535 |
+
with open(logo_path, "rb") as f:
|
| 536 |
+
logo_base64 = base64.b64encode(f.read()).decode()
|
| 537 |
+
gr.HTML(f"""
|
| 538 |
+
<div style="text-align: center; padding: 10px; background-color: #1a1a1a; width: 100%;">
|
| 539 |
+
<img src="data:image/png;base64,{logo_base64}"
|
| 540 |
+
alt="Intuigence Logo"
|
| 541 |
+
style="height: 60px; object-fit: contain;">
|
| 542 |
+
</div>
|
| 543 |
+
""")
|
| 544 |
+
else:
|
| 545 |
+
logger.warning(f"Logo not found at {logo_path}")
|
| 546 |
+
except Exception as e:
|
| 547 |
+
logger.error(f"Error loading logo: {e}")
|
| 548 |
+
|
| 549 |
+
# Main layout
|
| 550 |
+
with gr.Row(equal_height=True, elem_classes=["full-height-row"]):
|
| 551 |
+
# Left column
|
| 552 |
+
with gr.Column(scale=2):
|
| 553 |
+
# Upload area
|
| 554 |
+
with gr.Column(elem_classes=["upload-box"]):
|
| 555 |
+
image_input = gr.File(
|
| 556 |
+
label="Upload P&ID Document",
|
| 557 |
+
file_types=[".pdf", ".png", ".jpg", ".jpeg"],
|
| 558 |
+
file_count="single",
|
| 559 |
+
type="filepath"
|
| 560 |
+
)
|
| 561 |
+
|
| 562 |
+
# Status area
|
| 563 |
+
with gr.Column(elem_classes=["status-box-container"]):
|
| 564 |
+
gr.Markdown("### Processing Status")
|
| 565 |
+
progress_status = gr.Textbox(
|
| 566 |
+
label="Status",
|
| 567 |
+
show_label=False,
|
| 568 |
+
elem_classes=["status-box"],
|
| 569 |
+
lines=15,
|
| 570 |
+
max_lines=20,
|
| 571 |
+
interactive=False,
|
| 572 |
+
autoscroll=True,
|
| 573 |
+
value="" # Initialize with empty value
|
| 574 |
+
)
|
| 575 |
+
json_path_state = gr.State()
|
| 576 |
+
|
| 577 |
+
# Center column
|
| 578 |
+
with gr.Column(scale=5):
|
| 579 |
+
with gr.Tabs(elem_classes=["preview-tabs"]) as tabs:
|
| 580 |
+
with gr.TabItem("P&ID"):
|
| 581 |
+
original_image = gr.Image(label="Original P&ID", height=450) # Reduced height
|
| 582 |
+
with gr.TabItem("Symbols"):
|
| 583 |
+
symbol_image = gr.Image(label="Detected Symbols", height=450)
|
| 584 |
+
with gr.TabItem("Tags"):
|
| 585 |
+
text_image = gr.Image(label="Detected Tags", height=450)
|
| 586 |
+
with gr.TabItem("Pipelines"):
|
| 587 |
+
line_image = gr.Image(label="Detected Lines", height=450)
|
| 588 |
+
with gr.TabItem("Aggregated"):
|
| 589 |
+
aggregated_image = gr.Image(label="Aggregated Results", height=450)
|
| 590 |
+
with gr.TabItem("Graph"):
|
| 591 |
+
graph_image = gr.Image(label="Knowledge Graph", height=450)
|
| 592 |
+
|
| 593 |
+
# Right column
|
| 594 |
+
with gr.Column(scale=3):
|
| 595 |
+
with gr.Column(elem_classes=["chat-container"]):
|
| 596 |
+
gr.Markdown("### Chat Interface")
|
| 597 |
+
# Initialize chat with a welcome message
|
| 598 |
+
initial_chat = chat_message(
|
| 599 |
+
"agent",
|
| 600 |
+
"Ready to process P&ID documents and answer questions.",
|
| 601 |
+
agent_avatar,
|
| 602 |
+
get_timestamp()
|
| 603 |
+
)
|
| 604 |
+
chat_output = gr.HTML(
|
| 605 |
+
label="Chat",
|
| 606 |
+
elem_classes=["chatbox"],
|
| 607 |
+
value=initial_chat
|
| 608 |
+
)
|
| 609 |
+
# Message input and send button in a fixed-height container
|
| 610 |
+
with gr.Column(elem_classes=["chat-input-group"]):
|
| 611 |
+
user_input = gr.Textbox(
|
| 612 |
+
show_label=False,
|
| 613 |
+
placeholder="Type your question here...",
|
| 614 |
+
elem_classes=["chat-input"],
|
| 615 |
+
lines=3
|
| 616 |
+
)
|
| 617 |
+
send_button = gr.Button(
|
| 618 |
+
"Send",
|
| 619 |
+
elem_classes=["send-button"]
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Set up event handlers inside the Blocks context
|
| 623 |
+
image_input.upload(
|
| 624 |
+
fn=process_pnid,
|
| 625 |
+
inputs=[image_input, progress_status],
|
| 626 |
+
outputs=[
|
| 627 |
+
original_image,
|
| 628 |
+
symbol_image,
|
| 629 |
+
text_image,
|
| 630 |
+
line_image,
|
| 631 |
+
aggregated_image,
|
| 632 |
+
graph_image,
|
| 633 |
+
chat_output,
|
| 634 |
+
progress_status,
|
| 635 |
+
json_path_state
|
| 636 |
+
],
|
| 637 |
+
show_progress="hidden" # Hide the default progress bar
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
# Add input clearing and enable/disable logic for chat
|
| 641 |
+
def clear_and_handle_message(user_message, chat_history, json_path):
|
| 642 |
+
response = handle_user_message(user_message, chat_history, json_path)
|
| 643 |
+
return "", response # Clear input after sending
|
| 644 |
+
|
| 645 |
+
send_button.click(
|
| 646 |
+
fn=clear_and_handle_message,
|
| 647 |
+
inputs=[user_input, chat_output, json_path_state],
|
| 648 |
+
outputs=[user_input, chat_output]
|
| 649 |
+
)
|
| 650 |
+
|
| 651 |
+
# Also trigger on Enter key
|
| 652 |
+
user_input.submit(
|
| 653 |
+
fn=clear_and_handle_message,
|
| 654 |
+
inputs=[user_input, chat_output, json_path_state],
|
| 655 |
+
outputs=[user_input, chat_output]
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
return demo
|
| 659 |
+
|
| 660 |
+
def main():
|
| 661 |
+
demo = create_ui()
|
| 662 |
+
# Remove HF Spaces conditional, just use local development settings
|
| 663 |
+
demo.launch(server_name="0.0.0.0",
|
| 664 |
+
server_port=7860,
|
| 665 |
+
share=False)
|
| 666 |
+
|
| 667 |
+
if __name__ == "__main__":
|
| 668 |
+
main()
|
| 669 |
+
else:
|
| 670 |
+
# For Spaces deployment
|
| 671 |
+
demo = create_ui()
|
| 672 |
+
app = demo.app # Gradio requires 'app' variable for Spaces
|
graph_construction.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import networkx as nx
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import logging
|
| 7 |
+
import traceback
|
| 8 |
+
from storage import StorageFactory
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
|
| 13 |
+
"""Construct network graph from aggregated detection data"""
|
| 14 |
+
try:
|
| 15 |
+
# Use provided storage or get a new one
|
| 16 |
+
if storage is None:
|
| 17 |
+
storage = StorageFactory.get_storage()
|
| 18 |
+
|
| 19 |
+
# Create graph
|
| 20 |
+
G = nx.Graph()
|
| 21 |
+
pos = {} # For node positions
|
| 22 |
+
|
| 23 |
+
# Add nodes from the aggregated data
|
| 24 |
+
for node in data.get('nodes', []):
|
| 25 |
+
node_id = node['id']
|
| 26 |
+
node_type = node['type']
|
| 27 |
+
|
| 28 |
+
# Calculate position based on node type
|
| 29 |
+
if node_type == 'connection_point':
|
| 30 |
+
pos[node_id] = (node['coords']['x'], node['coords']['y'])
|
| 31 |
+
else: # symbol or text
|
| 32 |
+
bbox = node['bbox']
|
| 33 |
+
pos[node_id] = (
|
| 34 |
+
(bbox['xmin'] + bbox['xmax']) / 2,
|
| 35 |
+
(bbox['ymin'] + bbox['ymax']) / 2
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Add node with all its properties
|
| 39 |
+
G.add_node(node_id, **node)
|
| 40 |
+
|
| 41 |
+
# Add edges from the aggregated data
|
| 42 |
+
for edge in data.get('edges', []):
|
| 43 |
+
G.add_edge(
|
| 44 |
+
edge['source'],
|
| 45 |
+
edge['target'],
|
| 46 |
+
**edge.get('properties', {})
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# Create visualization
|
| 50 |
+
plt.figure(figsize=(20, 20))
|
| 51 |
+
|
| 52 |
+
# Draw nodes with different colors based on type
|
| 53 |
+
node_colors = []
|
| 54 |
+
for node in G.nodes():
|
| 55 |
+
node_type = G.nodes[node]['type']
|
| 56 |
+
if node_type == 'symbol':
|
| 57 |
+
node_colors.append('lightblue')
|
| 58 |
+
elif node_type == 'text':
|
| 59 |
+
node_colors.append('lightgreen')
|
| 60 |
+
else: # connection_point
|
| 61 |
+
node_colors.append('lightgray')
|
| 62 |
+
|
| 63 |
+
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
|
| 64 |
+
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
|
| 65 |
+
|
| 66 |
+
# Add labels
|
| 67 |
+
labels = {}
|
| 68 |
+
for node in G.nodes():
|
| 69 |
+
node_data = G.nodes[node]
|
| 70 |
+
if node_data['type'] == 'symbol':
|
| 71 |
+
labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
|
| 72 |
+
elif node_data['type'] == 'text':
|
| 73 |
+
content = node_data.get('content', '')
|
| 74 |
+
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
|
| 75 |
+
else:
|
| 76 |
+
labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
|
| 77 |
+
|
| 78 |
+
nx.draw_networkx_labels(G, pos, labels, font_size=8)
|
| 79 |
+
|
| 80 |
+
plt.title("P&ID Knowledge Graph")
|
| 81 |
+
plt.axis('off')
|
| 82 |
+
|
| 83 |
+
# Save the visualization
|
| 84 |
+
graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
|
| 85 |
+
plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
|
| 86 |
+
plt.close()
|
| 87 |
+
|
| 88 |
+
# Save graph data as JSON for future use
|
| 89 |
+
graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
|
| 90 |
+
with open(graph_json_path, 'w') as f:
|
| 91 |
+
json.dump(nx.node_link_data(G), f, indent=2)
|
| 92 |
+
|
| 93 |
+
return G, pos, plt.gcf()
|
| 94 |
+
|
| 95 |
+
except Exception as e:
|
| 96 |
+
logger.error(f"Error in construct_graph_network: {str(e)}")
|
| 97 |
+
traceback.print_exc()
|
| 98 |
+
return None, None, None
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
# Test code
|
| 102 |
+
test_data_path = "results/test_aggregated.json"
|
| 103 |
+
if os.path.exists(test_data_path):
|
| 104 |
+
with open(test_data_path, 'r') as f:
|
| 105 |
+
test_data = json.load(f)
|
| 106 |
+
|
| 107 |
+
G, pos, fig = construct_graph_network(
|
| 108 |
+
test_data,
|
| 109 |
+
"results/validation.json",
|
| 110 |
+
"results"
|
| 111 |
+
)
|
| 112 |
+
if fig:
|
| 113 |
+
plt.show()
|
graph_processor.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import networkx as nx
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import traceback
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
def create_connected_graph(input_data):
|
| 9 |
+
"""Create a connected graph from the input data"""
|
| 10 |
+
try:
|
| 11 |
+
# Validate input data structure
|
| 12 |
+
if not isinstance(input_data, dict):
|
| 13 |
+
raise ValueError("Invalid input data format")
|
| 14 |
+
|
| 15 |
+
# Check for required keys in new format
|
| 16 |
+
required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges']
|
| 17 |
+
if not all(key in input_data for key in required_keys):
|
| 18 |
+
raise ValueError(f"Missing required keys in input data. Expected: {required_keys}")
|
| 19 |
+
|
| 20 |
+
# Create graph
|
| 21 |
+
G = nx.Graph()
|
| 22 |
+
|
| 23 |
+
# Track positions for layout
|
| 24 |
+
pos = {}
|
| 25 |
+
|
| 26 |
+
# Add symbol nodes
|
| 27 |
+
for symbol in input_data['symbols']:
|
| 28 |
+
bbox = symbol.get('bbox', [])
|
| 29 |
+
symbol_id = symbol.get('id', str(uuid.uuid4()))
|
| 30 |
+
|
| 31 |
+
if bbox:
|
| 32 |
+
# Calculate center position
|
| 33 |
+
center_x = (bbox['xmin'] + bbox['xmax']) / 2
|
| 34 |
+
center_y = (bbox['ymin'] + bbox['ymax']) / 2
|
| 35 |
+
pos[symbol_id] = (center_x, center_y)
|
| 36 |
+
|
| 37 |
+
G.add_node(
|
| 38 |
+
symbol_id,
|
| 39 |
+
type='symbol',
|
| 40 |
+
class_name=symbol.get('class', ''),
|
| 41 |
+
bbox=bbox,
|
| 42 |
+
confidence=symbol.get('confidence', 0.0)
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Add text nodes
|
| 46 |
+
for text in input_data['texts']:
|
| 47 |
+
bbox = text.get('bbox', [])
|
| 48 |
+
text_id = text.get('id', str(uuid.uuid4()))
|
| 49 |
+
|
| 50 |
+
if bbox:
|
| 51 |
+
center_x = (bbox['xmin'] + bbox['xmax']) / 2
|
| 52 |
+
center_y = (bbox['ymin'] + bbox['ymax']) / 2
|
| 53 |
+
pos[text_id] = (center_x, center_y)
|
| 54 |
+
|
| 55 |
+
G.add_node(
|
| 56 |
+
text_id,
|
| 57 |
+
type='text',
|
| 58 |
+
text=text.get('text', ''),
|
| 59 |
+
bbox=bbox,
|
| 60 |
+
confidence=text.get('confidence', 0.0)
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Add edges from the edges list
|
| 64 |
+
for edge in input_data['edges']:
|
| 65 |
+
source = edge.get('source')
|
| 66 |
+
target = edge.get('target')
|
| 67 |
+
if source and target and source in G and target in G:
|
| 68 |
+
G.add_edge(
|
| 69 |
+
source,
|
| 70 |
+
target,
|
| 71 |
+
type=edge.get('type', 'connection'),
|
| 72 |
+
properties=edge.get('properties', {})
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
# Create visualization
|
| 76 |
+
plt.figure(figsize=(20, 20))
|
| 77 |
+
|
| 78 |
+
# Draw nodes with fixed positions
|
| 79 |
+
nx.draw_networkx_nodes(G, pos,
|
| 80 |
+
node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node in G.nodes()],
|
| 81 |
+
node_size=500)
|
| 82 |
+
|
| 83 |
+
# Draw edges
|
| 84 |
+
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
|
| 85 |
+
|
| 86 |
+
# Add labels
|
| 87 |
+
labels = {}
|
| 88 |
+
for node in G.nodes():
|
| 89 |
+
node_data = G.nodes[node]
|
| 90 |
+
if node_data['type'] == 'symbol':
|
| 91 |
+
labels[node] = f"S:{node_data['class_name']}"
|
| 92 |
+
else:
|
| 93 |
+
text = node_data.get('text', '')
|
| 94 |
+
labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}"
|
| 95 |
+
|
| 96 |
+
nx.draw_networkx_labels(G, pos, labels, font_size=8)
|
| 97 |
+
|
| 98 |
+
plt.title("P&ID Network Graph")
|
| 99 |
+
plt.axis('off')
|
| 100 |
+
|
| 101 |
+
return G, pos, plt.gcf()
|
| 102 |
+
|
| 103 |
+
except Exception as e:
|
| 104 |
+
print(f"Error in create_connected_graph: {str(e)}")
|
| 105 |
+
traceback.print_exc()
|
| 106 |
+
return None, None, None
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
# Test code
|
| 110 |
+
with open('results/0_aggregated.json') as f:
|
| 111 |
+
data = json.load(f)
|
| 112 |
+
|
| 113 |
+
G, pos, fig = create_connected_graph(data)
|
| 114 |
+
if fig:
|
| 115 |
+
plt.show()
|
graph_visualization.py
ADDED
|
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import networkx as nx
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import os
|
| 5 |
+
from pprint import pprint
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
def create_graph_visualization(json_path, save_plot=True):
|
| 9 |
+
"""Create and visualize a graph from the aggregated JSON data"""
|
| 10 |
+
|
| 11 |
+
# Load the aggregated data
|
| 12 |
+
with open(json_path, 'r') as f:
|
| 13 |
+
data = json.load(f)
|
| 14 |
+
|
| 15 |
+
print("\nData Structure:")
|
| 16 |
+
print(f"Keys in data: {data.keys()}")
|
| 17 |
+
for key in data.keys():
|
| 18 |
+
if isinstance(data[key], list):
|
| 19 |
+
print(f"Number of {key}: {len(data[key])}")
|
| 20 |
+
if data[key]:
|
| 21 |
+
print(f"Sample {key}:", data[key][0])
|
| 22 |
+
|
| 23 |
+
# Create a new graph
|
| 24 |
+
G = nx.Graph()
|
| 25 |
+
pos = {}
|
| 26 |
+
|
| 27 |
+
# Track unique junctions by coordinates to avoid duplicates
|
| 28 |
+
junction_map = {}
|
| 29 |
+
|
| 30 |
+
# Process lines and create unique junctions
|
| 31 |
+
print("\nProcessing Lines and Junctions:")
|
| 32 |
+
for line in data.get('lines', []):
|
| 33 |
+
try:
|
| 34 |
+
# Get line properties from edge data for accurate coordinates
|
| 35 |
+
edge_data = None
|
| 36 |
+
for edge in data.get('edges', []):
|
| 37 |
+
if edge['id'] == line['id']:
|
| 38 |
+
edge_data = edge
|
| 39 |
+
break
|
| 40 |
+
|
| 41 |
+
# Get coordinates and create unique ID for start/end points
|
| 42 |
+
if edge_data and 'connection_points' in edge_data['properties']:
|
| 43 |
+
conn_points = edge_data['properties']['connection_points']
|
| 44 |
+
start_x = int(conn_points['start']['x'])
|
| 45 |
+
start_y = int(conn_points['start']['y'])
|
| 46 |
+
end_x = int(conn_points['end']['x'])
|
| 47 |
+
end_y = int(conn_points['end']['y'])
|
| 48 |
+
start_id = str(uuid.uuid4())
|
| 49 |
+
end_id = str(uuid.uuid4())
|
| 50 |
+
else:
|
| 51 |
+
# Fallback to line points
|
| 52 |
+
start_x = int(line['start_point']['x'])
|
| 53 |
+
start_y = int(line['start_point']['y'])
|
| 54 |
+
end_x = int(line['end_point']['x'])
|
| 55 |
+
end_y = int(line['end_point']['y'])
|
| 56 |
+
start_id = line['start_point'].get('id', str(uuid.uuid4()))
|
| 57 |
+
end_id = line['end_point'].get('id', str(uuid.uuid4()))
|
| 58 |
+
|
| 59 |
+
# Skip invalid coordinates
|
| 60 |
+
if not (0 <= start_x <= 10000 and 0 <= start_y <= 10000 and
|
| 61 |
+
0 <= end_x <= 10000 and 0 <= end_y <= 10000):
|
| 62 |
+
print(f"Skipping line with invalid coordinates: ({start_x}, {start_y}) -> ({end_x}, {end_y})")
|
| 63 |
+
continue
|
| 64 |
+
|
| 65 |
+
# Create or get junction nodes
|
| 66 |
+
start_key = f"{start_x}_{start_y}"
|
| 67 |
+
end_key = f"{end_x}_{end_y}"
|
| 68 |
+
|
| 69 |
+
if start_key not in junction_map:
|
| 70 |
+
junction_map[start_key] = start_id
|
| 71 |
+
G.add_node(start_id,
|
| 72 |
+
type='junction',
|
| 73 |
+
junction_type=line['start_point'].get('type', 'unknown'),
|
| 74 |
+
coords={'x': start_x, 'y': start_y})
|
| 75 |
+
pos[start_id] = (start_x, start_y)
|
| 76 |
+
|
| 77 |
+
if end_key not in junction_map:
|
| 78 |
+
junction_map[end_key] = end_id
|
| 79 |
+
G.add_node(end_id,
|
| 80 |
+
type='junction',
|
| 81 |
+
junction_type=line['end_point'].get('type', 'unknown'),
|
| 82 |
+
coords={'x': end_x, 'y': end_y})
|
| 83 |
+
pos[end_id] = (end_x, end_y)
|
| 84 |
+
|
| 85 |
+
# Add line as edge with style properties
|
| 86 |
+
G.add_edge(junction_map[start_key], junction_map[end_key],
|
| 87 |
+
type='line',
|
| 88 |
+
style=line.get('style', {}))
|
| 89 |
+
|
| 90 |
+
except Exception as e:
|
| 91 |
+
print(f"Error processing line: {str(e)}")
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
# Add symbols and texts after lines
|
| 95 |
+
print("\nProcessing Symbols and Texts:")
|
| 96 |
+
for node in data.get('nodes', []):
|
| 97 |
+
if node['type'] in ['symbol', 'text']:
|
| 98 |
+
node_id = node['id']
|
| 99 |
+
coords = node.get('coords', {})
|
| 100 |
+
if coords:
|
| 101 |
+
x, y = coords['x'], coords['y']
|
| 102 |
+
elif 'center' in node:
|
| 103 |
+
x, y = node['center']['x'], node['center']['y']
|
| 104 |
+
else:
|
| 105 |
+
bbox = node['bbox']
|
| 106 |
+
x = (bbox['xmin'] + bbox['xmax']) / 2
|
| 107 |
+
y = (bbox['ymin'] + bbox['ymax']) / 2
|
| 108 |
+
|
| 109 |
+
G.add_node(node_id, **node)
|
| 110 |
+
pos[node_id] = (x, y)
|
| 111 |
+
print(f"Added {node['type']} at ({x}, {y})")
|
| 112 |
+
|
| 113 |
+
# Add default node if graph is empty
|
| 114 |
+
if not pos:
|
| 115 |
+
default_id = str(uuid.uuid4())
|
| 116 |
+
G.add_node(default_id, type='junction', coords={'x': 0, 'y': 0})
|
| 117 |
+
pos[default_id] = (0, 0)
|
| 118 |
+
|
| 119 |
+
# Scale positions to fit in [0, 1] range
|
| 120 |
+
x_vals = [p[0] for p in pos.values()]
|
| 121 |
+
y_vals = [p[1] for p in pos.values()]
|
| 122 |
+
x_min, x_max = min(x_vals), max(x_vals)
|
| 123 |
+
y_min, y_max = min(y_vals), max(y_vals)
|
| 124 |
+
|
| 125 |
+
scaled_pos = {}
|
| 126 |
+
for node, (x, y) in pos.items():
|
| 127 |
+
scaled_x = (x - x_min) / (x_max - x_min) if x_max > x_min else 0.5
|
| 128 |
+
scaled_y = 1 - ((y - y_min) / (y_max - y_min) if y_max > y_min else 0.5) # Flip Y coordinates
|
| 129 |
+
scaled_pos[node] = (scaled_x, scaled_y)
|
| 130 |
+
|
| 131 |
+
# Visualization attributes
|
| 132 |
+
node_colors = []
|
| 133 |
+
node_sizes = []
|
| 134 |
+
labels = {}
|
| 135 |
+
|
| 136 |
+
for node in G.nodes():
|
| 137 |
+
node_data = G.nodes[node]
|
| 138 |
+
if node_data['type'] == 'symbol':
|
| 139 |
+
node_colors.append('lightblue')
|
| 140 |
+
node_sizes.append(1000)
|
| 141 |
+
labels[node] = f"S:{node_data.get('properties', {}).get('class', 'unknown')}"
|
| 142 |
+
elif node_data['type'] == 'text':
|
| 143 |
+
node_colors.append('lightgreen')
|
| 144 |
+
node_sizes.append(800)
|
| 145 |
+
content = node_data.get('content', '')
|
| 146 |
+
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
|
| 147 |
+
else: # junction
|
| 148 |
+
node_colors.append('#ff0000') # Pure red
|
| 149 |
+
node_sizes.append(5) # Even smaller junction nodes
|
| 150 |
+
labels[node] = "" # No labels for junctions
|
| 151 |
+
|
| 152 |
+
# Update visualization
|
| 153 |
+
if save_plot:
|
| 154 |
+
plt.figure(figsize=(20, 20))
|
| 155 |
+
|
| 156 |
+
# Draw edges with styles
|
| 157 |
+
edge_styles = []
|
| 158 |
+
for (u, v, data) in G.edges(data=True):
|
| 159 |
+
if data.get('type') == 'line':
|
| 160 |
+
style = data.get('style', {})
|
| 161 |
+
color = style.get('color', '#000000')
|
| 162 |
+
width = float(style.get('stroke_width', 0.5))
|
| 163 |
+
alpha = 0.7
|
| 164 |
+
line_style = '--' if style.get('connection_type') == 'dashed' else '-'
|
| 165 |
+
|
| 166 |
+
nx.draw_networkx_edges(G, scaled_pos,
|
| 167 |
+
edgelist=[(u, v)],
|
| 168 |
+
edge_color=color,
|
| 169 |
+
width=width,
|
| 170 |
+
alpha=alpha,
|
| 171 |
+
style=line_style)
|
| 172 |
+
|
| 173 |
+
# Track unique styles for legend
|
| 174 |
+
edge_style = (line_style, color, width)
|
| 175 |
+
if edge_style not in edge_styles:
|
| 176 |
+
edge_styles.append(edge_style)
|
| 177 |
+
|
| 178 |
+
# Draw nodes with smaller junctions
|
| 179 |
+
nx.draw_networkx_nodes(G, scaled_pos,
|
| 180 |
+
node_color=node_colors,
|
| 181 |
+
node_size=[3 if size == 5 else size for size in node_sizes], # Even smaller junctions
|
| 182 |
+
alpha=1.0)
|
| 183 |
+
|
| 184 |
+
# Add labels only for symbols and texts
|
| 185 |
+
labels = {k: v for k, v in labels.items() if v}
|
| 186 |
+
nx.draw_networkx_labels(G, scaled_pos, labels,
|
| 187 |
+
font_size=8,
|
| 188 |
+
font_weight='bold')
|
| 189 |
+
|
| 190 |
+
# Create comprehensive legend
|
| 191 |
+
legend_elements = []
|
| 192 |
+
|
| 193 |
+
# Node types
|
| 194 |
+
legend_elements.extend([
|
| 195 |
+
plt.scatter([0], [0], c='lightblue', s=200, label='Symbol'),
|
| 196 |
+
plt.scatter([0], [0], c='lightgreen', s=200, label='Text'),
|
| 197 |
+
plt.scatter([0], [0], c='red', s=20, label='Junction')
|
| 198 |
+
])
|
| 199 |
+
|
| 200 |
+
# Line styles
|
| 201 |
+
for style, color, width in edge_styles:
|
| 202 |
+
legend_elements.append(
|
| 203 |
+
plt.Line2D([0], [0], color=color, linestyle=style,
|
| 204 |
+
linewidth=width, label=f'Line ({style})')
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Add legend with two columns
|
| 208 |
+
plt.legend(handles=legend_elements,
|
| 209 |
+
loc='center left',
|
| 210 |
+
bbox_to_anchor=(1, 0.5),
|
| 211 |
+
ncol=1,
|
| 212 |
+
fontsize=12,
|
| 213 |
+
title="Graph Elements")
|
| 214 |
+
|
| 215 |
+
plt.title("P&ID Knowledge Graph Visualization", pad=20, fontsize=16)
|
| 216 |
+
plt.axis('on')
|
| 217 |
+
plt.grid(True)
|
| 218 |
+
|
| 219 |
+
# Save with extra space for legend
|
| 220 |
+
output_path = os.path.join(os.path.dirname(json_path), "graph_visualization.png")
|
| 221 |
+
plt.savefig(output_path, bbox_inches='tight', dpi=300,
|
| 222 |
+
facecolor='white', edgecolor='none')
|
| 223 |
+
plt.close()
|
| 224 |
+
|
| 225 |
+
return G, scaled_pos
|
| 226 |
+
|
| 227 |
+
if __name__ == "__main__":
|
| 228 |
+
# Test the visualization
|
| 229 |
+
json_path = "results/001_page_1_text_aggregated.json"
|
| 230 |
+
|
| 231 |
+
if os.path.exists(json_path):
|
| 232 |
+
G, scaled_pos = create_graph_visualization(json_path)
|
| 233 |
+
plt.show()
|
| 234 |
+
else:
|
| 235 |
+
print(f"Error: Could not find {json_path}")
|
line_detection_ai.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from base import BaseDetector, BaseDetectionPipeline
|
| 2 |
+
from utils import *
|
| 3 |
+
from config import (
|
| 4 |
+
ImageConfig,
|
| 5 |
+
SymbolConfig,
|
| 6 |
+
TagConfig,
|
| 7 |
+
LineConfig,
|
| 8 |
+
PointConfig,
|
| 9 |
+
JunctionConfig
|
| 10 |
+
)
|
| 11 |
+
from detectors import (
|
| 12 |
+
LineDetector,
|
| 13 |
+
PointDetector,
|
| 14 |
+
JunctionDetector,
|
| 15 |
+
SymbolDetector,
|
| 16 |
+
TagDetector
|
| 17 |
+
)
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from storage import StorageFactory
|
| 20 |
+
from common import DetectionResult
|
| 21 |
+
from detection_schema import DetectionContext, JunctionType
|
| 22 |
+
from typing import List, Tuple, Optional, Dict
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
import cv2
|
| 26 |
+
import os
|
| 27 |
+
import logging
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
class DiagramDetectionPipeline:
|
| 32 |
+
"""
|
| 33 |
+
Pipeline that runs multiple detectors (line, point, junction, etc.) on an image,
|
| 34 |
+
and keeps a shared DetectionContext in memory.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self,
|
| 38 |
+
tag_detector: Optional[BaseDetector],
|
| 39 |
+
symbol_detector: Optional[BaseDetector],
|
| 40 |
+
line_detector: Optional[BaseDetector],
|
| 41 |
+
point_detector: Optional[BaseDetector],
|
| 42 |
+
junction_detector: Optional[BaseDetector],
|
| 43 |
+
storage: StorageInterface,
|
| 44 |
+
debug_handler: Optional[DebugHandler] = None,
|
| 45 |
+
transformer: Optional[CoordinateTransformer] = None):
|
| 46 |
+
"""
|
| 47 |
+
You can pass None for detectors you don't need.
|
| 48 |
+
"""
|
| 49 |
+
# super().__init__(storage=storage, debug_handler=debug_handler)
|
| 50 |
+
self.storage = storage
|
| 51 |
+
self.debug_handler = debug_handler
|
| 52 |
+
self.tag_detector = tag_detector
|
| 53 |
+
self.symbol_detector = symbol_detector
|
| 54 |
+
self.line_detector = line_detector
|
| 55 |
+
self.point_detector = point_detector
|
| 56 |
+
self.junction_detector = junction_detector
|
| 57 |
+
self.transformer = transformer or CoordinateTransformer()
|
| 58 |
+
|
| 59 |
+
def _load_image(self, image_path: str) -> np.ndarray:
|
| 60 |
+
"""Load image with validation."""
|
| 61 |
+
image_data = self.storage.load_file(image_path)
|
| 62 |
+
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
| 63 |
+
if image is None:
|
| 64 |
+
raise ValueError(f"Failed to load image from {image_path}")
|
| 65 |
+
return image
|
| 66 |
+
|
| 67 |
+
def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]:
|
| 68 |
+
"""Crop to ROI if provided, else return full image."""
|
| 69 |
+
if roi is not None and len(roi) == 4:
|
| 70 |
+
x_min, y_min, x_max, y_max = roi
|
| 71 |
+
return image[y_min:y_max, x_min:x_max], (x_min, y_min)
|
| 72 |
+
return image, (0, 0)
|
| 73 |
+
|
| 74 |
+
def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray:
|
| 75 |
+
"""Fill symbol & tag bounding boxes with white to avoid line detection picking them up."""
|
| 76 |
+
masked = image.copy()
|
| 77 |
+
for sym in context.symbols.values():
|
| 78 |
+
cv2.rectangle(masked,
|
| 79 |
+
(sym.bbox.xmin, sym.bbox.ymin),
|
| 80 |
+
(sym.bbox.xmax, sym.bbox.ymax),
|
| 81 |
+
(255, 255, 255), # White
|
| 82 |
+
thickness=-1)
|
| 83 |
+
|
| 84 |
+
for tg in context.tags.values():
|
| 85 |
+
cv2.rectangle(masked,
|
| 86 |
+
(tg.bbox.xmin, tg.bbox.ymin),
|
| 87 |
+
(tg.bbox.xmax, tg.bbox.ymax),
|
| 88 |
+
(255, 255, 255),
|
| 89 |
+
thickness=-1)
|
| 90 |
+
return masked
|
| 91 |
+
|
| 92 |
+
def run(
|
| 93 |
+
self,
|
| 94 |
+
image_path: str,
|
| 95 |
+
output_dir: str,
|
| 96 |
+
config
|
| 97 |
+
) -> DetectionResult:
|
| 98 |
+
"""
|
| 99 |
+
Main pipeline steps (in local coords):
|
| 100 |
+
1) Load + crop image
|
| 101 |
+
2) Detect symbols & tags
|
| 102 |
+
3) Make a copy for final debug images
|
| 103 |
+
4) White out symbol/tag bounding boxes
|
| 104 |
+
5) Detect lines, points, junctions
|
| 105 |
+
6) Save final JSON
|
| 106 |
+
7) Generate debug images with various combinations
|
| 107 |
+
"""
|
| 108 |
+
try:
|
| 109 |
+
with self.debug_handler.track_performance("total_processing"):
|
| 110 |
+
# 1) Load & crop
|
| 111 |
+
image = self._load_image(image_path)
|
| 112 |
+
cropped_image, roi_offset = self._crop_to_roi(image, config.roi)
|
| 113 |
+
|
| 114 |
+
# 2) Create fresh context
|
| 115 |
+
context = DetectionContext()
|
| 116 |
+
|
| 117 |
+
# 3) Detect symbols
|
| 118 |
+
with self.debug_handler.track_performance("symbol_detection"):
|
| 119 |
+
self.symbol_detector.detect(
|
| 120 |
+
cropped_image,
|
| 121 |
+
context=context,
|
| 122 |
+
roi_offset=roi_offset
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# 4) Detect tags
|
| 126 |
+
with self.debug_handler.track_performance("tag_detection"):
|
| 127 |
+
self.tag_detector.detect(
|
| 128 |
+
cropped_image,
|
| 129 |
+
context=context,
|
| 130 |
+
roi_offset=roi_offset
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# Make a copy of the cropped image for final debug combos
|
| 134 |
+
debug_cropped = cropped_image.copy()
|
| 135 |
+
|
| 136 |
+
# 5) White-out symbol/tag bboxes in the original cropped image
|
| 137 |
+
cropped_image = self._remove_symbol_tag_bboxes(cropped_image, context)
|
| 138 |
+
|
| 139 |
+
# 6) Detect lines
|
| 140 |
+
with self.debug_handler.track_performance("line_detection"):
|
| 141 |
+
self.line_detector.detect(cropped_image, context=context)
|
| 142 |
+
|
| 143 |
+
# 7) Detect points
|
| 144 |
+
if self.point_detector:
|
| 145 |
+
with self.debug_handler.track_performance("point_detection"):
|
| 146 |
+
self.point_detector.detect(cropped_image, context=context)
|
| 147 |
+
|
| 148 |
+
# 8) Detect junctions
|
| 149 |
+
if self.junction_detector:
|
| 150 |
+
with self.debug_handler.track_performance("junction_detection"):
|
| 151 |
+
self.junction_detector.detect(cropped_image, context=context)
|
| 152 |
+
|
| 153 |
+
# 9) Save final JSON & any final images
|
| 154 |
+
output_paths = self._persist_results(output_dir, image_path, context)
|
| 155 |
+
|
| 156 |
+
# 10) Save debug images in local coords using debug_cropped
|
| 157 |
+
self._save_all_combinations(debug_cropped, context, output_dir, image_path)
|
| 158 |
+
|
| 159 |
+
return DetectionResult(
|
| 160 |
+
success=True,
|
| 161 |
+
processing_time=self.debug_handler.metrics.get('total_processing', 0),
|
| 162 |
+
json_path=output_paths.get('json_path'),
|
| 163 |
+
image_path=output_paths.get('image_path') # Now returning the annotated image path
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
except Exception as e:
|
| 167 |
+
logger.error(f"Processing failed: {str(e)}")
|
| 168 |
+
return DetectionResult(
|
| 169 |
+
success=False,
|
| 170 |
+
error=str(e)
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# ------------------------------------------------
|
| 174 |
+
# HELPER FUNCTIONS
|
| 175 |
+
# ------------------------------------------------
|
| 176 |
+
def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict:
|
| 177 |
+
"""Saves final JSON and debug images to disk."""
|
| 178 |
+
self.storage.create_directory(output_dir)
|
| 179 |
+
base_name = Path(image_path).stem
|
| 180 |
+
|
| 181 |
+
# Save JSON
|
| 182 |
+
json_path = Path(output_dir) / f"{base_name}_detected_lines.json"
|
| 183 |
+
context_json_str = context.to_json(indent=2)
|
| 184 |
+
self.storage.save_file(str(json_path), context_json_str.encode('utf-8'))
|
| 185 |
+
|
| 186 |
+
# Save annotated image for pipeline display
|
| 187 |
+
annotated_image = self._draw_objects(
|
| 188 |
+
self._load_image(image_path),
|
| 189 |
+
context,
|
| 190 |
+
draw_lines=True,
|
| 191 |
+
draw_points=True,
|
| 192 |
+
draw_symbols=True,
|
| 193 |
+
draw_junctions=True,
|
| 194 |
+
draw_tags=True
|
| 195 |
+
)
|
| 196 |
+
image_path = Path(output_dir) / f"{base_name}_annotated.jpg"
|
| 197 |
+
_, encoded = cv2.imencode('.jpg', annotated_image)
|
| 198 |
+
self.storage.save_file(str(image_path), encoded.tobytes())
|
| 199 |
+
|
| 200 |
+
return {
|
| 201 |
+
"json_path": str(json_path),
|
| 202 |
+
"image_path": str(image_path)
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext,
|
| 206 |
+
output_dir: str, image_path: str) -> None:
|
| 207 |
+
"""Produce debug images with different combinations."""
|
| 208 |
+
base_name = Path(image_path).stem
|
| 209 |
+
base_name = base_name.split("_")[0]
|
| 210 |
+
combos = [
|
| 211 |
+
("text_detected_symbols", dict(draw_symbols=True, draw_tags=False, draw_lines=False, draw_points=False, draw_junctions=False)),
|
| 212 |
+
("text_detected_texts", dict(draw_symbols=False, draw_tags=True, draw_lines=False, draw_points=False, draw_junctions=False)),
|
| 213 |
+
("text_detected_lines", dict(draw_symbols=False, draw_tags=False, draw_lines=True, draw_points=False, draw_junctions=False)),
|
| 214 |
+
]
|
| 215 |
+
self.storage.create_directory(output_dir)
|
| 216 |
+
|
| 217 |
+
for combo_name, flags in combos:
|
| 218 |
+
annotated = self._draw_objects(local_image, context, **flags)
|
| 219 |
+
save_name = f"{base_name}_{combo_name}.jpg"
|
| 220 |
+
save_path = Path(output_dir) / save_name
|
| 221 |
+
_, encoded = cv2.imencode('.jpg', annotated)
|
| 222 |
+
self.storage.save_file(str(save_path), encoded.tobytes())
|
| 223 |
+
logger.info(f"Saved debug image: {save_path}")
|
| 224 |
+
|
| 225 |
+
def _draw_objects(self, base_image: np.ndarray, context: DetectionContext,
|
| 226 |
+
draw_lines: bool = True, draw_points: bool = True,
|
| 227 |
+
draw_symbols: bool = True, draw_junctions: bool = True,
|
| 228 |
+
draw_tags: bool = True) -> np.ndarray:
|
| 229 |
+
"""Draw detection results on a copy of base_image in local coords."""
|
| 230 |
+
annotated = base_image.copy()
|
| 231 |
+
|
| 232 |
+
# Lines
|
| 233 |
+
if draw_lines:
|
| 234 |
+
for ln in context.lines.values():
|
| 235 |
+
cv2.line(annotated,
|
| 236 |
+
(ln.start.coords.x, ln.start.coords.y),
|
| 237 |
+
(ln.end.coords.x, ln.end.coords.y),
|
| 238 |
+
(0, 255, 0), # green
|
| 239 |
+
2)
|
| 240 |
+
|
| 241 |
+
# Points
|
| 242 |
+
if draw_points:
|
| 243 |
+
for pt in context.points.values():
|
| 244 |
+
cv2.circle(annotated,
|
| 245 |
+
(pt.coords.x, pt.coords.y),
|
| 246 |
+
3,
|
| 247 |
+
(0, 0, 255), # red
|
| 248 |
+
-1)
|
| 249 |
+
|
| 250 |
+
# Symbols
|
| 251 |
+
if draw_symbols:
|
| 252 |
+
for sym in context.symbols.values():
|
| 253 |
+
cv2.rectangle(annotated,
|
| 254 |
+
(sym.bbox.xmin, sym.bbox.ymin),
|
| 255 |
+
(sym.bbox.xmax, sym.bbox.ymax),
|
| 256 |
+
(255, 255, 0), # cyan
|
| 257 |
+
2)
|
| 258 |
+
cv2.circle(annotated,
|
| 259 |
+
(sym.center.x, sym.center.y),
|
| 260 |
+
4,
|
| 261 |
+
(255, 0, 255), # magenta
|
| 262 |
+
-1)
|
| 263 |
+
|
| 264 |
+
# Junctions
|
| 265 |
+
if draw_junctions:
|
| 266 |
+
for jn in context.junctions.values():
|
| 267 |
+
if jn.junction_type == JunctionType.T:
|
| 268 |
+
color = (0, 165, 255) # orange
|
| 269 |
+
elif jn.junction_type == JunctionType.L:
|
| 270 |
+
color = (255, 0, 255) # magenta
|
| 271 |
+
else: # END
|
| 272 |
+
color = (0, 0, 255) # red
|
| 273 |
+
cv2.circle(annotated,
|
| 274 |
+
(jn.center.x, jn.center.y),
|
| 275 |
+
5,
|
| 276 |
+
color,
|
| 277 |
+
-1)
|
| 278 |
+
|
| 279 |
+
# Tags
|
| 280 |
+
if draw_tags:
|
| 281 |
+
for tg in context.tags.values():
|
| 282 |
+
cv2.rectangle(annotated,
|
| 283 |
+
(tg.bbox.xmin, tg.bbox.ymin),
|
| 284 |
+
(tg.bbox.xmax, tg.bbox.ymax),
|
| 285 |
+
(128, 0, 128), # purple
|
| 286 |
+
2)
|
| 287 |
+
cv2.putText(annotated,
|
| 288 |
+
tg.text,
|
| 289 |
+
(tg.bbox.xmin, tg.bbox.ymin - 5),
|
| 290 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 291 |
+
0.5,
|
| 292 |
+
(128, 0, 128),
|
| 293 |
+
1)
|
| 294 |
+
|
| 295 |
+
return annotated
|
| 296 |
+
|
| 297 |
+
def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict:
|
| 298 |
+
"""Legacy interface for line detection"""
|
| 299 |
+
storage = StorageFactory.get_storage()
|
| 300 |
+
debug_handler = DebugHandler(enabled=True, storage=storage)
|
| 301 |
+
|
| 302 |
+
line_detector = LineDetector(
|
| 303 |
+
config=LineConfig(),
|
| 304 |
+
model_path="models/deeplsd_md.tar",
|
| 305 |
+
device=torch.device("cpu"),
|
| 306 |
+
debug_handler=debug_handler
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
pipeline = DiagramDetectionPipeline(
|
| 310 |
+
tag_detector=None,
|
| 311 |
+
symbol_detector=None,
|
| 312 |
+
line_detector=line_detector,
|
| 313 |
+
point_detector=None,
|
| 314 |
+
junction_detector=None,
|
| 315 |
+
storage=storage,
|
| 316 |
+
debug_handler=debug_handler
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
result = pipeline.run(image_path, output_dir, ImageConfig())
|
| 320 |
+
return result
|
| 321 |
+
|
| 322 |
+
def _validate_and_normalize_coordinates(self, points):
|
| 323 |
+
"""Validate and normalize coordinates to image space"""
|
| 324 |
+
valid_points = []
|
| 325 |
+
for point in points:
|
| 326 |
+
x, y = point['x'], point['y']
|
| 327 |
+
# Validate coordinates are within image bounds
|
| 328 |
+
if 0 <= x <= self.image_width and 0 <= y <= self.image_height:
|
| 329 |
+
# Normalize coordinates if needed
|
| 330 |
+
valid_points.append({
|
| 331 |
+
'x': int(x),
|
| 332 |
+
'y': int(y),
|
| 333 |
+
'type': point.get('type', 'unknown'),
|
| 334 |
+
'confidence': point.get('confidence', 1.0)
|
| 335 |
+
})
|
| 336 |
+
return valid_points
|
| 337 |
+
|
| 338 |
+
if __name__ == "__main__":
|
| 339 |
+
# 1) Initialize components
|
| 340 |
+
storage = StorageFactory.get_storage()
|
| 341 |
+
debug_handler = DebugHandler(enabled=True, storage=storage)
|
| 342 |
+
|
| 343 |
+
# 2) Build detectors
|
| 344 |
+
conf = {
|
| 345 |
+
"detect_lines": True,
|
| 346 |
+
"line_detection_params": {
|
| 347 |
+
"merge": True,
|
| 348 |
+
"filtering": True,
|
| 349 |
+
"grad_thresh": 3,
|
| 350 |
+
"grad_nfa": True
|
| 351 |
+
}
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
# 3) Configure
|
| 355 |
+
line_config = LineConfig()
|
| 356 |
+
point_config = PointConfig()
|
| 357 |
+
junction_config = JunctionConfig()
|
| 358 |
+
symbol_config = SymbolConfig()
|
| 359 |
+
tag_config = TagConfig()
|
| 360 |
+
|
| 361 |
+
# ========================== Detectors ========================== #
|
| 362 |
+
symbol_detector = SymbolDetector(
|
| 363 |
+
config=symbol_config,
|
| 364 |
+
debug_handler=debug_handler
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
tag_detector = TagDetector(
|
| 368 |
+
config=tag_config,
|
| 369 |
+
debug_handler=debug_handler
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
line_detector = LineDetector(
|
| 373 |
+
config=line_config,
|
| 374 |
+
model_path="models/deeplsd_md.tar",
|
| 375 |
+
model_config=conf,
|
| 376 |
+
device=torch.device("cpu"), # or "cuda" if available
|
| 377 |
+
debug_handler=debug_handler
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
point_detector = PointDetector(
|
| 381 |
+
config=point_config,
|
| 382 |
+
debug_handler=debug_handler)
|
| 383 |
+
|
| 384 |
+
junction_detector = JunctionDetector(
|
| 385 |
+
config=junction_config,
|
| 386 |
+
debug_handler=debug_handler
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# 4) Create pipeline
|
| 390 |
+
pipeline = DiagramDetectionPipeline(
|
| 391 |
+
tag_detector=tag_detector,
|
| 392 |
+
symbol_detector=symbol_detector,
|
| 393 |
+
line_detector=line_detector,
|
| 394 |
+
point_detector=point_detector,
|
| 395 |
+
junction_detector=junction_detector,
|
| 396 |
+
storage=storage,
|
| 397 |
+
debug_handler=debug_handler
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
# 5) Run pipeline
|
| 401 |
+
result = pipeline.run(
|
| 402 |
+
image_path="samples/images/0.jpg",
|
| 403 |
+
output_dir="results/",
|
| 404 |
+
config=ImageConfig()
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
if result.success:
|
| 408 |
+
logger.info(f"Pipeline succeeded! See JSON at {result.json_path}")
|
| 409 |
+
else:
|
| 410 |
+
logger.error(f"Pipeline failed: {result.error}")
|
packages.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
libgl1-mesa-glx
|
| 2 |
+
libglib2.0-0
|
| 3 |
+
tesseract-ocr
|
pdf_processor.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import fitz # PyMuPDF
|
| 2 |
+
import os
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import io
|
| 8 |
+
import cv2 # Add this import
|
| 9 |
+
from storage import StorageInterface
|
| 10 |
+
from typing import List, Dict, Tuple, Any
|
| 11 |
+
import json
|
| 12 |
+
from text_detection_combined import process_drawing
|
| 13 |
+
|
| 14 |
+
# Configure logging
|
| 15 |
+
logging.basicConfig(level=logging.INFO)
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
class DocumentProcessor:
|
| 19 |
+
def __init__(self, storage: StorageInterface):
|
| 20 |
+
self.storage = storage
|
| 21 |
+
self.logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
# Configure optimal processing parameters
|
| 24 |
+
self.target_dpi = 600 # Increased from 300 to 600 DPI
|
| 25 |
+
self.min_dimension = 2000 # Minimum width/height
|
| 26 |
+
self.max_dimension = 8000 # Increased max dimension for higher DPI
|
| 27 |
+
self.quality = 95 # JPEG quality for saving
|
| 28 |
+
|
| 29 |
+
def process_document(self, file_path: str, output_dir: str) -> list:
|
| 30 |
+
"""Process document (PDF/PNG/JPG) and return paths to processed pages"""
|
| 31 |
+
file_ext = Path(file_path).suffix.lower()
|
| 32 |
+
|
| 33 |
+
if file_ext == '.pdf':
|
| 34 |
+
return self._process_pdf(file_path, output_dir)
|
| 35 |
+
elif file_ext in ['.png', '.jpg', '.jpeg']:
|
| 36 |
+
return self._process_image(file_path, output_dir)
|
| 37 |
+
else:
|
| 38 |
+
raise ValueError(f"Unsupported file format: {file_ext}")
|
| 39 |
+
|
| 40 |
+
def _process_pdf(self, pdf_path: str, output_dir: str) -> list:
|
| 41 |
+
"""Process PDF document"""
|
| 42 |
+
processed_pages = []
|
| 43 |
+
processing_results = {}
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
# Create output directory if it doesn't exist
|
| 47 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 48 |
+
|
| 49 |
+
# Clean up any existing files for this document
|
| 50 |
+
base_name = Path(pdf_path).stem
|
| 51 |
+
for file in os.listdir(output_dir):
|
| 52 |
+
if file.startswith(base_name) and file != os.path.basename(pdf_path):
|
| 53 |
+
file_path = os.path.join(output_dir, file)
|
| 54 |
+
try:
|
| 55 |
+
if os.path.isfile(file_path):
|
| 56 |
+
os.unlink(file_path)
|
| 57 |
+
except Exception as e:
|
| 58 |
+
self.logger.error(f"Error deleting file {file_path}: {e}")
|
| 59 |
+
|
| 60 |
+
# Read PDF file directly since it's already in the results directory
|
| 61 |
+
with open(pdf_path, 'rb') as f:
|
| 62 |
+
pdf_data = f.read()
|
| 63 |
+
|
| 64 |
+
doc = fitz.open(stream=pdf_data, filetype="pdf")
|
| 65 |
+
|
| 66 |
+
for page_num in range(len(doc)):
|
| 67 |
+
page = doc[page_num]
|
| 68 |
+
|
| 69 |
+
# Calculate zoom factor for 600 DPI
|
| 70 |
+
zoom = self.target_dpi / 72
|
| 71 |
+
matrix = fitz.Matrix(zoom, zoom)
|
| 72 |
+
|
| 73 |
+
# Get high-resolution image
|
| 74 |
+
pix = page.get_pixmap(matrix=matrix)
|
| 75 |
+
img_data = pix.tobytes()
|
| 76 |
+
|
| 77 |
+
# Convert to numpy array
|
| 78 |
+
nparr = np.frombuffer(img_data, np.uint8)
|
| 79 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 80 |
+
|
| 81 |
+
# Create base filename
|
| 82 |
+
base_filename = f"{Path(pdf_path).stem}_page_{page_num + 1}"
|
| 83 |
+
|
| 84 |
+
# Process and save different versions
|
| 85 |
+
optimized_versions = {
|
| 86 |
+
'text': self._optimize_for_text(img.copy()),
|
| 87 |
+
'symbol': self._optimize_for_symbols(img.copy()),
|
| 88 |
+
'line': self._optimize_for_lines(img.copy())
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
paths = {
|
| 92 |
+
'text': os.path.join(output_dir, f"{base_filename}_text.png"),
|
| 93 |
+
'symbol': os.path.join(output_dir, f"{base_filename}_symbol.png"),
|
| 94 |
+
'line': os.path.join(output_dir, f"{base_filename}_line.png")
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
# Save each version
|
| 98 |
+
for version_type, optimized_img in optimized_versions.items():
|
| 99 |
+
self._save_image(optimized_img, paths[version_type])
|
| 100 |
+
processed_pages.append(paths[version_type])
|
| 101 |
+
|
| 102 |
+
# Store processing results
|
| 103 |
+
processing_results[str(page_num + 1)] = {
|
| 104 |
+
"page_number": page_num + 1,
|
| 105 |
+
"dimensions": {
|
| 106 |
+
"width": img.shape[1],
|
| 107 |
+
"height": img.shape[0]
|
| 108 |
+
},
|
| 109 |
+
"paths": paths,
|
| 110 |
+
"dpi": self.target_dpi,
|
| 111 |
+
"zoom_factor": zoom
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
# Save processing results JSON
|
| 115 |
+
results_json_path = os.path.join(
|
| 116 |
+
output_dir,
|
| 117 |
+
f"{Path(pdf_path).stem}_processing_results.json"
|
| 118 |
+
)
|
| 119 |
+
with open(results_json_path, 'w') as f:
|
| 120 |
+
json.dump(processing_results, f, indent=4)
|
| 121 |
+
|
| 122 |
+
return processed_pages
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
self.logger.error(f"Error processing PDF: {str(e)}")
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
def _process_image(self, image_path: str, output_dir: str) -> list:
|
| 129 |
+
"""Process single image file"""
|
| 130 |
+
try:
|
| 131 |
+
# Load image
|
| 132 |
+
image_data = self.storage.load_file(image_path)
|
| 133 |
+
nparr = np.frombuffer(image_data, np.uint8)
|
| 134 |
+
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 135 |
+
|
| 136 |
+
# Process the image
|
| 137 |
+
processed_img = self._optimize_image(img)
|
| 138 |
+
|
| 139 |
+
# Save processed image
|
| 140 |
+
output_path = os.path.join(
|
| 141 |
+
output_dir,
|
| 142 |
+
f"{Path(image_path).stem}_text.png"
|
| 143 |
+
)
|
| 144 |
+
self._save_image(processed_img, output_path)
|
| 145 |
+
|
| 146 |
+
return [output_path]
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
self.logger.error(f"Error processing image: {str(e)}")
|
| 150 |
+
raise
|
| 151 |
+
|
| 152 |
+
def _optimize_image(self, img: np.ndarray) -> np.ndarray:
|
| 153 |
+
"""Optimize image for best detection results"""
|
| 154 |
+
# Convert to grayscale for processing
|
| 155 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 156 |
+
|
| 157 |
+
# Enhance contrast
|
| 158 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
| 159 |
+
enhanced = clahe.apply(gray)
|
| 160 |
+
|
| 161 |
+
# Denoise
|
| 162 |
+
denoised = cv2.fastNlMeansDenoising(enhanced)
|
| 163 |
+
|
| 164 |
+
# Binarize
|
| 165 |
+
_, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 166 |
+
|
| 167 |
+
# Resize while maintaining aspect ratio
|
| 168 |
+
height, width = binary.shape
|
| 169 |
+
scale = min(self.max_dimension / max(width, height),
|
| 170 |
+
max(self.min_dimension / min(width, height), 1.0))
|
| 171 |
+
|
| 172 |
+
if scale != 1.0:
|
| 173 |
+
new_width = int(width * scale)
|
| 174 |
+
new_height = int(height * scale)
|
| 175 |
+
resized = cv2.resize(binary, (new_width, new_height),
|
| 176 |
+
interpolation=cv2.INTER_LANCZOS4)
|
| 177 |
+
else:
|
| 178 |
+
resized = binary
|
| 179 |
+
|
| 180 |
+
# Convert back to BGR for compatibility
|
| 181 |
+
return cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR)
|
| 182 |
+
|
| 183 |
+
def _optimize_for_text(self, img: np.ndarray) -> np.ndarray:
|
| 184 |
+
"""Optimize image for text detection"""
|
| 185 |
+
# Convert to grayscale
|
| 186 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 187 |
+
|
| 188 |
+
# Enhance contrast using CLAHE
|
| 189 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
|
| 190 |
+
enhanced = clahe.apply(gray)
|
| 191 |
+
|
| 192 |
+
# Denoise
|
| 193 |
+
denoised = cv2.fastNlMeansDenoising(enhanced)
|
| 194 |
+
|
| 195 |
+
# Adaptive thresholding for better text separation
|
| 196 |
+
binary = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
|
| 197 |
+
cv2.THRESH_BINARY, 11, 2)
|
| 198 |
+
|
| 199 |
+
# Convert back to BGR
|
| 200 |
+
return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
|
| 201 |
+
|
| 202 |
+
def _optimize_for_symbols(self, img: np.ndarray) -> np.ndarray:
|
| 203 |
+
"""Optimize image for symbol detection"""
|
| 204 |
+
# Convert to grayscale
|
| 205 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 206 |
+
|
| 207 |
+
# Bilateral filter to preserve edges while reducing noise
|
| 208 |
+
bilateral = cv2.bilateralFilter(gray, 9, 75, 75)
|
| 209 |
+
|
| 210 |
+
# Enhance contrast
|
| 211 |
+
clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
|
| 212 |
+
enhanced = clahe.apply(bilateral)
|
| 213 |
+
|
| 214 |
+
# Sharpen image
|
| 215 |
+
kernel = np.array([[-1,-1,-1],
|
| 216 |
+
[-1, 9,-1],
|
| 217 |
+
[-1,-1,-1]])
|
| 218 |
+
sharpened = cv2.filter2D(enhanced, -1, kernel)
|
| 219 |
+
|
| 220 |
+
# Convert back to BGR
|
| 221 |
+
return cv2.cvtColor(sharpened, cv2.COLOR_GRAY2BGR)
|
| 222 |
+
|
| 223 |
+
def _optimize_for_lines(self, img: np.ndarray) -> np.ndarray:
|
| 224 |
+
"""Optimize image for line detection"""
|
| 225 |
+
# Convert to grayscale
|
| 226 |
+
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
| 227 |
+
|
| 228 |
+
# Reduce noise while preserving edges
|
| 229 |
+
denoised = cv2.GaussianBlur(gray, (3,3), 0)
|
| 230 |
+
|
| 231 |
+
# Edge enhancement
|
| 232 |
+
edges = cv2.Canny(denoised, 50, 150)
|
| 233 |
+
|
| 234 |
+
# Dilate edges to connect broken lines
|
| 235 |
+
kernel = np.ones((2,2), np.uint8)
|
| 236 |
+
dilated = cv2.dilate(edges, kernel, iterations=1)
|
| 237 |
+
|
| 238 |
+
# Convert back to BGR
|
| 239 |
+
return cv2.cvtColor(dilated, cv2.COLOR_GRAY2BGR)
|
| 240 |
+
|
| 241 |
+
def _save_image(self, img: np.ndarray, output_path: str):
|
| 242 |
+
"""Save processed image with optimal quality"""
|
| 243 |
+
# Encode image with high quality
|
| 244 |
+
_, buffer = cv2.imencode('.png', img, [
|
| 245 |
+
cv2.IMWRITE_PNG_COMPRESSION, 0
|
| 246 |
+
])
|
| 247 |
+
|
| 248 |
+
# Save to storage
|
| 249 |
+
self.storage.save_file(output_path, buffer.tobytes())
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
from storage import StorageFactory
|
| 253 |
+
import shutil
|
| 254 |
+
|
| 255 |
+
# Initialize storage and processor
|
| 256 |
+
storage = StorageFactory.get_storage()
|
| 257 |
+
processor = DocumentProcessor(storage)
|
| 258 |
+
|
| 259 |
+
# Process PDF
|
| 260 |
+
pdf_path = "samples/001.pdf"
|
| 261 |
+
output_dir = "results" # Changed from "processed_pages" to "results"
|
| 262 |
+
|
| 263 |
+
try:
|
| 264 |
+
# Ensure output directory exists
|
| 265 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 266 |
+
|
| 267 |
+
results = processor.process_document(
|
| 268 |
+
file_path=pdf_path,
|
| 269 |
+
output_dir=output_dir
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
# Print detailed results
|
| 273 |
+
print("\nProcessing Results:")
|
| 274 |
+
print(f"Output Directory: {os.path.abspath(output_dir)}")
|
| 275 |
+
|
| 276 |
+
for page_path in results:
|
| 277 |
+
abs_path = os.path.abspath(page_path)
|
| 278 |
+
file_size = os.path.getsize(page_path) / (1024 * 1024) # Convert to MB
|
| 279 |
+
print(f"- {os.path.basename(page_path)} ({file_size:.2f} MB)")
|
| 280 |
+
|
| 281 |
+
# Calculate total size of output
|
| 282 |
+
total_size = sum(os.path.getsize(os.path.join(output_dir, f))
|
| 283 |
+
for f in os.listdir(output_dir)) / (1024 * 1024)
|
| 284 |
+
print(f"\nTotal output size: {total_size:.2f} MB")
|
| 285 |
+
|
| 286 |
+
except Exception as e:
|
| 287 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
| 288 |
+
raise
|
requirements.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
gradio>=4.0.0
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
Pillow>=8.0.0
|
| 5 |
+
opencv-python-headless>=4.8.0
|
| 6 |
+
PyMuPDF>=1.18.0 # for PDF processing
|
| 7 |
+
|
| 8 |
+
# OCR Engines
|
| 9 |
+
pytesseract>=0.3.8
|
| 10 |
+
easyocr>=1.7.1
|
| 11 |
+
python-doctr>=0.7.0 # For DocTR OCR
|
| 12 |
+
tensorflow>=2.8.0 # Required by DocTR
|
| 13 |
+
|
| 14 |
+
# Deep Learning
|
| 15 |
+
torch>=2.1.0
|
| 16 |
+
torchvision>=0.15.0
|
| 17 |
+
ultralytics>=8.0.0 # for YOLO models
|
| 18 |
+
deeplsd # Add this for line detection
|
| 19 |
+
omegaconf>=2.3.0 # Required by DeepLSD
|
| 20 |
+
|
| 21 |
+
# Graph Processing
|
| 22 |
+
networkx>=2.6.0
|
| 23 |
+
plotly>=5.3.0
|
| 24 |
+
|
| 25 |
+
# Utilities
|
| 26 |
+
tqdm>=4.66.0
|
| 27 |
+
python-dotenv>=0.19.0
|
| 28 |
+
uuid>=1.30
|
| 29 |
+
shapely>=1.8.0 # for geometry operations
|
| 30 |
+
|
| 31 |
+
# Azure Storage
|
| 32 |
+
azure-storage-blob>=12.0.0
|
| 33 |
+
azure-core>=1.24.0
|
| 34 |
+
|
| 35 |
+
# AI/Chat
|
| 36 |
+
openai>=1.0.0 # For ChatGPT integration
|
| 37 |
+
loguru>=0.7.0
|
| 38 |
+
matplotlib>=3.4.0
|
| 39 |
+
|
| 40 |
+
# Added from the code block
|
| 41 |
+
requests>=2.31.0
|
setup.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="pid-processor",
|
| 5 |
+
version="0.1.0",
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
install_requires=[
|
| 8 |
+
line.strip()
|
| 9 |
+
for line in open("requirements.txt").readlines()
|
| 10 |
+
if not line.startswith("#")
|
| 11 |
+
],
|
| 12 |
+
author="Your Name",
|
| 13 |
+
author_email="your.email@example.com",
|
| 14 |
+
description="P&ID Processing with AI-Powered Graph Construction",
|
| 15 |
+
long_description=open("README.md").read(),
|
| 16 |
+
long_description_content_type="text/markdown",
|
| 17 |
+
url="https://github.com/yourusername/your-repo-name",
|
| 18 |
+
classifiers=[
|
| 19 |
+
"Programming Language :: Python :: 3",
|
| 20 |
+
"License :: OSI Approved :: MIT License",
|
| 21 |
+
"Operating System :: OS Independent",
|
| 22 |
+
],
|
| 23 |
+
python_requires=">=3.8",
|
| 24 |
+
)
|
storage.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from azure.storage.blob import BlobServiceClient
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
class StorageInterface(ABC):
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def save_file(self, file_path: str, content: bytes) -> str:
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def load_file(self, file_path: str) -> bytes:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def list_files(self, directory: str) -> list[str]:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def file_exists(self, file_path: str) -> bool:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def delete_file(self, file_path: str) -> None:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def create_directory(self, directory: str) -> None:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def delete_directory(self, directory: str) -> None:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def upload(self, local_path: str, destination_path: str) -> None:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def append_file(self, file_path: str, content: bytes) -> None:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def get_modified_time(self, file_path: str) -> float:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def directory_exists(self, directory: str) -> bool:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def load_json(self, file_path):
|
| 54 |
+
"""Load and parse JSON file."""
|
| 55 |
+
try:
|
| 56 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 57 |
+
data = json.load(f)
|
| 58 |
+
return data
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading JSON from {file_path}: {str(e)}")
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LocalStorage(StorageInterface):
|
| 65 |
+
|
| 66 |
+
def save_file(self, file_path: str, content: bytes) -> str:
|
| 67 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 68 |
+
with open(file_path, 'wb') as f:
|
| 69 |
+
f.write(content)
|
| 70 |
+
return file_path
|
| 71 |
+
|
| 72 |
+
def load_file(self, file_path: str) -> bytes:
|
| 73 |
+
with open(file_path, 'rb') as f:
|
| 74 |
+
return f.read()
|
| 75 |
+
|
| 76 |
+
def list_files(self, directory: str) -> list[str]:
|
| 77 |
+
return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
| 78 |
+
|
| 79 |
+
def file_exists(self, file_path: str) -> bool:
|
| 80 |
+
return os.path.exists(file_path)
|
| 81 |
+
|
| 82 |
+
def delete_file(self, file_path: str) -> None:
|
| 83 |
+
os.remove(file_path)
|
| 84 |
+
|
| 85 |
+
def create_directory(self, directory: str) -> None:
|
| 86 |
+
os.makedirs(directory, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
def delete_directory(self, directory: str) -> None:
|
| 89 |
+
shutil.rmtree(directory)
|
| 90 |
+
|
| 91 |
+
def upload(self, local_path: str, destination_path: str) -> None:
|
| 92 |
+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
|
| 93 |
+
shutil.copy(local_path, destination_path)
|
| 94 |
+
|
| 95 |
+
def append_file(self, file_path: str, content: bytes) -> None:
|
| 96 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 97 |
+
with open(file_path, 'ab') as f:
|
| 98 |
+
f.write(content)
|
| 99 |
+
|
| 100 |
+
def get_modified_time(self, file_path: str) -> float:
|
| 101 |
+
return os.path.getmtime(file_path)
|
| 102 |
+
|
| 103 |
+
def directory_exists(self, directory: str) -> bool:
|
| 104 |
+
return self.file_exists(directory)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class BlobStorage(StorageInterface):
|
| 108 |
+
"""
|
| 109 |
+
Writes to blob storage, using local disk as a cache
|
| 110 |
+
|
| 111 |
+
TODO: Allow configuration of temp dir instead of just using the same paths in both local and remote
|
| 112 |
+
"""
|
| 113 |
+
def __init__(self, connection_string: str, container_name: str):
|
| 114 |
+
self.blob_service_client = BlobServiceClient.from_connection_string(connection_string)
|
| 115 |
+
self.container_client = self.blob_service_client.get_container_client(container_name)
|
| 116 |
+
self.local_storage = LocalStorage()
|
| 117 |
+
|
| 118 |
+
def download(self, file_path: str) -> bytes:
|
| 119 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 120 |
+
return blob_client.download_blob().readall()
|
| 121 |
+
|
| 122 |
+
def sync(self, file_path: str) -> None:
|
| 123 |
+
if not self.local_storage.file_exists(file_path):
|
| 124 |
+
print(f"DEBUG: missing local version of {file_path} - downloading")
|
| 125 |
+
self.local_storage.save_file(file_path, self.download(file_path))
|
| 126 |
+
else:
|
| 127 |
+
local_timestamp = self.local_storage.get_modified_time(file_path)
|
| 128 |
+
remote_timestamp = self.get_modified_time(file_path)
|
| 129 |
+
if local_timestamp < remote_timestamp:
|
| 130 |
+
# We always write remotely before writing locally, so we expect local_timestamp to be > remote timestamp
|
| 131 |
+
print(f"DBEUG: local version of {file_path} out of date - downloading")
|
| 132 |
+
self.local_storage.save_file(file_path, self.download(file_path))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def save_file(self, file_path: str, content: bytes) -> str:
|
| 136 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 137 |
+
blob_client.upload_blob(content, overwrite=True)
|
| 138 |
+
self.local_storage.save_file(file_path, content)
|
| 139 |
+
return file_path
|
| 140 |
+
|
| 141 |
+
def load_file(self, file_path: str) -> bytes:
|
| 142 |
+
self.sync(file_path)
|
| 143 |
+
return self.local_storage.load_file(file_path)
|
| 144 |
+
|
| 145 |
+
def list_files(self, directory: str) -> list[str]:
|
| 146 |
+
return [blob.name for blob in self.container_client.list_blobs(name_starts_with=directory)]
|
| 147 |
+
|
| 148 |
+
def file_exists(self, file_path: str) -> bool:
|
| 149 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 150 |
+
return blob_client.exists()
|
| 151 |
+
|
| 152 |
+
def delete_file(self, file_path: str) -> None:
|
| 153 |
+
self.local_storage.delete_file(file_path)
|
| 154 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 155 |
+
blob_client.delete_blob()
|
| 156 |
+
|
| 157 |
+
def create_directory(self, directory: str) -> None:
|
| 158 |
+
# Blob storage doesn't have directories, so only create it locally
|
| 159 |
+
self.local_storage.create_directory(directory)
|
| 160 |
+
|
| 161 |
+
def delete_directory(self, directory: str) -> None:
|
| 162 |
+
self.local_storage.delete_directory(directory)
|
| 163 |
+
blobs_to_delete = self.container_client.list_blobs(name_starts_with=directory)
|
| 164 |
+
for blob in blobs_to_delete:
|
| 165 |
+
self.container_client.delete_blob(blob.name)
|
| 166 |
+
|
| 167 |
+
def upload(self, local_path: str, destination_path: str) -> None:
|
| 168 |
+
with open(local_path, "rb") as data:
|
| 169 |
+
blob_client = self.container_client.get_blob_client(destination_path)
|
| 170 |
+
blob_client.upload_blob(data, overwrite=True)
|
| 171 |
+
self.local_storage.upload(local_path, destination_path)
|
| 172 |
+
|
| 173 |
+
def append_file(self, file_path: str, content: bytes) -> None:
|
| 174 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 175 |
+
if not blob_client.exists():
|
| 176 |
+
blob_client.create_append_blob()
|
| 177 |
+
else:
|
| 178 |
+
self.sync(file_path)
|
| 179 |
+
|
| 180 |
+
blob_client.append_block(content)
|
| 181 |
+
self.local_storage.append_file(file_path, content)
|
| 182 |
+
|
| 183 |
+
def get_modified_time(self, file_path: str) -> float:
|
| 184 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 185 |
+
properties = blob_client.get_blob_properties()
|
| 186 |
+
# Convert the UTC datetime to a UNIX timestamp
|
| 187 |
+
return properties.last_modified.timestamp()
|
| 188 |
+
|
| 189 |
+
def directory_exists(self, directory: str) -> bool:
|
| 190 |
+
blobs = self.container_client.list_blobs(name_starts_with=directory)
|
| 191 |
+
return next(blobs, None) is not None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class StorageFactory:
|
| 195 |
+
@staticmethod
|
| 196 |
+
def get_storage() -> StorageInterface:
|
| 197 |
+
storage_type = os.getenv('STORAGE_TYPE', 'local').lower()
|
| 198 |
+
if storage_type == 'local':
|
| 199 |
+
return LocalStorage()
|
| 200 |
+
elif storage_type == 'blob':
|
| 201 |
+
connection_string = os.getenv('AZURE_STORAGE_CONNECTION_STRING')
|
| 202 |
+
container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME')
|
| 203 |
+
if not connection_string or not container_name:
|
| 204 |
+
raise ValueError("Azure Blob Storage connection string and container name must be set")
|
| 205 |
+
return BlobStorage(connection_string, container_name)
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError(f"Unsupported storage type: {storage_type}")
|
| 208 |
+
|
symbol_detection.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import json
|
| 3 |
+
import uuid
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from ultralytics import YOLO
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from storage import StorageInterface
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Tuple, List, Dict, Any
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
+
|
| 15 |
+
# Constants
|
| 16 |
+
MODEL_PATHS = {
|
| 17 |
+
"model1": "models/Intui_SDM_41.pt",
|
| 18 |
+
"model2": "models/Intui_SDM_20.pt" # Add your second model path here
|
| 19 |
+
}
|
| 20 |
+
MAX_DIMENSION = 1280
|
| 21 |
+
CONFIDENCE_THRESHOLDS = [0.1, 0.3, 0.5, 0.7, 0.9]
|
| 22 |
+
TEXT_COLOR = (0, 0, 255) # Red color for text
|
| 23 |
+
BOX_COLOR = (255, 0, 0) # Red color for box (no transparency)
|
| 24 |
+
BG_COLOR = (255, 255, 255, 0.6) # Semi-transparent white for text background
|
| 25 |
+
THICKNESS = 1 # Thin text thickness
|
| 26 |
+
BOX_THICKNESS = 2 # Box line thickness
|
| 27 |
+
MIN_FONT_SCALE = 0.2 # Minimum font scale
|
| 28 |
+
MAX_FONT_SCALE = 1.0 # Maximum font scale
|
| 29 |
+
TEXT_PADDING = 20 # Increased padding between text elements
|
| 30 |
+
OVERLAP_THRESHOLD = 0.3 # Threshold for detecting text overlap
|
| 31 |
+
|
| 32 |
+
def preprocess_image_for_symbol_detection(image_cv: np.ndarray) -> np.ndarray:
|
| 33 |
+
"""Preprocess the image for symbol detection."""
|
| 34 |
+
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
|
| 35 |
+
equalized = cv2.equalizeHist(gray)
|
| 36 |
+
filtered = cv2.bilateralFilter(equalized, 9, 75, 75)
|
| 37 |
+
edges = cv2.Canny(filtered, 100, 200)
|
| 38 |
+
preprocessed_image = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
|
| 39 |
+
return preprocessed_image
|
| 40 |
+
|
| 41 |
+
def evaluate_detections(detections_list: List[Dict[str, Any]]) -> int:
|
| 42 |
+
"""Evaluate the quality of detections."""
|
| 43 |
+
return len(detections_list)
|
| 44 |
+
|
| 45 |
+
def resize_image_with_aspect_ratio(image_cv: np.ndarray, max_dimension: int) -> Tuple[np.ndarray, int, int]:
|
| 46 |
+
"""Resize the image while maintaining the aspect ratio."""
|
| 47 |
+
original_height, original_width = image_cv.shape[:2]
|
| 48 |
+
if max(original_width, original_height) > max_dimension:
|
| 49 |
+
scale = max_dimension / float(max(original_width, original_height))
|
| 50 |
+
new_width = int(original_width * scale)
|
| 51 |
+
new_height = int(original_height * scale)
|
| 52 |
+
image_cv = cv2.resize(image_cv, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
|
| 53 |
+
else:
|
| 54 |
+
new_width, new_height = original_width, original_height
|
| 55 |
+
return image_cv, new_width, new_height
|
| 56 |
+
|
| 57 |
+
def merge_detections(all_detections: List[Dict]) -> List[Dict]:
|
| 58 |
+
"""
|
| 59 |
+
Merge detections from all models, keeping only the highest confidence detection
|
| 60 |
+
when duplicates are found using IoU.
|
| 61 |
+
"""
|
| 62 |
+
if not all_detections:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
# Sort by confidence
|
| 66 |
+
all_detections.sort(key=lambda x: x['confidence'], reverse=True)
|
| 67 |
+
|
| 68 |
+
# Keep track of which detections to keep
|
| 69 |
+
keep = [True] * len(all_detections)
|
| 70 |
+
|
| 71 |
+
def calculate_iou(box1, box2):
|
| 72 |
+
"""Calculate Intersection over Union (IoU) between two boxes."""
|
| 73 |
+
x1 = max(box1[0], box2[0])
|
| 74 |
+
y1 = max(box1[1], box2[1])
|
| 75 |
+
x2 = min(box1[2], box2[2])
|
| 76 |
+
y2 = min(box1[3], box2[3])
|
| 77 |
+
|
| 78 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| 79 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 80 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 81 |
+
union = area1 + area2 - intersection
|
| 82 |
+
|
| 83 |
+
return intersection / union if union > 0 else 0
|
| 84 |
+
|
| 85 |
+
# Apply NMS and keep only highest confidence detection
|
| 86 |
+
for i in range(len(all_detections)):
|
| 87 |
+
if not keep[i]:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
current_box = all_detections[i]['bbox']
|
| 91 |
+
current_label = all_detections[i]['original_label']
|
| 92 |
+
|
| 93 |
+
for j in range(i + 1, len(all_detections)):
|
| 94 |
+
if not keep[j]:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
# Check if same label type and high IoU
|
| 98 |
+
if (all_detections[j]['original_label'] == current_label and
|
| 99 |
+
calculate_iou(current_box, all_detections[j]['bbox']) > 0.5):
|
| 100 |
+
# Since list is sorted by confidence, i will always have higher confidence than j
|
| 101 |
+
keep[j] = False
|
| 102 |
+
logging.info(f"Removing duplicate detection of {current_label} with lower confidence "
|
| 103 |
+
f"({all_detections[j]['confidence']:.2f} < {all_detections[i]['confidence']:.2f})")
|
| 104 |
+
|
| 105 |
+
# Add kept detections to final list
|
| 106 |
+
merged_detections = [det for i, det in enumerate(all_detections) if keep[i]]
|
| 107 |
+
return merged_detections
|
| 108 |
+
|
| 109 |
+
def calculate_font_scale(image_width: int, bbox_width: int) -> float:
|
| 110 |
+
"""
|
| 111 |
+
Calculate appropriate font scale based on image and bbox dimensions.
|
| 112 |
+
"""
|
| 113 |
+
base_scale = 0.7 # Increased base scale for better visibility
|
| 114 |
+
|
| 115 |
+
# Adjust font size based on image width and bbox width
|
| 116 |
+
width_ratio = image_width / MAX_DIMENSION
|
| 117 |
+
bbox_ratio = bbox_width / image_width
|
| 118 |
+
|
| 119 |
+
# Calculate adaptive scale with increased multipliers
|
| 120 |
+
adaptive_scale = base_scale * max(width_ratio, 0.5) * max(bbox_ratio * 6, 0.7)
|
| 121 |
+
|
| 122 |
+
# Ensure font scale stays within reasonable bounds
|
| 123 |
+
return min(max(adaptive_scale, MIN_FONT_SCALE), MAX_FONT_SCALE)
|
| 124 |
+
|
| 125 |
+
def check_overlap(rect1, rect2):
|
| 126 |
+
"""Check if two rectangles overlap."""
|
| 127 |
+
x1_1, y1_1, x2_1, y2_1 = rect1
|
| 128 |
+
x1_2, y1_2, x2_2, y2_2 = rect2
|
| 129 |
+
|
| 130 |
+
return not (x2_1 < x1_2 or x1_1 > x2_2 or y2_1 < y1_2 or y1_1 > y2_2)
|
| 131 |
+
|
| 132 |
+
def draw_annotation(
|
| 133 |
+
image: np.ndarray,
|
| 134 |
+
bbox: List[int],
|
| 135 |
+
text: str,
|
| 136 |
+
confidence: float,
|
| 137 |
+
model_source: str,
|
| 138 |
+
existing_annotations: List[tuple] = None
|
| 139 |
+
) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Draw annotation with no background and thin fonts.
|
| 142 |
+
"""
|
| 143 |
+
if existing_annotations is None:
|
| 144 |
+
existing_annotations = []
|
| 145 |
+
|
| 146 |
+
x1, y1, x2, y2 = bbox
|
| 147 |
+
bbox_width = x2 - x1
|
| 148 |
+
image_width = image.shape[1]
|
| 149 |
+
image_height = image.shape[0]
|
| 150 |
+
|
| 151 |
+
# Calculate adaptive font scale
|
| 152 |
+
font_scale = calculate_font_scale(image_width, bbox_width)
|
| 153 |
+
|
| 154 |
+
# Simplify the annotation text
|
| 155 |
+
annotation_text = f'{text}\n{confidence:.0f}%'
|
| 156 |
+
lines = annotation_text.split('\n')
|
| 157 |
+
|
| 158 |
+
# Calculate text dimensions
|
| 159 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 160 |
+
max_width = 0
|
| 161 |
+
total_height = 0
|
| 162 |
+
line_heights = []
|
| 163 |
+
|
| 164 |
+
for line in lines:
|
| 165 |
+
(width, height), baseline = cv2.getTextSize(
|
| 166 |
+
line, font, font_scale, THICKNESS
|
| 167 |
+
)
|
| 168 |
+
max_width = max(max_width, width)
|
| 169 |
+
line_height = height + baseline + TEXT_PADDING
|
| 170 |
+
line_heights.append(line_height)
|
| 171 |
+
total_height += line_height
|
| 172 |
+
|
| 173 |
+
# Calculate initial text position with increased padding
|
| 174 |
+
padding = TEXT_PADDING
|
| 175 |
+
rect_x1 = max(0, x1 - padding)
|
| 176 |
+
rect_x2 = min(image_width, x1 + max_width + padding * 2)
|
| 177 |
+
|
| 178 |
+
# Try different positions to avoid overlap
|
| 179 |
+
positions = [
|
| 180 |
+
('top', y1 - total_height - padding),
|
| 181 |
+
('bottom', y2 + padding),
|
| 182 |
+
('top_shifted', y1 - total_height - padding * 2),
|
| 183 |
+
('bottom_shifted', y2 + padding * 2)
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
final_position = None
|
| 187 |
+
for pos_name, y_pos in positions:
|
| 188 |
+
if y_pos < 0 or y_pos + total_height > image_height:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
rect = (rect_x1, y_pos, rect_x2, y_pos + total_height)
|
| 192 |
+
overlap = False
|
| 193 |
+
|
| 194 |
+
for existing_rect in existing_annotations:
|
| 195 |
+
if check_overlap(rect, existing_rect):
|
| 196 |
+
overlap = True
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
if not overlap:
|
| 200 |
+
final_position = (pos_name, y_pos)
|
| 201 |
+
existing_annotations.append(rect)
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
# If no non-overlapping position found, use side position
|
| 205 |
+
if final_position is None:
|
| 206 |
+
rect_x1 = max(0, x1 + bbox_width + padding)
|
| 207 |
+
rect_x2 = min(image_width, rect_x1 + max_width + padding * 2)
|
| 208 |
+
y_pos = y1
|
| 209 |
+
final_position = ('side', y_pos)
|
| 210 |
+
|
| 211 |
+
rect_y1 = final_position[1]
|
| 212 |
+
|
| 213 |
+
# Draw bounding box (no transparency)
|
| 214 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), BOX_COLOR, BOX_THICKNESS)
|
| 215 |
+
|
| 216 |
+
# Draw text directly without background
|
| 217 |
+
text_y = rect_y1 + line_heights[0] - padding
|
| 218 |
+
for i, line in enumerate(lines):
|
| 219 |
+
# Draw text with thin lines
|
| 220 |
+
cv2.putText(
|
| 221 |
+
image,
|
| 222 |
+
line,
|
| 223 |
+
(rect_x1 + padding, text_y + sum(line_heights[:i])),
|
| 224 |
+
font,
|
| 225 |
+
font_scale,
|
| 226 |
+
TEXT_COLOR,
|
| 227 |
+
THICKNESS,
|
| 228 |
+
cv2.LINE_AA
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def run_detection_with_optimal_threshold(
|
| 232 |
+
image_path: str,
|
| 233 |
+
results_dir: str = "results",
|
| 234 |
+
file_name: str = "",
|
| 235 |
+
apply_preprocessing: bool = False,
|
| 236 |
+
resize_image: bool = True, # Changed default to True
|
| 237 |
+
storage: StorageInterface = None
|
| 238 |
+
) -> Tuple[str, str, str, List[int]]:
|
| 239 |
+
"""Run detection with multiple models and merge results."""
|
| 240 |
+
try:
|
| 241 |
+
image_data = storage.load_file(image_path)
|
| 242 |
+
nparr = np.frombuffer(image_data, np.uint8)
|
| 243 |
+
original_image_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 244 |
+
image_cv = original_image_cv.copy()
|
| 245 |
+
|
| 246 |
+
if resize_image:
|
| 247 |
+
logging.info("Resizing image for detection with aspect ratio...")
|
| 248 |
+
image_cv, resized_width, resized_height = resize_image_with_aspect_ratio(image_cv, MAX_DIMENSION)
|
| 249 |
+
else:
|
| 250 |
+
logging.info("Skipping image resizing...")
|
| 251 |
+
resized_height, resized_width = original_image_cv.shape[:2]
|
| 252 |
+
|
| 253 |
+
if apply_preprocessing:
|
| 254 |
+
logging.info("Preprocessing image for symbol detection...")
|
| 255 |
+
image_cv = preprocess_image_for_symbol_detection(image_cv)
|
| 256 |
+
else:
|
| 257 |
+
logging.info("Skipping image preprocessing for symbol detection...")
|
| 258 |
+
|
| 259 |
+
all_detections = []
|
| 260 |
+
|
| 261 |
+
# Run detection with each model
|
| 262 |
+
for model_name, model_path in MODEL_PATHS.items():
|
| 263 |
+
logging.info(f"Running detection with model: {model_name}")
|
| 264 |
+
|
| 265 |
+
if not model_path:
|
| 266 |
+
logging.warning(f"No model path found for {model_name}")
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
model = YOLO(model_path)
|
| 270 |
+
|
| 271 |
+
best_confidence_threshold = 0.5
|
| 272 |
+
best_detections_list = []
|
| 273 |
+
best_metric = -1
|
| 274 |
+
|
| 275 |
+
for confidence_threshold in CONFIDENCE_THRESHOLDS:
|
| 276 |
+
logging.info(f"Running detection with confidence threshold: {confidence_threshold}...")
|
| 277 |
+
results = model.predict(source=image_cv, imgsz=MAX_DIMENSION)
|
| 278 |
+
|
| 279 |
+
detections_list = []
|
| 280 |
+
for result in results:
|
| 281 |
+
for box in result.boxes:
|
| 282 |
+
confidence = float(box.conf[0])
|
| 283 |
+
if confidence >= confidence_threshold:
|
| 284 |
+
x1, y1, x2, y2 = map(float, box.xyxy[0])
|
| 285 |
+
class_id = int(box.cls[0])
|
| 286 |
+
label = result.names[class_id]
|
| 287 |
+
|
| 288 |
+
scale_x = original_image_cv.shape[1] / resized_width
|
| 289 |
+
scale_y = original_image_cv.shape[0] / resized_height
|
| 290 |
+
x1 *= scale_x
|
| 291 |
+
x2 *= scale_x
|
| 292 |
+
y1 *= scale_y
|
| 293 |
+
y2 *= scale_y
|
| 294 |
+
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
| 295 |
+
|
| 296 |
+
split_label = label.split('_')
|
| 297 |
+
if len(split_label) >= 3:
|
| 298 |
+
category = split_label[0]
|
| 299 |
+
type_ = split_label[1]
|
| 300 |
+
new_label = '_'.join(split_label[2:])
|
| 301 |
+
elif len(split_label) == 2:
|
| 302 |
+
category = split_label[0]
|
| 303 |
+
type_ = split_label[1]
|
| 304 |
+
new_label = split_label[1]
|
| 305 |
+
elif len(split_label) == 1:
|
| 306 |
+
category = split_label[0]
|
| 307 |
+
type_ = "Unknown"
|
| 308 |
+
new_label = split_label[0]
|
| 309 |
+
else:
|
| 310 |
+
logging.warning(f"Unexpected label format: {label}. Skipping this detection.")
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
detection_id = str(uuid.uuid4())
|
| 314 |
+
detection_info = {
|
| 315 |
+
"symbol_id": detection_id,
|
| 316 |
+
"class_id": class_id,
|
| 317 |
+
"original_label": label,
|
| 318 |
+
"category": category,
|
| 319 |
+
"type": type_,
|
| 320 |
+
"label": new_label,
|
| 321 |
+
"confidence": confidence,
|
| 322 |
+
"bbox": [x1, y1, x2, y2],
|
| 323 |
+
"model_source": model_name
|
| 324 |
+
}
|
| 325 |
+
detections_list.append(detection_info)
|
| 326 |
+
|
| 327 |
+
metric = evaluate_detections(detections_list)
|
| 328 |
+
if metric > best_metric:
|
| 329 |
+
best_metric = metric
|
| 330 |
+
best_confidence_threshold = confidence_threshold
|
| 331 |
+
best_detections_list = detections_list
|
| 332 |
+
|
| 333 |
+
all_detections.extend(best_detections_list)
|
| 334 |
+
|
| 335 |
+
# Merge detections from both models
|
| 336 |
+
merged_detections = merge_detections(all_detections)
|
| 337 |
+
logging.info(f"Total detections after merging: {len(merged_detections)}")
|
| 338 |
+
|
| 339 |
+
# Draw annotations on the image
|
| 340 |
+
existing_annotations = []
|
| 341 |
+
for det in merged_detections:
|
| 342 |
+
draw_annotation(
|
| 343 |
+
original_image_cv,
|
| 344 |
+
det["bbox"],
|
| 345 |
+
det["original_label"],
|
| 346 |
+
det["confidence"] * 100,
|
| 347 |
+
det["model_source"],
|
| 348 |
+
existing_annotations
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Save results
|
| 352 |
+
storage.create_directory(results_dir)
|
| 353 |
+
file_name_without_extension = os.path.splitext(file_name)[0]
|
| 354 |
+
|
| 355 |
+
# Prepare output JSON
|
| 356 |
+
total_detected_symbols = len(merged_detections)
|
| 357 |
+
class_counts = {}
|
| 358 |
+
for det in merged_detections:
|
| 359 |
+
full_label = det["original_label"]
|
| 360 |
+
class_counts[full_label] = class_counts.get(full_label, 0) + 1
|
| 361 |
+
|
| 362 |
+
output_json = {
|
| 363 |
+
"total_detected_symbols": total_detected_symbols,
|
| 364 |
+
"details": class_counts,
|
| 365 |
+
"detections": merged_detections
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
# Save JSON and image
|
| 369 |
+
detection_json_path = os.path.join(
|
| 370 |
+
results_dir, f'{file_name_without_extension}_detected_symbols.json'
|
| 371 |
+
)
|
| 372 |
+
storage.save_file(
|
| 373 |
+
detection_json_path,
|
| 374 |
+
json.dumps(output_json, indent=4).encode('utf-8')
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Save with maximum quality
|
| 378 |
+
detection_image_path = os.path.join(
|
| 379 |
+
results_dir, f'{file_name_without_extension}_detected_symbols.png' # Using PNG for transparency
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Configure image encoding parameters for maximum quality
|
| 383 |
+
encode_params = [
|
| 384 |
+
cv2.IMWRITE_PNG_COMPRESSION, 0 # No compression for PNG
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
# Save as high-quality PNG to preserve transparency
|
| 388 |
+
_, img_encoded = cv2.imencode(
|
| 389 |
+
'.png',
|
| 390 |
+
original_image_cv,
|
| 391 |
+
encode_params
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
storage.save_file(detection_image_path, img_encoded.tobytes())
|
| 395 |
+
|
| 396 |
+
# Calculate diagram bbox from merged detections
|
| 397 |
+
diagram_bbox = [
|
| 398 |
+
min([det['bbox'][0] for det in merged_detections], default=0),
|
| 399 |
+
min([det['bbox'][1] for det in merged_detections], default=0),
|
| 400 |
+
max([det['bbox'][2] for det in merged_detections], default=0),
|
| 401 |
+
max([det['bbox'][3] for det in merged_detections], default=0)
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
# Scale up image if it's too small
|
| 405 |
+
min_width = 2000 # Minimum width for good visibility
|
| 406 |
+
if original_image_cv.shape[1] < min_width:
|
| 407 |
+
scale_factor = min_width / original_image_cv.shape[1]
|
| 408 |
+
new_width = min_width
|
| 409 |
+
new_height = int(original_image_cv.shape[0] * scale_factor)
|
| 410 |
+
original_image_cv = cv2.resize(
|
| 411 |
+
original_image_cv,
|
| 412 |
+
(new_width, new_height),
|
| 413 |
+
interpolation=cv2.INTER_CUBIC
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return (
|
| 417 |
+
detection_image_path,
|
| 418 |
+
detection_json_path,
|
| 419 |
+
f"Total detections after merging: {total_detected_symbols}",
|
| 420 |
+
diagram_bbox
|
| 421 |
+
)
|
| 422 |
+
except Exception as e:
|
| 423 |
+
logging.error(f"An error occurred: {e}")
|
| 424 |
+
return "Error during detection", None, None, None
|
| 425 |
+
|
| 426 |
+
if __name__ == "__main__":
|
| 427 |
+
from storage import StorageFactory
|
| 428 |
+
|
| 429 |
+
uploaded_file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
|
| 430 |
+
results_dir = "results"
|
| 431 |
+
apply_symbol_preprocessing = False
|
| 432 |
+
resize_image = True
|
| 433 |
+
|
| 434 |
+
storage = StorageFactory.get_storage()
|
| 435 |
+
|
| 436 |
+
(
|
| 437 |
+
detection_image_path,
|
| 438 |
+
detection_json_path,
|
| 439 |
+
detection_log_message,
|
| 440 |
+
diagram_bbox
|
| 441 |
+
) = run_detection_with_optimal_threshold(
|
| 442 |
+
uploaded_file_path,
|
| 443 |
+
results_dir=results_dir,
|
| 444 |
+
file_name=os.path.basename(uploaded_file_path),
|
| 445 |
+
apply_preprocessing=apply_symbol_preprocessing,
|
| 446 |
+
resize_image=resize_image,
|
| 447 |
+
storage=storage
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
logging.info("Detection Image Path: %s", detection_image_path)
|
| 451 |
+
logging.info("Detection JSON Path: %s", detection_json_path)
|
| 452 |
+
logging.info("Detection Log Message: %s", detection_log_message)
|
| 453 |
+
logging.info("Diagram BBox: %s", diagram_bbox)
|
| 454 |
+
logging.info("Done!")
|
text_detection_combined.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import io
|
| 4 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 5 |
+
import numpy as np
|
| 6 |
+
from doctr.models import ocr_predictor
|
| 7 |
+
import pytesseract
|
| 8 |
+
import easyocr
|
| 9 |
+
from storage import StorageInterface
|
| 10 |
+
import re
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import cv2
|
| 14 |
+
import traceback
|
| 15 |
+
|
| 16 |
+
# Initialize models
|
| 17 |
+
try:
|
| 18 |
+
doctr_model = ocr_predictor(pretrained=True)
|
| 19 |
+
easyocr_reader = easyocr.Reader(['en'])
|
| 20 |
+
logging.info("All OCR models loaded successfully")
|
| 21 |
+
except Exception as e:
|
| 22 |
+
logging.error(f"Error loading OCR models: {e}")
|
| 23 |
+
|
| 24 |
+
# Combined patterns from all approaches
|
| 25 |
+
TEXT_PATTERNS = {
|
| 26 |
+
'Line_Number': r"(?:\d{1,5}[-](?:[A-Z]{2,4})[-]\d{1,3})",
|
| 27 |
+
'Equipment_Tag': r"(?:[A-Z]{1,3}[-][A-Z0-9]{1,4}[-]\d{1,3})",
|
| 28 |
+
'Instrument_Tag': r"(?:\d{2,3}[-][A-Z]{2,4}[-]\d{2,3})",
|
| 29 |
+
'Valve_Number': r"(?:[A-Z]{1,2}[-]\d{3})",
|
| 30 |
+
'Pipe_Size': r"(?:\d{1,2}[\"])",
|
| 31 |
+
'Flow_Direction': r"(?:FROM|TO)",
|
| 32 |
+
'Service_Description': r"(?:STEAM|WATER|AIR|GAS|DRAIN)",
|
| 33 |
+
'Process_Instrument': r"(?:[0-9]{2,3}(?:-[A-Z]{2,3})?-[0-9]{2,3}|[A-Z]{2,3}-[0-9]{2,3})",
|
| 34 |
+
'Nozzle': r"(?:N[0-9]{1,2}|MH)",
|
| 35 |
+
'Pipe_Connector': r"(?:[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5})"
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
def detect_text_combined(image, confidence_threshold=0.3):
|
| 39 |
+
"""Combine results from all three OCR approaches"""
|
| 40 |
+
results = []
|
| 41 |
+
|
| 42 |
+
# 1. Tesseract Detection
|
| 43 |
+
tesseract_results = detect_with_tesseract(image)
|
| 44 |
+
for result in tesseract_results:
|
| 45 |
+
result['source'] = 'tesseract'
|
| 46 |
+
results.append(result)
|
| 47 |
+
|
| 48 |
+
# 2. EasyOCR Detection
|
| 49 |
+
easyocr_results = detect_with_easyocr(image)
|
| 50 |
+
for result in easyocr_results:
|
| 51 |
+
result['source'] = 'easyocr'
|
| 52 |
+
results.append(result)
|
| 53 |
+
|
| 54 |
+
# 3. DocTR Detection
|
| 55 |
+
doctr_results = detect_with_doctr(image)
|
| 56 |
+
for result in doctr_results:
|
| 57 |
+
result['source'] = 'doctr'
|
| 58 |
+
results.append(result)
|
| 59 |
+
|
| 60 |
+
# Merge overlapping detections
|
| 61 |
+
merged_results = merge_overlapping_detections(results)
|
| 62 |
+
|
| 63 |
+
# Classify and filter results
|
| 64 |
+
classified_results = []
|
| 65 |
+
for result in merged_results:
|
| 66 |
+
if result['confidence'] >= confidence_threshold:
|
| 67 |
+
text_type = classify_text(result['text'])
|
| 68 |
+
result['text_type'] = text_type
|
| 69 |
+
classified_results.append(result)
|
| 70 |
+
|
| 71 |
+
return classified_results
|
| 72 |
+
|
| 73 |
+
def generate_detailed_summary(results):
|
| 74 |
+
"""Generate detailed detection summary"""
|
| 75 |
+
summary = {
|
| 76 |
+
'total_detections': len(results),
|
| 77 |
+
'by_type': {},
|
| 78 |
+
'by_source': {
|
| 79 |
+
'tesseract': {
|
| 80 |
+
'count': 0,
|
| 81 |
+
'by_type': {},
|
| 82 |
+
'avg_confidence': 0.0
|
| 83 |
+
},
|
| 84 |
+
'easyocr': {
|
| 85 |
+
'count': 0,
|
| 86 |
+
'by_type': {},
|
| 87 |
+
'avg_confidence': 0.0
|
| 88 |
+
},
|
| 89 |
+
'doctr': {
|
| 90 |
+
'count': 0,
|
| 91 |
+
'by_type': {},
|
| 92 |
+
'avg_confidence': 0.0
|
| 93 |
+
}
|
| 94 |
+
},
|
| 95 |
+
'confidence_ranges': {
|
| 96 |
+
'0.9-1.0': 0,
|
| 97 |
+
'0.8-0.9': 0,
|
| 98 |
+
'0.7-0.8': 0,
|
| 99 |
+
'0.6-0.7': 0,
|
| 100 |
+
'0.5-0.6': 0,
|
| 101 |
+
'<0.5': 0
|
| 102 |
+
},
|
| 103 |
+
'detected_items': []
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Initialize type counters
|
| 107 |
+
for pattern_type in TEXT_PATTERNS.keys():
|
| 108 |
+
summary['by_type'][pattern_type] = {
|
| 109 |
+
'count': 0,
|
| 110 |
+
'avg_confidence': 0.0,
|
| 111 |
+
'by_source': {
|
| 112 |
+
'tesseract': 0,
|
| 113 |
+
'easyocr': 0,
|
| 114 |
+
'doctr': 0
|
| 115 |
+
},
|
| 116 |
+
'items': []
|
| 117 |
+
}
|
| 118 |
+
# Initialize source-specific type counters
|
| 119 |
+
for source in summary['by_source'].keys():
|
| 120 |
+
summary['by_source'][source]['by_type'][pattern_type] = 0
|
| 121 |
+
|
| 122 |
+
# Process each detection
|
| 123 |
+
source_confidences = {'tesseract': [], 'easyocr': [], 'doctr': []}
|
| 124 |
+
|
| 125 |
+
for result in results:
|
| 126 |
+
# Get source and confidence
|
| 127 |
+
source = result['source']
|
| 128 |
+
conf = result['confidence']
|
| 129 |
+
text_type = result['text_type']
|
| 130 |
+
|
| 131 |
+
# Update source statistics
|
| 132 |
+
summary['by_source'][source]['count'] += 1
|
| 133 |
+
source_confidences[source].append(conf)
|
| 134 |
+
|
| 135 |
+
# Update confidence ranges
|
| 136 |
+
if conf >= 0.9: summary['confidence_ranges']['0.9-1.0'] += 1
|
| 137 |
+
elif conf >= 0.8: summary['confidence_ranges']['0.8-0.9'] += 1
|
| 138 |
+
elif conf >= 0.7: summary['confidence_ranges']['0.7-0.8'] += 1
|
| 139 |
+
elif conf >= 0.6: summary['confidence_ranges']['0.6-0.7'] += 1
|
| 140 |
+
elif conf >= 0.5: summary['confidence_ranges']['0.5-0.6'] += 1
|
| 141 |
+
else: summary['confidence_ranges']['<0.5'] += 1
|
| 142 |
+
|
| 143 |
+
# Update type statistics
|
| 144 |
+
if text_type in summary['by_type']:
|
| 145 |
+
type_stats = summary['by_type'][text_type]
|
| 146 |
+
type_stats['count'] += 1
|
| 147 |
+
type_stats['by_source'][source] += 1
|
| 148 |
+
summary['by_source'][source]['by_type'][text_type] += 1
|
| 149 |
+
type_stats['items'].append({
|
| 150 |
+
'text': result['text'],
|
| 151 |
+
'confidence': conf,
|
| 152 |
+
'source': source,
|
| 153 |
+
'bbox': result['bbox']
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
# Add to detected items
|
| 157 |
+
summary['detected_items'].append({
|
| 158 |
+
'text': result['text'],
|
| 159 |
+
'type': text_type,
|
| 160 |
+
'confidence': conf,
|
| 161 |
+
'source': source,
|
| 162 |
+
'bbox': result['bbox']
|
| 163 |
+
})
|
| 164 |
+
|
| 165 |
+
# Calculate average confidences
|
| 166 |
+
for source, confs in source_confidences.items():
|
| 167 |
+
if confs:
|
| 168 |
+
summary['by_source'][source]['avg_confidence'] = sum(confs) / len(confs)
|
| 169 |
+
|
| 170 |
+
# Calculate average confidences for each type
|
| 171 |
+
for text_type, stats in summary['by_type'].items():
|
| 172 |
+
if stats['items']:
|
| 173 |
+
stats['avg_confidence'] = sum(item['confidence'] for item in stats['items']) / len(stats['items'])
|
| 174 |
+
|
| 175 |
+
return summary
|
| 176 |
+
|
| 177 |
+
def process_drawing(image_path, results_dir, storage=None):
|
| 178 |
+
try:
|
| 179 |
+
# Read image using cv2
|
| 180 |
+
image = cv2.imread(image_path)
|
| 181 |
+
if image is None:
|
| 182 |
+
raise ValueError(f"Could not read image from {image_path}")
|
| 183 |
+
|
| 184 |
+
# Create annotated copy
|
| 185 |
+
annotated_image = image.copy()
|
| 186 |
+
|
| 187 |
+
# Initialize results and summary
|
| 188 |
+
text_results = {
|
| 189 |
+
'file_name': image_path,
|
| 190 |
+
'detections': []
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
text_summary = {
|
| 194 |
+
'total_detections': 0,
|
| 195 |
+
'by_source': {
|
| 196 |
+
'tesseract': {'count': 0, 'avg_confidence': 0.0},
|
| 197 |
+
'easyocr': {'count': 0, 'avg_confidence': 0.0},
|
| 198 |
+
'doctr': {'count': 0, 'avg_confidence': 0.0}
|
| 199 |
+
},
|
| 200 |
+
'by_type': {
|
| 201 |
+
'equipment_tag': {'count': 0, 'avg_confidence': 0.0},
|
| 202 |
+
'line_number': {'count': 0, 'avg_confidence': 0.0},
|
| 203 |
+
'instrument_tag': {'count': 0, 'avg_confidence': 0.0},
|
| 204 |
+
'valve_number': {'count': 0, 'avg_confidence': 0.0},
|
| 205 |
+
'pipe_size': {'count': 0, 'avg_confidence': 0.0},
|
| 206 |
+
'flow_direction': {'count': 0, 'avg_confidence': 0.0},
|
| 207 |
+
'service_description': {'count': 0, 'avg_confidence': 0.0},
|
| 208 |
+
'process_instrument': {'count': 0, 'avg_confidence': 0.0},
|
| 209 |
+
'nozzle': {'count': 0, 'avg_confidence': 0.0},
|
| 210 |
+
'pipe_connector': {'count': 0, 'avg_confidence': 0.0},
|
| 211 |
+
'other': {'count': 0, 'avg_confidence': 0.0}
|
| 212 |
+
}
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
# Run OCR with different engines
|
| 216 |
+
tesseract_results = detect_with_tesseract(image)
|
| 217 |
+
easyocr_results = detect_with_easyocr(image)
|
| 218 |
+
doctr_results = detect_with_doctr(image)
|
| 219 |
+
|
| 220 |
+
# Combine results
|
| 221 |
+
all_detections = []
|
| 222 |
+
all_detections.extend([(res, 'tesseract') for res in tesseract_results])
|
| 223 |
+
all_detections.extend([(res, 'easyocr') for res in easyocr_results])
|
| 224 |
+
all_detections.extend([(res, 'doctr') for res in doctr_results])
|
| 225 |
+
|
| 226 |
+
# Process each detection
|
| 227 |
+
for detection, source in all_detections:
|
| 228 |
+
# Update text_results
|
| 229 |
+
text_results['detections'].append({
|
| 230 |
+
'text': detection['text'],
|
| 231 |
+
'bbox': detection['bbox'],
|
| 232 |
+
'confidence': detection['confidence'],
|
| 233 |
+
'source': source
|
| 234 |
+
})
|
| 235 |
+
|
| 236 |
+
# Update summary statistics
|
| 237 |
+
text_summary['total_detections'] += 1
|
| 238 |
+
text_summary['by_source'][source]['count'] += 1
|
| 239 |
+
text_summary['by_source'][source]['avg_confidence'] += detection['confidence']
|
| 240 |
+
|
| 241 |
+
# Draw detection on image
|
| 242 |
+
x1, y1, x2, y2 = detection['bbox']
|
| 243 |
+
cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
| 244 |
+
cv2.putText(annotated_image, detection['text'], (int(x1), int(y1)-5),
|
| 245 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
| 246 |
+
|
| 247 |
+
# Calculate average confidences
|
| 248 |
+
for source in text_summary['by_source']:
|
| 249 |
+
if text_summary['by_source'][source]['count'] > 0:
|
| 250 |
+
text_summary['by_source'][source]['avg_confidence'] /= text_summary['by_source'][source]['count']
|
| 251 |
+
|
| 252 |
+
# Save results with new naming convention
|
| 253 |
+
base_name = Path(image_path).stem
|
| 254 |
+
text_result_image_path = os.path.join(results_dir, f"{base_name}_detected_texts.jpg")
|
| 255 |
+
text_result_json_path = os.path.join(results_dir, f"{base_name}_detected_texts.json")
|
| 256 |
+
|
| 257 |
+
# Save the annotated image
|
| 258 |
+
success = cv2.imwrite(text_result_image_path, annotated_image)
|
| 259 |
+
if not success:
|
| 260 |
+
raise ValueError(f"Failed to save image to {text_result_image_path}")
|
| 261 |
+
|
| 262 |
+
# Save the JSON results
|
| 263 |
+
with open(text_result_json_path, 'w', encoding='utf-8') as f:
|
| 264 |
+
json.dump({
|
| 265 |
+
'file_name': image_path,
|
| 266 |
+
'summary': text_summary,
|
| 267 |
+
'detections': text_results['detections']
|
| 268 |
+
}, f, indent=4, ensure_ascii=False)
|
| 269 |
+
|
| 270 |
+
return {
|
| 271 |
+
'image_path': text_result_image_path,
|
| 272 |
+
'json_path': text_result_json_path,
|
| 273 |
+
'results': text_results
|
| 274 |
+
}, text_summary
|
| 275 |
+
|
| 276 |
+
except Exception as e:
|
| 277 |
+
print(f"Error in process_drawing: {str(e)}")
|
| 278 |
+
traceback.print_exc()
|
| 279 |
+
return None, None
|
| 280 |
+
|
| 281 |
+
def detect_with_tesseract(image):
|
| 282 |
+
"""Detect text using Tesseract OCR"""
|
| 283 |
+
# Configure Tesseract for technical drawings
|
| 284 |
+
custom_config = r'--oem 3 --psm 11 -c tessedit_char_whitelist="ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-.()" -c tessedit_write_images=true -c textord_heavy_nr=true -c textord_min_linesize=3'
|
| 285 |
+
|
| 286 |
+
try:
|
| 287 |
+
data = pytesseract.image_to_data(
|
| 288 |
+
image,
|
| 289 |
+
config=custom_config,
|
| 290 |
+
output_type=pytesseract.Output.DICT
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
results = []
|
| 294 |
+
for i in range(len(data['text'])):
|
| 295 |
+
conf = float(data['conf'][i])
|
| 296 |
+
if conf > 30: # Lower confidence threshold for technical text
|
| 297 |
+
text = data['text'][i].strip()
|
| 298 |
+
if text:
|
| 299 |
+
x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
|
| 300 |
+
results.append({
|
| 301 |
+
'text': text,
|
| 302 |
+
'bbox': [x, y, x + w, y + h],
|
| 303 |
+
'confidence': conf / 100.0
|
| 304 |
+
})
|
| 305 |
+
return results
|
| 306 |
+
|
| 307 |
+
except Exception as e:
|
| 308 |
+
logger.error(f"Tesseract error: {str(e)}")
|
| 309 |
+
return []
|
| 310 |
+
|
| 311 |
+
def detect_with_easyocr(image):
|
| 312 |
+
"""Detect text using EasyOCR"""
|
| 313 |
+
if easyocr_reader is None:
|
| 314 |
+
return []
|
| 315 |
+
|
| 316 |
+
try:
|
| 317 |
+
results = easyocr_reader.readtext(
|
| 318 |
+
np.array(image),
|
| 319 |
+
paragraph=False,
|
| 320 |
+
height_ths=2.0,
|
| 321 |
+
width_ths=2.0,
|
| 322 |
+
contrast_ths=0.2,
|
| 323 |
+
text_threshold=0.5
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
parsed_results = []
|
| 327 |
+
for bbox, text, conf in results:
|
| 328 |
+
x1, y1 = min(point[0] for point in bbox), min(point[1] for point in bbox)
|
| 329 |
+
x2, y2 = max(point[0] for point in bbox), max(point[1] for point in bbox)
|
| 330 |
+
|
| 331 |
+
parsed_results.append({
|
| 332 |
+
'text': text,
|
| 333 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
| 334 |
+
'confidence': conf
|
| 335 |
+
})
|
| 336 |
+
return parsed_results
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
logger.error(f"EasyOCR error: {str(e)}")
|
| 340 |
+
return []
|
| 341 |
+
|
| 342 |
+
def detect_with_doctr(image):
|
| 343 |
+
"""Detect text using DocTR"""
|
| 344 |
+
try:
|
| 345 |
+
# Convert PIL image to numpy array
|
| 346 |
+
image_np = np.array(image)
|
| 347 |
+
|
| 348 |
+
# Get predictions
|
| 349 |
+
result = doctr_model([image_np])
|
| 350 |
+
doc = result.export()
|
| 351 |
+
|
| 352 |
+
# Parse results
|
| 353 |
+
results = []
|
| 354 |
+
for page in doc['pages']:
|
| 355 |
+
for block in page['blocks']:
|
| 356 |
+
for line in block['lines']:
|
| 357 |
+
for word in line['words']:
|
| 358 |
+
# Convert normalized coordinates to absolute
|
| 359 |
+
height, width = image_np.shape[:2]
|
| 360 |
+
points = np.array(word['geometry']) * np.array([width, height])
|
| 361 |
+
x1, y1 = points.min(axis=0)
|
| 362 |
+
x2, y2 = points.max(axis=0)
|
| 363 |
+
|
| 364 |
+
results.append({
|
| 365 |
+
'text': word['value'],
|
| 366 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
| 367 |
+
'confidence': word.get('confidence', 0.5)
|
| 368 |
+
})
|
| 369 |
+
return results
|
| 370 |
+
|
| 371 |
+
except Exception as e:
|
| 372 |
+
logger.error(f"DocTR error: {str(e)}")
|
| 373 |
+
return []
|
| 374 |
+
|
| 375 |
+
def merge_overlapping_detections(results, iou_threshold=0.5):
|
| 376 |
+
"""Merge overlapping detections from different sources"""
|
| 377 |
+
if not results:
|
| 378 |
+
return []
|
| 379 |
+
|
| 380 |
+
def calculate_iou(box1, box2):
|
| 381 |
+
x1 = max(box1[0], box2[0])
|
| 382 |
+
y1 = max(box1[1], box2[1])
|
| 383 |
+
x2 = min(box1[2], box2[2])
|
| 384 |
+
y2 = min(box1[3], box2[3])
|
| 385 |
+
|
| 386 |
+
if x2 < x1 or y2 < y1:
|
| 387 |
+
return 0.0
|
| 388 |
+
|
| 389 |
+
intersection = (x2 - x1) * (y2 - y1)
|
| 390 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 391 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 392 |
+
union = area1 + area2 - intersection
|
| 393 |
+
|
| 394 |
+
return intersection / union if union > 0 else 0
|
| 395 |
+
|
| 396 |
+
merged = []
|
| 397 |
+
used = set()
|
| 398 |
+
|
| 399 |
+
for i, r1 in enumerate(results):
|
| 400 |
+
if i in used:
|
| 401 |
+
continue
|
| 402 |
+
|
| 403 |
+
current_group = [r1]
|
| 404 |
+
used.add(i)
|
| 405 |
+
|
| 406 |
+
for j, r2 in enumerate(results):
|
| 407 |
+
if j in used:
|
| 408 |
+
continue
|
| 409 |
+
|
| 410 |
+
if calculate_iou(r1['bbox'], r2['bbox']) > iou_threshold:
|
| 411 |
+
current_group.append(r2)
|
| 412 |
+
used.add(j)
|
| 413 |
+
|
| 414 |
+
if len(current_group) == 1:
|
| 415 |
+
merged.append(current_group[0])
|
| 416 |
+
else:
|
| 417 |
+
# Keep the detection with highest confidence
|
| 418 |
+
best_detection = max(current_group, key=lambda x: x['confidence'])
|
| 419 |
+
merged.append(best_detection)
|
| 420 |
+
|
| 421 |
+
return merged
|
| 422 |
+
|
| 423 |
+
def classify_text(text):
|
| 424 |
+
"""Classify text based on patterns"""
|
| 425 |
+
if not text:
|
| 426 |
+
return 'Unknown'
|
| 427 |
+
|
| 428 |
+
# Clean and normalize text
|
| 429 |
+
text = text.strip().upper()
|
| 430 |
+
text = re.sub(r'\s+', '', text)
|
| 431 |
+
|
| 432 |
+
for text_type, pattern in TEXT_PATTERNS.items():
|
| 433 |
+
if re.match(pattern, text):
|
| 434 |
+
return text_type
|
| 435 |
+
|
| 436 |
+
return 'Unknown'
|
| 437 |
+
|
| 438 |
+
def annotate_image(image, results):
|
| 439 |
+
"""Create annotated image with detections"""
|
| 440 |
+
# Convert image to RGB mode to ensure color support
|
| 441 |
+
if image.mode != 'RGB':
|
| 442 |
+
image = image.convert('RGB')
|
| 443 |
+
|
| 444 |
+
# Create drawing object
|
| 445 |
+
draw = ImageDraw.Draw(image)
|
| 446 |
+
try:
|
| 447 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 448 |
+
except IOError:
|
| 449 |
+
font = ImageFont.load_default()
|
| 450 |
+
|
| 451 |
+
# Define colors for different text types
|
| 452 |
+
colors = {
|
| 453 |
+
'Line_Number': "#FF0000", # Bright Red
|
| 454 |
+
'Equipment_Tag': "#00FF00", # Bright Green
|
| 455 |
+
'Instrument_Tag': "#0000FF", # Bright Blue
|
| 456 |
+
'Valve_Number': "#FFA500", # Bright Orange
|
| 457 |
+
'Pipe_Size': "#FF00FF", # Bright Magenta
|
| 458 |
+
'Process_Instrument': "#00FFFF", # Bright Cyan
|
| 459 |
+
'Nozzle': "#FFFF00", # Yellow
|
| 460 |
+
'Pipe_Connector': "#800080", # Purple
|
| 461 |
+
'Unknown': "#FF4444" # Light Red
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
# Draw detections
|
| 465 |
+
for result in results:
|
| 466 |
+
text_type = result.get('text_type', 'Unknown')
|
| 467 |
+
color = colors.get(text_type, colors['Unknown'])
|
| 468 |
+
|
| 469 |
+
# Draw bounding box
|
| 470 |
+
draw.rectangle(result['bbox'], outline=color, width=3)
|
| 471 |
+
|
| 472 |
+
# Create label
|
| 473 |
+
label = f"{result['text']} ({result['confidence']:.2f})"
|
| 474 |
+
if text_type != 'Unknown':
|
| 475 |
+
label += f" [{text_type}]"
|
| 476 |
+
|
| 477 |
+
# Draw label background
|
| 478 |
+
text_bbox = draw.textbbox((result['bbox'][0], result['bbox'][1] - 20), label, font=font)
|
| 479 |
+
draw.rectangle(text_bbox, fill="#FFFFFF")
|
| 480 |
+
|
| 481 |
+
# Draw label text
|
| 482 |
+
draw.text((result['bbox'][0], result['bbox'][1] - 20), label, fill=color, font=font)
|
| 483 |
+
|
| 484 |
+
return image
|
| 485 |
+
|
| 486 |
+
def save_annotated_image(image, path, storage):
|
| 487 |
+
"""Save annotated image with maximum quality"""
|
| 488 |
+
image_byte_array = io.BytesIO()
|
| 489 |
+
image.save(
|
| 490 |
+
image_byte_array,
|
| 491 |
+
format='PNG',
|
| 492 |
+
optimize=False,
|
| 493 |
+
compress_level=0
|
| 494 |
+
)
|
| 495 |
+
storage.save_file(path, image_byte_array.getvalue())
|
| 496 |
+
|
| 497 |
+
if __name__ == "__main__":
|
| 498 |
+
from storage import StorageFactory
|
| 499 |
+
import logging
|
| 500 |
+
|
| 501 |
+
# Configure logging
|
| 502 |
+
logging.basicConfig(level=logging.INFO)
|
| 503 |
+
logger = logging.getLogger(__name__)
|
| 504 |
+
|
| 505 |
+
# Initialize storage
|
| 506 |
+
storage = StorageFactory.get_storage()
|
| 507 |
+
|
| 508 |
+
# Test file paths
|
| 509 |
+
file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
|
| 510 |
+
result_path = "results"
|
| 511 |
+
|
| 512 |
+
try:
|
| 513 |
+
# Ensure result directory exists
|
| 514 |
+
os.makedirs(result_path, exist_ok=True)
|
| 515 |
+
|
| 516 |
+
# Process the drawing
|
| 517 |
+
logger.info(f"Processing file: {file_path}")
|
| 518 |
+
results, summary = process_drawing(file_path, result_path, storage)
|
| 519 |
+
|
| 520 |
+
# Print detailed results
|
| 521 |
+
print("\n=== DETAILED DETECTION RESULTS ===")
|
| 522 |
+
print(f"\nTotal Detections: {summary['total_detections']}")
|
| 523 |
+
|
| 524 |
+
print("\nBreakdown by Text Type:")
|
| 525 |
+
print("-" * 50)
|
| 526 |
+
for text_type, stats in summary['by_type'].items():
|
| 527 |
+
if stats['count'] > 0:
|
| 528 |
+
print(f"\n{text_type}:")
|
| 529 |
+
print(f" Count: {stats['count']}")
|
| 530 |
+
print(f" Average Confidence: {stats['avg_confidence']:.2f}")
|
| 531 |
+
print(" Items:")
|
| 532 |
+
for item in stats['items']:
|
| 533 |
+
print(f" - {item['text']} (conf: {item['confidence']:.2f}, source: {item['source']})")
|
| 534 |
+
|
| 535 |
+
print("\nBreakdown by OCR Engine:")
|
| 536 |
+
print("-" * 50)
|
| 537 |
+
for source, count in summary['by_source'].items():
|
| 538 |
+
print(f"{source}: {count} detections")
|
| 539 |
+
|
| 540 |
+
print("\nConfidence Distribution:")
|
| 541 |
+
print("-" * 50)
|
| 542 |
+
for range_name, count in summary['confidence_ranges'].items():
|
| 543 |
+
print(f"{range_name}: {count} detections")
|
| 544 |
+
|
| 545 |
+
# Print output paths
|
| 546 |
+
print("\nOutput Files:")
|
| 547 |
+
print("-" * 50)
|
| 548 |
+
print(f"Annotated Image: {results['image_path']}")
|
| 549 |
+
print(f"JSON Results: {results['json_path']}")
|
| 550 |
+
|
| 551 |
+
except Exception as e:
|
| 552 |
+
logger.error(f"Error processing file: {e}")
|
| 553 |
+
raise
|
utils.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import numpy as np
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 6 |
+
from detection_schema import BBox
|
| 7 |
+
from storage import StorageInterface
|
| 8 |
+
|
| 9 |
+
class DebugHandler:
|
| 10 |
+
"""Production-grade debugging and performance tracking"""
|
| 11 |
+
|
| 12 |
+
def __init__(self, enabled: bool = False, storage: StorageInterface = None):
|
| 13 |
+
self.enabled = enabled
|
| 14 |
+
self.storage = storage
|
| 15 |
+
self.metrics = {}
|
| 16 |
+
self._start_time = None
|
| 17 |
+
|
| 18 |
+
@contextmanager
|
| 19 |
+
def track_performance(self, operation_name: str):
|
| 20 |
+
"""Context manager for performance tracking"""
|
| 21 |
+
if self.enabled:
|
| 22 |
+
self._start_time = time.perf_counter()
|
| 23 |
+
logger.debug(f"Starting {operation_name}")
|
| 24 |
+
|
| 25 |
+
yield
|
| 26 |
+
|
| 27 |
+
if self.enabled:
|
| 28 |
+
duration = time.perf_counter() - self._start_time
|
| 29 |
+
self.metrics[operation_name] = duration
|
| 30 |
+
logger.debug(f"{operation_name} completed in {duration:.2f}s")
|
| 31 |
+
|
| 32 |
+
def save_artifact(self, name: str, data: bytes, extension: str = "png"):
|
| 33 |
+
"""Generic artifact storage handler"""
|
| 34 |
+
if self.enabled and self.storage:
|
| 35 |
+
path = f"debug/{name}.{extension}"
|
| 36 |
+
self.storage.save_file(path, data)
|
| 37 |
+
logger.info(f"Saved debug artifact: {path}")
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class CoordinateTransformer:
|
| 41 |
+
@staticmethod
|
| 42 |
+
def global_to_local_bbox(
|
| 43 |
+
bbox: Union[BBox, List[BBox]],
|
| 44 |
+
roi: Optional[np.ndarray]
|
| 45 |
+
) -> Union[BBox, List[BBox]]:
|
| 46 |
+
"""
|
| 47 |
+
Convert global BBox(es) to ROI-local coordinates
|
| 48 |
+
Handles both single BBox and lists of BBoxes
|
| 49 |
+
"""
|
| 50 |
+
if roi is None or len(roi) != 4:
|
| 51 |
+
return bbox
|
| 52 |
+
|
| 53 |
+
x_min, y_min, _, _ = roi
|
| 54 |
+
|
| 55 |
+
def convert(b: BBox) -> BBox:
|
| 56 |
+
return BBox(
|
| 57 |
+
xmin=b.xmin - x_min,
|
| 58 |
+
ymin=b.ymin - y_min,
|
| 59 |
+
xmax=b.xmax - x_min,
|
| 60 |
+
ymax=b.ymax - y_min
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def local_to_global_bbox(
|
| 67 |
+
bbox: Union[BBox, List[BBox]],
|
| 68 |
+
roi: Optional[np.ndarray]
|
| 69 |
+
) -> Union[BBox, List[BBox]]:
|
| 70 |
+
"""
|
| 71 |
+
Convert ROI-local BBox(es) to global coordinates
|
| 72 |
+
Handles both single BBox and lists of BBoxes
|
| 73 |
+
"""
|
| 74 |
+
if roi is None or len(roi) != 4:
|
| 75 |
+
return bbox
|
| 76 |
+
|
| 77 |
+
x_min, y_min, _, _ = roi
|
| 78 |
+
|
| 79 |
+
def convert(b: BBox) -> BBox:
|
| 80 |
+
return BBox(
|
| 81 |
+
xmin=b.xmin + x_min,
|
| 82 |
+
ymin=b.ymin + y_min,
|
| 83 |
+
xmax=b.xmax + x_min,
|
| 84 |
+
ymax=b.ymax + y_min
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
|
| 88 |
+
|
| 89 |
+
# Maintain legacy tuple support if needed
|
| 90 |
+
@staticmethod
|
| 91 |
+
def global_to_local(
|
| 92 |
+
bboxes: List[Tuple[int, int, int, int]],
|
| 93 |
+
roi: Optional[np.ndarray]
|
| 94 |
+
) -> List[Tuple[int, int, int, int]]:
|
| 95 |
+
"""Legacy tuple version for backward compatibility"""
|
| 96 |
+
if roi is None or len(roi) != 4:
|
| 97 |
+
return bboxes
|
| 98 |
+
|
| 99 |
+
x_min, y_min, _, _ = roi
|
| 100 |
+
return [(x1 - x_min, y1 - y_min, x2 - x_min, y2 - y_min)
|
| 101 |
+
for x1, y1, x2, y2 in bboxes]
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def local_to_global(
|
| 105 |
+
bboxes: List[Tuple[int, int, int, int]],
|
| 106 |
+
roi: Optional[np.ndarray]
|
| 107 |
+
) -> List[Tuple[int, int, int, int]]:
|
| 108 |
+
"""Legacy tuple version for backward compatibility"""
|
| 109 |
+
if roi is None or len(roi) != 4:
|
| 110 |
+
return bboxes
|
| 111 |
+
|
| 112 |
+
x_min, y_min, _, _ = roi
|
| 113 |
+
return [(x1 + x_min, y1 + y_min, x2 + x_min, y2 + y_min)
|
| 114 |
+
for x1, y1, x2, y2 in bboxes]
|
| 115 |
+
|
| 116 |
+
@staticmethod
|
| 117 |
+
def local_to_global_point(point: Tuple[int, int], roi: Optional[np.ndarray]) -> Tuple[int, int]:
|
| 118 |
+
"""Convert single point from local to global coordinates"""
|
| 119 |
+
if roi is None or len(roi) != 4:
|
| 120 |
+
return point
|
| 121 |
+
x_min, y_min, _, _ = roi
|
| 122 |
+
return (int(point[0] + x_min), int(point[1] + y_min))
|