Spaces:
Runtime error
Runtime error
File size: 6,468 Bytes
910e0d4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | import json
import networkx as nx
import matplotlib.pyplot as plt
import os
from pprint import pprint
import uuid
import argparse
from pathlib import Path
from tqdm import tqdm
def create_graph_visualization(json_path: str, output_dir: str, base_name: str, save_plot: bool = True) -> dict:
"""Create graph visualization using actual coordinates from bboxes"""
try:
# Remove '_aggregated' suffix if present
if base_name.endswith('_aggregated'):
base_name = base_name[:-len('_aggregated')]
print("\nLoading JSON data...")
with open(json_path, 'r') as f:
data = json.load(f)
# Create graph
G = nx.Graph()
pos = {}
valid_nodes = []
invalid_nodes = []
# First pass - collect valid nodes
print("\nValidating nodes...")
for node in tqdm(data.get('nodes', []), desc="Validating"):
try:
node_id = str(node.get('id', ''))
x = float(node.get('x', 0))
y = float(node.get('y', 0))
if node_id and x and y: # Only add if we have valid coordinates
valid_nodes.append(node)
pos[node_id] = (x, y)
else:
invalid_nodes.append(node)
except (ValueError, TypeError) as e:
invalid_nodes.append(node)
continue
print(f"\nFound {len(valid_nodes)} valid nodes and {len(invalid_nodes)} invalid nodes")
# Add valid nodes
print("\nAdding valid nodes...")
for node in tqdm(valid_nodes, desc="Nodes"):
node_id = str(node.get('id', ''))
attrs = {
'type': node.get('type', ''),
'label': node.get('label', ''),
'x': float(node.get('x', 0)),
'y': float(node.get('y', 0))
}
G.add_node(node_id, **attrs)
# Add valid edges (only between valid nodes)
print("\nAdding valid edges...")
valid_edges = []
invalid_edges = []
for edge in tqdm(data.get('edges', []), desc="Edges"):
try:
start_id = str(edge.get('start_point', ''))
end_id = str(edge.get('end_point', ''))
if start_id in pos and end_id in pos: # Only add if both nodes exist
valid_edges.append(edge)
attrs = {
'type': edge.get('type', ''),
'weight': edge.get('weight', 1.0)
}
G.add_edge(start_id, end_id, **attrs)
else:
invalid_edges.append(edge)
except Exception as e:
invalid_edges.append(edge)
continue
print(f"\nFound {len(valid_edges)} valid edges and {len(invalid_edges)} invalid edges")
if save_plot:
print("\nGenerating visualization...")
plt.figure(figsize=(20, 20))
print("Drawing graph elements...")
with tqdm(total=3, desc="Drawing") as pbar:
# Draw nodes
nx.draw_networkx_nodes(G, pos,
node_color='lightblue',
node_size=100)
pbar.update(1)
# Draw edges
nx.draw_networkx_edges(G, pos)
pbar.update(1)
# Save plot
image_path = os.path.join(output_dir, f"{base_name}_graph_visualization.png")
plt.savefig(image_path, bbox_inches='tight', dpi=300)
plt.close()
pbar.update(1)
print(f"\nVisualization saved to: {image_path}")
return {
'success': True,
'image_path': image_path,
'graph': G,
'stats': {
'valid_nodes': len(valid_nodes),
'invalid_nodes': len(invalid_nodes),
'valid_edges': len(valid_edges),
'invalid_edges': len(invalid_edges)
}
}
return {
'success': True,
'graph': G
}
except Exception as e:
print(f"\nError creating graph: {str(e)}")
return {
'success': False,
'error': str(e)
}
if __name__ == "__main__":
"""Test the graph visualization independently"""
# Set up argument parser
parser = argparse.ArgumentParser(description='Create and visualize graph from aggregated JSON')
parser.add_argument('--json_path', type=str, default="results/002_page_1_aggregated.json",
help='Path to aggregated JSON file')
parser.add_argument('--output_dir', type=str, default="results",
help='Directory to save outputs')
parser.add_argument('--show', action='store_true',
help='Show the plot interactively')
args = parser.parse_args()
# Verify input file exists
if not os.path.exists(args.json_path):
print(f"Error: Could not find input file {args.json_path}")
exit(1)
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Get base name from input file and remove '_aggregated' suffix
base_name = Path(args.json_path).stem
if base_name.endswith('_aggregated'):
base_name = base_name[:-len('_aggregated')]
print(f"\nProcessing:")
print(f"Input: {args.json_path}")
print(f"Output: {args.output_dir}/{base_name}_graph_visualization.png")
try:
# Create visualization
result = create_graph_visualization(
json_path=args.json_path,
output_dir=args.output_dir,
base_name=base_name,
save_plot=True
)
if result['success']:
print(f"\nSuccess! Graph visualization saved to: {result['image_path']}")
if args.show:
plt.show()
else:
print(f"\nError: {result['error']}")
except Exception as e:
print(f"\nError during visualization: {str(e)}")
raise |