intelligent-pid / graph_construction.py
msIntui
Initial commit: Add core files for P&ID processing
9847531
import os
import json
import networkx as nx
import matplotlib.pyplot as plt
from pathlib import Path
import logging
import traceback
from storage import StorageFactory
logger = logging.getLogger(__name__)
def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
"""Construct network graph from aggregated detection data"""
try:
# Use provided storage or get a new one
if storage is None:
storage = StorageFactory.get_storage()
# Create graph
G = nx.Graph()
pos = {} # For node positions
# Add nodes from the aggregated data
for node in data.get('nodes', []):
node_id = node['id']
node_type = node['type']
# Calculate position based on node type
if node_type == 'connection_point':
pos[node_id] = (node['coords']['x'], node['coords']['y'])
else: # symbol or text
bbox = node['bbox']
pos[node_id] = (
(bbox['xmin'] + bbox['xmax']) / 2,
(bbox['ymin'] + bbox['ymax']) / 2
)
# Add node with all its properties
G.add_node(node_id, **node)
# Add edges from the aggregated data
for edge in data.get('edges', []):
G.add_edge(
edge['source'],
edge['target'],
**edge.get('properties', {})
)
# Create visualization
plt.figure(figsize=(20, 20))
# Draw nodes with different colors based on type
node_colors = []
for node in G.nodes():
node_type = G.nodes[node]['type']
if node_type == 'symbol':
node_colors.append('lightblue')
elif node_type == 'text':
node_colors.append('lightgreen')
else: # connection_point
node_colors.append('lightgray')
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
# Add labels
labels = {}
for node in G.nodes():
node_data = G.nodes[node]
if node_data['type'] == 'symbol':
labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
elif node_data['type'] == 'text':
content = node_data.get('content', '')
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
else:
labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
nx.draw_networkx_labels(G, pos, labels, font_size=8)
plt.title("P&ID Knowledge Graph")
plt.axis('off')
# Save the visualization
graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
plt.close()
# Save graph data as JSON for future use
graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
with open(graph_json_path, 'w') as f:
json.dump(nx.node_link_data(G), f, indent=2)
return G, pos, plt.gcf()
except Exception as e:
logger.error(f"Error in construct_graph_network: {str(e)}")
traceback.print_exc()
return None, None, None
if __name__ == "__main__":
# Test code
test_data_path = "results/test_aggregated.json"
if os.path.exists(test_data_path):
with open(test_data_path, 'r') as f:
test_data = json.load(f)
G, pos, fig = construct_graph_network(
test_data,
"results/validation.json",
"results"
)
if fig:
plt.show()