| import os |
| import shutil |
| import sys |
| import traceback |
| import uuid |
| from pathlib import Path |
| from typing import * |
|
|
| import gradio as gr |
| import numpy as np |
| import rembg |
| import spaces |
| import torch |
| import trimesh |
| from PIL import Image |
| from gradio_litmodel3d import LitModel3D |
|
|
| sys.path.append(os.getcwd()) |
| sys.path.append(os.path.join(os.getcwd(), 'third_parties/dsine')) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| MAX_SEED = 100 |
| TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp') |
| os.makedirs(TMP_DIR, exist_ok=True) |
|
|
| SS_MODEL_CHOICES = ["ss_flow_duet", "ss_flow_solo", "ss_flow_epic"] |
| SLAT_MODEL_CHOICES = ["slat_flow_auto", "slat_flow_control"] |
| DEFAULT_SS_MODEL = "ss_flow_duet" |
| DEFAULT_SLAT_MODEL = "slat_flow_auto" |
|
|
| current_ss_model_name = DEFAULT_SS_MODEL |
| current_slat_model_name = DEFAULT_SLAT_MODEL |
| pipeline = None |
| rembg_session = None |
|
|
|
|
| def get_runtime_device() -> str: |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def get_session_dir(session_id: Optional[str]) -> str: |
| target_session = session_id or uuid.uuid4().hex |
| user_dir = os.path.join(TMP_DIR, target_session) |
| os.makedirs(user_dir, exist_ok=True) |
| return user_dir |
|
|
|
|
| def get_pipeline(device: Optional[str] = None): |
| |
| |
| |
| |
| |
| |
| |
| global pipeline |
|
|
| device = device or get_runtime_device() |
|
|
| if pipeline is None: |
| from anigen.pipelines import AnigenImageTo3DPipeline |
|
|
| print(f"[app_hf] Initializing pipeline on {device}...", flush=True) |
| pipeline = AnigenImageTo3DPipeline.from_pretrained( |
| ss_flow_path=f'ckpts/anigen/{DEFAULT_SS_MODEL}', |
| slat_flow_path=f'ckpts/anigen/{DEFAULT_SLAT_MODEL}', |
| device=device, |
| use_ema=False, |
| ) |
| print(f"[app_hf] Pipeline initialized on {device}.", flush=True) |
| elif pipeline.device.type != device: |
| print(f"[app_hf] Moving pipeline from {pipeline.device.type} to {device}...", flush=True) |
| pipeline.to(torch.device(device)) |
| print(f"[app_hf] Pipeline moved to {device}.", flush=True) |
|
|
| return pipeline |
|
|
|
|
| def get_rembg_session(): |
| global rembg_session |
| if rembg_session is None: |
| print("[app_hf] Initializing rembg u2net session on CPU...", flush=True) |
| rembg_session = rembg.new_session("birefnet-general") |
| print("[app_hf] rembg session ready.", flush=True) |
| return rembg_session |
|
|
|
|
| def start_session(req: gr.Request): |
| get_session_dir(req.session_hash) |
|
|
|
|
| def end_session(req: gr.Request): |
| shutil.rmtree(get_session_dir(req.session_hash), ignore_errors=True) |
|
|
|
|
| def preprocess_for_display_and_inference(image: Optional[Image.Image]) -> Tuple[Optional[Image.Image], Optional[Image.Image]]: |
| if image is None: |
| return None, None |
|
|
| has_alpha = False |
| if image.mode == 'RGBA': |
| alpha = np.array(image)[:, :, 3] |
| if not np.all(alpha == 255): |
| has_alpha = True |
|
|
| if has_alpha: |
| rgba_output = image.convert('RGBA') |
| else: |
| input_image = image.convert('RGB') |
| max_size = max(input_image.size) |
| scale = min(1, 1024 / max_size) |
| if scale < 1: |
| input_image = input_image.resize( |
| (int(input_image.width * scale), int(input_image.height * scale)), |
| Image.Resampling.LANCZOS, |
| ) |
| rgba_output = rembg.remove(input_image, session=get_rembg_session()) |
|
|
| output_np = np.array(rgba_output) |
| alpha = output_np[:, :, 3] |
| bbox = np.argwhere(alpha > 0.8 * 255) |
| if len(bbox) == 0: |
| bbox_crop = (0, 0, rgba_output.width, rgba_output.height) |
| else: |
| bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0]) |
| center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2 |
| size = max(bbox[2] - bbox[0], bbox[3] - bbox[1]) |
| size = int(size * 1.2) |
| bbox_crop = ( |
| int(center[0] - size // 2), |
| int(center[1] - size // 2), |
| int(center[0] + size // 2), |
| int(center[1] + size // 2), |
| ) |
|
|
| rgba_output = rgba_output.crop(bbox_crop) |
| rgba_output = rgba_output.resize((518, 518), Image.Resampling.LANCZOS) |
|
|
| display_np = np.array(rgba_output).astype(np.float32) / 255 |
| display_np = display_np[:, :, :3] * display_np[:, :, 3:4] |
| display_image = Image.fromarray((display_np * 255).astype(np.uint8)) |
| return display_image, rgba_output |
|
|
|
|
| def save_processed_rgba(image: Optional[Image.Image], session_id: Optional[str] = None) -> Optional[str]: |
| if image is None: |
| return None |
|
|
| file_path = os.path.join(get_session_dir(session_id), 'processed_input_rgba.png') |
| image.save(file_path) |
| return file_path |
|
|
|
|
| def load_processed_rgba(path: Optional[str]) -> Optional[Image.Image]: |
| if not path: |
| return None |
| file_path = Path(path) |
| if not file_path.exists(): |
| return None |
| return Image.open(file_path).convert('RGBA') |
|
|
|
|
| def prepare_input_for_generation(image: Optional[Image.Image], req: gr.Request = None) -> Tuple[Optional[str], Optional[Image.Image]]: |
| print('[app_hf] Preparing input on CPU before GPU stage...', flush=True) |
| processed_image, processed_rgba = preprocess_for_display_and_inference(image) |
| processed_rgba_path = save_processed_rgba(processed_rgba, req.session_hash if req else None) |
| print(f'[app_hf] CPU preprocessing completed. path={processed_rgba_path}', flush=True) |
| return processed_rgba_path, processed_image |
|
|
|
|
| def get_seed(randomize_seed: bool, seed: int) -> int: |
| return np.random.randint(0, MAX_SEED) if randomize_seed else seed |
|
|
|
|
| def on_slat_model_change(slat_model_name: str): |
| is_control = (slat_model_name == 'slat_flow_control') |
| return ( |
| gr.update(interactive=is_control), |
| gr.update(visible=not is_control), |
| ) |
|
|
|
|
| def save_generation_state(state: Dict[str, Any], session_id: Optional[str]) -> str: |
| state_path = os.path.join(get_session_dir(session_id), 'generation_state.pt') |
| torch.save(state, state_path) |
| return state_path |
|
|
|
|
| def load_generation_state(state_path: str) -> Dict[str, Any]: |
| return torch.load(state_path, map_location='cpu') |
|
|
|
|
| def export_preview_assets( |
| user_dir: str, |
| orig_vertices: np.ndarray, |
| orig_faces: np.ndarray, |
| joints: np.ndarray, |
| parents: np.ndarray, |
| skin_weights: np.ndarray, |
| vertex_colors: Optional[np.ndarray], |
| ) -> Tuple[str, Optional[str]]: |
| |
| |
| from anigen.utils.export_utils import convert_to_glb_from_data, visualize_skeleton_as_mesh |
|
|
| preview_mesh_path = os.path.join(user_dir, 'preview_mesh.glb') |
| preview_skeleton_path = os.path.join(user_dir, 'preview_skeleton.glb') |
|
|
| preview_mesh = trimesh.Trimesh(vertices=orig_vertices, faces=orig_faces, process=False) |
| convert_to_glb_from_data( |
| preview_mesh, |
| joints, |
| parents, |
| skin_weights, |
| preview_mesh_path, |
| vertex_colors=vertex_colors, |
| texture_image=None, |
| ) |
|
|
| skeleton_mesh = visualize_skeleton_as_mesh(joints, parents) |
| if skeleton_mesh is not None and len(skeleton_mesh.vertices) > 0: |
| skeleton_mesh.export(preview_skeleton_path) |
| else: |
| preview_skeleton_path = None |
|
|
| return preview_mesh_path, preview_skeleton_path |
|
|
|
|
| def get_user_dir_from_artifact_path(path: Optional[str]) -> str: |
| if not path: |
| raise gr.Error('Missing intermediate artifact path.') |
| return str(Path(path).resolve().parent) |
|
|
|
|
| def update_download_buttons(mesh_path: Optional[str], skel_path: Optional[str]): |
| return ( |
| gr.update(value=mesh_path, interactive=bool(mesh_path)), |
| gr.update(value=skel_path, interactive=bool(skel_path)), |
| ) |
|
|
|
|
| def disable_download_buttons(): |
| return ( |
| gr.update(value=None, interactive=False), |
| gr.update(value=None, interactive=False), |
| ) |
|
|
|
|
| def mark_generate_queued(ss_sampling_steps: int, slat_sampling_steps: int): |
| return ( |
| f'CPU preprocessing done. Waiting for ZeroGPU allocation for preview generation ' |
| f'(SS steps: {ss_sampling_steps}, SLat steps: {slat_sampling_steps}).' |
| ) |
|
|
|
|
| def mark_extract_queued(texture_size: int, simplify_ratio: float, fill_holes: bool): |
| return ( |
| f'Waiting for ZeroGPU allocation for final GLB extraction ' |
| f'(texture: {texture_size}, simplify: {simplify_ratio:.2f}, fill_holes: {fill_holes}).' |
| ) |
|
|
|
|
| @spaces.GPU(duration=120) |
| def generate_preview( |
| processed_input_rgba_path: Optional[str], |
| seed: int, |
| ss_model_name: str, |
| slat_model_name: str, |
| ss_guidance_strength: float, |
| ss_sampling_steps: int, |
| slat_guidance_strength: float, |
| slat_sampling_steps: int, |
| joints_density: int, |
| progress=gr.Progress(track_tqdm=False), |
| ) -> Tuple[str, str, Optional[str], str]: |
| global current_ss_model_name, current_slat_model_name |
|
|
| print('[app_hf] generate_preview: entered GPU function, requesting pipeline...', flush=True) |
| try: |
| from anigen.utils.skin_utils import repair_skeleton_parents |
|
|
| device = get_runtime_device() |
| print(f'[app_hf] generate_preview started on device={device}', flush=True) |
|
|
| processed_input_rgba = load_processed_rgba(processed_input_rgba_path) |
| if processed_input_rgba is None: |
| raise gr.Error('Missing processed input image. Please upload an image and click Generate again.') |
|
|
| print('[app_hf] processed input loaded; initializing pipeline next.', flush=True) |
| pipe = get_pipeline(device) |
|
|
| if ss_model_name != current_ss_model_name: |
| progress(0, desc=f'Loading SS model: {ss_model_name}...') |
| pipe.load_ss_flow_model(f'ckpts/anigen/{ss_model_name}', device=device, use_ema=False) |
| current_ss_model_name = ss_model_name |
|
|
| if slat_model_name != current_slat_model_name: |
| progress(0, desc=f'Loading SLAT model: {slat_model_name}...') |
| pipe.load_slat_flow_model(f'ckpts/anigen/{slat_model_name}', device=device, use_ema=False) |
| current_slat_model_name = slat_model_name |
|
|
| torch.manual_seed(seed) |
| np.random.seed(seed) |
|
|
| progress(0.02, desc='Estimating normals...') |
| processed_image, processed_normal = pipe.preprocess_image(processed_input_rgba) |
| print('[app_hf] preprocessing on GPU worker finished.', flush=True) |
|
|
| progress(0.08, desc='Encoding image conditions...') |
| cond_dict_ss, cond_dict_slat_rgb = pipe.get_cond(processed_image, processed_normal) |
| print('[app_hf] conditioning ready.', flush=True) |
|
|
| def ss_progress_callback(step, total): |
| frac = (step + 1) / total |
| progress(0.10 + frac * 0.40, desc=f'SS Sampling: {step + 1}/{total}') |
|
|
| def slat_progress_callback(step, total): |
| frac = (step + 1) / total |
| progress(0.50 + frac * 0.40, desc=f'SLat Sampling: {step + 1}/{total}') |
|
|
| coords, coords_skl, _, _ = pipe.sample_sparse_structure( |
| cond_dict_ss, |
| strength=ss_guidance_strength, |
| steps=ss_sampling_steps, |
| progress_callback=ss_progress_callback, |
| ) |
| print('[app_hf] sparse structure sampled.', flush=True) |
|
|
| slat, slat_skl = pipe.sample_slat( |
| cond_dict_slat_rgb, |
| coords, |
| coords_skl, |
| strength=slat_guidance_strength, |
| steps=slat_sampling_steps, |
| joint_density=joints_density, |
| progress_callback=slat_progress_callback, |
| ) |
| print('[app_hf] slat sampled.', flush=True) |
|
|
| progress(0.92, desc='Decoding preview mesh...') |
| mesh_result, skeleton_result = pipe.decode_slat(slat, slat_skl) |
| print('[app_hf] decode finished.', flush=True) |
|
|
| joints = skeleton_result.joints_grouped.detach().cpu().to(torch.float32) |
| parents = skeleton_result.parents_grouped.detach().cpu().to(torch.int32) |
| parents_np = repair_skeleton_parents( |
| joints=joints.numpy(), |
| parents=parents.numpy(), |
| verbose=False, |
| ).astype(np.int32) |
| parents = torch.from_numpy(parents_np) |
| skin_weights = skeleton_result.skin_pred.detach().cpu().to(torch.float32) |
| orig_vertices = mesh_result.vertices.detach().cpu().to(torch.float32) |
| orig_faces = mesh_result.faces.detach().cpu().to(torch.long) |
| vertex_attrs = None |
| if getattr(mesh_result, 'vertex_attrs', None) is not None: |
| vertex_attrs = mesh_result.vertex_attrs.detach().cpu().to(torch.float32) |
|
|
| vertex_colors = None |
| if vertex_attrs is not None and vertex_attrs.shape[-1] >= 3: |
| vertex_colors = vertex_attrs[:, :3].numpy() |
|
|
| user_dir = get_user_dir_from_artifact_path(processed_input_rgba_path) |
| preview_mesh_path, preview_skeleton_path = export_preview_assets( |
| user_dir=user_dir, |
| orig_vertices=orig_vertices.numpy(), |
| orig_faces=orig_faces.numpy(), |
| joints=joints.numpy(), |
| parents=parents.numpy(), |
| skin_weights=skin_weights.numpy(), |
| vertex_colors=vertex_colors, |
| ) |
|
|
| state = { |
| 'orig_vertices': orig_vertices.contiguous(), |
| 'orig_faces': orig_faces.contiguous(), |
| 'vertex_attrs': vertex_attrs.contiguous() if vertex_attrs is not None else None, |
| 'joints': joints.contiguous(), |
| 'parents': parents.contiguous(), |
| 'skin_weights': skin_weights.contiguous(), |
| 'mesh_res': int(getattr(mesh_result, 'res', 64)), |
| } |
| state_path = os.path.join(user_dir, 'generation_state.pt') |
| torch.save(state, state_path) |
| print(f'[app_hf] Preview ready. State saved to {state_path}', flush=True) |
|
|
| status = 'Preview ready. Click “Extract GLB” to run simplification and texture baking.' |
| return state_path, preview_mesh_path, preview_skeleton_path, status |
| except Exception as exc: |
| print(f'[app_hf] generate_preview failed: {exc}', flush=True) |
| print(traceback.format_exc(), flush=True) |
| raise |
| finally: |
| |
| |
| |
| |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
|
|
| @spaces.GPU(duration=120) |
| def extract_glb( |
| generation_state_path: Optional[str], |
| texture_size: int, |
| simplify_ratio: float, |
| fill_holes: bool, |
| progress=gr.Progress(track_tqdm=False), |
| ) -> Tuple[str, Optional[str], str]: |
| try: |
| from anigen.representations.mesh.cube2mesh_skeleton import AniGenMeshExtractResult |
| from anigen.utils.export_utils import convert_to_glb_from_data, visualize_skeleton_as_mesh |
| from anigen.utils.postprocessing_utils import ( |
| bake_texture, |
| barycentric_transfer_attributes, |
| parametrize_mesh, |
| postprocess_mesh, |
| ) |
| from anigen.utils.render_utils import render_multiview |
| from anigen.utils.skin_utils import ( |
| filter_skinning_weights, |
| repair_skeleton_parents, |
| smooth_skin_weights_on_mesh, |
| ) |
|
|
| if not generation_state_path or not os.path.exists(generation_state_path): |
| raise gr.Error('Please click Generate first to create a preview state.') |
|
|
| device = get_runtime_device() |
| print(f'[app_hf] extract_glb started on device={device}', flush=True) |
| state = load_generation_state(generation_state_path) |
|
|
| orig_vertices = state['orig_vertices'].numpy() |
| orig_faces = state['orig_faces'].numpy() |
| joints = state['joints'].numpy() |
| parents = repair_skeleton_parents( |
| joints=joints, |
| parents=state['parents'].numpy().astype(np.int32), |
| verbose=False, |
| ).astype(np.int32) |
| skin_weights = state['skin_weights'].numpy() |
| vertex_attrs_cpu = state.get('vertex_attrs') |
| vertex_colors = None |
| if vertex_attrs_cpu is not None and vertex_attrs_cpu.shape[-1] >= 3: |
| vertex_colors = vertex_attrs_cpu[:, :3].numpy() |
|
|
| user_dir = get_user_dir_from_artifact_path(generation_state_path) |
| output_glb_path = os.path.join(user_dir, 'mesh.glb') |
| skeleton_glb_path = os.path.join(user_dir, 'skeleton.glb') |
|
|
| progress(0.02, desc='Simplifying mesh...') |
| new_vertices, new_faces = postprocess_mesh( |
| orig_vertices, |
| orig_faces, |
| simplify=(simplify_ratio > 0), |
| simplify_ratio=simplify_ratio, |
| fill_holes=fill_holes, |
| verbose=True, |
| ) |
|
|
| if new_vertices.shape[0] != orig_vertices.shape[0]: |
| orig_mesh = trimesh.Trimesh(vertices=orig_vertices, faces=orig_faces, process=False) |
| skin_weights = barycentric_transfer_attributes(orig_mesh, skin_weights, new_vertices) |
| if vertex_colors is not None: |
| vertex_colors = barycentric_transfer_attributes(orig_mesh, vertex_colors, new_vertices) |
|
|
| mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces, process=False) |
|
|
| |
| |
|
|
| progress(0.42, desc='Smoothing skin weights...') |
| skin_weights = smooth_skin_weights_on_mesh( |
| mesh, |
| skin_weights, |
| iterations=100, |
| alpha=1.0, |
| ) |
|
|
| texture_image = None |
| if int(texture_size) > 0: |
| progress(0.55, desc='Parameterizing UVs...') |
| uv_vertices, uv_faces, uvs, vmapping = parametrize_mesh(new_vertices, new_faces) |
| skin_weights = skin_weights[vmapping] |
| if vertex_colors is not None: |
| vertex_colors = vertex_colors[vmapping] |
|
|
| vertex_attrs = None |
| if vertex_attrs_cpu is not None: |
| vertex_attrs = vertex_attrs_cpu.to(device=device, dtype=torch.float32) |
| mesh_result = AniGenMeshExtractResult( |
| vertices=torch.as_tensor(orig_vertices, device=device, dtype=torch.float32), |
| faces=torch.as_tensor(orig_faces, device=device, dtype=torch.long), |
| vertex_attrs=vertex_attrs, |
| res=int(state.get('mesh_res', 64)), |
| ) |
|
|
| progress(0.65, desc='Rendering teacher views...') |
| observations, extrinsics_mv, intrinsics_mv = render_multiview( |
| mesh_result, |
| resolution=1024, |
| nviews=100, |
| ) |
| masks = [np.any(obs > 0, axis=-1) for obs in observations] |
| extrinsics_np = [e.detach().cpu().numpy() for e in extrinsics_mv] |
| intrinsics_np = [i.detach().cpu().numpy() for i in intrinsics_mv] |
|
|
| progress(0.78, desc='Baking texture...') |
| with torch.enable_grad(): |
| texture_image = bake_texture( |
| uv_vertices, |
| uv_faces, |
| uvs, |
| observations, |
| masks, |
| extrinsics_np, |
| intrinsics_np, |
| texture_size=int(texture_size), |
| mode='opt', |
| lambda_tv=0.01, |
| verbose=True, |
| ) |
|
|
| mesh = trimesh.Trimesh( |
| vertices=uv_vertices, |
| faces=uv_faces, |
| visual=trimesh.visual.TextureVisuals(uv=uvs), |
| process=False, |
| ) |
|
|
| progress(0.94, desc='Exporting GLB...') |
| convert_to_glb_from_data( |
| mesh, |
| joints, |
| parents, |
| skin_weights, |
| output_glb_path, |
| vertex_colors=vertex_colors, |
| texture_image=texture_image, |
| ) |
|
|
| skeleton_mesh = visualize_skeleton_as_mesh(joints, parents) |
| if skeleton_mesh is not None and len(skeleton_mesh.vertices) > 0: |
| skeleton_mesh.export(skeleton_glb_path) |
| else: |
| skeleton_glb_path = None |
|
|
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| print('[app_hf] Final GLB extraction completed.', flush=True) |
| return output_glb_path, skeleton_glb_path, 'Final GLB ready with mesh simplification and texture baking.' |
| except Exception as exc: |
| print(f'[app_hf] extract_glb failed: {exc}', flush=True) |
| print(traceback.format_exc(), flush=True) |
| raise |
|
|
|
|
| with gr.Blocks(delete_cache=(600, 600)) as demo: |
| gr.Markdown( |
| """ |
| ## Image to 3D Asset with [AniGen] |
| * Click **Generate** for a fast preview, then click **Extract GLB** for the full textured glb file export. |
| * [AniGen GitHub Repository](https://github.com/VAST-AI-Research/AniGen) |
| * [Tripo: Your 3D Workspace with AI](https://www.tripo3d.ai) |
| """ |
| ) |
|
|
| gr.HTML(""" |
| <style> |
| @keyframes gentle-pulse { |
| 0%, 100% { opacity: 1; } |
| 50% { opacity: 0.35; } |
| } |
| </style> |
| <div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin-bottom:-8px;"> |
| <span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">💡 <b>Tip</b></span>  |
| Not satisfied with the geometry or skeleton? |
| Try switching the SS Model to <code>ss_flow_solo</code> or <code>ss_flow_duet</code> in Generation Settings. |
| </div> |
| """) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| processed_input_path_state = gr.State(value=None) |
| generation_state_path = gr.State(value=None) |
| image_prompt = gr.Image(label='Image Prompt', format='png', image_mode='RGBA', type='pil', height=300) |
|
|
| with gr.Accordion(label='Generation Settings', open=True): |
| seed = gr.Slider(0, MAX_SEED, label='Seed', value=42, step=1) |
| randomize_seed = gr.Checkbox(label='Randomize Seed', value=False) |
|
|
| gr.Markdown('**Model Selection**') |
| with gr.Row(): |
| ss_model_dropdown = gr.Dropdown( |
| choices=SS_MODEL_CHOICES, |
| value=DEFAULT_SS_MODEL, |
| label='SS Model (Sparse Structure)', |
| ) |
| slat_model_dropdown = gr.Dropdown( |
| choices=SLAT_MODEL_CHOICES, |
| value=DEFAULT_SLAT_MODEL, |
| label='SLAT Model (Structured Latent)', |
| ) |
|
|
| gr.Markdown('Stage 1: Sparse Structure Generation') |
| with gr.Row(): |
| ss_guidance_strength = gr.Slider(0.0, 15.0, label='Guidance Strength', value=7.5, step=0.1) |
| ss_sampling_steps = gr.Slider(1, 50, label='Sampling Steps', value=25, step=1) |
|
|
| gr.Markdown('Stage 2: Structured Latent Generation') |
| with gr.Row(): |
| slat_guidance_strength = gr.Slider(0.0, 10.0, label='Guidance Strength', value=3.0, step=0.1) |
| slat_sampling_steps = gr.Slider(1, 50, label='Sampling Steps', value=25, step=1) |
|
|
| gr.Markdown('Skeleton & Skinning Settings') |
| joints_density = gr.Slider(0, 4, label='Joints Density', value=1, step=1, interactive=False) |
| density_hint = gr.Markdown( |
| '*Switch `SLAT Model` to `slat_flow_control` to enable joint density control.*', |
| visible=True, |
| ) |
|
|
| with gr.Accordion(label='Extraction Settings', open=False): |
| simplify_ratio = gr.Slider(0.0, 0.99, label='Mesh Simplification Ratio', value=0.95, step=0.01) |
| fill_holes = gr.Checkbox(label='Fill Holes', value=True) |
| texture_size = gr.Slider(256, 2048, label='Texture Resolution', value=1024, step=256) |
|
|
| with gr.Row(): |
| generate_btn = gr.Button('Generate') |
| extract_btn = gr.Button('Extract GLB') |
|
|
| with gr.Column(): |
| mesh_output = gr.Model3D(label="Generated Mesh", height=300, interactive=False) |
| download_mesh = gr.DownloadButton(label='Download Mesh GLB', interactive=False) |
| skeleton_output = LitModel3D(label='Skeleton Preview / Final GLB', exposure=5.0, height=300, interactive=False) |
| download_skeleton = gr.DownloadButton(label='Download Skeleton GLB', interactive=False) |
| processed_image_output = gr.Image(label='Processed Image', type='pil', height=300) |
| status_output = gr.Markdown('Upload an image, click **Generate**, then click **Extract GLB**.') |
|
|
| with gr.Row() as single_image_example: |
| gr.Examples( |
| examples=[ |
| f'assets/cond_images/{image}' |
| for image in os.listdir('assets/cond_images') |
| ], |
| inputs=[image_prompt], |
| examples_per_page=64, |
| ) |
|
|
| demo.load(start_session) |
| demo.unload(end_session) |
|
|
| slat_model_dropdown.change( |
| on_slat_model_change, |
| inputs=[slat_model_dropdown], |
| outputs=[joints_density, density_hint], |
| ) |
|
|
| generate_btn.click( |
| get_seed, |
| inputs=[randomize_seed, seed], |
| outputs=[seed], |
| ).then( |
| prepare_input_for_generation, |
| inputs=[image_prompt], |
| outputs=[processed_input_path_state, processed_image_output], |
| ).then( |
| mark_generate_queued, |
| inputs=[ss_sampling_steps, slat_sampling_steps], |
| outputs=[status_output], |
| ).then( |
| generate_preview, |
| inputs=[ |
| processed_input_path_state, |
| seed, |
| ss_model_dropdown, |
| slat_model_dropdown, |
| ss_guidance_strength, |
| ss_sampling_steps, |
| slat_guidance_strength, |
| slat_sampling_steps, |
| joints_density, |
| ], |
| outputs=[ |
| generation_state_path, |
| mesh_output, |
| skeleton_output, |
| status_output, |
| ], |
| ).then( |
| disable_download_buttons, |
| outputs=[download_mesh, download_skeleton], |
| ) |
|
|
| extract_btn.click( |
| mark_extract_queued, |
| inputs=[texture_size, simplify_ratio, fill_holes], |
| outputs=[status_output], |
| ).then( |
| extract_glb, |
| inputs=[generation_state_path, texture_size, simplify_ratio, fill_holes], |
| outputs=[mesh_output, skeleton_output, status_output], |
| ).then( |
| update_download_buttons, |
| inputs=[mesh_output, skeleton_output], |
| outputs=[download_mesh, download_skeleton], |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import importlib.util as _iu |
| _spec = _iu.spec_from_file_location( |
| 'anigen_ckpt_utils_isolated', |
| os.path.join(os.path.dirname(os.path.abspath(__file__)), 'anigen/utils/ckpt_utils.py'), |
| ) |
| _mod = _iu.module_from_spec(_spec) |
| _spec.loader.exec_module(_mod) |
| _mod.ensure_ckpts() |
| del _iu, _spec, _mod |
|
|
|
|
| if __name__ == '__main__': |
| demo.launch(server_name='0.0.0.0', share=True) |
|
|