import os import json import networkx as nx import matplotlib.pyplot as plt from pathlib import Path import logging import traceback from storage import StorageFactory logger = logging.getLogger(__name__) def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None): """Construct network graph from aggregated detection data""" try: # Use provided storage or get a new one if storage is None: storage = StorageFactory.get_storage() # Create graph G = nx.Graph() pos = {} # For node positions # Add nodes from the aggregated data for node in data.get('nodes', []): node_id = node['id'] node_type = node['type'] # Calculate position based on node type if node_type == 'connection_point': pos[node_id] = (node['coords']['x'], node['coords']['y']) else: # symbol or text bbox = node['bbox'] pos[node_id] = ( (bbox['xmin'] + bbox['xmax']) / 2, (bbox['ymin'] + bbox['ymax']) / 2 ) # Add node with all its properties G.add_node(node_id, **node) # Add edges from the aggregated data for edge in data.get('edges', []): G.add_edge( edge['source'], edge['target'], **edge.get('properties', {}) ) # Create visualization plt.figure(figsize=(20, 20)) # Draw nodes with different colors based on type node_colors = [] for node in G.nodes(): node_type = G.nodes[node]['type'] if node_type == 'symbol': node_colors.append('lightblue') elif node_type == 'text': node_colors.append('lightgreen') else: # connection_point node_colors.append('lightgray') nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500) nx.draw_networkx_edges(G, pos, edge_color='gray', width=1) # Add labels labels = {} for node in G.nodes(): node_data = G.nodes[node] if node_data['type'] == 'symbol': labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}" elif node_data['type'] == 'text': content = node_data.get('content', '') labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}" else: labels[node] = f"C:{node_data['properties'].get('point_type', '')}" nx.draw_networkx_labels(G, pos, labels, font_size=8) plt.title("P&ID Knowledge Graph") plt.axis('off') # Save the visualization graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png") plt.savefig(graph_image_path, bbox_inches='tight', dpi=300) plt.close() # Save graph data as JSON for future use graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json") with open(graph_json_path, 'w') as f: json.dump(nx.node_link_data(G), f, indent=2) return G, pos, plt.gcf() except Exception as e: logger.error(f"Error in construct_graph_network: {str(e)}") traceback.print_exc() return None, None, None if __name__ == "__main__": # Test code test_data_path = "results/test_aggregated.json" if os.path.exists(test_data_path): with open(test_data_path, 'r') as f: test_data = json.load(f) G, pos, fig = construct_graph_network( test_data, "results/validation.json", "results" ) if fig: plt.show()