Spaces:
Build error
Build error
| import json | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import os | |
| from pprint import pprint | |
| import uuid | |
| def create_graph_visualization(json_path, save_plot=True): | |
| """Create and visualize a graph from the aggregated JSON data""" | |
| # Load the aggregated data | |
| with open(json_path, 'r') as f: | |
| data = json.load(f) | |
| print("\nData Structure:") | |
| print(f"Keys in data: {data.keys()}") | |
| for key in data.keys(): | |
| if isinstance(data[key], list): | |
| print(f"Number of {key}: {len(data[key])}") | |
| if data[key]: | |
| print(f"Sample {key}:", data[key][0]) | |
| # Create a new graph | |
| G = nx.Graph() | |
| pos = {} | |
| # Track unique junctions by coordinates to avoid duplicates | |
| junction_map = {} | |
| # Process lines and create unique junctions | |
| print("\nProcessing Lines and Junctions:") | |
| for line in data.get('lines', []): | |
| try: | |
| # Get line properties from edge data for accurate coordinates | |
| edge_data = None | |
| for edge in data.get('edges', []): | |
| if edge['id'] == line['id']: | |
| edge_data = edge | |
| break | |
| # Get coordinates and create unique ID for start/end points | |
| if edge_data and 'connection_points' in edge_data['properties']: | |
| conn_points = edge_data['properties']['connection_points'] | |
| start_x = int(conn_points['start']['x']) | |
| start_y = int(conn_points['start']['y']) | |
| end_x = int(conn_points['end']['x']) | |
| end_y = int(conn_points['end']['y']) | |
| start_id = str(uuid.uuid4()) | |
| end_id = str(uuid.uuid4()) | |
| else: | |
| # Fallback to line points | |
| start_x = int(line['start_point']['x']) | |
| start_y = int(line['start_point']['y']) | |
| end_x = int(line['end_point']['x']) | |
| end_y = int(line['end_point']['y']) | |
| start_id = line['start_point'].get('id', str(uuid.uuid4())) | |
| end_id = line['end_point'].get('id', str(uuid.uuid4())) | |
| # Skip invalid coordinates | |
| if not (0 <= start_x <= 10000 and 0 <= start_y <= 10000 and | |
| 0 <= end_x <= 10000 and 0 <= end_y <= 10000): | |
| print(f"Skipping line with invalid coordinates: ({start_x}, {start_y}) -> ({end_x}, {end_y})") | |
| continue | |
| # Create or get junction nodes | |
| start_key = f"{start_x}_{start_y}" | |
| end_key = f"{end_x}_{end_y}" | |
| if start_key not in junction_map: | |
| junction_map[start_key] = start_id | |
| G.add_node(start_id, | |
| type='junction', | |
| junction_type=line['start_point'].get('type', 'unknown'), | |
| coords={'x': start_x, 'y': start_y}) | |
| pos[start_id] = (start_x, start_y) | |
| if end_key not in junction_map: | |
| junction_map[end_key] = end_id | |
| G.add_node(end_id, | |
| type='junction', | |
| junction_type=line['end_point'].get('type', 'unknown'), | |
| coords={'x': end_x, 'y': end_y}) | |
| pos[end_id] = (end_x, end_y) | |
| # Add line as edge with style properties | |
| G.add_edge(junction_map[start_key], junction_map[end_key], | |
| type='line', | |
| style=line.get('style', {})) | |
| except Exception as e: | |
| print(f"Error processing line: {str(e)}") | |
| continue | |
| # Add symbols and texts after lines | |
| print("\nProcessing Symbols and Texts:") | |
| for node in data.get('nodes', []): | |
| if node['type'] in ['symbol', 'text']: | |
| node_id = node['id'] | |
| coords = node.get('coords', {}) | |
| if coords: | |
| x, y = coords['x'], coords['y'] | |
| elif 'center' in node: | |
| x, y = node['center']['x'], node['center']['y'] | |
| else: | |
| bbox = node['bbox'] | |
| x = (bbox['xmin'] + bbox['xmax']) / 2 | |
| y = (bbox['ymin'] + bbox['ymax']) / 2 | |
| G.add_node(node_id, **node) | |
| pos[node_id] = (x, y) | |
| print(f"Added {node['type']} at ({x}, {y})") | |
| # Add default node if graph is empty | |
| if not pos: | |
| default_id = str(uuid.uuid4()) | |
| G.add_node(default_id, type='junction', coords={'x': 0, 'y': 0}) | |
| pos[default_id] = (0, 0) | |
| # Scale positions to fit in [0, 1] range | |
| x_vals = [p[0] for p in pos.values()] | |
| y_vals = [p[1] for p in pos.values()] | |
| x_min, x_max = min(x_vals), max(x_vals) | |
| y_min, y_max = min(y_vals), max(y_vals) | |
| scaled_pos = {} | |
| for node, (x, y) in pos.items(): | |
| scaled_x = (x - x_min) / (x_max - x_min) if x_max > x_min else 0.5 | |
| scaled_y = 1 - ((y - y_min) / (y_max - y_min) if y_max > y_min else 0.5) # Flip Y coordinates | |
| scaled_pos[node] = (scaled_x, scaled_y) | |
| # Visualization attributes | |
| node_colors = [] | |
| node_sizes = [] | |
| labels = {} | |
| for node in G.nodes(): | |
| node_data = G.nodes[node] | |
| if node_data['type'] == 'symbol': | |
| node_colors.append('lightblue') | |
| node_sizes.append(1000) | |
| labels[node] = f"S:{node_data.get('properties', {}).get('class', 'unknown')}" | |
| elif node_data['type'] == 'text': | |
| node_colors.append('lightgreen') | |
| node_sizes.append(800) | |
| content = node_data.get('content', '') | |
| labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}" | |
| else: # junction | |
| node_colors.append('#ff0000') # Pure red | |
| node_sizes.append(5) # Even smaller junction nodes | |
| labels[node] = "" # No labels for junctions | |
| # Update visualization | |
| if save_plot: | |
| plt.figure(figsize=(20, 20)) | |
| # Draw edges with styles | |
| edge_styles = [] | |
| for (u, v, data) in G.edges(data=True): | |
| if data.get('type') == 'line': | |
| style = data.get('style', {}) | |
| color = style.get('color', '#000000') | |
| width = float(style.get('stroke_width', 0.5)) | |
| alpha = 0.7 | |
| line_style = '--' if style.get('connection_type') == 'dashed' else '-' | |
| nx.draw_networkx_edges(G, scaled_pos, | |
| edgelist=[(u, v)], | |
| edge_color=color, | |
| width=width, | |
| alpha=alpha, | |
| style=line_style) | |
| # Track unique styles for legend | |
| edge_style = (line_style, color, width) | |
| if edge_style not in edge_styles: | |
| edge_styles.append(edge_style) | |
| # Draw nodes with smaller junctions | |
| nx.draw_networkx_nodes(G, scaled_pos, | |
| node_color=node_colors, | |
| node_size=[3 if size == 5 else size for size in node_sizes], # Even smaller junctions | |
| alpha=1.0) | |
| # Add labels only for symbols and texts | |
| labels = {k: v for k, v in labels.items() if v} | |
| nx.draw_networkx_labels(G, scaled_pos, labels, | |
| font_size=8, | |
| font_weight='bold') | |
| # Create comprehensive legend | |
| legend_elements = [] | |
| # Node types | |
| legend_elements.extend([ | |
| plt.scatter([0], [0], c='lightblue', s=200, label='Symbol'), | |
| plt.scatter([0], [0], c='lightgreen', s=200, label='Text'), | |
| plt.scatter([0], [0], c='red', s=20, label='Junction') | |
| ]) | |
| # Line styles | |
| for style, color, width in edge_styles: | |
| legend_elements.append( | |
| plt.Line2D([0], [0], color=color, linestyle=style, | |
| linewidth=width, label=f'Line ({style})') | |
| ) | |
| # Add legend with two columns | |
| plt.legend(handles=legend_elements, | |
| loc='center left', | |
| bbox_to_anchor=(1, 0.5), | |
| ncol=1, | |
| fontsize=12, | |
| title="Graph Elements") | |
| plt.title("P&ID Knowledge Graph Visualization", pad=20, fontsize=16) | |
| plt.axis('on') | |
| plt.grid(True) | |
| # Save with extra space for legend | |
| output_path = os.path.join(os.path.dirname(json_path), "graph_visualization.png") | |
| plt.savefig(output_path, bbox_inches='tight', dpi=300, | |
| facecolor='white', edgecolor='none') | |
| plt.close() | |
| return G, scaled_pos | |
| if __name__ == "__main__": | |
| # Test the visualization | |
| json_path = "results/001_page_1_text_aggregated.json" | |
| if os.path.exists(json_path): | |
| G, scaled_pos = create_graph_visualization(json_path) | |
| plt.show() | |
| else: | |
| print(f"Error: Could not find {json_path}") |