import json import networkx as nx import matplotlib.pyplot as plt import os from pprint import pprint import uuid def create_graph_visualization(json_path, save_plot=True): """Create and visualize a graph from the aggregated JSON data""" # Load the aggregated data with open(json_path, 'r') as f: data = json.load(f) print("\nData Structure:") print(f"Keys in data: {data.keys()}") for key in data.keys(): if isinstance(data[key], list): print(f"Number of {key}: {len(data[key])}") if data[key]: print(f"Sample {key}:", data[key][0]) # Create a new graph G = nx.Graph() pos = {} # Track unique junctions by coordinates to avoid duplicates junction_map = {} # Process lines and create unique junctions print("\nProcessing Lines and Junctions:") for line in data.get('lines', []): try: # Get line properties from edge data for accurate coordinates edge_data = None for edge in data.get('edges', []): if edge['id'] == line['id']: edge_data = edge break # Get coordinates and create unique ID for start/end points if edge_data and 'connection_points' in edge_data['properties']: conn_points = edge_data['properties']['connection_points'] start_x = int(conn_points['start']['x']) start_y = int(conn_points['start']['y']) end_x = int(conn_points['end']['x']) end_y = int(conn_points['end']['y']) start_id = str(uuid.uuid4()) end_id = str(uuid.uuid4()) else: # Fallback to line points start_x = int(line['start_point']['x']) start_y = int(line['start_point']['y']) end_x = int(line['end_point']['x']) end_y = int(line['end_point']['y']) start_id = line['start_point'].get('id', str(uuid.uuid4())) end_id = line['end_point'].get('id', str(uuid.uuid4())) # Skip invalid coordinates if not (0 <= start_x <= 10000 and 0 <= start_y <= 10000 and 0 <= end_x <= 10000 and 0 <= end_y <= 10000): print(f"Skipping line with invalid coordinates: ({start_x}, {start_y}) -> ({end_x}, {end_y})") continue # Create or get junction nodes start_key = f"{start_x}_{start_y}" end_key = f"{end_x}_{end_y}" if start_key not in junction_map: junction_map[start_key] = start_id G.add_node(start_id, type='junction', junction_type=line['start_point'].get('type', 'unknown'), coords={'x': start_x, 'y': start_y}) pos[start_id] = (start_x, start_y) if end_key not in junction_map: junction_map[end_key] = end_id G.add_node(end_id, type='junction', junction_type=line['end_point'].get('type', 'unknown'), coords={'x': end_x, 'y': end_y}) pos[end_id] = (end_x, end_y) # Add line as edge with style properties G.add_edge(junction_map[start_key], junction_map[end_key], type='line', style=line.get('style', {})) except Exception as e: print(f"Error processing line: {str(e)}") continue # Add symbols and texts after lines print("\nProcessing Symbols and Texts:") for node in data.get('nodes', []): if node['type'] in ['symbol', 'text']: node_id = node['id'] coords = node.get('coords', {}) if coords: x, y = coords['x'], coords['y'] elif 'center' in node: x, y = node['center']['x'], node['center']['y'] else: bbox = node['bbox'] x = (bbox['xmin'] + bbox['xmax']) / 2 y = (bbox['ymin'] + bbox['ymax']) / 2 G.add_node(node_id, **node) pos[node_id] = (x, y) print(f"Added {node['type']} at ({x}, {y})") # Add default node if graph is empty if not pos: default_id = str(uuid.uuid4()) G.add_node(default_id, type='junction', coords={'x': 0, 'y': 0}) pos[default_id] = (0, 0) # Scale positions to fit in [0, 1] range x_vals = [p[0] for p in pos.values()] y_vals = [p[1] for p in pos.values()] x_min, x_max = min(x_vals), max(x_vals) y_min, y_max = min(y_vals), max(y_vals) scaled_pos = {} for node, (x, y) in pos.items(): scaled_x = (x - x_min) / (x_max - x_min) if x_max > x_min else 0.5 scaled_y = 1 - ((y - y_min) / (y_max - y_min) if y_max > y_min else 0.5) # Flip Y coordinates scaled_pos[node] = (scaled_x, scaled_y) # Visualization attributes node_colors = [] node_sizes = [] labels = {} for node in G.nodes(): node_data = G.nodes[node] if node_data['type'] == 'symbol': node_colors.append('lightblue') node_sizes.append(1000) labels[node] = f"S:{node_data.get('properties', {}).get('class', 'unknown')}" elif node_data['type'] == 'text': node_colors.append('lightgreen') node_sizes.append(800) content = node_data.get('content', '') labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}" else: # junction node_colors.append('#ff0000') # Pure red node_sizes.append(5) # Even smaller junction nodes labels[node] = "" # No labels for junctions # Update visualization if save_plot: plt.figure(figsize=(20, 20)) # Draw edges with styles edge_styles = [] for (u, v, data) in G.edges(data=True): if data.get('type') == 'line': style = data.get('style', {}) color = style.get('color', '#000000') width = float(style.get('stroke_width', 0.5)) alpha = 0.7 line_style = '--' if style.get('connection_type') == 'dashed' else '-' nx.draw_networkx_edges(G, scaled_pos, edgelist=[(u, v)], edge_color=color, width=width, alpha=alpha, style=line_style) # Track unique styles for legend edge_style = (line_style, color, width) if edge_style not in edge_styles: edge_styles.append(edge_style) # Draw nodes with smaller junctions nx.draw_networkx_nodes(G, scaled_pos, node_color=node_colors, node_size=[3 if size == 5 else size for size in node_sizes], # Even smaller junctions alpha=1.0) # Add labels only for symbols and texts labels = {k: v for k, v in labels.items() if v} nx.draw_networkx_labels(G, scaled_pos, labels, font_size=8, font_weight='bold') # Create comprehensive legend legend_elements = [] # Node types legend_elements.extend([ plt.scatter([0], [0], c='lightblue', s=200, label='Symbol'), plt.scatter([0], [0], c='lightgreen', s=200, label='Text'), plt.scatter([0], [0], c='red', s=20, label='Junction') ]) # Line styles for style, color, width in edge_styles: legend_elements.append( plt.Line2D([0], [0], color=color, linestyle=style, linewidth=width, label=f'Line ({style})') ) # Add legend with two columns plt.legend(handles=legend_elements, loc='center left', bbox_to_anchor=(1, 0.5), ncol=1, fontsize=12, title="Graph Elements") plt.title("P&ID Knowledge Graph Visualization", pad=20, fontsize=16) plt.axis('on') plt.grid(True) # Save with extra space for legend output_path = os.path.join(os.path.dirname(json_path), "graph_visualization.png") plt.savefig(output_path, bbox_inches='tight', dpi=300, facecolor='white', edgecolor='none') plt.close() return G, scaled_pos if __name__ == "__main__": # Test the visualization json_path = "results/001_page_1_text_aggregated.json" if os.path.exists(json_path): G, scaled_pos = create_graph_visualization(json_path) plt.show() else: print(f"Error: Could not find {json_path}")