Spaces:
Runtime error
Runtime error
File size: 3,887 Bytes
910e0d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import os
import json
import networkx as nx
import matplotlib.pyplot as plt
from pathlib import Path
import logging
import traceback
from storage import StorageFactory
logger = logging.getLogger(__name__)
def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
"""Construct network graph from aggregated detection data"""
try:
# Use provided storage or get a new one
if storage is None:
storage = StorageFactory.get_storage()
# Create graph
G = nx.Graph()
pos = {} # For node positions
# Add nodes from the aggregated data
for node in data.get('nodes', []):
node_id = node['id']
node_type = node['type']
# Calculate position based on node type
if node_type == 'connection_point':
pos[node_id] = (node['coords']['x'], node['coords']['y'])
else: # symbol or text
bbox = node['bbox']
pos[node_id] = (
(bbox['xmin'] + bbox['xmax']) / 2,
(bbox['ymin'] + bbox['ymax']) / 2
)
# Add node with all its properties
G.add_node(node_id, **node)
# Add edges from the aggregated data
for edge in data.get('edges', []):
G.add_edge(
edge['source'],
edge['target'],
**edge.get('properties', {})
)
# Create visualization
plt.figure(figsize=(20, 20))
# Draw nodes with different colors based on type
node_colors = []
for node in G.nodes():
node_type = G.nodes[node]['type']
if node_type == 'symbol':
node_colors.append('lightblue')
elif node_type == 'text':
node_colors.append('lightgreen')
else: # connection_point
node_colors.append('lightgray')
nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
# Add labels
labels = {}
for node in G.nodes():
node_data = G.nodes[node]
if node_data['type'] == 'symbol':
labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
elif node_data['type'] == 'text':
content = node_data.get('content', '')
labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
else:
labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
nx.draw_networkx_labels(G, pos, labels, font_size=8)
plt.title("P&ID Knowledge Graph")
plt.axis('off')
# Save the visualization
graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
plt.close()
# Save graph data as JSON for future use
graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
with open(graph_json_path, 'w') as f:
json.dump(nx.node_link_data(G), f, indent=2)
return G, pos, plt.gcf()
except Exception as e:
logger.error(f"Error in construct_graph_network: {str(e)}")
traceback.print_exc()
return None, None, None
if __name__ == "__main__":
# Test code
test_data_path = "results/test_aggregated.json"
if os.path.exists(test_data_path):
with open(test_data_path, 'r') as f:
test_data = json.load(f)
G, pos, fig = construct_graph_network(
test_data,
"results/validation.json",
"results"
)
if fig:
plt.show() |