File size: 3,871 Bytes
9847531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import json
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import traceback
import uuid

def create_connected_graph(input_data):
    """Create a connected graph from the input data"""
    try:
        # Validate input data structure
        if not isinstance(input_data, dict):
            raise ValueError("Invalid input data format")

        # Check for required keys in new format
        required_keys = ['symbols', 'texts', 'lines', 'nodes', 'edges']
        if not all(key in input_data for key in required_keys):
            raise ValueError(f"Missing required keys in input data. Expected: {required_keys}")

        # Create graph
        G = nx.Graph()

        # Track positions for layout
        pos = {}

        # Add symbol nodes
        for symbol in input_data['symbols']:
            bbox = symbol.get('bbox', [])
            symbol_id = symbol.get('id', str(uuid.uuid4()))
            
            if bbox:
                # Calculate center position
                center_x = (bbox['xmin'] + bbox['xmax']) / 2
                center_y = (bbox['ymin'] + bbox['ymax']) / 2
                pos[symbol_id] = (center_x, center_y)
                
                G.add_node(
                    symbol_id,
                    type='symbol',
                    class_name=symbol.get('class', ''),
                    bbox=bbox,
                    confidence=symbol.get('confidence', 0.0)
                )

        # Add text nodes
        for text in input_data['texts']:
            bbox = text.get('bbox', [])
            text_id = text.get('id', str(uuid.uuid4()))
            
            if bbox:
                center_x = (bbox['xmin'] + bbox['xmax']) / 2
                center_y = (bbox['ymin'] + bbox['ymax']) / 2
                pos[text_id] = (center_x, center_y)
                
                G.add_node(
                    text_id,
                    type='text',
                    text=text.get('text', ''),
                    bbox=bbox,
                    confidence=text.get('confidence', 0.0)
                )

        # Add edges from the edges list
        for edge in input_data['edges']:
            source = edge.get('source')
            target = edge.get('target')
            if source and target and source in G and target in G:
                G.add_edge(
                    source,
                    target,
                    type=edge.get('type', 'connection'),
                    properties=edge.get('properties', {})
                )

        # Create visualization
        plt.figure(figsize=(20, 20))
        
        # Draw nodes with fixed positions
        nx.draw_networkx_nodes(G, pos,
                             node_color=['lightblue' if G.nodes[node]['type'] == 'symbol' else 'lightgreen' for node in G.nodes()],
                             node_size=500)
        
        # Draw edges
        nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
        
        # Add labels
        labels = {}
        for node in G.nodes():
            node_data = G.nodes[node]
            if node_data['type'] == 'symbol':
                labels[node] = f"S:{node_data['class_name']}"
            else:
                text = node_data.get('text', '')
                labels[node] = f"T:{text[:10]}..." if len(text) > 10 else f"T:{text}"
                
        nx.draw_networkx_labels(G, pos, labels, font_size=8)

        plt.title("P&ID Network Graph")
        plt.axis('off')

        return G, pos, plt.gcf()

    except Exception as e:
        print(f"Error in create_connected_graph: {str(e)}")
        traceback.print_exc()
        return None, None, None

if __name__ == "__main__":
    # Test code
    with open('results/0_aggregated.json') as f:
        data = json.load(f)
    
    G, pos, fig = create_connected_graph(data)
    if fig:
        plt.show()