Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import warnings | |
| from types import NoneType | |
| from typing import Any, Callable, Dict, Tuple, TypeAlias, Union | |
| import torch | |
| from torch import Tensor | |
| from torch.utils.checkpoint import checkpoint | |
| from torch_geometric.data import Data as PyGData | |
| from torch_geometric.data import HeteroData as PyGHeteroData | |
| try: | |
| import dgl # noqa: F401 for docs | |
| import dgl.function as fn | |
| from dgl import DGLGraph | |
| except ImportError: | |
| warnings.warn( | |
| "Note: This only applies if you're using DGL.\n" | |
| "MeshGraphNet (DGL version) requires the DGL library.\n" | |
| "Install it with your preferred CUDA version from:\n" | |
| "https://www.dgl.ai/pages/start.html\n" | |
| ) | |
| DGLGraph: TypeAlias = NoneType | |
| try: | |
| import torch_scatter | |
| except ImportError: | |
| warnings.warn( | |
| "MeshGraphNet will soon require PyTorch Geometric and torch_scatter.\n" | |
| "Install it from here:\n" | |
| "https://github.com/rusty1s/pytorch_scatter\n" | |
| ) | |
| from physicsnemo.models.gnn_layers import CuGraphCSC | |
| GraphType: TypeAlias = PyGData | PyGHeteroData | DGLGraph | CuGraphCSC | |
| try: | |
| from pylibcugraphops.pytorch.operators import ( | |
| agg_concat_e2n, | |
| update_efeat_bipartite_e2e, | |
| update_efeat_static_e2e, | |
| ) | |
| USE_CUGRAPHOPS = True | |
| except ImportError: | |
| update_efeat_bipartite_e2e = None | |
| update_efeat_static_e2e = None | |
| agg_concat_e2n = None | |
| USE_CUGRAPHOPS = False | |
| def checkpoint_identity(layer: Callable, *args: Any, **kwargs: Any) -> Any: | |
| """Applies the identity function for checkpointing. | |
| This function serves as an identity function for use with model layers | |
| when checkpointing is not enabled. It simply forwards the input arguments | |
| to the specified layer and returns its output. | |
| Parameters | |
| ---------- | |
| layer : Callable | |
| The model layer or function to apply to the input arguments. | |
| *args | |
| Positional arguments to be passed to the layer. | |
| **kwargs | |
| Keyword arguments to be passed to the layer. | |
| Returns | |
| ------- | |
| Any | |
| The output of the specified layer after processing the input arguments. | |
| """ | |
| return layer(*args) | |
| def set_checkpoint_fn(do_checkpointing: bool) -> Callable: | |
| """Sets checkpoint function. | |
| This function returns the appropriate checkpoint function based on the | |
| provided `do_checkpointing` flag. If `do_checkpointing` is True, the | |
| function returns the checkpoint function from PyTorch's | |
| `torch.utils.checkpoint`. Otherwise, it returns an identity function | |
| that simply passes the inputs through the given layer. | |
| Parameters | |
| ---------- | |
| do_checkpointing : bool | |
| Whether to use checkpointing for gradient computation. Checkpointing | |
| can reduce memory usage during backpropagation at the cost of | |
| increased computation time. | |
| Returns | |
| ------- | |
| Callable | |
| The selected checkpoint function to use for gradient computation. | |
| """ | |
| if do_checkpointing: | |
| return checkpoint | |
| else: | |
| return checkpoint_identity | |
| def concat_message_function(edges: Tensor) -> Dict[str, Tensor]: | |
| """Concatenates source node, destination node, and edge features. | |
| Parameters | |
| ---------- | |
| edges : Tensor | |
| Edges. | |
| Returns | |
| ------- | |
| Dict[Tensor] | |
| Concatenated source node, destination node, and edge features. | |
| """ | |
| # concats src node , dst node, and edge features | |
| cat_feat = torch.cat((edges.data["x"], edges.src["x"], edges.dst["x"]), dim=1) | |
| return {"cat_feat": cat_feat} | |
| def concat_efeat_dgl( | |
| efeat: Tensor, | |
| nfeat: Union[Tensor, Tuple[torch.Tensor, torch.Tensor]], | |
| graph: DGLGraph, | |
| ) -> Tensor: | |
| """Concatenates edge features with source and destination node features. | |
| Use for homogeneous graphs. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| nfeat : Tensor | Tuple[Tensor, Tensor] | |
| Node features. | |
| graph : DGLGraph | |
| Graph. | |
| Returns | |
| ------- | |
| Tensor | |
| Concatenated edge features with source and destination node features. | |
| """ | |
| if isinstance(nfeat, Tuple): | |
| src_feat, dst_feat = nfeat | |
| with graph.local_scope(): | |
| graph.srcdata["x"] = src_feat | |
| graph.dstdata["x"] = dst_feat | |
| graph.edata["x"] = efeat | |
| graph.apply_edges(concat_message_function) | |
| return graph.edata["cat_feat"] | |
| with graph.local_scope(): | |
| graph.ndata["x"] = nfeat | |
| graph.edata["x"] = efeat | |
| graph.apply_edges(concat_message_function) | |
| return graph.edata["cat_feat"] | |
| def concat_efeat_hetero_dgl( | |
| mesh_efeat: Tensor, | |
| world_efeat: Tensor, | |
| nfeat: Union[Tensor, Tuple[torch.Tensor, torch.Tensor]], | |
| graph: DGLGraph, | |
| ) -> Tensor: | |
| """Concatenates edge features with source and destination node features. | |
| Use for heterogeneous graphs. | |
| Parameters | |
| ---------- | |
| mesh_efeat : Tensor | |
| Mesh edge features. | |
| world_efeat : Tensor | |
| World edge features. | |
| nfeat : Tensor | Tuple[Tensor, Tensor] | |
| Node features. | |
| graph : DGLGraph | |
| Graph. | |
| Returns | |
| ------- | |
| Tensor | |
| Concatenated edge features with source and destination node features. | |
| """ | |
| if isinstance(nfeat, Tuple): | |
| src_feat, dst_feat = nfeat | |
| with graph.local_scope(): | |
| graph.srcdata["x"] = src_feat | |
| graph.dstdata["x"] = dst_feat | |
| graph.edata["x"] = torch.cat([mesh_efeat, world_efeat], dim=0) | |
| graph.apply_edges(concat_message_function) | |
| return graph.edata["cat_feat"] | |
| with graph.local_scope(): | |
| graph.ndata["x"] = nfeat | |
| graph.edata["x"] = torch.cat([mesh_efeat, world_efeat], dim=0) | |
| graph.apply_edges(concat_message_function) | |
| return graph.edata["cat_feat"] | |
| def concat_efeat_pyg( | |
| efeat: Tensor, | |
| nfeat: Union[Tensor, Tuple[Tensor, Tensor]], | |
| graph: PyGData | PyGHeteroData, | |
| ) -> Tensor: | |
| """Concatenates edge features with source and destination node features. | |
| Use for PyG graphs. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| nfeat : Tensor | Tuple[Tensor] | |
| Node features. | |
| graph : PyGData | |
| Graph. | |
| Returns | |
| ------- | |
| Tensor | |
| Concatenated edge features with source and destination node features. | |
| """ | |
| src_feat, dst_feat = nfeat if isinstance(nfeat, Tuple) else (nfeat, nfeat) | |
| if isinstance(graph, PyGHeteroData): | |
| src_idx, dst_idx = graph[graph.edge_types[0]].edge_index.long() | |
| else: | |
| src_idx, dst_idx = graph.edge_index.long() | |
| cat_feat = torch.cat((efeat, src_feat[src_idx], dst_feat[dst_idx]), dim=1) | |
| return cat_feat | |
| def concat_efeat( | |
| efeat: Tensor, | |
| nfeat: Union[Tensor, Tuple[Tensor]], | |
| graph: GraphType, | |
| ) -> Tensor: | |
| """Concatenates edge features with source and destination node features. | |
| Use for homogeneous graphs. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| nfeat : Tensor | Tuple[Tensor] | |
| Node features. | |
| graph : GraphType | |
| Graph. | |
| Returns | |
| ------- | |
| Tensor | |
| Concatenated edge features with source and destination node features. | |
| """ | |
| if isinstance(nfeat, Tensor): | |
| if isinstance(graph, CuGraphCSC): | |
| if graph.dgl_graph is not None or not USE_CUGRAPHOPS: | |
| src_feat, dst_feat = nfeat, nfeat | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(nfeat) | |
| efeat = concat_efeat_dgl( | |
| efeat, (src_feat, dst_feat), graph.to_dgl_graph() | |
| ) | |
| else: | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(nfeat) | |
| # torch.int64 to avoid indexing overflows due tu current behavior of cugraph-ops | |
| bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64) | |
| dst_feat = nfeat | |
| efeat = update_efeat_bipartite_e2e( | |
| efeat, src_feat, dst_feat, bipartite_graph, "concat" | |
| ) | |
| else: | |
| static_graph = graph.to_static_csc() | |
| efeat = update_efeat_static_e2e( | |
| efeat, | |
| nfeat, | |
| static_graph, | |
| mode="concat", | |
| use_source_emb=True, | |
| use_target_emb=True, | |
| ) | |
| elif isinstance(graph, DGLGraph): | |
| efeat = concat_efeat_dgl(efeat, nfeat, graph) | |
| elif isinstance(graph, (PyGData, PyGHeteroData)): | |
| efeat = concat_efeat_pyg(efeat, nfeat, graph) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| elif isinstance(nfeat, Tuple): | |
| src_feat, dst_feat = nfeat | |
| # update edge features through concatenating edge and node features | |
| if isinstance(graph, CuGraphCSC): | |
| if graph.dgl_graph is not None or not USE_CUGRAPHOPS: | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(src_feat) | |
| efeat = concat_efeat_dgl( | |
| efeat, (src_feat, dst_feat), graph.to_dgl_graph() | |
| ) | |
| else: | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(src_feat) | |
| # torch.int64 to avoid indexing overflows due tu current behavior of cugraph-ops | |
| bipartite_graph = graph.to_bipartite_csc(dtype=torch.int64) | |
| efeat = update_efeat_bipartite_e2e( | |
| efeat, src_feat, dst_feat, bipartite_graph, "concat" | |
| ) | |
| elif isinstance(graph, DGLGraph): | |
| efeat = concat_efeat_dgl(efeat, (src_feat, dst_feat), graph) | |
| elif isinstance(graph, (PyGData, PyGHeteroData)): | |
| efeat = concat_efeat_pyg(efeat, (src_feat, dst_feat), graph) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| else: | |
| raise ValueError(f"Unsupported node feature type: {type(nfeat)}") | |
| return efeat | |
| def concat_efeat_hetero( | |
| mesh_efeat: Tensor, | |
| world_efeat: Tensor, | |
| nfeat: Union[Tensor, Tuple[Tensor, Tensor]], | |
| graph: GraphType, | |
| ) -> Tensor: | |
| """Concatenates edge features with source and destination node features. | |
| Use for heterogeneous graphs. | |
| """ | |
| if isinstance(graph, CuGraphCSC): | |
| raise NotImplementedError( | |
| "concat_efeat_hetero is not supported for CuGraphCSC graphs yet." | |
| ) | |
| elif isinstance(graph, DGLGraph): | |
| efeat = concat_efeat_hetero_dgl(mesh_efeat, world_efeat, nfeat, graph) | |
| elif isinstance(graph, PyGData): | |
| efeat = concat_efeat_pyg( | |
| torch.cat((mesh_efeat, world_efeat), dim=0), nfeat, graph | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| return efeat | |
| def sum_edge_node_feat( | |
| efeat: Tensor, src_feat: Tensor, dst_feat: Tensor, src_idx: Tensor, dst_idx: Tensor | |
| ) -> Tensor: | |
| """Sums edge features with source and destination node features. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| src_feat : Tensor | |
| Source node features. | |
| dst_feat : Tensor | |
| Destination node features. | |
| src_idx : Tensor | |
| Source node indices. | |
| dst_idx : Tensor | |
| Destination node indices. | |
| Returns | |
| ------- | |
| Tensor | |
| Sum of edge features with source and destination node features. | |
| """ | |
| return efeat + src_feat[src_idx] + dst_feat[dst_idx] | |
| def sum_efeat( | |
| efeat: Tensor, | |
| nfeat: Union[Tensor, Tuple[Tensor]], | |
| graph: GraphType, | |
| ): | |
| """Sums edge features with source and destination node features. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| nfeat : Tensor | Tuple[Tensor] | |
| Node features (static setting) or tuple of node features of | |
| source and destination nodes (bipartite setting). | |
| graph : GraphType | |
| The underlying graph. | |
| Returns | |
| ------- | |
| Tensor | |
| Sum of edge features with source and destination node features. | |
| """ | |
| if isinstance(nfeat, Tensor): | |
| if isinstance(graph, CuGraphCSC): | |
| if graph.dgl_graph is not None or not USE_CUGRAPHOPS: | |
| src_feat, dst_feat = nfeat, nfeat | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(src_feat) | |
| src, dst = (item.long() for item in graph.to_dgl_graph().edges()) | |
| sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst) | |
| else: | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(nfeat) | |
| dst_feat = nfeat | |
| bipartite_graph = graph.to_bipartite_csc() | |
| sum_efeat = update_efeat_bipartite_e2e( | |
| efeat, src_feat, dst_feat, bipartite_graph, mode="sum" | |
| ) | |
| else: | |
| static_graph = graph.to_static_csc() | |
| sum_efeat = update_efeat_bipartite_e2e( | |
| efeat, nfeat, static_graph, mode="sum" | |
| ) | |
| elif isinstance(graph, DGLGraph): | |
| src_feat, dst_feat = nfeat, nfeat | |
| src, dst = (item.long() for item in graph.edges()) | |
| sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst) | |
| elif isinstance(graph, PyGData): | |
| src_feat, dst_feat = nfeat, nfeat | |
| src, dst = graph.edge_index.long() | |
| sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| else: | |
| src_feat, dst_feat = nfeat | |
| if isinstance(graph, CuGraphCSC): | |
| if graph.dgl_graph is not None or not USE_CUGRAPHOPS: | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(src_feat) | |
| src, dst = (item.long() for item in graph.to_dgl_graph().edges()) | |
| sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst) | |
| else: | |
| if graph.is_distributed: | |
| src_feat = graph.get_src_node_features_in_local_graph(src_feat) | |
| bipartite_graph = graph.to_bipartite_csc() | |
| sum_efeat = update_efeat_bipartite_e2e( | |
| efeat, src_feat, dst_feat, bipartite_graph, mode="sum" | |
| ) | |
| elif isinstance(graph, DGLGraph): | |
| src, dst = (item.long() for item in graph.edges()) | |
| sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst) | |
| elif isinstance(graph, (PyGData, PyGHeteroData)): | |
| if isinstance(graph, PyGHeteroData): | |
| src, dst = graph[graph.edge_types[0]].edge_index.long() | |
| else: | |
| src, dst = graph.edge_index.long() | |
| sum_efeat = sum_edge_node_feat(efeat, src_feat, dst_feat, src, dst) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| return sum_efeat | |
| def agg_concat_dgl( | |
| efeat: Tensor, dst_nfeat: Tensor, graph: DGLGraph, aggregation: str | |
| ) -> Tensor: | |
| """Aggregates edge features and concatenates result with destination node features. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| nfeat : Tensor | |
| Node features (destination nodes). | |
| graph : DGLGraph | |
| Graph. | |
| aggregation : str | |
| Aggregation method (sum or mean). | |
| Returns | |
| ------- | |
| Tensor | |
| Aggregated edge features concatenated with destination node features. | |
| Raises | |
| ------ | |
| RuntimeError | |
| If aggregation method is not sum or mean. | |
| """ | |
| with graph.local_scope(): | |
| # populate features on graph edges | |
| graph.edata["x"] = efeat | |
| # aggregate edge features | |
| if aggregation == "sum": | |
| graph.update_all(fn.copy_e("x", "m"), fn.sum("m", "h_dest")) | |
| elif aggregation == "mean": | |
| graph.update_all(fn.copy_e("x", "m"), fn.mean("m", "h_dest")) | |
| else: | |
| raise RuntimeError("Not a valid aggregation!") | |
| # concat dst-node & edge features | |
| cat_feat = torch.cat((graph.dstdata["h_dest"], dst_nfeat), -1) | |
| return cat_feat | |
| def agg_concat_hetero_dgl( | |
| mesh_efeat: Tensor, | |
| world_efeat: Tensor, | |
| dst_nfeat: Tensor, | |
| graph: DGLGraph, | |
| aggregation: str, | |
| ) -> Tensor: | |
| """Aggregates edge features and concatenates result with destination node features. | |
| Use for heterogeneous graphs. | |
| Parameters | |
| ---------- | |
| mesh_efeat : Tensor | |
| Mesh edge features. | |
| world_efeat : Tensor | |
| World edge features. | |
| dst_nfeat : Tensor | |
| Node features (destination nodes). | |
| graph : DGLGraph | |
| Graph. | |
| aggregation : str | |
| Aggregation method (sum or mean). | |
| Returns | |
| ------- | |
| Tensor | |
| Aggregated edge features concatenated with destination node features. | |
| Raises | |
| ------ | |
| RuntimeError | |
| If aggregation method is not sum or mean. | |
| """ | |
| with graph.local_scope(): | |
| # populate features on graph edges | |
| graph.edata["x"] = torch.cat([mesh_efeat, world_efeat], dim=0) | |
| # aggregate edge features | |
| if aggregation == "sum": | |
| graph.update_all(fn.copy_e("x", "m"), fn.sum("m", "h_dest")) | |
| elif aggregation == "mean": | |
| graph.update_all(fn.copy_e("x", "m"), fn.mean("m", "h_dest")) | |
| else: | |
| raise RuntimeError("Not a valid aggregation!") | |
| # concat dst-node & edge features | |
| cat_feat = torch.cat((graph.dstdata["h_dest"], dst_nfeat), -1) | |
| return cat_feat | |
| def agg_concat_pyg( | |
| efeat: Tensor, | |
| nfeat: Tensor, | |
| graph: PyGData | PyGHeteroData, | |
| aggregation: str, | |
| ) -> Tensor: | |
| if isinstance(graph, PyGHeteroData): | |
| _, dst = graph[graph.edge_types[0]].edge_index.long() | |
| else: | |
| _, dst = graph.edge_index.long() | |
| h_dest = torch_scatter.scatter( | |
| efeat, dst, dim=0, dim_size=nfeat.shape[0], reduce=aggregation | |
| ) | |
| cat_feat = torch.cat((h_dest, nfeat), -1) | |
| return cat_feat | |
| def aggregate_and_concat( | |
| efeat: Tensor, | |
| nfeat: Tensor, | |
| graph: GraphType, | |
| aggregation: str, | |
| ): | |
| """ | |
| Aggregates edge features and concatenates result with destination node features. | |
| Parameters | |
| ---------- | |
| efeat : Tensor | |
| Edge features. | |
| nfeat : Tensor | |
| Node features (destination nodes). | |
| graph : GraphType | |
| Graph. | |
| aggregation : str | |
| Aggregation method (sum or mean). | |
| Returns | |
| ------- | |
| Tensor | |
| Aggregated edge features concatenated with destination node features. | |
| Raises | |
| ------ | |
| RuntimeError | |
| If aggregation method is not sum or mean. | |
| """ | |
| if isinstance(graph, CuGraphCSC): | |
| # in this case, we don't have to distinguish a distributed setting | |
| # or the defalt setting as both efeat and nfeat are already | |
| # gurantueed to be on the same rank on both cases due to our | |
| # partitioning scheme | |
| if graph.dgl_graph is not None or not USE_CUGRAPHOPS: | |
| cat_feat = agg_concat_dgl(efeat, nfeat, graph.to_dgl_graph(), aggregation) | |
| else: | |
| static_graph = graph.to_static_csc() | |
| cat_feat = agg_concat_e2n(nfeat, efeat, static_graph, aggregation) | |
| elif isinstance(graph, DGLGraph): | |
| cat_feat = agg_concat_dgl(efeat, nfeat, graph, aggregation) | |
| elif isinstance(graph, (PyGData, PyGHeteroData)): | |
| cat_feat = agg_concat_pyg(efeat, nfeat, graph, aggregation) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| return cat_feat | |
| def aggregate_and_concat_hetero( | |
| mesh_efeat: Tensor, | |
| world_efeat: Tensor, | |
| nfeat: Tensor, | |
| graph: GraphType, | |
| aggregation: str, | |
| ): | |
| """ | |
| Aggregates edge features and concatenates result with destination node features. | |
| Use for heterogeneous graphs. | |
| Parameters | |
| ---------- | |
| mesh_efeat : Tensor | |
| Mesh edge features. | |
| world_efeat : Tensor | |
| World edge features. | |
| nfeat : Tensor | |
| Node features (destination nodes). | |
| graph : GraphType | |
| Graph. | |
| aggregation : str | |
| Aggregation method (sum or mean). | |
| Returns | |
| ------- | |
| Tensor | |
| Aggregated edge features concatenated with destination node features. | |
| Raises | |
| ------ | |
| RuntimeError | |
| If aggregation method is not sum or mean. | |
| """ | |
| if isinstance(graph, CuGraphCSC): | |
| raise NotImplementedError( | |
| "aggregate_and_concat_hetero is not supported for CuGraphCSC graphs yet." | |
| ) | |
| elif isinstance(graph, DGLGraph): | |
| cat_feat = agg_concat_hetero_dgl( | |
| mesh_efeat, world_efeat, nfeat, graph, aggregation | |
| ) | |
| elif isinstance(graph, PyGData): | |
| cat_feat = agg_concat_pyg( | |
| torch.cat((mesh_efeat, world_efeat), dim=0), nfeat, graph, aggregation | |
| ) | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| return cat_feat | |