opsiclear-admin commited on
Commit
35f099b
·
verified ·
1 Parent(s): c9ea1f8

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +712 -778
app.py CHANGED
@@ -1,778 +1,712 @@
1
- import gradio as gr
2
- from gradio_client import Client, handle_file
3
- import spaces
4
- from concurrent.futures import ThreadPoolExecutor
5
-
6
- import os
7
- import sys
8
-
9
- # Prioritize local o-voxel submodule (with cumesh.fill_holes() fix) over prebuilt wheel
10
- _script_dir = os.path.dirname(os.path.abspath(__file__))
11
- sys.path.insert(0, os.path.join(_script_dir, 'o-voxel'))
12
-
13
- os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
14
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
- os.environ["ATTN_BACKEND"] = "flash_attn_3"
16
- os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(_script_dir, 'autotune_cache.json')
17
- os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
18
- from datetime import datetime
19
- import shutil
20
- import cv2
21
- from typing import *
22
- import torch
23
- import numpy as np
24
- from PIL import Image
25
- import base64
26
- import io
27
- import tempfile
28
- from trellis2.modules.sparse import SparseTensor
29
- from trellis2.pipelines import Trellis2ImageTo3DPipeline
30
- from trellis2.renderers import EnvMap
31
- from trellis2.utils import render_utils
32
- import o_voxel
33
-
34
-
35
- MAX_SEED = np.iinfo(np.int32).max
36
- TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
37
- MODES = [
38
- {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
39
- {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
40
- {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
41
- {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
42
- {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
43
- {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
44
- ]
45
- STEPS = 8
46
- DEFAULT_MODE = 3
47
- DEFAULT_STEP = 3
48
-
49
-
50
- css = """
51
- /* Overwrite Gradio Default Style */
52
- .stepper-wrapper {
53
- padding: 0;
54
- }
55
-
56
- .stepper-container {
57
- padding: 0;
58
- align-items: center;
59
- }
60
-
61
- .step-button {
62
- flex-direction: row;
63
- }
64
-
65
- .step-connector {
66
- transform: none;
67
- }
68
-
69
- .step-number {
70
- width: 16px;
71
- height: 16px;
72
- }
73
-
74
- .step-label {
75
- position: relative;
76
- bottom: 0;
77
- }
78
-
79
- .wrap.center.full {
80
- inset: 0;
81
- height: 100%;
82
- }
83
-
84
- .wrap.center.full.translucent {
85
- background: var(--block-background-fill);
86
- }
87
-
88
- .meta-text-center {
89
- display: block !important;
90
- position: absolute !important;
91
- top: unset !important;
92
- bottom: 0 !important;
93
- right: 0 !important;
94
- transform: unset !important;
95
- }
96
-
97
- /* Previewer */
98
- .previewer-container {
99
- position: relative;
100
- font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
101
- width: 100%;
102
- height: 722px;
103
- margin: 0 auto;
104
- padding: 20px;
105
- display: flex;
106
- flex-direction: column;
107
- align-items: center;
108
- justify-content: center;
109
- }
110
-
111
- .previewer-container .tips-icon {
112
- position: absolute;
113
- right: 10px;
114
- top: 10px;
115
- z-index: 10;
116
- border-radius: 10px;
117
- color: #fff;
118
- background-color: var(--color-accent);
119
- padding: 3px 6px;
120
- user-select: none;
121
- }
122
-
123
- .previewer-container .tips-text {
124
- position: absolute;
125
- right: 10px;
126
- top: 50px;
127
- color: #fff;
128
- background-color: var(--color-accent);
129
- border-radius: 10px;
130
- padding: 6px;
131
- text-align: left;
132
- max-width: 300px;
133
- z-index: 10;
134
- transition: all 0.3s;
135
- opacity: 0%;
136
- user-select: none;
137
- }
138
-
139
- .previewer-container .tips-text p {
140
- font-size: 14px;
141
- line-height: 1.2;
142
- }
143
-
144
- .tips-icon:hover + .tips-text {
145
- display: block;
146
- opacity: 100%;
147
- }
148
-
149
- /* Row 1: Display Modes */
150
- .previewer-container .mode-row {
151
- width: 100%;
152
- display: flex;
153
- gap: 8px;
154
- justify-content: center;
155
- margin-bottom: 20px;
156
- flex-wrap: wrap;
157
- }
158
- .previewer-container .mode-btn {
159
- width: 24px;
160
- height: 24px;
161
- border-radius: 50%;
162
- cursor: pointer;
163
- opacity: 0.5;
164
- transition: all 0.2s;
165
- border: 2px solid var(--neutral-600, #555);
166
- object-fit: cover;
167
- }
168
- .previewer-container .mode-btn:hover { opacity: 0.9; transform: scale(1.1); }
169
- .previewer-container .mode-btn.active {
170
- opacity: 1;
171
- border-color: var(--color-accent);
172
- transform: scale(1.1);
173
- }
174
-
175
- /* Row 2: Display Image */
176
- .previewer-container .display-row {
177
- margin-bottom: 20px;
178
- min-height: 400px;
179
- width: 100%;
180
- flex-grow: 1;
181
- display: flex;
182
- justify-content: center;
183
- align-items: center;
184
- }
185
- .previewer-container .previewer-main-image {
186
- max-width: 100%;
187
- max-height: 100%;
188
- flex-grow: 1;
189
- object-fit: contain;
190
- display: none;
191
- }
192
- .previewer-container .previewer-main-image.visible {
193
- display: block;
194
- }
195
-
196
- /* Row 3: Custom HTML Slider */
197
- .previewer-container .slider-row {
198
- width: 100%;
199
- display: flex;
200
- flex-direction: column;
201
- align-items: center;
202
- gap: 10px;
203
- padding: 0 10px;
204
- }
205
-
206
- .previewer-container input[type=range] {
207
- -webkit-appearance: none;
208
- width: 100%;
209
- max-width: 400px;
210
- background: transparent;
211
- }
212
- .previewer-container input[type=range]::-webkit-slider-runnable-track {
213
- width: 100%;
214
- height: 8px;
215
- cursor: pointer;
216
- background: var(--neutral-700, #404040);
217
- border-radius: 5px;
218
- }
219
- .previewer-container input[type=range]::-webkit-slider-thumb {
220
- height: 20px;
221
- width: 20px;
222
- border-radius: 50%;
223
- background: var(--color-accent);
224
- cursor: pointer;
225
- -webkit-appearance: none;
226
- margin-top: -6px;
227
- box-shadow: 0 2px 5px rgba(0,0,0,0.2);
228
- transition: transform 0.1s;
229
- }
230
- .previewer-container input[type=range]::-webkit-slider-thumb:hover {
231
- transform: scale(1.2);
232
- }
233
-
234
- /* Overwrite Previewer Block Style */
235
- .gradio-container .padded:has(.previewer-container) {
236
- padding: 0 !important;
237
- }
238
-
239
- .gradio-container:has(.previewer-container) [data-testid="block-label"] {
240
- position: absolute;
241
- top: 0;
242
- left: 0;
243
- }
244
- """
245
-
246
-
247
- head = """
248
- <script>
249
- function refreshView(mode, step) {
250
- // 1. Find current mode and step
251
- const allImgs = document.querySelectorAll('.previewer-main-image');
252
- for (let i = 0; i < allImgs.length; i++) {
253
- const img = allImgs[i];
254
- if (img.classList.contains('visible')) {
255
- const id = img.id;
256
- const [_, m, s] = id.split('-');
257
- if (mode === -1) mode = parseInt(m.slice(1));
258
- if (step === -1) step = parseInt(s.slice(1));
259
- break;
260
- }
261
- }
262
-
263
- // 2. Hide ALL images
264
- // We select all elements with class 'previewer-main-image'
265
- allImgs.forEach(img => img.classList.remove('visible'));
266
-
267
- // 3. Construct the specific ID for the current state
268
- // Format: view-m{mode}-s{step}
269
- const targetId = 'view-m' + mode + '-s' + step;
270
- const targetImg = document.getElementById(targetId);
271
-
272
- // 4. Show ONLY the target
273
- if (targetImg) {
274
- targetImg.classList.add('visible');
275
- }
276
-
277
- // 5. Update Button Highlights
278
- const allBtns = document.querySelectorAll('.mode-btn');
279
- allBtns.forEach((btn, idx) => {
280
- if (idx === mode) btn.classList.add('active');
281
- else btn.classList.remove('active');
282
- });
283
- }
284
-
285
- // --- Action: Switch Mode ---
286
- function selectMode(mode) {
287
- refreshView(mode, -1);
288
- }
289
-
290
- // --- Action: Slider Change ---
291
- function onSliderChange(val) {
292
- refreshView(-1, parseInt(val));
293
- }
294
- </script>
295
- """
296
-
297
-
298
- empty_html = f"""
299
- <div class="previewer-container">
300
- <svg style=" opacity: .5; height: var(--size-5); color: var(--body-text-color);"
301
- xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round" class="feather feather-image"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
302
- </div>
303
- """
304
-
305
-
306
- def image_to_base64(image):
307
- buffered = io.BytesIO()
308
- image = image.convert("RGB")
309
- image.save(buffered, format="jpeg", quality=85)
310
- img_str = base64.b64encode(buffered.getvalue()).decode()
311
- return f"data:image/jpeg;base64,{img_str}"
312
-
313
-
314
- def start_session(req: gr.Request):
315
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
316
- os.makedirs(user_dir, exist_ok=True)
317
-
318
-
319
- def end_session(req: gr.Request):
320
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
321
- shutil.rmtree(user_dir)
322
-
323
-
324
- def remove_background(input: Image.Image) -> Image.Image:
325
- with tempfile.NamedTemporaryFile(suffix='.png') as f:
326
- input = input.convert('RGB')
327
- input.save(f.name)
328
- output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
329
- output = Image.open(output)
330
- return output
331
-
332
-
333
- def preprocess_image(input: Image.Image) -> Image.Image:
334
- """
335
- Preprocess the input image.
336
- """
337
- # if has alpha channel, use it directly; otherwise, remove background
338
- has_alpha = False
339
- if input.mode == 'RGBA':
340
- alpha = np.array(input)[:, :, 3]
341
- if not np.all(alpha == 255):
342
- has_alpha = True
343
- max_size = max(input.size)
344
- scale = min(1, 1024 / max_size)
345
- if scale < 1:
346
- input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
347
- if has_alpha:
348
- output = input
349
- else:
350
- output = remove_background(input)
351
- output_np = np.array(output)
352
- alpha = output_np[:, :, 3]
353
- bbox = np.argwhere(alpha > 0.8 * 255)
354
- if bbox.size == 0:
355
- # No visible pixels, center the image in a square
356
- size = max(output.size)
357
- square = Image.new('RGB', (size, size), (0, 0, 0))
358
- output_rgb = output.convert('RGB') if output.mode == 'RGBA' else output
359
- square.paste(output_rgb, ((size - output.width) // 2, (size - output.height) // 2))
360
- return square
361
- bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
362
- center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
363
- size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
364
- size = int(size * 1)
365
- bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
366
- output = output.crop(bbox) # type: ignore
367
- output = np.array(output).astype(np.float32) / 255
368
- output = output[:, :, :3] * output[:, :, 3:4]
369
- output = Image.fromarray((output * 255).astype(np.uint8))
370
- return output
371
-
372
-
373
- def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
374
- """
375
- Preprocess a list of input images for multi-image conditioning.
376
- Uses parallel processing for faster background removal.
377
- """
378
- images = [image[0] for image in images]
379
- with ThreadPoolExecutor(max_workers=min(4, len(images))) as executor:
380
- processed_images = list(executor.map(preprocess_image, images))
381
- return processed_images
382
-
383
-
384
- def pack_state(latents: Tuple[SparseTensor, SparseTensor, int]) -> dict:
385
- shape_slat, tex_slat, res = latents
386
- return {
387
- 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
388
- 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
389
- 'coords': shape_slat.coords.cpu().numpy(),
390
- 'res': res,
391
- }
392
-
393
-
394
- def unpack_state(state: dict) -> Tuple[SparseTensor, SparseTensor, int]:
395
- shape_slat = SparseTensor(
396
- feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
397
- coords=torch.from_numpy(state['coords']).cuda(),
398
- )
399
- tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
400
- return shape_slat, tex_slat, state['res']
401
-
402
-
403
- def get_seed(randomize_seed: bool, seed: int) -> int:
404
- """
405
- Get the random seed.
406
- """
407
- return np.random.randint(0, MAX_SEED) if randomize_seed else seed
408
-
409
-
410
- def prepare_multi_example() -> List[str]:
411
- """
412
- Prepare multi-image examples. Returns list of image paths.
413
- Shows only the first view as representative thumbnail.
414
- """
415
- multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
416
- examples = []
417
- for case in multi_case:
418
- first_img = f'assets/example_multi_image/{case}_1.png'
419
- if os.path.exists(first_img):
420
- examples.append(first_img)
421
- return examples
422
-
423
-
424
- def load_multi_example(image) -> List[Image.Image]:
425
- """Load all views for a multi-image case by matching the input image."""
426
- import hashlib
427
-
428
- # Convert numpy array to PIL Image if needed
429
- if isinstance(image, np.ndarray):
430
- image = Image.fromarray(image)
431
-
432
- # Get hash of input image for matching
433
- input_hash = hashlib.md5(np.array(image.convert('RGBA')).tobytes()).hexdigest()
434
-
435
- # Find matching case by comparing with first images
436
- multi_case = sorted(set([i.split('_')[0] for i in os.listdir("assets/example_multi_image")]))
437
- for case_name in multi_case:
438
- first_img_path = f'assets/example_multi_image/{case_name}_1.png'
439
- if os.path.exists(first_img_path):
440
- first_img = Image.open(first_img_path).convert('RGBA')
441
- first_hash = hashlib.md5(np.array(first_img).tobytes()).hexdigest()
442
- if first_hash == input_hash:
443
- # Found match, load all views
444
- images = []
445
- for i in range(1, 7):
446
- img_path = f'assets/example_multi_image/{case_name}_{i}.png'
447
- if os.path.exists(img_path):
448
- img = Image.open(img_path)
449
- images.append(preprocess_image(img))
450
- return images
451
-
452
- # No match found, return the single image preprocessed
453
- return [preprocess_image(image)]
454
-
455
-
456
- def split_image(image: Image.Image) -> List[Image.Image]:
457
- """
458
- Split a concatenated image into multiple views.
459
- """
460
- image = np.array(image)
461
- alpha = image[..., 3]
462
- alpha = np.any(alpha > 0, axis=0)
463
- start_pos = np.where(~alpha[:-1] & alpha[1:])[0].tolist()
464
- end_pos = np.where(alpha[:-1] & ~alpha[1:])[0].tolist()
465
- images = []
466
- for s, e in zip(start_pos, end_pos):
467
- images.append(Image.fromarray(image[:, s:e+1]))
468
- return [preprocess_image(image) for image in images]
469
-
470
-
471
- @spaces.GPU(duration=90)
472
- def image_to_3d(
473
- seed: int,
474
- resolution: str,
475
- ss_guidance_strength: float,
476
- ss_guidance_rescale: float,
477
- ss_sampling_steps: int,
478
- ss_rescale_t: float,
479
- shape_slat_guidance_strength: float,
480
- shape_slat_guidance_rescale: float,
481
- shape_slat_sampling_steps: int,
482
- shape_slat_rescale_t: float,
483
- tex_slat_guidance_strength: float,
484
- tex_slat_guidance_rescale: float,
485
- tex_slat_sampling_steps: int,
486
- tex_slat_rescale_t: float,
487
- multiimages: List[Tuple[Image.Image, str]],
488
- multiimage_algo: Literal["stochastic", "multidiffusion"],
489
- tex_multiimage_algo: Literal["stochastic", "multidiffusion"],
490
- req: gr.Request,
491
- progress=gr.Progress(track_tqdm=True),
492
- ) -> str:
493
- # --- Sampling ---
494
- outputs, latents = pipeline.run_multi_image(
495
- [image[0] for image in multiimages],
496
- seed=seed,
497
- preprocess_image=False,
498
- sparse_structure_sampler_params={
499
- "steps": ss_sampling_steps,
500
- "guidance_strength": ss_guidance_strength,
501
- "guidance_rescale": ss_guidance_rescale,
502
- "rescale_t": ss_rescale_t,
503
- },
504
- shape_slat_sampler_params={
505
- "steps": shape_slat_sampling_steps,
506
- "guidance_strength": shape_slat_guidance_strength,
507
- "guidance_rescale": shape_slat_guidance_rescale,
508
- "rescale_t": shape_slat_rescale_t,
509
- },
510
- tex_slat_sampler_params={
511
- "steps": tex_slat_sampling_steps,
512
- "guidance_strength": tex_slat_guidance_strength,
513
- "guidance_rescale": tex_slat_guidance_rescale,
514
- "rescale_t": tex_slat_rescale_t,
515
- },
516
- pipeline_type={
517
- "512": "512",
518
- "1024": "1024_cascade",
519
- "1536": "1536_cascade",
520
- }[resolution],
521
- return_latent=True,
522
- mode=multiimage_algo,
523
- tex_mode=tex_multiimage_algo,
524
- )
525
- mesh = outputs[0]
526
- mesh.simplify(16777216) # nvdiffrast limit
527
- images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
528
- state = pack_state(latents)
529
- del outputs, mesh, latents # Free memory
530
- torch.cuda.empty_cache()
531
-
532
- # --- HTML Construction ---
533
- # The Stack of 48 Images - encode in parallel for speed
534
- def encode_preview_image(args):
535
- m_idx, s_idx, render_key = args
536
- img_base64 = image_to_base64(Image.fromarray(images[render_key][s_idx]))
537
- return (m_idx, s_idx, img_base64)
538
-
539
- encode_tasks = [
540
- (m_idx, s_idx, mode['render_key'])
541
- for m_idx, mode in enumerate(MODES)
542
- for s_idx in range(STEPS)
543
- ]
544
-
545
- with ThreadPoolExecutor(max_workers=8) as executor:
546
- encoded_results = list(executor.map(encode_preview_image, encode_tasks))
547
-
548
- # Build HTML from encoded results
549
- encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
550
- images_html = ""
551
- for m_idx, mode in enumerate(MODES):
552
- for s_idx in range(STEPS):
553
- unique_id = f"view-m{m_idx}-s{s_idx}"
554
- is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
555
- vis_class = "visible" if is_visible else ""
556
- img_base64 = encoded_map[(m_idx, s_idx)]
557
-
558
- images_html += f"""
559
- <img id="{unique_id}"
560
- class="previewer-main-image {vis_class}"
561
- src="{img_base64}"
562
- loading="eager">
563
- """
564
-
565
- # Button Row HTML
566
- btns_html = ""
567
- for idx, mode in enumerate(MODES):
568
- active_class = "active" if idx == DEFAULT_MODE else ""
569
- # Note: onclick calls the JS function defined in Head
570
- btns_html += f"""
571
- <img src="{mode['icon_base64']}"
572
- class="mode-btn {active_class}"
573
- onclick="selectMode({idx})"
574
- title="{mode['name']}">
575
- """
576
-
577
- # Assemble the full component
578
- full_html = f"""
579
- <div class="previewer-container">
580
- <div class="tips-wrapper">
581
- <div class="tips-icon">💡Tips</div>
582
- <div class="tips-text">
583
- <p>● <b>Render Mode</b> - Click on the circular buttons to switch between different render modes.</p>
584
- <p>● <b>View Angle</b> - Drag the slider to change the view angle.</p>
585
- </div>
586
- </div>
587
-
588
- <!-- Row 1: Viewport containing 48 static <img> tags -->
589
- <div class="display-row">
590
- {images_html}
591
- </div>
592
-
593
- <!-- Row 2 -->
594
- <div class="mode-row" id="btn-group">
595
- {btns_html}
596
- </div>
597
-
598
- <!-- Row 3: Slider -->
599
- <div class="slider-row">
600
- <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
601
- </div>
602
- </div>
603
- """
604
-
605
- return state, full_html
606
-
607
-
608
- @spaces.GPU(duration=60)
609
- def extract_glb(
610
- state: dict,
611
- decimation_target: int,
612
- texture_size: int,
613
- req: gr.Request,
614
- progress=gr.Progress(track_tqdm=True),
615
- ) -> str:
616
- """
617
- Extract a GLB file from the 3D model.
618
-
619
- Args:
620
- state (dict): The state of the generated 3D model.
621
- decimation_target (int): The target face count for decimation.
622
- texture_size (int): The texture resolution.
623
-
624
- Returns:
625
- str: The path to the extracted GLB file.
626
- """
627
- user_dir = os.path.join(TMP_DIR, str(req.session_hash))
628
- shape_slat, tex_slat, res = unpack_state(state)
629
- mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
630
- mesh.simplify(16777216) # nvdiffrast limit
631
- glb = o_voxel.postprocess.to_glb(
632
- vertices=mesh.vertices,
633
- faces=mesh.faces,
634
- attr_volume=mesh.attrs,
635
- coords=mesh.coords,
636
- attr_layout=pipeline.pbr_attr_layout,
637
- grid_size=res,
638
- aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
639
- decimation_target=decimation_target,
640
- texture_size=texture_size,
641
- remesh=True,
642
- remesh_band=1,
643
- remesh_project=0,
644
- use_tqdm=True,
645
- )
646
- now = datetime.now()
647
- timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
648
- os.makedirs(user_dir, exist_ok=True)
649
- glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
650
- glb.export(glb_path, extension_webp=True)
651
- torch.cuda.empty_cache()
652
- return glb_path
653
-
654
-
655
- with gr.Blocks(delete_cache=(600, 600), theme=gr.themes.Soft(primary_hue="orange", neutral_hue="slate")) as demo:
656
- gr.Markdown("""
657
- ## Image to 3D Asset with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
658
- * Upload an image and click Generate to create a 3D asset. If the image has alpha channel, it will be used as the mask. Otherwise, background is automatically removed.
659
- * Click Extract GLB to export the GLB file if you're satisfied with the preview.
660
- """)
661
-
662
- with gr.Row():
663
- with gr.Column(scale=1, min_width=360):
664
- multiimage_prompt = gr.Gallery(label="Multi-View Images", format="png", type="pil", height=400, columns=3)
665
-
666
- resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
667
- seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
668
- randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
669
- decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
670
- texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
671
-
672
- with gr.Row():
673
- generate_btn = gr.Button("Generate", variant="primary")
674
- extract_btn = gr.Button("Extract GLB")
675
-
676
- with gr.Accordion(label="Advanced Settings", open=False):
677
- gr.Markdown("Stage 1: Sparse Structure Generation")
678
- with gr.Row():
679
- ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
680
- ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
681
- ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
682
- ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
683
- gr.Markdown("Stage 2: Shape Generation")
684
- with gr.Row():
685
- shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
686
- shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
687
- shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
688
- shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
689
- gr.Markdown("Stage 3: Material Generation")
690
- with gr.Row():
691
- tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
692
- tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
693
- tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
694
- tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
695
- multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
696
- tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
697
-
698
- with gr.Column(scale=10):
699
- preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
700
- glb_output = gr.Model3D(label="Extracted GLB", height=400, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0), visible=False)
701
-
702
- example_image = gr.Image(visible=False) # Hidden component for examples
703
- examples_multi = gr.Examples(
704
- examples=prepare_multi_example(),
705
- inputs=[example_image],
706
- fn=load_multi_example,
707
- outputs=[multiimage_prompt],
708
- run_on_click=True,
709
- examples_per_page=24,
710
- )
711
-
712
- output_buf = gr.State()
713
-
714
-
715
- # Handlers
716
- demo.load(start_session)
717
- demo.unload(end_session)
718
- multiimage_prompt.upload(
719
- preprocess_images,
720
- inputs=[multiimage_prompt],
721
- outputs=[multiimage_prompt],
722
- )
723
-
724
- generate_btn.click(
725
- get_seed,
726
- inputs=[randomize_seed, seed],
727
- outputs=[seed],
728
- ).then(
729
- image_to_3d,
730
- inputs=[
731
- seed, resolution,
732
- ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
733
- shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
734
- tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
735
- multiimage_prompt, multiimage_algo, tex_multiimage_algo
736
- ],
737
- outputs=[output_buf, preview_output],
738
- )
739
-
740
- extract_btn.click(
741
- extract_glb,
742
- inputs=[output_buf, decimation_target, texture_size],
743
- outputs=[glb_output],
744
- )
745
-
746
-
747
- # Launch the Gradio app
748
- if __name__ == "__main__":
749
- os.makedirs(TMP_DIR, exist_ok=True)
750
-
751
- # Construct ui components
752
- btn_img_base64_strs = {}
753
- for i in range(len(MODES)):
754
- icon = Image.open(MODES[i]['icon'])
755
- MODES[i]['icon_base64'] = image_to_base64(icon)
756
-
757
- rmbg_client = Client("briaai/BRIA-RMBG-2.0")
758
- pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
759
- pipeline.rembg_model = None
760
- pipeline.low_vram = True # Enable low VRAM mode for better memory efficiency
761
- pipeline.cuda()
762
-
763
- envmap = {
764
- 'forest': EnvMap(torch.tensor(
765
- cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
766
- dtype=torch.float32, device='cuda'
767
- )),
768
- 'sunset': EnvMap(torch.tensor(
769
- cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
770
- dtype=torch.float32, device='cuda'
771
- )),
772
- 'courtyard': EnvMap(torch.tensor(
773
- cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
774
- dtype=torch.float32, device='cuda'
775
- )),
776
- }
777
-
778
- demo.queue(max_size=10, default_concurrency_limit=1).launch(css=css, head=head)
 
1
+ import gradio as gr
2
+ from gradio_client import Client, handle_file
3
+ import spaces
4
+ from concurrent.futures import ThreadPoolExecutor
5
+
6
+ import os
7
+ os.environ["OPENCV_IO_ENABLE_OPENEXR"] = '1'
8
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
9
+ os.environ["ATTN_BACKEND"] = "flash_attn_3"
10
+ os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'autotune_cache.json')
11
+ os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = '1'
12
+ from datetime import datetime
13
+ import shutil
14
+ import cv2
15
+ from typing import *
16
+ import numpy as np
17
+ from PIL import Image
18
+ import base64
19
+ import io
20
+ import tempfile
21
+
22
+ # Lazy imports - will be loaded when GPU is available
23
+ torch = None
24
+ SparseTensor = None
25
+ Trellis2ImageTo3DPipeline = None
26
+ EnvMap = None
27
+ render_utils = None
28
+ o_voxel = None
29
+
30
+ # Global state - initialized on first GPU call
31
+ pipeline = None
32
+ envmap = None
33
+ _initialized = False
34
+
35
+
36
+ def _lazy_import():
37
+ """Import GPU-dependent modules. Must be called from within a @spaces.GPU function."""
38
+ global torch, SparseTensor, Trellis2ImageTo3DPipeline, EnvMap, render_utils, o_voxel
39
+ if torch is None:
40
+ import torch as _torch
41
+ torch = _torch
42
+ if SparseTensor is None:
43
+ from trellis2.modules.sparse import SparseTensor as _SparseTensor
44
+ SparseTensor = _SparseTensor
45
+ if Trellis2ImageTo3DPipeline is None:
46
+ from trellis2.pipelines import Trellis2ImageTo3DPipeline as _Trellis2ImageTo3DPipeline
47
+ Trellis2ImageTo3DPipeline = _Trellis2ImageTo3DPipeline
48
+ if EnvMap is None:
49
+ from trellis2.renderers import EnvMap as _EnvMap
50
+ EnvMap = _EnvMap
51
+ if render_utils is None:
52
+ from trellis2.utils import render_utils as _render_utils
53
+ render_utils = _render_utils
54
+ if o_voxel is None:
55
+ import o_voxel as _o_voxel
56
+ o_voxel = _o_voxel
57
+ # Patch postprocess module with local fix for cumesh.fill_holes() bug
58
+ import importlib.util
59
+ _local_pp = os.path.join(os.path.dirname(os.path.abspath(__file__)),
60
+ 'o-voxel', 'o_voxel', 'postprocess.py')
61
+ if os.path.exists(_local_pp):
62
+ _spec = importlib.util.spec_from_file_location('o_voxel.postprocess', _local_pp)
63
+ _mod = importlib.util.module_from_spec(_spec)
64
+ _spec.loader.exec_module(_mod)
65
+ o_voxel.postprocess = _mod
66
+ import sys
67
+ sys.modules['o_voxel.postprocess'] = _mod
68
+
69
+
70
+ def _initialize_pipeline():
71
+ """Initialize the pipeline and environment maps. Must be called from within a @spaces.GPU function."""
72
+ global pipeline, envmap, _initialized
73
+ if _initialized:
74
+ return
75
+
76
+ _lazy_import()
77
+
78
+ pipeline = Trellis2ImageTo3DPipeline.from_pretrained('microsoft/TRELLIS.2-4B')
79
+ pipeline.rembg_model = None
80
+ pipeline.low_vram = False
81
+ pipeline.cuda()
82
+
83
+ envmap = {
84
+ 'forest': EnvMap(torch.tensor(
85
+ cv2.cvtColor(cv2.imread('assets/hdri/forest.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
86
+ dtype=torch.float32, device='cuda'
87
+ )),
88
+ 'sunset': EnvMap(torch.tensor(
89
+ cv2.cvtColor(cv2.imread('assets/hdri/sunset.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
90
+ dtype=torch.float32, device='cuda'
91
+ )),
92
+ 'courtyard': EnvMap(torch.tensor(
93
+ cv2.cvtColor(cv2.imread('assets/hdri/courtyard.exr', cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB),
94
+ dtype=torch.float32, device='cuda'
95
+ )),
96
+ }
97
+
98
+ _initialized = True
99
+
100
+
101
+ MAX_SEED = np.iinfo(np.int32).max
102
+ TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
103
+ MODES = [
104
+ {"name": "Normal", "icon": "assets/app/normal.png", "render_key": "normal"},
105
+ {"name": "Clay render", "icon": "assets/app/clay.png", "render_key": "clay"},
106
+ {"name": "Base color", "icon": "assets/app/basecolor.png", "render_key": "base_color"},
107
+ {"name": "HDRI forest", "icon": "assets/app/hdri_forest.png", "render_key": "shaded_forest"},
108
+ {"name": "HDRI sunset", "icon": "assets/app/hdri_sunset.png", "render_key": "shaded_sunset"},
109
+ {"name": "HDRI courtyard", "icon": "assets/app/hdri_courtyard.png", "render_key": "shaded_courtyard"},
110
+ ]
111
+ STEPS = 8
112
+ DEFAULT_MODE = 3
113
+ DEFAULT_STEP = 3
114
+
115
+
116
+ css = """
117
+ /* ColmapView Dark Theme */
118
+ :root {
119
+ --bg-void: #0a0a0a;
120
+ --bg-primary: #0f0f0f;
121
+ --bg-secondary: #161616;
122
+ --bg-tertiary: #1e1e1e;
123
+ --bg-input: #1a1a1a;
124
+ --bg-elevated: #242424;
125
+ --bg-hover: #262626;
126
+ --text-primary: #e8e8e8;
127
+ --text-secondary: #8a8a8a;
128
+ --text-muted: #5a5a5a;
129
+ --border-subtle: #222222;
130
+ --border-color: #2a2a2a;
131
+ --border-light: #3a3a3a;
132
+ --accent: #b8b8b8;
133
+ --accent-hover: #d0d0d0;
134
+ --accent-dim: rgba(184, 184, 184, 0.12);
135
+ --shadow-sm: 0 2px 4px rgba(0, 0, 0, 0.4), 0 1px 2px rgba(0, 0, 0, 0.3);
136
+ --shadow-md: 0 4px 12px rgba(0, 0, 0, 0.5), 0 2px 4px rgba(0, 0, 0, 0.3);
137
+ --radius: 0.25rem;
138
+ --radius-md: 0.375rem;
139
+ --radius-lg: 0.5rem;
140
+ --font-sans: 'Roboto', -apple-system, BlinkMacSystemFont, sans-serif;
141
+ --font-mono: 'JetBrains Mono', 'Fira Code', 'Consolas', monospace;
142
+ }
143
+
144
+ /* Global Overrides */
145
+ .gradio-container { background: var(--bg-primary) !important; color: var(--text-primary) !important; font-family: var(--font-sans) !important; }
146
+ .dark { background: var(--bg-primary) !important; }
147
+ body { background: var(--bg-void) !important; color-scheme: dark; }
148
+
149
+ /* Panels & Blocks */
150
+ .block { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; border-radius: var(--radius-lg) !important; }
151
+ .panel { background: var(--bg-secondary) !important; }
152
+
153
+ /* Inputs */
154
+ input, textarea, select { background: var(--bg-input) !important; border: 1px solid var(--border-color) !important; color: var(--text-primary) !important; border-radius: var(--radius) !important; }
155
+ input:focus, textarea:focus, select:focus { border-color: var(--accent) !important; outline: none !important; }
156
+
157
+ /* Buttons */
158
+ .primary { background: var(--accent) !important; color: var(--bg-void) !important; border: none !important; }
159
+ .primary:hover { background: var(--accent-hover) !important; }
160
+ .secondary { background: var(--bg-hover) !important; color: var(--text-primary) !important; border: 1px solid var(--border-color) !important; }
161
+ .secondary:hover { background: var(--bg-tertiary) !important; }
162
+ button { transition: all 0.15s ease-out !important; }
163
+
164
+ /* Labels & Text */
165
+ label, .label-wrap { color: var(--text-secondary) !important; font-size: 0.875rem !important; }
166
+ .prose { color: var(--text-primary) !important; }
167
+ .prose h2 { color: var(--text-primary) !important; border: none !important; }
168
+
169
+ /* Sliders */
170
+ input[type="range"] { accent-color: var(--accent) !important; }
171
+
172
+ /* Gallery */
173
+ .gallery { background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
174
+
175
+ /* Accordion */
176
+ .accordion { background: var(--bg-secondary) !important; border: 1px solid var(--border-color) !important; }
177
+
178
+ /* Scrollbar */
179
+ ::-webkit-scrollbar { width: 12px; }
180
+ ::-webkit-scrollbar-track { background: var(--bg-secondary); }
181
+ ::-webkit-scrollbar-thumb { background: var(--border-light); border-radius: 9999px; }
182
+ ::-webkit-scrollbar-thumb:hover { background: var(--text-muted); }
183
+
184
+ /* Gradio Overrides */
185
+ .stepper-wrapper { padding: 0; }
186
+ .stepper-container { padding: 0; align-items: center; }
187
+ .step-button { flex-direction: row; }
188
+ .step-connector { transform: none; }
189
+ .step-number { width: 16px; height: 16px; background: var(--bg-tertiary) !important; border: 1px solid var(--border-color) !important; }
190
+ .step-label { position: relative; bottom: 0; color: var(--text-secondary) !important; }
191
+ .wrap.center.full { inset: 0; height: 100%; }
192
+ .wrap.center.full.translucent { background: var(--bg-secondary); }
193
+ .meta-text-center { display: block !important; position: absolute !important; top: unset !important; bottom: 0 !important; right: 0 !important; transform: unset !important; }
194
+
195
+ /* Previewer */
196
+ .previewer-container {
197
+ position: relative;
198
+ font-family: var(--font-sans);
199
+ width: 100%;
200
+ height: 722px;
201
+ margin: 0 auto;
202
+ padding: 20px;
203
+ display: flex;
204
+ flex-direction: column;
205
+ align-items: center;
206
+ justify-content: center;
207
+ background: var(--bg-void);
208
+ border-radius: var(--radius-lg);
209
+ }
210
+ .previewer-container .tips-icon {
211
+ position: absolute; right: 10px; top: 10px; z-index: 10;
212
+ border-radius: var(--radius-md); color: var(--bg-void);
213
+ background-color: var(--accent); padding: 4px 8px;
214
+ user-select: none; font-size: 0.75rem; font-weight: 500;
215
+ }
216
+ .previewer-container .tips-text {
217
+ position: absolute; right: 10px; top: 45px;
218
+ color: var(--text-primary); background-color: var(--bg-elevated);
219
+ border: 1px solid var(--border-light); border-radius: var(--radius-md);
220
+ padding: 8px 12px; text-align: left; max-width: 280px; z-index: 10;
221
+ transition: opacity 0.15s ease-out; opacity: 0; user-select: none;
222
+ box-shadow: var(--shadow-md);
223
+ }
224
+ .previewer-container .tips-text p { font-size: 0.75rem; line-height: 1.4; color: var(--text-secondary); margin: 4px 0; }
225
+ .tips-icon:hover + .tips-text { opacity: 1; }
226
+ .previewer-container .mode-row { width: 100%; display: flex; gap: 8px; justify-content: center; margin-bottom: 16px; flex-wrap: wrap; }
227
+ .previewer-container .mode-btn {
228
+ width: 28px; height: 28px; border-radius: 50%; cursor: pointer;
229
+ opacity: 0.5; transition: all 0.15s ease-out;
230
+ border: 2px solid var(--border-light); object-fit: cover;
231
+ }
232
+ .previewer-container .mode-btn:hover { opacity: 0.85; transform: scale(1.1); border-color: var(--text-muted); }
233
+ .previewer-container .mode-btn.active { opacity: 1; border-color: var(--accent); transform: scale(1.1); }
234
+ .previewer-container .display-row { margin-bottom: 16px; min-height: 400px; width: 100%; flex-grow: 1; display: flex; justify-content: center; align-items: center; }
235
+ .previewer-container .previewer-main-image { max-width: 100%; max-height: 100%; flex-grow: 1; object-fit: contain; display: none; }
236
+ .previewer-container .previewer-main-image.visible { display: block; }
237
+ .previewer-container .slider-row { width: 100%; display: flex; flex-direction: column; align-items: center; gap: 8px; padding: 0 16px; }
238
+ .previewer-container input[type=range] { -webkit-appearance: none; width: 100%; max-width: 400px; background: transparent; height: 20px; }
239
+ .previewer-container input[type=range]::-webkit-slider-runnable-track { width: 100%; height: 4px; cursor: pointer; background: var(--bg-tertiary); border-radius: 9999px; }
240
+ .previewer-container input[type=range]::-webkit-slider-thumb {
241
+ height: 12px; width: 12px; border-radius: 50%; background: var(--accent);
242
+ cursor: pointer; -webkit-appearance: none; margin-top: -4px;
243
+ box-shadow: var(--shadow-sm); transition: transform 0.1s ease-out;
244
+ }
245
+ .previewer-container input[type=range]::-webkit-slider-thumb:hover { transform: scale(1.2); background: var(--accent-hover); }
246
+ .gradio-container .padded:has(.previewer-container) { padding: 0 !important; }
247
+ .gradio-container:has(.previewer-container) [data-testid="block-label"] { position: absolute; top: 0; left: 0; }
248
+ """
249
+
250
+
251
+ head = """
252
+ <script>
253
+ function refreshView(mode, step) {
254
+ const allImgs = document.querySelectorAll('.previewer-main-image');
255
+ for (let i = 0; i < allImgs.length; i++) {
256
+ const img = allImgs[i];
257
+ if (img.classList.contains('visible')) {
258
+ const id = img.id;
259
+ const [_, m, s] = id.split('-');
260
+ if (mode === -1) mode = parseInt(m.slice(1));
261
+ if (step === -1) step = parseInt(s.slice(1));
262
+ break;
263
+ }
264
+ }
265
+ allImgs.forEach(img => img.classList.remove('visible'));
266
+ const targetId = 'view-m' + mode + '-s' + step;
267
+ const targetImg = document.getElementById(targetId);
268
+ if (targetImg) { targetImg.classList.add('visible'); }
269
+ const allBtns = document.querySelectorAll('.mode-btn');
270
+ allBtns.forEach((btn, idx) => {
271
+ if (idx === mode) btn.classList.add('active');
272
+ else btn.classList.remove('active');
273
+ });
274
+ }
275
+ function selectMode(mode) { refreshView(mode, -1); }
276
+ function onSliderChange(val) { refreshView(-1, parseInt(val)); }
277
+ </script>
278
+ """
279
+
280
+
281
+ empty_html = """
282
+ <div class="previewer-container">
283
+ <svg style="opacity: .5; height: var(--size-5); color: var(--body-text-color);"
284
+ xmlns="http://www.w3.org/2000/svg" width="100%" height="100%" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><rect x="3" y="3" width="18" height="18" rx="2" ry="2"></rect><circle cx="8.5" cy="8.5" r="1.5"></circle><polyline points="21 15 16 10 5 21"></polyline></svg>
285
+ </div>
286
+ """
287
+
288
+
289
+ def image_to_base64(image):
290
+ buffered = io.BytesIO()
291
+ image = image.convert("RGB")
292
+ image.save(buffered, format="jpeg", quality=85)
293
+ img_str = base64.b64encode(buffered.getvalue()).decode()
294
+ return f"data:image/jpeg;base64,{img_str}"
295
+
296
+
297
+ def start_session(req: gr.Request):
298
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
299
+ os.makedirs(user_dir, exist_ok=True)
300
+
301
+
302
+ def end_session(req: gr.Request):
303
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
304
+ if os.path.exists(user_dir):
305
+ shutil.rmtree(user_dir)
306
+
307
+
308
+ def remove_background(input: Image.Image) -> Image.Image:
309
+ with tempfile.NamedTemporaryFile(suffix='.png') as f:
310
+ input = input.convert('RGB')
311
+ input.save(f.name)
312
+ output = rmbg_client.predict(handle_file(f.name), api_name="/image")[0][0]
313
+ output = Image.open(output)
314
+ return output
315
+
316
+
317
+ def preprocess_image(input: Image.Image) -> Image.Image:
318
+ """Preprocess a single input image."""
319
+ has_alpha = False
320
+ if input.mode == 'RGBA':
321
+ alpha = np.array(input)[:, :, 3]
322
+ if not np.all(alpha == 255):
323
+ has_alpha = True
324
+ max_size = max(input.size)
325
+ scale = min(1, 1024 / max_size)
326
+ if scale < 1:
327
+ input = input.resize((int(input.width * scale), int(input.height * scale)), Image.Resampling.LANCZOS)
328
+ if has_alpha:
329
+ output = input
330
+ else:
331
+ output = remove_background(input)
332
+ output_np = np.array(output)
333
+ alpha = output_np[:, :, 3]
334
+ bbox = np.argwhere(alpha > 0.8 * 255)
335
+ bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
336
+ center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
337
+ size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
338
+ size = int(size * 1)
339
+ bbox = center[0] - size // 2, center[1] - size // 2, center[0] + size // 2, center[1] + size // 2
340
+ output = output.crop(bbox)
341
+ output = np.array(output).astype(np.float32) / 255
342
+ output = output[:, :, :3] * output[:, :, 3:4]
343
+ output = Image.fromarray((output * 255).astype(np.uint8))
344
+ return output
345
+
346
+
347
+ def preprocess_images(images: List[Tuple[Image.Image, str]]) -> List[Image.Image]:
348
+ """Preprocess a list of input images. Uses parallel processing."""
349
+ if not images:
350
+ return []
351
+ imgs = [img[0] if isinstance(img, tuple) else img for img in images]
352
+ with ThreadPoolExecutor(max_workers=min(4, len(imgs))) as executor:
353
+ processed_images = list(executor.map(preprocess_image, imgs))
354
+ return processed_images
355
+
356
+
357
+ def pack_state(latents):
358
+ shape_slat, tex_slat, res = latents
359
+ return {
360
+ 'shape_slat_feats': shape_slat.feats.cpu().numpy(),
361
+ 'tex_slat_feats': tex_slat.feats.cpu().numpy(),
362
+ 'coords': shape_slat.coords.cpu().numpy(),
363
+ 'res': res,
364
+ }
365
+
366
+
367
+ def unpack_state(state: dict):
368
+ _lazy_import()
369
+ shape_slat = SparseTensor(
370
+ feats=torch.from_numpy(state['shape_slat_feats']).cuda(),
371
+ coords=torch.from_numpy(state['coords']).cuda(),
372
+ )
373
+ tex_slat = shape_slat.replace(torch.from_numpy(state['tex_slat_feats']).cuda())
374
+ return shape_slat, tex_slat, state['res']
375
+
376
+
377
+ def get_seed(randomize_seed: bool, seed: int) -> int:
378
+ return np.random.randint(0, MAX_SEED) if randomize_seed else seed
379
+
380
+
381
+ def prepare_examples() -> List[List[str]]:
382
+ """Prepare examples as lists of image paths (not concatenated)."""
383
+ example_dir = "assets/example_multi_image"
384
+ if not os.path.exists(example_dir):
385
+ return []
386
+ files = os.listdir(example_dir)
387
+ cases = list(set([f.split('_')[0] for f in files if '_' in f and f.endswith('.png')]))
388
+ examples = []
389
+ for case in sorted(cases):
390
+ case_images = []
391
+ for i in range(1, 10): # Support up to 9 images per example
392
+ img_path = f'{example_dir}/{case}_{i}.png'
393
+ if os.path.exists(img_path):
394
+ case_images.append(img_path)
395
+ if case_images:
396
+ examples.append(case_images)
397
+ return examples
398
+
399
+
400
+ def load_example(example_paths: List[str]) -> List[Image.Image]:
401
+ """Load example images from paths."""
402
+ images = []
403
+ for path in example_paths:
404
+ img = Image.open(path)
405
+ images.append(img)
406
+ return images
407
+
408
+
409
+ @spaces.GPU(duration=120)
410
+ def image_to_3d(
411
+ images: List[Tuple[Image.Image, str]],
412
+ seed: int,
413
+ resolution: str,
414
+ ss_guidance_strength: float,
415
+ ss_guidance_rescale: float,
416
+ ss_sampling_steps: int,
417
+ ss_rescale_t: float,
418
+ shape_slat_guidance_strength: float,
419
+ shape_slat_guidance_rescale: float,
420
+ shape_slat_sampling_steps: int,
421
+ shape_slat_rescale_t: float,
422
+ tex_slat_guidance_strength: float,
423
+ tex_slat_guidance_rescale: float,
424
+ tex_slat_sampling_steps: int,
425
+ tex_slat_rescale_t: float,
426
+ multiimage_algo: Literal["multidiffusion", "stochastic"],
427
+ tex_multiimage_algo: Literal["multidiffusion", "stochastic"],
428
+ req: gr.Request,
429
+ progress=gr.Progress(track_tqdm=True),
430
+ ) -> str:
431
+ # Initialize pipeline on first call
432
+ _initialize_pipeline()
433
+
434
+ # Extract images from gallery format
435
+ if not images:
436
+ raise gr.Error("Please upload at least one image")
437
+
438
+ imgs = [img[0] if isinstance(img, tuple) else img for img in images]
439
+
440
+ # --- Sampling ---
441
+ if len(imgs) == 1:
442
+ # Single image mode
443
+ outputs, latents = pipeline.run(
444
+ imgs[0],
445
+ seed=seed,
446
+ preprocess_image=False,
447
+ sparse_structure_sampler_params={
448
+ "steps": ss_sampling_steps,
449
+ "guidance_strength": ss_guidance_strength,
450
+ "guidance_rescale": ss_guidance_rescale,
451
+ "rescale_t": ss_rescale_t,
452
+ },
453
+ shape_slat_sampler_params={
454
+ "steps": shape_slat_sampling_steps,
455
+ "guidance_strength": shape_slat_guidance_strength,
456
+ "guidance_rescale": shape_slat_guidance_rescale,
457
+ "rescale_t": shape_slat_rescale_t,
458
+ },
459
+ tex_slat_sampler_params={
460
+ "steps": tex_slat_sampling_steps,
461
+ "guidance_strength": tex_slat_guidance_strength,
462
+ "guidance_rescale": tex_slat_guidance_rescale,
463
+ "rescale_t": tex_slat_rescale_t,
464
+ },
465
+ pipeline_type={
466
+ "512": "512",
467
+ "1024": "1024_cascade",
468
+ "1536": "1536_cascade",
469
+ }[resolution],
470
+ return_latent=True,
471
+ )
472
+ else:
473
+ # Multi-image mode
474
+ outputs, latents = pipeline.run_multi_image(
475
+ imgs,
476
+ seed=seed,
477
+ preprocess_image=False,
478
+ sparse_structure_sampler_params={
479
+ "steps": ss_sampling_steps,
480
+ "guidance_strength": ss_guidance_strength,
481
+ "guidance_rescale": ss_guidance_rescale,
482
+ "rescale_t": ss_rescale_t,
483
+ },
484
+ shape_slat_sampler_params={
485
+ "steps": shape_slat_sampling_steps,
486
+ "guidance_strength": shape_slat_guidance_strength,
487
+ "guidance_rescale": shape_slat_guidance_rescale,
488
+ "rescale_t": shape_slat_rescale_t,
489
+ },
490
+ tex_slat_sampler_params={
491
+ "steps": tex_slat_sampling_steps,
492
+ "guidance_strength": tex_slat_guidance_strength,
493
+ "guidance_rescale": tex_slat_guidance_rescale,
494
+ "rescale_t": tex_slat_rescale_t,
495
+ },
496
+ pipeline_type={
497
+ "512": "512",
498
+ "1024": "1024_cascade",
499
+ "1536": "1536_cascade",
500
+ }[resolution],
501
+ return_latent=True,
502
+ mode=multiimage_algo,
503
+ tex_mode=tex_multiimage_algo,
504
+ )
505
+
506
+ mesh = outputs[0]
507
+ mesh.simplify(16777216)
508
+ render_images = render_utils.render_snapshot(mesh, resolution=1024, r=2, fov=36, nviews=STEPS, envmap=envmap)
509
+ state = pack_state(latents)
510
+ torch.cuda.empty_cache()
511
+
512
+ # --- HTML Construction ---
513
+ def encode_preview_image(args):
514
+ m_idx, s_idx, render_key = args
515
+ img_base64 = image_to_base64(Image.fromarray(render_images[render_key][s_idx]))
516
+ return (m_idx, s_idx, img_base64)
517
+
518
+ encode_tasks = [(m_idx, s_idx, mode['render_key']) for m_idx, mode in enumerate(MODES) for s_idx in range(STEPS)]
519
+
520
+ with ThreadPoolExecutor(max_workers=8) as executor:
521
+ encoded_results = list(executor.map(encode_preview_image, encode_tasks))
522
+
523
+ encoded_map = {(m, s): b64 for m, s, b64 in encoded_results}
524
+ images_html = ""
525
+ for m_idx, mode in enumerate(MODES):
526
+ for s_idx in range(STEPS):
527
+ unique_id = f"view-m{m_idx}-s{s_idx}"
528
+ is_visible = (m_idx == DEFAULT_MODE and s_idx == DEFAULT_STEP)
529
+ vis_class = "visible" if is_visible else ""
530
+ img_base64 = encoded_map[(m_idx, s_idx)]
531
+ images_html += f'<img id="{unique_id}" class="previewer-main-image {vis_class}" src="{img_base64}" loading="eager">'
532
+
533
+ btns_html = ""
534
+ for idx, mode in enumerate(MODES):
535
+ active_class = "active" if idx == DEFAULT_MODE else ""
536
+ btns_html += f'<img src="{mode["icon_base64"]}" class="mode-btn {active_class}" onclick="selectMode({idx})" title="{mode["name"]}">'
537
+
538
+ full_html = f"""
539
+ <div class="previewer-container">
540
+ <div class="tips-wrapper">
541
+ <div class="tips-icon">Tips</div>
542
+ <div class="tips-text">
543
+ <p>Render Mode - Click buttons to switch render modes.</p>
544
+ <p>View Angle - Drag slider to change view.</p>
545
+ </div>
546
+ </div>
547
+ <div class="display-row">{images_html}</div>
548
+ <div class="mode-row" id="btn-group">{btns_html}</div>
549
+ <div class="slider-row">
550
+ <input type="range" id="custom-slider" min="0" max="{STEPS - 1}" value="{DEFAULT_STEP}" step="1" oninput="onSliderChange(this.value)">
551
+ </div>
552
+ </div>
553
+ """
554
+ return state, full_html
555
+
556
+
557
+ @spaces.GPU(duration=120)
558
+ def extract_glb(
559
+ state: dict,
560
+ decimation_target: int,
561
+ texture_size: int,
562
+ req: gr.Request,
563
+ progress=gr.Progress(track_tqdm=True),
564
+ ) -> Tuple[str, str]:
565
+ _initialize_pipeline()
566
+
567
+ user_dir = os.path.join(TMP_DIR, str(req.session_hash))
568
+ shape_slat, tex_slat, res = unpack_state(state)
569
+ mesh = pipeline.decode_latent(shape_slat, tex_slat, res)[0]
570
+ mesh.simplify(16777216)
571
+ glb = o_voxel.postprocess.to_glb(
572
+ vertices=mesh.vertices,
573
+ faces=mesh.faces,
574
+ attr_volume=mesh.attrs,
575
+ coords=mesh.coords,
576
+ attr_layout=pipeline.pbr_attr_layout,
577
+ grid_size=res,
578
+ aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]],
579
+ decimation_target=decimation_target,
580
+ texture_size=texture_size,
581
+ remesh=True,
582
+ remesh_band=1,
583
+ remesh_project=0,
584
+ use_tqdm=True,
585
+ )
586
+ now = datetime.now()
587
+ timestamp = now.strftime("%Y-%m-%dT%H%M%S") + f".{now.microsecond // 1000:03d}"
588
+ os.makedirs(user_dir, exist_ok=True)
589
+ glb_path = os.path.join(user_dir, f'sample_{timestamp}.glb')
590
+ glb.export(glb_path, extension_webp=True)
591
+ torch.cuda.empty_cache()
592
+ return glb_path, glb_path
593
+
594
+
595
+ with gr.Blocks(delete_cache=(600, 600)) as demo:
596
+ gr.Markdown("""
597
+ ## Multi-View Image to 3D with [TRELLIS.2](https://microsoft.github.io/TRELLIS.2)
598
+ Upload one or more images of an object and click Generate to create a 3D asset.
599
+ Multiple views from different angles will produce better results.
600
+ """)
601
+
602
+ with gr.Row():
603
+ with gr.Column(scale=1, min_width=360):
604
+ image_prompt = gr.Gallery(
605
+ label="Input Images (upload 1 or more views)",
606
+ format="png",
607
+ type="pil",
608
+ height=400,
609
+ columns=3,
610
+ object_fit="contain"
611
+ )
612
+ gr.Markdown("*Upload multiple views of the same object for better 3D reconstruction.*")
613
+
614
+ resolution = gr.Radio(["512", "1024", "1536"], label="Resolution", value="1024")
615
+ seed = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1)
616
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
617
+ decimation_target = gr.Slider(100000, 500000, label="Decimation Target", value=300000, step=10000)
618
+ texture_size = gr.Slider(1024, 4096, label="Texture Size", value=2048, step=1024)
619
+
620
+ generate_btn = gr.Button("Generate", variant="primary")
621
+
622
+ with gr.Accordion(label="Advanced Settings", open=False):
623
+ gr.Markdown("Stage 1: Sparse Structure Generation")
624
+ with gr.Row():
625
+ ss_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
626
+ ss_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.7, step=0.01)
627
+ ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
628
+ ss_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=5.0, step=0.1)
629
+ gr.Markdown("Stage 2: Shape Generation")
630
+ with gr.Row():
631
+ shape_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=7.5, step=0.1)
632
+ shape_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.5, step=0.01)
633
+ shape_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
634
+ shape_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
635
+ gr.Markdown("Stage 3: Material Generation")
636
+ with gr.Row():
637
+ tex_slat_guidance_strength = gr.Slider(1.0, 10.0, label="Guidance Strength", value=1.0, step=0.1)
638
+ tex_slat_guidance_rescale = gr.Slider(0.0, 1.0, label="Guidance Rescale", value=0.0, step=0.01)
639
+ tex_slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=12, step=1)
640
+ tex_slat_rescale_t = gr.Slider(1.0, 6.0, label="Rescale T", value=3.0, step=0.1)
641
+ multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Structure Algorithm", value="stochastic")
642
+ tex_multiimage_algo = gr.Radio(["stochastic", "multidiffusion"], label="Texture Algorithm", value="multidiffusion")
643
+
644
+ with gr.Column(scale=10):
645
+ with gr.Walkthrough(selected=0) as walkthrough:
646
+ with gr.Step("Preview", id=0):
647
+ preview_output = gr.HTML(empty_html, label="3D Asset Preview", show_label=True, container=True)
648
+ extract_btn = gr.Button("Extract GLB")
649
+ with gr.Step("Extract", id=1):
650
+ glb_output = gr.Model3D(label="Extracted GLB", height=724, show_label=True, display_mode="solid", clear_color=(0.25, 0.25, 0.25, 1.0))
651
+ download_btn = gr.DownloadButton(label="Download GLB")
652
+ gr.Markdown("*GLB extraction may take 30+ seconds.*")
653
+
654
+ with gr.Column(scale=1, min_width=200):
655
+ gr.Markdown("### Examples")
656
+ # Create example buttons that load images into gallery
657
+ example_data = prepare_examples()
658
+ for i, example_paths in enumerate(example_data[:12]): # Limit to 12 examples
659
+ case_name = os.path.basename(example_paths[0]).split('_')[0]
660
+ btn = gr.Button(f"{case_name} ({len(example_paths)} views)", size="sm")
661
+ btn.click(
662
+ fn=lambda paths=example_paths: load_example(paths),
663
+ outputs=[image_prompt]
664
+ )
665
+
666
+ output_buf = gr.State()
667
+
668
+ # Handlers
669
+ demo.load(start_session)
670
+ demo.unload(end_session)
671
+
672
+ image_prompt.upload(
673
+ preprocess_images,
674
+ inputs=[image_prompt],
675
+ outputs=[image_prompt],
676
+ )
677
+
678
+ generate_btn.click(
679
+ get_seed, inputs=[randomize_seed, seed], outputs=[seed],
680
+ ).then(
681
+ lambda: gr.Walkthrough(selected=0), outputs=walkthrough
682
+ ).then(
683
+ image_to_3d,
684
+ inputs=[
685
+ image_prompt, seed, resolution,
686
+ ss_guidance_strength, ss_guidance_rescale, ss_sampling_steps, ss_rescale_t,
687
+ shape_slat_guidance_strength, shape_slat_guidance_rescale, shape_slat_sampling_steps, shape_slat_rescale_t,
688
+ tex_slat_guidance_strength, tex_slat_guidance_rescale, tex_slat_sampling_steps, tex_slat_rescale_t,
689
+ multiimage_algo, tex_multiimage_algo
690
+ ],
691
+ outputs=[output_buf, preview_output],
692
+ )
693
+
694
+ extract_btn.click(
695
+ lambda: gr.Walkthrough(selected=1), outputs=walkthrough
696
+ ).then(
697
+ extract_glb,
698
+ inputs=[output_buf, decimation_target, texture_size],
699
+ outputs=[glb_output, download_btn],
700
+ )
701
+
702
+
703
+ if __name__ == "__main__":
704
+ os.makedirs(TMP_DIR, exist_ok=True)
705
+
706
+ for i in range(len(MODES)):
707
+ icon = Image.open(MODES[i]['icon'])
708
+ MODES[i]['icon_base64'] = image_to_base64(icon)
709
+
710
+ rmbg_client = Client("briaai/BRIA-RMBG-2.0")
711
+
712
+ demo.launch(css=css, head=head)