msIntui commited on
Commit
910e0d4
Β·
0 Parent(s):

feat: initial clean deployment

Browse files
.gitignore ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Directories to ignore
2
+ archive/
3
+ debug/
4
+ samples/
5
+ chat/
6
+ # Models - allow specific model files
7
+ models/*
8
+ !models/yolo/
9
+ !models/deeplsd/
10
+ !models/doctr/
11
+ !models/*.pt
12
+ !models/*.tar
13
+ results/
14
+ logs/
15
+ DeepLSD/
16
+
17
+ # Large files
18
+ *.tar
19
+ *.pt
20
+ *.pth
21
+ *.onnx
22
+ *.weights
23
+
24
+ # Python
25
+ __pycache__/
26
+ *.py[cod]
27
+ *$py.class
28
+ *.so
29
+ .Python
30
+ env/
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ *.egg-info/
43
+ .installed.cfg
44
+ *.egg
45
+
46
+ # Virtual Environment
47
+ .venv
48
+ venv/
49
+ ENV/
50
+
51
+ # IDE
52
+ .idea/
53
+ .vscode/
54
+ *.swp
55
+ *.swo
56
+
57
+ # Project specific
58
+ !results/
59
+ !results/*.json
60
+ debug/
61
+ *.log
62
+ *.gz
63
+ models/
64
+ archive/
65
+ weights/
66
+
67
+ # Environment variables
68
+ .env
69
+ .env.*
70
+
71
+ # Explicitly track assets
72
+ !assets/
73
+ !assets/*.png
74
+ !assets/*.css
README.md ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Intelligent_PID
3
+ emoji: πŸ”
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.50.2
8
+ app_file: gradioChatApp.py
9
+ pinned: false
10
+ ---
11
+
12
+ # P&ID Processing with AI-Powered Graph Construction
13
+
14
+ ## Overview
15
+ This project processes P&ID (Piping and Instrumentation Diagram) images using multiple AI models for symbol detection, text recognition, and line detection. It constructs a graph representation of the diagram and provides an interactive interface for querying the diagram's contents.
16
+
17
+ ## Features
18
+ - P&ID Document Processing
19
+ - Symbol Detection
20
+ - Text Recognition
21
+ - Line Detection
22
+ - Knowledge Graph Generation
23
+ - Interactive Chat Interface
24
+
25
+ ## Usage
26
+ 1. Upload a P&ID document
27
+ 2. Click "Process Document"
28
+ 3. View results in different tabs
29
+ 4. Ask questions about the P&ID in the chat
30
+
31
+ ## Process Flow
32
+
33
+ ```mermaid
34
+ graph TD
35
+ subgraph "Document Input"
36
+ A[Upload Document] --> B[Validate File]
37
+ B -->|PDF/Image| C[Document Processor]
38
+ B -->|Invalid| ERR[Error Message]
39
+ C -->|PDF| D1[Extract Pages]
40
+ C -->|Image| D2[Direct Process]
41
+ end
42
+
43
+ subgraph "Image Preprocessing"
44
+ D1 --> E[Optimize Image]
45
+ D2 --> E
46
+ E -->|CLAHE Enhancement| E1[Contrast Enhancement]
47
+ E1 -->|Denoising| E2[Clean Image]
48
+ E2 -->|Binarization| E3[Binary Image]
49
+ E3 -->|Resize| E4[Normalized Image]
50
+ end
51
+
52
+ subgraph "Line Detection Pipeline"
53
+ E4 --> L1[Load DeepLSD Model]
54
+ L1 --> L2[Scale Image 0.1x]
55
+ L2 --> L3[Grayscale Conversion]
56
+ L3 --> L4[Model Inference]
57
+ L4 --> L5[Scale Coordinates]
58
+ L5 --> L6[Draw Lines]
59
+ end
60
+
61
+ subgraph "Detection Pipeline"
62
+ E4 --> F[Symbol Detection]
63
+ E4 --> G[Text Detection]
64
+
65
+ F --> S1[Load YOLO Models]
66
+ G --> T1[Load OCR Models]
67
+
68
+ S1 --> S2[Detect Symbols]
69
+ T1 --> T2[Detect Text]
70
+
71
+ S2 --> S3[Process Symbols]
72
+ T2 --> T3[Process Text]
73
+
74
+ L6 --> L7[Process Lines]
75
+ end
76
+
77
+ subgraph "Data Integration"
78
+ S3 --> I[Data Aggregation]
79
+ T3 --> I
80
+ L7 --> I
81
+ I --> J[Create Edges]
82
+ J --> K[Build Graph Network]
83
+ K --> L[Generate Knowledge Graph]
84
+ end
85
+
86
+ subgraph "User Interface"
87
+ L --> M[Interactive Visualization]
88
+ M --> N[Chat Interface]
89
+ N --> O[Query Processing]
90
+ O --> P[Response Generation]
91
+ P --> N
92
+ end
93
+
94
+ style A fill:#f9f,stroke:#333,stroke-width:2px
95
+ style F fill:#fbb,stroke:#333,stroke-width:2px
96
+ style G fill:#bfb,stroke:#333,stroke-width:2px
97
+ style H fill:#bbf,stroke:#333,stroke-width:2px
98
+ style I fill:#fbf,stroke:#333,stroke-width:2px
99
+ style N fill:#bbf,stroke:#333,stroke-width:2px
100
+
101
+ %% Add style for model nodes
102
+ style SM1 fill:#ffe6e6,stroke:#333,stroke-width:2px
103
+ style SM2 fill:#ffe6e6,stroke:#333,stroke-width:2px
104
+ style LM1 fill:#e6e6ff,stroke:#333,stroke-width:2px
105
+ style DC1 fill:#e6ffe6,stroke:#333,stroke-width:2px
106
+ style DC2 fill:#e6ffe6,stroke:#333,stroke-width:2px
107
+ ```
108
+
109
+ ## Architecture
110
+
111
+ ![Project Architecture](./assets/P&ID_to_Graph.drawio.png)
112
+
113
+ ## Features
114
+
115
+ - **Multi-modal AI Processing**:
116
+ - Combined OCR approach using Tesseract, EasyOCR, and DocTR
117
+ - Symbol detection with optimized thresholds
118
+ - Intelligent line and connection detection
119
+ - **Document Processing**:
120
+ - Support for PDF, PNG, JPG, JPEG formats
121
+ - Automatic page extraction from PDFs
122
+ - Image optimization pipeline
123
+ - **Text Detection Types**:
124
+ - Equipment Tags
125
+ - Line Numbers
126
+ - Instrument Tags
127
+ - Valve Numbers
128
+ - Pipe Sizes
129
+ - Flow Directions
130
+ - Service Descriptions
131
+ - Process Instruments
132
+ - Nozzles
133
+ - Pipe Connectors
134
+ - **Data Integration**:
135
+ - Automatic edge detection
136
+ - Relationship mapping
137
+ - Confidence scoring
138
+ - Detailed detection statistics
139
+ - **User Interface**:
140
+ - Interactive visualization tabs
141
+ - Real-time processing feedback
142
+ - AI-powered chat interface
143
+ - Knowledge graph exploration
144
+
145
+ The entire process is visualized through an interactive Gradio-based UI, allowing users to upload a P&ID image, follow the detection steps, and view both the results and insights in real time.
146
+
147
+ ## Key Files
148
+
149
+ - **gradioChatApp.py**: The main Gradio app script that handles the frontend and orchestrates the overall flow.
150
+ - **symbol_detection.py**: Module for detecting symbols using YOLO models.
151
+ - **text_detection_combined.py**: Unified module for text detection using multiple OCR engines (Tesseract, EasyOCR, DocTR).
152
+ - **line_detection_ai.py**: Module for detecting lines and connections using AI.
153
+ - **data_aggregation.py**: Aggregates detected elements into a structured format.
154
+ - **graph_construction.py**: Constructs the graph network from aggregated data.
155
+ - **graph_processor.py**: Handles graph visualization and processing.
156
+ - **pdf_processor.py**: Handles PDF document processing and page extraction.
157
+
158
+ ## Setup and Installation
159
+
160
+ 1. Clone the repository:
161
+ ```bash
162
+ git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
163
+ cd intui-PnID-POC
164
+ ```
165
+
166
+ 2. Install dependencies using uv:
167
+ ```bash
168
+ # Install uv if you haven't already
169
+ curl -LsSf https://astral.sh/uv/install.sh | sh
170
+
171
+ # Create and activate virtual environment
172
+ uv venv
173
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
174
+
175
+ # Install dependencies
176
+ uv pip install -r requirements.txt
177
+ ```
178
+
179
+ 3. Download required models:
180
+ ```bash
181
+ python download_model.py # Downloads DeepLSD model for line detection
182
+ ```
183
+
184
+ 4. Run the application:
185
+ ```bash
186
+ python gradioChatApp.py
187
+ ```
188
+
189
+ ## Models
190
+
191
+ ### Line Detection Model
192
+ - **DeepLSD Model**:
193
+ - File: deeplsd_md.tar
194
+ - Purpose: Line segment detection in P&ID diagrams
195
+ - Input Resolution: Variable (scaled to 0.1x for performance)
196
+ - Processing: Grayscale conversion and binary thresholding
197
+
198
+ ### Text Detection Models
199
+ - **Combined OCR Approach**:
200
+ - Tesseract OCR
201
+ - EasyOCR
202
+ - DocTR
203
+ - Purpose: Text recognition and classification
204
+
205
+ ### Graph Processing
206
+ - **NetworkX-based**:
207
+ - Purpose: Graph construction and analysis
208
+ - Features: Node linking, edge creation, path analysis
209
+
210
+ ## Updating the Environment
211
+
212
+ To update the environment, use the following:
213
+
214
+ ```bash
215
+ conda env update --file environment.yml --prune
216
+ ```
217
+
218
+ This command will update the environment according to changes made in the `environment.yml`.
219
+
220
+ ### Step 6: Deactivate the environment
221
+
222
+ When you're done, deactivate the environment by:
223
+
224
+ ```bash
225
+ conda deactivate
226
+ ```
227
+
228
+ 2. Upload a P&ID image through the interface.
229
+ 3. Follow the sequential steps of symbol, text, and line detection.
230
+ 4. View the generated graph and AI agent's reasoning in the real-time chat box.
231
+ 5. Save and export the results if satisfactory.
232
+
233
+ ## Folder Structure
234
+
235
+ ```
236
+ β”œβ”€β”€ assets/
237
+ β”‚ └── AiAgent.png
238
+ β”‚ └── llm.png
239
+ β”œβ”€β”€ gradioApp.py
240
+ β”œβ”€β”€ symbol_detection.py
241
+ β”œβ”€β”€ text_detection_combined.py
242
+ β”œβ”€β”€ line_detection_ai.py
243
+ β”œβ”€β”€ data_aggregation.py
244
+ β”œβ”€β”€ graph_construction.py
245
+ β”œβ”€β”€ graph_processor.py
246
+ β”œβ”€β”€ pdf_processor.py
247
+ β”œβ”€β”€ pnid_agent.py
248
+ β”œβ”€β”€ requirements.txt
249
+ β”œβ”€β”€ results/
250
+ β”œβ”€β”€ models/
251
+ β”‚ └── symbol_detection_model.pth
252
+ ```
253
+
254
+ ## /models Folder
255
+
256
+ - **models/symbol_detection_model.pth**: This folder contains the pre-trained model for symbol detection in P&ID diagrams. This model is crucial for detecting key symbols such as valves, instruments, and pipes in the diagram. Make sure to download the model and place it in the `/models` directory before running the app.
257
+
258
+ ## Future Work
259
+
260
+ - **Advanced Symbol Recognition**: Improve symbol detection by integrating more sophisticated recognition models.
261
+ - **Graph Enhancement**: Introduce more complex graph structures and logic for representing the relationships between the diagram's elements.
262
+ - **Data Export**: Allow export in additional formats such as DEXPI-compliant XML or JSON.
263
+
264
+
265
+ # Docker Information
266
+
267
+ We'll cover the basic docker operations here.
268
+
269
+ ## Building
270
+
271
+ There is a dockerfile for each different project (they have slightly different requiremnts).
272
+
273
+ ### `gradioChatApp.py`
274
+
275
+ Run this one as follows:
276
+
277
+ ```
278
+ > docker build -t exp-pnid-to-graph_chat-w-graph:0.0.4 -f Dockerfile-chatApp .
279
+ > docker tag exp-pnid-to-graph_chat-w-graph:0.0.4 intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
280
+ ```
281
+
282
+ ## Deploying to ACR
283
+
284
+ ### `gradioChatApp.py`
285
+
286
+ ```
287
+ > az login
288
+ > az acr login --name intaicr
289
+ > docker push intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
290
+ ```
291
+
292
+ ## Models
293
+
294
+ ### Symbol Detection Models
295
+ - **Intui_SDM_41.pt**: Primary model for equipment and large symbol detection
296
+ - Classes: Equipment, Vessels, Heat Exchangers
297
+ - Input Resolution: 1280x1280
298
+ - Confidence Threshold: 0.3-0.7 (adaptive)
299
+
300
+ - **Intui_SDM_20.pt**: Secondary model for instrument and small symbol detection
301
+ - Classes: Instruments, Valves, Indicators
302
+ - Input Resolution: 1280x1280
303
+ - Confidence Threshold: 0.3-0.7 (adaptive)
304
+
305
+ ### Line Detection Model
306
+ - **intui_LDM_01.pt**: Specialized model for line and connection detection
307
+ - Classes: Solid Lines, Dashed Lines
308
+ - Input Resolution: 1280x1280
309
+ - Confidence Threshold: 0.5
310
+
311
+ ### Text Detection Models
312
+ - **Tesseract**: v5.3.0
313
+ - Configuration:
314
+ - OEM Mode: 3 (Default)
315
+ - PSM Mode: 11 (Sparse text)
316
+ - Custom Whitelist: A-Z, 0-9, special characters
317
+
318
+ - **EasyOCR**: v1.7.1
319
+ - Configuration:
320
+ - Language: English
321
+ - Paragraph Mode: False
322
+ - Height Threshold: 2.0
323
+ - Width Threshold: 2.0
324
+ - Contrast Threshold: 0.2
325
+
326
+ - **DocTR**: v0.6.0
327
+ - Models:
328
+ - fast_base-688a8b34.pt
329
+ - crnn_vgg16_bn-9762b0b0.pt
330
+
331
+ # P&ID Line Detection
332
+
333
+ A deep learning-based pipeline for detecting lines in P&ID diagrams using DeepLSD.
334
+
335
+ ## Architecture
336
+ ```mermaid
337
+ graph TD
338
+ A[Input Image] --> B[Line Detection]
339
+ B --> C[DeepLSD Model]
340
+ C --> D[Post-processing]
341
+ D --> E[Output JSON/Image]
342
+
343
+ subgraph Line Detection Pipeline
344
+ B --> F[Image Preprocessing]
345
+ F --> G[Scale Image 0.1x]
346
+ G --> H[Grayscale Conversion]
347
+ H --> C
348
+ C --> I[Scale Coordinates]
349
+ I --> J[Draw Lines]
350
+ J --> E
351
+ end
352
+ ```
353
+
354
+ ## Setup
355
+
356
+ ### Prerequisites
357
+ - Python 3.12+
358
+ - uv (for dependency management)
359
+ - Git
360
+ - CUDA-capable GPU (optional)
361
+
362
+ ### Installation
363
+
364
+ 1. Clone the repository:
365
+ ```bash
366
+ git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
367
+ cd intui-PnID-POC
368
+ ```
369
+
370
+ 2. Install dependencies using uv:
371
+ ```bash
372
+ # Install uv if you haven't already
373
+ curl -LsSf https://astral.sh/uv/install.sh | sh
374
+
375
+ # Create and activate virtual environment
376
+ uv venv
377
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
378
+
379
+ # Install dependencies
380
+ uv pip install -r requirements.txt
381
+ ```
382
+
383
+ 3. Download DeepLSD model:
384
+ ```bash
385
+ python download_model.py
386
+ ```
387
+
388
+ ## Usage
389
+
390
+ 1. Run the line detection:
391
+ ```bash
392
+ python line_detection_ai.py
393
+ ```
394
+
395
+ The script will:
396
+ - Load the DeepLSD model
397
+ - Process input images at 0.1x scale for performance
398
+ - Generate line detections
399
+ - Save results as JSON and annotated images
400
+
401
+ ## Configuration
402
+
403
+ Key parameters in `line_detection_ai.py`:
404
+ - `scale_factor`: Image scaling (default: 0.1)
405
+ - `device`: CPU/GPU selection
406
+ - `mask_json_paths`: Paths to text/symbol detection results
407
+
408
+ ## Input/Output
409
+
410
+ ### Input
411
+ - Original P&ID images
412
+ - Optional text/symbol detection JSON files for masking
413
+
414
+ ### Output
415
+ - Annotated images with detected lines
416
+ - JSON files containing line coordinates and metadata
417
+
418
+ ## Project Structure
419
+
420
+ ```
421
+ β”œβ”€β”€ line_detection_ai.py # Main line detection script
422
+ β”œβ”€β”€ detectors.py # Line detector implementation
423
+ β”œβ”€β”€ download_model.py # Model download utility
424
+ β”œβ”€β”€ models/ # Directory for model files
425
+ β”‚ └── deeplsd_md.tar # DeepLSD model weights
426
+ β”œβ”€β”€ results/ # Output directory
427
+ └── requirements.txt # Project dependencies
428
+ ```
429
+
430
+ ## Dependencies
431
+
432
+ Key dependencies:
433
+ - torch
434
+ - opencv-python
435
+ - numpy
436
+ - DeepLSD
437
+
438
+ See `requirements.txt` for the complete list.
439
+
440
+ ## Contributing
441
+
442
+ 1. Fork the repository
443
+ 2. Create your feature branch (`git checkout -b feature/amazing-feature`)
444
+ 3. Commit your changes (`git commit -m 'Add some amazing feature'`)
445
+ 4. Push to the branch (`git push origin feature/amazing-feature`)
446
+ 5. Open a Pull Request
447
+
448
+ ## License
449
+
450
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
451
+
452
+ ## Acknowledgments
453
+
454
+ - [DeepLSD](https://github.com/cvg/DeepLSD) for the line detection model
455
+ - Original P&ID processing pipeline by IntuigenceAI
456
+ ---
457
+ title: PnID Diagram Analyzer
458
+ emoji: πŸ”
459
+ colorFrom: blue
460
+ colorTo: red
461
+ sdk: gradio
462
+ sdk_version: 4.19.2
463
+ app_file: gradioChatApp.py
464
+ pinned: false
465
+ ---
466
+
467
+ # PnID Diagram Analyzer
468
+
469
+ This app analyzes PnID diagrams using AI to detect and interpret various elements.
470
+
471
+ ## Features
472
+ - Line detection
473
+ - Symbol recognition
474
+ - Text detection
475
+ - Graph construction
476
+
477
+ # Intuigence P&ID Analyzer
478
+
479
+ Interactive P&ID analysis tool powered by AI.
480
+
481
+ ## Features
482
+ - P&ID Document Processing
483
+ - Symbol Detection
484
+ - Text Recognition
485
+ - Line Detection
486
+ - Knowledge Graph Generation
487
+ - Interactive Chat Interface
488
+
489
+ ## Usage
490
+ 1. Upload a P&ID document
491
+ 2. Click "Process Document"
492
+ 3. View results in different tabs
493
+ 4. Ask questions about the P&ID in the chat
494
+
assets/AiAgent.png ADDED
assets/intuigence.png ADDED
assets/user.png ADDED
base.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Dict
5
+
6
+ import numpy as np
7
+ import cv2
8
+
9
+ from pathlib import Path
10
+ from loguru import logger
11
+ import json
12
+
13
+ from common import DetectionResult
14
+ from storage import StorageInterface
15
+ from utils import DebugHandler, CoordinateTransformer
16
+
17
+
18
+ class BaseConfig(ABC):
19
+ """Abstract Base Config for all configuration classes."""
20
+
21
+ def __post_init__(self):
22
+ """Ensures default values are set correctly for all configs."""
23
+ pass
24
+
25
+ class BaseDetector(ABC):
26
+ """Abstract base class for detection models."""
27
+
28
+ def __init__(self,
29
+ config: BaseConfig,
30
+ debug_handler: DebugHandler = None):
31
+ self.config = config
32
+ self.debug_handler = debug_handler or DebugHandler()
33
+
34
+ @abstractmethod
35
+ def _load_model(self, model_path: str):
36
+ """Load and return the detection model."""
37
+ pass
38
+
39
+ @abstractmethod
40
+ def detect(self, image: np.ndarray, *args, **kwargs):
41
+ """Run detection on an input image."""
42
+ pass
43
+
44
+ @abstractmethod
45
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
46
+ """Preprocess the input image before detection."""
47
+ pass
48
+
49
+ @abstractmethod
50
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
51
+ """Postprocess the input image before detection."""
52
+ pass
53
+
54
+
55
+ class BaseDetectionPipeline(ABC):
56
+ """Abstract base class for detection pipelines."""
57
+
58
+ def __init__(
59
+ self,
60
+ storage: StorageInterface,
61
+ debug_handler=None
62
+ ):
63
+ # self.detector = detector
64
+ self.storage = storage
65
+ self.debug_handler = debug_handler or DebugHandler()
66
+ self.transformer = CoordinateTransformer()
67
+
68
+ @abstractmethod
69
+ def process_image(
70
+ self,
71
+ image_path: str,
72
+ output_dir: str,
73
+ config
74
+ ) -> DetectionResult:
75
+ """Main processing pipeline for a single image."""
76
+ pass
77
+
78
+ def _apply_roi(self, image: np.ndarray, roi: np.ndarray) -> np.ndarray:
79
+ """Apply region of interest cropping."""
80
+ if roi is not None and len(roi) == 4:
81
+ x_min, y_min, x_max, y_max = roi
82
+ return image[y_min:y_max, x_min:x_max]
83
+ return image
84
+
85
+ def _adjust_coordinates(self, detections: List[Dict], roi: np.ndarray) -> List[Dict]:
86
+ """Adjust detection coordinates based on ROI"""
87
+ if roi is None or len(roi) != 4:
88
+ return detections
89
+
90
+ x_offset, y_offset = roi[0], roi[1]
91
+ adjusted = []
92
+
93
+ for det in detections:
94
+ try:
95
+ adjusted_bbox = [
96
+ int(det["bbox"][0] + x_offset),
97
+ int(det["bbox"][1] + y_offset),
98
+ int(det["bbox"][2] + x_offset),
99
+ int(det["bbox"][3] + y_offset)
100
+ ]
101
+ adjusted_det = {**det, "bbox": adjusted_bbox}
102
+ adjusted.append(adjusted_det)
103
+ except KeyError:
104
+ logger.warning("Invalid detection format during coordinate adjustment")
105
+ return adjusted
106
+
107
+ def _persist_results(
108
+ self,
109
+ output_dir: str,
110
+ image_path: str,
111
+ detections: List[Dict],
112
+ annotated_image: Optional[np.ndarray]
113
+ ) -> Dict[str, str]:
114
+ """Save detection results and annotations"""
115
+ self.storage.create_directory(output_dir)
116
+ base_name = Path(image_path).stem
117
+
118
+ # Save JSON results
119
+ json_path = Path(output_dir) / f"{base_name}_lines.json"
120
+ self.storage.save_file(
121
+ str(json_path),
122
+ json.dumps({
123
+ "solid_lines": {"lines": detections},
124
+ "dashed_lines": {"lines": []}
125
+ }, indent=2).encode('utf-8')
126
+ )
127
+
128
+ # Save annotated image
129
+ img_path = None
130
+ if annotated_image is not None:
131
+ img_path = Path(output_dir) / f"{base_name}_annotated.jpg"
132
+ _, img_data = cv2.imencode('.jpg', annotated_image)
133
+ self.storage.save_file(str(img_path), img_data.tobytes())
134
+
135
+ return {
136
+ "json_path": str(json_path),
137
+ "image_path": str(img_path) if img_path else None
138
+ }
base_config.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class BaseConfig:
5
+ """Base configuration class"""
6
+ def __init__(self, **kwargs):
7
+ for key, value in kwargs.items():
8
+ setattr(self, key, value)
chatbot_agent.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chatbot_agent.py
2
+
3
+ import os
4
+ import json
5
+ import re
6
+ from openai import OpenAI
7
+ import traceback
8
+ import logging
9
+ from dotenv import load_dotenv
10
+
11
+ # Load environment variables
12
+ load_dotenv()
13
+
14
+ # Get logger
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Initialize OpenAI client with error handling
18
+ def get_openai_client():
19
+ api_key = os.getenv("OPENAI_API_KEY")
20
+ if not api_key:
21
+ raise ValueError("OpenAI API key not found in environment variables")
22
+ return OpenAI(api_key=api_key)
23
+
24
+ def format_message(role, content):
25
+ """Format message for chatbot history."""
26
+ return {"role": role, "content": content}
27
+
28
+ def initialize_graph_prompt(graph_data):
29
+ """Initialize the conversation with available node and edge information."""
30
+ try:
31
+ # Get summary info with safe fallbacks
32
+ summary = graph_data.get('summary', {})
33
+ summary_parts = []
34
+
35
+ # Only include counts that exist
36
+ if 'symbol_count' in summary:
37
+ summary_parts.append(f"Symbols: {summary['symbol_count']}")
38
+ if 'text_count' in summary:
39
+ summary_parts.append(f"Texts: {summary['text_count']}")
40
+ if 'line_count' in summary:
41
+ summary_parts.append(f"Lines: {summary['line_count']}")
42
+ if 'edge_count' in summary:
43
+ summary_parts.append(f"Edges: {summary['edge_count']}")
44
+
45
+ summary_info = ", ".join(summary_parts) + "."
46
+
47
+ # Prepare node details only if they exist
48
+ node_details = ""
49
+ detailed_results = graph_data.get('detailed_results', {})
50
+ if 'symbols' in detailed_results:
51
+ node_details = "Nodes (symbols) in the graph include:\n"
52
+ for symbol in detailed_results['symbols']:
53
+ details = []
54
+ if 'symbol_id' in symbol:
55
+ details.append(f"ID: {symbol['symbol_id']}")
56
+ if 'class_id' in symbol:
57
+ details.append(f"Class: {symbol['class_id']}")
58
+ if 'category' in symbol:
59
+ details.append(f"Category: {symbol['category']}")
60
+ if 'type' in symbol:
61
+ details.append(f"Type: {symbol['type']}")
62
+ if 'label' in symbol:
63
+ details.append(f"Label: {symbol['label']}")
64
+ if details: # Only add if we have any details
65
+ node_details += ", ".join(details) + "\n"
66
+
67
+ initial_prompt = (
68
+ "You have access to a knowledge graph generated from a P&ID diagram. "
69
+ f"The summary information includes:\n{summary_info}\n\n"
70
+ f"{node_details}\n"
71
+ "Answer questions about the P&ID elements using this information."
72
+ )
73
+
74
+ return initial_prompt
75
+
76
+ except Exception as e:
77
+ logger.error(f"Error creating initial prompt: {str(e)}")
78
+ return ("I have access to a P&ID diagram knowledge graph. "
79
+ "I can help answer questions about the diagram elements.")
80
+
81
+ def get_assistant_response(user_message, json_path):
82
+ """Generate response based on P&ID data and OpenAI."""
83
+ try:
84
+ client = get_openai_client()
85
+ # Load the aggregated data
86
+ with open(json_path, 'r') as f:
87
+ data = json.load(f)
88
+
89
+ # Process the user's question
90
+ question = user_message.lower()
91
+
92
+ # Use rule-based responses for specific questions
93
+ if "valve" in question or "valves" in question:
94
+ valve_count = sum(1 for symbol in data.get('symbols', [])
95
+ if 'class' in symbol and 'valve' in symbol['class'].lower())
96
+ return f"I found {valve_count} valves in this P&ID."
97
+
98
+ elif "pump" in question or "pumps" in question:
99
+ pump_count = sum(1 for symbol in data.get('symbols', [])
100
+ if 'class' in symbol and 'pump' in symbol['class'].lower())
101
+ return f"I found {pump_count} pumps in this P&ID."
102
+
103
+ elif "equipment" in question or "components" in question:
104
+ equipment_types = {}
105
+ for symbol in data.get('symbols', []):
106
+ if 'class' in symbol:
107
+ eq_type = symbol['class']
108
+ equipment_types[eq_type] = equipment_types.get(eq_type, 0) + 1
109
+
110
+ response = "Here's a summary of the equipment I found:\n"
111
+ for eq_type, count in equipment_types.items():
112
+ response += f"- {eq_type}: {count}\n"
113
+ return response
114
+
115
+ # For other questions, use OpenAI
116
+ else:
117
+ # Prepare the conversation context
118
+ graph_data = {
119
+ "summary": {
120
+ "symbol_count": len(data.get('symbols', [])),
121
+ "text_count": len(data.get('texts', [])),
122
+ "line_count": len(data.get('lines', [])),
123
+ "edge_count": len(data.get('edges', [])),
124
+ },
125
+ "detailed_results": data
126
+ }
127
+
128
+ initial_prompt = initialize_graph_prompt(graph_data)
129
+ conversation = [
130
+ {"role": "system", "content": initial_prompt},
131
+ {"role": "user", "content": user_message}
132
+ ]
133
+
134
+ response = client.chat.completions.create(
135
+ model="gpt-4-turbo",
136
+ messages=conversation
137
+ )
138
+ return response.choices[0].message.content
139
+
140
+ except Exception as e:
141
+ logger.error(f"Error in get_assistant_response: {str(e)}")
142
+ logger.error(traceback.format_exc())
143
+ return "I apologize, but I encountered an error analyzing the P&ID data. Please try asking a different question."
144
+
145
+ # Testing and Usage block
146
+ if __name__ == "__main__":
147
+ # Load the knowledge graph data from JSON file
148
+ json_file_path = "results/0_aggregated_detections.json"
149
+ try:
150
+ with open(json_file_path, 'r') as file:
151
+ graph_data = json.load(file)
152
+ except FileNotFoundError:
153
+ print(f"Error: File not found at {json_file_path}")
154
+ graph_data = None
155
+ except json.JSONDecodeError:
156
+ print("Error: Failed to decode JSON. Please check the file format.")
157
+ graph_data = None
158
+
159
+ # Initialize conversation history with assistant's welcome message
160
+ history = [format_message("assistant", "Hello! I am ready to answer your questions about the P&ID knowledge graph. The graph includes nodes (symbols), edges, linkers, and text tags, and I have detailed information available about each. Please ask any questions related to these elements and their connections.")]
161
+
162
+ # Print the assistant's welcome message
163
+ print("Assistant:", history[0]["content"])
164
+
165
+ # Individual Testing Options
166
+ if graph_data:
167
+ # Option 1: Test the graph prompt initialization
168
+ print("\n--- Test: Graph Prompt Initialization ---")
169
+ initial_prompt = initialize_graph_prompt(graph_data)
170
+ print(initial_prompt)
171
+
172
+ # Option 2: Simulate a conversation with a test question
173
+ print("\n--- Test: Simulate Conversation ---")
174
+ test_question = "Can you tell me about the connections between the nodes?"
175
+ history.append(format_message("user", test_question))
176
+
177
+ print(f"\nUser: {test_question}")
178
+ for response in get_assistant_response(test_question, json_file_path):
179
+ print("Assistant:", response)
180
+ history.append(format_message("assistant", response))
181
+
182
+ # Option 3: Manually input questions for interactive testing
183
+ while True:
184
+ user_question = input("\nYou: ")
185
+ if user_question.lower() in ["exit", "quit"]:
186
+ print("Exiting chat. Goodbye!")
187
+ break
188
+
189
+ history.append(format_message("user", user_question))
190
+ for response in get_assistant_response(user_question, json_file_path):
191
+ print("Assistant:", response)
192
+ history.append(format_message("assistant", response))
193
+ else:
194
+ print("Unable to load graph data. Please check the file path and format.")
common.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Dict, Optional, Tuple, Union
3
+ import numpy as np
4
+
5
+ from detection_schema import Line
6
+
7
+ @dataclass
8
+ class DetectionResult:
9
+ success: bool
10
+ error: Optional[str] = None
11
+ annotated_image: Optional[np.ndarray] = None
12
+ processing_time: float = 0.0
13
+ json_path: Optional[str] = None
14
+ image_path: Optional[str] = None
config.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Dict, Optional, Tuple, Union
3
+ import numpy as np
4
+ from base import BaseConfig
5
+
6
+
7
+ @dataclass
8
+ class ImageConfig(BaseConfig):
9
+ """Configuration for global image-related settings"""
10
+
11
+ roi: Optional[np.ndarray] = field(default_factory=lambda: np.array([500, 500, 5300, 4000]))
12
+ mask_json_path: str = "./text_and_symbol_bboxes.json"
13
+ save_annotations: bool = True
14
+ annotation_style: Dict = field(default_factory=lambda: {
15
+ 'bbox_color': (255, 0, 0),
16
+ 'line_color': (0, 255, 0),
17
+ 'text_color': (0, 0, 255),
18
+ 'thickness': 2,
19
+ 'font_scale': 0.6
20
+ })
21
+
22
+ @dataclass
23
+ class SymbolConfig(BaseConfig):
24
+ """Configuration for Symbol Detection"""
25
+ model_path: str = "models/symbol_detection.pt"
26
+ confidence_threshold: float = 0.5
27
+ nms_threshold: float = 0.3
28
+ min_size: Tuple[int, int] = (10, 10)
29
+ max_size: Tuple[int, int] = (200, 200)
30
+ class_names: List[str] = field(default_factory=lambda: ["background", "valve", "pump", "sensor"])
31
+
32
+ # Optional: Keep the multiple thresholds for experimentation
33
+ confidence_thresholds: List[float] = field(default_factory=lambda: [0.1, 0.3, 0.5, 0.7, 0.9])
34
+ model_paths: Dict[str, str] = field(default_factory=lambda: {
35
+ "model1": "models/Intui_SDM_41.pt",
36
+ "model2": "models/Intui_SDM_20.pt"
37
+ })
38
+
39
+
40
+
41
+ @dataclass
42
+ class TagConfig(BaseConfig):
43
+ """Configuration for Tag Detection with OCR"""
44
+ model_path: str = "models/tag_detection.json"
45
+ confidence_threshold: float = 0.5
46
+ iou_threshold: float = 0.4
47
+ ocr_engines: List[str] = field(default_factory=lambda: ['tesseract', 'easyocr', 'doctr'])
48
+ text_patterns: Dict[str, str] = field(default_factory=lambda: {
49
+ 'Line_Number': r"\d{1,5}-[A-Z]{2,4}-\d{1,3}",
50
+ 'Equipment_Tag': r"[A-Z]{1,3}-[A-Z0-9]{1,4}-\d{1,3}",
51
+ 'Instrument_Tag': r"\d{2,3}-[A-Z]{2,4}-\d{2,3}",
52
+ 'Valve_Number': r"[A-Z]{1,2}-\d{3}",
53
+ 'Pipe_Size': r"\d{1,2}\"",
54
+ 'Flow_Direction': r"FROM|TO",
55
+ 'Service_Description': r"STEAM|WATER|AIR|GAS|DRAIN",
56
+ 'Process_Instrument': r"\d{2,3}(?:-[A-Z]{2,3})?-\d{2,3}|[A-Z]{2,3}-\d{2,3}",
57
+ 'Nozzle': r"N[0-9]{1,2}|MH",
58
+ 'Pipe_Connector': r"[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5}"
59
+ })
60
+ tesseract_config: str = r'--oem 3 --psm 11'
61
+ easyocr_params: Dict = field(default_factory=lambda: {
62
+ 'paragraph': False,
63
+ 'height_ths': 2.0,
64
+ 'width_ths': 2.0,
65
+ 'contrast_ths': 0.2
66
+ })
67
+
68
+ @dataclass
69
+ class LineConfig(BaseConfig):
70
+ """Configuration for Line Detection"""
71
+
72
+ threshold_distance: float = 10.0
73
+ expansion_factor: float = 1.1
74
+
75
+
76
+ @dataclass
77
+ class PointConfig(BaseConfig):
78
+ """Configuration for Point Detection"""
79
+
80
+ threshold_distance: float = 10.0
81
+
82
+
83
+ @dataclass
84
+ class JunctionConfig(BaseConfig):
85
+ """Configuration for Junction Detection"""
86
+
87
+ window_size: int = 21
88
+ radius: int = 5
89
+ angle_threshold_lb: float = 15.0
90
+ angle_threshold_ub: float = 75.0
91
+
92
+ # @dataclass
93
+ # class JunctionConfig:
94
+ # radius: int = 5
95
+ # angle_threshold: float = 25.0
96
+ # colinear_threshold: float = 5.0
97
+ # connection_threshold: float = 5.0
data_aggregation_ai.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import logging
4
+ from datetime import datetime
5
+ from typing import List, Dict, Optional, Tuple
6
+ from storage import StorageFactory
7
+ import uuid
8
+ import traceback
9
+ import os
10
+ import cv2
11
+ import numpy as np
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ class DataAggregator:
16
+ def __init__(self, storage=None):
17
+ self.storage = storage or StorageFactory.get_storage()
18
+ self.logger = logging.getLogger(__name__)
19
+
20
+ def _parse_line_data(self, lines_data: dict) -> List[dict]:
21
+ """Parse line detection data with coordinate validation"""
22
+ parsed_lines = []
23
+
24
+ for line in lines_data.get("lines", []):
25
+ try:
26
+ # Extract and validate line coordinates
27
+ start_coords = line["start"]["coords"]
28
+ end_coords = line["end"]["coords"]
29
+ bbox = line["bbox"]
30
+
31
+ # Validate coordinates
32
+ if not (self._is_valid_point(start_coords) and
33
+ self._is_valid_point(end_coords) and
34
+ self._is_valid_bbox(bbox)):
35
+ self.logger.warning(f"Invalid coordinates in line: {line['id']}")
36
+ continue
37
+
38
+ # Create parsed line with validated coordinates
39
+ parsed_line = {
40
+ "id": line["id"],
41
+ "start_point": {
42
+ "x": int(start_coords["x"]),
43
+ "y": int(start_coords["y"]),
44
+ "type": line["start"]["type"],
45
+ "confidence": line["start"]["confidence"]
46
+ },
47
+ "end_point": {
48
+ "x": int(end_coords["x"]),
49
+ "y": int(end_coords["y"]),
50
+ "type": line["end"]["type"],
51
+ "confidence": line["end"]["confidence"]
52
+ },
53
+ "bbox": {
54
+ "xmin": int(bbox["xmin"]),
55
+ "ymin": int(bbox["ymin"]),
56
+ "xmax": int(bbox["xmax"]),
57
+ "ymax": int(bbox["ymax"])
58
+ },
59
+ "style": line["style"],
60
+ "confidence": line["confidence"]
61
+ }
62
+ parsed_lines.append(parsed_line)
63
+
64
+ except Exception as e:
65
+ self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}")
66
+ continue
67
+
68
+ return parsed_lines
69
+
70
+ def _is_valid_point(self, point: dict) -> bool:
71
+ """Validate point coordinates"""
72
+ try:
73
+ x, y = point.get("x"), point.get("y")
74
+ return (isinstance(x, (int, float)) and
75
+ isinstance(y, (int, float)) and
76
+ 0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed
77
+ except:
78
+ return False
79
+
80
+ def _is_valid_bbox(self, bbox: dict) -> bool:
81
+ """Validate bbox coordinates"""
82
+ try:
83
+ xmin = bbox.get("xmin")
84
+ ymin = bbox.get("ymin")
85
+ xmax = bbox.get("xmax")
86
+ ymax = bbox.get("ymax")
87
+
88
+ return (isinstance(xmin, (int, float)) and
89
+ isinstance(ymin, (int, float)) and
90
+ isinstance(xmax, (int, float)) and
91
+ isinstance(ymax, (int, float)) and
92
+ xmin < xmax and ymin < ymax and
93
+ 0 <= xmin <= 10000 and 0 <= ymin <= 10000 and
94
+ 0 <= xmax <= 10000 and 0 <= ymax <= 10000)
95
+ except:
96
+ return False
97
+
98
+ def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]:
99
+ """Create graph nodes and edges from detections"""
100
+ nodes = []
101
+ edges = []
102
+
103
+ # Debug input data
104
+ self.logger.info("Creating graph data with:")
105
+ self.logger.info(f"Lines: {len(lines)}")
106
+ self.logger.info(f"Symbols: {len(symbols)}")
107
+ self.logger.info(f"Texts: {len(texts)}")
108
+
109
+ try:
110
+ # Process symbols
111
+ for symbol in symbols:
112
+ bbox = symbol["bbox"] # bbox is a list [x1,y1,x2,y2]
113
+ nodes.append({
114
+ "id": symbol["symbol_id"],
115
+ "type": "symbol",
116
+ "category": symbol.get("category", ""),
117
+ "label": symbol.get("label", ""),
118
+ "confidence": symbol.get("confidence", 0.0),
119
+ "x": (bbox[0] + bbox[2]) / 2, # Use list indices
120
+ "y": (bbox[1] + bbox[3]) / 2, # Use list indices
121
+ "bbox": { # Convert to dict format for consistency
122
+ "xmin": bbox[0],
123
+ "ymin": bbox[1],
124
+ "xmax": bbox[2],
125
+ "ymax": bbox[3]
126
+ }
127
+ })
128
+
129
+ # Process texts
130
+ for text in texts:
131
+ bbox = text["bbox"] # bbox is a list [x1,y1,x2,y2]
132
+ nodes.append({
133
+ "id": str(uuid.uuid4()),
134
+ "type": "text",
135
+ "content": text.get("text", ""),
136
+ "confidence": text.get("confidence", 0.0),
137
+ "x": (bbox[0] + bbox[2]) / 2, # Use list indices
138
+ "y": (bbox[1] + bbox[3]) / 2, # Use list indices
139
+ "bbox": { # Convert to dict format for consistency
140
+ "xmin": bbox[0],
141
+ "ymin": bbox[1],
142
+ "xmax": bbox[2],
143
+ "ymax": bbox[3]
144
+ }
145
+ })
146
+
147
+ # Process lines (unchanged)
148
+ for line in lines:
149
+ edges.append({
150
+ "id": str(uuid.uuid4()),
151
+ "type": "line",
152
+ "start_point": line["start_point"],
153
+ "end_point": line["end_point"]
154
+ })
155
+
156
+ except Exception as e:
157
+ self.logger.error(f"Error processing data: {str(e)}")
158
+ self.logger.error("Current symbol/text being processed: %s",
159
+ json.dumps(symbol if 'symbol' in locals() else text, indent=2))
160
+ raise
161
+
162
+ return nodes, edges
163
+
164
+ def _validate_coordinates(self, data, data_type):
165
+ """Validate coordinates in the data"""
166
+ if not data:
167
+ return False
168
+
169
+ try:
170
+ if data_type == 'line':
171
+ # Check start and end points
172
+ start = data.get('start_point', {})
173
+ end = data.get('end_point', {})
174
+ bbox = data.get('bbox', {})
175
+
176
+ required_fields = ['x', 'y', 'type']
177
+ if not all(field in start for field in required_fields):
178
+ self.logger.warning(f"Missing required fields in start_point: {start}")
179
+ return False
180
+ if not all(field in end for field in required_fields):
181
+ self.logger.warning(f"Missing required fields in end_point: {end}")
182
+ return False
183
+
184
+ # Validate bbox coordinates
185
+ if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
186
+ self.logger.warning(f"Invalid bbox format: {bbox}")
187
+ return False
188
+
189
+ # Check coordinate consistency
190
+ if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
191
+ self.logger.warning(f"Invalid bbox coordinates: {bbox}")
192
+ return False
193
+
194
+ elif data_type in ['symbol', 'text']:
195
+ bbox = data.get('bbox', {})
196
+ if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
197
+ self.logger.warning(f"Invalid {data_type} bbox format: {bbox}")
198
+ return False
199
+
200
+ # Check coordinate consistency
201
+ if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
202
+ self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}")
203
+ return False
204
+
205
+ return True
206
+
207
+ except Exception as e:
208
+ self.logger.error(f"Validation error for {data_type}: {str(e)}")
209
+ return False
210
+
211
+ def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict:
212
+ """Aggregate detection results and create graph structure"""
213
+ try:
214
+ # Load line detection results
215
+ lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8'))
216
+ lines = self._parse_line_data(lines_data)
217
+ self.logger.info(f"Loaded {len(lines)} lines")
218
+
219
+ # Load and debug symbol detections
220
+ symbols = []
221
+ if symbols_path and Path(symbols_path).exists():
222
+ with open(symbols_path, 'r') as f:
223
+ symbols_data = json.load(f)
224
+ # Debug symbol data structure
225
+ self.logger.info("Symbol data keys: %s", list(symbols_data.keys()))
226
+ self.logger.info("First symbol in detections: %s",
227
+ json.dumps(symbols_data["detections"][0], indent=2))
228
+
229
+ symbols = symbols_data.get("detections", [])
230
+ self.logger.info(f"Loaded {len(symbols)} symbols from {symbols_path}")
231
+ # Debug first symbol structure
232
+ if symbols:
233
+ self.logger.info("First symbol keys: %s", list(symbols[0].keys()))
234
+ self.logger.info("First symbol bbox: %s", symbols[0]["bbox"])
235
+
236
+ # Load and debug text detections
237
+ texts = []
238
+ if texts_path and Path(texts_path).exists():
239
+ with open(texts_path, 'r') as f:
240
+ texts_data = json.load(f)
241
+ # Debug text data structure
242
+ self.logger.info("Text data keys: %s", list(texts_data.keys()))
243
+ self.logger.info("First text in detections: %s",
244
+ json.dumps(texts_data["detections"][0], indent=2))
245
+
246
+ texts = texts_data.get("detections", [])
247
+ self.logger.info(f"Loaded {len(texts)} texts from {texts_path}")
248
+ # Debug first text structure
249
+ if texts:
250
+ self.logger.info("First text keys: %s", list(texts[0].keys()))
251
+ self.logger.info("First text bbox: %s", texts[0]["bbox"])
252
+
253
+ # Create graph data
254
+ nodes, edges = self._create_graph_data(lines, symbols, texts)
255
+ self.logger.info(f"Created graph with {len(nodes)} nodes and {len(edges)} edges")
256
+
257
+ return {
258
+ "lines": lines,
259
+ "symbols": symbols,
260
+ "texts": texts,
261
+ "nodes": nodes,
262
+ "edges": edges,
263
+ "metadata": {
264
+ "timestamp": datetime.now().isoformat(),
265
+ "version": "2.0"
266
+ }
267
+ }
268
+ except Exception as e:
269
+ self.logger.error(f"Error during aggregation: {str(e)}")
270
+ self.logger.error("Stack trace:", exc_info=True) # Add full stack trace
271
+ raise
272
+
273
+ def _draw_aggregated_view(self, image: np.ndarray, results: dict) -> np.ndarray:
274
+ """Draw all detections on image"""
275
+ annotated = image.copy()
276
+
277
+ # Draw lines (green)
278
+ for line in results.get('lines', []):
279
+ try:
280
+ cv2.line(annotated,
281
+ (line['start_point']['x'], line['start_point']['y']),
282
+ (line['end_point']['x'], line['end_point']['y']),
283
+ (0, 255, 0), 2)
284
+ except Exception as e:
285
+ self.logger.warning(f"Skipping invalid line: {str(e)}")
286
+ continue
287
+
288
+ # Draw symbols (cyan) - Fix bbox access
289
+ for symbol in results.get('symbols', []):
290
+ try:
291
+ bbox = symbol['bbox']
292
+ # bbox is a list [x1,y1,x2,y2], not a dict
293
+ cv2.rectangle(annotated,
294
+ (bbox[0], bbox[1]), # Use list indices
295
+ (bbox[2], bbox[3]), # Use list indices
296
+ (255, 255, 0), 2)
297
+ except Exception as e:
298
+ self.logger.warning(f"Skipping invalid symbol: {str(e)}")
299
+ continue
300
+
301
+ # Draw texts (purple) - Fix bbox access
302
+ for text in results.get('texts', []):
303
+ try:
304
+ bbox = text['bbox']
305
+ # bbox is a list [x1,y1,x2,y2], not a dict
306
+ cv2.rectangle(annotated,
307
+ (bbox[0], bbox[1]), # Use list indices
308
+ (bbox[2], bbox[3]), # Use list indices
309
+ (128, 0, 128), 2)
310
+ except Exception as e:
311
+ self.logger.warning(f"Skipping invalid text: {str(e)}")
312
+ continue
313
+
314
+ return annotated
315
+
316
+ def process_data(self, image_path: str, output_dir: str, symbols_path: str, texts_path: str, lines_path: str):
317
+ try:
318
+ self.logger.info(f"Processing data with:")
319
+ self.logger.info(f"- Image: {image_path}")
320
+ self.logger.info(f"- Symbols: {symbols_path}")
321
+ self.logger.info(f"- Texts: {texts_path}")
322
+ self.logger.info(f"- Lines: {lines_path}")
323
+
324
+ base_name = Path(image_path).stem
325
+ self.logger.info(f"Base name: {base_name}")
326
+
327
+ aggregated_json = os.path.join(output_dir, f"{base_name}_aggregated.json")
328
+ self.logger.info(f"Will save aggregated data to: {aggregated_json}")
329
+
330
+ results = self.aggregate_data(symbols_path, texts_path, lines_path)
331
+ self.logger.info("Data aggregation completed")
332
+
333
+ with open(aggregated_json, 'w') as f:
334
+ json.dump(results, f, indent=2)
335
+ self.logger.info(f"Saved aggregated JSON to: {aggregated_json}")
336
+
337
+ # Create visualization using original image
338
+ image = cv2.imread(image_path)
339
+ annotated = self._draw_aggregated_view(image, results)
340
+ aggregated_image = os.path.join(output_dir, f"{base_name}_aggregated.png")
341
+ cv2.imwrite(aggregated_image, annotated)
342
+
343
+ # Return paths like other detectors
344
+ return {
345
+ 'success': True,
346
+ 'image_path': aggregated_image,
347
+ 'json_path': aggregated_json
348
+ }
349
+
350
+ except Exception as e:
351
+ self.logger.error(f"Error in data aggregation: {str(e)}")
352
+ return {
353
+ 'success': False,
354
+ 'error': str(e)
355
+ }
356
+
357
+ if __name__ == "__main__":
358
+ import os
359
+ from pprint import pprint
360
+
361
+ # Initialize the aggregator
362
+ aggregator = DataAggregator()
363
+
364
+ # Test paths using actual files in results folder
365
+ results_dir = "results"
366
+ base_name = "002_page_1"
367
+
368
+ # Input paths
369
+ symbols_path = os.path.join(results_dir, f"{base_name}_detected_symbols.json")
370
+ texts_path = os.path.join(results_dir, f"{base_name}_detected_texts.json")
371
+ lines_path = os.path.join(results_dir, f"{base_name}_detected_lines.json")
372
+
373
+ # Verify files exist
374
+ print(f"\nChecking input files:")
375
+ print(f"Symbols file exists: {os.path.exists(symbols_path)}")
376
+ print(f"Texts file exists: {os.path.exists(texts_path)}")
377
+ print(f"Lines file exists: {os.path.exists(lines_path)}")
378
+
379
+ try:
380
+ # Process the data
381
+ print("\nProcessing data...")
382
+ result = aggregator.process_data(
383
+ image_path=os.path.join(results_dir, f"{base_name}.png"),
384
+ output_dir=results_dir,
385
+ symbols_path=symbols_path,
386
+ texts_path=texts_path,
387
+ lines_path=lines_path
388
+ )
389
+
390
+ # Verify output files
391
+ aggregated_json = os.path.join(results_dir, f"{base_name}_aggregated.json")
392
+ aggregated_image = os.path.join(results_dir, f"{base_name}_aggregated.png")
393
+
394
+ print("\nChecking output files:")
395
+ print(f"Aggregated JSON exists: {os.path.exists(aggregated_json)}")
396
+ print(f"Aggregated image exists: {os.path.exists(aggregated_image)}")
397
+
398
+ # Load and print statistics from aggregated result
399
+ if os.path.exists(aggregated_json):
400
+ with open(aggregated_json, 'r') as f:
401
+ data = json.load(f)
402
+ print("\nAggregation Results:")
403
+ print(f"Number of Symbols: {len(data.get('symbols', []))}")
404
+ print(f"Number of Texts: {len(data.get('texts', []))}")
405
+ print(f"Number of Lines: {len(data.get('lines', []))}")
406
+ print(f"Number of Nodes: {len(data.get('nodes', []))}")
407
+ print(f"Number of Edges: {len(data.get('edges', []))}")
408
+
409
+ except Exception as e:
410
+ print(f"\nError during testing: {str(e)}")
411
+ traceback.print_exc()
detection_schema.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional, Tuple, Dict
3
+ import uuid
4
+ from enum import Enum
5
+ import json
6
+ import numpy as np
7
+
8
+ # ======================== Point ======================== #
9
+ class ConnectionType(Enum):
10
+ SOLID = "solid"
11
+ DASHED = "dashed"
12
+ PHANTOM = "phantom"
13
+
14
+ @dataclass
15
+ class Coordinates:
16
+ x: int
17
+ y: int
18
+
19
+ @dataclass
20
+ class BBox:
21
+ xmin: int
22
+ ymin: int
23
+ xmax: int
24
+ ymax: int
25
+
26
+ def width(self) -> int:
27
+ return self.xmax - self.xmin
28
+
29
+ def height(self) -> int:
30
+ return self.ymax - self.ymin
31
+
32
+ class JunctionType(str, Enum):
33
+ T = "T"
34
+ L = "L"
35
+ END = "END"
36
+
37
+ @dataclass
38
+ class Point:
39
+ coords: Coordinates
40
+ bbox: BBox
41
+ type: JunctionType
42
+ confidence: float = 1.0
43
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
44
+
45
+
46
+ # # ======================== Symbol ======================== #
47
+ # class SymbolType(Enum):
48
+ # VALVE = "valve"
49
+ # PUMP = "pump"
50
+ # SENSOR = "sensor"
51
+ # # Add others as needed
52
+ #
53
+ class ValveSubtype(Enum):
54
+ GATE = "gate"
55
+ GLOBE = "globe"
56
+ BUTTERFLY = "butterfly"
57
+ #
58
+ # @dataclass
59
+ # class Symbol:
60
+ # symbol_type: SymbolType
61
+ # bbox: BBox
62
+ # center: Coordinates
63
+ # connections: List[Point] = field(default_factory=list)
64
+ # subtype: Optional[ValveSubtype] = None
65
+ # id: str = field(default_factory=lambda: str(uuid.uuid4()))
66
+ # confidence: float = 0.95
67
+ # model_metadata: dict = field(default_factory=dict)
68
+
69
+
70
+ # ======================== Symbol ======================== #
71
+ class SymbolType(Enum):
72
+ VALVE = "valve"
73
+ PUMP = "pump"
74
+ SENSOR = "sensor"
75
+ OTHER = "other" # Added to handle unknown categories
76
+
77
+ @dataclass
78
+ class Symbol:
79
+ center: Coordinates
80
+ symbol_type: SymbolType = field(default=SymbolType.OTHER)
81
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
82
+ class_id: int = -1
83
+ original_label: str = ""
84
+ category: str = "" # e.g., "inst"
85
+ type: str = "" # e.g., "ind"
86
+ label: str = "" # e.g., "Solenoid_actuator"
87
+ bbox: BBox = None
88
+ confidence: float = 0.95
89
+ model_source: str = "" # e.g., "model2"
90
+ connections: List[Point] = field(default_factory=list)
91
+ subtype: Optional[ValveSubtype] = None
92
+ model_metadata: dict = field(default_factory=dict)
93
+
94
+ def __post_init__(self):
95
+ """
96
+ Handle any additional post-processing after initialization.
97
+ """
98
+ # Ensure bbox is a BBox object
99
+ if isinstance(self.bbox, list) and len(self.bbox) == 4:
100
+ self.bbox = BBox(*self.bbox)
101
+
102
+
103
+ # ======================== Line ======================== #
104
+ @dataclass
105
+ class LineStyle:
106
+ connection_type: ConnectionType
107
+ stroke_width: int = 2
108
+ color: str = "#000000" # CSS-style colors
109
+
110
+ @dataclass
111
+ class Line:
112
+ start: Point
113
+ end: Point
114
+ bbox: BBox
115
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
116
+ style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID))
117
+ confidence: float = 0.90
118
+ topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions
119
+
120
+
121
+ # ======================== Junction ======================== #
122
+ class JunctionType(str, Enum):
123
+ T = "T"
124
+ L = "L"
125
+ END = "END"
126
+
127
+ @dataclass
128
+ class JunctionProperties:
129
+ flow_direction: Optional[str] = None # "in", "out"
130
+ pressure: Optional[float] = None # kPa
131
+
132
+ @dataclass
133
+ class Junction:
134
+ center: Coordinates
135
+ junction_type: JunctionType
136
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
137
+ properties: JunctionProperties = field(default_factory=JunctionProperties)
138
+ connected_lines: List[str] = field(default_factory=list) # Line IDs
139
+
140
+
141
+ # # ======================== Tag ======================== #
142
+ # @dataclass
143
+ # class Tag:
144
+ # text: str
145
+ # bbox: BBox
146
+ # associated_element: str # ID of linked symbol/line
147
+ # id: str = field(default_factory=lambda: str(uuid.uuid4()))
148
+ # font_size: int = 12
149
+ # rotation: float = 0.0 # Degrees
150
+
151
+ @dataclass
152
+ class Tag:
153
+ text: str
154
+ bbox: BBox
155
+ confidence: float = 1.0
156
+ source: str = "" # e.g., "easyocr"
157
+ text_type: str = "Unknown" # e.g., "Unknown", could be something else later
158
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
159
+ associated_element: Optional[str] = None # ID of linked symbol/line (can be None)
160
+ font_size: int = 12
161
+ rotation: float = 0.0 # Degrees
162
+
163
+ def __post_init__(self):
164
+ """
165
+ Ensure bbox is properly converted.
166
+ """
167
+ if isinstance(self.bbox, list) and len(self.bbox) == 4:
168
+ self.bbox = BBox(*self.bbox)
169
+
170
+ # ----------------------------
171
+ # DETECTION CONTEXT
172
+ # ----------------------------
173
+
174
+ @dataclass
175
+ class DetectionContext:
176
+ """
177
+ In-memory container for all detected elements (lines, points, symbols, junctions, tags).
178
+ Each element is stored in a dict keyed by 'id' for quick lookup and update.
179
+ """
180
+ lines: Dict[str, Line] = field(default_factory=dict)
181
+ points: Dict[str, Point] = field(default_factory=dict)
182
+ symbols: Dict[str, Symbol] = field(default_factory=dict)
183
+ junctions: Dict[str, Junction] = field(default_factory=dict)
184
+ tags: Dict[str, Tag] = field(default_factory=dict)
185
+
186
+ # -------------------------
187
+ # 1) ADD / GET / REMOVE
188
+ # -------------------------
189
+ def add_line(self, line: Line) -> None:
190
+ self.lines[line.id] = line
191
+
192
+ def get_line(self, line_id: str) -> Optional[Line]:
193
+ return self.lines.get(line_id)
194
+
195
+ def remove_line(self, line_id: str) -> None:
196
+ self.lines.pop(line_id, None)
197
+
198
+ def add_point(self, point: Point) -> None:
199
+ self.points[point.id] = point
200
+
201
+ def get_point(self, point_id: str) -> Optional[Point]:
202
+ return self.points.get(point_id)
203
+
204
+ def remove_point(self, point_id: str) -> None:
205
+ self.points.pop(point_id, None)
206
+
207
+ def add_symbol(self, symbol: Symbol) -> None:
208
+ self.symbols[symbol.id] = symbol
209
+
210
+ def get_symbol(self, symbol_id: str) -> Optional[Symbol]:
211
+ return self.symbols.get(symbol_id)
212
+
213
+ def remove_symbol(self, symbol_id: str) -> None:
214
+ self.symbols.pop(symbol_id, None)
215
+
216
+ def add_junction(self, junction: Junction) -> None:
217
+ self.junctions[junction.id] = junction
218
+
219
+ def get_junction(self, junction_id: str) -> Optional[Junction]:
220
+ return self.junctions.get(junction_id)
221
+
222
+ def remove_junction(self, junction_id: str) -> None:
223
+ self.junctions.pop(junction_id, None)
224
+
225
+ def add_tag(self, tag: Tag) -> None:
226
+ self.tags[tag.id] = tag
227
+
228
+ def get_tag(self, tag_id: str) -> Optional[Tag]:
229
+ return self.tags.get(tag_id)
230
+
231
+ def remove_tag(self, tag_id: str) -> None:
232
+ self.tags.pop(tag_id, None)
233
+
234
+ # -------------------------
235
+ # 2) SERIALIZATION: to_dict / from_dict
236
+ # -------------------------
237
+ def to_dict(self) -> dict:
238
+ """Convert all stored objects into a JSON-serializable dictionary."""
239
+ return {
240
+ "lines": [self._line_to_dict(line) for line in self.lines.values()],
241
+ "points": [self._point_to_dict(pt) for pt in self.points.values()],
242
+ "symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()],
243
+ "junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()],
244
+ "tags": [self._tag_to_dict(tg) for tg in self.tags.values()]
245
+ }
246
+
247
+ @classmethod
248
+ def from_dict(cls, data: dict) -> "DetectionContext":
249
+ """
250
+ Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON).
251
+ """
252
+ context = cls()
253
+
254
+ # Points
255
+ for pt_dict in data.get("points", []):
256
+ pt_obj = cls._point_from_dict(pt_dict)
257
+ context.add_point(pt_obj)
258
+
259
+ # Lines
260
+ for ln_dict in data.get("lines", []):
261
+ ln_obj = cls._line_from_dict(ln_dict)
262
+ context.add_line(ln_obj)
263
+
264
+ # Symbols
265
+ for sym_dict in data.get("symbols", []):
266
+ sym_obj = cls._symbol_from_dict(sym_dict)
267
+ context.add_symbol(sym_obj)
268
+
269
+ # Junctions
270
+ for jn_dict in data.get("junctions", []):
271
+ jn_obj = cls._junction_from_dict(jn_dict)
272
+ context.add_junction(jn_obj)
273
+
274
+ # Tags
275
+ for tg_dict in data.get("tags", []):
276
+ tg_obj = cls._tag_from_dict(tg_dict)
277
+ context.add_tag(tg_obj)
278
+
279
+ return context
280
+
281
+ # -------------------------
282
+ # 3) HELPER METHODS FOR (DE)SERIALIZATION
283
+ # -------------------------
284
+ @staticmethod
285
+ def _bbox_to_dict(bbox: BBox) -> dict:
286
+ return {
287
+ "xmin": bbox.xmin,
288
+ "ymin": bbox.ymin,
289
+ "xmax": bbox.xmax,
290
+ "ymax": bbox.ymax
291
+ }
292
+
293
+ @staticmethod
294
+ def _bbox_from_dict(d: dict) -> BBox:
295
+ return BBox(
296
+ xmin=d["xmin"],
297
+ ymin=d["ymin"],
298
+ xmax=d["xmax"],
299
+ ymax=d["ymax"]
300
+ )
301
+
302
+ @staticmethod
303
+ def _coords_to_dict(coords: Coordinates) -> dict:
304
+ return {
305
+ "x": coords.x,
306
+ "y": coords.y
307
+ }
308
+
309
+ @staticmethod
310
+ def _coords_from_dict(d: dict) -> Coordinates:
311
+ return Coordinates(x=d["x"], y=d["y"])
312
+
313
+ @staticmethod
314
+ def _line_style_to_dict(style: LineStyle) -> dict:
315
+ return {
316
+ "connection_type": style.connection_type.value,
317
+ "stroke_width": style.stroke_width,
318
+ "color": style.color
319
+ }
320
+
321
+ @staticmethod
322
+ def _line_style_from_dict(d: dict) -> LineStyle:
323
+ return LineStyle(
324
+ connection_type=ConnectionType(d["connection_type"]),
325
+ stroke_width=d.get("stroke_width", 2),
326
+ color=d.get("color", "#000000")
327
+ )
328
+
329
+ @staticmethod
330
+ def _point_to_dict(pt: Point) -> dict:
331
+ return {
332
+ "id": pt.id,
333
+ "coords": DetectionContext._coords_to_dict(pt.coords),
334
+ "bbox": DetectionContext._bbox_to_dict(pt.bbox),
335
+ "type": pt.type.value,
336
+ "confidence": pt.confidence
337
+ }
338
+
339
+ @staticmethod
340
+ def _point_from_dict(d: dict) -> Point:
341
+ return Point(
342
+ id=d["id"],
343
+ coords=DetectionContext._coords_from_dict(d["coords"]),
344
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
345
+ type=JunctionType(d["type"]),
346
+ confidence=d.get("confidence", 1.0)
347
+ )
348
+
349
+ @staticmethod
350
+ def _line_to_dict(ln: Line) -> dict:
351
+ return {
352
+ "id": ln.id,
353
+ "start": DetectionContext._point_to_dict(ln.start),
354
+ "end": DetectionContext._point_to_dict(ln.end),
355
+ "bbox": DetectionContext._bbox_to_dict(ln.bbox),
356
+ "style": DetectionContext._line_style_to_dict(ln.style),
357
+ "confidence": ln.confidence,
358
+ "topological_links": ln.topological_links
359
+ }
360
+
361
+ @staticmethod
362
+ def _line_from_dict(d: dict) -> Line:
363
+ return Line(
364
+ id=d["id"],
365
+ start=DetectionContext._point_from_dict(d["start"]),
366
+ end=DetectionContext._point_from_dict(d["end"]),
367
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
368
+ style=DetectionContext._line_style_from_dict(d["style"]),
369
+ confidence=d.get("confidence", 0.90),
370
+ topological_links=d.get("topological_links", [])
371
+ )
372
+
373
+ @staticmethod
374
+ def _symbol_to_dict(sym: Symbol) -> dict:
375
+ return {
376
+ "id": sym.id,
377
+ "symbol_type": sym.symbol_type.value,
378
+ "bbox": DetectionContext._bbox_to_dict(sym.bbox),
379
+ "center": DetectionContext._coords_to_dict(sym.center),
380
+ "connections": [DetectionContext._point_to_dict(p) for p in sym.connections],
381
+ "subtype": sym.subtype.value if sym.subtype else None,
382
+ "confidence": sym.confidence,
383
+ "model_metadata": sym.model_metadata
384
+ }
385
+
386
+ @staticmethod
387
+ def _symbol_from_dict(d: dict) -> Symbol:
388
+ return Symbol(
389
+ id=d["id"],
390
+ symbol_type=SymbolType(d["symbol_type"]),
391
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
392
+ center=DetectionContext._coords_from_dict(d["center"]),
393
+ connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])],
394
+ subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None,
395
+ confidence=d.get("confidence", 0.95),
396
+ model_metadata=d.get("model_metadata", {})
397
+ )
398
+
399
+ @staticmethod
400
+ def _junction_props_to_dict(props: JunctionProperties) -> dict:
401
+ return {
402
+ "flow_direction": props.flow_direction,
403
+ "pressure": props.pressure
404
+ }
405
+
406
+ @staticmethod
407
+ def _junction_props_from_dict(d: dict) -> JunctionProperties:
408
+ return JunctionProperties(
409
+ flow_direction=d.get("flow_direction"),
410
+ pressure=d.get("pressure")
411
+ )
412
+
413
+ @staticmethod
414
+ def _junction_to_dict(jn: Junction) -> dict:
415
+ return {
416
+ "id": jn.id,
417
+ "center": DetectionContext._coords_to_dict(jn.center),
418
+ "junction_type": jn.junction_type.value,
419
+ "properties": DetectionContext._junction_props_to_dict(jn.properties),
420
+ "connected_lines": jn.connected_lines
421
+ }
422
+
423
+ @staticmethod
424
+ def _junction_from_dict(d: dict) -> Junction:
425
+ return Junction(
426
+ id=d["id"],
427
+ center=DetectionContext._coords_from_dict(d["center"]),
428
+ junction_type=JunctionType(d["junction_type"]),
429
+ properties=DetectionContext._junction_props_from_dict(d["properties"]),
430
+ connected_lines=d.get("connected_lines", [])
431
+ )
432
+
433
+ @staticmethod
434
+ def _tag_to_dict(tg: Tag) -> dict:
435
+ return {
436
+ "id": tg.id,
437
+ "text": tg.text,
438
+ "bbox": DetectionContext._bbox_to_dict(tg.bbox),
439
+ "associated_element": tg.associated_element,
440
+ "font_size": tg.font_size,
441
+ "rotation": tg.rotation
442
+ }
443
+
444
+ @staticmethod
445
+ def _tag_from_dict(d: dict) -> Tag:
446
+ return Tag(
447
+ id=d["id"],
448
+ text=d["text"],
449
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
450
+ associated_element=d["associated_element"],
451
+ font_size=d.get("font_size", 12),
452
+ rotation=d.get("rotation", 0.0)
453
+ )
454
+
455
+ # -------------------------
456
+ # 4) OPTIONAL UTILS
457
+ # -------------------------
458
+ def to_json(self, indent: int = 2) -> str:
459
+ """Convert context to JSON, ensuring dataclasses and numpy types are handled correctly."""
460
+ return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent)
461
+
462
+ @staticmethod
463
+ def _json_serializer(obj):
464
+ """Handles numpy types and unknown objects for JSON serialization."""
465
+ if isinstance(obj, np.integer):
466
+ return int(obj)
467
+ if isinstance(obj, np.floating):
468
+ return float(obj)
469
+ if isinstance(obj, np.ndarray):
470
+ return obj.tolist() # Convert arrays to lists
471
+ if isinstance(obj, Enum):
472
+ return obj.value # Convert Enums to string values
473
+ if hasattr(obj, "__dict__"):
474
+ return obj.__dict__ # Convert dataclass objects to dict
475
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
476
+
477
+ @classmethod
478
+ def from_json(cls, json_str: str) -> "DetectionContext":
479
+ """Load DetectionContext from a JSON string."""
480
+ data = json.loads(json_str)
481
+ return cls.from_dict(data)
detection_utils.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import List, Tuple
3
+ import math
4
+
5
+ def robust_merge_lines(lines: List[Tuple[float, float, float, float]],
6
+ angle_thresh: float = 5.0,
7
+ dist_thresh: float = 5.0) -> List[Tuple[float, float, float, float]]:
8
+ """
9
+ Merge similar line segments using angle and distance thresholds.
10
+
11
+ Args:
12
+ lines: List of line segments [(x1,y1,x2,y2),...]
13
+ angle_thresh: Maximum angle difference in degrees
14
+ dist_thresh: Maximum endpoint distance
15
+
16
+ Returns:
17
+ List of merged line segments
18
+ """
19
+ if not lines:
20
+ return []
21
+
22
+ # Convert to numpy array for easier manipulation
23
+ lines = np.array(lines)
24
+
25
+ # Calculate line angles
26
+ angles = np.arctan2(lines[:,3] - lines[:,1],
27
+ lines[:,2] - lines[:,0])
28
+ angles = np.degrees(angles) % 180
29
+
30
+ # Group similar lines
31
+ merged = []
32
+ used = set()
33
+
34
+ for i, line1 in enumerate(lines):
35
+ if i in used:
36
+ continue
37
+
38
+ # Find similar lines
39
+ similar = []
40
+ for j, line2 in enumerate(lines):
41
+ if j in used:
42
+ continue
43
+
44
+ # Check angle difference
45
+ angle_diff = abs(angles[i] - angles[j])
46
+ angle_diff = min(angle_diff, 180 - angle_diff)
47
+
48
+ if angle_diff > angle_thresh:
49
+ continue
50
+
51
+ # Check endpoint distances
52
+ dist1 = np.linalg.norm(line1[:2] - line2[:2])
53
+ dist2 = np.linalg.norm(line1[2:] - line2[2:])
54
+
55
+ if min(dist1, dist2) > dist_thresh:
56
+ continue
57
+
58
+ similar.append(j)
59
+ used.add(j)
60
+
61
+ # Merge similar lines
62
+ if similar:
63
+ points = lines[similar].reshape(-1, 2)
64
+ direction = np.array([np.cos(np.radians(angles[i])),
65
+ np.sin(np.radians(angles[i]))])
66
+
67
+ # Project points onto line direction
68
+ proj = points @ direction
69
+
70
+ # Get extreme points
71
+ min_idx = np.argmin(proj)
72
+ max_idx = np.argmax(proj)
73
+
74
+ merged_line = np.concatenate([points[min_idx], points[max_idx]])
75
+ merged.append(tuple(merged_line))
76
+
77
+ return merged
78
+
79
+ def compute_line_angle(x1: float, y1: float, x2: float, y2: float) -> float:
80
+ """Compute angle of line segment in degrees"""
81
+ return math.degrees(math.atan2(y2 - y1, x2 - x1)) % 180
detectors.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from typing import List, Optional, Tuple, Dict
7
+ from dataclasses import replace
8
+ from math import sqrt
9
+ import json
10
+ import uuid
11
+ from pathlib import Path
12
+
13
+ # Base classes and utilities
14
+ from base import BaseDetector
15
+ from detection_schema import DetectionContext
16
+ from utils import DebugHandler
17
+ from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
18
+
19
+ # DeepLSD model for line detection
20
+ from deeplsd.models.deeplsd_inference import DeepLSD
21
+ from ultralytics import YOLO
22
+
23
+ # Detection schema: dataclasses for different objects
24
+ from detection_schema import (
25
+ BBox,
26
+ Coordinates,
27
+ Point,
28
+ Line,
29
+ Symbol,
30
+ Tag,
31
+ SymbolType,
32
+ LineStyle,
33
+ ConnectionType,
34
+ JunctionType,
35
+ Junction
36
+ )
37
+
38
+ # Skeletonization and label processing for junction detection
39
+ from skimage.morphology import skeletonize
40
+ from skimage.measure import label
41
+
42
+
43
+ import os
44
+ import cv2
45
+ import torch
46
+ import numpy as np
47
+ from dataclasses import replace
48
+ from typing import List, Optional
49
+ from detection_utils import robust_merge_lines
50
+
51
+
52
+ class LineDetector(BaseDetector):
53
+ """
54
+ DeepLSD-based line detection with patch-based tiling and global merging.
55
+ """
56
+
57
+ def __init__(self,
58
+ config: LineConfig,
59
+ model_path: str,
60
+ model_config: dict,
61
+ device: torch.device,
62
+ debug_handler: DebugHandler = None):
63
+ super().__init__(config, debug_handler)
64
+
65
+ # Fix device selection for Apple Silicon
66
+ if torch.backends.mps.is_available():
67
+ self.device = torch.device("mps")
68
+ elif torch.cuda.is_available():
69
+ self.device = torch.device("cuda")
70
+ else:
71
+ self.device = torch.device("cpu")
72
+
73
+ self.model_path = model_path
74
+ self.model_config = model_config
75
+ self.model = self._load_model(model_path)
76
+
77
+ # Patch parameters
78
+ self.patch_size = 512
79
+ self.overlap = 10
80
+
81
+ # Merging thresholds
82
+ self.angle_thresh = 5.0 # degrees
83
+ self.dist_thresh = 5.0 # pixels
84
+
85
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
86
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (2, 2))
87
+ dilated = cv2.dilate(image, kernel, iterations=2)
88
+
89
+ skeleton = cv2.bitwise_not(dilated)
90
+ skeleton = skeletonize(skeleton // 255)
91
+ skeleton = (skeleton * 255).astype(np.uint8)
92
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
93
+ clean_image = cv2.dilate(skeleton, kernel, iterations=5)
94
+
95
+ self.debug_handler.save_artifact(name="skeleton", data=clean_image, extension="png")
96
+
97
+ return clean_image
98
+
99
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
100
+ return None
101
+ # -------------------------------------
102
+ # 1) Load Model
103
+ # -------------------------------------
104
+ def _load_model(self, model_path: str) -> DeepLSD:
105
+ if not os.path.exists(model_path):
106
+ raise FileNotFoundError(f"Model file not found: {model_path}")
107
+ ckpt = torch.load(model_path, map_location=self.device)
108
+ model = DeepLSD(self.model_config)
109
+ model.load_state_dict(ckpt["model"])
110
+ return model.to(self.device).eval()
111
+
112
+ # -------------------------------------
113
+ # 2) Main Detection Pipeline
114
+ # -------------------------------------
115
+ def detect(self,
116
+ image: np.ndarray,
117
+ context: DetectionContext,
118
+ mask_coords: Optional[List[BBox]] = None,
119
+ *args,
120
+ **kwargs) -> None:
121
+ """
122
+ Steps:
123
+ - Optional mask + threshold
124
+ - Tile into overlapping patches
125
+ - For each patch => run DeepLSD => re-map lines to global coords
126
+ - Merge lines robustly
127
+ - Build final Line objects => add to context
128
+ """
129
+ mask_coords = mask_coords or []
130
+
131
+ skeleton = self._preprocess(image)
132
+ # (A) Optional mask + threshold if you want a binary
133
+ # If your model expects grayscale or binary, do it here:
134
+ processed_img = self._apply_mask_and_threshold(skeleton, mask_coords)
135
+ # (B) Patch-based inference => collect raw lines in global coords
136
+ all_lines = self._detect_in_patches(processed_img)
137
+
138
+ # (C) Merge the lines in the global coordinate system
139
+ merged_line_segments = robust_merge_lines(
140
+ all_lines,
141
+ angle_thresh=self.angle_thresh,
142
+ dist_thresh=self.dist_thresh
143
+ )
144
+
145
+ # (D) Convert merged segments => final Line objects, add to context
146
+ for (x1, y1, x2, y2) in merged_line_segments:
147
+ line_obj = self._create_line_object(x1, y1, x2, y2)
148
+ context.add_line(line_obj)
149
+
150
+ # -------------------------------------
151
+ # 3) Optional Mask + Threshold
152
+ # -------------------------------------
153
+ def _apply_mask_and_threshold(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
154
+ """White out rectangular areas, then threshold to binary (if needed)."""
155
+ masked = image.copy()
156
+ for bbox in mask_coords:
157
+ x1, y1 = int(bbox.xmin), int(bbox.ymin)
158
+ x2, y2 = int(bbox.xmax), int(bbox.ymax)
159
+ cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)
160
+
161
+ # If image has 3 channels, convert to grayscale
162
+ if len(masked.shape) == 3:
163
+ masked_gray = cv2.cvtColor(masked, cv2.COLOR_BGR2GRAY)
164
+ else:
165
+ masked_gray = masked
166
+
167
+ # Binary threshold (adjust threshold as needed)
168
+ # If your model expects a plain grayscale, skip threshold
169
+ binary_img = cv2.threshold(masked_gray, 127, 255, cv2.THRESH_BINARY)[1]
170
+ return binary_img
171
+
172
+ # -------------------------------------
173
+ # 4) Patch-Based Inference
174
+ # -------------------------------------
175
+ def _detect_in_patches(self, processed_img: np.ndarray) -> List[tuple]:
176
+ """
177
+ Break the image into overlapping patches, run DeepLSD,
178
+ map local lines => global coords, and return the global line list.
179
+ """
180
+ patch_size = self.patch_size
181
+ overlap = self.overlap
182
+
183
+ height, width = processed_img.shape[:2]
184
+ step = patch_size - overlap
185
+
186
+ all_lines = []
187
+
188
+ for y in range(0, height, step):
189
+ patch_ymax = min(y + patch_size, height)
190
+ patch_ymin = patch_ymax - patch_size if (patch_ymax - y) < patch_size else y
191
+ if patch_ymin < 0: patch_ymin = 0
192
+
193
+ for x in range(0, width, step):
194
+ patch_xmax = min(x + patch_size, width)
195
+ patch_xmin = patch_xmax - patch_size if (patch_xmax - x) < patch_size else x
196
+ if patch_xmin < 0: patch_xmin = 0
197
+
198
+ patch = processed_img[patch_ymin:patch_ymax, patch_xmin:patch_xmax]
199
+
200
+ # Run model
201
+ local_lines = self._run_model_inference(patch)
202
+
203
+ # Convert local lines => global coords
204
+ for ln in local_lines:
205
+ (x1_local, y1_local), (x2_local, y2_local) = ln
206
+
207
+ # offset by patch_xmin, patch_ymin
208
+ gx1 = x1_local + patch_xmin
209
+ gy1 = y1_local + patch_ymin
210
+ gx2 = x2_local + patch_xmin
211
+ gy2 = y2_local + patch_ymin
212
+
213
+ # Optional: clamp or filter lines partially out-of-bounds
214
+ if 0 <= gx1 < width and 0 <= gx2 < width and 0 <= gy1 < height and 0 <= gy2 < height:
215
+ all_lines.append((gx1, gy1, gx2, gy2))
216
+
217
+ return all_lines
218
+
219
+ # -------------------------------------
220
+ # 5) Model Inference (Single Patch)
221
+ # -------------------------------------
222
+ def _run_model_inference(self, patch_img: np.ndarray) -> np.ndarray:
223
+ """
224
+ Run DeepLSD on a single patch (already masked/thresholded).
225
+ patch_img shape: [patchH, patchW].
226
+ Returns lines shape: [N, 2, 2].
227
+ """
228
+ # Convert patch to float32 and scale
229
+ inp = torch.tensor(patch_img, dtype=torch.float32, device=self.device)[None, None] / 255.0
230
+ with torch.no_grad():
231
+ output = self.model({"image": inp})
232
+ lines = output["lines"][0] # shape (N, 2, 2)
233
+ return lines
234
+
235
+ # -------------------------------------
236
+ # 6) Convert Merged Segments => Line Objects
237
+ # -------------------------------------
238
+ def _create_line_object(self, x1: float, y1: float, x2: float, y2: float) -> Line:
239
+ """
240
+ Create a minimal `Line` object from the final merged coordinates.
241
+ """
242
+ margin = 2
243
+ # Start point
244
+ start_pt = Point(
245
+ coords=Coordinates(int(x1), int(y1)),
246
+ bbox=BBox(
247
+ xmin=int(x1 - margin),
248
+ ymin=int(y1 - margin),
249
+ xmax=int(x1 + margin),
250
+ ymax=int(y1 + margin)
251
+ ),
252
+ type=JunctionType.END,
253
+ confidence=1.0
254
+ )
255
+ # End point
256
+ end_pt = Point(
257
+ coords=Coordinates(int(x2), int(y2)),
258
+ bbox=BBox(
259
+ xmin=int(x2 - margin),
260
+ ymin=int(y2 - margin),
261
+ xmax=int(x2 + margin),
262
+ ymax=int(y2 + margin)
263
+ ),
264
+ type=JunctionType.END,
265
+ confidence=1.0
266
+ )
267
+
268
+ # Overall bounding box
269
+ x_min = int(min(x1, x2))
270
+ x_max = int(max(x1, x2))
271
+ y_min = int(min(y1, y2))
272
+ y_max = int(max(y1, y2))
273
+
274
+ line_obj = Line(
275
+ start=start_pt,
276
+ end=end_pt,
277
+ bbox=BBox(xmin=x_min, ymin=y_min, xmax=x_max, ymax=y_max),
278
+ style=LineStyle(
279
+ connection_type=ConnectionType.SOLID,
280
+ stroke_width=2,
281
+ color="#000000"
282
+ ),
283
+ confidence=0.9,
284
+ topological_links=[]
285
+ )
286
+ return line_obj
287
+
288
+ class PointDetector(BaseDetector):
289
+ """
290
+ A detector that:
291
+ 1) Reads lines from the context
292
+ 2) Clusters endpoints within 'threshold_distance'
293
+ 3) Updates lines so that shared endpoints reference the same Point object
294
+ """
295
+
296
+ def __init__(self,
297
+ config:PointConfig,
298
+ debug_handler: DebugHandler = None):
299
+ super().__init__(config, debug_handler) # No real model to load
300
+ self.threshold_distance = config.threshold_distance
301
+
302
+ def _load_model(self, model_path: str):
303
+ """No model needed for simple point unification."""
304
+ return None
305
+
306
+ def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
307
+ """
308
+ Main method called by the pipeline.
309
+ 1) Gather all line endpoints from context
310
+ 2) Cluster them within 'threshold_distance'
311
+ 3) Update the line endpoints so they reference the unified cluster point
312
+ """
313
+ # 1) Collect all endpoints
314
+ endpoints = []
315
+ for line in context.lines.values():
316
+ endpoints.append(line.start)
317
+ endpoints.append(line.end)
318
+
319
+ # 2) Cluster endpoints
320
+ clusters = self._cluster_points(endpoints, self.threshold_distance)
321
+
322
+ # 3) Build a dictionary of "representative" points
323
+ # So that each cluster has one "canonical" point
324
+ # Then we link all the points in that cluster to the canonical reference
325
+ unified_point_map = {}
326
+ for cluster in clusters:
327
+ # let's pick the first point in the cluster as the "representative"
328
+ rep_point = cluster[0]
329
+ for p in cluster[1:]:
330
+ unified_point_map[p.id] = rep_point
331
+
332
+ # 4) Update all lines to reference the canonical point
333
+ for line in context.lines.values():
334
+ # unify start
335
+ if line.start.id in unified_point_map:
336
+ line.start = unified_point_map[line.start.id]
337
+ # unify end
338
+ if line.end.id in unified_point_map:
339
+ line.end = unified_point_map[line.end.id]
340
+
341
+ # We could also store the final set of unique points back in context.points
342
+ # (e.g. clearing old duplicates).
343
+ # That step is optional: you might prefer to keep everything in lines only,
344
+ # or you might want context.points as a separate reference.
345
+
346
+ # If you want to keep unique points in context.points:
347
+ new_points = {}
348
+ for line in context.lines.values():
349
+ new_points[line.start.id] = line.start
350
+ new_points[line.end.id] = line.end
351
+ context.points = new_points # replace the dictionary of points
352
+
353
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
354
+ """No specific image preprocessing needed."""
355
+ return image
356
+
357
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
358
+ """No specific image postprocessing needed."""
359
+ return image
360
+
361
+ # ----------------------
362
+ # HELPER: clustering
363
+ # ----------------------
364
+ def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
365
+ """
366
+ Very naive clustering:
367
+ 1) Start from the first point
368
+ 2) If it's within threshold of an existing cluster's representative,
369
+ put it in that cluster
370
+ 3) Otherwise start a new cluster
371
+ Return: list of clusters, each is a list of Points
372
+ """
373
+ clusters = []
374
+
375
+ for pt in points:
376
+ placed = False
377
+ for cluster in clusters:
378
+ # pick the first point in the cluster as reference
379
+ ref_pt = cluster[0]
380
+ if self._distance(pt, ref_pt) < threshold:
381
+ cluster.append(pt)
382
+ placed = True
383
+ break
384
+
385
+ if not placed:
386
+ clusters.append([pt])
387
+
388
+ return clusters
389
+
390
+ def _distance(self, p1: Point, p2: Point) -> float:
391
+ dx = p1.coords.x - p2.coords.x
392
+ dy = p1.coords.y - p2.coords.y
393
+ return sqrt(dx*dx + dy*dy)
394
+
395
+
396
+ class JunctionDetector(BaseDetector):
397
+ """
398
+ Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
399
+ and analyzing local connectivity. Also creates Junction objects in the context.
400
+ """
401
+
402
+ def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
403
+ super().__init__(config, debug_handler) # no real model path
404
+ self.window_size = config.window_size
405
+ self.radius = config.radius
406
+ self.angle_threshold_lb = config.angle_threshold_lb
407
+ self.angle_threshold_ub = config.angle_threshold_ub
408
+ self.debug_handler = debug_handler or DebugHandler()
409
+
410
+ def _load_model(self, model_path: str):
411
+ """Not loading any actual model, just skeleton logic."""
412
+ return None
413
+
414
+ def detect(self,
415
+ image: np.ndarray,
416
+ context: DetectionContext,
417
+ *args,
418
+ **kwargs) -> None:
419
+ """
420
+ 1) Convert to binary & skeletonize
421
+ 2) Classify each point in the context
422
+ 3) Create a Junction for each point and store it in context.junctions
423
+ (with 'connected_lines' referencing lines that share this point).
424
+ """
425
+ # 1) Preprocess -> skeleton
426
+ skeleton = self._create_skeleton(image)
427
+
428
+ # 2) Classify each point
429
+ for pt in context.points.values():
430
+ pt.type = self._classify_point(skeleton, pt)
431
+
432
+ # 3) Create a Junction object for each point
433
+ # If you prefer only T or L, you can filter out END points.
434
+ self._record_junctions_in_context(context)
435
+
436
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
437
+ """We might do thresholding; let's do a simple binary threshold."""
438
+ if image.ndim == 3:
439
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
440
+ else:
441
+ gray = image
442
+ _, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
443
+ return bin_image
444
+
445
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
446
+ return image
447
+
448
+ def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
449
+ """Skeletonize the binarized image."""
450
+ bin_img = self._preprocess(raw_image)
451
+ # For skeletonize, we need a boolean array
452
+ inv = cv2.bitwise_not(bin_img)
453
+ inv_bool = (inv > 127).astype(np.uint8)
454
+ skel = skeletonize(inv_bool).astype(np.uint8) * 255
455
+ return skel
456
+
457
+ def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
458
+ """
459
+ Given a skeleton image, look around 'pt' in a local window
460
+ to determine if it's an END, L, or T.
461
+ """
462
+ classification = JunctionType.END # default
463
+
464
+ half_w = self.window_size // 2
465
+ x, y = pt.coords.x, pt.coords.y
466
+
467
+ top = max(0, y - half_w)
468
+ bottom = min(skeleton.shape[0], y + half_w + 1)
469
+ left = max(0, x - half_w)
470
+ right = min(skeleton.shape[1], x + half_w + 1)
471
+
472
+ patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)
473
+
474
+ # create circular mask
475
+ circle_mask = np.zeros_like(patch, dtype=np.uint8)
476
+ local_cx = x - left
477
+ local_cy = y - top
478
+ cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
479
+ circle_skel = patch & circle_mask
480
+
481
+ # label connected regions
482
+ labeled = label(circle_skel, connectivity=2)
483
+ num_exits = labeled.max()
484
+
485
+ if num_exits == 1:
486
+ classification = JunctionType.END
487
+ elif num_exits == 2:
488
+ # check angle for L
489
+ classification = self._check_angle_for_L(labeled)
490
+ elif num_exits == 3:
491
+ classification = JunctionType.T
492
+
493
+ return classification
494
+
495
+ def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
496
+ """
497
+ If the angle between two branches is within
498
+ [angle_threshold_lb, angle_threshold_ub], it's 'L'.
499
+ Otherwise default to END.
500
+ """
501
+ coords = np.argwhere(labeled_region == 1)
502
+ if len(coords) < 2:
503
+ return JunctionType.END
504
+
505
+ (y1, x1), (y2, x2) = coords[:2]
506
+ dx = x2 - x1
507
+ dy = y2 - y1
508
+ angle = math.degrees(math.atan2(dy, dx))
509
+ acute_angle = min(abs(angle), 180 - abs(angle))
510
+
511
+ if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
512
+ return JunctionType.L
513
+ return JunctionType.END
514
+
515
+ # -----------------------------------------
516
+ # EXTRA STEP: Create Junction objects
517
+ # -----------------------------------------
518
+ def _record_junctions_in_context(self, context: DetectionContext):
519
+ """
520
+ Create a Junction object for each point in context.points.
521
+ If you only want T/L points as junctions, filter them out.
522
+ Also track any lines that connect to this point.
523
+ """
524
+
525
+ for pt in context.points.values():
526
+ # If you prefer to store all points as junction, do it:
527
+ # or if you want only T or L, do:
528
+ # if pt.type in {JunctionType.T, JunctionType.L}: ...
529
+
530
+ jn = Junction(
531
+ center=pt.coords,
532
+ junction_type=pt.type,
533
+ # add more properties if needed
534
+ )
535
+
536
+ # find lines that connect to this point
537
+ connected_lines = []
538
+ for ln in context.lines.values():
539
+ if ln.start.id == pt.id or ln.end.id == pt.id:
540
+ connected_lines.append(ln.id)
541
+
542
+ jn.connected_lines = connected_lines
543
+
544
+ # add to context
545
+ context.add_junction(jn)
546
+
547
+ import json
548
+ import uuid
549
+
550
+ class SymbolDetector(BaseDetector):
551
+ """
552
+ A placeholder detector that reads precomputed symbol data
553
+ from a JSON file and populates the context with Symbol objects.
554
+ """
555
+
556
+ def __init__(self,
557
+ config: SymbolConfig,
558
+ debug_handler: Optional[DebugHandler] = None,
559
+ symbol_json_path: str = "./symbols.json"):
560
+ super().__init__(config=config, debug_handler=debug_handler)
561
+ self.symbol_json_path = symbol_json_path
562
+
563
+ def _load_model(self, model_path: str):
564
+ """Not loading an actual model; symbol data is read from JSON."""
565
+ return None
566
+
567
+ def detect(self,
568
+ image: np.ndarray,
569
+ context: DetectionContext,
570
+ # roi_offset: Tuple[int, int],
571
+ *args,
572
+ **kwargs) -> None:
573
+ """
574
+ Reads from a JSON file containing symbol info,
575
+ adjusts coordinates using roi_offset, and updates context.
576
+ """
577
+ symbol_data = self._load_json_data(self.symbol_json_path)
578
+ if not symbol_data:
579
+ return
580
+
581
+ # x_min, y_min = roi_offset # Offset values from cropping
582
+
583
+ for record in symbol_data.get("detections", []): # Fix: Use "detections" key
584
+ # sym_obj = self._parse_symbol_record(record, x_min, y_min)
585
+ sym_obj = self._parse_symbol_record(record)
586
+ context.add_symbol(sym_obj)
587
+
588
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
589
+ return image
590
+
591
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
592
+ return image
593
+
594
+ # --------------
595
+ # HELPER METHODS
596
+ # --------------
597
+ def _load_json_data(self, json_path: str) -> dict:
598
+ if not os.path.exists(json_path):
599
+ self.debug_handler.save_artifact(name="symbol_error",
600
+ data=b"Missing symbol JSON file",
601
+ extension="txt")
602
+ return {}
603
+
604
+ with open(json_path, "r", encoding="utf-8") as f:
605
+ return json.load(f)
606
+
607
+ def _parse_symbol_record(self, record: dict) -> Symbol:
608
+ """
609
+ Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
610
+ """
611
+ bbox_list = record.get("bbox", [0, 0, 0, 0])
612
+ # bbox_obj = BBox(
613
+ # xmin=bbox_list[0] - x_min,
614
+ # ymin=bbox_list[1] - y_min,
615
+ # xmax=bbox_list[2] - x_min,
616
+ # ymax=bbox_list[3] - y_min
617
+ # )
618
+
619
+ bbox_obj = BBox(
620
+ xmin=bbox_list[0],
621
+ ymin=bbox_list[1],
622
+ xmax=bbox_list[2],
623
+ ymax=bbox_list[3]
624
+ )
625
+
626
+
627
+ # Compute the center
628
+ center_coords = Coordinates(
629
+ x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
630
+ y=(bbox_obj.ymin + bbox_obj.ymax) // 2
631
+ )
632
+
633
+ return Symbol(
634
+ id=record.get("symbol_id", ""),
635
+ class_id=record.get("class_id", -1),
636
+ original_label=record.get("original_label", ""),
637
+ category=record.get("category", ""),
638
+ type=record.get("type", ""),
639
+ label=record.get("label", ""),
640
+ bbox=bbox_obj,
641
+ center=center_coords,
642
+ confidence=record.get("confidence", 0.95),
643
+ model_source=record.get("model_source", ""),
644
+ connections=[]
645
+ )
646
+
647
+ class TagDetector(BaseDetector):
648
+ """
649
+ A placeholder detector that reads precomputed tag data
650
+ from a JSON file and populates the context with Tag objects.
651
+ """
652
+
653
+ def __init__(self,
654
+ config: TagConfig,
655
+ debug_handler: Optional[DebugHandler] = None,
656
+ tag_json_path: str = "./tags.json"):
657
+ super().__init__(config=config, debug_handler=debug_handler)
658
+ self.tag_json_path = tag_json_path
659
+
660
+ def _load_model(self, model_path: str):
661
+ """Not loading an actual model; tag data is read from JSON."""
662
+ return None
663
+
664
+ def detect(self,
665
+ image: np.ndarray,
666
+ context: DetectionContext,
667
+ # roi_offset: Tuple[int, int],
668
+ *args,
669
+ **kwargs) -> None:
670
+ """
671
+ Reads from a JSON file containing tag info,
672
+ adjusts coordinates using roi_offset, and updates context.
673
+ """
674
+
675
+ tag_data = self._load_json_data(self.tag_json_path)
676
+ if not tag_data:
677
+ return
678
+
679
+ # x_min, y_min = roi_offset # Offset values from cropping
680
+
681
+ for record in tag_data.get("detections", []): # Fix: Use "detections" key
682
+ # tag_obj = self._parse_tag_record(record, x_min, y_min)
683
+ tag_obj = self._parse_tag_record(record)
684
+ context.add_tag(tag_obj)
685
+
686
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
687
+ return image
688
+
689
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
690
+ return image
691
+
692
+ # --------------
693
+ # HELPER METHODS
694
+ # --------------
695
+ def _load_json_data(self, json_path: str) -> dict:
696
+ if not os.path.exists(json_path):
697
+ self.debug_handler.save_artifact(name="tag_error",
698
+ data=b"Missing tag JSON file",
699
+ extension="txt")
700
+ return {}
701
+
702
+ with open(json_path, "r", encoding="utf-8") as f:
703
+ return json.load(f)
704
+
705
+ def _parse_tag_record(self, record: dict) -> Tag:
706
+ """
707
+ Builds a Tag object from a JSON record, adjusting coordinates for cropping.
708
+ """
709
+ bbox_list = record.get("bbox", [0, 0, 0, 0])
710
+ # bbox_obj = BBox(
711
+ # xmin=bbox_list[0] - x_min,
712
+ # ymin=bbox_list[1] - y_min,
713
+ # xmax=bbox_list[2] - x_min,
714
+ # ymax=bbox_list[3] - y_min
715
+ # )
716
+
717
+ bbox_obj = BBox(
718
+ xmin=bbox_list[0],
719
+ ymin=bbox_list[1],
720
+ xmax=bbox_list[2],
721
+ ymax=bbox_list[3]
722
+ )
723
+
724
+ return Tag(
725
+ text=record.get("text", ""),
726
+ bbox=bbox_obj,
727
+ confidence=record.get("confidence", 1.0),
728
+ source=record.get("source", ""),
729
+ text_type=record.get("text_type", "Unknown"),
730
+ id=record.get("id", str(uuid.uuid4())),
731
+ font_size=record.get("font_size", 12),
732
+ rotation=record.get("rotation", 0.0)
733
+ )
download_models.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import torch
4
+ from doctr.models import ocr_predictor
5
+ from ultralytics import YOLO
6
+ from deeplsd.models.deeplsd_inference import DeepLSD
7
+ from dotenv import load_dotenv
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ def copy_local_models():
13
+ """Copy models from local directory to deployment"""
14
+ # Create model directories
15
+ os.makedirs('models/yolo', exist_ok=True)
16
+ os.makedirs('models/deeplsd', exist_ok=True)
17
+ os.makedirs('models/doctr', exist_ok=True)
18
+
19
+ # Source paths (adjust these to your local paths)
20
+ local_models_dir = "../models"
21
+
22
+ # Copy YOLO model
23
+ yolo_src = os.path.join(local_models_dir, "yolov8n.pt")
24
+ yolo_dst = "models/yolo/yolov8n.pt"
25
+ if os.path.exists(yolo_src):
26
+ shutil.copy2(yolo_src, yolo_dst)
27
+ print(f"Copied YOLO model to {yolo_dst}")
28
+
29
+ # Copy DeepLSD model
30
+ deeplsd_src = os.path.join(local_models_dir, "deeplsd_md.tar")
31
+ deeplsd_dst = "models/deeplsd/deeplsd_md.tar"
32
+ if os.path.exists(deeplsd_src):
33
+ shutil.copy2(deeplsd_src, deeplsd_dst)
34
+ print(f"Copied DeepLSD model to {deeplsd_dst}")
35
+
36
+ # Copy DocTR model if exists
37
+ doctr_src = os.path.join(local_models_dir, "ocr_predictor.pt")
38
+ doctr_dst = "models/doctr/ocr_predictor.pt"
39
+ if os.path.exists(doctr_src):
40
+ shutil.copy2(doctr_src, doctr_dst)
41
+ print(f"Copied DocTR model to {doctr_dst}")
42
+ else:
43
+ # Download DocTR model if not available locally
44
+ predictor = ocr_predictor(pretrained=True)
45
+ torch.save(predictor.state_dict(), doctr_dst)
46
+ print(f"Downloaded DocTR model to {doctr_dst}")
47
+
48
+ if __name__ == "__main__":
49
+ copy_local_models()
gradioChatApp.py ADDED
@@ -0,0 +1,806 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import gradio as gr
4
+ import json
5
+ from datetime import datetime
6
+ from symbol_detection import run_detection_with_optimal_threshold
7
+ from line_detection_ai import DiagramDetectionPipeline, LineDetector, LineConfig, ImageConfig, DebugHandler, \
8
+ PointConfig, JunctionConfig, PointDetector, JunctionDetector, SymbolConfig, SymbolDetector, TagConfig, TagDetector
9
+ from data_aggregation_ai import DataAggregator
10
+ from chatbot_agent import get_assistant_response
11
+ from storage import StorageFactory, LocalStorage
12
+ import traceback
13
+ from text_detection_combined import process_drawing
14
+ from pathlib import Path
15
+ from pdf_processor import DocumentProcessor
16
+ import networkx as nx
17
+ import logging
18
+ import matplotlib.pyplot as plt
19
+ from dotenv import load_dotenv
20
+ import torch
21
+ from graph_visualization import create_graph_visualization
22
+ import shutil
23
+ from detection_schema import BBox # Add this import
24
+ import cv2
25
+ import numpy as np
26
+ import time
27
+ from huggingface_hub import HfApi, login
28
+ from download_models import download_models, copy_local_models
29
+
30
+ # Load environment variables from .env file
31
+ load_dotenv()
32
+
33
+ # Configure logging at the start of the file
34
+ logging.basicConfig(
35
+ level=logging.INFO,
36
+ format='%(asctime)s - %(levelname)s - %(message)s',
37
+ datefmt='%Y-%m-%d %H:%M:%S'
38
+ )
39
+
40
+ # Get logger for this module
41
+ logger = logging.getLogger(__name__)
42
+
43
+ # Disable duplicate logs from other modules
44
+ logging.getLogger('PIL').setLevel(logging.WARNING)
45
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
46
+ logging.getLogger('gradio').setLevel(logging.WARNING)
47
+ logging.getLogger('networkx').setLevel(logging.WARNING)
48
+ logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
49
+ logging.getLogger('symbol_detection').setLevel(logging.WARNING)
50
+
51
+
52
+ # Only log important messages
53
+ def log_process_step(message, level=logging.INFO):
54
+ """Log processing steps with appropriate level"""
55
+ if level >= logging.WARNING:
56
+ logger.log(level, message)
57
+ elif "completed" in message.lower() or "generated" in message.lower():
58
+ logger.info(message)
59
+
60
+
61
+ # Helper function to format timestamps
62
+ def get_timestamp():
63
+ return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
64
+
65
+
66
+ def format_message(role, content):
67
+ """Format message for chatbot history."""
68
+ return {"role": role, "content": content}
69
+
70
+
71
+ # Load avatar images for agents
72
+ localStorage = LocalStorage()
73
+ agent_avatar = base64.b64encode(localStorage.load_file("assets/AiAgent.png")).decode()
74
+ llm_avatar = base64.b64encode(localStorage.load_file("assets/llm.png")).decode()
75
+ user_avatar = base64.b64encode(localStorage.load_file("assets/user.png")).decode()
76
+
77
+
78
+ # Chat message formatting with avatars and enhanced HTML for readability
79
+ def chat_message(role, message, avatar, timestamp):
80
+ # Convert Markdown-style formatting to HTML
81
+ formatted_message = (
82
+ message.replace("**", "<strong>").replace("**", "</strong>")
83
+ .replace("###", "<h3>").replace("##", "<h2>")
84
+ .replace("#", "<h1>").replace("\n", "<br>")
85
+ .replace("```", "<pre><code>").replace("`", "</code></pre>")
86
+ .replace("\n1. ", "<br>1. ") # For ordered lists starting with "1."
87
+ .replace("\n2. ", "<br>2. ")
88
+ .replace("\n3. ", "<br>3. ")
89
+ .replace("\n4. ", "<br>4. ")
90
+ .replace("\n5. ", "<br>5. ")
91
+ )
92
+
93
+ return f"""
94
+ <div class="chat-message {role}">
95
+ <img src="data:image/png;base64,{avatar}" class="avatar"/>
96
+ <div>
97
+ <div class="speech-bubble {role}-bubble">{formatted_message}</div>
98
+ <div class="timestamp">{timestamp}</div>
99
+ </div>
100
+ </div>
101
+ """
102
+
103
+
104
+ def resize_to_fit(image_path, max_width=800, max_height=600):
105
+ """Resize image to fit editor while maintaining aspect ratio"""
106
+ # Read image
107
+ img = cv2.imread(image_path)
108
+ if img is None:
109
+ return None, 1.0
110
+
111
+ # Get original dimensions
112
+ h, w = img.shape[:2]
113
+
114
+ # Calculate scale factor to fit within max dimensions
115
+ scale_w = max_width / w
116
+ scale_h = max_height / h
117
+ scale = min(scale_w, scale_h)
118
+
119
+ # Always resize to fit the editor window
120
+ new_w = int(w * scale)
121
+ new_h = int(h * scale)
122
+ resized = cv2.resize(img, (new_w, new_h))
123
+ return resized, scale
124
+
125
+
126
+ # Main processing function for P&ID steps
127
+ def process_pnid(image_file, progress=gr.Progress()):
128
+ """Process P&ID document with real-time progress updates."""
129
+ try:
130
+ if image_file is None:
131
+ raise ValueError("No file uploaded. Please upload a file first.")
132
+
133
+ progress_text = []
134
+ outputs = [None] * 9 # Changed from 8 to 9 to match UI outputs
135
+ base_name = os.path.splitext(os.path.basename(image_file.name))[0] + "_page_1"
136
+
137
+ # Initialize chat history with proper format
138
+ chat_history = [{"role": "assistant", "content": "Welcome! Upload a P&ID to begin analysis."}]
139
+ outputs[7] = chat_history # Chat history moved to index 7
140
+
141
+ def update_progress(step, message):
142
+ progress_text.append(f"{datetime.now().strftime('%Y-%m-%d %H:%M:%S')} - {message}")
143
+ outputs[0] = "\n".join(progress_text) # Progress text
144
+ progress(step)
145
+
146
+ # Initialize storage and results directory
147
+ storage = StorageFactory.get_storage()
148
+ results_dir = "results"
149
+ os.makedirs(results_dir, exist_ok=True)
150
+
151
+ # Clean results directory
152
+ logger.info("Cleaned results directory: results")
153
+ for file in os.listdir(results_dir):
154
+ file_path = os.path.join(results_dir, file)
155
+ try:
156
+ if os.path.isfile(file_path):
157
+ os.unlink(file_path)
158
+ except Exception as e:
159
+ logger.error(f"Error deleting file {file_path}: {str(e)}")
160
+
161
+ # Step 1: File Upload (10%)
162
+ logger.info(f"Processing file: {os.path.basename(image_file.name)}")
163
+ update_progress(0.1, "Step 1/7: File uploaded successfully")
164
+ yield outputs
165
+
166
+ # Step 2: Document Processing - Get high quality PNG
167
+ update_progress(0.2, "Step 2/7: Processing document...")
168
+ doc_processor = DocumentProcessor(storage)
169
+ processed_pages = doc_processor.process_document(
170
+ file_path=image_file,
171
+ output_dir=results_dir
172
+ )
173
+
174
+ if not processed_pages:
175
+ raise ValueError("No pages processed from document")
176
+
177
+ # Use high quality PNG for everything
178
+ high_quality_png = processed_pages[0]
179
+ outputs[1] = high_quality_png # P&ID Tab shows original high quality
180
+ update_progress(0.25, "Document loaded and displayed")
181
+ yield outputs
182
+
183
+ # Step 3: Symbol Detection using high quality PNG
184
+ detection_image_path, detection_json_path, _, diagram_bbox = run_detection_with_optimal_threshold(
185
+ high_quality_png, # Use high quality PNG
186
+ results_dir=results_dir,
187
+ file_name=os.path.basename(high_quality_png),
188
+ storage=storage,
189
+ resize_image=False # Don't resize
190
+ )
191
+ outputs[2] = detection_image_path # Symbols Tab
192
+ symbol_json_path = detection_json_path
193
+
194
+ # Step 4: Text Detection using high quality PNG
195
+ text_results, text_summary = process_drawing(
196
+ high_quality_png, # Use high quality PNG
197
+ results_dir,
198
+ storage
199
+ )
200
+ text_json_path = text_results['json_path']
201
+ outputs[3] = text_results['image_path'] # Tags Tab
202
+
203
+ # Step 5: Line Detection (80%)
204
+ update_progress(0.80, "Step 5/7: Line Detection")
205
+ yield outputs
206
+
207
+ try:
208
+ # Initialize components
209
+ debug_handler = DebugHandler(enabled=True, storage=storage)
210
+
211
+ # Configure detectors
212
+ line_config = LineConfig()
213
+ point_config = PointConfig()
214
+ junction_config = JunctionConfig()
215
+ symbol_config = SymbolConfig(
216
+ model_path="models/Intui_SDM_41.pt",
217
+ confidence_threshold=0.5,
218
+ nms_threshold=0.3
219
+ )
220
+ tag_config = TagConfig(
221
+ model_path="models/tag_detection.json",
222
+ confidence_threshold=0.5
223
+ )
224
+
225
+ # Create all required detectors
226
+ symbol_detector = SymbolDetector(
227
+ config=symbol_config,
228
+ debug_handler=debug_handler
229
+ )
230
+
231
+ tag_detector = TagDetector(
232
+ config=tag_config,
233
+ debug_handler=debug_handler
234
+ )
235
+
236
+ line_detector = LineDetector(
237
+ config=line_config,
238
+ model_path="models/deeplsd_md.tar",
239
+ model_config={"detect_lines": True},
240
+ device=torch.device("cuda"),
241
+ debug_handler=debug_handler
242
+ )
243
+
244
+ point_detector = PointDetector(
245
+ config=point_config,
246
+ debug_handler=debug_handler
247
+ )
248
+
249
+ junction_detector = JunctionDetector(
250
+ config=junction_config,
251
+ debug_handler=debug_handler
252
+ )
253
+
254
+ # Create pipeline with all detectors
255
+ pipeline = DiagramDetectionPipeline(
256
+ tag_detector=tag_detector,
257
+ symbol_detector=symbol_detector,
258
+ line_detector=line_detector,
259
+ point_detector=point_detector,
260
+ junction_detector=junction_detector,
261
+ storage=storage,
262
+ debug_handler=debug_handler
263
+ )
264
+
265
+ # Run pipeline with original high-res image
266
+ line_results = pipeline.run(
267
+ image_path=high_quality_png,
268
+ output_dir=results_dir,
269
+ config=ImageConfig()
270
+ )
271
+ line_json_path = line_results.json_path
272
+ outputs[4] = line_results.image_path
273
+
274
+ # Verify line detection output
275
+ if not os.path.exists(line_json_path):
276
+ raise ValueError(f"Line detection JSON not found: {line_json_path}")
277
+
278
+ # Verify line detection JSON content
279
+ with open(line_json_path, 'r') as f:
280
+ line_data = json.load(f)
281
+ if 'lines' not in line_data:
282
+ raise ValueError(f"Invalid line detection data format in {line_json_path}")
283
+ logger.info(f"Line detection completed successfully with {len(line_data['lines'])} lines")
284
+
285
+ # Verify all required JSONs exist before aggregation
286
+ required_jsons = {
287
+ 'symbols': symbol_json_path,
288
+ 'texts': text_json_path,
289
+ 'lines': line_json_path
290
+ }
291
+
292
+ for name, path in required_jsons.items():
293
+ if not os.path.exists(path):
294
+ raise ValueError(f"{name} JSON not found: {path}")
295
+ # Verify JSON can be loaded
296
+ with open(path, 'r') as f:
297
+ data = json.load(f)
298
+ logger.info(f"Loaded {name} JSON with {len(data.get('detections', data.get('lines', [])))} items")
299
+
300
+ # Data Aggregation
301
+ aggregator = DataAggregator(storage=storage)
302
+ aggregated_result = aggregator.process_data(
303
+ image_path=high_quality_png,
304
+ output_dir=results_dir,
305
+ symbols_path=symbol_json_path,
306
+ texts_path=text_json_path,
307
+ lines_path=line_json_path
308
+ )
309
+
310
+ # Verify aggregation result before graph creation
311
+ if not aggregated_result.get('success'):
312
+ raise ValueError(f"Data aggregation failed: {aggregated_result.get('error')}")
313
+
314
+ aggregated_json_path = aggregated_result['json_path']
315
+ if not os.path.exists(aggregated_json_path):
316
+ raise ValueError(f"Aggregated JSON not found: {aggregated_json_path}")
317
+
318
+ # Verify aggregated JSON content
319
+ with open(aggregated_json_path, 'r') as f:
320
+ aggregated_data = json.load(f)
321
+ required_keys = ['nodes', 'edges', 'symbols', 'texts', 'lines']
322
+ missing_keys = [k for k in required_keys if k not in aggregated_data]
323
+ if missing_keys:
324
+ raise ValueError(f"Aggregated JSON missing required keys: {missing_keys}")
325
+ logger.info("Aggregation completed successfully with:")
326
+ logger.info(f"- {len(aggregated_data['nodes'])} nodes")
327
+ logger.info(f"- {len(aggregated_data['edges'])} edges")
328
+
329
+ # After aggregation, create graph visualization
330
+ update_progress(0.85, "Step 6/7: Creating Knowledge Graph")
331
+ try:
332
+ # Create graph visualization
333
+ graph_results = create_graph_visualization(
334
+ json_path=aggregated_json_path,
335
+ output_dir=results_dir,
336
+ base_name=base_name,
337
+ save_plot=True
338
+ )
339
+
340
+ if not graph_results.get('success'):
341
+ logger.error(f"Error in graph generation: {graph_results.get('error')}")
342
+ raise Exception(graph_results.get('error'))
343
+
344
+ graph_path = f"results/{base_name}_graph_visualization.png"
345
+ if not os.path.exists(graph_path):
346
+ raise Exception("Graph visualization file not created")
347
+
348
+ update_progress(0.90, "Step 6/7: Knowledge Graph Created")
349
+
350
+ except Exception as e:
351
+ logger.error(f"Error creating graph visualization: {str(e)}")
352
+ raise
353
+
354
+ # Fix output assignments
355
+ outputs[0] = progress_text # Progress text
356
+ outputs[1] = high_quality_png # P&ID
357
+ outputs[2] = detection_image_path # Symbols
358
+ outputs[3] = text_results['image_path'] # Tags
359
+ outputs[4] = line_results.image_path # Lines
360
+ outputs[5] = f"results/{base_name}_aggregated.png" # Aggregated
361
+ outputs[6] = graph_path # Graph visualization
362
+ outputs[7] = chat_history # Chat
363
+ outputs[8] = aggregated_json_path # JSON state
364
+
365
+ # Update progress with all steps
366
+ update_progress(0.95, "Step 7/7: Finalizing Results")
367
+ chat_history = [{"role": "assistant", "content": "Processing complete! I can help answer questions about the P&ID contents."}]
368
+ outputs[7] = chat_history
369
+
370
+ update_progress(1.0, "βœ… Processing Complete")
371
+ yield outputs
372
+
373
+ except Exception as e:
374
+ # Update chat with error message
375
+ chat_history = [{"role": "assistant", "content": f"Error during processing: {str(e)}"}]
376
+ outputs[7] = chat_history
377
+ raise
378
+
379
+ except Exception as e:
380
+ logger.error(f"Error in process_pnid: {str(e)}")
381
+ logger.error(f"Stack trace:\n{traceback.format_exc()}")
382
+ # Update chat with error message
383
+ chat_history = [{"role": "assistant", "content": f"Error: {str(e)}"}]
384
+ outputs[7] = chat_history
385
+ raise
386
+
387
+
388
+ # Separate function for Chat interaction
389
+ def handle_user_message(user_input, chat_history, json_path_state):
390
+ """Handle user messages and generate responses."""
391
+ try:
392
+ if not user_input or not user_input.strip():
393
+ return chat_history
394
+
395
+ # Add user message
396
+ timestamp = get_timestamp()
397
+ new_history = chat_history + chat_message("user", user_input, user_avatar, timestamp)
398
+
399
+ # Check if json_path exists and is valid
400
+ if not json_path_state or not os.path.exists(json_path_state):
401
+ error_message = "Please upload and process a P&ID document first."
402
+ return new_history + chat_message("assistant", error_message, agent_avatar, get_timestamp())
403
+
404
+ try:
405
+ # Log for debugging
406
+ logger.info(f"Sending question to assistant: {user_input}")
407
+ logger.info(f"Using JSON path: {json_path_state}")
408
+
409
+ # Generate response
410
+ response = get_assistant_response(user_input, json_path_state)
411
+
412
+ # Handle the response
413
+ if isinstance(response, (str, dict)):
414
+ response_text = str(response)
415
+ else:
416
+ try:
417
+ # Try to get the first response from generator
418
+ response_text = next(response) if hasattr(response, '__next__') else str(response)
419
+ except StopIteration:
420
+ response_text = "I apologize, but I couldn't generate a response."
421
+ except Exception as e:
422
+ logger.error(f"Error processing response: {str(e)}")
423
+ response_text = "I apologize, but I encountered an error processing your request."
424
+
425
+ logger.info(f"Generated response: {response_text}")
426
+
427
+ if not response_text.strip():
428
+ response_text = "I apologize, but I couldn't generate a response. Please try asking your question differently."
429
+
430
+ # Add response to chat history
431
+ new_history += chat_message("assistant", response_text, agent_avatar, get_timestamp())
432
+
433
+ except Exception as e:
434
+ logger.error(f"Error generating response: {str(e)}")
435
+ logger.error(traceback.format_exc())
436
+ error_message = "I apologize, but I encountered an error processing your request. Please try again."
437
+ new_history += chat_message("assistant", error_message, agent_avatar, get_timestamp())
438
+
439
+ return new_history
440
+
441
+ except Exception as e:
442
+ logger.error(f"Chat error: {str(e)}")
443
+ logger.error(traceback.format_exc())
444
+ return chat_history + chat_message(
445
+ "assistant",
446
+ "I apologize, but something went wrong. Please try again.",
447
+ agent_avatar,
448
+ get_timestamp()
449
+ )
450
+
451
+
452
+ # Update custom CSS
453
+ custom_css = """
454
+ .full-height-row {
455
+ height: calc(100vh - 150px); /* Adjusted height */
456
+ margin: 0;
457
+ padding: 10px;
458
+ }
459
+ .upload-box {
460
+ background: #2a2a2a;
461
+ border-radius: 8px;
462
+ padding: 15px;
463
+ margin-bottom: 15px;
464
+ border: 1px solid #3a3a3a;
465
+ }
466
+ .status-box-container {
467
+ background: #2a2a2a;
468
+ border-radius: 8px;
469
+ padding: 15px;
470
+ height: calc(100vh - 350px); /* Reduced height */
471
+ border: 1px solid #3a3a3a;
472
+ margin-bottom: 15px;
473
+ }
474
+ .status-box {
475
+ font-family: 'Courier New', monospace;
476
+ font-size: 12px;
477
+ line-height: 1.4;
478
+ background-color: #1a1a1a;
479
+ color: #00ff00;
480
+ padding: 10px;
481
+ border-radius: 5px;
482
+ height: calc(100% - 40px); /* Adjust for header */
483
+ overflow-y: auto;
484
+ white-space: pre-wrap;
485
+ word-wrap: break-word;
486
+ border: none;
487
+ }
488
+ .preview-tabs {
489
+ height: calc(100vh - 100px); /* Increased container height from 200px */
490
+ background: #2a2a2a;
491
+ border-radius: 8px;
492
+ padding: 15px;
493
+ border: 1px solid #3a3a3a;
494
+ margin-bottom: 15px;
495
+ }
496
+ .chat-container {
497
+ height: 100%; /* Take full height */
498
+ display: flex;
499
+ flex-direction: column;
500
+ background: #2a2a2a;
501
+ border-radius: 8px;
502
+ padding: 15px;
503
+ border: 1px solid #3a3a3a;
504
+ }
505
+ .chatbox {
506
+ flex: 1; /* Take remaining space */
507
+ overflow-y: auto;
508
+ background: #1a1a1a;
509
+ border-radius: 8px;
510
+ padding: 15px;
511
+ margin-bottom: 15px;
512
+ color: #ffffff;
513
+ min-height: 200px; /* Ensure minimum height */
514
+ }
515
+ .chat-input-group {
516
+ height: auto; /* Allow natural height */
517
+ min-height: 120px; /* Minimum height for input area */
518
+ background: #1a1a1a;
519
+ border-radius: 8px;
520
+ padding: 15px;
521
+ margin-top: auto; /* Push to bottom */
522
+ }
523
+ .chat-input {
524
+ background: #2a2a2a;
525
+ color: #ffffff;
526
+ border: 1px solid #3a3a3a;
527
+ border-radius: 5px;
528
+ padding: 12px;
529
+ min-height: 80px;
530
+ width: 100%;
531
+ margin-bottom: 10px;
532
+ }
533
+ .send-button {
534
+ width: 100%;
535
+ background: #4a4a4a;
536
+ color: #ffffff;
537
+ border-radius: 5px;
538
+ border: none;
539
+ padding: 12px;
540
+ cursor: pointer;
541
+ transition: background-color 0.3s;
542
+ }
543
+ .result-image {
544
+ border-radius: 8px;
545
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
546
+ margin: 10px 0;
547
+ background: #ffffff;
548
+ }
549
+ .chat-message {
550
+ display: flex;
551
+ margin-bottom: 1rem;
552
+ align-items: flex-start;
553
+ }
554
+ .chat-message .avatar {
555
+ width: 40px;
556
+ height: 40px;
557
+ margin-right: 10px;
558
+ border-radius: 50%;
559
+ }
560
+ .chat-message .speech-bubble {
561
+ background: #2a2a2a;
562
+ padding: 10px 15px;
563
+ border-radius: 10px;
564
+ max-width: 80%;
565
+ margin-bottom: 5px;
566
+ }
567
+ .chat-message .timestamp {
568
+ font-size: 0.8em;
569
+ color: #666;
570
+ }
571
+ .logo-row {
572
+ width: 100%;
573
+ background-color: #1a1a1a;
574
+ padding: 10px 0;
575
+ margin: 0;
576
+ border-bottom: 1px solid #3a3a3a;
577
+ }
578
+ """
579
+
580
+
581
+ def create_ui():
582
+ current_dir = os.path.dirname(os.path.abspath(__file__))
583
+ logo_path = os.path.join(current_dir, "assets", "intuigence.png")
584
+
585
+ css = """
586
+ /* Theme colors */
587
+ :root {
588
+ --orange-primary: #ff6b2b;
589
+ --orange-hover: #ff8651;
590
+ --orange-light: rgba(255, 107, 43, 0.1);
591
+ }
592
+
593
+ /* Logo styling */
594
+ .logo-container {
595
+ padding: 10px 20px;
596
+ margin-bottom: 10px;
597
+ text-align: left;
598
+ width: 100%;
599
+ background: #1a1a1a; /* Match app background */
600
+ border-bottom: 1px solid #3a3a3a;
601
+ }
602
+ .logo-container img {
603
+ max-height: 40px;
604
+ width: auto;
605
+ display: inline-block !important;
606
+ }
607
+ /* Hide download and fullscreen buttons for logo */
608
+ .logo-container .download-button,
609
+ .logo-container .fullscreen-button {
610
+ display: none !important;
611
+ }
612
+ /* Adjust main content padding */
613
+ .main-content {
614
+ padding-top: 10px;
615
+ }
616
+ /* Custom orange theme */
617
+ .primary-button {
618
+ background: var(--orange-primary) !important;
619
+ color: white !important;
620
+ border: none !important;
621
+ }
622
+ .primary-button:hover {
623
+ background: var(--orange-hover) !important;
624
+ }
625
+
626
+ /* Tab styling */
627
+ .tabs > .tab-nav > button.selected {
628
+ border-color: var(--orange-primary) !important;
629
+ color: var(--orange-primary) !important;
630
+ }
631
+ .tabs > .tab-nav > button:hover {
632
+ border-color: var(--orange-hover) !important;
633
+ color: var(--orange-hover) !important;
634
+ }
635
+
636
+ /* File upload button */
637
+ .file-upload {
638
+ background: var(--orange-primary) !important;
639
+ }
640
+
641
+ /* Progress bar */
642
+ .progress-bar > div {
643
+ background: var(--orange-primary) !important;
644
+ }
645
+
646
+ /* Tags and labels */
647
+ .label-wrap {
648
+ background: var(--orange-primary) !important;
649
+ }
650
+
651
+ /* Selected/active states */
652
+ .selected, .active, .focused {
653
+ border-color: var(--orange-primary) !important;
654
+ color: var(--orange-primary) !important;
655
+ }
656
+
657
+ /* Links and interactive elements */
658
+ a, .link, .interactive {
659
+ color: var(--orange-primary) !important;
660
+ }
661
+ a:hover, .link:hover, .interactive:hover {
662
+ color: var(--orange-hover) !important;
663
+ }
664
+
665
+ /* Input focus states */
666
+ input:focus, textarea:focus {
667
+ border-color: var(--orange-primary) !important;
668
+ box-shadow: 0 0 0 1px var(--orange-light) !important;
669
+ }
670
+
671
+ /* Checkbox and radio */
672
+ input[type="checkbox"]:checked, input[type="radio"]:checked {
673
+ background-color: var(--orange-primary) !important;
674
+ border-color: var(--orange-primary) !important;
675
+ }
676
+ """
677
+
678
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as demo:
679
+ # Logo row (before main content)
680
+ with gr.Row(elem_classes="logo-container"):
681
+ gr.Image(
682
+ value=logo_path,
683
+ show_label=False,
684
+ container=False,
685
+ interactive=False,
686
+ show_download_button=False,
687
+ show_share_button=False,
688
+ height=40
689
+ )
690
+
691
+ # State for storing file path
692
+ file_path = gr.State()
693
+ json_path = gr.State()
694
+
695
+ # Main content row
696
+ with gr.Row(elem_classes="main-content"):
697
+ # Left column - File Upload & Processing
698
+ with gr.Column(scale=3, elem_classes="column-panel"):
699
+ file_output = gr.File(label="Upload P&ID Document")
700
+ process_button = gr.Button(
701
+ "Process Document",
702
+ elem_classes="primary-button" # Add custom class
703
+ )
704
+ progress_output = gr.Textbox(
705
+ label="Progress",
706
+ value="Waiting for document...",
707
+ interactive=False
708
+ )
709
+
710
+ # Center column - Preview Panel
711
+ with gr.Column(scale=5, elem_classes="column-panel preview-panel"):
712
+ with gr.Tabs() as tabs:
713
+ with gr.TabItem("P&ID"):
714
+ input_image = gr.Image(type="filepath", label="Original")
715
+ with gr.TabItem("Symbols"):
716
+ symbol_image = gr.Image(type="filepath", label="Detected Symbols")
717
+ with gr.TabItem("Tags"):
718
+ text_image = gr.Image(type="filepath", label="Detected Tags")
719
+ with gr.TabItem("Lines"):
720
+ line_image = gr.Image(type="filepath", label="Detected Lines")
721
+ with gr.TabItem("Aggregated"):
722
+ aggregated_image = gr.Image(type="filepath", label="Aggregated Results")
723
+ with gr.TabItem("Knowledge Graph"):
724
+ graph_image = gr.Image(type="filepath", label="Knowledge Graph")
725
+
726
+ # Right column - Chat Interface
727
+ with gr.Column(scale=4, elem_classes="column-panel chat-panel", elem_id="chat-panel"):
728
+ chat_history = gr.Chatbot(
729
+ [],
730
+ elem_classes="chat-history",
731
+ height=400,
732
+ show_label=False,
733
+ type="messages",
734
+ elem_id="chat-history"
735
+ )
736
+ with gr.Row():
737
+ chat_input = gr.Textbox(
738
+ placeholder="Ask me about the P&ID...",
739
+ show_label=False,
740
+ container=False
741
+ )
742
+ chat_button = gr.Button(
743
+ "Send",
744
+ elem_classes="primary-button" # Add custom class
745
+ )
746
+
747
+ def handle_chat(user_message, chat_history, json_path):
748
+ if not user_message:
749
+ return "", chat_history
750
+
751
+ # Add user message
752
+ chat_history = chat_history + [{"role": "user", "content": user_message}]
753
+
754
+ try:
755
+ # Get assistant response
756
+ response = get_assistant_response(user_message, json_path)
757
+ # Add assistant response
758
+ chat_history = chat_history + [{"role": "assistant", "content": response}]
759
+ except Exception as e:
760
+ logger.error(f"Error in chat response: {str(e)}")
761
+ chat_history = chat_history + [
762
+ {"role": "assistant", "content": "I apologize, but I encountered an error processing your request."}
763
+ ]
764
+
765
+ return "", chat_history
766
+
767
+ # Connect UI elements
768
+ chat_input.submit(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history])
769
+ chat_button.click(handle_chat, [chat_input, chat_history, json_path], [chat_input, chat_history])
770
+
771
+ process_button.click(
772
+ process_pnid,
773
+ inputs=[file_output],
774
+ outputs=[
775
+ progress_output, # Progress text (0)
776
+ input_image, # P&ID (1)
777
+ symbol_image, # Symbols (2)
778
+ text_image, # Tags (3)
779
+ line_image, # Lines (4)
780
+ aggregated_image, # Aggregated (5)
781
+ graph_image, # Graph (6)
782
+ chat_history, # Chat (7)
783
+ json_path # State (8)
784
+ ],
785
+ show_progress="hidden" # Hide progress in tabs
786
+ )
787
+
788
+ return demo
789
+
790
+
791
+ def main():
792
+ # Download models if they don't exist
793
+ if not os.path.exists('models/yolo/yolov8n.pt'):
794
+ copy_local_models()
795
+
796
+ demo = create_ui()
797
+ # Remove HF Spaces conditional, just use local development settings
798
+ demo.launch(
799
+ server_name="0.0.0.0",
800
+ server_port=7861, # Changed from 7860
801
+ share=True
802
+ )
803
+
804
+
805
+ if __name__ == "__main__":
806
+ main()
graph_construction.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import networkx as nx
4
+ import matplotlib.pyplot as plt
5
+ from pathlib import Path
6
+ import logging
7
+ import traceback
8
+ from storage import StorageFactory
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
14
+ """Construct network graph from aggregated detection data"""
15
+ try:
16
+ # Use provided storage or get a new one
17
+ if storage is None:
18
+ storage = StorageFactory.get_storage()
19
+
20
+ # Create graph
21
+ G = nx.Graph()
22
+ pos = {} # For node positions
23
+
24
+ # Add nodes from the aggregated data
25
+ for node in data.get('nodes', []):
26
+ node_id = node['id']
27
+ node_type = node['type']
28
+
29
+ # Calculate position based on node type
30
+ if node_type == 'connection_point':
31
+ pos[node_id] = (node['coords']['x'], node['coords']['y'])
32
+ else: # symbol or text
33
+ bbox = node['bbox']
34
+ pos[node_id] = (
35
+ (bbox['xmin'] + bbox['xmax']) / 2,
36
+ (bbox['ymin'] + bbox['ymax']) / 2
37
+ )
38
+
39
+ # Add node with all its properties
40
+ G.add_node(node_id, **node)
41
+
42
+ # Add edges from the aggregated data
43
+ for edge in data.get('edges', []):
44
+ G.add_edge(
45
+ edge['source'],
46
+ edge['target'],
47
+ **edge.get('properties', {})
48
+ )
49
+
50
+ # Create visualization
51
+ plt.figure(figsize=(20, 20))
52
+
53
+ # Draw nodes with different colors based on type
54
+ node_colors = []
55
+ for node in G.nodes():
56
+ node_type = G.nodes[node]['type']
57
+ if node_type == 'symbol':
58
+ node_colors.append('lightblue')
59
+ elif node_type == 'text':
60
+ node_colors.append('lightgreen')
61
+ else: # connection_point
62
+ node_colors.append('lightgray')
63
+
64
+ nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
65
+ nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
66
+
67
+ # Add labels
68
+ labels = {}
69
+ for node in G.nodes():
70
+ node_data = G.nodes[node]
71
+ if node_data['type'] == 'symbol':
72
+ labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
73
+ elif node_data['type'] == 'text':
74
+ content = node_data.get('content', '')
75
+ labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
76
+ else:
77
+ labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
78
+
79
+ nx.draw_networkx_labels(G, pos, labels, font_size=8)
80
+
81
+ plt.title("P&ID Knowledge Graph")
82
+ plt.axis('off')
83
+
84
+ # Save the visualization
85
+ graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
86
+ plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
87
+ plt.close()
88
+
89
+ # Save graph data as JSON for future use
90
+ graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
91
+ with open(graph_json_path, 'w') as f:
92
+ json.dump(nx.node_link_data(G), f, indent=2)
93
+
94
+ return G, pos, plt.gcf()
95
+
96
+ except Exception as e:
97
+ logger.error(f"Error in construct_graph_network: {str(e)}")
98
+ traceback.print_exc()
99
+ return None, None, None
100
+
101
+
102
+ if __name__ == "__main__":
103
+ # Test code
104
+ test_data_path = "results/test_aggregated.json"
105
+ if os.path.exists(test_data_path):
106
+ with open(test_data_path, 'r') as f:
107
+ test_data = json.load(f)
108
+
109
+ G, pos, fig = construct_graph_network(
110
+ test_data,
111
+ "results/validation.json",
112
+ "results"
113
+ )
114
+ if fig:
115
+ plt.show()
graph_processor.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import networkx as nx
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import traceback
6
+ import uuid
7
+
8
+
9
+ def create_connected_graph(input_data):
10
+ """Create a connected graph from the input data"""
11
+ try:
12
+ # Validate input data structure
13
+ if not isinstance(input_data, dict):
14
+ raise ValueError("Invalid input data format")
15
+
16
+ # Check for required keys in new format
17
+ required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges']
18
+ if not all(key in input_data for key in required_keys):
19
+ raise ValueError(f"Missing required keys in input data. Expected: {required_keys}")
20
+
21
+ # Create graph
22
+ G = nx.Graph()
23
+
24
+ # Track positions for layout
25
+ pos = {}
26
+
27
+ # Add symbol nodes
28
+ for symbol in input_data['symbols']:
29
+ bbox = symbol.get('bbox', [])
30
+ symbol_id = symbol.get('id', str(uuid.uuid4()))
31
+
32
+ if bbox:
33
+ # Calculate center position
34
+ center_x = (bbox['xmin'] + bbox['xmax']) / 2
35
+ center_y = (bbox['ymin'] + bbox['ymax']) / 2
36
+ pos[symbol_id] = (center_x, center_y)
37
+
38
+ G.add_node(
39
+ symbol_id,
40
+ type='symbol',
41
+ class_name=symbol.get('class', ''),
42
+ bbox=bbox,
43
+ confidence=symbol.get('confidence', 0.0)
44
+ )
45
+
46
+ # Add text nodes
47
+ for text in input_data['texts']:
48
+ bbox = text.get('bbox', [])
49
+ text_id = text.get('id', str(uuid.uuid4()))
50
+
51
+ if bbox:
52
+ center_x = (bbox['xmin'] + bbox['xmax']) / 2
53
+ center_y = (bbox['ymin'] + bbox['ymax']) / 2
54
+ pos[text_id] = (center_x, center_y)
55
+
56
+ G.add_node(
57
+ text_id,
58
+ type='text',
59
+ text=text.get('text', ''),
60
+ bbox=bbox,
61
+ confidence=text.get('confidence', 0.0)
62
+ )
63
+
64
+ # Add edges from the edges list
65
+ for edge in input_data['edges']:
66
+ source = edge.get('source')
67
+ target = edge.get('target')
68
+ if source and target and source in G and target in G:
69
+ G.add_edge(
70
+ source,
71
+ target,
72
+ type=edge.get('type', 'connection'),
73
+ properties=edge.get('properties', {})
74
+ )
75
+
76
+ # Create visualization
77
+ plt.figure(figsize=(20, 20))
78
+
79
+ # Draw nodes with fixed positions
80
+ nx.draw_networkx_nodes(G, pos,
81
+ node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node
82
+ in G.nodes()],
83
+ node_size=500)
84
+
85
+ # Draw edges
86
+ nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
87
+
88
+ # Add labels
89
+ labels = {}
90
+ for node in G.nodes():
91
+ node_data = G.nodes[node]
92
+ if node_data['type'] == 'symbol':
93
+ labels[node] = f"S:{node_data['class_name']}"
94
+ else:
95
+ text = node_data.get('text', '')
96
+ labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}"
97
+
98
+ nx.draw_networkx_labels(G, pos, labels, font_size=8)
99
+
100
+ plt.title("P&ID Network Graph")
101
+ plt.axis('off')
102
+
103
+ return G, pos, plt.gcf()
104
+
105
+ except Exception as e:
106
+ print(f"Error in create_connected_graph: {str(e)}")
107
+ traceback.print_exc()
108
+ return None, None, None
109
+
110
+
111
+ if __name__ == "__main__":
112
+ # Test code
113
+ with open('results/0_aggregated.json') as f:
114
+ data = json.load(f)
115
+
116
+ G, pos, fig = create_connected_graph(data)
117
+ if fig:
118
+ plt.show()
graph_visualization.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import networkx as nx
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from pprint import pprint
6
+ import uuid
7
+ import argparse
8
+ from pathlib import Path
9
+ from tqdm import tqdm
10
+
11
+ def create_graph_visualization(json_path: str, output_dir: str, base_name: str, save_plot: bool = True) -> dict:
12
+ """Create graph visualization using actual coordinates from bboxes"""
13
+ try:
14
+ # Remove '_aggregated' suffix if present
15
+ if base_name.endswith('_aggregated'):
16
+ base_name = base_name[:-len('_aggregated')]
17
+
18
+ print("\nLoading JSON data...")
19
+ with open(json_path, 'r') as f:
20
+ data = json.load(f)
21
+
22
+ # Create graph
23
+ G = nx.Graph()
24
+ pos = {}
25
+ valid_nodes = []
26
+ invalid_nodes = []
27
+
28
+ # First pass - collect valid nodes
29
+ print("\nValidating nodes...")
30
+ for node in tqdm(data.get('nodes', []), desc="Validating"):
31
+ try:
32
+ node_id = str(node.get('id', ''))
33
+ x = float(node.get('x', 0))
34
+ y = float(node.get('y', 0))
35
+
36
+ if node_id and x and y: # Only add if we have valid coordinates
37
+ valid_nodes.append(node)
38
+ pos[node_id] = (x, y)
39
+ else:
40
+ invalid_nodes.append(node)
41
+ except (ValueError, TypeError) as e:
42
+ invalid_nodes.append(node)
43
+ continue
44
+
45
+ print(f"\nFound {len(valid_nodes)} valid nodes and {len(invalid_nodes)} invalid nodes")
46
+
47
+ # Add valid nodes
48
+ print("\nAdding valid nodes...")
49
+ for node in tqdm(valid_nodes, desc="Nodes"):
50
+ node_id = str(node.get('id', ''))
51
+ attrs = {
52
+ 'type': node.get('type', ''),
53
+ 'label': node.get('label', ''),
54
+ 'x': float(node.get('x', 0)),
55
+ 'y': float(node.get('y', 0))
56
+ }
57
+ G.add_node(node_id, **attrs)
58
+
59
+ # Add valid edges (only between valid nodes)
60
+ print("\nAdding valid edges...")
61
+ valid_edges = []
62
+ invalid_edges = []
63
+
64
+ for edge in tqdm(data.get('edges', []), desc="Edges"):
65
+ try:
66
+ start_id = str(edge.get('start_point', ''))
67
+ end_id = str(edge.get('end_point', ''))
68
+
69
+ if start_id in pos and end_id in pos: # Only add if both nodes exist
70
+ valid_edges.append(edge)
71
+ attrs = {
72
+ 'type': edge.get('type', ''),
73
+ 'weight': edge.get('weight', 1.0)
74
+ }
75
+ G.add_edge(start_id, end_id, **attrs)
76
+ else:
77
+ invalid_edges.append(edge)
78
+ except Exception as e:
79
+ invalid_edges.append(edge)
80
+ continue
81
+
82
+ print(f"\nFound {len(valid_edges)} valid edges and {len(invalid_edges)} invalid edges")
83
+
84
+ if save_plot:
85
+ print("\nGenerating visualization...")
86
+ plt.figure(figsize=(20, 20))
87
+
88
+ print("Drawing graph elements...")
89
+ with tqdm(total=3, desc="Drawing") as pbar:
90
+ # Draw nodes
91
+ nx.draw_networkx_nodes(G, pos,
92
+ node_color='lightblue',
93
+ node_size=100)
94
+ pbar.update(1)
95
+
96
+ # Draw edges
97
+ nx.draw_networkx_edges(G, pos)
98
+ pbar.update(1)
99
+
100
+ # Save plot
101
+ image_path = os.path.join(output_dir, f"{base_name}_graph_visualization.png")
102
+ plt.savefig(image_path, bbox_inches='tight', dpi=300)
103
+ plt.close()
104
+ pbar.update(1)
105
+
106
+ print(f"\nVisualization saved to: {image_path}")
107
+ return {
108
+ 'success': True,
109
+ 'image_path': image_path,
110
+ 'graph': G,
111
+ 'stats': {
112
+ 'valid_nodes': len(valid_nodes),
113
+ 'invalid_nodes': len(invalid_nodes),
114
+ 'valid_edges': len(valid_edges),
115
+ 'invalid_edges': len(invalid_edges)
116
+ }
117
+ }
118
+
119
+ return {
120
+ 'success': True,
121
+ 'graph': G
122
+ }
123
+
124
+ except Exception as e:
125
+ print(f"\nError creating graph: {str(e)}")
126
+ return {
127
+ 'success': False,
128
+ 'error': str(e)
129
+ }
130
+
131
+ if __name__ == "__main__":
132
+ """Test the graph visualization independently"""
133
+
134
+ # Set up argument parser
135
+ parser = argparse.ArgumentParser(description='Create and visualize graph from aggregated JSON')
136
+ parser.add_argument('--json_path', type=str, default="results/002_page_1_aggregated.json",
137
+ help='Path to aggregated JSON file')
138
+ parser.add_argument('--output_dir', type=str, default="results",
139
+ help='Directory to save outputs')
140
+ parser.add_argument('--show', action='store_true',
141
+ help='Show the plot interactively')
142
+
143
+ args = parser.parse_args()
144
+
145
+ # Verify input file exists
146
+ if not os.path.exists(args.json_path):
147
+ print(f"Error: Could not find input file {args.json_path}")
148
+ exit(1)
149
+
150
+ # Create output directory if it doesn't exist
151
+ os.makedirs(args.output_dir, exist_ok=True)
152
+
153
+ # Get base name from input file and remove '_aggregated' suffix
154
+ base_name = Path(args.json_path).stem
155
+ if base_name.endswith('_aggregated'):
156
+ base_name = base_name[:-len('_aggregated')]
157
+
158
+ print(f"\nProcessing:")
159
+ print(f"Input: {args.json_path}")
160
+ print(f"Output: {args.output_dir}/{base_name}_graph_visualization.png")
161
+
162
+ try:
163
+ # Create visualization
164
+ result = create_graph_visualization(
165
+ json_path=args.json_path,
166
+ output_dir=args.output_dir,
167
+ base_name=base_name,
168
+ save_plot=True
169
+ )
170
+
171
+ if result['success']:
172
+ print(f"\nSuccess! Graph visualization saved to: {result['image_path']}")
173
+ if args.show:
174
+ plt.show()
175
+ else:
176
+ print(f"\nError: {result['error']}")
177
+
178
+ except Exception as e:
179
+ print(f"\nError during visualization: {str(e)}")
180
+ raise
line_detection_ai.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base import BaseDetector, BaseDetectionPipeline
2
+ from utils import *
3
+ from config import (
4
+ ImageConfig,
5
+ SymbolConfig,
6
+ TagConfig,
7
+ LineConfig,
8
+ PointConfig,
9
+ JunctionConfig
10
+ )
11
+ from detectors import (
12
+ LineDetector,
13
+ PointDetector,
14
+ JunctionDetector,
15
+ SymbolDetector,
16
+ TagDetector
17
+ )
18
+ from pathlib import Path
19
+ from storage import StorageFactory
20
+ from common import DetectionResult
21
+ from detection_schema import DetectionContext, JunctionType
22
+ from typing import List, Tuple, Optional, Dict
23
+ import torch
24
+ import numpy as np
25
+ import cv2
26
+ import os
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class DiagramDetectionPipeline:
33
+ """
34
+ Pipeline that runs multiple detectors (line, point, junction, etc.) on an image,
35
+ and keeps a shared DetectionContext in memory.
36
+ """
37
+
38
+ def __init__(self,
39
+ tag_detector: Optional[BaseDetector],
40
+ symbol_detector: Optional[BaseDetector],
41
+ line_detector: Optional[BaseDetector],
42
+ point_detector: Optional[BaseDetector],
43
+ junction_detector: Optional[BaseDetector],
44
+ storage: StorageInterface,
45
+ debug_handler: Optional[DebugHandler] = None,
46
+ transformer: Optional[CoordinateTransformer] = None):
47
+ """
48
+ You can pass None for detectors you don't need.
49
+ """
50
+ # super().__init__(storage=storage, debug_handler=debug_handler)
51
+ self.storage = storage
52
+ self.debug_handler = debug_handler
53
+ self.tag_detector = tag_detector
54
+ self.symbol_detector = symbol_detector
55
+ self.line_detector = line_detector
56
+ self.point_detector = point_detector
57
+ self.junction_detector = junction_detector
58
+ self.transformer = transformer or CoordinateTransformer()
59
+
60
+ def _load_image(self, image_path: str) -> np.ndarray:
61
+ """Load image with validation."""
62
+ image_data = self.storage.load_file(image_path)
63
+ image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
64
+ if image is None:
65
+ raise ValueError(f"Failed to load image from {image_path}")
66
+ return image
67
+
68
+ def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]:
69
+ """Crop to ROI if provided, else return full image."""
70
+ if roi is not None and len(roi) == 4:
71
+ x_min, y_min, x_max, y_max = roi
72
+ return image[y_min:y_max, x_min:x_max], (x_min, y_min)
73
+ return image, (0, 0)
74
+
75
+ def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray:
76
+ """Fill symbol & tag bounding boxes with white to avoid line detection picking them up."""
77
+ masked = image.copy()
78
+ for sym in context.symbols.values():
79
+ cv2.rectangle(masked,
80
+ (sym.bbox.xmin, sym.bbox.ymin),
81
+ (sym.bbox.xmax, sym.bbox.ymax),
82
+ (255, 255, 255), # White
83
+ thickness=-1)
84
+
85
+ for tg in context.tags.values():
86
+ cv2.rectangle(masked,
87
+ (tg.bbox.xmin, tg.bbox.ymin),
88
+ (tg.bbox.xmax, tg.bbox.ymax),
89
+ (255, 255, 255),
90
+ thickness=-1)
91
+ return masked
92
+
93
+ def run(
94
+ self,
95
+ image_path: str,
96
+ output_dir: str,
97
+ config
98
+ ) -> DetectionResult:
99
+ """
100
+ Main pipeline steps (in local coords):
101
+ 1) Load + crop image
102
+ 2) Detect symbols & tags
103
+ 3) Make a copy for final debug images
104
+ 4) White out symbol/tag bounding boxes
105
+ 5) Detect lines, points, junctions
106
+ 6) Save final JSON
107
+ 7) Generate debug images with various combinations
108
+ """
109
+ try:
110
+ with self.debug_handler.track_performance("total_processing"):
111
+ # 1) Load & crop
112
+ image = self._load_image(image_path)
113
+ # cropped_image, roi_offset = self._crop_to_roi(image, config.roi)
114
+
115
+ # 2) Create fresh context
116
+ context = DetectionContext()
117
+
118
+ # 3) Detect symbols
119
+ with self.debug_handler.track_performance("symbol_detection"):
120
+ self.symbol_detector.detect(
121
+ image,
122
+ context=context,
123
+ )
124
+
125
+ # 4) Detect tags
126
+ with self.debug_handler.track_performance("tag_detection"):
127
+ self.tag_detector.detect(
128
+ image,
129
+ context=context,
130
+ )
131
+
132
+ # Make a copy of the cropped image for final debug combos
133
+ debug_cropped = image.copy()
134
+
135
+ # 5) White-out symbol/tag bboxes in the original cropped image
136
+ cropped_image = self._remove_symbol_tag_bboxes(image, context)
137
+
138
+ # 6) Detect lines
139
+ with self.debug_handler.track_performance("line_detection"):
140
+ self.line_detector.detect(cropped_image, context=context)
141
+
142
+ # 7) Detect points
143
+ if self.point_detector:
144
+ with self.debug_handler.track_performance("point_detection"):
145
+ self.point_detector.detect(cropped_image, context=context)
146
+
147
+ # 8) Detect junctions
148
+ if self.junction_detector:
149
+ with self.debug_handler.track_performance("junction_detection"):
150
+ self.junction_detector.detect(cropped_image, context=context)
151
+
152
+ # 9) Save final JSON & any final images
153
+ output_paths = self._persist_results(output_dir, image_path, context)
154
+
155
+ # 10) Save debug images in local coords using debug_cropped
156
+ self._save_all_combinations(debug_cropped, context, output_dir, image_path)
157
+
158
+ return DetectionResult(
159
+ success=True,
160
+ processing_time=self.debug_handler.metrics.get('total_processing', 0),
161
+ json_path=output_paths.get('json_path'),
162
+ image_path=output_paths.get('image_path') # Now returning the annotated image path
163
+ )
164
+
165
+ except Exception as e:
166
+ logger.error(f"Processing failed: {str(e)}")
167
+ return DetectionResult(
168
+ success=False,
169
+ error=str(e)
170
+ )
171
+
172
+ # ------------------------------------------------
173
+ # HELPER FUNCTIONS
174
+ # ------------------------------------------------
175
+ def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict:
176
+ """Saves only JSON and line detection visualization"""
177
+ base_name = Path(image_path).stem
178
+ if base_name.startswith('display_'):
179
+ base_name = base_name[8:]
180
+
181
+ # Save JSON
182
+ json_path = Path(output_dir) / f"{base_name}_detected_lines.json"
183
+ context_json_str = context.to_json(indent=2)
184
+ self.storage.save_file(str(json_path), context_json_str.encode('utf-8'))
185
+
186
+ # Save line detection visualization using input image
187
+ annotated = self._draw_objects(
188
+ self._load_image(image_path), # Use input image instead of output
189
+ context,
190
+ draw_lines=True,
191
+ draw_points=False,
192
+ draw_symbols=False,
193
+ draw_junctions=False,
194
+ draw_tags=False
195
+ )
196
+
197
+ # Save visualization
198
+ image_path = Path(output_dir) / f"{base_name}_detected_lines.png"
199
+ _, encoded = cv2.imencode('.png', annotated)
200
+ self.storage.save_file(str(image_path), encoded.tobytes())
201
+
202
+ return {
203
+ "json_path": str(json_path),
204
+ "image_path": str(image_path)
205
+ }
206
+
207
+ def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext,
208
+ output_dir: str, image_path: str) -> None:
209
+ """Only save line detection visualization"""
210
+ base_name = Path(image_path).stem
211
+ if base_name.startswith('display_'):
212
+ base_name = base_name[8:]
213
+
214
+ # Only save line detection visualization
215
+ annotated = self._draw_objects(local_image, context,
216
+ draw_symbols=False,
217
+ draw_tags=False,
218
+ draw_lines=True,
219
+ draw_points=False,
220
+ draw_junctions=False)
221
+
222
+ save_name = f"{base_name}_detected_lines.png"
223
+ save_path = Path(output_dir) / save_name
224
+ _, encoded = cv2.imencode('.png', annotated)
225
+ self.storage.save_file(str(save_path), encoded.tobytes())
226
+
227
+ def _draw_objects(self, base_image: np.ndarray, context: DetectionContext,
228
+ draw_lines: bool = True, draw_points: bool = True,
229
+ draw_symbols: bool = True, draw_junctions: bool = True,
230
+ draw_tags: bool = True) -> np.ndarray:
231
+ """Draw detection results on a copy of base_image in local coords."""
232
+ annotated = base_image.copy()
233
+
234
+ # Lines
235
+ if draw_lines:
236
+ for ln in context.lines.values():
237
+ cv2.line(annotated,
238
+ (ln.start.coords.x, ln.start.coords.y),
239
+ (ln.end.coords.x, ln.end.coords.y),
240
+ (0, 255, 0), # green
241
+ 2)
242
+
243
+ # Points
244
+ if draw_points:
245
+ for pt in context.points.values():
246
+ cv2.circle(annotated,
247
+ (pt.coords.x, pt.coords.y),
248
+ 3,
249
+ (0, 0, 255), # red
250
+ -1)
251
+
252
+ # Symbols
253
+ if draw_symbols:
254
+ for sym in context.symbols.values():
255
+ cv2.rectangle(annotated,
256
+ (sym.bbox.xmin, sym.bbox.ymin),
257
+ (sym.bbox.xmax, sym.bbox.ymax),
258
+ (255, 255, 0), # cyan
259
+ 2)
260
+ cv2.circle(annotated,
261
+ (sym.center.x, sym.center.y),
262
+ 4,
263
+ (255, 0, 255), # magenta
264
+ -1)
265
+
266
+ # Junctions
267
+ if draw_junctions:
268
+ for jn in context.junctions.values():
269
+ if jn.junction_type == JunctionType.T:
270
+ color = (0, 165, 255) # orange
271
+ elif jn.junction_type == JunctionType.L:
272
+ color = (255, 0, 255) # magenta
273
+ else: # END
274
+ color = (0, 0, 255) # red
275
+ cv2.circle(annotated,
276
+ (jn.center.x, jn.center.y),
277
+ 5,
278
+ color,
279
+ -1)
280
+
281
+ # Tags
282
+ if draw_tags:
283
+ for tg in context.tags.values():
284
+ cv2.rectangle(annotated,
285
+ (tg.bbox.xmin, tg.bbox.ymin),
286
+ (tg.bbox.xmax, tg.bbox.ymax),
287
+ (128, 0, 128), # purple
288
+ 2)
289
+ cv2.putText(annotated,
290
+ tg.text,
291
+ (tg.bbox.xmin, tg.bbox.ymin - 5),
292
+ cv2.FONT_HERSHEY_SIMPLEX,
293
+ 0.5,
294
+ (128, 0, 128),
295
+ 1)
296
+
297
+ return annotated
298
+
299
+ def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict:
300
+ """Legacy interface for line detection"""
301
+ storage = StorageFactory.get_storage()
302
+ debug_handler = DebugHandler(enabled=True, storage=storage)
303
+
304
+ line_detector = LineDetector(
305
+ config=LineConfig(),
306
+ model_path="models/deeplsd_md.tar",
307
+ device=torch.device("cpu"),
308
+ debug_handler=debug_handler
309
+ )
310
+
311
+ pipeline = DiagramDetectionPipeline(
312
+ tag_detector=None,
313
+ symbol_detector=None,
314
+ line_detector=line_detector,
315
+ point_detector=None,
316
+ junction_detector=None,
317
+ storage=storage,
318
+ debug_handler=debug_handler
319
+ )
320
+
321
+ result = pipeline.run(image_path, output_dir, ImageConfig())
322
+ return result
323
+
324
+ def _validate_and_normalize_coordinates(self, points):
325
+ """Validate and normalize coordinates to image space"""
326
+ valid_points = []
327
+ for point in points:
328
+ x, y = point['x'], point['y']
329
+ # Validate coordinates are within image bounds
330
+ if 0 <= x <= self.image_width and 0 <= y <= self.image_height:
331
+ # Normalize coordinates if needed
332
+ valid_points.append({
333
+ 'x': int(x),
334
+ 'y': int(y),
335
+ 'type': point.get('type', 'unknown'),
336
+ 'confidence': point.get('confidence', 1.0)
337
+ })
338
+ return valid_points
339
+
340
+
341
+ if __name__ == "__main__":
342
+ # 1) Initialize components
343
+ storage = StorageFactory.get_storage()
344
+ debug_handler = DebugHandler(enabled=True, storage=storage)
345
+
346
+ # 2) Build detectors
347
+ conf = {
348
+ "detect_lines": True,
349
+ "line_detection_params": {
350
+ "merge": True,
351
+ "filtering": True,
352
+ "grad_thresh": 3,
353
+ "grad_nfa": True
354
+ }
355
+ }
356
+
357
+ # 3) Configure
358
+ line_config = LineConfig()
359
+ point_config = PointConfig()
360
+ junction_config = JunctionConfig()
361
+ symbol_config = SymbolConfig()
362
+ tag_config = TagConfig()
363
+
364
+ # ========================== Detectors ========================== #
365
+ symbol_detector = SymbolDetector(
366
+ config=symbol_config,
367
+ debug_handler=debug_handler
368
+ )
369
+
370
+ tag_detector = TagDetector(
371
+ config=tag_config,
372
+ debug_handler=debug_handler
373
+ )
374
+
375
+ line_detector = LineDetector(
376
+ config=line_config,
377
+ model_path="models/deeplsd_md.tar",
378
+ model_config=conf,
379
+ device=torch.device("cuda"), # or "cuda" if available
380
+ debug_handler=debug_handler
381
+ )
382
+
383
+ point_detector = PointDetector(
384
+ config=point_config,
385
+ debug_handler=debug_handler)
386
+
387
+ junction_detector = JunctionDetector(
388
+ config=junction_config,
389
+ debug_handler=debug_handler
390
+ )
391
+
392
+ # 4) Create pipeline
393
+ pipeline = DiagramDetectionPipeline(
394
+ tag_detector=tag_detector,
395
+ symbol_detector=symbol_detector,
396
+ line_detector=line_detector,
397
+ point_detector=point_detector,
398
+ junction_detector=junction_detector,
399
+ storage=storage,
400
+ debug_handler=debug_handler
401
+ )
402
+
403
+ # 5) Run pipeline
404
+ result = pipeline.run(
405
+ image_path="samples/images/0.jpg",
406
+ output_dir="results/",
407
+ config=ImageConfig()
408
+ )
409
+
410
+ if result.success:
411
+ logger.info(f"Pipeline succeeded! See JSON at {result.json_path}")
412
+ else:
413
+ logger.error(f"Pipeline failed: {result.error}")
line_detectors.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from typing import Dict, List, Optional, Tuple
5
+ from loguru import logger
6
+
7
+ # Check if DeepLSD is available
8
+ try:
9
+ from deeplsd.models.deeplsd_inference import DeepLSD
10
+ DEEPLSD_AVAILABLE = True
11
+ except ImportError:
12
+ DEEPLSD_AVAILABLE = False
13
+ logger.warning("DeepLSD not available, falling back to OpenCV")
14
+
15
+
16
+ class OpenCVLineDetector:
17
+ """Fallback line detector using OpenCV's HoughLinesP"""
18
+
19
+ def __init__(self):
20
+ self.params = {
21
+ 'threshold': 50,
22
+ 'minLineLength': 50,
23
+ 'maxLineGap': 10
24
+ }
25
+
26
+ def detect(self, image: np.ndarray) -> Dict:
27
+ """Detect lines using HoughLinesP"""
28
+ if len(image.shape) == 3:
29
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
30
+ else:
31
+ gray = image
32
+
33
+ edges = cv2.Canny(gray, 50, 150, apertureSize=3)
34
+ lines = cv2.HoughLinesP(
35
+ edges, 1, np.pi/180,
36
+ threshold=self.params['threshold'],
37
+ minLineLength=self.params['minLineLength'],
38
+ maxLineGap=self.params['maxLineGap']
39
+ )
40
+
41
+ detections = []
42
+ if lines is not None:
43
+ for line in lines:
44
+ x1, y1, x2, y2 = line[0]
45
+ detections.append({
46
+ 'x1': float(x1),
47
+ 'y1': float(y1),
48
+ 'x2': float(x2),
49
+ 'y2': float(y2),
50
+ 'confidence': 1.0
51
+ })
52
+
53
+ return {'lines': detections}
54
+
55
+
56
+ class DeepLSDDetector:
57
+ """Line detector using DeepLSD model"""
58
+
59
+ def __init__(self, model_path: str):
60
+ if not DEEPLSD_AVAILABLE:
61
+ raise ImportError("DeepLSD is not available")
62
+
63
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
64
+ self.model = self._load_model(model_path)
65
+
66
+ def _load_model(self, model_path: str) -> DeepLSD:
67
+ """Load the DeepLSD model"""
68
+ try:
69
+ ckpt = torch.load(model_path, map_location=self.device)
70
+ model = DeepLSD()
71
+ model.load_state_dict(ckpt['model'])
72
+ return model.to(self.device).eval()
73
+ except Exception as e:
74
+ logger.error(f"Failed to load DeepLSD model: {str(e)}")
75
+ raise
76
+
77
+ def detect(self, image: np.ndarray) -> Dict:
78
+ """Detect lines using DeepLSD"""
79
+ try:
80
+ # Convert to tensor
81
+ if len(image.shape) == 3:
82
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
83
+ else:
84
+ gray = image
85
+
86
+ tensor = torch.tensor(gray, dtype=torch.float32, device=self.device)[None, None] / 255.0
87
+
88
+ # Run inference
89
+ with torch.no_grad():
90
+ output = self.model({"image": tensor})
91
+ lines = output["lines"][0] # [N, 2, 2] array
92
+
93
+ # Convert to standard format
94
+ detections = []
95
+ for line in lines:
96
+ (x1, y1), (x2, y2) = line
97
+ detections.append({
98
+ 'x1': float(x1),
99
+ 'y1': float(y1),
100
+ 'x2': float(x2),
101
+ 'y2': float(y2),
102
+ 'confidence': float(output.get("confidence", [1.0])[0])
103
+ })
104
+
105
+ return {'lines': detections}
106
+
107
+ except Exception as e:
108
+ logger.error(f"Error in DeepLSD detection: {str(e)}")
109
+ return {'lines': []}
logger.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from loguru import logger
3
+ import sys
4
+
5
+ def get_logger(name: str):
6
+ """Configure and return a logger instance"""
7
+
8
+ # Remove any existing handlers
9
+ logger.remove()
10
+
11
+ # Add a new handler with custom format
12
+ logger.add(
13
+ sys.stderr,
14
+ format="<green>{time:YYYY-MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
15
+ level="INFO"
16
+ )
17
+
18
+ # Add file handler for persistent logging
19
+ logger.add(
20
+ "logs/app.log",
21
+ rotation="500 MB",
22
+ retention="10 days",
23
+ level="DEBUG",
24
+ compression="zip"
25
+ )
26
+
27
+ # Create logger for the module
28
+ module_logger = logger.bind(name=name)
29
+
30
+ return module_logger
pdf_processor.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fitz # PyMuPDF
3
+ import cv2
4
+ import numpy as np
5
+ from pathlib import Path
6
+ import logging
7
+ from storage import StorageInterface
8
+ import shutil
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class DocumentProcessor:
13
+ def __init__(self, storage: StorageInterface):
14
+ self.storage = storage
15
+ self.target_dpi = 600 # Fixed at 600 DPI
16
+
17
+ def clean_results_folder(self, output_dir: str):
18
+ """Clean the results directory before processing new files"""
19
+ if os.path.exists(output_dir):
20
+ try:
21
+ shutil.rmtree(output_dir)
22
+ logger.info(f"Cleaned results directory: {output_dir}")
23
+ except Exception as e:
24
+ logger.error(f"Error cleaning results directory: {str(e)}")
25
+ raise
26
+ os.makedirs(output_dir, exist_ok=True)
27
+
28
+ def process_document(self, file_path: str, output_dir: str) -> list:
29
+ """Process document (PDF/PNG/JPG) and return paths to processed pages"""
30
+ # Clean results folder first
31
+ self.clean_results_folder(output_dir)
32
+
33
+ file_ext = Path(file_path).suffix.lower()
34
+
35
+ if file_ext == '.pdf':
36
+ return self._process_pdf(file_path, output_dir)
37
+ elif file_ext in ['.png', '.jpg', '.jpeg']:
38
+ return self._process_image(file_path, output_dir)
39
+ else:
40
+ raise ValueError(f"Unsupported file format: {file_ext}")
41
+
42
+ def _process_pdf(self, pdf_path: str, output_dir: str) -> list:
43
+ """Process PDF document"""
44
+ processed_pages = []
45
+ base_name = Path(pdf_path).stem
46
+
47
+ try:
48
+ # Open PDF
49
+ doc = fitz.open(pdf_path)
50
+
51
+ for page_num in range(len(doc)):
52
+ page = doc[page_num]
53
+
54
+ # Get high-res image
55
+ pix = page.get_pixmap(matrix=fitz.Matrix(self.target_dpi/72, self.target_dpi/72))
56
+
57
+ # Convert to numpy array
58
+ img = np.frombuffer(pix.samples, dtype=np.uint8).reshape(pix.height, pix.width, pix.n)
59
+ if pix.n == 4: # RGBA
60
+ img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
61
+
62
+ # Save image
63
+ output_path = os.path.join(output_dir, f"{base_name}_page_{page_num + 1}.png")
64
+ self._save_image(img, output_path)
65
+ processed_pages.append(output_path)
66
+
67
+ return processed_pages
68
+
69
+ except Exception as e:
70
+ logger.error(f"Error processing PDF: {str(e)}")
71
+ raise
72
+
73
+ def _process_image(self, image_path: str, output_dir: str) -> list:
74
+ """Process single image"""
75
+ try:
76
+ # Read image
77
+ img = cv2.imread(image_path)
78
+ if img is None:
79
+ raise ValueError(f"Could not read image: {image_path}")
80
+
81
+ # Convert BGR to RGB
82
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
83
+
84
+ # Calculate scaling factor for 600 DPI
85
+ current_dpi = 72 # Assume standard screen resolution
86
+ scale = self.target_dpi / current_dpi
87
+
88
+ # Resize image
89
+ img = cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
90
+
91
+ # Save image
92
+ base_name = Path(image_path).stem
93
+ output_path = os.path.join(output_dir, f"{base_name}_page_1.png")
94
+ self._save_image(img, output_path)
95
+
96
+ return [output_path]
97
+
98
+ except Exception as e:
99
+ logger.error(f"Error processing image: {str(e)}")
100
+ raise
101
+
102
+ def _save_image(self, img: np.ndarray, output_path: str):
103
+ """Save processed image"""
104
+ # Encode image with high quality PNG
105
+ _, buffer = cv2.imencode('.png', cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
106
+ self.storage.save_file(output_path, buffer.tobytes())
107
+
108
+ if __name__ == "__main__":
109
+ from storage import StorageFactory
110
+
111
+ # Initialize storage and processor
112
+ storage = StorageFactory.get_storage()
113
+ processor = DocumentProcessor(storage)
114
+
115
+ # Process PDF
116
+ pdf_path = "samples/001.pdf"
117
+ output_dir = "results" # Changed from "processed_pages" to "results"
118
+
119
+ try:
120
+ # Ensure output directory exists
121
+ os.makedirs(output_dir, exist_ok=True)
122
+
123
+ results = processor.process_document(
124
+ file_path=pdf_path,
125
+ output_dir=output_dir
126
+ )
127
+
128
+ # Print detailed results
129
+ print("\nProcessing Results:")
130
+ print(f"Output Directory: {os.path.abspath(output_dir)}")
131
+
132
+ for page_path in results:
133
+ abs_path = os.path.abspath(page_path)
134
+ file_size = os.path.getsize(page_path) / (1024 * 1024) # Convert to MB
135
+ print(f"- {os.path.basename(page_path)} ({file_size:.2f} MB)")
136
+
137
+ # Calculate total size of output
138
+ total_size = sum(os.path.getsize(os.path.join(output_dir, f))
139
+ for f in os.listdir(output_dir)) / (1024 * 1024)
140
+ print(f"\nTotal output size: {total_size:.2f} MB")
141
+
142
+ except Exception as e:
143
+ logger.error(f"Error processing PDF: {str(e)}")
144
+ raise
requirements.txt ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=3.50.2
3
+ numpy>=1.24.0
4
+ Pillow>=10.0.0
5
+ opencv-python==4.8.1.78
6
+ PyMuPDF==1.23.8 # for PDF processing
7
+
8
+ # OCR and Text Detection
9
+ python-doctr==0.11.0 # Latest stable version
10
+ easyocr==1.7.1
11
+ pytesseract==0.3.10
12
+
13
+ # Deep Learning
14
+ torch>=2.0.0
15
+ torchvision>=0.15.0
16
+ tensorflow==2.11.0 # Optional with torch
17
+
18
+ # Graph Processing
19
+ networkx>=3.0
20
+ matplotlib>=3.7.0
21
+
22
+ # Utilities
23
+ python-dotenv>=1.0.0
24
+ tqdm==4.66.1
25
+ loguru==0.7.2
26
+ scipy==1.11.4
27
+ pypdfium2==4.20.0
28
+ weasyprint==60.1
29
+
30
+ # Storage and Processing
31
+ azure-storage-blob==12.19.0
32
+ azure-core==1.29.5
33
+
34
+ # OCR Engines
35
+ ultralytics==8.0.0 # for YOLO models
36
+ deeplsd @ git+https://github.com/cvg/DeepLSD.git
37
+ omegaconf>=2.3.0 # Required by DeepLSD
38
+ pytlsd @ git+https://github.com/iago-suarez/pytlsd.git # Required by DeepLSD
39
+
40
+ # AI/Chat
41
+ openai>=1.0.0 # For ChatGPT integration
42
+ uuid>=1.30
43
+ shapely>=1.8.0 # for geometry operations
44
+
45
+ # Added from the code block
46
+ requests>=2.31.0
47
+
48
+ # Added from the code block
49
+ opencv-python-headless>=4.8.0
50
+
51
+ # Added from the code block
52
+ huggingface-hub>=0.19.0
53
+ transformers>=4.35.0
54
+ gradio==5.15.0
results/002_page_1_aggregated.json ADDED
The diff for this file is too large to render. See raw diff
 
results/002_page_1_detected_symbols.json ADDED
The diff for this file is too large to render. See raw diff
 
storage.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from azure.storage.blob import BlobServiceClient
4
+ from abc import ABC, abstractmethod
5
+ import json
6
+
7
+ class StorageInterface(ABC):
8
+
9
+ @abstractmethod
10
+ def save_file(self, file_path: str, content: bytes) -> str:
11
+ pass
12
+
13
+ @abstractmethod
14
+ def load_file(self, file_path: str) -> bytes:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def list_files(self, directory: str) -> list[str]:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def file_exists(self, file_path: str) -> bool:
23
+ pass
24
+
25
+ @abstractmethod
26
+ def delete_file(self, file_path: str) -> None:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def create_directory(self, directory: str) -> None:
31
+ pass
32
+
33
+ @abstractmethod
34
+ def delete_directory(self, directory: str) -> None:
35
+ pass
36
+
37
+ @abstractmethod
38
+ def upload(self, local_path: str, destination_path: str) -> None:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def append_file(self, file_path: str, content: bytes) -> None:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_modified_time(self, file_path: str) -> float:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def directory_exists(self, directory: str) -> bool:
51
+ pass
52
+
53
+ def load_json(self, file_path):
54
+ """Load and parse JSON file."""
55
+ try:
56
+ with open(file_path, 'r', encoding='utf-8') as f:
57
+ data = json.load(f)
58
+ return data
59
+ except Exception as e:
60
+ print(f"Error loading JSON from {file_path}: {str(e)}")
61
+ return None
62
+
63
+
64
+ class LocalStorage(StorageInterface):
65
+
66
+ def save_file(self, file_path: str, content: bytes) -> str:
67
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
68
+ with open(file_path, 'wb') as f:
69
+ f.write(content)
70
+ return file_path
71
+
72
+ def load_file(self, file_path: str) -> bytes:
73
+ with open(file_path, 'rb') as f:
74
+ return f.read()
75
+
76
+ def list_files(self, directory: str) -> list[str]:
77
+ return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
78
+
79
+ def file_exists(self, file_path: str) -> bool:
80
+ return os.path.exists(file_path)
81
+
82
+ def delete_file(self, file_path: str) -> None:
83
+ os.remove(file_path)
84
+
85
+ def create_directory(self, directory: str) -> None:
86
+ os.makedirs(directory, exist_ok=True)
87
+
88
+ def delete_directory(self, directory: str) -> None:
89
+ shutil.rmtree(directory)
90
+
91
+ def upload(self, local_path: str, destination_path: str) -> None:
92
+ os.makedirs(os.path.dirname(destination_path), exist_ok=True)
93
+ shutil.copy(local_path, destination_path)
94
+
95
+ def append_file(self, file_path: str, content: bytes) -> None:
96
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
97
+ with open(file_path, 'ab') as f:
98
+ f.write(content)
99
+
100
+ def get_modified_time(self, file_path: str) -> float:
101
+ return os.path.getmtime(file_path)
102
+
103
+ def directory_exists(self, directory: str) -> bool:
104
+ return self.file_exists(directory)
105
+
106
+
107
+ class BlobStorage(StorageInterface):
108
+ """
109
+ Writes to blob storage, using local disk as a cache
110
+
111
+ TODO: Allow configuration of temp dir instead of just using the same paths in both local and remote
112
+ """
113
+ def __init__(self, connection_string: str, container_name: str):
114
+ self.blob_service_client = BlobServiceClient.from_connection_string(connection_string)
115
+ self.container_client = self.blob_service_client.get_container_client(container_name)
116
+ self.local_storage = LocalStorage()
117
+
118
+ def download(self, file_path: str) -> bytes:
119
+ blob_client = self.container_client.get_blob_client(file_path)
120
+ return blob_client.download_blob().readall()
121
+
122
+ def sync(self, file_path: str) -> None:
123
+ if not self.local_storage.file_exists(file_path):
124
+ print(f"DEBUG: missing local version of {file_path} - downloading")
125
+ self.local_storage.save_file(file_path, self.download(file_path))
126
+ else:
127
+ local_timestamp = self.local_storage.get_modified_time(file_path)
128
+ remote_timestamp = self.get_modified_time(file_path)
129
+ if local_timestamp < remote_timestamp:
130
+ # We always write remotely before writing locally, so we expect local_timestamp to be > remote timestamp
131
+ print(f"DBEUG: local version of {file_path} out of date - downloading")
132
+ self.local_storage.save_file(file_path, self.download(file_path))
133
+
134
+
135
+ def save_file(self, file_path: str, content: bytes) -> str:
136
+ blob_client = self.container_client.get_blob_client(file_path)
137
+ blob_client.upload_blob(content, overwrite=True)
138
+ self.local_storage.save_file(file_path, content)
139
+ return file_path
140
+
141
+ def load_file(self, file_path: str) -> bytes:
142
+ self.sync(file_path)
143
+ return self.local_storage.load_file(file_path)
144
+
145
+ def list_files(self, directory: str) -> list[str]:
146
+ return [blob.name for blob in self.container_client.list_blobs(name_starts_with=directory)]
147
+
148
+ def file_exists(self, file_path: str) -> bool:
149
+ blob_client = self.container_client.get_blob_client(file_path)
150
+ return blob_client.exists()
151
+
152
+ def delete_file(self, file_path: str) -> None:
153
+ self.local_storage.delete_file(file_path)
154
+ blob_client = self.container_client.get_blob_client(file_path)
155
+ blob_client.delete_blob()
156
+
157
+ def create_directory(self, directory: str) -> None:
158
+ # Blob storage doesn't have directories, so only create it locally
159
+ self.local_storage.create_directory(directory)
160
+
161
+ def delete_directory(self, directory: str) -> None:
162
+ self.local_storage.delete_directory(directory)
163
+ blobs_to_delete = self.container_client.list_blobs(name_starts_with=directory)
164
+ for blob in blobs_to_delete:
165
+ self.container_client.delete_blob(blob.name)
166
+
167
+ def upload(self, local_path: str, destination_path: str) -> None:
168
+ with open(local_path, "rb") as data:
169
+ blob_client = self.container_client.get_blob_client(destination_path)
170
+ blob_client.upload_blob(data, overwrite=True)
171
+ self.local_storage.upload(local_path, destination_path)
172
+
173
+ def append_file(self, file_path: str, content: bytes) -> None:
174
+ blob_client = self.container_client.get_blob_client(file_path)
175
+ if not blob_client.exists():
176
+ blob_client.create_append_blob()
177
+ else:
178
+ self.sync(file_path)
179
+
180
+ blob_client.append_block(content)
181
+ self.local_storage.append_file(file_path, content)
182
+
183
+ def get_modified_time(self, file_path: str) -> float:
184
+ blob_client = self.container_client.get_blob_client(file_path)
185
+ properties = blob_client.get_blob_properties()
186
+ # Convert the UTC datetime to a UNIX timestamp
187
+ return properties.last_modified.timestamp()
188
+
189
+ def directory_exists(self, directory: str) -> bool:
190
+ blobs = self.container_client.list_blobs(name_starts_with=directory)
191
+ return next(blobs, None) is not None
192
+
193
+
194
+ class StorageFactory:
195
+ @staticmethod
196
+ def get_storage() -> StorageInterface:
197
+ storage_type = os.getenv('STORAGE_TYPE', 'local').lower()
198
+ if storage_type == 'local':
199
+ return LocalStorage()
200
+ elif storage_type == 'blob':
201
+ connection_string = os.getenv('AZURE_STORAGE_CONNECTION_STRING')
202
+ container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME')
203
+ if not connection_string or not container_name:
204
+ raise ValueError("Azure Blob Storage connection string and container name must be set")
205
+ return BlobStorage(connection_string, container_name)
206
+ else:
207
+ raise ValueError(f"Unsupported storage type: {storage_type}")
208
+
symbol_detection.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import uuid
4
+ import os
5
+ import logging
6
+ from ultralytics import YOLO
7
+ from tqdm import tqdm
8
+ from storage import StorageInterface
9
+ import numpy as np
10
+ from typing import Tuple, List, Dict, Any
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+
15
+ # Constants
16
+ MODEL_PATHS = {
17
+ "model1": "models/Intui_SDM_41.pt",
18
+ "model2": "models/Intui_SDM_20.pt" # Add your second model path here
19
+ }
20
+ MAX_DIMENSION = 1280
21
+ CONFIDENCE_THRESHOLDS = [0.1, 0.3, 0.5, 0.7, 0.9]
22
+ TEXT_COLOR = (0, 0, 255) # Red color for text
23
+ BOX_COLOR = (255, 0, 0) # Red color for box (no transparency)
24
+ BG_COLOR = (255, 255, 255, 0.6) # Semi-transparent white for text background
25
+ THICKNESS = 1 # Thin text thickness
26
+ BOX_THICKNESS = 2 # Box line thickness
27
+ MIN_FONT_SCALE = 0.2 # Minimum font scale
28
+ MAX_FONT_SCALE = 1.0 # Maximum font scale
29
+ TEXT_PADDING = 20 # Increased padding between text elements
30
+ OVERLAP_THRESHOLD = 0.3 # Threshold for detecting text overlap
31
+
32
+ def preprocess_image_for_symbol_detection(image_cv: np.ndarray) -> np.ndarray:
33
+ """Preprocess the image for symbol detection."""
34
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
35
+ equalized = cv2.equalizeHist(gray)
36
+ filtered = cv2.bilateralFilter(equalized, 9, 75, 75)
37
+ edges = cv2.Canny(filtered, 100, 200)
38
+ preprocessed_image = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
39
+ return preprocessed_image
40
+
41
+ def evaluate_detections(detections_list: List[Dict[str, Any]]) -> int:
42
+ """Evaluate the quality of detections."""
43
+ return len(detections_list)
44
+
45
+ def resize_image_with_aspect_ratio(image_cv: np.ndarray, max_dimension: int) -> Tuple[np.ndarray, int, int]:
46
+ """Resize the image while maintaining the aspect ratio."""
47
+ original_height, original_width = image_cv.shape[:2]
48
+ if max(original_width, original_height) > max_dimension:
49
+ scale = max_dimension / float(max(original_width, original_height))
50
+ new_width = int(original_width * scale)
51
+ new_height = int(original_height * scale)
52
+ image_cv = cv2.resize(image_cv, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
53
+ else:
54
+ new_width, new_height = original_width, original_height
55
+ return image_cv, new_width, new_height
56
+
57
+ def merge_detections(all_detections: List[Dict]) -> List[Dict]:
58
+ """
59
+ Merge detections from all models, keeping only the highest confidence detection
60
+ when duplicates are found using IoU.
61
+ """
62
+ if not all_detections:
63
+ return []
64
+
65
+ # Sort by confidence
66
+ all_detections.sort(key=lambda x: x['confidence'], reverse=True)
67
+
68
+ # Keep track of which detections to keep
69
+ keep = [True] * len(all_detections)
70
+
71
+ def calculate_iou(box1, box2):
72
+ """Calculate Intersection over Union (IoU) between two boxes."""
73
+ x1 = max(box1[0], box2[0])
74
+ y1 = max(box1[1], box2[1])
75
+ x2 = min(box1[2], box2[2])
76
+ y2 = min(box1[3], box2[3])
77
+
78
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
79
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
80
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
81
+ union = area1 + area2 - intersection
82
+
83
+ return intersection / union if union > 0 else 0
84
+
85
+ # Apply NMS and keep only highest confidence detection
86
+ for i in range(len(all_detections)):
87
+ if not keep[i]:
88
+ continue
89
+
90
+ current_box = all_detections[i]['bbox']
91
+ current_label = all_detections[i]['original_label']
92
+
93
+ for j in range(i + 1, len(all_detections)):
94
+ if not keep[j]:
95
+ continue
96
+
97
+ # Check if same label type and high IoU
98
+ if (all_detections[j]['original_label'] == current_label and
99
+ calculate_iou(current_box, all_detections[j]['bbox']) > 0.5):
100
+ # Since list is sorted by confidence, i will always have higher confidence than j
101
+ keep[j] = False
102
+ logging.info(f"Removing duplicate detection of {current_label} with lower confidence "
103
+ f"({all_detections[j]['confidence']:.2f} < {all_detections[i]['confidence']:.2f})")
104
+
105
+ # Add kept detections to final list
106
+ merged_detections = [det for i, det in enumerate(all_detections) if keep[i]]
107
+ return merged_detections
108
+
109
+ def calculate_font_scale(image_width: int, bbox_width: int) -> float:
110
+ """
111
+ Calculate appropriate font scale based on image and bbox dimensions.
112
+ """
113
+ base_scale = 0.7 # Increased base scale for better visibility
114
+
115
+ # Adjust font size based on image width and bbox width
116
+ width_ratio = image_width / MAX_DIMENSION
117
+ bbox_ratio = bbox_width / image_width
118
+
119
+ # Calculate adaptive scale with increased multipliers
120
+ adaptive_scale = base_scale * max(width_ratio, 0.5) * max(bbox_ratio * 6, 0.7)
121
+
122
+ # Ensure font scale stays within reasonable bounds
123
+ return min(max(adaptive_scale, MIN_FONT_SCALE), MAX_FONT_SCALE)
124
+
125
+ def check_overlap(rect1, rect2):
126
+ """Check if two rectangles overlap."""
127
+ x1_1, y1_1, x2_1, y2_1 = rect1
128
+ x1_2, y1_2, x2_2, y2_2 = rect2
129
+
130
+ return not (x2_1 < x1_2 or x1_1 > x2_2 or y2_1 < y1_2 or y1_1 > y2_2)
131
+
132
+ def draw_annotation(
133
+ image: np.ndarray,
134
+ bbox: List[int],
135
+ text: str,
136
+ confidence: float,
137
+ model_source: str,
138
+ existing_annotations: List[tuple] = None
139
+ ) -> None:
140
+ """
141
+ Draw annotation with no background and thin fonts.
142
+ """
143
+ if existing_annotations is None:
144
+ existing_annotations = []
145
+
146
+ x1, y1, x2, y2 = bbox
147
+ bbox_width = x2 - x1
148
+ image_width = image.shape[1]
149
+ image_height = image.shape[0]
150
+
151
+ # Calculate adaptive font scale
152
+ font_scale = calculate_font_scale(image_width, bbox_width)
153
+
154
+ # Simplify the annotation text
155
+ annotation_text = f'{text}\n{confidence:.0f}%'
156
+ lines = annotation_text.split('\n')
157
+
158
+ # Calculate text dimensions
159
+ font = cv2.FONT_HERSHEY_SIMPLEX
160
+ max_width = 0
161
+ total_height = 0
162
+ line_heights = []
163
+
164
+ for line in lines:
165
+ (width, height), baseline = cv2.getTextSize(
166
+ line, font, font_scale, THICKNESS
167
+ )
168
+ max_width = max(max_width, width)
169
+ line_height = height + baseline + TEXT_PADDING
170
+ line_heights.append(line_height)
171
+ total_height += line_height
172
+
173
+ # Calculate initial text position with increased padding
174
+ padding = TEXT_PADDING
175
+ rect_x1 = max(0, x1 - padding)
176
+ rect_x2 = min(image_width, x1 + max_width + padding * 2)
177
+
178
+ # Try different positions to avoid overlap
179
+ positions = [
180
+ ('top', y1 - total_height - padding),
181
+ ('bottom', y2 + padding),
182
+ ('top_shifted', y1 - total_height - padding * 2),
183
+ ('bottom_shifted', y2 + padding * 2)
184
+ ]
185
+
186
+ final_position = None
187
+ for pos_name, y_pos in positions:
188
+ if y_pos < 0 or y_pos + total_height > image_height:
189
+ continue
190
+
191
+ rect = (rect_x1, y_pos, rect_x2, y_pos + total_height)
192
+ overlap = False
193
+
194
+ for existing_rect in existing_annotations:
195
+ if check_overlap(rect, existing_rect):
196
+ overlap = True
197
+ break
198
+
199
+ if not overlap:
200
+ final_position = (pos_name, y_pos)
201
+ existing_annotations.append(rect)
202
+ break
203
+
204
+ # If no non-overlapping position found, use side position
205
+ if final_position is None:
206
+ rect_x1 = max(0, x1 + bbox_width + padding)
207
+ rect_x2 = min(image_width, rect_x1 + max_width + padding * 2)
208
+ y_pos = y1
209
+ final_position = ('side', y_pos)
210
+
211
+ rect_y1 = final_position[1]
212
+
213
+ # Draw bounding box (no transparency)
214
+ cv2.rectangle(image, (x1, y1), (x2, y2), BOX_COLOR, BOX_THICKNESS)
215
+
216
+ # Draw text directly without background
217
+ text_y = rect_y1 + line_heights[0] - padding
218
+ for i, line in enumerate(lines):
219
+ # Draw text with thin lines
220
+ cv2.putText(
221
+ image,
222
+ line,
223
+ (rect_x1 + padding, text_y + sum(line_heights[:i])),
224
+ font,
225
+ font_scale,
226
+ TEXT_COLOR,
227
+ THICKNESS,
228
+ cv2.LINE_AA
229
+ )
230
+
231
+ def run_detection_with_optimal_threshold(
232
+ image_path: str,
233
+ results_dir: str = "results",
234
+ file_name: str = "",
235
+ apply_preprocessing: bool = False,
236
+ resize_image: bool = True, # Changed default to True
237
+ storage: StorageInterface = None
238
+ ) -> Tuple[str, str, str, List[int]]:
239
+ """Run detection with multiple models and merge results."""
240
+ try:
241
+ image_data = storage.load_file(image_path)
242
+ nparr = np.frombuffer(image_data, np.uint8)
243
+ original_image_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
244
+ image_cv = original_image_cv.copy()
245
+
246
+ if resize_image:
247
+ logging.info("Resizing image for detection with aspect ratio...")
248
+ image_cv, resized_width, resized_height = resize_image_with_aspect_ratio(image_cv, MAX_DIMENSION)
249
+ else:
250
+ logging.info("Skipping image resizing...")
251
+ resized_height, resized_width = original_image_cv.shape[:2]
252
+
253
+ if apply_preprocessing:
254
+ logging.info("Preprocessing image for symbol detection...")
255
+ image_cv = preprocess_image_for_symbol_detection(image_cv)
256
+ else:
257
+ logging.info("Skipping image preprocessing for symbol detection...")
258
+
259
+ all_detections = []
260
+
261
+ # Run detection with each model
262
+ for model_name, model_path in MODEL_PATHS.items():
263
+ logging.info(f"Running detection with model: {model_name}")
264
+
265
+ if not model_path:
266
+ logging.warning(f"No model path found for {model_name}")
267
+ continue
268
+
269
+ model = YOLO(model_path)
270
+
271
+ best_confidence_threshold = 0.5
272
+ best_detections_list = []
273
+ best_metric = -1
274
+
275
+ for confidence_threshold in CONFIDENCE_THRESHOLDS:
276
+ logging.info(f"Running detection with confidence threshold: {confidence_threshold}...")
277
+ results = model.predict(source=image_cv, imgsz=MAX_DIMENSION)
278
+
279
+ detections_list = []
280
+ for result in results:
281
+ for box in result.boxes:
282
+ confidence = float(box.conf[0])
283
+ if confidence >= confidence_threshold:
284
+ x1, y1, x2, y2 = map(float, box.xyxy[0])
285
+ class_id = int(box.cls[0])
286
+ label = result.names[class_id]
287
+
288
+ scale_x = original_image_cv.shape[1] / resized_width
289
+ scale_y = original_image_cv.shape[0] / resized_height
290
+ x1 *= scale_x
291
+ x2 *= scale_x
292
+ y1 *= scale_y
293
+ y2 *= scale_y
294
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
295
+
296
+ split_label = label.split('_')
297
+ if len(split_label) >= 3:
298
+ category = split_label[0]
299
+ type_ = split_label[1]
300
+ new_label = '_'.join(split_label[2:])
301
+ elif len(split_label) == 2:
302
+ category = split_label[0]
303
+ type_ = split_label[1]
304
+ new_label = split_label[1]
305
+ elif len(split_label) == 1:
306
+ category = split_label[0]
307
+ type_ = "Unknown"
308
+ new_label = split_label[0]
309
+ else:
310
+ logging.warning(f"Unexpected label format: {label}. Skipping this detection.")
311
+ continue
312
+
313
+ detection_id = str(uuid.uuid4())
314
+ detection_info = {
315
+ "symbol_id": detection_id,
316
+ "class_id": class_id,
317
+ "original_label": label,
318
+ "category": category,
319
+ "type": type_,
320
+ "label": new_label,
321
+ "confidence": confidence,
322
+ "bbox": [x1, y1, x2, y2],
323
+ "model_source": model_name
324
+ }
325
+ detections_list.append(detection_info)
326
+
327
+ metric = evaluate_detections(detections_list)
328
+ if metric > best_metric:
329
+ best_metric = metric
330
+ best_confidence_threshold = confidence_threshold
331
+ best_detections_list = detections_list
332
+
333
+ all_detections.extend(best_detections_list)
334
+
335
+ # Merge detections from both models
336
+ merged_detections = merge_detections(all_detections)
337
+ logging.info(f"Total detections after merging: {len(merged_detections)}")
338
+
339
+ # Draw annotations on the image
340
+ existing_annotations = []
341
+ for det in merged_detections:
342
+ draw_annotation(
343
+ original_image_cv,
344
+ det["bbox"],
345
+ det["original_label"],
346
+ det["confidence"] * 100,
347
+ det["model_source"],
348
+ existing_annotations
349
+ )
350
+
351
+ # Save results
352
+ storage.create_directory(results_dir)
353
+ file_name_without_extension = os.path.splitext(file_name)[0]
354
+
355
+ # Prepare output JSON
356
+ total_detected_symbols = len(merged_detections)
357
+ class_counts = {}
358
+ for det in merged_detections:
359
+ full_label = det["original_label"]
360
+ class_counts[full_label] = class_counts.get(full_label, 0) + 1
361
+
362
+ output_json = {
363
+ "total_detected_symbols": total_detected_symbols,
364
+ "details": class_counts,
365
+ "detections": merged_detections
366
+ }
367
+
368
+ # Save JSON and image
369
+ detection_json_path = os.path.join(
370
+ results_dir, f'{file_name_without_extension}_detected_symbols.json'
371
+ )
372
+ storage.save_file(
373
+ detection_json_path,
374
+ json.dumps(output_json, indent=4).encode('utf-8')
375
+ )
376
+
377
+ # Save with maximum quality
378
+ detection_image_path = os.path.join(
379
+ results_dir, f'{file_name_without_extension}_detected_symbols.png' # Using PNG for transparency
380
+ )
381
+
382
+ # Configure image encoding parameters for maximum quality
383
+ encode_params = [
384
+ cv2.IMWRITE_PNG_COMPRESSION, 0 # No compression for PNG
385
+ ]
386
+
387
+ # Save as high-quality PNG to preserve transparency
388
+ _, img_encoded = cv2.imencode(
389
+ '.png',
390
+ original_image_cv,
391
+ encode_params
392
+ )
393
+
394
+ storage.save_file(detection_image_path, img_encoded.tobytes())
395
+
396
+ # Calculate diagram bbox from merged detections
397
+ diagram_bbox = [
398
+ min([det['bbox'][0] for det in merged_detections], default=0),
399
+ min([det['bbox'][1] for det in merged_detections], default=0),
400
+ max([det['bbox'][2] for det in merged_detections], default=0),
401
+ max([det['bbox'][3] for det in merged_detections], default=0)
402
+ ]
403
+
404
+ # Scale up image if it's too small
405
+ min_width = 2000 # Minimum width for good visibility
406
+ if original_image_cv.shape[1] < min_width:
407
+ scale_factor = min_width / original_image_cv.shape[1]
408
+ new_width = min_width
409
+ new_height = int(original_image_cv.shape[0] * scale_factor)
410
+ original_image_cv = cv2.resize(
411
+ original_image_cv,
412
+ (new_width, new_height),
413
+ interpolation=cv2.INTER_CUBIC
414
+ )
415
+
416
+ return (
417
+ detection_image_path,
418
+ detection_json_path,
419
+ f"Total detections after merging: {total_detected_symbols}",
420
+ diagram_bbox
421
+ )
422
+ except Exception as e:
423
+ logging.error(f"An error occurred: {e}")
424
+ return "Error during detection", None, None, None
425
+
426
+ if __name__ == "__main__":
427
+ from storage import StorageFactory
428
+
429
+ uploaded_file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
430
+ results_dir = "results"
431
+ apply_symbol_preprocessing = False
432
+ resize_image = True
433
+
434
+ storage = StorageFactory.get_storage()
435
+
436
+ (
437
+ detection_image_path,
438
+ detection_json_path,
439
+ detection_log_message,
440
+ diagram_bbox
441
+ ) = run_detection_with_optimal_threshold(
442
+ uploaded_file_path,
443
+ results_dir=results_dir,
444
+ file_name=os.path.basename(uploaded_file_path),
445
+ apply_preprocessing=apply_symbol_preprocessing,
446
+ resize_image=resize_image,
447
+ storage=storage
448
+ )
449
+
450
+ logging.info("Detection Image Path: %s", detection_image_path)
451
+ logging.info("Detection JSON Path: %s", detection_json_path)
452
+ logging.info("Detection Log Message: %s", detection_log_message)
453
+ logging.info("Diagram BBox: %s", diagram_bbox)
454
+ logging.info("Done!")
text_detection_combined.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import io
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+ from doctr.models import ocr_predictor
7
+ import pytesseract
8
+ import easyocr
9
+ from storage import StorageInterface
10
+ import re
11
+ import logging
12
+ from pathlib import Path
13
+ import cv2
14
+ import traceback
15
+ from typing import Tuple
16
+
17
+ # Initialize models
18
+ try:
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+ doctr_model = ocr_predictor(pretrained=True)
22
+ easyocr_reader = easyocr.Reader(['en'])
23
+ logging.info("All OCR models loaded successfully")
24
+ except Exception as e:
25
+ logging.error(f"Error loading OCR models: {e}")
26
+
27
+ # Combined patterns from all approaches
28
+ TEXT_PATTERNS = {
29
+ 'Line_Number': r"(?:\d{1,5}[-](?:[A-Z]{2,4})[-]\d{1,3})",
30
+ 'Equipment_Tag': r"(?:[A-Z]{1,3}[-][A-Z0-9]{1,4}[-]\d{1,3})",
31
+ 'Instrument_Tag': r"(?:\d{2,3}[-][A-Z]{2,4}[-]\d{2,3})",
32
+ 'Valve_Number': r"(?:[A-Z]{1,2}[-]\d{3})",
33
+ 'Pipe_Size': r"(?:\d{1,2}[\"])",
34
+ 'Flow_Direction': r"(?:FROM|TO)",
35
+ 'Service_Description': r"(?:STEAM|WATER|AIR|GAS|DRAIN)",
36
+ 'Process_Instrument': r"(?:[0-9]{2,3}(?:-[A-Z]{2,3})?-[0-9]{2,3}|[A-Z]{2,3}-[0-9]{2,3})",
37
+ 'Nozzle': r"(?:N[0-9]{1,2}|MH)",
38
+ 'Pipe_Connector': r"(?:[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5})"
39
+ }
40
+
41
+ def detect_text_combined(image, confidence_threshold=0.3):
42
+ """Combine results from all three OCR approaches"""
43
+ results = []
44
+
45
+ # 1. Tesseract Detection
46
+ tesseract_results = detect_with_tesseract(image)
47
+ for result in tesseract_results:
48
+ result['source'] = 'tesseract'
49
+ results.append(result)
50
+
51
+ # 2. EasyOCR Detection
52
+ easyocr_results = detect_with_easyocr(image)
53
+ for result in easyocr_results:
54
+ result['source'] = 'easyocr'
55
+ results.append(result)
56
+
57
+ # 3. DocTR Detection
58
+ doctr_results = detect_with_doctr(image)
59
+ for result in doctr_results:
60
+ result['source'] = 'doctr'
61
+ results.append(result)
62
+
63
+ # Merge overlapping detections
64
+ merged_results = merge_overlapping_detections(results)
65
+
66
+ # Classify and filter results
67
+ classified_results = []
68
+ for result in merged_results:
69
+ if result['confidence'] >= confidence_threshold:
70
+ text_type = classify_text(result['text'])
71
+ result['text_type'] = text_type
72
+ classified_results.append(result)
73
+
74
+ return classified_results
75
+
76
+ def generate_detailed_summary(results):
77
+ """Generate detailed detection summary"""
78
+ summary = {
79
+ 'total_detections': len(results),
80
+ 'by_type': {},
81
+ 'by_source': {
82
+ 'tesseract': {
83
+ 'count': 0,
84
+ 'by_type': {},
85
+ 'avg_confidence': 0.0
86
+ },
87
+ 'easyocr': {
88
+ 'count': 0,
89
+ 'by_type': {},
90
+ 'avg_confidence': 0.0
91
+ },
92
+ 'doctr': {
93
+ 'count': 0,
94
+ 'by_type': {},
95
+ 'avg_confidence': 0.0
96
+ }
97
+ },
98
+ 'confidence_ranges': {
99
+ '0.9-1.0': 0,
100
+ '0.8-0.9': 0,
101
+ '0.7-0.8': 0,
102
+ '0.6-0.7': 0,
103
+ '0.5-0.6': 0,
104
+ '<0.5': 0
105
+ },
106
+ 'detected_items': []
107
+ }
108
+
109
+ # Initialize type counters
110
+ for pattern_type in TEXT_PATTERNS.keys():
111
+ summary['by_type'][pattern_type] = {
112
+ 'count': 0,
113
+ 'avg_confidence': 0.0,
114
+ 'by_source': {
115
+ 'tesseract': 0,
116
+ 'easyocr': 0,
117
+ 'doctr': 0
118
+ },
119
+ 'items': []
120
+ }
121
+ # Initialize source-specific type counters
122
+ for source in summary['by_source'].keys():
123
+ summary['by_source'][source]['by_type'][pattern_type] = 0
124
+
125
+ # Process each detection
126
+ source_confidences = {'tesseract': [], 'easyocr': [], 'doctr': []}
127
+
128
+ for result in results:
129
+ # Get source and confidence
130
+ source = result['source']
131
+ conf = result['confidence']
132
+ text_type = result['text_type']
133
+
134
+ # Update source statistics
135
+ summary['by_source'][source]['count'] += 1
136
+ source_confidences[source].append(conf)
137
+
138
+ # Update confidence ranges
139
+ if conf >= 0.9: summary['confidence_ranges']['0.9-1.0'] += 1
140
+ elif conf >= 0.8: summary['confidence_ranges']['0.8-0.9'] += 1
141
+ elif conf >= 0.7: summary['confidence_ranges']['0.7-0.8'] += 1
142
+ elif conf >= 0.6: summary['confidence_ranges']['0.6-0.7'] += 1
143
+ elif conf >= 0.5: summary['confidence_ranges']['0.5-0.6'] += 1
144
+ else: summary['confidence_ranges']['<0.5'] += 1
145
+
146
+ # Update type statistics
147
+ if text_type in summary['by_type']:
148
+ type_stats = summary['by_type'][text_type]
149
+ type_stats['count'] += 1
150
+ type_stats['by_source'][source] += 1
151
+ summary['by_source'][source]['by_type'][text_type] += 1
152
+ type_stats['items'].append({
153
+ 'text': result['text'],
154
+ 'confidence': conf,
155
+ 'source': source,
156
+ 'bbox': result['bbox']
157
+ })
158
+
159
+ # Add to detected items
160
+ summary['detected_items'].append({
161
+ 'text': result['text'],
162
+ 'type': text_type,
163
+ 'confidence': conf,
164
+ 'source': source,
165
+ 'bbox': result['bbox']
166
+ })
167
+
168
+ # Calculate average confidences
169
+ for source, confs in source_confidences.items():
170
+ if confs:
171
+ summary['by_source'][source]['avg_confidence'] = sum(confs) / len(confs)
172
+
173
+ # Calculate average confidences for each type
174
+ for text_type, stats in summary['by_type'].items():
175
+ if stats['items']:
176
+ stats['avg_confidence'] = sum(item['confidence'] for item in stats['items']) / len(stats['items'])
177
+
178
+ return summary
179
+
180
+ def process_drawing(image_path: str, output_dir: str, storage: StorageInterface) -> Tuple[dict, dict]:
181
+ """Process drawing with text detection.
182
+
183
+ Args:
184
+ image_path: Path to image file
185
+ output_dir: Directory to save results
186
+ storage: Optional storage handler
187
+ """
188
+ try:
189
+ # Read image using cv2
190
+ image = cv2.imread(image_path)
191
+ if image is None:
192
+ raise ValueError(f"Could not read image: {image_path}")
193
+
194
+ # Create annotated copy
195
+ annotated_image = image.copy()
196
+
197
+ # Initialize results and summary
198
+ text_results = {
199
+ 'file_name': image_path,
200
+ 'detections': []
201
+ }
202
+
203
+ text_summary = {
204
+ 'total_detections': 0,
205
+ 'by_source': {
206
+ 'tesseract': {'count': 0, 'avg_confidence': 0.0},
207
+ 'easyocr': {'count': 0, 'avg_confidence': 0.0},
208
+ 'doctr': {'count': 0, 'avg_confidence': 0.0}
209
+ },
210
+ 'by_type': {
211
+ 'equipment_tag': {'count': 0, 'avg_confidence': 0.0},
212
+ 'line_number': {'count': 0, 'avg_confidence': 0.0},
213
+ 'instrument_tag': {'count': 0, 'avg_confidence': 0.0},
214
+ 'valve_number': {'count': 0, 'avg_confidence': 0.0},
215
+ 'pipe_size': {'count': 0, 'avg_confidence': 0.0},
216
+ 'flow_direction': {'count': 0, 'avg_confidence': 0.0},
217
+ 'service_description': {'count': 0, 'avg_confidence': 0.0},
218
+ 'process_instrument': {'count': 0, 'avg_confidence': 0.0},
219
+ 'nozzle': {'count': 0, 'avg_confidence': 0.0},
220
+ 'pipe_connector': {'count': 0, 'avg_confidence': 0.0},
221
+ 'other': {'count': 0, 'avg_confidence': 0.0}
222
+ }
223
+ }
224
+
225
+ # Run OCR with different engines
226
+ tesseract_results = detect_with_tesseract(image)
227
+ easyocr_results = detect_with_easyocr(image)
228
+ doctr_results = detect_with_doctr(image)
229
+
230
+ # Combine results
231
+ all_detections = []
232
+ all_detections.extend([(res, 'tesseract') for res in tesseract_results])
233
+ all_detections.extend([(res, 'easyocr') for res in easyocr_results])
234
+ all_detections.extend([(res, 'doctr') for res in doctr_results])
235
+
236
+ # Process each detection
237
+ for detection, source in all_detections:
238
+ # Update text_results
239
+ text_results['detections'].append({
240
+ 'text': detection['text'],
241
+ 'bbox': detection['bbox'],
242
+ 'confidence': detection['confidence'],
243
+ 'source': source
244
+ })
245
+
246
+ # Update summary statistics
247
+ text_summary['total_detections'] += 1
248
+ text_summary['by_source'][source]['count'] += 1
249
+ text_summary['by_source'][source]['avg_confidence'] += detection['confidence']
250
+
251
+ # Draw detection on image
252
+ x1, y1, x2, y2 = detection['bbox']
253
+ cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
254
+ cv2.putText(annotated_image, detection['text'], (int(x1), int(y1)-5),
255
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
256
+
257
+ # Calculate average confidences
258
+ for source in text_summary['by_source']:
259
+ if text_summary['by_source'][source]['count'] > 0:
260
+ text_summary['by_source'][source]['avg_confidence'] /= text_summary['by_source'][source]['count']
261
+
262
+ # Save results with new naming convention
263
+ base_name = Path(image_path).stem
264
+ if base_name.startswith('display_'):
265
+ base_name = base_name[8:]
266
+
267
+ text_result_image_path = os.path.join(output_dir, f"{base_name}_detected_texts.png")
268
+ text_result_json_path = os.path.join(output_dir, f"{base_name}_detected_texts.json")
269
+
270
+ # Save the annotated image
271
+ cv2.imwrite(text_result_image_path, annotated_image)
272
+
273
+ # Save the JSON results
274
+ with open(text_result_json_path, 'w', encoding='utf-8') as f:
275
+ json.dump({
276
+ 'file_name': image_path,
277
+ 'summary': text_summary,
278
+ 'detections': text_results['detections']
279
+ }, f, indent=4, ensure_ascii=False)
280
+
281
+ return {
282
+ 'image_path': text_result_image_path,
283
+ 'json_path': text_result_json_path,
284
+ 'results': text_results
285
+ }, text_summary
286
+
287
+ except Exception as e:
288
+ logger.error(f"Text detection error: {str(e)}")
289
+ raise
290
+
291
+ def detect_with_tesseract(image):
292
+ """Detect text using Tesseract OCR"""
293
+ # Configure Tesseract for technical drawings
294
+ custom_config = r'--oem 3 --psm 11 -c tessedit_char_whitelist="ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-.()" -c tessedit_write_images=true -c textord_heavy_nr=true -c textord_min_linesize=3'
295
+
296
+ try:
297
+ data = pytesseract.image_to_data(
298
+ image,
299
+ config=custom_config,
300
+ output_type=pytesseract.Output.DICT
301
+ )
302
+
303
+ results = []
304
+ for i in range(len(data['text'])):
305
+ conf = float(data['conf'][i])
306
+ if conf > 30: # Lower confidence threshold for technical text
307
+ text = data['text'][i].strip()
308
+ if text:
309
+ x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
310
+ results.append({
311
+ 'text': text,
312
+ 'bbox': [x, y, x + w, y + h],
313
+ 'confidence': conf / 100.0
314
+ })
315
+ return results
316
+
317
+ except Exception as e:
318
+ logger.error(f"Tesseract error: {str(e)}")
319
+ return []
320
+
321
+ def detect_with_easyocr(image):
322
+ """Detect text using EasyOCR"""
323
+ if easyocr_reader is None:
324
+ return []
325
+
326
+ try:
327
+ results = easyocr_reader.readtext(
328
+ np.array(image),
329
+ paragraph=False,
330
+ height_ths=2.0,
331
+ width_ths=2.0,
332
+ contrast_ths=0.2,
333
+ text_threshold=0.5
334
+ )
335
+
336
+ parsed_results = []
337
+ for bbox, text, conf in results:
338
+ x1, y1 = min(point[0] for point in bbox), min(point[1] for point in bbox)
339
+ x2, y2 = max(point[0] for point in bbox), max(point[1] for point in bbox)
340
+
341
+ parsed_results.append({
342
+ 'text': text,
343
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
344
+ 'confidence': conf
345
+ })
346
+ return parsed_results
347
+
348
+ except Exception as e:
349
+ logger.error(f"EasyOCR error: {str(e)}")
350
+ return []
351
+
352
+ def detect_with_doctr(image):
353
+ """Detect text using DocTR"""
354
+ try:
355
+ # Convert PIL image to numpy array
356
+ image_np = np.array(image)
357
+
358
+ # Get predictions
359
+ result = doctr_model([image_np])
360
+ doc = result.export()
361
+
362
+ # Parse results
363
+ results = []
364
+ for page in doc['pages']:
365
+ for block in page['blocks']:
366
+ for line in block['lines']:
367
+ for word in line['words']:
368
+ # Convert normalized coordinates to absolute
369
+ height, width = image_np.shape[:2]
370
+ points = np.array(word['geometry']) * np.array([width, height])
371
+ x1, y1 = points.min(axis=0)
372
+ x2, y2 = points.max(axis=0)
373
+
374
+ results.append({
375
+ 'text': word['value'],
376
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
377
+ 'confidence': word.get('confidence', 0.5)
378
+ })
379
+ return results
380
+
381
+ except Exception as e:
382
+ logger.error(f"DocTR error: {str(e)}")
383
+ return []
384
+
385
+ def merge_overlapping_detections(results, iou_threshold=0.5):
386
+ """Merge overlapping detections from different sources"""
387
+ if not results:
388
+ return []
389
+
390
+ def calculate_iou(box1, box2):
391
+ x1 = max(box1[0], box2[0])
392
+ y1 = max(box1[1], box2[1])
393
+ x2 = min(box1[2], box2[2])
394
+ y2 = min(box1[3], box2[3])
395
+
396
+ if x2 < x1 or y2 < y1:
397
+ return 0.0
398
+
399
+ intersection = (x2 - x1) * (y2 - y1)
400
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
401
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
402
+ union = area1 + area2 - intersection
403
+
404
+ return intersection / union if union > 0 else 0
405
+
406
+ merged = []
407
+ used = set()
408
+
409
+ for i, r1 in enumerate(results):
410
+ if i in used:
411
+ continue
412
+
413
+ current_group = [r1]
414
+ used.add(i)
415
+
416
+ for j, r2 in enumerate(results):
417
+ if j in used:
418
+ continue
419
+
420
+ if calculate_iou(r1['bbox'], r2['bbox']) > iou_threshold:
421
+ current_group.append(r2)
422
+ used.add(j)
423
+
424
+ if len(current_group) == 1:
425
+ merged.append(current_group[0])
426
+ else:
427
+ # Keep the detection with highest confidence
428
+ best_detection = max(current_group, key=lambda x: x['confidence'])
429
+ merged.append(best_detection)
430
+
431
+ return merged
432
+
433
+ def classify_text(text):
434
+ """Classify text based on patterns"""
435
+ if not text:
436
+ return 'Unknown'
437
+
438
+ # Clean and normalize text
439
+ text = text.strip().upper()
440
+ text = re.sub(r'\s+', '', text)
441
+
442
+ for text_type, pattern in TEXT_PATTERNS.items():
443
+ if re.match(pattern, text):
444
+ return text_type
445
+
446
+ return 'Unknown'
447
+
448
+ def annotate_image(image, results):
449
+ """Create annotated image with detections"""
450
+ # Convert image to RGB mode to ensure color support
451
+ if image.mode != 'RGB':
452
+ image = image.convert('RGB')
453
+
454
+ # Create drawing object
455
+ draw = ImageDraw.Draw(image)
456
+ try:
457
+ font = ImageFont.truetype("arial.ttf", 20)
458
+ except IOError:
459
+ font = ImageFont.load_default()
460
+
461
+ # Define colors for different text types
462
+ colors = {
463
+ 'Line_Number': "#FF0000", # Bright Red
464
+ 'Equipment_Tag': "#00FF00", # Bright Green
465
+ 'Instrument_Tag': "#0000FF", # Bright Blue
466
+ 'Valve_Number': "#FFA500", # Bright Orange
467
+ 'Pipe_Size': "#FF00FF", # Bright Magenta
468
+ 'Process_Instrument': "#00FFFF", # Bright Cyan
469
+ 'Nozzle': "#FFFF00", # Yellow
470
+ 'Pipe_Connector': "#800080", # Purple
471
+ 'Unknown': "#FF4444" # Light Red
472
+ }
473
+
474
+ # Draw detections
475
+ for result in results:
476
+ text_type = result.get('text_type', 'Unknown')
477
+ color = colors.get(text_type, colors['Unknown'])
478
+
479
+ # Draw bounding box
480
+ draw.rectangle(result['bbox'], outline=color, width=3)
481
+
482
+ # Create label
483
+ label = f"{result['text']} ({result['confidence']:.2f})"
484
+ if text_type != 'Unknown':
485
+ label += f" [{text_type}]"
486
+
487
+ # Draw label background
488
+ text_bbox = draw.textbbox((result['bbox'][0], result['bbox'][1] - 20), label, font=font)
489
+ draw.rectangle(text_bbox, fill="#FFFFFF")
490
+
491
+ # Draw label text
492
+ draw.text((result['bbox'][0], result['bbox'][1] - 20), label, fill=color, font=font)
493
+
494
+ return image
495
+
496
+ def save_annotated_image(image, path, storage):
497
+ """Save annotated image with maximum quality"""
498
+ image_byte_array = io.BytesIO()
499
+ image.save(
500
+ image_byte_array,
501
+ format='PNG',
502
+ optimize=False,
503
+ compress_level=0
504
+ )
505
+ storage.save_file(path, image_byte_array.getvalue())
506
+
507
+ if __name__ == "__main__":
508
+ from storage import StorageFactory
509
+ import logging
510
+
511
+ # Configure logging
512
+ logging.basicConfig(level=logging.INFO)
513
+ logger = logging.getLogger(__name__)
514
+
515
+ # Initialize storage
516
+ storage = StorageFactory.get_storage()
517
+
518
+ # Test file paths
519
+ file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
520
+ result_path = "results"
521
+
522
+ try:
523
+ # Ensure result directory exists
524
+ os.makedirs(result_path, exist_ok=True)
525
+
526
+ # Process the drawing
527
+ logger.info(f"Processing file: {file_path}")
528
+ results, summary = process_drawing(file_path, result_path, storage)
529
+
530
+ # Print detailed results
531
+ print("\n=== DETAILED DETECTION RESULTS ===")
532
+ print(f"\nTotal Detections: {summary['total_detections']}")
533
+
534
+ print("\nBreakdown by Text Type:")
535
+ print("-" * 50)
536
+ for text_type, stats in summary['by_type'].items():
537
+ if stats['count'] > 0:
538
+ print(f"\n{text_type}:")
539
+ print(f" Count: {stats['count']}")
540
+ print(f" Average Confidence: {stats['avg_confidence']:.2f}")
541
+ print(" Items:")
542
+ for item in stats['items']:
543
+ print(f" - {item['text']} (conf: {item['confidence']:.2f}, source: {item['source']})")
544
+
545
+ print("\nBreakdown by OCR Engine:")
546
+ print("-" * 50)
547
+ for source, count in summary['by_source'].items():
548
+ print(f"{source}: {count} detections")
549
+
550
+ print("\nConfidence Distribution:")
551
+ print("-" * 50)
552
+ for range_name, count in summary['confidence_ranges'].items():
553
+ print(f"{range_name}: {count} detections")
554
+
555
+ # Print output paths
556
+ print("\nOutput Files:")
557
+ print("-" * 50)
558
+ print(f"Annotated Image: {results['image_path']}")
559
+ print(f"JSON Results: {results['json_path']}")
560
+
561
+ except Exception as e:
562
+ logger.error(f"Error processing file: {e}")
563
+ raise
utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ from contextlib import contextmanager
4
+ from loguru import logger
5
+ from typing import List, Dict, Optional, Tuple, Union
6
+ from detection_schema import BBox
7
+ from storage import StorageInterface
8
+ import cv2
9
+
10
+ class DebugHandler:
11
+ """Production-grade debugging and performance tracking"""
12
+
13
+ def __init__(self, enabled: bool = False, storage: StorageInterface = None):
14
+ self.enabled = enabled
15
+ self.storage = storage
16
+ self.metrics = {}
17
+ self._start_time = None
18
+
19
+ @contextmanager
20
+ def track_performance(self, operation_name: str):
21
+ """Context manager for performance tracking"""
22
+ if self.enabled:
23
+ self._start_time = time.perf_counter()
24
+ logger.debug(f"Starting {operation_name}")
25
+
26
+ yield
27
+
28
+ if self.enabled:
29
+ duration = time.perf_counter() - self._start_time
30
+ self.metrics[operation_name] = duration
31
+ logger.debug(f"{operation_name} completed in {duration:.2f}s")
32
+
33
+ def save_artifact(self, name: str, data: bytes, extension: str = "png"):
34
+ """Generic artifact storage handler"""
35
+ if self.enabled and self.storage:
36
+ path = f"debug/{name}.{extension}"
37
+
38
+ # Check if data is an np.ndarray (image)
39
+ if isinstance(data, np.ndarray):
40
+ # Convert np.ndarray to PNG bytes
41
+ success, encoded_image = cv2.imencode(f".{extension}", data)
42
+ if not success:
43
+ logger.error("Failed to encode image for saving.")
44
+ return
45
+ data = encoded_image.tobytes()
46
+
47
+ self.storage.save_file(path, data)
48
+ logger.info(f"Saved debug artifact: {path}")
49
+
50
+ class CoordinateTransformer:
51
+ @staticmethod
52
+ def global_to_local_bbox(
53
+ bbox: Union[BBox, List[BBox]],
54
+ roi: Optional[np.ndarray]
55
+ ) -> Union[BBox, List[BBox]]:
56
+ """
57
+ Convert global BBox(es) to ROI-local coordinates
58
+ Handles both single BBox and lists of BBoxes
59
+ """
60
+ if roi is None or len(roi) != 4:
61
+ return bbox
62
+
63
+ x_min, y_min, _, _ = roi
64
+
65
+ def convert(b: BBox) -> BBox:
66
+ return BBox(
67
+ xmin=b.xmin - x_min,
68
+ ymin=b.ymin - y_min,
69
+ xmax=b.xmax - x_min,
70
+ ymax=b.ymax - y_min
71
+ )
72
+
73
+ return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
74
+
75
+ @staticmethod
76
+ def local_to_global_bbox(
77
+ bbox: Union[BBox, List[BBox]],
78
+ roi: Optional[np.ndarray]
79
+ ) -> Union[BBox, List[BBox]]:
80
+ """
81
+ Convert ROI-local BBox(es) to global coordinates
82
+ Handles both single BBox and lists of BBoxes
83
+ """
84
+ if roi is None or len(roi) != 4:
85
+ return bbox
86
+
87
+ x_min, y_min, _, _ = roi
88
+
89
+ def convert(b: BBox) -> BBox:
90
+ return BBox(
91
+ xmin=b.xmin + x_min,
92
+ ymin=b.ymin + y_min,
93
+ xmax=b.xmax + x_min,
94
+ ymax=b.ymax + y_min
95
+ )
96
+
97
+ return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
98
+
99
+ # Maintain legacy tuple support if needed
100
+ @staticmethod
101
+ def global_to_local(
102
+ bboxes: List[Tuple[int, int, int, int]],
103
+ roi: Optional[np.ndarray]
104
+ ) -> List[Tuple[int, int, int, int]]:
105
+ """Legacy tuple version for backward compatibility"""
106
+ if roi is None or len(roi) != 4:
107
+ return bboxes
108
+
109
+ x_min, y_min, _, _ = roi
110
+ return [(x1 - x_min, y1 - y_min, x2 - x_min, y2 - y_min)
111
+ for x1, y1, x2, y2 in bboxes]
112
+
113
+ @staticmethod
114
+ def local_to_global(
115
+ bboxes: List[Tuple[int, int, int, int]],
116
+ roi: Optional[np.ndarray]
117
+ ) -> List[Tuple[int, int, int, int]]:
118
+ """Legacy tuple version for backward compatibility"""
119
+ if roi is None or len(roi) != 4:
120
+ return bboxes
121
+
122
+ x_min, y_min, _, _ = roi
123
+ return [(x1 + x_min, y1 + y_min, x2 + x_min, y2 + y_min)
124
+ for x1, y1, x2, y2 in bboxes]
125
+
126
+ @staticmethod
127
+ def local_to_global_point(point: Tuple[int, int], roi: Optional[np.ndarray]) -> Tuple[int, int]:
128
+ """Convert single point from local to global coordinates"""
129
+ if roi is None or len(roi) != 4:
130
+ return point
131
+ x_min, y_min, _, _ = roi
132
+ return (int(point[0] + x_min), int(point[1] + y_min))