import json import networkx as nx import numpy as np import matplotlib.pyplot as plt import traceback import uuid def create_connected_graph(input_data): """Create a connected graph from the input data""" try: # Validate input data structure if not isinstance(input_data, dict): raise ValueError("Invalid input data format") # Check for required keys in new format required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges'] if not all(key in input_data for key in required_keys): raise ValueError(f"Missing required keys in input data. Expected: {required_keys}") # Create graph G = nx.Graph() # Track positions for layout pos = {} # Add symbol nodes for symbol in input_data['symbols']: bbox = symbol.get('bbox', []) symbol_id = symbol.get('id', str(uuid.uuid4())) if bbox: # Calculate center position center_x = (bbox['xmin'] + bbox['xmax']) / 2 center_y = (bbox['ymin'] + bbox['ymax']) / 2 pos[symbol_id] = (center_x, center_y) G.add_node( symbol_id, type='symbol', class_name=symbol.get('class', ''), bbox=bbox, confidence=symbol.get('confidence', 0.0) ) # Add text nodes for text in input_data['texts']: bbox = text.get('bbox', []) text_id = text.get('id', str(uuid.uuid4())) if bbox: center_x = (bbox['xmin'] + bbox['xmax']) / 2 center_y = (bbox['ymin'] + bbox['ymax']) / 2 pos[text_id] = (center_x, center_y) G.add_node( text_id, type='text', text=text.get('text', ''), bbox=bbox, confidence=text.get('confidence', 0.0) ) # Add edges from the edges list for edge in input_data['edges']: source = edge.get('source') target = edge.get('target') if source and target and source in G and target in G: G.add_edge( source, target, type=edge.get('type', 'connection'), properties=edge.get('properties', {}) ) # Create visualization plt.figure(figsize=(20, 20)) # Draw nodes with fixed positions nx.draw_networkx_nodes(G, pos, node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node in G.nodes()], node_size=500) # Draw edges nx.draw_networkx_edges(G, pos, edge_color='gray', width=1) # Add labels labels = {} for node in G.nodes(): node_data = G.nodes[node] if node_data['type'] == 'symbol': labels[node] = f"S:{node_data['class_name']}" else: text = node_data.get('text', '') labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}" nx.draw_networkx_labels(G, pos, labels, font_size=8) plt.title("P&ID Network Graph") plt.axis('off') return G, pos, plt.gcf() except Exception as e: print(f"Error in create_connected_graph: {str(e)}") traceback.print_exc() return None, None, None if __name__ == "__main__": # Test code with open('results/0_aggregated.json') as f: data = json.load(f) G, pos, fig = create_connected_graph(data) if fig: plt.show()