Spaces:
Runtime error
Runtime error
msIntui commited on
Commit Β·
910e0d4
0
Parent(s):
feat: initial clean deployment
Browse files- .gitignore +74 -0
- README.md +494 -0
- assets/AiAgent.png +0 -0
- assets/intuigence.png +0 -0
- assets/user.png +0 -0
- base.py +138 -0
- base_config.py +8 -0
- chatbot_agent.py +194 -0
- common.py +14 -0
- config.py +97 -0
- data_aggregation_ai.py +411 -0
- detection_schema.py +481 -0
- detection_utils.py +81 -0
- detectors.py +733 -0
- download_models.py +49 -0
- gradioChatApp.py +806 -0
- graph_construction.py +115 -0
- graph_processor.py +118 -0
- graph_visualization.py +180 -0
- line_detection_ai.py +413 -0
- line_detectors.py +109 -0
- logger.py +30 -0
- pdf_processor.py +144 -0
- requirements.txt +54 -0
- results/002_page_1_aggregated.json +0 -0
- results/002_page_1_detected_symbols.json +0 -0
- storage.py +208 -0
- symbol_detection.py +454 -0
- text_detection_combined.py +563 -0
- utils.py +132 -0
.gitignore
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Directories to ignore
|
| 2 |
+
archive/
|
| 3 |
+
debug/
|
| 4 |
+
samples/
|
| 5 |
+
chat/
|
| 6 |
+
# Models - allow specific model files
|
| 7 |
+
models/*
|
| 8 |
+
!models/yolo/
|
| 9 |
+
!models/deeplsd/
|
| 10 |
+
!models/doctr/
|
| 11 |
+
!models/*.pt
|
| 12 |
+
!models/*.tar
|
| 13 |
+
results/
|
| 14 |
+
logs/
|
| 15 |
+
DeepLSD/
|
| 16 |
+
|
| 17 |
+
# Large files
|
| 18 |
+
*.tar
|
| 19 |
+
*.pt
|
| 20 |
+
*.pth
|
| 21 |
+
*.onnx
|
| 22 |
+
*.weights
|
| 23 |
+
|
| 24 |
+
# Python
|
| 25 |
+
__pycache__/
|
| 26 |
+
*.py[cod]
|
| 27 |
+
*$py.class
|
| 28 |
+
*.so
|
| 29 |
+
.Python
|
| 30 |
+
env/
|
| 31 |
+
build/
|
| 32 |
+
develop-eggs/
|
| 33 |
+
dist/
|
| 34 |
+
downloads/
|
| 35 |
+
eggs/
|
| 36 |
+
.eggs/
|
| 37 |
+
lib/
|
| 38 |
+
lib64/
|
| 39 |
+
parts/
|
| 40 |
+
sdist/
|
| 41 |
+
var/
|
| 42 |
+
*.egg-info/
|
| 43 |
+
.installed.cfg
|
| 44 |
+
*.egg
|
| 45 |
+
|
| 46 |
+
# Virtual Environment
|
| 47 |
+
.venv
|
| 48 |
+
venv/
|
| 49 |
+
ENV/
|
| 50 |
+
|
| 51 |
+
# IDE
|
| 52 |
+
.idea/
|
| 53 |
+
.vscode/
|
| 54 |
+
*.swp
|
| 55 |
+
*.swo
|
| 56 |
+
|
| 57 |
+
# Project specific
|
| 58 |
+
!results/
|
| 59 |
+
!results/*.json
|
| 60 |
+
debug/
|
| 61 |
+
*.log
|
| 62 |
+
*.gz
|
| 63 |
+
models/
|
| 64 |
+
archive/
|
| 65 |
+
weights/
|
| 66 |
+
|
| 67 |
+
# Environment variables
|
| 68 |
+
.env
|
| 69 |
+
.env.*
|
| 70 |
+
|
| 71 |
+
# Explicitly track assets
|
| 72 |
+
!assets/
|
| 73 |
+
!assets/*.png
|
| 74 |
+
!assets/*.css
|
README.md
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Intelligent_PID
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.50.2
|
| 8 |
+
app_file: gradioChatApp.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# P&ID Processing with AI-Powered Graph Construction
|
| 13 |
+
|
| 14 |
+
## Overview
|
| 15 |
+
This project processes P&ID (Piping and Instrumentation Diagram) images using multiple AI models for symbol detection, text recognition, and line detection. It constructs a graph representation of the diagram and provides an interactive interface for querying the diagram's contents.
|
| 16 |
+
|
| 17 |
+
## Features
|
| 18 |
+
- P&ID Document Processing
|
| 19 |
+
- Symbol Detection
|
| 20 |
+
- Text Recognition
|
| 21 |
+
- Line Detection
|
| 22 |
+
- Knowledge Graph Generation
|
| 23 |
+
- Interactive Chat Interface
|
| 24 |
+
|
| 25 |
+
## Usage
|
| 26 |
+
1. Upload a P&ID document
|
| 27 |
+
2. Click "Process Document"
|
| 28 |
+
3. View results in different tabs
|
| 29 |
+
4. Ask questions about the P&ID in the chat
|
| 30 |
+
|
| 31 |
+
## Process Flow
|
| 32 |
+
|
| 33 |
+
```mermaid
|
| 34 |
+
graph TD
|
| 35 |
+
subgraph "Document Input"
|
| 36 |
+
A[Upload Document] --> B[Validate File]
|
| 37 |
+
B -->|PDF/Image| C[Document Processor]
|
| 38 |
+
B -->|Invalid| ERR[Error Message]
|
| 39 |
+
C -->|PDF| D1[Extract Pages]
|
| 40 |
+
C -->|Image| D2[Direct Process]
|
| 41 |
+
end
|
| 42 |
+
|
| 43 |
+
subgraph "Image Preprocessing"
|
| 44 |
+
D1 --> E[Optimize Image]
|
| 45 |
+
D2 --> E
|
| 46 |
+
E -->|CLAHE Enhancement| E1[Contrast Enhancement]
|
| 47 |
+
E1 -->|Denoising| E2[Clean Image]
|
| 48 |
+
E2 -->|Binarization| E3[Binary Image]
|
| 49 |
+
E3 -->|Resize| E4[Normalized Image]
|
| 50 |
+
end
|
| 51 |
+
|
| 52 |
+
subgraph "Line Detection Pipeline"
|
| 53 |
+
E4 --> L1[Load DeepLSD Model]
|
| 54 |
+
L1 --> L2[Scale Image 0.1x]
|
| 55 |
+
L2 --> L3[Grayscale Conversion]
|
| 56 |
+
L3 --> L4[Model Inference]
|
| 57 |
+
L4 --> L5[Scale Coordinates]
|
| 58 |
+
L5 --> L6[Draw Lines]
|
| 59 |
+
end
|
| 60 |
+
|
| 61 |
+
subgraph "Detection Pipeline"
|
| 62 |
+
E4 --> F[Symbol Detection]
|
| 63 |
+
E4 --> G[Text Detection]
|
| 64 |
+
|
| 65 |
+
F --> S1[Load YOLO Models]
|
| 66 |
+
G --> T1[Load OCR Models]
|
| 67 |
+
|
| 68 |
+
S1 --> S2[Detect Symbols]
|
| 69 |
+
T1 --> T2[Detect Text]
|
| 70 |
+
|
| 71 |
+
S2 --> S3[Process Symbols]
|
| 72 |
+
T2 --> T3[Process Text]
|
| 73 |
+
|
| 74 |
+
L6 --> L7[Process Lines]
|
| 75 |
+
end
|
| 76 |
+
|
| 77 |
+
subgraph "Data Integration"
|
| 78 |
+
S3 --> I[Data Aggregation]
|
| 79 |
+
T3 --> I
|
| 80 |
+
L7 --> I
|
| 81 |
+
I --> J[Create Edges]
|
| 82 |
+
J --> K[Build Graph Network]
|
| 83 |
+
K --> L[Generate Knowledge Graph]
|
| 84 |
+
end
|
| 85 |
+
|
| 86 |
+
subgraph "User Interface"
|
| 87 |
+
L --> M[Interactive Visualization]
|
| 88 |
+
M --> N[Chat Interface]
|
| 89 |
+
N --> O[Query Processing]
|
| 90 |
+
O --> P[Response Generation]
|
| 91 |
+
P --> N
|
| 92 |
+
end
|
| 93 |
+
|
| 94 |
+
style A fill:#f9f,stroke:#333,stroke-width:2px
|
| 95 |
+
style F fill:#fbb,stroke:#333,stroke-width:2px
|
| 96 |
+
style G fill:#bfb,stroke:#333,stroke-width:2px
|
| 97 |
+
style H fill:#bbf,stroke:#333,stroke-width:2px
|
| 98 |
+
style I fill:#fbf,stroke:#333,stroke-width:2px
|
| 99 |
+
style N fill:#bbf,stroke:#333,stroke-width:2px
|
| 100 |
+
|
| 101 |
+
%% Add style for model nodes
|
| 102 |
+
style SM1 fill:#ffe6e6,stroke:#333,stroke-width:2px
|
| 103 |
+
style SM2 fill:#ffe6e6,stroke:#333,stroke-width:2px
|
| 104 |
+
style LM1 fill:#e6e6ff,stroke:#333,stroke-width:2px
|
| 105 |
+
style DC1 fill:#e6ffe6,stroke:#333,stroke-width:2px
|
| 106 |
+
style DC2 fill:#e6ffe6,stroke:#333,stroke-width:2px
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## Architecture
|
| 110 |
+
|
| 111 |
+

|
| 112 |
+
|
| 113 |
+
## Features
|
| 114 |
+
|
| 115 |
+
- **Multi-modal AI Processing**:
|
| 116 |
+
- Combined OCR approach using Tesseract, EasyOCR, and DocTR
|
| 117 |
+
- Symbol detection with optimized thresholds
|
| 118 |
+
- Intelligent line and connection detection
|
| 119 |
+
- **Document Processing**:
|
| 120 |
+
- Support for PDF, PNG, JPG, JPEG formats
|
| 121 |
+
- Automatic page extraction from PDFs
|
| 122 |
+
- Image optimization pipeline
|
| 123 |
+
- **Text Detection Types**:
|
| 124 |
+
- Equipment Tags
|
| 125 |
+
- Line Numbers
|
| 126 |
+
- Instrument Tags
|
| 127 |
+
- Valve Numbers
|
| 128 |
+
- Pipe Sizes
|
| 129 |
+
- Flow Directions
|
| 130 |
+
- Service Descriptions
|
| 131 |
+
- Process Instruments
|
| 132 |
+
- Nozzles
|
| 133 |
+
- Pipe Connectors
|
| 134 |
+
- **Data Integration**:
|
| 135 |
+
- Automatic edge detection
|
| 136 |
+
- Relationship mapping
|
| 137 |
+
- Confidence scoring
|
| 138 |
+
- Detailed detection statistics
|
| 139 |
+
- **User Interface**:
|
| 140 |
+
- Interactive visualization tabs
|
| 141 |
+
- Real-time processing feedback
|
| 142 |
+
- AI-powered chat interface
|
| 143 |
+
- Knowledge graph exploration
|
| 144 |
+
|
| 145 |
+
The entire process is visualized through an interactive Gradio-based UI, allowing users to upload a P&ID image, follow the detection steps, and view both the results and insights in real time.
|
| 146 |
+
|
| 147 |
+
## Key Files
|
| 148 |
+
|
| 149 |
+
- **gradioChatApp.py**: The main Gradio app script that handles the frontend and orchestrates the overall flow.
|
| 150 |
+
- **symbol_detection.py**: Module for detecting symbols using YOLO models.
|
| 151 |
+
- **text_detection_combined.py**: Unified module for text detection using multiple OCR engines (Tesseract, EasyOCR, DocTR).
|
| 152 |
+
- **line_detection_ai.py**: Module for detecting lines and connections using AI.
|
| 153 |
+
- **data_aggregation.py**: Aggregates detected elements into a structured format.
|
| 154 |
+
- **graph_construction.py**: Constructs the graph network from aggregated data.
|
| 155 |
+
- **graph_processor.py**: Handles graph visualization and processing.
|
| 156 |
+
- **pdf_processor.py**: Handles PDF document processing and page extraction.
|
| 157 |
+
|
| 158 |
+
## Setup and Installation
|
| 159 |
+
|
| 160 |
+
1. Clone the repository:
|
| 161 |
+
```bash
|
| 162 |
+
git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
|
| 163 |
+
cd intui-PnID-POC
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
2. Install dependencies using uv:
|
| 167 |
+
```bash
|
| 168 |
+
# Install uv if you haven't already
|
| 169 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 170 |
+
|
| 171 |
+
# Create and activate virtual environment
|
| 172 |
+
uv venv
|
| 173 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 174 |
+
|
| 175 |
+
# Install dependencies
|
| 176 |
+
uv pip install -r requirements.txt
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
3. Download required models:
|
| 180 |
+
```bash
|
| 181 |
+
python download_model.py # Downloads DeepLSD model for line detection
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
4. Run the application:
|
| 185 |
+
```bash
|
| 186 |
+
python gradioChatApp.py
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
## Models
|
| 190 |
+
|
| 191 |
+
### Line Detection Model
|
| 192 |
+
- **DeepLSD Model**:
|
| 193 |
+
- File: deeplsd_md.tar
|
| 194 |
+
- Purpose: Line segment detection in P&ID diagrams
|
| 195 |
+
- Input Resolution: Variable (scaled to 0.1x for performance)
|
| 196 |
+
- Processing: Grayscale conversion and binary thresholding
|
| 197 |
+
|
| 198 |
+
### Text Detection Models
|
| 199 |
+
- **Combined OCR Approach**:
|
| 200 |
+
- Tesseract OCR
|
| 201 |
+
- EasyOCR
|
| 202 |
+
- DocTR
|
| 203 |
+
- Purpose: Text recognition and classification
|
| 204 |
+
|
| 205 |
+
### Graph Processing
|
| 206 |
+
- **NetworkX-based**:
|
| 207 |
+
- Purpose: Graph construction and analysis
|
| 208 |
+
- Features: Node linking, edge creation, path analysis
|
| 209 |
+
|
| 210 |
+
## Updating the Environment
|
| 211 |
+
|
| 212 |
+
To update the environment, use the following:
|
| 213 |
+
|
| 214 |
+
```bash
|
| 215 |
+
conda env update --file environment.yml --prune
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
This command will update the environment according to changes made in the `environment.yml`.
|
| 219 |
+
|
| 220 |
+
### Step 6: Deactivate the environment
|
| 221 |
+
|
| 222 |
+
When you're done, deactivate the environment by:
|
| 223 |
+
|
| 224 |
+
```bash
|
| 225 |
+
conda deactivate
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
2. Upload a P&ID image through the interface.
|
| 229 |
+
3. Follow the sequential steps of symbol, text, and line detection.
|
| 230 |
+
4. View the generated graph and AI agent's reasoning in the real-time chat box.
|
| 231 |
+
5. Save and export the results if satisfactory.
|
| 232 |
+
|
| 233 |
+
## Folder Structure
|
| 234 |
+
|
| 235 |
+
```
|
| 236 |
+
βββ assets/
|
| 237 |
+
β βββ AiAgent.png
|
| 238 |
+
β βββ llm.png
|
| 239 |
+
βββ gradioApp.py
|
| 240 |
+
βββ symbol_detection.py
|
| 241 |
+
βββ text_detection_combined.py
|
| 242 |
+
βββ line_detection_ai.py
|
| 243 |
+
βββ data_aggregation.py
|
| 244 |
+
βββ graph_construction.py
|
| 245 |
+
βββ graph_processor.py
|
| 246 |
+
βββ pdf_processor.py
|
| 247 |
+
βββ pnid_agent.py
|
| 248 |
+
βββ requirements.txt
|
| 249 |
+
βββ results/
|
| 250 |
+
βββ models/
|
| 251 |
+
β βββ symbol_detection_model.pth
|
| 252 |
+
```
|
| 253 |
+
|
| 254 |
+
## /models Folder
|
| 255 |
+
|
| 256 |
+
- **models/symbol_detection_model.pth**: This folder contains the pre-trained model for symbol detection in P&ID diagrams. This model is crucial for detecting key symbols such as valves, instruments, and pipes in the diagram. Make sure to download the model and place it in the `/models` directory before running the app.
|
| 257 |
+
|
| 258 |
+
## Future Work
|
| 259 |
+
|
| 260 |
+
- **Advanced Symbol Recognition**: Improve symbol detection by integrating more sophisticated recognition models.
|
| 261 |
+
- **Graph Enhancement**: Introduce more complex graph structures and logic for representing the relationships between the diagram's elements.
|
| 262 |
+
- **Data Export**: Allow export in additional formats such as DEXPI-compliant XML or JSON.
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
# Docker Information
|
| 266 |
+
|
| 267 |
+
We'll cover the basic docker operations here.
|
| 268 |
+
|
| 269 |
+
## Building
|
| 270 |
+
|
| 271 |
+
There is a dockerfile for each different project (they have slightly different requiremnts).
|
| 272 |
+
|
| 273 |
+
### `gradioChatApp.py`
|
| 274 |
+
|
| 275 |
+
Run this one as follows:
|
| 276 |
+
|
| 277 |
+
```
|
| 278 |
+
> docker build -t exp-pnid-to-graph_chat-w-graph:0.0.4 -f Dockerfile-chatApp .
|
| 279 |
+
> docker tag exp-pnid-to-graph_chat-w-graph:0.0.4 intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
## Deploying to ACR
|
| 283 |
+
|
| 284 |
+
### `gradioChatApp.py`
|
| 285 |
+
|
| 286 |
+
```
|
| 287 |
+
> az login
|
| 288 |
+
> az acr login --name intaicr
|
| 289 |
+
> docker push intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
## Models
|
| 293 |
+
|
| 294 |
+
### Symbol Detection Models
|
| 295 |
+
- **Intui_SDM_41.pt**: Primary model for equipment and large symbol detection
|
| 296 |
+
- Classes: Equipment, Vessels, Heat Exchangers
|
| 297 |
+
- Input Resolution: 1280x1280
|
| 298 |
+
- Confidence Threshold: 0.3-0.7 (adaptive)
|
| 299 |
+
|
| 300 |
+
- **Intui_SDM_20.pt**: Secondary model for instrument and small symbol detection
|
| 301 |
+
- Classes: Instruments, Valves, Indicators
|
| 302 |
+
- Input Resolution: 1280x1280
|
| 303 |
+
- Confidence Threshold: 0.3-0.7 (adaptive)
|
| 304 |
+
|
| 305 |
+
### Line Detection Model
|
| 306 |
+
- **intui_LDM_01.pt**: Specialized model for line and connection detection
|
| 307 |
+
- Classes: Solid Lines, Dashed Lines
|
| 308 |
+
- Input Resolution: 1280x1280
|
| 309 |
+
- Confidence Threshold: 0.5
|
| 310 |
+
|
| 311 |
+
### Text Detection Models
|
| 312 |
+
- **Tesseract**: v5.3.0
|
| 313 |
+
- Configuration:
|
| 314 |
+
- OEM Mode: 3 (Default)
|
| 315 |
+
- PSM Mode: 11 (Sparse text)
|
| 316 |
+
- Custom Whitelist: A-Z, 0-9, special characters
|
| 317 |
+
|
| 318 |
+
- **EasyOCR**: v1.7.1
|
| 319 |
+
- Configuration:
|
| 320 |
+
- Language: English
|
| 321 |
+
- Paragraph Mode: False
|
| 322 |
+
- Height Threshold: 2.0
|
| 323 |
+
- Width Threshold: 2.0
|
| 324 |
+
- Contrast Threshold: 0.2
|
| 325 |
+
|
| 326 |
+
- **DocTR**: v0.6.0
|
| 327 |
+
- Models:
|
| 328 |
+
- fast_base-688a8b34.pt
|
| 329 |
+
- crnn_vgg16_bn-9762b0b0.pt
|
| 330 |
+
|
| 331 |
+
# P&ID Line Detection
|
| 332 |
+
|
| 333 |
+
A deep learning-based pipeline for detecting lines in P&ID diagrams using DeepLSD.
|
| 334 |
+
|
| 335 |
+
## Architecture
|
| 336 |
+
```mermaid
|
| 337 |
+
graph TD
|
| 338 |
+
A[Input Image] --> B[Line Detection]
|
| 339 |
+
B --> C[DeepLSD Model]
|
| 340 |
+
C --> D[Post-processing]
|
| 341 |
+
D --> E[Output JSON/Image]
|
| 342 |
+
|
| 343 |
+
subgraph Line Detection Pipeline
|
| 344 |
+
B --> F[Image Preprocessing]
|
| 345 |
+
F --> G[Scale Image 0.1x]
|
| 346 |
+
G --> H[Grayscale Conversion]
|
| 347 |
+
H --> C
|
| 348 |
+
C --> I[Scale Coordinates]
|
| 349 |
+
I --> J[Draw Lines]
|
| 350 |
+
J --> E
|
| 351 |
+
end
|
| 352 |
+
```
|
| 353 |
+
|
| 354 |
+
## Setup
|
| 355 |
+
|
| 356 |
+
### Prerequisites
|
| 357 |
+
- Python 3.12+
|
| 358 |
+
- uv (for dependency management)
|
| 359 |
+
- Git
|
| 360 |
+
- CUDA-capable GPU (optional)
|
| 361 |
+
|
| 362 |
+
### Installation
|
| 363 |
+
|
| 364 |
+
1. Clone the repository:
|
| 365 |
+
```bash
|
| 366 |
+
git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
|
| 367 |
+
cd intui-PnID-POC
|
| 368 |
+
```
|
| 369 |
+
|
| 370 |
+
2. Install dependencies using uv:
|
| 371 |
+
```bash
|
| 372 |
+
# Install uv if you haven't already
|
| 373 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 374 |
+
|
| 375 |
+
# Create and activate virtual environment
|
| 376 |
+
uv venv
|
| 377 |
+
source .venv/bin/activate # On Windows: .venv\Scripts\activate
|
| 378 |
+
|
| 379 |
+
# Install dependencies
|
| 380 |
+
uv pip install -r requirements.txt
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
3. Download DeepLSD model:
|
| 384 |
+
```bash
|
| 385 |
+
python download_model.py
|
| 386 |
+
```
|
| 387 |
+
|
| 388 |
+
## Usage
|
| 389 |
+
|
| 390 |
+
1. Run the line detection:
|
| 391 |
+
```bash
|
| 392 |
+
python line_detection_ai.py
|
| 393 |
+
```
|
| 394 |
+
|
| 395 |
+
The script will:
|
| 396 |
+
- Load the DeepLSD model
|
| 397 |
+
- Process input images at 0.1x scale for performance
|
| 398 |
+
- Generate line detections
|
| 399 |
+
- Save results as JSON and annotated images
|
| 400 |
+
|
| 401 |
+
## Configuration
|
| 402 |
+
|
| 403 |
+
Key parameters in `line_detection_ai.py`:
|
| 404 |
+
- `scale_factor`: Image scaling (default: 0.1)
|
| 405 |
+
- `device`: CPU/GPU selection
|
| 406 |
+
- `mask_json_paths`: Paths to text/symbol detection results
|
| 407 |
+
|
| 408 |
+
## Input/Output
|
| 409 |
+
|
| 410 |
+
### Input
|
| 411 |
+
- Original P&ID images
|
| 412 |
+
- Optional text/symbol detection JSON files for masking
|
| 413 |
+
|
| 414 |
+
### Output
|
| 415 |
+
- Annotated images with detected lines
|
| 416 |
+
- JSON files containing line coordinates and metadata
|
| 417 |
+
|
| 418 |
+
## Project Structure
|
| 419 |
+
|
| 420 |
+
```
|
| 421 |
+
βββ line_detection_ai.py # Main line detection script
|
| 422 |
+
βββ detectors.py # Line detector implementation
|
| 423 |
+
βββ download_model.py # Model download utility
|
| 424 |
+
βββ models/ # Directory for model files
|
| 425 |
+
β βββ deeplsd_md.tar # DeepLSD model weights
|
| 426 |
+
βββ results/ # Output directory
|
| 427 |
+
βββ requirements.txt # Project dependencies
|
| 428 |
+
```
|
| 429 |
+
|
| 430 |
+
## Dependencies
|
| 431 |
+
|
| 432 |
+
Key dependencies:
|
| 433 |
+
- torch
|
| 434 |
+
- opencv-python
|
| 435 |
+
- numpy
|
| 436 |
+
- DeepLSD
|
| 437 |
+
|
| 438 |
+
See `requirements.txt` for the complete list.
|
| 439 |
+
|
| 440 |
+
## Contributing
|
| 441 |
+
|
| 442 |
+
1. Fork the repository
|
| 443 |
+
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
|
| 444 |
+
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
|
| 445 |
+
4. Push to the branch (`git push origin feature/amazing-feature`)
|
| 446 |
+
5. Open a Pull Request
|
| 447 |
+
|
| 448 |
+
## License
|
| 449 |
+
|
| 450 |
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
| 451 |
+
|
| 452 |
+
## Acknowledgments
|
| 453 |
+
|
| 454 |
+
- [DeepLSD](https://github.com/cvg/DeepLSD) for the line detection model
|
| 455 |
+
- Original P&ID processing pipeline by IntuigenceAI
|
| 456 |
+
---
|
| 457 |
+
title: PnID Diagram Analyzer
|
| 458 |
+
emoji: π
|
| 459 |
+
colorFrom: blue
|
| 460 |
+
colorTo: red
|
| 461 |
+
sdk: gradio
|
| 462 |
+
sdk_version: 4.19.2
|
| 463 |
+
app_file: gradioChatApp.py
|
| 464 |
+
pinned: false
|
| 465 |
+
---
|
| 466 |
+
|
| 467 |
+
# PnID Diagram Analyzer
|
| 468 |
+
|
| 469 |
+
This app analyzes PnID diagrams using AI to detect and interpret various elements.
|
| 470 |
+
|
| 471 |
+
## Features
|
| 472 |
+
- Line detection
|
| 473 |
+
- Symbol recognition
|
| 474 |
+
- Text detection
|
| 475 |
+
- Graph construction
|
| 476 |
+
|
| 477 |
+
# Intuigence P&ID Analyzer
|
| 478 |
+
|
| 479 |
+
Interactive P&ID analysis tool powered by AI.
|
| 480 |
+
|
| 481 |
+
## Features
|
| 482 |
+
- P&ID Document Processing
|
| 483 |
+
- Symbol Detection
|
| 484 |
+
- Text Recognition
|
| 485 |
+
- Line Detection
|
| 486 |
+
- Knowledge Graph Generation
|
| 487 |
+
- Interactive Chat Interface
|
| 488 |
+
|
| 489 |
+
## Usage
|
| 490 |
+
1. Upload a P&ID document
|
| 491 |
+
2. Click "Process Document"
|
| 492 |
+
3. View results in different tabs
|
| 493 |
+
4. Ask questions about the P&ID in the chat
|
| 494 |
+
|
assets/AiAgent.png
ADDED
|
assets/intuigence.png
ADDED
|
assets/user.png
ADDED
|
base.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import List, Optional, Dict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from loguru import logger
|
| 11 |
+
import json
|
| 12 |
+
|
| 13 |
+
from common import DetectionResult
|
| 14 |
+
from storage import StorageInterface
|
| 15 |
+
from utils import DebugHandler, CoordinateTransformer
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BaseConfig(ABC):
|
| 19 |
+
"""Abstract Base Config for all configuration classes."""
|
| 20 |
+
|
| 21 |
+
def __post_init__(self):
|
| 22 |
+
"""Ensures default values are set correctly for all configs."""
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
class BaseDetector(ABC):
|
| 26 |
+
"""Abstract base class for detection models."""
|
| 27 |
+
|
| 28 |
+
def __init__(self,
|
| 29 |
+
config: BaseConfig,
|
| 30 |
+
debug_handler: DebugHandler = None):
|
| 31 |
+
self.config = config
|
| 32 |
+
self.debug_handler = debug_handler or DebugHandler()
|
| 33 |
+
|
| 34 |
+
@abstractmethod
|
| 35 |
+
def _load_model(self, model_path: str):
|
| 36 |
+
"""Load and return the detection model."""
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
@abstractmethod
|
| 40 |
+
def detect(self, image: np.ndarray, *args, **kwargs):
|
| 41 |
+
"""Run detection on an input image."""
|
| 42 |
+
pass
|
| 43 |
+
|
| 44 |
+
@abstractmethod
|
| 45 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 46 |
+
"""Preprocess the input image before detection."""
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 51 |
+
"""Postprocess the input image before detection."""
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class BaseDetectionPipeline(ABC):
|
| 56 |
+
"""Abstract base class for detection pipelines."""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
storage: StorageInterface,
|
| 61 |
+
debug_handler=None
|
| 62 |
+
):
|
| 63 |
+
# self.detector = detector
|
| 64 |
+
self.storage = storage
|
| 65 |
+
self.debug_handler = debug_handler or DebugHandler()
|
| 66 |
+
self.transformer = CoordinateTransformer()
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def process_image(
|
| 70 |
+
self,
|
| 71 |
+
image_path: str,
|
| 72 |
+
output_dir: str,
|
| 73 |
+
config
|
| 74 |
+
) -> DetectionResult:
|
| 75 |
+
"""Main processing pipeline for a single image."""
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def _apply_roi(self, image: np.ndarray, roi: np.ndarray) -> np.ndarray:
|
| 79 |
+
"""Apply region of interest cropping."""
|
| 80 |
+
if roi is not None and len(roi) == 4:
|
| 81 |
+
x_min, y_min, x_max, y_max = roi
|
| 82 |
+
return image[y_min:y_max, x_min:x_max]
|
| 83 |
+
return image
|
| 84 |
+
|
| 85 |
+
def _adjust_coordinates(self, detections: List[Dict], roi: np.ndarray) -> List[Dict]:
|
| 86 |
+
"""Adjust detection coordinates based on ROI"""
|
| 87 |
+
if roi is None or len(roi) != 4:
|
| 88 |
+
return detections
|
| 89 |
+
|
| 90 |
+
x_offset, y_offset = roi[0], roi[1]
|
| 91 |
+
adjusted = []
|
| 92 |
+
|
| 93 |
+
for det in detections:
|
| 94 |
+
try:
|
| 95 |
+
adjusted_bbox = [
|
| 96 |
+
int(det["bbox"][0] + x_offset),
|
| 97 |
+
int(det["bbox"][1] + y_offset),
|
| 98 |
+
int(det["bbox"][2] + x_offset),
|
| 99 |
+
int(det["bbox"][3] + y_offset)
|
| 100 |
+
]
|
| 101 |
+
adjusted_det = {**det, "bbox": adjusted_bbox}
|
| 102 |
+
adjusted.append(adjusted_det)
|
| 103 |
+
except KeyError:
|
| 104 |
+
logger.warning("Invalid detection format during coordinate adjustment")
|
| 105 |
+
return adjusted
|
| 106 |
+
|
| 107 |
+
def _persist_results(
|
| 108 |
+
self,
|
| 109 |
+
output_dir: str,
|
| 110 |
+
image_path: str,
|
| 111 |
+
detections: List[Dict],
|
| 112 |
+
annotated_image: Optional[np.ndarray]
|
| 113 |
+
) -> Dict[str, str]:
|
| 114 |
+
"""Save detection results and annotations"""
|
| 115 |
+
self.storage.create_directory(output_dir)
|
| 116 |
+
base_name = Path(image_path).stem
|
| 117 |
+
|
| 118 |
+
# Save JSON results
|
| 119 |
+
json_path = Path(output_dir) / f"{base_name}_lines.json"
|
| 120 |
+
self.storage.save_file(
|
| 121 |
+
str(json_path),
|
| 122 |
+
json.dumps({
|
| 123 |
+
"solid_lines": {"lines": detections},
|
| 124 |
+
"dashed_lines": {"lines": []}
|
| 125 |
+
}, indent=2).encode('utf-8')
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Save annotated image
|
| 129 |
+
img_path = None
|
| 130 |
+
if annotated_image is not None:
|
| 131 |
+
img_path = Path(output_dir) / f"{base_name}_annotated.jpg"
|
| 132 |
+
_, img_data = cv2.imencode('.jpg', annotated_image)
|
| 133 |
+
self.storage.save_file(str(img_path), img_data.tobytes())
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"json_path": str(json_path),
|
| 137 |
+
"image_path": str(img_path) if img_path else None
|
| 138 |
+
}
|
base_config.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
|
| 3 |
+
@dataclass
|
| 4 |
+
class BaseConfig:
|
| 5 |
+
"""Base configuration class"""
|
| 6 |
+
def __init__(self, **kwargs):
|
| 7 |
+
for key, value in kwargs.items():
|
| 8 |
+
setattr(self, key, value)
|
chatbot_agent.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# chatbot_agent.py
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import re
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
import traceback
|
| 8 |
+
import logging
|
| 9 |
+
from dotenv import load_dotenv
|
| 10 |
+
|
| 11 |
+
# Load environment variables
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
# Get logger
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
# Initialize OpenAI client with error handling
|
| 18 |
+
def get_openai_client():
|
| 19 |
+
api_key = os.getenv("OPENAI_API_KEY")
|
| 20 |
+
if not api_key:
|
| 21 |
+
raise ValueError("OpenAI API key not found in environment variables")
|
| 22 |
+
return OpenAI(api_key=api_key)
|
| 23 |
+
|
| 24 |
+
def format_message(role, content):
|
| 25 |
+
"""Format message for chatbot history."""
|
| 26 |
+
return {"role": role, "content": content}
|
| 27 |
+
|
| 28 |
+
def initialize_graph_prompt(graph_data):
|
| 29 |
+
"""Initialize the conversation with available node and edge information."""
|
| 30 |
+
try:
|
| 31 |
+
# Get summary info with safe fallbacks
|
| 32 |
+
summary = graph_data.get('summary', {})
|
| 33 |
+
summary_parts = []
|
| 34 |
+
|
| 35 |
+
# Only include counts that exist
|
| 36 |
+
if 'symbol_count' in summary:
|
| 37 |
+
summary_parts.append(f"Symbols: {summary['symbol_count']}")
|
| 38 |
+
if 'text_count' in summary:
|
| 39 |
+
summary_parts.append(f"Texts: {summary['text_count']}")
|
| 40 |
+
if 'line_count' in summary:
|
| 41 |
+
summary_parts.append(f"Lines: {summary['line_count']}")
|
| 42 |
+
if 'edge_count' in summary:
|
| 43 |
+
summary_parts.append(f"Edges: {summary['edge_count']}")
|
| 44 |
+
|
| 45 |
+
summary_info = ", ".join(summary_parts) + "."
|
| 46 |
+
|
| 47 |
+
# Prepare node details only if they exist
|
| 48 |
+
node_details = ""
|
| 49 |
+
detailed_results = graph_data.get('detailed_results', {})
|
| 50 |
+
if 'symbols' in detailed_results:
|
| 51 |
+
node_details = "Nodes (symbols) in the graph include:\n"
|
| 52 |
+
for symbol in detailed_results['symbols']:
|
| 53 |
+
details = []
|
| 54 |
+
if 'symbol_id' in symbol:
|
| 55 |
+
details.append(f"ID: {symbol['symbol_id']}")
|
| 56 |
+
if 'class_id' in symbol:
|
| 57 |
+
details.append(f"Class: {symbol['class_id']}")
|
| 58 |
+
if 'category' in symbol:
|
| 59 |
+
details.append(f"Category: {symbol['category']}")
|
| 60 |
+
if 'type' in symbol:
|
| 61 |
+
details.append(f"Type: {symbol['type']}")
|
| 62 |
+
if 'label' in symbol:
|
| 63 |
+
details.append(f"Label: {symbol['label']}")
|
| 64 |
+
if details: # Only add if we have any details
|
| 65 |
+
node_details += ", ".join(details) + "\n"
|
| 66 |
+
|
| 67 |
+
initial_prompt = (
|
| 68 |
+
"You have access to a knowledge graph generated from a P&ID diagram. "
|
| 69 |
+
f"The summary information includes:\n{summary_info}\n\n"
|
| 70 |
+
f"{node_details}\n"
|
| 71 |
+
"Answer questions about the P&ID elements using this information."
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
return initial_prompt
|
| 75 |
+
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Error creating initial prompt: {str(e)}")
|
| 78 |
+
return ("I have access to a P&ID diagram knowledge graph. "
|
| 79 |
+
"I can help answer questions about the diagram elements.")
|
| 80 |
+
|
| 81 |
+
def get_assistant_response(user_message, json_path):
|
| 82 |
+
"""Generate response based on P&ID data and OpenAI."""
|
| 83 |
+
try:
|
| 84 |
+
client = get_openai_client()
|
| 85 |
+
# Load the aggregated data
|
| 86 |
+
with open(json_path, 'r') as f:
|
| 87 |
+
data = json.load(f)
|
| 88 |
+
|
| 89 |
+
# Process the user's question
|
| 90 |
+
question = user_message.lower()
|
| 91 |
+
|
| 92 |
+
# Use rule-based responses for specific questions
|
| 93 |
+
if "valve" in question or "valves" in question:
|
| 94 |
+
valve_count = sum(1 for symbol in data.get('symbols', [])
|
| 95 |
+
if 'class' in symbol and 'valve' in symbol['class'].lower())
|
| 96 |
+
return f"I found {valve_count} valves in this P&ID."
|
| 97 |
+
|
| 98 |
+
elif "pump" in question or "pumps" in question:
|
| 99 |
+
pump_count = sum(1 for symbol in data.get('symbols', [])
|
| 100 |
+
if 'class' in symbol and 'pump' in symbol['class'].lower())
|
| 101 |
+
return f"I found {pump_count} pumps in this P&ID."
|
| 102 |
+
|
| 103 |
+
elif "equipment" in question or "components" in question:
|
| 104 |
+
equipment_types = {}
|
| 105 |
+
for symbol in data.get('symbols', []):
|
| 106 |
+
if 'class' in symbol:
|
| 107 |
+
eq_type = symbol['class']
|
| 108 |
+
equipment_types[eq_type] = equipment_types.get(eq_type, 0) + 1
|
| 109 |
+
|
| 110 |
+
response = "Here's a summary of the equipment I found:\n"
|
| 111 |
+
for eq_type, count in equipment_types.items():
|
| 112 |
+
response += f"- {eq_type}: {count}\n"
|
| 113 |
+
return response
|
| 114 |
+
|
| 115 |
+
# For other questions, use OpenAI
|
| 116 |
+
else:
|
| 117 |
+
# Prepare the conversation context
|
| 118 |
+
graph_data = {
|
| 119 |
+
"summary": {
|
| 120 |
+
"symbol_count": len(data.get('symbols', [])),
|
| 121 |
+
"text_count": len(data.get('texts', [])),
|
| 122 |
+
"line_count": len(data.get('lines', [])),
|
| 123 |
+
"edge_count": len(data.get('edges', [])),
|
| 124 |
+
},
|
| 125 |
+
"detailed_results": data
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
initial_prompt = initialize_graph_prompt(graph_data)
|
| 129 |
+
conversation = [
|
| 130 |
+
{"role": "system", "content": initial_prompt},
|
| 131 |
+
{"role": "user", "content": user_message}
|
| 132 |
+
]
|
| 133 |
+
|
| 134 |
+
response = client.chat.completions.create(
|
| 135 |
+
model="gpt-4-turbo",
|
| 136 |
+
messages=conversation
|
| 137 |
+
)
|
| 138 |
+
return response.choices[0].message.content
|
| 139 |
+
|
| 140 |
+
except Exception as e:
|
| 141 |
+
logger.error(f"Error in get_assistant_response: {str(e)}")
|
| 142 |
+
logger.error(traceback.format_exc())
|
| 143 |
+
return "I apologize, but I encountered an error analyzing the P&ID data. Please try asking a different question."
|
| 144 |
+
|
| 145 |
+
# Testing and Usage block
|
| 146 |
+
if __name__ == "__main__":
|
| 147 |
+
# Load the knowledge graph data from JSON file
|
| 148 |
+
json_file_path = "results/0_aggregated_detections.json"
|
| 149 |
+
try:
|
| 150 |
+
with open(json_file_path, 'r') as file:
|
| 151 |
+
graph_data = json.load(file)
|
| 152 |
+
except FileNotFoundError:
|
| 153 |
+
print(f"Error: File not found at {json_file_path}")
|
| 154 |
+
graph_data = None
|
| 155 |
+
except json.JSONDecodeError:
|
| 156 |
+
print("Error: Failed to decode JSON. Please check the file format.")
|
| 157 |
+
graph_data = None
|
| 158 |
+
|
| 159 |
+
# Initialize conversation history with assistant's welcome message
|
| 160 |
+
history = [format_message("assistant", "Hello! I am ready to answer your questions about the P&ID knowledge graph. The graph includes nodes (symbols), edges, linkers, and text tags, and I have detailed information available about each. Please ask any questions related to these elements and their connections.")]
|
| 161 |
+
|
| 162 |
+
# Print the assistant's welcome message
|
| 163 |
+
print("Assistant:", history[0]["content"])
|
| 164 |
+
|
| 165 |
+
# Individual Testing Options
|
| 166 |
+
if graph_data:
|
| 167 |
+
# Option 1: Test the graph prompt initialization
|
| 168 |
+
print("\n--- Test: Graph Prompt Initialization ---")
|
| 169 |
+
initial_prompt = initialize_graph_prompt(graph_data)
|
| 170 |
+
print(initial_prompt)
|
| 171 |
+
|
| 172 |
+
# Option 2: Simulate a conversation with a test question
|
| 173 |
+
print("\n--- Test: Simulate Conversation ---")
|
| 174 |
+
test_question = "Can you tell me about the connections between the nodes?"
|
| 175 |
+
history.append(format_message("user", test_question))
|
| 176 |
+
|
| 177 |
+
print(f"\nUser: {test_question}")
|
| 178 |
+
for response in get_assistant_response(test_question, json_file_path):
|
| 179 |
+
print("Assistant:", response)
|
| 180 |
+
history.append(format_message("assistant", response))
|
| 181 |
+
|
| 182 |
+
# Option 3: Manually input questions for interactive testing
|
| 183 |
+
while True:
|
| 184 |
+
user_question = input("\nYou: ")
|
| 185 |
+
if user_question.lower() in ["exit", "quit"]:
|
| 186 |
+
print("Exiting chat. Goodbye!")
|
| 187 |
+
break
|
| 188 |
+
|
| 189 |
+
history.append(format_message("user", user_question))
|
| 190 |
+
for response in get_assistant_response(user_question, json_file_path):
|
| 191 |
+
print("Assistant:", response)
|
| 192 |
+
history.append(format_message("assistant", response))
|
| 193 |
+
else:
|
| 194 |
+
print("Unable to load graph data. Please check the file path and format.")
|
common.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from detection_schema import Line
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class DetectionResult:
|
| 9 |
+
success: bool
|
| 10 |
+
error: Optional[str] = None
|
| 11 |
+
annotated_image: Optional[np.ndarray] = None
|
| 12 |
+
processing_time: float = 0.0
|
| 13 |
+
json_path: Optional[str] = None
|
| 14 |
+
image_path: Optional[str] = None
|
config.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 3 |
+
import numpy as np
|
| 4 |
+
from base import BaseConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class ImageConfig(BaseConfig):
|
| 9 |
+
"""Configuration for global image-related settings"""
|
| 10 |
+
|
| 11 |
+
roi: Optional[np.ndarray] = field(default_factory=lambda: np.array([500, 500, 5300, 4000]))
|
| 12 |
+
mask_json_path: str = "./text_and_symbol_bboxes.json"
|
| 13 |
+
save_annotations: bool = True
|
| 14 |
+
annotation_style: Dict = field(default_factory=lambda: {
|
| 15 |
+
'bbox_color': (255, 0, 0),
|
| 16 |
+
'line_color': (0, 255, 0),
|
| 17 |
+
'text_color': (0, 0, 255),
|
| 18 |
+
'thickness': 2,
|
| 19 |
+
'font_scale': 0.6
|
| 20 |
+
})
|
| 21 |
+
|
| 22 |
+
@dataclass
|
| 23 |
+
class SymbolConfig(BaseConfig):
|
| 24 |
+
"""Configuration for Symbol Detection"""
|
| 25 |
+
model_path: str = "models/symbol_detection.pt"
|
| 26 |
+
confidence_threshold: float = 0.5
|
| 27 |
+
nms_threshold: float = 0.3
|
| 28 |
+
min_size: Tuple[int, int] = (10, 10)
|
| 29 |
+
max_size: Tuple[int, int] = (200, 200)
|
| 30 |
+
class_names: List[str] = field(default_factory=lambda: ["background", "valve", "pump", "sensor"])
|
| 31 |
+
|
| 32 |
+
# Optional: Keep the multiple thresholds for experimentation
|
| 33 |
+
confidence_thresholds: List[float] = field(default_factory=lambda: [0.1, 0.3, 0.5, 0.7, 0.9])
|
| 34 |
+
model_paths: Dict[str, str] = field(default_factory=lambda: {
|
| 35 |
+
"model1": "models/Intui_SDM_41.pt",
|
| 36 |
+
"model2": "models/Intui_SDM_20.pt"
|
| 37 |
+
})
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@dataclass
|
| 42 |
+
class TagConfig(BaseConfig):
|
| 43 |
+
"""Configuration for Tag Detection with OCR"""
|
| 44 |
+
model_path: str = "models/tag_detection.json"
|
| 45 |
+
confidence_threshold: float = 0.5
|
| 46 |
+
iou_threshold: float = 0.4
|
| 47 |
+
ocr_engines: List[str] = field(default_factory=lambda: ['tesseract', 'easyocr', 'doctr'])
|
| 48 |
+
text_patterns: Dict[str, str] = field(default_factory=lambda: {
|
| 49 |
+
'Line_Number': r"\d{1,5}-[A-Z]{2,4}-\d{1,3}",
|
| 50 |
+
'Equipment_Tag': r"[A-Z]{1,3}-[A-Z0-9]{1,4}-\d{1,3}",
|
| 51 |
+
'Instrument_Tag': r"\d{2,3}-[A-Z]{2,4}-\d{2,3}",
|
| 52 |
+
'Valve_Number': r"[A-Z]{1,2}-\d{3}",
|
| 53 |
+
'Pipe_Size': r"\d{1,2}\"",
|
| 54 |
+
'Flow_Direction': r"FROM|TO",
|
| 55 |
+
'Service_Description': r"STEAM|WATER|AIR|GAS|DRAIN",
|
| 56 |
+
'Process_Instrument': r"\d{2,3}(?:-[A-Z]{2,3})?-\d{2,3}|[A-Z]{2,3}-\d{2,3}",
|
| 57 |
+
'Nozzle': r"N[0-9]{1,2}|MH",
|
| 58 |
+
'Pipe_Connector': r"[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5}"
|
| 59 |
+
})
|
| 60 |
+
tesseract_config: str = r'--oem 3 --psm 11'
|
| 61 |
+
easyocr_params: Dict = field(default_factory=lambda: {
|
| 62 |
+
'paragraph': False,
|
| 63 |
+
'height_ths': 2.0,
|
| 64 |
+
'width_ths': 2.0,
|
| 65 |
+
'contrast_ths': 0.2
|
| 66 |
+
})
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class LineConfig(BaseConfig):
|
| 70 |
+
"""Configuration for Line Detection"""
|
| 71 |
+
|
| 72 |
+
threshold_distance: float = 10.0
|
| 73 |
+
expansion_factor: float = 1.1
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclass
|
| 77 |
+
class PointConfig(BaseConfig):
|
| 78 |
+
"""Configuration for Point Detection"""
|
| 79 |
+
|
| 80 |
+
threshold_distance: float = 10.0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@dataclass
|
| 84 |
+
class JunctionConfig(BaseConfig):
|
| 85 |
+
"""Configuration for Junction Detection"""
|
| 86 |
+
|
| 87 |
+
window_size: int = 21
|
| 88 |
+
radius: int = 5
|
| 89 |
+
angle_threshold_lb: float = 15.0
|
| 90 |
+
angle_threshold_ub: float = 75.0
|
| 91 |
+
|
| 92 |
+
# @dataclass
|
| 93 |
+
# class JunctionConfig:
|
| 94 |
+
# radius: int = 5
|
| 95 |
+
# angle_threshold: float = 25.0
|
| 96 |
+
# colinear_threshold: float = 5.0
|
| 97 |
+
# connection_threshold: float = 5.0
|
data_aggregation_ai.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from typing import List, Dict, Optional, Tuple
|
| 6 |
+
from storage import StorageFactory
|
| 7 |
+
import uuid
|
| 8 |
+
import traceback
|
| 9 |
+
import os
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
class DataAggregator:
|
| 16 |
+
def __init__(self, storage=None):
|
| 17 |
+
self.storage = storage or StorageFactory.get_storage()
|
| 18 |
+
self.logger = logging.getLogger(__name__)
|
| 19 |
+
|
| 20 |
+
def _parse_line_data(self, lines_data: dict) -> List[dict]:
|
| 21 |
+
"""Parse line detection data with coordinate validation"""
|
| 22 |
+
parsed_lines = []
|
| 23 |
+
|
| 24 |
+
for line in lines_data.get("lines", []):
|
| 25 |
+
try:
|
| 26 |
+
# Extract and validate line coordinates
|
| 27 |
+
start_coords = line["start"]["coords"]
|
| 28 |
+
end_coords = line["end"]["coords"]
|
| 29 |
+
bbox = line["bbox"]
|
| 30 |
+
|
| 31 |
+
# Validate coordinates
|
| 32 |
+
if not (self._is_valid_point(start_coords) and
|
| 33 |
+
self._is_valid_point(end_coords) and
|
| 34 |
+
self._is_valid_bbox(bbox)):
|
| 35 |
+
self.logger.warning(f"Invalid coordinates in line: {line['id']}")
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
# Create parsed line with validated coordinates
|
| 39 |
+
parsed_line = {
|
| 40 |
+
"id": line["id"],
|
| 41 |
+
"start_point": {
|
| 42 |
+
"x": int(start_coords["x"]),
|
| 43 |
+
"y": int(start_coords["y"]),
|
| 44 |
+
"type": line["start"]["type"],
|
| 45 |
+
"confidence": line["start"]["confidence"]
|
| 46 |
+
},
|
| 47 |
+
"end_point": {
|
| 48 |
+
"x": int(end_coords["x"]),
|
| 49 |
+
"y": int(end_coords["y"]),
|
| 50 |
+
"type": line["end"]["type"],
|
| 51 |
+
"confidence": line["end"]["confidence"]
|
| 52 |
+
},
|
| 53 |
+
"bbox": {
|
| 54 |
+
"xmin": int(bbox["xmin"]),
|
| 55 |
+
"ymin": int(bbox["ymin"]),
|
| 56 |
+
"xmax": int(bbox["xmax"]),
|
| 57 |
+
"ymax": int(bbox["ymax"])
|
| 58 |
+
},
|
| 59 |
+
"style": line["style"],
|
| 60 |
+
"confidence": line["confidence"]
|
| 61 |
+
}
|
| 62 |
+
parsed_lines.append(parsed_line)
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}")
|
| 66 |
+
continue
|
| 67 |
+
|
| 68 |
+
return parsed_lines
|
| 69 |
+
|
| 70 |
+
def _is_valid_point(self, point: dict) -> bool:
|
| 71 |
+
"""Validate point coordinates"""
|
| 72 |
+
try:
|
| 73 |
+
x, y = point.get("x"), point.get("y")
|
| 74 |
+
return (isinstance(x, (int, float)) and
|
| 75 |
+
isinstance(y, (int, float)) and
|
| 76 |
+
0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed
|
| 77 |
+
except:
|
| 78 |
+
return False
|
| 79 |
+
|
| 80 |
+
def _is_valid_bbox(self, bbox: dict) -> bool:
|
| 81 |
+
"""Validate bbox coordinates"""
|
| 82 |
+
try:
|
| 83 |
+
xmin = bbox.get("xmin")
|
| 84 |
+
ymin = bbox.get("ymin")
|
| 85 |
+
xmax = bbox.get("xmax")
|
| 86 |
+
ymax = bbox.get("ymax")
|
| 87 |
+
|
| 88 |
+
return (isinstance(xmin, (int, float)) and
|
| 89 |
+
isinstance(ymin, (int, float)) and
|
| 90 |
+
isinstance(xmax, (int, float)) and
|
| 91 |
+
isinstance(ymax, (int, float)) and
|
| 92 |
+
xmin < xmax and ymin < ymax and
|
| 93 |
+
0 <= xmin <= 10000 and 0 <= ymin <= 10000 and
|
| 94 |
+
0 <= xmax <= 10000 and 0 <= ymax <= 10000)
|
| 95 |
+
except:
|
| 96 |
+
return False
|
| 97 |
+
|
| 98 |
+
def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]:
|
| 99 |
+
"""Create graph nodes and edges from detections"""
|
| 100 |
+
nodes = []
|
| 101 |
+
edges = []
|
| 102 |
+
|
| 103 |
+
# Debug input data
|
| 104 |
+
self.logger.info("Creating graph data with:")
|
| 105 |
+
self.logger.info(f"Lines: {len(lines)}")
|
| 106 |
+
self.logger.info(f"Symbols: {len(symbols)}")
|
| 107 |
+
self.logger.info(f"Texts: {len(texts)}")
|
| 108 |
+
|
| 109 |
+
try:
|
| 110 |
+
# Process symbols
|
| 111 |
+
for symbol in symbols:
|
| 112 |
+
bbox = symbol["bbox"] # bbox is a list [x1,y1,x2,y2]
|
| 113 |
+
nodes.append({
|
| 114 |
+
"id": symbol["symbol_id"],
|
| 115 |
+
"type": "symbol",
|
| 116 |
+
"category": symbol.get("category", ""),
|
| 117 |
+
"label": symbol.get("label", ""),
|
| 118 |
+
"confidence": symbol.get("confidence", 0.0),
|
| 119 |
+
"x": (bbox[0] + bbox[2]) / 2, # Use list indices
|
| 120 |
+
"y": (bbox[1] + bbox[3]) / 2, # Use list indices
|
| 121 |
+
"bbox": { # Convert to dict format for consistency
|
| 122 |
+
"xmin": bbox[0],
|
| 123 |
+
"ymin": bbox[1],
|
| 124 |
+
"xmax": bbox[2],
|
| 125 |
+
"ymax": bbox[3]
|
| 126 |
+
}
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
# Process texts
|
| 130 |
+
for text in texts:
|
| 131 |
+
bbox = text["bbox"] # bbox is a list [x1,y1,x2,y2]
|
| 132 |
+
nodes.append({
|
| 133 |
+
"id": str(uuid.uuid4()),
|
| 134 |
+
"type": "text",
|
| 135 |
+
"content": text.get("text", ""),
|
| 136 |
+
"confidence": text.get("confidence", 0.0),
|
| 137 |
+
"x": (bbox[0] + bbox[2]) / 2, # Use list indices
|
| 138 |
+
"y": (bbox[1] + bbox[3]) / 2, # Use list indices
|
| 139 |
+
"bbox": { # Convert to dict format for consistency
|
| 140 |
+
"xmin": bbox[0],
|
| 141 |
+
"ymin": bbox[1],
|
| 142 |
+
"xmax": bbox[2],
|
| 143 |
+
"ymax": bbox[3]
|
| 144 |
+
}
|
| 145 |
+
})
|
| 146 |
+
|
| 147 |
+
# Process lines (unchanged)
|
| 148 |
+
for line in lines:
|
| 149 |
+
edges.append({
|
| 150 |
+
"id": str(uuid.uuid4()),
|
| 151 |
+
"type": "line",
|
| 152 |
+
"start_point": line["start_point"],
|
| 153 |
+
"end_point": line["end_point"]
|
| 154 |
+
})
|
| 155 |
+
|
| 156 |
+
except Exception as e:
|
| 157 |
+
self.logger.error(f"Error processing data: {str(e)}")
|
| 158 |
+
self.logger.error("Current symbol/text being processed: %s",
|
| 159 |
+
json.dumps(symbol if 'symbol' in locals() else text, indent=2))
|
| 160 |
+
raise
|
| 161 |
+
|
| 162 |
+
return nodes, edges
|
| 163 |
+
|
| 164 |
+
def _validate_coordinates(self, data, data_type):
|
| 165 |
+
"""Validate coordinates in the data"""
|
| 166 |
+
if not data:
|
| 167 |
+
return False
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
if data_type == 'line':
|
| 171 |
+
# Check start and end points
|
| 172 |
+
start = data.get('start_point', {})
|
| 173 |
+
end = data.get('end_point', {})
|
| 174 |
+
bbox = data.get('bbox', {})
|
| 175 |
+
|
| 176 |
+
required_fields = ['x', 'y', 'type']
|
| 177 |
+
if not all(field in start for field in required_fields):
|
| 178 |
+
self.logger.warning(f"Missing required fields in start_point: {start}")
|
| 179 |
+
return False
|
| 180 |
+
if not all(field in end for field in required_fields):
|
| 181 |
+
self.logger.warning(f"Missing required fields in end_point: {end}")
|
| 182 |
+
return False
|
| 183 |
+
|
| 184 |
+
# Validate bbox coordinates
|
| 185 |
+
if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
|
| 186 |
+
self.logger.warning(f"Invalid bbox format: {bbox}")
|
| 187 |
+
return False
|
| 188 |
+
|
| 189 |
+
# Check coordinate consistency
|
| 190 |
+
if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
|
| 191 |
+
self.logger.warning(f"Invalid bbox coordinates: {bbox}")
|
| 192 |
+
return False
|
| 193 |
+
|
| 194 |
+
elif data_type in ['symbol', 'text']:
|
| 195 |
+
bbox = data.get('bbox', {})
|
| 196 |
+
if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
|
| 197 |
+
self.logger.warning(f"Invalid {data_type} bbox format: {bbox}")
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
# Check coordinate consistency
|
| 201 |
+
if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
|
| 202 |
+
self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}")
|
| 203 |
+
return False
|
| 204 |
+
|
| 205 |
+
return True
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
self.logger.error(f"Validation error for {data_type}: {str(e)}")
|
| 209 |
+
return False
|
| 210 |
+
|
| 211 |
+
def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict:
|
| 212 |
+
"""Aggregate detection results and create graph structure"""
|
| 213 |
+
try:
|
| 214 |
+
# Load line detection results
|
| 215 |
+
lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8'))
|
| 216 |
+
lines = self._parse_line_data(lines_data)
|
| 217 |
+
self.logger.info(f"Loaded {len(lines)} lines")
|
| 218 |
+
|
| 219 |
+
# Load and debug symbol detections
|
| 220 |
+
symbols = []
|
| 221 |
+
if symbols_path and Path(symbols_path).exists():
|
| 222 |
+
with open(symbols_path, 'r') as f:
|
| 223 |
+
symbols_data = json.load(f)
|
| 224 |
+
# Debug symbol data structure
|
| 225 |
+
self.logger.info("Symbol data keys: %s", list(symbols_data.keys()))
|
| 226 |
+
self.logger.info("First symbol in detections: %s",
|
| 227 |
+
json.dumps(symbols_data["detections"][0], indent=2))
|
| 228 |
+
|
| 229 |
+
symbols = symbols_data.get("detections", [])
|
| 230 |
+
self.logger.info(f"Loaded {len(symbols)} symbols from {symbols_path}")
|
| 231 |
+
# Debug first symbol structure
|
| 232 |
+
if symbols:
|
| 233 |
+
self.logger.info("First symbol keys: %s", list(symbols[0].keys()))
|
| 234 |
+
self.logger.info("First symbol bbox: %s", symbols[0]["bbox"])
|
| 235 |
+
|
| 236 |
+
# Load and debug text detections
|
| 237 |
+
texts = []
|
| 238 |
+
if texts_path and Path(texts_path).exists():
|
| 239 |
+
with open(texts_path, 'r') as f:
|
| 240 |
+
texts_data = json.load(f)
|
| 241 |
+
# Debug text data structure
|
| 242 |
+
self.logger.info("Text data keys: %s", list(texts_data.keys()))
|
| 243 |
+
self.logger.info("First text in detections: %s",
|
| 244 |
+
json.dumps(texts_data["detections"][0], indent=2))
|
| 245 |
+
|
| 246 |
+
texts = texts_data.get("detections", [])
|
| 247 |
+
self.logger.info(f"Loaded {len(texts)} texts from {texts_path}")
|
| 248 |
+
# Debug first text structure
|
| 249 |
+
if texts:
|
| 250 |
+
self.logger.info("First text keys: %s", list(texts[0].keys()))
|
| 251 |
+
self.logger.info("First text bbox: %s", texts[0]["bbox"])
|
| 252 |
+
|
| 253 |
+
# Create graph data
|
| 254 |
+
nodes, edges = self._create_graph_data(lines, symbols, texts)
|
| 255 |
+
self.logger.info(f"Created graph with {len(nodes)} nodes and {len(edges)} edges")
|
| 256 |
+
|
| 257 |
+
return {
|
| 258 |
+
"lines": lines,
|
| 259 |
+
"symbols": symbols,
|
| 260 |
+
"texts": texts,
|
| 261 |
+
"nodes": nodes,
|
| 262 |
+
"edges": edges,
|
| 263 |
+
"metadata": {
|
| 264 |
+
"timestamp": datetime.now().isoformat(),
|
| 265 |
+
"version": "2.0"
|
| 266 |
+
}
|
| 267 |
+
}
|
| 268 |
+
except Exception as e:
|
| 269 |
+
self.logger.error(f"Error during aggregation: {str(e)}")
|
| 270 |
+
self.logger.error("Stack trace:", exc_info=True) # Add full stack trace
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
+
def _draw_aggregated_view(self, image: np.ndarray, results: dict) -> np.ndarray:
|
| 274 |
+
"""Draw all detections on image"""
|
| 275 |
+
annotated = image.copy()
|
| 276 |
+
|
| 277 |
+
# Draw lines (green)
|
| 278 |
+
for line in results.get('lines', []):
|
| 279 |
+
try:
|
| 280 |
+
cv2.line(annotated,
|
| 281 |
+
(line['start_point']['x'], line['start_point']['y']),
|
| 282 |
+
(line['end_point']['x'], line['end_point']['y']),
|
| 283 |
+
(0, 255, 0), 2)
|
| 284 |
+
except Exception as e:
|
| 285 |
+
self.logger.warning(f"Skipping invalid line: {str(e)}")
|
| 286 |
+
continue
|
| 287 |
+
|
| 288 |
+
# Draw symbols (cyan) - Fix bbox access
|
| 289 |
+
for symbol in results.get('symbols', []):
|
| 290 |
+
try:
|
| 291 |
+
bbox = symbol['bbox']
|
| 292 |
+
# bbox is a list [x1,y1,x2,y2], not a dict
|
| 293 |
+
cv2.rectangle(annotated,
|
| 294 |
+
(bbox[0], bbox[1]), # Use list indices
|
| 295 |
+
(bbox[2], bbox[3]), # Use list indices
|
| 296 |
+
(255, 255, 0), 2)
|
| 297 |
+
except Exception as e:
|
| 298 |
+
self.logger.warning(f"Skipping invalid symbol: {str(e)}")
|
| 299 |
+
continue
|
| 300 |
+
|
| 301 |
+
# Draw texts (purple) - Fix bbox access
|
| 302 |
+
for text in results.get('texts', []):
|
| 303 |
+
try:
|
| 304 |
+
bbox = text['bbox']
|
| 305 |
+
# bbox is a list [x1,y1,x2,y2], not a dict
|
| 306 |
+
cv2.rectangle(annotated,
|
| 307 |
+
(bbox[0], bbox[1]), # Use list indices
|
| 308 |
+
(bbox[2], bbox[3]), # Use list indices
|
| 309 |
+
(128, 0, 128), 2)
|
| 310 |
+
except Exception as e:
|
| 311 |
+
self.logger.warning(f"Skipping invalid text: {str(e)}")
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
return annotated
|
| 315 |
+
|
| 316 |
+
def process_data(self, image_path: str, output_dir: str, symbols_path: str, texts_path: str, lines_path: str):
|
| 317 |
+
try:
|
| 318 |
+
self.logger.info(f"Processing data with:")
|
| 319 |
+
self.logger.info(f"- Image: {image_path}")
|
| 320 |
+
self.logger.info(f"- Symbols: {symbols_path}")
|
| 321 |
+
self.logger.info(f"- Texts: {texts_path}")
|
| 322 |
+
self.logger.info(f"- Lines: {lines_path}")
|
| 323 |
+
|
| 324 |
+
base_name = Path(image_path).stem
|
| 325 |
+
self.logger.info(f"Base name: {base_name}")
|
| 326 |
+
|
| 327 |
+
aggregated_json = os.path.join(output_dir, f"{base_name}_aggregated.json")
|
| 328 |
+
self.logger.info(f"Will save aggregated data to: {aggregated_json}")
|
| 329 |
+
|
| 330 |
+
results = self.aggregate_data(symbols_path, texts_path, lines_path)
|
| 331 |
+
self.logger.info("Data aggregation completed")
|
| 332 |
+
|
| 333 |
+
with open(aggregated_json, 'w') as f:
|
| 334 |
+
json.dump(results, f, indent=2)
|
| 335 |
+
self.logger.info(f"Saved aggregated JSON to: {aggregated_json}")
|
| 336 |
+
|
| 337 |
+
# Create visualization using original image
|
| 338 |
+
image = cv2.imread(image_path)
|
| 339 |
+
annotated = self._draw_aggregated_view(image, results)
|
| 340 |
+
aggregated_image = os.path.join(output_dir, f"{base_name}_aggregated.png")
|
| 341 |
+
cv2.imwrite(aggregated_image, annotated)
|
| 342 |
+
|
| 343 |
+
# Return paths like other detectors
|
| 344 |
+
return {
|
| 345 |
+
'success': True,
|
| 346 |
+
'image_path': aggregated_image,
|
| 347 |
+
'json_path': aggregated_json
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
self.logger.error(f"Error in data aggregation: {str(e)}")
|
| 352 |
+
return {
|
| 353 |
+
'success': False,
|
| 354 |
+
'error': str(e)
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
if __name__ == "__main__":
|
| 358 |
+
import os
|
| 359 |
+
from pprint import pprint
|
| 360 |
+
|
| 361 |
+
# Initialize the aggregator
|
| 362 |
+
aggregator = DataAggregator()
|
| 363 |
+
|
| 364 |
+
# Test paths using actual files in results folder
|
| 365 |
+
results_dir = "results"
|
| 366 |
+
base_name = "002_page_1"
|
| 367 |
+
|
| 368 |
+
# Input paths
|
| 369 |
+
symbols_path = os.path.join(results_dir, f"{base_name}_detected_symbols.json")
|
| 370 |
+
texts_path = os.path.join(results_dir, f"{base_name}_detected_texts.json")
|
| 371 |
+
lines_path = os.path.join(results_dir, f"{base_name}_detected_lines.json")
|
| 372 |
+
|
| 373 |
+
# Verify files exist
|
| 374 |
+
print(f"\nChecking input files:")
|
| 375 |
+
print(f"Symbols file exists: {os.path.exists(symbols_path)}")
|
| 376 |
+
print(f"Texts file exists: {os.path.exists(texts_path)}")
|
| 377 |
+
print(f"Lines file exists: {os.path.exists(lines_path)}")
|
| 378 |
+
|
| 379 |
+
try:
|
| 380 |
+
# Process the data
|
| 381 |
+
print("\nProcessing data...")
|
| 382 |
+
result = aggregator.process_data(
|
| 383 |
+
image_path=os.path.join(results_dir, f"{base_name}.png"),
|
| 384 |
+
output_dir=results_dir,
|
| 385 |
+
symbols_path=symbols_path,
|
| 386 |
+
texts_path=texts_path,
|
| 387 |
+
lines_path=lines_path
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Verify output files
|
| 391 |
+
aggregated_json = os.path.join(results_dir, f"{base_name}_aggregated.json")
|
| 392 |
+
aggregated_image = os.path.join(results_dir, f"{base_name}_aggregated.png")
|
| 393 |
+
|
| 394 |
+
print("\nChecking output files:")
|
| 395 |
+
print(f"Aggregated JSON exists: {os.path.exists(aggregated_json)}")
|
| 396 |
+
print(f"Aggregated image exists: {os.path.exists(aggregated_image)}")
|
| 397 |
+
|
| 398 |
+
# Load and print statistics from aggregated result
|
| 399 |
+
if os.path.exists(aggregated_json):
|
| 400 |
+
with open(aggregated_json, 'r') as f:
|
| 401 |
+
data = json.load(f)
|
| 402 |
+
print("\nAggregation Results:")
|
| 403 |
+
print(f"Number of Symbols: {len(data.get('symbols', []))}")
|
| 404 |
+
print(f"Number of Texts: {len(data.get('texts', []))}")
|
| 405 |
+
print(f"Number of Lines: {len(data.get('lines', []))}")
|
| 406 |
+
print(f"Number of Nodes: {len(data.get('nodes', []))}")
|
| 407 |
+
print(f"Number of Edges: {len(data.get('edges', []))}")
|
| 408 |
+
|
| 409 |
+
except Exception as e:
|
| 410 |
+
print(f"\nError during testing: {str(e)}")
|
| 411 |
+
traceback.print_exc()
|
detection_schema.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from typing import List, Optional, Tuple, Dict
|
| 3 |
+
import uuid
|
| 4 |
+
from enum import Enum
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
# ======================== Point ======================== #
|
| 9 |
+
class ConnectionType(Enum):
|
| 10 |
+
SOLID = "solid"
|
| 11 |
+
DASHED = "dashed"
|
| 12 |
+
PHANTOM = "phantom"
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class Coordinates:
|
| 16 |
+
x: int
|
| 17 |
+
y: int
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class BBox:
|
| 21 |
+
xmin: int
|
| 22 |
+
ymin: int
|
| 23 |
+
xmax: int
|
| 24 |
+
ymax: int
|
| 25 |
+
|
| 26 |
+
def width(self) -> int:
|
| 27 |
+
return self.xmax - self.xmin
|
| 28 |
+
|
| 29 |
+
def height(self) -> int:
|
| 30 |
+
return self.ymax - self.ymin
|
| 31 |
+
|
| 32 |
+
class JunctionType(str, Enum):
|
| 33 |
+
T = "T"
|
| 34 |
+
L = "L"
|
| 35 |
+
END = "END"
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class Point:
|
| 39 |
+
coords: Coordinates
|
| 40 |
+
bbox: BBox
|
| 41 |
+
type: JunctionType
|
| 42 |
+
confidence: float = 1.0
|
| 43 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# # ======================== Symbol ======================== #
|
| 47 |
+
# class SymbolType(Enum):
|
| 48 |
+
# VALVE = "valve"
|
| 49 |
+
# PUMP = "pump"
|
| 50 |
+
# SENSOR = "sensor"
|
| 51 |
+
# # Add others as needed
|
| 52 |
+
#
|
| 53 |
+
class ValveSubtype(Enum):
|
| 54 |
+
GATE = "gate"
|
| 55 |
+
GLOBE = "globe"
|
| 56 |
+
BUTTERFLY = "butterfly"
|
| 57 |
+
#
|
| 58 |
+
# @dataclass
|
| 59 |
+
# class Symbol:
|
| 60 |
+
# symbol_type: SymbolType
|
| 61 |
+
# bbox: BBox
|
| 62 |
+
# center: Coordinates
|
| 63 |
+
# connections: List[Point] = field(default_factory=list)
|
| 64 |
+
# subtype: Optional[ValveSubtype] = None
|
| 65 |
+
# id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 66 |
+
# confidence: float = 0.95
|
| 67 |
+
# model_metadata: dict = field(default_factory=dict)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ======================== Symbol ======================== #
|
| 71 |
+
class SymbolType(Enum):
|
| 72 |
+
VALVE = "valve"
|
| 73 |
+
PUMP = "pump"
|
| 74 |
+
SENSOR = "sensor"
|
| 75 |
+
OTHER = "other" # Added to handle unknown categories
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class Symbol:
|
| 79 |
+
center: Coordinates
|
| 80 |
+
symbol_type: SymbolType = field(default=SymbolType.OTHER)
|
| 81 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 82 |
+
class_id: int = -1
|
| 83 |
+
original_label: str = ""
|
| 84 |
+
category: str = "" # e.g., "inst"
|
| 85 |
+
type: str = "" # e.g., "ind"
|
| 86 |
+
label: str = "" # e.g., "Solenoid_actuator"
|
| 87 |
+
bbox: BBox = None
|
| 88 |
+
confidence: float = 0.95
|
| 89 |
+
model_source: str = "" # e.g., "model2"
|
| 90 |
+
connections: List[Point] = field(default_factory=list)
|
| 91 |
+
subtype: Optional[ValveSubtype] = None
|
| 92 |
+
model_metadata: dict = field(default_factory=dict)
|
| 93 |
+
|
| 94 |
+
def __post_init__(self):
|
| 95 |
+
"""
|
| 96 |
+
Handle any additional post-processing after initialization.
|
| 97 |
+
"""
|
| 98 |
+
# Ensure bbox is a BBox object
|
| 99 |
+
if isinstance(self.bbox, list) and len(self.bbox) == 4:
|
| 100 |
+
self.bbox = BBox(*self.bbox)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
# ======================== Line ======================== #
|
| 104 |
+
@dataclass
|
| 105 |
+
class LineStyle:
|
| 106 |
+
connection_type: ConnectionType
|
| 107 |
+
stroke_width: int = 2
|
| 108 |
+
color: str = "#000000" # CSS-style colors
|
| 109 |
+
|
| 110 |
+
@dataclass
|
| 111 |
+
class Line:
|
| 112 |
+
start: Point
|
| 113 |
+
end: Point
|
| 114 |
+
bbox: BBox
|
| 115 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 116 |
+
style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID))
|
| 117 |
+
confidence: float = 0.90
|
| 118 |
+
topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ======================== Junction ======================== #
|
| 122 |
+
class JunctionType(str, Enum):
|
| 123 |
+
T = "T"
|
| 124 |
+
L = "L"
|
| 125 |
+
END = "END"
|
| 126 |
+
|
| 127 |
+
@dataclass
|
| 128 |
+
class JunctionProperties:
|
| 129 |
+
flow_direction: Optional[str] = None # "in", "out"
|
| 130 |
+
pressure: Optional[float] = None # kPa
|
| 131 |
+
|
| 132 |
+
@dataclass
|
| 133 |
+
class Junction:
|
| 134 |
+
center: Coordinates
|
| 135 |
+
junction_type: JunctionType
|
| 136 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 137 |
+
properties: JunctionProperties = field(default_factory=JunctionProperties)
|
| 138 |
+
connected_lines: List[str] = field(default_factory=list) # Line IDs
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# # ======================== Tag ======================== #
|
| 142 |
+
# @dataclass
|
| 143 |
+
# class Tag:
|
| 144 |
+
# text: str
|
| 145 |
+
# bbox: BBox
|
| 146 |
+
# associated_element: str # ID of linked symbol/line
|
| 147 |
+
# id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 148 |
+
# font_size: int = 12
|
| 149 |
+
# rotation: float = 0.0 # Degrees
|
| 150 |
+
|
| 151 |
+
@dataclass
|
| 152 |
+
class Tag:
|
| 153 |
+
text: str
|
| 154 |
+
bbox: BBox
|
| 155 |
+
confidence: float = 1.0
|
| 156 |
+
source: str = "" # e.g., "easyocr"
|
| 157 |
+
text_type: str = "Unknown" # e.g., "Unknown", could be something else later
|
| 158 |
+
id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 159 |
+
associated_element: Optional[str] = None # ID of linked symbol/line (can be None)
|
| 160 |
+
font_size: int = 12
|
| 161 |
+
rotation: float = 0.0 # Degrees
|
| 162 |
+
|
| 163 |
+
def __post_init__(self):
|
| 164 |
+
"""
|
| 165 |
+
Ensure bbox is properly converted.
|
| 166 |
+
"""
|
| 167 |
+
if isinstance(self.bbox, list) and len(self.bbox) == 4:
|
| 168 |
+
self.bbox = BBox(*self.bbox)
|
| 169 |
+
|
| 170 |
+
# ----------------------------
|
| 171 |
+
# DETECTION CONTEXT
|
| 172 |
+
# ----------------------------
|
| 173 |
+
|
| 174 |
+
@dataclass
|
| 175 |
+
class DetectionContext:
|
| 176 |
+
"""
|
| 177 |
+
In-memory container for all detected elements (lines, points, symbols, junctions, tags).
|
| 178 |
+
Each element is stored in a dict keyed by 'id' for quick lookup and update.
|
| 179 |
+
"""
|
| 180 |
+
lines: Dict[str, Line] = field(default_factory=dict)
|
| 181 |
+
points: Dict[str, Point] = field(default_factory=dict)
|
| 182 |
+
symbols: Dict[str, Symbol] = field(default_factory=dict)
|
| 183 |
+
junctions: Dict[str, Junction] = field(default_factory=dict)
|
| 184 |
+
tags: Dict[str, Tag] = field(default_factory=dict)
|
| 185 |
+
|
| 186 |
+
# -------------------------
|
| 187 |
+
# 1) ADD / GET / REMOVE
|
| 188 |
+
# -------------------------
|
| 189 |
+
def add_line(self, line: Line) -> None:
|
| 190 |
+
self.lines[line.id] = line
|
| 191 |
+
|
| 192 |
+
def get_line(self, line_id: str) -> Optional[Line]:
|
| 193 |
+
return self.lines.get(line_id)
|
| 194 |
+
|
| 195 |
+
def remove_line(self, line_id: str) -> None:
|
| 196 |
+
self.lines.pop(line_id, None)
|
| 197 |
+
|
| 198 |
+
def add_point(self, point: Point) -> None:
|
| 199 |
+
self.points[point.id] = point
|
| 200 |
+
|
| 201 |
+
def get_point(self, point_id: str) -> Optional[Point]:
|
| 202 |
+
return self.points.get(point_id)
|
| 203 |
+
|
| 204 |
+
def remove_point(self, point_id: str) -> None:
|
| 205 |
+
self.points.pop(point_id, None)
|
| 206 |
+
|
| 207 |
+
def add_symbol(self, symbol: Symbol) -> None:
|
| 208 |
+
self.symbols[symbol.id] = symbol
|
| 209 |
+
|
| 210 |
+
def get_symbol(self, symbol_id: str) -> Optional[Symbol]:
|
| 211 |
+
return self.symbols.get(symbol_id)
|
| 212 |
+
|
| 213 |
+
def remove_symbol(self, symbol_id: str) -> None:
|
| 214 |
+
self.symbols.pop(symbol_id, None)
|
| 215 |
+
|
| 216 |
+
def add_junction(self, junction: Junction) -> None:
|
| 217 |
+
self.junctions[junction.id] = junction
|
| 218 |
+
|
| 219 |
+
def get_junction(self, junction_id: str) -> Optional[Junction]:
|
| 220 |
+
return self.junctions.get(junction_id)
|
| 221 |
+
|
| 222 |
+
def remove_junction(self, junction_id: str) -> None:
|
| 223 |
+
self.junctions.pop(junction_id, None)
|
| 224 |
+
|
| 225 |
+
def add_tag(self, tag: Tag) -> None:
|
| 226 |
+
self.tags[tag.id] = tag
|
| 227 |
+
|
| 228 |
+
def get_tag(self, tag_id: str) -> Optional[Tag]:
|
| 229 |
+
return self.tags.get(tag_id)
|
| 230 |
+
|
| 231 |
+
def remove_tag(self, tag_id: str) -> None:
|
| 232 |
+
self.tags.pop(tag_id, None)
|
| 233 |
+
|
| 234 |
+
# -------------------------
|
| 235 |
+
# 2) SERIALIZATION: to_dict / from_dict
|
| 236 |
+
# -------------------------
|
| 237 |
+
def to_dict(self) -> dict:
|
| 238 |
+
"""Convert all stored objects into a JSON-serializable dictionary."""
|
| 239 |
+
return {
|
| 240 |
+
"lines": [self._line_to_dict(line) for line in self.lines.values()],
|
| 241 |
+
"points": [self._point_to_dict(pt) for pt in self.points.values()],
|
| 242 |
+
"symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()],
|
| 243 |
+
"junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()],
|
| 244 |
+
"tags": [self._tag_to_dict(tg) for tg in self.tags.values()]
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
@classmethod
|
| 248 |
+
def from_dict(cls, data: dict) -> "DetectionContext":
|
| 249 |
+
"""
|
| 250 |
+
Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON).
|
| 251 |
+
"""
|
| 252 |
+
context = cls()
|
| 253 |
+
|
| 254 |
+
# Points
|
| 255 |
+
for pt_dict in data.get("points", []):
|
| 256 |
+
pt_obj = cls._point_from_dict(pt_dict)
|
| 257 |
+
context.add_point(pt_obj)
|
| 258 |
+
|
| 259 |
+
# Lines
|
| 260 |
+
for ln_dict in data.get("lines", []):
|
| 261 |
+
ln_obj = cls._line_from_dict(ln_dict)
|
| 262 |
+
context.add_line(ln_obj)
|
| 263 |
+
|
| 264 |
+
# Symbols
|
| 265 |
+
for sym_dict in data.get("symbols", []):
|
| 266 |
+
sym_obj = cls._symbol_from_dict(sym_dict)
|
| 267 |
+
context.add_symbol(sym_obj)
|
| 268 |
+
|
| 269 |
+
# Junctions
|
| 270 |
+
for jn_dict in data.get("junctions", []):
|
| 271 |
+
jn_obj = cls._junction_from_dict(jn_dict)
|
| 272 |
+
context.add_junction(jn_obj)
|
| 273 |
+
|
| 274 |
+
# Tags
|
| 275 |
+
for tg_dict in data.get("tags", []):
|
| 276 |
+
tg_obj = cls._tag_from_dict(tg_dict)
|
| 277 |
+
context.add_tag(tg_obj)
|
| 278 |
+
|
| 279 |
+
return context
|
| 280 |
+
|
| 281 |
+
# -------------------------
|
| 282 |
+
# 3) HELPER METHODS FOR (DE)SERIALIZATION
|
| 283 |
+
# -------------------------
|
| 284 |
+
@staticmethod
|
| 285 |
+
def _bbox_to_dict(bbox: BBox) -> dict:
|
| 286 |
+
return {
|
| 287 |
+
"xmin": bbox.xmin,
|
| 288 |
+
"ymin": bbox.ymin,
|
| 289 |
+
"xmax": bbox.xmax,
|
| 290 |
+
"ymax": bbox.ymax
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def _bbox_from_dict(d: dict) -> BBox:
|
| 295 |
+
return BBox(
|
| 296 |
+
xmin=d["xmin"],
|
| 297 |
+
ymin=d["ymin"],
|
| 298 |
+
xmax=d["xmax"],
|
| 299 |
+
ymax=d["ymax"]
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
@staticmethod
|
| 303 |
+
def _coords_to_dict(coords: Coordinates) -> dict:
|
| 304 |
+
return {
|
| 305 |
+
"x": coords.x,
|
| 306 |
+
"y": coords.y
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
@staticmethod
|
| 310 |
+
def _coords_from_dict(d: dict) -> Coordinates:
|
| 311 |
+
return Coordinates(x=d["x"], y=d["y"])
|
| 312 |
+
|
| 313 |
+
@staticmethod
|
| 314 |
+
def _line_style_to_dict(style: LineStyle) -> dict:
|
| 315 |
+
return {
|
| 316 |
+
"connection_type": style.connection_type.value,
|
| 317 |
+
"stroke_width": style.stroke_width,
|
| 318 |
+
"color": style.color
|
| 319 |
+
}
|
| 320 |
+
|
| 321 |
+
@staticmethod
|
| 322 |
+
def _line_style_from_dict(d: dict) -> LineStyle:
|
| 323 |
+
return LineStyle(
|
| 324 |
+
connection_type=ConnectionType(d["connection_type"]),
|
| 325 |
+
stroke_width=d.get("stroke_width", 2),
|
| 326 |
+
color=d.get("color", "#000000")
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def _point_to_dict(pt: Point) -> dict:
|
| 331 |
+
return {
|
| 332 |
+
"id": pt.id,
|
| 333 |
+
"coords": DetectionContext._coords_to_dict(pt.coords),
|
| 334 |
+
"bbox": DetectionContext._bbox_to_dict(pt.bbox),
|
| 335 |
+
"type": pt.type.value,
|
| 336 |
+
"confidence": pt.confidence
|
| 337 |
+
}
|
| 338 |
+
|
| 339 |
+
@staticmethod
|
| 340 |
+
def _point_from_dict(d: dict) -> Point:
|
| 341 |
+
return Point(
|
| 342 |
+
id=d["id"],
|
| 343 |
+
coords=DetectionContext._coords_from_dict(d["coords"]),
|
| 344 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 345 |
+
type=JunctionType(d["type"]),
|
| 346 |
+
confidence=d.get("confidence", 1.0)
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
@staticmethod
|
| 350 |
+
def _line_to_dict(ln: Line) -> dict:
|
| 351 |
+
return {
|
| 352 |
+
"id": ln.id,
|
| 353 |
+
"start": DetectionContext._point_to_dict(ln.start),
|
| 354 |
+
"end": DetectionContext._point_to_dict(ln.end),
|
| 355 |
+
"bbox": DetectionContext._bbox_to_dict(ln.bbox),
|
| 356 |
+
"style": DetectionContext._line_style_to_dict(ln.style),
|
| 357 |
+
"confidence": ln.confidence,
|
| 358 |
+
"topological_links": ln.topological_links
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
@staticmethod
|
| 362 |
+
def _line_from_dict(d: dict) -> Line:
|
| 363 |
+
return Line(
|
| 364 |
+
id=d["id"],
|
| 365 |
+
start=DetectionContext._point_from_dict(d["start"]),
|
| 366 |
+
end=DetectionContext._point_from_dict(d["end"]),
|
| 367 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 368 |
+
style=DetectionContext._line_style_from_dict(d["style"]),
|
| 369 |
+
confidence=d.get("confidence", 0.90),
|
| 370 |
+
topological_links=d.get("topological_links", [])
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
@staticmethod
|
| 374 |
+
def _symbol_to_dict(sym: Symbol) -> dict:
|
| 375 |
+
return {
|
| 376 |
+
"id": sym.id,
|
| 377 |
+
"symbol_type": sym.symbol_type.value,
|
| 378 |
+
"bbox": DetectionContext._bbox_to_dict(sym.bbox),
|
| 379 |
+
"center": DetectionContext._coords_to_dict(sym.center),
|
| 380 |
+
"connections": [DetectionContext._point_to_dict(p) for p in sym.connections],
|
| 381 |
+
"subtype": sym.subtype.value if sym.subtype else None,
|
| 382 |
+
"confidence": sym.confidence,
|
| 383 |
+
"model_metadata": sym.model_metadata
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
@staticmethod
|
| 387 |
+
def _symbol_from_dict(d: dict) -> Symbol:
|
| 388 |
+
return Symbol(
|
| 389 |
+
id=d["id"],
|
| 390 |
+
symbol_type=SymbolType(d["symbol_type"]),
|
| 391 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 392 |
+
center=DetectionContext._coords_from_dict(d["center"]),
|
| 393 |
+
connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])],
|
| 394 |
+
subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None,
|
| 395 |
+
confidence=d.get("confidence", 0.95),
|
| 396 |
+
model_metadata=d.get("model_metadata", {})
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
@staticmethod
|
| 400 |
+
def _junction_props_to_dict(props: JunctionProperties) -> dict:
|
| 401 |
+
return {
|
| 402 |
+
"flow_direction": props.flow_direction,
|
| 403 |
+
"pressure": props.pressure
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
@staticmethod
|
| 407 |
+
def _junction_props_from_dict(d: dict) -> JunctionProperties:
|
| 408 |
+
return JunctionProperties(
|
| 409 |
+
flow_direction=d.get("flow_direction"),
|
| 410 |
+
pressure=d.get("pressure")
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
@staticmethod
|
| 414 |
+
def _junction_to_dict(jn: Junction) -> dict:
|
| 415 |
+
return {
|
| 416 |
+
"id": jn.id,
|
| 417 |
+
"center": DetectionContext._coords_to_dict(jn.center),
|
| 418 |
+
"junction_type": jn.junction_type.value,
|
| 419 |
+
"properties": DetectionContext._junction_props_to_dict(jn.properties),
|
| 420 |
+
"connected_lines": jn.connected_lines
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
@staticmethod
|
| 424 |
+
def _junction_from_dict(d: dict) -> Junction:
|
| 425 |
+
return Junction(
|
| 426 |
+
id=d["id"],
|
| 427 |
+
center=DetectionContext._coords_from_dict(d["center"]),
|
| 428 |
+
junction_type=JunctionType(d["junction_type"]),
|
| 429 |
+
properties=DetectionContext._junction_props_from_dict(d["properties"]),
|
| 430 |
+
connected_lines=d.get("connected_lines", [])
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
@staticmethod
|
| 434 |
+
def _tag_to_dict(tg: Tag) -> dict:
|
| 435 |
+
return {
|
| 436 |
+
"id": tg.id,
|
| 437 |
+
"text": tg.text,
|
| 438 |
+
"bbox": DetectionContext._bbox_to_dict(tg.bbox),
|
| 439 |
+
"associated_element": tg.associated_element,
|
| 440 |
+
"font_size": tg.font_size,
|
| 441 |
+
"rotation": tg.rotation
|
| 442 |
+
}
|
| 443 |
+
|
| 444 |
+
@staticmethod
|
| 445 |
+
def _tag_from_dict(d: dict) -> Tag:
|
| 446 |
+
return Tag(
|
| 447 |
+
id=d["id"],
|
| 448 |
+
text=d["text"],
|
| 449 |
+
bbox=DetectionContext._bbox_from_dict(d["bbox"]),
|
| 450 |
+
associated_element=d["associated_element"],
|
| 451 |
+
font_size=d.get("font_size", 12),
|
| 452 |
+
rotation=d.get("rotation", 0.0)
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
# -------------------------
|
| 456 |
+
# 4) OPTIONAL UTILS
|
| 457 |
+
# -------------------------
|
| 458 |
+
def to_json(self, indent: int = 2) -> str:
|
| 459 |
+
"""Convert context to JSON, ensuring dataclasses and numpy types are handled correctly."""
|
| 460 |
+
return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent)
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def _json_serializer(obj):
|
| 464 |
+
"""Handles numpy types and unknown objects for JSON serialization."""
|
| 465 |
+
if isinstance(obj, np.integer):
|
| 466 |
+
return int(obj)
|
| 467 |
+
if isinstance(obj, np.floating):
|
| 468 |
+
return float(obj)
|
| 469 |
+
if isinstance(obj, np.ndarray):
|
| 470 |
+
return obj.tolist() # Convert arrays to lists
|
| 471 |
+
if isinstance(obj, Enum):
|
| 472 |
+
return obj.value # Convert Enums to string values
|
| 473 |
+
if hasattr(obj, "__dict__"):
|
| 474 |
+
return obj.__dict__ # Convert dataclass objects to dict
|
| 475 |
+
raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
|
| 476 |
+
|
| 477 |
+
@classmethod
|
| 478 |
+
def from_json(cls, json_str: str) -> "DetectionContext":
|
| 479 |
+
"""Load DetectionContext from a JSON string."""
|
| 480 |
+
data = json.loads(json_str)
|
| 481 |
+
return cls.from_dict(data)
|
detection_utils.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import List, Tuple
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
def robust_merge_lines(lines: List[Tuple[float, float, float, float]],
|
| 6 |
+
angle_thresh: float = 5.0,
|
| 7 |
+
dist_thresh: float = 5.0) -> List[Tuple[float, float, float, float]]:
|
| 8 |
+
"""
|
| 9 |
+
Merge similar line segments using angle and distance thresholds.
|
| 10 |
+
|
| 11 |
+
Args:
|
| 12 |
+
lines: List of line segments [(x1,y1,x2,y2),...]
|
| 13 |
+
angle_thresh: Maximum angle difference in degrees
|
| 14 |
+
dist_thresh: Maximum endpoint distance
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
List of merged line segments
|
| 18 |
+
"""
|
| 19 |
+
if not lines:
|
| 20 |
+
return []
|
| 21 |
+
|
| 22 |
+
# Convert to numpy array for easier manipulation
|
| 23 |
+
lines = np.array(lines)
|
| 24 |
+
|
| 25 |
+
# Calculate line angles
|
| 26 |
+
angles = np.arctan2(lines[:,3] - lines[:,1],
|
| 27 |
+
lines[:,2] - lines[:,0])
|
| 28 |
+
angles = np.degrees(angles) % 180
|
| 29 |
+
|
| 30 |
+
# Group similar lines
|
| 31 |
+
merged = []
|
| 32 |
+
used = set()
|
| 33 |
+
|
| 34 |
+
for i, line1 in enumerate(lines):
|
| 35 |
+
if i in used:
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
# Find similar lines
|
| 39 |
+
similar = []
|
| 40 |
+
for j, line2 in enumerate(lines):
|
| 41 |
+
if j in used:
|
| 42 |
+
continue
|
| 43 |
+
|
| 44 |
+
# Check angle difference
|
| 45 |
+
angle_diff = abs(angles[i] - angles[j])
|
| 46 |
+
angle_diff = min(angle_diff, 180 - angle_diff)
|
| 47 |
+
|
| 48 |
+
if angle_diff > angle_thresh:
|
| 49 |
+
continue
|
| 50 |
+
|
| 51 |
+
# Check endpoint distances
|
| 52 |
+
dist1 = np.linalg.norm(line1[:2] - line2[:2])
|
| 53 |
+
dist2 = np.linalg.norm(line1[2:] - line2[2:])
|
| 54 |
+
|
| 55 |
+
if min(dist1, dist2) > dist_thresh:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
similar.append(j)
|
| 59 |
+
used.add(j)
|
| 60 |
+
|
| 61 |
+
# Merge similar lines
|
| 62 |
+
if similar:
|
| 63 |
+
points = lines[similar].reshape(-1, 2)
|
| 64 |
+
direction = np.array([np.cos(np.radians(angles[i])),
|
| 65 |
+
np.sin(np.radians(angles[i]))])
|
| 66 |
+
|
| 67 |
+
# Project points onto line direction
|
| 68 |
+
proj = points @ direction
|
| 69 |
+
|
| 70 |
+
# Get extreme points
|
| 71 |
+
min_idx = np.argmin(proj)
|
| 72 |
+
max_idx = np.argmax(proj)
|
| 73 |
+
|
| 74 |
+
merged_line = np.concatenate([points[min_idx], points[max_idx]])
|
| 75 |
+
merged.append(tuple(merged_line))
|
| 76 |
+
|
| 77 |
+
return merged
|
| 78 |
+
|
| 79 |
+
def compute_line_angle(x1: float, y1: float, x2: float, y2: float) -> float:
|
| 80 |
+
"""Compute angle of line segment in degrees"""
|
| 81 |
+
return math.degrees(math.atan2(y2 - y1, x2 - x1)) % 180
|
detectors.py
ADDED
|
@@ -0,0 +1,733 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Optional, Tuple, Dict
|
| 7 |
+
from dataclasses import replace
|
| 8 |
+
from math import sqrt
|
| 9 |
+
import json
|
| 10 |
+
import uuid
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
# Base classes and utilities
|
| 14 |
+
from base import BaseDetector
|
| 15 |
+
from detection_schema import DetectionContext
|
| 16 |
+
from utils import DebugHandler
|
| 17 |
+
from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
|
| 18 |
+
|
| 19 |
+
# DeepLSD model for line detection
|
| 20 |
+
from deeplsd.models.deeplsd_inference import DeepLSD
|
| 21 |
+
from ultralytics import YOLO
|
| 22 |
+
|
| 23 |
+
# Detection schema: dataclasses for different objects
|
| 24 |
+
from detection_schema import (
|
| 25 |
+
BBox,
|
| 26 |
+
Coordinates,
|
| 27 |
+
Point,
|
| 28 |
+
Line,
|
| 29 |
+
Symbol,
|
| 30 |
+
Tag,
|
| 31 |
+
SymbolType,
|
| 32 |
+
LineStyle,
|
| 33 |
+
ConnectionType,
|
| 34 |
+
JunctionType,
|
| 35 |
+
Junction
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
# Skeletonization and label processing for junction detection
|
| 39 |
+
from skimage.morphology import skeletonize
|
| 40 |
+
from skimage.measure import label
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
import os
|
| 44 |
+
import cv2
|
| 45 |
+
import torch
|
| 46 |
+
import numpy as np
|
| 47 |
+
from dataclasses import replace
|
| 48 |
+
from typing import List, Optional
|
| 49 |
+
from detection_utils import robust_merge_lines
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class LineDetector(BaseDetector):
|
| 53 |
+
"""
|
| 54 |
+
DeepLSD-based line detection with patch-based tiling and global merging.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self,
|
| 58 |
+
config: LineConfig,
|
| 59 |
+
model_path: str,
|
| 60 |
+
model_config: dict,
|
| 61 |
+
device: torch.device,
|
| 62 |
+
debug_handler: DebugHandler = None):
|
| 63 |
+
super().__init__(config, debug_handler)
|
| 64 |
+
|
| 65 |
+
# Fix device selection for Apple Silicon
|
| 66 |
+
if torch.backends.mps.is_available():
|
| 67 |
+
self.device = torch.device("mps")
|
| 68 |
+
elif torch.cuda.is_available():
|
| 69 |
+
self.device = torch.device("cuda")
|
| 70 |
+
else:
|
| 71 |
+
self.device = torch.device("cpu")
|
| 72 |
+
|
| 73 |
+
self.model_path = model_path
|
| 74 |
+
self.model_config = model_config
|
| 75 |
+
self.model = self._load_model(model_path)
|
| 76 |
+
|
| 77 |
+
# Patch parameters
|
| 78 |
+
self.patch_size = 512
|
| 79 |
+
self.overlap = 10
|
| 80 |
+
|
| 81 |
+
# Merging thresholds
|
| 82 |
+
self.angle_thresh = 5.0 # degrees
|
| 83 |
+
self.dist_thresh = 5.0 # pixels
|
| 84 |
+
|
| 85 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 86 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
|
| 87 |
+
dilated = cv2.dilate(image, kernel, iterations=2)
|
| 88 |
+
|
| 89 |
+
skeleton = cv2.bitwise_not(dilated)
|
| 90 |
+
skeleton = skeletonize(skeleton // 255)
|
| 91 |
+
skeleton = (skeleton * 255).astype(np.uint8)
|
| 92 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
|
| 93 |
+
clean_image = cv2.dilate(skeleton, kernel, iterations=5)
|
| 94 |
+
|
| 95 |
+
self.debug_handler.save_artifact(name="skeleton", data=clean_image, extension="png")
|
| 96 |
+
|
| 97 |
+
return clean_image
|
| 98 |
+
|
| 99 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 100 |
+
return None
|
| 101 |
+
# -------------------------------------
|
| 102 |
+
# 1) Load Model
|
| 103 |
+
# -------------------------------------
|
| 104 |
+
def _load_model(self, model_path: str) -> DeepLSD:
|
| 105 |
+
if not os.path.exists(model_path):
|
| 106 |
+
raise FileNotFoundError(f"Model file not found: {model_path}")
|
| 107 |
+
ckpt = torch.load(model_path, map_location=self.device)
|
| 108 |
+
model = DeepLSD(self.model_config)
|
| 109 |
+
model.load_state_dict(ckpt["model"])
|
| 110 |
+
return model.to(self.device).eval()
|
| 111 |
+
|
| 112 |
+
# -------------------------------------
|
| 113 |
+
# 2) Main Detection Pipeline
|
| 114 |
+
# -------------------------------------
|
| 115 |
+
def detect(self,
|
| 116 |
+
image: np.ndarray,
|
| 117 |
+
context: DetectionContext,
|
| 118 |
+
mask_coords: Optional[List[BBox]] = None,
|
| 119 |
+
*args,
|
| 120 |
+
**kwargs) -> None:
|
| 121 |
+
"""
|
| 122 |
+
Steps:
|
| 123 |
+
- Optional mask + threshold
|
| 124 |
+
- Tile into overlapping patches
|
| 125 |
+
- For each patch => run DeepLSD => re-map lines to global coords
|
| 126 |
+
- Merge lines robustly
|
| 127 |
+
- Build final Line objects => add to context
|
| 128 |
+
"""
|
| 129 |
+
mask_coords = mask_coords or []
|
| 130 |
+
|
| 131 |
+
skeleton = self._preprocess(image)
|
| 132 |
+
# (A) Optional mask + threshold if you want a binary
|
| 133 |
+
# If your model expects grayscale or binary, do it here:
|
| 134 |
+
processed_img = self._apply_mask_and_threshold(skeleton, mask_coords)
|
| 135 |
+
# (B) Patch-based inference => collect raw lines in global coords
|
| 136 |
+
all_lines = self._detect_in_patches(processed_img)
|
| 137 |
+
|
| 138 |
+
# (C) Merge the lines in the global coordinate system
|
| 139 |
+
merged_line_segments = robust_merge_lines(
|
| 140 |
+
all_lines,
|
| 141 |
+
angle_thresh=self.angle_thresh,
|
| 142 |
+
dist_thresh=self.dist_thresh
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# (D) Convert merged segments => final Line objects, add to context
|
| 146 |
+
for (x1, y1, x2, y2) in merged_line_segments:
|
| 147 |
+
line_obj = self._create_line_object(x1, y1, x2, y2)
|
| 148 |
+
context.add_line(line_obj)
|
| 149 |
+
|
| 150 |
+
# -------------------------------------
|
| 151 |
+
# 3) Optional Mask + Threshold
|
| 152 |
+
# -------------------------------------
|
| 153 |
+
def _apply_mask_and_threshold(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
|
| 154 |
+
"""White out rectangular areas, then threshold to binary (if needed)."""
|
| 155 |
+
masked = image.copy()
|
| 156 |
+
for bbox in mask_coords:
|
| 157 |
+
x1, y1 = int(bbox.xmin), int(bbox.ymin)
|
| 158 |
+
x2, y2 = int(bbox.xmax), int(bbox.ymax)
|
| 159 |
+
cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)
|
| 160 |
+
|
| 161 |
+
# If image has 3 channels, convert to grayscale
|
| 162 |
+
if len(masked.shape) == 3:
|
| 163 |
+
masked_gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
|
| 164 |
+
else:
|
| 165 |
+
masked_gray = masked
|
| 166 |
+
|
| 167 |
+
# Binary threshold (adjust threshold as needed)
|
| 168 |
+
# If your model expects a plain grayscale, skip threshold
|
| 169 |
+
binary_img = cv2.threshold(masked_gray, 127, 255, cv2.THRESH_BINARY)[1]
|
| 170 |
+
return binary_img
|
| 171 |
+
|
| 172 |
+
# -------------------------------------
|
| 173 |
+
# 4) Patch-Based Inference
|
| 174 |
+
# -------------------------------------
|
| 175 |
+
def _detect_in_patches(self, processed_img: np.ndarray) -> List[tuple]:
|
| 176 |
+
"""
|
| 177 |
+
Break the image into overlapping patches, run DeepLSD,
|
| 178 |
+
map local lines => global coords, and return the global line list.
|
| 179 |
+
"""
|
| 180 |
+
patch_size = self.patch_size
|
| 181 |
+
overlap = self.overlap
|
| 182 |
+
|
| 183 |
+
height, width = processed_img.shape[:2]
|
| 184 |
+
step = patch_size - overlap
|
| 185 |
+
|
| 186 |
+
all_lines = []
|
| 187 |
+
|
| 188 |
+
for y in range(0, height, step):
|
| 189 |
+
patch_ymax = min(y + patch_size, height)
|
| 190 |
+
patch_ymin = patch_ymax - patch_size if (patch_ymax - y) < patch_size else y
|
| 191 |
+
if patch_ymin < 0: patch_ymin = 0
|
| 192 |
+
|
| 193 |
+
for x in range(0, width, step):
|
| 194 |
+
patch_xmax = min(x + patch_size, width)
|
| 195 |
+
patch_xmin = patch_xmax - patch_size if (patch_xmax - x) < patch_size else x
|
| 196 |
+
if patch_xmin < 0: patch_xmin = 0
|
| 197 |
+
|
| 198 |
+
patch = processed_img[patch_ymin:patch_ymax, patch_xmin:patch_xmax]
|
| 199 |
+
|
| 200 |
+
# Run model
|
| 201 |
+
local_lines = self._run_model_inference(patch)
|
| 202 |
+
|
| 203 |
+
# Convert local lines => global coords
|
| 204 |
+
for ln in local_lines:
|
| 205 |
+
(x1_local, y1_local), (x2_local, y2_local) = ln
|
| 206 |
+
|
| 207 |
+
# offset by patch_xmin, patch_ymin
|
| 208 |
+
gx1 = x1_local + patch_xmin
|
| 209 |
+
gy1 = y1_local + patch_ymin
|
| 210 |
+
gx2 = x2_local + patch_xmin
|
| 211 |
+
gy2 = y2_local + patch_ymin
|
| 212 |
+
|
| 213 |
+
# Optional: clamp or filter lines partially out-of-bounds
|
| 214 |
+
if 0 <= gx1 < width and 0 <= gx2 < width and 0 <= gy1 < height and 0 <= gy2 < height:
|
| 215 |
+
all_lines.append((gx1, gy1, gx2, gy2))
|
| 216 |
+
|
| 217 |
+
return all_lines
|
| 218 |
+
|
| 219 |
+
# -------------------------------------
|
| 220 |
+
# 5) Model Inference (Single Patch)
|
| 221 |
+
# -------------------------------------
|
| 222 |
+
def _run_model_inference(self, patch_img: np.ndarray) -> np.ndarray:
|
| 223 |
+
"""
|
| 224 |
+
Run DeepLSD on a single patch (already masked/thresholded).
|
| 225 |
+
patch_img shape: [patchH, patchW].
|
| 226 |
+
Returns lines shape: [N, 2, 2].
|
| 227 |
+
"""
|
| 228 |
+
# Convert patch to float32 and scale
|
| 229 |
+
inp = torch.tensor(patch_img, dtype=torch.float32, device=self.device)[None, None] / 255.0
|
| 230 |
+
with torch.no_grad():
|
| 231 |
+
output = self.model({"image": inp})
|
| 232 |
+
lines = output["lines"][0] # shape (N, 2, 2)
|
| 233 |
+
return lines
|
| 234 |
+
|
| 235 |
+
# -------------------------------------
|
| 236 |
+
# 6) Convert Merged Segments => Line Objects
|
| 237 |
+
# -------------------------------------
|
| 238 |
+
def _create_line_object(self, x1: float, y1: float, x2: float, y2: float) -> Line:
|
| 239 |
+
"""
|
| 240 |
+
Create a minimal `Line` object from the final merged coordinates.
|
| 241 |
+
"""
|
| 242 |
+
margin = 2
|
| 243 |
+
# Start point
|
| 244 |
+
start_pt = Point(
|
| 245 |
+
coords=Coordinates(int(x1), int(y1)),
|
| 246 |
+
bbox=BBox(
|
| 247 |
+
xmin=int(x1 - margin),
|
| 248 |
+
ymin=int(y1 - margin),
|
| 249 |
+
xmax=int(x1 + margin),
|
| 250 |
+
ymax=int(y1 + margin)
|
| 251 |
+
),
|
| 252 |
+
type=JunctionType.END,
|
| 253 |
+
confidence=1.0
|
| 254 |
+
)
|
| 255 |
+
# End point
|
| 256 |
+
end_pt = Point(
|
| 257 |
+
coords=Coordinates(int(x2), int(y2)),
|
| 258 |
+
bbox=BBox(
|
| 259 |
+
xmin=int(x2 - margin),
|
| 260 |
+
ymin=int(y2 - margin),
|
| 261 |
+
xmax=int(x2 + margin),
|
| 262 |
+
ymax=int(y2 + margin)
|
| 263 |
+
),
|
| 264 |
+
type=JunctionType.END,
|
| 265 |
+
confidence=1.0
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Overall bounding box
|
| 269 |
+
x_min = int(min(x1, x2))
|
| 270 |
+
x_max = int(max(x1, x2))
|
| 271 |
+
y_min = int(min(y1, y2))
|
| 272 |
+
y_max = int(max(y1, y2))
|
| 273 |
+
|
| 274 |
+
line_obj = Line(
|
| 275 |
+
start=start_pt,
|
| 276 |
+
end=end_pt,
|
| 277 |
+
bbox=BBox(xmin=x_min, ymin=y_min, xmax=x_max, ymax=y_max),
|
| 278 |
+
style=LineStyle(
|
| 279 |
+
connection_type=ConnectionType.SOLID,
|
| 280 |
+
stroke_width=2,
|
| 281 |
+
color="#000000"
|
| 282 |
+
),
|
| 283 |
+
confidence=0.9,
|
| 284 |
+
topological_links=[]
|
| 285 |
+
)
|
| 286 |
+
return line_obj
|
| 287 |
+
|
| 288 |
+
class PointDetector(BaseDetector):
|
| 289 |
+
"""
|
| 290 |
+
A detector that:
|
| 291 |
+
1) Reads lines from the context
|
| 292 |
+
2) Clusters endpoints within 'threshold_distance'
|
| 293 |
+
3) Updates lines so that shared endpoints reference the same Point object
|
| 294 |
+
"""
|
| 295 |
+
|
| 296 |
+
def __init__(self,
|
| 297 |
+
config:PointConfig,
|
| 298 |
+
debug_handler: DebugHandler = None):
|
| 299 |
+
super().__init__(config, debug_handler) # No real model to load
|
| 300 |
+
self.threshold_distance = config.threshold_distance
|
| 301 |
+
|
| 302 |
+
def _load_model(self, model_path: str):
|
| 303 |
+
"""No model needed for simple point unification."""
|
| 304 |
+
return None
|
| 305 |
+
|
| 306 |
+
def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
|
| 307 |
+
"""
|
| 308 |
+
Main method called by the pipeline.
|
| 309 |
+
1) Gather all line endpoints from context
|
| 310 |
+
2) Cluster them within 'threshold_distance'
|
| 311 |
+
3) Update the line endpoints so they reference the unified cluster point
|
| 312 |
+
"""
|
| 313 |
+
# 1) Collect all endpoints
|
| 314 |
+
endpoints = []
|
| 315 |
+
for line in context.lines.values():
|
| 316 |
+
endpoints.append(line.start)
|
| 317 |
+
endpoints.append(line.end)
|
| 318 |
+
|
| 319 |
+
# 2) Cluster endpoints
|
| 320 |
+
clusters = self._cluster_points(endpoints, self.threshold_distance)
|
| 321 |
+
|
| 322 |
+
# 3) Build a dictionary of "representative" points
|
| 323 |
+
# So that each cluster has one "canonical" point
|
| 324 |
+
# Then we link all the points in that cluster to the canonical reference
|
| 325 |
+
unified_point_map = {}
|
| 326 |
+
for cluster in clusters:
|
| 327 |
+
# let's pick the first point in the cluster as the "representative"
|
| 328 |
+
rep_point = cluster[0]
|
| 329 |
+
for p in cluster[1:]:
|
| 330 |
+
unified_point_map[p.id] = rep_point
|
| 331 |
+
|
| 332 |
+
# 4) Update all lines to reference the canonical point
|
| 333 |
+
for line in context.lines.values():
|
| 334 |
+
# unify start
|
| 335 |
+
if line.start.id in unified_point_map:
|
| 336 |
+
line.start = unified_point_map[line.start.id]
|
| 337 |
+
# unify end
|
| 338 |
+
if line.end.id in unified_point_map:
|
| 339 |
+
line.end = unified_point_map[line.end.id]
|
| 340 |
+
|
| 341 |
+
# We could also store the final set of unique points back in context.points
|
| 342 |
+
# (e.g. clearing old duplicates).
|
| 343 |
+
# That step is optional: you might prefer to keep everything in lines only,
|
| 344 |
+
# or you might want context.points as a separate reference.
|
| 345 |
+
|
| 346 |
+
# If you want to keep unique points in context.points:
|
| 347 |
+
new_points = {}
|
| 348 |
+
for line in context.lines.values():
|
| 349 |
+
new_points[line.start.id] = line.start
|
| 350 |
+
new_points[line.end.id] = line.end
|
| 351 |
+
context.points = new_points # replace the dictionary of points
|
| 352 |
+
|
| 353 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 354 |
+
"""No specific image preprocessing needed."""
|
| 355 |
+
return image
|
| 356 |
+
|
| 357 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 358 |
+
"""No specific image postprocessing needed."""
|
| 359 |
+
return image
|
| 360 |
+
|
| 361 |
+
# ----------------------
|
| 362 |
+
# HELPER: clustering
|
| 363 |
+
# ----------------------
|
| 364 |
+
def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
|
| 365 |
+
"""
|
| 366 |
+
Very naive clustering:
|
| 367 |
+
1) Start from the first point
|
| 368 |
+
2) If it's within threshold of an existing cluster's representative,
|
| 369 |
+
put it in that cluster
|
| 370 |
+
3) Otherwise start a new cluster
|
| 371 |
+
Return: list of clusters, each is a list of Points
|
| 372 |
+
"""
|
| 373 |
+
clusters = []
|
| 374 |
+
|
| 375 |
+
for pt in points:
|
| 376 |
+
placed = False
|
| 377 |
+
for cluster in clusters:
|
| 378 |
+
# pick the first point in the cluster as reference
|
| 379 |
+
ref_pt = cluster[0]
|
| 380 |
+
if self._distance(pt, ref_pt) < threshold:
|
| 381 |
+
cluster.append(pt)
|
| 382 |
+
placed = True
|
| 383 |
+
break
|
| 384 |
+
|
| 385 |
+
if not placed:
|
| 386 |
+
clusters.append([pt])
|
| 387 |
+
|
| 388 |
+
return clusters
|
| 389 |
+
|
| 390 |
+
def _distance(self, p1: Point, p2: Point) -> float:
|
| 391 |
+
dx = p1.coords.x - p2.coords.x
|
| 392 |
+
dy = p1.coords.y - p2.coords.y
|
| 393 |
+
return sqrt(dx*dx + dy*dy)
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
class JunctionDetector(BaseDetector):
|
| 397 |
+
"""
|
| 398 |
+
Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
|
| 399 |
+
and analyzing local connectivity. Also creates Junction objects in the context.
|
| 400 |
+
"""
|
| 401 |
+
|
| 402 |
+
def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
|
| 403 |
+
super().__init__(config, debug_handler) # no real model path
|
| 404 |
+
self.window_size = config.window_size
|
| 405 |
+
self.radius = config.radius
|
| 406 |
+
self.angle_threshold_lb = config.angle_threshold_lb
|
| 407 |
+
self.angle_threshold_ub = config.angle_threshold_ub
|
| 408 |
+
self.debug_handler = debug_handler or DebugHandler()
|
| 409 |
+
|
| 410 |
+
def _load_model(self, model_path: str):
|
| 411 |
+
"""Not loading any actual model, just skeleton logic."""
|
| 412 |
+
return None
|
| 413 |
+
|
| 414 |
+
def detect(self,
|
| 415 |
+
image: np.ndarray,
|
| 416 |
+
context: DetectionContext,
|
| 417 |
+
*args,
|
| 418 |
+
**kwargs) -> None:
|
| 419 |
+
"""
|
| 420 |
+
1) Convert to binary & skeletonize
|
| 421 |
+
2) Classify each point in the context
|
| 422 |
+
3) Create a Junction for each point and store it in context.junctions
|
| 423 |
+
(with 'connected_lines' referencing lines that share this point).
|
| 424 |
+
"""
|
| 425 |
+
# 1) Preprocess -> skeleton
|
| 426 |
+
skeleton = self._create_skeleton(image)
|
| 427 |
+
|
| 428 |
+
# 2) Classify each point
|
| 429 |
+
for pt in context.points.values():
|
| 430 |
+
pt.type = self._classify_point(skeleton, pt)
|
| 431 |
+
|
| 432 |
+
# 3) Create a Junction object for each point
|
| 433 |
+
# If you prefer only T or L, you can filter out END points.
|
| 434 |
+
self._record_junctions_in_context(context)
|
| 435 |
+
|
| 436 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 437 |
+
"""We might do thresholding; let's do a simple binary threshold."""
|
| 438 |
+
if image.ndim == 3:
|
| 439 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 440 |
+
else:
|
| 441 |
+
gray = image
|
| 442 |
+
_, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
|
| 443 |
+
return bin_image
|
| 444 |
+
|
| 445 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 446 |
+
return image
|
| 447 |
+
|
| 448 |
+
def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
|
| 449 |
+
"""Skeletonize the binarized image."""
|
| 450 |
+
bin_img = self._preprocess(raw_image)
|
| 451 |
+
# For skeletonize, we need a boolean array
|
| 452 |
+
inv = cv2.bitwise_not(bin_img)
|
| 453 |
+
inv_bool = (inv > 127).astype(np.uint8)
|
| 454 |
+
skel = skeletonize(inv_bool).astype(np.uint8) * 255
|
| 455 |
+
return skel
|
| 456 |
+
|
| 457 |
+
def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
|
| 458 |
+
"""
|
| 459 |
+
Given a skeleton image, look around 'pt' in a local window
|
| 460 |
+
to determine if it's an END, L, or T.
|
| 461 |
+
"""
|
| 462 |
+
classification = JunctionType.END # default
|
| 463 |
+
|
| 464 |
+
half_w = self.window_size // 2
|
| 465 |
+
x, y = pt.coords.x, pt.coords.y
|
| 466 |
+
|
| 467 |
+
top = max(0, y - half_w)
|
| 468 |
+
bottom = min(skeleton.shape[0], y + half_w + 1)
|
| 469 |
+
left = max(0, x - half_w)
|
| 470 |
+
right = min(skeleton.shape[1], x + half_w + 1)
|
| 471 |
+
|
| 472 |
+
patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)
|
| 473 |
+
|
| 474 |
+
# create circular mask
|
| 475 |
+
circle_mask = np.zeros_like(patch, dtype=np.uint8)
|
| 476 |
+
local_cx = x - left
|
| 477 |
+
local_cy = y - top
|
| 478 |
+
cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
|
| 479 |
+
circle_skel = patch & circle_mask
|
| 480 |
+
|
| 481 |
+
# label connected regions
|
| 482 |
+
labeled = label(circle_skel, connectivity=2)
|
| 483 |
+
num_exits = labeled.max()
|
| 484 |
+
|
| 485 |
+
if num_exits == 1:
|
| 486 |
+
classification = JunctionType.END
|
| 487 |
+
elif num_exits == 2:
|
| 488 |
+
# check angle for L
|
| 489 |
+
classification = self._check_angle_for_L(labeled)
|
| 490 |
+
elif num_exits == 3:
|
| 491 |
+
classification = JunctionType.T
|
| 492 |
+
|
| 493 |
+
return classification
|
| 494 |
+
|
| 495 |
+
def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
|
| 496 |
+
"""
|
| 497 |
+
If the angle between two branches is within
|
| 498 |
+
[angle_threshold_lb, angle_threshold_ub], it's 'L'.
|
| 499 |
+
Otherwise default to END.
|
| 500 |
+
"""
|
| 501 |
+
coords = np.argwhere(labeled_region == 1)
|
| 502 |
+
if len(coords) < 2:
|
| 503 |
+
return JunctionType.END
|
| 504 |
+
|
| 505 |
+
(y1, x1), (y2, x2) = coords[:2]
|
| 506 |
+
dx = x2 - x1
|
| 507 |
+
dy = y2 - y1
|
| 508 |
+
angle = math.degrees(math.atan2(dy, dx))
|
| 509 |
+
acute_angle = min(abs(angle), 180 - abs(angle))
|
| 510 |
+
|
| 511 |
+
if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
|
| 512 |
+
return JunctionType.L
|
| 513 |
+
return JunctionType.END
|
| 514 |
+
|
| 515 |
+
# -----------------------------------------
|
| 516 |
+
# EXTRA STEP: Create Junction objects
|
| 517 |
+
# -----------------------------------------
|
| 518 |
+
def _record_junctions_in_context(self, context: DetectionContext):
|
| 519 |
+
"""
|
| 520 |
+
Create a Junction object for each point in context.points.
|
| 521 |
+
If you only want T/L points as junctions, filter them out.
|
| 522 |
+
Also track any lines that connect to this point.
|
| 523 |
+
"""
|
| 524 |
+
|
| 525 |
+
for pt in context.points.values():
|
| 526 |
+
# If you prefer to store all points as junction, do it:
|
| 527 |
+
# or if you want only T or L, do:
|
| 528 |
+
# if pt.type in {JunctionType.T, JunctionType.L}: ...
|
| 529 |
+
|
| 530 |
+
jn = Junction(
|
| 531 |
+
center=pt.coords,
|
| 532 |
+
junction_type=pt.type,
|
| 533 |
+
# add more properties if needed
|
| 534 |
+
)
|
| 535 |
+
|
| 536 |
+
# find lines that connect to this point
|
| 537 |
+
connected_lines = []
|
| 538 |
+
for ln in context.lines.values():
|
| 539 |
+
if ln.start.id == pt.id or ln.end.id == pt.id:
|
| 540 |
+
connected_lines.append(ln.id)
|
| 541 |
+
|
| 542 |
+
jn.connected_lines = connected_lines
|
| 543 |
+
|
| 544 |
+
# add to context
|
| 545 |
+
context.add_junction(jn)
|
| 546 |
+
|
| 547 |
+
import json
|
| 548 |
+
import uuid
|
| 549 |
+
|
| 550 |
+
class SymbolDetector(BaseDetector):
|
| 551 |
+
"""
|
| 552 |
+
A placeholder detector that reads precomputed symbol data
|
| 553 |
+
from a JSON file and populates the context with Symbol objects.
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
def __init__(self,
|
| 557 |
+
config: SymbolConfig,
|
| 558 |
+
debug_handler: Optional[DebugHandler] = None,
|
| 559 |
+
symbol_json_path: str = "./symbols.json"):
|
| 560 |
+
super().__init__(config=config, debug_handler=debug_handler)
|
| 561 |
+
self.symbol_json_path = symbol_json_path
|
| 562 |
+
|
| 563 |
+
def _load_model(self, model_path: str):
|
| 564 |
+
"""Not loading an actual model; symbol data is read from JSON."""
|
| 565 |
+
return None
|
| 566 |
+
|
| 567 |
+
def detect(self,
|
| 568 |
+
image: np.ndarray,
|
| 569 |
+
context: DetectionContext,
|
| 570 |
+
# roi_offset: Tuple[int, int],
|
| 571 |
+
*args,
|
| 572 |
+
**kwargs) -> None:
|
| 573 |
+
"""
|
| 574 |
+
Reads from a JSON file containing symbol info,
|
| 575 |
+
adjusts coordinates using roi_offset, and updates context.
|
| 576 |
+
"""
|
| 577 |
+
symbol_data = self._load_json_data(self.symbol_json_path)
|
| 578 |
+
if not symbol_data:
|
| 579 |
+
return
|
| 580 |
+
|
| 581 |
+
# x_min, y_min = roi_offset # Offset values from cropping
|
| 582 |
+
|
| 583 |
+
for record in symbol_data.get("detections", []): # Fix: Use "detections" key
|
| 584 |
+
# sym_obj = self._parse_symbol_record(record, x_min, y_min)
|
| 585 |
+
sym_obj = self._parse_symbol_record(record)
|
| 586 |
+
context.add_symbol(sym_obj)
|
| 587 |
+
|
| 588 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 589 |
+
return image
|
| 590 |
+
|
| 591 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 592 |
+
return image
|
| 593 |
+
|
| 594 |
+
# --------------
|
| 595 |
+
# HELPER METHODS
|
| 596 |
+
# --------------
|
| 597 |
+
def _load_json_data(self, json_path: str) -> dict:
|
| 598 |
+
if not os.path.exists(json_path):
|
| 599 |
+
self.debug_handler.save_artifact(name="symbol_error",
|
| 600 |
+
data=b"Missing symbol JSON file",
|
| 601 |
+
extension="txt")
|
| 602 |
+
return {}
|
| 603 |
+
|
| 604 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 605 |
+
return json.load(f)
|
| 606 |
+
|
| 607 |
+
def _parse_symbol_record(self, record: dict) -> Symbol:
|
| 608 |
+
"""
|
| 609 |
+
Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
|
| 610 |
+
"""
|
| 611 |
+
bbox_list = record.get("bbox", [0, 0, 0, 0])
|
| 612 |
+
# bbox_obj = BBox(
|
| 613 |
+
# xmin=bbox_list[0] - x_min,
|
| 614 |
+
# ymin=bbox_list[1] - y_min,
|
| 615 |
+
# xmax=bbox_list[2] - x_min,
|
| 616 |
+
# ymax=bbox_list[3] - y_min
|
| 617 |
+
# )
|
| 618 |
+
|
| 619 |
+
bbox_obj = BBox(
|
| 620 |
+
xmin=bbox_list[0],
|
| 621 |
+
ymin=bbox_list[1],
|
| 622 |
+
xmax=bbox_list[2],
|
| 623 |
+
ymax=bbox_list[3]
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
# Compute the center
|
| 628 |
+
center_coords = Coordinates(
|
| 629 |
+
x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
|
| 630 |
+
y=(bbox_obj.ymin + bbox_obj.ymax) // 2
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
return Symbol(
|
| 634 |
+
id=record.get("symbol_id", ""),
|
| 635 |
+
class_id=record.get("class_id", -1),
|
| 636 |
+
original_label=record.get("original_label", ""),
|
| 637 |
+
category=record.get("category", ""),
|
| 638 |
+
type=record.get("type", ""),
|
| 639 |
+
label=record.get("label", ""),
|
| 640 |
+
bbox=bbox_obj,
|
| 641 |
+
center=center_coords,
|
| 642 |
+
confidence=record.get("confidence", 0.95),
|
| 643 |
+
model_source=record.get("model_source", ""),
|
| 644 |
+
connections=[]
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
class TagDetector(BaseDetector):
|
| 648 |
+
"""
|
| 649 |
+
A placeholder detector that reads precomputed tag data
|
| 650 |
+
from a JSON file and populates the context with Tag objects.
|
| 651 |
+
"""
|
| 652 |
+
|
| 653 |
+
def __init__(self,
|
| 654 |
+
config: TagConfig,
|
| 655 |
+
debug_handler: Optional[DebugHandler] = None,
|
| 656 |
+
tag_json_path: str = "./tags.json"):
|
| 657 |
+
super().__init__(config=config, debug_handler=debug_handler)
|
| 658 |
+
self.tag_json_path = tag_json_path
|
| 659 |
+
|
| 660 |
+
def _load_model(self, model_path: str):
|
| 661 |
+
"""Not loading an actual model; tag data is read from JSON."""
|
| 662 |
+
return None
|
| 663 |
+
|
| 664 |
+
def detect(self,
|
| 665 |
+
image: np.ndarray,
|
| 666 |
+
context: DetectionContext,
|
| 667 |
+
# roi_offset: Tuple[int, int],
|
| 668 |
+
*args,
|
| 669 |
+
**kwargs) -> None:
|
| 670 |
+
"""
|
| 671 |
+
Reads from a JSON file containing tag info,
|
| 672 |
+
adjusts coordinates using roi_offset, and updates context.
|
| 673 |
+
"""
|
| 674 |
+
|
| 675 |
+
tag_data = self._load_json_data(self.tag_json_path)
|
| 676 |
+
if not tag_data:
|
| 677 |
+
return
|
| 678 |
+
|
| 679 |
+
# x_min, y_min = roi_offset # Offset values from cropping
|
| 680 |
+
|
| 681 |
+
for record in tag_data.get("detections", []): # Fix: Use "detections" key
|
| 682 |
+
# tag_obj = self._parse_tag_record(record, x_min, y_min)
|
| 683 |
+
tag_obj = self._parse_tag_record(record)
|
| 684 |
+
context.add_tag(tag_obj)
|
| 685 |
+
|
| 686 |
+
def _preprocess(self, image: np.ndarray) -> np.ndarray:
|
| 687 |
+
return image
|
| 688 |
+
|
| 689 |
+
def _postprocess(self, image: np.ndarray) -> np.ndarray:
|
| 690 |
+
return image
|
| 691 |
+
|
| 692 |
+
# --------------
|
| 693 |
+
# HELPER METHODS
|
| 694 |
+
# --------------
|
| 695 |
+
def _load_json_data(self, json_path: str) -> dict:
|
| 696 |
+
if not os.path.exists(json_path):
|
| 697 |
+
self.debug_handler.save_artifact(name="tag_error",
|
| 698 |
+
data=b"Missing tag JSON file",
|
| 699 |
+
extension="txt")
|
| 700 |
+
return {}
|
| 701 |
+
|
| 702 |
+
with open(json_path, "r", encoding="utf-8") as f:
|
| 703 |
+
return json.load(f)
|
| 704 |
+
|
| 705 |
+
def _parse_tag_record(self, record: dict) -> Tag:
|
| 706 |
+
"""
|
| 707 |
+
Builds a Tag object from a JSON record, adjusting coordinates for cropping.
|
| 708 |
+
"""
|
| 709 |
+
bbox_list = record.get("bbox", [0, 0, 0, 0])
|
| 710 |
+
# bbox_obj = BBox(
|
| 711 |
+
# xmin=bbox_list[0] - x_min,
|
| 712 |
+
# ymin=bbox_list[1] - y_min,
|
| 713 |
+
# xmax=bbox_list[2] - x_min,
|
| 714 |
+
# ymax=bbox_list[3] - y_min
|
| 715 |
+
# )
|
| 716 |
+
|
| 717 |
+
bbox_obj = BBox(
|
| 718 |
+
xmin=bbox_list[0],
|
| 719 |
+
ymin=bbox_list[1],
|
| 720 |
+
xmax=bbox_list[2],
|
| 721 |
+
ymax=bbox_list[3]
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
return Tag(
|
| 725 |
+
text=record.get("text", ""),
|
| 726 |
+
bbox=bbox_obj,
|
| 727 |
+
confidence=record.get("confidence", 1.0),
|
| 728 |
+
source=record.get("source", ""),
|
| 729 |
+
text_type=record.get("text_type", "Unknown"),
|
| 730 |
+
id=record.get("id", str(uuid.uuid4())),
|
| 731 |
+
font_size=record.get("font_size", 12),
|
| 732 |
+
rotation=record.get("rotation", 0.0)
|
| 733 |
+
)
|
download_models.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import torch
|
| 4 |
+
from doctr.models import ocr_predictor
|
| 5 |
+
from ultralytics import YOLO
|
| 6 |
+
from deeplsd.models.deeplsd_inference import DeepLSD
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
|
| 9 |
+
# Load environment variables
|
| 10 |
+
load_dotenv()
|
| 11 |
+
|
| 12 |
+
def copy_local_models():
|
| 13 |
+
"""Copy models from local directory to deployment"""
|
| 14 |
+
# Create model directories
|
| 15 |
+
os.makedirs('models/yolo', exist_ok=True)
|
| 16 |
+
os.makedirs('models/deeplsd', exist_ok=True)
|
| 17 |
+
os.makedirs('models/doctr', exist_ok=True)
|
| 18 |
+
|
| 19 |
+
# Source paths (adjust these to your local paths)
|
| 20 |
+
local_models_dir = "../models"
|
| 21 |
+
|
| 22 |
+
# Copy YOLO model
|
| 23 |
+
yolo_src = os.path.join(local_models_dir, "yolov8n.pt")
|
| 24 |
+
yolo_dst = "models/yolo/yolov8n.pt"
|
| 25 |
+
if os.path.exists(yolo_src):
|
| 26 |
+
shutil.copy2(yolo_src, yolo_dst)
|
| 27 |
+
print(f"Copied YOLO model to {yolo_dst}")
|
| 28 |
+
|
| 29 |
+
# Copy DeepLSD model
|
| 30 |
+
deeplsd_src = os.path.join(local_models_dir, "deeplsd_md.tar")
|
| 31 |
+
deeplsd_dst = "models/deeplsd/deeplsd_md.tar"
|
| 32 |
+
if os.path.exists(deeplsd_src):
|
| 33 |
+
shutil.copy2(deeplsd_src, deeplsd_dst)
|
| 34 |
+
print(f"Copied DeepLSD model to {deeplsd_dst}")
|
| 35 |
+
|
| 36 |
+
# Copy DocTR model if exists
|
| 37 |
+
doctr_src = os.path.join(local_models_dir, "ocr_predictor.pt")
|
| 38 |
+
doctr_dst = "models/doctr/ocr_predictor.pt"
|
| 39 |
+
if os.path.exists(doctr_src):
|
| 40 |
+
shutil.copy2(doctr_src, doctr_dst)
|
| 41 |
+
print(f"Copied DocTR model to {doctr_dst}")
|
| 42 |
+
else:
|
| 43 |
+
# Download DocTR model if not available locally
|
| 44 |
+
predictor = ocr_predictor(pretrained=True)
|
| 45 |
+
torch.save(predictor.state_dict(), doctr_dst)
|
| 46 |
+
print(f"Downloaded DocTR model to {doctr_dst}")
|
| 47 |
+
|
| 48 |
+
if __name__ == "__main__":
|
| 49 |
+
copy_local_models()
|
gradioChatApp.py
ADDED
|
@@ -0,0 +1,806 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import base64
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import json
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
from symbol_detection import run_detection_with_optimal_threshold
|
| 7 |
+
from line_detection_ai import DiagramDetectionPipeline, LineDetector, LineConfig, ImageConfig, DebugHandler, \
|
| 8 |
+
PointConfig, JunctionConfig, PointDetector, JunctionDetector, SymbolConfig, SymbolDetector, TagConfig, TagDetector
|
| 9 |
+
from data_aggregation_ai import DataAggregator
|
| 10 |
+
from chatbot_agent import get_assistant_response
|
| 11 |
+
from storage import StorageFactory, LocalStorage
|
| 12 |
+
import traceback
|
| 13 |
+
from text_detection_combined import process_drawing
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
from pdf_processor import DocumentProcessor
|
| 16 |
+
import networkx as nx
|
| 17 |
+
import logging
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
from dotenv import load_dotenv
|
| 20 |
+
import torch
|
| 21 |
+
from graph_visualization import create_graph_visualization
|
| 22 |
+
import shutil
|
| 23 |
+
from detection_schema import BBox # Add this import
|
| 24 |
+
import cv2
|
| 25 |
+
import numpy as np
|
| 26 |
+
import time
|
| 27 |
+
from huggingface_hub import HfApi, login
|
| 28 |
+
from download_models import download_models, copy_local_models
|
| 29 |
+
|
| 30 |
+
# Load environment variables from .env file
|
| 31 |
+
load_dotenv()
|
| 32 |
+
|
| 33 |
+
# Configure logging at the start of the file
|
| 34 |
+
logging.basicConfig(
|
| 35 |
+
level=logging.INFO,
|
| 36 |
+
format='%(asctime)s - %(levelname)s - %(message)s',
|
| 37 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Get logger for this module
|
| 41 |
+
logger = logging.getLogger(__name__)
|
| 42 |
+
|
| 43 |
+
# Disable duplicate logs from other modules
|
| 44 |
+
logging.getLogger('PIL').setLevel(logging.WARNING)
|
| 45 |
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
| 46 |
+
logging.getLogger('gradio').setLevel(logging.WARNING)
|
| 47 |
+
logging.getLogger('networkx').setLevel(logging.WARNING)
|
| 48 |
+
logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
|
| 49 |
+
logging.getLogger('symbol_detection').setLevel(logging.WARNING)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# Only log important messages
|
| 53 |
+
def log_process_step(message, level=logging.INFO):
|
| 54 |
+
"""Log processing steps with appropriate level"""
|
| 55 |
+
if level >= logging.WARNING:
|
| 56 |
+
logger.log(level, message)
|
| 57 |
+
elif "completed" in message.lower() or "generated" in message.lower():
|
| 58 |
+
logger.info(message)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Helper function to format timestamps
|
| 62 |
+
def get_timestamp():
|
| 63 |
+
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def format_message(role, content):
|
| 67 |
+
"""Format message for chatbot history."""
|
| 68 |
+
return {"role": role, "content": content}
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# Load avatar images for agents
|
| 72 |
+
localStorage = LocalStorage()
|
| 73 |
+
agent_avatar = base64.b64encode(localStorage.load_file("assets/AiAgent.png")).decode()
|
| 74 |
+
llm_avatar = base64.b64encode(localStorage.load_file("assets/llm.png")).decode()
|
| 75 |
+
user_avatar = base64.b64encode(localStorage.load_file("assets/user.png")).decode()
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Chat message formatting with avatars and enhanced HTML for readability
|
| 79 |
+
def chat_message(role, message, avatar, timestamp):
|
| 80 |
+
# Convert Markdown-style formatting to HTML
|
| 81 |
+
formatted_message = (
|
| 82 |
+
message.replace("**", "<strong>").replace("**", "</strong>")
|
| 83 |
+
.replace("###", "<h3>").replace("##", "<h2>")
|
| 84 |
+
.replace("#", "<h1>").replace("\n", "<br>")
|
| 85 |
+
.replace("```", "<pre><code>").replace("`", "</code></pre>")
|
| 86 |
+
.replace("\n1. ", "<br>1. ") # For ordered lists starting with "1."
|
| 87 |
+
.replace("\n2. ", "<br>2. ")
|
| 88 |
+
.replace("\n3. ", "<br>3. ")
|
| 89 |
+
.replace("\n4. ", "<br>4. ")
|
| 90 |
+
.replace("\n5. ", "<br>5. ")
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
return f"""
|
| 94 |
+
<div class="chat-message {role}">
|
| 95 |
+
<img src="data:image/png;base64,{avatar}" class="avatar"/>
|
| 96 |
+
<div>
|
| 97 |
+
<div class="speech-bubble {role}-bubble">{formatted_message}</div>
|
| 98 |
+
<div class="timestamp">{timestamp}</div>
|
| 99 |
+
</div>
|
| 100 |
+
</div>
|
| 101 |
+
"""
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def resize_to_fit(image_path, max_width=800, max_height=600):
|
| 105 |
+
"""Resize image to fit editor while maintaining aspect ratio"""
|
| 106 |
+
# Read image
|
| 107 |
+
img = cv2.imread(image_path)
|
| 108 |
+
if img is None:
|
| 109 |
+
return None, 1.0
|
| 110 |
+
|
| 111 |
+
# Get original dimensions
|
| 112 |
+
h, w = img.shape[:2]
|
| 113 |
+
|
| 114 |
+
# Calculate scale factor to fit within max dimensions
|
| 115 |
+
scale_w = max_width / w
|
| 116 |
+
scale_h = max_height / h
|
| 117 |
+
scale = min(scale_w, scale_h)
|
| 118 |
+
|
| 119 |
+
# Always resize to fit the editor window
|
| 120 |
+
new_w = int(w * scale)
|
| 121 |
+
new_h = int(h * scale)
|
| 122 |
+
resized = cv2.resize(img, (new_w, new_h))
|
| 123 |
+
return resized, scale
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
# Main processing function for P&ID steps
|
| 127 |
+
def process_pnid(image_file, progress=gr.Progress()):
|
| 128 |
+
"""Process P&ID document with real-time progress updates."""
|
| 129 |
+
try:
|
| 130 |
+
if image_file is None:
|
| 131 |
+
raise ValueError("No file uploaded. Please upload a file first.")
|
| 132 |
+
|
| 133 |
+
progress_text = []
|
| 134 |
+
outputs = [None] * 9 # Changed from 8 to 9 to match UI outputs
|
| 135 |
+
base_name = os.path.splitext(os.path.basename(image_file.name))[0] + "_page_1"
|
| 136 |
+
|
| 137 |
+
# Initialize chat history with proper format
|
| 138 |
+
chat_history = [{"role": "assistant", "content": "Welcome! Upload a P&ID to begin analysis."}]
|
| 139 |
+
outputs[7] = chat_history # Chat history moved to index 7
|
| 140 |
+
|
| 141 |
+
def update_progress(step, message):
|
| 142 |
+
progress_text.append(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {message}")
|
| 143 |
+
outputs[0] = "\n".join(progress_text) # Progress text
|
| 144 |
+
progress(step)
|
| 145 |
+
|
| 146 |
+
# Initialize storage and results directory
|
| 147 |
+
storage = StorageFactory.get_storage()
|
| 148 |
+
results_dir = "results"
|
| 149 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 150 |
+
|
| 151 |
+
# Clean results directory
|
| 152 |
+
logger.info("Cleaned results directory: results")
|
| 153 |
+
for file in os.listdir(results_dir):
|
| 154 |
+
file_path = os.path.join(results_dir, file)
|
| 155 |
+
try:
|
| 156 |
+
if os.path.isfile(file_path):
|
| 157 |
+
os.unlink(file_path)
|
| 158 |
+
except Exception as e:
|
| 159 |
+
logger.error(f"Error deleting file {file_path}: {str(e)}")
|
| 160 |
+
|
| 161 |
+
# Step 1: File Upload (10%)
|
| 162 |
+
logger.info(f"Processing file: {os.path.basename(image_file.name)}")
|
| 163 |
+
update_progress(0.1, "Step 1/7: File uploaded successfully")
|
| 164 |
+
yield outputs
|
| 165 |
+
|
| 166 |
+
# Step 2: Document Processing - Get high quality PNG
|
| 167 |
+
update_progress(0.2, "Step 2/7: Processing document...")
|
| 168 |
+
doc_processor = DocumentProcessor(storage)
|
| 169 |
+
processed_pages = doc_processor.process_document(
|
| 170 |
+
file_path=image_file,
|
| 171 |
+
output_dir=results_dir
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if not processed_pages:
|
| 175 |
+
raise ValueError("No pages processed from document")
|
| 176 |
+
|
| 177 |
+
# Use high quality PNG for everything
|
| 178 |
+
high_quality_png = processed_pages[0]
|
| 179 |
+
outputs[1] = high_quality_png # P&ID Tab shows original high quality
|
| 180 |
+
update_progress(0.25, "Document loaded and displayed")
|
| 181 |
+
yield outputs
|
| 182 |
+
|
| 183 |
+
# Step 3: Symbol Detection using high quality PNG
|
| 184 |
+
detection_image_path, detection_json_path, _, diagram_bbox = run_detection_with_optimal_threshold(
|
| 185 |
+
high_quality_png, # Use high quality PNG
|
| 186 |
+
results_dir=results_dir,
|
| 187 |
+
file_name=os.path.basename(high_quality_png),
|
| 188 |
+
storage=storage,
|
| 189 |
+
resize_image=False # Don't resize
|
| 190 |
+
)
|
| 191 |
+
outputs[2] = detection_image_path # Symbols Tab
|
| 192 |
+
symbol_json_path = detection_json_path
|
| 193 |
+
|
| 194 |
+
# Step 4: Text Detection using high quality PNG
|
| 195 |
+
text_results, text_summary = process_drawing(
|
| 196 |
+
high_quality_png, # Use high quality PNG
|
| 197 |
+
results_dir,
|
| 198 |
+
storage
|
| 199 |
+
)
|
| 200 |
+
text_json_path = text_results['json_path']
|
| 201 |
+
outputs[3] = text_results['image_path'] # Tags Tab
|
| 202 |
+
|
| 203 |
+
# Step 5: Line Detection (80%)
|
| 204 |
+
update_progress(0.80, "Step 5/7: Line Detection")
|
| 205 |
+
yield outputs
|
| 206 |
+
|
| 207 |
+
try:
|
| 208 |
+
# Initialize components
|
| 209 |
+
debug_handler = DebugHandler(enabled=True, storage=storage)
|
| 210 |
+
|
| 211 |
+
# Configure detectors
|
| 212 |
+
line_config = LineConfig()
|
| 213 |
+
point_config = PointConfig()
|
| 214 |
+
junction_config = JunctionConfig()
|
| 215 |
+
symbol_config = SymbolConfig(
|
| 216 |
+
model_path="models/Intui_SDM_41.pt",
|
| 217 |
+
confidence_threshold=0.5,
|
| 218 |
+
nms_threshold=0.3
|
| 219 |
+
)
|
| 220 |
+
tag_config = TagConfig(
|
| 221 |
+
model_path="models/tag_detection.json",
|
| 222 |
+
confidence_threshold=0.5
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Create all required detectors
|
| 226 |
+
symbol_detector = SymbolDetector(
|
| 227 |
+
config=symbol_config,
|
| 228 |
+
debug_handler=debug_handler
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
tag_detector = TagDetector(
|
| 232 |
+
config=tag_config,
|
| 233 |
+
debug_handler=debug_handler
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
line_detector = LineDetector(
|
| 237 |
+
config=line_config,
|
| 238 |
+
model_path="models/deeplsd_md.tar",
|
| 239 |
+
model_config={"detect_lines": True},
|
| 240 |
+
device=torch.device("cuda"),
|
| 241 |
+
debug_handler=debug_handler
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
point_detector = PointDetector(
|
| 245 |
+
config=point_config,
|
| 246 |
+
debug_handler=debug_handler
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
junction_detector = JunctionDetector(
|
| 250 |
+
config=junction_config,
|
| 251 |
+
debug_handler=debug_handler
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
# Create pipeline with all detectors
|
| 255 |
+
pipeline = DiagramDetectionPipeline(
|
| 256 |
+
tag_detector=tag_detector,
|
| 257 |
+
symbol_detector=symbol_detector,
|
| 258 |
+
line_detector=line_detector,
|
| 259 |
+
point_detector=point_detector,
|
| 260 |
+
junction_detector=junction_detector,
|
| 261 |
+
storage=storage,
|
| 262 |
+
debug_handler=debug_handler
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
# Run pipeline with original high-res image
|
| 266 |
+
line_results = pipeline.run(
|
| 267 |
+
image_path=high_quality_png,
|
| 268 |
+
output_dir=results_dir,
|
| 269 |
+
config=ImageConfig()
|
| 270 |
+
)
|
| 271 |
+
line_json_path = line_results.json_path
|
| 272 |
+
outputs[4] = line_results.image_path
|
| 273 |
+
|
| 274 |
+
# Verify line detection output
|
| 275 |
+
if not os.path.exists(line_json_path):
|
| 276 |
+
raise ValueError(f"Line detection JSON not found: {line_json_path}")
|
| 277 |
+
|
| 278 |
+
# Verify line detection JSON content
|
| 279 |
+
with open(line_json_path, 'r') as f:
|
| 280 |
+
line_data = json.load(f)
|
| 281 |
+
if 'lines' not in line_data:
|
| 282 |
+
raise ValueError(f"Invalid line detection data format in {line_json_path}")
|
| 283 |
+
logger.info(f"Line detection completed successfully with {len(line_data['lines'])} lines")
|
| 284 |
+
|
| 285 |
+
# Verify all required JSONs exist before aggregation
|
| 286 |
+
required_jsons = {
|
| 287 |
+
'symbols': symbol_json_path,
|
| 288 |
+
'texts': text_json_path,
|
| 289 |
+
'lines': line_json_path
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
for name, path in required_jsons.items():
|
| 293 |
+
if not os.path.exists(path):
|
| 294 |
+
raise ValueError(f"{name} JSON not found: {path}")
|
| 295 |
+
# Verify JSON can be loaded
|
| 296 |
+
with open(path, 'r') as f:
|
| 297 |
+
data = json.load(f)
|
| 298 |
+
logger.info(f"Loaded {name} JSON with {len(data.get('detections', data.get('lines', [])))} items")
|
| 299 |
+
|
| 300 |
+
# Data Aggregation
|
| 301 |
+
aggregator = DataAggregator(storage=storage)
|
| 302 |
+
aggregated_result = aggregator.process_data(
|
| 303 |
+
image_path=high_quality_png,
|
| 304 |
+
output_dir=results_dir,
|
| 305 |
+
symbols_path=symbol_json_path,
|
| 306 |
+
texts_path=text_json_path,
|
| 307 |
+
lines_path=line_json_path
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Verify aggregation result before graph creation
|
| 311 |
+
if not aggregated_result.get('success'):
|
| 312 |
+
raise ValueError(f"Data aggregation failed: {aggregated_result.get('error')}")
|
| 313 |
+
|
| 314 |
+
aggregated_json_path = aggregated_result['json_path']
|
| 315 |
+
if not os.path.exists(aggregated_json_path):
|
| 316 |
+
raise ValueError(f"Aggregated JSON not found: {aggregated_json_path}")
|
| 317 |
+
|
| 318 |
+
# Verify aggregated JSON content
|
| 319 |
+
with open(aggregated_json_path, 'r') as f:
|
| 320 |
+
aggregated_data = json.load(f)
|
| 321 |
+
required_keys = ['nodes', 'edges', 'symbols', 'texts', 'lines']
|
| 322 |
+
missing_keys = [k for k in required_keys if k not in aggregated_data]
|
| 323 |
+
if missing_keys:
|
| 324 |
+
raise ValueError(f"Aggregated JSON missing required keys: {missing_keys}")
|
| 325 |
+
logger.info("Aggregation completed successfully with:")
|
| 326 |
+
logger.info(f"- {len(aggregated_data['nodes'])} nodes")
|
| 327 |
+
logger.info(f"- {len(aggregated_data['edges'])} edges")
|
| 328 |
+
|
| 329 |
+
# After aggregation, create graph visualization
|
| 330 |
+
update_progress(0.85, "Step 6/7: Creating Knowledge Graph")
|
| 331 |
+
try:
|
| 332 |
+
# Create graph visualization
|
| 333 |
+
graph_results = create_graph_visualization(
|
| 334 |
+
json_path=aggregated_json_path,
|
| 335 |
+
output_dir=results_dir,
|
| 336 |
+
base_name=base_name,
|
| 337 |
+
save_plot=True
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
if not graph_results.get('success'):
|
| 341 |
+
logger.error(f"Error in graph generation: {graph_results.get('error')}")
|
| 342 |
+
raise Exception(graph_results.get('error'))
|
| 343 |
+
|
| 344 |
+
graph_path = f"results/{base_name}_graph_visualization.png"
|
| 345 |
+
if not os.path.exists(graph_path):
|
| 346 |
+
raise Exception("Graph visualization file not created")
|
| 347 |
+
|
| 348 |
+
update_progress(0.90, "Step 6/7: Knowledge Graph Created")
|
| 349 |
+
|
| 350 |
+
except Exception as e:
|
| 351 |
+
logger.error(f"Error creating graph visualization: {str(e)}")
|
| 352 |
+
raise
|
| 353 |
+
|
| 354 |
+
# Fix output assignments
|
| 355 |
+
outputs[0] = progress_text # Progress text
|
| 356 |
+
outputs[1] = high_quality_png # P&ID
|
| 357 |
+
outputs[2] = detection_image_path # Symbols
|
| 358 |
+
outputs[3] = text_results['image_path'] # Tags
|
| 359 |
+
outputs[4] = line_results.image_path # Lines
|
| 360 |
+
outputs[5] = f"results/{base_name}_aggregated.png" # Aggregated
|
| 361 |
+
outputs[6] = graph_path # Graph visualization
|
| 362 |
+
outputs[7] = chat_history # Chat
|
| 363 |
+
outputs[8] = aggregated_json_path # JSON state
|
| 364 |
+
|
| 365 |
+
# Update progress with all steps
|
| 366 |
+
update_progress(0.95, "Step 7/7: Finalizing Results")
|
| 367 |
+
chat_history = [{"role": "assistant", "content": "Processing complete! I can help answer questions about the P&ID contents."}]
|
| 368 |
+
outputs[7] = chat_history
|
| 369 |
+
|
| 370 |
+
update_progress(1.0, "β
Processing Complete")
|
| 371 |
+
yield outputs
|
| 372 |
+
|
| 373 |
+
except Exception as e:
|
| 374 |
+
# Update chat with error message
|
| 375 |
+
chat_history = [{"role": "assistant", "content": f"Error during processing: {str(e)}"}]
|
| 376 |
+
outputs[7] = chat_history
|
| 377 |
+
raise
|
| 378 |
+
|
| 379 |
+
except Exception as e:
|
| 380 |
+
logger.error(f"Error in process_pnid: {str(e)}")
|
| 381 |
+
logger.error(f"Stack trace:\n{traceback.format_exc()}")
|
| 382 |
+
# Update chat with error message
|
| 383 |
+
chat_history = [{"role": "assistant", "content": f"Error: {str(e)}"}]
|
| 384 |
+
outputs[7] = chat_history
|
| 385 |
+
raise
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
# Separate function for Chat interaction
|
| 389 |
+
def handle_user_message(user_input, chat_history, json_path_state):
|
| 390 |
+
"""Handle user messages and generate responses."""
|
| 391 |
+
try:
|
| 392 |
+
if not user_input or not user_input.strip():
|
| 393 |
+
return chat_history
|
| 394 |
+
|
| 395 |
+
# Add user message
|
| 396 |
+
timestamp = get_timestamp()
|
| 397 |
+
new_history = chat_history + chat_message("user", user_input, user_avatar, timestamp)
|
| 398 |
+
|
| 399 |
+
# Check if json_path exists and is valid
|
| 400 |
+
if not json_path_state or not os.path.exists(json_path_state):
|
| 401 |
+
error_message = "Please upload and process a P&ID document first."
|
| 402 |
+
return new_history + chat_message("assistant", error_message, agent_avatar, get_timestamp())
|
| 403 |
+
|
| 404 |
+
try:
|
| 405 |
+
# Log for debugging
|
| 406 |
+
logger.info(f"Sending question to assistant: {user_input}")
|
| 407 |
+
logger.info(f"Using JSON path: {json_path_state}")
|
| 408 |
+
|
| 409 |
+
# Generate response
|
| 410 |
+
response = get_assistant_response(user_input, json_path_state)
|
| 411 |
+
|
| 412 |
+
# Handle the response
|
| 413 |
+
if isinstance(response, (str, dict)):
|
| 414 |
+
response_text = str(response)
|
| 415 |
+
else:
|
| 416 |
+
try:
|
| 417 |
+
# Try to get the first response from generator
|
| 418 |
+
response_text = next(response) if hasattr(response, '__next__') else str(response)
|
| 419 |
+
except StopIteration:
|
| 420 |
+
response_text = "I apologize, but I couldn't generate a response."
|
| 421 |
+
except Exception as e:
|
| 422 |
+
logger.error(f"Error processing response: {str(e)}")
|
| 423 |
+
response_text = "I apologize, but I encountered an error processing your request."
|
| 424 |
+
|
| 425 |
+
logger.info(f"Generated response: {response_text}")
|
| 426 |
+
|
| 427 |
+
if not response_text.strip():
|
| 428 |
+
response_text = "I apologize, but I couldn't generate a response. Please try asking your question differently."
|
| 429 |
+
|
| 430 |
+
# Add response to chat history
|
| 431 |
+
new_history += chat_message("assistant", response_text, agent_avatar, get_timestamp())
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
logger.error(f"Error generating response: {str(e)}")
|
| 435 |
+
logger.error(traceback.format_exc())
|
| 436 |
+
error_message = "I apologize, but I encountered an error processing your request. Please try again."
|
| 437 |
+
new_history += chat_message("assistant", error_message, agent_avatar, get_timestamp())
|
| 438 |
+
|
| 439 |
+
return new_history
|
| 440 |
+
|
| 441 |
+
except Exception as e:
|
| 442 |
+
logger.error(f"Chat error: {str(e)}")
|
| 443 |
+
logger.error(traceback.format_exc())
|
| 444 |
+
return chat_history + chat_message(
|
| 445 |
+
"assistant",
|
| 446 |
+
"I apologize, but something went wrong. Please try again.",
|
| 447 |
+
agent_avatar,
|
| 448 |
+
get_timestamp()
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
# Update custom CSS
|
| 453 |
+
custom_css = """
|
| 454 |
+
.full-height-row {
|
| 455 |
+
height: calc(100vh - 150px); /* Adjusted height */
|
| 456 |
+
margin: 0;
|
| 457 |
+
padding: 10px;
|
| 458 |
+
}
|
| 459 |
+
.upload-box {
|
| 460 |
+
background: #2a2a2a;
|
| 461 |
+
border-radius: 8px;
|
| 462 |
+
padding: 15px;
|
| 463 |
+
margin-bottom: 15px;
|
| 464 |
+
border: 1px solid #3a3a3a;
|
| 465 |
+
}
|
| 466 |
+
.status-box-container {
|
| 467 |
+
background: #2a2a2a;
|
| 468 |
+
border-radius: 8px;
|
| 469 |
+
padding: 15px;
|
| 470 |
+
height: calc(100vh - 350px); /* Reduced height */
|
| 471 |
+
border: 1px solid #3a3a3a;
|
| 472 |
+
margin-bottom: 15px;
|
| 473 |
+
}
|
| 474 |
+
.status-box {
|
| 475 |
+
font-family: 'Courier New', monospace;
|
| 476 |
+
font-size: 12px;
|
| 477 |
+
line-height: 1.4;
|
| 478 |
+
background-color: #1a1a1a;
|
| 479 |
+
color: #00ff00;
|
| 480 |
+
padding: 10px;
|
| 481 |
+
border-radius: 5px;
|
| 482 |
+
height: calc(100% - 40px); /* Adjust for header */
|
| 483 |
+
overflow-y: auto;
|
| 484 |
+
white-space: pre-wrap;
|
| 485 |
+
word-wrap: break-word;
|
| 486 |
+
border: none;
|
| 487 |
+
}
|
| 488 |
+
.preview-tabs {
|
| 489 |
+
height: calc(100vh - 100px); /* Increased container height from 200px */
|
| 490 |
+
background: #2a2a2a;
|
| 491 |
+
border-radius: 8px;
|
| 492 |
+
padding: 15px;
|
| 493 |
+
border: 1px solid #3a3a3a;
|
| 494 |
+
margin-bottom: 15px;
|
| 495 |
+
}
|
| 496 |
+
.chat-container {
|
| 497 |
+
height: 100%; /* Take full height */
|
| 498 |
+
display: flex;
|
| 499 |
+
flex-direction: column;
|
| 500 |
+
background: #2a2a2a;
|
| 501 |
+
border-radius: 8px;
|
| 502 |
+
padding: 15px;
|
| 503 |
+
border: 1px solid #3a3a3a;
|
| 504 |
+
}
|
| 505 |
+
.chatbox {
|
| 506 |
+
flex: 1; /* Take remaining space */
|
| 507 |
+
overflow-y: auto;
|
| 508 |
+
background: #1a1a1a;
|
| 509 |
+
border-radius: 8px;
|
| 510 |
+
padding: 15px;
|
| 511 |
+
margin-bottom: 15px;
|
| 512 |
+
color: #ffffff;
|
| 513 |
+
min-height: 200px; /* Ensure minimum height */
|
| 514 |
+
}
|
| 515 |
+
.chat-input-group {
|
| 516 |
+
height: auto; /* Allow natural height */
|
| 517 |
+
min-height: 120px; /* Minimum height for input area */
|
| 518 |
+
background: #1a1a1a;
|
| 519 |
+
border-radius: 8px;
|
| 520 |
+
padding: 15px;
|
| 521 |
+
margin-top: auto; /* Push to bottom */
|
| 522 |
+
}
|
| 523 |
+
.chat-input {
|
| 524 |
+
background: #2a2a2a;
|
| 525 |
+
color: #ffffff;
|
| 526 |
+
border: 1px solid #3a3a3a;
|
| 527 |
+
border-radius: 5px;
|
| 528 |
+
padding: 12px;
|
| 529 |
+
min-height: 80px;
|
| 530 |
+
width: 100%;
|
| 531 |
+
margin-bottom: 10px;
|
| 532 |
+
}
|
| 533 |
+
.send-button {
|
| 534 |
+
width: 100%;
|
| 535 |
+
background: #4a4a4a;
|
| 536 |
+
color: #ffffff;
|
| 537 |
+
border-radius: 5px;
|
| 538 |
+
border: none;
|
| 539 |
+
padding: 12px;
|
| 540 |
+
cursor: pointer;
|
| 541 |
+
transition: background-color 0.3s;
|
| 542 |
+
}
|
| 543 |
+
.result-image {
|
| 544 |
+
border-radius: 8px;
|
| 545 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
| 546 |
+
margin: 10px 0;
|
| 547 |
+
background: #ffffff;
|
| 548 |
+
}
|
| 549 |
+
.chat-message {
|
| 550 |
+
display: flex;
|
| 551 |
+
margin-bottom: 1rem;
|
| 552 |
+
align-items: flex-start;
|
| 553 |
+
}
|
| 554 |
+
.chat-message .avatar {
|
| 555 |
+
width: 40px;
|
| 556 |
+
height: 40px;
|
| 557 |
+
margin-right: 10px;
|
| 558 |
+
border-radius: 50%;
|
| 559 |
+
}
|
| 560 |
+
.chat-message .speech-bubble {
|
| 561 |
+
background: #2a2a2a;
|
| 562 |
+
padding: 10px 15px;
|
| 563 |
+
border-radius: 10px;
|
| 564 |
+
max-width: 80%;
|
| 565 |
+
margin-bottom: 5px;
|
| 566 |
+
}
|
| 567 |
+
.chat-message .timestamp {
|
| 568 |
+
font-size: 0.8em;
|
| 569 |
+
color: #666;
|
| 570 |
+
}
|
| 571 |
+
.logo-row {
|
| 572 |
+
width: 100%;
|
| 573 |
+
background-color: #1a1a1a;
|
| 574 |
+
padding: 10px 0;
|
| 575 |
+
margin: 0;
|
| 576 |
+
border-bottom: 1px solid #3a3a3a;
|
| 577 |
+
}
|
| 578 |
+
"""
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def create_ui():
|
| 582 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 583 |
+
logo_path = os.path.join(current_dir, "assets", "intuigence.png")
|
| 584 |
+
|
| 585 |
+
css = """
|
| 586 |
+
/* Theme colors */
|
| 587 |
+
:root {
|
| 588 |
+
--orange-primary: #ff6b2b;
|
| 589 |
+
--orange-hover: #ff8651;
|
| 590 |
+
--orange-light: rgba(255, 107, 43, 0.1);
|
| 591 |
+
}
|
| 592 |
+
|
| 593 |
+
/* Logo styling */
|
| 594 |
+
.logo-container {
|
| 595 |
+
padding: 10px 20px;
|
| 596 |
+
margin-bottom: 10px;
|
| 597 |
+
text-align: left;
|
| 598 |
+
width: 100%;
|
| 599 |
+
background: #1a1a1a; /* Match app background */
|
| 600 |
+
border-bottom: 1px solid #3a3a3a;
|
| 601 |
+
}
|
| 602 |
+
.logo-container img {
|
| 603 |
+
max-height: 40px;
|
| 604 |
+
width: auto;
|
| 605 |
+
display: inline-block !important;
|
| 606 |
+
}
|
| 607 |
+
/* Hide download and fullscreen buttons for logo */
|
| 608 |
+
.logo-container .download-button,
|
| 609 |
+
.logo-container .fullscreen-button {
|
| 610 |
+
display: none !important;
|
| 611 |
+
}
|
| 612 |
+
/* Adjust main content padding */
|
| 613 |
+
.main-content {
|
| 614 |
+
padding-top: 10px;
|
| 615 |
+
}
|
| 616 |
+
/* Custom orange theme */
|
| 617 |
+
.primary-button {
|
| 618 |
+
background: var(--orange-primary) !important;
|
| 619 |
+
color: white !important;
|
| 620 |
+
border: none !important;
|
| 621 |
+
}
|
| 622 |
+
.primary-button:hover {
|
| 623 |
+
background: var(--orange-hover) !important;
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
/* Tab styling */
|
| 627 |
+
.tabs > .tab-nav > button.selected {
|
| 628 |
+
border-color: var(--orange-primary) !important;
|
| 629 |
+
color: var(--orange-primary) !important;
|
| 630 |
+
}
|
| 631 |
+
.tabs > .tab-nav > button:hover {
|
| 632 |
+
border-color: var(--orange-hover) !important;
|
| 633 |
+
color: var(--orange-hover) !important;
|
| 634 |
+
}
|
| 635 |
+
|
| 636 |
+
/* File upload button */
|
| 637 |
+
.file-upload {
|
| 638 |
+
background: var(--orange-primary) !important;
|
| 639 |
+
}
|
| 640 |
+
|
| 641 |
+
/* Progress bar */
|
| 642 |
+
.progress-bar > div {
|
| 643 |
+
background: var(--orange-primary) !important;
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
/* Tags and labels */
|
| 647 |
+
.label-wrap {
|
| 648 |
+
background: var(--orange-primary) !important;
|
| 649 |
+
}
|
| 650 |
+
|
| 651 |
+
/* Selected/active states */
|
| 652 |
+
.selected, .active, .focused {
|
| 653 |
+
border-color: var(--orange-primary) !important;
|
| 654 |
+
color: var(--orange-primary) !important;
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
/* Links and interactive elements */
|
| 658 |
+
a, .link, .interactive {
|
| 659 |
+
color: var(--orange-primary) !important;
|
| 660 |
+
}
|
| 661 |
+
a:hover, .link:hover, .interactive:hover {
|
| 662 |
+
color: var(--orange-hover) !important;
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
/* Input focus states */
|
| 666 |
+
input:focus, textarea:focus {
|
| 667 |
+
border-color: var(--orange-primary) !important;
|
| 668 |
+
box-shadow: 0 0 0 1px var(--orange-light) !important;
|
| 669 |
+
}
|
| 670 |
+
|
| 671 |
+
/* Checkbox and radio */
|
| 672 |
+
input[type="checkbox"]:checked, input[type="radio"]:checked {
|
| 673 |
+
background-color: var(--orange-primary) !important;
|
| 674 |
+
border-color: var(--orange-primary) !important;
|
| 675 |
+
}
|
| 676 |
+
"""
|
| 677 |
+
|
| 678 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
|
| 679 |
+
# Logo row (before main content)
|
| 680 |
+
with gr.Row(elem_classes="logo-container"):
|
| 681 |
+
gr.Image(
|
| 682 |
+
value=logo_path,
|
| 683 |
+
show_label=False,
|
| 684 |
+
container=False,
|
| 685 |
+
interactive=False,
|
| 686 |
+
show_download_button=False,
|
| 687 |
+
show_share_button=False,
|
| 688 |
+
height=40
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
# State for storing file path
|
| 692 |
+
file_path = gr.State()
|
| 693 |
+
json_path = gr.State()
|
| 694 |
+
|
| 695 |
+
# Main content row
|
| 696 |
+
with gr.Row(elem_classes="main-content"):
|
| 697 |
+
# Left column - File Upload & Processing
|
| 698 |
+
with gr.Column(scale=3, elem_classes="column-panel"):
|
| 699 |
+
file_output = gr.File(label="Upload P&ID Document")
|
| 700 |
+
process_button = gr.Button(
|
| 701 |
+
"Process Document",
|
| 702 |
+
elem_classes="primary-button" # Add custom class
|
| 703 |
+
)
|
| 704 |
+
progress_output = gr.Textbox(
|
| 705 |
+
label="Progress",
|
| 706 |
+
value="Waiting for document...",
|
| 707 |
+
interactive=False
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
# Center column - Preview Panel
|
| 711 |
+
with gr.Column(scale=5, elem_classes="column-panel preview-panel"):
|
| 712 |
+
with gr.Tabs() as tabs:
|
| 713 |
+
with gr.TabItem("P&ID"):
|
| 714 |
+
input_image = gr.Image(type="filepath", label="Original")
|
| 715 |
+
with gr.TabItem("Symbols"):
|
| 716 |
+
symbol_image = gr.Image(type="filepath", label="Detected Symbols")
|
| 717 |
+
with gr.TabItem("Tags"):
|
| 718 |
+
text_image = gr.Image(type="filepath", label="Detected Tags")
|
| 719 |
+
with gr.TabItem("Lines"):
|
| 720 |
+
line_image = gr.Image(type="filepath", label="Detected Lines")
|
| 721 |
+
with gr.TabItem("Aggregated"):
|
| 722 |
+
aggregated_image = gr.Image(type="filepath", label="Aggregated Results")
|
| 723 |
+
with gr.TabItem("Knowledge Graph"):
|
| 724 |
+
graph_image = gr.Image(type="filepath", label="Knowledge Graph")
|
| 725 |
+
|
| 726 |
+
# Right column - Chat Interface
|
| 727 |
+
with gr.Column(scale=4, elem_classes="column-panel chat-panel", elem_id="chat-panel"):
|
| 728 |
+
chat_history = gr.Chatbot(
|
| 729 |
+
[],
|
| 730 |
+
elem_classes="chat-history",
|
| 731 |
+
height=400,
|
| 732 |
+
show_label=False,
|
| 733 |
+
type="messages",
|
| 734 |
+
elem_id="chat-history"
|
| 735 |
+
)
|
| 736 |
+
with gr.Row():
|
| 737 |
+
chat_input = gr.Textbox(
|
| 738 |
+
placeholder="Ask me about the P&ID...",
|
| 739 |
+
show_label=False,
|
| 740 |
+
container=False
|
| 741 |
+
)
|
| 742 |
+
chat_button = gr.Button(
|
| 743 |
+
"Send",
|
| 744 |
+
elem_classes="primary-button" # Add custom class
|
| 745 |
+
)
|
| 746 |
+
|
| 747 |
+
def handle_chat(user_message, chat_history, json_path):
|
| 748 |
+
if not user_message:
|
| 749 |
+
return "", chat_history
|
| 750 |
+
|
| 751 |
+
# Add user message
|
| 752 |
+
chat_history = chat_history + [{"role": "user", "content": user_message}]
|
| 753 |
+
|
| 754 |
+
try:
|
| 755 |
+
# Get assistant response
|
| 756 |
+
response = get_assistant_response(user_message, json_path)
|
| 757 |
+
# Add assistant response
|
| 758 |
+
chat_history = chat_history + [{"role": "assistant", "content": response}]
|
| 759 |
+
except Exception as e:
|
| 760 |
+
logger.error(f"Error in chat response: {str(e)}")
|
| 761 |
+
chat_history = chat_history + [
|
| 762 |
+
{"role": "assistant", "content": "I apologize, but I encountered an error processing your request."}
|
| 763 |
+
]
|
| 764 |
+
|
| 765 |
+
return "", chat_history
|
| 766 |
+
|
| 767 |
+
# Connect UI elements
|
| 768 |
+
chat_input.submit(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history])
|
| 769 |
+
chat_button.click(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history])
|
| 770 |
+
|
| 771 |
+
process_button.click(
|
| 772 |
+
process_pnid,
|
| 773 |
+
inputs=[file_output],
|
| 774 |
+
outputs=[
|
| 775 |
+
progress_output, # Progress text (0)
|
| 776 |
+
input_image, # P&ID (1)
|
| 777 |
+
symbol_image, # Symbols (2)
|
| 778 |
+
text_image, # Tags (3)
|
| 779 |
+
line_image, # Lines (4)
|
| 780 |
+
aggregated_image, # Aggregated (5)
|
| 781 |
+
graph_image, # Graph (6)
|
| 782 |
+
chat_history, # Chat (7)
|
| 783 |
+
json_path # State (8)
|
| 784 |
+
],
|
| 785 |
+
show_progress="hidden" # Hide progress in tabs
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
return demo
|
| 789 |
+
|
| 790 |
+
|
| 791 |
+
def main():
|
| 792 |
+
# Download models if they don't exist
|
| 793 |
+
if not os.path.exists('models/yolo/yolov8n.pt'):
|
| 794 |
+
copy_local_models()
|
| 795 |
+
|
| 796 |
+
demo = create_ui()
|
| 797 |
+
# Remove HF Spaces conditional, just use local development settings
|
| 798 |
+
demo.launch(
|
| 799 |
+
server_name="0.0.0.0",
|
| 800 |
+
server_port=7861, # Changed from 7860
|
| 801 |
+
share=True
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
|
| 805 |
+
if __name__ == "__main__":
|
| 806 |
+
main()
|
graph_construction.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import networkx as nx
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import logging
|
| 7 |
+
import traceback
|
| 8 |
+
from storage import StorageFactory
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
|
| 14 |
+
"""Construct network graph from aggregated detection data"""
|
| 15 |
+
try:
|
| 16 |
+
# Use provided storage or get a new one
|
| 17 |
+
if storage is None:
|
| 18 |
+
storage = StorageFactory.get_storage()
|
| 19 |
+
|
| 20 |
+
# Create graph
|
| 21 |
+
G = nx.Graph()
|
| 22 |
+
pos = {} # For node positions
|
| 23 |
+
|
| 24 |
+
# Add nodes from the aggregated data
|
| 25 |
+
for node in data.get('nodes', []):
|
| 26 |
+
node_id = node['id']
|
| 27 |
+
node_type = node['type']
|
| 28 |
+
|
| 29 |
+
# Calculate position based on node type
|
| 30 |
+
if node_type == 'connection_point':
|
| 31 |
+
pos[node_id] = (node['coords']['x'], node['coords']['y'])
|
| 32 |
+
else: # symbol or text
|
| 33 |
+
bbox = node['bbox']
|
| 34 |
+
pos[node_id] = (
|
| 35 |
+
(bbox['xmin'] + bbox['xmax']) / 2,
|
| 36 |
+
(bbox['ymin'] + bbox['ymax']) / 2
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
# Add node with all its properties
|
| 40 |
+
G.add_node(node_id, **node)
|
| 41 |
+
|
| 42 |
+
# Add edges from the aggregated data
|
| 43 |
+
for edge in data.get('edges', []):
|
| 44 |
+
G.add_edge(
|
| 45 |
+
edge['source'],
|
| 46 |
+
edge['target'],
|
| 47 |
+
**edge.get('properties', {})
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Create visualization
|
| 51 |
+
plt.figure(figsize=(20, 20))
|
| 52 |
+
|
| 53 |
+
# Draw nodes with different colors based on type
|
| 54 |
+
node_colors = []
|
| 55 |
+
for node in G.nodes():
|
| 56 |
+
node_type = G.nodes[node]['type']
|
| 57 |
+
if node_type == 'symbol':
|
| 58 |
+
node_colors.append('lightblue')
|
| 59 |
+
elif node_type == 'text':
|
| 60 |
+
node_colors.append('lightgreen')
|
| 61 |
+
else: # connection_point
|
| 62 |
+
node_colors.append('lightgray')
|
| 63 |
+
|
| 64 |
+
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
|
| 65 |
+
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
|
| 66 |
+
|
| 67 |
+
# Add labels
|
| 68 |
+
labels = {}
|
| 69 |
+
for node in G.nodes():
|
| 70 |
+
node_data = G.nodes[node]
|
| 71 |
+
if node_data['type'] == 'symbol':
|
| 72 |
+
labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
|
| 73 |
+
elif node_data['type'] == 'text':
|
| 74 |
+
content = node_data.get('content', '')
|
| 75 |
+
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
|
| 76 |
+
else:
|
| 77 |
+
labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
|
| 78 |
+
|
| 79 |
+
nx.draw_networkx_labels(G, pos, labels, font_size=8)
|
| 80 |
+
|
| 81 |
+
plt.title("P&ID Knowledge Graph")
|
| 82 |
+
plt.axis('off')
|
| 83 |
+
|
| 84 |
+
# Save the visualization
|
| 85 |
+
graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
|
| 86 |
+
plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
|
| 87 |
+
plt.close()
|
| 88 |
+
|
| 89 |
+
# Save graph data as JSON for future use
|
| 90 |
+
graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
|
| 91 |
+
with open(graph_json_path, 'w') as f:
|
| 92 |
+
json.dump(nx.node_link_data(G), f, indent=2)
|
| 93 |
+
|
| 94 |
+
return G, pos, plt.gcf()
|
| 95 |
+
|
| 96 |
+
except Exception as e:
|
| 97 |
+
logger.error(f"Error in construct_graph_network: {str(e)}")
|
| 98 |
+
traceback.print_exc()
|
| 99 |
+
return None, None, None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
# Test code
|
| 104 |
+
test_data_path = "results/test_aggregated.json"
|
| 105 |
+
if os.path.exists(test_data_path):
|
| 106 |
+
with open(test_data_path, 'r') as f:
|
| 107 |
+
test_data = json.load(f)
|
| 108 |
+
|
| 109 |
+
G, pos, fig = construct_graph_network(
|
| 110 |
+
test_data,
|
| 111 |
+
"results/validation.json",
|
| 112 |
+
"results"
|
| 113 |
+
)
|
| 114 |
+
if fig:
|
| 115 |
+
plt.show()
|
graph_processor.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import networkx as nx
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
import traceback
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_connected_graph(input_data):
|
| 10 |
+
"""Create a connected graph from the input data"""
|
| 11 |
+
try:
|
| 12 |
+
# Validate input data structure
|
| 13 |
+
if not isinstance(input_data, dict):
|
| 14 |
+
raise ValueError("Invalid input data format")
|
| 15 |
+
|
| 16 |
+
# Check for required keys in new format
|
| 17 |
+
required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges']
|
| 18 |
+
if not all(key in input_data for key in required_keys):
|
| 19 |
+
raise ValueError(f"Missing required keys in input data. Expected: {required_keys}")
|
| 20 |
+
|
| 21 |
+
# Create graph
|
| 22 |
+
G = nx.Graph()
|
| 23 |
+
|
| 24 |
+
# Track positions for layout
|
| 25 |
+
pos = {}
|
| 26 |
+
|
| 27 |
+
# Add symbol nodes
|
| 28 |
+
for symbol in input_data['symbols']:
|
| 29 |
+
bbox = symbol.get('bbox', [])
|
| 30 |
+
symbol_id = symbol.get('id', str(uuid.uuid4()))
|
| 31 |
+
|
| 32 |
+
if bbox:
|
| 33 |
+
# Calculate center position
|
| 34 |
+
center_x = (bbox['xmin'] + bbox['xmax']) / 2
|
| 35 |
+
center_y = (bbox['ymin'] + bbox['ymax']) / 2
|
| 36 |
+
pos[symbol_id] = (center_x, center_y)
|
| 37 |
+
|
| 38 |
+
G.add_node(
|
| 39 |
+
symbol_id,
|
| 40 |
+
type='symbol',
|
| 41 |
+
class_name=symbol.get('class', ''),
|
| 42 |
+
bbox=bbox,
|
| 43 |
+
confidence=symbol.get('confidence', 0.0)
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Add text nodes
|
| 47 |
+
for text in input_data['texts']:
|
| 48 |
+
bbox = text.get('bbox', [])
|
| 49 |
+
text_id = text.get('id', str(uuid.uuid4()))
|
| 50 |
+
|
| 51 |
+
if bbox:
|
| 52 |
+
center_x = (bbox['xmin'] + bbox['xmax']) / 2
|
| 53 |
+
center_y = (bbox['ymin'] + bbox['ymax']) / 2
|
| 54 |
+
pos[text_id] = (center_x, center_y)
|
| 55 |
+
|
| 56 |
+
G.add_node(
|
| 57 |
+
text_id,
|
| 58 |
+
type='text',
|
| 59 |
+
text=text.get('text', ''),
|
| 60 |
+
bbox=bbox,
|
| 61 |
+
confidence=text.get('confidence', 0.0)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Add edges from the edges list
|
| 65 |
+
for edge in input_data['edges']:
|
| 66 |
+
source = edge.get('source')
|
| 67 |
+
target = edge.get('target')
|
| 68 |
+
if source and target and source in G and target in G:
|
| 69 |
+
G.add_edge(
|
| 70 |
+
source,
|
| 71 |
+
target,
|
| 72 |
+
type=edge.get('type', 'connection'),
|
| 73 |
+
properties=edge.get('properties', {})
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Create visualization
|
| 77 |
+
plt.figure(figsize=(20, 20))
|
| 78 |
+
|
| 79 |
+
# Draw nodes with fixed positions
|
| 80 |
+
nx.draw_networkx_nodes(G, pos,
|
| 81 |
+
node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node
|
| 82 |
+
in G.nodes()],
|
| 83 |
+
node_size=500)
|
| 84 |
+
|
| 85 |
+
# Draw edges
|
| 86 |
+
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
|
| 87 |
+
|
| 88 |
+
# Add labels
|
| 89 |
+
labels = {}
|
| 90 |
+
for node in G.nodes():
|
| 91 |
+
node_data = G.nodes[node]
|
| 92 |
+
if node_data['type'] == 'symbol':
|
| 93 |
+
labels[node] = f"S:{node_data['class_name']}"
|
| 94 |
+
else:
|
| 95 |
+
text = node_data.get('text', '')
|
| 96 |
+
labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}"
|
| 97 |
+
|
| 98 |
+
nx.draw_networkx_labels(G, pos, labels, font_size=8)
|
| 99 |
+
|
| 100 |
+
plt.title("P&ID Network Graph")
|
| 101 |
+
plt.axis('off')
|
| 102 |
+
|
| 103 |
+
return G, pos, plt.gcf()
|
| 104 |
+
|
| 105 |
+
except Exception as e:
|
| 106 |
+
print(f"Error in create_connected_graph: {str(e)}")
|
| 107 |
+
traceback.print_exc()
|
| 108 |
+
return None, None, None
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if __name__ == "__main__":
|
| 112 |
+
# Test code
|
| 113 |
+
with open('results/0_aggregated.json') as f:
|
| 114 |
+
data = json.load(f)
|
| 115 |
+
|
| 116 |
+
G, pos, fig = create_connected_graph(data)
|
| 117 |
+
if fig:
|
| 118 |
+
plt.show()
|
graph_visualization.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import networkx as nx
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import os
|
| 5 |
+
from pprint import pprint
|
| 6 |
+
import uuid
|
| 7 |
+
import argparse
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
def create_graph_visualization(json_path: str, output_dir: str, base_name: str, save_plot: bool = True) -> dict:
|
| 12 |
+
"""Create graph visualization using actual coordinates from bboxes"""
|
| 13 |
+
try:
|
| 14 |
+
# Remove '_aggregated' suffix if present
|
| 15 |
+
if base_name.endswith('_aggregated'):
|
| 16 |
+
base_name = base_name[:-len('_aggregated')]
|
| 17 |
+
|
| 18 |
+
print("\nLoading JSON data...")
|
| 19 |
+
with open(json_path, 'r') as f:
|
| 20 |
+
data = json.load(f)
|
| 21 |
+
|
| 22 |
+
# Create graph
|
| 23 |
+
G = nx.Graph()
|
| 24 |
+
pos = {}
|
| 25 |
+
valid_nodes = []
|
| 26 |
+
invalid_nodes = []
|
| 27 |
+
|
| 28 |
+
# First pass - collect valid nodes
|
| 29 |
+
print("\nValidating nodes...")
|
| 30 |
+
for node in tqdm(data.get('nodes', []), desc="Validating"):
|
| 31 |
+
try:
|
| 32 |
+
node_id = str(node.get('id', ''))
|
| 33 |
+
x = float(node.get('x', 0))
|
| 34 |
+
y = float(node.get('y', 0))
|
| 35 |
+
|
| 36 |
+
if node_id and x and y: # Only add if we have valid coordinates
|
| 37 |
+
valid_nodes.append(node)
|
| 38 |
+
pos[node_id] = (x, y)
|
| 39 |
+
else:
|
| 40 |
+
invalid_nodes.append(node)
|
| 41 |
+
except (ValueError, TypeError) as e:
|
| 42 |
+
invalid_nodes.append(node)
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
print(f"\nFound {len(valid_nodes)} valid nodes and {len(invalid_nodes)} invalid nodes")
|
| 46 |
+
|
| 47 |
+
# Add valid nodes
|
| 48 |
+
print("\nAdding valid nodes...")
|
| 49 |
+
for node in tqdm(valid_nodes, desc="Nodes"):
|
| 50 |
+
node_id = str(node.get('id', ''))
|
| 51 |
+
attrs = {
|
| 52 |
+
'type': node.get('type', ''),
|
| 53 |
+
'label': node.get('label', ''),
|
| 54 |
+
'x': float(node.get('x', 0)),
|
| 55 |
+
'y': float(node.get('y', 0))
|
| 56 |
+
}
|
| 57 |
+
G.add_node(node_id, **attrs)
|
| 58 |
+
|
| 59 |
+
# Add valid edges (only between valid nodes)
|
| 60 |
+
print("\nAdding valid edges...")
|
| 61 |
+
valid_edges = []
|
| 62 |
+
invalid_edges = []
|
| 63 |
+
|
| 64 |
+
for edge in tqdm(data.get('edges', []), desc="Edges"):
|
| 65 |
+
try:
|
| 66 |
+
start_id = str(edge.get('start_point', ''))
|
| 67 |
+
end_id = str(edge.get('end_point', ''))
|
| 68 |
+
|
| 69 |
+
if start_id in pos and end_id in pos: # Only add if both nodes exist
|
| 70 |
+
valid_edges.append(edge)
|
| 71 |
+
attrs = {
|
| 72 |
+
'type': edge.get('type', ''),
|
| 73 |
+
'weight': edge.get('weight', 1.0)
|
| 74 |
+
}
|
| 75 |
+
G.add_edge(start_id, end_id, **attrs)
|
| 76 |
+
else:
|
| 77 |
+
invalid_edges.append(edge)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
invalid_edges.append(edge)
|
| 80 |
+
continue
|
| 81 |
+
|
| 82 |
+
print(f"\nFound {len(valid_edges)} valid edges and {len(invalid_edges)} invalid edges")
|
| 83 |
+
|
| 84 |
+
if save_plot:
|
| 85 |
+
print("\nGenerating visualization...")
|
| 86 |
+
plt.figure(figsize=(20, 20))
|
| 87 |
+
|
| 88 |
+
print("Drawing graph elements...")
|
| 89 |
+
with tqdm(total=3, desc="Drawing") as pbar:
|
| 90 |
+
# Draw nodes
|
| 91 |
+
nx.draw_networkx_nodes(G, pos,
|
| 92 |
+
node_color='lightblue',
|
| 93 |
+
node_size=100)
|
| 94 |
+
pbar.update(1)
|
| 95 |
+
|
| 96 |
+
# Draw edges
|
| 97 |
+
nx.draw_networkx_edges(G, pos)
|
| 98 |
+
pbar.update(1)
|
| 99 |
+
|
| 100 |
+
# Save plot
|
| 101 |
+
image_path = os.path.join(output_dir, f"{base_name}_graph_visualization.png")
|
| 102 |
+
plt.savefig(image_path, bbox_inches='tight', dpi=300)
|
| 103 |
+
plt.close()
|
| 104 |
+
pbar.update(1)
|
| 105 |
+
|
| 106 |
+
print(f"\nVisualization saved to: {image_path}")
|
| 107 |
+
return {
|
| 108 |
+
'success': True,
|
| 109 |
+
'image_path': image_path,
|
| 110 |
+
'graph': G,
|
| 111 |
+
'stats': {
|
| 112 |
+
'valid_nodes': len(valid_nodes),
|
| 113 |
+
'invalid_nodes': len(invalid_nodes),
|
| 114 |
+
'valid_edges': len(valid_edges),
|
| 115 |
+
'invalid_edges': len(invalid_edges)
|
| 116 |
+
}
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
return {
|
| 120 |
+
'success': True,
|
| 121 |
+
'graph': G
|
| 122 |
+
}
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
print(f"\nError creating graph: {str(e)}")
|
| 126 |
+
return {
|
| 127 |
+
'success': False,
|
| 128 |
+
'error': str(e)
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
"""Test the graph visualization independently"""
|
| 133 |
+
|
| 134 |
+
# Set up argument parser
|
| 135 |
+
parser = argparse.ArgumentParser(description='Create and visualize graph from aggregated JSON')
|
| 136 |
+
parser.add_argument('--json_path', type=str, default="results/002_page_1_aggregated.json",
|
| 137 |
+
help='Path to aggregated JSON file')
|
| 138 |
+
parser.add_argument('--output_dir', type=str, default="results",
|
| 139 |
+
help='Directory to save outputs')
|
| 140 |
+
parser.add_argument('--show', action='store_true',
|
| 141 |
+
help='Show the plot interactively')
|
| 142 |
+
|
| 143 |
+
args = parser.parse_args()
|
| 144 |
+
|
| 145 |
+
# Verify input file exists
|
| 146 |
+
if not os.path.exists(args.json_path):
|
| 147 |
+
print(f"Error: Could not find input file {args.json_path}")
|
| 148 |
+
exit(1)
|
| 149 |
+
|
| 150 |
+
# Create output directory if it doesn't exist
|
| 151 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
# Get base name from input file and remove '_aggregated' suffix
|
| 154 |
+
base_name = Path(args.json_path).stem
|
| 155 |
+
if base_name.endswith('_aggregated'):
|
| 156 |
+
base_name = base_name[:-len('_aggregated')]
|
| 157 |
+
|
| 158 |
+
print(f"\nProcessing:")
|
| 159 |
+
print(f"Input: {args.json_path}")
|
| 160 |
+
print(f"Output: {args.output_dir}/{base_name}_graph_visualization.png")
|
| 161 |
+
|
| 162 |
+
try:
|
| 163 |
+
# Create visualization
|
| 164 |
+
result = create_graph_visualization(
|
| 165 |
+
json_path=args.json_path,
|
| 166 |
+
output_dir=args.output_dir,
|
| 167 |
+
base_name=base_name,
|
| 168 |
+
save_plot=True
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
if result['success']:
|
| 172 |
+
print(f"\nSuccess! Graph visualization saved to: {result['image_path']}")
|
| 173 |
+
if args.show:
|
| 174 |
+
plt.show()
|
| 175 |
+
else:
|
| 176 |
+
print(f"\nError: {result['error']}")
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"\nError during visualization: {str(e)}")
|
| 180 |
+
raise
|
line_detection_ai.py
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from base import BaseDetector, BaseDetectionPipeline
|
| 2 |
+
from utils import *
|
| 3 |
+
from config import (
|
| 4 |
+
ImageConfig,
|
| 5 |
+
SymbolConfig,
|
| 6 |
+
TagConfig,
|
| 7 |
+
LineConfig,
|
| 8 |
+
PointConfig,
|
| 9 |
+
JunctionConfig
|
| 10 |
+
)
|
| 11 |
+
from detectors import (
|
| 12 |
+
LineDetector,
|
| 13 |
+
PointDetector,
|
| 14 |
+
JunctionDetector,
|
| 15 |
+
SymbolDetector,
|
| 16 |
+
TagDetector
|
| 17 |
+
)
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from storage import StorageFactory
|
| 20 |
+
from common import DetectionResult
|
| 21 |
+
from detection_schema import DetectionContext, JunctionType
|
| 22 |
+
from typing import List, Tuple, Optional, Dict
|
| 23 |
+
import torch
|
| 24 |
+
import numpy as np
|
| 25 |
+
import cv2
|
| 26 |
+
import os
|
| 27 |
+
import logging
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class DiagramDetectionPipeline:
|
| 33 |
+
"""
|
| 34 |
+
Pipeline that runs multiple detectors (line, point, junction, etc.) on an image,
|
| 35 |
+
and keeps a shared DetectionContext in memory.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self,
|
| 39 |
+
tag_detector: Optional[BaseDetector],
|
| 40 |
+
symbol_detector: Optional[BaseDetector],
|
| 41 |
+
line_detector: Optional[BaseDetector],
|
| 42 |
+
point_detector: Optional[BaseDetector],
|
| 43 |
+
junction_detector: Optional[BaseDetector],
|
| 44 |
+
storage: StorageInterface,
|
| 45 |
+
debug_handler: Optional[DebugHandler] = None,
|
| 46 |
+
transformer: Optional[CoordinateTransformer] = None):
|
| 47 |
+
"""
|
| 48 |
+
You can pass None for detectors you don't need.
|
| 49 |
+
"""
|
| 50 |
+
# super().__init__(storage=storage, debug_handler=debug_handler)
|
| 51 |
+
self.storage = storage
|
| 52 |
+
self.debug_handler = debug_handler
|
| 53 |
+
self.tag_detector = tag_detector
|
| 54 |
+
self.symbol_detector = symbol_detector
|
| 55 |
+
self.line_detector = line_detector
|
| 56 |
+
self.point_detector = point_detector
|
| 57 |
+
self.junction_detector = junction_detector
|
| 58 |
+
self.transformer = transformer or CoordinateTransformer()
|
| 59 |
+
|
| 60 |
+
def _load_image(self, image_path: str) -> np.ndarray:
|
| 61 |
+
"""Load image with validation."""
|
| 62 |
+
image_data = self.storage.load_file(image_path)
|
| 63 |
+
image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
|
| 64 |
+
if image is None:
|
| 65 |
+
raise ValueError(f"Failed to load image from {image_path}")
|
| 66 |
+
return image
|
| 67 |
+
|
| 68 |
+
def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]:
|
| 69 |
+
"""Crop to ROI if provided, else return full image."""
|
| 70 |
+
if roi is not None and len(roi) == 4:
|
| 71 |
+
x_min, y_min, x_max, y_max = roi
|
| 72 |
+
return image[y_min:y_max, x_min:x_max], (x_min, y_min)
|
| 73 |
+
return image, (0, 0)
|
| 74 |
+
|
| 75 |
+
def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray:
|
| 76 |
+
"""Fill symbol & tag bounding boxes with white to avoid line detection picking them up."""
|
| 77 |
+
masked = image.copy()
|
| 78 |
+
for sym in context.symbols.values():
|
| 79 |
+
cv2.rectangle(masked,
|
| 80 |
+
(sym.bbox.xmin, sym.bbox.ymin),
|
| 81 |
+
(sym.bbox.xmax, sym.bbox.ymax),
|
| 82 |
+
(255, 255, 255), # White
|
| 83 |
+
thickness=-1)
|
| 84 |
+
|
| 85 |
+
for tg in context.tags.values():
|
| 86 |
+
cv2.rectangle(masked,
|
| 87 |
+
(tg.bbox.xmin, tg.bbox.ymin),
|
| 88 |
+
(tg.bbox.xmax, tg.bbox.ymax),
|
| 89 |
+
(255, 255, 255),
|
| 90 |
+
thickness=-1)
|
| 91 |
+
return masked
|
| 92 |
+
|
| 93 |
+
def run(
|
| 94 |
+
self,
|
| 95 |
+
image_path: str,
|
| 96 |
+
output_dir: str,
|
| 97 |
+
config
|
| 98 |
+
) -> DetectionResult:
|
| 99 |
+
"""
|
| 100 |
+
Main pipeline steps (in local coords):
|
| 101 |
+
1) Load + crop image
|
| 102 |
+
2) Detect symbols & tags
|
| 103 |
+
3) Make a copy for final debug images
|
| 104 |
+
4) White out symbol/tag bounding boxes
|
| 105 |
+
5) Detect lines, points, junctions
|
| 106 |
+
6) Save final JSON
|
| 107 |
+
7) Generate debug images with various combinations
|
| 108 |
+
"""
|
| 109 |
+
try:
|
| 110 |
+
with self.debug_handler.track_performance("total_processing"):
|
| 111 |
+
# 1) Load & crop
|
| 112 |
+
image = self._load_image(image_path)
|
| 113 |
+
# cropped_image, roi_offset = self._crop_to_roi(image, config.roi)
|
| 114 |
+
|
| 115 |
+
# 2) Create fresh context
|
| 116 |
+
context = DetectionContext()
|
| 117 |
+
|
| 118 |
+
# 3) Detect symbols
|
| 119 |
+
with self.debug_handler.track_performance("symbol_detection"):
|
| 120 |
+
self.symbol_detector.detect(
|
| 121 |
+
image,
|
| 122 |
+
context=context,
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# 4) Detect tags
|
| 126 |
+
with self.debug_handler.track_performance("tag_detection"):
|
| 127 |
+
self.tag_detector.detect(
|
| 128 |
+
image,
|
| 129 |
+
context=context,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Make a copy of the cropped image for final debug combos
|
| 133 |
+
debug_cropped = image.copy()
|
| 134 |
+
|
| 135 |
+
# 5) White-out symbol/tag bboxes in the original cropped image
|
| 136 |
+
cropped_image = self._remove_symbol_tag_bboxes(image, context)
|
| 137 |
+
|
| 138 |
+
# 6) Detect lines
|
| 139 |
+
with self.debug_handler.track_performance("line_detection"):
|
| 140 |
+
self.line_detector.detect(cropped_image, context=context)
|
| 141 |
+
|
| 142 |
+
# 7) Detect points
|
| 143 |
+
if self.point_detector:
|
| 144 |
+
with self.debug_handler.track_performance("point_detection"):
|
| 145 |
+
self.point_detector.detect(cropped_image, context=context)
|
| 146 |
+
|
| 147 |
+
# 8) Detect junctions
|
| 148 |
+
if self.junction_detector:
|
| 149 |
+
with self.debug_handler.track_performance("junction_detection"):
|
| 150 |
+
self.junction_detector.detect(cropped_image, context=context)
|
| 151 |
+
|
| 152 |
+
# 9) Save final JSON & any final images
|
| 153 |
+
output_paths = self._persist_results(output_dir, image_path, context)
|
| 154 |
+
|
| 155 |
+
# 10) Save debug images in local coords using debug_cropped
|
| 156 |
+
self._save_all_combinations(debug_cropped, context, output_dir, image_path)
|
| 157 |
+
|
| 158 |
+
return DetectionResult(
|
| 159 |
+
success=True,
|
| 160 |
+
processing_time=self.debug_handler.metrics.get('total_processing', 0),
|
| 161 |
+
json_path=output_paths.get('json_path'),
|
| 162 |
+
image_path=output_paths.get('image_path') # Now returning the annotated image path
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
logger.error(f"Processing failed: {str(e)}")
|
| 167 |
+
return DetectionResult(
|
| 168 |
+
success=False,
|
| 169 |
+
error=str(e)
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# ------------------------------------------------
|
| 173 |
+
# HELPER FUNCTIONS
|
| 174 |
+
# ------------------------------------------------
|
| 175 |
+
def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict:
|
| 176 |
+
"""Saves only JSON and line detection visualization"""
|
| 177 |
+
base_name = Path(image_path).stem
|
| 178 |
+
if base_name.startswith('display_'):
|
| 179 |
+
base_name = base_name[8:]
|
| 180 |
+
|
| 181 |
+
# Save JSON
|
| 182 |
+
json_path = Path(output_dir) / f"{base_name}_detected_lines.json"
|
| 183 |
+
context_json_str = context.to_json(indent=2)
|
| 184 |
+
self.storage.save_file(str(json_path), context_json_str.encode('utf-8'))
|
| 185 |
+
|
| 186 |
+
# Save line detection visualization using input image
|
| 187 |
+
annotated = self._draw_objects(
|
| 188 |
+
self._load_image(image_path), # Use input image instead of output
|
| 189 |
+
context,
|
| 190 |
+
draw_lines=True,
|
| 191 |
+
draw_points=False,
|
| 192 |
+
draw_symbols=False,
|
| 193 |
+
draw_junctions=False,
|
| 194 |
+
draw_tags=False
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
# Save visualization
|
| 198 |
+
image_path = Path(output_dir) / f"{base_name}_detected_lines.png"
|
| 199 |
+
_, encoded = cv2.imencode('.png', annotated)
|
| 200 |
+
self.storage.save_file(str(image_path), encoded.tobytes())
|
| 201 |
+
|
| 202 |
+
return {
|
| 203 |
+
"json_path": str(json_path),
|
| 204 |
+
"image_path": str(image_path)
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext,
|
| 208 |
+
output_dir: str, image_path: str) -> None:
|
| 209 |
+
"""Only save line detection visualization"""
|
| 210 |
+
base_name = Path(image_path).stem
|
| 211 |
+
if base_name.startswith('display_'):
|
| 212 |
+
base_name = base_name[8:]
|
| 213 |
+
|
| 214 |
+
# Only save line detection visualization
|
| 215 |
+
annotated = self._draw_objects(local_image, context,
|
| 216 |
+
draw_symbols=False,
|
| 217 |
+
draw_tags=False,
|
| 218 |
+
draw_lines=True,
|
| 219 |
+
draw_points=False,
|
| 220 |
+
draw_junctions=False)
|
| 221 |
+
|
| 222 |
+
save_name = f"{base_name}_detected_lines.png"
|
| 223 |
+
save_path = Path(output_dir) / save_name
|
| 224 |
+
_, encoded = cv2.imencode('.png', annotated)
|
| 225 |
+
self.storage.save_file(str(save_path), encoded.tobytes())
|
| 226 |
+
|
| 227 |
+
def _draw_objects(self, base_image: np.ndarray, context: DetectionContext,
|
| 228 |
+
draw_lines: bool = True, draw_points: bool = True,
|
| 229 |
+
draw_symbols: bool = True, draw_junctions: bool = True,
|
| 230 |
+
draw_tags: bool = True) -> np.ndarray:
|
| 231 |
+
"""Draw detection results on a copy of base_image in local coords."""
|
| 232 |
+
annotated = base_image.copy()
|
| 233 |
+
|
| 234 |
+
# Lines
|
| 235 |
+
if draw_lines:
|
| 236 |
+
for ln in context.lines.values():
|
| 237 |
+
cv2.line(annotated,
|
| 238 |
+
(ln.start.coords.x, ln.start.coords.y),
|
| 239 |
+
(ln.end.coords.x, ln.end.coords.y),
|
| 240 |
+
(0, 255, 0), # green
|
| 241 |
+
2)
|
| 242 |
+
|
| 243 |
+
# Points
|
| 244 |
+
if draw_points:
|
| 245 |
+
for pt in context.points.values():
|
| 246 |
+
cv2.circle(annotated,
|
| 247 |
+
(pt.coords.x, pt.coords.y),
|
| 248 |
+
3,
|
| 249 |
+
(0, 0, 255), # red
|
| 250 |
+
-1)
|
| 251 |
+
|
| 252 |
+
# Symbols
|
| 253 |
+
if draw_symbols:
|
| 254 |
+
for sym in context.symbols.values():
|
| 255 |
+
cv2.rectangle(annotated,
|
| 256 |
+
(sym.bbox.xmin, sym.bbox.ymin),
|
| 257 |
+
(sym.bbox.xmax, sym.bbox.ymax),
|
| 258 |
+
(255, 255, 0), # cyan
|
| 259 |
+
2)
|
| 260 |
+
cv2.circle(annotated,
|
| 261 |
+
(sym.center.x, sym.center.y),
|
| 262 |
+
4,
|
| 263 |
+
(255, 0, 255), # magenta
|
| 264 |
+
-1)
|
| 265 |
+
|
| 266 |
+
# Junctions
|
| 267 |
+
if draw_junctions:
|
| 268 |
+
for jn in context.junctions.values():
|
| 269 |
+
if jn.junction_type == JunctionType.T:
|
| 270 |
+
color = (0, 165, 255) # orange
|
| 271 |
+
elif jn.junction_type == JunctionType.L:
|
| 272 |
+
color = (255, 0, 255) # magenta
|
| 273 |
+
else: # END
|
| 274 |
+
color = (0, 0, 255) # red
|
| 275 |
+
cv2.circle(annotated,
|
| 276 |
+
(jn.center.x, jn.center.y),
|
| 277 |
+
5,
|
| 278 |
+
color,
|
| 279 |
+
-1)
|
| 280 |
+
|
| 281 |
+
# Tags
|
| 282 |
+
if draw_tags:
|
| 283 |
+
for tg in context.tags.values():
|
| 284 |
+
cv2.rectangle(annotated,
|
| 285 |
+
(tg.bbox.xmin, tg.bbox.ymin),
|
| 286 |
+
(tg.bbox.xmax, tg.bbox.ymax),
|
| 287 |
+
(128, 0, 128), # purple
|
| 288 |
+
2)
|
| 289 |
+
cv2.putText(annotated,
|
| 290 |
+
tg.text,
|
| 291 |
+
(tg.bbox.xmin, tg.bbox.ymin - 5),
|
| 292 |
+
cv2.FONT_HERSHEY_SIMPLEX,
|
| 293 |
+
0.5,
|
| 294 |
+
(128, 0, 128),
|
| 295 |
+
1)
|
| 296 |
+
|
| 297 |
+
return annotated
|
| 298 |
+
|
| 299 |
+
def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict:
|
| 300 |
+
"""Legacy interface for line detection"""
|
| 301 |
+
storage = StorageFactory.get_storage()
|
| 302 |
+
debug_handler = DebugHandler(enabled=True, storage=storage)
|
| 303 |
+
|
| 304 |
+
line_detector = LineDetector(
|
| 305 |
+
config=LineConfig(),
|
| 306 |
+
model_path="models/deeplsd_md.tar",
|
| 307 |
+
device=torch.device("cpu"),
|
| 308 |
+
debug_handler=debug_handler
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
pipeline = DiagramDetectionPipeline(
|
| 312 |
+
tag_detector=None,
|
| 313 |
+
symbol_detector=None,
|
| 314 |
+
line_detector=line_detector,
|
| 315 |
+
point_detector=None,
|
| 316 |
+
junction_detector=None,
|
| 317 |
+
storage=storage,
|
| 318 |
+
debug_handler=debug_handler
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
result = pipeline.run(image_path, output_dir, ImageConfig())
|
| 322 |
+
return result
|
| 323 |
+
|
| 324 |
+
def _validate_and_normalize_coordinates(self, points):
|
| 325 |
+
"""Validate and normalize coordinates to image space"""
|
| 326 |
+
valid_points = []
|
| 327 |
+
for point in points:
|
| 328 |
+
x, y = point['x'], point['y']
|
| 329 |
+
# Validate coordinates are within image bounds
|
| 330 |
+
if 0 <= x <= self.image_width and 0 <= y <= self.image_height:
|
| 331 |
+
# Normalize coordinates if needed
|
| 332 |
+
valid_points.append({
|
| 333 |
+
'x': int(x),
|
| 334 |
+
'y': int(y),
|
| 335 |
+
'type': point.get('type', 'unknown'),
|
| 336 |
+
'confidence': point.get('confidence', 1.0)
|
| 337 |
+
})
|
| 338 |
+
return valid_points
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
if __name__ == "__main__":
|
| 342 |
+
# 1) Initialize components
|
| 343 |
+
storage = StorageFactory.get_storage()
|
| 344 |
+
debug_handler = DebugHandler(enabled=True, storage=storage)
|
| 345 |
+
|
| 346 |
+
# 2) Build detectors
|
| 347 |
+
conf = {
|
| 348 |
+
"detect_lines": True,
|
| 349 |
+
"line_detection_params": {
|
| 350 |
+
"merge": True,
|
| 351 |
+
"filtering": True,
|
| 352 |
+
"grad_thresh": 3,
|
| 353 |
+
"grad_nfa": True
|
| 354 |
+
}
|
| 355 |
+
}
|
| 356 |
+
|
| 357 |
+
# 3) Configure
|
| 358 |
+
line_config = LineConfig()
|
| 359 |
+
point_config = PointConfig()
|
| 360 |
+
junction_config = JunctionConfig()
|
| 361 |
+
symbol_config = SymbolConfig()
|
| 362 |
+
tag_config = TagConfig()
|
| 363 |
+
|
| 364 |
+
# ========================== Detectors ========================== #
|
| 365 |
+
symbol_detector = SymbolDetector(
|
| 366 |
+
config=symbol_config,
|
| 367 |
+
debug_handler=debug_handler
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
tag_detector = TagDetector(
|
| 371 |
+
config=tag_config,
|
| 372 |
+
debug_handler=debug_handler
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
line_detector = LineDetector(
|
| 376 |
+
config=line_config,
|
| 377 |
+
model_path="models/deeplsd_md.tar",
|
| 378 |
+
model_config=conf,
|
| 379 |
+
device=torch.device("cuda"), # or "cuda" if available
|
| 380 |
+
debug_handler=debug_handler
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
point_detector = PointDetector(
|
| 384 |
+
config=point_config,
|
| 385 |
+
debug_handler=debug_handler)
|
| 386 |
+
|
| 387 |
+
junction_detector = JunctionDetector(
|
| 388 |
+
config=junction_config,
|
| 389 |
+
debug_handler=debug_handler
|
| 390 |
+
)
|
| 391 |
+
|
| 392 |
+
# 4) Create pipeline
|
| 393 |
+
pipeline = DiagramDetectionPipeline(
|
| 394 |
+
tag_detector=tag_detector,
|
| 395 |
+
symbol_detector=symbol_detector,
|
| 396 |
+
line_detector=line_detector,
|
| 397 |
+
point_detector=point_detector,
|
| 398 |
+
junction_detector=junction_detector,
|
| 399 |
+
storage=storage,
|
| 400 |
+
debug_handler=debug_handler
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# 5) Run pipeline
|
| 404 |
+
result = pipeline.run(
|
| 405 |
+
image_path="samples/images/0.jpg",
|
| 406 |
+
output_dir="results/",
|
| 407 |
+
config=ImageConfig()
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
if result.success:
|
| 411 |
+
logger.info(f"Pipeline succeeded! See JSON at {result.json_path}")
|
| 412 |
+
else:
|
| 413 |
+
logger.error(f"Pipeline failed: {result.error}")
|
line_detectors.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Dict, List, Optional, Tuple
|
| 5 |
+
from loguru import logger
|
| 6 |
+
|
| 7 |
+
# Check if DeepLSD is available
|
| 8 |
+
try:
|
| 9 |
+
from deeplsd.models.deeplsd_inference import DeepLSD
|
| 10 |
+
DEEPLSD_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
DEEPLSD_AVAILABLE = False
|
| 13 |
+
logger.warning("DeepLSD not available, falling back to OpenCV")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class OpenCVLineDetector:
|
| 17 |
+
"""Fallback line detector using OpenCV's HoughLinesP"""
|
| 18 |
+
|
| 19 |
+
def __init__(self):
|
| 20 |
+
self.params = {
|
| 21 |
+
'threshold': 50,
|
| 22 |
+
'minLineLength': 50,
|
| 23 |
+
'maxLineGap': 10
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def detect(self, image: np.ndarray) -> Dict:
|
| 27 |
+
"""Detect lines using HoughLinesP"""
|
| 28 |
+
if len(image.shape) == 3:
|
| 29 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 30 |
+
else:
|
| 31 |
+
gray = image
|
| 32 |
+
|
| 33 |
+
edges = cv2.Canny(gray, 50, 150, apertureSize=3)
|
| 34 |
+
lines = cv2.HoughLinesP(
|
| 35 |
+
edges, 1, np.pi/180,
|
| 36 |
+
threshold=self.params['threshold'],
|
| 37 |
+
minLineLength=self.params['minLineLength'],
|
| 38 |
+
maxLineGap=self.params['maxLineGap']
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
detections = []
|
| 42 |
+
if lines is not None:
|
| 43 |
+
for line in lines:
|
| 44 |
+
x1, y1, x2, y2 = line[0]
|
| 45 |
+
detections.append({
|
| 46 |
+
'x1': float(x1),
|
| 47 |
+
'y1': float(y1),
|
| 48 |
+
'x2': float(x2),
|
| 49 |
+
'y2': float(y2),
|
| 50 |
+
'confidence': 1.0
|
| 51 |
+
})
|
| 52 |
+
|
| 53 |
+
return {'lines': detections}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class DeepLSDDetector:
|
| 57 |
+
"""Line detector using DeepLSD model"""
|
| 58 |
+
|
| 59 |
+
def __init__(self, model_path: str):
|
| 60 |
+
if not DEEPLSD_AVAILABLE:
|
| 61 |
+
raise ImportError("DeepLSD is not available")
|
| 62 |
+
|
| 63 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 64 |
+
self.model = self._load_model(model_path)
|
| 65 |
+
|
| 66 |
+
def _load_model(self, model_path: str) -> DeepLSD:
|
| 67 |
+
"""Load the DeepLSD model"""
|
| 68 |
+
try:
|
| 69 |
+
ckpt = torch.load(model_path, map_location=self.device)
|
| 70 |
+
model = DeepLSD()
|
| 71 |
+
model.load_state_dict(ckpt['model'])
|
| 72 |
+
return model.to(self.device).eval()
|
| 73 |
+
except Exception as e:
|
| 74 |
+
logger.error(f"Failed to load DeepLSD model: {str(e)}")
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
def detect(self, image: np.ndarray) -> Dict:
|
| 78 |
+
"""Detect lines using DeepLSD"""
|
| 79 |
+
try:
|
| 80 |
+
# Convert to tensor
|
| 81 |
+
if len(image.shape) == 3:
|
| 82 |
+
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
| 83 |
+
else:
|
| 84 |
+
gray = image
|
| 85 |
+
|
| 86 |
+
tensor = torch.tensor(gray, dtype=torch.float32, device=self.device)[None, None] / 255.0
|
| 87 |
+
|
| 88 |
+
# Run inference
|
| 89 |
+
with torch.no_grad():
|
| 90 |
+
output = self.model({"image": tensor})
|
| 91 |
+
lines = output["lines"][0] # [N, 2, 2] array
|
| 92 |
+
|
| 93 |
+
# Convert to standard format
|
| 94 |
+
detections = []
|
| 95 |
+
for line in lines:
|
| 96 |
+
(x1, y1), (x2, y2) = line
|
| 97 |
+
detections.append({
|
| 98 |
+
'x1': float(x1),
|
| 99 |
+
'y1': float(y1),
|
| 100 |
+
'x2': float(x2),
|
| 101 |
+
'y2': float(y2),
|
| 102 |
+
'confidence': float(output.get("confidence", [1.0])[0])
|
| 103 |
+
})
|
| 104 |
+
|
| 105 |
+
return {'lines': detections}
|
| 106 |
+
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(f"Error in DeepLSD detection: {str(e)}")
|
| 109 |
+
return {'lines': []}
|
logger.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from loguru import logger
|
| 3 |
+
import sys
|
| 4 |
+
|
| 5 |
+
def get_logger(name: str):
|
| 6 |
+
"""Configure and return a logger instance"""
|
| 7 |
+
|
| 8 |
+
# Remove any existing handlers
|
| 9 |
+
logger.remove()
|
| 10 |
+
|
| 11 |
+
# Add a new handler with custom format
|
| 12 |
+
logger.add(
|
| 13 |
+
sys.stderr,
|
| 14 |
+
format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
| 15 |
+
level="INFO"
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
# Add file handler for persistent logging
|
| 19 |
+
logger.add(
|
| 20 |
+
"logs/app.log",
|
| 21 |
+
rotation="500 MB",
|
| 22 |
+
retention="10 days",
|
| 23 |
+
level="DEBUG",
|
| 24 |
+
compression="zip"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
# Create logger for the module
|
| 28 |
+
module_logger = logger.bind(name=name)
|
| 29 |
+
|
| 30 |
+
return module_logger
|
pdf_processor.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import fitz # PyMuPDF
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import logging
|
| 7 |
+
from storage import StorageInterface
|
| 8 |
+
import shutil
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
class DocumentProcessor:
|
| 13 |
+
def __init__(self, storage: StorageInterface):
|
| 14 |
+
self.storage = storage
|
| 15 |
+
self.target_dpi = 600 # Fixed at 600 DPI
|
| 16 |
+
|
| 17 |
+
def clean_results_folder(self, output_dir: str):
|
| 18 |
+
"""Clean the results directory before processing new files"""
|
| 19 |
+
if os.path.exists(output_dir):
|
| 20 |
+
try:
|
| 21 |
+
shutil.rmtree(output_dir)
|
| 22 |
+
logger.info(f"Cleaned results directory: {output_dir}")
|
| 23 |
+
except Exception as e:
|
| 24 |
+
logger.error(f"Error cleaning results directory: {str(e)}")
|
| 25 |
+
raise
|
| 26 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 27 |
+
|
| 28 |
+
def process_document(self, file_path: str, output_dir: str) -> list:
|
| 29 |
+
"""Process document (PDF/PNG/JPG) and return paths to processed pages"""
|
| 30 |
+
# Clean results folder first
|
| 31 |
+
self.clean_results_folder(output_dir)
|
| 32 |
+
|
| 33 |
+
file_ext = Path(file_path).suffix.lower()
|
| 34 |
+
|
| 35 |
+
if file_ext == '.pdf':
|
| 36 |
+
return self._process_pdf(file_path, output_dir)
|
| 37 |
+
elif file_ext in ['.png', '.jpg', '.jpeg']:
|
| 38 |
+
return self._process_image(file_path, output_dir)
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError(f"Unsupported file format: {file_ext}")
|
| 41 |
+
|
| 42 |
+
def _process_pdf(self, pdf_path: str, output_dir: str) -> list:
|
| 43 |
+
"""Process PDF document"""
|
| 44 |
+
processed_pages = []
|
| 45 |
+
base_name = Path(pdf_path).stem
|
| 46 |
+
|
| 47 |
+
try:
|
| 48 |
+
# Open PDF
|
| 49 |
+
doc = fitz.open(pdf_path)
|
| 50 |
+
|
| 51 |
+
for page_num in range(len(doc)):
|
| 52 |
+
page = doc[page_num]
|
| 53 |
+
|
| 54 |
+
# Get high-res image
|
| 55 |
+
pix = page.get_pixmap(matrix=fitz.Matrix(self.target_dpi/72, self.target_dpi/72))
|
| 56 |
+
|
| 57 |
+
# Convert to numpy array
|
| 58 |
+
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
|
| 59 |
+
if pix.n == 4: # RGBA
|
| 60 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
| 61 |
+
|
| 62 |
+
# Save image
|
| 63 |
+
output_path = os.path.join(output_dir, f"{base_name}_page_{page_num + 1}.png")
|
| 64 |
+
self._save_image(img, output_path)
|
| 65 |
+
processed_pages.append(output_path)
|
| 66 |
+
|
| 67 |
+
return processed_pages
|
| 68 |
+
|
| 69 |
+
except Exception as e:
|
| 70 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
| 71 |
+
raise
|
| 72 |
+
|
| 73 |
+
def _process_image(self, image_path: str, output_dir: str) -> list:
|
| 74 |
+
"""Process single image"""
|
| 75 |
+
try:
|
| 76 |
+
# Read image
|
| 77 |
+
img = cv2.imread(image_path)
|
| 78 |
+
if img is None:
|
| 79 |
+
raise ValueError(f"Could not read image: {image_path}")
|
| 80 |
+
|
| 81 |
+
# Convert BGR to RGB
|
| 82 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 83 |
+
|
| 84 |
+
# Calculate scaling factor for 600 DPI
|
| 85 |
+
current_dpi = 72 # Assume standard screen resolution
|
| 86 |
+
scale = self.target_dpi / current_dpi
|
| 87 |
+
|
| 88 |
+
# Resize image
|
| 89 |
+
img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
|
| 90 |
+
|
| 91 |
+
# Save image
|
| 92 |
+
base_name = Path(image_path).stem
|
| 93 |
+
output_path = os.path.join(output_dir, f"{base_name}_page_1.png")
|
| 94 |
+
self._save_image(img, output_path)
|
| 95 |
+
|
| 96 |
+
return [output_path]
|
| 97 |
+
|
| 98 |
+
except Exception as e:
|
| 99 |
+
logger.error(f"Error processing image: {str(e)}")
|
| 100 |
+
raise
|
| 101 |
+
|
| 102 |
+
def _save_image(self, img: np.ndarray, output_path: str):
|
| 103 |
+
"""Save processed image"""
|
| 104 |
+
# Encode image with high quality PNG
|
| 105 |
+
_, buffer = cv2.imencode('.png', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
|
| 106 |
+
self.storage.save_file(output_path, buffer.tobytes())
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
from storage import StorageFactory
|
| 110 |
+
|
| 111 |
+
# Initialize storage and processor
|
| 112 |
+
storage = StorageFactory.get_storage()
|
| 113 |
+
processor = DocumentProcessor(storage)
|
| 114 |
+
|
| 115 |
+
# Process PDF
|
| 116 |
+
pdf_path = "samples/001.pdf"
|
| 117 |
+
output_dir = "results" # Changed from "processed_pages" to "results"
|
| 118 |
+
|
| 119 |
+
try:
|
| 120 |
+
# Ensure output directory exists
|
| 121 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 122 |
+
|
| 123 |
+
results = processor.process_document(
|
| 124 |
+
file_path=pdf_path,
|
| 125 |
+
output_dir=output_dir
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Print detailed results
|
| 129 |
+
print("\nProcessing Results:")
|
| 130 |
+
print(f"Output Directory: {os.path.abspath(output_dir)}")
|
| 131 |
+
|
| 132 |
+
for page_path in results:
|
| 133 |
+
abs_path = os.path.abspath(page_path)
|
| 134 |
+
file_size = os.path.getsize(page_path) / (1024 * 1024) # Convert to MB
|
| 135 |
+
print(f"- {os.path.basename(page_path)} ({file_size:.2f} MB)")
|
| 136 |
+
|
| 137 |
+
# Calculate total size of output
|
| 138 |
+
total_size = sum(os.path.getsize(os.path.join(output_dir, f))
|
| 139 |
+
for f in os.listdir(output_dir)) / (1024 * 1024)
|
| 140 |
+
print(f"\nTotal output size: {total_size:.2f} MB")
|
| 141 |
+
|
| 142 |
+
except Exception as e:
|
| 143 |
+
logger.error(f"Error processing PDF: {str(e)}")
|
| 144 |
+
raise
|
requirements.txt
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Core dependencies
|
| 2 |
+
gradio>=3.50.2
|
| 3 |
+
numpy>=1.24.0
|
| 4 |
+
Pillow>=10.0.0
|
| 5 |
+
opencv-python==4.8.1.78
|
| 6 |
+
PyMuPDF==1.23.8 # for PDF processing
|
| 7 |
+
|
| 8 |
+
# OCR and Text Detection
|
| 9 |
+
python-doctr==0.11.0 # Latest stable version
|
| 10 |
+
easyocr==1.7.1
|
| 11 |
+
pytesseract==0.3.10
|
| 12 |
+
|
| 13 |
+
# Deep Learning
|
| 14 |
+
torch>=2.0.0
|
| 15 |
+
torchvision>=0.15.0
|
| 16 |
+
tensorflow==2.11.0 # Optional with torch
|
| 17 |
+
|
| 18 |
+
# Graph Processing
|
| 19 |
+
networkx>=3.0
|
| 20 |
+
matplotlib>=3.7.0
|
| 21 |
+
|
| 22 |
+
# Utilities
|
| 23 |
+
python-dotenv>=1.0.0
|
| 24 |
+
tqdm==4.66.1
|
| 25 |
+
loguru==0.7.2
|
| 26 |
+
scipy==1.11.4
|
| 27 |
+
pypdfium2==4.20.0
|
| 28 |
+
weasyprint==60.1
|
| 29 |
+
|
| 30 |
+
# Storage and Processing
|
| 31 |
+
azure-storage-blob==12.19.0
|
| 32 |
+
azure-core==1.29.5
|
| 33 |
+
|
| 34 |
+
# OCR Engines
|
| 35 |
+
ultralytics==8.0.0 # for YOLO models
|
| 36 |
+
deeplsd @ git+https://github.com/cvg/DeepLSD.git
|
| 37 |
+
omegaconf>=2.3.0 # Required by DeepLSD
|
| 38 |
+
pytlsd @ git+https://github.com/iago-suarez/pytlsd.git # Required by DeepLSD
|
| 39 |
+
|
| 40 |
+
# AI/Chat
|
| 41 |
+
openai>=1.0.0 # For ChatGPT integration
|
| 42 |
+
uuid>=1.30
|
| 43 |
+
shapely>=1.8.0 # for geometry operations
|
| 44 |
+
|
| 45 |
+
# Added from the code block
|
| 46 |
+
requests>=2.31.0
|
| 47 |
+
|
| 48 |
+
# Added from the code block
|
| 49 |
+
opencv-python-headless>=4.8.0
|
| 50 |
+
|
| 51 |
+
# Added from the code block
|
| 52 |
+
huggingface-hub>=0.19.0
|
| 53 |
+
transformers>=4.35.0
|
| 54 |
+
gradio==5.15.0
|
results/002_page_1_aggregated.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
results/002_page_1_detected_symbols.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
storage.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
from azure.storage.blob import BlobServiceClient
|
| 4 |
+
from abc import ABC, abstractmethod
|
| 5 |
+
import json
|
| 6 |
+
|
| 7 |
+
class StorageInterface(ABC):
|
| 8 |
+
|
| 9 |
+
@abstractmethod
|
| 10 |
+
def save_file(self, file_path: str, content: bytes) -> str:
|
| 11 |
+
pass
|
| 12 |
+
|
| 13 |
+
@abstractmethod
|
| 14 |
+
def load_file(self, file_path: str) -> bytes:
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def list_files(self, directory: str) -> list[str]:
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def file_exists(self, file_path: str) -> bool:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def delete_file(self, file_path: str) -> None:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def create_directory(self, directory: str) -> None:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def delete_directory(self, directory: str) -> None:
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def upload(self, local_path: str, destination_path: str) -> None:
|
| 39 |
+
pass
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def append_file(self, file_path: str, content: bytes) -> None:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def get_modified_time(self, file_path: str) -> float:
|
| 47 |
+
pass
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def directory_exists(self, directory: str) -> bool:
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def load_json(self, file_path):
|
| 54 |
+
"""Load and parse JSON file."""
|
| 55 |
+
try:
|
| 56 |
+
with open(file_path, 'r', encoding='utf-8') as f:
|
| 57 |
+
data = json.load(f)
|
| 58 |
+
return data
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Error loading JSON from {file_path}: {str(e)}")
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LocalStorage(StorageInterface):
|
| 65 |
+
|
| 66 |
+
def save_file(self, file_path: str, content: bytes) -> str:
|
| 67 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 68 |
+
with open(file_path, 'wb') as f:
|
| 69 |
+
f.write(content)
|
| 70 |
+
return file_path
|
| 71 |
+
|
| 72 |
+
def load_file(self, file_path: str) -> bytes:
|
| 73 |
+
with open(file_path, 'rb') as f:
|
| 74 |
+
return f.read()
|
| 75 |
+
|
| 76 |
+
def list_files(self, directory: str) -> list[str]:
|
| 77 |
+
return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
| 78 |
+
|
| 79 |
+
def file_exists(self, file_path: str) -> bool:
|
| 80 |
+
return os.path.exists(file_path)
|
| 81 |
+
|
| 82 |
+
def delete_file(self, file_path: str) -> None:
|
| 83 |
+
os.remove(file_path)
|
| 84 |
+
|
| 85 |
+
def create_directory(self, directory: str) -> None:
|
| 86 |
+
os.makedirs(directory, exist_ok=True)
|
| 87 |
+
|
| 88 |
+
def delete_directory(self, directory: str) -> None:
|
| 89 |
+
shutil.rmtree(directory)
|
| 90 |
+
|
| 91 |
+
def upload(self, local_path: str, destination_path: str) -> None:
|
| 92 |
+
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
|
| 93 |
+
shutil.copy(local_path, destination_path)
|
| 94 |
+
|
| 95 |
+
def append_file(self, file_path: str, content: bytes) -> None:
|
| 96 |
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
| 97 |
+
with open(file_path, 'ab') as f:
|
| 98 |
+
f.write(content)
|
| 99 |
+
|
| 100 |
+
def get_modified_time(self, file_path: str) -> float:
|
| 101 |
+
return os.path.getmtime(file_path)
|
| 102 |
+
|
| 103 |
+
def directory_exists(self, directory: str) -> bool:
|
| 104 |
+
return self.file_exists(directory)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class BlobStorage(StorageInterface):
|
| 108 |
+
"""
|
| 109 |
+
Writes to blob storage, using local disk as a cache
|
| 110 |
+
|
| 111 |
+
TODO: Allow configuration of temp dir instead of just using the same paths in both local and remote
|
| 112 |
+
"""
|
| 113 |
+
def __init__(self, connection_string: str, container_name: str):
|
| 114 |
+
self.blob_service_client = BlobServiceClient.from_connection_string(connection_string)
|
| 115 |
+
self.container_client = self.blob_service_client.get_container_client(container_name)
|
| 116 |
+
self.local_storage = LocalStorage()
|
| 117 |
+
|
| 118 |
+
def download(self, file_path: str) -> bytes:
|
| 119 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 120 |
+
return blob_client.download_blob().readall()
|
| 121 |
+
|
| 122 |
+
def sync(self, file_path: str) -> None:
|
| 123 |
+
if not self.local_storage.file_exists(file_path):
|
| 124 |
+
print(f"DEBUG: missing local version of {file_path} - downloading")
|
| 125 |
+
self.local_storage.save_file(file_path, self.download(file_path))
|
| 126 |
+
else:
|
| 127 |
+
local_timestamp = self.local_storage.get_modified_time(file_path)
|
| 128 |
+
remote_timestamp = self.get_modified_time(file_path)
|
| 129 |
+
if local_timestamp < remote_timestamp:
|
| 130 |
+
# We always write remotely before writing locally, so we expect local_timestamp to be > remote timestamp
|
| 131 |
+
print(f"DBEUG: local version of {file_path} out of date - downloading")
|
| 132 |
+
self.local_storage.save_file(file_path, self.download(file_path))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def save_file(self, file_path: str, content: bytes) -> str:
|
| 136 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 137 |
+
blob_client.upload_blob(content, overwrite=True)
|
| 138 |
+
self.local_storage.save_file(file_path, content)
|
| 139 |
+
return file_path
|
| 140 |
+
|
| 141 |
+
def load_file(self, file_path: str) -> bytes:
|
| 142 |
+
self.sync(file_path)
|
| 143 |
+
return self.local_storage.load_file(file_path)
|
| 144 |
+
|
| 145 |
+
def list_files(self, directory: str) -> list[str]:
|
| 146 |
+
return [blob.name for blob in self.container_client.list_blobs(name_starts_with=directory)]
|
| 147 |
+
|
| 148 |
+
def file_exists(self, file_path: str) -> bool:
|
| 149 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 150 |
+
return blob_client.exists()
|
| 151 |
+
|
| 152 |
+
def delete_file(self, file_path: str) -> None:
|
| 153 |
+
self.local_storage.delete_file(file_path)
|
| 154 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 155 |
+
blob_client.delete_blob()
|
| 156 |
+
|
| 157 |
+
def create_directory(self, directory: str) -> None:
|
| 158 |
+
# Blob storage doesn't have directories, so only create it locally
|
| 159 |
+
self.local_storage.create_directory(directory)
|
| 160 |
+
|
| 161 |
+
def delete_directory(self, directory: str) -> None:
|
| 162 |
+
self.local_storage.delete_directory(directory)
|
| 163 |
+
blobs_to_delete = self.container_client.list_blobs(name_starts_with=directory)
|
| 164 |
+
for blob in blobs_to_delete:
|
| 165 |
+
self.container_client.delete_blob(blob.name)
|
| 166 |
+
|
| 167 |
+
def upload(self, local_path: str, destination_path: str) -> None:
|
| 168 |
+
with open(local_path, "rb") as data:
|
| 169 |
+
blob_client = self.container_client.get_blob_client(destination_path)
|
| 170 |
+
blob_client.upload_blob(data, overwrite=True)
|
| 171 |
+
self.local_storage.upload(local_path, destination_path)
|
| 172 |
+
|
| 173 |
+
def append_file(self, file_path: str, content: bytes) -> None:
|
| 174 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 175 |
+
if not blob_client.exists():
|
| 176 |
+
blob_client.create_append_blob()
|
| 177 |
+
else:
|
| 178 |
+
self.sync(file_path)
|
| 179 |
+
|
| 180 |
+
blob_client.append_block(content)
|
| 181 |
+
self.local_storage.append_file(file_path, content)
|
| 182 |
+
|
| 183 |
+
def get_modified_time(self, file_path: str) -> float:
|
| 184 |
+
blob_client = self.container_client.get_blob_client(file_path)
|
| 185 |
+
properties = blob_client.get_blob_properties()
|
| 186 |
+
# Convert the UTC datetime to a UNIX timestamp
|
| 187 |
+
return properties.last_modified.timestamp()
|
| 188 |
+
|
| 189 |
+
def directory_exists(self, directory: str) -> bool:
|
| 190 |
+
blobs = self.container_client.list_blobs(name_starts_with=directory)
|
| 191 |
+
return next(blobs, None) is not None
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class StorageFactory:
|
| 195 |
+
@staticmethod
|
| 196 |
+
def get_storage() -> StorageInterface:
|
| 197 |
+
storage_type = os.getenv('STORAGE_TYPE', 'local').lower()
|
| 198 |
+
if storage_type == 'local':
|
| 199 |
+
return LocalStorage()
|
| 200 |
+
elif storage_type == 'blob':
|
| 201 |
+
connection_string = os.getenv('AZURE_STORAGE_CONNECTION_STRING')
|
| 202 |
+
container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME')
|
| 203 |
+
if not connection_string or not container_name:
|
| 204 |
+
raise ValueError("Azure Blob Storage connection string and container name must be set")
|
| 205 |
+
return BlobStorage(connection_string, container_name)
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError(f"Unsupported storage type: {storage_type}")
|
| 208 |
+
|
symbol_detection.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import json
|
| 3 |
+
import uuid
|
| 4 |
+
import os
|
| 5 |
+
import logging
|
| 6 |
+
from ultralytics import YOLO
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from storage import StorageInterface
|
| 9 |
+
import numpy as np
|
| 10 |
+
from typing import Tuple, List, Dict, Any
|
| 11 |
+
|
| 12 |
+
# Configure logging
|
| 13 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 14 |
+
|
| 15 |
+
# Constants
|
| 16 |
+
MODEL_PATHS = {
|
| 17 |
+
"model1": "models/Intui_SDM_41.pt",
|
| 18 |
+
"model2": "models/Intui_SDM_20.pt" # Add your second model path here
|
| 19 |
+
}
|
| 20 |
+
MAX_DIMENSION = 1280
|
| 21 |
+
CONFIDENCE_THRESHOLDS = [0.1, 0.3, 0.5, 0.7, 0.9]
|
| 22 |
+
TEXT_COLOR = (0, 0, 255) # Red color for text
|
| 23 |
+
BOX_COLOR = (255, 0, 0) # Red color for box (no transparency)
|
| 24 |
+
BG_COLOR = (255, 255, 255, 0.6) # Semi-transparent white for text background
|
| 25 |
+
THICKNESS = 1 # Thin text thickness
|
| 26 |
+
BOX_THICKNESS = 2 # Box line thickness
|
| 27 |
+
MIN_FONT_SCALE = 0.2 # Minimum font scale
|
| 28 |
+
MAX_FONT_SCALE = 1.0 # Maximum font scale
|
| 29 |
+
TEXT_PADDING = 20 # Increased padding between text elements
|
| 30 |
+
OVERLAP_THRESHOLD = 0.3 # Threshold for detecting text overlap
|
| 31 |
+
|
| 32 |
+
def preprocess_image_for_symbol_detection(image_cv: np.ndarray) -> np.ndarray:
|
| 33 |
+
"""Preprocess the image for symbol detection."""
|
| 34 |
+
gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
|
| 35 |
+
equalized = cv2.equalizeHist(gray)
|
| 36 |
+
filtered = cv2.bilateralFilter(equalized, 9, 75, 75)
|
| 37 |
+
edges = cv2.Canny(filtered, 100, 200)
|
| 38 |
+
preprocessed_image = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
|
| 39 |
+
return preprocessed_image
|
| 40 |
+
|
| 41 |
+
def evaluate_detections(detections_list: List[Dict[str, Any]]) -> int:
|
| 42 |
+
"""Evaluate the quality of detections."""
|
| 43 |
+
return len(detections_list)
|
| 44 |
+
|
| 45 |
+
def resize_image_with_aspect_ratio(image_cv: np.ndarray, max_dimension: int) -> Tuple[np.ndarray, int, int]:
|
| 46 |
+
"""Resize the image while maintaining the aspect ratio."""
|
| 47 |
+
original_height, original_width = image_cv.shape[:2]
|
| 48 |
+
if max(original_width, original_height) > max_dimension:
|
| 49 |
+
scale = max_dimension / float(max(original_width, original_height))
|
| 50 |
+
new_width = int(original_width * scale)
|
| 51 |
+
new_height = int(original_height * scale)
|
| 52 |
+
image_cv = cv2.resize(image_cv, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
|
| 53 |
+
else:
|
| 54 |
+
new_width, new_height = original_width, original_height
|
| 55 |
+
return image_cv, new_width, new_height
|
| 56 |
+
|
| 57 |
+
def merge_detections(all_detections: List[Dict]) -> List[Dict]:
|
| 58 |
+
"""
|
| 59 |
+
Merge detections from all models, keeping only the highest confidence detection
|
| 60 |
+
when duplicates are found using IoU.
|
| 61 |
+
"""
|
| 62 |
+
if not all_detections:
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
# Sort by confidence
|
| 66 |
+
all_detections.sort(key=lambda x: x['confidence'], reverse=True)
|
| 67 |
+
|
| 68 |
+
# Keep track of which detections to keep
|
| 69 |
+
keep = [True] * len(all_detections)
|
| 70 |
+
|
| 71 |
+
def calculate_iou(box1, box2):
|
| 72 |
+
"""Calculate Intersection over Union (IoU) between two boxes."""
|
| 73 |
+
x1 = max(box1[0], box2[0])
|
| 74 |
+
y1 = max(box1[1], box2[1])
|
| 75 |
+
x2 = min(box1[2], box2[2])
|
| 76 |
+
y2 = min(box1[3], box2[3])
|
| 77 |
+
|
| 78 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
| 79 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 80 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 81 |
+
union = area1 + area2 - intersection
|
| 82 |
+
|
| 83 |
+
return intersection / union if union > 0 else 0
|
| 84 |
+
|
| 85 |
+
# Apply NMS and keep only highest confidence detection
|
| 86 |
+
for i in range(len(all_detections)):
|
| 87 |
+
if not keep[i]:
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
current_box = all_detections[i]['bbox']
|
| 91 |
+
current_label = all_detections[i]['original_label']
|
| 92 |
+
|
| 93 |
+
for j in range(i + 1, len(all_detections)):
|
| 94 |
+
if not keep[j]:
|
| 95 |
+
continue
|
| 96 |
+
|
| 97 |
+
# Check if same label type and high IoU
|
| 98 |
+
if (all_detections[j]['original_label'] == current_label and
|
| 99 |
+
calculate_iou(current_box, all_detections[j]['bbox']) > 0.5):
|
| 100 |
+
# Since list is sorted by confidence, i will always have higher confidence than j
|
| 101 |
+
keep[j] = False
|
| 102 |
+
logging.info(f"Removing duplicate detection of {current_label} with lower confidence "
|
| 103 |
+
f"({all_detections[j]['confidence']:.2f} < {all_detections[i]['confidence']:.2f})")
|
| 104 |
+
|
| 105 |
+
# Add kept detections to final list
|
| 106 |
+
merged_detections = [det for i, det in enumerate(all_detections) if keep[i]]
|
| 107 |
+
return merged_detections
|
| 108 |
+
|
| 109 |
+
def calculate_font_scale(image_width: int, bbox_width: int) -> float:
|
| 110 |
+
"""
|
| 111 |
+
Calculate appropriate font scale based on image and bbox dimensions.
|
| 112 |
+
"""
|
| 113 |
+
base_scale = 0.7 # Increased base scale for better visibility
|
| 114 |
+
|
| 115 |
+
# Adjust font size based on image width and bbox width
|
| 116 |
+
width_ratio = image_width / MAX_DIMENSION
|
| 117 |
+
bbox_ratio = bbox_width / image_width
|
| 118 |
+
|
| 119 |
+
# Calculate adaptive scale with increased multipliers
|
| 120 |
+
adaptive_scale = base_scale * max(width_ratio, 0.5) * max(bbox_ratio * 6, 0.7)
|
| 121 |
+
|
| 122 |
+
# Ensure font scale stays within reasonable bounds
|
| 123 |
+
return min(max(adaptive_scale, MIN_FONT_SCALE), MAX_FONT_SCALE)
|
| 124 |
+
|
| 125 |
+
def check_overlap(rect1, rect2):
|
| 126 |
+
"""Check if two rectangles overlap."""
|
| 127 |
+
x1_1, y1_1, x2_1, y2_1 = rect1
|
| 128 |
+
x1_2, y1_2, x2_2, y2_2 = rect2
|
| 129 |
+
|
| 130 |
+
return not (x2_1 < x1_2 or x1_1 > x2_2 or y2_1 < y1_2 or y1_1 > y2_2)
|
| 131 |
+
|
| 132 |
+
def draw_annotation(
|
| 133 |
+
image: np.ndarray,
|
| 134 |
+
bbox: List[int],
|
| 135 |
+
text: str,
|
| 136 |
+
confidence: float,
|
| 137 |
+
model_source: str,
|
| 138 |
+
existing_annotations: List[tuple] = None
|
| 139 |
+
) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Draw annotation with no background and thin fonts.
|
| 142 |
+
"""
|
| 143 |
+
if existing_annotations is None:
|
| 144 |
+
existing_annotations = []
|
| 145 |
+
|
| 146 |
+
x1, y1, x2, y2 = bbox
|
| 147 |
+
bbox_width = x2 - x1
|
| 148 |
+
image_width = image.shape[1]
|
| 149 |
+
image_height = image.shape[0]
|
| 150 |
+
|
| 151 |
+
# Calculate adaptive font scale
|
| 152 |
+
font_scale = calculate_font_scale(image_width, bbox_width)
|
| 153 |
+
|
| 154 |
+
# Simplify the annotation text
|
| 155 |
+
annotation_text = f'{text}\n{confidence:.0f}%'
|
| 156 |
+
lines = annotation_text.split('\n')
|
| 157 |
+
|
| 158 |
+
# Calculate text dimensions
|
| 159 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 160 |
+
max_width = 0
|
| 161 |
+
total_height = 0
|
| 162 |
+
line_heights = []
|
| 163 |
+
|
| 164 |
+
for line in lines:
|
| 165 |
+
(width, height), baseline = cv2.getTextSize(
|
| 166 |
+
line, font, font_scale, THICKNESS
|
| 167 |
+
)
|
| 168 |
+
max_width = max(max_width, width)
|
| 169 |
+
line_height = height + baseline + TEXT_PADDING
|
| 170 |
+
line_heights.append(line_height)
|
| 171 |
+
total_height += line_height
|
| 172 |
+
|
| 173 |
+
# Calculate initial text position with increased padding
|
| 174 |
+
padding = TEXT_PADDING
|
| 175 |
+
rect_x1 = max(0, x1 - padding)
|
| 176 |
+
rect_x2 = min(image_width, x1 + max_width + padding * 2)
|
| 177 |
+
|
| 178 |
+
# Try different positions to avoid overlap
|
| 179 |
+
positions = [
|
| 180 |
+
('top', y1 - total_height - padding),
|
| 181 |
+
('bottom', y2 + padding),
|
| 182 |
+
('top_shifted', y1 - total_height - padding * 2),
|
| 183 |
+
('bottom_shifted', y2 + padding * 2)
|
| 184 |
+
]
|
| 185 |
+
|
| 186 |
+
final_position = None
|
| 187 |
+
for pos_name, y_pos in positions:
|
| 188 |
+
if y_pos < 0 or y_pos + total_height > image_height:
|
| 189 |
+
continue
|
| 190 |
+
|
| 191 |
+
rect = (rect_x1, y_pos, rect_x2, y_pos + total_height)
|
| 192 |
+
overlap = False
|
| 193 |
+
|
| 194 |
+
for existing_rect in existing_annotations:
|
| 195 |
+
if check_overlap(rect, existing_rect):
|
| 196 |
+
overlap = True
|
| 197 |
+
break
|
| 198 |
+
|
| 199 |
+
if not overlap:
|
| 200 |
+
final_position = (pos_name, y_pos)
|
| 201 |
+
existing_annotations.append(rect)
|
| 202 |
+
break
|
| 203 |
+
|
| 204 |
+
# If no non-overlapping position found, use side position
|
| 205 |
+
if final_position is None:
|
| 206 |
+
rect_x1 = max(0, x1 + bbox_width + padding)
|
| 207 |
+
rect_x2 = min(image_width, rect_x1 + max_width + padding * 2)
|
| 208 |
+
y_pos = y1
|
| 209 |
+
final_position = ('side', y_pos)
|
| 210 |
+
|
| 211 |
+
rect_y1 = final_position[1]
|
| 212 |
+
|
| 213 |
+
# Draw bounding box (no transparency)
|
| 214 |
+
cv2.rectangle(image, (x1, y1), (x2, y2), BOX_COLOR, BOX_THICKNESS)
|
| 215 |
+
|
| 216 |
+
# Draw text directly without background
|
| 217 |
+
text_y = rect_y1 + line_heights[0] - padding
|
| 218 |
+
for i, line in enumerate(lines):
|
| 219 |
+
# Draw text with thin lines
|
| 220 |
+
cv2.putText(
|
| 221 |
+
image,
|
| 222 |
+
line,
|
| 223 |
+
(rect_x1 + padding, text_y + sum(line_heights[:i])),
|
| 224 |
+
font,
|
| 225 |
+
font_scale,
|
| 226 |
+
TEXT_COLOR,
|
| 227 |
+
THICKNESS,
|
| 228 |
+
cv2.LINE_AA
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
def run_detection_with_optimal_threshold(
|
| 232 |
+
image_path: str,
|
| 233 |
+
results_dir: str = "results",
|
| 234 |
+
file_name: str = "",
|
| 235 |
+
apply_preprocessing: bool = False,
|
| 236 |
+
resize_image: bool = True, # Changed default to True
|
| 237 |
+
storage: StorageInterface = None
|
| 238 |
+
) -> Tuple[str, str, str, List[int]]:
|
| 239 |
+
"""Run detection with multiple models and merge results."""
|
| 240 |
+
try:
|
| 241 |
+
image_data = storage.load_file(image_path)
|
| 242 |
+
nparr = np.frombuffer(image_data, np.uint8)
|
| 243 |
+
original_image_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
| 244 |
+
image_cv = original_image_cv.copy()
|
| 245 |
+
|
| 246 |
+
if resize_image:
|
| 247 |
+
logging.info("Resizing image for detection with aspect ratio...")
|
| 248 |
+
image_cv, resized_width, resized_height = resize_image_with_aspect_ratio(image_cv, MAX_DIMENSION)
|
| 249 |
+
else:
|
| 250 |
+
logging.info("Skipping image resizing...")
|
| 251 |
+
resized_height, resized_width = original_image_cv.shape[:2]
|
| 252 |
+
|
| 253 |
+
if apply_preprocessing:
|
| 254 |
+
logging.info("Preprocessing image for symbol detection...")
|
| 255 |
+
image_cv = preprocess_image_for_symbol_detection(image_cv)
|
| 256 |
+
else:
|
| 257 |
+
logging.info("Skipping image preprocessing for symbol detection...")
|
| 258 |
+
|
| 259 |
+
all_detections = []
|
| 260 |
+
|
| 261 |
+
# Run detection with each model
|
| 262 |
+
for model_name, model_path in MODEL_PATHS.items():
|
| 263 |
+
logging.info(f"Running detection with model: {model_name}")
|
| 264 |
+
|
| 265 |
+
if not model_path:
|
| 266 |
+
logging.warning(f"No model path found for {model_name}")
|
| 267 |
+
continue
|
| 268 |
+
|
| 269 |
+
model = YOLO(model_path)
|
| 270 |
+
|
| 271 |
+
best_confidence_threshold = 0.5
|
| 272 |
+
best_detections_list = []
|
| 273 |
+
best_metric = -1
|
| 274 |
+
|
| 275 |
+
for confidence_threshold in CONFIDENCE_THRESHOLDS:
|
| 276 |
+
logging.info(f"Running detection with confidence threshold: {confidence_threshold}...")
|
| 277 |
+
results = model.predict(source=image_cv, imgsz=MAX_DIMENSION)
|
| 278 |
+
|
| 279 |
+
detections_list = []
|
| 280 |
+
for result in results:
|
| 281 |
+
for box in result.boxes:
|
| 282 |
+
confidence = float(box.conf[0])
|
| 283 |
+
if confidence >= confidence_threshold:
|
| 284 |
+
x1, y1, x2, y2 = map(float, box.xyxy[0])
|
| 285 |
+
class_id = int(box.cls[0])
|
| 286 |
+
label = result.names[class_id]
|
| 287 |
+
|
| 288 |
+
scale_x = original_image_cv.shape[1] / resized_width
|
| 289 |
+
scale_y = original_image_cv.shape[0] / resized_height
|
| 290 |
+
x1 *= scale_x
|
| 291 |
+
x2 *= scale_x
|
| 292 |
+
y1 *= scale_y
|
| 293 |
+
y2 *= scale_y
|
| 294 |
+
x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
|
| 295 |
+
|
| 296 |
+
split_label = label.split('_')
|
| 297 |
+
if len(split_label) >= 3:
|
| 298 |
+
category = split_label[0]
|
| 299 |
+
type_ = split_label[1]
|
| 300 |
+
new_label = '_'.join(split_label[2:])
|
| 301 |
+
elif len(split_label) == 2:
|
| 302 |
+
category = split_label[0]
|
| 303 |
+
type_ = split_label[1]
|
| 304 |
+
new_label = split_label[1]
|
| 305 |
+
elif len(split_label) == 1:
|
| 306 |
+
category = split_label[0]
|
| 307 |
+
type_ = "Unknown"
|
| 308 |
+
new_label = split_label[0]
|
| 309 |
+
else:
|
| 310 |
+
logging.warning(f"Unexpected label format: {label}. Skipping this detection.")
|
| 311 |
+
continue
|
| 312 |
+
|
| 313 |
+
detection_id = str(uuid.uuid4())
|
| 314 |
+
detection_info = {
|
| 315 |
+
"symbol_id": detection_id,
|
| 316 |
+
"class_id": class_id,
|
| 317 |
+
"original_label": label,
|
| 318 |
+
"category": category,
|
| 319 |
+
"type": type_,
|
| 320 |
+
"label": new_label,
|
| 321 |
+
"confidence": confidence,
|
| 322 |
+
"bbox": [x1, y1, x2, y2],
|
| 323 |
+
"model_source": model_name
|
| 324 |
+
}
|
| 325 |
+
detections_list.append(detection_info)
|
| 326 |
+
|
| 327 |
+
metric = evaluate_detections(detections_list)
|
| 328 |
+
if metric > best_metric:
|
| 329 |
+
best_metric = metric
|
| 330 |
+
best_confidence_threshold = confidence_threshold
|
| 331 |
+
best_detections_list = detections_list
|
| 332 |
+
|
| 333 |
+
all_detections.extend(best_detections_list)
|
| 334 |
+
|
| 335 |
+
# Merge detections from both models
|
| 336 |
+
merged_detections = merge_detections(all_detections)
|
| 337 |
+
logging.info(f"Total detections after merging: {len(merged_detections)}")
|
| 338 |
+
|
| 339 |
+
# Draw annotations on the image
|
| 340 |
+
existing_annotations = []
|
| 341 |
+
for det in merged_detections:
|
| 342 |
+
draw_annotation(
|
| 343 |
+
original_image_cv,
|
| 344 |
+
det["bbox"],
|
| 345 |
+
det["original_label"],
|
| 346 |
+
det["confidence"] * 100,
|
| 347 |
+
det["model_source"],
|
| 348 |
+
existing_annotations
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
# Save results
|
| 352 |
+
storage.create_directory(results_dir)
|
| 353 |
+
file_name_without_extension = os.path.splitext(file_name)[0]
|
| 354 |
+
|
| 355 |
+
# Prepare output JSON
|
| 356 |
+
total_detected_symbols = len(merged_detections)
|
| 357 |
+
class_counts = {}
|
| 358 |
+
for det in merged_detections:
|
| 359 |
+
full_label = det["original_label"]
|
| 360 |
+
class_counts[full_label] = class_counts.get(full_label, 0) + 1
|
| 361 |
+
|
| 362 |
+
output_json = {
|
| 363 |
+
"total_detected_symbols": total_detected_symbols,
|
| 364 |
+
"details": class_counts,
|
| 365 |
+
"detections": merged_detections
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
# Save JSON and image
|
| 369 |
+
detection_json_path = os.path.join(
|
| 370 |
+
results_dir, f'{file_name_without_extension}_detected_symbols.json'
|
| 371 |
+
)
|
| 372 |
+
storage.save_file(
|
| 373 |
+
detection_json_path,
|
| 374 |
+
json.dumps(output_json, indent=4).encode('utf-8')
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
# Save with maximum quality
|
| 378 |
+
detection_image_path = os.path.join(
|
| 379 |
+
results_dir, f'{file_name_without_extension}_detected_symbols.png' # Using PNG for transparency
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# Configure image encoding parameters for maximum quality
|
| 383 |
+
encode_params = [
|
| 384 |
+
cv2.IMWRITE_PNG_COMPRESSION, 0 # No compression for PNG
|
| 385 |
+
]
|
| 386 |
+
|
| 387 |
+
# Save as high-quality PNG to preserve transparency
|
| 388 |
+
_, img_encoded = cv2.imencode(
|
| 389 |
+
'.png',
|
| 390 |
+
original_image_cv,
|
| 391 |
+
encode_params
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
storage.save_file(detection_image_path, img_encoded.tobytes())
|
| 395 |
+
|
| 396 |
+
# Calculate diagram bbox from merged detections
|
| 397 |
+
diagram_bbox = [
|
| 398 |
+
min([det['bbox'][0] for det in merged_detections], default=0),
|
| 399 |
+
min([det['bbox'][1] for det in merged_detections], default=0),
|
| 400 |
+
max([det['bbox'][2] for det in merged_detections], default=0),
|
| 401 |
+
max([det['bbox'][3] for det in merged_detections], default=0)
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
# Scale up image if it's too small
|
| 405 |
+
min_width = 2000 # Minimum width for good visibility
|
| 406 |
+
if original_image_cv.shape[1] < min_width:
|
| 407 |
+
scale_factor = min_width / original_image_cv.shape[1]
|
| 408 |
+
new_width = min_width
|
| 409 |
+
new_height = int(original_image_cv.shape[0] * scale_factor)
|
| 410 |
+
original_image_cv = cv2.resize(
|
| 411 |
+
original_image_cv,
|
| 412 |
+
(new_width, new_height),
|
| 413 |
+
interpolation=cv2.INTER_CUBIC
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
return (
|
| 417 |
+
detection_image_path,
|
| 418 |
+
detection_json_path,
|
| 419 |
+
f"Total detections after merging: {total_detected_symbols}",
|
| 420 |
+
diagram_bbox
|
| 421 |
+
)
|
| 422 |
+
except Exception as e:
|
| 423 |
+
logging.error(f"An error occurred: {e}")
|
| 424 |
+
return "Error during detection", None, None, None
|
| 425 |
+
|
| 426 |
+
if __name__ == "__main__":
|
| 427 |
+
from storage import StorageFactory
|
| 428 |
+
|
| 429 |
+
uploaded_file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
|
| 430 |
+
results_dir = "results"
|
| 431 |
+
apply_symbol_preprocessing = False
|
| 432 |
+
resize_image = True
|
| 433 |
+
|
| 434 |
+
storage = StorageFactory.get_storage()
|
| 435 |
+
|
| 436 |
+
(
|
| 437 |
+
detection_image_path,
|
| 438 |
+
detection_json_path,
|
| 439 |
+
detection_log_message,
|
| 440 |
+
diagram_bbox
|
| 441 |
+
) = run_detection_with_optimal_threshold(
|
| 442 |
+
uploaded_file_path,
|
| 443 |
+
results_dir=results_dir,
|
| 444 |
+
file_name=os.path.basename(uploaded_file_path),
|
| 445 |
+
apply_preprocessing=apply_symbol_preprocessing,
|
| 446 |
+
resize_image=resize_image,
|
| 447 |
+
storage=storage
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
logging.info("Detection Image Path: %s", detection_image_path)
|
| 451 |
+
logging.info("Detection JSON Path: %s", detection_json_path)
|
| 452 |
+
logging.info("Detection Log Message: %s", detection_log_message)
|
| 453 |
+
logging.info("Diagram BBox: %s", diagram_bbox)
|
| 454 |
+
logging.info("Done!")
|
text_detection_combined.py
ADDED
|
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import io
|
| 4 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 5 |
+
import numpy as np
|
| 6 |
+
from doctr.models import ocr_predictor
|
| 7 |
+
import pytesseract
|
| 8 |
+
import easyocr
|
| 9 |
+
from storage import StorageInterface
|
| 10 |
+
import re
|
| 11 |
+
import logging
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import cv2
|
| 14 |
+
import traceback
|
| 15 |
+
from typing import Tuple
|
| 16 |
+
|
| 17 |
+
# Initialize models
|
| 18 |
+
try:
|
| 19 |
+
logging.basicConfig(level=logging.INFO)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
doctr_model = ocr_predictor(pretrained=True)
|
| 22 |
+
easyocr_reader = easyocr.Reader(['en'])
|
| 23 |
+
logging.info("All OCR models loaded successfully")
|
| 24 |
+
except Exception as e:
|
| 25 |
+
logging.error(f"Error loading OCR models: {e}")
|
| 26 |
+
|
| 27 |
+
# Combined patterns from all approaches
|
| 28 |
+
TEXT_PATTERNS = {
|
| 29 |
+
'Line_Number': r"(?:\d{1,5}[-](?:[A-Z]{2,4})[-]\d{1,3})",
|
| 30 |
+
'Equipment_Tag': r"(?:[A-Z]{1,3}[-][A-Z0-9]{1,4}[-]\d{1,3})",
|
| 31 |
+
'Instrument_Tag': r"(?:\d{2,3}[-][A-Z]{2,4}[-]\d{2,3})",
|
| 32 |
+
'Valve_Number': r"(?:[A-Z]{1,2}[-]\d{3})",
|
| 33 |
+
'Pipe_Size': r"(?:\d{1,2}[\"])",
|
| 34 |
+
'Flow_Direction': r"(?:FROM|TO)",
|
| 35 |
+
'Service_Description': r"(?:STEAM|WATER|AIR|GAS|DRAIN)",
|
| 36 |
+
'Process_Instrument': r"(?:[0-9]{2,3}(?:-[A-Z]{2,3})?-[0-9]{2,3}|[A-Z]{2,3}-[0-9]{2,3})",
|
| 37 |
+
'Nozzle': r"(?:N[0-9]{1,2}|MH)",
|
| 38 |
+
'Pipe_Connector': r"(?:[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5})"
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def detect_text_combined(image, confidence_threshold=0.3):
|
| 42 |
+
"""Combine results from all three OCR approaches"""
|
| 43 |
+
results = []
|
| 44 |
+
|
| 45 |
+
# 1. Tesseract Detection
|
| 46 |
+
tesseract_results = detect_with_tesseract(image)
|
| 47 |
+
for result in tesseract_results:
|
| 48 |
+
result['source'] = 'tesseract'
|
| 49 |
+
results.append(result)
|
| 50 |
+
|
| 51 |
+
# 2. EasyOCR Detection
|
| 52 |
+
easyocr_results = detect_with_easyocr(image)
|
| 53 |
+
for result in easyocr_results:
|
| 54 |
+
result['source'] = 'easyocr'
|
| 55 |
+
results.append(result)
|
| 56 |
+
|
| 57 |
+
# 3. DocTR Detection
|
| 58 |
+
doctr_results = detect_with_doctr(image)
|
| 59 |
+
for result in doctr_results:
|
| 60 |
+
result['source'] = 'doctr'
|
| 61 |
+
results.append(result)
|
| 62 |
+
|
| 63 |
+
# Merge overlapping detections
|
| 64 |
+
merged_results = merge_overlapping_detections(results)
|
| 65 |
+
|
| 66 |
+
# Classify and filter results
|
| 67 |
+
classified_results = []
|
| 68 |
+
for result in merged_results:
|
| 69 |
+
if result['confidence'] >= confidence_threshold:
|
| 70 |
+
text_type = classify_text(result['text'])
|
| 71 |
+
result['text_type'] = text_type
|
| 72 |
+
classified_results.append(result)
|
| 73 |
+
|
| 74 |
+
return classified_results
|
| 75 |
+
|
| 76 |
+
def generate_detailed_summary(results):
|
| 77 |
+
"""Generate detailed detection summary"""
|
| 78 |
+
summary = {
|
| 79 |
+
'total_detections': len(results),
|
| 80 |
+
'by_type': {},
|
| 81 |
+
'by_source': {
|
| 82 |
+
'tesseract': {
|
| 83 |
+
'count': 0,
|
| 84 |
+
'by_type': {},
|
| 85 |
+
'avg_confidence': 0.0
|
| 86 |
+
},
|
| 87 |
+
'easyocr': {
|
| 88 |
+
'count': 0,
|
| 89 |
+
'by_type': {},
|
| 90 |
+
'avg_confidence': 0.0
|
| 91 |
+
},
|
| 92 |
+
'doctr': {
|
| 93 |
+
'count': 0,
|
| 94 |
+
'by_type': {},
|
| 95 |
+
'avg_confidence': 0.0
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
'confidence_ranges': {
|
| 99 |
+
'0.9-1.0': 0,
|
| 100 |
+
'0.8-0.9': 0,
|
| 101 |
+
'0.7-0.8': 0,
|
| 102 |
+
'0.6-0.7': 0,
|
| 103 |
+
'0.5-0.6': 0,
|
| 104 |
+
'<0.5': 0
|
| 105 |
+
},
|
| 106 |
+
'detected_items': []
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
# Initialize type counters
|
| 110 |
+
for pattern_type in TEXT_PATTERNS.keys():
|
| 111 |
+
summary['by_type'][pattern_type] = {
|
| 112 |
+
'count': 0,
|
| 113 |
+
'avg_confidence': 0.0,
|
| 114 |
+
'by_source': {
|
| 115 |
+
'tesseract': 0,
|
| 116 |
+
'easyocr': 0,
|
| 117 |
+
'doctr': 0
|
| 118 |
+
},
|
| 119 |
+
'items': []
|
| 120 |
+
}
|
| 121 |
+
# Initialize source-specific type counters
|
| 122 |
+
for source in summary['by_source'].keys():
|
| 123 |
+
summary['by_source'][source]['by_type'][pattern_type] = 0
|
| 124 |
+
|
| 125 |
+
# Process each detection
|
| 126 |
+
source_confidences = {'tesseract': [], 'easyocr': [], 'doctr': []}
|
| 127 |
+
|
| 128 |
+
for result in results:
|
| 129 |
+
# Get source and confidence
|
| 130 |
+
source = result['source']
|
| 131 |
+
conf = result['confidence']
|
| 132 |
+
text_type = result['text_type']
|
| 133 |
+
|
| 134 |
+
# Update source statistics
|
| 135 |
+
summary['by_source'][source]['count'] += 1
|
| 136 |
+
source_confidences[source].append(conf)
|
| 137 |
+
|
| 138 |
+
# Update confidence ranges
|
| 139 |
+
if conf >= 0.9: summary['confidence_ranges']['0.9-1.0'] += 1
|
| 140 |
+
elif conf >= 0.8: summary['confidence_ranges']['0.8-0.9'] += 1
|
| 141 |
+
elif conf >= 0.7: summary['confidence_ranges']['0.7-0.8'] += 1
|
| 142 |
+
elif conf >= 0.6: summary['confidence_ranges']['0.6-0.7'] += 1
|
| 143 |
+
elif conf >= 0.5: summary['confidence_ranges']['0.5-0.6'] += 1
|
| 144 |
+
else: summary['confidence_ranges']['<0.5'] += 1
|
| 145 |
+
|
| 146 |
+
# Update type statistics
|
| 147 |
+
if text_type in summary['by_type']:
|
| 148 |
+
type_stats = summary['by_type'][text_type]
|
| 149 |
+
type_stats['count'] += 1
|
| 150 |
+
type_stats['by_source'][source] += 1
|
| 151 |
+
summary['by_source'][source]['by_type'][text_type] += 1
|
| 152 |
+
type_stats['items'].append({
|
| 153 |
+
'text': result['text'],
|
| 154 |
+
'confidence': conf,
|
| 155 |
+
'source': source,
|
| 156 |
+
'bbox': result['bbox']
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
# Add to detected items
|
| 160 |
+
summary['detected_items'].append({
|
| 161 |
+
'text': result['text'],
|
| 162 |
+
'type': text_type,
|
| 163 |
+
'confidence': conf,
|
| 164 |
+
'source': source,
|
| 165 |
+
'bbox': result['bbox']
|
| 166 |
+
})
|
| 167 |
+
|
| 168 |
+
# Calculate average confidences
|
| 169 |
+
for source, confs in source_confidences.items():
|
| 170 |
+
if confs:
|
| 171 |
+
summary['by_source'][source]['avg_confidence'] = sum(confs) / len(confs)
|
| 172 |
+
|
| 173 |
+
# Calculate average confidences for each type
|
| 174 |
+
for text_type, stats in summary['by_type'].items():
|
| 175 |
+
if stats['items']:
|
| 176 |
+
stats['avg_confidence'] = sum(item['confidence'] for item in stats['items']) / len(stats['items'])
|
| 177 |
+
|
| 178 |
+
return summary
|
| 179 |
+
|
| 180 |
+
def process_drawing(image_path: str, output_dir: str, storage: StorageInterface) -> Tuple[dict, dict]:
|
| 181 |
+
"""Process drawing with text detection.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
image_path: Path to image file
|
| 185 |
+
output_dir: Directory to save results
|
| 186 |
+
storage: Optional storage handler
|
| 187 |
+
"""
|
| 188 |
+
try:
|
| 189 |
+
# Read image using cv2
|
| 190 |
+
image = cv2.imread(image_path)
|
| 191 |
+
if image is None:
|
| 192 |
+
raise ValueError(f"Could not read image: {image_path}")
|
| 193 |
+
|
| 194 |
+
# Create annotated copy
|
| 195 |
+
annotated_image = image.copy()
|
| 196 |
+
|
| 197 |
+
# Initialize results and summary
|
| 198 |
+
text_results = {
|
| 199 |
+
'file_name': image_path,
|
| 200 |
+
'detections': []
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
text_summary = {
|
| 204 |
+
'total_detections': 0,
|
| 205 |
+
'by_source': {
|
| 206 |
+
'tesseract': {'count': 0, 'avg_confidence': 0.0},
|
| 207 |
+
'easyocr': {'count': 0, 'avg_confidence': 0.0},
|
| 208 |
+
'doctr': {'count': 0, 'avg_confidence': 0.0}
|
| 209 |
+
},
|
| 210 |
+
'by_type': {
|
| 211 |
+
'equipment_tag': {'count': 0, 'avg_confidence': 0.0},
|
| 212 |
+
'line_number': {'count': 0, 'avg_confidence': 0.0},
|
| 213 |
+
'instrument_tag': {'count': 0, 'avg_confidence': 0.0},
|
| 214 |
+
'valve_number': {'count': 0, 'avg_confidence': 0.0},
|
| 215 |
+
'pipe_size': {'count': 0, 'avg_confidence': 0.0},
|
| 216 |
+
'flow_direction': {'count': 0, 'avg_confidence': 0.0},
|
| 217 |
+
'service_description': {'count': 0, 'avg_confidence': 0.0},
|
| 218 |
+
'process_instrument': {'count': 0, 'avg_confidence': 0.0},
|
| 219 |
+
'nozzle': {'count': 0, 'avg_confidence': 0.0},
|
| 220 |
+
'pipe_connector': {'count': 0, 'avg_confidence': 0.0},
|
| 221 |
+
'other': {'count': 0, 'avg_confidence': 0.0}
|
| 222 |
+
}
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
# Run OCR with different engines
|
| 226 |
+
tesseract_results = detect_with_tesseract(image)
|
| 227 |
+
easyocr_results = detect_with_easyocr(image)
|
| 228 |
+
doctr_results = detect_with_doctr(image)
|
| 229 |
+
|
| 230 |
+
# Combine results
|
| 231 |
+
all_detections = []
|
| 232 |
+
all_detections.extend([(res, 'tesseract') for res in tesseract_results])
|
| 233 |
+
all_detections.extend([(res, 'easyocr') for res in easyocr_results])
|
| 234 |
+
all_detections.extend([(res, 'doctr') for res in doctr_results])
|
| 235 |
+
|
| 236 |
+
# Process each detection
|
| 237 |
+
for detection, source in all_detections:
|
| 238 |
+
# Update text_results
|
| 239 |
+
text_results['detections'].append({
|
| 240 |
+
'text': detection['text'],
|
| 241 |
+
'bbox': detection['bbox'],
|
| 242 |
+
'confidence': detection['confidence'],
|
| 243 |
+
'source': source
|
| 244 |
+
})
|
| 245 |
+
|
| 246 |
+
# Update summary statistics
|
| 247 |
+
text_summary['total_detections'] += 1
|
| 248 |
+
text_summary['by_source'][source]['count'] += 1
|
| 249 |
+
text_summary['by_source'][source]['avg_confidence'] += detection['confidence']
|
| 250 |
+
|
| 251 |
+
# Draw detection on image
|
| 252 |
+
x1, y1, x2, y2 = detection['bbox']
|
| 253 |
+
cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
| 254 |
+
cv2.putText(annotated_image, detection['text'], (int(x1), int(y1)-5),
|
| 255 |
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
|
| 256 |
+
|
| 257 |
+
# Calculate average confidences
|
| 258 |
+
for source in text_summary['by_source']:
|
| 259 |
+
if text_summary['by_source'][source]['count'] > 0:
|
| 260 |
+
text_summary['by_source'][source]['avg_confidence'] /= text_summary['by_source'][source]['count']
|
| 261 |
+
|
| 262 |
+
# Save results with new naming convention
|
| 263 |
+
base_name = Path(image_path).stem
|
| 264 |
+
if base_name.startswith('display_'):
|
| 265 |
+
base_name = base_name[8:]
|
| 266 |
+
|
| 267 |
+
text_result_image_path = os.path.join(output_dir, f"{base_name}_detected_texts.png")
|
| 268 |
+
text_result_json_path = os.path.join(output_dir, f"{base_name}_detected_texts.json")
|
| 269 |
+
|
| 270 |
+
# Save the annotated image
|
| 271 |
+
cv2.imwrite(text_result_image_path, annotated_image)
|
| 272 |
+
|
| 273 |
+
# Save the JSON results
|
| 274 |
+
with open(text_result_json_path, 'w', encoding='utf-8') as f:
|
| 275 |
+
json.dump({
|
| 276 |
+
'file_name': image_path,
|
| 277 |
+
'summary': text_summary,
|
| 278 |
+
'detections': text_results['detections']
|
| 279 |
+
}, f, indent=4, ensure_ascii=False)
|
| 280 |
+
|
| 281 |
+
return {
|
| 282 |
+
'image_path': text_result_image_path,
|
| 283 |
+
'json_path': text_result_json_path,
|
| 284 |
+
'results': text_results
|
| 285 |
+
}, text_summary
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
logger.error(f"Text detection error: {str(e)}")
|
| 289 |
+
raise
|
| 290 |
+
|
| 291 |
+
def detect_with_tesseract(image):
|
| 292 |
+
"""Detect text using Tesseract OCR"""
|
| 293 |
+
# Configure Tesseract for technical drawings
|
| 294 |
+
custom_config = r'--oem 3 --psm 11 -c tessedit_char_whitelist="ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-.()" -c tessedit_write_images=true -c textord_heavy_nr=true -c textord_min_linesize=3'
|
| 295 |
+
|
| 296 |
+
try:
|
| 297 |
+
data = pytesseract.image_to_data(
|
| 298 |
+
image,
|
| 299 |
+
config=custom_config,
|
| 300 |
+
output_type=pytesseract.Output.DICT
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
results = []
|
| 304 |
+
for i in range(len(data['text'])):
|
| 305 |
+
conf = float(data['conf'][i])
|
| 306 |
+
if conf > 30: # Lower confidence threshold for technical text
|
| 307 |
+
text = data['text'][i].strip()
|
| 308 |
+
if text:
|
| 309 |
+
x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
|
| 310 |
+
results.append({
|
| 311 |
+
'text': text,
|
| 312 |
+
'bbox': [x, y, x + w, y + h],
|
| 313 |
+
'confidence': conf / 100.0
|
| 314 |
+
})
|
| 315 |
+
return results
|
| 316 |
+
|
| 317 |
+
except Exception as e:
|
| 318 |
+
logger.error(f"Tesseract error: {str(e)}")
|
| 319 |
+
return []
|
| 320 |
+
|
| 321 |
+
def detect_with_easyocr(image):
|
| 322 |
+
"""Detect text using EasyOCR"""
|
| 323 |
+
if easyocr_reader is None:
|
| 324 |
+
return []
|
| 325 |
+
|
| 326 |
+
try:
|
| 327 |
+
results = easyocr_reader.readtext(
|
| 328 |
+
np.array(image),
|
| 329 |
+
paragraph=False,
|
| 330 |
+
height_ths=2.0,
|
| 331 |
+
width_ths=2.0,
|
| 332 |
+
contrast_ths=0.2,
|
| 333 |
+
text_threshold=0.5
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
parsed_results = []
|
| 337 |
+
for bbox, text, conf in results:
|
| 338 |
+
x1, y1 = min(point[0] for point in bbox), min(point[1] for point in bbox)
|
| 339 |
+
x2, y2 = max(point[0] for point in bbox), max(point[1] for point in bbox)
|
| 340 |
+
|
| 341 |
+
parsed_results.append({
|
| 342 |
+
'text': text,
|
| 343 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
| 344 |
+
'confidence': conf
|
| 345 |
+
})
|
| 346 |
+
return parsed_results
|
| 347 |
+
|
| 348 |
+
except Exception as e:
|
| 349 |
+
logger.error(f"EasyOCR error: {str(e)}")
|
| 350 |
+
return []
|
| 351 |
+
|
| 352 |
+
def detect_with_doctr(image):
|
| 353 |
+
"""Detect text using DocTR"""
|
| 354 |
+
try:
|
| 355 |
+
# Convert PIL image to numpy array
|
| 356 |
+
image_np = np.array(image)
|
| 357 |
+
|
| 358 |
+
# Get predictions
|
| 359 |
+
result = doctr_model([image_np])
|
| 360 |
+
doc = result.export()
|
| 361 |
+
|
| 362 |
+
# Parse results
|
| 363 |
+
results = []
|
| 364 |
+
for page in doc['pages']:
|
| 365 |
+
for block in page['blocks']:
|
| 366 |
+
for line in block['lines']:
|
| 367 |
+
for word in line['words']:
|
| 368 |
+
# Convert normalized coordinates to absolute
|
| 369 |
+
height, width = image_np.shape[:2]
|
| 370 |
+
points = np.array(word['geometry']) * np.array([width, height])
|
| 371 |
+
x1, y1 = points.min(axis=0)
|
| 372 |
+
x2, y2 = points.max(axis=0)
|
| 373 |
+
|
| 374 |
+
results.append({
|
| 375 |
+
'text': word['value'],
|
| 376 |
+
'bbox': [int(x1), int(y1), int(x2), int(y2)],
|
| 377 |
+
'confidence': word.get('confidence', 0.5)
|
| 378 |
+
})
|
| 379 |
+
return results
|
| 380 |
+
|
| 381 |
+
except Exception as e:
|
| 382 |
+
logger.error(f"DocTR error: {str(e)}")
|
| 383 |
+
return []
|
| 384 |
+
|
| 385 |
+
def merge_overlapping_detections(results, iou_threshold=0.5):
|
| 386 |
+
"""Merge overlapping detections from different sources"""
|
| 387 |
+
if not results:
|
| 388 |
+
return []
|
| 389 |
+
|
| 390 |
+
def calculate_iou(box1, box2):
|
| 391 |
+
x1 = max(box1[0], box2[0])
|
| 392 |
+
y1 = max(box1[1], box2[1])
|
| 393 |
+
x2 = min(box1[2], box2[2])
|
| 394 |
+
y2 = min(box1[3], box2[3])
|
| 395 |
+
|
| 396 |
+
if x2 < x1 or y2 < y1:
|
| 397 |
+
return 0.0
|
| 398 |
+
|
| 399 |
+
intersection = (x2 - x1) * (y2 - y1)
|
| 400 |
+
area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
| 401 |
+
area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
| 402 |
+
union = area1 + area2 - intersection
|
| 403 |
+
|
| 404 |
+
return intersection / union if union > 0 else 0
|
| 405 |
+
|
| 406 |
+
merged = []
|
| 407 |
+
used = set()
|
| 408 |
+
|
| 409 |
+
for i, r1 in enumerate(results):
|
| 410 |
+
if i in used:
|
| 411 |
+
continue
|
| 412 |
+
|
| 413 |
+
current_group = [r1]
|
| 414 |
+
used.add(i)
|
| 415 |
+
|
| 416 |
+
for j, r2 in enumerate(results):
|
| 417 |
+
if j in used:
|
| 418 |
+
continue
|
| 419 |
+
|
| 420 |
+
if calculate_iou(r1['bbox'], r2['bbox']) > iou_threshold:
|
| 421 |
+
current_group.append(r2)
|
| 422 |
+
used.add(j)
|
| 423 |
+
|
| 424 |
+
if len(current_group) == 1:
|
| 425 |
+
merged.append(current_group[0])
|
| 426 |
+
else:
|
| 427 |
+
# Keep the detection with highest confidence
|
| 428 |
+
best_detection = max(current_group, key=lambda x: x['confidence'])
|
| 429 |
+
merged.append(best_detection)
|
| 430 |
+
|
| 431 |
+
return merged
|
| 432 |
+
|
| 433 |
+
def classify_text(text):
|
| 434 |
+
"""Classify text based on patterns"""
|
| 435 |
+
if not text:
|
| 436 |
+
return 'Unknown'
|
| 437 |
+
|
| 438 |
+
# Clean and normalize text
|
| 439 |
+
text = text.strip().upper()
|
| 440 |
+
text = re.sub(r'\s+', '', text)
|
| 441 |
+
|
| 442 |
+
for text_type, pattern in TEXT_PATTERNS.items():
|
| 443 |
+
if re.match(pattern, text):
|
| 444 |
+
return text_type
|
| 445 |
+
|
| 446 |
+
return 'Unknown'
|
| 447 |
+
|
| 448 |
+
def annotate_image(image, results):
|
| 449 |
+
"""Create annotated image with detections"""
|
| 450 |
+
# Convert image to RGB mode to ensure color support
|
| 451 |
+
if image.mode != 'RGB':
|
| 452 |
+
image = image.convert('RGB')
|
| 453 |
+
|
| 454 |
+
# Create drawing object
|
| 455 |
+
draw = ImageDraw.Draw(image)
|
| 456 |
+
try:
|
| 457 |
+
font = ImageFont.truetype("arial.ttf", 20)
|
| 458 |
+
except IOError:
|
| 459 |
+
font = ImageFont.load_default()
|
| 460 |
+
|
| 461 |
+
# Define colors for different text types
|
| 462 |
+
colors = {
|
| 463 |
+
'Line_Number': "#FF0000", # Bright Red
|
| 464 |
+
'Equipment_Tag': "#00FF00", # Bright Green
|
| 465 |
+
'Instrument_Tag': "#0000FF", # Bright Blue
|
| 466 |
+
'Valve_Number': "#FFA500", # Bright Orange
|
| 467 |
+
'Pipe_Size': "#FF00FF", # Bright Magenta
|
| 468 |
+
'Process_Instrument': "#00FFFF", # Bright Cyan
|
| 469 |
+
'Nozzle': "#FFFF00", # Yellow
|
| 470 |
+
'Pipe_Connector': "#800080", # Purple
|
| 471 |
+
'Unknown': "#FF4444" # Light Red
|
| 472 |
+
}
|
| 473 |
+
|
| 474 |
+
# Draw detections
|
| 475 |
+
for result in results:
|
| 476 |
+
text_type = result.get('text_type', 'Unknown')
|
| 477 |
+
color = colors.get(text_type, colors['Unknown'])
|
| 478 |
+
|
| 479 |
+
# Draw bounding box
|
| 480 |
+
draw.rectangle(result['bbox'], outline=color, width=3)
|
| 481 |
+
|
| 482 |
+
# Create label
|
| 483 |
+
label = f"{result['text']} ({result['confidence']:.2f})"
|
| 484 |
+
if text_type != 'Unknown':
|
| 485 |
+
label += f" [{text_type}]"
|
| 486 |
+
|
| 487 |
+
# Draw label background
|
| 488 |
+
text_bbox = draw.textbbox((result['bbox'][0], result['bbox'][1] - 20), label, font=font)
|
| 489 |
+
draw.rectangle(text_bbox, fill="#FFFFFF")
|
| 490 |
+
|
| 491 |
+
# Draw label text
|
| 492 |
+
draw.text((result['bbox'][0], result['bbox'][1] - 20), label, fill=color, font=font)
|
| 493 |
+
|
| 494 |
+
return image
|
| 495 |
+
|
| 496 |
+
def save_annotated_image(image, path, storage):
|
| 497 |
+
"""Save annotated image with maximum quality"""
|
| 498 |
+
image_byte_array = io.BytesIO()
|
| 499 |
+
image.save(
|
| 500 |
+
image_byte_array,
|
| 501 |
+
format='PNG',
|
| 502 |
+
optimize=False,
|
| 503 |
+
compress_level=0
|
| 504 |
+
)
|
| 505 |
+
storage.save_file(path, image_byte_array.getvalue())
|
| 506 |
+
|
| 507 |
+
if __name__ == "__main__":
|
| 508 |
+
from storage import StorageFactory
|
| 509 |
+
import logging
|
| 510 |
+
|
| 511 |
+
# Configure logging
|
| 512 |
+
logging.basicConfig(level=logging.INFO)
|
| 513 |
+
logger = logging.getLogger(__name__)
|
| 514 |
+
|
| 515 |
+
# Initialize storage
|
| 516 |
+
storage = StorageFactory.get_storage()
|
| 517 |
+
|
| 518 |
+
# Test file paths
|
| 519 |
+
file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
|
| 520 |
+
result_path = "results"
|
| 521 |
+
|
| 522 |
+
try:
|
| 523 |
+
# Ensure result directory exists
|
| 524 |
+
os.makedirs(result_path, exist_ok=True)
|
| 525 |
+
|
| 526 |
+
# Process the drawing
|
| 527 |
+
logger.info(f"Processing file: {file_path}")
|
| 528 |
+
results, summary = process_drawing(file_path, result_path, storage)
|
| 529 |
+
|
| 530 |
+
# Print detailed results
|
| 531 |
+
print("\n=== DETAILED DETECTION RESULTS ===")
|
| 532 |
+
print(f"\nTotal Detections: {summary['total_detections']}")
|
| 533 |
+
|
| 534 |
+
print("\nBreakdown by Text Type:")
|
| 535 |
+
print("-" * 50)
|
| 536 |
+
for text_type, stats in summary['by_type'].items():
|
| 537 |
+
if stats['count'] > 0:
|
| 538 |
+
print(f"\n{text_type}:")
|
| 539 |
+
print(f" Count: {stats['count']}")
|
| 540 |
+
print(f" Average Confidence: {stats['avg_confidence']:.2f}")
|
| 541 |
+
print(" Items:")
|
| 542 |
+
for item in stats['items']:
|
| 543 |
+
print(f" - {item['text']} (conf: {item['confidence']:.2f}, source: {item['source']})")
|
| 544 |
+
|
| 545 |
+
print("\nBreakdown by OCR Engine:")
|
| 546 |
+
print("-" * 50)
|
| 547 |
+
for source, count in summary['by_source'].items():
|
| 548 |
+
print(f"{source}: {count} detections")
|
| 549 |
+
|
| 550 |
+
print("\nConfidence Distribution:")
|
| 551 |
+
print("-" * 50)
|
| 552 |
+
for range_name, count in summary['confidence_ranges'].items():
|
| 553 |
+
print(f"{range_name}: {count} detections")
|
| 554 |
+
|
| 555 |
+
# Print output paths
|
| 556 |
+
print("\nOutput Files:")
|
| 557 |
+
print("-" * 50)
|
| 558 |
+
print(f"Annotated Image: {results['image_path']}")
|
| 559 |
+
print(f"JSON Results: {results['json_path']}")
|
| 560 |
+
|
| 561 |
+
except Exception as e:
|
| 562 |
+
logger.error(f"Error processing file: {e}")
|
| 563 |
+
raise
|
utils.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import numpy as np
|
| 3 |
+
from contextlib import contextmanager
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from typing import List, Dict, Optional, Tuple, Union
|
| 6 |
+
from detection_schema import BBox
|
| 7 |
+
from storage import StorageInterface
|
| 8 |
+
import cv2
|
| 9 |
+
|
| 10 |
+
class DebugHandler:
|
| 11 |
+
"""Production-grade debugging and performance tracking"""
|
| 12 |
+
|
| 13 |
+
def __init__(self, enabled: bool = False, storage: StorageInterface = None):
|
| 14 |
+
self.enabled = enabled
|
| 15 |
+
self.storage = storage
|
| 16 |
+
self.metrics = {}
|
| 17 |
+
self._start_time = None
|
| 18 |
+
|
| 19 |
+
@contextmanager
|
| 20 |
+
def track_performance(self, operation_name: str):
|
| 21 |
+
"""Context manager for performance tracking"""
|
| 22 |
+
if self.enabled:
|
| 23 |
+
self._start_time = time.perf_counter()
|
| 24 |
+
logger.debug(f"Starting {operation_name}")
|
| 25 |
+
|
| 26 |
+
yield
|
| 27 |
+
|
| 28 |
+
if self.enabled:
|
| 29 |
+
duration = time.perf_counter() - self._start_time
|
| 30 |
+
self.metrics[operation_name] = duration
|
| 31 |
+
logger.debug(f"{operation_name} completed in {duration:.2f}s")
|
| 32 |
+
|
| 33 |
+
def save_artifact(self, name: str, data: bytes, extension: str = "png"):
|
| 34 |
+
"""Generic artifact storage handler"""
|
| 35 |
+
if self.enabled and self.storage:
|
| 36 |
+
path = f"debug/{name}.{extension}"
|
| 37 |
+
|
| 38 |
+
# Check if data is an np.ndarray (image)
|
| 39 |
+
if isinstance(data, np.ndarray):
|
| 40 |
+
# Convert np.ndarray to PNG bytes
|
| 41 |
+
success, encoded_image = cv2.imencode(f".{extension}", data)
|
| 42 |
+
if not success:
|
| 43 |
+
logger.error("Failed to encode image for saving.")
|
| 44 |
+
return
|
| 45 |
+
data = encoded_image.tobytes()
|
| 46 |
+
|
| 47 |
+
self.storage.save_file(path, data)
|
| 48 |
+
logger.info(f"Saved debug artifact: {path}")
|
| 49 |
+
|
| 50 |
+
class CoordinateTransformer:
|
| 51 |
+
@staticmethod
|
| 52 |
+
def global_to_local_bbox(
|
| 53 |
+
bbox: Union[BBox, List[BBox]],
|
| 54 |
+
roi: Optional[np.ndarray]
|
| 55 |
+
) -> Union[BBox, List[BBox]]:
|
| 56 |
+
"""
|
| 57 |
+
Convert global BBox(es) to ROI-local coordinates
|
| 58 |
+
Handles both single BBox and lists of BBoxes
|
| 59 |
+
"""
|
| 60 |
+
if roi is None or len(roi) != 4:
|
| 61 |
+
return bbox
|
| 62 |
+
|
| 63 |
+
x_min, y_min, _, _ = roi
|
| 64 |
+
|
| 65 |
+
def convert(b: BBox) -> BBox:
|
| 66 |
+
return BBox(
|
| 67 |
+
xmin=b.xmin - x_min,
|
| 68 |
+
ymin=b.ymin - y_min,
|
| 69 |
+
xmax=b.xmax - x_min,
|
| 70 |
+
ymax=b.ymax - y_min
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
|
| 74 |
+
|
| 75 |
+
@staticmethod
|
| 76 |
+
def local_to_global_bbox(
|
| 77 |
+
bbox: Union[BBox, List[BBox]],
|
| 78 |
+
roi: Optional[np.ndarray]
|
| 79 |
+
) -> Union[BBox, List[BBox]]:
|
| 80 |
+
"""
|
| 81 |
+
Convert ROI-local BBox(es) to global coordinates
|
| 82 |
+
Handles both single BBox and lists of BBoxes
|
| 83 |
+
"""
|
| 84 |
+
if roi is None or len(roi) != 4:
|
| 85 |
+
return bbox
|
| 86 |
+
|
| 87 |
+
x_min, y_min, _, _ = roi
|
| 88 |
+
|
| 89 |
+
def convert(b: BBox) -> BBox:
|
| 90 |
+
return BBox(
|
| 91 |
+
xmin=b.xmin + x_min,
|
| 92 |
+
ymin=b.ymin + y_min,
|
| 93 |
+
xmax=b.xmax + x_min,
|
| 94 |
+
ymax=b.ymax + y_min
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
|
| 98 |
+
|
| 99 |
+
# Maintain legacy tuple support if needed
|
| 100 |
+
@staticmethod
|
| 101 |
+
def global_to_local(
|
| 102 |
+
bboxes: List[Tuple[int, int, int, int]],
|
| 103 |
+
roi: Optional[np.ndarray]
|
| 104 |
+
) -> List[Tuple[int, int, int, int]]:
|
| 105 |
+
"""Legacy tuple version for backward compatibility"""
|
| 106 |
+
if roi is None or len(roi) != 4:
|
| 107 |
+
return bboxes
|
| 108 |
+
|
| 109 |
+
x_min, y_min, _, _ = roi
|
| 110 |
+
return [(x1 - x_min, y1 - y_min, x2 - x_min, y2 - y_min)
|
| 111 |
+
for x1, y1, x2, y2 in bboxes]
|
| 112 |
+
|
| 113 |
+
@staticmethod
|
| 114 |
+
def local_to_global(
|
| 115 |
+
bboxes: List[Tuple[int, int, int, int]],
|
| 116 |
+
roi: Optional[np.ndarray]
|
| 117 |
+
) -> List[Tuple[int, int, int, int]]:
|
| 118 |
+
"""Legacy tuple version for backward compatibility"""
|
| 119 |
+
if roi is None or len(roi) != 4:
|
| 120 |
+
return bboxes
|
| 121 |
+
|
| 122 |
+
x_min, y_min, _, _ = roi
|
| 123 |
+
return [(x1 + x_min, y1 + y_min, x2 + x_min, y2 + y_min)
|
| 124 |
+
for x1, y1, x2, y2 in bboxes]
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def local_to_global_point(point: Tuple[int, int], roi: Optional[np.ndarray]) -> Tuple[int, int]:
|
| 128 |
+
"""Convert single point from local to global coordinates"""
|
| 129 |
+
if roi is None or len(roi) != 4:
|
| 130 |
+
return point
|
| 131 |
+
x_min, y_min, _, _ = roi
|
| 132 |
+
return (int(point[0] + x_min), int(point[1] + y_min))
|