Spaces:
Build error
Build error
File size: 9,275 Bytes
9847531 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 | import json
import networkx as nx
import matplotlib.pyplot as plt
import os
from pprint import pprint
import uuid
def create_graph_visualization(json_path, save_plot=True):
"""Create and visualize a graph from the aggregated JSON data"""
# Load the aggregated data
with open(json_path, 'r') as f:
data = json.load(f)
print("\nData Structure:")
print(f"Keys in data: {data.keys()}")
for key in data.keys():
if isinstance(data[key], list):
print(f"Number of {key}: {len(data[key])}")
if data[key]:
print(f"Sample {key}:", data[key][0])
# Create a new graph
G = nx.Graph()
pos = {}
# Track unique junctions by coordinates to avoid duplicates
junction_map = {}
# Process lines and create unique junctions
print("\nProcessing Lines and Junctions:")
for line in data.get('lines', []):
try:
# Get line properties from edge data for accurate coordinates
edge_data = None
for edge in data.get('edges', []):
if edge['id'] == line['id']:
edge_data = edge
break
# Get coordinates and create unique ID for start/end points
if edge_data and 'connection_points' in edge_data['properties']:
conn_points = edge_data['properties']['connection_points']
start_x = int(conn_points['start']['x'])
start_y = int(conn_points['start']['y'])
end_x = int(conn_points['end']['x'])
end_y = int(conn_points['end']['y'])
start_id = str(uuid.uuid4())
end_id = str(uuid.uuid4())
else:
# Fallback to line points
start_x = int(line['start_point']['x'])
start_y = int(line['start_point']['y'])
end_x = int(line['end_point']['x'])
end_y = int(line['end_point']['y'])
start_id = line['start_point'].get('id', str(uuid.uuid4()))
end_id = line['end_point'].get('id', str(uuid.uuid4()))
# Skip invalid coordinates
if not (0 <= start_x <= 10000 and 0 <= start_y <= 10000 and
0 <= end_x <= 10000 and 0 <= end_y <= 10000):
print(f"Skipping line with invalid coordinates: ({start_x}, {start_y}) -> ({end_x}, {end_y})")
continue
# Create or get junction nodes
start_key = f"{start_x}_{start_y}"
end_key = f"{end_x}_{end_y}"
if start_key not in junction_map:
junction_map[start_key] = start_id
G.add_node(start_id,
type='junction',
junction_type=line['start_point'].get('type', 'unknown'),
coords={'x': start_x, 'y': start_y})
pos[start_id] = (start_x, start_y)
if end_key not in junction_map:
junction_map[end_key] = end_id
G.add_node(end_id,
type='junction',
junction_type=line['end_point'].get('type', 'unknown'),
coords={'x': end_x, 'y': end_y})
pos[end_id] = (end_x, end_y)
# Add line as edge with style properties
G.add_edge(junction_map[start_key], junction_map[end_key],
type='line',
style=line.get('style', {}))
except Exception as e:
print(f"Error processing line: {str(e)}")
continue
# Add symbols and texts after lines
print("\nProcessing Symbols and Texts:")
for node in data.get('nodes', []):
if node['type'] in ['symbol', 'text']:
node_id = node['id']
coords = node.get('coords', {})
if coords:
x, y = coords['x'], coords['y']
elif 'center' in node:
x, y = node['center']['x'], node['center']['y']
else:
bbox = node['bbox']
x = (bbox['xmin'] + bbox['xmax']) / 2
y = (bbox['ymin'] + bbox['ymax']) / 2
G.add_node(node_id, **node)
pos[node_id] = (x, y)
print(f"Added {node['type']} at ({x}, {y})")
# Add default node if graph is empty
if not pos:
default_id = str(uuid.uuid4())
G.add_node(default_id, type='junction', coords={'x': 0, 'y': 0})
pos[default_id] = (0, 0)
# Scale positions to fit in [0, 1] range
x_vals = [p[0] for p in pos.values()]
y_vals = [p[1] for p in pos.values()]
x_min, x_max = min(x_vals), max(x_vals)
y_min, y_max = min(y_vals), max(y_vals)
scaled_pos = {}
for node, (x, y) in pos.items():
scaled_x = (x - x_min) / (x_max - x_min) if x_max > x_min else 0.5
scaled_y = 1 - ((y - y_min) / (y_max - y_min) if y_max > y_min else 0.5) # Flip Y coordinates
scaled_pos[node] = (scaled_x, scaled_y)
# Visualization attributes
node_colors = []
node_sizes = []
labels = {}
for node in G.nodes():
node_data = G.nodes[node]
if node_data['type'] == 'symbol':
node_colors.append('lightblue')
node_sizes.append(1000)
labels[node] = f"S:{node_data.get('properties', {}).get('class', 'unknown')}"
elif node_data['type'] == 'text':
node_colors.append('lightgreen')
node_sizes.append(800)
content = node_data.get('content', '')
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
else: # junction
node_colors.append('#ff0000') # Pure red
node_sizes.append(5) # Even smaller junction nodes
labels[node] = "" # No labels for junctions
# Update visualization
if save_plot:
plt.figure(figsize=(20, 20))
# Draw edges with styles
edge_styles = []
for (u, v, data) in G.edges(data=True):
if data.get('type') == 'line':
style = data.get('style', {})
color = style.get('color', '#000000')
width = float(style.get('stroke_width', 0.5))
alpha = 0.7
line_style = '--' if style.get('connection_type') == 'dashed' else '-'
nx.draw_networkx_edges(G, scaled_pos,
edgelist=[(u, v)],
edge_color=color,
width=width,
alpha=alpha,
style=line_style)
# Track unique styles for legend
edge_style = (line_style, color, width)
if edge_style not in edge_styles:
edge_styles.append(edge_style)
# Draw nodes with smaller junctions
nx.draw_networkx_nodes(G, scaled_pos,
node_color=node_colors,
node_size=[3 if size == 5 else size for size in node_sizes], # Even smaller junctions
alpha=1.0)
# Add labels only for symbols and texts
labels = {k: v for k, v in labels.items() if v}
nx.draw_networkx_labels(G, scaled_pos, labels,
font_size=8,
font_weight='bold')
# Create comprehensive legend
legend_elements = []
# Node types
legend_elements.extend([
plt.scatter([0], [0], c='lightblue', s=200, label='Symbol'),
plt.scatter([0], [0], c='lightgreen', s=200, label='Text'),
plt.scatter([0], [0], c='red', s=20, label='Junction')
])
# Line styles
for style, color, width in edge_styles:
legend_elements.append(
plt.Line2D([0], [0], color=color, linestyle=style,
linewidth=width, label=f'Line ({style})')
)
# Add legend with two columns
plt.legend(handles=legend_elements,
loc='center left',
bbox_to_anchor=(1, 0.5),
ncol=1,
fontsize=12,
title="Graph Elements")
plt.title("P&ID Knowledge Graph Visualization", pad=20, fontsize=16)
plt.axis('on')
plt.grid(True)
# Save with extra space for legend
output_path = os.path.join(os.path.dirname(json_path), "graph_visualization.png")
plt.savefig(output_path, bbox_inches='tight', dpi=300,
facecolor='white', edgecolor='none')
plt.close()
return G, scaled_pos
if __name__ == "__main__":
# Test the visualization
json_path = "results/001_page_1_text_aggregated.json"
if os.path.exists(json_path):
G, scaled_pos = create_graph_visualization(json_path)
plt.show()
else:
print(f"Error: Could not find {json_path}") |