File size: 4,629 Bytes
5ceead6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import torch
import json
import os
from dataset.Graphs import Graphs
from typing import List, Dict, Tuple

def combine_feature_stats(chunks: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor, int]:
    """
    Combine mean/std/count from multiple chunks using Welford's algorithm.
    Returns combined mean, std, and total count.
    """
    n_total = 0
    mean_total = None
    M2_total = None

    for chunk in chunks:
        n_k = chunk['count']
        if n_k == 0:
            continue

        mean_k = torch.tensor(chunk['mean'])
        std_k = torch.tensor(chunk['std'])
        M2_k = (std_k ** 2) * n_k

        if n_total == 0:
            mean_total = mean_k
            M2_total = M2_k
            n_total = n_k
        else:
            delta = mean_k - mean_total
            N = n_total + n_k
            mean_total += delta * (n_k / N)
            M2_total += M2_k + (delta ** 2) * (n_total * n_k / N)
            n_total = N

    if n_total == 0:
        return torch.tensor([]), torch.tensor([]), 0

    std_total = torch.sqrt(M2_total / n_total)
    return mean_total, std_total, n_total

def global_stats(dirpath: str, dtype: torch.dtype) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, int]]:
    """
    Load all JSON stats files in a directory, combine node, edge, and global stats,
    and optionally save the combined stats as JSON to `save_path`.
    """

    combined_stats_path = os.path.join(dirpath, "global_stats.json")

    if not os.path.exists(combined_stats_path):
        stats_list = []
        for fname in os.listdir(dirpath):
            if fname.endswith('.json'):
                with open(os.path.join(dirpath, fname), 'r') as f:
                    stats_list.append(json.load(f))

        node_stats = [s['node'] for s in stats_list]
        edge_stats = [s['edge'] for s in stats_list]

        combined = {
            'node': combine_feature_stats(node_stats),
            'edge': combine_feature_stats(edge_stats),
        }

        combined_json = {}
        for key, (mean, std, count) in combined.items():
            combined_json[key] = {
                'mean': mean.tolist() if mean.numel() > 0 else [],
                'std': std.tolist() if std.numel() > 0 else [],
                'count': count,
            }

        with open(combined_stats_path, 'w') as f:
            json.dump(combined_json, f, indent=4)

    with open(combined_stats_path, 'r') as f:
        combined_json = json.load(f)

    def to_tensor(d):
        mean = torch.tensor(d['mean'], dtype=dtype) if d['mean'] else torch.tensor([], dtype=dtype)
        std = torch.tensor(d['std'], dtype=dtype) if d['std'] else torch.tensor([], dtype=dtype)
        count = d['count']
        return mean, std, count

    return {
        'node': to_tensor(combined_json['node']),
        'edge': to_tensor(combined_json['edge']),
    }

def compute_stats(feats, eps=1e-6):
    mean = feats.mean(dim=0)
    if feats.size(0) > 1:
        var = ((feats - mean) ** 2).mean(dim=0)
    else:
        var = torch.zeros_like(mean)
    std = torch.sqrt(var)
    std = torch.where(std < eps, torch.full_like(std, eps), std)

    return mean, std

def save_stats(graphs: 'Graphs', filepath: str, categorical_unique_threshold=50):
    """
    Compute and save normalization stats (mean, std, counts) for node and edge features.
    Categorical features (few unique values) have normalization disabled (mean=0, std=1).
    """
    if len(graphs) == 0:
        raise ValueError("No graphs to compute stats from.")

    # Node and edge features
    all_node_feats = torch.cat([g.ndata['features'] for g, _ in graphs], dim=0)
    all_edge_feats = torch.cat([g.edata['features'] for g, _ in graphs], dim=0)

    counts = {
        'node': all_node_feats.size(0),
        'edge': all_edge_feats.size(0),
    }

    node_mean, node_std = compute_stats(all_node_feats)
    edge_mean, edge_std = compute_stats(all_edge_feats)

    categorical_mask = torch.tensor([
        torch.unique(all_node_feats[:, i]).numel() < categorical_unique_threshold
        for i in range(node_mean.size(0))
    ], dtype=torch.bool)
    node_mean[categorical_mask] = 0.0
    node_std[categorical_mask] = 1.0

    stats = {
        'node': {
            'mean': node_mean.tolist(),
            'std': node_std.tolist(),
            'count': counts['node'],
        },
        'edge': {
            'mean': edge_mean.tolist(),
            'std': edge_std.tolist(),
            'count': counts['edge'],
        },
    }

    os.makedirs(os.path.dirname(filepath), exist_ok=True)

    with open(filepath, 'w') as f:
        json.dump(stats, f, indent=4)