qbhf2's picture
added NvidiaWarp and GarmentCode repos
66c9c8a
# Copyright (c) 2023 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Node creating a grid geometry simulated with a wave equation solver."""
from math import sqrt
import traceback
import numpy as np
import omni.graph.core as og
import omni.timeline
import warp as wp
import omni.warp.nodes
from omni.warp.nodes._impl.kernels.grid_create import grid_create_launch_kernel
from omni.warp.nodes.ogn.OgnWaveSolveDatabase import OgnWaveSolveDatabase
PROFILING = False
# Kernels
# -----------------------------------------------------------------------------
@wp.func
def sample_height(
height_map: wp.array(dtype=float),
x: int,
z: int,
point_count_x: int,
point_count_z: int,
):
# Clamp to the grid's bounds.
x = wp.clamp(x, 0, point_count_x - 1)
z = wp.clamp(z, 0, point_count_z - 1)
return height_map[z * point_count_x + x]
@wp.func
def laplacian(
height_map: wp.array(dtype=float),
x: int,
z: int,
point_count_x: int,
point_count_z: int,
):
# See https://en.wikipedia.org/wiki/Wave_equation.
ddx = (
sample_height(height_map, x + 1, z, point_count_x, point_count_z)
- sample_height(height_map, x, z, point_count_x, point_count_z) * 2.0
+ sample_height(height_map, x - 1, z, point_count_x, point_count_z)
)
ddz = (
sample_height(height_map, x, z + 1, point_count_x, point_count_z)
- sample_height(height_map, x, z, point_count_x, point_count_z) * 2.0
+ sample_height(height_map, x, z - 1, point_count_x, point_count_z)
)
return ddx + ddz
@wp.kernel(enable_backward=False)
def displace_kernel(
point_count_x: int,
center_x: float,
center_z: float,
radius: float,
amplitude: float,
time: float,
out_height_map_0: wp.array(dtype=float),
out_height_map_1: wp.array(dtype=float),
):
tid = wp.tid()
x = tid % point_count_x
z = tid // point_count_x
dx = float(x) - center_x
dz = float(z) - center_z
dist_sq = float(dx * dx + dz * dz)
if dist_sq < radius * radius:
height = amplitude * wp.sin(time)
out_height_map_0[tid] = height
out_height_map_1[tid] = height
@wp.kernel(enable_backward=False)
def simulate_kernel(
point_count_x: int,
point_count_z: int,
inv_cell_size: float,
speed: float,
damping: float,
dt: float,
height_map_1: wp.array(dtype=float),
out_height_map_0: wp.array(dtype=float),
):
tid = wp.tid()
x = tid % point_count_x
z = tid // point_count_x
d = laplacian(height_map_1, x, z, point_count_x, point_count_z)
d *= inv_cell_size * inv_cell_size
# Integrate and write the result in the 'previous' height map buffer since
# it will be then swapped to become the 'current' one.
h0 = out_height_map_0[tid]
h1 = height_map_1[tid]
out_height_map_0[tid] = h1 * 2.0 - h0 + (d * speed - (h1 - h0) * damping) * dt * dt
@wp.kernel(enable_backward=False)
def update_mesh_kernel(
height_map: wp.array(dtype=float),
out_points: wp.array(dtype=wp.vec3),
):
tid = wp.tid()
height = height_map[tid]
pos = out_points[tid]
out_points[tid] = wp.vec3(pos[0], height, pos[2])
# Internal State
# ------------------------------------------------------------------------------
class InternalState:
"""Internal state for the node."""
def __init__(self) -> None:
self.height_map_0 = None
self.height_map_1 = None
self.time = 0.0
self.is_valid = False
self.attr_tracking = omni.warp.nodes.AttrTracking(
(
"size",
"cellSize",
),
)
def needs_initialization(self, db: OgnWaveSolveDatabase) -> bool:
"""Checks if the internal state needs to be (re)initialized."""
if not self.is_valid:
return True
if self.attr_tracking.have_attrs_changed(db):
return True
if db.inputs.time < self.time:
# Reset the simulation when we're rewinding.
return True
return False
def initialize(
self,
db: OgnWaveSolveDatabase,
dims: np.ndarray,
) -> bool:
"""Initializes the internal state."""
point_count = omni.warp.nodes.mesh_get_point_count(db.outputs.mesh)
# Initialize a double buffering for the height map.
height_map_0 = wp.zeros(point_count, dtype=float)
height_map_1 = wp.zeros(point_count, dtype=float)
# Build the grid mesh.
grid_create_launch_kernel(
omni.warp.nodes.mesh_get_points(db.outputs.mesh),
omni.warp.nodes.mesh_get_face_vertex_counts(db.outputs.mesh),
omni.warp.nodes.mesh_get_face_vertex_indices(db.outputs.mesh),
omni.warp.nodes.mesh_get_normals(db.outputs.mesh),
omni.warp.nodes.mesh_get_uvs(db.outputs.mesh),
db.inputs.size.tolist(),
dims.tolist(),
update_topology=True,
)
# Store the class members.
self.height_map_0 = height_map_0
self.height_map_1 = height_map_1
self.attr_tracking.update_state(db)
return True
# Compute
# ------------------------------------------------------------------------------
def displace(
db: OgnWaveSolveDatabase,
dims: np.ndarray,
cell_size: np.ndarray,
) -> None:
"""Displaces the height map with the collider."""
state = db.internal_state
# Retrieve some data from the grid mesh.
xform = omni.warp.nodes.bundle_get_world_xform(db.outputs.mesh)
# Retrieve some data from the collider mesh.
collider_xform = omni.warp.nodes.bundle_get_world_xform(db.inputs.collider)
collider_extent = omni.warp.nodes.mesh_get_world_extent(
db.inputs.collider,
axis_aligned=True,
)
# Retrieve the collider's position in the grid's object space.
collider_pos = np.pad(collider_xform[3][:3], (0, 1), constant_values=1)
collider_pos = np.dot(np.linalg.inv(xform).T, collider_pos)
# Compute the collider's radius.
collider_radius = np.amax(collider_extent[1] - collider_extent[0]) * 0.5
# Determine the point around which the grid will be displaced.
center_x = (dims[0] + 1) * 0.5 - float(collider_pos[0]) / cell_size[0]
center_z = (dims[1] + 1) * 0.5 - float(collider_pos[2]) / cell_size[1]
# Clamp the deformation center to the grid's bounds.
center_x = max(0, min(dims[0], center_x))
center_z = max(0, min(dims[1], center_z))
# Apply the displacement if the collider is in contact with the grid.
contact_radius_sq = (collider_radius**2) - (abs(collider_pos[1]) ** 2)
if contact_radius_sq > 0:
cell_size_uniform = (cell_size[0] + cell_size[1]) * 0.5
center_radius = sqrt(contact_radius_sq) / cell_size_uniform
wp.launch(
kernel=displace_kernel,
dim=omni.warp.nodes.mesh_get_point_count(db.outputs.mesh),
inputs=[
dims[0] + 1,
center_x,
center_z,
center_radius,
db.inputs.amplitude,
db.inputs.time,
],
outputs=[
state.height_map_0,
state.height_map_1,
],
)
def simulate(
db: OgnWaveSolveDatabase,
dims: np.ndarray,
cell_size: np.ndarray,
sim_dt: bool,
) -> None:
"""Solves the wave simulation."""
state = db.internal_state
cell_size_uniform = (cell_size[0] + cell_size[1]) * 0.5
wp.launch(
kernel=simulate_kernel,
dim=omni.warp.nodes.mesh_get_point_count(db.outputs.mesh),
inputs=[
dims[0] + 1,
dims[1] + 1,
1.0 / cell_size_uniform,
db.inputs.speed,
db.inputs.damping,
sim_dt,
state.height_map_1,
],
outputs=[
state.height_map_0,
],
)
# Swap the height map buffers
state.height_map_0, state.height_map_1 = (
state.height_map_1,
state.height_map_0,
)
def update_mesh(db: OgnWaveSolveDatabase) -> None:
"""Updates the output grid mesh."""
state = db.internal_state
wp.launch(
kernel=update_mesh_kernel,
dim=omni.warp.nodes.mesh_get_point_count(db.outputs.mesh),
inputs=[
state.height_map_1,
],
outputs=[
omni.warp.nodes.mesh_get_points(db.outputs.mesh),
],
)
def compute(db: OgnWaveSolveDatabase) -> None:
"""Evaluates the node."""
db.outputs.mesh.changes().activate()
if not db.outputs.mesh.valid:
return
state = db.internal_state
# Compute the number of divisions.
dims = (db.inputs.size / db.inputs.cellSize).astype(int)
# Compute the mesh's topology counts.
face_count = dims[0] * dims[1]
vertex_count = face_count * 4
point_count = (dims[0] + 1) * (dims[1] + 1)
# Create a new geometry mesh within the output bundle.
omni.warp.nodes.mesh_create_bundle(
db.outputs.mesh,
point_count,
vertex_count,
face_count,
xform=db.inputs.transform,
create_normals=True,
create_uvs=True,
)
if state.needs_initialization(db):
# Initialize the internal state if it hasn't been already.
if not state.initialize(db, dims):
return
else:
# We skip the simulation if it has just been initialized.
# Retrieve the simulation's delta time.
timeline = omni.timeline.get_timeline_interface()
sim_rate = timeline.get_ticks_per_second()
sim_dt = 1.0 / sim_rate
# Infer the size of each cell from the overall grid size and the number
# of dimensions.
cell_size = db.inputs.size / dims
if db.inputs.collider.valid:
with omni.warp.nodes.NodeTimer("displace", db, active=PROFILING):
# Deform the grid with a displacement value if the collider
# is in contact with it.
displace(db, dims, cell_size)
with omni.warp.nodes.NodeTimer("simulate", db, active=PROFILING):
# Simulate the ripples using the wave equation.
simulate(db, dims, cell_size, sim_dt)
with omni.warp.nodes.NodeTimer("update_mesh", db, active=PROFILING):
# Update the mesh points with the height map resulting from
# the displacement and simulation steps.
update_mesh(db)
# Store the current time.
state.time = db.inputs.time
# Node Entry Point
# ------------------------------------------------------------------------------
class OgnWaveSolve:
"""Node."""
@staticmethod
def internal_state() -> InternalState:
return InternalState()
@staticmethod
def compute(db: OgnWaveSolveDatabase) -> None:
device = wp.get_device("cuda:0")
try:
with wp.ScopedDevice(device):
compute(db)
except Exception:
db.log_error(traceback.format_exc())
db.internal_state.is_valid = False
return
db.internal_state.is_valid = True
# Fire the execution for the downstream nodes.
db.outputs.execOut = og.ExecutionAttributeState.ENABLED