import torch import torch.nn as nn import dgl def mean_nodes(batched_graph, feat_key='h', op='mean', node_split=None): """ Aggregates node features per disjoint graph in a batched DGLGraph. Args: batched_graph: DGLGraph feat_key: str, node feature key op: 'mean', 'sum', or 'max' node_split: 1D tensor or list of ints (num nodes per graph) Returns: Tensor of shape [num_graphs, node_feat_dim] """ h = batched_graph.ndata[feat_key] if node_split is None or len(node_split) == 0: if op == 'mean': return dgl.mean_nodes(batched_graph, feat_key) elif op == 'sum': return dgl.sum_nodes(batched_graph, feat_key) elif op == 'max': return dgl.max_nodes(batched_graph, feat_key) else: raise ValueError(f"Unknown op: {op}") else: # Ensure node_split is a flat list of ints if isinstance(node_split, torch.Tensor): splits = node_split.view(-1).tolist() else: splits = [int(x) for x in node_split] chunks = torch.split(h, splits, dim=0) if op == 'mean': out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks]) elif op == 'sum': out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks]) elif op == 'max': out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(h[0]) for chunk in chunks]) else: raise ValueError(f"Unknown op: {op}") return out def mean_edges(batched_graph, feat_key='e', op='mean', edge_split=None): """ Aggregates edge features per disjoint graph in a batched DGLGraph. Args: batched_graph: DGLGraph feat_key: str, edge feature key op: 'mean', 'sum', or 'max' edge_split: 1D tensor or list of ints (num edges per graph) Returns: Tensor of shape [num_graphs, edge_feat_dim] """ e = batched_graph.edata[feat_key] if edge_split is None or len(edge_split) == 0: if op == 'mean': return dgl.mean_edges(batched_graph, feat_key) elif op == 'sum': return dgl.sum_edges(batched_graph, feat_key) elif op == 'max': return dgl.max_edges(batched_graph, feat_key) else: raise ValueError(f"Unknown op: {op}") else: # Ensure edge_split is a flat list of ints if isinstance(edge_split, torch.Tensor): splits = edge_split.view(-1).tolist() else: splits = [int(x) for x in edge_split] chunks = torch.split(e, splits, dim=0) if op == 'mean': out = torch.stack([chunk.mean(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks]) elif op == 'sum': out = torch.stack([chunk.sum(0) if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks]) elif op == 'max': out = torch.stack([chunk.max(0).values if chunk.shape[0] > 0 else torch.zeros_like(e[0]) for chunk in chunks]) else: raise ValueError(f"Unknown op: {op}") return out def Make_SLP(in_size, out_size, activation = nn.ReLU, dropout = 0): layers = [] layers.append(nn.Linear(in_size, out_size)) layers.append(activation()) layers.append(nn.Dropout(dropout)) return layers def Make_MLP(in_size, hid_size, out_size, n_layers, activation = nn.ReLU, dropout = 0): layers = [] if n_layers > 1: layers += Make_SLP(in_size, hid_size, activation, dropout) for i in range(n_layers-2): layers += Make_SLP(hid_size, hid_size, activation, dropout) layers += Make_SLP(hid_size, out_size, activation, dropout) else: layers += Make_SLP(in_size, out_size, activation, dropout) layers.append(torch.nn.LayerNorm(out_size)) return nn.Sequential(*layers) def broadcast_global_to_nodes(globals, node_split): """ globals: [num_graphs, global_dim] node_split: list/1D tensor of length num_graphs, number of nodes per graph Returns: [total_num_nodes, global_dim] """ if node_split is None: raise ValueError("node_split must be provided") if not torch.is_tensor(node_split): node_split = torch.tensor(node_split, dtype=torch.long, device=globals.device) else: node_split = node_split.to(device=globals.device, dtype=torch.long) node_split = node_split.flatten() return torch.repeat_interleave(globals, node_split, dim=0) def broadcast_global_to_edges(globals, edge_split): """ globals: [num_graphs, global_dim] (on CUDA or CPU) edge_split: list/1D tensor of length num_graphs, number of edges per graph (CPU or CUDA) Returns: [total_num_edges, global_dim] """ if edge_split is None: raise ValueError("edge_split must be provided") if not torch.is_tensor(edge_split): edge_split = torch.tensor(edge_split, dtype=torch.long, device=globals.device) else: edge_split = edge_split.to(device=globals.device, dtype=torch.long) edge_split = edge_split.flatten() return torch.repeat_interleave(globals, edge_split, dim=0) def copy_v(edges): return {'m_v': edges.dst['h']}