File size: 4,629 Bytes
5ceead6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 | import torch
import json
import os
from dataset.Graphs import Graphs
from typing import List, Dict, Tuple
def combine_feature_stats(chunks: List[Dict]) -> Tuple[torch.Tensor, torch.Tensor, int]:
"""
Combine mean/std/count from multiple chunks using Welford's algorithm.
Returns combined mean, std, and total count.
"""
n_total = 0
mean_total = None
M2_total = None
for chunk in chunks:
n_k = chunk['count']
if n_k == 0:
continue
mean_k = torch.tensor(chunk['mean'])
std_k = torch.tensor(chunk['std'])
M2_k = (std_k ** 2) * n_k
if n_total == 0:
mean_total = mean_k
M2_total = M2_k
n_total = n_k
else:
delta = mean_k - mean_total
N = n_total + n_k
mean_total += delta * (n_k / N)
M2_total += M2_k + (delta ** 2) * (n_total * n_k / N)
n_total = N
if n_total == 0:
return torch.tensor([]), torch.tensor([]), 0
std_total = torch.sqrt(M2_total / n_total)
return mean_total, std_total, n_total
def global_stats(dirpath: str, dtype: torch.dtype) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, int]]:
"""
Load all JSON stats files in a directory, combine node, edge, and global stats,
and optionally save the combined stats as JSON to `save_path`.
"""
combined_stats_path = os.path.join(dirpath, "global_stats.json")
if not os.path.exists(combined_stats_path):
stats_list = []
for fname in os.listdir(dirpath):
if fname.endswith('.json'):
with open(os.path.join(dirpath, fname), 'r') as f:
stats_list.append(json.load(f))
node_stats = [s['node'] for s in stats_list]
edge_stats = [s['edge'] for s in stats_list]
combined = {
'node': combine_feature_stats(node_stats),
'edge': combine_feature_stats(edge_stats),
}
combined_json = {}
for key, (mean, std, count) in combined.items():
combined_json[key] = {
'mean': mean.tolist() if mean.numel() > 0 else [],
'std': std.tolist() if std.numel() > 0 else [],
'count': count,
}
with open(combined_stats_path, 'w') as f:
json.dump(combined_json, f, indent=4)
with open(combined_stats_path, 'r') as f:
combined_json = json.load(f)
def to_tensor(d):
mean = torch.tensor(d['mean'], dtype=dtype) if d['mean'] else torch.tensor([], dtype=dtype)
std = torch.tensor(d['std'], dtype=dtype) if d['std'] else torch.tensor([], dtype=dtype)
count = d['count']
return mean, std, count
return {
'node': to_tensor(combined_json['node']),
'edge': to_tensor(combined_json['edge']),
}
def compute_stats(feats, eps=1e-6):
mean = feats.mean(dim=0)
if feats.size(0) > 1:
var = ((feats - mean) ** 2).mean(dim=0)
else:
var = torch.zeros_like(mean)
std = torch.sqrt(var)
std = torch.where(std < eps, torch.full_like(std, eps), std)
return mean, std
def save_stats(graphs: 'Graphs', filepath: str, categorical_unique_threshold=50):
"""
Compute and save normalization stats (mean, std, counts) for node and edge features.
Categorical features (few unique values) have normalization disabled (mean=0, std=1).
"""
if len(graphs) == 0:
raise ValueError("No graphs to compute stats from.")
# Node and edge features
all_node_feats = torch.cat([g.ndata['features'] for g, _ in graphs], dim=0)
all_edge_feats = torch.cat([g.edata['features'] for g, _ in graphs], dim=0)
counts = {
'node': all_node_feats.size(0),
'edge': all_edge_feats.size(0),
}
node_mean, node_std = compute_stats(all_node_feats)
edge_mean, edge_std = compute_stats(all_edge_feats)
categorical_mask = torch.tensor([
torch.unique(all_node_feats[:, i]).numel() < categorical_unique_threshold
for i in range(node_mean.size(0))
], dtype=torch.bool)
node_mean[categorical_mask] = 0.0
node_std[categorical_mask] = 1.0
stats = {
'node': {
'mean': node_mean.tolist(),
'std': node_std.tolist(),
'count': counts['node'],
},
'edge': {
'mean': edge_mean.tolist(),
'std': edge_std.tolist(),
'count': counts['edge'],
},
}
os.makedirs(os.path.dirname(filepath), exist_ok=True)
with open(filepath, 'w') as f:
json.dump(stats, f, indent=4) |