drum / app.py
kazuto1011's picture
Update app.py
7c6d11c verified
import einops
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import spaces
import torch
from torch.utils.data import DataLoader
from dataset import SynLiDAR
from drum import DRUM
DESCRIPTION = """
<div class="head">
<div class="title">DRUM: Diffusion-based Raydrop-aware Unpaired Mapping for Sim2Real LiDAR Segmentation</div>
<div class="conference">ICRA 2026</div>
<div class="authors">
<a href="https://miyawakitomoya.com/" target="_blank" rel="noopener"> Tomoya Miyawaki</a><sup>1</sup>
&nbsp;&nbsp;&nbsp;
<a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a><sup>1</sup>
&nbsp;&nbsp;&nbsp;
<a href="https://www-robotics.jpl.nasa.gov/who-we-are/people/yumi_iwashita/" target="_blank" rel="noopener"> Yumi Iwashita</a><sup>2</sup>
&nbsp;&nbsp;&nbsp;
<a href="https://robotics.ait.kyushu-u.ac.jp/kurazume/en/" target="_blank" rel="noopener"> Ryo Kurazume</a><sup>1</sup>
</div>
<div class="affiliations">
<sup>1</sup>Kyushu University
&nbsp;&nbsp;&nbsp;
<sup>2</sup>NASA Jet Propulsion Laboratory
</div>
<div class="materials">
<a href="https://miya-tomoya.github.io/drum/">Project</a> |
<a href="https://arxiv.org/abs/2603.26263">Paper</a> |
<a href="https://github.com/miya-tomoya/drum">Code</a>
</div>
<br>
<div class="description">
This demo performs Sim2Real LiDAR translation by our method;
converting random <b>simulation</b> samples of the SynLiDAR dataset into <b>pseudo-real</b> samples in the KITTI style.<br>
</div>
<br>
</div>
"""
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
device = torch.device(device)
r2dm, lidar_utils, cfg = torch.hub.load(
repo_or_dir="kazuto1011/r2dm",
model="pretrained_r2dm",
config="r2dm-h-kittiraw-300k-icra2026",
device=device,
trust_repo=True,
)
drum = DRUM(r2dm)
dataset = SynLiDAR(shape=cfg.data.resolution, revision="sub")
def colorize(img: torch.Tensor) -> np.ndarray:
img = img.detach().float().clamp(0, 1).cpu().numpy()
rgb = plt.get_cmap("turbo")(img)[..., :3]
return (rgb * 255).astype(np.uint8)
@spaces.GPU
def run(
nfe: int, num_harmonization_steps: int, batch_size: int, progress=gr.Progress()
):
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
item = next(iter(loader))
depth_sim = item["depth"].to(device)
depth_sim = lidar_utils.convert_depth(depth_sim)
depth_sim = lidar_utils.normalize(depth_sim)
result = drum(
depth_sim,
num_steps=int(nfe),
num_harmonization=int(num_harmonization_steps),
progress_bar=progress.tqdm(range(int(nfe))),
)
x = lidar_utils.denormalize(result)
x[:, [0]] = lidar_utils.revert_depth(x[:, [0]]) / 80
x = einops.rearrange(x, "b c h w -> b (c h) w")
y = lidar_utils.denormalize(depth_sim)
y = lidar_utils.revert_depth(y) / 80
y = torch.cat([y, torch.zeros_like(y)], dim=1)
y = einops.rearrange(y, "b c h w -> b (c h) w")
x = colorize(x)
y = colorize(y)
x = einops.rearrange(x, "b ch w c -> (b ch) w c")
y = einops.rearrange(y, "b ch w c -> (b ch) w c")
return y, x
with gr.Blocks() as demo:
gr.HTML(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("**Hyperparameters:**")
nfe = gr.Dropdown(
choices=[2**i for i in range(0, 9)],
value=32,
label="Number of function evaluations (NFE)",
)
num_harmonization_steps = gr.Dropdown(
choices=range(1, 9),
value=3,
label="Number of harmonization steps",
)
batch_size = gr.Dropdown(
choices=[1, 2, 4, 8],
value=8,
label="Batch size",
)
run_button = gr.Button("Run DRUM 🥁", variant="primary")
with gr.Column(scale=2):
gr.Markdown(
"**Result:** depth and reflectance channels are visualized in stacked horizontally."
)
comparison = gr.ImageSlider(
label="Simulation (left) vs Pseudo-real (right)",
type="numpy",
max_height=1000,
)
run_button.click(
fn=run,
inputs=[nfe, num_harmonization_steps, batch_size],
outputs=[comparison],
)
demo.queue()
demo.launch(
css="""
.head {
text-align: center;
display: block;
font-size: var(--text-xl);
}
.title {
font-size: var(--text-xxl);
font-weight: bold;
margin-top: 2rem;
}
.description {
font-size: var(--text-lg);
}
""",
theme=gr.themes.Ocean(),
)