import os import shutil import sys import traceback import uuid from pathlib import Path from typing import * import gradio as gr import numpy as np import rembg import spaces import torch import trimesh from PIL import Image from gradio_litmodel3d import LitModel3D sys.path.append(os.getcwd()) sys.path.append(os.path.join(os.getcwd(), 'third_parties/dsine')) # IMPORTANT: Do NOT import anything from `anigen.*` or `third_parties.*` at # module scope. `anigen/__init__.py` eagerly imports `anigen.models`, # `anigen.modules`, `anigen.pipelines`, etc., which pulls in `warp` and other # native libs. When warp imports it tries to init CUDA, and on ZeroGPU the main # process has no GPU, so it sets a bad global CUDA state. Any subsequent # `@spaces.GPU` forked worker then dies silently with "GPU task aborted" before # the task body runs. Keeping the main process free of `anigen` imports avoids # this. The worker imports `anigen` fresh and works correctly. MAX_SEED = 100 TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') os.makedirs(TMP_DIR, exist_ok=True) SS_MODEL_CHOICES = ["ss_flow_duet", "ss_flow_solo", "ss_flow_epic"] SLAT_MODEL_CHOICES = ["slat_flow_auto", "slat_flow_control"] DEFAULT_SS_MODEL = "ss_flow_duet" DEFAULT_SLAT_MODEL = "slat_flow_auto" current_ss_model_name = DEFAULT_SS_MODEL current_slat_model_name = DEFAULT_SLAT_MODEL pipeline = None rembg_session = None def get_runtime_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def get_session_dir(session_id: Optional[str]) -> str: target_session = session_id or uuid.uuid4().hex user_dir = os.path.join(TMP_DIR, target_session) os.makedirs(user_dir, exist_ok=True) return user_dir def get_pipeline(device: Optional[str] = None): # NOTE: This function must only be called from inside a `@spaces.GPU` # worker. The ZeroGPU pattern of pre-loading on CPU in the main process and # then calling `.to("cuda")` inside the worker crashes silently for the # AniGen pipeline (worker exits with "GPU task aborted" before any print). # We therefore lazily create a fresh pipeline *inside* the worker and cache # it in a module-level global so that subsequent reused workers skip the # 40+s load. global pipeline device = device or get_runtime_device() if pipeline is None: from anigen.pipelines import AnigenImageTo3DPipeline print(f"[app_hf] Initializing pipeline on {device}...", flush=True) pipeline = AnigenImageTo3DPipeline.from_pretrained( ss_flow_path=f'ckpts/anigen/{DEFAULT_SS_MODEL}', slat_flow_path=f'ckpts/anigen/{DEFAULT_SLAT_MODEL}', device=device, use_ema=False, ) print(f"[app_hf] Pipeline initialized on {device}.", flush=True) elif pipeline.device.type != device: print(f"[app_hf] Moving pipeline from {pipeline.device.type} to {device}...", flush=True) pipeline.to(torch.device(device)) print(f"[app_hf] Pipeline moved to {device}.", flush=True) return pipeline def get_rembg_session(): global rembg_session if rembg_session is None: print("[app_hf] Initializing rembg u2net session on CPU...", flush=True) rembg_session = rembg.new_session("birefnet-general") print("[app_hf] rembg session ready.", flush=True) return rembg_session def start_session(req: gr.Request): get_session_dir(req.session_hash) def end_session(req: gr.Request): shutil.rmtree(get_session_dir(req.session_hash), ignore_errors=True) def preprocess_for_display_and_inference(image: Optional[Image.Image]) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: if image is None: return None, None has_alpha = False if image.mode == 'RGBA': alpha = np.array(image)[:, :, 3] if not np.all(alpha == 255): has_alpha = True if has_alpha: rgba_output = image.convert('RGBA') else: input_image = image.convert('RGB') max_size = max(input_image.size) scale = min(1, 1024 / max_size) if scale < 1: input_image = input_image.resize( (int(input_image.width * scale), int(input_image.height * scale)), Image.Resampling.LANCZOS, ) rgba_output = rembg.remove(input_image, session=get_rembg_session()) output_np = np.array(rgba_output) alpha = output_np[:, :, 3] bbox = np.argwhere(alpha > 0.8 * 255) if len(bbox) == 0: bbox_crop = (0, 0, rgba_output.width, rgba_output.height) else: bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) size = int(size * 1.2) bbox_crop = ( int(center[0] - size // 2), int(center[1] - size // 2), int(center[0] + size // 2), int(center[1] + size // 2), ) rgba_output = rgba_output.crop(bbox_crop) rgba_output = rgba_output.resize((518, 518), Image.Resampling.LANCZOS) display_np = np.array(rgba_output).astype(np.float32) / 255 display_np = display_np[:, :, :3] * display_np[:, :, 3:4] display_image = Image.fromarray((display_np * 255).astype(np.uint8)) return display_image, rgba_output def save_processed_rgba(image: Optional[Image.Image], session_id: Optional[str] = None) -> Optional[str]: if image is None: return None file_path = os.path.join(get_session_dir(session_id), 'processed_input_rgba.png') image.save(file_path) return file_path def load_processed_rgba(path: Optional[str]) -> Optional[Image.Image]: if not path: return None file_path = Path(path) if not file_path.exists(): return None return Image.open(file_path).convert('RGBA') def prepare_input_for_generation(image: Optional[Image.Image], req: gr.Request = None) -> Tuple[Optional[str], Optional[Image.Image]]: print('[app_hf] Preparing input on CPU before GPU stage...', flush=True) processed_image, processed_rgba = preprocess_for_display_and_inference(image) processed_rgba_path = save_processed_rgba(processed_rgba, req.session_hash if req else None) print(f'[app_hf] CPU preprocessing completed. path={processed_rgba_path}', flush=True) return processed_rgba_path, processed_image def get_seed(randomize_seed: bool, seed: int) -> int: return np.random.randint(0, MAX_SEED) if randomize_seed else seed def on_slat_model_change(slat_model_name: str): is_control = (slat_model_name == 'slat_flow_control') return ( gr.update(interactive=is_control), gr.update(visible=not is_control), ) def save_generation_state(state: Dict[str, Any], session_id: Optional[str]) -> str: state_path = os.path.join(get_session_dir(session_id), 'generation_state.pt') torch.save(state, state_path) return state_path def load_generation_state(state_path: str) -> Dict[str, Any]: return torch.load(state_path, map_location='cpu') def export_preview_assets( user_dir: str, orig_vertices: np.ndarray, orig_faces: np.ndarray, joints: np.ndarray, parents: np.ndarray, skin_weights: np.ndarray, vertex_colors: Optional[np.ndarray], ) -> Tuple[str, Optional[str]]: # Lazy import: see the note at the top of the file about not importing # anigen in the main process. from anigen.utils.export_utils import convert_to_glb_from_data, visualize_skeleton_as_mesh preview_mesh_path = os.path.join(user_dir, 'preview_mesh.glb') preview_skeleton_path = os.path.join(user_dir, 'preview_skeleton.glb') preview_mesh = trimesh.Trimesh(vertices=orig_vertices, faces=orig_faces, process=False) convert_to_glb_from_data( preview_mesh, joints, parents, skin_weights, preview_mesh_path, vertex_colors=vertex_colors, texture_image=None, ) skeleton_mesh = visualize_skeleton_as_mesh(joints, parents) if skeleton_mesh is not None and len(skeleton_mesh.vertices) > 0: skeleton_mesh.export(preview_skeleton_path) else: preview_skeleton_path = None return preview_mesh_path, preview_skeleton_path def get_user_dir_from_artifact_path(path: Optional[str]) -> str: if not path: raise gr.Error('Missing intermediate artifact path.') return str(Path(path).resolve().parent) def update_download_buttons(mesh_path: Optional[str], skel_path: Optional[str]): return ( gr.update(value=mesh_path, interactive=bool(mesh_path)), gr.update(value=skel_path, interactive=bool(skel_path)), ) def disable_download_buttons(): return ( gr.update(value=None, interactive=False), gr.update(value=None, interactive=False), ) def mark_generate_queued(ss_sampling_steps: int, slat_sampling_steps: int): return ( f'CPU preprocessing done. Waiting for ZeroGPU allocation for preview generation ' f'(SS steps: {ss_sampling_steps}, SLat steps: {slat_sampling_steps}).' ) def mark_extract_queued(texture_size: int, simplify_ratio: float, fill_holes: bool): return ( f'Waiting for ZeroGPU allocation for final GLB extraction ' f'(texture: {texture_size}, simplify: {simplify_ratio:.2f}, fill_holes: {fill_holes}).' ) @spaces.GPU(duration=120) def generate_preview( processed_input_rgba_path: Optional[str], seed: int, ss_model_name: str, slat_model_name: str, ss_guidance_strength: float, ss_sampling_steps: int, slat_guidance_strength: float, slat_sampling_steps: int, joints_density: int, progress=gr.Progress(track_tqdm=False), ) -> Tuple[str, str, Optional[str], str]: global current_ss_model_name, current_slat_model_name print('[app_hf] generate_preview: entered GPU function, requesting pipeline...', flush=True) try: from anigen.utils.skin_utils import repair_skeleton_parents device = get_runtime_device() print(f'[app_hf] generate_preview started on device={device}', flush=True) processed_input_rgba = load_processed_rgba(processed_input_rgba_path) if processed_input_rgba is None: raise gr.Error('Missing processed input image. Please upload an image and click Generate again.') print('[app_hf] processed input loaded; initializing pipeline next.', flush=True) pipe = get_pipeline(device) if ss_model_name != current_ss_model_name: progress(0, desc=f'Loading SS model: {ss_model_name}...') pipe.load_ss_flow_model(f'ckpts/anigen/{ss_model_name}', device=device, use_ema=False) current_ss_model_name = ss_model_name if slat_model_name != current_slat_model_name: progress(0, desc=f'Loading SLAT model: {slat_model_name}...') pipe.load_slat_flow_model(f'ckpts/anigen/{slat_model_name}', device=device, use_ema=False) current_slat_model_name = slat_model_name torch.manual_seed(seed) np.random.seed(seed) progress(0.02, desc='Estimating normals...') processed_image, processed_normal = pipe.preprocess_image(processed_input_rgba) print('[app_hf] preprocessing on GPU worker finished.', flush=True) progress(0.08, desc='Encoding image conditions...') cond_dict_ss, cond_dict_slat_rgb = pipe.get_cond(processed_image, processed_normal) print('[app_hf] conditioning ready.', flush=True) def ss_progress_callback(step, total): frac = (step + 1) / total progress(0.10 + frac * 0.40, desc=f'SS Sampling: {step + 1}/{total}') def slat_progress_callback(step, total): frac = (step + 1) / total progress(0.50 + frac * 0.40, desc=f'SLat Sampling: {step + 1}/{total}') coords, coords_skl, _, _ = pipe.sample_sparse_structure( cond_dict_ss, strength=ss_guidance_strength, steps=ss_sampling_steps, progress_callback=ss_progress_callback, ) print('[app_hf] sparse structure sampled.', flush=True) slat, slat_skl = pipe.sample_slat( cond_dict_slat_rgb, coords, coords_skl, strength=slat_guidance_strength, steps=slat_sampling_steps, joint_density=joints_density, progress_callback=slat_progress_callback, ) print('[app_hf] slat sampled.', flush=True) progress(0.92, desc='Decoding preview mesh...') mesh_result, skeleton_result = pipe.decode_slat(slat, slat_skl) print('[app_hf] decode finished.', flush=True) joints = skeleton_result.joints_grouped.detach().cpu().to(torch.float32) parents = skeleton_result.parents_grouped.detach().cpu().to(torch.int32) parents_np = repair_skeleton_parents( joints=joints.numpy(), parents=parents.numpy(), verbose=False, ).astype(np.int32) parents = torch.from_numpy(parents_np) skin_weights = skeleton_result.skin_pred.detach().cpu().to(torch.float32) orig_vertices = mesh_result.vertices.detach().cpu().to(torch.float32) orig_faces = mesh_result.faces.detach().cpu().to(torch.long) vertex_attrs = None if getattr(mesh_result, 'vertex_attrs', None) is not None: vertex_attrs = mesh_result.vertex_attrs.detach().cpu().to(torch.float32) vertex_colors = None if vertex_attrs is not None and vertex_attrs.shape[-1] >= 3: vertex_colors = vertex_attrs[:, :3].numpy() user_dir = get_user_dir_from_artifact_path(processed_input_rgba_path) preview_mesh_path, preview_skeleton_path = export_preview_assets( user_dir=user_dir, orig_vertices=orig_vertices.numpy(), orig_faces=orig_faces.numpy(), joints=joints.numpy(), parents=parents.numpy(), skin_weights=skin_weights.numpy(), vertex_colors=vertex_colors, ) state = { 'orig_vertices': orig_vertices.contiguous(), 'orig_faces': orig_faces.contiguous(), 'vertex_attrs': vertex_attrs.contiguous() if vertex_attrs is not None else None, 'joints': joints.contiguous(), 'parents': parents.contiguous(), 'skin_weights': skin_weights.contiguous(), 'mesh_res': int(getattr(mesh_result, 'res', 64)), } state_path = os.path.join(user_dir, 'generation_state.pt') torch.save(state, state_path) print(f'[app_hf] Preview ready. State saved to {state_path}', flush=True) status = 'Preview ready. Click “Extract GLB” to run simplification and texture baking.' return state_path, preview_mesh_path, preview_skeleton_path, status except Exception as exc: print(f'[app_hf] generate_preview failed: {exc}', flush=True) print(traceback.format_exc(), flush=True) raise finally: # Don't move pipeline back to CPU: the ZeroGPU worker is reused across # calls, so keeping the pipeline on GPU lets subsequent calls skip the # 40+s load. Moving to CPU and back has also been shown to trigger # silent worker aborts for this pipeline. if torch.cuda.is_available(): torch.cuda.empty_cache() @spaces.GPU(duration=120) def extract_glb( generation_state_path: Optional[str], texture_size: int, simplify_ratio: float, fill_holes: bool, progress=gr.Progress(track_tqdm=False), ) -> Tuple[str, Optional[str], str]: try: from anigen.representations.mesh.cube2mesh_skeleton import AniGenMeshExtractResult from anigen.utils.export_utils import convert_to_glb_from_data, visualize_skeleton_as_mesh from anigen.utils.postprocessing_utils import ( bake_texture, barycentric_transfer_attributes, parametrize_mesh, postprocess_mesh, ) from anigen.utils.render_utils import render_multiview from anigen.utils.skin_utils import ( filter_skinning_weights, repair_skeleton_parents, smooth_skin_weights_on_mesh, ) if not generation_state_path or not os.path.exists(generation_state_path): raise gr.Error('Please click Generate first to create a preview state.') device = get_runtime_device() print(f'[app_hf] extract_glb started on device={device}', flush=True) state = load_generation_state(generation_state_path) orig_vertices = state['orig_vertices'].numpy() orig_faces = state['orig_faces'].numpy() joints = state['joints'].numpy() parents = repair_skeleton_parents( joints=joints, parents=state['parents'].numpy().astype(np.int32), verbose=False, ).astype(np.int32) skin_weights = state['skin_weights'].numpy() vertex_attrs_cpu = state.get('vertex_attrs') vertex_colors = None if vertex_attrs_cpu is not None and vertex_attrs_cpu.shape[-1] >= 3: vertex_colors = vertex_attrs_cpu[:, :3].numpy() user_dir = get_user_dir_from_artifact_path(generation_state_path) output_glb_path = os.path.join(user_dir, 'mesh.glb') skeleton_glb_path = os.path.join(user_dir, 'skeleton.glb') progress(0.02, desc='Simplifying mesh...') new_vertices, new_faces = postprocess_mesh( orig_vertices, orig_faces, simplify=(simplify_ratio > 0), simplify_ratio=simplify_ratio, fill_holes=fill_holes, verbose=True, ) if new_vertices.shape[0] != orig_vertices.shape[0]: orig_mesh = trimesh.Trimesh(vertices=orig_vertices, faces=orig_faces, process=False) skin_weights = barycentric_transfer_attributes(orig_mesh, skin_weights, new_vertices) if vertex_colors is not None: vertex_colors = barycentric_transfer_attributes(orig_mesh, vertex_colors, new_vertices) mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces, process=False) # progress(0.30, desc='Filtering skin weights...') # skin_weights = filter_skinning_weights(mesh, skin_weights, joints, parents) progress(0.42, desc='Smoothing skin weights...') skin_weights = smooth_skin_weights_on_mesh( mesh, skin_weights, iterations=100, alpha=1.0, ) texture_image = None if int(texture_size) > 0: progress(0.55, desc='Parameterizing UVs...') uv_vertices, uv_faces, uvs, vmapping = parametrize_mesh(new_vertices, new_faces) skin_weights = skin_weights[vmapping] if vertex_colors is not None: vertex_colors = vertex_colors[vmapping] vertex_attrs = None if vertex_attrs_cpu is not None: vertex_attrs = vertex_attrs_cpu.to(device=device, dtype=torch.float32) mesh_result = AniGenMeshExtractResult( vertices=torch.as_tensor(orig_vertices, device=device, dtype=torch.float32), faces=torch.as_tensor(orig_faces, device=device, dtype=torch.long), vertex_attrs=vertex_attrs, res=int(state.get('mesh_res', 64)), ) progress(0.65, desc='Rendering teacher views...') observations, extrinsics_mv, intrinsics_mv = render_multiview( mesh_result, resolution=1024, nviews=100, ) masks = [np.any(obs > 0, axis=-1) for obs in observations] extrinsics_np = [e.detach().cpu().numpy() for e in extrinsics_mv] intrinsics_np = [i.detach().cpu().numpy() for i in intrinsics_mv] progress(0.78, desc='Baking texture...') with torch.enable_grad(): texture_image = bake_texture( uv_vertices, uv_faces, uvs, observations, masks, extrinsics_np, intrinsics_np, texture_size=int(texture_size), mode='opt', lambda_tv=0.01, verbose=True, ) mesh = trimesh.Trimesh( vertices=uv_vertices, faces=uv_faces, visual=trimesh.visual.TextureVisuals(uv=uvs), process=False, ) progress(0.94, desc='Exporting GLB...') convert_to_glb_from_data( mesh, joints, parents, skin_weights, output_glb_path, vertex_colors=vertex_colors, texture_image=texture_image, ) skeleton_mesh = visualize_skeleton_as_mesh(joints, parents) if skeleton_mesh is not None and len(skeleton_mesh.vertices) > 0: skeleton_mesh.export(skeleton_glb_path) else: skeleton_glb_path = None if torch.cuda.is_available(): torch.cuda.empty_cache() print('[app_hf] Final GLB extraction completed.', flush=True) return output_glb_path, skeleton_glb_path, 'Final GLB ready with mesh simplification and texture baking.' except Exception as exc: print(f'[app_hf] extract_glb failed: {exc}', flush=True) print(traceback.format_exc(), flush=True) raise with gr.Blocks(delete_cache=(600, 600)) as demo: gr.Markdown( """ ## Image to 3D Asset with [AniGen] * Click **Generate** for a fast preview, then click **Extract GLB** for the full textured glb file export. * [AniGen GitHub Repository](https://github.com/VAST-AI-Research/AniGen) * [Tripo: Your 3D Workspace with AI](https://www.tripo3d.ai) """ ) gr.HTML("""
ss_flow_solo or ss_flow_duet in Generation Settings.