opsiclear-admin commited on
Commit
02da9ed
·
verified ·
1 Parent(s): 304d5fa

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. .gitattributes +44 -44
  2. README.md +73 -4
  3. app_local.py +559 -0
.gitattributes CHANGED
@@ -1,38 +1,38 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp filter=lfs diff=lfs merge=lfs -text
37
  assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp filter=lfs diff=lfs merge=lfs -text
38
  assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp filter=lfs diff=lfs merge=lfs -text
@@ -108,6 +108,15 @@ assets/example_multi_image/popmart_3.png filter=lfs diff=lfs merge=lfs -text
108
  assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text
109
  assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text
110
  assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
111
  assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text
112
  assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text
113
  assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text
@@ -123,12 +132,3 @@ assets/hdri/night.exr filter=lfs diff=lfs merge=lfs -text
123
  assets/hdri/sunrise.exr filter=lfs diff=lfs merge=lfs -text
124
  assets/hdri/sunset.exr filter=lfs diff=lfs merge=lfs -text
125
  assets/teaser.webp filter=lfs diff=lfs merge=lfs -text
126
- assets/example_multi_image/scan55_1.png filter=lfs diff=lfs merge=lfs -text
127
- assets/example_multi_image/scan55_2.png filter=lfs diff=lfs merge=lfs -text
128
- assets/example_multi_image/scan55_3.png filter=lfs diff=lfs merge=lfs -text
129
- assets/example_multi_image/scan65_1.png filter=lfs diff=lfs merge=lfs -text
130
- assets/example_multi_image/scan65_2.png filter=lfs diff=lfs merge=lfs -text
131
- assets/example_multi_image/scan65_3.png filter=lfs diff=lfs merge=lfs -text
132
- assets/example_multi_image/scan69_1.png filter=lfs diff=lfs merge=lfs -text
133
- assets/example_multi_image/scan69_2.png filter=lfs diff=lfs merge=lfs -text
134
- assets/example_multi_image/scan69_3.png filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  assets/example_image/0a34fae7ba57cb8870df5325b9c30ea474def1b0913c19c596655b85a79fdee4.webp filter=lfs diff=lfs merge=lfs -text
37
  assets/example_image/0e4984a9b3765ce80e9853443f9319ecedf90885c74b56cccfebc09402740f8a.webp filter=lfs diff=lfs merge=lfs -text
38
  assets/example_image/0f168a4b1b6e96c72e9627c97a212c27a4572250ff58e25703b9d0c2bc74191a.webp filter=lfs diff=lfs merge=lfs -text
 
108
  assets/example_multi_image/rabbit_1.png filter=lfs diff=lfs merge=lfs -text
109
  assets/example_multi_image/rabbit_2.png filter=lfs diff=lfs merge=lfs -text
110
  assets/example_multi_image/rabbit_3.png filter=lfs diff=lfs merge=lfs -text
111
+ assets/example_multi_image/scan55_1.png filter=lfs diff=lfs merge=lfs -text
112
+ assets/example_multi_image/scan55_2.png filter=lfs diff=lfs merge=lfs -text
113
+ assets/example_multi_image/scan55_3.png filter=lfs diff=lfs merge=lfs -text
114
+ assets/example_multi_image/scan65_1.png filter=lfs diff=lfs merge=lfs -text
115
+ assets/example_multi_image/scan65_2.png filter=lfs diff=lfs merge=lfs -text
116
+ assets/example_multi_image/scan65_3.png filter=lfs diff=lfs merge=lfs -text
117
+ assets/example_multi_image/scan69_1.png filter=lfs diff=lfs merge=lfs -text
118
+ assets/example_multi_image/scan69_2.png filter=lfs diff=lfs merge=lfs -text
119
+ assets/example_multi_image/scan69_3.png filter=lfs diff=lfs merge=lfs -text
120
  assets/example_multi_image/tiger_1.png filter=lfs diff=lfs merge=lfs -text
121
  assets/example_multi_image/tiger_2.png filter=lfs diff=lfs merge=lfs -text
122
  assets/example_multi_image/tiger_3.png filter=lfs diff=lfs merge=lfs -text
 
132
  assets/hdri/sunrise.exr filter=lfs diff=lfs merge=lfs -text
133
  assets/hdri/sunset.exr filter=lfs diff=lfs merge=lfs -text
134
  assets/teaser.webp filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: TRELLIS.2 Multi-Image
3
  emoji: 🧊
4
  colorFrom: blue
5
  colorTo: purple
@@ -10,10 +10,79 @@ app_file: app.py
10
  pinned: false
11
  license: mit
12
  short_description: Multi-view image to 3D generation
 
13
  ---
14
 
15
- # TRELLIS.2 Multi-Image Conditioning
16
 
17
- Multi-view image to 3D generation using [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) with multi-image conditioning.
18
 
19
- Upload multiple views of an object to generate a 3D model.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: TRELLIS.2 Multi-Image Conditioning
3
  emoji: 🧊
4
  colorFrom: blue
5
  colorTo: purple
 
10
  pinned: false
11
  license: mit
12
  short_description: Multi-view image to 3D generation
13
+ suggested_hardware: a100-large
14
  ---
15
 
16
+ # TRELLIS.2 Multi-Image Conditioning Fork
17
 
18
+ This fork extends [TRELLIS.2](https://github.com/microsoft/TRELLIS.2) with multi-image conditioning and Windows support.
19
 
20
+ ## What's New
21
+
22
+ - **Multi-image conditioning**: Use multiple views for better 3D reconstruction
23
+ - **Windows support**: Runs on Windows with automatic `sdpa` attention fallback
24
+
25
+ > For the interactive visualization tool, see the [viser_view branch](https://github.com/OpsiClear/Trellis2_multi_image_conditioning/tree/viser_view).
26
+
27
+ ## Installation
28
+
29
+ ```sh
30
+ git clone https://github.com/OpsiClear/Trellis2_multi_image_conditioning.git --recursive
31
+ cd Trellis2_multi_image_conditioning
32
+ . ./setup.sh --new-env --basic --flash-attn --nvdiffrast --nvdiffrec --cumesh --o-voxel --flexgemm
33
+ ```
34
+
35
+ > On Windows, `flash-attn` is unavailable. The code automatically falls back to PyTorch's native `sdpa` backend.
36
+
37
+ ## Usage
38
+
39
+ ### Multi-Image Generation
40
+
41
+ ```python
42
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline
43
+ from PIL import Image
44
+
45
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B")
46
+ pipeline.cuda()
47
+
48
+ # Load multiple views
49
+ images = [Image.open(f"view_{i}.png") for i in range(4)]
50
+
51
+ # Generate with multi-image conditioning
52
+ mesh = pipeline.run_multi_image(images)[0]
53
+ ```
54
+
55
+ Or run the example:
56
+ ```sh
57
+ python example_multi_image.py
58
+ ```
59
+
60
+ ### Other Examples
61
+
62
+ ```sh
63
+ python example.py # Single image generation
64
+ python app.py # Gradio web demo
65
+ python example_texturing.py # PBR texture generation
66
+ python app_texturing.py # Texture generation web demo
67
+ ```
68
+
69
+ ## Attribution
70
+
71
+ This project is a fork of [TRELLIS.2](https://github.com/microsoft/TRELLIS.2) by Microsoft Corporation, originally released under the MIT License.
72
+
73
+ For full documentation, training instructions, and model details, see the original repository.
74
+
75
+ If you use this code, please cite the original paper:
76
+
77
+ ```bibtex
78
+ @article{xiang2025trellis2,
79
+ title={Native and Compact Structured Latents for 3D Generation},
80
+ author={Xiang, Jianfeng and Chen, Xiaoxue and Xu, Sicheng and Wang, Ruicheng and Lv, Zelong and Deng, Yu and Zhu, Hongyuan and Dong, Yue and Zhao, Hao and Yuan, Nicholas Jing and Yang, Jiaolong},
81
+ journal={Tech report},
82
+ year={2025}
83
+ }
84
+ ```
85
+
86
+ ## License
87
+
88
+ The original TRELLIS.2 code is MIT licensed by Microsoft Corporation. New additions in this fork (multi-image conditioning) are licensed under [AGPL-3.0](LICENSE).
app_local.py ADDED
@@ -0,0 +1,559 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Local testing version of app.py for Windows
3
+ - Uses sdpa backend instead of flash_attn_3
4
+ - Loads models at startup (no lazy imports needed)
5
+ - Mock @spaces.GPU decorator
6
+ """
7
+
8
+ import gradio as gr
9
+ from gradio_client import Client, handle_file
10
+ from concurrent.futures import ThreadPoolExecutor
11
+
12
+ import os
13
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
14
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
+ os.environ["ATTN_BACKEND"] = "sdpa" # Windows fallback
16
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
17
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
18
+
19
+ from datetime import datetime
20
+ import shutil
21
+ import cv2
22
+ from typing import *
23
+ import torch
24
+ import numpy as np
25
+ from PIL import Image
26
+ import base64
27
+ import io
28
+ import tempfile
29
+
30
+ from trellis2.modules.sparse import SparseTensor
31
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline
32
+ from trellis2.renderers import EnvMap
33
+ from trellis2.utils import render_utils
34
+ import o_voxel
35
+
36
+ # Mock spaces.GPU decorator for local testing
37
+ class MockSpaces:
38
+ @staticmethod
39
+ def GPU(duration=60):
40
+ def decorator(fn):
41
+ return fn
42
+ return decorator
43
+
44
+ spaces = MockSpaces()
45
+
46
+ MAX_SEED = np.iinfo(np.int32).max
47
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
48
+ MODES = [
49
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
50
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
51
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
52
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
53
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
54
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
55
+ ]
56
+ STEPS = 8
57
+ DEFAULT_MODE = 3
58
+ DEFAULT_STEP = 3
59
+
60
+
61
+ css = """
62
+ /* Overwrite Gradio Default Style */
63
+ .stepper-wrapper { padding: 0; }
64
+ .stepper-container { padding: 0; align-items: center; }
65
+ .step-button { flex-direction: row; }
66
+ .step-connector { transform: none; }
67
+ .step-number { width: 16px; height: 16px; }
68
+ .step-label { position: relative; bottom: 0; }
69
+ .wrap.center.full { inset: 0; height: 100%; }
70
+ .wrap.center.full.translucent { background: var(--block-background-fill); }
71
+ .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
72
+
73
+ /* Previewer */
74
+ .previewer-container { position: relative; font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; width: 100%; height: 722px; margin: 0 auto; padding: 20px; display: flex; flex-direction: column; align-items: center; justify-content: center; }
75
+ .previewer-container .tips-icon { position: absolute; right: 10px; top: 10px; z-index: 10; border-radius: 10px; color: #fff; background-color: var(--color-accent); padding: 3px 6px; user-select: none; }
76
+ .previewer-container .tips-text { position: absolute; right: 10px; top: 50px; color: #fff; background-color: var(--color-accent); border-radius: 10px; padding: 6px; text-align: left; max-width: 300px; z-index: 10; transition: all 0.3s; opacity: 0%; user-select: none; }
77
+ .previewer-container .tips-text p { font-size: 14px; line-height: 1.2; }
78
+ .tips-icon:hover + .tips-text { display: block; opacity: 100%; }
79
+ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 20px; flex-wrap: wrap; }
80
+ .previewer-container .mode-btn { width: 24px; height: 24px; border-radius: 50%; cursor: pointer; opacity: 0.5; transition: all 0.2s; border: 2px solid #ddd; object-fit: cover; }
81
+ .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
82
+ .previewer-container .mode-btn.active { opacity: 1; border-color: var(--color-accent); transform: scale(1.1); }
83
+ .previewer-container .display-row { margin-bottom: 20px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
84
+ .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
85
+ .previewer-container .previewer-main-image.visible { display: block; }
86
+ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 10px; padding: 0 10px; }
87
+ .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; }
88
+ .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 8px; cursor: pointer; background: #ddd; border-radius: 5px; }
89
+ .previewer-container input[type=range]::-webkit-slider-thumb { height: 20px; width: 20px; border-radius: 50%; background: var(--color-accent); cursor: pointer; -webkit-appearance: none; margin-top: -6px; box-shadow: 0 2px 5px rgba(0,0,0,0.2); transition: transform 0.1s; }
90
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); }
91
+ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
92
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
93
+ """
94
+
95
+
96
+ head = """
97
+ <script>
98
+ function refreshView(mode, step) {
99
+ const allImgs = document.querySelectorAll('.previewer-main-image');
100
+ for (let i = 0; i < allImgs.length; i++) {
101
+ const img = allImgs[i];
102
+ if (img.classList.contains('visible')) {
103
+ const id = img.id;
104
+ const [_, m, s] = id.split('-');
105
+ if (mode === -1) mode = parseInt(m.slice(1));
106
+ if (step === -1) step = parseInt(s.slice(1));
107
+ break;
108
+ }
109
+ }
110
+ allImgs.forEach(img => img.classList.remove('visible'));
111
+ const targetId = 'view-m' + mode + '-s' + step;
112
+ const targetImg = document.getElementById(targetId);
113
+ if (targetImg) { targetImg.classList.add('visible'); }
114
+ const allBtns = document.querySelectorAll('.mode-btn');
115
+ allBtns.forEach((btn, idx) => {
116
+ if (idx === mode) btn.classList.add('active');
117
+ else btn.classList.remove('active');
118
+ });
119
+ }
120
+ function selectMode(mode) { refreshView(mode, -1); }
121
+ function onSliderChange(val) { refreshView(-1, parseInt(val)); }
122
+ </script>
123
+ """
124
+
125
+
126
+ empty_html = """
127
+ <div class="previewer-container">
128
+ <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
129
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
130
+ </div>
131
+ """
132
+
133
+
134
+ def image_to_base64(image):
135
+ buffered = io.BytesIO()
136
+ image = image.convert("RGB")
137
+ image.save(buffered, format="jpeg", quality=85)
138
+ img_str = base64.b64encode(buffered.getvalue()).decode()
139
+ return f"data:image/jpeg;base64,{img_str}"
140
+
141
+
142
+ def start_session(req: gr.Request):
143
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
144
+ os.makedirs(user_dir, exist_ok=True)
145
+
146
+
147
+ def end_session(req: gr.Request):
148
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
149
+ if os.path.exists(user_dir):
150
+ shutil.rmtree(user_dir)
151
+
152
+
153
+ def remove_background(input: Image.Image) -> Image.Image:
154
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f:
155
+ input = input.convert('RGB')
156
+ input.save(f.name)
157
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
158
+ output = Image.open(output)
159
+ os.unlink(f.name)
160
+ return output
161
+
162
+
163
+ def preprocess_image(input: Image.Image) -> Image.Image:
164
+ has_alpha = False
165
+ if input.mode == 'RGBA':
166
+ alpha = np.array(input)[:, :, 3]
167
+ if not np.all(alpha == 255):
168
+ has_alpha = True
169
+ max_size = max(input.size)
170
+ scale = min(1, 1024 / max_size)
171
+ if scale < 1:
172
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
173
+ if has_alpha:
174
+ output = input
175
+ else:
176
+ output = remove_background(input)
177
+ output_np = np.array(output)
178
+ alpha = output_np[:, :, 3]
179
+ bbox = np.argwhere(alpha > 0.8 * 255)
180
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
181
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
182
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
183
+ size = int(size * 1)
184
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
185
+ output = output.crop(bbox)
186
+ output = np.array(output).astype(np.float32) / 255
187
+ output = output[:, :, :3] * output[:, :, 3:4]
188
+ output = Image.fromarray((output * 255).astype(np.uint8))
189
+ return output
190
+
191
+
192
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
193
+ images = [image[0] for image in images]
194
+ with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
195
+ processed_images = list(executor.map(preprocess_image, images))
196
+ return processed_images
197
+
198
+
199
+ def pack_state(latents):
200
+ shape_slat, tex_slat, res = latents
201
+ return {
202
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
203
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
204
+ 'coords': shape_slat.coords.cpu().numpy(),
205
+ 'res': res,
206
+ }
207
+
208
+
209
+ def unpack_state(state: dict):
210
+ shape_slat = SparseTensor(
211
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
212
+ coords=torch.from_numpy(state['coords']).cuda(),
213
+ )
214
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
215
+ return shape_slat, tex_slat, state['res']
216
+
217
+
218
+ def get_seed(randomize_seed: bool, seed: int) -> int:
219
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
220
+
221
+
222
+ def prepare_multi_example() -> List[Image.Image]:
223
+ example_dir = "assets/example_multi_image"
224
+ if not os.path.exists(example_dir):
225
+ return []
226
+ multi_case = list(set([i.split('_')[0] for i in os.listdir(example_dir) if '_' in i]))
227
+ images = []
228
+ for case in multi_case:
229
+ _images = []
230
+ for i in range(1, 4):
231
+ img_path = f'{example_dir}/{case}_{i}.png'
232
+ if os.path.exists(img_path):
233
+ img = Image.open(img_path)
234
+ W, H = img.size
235
+ img = img.resize((int(W / H * 512), 512))
236
+ _images.append(np.array(img))
237
+ if len(_images) == 3:
238
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
239
+ return images
240
+
241
+
242
+ def split_image(image: Image.Image) -> List[Image.Image]:
243
+ image = np.array(image)
244
+ alpha = image[..., 3]
245
+ alpha = np.any(alpha > 0, axis=0)
246
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
247
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
248
+ images = []
249
+ for s, e in zip(start_pos, end_pos):
250
+ images.append(Image.fromarray(image[:, s:e+1]))
251
+ return [preprocess_image(image) for image in images]
252
+
253
+
254
+ @spaces.GPU(duration=120)
255
+ def image_to_3d(
256
+ image: Image.Image,
257
+ seed: int,
258
+ resolution: str,
259
+ ss_guidance_strength: float,
260
+ ss_guidance_rescale: float,
261
+ ss_sampling_steps: int,
262
+ ss_rescale_t: float,
263
+ shape_slat_guidance_strength: float,
264
+ shape_slat_guidance_rescale: float,
265
+ shape_slat_sampling_steps: int,
266
+ shape_slat_rescale_t: float,
267
+ tex_slat_guidance_strength: float,
268
+ tex_slat_guidance_rescale: float,
269
+ tex_slat_sampling_steps: int,
270
+ tex_slat_rescale_t: float,
271
+ req: gr.Request,
272
+ progress=gr.Progress(track_tqdm=True),
273
+ multiimages: List[Tuple[Image.Image, str]] = None,
274
+ is_multiimage: bool = False,
275
+ multiimage_algo: Literal["multidiffusion", "stochastic"] = "stochastic",
276
+ ) -> str:
277
+ if not is_multiimage:
278
+ outputs, latents = pipeline.run(
279
+ image,
280
+ seed=seed,
281
+ preprocess_image=False,
282
+ sparse_structure_sampler_params={
283
+ "steps": ss_sampling_steps,
284
+ "guidance_strength": ss_guidance_strength,
285
+ "guidance_rescale": ss_guidance_rescale,
286
+ "rescale_t": ss_rescale_t,
287
+ },
288
+ shape_slat_sampler_params={
289
+ "steps": shape_slat_sampling_steps,
290
+ "guidance_strength": shape_slat_guidance_strength,
291
+ "guidance_rescale": shape_slat_guidance_rescale,
292
+ "rescale_t": shape_slat_rescale_t,
293
+ },
294
+ tex_slat_sampler_params={
295
+ "steps": tex_slat_sampling_steps,
296
+ "guidance_strength": tex_slat_guidance_strength,
297
+ "guidance_rescale": tex_slat_guidance_rescale,
298
+ "rescale_t": tex_slat_rescale_t,
299
+ },
300
+ pipeline_type={
301
+ "512": "512",
302
+ "1024": "1024_cascade",
303
+ "1536": "1536_cascade",
304
+ }[resolution],
305
+ return_latent=True,
306
+ )
307
+ else:
308
+ outputs, latents = pipeline.run_multi_image(
309
+ [img[0] for img in multiimages],
310
+ seed=seed,
311
+ preprocess_image=False,
312
+ sparse_structure_sampler_params={
313
+ "steps": ss_sampling_steps,
314
+ "guidance_strength": ss_guidance_strength,
315
+ "guidance_rescale": ss_guidance_rescale,
316
+ "rescale_t": ss_rescale_t,
317
+ },
318
+ shape_slat_sampler_params={
319
+ "steps": shape_slat_sampling_steps,
320
+ "guidance_strength": shape_slat_guidance_strength,
321
+ "guidance_rescale": shape_slat_guidance_rescale,
322
+ "rescale_t": shape_slat_rescale_t,
323
+ },
324
+ tex_slat_sampler_params={
325
+ "steps": tex_slat_sampling_steps,
326
+ "guidance_strength": tex_slat_guidance_strength,
327
+ "guidance_rescale": tex_slat_guidance_rescale,
328
+ "rescale_t": tex_slat_rescale_t,
329
+ },
330
+ pipeline_type={
331
+ "512": "512",
332
+ "1024": "1024_cascade",
333
+ "1536": "1536_cascade",
334
+ }[resolution],
335
+ return_latent=True,
336
+ mode=multiimage_algo,
337
+ )
338
+
339
+ mesh = outputs[0]
340
+ mesh.simplify(16777216)
341
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
342
+ state = pack_state(latents)
343
+ torch.cuda.empty_cache()
344
+
345
+ def encode_preview_image(args):
346
+ m_idx, s_idx, render_key = args
347
+ img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
348
+ return (m_idx, s_idx, img_base64)
349
+
350
+ encode_tasks = [(m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS)]
351
+
352
+ with ThreadPoolExecutor(max_workers=8) as executor:
353
+ encoded_results = list(executor.map(encode_preview_image, encode_tasks))
354
+
355
+ encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
356
+ images_html = ""
357
+ for m_idx, mode in enumerate(MODES):
358
+ for s_idx in range(STEPS):
359
+ unique_id = f"view-m{m_idx}-s{s_idx}"
360
+ is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
361
+ vis_class = "visible" if is_visible else ""
362
+ img_base64 = encoded_map[(m_idx, s_idx)]
363
+ images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
364
+
365
+ btns_html = ""
366
+ for idx, mode in enumerate(MODES):
367
+ active_class = "active" if idx == DEFAULT_MODE else ""
368
+ btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
369
+
370
+ full_html = f"""
371
+ <div class="previewer-container">
372
+ <div class="tips-wrapper">
373
+ <div class="tips-icon">Tips</div>
374
+ <div class="tips-text">
375
+ <p>Render Mode - Click on the circular buttons to switch between different render modes.</p>
376
+ <p>View Angle - Drag the slider to change the view angle.</p>
377
+ </div>
378
+ </div>
379
+ <div class="display-row">{images_html}</div>
380
+ <div class="mode-row" id="btn-group">{btns_html}</div>
381
+ <div class="slider-row">
382
+ <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
383
+ </div>
384
+ </div>
385
+ """
386
+ return state, full_html
387
+
388
+
389
+ @spaces.GPU(duration=120)
390
+ def extract_glb(
391
+ state: dict,
392
+ decimation_target: int,
393
+ texture_size: int,
394
+ req: gr.Request,
395
+ progress=gr.Progress(track_tqdm=True),
396
+ ) -> Tuple[str, str]:
397
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
398
+ shape_slat, tex_slat, res = unpack_state(state)
399
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
400
+ mesh.simplify(16777216)
401
+ glb = o_voxel.postprocess.to_glb(
402
+ vertices=mesh.vertices,
403
+ faces=mesh.faces,
404
+ attr_volume=mesh.attrs,
405
+ coords=mesh.coords,
406
+ attr_layout=pipeline.pbr_attr_layout,
407
+ grid_size=res,
408
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
409
+ decimation_target=decimation_target,
410
+ texture_size=texture_size,
411
+ remesh=True,
412
+ remesh_band=1,
413
+ remesh_project=0,
414
+ use_tqdm=True,
415
+ )
416
+ now = datetime.now()
417
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
418
+ os.makedirs(user_dir, exist_ok=True)
419
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
420
+ glb.export(glb_path, extension_webp=True)
421
+ torch.cuda.empty_cache()
422
+ return glb_path, glb_path
423
+
424
+
425
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
426
+ gr.Markdown("""
427
+ ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2) - Local Testing
428
+ * Upload an image and click Generate to create a 3D asset.
429
+ """)
430
+
431
+ with gr.Row():
432
+ with gr.Column(scale=1, min_width=360):
433
+ with gr.Tabs() as input_tabs:
434
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
435
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
436
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
437
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=400, columns=3)
438
+ gr.Markdown("Input different views of the object in separate images.")
439
+
440
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
441
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
442
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
443
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
444
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
445
+
446
+ generate_btn = gr.Button("Generate")
447
+
448
+ with gr.Accordion(label="Advanced Settings", open=False):
449
+ gr.Markdown("Stage 1: Sparse Structure Generation")
450
+ with gr.Row():
451
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
452
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
453
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
454
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
455
+ gr.Markdown("Stage 2: Shape Generation")
456
+ with gr.Row():
457
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
458
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
459
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
460
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
461
+ gr.Markdown("Stage 3: Material Generation")
462
+ with gr.Row():
463
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
464
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
465
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
466
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
467
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
468
+
469
+ with gr.Column(scale=10):
470
+ with gr.Walkthrough(selected=0) as walkthrough:
471
+ with gr.Step("Preview", id=0):
472
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
473
+ extract_btn = gr.Button("Extract GLB")
474
+ with gr.Step("Extract", id=1):
475
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
476
+ download_btn = gr.DownloadButton(label="Download GLB")
477
+
478
+ with gr.Column(scale=1, min_width=172) as multiimage_example:
479
+ examples_multi = gr.Examples(
480
+ examples=prepare_multi_example(),
481
+ label="Multi Image Examples",
482
+ inputs=[image_prompt],
483
+ fn=split_image,
484
+ outputs=[multiimage_prompt],
485
+ run_on_click=True,
486
+ examples_per_page=8,
487
+ )
488
+
489
+ is_multiimage = gr.State(False)
490
+ output_buf = gr.State()
491
+
492
+ demo.load(start_session)
493
+ demo.unload(end_session)
494
+
495
+ single_image_input_tab.select(lambda: False, outputs=[is_multiimage])
496
+ multiimage_input_tab.select(lambda: True, outputs=[is_multiimage])
497
+
498
+ image_prompt.upload(preprocess_image, inputs=[image_prompt], outputs=[image_prompt])
499
+ multiimage_prompt.upload(preprocess_images, inputs=[multiimage_prompt], outputs=[multiimage_prompt])
500
+
501
+ generate_btn.click(
502
+ get_seed, inputs=[randomize_seed, seed], outputs=[seed],
503
+ ).then(
504
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
505
+ ).then(
506
+ image_to_3d,
507
+ inputs=[
508
+ image_prompt, seed, resolution,
509
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
510
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
511
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
512
+ multiimage_prompt, is_multiimage, multiimage_algo
513
+ ],
514
+ outputs=[output_buf, preview_output],
515
+ )
516
+
517
+ extract_btn.click(
518
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
519
+ ).then(
520
+ extract_glb,
521
+ inputs=[output_buf, decimation_target, texture_size],
522
+ outputs=[glb_output, download_btn],
523
+ )
524
+
525
+
526
+ if __name__ == "__main__":
527
+ os.makedirs(TMP_DIR, exist_ok=True)
528
+
529
+ for i in range(len(MODES)):
530
+ icon = Image.open(MODES[i]['icon'])
531
+ MODES[i]['icon_base64'] = image_to_base64(icon)
532
+
533
+ print("Connecting to background removal service...")
534
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
535
+
536
+ print("Loading TRELLIS.2 pipeline...")
537
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
538
+ pipeline.rembg_model = None
539
+ pipeline.low_vram = False
540
+ pipeline.cuda()
541
+
542
+ print("Loading environment maps...")
543
+ envmap = {
544
+ 'forest': EnvMap(torch.tensor(
545
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
546
+ dtype=torch.float32, device='cuda'
547
+ )),
548
+ 'sunset': EnvMap(torch.tensor(
549
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
550
+ dtype=torch.float32, device='cuda'
551
+ )),
552
+ 'courtyard': EnvMap(torch.tensor(
553
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
554
+ dtype=torch.float32, device='cuda'
555
+ )),
556
+ }
557
+
558
+ print("Starting Gradio app...")
559
+ demo.launch(css=css, head=head)