import einops import gradio as gr import matplotlib.pyplot as plt import numpy as np import spaces import torch from torch.utils.data import DataLoader from dataset import SynLiDAR from drum import DRUM DESCRIPTION = """
DRUM: Diffusion-based Raydrop-aware Unpaired Mapping for Sim2Real LiDAR Segmentation
ICRA 2026
Tomoya Miyawaki1     Kazuto Nakashima1     Yumi Iwashita2     Ryo Kurazume1
1Kyushu University     2NASA Jet Propulsion Laboratory
Project | Paper | Code

This demo performs Sim2Real LiDAR translation by our method; converting random simulation samples of the SynLiDAR dataset into pseudo-real samples in the KITTI style.

""" if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" device = torch.device(device) r2dm, lidar_utils, cfg = torch.hub.load( repo_or_dir="kazuto1011/r2dm", model="pretrained_r2dm", config="r2dm-h-kittiraw-300k-icra2026", device=device, trust_repo=True, ) drum = DRUM(r2dm) dataset = SynLiDAR(shape=cfg.data.resolution, revision="sub") def colorize(img: torch.Tensor) -> np.ndarray: img = img.detach().float().clamp(0, 1).cpu().numpy() rgb = plt.get_cmap("turbo")(img)[..., :3] return (rgb * 255).astype(np.uint8) @spaces.GPU def run( nfe: int, num_harmonization_steps: int, batch_size: int, progress=gr.Progress() ): loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) item = next(iter(loader)) depth_sim = item["depth"].to(device) depth_sim = lidar_utils.convert_depth(depth_sim) depth_sim = lidar_utils.normalize(depth_sim) result = drum( depth_sim, num_steps=int(nfe), num_harmonization=int(num_harmonization_steps), progress_bar=progress.tqdm(range(int(nfe))), ) x = lidar_utils.denormalize(result) x[:, [0]] = lidar_utils.revert_depth(x[:, [0]]) / 80 x = einops.rearrange(x, "b c h w -> b (c h) w") y = lidar_utils.denormalize(depth_sim) y = lidar_utils.revert_depth(y) / 80 y = torch.cat([y, torch.zeros_like(y)], dim=1) y = einops.rearrange(y, "b c h w -> b (c h) w") x = colorize(x) y = colorize(y) x = einops.rearrange(x, "b ch w c -> (b ch) w c") y = einops.rearrange(y, "b ch w c -> (b ch) w c") return y, x with gr.Blocks() as demo: gr.HTML(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): gr.Markdown("**Hyperparameters:**") nfe = gr.Dropdown( choices=[2**i for i in range(0, 9)], value=32, label="Number of function evaluations (NFE)", ) num_harmonization_steps = gr.Dropdown( choices=range(1, 9), value=3, label="Number of harmonization steps", ) batch_size = gr.Dropdown( choices=[1, 2, 4, 8], value=8, label="Batch size", ) run_button = gr.Button("Run DRUM 🥁", variant="primary") with gr.Column(scale=2): gr.Markdown( "**Result:** depth and reflectance channels are visualized in stacked horizontally." ) comparison = gr.ImageSlider( label="Simulation (left) vs Pseudo-real (right)", type="numpy", max_height=1000, ) run_button.click( fn=run, inputs=[nfe, num_harmonization_steps, batch_size], outputs=[comparison], ) demo.queue() demo.launch( css=""" .head { text-align: center; display: block; font-size: var(--text-xl); } .title { font-size: var(--text-xxl); font-weight: bold; margin-top: 2rem; } .description { font-size: var(--text-lg); } """, theme=gr.themes.Ocean(), )