physctrl / src /utils /physics.py
chenwang's picture
update
4724018
import os
import h5py
import torch
import torch.nn.functional as Fn
import numpy as np
import json
class DeformLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.device = "cuda"
self.N = 2048
self.I33 = torch.eye(3, device=self.device).unsqueeze(0).repeat(self.N, 1, 1)
self.dT = 0.0417
self.grid_lim = 10
self.grid_size = 125
self.dx = self.grid_lim / self.grid_size
self.inv_dx = 1 / self.dx
self.density = 1000
def forward_sequential(self, x, vol, F, C, frame_interval=2, norm_fac=5, v=None):
# Denormalize x & Double dt (since we sample every 2 frames) for training
if norm_fac > 0:
x = x * 2 + norm_fac
dT = self.dT * frame_interval
loss = 0
for bs in range(x.shape[0]):
particle_mass = (self.density * vol[bs]).unsqueeze(-1).repeat(1, 3)
start_t = 1 if frame_interval == 1 else 0
end_t = x.shape[1] - 2
for t in range(start_t, end_t):
# Initialize
grid_m = torch.zeros((self.grid_size, self.grid_size, self.grid_size), device=self.device)
grid_v = torch.zeros((self.grid_size, self.grid_size, self.grid_size, 3), device=self.device)
particle_x = x[bs, t]
if v is not None:
particle_v = v[bs, t + 1]
else:
particle_v = (x[bs, t + 2] - x[bs, t]) / (2 * dT)
particle_F = F[bs, t].reshape(-1, 3, 3)
particle_F_next = F[bs, t + 1].reshape(-1, 3, 3)
particle_C = C[bs, t].reshape(-1, 3, 3)
# P2G
grid_pos = particle_x * self.inv_dx
base_pos = (grid_pos - 0.5).int()
fx = grid_pos - base_pos
w = [0.5 * ((1.5 - fx) ** 2), 0.75 - ((fx - 1) ** 2), 0.5 * ((fx - 0.5) ** 2)]
w = torch.stack(w, dim=2)
dw = [fx - 1.5, -2 * (fx - 1), fx - 0.5]
dw = torch.stack(dw, dim=2)
for i in range(3):
for j in range(3):
for k in range(3):
dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).repeat(self.N, 1)
dpos = (dpos - fx) * self.dx
ix = base_pos[:, 0] + i
iy = base_pos[:, 1] + j
iz = base_pos[:, 2] + k
weight = w[:, 0, i] * w[:, 1, j] * w[:, 2, k]
dweight = [dw[:, 0, i] * w[:, 1, j] * w[:, 2, k],
w[:, 0, i] * dw[:, 1, j] * w[:, 2, k],
w[:, 0, i] * w[:, 1, j] * dw[:, 2, k]]
dweight = torch.stack(dweight, dim=1) * self.inv_dx
v_in_add = weight.unsqueeze(-1) * particle_mass * (particle_v + \
(particle_C @ dpos.unsqueeze(-1)).squeeze(-1))
flat_idx = ix * self.grid_size * self.grid_size + iy * self.grid_size + iz
flat_idx = flat_idx.long()
grid_v = grid_v.view(-1, 3)
grid_v = grid_v.scatter_add(0, flat_idx.unsqueeze(-1).repeat(1, 3), v_in_add)
grid_v = grid_v.view(self.grid_size, self.grid_size, self.grid_size, 3)
grid_m = grid_m.view(-1)
grid_m = grid_m.scatter_add(0, flat_idx, weight * particle_mass[:, 0])
grid_m = grid_m.view(self.grid_size, self.grid_size, self.grid_size)
# Grid Norm
grid_m = torch.where(grid_m > 1e-15, grid_m, torch.ones_like(grid_m))
grid_v = grid_v / grid_m.unsqueeze(-1)
# G2P
new_F_pred = torch.zeros_like(particle_F)
for i in range(3):
for j in range(3):
for k in range(3):
dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).repeat(self.N, 1).float() - fx
ix = base_pos[:, 0] + i
iy = base_pos[:, 1] + j
iz = base_pos[:, 2] + k
weight = w[:, 0, i] * w[:, 1, j] * w[:, 2, k]
dweight = [dw[:, 0, i] * w[:, 1, j] * w[:, 2, k],
w[:, 0, i] * dw[:, 1, j] * w[:, 2, k],
w[:, 0, i] * w[:, 1, j] * dw[:, 2, k]]
dweight = torch.stack(dweight, dim=1) * self.inv_dx
grid_v_local = grid_v[ix, iy, iz]
new_F_pred = new_F_pred + (grid_v_local.unsqueeze(-1) @ dweight.unsqueeze(1))
F_pred = (self.I33 + new_F_pred * dT) @ particle_F
loss = loss + Fn.l1_loss(F_pred, particle_F_next)
# loss = loss + Fn.l1_loss(particle_F, particle_F_next)
return loss / x.shape[0]
def forward(self, x, vol, F, C, frame_interval=2, norm_fac=5, v=None):
# Denormalize x & Double dt (since we sample every 2 frames) for training
if norm_fac > 0:
x = x * 2 + norm_fac
dT = self.dT * frame_interval
loss = 0
bs = x.shape[0]
start_t = 1 if frame_interval == 1 else 0
end_t = x.shape[1] - 2
M = bs * (end_t - start_t)
# Initialize
grid_m = torch.zeros((M, self.grid_size, self.grid_size, self.grid_size), device=self.device)
grid_v = torch.zeros((M, self.grid_size, self.grid_size, self.grid_size, 3), device=self.device)
particle_x = x[:, start_t:end_t].reshape(M, self.N, 3)
# particle_x = x[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3)
if v is not None:
# particle_v = v[:, start_t:end_t].reshape(M, self.N, 3)
particle_v = v[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3)
else:
particle_v = (x[:, (start_t+2):(end_t+2)] - x[:, start_t:end_t]) / (2 * dT)
particle_v = particle_v.reshape(M, self.N, 3)
particle_F = F[:, start_t:end_t].reshape(M, self.N, 3, 3)
particle_F_next = F[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3, 3)
particle_C = C[:, start_t:end_t].reshape(M, self.N, 3, 3)
# particle_C = C[:, (start_t+1):(end_t+1)].reshape(M, self.N, 3, 3)
vol = vol.unsqueeze(1).repeat(1, end_t - start_t, 1).reshape(M, self.N)
particle_mass = (self.density * vol).unsqueeze(-1).repeat(1, 1, 3)
# P2G
grid_pos = particle_x * self.inv_dx
base_pos = (grid_pos - 0.5).int()
fx = grid_pos - base_pos
w = [0.5 * ((1.5 - fx) ** 2), 0.75 - ((fx - 1) ** 2), 0.5 * ((fx - 0.5) ** 2)]
w = torch.stack(w, dim=3)
dw = [fx - 1.5, -2 * (fx - 1), fx - 0.5]
dw = torch.stack(dw, dim=3)
for i in range(3):
for j in range(3):
for k in range(3):
dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).unsqueeze(0).repeat(M, self.N, 1)
dpos = (dpos - fx) * self.dx
ix = base_pos[:, :, 0] + i
iy = base_pos[:, :, 1] + j
iz = base_pos[:, :, 2] + k
weight = w[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k]
dweight = [dw[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k],
w[:, :, 0, i] * dw[:, :, 1, j] * w[:, :, 2, k],
w[:, :, 0, i] * w[:, :, 1, j] * dw[:, :, 2, k]]
dweight = torch.stack(dweight, dim=2) * self.inv_dx
v_in_add = weight.unsqueeze(-1) * particle_mass * (particle_v + \
(particle_C @ dpos.unsqueeze(-1)).squeeze(-1))
flat_idx = ix * self.grid_size * self.grid_size + iy * self.grid_size + iz
flat_idx = flat_idx.long()
grid_v = grid_v.view(M, -1, 3)
grid_v = grid_v.scatter_add(1, flat_idx.unsqueeze(-1).repeat(1, 1, 3), v_in_add)
grid_v = grid_v.view(M, self.grid_size, self.grid_size, self.grid_size, 3)
grid_m = grid_m.view(M, -1)
grid_m = grid_m.scatter_add(1, flat_idx, weight * particle_mass[:, :, 0])
grid_m = grid_m.view(M, self.grid_size, self.grid_size, self.grid_size)
# Grid Norm
grid_m = torch.where(grid_m > 1e-15, grid_m, torch.ones_like(grid_m))
grid_v = grid_v / grid_m.unsqueeze(-1)
# G2P
new_F_pred = torch.zeros_like(particle_F)
for i in range(3):
for j in range(3):
for k in range(3):
dpos = torch.tensor([i, j, k], device=self.device).unsqueeze(0).unsqueeze(0).repeat(M, self.N, 1).float() - fx
ix = base_pos[:, :, 0] + i
iy = base_pos[:, :, 1] + j
iz = base_pos[:, :, 2] + k
weight = w[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k]
dweight = [dw[:, :, 0, i] * w[:, :, 1, j] * w[:, :, 2, k],
w[:, :, 0, i] * dw[:, :, 1, j] * w[:, :, 2, k],
w[:, :, 0, i] * w[:, :, 1, j] * dw[:, :, 2, k]]
dweight = torch.stack(dweight, dim=2) * self.inv_dx
flat_idx = ix * self.grid_size * self.grid_size + iy * self.grid_size + iz
flat_idx = flat_idx.long()
grid_v = grid_v.view(M, -1, 3)
grid_v_local = grid_v.gather(1, flat_idx.unsqueeze(-1).repeat(1, 1, 3))
new_F_pred = new_F_pred + (grid_v_local.unsqueeze(-1) @ dweight.unsqueeze(2))
F_pred = (self.I33 + new_F_pred * dT) @ particle_F
loss = loss + Fn.l1_loss(F_pred, particle_F_next)
return loss * (end_t - start_t)
def loss_momentum(x, vol, force, drag_pt_num, start_frame=1, frame_interval=2,
norm_fac=5, v=None, density=1000, dt=0.0417):
# Denormalize x & Double dt (since we sample every 2 frames) for training
if norm_fac > 0:
x = x * 2 + norm_fac
dt = dt * frame_interval
loss = []
if v is not None:
v_curr = v[:, 1:-1]
else:
v_pos = x[:, 1:-1] - x[:, :-2]
v_neg = x[:, 2:] - x[:, 1:-1]
v_curr = (v_pos + v_neg) / (2 * dt)
p_int = density * vol.unsqueeze(-1).unsqueeze(1) * v_curr
p_int = p_int.sum(dim=2)
dt_acc = torch.arange(1, x.shape[1] - 1, device=p_int.device, dtype=p_int.dtype) * dt
force = force.unsqueeze(1)
drag_pt_num = drag_pt_num.unsqueeze(1)
dt_acc = dt_acc.unsqueeze(0).unsqueeze(-1).repeat(drag_pt_num.shape[0], 1, 3)
p_ext = force * dt_acc * drag_pt_num
p_ext = p_ext + start_frame * force * (dt / frame_interval) * drag_pt_num
loss = Fn.mse_loss(p_int, p_ext)
return loss