intelligent-pid / graph_visualization.py
msIntui
Initial commit: Add core files for P&ID processing
9847531
import json
import networkx as nx
import matplotlib.pyplot as plt
import os
from pprint import pprint
import uuid
def create_graph_visualization(json_path, save_plot=True):
"""Create and visualize a graph from the aggregated JSON data"""
# Load the aggregated data
with open(json_path, 'r') as f:
data = json.load(f)
print("\nData Structure:")
print(f"Keys in data: {data.keys()}")
for key in data.keys():
if isinstance(data[key], list):
print(f"Number of {key}: {len(data[key])}")
if data[key]:
print(f"Sample {key}:", data[key][0])
# Create a new graph
G = nx.Graph()
pos = {}
# Track unique junctions by coordinates to avoid duplicates
junction_map = {}
# Process lines and create unique junctions
print("\nProcessing Lines and Junctions:")
for line in data.get('lines', []):
try:
# Get line properties from edge data for accurate coordinates
edge_data = None
for edge in data.get('edges', []):
if edge['id'] == line['id']:
edge_data = edge
break
# Get coordinates and create unique ID for start/end points
if edge_data and 'connection_points' in edge_data['properties']:
conn_points = edge_data['properties']['connection_points']
start_x = int(conn_points['start']['x'])
start_y = int(conn_points['start']['y'])
end_x = int(conn_points['end']['x'])
end_y = int(conn_points['end']['y'])
start_id = str(uuid.uuid4())
end_id = str(uuid.uuid4())
else:
# Fallback to line points
start_x = int(line['start_point']['x'])
start_y = int(line['start_point']['y'])
end_x = int(line['end_point']['x'])
end_y = int(line['end_point']['y'])
start_id = line['start_point'].get('id', str(uuid.uuid4()))
end_id = line['end_point'].get('id', str(uuid.uuid4()))
# Skip invalid coordinates
if not (0 <= start_x <= 10000 and 0 <= start_y <= 10000 and
0 <= end_x <= 10000 and 0 <= end_y <= 10000):
print(f"Skipping line with invalid coordinates: ({start_x}, {start_y}) -> ({end_x}, {end_y})")
continue
# Create or get junction nodes
start_key = f"{start_x}_{start_y}"
end_key = f"{end_x}_{end_y}"
if start_key not in junction_map:
junction_map[start_key] = start_id
G.add_node(start_id,
type='junction',
junction_type=line['start_point'].get('type', 'unknown'),
coords={'x': start_x, 'y': start_y})
pos[start_id] = (start_x, start_y)
if end_key not in junction_map:
junction_map[end_key] = end_id
G.add_node(end_id,
type='junction',
junction_type=line['end_point'].get('type', 'unknown'),
coords={'x': end_x, 'y': end_y})
pos[end_id] = (end_x, end_y)
# Add line as edge with style properties
G.add_edge(junction_map[start_key], junction_map[end_key],
type='line',
style=line.get('style', {}))
except Exception as e:
print(f"Error processing line: {str(e)}")
continue
# Add symbols and texts after lines
print("\nProcessing Symbols and Texts:")
for node in data.get('nodes', []):
if node['type'] in ['symbol', 'text']:
node_id = node['id']
coords = node.get('coords', {})
if coords:
x, y = coords['x'], coords['y']
elif 'center' in node:
x, y = node['center']['x'], node['center']['y']
else:
bbox = node['bbox']
x = (bbox['xmin'] + bbox['xmax']) / 2
y = (bbox['ymin'] + bbox['ymax']) / 2
G.add_node(node_id, **node)
pos[node_id] = (x, y)
print(f"Added {node['type']} at ({x}, {y})")
# Add default node if graph is empty
if not pos:
default_id = str(uuid.uuid4())
G.add_node(default_id, type='junction', coords={'x': 0, 'y': 0})
pos[default_id] = (0, 0)
# Scale positions to fit in [0, 1] range
x_vals = [p[0] for p in pos.values()]
y_vals = [p[1] for p in pos.values()]
x_min, x_max = min(x_vals), max(x_vals)
y_min, y_max = min(y_vals), max(y_vals)
scaled_pos = {}
for node, (x, y) in pos.items():
scaled_x = (x - x_min) / (x_max - x_min) if x_max > x_min else 0.5
scaled_y = 1 - ((y - y_min) / (y_max - y_min) if y_max > y_min else 0.5) # Flip Y coordinates
scaled_pos[node] = (scaled_x, scaled_y)
# Visualization attributes
node_colors = []
node_sizes = []
labels = {}
for node in G.nodes():
node_data = G.nodes[node]
if node_data['type'] == 'symbol':
node_colors.append('lightblue')
node_sizes.append(1000)
labels[node] = f"S:{node_data.get('properties', {}).get('class', 'unknown')}"
elif node_data['type'] == 'text':
node_colors.append('lightgreen')
node_sizes.append(800)
content = node_data.get('content', '')
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
else: # junction
node_colors.append('#ff0000') # Pure red
node_sizes.append(5) # Even smaller junction nodes
labels[node] = "" # No labels for junctions
# Update visualization
if save_plot:
plt.figure(figsize=(20, 20))
# Draw edges with styles
edge_styles = []
for (u, v, data) in G.edges(data=True):
if data.get('type') == 'line':
style = data.get('style', {})
color = style.get('color', '#000000')
width = float(style.get('stroke_width', 0.5))
alpha = 0.7
line_style = '--' if style.get('connection_type') == 'dashed' else '-'
nx.draw_networkx_edges(G, scaled_pos,
edgelist=[(u, v)],
edge_color=color,
width=width,
alpha=alpha,
style=line_style)
# Track unique styles for legend
edge_style = (line_style, color, width)
if edge_style not in edge_styles:
edge_styles.append(edge_style)
# Draw nodes with smaller junctions
nx.draw_networkx_nodes(G, scaled_pos,
node_color=node_colors,
node_size=[3 if size == 5 else size for size in node_sizes], # Even smaller junctions
alpha=1.0)
# Add labels only for symbols and texts
labels = {k: v for k, v in labels.items() if v}
nx.draw_networkx_labels(G, scaled_pos, labels,
font_size=8,
font_weight='bold')
# Create comprehensive legend
legend_elements = []
# Node types
legend_elements.extend([
plt.scatter([0], [0], c='lightblue', s=200, label='Symbol'),
plt.scatter([0], [0], c='lightgreen', s=200, label='Text'),
plt.scatter([0], [0], c='red', s=20, label='Junction')
])
# Line styles
for style, color, width in edge_styles:
legend_elements.append(
plt.Line2D([0], [0], color=color, linestyle=style,
linewidth=width, label=f'Line ({style})')
)
# Add legend with two columns
plt.legend(handles=legend_elements,
loc='center left',
bbox_to_anchor=(1, 0.5),
ncol=1,
fontsize=12,
title="Graph Elements")
plt.title("P&ID Knowledge Graph Visualization", pad=20, fontsize=16)
plt.axis('on')
plt.grid(True)
# Save with extra space for legend
output_path = os.path.join(os.path.dirname(json_path), "graph_visualization.png")
plt.savefig(output_path, bbox_inches='tight', dpi=300,
facecolor='white', edgecolor='none')
plt.close()
return G, scaled_pos
if __name__ == "__main__":
# Test the visualization
json_path = "results/001_page_1_text_aggregated.json"
if os.path.exists(json_path):
G, scaled_pos = create_graph_visualization(json_path)
plt.show()
else:
print(f"Error: Could not find {json_path}")