ArthurY's picture
update source
c3d0544
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ruff: noqa: S101,F722
from typing import List, Literal, Optional
import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
from physicsnemo.models.figconvnet.components.encodings import SinusoidalEncoding
from physicsnemo.models.figconvnet.components.mlp import MLPBlock
from physicsnemo.models.figconvnet.components.reductions import (
REDUCTION_TYPES,
row_reduction,
)
from physicsnemo.models.figconvnet.geometries import PointFeatures
from physicsnemo.models.figconvnet.neighbor_ops import (
batched_neighbor_knn_search,
batched_neighbor_radius_search,
)
class PointFeatureTransform(nn.Module):
"""PointFeatureTransform."""
def __init__(
self,
feature_transform: nn.Module,
):
super().__init__()
self.feature_transform = feature_transform
def forward(self, point_features: PointFeatures["N C1"]) -> PointFeatures["N C2"]:
return PointFeatures(
point_features.vertices, self.feature_transform(point_features.features)
)
class PointFeatureCat(nn.Module):
"""PointFeatureCat."""
def forward(
self,
point_features: PointFeatures["N C1"],
point_features2: PointFeatures["N C2"],
) -> PointFeatures["N C3"]:
return PointFeatures(
point_features.vertices,
torch.cat([point_features.features, point_features2.features], dim=1),
)
class PointFeatureMLP(PointFeatureTransform):
"""PointFeatureMLP."""
def __init__(
self,
in_channels: int,
out_channels: int,
hidden_channels: Optional[int] = None,
multiplier: int = 2,
nonlinearity: nn.Module = nn.GELU,
):
if hidden_channels is None:
hidden_channels = multiplier * out_channels
PointFeatureTransform.__init__(
self,
nn.Sequential(
nn.Linear(in_channels, hidden_channels),
nonlinearity(),
nn.Linear(hidden_channels, out_channels),
),
)
class PointFeatureConv(nn.Module):
"""PointFeatureConv."""
def __init__(
self,
radius: float,
edge_transform_mlp: Optional[nn.Module] = None,
out_transform_mlp: Optional[nn.Module] = None,
in_channels: int = 8,
out_channels: int = 32,
hidden_dim: Optional[int] = None,
channel_multiplier: int = 2,
use_rel_pos: bool = True,
use_rel_pos_encode: bool = False,
pos_encode_dim: int = 32,
pos_encode_range: float = 4,
reductions: List[REDUCTION_TYPES] = ["mean"],
downsample_voxel_size: Optional[float] = None,
out_point_feature_type: Literal["provided", "downsample", "same"] = "same",
provided_in_channels: Optional[int] = None,
neighbor_search_vertices_scaler: Optional[Float[Tensor, "3"]] = None,
neighbor_search_type: Literal["radius", "knn"] = "radius",
radius_search_method: Literal["open3d", "warp"] = "warp",
knn_k: Optional[int] = None,
):
"""If use_relative_position_encoding is True, the positional encoding vertex coordinate
difference is added to the edge features.
downsample_voxel_size: If not None, the input point cloud will be downsampled.
out_point_feature_type: If "upsample", the output point features will be upsampled to the input point cloud size.
use_rel_pos: If True, the relative position of the neighbor points will be used as the edge features.
use_rel_pos_encode: If True, the encoding relative position of the neighbor points will be used as the edge features.
if neighbor_search_vertices_scaler is not None, find neighbors using the
scaled vertices. This allows finding neighbors with an axis aligned
ellipsoidal neighborhood.
"""
super().__init__()
assert isinstance(reductions, (tuple, list)) and len(reductions) > 0, (
f"reductions must be a list or tuple of length > 0, got {reductions}"
)
if out_point_feature_type == "provided":
assert downsample_voxel_size is None, (
"downsample_voxel_size is only used for downsample"
)
assert provided_in_channels is not None, (
"provided_in_channels must be provided for provided type"
)
elif out_point_feature_type == "downsample":
assert downsample_voxel_size is not None, (
"downsample_voxel_size must be provided for downsample"
)
assert provided_in_channels is None, (
"provided_in_channels must be None for downsample type"
)
elif out_point_feature_type == "same":
assert downsample_voxel_size is None, (
"downsample_voxel_size is only used for downsample"
)
assert provided_in_channels is None, (
"provided_in_channels must be None for same type"
)
if downsample_voxel_size is not None and downsample_voxel_size > radius:
raise ValueError(
f"downsample_voxel_size {downsample_voxel_size} must be <= radius {radius}"
)
self.reductions = reductions
self.downsample_voxel_size = downsample_voxel_size
self.in_channels = in_channels
self.out_channels = out_channels
self.use_rel_pos = use_rel_pos
self.use_rel_pos_encode = use_rel_pos_encode
self.out_point_feature_type = out_point_feature_type
self.neighbor_search_vertices_scaler = neighbor_search_vertices_scaler
self.neighbor_search_type = neighbor_search_type
self.radius_search_method = radius_search_method
if neighbor_search_type == "radius":
self.radius_or_k = radius
elif neighbor_search_type == "knn":
assert knn_k is not None
self.radius_or_k = knn_k
else:
raise ValueError(
f"neighbor_search_type must be radius or knn, got {neighbor_search_type}"
)
self.positional_encoding = SinusoidalEncoding(
pos_encode_dim, data_range=pos_encode_range
)
# When down voxel size is not None, there will be out_point_features will be provided as an additional input
if provided_in_channels is None:
provided_in_channels = in_channels
if hidden_dim is None:
hidden_dim = channel_multiplier * out_channels
if edge_transform_mlp is None:
edge_in_channels = in_channels + provided_in_channels
if use_rel_pos_encode:
edge_in_channels += pos_encode_dim * 3
elif use_rel_pos:
edge_in_channels += 3
edge_transform_mlp = MLPBlock(
in_channels=edge_in_channels,
hidden_channels=hidden_dim,
out_channels=out_channels,
)
self.edge_transform_mlp = edge_transform_mlp
if out_transform_mlp is None:
out_transform_mlp = MLPBlock(
in_channels=out_channels * len(reductions),
hidden_channels=hidden_dim,
out_channels=out_channels,
)
self.out_transform_mlp = out_transform_mlp
def __repr__(self):
out_str = f"{self.__class__.__name__}(in_channels={self.in_channels} out_channels={self.out_channels} search_type={self.neighbor_search_type} reductions={self.reductions}"
if self.downsample_voxel_size is not None:
out_str += f" down_voxel_size={self.downsample_voxel_size}"
if self.use_rel_pos_encode:
out_str += f" rel_pos_encode={self.use_rel_pos_encode}"
out_str += ")"
return out_str
def forward(
self,
in_point_features: PointFeatures["B N C1"],
out_point_features: Optional[PointFeatures["B M C1"]] = None,
# in_weight: Optional[Float[Tensor, "B N"]] = None, # noqa: F821
neighbor_search_vertices_scaler: Optional[Float[Tensor, "3"]] = None,
) -> PointFeatures["B M C2"]:
"""When out_point_features is None, the output will be generated on the
in_point_features.vertices."""
if self.out_point_feature_type == "provided":
assert out_point_features is not None, (
"out_point_features must be provided for the provided type"
)
elif self.out_point_feature_type == "downsample":
assert out_point_features is None
out_point_features = in_point_features.voxel_down_sample(
self.downsample_voxel_size
)
elif self.out_point_feature_type == "same":
assert out_point_features is None
out_point_features = in_point_features
in_num_channels = in_point_features.num_channels
out_num_channels = out_point_features.num_channels
assert (
in_num_channels
+ out_num_channels
+ self.use_rel_pos_encode * self.positional_encoding.num_channels * 3
+ (not self.use_rel_pos_encode) * self.use_rel_pos * 3
== self.edge_transform_mlp.in_channels
), (
f"input features shape {in_point_features.features.shape} and {out_point_features.features.shape} does not match the edge_transform_mlp input features {self.edge_transform_mlp.in_channels}"
)
# Get the neighbors
in_vertices = in_point_features.vertices
out_vertices = out_point_features.vertices
if self.neighbor_search_vertices_scaler is not None:
neighbor_search_vertices_scaler = self.neighbor_search_vertices_scaler
if neighbor_search_vertices_scaler is not None:
in_vertices = in_vertices * neighbor_search_vertices_scaler.to(in_vertices)
out_vertices = out_vertices * neighbor_search_vertices_scaler.to(
out_vertices
)
if self.neighbor_search_type == "knn":
device = in_vertices.device
neighbors_index = batched_neighbor_knn_search(
in_vertices, out_vertices, self.radius_or_k
)
# B x M x K index
neighbors_index = neighbors_index.long().to(device).view(-1)
# M row splits
neighbors_row_splits = (
torch.arange(
0, out_vertices.shape[0] * out_vertices.shape[1] + 1, device=device
)
* self.radius_or_k
)
rep_in_features = in_point_features.features.view(-1, in_num_channels)[
neighbors_index
]
num_reps = self.radius_or_k
elif self.neighbor_search_type == "radius":
neighbors = batched_neighbor_radius_search(
in_vertices,
out_vertices,
radius=self.radius_or_k,
search_method=self.radius_search_method,
)
neighbors_index = neighbors.neighbors_index.long()
rep_in_features = in_point_features.features.view(-1, in_num_channels)[
neighbors_index
]
neighbors_row_splits = neighbors.neighbors_row_splits
num_reps = neighbors_row_splits[1:] - neighbors_row_splits[:-1]
else:
raise ValueError(
f"neighbor_search_type must be radius or knn, got {self.neighbor_search_type}"
)
# repeat the self features using num_reps
self_features = torch.repeat_interleave(
out_point_features.features.view(-1, out_num_channels).contiguous(),
num_reps,
dim=0,
)
edge_features = [rep_in_features, self_features]
if self.use_rel_pos or self.use_rel_pos_encode:
in_rep_vertices = in_point_features.vertices.view(-1, 3)[neighbors_index]
self_vertices = torch.repeat_interleave(
out_point_features.vertices.view(-1, 3).contiguous(), num_reps, dim=0
)
if self.use_rel_pos_encode:
edge_features.append(
self.positional_encoding(
in_rep_vertices.view(-1, 3) - self_vertices.view(-1, 3)
)
)
elif self.use_rel_pos:
edge_features.append(in_rep_vertices - self_vertices)
edge_features = torch.cat(edge_features, dim=1)
edge_features = self.edge_transform_mlp(edge_features)
# if in_weight is not None:
# assert in_weight.shape[0] == in_point_features.features.shape[0]
# rep_weights = in_weight[neighbors_index]
# edge_features = edge_features * rep_weights.squeeze().unsqueeze(-1)
out_features = [
row_reduction(edge_features, neighbors_row_splits, reduction=reduction)
for reduction in self.reductions
]
out_features = torch.cat(out_features, dim=-1)
out_features = self.out_transform_mlp(out_features)
# Convert back to the original shape
out_features = out_features.view(
out_point_features.batch_size,
out_point_features.num_points,
out_features.shape[-1],
)
return PointFeatures(out_point_features.vertices, out_features)
class PointFeatureConvBlock(nn.Module):
"""ConvBlock has two convolutions with a residual connection."""
def __init__(
self,
in_channels: int,
out_channels: int,
radius: float,
use_rel_pos: bool = True,
use_rel_pos_encode: bool = False,
pos_encode_dim: int = 32,
reductions: List[REDUCTION_TYPES] = ["mean"],
downsample_voxel_size: Optional[float] = None,
pos_encode_range: float = 4,
out_point_feature_type: Literal["provided", "downsample", "same"] = "same",
provided_in_channels: Optional[int] = None,
neighbor_search_type: Literal["radius", "knn"] = "radius",
knn_k: Optional[int] = None,
):
super().__init__()
self.downsample_voxel_size = downsample_voxel_size
self.out_point_feature_type = out_point_feature_type
self.conv1 = PointFeatureConv(
in_channels=in_channels,
out_channels=out_channels,
radius=radius,
use_rel_pos=use_rel_pos,
use_rel_pos_encode=use_rel_pos_encode,
pos_encode_dim=pos_encode_dim,
reductions=reductions,
downsample_voxel_size=downsample_voxel_size,
pos_encode_range=pos_encode_range,
out_point_feature_type=out_point_feature_type,
provided_in_channels=provided_in_channels,
neighbor_search_type=neighbor_search_type,
knn_k=knn_k,
)
self.conv2 = PointFeatureConv(
in_channels=out_channels,
out_channels=out_channels,
radius=radius,
use_rel_pos=use_rel_pos,
pos_encode_range=pos_encode_range,
reductions=reductions,
out_point_feature_type="same",
neighbor_search_type=neighbor_search_type,
knn_k=knn_k,
)
self.norm1 = PointFeatureTransform(nn.LayerNorm(out_channels))
self.norm2 = PointFeatureTransform(nn.LayerNorm(out_channels))
if out_point_feature_type == "provided":
self.shortcut = PointFeatureMLP(
in_channels=provided_in_channels, out_channels=out_channels
)
elif in_channels == out_channels:
self.shortcut = nn.Identity()
else:
self.shortcut = PointFeatureMLP(in_channels, out_channels)
# if down_sample_voxel_size is None, just use MLP for shortcut, otherwise, use conv to downsample
self.nonlinear = PointFeatureTransform(nn.GELU())
def forward(
self,
in_point_features: PointFeatures["B N C1"],
out_point_features: Optional[PointFeatures["B M C2"]] = None,
) -> PointFeatures["B N C2"]:
if self.out_point_feature_type == "provided":
assert out_point_features is not None, (
"out_point_features must be provided for the provided type"
)
out = self.conv1(in_point_features, out_point_features)
elif self.out_point_feature_type == "downsample":
assert out_point_features is None
out_point_features = in_point_features.voxel_down_sample(
self.downsample_voxel_size
)
out = self.conv1(in_point_features)
elif self.out_point_feature_type == "same":
assert out_point_features is None
out_point_features = in_point_features
out = self.conv1(in_point_features)
out = self.nonlinear(self.norm1(out))
out = self.norm2(self.conv2(out))
return self.nonlinear(out + self.shortcut(out_point_features))