Spaces:
Build error
Build error
File size: 3,871 Bytes
9847531 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | import json
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import traceback
import uuid
def create_connected_graph(input_data):
"""Create a connected graph from the input data"""
try:
# Validate input data structure
if not isinstance(input_data, dict):
raise ValueError("Invalid input data format")
# Check for required keys in new format
required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges']
if not all(key in input_data for key in required_keys):
raise ValueError(f"Missing required keys in input data. Expected: {required_keys}")
# Create graph
G = nx.Graph()
# Track positions for layout
pos = {}
# Add symbol nodes
for symbol in input_data['symbols']:
bbox = symbol.get('bbox', [])
symbol_id = symbol.get('id', str(uuid.uuid4()))
if bbox:
# Calculate center position
center_x = (bbox['xmin'] + bbox['xmax']) / 2
center_y = (bbox['ymin'] + bbox['ymax']) / 2
pos[symbol_id] = (center_x, center_y)
G.add_node(
symbol_id,
type='symbol',
class_name=symbol.get('class', ''),
bbox=bbox,
confidence=symbol.get('confidence', 0.0)
)
# Add text nodes
for text in input_data['texts']:
bbox = text.get('bbox', [])
text_id = text.get('id', str(uuid.uuid4()))
if bbox:
center_x = (bbox['xmin'] + bbox['xmax']) / 2
center_y = (bbox['ymin'] + bbox['ymax']) / 2
pos[text_id] = (center_x, center_y)
G.add_node(
text_id,
type='text',
text=text.get('text', ''),
bbox=bbox,
confidence=text.get('confidence', 0.0)
)
# Add edges from the edges list
for edge in input_data['edges']:
source = edge.get('source')
target = edge.get('target')
if source and target and source in G and target in G:
G.add_edge(
source,
target,
type=edge.get('type', 'connection'),
properties=edge.get('properties', {})
)
# Create visualization
plt.figure(figsize=(20, 20))
# Draw nodes with fixed positions
nx.draw_networkx_nodes(G, pos,
node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node in G.nodes()],
node_size=500)
# Draw edges
nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
# Add labels
labels = {}
for node in G.nodes():
node_data = G.nodes[node]
if node_data['type'] == 'symbol':
labels[node] = f"S:{node_data['class_name']}"
else:
text = node_data.get('text', '')
labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}"
nx.draw_networkx_labels(G, pos, labels, font_size=8)
plt.title("P&ID Network Graph")
plt.axis('off')
return G, pos, plt.gcf()
except Exception as e:
print(f"Error in create_connected_graph: {str(e)}")
traceback.print_exc()
return None, None, None
if __name__ == "__main__":
# Test code
with open('results/0_aggregated.json') as f:
data = json.load(f)
G, pos, fig = create_connected_graph(data)
if fig:
plt.show() |