Spaces:
Running on Zero
Running on Zero
| import einops | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import spaces | |
| import torch | |
| from torch.utils.data import DataLoader | |
| from dataset import SynLiDAR | |
| from drum import DRUM | |
| DESCRIPTION = """ | |
| <div class="head"> | |
| <div class="title">DRUM: Diffusion-based Raydrop-aware Unpaired Mapping for Sim2Real LiDAR Segmentation</div> | |
| <div class="conference">ICRA 2026</div> | |
| <div class="authors"> | |
| <a href="https://miyawakitomoya.com/" target="_blank" rel="noopener"> Tomoya Miyawaki</a><sup>1</sup> | |
| | |
| <a href="https://kazuto1011.github.io/" target="_blank" rel="noopener"> Kazuto Nakashima</a><sup>1</sup> | |
| | |
| <a href="https://www-robotics.jpl.nasa.gov/who-we-are/people/yumi_iwashita/" target="_blank" rel="noopener"> Yumi Iwashita</a><sup>2</sup> | |
| | |
| <a href="https://robotics.ait.kyushu-u.ac.jp/kurazume/en/" target="_blank" rel="noopener"> Ryo Kurazume</a><sup>1</sup> | |
| </div> | |
| <div class="affiliations"> | |
| <sup>1</sup>Kyushu University | |
| | |
| <sup>2</sup>NASA Jet Propulsion Laboratory | |
| </div> | |
| <div class="materials"> | |
| <a href="https://miya-tomoya.github.io/drum/">Project</a> | | |
| <a href="https://arxiv.org/abs/2603.26263">Paper</a> | | |
| <a href="https://github.com/miya-tomoya/drum">Code</a> | |
| </div> | |
| <br> | |
| <div class="description"> | |
| This demo performs Sim2Real LiDAR translation by our method; | |
| converting random <b>simulation</b> samples of the SynLiDAR dataset into <b>pseudo-real</b> samples in the KITTI style.<br> | |
| </div> | |
| <br> | |
| </div> | |
| """ | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| device = "mps" | |
| else: | |
| device = "cpu" | |
| device = torch.device(device) | |
| r2dm, lidar_utils, cfg = torch.hub.load( | |
| repo_or_dir="kazuto1011/r2dm", | |
| model="pretrained_r2dm", | |
| config="r2dm-h-kittiraw-300k-icra2026", | |
| device=device, | |
| trust_repo=True, | |
| ) | |
| drum = DRUM(r2dm) | |
| dataset = SynLiDAR(shape=cfg.data.resolution, revision="sub") | |
| def colorize(img: torch.Tensor) -> np.ndarray: | |
| img = img.detach().float().clamp(0, 1).cpu().numpy() | |
| rgb = plt.get_cmap("turbo")(img)[..., :3] | |
| return (rgb * 255).astype(np.uint8) | |
| def run( | |
| nfe: int, num_harmonization_steps: int, batch_size: int, progress=gr.Progress() | |
| ): | |
| loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False) | |
| item = next(iter(loader)) | |
| depth_sim = item["depth"].to(device) | |
| depth_sim = lidar_utils.convert_depth(depth_sim) | |
| depth_sim = lidar_utils.normalize(depth_sim) | |
| result = drum( | |
| depth_sim, | |
| num_steps=int(nfe), | |
| num_harmonization=int(num_harmonization_steps), | |
| progress_bar=progress.tqdm(range(int(nfe))), | |
| ) | |
| x = lidar_utils.denormalize(result) | |
| x[:, [0]] = lidar_utils.revert_depth(x[:, [0]]) / 80 | |
| x = einops.rearrange(x, "b c h w -> b (c h) w") | |
| y = lidar_utils.denormalize(depth_sim) | |
| y = lidar_utils.revert_depth(y) / 80 | |
| y = torch.cat([y, torch.zeros_like(y)], dim=1) | |
| y = einops.rearrange(y, "b c h w -> b (c h) w") | |
| x = colorize(x) | |
| y = colorize(y) | |
| x = einops.rearrange(x, "b ch w c -> (b ch) w c") | |
| y = einops.rearrange(y, "b ch w c -> (b ch) w c") | |
| return y, x | |
| with gr.Blocks() as demo: | |
| gr.HTML(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("**Hyperparameters:**") | |
| nfe = gr.Dropdown( | |
| choices=[2**i for i in range(0, 9)], | |
| value=32, | |
| label="Number of function evaluations (NFE)", | |
| ) | |
| num_harmonization_steps = gr.Dropdown( | |
| choices=range(1, 9), | |
| value=3, | |
| label="Number of harmonization steps", | |
| ) | |
| batch_size = gr.Dropdown( | |
| choices=[1, 2, 4, 8], | |
| value=8, | |
| label="Batch size", | |
| ) | |
| run_button = gr.Button("Run DRUM 🥁", variant="primary") | |
| with gr.Column(scale=2): | |
| gr.Markdown( | |
| "**Result:** depth and reflectance channels are visualized in stacked horizontally." | |
| ) | |
| comparison = gr.ImageSlider( | |
| label="Simulation (left) vs Pseudo-real (right)", | |
| type="numpy", | |
| max_height=1000, | |
| ) | |
| run_button.click( | |
| fn=run, | |
| inputs=[nfe, num_harmonization_steps, batch_size], | |
| outputs=[comparison], | |
| ) | |
| demo.queue() | |
| demo.launch( | |
| css=""" | |
| .head { | |
| text-align: center; | |
| display: block; | |
| font-size: var(--text-xl); | |
| } | |
| .title { | |
| font-size: var(--text-xxl); | |
| font-weight: bold; | |
| margin-top: 2rem; | |
| } | |
| .description { | |
| font-size: var(--text-lg); | |
| } | |
| """, | |
| theme=gr.themes.Ocean(), | |
| ) | |