Spaces:
Runtime error
Runtime error
FrozenBurning
commited on
Commit
·
81ecb2b
1
Parent(s):
06ea84f
single view to 3D init release
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +4 -0
- README.md +2 -1
- app.py +209 -0
- assets/examples/blue_cat.png +0 -0
- assets/examples/bubble_mart_blue.png +0 -0
- assets/examples/bulldog.png +0 -0
- assets/examples/ceramic.png +0 -0
- assets/examples/chair_watermelon.png +0 -0
- assets/examples/cup_rgba.png +0 -0
- assets/examples/cute_horse.jpg +0 -0
- assets/examples/earphone.jpg +0 -0
- assets/examples/firedragon.png +0 -0
- assets/examples/fox.jpg +0 -0
- assets/examples/fruit_elephant.jpg +0 -0
- assets/examples/hatsune_miku.png +0 -0
- assets/examples/ikun_rgba.png +0 -0
- assets/examples/mailbox.png +0 -0
- assets/examples/mario.png +0 -0
- assets/examples/mei_ling_panda.png +0 -0
- assets/examples/mushroom_teapot.jpg +0 -0
- assets/examples/pikachu.png +0 -0
- assets/examples/potplant_rgba.png +0 -0
- assets/examples/seed_frog.png +0 -0
- assets/examples/shuai_panda_notail.png +0 -0
- assets/examples/yellow_duck.png +0 -0
- configs/inference_dit.yml +97 -0
- dva/__init__.py +5 -0
- dva/attr_dict.py +66 -0
- dva/geom.py +653 -0
- dva/io.py +56 -0
- dva/layers.py +157 -0
- dva/losses.py +239 -0
- dva/mvp/extensions/mvpraymarch/bvh.cu +292 -0
- dva/mvp/extensions/mvpraymarch/cudadispatch.h +104 -0
- dva/mvp/extensions/mvpraymarch/helper_math.h +1453 -0
- dva/mvp/extensions/mvpraymarch/makefile +2 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp +405 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch.py +559 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu +208 -0
- dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h +218 -0
- dva/mvp/extensions/mvpraymarch/primaccum.h +101 -0
- dva/mvp/extensions/mvpraymarch/primsampler.h +94 -0
- dva/mvp/extensions/mvpraymarch/primtransf.h +182 -0
- dva/mvp/extensions/mvpraymarch/setup.py +30 -0
- dva/mvp/extensions/mvpraymarch/utils.h +847 -0
- dva/mvp/extensions/utils/helper_math.h +1453 -0
- dva/mvp/extensions/utils/makefile +2 -0
- dva/mvp/extensions/utils/setup.py +29 -0
- dva/mvp/extensions/utils/utils.cpp +137 -0
- dva/mvp/extensions/utils/utils.py +211 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__
|
| 2 |
+
build
|
| 3 |
+
*.so
|
| 4 |
+
runs
|
README.md
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
---
|
| 2 |
-
title: 3DTopia
|
| 3 |
emoji: 🌖
|
| 4 |
colorFrom: green
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.41.0
|
|
|
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 1 |
---
|
| 2 |
+
title: 3DTopia-XL
|
| 3 |
emoji: 🌖
|
| 4 |
colorFrom: green
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 4.41.0
|
| 8 |
+
python_version: 3.9
|
| 9 |
app_file: app.py
|
| 10 |
pinned: false
|
| 11 |
---
|
app.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import imageio
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
os.system("bash install.sh")
|
| 6 |
+
|
| 7 |
+
from omegaconf import OmegaConf
|
| 8 |
+
import tqdm
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
import torchvision.transforms.functional as TF
|
| 13 |
+
import rembg
|
| 14 |
+
import gradio as gr
|
| 15 |
+
from dva.io import load_from_config
|
| 16 |
+
from dva.ray_marcher import RayMarcher
|
| 17 |
+
from dva.visualize import visualize_primvolume, visualize_video_primvolume
|
| 18 |
+
from inference import remove_background, resize_foreground, extract_texmesh
|
| 19 |
+
from models.diffusion import create_diffusion
|
| 20 |
+
from huggingface_hub import hf_hub_download
|
| 21 |
+
ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_sview_dit_fp16.pt")
|
| 22 |
+
vae_ckpt_path = hf_hub_download(repo_id="frozenburning/3DTopia-XL", filename="model_vae_fp16.pt")
|
| 23 |
+
|
| 24 |
+
GRADIO_PRIM_VIDEO_PATH = 'prim.mp4'
|
| 25 |
+
GRADIO_RGB_VIDEO_PATH = 'rgb.mp4'
|
| 26 |
+
GRADIO_MAT_VIDEO_PATH = 'mat.mp4'
|
| 27 |
+
GRADIO_GLB_PATH = 'pbr_mesh.glb'
|
| 28 |
+
CONFIG_PATH = "./configs/inference_dit.yml"
|
| 29 |
+
|
| 30 |
+
config = OmegaConf.load(CONFIG_PATH)
|
| 31 |
+
config.checkpoint_path = ckpt_path
|
| 32 |
+
config.model.vae_checkpoint_path = vae_ckpt_path
|
| 33 |
+
# model
|
| 34 |
+
model = load_from_config(config.model.generator)
|
| 35 |
+
state_dict = torch.load(config.checkpoint_path, map_location='cpu')
|
| 36 |
+
model.load_state_dict(state_dict['ema'])
|
| 37 |
+
vae = load_from_config(config.model.vae)
|
| 38 |
+
vae_state_dict = torch.load(config.model.vae_checkpoint_path, map_location='cpu')
|
| 39 |
+
vae.load_state_dict(vae_state_dict['model_state_dict'])
|
| 40 |
+
conditioner = load_from_config(config.model.conditioner)
|
| 41 |
+
|
| 42 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 43 |
+
vae = vae.to(device)
|
| 44 |
+
conditioner = conditioner.to(device)
|
| 45 |
+
model = model.to(device)
|
| 46 |
+
model.eval()
|
| 47 |
+
|
| 48 |
+
amp = True
|
| 49 |
+
precision_dtype = torch.float16
|
| 50 |
+
|
| 51 |
+
rm = RayMarcher(
|
| 52 |
+
config.image_height,
|
| 53 |
+
config.image_width,
|
| 54 |
+
**config.rm,
|
| 55 |
+
).to(device)
|
| 56 |
+
|
| 57 |
+
perchannel_norm = False
|
| 58 |
+
if "latent_mean" in config.model:
|
| 59 |
+
latent_mean = torch.Tensor(config.model.latent_mean)[None, None, :].to(device)
|
| 60 |
+
latent_std = torch.Tensor(config.model.latent_std)[None, None, :].to(device)
|
| 61 |
+
assert latent_mean.shape[-1] == config.model.generator.in_channels
|
| 62 |
+
perchannel_norm = True
|
| 63 |
+
|
| 64 |
+
config.diffusion.pop("timestep_respacing")
|
| 65 |
+
config.model.pop("vae")
|
| 66 |
+
config.model.pop("vae_checkpoint_path")
|
| 67 |
+
config.model.pop("conditioner")
|
| 68 |
+
config.model.pop("generator")
|
| 69 |
+
config.model.pop("latent_nf")
|
| 70 |
+
config.model.pop("latent_mean")
|
| 71 |
+
config.model.pop("latent_std")
|
| 72 |
+
model_primx = load_from_config(config.model)
|
| 73 |
+
# load rembg
|
| 74 |
+
rembg_session = rembg.new_session()
|
| 75 |
+
|
| 76 |
+
# process function
|
| 77 |
+
def process(input_image, input_num_steps=25, input_seed=42, input_cfg=6.0):
|
| 78 |
+
# seed
|
| 79 |
+
torch.manual_seed(input_seed)
|
| 80 |
+
|
| 81 |
+
os.makedirs(config.output_dir, exist_ok=True)
|
| 82 |
+
output_rgb_video_path = os.path.join(config.output_dir, GRADIO_RGB_VIDEO_PATH)
|
| 83 |
+
output_prim_video_path = os.path.join(config.output_dir, GRADIO_PRIM_VIDEO_PATH)
|
| 84 |
+
output_mat_video_path = os.path.join(config.output_dir, GRADIO_MAT_VIDEO_PATH)
|
| 85 |
+
output_glb_path = os.path.join(config.output_dir, GRADIO_GLB_PATH)
|
| 86 |
+
|
| 87 |
+
diffusion = create_diffusion(timestep_respacing=respacing, **config.diffusion)
|
| 88 |
+
sample_fn = diffusion.ddim_sample_loop_progressive
|
| 89 |
+
fwd_fn = model.forward_with_cfg
|
| 90 |
+
|
| 91 |
+
# text-conditioned
|
| 92 |
+
if input_image is None:
|
| 93 |
+
raise NotImplementedError
|
| 94 |
+
# image-conditioned (may also input text, but no text usually works too)
|
| 95 |
+
else:
|
| 96 |
+
input_image = remove_background(input_image, rembg_session)
|
| 97 |
+
input_image = resize_foreground(input_image, 0.85)
|
| 98 |
+
raw_image = np.array(input_image)
|
| 99 |
+
mask = (raw_image[..., -1][..., None] > 0) * 1
|
| 100 |
+
raw_image = raw_image[..., :3] * mask
|
| 101 |
+
input_cond = torch.from_numpy(np.array(raw_image)[None, ...]).to(device)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
latent = torch.randn(1, config.model.num_prims, 1, 4, 4, 4)
|
| 105 |
+
batch = {}
|
| 106 |
+
inf_bs = 1
|
| 107 |
+
inf_x = torch.randn(inf_bs, config.model.num_prims, 68).to(device)
|
| 108 |
+
y = conditioner.encoder(input_cond)
|
| 109 |
+
model_kwargs = dict(y=y[:inf_bs, ...], precision_dtype=precision_dtype, enable_amp=amp)
|
| 110 |
+
if input_cfg >= 0:
|
| 111 |
+
model_kwargs['cfg_scale'] = input_cfg
|
| 112 |
+
for samples in sample_fn(fwd_fn, inf_x.shape, inf_x, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device):
|
| 113 |
+
final_samples = samples
|
| 114 |
+
recon_param = final_samples["sample"].reshape(inf_bs, config.model.num_prims, -1)
|
| 115 |
+
if perchannel_norm:
|
| 116 |
+
recon_param = recon_param / config.model.latent_nf * latent_std + latent_mean
|
| 117 |
+
recon_srt_param = recon_param[:, :, 0:4]
|
| 118 |
+
recon_feat_param = recon_param[:, :, 4:] # [8, 2048, 64]
|
| 119 |
+
recon_feat_param_list = []
|
| 120 |
+
# one-by-one to avoid oom
|
| 121 |
+
for inf_bidx in range(inf_bs):
|
| 122 |
+
if not perchannel_norm:
|
| 123 |
+
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]) / config.model.latent_nf)
|
| 124 |
+
else:
|
| 125 |
+
decoded = vae.decode(recon_feat_param[inf_bidx, ...].reshape(1*config.model.num_prims, *latent.shape[-4:]))
|
| 126 |
+
recon_feat_param_list.append(decoded.detach())
|
| 127 |
+
recon_feat_param = torch.concat(recon_feat_param_list, dim=0)
|
| 128 |
+
# invert normalization
|
| 129 |
+
if not perchannel_norm:
|
| 130 |
+
recon_srt_param[:, :, 0:1] = (recon_srt_param[:, :, 0:1] / 10) + 0.05
|
| 131 |
+
recon_feat_param[:, 0:1, ...] /= 5.
|
| 132 |
+
recon_feat_param[:, 1:, ...] = (recon_feat_param[:, 1:, ...] + 1) / 2.
|
| 133 |
+
recon_feat_param = recon_feat_param.reshape(inf_bs, config.model.num_prims, -1)
|
| 134 |
+
recon_param = torch.concat([recon_srt_param, recon_feat_param], dim=-1)
|
| 135 |
+
visualize_video_primvolume(config.output_dir, batch, recon_param, 60, rm, device)
|
| 136 |
+
prim_params = {'srt_param': recon_srt_param[0].detach().cpu(), 'feat_param': recon_feat_param[0].detach().cpu()}
|
| 137 |
+
torch.save({'model_state_dict': prim_params}, "{}/denoised.pt".format(config.output_dir))
|
| 138 |
+
|
| 139 |
+
# exporting GLB mesh
|
| 140 |
+
denoise_param_path = os.path.join(config.output_dir, 'denoised.pt')
|
| 141 |
+
primx_ckpt_weight = torch.load(denoise_param_path, map_location='cpu')['model_state_dict']
|
| 142 |
+
model_primx.load_state_dict(ckpt_weight)
|
| 143 |
+
model_primx.to(device)
|
| 144 |
+
model_primx.eval()
|
| 145 |
+
with torch.no_grad():
|
| 146 |
+
model_primx.srt_param[:, 1:4] *= 0.85
|
| 147 |
+
extract_texmesh(config.inference, model_primx, output_glb_path, device)
|
| 148 |
+
|
| 149 |
+
return output_rgb_video_path, output_prim_video_path, output_mat_video_path, output_glb_path
|
| 150 |
+
|
| 151 |
+
# gradio UI
|
| 152 |
+
_TITLE = '''3DTopia-XL'''
|
| 153 |
+
|
| 154 |
+
_DESCRIPTION = '''
|
| 155 |
+
<div>
|
| 156 |
+
<a style="display:inline-block" href="https://frozenburning.github.io/projects/3DTopia-XL/"><img src='https://img.shields.io/badge/public_website-8A2BE2'></a>
|
| 157 |
+
<a style="display:inline-block; margin-left: .5em" href="https://github.com/3DTopia/3DTopia-XL"><img src='https://img.shields.io/github/stars/3DTopia/3DTopia-XL?style=social'/></a>
|
| 158 |
+
</div>
|
| 159 |
+
|
| 160 |
+
* Now we offer 1) single image conditioned model, we will release 2) multiview images conditioned model and 3) pure text conditioned model in the future!
|
| 161 |
+
* If you find the output unsatisfying, try using different seeds!
|
| 162 |
+
'''
|
| 163 |
+
|
| 164 |
+
block = gr.Blocks(title=_TITLE).queue()
|
| 165 |
+
with block:
|
| 166 |
+
with gr.Row():
|
| 167 |
+
with gr.Column(scale=1):
|
| 168 |
+
gr.Markdown('# ' + _TITLE)
|
| 169 |
+
gr.Markdown(_DESCRIPTION)
|
| 170 |
+
|
| 171 |
+
with gr.Row(variant='panel'):
|
| 172 |
+
with gr.Column(scale=1):
|
| 173 |
+
# input image
|
| 174 |
+
input_image = gr.Image(label="image", type='pil')
|
| 175 |
+
# inference steps
|
| 176 |
+
input_num_steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=25)
|
| 177 |
+
# random seed
|
| 178 |
+
input_cfg = gr.Slider(label="CFG scale", minimum=0, maximum=15, step=1, value=6)
|
| 179 |
+
# random seed
|
| 180 |
+
input_seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=42)
|
| 181 |
+
# gen button
|
| 182 |
+
button_gen = gr.Button("Generate")
|
| 183 |
+
|
| 184 |
+
with gr.Column(scale=1):
|
| 185 |
+
with gr.Tab("Video"):
|
| 186 |
+
# final video results
|
| 187 |
+
output_rgb_video = gr.Video(label="video")
|
| 188 |
+
output_prim_video = gr.Video(label="video")
|
| 189 |
+
output_mat_video = gr.Video(label="video")
|
| 190 |
+
with gr.Tab("GLB"):
|
| 191 |
+
# glb file
|
| 192 |
+
output_glb = gr.File(label="glb")
|
| 193 |
+
|
| 194 |
+
button_gen.click(process, inputs=[input_image, input_num_steps, input_seed, input_cfg], outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb])
|
| 195 |
+
|
| 196 |
+
gr.Examples(
|
| 197 |
+
examples=[
|
| 198 |
+
"assets/examples/fruit_elephant.jpg",
|
| 199 |
+
"assets/examples/mei_ling_panda.png",
|
| 200 |
+
"assets/examples/shuai_panda_notail.png",
|
| 201 |
+
],
|
| 202 |
+
inputs=[input_image],
|
| 203 |
+
outputs=[output_rgb_video, output_prim_video, output_mat_video, output_glb],
|
| 204 |
+
fn=lambda x: process(input_image=x),
|
| 205 |
+
cache_examples=False,
|
| 206 |
+
label='Single Image to 3D PBR Asset'
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
block.launch(server_name="0.0.0.0", share=True)
|
assets/examples/blue_cat.png
ADDED
|
assets/examples/bubble_mart_blue.png
ADDED
|
assets/examples/bulldog.png
ADDED
|
assets/examples/ceramic.png
ADDED
|
assets/examples/chair_watermelon.png
ADDED
|
assets/examples/cup_rgba.png
ADDED
|
assets/examples/cute_horse.jpg
ADDED
|
assets/examples/earphone.jpg
ADDED
|
assets/examples/firedragon.png
ADDED
|
assets/examples/fox.jpg
ADDED
|
assets/examples/fruit_elephant.jpg
ADDED
|
assets/examples/hatsune_miku.png
ADDED
|
assets/examples/ikun_rgba.png
ADDED
|
assets/examples/mailbox.png
ADDED
|
assets/examples/mario.png
ADDED
|
assets/examples/mei_ling_panda.png
ADDED
|
assets/examples/mushroom_teapot.jpg
ADDED
|
assets/examples/pikachu.png
ADDED
|
assets/examples/potplant_rgba.png
ADDED
|
assets/examples/seed_frog.png
ADDED
|
assets/examples/shuai_panda_notail.png
ADDED
|
assets/examples/yellow_duck.png
ADDED
|
configs/inference_dit.yml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
debug: False
|
| 2 |
+
root_data_dir: ./runs
|
| 3 |
+
checkpoint_path:
|
| 4 |
+
global_seed: 42
|
| 5 |
+
|
| 6 |
+
inference:
|
| 7 |
+
input_dir:
|
| 8 |
+
ddim: 25
|
| 9 |
+
cfg: 6
|
| 10 |
+
seed: ${global_seed}
|
| 11 |
+
precision: fp16
|
| 12 |
+
export_glb: True
|
| 13 |
+
decimate: 100000
|
| 14 |
+
mc_resolution: 256
|
| 15 |
+
batch_size: 4096
|
| 16 |
+
remesh: False
|
| 17 |
+
|
| 18 |
+
image_height: 518
|
| 19 |
+
image_width: 518
|
| 20 |
+
|
| 21 |
+
model:
|
| 22 |
+
class_name: models.primsdf.PrimSDF
|
| 23 |
+
num_prims: 2048
|
| 24 |
+
dim_feat: 6
|
| 25 |
+
prim_shape: 8
|
| 26 |
+
init_scale: 0.05 # useless if auto_scale_init == True
|
| 27 |
+
sdf2alpha_var: 0.005
|
| 28 |
+
auto_scale_init: True
|
| 29 |
+
init_sampling: uniform
|
| 30 |
+
vae:
|
| 31 |
+
class_name: models.vae3d_dib.VAE
|
| 32 |
+
in_channels: ${model.dim_feat}
|
| 33 |
+
latent_channels: 1
|
| 34 |
+
out_channels: ${model.vae.in_channels}
|
| 35 |
+
down_channels: [32, 256]
|
| 36 |
+
mid_attention: True
|
| 37 |
+
up_channels: [256, 32]
|
| 38 |
+
layers_per_block: 2
|
| 39 |
+
gradient_checkpointing: False
|
| 40 |
+
vae_checkpoint_path:
|
| 41 |
+
conditioner:
|
| 42 |
+
class_name: models.conditioner.image.ImageConditioner
|
| 43 |
+
num_prims: ${model.num_prims}
|
| 44 |
+
dim_feat: ${model.dim_feat}
|
| 45 |
+
prim_shape: ${model.prim_shape}
|
| 46 |
+
sample_view: False
|
| 47 |
+
encoder_config:
|
| 48 |
+
class_name: models.conditioner.image_dinov2.Dinov2Wrapper
|
| 49 |
+
model_name: dinov2_vitb14_reg
|
| 50 |
+
freeze: True
|
| 51 |
+
generator:
|
| 52 |
+
class_name: models.dit_crossattn.DiT
|
| 53 |
+
seq_length: ${model.num_prims}
|
| 54 |
+
in_channels: 68 # equals to model.vae.latent_channels * latent_dim^3
|
| 55 |
+
condition_channels: 768
|
| 56 |
+
hidden_size: 1152
|
| 57 |
+
depth: 28
|
| 58 |
+
num_heads: 16
|
| 59 |
+
attn_proj_bias: True
|
| 60 |
+
cond_drop_prob: 0.1
|
| 61 |
+
gradient_checkpointing: False
|
| 62 |
+
latent_nf: 1.0
|
| 63 |
+
latent_mean: [ 0.0442, -0.0029, -0.0425, -0.0043, -0.4086, -0.2906, -0.7002, -0.0852, -0.4446, -0.6896, -0.7344, -0.3524, -0.5488, -0.4313, -1.1715, -0.0875, -0.6131, -0.3924, -0.7335, -0.3749, 0.4658, -0.0236, 0.8362, 0.3388, 0.0188, 0.5988, -0.1853, 1.1579, 0.6240, 0.0758, 0.9641, 0.6586, 0.6260, 0.2384, 0.7798, 0.8297, -0.6543, -0.4441, -1.3887, -0.0393, -0.9008, -0.8616, -1.7434, -0.1328, -0.8119, -0.8225, -1.8533, -0.0444, -1.0510, -0.5158, -1.1907, -0.5265, 0.2832, 0.6037, 0.5981, 0.5461, 0.4366, 0.4144, 0.7219, 0.5722, 0.5937, 0.5598, 0.9414, 0.7419, 0.2102, 0.3388, 0.4501, 0.5166]
|
| 64 |
+
latent_std: [0.0219, 0.3707, 0.3911, 0.3610, 0.7549, 0.7909, 0.9691, 0.9193, 0.8218, 0.9389, 1.1785, 1.0254, 0.6376, 0.6568, 0.7892, 0.8468, 0.8775, 0.7920, 0.9037, 0.9329, 0.9196, 1.1123, 1.3041, 1.0955, 1.2727, 1.6565, 1.8502, 1.7006, 0.8973, 1.0408, 1.2034, 1.2703, 1.0373, 1.0486, 1.0716, 0.9746, 0.7088, 0.8685, 1.0030, 0.9504, 1.0410, 1.3033, 1.5368, 1.4386, 0.6142, 0.6887, 0.9085, 0.9903, 1.0190, 0.9302, 1.0121, 0.9964, 1.1474, 1.2729, 1.4627, 1.1404, 1.3713, 1.6692, 1.8424, 1.5047, 1.1356, 1.2369, 1.3554, 1.1848, 1.1319, 1.0822, 1.1972, 0.9916]
|
| 65 |
+
|
| 66 |
+
diffusion:
|
| 67 |
+
timestep_respacing:
|
| 68 |
+
noise_schedule: squaredcos_cap_v2
|
| 69 |
+
diffusion_steps: 1000
|
| 70 |
+
parameterization: v
|
| 71 |
+
|
| 72 |
+
rm:
|
| 73 |
+
volradius: 10000.0
|
| 74 |
+
dt: 1.0
|
| 75 |
+
|
| 76 |
+
optimizer:
|
| 77 |
+
class_name: torch.optim.AdamW
|
| 78 |
+
lr: 0.0001
|
| 79 |
+
weight_decay: 0
|
| 80 |
+
|
| 81 |
+
scheduler:
|
| 82 |
+
class_name: dva.scheduler.CosineWarmupScheduler
|
| 83 |
+
warmup_iters: 3000
|
| 84 |
+
max_iters: 200000
|
| 85 |
+
|
| 86 |
+
train:
|
| 87 |
+
batch_size: 8
|
| 88 |
+
n_workers: 4
|
| 89 |
+
n_epochs: 1000
|
| 90 |
+
log_every_n_steps: 50
|
| 91 |
+
summary_every_n_steps: 10000
|
| 92 |
+
ckpt_every_n_steps: 10000
|
| 93 |
+
amp: False
|
| 94 |
+
precision: tf32
|
| 95 |
+
|
| 96 |
+
tag: 3dtopia-xl-sview
|
| 97 |
+
output_dir: ${root_data_dir}/inference/${tag}
|
dva/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
dva/attr_dict.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class AttrDict:
|
| 11 |
+
def __init__(self, entries):
|
| 12 |
+
self.add_entries_(entries)
|
| 13 |
+
|
| 14 |
+
def keys(self):
|
| 15 |
+
return self.__dict__.keys()
|
| 16 |
+
|
| 17 |
+
def values(self):
|
| 18 |
+
return self.__dict__.values()
|
| 19 |
+
|
| 20 |
+
def __getitem__(self, key):
|
| 21 |
+
return self.__dict__[key]
|
| 22 |
+
|
| 23 |
+
def __setitem__(self, key, value):
|
| 24 |
+
self.__dict__[key] = value
|
| 25 |
+
|
| 26 |
+
def __delitem__(self, key):
|
| 27 |
+
return self.__dict__.__delitem__(key)
|
| 28 |
+
|
| 29 |
+
def __contains__(self, key):
|
| 30 |
+
return key in self.__dict__
|
| 31 |
+
|
| 32 |
+
def __repr__(self):
|
| 33 |
+
return self.__dict__.__repr__()
|
| 34 |
+
|
| 35 |
+
def __getattr__(self, attr):
|
| 36 |
+
if attr.startswith("__"):
|
| 37 |
+
return self.__getattribute__(attr)
|
| 38 |
+
return self.__dict__[attr]
|
| 39 |
+
|
| 40 |
+
def items(self):
|
| 41 |
+
return self.__dict__.items()
|
| 42 |
+
|
| 43 |
+
def __iter__(self):
|
| 44 |
+
return iter(self.items())
|
| 45 |
+
|
| 46 |
+
def add_entries_(self, entries, overwrite=True):
|
| 47 |
+
for key, value in entries.items():
|
| 48 |
+
if key not in self.__dict__:
|
| 49 |
+
if isinstance(value, dict):
|
| 50 |
+
self.__dict__[key] = AttrDict(value)
|
| 51 |
+
else:
|
| 52 |
+
self.__dict__[key] = value
|
| 53 |
+
else:
|
| 54 |
+
if isinstance(value, dict):
|
| 55 |
+
self.__dict__[key].add_entries_(entries=value, overwrite=overwrite)
|
| 56 |
+
elif overwrite or self.__dict__[key] is None:
|
| 57 |
+
self.__dict__[key] = value
|
| 58 |
+
|
| 59 |
+
def serialize(self):
|
| 60 |
+
return json.dumps(self, default=self.obj_to_dict, indent=4)
|
| 61 |
+
|
| 62 |
+
def obj_to_dict(self, obj):
|
| 63 |
+
return obj.__dict__
|
| 64 |
+
|
| 65 |
+
def get(self, key, default=None):
|
| 66 |
+
return self.__dict__.get(key, default)
|
dva/geom.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch as th
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from sklearn.neighbors import KDTree
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger(__name__)
|
| 12 |
+
|
| 13 |
+
# NOTE: we need pytorch3d primarily for UV rasterization things
|
| 14 |
+
from pytorch3d.renderer.mesh.rasterize_meshes import rasterize_meshes
|
| 15 |
+
from pytorch3d.structures import Meshes
|
| 16 |
+
from typing import Union, Optional, Tuple
|
| 17 |
+
import trimesh
|
| 18 |
+
from trimesh import Trimesh
|
| 19 |
+
from trimesh.triangles import points_to_barycentric
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
# pyre-fixme[21]: Could not find module `igl`.
|
| 23 |
+
from igl import point_mesh_squared_distance # @manual
|
| 24 |
+
|
| 25 |
+
# pyre-fixme[3]: Return type must be annotated.
|
| 26 |
+
# pyre-fixme[2]: Parameter must be annotated.
|
| 27 |
+
def closest_point(mesh, points):
|
| 28 |
+
"""Helper function that mimics trimesh.proximity.closest_point but uses
|
| 29 |
+
IGL for faster queries."""
|
| 30 |
+
v = mesh.vertices
|
| 31 |
+
vi = mesh.faces
|
| 32 |
+
dist, face_idxs, p = point_mesh_squared_distance(points, v, vi)
|
| 33 |
+
return p, dist, face_idxs
|
| 34 |
+
|
| 35 |
+
except ImportError:
|
| 36 |
+
from trimesh.proximity import closest_point
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def closest_point_barycentrics(v, vi, points):
|
| 40 |
+
"""Given a 3D mesh and a set of query points, return closest point barycentrics
|
| 41 |
+
Args:
|
| 42 |
+
v: np.array (float)
|
| 43 |
+
[N, 3] mesh vertices
|
| 44 |
+
|
| 45 |
+
vi: np.array (int)
|
| 46 |
+
[N, 3] mesh triangle indices
|
| 47 |
+
|
| 48 |
+
points: np.array (float)
|
| 49 |
+
[M, 3] query points
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Tuple[approx, barys, interp_idxs, face_idxs]
|
| 53 |
+
approx: [M, 3] approximated (closest) points on the mesh
|
| 54 |
+
barys: [M, 3] barycentric weights that produce "approx"
|
| 55 |
+
interp_idxs: [M, 3] vertex indices for barycentric interpolation
|
| 56 |
+
face_idxs: [M] face indices for barycentric interpolation. interp_idxs = vi[face_idxs]
|
| 57 |
+
"""
|
| 58 |
+
mesh = Trimesh(vertices=v, faces=vi, process=False)
|
| 59 |
+
p, _, face_idxs = closest_point(mesh, points)
|
| 60 |
+
p = p.reshape((points.shape[0], 3))
|
| 61 |
+
face_idxs = face_idxs.reshape((points.shape[0],))
|
| 62 |
+
barys = points_to_barycentric(mesh.triangles[face_idxs], p)
|
| 63 |
+
b0, b1, b2 = np.split(barys, 3, axis=1)
|
| 64 |
+
|
| 65 |
+
interp_idxs = vi[face_idxs]
|
| 66 |
+
v0 = v[interp_idxs[:, 0]]
|
| 67 |
+
v1 = v[interp_idxs[:, 1]]
|
| 68 |
+
v2 = v[interp_idxs[:, 2]]
|
| 69 |
+
approx = b0 * v0 + b1 * v1 + b2 * v2
|
| 70 |
+
return approx, barys, interp_idxs, face_idxs
|
| 71 |
+
|
| 72 |
+
def make_uv_face_index(
|
| 73 |
+
vt: th.Tensor,
|
| 74 |
+
vti: th.Tensor,
|
| 75 |
+
uv_shape: Union[Tuple[int, int], int],
|
| 76 |
+
flip_uv: bool = True,
|
| 77 |
+
device: Optional[Union[str, th.device]] = None,
|
| 78 |
+
):
|
| 79 |
+
"""Compute a UV-space face index map identifying which mesh face contains each
|
| 80 |
+
texel. For texels with no assigned triangle, the index will be -1."""
|
| 81 |
+
|
| 82 |
+
if isinstance(uv_shape, int):
|
| 83 |
+
uv_shape = (uv_shape, uv_shape)
|
| 84 |
+
|
| 85 |
+
uv_max_shape_ind = uv_shape.index(max(uv_shape))
|
| 86 |
+
uv_min_shape_ind = uv_shape.index(min(uv_shape))
|
| 87 |
+
uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
|
| 88 |
+
|
| 89 |
+
if device is not None:
|
| 90 |
+
if isinstance(device, str):
|
| 91 |
+
dev = th.device(device)
|
| 92 |
+
else:
|
| 93 |
+
dev = device
|
| 94 |
+
assert dev.type == "cuda"
|
| 95 |
+
else:
|
| 96 |
+
dev = th.device("cuda")
|
| 97 |
+
|
| 98 |
+
vt = 1.0 - vt.clone()
|
| 99 |
+
|
| 100 |
+
if flip_uv:
|
| 101 |
+
vt = vt.clone()
|
| 102 |
+
vt[:, 1] = 1 - vt[:, 1]
|
| 103 |
+
vt_pix = 2.0 * vt.to(dev) - 1.0
|
| 104 |
+
vt_pix = th.cat([vt_pix, th.ones_like(vt_pix[:, 0:1])], dim=1)
|
| 105 |
+
|
| 106 |
+
vt_pix[:, uv_min_shape_ind] *= uv_ratio
|
| 107 |
+
meshes = Meshes(vt_pix[np.newaxis], vti[np.newaxis].to(dev))
|
| 108 |
+
with th.no_grad():
|
| 109 |
+
face_index, _, _, _ = rasterize_meshes(
|
| 110 |
+
meshes, uv_shape, faces_per_pixel=1, z_clip_value=0.0, bin_size=0
|
| 111 |
+
)
|
| 112 |
+
face_index = face_index[0, ..., 0]
|
| 113 |
+
return face_index
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def make_uv_vert_index(
|
| 117 |
+
vt: th.Tensor,
|
| 118 |
+
vi: th.Tensor,
|
| 119 |
+
vti: th.Tensor,
|
| 120 |
+
uv_shape: Union[Tuple[int, int], int],
|
| 121 |
+
flip_uv: bool = True,
|
| 122 |
+
):
|
| 123 |
+
"""Compute a UV-space vertex index map identifying which mesh vertices
|
| 124 |
+
comprise the triangle containing each texel. For texels with no assigned
|
| 125 |
+
triangle, all indices will be -1.
|
| 126 |
+
"""
|
| 127 |
+
face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv)
|
| 128 |
+
vert_index_map = vi[face_index_map.clamp(min=0)]
|
| 129 |
+
vert_index_map[face_index_map < 0] = -1
|
| 130 |
+
return vert_index_map.long()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def bary_coords(points: th.Tensor, triangles: th.Tensor, eps: float = 1.0e-6):
|
| 134 |
+
"""Computes barycentric coordinates for a set of 2D query points given
|
| 135 |
+
coordintes for the 3 vertices of the enclosing triangle for each point."""
|
| 136 |
+
x = points[:, 0] - triangles[2, :, 0]
|
| 137 |
+
x1 = triangles[0, :, 0] - triangles[2, :, 0]
|
| 138 |
+
x2 = triangles[1, :, 0] - triangles[2, :, 0]
|
| 139 |
+
y = points[:, 1] - triangles[2, :, 1]
|
| 140 |
+
y1 = triangles[0, :, 1] - triangles[2, :, 1]
|
| 141 |
+
y2 = triangles[1, :, 1] - triangles[2, :, 1]
|
| 142 |
+
denom = y2 * x1 - y1 * x2
|
| 143 |
+
n0 = y2 * x - x2 * y
|
| 144 |
+
n1 = x1 * y - y1 * x
|
| 145 |
+
|
| 146 |
+
# Small epsilon to prevent divide-by-zero error.
|
| 147 |
+
denom = th.where(denom >= 0, denom.clamp(min=eps), denom.clamp(max=-eps))
|
| 148 |
+
|
| 149 |
+
bary_0 = n0 / denom
|
| 150 |
+
bary_1 = n1 / denom
|
| 151 |
+
bary_2 = 1.0 - bary_0 - bary_1
|
| 152 |
+
|
| 153 |
+
return th.stack((bary_0, bary_1, bary_2))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def make_uv_barys(
|
| 157 |
+
vt: th.Tensor,
|
| 158 |
+
vti: th.Tensor,
|
| 159 |
+
uv_shape: Union[Tuple[int, int], int],
|
| 160 |
+
flip_uv: bool = True,
|
| 161 |
+
):
|
| 162 |
+
"""Compute a UV-space barycentric map where each texel contains barycentric
|
| 163 |
+
coordinates for that texel within its enclosing UV triangle. For texels
|
| 164 |
+
with no assigned triangle, all 3 barycentric coordinates will be 0.
|
| 165 |
+
"""
|
| 166 |
+
if isinstance(uv_shape, int):
|
| 167 |
+
uv_shape = (uv_shape, uv_shape)
|
| 168 |
+
|
| 169 |
+
if flip_uv:
|
| 170 |
+
# Flip here because texture coordinates in some of our topo files are
|
| 171 |
+
# stored in OpenGL convention with Y=0 on the bottom of the texture
|
| 172 |
+
# unlike numpy/torch arrays/tensors.
|
| 173 |
+
vt = vt.clone()
|
| 174 |
+
vt[:, 1] = 1 - vt[:, 1]
|
| 175 |
+
|
| 176 |
+
face_index_map = make_uv_face_index(vt, vti, uv_shape, flip_uv=False)
|
| 177 |
+
vti_map = vti.long()[face_index_map.clamp(min=0)]
|
| 178 |
+
|
| 179 |
+
uv_max_shape_ind = uv_shape.index(max(uv_shape))
|
| 180 |
+
uv_min_shape_ind = uv_shape.index(min(uv_shape))
|
| 181 |
+
uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
|
| 182 |
+
vt = vt.clone()
|
| 183 |
+
vt = vt * 2 - 1
|
| 184 |
+
vt[:, uv_min_shape_ind] *= uv_ratio
|
| 185 |
+
uv_tri_uvs = vt[vti_map].permute(2, 0, 1, 3)
|
| 186 |
+
|
| 187 |
+
uv_grid = th.meshgrid(
|
| 188 |
+
th.linspace(0.5, uv_shape[0] - 0.5, uv_shape[0]) / uv_shape[0],
|
| 189 |
+
th.linspace(0.5, uv_shape[1] - 0.5, uv_shape[1]) / uv_shape[1],
|
| 190 |
+
)
|
| 191 |
+
uv_grid = th.stack(uv_grid[::-1], dim=2).to(uv_tri_uvs)
|
| 192 |
+
uv_grid = uv_grid * 2 - 1
|
| 193 |
+
uv_grid[..., uv_min_shape_ind] *= uv_ratio
|
| 194 |
+
|
| 195 |
+
bary_map = bary_coords(uv_grid.view(-1, 2), uv_tri_uvs.view(3, -1, 2))
|
| 196 |
+
bary_map = bary_map.permute(1, 0).view(uv_shape[0], uv_shape[1], 3)
|
| 197 |
+
bary_map[face_index_map < 0] = 0
|
| 198 |
+
return face_index_map, bary_map
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def index_image_impaint(
|
| 202 |
+
index_image: th.Tensor,
|
| 203 |
+
bary_image: Optional[th.Tensor] = None,
|
| 204 |
+
distance_threshold=100.0,
|
| 205 |
+
):
|
| 206 |
+
# getting the mask around the indexes?
|
| 207 |
+
if len(index_image.shape) == 3:
|
| 208 |
+
valid_index = (index_image != -1).any(dim=-1)
|
| 209 |
+
elif len(index_image.shape) == 2:
|
| 210 |
+
valid_index = index_image != -1
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError("`index_image` should be a [H,W] or [H,W,C] image")
|
| 213 |
+
|
| 214 |
+
invalid_index = ~valid_index
|
| 215 |
+
|
| 216 |
+
device = index_image.device
|
| 217 |
+
|
| 218 |
+
valid_ij = th.stack(th.where(valid_index), dim=-1)
|
| 219 |
+
invalid_ij = th.stack(th.where(invalid_index), dim=-1)
|
| 220 |
+
lookup_valid = KDTree(valid_ij.cpu().numpy())
|
| 221 |
+
|
| 222 |
+
dists, idxs = lookup_valid.query(invalid_ij.cpu())
|
| 223 |
+
|
| 224 |
+
# TODO: try average?
|
| 225 |
+
idxs = th.as_tensor(idxs, device=device)[..., 0]
|
| 226 |
+
dists = th.as_tensor(dists, device=device)[..., 0]
|
| 227 |
+
|
| 228 |
+
dist_mask = dists < distance_threshold
|
| 229 |
+
|
| 230 |
+
invalid_border = th.zeros_like(invalid_index)
|
| 231 |
+
invalid_border[invalid_index] = dist_mask
|
| 232 |
+
|
| 233 |
+
invalid_src_ij = valid_ij[idxs][dist_mask]
|
| 234 |
+
invalid_dst_ij = invalid_ij[dist_mask]
|
| 235 |
+
|
| 236 |
+
index_image_imp = index_image.clone()
|
| 237 |
+
|
| 238 |
+
index_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = index_image[
|
| 239 |
+
invalid_src_ij[:, 0], invalid_src_ij[:, 1]
|
| 240 |
+
]
|
| 241 |
+
|
| 242 |
+
if bary_image is not None:
|
| 243 |
+
bary_image_imp = bary_image.clone()
|
| 244 |
+
|
| 245 |
+
bary_image_imp[invalid_dst_ij[:, 0], invalid_dst_ij[:, 1]] = bary_image[
|
| 246 |
+
invalid_src_ij[:, 0], invalid_src_ij[:, 1]
|
| 247 |
+
]
|
| 248 |
+
|
| 249 |
+
return index_image_imp, bary_image_imp
|
| 250 |
+
return index_image_imp
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class GeometryModule(nn.Module):
|
| 254 |
+
def __init__(
|
| 255 |
+
self,
|
| 256 |
+
v,
|
| 257 |
+
vi,
|
| 258 |
+
vt,
|
| 259 |
+
vti,
|
| 260 |
+
uv_size,
|
| 261 |
+
v2uv: Optional[th.Tensor] = None,
|
| 262 |
+
flip_uv=False,
|
| 263 |
+
impaint=False,
|
| 264 |
+
impaint_threshold=100.0,
|
| 265 |
+
):
|
| 266 |
+
super().__init__()
|
| 267 |
+
|
| 268 |
+
self.register_buffer("v", th.as_tensor(v))
|
| 269 |
+
self.register_buffer("vi", th.as_tensor(vi))
|
| 270 |
+
self.register_buffer("vt", th.as_tensor(vt))
|
| 271 |
+
self.register_buffer("vti", th.as_tensor(vti))
|
| 272 |
+
if v2uv is not None:
|
| 273 |
+
self.register_buffer("v2uv", th.as_tensor(v2uv, dtype=th.int64))
|
| 274 |
+
|
| 275 |
+
# TODO: should we just pass topology here?
|
| 276 |
+
# self.n_verts = v2uv.shape[0]
|
| 277 |
+
self.n_verts = vi.max() + 1
|
| 278 |
+
|
| 279 |
+
self.uv_size = uv_size
|
| 280 |
+
|
| 281 |
+
# TODO: can't we just index face_index?
|
| 282 |
+
index_image = make_uv_vert_index(
|
| 283 |
+
self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
| 284 |
+
).cpu()
|
| 285 |
+
face_index, bary_image = make_uv_barys(
|
| 286 |
+
self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
| 287 |
+
)
|
| 288 |
+
if impaint:
|
| 289 |
+
if min(uv_size) >= 1024:
|
| 290 |
+
logger.info(
|
| 291 |
+
"impainting index image might take a while for sizes >= 1024"
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
index_image, bary_image = index_image_impaint(
|
| 295 |
+
index_image, bary_image, impaint_threshold
|
| 296 |
+
)
|
| 297 |
+
# TODO: we can avoid doing this 2x
|
| 298 |
+
face_index = index_image_impaint(
|
| 299 |
+
face_index, distance_threshold=impaint_threshold
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
self.register_buffer("index_image", index_image.cpu())
|
| 303 |
+
self.register_buffer("bary_image", bary_image.cpu())
|
| 304 |
+
self.register_buffer("face_index_image", face_index.cpu())
|
| 305 |
+
|
| 306 |
+
def render_index_images(self, uv_size, flip_uv=False, impaint=False):
|
| 307 |
+
index_image = make_uv_vert_index(
|
| 308 |
+
self.vt, self.vi, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
| 309 |
+
)
|
| 310 |
+
face_image, bary_image = make_uv_barys(
|
| 311 |
+
self.vt, self.vti, uv_shape=uv_size, flip_uv=flip_uv
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
if impaint:
|
| 315 |
+
index_image, bary_image = index_image_impaint(
|
| 316 |
+
index_image,
|
| 317 |
+
bary_image,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
return index_image, face_image, bary_image
|
| 321 |
+
|
| 322 |
+
def vn(self, verts):
|
| 323 |
+
return vert_normals(verts, self.vi[np.newaxis].to(th.long))
|
| 324 |
+
|
| 325 |
+
def to_uv(self, values):
|
| 326 |
+
return values_to_uv(values, self.index_image, self.bary_image)
|
| 327 |
+
|
| 328 |
+
def from_uv(self, values_uv):
|
| 329 |
+
# TODO: we need to sample this
|
| 330 |
+
return sample_uv(values_uv, self.vt, self.v2uv.to(th.long))
|
| 331 |
+
|
| 332 |
+
def rand_sample_3d_uv(self, count, uv_img):
|
| 333 |
+
"""
|
| 334 |
+
Sample a set of 3D points on the surface of mesh, return corresponding interpolated values in UV space.
|
| 335 |
+
|
| 336 |
+
Args:
|
| 337 |
+
count - num of 3D points to be sampled
|
| 338 |
+
|
| 339 |
+
uv_img - the image in uv space to be sampled, e.g., texture
|
| 340 |
+
"""
|
| 341 |
+
_mesh = Trimesh(vertices=self.v.detach().cpu().numpy(), faces=self.vi.detach().cpu().numpy(), process=False)
|
| 342 |
+
points, _ = trimesh.sample.sample_surface(_mesh, count)
|
| 343 |
+
return self.sample_uv_from_3dpts(points, uv_img)
|
| 344 |
+
|
| 345 |
+
def sample_uv_from_3dpts(self, points, uv_img):
|
| 346 |
+
num_pts = points.shape[0]
|
| 347 |
+
approx, barys, interp_idxs, face_idxs = closest_point_barycentrics(self.v.detach().cpu().numpy(), self.vi.detach().cpu().numpy(), points)
|
| 348 |
+
interp_uv_coords = self.vt[interp_idxs, :] # [N, 3, 2]
|
| 349 |
+
# do bary interp first to get interp_uv_coord in high-reso uv space
|
| 350 |
+
target_uv_coords = th.sum(interp_uv_coords * th.from_numpy(barys)[..., None], dim=1).float()
|
| 351 |
+
# then directly sample from uv space
|
| 352 |
+
sampled_values = sample_uv(values_uv=uv_img.permute(2, 0, 1)[None, ...], uv_coords=target_uv_coords) # [1, count, c]
|
| 353 |
+
approx_values = sampled_values[0].reshape(num_pts, uv_img.shape[2])
|
| 354 |
+
return approx_values.numpy(), points
|
| 355 |
+
|
| 356 |
+
def vert_sample_uv(self, uv_img):
|
| 357 |
+
count = self.v.shape[0]
|
| 358 |
+
points = self.v.detach().cpu().numpy()
|
| 359 |
+
approx_values, _ = self.sample_uv_from_3dpts(points, uv_img)
|
| 360 |
+
return approx_values
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
def sample_uv(
|
| 364 |
+
values_uv,
|
| 365 |
+
uv_coords,
|
| 366 |
+
v2uv: Optional[th.Tensor] = None,
|
| 367 |
+
mode: str = "bilinear",
|
| 368 |
+
align_corners: bool = True,
|
| 369 |
+
flip_uvs: bool = False,
|
| 370 |
+
):
|
| 371 |
+
batch_size = values_uv.shape[0]
|
| 372 |
+
|
| 373 |
+
if flip_uvs:
|
| 374 |
+
uv_coords = uv_coords.clone()
|
| 375 |
+
uv_coords[:, 1] = 1.0 - uv_coords[:, 1]
|
| 376 |
+
|
| 377 |
+
# uv_coords_norm is [1, N, 1, 2] afterwards
|
| 378 |
+
uv_coords_norm = (uv_coords * 2.0 - 1.0)[np.newaxis, :, np.newaxis].expand(
|
| 379 |
+
batch_size, -1, -1, -1
|
| 380 |
+
)
|
| 381 |
+
# uv_shape = values_uv.shape[-2:]
|
| 382 |
+
# uv_max_shape_ind = uv_shape.index(max(uv_shape))
|
| 383 |
+
# uv_min_shape_ind = uv_shape.index(min(uv_shape))
|
| 384 |
+
# uv_ratio = uv_shape[uv_max_shape_ind] / uv_shape[uv_min_shape_ind]
|
| 385 |
+
# uv_coords_norm[..., uv_min_shape_ind] *= uv_ratio
|
| 386 |
+
|
| 387 |
+
values = (
|
| 388 |
+
F.grid_sample(values_uv, uv_coords_norm, align_corners=align_corners, mode=mode)
|
| 389 |
+
.squeeze(-1)
|
| 390 |
+
.permute((0, 2, 1))
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
if v2uv is not None:
|
| 394 |
+
values_duplicate = values[:, v2uv]
|
| 395 |
+
values = values_duplicate.mean(2)
|
| 396 |
+
|
| 397 |
+
return values
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def values_to_uv(values, index_img, bary_img):
|
| 401 |
+
uv_size = index_img.shape
|
| 402 |
+
index_mask = th.all(index_img != -1, dim=-1)
|
| 403 |
+
idxs_flat = index_img[index_mask].to(th.int64)
|
| 404 |
+
bary_flat = bary_img[index_mask].to(th.float32)
|
| 405 |
+
# NOTE: here we assume
|
| 406 |
+
values_flat = th.sum(values[:, idxs_flat].permute(0, 3, 1, 2) * bary_flat, dim=-1)
|
| 407 |
+
values_uv = th.zeros(
|
| 408 |
+
values.shape[0],
|
| 409 |
+
values.shape[-1],
|
| 410 |
+
uv_size[0],
|
| 411 |
+
uv_size[1],
|
| 412 |
+
dtype=values.dtype,
|
| 413 |
+
device=values.device,
|
| 414 |
+
)
|
| 415 |
+
values_uv[:, :, index_mask] = values_flat
|
| 416 |
+
return values_uv
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
def face_normals(v, vi, eps: float = 1e-5):
|
| 420 |
+
pts = v[:, vi]
|
| 421 |
+
v0 = pts[:, :, 1] - pts[:, :, 0]
|
| 422 |
+
v1 = pts[:, :, 2] - pts[:, :, 0]
|
| 423 |
+
n = th.cross(v0, v1, dim=-1)
|
| 424 |
+
norm = th.norm(n, dim=-1, keepdim=True)
|
| 425 |
+
norm[norm < eps] = 1
|
| 426 |
+
n /= norm
|
| 427 |
+
return n
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def vert_normals(v, vi, eps: float = 1.0e-5):
|
| 431 |
+
fnorms = face_normals(v, vi)
|
| 432 |
+
fnorms = fnorms[:, :, None].expand(-1, -1, 3, -1).reshape(fnorms.shape[0], -1, 3)
|
| 433 |
+
vi_flat = vi.view(1, -1).expand(v.shape[0], -1)
|
| 434 |
+
vnorms = th.zeros_like(v)
|
| 435 |
+
for j in range(3):
|
| 436 |
+
vnorms[..., j].scatter_add_(1, vi_flat, fnorms[..., j])
|
| 437 |
+
norm = th.norm(vnorms, dim=-1, keepdim=True)
|
| 438 |
+
norm[norm < eps] = 1
|
| 439 |
+
vnorms /= norm
|
| 440 |
+
return vnorms
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def compute_view_cos(verts, faces, camera_pos):
|
| 444 |
+
vn = F.normalize(vert_normals(verts, faces), dim=-1)
|
| 445 |
+
v2c = F.normalize(verts - camera_pos[:, np.newaxis], dim=-1)
|
| 446 |
+
return th.einsum("bnd,bnd->bn", vn, v2c)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def compute_tbn(geom, vt, vi, vti):
|
| 450 |
+
"""Computes tangent, bitangent, and normal vectors given a mesh.
|
| 451 |
+
Args:
|
| 452 |
+
geom: [N, n_verts, 3] th.Tensor
|
| 453 |
+
Vertex positions.
|
| 454 |
+
vt: [n_uv_coords, 2] th.Tensor
|
| 455 |
+
UV coordinates.
|
| 456 |
+
vi: [..., 3] th.Tensor
|
| 457 |
+
Face vertex indices.
|
| 458 |
+
vti: [..., 3] th.Tensor
|
| 459 |
+
Face UV indices.
|
| 460 |
+
Returns:
|
| 461 |
+
[..., 3] th.Tensors for T, B, N.
|
| 462 |
+
"""
|
| 463 |
+
|
| 464 |
+
v0 = geom[:, vi[..., 0]]
|
| 465 |
+
v1 = geom[:, vi[..., 1]]
|
| 466 |
+
v2 = geom[:, vi[..., 2]]
|
| 467 |
+
vt0 = vt[vti[..., 0]]
|
| 468 |
+
vt1 = vt[vti[..., 1]]
|
| 469 |
+
vt2 = vt[vti[..., 2]]
|
| 470 |
+
|
| 471 |
+
v01 = v1 - v0
|
| 472 |
+
v02 = v2 - v0
|
| 473 |
+
vt01 = vt1 - vt0
|
| 474 |
+
vt02 = vt2 - vt0
|
| 475 |
+
f = 1.0 / (
|
| 476 |
+
vt01[None, ..., 0] * vt02[None, ..., 1]
|
| 477 |
+
- vt01[None, ..., 1] * vt02[None, ..., 0]
|
| 478 |
+
)
|
| 479 |
+
tangent = f[..., None] * th.stack(
|
| 480 |
+
[
|
| 481 |
+
v01[..., 0] * vt02[None, ..., 1] - v02[..., 0] * vt01[None, ..., 1],
|
| 482 |
+
v01[..., 1] * vt02[None, ..., 1] - v02[..., 1] * vt01[None, ..., 1],
|
| 483 |
+
v01[..., 2] * vt02[None, ..., 1] - v02[..., 2] * vt01[None, ..., 1],
|
| 484 |
+
],
|
| 485 |
+
dim=-1,
|
| 486 |
+
)
|
| 487 |
+
tangent = F.normalize(tangent, dim=-1)
|
| 488 |
+
normal = F.normalize(th.cross(v01, v02, dim=3), dim=-1)
|
| 489 |
+
bitangent = F.normalize(th.cross(tangent, normal, dim=3), dim=-1)
|
| 490 |
+
|
| 491 |
+
return tangent, bitangent, normal
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
def compute_v2uv(n_verts, vi, vti, n_max=4):
|
| 495 |
+
"""Computes mapping from vertex indices to texture indices.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
vi: [F, 3], triangles
|
| 499 |
+
vti: [F, 3], texture triangles
|
| 500 |
+
n_max: int, max number of texture locations
|
| 501 |
+
|
| 502 |
+
Returns:
|
| 503 |
+
[n_verts, n_max], texture indices
|
| 504 |
+
"""
|
| 505 |
+
v2uv_dict = {}
|
| 506 |
+
for i_v, i_uv in zip(vi.reshape(-1), vti.reshape(-1)):
|
| 507 |
+
v2uv_dict.setdefault(i_v, set()).add(i_uv)
|
| 508 |
+
assert len(v2uv_dict) == n_verts
|
| 509 |
+
v2uv = np.zeros((n_verts, n_max), dtype=np.int32)
|
| 510 |
+
for i in range(n_verts):
|
| 511 |
+
vals = sorted(list(v2uv_dict[i]))
|
| 512 |
+
v2uv[i, :] = vals[0]
|
| 513 |
+
v2uv[i, : len(vals)] = np.array(vals)
|
| 514 |
+
return v2uv
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def compute_neighbours(n_verts, vi, n_max_values=10):
|
| 518 |
+
"""Computes first-ring neighbours given vertices and faces."""
|
| 519 |
+
n_vi = vi.shape[0]
|
| 520 |
+
|
| 521 |
+
adj = {i: set() for i in range(n_verts)}
|
| 522 |
+
for i in range(n_vi):
|
| 523 |
+
for idx in vi[i]:
|
| 524 |
+
adj[idx] |= set(vi[i]) - set([idx])
|
| 525 |
+
|
| 526 |
+
nbs_idxs = np.tile(np.arange(n_verts)[:, np.newaxis], (1, n_max_values))
|
| 527 |
+
nbs_weights = np.zeros((n_verts, n_max_values), dtype=np.float32)
|
| 528 |
+
|
| 529 |
+
for idx in range(n_verts):
|
| 530 |
+
n_values = min(len(adj[idx]), n_max_values)
|
| 531 |
+
nbs_idxs[idx, :n_values] = np.array(list(adj[idx]))[:n_values]
|
| 532 |
+
nbs_weights[idx, :n_values] = -1.0 / n_values
|
| 533 |
+
|
| 534 |
+
return nbs_idxs, nbs_weights
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
def make_postex(v, idxim, barim):
|
| 538 |
+
return (
|
| 539 |
+
barim[None, :, :, 0, None] * v[:, idxim[:, :, 0]]
|
| 540 |
+
+ barim[None, :, :, 1, None] * v[:, idxim[:, :, 1]]
|
| 541 |
+
+ barim[None, :, :, 2, None] * v[:, idxim[:, :, 2]]
|
| 542 |
+
).permute(0, 3, 1, 2)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def matrix_to_axisangle(r):
|
| 546 |
+
th = th.arccos(0.5 * (r[..., 0, 0] + r[..., 1, 1] + r[..., 2, 2] - 1.0))[..., None]
|
| 547 |
+
vec = (
|
| 548 |
+
0.5
|
| 549 |
+
* th.stack(
|
| 550 |
+
[
|
| 551 |
+
r[..., 2, 1] - r[..., 1, 2],
|
| 552 |
+
r[..., 0, 2] - r[..., 2, 0],
|
| 553 |
+
r[..., 1, 0] - r[..., 0, 1],
|
| 554 |
+
],
|
| 555 |
+
dim=-1,
|
| 556 |
+
)
|
| 557 |
+
/ th.sin(th)
|
| 558 |
+
)
|
| 559 |
+
return th, vec
|
| 560 |
+
|
| 561 |
+
|
| 562 |
+
def axisangle_to_matrix(rvec):
|
| 563 |
+
theta = th.sqrt(1e-5 + th.sum(rvec**2, dim=-1))
|
| 564 |
+
rvec = rvec / theta[..., None]
|
| 565 |
+
costh = th.cos(theta)
|
| 566 |
+
sinth = th.sin(theta)
|
| 567 |
+
return th.stack(
|
| 568 |
+
(
|
| 569 |
+
th.stack(
|
| 570 |
+
(
|
| 571 |
+
rvec[..., 0] ** 2 + (1.0 - rvec[..., 0] ** 2) * costh,
|
| 572 |
+
rvec[..., 0] * rvec[..., 1] * (1.0 - costh) - rvec[..., 2] * sinth,
|
| 573 |
+
rvec[..., 0] * rvec[..., 2] * (1.0 - costh) + rvec[..., 1] * sinth,
|
| 574 |
+
),
|
| 575 |
+
dim=-1,
|
| 576 |
+
),
|
| 577 |
+
th.stack(
|
| 578 |
+
(
|
| 579 |
+
rvec[..., 0] * rvec[..., 1] * (1.0 - costh) + rvec[..., 2] * sinth,
|
| 580 |
+
rvec[..., 1] ** 2 + (1.0 - rvec[..., 1] ** 2) * costh,
|
| 581 |
+
rvec[..., 1] * rvec[..., 2] * (1.0 - costh) - rvec[..., 0] * sinth,
|
| 582 |
+
),
|
| 583 |
+
dim=-1,
|
| 584 |
+
),
|
| 585 |
+
th.stack(
|
| 586 |
+
(
|
| 587 |
+
rvec[..., 0] * rvec[..., 2] * (1.0 - costh) - rvec[..., 1] * sinth,
|
| 588 |
+
rvec[..., 1] * rvec[..., 2] * (1.0 - costh) + rvec[..., 0] * sinth,
|
| 589 |
+
rvec[..., 2] ** 2 + (1.0 - rvec[..., 2] ** 2) * costh,
|
| 590 |
+
),
|
| 591 |
+
dim=-1,
|
| 592 |
+
),
|
| 593 |
+
),
|
| 594 |
+
dim=-2,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
|
| 598 |
+
def rotation_interp(r0, r1, alpha):
|
| 599 |
+
r0a = r0.view(-1, 3, 3)
|
| 600 |
+
r1a = r1.view(-1, 3, 3)
|
| 601 |
+
r = th.bmm(r0a.permute(0, 2, 1), r1a).view_as(r0)
|
| 602 |
+
|
| 603 |
+
th, rvec = matrix_to_axisangle(r)
|
| 604 |
+
rvec = rvec * (alpha * th)
|
| 605 |
+
|
| 606 |
+
r = axisangle_to_matrix(rvec)
|
| 607 |
+
return th.bmm(r0a, r.view(-1, 3, 3)).view_as(r0)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
def convert_camera_parameters(Rt, K):
|
| 611 |
+
R = Rt[:, :3, :3]
|
| 612 |
+
t = -R.permute(0, 2, 1).bmm(Rt[:, :3, 3].unsqueeze(2)).squeeze(2)
|
| 613 |
+
return dict(
|
| 614 |
+
campos=t,
|
| 615 |
+
camrot=R,
|
| 616 |
+
focal=K[:, :2, :2],
|
| 617 |
+
princpt=K[:, :2, 2],
|
| 618 |
+
)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
def project_points_multi(p, Rt, K, normalize=False, size=None):
|
| 622 |
+
"""Project a set of 3D points into multiple cameras with a pinhole model.
|
| 623 |
+
Args:
|
| 624 |
+
p: [B, N, 3], input 3D points in world coordinates
|
| 625 |
+
Rt: [B, NC, 3, 4], extrinsics (where NC is the number of cameras to project to)
|
| 626 |
+
K: [B, NC, 3, 3], intrinsics
|
| 627 |
+
normalize: bool, whether to normalize coordinates to [-1.0, 1.0]
|
| 628 |
+
Returns:
|
| 629 |
+
tuple:
|
| 630 |
+
- [B, NC, N, 2] - projected points
|
| 631 |
+
- [B, NC, N] - their
|
| 632 |
+
"""
|
| 633 |
+
B, N = p.shape[:2]
|
| 634 |
+
NC = Rt.shape[1]
|
| 635 |
+
|
| 636 |
+
Rt = Rt.reshape(B * NC, 3, 4)
|
| 637 |
+
K = K.reshape(B * NC, 3, 3)
|
| 638 |
+
|
| 639 |
+
# [B, N, 3] -> [B * NC, N, 3]
|
| 640 |
+
p = p[:, np.newaxis].expand(-1, NC, -1, -1).reshape(B * NC, -1, 3)
|
| 641 |
+
p_cam = p @ Rt[:, :3, :3].transpose(-2, -1) + Rt[:, :3, 3][:, np.newaxis]
|
| 642 |
+
p_pix = p_cam @ K.transpose(-2, -1)
|
| 643 |
+
p_depth = p_pix[:, :, 2:]
|
| 644 |
+
p_pix = (p_pix[..., :2] / p_depth).reshape(B, NC, N, 2)
|
| 645 |
+
p_depth = p_depth.reshape(B, NC, N)
|
| 646 |
+
|
| 647 |
+
if normalize:
|
| 648 |
+
assert size is not None
|
| 649 |
+
h, w = size
|
| 650 |
+
p_pix = (
|
| 651 |
+
2.0 * p_pix / th.as_tensor([w, h], dtype=th.float32, device=p.device) - 1.0
|
| 652 |
+
)
|
| 653 |
+
return p_pix, p_depth
|
dva/io.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
import copy
|
| 11 |
+
import importlib
|
| 12 |
+
from typing import Any, Dict
|
| 13 |
+
|
| 14 |
+
def load_module(module_name, class_name=None, silent: bool = False):
|
| 15 |
+
module = importlib.import_module(module_name)
|
| 16 |
+
return getattr(module, class_name) if class_name else module
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_class(class_name):
|
| 20 |
+
return load_module(*class_name.rsplit(".", 1))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_from_config(config, **kwargs):
|
| 24 |
+
"""Instantiate an object given a config and arguments."""
|
| 25 |
+
assert "class_name" in config and "module_name" not in config
|
| 26 |
+
config = copy.deepcopy(config)
|
| 27 |
+
class_name = config.pop("class_name")
|
| 28 |
+
object_class = load_class(class_name)
|
| 29 |
+
return object_class(**config, **kwargs)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def load_opencv_calib(extrin_path, intrin_path):
|
| 33 |
+
cameras = {}
|
| 34 |
+
|
| 35 |
+
fse = cv2.FileStorage()
|
| 36 |
+
fse.open(extrin_path, cv2.FileStorage_READ)
|
| 37 |
+
|
| 38 |
+
fsi = cv2.FileStorage()
|
| 39 |
+
fsi.open(intrin_path, cv2.FileStorage_READ)
|
| 40 |
+
|
| 41 |
+
names = [
|
| 42 |
+
fse.getNode("names").at(c).string() for c in range(fse.getNode("names").size())
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
for camera in names:
|
| 46 |
+
rot = fse.getNode(f"R_{camera}").mat()
|
| 47 |
+
R = fse.getNode(f"Rot_{camera}").mat()
|
| 48 |
+
T = fse.getNode(f"T_{camera}").mat()
|
| 49 |
+
R_pred = cv2.Rodrigues(rot)[0]
|
| 50 |
+
assert np.all(np.isclose(R_pred, R))
|
| 51 |
+
K = fsi.getNode(f"K_{camera}").mat()
|
| 52 |
+
cameras[camera] = {
|
| 53 |
+
"Rt": np.concatenate([R, T], axis=1).astype(np.float32),
|
| 54 |
+
"K": K.astype(np.float32),
|
| 55 |
+
}
|
| 56 |
+
return cameras
|
dva/layers.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch as th
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from dva.mvp.models.utils import Conv2dWN, Conv2dWNUB, ConvTranspose2dWNUB, initmod
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ConvBlock(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
in_channels,
|
| 19 |
+
out_channels,
|
| 20 |
+
size,
|
| 21 |
+
lrelu_slope=0.2,
|
| 22 |
+
kernel_size=3,
|
| 23 |
+
padding=1,
|
| 24 |
+
wnorm_dim=0,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
|
| 28 |
+
self.conv_resize = Conv2dWN(in_channels, out_channels, kernel_size=1)
|
| 29 |
+
self.conv1 = Conv2dWNUB(
|
| 30 |
+
in_channels,
|
| 31 |
+
in_channels,
|
| 32 |
+
kernel_size=kernel_size,
|
| 33 |
+
padding=padding,
|
| 34 |
+
height=size,
|
| 35 |
+
width=size,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
self.lrelu1 = nn.LeakyReLU(lrelu_slope)
|
| 39 |
+
self.conv2 = Conv2dWNUB(
|
| 40 |
+
in_channels,
|
| 41 |
+
out_channels,
|
| 42 |
+
kernel_size=kernel_size,
|
| 43 |
+
padding=padding,
|
| 44 |
+
height=size,
|
| 45 |
+
width=size,
|
| 46 |
+
)
|
| 47 |
+
self.lrelu2 = nn.LeakyReLU(lrelu_slope)
|
| 48 |
+
|
| 49 |
+
def forward(self, x):
|
| 50 |
+
x_skip = self.conv_resize(x)
|
| 51 |
+
x = self.conv1(x)
|
| 52 |
+
x = self.lrelu1(x)
|
| 53 |
+
x = self.conv2(x)
|
| 54 |
+
x = self.lrelu2(x)
|
| 55 |
+
return x + x_skip
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def tile2d(x, size: int):
|
| 59 |
+
"""Tile a given set of features into a convolutional map.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
x: float tensor of shape [N, F]
|
| 63 |
+
size: int or a tuple
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
a feature map [N, F, size[0], size[1]]
|
| 67 |
+
"""
|
| 68 |
+
# size = size if isinstance(size, tuple) else (size, size)
|
| 69 |
+
# NOTE: expecting only int here (!!!)
|
| 70 |
+
return x[:, :, np.newaxis, np.newaxis].expand(-1, -1, size, size)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def weights_initializer(m, alpha: float = 1.0):
|
| 74 |
+
return initmod(m, nn.init.calculate_gain("leaky_relu", alpha))
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class UNetWB(nn.Module):
|
| 78 |
+
def __init__(
|
| 79 |
+
self,
|
| 80 |
+
in_channels,
|
| 81 |
+
out_channels,
|
| 82 |
+
size,
|
| 83 |
+
n_init_ftrs=8,
|
| 84 |
+
out_scale=0.1,
|
| 85 |
+
):
|
| 86 |
+
# super().__init__(*args, **kwargs)
|
| 87 |
+
super().__init__()
|
| 88 |
+
|
| 89 |
+
self.out_scale = 0.1
|
| 90 |
+
|
| 91 |
+
F = n_init_ftrs
|
| 92 |
+
|
| 93 |
+
# TODO: allow changing the size?
|
| 94 |
+
self.size = size
|
| 95 |
+
|
| 96 |
+
self.down1 = nn.Sequential(
|
| 97 |
+
Conv2dWNUB(in_channels, F, self.size // 2, self.size // 2, 4, 2, 1),
|
| 98 |
+
nn.LeakyReLU(0.2),
|
| 99 |
+
)
|
| 100 |
+
self.down2 = nn.Sequential(
|
| 101 |
+
Conv2dWNUB(F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
|
| 102 |
+
nn.LeakyReLU(0.2),
|
| 103 |
+
)
|
| 104 |
+
self.down3 = nn.Sequential(
|
| 105 |
+
Conv2dWNUB(2 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
|
| 106 |
+
nn.LeakyReLU(0.2),
|
| 107 |
+
)
|
| 108 |
+
self.down4 = nn.Sequential(
|
| 109 |
+
Conv2dWNUB(4 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1),
|
| 110 |
+
nn.LeakyReLU(0.2),
|
| 111 |
+
)
|
| 112 |
+
self.down5 = nn.Sequential(
|
| 113 |
+
Conv2dWNUB(8 * F, 16 * F, self.size // 32, self.size // 32, 4, 2, 1),
|
| 114 |
+
nn.LeakyReLU(0.2),
|
| 115 |
+
)
|
| 116 |
+
self.up1 = nn.Sequential(
|
| 117 |
+
ConvTranspose2dWNUB(
|
| 118 |
+
16 * F, 8 * F, self.size // 16, self.size // 16, 4, 2, 1
|
| 119 |
+
),
|
| 120 |
+
nn.LeakyReLU(0.2),
|
| 121 |
+
)
|
| 122 |
+
self.up2 = nn.Sequential(
|
| 123 |
+
ConvTranspose2dWNUB(8 * F, 4 * F, self.size // 8, self.size // 8, 4, 2, 1),
|
| 124 |
+
nn.LeakyReLU(0.2),
|
| 125 |
+
)
|
| 126 |
+
self.up3 = nn.Sequential(
|
| 127 |
+
ConvTranspose2dWNUB(4 * F, 2 * F, self.size // 4, self.size // 4, 4, 2, 1),
|
| 128 |
+
nn.LeakyReLU(0.2),
|
| 129 |
+
)
|
| 130 |
+
self.up4 = nn.Sequential(
|
| 131 |
+
ConvTranspose2dWNUB(2 * F, F, self.size // 2, self.size // 2, 4, 2, 1),
|
| 132 |
+
nn.LeakyReLU(0.2),
|
| 133 |
+
)
|
| 134 |
+
self.up5 = nn.Sequential(
|
| 135 |
+
ConvTranspose2dWNUB(F, F, self.size, self.size, 4, 2, 1), nn.LeakyReLU(0.2)
|
| 136 |
+
)
|
| 137 |
+
self.out = Conv2dWNUB(
|
| 138 |
+
F + in_channels, out_channels, self.size, self.size, kernel_size=1
|
| 139 |
+
)
|
| 140 |
+
self.apply(lambda x: initmod(x, 0.2))
|
| 141 |
+
initmod(self.out, 1.0)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
x1 = x
|
| 145 |
+
x2 = self.down1(x1)
|
| 146 |
+
x3 = self.down2(x2)
|
| 147 |
+
x4 = self.down3(x3)
|
| 148 |
+
x5 = self.down4(x4)
|
| 149 |
+
x6 = self.down5(x5)
|
| 150 |
+
# TODO: switch to concat?
|
| 151 |
+
x = self.up1(x6) + x5
|
| 152 |
+
x = self.up2(x) + x4
|
| 153 |
+
x = self.up3(x) + x3
|
| 154 |
+
x = self.up4(x) + x2
|
| 155 |
+
x = self.up5(x)
|
| 156 |
+
x = th.cat([x, x1], dim=1)
|
| 157 |
+
return self.out(x) * self.out_scale
|
dva/losses.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch as th
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from .vgg import VGGLossMasked
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dva.{__name__}")
|
| 16 |
+
|
| 17 |
+
class DCTLoss(nn.Module):
|
| 18 |
+
def __init__(self, weights):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.weights = weights
|
| 21 |
+
|
| 22 |
+
def forward(self, inputs, preds, iteration=None):
|
| 23 |
+
loss_dict = {"loss_total": 0.0}
|
| 24 |
+
target = inputs['gt']
|
| 25 |
+
recon = preds['recon']
|
| 26 |
+
posterior = preds['posterior']
|
| 27 |
+
fft_gt = th.view_as_real(th.fft.fft(target.reshape(target.shape[0], -1)))
|
| 28 |
+
fft_recon = th.view_as_real(th.fft.fft(recon.reshape(recon.shape[0], -1)))
|
| 29 |
+
loss_recon_dct_l1 = th.mean(th.abs(fft_gt - fft_recon))
|
| 30 |
+
loss_recon_l1 = th.mean(th.abs(target - recon))
|
| 31 |
+
loss_kl = posterior.kl().mean()
|
| 32 |
+
loss_dict.update(loss_recon_l1=loss_recon_l1, loss_recon_dct_l1=loss_recon_dct_l1, loss_kl=loss_kl)
|
| 33 |
+
loss_total = self.weights.recon * loss_recon_dct_l1 + self.weights.kl * loss_kl
|
| 34 |
+
|
| 35 |
+
loss_dict["loss_total"] = loss_total
|
| 36 |
+
return loss_total, loss_dict
|
| 37 |
+
|
| 38 |
+
class VAESepL2Loss(nn.Module):
|
| 39 |
+
def __init__(self, weights):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.weights = weights
|
| 42 |
+
|
| 43 |
+
def forward(self, inputs, preds, iteration=None):
|
| 44 |
+
loss_dict = {"loss_total": 0.0}
|
| 45 |
+
target = inputs['gt']
|
| 46 |
+
recon = preds['recon']
|
| 47 |
+
posterior = preds['posterior']
|
| 48 |
+
recon_diff = (target - recon) ** 2
|
| 49 |
+
loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
|
| 50 |
+
loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
|
| 51 |
+
loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
|
| 52 |
+
loss_kl = posterior.kl().mean()
|
| 53 |
+
loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
|
| 54 |
+
loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
|
| 55 |
+
if "kl" in self.weights:
|
| 56 |
+
loss_total += self.weights.kl * loss_kl
|
| 57 |
+
|
| 58 |
+
loss_dict["loss_total"] = loss_total
|
| 59 |
+
return loss_total, loss_dict
|
| 60 |
+
|
| 61 |
+
class VAESepLoss(nn.Module):
|
| 62 |
+
def __init__(self, weights):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.weights = weights
|
| 65 |
+
|
| 66 |
+
def forward(self, inputs, preds, iteration=None):
|
| 67 |
+
loss_dict = {"loss_total": 0.0}
|
| 68 |
+
target = inputs['gt']
|
| 69 |
+
recon = preds['recon']
|
| 70 |
+
posterior = preds['posterior']
|
| 71 |
+
recon_diff = th.abs(target - recon)
|
| 72 |
+
loss_recon_sdf_l1 = th.mean(recon_diff[:, 0:1, ...])
|
| 73 |
+
loss_recon_rgb_l1 = th.mean(recon_diff[:, 1:4, ...])
|
| 74 |
+
loss_recon_mat_l1 = th.mean(recon_diff[:, 4:6, ...])
|
| 75 |
+
loss_kl = posterior.kl().mean()
|
| 76 |
+
loss_dict.update(loss_sdf_l1=loss_recon_sdf_l1, loss_rgb_l1=loss_recon_rgb_l1, loss_mat_l1=loss_recon_mat_l1, loss_kl=loss_kl)
|
| 77 |
+
loss_total = self.weights.sdf * loss_recon_sdf_l1 + self.weights.rgb * loss_recon_rgb_l1 + self.weights.mat * loss_recon_mat_l1
|
| 78 |
+
if "kl" in self.weights:
|
| 79 |
+
loss_total += self.weights.kl * loss_kl
|
| 80 |
+
|
| 81 |
+
loss_dict["loss_total"] = loss_total
|
| 82 |
+
return loss_total, loss_dict
|
| 83 |
+
|
| 84 |
+
class VAELoss(nn.Module):
|
| 85 |
+
def __init__(self, weights):
|
| 86 |
+
super().__init__()
|
| 87 |
+
self.weights = weights
|
| 88 |
+
|
| 89 |
+
def forward(self, inputs, preds, iteration=None):
|
| 90 |
+
loss_dict = {"loss_total": 0.0}
|
| 91 |
+
target = inputs['gt']
|
| 92 |
+
recon = preds['recon']
|
| 93 |
+
posterior = preds['posterior']
|
| 94 |
+
loss_recon_l1 = th.mean(th.abs(target - recon))
|
| 95 |
+
loss_kl = posterior.kl().mean()
|
| 96 |
+
loss_dict.update(loss_recon_l1=loss_recon_l1, loss_kl=loss_kl)
|
| 97 |
+
loss_total = self.weights.recon * loss_recon_l1 + self.weights.kl * loss_kl
|
| 98 |
+
|
| 99 |
+
loss_dict["loss_total"] = loss_total
|
| 100 |
+
return loss_total, loss_dict
|
| 101 |
+
|
| 102 |
+
class PrimSDFLoss(nn.Module):
|
| 103 |
+
def __init__(self, weights, shape_opt_steps=2000, tex_opt_steps=6000):
|
| 104 |
+
super().__init__()
|
| 105 |
+
self.weights = weights
|
| 106 |
+
self.shape_opt_steps = shape_opt_steps
|
| 107 |
+
self.tex_opt_steps = tex_opt_steps
|
| 108 |
+
|
| 109 |
+
def forward(self, inputs, preds, iteration=None):
|
| 110 |
+
loss_dict = {"loss_total": 0.0}
|
| 111 |
+
|
| 112 |
+
if iteration < self.shape_opt_steps:
|
| 113 |
+
target_sdf = inputs['sdf']
|
| 114 |
+
sdf = preds['sdf']
|
| 115 |
+
loss_sdf_l1 = th.mean(th.abs(sdf - target_sdf))
|
| 116 |
+
loss_dict.update(loss_sdf_l1=loss_sdf_l1)
|
| 117 |
+
loss_total = self.weights.sdf_l1 * loss_sdf_l1
|
| 118 |
+
|
| 119 |
+
prim_scale = preds["prim_scale"]
|
| 120 |
+
# we use 1/scale instead of the original 100/scale as our scale is normalized to [-1, 1] cube
|
| 121 |
+
if "vol_sum" in self.weights:
|
| 122 |
+
loss_prim_vol_sum = th.mean(th.sum(th.prod(1 / prim_scale, dim=-1), dim=-1))
|
| 123 |
+
loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
|
| 124 |
+
loss_total += self.weights.vol_sum * loss_prim_vol_sum
|
| 125 |
+
|
| 126 |
+
if iteration >= self.shape_opt_steps and iteration < self.tex_opt_steps:
|
| 127 |
+
target_tex = inputs['tex']
|
| 128 |
+
tex = preds['tex']
|
| 129 |
+
loss_tex_l1 = th.mean(th.abs(tex - target_tex))
|
| 130 |
+
loss_dict.update(loss_tex_l1=loss_tex_l1)
|
| 131 |
+
|
| 132 |
+
loss_total = (
|
| 133 |
+
self.weights.rgb_l1 * loss_tex_l1
|
| 134 |
+
)
|
| 135 |
+
if "mat_l1" in self.weights:
|
| 136 |
+
target_mat = inputs['mat']
|
| 137 |
+
mat = preds['mat']
|
| 138 |
+
loss_mat_l1 = th.mean(th.abs(mat - target_mat))
|
| 139 |
+
loss_dict.update(loss_mat_l1=loss_mat_l1)
|
| 140 |
+
loss_total += self.weights.mat_l1 * loss_mat_l1
|
| 141 |
+
|
| 142 |
+
if "grad_l2" in self.weights:
|
| 143 |
+
loss_grad_l2 = th.mean((preds["grad"] - inputs["grad"]) ** 2)
|
| 144 |
+
loss_total += self.weights.grad_l2 * loss_grad_l2
|
| 145 |
+
loss_dict.update(loss_grad_l2=loss_grad_l2)
|
| 146 |
+
|
| 147 |
+
loss_dict["loss_total"] = loss_total
|
| 148 |
+
return loss_total, loss_dict
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class TotalMVPLoss(nn.Module):
|
| 152 |
+
def __init__(self, weights, assets=None):
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
self.weights = weights
|
| 156 |
+
|
| 157 |
+
if "vgg" in self.weights:
|
| 158 |
+
self.vgg_loss = VGGLossMasked()
|
| 159 |
+
|
| 160 |
+
def forward(self, inputs, preds, iteration=None):
|
| 161 |
+
|
| 162 |
+
loss_dict = {"loss_total": 0.0}
|
| 163 |
+
|
| 164 |
+
B = inputs["image"].shape
|
| 165 |
+
|
| 166 |
+
# rgb
|
| 167 |
+
target_rgb = inputs["image"].permute(0, 2, 3, 1)
|
| 168 |
+
# removing the mask
|
| 169 |
+
target_rgb = target_rgb * inputs["image_mask"][:, 0, :, :, np.newaxis]
|
| 170 |
+
|
| 171 |
+
rgb = preds["rgb"]
|
| 172 |
+
loss_rgb_mse = th.mean(((rgb - target_rgb) / 16.0) ** 2.0)
|
| 173 |
+
loss_dict.update(loss_rgb_mse=loss_rgb_mse)
|
| 174 |
+
|
| 175 |
+
alpha = preds["alpha"]
|
| 176 |
+
|
| 177 |
+
# mask loss
|
| 178 |
+
target_mask = inputs["image_mask"][:, 0].to(th.float32)
|
| 179 |
+
loss_mask_mae = th.mean((target_mask - alpha).abs())
|
| 180 |
+
loss_dict.update(loss_mask_mae=loss_mask_mae)
|
| 181 |
+
|
| 182 |
+
B = alpha.shape[0]
|
| 183 |
+
|
| 184 |
+
# beta prior on opacity
|
| 185 |
+
loss_alpha_prior = th.mean(
|
| 186 |
+
th.log(0.1 + alpha.reshape(B, -1))
|
| 187 |
+
+ th.log(0.1 + 1.0 - alpha.reshape(B, -1))
|
| 188 |
+
- -2.20727
|
| 189 |
+
)
|
| 190 |
+
loss_dict.update(loss_alpha_prior=loss_alpha_prior)
|
| 191 |
+
|
| 192 |
+
prim_scale = preds["prim_scale"]
|
| 193 |
+
loss_prim_vol_sum = th.mean(th.sum(th.prod(100.0 / prim_scale, dim=-1), dim=-1))
|
| 194 |
+
loss_dict.update(loss_prim_vol_sum=loss_prim_vol_sum)
|
| 195 |
+
|
| 196 |
+
loss_total = (
|
| 197 |
+
self.weights.rgb_mse * loss_rgb_mse
|
| 198 |
+
+ self.weights.mask_mae * loss_mask_mae
|
| 199 |
+
+ self.weights.alpha_prior * loss_alpha_prior
|
| 200 |
+
+ self.weights.prim_vol_sum * loss_prim_vol_sum
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
if "embs_l2" in self.weights:
|
| 204 |
+
loss_embs_l2 = th.sum(th.norm(preds["embs"], dim=1))
|
| 205 |
+
loss_total += self.weights.embs_l2 * loss_embs_l2
|
| 206 |
+
loss_dict.update(loss_embs_l2=loss_embs_l2)
|
| 207 |
+
|
| 208 |
+
if "vgg" in self.weights:
|
| 209 |
+
loss_vgg = self.vgg_loss(
|
| 210 |
+
rgb.permute(0, 3, 1, 2),
|
| 211 |
+
target_rgb.permute(0, 3, 1, 2),
|
| 212 |
+
inputs["image_mask"],
|
| 213 |
+
)
|
| 214 |
+
loss_total += self.weights.vgg * loss_vgg
|
| 215 |
+
loss_dict.update(loss_vgg=loss_vgg)
|
| 216 |
+
|
| 217 |
+
if "prim_scale_var" in self.weights:
|
| 218 |
+
log_prim_scale = th.log(prim_scale)
|
| 219 |
+
# NOTE: should we detach this?
|
| 220 |
+
log_prim_scale_mean = th.mean(log_prim_scale, dim=1, keepdim=True)
|
| 221 |
+
loss_prim_scale_var = th.mean((log_prim_scale - log_prim_scale_mean) ** 2.0)
|
| 222 |
+
loss_total += self.weights.prim_scale_var * loss_prim_scale_var
|
| 223 |
+
loss_dict.update(loss_prim_scale_var=loss_prim_scale_var)
|
| 224 |
+
|
| 225 |
+
loss_dict["loss_total"] = loss_total
|
| 226 |
+
|
| 227 |
+
return loss_total, loss_dict
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def process_losses(loss_dict, reduce=True, detach=True):
|
| 231 |
+
"""Preprocess the dict of losses outputs."""
|
| 232 |
+
result = {
|
| 233 |
+
k.replace("loss_", ""): v for k, v in loss_dict.items() if k.startswith("loss_")
|
| 234 |
+
}
|
| 235 |
+
if detach:
|
| 236 |
+
result = {k: v.detach() for k, v in result.items()}
|
| 237 |
+
if reduce:
|
| 238 |
+
result = {k: float(v.mean().item()) for k, v in result.items()}
|
| 239 |
+
return result
|
dva/mvp/extensions/mvpraymarch/bvh.cu
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#include <cmath>
|
| 8 |
+
#include <cstdio>
|
| 9 |
+
#include <functional>
|
| 10 |
+
#include <map>
|
| 11 |
+
|
| 12 |
+
#include "helper_math.h"
|
| 13 |
+
|
| 14 |
+
#include "cudadispatch.h"
|
| 15 |
+
|
| 16 |
+
#include "primtransf.h"
|
| 17 |
+
|
| 18 |
+
// Expands a 10-bit integer into 30 bits
|
| 19 |
+
// by inserting 2 zeros after each bit.
|
| 20 |
+
__device__ unsigned int expand_bits(unsigned int v) {
|
| 21 |
+
v = (v * 0x00010001u) & 0xFF0000FFu;
|
| 22 |
+
v = (v * 0x00000101u) & 0x0F00F00Fu;
|
| 23 |
+
v = (v * 0x00000011u) & 0xC30C30C3u;
|
| 24 |
+
v = (v * 0x00000005u) & 0x49249249u;
|
| 25 |
+
return v;
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
// Calculates a 30-bit Morton code for the
|
| 29 |
+
// given 3D point located within the unit cube [0,1].
|
| 30 |
+
__device__ unsigned int morton3D(float x, float y, float z) {
|
| 31 |
+
x = fminf(fmaxf(x * 1024.0f, 0.0f), 1023.0f);
|
| 32 |
+
y = fminf(fmaxf(y * 1024.0f, 0.0f), 1023.0f);
|
| 33 |
+
z = fminf(fmaxf(z * 1024.0f, 0.0f), 1023.0f);
|
| 34 |
+
unsigned int xx = expand_bits((unsigned int)x);
|
| 35 |
+
unsigned int yy = expand_bits((unsigned int)y);
|
| 36 |
+
unsigned int zz = expand_bits((unsigned int)z);
|
| 37 |
+
return xx * 4 + yy * 2 + zz;
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
template<typename PrimTransfT>
|
| 41 |
+
__global__ void compute_morton_kernel(
|
| 42 |
+
int N, int K,
|
| 43 |
+
typename PrimTransfT::Data data,
|
| 44 |
+
int * code
|
| 45 |
+
) {
|
| 46 |
+
const int count = N * K;
|
| 47 |
+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
|
| 48 |
+
const int k = index % K;
|
| 49 |
+
const int n = index / K;
|
| 50 |
+
|
| 51 |
+
//float4 c = center[n * K + k];
|
| 52 |
+
float3 c = data.get_center(n, k);
|
| 53 |
+
code[n * K + k] = morton3D(c.x, c.y, c.z);
|
| 54 |
+
}
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
__forceinline__ __device__ int delta(int* sortedcodes, int x, int y, int K) {
|
| 58 |
+
if (x >= 0 && x <= K - 1 && y >= 0 && y <= K - 1) {
|
| 59 |
+
return sortedcodes[x] == sortedcodes[y] ?
|
| 60 |
+
32 + __clz(x ^ y) :
|
| 61 |
+
__clz(sortedcodes[x] ^ sortedcodes[y]);
|
| 62 |
+
}
|
| 63 |
+
return -1;
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
__forceinline__ __device__ int sign(int x) {
|
| 67 |
+
return (int)(x > 0) - (int)(x < 0);
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
__device__ int find_split(
|
| 71 |
+
int* sortedcodes,
|
| 72 |
+
int first,
|
| 73 |
+
int last,
|
| 74 |
+
int K) {
|
| 75 |
+
float commonPrefix = delta(sortedcodes, first, last, K);
|
| 76 |
+
int split = first;
|
| 77 |
+
int step = last - first;
|
| 78 |
+
|
| 79 |
+
do {
|
| 80 |
+
step = (step + 1) >> 1; // exponential decrease
|
| 81 |
+
int newSplit = split + step; // proposed new position
|
| 82 |
+
|
| 83 |
+
if (newSplit < last) {
|
| 84 |
+
int splitPrefix = delta(sortedcodes, first, newSplit, K);
|
| 85 |
+
if (splitPrefix > commonPrefix) {
|
| 86 |
+
split = newSplit; // accept proposal
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
} while (step > 1);
|
| 90 |
+
|
| 91 |
+
return split;
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
__device__ int2 determine_range(int* sortedcodes, int K, int idx) {
|
| 95 |
+
int d = sign(delta(sortedcodes, idx, idx + 1, K) - delta(sortedcodes, idx, idx - 1, K));
|
| 96 |
+
int dmin = delta(sortedcodes, idx, idx - d, K);
|
| 97 |
+
int lmax = 2;
|
| 98 |
+
while (delta(sortedcodes, idx, idx + lmax * d, K) > dmin) {
|
| 99 |
+
lmax = lmax * 2;
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
int l = 0;
|
| 103 |
+
for (int t = lmax / 2; t >= 1; t /= 2) {
|
| 104 |
+
if (delta(sortedcodes, idx, idx + (l + t)*d, K) > dmin) {
|
| 105 |
+
l += t;
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
int j = idx + l*d;
|
| 110 |
+
int2 range;
|
| 111 |
+
range.x = min(idx, j);
|
| 112 |
+
range.y = max(idx, j);
|
| 113 |
+
|
| 114 |
+
return range;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
__global__ void build_tree_kernel(
|
| 118 |
+
int N, int K,
|
| 119 |
+
int * sortedcodes,
|
| 120 |
+
int2 * nodechildren,
|
| 121 |
+
int * nodeparent) {
|
| 122 |
+
const int count = N * (K + K - 1);
|
| 123 |
+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
|
| 124 |
+
const int k = index % (K + K - 1);
|
| 125 |
+
const int n = index / (K + K - 1);
|
| 126 |
+
|
| 127 |
+
if (k >= K - 1) {
|
| 128 |
+
// leaf
|
| 129 |
+
nodechildren[n * (K + K - 1) + k] = make_int2(-(k - (K - 1)) - 1, -(k - (K - 1)) - 2);
|
| 130 |
+
} else {
|
| 131 |
+
// internal node
|
| 132 |
+
|
| 133 |
+
// find out which range of objects the node corresponds to
|
| 134 |
+
int2 range = determine_range(sortedcodes + n * K, K, k);
|
| 135 |
+
int first = range.x;
|
| 136 |
+
int last = range.y;
|
| 137 |
+
|
| 138 |
+
// determine where to split the range
|
| 139 |
+
int split = find_split(sortedcodes + n * K, first, last, K);
|
| 140 |
+
|
| 141 |
+
// select childA
|
| 142 |
+
int childa = split == first ? (K - 1) + split : split;
|
| 143 |
+
|
| 144 |
+
// select childB
|
| 145 |
+
int childb = split + 1 == last ? (K - 1) + split + 1 : split + 1;
|
| 146 |
+
|
| 147 |
+
// record parent-child relationships
|
| 148 |
+
nodechildren[n * (K + K - 1) + k] = make_int2(childa, childb);
|
| 149 |
+
nodeparent[n * (K + K - 1) + childa] = k;
|
| 150 |
+
nodeparent[n * (K + K - 1) + childb] = k;
|
| 151 |
+
}
|
| 152 |
+
}
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
template<typename PrimTransfT>
|
| 156 |
+
__global__ void compute_aabb_kernel(
|
| 157 |
+
int N, int K,
|
| 158 |
+
typename PrimTransfT::Data data,
|
| 159 |
+
int * sortedobjid,
|
| 160 |
+
int2 * nodechildren,
|
| 161 |
+
int * nodeparent,
|
| 162 |
+
float3 * nodeaabb,
|
| 163 |
+
int * atom) {
|
| 164 |
+
const int count = N * K;
|
| 165 |
+
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < count; index += blockDim.x * gridDim.x) {
|
| 166 |
+
const int k = index % K;
|
| 167 |
+
const int n = index / K;
|
| 168 |
+
|
| 169 |
+
// compute BBOX for leaf
|
| 170 |
+
int kk = sortedobjid[n * K + k];
|
| 171 |
+
|
| 172 |
+
float3 pmin;
|
| 173 |
+
float3 pmax;
|
| 174 |
+
data.compute_aabb(n, kk, pmin, pmax);
|
| 175 |
+
|
| 176 |
+
nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 0] = pmin;
|
| 177 |
+
nodeaabb[n * (K + K - 1) * 2 + ((K - 1) + k) * 2 + 1] = pmax;
|
| 178 |
+
|
| 179 |
+
int node = nodeparent[n * (K + K - 1) + ((K - 1) + k)];
|
| 180 |
+
|
| 181 |
+
while (node != -1 && atomicCAS(&atom[n * (K - 1) + node], 0, 1) == 1) {
|
| 182 |
+
int2 children = nodechildren[n * (K + K - 1) + node];
|
| 183 |
+
float3 laabbmin = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 0];
|
| 184 |
+
float3 laabbmax = nodeaabb[n * (K + K - 1) * 2 + children.x * 2 + 1];
|
| 185 |
+
float3 raabbmin = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 0];
|
| 186 |
+
float3 raabbmax = nodeaabb[n * (K + K - 1) * 2 + children.y * 2 + 1];
|
| 187 |
+
|
| 188 |
+
float3 aabbmin = fminf(laabbmin, raabbmin);
|
| 189 |
+
float3 aabbmax = fmaxf(laabbmax, raabbmax);
|
| 190 |
+
|
| 191 |
+
nodeaabb[n * (K + K - 1) * 2 + node * 2 + 0] = aabbmin;
|
| 192 |
+
nodeaabb[n * (K + K - 1) * 2 + node * 2 + 1] = aabbmax;
|
| 193 |
+
|
| 194 |
+
node = nodeparent[n * (K + K - 1) + node];
|
| 195 |
+
|
| 196 |
+
__threadfence();
|
| 197 |
+
}
|
| 198 |
+
}
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
void compute_morton_cuda(
|
| 202 |
+
int N, int K,
|
| 203 |
+
float * primpos,
|
| 204 |
+
int * code,
|
| 205 |
+
int algorithm,
|
| 206 |
+
cudaStream_t stream) {
|
| 207 |
+
int count = N * K;
|
| 208 |
+
int blocksize = 512;
|
| 209 |
+
int gridsize = (count + blocksize - 1) / blocksize;
|
| 210 |
+
|
| 211 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
| 212 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
| 213 |
+
PrimTransfDataBase{},
|
| 214 |
+
K, (float3*)primpos, nullptr,
|
| 215 |
+
K * 3, nullptr, nullptr,
|
| 216 |
+
K, nullptr, nullptr});
|
| 217 |
+
|
| 218 |
+
std::map<int, std::function<void(dim3, dim3, cudaStream_t, int, int, std::shared_ptr<PrimTransfDataBase>, int*)>> dispatcher = {
|
| 219 |
+
{ 0, make_cudacall(compute_morton_kernel<PrimTransfSRT>) }
|
| 220 |
+
};
|
| 221 |
+
|
| 222 |
+
auto iter = dispatcher.find(min(0, algorithm));
|
| 223 |
+
if (iter != dispatcher.end()) {
|
| 224 |
+
(iter->second)(
|
| 225 |
+
dim3(gridsize), dim3(blocksize), stream,
|
| 226 |
+
N, K,
|
| 227 |
+
primtransf_data,
|
| 228 |
+
code);
|
| 229 |
+
}
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
void build_tree_cuda(
|
| 233 |
+
int N, int K,
|
| 234 |
+
int * sortedcode,
|
| 235 |
+
int * nodechildren,
|
| 236 |
+
int * nodeparent,
|
| 237 |
+
cudaStream_t stream) {
|
| 238 |
+
int count = N * (K + K - 1);
|
| 239 |
+
int nthreads = 512;
|
| 240 |
+
int nblocks = (count + nthreads - 1) / nthreads;
|
| 241 |
+
build_tree_kernel<<<nblocks, nthreads, 0, stream>>>(
|
| 242 |
+
N, K,
|
| 243 |
+
sortedcode,
|
| 244 |
+
reinterpret_cast<int2 *>(nodechildren),
|
| 245 |
+
nodeparent);
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
void compute_aabb_cuda(
|
| 249 |
+
int N, int K,
|
| 250 |
+
float * primpos,
|
| 251 |
+
float * primrot,
|
| 252 |
+
float * primscale,
|
| 253 |
+
int * sortedobjid,
|
| 254 |
+
int * nodechildren,
|
| 255 |
+
int * nodeparent,
|
| 256 |
+
float * nodeaabb,
|
| 257 |
+
int algorithm,
|
| 258 |
+
cudaStream_t stream) {
|
| 259 |
+
int * atom;
|
| 260 |
+
cudaMalloc(&atom, N * (K - 1) * 4);
|
| 261 |
+
cudaMemset(atom, 0, N * (K - 1) * 4);
|
| 262 |
+
|
| 263 |
+
int count = N * K;
|
| 264 |
+
int blocksize = 512;
|
| 265 |
+
int gridsize = (count + blocksize - 1) / blocksize;
|
| 266 |
+
|
| 267 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
| 268 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
| 269 |
+
PrimTransfDataBase{},
|
| 270 |
+
K, (float3*)primpos, nullptr,
|
| 271 |
+
K * 3, (float3*)primrot, nullptr,
|
| 272 |
+
K, (float3*)primscale, nullptr});
|
| 273 |
+
|
| 274 |
+
std::map<int, std::function<void(dim3, dim3, cudaStream_t, int, int, std::shared_ptr<PrimTransfDataBase>, int*, int2*, int*, float3*, int*)>> dispatcher = {
|
| 275 |
+
{ 0, make_cudacall(compute_aabb_kernel<PrimTransfSRT>) }
|
| 276 |
+
};
|
| 277 |
+
|
| 278 |
+
auto iter = dispatcher.find(min(0, algorithm));
|
| 279 |
+
if (iter != dispatcher.end()) {
|
| 280 |
+
(iter->second)(
|
| 281 |
+
dim3(gridsize), dim3(blocksize), stream,
|
| 282 |
+
N, K,
|
| 283 |
+
primtransf_data,
|
| 284 |
+
sortedobjid,
|
| 285 |
+
reinterpret_cast<int2 *>(nodechildren),
|
| 286 |
+
nodeparent,
|
| 287 |
+
reinterpret_cast<float3 *>(nodeaabb),
|
| 288 |
+
atom);
|
| 289 |
+
}
|
| 290 |
+
|
| 291 |
+
cudaFree(atom);
|
| 292 |
+
}
|
dva/mvp/extensions/mvpraymarch/cudadispatch.h
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#ifndef cudadispatch_h_
|
| 8 |
+
#define cudadispatch_h_
|
| 9 |
+
|
| 10 |
+
#include <functional>
|
| 11 |
+
#include <memory>
|
| 12 |
+
#include <type_traits>
|
| 13 |
+
|
| 14 |
+
template<typename T, typename = void>
|
| 15 |
+
struct get_base {
|
| 16 |
+
typedef T type;
|
| 17 |
+
};
|
| 18 |
+
|
| 19 |
+
template<typename T>
|
| 20 |
+
struct get_base<T, typename std::enable_if<std::is_base_of<typename T::base, T>::value>::type> {
|
| 21 |
+
typedef std::shared_ptr<typename T::base> type;
|
| 22 |
+
};
|
| 23 |
+
|
| 24 |
+
template<typename T> struct is_shared_ptr : std::false_type {};
|
| 25 |
+
template<typename T> struct is_shared_ptr<std::shared_ptr<T>> : std::true_type {};
|
| 26 |
+
|
| 27 |
+
template<typename OutT, typename T>
|
| 28 |
+
auto convert_shptr_impl2(std::shared_ptr<T> t) {
|
| 29 |
+
return *static_cast<OutT*>(t.get());
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
template<typename OutT, typename T>
|
| 33 |
+
auto convert_shptr_impl(T&& t, std::false_type) {
|
| 34 |
+
return convert_shptr_impl2<OutT>(t);
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template<typename OutT, typename T>
|
| 38 |
+
auto convert_shptr_impl(T&& t, std::true_type) {
|
| 39 |
+
return std::forward<T>(t);
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
template<typename OutT, typename T>
|
| 43 |
+
auto convert_shptr(T&& t) {
|
| 44 |
+
return convert_shptr_impl<OutT>(std::forward<T>(t), std::is_same<OutT, T>{});
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template<typename... ArgsIn>
|
| 48 |
+
struct cudacall {
|
| 49 |
+
struct functbase {
|
| 50 |
+
virtual ~functbase() {}
|
| 51 |
+
virtual void call(dim3, dim3, cudaStream_t, ArgsIn...) const = 0;
|
| 52 |
+
};
|
| 53 |
+
|
| 54 |
+
template<typename... ArgsOut>
|
| 55 |
+
struct funct : public functbase {
|
| 56 |
+
std::function<void(ArgsOut...)> fn;
|
| 57 |
+
funct(void(*fn_)(ArgsOut...)) : fn(fn_) { }
|
| 58 |
+
void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsIn... args) const {
|
| 59 |
+
void (*const*kfunc)(ArgsOut...) = fn.template target<void (*)(ArgsOut...)>();
|
| 60 |
+
(*kfunc)<<<gridsize, blocksize, 0, stream>>>(
|
| 61 |
+
std::forward<ArgsOut>(convert_shptr<ArgsOut>(std::forward<ArgsIn>(args)))...);
|
| 62 |
+
}
|
| 63 |
+
};
|
| 64 |
+
|
| 65 |
+
std::shared_ptr<functbase> fn;
|
| 66 |
+
|
| 67 |
+
template<typename... ArgsOut>
|
| 68 |
+
cudacall(void(*fn_)(ArgsOut...)) : fn(std::make_shared<funct<ArgsOut...>>(fn_)) { }
|
| 69 |
+
|
| 70 |
+
template<typename... ArgsTmp>
|
| 71 |
+
void call(dim3 gridsize, dim3 blocksize, cudaStream_t stream, ArgsTmp&&... args) const {
|
| 72 |
+
fn->call(gridsize, blocksize, stream, std::forward<ArgsIn>(args)...);
|
| 73 |
+
}
|
| 74 |
+
};
|
| 75 |
+
|
| 76 |
+
template <typename F, typename T>
|
| 77 |
+
struct binder {
|
| 78 |
+
F f; T t;
|
| 79 |
+
template <typename... Args>
|
| 80 |
+
auto operator()(Args&&... args) const
|
| 81 |
+
-> decltype(f(t, std::forward<Args>(args)...)) {
|
| 82 |
+
return f(t, std::forward<Args>(args)...);
|
| 83 |
+
}
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
template <typename F, typename T>
|
| 87 |
+
binder<typename std::decay<F>::type
|
| 88 |
+
, typename std::decay<T>::type> BindFirst(F&& f, T&& t) {
|
| 89 |
+
return { std::forward<F>(f), std::forward<T>(t) };
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
template<typename... ArgsOut>
|
| 93 |
+
auto make_cudacall_(void(*fn)(ArgsOut...)) {
|
| 94 |
+
return BindFirst(
|
| 95 |
+
std::mem_fn(&cudacall<typename get_base<ArgsOut>::type...>::template call<typename get_base<ArgsOut>::type...>),
|
| 96 |
+
cudacall<typename get_base<ArgsOut>::type...>(fn));
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
template<typename... ArgsOut>
|
| 100 |
+
std::function<void(dim3, dim3, cudaStream_t, typename get_base<ArgsOut>::type...)> make_cudacall(void(*fn)(ArgsOut...)) {
|
| 101 |
+
return std::function<void(dim3, dim3, cudaStream_t, typename get_base<ArgsOut>::type...)>(make_cudacall_(fn));
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/helper_math.h
ADDED
|
@@ -0,0 +1,1453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
| 5 |
+
* with this source code for terms and conditions that govern your use of
|
| 6 |
+
* this software. Any use, reproduction, disclosure, or distribution of
|
| 7 |
+
* this software and related documentation outside the terms of the EULA
|
| 8 |
+
* is strictly prohibited.
|
| 9 |
+
*
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
/*
|
| 13 |
+
* This file implements common mathematical operations on vector types
|
| 14 |
+
* (float3, float4 etc.) since these are not provided as standard by CUDA.
|
| 15 |
+
*
|
| 16 |
+
* The syntax is modeled on the Cg standard library.
|
| 17 |
+
*
|
| 18 |
+
* This is part of the Helper library includes
|
| 19 |
+
*
|
| 20 |
+
* Thanks to Linh Hah for additions and fixes.
|
| 21 |
+
*/
|
| 22 |
+
|
| 23 |
+
#ifndef HELPER_MATH_H
|
| 24 |
+
#define HELPER_MATH_H
|
| 25 |
+
|
| 26 |
+
#include "cuda_runtime.h"
|
| 27 |
+
|
| 28 |
+
typedef unsigned int uint;
|
| 29 |
+
typedef unsigned short ushort;
|
| 30 |
+
|
| 31 |
+
#ifndef EXIT_WAIVED
|
| 32 |
+
#define EXIT_WAIVED 2
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
#ifndef __CUDACC__
|
| 36 |
+
#include <math.h>
|
| 37 |
+
|
| 38 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 39 |
+
// host implementations of CUDA functions
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
inline float fminf(float a, float b)
|
| 43 |
+
{
|
| 44 |
+
return a < b ? a : b;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
inline float fmaxf(float a, float b)
|
| 48 |
+
{
|
| 49 |
+
return a > b ? a : b;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
inline int max(int a, int b)
|
| 53 |
+
{
|
| 54 |
+
return a > b ? a : b;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline int min(int a, int b)
|
| 58 |
+
{
|
| 59 |
+
return a < b ? a : b;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
inline float rsqrtf(float x)
|
| 63 |
+
{
|
| 64 |
+
return 1.0f / sqrtf(x);
|
| 65 |
+
}
|
| 66 |
+
#endif
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
// constructors
|
| 70 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
|
| 72 |
+
inline __host__ __device__ float2 make_float2(float s)
|
| 73 |
+
{
|
| 74 |
+
return make_float2(s, s);
|
| 75 |
+
}
|
| 76 |
+
inline __host__ __device__ float2 make_float2(float3 a)
|
| 77 |
+
{
|
| 78 |
+
return make_float2(a.x, a.y);
|
| 79 |
+
}
|
| 80 |
+
inline __host__ __device__ float2 make_float2(int2 a)
|
| 81 |
+
{
|
| 82 |
+
return make_float2(float(a.x), float(a.y));
|
| 83 |
+
}
|
| 84 |
+
inline __host__ __device__ float2 make_float2(uint2 a)
|
| 85 |
+
{
|
| 86 |
+
return make_float2(float(a.x), float(a.y));
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
inline __host__ __device__ int2 make_int2(int s)
|
| 90 |
+
{
|
| 91 |
+
return make_int2(s, s);
|
| 92 |
+
}
|
| 93 |
+
inline __host__ __device__ int2 make_int2(int3 a)
|
| 94 |
+
{
|
| 95 |
+
return make_int2(a.x, a.y);
|
| 96 |
+
}
|
| 97 |
+
inline __host__ __device__ int2 make_int2(uint2 a)
|
| 98 |
+
{
|
| 99 |
+
return make_int2(int(a.x), int(a.y));
|
| 100 |
+
}
|
| 101 |
+
inline __host__ __device__ int2 make_int2(float2 a)
|
| 102 |
+
{
|
| 103 |
+
return make_int2(int(a.x), int(a.y));
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
inline __host__ __device__ uint2 make_uint2(uint s)
|
| 107 |
+
{
|
| 108 |
+
return make_uint2(s, s);
|
| 109 |
+
}
|
| 110 |
+
inline __host__ __device__ uint2 make_uint2(uint3 a)
|
| 111 |
+
{
|
| 112 |
+
return make_uint2(a.x, a.y);
|
| 113 |
+
}
|
| 114 |
+
inline __host__ __device__ uint2 make_uint2(int2 a)
|
| 115 |
+
{
|
| 116 |
+
return make_uint2(uint(a.x), uint(a.y));
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline __host__ __device__ float3 make_float3(float s)
|
| 120 |
+
{
|
| 121 |
+
return make_float3(s, s, s);
|
| 122 |
+
}
|
| 123 |
+
inline __host__ __device__ float3 make_float3(float2 a)
|
| 124 |
+
{
|
| 125 |
+
return make_float3(a.x, a.y, 0.0f);
|
| 126 |
+
}
|
| 127 |
+
inline __host__ __device__ float3 make_float3(float2 a, float s)
|
| 128 |
+
{
|
| 129 |
+
return make_float3(a.x, a.y, s);
|
| 130 |
+
}
|
| 131 |
+
inline __host__ __device__ float3 make_float3(float4 a)
|
| 132 |
+
{
|
| 133 |
+
return make_float3(a.x, a.y, a.z);
|
| 134 |
+
}
|
| 135 |
+
inline __host__ __device__ float3 make_float3(int3 a)
|
| 136 |
+
{
|
| 137 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
| 138 |
+
}
|
| 139 |
+
inline __host__ __device__ float3 make_float3(uint3 a)
|
| 140 |
+
{
|
| 141 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
inline __host__ __device__ int3 make_int3(int s)
|
| 145 |
+
{
|
| 146 |
+
return make_int3(s, s, s);
|
| 147 |
+
}
|
| 148 |
+
inline __host__ __device__ int3 make_int3(int2 a)
|
| 149 |
+
{
|
| 150 |
+
return make_int3(a.x, a.y, 0);
|
| 151 |
+
}
|
| 152 |
+
inline __host__ __device__ int3 make_int3(int2 a, int s)
|
| 153 |
+
{
|
| 154 |
+
return make_int3(a.x, a.y, s);
|
| 155 |
+
}
|
| 156 |
+
inline __host__ __device__ int3 make_int3(uint3 a)
|
| 157 |
+
{
|
| 158 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
| 159 |
+
}
|
| 160 |
+
inline __host__ __device__ int3 make_int3(float3 a)
|
| 161 |
+
{
|
| 162 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
inline __host__ __device__ uint3 make_uint3(uint s)
|
| 166 |
+
{
|
| 167 |
+
return make_uint3(s, s, s);
|
| 168 |
+
}
|
| 169 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a)
|
| 170 |
+
{
|
| 171 |
+
return make_uint3(a.x, a.y, 0);
|
| 172 |
+
}
|
| 173 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
|
| 174 |
+
{
|
| 175 |
+
return make_uint3(a.x, a.y, s);
|
| 176 |
+
}
|
| 177 |
+
inline __host__ __device__ uint3 make_uint3(uint4 a)
|
| 178 |
+
{
|
| 179 |
+
return make_uint3(a.x, a.y, a.z);
|
| 180 |
+
}
|
| 181 |
+
inline __host__ __device__ uint3 make_uint3(int3 a)
|
| 182 |
+
{
|
| 183 |
+
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
inline __host__ __device__ float4 make_float4(float s)
|
| 187 |
+
{
|
| 188 |
+
return make_float4(s, s, s, s);
|
| 189 |
+
}
|
| 190 |
+
inline __host__ __device__ float4 make_float4(float3 a)
|
| 191 |
+
{
|
| 192 |
+
return make_float4(a.x, a.y, a.z, 0.0f);
|
| 193 |
+
}
|
| 194 |
+
inline __host__ __device__ float4 make_float4(float3 a, float w)
|
| 195 |
+
{
|
| 196 |
+
return make_float4(a.x, a.y, a.z, w);
|
| 197 |
+
}
|
| 198 |
+
inline __host__ __device__ float4 make_float4(int4 a)
|
| 199 |
+
{
|
| 200 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
| 201 |
+
}
|
| 202 |
+
inline __host__ __device__ float4 make_float4(uint4 a)
|
| 203 |
+
{
|
| 204 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
inline __host__ __device__ int4 make_int4(int s)
|
| 208 |
+
{
|
| 209 |
+
return make_int4(s, s, s, s);
|
| 210 |
+
}
|
| 211 |
+
inline __host__ __device__ int4 make_int4(int3 a)
|
| 212 |
+
{
|
| 213 |
+
return make_int4(a.x, a.y, a.z, 0);
|
| 214 |
+
}
|
| 215 |
+
inline __host__ __device__ int4 make_int4(int3 a, int w)
|
| 216 |
+
{
|
| 217 |
+
return make_int4(a.x, a.y, a.z, w);
|
| 218 |
+
}
|
| 219 |
+
inline __host__ __device__ int4 make_int4(uint4 a)
|
| 220 |
+
{
|
| 221 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
| 222 |
+
}
|
| 223 |
+
inline __host__ __device__ int4 make_int4(float4 a)
|
| 224 |
+
{
|
| 225 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
inline __host__ __device__ uint4 make_uint4(uint s)
|
| 230 |
+
{
|
| 231 |
+
return make_uint4(s, s, s, s);
|
| 232 |
+
}
|
| 233 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a)
|
| 234 |
+
{
|
| 235 |
+
return make_uint4(a.x, a.y, a.z, 0);
|
| 236 |
+
}
|
| 237 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
|
| 238 |
+
{
|
| 239 |
+
return make_uint4(a.x, a.y, a.z, w);
|
| 240 |
+
}
|
| 241 |
+
inline __host__ __device__ uint4 make_uint4(int4 a)
|
| 242 |
+
{
|
| 243 |
+
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 247 |
+
// negate
|
| 248 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 249 |
+
|
| 250 |
+
inline __host__ __device__ float2 operator-(float2 &a)
|
| 251 |
+
{
|
| 252 |
+
return make_float2(-a.x, -a.y);
|
| 253 |
+
}
|
| 254 |
+
inline __host__ __device__ int2 operator-(int2 &a)
|
| 255 |
+
{
|
| 256 |
+
return make_int2(-a.x, -a.y);
|
| 257 |
+
}
|
| 258 |
+
inline __host__ __device__ float3 operator-(float3 &a)
|
| 259 |
+
{
|
| 260 |
+
return make_float3(-a.x, -a.y, -a.z);
|
| 261 |
+
}
|
| 262 |
+
inline __host__ __device__ int3 operator-(int3 &a)
|
| 263 |
+
{
|
| 264 |
+
return make_int3(-a.x, -a.y, -a.z);
|
| 265 |
+
}
|
| 266 |
+
inline __host__ __device__ float4 operator-(float4 &a)
|
| 267 |
+
{
|
| 268 |
+
return make_float4(-a.x, -a.y, -a.z, -a.w);
|
| 269 |
+
}
|
| 270 |
+
inline __host__ __device__ int4 operator-(int4 &a)
|
| 271 |
+
{
|
| 272 |
+
return make_int4(-a.x, -a.y, -a.z, -a.w);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 276 |
+
// addition
|
| 277 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 278 |
+
|
| 279 |
+
inline __host__ __device__ float2 operator+(float2 a, float2 b)
|
| 280 |
+
{
|
| 281 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
| 282 |
+
}
|
| 283 |
+
inline __host__ __device__ void operator+=(float2 &a, float2 b)
|
| 284 |
+
{
|
| 285 |
+
a.x += b.x;
|
| 286 |
+
a.y += b.y;
|
| 287 |
+
}
|
| 288 |
+
inline __host__ __device__ float2 operator+(float2 a, float b)
|
| 289 |
+
{
|
| 290 |
+
return make_float2(a.x + b, a.y + b);
|
| 291 |
+
}
|
| 292 |
+
inline __host__ __device__ float2 operator+(float b, float2 a)
|
| 293 |
+
{
|
| 294 |
+
return make_float2(a.x + b, a.y + b);
|
| 295 |
+
}
|
| 296 |
+
inline __host__ __device__ void operator+=(float2 &a, float b)
|
| 297 |
+
{
|
| 298 |
+
a.x += b;
|
| 299 |
+
a.y += b;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
inline __host__ __device__ int2 operator+(int2 a, int2 b)
|
| 303 |
+
{
|
| 304 |
+
return make_int2(a.x + b.x, a.y + b.y);
|
| 305 |
+
}
|
| 306 |
+
inline __host__ __device__ void operator+=(int2 &a, int2 b)
|
| 307 |
+
{
|
| 308 |
+
a.x += b.x;
|
| 309 |
+
a.y += b.y;
|
| 310 |
+
}
|
| 311 |
+
inline __host__ __device__ int2 operator+(int2 a, int b)
|
| 312 |
+
{
|
| 313 |
+
return make_int2(a.x + b, a.y + b);
|
| 314 |
+
}
|
| 315 |
+
inline __host__ __device__ int2 operator+(int b, int2 a)
|
| 316 |
+
{
|
| 317 |
+
return make_int2(a.x + b, a.y + b);
|
| 318 |
+
}
|
| 319 |
+
inline __host__ __device__ void operator+=(int2 &a, int b)
|
| 320 |
+
{
|
| 321 |
+
a.x += b;
|
| 322 |
+
a.y += b;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
|
| 326 |
+
{
|
| 327 |
+
return make_uint2(a.x + b.x, a.y + b.y);
|
| 328 |
+
}
|
| 329 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
|
| 330 |
+
{
|
| 331 |
+
a.x += b.x;
|
| 332 |
+
a.y += b.y;
|
| 333 |
+
}
|
| 334 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
|
| 335 |
+
{
|
| 336 |
+
return make_uint2(a.x + b, a.y + b);
|
| 337 |
+
}
|
| 338 |
+
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
|
| 339 |
+
{
|
| 340 |
+
return make_uint2(a.x + b, a.y + b);
|
| 341 |
+
}
|
| 342 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint b)
|
| 343 |
+
{
|
| 344 |
+
a.x += b;
|
| 345 |
+
a.y += b;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
inline __host__ __device__ float3 operator+(float3 a, float3 b)
|
| 350 |
+
{
|
| 351 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 352 |
+
}
|
| 353 |
+
inline __host__ __device__ void operator+=(float3 &a, float3 b)
|
| 354 |
+
{
|
| 355 |
+
a.x += b.x;
|
| 356 |
+
a.y += b.y;
|
| 357 |
+
a.z += b.z;
|
| 358 |
+
}
|
| 359 |
+
inline __host__ __device__ float3 operator+(float3 a, float b)
|
| 360 |
+
{
|
| 361 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
| 362 |
+
}
|
| 363 |
+
inline __host__ __device__ void operator+=(float3 &a, float b)
|
| 364 |
+
{
|
| 365 |
+
a.x += b;
|
| 366 |
+
a.y += b;
|
| 367 |
+
a.z += b;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
inline __host__ __device__ int3 operator+(int3 a, int3 b)
|
| 371 |
+
{
|
| 372 |
+
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 373 |
+
}
|
| 374 |
+
inline __host__ __device__ void operator+=(int3 &a, int3 b)
|
| 375 |
+
{
|
| 376 |
+
a.x += b.x;
|
| 377 |
+
a.y += b.y;
|
| 378 |
+
a.z += b.z;
|
| 379 |
+
}
|
| 380 |
+
inline __host__ __device__ int3 operator+(int3 a, int b)
|
| 381 |
+
{
|
| 382 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
| 383 |
+
}
|
| 384 |
+
inline __host__ __device__ void operator+=(int3 &a, int b)
|
| 385 |
+
{
|
| 386 |
+
a.x += b;
|
| 387 |
+
a.y += b;
|
| 388 |
+
a.z += b;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
|
| 392 |
+
{
|
| 393 |
+
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 394 |
+
}
|
| 395 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
|
| 396 |
+
{
|
| 397 |
+
a.x += b.x;
|
| 398 |
+
a.y += b.y;
|
| 399 |
+
a.z += b.z;
|
| 400 |
+
}
|
| 401 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
|
| 402 |
+
{
|
| 403 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
| 404 |
+
}
|
| 405 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint b)
|
| 406 |
+
{
|
| 407 |
+
a.x += b;
|
| 408 |
+
a.y += b;
|
| 409 |
+
a.z += b;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
inline __host__ __device__ int3 operator+(int b, int3 a)
|
| 413 |
+
{
|
| 414 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
| 415 |
+
}
|
| 416 |
+
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
|
| 417 |
+
{
|
| 418 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
| 419 |
+
}
|
| 420 |
+
inline __host__ __device__ float3 operator+(float b, float3 a)
|
| 421 |
+
{
|
| 422 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
inline __host__ __device__ float4 operator+(float4 a, float4 b)
|
| 426 |
+
{
|
| 427 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 428 |
+
}
|
| 429 |
+
inline __host__ __device__ void operator+=(float4 &a, float4 b)
|
| 430 |
+
{
|
| 431 |
+
a.x += b.x;
|
| 432 |
+
a.y += b.y;
|
| 433 |
+
a.z += b.z;
|
| 434 |
+
a.w += b.w;
|
| 435 |
+
}
|
| 436 |
+
inline __host__ __device__ float4 operator+(float4 a, float b)
|
| 437 |
+
{
|
| 438 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 439 |
+
}
|
| 440 |
+
inline __host__ __device__ float4 operator+(float b, float4 a)
|
| 441 |
+
{
|
| 442 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 443 |
+
}
|
| 444 |
+
inline __host__ __device__ void operator+=(float4 &a, float b)
|
| 445 |
+
{
|
| 446 |
+
a.x += b;
|
| 447 |
+
a.y += b;
|
| 448 |
+
a.z += b;
|
| 449 |
+
a.w += b;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
inline __host__ __device__ int4 operator+(int4 a, int4 b)
|
| 453 |
+
{
|
| 454 |
+
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 455 |
+
}
|
| 456 |
+
inline __host__ __device__ void operator+=(int4 &a, int4 b)
|
| 457 |
+
{
|
| 458 |
+
a.x += b.x;
|
| 459 |
+
a.y += b.y;
|
| 460 |
+
a.z += b.z;
|
| 461 |
+
a.w += b.w;
|
| 462 |
+
}
|
| 463 |
+
inline __host__ __device__ int4 operator+(int4 a, int b)
|
| 464 |
+
{
|
| 465 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 466 |
+
}
|
| 467 |
+
inline __host__ __device__ int4 operator+(int b, int4 a)
|
| 468 |
+
{
|
| 469 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 470 |
+
}
|
| 471 |
+
inline __host__ __device__ void operator+=(int4 &a, int b)
|
| 472 |
+
{
|
| 473 |
+
a.x += b;
|
| 474 |
+
a.y += b;
|
| 475 |
+
a.z += b;
|
| 476 |
+
a.w += b;
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
|
| 480 |
+
{
|
| 481 |
+
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 482 |
+
}
|
| 483 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
|
| 484 |
+
{
|
| 485 |
+
a.x += b.x;
|
| 486 |
+
a.y += b.y;
|
| 487 |
+
a.z += b.z;
|
| 488 |
+
a.w += b.w;
|
| 489 |
+
}
|
| 490 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
|
| 491 |
+
{
|
| 492 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 493 |
+
}
|
| 494 |
+
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
|
| 495 |
+
{
|
| 496 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 497 |
+
}
|
| 498 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint b)
|
| 499 |
+
{
|
| 500 |
+
a.x += b;
|
| 501 |
+
a.y += b;
|
| 502 |
+
a.z += b;
|
| 503 |
+
a.w += b;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 507 |
+
// subtract
|
| 508 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 509 |
+
|
| 510 |
+
inline __host__ __device__ float2 operator-(float2 a, float2 b)
|
| 511 |
+
{
|
| 512 |
+
return make_float2(a.x - b.x, a.y - b.y);
|
| 513 |
+
}
|
| 514 |
+
inline __host__ __device__ void operator-=(float2 &a, float2 b)
|
| 515 |
+
{
|
| 516 |
+
a.x -= b.x;
|
| 517 |
+
a.y -= b.y;
|
| 518 |
+
}
|
| 519 |
+
inline __host__ __device__ float2 operator-(float2 a, float b)
|
| 520 |
+
{
|
| 521 |
+
return make_float2(a.x - b, a.y - b);
|
| 522 |
+
}
|
| 523 |
+
inline __host__ __device__ float2 operator-(float b, float2 a)
|
| 524 |
+
{
|
| 525 |
+
return make_float2(b - a.x, b - a.y);
|
| 526 |
+
}
|
| 527 |
+
inline __host__ __device__ void operator-=(float2 &a, float b)
|
| 528 |
+
{
|
| 529 |
+
a.x -= b;
|
| 530 |
+
a.y -= b;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
inline __host__ __device__ int2 operator-(int2 a, int2 b)
|
| 534 |
+
{
|
| 535 |
+
return make_int2(a.x - b.x, a.y - b.y);
|
| 536 |
+
}
|
| 537 |
+
inline __host__ __device__ void operator-=(int2 &a, int2 b)
|
| 538 |
+
{
|
| 539 |
+
a.x -= b.x;
|
| 540 |
+
a.y -= b.y;
|
| 541 |
+
}
|
| 542 |
+
inline __host__ __device__ int2 operator-(int2 a, int b)
|
| 543 |
+
{
|
| 544 |
+
return make_int2(a.x - b, a.y - b);
|
| 545 |
+
}
|
| 546 |
+
inline __host__ __device__ int2 operator-(int b, int2 a)
|
| 547 |
+
{
|
| 548 |
+
return make_int2(b - a.x, b - a.y);
|
| 549 |
+
}
|
| 550 |
+
inline __host__ __device__ void operator-=(int2 &a, int b)
|
| 551 |
+
{
|
| 552 |
+
a.x -= b;
|
| 553 |
+
a.y -= b;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
|
| 557 |
+
{
|
| 558 |
+
return make_uint2(a.x - b.x, a.y - b.y);
|
| 559 |
+
}
|
| 560 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
|
| 561 |
+
{
|
| 562 |
+
a.x -= b.x;
|
| 563 |
+
a.y -= b.y;
|
| 564 |
+
}
|
| 565 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
|
| 566 |
+
{
|
| 567 |
+
return make_uint2(a.x - b, a.y - b);
|
| 568 |
+
}
|
| 569 |
+
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
|
| 570 |
+
{
|
| 571 |
+
return make_uint2(b - a.x, b - a.y);
|
| 572 |
+
}
|
| 573 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint b)
|
| 574 |
+
{
|
| 575 |
+
a.x -= b;
|
| 576 |
+
a.y -= b;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
inline __host__ __device__ float3 operator-(float3 a, float3 b)
|
| 580 |
+
{
|
| 581 |
+
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 582 |
+
}
|
| 583 |
+
inline __host__ __device__ void operator-=(float3 &a, float3 b)
|
| 584 |
+
{
|
| 585 |
+
a.x -= b.x;
|
| 586 |
+
a.y -= b.y;
|
| 587 |
+
a.z -= b.z;
|
| 588 |
+
}
|
| 589 |
+
inline __host__ __device__ float3 operator-(float3 a, float b)
|
| 590 |
+
{
|
| 591 |
+
return make_float3(a.x - b, a.y - b, a.z - b);
|
| 592 |
+
}
|
| 593 |
+
inline __host__ __device__ float3 operator-(float b, float3 a)
|
| 594 |
+
{
|
| 595 |
+
return make_float3(b - a.x, b - a.y, b - a.z);
|
| 596 |
+
}
|
| 597 |
+
inline __host__ __device__ void operator-=(float3 &a, float b)
|
| 598 |
+
{
|
| 599 |
+
a.x -= b;
|
| 600 |
+
a.y -= b;
|
| 601 |
+
a.z -= b;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
inline __host__ __device__ int3 operator-(int3 a, int3 b)
|
| 605 |
+
{
|
| 606 |
+
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 607 |
+
}
|
| 608 |
+
inline __host__ __device__ void operator-=(int3 &a, int3 b)
|
| 609 |
+
{
|
| 610 |
+
a.x -= b.x;
|
| 611 |
+
a.y -= b.y;
|
| 612 |
+
a.z -= b.z;
|
| 613 |
+
}
|
| 614 |
+
inline __host__ __device__ int3 operator-(int3 a, int b)
|
| 615 |
+
{
|
| 616 |
+
return make_int3(a.x - b, a.y - b, a.z - b);
|
| 617 |
+
}
|
| 618 |
+
inline __host__ __device__ int3 operator-(int b, int3 a)
|
| 619 |
+
{
|
| 620 |
+
return make_int3(b - a.x, b - a.y, b - a.z);
|
| 621 |
+
}
|
| 622 |
+
inline __host__ __device__ void operator-=(int3 &a, int b)
|
| 623 |
+
{
|
| 624 |
+
a.x -= b;
|
| 625 |
+
a.y -= b;
|
| 626 |
+
a.z -= b;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
|
| 630 |
+
{
|
| 631 |
+
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 632 |
+
}
|
| 633 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
|
| 634 |
+
{
|
| 635 |
+
a.x -= b.x;
|
| 636 |
+
a.y -= b.y;
|
| 637 |
+
a.z -= b.z;
|
| 638 |
+
}
|
| 639 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
|
| 640 |
+
{
|
| 641 |
+
return make_uint3(a.x - b, a.y - b, a.z - b);
|
| 642 |
+
}
|
| 643 |
+
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
|
| 644 |
+
{
|
| 645 |
+
return make_uint3(b - a.x, b - a.y, b - a.z);
|
| 646 |
+
}
|
| 647 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint b)
|
| 648 |
+
{
|
| 649 |
+
a.x -= b;
|
| 650 |
+
a.y -= b;
|
| 651 |
+
a.z -= b;
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
inline __host__ __device__ float4 operator-(float4 a, float4 b)
|
| 655 |
+
{
|
| 656 |
+
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 657 |
+
}
|
| 658 |
+
inline __host__ __device__ void operator-=(float4 &a, float4 b)
|
| 659 |
+
{
|
| 660 |
+
a.x -= b.x;
|
| 661 |
+
a.y -= b.y;
|
| 662 |
+
a.z -= b.z;
|
| 663 |
+
a.w -= b.w;
|
| 664 |
+
}
|
| 665 |
+
inline __host__ __device__ float4 operator-(float4 a, float b)
|
| 666 |
+
{
|
| 667 |
+
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 668 |
+
}
|
| 669 |
+
inline __host__ __device__ void operator-=(float4 &a, float b)
|
| 670 |
+
{
|
| 671 |
+
a.x -= b;
|
| 672 |
+
a.y -= b;
|
| 673 |
+
a.z -= b;
|
| 674 |
+
a.w -= b;
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
inline __host__ __device__ int4 operator-(int4 a, int4 b)
|
| 678 |
+
{
|
| 679 |
+
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 680 |
+
}
|
| 681 |
+
inline __host__ __device__ void operator-=(int4 &a, int4 b)
|
| 682 |
+
{
|
| 683 |
+
a.x -= b.x;
|
| 684 |
+
a.y -= b.y;
|
| 685 |
+
a.z -= b.z;
|
| 686 |
+
a.w -= b.w;
|
| 687 |
+
}
|
| 688 |
+
inline __host__ __device__ int4 operator-(int4 a, int b)
|
| 689 |
+
{
|
| 690 |
+
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 691 |
+
}
|
| 692 |
+
inline __host__ __device__ int4 operator-(int b, int4 a)
|
| 693 |
+
{
|
| 694 |
+
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
|
| 695 |
+
}
|
| 696 |
+
inline __host__ __device__ void operator-=(int4 &a, int b)
|
| 697 |
+
{
|
| 698 |
+
a.x -= b;
|
| 699 |
+
a.y -= b;
|
| 700 |
+
a.z -= b;
|
| 701 |
+
a.w -= b;
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
|
| 705 |
+
{
|
| 706 |
+
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 707 |
+
}
|
| 708 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
|
| 709 |
+
{
|
| 710 |
+
a.x -= b.x;
|
| 711 |
+
a.y -= b.y;
|
| 712 |
+
a.z -= b.z;
|
| 713 |
+
a.w -= b.w;
|
| 714 |
+
}
|
| 715 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
|
| 716 |
+
{
|
| 717 |
+
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 718 |
+
}
|
| 719 |
+
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
|
| 720 |
+
{
|
| 721 |
+
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
|
| 722 |
+
}
|
| 723 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint b)
|
| 724 |
+
{
|
| 725 |
+
a.x -= b;
|
| 726 |
+
a.y -= b;
|
| 727 |
+
a.z -= b;
|
| 728 |
+
a.w -= b;
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 732 |
+
// multiply
|
| 733 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 734 |
+
|
| 735 |
+
inline __host__ __device__ float2 operator*(float2 a, float2 b)
|
| 736 |
+
{
|
| 737 |
+
return make_float2(a.x * b.x, a.y * b.y);
|
| 738 |
+
}
|
| 739 |
+
inline __host__ __device__ void operator*=(float2 &a, float2 b)
|
| 740 |
+
{
|
| 741 |
+
a.x *= b.x;
|
| 742 |
+
a.y *= b.y;
|
| 743 |
+
}
|
| 744 |
+
inline __host__ __device__ float2 operator*(float2 a, float b)
|
| 745 |
+
{
|
| 746 |
+
return make_float2(a.x * b, a.y * b);
|
| 747 |
+
}
|
| 748 |
+
inline __host__ __device__ float2 operator*(float b, float2 a)
|
| 749 |
+
{
|
| 750 |
+
return make_float2(b * a.x, b * a.y);
|
| 751 |
+
}
|
| 752 |
+
inline __host__ __device__ void operator*=(float2 &a, float b)
|
| 753 |
+
{
|
| 754 |
+
a.x *= b;
|
| 755 |
+
a.y *= b;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
inline __host__ __device__ int2 operator*(int2 a, int2 b)
|
| 759 |
+
{
|
| 760 |
+
return make_int2(a.x * b.x, a.y * b.y);
|
| 761 |
+
}
|
| 762 |
+
inline __host__ __device__ void operator*=(int2 &a, int2 b)
|
| 763 |
+
{
|
| 764 |
+
a.x *= b.x;
|
| 765 |
+
a.y *= b.y;
|
| 766 |
+
}
|
| 767 |
+
inline __host__ __device__ int2 operator*(int2 a, int b)
|
| 768 |
+
{
|
| 769 |
+
return make_int2(a.x * b, a.y * b);
|
| 770 |
+
}
|
| 771 |
+
inline __host__ __device__ int2 operator*(int b, int2 a)
|
| 772 |
+
{
|
| 773 |
+
return make_int2(b * a.x, b * a.y);
|
| 774 |
+
}
|
| 775 |
+
inline __host__ __device__ void operator*=(int2 &a, int b)
|
| 776 |
+
{
|
| 777 |
+
a.x *= b;
|
| 778 |
+
a.y *= b;
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
|
| 782 |
+
{
|
| 783 |
+
return make_uint2(a.x * b.x, a.y * b.y);
|
| 784 |
+
}
|
| 785 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
|
| 786 |
+
{
|
| 787 |
+
a.x *= b.x;
|
| 788 |
+
a.y *= b.y;
|
| 789 |
+
}
|
| 790 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
|
| 791 |
+
{
|
| 792 |
+
return make_uint2(a.x * b, a.y * b);
|
| 793 |
+
}
|
| 794 |
+
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
|
| 795 |
+
{
|
| 796 |
+
return make_uint2(b * a.x, b * a.y);
|
| 797 |
+
}
|
| 798 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint b)
|
| 799 |
+
{
|
| 800 |
+
a.x *= b;
|
| 801 |
+
a.y *= b;
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
inline __host__ __device__ float3 operator*(float3 a, float3 b)
|
| 805 |
+
{
|
| 806 |
+
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 807 |
+
}
|
| 808 |
+
inline __host__ __device__ void operator*=(float3 &a, float3 b)
|
| 809 |
+
{
|
| 810 |
+
a.x *= b.x;
|
| 811 |
+
a.y *= b.y;
|
| 812 |
+
a.z *= b.z;
|
| 813 |
+
}
|
| 814 |
+
inline __host__ __device__ float3 operator*(float3 a, float b)
|
| 815 |
+
{
|
| 816 |
+
return make_float3(a.x * b, a.y * b, a.z * b);
|
| 817 |
+
}
|
| 818 |
+
inline __host__ __device__ float3 operator*(float b, float3 a)
|
| 819 |
+
{
|
| 820 |
+
return make_float3(b * a.x, b * a.y, b * a.z);
|
| 821 |
+
}
|
| 822 |
+
inline __host__ __device__ void operator*=(float3 &a, float b)
|
| 823 |
+
{
|
| 824 |
+
a.x *= b;
|
| 825 |
+
a.y *= b;
|
| 826 |
+
a.z *= b;
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
inline __host__ __device__ int3 operator*(int3 a, int3 b)
|
| 830 |
+
{
|
| 831 |
+
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 832 |
+
}
|
| 833 |
+
inline __host__ __device__ void operator*=(int3 &a, int3 b)
|
| 834 |
+
{
|
| 835 |
+
a.x *= b.x;
|
| 836 |
+
a.y *= b.y;
|
| 837 |
+
a.z *= b.z;
|
| 838 |
+
}
|
| 839 |
+
inline __host__ __device__ int3 operator*(int3 a, int b)
|
| 840 |
+
{
|
| 841 |
+
return make_int3(a.x * b, a.y * b, a.z * b);
|
| 842 |
+
}
|
| 843 |
+
inline __host__ __device__ int3 operator*(int b, int3 a)
|
| 844 |
+
{
|
| 845 |
+
return make_int3(b * a.x, b * a.y, b * a.z);
|
| 846 |
+
}
|
| 847 |
+
inline __host__ __device__ void operator*=(int3 &a, int b)
|
| 848 |
+
{
|
| 849 |
+
a.x *= b;
|
| 850 |
+
a.y *= b;
|
| 851 |
+
a.z *= b;
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
|
| 855 |
+
{
|
| 856 |
+
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 857 |
+
}
|
| 858 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
|
| 859 |
+
{
|
| 860 |
+
a.x *= b.x;
|
| 861 |
+
a.y *= b.y;
|
| 862 |
+
a.z *= b.z;
|
| 863 |
+
}
|
| 864 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
|
| 865 |
+
{
|
| 866 |
+
return make_uint3(a.x * b, a.y * b, a.z * b);
|
| 867 |
+
}
|
| 868 |
+
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
|
| 869 |
+
{
|
| 870 |
+
return make_uint3(b * a.x, b * a.y, b * a.z);
|
| 871 |
+
}
|
| 872 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint b)
|
| 873 |
+
{
|
| 874 |
+
a.x *= b;
|
| 875 |
+
a.y *= b;
|
| 876 |
+
a.z *= b;
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
inline __host__ __device__ float4 operator*(float4 a, float4 b)
|
| 880 |
+
{
|
| 881 |
+
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 882 |
+
}
|
| 883 |
+
inline __host__ __device__ void operator*=(float4 &a, float4 b)
|
| 884 |
+
{
|
| 885 |
+
a.x *= b.x;
|
| 886 |
+
a.y *= b.y;
|
| 887 |
+
a.z *= b.z;
|
| 888 |
+
a.w *= b.w;
|
| 889 |
+
}
|
| 890 |
+
inline __host__ __device__ float4 operator*(float4 a, float b)
|
| 891 |
+
{
|
| 892 |
+
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 893 |
+
}
|
| 894 |
+
inline __host__ __device__ float4 operator*(float b, float4 a)
|
| 895 |
+
{
|
| 896 |
+
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 897 |
+
}
|
| 898 |
+
inline __host__ __device__ void operator*=(float4 &a, float b)
|
| 899 |
+
{
|
| 900 |
+
a.x *= b;
|
| 901 |
+
a.y *= b;
|
| 902 |
+
a.z *= b;
|
| 903 |
+
a.w *= b;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
inline __host__ __device__ int4 operator*(int4 a, int4 b)
|
| 907 |
+
{
|
| 908 |
+
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 909 |
+
}
|
| 910 |
+
inline __host__ __device__ void operator*=(int4 &a, int4 b)
|
| 911 |
+
{
|
| 912 |
+
a.x *= b.x;
|
| 913 |
+
a.y *= b.y;
|
| 914 |
+
a.z *= b.z;
|
| 915 |
+
a.w *= b.w;
|
| 916 |
+
}
|
| 917 |
+
inline __host__ __device__ int4 operator*(int4 a, int b)
|
| 918 |
+
{
|
| 919 |
+
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 920 |
+
}
|
| 921 |
+
inline __host__ __device__ int4 operator*(int b, int4 a)
|
| 922 |
+
{
|
| 923 |
+
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 924 |
+
}
|
| 925 |
+
inline __host__ __device__ void operator*=(int4 &a, int b)
|
| 926 |
+
{
|
| 927 |
+
a.x *= b;
|
| 928 |
+
a.y *= b;
|
| 929 |
+
a.z *= b;
|
| 930 |
+
a.w *= b;
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
|
| 934 |
+
{
|
| 935 |
+
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 936 |
+
}
|
| 937 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
|
| 938 |
+
{
|
| 939 |
+
a.x *= b.x;
|
| 940 |
+
a.y *= b.y;
|
| 941 |
+
a.z *= b.z;
|
| 942 |
+
a.w *= b.w;
|
| 943 |
+
}
|
| 944 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
|
| 945 |
+
{
|
| 946 |
+
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 947 |
+
}
|
| 948 |
+
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
|
| 949 |
+
{
|
| 950 |
+
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 951 |
+
}
|
| 952 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint b)
|
| 953 |
+
{
|
| 954 |
+
a.x *= b;
|
| 955 |
+
a.y *= b;
|
| 956 |
+
a.z *= b;
|
| 957 |
+
a.w *= b;
|
| 958 |
+
}
|
| 959 |
+
|
| 960 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 961 |
+
// divide
|
| 962 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 963 |
+
|
| 964 |
+
inline __host__ __device__ float2 operator/(float2 a, float2 b)
|
| 965 |
+
{
|
| 966 |
+
return make_float2(a.x / b.x, a.y / b.y);
|
| 967 |
+
}
|
| 968 |
+
inline __host__ __device__ void operator/=(float2 &a, float2 b)
|
| 969 |
+
{
|
| 970 |
+
a.x /= b.x;
|
| 971 |
+
a.y /= b.y;
|
| 972 |
+
}
|
| 973 |
+
inline __host__ __device__ float2 operator/(float2 a, float b)
|
| 974 |
+
{
|
| 975 |
+
return make_float2(a.x / b, a.y / b);
|
| 976 |
+
}
|
| 977 |
+
inline __host__ __device__ void operator/=(float2 &a, float b)
|
| 978 |
+
{
|
| 979 |
+
a.x /= b;
|
| 980 |
+
a.y /= b;
|
| 981 |
+
}
|
| 982 |
+
inline __host__ __device__ float2 operator/(float b, float2 a)
|
| 983 |
+
{
|
| 984 |
+
return make_float2(b / a.x, b / a.y);
|
| 985 |
+
}
|
| 986 |
+
|
| 987 |
+
inline __host__ __device__ float3 operator/(float3 a, float3 b)
|
| 988 |
+
{
|
| 989 |
+
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
| 990 |
+
}
|
| 991 |
+
inline __host__ __device__ void operator/=(float3 &a, float3 b)
|
| 992 |
+
{
|
| 993 |
+
a.x /= b.x;
|
| 994 |
+
a.y /= b.y;
|
| 995 |
+
a.z /= b.z;
|
| 996 |
+
}
|
| 997 |
+
inline __host__ __device__ float3 operator/(float3 a, float b)
|
| 998 |
+
{
|
| 999 |
+
return make_float3(a.x / b, a.y / b, a.z / b);
|
| 1000 |
+
}
|
| 1001 |
+
inline __host__ __device__ void operator/=(float3 &a, float b)
|
| 1002 |
+
{
|
| 1003 |
+
a.x /= b;
|
| 1004 |
+
a.y /= b;
|
| 1005 |
+
a.z /= b;
|
| 1006 |
+
}
|
| 1007 |
+
inline __host__ __device__ float3 operator/(float b, float3 a)
|
| 1008 |
+
{
|
| 1009 |
+
return make_float3(b / a.x, b / a.y, b / a.z);
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
inline __host__ __device__ float4 operator/(float4 a, float4 b)
|
| 1013 |
+
{
|
| 1014 |
+
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
|
| 1015 |
+
}
|
| 1016 |
+
inline __host__ __device__ void operator/=(float4 &a, float4 b)
|
| 1017 |
+
{
|
| 1018 |
+
a.x /= b.x;
|
| 1019 |
+
a.y /= b.y;
|
| 1020 |
+
a.z /= b.z;
|
| 1021 |
+
a.w /= b.w;
|
| 1022 |
+
}
|
| 1023 |
+
inline __host__ __device__ float4 operator/(float4 a, float b)
|
| 1024 |
+
{
|
| 1025 |
+
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
|
| 1026 |
+
}
|
| 1027 |
+
inline __host__ __device__ void operator/=(float4 &a, float b)
|
| 1028 |
+
{
|
| 1029 |
+
a.x /= b;
|
| 1030 |
+
a.y /= b;
|
| 1031 |
+
a.z /= b;
|
| 1032 |
+
a.w /= b;
|
| 1033 |
+
}
|
| 1034 |
+
inline __host__ __device__ float4 operator/(float b, float4 a)
|
| 1035 |
+
{
|
| 1036 |
+
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
|
| 1037 |
+
}
|
| 1038 |
+
|
| 1039 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1040 |
+
// min
|
| 1041 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1042 |
+
|
| 1043 |
+
inline __host__ __device__ float2 fminf(float2 a, float2 b)
|
| 1044 |
+
{
|
| 1045 |
+
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
|
| 1046 |
+
}
|
| 1047 |
+
inline __host__ __device__ float3 fminf(float3 a, float3 b)
|
| 1048 |
+
{
|
| 1049 |
+
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
|
| 1050 |
+
}
|
| 1051 |
+
inline __host__ __device__ float4 fminf(float4 a, float4 b)
|
| 1052 |
+
{
|
| 1053 |
+
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
|
| 1054 |
+
}
|
| 1055 |
+
|
| 1056 |
+
inline __host__ __device__ int2 min(int2 a, int2 b)
|
| 1057 |
+
{
|
| 1058 |
+
return make_int2(min(a.x,b.x), min(a.y,b.y));
|
| 1059 |
+
}
|
| 1060 |
+
inline __host__ __device__ int3 min(int3 a, int3 b)
|
| 1061 |
+
{
|
| 1062 |
+
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
| 1063 |
+
}
|
| 1064 |
+
inline __host__ __device__ int4 min(int4 a, int4 b)
|
| 1065 |
+
{
|
| 1066 |
+
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
|
| 1070 |
+
{
|
| 1071 |
+
return make_uint2(min(a.x,b.x), min(a.y,b.y));
|
| 1072 |
+
}
|
| 1073 |
+
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
|
| 1074 |
+
{
|
| 1075 |
+
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
| 1076 |
+
}
|
| 1077 |
+
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
|
| 1078 |
+
{
|
| 1079 |
+
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1083 |
+
// max
|
| 1084 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1085 |
+
|
| 1086 |
+
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
|
| 1087 |
+
{
|
| 1088 |
+
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
|
| 1089 |
+
}
|
| 1090 |
+
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
|
| 1091 |
+
{
|
| 1092 |
+
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
|
| 1093 |
+
}
|
| 1094 |
+
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
|
| 1095 |
+
{
|
| 1096 |
+
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
|
| 1097 |
+
}
|
| 1098 |
+
|
| 1099 |
+
inline __host__ __device__ int2 max(int2 a, int2 b)
|
| 1100 |
+
{
|
| 1101 |
+
return make_int2(max(a.x,b.x), max(a.y,b.y));
|
| 1102 |
+
}
|
| 1103 |
+
inline __host__ __device__ int3 max(int3 a, int3 b)
|
| 1104 |
+
{
|
| 1105 |
+
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
| 1106 |
+
}
|
| 1107 |
+
inline __host__ __device__ int4 max(int4 a, int4 b)
|
| 1108 |
+
{
|
| 1109 |
+
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
| 1110 |
+
}
|
| 1111 |
+
|
| 1112 |
+
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
|
| 1113 |
+
{
|
| 1114 |
+
return make_uint2(max(a.x,b.x), max(a.y,b.y));
|
| 1115 |
+
}
|
| 1116 |
+
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
|
| 1117 |
+
{
|
| 1118 |
+
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
| 1119 |
+
}
|
| 1120 |
+
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
|
| 1121 |
+
{
|
| 1122 |
+
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1126 |
+
// lerp
|
| 1127 |
+
// - linear interpolation between a and b, based on value t in [0, 1] range
|
| 1128 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1129 |
+
|
| 1130 |
+
inline __device__ __host__ float lerp(float a, float b, float t)
|
| 1131 |
+
{
|
| 1132 |
+
return a + t*(b-a);
|
| 1133 |
+
}
|
| 1134 |
+
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
|
| 1135 |
+
{
|
| 1136 |
+
return a + t*(b-a);
|
| 1137 |
+
}
|
| 1138 |
+
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
|
| 1139 |
+
{
|
| 1140 |
+
return a + t*(b-a);
|
| 1141 |
+
}
|
| 1142 |
+
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
|
| 1143 |
+
{
|
| 1144 |
+
return a + t*(b-a);
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1148 |
+
// clamp
|
| 1149 |
+
// - clamp the value v to be in the range [a, b]
|
| 1150 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1151 |
+
|
| 1152 |
+
inline __device__ __host__ float clamp(float f, float a, float b)
|
| 1153 |
+
{
|
| 1154 |
+
return fmaxf(a, fminf(f, b));
|
| 1155 |
+
}
|
| 1156 |
+
inline __device__ __host__ int clamp(int f, int a, int b)
|
| 1157 |
+
{
|
| 1158 |
+
return max(a, min(f, b));
|
| 1159 |
+
}
|
| 1160 |
+
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
|
| 1161 |
+
{
|
| 1162 |
+
return max(a, min(f, b));
|
| 1163 |
+
}
|
| 1164 |
+
|
| 1165 |
+
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
|
| 1166 |
+
{
|
| 1167 |
+
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1168 |
+
}
|
| 1169 |
+
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
|
| 1170 |
+
{
|
| 1171 |
+
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1172 |
+
}
|
| 1173 |
+
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
|
| 1174 |
+
{
|
| 1175 |
+
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1176 |
+
}
|
| 1177 |
+
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
|
| 1178 |
+
{
|
| 1179 |
+
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1180 |
+
}
|
| 1181 |
+
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
|
| 1182 |
+
{
|
| 1183 |
+
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1184 |
+
}
|
| 1185 |
+
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
|
| 1186 |
+
{
|
| 1187 |
+
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1188 |
+
}
|
| 1189 |
+
|
| 1190 |
+
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
|
| 1191 |
+
{
|
| 1192 |
+
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1193 |
+
}
|
| 1194 |
+
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
|
| 1195 |
+
{
|
| 1196 |
+
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1197 |
+
}
|
| 1198 |
+
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
|
| 1199 |
+
{
|
| 1200 |
+
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1201 |
+
}
|
| 1202 |
+
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
|
| 1203 |
+
{
|
| 1204 |
+
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1205 |
+
}
|
| 1206 |
+
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
|
| 1207 |
+
{
|
| 1208 |
+
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1209 |
+
}
|
| 1210 |
+
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
|
| 1211 |
+
{
|
| 1212 |
+
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1213 |
+
}
|
| 1214 |
+
|
| 1215 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
|
| 1216 |
+
{
|
| 1217 |
+
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1218 |
+
}
|
| 1219 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
|
| 1220 |
+
{
|
| 1221 |
+
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1222 |
+
}
|
| 1223 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
|
| 1224 |
+
{
|
| 1225 |
+
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1226 |
+
}
|
| 1227 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
|
| 1228 |
+
{
|
| 1229 |
+
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1230 |
+
}
|
| 1231 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
|
| 1232 |
+
{
|
| 1233 |
+
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1234 |
+
}
|
| 1235 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
|
| 1236 |
+
{
|
| 1237 |
+
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1241 |
+
// dot product
|
| 1242 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1243 |
+
|
| 1244 |
+
inline __host__ __device__ float dot(float2 a, float2 b)
|
| 1245 |
+
{
|
| 1246 |
+
return a.x * b.x + a.y * b.y;
|
| 1247 |
+
}
|
| 1248 |
+
inline __host__ __device__ float dot(float3 a, float3 b)
|
| 1249 |
+
{
|
| 1250 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1251 |
+
}
|
| 1252 |
+
inline __host__ __device__ float dot(float4 a, float4 b)
|
| 1253 |
+
{
|
| 1254 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1255 |
+
}
|
| 1256 |
+
|
| 1257 |
+
inline __host__ __device__ int dot(int2 a, int2 b)
|
| 1258 |
+
{
|
| 1259 |
+
return a.x * b.x + a.y * b.y;
|
| 1260 |
+
}
|
| 1261 |
+
inline __host__ __device__ int dot(int3 a, int3 b)
|
| 1262 |
+
{
|
| 1263 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1264 |
+
}
|
| 1265 |
+
inline __host__ __device__ int dot(int4 a, int4 b)
|
| 1266 |
+
{
|
| 1267 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1268 |
+
}
|
| 1269 |
+
|
| 1270 |
+
inline __host__ __device__ uint dot(uint2 a, uint2 b)
|
| 1271 |
+
{
|
| 1272 |
+
return a.x * b.x + a.y * b.y;
|
| 1273 |
+
}
|
| 1274 |
+
inline __host__ __device__ uint dot(uint3 a, uint3 b)
|
| 1275 |
+
{
|
| 1276 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1277 |
+
}
|
| 1278 |
+
inline __host__ __device__ uint dot(uint4 a, uint4 b)
|
| 1279 |
+
{
|
| 1280 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1284 |
+
// length
|
| 1285 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1286 |
+
|
| 1287 |
+
inline __host__ __device__ float length(float2 v)
|
| 1288 |
+
{
|
| 1289 |
+
return sqrtf(dot(v, v));
|
| 1290 |
+
}
|
| 1291 |
+
inline __host__ __device__ float length(float3 v)
|
| 1292 |
+
{
|
| 1293 |
+
return sqrtf(dot(v, v));
|
| 1294 |
+
}
|
| 1295 |
+
inline __host__ __device__ float length(float4 v)
|
| 1296 |
+
{
|
| 1297 |
+
return sqrtf(dot(v, v));
|
| 1298 |
+
}
|
| 1299 |
+
|
| 1300 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1301 |
+
// normalize
|
| 1302 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1303 |
+
|
| 1304 |
+
inline __host__ __device__ float2 normalize(float2 v)
|
| 1305 |
+
{
|
| 1306 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1307 |
+
return v * invLen;
|
| 1308 |
+
}
|
| 1309 |
+
inline __host__ __device__ float3 normalize(float3 v)
|
| 1310 |
+
{
|
| 1311 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1312 |
+
return v * invLen;
|
| 1313 |
+
}
|
| 1314 |
+
inline __host__ __device__ float4 normalize(float4 v)
|
| 1315 |
+
{
|
| 1316 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1317 |
+
return v * invLen;
|
| 1318 |
+
}
|
| 1319 |
+
|
| 1320 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1321 |
+
// floor
|
| 1322 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1323 |
+
|
| 1324 |
+
inline __host__ __device__ float2 floorf(float2 v)
|
| 1325 |
+
{
|
| 1326 |
+
return make_float2(floorf(v.x), floorf(v.y));
|
| 1327 |
+
}
|
| 1328 |
+
inline __host__ __device__ float3 floorf(float3 v)
|
| 1329 |
+
{
|
| 1330 |
+
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
|
| 1331 |
+
}
|
| 1332 |
+
inline __host__ __device__ float4 floorf(float4 v)
|
| 1333 |
+
{
|
| 1334 |
+
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
|
| 1335 |
+
}
|
| 1336 |
+
|
| 1337 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1338 |
+
// frac - returns the fractional portion of a scalar or each vector component
|
| 1339 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1340 |
+
|
| 1341 |
+
inline __host__ __device__ float fracf(float v)
|
| 1342 |
+
{
|
| 1343 |
+
return v - floorf(v);
|
| 1344 |
+
}
|
| 1345 |
+
inline __host__ __device__ float2 fracf(float2 v)
|
| 1346 |
+
{
|
| 1347 |
+
return make_float2(fracf(v.x), fracf(v.y));
|
| 1348 |
+
}
|
| 1349 |
+
inline __host__ __device__ float3 fracf(float3 v)
|
| 1350 |
+
{
|
| 1351 |
+
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
|
| 1352 |
+
}
|
| 1353 |
+
inline __host__ __device__ float4 fracf(float4 v)
|
| 1354 |
+
{
|
| 1355 |
+
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
|
| 1356 |
+
}
|
| 1357 |
+
|
| 1358 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1359 |
+
// fmod
|
| 1360 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1361 |
+
|
| 1362 |
+
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
|
| 1363 |
+
{
|
| 1364 |
+
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
|
| 1365 |
+
}
|
| 1366 |
+
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
|
| 1367 |
+
{
|
| 1368 |
+
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
|
| 1369 |
+
}
|
| 1370 |
+
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
|
| 1371 |
+
{
|
| 1372 |
+
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
|
| 1373 |
+
}
|
| 1374 |
+
|
| 1375 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1376 |
+
// absolute value
|
| 1377 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1378 |
+
|
| 1379 |
+
inline __host__ __device__ float2 fabs(float2 v)
|
| 1380 |
+
{
|
| 1381 |
+
return make_float2(fabs(v.x), fabs(v.y));
|
| 1382 |
+
}
|
| 1383 |
+
inline __host__ __device__ float3 fabs(float3 v)
|
| 1384 |
+
{
|
| 1385 |
+
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
|
| 1386 |
+
}
|
| 1387 |
+
inline __host__ __device__ float4 fabs(float4 v)
|
| 1388 |
+
{
|
| 1389 |
+
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
|
| 1390 |
+
}
|
| 1391 |
+
|
| 1392 |
+
inline __host__ __device__ int2 abs(int2 v)
|
| 1393 |
+
{
|
| 1394 |
+
return make_int2(abs(v.x), abs(v.y));
|
| 1395 |
+
}
|
| 1396 |
+
inline __host__ __device__ int3 abs(int3 v)
|
| 1397 |
+
{
|
| 1398 |
+
return make_int3(abs(v.x), abs(v.y), abs(v.z));
|
| 1399 |
+
}
|
| 1400 |
+
inline __host__ __device__ int4 abs(int4 v)
|
| 1401 |
+
{
|
| 1402 |
+
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
|
| 1403 |
+
}
|
| 1404 |
+
|
| 1405 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1406 |
+
// reflect
|
| 1407 |
+
// - returns reflection of incident ray I around surface normal N
|
| 1408 |
+
// - N should be normalized, reflected vector's length is equal to length of I
|
| 1409 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1410 |
+
|
| 1411 |
+
inline __host__ __device__ float3 reflect(float3 i, float3 n)
|
| 1412 |
+
{
|
| 1413 |
+
return i - 2.0f * n * dot(n,i);
|
| 1414 |
+
}
|
| 1415 |
+
|
| 1416 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1417 |
+
// cross product
|
| 1418 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1419 |
+
|
| 1420 |
+
inline __host__ __device__ float3 cross(float3 a, float3 b)
|
| 1421 |
+
{
|
| 1422 |
+
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1426 |
+
// smoothstep
|
| 1427 |
+
// - returns 0 if x < a
|
| 1428 |
+
// - returns 1 if x > b
|
| 1429 |
+
// - otherwise returns smooth interpolation between 0 and 1 based on x
|
| 1430 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1431 |
+
|
| 1432 |
+
inline __device__ __host__ float smoothstep(float a, float b, float x)
|
| 1433 |
+
{
|
| 1434 |
+
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1435 |
+
return (y*y*(3.0f - (2.0f*y)));
|
| 1436 |
+
}
|
| 1437 |
+
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
|
| 1438 |
+
{
|
| 1439 |
+
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1440 |
+
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
|
| 1441 |
+
}
|
| 1442 |
+
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
|
| 1443 |
+
{
|
| 1444 |
+
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1445 |
+
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
|
| 1446 |
+
}
|
| 1447 |
+
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
|
| 1448 |
+
{
|
| 1449 |
+
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1450 |
+
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
|
| 1451 |
+
}
|
| 1452 |
+
|
| 1453 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/makefile
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
all:
|
| 2 |
+
python setup.py build_ext --inplace
|
dva/mvp/extensions/mvpraymarch/mvpraymarch.cpp
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <c10/cuda/CUDAStream.h>
|
| 9 |
+
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
void compute_morton_cuda(
|
| 13 |
+
int N, int K,
|
| 14 |
+
float * primpos,
|
| 15 |
+
int * code,
|
| 16 |
+
int algorithm,
|
| 17 |
+
cudaStream_t stream);
|
| 18 |
+
|
| 19 |
+
void build_tree_cuda(
|
| 20 |
+
int N, int K,
|
| 21 |
+
int * sortedcode,
|
| 22 |
+
int * nodechildren,
|
| 23 |
+
int * nodeparent,
|
| 24 |
+
cudaStream_t stream);
|
| 25 |
+
|
| 26 |
+
void compute_aabb_cuda(
|
| 27 |
+
int N, int K,
|
| 28 |
+
float * primpos,
|
| 29 |
+
float * primrot,
|
| 30 |
+
float * primscale,
|
| 31 |
+
int * sortedobjid,
|
| 32 |
+
int * nodechildren,
|
| 33 |
+
int * nodeparent,
|
| 34 |
+
float * nodeaabb,
|
| 35 |
+
int algorithm,
|
| 36 |
+
cudaStream_t stream);
|
| 37 |
+
|
| 38 |
+
void raymarch_forward_cuda(
|
| 39 |
+
int N, int H, int W, int K,
|
| 40 |
+
float * rayposim,
|
| 41 |
+
float * raydirim,
|
| 42 |
+
float stepsize,
|
| 43 |
+
float * tminmaxim,
|
| 44 |
+
|
| 45 |
+
int * sortedobjid,
|
| 46 |
+
int * nodechildren,
|
| 47 |
+
float * nodeaabb,
|
| 48 |
+
|
| 49 |
+
float * primpos,
|
| 50 |
+
float * primrot,
|
| 51 |
+
float * primscale,
|
| 52 |
+
|
| 53 |
+
int TD, int TH, int TW,
|
| 54 |
+
float * tplate,
|
| 55 |
+
int WD, int WH, int WW,
|
| 56 |
+
float * warp,
|
| 57 |
+
|
| 58 |
+
float * rayrgbaim,
|
| 59 |
+
float * raysatim,
|
| 60 |
+
int * raytermim,
|
| 61 |
+
|
| 62 |
+
int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
|
| 63 |
+
bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
|
| 64 |
+
int griddim, int blocksizex, int blocksizey,
|
| 65 |
+
cudaStream_t stream);
|
| 66 |
+
|
| 67 |
+
void raymarch_backward_cuda(
|
| 68 |
+
int N, int H, int W, int K,
|
| 69 |
+
float * rayposim,
|
| 70 |
+
float * raydirim,
|
| 71 |
+
float stepsize,
|
| 72 |
+
float * tminmaxim,
|
| 73 |
+
|
| 74 |
+
int * sortedobjid,
|
| 75 |
+
int * nodechildren,
|
| 76 |
+
float * nodeaabb,
|
| 77 |
+
|
| 78 |
+
float * primpos,
|
| 79 |
+
float * grad_primpos,
|
| 80 |
+
float * primrot,
|
| 81 |
+
float * grad_primrot,
|
| 82 |
+
float * primscale,
|
| 83 |
+
float * grad_primscale,
|
| 84 |
+
|
| 85 |
+
int TD, int TH, int TW,
|
| 86 |
+
float * tplate,
|
| 87 |
+
float * grad_tplate,
|
| 88 |
+
int WD, int WH, int WW,
|
| 89 |
+
float * warp,
|
| 90 |
+
float * grad_warp,
|
| 91 |
+
|
| 92 |
+
float * rayrgbaim,
|
| 93 |
+
float * grad_rayrgba,
|
| 94 |
+
float * raysatim,
|
| 95 |
+
int * raytermim,
|
| 96 |
+
|
| 97 |
+
int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
|
| 98 |
+
bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
|
| 99 |
+
int griddim, int blocksizex, int blocksizey,
|
| 100 |
+
cudaStream_t stream);
|
| 101 |
+
|
| 102 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
| 103 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
| 104 |
+
#define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x))
|
| 105 |
+
|
| 106 |
+
std::vector<torch::Tensor> compute_morton(
|
| 107 |
+
torch::Tensor primpos,
|
| 108 |
+
torch::Tensor code,
|
| 109 |
+
int algorithm) {
|
| 110 |
+
CHECK_INPUT(primpos);
|
| 111 |
+
CHECK_INPUT(code);
|
| 112 |
+
|
| 113 |
+
int N = primpos.size(0);
|
| 114 |
+
int K = primpos.size(1);
|
| 115 |
+
|
| 116 |
+
compute_morton_cuda(
|
| 117 |
+
N, K,
|
| 118 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
| 119 |
+
reinterpret_cast<int *>(code.data_ptr()),
|
| 120 |
+
algorithm,
|
| 121 |
+
0);
|
| 122 |
+
|
| 123 |
+
return {};
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
std::vector<torch::Tensor> build_tree(
|
| 127 |
+
torch::Tensor sortedcode,
|
| 128 |
+
torch::Tensor nodechildren,
|
| 129 |
+
torch::Tensor nodeparent) {
|
| 130 |
+
CHECK_INPUT(sortedcode);
|
| 131 |
+
CHECK_INPUT(nodechildren);
|
| 132 |
+
CHECK_INPUT(nodeparent);
|
| 133 |
+
|
| 134 |
+
int N = sortedcode.size(0);
|
| 135 |
+
int K = sortedcode.size(1);
|
| 136 |
+
|
| 137 |
+
build_tree_cuda(N, K,
|
| 138 |
+
reinterpret_cast<int *>(sortedcode.data_ptr()),
|
| 139 |
+
reinterpret_cast<int *>(nodechildren.data_ptr()),
|
| 140 |
+
reinterpret_cast<int *>(nodeparent.data_ptr()),
|
| 141 |
+
0);
|
| 142 |
+
|
| 143 |
+
return {};
|
| 144 |
+
}
|
| 145 |
+
|
| 146 |
+
std::vector<torch::Tensor> compute_aabb(
|
| 147 |
+
torch::Tensor primpos,
|
| 148 |
+
torch::optional<torch::Tensor> primrot,
|
| 149 |
+
torch::optional<torch::Tensor> primscale,
|
| 150 |
+
torch::Tensor sortedobjid,
|
| 151 |
+
torch::Tensor nodechildren,
|
| 152 |
+
torch::Tensor nodeparent,
|
| 153 |
+
torch::Tensor nodeaabb,
|
| 154 |
+
int algorithm) {
|
| 155 |
+
CHECK_INPUT(sortedobjid);
|
| 156 |
+
CHECK_INPUT(primpos);
|
| 157 |
+
if (primrot) { CHECK_INPUT(*primrot); }
|
| 158 |
+
if (primscale) { CHECK_INPUT(*primscale); }
|
| 159 |
+
CHECK_INPUT(nodechildren);
|
| 160 |
+
CHECK_INPUT(nodeparent);
|
| 161 |
+
CHECK_INPUT(nodeaabb);
|
| 162 |
+
|
| 163 |
+
int N = primpos.size(0);
|
| 164 |
+
int K = primpos.size(1);
|
| 165 |
+
|
| 166 |
+
compute_aabb_cuda(N, K,
|
| 167 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
| 168 |
+
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
|
| 169 |
+
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
|
| 170 |
+
reinterpret_cast<int *>(sortedobjid.data_ptr()),
|
| 171 |
+
reinterpret_cast<int *>(nodechildren.data_ptr()),
|
| 172 |
+
reinterpret_cast<int *>(nodeparent.data_ptr()),
|
| 173 |
+
reinterpret_cast<float *>(nodeaabb.data_ptr()),
|
| 174 |
+
algorithm,
|
| 175 |
+
0);
|
| 176 |
+
|
| 177 |
+
return {};
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
std::vector<torch::Tensor> raymarch_forward(
|
| 181 |
+
torch::Tensor rayposim,
|
| 182 |
+
torch::Tensor raydirim,
|
| 183 |
+
float stepsize,
|
| 184 |
+
torch::Tensor tminmaxim,
|
| 185 |
+
|
| 186 |
+
torch::optional<torch::Tensor> sortedobjid,
|
| 187 |
+
torch::optional<torch::Tensor> nodechildren,
|
| 188 |
+
torch::optional<torch::Tensor> nodeaabb,
|
| 189 |
+
|
| 190 |
+
torch::Tensor primpos,
|
| 191 |
+
torch::optional<torch::Tensor> primrot,
|
| 192 |
+
torch::optional<torch::Tensor> primscale,
|
| 193 |
+
|
| 194 |
+
torch::Tensor tplate,
|
| 195 |
+
torch::optional<torch::Tensor> warp,
|
| 196 |
+
|
| 197 |
+
torch::Tensor rayrgbaim,
|
| 198 |
+
torch::optional<torch::Tensor> raysatim,
|
| 199 |
+
torch::optional<torch::Tensor> raytermim,
|
| 200 |
+
|
| 201 |
+
int algorithm=0,
|
| 202 |
+
bool sortboxes=true,
|
| 203 |
+
int maxhitboxes=512,
|
| 204 |
+
bool synchitboxes=false,
|
| 205 |
+
bool chlast=false,
|
| 206 |
+
float fadescale=8.f,
|
| 207 |
+
float fadeexp=8.f,
|
| 208 |
+
int accum=0,
|
| 209 |
+
float termthresh=0.f,
|
| 210 |
+
int griddim=3,
|
| 211 |
+
int blocksizex=8,
|
| 212 |
+
int blocksizey=16) {
|
| 213 |
+
CHECK_INPUT(rayposim);
|
| 214 |
+
CHECK_INPUT(raydirim);
|
| 215 |
+
CHECK_INPUT(tminmaxim);
|
| 216 |
+
if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
|
| 217 |
+
if (nodechildren) { CHECK_INPUT(*nodechildren); }
|
| 218 |
+
if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
|
| 219 |
+
CHECK_INPUT(tplate);
|
| 220 |
+
if (warp) { CHECK_INPUT(*warp); }
|
| 221 |
+
CHECK_INPUT(primpos);
|
| 222 |
+
if (primrot) { CHECK_INPUT(*primrot); }
|
| 223 |
+
if (primscale) { CHECK_INPUT(*primscale); }
|
| 224 |
+
CHECK_INPUT(rayrgbaim);
|
| 225 |
+
if (raysatim) { CHECK_INPUT(*raysatim); }
|
| 226 |
+
if (raytermim) { CHECK_INPUT(*raytermim); }
|
| 227 |
+
|
| 228 |
+
int N = rayposim.size(0);
|
| 229 |
+
int H = rayposim.size(1);
|
| 230 |
+
int W = rayposim.size(2);
|
| 231 |
+
int K = primpos.size(1);
|
| 232 |
+
|
| 233 |
+
int TD, TH, TW;
|
| 234 |
+
if (chlast) {
|
| 235 |
+
TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
|
| 236 |
+
} else {
|
| 237 |
+
TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
int WD = 0, WH = 0, WW = 0;
|
| 241 |
+
if (warp) {
|
| 242 |
+
if (chlast) {
|
| 243 |
+
WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
|
| 244 |
+
} else {
|
| 245 |
+
WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
|
| 246 |
+
}
|
| 247 |
+
}
|
| 248 |
+
|
| 249 |
+
raymarch_forward_cuda(N, H, W, K,
|
| 250 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
| 251 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
| 252 |
+
stepsize,
|
| 253 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
| 254 |
+
sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
|
| 255 |
+
nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
|
| 256 |
+
nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
|
| 257 |
+
|
| 258 |
+
// prim transforms
|
| 259 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
| 260 |
+
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
|
| 261 |
+
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
|
| 262 |
+
|
| 263 |
+
// prim sampler
|
| 264 |
+
TD, TH, TW,
|
| 265 |
+
reinterpret_cast<float *>(tplate.data_ptr()),
|
| 266 |
+
WD, WH, WW,
|
| 267 |
+
warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
|
| 268 |
+
|
| 269 |
+
// prim accumulator
|
| 270 |
+
reinterpret_cast<float *>(rayrgbaim.data_ptr()),
|
| 271 |
+
raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
|
| 272 |
+
raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
|
| 273 |
+
|
| 274 |
+
// options
|
| 275 |
+
algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
|
| 276 |
+
griddim, blocksizex, blocksizey,
|
| 277 |
+
0);
|
| 278 |
+
|
| 279 |
+
return {};
|
| 280 |
+
}
|
| 281 |
+
|
| 282 |
+
std::vector<torch::Tensor> raymarch_backward(
|
| 283 |
+
torch::Tensor rayposim,
|
| 284 |
+
torch::Tensor raydirim,
|
| 285 |
+
float stepsize,
|
| 286 |
+
torch::Tensor tminmaxim,
|
| 287 |
+
|
| 288 |
+
torch::optional<torch::Tensor> sortedobjid,
|
| 289 |
+
torch::optional<torch::Tensor> nodechildren,
|
| 290 |
+
torch::optional<torch::Tensor> nodeaabb,
|
| 291 |
+
|
| 292 |
+
torch::Tensor primpos,
|
| 293 |
+
torch::Tensor grad_primpos,
|
| 294 |
+
torch::optional<torch::Tensor> primrot,
|
| 295 |
+
torch::optional<torch::Tensor> grad_primrot,
|
| 296 |
+
torch::optional<torch::Tensor> primscale,
|
| 297 |
+
torch::optional<torch::Tensor> grad_primscale,
|
| 298 |
+
|
| 299 |
+
torch::Tensor tplate,
|
| 300 |
+
torch::Tensor grad_tplate,
|
| 301 |
+
torch::optional<torch::Tensor> warp,
|
| 302 |
+
torch::optional<torch::Tensor> grad_warp,
|
| 303 |
+
|
| 304 |
+
torch::Tensor rayrgbaim,
|
| 305 |
+
torch::Tensor grad_rayrgba,
|
| 306 |
+
torch::optional<torch::Tensor> raysatim,
|
| 307 |
+
torch::optional<torch::Tensor> raytermim,
|
| 308 |
+
|
| 309 |
+
int algorithm=0,
|
| 310 |
+
bool sortboxes=true,
|
| 311 |
+
int maxhitboxes=512,
|
| 312 |
+
bool synchitboxes=false,
|
| 313 |
+
bool chlast=false,
|
| 314 |
+
float fadescale=8.f,
|
| 315 |
+
float fadeexp=8.f,
|
| 316 |
+
int accum=0,
|
| 317 |
+
float termthresh=0.f,
|
| 318 |
+
int griddim=3,
|
| 319 |
+
int blocksizex=8,
|
| 320 |
+
int blocksizey=16) {
|
| 321 |
+
CHECK_INPUT(rayposim);
|
| 322 |
+
CHECK_INPUT(raydirim);
|
| 323 |
+
CHECK_INPUT(tminmaxim);
|
| 324 |
+
if (sortedobjid) { CHECK_INPUT(*sortedobjid); }
|
| 325 |
+
if (nodechildren) { CHECK_INPUT(*nodechildren); }
|
| 326 |
+
if (nodeaabb) { CHECK_INPUT(*nodeaabb); }
|
| 327 |
+
CHECK_INPUT(tplate);
|
| 328 |
+
if (warp) { CHECK_INPUT(*warp); }
|
| 329 |
+
CHECK_INPUT(primpos);
|
| 330 |
+
if (primrot) { CHECK_INPUT(*primrot); }
|
| 331 |
+
if (primscale) { CHECK_INPUT(*primscale); }
|
| 332 |
+
CHECK_INPUT(rayrgbaim);
|
| 333 |
+
if (raysatim) { CHECK_INPUT(*raysatim); }
|
| 334 |
+
if (raytermim) { CHECK_INPUT(*raytermim); }
|
| 335 |
+
CHECK_INPUT(grad_rayrgba);
|
| 336 |
+
CHECK_INPUT(grad_tplate);
|
| 337 |
+
if (grad_warp) { CHECK_INPUT(*grad_warp); }
|
| 338 |
+
CHECK_INPUT(grad_primpos);
|
| 339 |
+
if (grad_primrot) { CHECK_INPUT(*grad_primrot); }
|
| 340 |
+
if (grad_primscale) { CHECK_INPUT(*grad_primscale); }
|
| 341 |
+
|
| 342 |
+
int N = rayposim.size(0);
|
| 343 |
+
int H = rayposim.size(1);
|
| 344 |
+
int W = rayposim.size(2);
|
| 345 |
+
int K = primpos.size(1);
|
| 346 |
+
|
| 347 |
+
int TD, TH, TW;
|
| 348 |
+
if (chlast) {
|
| 349 |
+
TD = tplate.size(2); TH = tplate.size(3); TW = tplate.size(4);
|
| 350 |
+
} else {
|
| 351 |
+
TD = tplate.size(3); TH = tplate.size(4); TW = tplate.size(5);
|
| 352 |
+
}
|
| 353 |
+
|
| 354 |
+
int WD = 0, WH = 0, WW = 0;
|
| 355 |
+
if (warp) {
|
| 356 |
+
if (chlast) {
|
| 357 |
+
WD = warp->size(2); WH = warp->size(3); WW = warp->size(4);
|
| 358 |
+
} else {
|
| 359 |
+
WD = warp->size(3); WH = warp->size(4); WW = warp->size(5);
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
raymarch_backward_cuda(N, H, W, K,
|
| 364 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
| 365 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
| 366 |
+
stepsize,
|
| 367 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
| 368 |
+
sortedobjid ? reinterpret_cast<int *>(sortedobjid->data_ptr()) : nullptr,
|
| 369 |
+
nodechildren ? reinterpret_cast<int *>(nodechildren->data_ptr()) : nullptr,
|
| 370 |
+
nodeaabb ? reinterpret_cast<float *>(nodeaabb->data_ptr()) : nullptr,
|
| 371 |
+
|
| 372 |
+
reinterpret_cast<float *>(primpos.data_ptr()),
|
| 373 |
+
reinterpret_cast<float *>(grad_primpos.data_ptr()),
|
| 374 |
+
primrot ? reinterpret_cast<float *>(primrot->data_ptr()) : nullptr,
|
| 375 |
+
grad_primrot ? reinterpret_cast<float *>(grad_primrot->data_ptr()) : nullptr,
|
| 376 |
+
primscale ? reinterpret_cast<float *>(primscale->data_ptr()) : nullptr,
|
| 377 |
+
grad_primscale ? reinterpret_cast<float *>(grad_primscale->data_ptr()) : nullptr,
|
| 378 |
+
|
| 379 |
+
TD, TH, TW,
|
| 380 |
+
reinterpret_cast<float *>(tplate.data_ptr()),
|
| 381 |
+
reinterpret_cast<float *>(grad_tplate.data_ptr()),
|
| 382 |
+
WD, WH, WW,
|
| 383 |
+
warp ? reinterpret_cast<float *>(warp->data_ptr()) : nullptr,
|
| 384 |
+
grad_warp ? reinterpret_cast<float *>(grad_warp->data_ptr()) : nullptr,
|
| 385 |
+
|
| 386 |
+
reinterpret_cast<float *>(rayrgbaim.data_ptr()),
|
| 387 |
+
reinterpret_cast<float *>(grad_rayrgba.data_ptr()),
|
| 388 |
+
raysatim ? reinterpret_cast<float *>(raysatim->data_ptr()) : nullptr,
|
| 389 |
+
raytermim ? reinterpret_cast<int *>(raytermim->data_ptr()) : nullptr,
|
| 390 |
+
|
| 391 |
+
algorithm, sortboxes, maxhitboxes, synchitboxes, chlast, fadescale, fadeexp, accum, termthresh,
|
| 392 |
+
griddim, blocksizex, blocksizey,
|
| 393 |
+
0);
|
| 394 |
+
|
| 395 |
+
return {};
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 399 |
+
m.def("compute_morton", &compute_morton, "compute morton codes (CUDA)");
|
| 400 |
+
m.def("build_tree", &build_tree, "build BVH tree (CUDA)");
|
| 401 |
+
m.def("compute_aabb", &compute_aabb, "compute AABB sizes (CUDA)");
|
| 402 |
+
|
| 403 |
+
m.def("raymarch_forward", &raymarch_forward, "raymarch forward (CUDA)");
|
| 404 |
+
m.def("raymarch_backward", &raymarch_backward, "raymarch backward (CUDA)");
|
| 405 |
+
}
|
dva/mvp/extensions/mvpraymarch/mvpraymarch.py
ADDED
|
@@ -0,0 +1,559 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.autograd import Function
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from . import mvpraymarchlib
|
| 17 |
+
except:
|
| 18 |
+
import mvpraymarchlib
|
| 19 |
+
|
| 20 |
+
def build_accel(primtransfin, algo, fixedorder=False):
|
| 21 |
+
"""build bvh structure given primitive centers and sizes
|
| 22 |
+
|
| 23 |
+
Parameters:
|
| 24 |
+
----------
|
| 25 |
+
primtransfin : tuple[tensor, tensor, tensor]
|
| 26 |
+
primitive transform tensors
|
| 27 |
+
algo : int
|
| 28 |
+
raymarching algorithm
|
| 29 |
+
fixedorder : optional[str]
|
| 30 |
+
True means the bvh builder will not reorder primitives and will
|
| 31 |
+
use a trivial tree structure. Likely to be slow for arbitrary
|
| 32 |
+
configurations of primitives.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
primpos, primrot, primscale = primtransfin
|
| 36 |
+
|
| 37 |
+
N = primpos.size(0)
|
| 38 |
+
K = primpos.size(1)
|
| 39 |
+
|
| 40 |
+
dev = primpos.device
|
| 41 |
+
|
| 42 |
+
# compute and sort morton codes
|
| 43 |
+
if fixedorder:
|
| 44 |
+
sortedobjid = (torch.arange(N*K, dtype=torch.int32, device=dev) % K).view(N, K)
|
| 45 |
+
else:
|
| 46 |
+
cmax = primpos.max(dim=1, keepdim=True)[0]
|
| 47 |
+
cmin = primpos.min(dim=1, keepdim=True)[0]
|
| 48 |
+
|
| 49 |
+
centers_norm = (primpos - cmin) / (cmax - cmin).clamp(min=1e-8)
|
| 50 |
+
|
| 51 |
+
mortoncode = torch.empty((N, K), dtype=torch.int32, device=dev)
|
| 52 |
+
mvpraymarchlib.compute_morton(centers_norm, mortoncode, algo)
|
| 53 |
+
sortedcode, sortedobjid_long = torch.sort(mortoncode, dim=-1)
|
| 54 |
+
sortedobjid = sortedobjid_long.int()
|
| 55 |
+
|
| 56 |
+
if fixedorder:
|
| 57 |
+
nodechildren = torch.cat([
|
| 58 |
+
torch.arange(1, (K - 1) * 2 + 1, dtype=torch.int32, device=dev),
|
| 59 |
+
torch.div(torch.arange(-2, -(K * 2 + 1) - 1, -1, dtype=torch.int32, device=dev), 2, rounding_mode="floor")],
|
| 60 |
+
dim=0).view(1, K + K - 1, 2).repeat(N, 1, 1)
|
| 61 |
+
nodeparent = (
|
| 62 |
+
torch.div(torch.arange(-1, K * 2 - 2, dtype=torch.int32, device=dev), 2, rounding_mode="floor")
|
| 63 |
+
.view(1, -1).repeat(N, 1))
|
| 64 |
+
else:
|
| 65 |
+
nodechildren = torch.empty((N, K + K - 1, 2), dtype=torch.int32, device=dev)
|
| 66 |
+
nodeparent = torch.full((N, K + K - 1), -1, dtype=torch.int32, device=dev)
|
| 67 |
+
mvpraymarchlib.build_tree(sortedcode, nodechildren, nodeparent)
|
| 68 |
+
|
| 69 |
+
nodeaabb = torch.empty((N, K + K - 1, 2, 3), dtype=torch.float32, device=dev)
|
| 70 |
+
mvpraymarchlib.compute_aabb(*primtransfin, sortedobjid, nodechildren, nodeparent, nodeaabb, algo)
|
| 71 |
+
|
| 72 |
+
return sortedobjid, nodechildren, nodeaabb
|
| 73 |
+
|
| 74 |
+
class MVPRaymarch(Function):
|
| 75 |
+
"""Custom Function for raymarching Mixture of Volumetric Primitives."""
|
| 76 |
+
@staticmethod
|
| 77 |
+
def forward(self, raypos, raydir, stepsize, tminmax,
|
| 78 |
+
primpos, primrot, primscale,
|
| 79 |
+
template, warp,
|
| 80 |
+
rayterm, gradmode, options):
|
| 81 |
+
algo = options["algo"]
|
| 82 |
+
usebvh = options["usebvh"]
|
| 83 |
+
sortprims = options["sortprims"]
|
| 84 |
+
randomorder = options["randomorder"]
|
| 85 |
+
maxhitboxes = options["maxhitboxes"]
|
| 86 |
+
synchitboxes = options["synchitboxes"]
|
| 87 |
+
chlast = options["chlast"]
|
| 88 |
+
fadescale = options["fadescale"]
|
| 89 |
+
fadeexp = options["fadeexp"]
|
| 90 |
+
accum = options["accum"]
|
| 91 |
+
termthresh = options["termthresh"]
|
| 92 |
+
griddim = options["griddim"]
|
| 93 |
+
if isinstance(options["blocksize"], tuple):
|
| 94 |
+
blocksizex, blocksizey = options["blocksize"]
|
| 95 |
+
else:
|
| 96 |
+
blocksizex = options["blocksize"]
|
| 97 |
+
blocksizey = 1
|
| 98 |
+
|
| 99 |
+
assert raypos.is_contiguous() and raypos.size(3) == 3
|
| 100 |
+
assert raydir.is_contiguous() and raydir.size(3) == 3
|
| 101 |
+
assert tminmax.is_contiguous() and tminmax.size(3) == 2
|
| 102 |
+
|
| 103 |
+
assert primpos is None or primpos.is_contiguous() and primpos.size(2) == 3
|
| 104 |
+
assert primrot is None or primrot.is_contiguous() and primrot.size(2) == 3
|
| 105 |
+
assert primscale is None or primscale.is_contiguous() and primscale.size(2) == 3
|
| 106 |
+
|
| 107 |
+
if chlast:
|
| 108 |
+
assert template.is_contiguous() and len(template.size()) == 6 and template.size(-1) == 4
|
| 109 |
+
assert warp is None or (warp.is_contiguous() and warp.size(-1) == 3)
|
| 110 |
+
else:
|
| 111 |
+
assert template.is_contiguous() and len(template.size()) == 6 and template.size(2) == 4
|
| 112 |
+
assert warp is None or (warp.is_contiguous() and warp.size(2) == 3)
|
| 113 |
+
|
| 114 |
+
primtransfin = (primpos, primrot, primscale)
|
| 115 |
+
|
| 116 |
+
# Build bvh
|
| 117 |
+
if usebvh is not False:
|
| 118 |
+
# compute radius of primitives
|
| 119 |
+
sortedobjid, nodechildren, nodeaabb = build_accel(primtransfin,
|
| 120 |
+
algo, fixedorder=usebvh=="fixedorder")
|
| 121 |
+
assert sortedobjid.is_contiguous()
|
| 122 |
+
assert nodechildren.is_contiguous()
|
| 123 |
+
assert nodeaabb.is_contiguous()
|
| 124 |
+
|
| 125 |
+
if randomorder:
|
| 126 |
+
sortedobjid = sortedobjid[torch.randperm(len(sortedobjid))]
|
| 127 |
+
else:
|
| 128 |
+
_, sortedobjid, nodechildren, nodeaabb = None, None, None, None
|
| 129 |
+
|
| 130 |
+
# march through boxes
|
| 131 |
+
N, H, W = raypos.size(0), raypos.size(1), raypos.size(2)
|
| 132 |
+
rayrgba = torch.empty((N, H, W, 4), device=raypos.device)
|
| 133 |
+
if gradmode:
|
| 134 |
+
raysat = torch.full((N, H, W, 3), -1, dtype=torch.float32, device=raypos.device)
|
| 135 |
+
rayterm = None
|
| 136 |
+
else:
|
| 137 |
+
raysat = None
|
| 138 |
+
rayterm = None
|
| 139 |
+
|
| 140 |
+
mvpraymarchlib.raymarch_forward(
|
| 141 |
+
raypos, raydir, stepsize, tminmax,
|
| 142 |
+
sortedobjid, nodechildren, nodeaabb,
|
| 143 |
+
*primtransfin,
|
| 144 |
+
template, warp,
|
| 145 |
+
rayrgba, raysat, rayterm,
|
| 146 |
+
algo, sortprims, maxhitboxes, synchitboxes, chlast,
|
| 147 |
+
fadescale, fadeexp,
|
| 148 |
+
accum, termthresh,
|
| 149 |
+
griddim, blocksizex, blocksizey)
|
| 150 |
+
|
| 151 |
+
self.save_for_backward(
|
| 152 |
+
raypos, raydir, tminmax,
|
| 153 |
+
sortedobjid, nodechildren, nodeaabb,
|
| 154 |
+
primpos, primrot, primscale,
|
| 155 |
+
template, warp,
|
| 156 |
+
rayrgba, raysat, rayterm)
|
| 157 |
+
self.options = options
|
| 158 |
+
self.stepsize = stepsize
|
| 159 |
+
|
| 160 |
+
return rayrgba
|
| 161 |
+
|
| 162 |
+
@staticmethod
|
| 163 |
+
def backward(self, grad_rayrgba):
|
| 164 |
+
(raypos, raydir, tminmax,
|
| 165 |
+
sortedobjid, nodechildren, nodeaabb,
|
| 166 |
+
primpos, primrot, primscale,
|
| 167 |
+
template, warp,
|
| 168 |
+
rayrgba, raysat, rayterm) = self.saved_tensors
|
| 169 |
+
algo = self.options["algo"]
|
| 170 |
+
usebvh = self.options["usebvh"]
|
| 171 |
+
sortprims = self.options["sortprims"]
|
| 172 |
+
maxhitboxes = self.options["maxhitboxes"]
|
| 173 |
+
synchitboxes = self.options["synchitboxes"]
|
| 174 |
+
chlast = self.options["chlast"]
|
| 175 |
+
fadescale = self.options["fadescale"]
|
| 176 |
+
fadeexp = self.options["fadeexp"]
|
| 177 |
+
accum = self.options["accum"]
|
| 178 |
+
termthresh = self.options["termthresh"]
|
| 179 |
+
griddim = self.options["griddim"]
|
| 180 |
+
if isinstance(self.options["bwdblocksize"], tuple):
|
| 181 |
+
blocksizex, blocksizey = self.options["bwdblocksize"]
|
| 182 |
+
else:
|
| 183 |
+
blocksizex = self.options["bwdblocksize"]
|
| 184 |
+
blocksizey = 1
|
| 185 |
+
|
| 186 |
+
stepsize = self.stepsize
|
| 187 |
+
|
| 188 |
+
grad_primpos = torch.zeros_like(primpos)
|
| 189 |
+
grad_primrot = torch.zeros_like(primrot)
|
| 190 |
+
grad_primscale = torch.zeros_like(primscale)
|
| 191 |
+
primtransfin = (primpos, grad_primpos, primrot, grad_primrot, primscale, grad_primscale)
|
| 192 |
+
|
| 193 |
+
grad_template = torch.zeros_like(template)
|
| 194 |
+
grad_warp = torch.zeros_like(warp) if warp is not None else None
|
| 195 |
+
|
| 196 |
+
mvpraymarchlib.raymarch_backward(raypos, raydir, stepsize, tminmax,
|
| 197 |
+
sortedobjid, nodechildren, nodeaabb,
|
| 198 |
+
|
| 199 |
+
*primtransfin,
|
| 200 |
+
|
| 201 |
+
template, grad_template, warp, grad_warp,
|
| 202 |
+
|
| 203 |
+
rayrgba, grad_rayrgba.contiguous(), raysat, rayterm,
|
| 204 |
+
|
| 205 |
+
algo, sortprims, maxhitboxes, synchitboxes, chlast,
|
| 206 |
+
fadescale, fadeexp,
|
| 207 |
+
accum, termthresh,
|
| 208 |
+
griddim, blocksizex, blocksizey)
|
| 209 |
+
|
| 210 |
+
return (None, None, None, None,
|
| 211 |
+
grad_primpos, grad_primrot, grad_primscale,
|
| 212 |
+
grad_template, grad_warp,
|
| 213 |
+
None, None, None)
|
| 214 |
+
|
| 215 |
+
def mvpraymarch(raypos, raydir, stepsize, tminmax,
|
| 216 |
+
primtransf,
|
| 217 |
+
template, warp,
|
| 218 |
+
rayterm=None,
|
| 219 |
+
algo=0, usebvh="fixedorder",
|
| 220 |
+
sortprims=False, randomorder=False,
|
| 221 |
+
maxhitboxes=512, synchitboxes=True,
|
| 222 |
+
chlast=True, fadescale=8., fadeexp=8.,
|
| 223 |
+
accum=0, termthresh=0.,
|
| 224 |
+
griddim=3, blocksize=(8, 16), bwdblocksize=(8, 16)):
|
| 225 |
+
"""Main entry point for raymarching MVP.
|
| 226 |
+
|
| 227 |
+
Parameters:
|
| 228 |
+
----------
|
| 229 |
+
raypos: N x H x W x 3 tensor of ray origins
|
| 230 |
+
raydir: N x H x W x 3 tensor of ray directions
|
| 231 |
+
stepsize: raymarching step size
|
| 232 |
+
tminmax: N x H x W x 2 tensor of raymarching min/max bounds
|
| 233 |
+
template: N x K x 4 x TD x TH x TW tensor of K RGBA primitives
|
| 234 |
+
warp: N x K x 3 x TD x TH x TW tensor of K warp fields (optional)
|
| 235 |
+
primpos: N x K x 3 tensor of primitive centers
|
| 236 |
+
primrot: N x K x 3 x 3 tensor of primitive orientations
|
| 237 |
+
primscale: N x K x 3 tensor of primitive inverse dimension lengths
|
| 238 |
+
algo: algorithm for raymarching (valid values: 0, 1). algo=0 is the fastest.
|
| 239 |
+
Currently algo=0 has a limit of 512 primitives per ray, so problems can
|
| 240 |
+
occur if there are many more boxes. all sortprims=True options have
|
| 241 |
+
this limitation, but you can use (algo=1, sortprims=False,
|
| 242 |
+
usebvh="fixedorder") which works correctly and has no primitive number
|
| 243 |
+
limitation (but is slightly slower).
|
| 244 |
+
usebvh: True to use bvh, "fixedorder" for a simple BVH, False for no bvh
|
| 245 |
+
sortprims: True to sort overlapping primitives at a sample point. Must
|
| 246 |
+
be True for gradients to match the PyTorch gradients. Seems unstable
|
| 247 |
+
if False but also not a big performance bottleneck.
|
| 248 |
+
chlast: whether template is provided as channels last or not. True tends
|
| 249 |
+
to be faster.
|
| 250 |
+
fadescale: Opacity is faded at the borders of the primitives by the equation
|
| 251 |
+
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of
|
| 252 |
+
the primitive.
|
| 253 |
+
fadeexp: Opacity is faded at the borders of the primitives by the equation
|
| 254 |
+
exp(-fadescale * x ** fadeexp) where x is the normalized coordinates of
|
| 255 |
+
the primitive.
|
| 256 |
+
griddim: CUDA grid dimensionality.
|
| 257 |
+
blocksize: blocksize of CUDA kernels. Should be 2-element tuple if
|
| 258 |
+
griddim>1, or integer if griddim==1."""
|
| 259 |
+
if isinstance(primtransf, tuple):
|
| 260 |
+
primpos, primrot, primscale = primtransf
|
| 261 |
+
else:
|
| 262 |
+
primpos, primrot, primscale = (
|
| 263 |
+
primtransf[:, :, 0, :].contiguous(),
|
| 264 |
+
primtransf[:, :, 1:4, :].contiguous(),
|
| 265 |
+
primtransf[:, :, 4, :].contiguous())
|
| 266 |
+
primtransfin = (primpos, primrot, primscale)
|
| 267 |
+
|
| 268 |
+
out = MVPRaymarch.apply(raypos, raydir, stepsize, tminmax,
|
| 269 |
+
*primtransfin,
|
| 270 |
+
template, warp,
|
| 271 |
+
rayterm, torch.is_grad_enabled(),
|
| 272 |
+
{"algo": algo, "usebvh": usebvh, "sortprims": sortprims, "randomorder": randomorder,
|
| 273 |
+
"maxhitboxes": maxhitboxes, "synchitboxes": synchitboxes,
|
| 274 |
+
"chlast": chlast, "fadescale": fadescale, "fadeexp": fadeexp,
|
| 275 |
+
"accum": accum, "termthresh": termthresh,
|
| 276 |
+
"griddim": griddim, "blocksize": blocksize, "bwdblocksize": bwdblocksize})
|
| 277 |
+
return out
|
| 278 |
+
|
| 279 |
+
class Rodrigues(nn.Module):
|
| 280 |
+
def __init__(self):
|
| 281 |
+
super(Rodrigues, self).__init__()
|
| 282 |
+
|
| 283 |
+
def forward(self, rvec):
|
| 284 |
+
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
|
| 285 |
+
rvec = rvec / theta[:, None]
|
| 286 |
+
costh = torch.cos(theta)
|
| 287 |
+
sinth = torch.sin(theta)
|
| 288 |
+
return torch.stack((
|
| 289 |
+
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
|
| 290 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
|
| 291 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
|
| 292 |
+
|
| 293 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
|
| 294 |
+
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
|
| 295 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
|
| 296 |
+
|
| 297 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
|
| 298 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
|
| 299 |
+
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
|
| 300 |
+
|
| 301 |
+
def gradcheck(usebvh=True, sortprims=True, maxhitboxes=512, synchitboxes=False,
|
| 302 |
+
dowarp=False, chlast=False, fadescale=8., fadeexp=8.,
|
| 303 |
+
accum=0, termthresh=0., algo=0, griddim=2, blocksize=(8, 16), bwdblocksize=(8, 16)):
|
| 304 |
+
N = 2
|
| 305 |
+
H = 65
|
| 306 |
+
W = 65
|
| 307 |
+
k3 = 4
|
| 308 |
+
K = k3*k3*k3
|
| 309 |
+
|
| 310 |
+
M = 32
|
| 311 |
+
|
| 312 |
+
print("=================================================================")
|
| 313 |
+
print("usebvh={}, sortprims={}, maxhb={}, synchb={}, dowarp={}, chlast={}, "
|
| 314 |
+
"fadescale={}, fadeexp={}, accum={}, termthresh={}, algo={}, griddim={}, "
|
| 315 |
+
"blocksize={}, bwdblocksize={}".format(
|
| 316 |
+
usebvh, sortprims, maxhitboxes, synchitboxes, dowarp, chlast,
|
| 317 |
+
fadescale, fadeexp, accum, termthresh, algo, griddim, blocksize,
|
| 318 |
+
bwdblocksize))
|
| 319 |
+
|
| 320 |
+
# generate random inputs
|
| 321 |
+
torch.manual_seed(1112)
|
| 322 |
+
|
| 323 |
+
coherent_rays = True
|
| 324 |
+
if not coherent_rays:
|
| 325 |
+
_raypos = torch.randn(N, H, W, 3).to("cuda")
|
| 326 |
+
_raydir = torch.randn(N, H, W, 3).to("cuda")
|
| 327 |
+
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
|
| 328 |
+
else:
|
| 329 |
+
focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)])
|
| 330 |
+
princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)])
|
| 331 |
+
pixely, pixelx = torch.meshgrid(torch.arange(H).float(), torch.arange(W).float())
|
| 332 |
+
pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
|
| 333 |
+
|
| 334 |
+
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
|
| 335 |
+
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
|
| 336 |
+
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
|
| 337 |
+
|
| 338 |
+
_raypos = torch.tensor([-0.0, 0.0, -4.])[None, None, None, :].repeat(N, H, W, 1).to("cuda")
|
| 339 |
+
_raydir = raydir.to("cuda")
|
| 340 |
+
_raydir /= torch.sqrt(torch.sum(_raydir ** 2, dim=-1, keepdim=True))
|
| 341 |
+
|
| 342 |
+
max_len = 6.0
|
| 343 |
+
_stepsize = max_len / 15.386928
|
| 344 |
+
_tminmax = max_len*torch.arange(2, dtype=torch.float32)[None, None, None, :].repeat(N, H, W, 1).to("cuda") + \
|
| 345 |
+
torch.rand(N, H, W, 2, device="cuda") * 1.
|
| 346 |
+
|
| 347 |
+
_template = torch.randn(N, K, 4, M, M, M, requires_grad=True)
|
| 348 |
+
_template.data[:, :, -1, :, :, :] -= 3.5
|
| 349 |
+
_template = _template.contiguous().detach().clone()
|
| 350 |
+
_template.requires_grad = True
|
| 351 |
+
gridxyz = torch.stack(torch.meshgrid(
|
| 352 |
+
torch.linspace(-1., 1., M//2),
|
| 353 |
+
torch.linspace(-1., 1., M//2),
|
| 354 |
+
torch.linspace(-1., 1., M//2))[::-1], dim=0).contiguous()
|
| 355 |
+
_warp = (torch.randn(N, K, 3, M//2, M//2, M//2) * 0.01 + gridxyz[None, None, :, :, :, :]).contiguous().detach().clone()
|
| 356 |
+
_warp.requires_grad = True
|
| 357 |
+
_primpos = torch.randn(N, K, 3, requires_grad=True)
|
| 358 |
+
_primpos = torch.randn(N, K, 3, requires_grad=True)
|
| 359 |
+
|
| 360 |
+
coherent_centers = True
|
| 361 |
+
if coherent_centers:
|
| 362 |
+
ns = k3
|
| 363 |
+
#assert ns*ns*ns==K
|
| 364 |
+
grid3d = torch.stack(torch.meshgrid(
|
| 365 |
+
torch.linspace(-1., 1., ns),
|
| 366 |
+
torch.linspace(-1., 1., ns),
|
| 367 |
+
torch.linspace(-1., 1., K//(ns*ns)))[::-1], dim=0)[None]
|
| 368 |
+
_primpos = ((
|
| 369 |
+
grid3d.permute((0, 2, 3, 4, 1)).reshape(1, K, 3).expand(N, -1, -1) +
|
| 370 |
+
0.1 * torch.randn(N, K, 3, requires_grad=True)
|
| 371 |
+
)).contiguous().detach().clone()
|
| 372 |
+
_primpos.requires_grad = True
|
| 373 |
+
scale_ws = 1.
|
| 374 |
+
_primrot = torch.randn(N, K, 3)
|
| 375 |
+
rodrigues = Rodrigues()
|
| 376 |
+
_primrot = rodrigues(_primrot.view(-1, 3)).view(N, K, 3, 3).contiguous().detach().clone()
|
| 377 |
+
_primrot.requires_grad = True
|
| 378 |
+
|
| 379 |
+
_primscale = torch.randn(N, K, 3, requires_grad=True)
|
| 380 |
+
_primscale.data *= 0.0
|
| 381 |
+
|
| 382 |
+
if dowarp:
|
| 383 |
+
params = [_template, _warp, _primscale, _primrot, _primpos]
|
| 384 |
+
paramnames = ["template", "warp", "primscale", "primrot", "primpos"]
|
| 385 |
+
else:
|
| 386 |
+
params = [_template, _primscale, _primrot, _primpos]
|
| 387 |
+
paramnames = ["template", "primscale", "primrot", "primpos"]
|
| 388 |
+
|
| 389 |
+
termthreshorig = termthresh
|
| 390 |
+
|
| 391 |
+
########################### run pytorch version ###########################
|
| 392 |
+
|
| 393 |
+
raypos = _raypos
|
| 394 |
+
raydir = _raydir
|
| 395 |
+
stepsize = _stepsize
|
| 396 |
+
tminmax = _tminmax
|
| 397 |
+
|
| 398 |
+
#template = F.softplus(_template.to("cuda") * 1.5)
|
| 399 |
+
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
|
| 400 |
+
warp = _warp.to("cuda")
|
| 401 |
+
primpos = _primpos.to("cuda") * 0.3
|
| 402 |
+
primrot = _primrot.to("cuda")
|
| 403 |
+
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
|
| 404 |
+
|
| 405 |
+
# python raymarching implementation
|
| 406 |
+
rayrgba = torch.zeros((N, H, W, 4)).to("cuda")
|
| 407 |
+
raypos = raypos + raydir * tminmax[:, :, :, 0, None]
|
| 408 |
+
t = tminmax[:, :, :, 0]
|
| 409 |
+
|
| 410 |
+
step = 0
|
| 411 |
+
t0 = t.detach().clone()
|
| 412 |
+
raypos0 = raypos.detach().clone()
|
| 413 |
+
|
| 414 |
+
torch.cuda.synchronize()
|
| 415 |
+
time0 = time.time()
|
| 416 |
+
|
| 417 |
+
while (t < tminmax[:, :, :, 1]).any():
|
| 418 |
+
valid2 = torch.ones_like(rayrgba[:, :, :, 3:4])
|
| 419 |
+
|
| 420 |
+
for k in range(K):
|
| 421 |
+
y0 = torch.bmm(
|
| 422 |
+
(raypos - primpos[:, k, None, None, :]).view(raypos.size(0), -1, raypos.size(3)),
|
| 423 |
+
primrot[:, k, :, :]).view_as(raypos) * primscale[:, k, None, None, :]
|
| 424 |
+
|
| 425 |
+
fade = torch.exp(-fadescale * torch.sum(torch.abs(y0) ** fadeexp, dim=-1, keepdim=True))
|
| 426 |
+
|
| 427 |
+
if dowarp:
|
| 428 |
+
y1 = F.grid_sample(
|
| 429 |
+
warp[:, k, :, :, :, :],
|
| 430 |
+
y0[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
|
| 431 |
+
else:
|
| 432 |
+
y1 = y0
|
| 433 |
+
|
| 434 |
+
sample = F.grid_sample(
|
| 435 |
+
template[:, k, :, :, :, :],
|
| 436 |
+
y1[:, None, :, :, :], align_corners=True)[:, :, 0, :, :].permute(0, 2, 3, 1)
|
| 437 |
+
|
| 438 |
+
valid1 = (
|
| 439 |
+
torch.prod(y0[:, :, :, :] >= -1., dim=-1, keepdim=True) *
|
| 440 |
+
torch.prod(y0[:, :, :, :] <= 1., dim=-1, keepdim=True))
|
| 441 |
+
|
| 442 |
+
valid = ((t >= tminmax[:, :, :, 0]) & (t < tminmax[:, :, :, 1])).float()[:, :, :, None]
|
| 443 |
+
|
| 444 |
+
alpha0 = sample[:, :, :, 3:4]
|
| 445 |
+
|
| 446 |
+
rgb = sample[:, :, :, 0:3] * valid * valid1
|
| 447 |
+
alpha = alpha0 * fade * stepsize * valid * valid1
|
| 448 |
+
|
| 449 |
+
if accum == 0:
|
| 450 |
+
newalpha = rayrgba[:, :, :, 3:4] + alpha
|
| 451 |
+
contrib = (newalpha.clamp(max=1.0) - rayrgba[:, :, :, 3:4]) * valid * valid1
|
| 452 |
+
rayrgba = rayrgba + contrib * torch.cat([rgb, torch.ones_like(alpha)], dim=-1)
|
| 453 |
+
else:
|
| 454 |
+
raise
|
| 455 |
+
|
| 456 |
+
step += 1
|
| 457 |
+
t = t0 + stepsize * step
|
| 458 |
+
raypos = raypos0 + raydir * stepsize * step
|
| 459 |
+
|
| 460 |
+
print(rayrgba[..., -1].min().item(), rayrgba[..., -1].max().item())
|
| 461 |
+
|
| 462 |
+
sample0 = rayrgba
|
| 463 |
+
|
| 464 |
+
torch.cuda.synchronize()
|
| 465 |
+
time1 = time.time()
|
| 466 |
+
|
| 467 |
+
sample0.backward(torch.ones_like(sample0))
|
| 468 |
+
|
| 469 |
+
torch.cuda.synchronize()
|
| 470 |
+
time2 = time.time()
|
| 471 |
+
|
| 472 |
+
print("{:<10} {:>10} {:>10} {:>10}".format("", "fwd", "bwd", "total"))
|
| 473 |
+
print("{:<10} {:10.5} {:10.5} {:10.5}".format("pytime", time1 - time0, time2 - time1, time2 - time0))
|
| 474 |
+
|
| 475 |
+
grads0 = [p.grad.detach().clone() for p in params]
|
| 476 |
+
|
| 477 |
+
for p in params:
|
| 478 |
+
p.grad.detach_()
|
| 479 |
+
p.grad.zero_()
|
| 480 |
+
|
| 481 |
+
############################## run cuda version ###########################
|
| 482 |
+
|
| 483 |
+
raypos = _raypos
|
| 484 |
+
raydir = _raydir
|
| 485 |
+
stepsize = _stepsize
|
| 486 |
+
tminmax = _tminmax
|
| 487 |
+
|
| 488 |
+
template = F.softplus(_template.to("cuda") * 1.5) if algo != 2 else _template.to("cuda") * 1.5
|
| 489 |
+
warp = _warp.to("cuda")
|
| 490 |
+
if chlast:
|
| 491 |
+
template = template.permute(0, 1, 3, 4, 5, 2).contiguous()
|
| 492 |
+
warp = warp.permute(0, 1, 3, 4, 5, 2).contiguous()
|
| 493 |
+
primpos = _primpos.to("cuda") * 0.3
|
| 494 |
+
primrot = _primrot.to("cuda")
|
| 495 |
+
primscale = scale_ws * torch.exp(0.1 * _primscale.to("cuda"))
|
| 496 |
+
|
| 497 |
+
niter = 1
|
| 498 |
+
|
| 499 |
+
tf, tb = 0., 0.
|
| 500 |
+
for i in range(niter):
|
| 501 |
+
for p in params:
|
| 502 |
+
try:
|
| 503 |
+
p.grad.detach_()
|
| 504 |
+
p.grad.zero_()
|
| 505 |
+
except:
|
| 506 |
+
pass
|
| 507 |
+
t0 = time.time()
|
| 508 |
+
torch.cuda.synchronize()
|
| 509 |
+
sample1 = mvpraymarch(raypos, raydir, stepsize, tminmax,
|
| 510 |
+
(primpos, primrot, primscale),
|
| 511 |
+
template, warp if dowarp else None,
|
| 512 |
+
algo=algo, usebvh=usebvh, sortprims=sortprims,
|
| 513 |
+
maxhitboxes=maxhitboxes, synchitboxes=synchitboxes,
|
| 514 |
+
chlast=chlast, fadescale=fadescale, fadeexp=fadeexp,
|
| 515 |
+
accum=accum, termthresh=termthreshorig,
|
| 516 |
+
griddim=griddim, blocksize=blocksize, bwdblocksize=bwdblocksize)
|
| 517 |
+
t1 = time.time()
|
| 518 |
+
torch.cuda.synchronize()
|
| 519 |
+
sample1.backward(torch.ones_like(sample1), retain_graph=True)
|
| 520 |
+
torch.cuda.synchronize()
|
| 521 |
+
t2 = time.time()
|
| 522 |
+
tf += t1 - t0
|
| 523 |
+
tb += t2 - t1
|
| 524 |
+
|
| 525 |
+
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
|
| 526 |
+
grads1 = [p.grad.detach().clone() for p in params]
|
| 527 |
+
|
| 528 |
+
############# compare results #############
|
| 529 |
+
|
| 530 |
+
print("-----------------------------------------------------------------")
|
| 531 |
+
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "||py||", "||cuda||", "index", "py", "cuda"))
|
| 532 |
+
ind = torch.argmax(torch.abs(sample0 - sample1))
|
| 533 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
| 534 |
+
"fwd",
|
| 535 |
+
torch.max(torch.abs(sample0 - sample1)).item(),
|
| 536 |
+
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
|
| 537 |
+
torch.sqrt(torch.sum(sample0 * sample0)).item(),
|
| 538 |
+
torch.sqrt(torch.sum(sample1 * sample1)).item(),
|
| 539 |
+
ind.item(),
|
| 540 |
+
sample0.view(-1)[ind].item(),
|
| 541 |
+
sample1.view(-1)[ind].item()))
|
| 542 |
+
|
| 543 |
+
for p, g0, g1 in zip(paramnames, grads0, grads1):
|
| 544 |
+
ind = torch.argmax(torch.abs(g0 - g1))
|
| 545 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
| 546 |
+
p,
|
| 547 |
+
torch.max(torch.abs(g0 - g1)).item(),
|
| 548 |
+
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
|
| 549 |
+
torch.sqrt(torch.sum(g0 * g0)).item(),
|
| 550 |
+
torch.sqrt(torch.sum(g1 * g1)).item(),
|
| 551 |
+
ind.item(),
|
| 552 |
+
g0.view(-1)[ind].item(),
|
| 553 |
+
g1.view(-1)[ind].item()))
|
| 554 |
+
|
| 555 |
+
if __name__ == "__main__":
|
| 556 |
+
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True,
|
| 557 |
+
dowarp=False, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=0, griddim=3)
|
| 558 |
+
gradcheck(usebvh="fixedorder", sortprims=False, maxhitboxes=512, synchitboxes=True,
|
| 559 |
+
dowarp=True, chlast=True, fadescale=6.5, fadeexp=7.5, accum=0, algo=1, griddim=3)
|
dva/mvp/extensions/mvpraymarch/mvpraymarch_kernel.cu
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#include <chrono>
|
| 8 |
+
#include <functional>
|
| 9 |
+
#include <iostream>
|
| 10 |
+
#include <map>
|
| 11 |
+
#include <memory>
|
| 12 |
+
#include <tuple>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
#include "helper_math.h"
|
| 16 |
+
|
| 17 |
+
#include "cudadispatch.h"
|
| 18 |
+
|
| 19 |
+
#include "utils.h"
|
| 20 |
+
|
| 21 |
+
#include "primtransf.h"
|
| 22 |
+
#include "primsampler.h"
|
| 23 |
+
#include "primaccum.h"
|
| 24 |
+
|
| 25 |
+
#include "mvpraymarch_subset_kernel.h"
|
| 26 |
+
|
| 27 |
+
typedef std::shared_ptr<PrimTransfDataBase> PrimTransfDataBase_ptr;
|
| 28 |
+
typedef std::shared_ptr<PrimSamplerDataBase> PrimSamplerDataBase_ptr;
|
| 29 |
+
typedef std::shared_ptr<PrimAccumDataBase> PrimAccumDataBase_ptr;
|
| 30 |
+
typedef std::function<void(dim3, dim3, cudaStream_t, int, int, int, int,
|
| 31 |
+
float3*, float3*, float, float2*, int*, int2*, float3*,
|
| 32 |
+
PrimTransfDataBase_ptr, PrimSamplerDataBase_ptr,
|
| 33 |
+
PrimAccumDataBase_ptr)> mapfn_t;
|
| 34 |
+
typedef RaySubsetFixedBVH<false, 512, true, PrimTransfSRT> raysubset_t;
|
| 35 |
+
|
| 36 |
+
void raymarch_forward_cuda(
|
| 37 |
+
int N, int H, int W, int K,
|
| 38 |
+
float * rayposim,
|
| 39 |
+
float * raydirim,
|
| 40 |
+
float stepsize,
|
| 41 |
+
float * tminmaxim,
|
| 42 |
+
|
| 43 |
+
int * sortedobjid,
|
| 44 |
+
int * nodechildren,
|
| 45 |
+
float * nodeaabb,
|
| 46 |
+
float * primpos,
|
| 47 |
+
float * primrot,
|
| 48 |
+
float * primscale,
|
| 49 |
+
|
| 50 |
+
int TD, int TH, int TW,
|
| 51 |
+
float * tplate,
|
| 52 |
+
int WD, int WH, int WW,
|
| 53 |
+
float * warp,
|
| 54 |
+
|
| 55 |
+
float * rayrgbaim,
|
| 56 |
+
float * raysatim,
|
| 57 |
+
int * raytermim,
|
| 58 |
+
|
| 59 |
+
int algorithm,
|
| 60 |
+
bool sortboxes,
|
| 61 |
+
int maxhitboxes,
|
| 62 |
+
bool synchitboxes,
|
| 63 |
+
bool chlast,
|
| 64 |
+
float fadescale,
|
| 65 |
+
float fadeexp,
|
| 66 |
+
int accum,
|
| 67 |
+
float termthresh,
|
| 68 |
+
int griddim, int blocksizex, int blocksizey,
|
| 69 |
+
cudaStream_t stream) {
|
| 70 |
+
dim3 blocksize(blocksizex, blocksizey);
|
| 71 |
+
dim3 gridsize;
|
| 72 |
+
gridsize = dim3(
|
| 73 |
+
(W + blocksize.x - 1) / blocksize.x,
|
| 74 |
+
(H + blocksize.y - 1) / blocksize.y,
|
| 75 |
+
N);
|
| 76 |
+
|
| 77 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
| 78 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
| 79 |
+
PrimTransfDataBase{},
|
| 80 |
+
K, (float3*)primpos, nullptr,
|
| 81 |
+
K * 3, (float3*)primrot, nullptr,
|
| 82 |
+
K, (float3*)primscale, nullptr});
|
| 83 |
+
std::shared_ptr<PrimSamplerDataBase> primsampler_data;
|
| 84 |
+
if (algorithm == 1) {
|
| 85 |
+
primsampler_data = std::make_shared<PrimSamplerTW<true>::Data>(PrimSamplerTW<true>::Data{
|
| 86 |
+
PrimSamplerDataBase{},
|
| 87 |
+
fadescale, fadeexp,
|
| 88 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr,
|
| 89 |
+
K * WD * WH * WW * 3, WD, WH, WW, warp, nullptr});
|
| 90 |
+
} else {
|
| 91 |
+
primsampler_data = std::make_shared<PrimSamplerTW<false>::Data>(PrimSamplerTW<false>::Data{
|
| 92 |
+
PrimSamplerDataBase{},
|
| 93 |
+
fadescale, fadeexp,
|
| 94 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, nullptr,
|
| 95 |
+
0, 0, 0, 0, nullptr, nullptr});
|
| 96 |
+
}
|
| 97 |
+
std::shared_ptr<PrimAccumDataBase> primaccum_data = std::make_shared<PrimAccumAdditive::Data>(PrimAccumAdditive::Data{
|
| 98 |
+
PrimAccumDataBase{},
|
| 99 |
+
termthresh, H * W, W, 1, (float4*)rayrgbaim, nullptr, (float3*)raysatim});
|
| 100 |
+
|
| 101 |
+
std::map<int, mapfn_t> dispatcher = {
|
| 102 |
+
{0, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<false>, PrimAccumAdditive>)},
|
| 103 |
+
{1, make_cudacall(raymarch_subset_forward_kernel<512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<true>, PrimAccumAdditive>)}};
|
| 104 |
+
|
| 105 |
+
auto iter = dispatcher.find(algorithm);
|
| 106 |
+
if (iter != dispatcher.end()) {
|
| 107 |
+
(iter->second)(
|
| 108 |
+
gridsize, blocksize, stream,
|
| 109 |
+
N, H, W, K,
|
| 110 |
+
reinterpret_cast<float3 *>(rayposim),
|
| 111 |
+
reinterpret_cast<float3 *>(raydirim),
|
| 112 |
+
stepsize,
|
| 113 |
+
reinterpret_cast<float2 *>(tminmaxim),
|
| 114 |
+
reinterpret_cast<int *>(sortedobjid),
|
| 115 |
+
reinterpret_cast<int2 *>(nodechildren),
|
| 116 |
+
reinterpret_cast<float3 *>(nodeaabb),
|
| 117 |
+
primtransf_data,
|
| 118 |
+
primsampler_data,
|
| 119 |
+
primaccum_data);
|
| 120 |
+
}
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
void raymarch_backward_cuda(
|
| 124 |
+
int N, int H, int W, int K,
|
| 125 |
+
float * rayposim,
|
| 126 |
+
float * raydirim,
|
| 127 |
+
float stepsize,
|
| 128 |
+
float * tminmaxim,
|
| 129 |
+
int * sortedobjid,
|
| 130 |
+
int * nodechildren,
|
| 131 |
+
float * nodeaabb,
|
| 132 |
+
|
| 133 |
+
float * primpos,
|
| 134 |
+
float * grad_primpos,
|
| 135 |
+
float * primrot,
|
| 136 |
+
float * grad_primrot,
|
| 137 |
+
float * primscale,
|
| 138 |
+
float * grad_primscale,
|
| 139 |
+
|
| 140 |
+
int TD, int TH, int TW,
|
| 141 |
+
float * tplate,
|
| 142 |
+
float * grad_tplate,
|
| 143 |
+
int WD, int WH, int WW,
|
| 144 |
+
float * warp,
|
| 145 |
+
float * grad_warp,
|
| 146 |
+
|
| 147 |
+
float * rayrgbaim,
|
| 148 |
+
float * grad_rayrgba,
|
| 149 |
+
float * raysatim,
|
| 150 |
+
int * raytermim,
|
| 151 |
+
|
| 152 |
+
int algorithm, bool sortboxes, int maxhitboxes, bool synchitboxes,
|
| 153 |
+
bool chlast, float fadescale, float fadeexp, int accum, float termthresh,
|
| 154 |
+
int griddim, int blocksizex, int blocksizey,
|
| 155 |
+
|
| 156 |
+
cudaStream_t stream) {
|
| 157 |
+
dim3 blocksize(blocksizex, blocksizey);
|
| 158 |
+
dim3 gridsize;
|
| 159 |
+
gridsize = dim3(
|
| 160 |
+
(W + blocksize.x - 1) / blocksize.x,
|
| 161 |
+
(H + blocksize.y - 1) / blocksize.y,
|
| 162 |
+
N);
|
| 163 |
+
|
| 164 |
+
std::shared_ptr<PrimTransfDataBase> primtransf_data;
|
| 165 |
+
primtransf_data = std::make_shared<PrimTransfSRT::Data>(PrimTransfSRT::Data{
|
| 166 |
+
PrimTransfDataBase{},
|
| 167 |
+
K, (float3*)primpos, (float3*)grad_primpos,
|
| 168 |
+
K * 3, (float3*)primrot, (float3*)grad_primrot,
|
| 169 |
+
K, (float3*)primscale, (float3*)grad_primscale});
|
| 170 |
+
std::shared_ptr<PrimSamplerDataBase> primsampler_data;
|
| 171 |
+
if (algorithm == 1) {
|
| 172 |
+
primsampler_data = std::make_shared<PrimSamplerTW<true>::Data>(PrimSamplerTW<true>::Data{
|
| 173 |
+
PrimSamplerDataBase{},
|
| 174 |
+
fadescale, fadeexp,
|
| 175 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate,
|
| 176 |
+
K * WD * WH * WW * 3, WD, WH, WW, warp, grad_warp});
|
| 177 |
+
} else {
|
| 178 |
+
primsampler_data = std::make_shared<PrimSamplerTW<false>::Data>(PrimSamplerTW<false>::Data{
|
| 179 |
+
PrimSamplerDataBase{},
|
| 180 |
+
fadescale, fadeexp,
|
| 181 |
+
K * TD * TH * TW * 4, TD, TH, TW, tplate, grad_tplate,
|
| 182 |
+
0, 0, 0, 0, nullptr, nullptr});
|
| 183 |
+
}
|
| 184 |
+
std::shared_ptr<PrimAccumDataBase> primaccum_data = std::make_shared<PrimAccumAdditive::Data>(PrimAccumAdditive::Data{
|
| 185 |
+
PrimAccumDataBase{},
|
| 186 |
+
termthresh, H * W, W, 1, (float4*)rayrgbaim, (float4*)grad_rayrgba, (float3*)raysatim});
|
| 187 |
+
|
| 188 |
+
std::map<int, mapfn_t> dispatcher = {
|
| 189 |
+
{0, make_cudacall(raymarch_subset_backward_kernel<true, 512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<false>, PrimAccumAdditive>)},
|
| 190 |
+
{1, make_cudacall(raymarch_subset_backward_kernel<true, 512, 4, raysubset_t, PrimTransfSRT, PrimSamplerTW<true>, PrimAccumAdditive>)}};
|
| 191 |
+
|
| 192 |
+
auto iter = dispatcher.find(algorithm);
|
| 193 |
+
if (iter != dispatcher.end()) {
|
| 194 |
+
(iter->second)(
|
| 195 |
+
gridsize, blocksize, stream,
|
| 196 |
+
N, H, W, K,
|
| 197 |
+
reinterpret_cast<float3 *>(rayposim),
|
| 198 |
+
reinterpret_cast<float3 *>(raydirim),
|
| 199 |
+
stepsize,
|
| 200 |
+
reinterpret_cast<float2 *>(tminmaxim),
|
| 201 |
+
reinterpret_cast<int *>(sortedobjid),
|
| 202 |
+
reinterpret_cast<int2 *>(nodechildren),
|
| 203 |
+
reinterpret_cast<float3 *>(nodeaabb),
|
| 204 |
+
primtransf_data,
|
| 205 |
+
primsampler_data,
|
| 206 |
+
primaccum_data);
|
| 207 |
+
}
|
| 208 |
+
}
|
dva/mvp/extensions/mvpraymarch/mvpraymarch_subset_kernel.h
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
template<
|
| 8 |
+
int maxhitboxes,
|
| 9 |
+
int nwarps,
|
| 10 |
+
class RaySubsetT=RaySubsetFixedBVH<false, 512, true, PrimTransfSRT>,
|
| 11 |
+
class PrimTransfT=PrimTransfSRT,
|
| 12 |
+
class PrimSamplerT=PrimSamplerTW<false>,
|
| 13 |
+
class PrimAccumT=PrimAccumAdditive>
|
| 14 |
+
__global__ void raymarch_subset_forward_kernel(
|
| 15 |
+
int N, int H, int W, int K,
|
| 16 |
+
float3 * rayposim,
|
| 17 |
+
float3 * raydirim,
|
| 18 |
+
float stepsize,
|
| 19 |
+
float2 * tminmaxim,
|
| 20 |
+
int * sortedobjid,
|
| 21 |
+
int2 * nodechildren,
|
| 22 |
+
float3 * nodeaabb,
|
| 23 |
+
typename PrimTransfT::Data primtransf_data,
|
| 24 |
+
typename PrimSamplerT::Data primsampler_data,
|
| 25 |
+
typename PrimAccumT::Data primaccum_data
|
| 26 |
+
) {
|
| 27 |
+
int w = blockIdx.x * blockDim.x + threadIdx.x;
|
| 28 |
+
int h = blockIdx.y * blockDim.y + threadIdx.y;
|
| 29 |
+
int n = blockIdx.z;
|
| 30 |
+
bool validthread = (w < W) && (h < H) && (n<N);
|
| 31 |
+
|
| 32 |
+
assert(nwarps == 0 || blockDim.x * blockDim.y / 32 <= nwarps);
|
| 33 |
+
const int warpid = __shfl_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32, 0);
|
| 34 |
+
assert(__match_any_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32) == 0xffffffff);
|
| 35 |
+
|
| 36 |
+
// warpmask contains the valid threads in the warp
|
| 37 |
+
unsigned warpmask = 0xffffffff;
|
| 38 |
+
n = min(N - 1, n);
|
| 39 |
+
h = min(H - 1, h);
|
| 40 |
+
w = min(W - 1, w);
|
| 41 |
+
|
| 42 |
+
sortedobjid += n * K;
|
| 43 |
+
nodechildren += n * (K + K - 1);
|
| 44 |
+
nodeaabb += n * (K + K - 1) * 2;
|
| 45 |
+
|
| 46 |
+
primtransf_data.n_stride(n);
|
| 47 |
+
primsampler_data.n_stride(n);
|
| 48 |
+
primaccum_data.n_stride(n, h, w);
|
| 49 |
+
|
| 50 |
+
float3 raypos = rayposim[n * H * W + h * W + w];
|
| 51 |
+
float3 raydir = raydirim[n * H * W + h * W + w];
|
| 52 |
+
float2 tminmax = tminmaxim[n * H * W + h * W + w];
|
| 53 |
+
|
| 54 |
+
int hitboxes[nwarps > 0 ? 1 : maxhitboxes];
|
| 55 |
+
__shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1];
|
| 56 |
+
int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes;
|
| 57 |
+
int nhitboxes = 0;
|
| 58 |
+
|
| 59 |
+
// find raytminmax
|
| 60 |
+
float2 rtminmax = make_float2(std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity());
|
| 61 |
+
RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax,
|
| 62 |
+
sortedobjid, nodechildren, nodeaabb,
|
| 63 |
+
primtransf_data, hitboxes_ptr, nhitboxes);
|
| 64 |
+
rtminmax.x = max(rtminmax.x, tminmax.x);
|
| 65 |
+
rtminmax.y = min(rtminmax.y, tminmax.y);
|
| 66 |
+
__syncwarp(warpmask);
|
| 67 |
+
|
| 68 |
+
float t = tminmax.x;
|
| 69 |
+
raypos = raypos + raydir * tminmax.x;
|
| 70 |
+
|
| 71 |
+
int incs = floor((rtminmax.x - t) / stepsize);
|
| 72 |
+
t += incs * stepsize;
|
| 73 |
+
raypos += raydir * incs * stepsize;
|
| 74 |
+
|
| 75 |
+
PrimAccumT pa;
|
| 76 |
+
|
| 77 |
+
while (!__all_sync(warpmask, t > rtminmax.y + 1e-5f || pa.is_done())) {
|
| 78 |
+
for (int ks = 0; ks < nhitboxes; ++ks) {
|
| 79 |
+
int k = hitboxes_ptr[ks];
|
| 80 |
+
|
| 81 |
+
// compute primitive-relative coordinate
|
| 82 |
+
PrimTransfT pt;
|
| 83 |
+
float3 samplepos = pt.forward(primtransf_data, k, raypos);
|
| 84 |
+
|
| 85 |
+
if (pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f) {
|
| 86 |
+
// sample
|
| 87 |
+
PrimSamplerT ps;
|
| 88 |
+
float4 sample = ps.forward(primsampler_data, k, samplepos);
|
| 89 |
+
|
| 90 |
+
// accumulate
|
| 91 |
+
pa.forward_prim(primaccum_data, sample, stepsize);
|
| 92 |
+
}
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
// update position
|
| 96 |
+
t += stepsize;
|
| 97 |
+
raypos += raydir * stepsize;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
pa.write(primaccum_data);
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
template <
|
| 104 |
+
bool forwarddir,
|
| 105 |
+
int maxhitboxes,
|
| 106 |
+
int nwarps,
|
| 107 |
+
class RaySubsetT=RaySubsetFixedBVH<false, 512, true, PrimTransfSRT>,
|
| 108 |
+
class PrimTransfT=PrimTransfSRT,
|
| 109 |
+
class PrimSamplerT=PrimSamplerTW<false>,
|
| 110 |
+
class PrimAccumT=PrimAccumAdditive>
|
| 111 |
+
__global__ void raymarch_subset_backward_kernel(
|
| 112 |
+
int N, int H, int W, int K,
|
| 113 |
+
float3 * rayposim,
|
| 114 |
+
float3 * raydirim,
|
| 115 |
+
float stepsize,
|
| 116 |
+
float2 * tminmaxim,
|
| 117 |
+
int * sortedobjid,
|
| 118 |
+
int2 * nodechildren,
|
| 119 |
+
float3 * nodeaabb,
|
| 120 |
+
typename PrimTransfT::Data primtransf_data,
|
| 121 |
+
typename PrimSamplerT::Data primsampler_data,
|
| 122 |
+
typename PrimAccumT::Data primaccum_data
|
| 123 |
+
) {
|
| 124 |
+
int w = blockIdx.x * blockDim.x + threadIdx.x;
|
| 125 |
+
int h = blockIdx.y * blockDim.y + threadIdx.y;
|
| 126 |
+
int n = blockIdx.z;
|
| 127 |
+
bool validthread = (w < W) && (h < H) && (n<N);
|
| 128 |
+
|
| 129 |
+
assert(nwarps == 0 || blockDim.x * blockDim.y / 32 <= nwarps);
|
| 130 |
+
const int warpid = __shfl_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32, 0);
|
| 131 |
+
assert(__match_any_sync(0xffffffff, (threadIdx.y * blockDim.x + threadIdx.x) / 32) == 0xffffffff);
|
| 132 |
+
|
| 133 |
+
// warpmask contains the valid threads in the warp
|
| 134 |
+
unsigned warpmask = 0xffffffff;
|
| 135 |
+
n = min(N - 1, n);
|
| 136 |
+
h = min(H - 1, h);
|
| 137 |
+
w = min(W - 1, w);
|
| 138 |
+
|
| 139 |
+
sortedobjid += n * K;
|
| 140 |
+
nodechildren += n * (K + K - 1);
|
| 141 |
+
nodeaabb += n * (K + K - 1) * 2;
|
| 142 |
+
|
| 143 |
+
primtransf_data.n_stride(n);
|
| 144 |
+
primsampler_data.n_stride(n);
|
| 145 |
+
primaccum_data.n_stride(n, h, w);
|
| 146 |
+
|
| 147 |
+
float3 raypos = rayposim[n * H * W + h * W + w];
|
| 148 |
+
float3 raydir = raydirim[n * H * W + h * W + w];
|
| 149 |
+
float2 tminmax = tminmaxim[n * H * W + h * W + w];
|
| 150 |
+
|
| 151 |
+
PrimAccumT pa;
|
| 152 |
+
pa.read(primaccum_data);
|
| 153 |
+
|
| 154 |
+
int hitboxes[nwarps > 0 ? 1 : maxhitboxes];
|
| 155 |
+
__shared__ int hitboxes_sh[nwarps > 0 ? maxhitboxes * nwarps : 1];
|
| 156 |
+
int * hitboxes_ptr = nwarps > 0 ? hitboxes_sh + maxhitboxes * warpid : hitboxes;
|
| 157 |
+
int nhitboxes = 0;
|
| 158 |
+
|
| 159 |
+
// find raytminmax
|
| 160 |
+
float2 rtminmax = make_float2(std::numeric_limits<float>::infinity(), -std::numeric_limits<float>::infinity());
|
| 161 |
+
RaySubsetT::forward(warpmask, K, raypos, raydir, tminmax, rtminmax,
|
| 162 |
+
sortedobjid, nodechildren, nodeaabb,
|
| 163 |
+
primtransf_data, hitboxes_ptr, nhitboxes);
|
| 164 |
+
rtminmax.x = max(rtminmax.x, tminmax.x);
|
| 165 |
+
rtminmax.y = min(rtminmax.y, tminmax.y);
|
| 166 |
+
__syncwarp(warpmask);
|
| 167 |
+
|
| 168 |
+
// set up raymarching position
|
| 169 |
+
float t = tminmax.x;
|
| 170 |
+
raypos = raypos + raydir * tminmax.x;
|
| 171 |
+
|
| 172 |
+
int incs = floor((rtminmax.x - t) / stepsize);
|
| 173 |
+
t += incs * stepsize;
|
| 174 |
+
raypos += raydir * incs * stepsize;
|
| 175 |
+
|
| 176 |
+
if (!forwarddir) {
|
| 177 |
+
int nsteps = pa.get_nsteps();
|
| 178 |
+
t += nsteps * stepsize;
|
| 179 |
+
raypos += raydir * nsteps * stepsize;
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
while (__any_sync(warpmask, (
|
| 183 |
+
(forwarddir && t < rtminmax.y + 1e-5f ||
|
| 184 |
+
!forwarddir && t > rtminmax.x - 1e-5f) &&
|
| 185 |
+
!pa.is_done()))) {
|
| 186 |
+
for (int ks = 0; ks < nhitboxes; ++ks) {
|
| 187 |
+
int k = hitboxes_ptr[forwarddir ? ks : nhitboxes - ks - 1];
|
| 188 |
+
|
| 189 |
+
PrimTransfT pt;
|
| 190 |
+
float3 samplepos = pt.forward(primtransf_data, k, raypos);
|
| 191 |
+
|
| 192 |
+
bool evalprim = pt.valid(samplepos) && !pa.is_done() && t < rtminmax.y + 1e-5f;
|
| 193 |
+
|
| 194 |
+
float3 dL_samplepos = make_float3(0.f);
|
| 195 |
+
if (evalprim) {
|
| 196 |
+
PrimSamplerT ps;
|
| 197 |
+
float4 sample = ps.forward(primsampler_data, k, samplepos);
|
| 198 |
+
|
| 199 |
+
float4 dL_sample = pa.forwardbackward_prim(primaccum_data, sample, stepsize);
|
| 200 |
+
|
| 201 |
+
dL_samplepos = ps.backward(primsampler_data, k, samplepos, sample, dL_sample, validthread);
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
if (__any_sync(warpmask, evalprim)) {
|
| 205 |
+
pt.backward(primtransf_data, k, samplepos, dL_samplepos, validthread && evalprim);
|
| 206 |
+
}
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
if (forwarddir) {
|
| 210 |
+
t += stepsize;
|
| 211 |
+
raypos += raydir * stepsize;
|
| 212 |
+
} else {
|
| 213 |
+
t -= stepsize;
|
| 214 |
+
raypos -= raydir * stepsize;
|
| 215 |
+
}
|
| 216 |
+
}
|
| 217 |
+
}
|
| 218 |
+
|
dva/mvp/extensions/mvpraymarch/primaccum.h
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#ifndef MVPRAYMARCHER_PRIMACCUM_H_
|
| 8 |
+
#define MVPRAYMARCHER_PRIMACCUM_H_
|
| 9 |
+
|
| 10 |
+
struct PrimAccumDataBase {
|
| 11 |
+
typedef PrimAccumDataBase base;
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
struct PrimAccumAdditive {
|
| 15 |
+
struct Data : public PrimAccumDataBase {
|
| 16 |
+
float termthresh;
|
| 17 |
+
|
| 18 |
+
int nstride, hstride, wstride;
|
| 19 |
+
float4 * rayrgbaim;
|
| 20 |
+
float4 * grad_rayrgbaim;
|
| 21 |
+
float3 * raysatim;
|
| 22 |
+
|
| 23 |
+
__forceinline__ __device__ void n_stride(int n, int h, int w) {
|
| 24 |
+
rayrgbaim += n * nstride + h * hstride + w * wstride;
|
| 25 |
+
grad_rayrgbaim += n * nstride + h * hstride + w * wstride;
|
| 26 |
+
if (raysatim) {
|
| 27 |
+
raysatim += n * nstride + h * hstride + w * wstride;
|
| 28 |
+
}
|
| 29 |
+
}
|
| 30 |
+
};
|
| 31 |
+
|
| 32 |
+
float4 rayrgba;
|
| 33 |
+
float3 raysat;
|
| 34 |
+
bool sat;
|
| 35 |
+
float4 dL_rayrgba;
|
| 36 |
+
|
| 37 |
+
__forceinline__ __device__ PrimAccumAdditive() :
|
| 38 |
+
rayrgba(make_float4(0.f)),
|
| 39 |
+
raysat(make_float3(-1.f)),
|
| 40 |
+
sat(false) {
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
__forceinline__ __device__ bool is_done() const {
|
| 44 |
+
return sat;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
__forceinline__ __device__ int get_nsteps() const {
|
| 48 |
+
return 0;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
__forceinline__ __device__ void write(const Data & data) {
|
| 52 |
+
*data.rayrgbaim = rayrgba;
|
| 53 |
+
if (data.raysatim) {
|
| 54 |
+
*data.raysatim = raysat;
|
| 55 |
+
}
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
__forceinline__ __device__ void read(const Data & data) {
|
| 59 |
+
dL_rayrgba = *data.grad_rayrgbaim;
|
| 60 |
+
raysat = *data.raysatim;
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
__forceinline__ __device__ void forward_prim(const Data & data, float4 sample, float stepsize) {
|
| 64 |
+
// accumulate
|
| 65 |
+
float3 rgb = make_float3(sample);
|
| 66 |
+
float alpha = sample.w;
|
| 67 |
+
float newalpha = rayrgba.w + alpha * stepsize;
|
| 68 |
+
float contrib = fminf(newalpha, 1.f) - rayrgba.w;
|
| 69 |
+
|
| 70 |
+
rayrgba += make_float4(rgb, 1.f) * contrib;
|
| 71 |
+
|
| 72 |
+
if (newalpha >= 1.f) {
|
| 73 |
+
// save saturation point
|
| 74 |
+
if (!sat) {
|
| 75 |
+
raysat = rgb;
|
| 76 |
+
}
|
| 77 |
+
sat = true;
|
| 78 |
+
}
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
__forceinline__ __device__ float4 forwardbackward_prim(const Data & data, float4 sample, float stepsize) {
|
| 82 |
+
float3 rgb = make_float3(sample);
|
| 83 |
+
float4 rgb1 = make_float4(rgb, 1.f);
|
| 84 |
+
sample.w *= stepsize;
|
| 85 |
+
|
| 86 |
+
bool thissat = rayrgba.w + sample.w >= 1.f;
|
| 87 |
+
sat = sat || thissat;
|
| 88 |
+
|
| 89 |
+
float weight = sat ? (1.f - rayrgba.w) : sample.w;
|
| 90 |
+
|
| 91 |
+
float3 dL_rgb = weight * make_float3(dL_rayrgba);
|
| 92 |
+
float dL_alpha = sat ? 0.f :
|
| 93 |
+
stepsize * dot(rgb1 - (raysat.x > -1.f ? make_float4(raysat, 1.f) : make_float4(0.f)), dL_rayrgba);
|
| 94 |
+
|
| 95 |
+
rayrgba += make_float4(rgb, 1.f) * weight;
|
| 96 |
+
|
| 97 |
+
return make_float4(dL_rgb, dL_alpha);
|
| 98 |
+
}
|
| 99 |
+
};
|
| 100 |
+
|
| 101 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/primsampler.h
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#ifndef MVPRAYMARCHER_PRIMSAMPLER_H_
|
| 8 |
+
#define MVPRAYMARCHER_PRIMSAMPLER_H_
|
| 9 |
+
|
| 10 |
+
struct PrimSamplerDataBase {
|
| 11 |
+
typedef PrimSamplerDataBase base;
|
| 12 |
+
};
|
| 13 |
+
|
| 14 |
+
template<
|
| 15 |
+
bool dowarp,
|
| 16 |
+
template<typename> class GridSamplerT=GridSamplerChlast>
|
| 17 |
+
struct PrimSamplerTW {
|
| 18 |
+
struct Data : public PrimSamplerDataBase {
|
| 19 |
+
float fadescale, fadeexp;
|
| 20 |
+
|
| 21 |
+
int tplate_nstride;
|
| 22 |
+
int TD, TH, TW;
|
| 23 |
+
float * tplate;
|
| 24 |
+
float * grad_tplate;
|
| 25 |
+
|
| 26 |
+
int warp_nstride;
|
| 27 |
+
int WD, WH, WW;
|
| 28 |
+
float * warp;
|
| 29 |
+
float * grad_warp;
|
| 30 |
+
|
| 31 |
+
__forceinline__ __device__ void n_stride(int n) {
|
| 32 |
+
tplate += n * tplate_nstride;
|
| 33 |
+
grad_tplate += n * tplate_nstride;
|
| 34 |
+
warp += n * warp_nstride;
|
| 35 |
+
grad_warp += n * warp_nstride;
|
| 36 |
+
}
|
| 37 |
+
};
|
| 38 |
+
|
| 39 |
+
float fade;
|
| 40 |
+
float * tplate_ptr;
|
| 41 |
+
float * warp_ptr;
|
| 42 |
+
float3 yy1;
|
| 43 |
+
|
| 44 |
+
__forceinline__ __device__ float4 forward(
|
| 45 |
+
const Data & data,
|
| 46 |
+
int k,
|
| 47 |
+
float3 y0) {
|
| 48 |
+
fade = __expf(-data.fadescale * (
|
| 49 |
+
__powf(abs(y0.x), data.fadeexp) +
|
| 50 |
+
__powf(abs(y0.y), data.fadeexp) +
|
| 51 |
+
__powf(abs(y0.z), data.fadeexp)));
|
| 52 |
+
|
| 53 |
+
if (dowarp) {
|
| 54 |
+
warp_ptr = data.warp + (k * 3 * data.WD * data.WH * data.WW);
|
| 55 |
+
yy1 = GridSamplerT<float3>::forward(3, data.WD, data.WH, data.WW, warp_ptr, y0, false);
|
| 56 |
+
} else {
|
| 57 |
+
yy1 = y0;
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
tplate_ptr = data.tplate + (k * 4 * data.TD * data.TH * data.TW);
|
| 61 |
+
float4 sample = GridSamplerT<float4>::forward(4, data.TD, data.TH, data.TW, tplate_ptr, yy1, false);
|
| 62 |
+
|
| 63 |
+
sample.w *= fade;
|
| 64 |
+
|
| 65 |
+
return sample;
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
__forceinline__ __device__ float3 backward(const Data & data, int k, float3 y0,
|
| 69 |
+
float4 sample, float4 dL_sample, bool validthread) {
|
| 70 |
+
float3 dfade_y0 = -(data.fadescale * data.fadeexp) * make_float3(
|
| 71 |
+
__powf(abs(y0.x), data.fadeexp - 1.f) * (y0.x > 0.f ? 1.f : -1.f),
|
| 72 |
+
__powf(abs(y0.y), data.fadeexp - 1.f) * (y0.y > 0.f ? 1.f : -1.f),
|
| 73 |
+
__powf(abs(y0.z), data.fadeexp - 1.f) * (y0.z > 0.f ? 1.f : -1.f));
|
| 74 |
+
float3 dL_y0 = dfade_y0 * sample.w * dL_sample.w;
|
| 75 |
+
|
| 76 |
+
dL_sample.w *= fade;
|
| 77 |
+
|
| 78 |
+
float * grad_tplate_ptr = data.grad_tplate + (k * 4 * data.TD * data.TH * data.TW);
|
| 79 |
+
float3 dL_y1 = GridSamplerT<float4>::backward(4, data.TD, data.TH, data.TW,
|
| 80 |
+
tplate_ptr, grad_tplate_ptr, yy1, validthread ? dL_sample : make_float4(0.f), false);
|
| 81 |
+
|
| 82 |
+
if (dowarp) {
|
| 83 |
+
float * grad_warp_ptr = data.grad_warp + (k * 3 * data.WD * data.WH * data.WW);
|
| 84 |
+
dL_y0 += GridSamplerT<float3>::backward(3, data.WD, data.WH, data.WW,
|
| 85 |
+
warp_ptr, grad_warp_ptr, y0, validthread ? dL_y1 : make_float3(0.f), false);
|
| 86 |
+
} else {
|
| 87 |
+
dL_y0 += dL_y1;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
return dL_y0;
|
| 91 |
+
}
|
| 92 |
+
};
|
| 93 |
+
|
| 94 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/primtransf.h
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#ifndef MVPRAYMARCHER_PRIMTRANSF_H_
|
| 8 |
+
#define MVPRAYMARCHER_PRIMTRANSF_H_
|
| 9 |
+
|
| 10 |
+
#include "utils.h"
|
| 11 |
+
|
| 12 |
+
__forceinline__ __device__ void compute_aabb_srt(
|
| 13 |
+
float3 pt, float3 pr0, float3 pr1, float3 pr2, float3 ps,
|
| 14 |
+
float3 & pmin, float3 & pmax) {
|
| 15 |
+
float3 p;
|
| 16 |
+
p = make_float3(-1.f, -1.f, -1.f) / ps;
|
| 17 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 18 |
+
|
| 19 |
+
pmin = p;
|
| 20 |
+
pmax = p;
|
| 21 |
+
|
| 22 |
+
p = make_float3(1.f, -1.f, -1.f) / ps;
|
| 23 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 24 |
+
|
| 25 |
+
pmin = fminf(pmin, p);
|
| 26 |
+
pmax = fmaxf(pmax, p);
|
| 27 |
+
|
| 28 |
+
p = make_float3(-1.f, 1.f, -1.f) / ps;
|
| 29 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 30 |
+
|
| 31 |
+
pmin = fminf(pmin, p);
|
| 32 |
+
pmax = fmaxf(pmax, p);
|
| 33 |
+
|
| 34 |
+
p = make_float3(1.f, 1.f, -1.f) / ps;
|
| 35 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 36 |
+
|
| 37 |
+
pmin = fminf(pmin, p);
|
| 38 |
+
pmax = fmaxf(pmax, p);
|
| 39 |
+
|
| 40 |
+
p = make_float3(-1.f, -1.f, 1.f) / ps;
|
| 41 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 42 |
+
|
| 43 |
+
pmin = fminf(pmin, p);
|
| 44 |
+
pmax = fmaxf(pmax, p);
|
| 45 |
+
|
| 46 |
+
p = make_float3(1.f, -1.f, 1.f) / ps;
|
| 47 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 48 |
+
|
| 49 |
+
pmin = fminf(pmin, p);
|
| 50 |
+
pmax = fmaxf(pmax, p);
|
| 51 |
+
|
| 52 |
+
p = make_float3(-1.f, 1.f, 1.f) / ps;
|
| 53 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 54 |
+
|
| 55 |
+
pmin = fminf(pmin, p);
|
| 56 |
+
pmax = fmaxf(pmax, p);
|
| 57 |
+
|
| 58 |
+
p = make_float3(1.f, 1.f, 1.f) / ps;
|
| 59 |
+
p = make_float3(dot(p, pr0), dot(p, pr1), dot(p, pr2)) + pt;
|
| 60 |
+
|
| 61 |
+
pmin = fminf(pmin, p);
|
| 62 |
+
pmax = fmaxf(pmax, p);
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
struct PrimTransfDataBase {
|
| 66 |
+
typedef PrimTransfDataBase base;
|
| 67 |
+
};
|
| 68 |
+
|
| 69 |
+
struct PrimTransfSRT {
|
| 70 |
+
struct Data : public PrimTransfDataBase {
|
| 71 |
+
int primpos_nstride;
|
| 72 |
+
float3 * primpos;
|
| 73 |
+
float3 * grad_primpos;
|
| 74 |
+
int primrot_nstride;
|
| 75 |
+
float3 * primrot;
|
| 76 |
+
float3 * grad_primrot;
|
| 77 |
+
int primscale_nstride;
|
| 78 |
+
float3 * primscale;
|
| 79 |
+
float3 * grad_primscale;
|
| 80 |
+
|
| 81 |
+
__forceinline__ __device__ void n_stride(int n) {
|
| 82 |
+
primpos += n * primpos_nstride;
|
| 83 |
+
grad_primpos += n * primpos_nstride;
|
| 84 |
+
primrot += n * primrot_nstride;
|
| 85 |
+
grad_primrot += n * primrot_nstride;
|
| 86 |
+
primscale += n * primscale_nstride;
|
| 87 |
+
grad_primscale += n * primscale_nstride;
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
__forceinline__ __device__ float3 get_center(int n, int k) {
|
| 91 |
+
return primpos[n * primpos_nstride + k];
|
| 92 |
+
}
|
| 93 |
+
|
| 94 |
+
__forceinline__ __device__ void compute_aabb(int n, int k, float3 & pmin, float3 & pmax) {
|
| 95 |
+
float3 pt = primpos[n * primpos_nstride + k];
|
| 96 |
+
float3 pr0 = primrot[n * primrot_nstride + k * 3 + 0];
|
| 97 |
+
float3 pr1 = primrot[n * primrot_nstride + k * 3 + 1];
|
| 98 |
+
float3 pr2 = primrot[n * primrot_nstride + k * 3 + 2];
|
| 99 |
+
float3 ps = primscale[n * primscale_nstride + k];
|
| 100 |
+
|
| 101 |
+
compute_aabb_srt(pt, pr0, pr1, pr2, ps, pmin, pmax);
|
| 102 |
+
}
|
| 103 |
+
};
|
| 104 |
+
|
| 105 |
+
float3 xmt;
|
| 106 |
+
float3 pr0;
|
| 107 |
+
float3 pr1;
|
| 108 |
+
float3 pr2;
|
| 109 |
+
float3 rxmt;
|
| 110 |
+
float3 ps;
|
| 111 |
+
|
| 112 |
+
static __forceinline__ __device__ bool valid(float3 pos) {
|
| 113 |
+
return (
|
| 114 |
+
pos.x > -1.f && pos.x < 1.f &&
|
| 115 |
+
pos.y > -1.f && pos.y < 1.f &&
|
| 116 |
+
pos.z > -1.f && pos.z < 1.f);
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
__forceinline__ __device__ float3 forward(
|
| 120 |
+
const Data & data,
|
| 121 |
+
int k,
|
| 122 |
+
float3 x) {
|
| 123 |
+
float3 pt = data.primpos[k];
|
| 124 |
+
pr0 = data.primrot[(k) * 3 + 0];
|
| 125 |
+
pr1 = data.primrot[(k) * 3 + 1];
|
| 126 |
+
pr2 = data.primrot[(k) * 3 + 2];
|
| 127 |
+
ps = data.primscale[k];
|
| 128 |
+
xmt = x - pt;
|
| 129 |
+
rxmt = pr0 * xmt.x + pr1 * xmt.y + pr2 * xmt.z;
|
| 130 |
+
float3 y0 = rxmt * ps;
|
| 131 |
+
return y0;
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
static __forceinline__ __device__ void forward2(
|
| 135 |
+
const Data & data,
|
| 136 |
+
int k,
|
| 137 |
+
float3 r, float3 d, float3 & rout, float3 & dout) {
|
| 138 |
+
float3 pt = data.primpos[k];
|
| 139 |
+
float3 pr0 = data.primrot[k * 3 + 0];
|
| 140 |
+
float3 pr1 = data.primrot[k * 3 + 1];
|
| 141 |
+
float3 pr2 = data.primrot[k * 3 + 2];
|
| 142 |
+
float3 ps = data.primscale[k];
|
| 143 |
+
float3 xmt = r - pt;
|
| 144 |
+
float3 dmt = d;
|
| 145 |
+
float3 rxmt = pr0 * xmt.x;
|
| 146 |
+
float3 rdmt = pr0 * dmt.x;
|
| 147 |
+
rxmt += pr1 * xmt.y;
|
| 148 |
+
rdmt += pr1 * dmt.y;
|
| 149 |
+
rxmt += pr2 * xmt.z;
|
| 150 |
+
rdmt += pr2 * dmt.z;
|
| 151 |
+
rout = rxmt * ps;
|
| 152 |
+
dout = rdmt * ps;
|
| 153 |
+
}
|
| 154 |
+
|
| 155 |
+
__forceinline__ __device__ void backward(const Data & data, int k, float3 x, float3 dL_y0, bool validthread) {
|
| 156 |
+
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 0, validthread ? rxmt.x * dL_y0.x : 0.f);
|
| 157 |
+
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 1, validthread ? rxmt.y * dL_y0.y : 0.f);
|
| 158 |
+
fastAtomicAdd((float*)data.grad_primscale + k * 3 + 2, validthread ? rxmt.z * dL_y0.z : 0.f);
|
| 159 |
+
|
| 160 |
+
dL_y0 *= ps;
|
| 161 |
+
float3 gpr0 = xmt.x * dL_y0;
|
| 162 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 0, validthread ? gpr0.x : 0.f);
|
| 163 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 1, validthread ? gpr0.y : 0.f);
|
| 164 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 0) * 3 + 2, validthread ? gpr0.z : 0.f);
|
| 165 |
+
|
| 166 |
+
float3 gpr1 = xmt.y * dL_y0;
|
| 167 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 0, validthread ? gpr1.x : 0.f);
|
| 168 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 1, validthread ? gpr1.y : 0.f);
|
| 169 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 1) * 3 + 2, validthread ? gpr1.z : 0.f);
|
| 170 |
+
|
| 171 |
+
float3 gpr2 = xmt.z * dL_y0;
|
| 172 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 0, validthread ? gpr2.x : 0.f);
|
| 173 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 1, validthread ? gpr2.y : 0.f);
|
| 174 |
+
fastAtomicAdd((float*)data.grad_primrot + (k * 3 + 2) * 3 + 2, validthread ? gpr2.z : 0.f);
|
| 175 |
+
|
| 176 |
+
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 0, validthread ? -dot(pr0, dL_y0) : 0.f);
|
| 177 |
+
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 1, validthread ? -dot(pr1, dL_y0) : 0.f);
|
| 178 |
+
fastAtomicAdd((float*)data.grad_primpos + k * 3 + 2, validthread ? -dot(pr2, dL_y0) : 0.f);
|
| 179 |
+
}
|
| 180 |
+
};
|
| 181 |
+
|
| 182 |
+
#endif
|
dva/mvp/extensions/mvpraymarch/setup.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from setuptools import setup
|
| 8 |
+
|
| 9 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
import torch
|
| 13 |
+
setup(
|
| 14 |
+
name="mvpraymarch",
|
| 15 |
+
ext_modules=[
|
| 16 |
+
CUDAExtension(
|
| 17 |
+
"mvpraymarchlib",
|
| 18 |
+
sources=["mvpraymarch.cpp", "mvpraymarch_kernel.cu", "bvh.cu"],
|
| 19 |
+
extra_compile_args={
|
| 20 |
+
"nvcc": [
|
| 21 |
+
"-use_fast_math",
|
| 22 |
+
"-arch=sm_70",
|
| 23 |
+
"-std=c++17",
|
| 24 |
+
"-lineinfo",
|
| 25 |
+
]
|
| 26 |
+
}
|
| 27 |
+
)
|
| 28 |
+
],
|
| 29 |
+
cmdclass={"build_ext": BuildExtension}
|
| 30 |
+
)
|
dva/mvp/extensions/mvpraymarch/utils.h
ADDED
|
@@ -0,0 +1,847 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#ifndef MVPRAYMARCHER_UTILS_H_
|
| 8 |
+
#define MVPRAYMARCHER_UTILS_H_
|
| 9 |
+
|
| 10 |
+
#include <cassert>
|
| 11 |
+
#include <cmath>
|
| 12 |
+
|
| 13 |
+
#include <limits>
|
| 14 |
+
|
| 15 |
+
#include "helper_math.h"
|
| 16 |
+
|
| 17 |
+
static __forceinline__ __device__ float clock_diff(long long int end, long long int start) {
|
| 18 |
+
long long int max_clock = std::numeric_limits<long long int>::max();
|
| 19 |
+
return (end<start? (end + float(max_clock-start)) : float(end-start));
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
static __forceinline__ __device__
|
| 23 |
+
bool allgt(float3 a, float3 b) {
|
| 24 |
+
return a.x >= b.x && a.y >= b.y && a.z >= b.z;
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
static __forceinline__ __device__
|
| 28 |
+
bool alllt(float3 a, float3 b) {
|
| 29 |
+
return a.x <= b.x && a.y <= b.y && a.z <= b.z;
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
static __forceinline__ __device__
|
| 33 |
+
float4 softplus(float4 x) {
|
| 34 |
+
return make_float4(
|
| 35 |
+
x.x > 20.f ? x.x : logf(1.f + expf(x.x)),
|
| 36 |
+
x.y > 20.f ? x.y : logf(1.f + expf(x.y)),
|
| 37 |
+
x.z > 20.f ? x.z : logf(1.f + expf(x.z)),
|
| 38 |
+
x.w > 20.f ? x.w : logf(1.f + expf(x.w)));
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
static __forceinline__ __device__
|
| 42 |
+
float softplus(float x) {
|
| 43 |
+
// that's a neat trick
|
| 44 |
+
return __logf(1.f + __expf(-abs(x))) + max(x, 0.f);
|
| 45 |
+
}
|
| 46 |
+
static __forceinline__ __device__
|
| 47 |
+
float softplus_grad(float x) {
|
| 48 |
+
// that's a neat trick
|
| 49 |
+
float expnabsx = __expf(-abs(x));
|
| 50 |
+
return (0.5f - expnabsx / (1.f + expnabsx)) * copysign(1.f, x) + 0.5f;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
static __forceinline__ __device__
|
| 55 |
+
float4 sigmoid(float4 x) {
|
| 56 |
+
return make_float4(
|
| 57 |
+
1.f / (1.f + expf(-x.x)),
|
| 58 |
+
1.f / (1.f + expf(-x.y)),
|
| 59 |
+
1.f / (1.f + expf(-x.z)),
|
| 60 |
+
1.f / (1.f + expf(-x.w)));
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
// perform reduction on warp, then call atomicAdd for only one lane
|
| 64 |
+
static __forceinline__ __device__ void fastAtomicAdd(float * ptr, float val) {
|
| 65 |
+
for (int offset = 16; offset > 0; offset /= 2) {
|
| 66 |
+
val += __shfl_down_sync(0xffffffff, val, offset);
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
const int laneid = (threadIdx.y * blockDim.x + threadIdx.x) % 32;
|
| 70 |
+
if (laneid == 0) {
|
| 71 |
+
atomicAdd(ptr, val);
|
| 72 |
+
}
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
static __forceinline__ __device__
|
| 77 |
+
bool within_bounds_3d(int d, int h, int w, int D, int H, int W) {
|
| 78 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W;
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
static __forceinline__ __device__
|
| 82 |
+
void safe_add_3d(float *data, int d, int h, int w,
|
| 83 |
+
int sD, int sH, int sW, int D, int H, int W,
|
| 84 |
+
float delta) {
|
| 85 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 86 |
+
atomicAdd(data + d * sD + h * sH + w * sW, delta);
|
| 87 |
+
}
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
static __forceinline__ __device__
|
| 91 |
+
void safe_add_3d(float3 *data, int d, int h, int w,
|
| 92 |
+
int sD, int sH, int sW, int D, int H, int W,
|
| 93 |
+
float3 delta) {
|
| 94 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 95 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 0, delta.x);
|
| 96 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 1, delta.y);
|
| 97 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 3 + 2, delta.z);
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
static __forceinline__ __device__
|
| 102 |
+
void safe_add_3d(float4 *data, int d, int h, int w,
|
| 103 |
+
int sD, int sH, int sW, int D, int H, int W,
|
| 104 |
+
float4 delta) {
|
| 105 |
+
if (within_bounds_3d(d, h, w, D, H, W)) {
|
| 106 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 0, delta.x);
|
| 107 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 1, delta.y);
|
| 108 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 2, delta.z);
|
| 109 |
+
atomicAdd((float*)data + (d * sD + h * sH + w * sW) * 4 + 3, delta.w);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
static __forceinline__ __device__
|
| 114 |
+
float clip_coordinates(float in, int clip_limit) {
|
| 115 |
+
return ::min(static_cast<float>(clip_limit - 1), ::max(in, 0.f));
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
template <typename scalar_t>
|
| 119 |
+
static __forceinline__ __device__
|
| 120 |
+
float clip_coordinates_set_grad(float in, int clip_limit, scalar_t *grad_in) {
|
| 121 |
+
if (in < 0.f) {
|
| 122 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 123 |
+
return 0.f;
|
| 124 |
+
} else {
|
| 125 |
+
float max = static_cast<float>(clip_limit - 1);
|
| 126 |
+
if (in > max) {
|
| 127 |
+
*grad_in = static_cast<scalar_t>(0);
|
| 128 |
+
return max;
|
| 129 |
+
} else {
|
| 130 |
+
*grad_in = static_cast<scalar_t>(1);
|
| 131 |
+
return in;
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
}
|
| 135 |
+
|
| 136 |
+
template<typename out_t>
|
| 137 |
+
static __device__ out_t grid_sample_forward(int C, int inp_D, int inp_H,
|
| 138 |
+
int inp_W, float* vals, float3 pos, bool border) {
|
| 139 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
|
| 140 |
+
int out_sC = 1;
|
| 141 |
+
|
| 142 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
| 143 |
+
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
|
| 144 |
+
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
|
| 145 |
+
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
|
| 146 |
+
|
| 147 |
+
if (border) {
|
| 148 |
+
// clip coordinates to image borders
|
| 149 |
+
ix = clip_coordinates(ix, inp_W);
|
| 150 |
+
iy = clip_coordinates(iy, inp_H);
|
| 151 |
+
iz = clip_coordinates(iz, inp_D);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
// get corner pixel values from (x, y, z)
|
| 155 |
+
// for 4d, we used north-east-south-west
|
| 156 |
+
// for 5d, we add top-bottom
|
| 157 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
| 158 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
| 159 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
| 160 |
+
|
| 161 |
+
int ix_tne = ix_tnw + 1;
|
| 162 |
+
int iy_tne = iy_tnw;
|
| 163 |
+
int iz_tne = iz_tnw;
|
| 164 |
+
|
| 165 |
+
int ix_tsw = ix_tnw;
|
| 166 |
+
int iy_tsw = iy_tnw + 1;
|
| 167 |
+
int iz_tsw = iz_tnw;
|
| 168 |
+
|
| 169 |
+
int ix_tse = ix_tnw + 1;
|
| 170 |
+
int iy_tse = iy_tnw + 1;
|
| 171 |
+
int iz_tse = iz_tnw;
|
| 172 |
+
|
| 173 |
+
int ix_bnw = ix_tnw;
|
| 174 |
+
int iy_bnw = iy_tnw;
|
| 175 |
+
int iz_bnw = iz_tnw + 1;
|
| 176 |
+
|
| 177 |
+
int ix_bne = ix_tnw + 1;
|
| 178 |
+
int iy_bne = iy_tnw;
|
| 179 |
+
int iz_bne = iz_tnw + 1;
|
| 180 |
+
|
| 181 |
+
int ix_bsw = ix_tnw;
|
| 182 |
+
int iy_bsw = iy_tnw + 1;
|
| 183 |
+
int iz_bsw = iz_tnw + 1;
|
| 184 |
+
|
| 185 |
+
int ix_bse = ix_tnw + 1;
|
| 186 |
+
int iy_bse = iy_tnw + 1;
|
| 187 |
+
int iz_bse = iz_tnw + 1;
|
| 188 |
+
|
| 189 |
+
// get surfaces to each neighbor:
|
| 190 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
| 191 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
| 192 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
| 193 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
| 194 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
| 195 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
| 196 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
| 197 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
| 198 |
+
|
| 199 |
+
out_t result;
|
| 200 |
+
//auto inp_ptr_NC = input.data + n * inp_sN;
|
| 201 |
+
//auto out_ptr_NCDHW = output.data + n * out_sN + d * out_sD + h * out_sH + w * out_sW;
|
| 202 |
+
float * inp_ptr_NC = vals;
|
| 203 |
+
float * out_ptr_NCDHW = &result.x;
|
| 204 |
+
for (int c = 0; c < C; ++c, inp_ptr_NC += inp_sC, out_ptr_NCDHW += out_sC) {
|
| 205 |
+
// (c, iz_tnw, iy_tnw, ix_tnw) * tnw + (c, iz_tne, iy_tne, ix_tne) * tne
|
| 206 |
+
// + (c, iz_tsw, iy_tsw, ix_tsw) * tsw + (c, iz_tse, iy_tse, ix_tse) * tse
|
| 207 |
+
// + (c, iz_bnw, iy_bnw, ix_bnw) * bnw + (c, iz_bne, iy_bne, ix_bne) * bne
|
| 208 |
+
// + (c, iz_bsw, iy_bsw, ix_bsw) * bsw + (c, iz_bse, iy_bse, ix_bse) * bse
|
| 209 |
+
*out_ptr_NCDHW = static_cast<float>(0);
|
| 210 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
| 211 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
|
| 212 |
+
}
|
| 213 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
| 214 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
|
| 215 |
+
}
|
| 216 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
| 217 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
|
| 218 |
+
}
|
| 219 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
| 220 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
|
| 221 |
+
}
|
| 222 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
| 223 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
|
| 224 |
+
}
|
| 225 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
| 226 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
|
| 227 |
+
}
|
| 228 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
| 229 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
|
| 230 |
+
}
|
| 231 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
| 232 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
|
| 233 |
+
}
|
| 234 |
+
}
|
| 235 |
+
return result;
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
template<typename out_t>
|
| 239 |
+
static __device__ float3 grid_sample_backward(int C, int inp_D, int inp_H,
|
| 240 |
+
int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
|
| 241 |
+
bool border) {
|
| 242 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H, inp_sC = inp_W * inp_H * inp_D;
|
| 243 |
+
int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H, gInp_sC = inp_W * inp_H * inp_D;
|
| 244 |
+
int gOut_sC = 1;
|
| 245 |
+
|
| 246 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
| 247 |
+
float ix = max(-10.f, min(10.f, ((pos.x + 1.f) * 0.5f))) * (inp_W - 1);
|
| 248 |
+
float iy = max(-10.f, min(10.f, ((pos.y + 1.f) * 0.5f))) * (inp_H - 1);
|
| 249 |
+
float iz = max(-10.f, min(10.f, ((pos.z + 1.f) * 0.5f))) * (inp_D - 1);
|
| 250 |
+
|
| 251 |
+
float gix_mult = (inp_W - 1.f) / 2;
|
| 252 |
+
float giy_mult = (inp_H - 1.f) / 2;
|
| 253 |
+
float giz_mult = (inp_D - 1.f) / 2;
|
| 254 |
+
|
| 255 |
+
if (border) {
|
| 256 |
+
// clip coordinates to image borders
|
| 257 |
+
ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
|
| 258 |
+
iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
|
| 259 |
+
iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
|
| 260 |
+
}
|
| 261 |
+
|
| 262 |
+
// get corner pixel values from (x, y, z)
|
| 263 |
+
// for 4d, we used north-east-south-west
|
| 264 |
+
// for 5d, we add top-bottom
|
| 265 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
| 266 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
| 267 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
| 268 |
+
|
| 269 |
+
int ix_tne = ix_tnw + 1;
|
| 270 |
+
int iy_tne = iy_tnw;
|
| 271 |
+
int iz_tne = iz_tnw;
|
| 272 |
+
|
| 273 |
+
int ix_tsw = ix_tnw;
|
| 274 |
+
int iy_tsw = iy_tnw + 1;
|
| 275 |
+
int iz_tsw = iz_tnw;
|
| 276 |
+
|
| 277 |
+
int ix_tse = ix_tnw + 1;
|
| 278 |
+
int iy_tse = iy_tnw + 1;
|
| 279 |
+
int iz_tse = iz_tnw;
|
| 280 |
+
|
| 281 |
+
int ix_bnw = ix_tnw;
|
| 282 |
+
int iy_bnw = iy_tnw;
|
| 283 |
+
int iz_bnw = iz_tnw + 1;
|
| 284 |
+
|
| 285 |
+
int ix_bne = ix_tnw + 1;
|
| 286 |
+
int iy_bne = iy_tnw;
|
| 287 |
+
int iz_bne = iz_tnw + 1;
|
| 288 |
+
|
| 289 |
+
int ix_bsw = ix_tnw;
|
| 290 |
+
int iy_bsw = iy_tnw + 1;
|
| 291 |
+
int iz_bsw = iz_tnw + 1;
|
| 292 |
+
|
| 293 |
+
int ix_bse = ix_tnw + 1;
|
| 294 |
+
int iy_bse = iy_tnw + 1;
|
| 295 |
+
int iz_bse = iz_tnw + 1;
|
| 296 |
+
|
| 297 |
+
// get surfaces to each neighbor:
|
| 298 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
| 299 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
| 300 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
| 301 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
| 302 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
| 303 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
| 304 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
| 305 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
| 306 |
+
|
| 307 |
+
float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
|
| 308 |
+
//float *gOut_ptr_NCDHW = grad_output.data + n * gOut_sN + d * gOut_sD + h * gOut_sH + w * gOut_sW;
|
| 309 |
+
//float *gInp_ptr_NC = grad_input.data + n * gInp_sN;
|
| 310 |
+
//float *inp_ptr_NC = input.data + n * inp_sN;
|
| 311 |
+
float *gOut_ptr_NCDHW = &grad_out.x;
|
| 312 |
+
float *gInp_ptr_NC = grad_vals;
|
| 313 |
+
float *inp_ptr_NC = vals;
|
| 314 |
+
// calculate bilinear weighted pixel value and set output pixel
|
| 315 |
+
for (int c = 0; c < C; ++c, gOut_ptr_NCDHW += gOut_sC, gInp_ptr_NC += gInp_sC, inp_ptr_NC += inp_sC) {
|
| 316 |
+
float gOut = *gOut_ptr_NCDHW;
|
| 317 |
+
|
| 318 |
+
// calculate and set grad_input
|
| 319 |
+
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
|
| 320 |
+
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
|
| 321 |
+
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
|
| 322 |
+
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
|
| 323 |
+
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
|
| 324 |
+
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
|
| 325 |
+
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
|
| 326 |
+
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
|
| 327 |
+
|
| 328 |
+
// calculate grad_grid
|
| 329 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
| 330 |
+
float tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
|
| 331 |
+
gix -= tnw_val * (iy_bse - iy) * (iz_bse - iz) * gOut;
|
| 332 |
+
giy -= tnw_val * (ix_bse - ix) * (iz_bse - iz) * gOut;
|
| 333 |
+
giz -= tnw_val * (ix_bse - ix) * (iy_bse - iy) * gOut;
|
| 334 |
+
}
|
| 335 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
| 336 |
+
float tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
|
| 337 |
+
gix += tne_val * (iy_bsw - iy) * (iz_bsw - iz) * gOut;
|
| 338 |
+
giy -= tne_val * (ix - ix_bsw) * (iz_bsw - iz) * gOut;
|
| 339 |
+
giz -= tne_val * (ix - ix_bsw) * (iy_bsw - iy) * gOut;
|
| 340 |
+
}
|
| 341 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
| 342 |
+
float tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
|
| 343 |
+
gix -= tsw_val * (iy - iy_bne) * (iz_bne - iz) * gOut;
|
| 344 |
+
giy += tsw_val * (ix_bne - ix) * (iz_bne - iz) * gOut;
|
| 345 |
+
giz -= tsw_val * (ix_bne - ix) * (iy - iy_bne) * gOut;
|
| 346 |
+
}
|
| 347 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
| 348 |
+
float tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
|
| 349 |
+
gix += tse_val * (iy - iy_bnw) * (iz_bnw - iz) * gOut;
|
| 350 |
+
giy += tse_val * (ix - ix_bnw) * (iz_bnw - iz) * gOut;
|
| 351 |
+
giz -= tse_val * (ix - ix_bnw) * (iy - iy_bnw) * gOut;
|
| 352 |
+
}
|
| 353 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
| 354 |
+
float bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
|
| 355 |
+
gix -= bnw_val * (iy_tse - iy) * (iz - iz_tse) * gOut;
|
| 356 |
+
giy -= bnw_val * (ix_tse - ix) * (iz - iz_tse) * gOut;
|
| 357 |
+
giz += bnw_val * (ix_tse - ix) * (iy_tse - iy) * gOut;
|
| 358 |
+
}
|
| 359 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
| 360 |
+
float bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
|
| 361 |
+
gix += bne_val * (iy_tsw - iy) * (iz - iz_tsw) * gOut;
|
| 362 |
+
giy -= bne_val * (ix - ix_tsw) * (iz - iz_tsw) * gOut;
|
| 363 |
+
giz += bne_val * (ix - ix_tsw) * (iy_tsw - iy) * gOut;
|
| 364 |
+
}
|
| 365 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
| 366 |
+
float bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
|
| 367 |
+
gix -= bsw_val * (iy - iy_tne) * (iz - iz_tne) * gOut;
|
| 368 |
+
giy += bsw_val * (ix_tne - ix) * (iz - iz_tne) * gOut;
|
| 369 |
+
giz += bsw_val * (ix_tne - ix) * (iy - iy_tne) * gOut;
|
| 370 |
+
}
|
| 371 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
| 372 |
+
float bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
|
| 373 |
+
gix += bse_val * (iy - iy_tnw) * (iz - iz_tnw) * gOut;
|
| 374 |
+
giy += bse_val * (ix - ix_tnw) * (iz - iz_tnw) * gOut;
|
| 375 |
+
giz += bse_val * (ix - ix_tnw) * (iy - iy_tnw) * gOut;
|
| 376 |
+
}
|
| 377 |
+
}
|
| 378 |
+
|
| 379 |
+
return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
// this dummy struct necessary because c++ is dumb
|
| 383 |
+
template<typename out_t>
|
| 384 |
+
struct GridSampler {
|
| 385 |
+
static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W,
|
| 386 |
+
float* vals, float3 pos, bool border) {
|
| 387 |
+
return grid_sample_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border);
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W,
|
| 391 |
+
float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) {
|
| 392 |
+
return grid_sample_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border);
|
| 393 |
+
}
|
| 394 |
+
};
|
| 395 |
+
|
| 396 |
+
//template <typename T>
|
| 397 |
+
//__device__ void cswap ( T& a, T& b ) {
|
| 398 |
+
// T c(a); a=b; b=c;
|
| 399 |
+
//}
|
| 400 |
+
|
| 401 |
+
static __forceinline__ __device__
|
| 402 |
+
int within_bounds_3d_ind(int d, int h, int w, int D, int H, int W) {
|
| 403 |
+
return d >= 0 && d < D && h >= 0 && h < H && w >= 0 && w < W ? ((d * H) + h) * W + w : -1;
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
template<class out_t>
|
| 407 |
+
static __device__ out_t grid_sample_chlast_forward(int, int inp_D, int inp_H,
|
| 408 |
+
int inp_W, float * vals, float3 pos, bool border) {
|
| 409 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H;
|
| 410 |
+
|
| 411 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
| 412 |
+
float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1);
|
| 413 |
+
float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1);
|
| 414 |
+
float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1);
|
| 415 |
+
|
| 416 |
+
if (border) {
|
| 417 |
+
// clip coordinates to image borders
|
| 418 |
+
ix = clip_coordinates(ix, inp_W);
|
| 419 |
+
iy = clip_coordinates(iy, inp_H);
|
| 420 |
+
iz = clip_coordinates(iz, inp_D);
|
| 421 |
+
}
|
| 422 |
+
|
| 423 |
+
// get corner pixel values from (x, y, z)
|
| 424 |
+
// for 4d, we used north-east-south-west
|
| 425 |
+
// for 5d, we add top-bottom
|
| 426 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
| 427 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
| 428 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
| 429 |
+
|
| 430 |
+
int ix_tne = ix_tnw + 1;
|
| 431 |
+
int iy_tne = iy_tnw;
|
| 432 |
+
int iz_tne = iz_tnw;
|
| 433 |
+
|
| 434 |
+
int ix_tsw = ix_tnw;
|
| 435 |
+
int iy_tsw = iy_tnw + 1;
|
| 436 |
+
int iz_tsw = iz_tnw;
|
| 437 |
+
|
| 438 |
+
int ix_tse = ix_tnw + 1;
|
| 439 |
+
int iy_tse = iy_tnw + 1;
|
| 440 |
+
int iz_tse = iz_tnw;
|
| 441 |
+
|
| 442 |
+
int ix_bnw = ix_tnw;
|
| 443 |
+
int iy_bnw = iy_tnw;
|
| 444 |
+
int iz_bnw = iz_tnw + 1;
|
| 445 |
+
|
| 446 |
+
int ix_bne = ix_tnw + 1;
|
| 447 |
+
int iy_bne = iy_tnw;
|
| 448 |
+
int iz_bne = iz_tnw + 1;
|
| 449 |
+
|
| 450 |
+
int ix_bsw = ix_tnw;
|
| 451 |
+
int iy_bsw = iy_tnw + 1;
|
| 452 |
+
int iz_bsw = iz_tnw + 1;
|
| 453 |
+
|
| 454 |
+
int ix_bse = ix_tnw + 1;
|
| 455 |
+
int iy_bse = iy_tnw + 1;
|
| 456 |
+
int iz_bse = iz_tnw + 1;
|
| 457 |
+
|
| 458 |
+
// get surfaces to each neighbor:
|
| 459 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
| 460 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
| 461 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
| 462 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
| 463 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
| 464 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
| 465 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
| 466 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
| 467 |
+
|
| 468 |
+
out_t result;
|
| 469 |
+
memset(&result, 0, sizeof(out_t));
|
| 470 |
+
out_t * inp_ptr_NC = (out_t*)vals;
|
| 471 |
+
out_t * out_ptr_NCDHW = &result;
|
| 472 |
+
{
|
| 473 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
| 474 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW] * tnw;
|
| 475 |
+
}
|
| 476 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
| 477 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW] * tne;
|
| 478 |
+
}
|
| 479 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
| 480 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW] * tsw;
|
| 481 |
+
}
|
| 482 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
| 483 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW] * tse;
|
| 484 |
+
}
|
| 485 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
| 486 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW] * bnw;
|
| 487 |
+
}
|
| 488 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
| 489 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW] * bne;
|
| 490 |
+
}
|
| 491 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
| 492 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW] * bsw;
|
| 493 |
+
}
|
| 494 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
| 495 |
+
*out_ptr_NCDHW += inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW] * bse;
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
return result;
|
| 500 |
+
}
|
| 501 |
+
|
| 502 |
+
template<typename out_t>
|
| 503 |
+
static __device__ float3 grid_sample_chlast_backward(int, int inp_D, int inp_H,
|
| 504 |
+
int inp_W, float* vals, float* grad_vals, float3 pos, out_t grad_out,
|
| 505 |
+
bool border) {
|
| 506 |
+
int inp_sW = 1, inp_sH = inp_W, inp_sD = inp_W * inp_H;
|
| 507 |
+
int gInp_sW = 1, gInp_sH = inp_W, gInp_sD = inp_W * inp_H;
|
| 508 |
+
|
| 509 |
+
// normalize ix, iy, iz from [-1, 1] to [0, inp_W-1] & [0, inp_H-1] & [0, inp_D-1]
|
| 510 |
+
float ix = max(-100.f, min(100.f, ((pos.x + 1.f) / 2))) * (inp_W - 1);
|
| 511 |
+
float iy = max(-100.f, min(100.f, ((pos.y + 1.f) / 2))) * (inp_H - 1);
|
| 512 |
+
float iz = max(-100.f, min(100.f, ((pos.z + 1.f) / 2))) * (inp_D - 1);
|
| 513 |
+
|
| 514 |
+
float gix_mult = (inp_W - 1.f) / 2;
|
| 515 |
+
float giy_mult = (inp_H - 1.f) / 2;
|
| 516 |
+
float giz_mult = (inp_D - 1.f) / 2;
|
| 517 |
+
|
| 518 |
+
if (border) {
|
| 519 |
+
// clip coordinates to image borders
|
| 520 |
+
ix = clip_coordinates_set_grad(ix, inp_W, &gix_mult);
|
| 521 |
+
iy = clip_coordinates_set_grad(iy, inp_H, &giy_mult);
|
| 522 |
+
iz = clip_coordinates_set_grad(iz, inp_D, &giz_mult);
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
// get corner pixel values from (x, y, z)
|
| 526 |
+
// for 4d, we used north-east-south-west
|
| 527 |
+
// for 5d, we add top-bottom
|
| 528 |
+
int ix_tnw = static_cast<int>(::floor(ix));
|
| 529 |
+
int iy_tnw = static_cast<int>(::floor(iy));
|
| 530 |
+
int iz_tnw = static_cast<int>(::floor(iz));
|
| 531 |
+
|
| 532 |
+
int ix_tne = ix_tnw + 1;
|
| 533 |
+
int iy_tne = iy_tnw;
|
| 534 |
+
int iz_tne = iz_tnw;
|
| 535 |
+
|
| 536 |
+
int ix_tsw = ix_tnw;
|
| 537 |
+
int iy_tsw = iy_tnw + 1;
|
| 538 |
+
int iz_tsw = iz_tnw;
|
| 539 |
+
|
| 540 |
+
int ix_tse = ix_tnw + 1;
|
| 541 |
+
int iy_tse = iy_tnw + 1;
|
| 542 |
+
int iz_tse = iz_tnw;
|
| 543 |
+
|
| 544 |
+
int ix_bnw = ix_tnw;
|
| 545 |
+
int iy_bnw = iy_tnw;
|
| 546 |
+
int iz_bnw = iz_tnw + 1;
|
| 547 |
+
|
| 548 |
+
int ix_bne = ix_tnw + 1;
|
| 549 |
+
int iy_bne = iy_tnw;
|
| 550 |
+
int iz_bne = iz_tnw + 1;
|
| 551 |
+
|
| 552 |
+
int ix_bsw = ix_tnw;
|
| 553 |
+
int iy_bsw = iy_tnw + 1;
|
| 554 |
+
int iz_bsw = iz_tnw + 1;
|
| 555 |
+
|
| 556 |
+
int ix_bse = ix_tnw + 1;
|
| 557 |
+
int iy_bse = iy_tnw + 1;
|
| 558 |
+
int iz_bse = iz_tnw + 1;
|
| 559 |
+
|
| 560 |
+
// get surfaces to each neighbor:
|
| 561 |
+
float tnw = (ix_bse - ix) * (iy_bse - iy) * (iz_bse - iz);
|
| 562 |
+
float tne = (ix - ix_bsw) * (iy_bsw - iy) * (iz_bsw - iz);
|
| 563 |
+
float tsw = (ix_bne - ix) * (iy - iy_bne) * (iz_bne - iz);
|
| 564 |
+
float tse = (ix - ix_bnw) * (iy - iy_bnw) * (iz_bnw - iz);
|
| 565 |
+
float bnw = (ix_tse - ix) * (iy_tse - iy) * (iz - iz_tse);
|
| 566 |
+
float bne = (ix - ix_tsw) * (iy_tsw - iy) * (iz - iz_tsw);
|
| 567 |
+
float bsw = (ix_tne - ix) * (iy - iy_tne) * (iz - iz_tne);
|
| 568 |
+
float bse = (ix - ix_tnw) * (iy - iy_tnw) * (iz - iz_tnw);
|
| 569 |
+
|
| 570 |
+
float gix = static_cast<float>(0), giy = static_cast<float>(0), giz = static_cast<float>(0);
|
| 571 |
+
out_t *gOut_ptr_NCDHW = &grad_out;
|
| 572 |
+
out_t *gInp_ptr_NC = (out_t*)grad_vals;
|
| 573 |
+
out_t *inp_ptr_NC = (out_t*)vals;
|
| 574 |
+
|
| 575 |
+
// calculate bilinear weighted pixel value and set output pixel
|
| 576 |
+
{
|
| 577 |
+
out_t gOut = *gOut_ptr_NCDHW;
|
| 578 |
+
|
| 579 |
+
// calculate and set grad_input
|
| 580 |
+
safe_add_3d(gInp_ptr_NC, iz_tnw, iy_tnw, ix_tnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tnw * gOut);
|
| 581 |
+
safe_add_3d(gInp_ptr_NC, iz_tne, iy_tne, ix_tne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tne * gOut);
|
| 582 |
+
safe_add_3d(gInp_ptr_NC, iz_tsw, iy_tsw, ix_tsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tsw * gOut);
|
| 583 |
+
safe_add_3d(gInp_ptr_NC, iz_tse, iy_tse, ix_tse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, tse * gOut);
|
| 584 |
+
safe_add_3d(gInp_ptr_NC, iz_bnw, iy_bnw, ix_bnw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bnw * gOut);
|
| 585 |
+
safe_add_3d(gInp_ptr_NC, iz_bne, iy_bne, ix_bne, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bne * gOut);
|
| 586 |
+
safe_add_3d(gInp_ptr_NC, iz_bsw, iy_bsw, ix_bsw, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bsw * gOut);
|
| 587 |
+
safe_add_3d(gInp_ptr_NC, iz_bse, iy_bse, ix_bse, gInp_sD, gInp_sH, gInp_sW, inp_D, inp_H, inp_W, bse * gOut);
|
| 588 |
+
|
| 589 |
+
// calculate grad_grid
|
| 590 |
+
if (within_bounds_3d(iz_tnw, iy_tnw, ix_tnw, inp_D, inp_H, inp_W)) {
|
| 591 |
+
out_t tnw_val = inp_ptr_NC[iz_tnw * inp_sD + iy_tnw * inp_sH + ix_tnw * inp_sW];
|
| 592 |
+
gix -= (iy_bse - iy) * (iz_bse - iz) * dot(tnw_val, gOut);
|
| 593 |
+
giy -= (ix_bse - ix) * (iz_bse - iz) * dot(tnw_val, gOut);
|
| 594 |
+
giz -= (ix_bse - ix) * (iy_bse - iy) * dot(tnw_val, gOut);
|
| 595 |
+
}
|
| 596 |
+
if (within_bounds_3d(iz_tne, iy_tne, ix_tne, inp_D, inp_H, inp_W)) {
|
| 597 |
+
out_t tne_val = inp_ptr_NC[iz_tne * inp_sD + iy_tne * inp_sH + ix_tne * inp_sW];
|
| 598 |
+
gix += (iy_bsw - iy) * (iz_bsw - iz) * dot(tne_val, gOut);
|
| 599 |
+
giy -= (ix - ix_bsw) * (iz_bsw - iz) * dot(tne_val, gOut);
|
| 600 |
+
giz -= (ix - ix_bsw) * (iy_bsw - iy) * dot(tne_val, gOut);
|
| 601 |
+
}
|
| 602 |
+
if (within_bounds_3d(iz_tsw, iy_tsw, ix_tsw, inp_D, inp_H, inp_W)) {
|
| 603 |
+
out_t tsw_val = inp_ptr_NC[iz_tsw * inp_sD + iy_tsw * inp_sH + ix_tsw * inp_sW];
|
| 604 |
+
gix -= (iy - iy_bne) * (iz_bne - iz) * dot(tsw_val, gOut);
|
| 605 |
+
giy += (ix_bne - ix) * (iz_bne - iz) * dot(tsw_val, gOut);
|
| 606 |
+
giz -= (ix_bne - ix) * (iy - iy_bne) * dot(tsw_val, gOut);
|
| 607 |
+
}
|
| 608 |
+
if (within_bounds_3d(iz_tse, iy_tse, ix_tse, inp_D, inp_H, inp_W)) {
|
| 609 |
+
out_t tse_val = inp_ptr_NC[iz_tse * inp_sD + iy_tse * inp_sH + ix_tse * inp_sW];
|
| 610 |
+
gix += (iy - iy_bnw) * (iz_bnw - iz) * dot(tse_val, gOut);
|
| 611 |
+
giy += (ix - ix_bnw) * (iz_bnw - iz) * dot(tse_val, gOut);
|
| 612 |
+
giz -= (ix - ix_bnw) * (iy - iy_bnw) * dot(tse_val, gOut);
|
| 613 |
+
}
|
| 614 |
+
if (within_bounds_3d(iz_bnw, iy_bnw, ix_bnw, inp_D, inp_H, inp_W)) {
|
| 615 |
+
out_t bnw_val = inp_ptr_NC[iz_bnw * inp_sD + iy_bnw * inp_sH + ix_bnw * inp_sW];
|
| 616 |
+
gix -= (iy_tse - iy) * (iz - iz_tse) * dot(bnw_val, gOut);
|
| 617 |
+
giy -= (ix_tse - ix) * (iz - iz_tse) * dot(bnw_val, gOut);
|
| 618 |
+
giz += (ix_tse - ix) * (iy_tse - iy) * dot(bnw_val, gOut);
|
| 619 |
+
}
|
| 620 |
+
if (within_bounds_3d(iz_bne, iy_bne, ix_bne, inp_D, inp_H, inp_W)) {
|
| 621 |
+
out_t bne_val = inp_ptr_NC[iz_bne * inp_sD + iy_bne * inp_sH + ix_bne * inp_sW];
|
| 622 |
+
gix += (iy_tsw - iy) * (iz - iz_tsw) * dot(bne_val, gOut);
|
| 623 |
+
giy -= (ix - ix_tsw) * (iz - iz_tsw) * dot(bne_val, gOut);
|
| 624 |
+
giz += (ix - ix_tsw) * (iy_tsw - iy) * dot(bne_val, gOut);
|
| 625 |
+
}
|
| 626 |
+
if (within_bounds_3d(iz_bsw, iy_bsw, ix_bsw, inp_D, inp_H, inp_W)) {
|
| 627 |
+
out_t bsw_val = inp_ptr_NC[iz_bsw * inp_sD + iy_bsw * inp_sH + ix_bsw * inp_sW];
|
| 628 |
+
gix -= (iy - iy_tne) * (iz - iz_tne) * dot(bsw_val, gOut);
|
| 629 |
+
giy += (ix_tne - ix) * (iz - iz_tne) * dot(bsw_val, gOut);
|
| 630 |
+
giz += (ix_tne - ix) * (iy - iy_tne) * dot(bsw_val, gOut);
|
| 631 |
+
}
|
| 632 |
+
if (within_bounds_3d(iz_bse, iy_bse, ix_bse, inp_D, inp_H, inp_W)) {
|
| 633 |
+
out_t bse_val = inp_ptr_NC[iz_bse * inp_sD + iy_bse * inp_sH + ix_bse * inp_sW];
|
| 634 |
+
gix += (iy - iy_tnw) * (iz - iz_tnw) * dot(bse_val, gOut);
|
| 635 |
+
giy += (ix - ix_tnw) * (iz - iz_tnw) * dot(bse_val, gOut);
|
| 636 |
+
giz += (ix - ix_tnw) * (iy - iy_tnw) * dot(bse_val, gOut);
|
| 637 |
+
}
|
| 638 |
+
}
|
| 639 |
+
|
| 640 |
+
return make_float3(gix_mult * gix, giy_mult * giy, giz_mult * giz);
|
| 641 |
+
}
|
| 642 |
+
|
| 643 |
+
template<typename out_t>
|
| 644 |
+
struct GridSamplerChlast {
|
| 645 |
+
static __forceinline__ __device__ out_t forward(int C, int inp_D, int inp_H, int inp_W,
|
| 646 |
+
float* vals, float3 pos, bool border) {
|
| 647 |
+
return grid_sample_chlast_forward<out_t>(C, inp_D, inp_H, inp_W, vals, pos, border);
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
static __forceinline__ __device__ float3 backward(int C, int inp_D, int inp_H, int inp_W,
|
| 651 |
+
float* vals, float* grad_vals, float3 pos, out_t grad_out, bool border) {
|
| 652 |
+
return grid_sample_chlast_backward<out_t>(C, inp_D, inp_H, inp_W, vals, grad_vals, pos, grad_out, border);
|
| 653 |
+
}
|
| 654 |
+
};
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
inline __host__ __device__ float min_component(float3 a) {
|
| 658 |
+
return fminf(fminf(a.x,a.y),a.z);
|
| 659 |
+
}
|
| 660 |
+
|
| 661 |
+
inline __host__ __device__ float max_component(float3 a) {
|
| 662 |
+
return fmaxf(fmaxf(a.x,a.y),a.z);
|
| 663 |
+
}
|
| 664 |
+
|
| 665 |
+
inline __host__ __device__ float3 abs(float3 a) {
|
| 666 |
+
return make_float3(abs(a.x), abs(a.y), abs(a.z));
|
| 667 |
+
}
|
| 668 |
+
|
| 669 |
+
__forceinline__ __device__ bool ray_aabb_hit(float3 p0, float3 p1, float3 raypos, float3 raydir) {
|
| 670 |
+
float3 t0 = (p0 - raypos) / raydir;
|
| 671 |
+
float3 t1 = (p1 - raypos) / raydir;
|
| 672 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
| 673 |
+
|
| 674 |
+
return max_component(tmin) <= min_component(tmax);
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
__forceinline__ __device__ bool ray_aabb_hit_ird(float3 p0, float3 p1, float3 raypos, float3 ird) {
|
| 678 |
+
float3 t0 = (p0 - raypos) * ird;
|
| 679 |
+
float3 t1 = (p1 - raypos) * ird;
|
| 680 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
| 681 |
+
|
| 682 |
+
return max_component(tmin) <= min_component(tmax);
|
| 683 |
+
|
| 684 |
+
}
|
| 685 |
+
__forceinline__ __device__ void ray_aabb_hit_ird_tminmax(float3 p0, float3 p1,
|
| 686 |
+
float3 raypos, float3 ird, float &otmin, float &otmax) {
|
| 687 |
+
float3 t0 = (p0 - raypos) * ird;
|
| 688 |
+
float3 t1 = (p1 - raypos) * ird;
|
| 689 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
| 690 |
+
tmin = fminf(t0,t1);
|
| 691 |
+
tmax = fmaxf(t0,t1);
|
| 692 |
+
otmin = max_component(tmin);
|
| 693 |
+
otmax = min_component(tmax);
|
| 694 |
+
}
|
| 695 |
+
|
| 696 |
+
inline __device__ bool aabb_intersect(float3 p0, float3 p1, float3 r0, float3 rd, float &tmin, float &tmax) {
|
| 697 |
+
float tymin, tymax, tzmin, tzmax;
|
| 698 |
+
const float3 bounds[2] = {p0, p1};
|
| 699 |
+
float3 ird = 1.0f/rd;
|
| 700 |
+
int sx = (ird.x<0) ? 1 : 0;
|
| 701 |
+
int sy = (ird.y<0) ? 1 : 0;
|
| 702 |
+
int sz = (ird.z<0) ? 1 : 0;
|
| 703 |
+
tmin = (bounds[sx].x - r0.x) * ird.x;
|
| 704 |
+
tmax = (bounds[1-sx].x - r0.x) * ird.x;
|
| 705 |
+
tymin = (bounds[sy].y - r0.y) * ird.y;
|
| 706 |
+
tymax = (bounds[1-sy].y - r0.y) * ird.y;
|
| 707 |
+
|
| 708 |
+
if ((tmin > tymax) || (tymin > tmax))
|
| 709 |
+
return false;
|
| 710 |
+
if (tymin > tmin)
|
| 711 |
+
tmin = tymin;
|
| 712 |
+
if (tymax < tmax)
|
| 713 |
+
tmax = tymax;
|
| 714 |
+
|
| 715 |
+
tzmin = (bounds[sz].z - r0.z) * ird.z;
|
| 716 |
+
tzmax = (bounds[1-sz].z - r0.z) * ird.z;
|
| 717 |
+
|
| 718 |
+
if ((tmin > tzmax) || (tzmin > tmax))
|
| 719 |
+
return false;
|
| 720 |
+
if (tzmin > tmin)
|
| 721 |
+
tmin = tzmin;
|
| 722 |
+
if (tzmax < tmax)
|
| 723 |
+
tmax = tzmax;
|
| 724 |
+
|
| 725 |
+
return true;
|
| 726 |
+
}
|
| 727 |
+
|
| 728 |
+
template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT>
|
| 729 |
+
static __forceinline__ __device__ void ray_subset_fixedbvh(
|
| 730 |
+
unsigned warpmask,
|
| 731 |
+
int K,
|
| 732 |
+
float3 raypos,
|
| 733 |
+
float3 raydir,
|
| 734 |
+
float2 tminmax,
|
| 735 |
+
float2 &rtminmax,
|
| 736 |
+
int * sortedobjid,
|
| 737 |
+
int2 * nodechildren,
|
| 738 |
+
float3 * nodeaabb,
|
| 739 |
+
const typename PrimTransfT::Data & primtransf_data,
|
| 740 |
+
int *hitboxes,
|
| 741 |
+
int & num) {
|
| 742 |
+
float3 iraydir = 1.0f/raydir;
|
| 743 |
+
int stack[64];
|
| 744 |
+
int* stack_ptr = stack;
|
| 745 |
+
*stack_ptr++ = -1;
|
| 746 |
+
int node = 0;
|
| 747 |
+
do {
|
| 748 |
+
// check if we're in a leaf
|
| 749 |
+
if (node >= (K - 1)) {
|
| 750 |
+
{
|
| 751 |
+
int k = node - (K - 1);
|
| 752 |
+
|
| 753 |
+
float3 r0, rd;
|
| 754 |
+
PrimTransfT::forward2(primtransf_data, k, raypos, raydir, r0, rd);
|
| 755 |
+
|
| 756 |
+
float3 ird = 1.0f/rd;
|
| 757 |
+
float3 t0 = (-1.f - r0) * ird;
|
| 758 |
+
float3 t1 = (1.f - r0) * ird;
|
| 759 |
+
float3 tmin = fminf(t0,t1), tmax = fmaxf(t0,t1);
|
| 760 |
+
|
| 761 |
+
float trmin = max_component(tmin);
|
| 762 |
+
float trmax = min_component(tmax);
|
| 763 |
+
|
| 764 |
+
bool intersection = trmin <= trmax;
|
| 765 |
+
|
| 766 |
+
if (intersection) {
|
| 767 |
+
// hit
|
| 768 |
+
rtminmax.x = fminf(rtminmax.x, trmin);
|
| 769 |
+
rtminmax.y = fmaxf(rtminmax.y, trmax);
|
| 770 |
+
}
|
| 771 |
+
|
| 772 |
+
if (sync) {
|
| 773 |
+
intersection = __any_sync(warpmask, intersection);
|
| 774 |
+
}
|
| 775 |
+
|
| 776 |
+
if (intersection) {
|
| 777 |
+
if (sortboxes) {
|
| 778 |
+
if (num < maxhitboxes) {
|
| 779 |
+
int j = num - 1;
|
| 780 |
+
while (j >= 0 && hitboxes[j] > k) {
|
| 781 |
+
hitboxes[j + 1] = hitboxes[j];
|
| 782 |
+
j = j - 1;
|
| 783 |
+
}
|
| 784 |
+
hitboxes[j + 1] = k;
|
| 785 |
+
num++;
|
| 786 |
+
}
|
| 787 |
+
} else {
|
| 788 |
+
if (num < maxhitboxes) {
|
| 789 |
+
hitboxes[num++] = k;
|
| 790 |
+
}
|
| 791 |
+
}
|
| 792 |
+
}
|
| 793 |
+
}
|
| 794 |
+
|
| 795 |
+
node = *--stack_ptr;
|
| 796 |
+
} else {
|
| 797 |
+
int2 children = make_int2(node * 2 + 1, node * 2 + 2);
|
| 798 |
+
|
| 799 |
+
// check if we're in each child's bbox
|
| 800 |
+
float3 * nodeaabb_ptr = nodeaabb + children.x * 2;
|
| 801 |
+
bool traverse_l = ray_aabb_hit_ird(nodeaabb_ptr[0], nodeaabb_ptr[1], raypos, iraydir);
|
| 802 |
+
bool traverse_r = ray_aabb_hit_ird(nodeaabb_ptr[2], nodeaabb_ptr[3], raypos, iraydir);
|
| 803 |
+
|
| 804 |
+
if (sync) {
|
| 805 |
+
traverse_l = __any_sync(warpmask, traverse_l);
|
| 806 |
+
traverse_r = __any_sync(warpmask, traverse_r);
|
| 807 |
+
}
|
| 808 |
+
|
| 809 |
+
// update stack
|
| 810 |
+
if (!traverse_l && !traverse_r) {
|
| 811 |
+
node = *--stack_ptr;
|
| 812 |
+
} else {
|
| 813 |
+
node = traverse_l ? children.x : children.y;
|
| 814 |
+
if (traverse_l && traverse_r) {
|
| 815 |
+
*stack_ptr++ = children.y;
|
| 816 |
+
}
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
if (sync) {
|
| 820 |
+
__syncwarp(warpmask);
|
| 821 |
+
}
|
| 822 |
+
}
|
| 823 |
+
} while (node != -1);
|
| 824 |
+
}
|
| 825 |
+
|
| 826 |
+
template<bool sortboxes, int maxhitboxes, bool sync, class PrimTransfT>
|
| 827 |
+
struct RaySubsetFixedBVH {
|
| 828 |
+
static __forceinline__ __device__ void forward(
|
| 829 |
+
unsigned warpmask,
|
| 830 |
+
int K,
|
| 831 |
+
float3 raypos,
|
| 832 |
+
float3 raydir,
|
| 833 |
+
float2 tminmax,
|
| 834 |
+
float2 &rtminmax,
|
| 835 |
+
int * sortedobjid,
|
| 836 |
+
int2 * nodechildren,
|
| 837 |
+
float3 * nodeaabb,
|
| 838 |
+
const typename PrimTransfT::Data & primtransf_data,
|
| 839 |
+
int *hitboxes,
|
| 840 |
+
int & num) {
|
| 841 |
+
ray_subset_fixedbvh<sortboxes, maxhitboxes, sync, PrimTransfT>(
|
| 842 |
+
warpmask, K, raypos, raydir, tminmax, rtminmax,
|
| 843 |
+
sortedobjid, nodechildren, nodeaabb, primtransf_data, hitboxes, num);
|
| 844 |
+
}
|
| 845 |
+
};
|
| 846 |
+
|
| 847 |
+
#endif
|
dva/mvp/extensions/utils/helper_math.h
ADDED
|
@@ -0,0 +1,1453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/**
|
| 2 |
+
* Copyright 1993-2013 NVIDIA Corporation. All rights reserved.
|
| 3 |
+
*
|
| 4 |
+
* Please refer to the NVIDIA end user license agreement (EULA) associated
|
| 5 |
+
* with this source code for terms and conditions that govern your use of
|
| 6 |
+
* this software. Any use, reproduction, disclosure, or distribution of
|
| 7 |
+
* this software and related documentation outside the terms of the EULA
|
| 8 |
+
* is strictly prohibited.
|
| 9 |
+
*
|
| 10 |
+
*/
|
| 11 |
+
|
| 12 |
+
/*
|
| 13 |
+
* This file implements common mathematical operations on vector types
|
| 14 |
+
* (float3, float4 etc.) since these are not provided as standard by CUDA.
|
| 15 |
+
*
|
| 16 |
+
* The syntax is modeled on the Cg standard library.
|
| 17 |
+
*
|
| 18 |
+
* This is part of the Helper library includes
|
| 19 |
+
*
|
| 20 |
+
* Thanks to Linh Hah for additions and fixes.
|
| 21 |
+
*/
|
| 22 |
+
|
| 23 |
+
#ifndef HELPER_MATH_H
|
| 24 |
+
#define HELPER_MATH_H
|
| 25 |
+
|
| 26 |
+
#include "cuda_runtime.h"
|
| 27 |
+
|
| 28 |
+
typedef unsigned int uint;
|
| 29 |
+
typedef unsigned short ushort;
|
| 30 |
+
|
| 31 |
+
#ifndef EXIT_WAIVED
|
| 32 |
+
#define EXIT_WAIVED 2
|
| 33 |
+
#endif
|
| 34 |
+
|
| 35 |
+
#ifndef __CUDACC__
|
| 36 |
+
#include <math.h>
|
| 37 |
+
|
| 38 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 39 |
+
// host implementations of CUDA functions
|
| 40 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 41 |
+
|
| 42 |
+
inline float fminf(float a, float b)
|
| 43 |
+
{
|
| 44 |
+
return a < b ? a : b;
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
inline float fmaxf(float a, float b)
|
| 48 |
+
{
|
| 49 |
+
return a > b ? a : b;
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
inline int max(int a, int b)
|
| 53 |
+
{
|
| 54 |
+
return a > b ? a : b;
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
inline int min(int a, int b)
|
| 58 |
+
{
|
| 59 |
+
return a < b ? a : b;
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
inline float rsqrtf(float x)
|
| 63 |
+
{
|
| 64 |
+
return 1.0f / sqrtf(x);
|
| 65 |
+
}
|
| 66 |
+
#endif
|
| 67 |
+
|
| 68 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 69 |
+
// constructors
|
| 70 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 71 |
+
|
| 72 |
+
inline __host__ __device__ float2 make_float2(float s)
|
| 73 |
+
{
|
| 74 |
+
return make_float2(s, s);
|
| 75 |
+
}
|
| 76 |
+
inline __host__ __device__ float2 make_float2(float3 a)
|
| 77 |
+
{
|
| 78 |
+
return make_float2(a.x, a.y);
|
| 79 |
+
}
|
| 80 |
+
inline __host__ __device__ float2 make_float2(int2 a)
|
| 81 |
+
{
|
| 82 |
+
return make_float2(float(a.x), float(a.y));
|
| 83 |
+
}
|
| 84 |
+
inline __host__ __device__ float2 make_float2(uint2 a)
|
| 85 |
+
{
|
| 86 |
+
return make_float2(float(a.x), float(a.y));
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
inline __host__ __device__ int2 make_int2(int s)
|
| 90 |
+
{
|
| 91 |
+
return make_int2(s, s);
|
| 92 |
+
}
|
| 93 |
+
inline __host__ __device__ int2 make_int2(int3 a)
|
| 94 |
+
{
|
| 95 |
+
return make_int2(a.x, a.y);
|
| 96 |
+
}
|
| 97 |
+
inline __host__ __device__ int2 make_int2(uint2 a)
|
| 98 |
+
{
|
| 99 |
+
return make_int2(int(a.x), int(a.y));
|
| 100 |
+
}
|
| 101 |
+
inline __host__ __device__ int2 make_int2(float2 a)
|
| 102 |
+
{
|
| 103 |
+
return make_int2(int(a.x), int(a.y));
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
inline __host__ __device__ uint2 make_uint2(uint s)
|
| 107 |
+
{
|
| 108 |
+
return make_uint2(s, s);
|
| 109 |
+
}
|
| 110 |
+
inline __host__ __device__ uint2 make_uint2(uint3 a)
|
| 111 |
+
{
|
| 112 |
+
return make_uint2(a.x, a.y);
|
| 113 |
+
}
|
| 114 |
+
inline __host__ __device__ uint2 make_uint2(int2 a)
|
| 115 |
+
{
|
| 116 |
+
return make_uint2(uint(a.x), uint(a.y));
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
inline __host__ __device__ float3 make_float3(float s)
|
| 120 |
+
{
|
| 121 |
+
return make_float3(s, s, s);
|
| 122 |
+
}
|
| 123 |
+
inline __host__ __device__ float3 make_float3(float2 a)
|
| 124 |
+
{
|
| 125 |
+
return make_float3(a.x, a.y, 0.0f);
|
| 126 |
+
}
|
| 127 |
+
inline __host__ __device__ float3 make_float3(float2 a, float s)
|
| 128 |
+
{
|
| 129 |
+
return make_float3(a.x, a.y, s);
|
| 130 |
+
}
|
| 131 |
+
inline __host__ __device__ float3 make_float3(float4 a)
|
| 132 |
+
{
|
| 133 |
+
return make_float3(a.x, a.y, a.z);
|
| 134 |
+
}
|
| 135 |
+
inline __host__ __device__ float3 make_float3(int3 a)
|
| 136 |
+
{
|
| 137 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
| 138 |
+
}
|
| 139 |
+
inline __host__ __device__ float3 make_float3(uint3 a)
|
| 140 |
+
{
|
| 141 |
+
return make_float3(float(a.x), float(a.y), float(a.z));
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
inline __host__ __device__ int3 make_int3(int s)
|
| 145 |
+
{
|
| 146 |
+
return make_int3(s, s, s);
|
| 147 |
+
}
|
| 148 |
+
inline __host__ __device__ int3 make_int3(int2 a)
|
| 149 |
+
{
|
| 150 |
+
return make_int3(a.x, a.y, 0);
|
| 151 |
+
}
|
| 152 |
+
inline __host__ __device__ int3 make_int3(int2 a, int s)
|
| 153 |
+
{
|
| 154 |
+
return make_int3(a.x, a.y, s);
|
| 155 |
+
}
|
| 156 |
+
inline __host__ __device__ int3 make_int3(uint3 a)
|
| 157 |
+
{
|
| 158 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
| 159 |
+
}
|
| 160 |
+
inline __host__ __device__ int3 make_int3(float3 a)
|
| 161 |
+
{
|
| 162 |
+
return make_int3(int(a.x), int(a.y), int(a.z));
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
inline __host__ __device__ uint3 make_uint3(uint s)
|
| 166 |
+
{
|
| 167 |
+
return make_uint3(s, s, s);
|
| 168 |
+
}
|
| 169 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a)
|
| 170 |
+
{
|
| 171 |
+
return make_uint3(a.x, a.y, 0);
|
| 172 |
+
}
|
| 173 |
+
inline __host__ __device__ uint3 make_uint3(uint2 a, uint s)
|
| 174 |
+
{
|
| 175 |
+
return make_uint3(a.x, a.y, s);
|
| 176 |
+
}
|
| 177 |
+
inline __host__ __device__ uint3 make_uint3(uint4 a)
|
| 178 |
+
{
|
| 179 |
+
return make_uint3(a.x, a.y, a.z);
|
| 180 |
+
}
|
| 181 |
+
inline __host__ __device__ uint3 make_uint3(int3 a)
|
| 182 |
+
{
|
| 183 |
+
return make_uint3(uint(a.x), uint(a.y), uint(a.z));
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
inline __host__ __device__ float4 make_float4(float s)
|
| 187 |
+
{
|
| 188 |
+
return make_float4(s, s, s, s);
|
| 189 |
+
}
|
| 190 |
+
inline __host__ __device__ float4 make_float4(float3 a)
|
| 191 |
+
{
|
| 192 |
+
return make_float4(a.x, a.y, a.z, 0.0f);
|
| 193 |
+
}
|
| 194 |
+
inline __host__ __device__ float4 make_float4(float3 a, float w)
|
| 195 |
+
{
|
| 196 |
+
return make_float4(a.x, a.y, a.z, w);
|
| 197 |
+
}
|
| 198 |
+
inline __host__ __device__ float4 make_float4(int4 a)
|
| 199 |
+
{
|
| 200 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
| 201 |
+
}
|
| 202 |
+
inline __host__ __device__ float4 make_float4(uint4 a)
|
| 203 |
+
{
|
| 204 |
+
return make_float4(float(a.x), float(a.y), float(a.z), float(a.w));
|
| 205 |
+
}
|
| 206 |
+
|
| 207 |
+
inline __host__ __device__ int4 make_int4(int s)
|
| 208 |
+
{
|
| 209 |
+
return make_int4(s, s, s, s);
|
| 210 |
+
}
|
| 211 |
+
inline __host__ __device__ int4 make_int4(int3 a)
|
| 212 |
+
{
|
| 213 |
+
return make_int4(a.x, a.y, a.z, 0);
|
| 214 |
+
}
|
| 215 |
+
inline __host__ __device__ int4 make_int4(int3 a, int w)
|
| 216 |
+
{
|
| 217 |
+
return make_int4(a.x, a.y, a.z, w);
|
| 218 |
+
}
|
| 219 |
+
inline __host__ __device__ int4 make_int4(uint4 a)
|
| 220 |
+
{
|
| 221 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
| 222 |
+
}
|
| 223 |
+
inline __host__ __device__ int4 make_int4(float4 a)
|
| 224 |
+
{
|
| 225 |
+
return make_int4(int(a.x), int(a.y), int(a.z), int(a.w));
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
inline __host__ __device__ uint4 make_uint4(uint s)
|
| 230 |
+
{
|
| 231 |
+
return make_uint4(s, s, s, s);
|
| 232 |
+
}
|
| 233 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a)
|
| 234 |
+
{
|
| 235 |
+
return make_uint4(a.x, a.y, a.z, 0);
|
| 236 |
+
}
|
| 237 |
+
inline __host__ __device__ uint4 make_uint4(uint3 a, uint w)
|
| 238 |
+
{
|
| 239 |
+
return make_uint4(a.x, a.y, a.z, w);
|
| 240 |
+
}
|
| 241 |
+
inline __host__ __device__ uint4 make_uint4(int4 a)
|
| 242 |
+
{
|
| 243 |
+
return make_uint4(uint(a.x), uint(a.y), uint(a.z), uint(a.w));
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 247 |
+
// negate
|
| 248 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 249 |
+
|
| 250 |
+
inline __host__ __device__ float2 operator-(float2 &a)
|
| 251 |
+
{
|
| 252 |
+
return make_float2(-a.x, -a.y);
|
| 253 |
+
}
|
| 254 |
+
inline __host__ __device__ int2 operator-(int2 &a)
|
| 255 |
+
{
|
| 256 |
+
return make_int2(-a.x, -a.y);
|
| 257 |
+
}
|
| 258 |
+
inline __host__ __device__ float3 operator-(float3 &a)
|
| 259 |
+
{
|
| 260 |
+
return make_float3(-a.x, -a.y, -a.z);
|
| 261 |
+
}
|
| 262 |
+
inline __host__ __device__ int3 operator-(int3 &a)
|
| 263 |
+
{
|
| 264 |
+
return make_int3(-a.x, -a.y, -a.z);
|
| 265 |
+
}
|
| 266 |
+
inline __host__ __device__ float4 operator-(float4 &a)
|
| 267 |
+
{
|
| 268 |
+
return make_float4(-a.x, -a.y, -a.z, -a.w);
|
| 269 |
+
}
|
| 270 |
+
inline __host__ __device__ int4 operator-(int4 &a)
|
| 271 |
+
{
|
| 272 |
+
return make_int4(-a.x, -a.y, -a.z, -a.w);
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 276 |
+
// addition
|
| 277 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 278 |
+
|
| 279 |
+
inline __host__ __device__ float2 operator+(float2 a, float2 b)
|
| 280 |
+
{
|
| 281 |
+
return make_float2(a.x + b.x, a.y + b.y);
|
| 282 |
+
}
|
| 283 |
+
inline __host__ __device__ void operator+=(float2 &a, float2 b)
|
| 284 |
+
{
|
| 285 |
+
a.x += b.x;
|
| 286 |
+
a.y += b.y;
|
| 287 |
+
}
|
| 288 |
+
inline __host__ __device__ float2 operator+(float2 a, float b)
|
| 289 |
+
{
|
| 290 |
+
return make_float2(a.x + b, a.y + b);
|
| 291 |
+
}
|
| 292 |
+
inline __host__ __device__ float2 operator+(float b, float2 a)
|
| 293 |
+
{
|
| 294 |
+
return make_float2(a.x + b, a.y + b);
|
| 295 |
+
}
|
| 296 |
+
inline __host__ __device__ void operator+=(float2 &a, float b)
|
| 297 |
+
{
|
| 298 |
+
a.x += b;
|
| 299 |
+
a.y += b;
|
| 300 |
+
}
|
| 301 |
+
|
| 302 |
+
inline __host__ __device__ int2 operator+(int2 a, int2 b)
|
| 303 |
+
{
|
| 304 |
+
return make_int2(a.x + b.x, a.y + b.y);
|
| 305 |
+
}
|
| 306 |
+
inline __host__ __device__ void operator+=(int2 &a, int2 b)
|
| 307 |
+
{
|
| 308 |
+
a.x += b.x;
|
| 309 |
+
a.y += b.y;
|
| 310 |
+
}
|
| 311 |
+
inline __host__ __device__ int2 operator+(int2 a, int b)
|
| 312 |
+
{
|
| 313 |
+
return make_int2(a.x + b, a.y + b);
|
| 314 |
+
}
|
| 315 |
+
inline __host__ __device__ int2 operator+(int b, int2 a)
|
| 316 |
+
{
|
| 317 |
+
return make_int2(a.x + b, a.y + b);
|
| 318 |
+
}
|
| 319 |
+
inline __host__ __device__ void operator+=(int2 &a, int b)
|
| 320 |
+
{
|
| 321 |
+
a.x += b;
|
| 322 |
+
a.y += b;
|
| 323 |
+
}
|
| 324 |
+
|
| 325 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint2 b)
|
| 326 |
+
{
|
| 327 |
+
return make_uint2(a.x + b.x, a.y + b.y);
|
| 328 |
+
}
|
| 329 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint2 b)
|
| 330 |
+
{
|
| 331 |
+
a.x += b.x;
|
| 332 |
+
a.y += b.y;
|
| 333 |
+
}
|
| 334 |
+
inline __host__ __device__ uint2 operator+(uint2 a, uint b)
|
| 335 |
+
{
|
| 336 |
+
return make_uint2(a.x + b, a.y + b);
|
| 337 |
+
}
|
| 338 |
+
inline __host__ __device__ uint2 operator+(uint b, uint2 a)
|
| 339 |
+
{
|
| 340 |
+
return make_uint2(a.x + b, a.y + b);
|
| 341 |
+
}
|
| 342 |
+
inline __host__ __device__ void operator+=(uint2 &a, uint b)
|
| 343 |
+
{
|
| 344 |
+
a.x += b;
|
| 345 |
+
a.y += b;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
inline __host__ __device__ float3 operator+(float3 a, float3 b)
|
| 350 |
+
{
|
| 351 |
+
return make_float3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 352 |
+
}
|
| 353 |
+
inline __host__ __device__ void operator+=(float3 &a, float3 b)
|
| 354 |
+
{
|
| 355 |
+
a.x += b.x;
|
| 356 |
+
a.y += b.y;
|
| 357 |
+
a.z += b.z;
|
| 358 |
+
}
|
| 359 |
+
inline __host__ __device__ float3 operator+(float3 a, float b)
|
| 360 |
+
{
|
| 361 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
| 362 |
+
}
|
| 363 |
+
inline __host__ __device__ void operator+=(float3 &a, float b)
|
| 364 |
+
{
|
| 365 |
+
a.x += b;
|
| 366 |
+
a.y += b;
|
| 367 |
+
a.z += b;
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
inline __host__ __device__ int3 operator+(int3 a, int3 b)
|
| 371 |
+
{
|
| 372 |
+
return make_int3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 373 |
+
}
|
| 374 |
+
inline __host__ __device__ void operator+=(int3 &a, int3 b)
|
| 375 |
+
{
|
| 376 |
+
a.x += b.x;
|
| 377 |
+
a.y += b.y;
|
| 378 |
+
a.z += b.z;
|
| 379 |
+
}
|
| 380 |
+
inline __host__ __device__ int3 operator+(int3 a, int b)
|
| 381 |
+
{
|
| 382 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
| 383 |
+
}
|
| 384 |
+
inline __host__ __device__ void operator+=(int3 &a, int b)
|
| 385 |
+
{
|
| 386 |
+
a.x += b;
|
| 387 |
+
a.y += b;
|
| 388 |
+
a.z += b;
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint3 b)
|
| 392 |
+
{
|
| 393 |
+
return make_uint3(a.x + b.x, a.y + b.y, a.z + b.z);
|
| 394 |
+
}
|
| 395 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint3 b)
|
| 396 |
+
{
|
| 397 |
+
a.x += b.x;
|
| 398 |
+
a.y += b.y;
|
| 399 |
+
a.z += b.z;
|
| 400 |
+
}
|
| 401 |
+
inline __host__ __device__ uint3 operator+(uint3 a, uint b)
|
| 402 |
+
{
|
| 403 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
| 404 |
+
}
|
| 405 |
+
inline __host__ __device__ void operator+=(uint3 &a, uint b)
|
| 406 |
+
{
|
| 407 |
+
a.x += b;
|
| 408 |
+
a.y += b;
|
| 409 |
+
a.z += b;
|
| 410 |
+
}
|
| 411 |
+
|
| 412 |
+
inline __host__ __device__ int3 operator+(int b, int3 a)
|
| 413 |
+
{
|
| 414 |
+
return make_int3(a.x + b, a.y + b, a.z + b);
|
| 415 |
+
}
|
| 416 |
+
inline __host__ __device__ uint3 operator+(uint b, uint3 a)
|
| 417 |
+
{
|
| 418 |
+
return make_uint3(a.x + b, a.y + b, a.z + b);
|
| 419 |
+
}
|
| 420 |
+
inline __host__ __device__ float3 operator+(float b, float3 a)
|
| 421 |
+
{
|
| 422 |
+
return make_float3(a.x + b, a.y + b, a.z + b);
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
inline __host__ __device__ float4 operator+(float4 a, float4 b)
|
| 426 |
+
{
|
| 427 |
+
return make_float4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 428 |
+
}
|
| 429 |
+
inline __host__ __device__ void operator+=(float4 &a, float4 b)
|
| 430 |
+
{
|
| 431 |
+
a.x += b.x;
|
| 432 |
+
a.y += b.y;
|
| 433 |
+
a.z += b.z;
|
| 434 |
+
a.w += b.w;
|
| 435 |
+
}
|
| 436 |
+
inline __host__ __device__ float4 operator+(float4 a, float b)
|
| 437 |
+
{
|
| 438 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 439 |
+
}
|
| 440 |
+
inline __host__ __device__ float4 operator+(float b, float4 a)
|
| 441 |
+
{
|
| 442 |
+
return make_float4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 443 |
+
}
|
| 444 |
+
inline __host__ __device__ void operator+=(float4 &a, float b)
|
| 445 |
+
{
|
| 446 |
+
a.x += b;
|
| 447 |
+
a.y += b;
|
| 448 |
+
a.z += b;
|
| 449 |
+
a.w += b;
|
| 450 |
+
}
|
| 451 |
+
|
| 452 |
+
inline __host__ __device__ int4 operator+(int4 a, int4 b)
|
| 453 |
+
{
|
| 454 |
+
return make_int4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 455 |
+
}
|
| 456 |
+
inline __host__ __device__ void operator+=(int4 &a, int4 b)
|
| 457 |
+
{
|
| 458 |
+
a.x += b.x;
|
| 459 |
+
a.y += b.y;
|
| 460 |
+
a.z += b.z;
|
| 461 |
+
a.w += b.w;
|
| 462 |
+
}
|
| 463 |
+
inline __host__ __device__ int4 operator+(int4 a, int b)
|
| 464 |
+
{
|
| 465 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 466 |
+
}
|
| 467 |
+
inline __host__ __device__ int4 operator+(int b, int4 a)
|
| 468 |
+
{
|
| 469 |
+
return make_int4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 470 |
+
}
|
| 471 |
+
inline __host__ __device__ void operator+=(int4 &a, int b)
|
| 472 |
+
{
|
| 473 |
+
a.x += b;
|
| 474 |
+
a.y += b;
|
| 475 |
+
a.z += b;
|
| 476 |
+
a.w += b;
|
| 477 |
+
}
|
| 478 |
+
|
| 479 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint4 b)
|
| 480 |
+
{
|
| 481 |
+
return make_uint4(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w);
|
| 482 |
+
}
|
| 483 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint4 b)
|
| 484 |
+
{
|
| 485 |
+
a.x += b.x;
|
| 486 |
+
a.y += b.y;
|
| 487 |
+
a.z += b.z;
|
| 488 |
+
a.w += b.w;
|
| 489 |
+
}
|
| 490 |
+
inline __host__ __device__ uint4 operator+(uint4 a, uint b)
|
| 491 |
+
{
|
| 492 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 493 |
+
}
|
| 494 |
+
inline __host__ __device__ uint4 operator+(uint b, uint4 a)
|
| 495 |
+
{
|
| 496 |
+
return make_uint4(a.x + b, a.y + b, a.z + b, a.w + b);
|
| 497 |
+
}
|
| 498 |
+
inline __host__ __device__ void operator+=(uint4 &a, uint b)
|
| 499 |
+
{
|
| 500 |
+
a.x += b;
|
| 501 |
+
a.y += b;
|
| 502 |
+
a.z += b;
|
| 503 |
+
a.w += b;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 507 |
+
// subtract
|
| 508 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 509 |
+
|
| 510 |
+
inline __host__ __device__ float2 operator-(float2 a, float2 b)
|
| 511 |
+
{
|
| 512 |
+
return make_float2(a.x - b.x, a.y - b.y);
|
| 513 |
+
}
|
| 514 |
+
inline __host__ __device__ void operator-=(float2 &a, float2 b)
|
| 515 |
+
{
|
| 516 |
+
a.x -= b.x;
|
| 517 |
+
a.y -= b.y;
|
| 518 |
+
}
|
| 519 |
+
inline __host__ __device__ float2 operator-(float2 a, float b)
|
| 520 |
+
{
|
| 521 |
+
return make_float2(a.x - b, a.y - b);
|
| 522 |
+
}
|
| 523 |
+
inline __host__ __device__ float2 operator-(float b, float2 a)
|
| 524 |
+
{
|
| 525 |
+
return make_float2(b - a.x, b - a.y);
|
| 526 |
+
}
|
| 527 |
+
inline __host__ __device__ void operator-=(float2 &a, float b)
|
| 528 |
+
{
|
| 529 |
+
a.x -= b;
|
| 530 |
+
a.y -= b;
|
| 531 |
+
}
|
| 532 |
+
|
| 533 |
+
inline __host__ __device__ int2 operator-(int2 a, int2 b)
|
| 534 |
+
{
|
| 535 |
+
return make_int2(a.x - b.x, a.y - b.y);
|
| 536 |
+
}
|
| 537 |
+
inline __host__ __device__ void operator-=(int2 &a, int2 b)
|
| 538 |
+
{
|
| 539 |
+
a.x -= b.x;
|
| 540 |
+
a.y -= b.y;
|
| 541 |
+
}
|
| 542 |
+
inline __host__ __device__ int2 operator-(int2 a, int b)
|
| 543 |
+
{
|
| 544 |
+
return make_int2(a.x - b, a.y - b);
|
| 545 |
+
}
|
| 546 |
+
inline __host__ __device__ int2 operator-(int b, int2 a)
|
| 547 |
+
{
|
| 548 |
+
return make_int2(b - a.x, b - a.y);
|
| 549 |
+
}
|
| 550 |
+
inline __host__ __device__ void operator-=(int2 &a, int b)
|
| 551 |
+
{
|
| 552 |
+
a.x -= b;
|
| 553 |
+
a.y -= b;
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint2 b)
|
| 557 |
+
{
|
| 558 |
+
return make_uint2(a.x - b.x, a.y - b.y);
|
| 559 |
+
}
|
| 560 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint2 b)
|
| 561 |
+
{
|
| 562 |
+
a.x -= b.x;
|
| 563 |
+
a.y -= b.y;
|
| 564 |
+
}
|
| 565 |
+
inline __host__ __device__ uint2 operator-(uint2 a, uint b)
|
| 566 |
+
{
|
| 567 |
+
return make_uint2(a.x - b, a.y - b);
|
| 568 |
+
}
|
| 569 |
+
inline __host__ __device__ uint2 operator-(uint b, uint2 a)
|
| 570 |
+
{
|
| 571 |
+
return make_uint2(b - a.x, b - a.y);
|
| 572 |
+
}
|
| 573 |
+
inline __host__ __device__ void operator-=(uint2 &a, uint b)
|
| 574 |
+
{
|
| 575 |
+
a.x -= b;
|
| 576 |
+
a.y -= b;
|
| 577 |
+
}
|
| 578 |
+
|
| 579 |
+
inline __host__ __device__ float3 operator-(float3 a, float3 b)
|
| 580 |
+
{
|
| 581 |
+
return make_float3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 582 |
+
}
|
| 583 |
+
inline __host__ __device__ void operator-=(float3 &a, float3 b)
|
| 584 |
+
{
|
| 585 |
+
a.x -= b.x;
|
| 586 |
+
a.y -= b.y;
|
| 587 |
+
a.z -= b.z;
|
| 588 |
+
}
|
| 589 |
+
inline __host__ __device__ float3 operator-(float3 a, float b)
|
| 590 |
+
{
|
| 591 |
+
return make_float3(a.x - b, a.y - b, a.z - b);
|
| 592 |
+
}
|
| 593 |
+
inline __host__ __device__ float3 operator-(float b, float3 a)
|
| 594 |
+
{
|
| 595 |
+
return make_float3(b - a.x, b - a.y, b - a.z);
|
| 596 |
+
}
|
| 597 |
+
inline __host__ __device__ void operator-=(float3 &a, float b)
|
| 598 |
+
{
|
| 599 |
+
a.x -= b;
|
| 600 |
+
a.y -= b;
|
| 601 |
+
a.z -= b;
|
| 602 |
+
}
|
| 603 |
+
|
| 604 |
+
inline __host__ __device__ int3 operator-(int3 a, int3 b)
|
| 605 |
+
{
|
| 606 |
+
return make_int3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 607 |
+
}
|
| 608 |
+
inline __host__ __device__ void operator-=(int3 &a, int3 b)
|
| 609 |
+
{
|
| 610 |
+
a.x -= b.x;
|
| 611 |
+
a.y -= b.y;
|
| 612 |
+
a.z -= b.z;
|
| 613 |
+
}
|
| 614 |
+
inline __host__ __device__ int3 operator-(int3 a, int b)
|
| 615 |
+
{
|
| 616 |
+
return make_int3(a.x - b, a.y - b, a.z - b);
|
| 617 |
+
}
|
| 618 |
+
inline __host__ __device__ int3 operator-(int b, int3 a)
|
| 619 |
+
{
|
| 620 |
+
return make_int3(b - a.x, b - a.y, b - a.z);
|
| 621 |
+
}
|
| 622 |
+
inline __host__ __device__ void operator-=(int3 &a, int b)
|
| 623 |
+
{
|
| 624 |
+
a.x -= b;
|
| 625 |
+
a.y -= b;
|
| 626 |
+
a.z -= b;
|
| 627 |
+
}
|
| 628 |
+
|
| 629 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint3 b)
|
| 630 |
+
{
|
| 631 |
+
return make_uint3(a.x - b.x, a.y - b.y, a.z - b.z);
|
| 632 |
+
}
|
| 633 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint3 b)
|
| 634 |
+
{
|
| 635 |
+
a.x -= b.x;
|
| 636 |
+
a.y -= b.y;
|
| 637 |
+
a.z -= b.z;
|
| 638 |
+
}
|
| 639 |
+
inline __host__ __device__ uint3 operator-(uint3 a, uint b)
|
| 640 |
+
{
|
| 641 |
+
return make_uint3(a.x - b, a.y - b, a.z - b);
|
| 642 |
+
}
|
| 643 |
+
inline __host__ __device__ uint3 operator-(uint b, uint3 a)
|
| 644 |
+
{
|
| 645 |
+
return make_uint3(b - a.x, b - a.y, b - a.z);
|
| 646 |
+
}
|
| 647 |
+
inline __host__ __device__ void operator-=(uint3 &a, uint b)
|
| 648 |
+
{
|
| 649 |
+
a.x -= b;
|
| 650 |
+
a.y -= b;
|
| 651 |
+
a.z -= b;
|
| 652 |
+
}
|
| 653 |
+
|
| 654 |
+
inline __host__ __device__ float4 operator-(float4 a, float4 b)
|
| 655 |
+
{
|
| 656 |
+
return make_float4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 657 |
+
}
|
| 658 |
+
inline __host__ __device__ void operator-=(float4 &a, float4 b)
|
| 659 |
+
{
|
| 660 |
+
a.x -= b.x;
|
| 661 |
+
a.y -= b.y;
|
| 662 |
+
a.z -= b.z;
|
| 663 |
+
a.w -= b.w;
|
| 664 |
+
}
|
| 665 |
+
inline __host__ __device__ float4 operator-(float4 a, float b)
|
| 666 |
+
{
|
| 667 |
+
return make_float4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 668 |
+
}
|
| 669 |
+
inline __host__ __device__ void operator-=(float4 &a, float b)
|
| 670 |
+
{
|
| 671 |
+
a.x -= b;
|
| 672 |
+
a.y -= b;
|
| 673 |
+
a.z -= b;
|
| 674 |
+
a.w -= b;
|
| 675 |
+
}
|
| 676 |
+
|
| 677 |
+
inline __host__ __device__ int4 operator-(int4 a, int4 b)
|
| 678 |
+
{
|
| 679 |
+
return make_int4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 680 |
+
}
|
| 681 |
+
inline __host__ __device__ void operator-=(int4 &a, int4 b)
|
| 682 |
+
{
|
| 683 |
+
a.x -= b.x;
|
| 684 |
+
a.y -= b.y;
|
| 685 |
+
a.z -= b.z;
|
| 686 |
+
a.w -= b.w;
|
| 687 |
+
}
|
| 688 |
+
inline __host__ __device__ int4 operator-(int4 a, int b)
|
| 689 |
+
{
|
| 690 |
+
return make_int4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 691 |
+
}
|
| 692 |
+
inline __host__ __device__ int4 operator-(int b, int4 a)
|
| 693 |
+
{
|
| 694 |
+
return make_int4(b - a.x, b - a.y, b - a.z, b - a.w);
|
| 695 |
+
}
|
| 696 |
+
inline __host__ __device__ void operator-=(int4 &a, int b)
|
| 697 |
+
{
|
| 698 |
+
a.x -= b;
|
| 699 |
+
a.y -= b;
|
| 700 |
+
a.z -= b;
|
| 701 |
+
a.w -= b;
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint4 b)
|
| 705 |
+
{
|
| 706 |
+
return make_uint4(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w);
|
| 707 |
+
}
|
| 708 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint4 b)
|
| 709 |
+
{
|
| 710 |
+
a.x -= b.x;
|
| 711 |
+
a.y -= b.y;
|
| 712 |
+
a.z -= b.z;
|
| 713 |
+
a.w -= b.w;
|
| 714 |
+
}
|
| 715 |
+
inline __host__ __device__ uint4 operator-(uint4 a, uint b)
|
| 716 |
+
{
|
| 717 |
+
return make_uint4(a.x - b, a.y - b, a.z - b, a.w - b);
|
| 718 |
+
}
|
| 719 |
+
inline __host__ __device__ uint4 operator-(uint b, uint4 a)
|
| 720 |
+
{
|
| 721 |
+
return make_uint4(b - a.x, b - a.y, b - a.z, b - a.w);
|
| 722 |
+
}
|
| 723 |
+
inline __host__ __device__ void operator-=(uint4 &a, uint b)
|
| 724 |
+
{
|
| 725 |
+
a.x -= b;
|
| 726 |
+
a.y -= b;
|
| 727 |
+
a.z -= b;
|
| 728 |
+
a.w -= b;
|
| 729 |
+
}
|
| 730 |
+
|
| 731 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 732 |
+
// multiply
|
| 733 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 734 |
+
|
| 735 |
+
inline __host__ __device__ float2 operator*(float2 a, float2 b)
|
| 736 |
+
{
|
| 737 |
+
return make_float2(a.x * b.x, a.y * b.y);
|
| 738 |
+
}
|
| 739 |
+
inline __host__ __device__ void operator*=(float2 &a, float2 b)
|
| 740 |
+
{
|
| 741 |
+
a.x *= b.x;
|
| 742 |
+
a.y *= b.y;
|
| 743 |
+
}
|
| 744 |
+
inline __host__ __device__ float2 operator*(float2 a, float b)
|
| 745 |
+
{
|
| 746 |
+
return make_float2(a.x * b, a.y * b);
|
| 747 |
+
}
|
| 748 |
+
inline __host__ __device__ float2 operator*(float b, float2 a)
|
| 749 |
+
{
|
| 750 |
+
return make_float2(b * a.x, b * a.y);
|
| 751 |
+
}
|
| 752 |
+
inline __host__ __device__ void operator*=(float2 &a, float b)
|
| 753 |
+
{
|
| 754 |
+
a.x *= b;
|
| 755 |
+
a.y *= b;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
inline __host__ __device__ int2 operator*(int2 a, int2 b)
|
| 759 |
+
{
|
| 760 |
+
return make_int2(a.x * b.x, a.y * b.y);
|
| 761 |
+
}
|
| 762 |
+
inline __host__ __device__ void operator*=(int2 &a, int2 b)
|
| 763 |
+
{
|
| 764 |
+
a.x *= b.x;
|
| 765 |
+
a.y *= b.y;
|
| 766 |
+
}
|
| 767 |
+
inline __host__ __device__ int2 operator*(int2 a, int b)
|
| 768 |
+
{
|
| 769 |
+
return make_int2(a.x * b, a.y * b);
|
| 770 |
+
}
|
| 771 |
+
inline __host__ __device__ int2 operator*(int b, int2 a)
|
| 772 |
+
{
|
| 773 |
+
return make_int2(b * a.x, b * a.y);
|
| 774 |
+
}
|
| 775 |
+
inline __host__ __device__ void operator*=(int2 &a, int b)
|
| 776 |
+
{
|
| 777 |
+
a.x *= b;
|
| 778 |
+
a.y *= b;
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint2 b)
|
| 782 |
+
{
|
| 783 |
+
return make_uint2(a.x * b.x, a.y * b.y);
|
| 784 |
+
}
|
| 785 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint2 b)
|
| 786 |
+
{
|
| 787 |
+
a.x *= b.x;
|
| 788 |
+
a.y *= b.y;
|
| 789 |
+
}
|
| 790 |
+
inline __host__ __device__ uint2 operator*(uint2 a, uint b)
|
| 791 |
+
{
|
| 792 |
+
return make_uint2(a.x * b, a.y * b);
|
| 793 |
+
}
|
| 794 |
+
inline __host__ __device__ uint2 operator*(uint b, uint2 a)
|
| 795 |
+
{
|
| 796 |
+
return make_uint2(b * a.x, b * a.y);
|
| 797 |
+
}
|
| 798 |
+
inline __host__ __device__ void operator*=(uint2 &a, uint b)
|
| 799 |
+
{
|
| 800 |
+
a.x *= b;
|
| 801 |
+
a.y *= b;
|
| 802 |
+
}
|
| 803 |
+
|
| 804 |
+
inline __host__ __device__ float3 operator*(float3 a, float3 b)
|
| 805 |
+
{
|
| 806 |
+
return make_float3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 807 |
+
}
|
| 808 |
+
inline __host__ __device__ void operator*=(float3 &a, float3 b)
|
| 809 |
+
{
|
| 810 |
+
a.x *= b.x;
|
| 811 |
+
a.y *= b.y;
|
| 812 |
+
a.z *= b.z;
|
| 813 |
+
}
|
| 814 |
+
inline __host__ __device__ float3 operator*(float3 a, float b)
|
| 815 |
+
{
|
| 816 |
+
return make_float3(a.x * b, a.y * b, a.z * b);
|
| 817 |
+
}
|
| 818 |
+
inline __host__ __device__ float3 operator*(float b, float3 a)
|
| 819 |
+
{
|
| 820 |
+
return make_float3(b * a.x, b * a.y, b * a.z);
|
| 821 |
+
}
|
| 822 |
+
inline __host__ __device__ void operator*=(float3 &a, float b)
|
| 823 |
+
{
|
| 824 |
+
a.x *= b;
|
| 825 |
+
a.y *= b;
|
| 826 |
+
a.z *= b;
|
| 827 |
+
}
|
| 828 |
+
|
| 829 |
+
inline __host__ __device__ int3 operator*(int3 a, int3 b)
|
| 830 |
+
{
|
| 831 |
+
return make_int3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 832 |
+
}
|
| 833 |
+
inline __host__ __device__ void operator*=(int3 &a, int3 b)
|
| 834 |
+
{
|
| 835 |
+
a.x *= b.x;
|
| 836 |
+
a.y *= b.y;
|
| 837 |
+
a.z *= b.z;
|
| 838 |
+
}
|
| 839 |
+
inline __host__ __device__ int3 operator*(int3 a, int b)
|
| 840 |
+
{
|
| 841 |
+
return make_int3(a.x * b, a.y * b, a.z * b);
|
| 842 |
+
}
|
| 843 |
+
inline __host__ __device__ int3 operator*(int b, int3 a)
|
| 844 |
+
{
|
| 845 |
+
return make_int3(b * a.x, b * a.y, b * a.z);
|
| 846 |
+
}
|
| 847 |
+
inline __host__ __device__ void operator*=(int3 &a, int b)
|
| 848 |
+
{
|
| 849 |
+
a.x *= b;
|
| 850 |
+
a.y *= b;
|
| 851 |
+
a.z *= b;
|
| 852 |
+
}
|
| 853 |
+
|
| 854 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint3 b)
|
| 855 |
+
{
|
| 856 |
+
return make_uint3(a.x * b.x, a.y * b.y, a.z * b.z);
|
| 857 |
+
}
|
| 858 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint3 b)
|
| 859 |
+
{
|
| 860 |
+
a.x *= b.x;
|
| 861 |
+
a.y *= b.y;
|
| 862 |
+
a.z *= b.z;
|
| 863 |
+
}
|
| 864 |
+
inline __host__ __device__ uint3 operator*(uint3 a, uint b)
|
| 865 |
+
{
|
| 866 |
+
return make_uint3(a.x * b, a.y * b, a.z * b);
|
| 867 |
+
}
|
| 868 |
+
inline __host__ __device__ uint3 operator*(uint b, uint3 a)
|
| 869 |
+
{
|
| 870 |
+
return make_uint3(b * a.x, b * a.y, b * a.z);
|
| 871 |
+
}
|
| 872 |
+
inline __host__ __device__ void operator*=(uint3 &a, uint b)
|
| 873 |
+
{
|
| 874 |
+
a.x *= b;
|
| 875 |
+
a.y *= b;
|
| 876 |
+
a.z *= b;
|
| 877 |
+
}
|
| 878 |
+
|
| 879 |
+
inline __host__ __device__ float4 operator*(float4 a, float4 b)
|
| 880 |
+
{
|
| 881 |
+
return make_float4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 882 |
+
}
|
| 883 |
+
inline __host__ __device__ void operator*=(float4 &a, float4 b)
|
| 884 |
+
{
|
| 885 |
+
a.x *= b.x;
|
| 886 |
+
a.y *= b.y;
|
| 887 |
+
a.z *= b.z;
|
| 888 |
+
a.w *= b.w;
|
| 889 |
+
}
|
| 890 |
+
inline __host__ __device__ float4 operator*(float4 a, float b)
|
| 891 |
+
{
|
| 892 |
+
return make_float4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 893 |
+
}
|
| 894 |
+
inline __host__ __device__ float4 operator*(float b, float4 a)
|
| 895 |
+
{
|
| 896 |
+
return make_float4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 897 |
+
}
|
| 898 |
+
inline __host__ __device__ void operator*=(float4 &a, float b)
|
| 899 |
+
{
|
| 900 |
+
a.x *= b;
|
| 901 |
+
a.y *= b;
|
| 902 |
+
a.z *= b;
|
| 903 |
+
a.w *= b;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
inline __host__ __device__ int4 operator*(int4 a, int4 b)
|
| 907 |
+
{
|
| 908 |
+
return make_int4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 909 |
+
}
|
| 910 |
+
inline __host__ __device__ void operator*=(int4 &a, int4 b)
|
| 911 |
+
{
|
| 912 |
+
a.x *= b.x;
|
| 913 |
+
a.y *= b.y;
|
| 914 |
+
a.z *= b.z;
|
| 915 |
+
a.w *= b.w;
|
| 916 |
+
}
|
| 917 |
+
inline __host__ __device__ int4 operator*(int4 a, int b)
|
| 918 |
+
{
|
| 919 |
+
return make_int4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 920 |
+
}
|
| 921 |
+
inline __host__ __device__ int4 operator*(int b, int4 a)
|
| 922 |
+
{
|
| 923 |
+
return make_int4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 924 |
+
}
|
| 925 |
+
inline __host__ __device__ void operator*=(int4 &a, int b)
|
| 926 |
+
{
|
| 927 |
+
a.x *= b;
|
| 928 |
+
a.y *= b;
|
| 929 |
+
a.z *= b;
|
| 930 |
+
a.w *= b;
|
| 931 |
+
}
|
| 932 |
+
|
| 933 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint4 b)
|
| 934 |
+
{
|
| 935 |
+
return make_uint4(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w);
|
| 936 |
+
}
|
| 937 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint4 b)
|
| 938 |
+
{
|
| 939 |
+
a.x *= b.x;
|
| 940 |
+
a.y *= b.y;
|
| 941 |
+
a.z *= b.z;
|
| 942 |
+
a.w *= b.w;
|
| 943 |
+
}
|
| 944 |
+
inline __host__ __device__ uint4 operator*(uint4 a, uint b)
|
| 945 |
+
{
|
| 946 |
+
return make_uint4(a.x * b, a.y * b, a.z * b, a.w * b);
|
| 947 |
+
}
|
| 948 |
+
inline __host__ __device__ uint4 operator*(uint b, uint4 a)
|
| 949 |
+
{
|
| 950 |
+
return make_uint4(b * a.x, b * a.y, b * a.z, b * a.w);
|
| 951 |
+
}
|
| 952 |
+
inline __host__ __device__ void operator*=(uint4 &a, uint b)
|
| 953 |
+
{
|
| 954 |
+
a.x *= b;
|
| 955 |
+
a.y *= b;
|
| 956 |
+
a.z *= b;
|
| 957 |
+
a.w *= b;
|
| 958 |
+
}
|
| 959 |
+
|
| 960 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 961 |
+
// divide
|
| 962 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 963 |
+
|
| 964 |
+
inline __host__ __device__ float2 operator/(float2 a, float2 b)
|
| 965 |
+
{
|
| 966 |
+
return make_float2(a.x / b.x, a.y / b.y);
|
| 967 |
+
}
|
| 968 |
+
inline __host__ __device__ void operator/=(float2 &a, float2 b)
|
| 969 |
+
{
|
| 970 |
+
a.x /= b.x;
|
| 971 |
+
a.y /= b.y;
|
| 972 |
+
}
|
| 973 |
+
inline __host__ __device__ float2 operator/(float2 a, float b)
|
| 974 |
+
{
|
| 975 |
+
return make_float2(a.x / b, a.y / b);
|
| 976 |
+
}
|
| 977 |
+
inline __host__ __device__ void operator/=(float2 &a, float b)
|
| 978 |
+
{
|
| 979 |
+
a.x /= b;
|
| 980 |
+
a.y /= b;
|
| 981 |
+
}
|
| 982 |
+
inline __host__ __device__ float2 operator/(float b, float2 a)
|
| 983 |
+
{
|
| 984 |
+
return make_float2(b / a.x, b / a.y);
|
| 985 |
+
}
|
| 986 |
+
|
| 987 |
+
inline __host__ __device__ float3 operator/(float3 a, float3 b)
|
| 988 |
+
{
|
| 989 |
+
return make_float3(a.x / b.x, a.y / b.y, a.z / b.z);
|
| 990 |
+
}
|
| 991 |
+
inline __host__ __device__ void operator/=(float3 &a, float3 b)
|
| 992 |
+
{
|
| 993 |
+
a.x /= b.x;
|
| 994 |
+
a.y /= b.y;
|
| 995 |
+
a.z /= b.z;
|
| 996 |
+
}
|
| 997 |
+
inline __host__ __device__ float3 operator/(float3 a, float b)
|
| 998 |
+
{
|
| 999 |
+
return make_float3(a.x / b, a.y / b, a.z / b);
|
| 1000 |
+
}
|
| 1001 |
+
inline __host__ __device__ void operator/=(float3 &a, float b)
|
| 1002 |
+
{
|
| 1003 |
+
a.x /= b;
|
| 1004 |
+
a.y /= b;
|
| 1005 |
+
a.z /= b;
|
| 1006 |
+
}
|
| 1007 |
+
inline __host__ __device__ float3 operator/(float b, float3 a)
|
| 1008 |
+
{
|
| 1009 |
+
return make_float3(b / a.x, b / a.y, b / a.z);
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
inline __host__ __device__ float4 operator/(float4 a, float4 b)
|
| 1013 |
+
{
|
| 1014 |
+
return make_float4(a.x / b.x, a.y / b.y, a.z / b.z, a.w / b.w);
|
| 1015 |
+
}
|
| 1016 |
+
inline __host__ __device__ void operator/=(float4 &a, float4 b)
|
| 1017 |
+
{
|
| 1018 |
+
a.x /= b.x;
|
| 1019 |
+
a.y /= b.y;
|
| 1020 |
+
a.z /= b.z;
|
| 1021 |
+
a.w /= b.w;
|
| 1022 |
+
}
|
| 1023 |
+
inline __host__ __device__ float4 operator/(float4 a, float b)
|
| 1024 |
+
{
|
| 1025 |
+
return make_float4(a.x / b, a.y / b, a.z / b, a.w / b);
|
| 1026 |
+
}
|
| 1027 |
+
inline __host__ __device__ void operator/=(float4 &a, float b)
|
| 1028 |
+
{
|
| 1029 |
+
a.x /= b;
|
| 1030 |
+
a.y /= b;
|
| 1031 |
+
a.z /= b;
|
| 1032 |
+
a.w /= b;
|
| 1033 |
+
}
|
| 1034 |
+
inline __host__ __device__ float4 operator/(float b, float4 a)
|
| 1035 |
+
{
|
| 1036 |
+
return make_float4(b / a.x, b / a.y, b / a.z, b / a.w);
|
| 1037 |
+
}
|
| 1038 |
+
|
| 1039 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1040 |
+
// min
|
| 1041 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1042 |
+
|
| 1043 |
+
inline __host__ __device__ float2 fminf(float2 a, float2 b)
|
| 1044 |
+
{
|
| 1045 |
+
return make_float2(fminf(a.x,b.x), fminf(a.y,b.y));
|
| 1046 |
+
}
|
| 1047 |
+
inline __host__ __device__ float3 fminf(float3 a, float3 b)
|
| 1048 |
+
{
|
| 1049 |
+
return make_float3(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z));
|
| 1050 |
+
}
|
| 1051 |
+
inline __host__ __device__ float4 fminf(float4 a, float4 b)
|
| 1052 |
+
{
|
| 1053 |
+
return make_float4(fminf(a.x,b.x), fminf(a.y,b.y), fminf(a.z,b.z), fminf(a.w,b.w));
|
| 1054 |
+
}
|
| 1055 |
+
|
| 1056 |
+
inline __host__ __device__ int2 min(int2 a, int2 b)
|
| 1057 |
+
{
|
| 1058 |
+
return make_int2(min(a.x,b.x), min(a.y,b.y));
|
| 1059 |
+
}
|
| 1060 |
+
inline __host__ __device__ int3 min(int3 a, int3 b)
|
| 1061 |
+
{
|
| 1062 |
+
return make_int3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
| 1063 |
+
}
|
| 1064 |
+
inline __host__ __device__ int4 min(int4 a, int4 b)
|
| 1065 |
+
{
|
| 1066 |
+
return make_int4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
| 1067 |
+
}
|
| 1068 |
+
|
| 1069 |
+
inline __host__ __device__ uint2 min(uint2 a, uint2 b)
|
| 1070 |
+
{
|
| 1071 |
+
return make_uint2(min(a.x,b.x), min(a.y,b.y));
|
| 1072 |
+
}
|
| 1073 |
+
inline __host__ __device__ uint3 min(uint3 a, uint3 b)
|
| 1074 |
+
{
|
| 1075 |
+
return make_uint3(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z));
|
| 1076 |
+
}
|
| 1077 |
+
inline __host__ __device__ uint4 min(uint4 a, uint4 b)
|
| 1078 |
+
{
|
| 1079 |
+
return make_uint4(min(a.x,b.x), min(a.y,b.y), min(a.z,b.z), min(a.w,b.w));
|
| 1080 |
+
}
|
| 1081 |
+
|
| 1082 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1083 |
+
// max
|
| 1084 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1085 |
+
|
| 1086 |
+
inline __host__ __device__ float2 fmaxf(float2 a, float2 b)
|
| 1087 |
+
{
|
| 1088 |
+
return make_float2(fmaxf(a.x,b.x), fmaxf(a.y,b.y));
|
| 1089 |
+
}
|
| 1090 |
+
inline __host__ __device__ float3 fmaxf(float3 a, float3 b)
|
| 1091 |
+
{
|
| 1092 |
+
return make_float3(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z));
|
| 1093 |
+
}
|
| 1094 |
+
inline __host__ __device__ float4 fmaxf(float4 a, float4 b)
|
| 1095 |
+
{
|
| 1096 |
+
return make_float4(fmaxf(a.x,b.x), fmaxf(a.y,b.y), fmaxf(a.z,b.z), fmaxf(a.w,b.w));
|
| 1097 |
+
}
|
| 1098 |
+
|
| 1099 |
+
inline __host__ __device__ int2 max(int2 a, int2 b)
|
| 1100 |
+
{
|
| 1101 |
+
return make_int2(max(a.x,b.x), max(a.y,b.y));
|
| 1102 |
+
}
|
| 1103 |
+
inline __host__ __device__ int3 max(int3 a, int3 b)
|
| 1104 |
+
{
|
| 1105 |
+
return make_int3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
| 1106 |
+
}
|
| 1107 |
+
inline __host__ __device__ int4 max(int4 a, int4 b)
|
| 1108 |
+
{
|
| 1109 |
+
return make_int4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
| 1110 |
+
}
|
| 1111 |
+
|
| 1112 |
+
inline __host__ __device__ uint2 max(uint2 a, uint2 b)
|
| 1113 |
+
{
|
| 1114 |
+
return make_uint2(max(a.x,b.x), max(a.y,b.y));
|
| 1115 |
+
}
|
| 1116 |
+
inline __host__ __device__ uint3 max(uint3 a, uint3 b)
|
| 1117 |
+
{
|
| 1118 |
+
return make_uint3(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z));
|
| 1119 |
+
}
|
| 1120 |
+
inline __host__ __device__ uint4 max(uint4 a, uint4 b)
|
| 1121 |
+
{
|
| 1122 |
+
return make_uint4(max(a.x,b.x), max(a.y,b.y), max(a.z,b.z), max(a.w,b.w));
|
| 1123 |
+
}
|
| 1124 |
+
|
| 1125 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1126 |
+
// lerp
|
| 1127 |
+
// - linear interpolation between a and b, based on value t in [0, 1] range
|
| 1128 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1129 |
+
|
| 1130 |
+
inline __device__ __host__ float lerp(float a, float b, float t)
|
| 1131 |
+
{
|
| 1132 |
+
return a + t*(b-a);
|
| 1133 |
+
}
|
| 1134 |
+
inline __device__ __host__ float2 lerp(float2 a, float2 b, float t)
|
| 1135 |
+
{
|
| 1136 |
+
return a + t*(b-a);
|
| 1137 |
+
}
|
| 1138 |
+
inline __device__ __host__ float3 lerp(float3 a, float3 b, float t)
|
| 1139 |
+
{
|
| 1140 |
+
return a + t*(b-a);
|
| 1141 |
+
}
|
| 1142 |
+
inline __device__ __host__ float4 lerp(float4 a, float4 b, float t)
|
| 1143 |
+
{
|
| 1144 |
+
return a + t*(b-a);
|
| 1145 |
+
}
|
| 1146 |
+
|
| 1147 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1148 |
+
// clamp
|
| 1149 |
+
// - clamp the value v to be in the range [a, b]
|
| 1150 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1151 |
+
|
| 1152 |
+
inline __device__ __host__ float clamp(float f, float a, float b)
|
| 1153 |
+
{
|
| 1154 |
+
return fmaxf(a, fminf(f, b));
|
| 1155 |
+
}
|
| 1156 |
+
inline __device__ __host__ int clamp(int f, int a, int b)
|
| 1157 |
+
{
|
| 1158 |
+
return max(a, min(f, b));
|
| 1159 |
+
}
|
| 1160 |
+
inline __device__ __host__ uint clamp(uint f, uint a, uint b)
|
| 1161 |
+
{
|
| 1162 |
+
return max(a, min(f, b));
|
| 1163 |
+
}
|
| 1164 |
+
|
| 1165 |
+
inline __device__ __host__ float2 clamp(float2 v, float a, float b)
|
| 1166 |
+
{
|
| 1167 |
+
return make_float2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1168 |
+
}
|
| 1169 |
+
inline __device__ __host__ float2 clamp(float2 v, float2 a, float2 b)
|
| 1170 |
+
{
|
| 1171 |
+
return make_float2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1172 |
+
}
|
| 1173 |
+
inline __device__ __host__ float3 clamp(float3 v, float a, float b)
|
| 1174 |
+
{
|
| 1175 |
+
return make_float3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1176 |
+
}
|
| 1177 |
+
inline __device__ __host__ float3 clamp(float3 v, float3 a, float3 b)
|
| 1178 |
+
{
|
| 1179 |
+
return make_float3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1180 |
+
}
|
| 1181 |
+
inline __device__ __host__ float4 clamp(float4 v, float a, float b)
|
| 1182 |
+
{
|
| 1183 |
+
return make_float4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1184 |
+
}
|
| 1185 |
+
inline __device__ __host__ float4 clamp(float4 v, float4 a, float4 b)
|
| 1186 |
+
{
|
| 1187 |
+
return make_float4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1188 |
+
}
|
| 1189 |
+
|
| 1190 |
+
inline __device__ __host__ int2 clamp(int2 v, int a, int b)
|
| 1191 |
+
{
|
| 1192 |
+
return make_int2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1193 |
+
}
|
| 1194 |
+
inline __device__ __host__ int2 clamp(int2 v, int2 a, int2 b)
|
| 1195 |
+
{
|
| 1196 |
+
return make_int2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1197 |
+
}
|
| 1198 |
+
inline __device__ __host__ int3 clamp(int3 v, int a, int b)
|
| 1199 |
+
{
|
| 1200 |
+
return make_int3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1201 |
+
}
|
| 1202 |
+
inline __device__ __host__ int3 clamp(int3 v, int3 a, int3 b)
|
| 1203 |
+
{
|
| 1204 |
+
return make_int3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1205 |
+
}
|
| 1206 |
+
inline __device__ __host__ int4 clamp(int4 v, int a, int b)
|
| 1207 |
+
{
|
| 1208 |
+
return make_int4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1209 |
+
}
|
| 1210 |
+
inline __device__ __host__ int4 clamp(int4 v, int4 a, int4 b)
|
| 1211 |
+
{
|
| 1212 |
+
return make_int4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1213 |
+
}
|
| 1214 |
+
|
| 1215 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint a, uint b)
|
| 1216 |
+
{
|
| 1217 |
+
return make_uint2(clamp(v.x, a, b), clamp(v.y, a, b));
|
| 1218 |
+
}
|
| 1219 |
+
inline __device__ __host__ uint2 clamp(uint2 v, uint2 a, uint2 b)
|
| 1220 |
+
{
|
| 1221 |
+
return make_uint2(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y));
|
| 1222 |
+
}
|
| 1223 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint a, uint b)
|
| 1224 |
+
{
|
| 1225 |
+
return make_uint3(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b));
|
| 1226 |
+
}
|
| 1227 |
+
inline __device__ __host__ uint3 clamp(uint3 v, uint3 a, uint3 b)
|
| 1228 |
+
{
|
| 1229 |
+
return make_uint3(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z));
|
| 1230 |
+
}
|
| 1231 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint a, uint b)
|
| 1232 |
+
{
|
| 1233 |
+
return make_uint4(clamp(v.x, a, b), clamp(v.y, a, b), clamp(v.z, a, b), clamp(v.w, a, b));
|
| 1234 |
+
}
|
| 1235 |
+
inline __device__ __host__ uint4 clamp(uint4 v, uint4 a, uint4 b)
|
| 1236 |
+
{
|
| 1237 |
+
return make_uint4(clamp(v.x, a.x, b.x), clamp(v.y, a.y, b.y), clamp(v.z, a.z, b.z), clamp(v.w, a.w, b.w));
|
| 1238 |
+
}
|
| 1239 |
+
|
| 1240 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1241 |
+
// dot product
|
| 1242 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1243 |
+
|
| 1244 |
+
inline __host__ __device__ float dot(float2 a, float2 b)
|
| 1245 |
+
{
|
| 1246 |
+
return a.x * b.x + a.y * b.y;
|
| 1247 |
+
}
|
| 1248 |
+
inline __host__ __device__ float dot(float3 a, float3 b)
|
| 1249 |
+
{
|
| 1250 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1251 |
+
}
|
| 1252 |
+
inline __host__ __device__ float dot(float4 a, float4 b)
|
| 1253 |
+
{
|
| 1254 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1255 |
+
}
|
| 1256 |
+
|
| 1257 |
+
inline __host__ __device__ int dot(int2 a, int2 b)
|
| 1258 |
+
{
|
| 1259 |
+
return a.x * b.x + a.y * b.y;
|
| 1260 |
+
}
|
| 1261 |
+
inline __host__ __device__ int dot(int3 a, int3 b)
|
| 1262 |
+
{
|
| 1263 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1264 |
+
}
|
| 1265 |
+
inline __host__ __device__ int dot(int4 a, int4 b)
|
| 1266 |
+
{
|
| 1267 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1268 |
+
}
|
| 1269 |
+
|
| 1270 |
+
inline __host__ __device__ uint dot(uint2 a, uint2 b)
|
| 1271 |
+
{
|
| 1272 |
+
return a.x * b.x + a.y * b.y;
|
| 1273 |
+
}
|
| 1274 |
+
inline __host__ __device__ uint dot(uint3 a, uint3 b)
|
| 1275 |
+
{
|
| 1276 |
+
return a.x * b.x + a.y * b.y + a.z * b.z;
|
| 1277 |
+
}
|
| 1278 |
+
inline __host__ __device__ uint dot(uint4 a, uint4 b)
|
| 1279 |
+
{
|
| 1280 |
+
return a.x * b.x + a.y * b.y + a.z * b.z + a.w * b.w;
|
| 1281 |
+
}
|
| 1282 |
+
|
| 1283 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1284 |
+
// length
|
| 1285 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1286 |
+
|
| 1287 |
+
inline __host__ __device__ float length(float2 v)
|
| 1288 |
+
{
|
| 1289 |
+
return sqrtf(dot(v, v));
|
| 1290 |
+
}
|
| 1291 |
+
inline __host__ __device__ float length(float3 v)
|
| 1292 |
+
{
|
| 1293 |
+
return sqrtf(dot(v, v));
|
| 1294 |
+
}
|
| 1295 |
+
inline __host__ __device__ float length(float4 v)
|
| 1296 |
+
{
|
| 1297 |
+
return sqrtf(dot(v, v));
|
| 1298 |
+
}
|
| 1299 |
+
|
| 1300 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1301 |
+
// normalize
|
| 1302 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1303 |
+
|
| 1304 |
+
inline __host__ __device__ float2 normalize(float2 v)
|
| 1305 |
+
{
|
| 1306 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1307 |
+
return v * invLen;
|
| 1308 |
+
}
|
| 1309 |
+
inline __host__ __device__ float3 normalize(float3 v)
|
| 1310 |
+
{
|
| 1311 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1312 |
+
return v * invLen;
|
| 1313 |
+
}
|
| 1314 |
+
inline __host__ __device__ float4 normalize(float4 v)
|
| 1315 |
+
{
|
| 1316 |
+
float invLen = rsqrtf(dot(v, v));
|
| 1317 |
+
return v * invLen;
|
| 1318 |
+
}
|
| 1319 |
+
|
| 1320 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1321 |
+
// floor
|
| 1322 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1323 |
+
|
| 1324 |
+
inline __host__ __device__ float2 floorf(float2 v)
|
| 1325 |
+
{
|
| 1326 |
+
return make_float2(floorf(v.x), floorf(v.y));
|
| 1327 |
+
}
|
| 1328 |
+
inline __host__ __device__ float3 floorf(float3 v)
|
| 1329 |
+
{
|
| 1330 |
+
return make_float3(floorf(v.x), floorf(v.y), floorf(v.z));
|
| 1331 |
+
}
|
| 1332 |
+
inline __host__ __device__ float4 floorf(float4 v)
|
| 1333 |
+
{
|
| 1334 |
+
return make_float4(floorf(v.x), floorf(v.y), floorf(v.z), floorf(v.w));
|
| 1335 |
+
}
|
| 1336 |
+
|
| 1337 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1338 |
+
// frac - returns the fractional portion of a scalar or each vector component
|
| 1339 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1340 |
+
|
| 1341 |
+
inline __host__ __device__ float fracf(float v)
|
| 1342 |
+
{
|
| 1343 |
+
return v - floorf(v);
|
| 1344 |
+
}
|
| 1345 |
+
inline __host__ __device__ float2 fracf(float2 v)
|
| 1346 |
+
{
|
| 1347 |
+
return make_float2(fracf(v.x), fracf(v.y));
|
| 1348 |
+
}
|
| 1349 |
+
inline __host__ __device__ float3 fracf(float3 v)
|
| 1350 |
+
{
|
| 1351 |
+
return make_float3(fracf(v.x), fracf(v.y), fracf(v.z));
|
| 1352 |
+
}
|
| 1353 |
+
inline __host__ __device__ float4 fracf(float4 v)
|
| 1354 |
+
{
|
| 1355 |
+
return make_float4(fracf(v.x), fracf(v.y), fracf(v.z), fracf(v.w));
|
| 1356 |
+
}
|
| 1357 |
+
|
| 1358 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1359 |
+
// fmod
|
| 1360 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1361 |
+
|
| 1362 |
+
inline __host__ __device__ float2 fmodf(float2 a, float2 b)
|
| 1363 |
+
{
|
| 1364 |
+
return make_float2(fmodf(a.x, b.x), fmodf(a.y, b.y));
|
| 1365 |
+
}
|
| 1366 |
+
inline __host__ __device__ float3 fmodf(float3 a, float3 b)
|
| 1367 |
+
{
|
| 1368 |
+
return make_float3(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z));
|
| 1369 |
+
}
|
| 1370 |
+
inline __host__ __device__ float4 fmodf(float4 a, float4 b)
|
| 1371 |
+
{
|
| 1372 |
+
return make_float4(fmodf(a.x, b.x), fmodf(a.y, b.y), fmodf(a.z, b.z), fmodf(a.w, b.w));
|
| 1373 |
+
}
|
| 1374 |
+
|
| 1375 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1376 |
+
// absolute value
|
| 1377 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1378 |
+
|
| 1379 |
+
inline __host__ __device__ float2 fabs(float2 v)
|
| 1380 |
+
{
|
| 1381 |
+
return make_float2(fabs(v.x), fabs(v.y));
|
| 1382 |
+
}
|
| 1383 |
+
inline __host__ __device__ float3 fabs(float3 v)
|
| 1384 |
+
{
|
| 1385 |
+
return make_float3(fabs(v.x), fabs(v.y), fabs(v.z));
|
| 1386 |
+
}
|
| 1387 |
+
inline __host__ __device__ float4 fabs(float4 v)
|
| 1388 |
+
{
|
| 1389 |
+
return make_float4(fabs(v.x), fabs(v.y), fabs(v.z), fabs(v.w));
|
| 1390 |
+
}
|
| 1391 |
+
|
| 1392 |
+
inline __host__ __device__ int2 abs(int2 v)
|
| 1393 |
+
{
|
| 1394 |
+
return make_int2(abs(v.x), abs(v.y));
|
| 1395 |
+
}
|
| 1396 |
+
inline __host__ __device__ int3 abs(int3 v)
|
| 1397 |
+
{
|
| 1398 |
+
return make_int3(abs(v.x), abs(v.y), abs(v.z));
|
| 1399 |
+
}
|
| 1400 |
+
inline __host__ __device__ int4 abs(int4 v)
|
| 1401 |
+
{
|
| 1402 |
+
return make_int4(abs(v.x), abs(v.y), abs(v.z), abs(v.w));
|
| 1403 |
+
}
|
| 1404 |
+
|
| 1405 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1406 |
+
// reflect
|
| 1407 |
+
// - returns reflection of incident ray I around surface normal N
|
| 1408 |
+
// - N should be normalized, reflected vector's length is equal to length of I
|
| 1409 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1410 |
+
|
| 1411 |
+
inline __host__ __device__ float3 reflect(float3 i, float3 n)
|
| 1412 |
+
{
|
| 1413 |
+
return i - 2.0f * n * dot(n,i);
|
| 1414 |
+
}
|
| 1415 |
+
|
| 1416 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1417 |
+
// cross product
|
| 1418 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1419 |
+
|
| 1420 |
+
inline __host__ __device__ float3 cross(float3 a, float3 b)
|
| 1421 |
+
{
|
| 1422 |
+
return make_float3(a.y*b.z - a.z*b.y, a.z*b.x - a.x*b.z, a.x*b.y - a.y*b.x);
|
| 1423 |
+
}
|
| 1424 |
+
|
| 1425 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1426 |
+
// smoothstep
|
| 1427 |
+
// - returns 0 if x < a
|
| 1428 |
+
// - returns 1 if x > b
|
| 1429 |
+
// - otherwise returns smooth interpolation between 0 and 1 based on x
|
| 1430 |
+
////////////////////////////////////////////////////////////////////////////////
|
| 1431 |
+
|
| 1432 |
+
inline __device__ __host__ float smoothstep(float a, float b, float x)
|
| 1433 |
+
{
|
| 1434 |
+
float y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1435 |
+
return (y*y*(3.0f - (2.0f*y)));
|
| 1436 |
+
}
|
| 1437 |
+
inline __device__ __host__ float2 smoothstep(float2 a, float2 b, float2 x)
|
| 1438 |
+
{
|
| 1439 |
+
float2 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1440 |
+
return (y*y*(make_float2(3.0f) - (make_float2(2.0f)*y)));
|
| 1441 |
+
}
|
| 1442 |
+
inline __device__ __host__ float3 smoothstep(float3 a, float3 b, float3 x)
|
| 1443 |
+
{
|
| 1444 |
+
float3 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1445 |
+
return (y*y*(make_float3(3.0f) - (make_float3(2.0f)*y)));
|
| 1446 |
+
}
|
| 1447 |
+
inline __device__ __host__ float4 smoothstep(float4 a, float4 b, float4 x)
|
| 1448 |
+
{
|
| 1449 |
+
float4 y = clamp((x - a) / (b - a), 0.0f, 1.0f);
|
| 1450 |
+
return (y*y*(make_float4(3.0f) - (make_float4(2.0f)*y)));
|
| 1451 |
+
}
|
| 1452 |
+
|
| 1453 |
+
#endif
|
dva/mvp/extensions/utils/makefile
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
all:
|
| 2 |
+
python setup.py build_ext --inplace
|
dva/mvp/extensions/utils/setup.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from setuptools import setup
|
| 8 |
+
|
| 9 |
+
from torch.utils.cpp_extension import CUDAExtension, BuildExtension
|
| 10 |
+
|
| 11 |
+
if __name__ == "__main__":
|
| 12 |
+
import torch
|
| 13 |
+
setup(
|
| 14 |
+
name="utils",
|
| 15 |
+
ext_modules=[
|
| 16 |
+
CUDAExtension(
|
| 17 |
+
"utilslib",
|
| 18 |
+
sources=["utils.cpp", "utils_kernel.cu"],
|
| 19 |
+
extra_compile_args={
|
| 20 |
+
"nvcc": [
|
| 21 |
+
"-arch=sm_70",
|
| 22 |
+
"-std=c++14",
|
| 23 |
+
"-lineinfo",
|
| 24 |
+
]
|
| 25 |
+
}
|
| 26 |
+
)
|
| 27 |
+
],
|
| 28 |
+
cmdclass={"build_ext": BuildExtension}
|
| 29 |
+
)
|
dva/mvp/extensions/utils/utils.cpp
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
// All rights reserved.
|
| 3 |
+
//
|
| 4 |
+
// This source code is licensed under the license found in the
|
| 5 |
+
// LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
#include <torch/extension.h>
|
| 8 |
+
#include <c10/cuda/CUDAStream.h>
|
| 9 |
+
|
| 10 |
+
#include <vector>
|
| 11 |
+
|
| 12 |
+
void compute_raydirs_forward_cuda(
|
| 13 |
+
int N, int H, int W,
|
| 14 |
+
float * viewposim,
|
| 15 |
+
float * viewrotim,
|
| 16 |
+
float * focalim,
|
| 17 |
+
float * princptim,
|
| 18 |
+
float * pixelcoordsim,
|
| 19 |
+
float volradius,
|
| 20 |
+
float * raypos,
|
| 21 |
+
float * raydir,
|
| 22 |
+
float * tminmax,
|
| 23 |
+
cudaStream_t stream);
|
| 24 |
+
|
| 25 |
+
void compute_raydirs_backward_cuda(
|
| 26 |
+
int N, int H, int W,
|
| 27 |
+
float * viewposim,
|
| 28 |
+
float * viewrotim,
|
| 29 |
+
float * focalim,
|
| 30 |
+
float * princptim,
|
| 31 |
+
float * pixelcoordsim,
|
| 32 |
+
float volradius,
|
| 33 |
+
float * raypos,
|
| 34 |
+
float * raydir,
|
| 35 |
+
float * tminmax,
|
| 36 |
+
float * grad_viewposim,
|
| 37 |
+
float * grad_viewrotim,
|
| 38 |
+
float * grad_focalim,
|
| 39 |
+
float * grad_princptim,
|
| 40 |
+
cudaStream_t stream);
|
| 41 |
+
|
| 42 |
+
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
|
| 43 |
+
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
| 44 |
+
#define CHECK_INPUT(x) CHECK_CUDA((x)); CHECK_CONTIGUOUS((x))
|
| 45 |
+
|
| 46 |
+
std::vector<torch::Tensor> compute_raydirs_forward(
|
| 47 |
+
torch::Tensor viewposim,
|
| 48 |
+
torch::Tensor viewrotim,
|
| 49 |
+
torch::Tensor focalim,
|
| 50 |
+
torch::Tensor princptim,
|
| 51 |
+
torch::optional<torch::Tensor> pixelcoordsim,
|
| 52 |
+
int W, int H,
|
| 53 |
+
float volradius,
|
| 54 |
+
torch::Tensor rayposim,
|
| 55 |
+
torch::Tensor raydirim,
|
| 56 |
+
torch::Tensor tminmaxim) {
|
| 57 |
+
CHECK_INPUT(viewposim);
|
| 58 |
+
CHECK_INPUT(viewrotim);
|
| 59 |
+
CHECK_INPUT(focalim);
|
| 60 |
+
CHECK_INPUT(princptim);
|
| 61 |
+
if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); }
|
| 62 |
+
CHECK_INPUT(rayposim);
|
| 63 |
+
CHECK_INPUT(raydirim);
|
| 64 |
+
CHECK_INPUT(tminmaxim);
|
| 65 |
+
|
| 66 |
+
int N = viewposim.size(0);
|
| 67 |
+
assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W));
|
| 68 |
+
|
| 69 |
+
compute_raydirs_forward_cuda(N, H, W,
|
| 70 |
+
reinterpret_cast<float *>(viewposim.data_ptr()),
|
| 71 |
+
reinterpret_cast<float *>(viewrotim.data_ptr()),
|
| 72 |
+
reinterpret_cast<float *>(focalim.data_ptr()),
|
| 73 |
+
reinterpret_cast<float *>(princptim.data_ptr()),
|
| 74 |
+
pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr,
|
| 75 |
+
volradius,
|
| 76 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
| 77 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
| 78 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
| 79 |
+
0);
|
| 80 |
+
|
| 81 |
+
return {};
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
std::vector<torch::Tensor> compute_raydirs_backward(
|
| 85 |
+
torch::Tensor viewposim,
|
| 86 |
+
torch::Tensor viewrotim,
|
| 87 |
+
torch::Tensor focalim,
|
| 88 |
+
torch::Tensor princptim,
|
| 89 |
+
torch::optional<torch::Tensor> pixelcoordsim,
|
| 90 |
+
int W, int H,
|
| 91 |
+
float volradius,
|
| 92 |
+
torch::Tensor rayposim,
|
| 93 |
+
torch::Tensor raydirim,
|
| 94 |
+
torch::Tensor tminmaxim,
|
| 95 |
+
torch::Tensor grad_viewpos,
|
| 96 |
+
torch::Tensor grad_viewrot,
|
| 97 |
+
torch::Tensor grad_focal,
|
| 98 |
+
torch::Tensor grad_princpt) {
|
| 99 |
+
CHECK_INPUT(viewposim);
|
| 100 |
+
CHECK_INPUT(viewrotim);
|
| 101 |
+
CHECK_INPUT(focalim);
|
| 102 |
+
CHECK_INPUT(princptim);
|
| 103 |
+
if (pixelcoordsim) { CHECK_INPUT(*pixelcoordsim); }
|
| 104 |
+
CHECK_INPUT(rayposim);
|
| 105 |
+
CHECK_INPUT(raydirim);
|
| 106 |
+
CHECK_INPUT(tminmaxim);
|
| 107 |
+
CHECK_INPUT(grad_viewpos);
|
| 108 |
+
CHECK_INPUT(grad_viewrot);
|
| 109 |
+
CHECK_INPUT(grad_focal);
|
| 110 |
+
CHECK_INPUT(grad_princpt);
|
| 111 |
+
|
| 112 |
+
int N = viewposim.size(0);
|
| 113 |
+
assert(!pixelcoordsim || (pixelcoordsim.size(1) == H && pixelcoordsim.size(2) == W));
|
| 114 |
+
|
| 115 |
+
compute_raydirs_backward_cuda(N, H, W,
|
| 116 |
+
reinterpret_cast<float *>(viewposim.data_ptr()),
|
| 117 |
+
reinterpret_cast<float *>(viewrotim.data_ptr()),
|
| 118 |
+
reinterpret_cast<float *>(focalim.data_ptr()),
|
| 119 |
+
reinterpret_cast<float *>(princptim.data_ptr()),
|
| 120 |
+
pixelcoordsim ? reinterpret_cast<float *>(pixelcoordsim->data_ptr()) : nullptr,
|
| 121 |
+
volradius,
|
| 122 |
+
reinterpret_cast<float *>(rayposim.data_ptr()),
|
| 123 |
+
reinterpret_cast<float *>(raydirim.data_ptr()),
|
| 124 |
+
reinterpret_cast<float *>(tminmaxim.data_ptr()),
|
| 125 |
+
reinterpret_cast<float *>(grad_viewpos.data_ptr()),
|
| 126 |
+
reinterpret_cast<float *>(grad_viewrot.data_ptr()),
|
| 127 |
+
reinterpret_cast<float *>(grad_focal.data_ptr()),
|
| 128 |
+
reinterpret_cast<float *>(grad_princpt.data_ptr()),
|
| 129 |
+
0);
|
| 130 |
+
|
| 131 |
+
return {};
|
| 132 |
+
}
|
| 133 |
+
|
| 134 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 135 |
+
m.def("compute_raydirs_forward", &compute_raydirs_forward, "raydirs forward (CUDA)");
|
| 136 |
+
m.def("compute_raydirs_backward", &compute_raydirs_backward, "raydirs backward (CUDA)");
|
| 137 |
+
}
|
dva/mvp/extensions/utils/utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
from torch.autograd import Function
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from . import utilslib
|
| 17 |
+
except:
|
| 18 |
+
import utilslib
|
| 19 |
+
|
| 20 |
+
class ComputeRaydirs(Function):
|
| 21 |
+
@staticmethod
|
| 22 |
+
def forward(self, viewpos, viewrot, focal, princpt, pixelcoords, volradius):
|
| 23 |
+
for tensor in [viewpos, viewrot, focal, princpt, pixelcoords]:
|
| 24 |
+
assert tensor.is_contiguous()
|
| 25 |
+
|
| 26 |
+
N = viewpos.size(0)
|
| 27 |
+
if isinstance(pixelcoords, tuple):
|
| 28 |
+
W, H = pixelcoords
|
| 29 |
+
pixelcoords = None
|
| 30 |
+
else:
|
| 31 |
+
H = pixelcoords.size(1)
|
| 32 |
+
W = pixelcoords.size(2)
|
| 33 |
+
|
| 34 |
+
raypos = torch.empty((N, H, W, 3), device=viewpos.device)
|
| 35 |
+
raydirs = torch.empty((N, H, W, 3), device=viewpos.device)
|
| 36 |
+
tminmax = torch.empty((N, H, W, 2), device=viewpos.device)
|
| 37 |
+
utilslib.compute_raydirs_forward(viewpos, viewrot, focal, princpt,
|
| 38 |
+
pixelcoords, W, H, volradius, raypos, raydirs, tminmax)
|
| 39 |
+
|
| 40 |
+
return raypos, raydirs, tminmax
|
| 41 |
+
|
| 42 |
+
@staticmethod
|
| 43 |
+
def backward(self, grad_raydirs, grad_tminmax):
|
| 44 |
+
return None, None, None, None, None, None
|
| 45 |
+
|
| 46 |
+
def compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius):
|
| 47 |
+
raypos, raydirs, tminmax = ComputeRaydirs.apply(viewpos, viewrot, focal, princpt, pixelcoords, volradius)
|
| 48 |
+
return raypos, raydirs, tminmax
|
| 49 |
+
|
| 50 |
+
class Rodrigues(nn.Module):
|
| 51 |
+
def __init__(self):
|
| 52 |
+
super(Rodrigues, self).__init__()
|
| 53 |
+
|
| 54 |
+
def forward(self, rvec):
|
| 55 |
+
theta = torch.sqrt(1e-5 + torch.sum(rvec ** 2, dim=1))
|
| 56 |
+
rvec = rvec / theta[:, None]
|
| 57 |
+
costh = torch.cos(theta)
|
| 58 |
+
sinth = torch.sin(theta)
|
| 59 |
+
return torch.stack((
|
| 60 |
+
rvec[:, 0] ** 2 + (1. - rvec[:, 0] ** 2) * costh,
|
| 61 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) - rvec[:, 2] * sinth,
|
| 62 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) + rvec[:, 1] * sinth,
|
| 63 |
+
|
| 64 |
+
rvec[:, 0] * rvec[:, 1] * (1. - costh) + rvec[:, 2] * sinth,
|
| 65 |
+
rvec[:, 1] ** 2 + (1. - rvec[:, 1] ** 2) * costh,
|
| 66 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) - rvec[:, 0] * sinth,
|
| 67 |
+
|
| 68 |
+
rvec[:, 0] * rvec[:, 2] * (1. - costh) - rvec[:, 1] * sinth,
|
| 69 |
+
rvec[:, 1] * rvec[:, 2] * (1. - costh) + rvec[:, 0] * sinth,
|
| 70 |
+
rvec[:, 2] ** 2 + (1. - rvec[:, 2] ** 2) * costh), dim=1).view(-1, 3, 3)
|
| 71 |
+
|
| 72 |
+
def gradcheck():
|
| 73 |
+
N = 2
|
| 74 |
+
H = 64
|
| 75 |
+
W = 64
|
| 76 |
+
k3 = 4
|
| 77 |
+
K = k3*k3*k3
|
| 78 |
+
|
| 79 |
+
M = 32
|
| 80 |
+
volradius = 1.
|
| 81 |
+
|
| 82 |
+
# generate random inputs
|
| 83 |
+
torch.manual_seed(1113)
|
| 84 |
+
|
| 85 |
+
rodrigues = Rodrigues()
|
| 86 |
+
|
| 87 |
+
_viewpos = torch.tensor([[-0.0, 0.0, -4.] for n in range(N)], device="cuda") + torch.randn(N, 3, device="cuda") * 0.1
|
| 88 |
+
viewrvec = torch.randn(N, 3, device="cuda") * 0.01
|
| 89 |
+
_viewrot = rodrigues(viewrvec)
|
| 90 |
+
|
| 91 |
+
_focal = torch.tensor([[W*4.0, W*4.0] for n in range(N)], device="cuda")
|
| 92 |
+
_princpt = torch.tensor([[W*0.5, H*0.5] for n in range(N)], device="cuda")
|
| 93 |
+
pixely, pixelx = torch.meshgrid(torch.arange(H, device="cuda").float(), torch.arange(W, device="cuda").float())
|
| 94 |
+
_pixelcoords = torch.stack([pixelx, pixely], dim=-1)[None, :, :, :].repeat(N, 1, 1, 1)
|
| 95 |
+
|
| 96 |
+
_viewpos = _viewpos.contiguous().detach().clone()
|
| 97 |
+
_viewpos.requires_grad = True
|
| 98 |
+
_viewrot = _viewrot.contiguous().detach().clone()
|
| 99 |
+
_viewrot.requires_grad = True
|
| 100 |
+
_focal = _focal.contiguous().detach().clone()
|
| 101 |
+
_focal.requires_grad = True
|
| 102 |
+
_princpt = _princpt.contiguous().detach().clone()
|
| 103 |
+
_princpt.requires_grad = True
|
| 104 |
+
_pixelcoords = _pixelcoords.contiguous().detach().clone()
|
| 105 |
+
_pixelcoords.requires_grad = True
|
| 106 |
+
|
| 107 |
+
max_len = 6.0
|
| 108 |
+
_stepsize = max_len / 15.5
|
| 109 |
+
|
| 110 |
+
params = [_viewpos, _viewrot, _focal, _princpt]
|
| 111 |
+
paramnames = ["viewpos", "viewrot", "focal", "princpt"]
|
| 112 |
+
|
| 113 |
+
########################### run pytorch version ###########################
|
| 114 |
+
|
| 115 |
+
viewpos = _viewpos
|
| 116 |
+
viewrot = _viewrot
|
| 117 |
+
focal = _focal
|
| 118 |
+
princpt = _princpt
|
| 119 |
+
pixelcoords = _pixelcoords
|
| 120 |
+
|
| 121 |
+
raypos = viewpos[:, None, None, :].repeat(1, H, W, 1)
|
| 122 |
+
|
| 123 |
+
raydir = (pixelcoords - princpt[:, None, None, :]) / focal[:, None, None, :]
|
| 124 |
+
raydir = torch.cat([raydir, torch.ones_like(raydir[:, :, :, 0:1])], dim=-1)
|
| 125 |
+
raydir = torch.sum(viewrot[:, None, None, :, :] * raydir[:, :, :, :, None], dim=-2)
|
| 126 |
+
raydir = raydir / torch.sqrt(torch.sum(raydir ** 2, dim=-1, keepdim=True))
|
| 127 |
+
|
| 128 |
+
t1 = (-1. - viewpos[:, None, None, :]) / raydir
|
| 129 |
+
t2 = ( 1. - viewpos[:, None, None, :]) / raydir
|
| 130 |
+
tmin = torch.max(torch.min(t1[..., 0], t2[..., 0]),
|
| 131 |
+
torch.max(torch.min(t1[..., 1], t2[..., 1]),
|
| 132 |
+
torch.min(t1[..., 2], t2[..., 2]))).clamp(min=0.)
|
| 133 |
+
tmax = torch.min(torch.max(t1[..., 0], t2[..., 0]),
|
| 134 |
+
torch.min(torch.max(t1[..., 1], t2[..., 1]),
|
| 135 |
+
torch.max(t1[..., 2], t2[..., 2])))
|
| 136 |
+
|
| 137 |
+
tminmax = torch.stack([tmin, tmax], dim=-1)
|
| 138 |
+
|
| 139 |
+
sample0 = raydir
|
| 140 |
+
|
| 141 |
+
torch.cuda.synchronize()
|
| 142 |
+
time1 = time.time()
|
| 143 |
+
|
| 144 |
+
sample0.backward(torch.ones_like(sample0))
|
| 145 |
+
|
| 146 |
+
torch.cuda.synchronize()
|
| 147 |
+
time2 = time.time()
|
| 148 |
+
|
| 149 |
+
grads0 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
|
| 150 |
+
|
| 151 |
+
for p in params:
|
| 152 |
+
if p.grad is not None:
|
| 153 |
+
p.grad.detach_()
|
| 154 |
+
p.grad.zero_()
|
| 155 |
+
|
| 156 |
+
############################## run cuda version ###########################
|
| 157 |
+
|
| 158 |
+
viewpos = _viewpos
|
| 159 |
+
viewrot = _viewrot
|
| 160 |
+
focal = _focal
|
| 161 |
+
princpt = _princpt
|
| 162 |
+
pixelcoords = _pixelcoords
|
| 163 |
+
|
| 164 |
+
niter = 1
|
| 165 |
+
|
| 166 |
+
for p in params:
|
| 167 |
+
if p.grad is not None:
|
| 168 |
+
p.grad.detach_()
|
| 169 |
+
p.grad.zero_()
|
| 170 |
+
t0 = time.time()
|
| 171 |
+
torch.cuda.synchronize()
|
| 172 |
+
|
| 173 |
+
sample1 = compute_raydirs(viewpos, viewrot, focal, princpt, pixelcoords, volradius)[1]
|
| 174 |
+
|
| 175 |
+
t1 = time.time()
|
| 176 |
+
torch.cuda.synchronize()
|
| 177 |
+
|
| 178 |
+
print("-----------------------------------------------------------------")
|
| 179 |
+
print("{:>10} {:>10} {:>10} {:>10} {:>10} {:>10}".format("", "maxabsdiff", "dp", "index", "py", "cuda"))
|
| 180 |
+
ind = torch.argmax(torch.abs(sample0 - sample1))
|
| 181 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
| 182 |
+
"fwd",
|
| 183 |
+
torch.max(torch.abs(sample0 - sample1)).item(),
|
| 184 |
+
(torch.sum(sample0 * sample1) / torch.sqrt(torch.sum(sample0 * sample0) * torch.sum(sample1 * sample1))).item(),
|
| 185 |
+
ind.item(),
|
| 186 |
+
sample0.view(-1)[ind].item(),
|
| 187 |
+
sample1.view(-1)[ind].item()))
|
| 188 |
+
|
| 189 |
+
sample1.backward(torch.ones_like(sample1), retain_graph=True)
|
| 190 |
+
|
| 191 |
+
torch.cuda.synchronize()
|
| 192 |
+
t2 = time.time()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
print("{:<10} {:10.5} {:10.5} {:10.5}".format("time", tf / niter, tb / niter, (tf + tb) / niter))
|
| 196 |
+
grads1 = [p.grad.detach().clone() if p.grad is not None else None for p in params]
|
| 197 |
+
|
| 198 |
+
############# compare results #############
|
| 199 |
+
|
| 200 |
+
for p, g0, g1 in zip(paramnames, grads0, grads1):
|
| 201 |
+
ind = torch.argmax(torch.abs(g0 - g1))
|
| 202 |
+
print("{:<10} {:>10.5} {:>10.5} {:>10} {:>10.5} {:>10.5}".format(
|
| 203 |
+
p,
|
| 204 |
+
torch.max(torch.abs(g0 - g1)).item(),
|
| 205 |
+
(torch.sum(g0 * g1) / torch.sqrt(torch.sum(g0 * g0) * torch.sum(g1 * g1))).item(),
|
| 206 |
+
ind.item(),
|
| 207 |
+
g0.view(-1)[ind].item(),
|
| 208 |
+
g1.view(-1)[ind].item()))
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
gradcheck()
|