Spaces:
Sleeping
Sleeping
| # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. | |
| # SPDX-FileCopyrightText: All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import warnings | |
| from types import NoneType | |
| from typing import TypeAlias | |
| try: | |
| from dgl import DGLGraph | |
| except ImportError: | |
| warnings.warn( | |
| "Note: This only applies if you're using DGL.\n" | |
| "MeshGraphNet (DGL version) requires the DGL library.\n" | |
| "Install it with your preferred CUDA version from:\n" | |
| "https://www.dgl.ai/pages/start.html\n" | |
| ) | |
| DGLGraph: TypeAlias = NoneType | |
| import torch | |
| import torch_cluster | |
| import torch_geometric as pyg | |
| import torch_scatter | |
| from physicsnemo.models.meshgraphnet.meshgraphnet import MeshGraphNet | |
| class Mesh_Reduced(torch.nn.Module): | |
| """PbGMR-GMUS architecture. | |
| A mesh-reduced architecture that combines encoding and decoding processors | |
| for physics prediction in reduced mesh space. | |
| Parameters | |
| ---------- | |
| input_dim_nodes : int | |
| Number of node features. | |
| input_dim_edges : int | |
| Number of edge features. | |
| output_decode_dim : int | |
| Number of decoding outputs (per node). | |
| output_encode_dim : int, optional | |
| Number of encoding outputs (per pivotal position), by default 3. | |
| processor_size : int, optional | |
| Number of message passing blocks, by default 15. | |
| num_layers_node_processor : int, optional | |
| Number of MLP layers for processing nodes in each message passing block, by default 2. | |
| num_layers_edge_processor : int, optional | |
| Number of MLP layers for processing edge features in each message passing block, by default 2. | |
| hidden_dim_processor : int, optional | |
| Hidden layer size for the message passing blocks, by default 128. | |
| hidden_dim_node_encoder : int, optional | |
| Hidden layer size for the node feature encoder, by default 128. | |
| num_layers_node_encoder : int, optional | |
| Number of MLP layers for the node feature encoder, by default 2. | |
| hidden_dim_edge_encoder : int, optional | |
| Hidden layer size for the edge feature encoder, by default 128. | |
| num_layers_edge_encoder : int, optional | |
| Number of MLP layers for the edge feature encoder, by default 2. | |
| hidden_dim_node_decoder : int, optional | |
| Hidden layer size for the node feature decoder, by default 128. | |
| num_layers_node_decoder : int, optional | |
| Number of MLP layers for the node feature decoder, by default 2. | |
| k : int, optional | |
| Number of nodes considered for per pivotal position, by default 3. | |
| aggregation : str, optional | |
| Message aggregation type, by default "mean". | |
| Notes | |
| ----- | |
| Reference: Han, Xu, et al. "Predicting physics in mesh-reduced space with temporal attention." | |
| arXiv preprint arXiv:2201.09113 (2022). | |
| """ | |
| def __init__( | |
| self, | |
| input_dim_nodes: int, | |
| input_dim_edges: int, | |
| output_decode_dim: int, | |
| output_encode_dim: int = 3, | |
| processor_size: int = 15, | |
| num_layers_node_processor: int = 2, | |
| num_layers_edge_processor: int = 2, | |
| hidden_dim_processor: int = 128, | |
| hidden_dim_node_encoder: int = 128, | |
| num_layers_node_encoder: int = 2, | |
| hidden_dim_edge_encoder: int = 128, | |
| num_layers_edge_encoder: int = 2, | |
| hidden_dim_node_decoder: int = 128, | |
| num_layers_node_decoder: int = 2, | |
| k: int = 3, | |
| aggregation: str = "mean", | |
| ): | |
| super(Mesh_Reduced, self).__init__() | |
| self.knn_encoder_already = False | |
| self.knn_decoder_already = False | |
| self.encoder_processor = MeshGraphNet( | |
| input_dim_nodes, | |
| input_dim_edges, | |
| output_encode_dim, | |
| processor_size, | |
| "relu", | |
| num_layers_node_processor, | |
| num_layers_edge_processor, | |
| hidden_dim_processor, | |
| hidden_dim_node_encoder, | |
| num_layers_node_encoder, | |
| hidden_dim_edge_encoder, | |
| num_layers_edge_encoder, | |
| hidden_dim_node_decoder, | |
| num_layers_node_decoder, | |
| aggregation, | |
| ) | |
| self.decoder_processor = MeshGraphNet( | |
| output_encode_dim, | |
| input_dim_edges, | |
| output_decode_dim, | |
| processor_size, | |
| "relu", | |
| num_layers_node_processor, | |
| num_layers_edge_processor, | |
| hidden_dim_processor, | |
| hidden_dim_node_encoder, | |
| num_layers_node_encoder, | |
| hidden_dim_edge_encoder, | |
| num_layers_edge_encoder, | |
| hidden_dim_node_decoder, | |
| num_layers_node_decoder, | |
| aggregation, | |
| ) | |
| self.k = k | |
| self.PivotalNorm = torch.nn.LayerNorm(output_encode_dim) | |
| def knn_interpolate( | |
| self, | |
| x: torch.Tensor, | |
| pos_x: torch.Tensor, | |
| pos_y: torch.Tensor, | |
| batch_x: torch.Tensor = None, | |
| batch_y: torch.Tensor = None, | |
| k: int = 3, | |
| num_workers: int = 1, | |
| ): | |
| """Perform k-nearest neighbor interpolation. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Input features to interpolate. | |
| pos_x : torch.Tensor | |
| Source positions. | |
| pos_y : torch.Tensor | |
| Target positions. | |
| batch_x : torch.Tensor, optional | |
| Batch indices for source positions, by default None. | |
| batch_y : torch.Tensor, optional | |
| Batch indices for target positions, by default None. | |
| k : int, optional | |
| Number of nearest neighbors to consider, by default 3. | |
| num_workers : int, optional | |
| Number of workers for parallel processing, by default 1. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Interpolated features. | |
| torch.Tensor | |
| Source indices. | |
| torch.Tensor | |
| Target indices. | |
| torch.Tensor | |
| Interpolation weights. | |
| """ | |
| with torch.no_grad(): | |
| assign_index = torch_cluster.knn( | |
| pos_x, | |
| pos_y, | |
| k, | |
| batch_x=batch_x, | |
| batch_y=batch_y, | |
| num_workers=num_workers, | |
| ) | |
| y_idx, x_idx = assign_index[0], assign_index[1] | |
| diff = pos_x[x_idx] - pos_y[y_idx] | |
| squared_distance = (diff * diff).sum(dim=-1, keepdim=True) | |
| weights = 1.0 / torch.clamp(squared_distance, min=1e-16) | |
| y = torch_scatter.scatter( | |
| x[x_idx] * weights, y_idx, 0, dim_size=pos_y.size(0), reduce="sum" | |
| ) | |
| y = y / torch_scatter.scatter( | |
| weights, y_idx, 0, dim_size=pos_y.size(0), reduce="sum" | |
| ) | |
| return y.float(), x_idx, y_idx, weights | |
| def encode(self, x, edge_features, graph, position_mesh, position_pivotal): | |
| """Encode mesh features to pivotal space. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Input node features. | |
| edge_features : torch.Tensor | |
| Edge features. | |
| graph : Union[DGLGraph, pyg.data.Data] | |
| Input graph. | |
| position_mesh : torch.Tensor | |
| Mesh positions. | |
| position_pivotal : torch.Tensor | |
| Pivotal positions. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Encoded features in pivotal space. | |
| """ | |
| x = self.encoder_processor(x, edge_features, graph) | |
| x = self.PivotalNorm(x) | |
| nodes_index = torch.arange(graph.batch_size).to(x.device) | |
| if isinstance(graph, DGLGraph): | |
| batch_mesh = nodes_index.repeat_interleave(graph.batch_num_nodes()) | |
| elif isinstance(graph, pyg.data.Data): | |
| batch_mesh = graph.batch | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| position_mesh_batch = position_mesh.repeat(graph.batch_size, 1) | |
| position_pivotal_batch = position_pivotal.repeat(graph.batch_size, 1) | |
| batch_pivotal = nodes_index.repeat_interleave( | |
| torch.tensor([len(position_pivotal)] * graph.batch_size).to(x.device) | |
| ) | |
| x, _, _, _ = self.knn_interpolate( | |
| x=x, | |
| pos_x=position_mesh_batch, | |
| pos_y=position_pivotal_batch, | |
| batch_x=batch_mesh, | |
| batch_y=batch_pivotal, | |
| ) | |
| return x | |
| def decode(self, x, edge_features, graph, position_mesh, position_pivotal): | |
| """Decode pivotal features back to mesh space. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Input features in pivotal space. | |
| edge_features : torch.Tensor | |
| Edge features. | |
| graph : Union[DGLGraph, pyg.data.Data] | |
| Input graph. | |
| position_mesh : torch.Tensor | |
| Mesh positions. | |
| position_pivotal : torch.Tensor | |
| Pivotal positions. | |
| Returns | |
| ------- | |
| torch.Tensor | |
| Decoded features in mesh space. | |
| """ | |
| nodes_index = torch.arange(graph.batch_size).to(x.device) | |
| if isinstance(graph, DGLGraph): | |
| batch_mesh = nodes_index.repeat_interleave(graph.batch_num_nodes()) | |
| elif isinstance(graph, pyg.data.Data): | |
| batch_mesh = graph.batch | |
| else: | |
| raise ValueError(f"Unsupported graph type: {type(graph)}") | |
| position_mesh_batch = position_mesh.repeat(graph.batch_size, 1) | |
| position_pivotal_batch = position_pivotal.repeat(graph.batch_size, 1) | |
| batch_pivotal = nodes_index.repeat_interleave( | |
| torch.tensor([len(position_pivotal)] * graph.batch_size).to(x.device) | |
| ) | |
| x, _, _, _ = self.knn_interpolate( | |
| x=x, | |
| pos_x=position_pivotal_batch, | |
| pos_y=position_mesh_batch, | |
| batch_x=batch_pivotal, | |
| batch_y=batch_mesh, | |
| ) | |
| x = self.decoder_processor(x, edge_features, graph) | |
| return x | |