intelligent-pid / data_aggregation_ai.py
msIntui
Initial commit: Add core files for P&ID processing
9847531
from pathlib import Path
import json
import logging
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from storage import StorageFactory
import uuid
import traceback
logger = logging.getLogger(__name__)
class DataAggregator:
def __init__(self, storage=None):
self.storage = storage or StorageFactory.get_storage()
self.logger = logging.getLogger(__name__)
def _parse_line_data(self, lines_data: dict) -> List[dict]:
"""Parse line detection data with coordinate validation"""
parsed_lines = []
for line in lines_data.get("lines", []):
try:
# Extract and validate line coordinates
start_coords = line["start"]["coords"]
end_coords = line["end"]["coords"]
bbox = line["bbox"]
# Validate coordinates
if not (self._is_valid_point(start_coords) and
self._is_valid_point(end_coords) and
self._is_valid_bbox(bbox)):
self.logger.warning(f"Invalid coordinates in line: {line['id']}")
continue
# Create parsed line with validated coordinates
parsed_line = {
"id": line["id"],
"start_point": {
"x": int(start_coords["x"]),
"y": int(start_coords["y"]),
"type": line["start"]["type"],
"confidence": line["start"]["confidence"]
},
"end_point": {
"x": int(end_coords["x"]),
"y": int(end_coords["y"]),
"type": line["end"]["type"],
"confidence": line["end"]["confidence"]
},
"bbox": {
"xmin": int(bbox["xmin"]),
"ymin": int(bbox["ymin"]),
"xmax": int(bbox["xmax"]),
"ymax": int(bbox["ymax"])
},
"style": line["style"],
"confidence": line["confidence"]
}
parsed_lines.append(parsed_line)
except Exception as e:
self.logger.error(f"Error parsing line {line.get('id')}: {str(e)}")
continue
return parsed_lines
def _is_valid_point(self, point: dict) -> bool:
"""Validate point coordinates"""
try:
x, y = point.get("x"), point.get("y")
return (isinstance(x, (int, float)) and
isinstance(y, (int, float)) and
0 <= x <= 10000 and 0 <= y <= 10000) # Adjust range as needed
except:
return False
def _is_valid_bbox(self, bbox: dict) -> bool:
"""Validate bbox coordinates"""
try:
xmin = bbox.get("xmin")
ymin = bbox.get("ymin")
xmax = bbox.get("xmax")
ymax = bbox.get("ymax")
return (isinstance(xmin, (int, float)) and
isinstance(ymin, (int, float)) and
isinstance(xmax, (int, float)) and
isinstance(ymax, (int, float)) and
xmin < xmax and ymin < ymax and
0 <= xmin <= 10000 and 0 <= ymin <= 10000 and
0 <= xmax <= 10000 and 0 <= ymax <= 10000)
except:
return False
def _create_graph_data(self, lines: List[dict], symbols: List[dict], texts: List[dict]) -> Tuple[List[dict], List[dict]]:
"""Create nodes and edges for the knowledge graph following the three-step process"""
nodes = []
edges = []
# Step 1: Create Object Nodes with their properties and center points
# 1a. Symbol Nodes
for symbol in symbols:
bbox = symbol["bbox"]
center_x = (bbox["xmin"] + bbox["xmax"]) / 2
center_y = (bbox["ymin"] + bbox["ymax"]) / 2
node = {
"id": symbol.get("id", str(uuid.uuid4())),
"type": "symbol",
"category": symbol.get("category", "unknown"),
"bbox": bbox,
"center": {"x": center_x, "y": center_y},
"confidence": symbol.get("confidence", 1.0),
"properties": {
"class": symbol.get("class", ""),
"equipment_type": symbol.get("equipment_type", ""),
"original_label": symbol.get("original_label", ""),
}
}
nodes.append(node)
# 1b. Text Nodes
for text in texts:
bbox = text["bbox"]
center_x = (bbox["xmin"] + bbox["xmax"]) / 2
center_y = (bbox["ymin"] + bbox["ymax"]) / 2
node = {
"id": text.get("id", str(uuid.uuid4())),
"type": "text",
"content": text.get("text", ""),
"bbox": bbox,
"center": {"x": center_x, "y": center_y},
"confidence": text.get("confidence", 1.0),
"properties": {
"font_size": text.get("font_size"),
"rotation": text.get("rotation", 0.0),
"text_type": text.get("text_type", "unknown")
}
}
nodes.append(node)
# Step 2: Create Junction Nodes (T/L connections)
junction_map = {} # To track junctions for edge creation
for line in lines:
# Check start point
if line["start_point"].get("type") in ["T", "L"]:
junction_id = str(uuid.uuid4())
junction_node = {
"id": junction_id,
"type": "junction",
"junction_type": line["start_point"]["type"],
"coords": {
"x": line["start_point"]["x"],
"y": line["start_point"]["y"]
},
"properties": {
"confidence": line["start_point"].get("confidence", 1.0)
}
}
nodes.append(junction_node)
junction_map[f"{line['start_point']['x']}_{line['start_point']['y']}"] = junction_id
# Check end point
if line["end_point"].get("type") in ["T", "L"]:
junction_id = str(uuid.uuid4())
junction_node = {
"id": junction_id,
"type": "junction",
"junction_type": line["end_point"]["type"],
"coords": {
"x": line["end_point"]["x"],
"y": line["end_point"]["y"]
},
"properties": {
"confidence": line["end_point"].get("confidence", 1.0)
}
}
nodes.append(junction_node)
junction_map[f"{line['end_point']['x']}_{line['end_point']['y']}"] = junction_id
# Step 3: Create Edges with connection points and topology
# 3a. Line-Junction Connections
for line in lines:
line_id = line.get("id", str(uuid.uuid4()))
start_key = f"{line['start_point']['x']}_{line['start_point']['y']}"
end_key = f"{line['end_point']['x']}_{line['end_point']['y']}"
# Create edge for line itself
edge = {
"id": line_id,
"type": "line",
"source": junction_map.get(start_key, str(uuid.uuid4())),
"target": junction_map.get(end_key, str(uuid.uuid4())),
"properties": {
"style": line["style"],
"confidence": line.get("confidence", 1.0),
"connection_points": {
"start": {"x": line["start_point"]["x"], "y": line["start_point"]["y"]},
"end": {"x": line["end_point"]["x"], "y": line["end_point"]["y"]}
},
"bbox": line["bbox"]
}
}
edges.append(edge)
# 3b. Symbol-Line Connections (based on spatial proximity)
for symbol in symbols:
symbol_center = {
"x": (symbol["bbox"]["xmin"] + symbol["bbox"]["xmax"]) / 2,
"y": (symbol["bbox"]["ymin"] + symbol["bbox"]["ymax"]) / 2
}
# Find connected lines based on proximity to endpoints
for line in lines:
# Check if line endpoints are near symbol center
for point_type in ["start_point", "end_point"]:
point = line[point_type]
dist = ((point["x"] - symbol_center["x"])**2 +
(point["y"] - symbol_center["y"])**2)**0.5
if dist < 50: # Threshold for connection, adjust as needed
edge = {
"id": str(uuid.uuid4()),
"type": "symbol_line_connection",
"source": symbol["id"],
"target": line["id"],
"properties": {
"connection_point": {"x": point["x"], "y": point["y"]},
"connection_type": point_type,
"distance": dist
}
}
edges.append(edge)
# 3c. Symbol-Text Associations (based on proximity and containment)
for text in texts:
text_center = {
"x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2,
"y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2
}
for symbol in symbols:
# Check if text is near or contained within symbol
if (text_center["x"] >= symbol["bbox"]["xmin"] - 20 and
text_center["x"] <= symbol["bbox"]["xmax"] + 20 and
text_center["y"] >= symbol["bbox"]["ymin"] - 20 and
text_center["y"] <= symbol["bbox"]["ymax"] + 20):
edge = {
"id": str(uuid.uuid4()),
"type": "symbol_text_association",
"source": symbol["id"],
"target": text["id"],
"properties": {
"association_type": "label",
"confidence": min(symbol.get("confidence", 1.0),
text.get("confidence", 1.0))
}
}
edges.append(edge)
# 3d. Line-Text Associations (based on proximity and alignment)
for text in texts:
text_center = {
"x": (text["bbox"]["xmin"] + text["bbox"]["xmax"]) / 2,
"y": (text["bbox"]["ymin"] + text["bbox"]["ymax"]) / 2
}
text_bbox = text["bbox"]
for line in lines:
line_bbox = line["bbox"]
line_center = {
"x": (line_bbox["xmin"] + line_bbox["xmax"]) / 2,
"y": (line_bbox["ymin"] + line_bbox["ymax"]) / 2
}
# Check if text is near the line (using both center and bbox)
is_nearby_horizontal = (
abs(text_center["y"] - line_center["y"]) < 30 and # Vertical proximity
text_bbox["xmin"] <= line_bbox["xmax"] and
text_bbox["xmax"] >= line_bbox["xmin"]
)
is_nearby_vertical = (
abs(text_center["x"] - line_center["x"]) < 30 and # Horizontal proximity
text_bbox["ymin"] <= line_bbox["ymax"] and
text_bbox["ymax"] >= line_bbox["ymin"]
)
# Determine text type and position relative to line
if is_nearby_horizontal or is_nearby_vertical:
text_type = text.get("text_type", "unknown").lower()
# Classify the text based on content and position
if any(pattern in text.get("text", "").upper()
for pattern in ["-", "LINE", "PIPE"]):
association_type = "line_id"
else:
association_type = "description"
edge = {
"id": str(uuid.uuid4()),
"type": "line_text_association",
"source": line["id"],
"target": text["id"],
"properties": {
"association_type": association_type,
"relative_position": "horizontal" if is_nearby_horizontal else "vertical",
"confidence": min(line.get("confidence", 1.0),
text.get("confidence", 1.0)),
"distance": abs(text_center["y"] - line_center["y"]) if is_nearby_horizontal
else abs(text_center["x"] - line_center["x"])
}
}
edges.append(edge)
return nodes, edges
def _validate_coordinates(self, data, data_type):
"""Validate coordinates in the data"""
if not data:
return False
try:
if data_type == 'line':
# Check start and end points
start = data.get('start_point', {})
end = data.get('end_point', {})
bbox = data.get('bbox', {})
required_fields = ['x', 'y', 'type']
if not all(field in start for field in required_fields):
self.logger.warning(f"Missing required fields in start_point: {start}")
return False
if not all(field in end for field in required_fields):
self.logger.warning(f"Missing required fields in end_point: {end}")
return False
# Validate bbox coordinates
if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
self.logger.warning(f"Invalid bbox format: {bbox}")
return False
# Check coordinate consistency
if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
self.logger.warning(f"Invalid bbox coordinates: {bbox}")
return False
elif data_type in ['symbol', 'text']:
bbox = data.get('bbox', {})
if not all(key in bbox for key in ['xmin', 'ymin', 'xmax', 'ymax']):
self.logger.warning(f"Invalid {data_type} bbox format: {bbox}")
return False
# Check coordinate consistency
if bbox['xmin'] > bbox['xmax'] or bbox['ymin'] > bbox['ymax']:
self.logger.warning(f"Invalid {data_type} bbox coordinates: {bbox}")
return False
return True
except Exception as e:
self.logger.error(f"Validation error for {data_type}: {str(e)}")
return False
def aggregate_data(self, symbols_path: str, texts_path: str, lines_path: str) -> dict:
"""Aggregate detection results and create graph structure"""
try:
# Load line detection results
lines_data = json.loads(self.storage.load_file(lines_path).decode('utf-8'))
lines = self._parse_line_data(lines_data)
# Load symbol detections
symbols = []
if symbols_path and Path(symbols_path).exists():
symbols_data = json.loads(self.storage.load_file(symbols_path).decode('utf-8'))
symbols = symbols_data.get("symbols", [])
# Load text detections
texts = []
if texts_path and Path(texts_path).exists():
texts_data = json.loads(self.storage.load_file(texts_path).decode('utf-8'))
texts = texts_data.get("texts", [])
# Create graph data
nodes, edges = self._create_graph_data(lines, symbols, texts)
# Combine all detections
aggregated_data = {
"lines": lines,
"symbols": symbols,
"texts": texts,
"nodes": nodes,
"edges": edges,
"metadata": {
"timestamp": datetime.now().isoformat(),
"version": "2.0"
}
}
return aggregated_data
except Exception as e:
logger.error(f"Error during aggregation: {str(e)}")
raise
if __name__ == "__main__":
import os
from pprint import pprint
# Initialize the aggregator
aggregator = DataAggregator()
# Test paths (adjust these to match your results folder)
results_dir = "results/"
symbols_path = os.path.join(results_dir, "0_text_detected_symbols.json")
texts_path = os.path.join(results_dir, "0_text_detected_texts.json")
lines_path = os.path.join(results_dir, "0_text_detected_lines.json")
try:
# Aggregate the data
aggregated_data = aggregator.aggregate_data(
symbols_path=symbols_path,
texts_path=texts_path,
lines_path=lines_path
)
# Save the aggregated result
output_path = os.path.join(results_dir, "0_aggregated_test.json")
with open(output_path, 'w') as f:
json.dump(aggregated_data, f, indent=2)
# Print some statistics
print("\nAggregation Results:")
print(f"Number of Symbols: {len(aggregated_data['symbols'])}")
print(f"Number of Texts: {len(aggregated_data['texts'])}")
print(f"Number of Lines: {len(aggregated_data['lines'])}")
print(f"Number of Nodes: {len(aggregated_data['nodes'])}")
print(f"Number of Edges: {len(aggregated_data['edges'])}")
# Print sample of each type
print("\nSample Node:")
if aggregated_data['nodes']:
pprint(aggregated_data['nodes'][0])
print("\nSample Edge:")
if aggregated_data['edges']:
pprint(aggregated_data['edges'][0])
print(f"\nAggregated data saved to: {output_path}")
except Exception as e:
print(f"Error during testing: {str(e)}")
traceback.print_exc()