opsiclear-admin commited on
Commit
81587ad
·
verified ·
1 Parent(s): ace43a1

Add app.py

Browse files
Files changed (1) hide show
  1. app.py +861 -7
app.py CHANGED
@@ -1,7 +1,861 @@
1
- import gradio as gr
2
-
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
-
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_client import Client, handle_file
3
+ import spaces
4
+ from concurrent.futures import ThreadPoolExecutor
5
+
6
+ import os
7
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
10
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
11
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
12
+ from datetime import datetime
13
+ import shutil
14
+ import cv2
15
+ from typing import *
16
+ import numpy as np
17
+ from PIL import Image
18
+ import base64
19
+ import io
20
+ import tempfile
21
+
22
+ # Lazy imports - will be loaded when GPU is available
23
+ torch = None
24
+ SparseTensor = None
25
+ Trellis2ImageTo3DPipeline = None
26
+ EnvMap = None
27
+ render_utils = None
28
+ o_voxel = None
29
+
30
+ # Global state - initialized on first GPU call
31
+ pipeline = None
32
+ envmap = None
33
+ _initialized = False
34
+
35
+
36
+ def _lazy_import():
37
+ """Import GPU-dependent modules. Must be called from within a @spaces.GPU function."""
38
+ global torch, SparseTensor, Trellis2ImageTo3DPipeline, EnvMap, render_utils, o_voxel
39
+ if torch is None:
40
+ import torch as _torch
41
+ torch = _torch
42
+ if SparseTensor is None:
43
+ from trellis2.modules.sparse import SparseTensor as _SparseTensor
44
+ SparseTensor = _SparseTensor
45
+ if Trellis2ImageTo3DPipeline is None:
46
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline as _Trellis2ImageTo3DPipeline
47
+ Trellis2ImageTo3DPipeline = _Trellis2ImageTo3DPipeline
48
+ if EnvMap is None:
49
+ from trellis2.renderers import EnvMap as _EnvMap
50
+ EnvMap = _EnvMap
51
+ if render_utils is None:
52
+ from trellis2.utils import render_utils as _render_utils
53
+ render_utils = _render_utils
54
+ if o_voxel is None:
55
+ import o_voxel as _o_voxel
56
+ o_voxel = _o_voxel
57
+
58
+
59
+ def _initialize_pipeline():
60
+ """Initialize the pipeline and environment maps. Must be called from within a @spaces.GPU function."""
61
+ global pipeline, envmap, _initialized
62
+ if _initialized:
63
+ return
64
+
65
+ _lazy_import()
66
+
67
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
68
+ pipeline.rembg_model = None
69
+ pipeline.low_vram = False
70
+ pipeline.cuda()
71
+
72
+ envmap = {
73
+ 'forest': EnvMap(torch.tensor(
74
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
75
+ dtype=torch.float32, device='cuda'
76
+ )),
77
+ 'sunset': EnvMap(torch.tensor(
78
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
79
+ dtype=torch.float32, device='cuda'
80
+ )),
81
+ 'courtyard': EnvMap(torch.tensor(
82
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
83
+ dtype=torch.float32, device='cuda'
84
+ )),
85
+ }
86
+
87
+ _initialized = True
88
+
89
+
90
+ MAX_SEED = np.iinfo(np.int32).max
91
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
92
+ MODES = [
93
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
94
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
95
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
96
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
97
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
98
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
99
+ ]
100
+ STEPS = 8
101
+ DEFAULT_MODE = 3
102
+ DEFAULT_STEP = 3
103
+
104
+
105
+ css = """
106
+ /* Overwrite Gradio Default Style */
107
+ .stepper-wrapper {
108
+ padding: 0;
109
+ }
110
+
111
+ .stepper-container {
112
+ padding: 0;
113
+ align-items: center;
114
+ }
115
+
116
+ .step-button {
117
+ flex-direction: row;
118
+ }
119
+
120
+ .step-connector {
121
+ transform: none;
122
+ }
123
+
124
+ .step-number {
125
+ width: 16px;
126
+ height: 16px;
127
+ }
128
+
129
+ .step-label {
130
+ position: relative;
131
+ bottom: 0;
132
+ }
133
+
134
+ .wrap.center.full {
135
+ inset: 0;
136
+ height: 100%;
137
+ }
138
+
139
+ .wrap.center.full.translucent {
140
+ background: var(--block-background-fill);
141
+ }
142
+
143
+ .meta-text-center {
144
+ display: block !important;
145
+ position: absolute !important;
146
+ top: unset !important;
147
+ bottom: 0 !important;
148
+ right: 0 !important;
149
+ transform: unset !important;
150
+ }
151
+
152
+ /* Previewer */
153
+ .previewer-container {
154
+ position: relative;
155
+ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
156
+ width: 100%;
157
+ height: 722px;
158
+ margin: 0 auto;
159
+ padding: 20px;
160
+ display: flex;
161
+ flex-direction: column;
162
+ align-items: center;
163
+ justify-content: center;
164
+ }
165
+
166
+ .previewer-container .tips-icon {
167
+ position: absolute;
168
+ right: 10px;
169
+ top: 10px;
170
+ z-index: 10;
171
+ border-radius: 10px;
172
+ color: #fff;
173
+ background-color: var(--color-accent);
174
+ padding: 3px 6px;
175
+ user-select: none;
176
+ }
177
+
178
+ .previewer-container .tips-text {
179
+ position: absolute;
180
+ right: 10px;
181
+ top: 50px;
182
+ color: #fff;
183
+ background-color: var(--color-accent);
184
+ border-radius: 10px;
185
+ padding: 6px;
186
+ text-align: left;
187
+ max-width: 300px;
188
+ z-index: 10;
189
+ transition: all 0.3s;
190
+ opacity: 0%;
191
+ user-select: none;
192
+ }
193
+
194
+ .previewer-container .tips-text p {
195
+ font-size: 14px;
196
+ line-height: 1.2;
197
+ }
198
+
199
+ .tips-icon:hover + .tips-text {
200
+ display: block;
201
+ opacity: 100%;
202
+ }
203
+
204
+ /* Row 1: Display Modes */
205
+ .previewer-container .mode-row {
206
+ width: 100%;
207
+ display: flex;
208
+ gap: 8px;
209
+ justify-content: center;
210
+ margin-bottom: 20px;
211
+ flex-wrap: wrap;
212
+ }
213
+ .previewer-container .mode-btn {
214
+ width: 24px;
215
+ height: 24px;
216
+ border-radius: 50%;
217
+ cursor: pointer;
218
+ opacity: 0.5;
219
+ transition: all 0.2s;
220
+ border: 2px solid #ddd;
221
+ object-fit: cover;
222
+ }
223
+ .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
224
+ .previewer-container .mode-btn.active {
225
+ opacity: 1;
226
+ border-color: var(--color-accent);
227
+ transform: scale(1.1);
228
+ }
229
+
230
+ /* Row 2: Display Image */
231
+ .previewer-container .display-row {
232
+ margin-bottom: 20px;
233
+ min-height: 400px;
234
+ width: 100%;
235
+ flex-grow: 1;
236
+ display: flex;
237
+ justify-content: center;
238
+ align-items: center;
239
+ }
240
+ .previewer-container .previewer-main-image {
241
+ max-width: 100%;
242
+ max-height: 100%;
243
+ flex-grow: 1;
244
+ object-fit: contain;
245
+ display: none;
246
+ }
247
+ .previewer-container .previewer-main-image.visible {
248
+ display: block;
249
+ }
250
+
251
+ /* Row 3: Custom HTML Slider */
252
+ .previewer-container .slider-row {
253
+ width: 100%;
254
+ display: flex;
255
+ flex-direction: column;
256
+ align-items: center;
257
+ gap: 10px;
258
+ padding: 0 10px;
259
+ }
260
+
261
+ .previewer-container input[type=range] {
262
+ -webkit-appearance: none;
263
+ width: 100%;
264
+ max-width: 400px;
265
+ background: transparent;
266
+ }
267
+ .previewer-container input[type=range]::-webkit-slider-runnable-track {
268
+ width: 100%;
269
+ height: 8px;
270
+ cursor: pointer;
271
+ background: #ddd;
272
+ border-radius: 5px;
273
+ }
274
+ .previewer-container input[type=range]::-webkit-slider-thumb {
275
+ height: 20px;
276
+ width: 20px;
277
+ border-radius: 50%;
278
+ background: var(--color-accent);
279
+ cursor: pointer;
280
+ -webkit-appearance: none;
281
+ margin-top: -6px;
282
+ box-shadow: 0 2px 5px rgba(0,0,0,0.2);
283
+ transition: transform 0.1s;
284
+ }
285
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover {
286
+ transform: scale(1.2);
287
+ }
288
+
289
+ /* Overwrite Previewer Block Style */
290
+ .gradio-container .padded:has(.previewer-container) {
291
+ padding: 0 !important;
292
+ }
293
+
294
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] {
295
+ position: absolute;
296
+ top: 0;
297
+ left: 0;
298
+ }
299
+ """
300
+
301
+
302
+ head = """
303
+ <script>
304
+ function refreshView(mode, step) {
305
+ // 1. Find current mode and step
306
+ const allImgs = document.querySelectorAll('.previewer-main-image');
307
+ for (let i = 0; i < allImgs.length; i++) {
308
+ const img = allImgs[i];
309
+ if (img.classList.contains('visible')) {
310
+ const id = img.id;
311
+ const [_, m, s] = id.split('-');
312
+ if (mode === -1) mode = parseInt(m.slice(1));
313
+ if (step === -1) step = parseInt(s.slice(1));
314
+ break;
315
+ }
316
+ }
317
+
318
+ // 2. Hide ALL images
319
+ // We select all elements with class 'previewer-main-image'
320
+ allImgs.forEach(img => img.classList.remove('visible'));
321
+
322
+ // 3. Construct the specific ID for the current state
323
+ // Format: view-m{mode}-s{step}
324
+ const targetId = 'view-m' + mode + '-s' + step;
325
+ const targetImg = document.getElementById(targetId);
326
+
327
+ // 4. Show ONLY the target
328
+ if (targetImg) {
329
+ targetImg.classList.add('visible');
330
+ }
331
+
332
+ // 5. Update Button Highlights
333
+ const allBtns = document.querySelectorAll('.mode-btn');
334
+ allBtns.forEach((btn, idx) => {
335
+ if (idx === mode) btn.classList.add('active');
336
+ else btn.classList.remove('active');
337
+ });
338
+ }
339
+
340
+ // --- Action: Switch Mode ---
341
+ function selectMode(mode) {
342
+ refreshView(mode, -1);
343
+ }
344
+
345
+ // --- Action: Slider Change ---
346
+ function onSliderChange(val) {
347
+ refreshView(-1, parseInt(val));
348
+ }
349
+ </script>
350
+ """
351
+
352
+
353
+ empty_html = f"""
354
+ <div class="previewer-container">
355
+ <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
356
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
357
+ </div>
358
+ """
359
+
360
+
361
+ def image_to_base64(image):
362
+ buffered = io.BytesIO()
363
+ image = image.convert("RGB")
364
+ image.save(buffered, format="jpeg", quality=85)
365
+ img_str = base64.b64encode(buffered.getvalue()).decode()
366
+ return f"data:image/jpeg;base64,{img_str}"
367
+
368
+
369
+ def start_session(req: gr.Request):
370
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
371
+ os.makedirs(user_dir, exist_ok=True)
372
+
373
+
374
+ def end_session(req: gr.Request):
375
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
376
+ if os.path.exists(user_dir):
377
+ shutil.rmtree(user_dir)
378
+
379
+
380
+ def remove_background(input: Image.Image) -> Image.Image:
381
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
382
+ input = input.convert('RGB')
383
+ input.save(f.name)
384
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
385
+ output = Image.open(output)
386
+ return output
387
+
388
+
389
+ def preprocess_image(input: Image.Image) -> Image.Image:
390
+ """
391
+ Preprocess the input image.
392
+ """
393
+ # if has alpha channel, use it directly; otherwise, remove background
394
+ has_alpha = False
395
+ if input.mode == 'RGBA':
396
+ alpha = np.array(input)[:, :, 3]
397
+ if not np.all(alpha == 255):
398
+ has_alpha = True
399
+ max_size = max(input.size)
400
+ scale = min(1, 1024 / max_size)
401
+ if scale < 1:
402
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
403
+ if has_alpha:
404
+ output = input
405
+ else:
406
+ output = remove_background(input)
407
+ output_np = np.array(output)
408
+ alpha = output_np[:, :, 3]
409
+ bbox = np.argwhere(alpha > 0.8 * 255)
410
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
411
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
412
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
413
+ size = int(size * 1)
414
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
415
+ output = output.crop(bbox) # type: ignore
416
+ output = np.array(output).astype(np.float32) / 255
417
+ output = output[:, :, :3] * output[:, :, 3:4]
418
+ output = Image.fromarray((output * 255).astype(np.uint8))
419
+ return output
420
+
421
+
422
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
423
+ """
424
+ Preprocess a list of input images for multi-image conditioning.
425
+ Uses parallel processing for faster background removal.
426
+ """
427
+ images = [image[0] for image in images]
428
+ with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
429
+ processed_images = list(executor.map(preprocess_image, images))
430
+ return processed_images
431
+
432
+
433
+ def pack_state(latents):
434
+ shape_slat, tex_slat, res = latents
435
+ return {
436
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
437
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
438
+ 'coords': shape_slat.coords.cpu().numpy(),
439
+ 'res': res,
440
+ }
441
+
442
+
443
+ def unpack_state(state: dict):
444
+ _lazy_import()
445
+ shape_slat = SparseTensor(
446
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
447
+ coords=torch.from_numpy(state['coords']).cuda(),
448
+ )
449
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
450
+ return shape_slat, tex_slat, state['res']
451
+
452
+
453
+ def get_seed(randomize_seed: bool, seed: int) -> int:
454
+ """
455
+ Get the random seed.
456
+ """
457
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
458
+
459
+
460
+ def prepare_multi_example() -> List[Image.Image]:
461
+ """
462
+ Prepare multi-image examples for the gallery.
463
+ """
464
+ multi_case = list(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
465
+ images = []
466
+ for case in multi_case:
467
+ _images = []
468
+ for i in range(1, 4):
469
+ img = Image.open(f'assets/example_multi_image/{case}_{i}.png')
470
+ W, H = img.size
471
+ img = img.resize((int(W / H * 512), 512))
472
+ _images.append(np.array(img))
473
+ images.append(Image.fromarray(np.concatenate(_images, axis=1)))
474
+ return images
475
+
476
+
477
+ def split_image(image: Image.Image) -> List[Image.Image]:
478
+ """
479
+ Split a concatenated image into multiple views.
480
+ """
481
+ image = np.array(image)
482
+ alpha = image[..., 3]
483
+ alpha = np.any(alpha > 0, axis=0)
484
+ start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
485
+ end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
486
+ images = []
487
+ for s, e in zip(start_pos, end_pos):
488
+ images.append(Image.fromarray(image[:, s:e+1]))
489
+ return [preprocess_image(image) for image in images]
490
+
491
+
492
+ @spaces.GPU(duration=120)
493
+ def image_to_3d(
494
+ image: Image.Image,
495
+ seed: int,
496
+ resolution: str,
497
+ ss_guidance_strength: float,
498
+ ss_guidance_rescale: float,
499
+ ss_sampling_steps: int,
500
+ ss_rescale_t: float,
501
+ shape_slat_guidance_strength: float,
502
+ shape_slat_guidance_rescale: float,
503
+ shape_slat_sampling_steps: int,
504
+ shape_slat_rescale_t: float,
505
+ tex_slat_guidance_strength: float,
506
+ tex_slat_guidance_rescale: float,
507
+ tex_slat_sampling_steps: int,
508
+ tex_slat_rescale_t: float,
509
+ req: gr.Request,
510
+ progress=gr.Progress(track_tqdm=True),
511
+ multiimages: List[Tuple[Image.Image, str]] = None,
512
+ is_multiimage: bool = False,
513
+ multiimage_algo: Literal["multidiffusion", "stochastic"] = "stochastic",
514
+ ) -> str:
515
+ # Initialize pipeline on first call
516
+ _initialize_pipeline()
517
+
518
+ # --- Sampling ---
519
+ if not is_multiimage:
520
+ outputs, latents = pipeline.run(
521
+ image,
522
+ seed=seed,
523
+ preprocess_image=False,
524
+ sparse_structure_sampler_params={
525
+ "steps": ss_sampling_steps,
526
+ "guidance_strength": ss_guidance_strength,
527
+ "guidance_rescale": ss_guidance_rescale,
528
+ "rescale_t": ss_rescale_t,
529
+ },
530
+ shape_slat_sampler_params={
531
+ "steps": shape_slat_sampling_steps,
532
+ "guidance_strength": shape_slat_guidance_strength,
533
+ "guidance_rescale": shape_slat_guidance_rescale,
534
+ "rescale_t": shape_slat_rescale_t,
535
+ },
536
+ tex_slat_sampler_params={
537
+ "steps": tex_slat_sampling_steps,
538
+ "guidance_strength": tex_slat_guidance_strength,
539
+ "guidance_rescale": tex_slat_guidance_rescale,
540
+ "rescale_t": tex_slat_rescale_t,
541
+ },
542
+ pipeline_type={
543
+ "512": "512",
544
+ "1024": "1024_cascade",
545
+ "1536": "1536_cascade",
546
+ }[resolution],
547
+ return_latent=True,
548
+ )
549
+ else:
550
+ outputs, latents = pipeline.run_multi_image(
551
+ [image[0] for image in multiimages],
552
+ seed=seed,
553
+ preprocess_image=False,
554
+ sparse_structure_sampler_params={
555
+ "steps": ss_sampling_steps,
556
+ "guidance_strength": ss_guidance_strength,
557
+ "guidance_rescale": ss_guidance_rescale,
558
+ "rescale_t": ss_rescale_t,
559
+ },
560
+ shape_slat_sampler_params={
561
+ "steps": shape_slat_sampling_steps,
562
+ "guidance_strength": shape_slat_guidance_strength,
563
+ "guidance_rescale": shape_slat_guidance_rescale,
564
+ "rescale_t": shape_slat_rescale_t,
565
+ },
566
+ tex_slat_sampler_params={
567
+ "steps": tex_slat_sampling_steps,
568
+ "guidance_strength": tex_slat_guidance_strength,
569
+ "guidance_rescale": tex_slat_guidance_rescale,
570
+ "rescale_t": tex_slat_rescale_t,
571
+ },
572
+ pipeline_type={
573
+ "512": "512",
574
+ "1024": "1024_cascade",
575
+ "1536": "1536_cascade",
576
+ }[resolution],
577
+ return_latent=True,
578
+ mode=multiimage_algo,
579
+ )
580
+ mesh = outputs[0]
581
+ mesh.simplify(16777216) # nvdiffrast limit
582
+ images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
583
+ state = pack_state(latents)
584
+ torch.cuda.empty_cache()
585
+
586
+ # --- HTML Construction ---
587
+ # The Stack of 48 Images - encode in parallel for speed
588
+ def encode_preview_image(args):
589
+ m_idx, s_idx, render_key = args
590
+ img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
591
+ return (m_idx, s_idx, img_base64)
592
+
593
+ encode_tasks = [
594
+ (m_idx, s_idx, mode['render_key'])
595
+ for m_idx, mode in enumerate(MODES)
596
+ for s_idx in range(STEPS)
597
+ ]
598
+
599
+ with ThreadPoolExecutor(max_workers=8) as executor:
600
+ encoded_results = list(executor.map(encode_preview_image, encode_tasks))
601
+
602
+ # Build HTML from encoded results
603
+ encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
604
+ images_html = ""
605
+ for m_idx, mode in enumerate(MODES):
606
+ for s_idx in range(STEPS):
607
+ unique_id = f"view-m{m_idx}-s{s_idx}"
608
+ is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
609
+ vis_class = "visible" if is_visible else ""
610
+ img_base64 = encoded_map[(m_idx, s_idx)]
611
+
612
+ images_html += f"""
613
+ <img id="{unique_id}"
614
+ class="previewer-main-image {vis_class}"
615
+ src="{img_base64}"
616
+ loading="eager">
617
+ """
618
+
619
+ # Button Row HTML
620
+ btns_html = ""
621
+ for idx, mode in enumerate(MODES):
622
+ active_class = "active" if idx == DEFAULT_MODE else ""
623
+ # Note: onclick calls the JS function defined in Head
624
+ btns_html += f"""
625
+ <img src="{mode['icon_base64']}"
626
+ class="mode-btn {active_class}"
627
+ onclick="selectMode({idx})"
628
+ title="{mode['name']}">
629
+ """
630
+
631
+ # Assemble the full component
632
+ full_html = f"""
633
+ <div class="previewer-container">
634
+ <div class="tips-wrapper">
635
+ <div class="tips-icon">💡Tips</div>
636
+ <div class="tips-text">
637
+ <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
638
+ <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
639
+ </div>
640
+ </div>
641
+
642
+ <!-- Row 1: Viewport containing 48 static <img> tags -->
643
+ <div class="display-row">
644
+ {images_html}
645
+ </div>
646
+
647
+ <!-- Row 2 -->
648
+ <div class="mode-row" id="btn-group">
649
+ {btns_html}
650
+ </div>
651
+
652
+ <!-- Row 3: Slider -->
653
+ <div class="slider-row">
654
+ <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
655
+ </div>
656
+ </div>
657
+ """
658
+
659
+ return state, full_html
660
+
661
+
662
+ @spaces.GPU(duration=120)
663
+ def extract_glb(
664
+ state: dict,
665
+ decimation_target: int,
666
+ texture_size: int,
667
+ req: gr.Request,
668
+ progress=gr.Progress(track_tqdm=True),
669
+ ) -> Tuple[str, str]:
670
+ """
671
+ Extract a GLB file from the 3D model.
672
+
673
+ Args:
674
+ state (dict): The state of the generated 3D model.
675
+ decimation_target (int): The target face count for decimation.
676
+ texture_size (int): The texture resolution.
677
+
678
+ Returns:
679
+ str: The path to the extracted GLB file.
680
+ """
681
+ # Initialize pipeline on first call
682
+ _initialize_pipeline()
683
+
684
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
685
+ shape_slat, tex_slat, res = unpack_state(state)
686
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
687
+ mesh.simplify(16777216) # nvdiffrast limit
688
+ glb = o_voxel.postprocess.to_glb(
689
+ vertices=mesh.vertices,
690
+ faces=mesh.faces,
691
+ attr_volume=mesh.attrs,
692
+ coords=mesh.coords,
693
+ attr_layout=pipeline.pbr_attr_layout,
694
+ grid_size=res,
695
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
696
+ decimation_target=decimation_target,
697
+ texture_size=texture_size,
698
+ remesh=True,
699
+ remesh_band=1,
700
+ remesh_project=0,
701
+ use_tqdm=True,
702
+ )
703
+ now = datetime.now()
704
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
705
+ os.makedirs(user_dir, exist_ok=True)
706
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
707
+ glb.export(glb_path, extension_webp=True)
708
+ torch.cuda.empty_cache()
709
+ return glb_path, glb_path
710
+
711
+
712
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
713
+ gr.Markdown("""
714
+ ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
715
+ * Upload an image (preferably with an alpha-masked foreground object) and click Generate to create a 3D asset.
716
+ * Click Extract GLB to export and download the generated GLB file if you're satisfied with the result. Otherwise, try another time.
717
+ """)
718
+
719
+ with gr.Row():
720
+ with gr.Column(scale=1, min_width=360):
721
+ with gr.Tabs() as input_tabs:
722
+ with gr.Tab(label="Single Image", id=0) as single_image_input_tab:
723
+ image_prompt = gr.Image(label="Image Prompt", format="png", image_mode="RGBA", type="pil", height=400)
724
+ with gr.Tab(label="Multiple Images", id=1) as multiimage_input_tab:
725
+ multiimage_prompt = gr.Gallery(label="Image Prompt", format="png", type="pil", height=400, columns=3)
726
+ gr.Markdown("""
727
+ Input different views of the object in separate images.
728
+ *NOTE: this is an experimental algorithm without training a specialized model. It may not produce the best results for all images, especially those having different poses or inconsistent details.*
729
+ """)
730
+
731
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
732
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
733
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
734
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
735
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
736
+
737
+ generate_btn = gr.Button("Generate")
738
+
739
+ with gr.Accordion(label="Advanced Settings", open=False):
740
+ gr.Markdown("Stage 1: Sparse Structure Generation")
741
+ with gr.Row():
742
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
743
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
744
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
745
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
746
+ gr.Markdown("Stage 2: Shape Generation")
747
+ with gr.Row():
748
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
749
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
750
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
751
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
752
+ gr.Markdown("Stage 3: Material Generation")
753
+ with gr.Row():
754
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
755
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
756
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
757
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
758
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Multi-image Algorithm", value="stochastic")
759
+
760
+ with gr.Column(scale=10):
761
+ with gr.Walkthrough(selected=0) as walkthrough:
762
+ with gr.Step("Preview", id=0):
763
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
764
+ extract_btn = gr.Button("Extract GLB")
765
+ with gr.Step("Extract", id=1):
766
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
767
+ download_btn = gr.DownloadButton(label="Download GLB")
768
+ gr.Markdown("*We are actively working on improving the speed of GLB extraction. Currently, it may take half a minute or more and face count is limited.*")
769
+
770
+ with gr.Column(scale=1, min_width=172) as single_image_example:
771
+ examples = gr.Examples(
772
+ examples=[
773
+ f'assets/example_image/{image}'
774
+ for image in os.listdir("assets/example_image")
775
+ ],
776
+ inputs=[image_prompt],
777
+ fn=preprocess_image,
778
+ outputs=[image_prompt],
779
+ run_on_click=True,
780
+ examples_per_page=18,
781
+ )
782
+
783
+ with gr.Column(visible=True) as multiimage_example:
784
+ examples_multi = gr.Examples(
785
+ examples=prepare_multi_example(),
786
+ label="Multi Image Examples",
787
+ inputs=[image_prompt],
788
+ fn=split_image,
789
+ outputs=[multiimage_prompt],
790
+ run_on_click=True,
791
+ examples_per_page=8,
792
+ )
793
+
794
+ is_multiimage = gr.State(False)
795
+ output_buf = gr.State()
796
+
797
+
798
+ # Handlers
799
+ demo.load(start_session)
800
+ demo.unload(end_session)
801
+
802
+ single_image_input_tab.select(
803
+ lambda: (False, gr.update(visible=True), gr.update(visible=True)),
804
+ outputs=[is_multiimage, single_image_example, multiimage_example]
805
+ )
806
+ multiimage_input_tab.select(
807
+ lambda: (True, gr.update(visible=True), gr.update(visible=True)),
808
+ outputs=[is_multiimage, single_image_example, multiimage_example]
809
+ )
810
+
811
+ image_prompt.upload(
812
+ preprocess_image,
813
+ inputs=[image_prompt],
814
+ outputs=[image_prompt],
815
+ )
816
+ multiimage_prompt.upload(
817
+ preprocess_images,
818
+ inputs=[multiimage_prompt],
819
+ outputs=[multiimage_prompt],
820
+ )
821
+
822
+ generate_btn.click(
823
+ get_seed,
824
+ inputs=[randomize_seed, seed],
825
+ outputs=[seed],
826
+ ).then(
827
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
828
+ ).then(
829
+ image_to_3d,
830
+ inputs=[
831
+ image_prompt, seed, resolution,
832
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
833
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
834
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
835
+ multiimage_prompt, is_multiimage, multiimage_algo
836
+ ],
837
+ outputs=[output_buf, preview_output],
838
+ )
839
+
840
+ extract_btn.click(
841
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
842
+ ).then(
843
+ extract_glb,
844
+ inputs=[output_buf, decimation_target, texture_size],
845
+ outputs=[glb_output, download_btn],
846
+ )
847
+
848
+
849
+ # Launch the Gradio app
850
+ if __name__ == "__main__":
851
+ os.makedirs(TMP_DIR, exist_ok=True)
852
+
853
+ # Construct ui components (CPU-only, no GPU needed)
854
+ btn_img_base64_strs = {}
855
+ for i in range(len(MODES)):
856
+ icon = Image.open(MODES[i]['icon'])
857
+ MODES[i]['icon_base64'] = image_to_base64(icon)
858
+
859
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
860
+
861
+ demo.launch(css=css, head=head)