Spaces:
Sleeping
Sleeping
File size: 35,910 Bytes
c840ad0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 |
from __future__ import annotations
import logging
import os
import shutil
import uuid
import time
from pathlib import Path
from gradio_client import Client, handle_file
from gradio_client import exceptions as gradio_exceptions
import httpx
from config import (
HF_TOKEN,
MODEL_STORAGE_PATH,
)
logger = logging.getLogger(__name__)
class HuggingFaceService:
"""Handles 3D model generation via Hugging Face Spaces."""
# ============================================================================
# Initialization
# ============================================================================
def __init__(self):
"""Initializes all Gradio Clients.
Note: HF_TOKEN is optional. If not provided, Spaces will use their own quota.
If provided, authenticated calls may consume your GPU quota.
"""
# HF_TOKEN is optional - we'll try without it first to avoid consuming your quota
self.trellis_client: Client | None = None
self.shap_e_client: Client | None = None
self.hunyuan_client: Client | None = None
self.trellis_client2: Client | None = None
# Ensure model storage directory exists
self.storage_path = Path(MODEL_STORAGE_PATH)
self.storage_path.mkdir(parents=True, exist_ok=True)
# Initialize TRELLIS client (non-blocking - text-to-3D features will be unavailable if this fails)
# Hugging Face Spaces can be slow or sleeping (free tier), so we retry with backoff
self.trellis_client = self._initialize_client_with_retry(
"dkatz2391/TRELLIS_TextTo3D_Try2",
"TRELLIS",
max_retries=2,
)
# Initialize Shap-E client (non-blocking)
self.shap_e_client = self._initialize_client_with_retry(
"hysts/Shap-E",
"Shap-E",
max_retries=2,
)
# Initialize Hunyuan3D-2 client for image-to-3D (non-blocking)
self.hunyuan_client = self._initialize_client_with_retry(
"tencent/Hunyuan3D-2.1",
"Hunyuan3D-2",
max_retries=2,
)
self.trellis_client2 = self._initialize_client_with_retry(
"trellis-community/TRELLIS",
"TRELLIS-2",
max_retries=2,
)
def _initialize_client_with_retry(
self, space_id: str, service_name: str, max_retries: int = 2
) -> Client | None:
"""Initialize a Gradio Client with retry logic and exponential backoff.
Hugging Face Spaces on the free tier can be slow to respond or sleeping,
so we retry with increasing delays.
Args:
space_id: Hugging Face Space ID (e.g., "username/space-name")
service_name: Human-readable name for logging
max_retries: Maximum number of retry attempts
Returns:
Client instance if successful, None if all retries failed
"""
for attempt in range(max_retries + 1):
try:
if attempt > 0:
# Exponential backoff: 5s, 10s
wait_time = 5 * (2 ** (attempt - 1))
logger.info(
f"Retrying {service_name} initialization (attempt {attempt + 1}/{max_retries + 1}) "
f"after {wait_time}s delay..."
)
time.sleep(wait_time)
logger.info(
f"Initializing {service_name} Client (attempt {attempt + 1}/{max_retries + 1})..."
)
# Try without token first to avoid consuming your GPU quota
# Most public Spaces work without authentication and use their own quota
try:
client = Client(space_id) # No token - uses Space owner's quota
logger.info(
f"β {service_name} Gradio Client initialized (no token - using Space owner's quota)"
)
return client
except Exception as no_token_error:
# If that fails, try with token (may consume your GPU quota)
if HF_TOKEN:
logger.warning(
f"{service_name} requires authentication. Using HF_TOKEN (may consume your GPU quota)"
)
client = Client(space_id, hf_token=HF_TOKEN)
logger.info(
f"β {service_name} Gradio Client initialized (with token)"
)
return client
else:
# Re-raise the original error if no token available
raise no_token_error
except Exception as e:
error_msg = str(e).lower()
is_timeout = "timeout" in error_msg or "timed out" in error_msg
if attempt < max_retries:
logger.warning(
f"{service_name} initialization attempt {attempt + 1} failed: {e}. "
f"Will retry..."
)
else:
# Final attempt failed
if is_timeout:
logger.warning(
f"Failed to initialize {service_name} Client after {max_retries + 1} attempts: {e}. "
f"This is likely because the Hugging Face Space is sleeping (free tier) or slow to respond. "
f"{service_name} features will not be available. "
f"The Space will wake up automatically when first used, but initialization may take longer."
)
else:
logger.warning(
f"Failed to initialize {service_name} Client after {max_retries + 1} attempts: {e}. "
f"{service_name} features will not be available."
)
return None
return None
# ============================================================================
# Public API - Model Generation
# ============================================================================
def text_to_3d(self, prompt: str, model_id: str | None = None) -> str:
"""Generate a 3D model directly from a text prompt using TRELLIS.
Uses the /generate_and_extract_glb endpoint which:
- Generates 3D model from text
- Extracts GLB with textures
- Returns GLB file path directly
Args:
prompt: Text prompt for 3D model generation
model_id: Optional model ID to use for the filename. If not provided, generates a new UUID.
Returns:
Path to the generated GLB file
"""
if not self.trellis_client:
raise RuntimeError("TRELLIS Client not initialized.")
logger.info(f"Generating 3D model from text prompt: '{prompt[:60]}...'")
filename = model_id if model_id else str(uuid.uuid4())
glb_path = self.storage_path / f"{filename}.glb"
try:
# TRELLIS API Endpoint Comparison:
# - /generate_and_extract_glb: Returns GLB file directly with textures
# Includes mesh_simplify and texture_size parameters for AR optimization
# - /text_to_3d: Basic endpoint, may return raw 3D data (needs testing)
# Missing mesh_simplify and texture_size parameters
#
# We use /generate_and_extract_glb because:
# 1. Returns GLB file path/URL directly (no additional extraction step)
# 2. Includes texture_size parameter (critical for texture export)
# 3. Includes mesh_simplify parameter (optimizes for AR performance)
# 4. One-step process - generates and extracts in single call
logger.info("Generating 3D model with TRELLIS...")
try:
result = self.trellis_client.predict(
prompt=prompt,
seed=0, # Fixed seed for reproducibility
ss_guidance_strength=7.5,
ss_sampling_steps=25,
slat_guidance_strength=7.5,
slat_sampling_steps=25,
mesh_simplify=0.95, # Slight simplification for AR
texture_size=1024, # Texture size - key for texture export!
api_name="/generate_and_extract_glb", # Returns GLB directly
)
logger.debug(f"TRELLIS result type: {type(result)}")
logger.debug(f"TRELLIS result: {result}")
except gradio_exceptions.AppError as app_error:
logger.error(f"TRELLIS AppError: {app_error}")
raise RuntimeError(
f"TRELLIS Space returned an error: {app_error}. "
f"This might mean: 1) The Space is busy or has queue limits, "
f"2) The prompt is invalid, 3) The Space is experiencing issues."
) from app_error
except Exception as api_error:
logger.error(f"TRELLIS API call failed: {api_error}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(
f"TRELLIS API call failed: {api_error}"
) from api_error
# TRELLIS returns a string (file path or URL) directly
if isinstance(result, str):
glb_file_path = result
elif isinstance(result, dict) and "value" in result:
# Sometimes Gradio returns dict with 'value' key
glb_file_path = result["value"]
else:
logger.error(
f"Unexpected TRELLIS result format: {type(result)}, value: {result}"
)
raise RuntimeError(f"Unexpected TRELLIS result format: {type(result)}")
if not glb_file_path or not isinstance(glb_file_path, str):
raise RuntimeError(
f"TRELLIS returned invalid file path: {glb_file_path}"
)
# TRELLIS may return a URL instead of a local file path
if glb_file_path.startswith("http://") or glb_file_path.startswith(
"https://"
):
# Download the GLB file from the URL
logger.info(f"TRELLIS returned URL, downloading GLB: {glb_file_path}")
try:
with httpx.Client(timeout=60.0) as client:
response = client.get(glb_file_path)
response.raise_for_status()
# Ensure parent directory exists before saving
glb_path.parent.mkdir(parents=True, exist_ok=True)
# Save to our storage
with open(glb_path, "wb") as f:
f.write(response.content)
logger.info(
f"β Downloaded TRELLIS GLB: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
except Exception as download_error:
logger.error(f"Failed to download GLB from URL: {download_error}")
raise RuntimeError(
f"Failed to download GLB from TRELLIS URL: {download_error}"
)
else:
# Local file path - check if it exists and copy it
if not Path(glb_file_path).exists():
raise RuntimeError(
f"Generated GLB file not found: {glb_file_path}. "
f"TRELLIS may have failed to generate the model."
)
# Ensure parent directory exists before copying
glb_path.parent.mkdir(parents=True, exist_ok=True)
# Copy the generated GLB to our storage
shutil.copy(glb_file_path, glb_path)
logger.info(
f"β TRELLIS GLB generated: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
return str(glb_path)
except Exception as e:
logger.error(f"TRELLIS text-to-3D generation failed: {e}")
raise RuntimeError(f"3D model generation failed: {e}")
def text_to_3d_shap_e(self, prompt: str, model_id: str | None = None) -> str:
"""Generate a 3D model from text prompt using Shap-E.
Uses Shap-E for advanced 3D model generation from text prompts.
Uses the /text_to_3d endpoint with the prompt parameter for text-to-3D.
API Reference: https://huggingface.co/spaces/hysts/Shap-E
Args:
prompt: Text prompt for 3D model generation
model_id: Optional model ID to use for the filename. If not provided, generates a new UUID.
Returns:
Path to the generated GLB file
"""
if not self.shap_e_client:
raise RuntimeError("Shap-E Client not initialized.")
logger.info(
f"Generating 3D model with Shap-E from text prompt: '{prompt[:60]}...'"
)
filename = model_id if model_id else str(uuid.uuid4())
glb_path = self.storage_path / f"{filename}.glb"
try:
logger.info("Generating 3D model with Shap-E...")
try:
DEFAULT_SEED = 0
DEFAULT_GUIDANCE_SCALE = 20.0
DEFAULT_STEPS = 100
# Shap-E /text-to-3d endpoint typically accepts:
# - prompt (str)
# - seed (int/float)
# - guidance_scale (float)
# - num_inference_steps (int) - not "steps"
result = self.shap_e_client.predict(
prompt=prompt,
seed=DEFAULT_SEED,
guidance_scale=DEFAULT_GUIDANCE_SCALE,
num_inference_steps=DEFAULT_STEPS, # This value is correctly mapped to 'num_inference_steps'
api_name="/text-to-3d",
)
logger.debug(f"Shap-E result type: {type(result)}")
logger.debug(f"Shap-E result: {result}")
except gradio_exceptions.AppError as app_error:
logger.error(f"Shap-E AppError: {app_error}")
raise RuntimeError(
f"Shap-E Space returned an error: {app_error}. "
f"This might mean: 1) The Space is busy or has queue limits, "
f"2) The prompt is invalid, 3) The Space is experiencing issues."
) from app_error
except Exception as api_error:
logger.error(f"Shap-E API call failed: {api_error}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(
f"Shap-E API call failed: {api_error}"
) from api_error
# Handle result - format depends on the actual endpoint response
if isinstance(result, str):
temp_file_path = result
elif isinstance(result, (list, tuple)) and len(result) > 0:
# Shap-E may return a tuple/list with file path as first element
temp_file_path = result[0]
elif isinstance(result, dict):
# Check for common keys that might contain file path
if "file" in result:
temp_file_path = result["file"]
elif "value" in result:
temp_file_path = result["value"]
else:
logger.error(
f"Unexpected Shap-E result format: {type(result)}, value: {result}"
)
raise RuntimeError(
f"Unexpected Shap-E result format: {type(result)}"
)
else:
logger.error(
f"Unexpected Shap-E result format: {type(result)}, value: {result}"
)
raise RuntimeError(f"Unexpected Shap-E result format: {type(result)}")
if not temp_file_path or not isinstance(temp_file_path, str):
raise RuntimeError(
f"Shap-E returned invalid file path: {temp_file_path}"
)
# Shap-E may return a URL instead of a local file path
if temp_file_path.startswith("http://") or temp_file_path.startswith(
"https://"
):
# Download the GLB file from the URL
logger.info(f"Shap-E returned URL, downloading GLB: {temp_file_path}")
try:
with httpx.Client(timeout=120.0) as client:
response = client.get(temp_file_path)
response.raise_for_status()
# Ensure parent directory exists before saving
glb_path.parent.mkdir(parents=True, exist_ok=True)
# Save to our storage
with open(glb_path, "wb") as f:
f.write(response.content)
logger.info(
f"β Downloaded Shap-E GLB: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
except Exception as download_error:
logger.error(f"Failed to download GLB from URL: {download_error}")
raise RuntimeError(
f"Failed to download GLB from Shap-E URL: {download_error}"
)
else:
# Local file path - check if it exists and copy it
if not Path(temp_file_path).exists():
raise RuntimeError(
f"Generated GLB file not found: {temp_file_path}. "
f"Shap-E may have failed to generate the model."
)
# Ensure parent directory exists before copying
glb_path.parent.mkdir(parents=True, exist_ok=True)
# Note: Shap-E might return a .ply, but we save it as .glb
# as requested by the function signature.
shutil.copy(temp_file_path, glb_path)
logger.info(
f"β Shap-E GLB generated: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
return str(glb_path)
except Exception as e:
logger.error(f"Shap-E text-to-3D generation failed: {e}")
raise RuntimeError(f"3D model generation failed: {e}")
def image_to_3d_hunyuan(
self, image_file_path: str, model_id: str | None = None
) -> str:
"""
Generate a 3D model from a 2D image using Hunyuan3D.
This optimized version performs remote chaining of the /generation_all and
/on_export_click steps to avoid slow local download/re-upload cycles of intermediary files.
"""
if not self.hunyuan_client:
raise RuntimeError("Hunyuan3D Client not initialized.")
if not Path(image_file_path).exists():
logger.error(f"Input image file not found at: {image_file_path}")
raise FileNotFoundError(f"Input image file not found: {image_file_path}")
logger.info(
f"Generating 3D model with Hunyuan3D from image file: '{image_file_path}'"
)
filename = model_id if model_id else str(uuid.uuid4())
glb_path = self.storage_path / f"{filename}.glb"
# Default parameters - Tuned for faster generation time and to avoid GPU quota limits
DEFAULT_STEPS = 15 # Further reduced from 20 to decrease generation time and avoid GPU quota limits (Job requested 180s vs 75s left)
DEFAULT_GUIDANCE = 5.0
DEFAULT_SEED = 1234
RANDOMIZE_SEED = True
DEFAULT_OCTREE_RES = 128 # Further reduced from 192 to decrease generation time and avoid GPU quota limits
REMOVE_BG = True
DEFAULT_NUM_CHUNKS = (
4000 # Further reduced from 5000 to decrease generation time
)
# Helper function to extract file path, handling Gradio's inconsistent wrapping
def extract_path(result_part):
if isinstance(result_part, dict):
# Check for 'value' (from previous API return) or 'path'/'name' (if pre-wrapped)
if "value" in result_part:
return result_part["value"]
if "path" in result_part:
return result_part["path"]
if "name" in result_part:
return result_part["name"]
return result_part
# Helper function to wrap a remote path into the required dictionary format for the next API call
# CRITICAL: Use 'path' as the key to satisfy the downstream Gradio component's FileData validation.
def format_for_remote_api(path: str) -> dict:
return {"path": path}
try:
# --- Step 1: Call /generation_all to get mesh and texture ---
logger.info("Hunyuan3D Step 1/2: Calling /generation_all...")
try:
gen_result = self.hunyuan_client.predict(
image=handle_file(image_file_path),
mv_image_front=None,
mv_image_back=None,
mv_image_left=None,
mv_image_right=None,
steps=DEFAULT_STEPS,
guidance_scale=DEFAULT_GUIDANCE,
seed=DEFAULT_SEED,
octree_resolution=DEFAULT_OCTREE_RES,
check_box_rembg=REMOVE_BG,
num_chunks=DEFAULT_NUM_CHUNKS,
randomize_seed=RANDOMIZE_SEED,
api_name="/generation_all",
)
logger.debug(f"Hunyuan3D /generation_all result: {gen_result}")
except gradio_exceptions.AppError as app_error:
logger.error(f"Hunyuan3D /generation_all AppError: {app_error}")
raise RuntimeError(
f"Hunyuan3D Space (generation) returned an error: {app_error}."
) from app_error
except Exception as api_error:
logger.error(f"Hunyuan3D /generation_all API call failed: {api_error}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(
f"Hunyuan3D /generation_all API call failed: {api_error}"
) from api_error
# /generation_all returns a tuple of 5 elements
if not isinstance(gen_result, (list, tuple)) or len(gen_result) < 2:
raise RuntimeError(
f"Unexpected result from Hunyuan3D /generation_all: {gen_result}"
)
# Extract raw remote paths (strings)
mesh_raw_path = extract_path(gen_result[0])
texture_raw_path = extract_path(gen_result[1])
if not mesh_raw_path:
raise RuntimeError(
f"Hunyuan3D /generation_all did not return a mesh file path."
)
logger.info(f"Hunyuan3D generated mesh: {mesh_raw_path}")
logger.info(f"Hunyuan3D generated texture: {texture_raw_path}")
# Check if /generation_all already returned a GLB file (rare, but possible)
if isinstance(texture_raw_path, str) and texture_raw_path.endswith(".glb"):
logger.info(
"Hunyuan3D /generation_all already returned a GLB file, skipping export step."
)
temp_file_path = texture_raw_path
else:
# --- Step 2: Call /on_export_click to convert to GLB (Remote Chaining) ---
logger.info(
"Hunyuan3D Step 2/2: Calling /on_export_click to export GLB (Remote Chaining)..."
)
# Pass remote paths directly, formatted in the exact structure Gradio needs: {"path": path}.
file_out_param = format_for_remote_api(mesh_raw_path)
if texture_raw_path:
file_out2_param = format_for_remote_api(texture_raw_path)
export_texture_flag = True
else:
# Use mesh for file_out2 if texture is missing
file_out2_param = format_for_remote_api(mesh_raw_path)
export_texture_flag = False
logger.info(
"No texture file generated, using mesh file for both file_out parameters"
)
try:
export_result = self.hunyuan_client.predict(
file_out=file_out_param, # Optimized: {"path": "/remote/path"}
file_out2=file_out2_param, # Optimized: {"path": "/remote/path"}
file_type="glb",
reduce_face=False,
export_texture=export_texture_flag,
target_face_num=10000,
api_name="/on_export_click",
)
logger.debug(f"Hunyuan3D /on_export_click result: {export_result}")
except gradio_exceptions.AppError as app_error:
logger.error(f"Hunyuan3D /on_export_click AppError: {app_error}")
raise RuntimeError(
f"Hunyuan3D Space (export) returned an error: {app_error}."
) from app_error
except Exception as api_error:
logger.error(
f"Hunyuan3D /on_export_click API call failed: {api_error}"
)
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
raise RuntimeError(
f"Hunyuan3D /on_export_click API call failed: {api_error}"
) from api_error
# /on_export_click returns a tuple of 2 elements
if (
not isinstance(export_result, (list, tuple))
or len(export_result) < 2
):
raise RuntimeError(
f"Unexpected result from Hunyuan3D /on_export_click: {export_result}"
)
# Final result is the downloaded GLB URL/path
temp_file_path = extract_path(export_result[1])
if not temp_file_path or not isinstance(temp_file_path, str):
raise RuntimeError(
f"Hunyuan3D /on_export_click returned invalid file path: {temp_file_path}"
)
# --- Step 3: Handle the final GLB file (download or copy) ---
# The final GLB URL is the only file we need to download locally
if temp_file_path.startswith("http://") or temp_file_path.startswith(
"https://"
):
logger.info(
f"Hunyuan3D returned URL, downloading GLB: {temp_file_path}"
)
try:
with httpx.Client(timeout=120.0) as client:
response = client.get(temp_file_path)
response.raise_for_status()
glb_path.parent.mkdir(parents=True, exist_ok=True)
with open(glb_path, "wb") as f:
f.write(response.content)
logger.info(
f"β Downloaded Hunyuan3D GLB: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
except Exception as download_error:
logger.error(f"Failed to download GLB from URL: {download_error}")
raise RuntimeError(
f"Failed to download GLB from Hunyuan3D URL: {download_error}"
)
else:
if not Path(temp_file_path).exists():
raise RuntimeError(
f"Generated GLB file not found: {temp_file_path}. "
)
glb_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(temp_file_path, glb_path)
logger.info(
f"β Hunyuan3D GLB generated: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
return str(glb_path)
except Exception as e:
logger.error(f"Hunyuan3D image-to-3D generation failed: {e}")
raise RuntimeError(f"3D model generation failed: {e}")
def image_to_3d_TRELLIS(
self, image_file_path: str, model_id: str | None = None
) -> str:
"""
Generate a 3D model from a 2D image using the TRELLIS API endpoint.
This function has been streamlined to use the single-step TRELLIS API
endpoint (`/generate_and_extract_glb`) for faster, simpler generation,
assuming the client (self.trellis_client2) is configured for the TRELLIS Space.
Args:
image_file_path: Local file path to the input 2D image.
model_id: Optional model ID to use for the output filename.
Returns:
Path to the generated GLB file (local storage path).
"""
# NOTE: Assuming self.trellis_client is now used for the TRELLIS-style generation.
if not self.trellis_client2:
raise RuntimeError(
"3D Generation Client not initialized (using TRELLIS single-step API)."
)
input_path = Path(image_file_path)
if not input_path.exists():
logger.error(f"Input image file not found at: {image_file_path}")
raise FileNotFoundError(f"Input image file not found: {image_file_path}")
logger.info(
f"Generating 3D model with TRELLIS (single-step) from image file: '{image_file_path}'"
)
filename = model_id if model_id else str(uuid.uuid4())
glb_path = self.storage_path / f"{filename}.glb"
# Helper function to extract file path, handling Gradio's inconsistent wrapping
def extract_path(result_part):
if isinstance(result_part, dict):
if "value" in result_part:
return result_part["value"]
if "path" in result_part:
return result_part["path"]
if "name" in result_part:
return result_part["name"]
return result_part
try:
# --- Step 1: Call /generate_and_extract_glb (TRELLIS single-step API) ---
logger.info("TRELLIS Step 1/1: Calling /generate_and_extract_glb...")
# Parameters derived from the TRELLIS Gradio API example
# Using self.trellis_client2 as the client for this operation
gen_result = self.trellis_client2.predict(
image=handle_file(image_file_path),
multiimages=[],
seed=0,
ss_guidance_strength=7.5,
ss_sampling_steps=12,
slat_guidance_strength=3,
slat_sampling_steps=12,
multiimage_algo="stochastic",
mesh_simplify=0.95,
texture_size=1024,
api_name="/generate_and_extract_glb",
)
logger.debug(f"TRELLIS /generate_and_extract_glb result: {gen_result}")
if not isinstance(gen_result, (list, tuple)) or not gen_result:
raise RuntimeError(
f"Unexpected result from TRELLIS /generate_and_extract_glb: {gen_result}"
)
# The TRELLIS endpoint returns a tuple of 3 elements: [video, glb/gaussian, download_glb]
# The second element (index [1]) is the file path for the Litmodel3d component (glb/gaussian).
# We check the first element ([0]) which is the video component, as it often contains the file path first.
# The TRELLIS API documentation shows the video component (index 0) is the first return element.
# We need the path to the downloaded GLB file. Let's assume the first element path or the second element path.
# Based on the API doc:
# Returns tuple of 3 elements:
# [0] dict(video: filepath, subtitles: filepath | None) -> The Video component path
# [1] filepath -> The Litmodel3d component path (GLB/Gaussian)
# [2] filepath -> The Downloadbutton path (GLB)
# The most reliable path to the output file is often the second element [1] or third [2],
# but the original TRELLIS function used [0]. We will update to use [1] for the Litmodel3d path
# as it is more likely to contain the actual GLB data path when running remotely.
# Let's use index [1] or [2] if available, falling back to [0].
temp_file_path = (
extract_path(gen_result[1]) if len(gen_result) > 1 else None
)
if not temp_file_path:
temp_file_path = extract_path(gen_result[0])
if not temp_file_path or not isinstance(temp_file_path, str):
raise RuntimeError(
f"TRELLIS returned invalid file path: {temp_file_path}"
)
# --- Step 2: Handle the final GLB file (download or copy) ---
if temp_file_path.startswith("http://") or temp_file_path.startswith(
"https://"
):
logger.info(f"TRELLIS returned URL, downloading GLB: {temp_file_path}")
try:
with httpx.Client(timeout=120.0) as http_client:
response = http_client.get(temp_file_path)
response.raise_for_status()
glb_path.parent.mkdir(parents=True, exist_ok=True)
with open(glb_path, "wb") as f:
f.write(response.content)
logger.info(
f"β Downloaded TRELLIS GLB: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
except Exception as download_error:
logger.error(f"Failed to download GLB from URL: {download_error}")
raise RuntimeError(
f"Failed to download GLB from TRELLIS URL: {download_error}"
)
else:
if not Path(temp_file_path).exists():
raise RuntimeError(
f"Generated GLB file not found: {temp_file_path}. "
)
glb_path.parent.mkdir(parents=True, exist_ok=True)
shutil.copy(temp_file_path, glb_path)
logger.info(
f"β TRELLIS GLB generated: {glb_path} ({os.path.getsize(glb_path)} bytes)"
)
return str(glb_path)
except gradio_exceptions.AppError as app_error:
logger.error(f"TRELLIS AppError: {app_error}")
raise RuntimeError(
f"TRELLIS Space returned an error during generation: {app_error}."
) from app_error
except Exception as e:
logger.error(f"TRELLIS image-to-3D generation failed: {e}")
raise RuntimeError(f"3D model generation failed: {e}")
|