brian4dwell commited on
Commit
594b88c
·
1 Parent(s): 5b0c756

add STream3r

Browse files
Files changed (2) hide show
  1. README.md +342 -12
  2. app.py +761 -4
README.md CHANGED
@@ -1,12 +1,342 @@
1
- ---
2
- title: Dwellbot Stream3r
3
- emoji: 📚
4
- colorFrom: blue
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.46.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+ <h1>
3
+ STream3R: Scalable Sequential 3D Reconstruction with Causal Transformer
4
+ </h1>
5
+ </div>
6
+
7
+ <div align="center">
8
+ <h4>
9
+ <a href="https://nirvanalan.github.io/projects/stream3r" target='_blank'>
10
+ <img src="https://img.shields.io/badge/🐳-Project%20Page-blue">
11
+ </a>
12
+ <a href="https://arxiv.org/abs/2508.10893" target='_blank'>
13
+ <img src="https://img.shields.io/badge/arXiv-2508.10893-b31b1b.svg">
14
+ </a>
15
+ <img src="https://visitor-badge.laobi.icu/badge?page_id=yhluo.STream3R">
16
+ </h4>
17
+ <div >
18
+ <a href='https://nirvanalan.github.io/' target='_blank'>Yushi Lan</a><sup>1*</sup>&emsp;
19
+ <a href='https://scholar.google.com/citations?user=fZxK2B0AAAAJ&hl' target='_blank'>Yihang Luo</a><sup>1*</sup>&emsp;
20
+ <a href='https://hongfz16.github.io' target='_blank'>Fangzhou Hong</a><sup>1</sup>&emsp;
21
+ <a href='https://shangchenzhou.com/' target='_blank'>Shangchen Zhou</a><sup>1</sup>&emsp;
22
+ <a href='https://chenhonghua.github.io/clay.github.io/' target='_blank'>Honghua Chen</a><sup>1</sup>&emsp;
23
+ <br>
24
+ <a href='https://zhaoyanglyu.github.io/' target='_blank'>Zhaoyang Lyu</a><sup>2</sup>&emsp;
25
+ <a href='https://williamyang1991.github.io/' target='_blank'>Shuai Yang</a><sup>3</sup>&emsp;
26
+ <a href='https://daibo.info/' target='_blank'>Bo Dai</a>
27
+ <sup>4</sup>
28
+ <a href='https://www.mmlab-ntu.com/person/ccloy/' target='_blank'>Chen Change Loy</a>
29
+ <sup>1</sup> &emsp;
30
+ <a href='https://xingangpan.github.io/' target='_blank'>Xingang Pan</a>
31
+ <sup>1</sup>
32
+ </div>
33
+ <div>
34
+ S-Lab, Nanyang Technological University<sup>1</sup>;
35
+ <!-- &emsp; -->
36
+ <br>
37
+ Shanghai Artificial Intelligence Laboratory<sup>2</sup>;
38
+ WICT, Peking University<sup>3</sup>;
39
+ The University of Hong Kong <sup>4</sup>
40
+ <!-- <br>
41
+ <sup>*</sup>corresponding author -->
42
+ </div>
43
+ </div>
44
+
45
+ <br>
46
+
47
+ <div align="center">
48
+ <p>
49
+ <span style="font-variant: small-caps;"><strong>STream3R</strong></span> reformulates dense 3D reconstruction into a sequential registration task with causal attention.
50
+ <br>
51
+ <i>⭐ Now supports <b>FlashAttention</b>, <b>KV Cache</b>, <b>Causal Attention</b>, <b>Sliding Window Attention</b>, and <b>Full Attention</b>!</i>
52
+ </p>
53
+ <img width="820" alt="pipeline" src="assets/teaser_dynamic.gif">
54
+ :open_book: See more visual results on our <a href="https://nirvanalan.github.io/projects/stream3r" target="_blank">project page</a>
55
+ </div>
56
+
57
+ <br>
58
+
59
+ <details>
60
+ <summary><b>Abstract</b></summary>
61
+ <br>
62
+ <div align="center">
63
+ <img width="820" alt="pipeline" src="assets/pipeline.png">
64
+ <p align="justify">
65
+ We present STream3R, a novel approach to 3D reconstruction that reformulates pointmap prediction as a decoder-only Transformer problem. Existing state-of-the-art methods for multi-view reconstruction either depend on expensive global optimization or rely on simplistic memory mechanisms that scale poorly with sequence length. In contrast, STream3R introduces an streaming framework that processes image sequences efficiently using causal attention, inspired by advances in modern language modeling. By learning geometric priors from large-scale 3D datasets, STream3R generalizes well to diverse and challenging scenarios, including dynamic scenes where traditional methods often fail. Extensive experiments show that our method consistently outperforms prior work across both static and dynamic scene benchmarks. Moreover, STream3R is inherently compatible with LLM-style training infrastructure, enabling efficient large-scale pretraining and fine-tuning for various downstream 3D tasks. Our results underscore the potential of causal Transformer models for online 3D perception, paving the way for real-time 3D understanding in streaming environments.
66
+ </p>
67
+ </div>
68
+ </details>
69
+
70
+
71
+ ## :fire: News
72
+
73
+ - [Sep 16, 2025] The complete training code is released!
74
+ - [Aug 22, 2025] The evaluation code is now available!
75
+ - [Aug 15, 2025] Our inference code and weights are released!
76
+
77
+
78
+ ## 🔧 Installation
79
+
80
+ 1. Clone Repo
81
+ ```bash
82
+ git clone https://github.com/NIRVANALAN/STream3R
83
+ cd STream3R
84
+ ```
85
+
86
+ 2. Create Conda Environment
87
+ ```bash
88
+ conda create -n stream3r python=3.11 cmake=3.14.0 -y
89
+ conda activate stream3r
90
+ ```
91
+ 3. Install Python Dependencies
92
+
93
+ **Important:** Install [Torch](https://pytorch.org/get-started/locally/) based on your CUDA version. For example, for *Torch 2.8.0 + CUDA 12.6*:
94
+
95
+ ```
96
+ # Install Torch
97
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
98
+
99
+ # Install other dependencies
100
+ pip install -r requirements.txt
101
+
102
+ # Install STream3R as a package
103
+ pip install -e .
104
+ ```
105
+
106
+ ## :computer: Inference
107
+
108
+ You can now try STream3R with the following code. The checkpoint will be downloaded automatically from [Hugging Face](https://huggingface.co/yslan/STream3R).
109
+
110
+ You can set the inference mode to `causal` for causal attention, `window` for sliding window attention (with a default window size of 5), or `full` for bidirectional attention.
111
+
112
+ ```python
113
+ import os
114
+ import torch
115
+ from stream3r.models.stream3r import STream3R
116
+ from stream3r.models.components.utils.load_fn import load_and_preprocess_images
117
+
118
+ device = "cuda" if torch.cuda.is_available() else "cpu"
119
+
120
+ model = STream3R.from_pretrained("yslan/STream3R").to(device)
121
+
122
+ example_dir = "examples/static_room"
123
+ image_names = [os.path.join(example_dir, file) for file in sorted(os.listdir(example_dir))]
124
+ images = load_and_preprocess_images(image_names).to(device)
125
+
126
+ with torch.no_grad():
127
+ # Use one mode "causal", "window", or "full" in a single forward pass
128
+ predictions = model(images, mode="causal")
129
+ ```
130
+
131
+ We also support a KV cache version to enable streaming input using `StreamSession`. The `StreamSession` takes sequential input and processes them one by one, making it suitable for real-time or low-latency applications. This streaming 3D reconstruction pipeline can be applied in various scenarios such as real-time robotics, autonomous navigation, online 3D understanding and SLAM. An example usage is shown below:
132
+
133
+ ```python
134
+ import os
135
+ import torch
136
+ from stream3r.models.stream3r import STream3R
137
+ from stream3r.stream_session import StreamSession
138
+ from stream3r.models.components.utils.load_fn import load_and_preprocess_images
139
+
140
+ device = "cuda" if torch.cuda.is_available() else "cpu"
141
+
142
+ model = STream3R.from_pretrained("yslan/STream3R").to(device)
143
+
144
+ example_dir = "examples/static_room"
145
+ image_names = [os.path.join(example_dir, file) for file in sorted(os.listdir(example_dir))]
146
+ images = load_and_preprocess_images(image_names).to(device)
147
+ # StreamSession supports KV cache management for both "causal" and "window" modes.
148
+ session = StreamSession(model, mode="causal")
149
+
150
+ with torch.no_grad():
151
+ # Process images one by one to simulate streaming inference
152
+ for i in range(images.shape[0]):
153
+ image = images[i : i + 1]
154
+ predictions = session.forward_stream(image)
155
+ session.clear()
156
+ ```
157
+
158
+ ## :zap: Demo
159
+ You can run the demo built on [VGG-T's code](https://github.com/facebookresearch/vggt) using the script [`app.py`](app.py) with the following command:
160
+
161
+ ```sh
162
+ python app.py
163
+ ```
164
+
165
+ ## 📁 Code Structure
166
+
167
+ The repository is structured as follows:
168
+
169
+ ```
170
+ STream3R/
171
+ ├── stream3r/
172
+ │ ├── models/
173
+ │ │ ├── stream3r.py
174
+ │ │ ├── multiview_dust3r_module.py
175
+ │ │ └── components/
176
+ │ ├── dust3r/
177
+ │ ├── croco/
178
+ │ ├── utils/
179
+ │ └── stream_session.py
180
+ ├── configs/
181
+ ├── examples/
182
+ ├── assets/
183
+ ├── app.py
184
+ ├── requirements.txt
185
+ ├── setup.py
186
+ └── README.md
187
+ ```
188
+
189
+ ## :100: Quantitive Results
190
+
191
+ *3D Reconstruction Comparison on NRGBD.*
192
+
193
+ | Method | Type | Acc Mean ↓ | Acc Med. ↓ | Comp Mean ↓ | Comp Med. ↓ | NC Mean ↑ | NC Med. ↑ |
194
+ |---------------------|----------|------------|------------|-------------|-------------|-----------|-----------|
195
+ | VGG-T | FA | 0.073 | 0.018 | 0.077 | 0.021 | 0.910 | 0.990 |
196
+ | DUSt3R | Optim | 0.144 | 0.019 | 0.154 | 0.018 | 0.870 | 0.982 |
197
+ | MASt3R | Optim | 0.085 | 0.033 | 0.063 | 0.028 | 0.794 | 0.928 |
198
+ | MonST3R | Optim | 0.272 | 0.114 | 0.287 | 0.110 | 0.758 | 0.843 |
199
+ | Spann3R | Stream | 0.416 | 0.323 | 0.417 | 0.285 | 0.684 | 0.789 |
200
+ | CUT3R | Stream | 0.099 | 0.031 | 0.076 | 0.026 | 0.837 | 0.971 |
201
+ | StreamVGGT | Stream | 0.084 | 0.044 | 0.074 | 0.041 | 0.861 | 0.986 |
202
+ | Ours | Stream | **0.057** | **0.014** | **0.028** | **0.013** | **0.910** | **0.993** |
203
+
204
+ Read our [full paper](https://arxiv.org/abs/2508.10893) for more insights.
205
+
206
+ ## ⏳ GPU Memory Usage and Runtime
207
+
208
+ We report the peak GPU memory usage (VRAM) and runtime of our full model for processing each streaming input using the `StreamSession` implementation. All experiments were conducted at a common resolution of 518 × 384 on a single H200 GPU. The benchmark includes both *Causal* for causal attention and *Window* for sliding window attention with a window size of 5.
209
+
210
+
211
+ *Run Time (s).*
212
+ | Num of Frames | 1 | 20 | 40 | 80 | 100 | 120 | 140 | 180 | 200 |
213
+ |-----------|--------|--------|--------|--------|--------|--------|--------|--------|--------|
214
+ | Causal | 0.1164 | 0.2034 | 0.3060 | 0.4986 | 0.5945 | 0.6947 | 0.7916 | 0.9911 | 1.1703 |
215
+ | Window | 0.1167 | 0.1528 | 0.1523 | 0.1517 | 0.1515 | 0.1512 | 0.1482 | 0.1443 | 0.1463 |
216
+
217
+
218
+ *VRAM (GB).*
219
+ | Num of Frames | 1 | 20 | 40 | 80 | 100 | 120 | 140 | 180 | 200 |
220
+ |-----------|--------|--------|--------|--------|--------|--------|--------|--------|--------|
221
+ | Causal | 5.49 | 9.02 | 12.92 | 21.00 | 25.03 | 29.10 | 33.21 | 41.31 | 45.41 |
222
+ | Window | 5.49 | 6.53 | 6.53 | 6.53 | 6.53 | 6.53 | 6.53 | 6.53 | 6.53 |
223
+
224
+
225
+ ## :hotsprings: Training
226
+
227
+ 1. Prepare Dataset
228
+
229
+ We follow [CUT3R](https://github.com/CUT3R/CUT3R/blob/main/docs/preprocess.md) to preprocess the dataset for training.
230
+
231
+ 2. Set Up Config
232
+
233
+ Update training config file ```configs/experiment/stream3r/stream3r.yaml``` as needed. For example:
234
+ - Set `pretrained` to the path of the [VGG-T checkpoint](https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt).
235
+ - Set `data_root` to the directory where you saved the processed dataset.
236
+
237
+ 3. Launch training with:
238
+ ```bash
239
+ python stream3r/train.py experiment=stream3r/stream3r
240
+ ```
241
+
242
+ 4. After training, you can convert the checkpoint into a `state_dict` file, for example:
243
+ ```python
244
+ from lightning.pytorch.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
245
+
246
+ convert_zero_checkpoint_to_fp32_state_dict(
247
+ checkpoint_dir="logs/stream3r/runs/stream3r_99999/checkpoints/000-00002000.ckpt",
248
+ output_file="logs/stream3r/runs/stream3r_99999/checkpoints/last_aggregated.ckpt",
249
+ tag=None
250
+ )
251
+ ```
252
+
253
+ ## 📈 Evaluation
254
+
255
+ The evaluation follows [MonST3R](https://github.com/Junyi42/monst3r) and [Spann3R](https://github.com/HengyiWang/spann3r), [CUT3R](https://github.com/CUT3R/CUT3R).
256
+
257
+ 1. Prepare Evaluation Dataset
258
+
259
+ We follow the dataset preparation guides from [MonST3R](https://github.com/Junyi42/monst3r/blob/main/data/evaluation_script.md) and [Spann3R](https://github.com/HengyiWang/spann3r/blob/main/docs/data_preprocess.md) to prepare the datasets. For convenience, we provide the processed datasets on [Hugging Face](https://huggingface.co/datasets/yslan/pointmap_regression_evalsets), which can be downloaded directly.
260
+
261
+ The datasets should be organized as follows under the root directiory of the project:
262
+ ```
263
+ data/
264
+ ├── 7scenes
265
+ ├── bonn
266
+ ├── kitti
267
+ ├── neural_rgbd
268
+ ├── nyu-v2
269
+ ├── scannetv2
270
+ ├── sintel
271
+ └── tum
272
+ ```
273
+
274
+ 2. Run Evaluation
275
+
276
+ Use the provided scripts to evaluate different tasks.
277
+
278
+ *For Video Depth and Camera Pose Estimation, some datasets contain more than 100 images. To reduce memory usage, we use `StreamSession` to process frames sequentially while managing the KV cache.*
279
+
280
+ ### Monodepth
281
+
282
+ ```bash
283
+ bash eval/monodepth/run.sh
284
+ ```
285
+ Results will be saved in `eval_results/monodepth/${model_name}/${data}/metric.json`.
286
+
287
+ ### Video Depth
288
+
289
+ ```bash
290
+ bash eval/video_depth/run.sh
291
+ ```
292
+ Results will be saved in `eval_results/video_depth/${model_name}/${data}/result_scale.json`.
293
+
294
+ ### Camera Pose Estimation
295
+
296
+ ```bash
297
+ bash eval/relpose/run.sh
298
+ ```
299
+ Results will be saved in `eval_results/relpose/${model_name}/${data}/_error_log.txt`.
300
+
301
+ ### Multi-view Reconstruction
302
+
303
+ ```bash
304
+ bash eval/mv_recon/run.sh
305
+ ```
306
+ Results will be saved in `eval_results/mv_recon/${model_name}/${data}/logs_all.txt`.
307
+
308
+
309
+ ## :calendar: TODO
310
+
311
+ - [x] Release evaluation code.
312
+ - [x] Release training code.
313
+ - [ ] Release the metric-scale version.
314
+
315
+
316
+ ## :page_with_curl: License
317
+
318
+ This project is licensed under <a rel="license" href="./LICENSE">NTU S-Lab License 1.0</a>. Redistribution and use should follow this license.
319
+
320
+ ## :pencil: Citation
321
+
322
+ If you find our code or paper helps, please consider citing:
323
+
324
+ ```bibtex
325
+ @article{stream3r2025,
326
+ title={STream3R: Scalable Sequential 3D Reconstruction with Causal Transformer},
327
+ author={Lan, Yushi and Luo, Yihang and Hong, Fangzhou and Zhou, Shangchen and Chen, Honghua and Lyu, Zhaoyang and Yang, Shuai and Dai, Bo and Loy, Chen Change and Pan, Xingang},
328
+ booktitle={arXiv preprint arXiv:2508.10893},
329
+ year={2025}
330
+ }
331
+ ```
332
+ ## :pencil: Acknowledgments
333
+ We recognize several concurrent works on streaming methods. We encourage you to check them out:
334
+
335
+ [StreamVGGT](https://github.com/wzzheng/StreamVGGT) &nbsp;|&nbsp; [CUT3R](https://github.com/CUT3R/CUT3R) &nbsp;|&nbsp; [SLAM3R](https://github.com/PKU-VCL-3DV/SLAM3R) &nbsp;|&nbsp; [Spann3R](https://github.com/HengyiWang/spann3r)
336
+
337
+ STream3R is built on the shoulders of several outstanding open-source projects. Many thanks to the following exceptional projects:
338
+
339
+ [VGG-T](https://github.com/facebookresearch/vggt) &nbsp;|&nbsp; [Fast3R](https://github.com/facebookresearch/fast3r) &nbsp;|&nbsp; [DUSt3R](https://github.com/naver/dust3r) &nbsp;|&nbsp; [MonST3R](https://github.com/Junyi42/monst3r) &nbsp;|&nbsp; [Viser](https://github.com/nerfstudio-project/viser)
340
+
341
+ ## :mailbox: Contact
342
+ If you have any question, please feel free to contact us via `lanyushi15@gmail.com` or Github issues.
app.py CHANGED
@@ -1,7 +1,764 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import cv2
9
+ import torch
10
+ import numpy as np
11
  import gradio as gr
12
+ import shutil
13
+ from datetime import datetime
14
+ import glob
15
+ import gc
16
+ import time
17
+
18
+ from stream3r.models.stream3r import STream3R
19
+ from stream3r.stream_session import StreamSession
20
+ from stream3r.models.components.utils.load_fn import load_and_preprocess_images
21
+ from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri
22
+ from stream3r.models.components.utils.geometry import unproject_depth_map_to_point_map
23
+ from stream3r.utils.visual_utils import predictions_to_glb
24
+
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ model = STream3R.from_pretrained("yslan/STream3R")
29
+
30
+ # -------------------------------------------------------------------------
31
+ # 1) Core model inference
32
+ # -------------------------------------------------------------------------
33
+ def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> dict:
34
+ """
35
+ Run the STream3R model on images in the 'target_dir/images' folder and return predictions.
36
+
37
+ Args:
38
+ target_dir: Directory containing the images subfolder
39
+ model: STream3R model instance
40
+ mode: Processing mode ("causal", "window", or "full")
41
+ streaming: If True, use StreamSession for sequential processing; if False, use batch processing
42
+ """
43
+ print(f"Processing images from {target_dir}")
44
+
45
+ # Device check
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ if not torch.cuda.is_available():
48
+ raise ValueError("CUDA is not available. Check your environment.")
49
+
50
+ # Move model to device
51
+ model = model.to(device)
52
+ model.eval()
53
+
54
+ # Load and preprocess images
55
+ image_names = glob.glob(os.path.join(target_dir, "images", "*"))
56
+ image_names = sorted(image_names)
57
+ print(f"Found {len(image_names)} images")
58
+ if len(image_names) == 0:
59
+ raise ValueError("No images found. Check your upload.")
60
+
61
+ images = load_and_preprocess_images(image_names).to(device)
62
+ print(f"Preprocessed images shape: {images.shape}")
63
+
64
+ # Run inference
65
+ print(f"Running inference in {'streaming' if streaming else 'batch'} mode...")
66
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
67
+
68
+ with torch.no_grad():
69
+ with torch.amp.autocast(dtype=dtype, device_type=device):
70
+ if streaming:
71
+ # Use StreamSession for sequential processing
72
+ if mode == "full":
73
+ print("Warning: Streaming mode does not support 'full' attention mode. Switching to 'causal' mode.")
74
+ mode = "causal"
75
+
76
+ session = StreamSession(model, mode=mode)
77
+
78
+ # Process images one by one to simulate streaming inference
79
+ for i in range(images.shape[0]):
80
+ image = images[i : i + 1]
81
+ predictions = session.forward_stream(image)
82
+
83
+ session.clear()
84
+ else:
85
+ # Use batch processing (original behavior)
86
+ predictions = model(images, mode=mode)
87
+
88
+ # Convert pose encoding to extrinsic and intrinsic matrices
89
+ print("Converting pose encoding to extrinsic and intrinsic matrices...")
90
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
91
+ predictions["extrinsic"] = extrinsic
92
+ predictions["intrinsic"] = intrinsic
93
+
94
+ # Convert tensors to numpy
95
+ for key in predictions.keys():
96
+ if isinstance(predictions[key], torch.Tensor):
97
+ predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
98
+ predictions['pose_enc_list'] = None # remove pose_enc_list
99
+
100
+ # Generate world points from depth map
101
+ print("Computing world points from depth map...")
102
+ depth_map = predictions["depth"] # (S, H, W, 1)
103
+ world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
104
+ predictions["world_points_from_depth"] = world_points
105
+
106
+ # Clean up
107
+ torch.cuda.empty_cache()
108
+ return predictions
109
+
110
+
111
+ # -------------------------------------------------------------------------
112
+ # 2) Handle uploaded video/images --> produce target_dir + images
113
+ # -------------------------------------------------------------------------
114
+ def handle_uploads(input_video, input_images):
115
+ """
116
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
117
+ images or extracted frames from video into it. Return (target_dir, image_paths).
118
+ """
119
+ start_time = time.time()
120
+ gc.collect()
121
+ torch.cuda.empty_cache()
122
+
123
+ # Create a unique folder name
124
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
125
+ target_dir = os.path.join("demo_cache", f"input_images_{timestamp}")
126
+ target_dir_images = os.path.join(target_dir, "images")
127
+
128
+ # Clean up if somehow that folder already exists
129
+ if os.path.exists(target_dir):
130
+ shutil.rmtree(target_dir)
131
+ os.makedirs(target_dir)
132
+ os.makedirs(target_dir_images)
133
+
134
+ image_paths = []
135
+
136
+ # --- Handle images ---
137
+ if input_images is not None:
138
+ for file_data in input_images:
139
+ if isinstance(file_data, dict) and "name" in file_data:
140
+ file_path = file_data["name"]
141
+ else:
142
+ file_path = file_data
143
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
144
+ shutil.copy(file_path, dst_path)
145
+ image_paths.append(dst_path)
146
+
147
+ # --- Handle video ---
148
+ if input_video is not None:
149
+ if isinstance(input_video, dict) and "name" in input_video:
150
+ video_path = input_video["name"]
151
+ else:
152
+ video_path = input_video
153
+
154
+ vs = cv2.VideoCapture(video_path)
155
+ fps = vs.get(cv2.CAP_PROP_FPS)
156
+ frame_interval = int(fps * 1) # 1 frame/sec
157
+
158
+ count = 0
159
+ video_frame_num = 0
160
+ while True:
161
+ gotit, frame = vs.read()
162
+ if not gotit:
163
+ break
164
+ count += 1
165
+ if count % frame_interval == 0:
166
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
167
+ cv2.imwrite(image_path, frame)
168
+ image_paths.append(image_path)
169
+ video_frame_num += 1
170
+
171
+ # Sort final images for gallery
172
+ image_paths = sorted(image_paths)
173
+
174
+ end_time = time.time()
175
+ print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
176
+ return target_dir, image_paths
177
+
178
+
179
+ # -------------------------------------------------------------------------
180
+ # 3) Update gallery on upload
181
+ # -------------------------------------------------------------------------
182
+ def update_gallery_on_upload(input_video, input_images):
183
+ """
184
+ Whenever user uploads or changes files, immediately handle them
185
+ and show in the gallery. Return (target_dir, image_paths).
186
+ If nothing is uploaded, returns "None" and empty list.
187
+ """
188
+ if not input_video and not input_images:
189
+ return None, None, None, None
190
+ target_dir, image_paths = handle_uploads(input_video, input_images)
191
+ return None, target_dir, image_paths, "Upload complete. Click 'Reconstruct' to begin 3D processing."
192
+
193
+
194
+ # -------------------------------------------------------------------------
195
+ # 4) Reconstruction: uses the target_dir plus any viz parameters
196
+ # -------------------------------------------------------------------------
197
+ def gradio_demo(
198
+ target_dir,
199
+ conf_thres=3.0,
200
+ frame_filter="All",
201
+ mask_black_bg=False,
202
+ mask_white_bg=False,
203
+ show_cam=True,
204
+ mask_sky=False,
205
+ prediction_mode="Pointmap Regression",
206
+ mode="causal",
207
+ streaming=False,
208
+ ):
209
+ """
210
+ Perform reconstruction using the already-created target_dir/images.
211
+ """
212
+ if not os.path.isdir(target_dir) or target_dir == "None":
213
+ return None, "No valid target directory found. Please upload first.", None, None
214
+
215
+ start_time = time.time()
216
+ gc.collect()
217
+ torch.cuda.empty_cache()
218
+
219
+ # Prepare frame_filter dropdown
220
+ target_dir_images = os.path.join(target_dir, "images")
221
+ all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
222
+ all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
223
+ frame_filter_choices = ["All"] + all_files
224
+
225
+ print("Running run_model...")
226
+ with torch.no_grad():
227
+ predictions = run_model(target_dir, model, mode=mode, streaming=streaming)
228
+
229
+ # Save predictions
230
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
231
+ np.savez(prediction_save_path, **predictions)
232
+
233
+ # Handle None frame_filter
234
+ if frame_filter is None:
235
+ frame_filter = "All"
236
+
237
+ # Build a GLB file name
238
+ glbfile = os.path.join(
239
+ target_dir,
240
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb",
241
+ )
242
+
243
+ # Convert predictions to GLB
244
+ glbscene = predictions_to_glb(
245
+ predictions,
246
+ conf_thres=conf_thres,
247
+ filter_by_frames=frame_filter,
248
+ mask_black_bg=mask_black_bg,
249
+ mask_white_bg=mask_white_bg,
250
+ show_cam=show_cam,
251
+ mask_sky=mask_sky,
252
+ target_dir=target_dir,
253
+ prediction_mode=prediction_mode,
254
+ )
255
+ glbscene.export(file_obj=glbfile)
256
+
257
+ # Cleanup
258
+ del predictions
259
+ gc.collect()
260
+ torch.cuda.empty_cache()
261
+
262
+ end_time = time.time()
263
+ print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
264
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
265
+
266
+ return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
267
+
268
+
269
+ # -------------------------------------------------------------------------
270
+ # 5) Helper functions for UI resets + re-visualization
271
+ # -------------------------------------------------------------------------
272
+ def clear_fields():
273
+ """
274
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
275
+ """
276
+ return None
277
+
278
+
279
+ def update_log():
280
+ """
281
+ Display a quick log message while waiting.
282
+ """
283
+ return "Loading and Reconstructing..."
284
+
285
+
286
+ def update_visualization(
287
+ target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
288
+ ):
289
+ """
290
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
291
+ and return it for the 3D viewer. If is_example == "True", skip.
292
+ """
293
+
294
+ # If it's an example click, skip as requested
295
+ if is_example == "True":
296
+ return None, "No reconstruction available. Please click the Reconstruct button first."
297
+
298
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
299
+ return None, "No reconstruction available. Please click the Reconstruct button first."
300
+
301
+ predictions_path = os.path.join(target_dir, "predictions.npz")
302
+ if not os.path.exists(predictions_path):
303
+ return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
304
+
305
+ key_list = [
306
+ "pose_enc",
307
+ "depth",
308
+ "depth_conf",
309
+ "world_points",
310
+ "world_points_conf",
311
+ "images",
312
+ "extrinsic",
313
+ "intrinsic",
314
+ "world_points_from_depth",
315
+ ]
316
+
317
+ loaded = np.load(predictions_path)
318
+ predictions = {key: np.array(loaded[key]) for key in key_list}
319
+
320
+ glbfile = os.path.join(
321
+ target_dir,
322
+ f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb",
323
+ )
324
+
325
+ if not os.path.exists(glbfile):
326
+ glbscene = predictions_to_glb(
327
+ predictions,
328
+ conf_thres=conf_thres,
329
+ filter_by_frames=frame_filter,
330
+ mask_black_bg=mask_black_bg,
331
+ mask_white_bg=mask_white_bg,
332
+ show_cam=show_cam,
333
+ mask_sky=mask_sky,
334
+ target_dir=target_dir,
335
+ prediction_mode=prediction_mode,
336
+ )
337
+ glbscene.export(file_obj=glbfile)
338
+
339
+ return glbfile, "Updating Visualization"
340
+
341
+
342
+ # -------------------------------------------------------------------------
343
+ # Example images
344
+ # -------------------------------------------------------------------------
345
+
346
+ great_wall_video = "examples/videos/great_wall.mp4"
347
+ colosseum_video = "examples/videos/Colosseum.mp4"
348
+ room_video = "examples/videos/room.mp4"
349
+ kitchen_video = "examples/videos/kitchen.mp4"
350
+ fern_video = "examples/videos/fern.mp4"
351
+ single_cartoon_video = "examples/videos/single_cartoon.mp4"
352
+ single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
353
+ pyramid_video = "examples/videos/pyramid.mp4"
354
+
355
+
356
+ # -------------------------------------------------------------------------
357
+ # 6) Build Gradio UI
358
+ # -------------------------------------------------------------------------
359
+ theme = gr.themes.Ocean()
360
+ theme.set(
361
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
362
+ checkbox_label_text_color_selected="*button_primary_text_color",
363
+ )
364
+
365
+ with gr.Blocks(
366
+ theme=theme,
367
+ css="""
368
+ .custom-log * {
369
+ font-style: italic;
370
+ font-size: 22px !important;
371
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
372
+ -webkit-background-clip: text;
373
+ background-clip: text;
374
+ font-weight: bold !important;
375
+ color: transparent !important;
376
+ text-align: center !important;
377
+ }
378
+
379
+ .example-log * {
380
+ font-style: italic;
381
+ font-size: 16px !important;
382
+ background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
383
+ -webkit-background-clip: text;
384
+ background-clip: text;
385
+ color: transparent !important;
386
+ }
387
+
388
+ #my_radio .wrap {
389
+ display: flex;
390
+ flex-wrap: nowrap;
391
+ justify-content: center;
392
+ align-items: center;
393
+ }
394
+
395
+ #my_radio .wrap label {
396
+ display: flex;
397
+ width: 50%;
398
+ justify-content: center;
399
+ align-items: center;
400
+ margin: 0;
401
+ padding: 10px 0;
402
+ box-sizing: border-box;
403
+ }
404
+ """,
405
+ ) as demo:
406
+ # Instead of gr.State, we use a hidden Textbox:
407
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
408
+ num_images = gr.Textbox(label="num_images", visible=False, value="None")
409
+ example_preview = gr.Image(label="Example Preview", visible=False)
410
+
411
+ gr.HTML(
412
+ """
413
+ <h1>🌅 STream3R: Scalable Sequential 3D Reconstruction with Causal Transformer</h1>
414
+ <p>
415
+ <a href="https://github.com/NIRVANALAN/STream3R">GitHub Repository</a> |
416
+ <a href="https://nirvanalan.github.io/projects/stream3r">Project Page</a> |
417
+ <a href="https://arxiv.org/abs/2508.10893">Paper</a>
418
+ </p>
419
+
420
+ <blockquote>
421
+ Special thanks to VGG-T for their visualization demo, which this demo is built upon!
422
+ </blockquote>
423
+
424
+ <div style="font-size: 16px; line-height: 1.5;">
425
+ <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. STream3R takes these images and generates a 3D point cloud, along with estimated camera poses.</p>
426
+
427
+ <h3>Getting Started:</h3>
428
+ <ol>
429
+ <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li>
430
+ <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li>
431
+ <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li>
432
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li>
433
+ <li>
434
+ <strong>Adjust Visualization (Optional):</strong>
435
+ After reconstruction, you can fine-tune the visualization using the options below
436
+ <details style="display:inline;">
437
+ <summary style="display:inline;">(<strong>click to expand</strong>):</summary>
438
+ <ul>
439
+ <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li>
440
+ <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li>
441
+ <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li>
442
+ <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li>
443
+ <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li>
444
+ </ul>
445
+ </details>
446
+ </li>
447
+ </ol>
448
+ <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">STream3R typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of STream3R's processing time. </span></p>
449
+ </div>
450
+ """
451
+ )
452
+
453
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
454
+
455
+ with gr.Row():
456
+ with gr.Column(scale=2):
457
+ input_video = gr.Video(label="Upload Video", interactive=True)
458
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
459
+
460
+ image_gallery = gr.Gallery(
461
+ label="Preview",
462
+ columns=4,
463
+ height="300px",
464
+ show_download_button=True,
465
+ object_fit="contain",
466
+ preview=True,
467
+ )
468
+
469
+ with gr.Column(scale=4):
470
+ with gr.Column():
471
+ gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
472
+ log_output = gr.Markdown(
473
+ "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"]
474
+ )
475
+ reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
476
+
477
+ with gr.Row():
478
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
479
+ clear_btn = gr.ClearButton(
480
+ [input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery],
481
+ scale=1,
482
+ )
483
+
484
+ with gr.Row():
485
+ prediction_mode = gr.Radio(
486
+ ["Depthmap and Camera Branch", "Pointmap Branch"],
487
+ label="Select a Prediction Mode",
488
+ value="Depthmap and Camera Branch",
489
+ scale=1,
490
+ elem_id="my_radio",
491
+ )
492
+
493
+ with gr.Row():
494
+ streaming = gr.Radio(
495
+ [('stream', True), ('batch', False)],
496
+ label="Streaming or Batch Mode",
497
+ value=False,
498
+ scale=1,
499
+ )
500
+
501
+ with gr.Row():
502
+ mode = gr.Radio(
503
+ ["causal", "window", "full"],
504
+ label="Select Processing Mode",
505
+ value="causal",
506
+ scale=1,
507
+ )
508
+
509
+ with gr.Row():
510
+ conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)")
511
+ frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame")
512
+ with gr.Column():
513
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
514
+ mask_sky = gr.Checkbox(label="Filter Sky", value=False)
515
+ mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
516
+ mask_white_bg = gr.Checkbox(label="Filter White Background", value=False)
517
+
518
+ # ---------------------- Examples section ----------------------
519
+ def build_examples_from_folder():
520
+ examples_root = "examples"
521
+ entries = []
522
+ if not os.path.isdir(examples_root):
523
+ return entries
524
+ candidate_dirs = sorted(
525
+ [
526
+ os.path.join(examples_root, d)
527
+ for d in os.listdir(examples_root)
528
+ if os.path.isdir(os.path.join(examples_root, d))
529
+ ], reverse=True
530
+ )
531
+ if not candidate_dirs:
532
+ candidate_dirs = [examples_root]
533
+ for example_dir in candidate_dirs:
534
+ image_files = []
535
+ for pattern in ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.webp"):
536
+ image_files.extend(sorted(glob.glob(os.path.join(example_dir, pattern))))
537
+ if not image_files:
538
+ continue
539
+ preview_image = image_files[0]
540
+ num_images_str = str(len(image_files))
541
+ entries.append(
542
+ [
543
+ preview_image, # preview image (for visualization only)
544
+ None, # input_video (unused for examples)
545
+ num_images_str,
546
+ image_files, # input_images
547
+ 15.0, # conf_thres
548
+ False, # mask_black_bg
549
+ False, # mask_white_bg
550
+ True, # show_cam
551
+ False, # mask_sky
552
+ "Depthmap and Camera Branch", # prediction_mode
553
+ "True", # is_example
554
+ "causal", # mode
555
+ ]
556
+ )
557
+ return entries[:2]
558
+
559
+ examples = build_examples_from_folder()
560
+
561
+ def example_pipeline(
562
+ preview_image,
563
+ input_video,
564
+ num_images_str,
565
+ input_images,
566
+ conf_thres,
567
+ mask_black_bg,
568
+ mask_white_bg,
569
+ show_cam,
570
+ mask_sky,
571
+ prediction_mode,
572
+ is_example_str,
573
+ mode="causal",
574
+ ):
575
+ """
576
+ 1) Copy example images to new target_dir
577
+ 2) Reconstruct
578
+ 3) Return model3D + logs + new_dir + updated dropdown + gallery
579
+ We do NOT return is_example. It's just an input.
580
+ """
581
+ target_dir, image_paths = handle_uploads(input_video, input_images)
582
+ # Always use "All" for frame_filter in examples
583
+ frame_filter = "All"
584
+ glbfile, log_msg, dropdown = gradio_demo(
585
+ target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, mode
586
+ )
587
+ return glbfile, log_msg, target_dir, dropdown, image_paths
588
+
589
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
590
+
591
+ gr.Examples(
592
+ examples=examples,
593
+ inputs=[
594
+ example_preview,
595
+ input_video,
596
+ num_images,
597
+ input_images,
598
+ conf_thres,
599
+ mask_black_bg,
600
+ mask_white_bg,
601
+ show_cam,
602
+ mask_sky,
603
+ prediction_mode,
604
+ is_example,
605
+ mode,
606
+ ],
607
+ outputs=[reconstruction_output, log_output, target_dir_output, frame_filter, image_gallery],
608
+ fn=example_pipeline,
609
+ cache_examples=False,
610
+ examples_per_page=50,
611
+ )
612
+
613
+ # -------------------------------------------------------------------------
614
+ # "Reconstruct" button logic:
615
+ # - Clear fields
616
+ # - Update log
617
+ # - gradio_demo(...) with the existing target_dir
618
+ # - Then set is_example = "False"
619
+ # -------------------------------------------------------------------------
620
+ submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then(
621
+ fn=update_log, inputs=[], outputs=[log_output]
622
+ ).then(
623
+ fn=gradio_demo,
624
+ inputs=[
625
+ target_dir_output,
626
+ conf_thres,
627
+ frame_filter,
628
+ mask_black_bg,
629
+ mask_white_bg,
630
+ show_cam,
631
+ mask_sky,
632
+ prediction_mode,
633
+ mode,
634
+ streaming,
635
+ ],
636
+ outputs=[reconstruction_output, log_output, frame_filter],
637
+ ).then(
638
+ fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
639
+ )
640
+
641
+ # -------------------------------------------------------------------------
642
+ # Real-time Visualization Updates
643
+ # -------------------------------------------------------------------------
644
+ conf_thres.change(
645
+ update_visualization,
646
+ [
647
+ target_dir_output,
648
+ conf_thres,
649
+ frame_filter,
650
+ mask_black_bg,
651
+ mask_white_bg,
652
+ show_cam,
653
+ mask_sky,
654
+ prediction_mode,
655
+ is_example,
656
+ ],
657
+ [reconstruction_output, log_output],
658
+ )
659
+ frame_filter.change(
660
+ update_visualization,
661
+ [
662
+ target_dir_output,
663
+ conf_thres,
664
+ frame_filter,
665
+ mask_black_bg,
666
+ mask_white_bg,
667
+ show_cam,
668
+ mask_sky,
669
+ prediction_mode,
670
+ is_example,
671
+ ],
672
+ [reconstruction_output, log_output],
673
+ )
674
+ mask_black_bg.change(
675
+ update_visualization,
676
+ [
677
+ target_dir_output,
678
+ conf_thres,
679
+ frame_filter,
680
+ mask_black_bg,
681
+ mask_white_bg,
682
+ show_cam,
683
+ mask_sky,
684
+ prediction_mode,
685
+ is_example,
686
+ ],
687
+ [reconstruction_output, log_output],
688
+ )
689
+ mask_white_bg.change(
690
+ update_visualization,
691
+ [
692
+ target_dir_output,
693
+ conf_thres,
694
+ frame_filter,
695
+ mask_black_bg,
696
+ mask_white_bg,
697
+ show_cam,
698
+ mask_sky,
699
+ prediction_mode,
700
+ is_example,
701
+ ],
702
+ [reconstruction_output, log_output],
703
+ )
704
+ show_cam.change(
705
+ update_visualization,
706
+ [
707
+ target_dir_output,
708
+ conf_thres,
709
+ frame_filter,
710
+ mask_black_bg,
711
+ mask_white_bg,
712
+ show_cam,
713
+ mask_sky,
714
+ prediction_mode,
715
+ is_example,
716
+ ],
717
+ [reconstruction_output, log_output],
718
+ )
719
+ mask_sky.change(
720
+ update_visualization,
721
+ [
722
+ target_dir_output,
723
+ conf_thres,
724
+ frame_filter,
725
+ mask_black_bg,
726
+ mask_white_bg,
727
+ show_cam,
728
+ mask_sky,
729
+ prediction_mode,
730
+ is_example,
731
+ ],
732
+ [reconstruction_output, log_output],
733
+ )
734
+ prediction_mode.change(
735
+ update_visualization,
736
+ [
737
+ target_dir_output,
738
+ conf_thres,
739
+ frame_filter,
740
+ mask_black_bg,
741
+ mask_white_bg,
742
+ show_cam,
743
+ mask_sky,
744
+ prediction_mode,
745
+ is_example,
746
+ ],
747
+ [reconstruction_output, log_output],
748
+ )
749
 
750
+ # -------------------------------------------------------------------------
751
+ # Auto-update gallery whenever user uploads or changes their files
752
+ # -------------------------------------------------------------------------
753
+ input_video.change(
754
+ fn=update_gallery_on_upload,
755
+ inputs=[input_video, input_images],
756
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
757
+ )
758
+ input_images.change(
759
+ fn=update_gallery_on_upload,
760
+ inputs=[input_video, input_images],
761
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
762
+ )
763
 
764
+ demo.queue(max_size=20).launch(show_error=True, share=True)