File size: 4,022 Bytes
9847531
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import json
import networkx as nx
import matplotlib.pyplot as plt
from pathlib import Path
import logging
import traceback
from storage import StorageFactory

logger = logging.getLogger(__name__)

def construct_graph_network(data: dict, validation_results_path: str, results_dir: str, storage=None):
    """Construct network graph from aggregated detection data"""
    try:
        # Use provided storage or get a new one
        if storage is None:
            storage = StorageFactory.get_storage()
        
        # Create graph
        G = nx.Graph()
        pos = {}  # For node positions
        
        # Add nodes from the aggregated data
        for node in data.get('nodes', []):
            node_id = node['id']
            node_type = node['type']
            
            # Calculate position based on node type
            if node_type == 'connection_point':
                pos[node_id] = (node['coords']['x'], node['coords']['y'])
            else:  # symbol or text
                bbox = node['bbox']
                pos[node_id] = (
                    (bbox['xmin'] + bbox['xmax']) / 2,
                    (bbox['ymin'] + bbox['ymax']) / 2
                )
            
            # Add node with all its properties
            G.add_node(node_id, **node)
        
        # Add edges from the aggregated data
        for edge in data.get('edges', []):
            G.add_edge(
                edge['source'],
                edge['target'],
                **edge.get('properties', {})
            )
        
        # Create visualization
        plt.figure(figsize=(20, 20))
        
        # Draw nodes with different colors based on type
        node_colors = []
        for node in G.nodes():
            node_type = G.nodes[node]['type']
            if node_type == 'symbol':
                node_colors.append('lightblue')
            elif node_type == 'text':
                node_colors.append('lightgreen')
            else:  # connection_point
                node_colors.append('lightgray')
        
        nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=500)
        nx.draw_networkx_edges(G, pos, edge_color='gray', width=1)
        
        # Add labels
        labels = {}
        for node in G.nodes():
            node_data = G.nodes[node]
            if node_data['type'] == 'symbol':
                labels[node] = f"S:{node_data.get('properties', {}).get('class', '')}"
            elif node_data['type'] == 'text':
                content = node_data.get('content', '')
                labels[node] = f"T:{content[:10]}..." if len(content) > 10 else f"T:{content}"
            else:
                labels[node] = f"C:{node_data['properties'].get('point_type', '')}"
        
        nx.draw_networkx_labels(G, pos, labels, font_size=8)
        
        plt.title("P&ID Knowledge Graph")
        plt.axis('off')
        
        # Save the visualization
        graph_image_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph.png")
        plt.savefig(graph_image_path, bbox_inches='tight', dpi=300)
        plt.close()
        
        # Save graph data as JSON for future use
        graph_json_path = os.path.join(results_dir, f"{Path(data.get('image_path', 'graph')).stem}_graph_data.json")
        with open(graph_json_path, 'w') as f:
            json.dump(nx.node_link_data(G), f, indent=2)
        
        return G, pos, plt.gcf()
        
    except Exception as e:
        logger.error(f"Error in construct_graph_network: {str(e)}")
        traceback.print_exc()
        return None, None, None

if __name__ == "__main__":
    # Test code
    test_data_path = "results/test_aggregated.json"
    if os.path.exists(test_data_path):
        with open(test_data_path, 'r') as f:
            test_data = json.load(f)
        
        G, pos, fig = construct_graph_network(
            test_data,
            "results/validation.json",
            "results"
        )
        if fig:
            plt.show()