import json import networkx as nx import matplotlib.pyplot as plt import os from pprint import pprint import uuid import argparse from pathlib import Path from tqdm import tqdm def create_graph_visualization(json_path: str, output_dir: str, base_name: str, save_plot: bool = True) -> dict: """Create graph visualization using actual coordinates from bboxes""" try: # Remove '_aggregated' suffix if present if base_name.endswith('_aggregated'): base_name = base_name[:-len('_aggregated')] print("\nLoading JSON data...") with open(json_path, 'r') as f: data = json.load(f) # Create graph G = nx.Graph() pos = {} valid_nodes = [] invalid_nodes = [] # First pass - collect valid nodes print("\nValidating nodes...") for node in tqdm(data.get('nodes', []), desc="Validating"): try: node_id = str(node.get('id', '')) x = float(node.get('x', 0)) y = float(node.get('y', 0)) if node_id and x and y: # Only add if we have valid coordinates valid_nodes.append(node) pos[node_id] = (x, y) else: invalid_nodes.append(node) except (ValueError, TypeError) as e: invalid_nodes.append(node) continue print(f"\nFound {len(valid_nodes)} valid nodes and {len(invalid_nodes)} invalid nodes") # Add valid nodes print("\nAdding valid nodes...") for node in tqdm(valid_nodes, desc="Nodes"): node_id = str(node.get('id', '')) attrs = { 'type': node.get('type', ''), 'label': node.get('label', ''), 'x': float(node.get('x', 0)), 'y': float(node.get('y', 0)) } G.add_node(node_id, **attrs) # Add valid edges (only between valid nodes) print("\nAdding valid edges...") valid_edges = [] invalid_edges = [] for edge in tqdm(data.get('edges', []), desc="Edges"): try: start_id = str(edge.get('start_point', '')) end_id = str(edge.get('end_point', '')) if start_id in pos and end_id in pos: # Only add if both nodes exist valid_edges.append(edge) attrs = { 'type': edge.get('type', ''), 'weight': edge.get('weight', 1.0) } G.add_edge(start_id, end_id, **attrs) else: invalid_edges.append(edge) except Exception as e: invalid_edges.append(edge) continue print(f"\nFound {len(valid_edges)} valid edges and {len(invalid_edges)} invalid edges") if save_plot: print("\nGenerating visualization...") plt.figure(figsize=(20, 20)) print("Drawing graph elements...") with tqdm(total=3, desc="Drawing") as pbar: # Draw nodes nx.draw_networkx_nodes(G, pos, node_color='lightblue', node_size=100) pbar.update(1) # Draw edges nx.draw_networkx_edges(G, pos) pbar.update(1) # Save plot image_path = os.path.join(output_dir, f"{base_name}_graph_visualization.png") plt.savefig(image_path, bbox_inches='tight', dpi=300) plt.close() pbar.update(1) print(f"\nVisualization saved to: {image_path}") return { 'success': True, 'image_path': image_path, 'graph': G, 'stats': { 'valid_nodes': len(valid_nodes), 'invalid_nodes': len(invalid_nodes), 'valid_edges': len(valid_edges), 'invalid_edges': len(invalid_edges) } } return { 'success': True, 'graph': G } except Exception as e: print(f"\nError creating graph: {str(e)}") return { 'success': False, 'error': str(e) } if __name__ == "__main__": """Test the graph visualization independently""" # Set up argument parser parser = argparse.ArgumentParser(description='Create and visualize graph from aggregated JSON') parser.add_argument('--json_path', type=str, default="results/002_page_1_aggregated.json", help='Path to aggregated JSON file') parser.add_argument('--output_dir', type=str, default="results", help='Directory to save outputs') parser.add_argument('--show', action='store_true', help='Show the plot interactively') args = parser.parse_args() # Verify input file exists if not os.path.exists(args.json_path): print(f"Error: Could not find input file {args.json_path}") exit(1) # Create output directory if it doesn't exist os.makedirs(args.output_dir, exist_ok=True) # Get base name from input file and remove '_aggregated' suffix base_name = Path(args.json_path).stem if base_name.endswith('_aggregated'): base_name = base_name[:-len('_aggregated')] print(f"\nProcessing:") print(f"Input: {args.json_path}") print(f"Output: {args.output_dir}/{base_name}_graph_visualization.png") try: # Create visualization result = create_graph_visualization( json_path=args.json_path, output_dir=args.output_dir, base_name=base_name, save_plot=True ) if result['success']: print(f"\nSuccess! Graph visualization saved to: {result['image_path']}") if args.show: plt.show() else: print(f"\nError: {result['error']}") except Exception as e: print(f"\nError during visualization: {str(e)}") raise