Spaces:
Build error
Build error
| import os | |
| import json | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| from pathlib import Path | |
| import logging | |
| import traceback | |
| from storage import StorageFactory | |
| logger = logging.getLogger(__name__) | |
| def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None): | |
| """Construct network graph from aggregated detection data""" | |
| try: | |
| # Use provided storage or get a new one | |
| if storage is None: | |
| storage = StorageFactory.get_storage() | |
| # Create graph | |
| G = nx.Graph() | |
| pos = {} # For node positions | |
| # Add nodes from the aggregated data | |
| for node in data.get('nodes', []): | |
| node_id = node['id'] | |
| node_type = node['type'] | |
| # Calculate position based on node type | |
| if node_type == 'connection_point': | |
| pos[node_id] = (node['coords']['x'], node['coords']['y']) | |
| else: # symbol or text | |
| bbox = node['bbox'] | |
| pos[node_id] = ( | |
| (bbox['xmin'] + bbox['xmax']) / 2, | |
| (bbox['ymin'] + bbox['ymax']) / 2 | |
| ) | |
| # Add node with all its properties | |
| G.add_node(node_id, **node) | |
| # Add edges from the aggregated data | |
| for edge in data.get('edges', []): | |
| G.add_edge( | |
| edge['source'], | |
| edge['target'], | |
| **edge.get('properties', {}) | |
| ) | |
| # Create visualization | |
| plt.figure(figsize=(20, 20)) | |
| # Draw nodes with different colors based on type | |
| node_colors = [] | |
| for node in G.nodes(): | |
| node_type = G.nodes[node]['type'] | |
| if node_type == 'symbol': | |
| node_colors.append('lightblue') | |
| elif node_type == 'text': | |
| node_colors.append('lightgreen') | |
| else: # connection_point | |
| node_colors.append('lightgray') | |
| nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500) | |
| nx.draw_networkx_edges(G, pos, edge_color='gray', width=1) | |
| # Add labels | |
| labels = {} | |
| for node in G.nodes(): | |
| node_data = G.nodes[node] | |
| if node_data['type'] == 'symbol': | |
| labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}" | |
| elif node_data['type'] == 'text': | |
| content = node_data.get('content', '') | |
| labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}" | |
| else: | |
| labels[node] = f"C:{node_data['properties'].get('point_type', '')}" | |
| nx.draw_networkx_labels(G, pos, labels, font_size=8) | |
| plt.title("P&ID Knowledge Graph") | |
| plt.axis('off') | |
| # Save the visualization | |
| graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png") | |
| plt.savefig(graph_image_path, bbox_inches='tight', dpi=300) | |
| plt.close() | |
| # Save graph data as JSON for future use | |
| graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json") | |
| with open(graph_json_path, 'w') as f: | |
| json.dump(nx.node_link_data(G), f, indent=2) | |
| return G, pos, plt.gcf() | |
| except Exception as e: | |
| logger.error(f"Error in construct_graph_network: {str(e)}") | |
| traceback.print_exc() | |
| return None, None, None | |
| if __name__ == "__main__": | |
| # Test code | |
| test_data_path = "results/test_aggregated.json" | |
| if os.path.exists(test_data_path): | |
| with open(test_data_path, 'r') as f: | |
| test_data = json.load(f) | |
| G, pos, fig = construct_graph_network( | |
| test_data, | |
| "results/validation.json", | |
| "results" | |
| ) | |
| if fig: | |
| plt.show() | |