Intelligent_PID / graph_visualization.py
msIntui
feat: initial clean deployment
910e0d4
import json
import networkx as nx
import matplotlib.pyplot as plt
import os
from pprint import pprint
import uuid
import argparse
from pathlib import Path
from tqdm import tqdm
def create_graph_visualization(json_path: str, output_dir: str, base_name: str, save_plot: bool = True) -> dict:
"""Create graph visualization using actual coordinates from bboxes"""
try:
# Remove '_aggregated' suffix if present
if base_name.endswith('_aggregated'):
base_name = base_name[:-len('_aggregated')]
print("\nLoading JSON data...")
with open(json_path, 'r') as f:
data = json.load(f)
# Create graph
G = nx.Graph()
pos = {}
valid_nodes = []
invalid_nodes = []
# First pass - collect valid nodes
print("\nValidating nodes...")
for node in tqdm(data.get('nodes', []), desc="Validating"):
try:
node_id = str(node.get('id', ''))
x = float(node.get('x', 0))
y = float(node.get('y', 0))
if node_id and x and y: # Only add if we have valid coordinates
valid_nodes.append(node)
pos[node_id] = (x, y)
else:
invalid_nodes.append(node)
except (ValueError, TypeError) as e:
invalid_nodes.append(node)
continue
print(f"\nFound {len(valid_nodes)} valid nodes and {len(invalid_nodes)} invalid nodes")
# Add valid nodes
print("\nAdding valid nodes...")
for node in tqdm(valid_nodes, desc="Nodes"):
node_id = str(node.get('id', ''))
attrs = {
'type': node.get('type', ''),
'label': node.get('label', ''),
'x': float(node.get('x', 0)),
'y': float(node.get('y', 0))
}
G.add_node(node_id, **attrs)
# Add valid edges (only between valid nodes)
print("\nAdding valid edges...")
valid_edges = []
invalid_edges = []
for edge in tqdm(data.get('edges', []), desc="Edges"):
try:
start_id = str(edge.get('start_point', ''))
end_id = str(edge.get('end_point', ''))
if start_id in pos and end_id in pos: # Only add if both nodes exist
valid_edges.append(edge)
attrs = {
'type': edge.get('type', ''),
'weight': edge.get('weight', 1.0)
}
G.add_edge(start_id, end_id, **attrs)
else:
invalid_edges.append(edge)
except Exception as e:
invalid_edges.append(edge)
continue
print(f"\nFound {len(valid_edges)} valid edges and {len(invalid_edges)} invalid edges")
if save_plot:
print("\nGenerating visualization...")
plt.figure(figsize=(20, 20))
print("Drawing graph elements...")
with tqdm(total=3, desc="Drawing") as pbar:
# Draw nodes
nx.draw_networkx_nodes(G, pos,
node_color='lightblue',
node_size=100)
pbar.update(1)
# Draw edges
nx.draw_networkx_edges(G, pos)
pbar.update(1)
# Save plot
image_path = os.path.join(output_dir, f"{base_name}_graph_visualization.png")
plt.savefig(image_path, bbox_inches='tight', dpi=300)
plt.close()
pbar.update(1)
print(f"\nVisualization saved to: {image_path}")
return {
'success': True,
'image_path': image_path,
'graph': G,
'stats': {
'valid_nodes': len(valid_nodes),
'invalid_nodes': len(invalid_nodes),
'valid_edges': len(valid_edges),
'invalid_edges': len(invalid_edges)
}
}
return {
'success': True,
'graph': G
}
except Exception as e:
print(f"\nError creating graph: {str(e)}")
return {
'success': False,
'error': str(e)
}
if __name__ == "__main__":
"""Test the graph visualization independently"""
# Set up argument parser
parser = argparse.ArgumentParser(description='Create and visualize graph from aggregated JSON')
parser.add_argument('--json_path', type=str, default="results/002_page_1_aggregated.json",
help='Path to aggregated JSON file')
parser.add_argument('--output_dir', type=str, default="results",
help='Directory to save outputs')
parser.add_argument('--show', action='store_true',
help='Show the plot interactively')
args = parser.parse_args()
# Verify input file exists
if not os.path.exists(args.json_path):
print(f"Error: Could not find input file {args.json_path}")
exit(1)
# Create output directory if it doesn't exist
os.makedirs(args.output_dir, exist_ok=True)
# Get base name from input file and remove '_aggregated' suffix
base_name = Path(args.json_path).stem
if base_name.endswith('_aggregated'):
base_name = base_name[:-len('_aggregated')]
print(f"\nProcessing:")
print(f"Input: {args.json_path}")
print(f"Output: {args.output_dir}/{base_name}_graph_visualization.png")
try:
# Create visualization
result = create_graph_visualization(
json_path=args.json_path,
output_dir=args.output_dir,
base_name=base_name,
save_plot=True
)
if result['success']:
print(f"\nSuccess! Graph visualization saved to: {result['image_path']}")
if args.show:
plt.show()
else:
print(f"\nError: {result['error']}")
except Exception as e:
print(f"\nError during visualization: {str(e)}")
raise