Spaces:
Running on Zero
Running on Zero
| """ | |
| RunPod Serverless Handler for ASM Image-to-3D | |
| Optimized for RunPod Cached Models + Baked Docker Models | |
| """ | |
| import os | |
| import sys | |
| import runpod | |
| import base64 | |
| import time | |
| import uuid | |
| import json | |
| import asyncio | |
| import threading | |
| import shutil | |
| import torch | |
| import numpy as np | |
| import imageio | |
| from io import BytesIO | |
| from PIL import Image, ImageOps | |
| from pillow_heif import register_heif_opener | |
| from typing import Dict, Any, Generator, Tuple, List, Literal | |
| from easydict import EasyDict as edict | |
| # Register HEIC opener for iPhone photos | |
| register_heif_opener() | |
| # ============================================================================= | |
| # STORAGE CONFIGURATION (Simplified) | |
| # ============================================================================= | |
| # Priority: RunPod Cached Models > Baked in Docker > Download at runtime | |
| # No network volume needed - models are either cached by RunPod or baked in Docker | |
| print("=== RunPod Worker Starting ===", flush=True) | |
| # Paths | |
| RUNPOD_MODEL_CACHE = "/runpod-volume/huggingface-cache/hub" | |
| BAKED_CACHE = "/app/cache" | |
| def find_cached_model(model_name: str) -> str: | |
| """Find model in RunPod's cache system.""" | |
| cache_name = model_name.replace("/", "--") | |
| snapshots_dir = os.path.join(RUNPOD_MODEL_CACHE, f"models--{cache_name}", "snapshots") | |
| if os.path.exists(snapshots_dir): | |
| snapshots = os.listdir(snapshots_dir) | |
| if snapshots: | |
| return os.path.join(snapshots_dir, snapshots[0]) | |
| return None | |
| # Check RunPod cached models first (fastest) | |
| USE_RUNPOD_CACHE = os.path.exists(RUNPOD_MODEL_CACHE) | |
| cached_model_path = None | |
| if USE_RUNPOD_CACHE: | |
| cached_model_path = find_cached_model("arabago96/ASM-model") | |
| if cached_model_path: | |
| print(f"⚡ RunPod Cache: ASM-model ready (host NVMe SSD)", flush=True) | |
| os.environ['HF_HOME'] = "/runpod-volume/huggingface-cache" | |
| else: | |
| USE_RUNPOD_CACHE = False | |
| # Fallback to baked Docker cache - with STRICT verification | |
| if not USE_RUNPOD_CACHE: | |
| # Check that ACTUAL model files exist, not just the directory | |
| asm_baked = os.path.exists(os.path.join(BAKED_CACHE, 'huggingface', 'hub')) | |
| dinov2_baked = os.path.exists(os.path.join(BAKED_CACHE, 'torch', 'hub', 'facebookresearch_dinov2_main')) | |
| dinov2_weights = os.path.exists(os.path.join(BAKED_CACHE, 'torch', 'hub', 'checkpoints', 'dinov2_vitl14_reg4_pretrain.pth')) | |
| # Check for BiRefNet (HF Cache) | |
| birefnet_baked = os.path.exists(os.path.join(BAKED_CACHE, 'huggingface', 'hub', 'models--ZhengPeng7--BiRefNet')) | |
| print(f"🔍 Checking baked models:", flush=True) | |
| print(f" ASM Model: {'✅' if asm_baked else '❌'} {BAKED_CACHE}/huggingface/hub", flush=True) | |
| print(f" DINOv2: {'✅' if dinov2_baked and dinov2_weights else '❌'}", flush=True) | |
| print(f" BiRefNet: {'✅' if birefnet_baked else '❌'} (Replaces rembg)", flush=True) | |
| if asm_baked and dinov2_baked and dinov2_weights and birefnet_baked: | |
| print(f"🐳 All models baked in Docker image", flush=True) | |
| os.environ['HF_HOME'] = os.path.join(BAKED_CACHE, 'huggingface') | |
| os.environ['TORCH_HOME'] = os.path.join(BAKED_CACHE, 'torch') | |
| os.environ['XDG_CACHE_HOME'] = BAKED_CACHE | |
| else: | |
| # FAIL FAST - no silent downloads | |
| missing = [] | |
| if not asm_baked: missing.append("ASM Model") | |
| if not dinov2_baked: missing.append("DINOv2 (Repo)") | |
| if not dinov2_weights: missing.append("DINOv2 (Weights)") | |
| if not birefnet_baked: missing.append("BiRefNet") | |
| print(f"❌ FATAL: Missing models: {', '.join(missing)}", flush=True) | |
| print(f" RunPod cache: NOT FOUND at {RUNPOD_MODEL_CACHE}", flush=True) | |
| print(f" Baked cache: INCOMPLETE at {BAKED_CACHE}", flush=True) | |
| print(f" Fix: Enable RunPod cached models OR rebuild Docker with baked models", flush=True) | |
| raise RuntimeError(f"Missing models: {', '.join(missing)}. Enable RunPod cache or rebuild Docker.") | |
| # Set torch hub directory | |
| # Set torch hub directory | |
| # torch.hub.set_dir(os.environ.get('TORCH_HOME', '/tmp/torch')) -> REMOVE: This breaks hierarchy | |
| pass | |
| # Force xformers attention to prevent crashes | |
| os.environ['ATTN_BACKEND'] = 'xformers' | |
| os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' | |
| # Print config summary | |
| print(f" HF_HOME: {os.environ.get('HF_HOME')}", flush=True) | |
| print(f" TORCH_HOME: {os.environ.get('TORCH_HOME')}", flush=True) | |
| # ============================================================================= | |
| # IMPORTS (after env vars are set) | |
| # ============================================================================= | |
| # Lazy imports within methods to ensure <1s startup time | |
| # from asm.pipelines import ASMImageTo3DPipeline | |
| # from asm.representations import Gaussian, MeshExtractResult | |
| # from asm.utils import render_utils, postprocessing_utils | |
| import warnings | |
| warnings.filterwarnings("ignore", category=FutureWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| # import rembg (Replaced by BiRefNet) | |
| from utils_birefnet import BiRefNet | |
| # ============================================================================= | |
| # INFERENCE CLASS (Lazy Loading) | |
| # ============================================================================= | |
| class ASMInference: | |
| """Handles ASM image-to-3D inference with lazy loading. | |
| - Step 1 (preprocess): Only loads BiRefNet | |
| - Step 2 (generate): Loads ASM model + DINOv2 | |
| - Step 3 (export): Uses saved state, no models needed | |
| """ | |
| def __init__(self, model_name: str = "arabago96/ASM-model"): | |
| self.pipeline = None | |
| self.model_name = model_name | |
| self.lock = threading.Lock() | |
| self.birefnet_lock = threading.Lock() | |
| self.birefnet = None | |
| def _load_birefnet(self): | |
| """Load BiRefNet on demand (State of the Art, uses PyTorch).""" | |
| if self.birefnet is None: | |
| # No 27s delay here - PyTorch loading is fast | |
| print(f"[LOAD] Initializing BiRefNet (First Run Only)...", flush=True) | |
| t0 = time.time() | |
| self.birefnet = BiRefNet() # Default: ZhengPeng7/BiRefNet | |
| self.birefnet.cuda() | |
| print(f"[LOAD] BiRefNet ready in {time.time()-t0:.1f}s", flush=True) | |
| def _load_model(self): | |
| """Load ASM pipeline on demand (includes DINOv2).""" | |
| if self.pipeline is None: | |
| print(f"[LOAD] Initializing ASM+Torch (First Run Only)...", flush=True) | |
| # Lazy import to prevent cold-start delay | |
| from asm.pipelines import ASMImageTo3DPipeline | |
| with self.lock: | |
| if self.pipeline is None: | |
| hf_home = os.environ.get('HF_HOME', 'default') | |
| torch_home = os.environ.get('TORCH_HOME', 'default') | |
| print(f"[LOAD] ASM Model from: {hf_home}", flush=True) | |
| print(f"[LOAD] DINOv2 from: {torch_home}", flush=True) | |
| t0 = time.time() | |
| self.pipeline = ASMImageTo3DPipeline.from_pretrained(self.model_name) | |
| self.pipeline.cuda() | |
| print(f"[LOAD] ASM+DINOv2 ready in {time.time()-t0:.1f}s", flush=True) | |
| def preprocess_image(self, input_image: Image.Image) -> Image.Image: | |
| """ | |
| Preprocess image using ASM logic (BiRefNet + Centering). | |
| """ | |
| self._load_birefnet() | |
| # 1. Resize if too big | |
| max_size = max(input_image.size) | |
| scale = min(1, 1024 / max_size) | |
| if scale < 1: | |
| input_image = input_image.resize( | |
| (int(input_image.width * scale), int(input_image.height * scale)), | |
| Image.Resampling.LANCZOS | |
| ) | |
| # 2. Update self.birefnet | |
| # Check if already RGBA | |
| has_alpha = False | |
| if input_image.mode == 'RGBA': | |
| alpha = np.array(input_image)[:, :, 3] | |
| if not np.all(alpha == 255): | |
| has_alpha = True | |
| if has_alpha: | |
| output = input_image | |
| else: | |
| # 3. Remove Background (BiRefNet) | |
| with self.birefnet_lock: | |
| output = self.birefnet(input_image) | |
| # 4. Crop and Center (ASM Logic) | |
| output_np = np.array(output) | |
| alpha = output_np[:, :, 3] | |
| bbox = np.argwhere(alpha > 0.8 * 255) | |
| if bbox.shape[0] == 0: | |
| bbox = [0, 0, output.height, output.width] | |
| # If empty, just return as is (or handle error) | |
| else: | |
| bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) | |
| center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 | |
| size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) | |
| size = int(size * 1) | |
| # Square crop | |
| bbox = ( | |
| int(center[0] - size // 2), | |
| int(center[1] - size // 2), | |
| int(center[0] + size // 2), | |
| int(center[1] + size // 2) | |
| ) | |
| output = output.crop(bbox) | |
| # 5. Composition (Black Background) | |
| output = np.array(output).astype(np.float32) / 255 | |
| output = output[:, :, :3] * output[:, :, 3:4] # Apply alpha to RGB | |
| output = Image.fromarray((output * 255).astype(np.uint8)) | |
| return output | |
| def generate_3d(self, image: Image.Image, seed: int = 0, | |
| ss_guidance_strength: float = 9, ss_sampling_steps: int = 30, | |
| slat_guidance_strength: float = 9, slat_sampling_steps: int = 50, | |
| preprocess: bool = False) -> Tuple["Gaussian", "MeshExtractResult"]: | |
| """Generate 3D from image. Loads ASM model if not already loaded.""" | |
| self._load_model() | |
| # Ensure steps are integers | |
| ss_sampling_steps = int(ss_sampling_steps) | |
| slat_sampling_steps = int(slat_sampling_steps) | |
| print(f"[GENERATE] seed={seed}, ss_steps={ss_sampling_steps}, slat_steps={slat_sampling_steps}", flush=True) | |
| outputs = self.pipeline.run( | |
| image, | |
| seed=seed, | |
| formats=["gaussian", "mesh"], | |
| preprocess_image=preprocess, | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "cfg_strength": ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "cfg_strength": slat_guidance_strength, | |
| }, | |
| ) | |
| print(f"[GENERATE] Complete", flush=True) | |
| return outputs['gaussian'][0], outputs['mesh'][0] | |
| def generate_3d_multi_image(self, images: List[Image.Image], seed: int = 0, | |
| ss_guidance_strength: float = 9, ss_sampling_steps: int = 30, | |
| slat_guidance_strength: float = 9, slat_sampling_steps: int = 50, | |
| multiimage_algo: str = "stochastic", | |
| preprocess: bool = False) -> Tuple["Gaussian", "MeshExtractResult"]: | |
| """Generate 3D from multiple images.""" | |
| self._load_model() | |
| # Ensure steps are integers | |
| ss_sampling_steps = int(ss_sampling_steps) | |
| slat_sampling_steps = int(slat_sampling_steps) | |
| outputs = self.pipeline.run_multi_image( | |
| images, | |
| seed=seed, | |
| formats=["gaussian", "mesh"], | |
| preprocess_image=preprocess, | |
| sparse_structure_sampler_params={ | |
| "steps": ss_sampling_steps, | |
| "cfg_strength": ss_guidance_strength, | |
| }, | |
| slat_sampler_params={ | |
| "steps": slat_sampling_steps, | |
| "cfg_strength": slat_guidance_strength, | |
| }, | |
| mode=multiimage_algo, | |
| ) | |
| return outputs['gaussian'][0], outputs['mesh'][0] | |
| def render_video(self, gaussian: "Gaussian", output_path: str, num_frames: int = 120, fps: int = 15) -> str: | |
| from asm.utils import render_utils | |
| import imageio | |
| video = render_utils.render_video(gaussian, num_frames=num_frames)['color'] | |
| imageio.mimsave(output_path, video, fps=fps) | |
| return output_path | |
| def export_glb(self, gaussian: "Gaussian", mesh: "MeshExtractResult", output_path: str, | |
| mesh_simplify: float = 0.9, texture_size: int = 1024) -> str: | |
| from asm.utils import postprocessing_utils | |
| glb = postprocessing_utils.to_glb(gaussian, mesh, simplify=mesh_simplify, texture_size=texture_size, verbose=False) | |
| glb.export(output_path) | |
| return output_path | |
| def export_glb_raw(self, gaussian: "Gaussian", mesh: "MeshExtractResult", output_path: str, texture_size: int = 1024) -> str: | |
| from asm.utils import postprocessing_utils | |
| glb = postprocessing_utils.to_glb_raw(gaussian, mesh, texture_size=texture_size, verbose=False) | |
| glb.export(output_path) | |
| return output_path | |
| def export_gaussian(self, gaussian: "Gaussian", output_path: str) -> str: | |
| gaussian.save_ply(output_path) | |
| return output_path | |
| def cleanup_gpu(self): | |
| torch.cuda.empty_cache() | |
| # ============================================================================= | |
| # STATE MANAGEMENT | |
| # ============================================================================= | |
| def pack_state(gs: "Gaussian", mesh: "MeshExtractResult") -> dict: | |
| return { | |
| 'gaussian': { | |
| **gs.init_params, | |
| '_xyz': gs._xyz.cpu().numpy(), | |
| '_features_dc': gs._features_dc.cpu().numpy(), | |
| '_scaling': gs._scaling.cpu().numpy(), | |
| '_rotation': gs._rotation.cpu().numpy(), | |
| '_opacity': gs._opacity.cpu().numpy(), | |
| }, | |
| 'mesh': { | |
| 'vertices': mesh.vertices.cpu().numpy(), | |
| 'faces': mesh.faces.cpu().numpy(), | |
| }, | |
| } | |
| def unpack_state(state: dict) -> Tuple["Gaussian", "MeshExtractResult"]: | |
| from asm.representations import Gaussian | |
| from easydict import EasyDict as edict | |
| gs = Gaussian( | |
| aabb=state['gaussian']['aabb'], | |
| sh_degree=state['gaussian']['sh_degree'], | |
| mininum_kernel_size=state['gaussian']['mininum_kernel_size'], | |
| scaling_bias=state['gaussian']['scaling_bias'], | |
| opacity_bias=state['gaussian']['opacity_bias'], | |
| scaling_activation=state['gaussian']['scaling_activation'], | |
| ) | |
| gs._xyz = torch.tensor(state['gaussian']['_xyz'], device='cuda') | |
| gs._features_dc = torch.tensor(state['gaussian']['_features_dc'], device='cuda') | |
| gs._scaling = torch.tensor(state['gaussian']['_scaling'], device='cuda') | |
| gs._rotation = torch.tensor(state['gaussian']['_rotation'], device='cuda') | |
| gs._opacity = torch.tensor(state['gaussian']['_opacity'], device='cuda') | |
| mesh = edict( | |
| vertices=torch.tensor(state['mesh']['vertices'], device='cuda'), | |
| faces=torch.tensor(state['mesh']['faces'], device='cuda'), | |
| ) | |
| return gs, mesh | |
| # ============================================================================= | |
| # GLOBALS | |
| # ============================================================================= | |
| inference = ASMInference() | |
| # Session storage (must be on Network Volume to persist across workers) | |
| # Fallback to /tmp for local testing | |
| SESSION_DIR = '/runpod-volume/sessions' if os.path.exists('/runpod-volume') else '/tmp/runpod_sessions' | |
| os.makedirs(SESSION_DIR, exist_ok=True) | |
| SESSION_TIMEOUT = 3600 # 1 hour | |
| # Concurrency | |
| MAX_CONCURRENT_JOBS = int(os.environ.get('MAX_CONCURRENT_JOBS', '2')) | |
| # ============================================================================= | |
| # UTILITIES | |
| # ============================================================================= | |
| def decode_image(image_b64: str) -> Image.Image: | |
| image = Image.open(BytesIO(base64.b64decode(image_b64))) | |
| image = ImageOps.exif_transpose(image) | |
| if image.mode not in ("RGB", "RGBA"): | |
| image = image.convert("RGBA" if "A" in image.mode else "RGB") | |
| return image | |
| def encode_image(image: Image.Image, format: str = "PNG") -> str: | |
| buffer = BytesIO() | |
| image.save(buffer, format=format) | |
| return base64.b64encode(buffer.getvalue()).decode() | |
| def encode_file(file_path: str) -> str: | |
| with open(file_path, "rb") as f: | |
| return base64.b64encode(f.read()).decode() | |
| def get_session_dir(session_id: str) -> str: | |
| return os.path.join(SESSION_DIR, session_id) | |
| def create_session() -> str: | |
| session_id = str(uuid.uuid4()) | |
| session_dir = get_session_dir(session_id) | |
| os.makedirs(session_dir, exist_ok=True) | |
| with open(os.path.join(session_dir, 'created'), 'w') as f: | |
| f.write(str(time.time())) | |
| return session_id | |
| def save_session_state(session_id: str, state: dict): | |
| session_dir = get_session_dir(session_id) | |
| np.savez_compressed( | |
| os.path.join(session_dir, 'state.npz'), | |
| gaussian_aabb=state['gaussian']['aabb'], | |
| gaussian_sh_degree=state['gaussian']['sh_degree'], | |
| gaussian_mininum_kernel_size=state['gaussian']['mininum_kernel_size'], | |
| gaussian_scaling_bias=state['gaussian']['scaling_bias'], | |
| gaussian_opacity_bias=state['gaussian']['opacity_bias'], | |
| gaussian_scaling_activation=state['gaussian']['scaling_activation'], | |
| gaussian_xyz=state['gaussian']['_xyz'], | |
| gaussian_features_dc=state['gaussian']['_features_dc'], | |
| gaussian_scaling=state['gaussian']['_scaling'], | |
| gaussian_rotation=state['gaussian']['_rotation'], | |
| gaussian_opacity=state['gaussian']['_opacity'], | |
| mesh_vertices=state['mesh']['vertices'], | |
| mesh_faces=state['mesh']['faces'], | |
| ) | |
| def load_session_state(session_id: str) -> dict: | |
| state_path = os.path.join(get_session_dir(session_id), 'state.npz') | |
| if not os.path.exists(state_path): | |
| raise ValueError(f"No state found for session {session_id}") | |
| data = np.load(state_path, allow_pickle=True) | |
| return { | |
| 'gaussian': { | |
| 'aabb': data['gaussian_aabb'].item() if data['gaussian_aabb'].ndim == 0 else data['gaussian_aabb'], | |
| 'sh_degree': int(data['gaussian_sh_degree']), | |
| 'mininum_kernel_size': float(data['gaussian_mininum_kernel_size']), | |
| 'scaling_bias': float(data['gaussian_scaling_bias']), | |
| 'opacity_bias': float(data['gaussian_opacity_bias']), | |
| 'scaling_activation': str(data['gaussian_scaling_activation']), | |
| '_xyz': data['gaussian_xyz'], | |
| '_features_dc': data['gaussian_features_dc'], | |
| '_scaling': data['gaussian_scaling'], | |
| '_rotation': data['gaussian_rotation'], | |
| '_opacity': data['gaussian_opacity'], | |
| }, | |
| 'mesh': { | |
| 'vertices': data['mesh_vertices'], | |
| 'faces': data['mesh_faces'], | |
| }, | |
| } | |
| def save_preprocessed_image(session_id: str, image: Image.Image): | |
| image.save(os.path.join(get_session_dir(session_id), 'preprocessed.png'), 'PNG') | |
| def load_preprocessed_image(session_id: str) -> Image.Image: | |
| path = os.path.join(get_session_dir(session_id), 'preprocessed.png') | |
| if not os.path.exists(path): | |
| raise ValueError(f"No preprocessed image for session {session_id}") | |
| return Image.open(path) | |
| def cleanup_old_sessions(): | |
| """Cleanup expired sessions.""" | |
| try: | |
| current_time = time.time() | |
| for session_id in os.listdir(SESSION_DIR): | |
| session_path = os.path.join(SESSION_DIR, session_id) | |
| if not os.path.isdir(session_path): | |
| continue | |
| created_file = os.path.join(session_path, 'created') | |
| if os.path.exists(created_file): | |
| with open(created_file, 'r') as f: | |
| created_time = float(f.read().strip()) | |
| if current_time - created_time > SESSION_TIMEOUT: | |
| shutil.rmtree(session_path, ignore_errors=True) | |
| except Exception: | |
| pass # Don't fail the request on cleanup errors | |
| # ============================================================================= | |
| # STEP HANDLERS | |
| # ============================================================================= | |
| def handle_preprocess(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Step 1: Preprocess image (rembg only - fast).""" | |
| start_time = time.time() | |
| if "image" not in input_data: | |
| return {"error": "No image provided"} | |
| session_id = create_session() | |
| image = decode_image(input_data["image"]) | |
| preprocessed = inference.preprocess_image(image) | |
| save_preprocessed_image(session_id, preprocessed) | |
| processing_time = round(time.time() - start_time, 2) | |
| print(f"[PREPROCESS] Complete in {processing_time}s", flush=True) | |
| return { | |
| "session_id": session_id, | |
| "preprocessed_image": encode_image(preprocessed), | |
| "step": "preprocess", | |
| "status": "complete", | |
| "processing_time": processing_time, | |
| "next_step": "generate" | |
| } | |
| def handle_preprocess_multi(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Step 1 (multi): Preprocess multiple images.""" | |
| start_time = time.time() | |
| if "images" not in input_data: | |
| return {"error": "No images provided"} | |
| images_b64 = input_data["images"] | |
| if not isinstance(images_b64, list) or len(images_b64) < 2: | |
| return {"error": "Multi-image mode requires at least 2 images"} | |
| session_id = create_session() | |
| session_dir = get_session_dir(session_id) | |
| preprocessed_images = [] | |
| for i, img_b64 in enumerate(images_b64): | |
| image = decode_image(img_b64) | |
| preprocessed = inference.preprocess_image(image) | |
| preprocessed_images.append(preprocessed) | |
| preprocessed.save(os.path.join(session_dir, f'preprocessed_{i}.png'), 'PNG') | |
| with open(os.path.join(session_dir, 'image_count'), 'w') as f: | |
| f.write(str(len(preprocessed_images))) | |
| processing_time = round(time.time() - start_time, 2) | |
| print(f"[PREPROCESS_MULTI] Complete in {processing_time}s", flush=True) | |
| return { | |
| "session_id": session_id, | |
| "preprocessed_image": encode_image(preprocessed_images[0]), | |
| "preprocessed_images": [encode_image(img) for img in preprocessed_images], | |
| "image_count": len(preprocessed_images), | |
| "step": "preprocess_multi", | |
| "status": "complete", | |
| "processing_time": processing_time, | |
| "next_step": "generate_multi" | |
| } | |
| def handle_generate(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Step 2: Generate 3D (loads ASM + DINOv2).""" | |
| start_time = time.time() | |
| session_id = input_data.get("session_id") | |
| if not session_id: | |
| return {"error": "No session_id provided"} | |
| try: | |
| preprocessed = load_preprocessed_image(session_id) | |
| except ValueError as e: | |
| return {"error": str(e)} | |
| session_dir = get_session_dir(session_id) | |
| gaussian, mesh = inference.generate_3d( | |
| image=preprocessed, | |
| seed=input_data.get("seed", 0), | |
| ss_guidance_strength=input_data.get("ss_guidance_strength", 9), | |
| ss_sampling_steps=input_data.get("ss_sampling_steps", 30), | |
| slat_guidance_strength=input_data.get("slat_guidance_strength", 9), | |
| slat_sampling_steps=input_data.get("slat_sampling_steps", 50), | |
| preprocess=False, | |
| ) | |
| save_session_state(session_id, pack_state(gaussian, mesh)) | |
| video_path = os.path.join(session_dir, 'preview.mp4') | |
| inference.render_video(gaussian, video_path) | |
| inference.cleanup_gpu() | |
| processing_time = round(time.time() - start_time, 2) | |
| print(f"[GENERATE] Complete in {processing_time}s", flush=True) | |
| return { | |
| "session_id": session_id, | |
| "video": encode_file(video_path), | |
| "step": "generate", | |
| "status": "complete", | |
| "processing_time": processing_time, | |
| "next_step": "export" | |
| } | |
| def handle_generate_multi(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Step 2 (multi): Generate 3D from multiple images.""" | |
| start_time = time.time() | |
| session_id = input_data.get("session_id") | |
| if not session_id: | |
| return {"error": "No session_id provided"} | |
| session_dir = get_session_dir(session_id) | |
| count_file = os.path.join(session_dir, 'image_count') | |
| if not os.path.exists(count_file): | |
| return {"error": "No multi-image session found"} | |
| with open(count_file, 'r') as f: | |
| image_count = int(f.read().strip()) | |
| preprocessed_images = [] | |
| for i in range(image_count): | |
| img_path = os.path.join(session_dir, f'preprocessed_{i}.png') | |
| if not os.path.exists(img_path): | |
| return {"error": f"Missing preprocessed image {i}"} | |
| preprocessed_images.append(Image.open(img_path)) | |
| gaussian, mesh = inference.generate_3d_multi_image( | |
| images=preprocessed_images, | |
| seed=input_data.get("seed", 0), | |
| ss_guidance_strength=input_data.get("ss_guidance_strength", 9), | |
| ss_sampling_steps=input_data.get("ss_sampling_steps", 30), | |
| slat_guidance_strength=input_data.get("slat_guidance_strength", 9), | |
| slat_sampling_steps=input_data.get("slat_sampling_steps", 50), | |
| multiimage_algo=input_data.get("multiimage_algo", "stochastic"), | |
| preprocess=False, | |
| ) | |
| save_session_state(session_id, pack_state(gaussian, mesh)) | |
| video_path = os.path.join(session_dir, 'preview.mp4') | |
| inference.render_video(gaussian, video_path) | |
| inference.cleanup_gpu() | |
| processing_time = round(time.time() - start_time, 2) | |
| print(f"[GENERATE_MULTI] Complete in {processing_time}s", flush=True) | |
| return { | |
| "session_id": session_id, | |
| "video": encode_file(video_path), | |
| "step": "generate_multi", | |
| "status": "complete", | |
| "processing_time": processing_time, | |
| "next_step": "export" | |
| } | |
| def handle_export(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Step 3: Export GLB/PLY (no model loading needed).""" | |
| start_time = time.time() | |
| session_id = input_data.get("session_id") | |
| if not session_id: | |
| return {"error": "No session_id provided"} | |
| try: | |
| state = load_session_state(session_id) | |
| except ValueError as e: | |
| return {"error": str(e)} | |
| session_dir = get_session_dir(session_id) | |
| gaussian, mesh = unpack_state(state) | |
| output_format = input_data.get("output_format", "glb") | |
| mesh_simplify = input_data.get("mesh_simplify", 0.9) | |
| texture_size = input_data.get("texture_size", 1024) | |
| print(f"[EXPORT] format={output_format}, simplify={mesh_simplify}, texture={texture_size}", flush=True) | |
| result = {} | |
| if output_format == "glb": | |
| output_path = os.path.join(session_dir, "output.glb") | |
| inference.export_glb(gaussian, mesh, output_path, mesh_simplify, texture_size) | |
| result["glb"] = encode_file(output_path) | |
| elif output_format == "glb_raw": | |
| output_path = os.path.join(session_dir, "output_raw.glb") | |
| inference.export_glb_raw(gaussian, mesh, output_path, texture_size) | |
| result["glb"] = encode_file(output_path) | |
| elif output_format == "gaussian": | |
| output_path = os.path.join(session_dir, "output.ply") | |
| inference.export_gaussian(gaussian, output_path) | |
| result["gaussian"] = encode_file(output_path) | |
| elif output_format == "video": | |
| video_path = os.path.join(session_dir, "output.mp4") | |
| inference.render_video(gaussian, video_path) | |
| result["video"] = encode_file(video_path) | |
| else: | |
| return {"error": f"Unknown output_format: {output_format}"} | |
| inference.cleanup_gpu() | |
| processing_time = round(time.time() - start_time, 2) | |
| print(f"[EXPORT] Complete in {processing_time}s", flush=True) | |
| return { | |
| "session_id": session_id, | |
| "step": "export", | |
| "status": "complete", | |
| "processing_time": processing_time, | |
| **result | |
| } | |
| def handle_cleanup(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Cleanup session data.""" | |
| session_id = input_data.get("session_id") | |
| if not session_id: | |
| return {"error": "No session_id provided"} | |
| session_dir = get_session_dir(session_id) | |
| if os.path.exists(session_dir): | |
| shutil.rmtree(session_dir, ignore_errors=True) | |
| return {"session_id": session_id, "step": "cleanup", "status": "complete"} | |
| def handle_full_pipeline(input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Run entire pipeline in ONE job (preprocess → generate → export).""" | |
| total_start = time.time() | |
| step_timings = {} | |
| # === PREPROCESS === | |
| step_start = time.time() | |
| if "image" not in input_data: | |
| return {"error": "No image provided", "step": "full_pipeline"} | |
| session_id = create_session() | |
| session_dir = get_session_dir(session_id) | |
| image = decode_image(input_data["image"]) | |
| preprocessed = inference.preprocess_image(image) | |
| save_preprocessed_image(session_id, preprocessed) | |
| preprocessed_b64 = encode_image(preprocessed) | |
| step_timings["preprocess"] = round(time.time() - step_start, 2) | |
| print(f"[FULL] Preprocess: {step_timings['preprocess']}s", flush=True) | |
| # === GENERATE === | |
| step_start = time.time() | |
| gaussian, mesh = inference.generate_3d( | |
| image=preprocessed, | |
| seed=input_data.get("seed", 0), | |
| ss_guidance_strength=input_data.get("ss_guidance_strength", 9), | |
| ss_sampling_steps=input_data.get("ss_sampling_steps", 30), | |
| slat_guidance_strength=input_data.get("slat_guidance_strength", 9), | |
| slat_sampling_steps=input_data.get("slat_sampling_steps", 50), | |
| preprocess=False, | |
| ) | |
| save_session_state(session_id, pack_state(gaussian, mesh)) | |
| video_path = os.path.join(session_dir, 'preview.mp4') | |
| inference.render_video(gaussian, video_path) | |
| video_b64 = encode_file(video_path) | |
| step_timings["generate"] = round(time.time() - step_start, 2) | |
| print(f"[FULL] Generate: {step_timings['generate']}s", flush=True) | |
| # === EXPORT === | |
| step_start = time.time() | |
| output_format = input_data.get("output_format", "glb") | |
| mesh_simplify = input_data.get("mesh_simplify", 0.9) | |
| texture_size = input_data.get("texture_size", 1024) | |
| result = {} | |
| if output_format == "glb": | |
| output_path = os.path.join(session_dir, "output.glb") | |
| inference.export_glb(gaussian, mesh, output_path, mesh_simplify, texture_size) | |
| result["glb"] = encode_file(output_path) | |
| elif output_format == "glb_raw": | |
| output_path = os.path.join(session_dir, "output_raw.glb") | |
| inference.export_glb_raw(gaussian, mesh, output_path, texture_size) | |
| result["glb"] = encode_file(output_path) | |
| elif output_format == "gaussian": | |
| output_path = os.path.join(session_dir, "output.ply") | |
| inference.export_gaussian(gaussian, output_path) | |
| result["gaussian"] = encode_file(output_path) | |
| else: | |
| return {"error": f"Unknown output_format: {output_format}", "step": "full_pipeline"} | |
| step_timings["export"] = round(time.time() - step_start, 2) | |
| print(f"[FULL] Export: {step_timings['export']}s", flush=True) | |
| # === CLEANUP === | |
| inference.cleanup_gpu() | |
| total_time = round(time.time() - total_start, 2) | |
| print(f"[FULL] ✅ Complete in {total_time}s", flush=True) | |
| return { | |
| "session_id": session_id, | |
| "preprocessed_image": preprocessed_b64, | |
| "video": video_b64, | |
| **result, | |
| "step": "full_pipeline", | |
| "status": "complete", | |
| "processing_time": total_time, | |
| "step_timings": step_timings, | |
| } | |
| # ============================================================================= | |
| # MAIN HANDLER | |
| # ============================================================================= | |
| def _run_step(step: str, input_data: Dict[str, Any]) -> Dict[str, Any]: | |
| """Execute step handler synchronously.""" | |
| if step == "preprocess": | |
| return handle_preprocess(input_data) | |
| elif step == "preprocess_multi": | |
| return handle_preprocess_multi(input_data) | |
| elif step == "generate": | |
| return handle_generate(input_data) | |
| elif step == "generate_multi": | |
| return handle_generate_multi(input_data) | |
| elif step == "export": | |
| return handle_export(input_data) | |
| elif step == "cleanup": | |
| return handle_cleanup(input_data) | |
| elif step == "full_pipeline": | |
| return handle_full_pipeline(input_data) | |
| else: | |
| return {"error": f"Unknown step: {step}"} | |
| async def handler(job: Dict[str, Any]) -> Dict[str, Any]: | |
| """Async handler for RunPod serverless.""" | |
| try: | |
| job_id = job.get("id", "unknown") | |
| input_data = job.get("input", {}) | |
| step = input_data.get("step", "preprocess") | |
| session_id = input_data.get("session_id", "new") | |
| print(f"[JOB] {job_id[:8]} | step={step} | session={session_id[:8] if len(session_id) > 8 else session_id}", flush=True) | |
| result = await asyncio.to_thread(_run_step, step, input_data) | |
| # Cleanup AFTER the job is done (Zero impact on generation) | |
| threading.Thread(target=cleanup_old_sessions, daemon=True).start() | |
| print(f"[JOB] {job_id[:8]} | done", flush=True) | |
| return result | |
| except Exception as e: | |
| import traceback | |
| print(f"[ERROR] {e}", flush=True) | |
| return {"error": str(e), "details": traceback.format_exc()} | |
| def concurrency_modifier(current_concurrency: int) -> int: | |
| """Return max concurrent jobs. Called by RunPod.""" | |
| return MAX_CONCURRENT_JOBS | |
| if __name__ == "__main__": | |
| print(f"[STARTUP] MAX_CONCURRENT_JOBS={MAX_CONCURRENT_JOBS}", flush=True) | |
| runpod.serverless.start({ | |
| "handler": handler, | |
| "concurrency_modifier": concurrency_modifier | |
| }) | |