import einops
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import spaces
import torch
from torch.utils.data import DataLoader
from dataset import SynLiDAR
from drum import DRUM
DESCRIPTION = """
DRUM: Diffusion-based Raydrop-aware Unpaired Mapping for Sim2Real LiDAR Segmentation
ICRA 2026
1Kyushu University
2NASA Jet Propulsion Laboratory
This demo performs Sim2Real LiDAR translation by our method;
converting random simulation samples of the SynLiDAR dataset into pseudo-real samples in the KITTI style.
"""
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device)
r2dm, lidar_utils, cfg = torch.hub.load(
repo_or_dir="kazuto1011/r2dm",
model="pretrained_r2dm",
config="r2dm-h-kittiraw-300k-icra2026",
device=device,
trust_repo=True,
)
drum = DRUM(r2dm)
dataset = SynLiDAR(shape=cfg.data.resolution, revision="sub")
def colorize(img: torch.Tensor) -> np.ndarray:
img = img.detach().float().clamp(0, 1).cpu().numpy()
rgb = plt.get_cmap("turbo")(img)[..., :3]
return (rgb * 255).astype(np.uint8)
@spaces.GPU
def run(
nfe: int, num_harmonization_steps: int, batch_size: int, progress=gr.Progress()
):
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
item = next(iter(loader))
depth_sim = item["depth"].to(device)
depth_sim = lidar_utils.convert_depth(depth_sim)
depth_sim = lidar_utils.normalize(depth_sim)
result = drum(
depth_sim,
num_steps=int(nfe),
num_harmonization=int(num_harmonization_steps),
progress_bar=progress.tqdm(range(int(nfe))),
)
x = lidar_utils.denormalize(result)
x[:, [0]] = lidar_utils.revert_depth(x[:, [0]]) / 80
x = einops.rearrange(x, "b c h w -> b (c h) w")
y = lidar_utils.denormalize(depth_sim)
y = lidar_utils.revert_depth(y) / 80
y = torch.cat([y, torch.zeros_like(y)], dim=1)
y = einops.rearrange(y, "b c h w -> b (c h) w")
x = colorize(x)
y = colorize(y)
x = einops.rearrange(x, "b ch w c -> (b ch) w c")
y = einops.rearrange(y, "b ch w c -> (b ch) w c")
return y, x
with gr.Blocks() as demo:
gr.HTML(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Hyperparameters:**")
nfe = gr.Dropdown(
choices=[2**i for i in range(0, 9)],
value=32,
label="Number of function evaluations (NFE)",
)
num_harmonization_steps = gr.Dropdown(
choices=range(1, 9),
value=3,
label="Number of harmonization steps",
)
batch_size = gr.Dropdown(
choices=[1, 2, 4, 8],
value=8,
label="Batch size",
)
run_button = gr.Button("Run DRUM 🥁", variant="primary")
with gr.Column(scale=2):
gr.Markdown(
"**Result:** depth and reflectance channels are visualized in stacked horizontally."
)
comparison = gr.ImageSlider(
label="Simulation (left) vs Pseudo-real (right)",
type="numpy",
max_height=1000,
)
run_button.click(
fn=run,
inputs=[nfe, num_harmonization_steps, batch_size],
outputs=[comparison],
)
demo.queue()
demo.launch(
css="""
.head {
text-align: center;
display: block;
font-size: var(--text-xl);
}
.title {
font-size: var(--text-xxl);
font-weight: bold;
margin-top: 2rem;
}
.description {
font-size: var(--text-lg);
}
""",
theme=gr.themes.Ocean(),
)