msIntui commited on
Commit
9847531
·
0 Parent(s):

Initial commit: Add core files for P&ID processing

Browse files
.gitattributes ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pth filter=lfs diff=lfs merge=lfs -text
2
+ *.pt filter=lfs diff=lfs merge=lfs -text
3
+ *.tar filter=lfs diff=lfs merge=lfs -text
4
+ models/* filter=lfs diff=lfs merge=lfs -text
5
+ chat/*.safetensors filter=lfs diff=lfs merge=lfs -text
6
+ chat/adapter_model.safetensors filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Samples and large files
2
+ samples/
3
+ *.pdf
4
+ *.jpg
5
+ *.jpeg
6
+ *.png
7
+ *.zip
8
+
9
+ # Model files
10
+ *.pth
11
+ *.pt
12
+ *.tar
13
+ models/*
14
+ chat/*.safetensors
15
+
16
+ # Python
17
+ __pycache__/
18
+ *.py[cod]
19
+ *.class
20
+ .env
21
+ .venv/
22
+ venv/
23
+ ENV/
24
+
25
+ # IDE
26
+ .vscode/
27
+ .idea/
28
+
29
+ # Other
30
+ archive/
31
+ archive 2/
32
+ results/
33
+ DeepLSD/
README.md ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Intelligent_PID
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.3.0
8
+ app_file: gradioChatApp.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # P&ID Processing with AI-Powered Graph Construction
14
+
15
+ ## Overview
16
+
17
+ This project processes P&ID (Piping and Instrumentation Diagram) images using multiple AI models for symbol detection, text recognition, and line detection. It constructs a graph representation of the diagram and provides an interactive interface for querying the diagram's contents.
18
+
19
+ ## Process Flow
20
+
21
+ ```mermaid
22
+ graph TD
23
+ subgraph "Document Input"
24
+ A[Upload Document] --> B[Validate File]
25
+ B -->|PDF/Image| C[Document Processor]
26
+ B -->|Invalid| ERR[Error Message]
27
+ C -->|PDF| D1[Extract Pages]
28
+ C -->|Image| D2[Direct Process]
29
+ end
30
+
31
+ subgraph "Image Preprocessing"
32
+ D1 --> E[Optimize Image]
33
+ D2 --> E
34
+ E -->|CLAHE Enhancement| E1[Contrast Enhancement]
35
+ E1 -->|Denoising| E2[Clean Image]
36
+ E2 -->|Binarization| E3[Binary Image]
37
+ E3 -->|Resize| E4[Normalized Image]
38
+ end
39
+
40
+ subgraph "Line Detection Pipeline"
41
+ E4 --> L1[Load DeepLSD Model]
42
+ L1 --> L2[Scale Image 0.1x]
43
+ L2 --> L3[Grayscale Conversion]
44
+ L3 --> L4[Model Inference]
45
+ L4 --> L5[Scale Coordinates]
46
+ L5 --> L6[Draw Lines]
47
+ end
48
+
49
+ subgraph "Detection Pipeline"
50
+ E4 --> F[Symbol Detection]
51
+ E4 --> G[Text Detection]
52
+
53
+ F --> S1[Load YOLO Models]
54
+ G --> T1[Load OCR Models]
55
+
56
+ S1 --> S2[Detect Symbols]
57
+ T1 --> T2[Detect Text]
58
+
59
+ S2 --> S3[Process Symbols]
60
+ T2 --> T3[Process Text]
61
+
62
+ L6 --> L7[Process Lines]
63
+ end
64
+
65
+ subgraph "Data Integration"
66
+ S3 --> I[Data Aggregation]
67
+ T3 --> I
68
+ L7 --> I
69
+ I --> J[Create Edges]
70
+ J --> K[Build Graph Network]
71
+ K --> L[Generate Knowledge Graph]
72
+ end
73
+
74
+ subgraph "User Interface"
75
+ L --> M[Interactive Visualization]
76
+ M --> N[Chat Interface]
77
+ N --> O[Query Processing]
78
+ O --> P[Response Generation]
79
+ P --> N
80
+ end
81
+
82
+ style A fill:#f9f,stroke:#333,stroke-width:2px
83
+ style F fill:#fbb,stroke:#333,stroke-width:2px
84
+ style G fill:#bfb,stroke:#333,stroke-width:2px
85
+ style H fill:#bbf,stroke:#333,stroke-width:2px
86
+ style I fill:#fbf,stroke:#333,stroke-width:2px
87
+ style N fill:#bbf,stroke:#333,stroke-width:2px
88
+
89
+ %% Add style for model nodes
90
+ style SM1 fill:#ffe6e6,stroke:#333,stroke-width:2px
91
+ style SM2 fill:#ffe6e6,stroke:#333,stroke-width:2px
92
+ style LM1 fill:#e6e6ff,stroke:#333,stroke-width:2px
93
+ style DC1 fill:#e6ffe6,stroke:#333,stroke-width:2px
94
+ style DC2 fill:#e6ffe6,stroke:#333,stroke-width:2px
95
+ ```
96
+
97
+ ## Architecture
98
+
99
+ ![Project Architecture](./assets/P&ID_to_Graph.drawio.png)
100
+
101
+ ## Features
102
+
103
+ - **Multi-modal AI Processing**:
104
+ - Combined OCR approach using Tesseract, EasyOCR, and DocTR
105
+ - Symbol detection with optimized thresholds
106
+ - Intelligent line and connection detection
107
+ - **Document Processing**:
108
+ - Support for PDF, PNG, JPG, JPEG formats
109
+ - Automatic page extraction from PDFs
110
+ - Image optimization pipeline
111
+ - **Text Detection Types**:
112
+ - Equipment Tags
113
+ - Line Numbers
114
+ - Instrument Tags
115
+ - Valve Numbers
116
+ - Pipe Sizes
117
+ - Flow Directions
118
+ - Service Descriptions
119
+ - Process Instruments
120
+ - Nozzles
121
+ - Pipe Connectors
122
+ - **Data Integration**:
123
+ - Automatic edge detection
124
+ - Relationship mapping
125
+ - Confidence scoring
126
+ - Detailed detection statistics
127
+ - **User Interface**:
128
+ - Interactive visualization tabs
129
+ - Real-time processing feedback
130
+ - AI-powered chat interface
131
+ - Knowledge graph exploration
132
+
133
+ The entire process is visualized through an interactive Gradio-based UI, allowing users to upload a P&ID image, follow the detection steps, and view both the results and insights in real time.
134
+
135
+ ## Key Files
136
+
137
+ - **gradioChatApp.py**: The main Gradio app script that handles the frontend and orchestrates the overall flow.
138
+ - **symbol_detection.py**: Module for detecting symbols using YOLO models.
139
+ - **text_detection_combined.py**: Unified module for text detection using multiple OCR engines (Tesseract, EasyOCR, DocTR).
140
+ - **line_detection_ai.py**: Module for detecting lines and connections using AI.
141
+ - **data_aggregation.py**: Aggregates detected elements into a structured format.
142
+ - **graph_construction.py**: Constructs the graph network from aggregated data.
143
+ - **graph_processor.py**: Handles graph visualization and processing.
144
+ - **pdf_processor.py**: Handles PDF document processing and page extraction.
145
+
146
+ ## Setup and Installation
147
+
148
+ 1. Clone the repository:
149
+ ```bash
150
+ git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
151
+ cd intui-PnID-POC
152
+ ```
153
+
154
+ 2. Install dependencies using uv:
155
+ ```bash
156
+ # Install uv if you haven't already
157
+ curl -LsSf https://astral.sh/uv/install.sh | sh
158
+
159
+ # Create and activate virtual environment
160
+ uv venv
161
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
162
+
163
+ # Install dependencies
164
+ uv pip install -r requirements.txt
165
+ ```
166
+
167
+ 3. Download required models:
168
+ ```bash
169
+ python download_model.py # Downloads DeepLSD model for line detection
170
+ ```
171
+
172
+ 4. Run the application:
173
+ ```bash
174
+ python gradioChatApp.py
175
+ ```
176
+
177
+ ## Models
178
+
179
+ ### Line Detection Model
180
+ - **DeepLSD Model**:
181
+ - File: deeplsd_md.tar
182
+ - Purpose: Line segment detection in P&ID diagrams
183
+ - Input Resolution: Variable (scaled to 0.1x for performance)
184
+ - Processing: Grayscale conversion and binary thresholding
185
+
186
+ ### Text Detection Models
187
+ - **Combined OCR Approach**:
188
+ - Tesseract OCR
189
+ - EasyOCR
190
+ - DocTR
191
+ - Purpose: Text recognition and classification
192
+
193
+ ### Graph Processing
194
+ - **NetworkX-based**:
195
+ - Purpose: Graph construction and analysis
196
+ - Features: Node linking, edge creation, path analysis
197
+
198
+ ## Updating the Environment
199
+
200
+ To update the environment, use the following:
201
+
202
+ ```bash
203
+ conda env update --file environment.yml --prune
204
+ ```
205
+
206
+ This command will update the environment according to changes made in the `environment.yml`.
207
+
208
+ ### Step 6: Deactivate the environment
209
+
210
+ When you're done, deactivate the environment by:
211
+
212
+ ```bash
213
+ conda deactivate
214
+ ```
215
+
216
+ 2. Upload a P&ID image through the interface.
217
+ 3. Follow the sequential steps of symbol, text, and line detection.
218
+ 4. View the generated graph and AI agent's reasoning in the real-time chat box.
219
+ 5. Save and export the results if satisfactory.
220
+
221
+ ## Folder Structure
222
+
223
+ ```
224
+ ├── assets/
225
+ │ └── AiAgent.png
226
+ │ └── llm.png
227
+ ├── gradioApp.py
228
+ ├── symbol_detection.py
229
+ ├── text_detection_combined.py
230
+ ├── line_detection_ai.py
231
+ ├── data_aggregation.py
232
+ ├── graph_construction.py
233
+ ├── graph_processor.py
234
+ ├── pdf_processor.py
235
+ ├── pnid_agent.py
236
+ ├── requirements.txt
237
+ ├── results/
238
+ ├── models/
239
+ │ └── symbol_detection_model.pth
240
+ ```
241
+
242
+ ## /models Folder
243
+
244
+ - **models/symbol_detection_model.pth**: This folder contains the pre-trained model for symbol detection in P&ID diagrams. This model is crucial for detecting key symbols such as valves, instruments, and pipes in the diagram. Make sure to download the model and place it in the `/models` directory before running the app.
245
+
246
+ ## Future Work
247
+
248
+ - **Advanced Symbol Recognition**: Improve symbol detection by integrating more sophisticated recognition models.
249
+ - **Graph Enhancement**: Introduce more complex graph structures and logic for representing the relationships between the diagram's elements.
250
+ - **Data Export**: Allow export in additional formats such as DEXPI-compliant XML or JSON.
251
+
252
+
253
+ # Docker Information
254
+
255
+ We'll cover the basic docker operations here.
256
+
257
+ ## Building
258
+
259
+ There is a dockerfile for each different project (they have slightly different requiremnts).
260
+
261
+ ### `gradioChatApp.py`
262
+
263
+ Run this one as follows:
264
+
265
+ ```
266
+ > docker build -t exp-pnid-to-graph_chat-w-graph:0.0.4 -f Dockerfile-chatApp .
267
+ > docker tag exp-pnid-to-graph_chat-w-graph:0.0.4 intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
268
+ ```
269
+
270
+ ## Deploying to ACR
271
+
272
+ ### `gradioChatApp.py`
273
+
274
+ ```
275
+ > az login
276
+ > az acr login --name intaicr
277
+ > docker push intaicr.azurecr.io/intai/exp-pnid-to-graph_chat-w-graph:0.0.4
278
+ ```
279
+
280
+ ## Models
281
+
282
+ ### Symbol Detection Models
283
+ - **Intui_SDM_41.pt**: Primary model for equipment and large symbol detection
284
+ - Classes: Equipment, Vessels, Heat Exchangers
285
+ - Input Resolution: 1280x1280
286
+ - Confidence Threshold: 0.3-0.7 (adaptive)
287
+
288
+ - **Intui_SDM_20.pt**: Secondary model for instrument and small symbol detection
289
+ - Classes: Instruments, Valves, Indicators
290
+ - Input Resolution: 1280x1280
291
+ - Confidence Threshold: 0.3-0.7 (adaptive)
292
+
293
+ ### Line Detection Model
294
+ - **intui_LDM_01.pt**: Specialized model for line and connection detection
295
+ - Classes: Solid Lines, Dashed Lines
296
+ - Input Resolution: 1280x1280
297
+ - Confidence Threshold: 0.5
298
+
299
+ ### Text Detection Models
300
+ - **Tesseract**: v5.3.0
301
+ - Configuration:
302
+ - OEM Mode: 3 (Default)
303
+ - PSM Mode: 11 (Sparse text)
304
+ - Custom Whitelist: A-Z, 0-9, special characters
305
+
306
+ - **EasyOCR**: v1.7.1
307
+ - Configuration:
308
+ - Language: English
309
+ - Paragraph Mode: False
310
+ - Height Threshold: 2.0
311
+ - Width Threshold: 2.0
312
+ - Contrast Threshold: 0.2
313
+
314
+ - **DocTR**: v0.6.0
315
+ - Models:
316
+ - fast_base-688a8b34.pt
317
+ - crnn_vgg16_bn-9762b0b0.pt
318
+
319
+ # P&ID Line Detection
320
+
321
+ A deep learning-based pipeline for detecting lines in P&ID diagrams using DeepLSD.
322
+
323
+ ## Architecture
324
+ ```mermaid
325
+ graph TD
326
+ A[Input Image] --> B[Line Detection]
327
+ B --> C[DeepLSD Model]
328
+ C --> D[Post-processing]
329
+ D --> E[Output JSON/Image]
330
+
331
+ subgraph Line Detection Pipeline
332
+ B --> F[Image Preprocessing]
333
+ F --> G[Scale Image 0.1x]
334
+ G --> H[Grayscale Conversion]
335
+ H --> C
336
+ C --> I[Scale Coordinates]
337
+ I --> J[Draw Lines]
338
+ J --> E
339
+ end
340
+ ```
341
+
342
+ ## Setup
343
+
344
+ ### Prerequisites
345
+ - Python 3.12+
346
+ - uv (for dependency management)
347
+ - Git
348
+ - CUDA-capable GPU (optional)
349
+
350
+ ### Installation
351
+
352
+ 1. Clone the repository:
353
+ ```bash
354
+ git clone https://github.com/IntuigenceAI/intui-PnID-POC.git
355
+ cd intui-PnID-POC
356
+ ```
357
+
358
+ 2. Install dependencies using uv:
359
+ ```bash
360
+ # Install uv if you haven't already
361
+ curl -LsSf https://astral.sh/uv/install.sh | sh
362
+
363
+ # Create and activate virtual environment
364
+ uv venv
365
+ source .venv/bin/activate # On Windows: .venv\Scripts\activate
366
+
367
+ # Install dependencies
368
+ uv pip install -r requirements.txt
369
+ ```
370
+
371
+ 3. Download DeepLSD model:
372
+ ```bash
373
+ python download_model.py
374
+ ```
375
+
376
+ ## Usage
377
+
378
+ 1. Run the line detection:
379
+ ```bash
380
+ python line_detection_ai.py
381
+ ```
382
+
383
+ The script will:
384
+ - Load the DeepLSD model
385
+ - Process input images at 0.1x scale for performance
386
+ - Generate line detections
387
+ - Save results as JSON and annotated images
388
+
389
+ ## Configuration
390
+
391
+ Key parameters in `line_detection_ai.py`:
392
+ - `scale_factor`: Image scaling (default: 0.1)
393
+ - `device`: CPU/GPU selection
394
+ - `mask_json_paths`: Paths to text/symbol detection results
395
+
396
+ ## Input/Output
397
+
398
+ ### Input
399
+ - Original P&ID images
400
+ - Optional text/symbol detection JSON files for masking
401
+
402
+ ### Output
403
+ - Annotated images with detected lines
404
+ - JSON files containing line coordinates and metadata
405
+
406
+ ## Project Structure
407
+
408
+ ```
409
+ ├── line_detection_ai.py # Main line detection script
410
+ ├── detectors.py # Line detector implementation
411
+ ├── download_model.py # Model download utility
412
+ ├── models/ # Directory for model files
413
+ │ └── deeplsd_md.tar # DeepLSD model weights
414
+ ├── results/ # Output directory
415
+ └── requirements.txt # Project dependencies
416
+ ```
417
+
418
+ ## Dependencies
419
+
420
+ Key dependencies:
421
+ - torch
422
+ - opencv-python
423
+ - numpy
424
+ - DeepLSD
425
+
426
+ See `requirements.txt` for the complete list.
427
+
428
+ ## Contributing
429
+
430
+ 1. Fork the repository
431
+ 2. Create your feature branch (`git checkout -b feature/amazing-feature`)
432
+ 3. Commit your changes (`git commit -m 'Add some amazing feature'`)
433
+ 4. Push to the branch (`git push origin feature/amazing-feature`)
434
+ 5. Open a Pull Request
435
+
436
+ ## License
437
+
438
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
439
+
440
+ ## Acknowledgments
441
+
442
+ - [DeepLSD](https://github.com/cvg/DeepLSD) for the line detection model
443
+ - Original P&ID processing pipeline by IntuigenceAI
444
+
445
+ ---
446
+ title: PnID Diagram Analyzer
447
+ emoji: 🔍
448
+ colorFrom: blue
449
+ colorTo: red
450
+ sdk: gradio
451
+ sdk_version: 4.19.2
452
+ app_file: gradioChatApp.py
453
+ pinned: false
454
+ ---
455
+
456
+ # PnID Diagram Analyzer
457
+
458
+ This app analyzes PnID diagrams using AI to detect and interpret various elements.
459
+
460
+ ## Features
461
+ - Line detection
462
+ - Symbol recognition
463
+ - Text detection
464
+ - Graph construction
base.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Dict
5
+
6
+ import numpy as np
7
+ import cv2
8
+
9
+ from pathlib import Path
10
+ from loguru import logger
11
+ import json
12
+
13
+ from common import DetectionResult
14
+ from storage import StorageInterface
15
+ from utils import DebugHandler, CoordinateTransformer
16
+
17
+
18
+ class BaseConfig(ABC):
19
+ """Abstract Base Config for all configuration classes."""
20
+
21
+ def __post_init__(self):
22
+ """Ensures default values are set correctly for all configs."""
23
+ pass
24
+
25
+ class BaseDetector(ABC):
26
+ """Abstract base class for detection models."""
27
+
28
+ def __init__(self,
29
+ config: BaseConfig,
30
+ debug_handler: DebugHandler = None):
31
+ self.config = config
32
+ self.debug_handler = debug_handler or DebugHandler()
33
+
34
+ @abstractmethod
35
+ def _load_model(self, model_path: str):
36
+ """Load and return the detection model."""
37
+ pass
38
+
39
+ @abstractmethod
40
+ def detect(self, image: np.ndarray, *args, **kwargs):
41
+ """Run detection on an input image."""
42
+ pass
43
+
44
+ @abstractmethod
45
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
46
+ """Preprocess the input image before detection."""
47
+ pass
48
+
49
+ @abstractmethod
50
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
51
+ """Postprocess the input image before detection."""
52
+ pass
53
+
54
+
55
+ class BaseDetectionPipeline(ABC):
56
+ """Abstract base class for detection pipelines."""
57
+
58
+ def __init__(
59
+ self,
60
+ storage: StorageInterface,
61
+ debug_handler=None
62
+ ):
63
+ # self.detector = detector
64
+ self.storage = storage
65
+ self.debug_handler = debug_handler or DebugHandler()
66
+ self.transformer = CoordinateTransformer()
67
+
68
+ @abstractmethod
69
+ def process_image(
70
+ self,
71
+ image_path: str,
72
+ output_dir: str,
73
+ config
74
+ ) -> DetectionResult:
75
+ """Main processing pipeline for a single image."""
76
+ pass
77
+
78
+ def _apply_roi(self, image: np.ndarray, roi: np.ndarray) -> np.ndarray:
79
+ """Apply region of interest cropping."""
80
+ if roi is not None and len(roi) == 4:
81
+ x_min, y_min, x_max, y_max = roi
82
+ return image[y_min:y_max, x_min:x_max]
83
+ return image
84
+
85
+ def _adjust_coordinates(self, detections: List[Dict], roi: np.ndarray) -> List[Dict]:
86
+ """Adjust detection coordinates based on ROI"""
87
+ if roi is None or len(roi) != 4:
88
+ return detections
89
+
90
+ x_offset, y_offset = roi[0], roi[1]
91
+ adjusted = []
92
+
93
+ for det in detections:
94
+ try:
95
+ adjusted_bbox = [
96
+ int(det["bbox"][0] + x_offset),
97
+ int(det["bbox"][1] + y_offset),
98
+ int(det["bbox"][2] + x_offset),
99
+ int(det["bbox"][3] + y_offset)
100
+ ]
101
+ adjusted_det = {**det, "bbox": adjusted_bbox}
102
+ adjusted.append(adjusted_det)
103
+ except KeyError:
104
+ logger.warning("Invalid detection format during coordinate adjustment")
105
+ return adjusted
106
+
107
+ def _persist_results(
108
+ self,
109
+ output_dir: str,
110
+ image_path: str,
111
+ detections: List[Dict],
112
+ annotated_image: Optional[np.ndarray]
113
+ ) -> Dict[str, str]:
114
+ """Save detection results and annotations"""
115
+ self.storage.create_directory(output_dir)
116
+ base_name = Path(image_path).stem
117
+
118
+ # Save JSON results
119
+ json_path = Path(output_dir) / f"{base_name}_lines.json"
120
+ self.storage.save_file(
121
+ str(json_path),
122
+ json.dumps({
123
+ "solid_lines": {"lines": detections},
124
+ "dashed_lines": {"lines": []}
125
+ }, indent=2).encode('utf-8')
126
+ )
127
+
128
+ # Save annotated image
129
+ img_path = None
130
+ if annotated_image is not None:
131
+ img_path = Path(output_dir) / f"{base_name}_annotated.jpg"
132
+ _, img_data = cv2.imencode('.jpg', annotated_image)
133
+ self.storage.save_file(str(img_path), img_data.tobytes())
134
+
135
+ return {
136
+ "json_path": str(json_path),
137
+ "image_path": str(img_path) if img_path else None
138
+ }
chatbot_agent.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # chatbot_agent.py
2
+
3
+ import os
4
+ import json
5
+ import re
6
+ from openai import OpenAI
7
+ import traceback
8
+ import logging
9
+
10
+ # Get logger
11
+ logger = logging.getLogger(__name__)
12
+
13
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
14
+
15
+ def format_message(role, content):
16
+ """Format message for chatbot history."""
17
+ return {"role": role, "content": content}
18
+
19
+ def initialize_graph_prompt(graph_data):
20
+ """Initialize the conversation with detailed node, edge, linker, and text information."""
21
+ summary_info = (
22
+ f"Symbols (represented as nodes): {graph_data['summary']['symbol_count']}, "
23
+ f"Texts: {graph_data['summary']['text_count']}, "
24
+ f"Lines: {graph_data['summary']['line_count']}, "
25
+ f"Linkers: {graph_data['summary']['linker_count']}, "
26
+ f"Edges: {graph_data['summary']['edge_count']}."
27
+ )
28
+
29
+ # Prepare detailed node (symbol) data
30
+ node_details = "Nodes (symbols) in the graph include the following details:\n"
31
+ for symbol in graph_data["detailed_results"]["symbols"]:
32
+ node_details += (
33
+ f"Node ID: {symbol['symbol_id']}, Class ID: {symbol['class_id']}, "
34
+ f"Category: {symbol['category']}, Type: {symbol['type']}, "
35
+ f"Label: {symbol['label']}, Confidence: {symbol['confidence']}\n"
36
+ )
37
+
38
+ # Prepare edge data
39
+ edge_details = "Edges in the graph showing connections between nodes are as follows:\n"
40
+ for edge in graph_data["detailed_results"].get("edges", []):
41
+ edge_details += (
42
+ f"Edge ID: {edge['edge_id']}, From Node: {edge['symbol_1_id']}, "
43
+ f"To Node: {edge['symbol_2_id']}, Type: {edge.get('type', 'unknown')}\n"
44
+ )
45
+
46
+
47
+ # Prepare linker data
48
+ linker_details = "Linkers in the diagram are as follows:\n"
49
+ for linker in graph_data["detailed_results"].get("linkers", []):
50
+ linker_details += (
51
+ f"Symbol ID: {linker['symbol_id']}, Associated Text IDs: {linker.get('text_ids', [])}, "
52
+ f"Associated Edge IDs: {linker.get('edge_ids', [])}, Position: {linker.get('bbox', 'unknown')}\n"
53
+ )
54
+
55
+
56
+ # Prepare text (tag) data
57
+ text_details = "Text elements with associated tags in the diagram are as follows:\n"
58
+ for text in graph_data["detailed_results"].get("texts", []):
59
+ text_details += (
60
+ f"Text ID: {text['text_id']}, Content: {text['content']}, "
61
+ f"Confidence: {text['confidence']}, Position: {text['bbox']}\n"
62
+ )
63
+
64
+
65
+ initial_prompt = (
66
+ "You have access to a knowledge graph generated from a P&ID diagram. "
67
+ f"The summary information includes:\n{summary_info}\n\n"
68
+ "The detailed information about each node (symbol) in the graph is as follows:\n"
69
+ f"{node_details}\n"
70
+ "The edges connecting these nodes are as follows:\n"
71
+ f"{edge_details}\n"
72
+ "The linkers in the diagram are as follows:\n"
73
+ f"{linker_details}\n"
74
+ "The text elements and their tags in the diagram are as follows:\n"
75
+ f"{text_details}\n"
76
+ "Answer questions about specific nodes, edges, types, labels, categories, linkers, or text tags using this information."
77
+ )
78
+
79
+ return initial_prompt
80
+
81
+ def get_assistant_response(user_message, json_path):
82
+ """Generate response based on P&ID data and OpenAI."""
83
+ try:
84
+ # Load the aggregated data
85
+ with open(json_path, 'r') as f:
86
+ data = json.load(f)
87
+
88
+ # Process the user's question
89
+ question = user_message.lower()
90
+
91
+ # Use rule-based responses for specific questions
92
+ if "valve" in question or "valves" in question:
93
+ valve_count = sum(1 for symbol in data.get('symbols', [])
94
+ if 'class' in symbol and 'valve' in symbol['class'].lower())
95
+ return f"I found {valve_count} valves in this P&ID."
96
+
97
+ elif "pump" in question or "pumps" in question:
98
+ pump_count = sum(1 for symbol in data.get('symbols', [])
99
+ if 'class' in symbol and 'pump' in symbol['class'].lower())
100
+ return f"I found {pump_count} pumps in this P&ID."
101
+
102
+ elif "equipment" in question or "components" in question:
103
+ equipment_types = {}
104
+ for symbol in data.get('symbols', []):
105
+ if 'class' in symbol:
106
+ eq_type = symbol['class']
107
+ equipment_types[eq_type] = equipment_types.get(eq_type, 0) + 1
108
+
109
+ response = "Here's a summary of the equipment I found:\n"
110
+ for eq_type, count in equipment_types.items():
111
+ response += f"- {eq_type}: {count}\n"
112
+ return response
113
+
114
+ # For other questions, use OpenAI
115
+ else:
116
+ # Prepare the conversation context
117
+ graph_data = {
118
+ "summary": {
119
+ "symbol_count": len(data.get('symbols', [])),
120
+ "text_count": len(data.get('texts', [])),
121
+ "line_count": len(data.get('lines', [])),
122
+ "edge_count": len(data.get('edges', [])),
123
+ },
124
+ "detailed_results": data
125
+ }
126
+
127
+ initial_prompt = initialize_graph_prompt(graph_data)
128
+ conversation = [
129
+ {"role": "system", "content": initial_prompt},
130
+ {"role": "user", "content": user_message}
131
+ ]
132
+
133
+ response = client.chat.completions.create(
134
+ model="gpt-4-turbo",
135
+ messages=conversation
136
+ )
137
+ return response.choices[0].message.content
138
+
139
+ except Exception as e:
140
+ logger.error(f"Error in get_assistant_response: {str(e)}")
141
+ logger.error(traceback.format_exc())
142
+ return "I apologize, but I encountered an error analyzing the P&ID data. Please try asking a different question."
143
+
144
+ # Testing and Usage block
145
+ if __name__ == "__main__":
146
+ # Load the knowledge graph data from JSON file
147
+ json_file_path = "results/0_aggregated_detections.json"
148
+ try:
149
+ with open(json_file_path, 'r') as file:
150
+ graph_data = json.load(file)
151
+ except FileNotFoundError:
152
+ print(f"Error: File not found at {json_file_path}")
153
+ graph_data = None
154
+ except json.JSONDecodeError:
155
+ print("Error: Failed to decode JSON. Please check the file format.")
156
+ graph_data = None
157
+
158
+ # Initialize conversation history with assistant's welcome message
159
+ history = [format_message("assistant", "Hello! I am ready to answer your questions about the P&ID knowledge graph. The graph includes nodes (symbols), edges, linkers, and text tags, and I have detailed information available about each. Please ask any questions related to these elements and their connections.")]
160
+
161
+ # Print the assistant's welcome message
162
+ print("Assistant:", history[0]["content"])
163
+
164
+ # Individual Testing Options
165
+ if graph_data:
166
+ # Option 1: Test the graph prompt initialization
167
+ print("\n--- Test: Graph Prompt Initialization ---")
168
+ initial_prompt = initialize_graph_prompt(graph_data)
169
+ print(initial_prompt)
170
+
171
+ # Option 2: Simulate a conversation with a test question
172
+ print("\n--- Test: Simulate Conversation ---")
173
+ test_question = "Can you tell me about the connections between the nodes?"
174
+ history.append(format_message("user", test_question))
175
+
176
+ print(f"\nUser: {test_question}")
177
+ for response in get_assistant_response(test_question, json_file_path):
178
+ print("Assistant:", response)
179
+ history.append(format_message("assistant", response))
180
+
181
+ # Option 3: Manually input questions for interactive testing
182
+ while True:
183
+ user_question = input("\nYou: ")
184
+ if user_question.lower() in ["exit", "quit"]:
185
+ print("Exiting chat. Goodbye!")
186
+ break
187
+
188
+ history.append(format_message("user", user_question))
189
+ for response in get_assistant_response(user_question, json_file_path):
190
+ print("Assistant:", response)
191
+ history.append(format_message("assistant", response))
192
+ else:
193
+ print("Unable to load graph data. Please check the file path and format.")
common.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import List, Dict, Optional, Tuple, Union
3
+ import numpy as np
4
+
5
+ from detection_schema import Line
6
+
7
+ @dataclass
8
+ class DetectionResult:
9
+ success: bool
10
+ error: Optional[str] = None
11
+ annotated_image: Optional[np.ndarray] = None
12
+ processing_time: float = 0.0
13
+ json_path: Optional[str] = None
14
+ image_path: Optional[str] = None
config.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Dict, Optional, Tuple, Union
3
+ import numpy as np
4
+ from base import BaseConfig
5
+
6
+
7
+ @dataclass
8
+ class ImageConfig(BaseConfig):
9
+ """Configuration for global image-related settings"""
10
+
11
+ roi: Optional[np.ndarray] = field(default_factory=lambda: np.array([500, 500, 5300, 4000]))
12
+ mask_json_path: str = "./text_and_symbol_bboxes.json"
13
+ save_annotations: bool = True
14
+ annotation_style: Dict = field(default_factory=lambda: {
15
+ 'bbox_color': (255, 0, 0),
16
+ 'line_color': (0, 255, 0),
17
+ 'text_color': (0, 0, 255),
18
+ 'thickness': 2,
19
+ 'font_scale': 0.6
20
+ })
21
+
22
+ @dataclass
23
+ class SymbolConfig:
24
+ """Configuration for Symbol Detection"""
25
+ model_paths: Dict[str, str] = field(default_factory=lambda: {
26
+ "model1": "models/Intui_SDM_41.pt",
27
+ "model2": "models/Intui_SDM_20.pt"
28
+ })
29
+ # Multiple thresholds to test
30
+ confidence_thresholds: List[float] = field(default_factory=lambda: [0.1, 0.3, 0.5, 0.7, 0.9])
31
+ apply_preprocessing: bool = False
32
+ resize_image: bool = True
33
+ max_dimension: int = 1280
34
+ iou_threshold: float = 0.5
35
+ # Optional mapping from class_id to SymbolType
36
+ symbol_type_mapping: Dict[str, str] = field(default_factory=lambda: {
37
+ "valve": "VALVE",
38
+ "pump": "PUMP",
39
+ "sensor": "SENSOR"
40
+ })
41
+
42
+
43
+ @dataclass
44
+ class TagConfig(BaseConfig):
45
+ """Configuration for Tag Detection with OCR"""
46
+ confidence_threshold: float = 0.5
47
+ iou_threshold: float = 0.4
48
+ ocr_engines: List[str] = field(default_factory=lambda: ['tesseract', 'easyocr', 'doctr'])
49
+ text_patterns: Dict[str, str] = field(default_factory=lambda: {
50
+ 'Line_Number': r"\d{1,5}-[A-Z]{2,4}-\d{1,3}",
51
+ 'Equipment_Tag': r"[A-Z]{1,3}-[A-Z0-9]{1,4}-\d{1,3}",
52
+ 'Instrument_Tag': r"\d{2,3}-[A-Z]{2,4}-\d{2,3}",
53
+ 'Valve_Number': r"[A-Z]{1,2}-\d{3}",
54
+ 'Pipe_Size': r"\d{1,2}\"",
55
+ 'Flow_Direction': r"FROM|TO",
56
+ 'Service_Description': r"STEAM|WATER|AIR|GAS|DRAIN",
57
+ 'Process_Instrument': r"\d{2,3}(?:-[A-Z]{2,3})?-\d{2,3}|[A-Z]{2,3}-\d{2,3}",
58
+ 'Nozzle': r"N[0-9]{1,2}|MH",
59
+ 'Pipe_Connector': r"[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5}"
60
+ })
61
+ tesseract_config: str = r'--oem 3 --psm 11'
62
+ easyocr_params: Dict = field(default_factory=lambda: {
63
+ 'paragraph': False,
64
+ 'height_ths': 2.0,
65
+ 'width_ths': 2.0,
66
+ 'contrast_ths': 0.2
67
+ })
68
+
69
+ @dataclass
70
+ class LineConfig(BaseConfig):
71
+ """Configuration for Line Detection"""
72
+
73
+ threshold_distance: float = 10.0
74
+ expansion_factor: float = 1.1
75
+
76
+
77
+ @dataclass
78
+ class PointConfig(BaseConfig):
79
+ """Configuration for Point Detection"""
80
+
81
+ threshold_distance: float = 10.0
82
+
83
+
84
+ @dataclass
85
+ class JunctionConfig(BaseConfig):
86
+ """Configuration for Junction Detection"""
87
+
88
+ window_size: int = 21
89
+ radius: int = 5
90
+ angle_threshold_lb: float = 15.0
91
+ angle_threshold_ub: float = 75.0
92
+ <<<<<<< HEAD
93
+ expansion_factor: float = 1.1
94
+ mask_json_paths: List[str] = field(default_factory=list) # Now a list
95
+
96
+ roi: Optional[np.ndarray] = field(default_factory=lambda: np.array([300, 500, 7500, 7000]))
97
+ save_annotations: bool = True
98
+ annotation_style: Dict = None
99
+
100
+ def __post_init__(self):
101
+ self.annotation_style = self.annotation_style or {
102
+ 'bbox_color': (255, 0, 0),
103
+ 'line_color': (0, 255, 0),
104
+ 'text_color': (0, 0, 255),
105
+ 'thickness': 2,
106
+ 'font_scale': 0.6
107
+ }
108
+ =======
109
+
110
+ # @dataclass
111
+ # class JunctionConfig:
112
+ # radius: int = 5
113
+ # angle_threshold: float = 25.0
114
+ # colinear_threshold: float = 5.0
115
+ # connection_threshold: float = 5.0
116
+ >>>>>>> temp/test-integration
data_aggregation_ai.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import json
3
+ import logging
4
+ from datetime import datetime
5
+ from typing import List, Dict, Optional, Tuple
6
+ from storage import StorageFactory
7
+ import uuid
8
+ import traceback
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class DataAggregator:
13
+ def __init__(self, storage=None):
14
+ self.storage = storage or StorageFactory.get_storage()
15
+ self.logger = logging.getLogger(__name__)
16
+
17
+ def _parse_line_data(self, lines_data: dict) -> List[dict]:
18
+ """Parse line detection data with coordinate validation"""
19
+ parsed_lines = []
20
+
21
+ for line in lines_data.get("lines", []):
22
+ try:
23
+ # Extract and validate line coordinates
24
+ start_coords = line["start"]["coords"]
25
+ end_coords = line["end"]["coords"]
26
+ bbox = line["bbox"]
27
+
28
+ # Validate coordinates
29
+ if not (self._is_valid_point(start_coords) and
30
+ self._is_valid_point(end_coords) and
31
+ self._is_valid_bbox(bbox)):
32
+ self.logger.warning(f"Invalid coordinates in line: {line['id']}")
33
+ continue
34
+
35
+ # Create parsed line with validated coordinates
36
+ parsed_line = {
37
+ "id": line["id"],
38
+ "start_point": {
39
+ "x": int(start_coords["x"]),
40
+ "y": int(start_coords["y"]),
41
+ "type": line["start"]["type"],
42
+ "confidence": line["start"]["confidence"]
43
+ },
44
+ "end_point": {
45
+ "x": int(end_coords["x"]),
46
+ "y": int(end_coords["y"]),
47
+ "type": line["end"]["type"],
48
+ "confidence": line["end"]["confidence"]
49
+ },
50
+ "bbox": {
51
+ "xmin": int(bbox["xmin"]),
52
+ "ymin": int(bbox["ymin"]),
53
+ "xmax": int(bbox["xmax"]),
54
+ "ymax": int(bbox["ymax"])
55
+ },
56
+ "style": line["style"],
57
+ "confidence": line["confidence"]
58
+ }
59
+ parsed_lines.append(parsed_line)
60
+
61
+ except Exception as e:
62
+ self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}")
63
+ continue
64
+
65
+ return parsed_lines
66
+
67
+ def _is_valid_point(self, point: dict) -> bool:
68
+ """Validate point coordinates"""
69
+ try:
70
+ x, y = point.get("x"), point.get("y")
71
+ return (isinstance(x, (int, float)) and
72
+ isinstance(y, (int, float)) and
73
+ 0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed
74
+ except:
75
+ return False
76
+
77
+ def _is_valid_bbox(self, bbox: dict) -> bool:
78
+ """Validate bbox coordinates"""
79
+ try:
80
+ xmin = bbox.get("xmin")
81
+ ymin = bbox.get("ymin")
82
+ xmax = bbox.get("xmax")
83
+ ymax = bbox.get("ymax")
84
+
85
+ return (isinstance(xmin, (int, float)) and
86
+ isinstance(ymin, (int, float)) and
87
+ isinstance(xmax, (int, float)) and
88
+ isinstance(ymax, (int, float)) and
89
+ xmin < xmax and ymin < ymax and
90
+ 0 <= xmin <= 10000 and 0 <= ymin <= 10000 and
91
+ 0 <= xmax <= 10000 and 0 <= ymax <= 10000)
92
+ except:
93
+ return False
94
+
95
+ def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]:
96
+ """Create nodes and edges for the knowledge graph following the three-step process"""
97
+ nodes = []
98
+ edges = []
99
+
100
+ # Step 1: Create Object Nodes with their properties and center points
101
+ # 1a. Symbol Nodes
102
+ for symbol in symbols:
103
+ bbox = symbol["bbox"]
104
+ center_x = (bbox["xmin"] + bbox["xmax"]) / 2
105
+ center_y = (bbox["ymin"] + bbox["ymax"]) / 2
106
+
107
+ node = {
108
+ "id": symbol.get("id", str(uuid.uuid4())),
109
+ "type": "symbol",
110
+ "category": symbol.get("category", "unknown"),
111
+ "bbox": bbox,
112
+ "center": {"x": center_x, "y": center_y},
113
+ "confidence": symbol.get("confidence", 1.0),
114
+ "properties": {
115
+ "class": symbol.get("class", ""),
116
+ "equipment_type": symbol.get("equipment_type", ""),
117
+ "original_label": symbol.get("original_label", ""),
118
+ }
119
+ }
120
+ nodes.append(node)
121
+
122
+ # 1b. Text Nodes
123
+ for text in texts:
124
+ bbox = text["bbox"]
125
+ center_x = (bbox["xmin"] + bbox["xmax"]) / 2
126
+ center_y = (bbox["ymin"] + bbox["ymax"]) / 2
127
+
128
+ node = {
129
+ "id": text.get("id", str(uuid.uuid4())),
130
+ "type": "text",
131
+ "content": text.get("text", ""),
132
+ "bbox": bbox,
133
+ "center": {"x": center_x, "y": center_y},
134
+ "confidence": text.get("confidence", 1.0),
135
+ "properties": {
136
+ "font_size": text.get("font_size"),
137
+ "rotation": text.get("rotation", 0.0),
138
+ "text_type": text.get("text_type", "unknown")
139
+ }
140
+ }
141
+ nodes.append(node)
142
+
143
+ # Step 2: Create Junction Nodes (T/L connections)
144
+ junction_map = {} # To track junctions for edge creation
145
+ for line in lines:
146
+ # Check start point
147
+ if line["start_point"].get("type") in ["T", "L"]:
148
+ junction_id = str(uuid.uuid4())
149
+ junction_node = {
150
+ "id": junction_id,
151
+ "type": "junction",
152
+ "junction_type": line["start_point"]["type"],
153
+ "coords": {
154
+ "x": line["start_point"]["x"],
155
+ "y": line["start_point"]["y"]
156
+ },
157
+ "properties": {
158
+ "confidence": line["start_point"].get("confidence", 1.0)
159
+ }
160
+ }
161
+ nodes.append(junction_node)
162
+ junction_map[f"{line['start_point']['x']}_{line['start_point']['y']}"] = junction_id
163
+
164
+ # Check end point
165
+ if line["end_point"].get("type") in ["T", "L"]:
166
+ junction_id = str(uuid.uuid4())
167
+ junction_node = {
168
+ "id": junction_id,
169
+ "type": "junction",
170
+ "junction_type": line["end_point"]["type"],
171
+ "coords": {
172
+ "x": line["end_point"]["x"],
173
+ "y": line["end_point"]["y"]
174
+ },
175
+ "properties": {
176
+ "confidence": line["end_point"].get("confidence", 1.0)
177
+ }
178
+ }
179
+ nodes.append(junction_node)
180
+ junction_map[f"{line['end_point']['x']}_{line['end_point']['y']}"] = junction_id
181
+
182
+ # Step 3: Create Edges with connection points and topology
183
+ # 3a. Line-Junction Connections
184
+ for line in lines:
185
+ line_id = line.get("id", str(uuid.uuid4()))
186
+ start_key = f"{line['start_point']['x']}_{line['start_point']['y']}"
187
+ end_key = f"{line['end_point']['x']}_{line['end_point']['y']}"
188
+
189
+ # Create edge for line itself
190
+ edge = {
191
+ "id": line_id,
192
+ "type": "line",
193
+ "source": junction_map.get(start_key, str(uuid.uuid4())),
194
+ "target": junction_map.get(end_key, str(uuid.uuid4())),
195
+ "properties": {
196
+ "style": line["style"],
197
+ "confidence": line.get("confidence", 1.0),
198
+ "connection_points": {
199
+ "start": {"x": line["start_point"]["x"], "y": line["start_point"]["y"]},
200
+ "end": {"x": line["end_point"]["x"], "y": line["end_point"]["y"]}
201
+ },
202
+ "bbox": line["bbox"]
203
+ }
204
+ }
205
+ edges.append(edge)
206
+
207
+ # 3b. Symbol-Line Connections (based on spatial proximity)
208
+ for symbol in symbols:
209
+ symbol_center = {
210
+ "x": (symbol["bbox"]["xmin"] + symbol["bbox"]["xmax"]) / 2,
211
+ "y": (symbol["bbox"]["ymin"] + symbol["bbox"]["ymax"]) / 2
212
+ }
213
+
214
+ # Find connected lines based on proximity to endpoints
215
+ for line in lines:
216
+ # Check if line endpoints are near symbol center
217
+ for point_type in ["start_point", "end_point"]:
218
+ point = line[point_type]
219
+ dist = ((point["x"] - symbol_center["x"])**2 +
220
+ (point["y"] - symbol_center["y"])**2)**0.5
221
+
222
+ if dist < 50: # Threshold for connection, adjust as needed
223
+ edge = {
224
+ "id": str(uuid.uuid4()),
225
+ "type": "symbol_line_connection",
226
+ "source": symbol["id"],
227
+ "target": line["id"],
228
+ "properties": {
229
+ "connection_point": {"x": point["x"], "y": point["y"]},
230
+ "connection_type": point_type,
231
+ "distance": dist
232
+ }
233
+ }
234
+ edges.append(edge)
235
+
236
+ # 3c. Symbol-Text Associations (based on proximity and containment)
237
+ for text in texts:
238
+ text_center = {
239
+ "x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2,
240
+ "y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2
241
+ }
242
+
243
+ for symbol in symbols:
244
+ # Check if text is near or contained within symbol
245
+ if (text_center["x"] >= symbol["bbox"]["xmin"] - 20 and
246
+ text_center["x"] <= symbol["bbox"]["xmax"] + 20 and
247
+ text_center["y"] >= symbol["bbox"]["ymin"] - 20 and
248
+ text_center["y"] <= symbol["bbox"]["ymax"] + 20):
249
+
250
+ edge = {
251
+ "id": str(uuid.uuid4()),
252
+ "type": "symbol_text_association",
253
+ "source": symbol["id"],
254
+ "target": text["id"],
255
+ "properties": {
256
+ "association_type": "label",
257
+ "confidence": min(symbol.get("confidence", 1.0),
258
+ text.get("confidence", 1.0))
259
+ }
260
+ }
261
+ edges.append(edge)
262
+
263
+ # 3d. Line-Text Associations (based on proximity and alignment)
264
+ for text in texts:
265
+ text_center = {
266
+ "x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2,
267
+ "y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2
268
+ }
269
+ text_bbox = text["bbox"]
270
+
271
+ for line in lines:
272
+ line_bbox = line["bbox"]
273
+ line_center = {
274
+ "x": (line_bbox["xmin"] + line_bbox["xmax"]) / 2,
275
+ "y": (line_bbox["ymin"] + line_bbox["ymax"]) / 2
276
+ }
277
+
278
+ # Check if text is near the line (using both center and bbox)
279
+ is_nearby_horizontal = (
280
+ abs(text_center["y"] - line_center["y"]) < 30 and # Vertical proximity
281
+ text_bbox["xmin"] <= line_bbox["xmax"] and
282
+ text_bbox["xmax"] >= line_bbox["xmin"]
283
+ )
284
+
285
+ is_nearby_vertical = (
286
+ abs(text_center["x"] - line_center["x"]) < 30 and # Horizontal proximity
287
+ text_bbox["ymin"] <= line_bbox["ymax"] and
288
+ text_bbox["ymax"] >= line_bbox["ymin"]
289
+ )
290
+
291
+ # Determine text type and position relative to line
292
+ if is_nearby_horizontal or is_nearby_vertical:
293
+ text_type = text.get("text_type", "unknown").lower()
294
+
295
+ # Classify the text based on content and position
296
+ if any(pattern in text.get("text", "").upper()
297
+ for pattern in ["-", "LINE", "PIPE"]):
298
+ association_type = "line_id"
299
+ else:
300
+ association_type = "description"
301
+
302
+ edge = {
303
+ "id": str(uuid.uuid4()),
304
+ "type": "line_text_association",
305
+ "source": line["id"],
306
+ "target": text["id"],
307
+ "properties": {
308
+ "association_type": association_type,
309
+ "relative_position": "horizontal" if is_nearby_horizontal else "vertical",
310
+ "confidence": min(line.get("confidence", 1.0),
311
+ text.get("confidence", 1.0)),
312
+ "distance": abs(text_center["y"] - line_center["y"]) if is_nearby_horizontal
313
+ else abs(text_center["x"] - line_center["x"])
314
+ }
315
+ }
316
+ edges.append(edge)
317
+
318
+ return nodes, edges
319
+
320
+ def _validate_coordinates(self, data, data_type):
321
+ """Validate coordinates in the data"""
322
+ if not data:
323
+ return False
324
+
325
+ try:
326
+ if data_type == 'line':
327
+ # Check start and end points
328
+ start = data.get('start_point', {})
329
+ end = data.get('end_point', {})
330
+ bbox = data.get('bbox', {})
331
+
332
+ required_fields = ['x', 'y', 'type']
333
+ if not all(field in start for field in required_fields):
334
+ self.logger.warning(f"Missing required fields in start_point: {start}")
335
+ return False
336
+ if not all(field in end for field in required_fields):
337
+ self.logger.warning(f"Missing required fields in end_point: {end}")
338
+ return False
339
+
340
+ # Validate bbox coordinates
341
+ if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
342
+ self.logger.warning(f"Invalid bbox format: {bbox}")
343
+ return False
344
+
345
+ # Check coordinate consistency
346
+ if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
347
+ self.logger.warning(f"Invalid bbox coordinates: {bbox}")
348
+ return False
349
+
350
+ elif data_type in ['symbol', 'text']:
351
+ bbox = data.get('bbox', {})
352
+ if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
353
+ self.logger.warning(f"Invalid {data_type} bbox format: {bbox}")
354
+ return False
355
+
356
+ # Check coordinate consistency
357
+ if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
358
+ self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}")
359
+ return False
360
+
361
+ return True
362
+
363
+ except Exception as e:
364
+ self.logger.error(f"Validation error for {data_type}: {str(e)}")
365
+ return False
366
+
367
+ def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict:
368
+ """Aggregate detection results and create graph structure"""
369
+ try:
370
+ # Load line detection results
371
+ lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8'))
372
+ lines = self._parse_line_data(lines_data)
373
+
374
+ # Load symbol detections
375
+ symbols = []
376
+ if symbols_path and Path(symbols_path).exists():
377
+ symbols_data = json.loads(self.storage.load_file(symbols_path).decode('utf-8'))
378
+ symbols = symbols_data.get("symbols", [])
379
+
380
+ # Load text detections
381
+ texts = []
382
+ if texts_path and Path(texts_path).exists():
383
+ texts_data = json.loads(self.storage.load_file(texts_path).decode('utf-8'))
384
+ texts = texts_data.get("texts", [])
385
+
386
+ # Create graph data
387
+ nodes, edges = self._create_graph_data(lines, symbols, texts)
388
+
389
+ # Combine all detections
390
+ aggregated_data = {
391
+ "lines": lines,
392
+ "symbols": symbols,
393
+ "texts": texts,
394
+ "nodes": nodes,
395
+ "edges": edges,
396
+ "metadata": {
397
+ "timestamp": datetime.now().isoformat(),
398
+ "version": "2.0"
399
+ }
400
+ }
401
+
402
+ return aggregated_data
403
+
404
+ except Exception as e:
405
+ logger.error(f"Error during aggregation: {str(e)}")
406
+ raise
407
+
408
+ if __name__ == "__main__":
409
+ import os
410
+ from pprint import pprint
411
+
412
+ # Initialize the aggregator
413
+ aggregator = DataAggregator()
414
+
415
+ # Test paths (adjust these to match your results folder)
416
+ results_dir = "results/"
417
+ symbols_path = os.path.join(results_dir, "0_text_detected_symbols.json")
418
+ texts_path = os.path.join(results_dir, "0_text_detected_texts.json")
419
+ lines_path = os.path.join(results_dir, "0_text_detected_lines.json")
420
+
421
+ try:
422
+ # Aggregate the data
423
+ aggregated_data = aggregator.aggregate_data(
424
+ symbols_path=symbols_path,
425
+ texts_path=texts_path,
426
+ lines_path=lines_path
427
+ )
428
+
429
+ # Save the aggregated result
430
+ output_path = os.path.join(results_dir, "0_aggregated_test.json")
431
+ with open(output_path, 'w') as f:
432
+ json.dump(aggregated_data, f, indent=2)
433
+
434
+ # Print some statistics
435
+ print("\nAggregation Results:")
436
+ print(f"Number of Symbols: {len(aggregated_data['symbols'])}")
437
+ print(f"Number of Texts: {len(aggregated_data['texts'])}")
438
+ print(f"Number of Lines: {len(aggregated_data['lines'])}")
439
+ print(f"Number of Nodes: {len(aggregated_data['nodes'])}")
440
+ print(f"Number of Edges: {len(aggregated_data['edges'])}")
441
+
442
+ # Print sample of each type
443
+ print("\nSample Node:")
444
+ if aggregated_data['nodes']:
445
+ pprint(aggregated_data['nodes'][0])
446
+
447
+ print("\nSample Edge:")
448
+ if aggregated_data['edges']:
449
+ pprint(aggregated_data['edges'][0])
450
+
451
+ print(f"\nAggregated data saved to: {output_path}")
452
+
453
+ except Exception as e:
454
+ print(f"Error during testing: {str(e)}")
455
+ traceback.print_exc()
detection_schema.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from typing import List, Optional, Tuple, Dict
3
+ import uuid
4
+ from enum import Enum
5
+ import json
6
+ import numpy as np
7
+
8
+ # ======================== Point ======================== #
9
+ class ConnectionType(Enum):
10
+ SOLID = "solid"
11
+ DASHED = "dashed"
12
+ PHANTOM = "phantom"
13
+
14
+ @dataclass
15
+ class Coordinates:
16
+ x: int
17
+ y: int
18
+
19
+ @dataclass
20
+ class BBox:
21
+ xmin: int
22
+ ymin: int
23
+ xmax: int
24
+ ymax: int
25
+
26
+ def width(self) -> int:
27
+ return self.xmax - self.xmin
28
+
29
+ def height(self) -> int:
30
+ return self.ymax - self.ymin
31
+
32
+ class JunctionType(str, Enum):
33
+ T = "T"
34
+ L = "L"
35
+ END = "END"
36
+
37
+ @dataclass
38
+ class Point:
39
+ coords: Coordinates
40
+ bbox: BBox
41
+ type: JunctionType
42
+ confidence: float = 1.0
43
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
44
+
45
+
46
+ # # ======================== Symbol ======================== #
47
+ # class SymbolType(Enum):
48
+ # VALVE = "valve"
49
+ # PUMP = "pump"
50
+ # SENSOR = "sensor"
51
+ # # Add others as needed
52
+ #
53
+ class ValveSubtype(Enum):
54
+ GATE = "gate"
55
+ GLOBE = "globe"
56
+ BUTTERFLY = "butterfly"
57
+ #
58
+ # @dataclass
59
+ # class Symbol:
60
+ # symbol_type: SymbolType
61
+ # bbox: BBox
62
+ # center: Coordinates
63
+ # connections: List[Point] = field(default_factory=list)
64
+ # subtype: Optional[ValveSubtype] = None
65
+ # id: str = field(default_factory=lambda: str(uuid.uuid4()))
66
+ # confidence: float = 0.95
67
+ # model_metadata: dict = field(default_factory=dict)
68
+
69
+
70
+ # ======================== Symbol ======================== #
71
+ class SymbolType(Enum):
72
+ VALVE = "valve"
73
+ PUMP = "pump"
74
+ SENSOR = "sensor"
75
+ OTHER = "other" # Added to handle unknown categories
76
+
77
+ @dataclass
78
+ class Symbol:
79
+ center: Coordinates
80
+ symbol_type: SymbolType = field(default=SymbolType.OTHER)
81
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
82
+ class_id: int = -1
83
+ original_label: str = ""
84
+ category: str = "" # e.g., "inst"
85
+ type: str = "" # e.g., "ind"
86
+ label: str = "" # e.g., "Solenoid_actuator"
87
+ bbox: BBox = None
88
+ confidence: float = 0.95
89
+ model_source: str = "" # e.g., "model2"
90
+ connections: List[Point] = field(default_factory=list)
91
+ subtype: Optional[ValveSubtype] = None
92
+ model_metadata: dict = field(default_factory=dict)
93
+
94
+ def __post_init__(self):
95
+ """
96
+ Handle any additional post-processing after initialization.
97
+ """
98
+ # Ensure bbox is a BBox object
99
+ if isinstance(self.bbox, list) and len(self.bbox) == 4:
100
+ self.bbox = BBox(*self.bbox)
101
+
102
+
103
+ # ======================== Line ======================== #
104
+ @dataclass
105
+ class LineStyle:
106
+ connection_type: ConnectionType
107
+ stroke_width: int = 2
108
+ color: str = "#000000" # CSS-style colors
109
+
110
+ @dataclass
111
+ class Line:
112
+ start: Point
113
+ end: Point
114
+ bbox: BBox
115
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
116
+ style: LineStyle = field(default_factory=lambda: LineStyle(ConnectionType.SOLID))
117
+ confidence: float = 0.90
118
+ topological_links: List[str] = field(default_factory=list) # Linked symbols/junctions
119
+
120
+
121
+ # ======================== Junction ======================== #
122
+ class JunctionType(str, Enum):
123
+ T = "T"
124
+ L = "L"
125
+ END = "END"
126
+
127
+ @dataclass
128
+ class JunctionProperties:
129
+ flow_direction: Optional[str] = None # "in", "out"
130
+ pressure: Optional[float] = None # kPa
131
+
132
+ @dataclass
133
+ class Junction:
134
+ center: Coordinates
135
+ junction_type: JunctionType
136
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
137
+ properties: JunctionProperties = field(default_factory=JunctionProperties)
138
+ connected_lines: List[str] = field(default_factory=list) # Line IDs
139
+
140
+
141
+ # # ======================== Tag ======================== #
142
+ # @dataclass
143
+ # class Tag:
144
+ # text: str
145
+ # bbox: BBox
146
+ # associated_element: str # ID of linked symbol/line
147
+ # id: str = field(default_factory=lambda: str(uuid.uuid4()))
148
+ # font_size: int = 12
149
+ # rotation: float = 0.0 # Degrees
150
+
151
+ @dataclass
152
+ class Tag:
153
+ text: str
154
+ bbox: BBox
155
+ confidence: float = 1.0
156
+ source: str = "" # e.g., "easyocr"
157
+ text_type: str = "Unknown" # e.g., "Unknown", could be something else later
158
+ id: str = field(default_factory=lambda: str(uuid.uuid4()))
159
+ associated_element: Optional[str] = None # ID of linked symbol/line (can be None)
160
+ font_size: int = 12
161
+ rotation: float = 0.0 # Degrees
162
+
163
+ def __post_init__(self):
164
+ """
165
+ Ensure bbox is properly converted.
166
+ """
167
+ if isinstance(self.bbox, list) and len(self.bbox) == 4:
168
+ self.bbox = BBox(*self.bbox)
169
+
170
+ # ----------------------------
171
+ # DETECTION CONTEXT
172
+ # ----------------------------
173
+
174
+ @dataclass
175
+ class DetectionContext:
176
+ """
177
+ In-memory container for all detected elements (lines, points, symbols, junctions, tags).
178
+ Each element is stored in a dict keyed by 'id' for quick lookup and update.
179
+ """
180
+ lines: Dict[str, Line] = field(default_factory=dict)
181
+ points: Dict[str, Point] = field(default_factory=dict)
182
+ symbols: Dict[str, Symbol] = field(default_factory=dict)
183
+ junctions: Dict[str, Junction] = field(default_factory=dict)
184
+ tags: Dict[str, Tag] = field(default_factory=dict)
185
+
186
+ # -------------------------
187
+ # 1) ADD / GET / REMOVE
188
+ # -------------------------
189
+ def add_line(self, line: Line) -> None:
190
+ self.lines[line.id] = line
191
+
192
+ def get_line(self, line_id: str) -> Optional[Line]:
193
+ return self.lines.get(line_id)
194
+
195
+ def remove_line(self, line_id: str) -> None:
196
+ self.lines.pop(line_id, None)
197
+
198
+ def add_point(self, point: Point) -> None:
199
+ self.points[point.id] = point
200
+
201
+ def get_point(self, point_id: str) -> Optional[Point]:
202
+ return self.points.get(point_id)
203
+
204
+ def remove_point(self, point_id: str) -> None:
205
+ self.points.pop(point_id, None)
206
+
207
+ def add_symbol(self, symbol: Symbol) -> None:
208
+ self.symbols[symbol.id] = symbol
209
+
210
+ def get_symbol(self, symbol_id: str) -> Optional[Symbol]:
211
+ return self.symbols.get(symbol_id)
212
+
213
+ def remove_symbol(self, symbol_id: str) -> None:
214
+ self.symbols.pop(symbol_id, None)
215
+
216
+ def add_junction(self, junction: Junction) -> None:
217
+ self.junctions[junction.id] = junction
218
+
219
+ def get_junction(self, junction_id: str) -> Optional[Junction]:
220
+ return self.junctions.get(junction_id)
221
+
222
+ def remove_junction(self, junction_id: str) -> None:
223
+ self.junctions.pop(junction_id, None)
224
+
225
+ def add_tag(self, tag: Tag) -> None:
226
+ self.tags[tag.id] = tag
227
+
228
+ def get_tag(self, tag_id: str) -> Optional[Tag]:
229
+ return self.tags.get(tag_id)
230
+
231
+ def remove_tag(self, tag_id: str) -> None:
232
+ self.tags.pop(tag_id, None)
233
+
234
+ # -------------------------
235
+ # 2) SERIALIZATION: to_dict / from_dict
236
+ # -------------------------
237
+ def to_dict(self) -> dict:
238
+ """Convert all stored objects into a JSON-serializable dictionary."""
239
+ return {
240
+ "lines": [self._line_to_dict(line) for line in self.lines.values()],
241
+ "points": [self._point_to_dict(pt) for pt in self.points.values()],
242
+ "symbols": [self._symbol_to_dict(sym) for sym in self.symbols.values()],
243
+ "junctions": [self._junction_to_dict(jn) for jn in self.junctions.values()],
244
+ "tags": [self._tag_to_dict(tg) for tg in self.tags.values()]
245
+ }
246
+
247
+ @classmethod
248
+ def from_dict(cls, data: dict) -> "DetectionContext":
249
+ """
250
+ Create a new DetectionContext from a dictionary structure (e.g. loaded from JSON).
251
+ """
252
+ context = cls()
253
+
254
+ # Points
255
+ for pt_dict in data.get("points", []):
256
+ pt_obj = cls._point_from_dict(pt_dict)
257
+ context.add_point(pt_obj)
258
+
259
+ # Lines
260
+ for ln_dict in data.get("lines", []):
261
+ ln_obj = cls._line_from_dict(ln_dict)
262
+ context.add_line(ln_obj)
263
+
264
+ # Symbols
265
+ for sym_dict in data.get("symbols", []):
266
+ sym_obj = cls._symbol_from_dict(sym_dict)
267
+ context.add_symbol(sym_obj)
268
+
269
+ # Junctions
270
+ for jn_dict in data.get("junctions", []):
271
+ jn_obj = cls._junction_from_dict(jn_dict)
272
+ context.add_junction(jn_obj)
273
+
274
+ # Tags
275
+ for tg_dict in data.get("tags", []):
276
+ tg_obj = cls._tag_from_dict(tg_dict)
277
+ context.add_tag(tg_obj)
278
+
279
+ return context
280
+
281
+ # -------------------------
282
+ # 3) HELPER METHODS FOR (DE)SERIALIZATION
283
+ # -------------------------
284
+ @staticmethod
285
+ def _bbox_to_dict(bbox: BBox) -> dict:
286
+ return {
287
+ "xmin": bbox.xmin,
288
+ "ymin": bbox.ymin,
289
+ "xmax": bbox.xmax,
290
+ "ymax": bbox.ymax
291
+ }
292
+
293
+ @staticmethod
294
+ def _bbox_from_dict(d: dict) -> BBox:
295
+ return BBox(
296
+ xmin=d["xmin"],
297
+ ymin=d["ymin"],
298
+ xmax=d["xmax"],
299
+ ymax=d["ymax"]
300
+ )
301
+
302
+ @staticmethod
303
+ def _coords_to_dict(coords: Coordinates) -> dict:
304
+ return {
305
+ "x": coords.x,
306
+ "y": coords.y
307
+ }
308
+
309
+ @staticmethod
310
+ def _coords_from_dict(d: dict) -> Coordinates:
311
+ return Coordinates(x=d["x"], y=d["y"])
312
+
313
+ @staticmethod
314
+ def _line_style_to_dict(style: LineStyle) -> dict:
315
+ return {
316
+ "connection_type": style.connection_type.value,
317
+ "stroke_width": style.stroke_width,
318
+ "color": style.color
319
+ }
320
+
321
+ @staticmethod
322
+ def _line_style_from_dict(d: dict) -> LineStyle:
323
+ return LineStyle(
324
+ connection_type=ConnectionType(d["connection_type"]),
325
+ stroke_width=d.get("stroke_width", 2),
326
+ color=d.get("color", "#000000")
327
+ )
328
+
329
+ @staticmethod
330
+ def _point_to_dict(pt: Point) -> dict:
331
+ return {
332
+ "id": pt.id,
333
+ "coords": DetectionContext._coords_to_dict(pt.coords),
334
+ "bbox": DetectionContext._bbox_to_dict(pt.bbox),
335
+ "type": pt.type.value,
336
+ "confidence": pt.confidence
337
+ }
338
+
339
+ @staticmethod
340
+ def _point_from_dict(d: dict) -> Point:
341
+ return Point(
342
+ id=d["id"],
343
+ coords=DetectionContext._coords_from_dict(d["coords"]),
344
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
345
+ type=JunctionType(d["type"]),
346
+ confidence=d.get("confidence", 1.0)
347
+ )
348
+
349
+ @staticmethod
350
+ def _line_to_dict(ln: Line) -> dict:
351
+ return {
352
+ "id": ln.id,
353
+ "start": DetectionContext._point_to_dict(ln.start),
354
+ "end": DetectionContext._point_to_dict(ln.end),
355
+ "bbox": DetectionContext._bbox_to_dict(ln.bbox),
356
+ "style": DetectionContext._line_style_to_dict(ln.style),
357
+ "confidence": ln.confidence,
358
+ "topological_links": ln.topological_links
359
+ }
360
+
361
+ @staticmethod
362
+ def _line_from_dict(d: dict) -> Line:
363
+ return Line(
364
+ id=d["id"],
365
+ start=DetectionContext._point_from_dict(d["start"]),
366
+ end=DetectionContext._point_from_dict(d["end"]),
367
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
368
+ style=DetectionContext._line_style_from_dict(d["style"]),
369
+ confidence=d.get("confidence", 0.90),
370
+ topological_links=d.get("topological_links", [])
371
+ )
372
+
373
+ @staticmethod
374
+ def _symbol_to_dict(sym: Symbol) -> dict:
375
+ return {
376
+ "id": sym.id,
377
+ "symbol_type": sym.symbol_type.value,
378
+ "bbox": DetectionContext._bbox_to_dict(sym.bbox),
379
+ "center": DetectionContext._coords_to_dict(sym.center),
380
+ "connections": [DetectionContext._point_to_dict(p) for p in sym.connections],
381
+ "subtype": sym.subtype.value if sym.subtype else None,
382
+ "confidence": sym.confidence,
383
+ "model_metadata": sym.model_metadata
384
+ }
385
+
386
+ @staticmethod
387
+ def _symbol_from_dict(d: dict) -> Symbol:
388
+ return Symbol(
389
+ id=d["id"],
390
+ symbol_type=SymbolType(d["symbol_type"]),
391
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
392
+ center=DetectionContext._coords_from_dict(d["center"]),
393
+ connections=[DetectionContext._point_from_dict(p) for p in d.get("connections", [])],
394
+ subtype=ValveSubtype(d["subtype"]) if d.get("subtype") else None,
395
+ confidence=d.get("confidence", 0.95),
396
+ model_metadata=d.get("model_metadata", {})
397
+ )
398
+
399
+ @staticmethod
400
+ def _junction_props_to_dict(props: JunctionProperties) -> dict:
401
+ return {
402
+ "flow_direction": props.flow_direction,
403
+ "pressure": props.pressure
404
+ }
405
+
406
+ @staticmethod
407
+ def _junction_props_from_dict(d: dict) -> JunctionProperties:
408
+ return JunctionProperties(
409
+ flow_direction=d.get("flow_direction"),
410
+ pressure=d.get("pressure")
411
+ )
412
+
413
+ @staticmethod
414
+ def _junction_to_dict(jn: Junction) -> dict:
415
+ return {
416
+ "id": jn.id,
417
+ "center": DetectionContext._coords_to_dict(jn.center),
418
+ "junction_type": jn.junction_type.value,
419
+ "properties": DetectionContext._junction_props_to_dict(jn.properties),
420
+ "connected_lines": jn.connected_lines
421
+ }
422
+
423
+ @staticmethod
424
+ def _junction_from_dict(d: dict) -> Junction:
425
+ return Junction(
426
+ id=d["id"],
427
+ center=DetectionContext._coords_from_dict(d["center"]),
428
+ junction_type=JunctionType(d["junction_type"]),
429
+ properties=DetectionContext._junction_props_from_dict(d["properties"]),
430
+ connected_lines=d.get("connected_lines", [])
431
+ )
432
+
433
+ @staticmethod
434
+ def _tag_to_dict(tg: Tag) -> dict:
435
+ return {
436
+ "id": tg.id,
437
+ "text": tg.text,
438
+ "bbox": DetectionContext._bbox_to_dict(tg.bbox),
439
+ "associated_element": tg.associated_element,
440
+ "font_size": tg.font_size,
441
+ "rotation": tg.rotation
442
+ }
443
+
444
+ @staticmethod
445
+ def _tag_from_dict(d: dict) -> Tag:
446
+ return Tag(
447
+ id=d["id"],
448
+ text=d["text"],
449
+ bbox=DetectionContext._bbox_from_dict(d["bbox"]),
450
+ associated_element=d["associated_element"],
451
+ font_size=d.get("font_size", 12),
452
+ rotation=d.get("rotation", 0.0)
453
+ )
454
+
455
+ # -------------------------
456
+ # 4) OPTIONAL UTILS
457
+ # -------------------------
458
+ def to_json(self, indent: int = 2) -> str:
459
+ """Convert context to JSON, ensuring dataclasses and numpy types are handled correctly."""
460
+ return json.dumps(self.to_dict(), default=self._json_serializer, indent=indent)
461
+
462
+ @staticmethod
463
+ def _json_serializer(obj):
464
+ """Handles numpy types and unknown objects for JSON serialization."""
465
+ if isinstance(obj, np.integer):
466
+ return int(obj)
467
+ if isinstance(obj, np.floating):
468
+ return float(obj)
469
+ if isinstance(obj, np.ndarray):
470
+ return obj.tolist() # Convert arrays to lists
471
+ if isinstance(obj, Enum):
472
+ return obj.value # Convert Enums to string values
473
+ if hasattr(obj, "__dict__"):
474
+ return obj.__dict__ # Convert dataclass objects to dict
475
+ raise TypeError(f"Object of type {type(obj)} is not JSON serializable")
476
+
477
+ @classmethod
478
+ def from_json(cls, json_str: str) -> "DetectionContext":
479
+ """Load DetectionContext from a JSON string."""
480
+ data = json.loads(json_str)
481
+ return cls.from_dict(data)
detectors.py ADDED
@@ -0,0 +1,1096 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import cv2
5
+ import numpy as np
6
+ from typing import List, Optional, Tuple, Dict
7
+ from dataclasses import replace
8
+ from math import sqrt
9
+ import json
10
+ import uuid
11
+ from pathlib import Path
12
+
13
+ # Base classes and utilities
14
+ from base import BaseDetector
15
+ from detection_schema import DetectionContext
16
+ from utils import DebugHandler
17
+ from config import SymbolConfig, TagConfig, LineConfig, PointConfig, JunctionConfig
18
+
19
+ # DeepLSD model for line detection
20
+ from deeplsd.models.deeplsd_inference import DeepLSD
21
+ from ultralytics import YOLO
22
+
23
+ # Detection schema: dataclasses for different objects
24
+ from detection_schema import (
25
+ BBox,
26
+ Coordinates,
27
+ Point,
28
+ Line,
29
+ Symbol,
30
+ Tag,
31
+ SymbolType,
32
+ LineStyle,
33
+ ConnectionType,
34
+ JunctionType,
35
+ Junction
36
+ )
37
+
38
+ # Skeletonization and label processing for junction detection
39
+ from skimage.morphology import skeletonize
40
+ from skimage.measure import label
41
+
42
+
43
+ class LineDetector(BaseDetector):
44
+ """
45
+ DeepLSD-based line detection that populates newly detected lines (and naive endpoints)
46
+ directly into a DetectionContext.
47
+ """
48
+
49
+ <<<<<<< HEAD
50
+ def __init__(self, model_path=None, model=None, model_config=None, device=None, debug_handler=None):
51
+ self.device = device or torch.device('cpu')
52
+ self.debug_handler = debug_handler
53
+ if model is not None:
54
+ self.model = model
55
+ else:
56
+ super().__init__(model_path)
57
+ self.config = model_config or {}
58
+ self.scale_factor = 8.0 # Inverse of 0.5 scaling
59
+ self.margin = 10 # BBox expansion margin
60
+ =======
61
+ def __init__(self,
62
+ config: LineConfig,
63
+ model_path: str,
64
+ model_config: dict,
65
+ device: torch.device,
66
+ debug_handler: DebugHandler = None):
67
+ self.device = device
68
+ self.model_path = model_path
69
+ self.model_config = model_config
70
+ super().__init__(config, debug_handler)
71
+ self._load_params()
72
+ self.model = self._load_model(model_path)
73
+ self.scale_factor = 0.75 # For downscaling input to model
74
+ self.margin = 10
75
+ >>>>>>> temp/test-integration
76
+
77
+ # -------------------------------------
78
+ # BaseDetector requirements
79
+ # -------------------------------------
80
+ def _load_model(self, model_path: str) -> DeepLSD:
81
+ """Load and configure the DeepLSD model."""
82
+ if not os.path.exists(model_path):
83
+ raise FileNotFoundError(f"Model file not found: {model_path}")
84
+ ckpt = torch.load(model_path, map_location=self.device)
85
+ <<<<<<< HEAD
86
+ model = DeepLSD(self.config)
87
+ model.load_state_dict(ckpt['model'])
88
+ =======
89
+ model = DeepLSD(self.model_config)
90
+ model.load_state_dict(ckpt["model"])
91
+ >>>>>>> temp/test-integration
92
+ return model.to(self.device).eval()
93
+
94
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
95
+ """
96
+ Not used directly here. We'll handle our own
97
+ masking + threshold steps in the detect() method.
98
+ """
99
+ return image
100
+
101
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
102
+ """
103
+ Not used directly. Postprocessing is integrated
104
+ into detect() after we create lines.
105
+ """
106
+ return image
107
+
108
+ # -------------------------------------
109
+ # Our main detection method
110
+ # -------------------------------------
111
+ def detect(self,
112
+ image: np.ndarray,
113
+ context: DetectionContext,
114
+ mask_coords: Optional[List[BBox]] = None,
115
+ *args,
116
+ **kwargs) -> None:
117
+ """
118
+ Main detection pipeline:
119
+ 1) Apply mask
120
+ 2) Convert to binary & downscale
121
+ 3) Run DeepLSD
122
+ 4) Build minimal Line objects (with naive endpoints)
123
+ 5) Scale lines to original resolution
124
+ 6) Store the lines into the context
125
+
126
+ We do NOT unify endpoints here or classify them as T/L/etc.
127
+ """
128
+ mask_coords = mask_coords or []
129
+
130
+ # (A) Preprocess
131
+ processed_img = self._apply_mask_and_downscale(image, mask_coords)
132
+
133
+ # (B) Inference
134
+ raw_output = self._run_model_inference(processed_img)
135
+
136
+ # (C) Create lines in downscaled space
137
+ downscaled_lines = self._create_lines_from_output(raw_output)
138
+
139
+ # (D) Scale them to original resolution
140
+ lines_scaled = [self._scale_line(ln) for ln in downscaled_lines]
141
+
142
+ # (E) Add them to context
143
+ for line in lines_scaled:
144
+ context.add_line(line)
145
+
146
+ # -------------------------------------
147
+ # Internal helpers
148
+ # -------------------------------------
149
+ def _load_params(self):
150
+ """Load any model_config parameters if needed."""
151
+ pass
152
+
153
+ def _apply_mask_and_downscale(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
154
+ """Apply rectangular mask, then threshold, then downscale."""
155
+ masked = self._apply_masking(image, mask_coords)
156
+ gray = cv2.cvtColor(masked, cv2.COLOR_RGB2GRAY)
157
+ <<<<<<< HEAD
158
+ binary = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
159
+ return cv2.resize(binary, None, fx=1/self.scale_factor, fy=1/self.scale_factor)
160
+ =======
161
+ binary_full = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)[1]
162
+ >>>>>>> temp/test-integration
163
+
164
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 1))
165
+ dilated = cv2.dilate(binary_full, kernel, iterations=2)
166
+
167
+ # Downscale
168
+ binary_downscaled = cv2.resize(
169
+ dilated,
170
+ None,
171
+ fx=self.scale_factor,
172
+ fy=self.scale_factor
173
+ )
174
+ return binary_downscaled
175
+
176
+ def _apply_masking(self, image: np.ndarray, mask_coords: List[BBox]) -> np.ndarray:
177
+ """White-out rectangular areas to ignore them."""
178
+ masked = image.copy()
179
+ for bbox in mask_coords:
180
+ x1, y1 = int(bbox.xmin), int(bbox.ymin)
181
+ x2, y2 = int(bbox.xmax), int(bbox.ymax)
182
+ cv2.rectangle(masked, (x1, y1), (x2, y2), (255, 255, 255), -1)
183
+ return masked
184
+
185
+ def _run_model_inference(self, downscaled_binary: np.ndarray) -> np.ndarray:
186
+ """Run DeepLSD on the downscaled binary image, returning raw lines [N, 2, 2]."""
187
+ tensor = torch.tensor(downscaled_binary, dtype=torch.float32, device=self.device)[None, None] / 255.0
188
+ # tensor = torch.tensor(downscaled_binary, dtype=torch.float32, device=self.device)[None, None] / 255.0
189
+ with torch.no_grad():
190
+ output = self.model({"image": tensor})
191
+ # shape: [batch, num_lines, 2, 2]
192
+ return output["lines"][0]
193
+
194
+ def _create_lines_from_output(self, model_output: np.ndarray) -> List[Line]:
195
+ """
196
+ Convert each [2,2] line segment into a minimal Line with naive endpoints (type=END).
197
+ Coordinates are in downscaled space.
198
+ """
199
+ lines = []
200
+ for endpoints in model_output:
201
+ (x1, y1), (x2, y2) = endpoints # shape (2,) each
202
+
203
+ p_start = self._create_point(x1, y1)
204
+ p_end = self._create_point(x2, y2)
205
+
206
+ # minimal bounding box in downscaled coords
207
+ x_min = min(x1, x2)
208
+ x_max = max(x1, x2)
209
+ y_min = min(y1, y2)
210
+ y_max = max(y1, y2)
211
+
212
+ line_obj = Line(
213
+ start=p_start,
214
+ end=p_end,
215
+ bbox=BBox(
216
+ xmin=int(x_min),
217
+ ymin=int(y_min),
218
+ xmax=int(x_max),
219
+ ymax=int(y_max)
220
+ ),
221
+ # style / confidence / ID assigned by default
222
+ style=LineStyle(
223
+ connection_type=ConnectionType.SOLID,
224
+ stroke_width=2,
225
+ color="#000000"
226
+ ),
227
+ confidence=0.9,
228
+ topological_links=[]
229
+ )
230
+ lines.append(line_obj)
231
+
232
+ return lines
233
+
234
+ def _create_point(self, x: float, y: float) -> Point:
235
+ """
236
+ Creates a naive 'END'-type Point at downscaled coords.
237
+ We'll scale it later.
238
+ """
239
+ margin = 2
240
+ return Point(
241
+ coords=Coordinates(x=int(x), y=int(y)),
242
+ bbox=BBox(
243
+ xmin=int(x - margin),
244
+ ymin=int(y - margin),
245
+ xmax=int(x + margin),
246
+ ymax=int(y + margin)
247
+ ),
248
+ type=JunctionType.END, # no classification here
249
+ confidence=1.0
250
+ )
251
+
252
+ def _scale_line(self, line: Line) -> Line:
253
+ """
254
+ Scale line's start/end points + bounding box to original resolution.
255
+ """
256
+ scaled_start = self._scale_point(line.start)
257
+ scaled_end = self._scale_point(line.end)
258
+
259
+ # recalc bounding box in original scale
260
+ new_bbox = BBox(
261
+ xmin=min(scaled_start.bbox.xmin, scaled_end.bbox.xmin),
262
+ ymin=min(scaled_start.bbox.ymin, scaled_end.bbox.ymin),
263
+ xmax=max(scaled_start.bbox.xmax, scaled_end.bbox.xmax),
264
+ ymax=max(scaled_start.bbox.ymax, scaled_end.bbox.ymax)
265
+ )
266
+
267
+ return replace(line, start=scaled_start, end=scaled_end, bbox=new_bbox)
268
+
269
+ def _scale_point(self, point: Point) -> Point:
270
+ sx = int(point.coords.x * 1/self.scale_factor)
271
+ sy = int(point.coords.y * 1/self.scale_factor)
272
+
273
+ bb = point.bbox
274
+ scaled_bbox = BBox(
275
+ xmin=int(bb.xmin * 1/self.scale_factor),
276
+ ymin=int(bb.ymin * 1/self.scale_factor),
277
+ xmax=int(bb.xmax * 1/self.scale_factor),
278
+ ymax=int(bb.ymax * 1/self.scale_factor)
279
+ )
280
+ return replace(point, coords=Coordinates(sx, sy), bbox=scaled_bbox)
281
+
282
+
283
+ class PointDetector(BaseDetector):
284
+ """
285
+ A detector that:
286
+ 1) Reads lines from the context
287
+ 2) Clusters endpoints within 'threshold_distance'
288
+ 3) Updates lines so that shared endpoints reference the same Point object
289
+ """
290
+
291
+ def __init__(self,
292
+ config:PointConfig,
293
+ debug_handler: DebugHandler = None):
294
+ super().__init__(config, debug_handler) # No real model to load
295
+ self.threshold_distance = config.threshold_distance
296
+
297
+ def _load_model(self, model_path: str):
298
+ """No model needed for simple point unification."""
299
+ return None
300
+
301
+ def detect(self, image: np.ndarray, context: DetectionContext, *args, **kwargs) -> None:
302
+ """
303
+ Main method called by the pipeline.
304
+ 1) Gather all line endpoints from context
305
+ 2) Cluster them within 'threshold_distance'
306
+ 3) Update the line endpoints so they reference the unified cluster point
307
+ """
308
+ # 1) Collect all endpoints
309
+ endpoints = []
310
+ for line in context.lines.values():
311
+ endpoints.append(line.start)
312
+ endpoints.append(line.end)
313
+
314
+ # 2) Cluster endpoints
315
+ clusters = self._cluster_points(endpoints, self.threshold_distance)
316
+
317
+ # 3) Build a dictionary of "representative" points
318
+ # So that each cluster has one "canonical" point
319
+ # Then we link all the points in that cluster to the canonical reference
320
+ unified_point_map = {}
321
+ for cluster in clusters:
322
+ # let's pick the first point in the cluster as the "representative"
323
+ rep_point = cluster[0]
324
+ for p in cluster[1:]:
325
+ unified_point_map[p.id] = rep_point
326
+
327
+ # 4) Update all lines to reference the canonical point
328
+ for line in context.lines.values():
329
+ # unify start
330
+ if line.start.id in unified_point_map:
331
+ line.start = unified_point_map[line.start.id]
332
+ # unify end
333
+ if line.end.id in unified_point_map:
334
+ line.end = unified_point_map[line.end.id]
335
+
336
+ # We could also store the final set of unique points back in context.points
337
+ # (e.g. clearing old duplicates).
338
+ # That step is optional: you might prefer to keep everything in lines only,
339
+ # or you might want context.points as a separate reference.
340
+
341
+ # If you want to keep unique points in context.points:
342
+ new_points = {}
343
+ for line in context.lines.values():
344
+ new_points[line.start.id] = line.start
345
+ new_points[line.end.id] = line.end
346
+ context.points = new_points # replace the dictionary of points
347
+
348
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
349
+ """No specific image preprocessing needed."""
350
+ return image
351
+
352
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
353
+ """No specific image postprocessing needed."""
354
+ return image
355
+
356
+ # ----------------------
357
+ # HELPER: clustering
358
+ # ----------------------
359
+ def _cluster_points(self, points: List[Point], threshold: float) -> List[List[Point]]:
360
+ """
361
+ Very naive clustering:
362
+ 1) Start from the first point
363
+ 2) If it's within threshold of an existing cluster's representative,
364
+ put it in that cluster
365
+ 3) Otherwise start a new cluster
366
+ Return: list of clusters, each is a list of Points
367
+ """
368
+ clusters = []
369
+
370
+ for pt in points:
371
+ placed = False
372
+ for cluster in clusters:
373
+ # pick the first point in the cluster as reference
374
+ ref_pt = cluster[0]
375
+ if self._distance(pt, ref_pt) < threshold:
376
+ cluster.append(pt)
377
+ placed = True
378
+ break
379
+
380
+ if not placed:
381
+ clusters.append([pt])
382
+
383
+ return clusters
384
+
385
+ def _distance(self, p1: Point, p2: Point) -> float:
386
+ dx = p1.coords.x - p2.coords.x
387
+ dy = p1.coords.y - p2.coords.y
388
+ return sqrt(dx*dx + dy*dy)
389
+
390
+
391
+ class JunctionDetector(BaseDetector):
392
+ """
393
+ Classifies points as 'END', 'L', or 'T' by skeletonizing the binarized image
394
+ and analyzing local connectivity. Also creates Junction objects in the context.
395
+ """
396
+
397
+ def __init__(self, config: JunctionConfig, debug_handler: DebugHandler = None):
398
+ super().__init__(config, debug_handler) # no real model path
399
+ self.window_size = config.window_size
400
+ self.radius = config.radius
401
+ self.angle_threshold_lb = config.angle_threshold_lb
402
+ self.angle_threshold_ub = config.angle_threshold_ub
403
+ self.debug_handler = debug_handler or DebugHandler()
404
+
405
+ def _load_model(self, model_path: str):
406
+ """Not loading any actual model, just skeleton logic."""
407
+ return None
408
+
409
+ def detect(self,
410
+ image: np.ndarray,
411
+ context: DetectionContext,
412
+ *args,
413
+ **kwargs) -> None:
414
+ """
415
+ 1) Convert to binary & skeletonize
416
+ 2) Classify each point in the context
417
+ 3) Create a Junction for each point and store it in context.junctions
418
+ (with 'connected_lines' referencing lines that share this point).
419
+ """
420
+ # 1) Preprocess -> skeleton
421
+ skeleton = self._create_skeleton(image)
422
+
423
+ # 2) Classify each point
424
+ for pt in context.points.values():
425
+ pt.type = self._classify_point(skeleton, pt)
426
+
427
+ # 3) Create a Junction object for each point
428
+ # If you prefer only T or L, you can filter out END points.
429
+ self._record_junctions_in_context(context)
430
+
431
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
432
+ """We might do thresholding; let's do a simple binary threshold."""
433
+ if image.ndim == 3:
434
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
435
+ else:
436
+ gray = image
437
+ _, bin_image = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY)
438
+ return bin_image
439
+
440
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
441
+ return image
442
+
443
+ def _create_skeleton(self, raw_image: np.ndarray) -> np.ndarray:
444
+ """Skeletonize the binarized image."""
445
+ bin_img = self._preprocess(raw_image)
446
+ # For skeletonize, we need a boolean array
447
+ inv = cv2.bitwise_not(bin_img)
448
+ inv_bool = (inv > 127).astype(np.uint8)
449
+ skel = skeletonize(inv_bool).astype(np.uint8) * 255
450
+ return skel
451
+
452
+ def _classify_point(self, skeleton: np.ndarray, pt: Point) -> JunctionType:
453
+ """
454
+ Given a skeleton image, look around 'pt' in a local window
455
+ to determine if it's an END, L, or T.
456
+ """
457
+ classification = JunctionType.END # default
458
+
459
+ half_w = self.window_size // 2
460
+ x, y = pt.coords.x, pt.coords.y
461
+
462
+ top = max(0, y - half_w)
463
+ bottom = min(skeleton.shape[0], y + half_w + 1)
464
+ left = max(0, x - half_w)
465
+ right = min(skeleton.shape[1], x + half_w + 1)
466
+
467
+ patch = (skeleton[top:bottom, left:right] > 127).astype(np.uint8)
468
+
469
+ # create circular mask
470
+ circle_mask = np.zeros_like(patch, dtype=np.uint8)
471
+ local_cx = x - left
472
+ local_cy = y - top
473
+ cv2.circle(circle_mask, (local_cx, local_cy), self.radius, 1, -1)
474
+ circle_skel = patch & circle_mask
475
+
476
+ # label connected regions
477
+ labeled = label(circle_skel, connectivity=2)
478
+ num_exits = labeled.max()
479
+
480
+ if num_exits == 1:
481
+ classification = JunctionType.END
482
+ elif num_exits == 2:
483
+ # check angle for L
484
+ classification = self._check_angle_for_L(labeled)
485
+ elif num_exits == 3:
486
+ classification = JunctionType.T
487
+
488
+ return classification
489
+
490
+ def _check_angle_for_L(self, labeled_region: np.ndarray) -> JunctionType:
491
+ """
492
+ If the angle between two branches is within
493
+ [angle_threshold_lb, angle_threshold_ub], it's 'L'.
494
+ Otherwise default to END.
495
+ """
496
+ coords = np.argwhere(labeled_region == 1)
497
+ if len(coords) < 2:
498
+ return JunctionType.END
499
+
500
+ (y1, x1), (y2, x2) = coords[:2]
501
+ dx = x2 - x1
502
+ dy = y2 - y1
503
+ angle = math.degrees(math.atan2(dy, dx))
504
+ acute_angle = min(abs(angle), 180 - abs(angle))
505
+
506
+ if self.angle_threshold_lb <= acute_angle <= self.angle_threshold_ub:
507
+ return JunctionType.L
508
+ return JunctionType.END
509
+
510
+ # -----------------------------------------
511
+ # EXTRA STEP: Create Junction objects
512
+ # -----------------------------------------
513
+ def _record_junctions_in_context(self, context: DetectionContext):
514
+ """
515
+ Create a Junction object for each point in context.points.
516
+ If you only want T/L points as junctions, filter them out.
517
+ Also track any lines that connect to this point.
518
+ """
519
+
520
+ for pt in context.points.values():
521
+ # If you prefer to store all points as junction, do it:
522
+ # or if you want only T or L, do:
523
+ # if pt.type in {JunctionType.T, JunctionType.L}: ...
524
+
525
+ jn = Junction(
526
+ center=pt.coords,
527
+ junction_type=pt.type,
528
+ # add more properties if needed
529
+ )
530
+
531
+ # find lines that connect to this point
532
+ connected_lines = []
533
+ for ln in context.lines.values():
534
+ if ln.start.id == pt.id or ln.end.id == pt.id:
535
+ connected_lines.append(ln.id)
536
+
537
+ jn.connected_lines = connected_lines
538
+
539
+ # add to context
540
+ context.add_junction(jn)
541
+
542
+ # from loguru import logger
543
+ #
544
+ #
545
+ # class SymbolDetector(BaseDetector):
546
+ # """
547
+ # YOLO-based symbol detector using multiple confidence thresholds,
548
+ # merges final detections, and stores them in the context.
549
+ # """
550
+ #
551
+ # def __init__(self, config: SymbolConfig, debug_handler: Optional[DebugHandler] = None):
552
+ # super().__init__(config, debug_handler)
553
+ # self.config = config
554
+ # self.debug_handler = debug_handler or DebugHandler()
555
+ # self.models = self._load_models()
556
+ # self.class_map = self._build_class_map()
557
+ #
558
+ # logger.info("Symbol detector initialized with config: %s", self.config)
559
+ #
560
+ # # -----------------------------
561
+ # # BaseDetector Implementation
562
+ # # -----------------------------
563
+ # def _load_model(self, model_path: str):
564
+ # """We won't use this single-model loader; see _load_models()."""
565
+ # pass
566
+ #
567
+ # def detect(self,
568
+ # image: np.ndarray,
569
+ # context: DetectionContext,
570
+ # roi_offset: Tuple[int, int],
571
+ # *args,
572
+ # **kwargs) -> None:
573
+ # """
574
+ # Run multi-threshold YOLO detection for each model, pick best threshold,
575
+ # merge detections, and store Symbol objects in context.
576
+ # """
577
+ # try:
578
+ # with self.debug_handler.track_performance("symbol_detection"):
579
+ # # 1) Possibly preprocess & resize
580
+ # processed_img = self._preprocess(image)
581
+ # resized_img, scale_factor = self._resize_image(processed_img)
582
+ #
583
+ # # 2) Detect with all models, each using multiple thresholds
584
+ # all_detections = []
585
+ # for model_name, model in self.models.items():
586
+ # best_detections = self._detect_best_threshold(
587
+ # model, resized_img, image.shape, scale_factor, model_name
588
+ # )
589
+ # all_detections.extend(best_detections)
590
+ #
591
+ # # 3) Merge detections using NMS logic
592
+ # merged_detections = self._merge_detections(all_detections)
593
+ #
594
+ # # 4) Update context with final symbols
595
+ # self._update_context(merged_detections, context)
596
+ #
597
+ # # 5) Create optional debug image artifact
598
+ # debug_image = self._create_debug_image(processed_img, merged_detections)
599
+ # _, debug_img_encoded = cv2.imencode('.jpg', debug_image)
600
+ # self.debug_handler.save_artifact(
601
+ # name="symbol_detection_debug",
602
+ # data=debug_img_encoded.tobytes(),
603
+ # extension="jpg"
604
+ # )
605
+ #
606
+ # except Exception as e:
607
+ # logger.error("Symbol detection failed: %s", str(e), exc_info=True)
608
+ # self.debug_handler.save_artifact(
609
+ # name="symbol_detection_error",
610
+ # data=f"Detection error: {str(e)}".encode('utf-8'),
611
+ # extension="txt"
612
+ # )
613
+ #
614
+ # def _preprocess(self, image: np.ndarray) -> np.ndarray:
615
+ # """Preprocess if needed (e.g., histogram equalization)."""
616
+ # if self.config.apply_preprocessing:
617
+ # gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
618
+ # equalized = cv2.equalizeHist(gray)
619
+ # # convert back to BGR for YOLO
620
+ # return cv2.cvtColor(equalized, cv2.COLOR_GRAY2BGR)
621
+ # return image.copy()
622
+ #
623
+ # def _postprocess(self, image: np.ndarray) -> np.ndarray:
624
+ # return None
625
+ #
626
+ # # -----------------------------
627
+ # # Internal Helpers
628
+ # # -----------------------------
629
+ # def _load_models(self) -> Dict[str, YOLO]:
630
+ # """Load multiple YOLO models from config."""
631
+ # models = {}
632
+ # for model_name, path_str in self.config.model_paths.items():
633
+ # path = Path(path_str)
634
+ # if not path.exists():
635
+ # raise FileNotFoundError(f"Model file not found: {path_str}")
636
+ # models[model_name] = YOLO(str(path))
637
+ # logger.info(f"Loaded model '{model_name}' from {path_str}")
638
+ # return models
639
+ #
640
+ # def _build_class_map(self) -> Dict[int, SymbolType]:
641
+ # """
642
+ # Convert config symbol_type_mapping (like {"pump": "PUMP"})
643
+ # into a dictionary from YOLO class_id to SymbolType.
644
+ # If you have a fixed list of YOLO classes, you can map them here.
645
+ # """
646
+ # # For example, if YOLO has classes like ["valve", "pump", ...],
647
+ # # you might want to do something more dynamic.
648
+ # # For now, let's just return an empty dict or handle it in detection.
649
+ # return {}
650
+ #
651
+ # def _resize_image(self, image: np.ndarray) -> Tuple[np.ndarray, float]:
652
+ # """Resize while maintaining aspect ratio if needed."""
653
+ # h, w = image.shape[:2]
654
+ # if not self.config.resize_image:
655
+ # return image, 1.0
656
+ #
657
+ # if max(w, h) > self.config.max_dimension:
658
+ # scale = self.config.max_dimension / max(w, h)
659
+ # new_w, new_h = int(w * scale), int(h * scale)
660
+ # resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
661
+ # return resized, scale
662
+ # return image, 1.0
663
+ #
664
+ # def _detect_best_threshold(self,
665
+ # model: YOLO,
666
+ # resized_img: np.ndarray,
667
+ # orig_shape: Tuple[int, int, int],
668
+ # scale_factor: float,
669
+ # model_name: str) -> List[Dict]:
670
+ # """
671
+ # Run detection across multiple confidence thresholds.
672
+ # Use the threshold that yields the 'best metric' (currently # of detections).
673
+ # """
674
+ # best_metric = -1
675
+ # best_threshold = 0.5
676
+ # best_detections_list = []
677
+ #
678
+ # # Evaluate each threshold
679
+ # for thresh in self.config.confidence_thresholds:
680
+ # # Run YOLO detection
681
+ # # Setting conf=thresh or conf=0.0 + we do filtering ourselves.
682
+ # results = model.predict(
683
+ # source=resized_img,
684
+ # imgsz=self.config.max_dimension,
685
+ # conf=0.0, # We'll filter manually below
686
+ # verbose=False
687
+ # )
688
+ #
689
+ # # Convert to detection dict
690
+ # detections_list = []
691
+ # for result in results:
692
+ # for box in result.boxes:
693
+ # conf_val = float(box.conf[0])
694
+ # if conf_val >= thresh:
695
+ # # Convert bounding box coords to original (local) coords
696
+ # x1, y1, x2, y2 = self._scale_coordinates(
697
+ # box.xyxy[0].cpu().numpy(),
698
+ # resized_img.shape, # shape after resizing
699
+ # scale_factor
700
+ # )
701
+ # class_id = int(box.cls[0])
702
+ # label = result.names[class_id] if result.names else "unknown_label"
703
+ #
704
+ # # parse label (category, type, new_label)
705
+ # category, type_str, new_label = self._parse_label(label)
706
+ #
707
+ # detection_info = {
708
+ # "symbol_id": str(uuid.uuid4()),
709
+ # "class_id": class_id,
710
+ # "original_label": label,
711
+ # "category": category,
712
+ # "type": type_str,
713
+ # "label": new_label,
714
+ # "confidence": conf_val,
715
+ # "bbox": [x1, y1, x2, y2],
716
+ # "model_source": model_name
717
+ # }
718
+ # detections_list.append(detection_info)
719
+ #
720
+ # # Evaluate
721
+ # metric = self._evaluate_detections(detections_list)
722
+ # if metric > best_metric:
723
+ # best_metric = metric
724
+ # best_threshold = thresh
725
+ # best_detections_list = detections_list
726
+ #
727
+ # logger.info(f"For model {model_name}, best threshold={best_threshold:.2f} with {best_metric} detections.")
728
+ # return best_detections_list
729
+ #
730
+ # def _evaluate_detections(self, detections_list: List[Dict]) -> int:
731
+ # """A simple metric: # of detections."""
732
+ # return len(detections_list)
733
+ #
734
+ # def _parse_label(self, label: str) -> Tuple[str, str, str]:
735
+ # """
736
+ # Attempt to parse the YOLO label into (category, type, new_label).
737
+ # Example label: "inst_ind_Solenoid_actuator"
738
+ # -> category=inst, type=ind, new_label="Solenoid_actuator"
739
+ # If no underscores, we fallback to "Unknown" for type.
740
+ # """
741
+ # split_label = label.split('_')
742
+ # if len(split_label) >= 3:
743
+ # category = split_label[0]
744
+ # type_ = split_label[1]
745
+ # new_label = '_'.join(split_label[2:])
746
+ # elif len(split_label) == 2:
747
+ # category = split_label[0]
748
+ # type_ = split_label[1]
749
+ # new_label = split_label[1]
750
+ # elif len(split_label) == 1:
751
+ # category = split_label[0]
752
+ # type_ = "Unknown"
753
+ # new_label = split_label[0]
754
+ # else:
755
+ # logger.warning(f"Unexpected label format: {label}")
756
+ # return ("Unknown", "Unknown", label)
757
+ #
758
+ # return (category, type_, new_label)
759
+ #
760
+ # def _scale_coordinates(self,
761
+ # coords: np.ndarray,
762
+ # resized_shape: Tuple[int, int, int],
763
+ # scale_factor: float) -> Tuple[int, int, int, int]:
764
+ # """
765
+ # Scale YOLO's [x1,y1,x2,y2] from the resized image back to the original local coords.
766
+ # """
767
+ # x1, y1, x2, y2 = coords
768
+ # # Because we resized by scale_factor
769
+ # # so original coordinate = coords / scale_factor
770
+ # return (
771
+ # int(x1 / scale_factor),
772
+ # int(y1 / scale_factor),
773
+ # int(x2 / scale_factor),
774
+ # int(y2 / scale_factor),
775
+ # )
776
+ #
777
+ # def _merge_detections(self, all_detections: List[Dict]) -> List[Dict]:
778
+ # """Merge using NMS-like approach (IoU-based) across all models."""
779
+ # if not all_detections:
780
+ # return []
781
+ #
782
+ # # Sort by confidence (descending)
783
+ # all_detections.sort(key=lambda x: x['confidence'], reverse=True)
784
+ # keep = [True] * len(all_detections)
785
+ #
786
+ # for i in range(len(all_detections)):
787
+ # if not keep[i]:
788
+ # continue
789
+ # for j in range(i + 1, len(all_detections)):
790
+ # if not keep[j]:
791
+ # continue
792
+ # # Merge if same class_id & high IoU
793
+ # if (all_detections[i]['class_id'] == all_detections[j]['class_id'] and
794
+ # self._calculate_iou(all_detections[i]['bbox'], all_detections[j]['bbox']) > 0.5):
795
+ # keep[j] = False
796
+ #
797
+ # return [det for idx, det in enumerate(all_detections) if keep[idx]]
798
+ #
799
+ # def _calculate_iou(self, box1: List[int], box2: List[int]) -> float:
800
+ # """Intersection over Union"""
801
+ # x_left = max(box1[0], box2[0])
802
+ # y_top = max(box1[1], box2[1])
803
+ # x_right = min(box1[2], box2[2])
804
+ # y_bottom = min(box1[3], box2[3])
805
+ #
806
+ # inter_w = max(0, x_right - x_left)
807
+ # inter_h = max(0, y_bottom - y_top)
808
+ # intersection = inter_w * inter_h
809
+ #
810
+ # area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
811
+ # area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
812
+ # union = float(area1 + area2 - intersection)
813
+ # return intersection / union if union > 0 else 0.0
814
+ #
815
+ # def _update_context(self, detections: List[Dict], context: DetectionContext) -> None:
816
+ # """Convert final detections into Symbol objects & add to context."""
817
+ # for det in detections:
818
+ # x1, y1, x2, y2 = det['bbox']
819
+ # # Use your Symbol dataclass from detection_schema
820
+ # symbol_obj = Symbol(
821
+ # bbox=BBox(xmin=x1, ymin=y1, xmax=x2, ymax=y2),
822
+ # center=Coordinates(x=(x1 + x2) // 2, y=(y1 + y2) // 2),
823
+ # symbol_type=SymbolType.OTHER, # default
824
+ # confidence=det['confidence'],
825
+ # model_source=det['model_source'],
826
+ # class_id=det['class_id'],
827
+ # original_label=det['original_label'],
828
+ # category=det['category'],
829
+ # type=det['type'],
830
+ # label=det['label']
831
+ # )
832
+ # context.add_symbol(symbol_obj)
833
+ #
834
+ # def _create_debug_image(self, image: np.ndarray, detections: List[Dict]) -> np.ndarray:
835
+ # """Optional: draw bounding boxes & labels on a copy of 'image'."""
836
+ # debug_img = image.copy()
837
+ # for det in detections:
838
+ # x1, y1, x2, y2 = det['bbox']
839
+ # cv2.rectangle(debug_img, (x1, y1), (x2, y2), (0, 255, 0), 2)
840
+ # txt = f"{det['label']} {det['confidence']:.2f}"
841
+ # cv2.putText(debug_img, txt, (x1, max(0, y1 - 10)),
842
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
843
+ # return debug_img
844
+ #
845
+ #
846
+ # class TagDetector(BaseDetector):
847
+ # """
848
+ # A placeholder detector that reads precomputed tag data
849
+ # from a JSON file and populates the context with Tag objects.
850
+ # """
851
+ #
852
+ # def __init__(self,
853
+ # config: TagConfig,
854
+ # debug_handler: Optional[DebugHandler] = None,
855
+ # tag_json_path: str = "./tags.json"):
856
+ # super().__init__(config=config, debug_handler=debug_handler)
857
+ # self.tag_json_path = tag_json_path
858
+ #
859
+ # def _load_model(self, model_path: str):
860
+ # """Not loading an actual model; tag data is read from JSON."""
861
+ # return None
862
+ #
863
+ # def detect(self,
864
+ # image: np.ndarray,
865
+ # context: DetectionContext,
866
+ # roi_offset: Tuple[int, int],
867
+ # *args,
868
+ # **kwargs) -> None:
869
+ # """
870
+ # Reads from a JSON file containing tag info,
871
+ # adjusts coordinates using roi_offset, and updates context.
872
+ # """
873
+ #
874
+ # tag_data = self._load_json_data(self.tag_json_path)
875
+ # if not tag_data:
876
+ # return
877
+ #
878
+ # x_min, y_min = roi_offset # Offset values from cropping
879
+ #
880
+ # for record in tag_data.get("detections", []): # Fix: Use "detections" key
881
+ # tag_obj = self._parse_tag_record(record, x_min, y_min)
882
+ # context.add_tag(tag_obj)
883
+ #
884
+ # def _preprocess(self, image: np.ndarray) -> np.ndarray:
885
+ # return image
886
+ #
887
+ # def _postprocess(self, image: np.ndarray) -> np.ndarray:
888
+ # return image
889
+ #
890
+ # # --------------
891
+ # # HELPER METHODS
892
+ # # --------------
893
+ # def _load_json_data(self, json_path: str) -> dict:
894
+ # if not os.path.exists(json_path):
895
+ # self.debug_handler.save_artifact(name="tag_error",
896
+ # data=b"Missing tag JSON file",
897
+ # extension="txt")
898
+ # return {}
899
+ #
900
+ # with open(json_path, "r", encoding="utf-8") as f:
901
+ # return json.load(f)
902
+ #
903
+ # def _parse_tag_record(self, record: dict, x_min: int, y_min: int) -> Tag:
904
+ # """
905
+ # Builds a Tag object from a JSON record, adjusting coordinates for cropping.
906
+ # """
907
+ # bbox_list = record.get("bbox", [0, 0, 0, 0])
908
+ # bbox_obj = BBox(
909
+ # xmin=bbox_list[0] - x_min,
910
+ # ymin=bbox_list[1] - y_min,
911
+ # xmax=bbox_list[2] - x_min,
912
+ # ymax=bbox_list[3] - y_min
913
+ # )
914
+ #
915
+ # return Tag(
916
+ # text=record.get("text", ""),
917
+ # bbox=bbox_obj,
918
+ # confidence=record.get("confidence", 1.0),
919
+ # source=record.get("source", ""),
920
+ # text_type=record.get("text_type", "Unknown"),
921
+ # id=record.get("id", str(uuid.uuid4())),
922
+ # font_size=record.get("font_size", 12),
923
+ # rotation=record.get("rotation", 0.0)
924
+ # )
925
+
926
+
927
+ import json
928
+ import uuid
929
+
930
+ class SymbolDetector(BaseDetector):
931
+ """
932
+ A placeholder detector that reads precomputed symbol data
933
+ from a JSON file and populates the context with Symbol objects.
934
+ """
935
+
936
+ def __init__(self,
937
+ config: SymbolConfig,
938
+ debug_handler: Optional[DebugHandler] = None,
939
+ symbol_json_path: str = "./symbols.json"):
940
+ super().__init__(config=config, debug_handler=debug_handler)
941
+ self.symbol_json_path = symbol_json_path
942
+
943
+ def _load_model(self, model_path: str):
944
+ """Not loading an actual model; symbol data is read from JSON."""
945
+ return None
946
+
947
+ def detect(self,
948
+ image: np.ndarray,
949
+ context: DetectionContext,
950
+ roi_offset: Tuple[int, int],
951
+ *args,
952
+ **kwargs) -> None:
953
+ """
954
+ Reads from a JSON file containing symbol info,
955
+ adjusts coordinates using roi_offset, and updates context.
956
+ """
957
+ symbol_data = self._load_json_data(self.symbol_json_path)
958
+ if not symbol_data:
959
+ return
960
+
961
+ x_min, y_min = roi_offset # Offset values from cropping
962
+
963
+ for record in symbol_data.get("detections", []): # Fix: Use "detections" key
964
+ sym_obj = self._parse_symbol_record(record, x_min, y_min)
965
+ context.add_symbol(sym_obj)
966
+
967
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
968
+ return image
969
+
970
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
971
+ return image
972
+
973
+ # --------------
974
+ # HELPER METHODS
975
+ # --------------
976
+ def _load_json_data(self, json_path: str) -> dict:
977
+ if not os.path.exists(json_path):
978
+ self.debug_handler.save_artifact(name="symbol_error",
979
+ data=b"Missing symbol JSON file",
980
+ extension="txt")
981
+ return {}
982
+
983
+ with open(json_path, "r", encoding="utf-8") as f:
984
+ return json.load(f)
985
+
986
+ def _parse_symbol_record(self, record: dict, x_min: int, y_min: int) -> Symbol:
987
+ """
988
+ Builds a Symbol object from a JSON record, adjusting coordinates for cropping.
989
+ """
990
+ bbox_list = record.get("bbox", [0, 0, 0, 0])
991
+ bbox_obj = BBox(
992
+ xmin=bbox_list[0] - x_min,
993
+ ymin=bbox_list[1] - y_min,
994
+ xmax=bbox_list[2] - x_min,
995
+ ymax=bbox_list[3] - y_min
996
+ )
997
+
998
+ # Compute the center
999
+ center_coords = Coordinates(
1000
+ x=(bbox_obj.xmin + bbox_obj.xmax) // 2,
1001
+ y=(bbox_obj.ymin + bbox_obj.ymax) // 2
1002
+ )
1003
+
1004
+ return Symbol(
1005
+ id=record.get("symbol_id", ""),
1006
+ class_id=record.get("class_id", -1),
1007
+ original_label=record.get("original_label", ""),
1008
+ category=record.get("category", ""),
1009
+ type=record.get("type", ""),
1010
+ label=record.get("label", ""),
1011
+ bbox=bbox_obj,
1012
+ center=center_coords,
1013
+ confidence=record.get("confidence", 0.95),
1014
+ model_source=record.get("model_source", ""),
1015
+ connections=[]
1016
+ )
1017
+
1018
+ class TagDetector(BaseDetector):
1019
+ """
1020
+ A placeholder detector that reads precomputed tag data
1021
+ from a JSON file and populates the context with Tag objects.
1022
+ """
1023
+
1024
+ def __init__(self,
1025
+ config: TagConfig,
1026
+ debug_handler: Optional[DebugHandler] = None,
1027
+ tag_json_path: str = "./tags.json"):
1028
+ super().__init__(config=config, debug_handler=debug_handler)
1029
+ self.tag_json_path = tag_json_path
1030
+
1031
+ def _load_model(self, model_path: str):
1032
+ """Not loading an actual model; tag data is read from JSON."""
1033
+ return None
1034
+
1035
+ def detect(self,
1036
+ image: np.ndarray,
1037
+ context: DetectionContext,
1038
+ roi_offset: Tuple[int, int],
1039
+ *args,
1040
+ **kwargs) -> None:
1041
+ """
1042
+ Reads from a JSON file containing tag info,
1043
+ adjusts coordinates using roi_offset, and updates context.
1044
+ """
1045
+
1046
+ tag_data = self._load_json_data(self.tag_json_path)
1047
+ if not tag_data:
1048
+ return
1049
+
1050
+ x_min, y_min = roi_offset # Offset values from cropping
1051
+
1052
+ for record in tag_data.get("detections", []): # Fix: Use "detections" key
1053
+ tag_obj = self._parse_tag_record(record, x_min, y_min)
1054
+ context.add_tag(tag_obj)
1055
+
1056
+ def _preprocess(self, image: np.ndarray) -> np.ndarray:
1057
+ return image
1058
+
1059
+ def _postprocess(self, image: np.ndarray) -> np.ndarray:
1060
+ return image
1061
+
1062
+ # --------------
1063
+ # HELPER METHODS
1064
+ # --------------
1065
+ def _load_json_data(self, json_path: str) -> dict:
1066
+ if not os.path.exists(json_path):
1067
+ self.debug_handler.save_artifact(name="tag_error",
1068
+ data=b"Missing tag JSON file",
1069
+ extension="txt")
1070
+ return {}
1071
+
1072
+ with open(json_path, "r", encoding="utf-8") as f:
1073
+ return json.load(f)
1074
+
1075
+ def _parse_tag_record(self, record: dict, x_min: int, y_min: int) -> Tag:
1076
+ """
1077
+ Builds a Tag object from a JSON record, adjusting coordinates for cropping.
1078
+ """
1079
+ bbox_list = record.get("bbox", [0, 0, 0, 0])
1080
+ bbox_obj = BBox(
1081
+ xmin=bbox_list[0] - x_min,
1082
+ ymin=bbox_list[1] - y_min,
1083
+ xmax=bbox_list[2] - x_min,
1084
+ ymax=bbox_list[3] - y_min
1085
+ )
1086
+
1087
+ return Tag(
1088
+ text=record.get("text", ""),
1089
+ bbox=bbox_obj,
1090
+ confidence=record.get("confidence", 1.0),
1091
+ source=record.get("source", ""),
1092
+ text_type=record.get("text_type", "Unknown"),
1093
+ id=record.get("id", str(uuid.uuid4())),
1094
+ font_size=record.get("font_size", 12),
1095
+ rotation=record.get("rotation", 0.0)
1096
+ )
gradioChatApp.py ADDED
@@ -0,0 +1,672 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import base64
3
+ import gradio as gr
4
+ import json
5
+ from datetime import datetime
6
+ from symbol_detection import run_detection_with_optimal_threshold
7
+ from line_detection_ai import DiagramDetectionPipeline, LineDetector, LineConfig, ImageConfig, DebugHandler, PointConfig, JunctionConfig, PointDetector, JunctionDetector, SymbolConfig, SymbolDetector, TagConfig, TagDetector
8
+ from data_aggregation_ai import DataAggregator
9
+ from chatbot_agent import get_assistant_response
10
+ from storage import StorageFactory, LocalStorage
11
+ import traceback
12
+ from text_detection_combined import process_drawing
13
+ from pathlib import Path
14
+ from pdf_processor import DocumentProcessor
15
+ import networkx as nx
16
+ import logging
17
+ import matplotlib.pyplot as plt
18
+ from dotenv import load_dotenv
19
+ import torch
20
+ from graph_visualization import create_graph_visualization
21
+ import shutil
22
+
23
+ # Load environment variables from .env file
24
+ load_dotenv()
25
+
26
+ # Configure logging at the start of the file
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(levelname)s - %(message)s',
30
+ datefmt='%Y-%m-%d %H:%M:%S'
31
+ )
32
+
33
+ # Get logger for this module
34
+ logger = logging.getLogger(__name__)
35
+
36
+ # Disable duplicate logs from other modules
37
+ logging.getLogger('PIL').setLevel(logging.WARNING)
38
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
39
+ logging.getLogger('gradio').setLevel(logging.WARNING)
40
+ logging.getLogger('networkx').setLevel(logging.WARNING)
41
+ logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
42
+ logging.getLogger('symbol_detection').setLevel(logging.WARNING)
43
+
44
+ # Only log important messages
45
+ def log_process_step(message, level=logging.INFO):
46
+ """Log processing steps with appropriate level"""
47
+ if level >= logging.WARNING:
48
+ logger.log(level, message)
49
+ elif "completed" in message.lower() or "generated" in message.lower():
50
+ logger.info(message)
51
+
52
+ # Helper function to format timestamps
53
+ def get_timestamp():
54
+ return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
55
+
56
+ def format_message(role, content):
57
+ """Format message for chatbot history."""
58
+ return {"role": role, "content": content}
59
+
60
+ # Load avatar images for agents
61
+ localStorage = LocalStorage()
62
+ agent_avatar = base64.b64encode(localStorage.load_file("assets/AiAgent.png")).decode()
63
+ llm_avatar = base64.b64encode(localStorage.load_file("assets/llm.png")).decode()
64
+ user_avatar = base64.b64encode(localStorage.load_file("assets/user.png")).decode()
65
+
66
+ # Chat message formatting with avatars and enhanced HTML for readability
67
+ def chat_message(role, message, avatar, timestamp):
68
+ # Convert Markdown-style formatting to HTML
69
+ formatted_message = (
70
+ message.replace("**", "<strong>").replace("**", "</strong>")
71
+ .replace("###", "<h3>").replace("##", "<h2>")
72
+ .replace("#", "<h1>").replace("\n", "<br>")
73
+ .replace("```", "<pre><code>").replace("`", "</code></pre>")
74
+ .replace("\n1. ", "<br>1. ") # For ordered lists starting with "1."
75
+ .replace("\n2. ", "<br>2. ")
76
+ .replace("\n3. ", "<br>3. ")
77
+ .replace("\n4. ", "<br>4. ")
78
+ .replace("\n5. ", "<br>5. ")
79
+ )
80
+
81
+ return f"""
82
+ <div class="chat-message {role}">
83
+ <img src="data:image/png;base64,{avatar}" class="avatar"/>
84
+ <div>
85
+ <div class="speech-bubble {role}-bubble">{formatted_message}</div>
86
+ <div class="timestamp">{timestamp}</div>
87
+ </div>
88
+ </div>
89
+ """
90
+
91
+ # Main processing function for P&ID steps
92
+ def process_pnid(image_file, progress_status, progress=gr.Progress()):
93
+ """Process P&ID document with real-time progress updates."""
94
+ try:
95
+ # Disable verbose logging for processing components
96
+ logging.getLogger('line_detection_ai').setLevel(logging.WARNING)
97
+ logging.getLogger('symbol_detection').setLevel(logging.WARNING)
98
+ logging.getLogger('text_detection').setLevel(logging.WARNING)
99
+
100
+ progress_text = []
101
+ outputs = [None] * 9
102
+
103
+ def update_progress(step, message):
104
+ timestamp = get_timestamp()
105
+ progress_text.append(f"{timestamp} - {message}")
106
+ outputs[7] = "\n".join(progress_text[-20:]) # Keep last 20 lines
107
+ progress(step, desc=f"Step {step}/7: {message}")
108
+ return outputs
109
+
110
+ # Update progress with smaller steps
111
+ update_progress(0.1, "Starting processing...")
112
+ yield outputs
113
+
114
+ storage = StorageFactory.get_storage()
115
+ results_dir = "results"
116
+ outputs = [None] * 9
117
+
118
+ if image_file is None:
119
+ raise ValueError("No file uploaded")
120
+
121
+ os.makedirs(results_dir, exist_ok=True)
122
+ current_progress = 0
123
+ progress_text = []
124
+
125
+ # Step 1: File Upload (10%)
126
+ logger.info(f"Processing file: {os.path.basename(image_file)}")
127
+ update_progress(0.1, "Step 1/7: File uploaded successfully")
128
+ yield outputs
129
+
130
+ # Step 2: Document Processing (25%)
131
+ update_progress(0.25, "Step 2/7: Processing document...")
132
+ yield outputs
133
+
134
+ doc_processor = DocumentProcessor(storage)
135
+ processed_pages = doc_processor.process_document(
136
+ file_path=image_file,
137
+ output_dir=results_dir
138
+ )
139
+
140
+ if not processed_pages:
141
+ raise ValueError("No pages processed from document")
142
+
143
+ display_path = processed_pages[0]
144
+ outputs[0] = display_path
145
+ update_progress(0.25, "Document processed successfully")
146
+ yield outputs
147
+
148
+ # Step 3: Symbol Detection (45%)
149
+ update_progress(0.45, "Step 3/7: Symbol Detection")
150
+ yield outputs
151
+
152
+ # Store detection results and diagram_bbox
153
+ detection_results = run_detection_with_optimal_threshold(
154
+ display_path,
155
+ results_dir=results_dir,
156
+ file_name=os.path.basename(display_path),
157
+ resize_image=True,
158
+ storage=storage
159
+ )
160
+ detection_image_path, detection_json_path, _, diagram_bbox = detection_results
161
+
162
+ if diagram_bbox is None:
163
+ logger.warning("No diagram bounding box detected, using full image")
164
+ # Provide a fallback bbox if needed
165
+ diagram_bbox = [0, 0, 0, 0] # Or get image dimensions
166
+
167
+ outputs[1] = detection_image_path
168
+ update_progress(0.45, "Symbol detection completed")
169
+ yield outputs
170
+
171
+ # Step 4: Text Detection (65%)
172
+ update_progress(0.65, "Step 4/7: Text Detection")
173
+ yield outputs
174
+
175
+ text_results, text_summary = process_drawing(display_path, results_dir, storage)
176
+ outputs[2] = text_results['image_path']
177
+
178
+ update_progress(0.65, "Text detection completed")
179
+ update_progress(0.65, f"Found {text_summary['total_detections']} text elements")
180
+ yield outputs
181
+
182
+ # Step 5: Line Detection (80%)
183
+ update_progress(0.80, "Step 5/7: Line Detection")
184
+ yield outputs
185
+
186
+ try:
187
+ # Initialize components
188
+ debug_handler = DebugHandler(enabled=True, storage=storage)
189
+
190
+ # Configure detectors
191
+ line_config = LineConfig()
192
+ point_config = PointConfig()
193
+ junction_config = JunctionConfig()
194
+ symbol_config = SymbolConfig()
195
+ tag_config = TagConfig()
196
+
197
+ # Create all required detectors
198
+ symbol_detector = SymbolDetector(
199
+ config=symbol_config,
200
+ debug_handler=debug_handler
201
+ )
202
+
203
+ tag_detector = TagDetector(
204
+ config=tag_config,
205
+ debug_handler=debug_handler
206
+ )
207
+
208
+ line_detector = LineDetector(
209
+ config=line_config,
210
+ model_path="models/deeplsd_md.tar",
211
+ model_config={"detect_lines": True},
212
+ device=torch.device("cpu"),
213
+ debug_handler=debug_handler
214
+ )
215
+
216
+ point_detector = PointDetector(
217
+ config=point_config,
218
+ debug_handler=debug_handler
219
+ )
220
+
221
+ junction_detector = JunctionDetector(
222
+ config=junction_config,
223
+ debug_handler=debug_handler
224
+ )
225
+
226
+ # Create and run pipeline with all detectors
227
+ pipeline = DiagramDetectionPipeline(
228
+ tag_detector=tag_detector,
229
+ symbol_detector=symbol_detector,
230
+ line_detector=line_detector,
231
+ point_detector=point_detector,
232
+ junction_detector=junction_detector,
233
+ storage=storage,
234
+ debug_handler=debug_handler
235
+ )
236
+
237
+ # Run pipeline
238
+ result = pipeline.run(
239
+ image_path=display_path,
240
+ output_dir=results_dir,
241
+ config=ImageConfig()
242
+ )
243
+
244
+ if result.success:
245
+ line_image_path = result.image_path
246
+ line_json_path = result.json_path
247
+ outputs[3] = line_image_path
248
+ update_progress(0.80, "Line detection completed")
249
+ else:
250
+ logger.error(f"Pipeline failed: {result.error}")
251
+ raise Exception(result.error)
252
+
253
+ except Exception as e:
254
+ logger.error(f"Line detection error: {str(e)}")
255
+ raise
256
+
257
+ # Step 6: Data Aggregation (90%)
258
+ update_progress(0.90, "Step 6/7: Data Aggregation")
259
+ yield outputs
260
+
261
+ data_aggregator = DataAggregator(storage=storage)
262
+ aggregated_data = data_aggregator.aggregate_data(
263
+ symbols_path=detection_json_path,
264
+ texts_path=text_results['json_path'],
265
+ lines_path=line_json_path
266
+ )
267
+
268
+ # Add image path to aggregated data
269
+ aggregated_data['image_path'] = display_path
270
+
271
+ # Save aggregated data
272
+ aggregated_json_path = os.path.join(results_dir, f"{Path(display_path).stem}_aggregated.json")
273
+ with open(aggregated_json_path, 'w') as f:
274
+ json.dump(aggregated_data, f, indent=2)
275
+
276
+ # Use the detection image as the aggregated view for now
277
+ # TODO: Implement visualization in DataAggregator if needed
278
+ outputs[4] = detection_image_path # Changed from aggregated_image_path
279
+ outputs[8] = aggregated_json_path
280
+ update_progress(0.90, "Data aggregation completed")
281
+ yield outputs
282
+
283
+ # Step 7: Graph Generation (95%)
284
+ update_progress(0.95, "Step 7/7: Generating knowledge graph...")
285
+ yield outputs
286
+
287
+ try:
288
+ with open(aggregated_json_path, 'r') as f:
289
+ aggregated_detection_data = json.load(f)
290
+
291
+ logger.info("Creating knowledge graph...")
292
+
293
+ # Create graph visualization - this will save the visualization file
294
+ G, _ = create_graph_visualization(aggregated_json_path, save_plot=True)
295
+
296
+ if G is not None:
297
+ # Use the saved visualization file
298
+ graph_image_path = os.path.join(os.path.dirname(aggregated_json_path), "graph_visualization.png")
299
+
300
+ if os.path.exists(graph_image_path):
301
+ outputs[5] = graph_image_path
302
+ update_progress(0.95, "Knowledge graph generated")
303
+ logger.info("Knowledge graph generated and saved successfully")
304
+
305
+ # Final completion (100%)
306
+ update_progress(1.0, "✅ Processing Complete")
307
+ welcome_message = chat_message(
308
+ "agent",
309
+ "Processing complete! I can help answer questions about the P&ID contents.",
310
+ agent_avatar,
311
+ get_timestamp()
312
+ )
313
+ outputs[6] = welcome_message
314
+ update_progress(1.0, "✅ All processing steps completed successfully!")
315
+ yield outputs
316
+ else:
317
+ logger.warning("Graph visualization file not found")
318
+ update_progress(1.0, "⚠️ Warning: Graph visualization could not be generated")
319
+ yield outputs
320
+ else:
321
+ logger.warning("No graph was generated")
322
+ update_progress(1.0, "⚠️ Warning: No graph could be generated")
323
+ yield outputs
324
+
325
+ except Exception as e:
326
+ logger.error(f"Error in graph generation: {str(e)}")
327
+ logger.error(f"Traceback: {traceback.format_exc()}")
328
+ raise
329
+
330
+ except Exception as e:
331
+ logger.error(f"Error in process_pnid: {str(e)}")
332
+ logger.error(traceback.format_exc())
333
+ error_msg = f"❌ Error: {str(e)}"
334
+ update_progress(1.0, error_msg)
335
+ yield outputs
336
+
337
+ # Separate function for Chat interaction
338
+ def handle_user_message(user_input, chat_history, json_path_state):
339
+ """Handle user messages and generate responses."""
340
+ try:
341
+ if not user_input or not user_input.strip():
342
+ return chat_history
343
+
344
+ # Add user message
345
+ timestamp = get_timestamp()
346
+ new_history = chat_history + chat_message("user", user_input, user_avatar, timestamp)
347
+
348
+ # Check if json_path exists and is valid
349
+ if not json_path_state or not os.path.exists(json_path_state):
350
+ error_message = "Please upload and process a P&ID document first."
351
+ return new_history + chat_message("assistant", error_message, agent_avatar, get_timestamp())
352
+
353
+ try:
354
+ # Log for debugging
355
+ logger.info(f"Sending question to assistant: {user_input}")
356
+ logger.info(f"Using JSON path: {json_path_state}")
357
+
358
+ # Generate response
359
+ response = get_assistant_response(user_input, json_path_state)
360
+
361
+ # Handle the response
362
+ if isinstance(response, (str, dict)):
363
+ response_text = str(response)
364
+ else:
365
+ try:
366
+ # Try to get the first response from generator
367
+ response_text = next(response) if hasattr(response, '__next__') else str(response)
368
+ except StopIteration:
369
+ response_text = "I apologize, but I couldn't generate a response."
370
+ except Exception as e:
371
+ logger.error(f"Error processing response: {str(e)}")
372
+ response_text = "I apologize, but I encountered an error processing your request."
373
+
374
+ logger.info(f"Generated response: {response_text}")
375
+
376
+ if not response_text.strip():
377
+ response_text = "I apologize, but I couldn't generate a response. Please try asking your question differently."
378
+
379
+ # Add response to chat history
380
+ new_history += chat_message("assistant", response_text, agent_avatar, get_timestamp())
381
+
382
+ except Exception as e:
383
+ logger.error(f"Error generating response: {str(e)}")
384
+ logger.error(traceback.format_exc())
385
+ error_message = "I apologize, but I encountered an error processing your request. Please try again."
386
+ new_history += chat_message("assistant", error_message, agent_avatar, get_timestamp())
387
+
388
+ return new_history
389
+
390
+ except Exception as e:
391
+ logger.error(f"Chat error: {str(e)}")
392
+ logger.error(traceback.format_exc())
393
+ return chat_history + chat_message(
394
+ "assistant",
395
+ "I apologize, but something went wrong. Please try again.",
396
+ agent_avatar,
397
+ get_timestamp()
398
+ )
399
+
400
+ # Update custom CSS
401
+ custom_css = """
402
+ .full-height-row {
403
+ height: calc(100vh - 150px); /* Adjusted height */
404
+ margin: 0;
405
+ padding: 10px;
406
+ }
407
+ .upload-box {
408
+ background: #2a2a2a;
409
+ border-radius: 8px;
410
+ padding: 15px;
411
+ margin-bottom: 15px;
412
+ border: 1px solid #3a3a3a;
413
+ }
414
+ .status-box-container {
415
+ background: #2a2a2a;
416
+ border-radius: 8px;
417
+ padding: 15px;
418
+ height: calc(100vh - 350px); /* Reduced height */
419
+ border: 1px solid #3a3a3a;
420
+ margin-bottom: 15px;
421
+ }
422
+ .status-box {
423
+ font-family: 'Courier New', monospace;
424
+ font-size: 12px;
425
+ line-height: 1.4;
426
+ background-color: #1a1a1a;
427
+ color: #00ff00;
428
+ padding: 10px;
429
+ border-radius: 5px;
430
+ height: calc(100% - 40px); /* Adjust for header */
431
+ overflow-y: auto;
432
+ white-space: pre-wrap;
433
+ word-wrap: break-word;
434
+ border: none;
435
+ }
436
+ .preview-tabs {
437
+ height: calc(100vh - 350px); /* Reduced height */
438
+ background: #2a2a2a;
439
+ border-radius: 8px;
440
+ padding: 15px;
441
+ border: 1px solid #3a3a3a;
442
+ margin-bottom: 15px;
443
+ }
444
+ .chat-container {
445
+ height: 100%; /* Take full height */
446
+ display: flex;
447
+ flex-direction: column;
448
+ background: #2a2a2a;
449
+ border-radius: 8px;
450
+ padding: 15px;
451
+ border: 1px solid #3a3a3a;
452
+ }
453
+ .chatbox {
454
+ flex: 1; /* Take remaining space */
455
+ overflow-y: auto;
456
+ background: #1a1a1a;
457
+ border-radius: 8px;
458
+ padding: 15px;
459
+ margin-bottom: 15px;
460
+ color: #ffffff;
461
+ min-height: 200px; /* Ensure minimum height */
462
+ }
463
+ .chat-input-group {
464
+ height: auto; /* Allow natural height */
465
+ min-height: 120px; /* Minimum height for input area */
466
+ background: #1a1a1a;
467
+ border-radius: 8px;
468
+ padding: 15px;
469
+ margin-top: auto; /* Push to bottom */
470
+ }
471
+ .chat-input {
472
+ background: #2a2a2a;
473
+ color: #ffffff;
474
+ border: 1px solid #3a3a3a;
475
+ border-radius: 5px;
476
+ padding: 12px;
477
+ min-height: 80px;
478
+ width: 100%;
479
+ margin-bottom: 10px;
480
+ }
481
+ .send-button {
482
+ width: 100%;
483
+ background: #4a4a4a;
484
+ color: #ffffff;
485
+ border-radius: 5px;
486
+ border: none;
487
+ padding: 12px;
488
+ cursor: pointer;
489
+ transition: background-color 0.3s;
490
+ }
491
+ .result-image {
492
+ border-radius: 8px;
493
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
494
+ margin: 10px 0;
495
+ background: #ffffff;
496
+ }
497
+ .chat-message {
498
+ display: flex;
499
+ margin-bottom: 1rem;
500
+ align-items: flex-start;
501
+ }
502
+ .chat-message .avatar {
503
+ width: 40px;
504
+ height: 40px;
505
+ margin-right: 10px;
506
+ border-radius: 50%;
507
+ }
508
+ .chat-message .speech-bubble {
509
+ background: #2a2a2a;
510
+ padding: 10px 15px;
511
+ border-radius: 10px;
512
+ max-width: 80%;
513
+ margin-bottom: 5px;
514
+ }
515
+ .chat-message .timestamp {
516
+ font-size: 0.8em;
517
+ color: #666;
518
+ }
519
+ .logo-row {
520
+ width: 100%;
521
+ background-color: #1a1a1a;
522
+ padding: 10px 0;
523
+ margin: 0;
524
+ border-bottom: 1px solid #3a3a3a;
525
+ }
526
+ """
527
+
528
+ def create_ui():
529
+ with gr.Blocks(css=custom_css) as demo:
530
+ # Logo row
531
+ with gr.Row(elem_classes=["logo-row"]):
532
+ try:
533
+ logo_path = os.path.join(os.path.dirname(__file__), "assets", "intuigence.png")
534
+ if os.path.exists(logo_path):
535
+ with open(logo_path, "rb") as f:
536
+ logo_base64 = base64.b64encode(f.read()).decode()
537
+ gr.HTML(f"""
538
+ <div style="text-align: center; padding: 10px; background-color: #1a1a1a; width: 100%;">
539
+ <img src="data:image/png;base64,{logo_base64}"
540
+ alt="Intuigence Logo"
541
+ style="height: 60px; object-fit: contain;">
542
+ </div>
543
+ """)
544
+ else:
545
+ logger.warning(f"Logo not found at {logo_path}")
546
+ except Exception as e:
547
+ logger.error(f"Error loading logo: {e}")
548
+
549
+ # Main layout
550
+ with gr.Row(equal_height=True, elem_classes=["full-height-row"]):
551
+ # Left column
552
+ with gr.Column(scale=2):
553
+ # Upload area
554
+ with gr.Column(elem_classes=["upload-box"]):
555
+ image_input = gr.File(
556
+ label="Upload P&ID Document",
557
+ file_types=[".pdf", ".png", ".jpg", ".jpeg"],
558
+ file_count="single",
559
+ type="filepath"
560
+ )
561
+
562
+ # Status area
563
+ with gr.Column(elem_classes=["status-box-container"]):
564
+ gr.Markdown("### Processing Status")
565
+ progress_status = gr.Textbox(
566
+ label="Status",
567
+ show_label=False,
568
+ elem_classes=["status-box"],
569
+ lines=15,
570
+ max_lines=20,
571
+ interactive=False,
572
+ autoscroll=True,
573
+ value="" # Initialize with empty value
574
+ )
575
+ json_path_state = gr.State()
576
+
577
+ # Center column
578
+ with gr.Column(scale=5):
579
+ with gr.Tabs(elem_classes=["preview-tabs"]) as tabs:
580
+ with gr.TabItem("P&ID"):
581
+ original_image = gr.Image(label="Original P&ID", height=450) # Reduced height
582
+ with gr.TabItem("Symbols"):
583
+ symbol_image = gr.Image(label="Detected Symbols", height=450)
584
+ with gr.TabItem("Tags"):
585
+ text_image = gr.Image(label="Detected Tags", height=450)
586
+ with gr.TabItem("Pipelines"):
587
+ line_image = gr.Image(label="Detected Lines", height=450)
588
+ with gr.TabItem("Aggregated"):
589
+ aggregated_image = gr.Image(label="Aggregated Results", height=450)
590
+ with gr.TabItem("Graph"):
591
+ graph_image = gr.Image(label="Knowledge Graph", height=450)
592
+
593
+ # Right column
594
+ with gr.Column(scale=3):
595
+ with gr.Column(elem_classes=["chat-container"]):
596
+ gr.Markdown("### Chat Interface")
597
+ # Initialize chat with a welcome message
598
+ initial_chat = chat_message(
599
+ "agent",
600
+ "Ready to process P&ID documents and answer questions.",
601
+ agent_avatar,
602
+ get_timestamp()
603
+ )
604
+ chat_output = gr.HTML(
605
+ label="Chat",
606
+ elem_classes=["chatbox"],
607
+ value=initial_chat
608
+ )
609
+ # Message input and send button in a fixed-height container
610
+ with gr.Column(elem_classes=["chat-input-group"]):
611
+ user_input = gr.Textbox(
612
+ show_label=False,
613
+ placeholder="Type your question here...",
614
+ elem_classes=["chat-input"],
615
+ lines=3
616
+ )
617
+ send_button = gr.Button(
618
+ "Send",
619
+ elem_classes=["send-button"]
620
+ )
621
+
622
+ # Set up event handlers inside the Blocks context
623
+ image_input.upload(
624
+ fn=process_pnid,
625
+ inputs=[image_input, progress_status],
626
+ outputs=[
627
+ original_image,
628
+ symbol_image,
629
+ text_image,
630
+ line_image,
631
+ aggregated_image,
632
+ graph_image,
633
+ chat_output,
634
+ progress_status,
635
+ json_path_state
636
+ ],
637
+ show_progress="hidden" # Hide the default progress bar
638
+ )
639
+
640
+ # Add input clearing and enable/disable logic for chat
641
+ def clear_and_handle_message(user_message, chat_history, json_path):
642
+ response = handle_user_message(user_message, chat_history, json_path)
643
+ return "", response # Clear input after sending
644
+
645
+ send_button.click(
646
+ fn=clear_and_handle_message,
647
+ inputs=[user_input, chat_output, json_path_state],
648
+ outputs=[user_input, chat_output]
649
+ )
650
+
651
+ # Also trigger on Enter key
652
+ user_input.submit(
653
+ fn=clear_and_handle_message,
654
+ inputs=[user_input, chat_output, json_path_state],
655
+ outputs=[user_input, chat_output]
656
+ )
657
+
658
+ return demo
659
+
660
+ def main():
661
+ demo = create_ui()
662
+ # Remove HF Spaces conditional, just use local development settings
663
+ demo.launch(server_name="0.0.0.0",
664
+ server_port=7860,
665
+ share=False)
666
+
667
+ if __name__ == "__main__":
668
+ main()
669
+ else:
670
+ # For Spaces deployment
671
+ demo = create_ui()
672
+ app = demo.app # Gradio requires 'app' variable for Spaces
graph_construction.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import networkx as nx
4
+ import matplotlib.pyplot as plt
5
+ from pathlib import Path
6
+ import logging
7
+ import traceback
8
+ from storage import StorageFactory
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
13
+ """Construct network graph from aggregated detection data"""
14
+ try:
15
+ # Use provided storage or get a new one
16
+ if storage is None:
17
+ storage = StorageFactory.get_storage()
18
+
19
+ # Create graph
20
+ G = nx.Graph()
21
+ pos = {} # For node positions
22
+
23
+ # Add nodes from the aggregated data
24
+ for node in data.get('nodes', []):
25
+ node_id = node['id']
26
+ node_type = node['type']
27
+
28
+ # Calculate position based on node type
29
+ if node_type == 'connection_point':
30
+ pos[node_id] = (node['coords']['x'], node['coords']['y'])
31
+ else: # symbol or text
32
+ bbox = node['bbox']
33
+ pos[node_id] = (
34
+ (bbox['xmin'] + bbox['xmax']) / 2,
35
+ (bbox['ymin'] + bbox['ymax']) / 2
36
+ )
37
+
38
+ # Add node with all its properties
39
+ G.add_node(node_id, **node)
40
+
41
+ # Add edges from the aggregated data
42
+ for edge in data.get('edges', []):
43
+ G.add_edge(
44
+ edge['source'],
45
+ edge['target'],
46
+ **edge.get('properties', {})
47
+ )
48
+
49
+ # Create visualization
50
+ plt.figure(figsize=(20, 20))
51
+
52
+ # Draw nodes with different colors based on type
53
+ node_colors = []
54
+ for node in G.nodes():
55
+ node_type = G.nodes[node]['type']
56
+ if node_type == 'symbol':
57
+ node_colors.append('lightblue')
58
+ elif node_type == 'text':
59
+ node_colors.append('lightgreen')
60
+ else: # connection_point
61
+ node_colors.append('lightgray')
62
+
63
+ nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
64
+ nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
65
+
66
+ # Add labels
67
+ labels = {}
68
+ for node in G.nodes():
69
+ node_data = G.nodes[node]
70
+ if node_data['type'] == 'symbol':
71
+ labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
72
+ elif node_data['type'] == 'text':
73
+ content = node_data.get('content', '')
74
+ labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
75
+ else:
76
+ labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
77
+
78
+ nx.draw_networkx_labels(G, pos, labels, font_size=8)
79
+
80
+ plt.title("P&ID Knowledge Graph")
81
+ plt.axis('off')
82
+
83
+ # Save the visualization
84
+ graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
85
+ plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
86
+ plt.close()
87
+
88
+ # Save graph data as JSON for future use
89
+ graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
90
+ with open(graph_json_path, 'w') as f:
91
+ json.dump(nx.node_link_data(G), f, indent=2)
92
+
93
+ return G, pos, plt.gcf()
94
+
95
+ except Exception as e:
96
+ logger.error(f"Error in construct_graph_network: {str(e)}")
97
+ traceback.print_exc()
98
+ return None, None, None
99
+
100
+ if __name__ == "__main__":
101
+ # Test code
102
+ test_data_path = "results/test_aggregated.json"
103
+ if os.path.exists(test_data_path):
104
+ with open(test_data_path, 'r') as f:
105
+ test_data = json.load(f)
106
+
107
+ G, pos, fig = construct_graph_network(
108
+ test_data,
109
+ "results/validation.json",
110
+ "results"
111
+ )
112
+ if fig:
113
+ plt.show()
graph_processor.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import networkx as nx
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ import traceback
6
+ import uuid
7
+
8
+ def create_connected_graph(input_data):
9
+ """Create a connected graph from the input data"""
10
+ try:
11
+ # Validate input data structure
12
+ if not isinstance(input_data, dict):
13
+ raise ValueError("Invalid input data format")
14
+
15
+ # Check for required keys in new format
16
+ required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges']
17
+ if not all(key in input_data for key in required_keys):
18
+ raise ValueError(f"Missing required keys in input data. Expected: {required_keys}")
19
+
20
+ # Create graph
21
+ G = nx.Graph()
22
+
23
+ # Track positions for layout
24
+ pos = {}
25
+
26
+ # Add symbol nodes
27
+ for symbol in input_data['symbols']:
28
+ bbox = symbol.get('bbox', [])
29
+ symbol_id = symbol.get('id', str(uuid.uuid4()))
30
+
31
+ if bbox:
32
+ # Calculate center position
33
+ center_x = (bbox['xmin'] + bbox['xmax']) / 2
34
+ center_y = (bbox['ymin'] + bbox['ymax']) / 2
35
+ pos[symbol_id] = (center_x, center_y)
36
+
37
+ G.add_node(
38
+ symbol_id,
39
+ type='symbol',
40
+ class_name=symbol.get('class', ''),
41
+ bbox=bbox,
42
+ confidence=symbol.get('confidence', 0.0)
43
+ )
44
+
45
+ # Add text nodes
46
+ for text in input_data['texts']:
47
+ bbox = text.get('bbox', [])
48
+ text_id = text.get('id', str(uuid.uuid4()))
49
+
50
+ if bbox:
51
+ center_x = (bbox['xmin'] + bbox['xmax']) / 2
52
+ center_y = (bbox['ymin'] + bbox['ymax']) / 2
53
+ pos[text_id] = (center_x, center_y)
54
+
55
+ G.add_node(
56
+ text_id,
57
+ type='text',
58
+ text=text.get('text', ''),
59
+ bbox=bbox,
60
+ confidence=text.get('confidence', 0.0)
61
+ )
62
+
63
+ # Add edges from the edges list
64
+ for edge in input_data['edges']:
65
+ source = edge.get('source')
66
+ target = edge.get('target')
67
+ if source and target and source in G and target in G:
68
+ G.add_edge(
69
+ source,
70
+ target,
71
+ type=edge.get('type', 'connection'),
72
+ properties=edge.get('properties', {})
73
+ )
74
+
75
+ # Create visualization
76
+ plt.figure(figsize=(20, 20))
77
+
78
+ # Draw nodes with fixed positions
79
+ nx.draw_networkx_nodes(G, pos,
80
+ node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node in G.nodes()],
81
+ node_size=500)
82
+
83
+ # Draw edges
84
+ nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
85
+
86
+ # Add labels
87
+ labels = {}
88
+ for node in G.nodes():
89
+ node_data = G.nodes[node]
90
+ if node_data['type'] == 'symbol':
91
+ labels[node] = f"S:{node_data['class_name']}"
92
+ else:
93
+ text = node_data.get('text', '')
94
+ labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}"
95
+
96
+ nx.draw_networkx_labels(G, pos, labels, font_size=8)
97
+
98
+ plt.title("P&ID Network Graph")
99
+ plt.axis('off')
100
+
101
+ return G, pos, plt.gcf()
102
+
103
+ except Exception as e:
104
+ print(f"Error in create_connected_graph: {str(e)}")
105
+ traceback.print_exc()
106
+ return None, None, None
107
+
108
+ if __name__ == "__main__":
109
+ # Test code
110
+ with open('results/0_aggregated.json') as f:
111
+ data = json.load(f)
112
+
113
+ G, pos, fig = create_connected_graph(data)
114
+ if fig:
115
+ plt.show()
graph_visualization.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import networkx as nx
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from pprint import pprint
6
+ import uuid
7
+
8
+ def create_graph_visualization(json_path, save_plot=True):
9
+ """Create and visualize a graph from the aggregated JSON data"""
10
+
11
+ # Load the aggregated data
12
+ with open(json_path, 'r') as f:
13
+ data = json.load(f)
14
+
15
+ print("\nData Structure:")
16
+ print(f"Keys in data: {data.keys()}")
17
+ for key in data.keys():
18
+ if isinstance(data[key], list):
19
+ print(f"Number of {key}: {len(data[key])}")
20
+ if data[key]:
21
+ print(f"Sample {key}:", data[key][0])
22
+
23
+ # Create a new graph
24
+ G = nx.Graph()
25
+ pos = {}
26
+
27
+ # Track unique junctions by coordinates to avoid duplicates
28
+ junction_map = {}
29
+
30
+ # Process lines and create unique junctions
31
+ print("\nProcessing Lines and Junctions:")
32
+ for line in data.get('lines', []):
33
+ try:
34
+ # Get line properties from edge data for accurate coordinates
35
+ edge_data = None
36
+ for edge in data.get('edges', []):
37
+ if edge['id'] == line['id']:
38
+ edge_data = edge
39
+ break
40
+
41
+ # Get coordinates and create unique ID for start/end points
42
+ if edge_data and 'connection_points' in edge_data['properties']:
43
+ conn_points = edge_data['properties']['connection_points']
44
+ start_x = int(conn_points['start']['x'])
45
+ start_y = int(conn_points['start']['y'])
46
+ end_x = int(conn_points['end']['x'])
47
+ end_y = int(conn_points['end']['y'])
48
+ start_id = str(uuid.uuid4())
49
+ end_id = str(uuid.uuid4())
50
+ else:
51
+ # Fallback to line points
52
+ start_x = int(line['start_point']['x'])
53
+ start_y = int(line['start_point']['y'])
54
+ end_x = int(line['end_point']['x'])
55
+ end_y = int(line['end_point']['y'])
56
+ start_id = line['start_point'].get('id', str(uuid.uuid4()))
57
+ end_id = line['end_point'].get('id', str(uuid.uuid4()))
58
+
59
+ # Skip invalid coordinates
60
+ if not (0 <= start_x <= 10000 and 0 <= start_y <= 10000 and
61
+ 0 <= end_x <= 10000 and 0 <= end_y <= 10000):
62
+ print(f"Skipping line with invalid coordinates: ({start_x}, {start_y}) -> ({end_x}, {end_y})")
63
+ continue
64
+
65
+ # Create or get junction nodes
66
+ start_key = f"{start_x}_{start_y}"
67
+ end_key = f"{end_x}_{end_y}"
68
+
69
+ if start_key not in junction_map:
70
+ junction_map[start_key] = start_id
71
+ G.add_node(start_id,
72
+ type='junction',
73
+ junction_type=line['start_point'].get('type', 'unknown'),
74
+ coords={'x': start_x, 'y': start_y})
75
+ pos[start_id] = (start_x, start_y)
76
+
77
+ if end_key not in junction_map:
78
+ junction_map[end_key] = end_id
79
+ G.add_node(end_id,
80
+ type='junction',
81
+ junction_type=line['end_point'].get('type', 'unknown'),
82
+ coords={'x': end_x, 'y': end_y})
83
+ pos[end_id] = (end_x, end_y)
84
+
85
+ # Add line as edge with style properties
86
+ G.add_edge(junction_map[start_key], junction_map[end_key],
87
+ type='line',
88
+ style=line.get('style', {}))
89
+
90
+ except Exception as e:
91
+ print(f"Error processing line: {str(e)}")
92
+ continue
93
+
94
+ # Add symbols and texts after lines
95
+ print("\nProcessing Symbols and Texts:")
96
+ for node in data.get('nodes', []):
97
+ if node['type'] in ['symbol', 'text']:
98
+ node_id = node['id']
99
+ coords = node.get('coords', {})
100
+ if coords:
101
+ x, y = coords['x'], coords['y']
102
+ elif 'center' in node:
103
+ x, y = node['center']['x'], node['center']['y']
104
+ else:
105
+ bbox = node['bbox']
106
+ x = (bbox['xmin'] + bbox['xmax']) / 2
107
+ y = (bbox['ymin'] + bbox['ymax']) / 2
108
+
109
+ G.add_node(node_id, **node)
110
+ pos[node_id] = (x, y)
111
+ print(f"Added {node['type']} at ({x}, {y})")
112
+
113
+ # Add default node if graph is empty
114
+ if not pos:
115
+ default_id = str(uuid.uuid4())
116
+ G.add_node(default_id, type='junction', coords={'x': 0, 'y': 0})
117
+ pos[default_id] = (0, 0)
118
+
119
+ # Scale positions to fit in [0, 1] range
120
+ x_vals = [p[0] for p in pos.values()]
121
+ y_vals = [p[1] for p in pos.values()]
122
+ x_min, x_max = min(x_vals), max(x_vals)
123
+ y_min, y_max = min(y_vals), max(y_vals)
124
+
125
+ scaled_pos = {}
126
+ for node, (x, y) in pos.items():
127
+ scaled_x = (x - x_min) / (x_max - x_min) if x_max > x_min else 0.5
128
+ scaled_y = 1 - ((y - y_min) / (y_max - y_min) if y_max > y_min else 0.5) # Flip Y coordinates
129
+ scaled_pos[node] = (scaled_x, scaled_y)
130
+
131
+ # Visualization attributes
132
+ node_colors = []
133
+ node_sizes = []
134
+ labels = {}
135
+
136
+ for node in G.nodes():
137
+ node_data = G.nodes[node]
138
+ if node_data['type'] == 'symbol':
139
+ node_colors.append('lightblue')
140
+ node_sizes.append(1000)
141
+ labels[node] = f"S:{node_data.get('properties', {}).get('class', 'unknown')}"
142
+ elif node_data['type'] == 'text':
143
+ node_colors.append('lightgreen')
144
+ node_sizes.append(800)
145
+ content = node_data.get('content', '')
146
+ labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
147
+ else: # junction
148
+ node_colors.append('#ff0000') # Pure red
149
+ node_sizes.append(5) # Even smaller junction nodes
150
+ labels[node] = "" # No labels for junctions
151
+
152
+ # Update visualization
153
+ if save_plot:
154
+ plt.figure(figsize=(20, 20))
155
+
156
+ # Draw edges with styles
157
+ edge_styles = []
158
+ for (u, v, data) in G.edges(data=True):
159
+ if data.get('type') == 'line':
160
+ style = data.get('style', {})
161
+ color = style.get('color', '#000000')
162
+ width = float(style.get('stroke_width', 0.5))
163
+ alpha = 0.7
164
+ line_style = '--' if style.get('connection_type') == 'dashed' else '-'
165
+
166
+ nx.draw_networkx_edges(G, scaled_pos,
167
+ edgelist=[(u, v)],
168
+ edge_color=color,
169
+ width=width,
170
+ alpha=alpha,
171
+ style=line_style)
172
+
173
+ # Track unique styles for legend
174
+ edge_style = (line_style, color, width)
175
+ if edge_style not in edge_styles:
176
+ edge_styles.append(edge_style)
177
+
178
+ # Draw nodes with smaller junctions
179
+ nx.draw_networkx_nodes(G, scaled_pos,
180
+ node_color=node_colors,
181
+ node_size=[3 if size == 5 else size for size in node_sizes], # Even smaller junctions
182
+ alpha=1.0)
183
+
184
+ # Add labels only for symbols and texts
185
+ labels = {k: v for k, v in labels.items() if v}
186
+ nx.draw_networkx_labels(G, scaled_pos, labels,
187
+ font_size=8,
188
+ font_weight='bold')
189
+
190
+ # Create comprehensive legend
191
+ legend_elements = []
192
+
193
+ # Node types
194
+ legend_elements.extend([
195
+ plt.scatter([0], [0], c='lightblue', s=200, label='Symbol'),
196
+ plt.scatter([0], [0], c='lightgreen', s=200, label='Text'),
197
+ plt.scatter([0], [0], c='red', s=20, label='Junction')
198
+ ])
199
+
200
+ # Line styles
201
+ for style, color, width in edge_styles:
202
+ legend_elements.append(
203
+ plt.Line2D([0], [0], color=color, linestyle=style,
204
+ linewidth=width, label=f'Line ({style})')
205
+ )
206
+
207
+ # Add legend with two columns
208
+ plt.legend(handles=legend_elements,
209
+ loc='center left',
210
+ bbox_to_anchor=(1, 0.5),
211
+ ncol=1,
212
+ fontsize=12,
213
+ title="Graph Elements")
214
+
215
+ plt.title("P&ID Knowledge Graph Visualization", pad=20, fontsize=16)
216
+ plt.axis('on')
217
+ plt.grid(True)
218
+
219
+ # Save with extra space for legend
220
+ output_path = os.path.join(os.path.dirname(json_path), "graph_visualization.png")
221
+ plt.savefig(output_path, bbox_inches='tight', dpi=300,
222
+ facecolor='white', edgecolor='none')
223
+ plt.close()
224
+
225
+ return G, scaled_pos
226
+
227
+ if __name__ == "__main__":
228
+ # Test the visualization
229
+ json_path = "results/001_page_1_text_aggregated.json"
230
+
231
+ if os.path.exists(json_path):
232
+ G, scaled_pos = create_graph_visualization(json_path)
233
+ plt.show()
234
+ else:
235
+ print(f"Error: Could not find {json_path}")
line_detection_ai.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from base import BaseDetector, BaseDetectionPipeline
2
+ from utils import *
3
+ from config import (
4
+ ImageConfig,
5
+ SymbolConfig,
6
+ TagConfig,
7
+ LineConfig,
8
+ PointConfig,
9
+ JunctionConfig
10
+ )
11
+ from detectors import (
12
+ LineDetector,
13
+ PointDetector,
14
+ JunctionDetector,
15
+ SymbolDetector,
16
+ TagDetector
17
+ )
18
+ from pathlib import Path
19
+ from storage import StorageFactory
20
+ from common import DetectionResult
21
+ from detection_schema import DetectionContext, JunctionType
22
+ from typing import List, Tuple, Optional, Dict
23
+ import torch
24
+ import numpy as np
25
+ import cv2
26
+ import os
27
+ import logging
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+ class DiagramDetectionPipeline:
32
+ """
33
+ Pipeline that runs multiple detectors (line, point, junction, etc.) on an image,
34
+ and keeps a shared DetectionContext in memory.
35
+ """
36
+
37
+ def __init__(self,
38
+ tag_detector: Optional[BaseDetector],
39
+ symbol_detector: Optional[BaseDetector],
40
+ line_detector: Optional[BaseDetector],
41
+ point_detector: Optional[BaseDetector],
42
+ junction_detector: Optional[BaseDetector],
43
+ storage: StorageInterface,
44
+ debug_handler: Optional[DebugHandler] = None,
45
+ transformer: Optional[CoordinateTransformer] = None):
46
+ """
47
+ You can pass None for detectors you don't need.
48
+ """
49
+ # super().__init__(storage=storage, debug_handler=debug_handler)
50
+ self.storage = storage
51
+ self.debug_handler = debug_handler
52
+ self.tag_detector = tag_detector
53
+ self.symbol_detector = symbol_detector
54
+ self.line_detector = line_detector
55
+ self.point_detector = point_detector
56
+ self.junction_detector = junction_detector
57
+ self.transformer = transformer or CoordinateTransformer()
58
+
59
+ def _load_image(self, image_path: str) -> np.ndarray:
60
+ """Load image with validation."""
61
+ image_data = self.storage.load_file(image_path)
62
+ image = cv2.imdecode(np.frombuffer(image_data, np.uint8), cv2.IMREAD_COLOR)
63
+ if image is None:
64
+ raise ValueError(f"Failed to load image from {image_path}")
65
+ return image
66
+
67
+ def _crop_to_roi(self, image: np.ndarray, roi: Optional[list]) -> Tuple[np.ndarray, Tuple[int, int]]:
68
+ """Crop to ROI if provided, else return full image."""
69
+ if roi is not None and len(roi) == 4:
70
+ x_min, y_min, x_max, y_max = roi
71
+ return image[y_min:y_max, x_min:x_max], (x_min, y_min)
72
+ return image, (0, 0)
73
+
74
+ def _remove_symbol_tag_bboxes(self, image: np.ndarray, context: DetectionContext) -> np.ndarray:
75
+ """Fill symbol & tag bounding boxes with white to avoid line detection picking them up."""
76
+ masked = image.copy()
77
+ for sym in context.symbols.values():
78
+ cv2.rectangle(masked,
79
+ (sym.bbox.xmin, sym.bbox.ymin),
80
+ (sym.bbox.xmax, sym.bbox.ymax),
81
+ (255, 255, 255), # White
82
+ thickness=-1)
83
+
84
+ for tg in context.tags.values():
85
+ cv2.rectangle(masked,
86
+ (tg.bbox.xmin, tg.bbox.ymin),
87
+ (tg.bbox.xmax, tg.bbox.ymax),
88
+ (255, 255, 255),
89
+ thickness=-1)
90
+ return masked
91
+
92
+ def run(
93
+ self,
94
+ image_path: str,
95
+ output_dir: str,
96
+ config
97
+ ) -> DetectionResult:
98
+ """
99
+ Main pipeline steps (in local coords):
100
+ 1) Load + crop image
101
+ 2) Detect symbols & tags
102
+ 3) Make a copy for final debug images
103
+ 4) White out symbol/tag bounding boxes
104
+ 5) Detect lines, points, junctions
105
+ 6) Save final JSON
106
+ 7) Generate debug images with various combinations
107
+ """
108
+ try:
109
+ with self.debug_handler.track_performance("total_processing"):
110
+ # 1) Load & crop
111
+ image = self._load_image(image_path)
112
+ cropped_image, roi_offset = self._crop_to_roi(image, config.roi)
113
+
114
+ # 2) Create fresh context
115
+ context = DetectionContext()
116
+
117
+ # 3) Detect symbols
118
+ with self.debug_handler.track_performance("symbol_detection"):
119
+ self.symbol_detector.detect(
120
+ cropped_image,
121
+ context=context,
122
+ roi_offset=roi_offset
123
+ )
124
+
125
+ # 4) Detect tags
126
+ with self.debug_handler.track_performance("tag_detection"):
127
+ self.tag_detector.detect(
128
+ cropped_image,
129
+ context=context,
130
+ roi_offset=roi_offset
131
+ )
132
+
133
+ # Make a copy of the cropped image for final debug combos
134
+ debug_cropped = cropped_image.copy()
135
+
136
+ # 5) White-out symbol/tag bboxes in the original cropped image
137
+ cropped_image = self._remove_symbol_tag_bboxes(cropped_image, context)
138
+
139
+ # 6) Detect lines
140
+ with self.debug_handler.track_performance("line_detection"):
141
+ self.line_detector.detect(cropped_image, context=context)
142
+
143
+ # 7) Detect points
144
+ if self.point_detector:
145
+ with self.debug_handler.track_performance("point_detection"):
146
+ self.point_detector.detect(cropped_image, context=context)
147
+
148
+ # 8) Detect junctions
149
+ if self.junction_detector:
150
+ with self.debug_handler.track_performance("junction_detection"):
151
+ self.junction_detector.detect(cropped_image, context=context)
152
+
153
+ # 9) Save final JSON & any final images
154
+ output_paths = self._persist_results(output_dir, image_path, context)
155
+
156
+ # 10) Save debug images in local coords using debug_cropped
157
+ self._save_all_combinations(debug_cropped, context, output_dir, image_path)
158
+
159
+ return DetectionResult(
160
+ success=True,
161
+ processing_time=self.debug_handler.metrics.get('total_processing', 0),
162
+ json_path=output_paths.get('json_path'),
163
+ image_path=output_paths.get('image_path') # Now returning the annotated image path
164
+ )
165
+
166
+ except Exception as e:
167
+ logger.error(f"Processing failed: {str(e)}")
168
+ return DetectionResult(
169
+ success=False,
170
+ error=str(e)
171
+ )
172
+
173
+ # ------------------------------------------------
174
+ # HELPER FUNCTIONS
175
+ # ------------------------------------------------
176
+ def _persist_results(self, output_dir: str, image_path: str, context: DetectionContext) -> dict:
177
+ """Saves final JSON and debug images to disk."""
178
+ self.storage.create_directory(output_dir)
179
+ base_name = Path(image_path).stem
180
+
181
+ # Save JSON
182
+ json_path = Path(output_dir) / f"{base_name}_detected_lines.json"
183
+ context_json_str = context.to_json(indent=2)
184
+ self.storage.save_file(str(json_path), context_json_str.encode('utf-8'))
185
+
186
+ # Save annotated image for pipeline display
187
+ annotated_image = self._draw_objects(
188
+ self._load_image(image_path),
189
+ context,
190
+ draw_lines=True,
191
+ draw_points=True,
192
+ draw_symbols=True,
193
+ draw_junctions=True,
194
+ draw_tags=True
195
+ )
196
+ image_path = Path(output_dir) / f"{base_name}_annotated.jpg"
197
+ _, encoded = cv2.imencode('.jpg', annotated_image)
198
+ self.storage.save_file(str(image_path), encoded.tobytes())
199
+
200
+ return {
201
+ "json_path": str(json_path),
202
+ "image_path": str(image_path)
203
+ }
204
+
205
+ def _save_all_combinations(self, local_image: np.ndarray, context: DetectionContext,
206
+ output_dir: str, image_path: str) -> None:
207
+ """Produce debug images with different combinations."""
208
+ base_name = Path(image_path).stem
209
+ base_name = base_name.split("_")[0]
210
+ combos = [
211
+ ("text_detected_symbols", dict(draw_symbols=True, draw_tags=False, draw_lines=False, draw_points=False, draw_junctions=False)),
212
+ ("text_detected_texts", dict(draw_symbols=False, draw_tags=True, draw_lines=False, draw_points=False, draw_junctions=False)),
213
+ ("text_detected_lines", dict(draw_symbols=False, draw_tags=False, draw_lines=True, draw_points=False, draw_junctions=False)),
214
+ ]
215
+ self.storage.create_directory(output_dir)
216
+
217
+ for combo_name, flags in combos:
218
+ annotated = self._draw_objects(local_image, context, **flags)
219
+ save_name = f"{base_name}_{combo_name}.jpg"
220
+ save_path = Path(output_dir) / save_name
221
+ _, encoded = cv2.imencode('.jpg', annotated)
222
+ self.storage.save_file(str(save_path), encoded.tobytes())
223
+ logger.info(f"Saved debug image: {save_path}")
224
+
225
+ def _draw_objects(self, base_image: np.ndarray, context: DetectionContext,
226
+ draw_lines: bool = True, draw_points: bool = True,
227
+ draw_symbols: bool = True, draw_junctions: bool = True,
228
+ draw_tags: bool = True) -> np.ndarray:
229
+ """Draw detection results on a copy of base_image in local coords."""
230
+ annotated = base_image.copy()
231
+
232
+ # Lines
233
+ if draw_lines:
234
+ for ln in context.lines.values():
235
+ cv2.line(annotated,
236
+ (ln.start.coords.x, ln.start.coords.y),
237
+ (ln.end.coords.x, ln.end.coords.y),
238
+ (0, 255, 0), # green
239
+ 2)
240
+
241
+ # Points
242
+ if draw_points:
243
+ for pt in context.points.values():
244
+ cv2.circle(annotated,
245
+ (pt.coords.x, pt.coords.y),
246
+ 3,
247
+ (0, 0, 255), # red
248
+ -1)
249
+
250
+ # Symbols
251
+ if draw_symbols:
252
+ for sym in context.symbols.values():
253
+ cv2.rectangle(annotated,
254
+ (sym.bbox.xmin, sym.bbox.ymin),
255
+ (sym.bbox.xmax, sym.bbox.ymax),
256
+ (255, 255, 0), # cyan
257
+ 2)
258
+ cv2.circle(annotated,
259
+ (sym.center.x, sym.center.y),
260
+ 4,
261
+ (255, 0, 255), # magenta
262
+ -1)
263
+
264
+ # Junctions
265
+ if draw_junctions:
266
+ for jn in context.junctions.values():
267
+ if jn.junction_type == JunctionType.T:
268
+ color = (0, 165, 255) # orange
269
+ elif jn.junction_type == JunctionType.L:
270
+ color = (255, 0, 255) # magenta
271
+ else: # END
272
+ color = (0, 0, 255) # red
273
+ cv2.circle(annotated,
274
+ (jn.center.x, jn.center.y),
275
+ 5,
276
+ color,
277
+ -1)
278
+
279
+ # Tags
280
+ if draw_tags:
281
+ for tg in context.tags.values():
282
+ cv2.rectangle(annotated,
283
+ (tg.bbox.xmin, tg.bbox.ymin),
284
+ (tg.bbox.xmax, tg.bbox.ymax),
285
+ (128, 0, 128), # purple
286
+ 2)
287
+ cv2.putText(annotated,
288
+ tg.text,
289
+ (tg.bbox.xmin, tg.bbox.ymin - 5),
290
+ cv2.FONT_HERSHEY_SIMPLEX,
291
+ 0.5,
292
+ (128, 0, 128),
293
+ 1)
294
+
295
+ return annotated
296
+
297
+ def detect_lines(self, image_path: str, output_dir: str, config: Optional[Dict] = None) -> Dict:
298
+ """Legacy interface for line detection"""
299
+ storage = StorageFactory.get_storage()
300
+ debug_handler = DebugHandler(enabled=True, storage=storage)
301
+
302
+ line_detector = LineDetector(
303
+ config=LineConfig(),
304
+ model_path="models/deeplsd_md.tar",
305
+ device=torch.device("cpu"),
306
+ debug_handler=debug_handler
307
+ )
308
+
309
+ pipeline = DiagramDetectionPipeline(
310
+ tag_detector=None,
311
+ symbol_detector=None,
312
+ line_detector=line_detector,
313
+ point_detector=None,
314
+ junction_detector=None,
315
+ storage=storage,
316
+ debug_handler=debug_handler
317
+ )
318
+
319
+ result = pipeline.run(image_path, output_dir, ImageConfig())
320
+ return result
321
+
322
+ def _validate_and_normalize_coordinates(self, points):
323
+ """Validate and normalize coordinates to image space"""
324
+ valid_points = []
325
+ for point in points:
326
+ x, y = point['x'], point['y']
327
+ # Validate coordinates are within image bounds
328
+ if 0 <= x <= self.image_width and 0 <= y <= self.image_height:
329
+ # Normalize coordinates if needed
330
+ valid_points.append({
331
+ 'x': int(x),
332
+ 'y': int(y),
333
+ 'type': point.get('type', 'unknown'),
334
+ 'confidence': point.get('confidence', 1.0)
335
+ })
336
+ return valid_points
337
+
338
+ if __name__ == "__main__":
339
+ # 1) Initialize components
340
+ storage = StorageFactory.get_storage()
341
+ debug_handler = DebugHandler(enabled=True, storage=storage)
342
+
343
+ # 2) Build detectors
344
+ conf = {
345
+ "detect_lines": True,
346
+ "line_detection_params": {
347
+ "merge": True,
348
+ "filtering": True,
349
+ "grad_thresh": 3,
350
+ "grad_nfa": True
351
+ }
352
+ }
353
+
354
+ # 3) Configure
355
+ line_config = LineConfig()
356
+ point_config = PointConfig()
357
+ junction_config = JunctionConfig()
358
+ symbol_config = SymbolConfig()
359
+ tag_config = TagConfig()
360
+
361
+ # ========================== Detectors ========================== #
362
+ symbol_detector = SymbolDetector(
363
+ config=symbol_config,
364
+ debug_handler=debug_handler
365
+ )
366
+
367
+ tag_detector = TagDetector(
368
+ config=tag_config,
369
+ debug_handler=debug_handler
370
+ )
371
+
372
+ line_detector = LineDetector(
373
+ config=line_config,
374
+ model_path="models/deeplsd_md.tar",
375
+ model_config=conf,
376
+ device=torch.device("cpu"), # or "cuda" if available
377
+ debug_handler=debug_handler
378
+ )
379
+
380
+ point_detector = PointDetector(
381
+ config=point_config,
382
+ debug_handler=debug_handler)
383
+
384
+ junction_detector = JunctionDetector(
385
+ config=junction_config,
386
+ debug_handler=debug_handler
387
+ )
388
+
389
+ # 4) Create pipeline
390
+ pipeline = DiagramDetectionPipeline(
391
+ tag_detector=tag_detector,
392
+ symbol_detector=symbol_detector,
393
+ line_detector=line_detector,
394
+ point_detector=point_detector,
395
+ junction_detector=junction_detector,
396
+ storage=storage,
397
+ debug_handler=debug_handler
398
+ )
399
+
400
+ # 5) Run pipeline
401
+ result = pipeline.run(
402
+ image_path="samples/images/0.jpg",
403
+ output_dir="results/",
404
+ config=ImageConfig()
405
+ )
406
+
407
+ if result.success:
408
+ logger.info(f"Pipeline succeeded! See JSON at {result.json_path}")
409
+ else:
410
+ logger.error(f"Pipeline failed: {result.error}")
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ libgl1-mesa-glx
2
+ libglib2.0-0
3
+ tesseract-ocr
pdf_processor.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fitz # PyMuPDF
2
+ import os
3
+ import logging
4
+ from pathlib import Path
5
+ import numpy as np
6
+ from PIL import Image
7
+ import io
8
+ import cv2 # Add this import
9
+ from storage import StorageInterface
10
+ from typing import List, Dict, Tuple, Any
11
+ import json
12
+ from text_detection_combined import process_drawing
13
+
14
+ # Configure logging
15
+ logging.basicConfig(level=logging.INFO)
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class DocumentProcessor:
19
+ def __init__(self, storage: StorageInterface):
20
+ self.storage = storage
21
+ self.logger = logging.getLogger(__name__)
22
+
23
+ # Configure optimal processing parameters
24
+ self.target_dpi = 600 # Increased from 300 to 600 DPI
25
+ self.min_dimension = 2000 # Minimum width/height
26
+ self.max_dimension = 8000 # Increased max dimension for higher DPI
27
+ self.quality = 95 # JPEG quality for saving
28
+
29
+ def process_document(self, file_path: str, output_dir: str) -> list:
30
+ """Process document (PDF/PNG/JPG) and return paths to processed pages"""
31
+ file_ext = Path(file_path).suffix.lower()
32
+
33
+ if file_ext == '.pdf':
34
+ return self._process_pdf(file_path, output_dir)
35
+ elif file_ext in ['.png', '.jpg', '.jpeg']:
36
+ return self._process_image(file_path, output_dir)
37
+ else:
38
+ raise ValueError(f"Unsupported file format: {file_ext}")
39
+
40
+ def _process_pdf(self, pdf_path: str, output_dir: str) -> list:
41
+ """Process PDF document"""
42
+ processed_pages = []
43
+ processing_results = {}
44
+
45
+ try:
46
+ # Create output directory if it doesn't exist
47
+ os.makedirs(output_dir, exist_ok=True)
48
+
49
+ # Clean up any existing files for this document
50
+ base_name = Path(pdf_path).stem
51
+ for file in os.listdir(output_dir):
52
+ if file.startswith(base_name) and file != os.path.basename(pdf_path):
53
+ file_path = os.path.join(output_dir, file)
54
+ try:
55
+ if os.path.isfile(file_path):
56
+ os.unlink(file_path)
57
+ except Exception as e:
58
+ self.logger.error(f"Error deleting file {file_path}: {e}")
59
+
60
+ # Read PDF file directly since it's already in the results directory
61
+ with open(pdf_path, 'rb') as f:
62
+ pdf_data = f.read()
63
+
64
+ doc = fitz.open(stream=pdf_data, filetype="pdf")
65
+
66
+ for page_num in range(len(doc)):
67
+ page = doc[page_num]
68
+
69
+ # Calculate zoom factor for 600 DPI
70
+ zoom = self.target_dpi / 72
71
+ matrix = fitz.Matrix(zoom, zoom)
72
+
73
+ # Get high-resolution image
74
+ pix = page.get_pixmap(matrix=matrix)
75
+ img_data = pix.tobytes()
76
+
77
+ # Convert to numpy array
78
+ nparr = np.frombuffer(img_data, np.uint8)
79
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
80
+
81
+ # Create base filename
82
+ base_filename = f"{Path(pdf_path).stem}_page_{page_num + 1}"
83
+
84
+ # Process and save different versions
85
+ optimized_versions = {
86
+ 'text': self._optimize_for_text(img.copy()),
87
+ 'symbol': self._optimize_for_symbols(img.copy()),
88
+ 'line': self._optimize_for_lines(img.copy())
89
+ }
90
+
91
+ paths = {
92
+ 'text': os.path.join(output_dir, f"{base_filename}_text.png"),
93
+ 'symbol': os.path.join(output_dir, f"{base_filename}_symbol.png"),
94
+ 'line': os.path.join(output_dir, f"{base_filename}_line.png")
95
+ }
96
+
97
+ # Save each version
98
+ for version_type, optimized_img in optimized_versions.items():
99
+ self._save_image(optimized_img, paths[version_type])
100
+ processed_pages.append(paths[version_type])
101
+
102
+ # Store processing results
103
+ processing_results[str(page_num + 1)] = {
104
+ "page_number": page_num + 1,
105
+ "dimensions": {
106
+ "width": img.shape[1],
107
+ "height": img.shape[0]
108
+ },
109
+ "paths": paths,
110
+ "dpi": self.target_dpi,
111
+ "zoom_factor": zoom
112
+ }
113
+
114
+ # Save processing results JSON
115
+ results_json_path = os.path.join(
116
+ output_dir,
117
+ f"{Path(pdf_path).stem}_processing_results.json"
118
+ )
119
+ with open(results_json_path, 'w') as f:
120
+ json.dump(processing_results, f, indent=4)
121
+
122
+ return processed_pages
123
+
124
+ except Exception as e:
125
+ self.logger.error(f"Error processing PDF: {str(e)}")
126
+ raise
127
+
128
+ def _process_image(self, image_path: str, output_dir: str) -> list:
129
+ """Process single image file"""
130
+ try:
131
+ # Load image
132
+ image_data = self.storage.load_file(image_path)
133
+ nparr = np.frombuffer(image_data, np.uint8)
134
+ img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
135
+
136
+ # Process the image
137
+ processed_img = self._optimize_image(img)
138
+
139
+ # Save processed image
140
+ output_path = os.path.join(
141
+ output_dir,
142
+ f"{Path(image_path).stem}_text.png"
143
+ )
144
+ self._save_image(processed_img, output_path)
145
+
146
+ return [output_path]
147
+
148
+ except Exception as e:
149
+ self.logger.error(f"Error processing image: {str(e)}")
150
+ raise
151
+
152
+ def _optimize_image(self, img: np.ndarray) -> np.ndarray:
153
+ """Optimize image for best detection results"""
154
+ # Convert to grayscale for processing
155
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
156
+
157
+ # Enhance contrast
158
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
159
+ enhanced = clahe.apply(gray)
160
+
161
+ # Denoise
162
+ denoised = cv2.fastNlMeansDenoising(enhanced)
163
+
164
+ # Binarize
165
+ _, binary = cv2.threshold(denoised, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
166
+
167
+ # Resize while maintaining aspect ratio
168
+ height, width = binary.shape
169
+ scale = min(self.max_dimension / max(width, height),
170
+ max(self.min_dimension / min(width, height), 1.0))
171
+
172
+ if scale != 1.0:
173
+ new_width = int(width * scale)
174
+ new_height = int(height * scale)
175
+ resized = cv2.resize(binary, (new_width, new_height),
176
+ interpolation=cv2.INTER_LANCZOS4)
177
+ else:
178
+ resized = binary
179
+
180
+ # Convert back to BGR for compatibility
181
+ return cv2.cvtColor(resized, cv2.COLOR_GRAY2BGR)
182
+
183
+ def _optimize_for_text(self, img: np.ndarray) -> np.ndarray:
184
+ """Optimize image for text detection"""
185
+ # Convert to grayscale
186
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
187
+
188
+ # Enhance contrast using CLAHE
189
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
190
+ enhanced = clahe.apply(gray)
191
+
192
+ # Denoise
193
+ denoised = cv2.fastNlMeansDenoising(enhanced)
194
+
195
+ # Adaptive thresholding for better text separation
196
+ binary = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
197
+ cv2.THRESH_BINARY, 11, 2)
198
+
199
+ # Convert back to BGR
200
+ return cv2.cvtColor(binary, cv2.COLOR_GRAY2BGR)
201
+
202
+ def _optimize_for_symbols(self, img: np.ndarray) -> np.ndarray:
203
+ """Optimize image for symbol detection"""
204
+ # Convert to grayscale
205
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
206
+
207
+ # Bilateral filter to preserve edges while reducing noise
208
+ bilateral = cv2.bilateralFilter(gray, 9, 75, 75)
209
+
210
+ # Enhance contrast
211
+ clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8,8))
212
+ enhanced = clahe.apply(bilateral)
213
+
214
+ # Sharpen image
215
+ kernel = np.array([[-1,-1,-1],
216
+ [-1, 9,-1],
217
+ [-1,-1,-1]])
218
+ sharpened = cv2.filter2D(enhanced, -1, kernel)
219
+
220
+ # Convert back to BGR
221
+ return cv2.cvtColor(sharpened, cv2.COLOR_GRAY2BGR)
222
+
223
+ def _optimize_for_lines(self, img: np.ndarray) -> np.ndarray:
224
+ """Optimize image for line detection"""
225
+ # Convert to grayscale
226
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
227
+
228
+ # Reduce noise while preserving edges
229
+ denoised = cv2.GaussianBlur(gray, (3,3), 0)
230
+
231
+ # Edge enhancement
232
+ edges = cv2.Canny(denoised, 50, 150)
233
+
234
+ # Dilate edges to connect broken lines
235
+ kernel = np.ones((2,2), np.uint8)
236
+ dilated = cv2.dilate(edges, kernel, iterations=1)
237
+
238
+ # Convert back to BGR
239
+ return cv2.cvtColor(dilated, cv2.COLOR_GRAY2BGR)
240
+
241
+ def _save_image(self, img: np.ndarray, output_path: str):
242
+ """Save processed image with optimal quality"""
243
+ # Encode image with high quality
244
+ _, buffer = cv2.imencode('.png', img, [
245
+ cv2.IMWRITE_PNG_COMPRESSION, 0
246
+ ])
247
+
248
+ # Save to storage
249
+ self.storage.save_file(output_path, buffer.tobytes())
250
+
251
+ if __name__ == "__main__":
252
+ from storage import StorageFactory
253
+ import shutil
254
+
255
+ # Initialize storage and processor
256
+ storage = StorageFactory.get_storage()
257
+ processor = DocumentProcessor(storage)
258
+
259
+ # Process PDF
260
+ pdf_path = "samples/001.pdf"
261
+ output_dir = "results" # Changed from "processed_pages" to "results"
262
+
263
+ try:
264
+ # Ensure output directory exists
265
+ os.makedirs(output_dir, exist_ok=True)
266
+
267
+ results = processor.process_document(
268
+ file_path=pdf_path,
269
+ output_dir=output_dir
270
+ )
271
+
272
+ # Print detailed results
273
+ print("\nProcessing Results:")
274
+ print(f"Output Directory: {os.path.abspath(output_dir)}")
275
+
276
+ for page_path in results:
277
+ abs_path = os.path.abspath(page_path)
278
+ file_size = os.path.getsize(page_path) / (1024 * 1024) # Convert to MB
279
+ print(f"- {os.path.basename(page_path)} ({file_size:.2f} MB)")
280
+
281
+ # Calculate total size of output
282
+ total_size = sum(os.path.getsize(os.path.join(output_dir, f))
283
+ for f in os.listdir(output_dir)) / (1024 * 1024)
284
+ print(f"\nTotal output size: {total_size:.2f} MB")
285
+
286
+ except Exception as e:
287
+ logger.error(f"Error processing PDF: {str(e)}")
288
+ raise
requirements.txt ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ gradio>=4.0.0
3
+ numpy>=1.24.0
4
+ Pillow>=8.0.0
5
+ opencv-python-headless>=4.8.0
6
+ PyMuPDF>=1.18.0 # for PDF processing
7
+
8
+ # OCR Engines
9
+ pytesseract>=0.3.8
10
+ easyocr>=1.7.1
11
+ python-doctr>=0.7.0 # For DocTR OCR
12
+ tensorflow>=2.8.0 # Required by DocTR
13
+
14
+ # Deep Learning
15
+ torch>=2.1.0
16
+ torchvision>=0.15.0
17
+ ultralytics>=8.0.0 # for YOLO models
18
+ deeplsd # Add this for line detection
19
+ omegaconf>=2.3.0 # Required by DeepLSD
20
+
21
+ # Graph Processing
22
+ networkx>=2.6.0
23
+ plotly>=5.3.0
24
+
25
+ # Utilities
26
+ tqdm>=4.66.0
27
+ python-dotenv>=0.19.0
28
+ uuid>=1.30
29
+ shapely>=1.8.0 # for geometry operations
30
+
31
+ # Azure Storage
32
+ azure-storage-blob>=12.0.0
33
+ azure-core>=1.24.0
34
+
35
+ # AI/Chat
36
+ openai>=1.0.0 # For ChatGPT integration
37
+ loguru>=0.7.0
38
+ matplotlib>=3.4.0
39
+
40
+ # Added from the code block
41
+ requests>=2.31.0
setup.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="pid-processor",
5
+ version="0.1.0",
6
+ packages=find_packages(),
7
+ install_requires=[
8
+ line.strip()
9
+ for line in open("requirements.txt").readlines()
10
+ if not line.startswith("#")
11
+ ],
12
+ author="Your Name",
13
+ author_email="your.email@example.com",
14
+ description="P&ID Processing with AI-Powered Graph Construction",
15
+ long_description=open("README.md").read(),
16
+ long_description_content_type="text/markdown",
17
+ url="https://github.com/yourusername/your-repo-name",
18
+ classifiers=[
19
+ "Programming Language :: Python :: 3",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Operating System :: OS Independent",
22
+ ],
23
+ python_requires=">=3.8",
24
+ )
storage.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from azure.storage.blob import BlobServiceClient
4
+ from abc import ABC, abstractmethod
5
+ import json
6
+
7
+ class StorageInterface(ABC):
8
+
9
+ @abstractmethod
10
+ def save_file(self, file_path: str, content: bytes) -> str:
11
+ pass
12
+
13
+ @abstractmethod
14
+ def load_file(self, file_path: str) -> bytes:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def list_files(self, directory: str) -> list[str]:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def file_exists(self, file_path: str) -> bool:
23
+ pass
24
+
25
+ @abstractmethod
26
+ def delete_file(self, file_path: str) -> None:
27
+ pass
28
+
29
+ @abstractmethod
30
+ def create_directory(self, directory: str) -> None:
31
+ pass
32
+
33
+ @abstractmethod
34
+ def delete_directory(self, directory: str) -> None:
35
+ pass
36
+
37
+ @abstractmethod
38
+ def upload(self, local_path: str, destination_path: str) -> None:
39
+ pass
40
+
41
+ @abstractmethod
42
+ def append_file(self, file_path: str, content: bytes) -> None:
43
+ pass
44
+
45
+ @abstractmethod
46
+ def get_modified_time(self, file_path: str) -> float:
47
+ pass
48
+
49
+ @abstractmethod
50
+ def directory_exists(self, directory: str) -> bool:
51
+ pass
52
+
53
+ def load_json(self, file_path):
54
+ """Load and parse JSON file."""
55
+ try:
56
+ with open(file_path, 'r', encoding='utf-8') as f:
57
+ data = json.load(f)
58
+ return data
59
+ except Exception as e:
60
+ print(f"Error loading JSON from {file_path}: {str(e)}")
61
+ return None
62
+
63
+
64
+ class LocalStorage(StorageInterface):
65
+
66
+ def save_file(self, file_path: str, content: bytes) -> str:
67
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
68
+ with open(file_path, 'wb') as f:
69
+ f.write(content)
70
+ return file_path
71
+
72
+ def load_file(self, file_path: str) -> bytes:
73
+ with open(file_path, 'rb') as f:
74
+ return f.read()
75
+
76
+ def list_files(self, directory: str) -> list[str]:
77
+ return [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
78
+
79
+ def file_exists(self, file_path: str) -> bool:
80
+ return os.path.exists(file_path)
81
+
82
+ def delete_file(self, file_path: str) -> None:
83
+ os.remove(file_path)
84
+
85
+ def create_directory(self, directory: str) -> None:
86
+ os.makedirs(directory, exist_ok=True)
87
+
88
+ def delete_directory(self, directory: str) -> None:
89
+ shutil.rmtree(directory)
90
+
91
+ def upload(self, local_path: str, destination_path: str) -> None:
92
+ os.makedirs(os.path.dirname(destination_path), exist_ok=True)
93
+ shutil.copy(local_path, destination_path)
94
+
95
+ def append_file(self, file_path: str, content: bytes) -> None:
96
+ os.makedirs(os.path.dirname(file_path), exist_ok=True)
97
+ with open(file_path, 'ab') as f:
98
+ f.write(content)
99
+
100
+ def get_modified_time(self, file_path: str) -> float:
101
+ return os.path.getmtime(file_path)
102
+
103
+ def directory_exists(self, directory: str) -> bool:
104
+ return self.file_exists(directory)
105
+
106
+
107
+ class BlobStorage(StorageInterface):
108
+ """
109
+ Writes to blob storage, using local disk as a cache
110
+
111
+ TODO: Allow configuration of temp dir instead of just using the same paths in both local and remote
112
+ """
113
+ def __init__(self, connection_string: str, container_name: str):
114
+ self.blob_service_client = BlobServiceClient.from_connection_string(connection_string)
115
+ self.container_client = self.blob_service_client.get_container_client(container_name)
116
+ self.local_storage = LocalStorage()
117
+
118
+ def download(self, file_path: str) -> bytes:
119
+ blob_client = self.container_client.get_blob_client(file_path)
120
+ return blob_client.download_blob().readall()
121
+
122
+ def sync(self, file_path: str) -> None:
123
+ if not self.local_storage.file_exists(file_path):
124
+ print(f"DEBUG: missing local version of {file_path} - downloading")
125
+ self.local_storage.save_file(file_path, self.download(file_path))
126
+ else:
127
+ local_timestamp = self.local_storage.get_modified_time(file_path)
128
+ remote_timestamp = self.get_modified_time(file_path)
129
+ if local_timestamp < remote_timestamp:
130
+ # We always write remotely before writing locally, so we expect local_timestamp to be > remote timestamp
131
+ print(f"DBEUG: local version of {file_path} out of date - downloading")
132
+ self.local_storage.save_file(file_path, self.download(file_path))
133
+
134
+
135
+ def save_file(self, file_path: str, content: bytes) -> str:
136
+ blob_client = self.container_client.get_blob_client(file_path)
137
+ blob_client.upload_blob(content, overwrite=True)
138
+ self.local_storage.save_file(file_path, content)
139
+ return file_path
140
+
141
+ def load_file(self, file_path: str) -> bytes:
142
+ self.sync(file_path)
143
+ return self.local_storage.load_file(file_path)
144
+
145
+ def list_files(self, directory: str) -> list[str]:
146
+ return [blob.name for blob in self.container_client.list_blobs(name_starts_with=directory)]
147
+
148
+ def file_exists(self, file_path: str) -> bool:
149
+ blob_client = self.container_client.get_blob_client(file_path)
150
+ return blob_client.exists()
151
+
152
+ def delete_file(self, file_path: str) -> None:
153
+ self.local_storage.delete_file(file_path)
154
+ blob_client = self.container_client.get_blob_client(file_path)
155
+ blob_client.delete_blob()
156
+
157
+ def create_directory(self, directory: str) -> None:
158
+ # Blob storage doesn't have directories, so only create it locally
159
+ self.local_storage.create_directory(directory)
160
+
161
+ def delete_directory(self, directory: str) -> None:
162
+ self.local_storage.delete_directory(directory)
163
+ blobs_to_delete = self.container_client.list_blobs(name_starts_with=directory)
164
+ for blob in blobs_to_delete:
165
+ self.container_client.delete_blob(blob.name)
166
+
167
+ def upload(self, local_path: str, destination_path: str) -> None:
168
+ with open(local_path, "rb") as data:
169
+ blob_client = self.container_client.get_blob_client(destination_path)
170
+ blob_client.upload_blob(data, overwrite=True)
171
+ self.local_storage.upload(local_path, destination_path)
172
+
173
+ def append_file(self, file_path: str, content: bytes) -> None:
174
+ blob_client = self.container_client.get_blob_client(file_path)
175
+ if not blob_client.exists():
176
+ blob_client.create_append_blob()
177
+ else:
178
+ self.sync(file_path)
179
+
180
+ blob_client.append_block(content)
181
+ self.local_storage.append_file(file_path, content)
182
+
183
+ def get_modified_time(self, file_path: str) -> float:
184
+ blob_client = self.container_client.get_blob_client(file_path)
185
+ properties = blob_client.get_blob_properties()
186
+ # Convert the UTC datetime to a UNIX timestamp
187
+ return properties.last_modified.timestamp()
188
+
189
+ def directory_exists(self, directory: str) -> bool:
190
+ blobs = self.container_client.list_blobs(name_starts_with=directory)
191
+ return next(blobs, None) is not None
192
+
193
+
194
+ class StorageFactory:
195
+ @staticmethod
196
+ def get_storage() -> StorageInterface:
197
+ storage_type = os.getenv('STORAGE_TYPE', 'local').lower()
198
+ if storage_type == 'local':
199
+ return LocalStorage()
200
+ elif storage_type == 'blob':
201
+ connection_string = os.getenv('AZURE_STORAGE_CONNECTION_STRING')
202
+ container_name = os.getenv('AZURE_STORAGE_CONTAINER_NAME')
203
+ if not connection_string or not container_name:
204
+ raise ValueError("Azure Blob Storage connection string and container name must be set")
205
+ return BlobStorage(connection_string, container_name)
206
+ else:
207
+ raise ValueError(f"Unsupported storage type: {storage_type}")
208
+
symbol_detection.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import json
3
+ import uuid
4
+ import os
5
+ import logging
6
+ from ultralytics import YOLO
7
+ from tqdm import tqdm
8
+ from storage import StorageInterface
9
+ import numpy as np
10
+ from typing import Tuple, List, Dict, Any
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
14
+
15
+ # Constants
16
+ MODEL_PATHS = {
17
+ "model1": "models/Intui_SDM_41.pt",
18
+ "model2": "models/Intui_SDM_20.pt" # Add your second model path here
19
+ }
20
+ MAX_DIMENSION = 1280
21
+ CONFIDENCE_THRESHOLDS = [0.1, 0.3, 0.5, 0.7, 0.9]
22
+ TEXT_COLOR = (0, 0, 255) # Red color for text
23
+ BOX_COLOR = (255, 0, 0) # Red color for box (no transparency)
24
+ BG_COLOR = (255, 255, 255, 0.6) # Semi-transparent white for text background
25
+ THICKNESS = 1 # Thin text thickness
26
+ BOX_THICKNESS = 2 # Box line thickness
27
+ MIN_FONT_SCALE = 0.2 # Minimum font scale
28
+ MAX_FONT_SCALE = 1.0 # Maximum font scale
29
+ TEXT_PADDING = 20 # Increased padding between text elements
30
+ OVERLAP_THRESHOLD = 0.3 # Threshold for detecting text overlap
31
+
32
+ def preprocess_image_for_symbol_detection(image_cv: np.ndarray) -> np.ndarray:
33
+ """Preprocess the image for symbol detection."""
34
+ gray = cv2.cvtColor(image_cv, cv2.COLOR_BGR2GRAY)
35
+ equalized = cv2.equalizeHist(gray)
36
+ filtered = cv2.bilateralFilter(equalized, 9, 75, 75)
37
+ edges = cv2.Canny(filtered, 100, 200)
38
+ preprocessed_image = cv2.cvtColor(edges, cv2.COLOR_GRAY2BGR)
39
+ return preprocessed_image
40
+
41
+ def evaluate_detections(detections_list: List[Dict[str, Any]]) -> int:
42
+ """Evaluate the quality of detections."""
43
+ return len(detections_list)
44
+
45
+ def resize_image_with_aspect_ratio(image_cv: np.ndarray, max_dimension: int) -> Tuple[np.ndarray, int, int]:
46
+ """Resize the image while maintaining the aspect ratio."""
47
+ original_height, original_width = image_cv.shape[:2]
48
+ if max(original_width, original_height) > max_dimension:
49
+ scale = max_dimension / float(max(original_width, original_height))
50
+ new_width = int(original_width * scale)
51
+ new_height = int(original_height * scale)
52
+ image_cv = cv2.resize(image_cv, (new_width, new_height), interpolation=cv2.INTER_LINEAR)
53
+ else:
54
+ new_width, new_height = original_width, original_height
55
+ return image_cv, new_width, new_height
56
+
57
+ def merge_detections(all_detections: List[Dict]) -> List[Dict]:
58
+ """
59
+ Merge detections from all models, keeping only the highest confidence detection
60
+ when duplicates are found using IoU.
61
+ """
62
+ if not all_detections:
63
+ return []
64
+
65
+ # Sort by confidence
66
+ all_detections.sort(key=lambda x: x['confidence'], reverse=True)
67
+
68
+ # Keep track of which detections to keep
69
+ keep = [True] * len(all_detections)
70
+
71
+ def calculate_iou(box1, box2):
72
+ """Calculate Intersection over Union (IoU) between two boxes."""
73
+ x1 = max(box1[0], box2[0])
74
+ y1 = max(box1[1], box2[1])
75
+ x2 = min(box1[2], box2[2])
76
+ y2 = min(box1[3], box2[3])
77
+
78
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
79
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
80
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
81
+ union = area1 + area2 - intersection
82
+
83
+ return intersection / union if union > 0 else 0
84
+
85
+ # Apply NMS and keep only highest confidence detection
86
+ for i in range(len(all_detections)):
87
+ if not keep[i]:
88
+ continue
89
+
90
+ current_box = all_detections[i]['bbox']
91
+ current_label = all_detections[i]['original_label']
92
+
93
+ for j in range(i + 1, len(all_detections)):
94
+ if not keep[j]:
95
+ continue
96
+
97
+ # Check if same label type and high IoU
98
+ if (all_detections[j]['original_label'] == current_label and
99
+ calculate_iou(current_box, all_detections[j]['bbox']) > 0.5):
100
+ # Since list is sorted by confidence, i will always have higher confidence than j
101
+ keep[j] = False
102
+ logging.info(f"Removing duplicate detection of {current_label} with lower confidence "
103
+ f"({all_detections[j]['confidence']:.2f} < {all_detections[i]['confidence']:.2f})")
104
+
105
+ # Add kept detections to final list
106
+ merged_detections = [det for i, det in enumerate(all_detections) if keep[i]]
107
+ return merged_detections
108
+
109
+ def calculate_font_scale(image_width: int, bbox_width: int) -> float:
110
+ """
111
+ Calculate appropriate font scale based on image and bbox dimensions.
112
+ """
113
+ base_scale = 0.7 # Increased base scale for better visibility
114
+
115
+ # Adjust font size based on image width and bbox width
116
+ width_ratio = image_width / MAX_DIMENSION
117
+ bbox_ratio = bbox_width / image_width
118
+
119
+ # Calculate adaptive scale with increased multipliers
120
+ adaptive_scale = base_scale * max(width_ratio, 0.5) * max(bbox_ratio * 6, 0.7)
121
+
122
+ # Ensure font scale stays within reasonable bounds
123
+ return min(max(adaptive_scale, MIN_FONT_SCALE), MAX_FONT_SCALE)
124
+
125
+ def check_overlap(rect1, rect2):
126
+ """Check if two rectangles overlap."""
127
+ x1_1, y1_1, x2_1, y2_1 = rect1
128
+ x1_2, y1_2, x2_2, y2_2 = rect2
129
+
130
+ return not (x2_1 < x1_2 or x1_1 > x2_2 or y2_1 < y1_2 or y1_1 > y2_2)
131
+
132
+ def draw_annotation(
133
+ image: np.ndarray,
134
+ bbox: List[int],
135
+ text: str,
136
+ confidence: float,
137
+ model_source: str,
138
+ existing_annotations: List[tuple] = None
139
+ ) -> None:
140
+ """
141
+ Draw annotation with no background and thin fonts.
142
+ """
143
+ if existing_annotations is None:
144
+ existing_annotations = []
145
+
146
+ x1, y1, x2, y2 = bbox
147
+ bbox_width = x2 - x1
148
+ image_width = image.shape[1]
149
+ image_height = image.shape[0]
150
+
151
+ # Calculate adaptive font scale
152
+ font_scale = calculate_font_scale(image_width, bbox_width)
153
+
154
+ # Simplify the annotation text
155
+ annotation_text = f'{text}\n{confidence:.0f}%'
156
+ lines = annotation_text.split('\n')
157
+
158
+ # Calculate text dimensions
159
+ font = cv2.FONT_HERSHEY_SIMPLEX
160
+ max_width = 0
161
+ total_height = 0
162
+ line_heights = []
163
+
164
+ for line in lines:
165
+ (width, height), baseline = cv2.getTextSize(
166
+ line, font, font_scale, THICKNESS
167
+ )
168
+ max_width = max(max_width, width)
169
+ line_height = height + baseline + TEXT_PADDING
170
+ line_heights.append(line_height)
171
+ total_height += line_height
172
+
173
+ # Calculate initial text position with increased padding
174
+ padding = TEXT_PADDING
175
+ rect_x1 = max(0, x1 - padding)
176
+ rect_x2 = min(image_width, x1 + max_width + padding * 2)
177
+
178
+ # Try different positions to avoid overlap
179
+ positions = [
180
+ ('top', y1 - total_height - padding),
181
+ ('bottom', y2 + padding),
182
+ ('top_shifted', y1 - total_height - padding * 2),
183
+ ('bottom_shifted', y2 + padding * 2)
184
+ ]
185
+
186
+ final_position = None
187
+ for pos_name, y_pos in positions:
188
+ if y_pos < 0 or y_pos + total_height > image_height:
189
+ continue
190
+
191
+ rect = (rect_x1, y_pos, rect_x2, y_pos + total_height)
192
+ overlap = False
193
+
194
+ for existing_rect in existing_annotations:
195
+ if check_overlap(rect, existing_rect):
196
+ overlap = True
197
+ break
198
+
199
+ if not overlap:
200
+ final_position = (pos_name, y_pos)
201
+ existing_annotations.append(rect)
202
+ break
203
+
204
+ # If no non-overlapping position found, use side position
205
+ if final_position is None:
206
+ rect_x1 = max(0, x1 + bbox_width + padding)
207
+ rect_x2 = min(image_width, rect_x1 + max_width + padding * 2)
208
+ y_pos = y1
209
+ final_position = ('side', y_pos)
210
+
211
+ rect_y1 = final_position[1]
212
+
213
+ # Draw bounding box (no transparency)
214
+ cv2.rectangle(image, (x1, y1), (x2, y2), BOX_COLOR, BOX_THICKNESS)
215
+
216
+ # Draw text directly without background
217
+ text_y = rect_y1 + line_heights[0] - padding
218
+ for i, line in enumerate(lines):
219
+ # Draw text with thin lines
220
+ cv2.putText(
221
+ image,
222
+ line,
223
+ (rect_x1 + padding, text_y + sum(line_heights[:i])),
224
+ font,
225
+ font_scale,
226
+ TEXT_COLOR,
227
+ THICKNESS,
228
+ cv2.LINE_AA
229
+ )
230
+
231
+ def run_detection_with_optimal_threshold(
232
+ image_path: str,
233
+ results_dir: str = "results",
234
+ file_name: str = "",
235
+ apply_preprocessing: bool = False,
236
+ resize_image: bool = True, # Changed default to True
237
+ storage: StorageInterface = None
238
+ ) -> Tuple[str, str, str, List[int]]:
239
+ """Run detection with multiple models and merge results."""
240
+ try:
241
+ image_data = storage.load_file(image_path)
242
+ nparr = np.frombuffer(image_data, np.uint8)
243
+ original_image_cv = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
244
+ image_cv = original_image_cv.copy()
245
+
246
+ if resize_image:
247
+ logging.info("Resizing image for detection with aspect ratio...")
248
+ image_cv, resized_width, resized_height = resize_image_with_aspect_ratio(image_cv, MAX_DIMENSION)
249
+ else:
250
+ logging.info("Skipping image resizing...")
251
+ resized_height, resized_width = original_image_cv.shape[:2]
252
+
253
+ if apply_preprocessing:
254
+ logging.info("Preprocessing image for symbol detection...")
255
+ image_cv = preprocess_image_for_symbol_detection(image_cv)
256
+ else:
257
+ logging.info("Skipping image preprocessing for symbol detection...")
258
+
259
+ all_detections = []
260
+
261
+ # Run detection with each model
262
+ for model_name, model_path in MODEL_PATHS.items():
263
+ logging.info(f"Running detection with model: {model_name}")
264
+
265
+ if not model_path:
266
+ logging.warning(f"No model path found for {model_name}")
267
+ continue
268
+
269
+ model = YOLO(model_path)
270
+
271
+ best_confidence_threshold = 0.5
272
+ best_detections_list = []
273
+ best_metric = -1
274
+
275
+ for confidence_threshold in CONFIDENCE_THRESHOLDS:
276
+ logging.info(f"Running detection with confidence threshold: {confidence_threshold}...")
277
+ results = model.predict(source=image_cv, imgsz=MAX_DIMENSION)
278
+
279
+ detections_list = []
280
+ for result in results:
281
+ for box in result.boxes:
282
+ confidence = float(box.conf[0])
283
+ if confidence >= confidence_threshold:
284
+ x1, y1, x2, y2 = map(float, box.xyxy[0])
285
+ class_id = int(box.cls[0])
286
+ label = result.names[class_id]
287
+
288
+ scale_x = original_image_cv.shape[1] / resized_width
289
+ scale_y = original_image_cv.shape[0] / resized_height
290
+ x1 *= scale_x
291
+ x2 *= scale_x
292
+ y1 *= scale_y
293
+ y2 *= scale_y
294
+ x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
295
+
296
+ split_label = label.split('_')
297
+ if len(split_label) >= 3:
298
+ category = split_label[0]
299
+ type_ = split_label[1]
300
+ new_label = '_'.join(split_label[2:])
301
+ elif len(split_label) == 2:
302
+ category = split_label[0]
303
+ type_ = split_label[1]
304
+ new_label = split_label[1]
305
+ elif len(split_label) == 1:
306
+ category = split_label[0]
307
+ type_ = "Unknown"
308
+ new_label = split_label[0]
309
+ else:
310
+ logging.warning(f"Unexpected label format: {label}. Skipping this detection.")
311
+ continue
312
+
313
+ detection_id = str(uuid.uuid4())
314
+ detection_info = {
315
+ "symbol_id": detection_id,
316
+ "class_id": class_id,
317
+ "original_label": label,
318
+ "category": category,
319
+ "type": type_,
320
+ "label": new_label,
321
+ "confidence": confidence,
322
+ "bbox": [x1, y1, x2, y2],
323
+ "model_source": model_name
324
+ }
325
+ detections_list.append(detection_info)
326
+
327
+ metric = evaluate_detections(detections_list)
328
+ if metric > best_metric:
329
+ best_metric = metric
330
+ best_confidence_threshold = confidence_threshold
331
+ best_detections_list = detections_list
332
+
333
+ all_detections.extend(best_detections_list)
334
+
335
+ # Merge detections from both models
336
+ merged_detections = merge_detections(all_detections)
337
+ logging.info(f"Total detections after merging: {len(merged_detections)}")
338
+
339
+ # Draw annotations on the image
340
+ existing_annotations = []
341
+ for det in merged_detections:
342
+ draw_annotation(
343
+ original_image_cv,
344
+ det["bbox"],
345
+ det["original_label"],
346
+ det["confidence"] * 100,
347
+ det["model_source"],
348
+ existing_annotations
349
+ )
350
+
351
+ # Save results
352
+ storage.create_directory(results_dir)
353
+ file_name_without_extension = os.path.splitext(file_name)[0]
354
+
355
+ # Prepare output JSON
356
+ total_detected_symbols = len(merged_detections)
357
+ class_counts = {}
358
+ for det in merged_detections:
359
+ full_label = det["original_label"]
360
+ class_counts[full_label] = class_counts.get(full_label, 0) + 1
361
+
362
+ output_json = {
363
+ "total_detected_symbols": total_detected_symbols,
364
+ "details": class_counts,
365
+ "detections": merged_detections
366
+ }
367
+
368
+ # Save JSON and image
369
+ detection_json_path = os.path.join(
370
+ results_dir, f'{file_name_without_extension}_detected_symbols.json'
371
+ )
372
+ storage.save_file(
373
+ detection_json_path,
374
+ json.dumps(output_json, indent=4).encode('utf-8')
375
+ )
376
+
377
+ # Save with maximum quality
378
+ detection_image_path = os.path.join(
379
+ results_dir, f'{file_name_without_extension}_detected_symbols.png' # Using PNG for transparency
380
+ )
381
+
382
+ # Configure image encoding parameters for maximum quality
383
+ encode_params = [
384
+ cv2.IMWRITE_PNG_COMPRESSION, 0 # No compression for PNG
385
+ ]
386
+
387
+ # Save as high-quality PNG to preserve transparency
388
+ _, img_encoded = cv2.imencode(
389
+ '.png',
390
+ original_image_cv,
391
+ encode_params
392
+ )
393
+
394
+ storage.save_file(detection_image_path, img_encoded.tobytes())
395
+
396
+ # Calculate diagram bbox from merged detections
397
+ diagram_bbox = [
398
+ min([det['bbox'][0] for det in merged_detections], default=0),
399
+ min([det['bbox'][1] for det in merged_detections], default=0),
400
+ max([det['bbox'][2] for det in merged_detections], default=0),
401
+ max([det['bbox'][3] for det in merged_detections], default=0)
402
+ ]
403
+
404
+ # Scale up image if it's too small
405
+ min_width = 2000 # Minimum width for good visibility
406
+ if original_image_cv.shape[1] < min_width:
407
+ scale_factor = min_width / original_image_cv.shape[1]
408
+ new_width = min_width
409
+ new_height = int(original_image_cv.shape[0] * scale_factor)
410
+ original_image_cv = cv2.resize(
411
+ original_image_cv,
412
+ (new_width, new_height),
413
+ interpolation=cv2.INTER_CUBIC
414
+ )
415
+
416
+ return (
417
+ detection_image_path,
418
+ detection_json_path,
419
+ f"Total detections after merging: {total_detected_symbols}",
420
+ diagram_bbox
421
+ )
422
+ except Exception as e:
423
+ logging.error(f"An error occurred: {e}")
424
+ return "Error during detection", None, None, None
425
+
426
+ if __name__ == "__main__":
427
+ from storage import StorageFactory
428
+
429
+ uploaded_file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
430
+ results_dir = "results"
431
+ apply_symbol_preprocessing = False
432
+ resize_image = True
433
+
434
+ storage = StorageFactory.get_storage()
435
+
436
+ (
437
+ detection_image_path,
438
+ detection_json_path,
439
+ detection_log_message,
440
+ diagram_bbox
441
+ ) = run_detection_with_optimal_threshold(
442
+ uploaded_file_path,
443
+ results_dir=results_dir,
444
+ file_name=os.path.basename(uploaded_file_path),
445
+ apply_preprocessing=apply_symbol_preprocessing,
446
+ resize_image=resize_image,
447
+ storage=storage
448
+ )
449
+
450
+ logging.info("Detection Image Path: %s", detection_image_path)
451
+ logging.info("Detection JSON Path: %s", detection_json_path)
452
+ logging.info("Detection Log Message: %s", detection_log_message)
453
+ logging.info("Diagram BBox: %s", diagram_bbox)
454
+ logging.info("Done!")
text_detection_combined.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import io
4
+ from PIL import Image, ImageDraw, ImageFont
5
+ import numpy as np
6
+ from doctr.models import ocr_predictor
7
+ import pytesseract
8
+ import easyocr
9
+ from storage import StorageInterface
10
+ import re
11
+ import logging
12
+ from pathlib import Path
13
+ import cv2
14
+ import traceback
15
+
16
+ # Initialize models
17
+ try:
18
+ doctr_model = ocr_predictor(pretrained=True)
19
+ easyocr_reader = easyocr.Reader(['en'])
20
+ logging.info("All OCR models loaded successfully")
21
+ except Exception as e:
22
+ logging.error(f"Error loading OCR models: {e}")
23
+
24
+ # Combined patterns from all approaches
25
+ TEXT_PATTERNS = {
26
+ 'Line_Number': r"(?:\d{1,5}[-](?:[A-Z]{2,4})[-]\d{1,3})",
27
+ 'Equipment_Tag': r"(?:[A-Z]{1,3}[-][A-Z0-9]{1,4}[-]\d{1,3})",
28
+ 'Instrument_Tag': r"(?:\d{2,3}[-][A-Z]{2,4}[-]\d{2,3})",
29
+ 'Valve_Number': r"(?:[A-Z]{1,2}[-]\d{3})",
30
+ 'Pipe_Size': r"(?:\d{1,2}[\"])",
31
+ 'Flow_Direction': r"(?:FROM|TO)",
32
+ 'Service_Description': r"(?:STEAM|WATER|AIR|GAS|DRAIN)",
33
+ 'Process_Instrument': r"(?:[0-9]{2,3}(?:-[A-Z]{2,3})?-[0-9]{2,3}|[A-Z]{2,3}-[0-9]{2,3})",
34
+ 'Nozzle': r"(?:N[0-9]{1,2}|MH)",
35
+ 'Pipe_Connector': r"(?:[0-9]{1,5}|[A-Z]{1,2}[0-9]{2,5})"
36
+ }
37
+
38
+ def detect_text_combined(image, confidence_threshold=0.3):
39
+ """Combine results from all three OCR approaches"""
40
+ results = []
41
+
42
+ # 1. Tesseract Detection
43
+ tesseract_results = detect_with_tesseract(image)
44
+ for result in tesseract_results:
45
+ result['source'] = 'tesseract'
46
+ results.append(result)
47
+
48
+ # 2. EasyOCR Detection
49
+ easyocr_results = detect_with_easyocr(image)
50
+ for result in easyocr_results:
51
+ result['source'] = 'easyocr'
52
+ results.append(result)
53
+
54
+ # 3. DocTR Detection
55
+ doctr_results = detect_with_doctr(image)
56
+ for result in doctr_results:
57
+ result['source'] = 'doctr'
58
+ results.append(result)
59
+
60
+ # Merge overlapping detections
61
+ merged_results = merge_overlapping_detections(results)
62
+
63
+ # Classify and filter results
64
+ classified_results = []
65
+ for result in merged_results:
66
+ if result['confidence'] >= confidence_threshold:
67
+ text_type = classify_text(result['text'])
68
+ result['text_type'] = text_type
69
+ classified_results.append(result)
70
+
71
+ return classified_results
72
+
73
+ def generate_detailed_summary(results):
74
+ """Generate detailed detection summary"""
75
+ summary = {
76
+ 'total_detections': len(results),
77
+ 'by_type': {},
78
+ 'by_source': {
79
+ 'tesseract': {
80
+ 'count': 0,
81
+ 'by_type': {},
82
+ 'avg_confidence': 0.0
83
+ },
84
+ 'easyocr': {
85
+ 'count': 0,
86
+ 'by_type': {},
87
+ 'avg_confidence': 0.0
88
+ },
89
+ 'doctr': {
90
+ 'count': 0,
91
+ 'by_type': {},
92
+ 'avg_confidence': 0.0
93
+ }
94
+ },
95
+ 'confidence_ranges': {
96
+ '0.9-1.0': 0,
97
+ '0.8-0.9': 0,
98
+ '0.7-0.8': 0,
99
+ '0.6-0.7': 0,
100
+ '0.5-0.6': 0,
101
+ '<0.5': 0
102
+ },
103
+ 'detected_items': []
104
+ }
105
+
106
+ # Initialize type counters
107
+ for pattern_type in TEXT_PATTERNS.keys():
108
+ summary['by_type'][pattern_type] = {
109
+ 'count': 0,
110
+ 'avg_confidence': 0.0,
111
+ 'by_source': {
112
+ 'tesseract': 0,
113
+ 'easyocr': 0,
114
+ 'doctr': 0
115
+ },
116
+ 'items': []
117
+ }
118
+ # Initialize source-specific type counters
119
+ for source in summary['by_source'].keys():
120
+ summary['by_source'][source]['by_type'][pattern_type] = 0
121
+
122
+ # Process each detection
123
+ source_confidences = {'tesseract': [], 'easyocr': [], 'doctr': []}
124
+
125
+ for result in results:
126
+ # Get source and confidence
127
+ source = result['source']
128
+ conf = result['confidence']
129
+ text_type = result['text_type']
130
+
131
+ # Update source statistics
132
+ summary['by_source'][source]['count'] += 1
133
+ source_confidences[source].append(conf)
134
+
135
+ # Update confidence ranges
136
+ if conf >= 0.9: summary['confidence_ranges']['0.9-1.0'] += 1
137
+ elif conf >= 0.8: summary['confidence_ranges']['0.8-0.9'] += 1
138
+ elif conf >= 0.7: summary['confidence_ranges']['0.7-0.8'] += 1
139
+ elif conf >= 0.6: summary['confidence_ranges']['0.6-0.7'] += 1
140
+ elif conf >= 0.5: summary['confidence_ranges']['0.5-0.6'] += 1
141
+ else: summary['confidence_ranges']['<0.5'] += 1
142
+
143
+ # Update type statistics
144
+ if text_type in summary['by_type']:
145
+ type_stats = summary['by_type'][text_type]
146
+ type_stats['count'] += 1
147
+ type_stats['by_source'][source] += 1
148
+ summary['by_source'][source]['by_type'][text_type] += 1
149
+ type_stats['items'].append({
150
+ 'text': result['text'],
151
+ 'confidence': conf,
152
+ 'source': source,
153
+ 'bbox': result['bbox']
154
+ })
155
+
156
+ # Add to detected items
157
+ summary['detected_items'].append({
158
+ 'text': result['text'],
159
+ 'type': text_type,
160
+ 'confidence': conf,
161
+ 'source': source,
162
+ 'bbox': result['bbox']
163
+ })
164
+
165
+ # Calculate average confidences
166
+ for source, confs in source_confidences.items():
167
+ if confs:
168
+ summary['by_source'][source]['avg_confidence'] = sum(confs) / len(confs)
169
+
170
+ # Calculate average confidences for each type
171
+ for text_type, stats in summary['by_type'].items():
172
+ if stats['items']:
173
+ stats['avg_confidence'] = sum(item['confidence'] for item in stats['items']) / len(stats['items'])
174
+
175
+ return summary
176
+
177
+ def process_drawing(image_path, results_dir, storage=None):
178
+ try:
179
+ # Read image using cv2
180
+ image = cv2.imread(image_path)
181
+ if image is None:
182
+ raise ValueError(f"Could not read image from {image_path}")
183
+
184
+ # Create annotated copy
185
+ annotated_image = image.copy()
186
+
187
+ # Initialize results and summary
188
+ text_results = {
189
+ 'file_name': image_path,
190
+ 'detections': []
191
+ }
192
+
193
+ text_summary = {
194
+ 'total_detections': 0,
195
+ 'by_source': {
196
+ 'tesseract': {'count': 0, 'avg_confidence': 0.0},
197
+ 'easyocr': {'count': 0, 'avg_confidence': 0.0},
198
+ 'doctr': {'count': 0, 'avg_confidence': 0.0}
199
+ },
200
+ 'by_type': {
201
+ 'equipment_tag': {'count': 0, 'avg_confidence': 0.0},
202
+ 'line_number': {'count': 0, 'avg_confidence': 0.0},
203
+ 'instrument_tag': {'count': 0, 'avg_confidence': 0.0},
204
+ 'valve_number': {'count': 0, 'avg_confidence': 0.0},
205
+ 'pipe_size': {'count': 0, 'avg_confidence': 0.0},
206
+ 'flow_direction': {'count': 0, 'avg_confidence': 0.0},
207
+ 'service_description': {'count': 0, 'avg_confidence': 0.0},
208
+ 'process_instrument': {'count': 0, 'avg_confidence': 0.0},
209
+ 'nozzle': {'count': 0, 'avg_confidence': 0.0},
210
+ 'pipe_connector': {'count': 0, 'avg_confidence': 0.0},
211
+ 'other': {'count': 0, 'avg_confidence': 0.0}
212
+ }
213
+ }
214
+
215
+ # Run OCR with different engines
216
+ tesseract_results = detect_with_tesseract(image)
217
+ easyocr_results = detect_with_easyocr(image)
218
+ doctr_results = detect_with_doctr(image)
219
+
220
+ # Combine results
221
+ all_detections = []
222
+ all_detections.extend([(res, 'tesseract') for res in tesseract_results])
223
+ all_detections.extend([(res, 'easyocr') for res in easyocr_results])
224
+ all_detections.extend([(res, 'doctr') for res in doctr_results])
225
+
226
+ # Process each detection
227
+ for detection, source in all_detections:
228
+ # Update text_results
229
+ text_results['detections'].append({
230
+ 'text': detection['text'],
231
+ 'bbox': detection['bbox'],
232
+ 'confidence': detection['confidence'],
233
+ 'source': source
234
+ })
235
+
236
+ # Update summary statistics
237
+ text_summary['total_detections'] += 1
238
+ text_summary['by_source'][source]['count'] += 1
239
+ text_summary['by_source'][source]['avg_confidence'] += detection['confidence']
240
+
241
+ # Draw detection on image
242
+ x1, y1, x2, y2 = detection['bbox']
243
+ cv2.rectangle(annotated_image, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
244
+ cv2.putText(annotated_image, detection['text'], (int(x1), int(y1)-5),
245
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
246
+
247
+ # Calculate average confidences
248
+ for source in text_summary['by_source']:
249
+ if text_summary['by_source'][source]['count'] > 0:
250
+ text_summary['by_source'][source]['avg_confidence'] /= text_summary['by_source'][source]['count']
251
+
252
+ # Save results with new naming convention
253
+ base_name = Path(image_path).stem
254
+ text_result_image_path = os.path.join(results_dir, f"{base_name}_detected_texts.jpg")
255
+ text_result_json_path = os.path.join(results_dir, f"{base_name}_detected_texts.json")
256
+
257
+ # Save the annotated image
258
+ success = cv2.imwrite(text_result_image_path, annotated_image)
259
+ if not success:
260
+ raise ValueError(f"Failed to save image to {text_result_image_path}")
261
+
262
+ # Save the JSON results
263
+ with open(text_result_json_path, 'w', encoding='utf-8') as f:
264
+ json.dump({
265
+ 'file_name': image_path,
266
+ 'summary': text_summary,
267
+ 'detections': text_results['detections']
268
+ }, f, indent=4, ensure_ascii=False)
269
+
270
+ return {
271
+ 'image_path': text_result_image_path,
272
+ 'json_path': text_result_json_path,
273
+ 'results': text_results
274
+ }, text_summary
275
+
276
+ except Exception as e:
277
+ print(f"Error in process_drawing: {str(e)}")
278
+ traceback.print_exc()
279
+ return None, None
280
+
281
+ def detect_with_tesseract(image):
282
+ """Detect text using Tesseract OCR"""
283
+ # Configure Tesseract for technical drawings
284
+ custom_config = r'--oem 3 --psm 11 -c tessedit_char_whitelist="ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789-.()" -c tessedit_write_images=true -c textord_heavy_nr=true -c textord_min_linesize=3'
285
+
286
+ try:
287
+ data = pytesseract.image_to_data(
288
+ image,
289
+ config=custom_config,
290
+ output_type=pytesseract.Output.DICT
291
+ )
292
+
293
+ results = []
294
+ for i in range(len(data['text'])):
295
+ conf = float(data['conf'][i])
296
+ if conf > 30: # Lower confidence threshold for technical text
297
+ text = data['text'][i].strip()
298
+ if text:
299
+ x, y, w, h = data['left'][i], data['top'][i], data['width'][i], data['height'][i]
300
+ results.append({
301
+ 'text': text,
302
+ 'bbox': [x, y, x + w, y + h],
303
+ 'confidence': conf / 100.0
304
+ })
305
+ return results
306
+
307
+ except Exception as e:
308
+ logger.error(f"Tesseract error: {str(e)}")
309
+ return []
310
+
311
+ def detect_with_easyocr(image):
312
+ """Detect text using EasyOCR"""
313
+ if easyocr_reader is None:
314
+ return []
315
+
316
+ try:
317
+ results = easyocr_reader.readtext(
318
+ np.array(image),
319
+ paragraph=False,
320
+ height_ths=2.0,
321
+ width_ths=2.0,
322
+ contrast_ths=0.2,
323
+ text_threshold=0.5
324
+ )
325
+
326
+ parsed_results = []
327
+ for bbox, text, conf in results:
328
+ x1, y1 = min(point[0] for point in bbox), min(point[1] for point in bbox)
329
+ x2, y2 = max(point[0] for point in bbox), max(point[1] for point in bbox)
330
+
331
+ parsed_results.append({
332
+ 'text': text,
333
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
334
+ 'confidence': conf
335
+ })
336
+ return parsed_results
337
+
338
+ except Exception as e:
339
+ logger.error(f"EasyOCR error: {str(e)}")
340
+ return []
341
+
342
+ def detect_with_doctr(image):
343
+ """Detect text using DocTR"""
344
+ try:
345
+ # Convert PIL image to numpy array
346
+ image_np = np.array(image)
347
+
348
+ # Get predictions
349
+ result = doctr_model([image_np])
350
+ doc = result.export()
351
+
352
+ # Parse results
353
+ results = []
354
+ for page in doc['pages']:
355
+ for block in page['blocks']:
356
+ for line in block['lines']:
357
+ for word in line['words']:
358
+ # Convert normalized coordinates to absolute
359
+ height, width = image_np.shape[:2]
360
+ points = np.array(word['geometry']) * np.array([width, height])
361
+ x1, y1 = points.min(axis=0)
362
+ x2, y2 = points.max(axis=0)
363
+
364
+ results.append({
365
+ 'text': word['value'],
366
+ 'bbox': [int(x1), int(y1), int(x2), int(y2)],
367
+ 'confidence': word.get('confidence', 0.5)
368
+ })
369
+ return results
370
+
371
+ except Exception as e:
372
+ logger.error(f"DocTR error: {str(e)}")
373
+ return []
374
+
375
+ def merge_overlapping_detections(results, iou_threshold=0.5):
376
+ """Merge overlapping detections from different sources"""
377
+ if not results:
378
+ return []
379
+
380
+ def calculate_iou(box1, box2):
381
+ x1 = max(box1[0], box2[0])
382
+ y1 = max(box1[1], box2[1])
383
+ x2 = min(box1[2], box2[2])
384
+ y2 = min(box1[3], box2[3])
385
+
386
+ if x2 < x1 or y2 < y1:
387
+ return 0.0
388
+
389
+ intersection = (x2 - x1) * (y2 - y1)
390
+ area1 = (box1[2] - box1[0]) * (box1[3] - box1[1])
391
+ area2 = (box2[2] - box2[0]) * (box2[3] - box2[1])
392
+ union = area1 + area2 - intersection
393
+
394
+ return intersection / union if union > 0 else 0
395
+
396
+ merged = []
397
+ used = set()
398
+
399
+ for i, r1 in enumerate(results):
400
+ if i in used:
401
+ continue
402
+
403
+ current_group = [r1]
404
+ used.add(i)
405
+
406
+ for j, r2 in enumerate(results):
407
+ if j in used:
408
+ continue
409
+
410
+ if calculate_iou(r1['bbox'], r2['bbox']) > iou_threshold:
411
+ current_group.append(r2)
412
+ used.add(j)
413
+
414
+ if len(current_group) == 1:
415
+ merged.append(current_group[0])
416
+ else:
417
+ # Keep the detection with highest confidence
418
+ best_detection = max(current_group, key=lambda x: x['confidence'])
419
+ merged.append(best_detection)
420
+
421
+ return merged
422
+
423
+ def classify_text(text):
424
+ """Classify text based on patterns"""
425
+ if not text:
426
+ return 'Unknown'
427
+
428
+ # Clean and normalize text
429
+ text = text.strip().upper()
430
+ text = re.sub(r'\s+', '', text)
431
+
432
+ for text_type, pattern in TEXT_PATTERNS.items():
433
+ if re.match(pattern, text):
434
+ return text_type
435
+
436
+ return 'Unknown'
437
+
438
+ def annotate_image(image, results):
439
+ """Create annotated image with detections"""
440
+ # Convert image to RGB mode to ensure color support
441
+ if image.mode != 'RGB':
442
+ image = image.convert('RGB')
443
+
444
+ # Create drawing object
445
+ draw = ImageDraw.Draw(image)
446
+ try:
447
+ font = ImageFont.truetype("arial.ttf", 20)
448
+ except IOError:
449
+ font = ImageFont.load_default()
450
+
451
+ # Define colors for different text types
452
+ colors = {
453
+ 'Line_Number': "#FF0000", # Bright Red
454
+ 'Equipment_Tag': "#00FF00", # Bright Green
455
+ 'Instrument_Tag': "#0000FF", # Bright Blue
456
+ 'Valve_Number': "#FFA500", # Bright Orange
457
+ 'Pipe_Size': "#FF00FF", # Bright Magenta
458
+ 'Process_Instrument': "#00FFFF", # Bright Cyan
459
+ 'Nozzle': "#FFFF00", # Yellow
460
+ 'Pipe_Connector': "#800080", # Purple
461
+ 'Unknown': "#FF4444" # Light Red
462
+ }
463
+
464
+ # Draw detections
465
+ for result in results:
466
+ text_type = result.get('text_type', 'Unknown')
467
+ color = colors.get(text_type, colors['Unknown'])
468
+
469
+ # Draw bounding box
470
+ draw.rectangle(result['bbox'], outline=color, width=3)
471
+
472
+ # Create label
473
+ label = f"{result['text']} ({result['confidence']:.2f})"
474
+ if text_type != 'Unknown':
475
+ label += f" [{text_type}]"
476
+
477
+ # Draw label background
478
+ text_bbox = draw.textbbox((result['bbox'][0], result['bbox'][1] - 20), label, font=font)
479
+ draw.rectangle(text_bbox, fill="#FFFFFF")
480
+
481
+ # Draw label text
482
+ draw.text((result['bbox'][0], result['bbox'][1] - 20), label, fill=color, font=font)
483
+
484
+ return image
485
+
486
+ def save_annotated_image(image, path, storage):
487
+ """Save annotated image with maximum quality"""
488
+ image_byte_array = io.BytesIO()
489
+ image.save(
490
+ image_byte_array,
491
+ format='PNG',
492
+ optimize=False,
493
+ compress_level=0
494
+ )
495
+ storage.save_file(path, image_byte_array.getvalue())
496
+
497
+ if __name__ == "__main__":
498
+ from storage import StorageFactory
499
+ import logging
500
+
501
+ # Configure logging
502
+ logging.basicConfig(level=logging.INFO)
503
+ logger = logging.getLogger(__name__)
504
+
505
+ # Initialize storage
506
+ storage = StorageFactory.get_storage()
507
+
508
+ # Test file paths
509
+ file_path = "processed_pages/10219-1-DG-BC-00011.01-REV_A_page_1_text.png"
510
+ result_path = "results"
511
+
512
+ try:
513
+ # Ensure result directory exists
514
+ os.makedirs(result_path, exist_ok=True)
515
+
516
+ # Process the drawing
517
+ logger.info(f"Processing file: {file_path}")
518
+ results, summary = process_drawing(file_path, result_path, storage)
519
+
520
+ # Print detailed results
521
+ print("\n=== DETAILED DETECTION RESULTS ===")
522
+ print(f"\nTotal Detections: {summary['total_detections']}")
523
+
524
+ print("\nBreakdown by Text Type:")
525
+ print("-" * 50)
526
+ for text_type, stats in summary['by_type'].items():
527
+ if stats['count'] > 0:
528
+ print(f"\n{text_type}:")
529
+ print(f" Count: {stats['count']}")
530
+ print(f" Average Confidence: {stats['avg_confidence']:.2f}")
531
+ print(" Items:")
532
+ for item in stats['items']:
533
+ print(f" - {item['text']} (conf: {item['confidence']:.2f}, source: {item['source']})")
534
+
535
+ print("\nBreakdown by OCR Engine:")
536
+ print("-" * 50)
537
+ for source, count in summary['by_source'].items():
538
+ print(f"{source}: {count} detections")
539
+
540
+ print("\nConfidence Distribution:")
541
+ print("-" * 50)
542
+ for range_name, count in summary['confidence_ranges'].items():
543
+ print(f"{range_name}: {count} detections")
544
+
545
+ # Print output paths
546
+ print("\nOutput Files:")
547
+ print("-" * 50)
548
+ print(f"Annotated Image: {results['image_path']}")
549
+ print(f"JSON Results: {results['json_path']}")
550
+
551
+ except Exception as e:
552
+ logger.error(f"Error processing file: {e}")
553
+ raise
utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import numpy as np
3
+ from contextlib import contextmanager
4
+ from loguru import logger
5
+ from typing import List, Dict, Optional, Tuple, Union
6
+ from detection_schema import BBox
7
+ from storage import StorageInterface
8
+
9
+ class DebugHandler:
10
+ """Production-grade debugging and performance tracking"""
11
+
12
+ def __init__(self, enabled: bool = False, storage: StorageInterface = None):
13
+ self.enabled = enabled
14
+ self.storage = storage
15
+ self.metrics = {}
16
+ self._start_time = None
17
+
18
+ @contextmanager
19
+ def track_performance(self, operation_name: str):
20
+ """Context manager for performance tracking"""
21
+ if self.enabled:
22
+ self._start_time = time.perf_counter()
23
+ logger.debug(f"Starting {operation_name}")
24
+
25
+ yield
26
+
27
+ if self.enabled:
28
+ duration = time.perf_counter() - self._start_time
29
+ self.metrics[operation_name] = duration
30
+ logger.debug(f"{operation_name} completed in {duration:.2f}s")
31
+
32
+ def save_artifact(self, name: str, data: bytes, extension: str = "png"):
33
+ """Generic artifact storage handler"""
34
+ if self.enabled and self.storage:
35
+ path = f"debug/{name}.{extension}"
36
+ self.storage.save_file(path, data)
37
+ logger.info(f"Saved debug artifact: {path}")
38
+
39
+
40
+ class CoordinateTransformer:
41
+ @staticmethod
42
+ def global_to_local_bbox(
43
+ bbox: Union[BBox, List[BBox]],
44
+ roi: Optional[np.ndarray]
45
+ ) -> Union[BBox, List[BBox]]:
46
+ """
47
+ Convert global BBox(es) to ROI-local coordinates
48
+ Handles both single BBox and lists of BBoxes
49
+ """
50
+ if roi is None or len(roi) != 4:
51
+ return bbox
52
+
53
+ x_min, y_min, _, _ = roi
54
+
55
+ def convert(b: BBox) -> BBox:
56
+ return BBox(
57
+ xmin=b.xmin - x_min,
58
+ ymin=b.ymin - y_min,
59
+ xmax=b.xmax - x_min,
60
+ ymax=b.ymax - y_min
61
+ )
62
+
63
+ return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
64
+
65
+ @staticmethod
66
+ def local_to_global_bbox(
67
+ bbox: Union[BBox, List[BBox]],
68
+ roi: Optional[np.ndarray]
69
+ ) -> Union[BBox, List[BBox]]:
70
+ """
71
+ Convert ROI-local BBox(es) to global coordinates
72
+ Handles both single BBox and lists of BBoxes
73
+ """
74
+ if roi is None or len(roi) != 4:
75
+ return bbox
76
+
77
+ x_min, y_min, _, _ = roi
78
+
79
+ def convert(b: BBox) -> BBox:
80
+ return BBox(
81
+ xmin=b.xmin + x_min,
82
+ ymin=b.ymin + y_min,
83
+ xmax=b.xmax + x_min,
84
+ ymax=b.ymax + y_min
85
+ )
86
+
87
+ return map(convert, bbox) if isinstance(bbox, list) else convert(bbox)
88
+
89
+ # Maintain legacy tuple support if needed
90
+ @staticmethod
91
+ def global_to_local(
92
+ bboxes: List[Tuple[int, int, int, int]],
93
+ roi: Optional[np.ndarray]
94
+ ) -> List[Tuple[int, int, int, int]]:
95
+ """Legacy tuple version for backward compatibility"""
96
+ if roi is None or len(roi) != 4:
97
+ return bboxes
98
+
99
+ x_min, y_min, _, _ = roi
100
+ return [(x1 - x_min, y1 - y_min, x2 - x_min, y2 - y_min)
101
+ for x1, y1, x2, y2 in bboxes]
102
+
103
+ @staticmethod
104
+ def local_to_global(
105
+ bboxes: List[Tuple[int, int, int, int]],
106
+ roi: Optional[np.ndarray]
107
+ ) -> List[Tuple[int, int, int, int]]:
108
+ """Legacy tuple version for backward compatibility"""
109
+ if roi is None or len(roi) != 4:
110
+ return bboxes
111
+
112
+ x_min, y_min, _, _ = roi
113
+ return [(x1 + x_min, y1 + y_min, x2 + x_min, y2 + y_min)
114
+ for x1, y1, x2, y2 in bboxes]
115
+
116
+ @staticmethod
117
+ def local_to_global_point(point: Tuple[int, int], roi: Optional[np.ndarray]) -> Tuple[int, int]:
118
+ """Convert single point from local to global coordinates"""
119
+ if roi is None or len(roi) != 4:
120
+ return point
121
+ x_min, y_min, _, _ = roi
122
+ return (int(point[0] + x_min), int(point[1] + y_min))