Spaces:
Build error
Build error
| import json | |
| import networkx as nx | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import traceback | |
| import uuid | |
| def create_connected_graph(input_data): | |
| """Create a connected graph from the input data""" | |
| try: | |
| # Validate input data structure | |
| if not isinstance(input_data, dict): | |
| raise ValueError("Invalid input data format") | |
| # Check for required keys in new format | |
| required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges'] | |
| if not all(key in input_data for key in required_keys): | |
| raise ValueError(f"Missing required keys in input data. Expected: {required_keys}") | |
| # Create graph | |
| G = nx.Graph() | |
| # Track positions for layout | |
| pos = {} | |
| # Add symbol nodes | |
| for symbol in input_data['symbols']: | |
| bbox = symbol.get('bbox', []) | |
| symbol_id = symbol.get('id', str(uuid.uuid4())) | |
| if bbox: | |
| # Calculate center position | |
| center_x = (bbox['xmin'] + bbox['xmax']) / 2 | |
| center_y = (bbox['ymin'] + bbox['ymax']) / 2 | |
| pos[symbol_id] = (center_x, center_y) | |
| G.add_node( | |
| symbol_id, | |
| type='symbol', | |
| class_name=symbol.get('class', ''), | |
| bbox=bbox, | |
| confidence=symbol.get('confidence', 0.0) | |
| ) | |
| # Add text nodes | |
| for text in input_data['texts']: | |
| bbox = text.get('bbox', []) | |
| text_id = text.get('id', str(uuid.uuid4())) | |
| if bbox: | |
| center_x = (bbox['xmin'] + bbox['xmax']) / 2 | |
| center_y = (bbox['ymin'] + bbox['ymax']) / 2 | |
| pos[text_id] = (center_x, center_y) | |
| G.add_node( | |
| text_id, | |
| type='text', | |
| text=text.get('text', ''), | |
| bbox=bbox, | |
| confidence=text.get('confidence', 0.0) | |
| ) | |
| # Add edges from the edges list | |
| for edge in input_data['edges']: | |
| source = edge.get('source') | |
| target = edge.get('target') | |
| if source and target and source in G and target in G: | |
| G.add_edge( | |
| source, | |
| target, | |
| type=edge.get('type', 'connection'), | |
| properties=edge.get('properties', {}) | |
| ) | |
| # Create visualization | |
| plt.figure(figsize=(20, 20)) | |
| # Draw nodes with fixed positions | |
| nx.draw_networkx_nodes(G, pos, | |
| node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node in G.nodes()], | |
| node_size=500) | |
| # Draw edges | |
| nx.draw_networkx_edges(G, pos, edge_color='gray', width=1) | |
| # Add labels | |
| labels = {} | |
| for node in G.nodes(): | |
| node_data = G.nodes[node] | |
| if node_data['type'] == 'symbol': | |
| labels[node] = f"S:{node_data['class_name']}" | |
| else: | |
| text = node_data.get('text', '') | |
| labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}" | |
| nx.draw_networkx_labels(G, pos, labels, font_size=8) | |
| plt.title("P&ID Network Graph") | |
| plt.axis('off') | |
| return G, pos, plt.gcf() | |
| except Exception as e: | |
| print(f"Error in create_connected_graph: {str(e)}") | |
| traceback.print_exc() | |
| return None, None, None | |
| if __name__ == "__main__": | |
| # Test code | |
| with open('results/0_aggregated.json') as f: | |
| data = json.load(f) | |
| G, pos, fig = create_connected_graph(data) | |
| if fig: | |
| plt.show() |