Julian Bilcke Claude commited on
Commit
2a6e562
·
1 Parent(s): 1221d93

Add ZeroGPU Gradio app and deployment documentation

Browse files

- Add app_gradio.py with ZeroGPU integration
- Add ZEROGPU_MIGRATION.md with implementation guide
- Add CLAUDE.md for AI assistant context
- Update README with ZeroGPU demo instructions
- Update requirements.txt for Gradio compatibility
- Add example camera trajectory JSON

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

.claude/settings.local.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(git remote rename:*)",
5
+ "Bash(git remote add:*)",
6
+ "Bash(git add:*)"
7
+ ],
8
+ "deny": [],
9
+ "ask": []
10
+ }
11
+ }
CLAUDE.md ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CLAUDE.md
2
+
3
+ This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+ ## Project Overview
6
+
7
+ FlashWorld is a high-quality 3D scene generation system that creates 3D scenes from text or image prompts in ~7 seconds on a single A100/A800 GPU. The project uses diffusion-based transformers with Gaussian Splatting for 3D reconstruction.
8
+
9
+ **Key capabilities:**
10
+ - Fast 3D scene generation (7 seconds on A100/A800)
11
+ - Text-to-3D and Image-to-3D generation
12
+ - Supports 24GB GPU memory configurations
13
+ - Outputs 3D Gaussian Splatting (.ply) files
14
+
15
+ ## Running the Application
16
+
17
+ ### Local Demo (Flask + Custom UI)
18
+ ```bash
19
+ python app.py --port 7860 --gpu 0 --cache_dir ./tmpfiles --max_concurrent 1
20
+ ```
21
+
22
+ Access the web interface at `http://HOST_IP:7860`
23
+
24
+ **Important flags:**
25
+ - `--offload_t5`: Offload text encoding to CPU to reduce GPU memory (trades speed for memory)
26
+ - `--ckpt`: Path to custom checkpoint (auto-downloads from HuggingFace if not provided)
27
+ - `--max_concurrent`: Maximum concurrent generation tasks (default: 1)
28
+
29
+ ### ZeroGPU Demo (Gradio)
30
+ ```bash
31
+ python app_gradio.py
32
+ ```
33
+
34
+ **ZeroGPU Configuration:**
35
+ - Uses `@spaces.GPU(duration=15)` decorator with 15-second GPU budget
36
+ - Model loading happens **outside** GPU decorator scope (in global scope)
37
+ - Gradio 5.49.1+ required
38
+ - Compatible with Hugging Face Spaces ZeroGPU hardware
39
+ - Automatically downloads model checkpoint from HuggingFace Hub
40
+
41
+ ### Installation
42
+ Dependencies are in `requirements.txt`. Key packages:
43
+ - PyTorch 2.6.0 with CUDA support
44
+ - Custom gsplat version from specific commit
45
+ - Custom diffusers version from specific commit
46
+
47
+ Install with:
48
+ ```bash
49
+ pip install -r requirements.txt
50
+ ```
51
+
52
+ ## Architecture
53
+
54
+ ### Core Components
55
+
56
+ **GenerationSystem** (app.py:90-346)
57
+ - Main neural network system combining VAE, text encoder, transformer, and 3D reconstruction
58
+ - Key submodules:
59
+ - `vae`: AutoencoderKLWan for image encoding/decoding (from Wan2.2-TI2V-5B model)
60
+ - `text_encoder`: UMT5 for text embedding
61
+ - `transformer`: WanTransformer3DModel for diffusion denoising
62
+ - `recon_decoder`: WANDecoderPixelAligned3DGSReconstructionModel for 3D Gaussian Splatting reconstruction
63
+ - Uses flow matching scheduler with 4 denoising steps
64
+ - Implements feedback mechanism where previous predictions inform next denoising step
65
+
66
+ **Key Generation Pipeline:**
67
+ 1. Text/image prompt → text embeddings + optional image latents
68
+ 2. Create raymaps from camera parameters (6DOF)
69
+ 3. Iterative denoising with 3D feedback loop (4 steps at timesteps [0, 250, 500, 750])
70
+ 4. Final prediction → 3D Gaussian parameters → render to images
71
+ 5. Export to PLY file format
72
+
73
+ ### Model Files
74
+
75
+ **models/transformer_wan.py**
76
+ - 3D transformer for video diffusion (adapted from Wan2.2 model)
77
+ - Handles temporal + spatial attention with RoPE (Rotary Position Embeddings)
78
+
79
+ **models/reconstruction_model.py**
80
+ - `WANDecoderPixelAligned3DGSReconstructionModel`: Converts latent features to 3D Gaussian parameters
81
+ - `PixelAligned3DGS`: Per-pixel Gaussian parameter prediction
82
+ - Outputs: positions (xyz), opacity, scales, rotations, SH features
83
+
84
+ **models/autoencoder_kl_wan.py**
85
+ - VAE for image encoding/decoding (WAN architecture)
86
+ - Custom 3D causal convolutions adapted for single-frame processing
87
+
88
+ **models/render.py**
89
+ - Gaussian Splatting rasterization using gsplat library
90
+
91
+ **utils.py**
92
+ - Camera utilities: normalize_cameras, create_rays, create_raymaps
93
+ - Quaternion operations: quaternion_to_matrix, matrix_to_quaternion, quaternion_slerp
94
+ - Camera interpolation: sample_from_dense_cameras, sample_from_two_pose
95
+ - Export: export_ply_for_gaussians
96
+
97
+ ### Gradio Interface (app_gradio.py)
98
+
99
+ **ZeroGPU Integration:**
100
+ - Model initialized in global scope (outside @spaces.GPU decorator)
101
+ - `generate_scene()` function decorated with `@spaces.GPU(duration=15)`
102
+ - Accepts image prompts (PIL), text prompts, camera JSON, and resolution
103
+ - Returns PLY file and status message
104
+ - Uses Gradio Progress API for user feedback
105
+
106
+ **Input Format:**
107
+ - Image: PIL Image (optional)
108
+ - Text: String prompt (optional)
109
+ - Camera JSON: Array of camera dictionaries with `quaternion`, `position`, `fx`, `fy`, `cx`, `cy`
110
+ - Resolution: String format "NxHxW" (e.g., "24x480x704")
111
+
112
+ ### Flask API (app.py - Local Only)
113
+
114
+ **Concurrency Management** (concurrency_manager.py)
115
+ - Thread-pool based task queue for handling multiple generation requests
116
+ - Task states: QUEUED → RUNNING → COMPLETED/FAILED
117
+ - Automatic cleanup of old cached files (30 minute TTL)
118
+
119
+ **API Endpoints:**
120
+ - `POST /generate`: Submit generation task (returns task_id immediately)
121
+ - `GET /task/<task_id>`: Poll task status and get results
122
+ - `GET /download/<file_id>`: Download generated PLY file
123
+ - `DELETE /delete/<file_id>`: Clean up generated files
124
+ - `GET /status`: Get queue status
125
+ - `GET /`: Serve web interface (index.html)
126
+
127
+ **Request Format:**
128
+ ```json
129
+ {
130
+ "image_prompt": "<base64 or path>", // optional
131
+ "text_prompt": "...",
132
+ "cameras": [{"quaternion": [...], "position": [...], "fx": ..., "fy": ..., "cx": ..., "cy": ...}],
133
+ "resolution": [n_frames, height, width],
134
+ "image_index": 0 // which frame to condition on
135
+ }
136
+ ```
137
+
138
+ ### Camera System
139
+
140
+ Cameras are represented as 11D vectors: `[qw, qx, qy, qz, tx, ty, tz, fx, fy, cx, cy]`
141
+ - First 4: quaternion rotation (real-first convention)
142
+ - Next 3: translation
143
+ - Last 4: intrinsics (normalized by image dimensions)
144
+
145
+ **Camera normalization** (utils.py:269-296):
146
+ - Centers scene around first camera
147
+ - Normalizes translation scale based on max camera distance
148
+ - Critical for stable 3D generation
149
+
150
+ ## Development Notes
151
+
152
+ ### Memory Management
153
+ - Model uses FP8 quantization (quant.py) for transformer to reduce memory
154
+ - VAE and text encoder can be offloaded to CPU with `--offload_t5` and `--offload_vae` flags
155
+ - Checkpoint mechanism for decoder to reduce memory during training
156
+
157
+ ### Key Constants
158
+ - Latent dimension: 48 channels
159
+ - Temporal downsample: 4x
160
+ - Spatial downsample: 16x
161
+ - Feature dimension: 1024 channels
162
+ - Latent patch size: 2
163
+ - Denoising timesteps: [0, 250, 500, 750]
164
+
165
+ ### Model Weights
166
+ - Primary checkpoint auto-downloads from HuggingFace: `imlixinyang/FlashWorld`
167
+ - Base diffusion model: `Wan-AI/Wan2.2-TI2V-5B-Diffusers`
168
+ - Model is adapted with additional input/output channels for 3D features
169
+
170
+ ### Rendering
171
+ - Uses gsplat 1.5.2 for differentiable Gaussian Splatting
172
+ - SH degree: 2 (supports spherical harmonics up to degree 2)
173
+ - Background modes: 'white', 'black', 'random'
174
+ - Output FPS: 15
175
+
176
+ ## License
177
+
178
+ CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) - Academic research use only.
README.md CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  <p align="center">
3
  <h2 align="center">
@@ -52,15 +64,37 @@ git clone https://github.com/imlixinyang/FlashWorld.git
52
  cd FlashWorld
53
  ```
54
 
55
- - run our demo app by:
 
 
56
  ```
57
  python app.py
58
  ```
59
 
 
 
 
 
 
60
  If your machine has limited GPU memory, consider adding the ```--offload_t5``` flag to offload text encoding to the CPU, which will reduce GPU memory usage. Note that this may slow down the generation speed somewhat.
61
 
62
  Then, open your web browser and navigate to ```http://HOST_IP:7860``` to start exploring FlashWorld!
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  <!-- We also provide example trajectory josn files and input images in the `examples/` directory. -->
65
 
66
  ## More Generation Results
 
1
+ ---
2
+ title: FlashWorld
3
+ emoji: 🌎
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.49.1
8
+ app_file: app_gradio.py
9
+ pinned: false
10
+ license: cc-by-nc-sa-4.0
11
+ python_version: 3.10.13
12
+ ---
13
 
14
  <p align="center">
15
  <h2 align="center">
 
64
  cd FlashWorld
65
  ```
66
 
67
+ - run our demo app:
68
+
69
+ **Local Demo (Flask + Custom UI):**
70
  ```
71
  python app.py
72
  ```
73
 
74
+ **ZeroGPU Demo (Gradio):**
75
+ ```
76
+ python app_gradio.py
77
+ ```
78
+
79
  If your machine has limited GPU memory, consider adding the ```--offload_t5``` flag to offload text encoding to the CPU, which will reduce GPU memory usage. Note that this may slow down the generation speed somewhat.
80
 
81
  Then, open your web browser and navigate to ```http://HOST_IP:7860``` to start exploring FlashWorld!
82
 
83
+ ## ZeroGPU Deployment
84
+
85
+ This repository is compatible with Hugging Face Spaces using ZeroGPU. The `app_gradio.py` file provides a Gradio interface with:
86
+
87
+ - **15-second GPU budget** per generation (configurable via `@spaces.GPU(duration=15)`)
88
+ - Model loading happens **outside** the GPU decorator for efficiency
89
+ - Supports both image and text prompts
90
+ - Camera trajectory input via JSON
91
+ - Outputs 3D Gaussian Splatting PLY files
92
+
93
+ To deploy on Hugging Face Spaces:
94
+ 1. Create a new Space with ZeroGPU hardware
95
+ 2. Set `app_file: app_gradio.py` in the README header
96
+ 3. The model checkpoint will be automatically downloaded from HuggingFace Hub
97
+
98
  <!-- We also provide example trajectory josn files and input images in the `examples/` directory. -->
99
 
100
  ## More Generation Results
ZEROGPU_MIGRATION.md ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ZeroGPU Migration Guide
2
+
3
+ This document describes the changes made to enable FlashWorld to run on Hugging Face Spaces with ZeroGPU.
4
+
5
+ ## Overview
6
+
7
+ FlashWorld has been adapted to support ZeroGPU deployment on Hugging Face Spaces. This allows the model to run on free, dynamically allocated GPU resources with a configurable time budget.
8
+
9
+ ## Changes Made
10
+
11
+ ### 1. New Gradio Application (`app_gradio.py`)
12
+
13
+ Created a new Gradio-based interface that replaces the Flask API for ZeroGPU deployment:
14
+
15
+ **Key Features:**
16
+ - Uses Gradio 5.49.1+ for the interface
17
+ - Implements `@spaces.GPU(duration=15)` decorator with 15-second GPU budget
18
+ - Model loading happens in global scope (outside GPU decorator) for efficiency
19
+ - Simpler interface compared to the original Flask app with custom HTML
20
+ - Accepts camera trajectory as JSON input
21
+ - Returns PLY files for download
22
+
23
+ **Architecture:**
24
+ ```python
25
+ # Model loads globally (once, at startup)
26
+ generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device, offload_t5=args.offload_t5)
27
+
28
+ # Generation function uses GPU only when called
29
+ @spaces.GPU(duration=15)
30
+ def generate_scene(image_prompt, text_prompt, camera_json, resolution):
31
+ # GPU-intensive work happens here
32
+ # Returns PLY file + status message
33
+ ```
34
+
35
+ ### 2. Requirements Updates (`requirements.txt`)
36
+
37
+ **Removed:**
38
+ - `flask==3.1.2` (not needed for ZeroGPU deployment)
39
+
40
+ **Added:**
41
+ - `spaces` (Hugging Face Spaces integration library)
42
+
43
+ **Kept:**
44
+ - `gradio==5.49.1` (required for Gradio SDK)
45
+ - All other dependencies remain unchanged
46
+
47
+ ### 3. README Updates
48
+
49
+ **Added YAML frontmatter:**
50
+ ```yaml
51
+ ---
52
+ title: FlashWorld
53
+ emoji: 🌎
54
+ colorFrom: blue
55
+ colorTo: green
56
+ sdk: gradio
57
+ sdk_version: 5.49.1
58
+ app_file: app_gradio.py
59
+ pinned: false
60
+ license: cc-by-nc-sa-4.0
61
+ python_version: 3.10.13
62
+ ---
63
+ ```
64
+
65
+ **Added ZeroGPU deployment section:**
66
+ - Instructions for deploying on Hugging Face Spaces
67
+ - Documentation of 15-second GPU budget
68
+ - Explanation of model loading strategy
69
+
70
+ ### 4. CLAUDE.md Updates
71
+
72
+ Updated the development documentation to include:
73
+ - Instructions for running both Flask (local) and Gradio (ZeroGPU) versions
74
+ - Documentation of ZeroGPU configuration
75
+ - Explanation of decorator usage and model loading patterns
76
+
77
+ ### 5. Example Camera Trajectory
78
+
79
+ Created `examples/simple_trajectory.json` with a basic 5-camera forward-moving trajectory to help users get started.
80
+
81
+ ## Key Design Decisions
82
+
83
+ ### Why 15 Seconds?
84
+
85
+ The GPU duration budget was set to 15 seconds for the following reasons:
86
+ 1. Generation takes ~7 seconds on A100/A800
87
+ 2. Additional time needed for:
88
+ - Input processing (image resizing, camera parsing)
89
+ - Export to PLY format
90
+ - Buffer for slower GPUs or variable load
91
+ 3. ZeroGPU default is 60 seconds, so 15 seconds is conservative
92
+
93
+ ### Model Loading Strategy
94
+
95
+ The model is loaded **once** in global scope, not inside the `@spaces.GPU` decorator:
96
+
97
+ **Advantages:**
98
+ - Model loads at startup, not on every generation
99
+ - Faster response time for users
100
+ - More efficient use of GPU time budget
101
+ - Follows ZeroGPU best practices
102
+
103
+ **Implementation:**
104
+ ```python
105
+ # Global scope - loads once at startup
106
+ generation_system = GenerationSystem(...)
107
+
108
+ # GPU decorator - only for inference
109
+ @spaces.GPU(duration=15)
110
+ def generate_scene(...):
111
+ return generation_system.generate(...)
112
+ ```
113
+
114
+ ### Input Format
115
+
116
+ Camera trajectories are provided as JSON to make the Gradio interface simpler:
117
+
118
+ ```json
119
+ {
120
+ "cameras": [
121
+ {
122
+ "quaternion": [w, x, y, z],
123
+ "position": [x, y, z],
124
+ "fx": 352.0,
125
+ "fy": 352.0,
126
+ "cx": 352.0,
127
+ "cy": 240.0
128
+ }
129
+ ]
130
+ }
131
+ ```
132
+
133
+ This is different from the Flask API which used nested dictionaries in the POST request.
134
+
135
+ ## Deployment Instructions
136
+
137
+ ### Local Testing
138
+
139
+ Test the Gradio app locally before deploying:
140
+
141
+ ```bash
142
+ python app_gradio.py
143
+ ```
144
+
145
+ This will start the Gradio interface at `http://localhost:7860`
146
+
147
+ ### Hugging Face Spaces Deployment
148
+
149
+ 1. **Create a new Space:**
150
+ - Go to https://huggingface.co/spaces
151
+ - Click "Create new Space"
152
+ - Select "ZeroGPU" as hardware
153
+
154
+ 2. **Upload files:**
155
+ - Push this repository to the Space
156
+ - Ensure `app_gradio.py` is set as the app file in README.md
157
+
158
+ 3. **Configuration:**
159
+ - The Space will automatically use the YAML frontmatter in README.md
160
+ - Model checkpoint will auto-download from HuggingFace Hub
161
+ - No additional configuration needed
162
+
163
+ 4. **Optional: Enable `--offload_t5` flag:**
164
+ - Edit `app_gradio.py` to add `offload_t5=True` in `GenerationSystem` initialization
165
+ - This reduces GPU memory usage but may slightly increase generation time
166
+
167
+ ## Limitations
168
+
169
+ ### ZeroGPU Constraints
170
+
171
+ 1. **60-second hard limit:** Cannot exceed 60 seconds per GPU call
172
+ 2. **No torch.compile:** Not supported in ZeroGPU environment
173
+ 3. **Gradio only:** Must use Gradio SDK (no Flask or other frameworks)
174
+ 4. **Python 3.10.13:** Recommended Python version
175
+
176
+ ### Feature Differences from Flask App
177
+
178
+ The Gradio app (`app_gradio.py`) differs from the original Flask app (`app.py`):
179
+
180
+ **Missing features:**
181
+ - Custom HTML/CSS interface
182
+ - Real-time 3D preview with Spark.js
183
+ - Manual camera trajectory recording with mouse/keyboard
184
+ - Template-based trajectory generation
185
+ - Queue visualization with progress bars
186
+ - Concurrent request handling
187
+
188
+ **Present features:**
189
+ - Image and text prompts
190
+ - Camera trajectory input (via JSON)
191
+ - PLY file generation and download
192
+ - Simple, accessible Gradio interface
193
+
194
+ ### Recommended Usage
195
+
196
+ For **ZeroGPU deployment:**
197
+ - Use `app_gradio.py`
198
+ - Keep camera trajectories reasonable (≤24 frames)
199
+ - Consider enabling `--offload_t5` for memory savings
200
+
201
+ For **local development with full features:**
202
+ - Use `app.py`
203
+ - Enjoy the full custom UI with interactive camera controls
204
+ - Support for multiple concurrent generations
205
+
206
+ ## Testing
207
+
208
+ ### Test the Gradio App
209
+
210
+ ```bash
211
+ # Start the app
212
+ python app_gradio.py
213
+
214
+ # In the browser (http://localhost:7860):
215
+ # 1. Upload an image (optional)
216
+ # 2. Enter text prompt (optional)
217
+ # 3. Paste example camera JSON from examples/simple_trajectory.json
218
+ # 4. Select resolution (24x480x704)
219
+ # 5. Click "Generate 3D Scene"
220
+ ```
221
+
222
+ ### Verify GPU Decorator
223
+
224
+ Check that model loading happens outside the decorator:
225
+
226
+ ```python
227
+ # Good - model loads once at startup
228
+ generation_system = GenerationSystem(...)
229
+
230
+ @spaces.GPU(duration=15)
231
+ def generate_scene(...):
232
+ return generation_system.generate(...)
233
+
234
+ # Bad - would reload model on every call (slow!)
235
+ @spaces.GPU(duration=15)
236
+ def generate_scene(...):
237
+ generation_system = GenerationSystem(...) # Don't do this!
238
+ return generation_system.generate(...)
239
+ ```
240
+
241
+ ## Troubleshooting
242
+
243
+ ### "GPU budget exceeded"
244
+
245
+ **Cause:** Generation took longer than 15 seconds
246
+
247
+ **Solutions:**
248
+ - Reduce number of frames in camera trajectory
249
+ - Enable `--offload_t5` flag
250
+ - Increase duration: `@spaces.GPU(duration=20)`
251
+
252
+ ### "Out of memory"
253
+
254
+ **Cause:** GPU memory exhausted
255
+
256
+ **Solutions:**
257
+ - Enable T5 offloading: `offload_t5=True`
258
+ - Enable VAE offloading: `offload_vae=True`
259
+ - Reduce resolution
260
+ - Reduce number of frames
261
+
262
+ ### "Model checkpoint not found"
263
+
264
+ **Cause:** Automatic download failed
265
+
266
+ **Solutions:**
267
+ - Check internet connection
268
+ - Verify HuggingFace access
269
+ - Manually download and specify with `--ckpt` flag
270
+
271
+ ## Future Improvements
272
+
273
+ Potential enhancements for ZeroGPU deployment:
274
+
275
+ 1. **Gradio Blocks UI:** Add more interactive controls
276
+ 2. **Example gallery:** Pre-loaded example camera trajectories
277
+ 3. **3D visualization:** Embed PLY viewer in Gradio interface
278
+ 4. **Video preview:** Show rendered video before downloading PLY
279
+ 5. **Dynamic duration:** Adjust GPU budget based on camera count
280
+
281
+ ## References
282
+
283
+ - [ZeroGPU Documentation](https://huggingface.co/docs/hub/en/spaces-zerogpu)
284
+ - [Gradio Documentation](https://gradio.app/docs/)
285
+ - [FlashWorld Paper](https://arxiv.org/pdf/2510.13678)
286
+ - [FlashWorld Project Page](https://imlixinyang.github.io/FlashWorld-Project-Page)
app_gradio.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import spaces
3
+ GPU = spaces.GPU
4
+ print("spaces GPU is available")
5
+ except ImportError:
6
+ def GPU(duration=15):
7
+ def decorator(func):
8
+ return func
9
+ return decorator
10
+ print("spaces GPU is NOT available, using fallback decorator")
11
+
12
+ import os
13
+ import torch
14
+ import numpy as np
15
+ import imageio
16
+ import json
17
+ import time
18
+ from PIL import Image
19
+ import gradio as gr
20
+ from huggingface_hub import hf_hub_download
21
+ import einops
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+
25
+ from models import *
26
+ from utils import *
27
+ from transformers import T5TokenizerFast, UMT5EncoderModel
28
+ from diffusers import FlowMatchEulerDiscreteScheduler
29
+
30
+
31
+ class MyFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler):
32
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
33
+ if schedule_timesteps is None:
34
+ schedule_timesteps = self.timesteps
35
+
36
+ return torch.argmin(
37
+ (timestep - schedule_timesteps.to(timestep.device)).abs(), dim=0).item()
38
+
39
+
40
+ class GenerationSystem(nn.Module):
41
+ def __init__(self, ckpt_path=None, device="cuda:0", offload_t5=False, offload_vae=False):
42
+ super().__init__()
43
+ self.device = device
44
+ self.offload_t5 = offload_t5
45
+ self.offload_vae = offload_vae
46
+
47
+ self.latent_dim = 48
48
+ self.temporal_downsample_factor = 4
49
+ self.spatial_downsample_factor = 16
50
+
51
+ self.feat_dim = 1024
52
+
53
+ self.latent_patch_size = 2
54
+
55
+ self.denoising_steps = [0, 250, 500, 750]
56
+
57
+ model_id = "Wan-AI/Wan2.2-TI2V-5B-Diffusers"
58
+
59
+ self.vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float).eval()
60
+
61
+ from models.autoencoder_kl_wan import WanCausalConv3d
62
+ with torch.no_grad():
63
+ for name, module in self.vae.named_modules():
64
+ if isinstance(module, WanCausalConv3d):
65
+ time_pad = module._padding[4]
66
+ module.padding = (0, module._padding[2], module._padding[0])
67
+ module._padding = (0, 0, 0, 0, 0, 0)
68
+ module.weight = torch.nn.Parameter(module.weight[:, :, time_pad:].clone())
69
+
70
+ self.vae.requires_grad_(False)
71
+
72
+ self.register_buffer('latents_mean', torch.tensor(self.vae.config.latents_mean).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
73
+ self.register_buffer('latents_std', torch.tensor(self.vae.config.latents_std).float().view(1, self.vae.config.z_dim, 1, 1, 1).to(self.device))
74
+
75
+ self.latent_scale_fn = lambda x: (x - self.latents_mean) / self.latents_std
76
+ self.latent_unscale_fn = lambda x: x * self.latents_std + self.latents_mean
77
+
78
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_id, subfolder="tokenizer")
79
+
80
+ self.text_encoder = UMT5EncoderModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.float32).eval().requires_grad_(False).to(self.device if not self.offload_t5 else "cpu")
81
+
82
+ self.transformer = WanTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.float32).train().requires_grad_(False)
83
+
84
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, 6 + self.latent_dim)))
85
+
86
+ weight = self.transformer.proj_out.weight.reshape(self.latent_patch_size ** 2, self.latent_dim, self.transformer.proj_out.weight.shape[1])
87
+ bias = self.transformer.proj_out.bias.reshape(self.latent_patch_size ** 2, self.latent_dim)
88
+
89
+ extra_weight = torch.randn(self.latent_patch_size ** 2, self.feat_dim, self.transformer.proj_out.weight.shape[1]) * 0.02
90
+ extra_bias = torch.zeros(self.latent_patch_size ** 2, self.feat_dim)
91
+
92
+ self.transformer.proj_out.weight = nn.Parameter(torch.cat([weight, extra_weight], dim=1).flatten(0, 1).detach().clone())
93
+ self.transformer.proj_out.bias = nn.Parameter(torch.cat([bias, extra_bias], dim=1).flatten(0, 1).detach().clone())
94
+
95
+ self.recon_decoder = WANDecoderPixelAligned3DGSReconstructionModel(self.vae, self.feat_dim, use_render_checkpointing=True, use_network_checkpointing=False).train().requires_grad_(False).to(self.device)
96
+
97
+ self.scheduler = MyFlowMatchEulerDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler", shift=3)
98
+
99
+ self.register_buffer('timesteps', self.scheduler.timesteps.clone().to(self.device))
100
+
101
+ self.transformer.disable_gradient_checkpointing()
102
+ self.transformer.gradient_checkpointing = False
103
+
104
+ self.add_feedback_for_transformer()
105
+
106
+ if ckpt_path is not None:
107
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=False)
108
+ self.transformer.load_state_dict(state_dict["transformer"])
109
+ self.recon_decoder.load_state_dict(state_dict["recon_decoder"])
110
+ print(f"Loaded {ckpt_path}.")
111
+
112
+ from quant import FluxFp8GeMMProcessor
113
+
114
+ FluxFp8GeMMProcessor(self.transformer)
115
+
116
+ del self.vae.post_quant_conv, self.vae.decoder
117
+ self.vae.to(self.device if not self.offload_vae else "cpu")
118
+
119
+ self.transformer.to(self.device)
120
+
121
+ def add_feedback_for_transformer(self):
122
+ self.use_feedback = True
123
+ self.transformer.patch_embedding.weight = nn.Parameter(F.pad(self.transformer.patch_embedding.weight, (0, 0, 0, 0, 0, 0, 0, self.feat_dim + self.latent_dim)))
124
+
125
+ def encode_text(self, texts):
126
+ max_sequence_length = 512
127
+
128
+ text_inputs = self.tokenizer(
129
+ texts,
130
+ padding="max_length",
131
+ max_length=max_sequence_length,
132
+ truncation=True,
133
+ add_special_tokens=True,
134
+ return_attention_mask=True,
135
+ return_tensors="pt",
136
+ )
137
+ if getattr(self, "offload_t5", False):
138
+ text_input_ids = text_inputs.input_ids.to("cpu")
139
+ mask = text_inputs.attention_mask.to("cpu")
140
+ else:
141
+ text_input_ids = text_inputs.input_ids.to(self.device)
142
+ mask = text_inputs.attention_mask.to(self.device)
143
+ seq_lens = mask.gt(0).sum(dim=1).long()
144
+
145
+ if getattr(self, "offload_t5", False):
146
+ with torch.no_grad():
147
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state.to(self.device)
148
+ else:
149
+ text_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state
150
+ text_embeds = [u[:v] for u, v in zip(text_embeds, seq_lens)]
151
+ text_embeds = torch.stack(
152
+ [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in text_embeds], dim=0
153
+ )
154
+ return text_embeds.float()
155
+
156
+ def forward_generator(self, noisy_latents, raymaps, condition_latents, t, text_embeds, cameras, render_cameras, image_height, image_width, need_3d_mode=True):
157
+
158
+ out = self.transformer(
159
+ hidden_states=torch.cat([noisy_latents, raymaps, condition_latents], dim=1),
160
+ timestep=t,
161
+ encoder_hidden_states=text_embeds,
162
+ return_dict=False,
163
+ )[0]
164
+
165
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
166
+
167
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
168
+ latents_pred_2d = noisy_latents - sigma * v_pred
169
+
170
+ if need_3d_mode:
171
+ scene_params = self.recon_decoder(
172
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
173
+ einops.rearrange(self.latent_unscale_fn(latents_pred_2d.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
174
+ cameras
175
+ ).flatten(1, -2)
176
+
177
+ images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
178
+
179
+ latents_pred_3d = einops.rearrange(self.latent_scale_fn(self.vae.encode(
180
+ einops.rearrange(images_pred, 'B T C H W -> (B T) C H W', T=images_pred.shape[1]).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
181
+ ).latent_dist.sample().to(self.device)).squeeze(2), '(B T) C H W -> B C T H W', T=images_pred.shape[1]).to(noisy_latents.dtype)
182
+
183
+ return {
184
+ '2d': latents_pred_2d,
185
+ '3d': latents_pred_3d if need_3d_mode else None,
186
+ 'rgb_3d': images_pred if need_3d_mode else None,
187
+ 'scene': scene_params if need_3d_mode else None,
188
+ 'feat': feats
189
+ }
190
+
191
+ @torch.no_grad()
192
+ @torch.amp.autocast(dtype=torch.bfloat16, device_type="cuda")
193
+ def generate(self, cameras, n_frame, image=None, text="", image_index=0, image_height=480, image_width=704, video_output_path=None):
194
+ with torch.no_grad():
195
+ batch_size = 1
196
+
197
+ cameras = cameras.to(self.device).unsqueeze(0)
198
+
199
+ if cameras.shape[1] != n_frame:
200
+ render_cameras = cameras.clone()
201
+ cameras = sample_from_dense_cameras(cameras.squeeze(0), torch.linspace(0, 1, n_frame, device=self.device)).unsqueeze(0)
202
+ else:
203
+ render_cameras = cameras
204
+
205
+ cameras, ref_w2c, T_norm = normalize_cameras(cameras, return_meta=True, n_frame=None)
206
+
207
+ render_cameras = normalize_cameras(render_cameras, ref_w2c=ref_w2c, T_norm=T_norm, n_frame=None)
208
+
209
+ text = "[Static] " + text
210
+
211
+ text_embeds = self.encode_text([text])
212
+
213
+ masks = torch.zeros(batch_size, n_frame, device=self.device)
214
+
215
+ condition_latents = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
216
+
217
+ if image is not None:
218
+ image = image.to(self.device)
219
+
220
+ latent = self.latent_scale_fn(self.vae.encode(
221
+ image.unsqueeze(0).unsqueeze(2).to(self.device if not self.offload_vae else "cpu").float()
222
+ ).latent_dist.sample().to(self.device)).squeeze(2)
223
+
224
+ masks[:, image_index] = 1
225
+ condition_latents[:, :, image_index] = latent
226
+
227
+ raymaps = create_raymaps(cameras, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor)
228
+ raymaps = einops.rearrange(raymaps, 'B T H W C -> B C T H W', T=n_frame)
229
+
230
+ noise = torch.randn(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
231
+
232
+ noisy_latents = noise
233
+
234
+ torch.cuda.empty_cache()
235
+
236
+ if self.use_feedback:
237
+ prev_latents_pred = torch.zeros(batch_size, self.latent_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
238
+
239
+ prev_feats = torch.zeros(batch_size, self.feat_dim, n_frame, image_height // self.spatial_downsample_factor, image_width // self.spatial_downsample_factor, device=self.device)
240
+
241
+ for i in range(len(self.denoising_steps)):
242
+ t_ids = torch.full((noisy_latents.shape[0],), self.denoising_steps[i], device=self.device)
243
+
244
+ t = self.timesteps[t_ids]
245
+
246
+ if self.use_feedback:
247
+ _condition_latents = torch.cat([condition_latents, prev_feats, prev_latents_pred], dim=1)
248
+ else:
249
+ _condition_latents = condition_latents
250
+
251
+ if i < len(self.denoising_steps) - 1:
252
+ out = self.forward_generator(noisy_latents, raymaps, _condition_latents, t, text_embeds, cameras, cameras, image_height, image_width, need_3d_mode=True)
253
+
254
+ latents_pred = out["3d"]
255
+
256
+ if self.use_feedback:
257
+ prev_latents_pred = latents_pred
258
+ prev_feats = out['feat']
259
+
260
+ noisy_latents = self.scheduler.scale_noise(latents_pred, self.timesteps[torch.full((noisy_latents.shape[0],), self.denoising_steps[i + 1], device=self.device)], torch.randn_like(noise))
261
+
262
+ else:
263
+ out = self.transformer(
264
+ hidden_states=torch.cat([noisy_latents, raymaps, _condition_latents], dim=1),
265
+ timestep=t,
266
+ encoder_hidden_states=text_embeds,
267
+ return_dict=False,
268
+ )[0]
269
+
270
+ v_pred, feats = out.split([self.latent_dim, self.feat_dim], dim=1)
271
+
272
+ sigma = torch.stack([self.scheduler.sigmas[self.scheduler.index_for_timestep(_t)] for _t in t.unbind(0)], dim=0).to(self.device)
273
+ latents_pred = noisy_latents - sigma * v_pred
274
+
275
+ scene_params = self.recon_decoder(
276
+ einops.rearrange(feats, 'B C T H W -> (B T) C H W').unsqueeze(2),
277
+ einops.rearrange(self.latent_unscale_fn(latents_pred.detach()), 'B C T H W -> (B T) C H W').unsqueeze(2),
278
+ cameras
279
+ ).flatten(1, -2)
280
+
281
+ if video_output_path is not None:
282
+ interpolated_images_pred, _ = self.recon_decoder.render(scene_params.unbind(0), render_cameras, image_height, image_width, bg_mode="white")
283
+
284
+ interpolated_images_pred = einops.rearrange(interpolated_images_pred[0].clamp(-1, 1).add(1).div(2), 'T C H W -> T H W C')
285
+
286
+ interpolated_images_pred = [torch.cat([img], dim=1).detach().cpu().mul(255).numpy().astype(np.uint8) for i, img in enumerate(interpolated_images_pred.unbind(0))]
287
+
288
+ imageio.mimwrite(video_output_path, interpolated_images_pred, fps=15, quality=8, macro_block_size=1)
289
+
290
+ scene_params = scene_params[0]
291
+
292
+ scene_params = scene_params.detach().cpu()
293
+
294
+ return scene_params, ref_w2c, T_norm
295
+
296
+
297
+ # Initialize the model globally (outside GPU decorator)
298
+ print("Initializing model...")
299
+ import argparse
300
+ parser = argparse.ArgumentParser()
301
+ parser.add_argument("--ckpt", default=None)
302
+ parser.add_argument("--gpu", type=int, default=0)
303
+ parser.add_argument("--offload_t5", action="store_true", help="Offload T5 encoder to CPU to save GPU memory")
304
+ args, _ = parser.parse_known_args()
305
+
306
+ # Ensure model.ckpt exists, download if not present
307
+ if args.ckpt is None:
308
+ from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
309
+ ckpt_path = os.path.join(HUGGINGFACE_HUB_CACHE, "models--imlixinyang--FlashWorld", "snapshots", "6a8e88c6f88678ac098e4c82675f0aee555d6e5d", "model.ckpt")
310
+ if not os.path.exists(ckpt_path):
311
+ print("Downloading model checkpoint...")
312
+ hf_hub_download(repo_id="imlixinyang/FlashWorld", filename="model.ckpt", local_dir_use_symlinks=False)
313
+ else:
314
+ ckpt_path = args.ckpt
315
+
316
+ device = f"cuda:{args.gpu}" if torch.cuda.is_available() else "cpu"
317
+ print(f"Loading model on device: {device}")
318
+ generation_system = GenerationSystem(ckpt_path=ckpt_path, device=device, offload_t5=args.offload_t5)
319
+ print("Model loaded successfully!")
320
+
321
+
322
+ # GPU-decorated generation function with 15-second budget
323
+ @GPU(duration=15)
324
+ def generate_scene(
325
+ image_prompt,
326
+ text_prompt,
327
+ camera_json,
328
+ resolution,
329
+ progress=gr.Progress()
330
+ ):
331
+ """
332
+ Generate 3D scene from image/text prompts and camera trajectory.
333
+
334
+ Args:
335
+ image_prompt: PIL Image or None
336
+ text_prompt: str
337
+ camera_json: JSON string with camera trajectory
338
+ resolution: str in format "NxHxW"
339
+ """
340
+ try:
341
+ progress(0, desc="Parsing inputs...")
342
+
343
+ # Parse resolution
344
+ n_frame, image_height, image_width = [int(x) for x in resolution.split('x')]
345
+
346
+ # Parse camera JSON
347
+ try:
348
+ camera_data = json.loads(camera_json)
349
+ if "cameras" not in camera_data or len(camera_data["cameras"]) == 0:
350
+ return None, "Error: No cameras found in JSON"
351
+ except json.JSONDecodeError as e:
352
+ return None, f"Error: Invalid JSON format: {str(e)}"
353
+
354
+ progress(0.1, desc="Processing camera trajectory...")
355
+
356
+ # Convert cameras to tensor
357
+ cameras = []
358
+ for cam in camera_data["cameras"]:
359
+ quat = cam["quaternion"] # [w, x, y, z]
360
+ pos = cam["position"] # [x, y, z]
361
+ fx = cam.get("fx", 0.5 / np.tan(0.5 * 60 * np.pi / 180) * image_height)
362
+ fy = cam.get("fy", 0.5 / np.tan(0.5 * 60 * np.pi / 180) * image_height)
363
+ cx = cam.get("cx", 0.5 * image_width)
364
+ cy = cam.get("cy", 0.5 * image_height)
365
+
366
+ camera_tensor = np.array([
367
+ quat[0], quat[1], quat[2], quat[3], # quaternion
368
+ pos[0], pos[1], pos[2], # position
369
+ fx / image_width, fy / image_height, # normalized focal lengths
370
+ cx / image_width, cy / image_height # normalized principal point
371
+ ], dtype=np.float32)
372
+ cameras.append(camera_tensor)
373
+
374
+ cameras = torch.from_numpy(np.stack(cameras, axis=0))
375
+
376
+ # Process image prompt
377
+ image = None
378
+ if image_prompt is not None:
379
+ progress(0.2, desc="Processing image prompt...")
380
+ # Convert PIL to tensor and resize
381
+ img = image_prompt.convert('RGB')
382
+ w, h = img.size
383
+
384
+ # Center crop
385
+ if image_height / h > image_width / w:
386
+ scale = image_height / h
387
+ else:
388
+ scale = image_width / w
389
+
390
+ new_h = int(image_height / scale)
391
+ new_w = int(image_width / scale)
392
+
393
+ img = img.crop((
394
+ (w - new_w) // 2, (h - new_h) // 2,
395
+ new_w + (w - new_w) // 2, new_h + (h - new_h) // 2
396
+ )).resize((image_width, image_height))
397
+
398
+ image = torch.from_numpy(np.array(img)).float().permute(2, 0, 1) / 255.0 * 2 - 1
399
+
400
+ progress(0.3, desc="Generating 3D scene (this takes ~7 seconds)...")
401
+
402
+ # Generate scene
403
+ output_path = f"/tmp/flashworld_output_{int(time.time())}.mp4"
404
+ scene_params, ref_w2c, T_norm = generation_system.generate(
405
+ cameras=cameras,
406
+ n_frame=n_frame,
407
+ image=image,
408
+ text=text_prompt,
409
+ image_index=0,
410
+ image_height=image_height,
411
+ image_width=image_width,
412
+ video_output_path=output_path
413
+ )
414
+
415
+ progress(0.9, desc="Exporting result...")
416
+
417
+ # Export to PLY
418
+ ply_path = f"/tmp/flashworld_output_{int(time.time())}.ply"
419
+ export_ply_for_gaussians(ply_path, scene_params, opacity_threshold=0.001, T_norm=T_norm)
420
+
421
+ progress(1.0, desc="Done!")
422
+
423
+ return ply_path, f"Generation successful! Scene contains {scene_params.shape[0]} Gaussians."
424
+
425
+ except Exception as e:
426
+ import traceback
427
+ error_msg = f"Error during generation: {str(e)}\n{traceback.format_exc()}"
428
+ print(error_msg)
429
+ return None, error_msg
430
+
431
+
432
+ # Create Gradio interface
433
+ def create_demo():
434
+ with gr.Blocks(title="FlashWorld: Fast 3D Scene Generation") as demo:
435
+ gr.Markdown("""
436
+ # FlashWorld: High-quality 3D Scene Generation within Seconds
437
+
438
+ Generate 3D scenes in ~7 seconds from text or image prompts with camera trajectory!
439
+
440
+ **Note:** This demo uses ZeroGPU with a 15-second budget. Please ensure your camera trajectory is reasonable.
441
+ """)
442
+
443
+ with gr.Row():
444
+ with gr.Column():
445
+ # Input controls
446
+ gr.Markdown("### 1. Prompts")
447
+ image_input = gr.Image(label="Image Prompt (Optional)", type="pil")
448
+ text_input = gr.Textbox(
449
+ label="Text Prompt",
450
+ placeholder="A beautiful mountain landscape with trees...",
451
+ value=""
452
+ )
453
+
454
+ gr.Markdown("### 2. Camera Trajectory")
455
+ camera_json_input = gr.Code(
456
+ label="Camera JSON",
457
+ language="json",
458
+ value="""{
459
+ "cameras": [
460
+ {
461
+ "quaternion": [1, 0, 0, 0],
462
+ "position": [0, 0, 0],
463
+ "fx": 352.0,
464
+ "fy": 352.0,
465
+ "cx": 352.0,
466
+ "cy": 240.0
467
+ },
468
+ {
469
+ "quaternion": [1, 0, 0, 0],
470
+ "position": [0, 0, -0.5],
471
+ "fx": 352.0,
472
+ "fy": 352.0,
473
+ "cx": 352.0,
474
+ "cy": 240.0
475
+ }
476
+ ]
477
+ }""",
478
+ lines=15
479
+ )
480
+
481
+ gr.Markdown("### 3. Resolution")
482
+ resolution_input = gr.Dropdown(
483
+ label="Resolution (NxHxW)",
484
+ choices=["24x480x704", "24x704x480"],
485
+ value="24x480x704"
486
+ )
487
+
488
+ generate_btn = gr.Button("Generate 3D Scene", variant="primary", size="lg")
489
+
490
+ with gr.Column():
491
+ # Output
492
+ gr.Markdown("### Output")
493
+ output_file = gr.File(label="Download PLY File")
494
+ output_message = gr.Textbox(label="Status", lines=3)
495
+
496
+ gr.Markdown("""
497
+ ### Instructions:
498
+ 1. **Optional:** Upload an image prompt
499
+ 2. **Optional:** Enter a text description
500
+ 3. **Required:** Provide camera trajectory as JSON
501
+ 4. Select resolution (24 frames recommended)
502
+ 5. Click "Generate 3D Scene"
503
+
504
+ The camera JSON should contain an array of cameras with:
505
+ - `quaternion`: [w, x, y, z] rotation
506
+ - `position`: [x, y, z] translation
507
+ - `fx`, `fy`: focal lengths (pixels)
508
+ - `cx`, `cy`: principal point (pixels)
509
+
510
+ **Tips:**
511
+ - Generation takes ~7 seconds on GPU
512
+ - Download the PLY file to view in 3D viewers
513
+ - Use reasonable camera trajectories (not too many frames)
514
+ """)
515
+
516
+ # Connect the button
517
+ generate_btn.click(
518
+ fn=generate_scene,
519
+ inputs=[image_input, text_input, camera_json_input, resolution_input],
520
+ outputs=[output_file, output_message]
521
+ )
522
+
523
+ return demo
524
+
525
+
526
+ if __name__ == "__main__":
527
+ demo = create_demo()
528
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=False)
examples/simple_trajectory.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cameras": [
3
+ {
4
+ "quaternion": [1, 0, 0, 0],
5
+ "position": [0, 0, 0],
6
+ "fx": 352.0,
7
+ "fy": 352.0,
8
+ "cx": 352.0,
9
+ "cy": 240.0
10
+ },
11
+ {
12
+ "quaternion": [1, 0, 0, 0],
13
+ "position": [0, 0, -0.2],
14
+ "fx": 352.0,
15
+ "fy": 352.0,
16
+ "cx": 352.0,
17
+ "cy": 240.0
18
+ },
19
+ {
20
+ "quaternion": [1, 0, 0, 0],
21
+ "position": [0, 0, -0.4],
22
+ "fx": 352.0,
23
+ "fy": 352.0,
24
+ "cx": 352.0,
25
+ "cy": 240.0
26
+ },
27
+ {
28
+ "quaternion": [1, 0, 0, 0],
29
+ "position": [0, 0, -0.6],
30
+ "fx": 352.0,
31
+ "fy": 352.0,
32
+ "cx": 352.0,
33
+ "cy": 240.0
34
+ },
35
+ {
36
+ "quaternion": [1, 0, 0, 0],
37
+ "position": [0, 0, -0.8],
38
+ "fx": 352.0,
39
+ "fy": 352.0,
40
+ "cx": 352.0,
41
+ "cy": 240.0
42
+ }
43
+ ]
44
+ }
requirements.txt CHANGED
@@ -11,9 +11,9 @@ opencv-python==4.12.0.88
11
  av==15.1.0
12
  plyfile==1.1.2
13
  ftfy==6.3.1
14
- flask==3.1.2
15
  gradio==5.49.1
16
  gsplat==1.5.2
17
  accelerate==1.10.1
 
18
  git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
19
  git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712
 
11
  av==15.1.0
12
  plyfile==1.1.2
13
  ftfy==6.3.1
 
14
  gradio==5.49.1
15
  gsplat==1.5.2
16
  accelerate==1.10.1
17
+ spaces
18
  git+https://github.com/huggingface/diffusers.git@447e8322f76efea55d4769cd67c372edbf0715b8
19
  git+https://github.com/nerfstudio-project/gsplat.git@32f2a54d21c7ecb135320bb02b136b7407ae5712