Spaces:
Running on Zero
Running on Zero
| import argparse | |
| import base64 | |
| import concurrent.futures | |
| import importlib.util | |
| import json | |
| import os | |
| import sys | |
| import threading | |
| import time | |
| from datetime import datetime | |
| from typing import Any, Dict, Generator, List, Optional, Tuple | |
| os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| if importlib.util.find_spec("flash_attn") is not None: | |
| _attn_backend = "flash_attn" | |
| elif importlib.util.find_spec("xformers") is not None: | |
| _attn_backend = "xformers" | |
| else: | |
| _attn_backend = "sdpa" | |
| os.environ.setdefault("ATTN_BACKEND", _attn_backend) | |
| os.environ.setdefault("SPARSE_ATTN_BACKEND", _attn_backend) | |
| os.environ.setdefault("SPCONV_ALGO", "native") | |
| os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join( | |
| os.path.dirname(os.path.abspath(__file__)), "autotune_cache.json" | |
| ) | |
| os.environ["FLEX_GEMM_AUTOTUNER_VERBOSE"] = "1" | |
| os.environ.setdefault("TRELLIS_REMBG_MODEL", "briaai/RMBG-2.0") | |
| import cv2 | |
| import gradio as gr | |
| import imageio | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import trimesh | |
| from PIL import Image | |
| from gradio_litmodel3d import LitModel3D | |
| sys.path.append(os.getcwd()) | |
| from trellis.pipelines import TrellisImageTo3DPipeline | |
| from trellis.representations import Gaussian, MeshExtractResult | |
| from trellis.utils import postprocessing_utils as trellis_postprocessing_utils | |
| from trellis.utils import render_utils as trellis_render_utils | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") | |
| logger = logging.getLogger(__name__) | |
| parser = argparse.ArgumentParser(description="Pocket 3D AI") | |
| parser.add_argument("--prod", action="store_true", help="Run in production mode") | |
| parser.add_argument("--port", type=int, help="Port to run the server on (default: 8081 for prod, 8080 for dev)") | |
| cmd_args, _unknown_args = parser.parse_known_args() | |
| prod = cmd_args.prod | |
| port = cmd_args.port if cmd_args.port else (8081 if prod else 8080) | |
| show_options = not prod | |
| RUNNING_ON_SPACES = bool(os.getenv("SPACE_ID")) | |
| PUBLIC_BASE_URL = os.getenv("PUBLIC_BASE_URL", "").strip() | |
| if PUBLIC_BASE_URL and not PUBLIC_BASE_URL.startswith(("http://", "https://")): | |
| PUBLIC_BASE_URL = f"https://{PUBLIC_BASE_URL}" | |
| PUBLIC_BASE_URL = PUBLIC_BASE_URL.rstrip("/") | |
| if RUNNING_ON_SPACES: | |
| # Required behind proxies/custom domains so Gradio/Uvicorn generate https URLs. | |
| os.environ.setdefault("FORWARDED_ALLOW_IPS", "*") | |
| HAS_DIFF_GAUSSIAN_RASTERIZATION = importlib.util.find_spec("diff_gaussian_rasterization") is not None | |
| if HAS_DIFF_GAUSSIAN_RASTERIZATION: | |
| logger.info("diff_gaussian_rasterization detected: Gaussian RGB rendering enabled.") | |
| else: | |
| logger.warning( | |
| "diff_gaussian_rasterization is missing: preview/texture baking will fall back to mesh appearance." | |
| ) | |
| TRELLIS_RUNTIME = os.getenv("TRELLIS_RUNTIME", "original").strip().lower() | |
| if TRELLIS_RUNTIME not in {"original", "trellis2"}: | |
| logger.warning("Unknown TRELLIS_RUNTIME=%s, defaulting to 'original'.", TRELLIS_RUNTIME) | |
| TRELLIS_RUNTIME = "original" | |
| MAX_SEED = np.iinfo(np.int32).max | |
| APP_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| TMP_DIR = os.path.join(APP_DIR, "cache") | |
| EXPORT_DIR = os.path.join(APP_DIR, "exports") | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| os.makedirs(EXPORT_DIR, exist_ok=True) | |
| TMP_DIR_ABS = TMP_DIR | |
| EXPORT_DIR_ABS = EXPORT_DIR | |
| ASSETS_DIR_ABS = os.path.join(APP_DIR, "assets") | |
| ALLOWED_PATHS = [TMP_DIR_ABS, EXPORT_DIR_ABS, ASSETS_DIR_ABS] | |
| if gr.NO_RELOAD: | |
| pipeline = None | |
| envmap = None | |
| trellis2_render_utils = None | |
| EnvMap = None | |
| o_voxel = None | |
| def initialize_pipeline(): | |
| global pipeline, envmap, trellis2_render_utils, EnvMap, o_voxel | |
| if pipeline is not None: | |
| return | |
| logger.info("Initializing runtime '%s'...", TRELLIS_RUNTIME) | |
| start_time = time.time() | |
| try: | |
| if TRELLIS_RUNTIME == "original": | |
| pipeline = TrellisImageTo3DPipeline.from_pretrained( | |
| "JeffreyXiang/TRELLIS-image-large", | |
| formats=["mesh", "gaussian"], | |
| ) | |
| if hasattr(pipeline, "_move_all_models_to_cpu"): | |
| pipeline._move_all_models_to_cpu() | |
| logger.info("Using original TRELLIS runtime.") | |
| else: | |
| from trellis2.pipelines import Trellis2ImageTo3DPipeline | |
| from trellis2.renderers import EnvMap as Trellis2EnvMap | |
| from trellis2.utils import render_utils as trellis2_render_utils_mod | |
| import o_voxel as o_voxel_mod | |
| pipeline = Trellis2ImageTo3DPipeline.from_pretrained("microsoft/TRELLIS.2-4B") | |
| pipeline.low_vram = False | |
| pipeline._device = "cpu" | |
| trellis2_render_utils = trellis2_render_utils_mod | |
| EnvMap = Trellis2EnvMap | |
| o_voxel = o_voxel_mod | |
| envmap = {} | |
| for name in ["forest", "sunset", "courtyard"]: | |
| exr_path = os.path.join("assets", "hdri", f"{name}.exr") | |
| if os.path.exists(exr_path): | |
| exr = cv2.imread(exr_path, cv2.IMREAD_UNCHANGED) | |
| if exr is None: | |
| continue | |
| if RUNNING_ON_SPACES: | |
| exr = cv2.resize(exr, (512, 256), interpolation=cv2.INTER_AREA) | |
| envmap[name] = cv2.cvtColor(exr, cv2.COLOR_BGR2RGB) | |
| logger.info("Using TRELLIS.2 runtime.") | |
| logger.info("Pipeline initialized in %.2fs.", time.time() - start_time) | |
| except Exception as e: | |
| logger.error("Failed to initialize pipeline: %s", e, exc_info=True) | |
| pipeline = None | |
| raise | |
| initialize_pipeline() | |
| def clear_cuda_cache() -> None: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def normalize_video_frames(frames: Any) -> List[np.ndarray]: | |
| if frames is None: | |
| return [] | |
| def _normalize_frame(arr: np.ndarray) -> Optional[np.ndarray]: | |
| if arr is None or arr.ndim != 3: | |
| return None | |
| if arr.shape[-1] == 4: | |
| arr = arr[:, :, :3] | |
| if np.issubdtype(arr.dtype, np.floating): | |
| min_v = float(np.nanmin(arr)) | |
| max_v = float(np.nanmax(arr)) | |
| if min_v >= 0.0 and max_v <= 1.0: | |
| arr = arr * 255.0 | |
| elif min_v >= -1.0 and max_v <= 1.0: | |
| arr = (arr + 1.0) * 127.5 | |
| arr = np.clip(arr, 0.0, 255.0) | |
| return arr.astype(np.uint8) | |
| if isinstance(frames, np.ndarray): | |
| if frames.ndim == 4: | |
| return [nf for f in frames if (nf := _normalize_frame(f)) is not None] | |
| if frames.ndim == 3: | |
| nf = _normalize_frame(frames) | |
| return [nf] if nf is not None else [] | |
| return [] | |
| normalized = [] | |
| for frame in frames: | |
| if frame is None: | |
| continue | |
| arr = np.asarray(frame) | |
| nf = _normalize_frame(arr) | |
| if nf is not None: | |
| normalized.append(nf) | |
| return normalized | |
| def write_mp4(video_path: str, frames: List[np.ndarray], fps: int = 15) -> bool: | |
| if not frames: | |
| return False | |
| try: | |
| with imageio.get_writer( | |
| video_path, | |
| format="FFMPEG", | |
| mode="I", | |
| fps=fps, | |
| codec="libx264", | |
| ffmpeg_params=["-preset", "ultrafast", "-movflags", "faststart"], | |
| ) as writer: | |
| for frame in frames: | |
| writer.append_data(frame) | |
| if os.path.exists(video_path) and os.path.getsize(video_path) > 0: | |
| return True | |
| except Exception as ffmpeg_err: | |
| logger.warning("FFMPEG video writer failed: %s", ffmpeg_err) | |
| try: | |
| h, w = frames[0].shape[:2] | |
| writer = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (w, h)) | |
| if not writer.isOpened(): | |
| raise RuntimeError("OpenCV VideoWriter failed to open.") | |
| for frame in frames: | |
| if frame.shape[0] != h or frame.shape[1] != w: | |
| frame = cv2.resize(frame, (w, h), interpolation=cv2.INTER_AREA) | |
| writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) | |
| writer.release() | |
| if os.path.exists(video_path) and os.path.getsize(video_path) > 0: | |
| return True | |
| except Exception as opencv_err: | |
| logger.error("OpenCV video writer failed: %s", opencv_err, exc_info=True) | |
| return False | |
| def is_video_mostly_black(frames: List[np.ndarray], mean_threshold: float = 2.0) -> bool: | |
| if not frames: | |
| return True | |
| sample = frames[: min(8, len(frames))] | |
| means = [float(np.mean(f)) for f in sample if f is not None] | |
| if not means: | |
| return True | |
| return (sum(means) / len(means)) < mean_threshold | |
| def start_session(req: gr.Request): | |
| clear_cuda_cache() | |
| os.makedirs(TMP_DIR, exist_ok=True) | |
| os.makedirs(EXPORT_DIR, exist_ok=True) | |
| def end_session(req: gr.Request): | |
| clear_cuda_cache() | |
| def preprocess_image(image: Optional[Image.Image]) -> Optional[Image.Image]: | |
| if image is None: | |
| return None | |
| try: | |
| return pipeline.preprocess_image(image) | |
| except Exception as e: | |
| logger.error("Error during image preprocessing: %s", e, exc_info=True) | |
| return None | |
| def preprocess_images(images: List[Image.Image]) -> Optional[List[Image.Image]]: | |
| if not images: | |
| return None | |
| try: | |
| processed = pipeline.preprocess_images(images) | |
| if any(img is None for img in processed): | |
| logger.error("One or more images failed preprocessing.") | |
| return None | |
| return processed | |
| except Exception as e: | |
| logger.error("Error during multi-image preprocessing: %s", e, exc_info=True) | |
| return None | |
| def get_seed(randomize_seed: bool, seed: int) -> int: | |
| return np.random.randint(0, MAX_SEED) if randomize_seed else seed | |
| def encode_file_to_b64(file_path: str) -> str: | |
| with open(file_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode("ascii") | |
| def decode_b64_to_file(file_b64: str, out_path: str) -> None: | |
| with open(out_path, "wb") as f: | |
| f.write(base64.b64decode(file_b64)) | |
| def build_export_payload(glb_path: str, stl_path: Optional[str]) -> Dict[str, Any]: | |
| payload: Dict[str, Any] = { | |
| "glb_name": os.path.basename(glb_path), | |
| "glb_b64": encode_file_to_b64(glb_path), | |
| } | |
| if stl_path: | |
| payload["stl_name"] = os.path.basename(stl_path) | |
| payload["stl_b64"] = encode_file_to_b64(stl_path) | |
| return payload | |
| def get_public_base_url(req: Optional[gr.Request]) -> str: | |
| if PUBLIC_BASE_URL: | |
| return PUBLIC_BASE_URL | |
| if req is not None: | |
| headers = req.headers or {} | |
| x_forwarded_host = headers.get("x-forwarded-host", "").split(",")[0].strip() | |
| host = x_forwarded_host or headers.get("host", "").split(",")[0].strip() | |
| if host: | |
| return f"https://{host}" | |
| space_host = os.getenv("SPACE_HOST", "").strip() | |
| if space_host: | |
| return f"https://{space_host}" | |
| return "" | |
| def to_public_file_value(file_path: str, req: Optional[gr.Request]) -> str: | |
| base_url = get_public_base_url(req) | |
| if RUNNING_ON_SPACES and base_url: | |
| return f"{base_url}/file={file_path}" | |
| return file_path | |
| def materialize_export_payload( | |
| payload: Optional[Dict[str, Any]], | |
| req: gr.Request, | |
| ) -> Tuple[Optional[str], Dict[str, Any]]: | |
| if payload is None: | |
| return None, gr.update(value=None, visible=True, interactive=False) | |
| try: | |
| os.makedirs(EXPORT_DIR, exist_ok=True) | |
| current_time = datetime.now().strftime("%Y-%m%d-%H%M%S-%f") | |
| glb_name = str(payload.get("glb_name", "model.glb")) | |
| if not glb_name.lower().endswith(".glb"): | |
| glb_name = f"{glb_name}.glb" | |
| glb_path_abs = os.path.abspath(os.path.join(EXPORT_DIR, f"{current_time}-{glb_name}")) | |
| decode_b64_to_file(str(payload["glb_b64"]), glb_path_abs) | |
| glb_size = os.path.getsize(glb_path_abs) | |
| if glb_size <= 0: | |
| raise RuntimeError("Materialized GLB file is empty.") | |
| stl_path_abs = None | |
| stl_b64 = payload.get("stl_b64") | |
| if stl_b64: | |
| stl_name = str(payload.get("stl_name", "model.stl")) | |
| if not stl_name.lower().endswith(".stl"): | |
| stl_name = f"{stl_name}.stl" | |
| stl_path_abs = os.path.abspath(os.path.join(EXPORT_DIR, f"{current_time}-{stl_name}")) | |
| decode_b64_to_file(str(stl_b64), stl_path_abs) | |
| if os.path.getsize(stl_path_abs) <= 0: | |
| logger.warning("Materialized STL is empty; hiding download.") | |
| stl_path_abs = None | |
| logger.info("Materialized GLB for UI: %s (%d bytes)", glb_path_abs, glb_size) | |
| if stl_path_abs: | |
| logger.info("Materialized STL for UI: %s (%d bytes)", stl_path_abs, os.path.getsize(stl_path_abs)) | |
| glb_value = glb_path_abs | |
| stl_value = stl_path_abs if stl_path_abs else None | |
| stl_update = gr.update(value=stl_value, visible=True, interactive=bool(stl_value)) | |
| return glb_value, stl_update | |
| except Exception as e: | |
| logger.error("Materializing export payload failed: %s", e, exc_info=True) | |
| raise gr.Error("Model generated, but packaging for display failed. Please retry.") from e | |
| def export_stl_from_glb(glb_path: str) -> Optional[str]: | |
| stl_path = None | |
| mesh_data = trimesh.load_mesh(glb_path, force="mesh") | |
| mesh_to_export = None | |
| if isinstance(mesh_data, trimesh.Scene): | |
| geometries = [g for g in mesh_data.geometry.values() if isinstance(g, trimesh.Trimesh)] | |
| valid = [g for g in geometries if g.vertices is not None and len(g.vertices) > 0] | |
| if valid: | |
| combined_mesh = trimesh.util.concatenate(valid) | |
| if isinstance(combined_mesh, trimesh.Trimesh) and len(combined_mesh.vertices) > 0: | |
| mesh_to_export = combined_mesh | |
| elif isinstance(mesh_data, trimesh.Trimesh) and len(mesh_data.vertices) > 0: | |
| mesh_to_export = mesh_data | |
| if mesh_to_export and mesh_to_export.faces is not None and len(mesh_to_export.faces) > 0: | |
| mesh_to_export = mesh_to_export.copy() | |
| rot_x_90 = trimesh.transformations.rotation_matrix(np.deg2rad(90), [1, 0, 0]) | |
| mesh_to_export.apply_transform(rot_x_90) | |
| bbox = mesh_to_export.bounds | |
| current_size = (bbox[1] - bbox[0]).max() | |
| target_size_mm = 152.4 | |
| if current_size > 0: | |
| mesh_to_export.vertices *= target_size_mm / current_size | |
| current_time_stl = datetime.now().strftime("%Y-%m%d-%H%M%S-%f") | |
| stl_path = os.path.join(EXPORT_DIR, f"{current_time_stl}.stl") | |
| mesh_to_export.export(stl_path) | |
| logger.info("STL exported: %s", stl_path) | |
| return stl_path | |
| def get_preview_settings(req: Optional[gr.Request]) -> Tuple[bool, int, int, int]: | |
| headers = req.headers if req else {} | |
| user_agent = headers.get("User-Agent", "").lower() | |
| is_mobile = any(d in user_agent for d in ["android", "iphone", "ipad", "mobile"]) | |
| resolution = 256 if is_mobile else 384 | |
| fps = 15 | |
| seconds = 3 | |
| return is_mobile, resolution, fps, seconds | |
| def pack_original_state(outputs: Dict[str, Any]) -> Dict[str, Any]: | |
| mesh = outputs["mesh"][0] | |
| packed: Dict[str, Any] = { | |
| "runtime": "original", | |
| "mesh": { | |
| "vertices": mesh.vertices.detach().cpu().numpy(), | |
| "faces": mesh.faces.detach().cpu().numpy(), | |
| "vertex_attrs": mesh.vertex_attrs.detach().cpu().numpy() if mesh.vertex_attrs is not None else None, | |
| "res": int(getattr(mesh, "res", 64)), | |
| }, | |
| } | |
| gaussian_outputs = outputs.get("gaussian", None) | |
| if gaussian_outputs and gaussian_outputs[0] is not None: | |
| gaussian = gaussian_outputs[0] | |
| packed["gaussian"] = { | |
| "init_params": dict(getattr(gaussian, "init_params", {})), | |
| "xyz": gaussian._xyz.detach().cpu().numpy(), | |
| "features_dc": gaussian._features_dc.detach().cpu().numpy(), | |
| "features_rest": gaussian._features_rest.detach().cpu().numpy() if gaussian._features_rest is not None else None, | |
| "scaling": gaussian._scaling.detach().cpu().numpy(), | |
| "rotation": gaussian._rotation.detach().cpu().numpy(), | |
| "opacity": gaussian._opacity.detach().cpu().numpy(), | |
| } | |
| return packed | |
| def unpack_original_state(state: Dict[str, Any]) -> Dict[str, Any]: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| minfo = state["mesh"] | |
| vertex_attrs_np = minfo.get("vertex_attrs") | |
| mesh = MeshExtractResult( | |
| vertices=torch.tensor(minfo["vertices"], device=device, dtype=torch.float32), | |
| faces=torch.tensor(minfo["faces"], device=device, dtype=torch.int64), | |
| vertex_attrs=torch.tensor(vertex_attrs_np, device=device, dtype=torch.float32) if vertex_attrs_np is not None else None, | |
| res=int(minfo.get("res", 64)), | |
| ) | |
| out: Dict[str, Any] = {"mesh": mesh} | |
| ginfo = state.get("gaussian") | |
| if ginfo is not None: | |
| gaussian = Gaussian(device=device.type, **ginfo["init_params"]) | |
| gaussian._xyz = torch.tensor(ginfo["xyz"], device=device, dtype=torch.float32) | |
| gaussian._features_dc = torch.tensor(ginfo["features_dc"], device=device, dtype=torch.float32) | |
| gaussian._features_rest = ( | |
| torch.tensor(ginfo["features_rest"], device=device, dtype=torch.float32) | |
| if ginfo["features_rest"] is not None else None | |
| ) | |
| gaussian._scaling = torch.tensor(ginfo["scaling"], device=device, dtype=torch.float32) | |
| gaussian._rotation = torch.tensor(ginfo["rotation"], device=device, dtype=torch.float32) | |
| gaussian._opacity = torch.tensor(ginfo["opacity"], device=device, dtype=torch.float32) | |
| out["gaussian"] = gaussian | |
| return out | |
| def pack_trellis2_state(mesh: Any, grid_size: int) -> Dict[str, Any]: | |
| return { | |
| "runtime": "trellis2", | |
| "mesh": { | |
| "vertices": mesh.vertices.detach().cpu().numpy(), | |
| "faces": mesh.faces.detach().cpu().numpy(), | |
| "attrs": mesh.attrs.detach().cpu().numpy(), | |
| "coords": mesh.coords.detach().cpu().numpy(), | |
| "voxel_shape": list(mesh.voxel_shape), | |
| "layout": {k: [v.start, v.stop] for k, v in mesh.layout.items()}, | |
| }, | |
| "grid_size": grid_size, | |
| } | |
| def unpack_trellis2_state(state: Dict[str, Any]) -> Dict[str, Any]: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| mesh_info = state["mesh"] | |
| attr_layout = {k: slice(v[0], v[1]) for k, v in mesh_info["layout"].items()} | |
| return { | |
| "vertices": torch.tensor(mesh_info["vertices"], device=device, dtype=torch.float32), | |
| "faces": torch.tensor(mesh_info["faces"], device=device, dtype=torch.int32), | |
| "attrs": torch.tensor(mesh_info["attrs"], device=device, dtype=torch.float32), | |
| "coords": torch.tensor(mesh_info["coords"], device=device, dtype=torch.int32), | |
| "voxel_shape": torch.Size(mesh_info["voxel_shape"]), | |
| "attr_layout": attr_layout, | |
| "grid_size": int(state["grid_size"]), | |
| } | |
| def render_original_preview(outputs: Dict[str, Any], req: gr.Request) -> Optional[str]: | |
| is_mobile, resolution, fps, seconds = get_preview_settings(req) | |
| num_frames = seconds * fps | |
| mesh = outputs["mesh"][0] | |
| gaussian_outputs = outputs.get("gaussian", None) | |
| gaussian = gaussian_outputs[0] if gaussian_outputs and gaussian_outputs[0] is not None else None | |
| try: | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: | |
| color_source = gaussian if gaussian is not None else mesh | |
| future_color = executor.submit( | |
| trellis_render_utils.render_video, | |
| color_source, | |
| resolution=resolution, | |
| num_frames=num_frames, | |
| mode="color", | |
| verbose=False, | |
| ) | |
| future_normal = executor.submit( | |
| trellis_render_utils.render_video, | |
| mesh, | |
| resolution=resolution, | |
| num_frames=num_frames, | |
| mode="normal", | |
| verbose=False, | |
| ) | |
| color_result = future_color.result() | |
| normal_result = future_normal.result() | |
| except ModuleNotFoundError as e: | |
| if "diff_gaussian_rasterization" not in str(e): | |
| raise | |
| logger.warning("Gaussian renderer unavailable; falling back to mesh-only preview color.") | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor: | |
| future_color = executor.submit( | |
| trellis_render_utils.render_video, | |
| mesh, | |
| resolution=resolution, | |
| num_frames=num_frames, | |
| mode="color", | |
| verbose=False, | |
| ) | |
| future_normal = executor.submit( | |
| trellis_render_utils.render_video, | |
| mesh, | |
| resolution=resolution, | |
| num_frames=num_frames, | |
| mode="normal", | |
| verbose=False, | |
| ) | |
| color_result = future_color.result() | |
| normal_result = future_normal.result() | |
| color_frames = color_result.get("color", []) if color_result else [] | |
| normal_frames = normal_result.get("normal", []) if normal_result else [] | |
| if not color_frames or not normal_frames: | |
| logger.warning("Preview rendering returned no frames.") | |
| return None | |
| # Some Spaces builds miss gaussian rasterization and mesh color can be all black. | |
| # In that case, keep the left pane informative by mirroring normals. | |
| if is_video_mostly_black(color_frames): | |
| logger.warning("Preview color frames are mostly black; falling back to normal frames for left pane.") | |
| color_frames = normal_frames | |
| frame_count = min(len(color_frames), len(normal_frames)) | |
| combined = [] | |
| for i in range(frame_count): | |
| if is_mobile: | |
| frame = np.concatenate([color_frames[i], normal_frames[i]], axis=0) | |
| else: | |
| frame = np.concatenate([color_frames[i], normal_frames[i]], axis=1) | |
| combined.append(frame) | |
| current_time = datetime.now().strftime("%Y-%m%d-%H%M%S") | |
| video_path = os.path.join(TMP_DIR, f"{current_time}.mp4") | |
| if write_mp4(video_path, combined, fps=fps): | |
| return video_path | |
| return None | |
| def render_trellis2_preview(mesh: Any, req: gr.Request) -> Optional[str]: | |
| _is_mobile, resolution, fps, seconds = get_preview_settings(req) | |
| num_frames = seconds * fps | |
| loaded_envmap = {} | |
| for name, exr_data in (envmap or {}).items(): | |
| loaded_envmap[name] = EnvMap(torch.tensor(exr_data, dtype=torch.float32, device="cuda")) | |
| preview_envmap = loaded_envmap.get("sunset") if loaded_envmap else None | |
| if preview_envmap is None and loaded_envmap: | |
| preview_envmap = next(iter(loaded_envmap.values())) | |
| if preview_envmap is not None: | |
| vid_result = trellis2_render_utils.render_video( | |
| mesh, | |
| resolution=resolution, | |
| num_frames=num_frames, | |
| r=2, | |
| fov=36, | |
| envmap=preview_envmap, | |
| ) | |
| else: | |
| vid_result = trellis2_render_utils.render_video( | |
| mesh, | |
| resolution=resolution, | |
| num_frames=num_frames, | |
| r=2, | |
| fov=36, | |
| envmap=loaded_envmap, | |
| ) | |
| shaded_frames = vid_result.get("shaded") | |
| if shaded_frames is None: | |
| shaded_keys = [k for k in vid_result.keys() if k.startswith("shaded_")] | |
| if shaded_keys: | |
| shaded_frames = vid_result[shaded_keys[0]] | |
| color_frames = normalize_video_frames(shaded_frames if shaded_frames is not None else vid_result.get("color", [])) | |
| normal_frames = normalize_video_frames(vid_result.get("normal", [])) | |
| if len(color_frames) == 0 and len(normal_frames) == 0: | |
| return None | |
| current_time = datetime.now().strftime("%Y-%m%d-%H%M%S") | |
| video_path = os.path.join(TMP_DIR, f"{current_time}.mp4") | |
| if len(color_frames) > 0: | |
| ok = write_mp4(video_path, color_frames, fps=fps) | |
| else: | |
| ok = write_mp4(video_path, normal_frames, fps=fps) | |
| return video_path if ok else None | |
| def _run_original_pipeline( | |
| image: Image.Image, | |
| seed: int, | |
| ss_guidance_strength: float, | |
| ss_sampling_steps: int, | |
| slat_guidance_strength: float, | |
| slat_sampling_steps: int, | |
| ) -> Dict[str, Any]: | |
| return pipeline.run( | |
| image, | |
| seed=seed, | |
| formats=["mesh", "gaussian"], | |
| preprocess_image=False, | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "cfg_strength": ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "cfg_strength": slat_guidance_strength, | |
| }, | |
| ) | |
| def _run_original_pipeline_multi( | |
| images: List[Image.Image], | |
| seed: int, | |
| ss_guidance_strength: float, | |
| ss_sampling_steps: int, | |
| slat_guidance_strength: float, | |
| slat_sampling_steps: int, | |
| ) -> Dict[str, Any]: | |
| return pipeline.run_multi_image( | |
| images, | |
| seed=seed, | |
| formats=["mesh", "gaussian"], | |
| preprocess_image=False, | |
| mode="stochastic", | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "cfg_strength": ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "cfg_strength": slat_guidance_strength, | |
| }, | |
| ) | |
| def _run_trellis2_pipeline( | |
| image: Image.Image, | |
| seed: int, | |
| ss_guidance_strength: float, | |
| ss_sampling_steps: int, | |
| slat_guidance_strength: float, | |
| slat_sampling_steps: int, | |
| ) -> Tuple[Any, int]: | |
| if o_voxel is None: | |
| raise RuntimeError("TRELLIS.2 runtime requires o_voxel, but it is not available.") | |
| pipeline.cuda() | |
| if RUNNING_ON_SPACES: | |
| pipeline_type = "512" | |
| grid_size = 512 | |
| else: | |
| pipeline_type = "1024_cascade" | |
| grid_size = 1024 | |
| outputs = pipeline.run( | |
| image, | |
| seed=seed, | |
| preprocess_image=False, | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "guidance_strength": ss_guidance_strength, | |
| }, | |
| shape_slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "guidance_strength": slat_guidance_strength, | |
| }, | |
| tex_slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "guidance_strength": slat_guidance_strength, | |
| }, | |
| pipeline_type=pipeline_type, | |
| return_latent=False, | |
| ) | |
| return outputs[0], grid_size | |
| def _extract_model_payload_from_state( | |
| state: Dict[str, Any], | |
| mesh_simplify: float, | |
| texture_size: int, | |
| progress=gr.Progress(track_tqdm=True), | |
| ) -> Optional[Dict[str, Any]]: | |
| if state.get("runtime") == "original": | |
| original_state = unpack_original_state(state) | |
| mesh = original_state["mesh"] | |
| app_rep = original_state.get("gaussian") | |
| if app_rep is None or not HAS_DIFF_GAUSSIAN_RASTERIZATION: | |
| if app_rep is not None and not HAS_DIFF_GAUSSIAN_RASTERIZATION: | |
| logger.warning( | |
| "diff_gaussian_rasterization unavailable; using mesh appearance for GLB baking." | |
| ) | |
| app_rep = mesh | |
| glb = trellis_postprocessing_utils.to_glb( | |
| app_rep, | |
| mesh, | |
| simplify=mesh_simplify, | |
| fill_holes=True, | |
| texture_size=texture_size, | |
| verbose=False, | |
| ) | |
| else: | |
| if o_voxel is None: | |
| raise gr.Error("TRELLIS.2 runtime requires o_voxel, but it is not available.") | |
| mesh_state = unpack_trellis2_state(state) | |
| decimation_target = max(100000, int((1.0 - mesh_simplify) * 500000)) | |
| glb = o_voxel.postprocess.to_glb( | |
| vertices=mesh_state["vertices"], | |
| faces=mesh_state["faces"], | |
| attr_volume=mesh_state["attrs"], | |
| coords=mesh_state["coords"], | |
| attr_layout=mesh_state["attr_layout"], | |
| grid_size=mesh_state["grid_size"], | |
| aabb=[[-0.5, -0.5, -0.5], [0.5, 0.5, 0.5]], | |
| decimation_target=decimation_target, | |
| texture_size=texture_size, | |
| remesh=not RUNNING_ON_SPACES, | |
| remesh_band=1, | |
| remesh_project=0, | |
| use_tqdm=False, | |
| ) | |
| if progress is not None: | |
| progress(0.84, desc="Exporting GLB...") | |
| current_time_glb = datetime.now().strftime("%Y-%m%d-%H%M%S-%f") | |
| glb_path = os.path.join(EXPORT_DIR, f"{current_time_glb}.glb") | |
| glb.export(glb_path) | |
| glb_path_abs = os.path.abspath(glb_path) | |
| logger.info("GLB exported: %s (%d bytes)", glb_path_abs, os.path.getsize(glb_path_abs)) | |
| try: | |
| scene_or_mesh = trimesh.load(glb_path_abs, force="scene") | |
| if isinstance(scene_or_mesh, trimesh.Scene): | |
| if not scene_or_mesh.geometry: | |
| raise RuntimeError("Exported GLB contains no geometry.") | |
| elif isinstance(scene_or_mesh, trimesh.Trimesh): | |
| if scene_or_mesh.vertices is None or len(scene_or_mesh.vertices) == 0: | |
| raise RuntimeError("Exported GLB mesh is empty.") | |
| except Exception as validate_err: | |
| logger.error("GLB validation failed: %s", validate_err, exc_info=True) | |
| raise gr.Error("Model was exported but failed validation for viewer rendering.") | |
| stl_start = time.time() | |
| stl_path = export_stl_from_glb(glb_path_abs) | |
| if stl_path is not None: | |
| stl_path = os.path.abspath(stl_path) | |
| logger.info("STL Export Time: %.2fs", time.time() - stl_start) | |
| logger.info("Returning model file: %s", glb_path_abs) | |
| logger.info("Returning STL file: %s", stl_path) | |
| return build_export_payload(glb_path_abs, stl_path) | |
| def generate_and_extract( | |
| image: Optional[Image.Image], | |
| seed: int, | |
| ss_guidance_strength: float, | |
| ss_sampling_steps: int, | |
| slat_guidance_strength: float, | |
| slat_sampling_steps: int, | |
| mesh_simplify: float, | |
| texture_size: int, | |
| req: gr.Request, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| """Single GPU session: generate preview + extract model (generator). | |
| Combining both phases into one @spaces.GPU call halves ZeroGPU quota | |
| usage (one 40s reservation instead of two). Yields the preview video | |
| immediately so the user sees it while model extraction continues. | |
| """ | |
| if image is None or pipeline is None: | |
| yield None, None, None, gr.update() | |
| return | |
| if not torch.cuda.is_available(): | |
| raise gr.Error("GPU is not ready. Please retry in a few seconds.") | |
| job_start = time.time() | |
| preprocess_start = time.time() | |
| image = preprocess_image(image) | |
| if image is None: | |
| raise gr.Error("Image preprocessing failed. Please try a different image.") | |
| logger.info("Preprocess Time: %.2fs", time.time() - preprocess_start) | |
| progress(0.1, desc=f"Preprocessed image in {time.time() - preprocess_start:.1f}s") | |
| try: | |
| # --- Phase A: Generate preview --- | |
| outputs = None | |
| mesh = None | |
| grid_size = 512 | |
| inference_start = time.time() | |
| progress(0.15, desc="Generating 3D structure...") | |
| if TRELLIS_RUNTIME == "original": | |
| outputs = _run_original_pipeline( | |
| image, | |
| seed, | |
| ss_guidance_strength, | |
| ss_sampling_steps, | |
| slat_guidance_strength, | |
| slat_sampling_steps, | |
| ) | |
| else: | |
| mesh, grid_size = _run_trellis2_pipeline( | |
| image, | |
| seed, | |
| ss_guidance_strength, | |
| ss_sampling_steps, | |
| slat_guidance_strength, | |
| slat_sampling_steps, | |
| ) | |
| logger.info("Inference Time: %.2fs", time.time() - inference_start) | |
| preview_start = time.time() | |
| progress(0.45, desc=f"Rendering preview... {time.time() - job_start:.1f}s") | |
| # Pack state immediately so both tasks can proceed in parallel | |
| if TRELLIS_RUNTIME == "original": | |
| state = pack_original_state(outputs) | |
| else: | |
| state = pack_trellis2_state(mesh, grid_size) | |
| # Launch preview rendering and model extraction concurrently. | |
| # CUDA ops from different CPU threads safely serialize on the GPU. | |
| with concurrent.futures.ThreadPoolExecutor(max_workers=2) as pool: | |
| if TRELLIS_RUNTIME == "original": | |
| future_video = pool.submit(render_original_preview, outputs, req) | |
| else: | |
| future_video = pool.submit(render_trellis2_preview, mesh, req) | |
| future_extract = pool.submit( | |
| _extract_model_payload_from_state, state, mesh_simplify, texture_size, None, | |
| ) | |
| # Yield the video as soon as it's ready (model extraction continues) | |
| video_path = future_video.result() | |
| logger.info("Preview Render Time: %.2fs", time.time() - preview_start) | |
| yield video_path, state, None, gr.update() | |
| # Wait for model extraction to finish | |
| export_payload = future_extract.result() | |
| logger.info("Phase A+B Parallel Time: %.2fs", time.time() - preview_start) | |
| logger.info("Total Time: %.2fs", time.time() - job_start) | |
| yield video_path, state, export_payload, gr.update() | |
| except RuntimeError as re: | |
| if "out of memory" in str(re).lower(): | |
| raise gr.Error("GPU out of memory. Try reducing texture size and retry.") | |
| raise gr.Error("Generation failed. Try another image or lower complexity.") from re | |
| except Exception as e: | |
| raise gr.Error("Generation failed. Please retry.") from e | |
| finally: | |
| clear_cuda_cache() | |
| # @spaces.GPU(duration=30) | |
| # def generate_preview_multi( | |
| # gallery_images: Optional[List], | |
| # seed: int, | |
| # ss_guidance_strength: float, | |
| # ss_sampling_steps: int, | |
| # slat_guidance_strength: float, | |
| # slat_sampling_steps: int, | |
| # req: gr.Request, | |
| # progress=gr.Progress(track_tqdm=True), | |
| # ): | |
| # """Generate a 3D preview from multiple input images (Trellis 1 only).""" | |
| # if not gallery_images or pipeline is None: | |
| # return None, None | |
| # if TRELLIS_RUNTIME != "original": | |
| # raise gr.Error("Multi-image generation is only supported with the original TRELLIS runtime.") | |
| # if not torch.cuda.is_available(): | |
| # raise gr.Error("GPU is not ready. Please retry in a few seconds.") | |
| # # Extract PIL images from Gradio gallery format | |
| # pil_images: List[Image.Image] = [] | |
| # for item in gallery_images: | |
| # if isinstance(item, Image.Image): | |
| # pil_images.append(item) | |
| # elif isinstance(item, tuple): | |
| # pil_images.append(item[0] if isinstance(item[0], Image.Image) else Image.open(item[0])) | |
| # elif isinstance(item, str): | |
| # pil_images.append(Image.open(item)) | |
| # elif isinstance(item, dict) and "name" in item: | |
| # pil_images.append(Image.open(item["name"])) | |
| # if len(pil_images) < 2: | |
| # raise gr.Error("Please upload at least 2 images for multi-image generation.") | |
| # logger.info("Multi-image generation with %d images", len(pil_images)) | |
| # job_start = time.time() | |
| # preprocess_start = time.time() | |
| # processed_images = preprocess_images(pil_images) | |
| # if processed_images is None: | |
| # raise gr.Error("Image preprocessing failed. Please try different images.") | |
| # logger.info("Multi-image Preprocess Time: %.2fs", time.time() - preprocess_start) | |
| # progress(0.1, desc=f"Preprocessed {len(processed_images)} images in {time.time() - preprocess_start:.1f}s") | |
| # try: | |
| # inference_start = time.time() | |
| # progress(0.2, desc=f"Generating 3D structure from {len(processed_images)} views...") | |
| # outputs = _run_original_pipeline_multi( | |
| # processed_images, | |
| # seed, | |
| # ss_guidance_strength, | |
| # ss_sampling_steps, | |
| # slat_guidance_strength, | |
| # slat_sampling_steps, | |
| # ) | |
| # if outputs is None: | |
| # raise gr.Error("Multi-image pipeline returned no results. Please retry.") | |
| # logger.info("Multi-image Inference Time: %.2fs", time.time() - inference_start) | |
| # preview_start = time.time() | |
| # progress(0.58, desc=f"Rendering preview... {time.time() - job_start:.1f}s") | |
| # video_path = render_original_preview(outputs, req) | |
| # state = pack_original_state(outputs) | |
| # logger.info("Preview Render Time: %.2fs", time.time() - preview_start) | |
| # logger.info("Multi-Image Phase A Total Time: %.2fs", time.time() - job_start) | |
| # return video_path, state | |
| # except RuntimeError as re: | |
| # if "out of memory" in str(re).lower(): | |
| # raise gr.Error("GPU out of memory. Try fewer images or lower complexity.") | |
| # raise gr.Error("Multi-image generation failed. Try different images.") from re | |
| # except Exception as e: | |
| # raise gr.Error("Multi-image generation failed. Please retry.") from e | |
| # finally: | |
| # clear_cuda_cache() | |
| css = """ | |
| h1, h2, h3 { text-align: center; display: block; } | |
| h1 a {color: #5A11FF !important; text-decoration: none !important;} | |
| footer { visibility: hidden; } | |
| .gradio-container { max-width: 1100px !important; } | |
| .gr-image-container { display: flex !important; justify-content: center !important; align-items: center !important; width: 100%; height: 240px; } | |
| .gr-image-container img { width: 100%; height: 100%; object-fit: contain; object-position: center; } | |
| .gr-image { display: flex; justify-content: center; align-items: center; width: 100%; height: 512px; overflow: hidden; } | |
| .gr-image img { width: 100%; height: 100%; object-fit: cover; object-position: center; } | |
| .video-container video { width: 100% !important; height: 100% !important; object-fit: contain !important; object-position: center !important; } | |
| .sponsor-banner { text-align: center; margin: 8px 0 14px 0; } | |
| .sponsor-banner-title { font-size: 1.05rem; font-weight: 700; margin-bottom: 8px; } | |
| .sponsor-banner-title a { | |
| color: #5A11FF !important; | |
| text-decoration: none !important; | |
| } | |
| .sponsor-banner-button { | |
| display: inline-block; | |
| padding: 8px 14px; | |
| border-radius: 10px; | |
| font-weight: 700; | |
| text-decoration: none !important; | |
| background: linear-gradient(90deg, #2f6bff 0%, #7d4dff 100%); | |
| color: #ffffff !important; | |
| } | |
| .toast-wrap, .toast-body, .toast-container { display: none !important; } | |
| .model-container .progress-text, .model-container .progress-level { display: none !important; } | |
| @media screen and (min-width: 768px) { | |
| .gr-image-container { height: 360px !important; } | |
| .video-container { height: 360px !important; max-width: 680px !important; margin: 0 auto !important; aspect-ratio: auto !important; } | |
| .model-container { height: 480px !important; max-width: 680px !important; margin: 0 auto !important; } | |
| } | |
| .custom-header { display: flex; align-items: center; height: 100%; } | |
| .stl-download-btn { max-width: 680px !important; margin: 0 auto !important; } | |
| .stl-download-btn button { width: 100% !important; background: #5A11FF !important; border-color: #5A11FF !important; color: #fff !important; } | |
| """ | |
| schema_data = { | |
| "@context": "https://schema.org", | |
| "@type": "SoftwareApplication", | |
| "name": "Pocket 3D AI by SkyeBrowse", | |
| "operatingSystem": "Web", | |
| "applicationCategory": "MultimediaApplication", | |
| "description": "Instant AI-powered 3D model generation from a single image. Upload a photo and get a downloadable GLB and STL model in seconds.", | |
| "author": { | |
| "@type": "Organization", | |
| "name": "SkyeBrowse", | |
| "url": "https://www.skyebrowse.com" | |
| }, | |
| "offers": { | |
| "@type": "Offer", | |
| "price": "0", | |
| "priceCurrency": "USD" | |
| } | |
| } | |
| head_html = f""" | |
| <script type="application/ld+json"> | |
| {json.dumps(schema_data)} | |
| </script> | |
| <link rel="canonical" href="https://3dai.skyebrowse.com/" /> | |
| <meta name="description" content="Generate 3D models from a single image with SkyeBrowse Pocket 3D AI. Download GLB and STL files instantly."> | |
| <meta property="og:title" content="Pocket 3D AI | Powered by SkyeBrowse"> | |
| <meta property="og:type" content="website"> | |
| <meta property="og:url" content="https://3dai.skyebrowse.com/"> | |
| <meta property="og:image" content="https://www.skyebrowse.com/logo.png"> | |
| """ | |
| custom_js = """ | |
| () => { | |
| new MutationObserver(() => { | |
| document.querySelectorAll('.progress-text, .eta-bar, .progress-level-inner').forEach(el => { | |
| if (el.textContent.match(/zero\\s*gpu/i)) { | |
| el.style.visibility = 'hidden'; | |
| } | |
| }); | |
| }).observe(document.body, {childList: true, subtree: true, characterData: true}); | |
| // Rewrite external app links when hosted on *.app.skyebrowse.com | |
| const hostname = window.location.hostname; | |
| if (hostname.endsWith('app.skyebrowse.com')) { | |
| const origin = window.location.origin; | |
| const linkMap = { | |
| 'interiorai.skyebrowse.com': origin + '/interior-ai', | |
| 'anime.skyebrowse.com': origin + '/anime-ai', | |
| '3dai.skyebrowse.com': origin + '/3d-ai', | |
| 'app.skyebrowse.com': origin + '/app', | |
| 'www.skyebrowse.com': origin, | |
| }; | |
| function rewriteLinks() { | |
| document.querySelectorAll('a[href]').forEach(a => { | |
| try { | |
| const url = new URL(a.href); | |
| if (linkMap[url.hostname]) { | |
| a.href = linkMap[url.hostname]; | |
| } | |
| } catch(e) {} | |
| }); | |
| } | |
| rewriteLinks(); | |
| new MutationObserver(rewriteLinks).observe(document.body, {childList: true, subtree: true}); | |
| } | |
| } | |
| """ | |
| with gr.Blocks(theme="Taithrah/Minimal", css=css, js=custom_js, head=head_html, title="Pocket 3D AI | SkyeBrowse") as demo: | |
| default_ss_steps = 25 | |
| default_slat_steps = 25 | |
| default_texture_size = 2048 | |
| texture_min = 512 | |
| texture_max = 4096 | |
| texture_step = 1024 | |
| # with gr.Row(equal_height=True): | |
| # gr.Image("assets/sb_pocket_logo_dark.png", show_label=False, container=False, show_download_button=False, min_width=50, interactive=False, show_fullscreen_button=False) | |
| gr.Markdown("# 🤖 Pocket 3D AI by [SkyeBrowse](https://www.skyebrowse.com)") | |
| gr.Markdown( | |
| "Turn any image into a 3D model in seconds. Upload or paste a photo and Pocket 3D AI will generate a " | |
| "full 3D model with textures. Preview it right in your browser, then download the **GLB** or **STL** file " | |
| "for 3D printing, game engines, AR/VR, and more." | |
| ) | |
| gr.HTML( | |
| '<div class="sponsor-banner">' | |
| # '<div class="sponsor-banner-title">Sponsored by <a href="https://www.skyebrowse.com" target="_blank">SkyeBrowse</a></div>' | |
| '<a class="sponsor-banner-button" href="https://www.skyebrowse.com" target="_blank">Try more AI + 3D modeling</a>' | |
| '</div>' | |
| ) | |
| with gr.Row(): | |
| gr.HTML( | |
| '<div style="text-align: center; margin-bottom: 15px; border-bottom: 1px solid #eee; padding-bottom: 10px;">' | |
| '<span style="font-weight: 600;">Try our other AI tools + 3D modeling: </span>' | |
| '<a href="https://interiorai.skyebrowse.com" style="color: #2f6bff; font-weight: bold; margin: 0 12px; text-decoration: none;">🏠 Interior AI Designer</a>' | |
| '<a href="https://anime.skyebrowse.com" style="color: #2f6bff; font-weight: bold; margin: 0 12px; text-decoration: none;">🎨 Anime AI Art</a>' | |
| '<a href="https://app.skyebrowse.com" style="color: #2f6bff; font-weight: bold; margin: 0 12px; text-decoration: none;">🛸 3D Drone Mapping</a>' | |
| '</div>' | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=100, variant="default"): | |
| image_prompt = gr.Image( | |
| label="Single Image Input", | |
| format="png", | |
| image_mode="RGBA", | |
| type="pil", | |
| sources=["upload", "clipboard"], | |
| container=True, | |
| mirror_webcam=True, | |
| visible=True, | |
| height=240, | |
| elem_classes="gr-image-container", | |
| ) | |
| # multi_image_gallery = gr.Gallery( | |
| # label="Multi-Image Input (2–6 views)", | |
| # columns=3, | |
| # rows=2, | |
| # height=240, | |
| # object_fit="contain", | |
| # type="pil", | |
| # visible=True if TRELLIS_RUNTIME == "original" else False, | |
| # ) | |
| # multi_image_btn = gr.Button( | |
| # "🚀 Generate from Multiple Images", | |
| # variant="secondary", | |
| # visible=True if TRELLIS_RUNTIME == "original" else False, | |
| # size="lg", | |
| # ) | |
| with gr.Column(scale=5, min_width=100): | |
| video_output = gr.Video( | |
| label="Preview", | |
| height=240, | |
| elem_classes="video-container", | |
| visible=True, | |
| autoplay=True, | |
| loop=True, | |
| show_download_button=True, | |
| interactive=False, | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=2, min_width=100, variant="default"): | |
| examples = gr.Examples( | |
| examples=[f"./assets/example_image/{image}" for image in os.listdir("./assets/example_image")], | |
| inputs=[image_prompt], | |
| examples_per_page=9, | |
| ) | |
| with gr.Column(scale=5): | |
| model_output = LitModel3D( | |
| label="", | |
| container=True, | |
| zoom_speed=0.5, | |
| pan_speed=3.0, | |
| exposure=10.0, | |
| height=360, | |
| elem_classes="model-container", | |
| visible=True, | |
| ) | |
| stl_download_button = gr.DownloadButton( | |
| label="Download STL", | |
| visible=True, | |
| interactive=False, | |
| size="lg", | |
| variant="primary", | |
| elem_classes="stl-download-btn", | |
| ) | |
| with gr.Accordion(label="Generation Settings", open=False, visible=show_options and not RUNNING_ON_SPACES): | |
| seed_slider = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True) | |
| gr.Markdown("Stage 1: Sparse Structure Generation") | |
| with gr.Row(): | |
| ss_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=7.5, step=0.1) | |
| ss_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=default_ss_steps, step=1) | |
| gr.Markdown("Stage 2: Structured Latent Generation") | |
| with gr.Row(): | |
| slat_guidance_strength = gr.Slider(0.0, 10.0, label="Guidance Strength", value=1.5, step=0.1) | |
| slat_sampling_steps = gr.Slider(1, 50, label="Sampling Steps", value=default_slat_steps, step=1) | |
| if RUNNING_ON_SPACES: | |
| seed_slider = gr.Slider(0, MAX_SEED, label="Seed", value=0, step=1, visible=False) | |
| randomize_seed = gr.Checkbox(label="Randomize Seed", value=True, visible=False) | |
| ss_guidance_strength = gr.Slider(0.0, 10.0, value=6.0, step=0.1, visible=False) | |
| ss_sampling_steps = gr.Slider(1, 50, value=default_ss_steps, step=1, visible=False) | |
| slat_guidance_strength = gr.Slider(0.0, 10.0, value=1.0, step=0.1, visible=False) | |
| slat_sampling_steps = gr.Slider(1, 50, value=default_slat_steps, step=1, visible=False) | |
| with gr.Accordion(label="GLB Extraction Settings", open=False, visible=show_options and not RUNNING_ON_SPACES): | |
| mesh_simplify = gr.Slider(0.0, 0.98, label="Simplify", value=0.95, step=0.01) | |
| texture_size = gr.Slider(texture_min, texture_max, label="Texture Size", value=default_texture_size, step=texture_step) | |
| if RUNNING_ON_SPACES: | |
| mesh_simplify = gr.Slider(0.0, 0.98, value=0.95, step=0.01, visible=False) | |
| texture_size = gr.Slider(texture_min, texture_max, value=default_texture_size, step=texture_step, visible=False) | |
| export_payload_state = gr.State(value=None) | |
| demo.load(start_session) | |
| demo.unload(end_session) | |
| # multi_common_inputs = [ | |
| # multi_image_gallery, | |
| # seed_slider, | |
| # ss_guidance_strength, | |
| # ss_sampling_steps, | |
| # slat_guidance_strength, | |
| # slat_sampling_steps, | |
| # ] | |
| pipeline_state = gr.State(value=None) | |
| combined_inputs = [ | |
| image_prompt, | |
| seed_slider, | |
| ss_guidance_strength, | |
| ss_sampling_steps, | |
| slat_guidance_strength, | |
| slat_sampling_steps, | |
| mesh_simplify, | |
| texture_size, | |
| ] | |
| image_prompt.upload( | |
| get_seed, | |
| inputs=[randomize_seed, seed_slider], | |
| outputs=[seed_slider], | |
| show_progress="minimal", | |
| trigger_mode="always_last", | |
| ).then( | |
| fn=generate_and_extract, | |
| inputs=combined_inputs, | |
| outputs=[video_output, pipeline_state, export_payload_state, model_output], | |
| show_progress="minimal", | |
| scroll_to_output=True, | |
| ).then( | |
| fn=materialize_export_payload, | |
| inputs=[export_payload_state], | |
| outputs=[model_output, stl_download_button], | |
| show_progress="minimal", | |
| ) | |
| # multi_image_btn.click( | |
| # get_seed, | |
| # inputs=[randomize_seed, seed_slider], | |
| # outputs=[seed_slider], | |
| # show_progress="minimal", | |
| # ).then( | |
| # fn=generate_preview_multi, | |
| # inputs=multi_common_inputs, | |
| # outputs=[video_output, pipeline_state], | |
| # show_progress="minimal", | |
| # scroll_to_output=True, | |
| # ).then( | |
| # fn=extract_model, | |
| # inputs=[pipeline_state, mesh_simplify, texture_size], | |
| # outputs=[export_payload_state, model_output], | |
| # show_progress="minimal", | |
| # ).then( | |
| # fn=materialize_export_payload, | |
| # inputs=[export_payload_state], | |
| # outputs=[model_output, stl_download_button], | |
| # show_progress="minimal", | |
| # ) | |
| examples.dataset.select( | |
| fn=get_seed, | |
| inputs=[randomize_seed, seed_slider], | |
| outputs=[seed_slider], | |
| show_progress="minimal", | |
| trigger_mode="always_last", | |
| ).then( | |
| fn=generate_and_extract, | |
| inputs=combined_inputs, | |
| outputs=[video_output, pipeline_state, export_payload_state, model_output], | |
| show_progress="minimal", | |
| scroll_to_output=True, | |
| ).then( | |
| fn=materialize_export_payload, | |
| inputs=[export_payload_state], | |
| outputs=[model_output, stl_download_button], | |
| show_progress="minimal", | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### About Pocket 3D AI by SkyeBrowse | |
| Pocket 3D AI generates 3D models from a single image in seconds. | |
| Upload any photo and get a downloadable **GLB** and **STL** file ready for 3D printing, | |
| game engines, or AR/VR applications. Built on **TRELLIS** image-to-3D reconstruction, | |
| this tool is part of the **SkyeBrowse AI** ecosystem powering | |
| **3D reconstruction**, **photogrammetry**, and **spatial data visualization**. | |
| """ | |
| ) | |
| RESTART_INTERVAL_SECONDS = 3600 # 1 hour | |
| def _auto_restart_space(): | |
| """Background thread that restarts this HF Space every hour to keep Zero GPU healthy.""" | |
| space_id = os.getenv("SPACE_ID") | |
| if not space_id: | |
| return | |
| try: | |
| from huggingface_hub import HfApi | |
| api = HfApi() | |
| except Exception: | |
| logger.warning("huggingface_hub not available; auto-restart disabled") | |
| return | |
| logger.info("Auto-restart thread started — will restart %s every %ds", space_id, RESTART_INTERVAL_SECONDS) | |
| time.sleep(RESTART_INTERVAL_SECONDS) | |
| logger.info("Auto-restart: restarting space %s now", space_id) | |
| try: | |
| api.restart_space(space_id) | |
| except Exception as e: | |
| logger.error("Auto-restart failed: %s", e) | |
| if RUNNING_ON_SPACES: | |
| _restart_thread = threading.Thread(target=_auto_restart_space, daemon=True) | |
| _restart_thread.start() | |
| if __name__ == "__main__": | |
| if pipeline is None: | |
| logger.critical("Pipeline failed to initialize. Exiting.") | |
| sys.exit(1) | |
| logger.info("Launching runtime: %s", TRELLIS_RUNTIME) | |
| if RUNNING_ON_SPACES: | |
| logger.info("Launching on HuggingFace Spaces") | |
| launch_kwargs = { | |
| "show_api": False, | |
| "share": False, | |
| "allowed_paths": ALLOWED_PATHS, | |
| } | |
| if PUBLIC_BASE_URL: | |
| launch_kwargs["root_path"] = PUBLIC_BASE_URL | |
| logger.info("Using PUBLIC_BASE_URL for URL generation: %s", PUBLIC_BASE_URL) | |
| else: | |
| logger.warning( | |
| "PUBLIC_BASE_URL is not set. If using a custom domain, set PUBLIC_BASE_URL=https://your-domain" | |
| ) | |
| demo.queue(max_size=20, default_concurrency_limit=20, api_open=False).launch(**launch_kwargs) | |
| elif prod: | |
| logger.info("Launching in PRODUCTION mode on port %s", port) | |
| demo.queue(max_size=20, default_concurrency_limit=5).launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| show_api=False, | |
| favicon_path="assets/sb_3d_ai_logo.png", | |
| share=False, | |
| allowed_paths=ALLOWED_PATHS, | |
| ) | |
| else: | |
| logger.info("Launching in DEVELOPMENT mode on port %s", port) | |
| demo.queue(api_open=False).launch( | |
| server_name="0.0.0.0", | |
| server_port=port, | |
| show_api=False, | |
| favicon_path="assets/sb_3d_ai_logo.png", | |
| debug=True, | |
| share=True, | |
| allowed_paths=ALLOWED_PATHS, | |
| ) | |