AniGen / app.py
Yihua7's picture
Initial commit: AniGen - Animatable 3D Generation
6b92ff7
import os
import shutil
import sys
import traceback
import uuid
from pathlib import Path
from typing import *
import gradio as gr
import numpy as np
import rembg
import spaces
import torch
import trimesh
from PIL import Image
from gradio_litmodel3d import LitModel3D
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), 'third_parties/dsine'))
# IMPORTANT: Do NOT import anything from `anigen.*` or `third_parties.*` at
# module scope. `anigen/__init__.py` eagerly imports `anigen.models`,
# `anigen.modules`, `anigen.pipelines`, etc., which pulls in `warp` and other
# native libs. When warp imports it tries to init CUDA, and on ZeroGPU the main
# process has no GPU, so it sets a bad global CUDA state. Any subsequent
# `@spaces.GPU` forked worker then dies silently with "GPU task aborted" before
# the task body runs. Keeping the main process free of `anigen` imports avoids
# this. The worker imports `anigen` fresh and works correctly.
MAX_SEED = 100
TMP_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tmp')
os.makedirs(TMP_DIR, exist_ok=True)
SS_MODEL_CHOICES = ["ss_flow_duet", "ss_flow_solo", "ss_flow_epic"]
SLAT_MODEL_CHOICES = ["slat_flow_auto", "slat_flow_control"]
DEFAULT_SS_MODEL = "ss_flow_duet"
DEFAULT_SLAT_MODEL = "slat_flow_auto"
current_ss_model_name = DEFAULT_SS_MODEL
current_slat_model_name = DEFAULT_SLAT_MODEL
pipeline = None
rembg_session = None
def get_runtime_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def get_session_dir(session_id: Optional[str]) -> str:
target_session = session_id or uuid.uuid4().hex
user_dir = os.path.join(TMP_DIR, target_session)
os.makedirs(user_dir, exist_ok=True)
return user_dir
def get_pipeline(device: Optional[str] = None):
# NOTE: This function must only be called from inside a `@spaces.GPU`
# worker. The ZeroGPU pattern of pre-loading on CPU in the main process and
# then calling `.to("cuda")` inside the worker crashes silently for the
# AniGen pipeline (worker exits with "GPU task aborted" before any print).
# We therefore lazily create a fresh pipeline *inside* the worker and cache
# it in a module-level global so that subsequent reused workers skip the
# 40+s load.
global pipeline
device = device or get_runtime_device()
if pipeline is None:
from anigen.pipelines import AnigenImageTo3DPipeline
print(f"[app_hf] Initializing pipeline on {device}...", flush=True)
pipeline = AnigenImageTo3DPipeline.from_pretrained(
ss_flow_path=f'ckpts/anigen/{DEFAULT_SS_MODEL}',
slat_flow_path=f'ckpts/anigen/{DEFAULT_SLAT_MODEL}',
device=device,
use_ema=False,
)
print(f"[app_hf] Pipeline initialized on {device}.", flush=True)
elif pipeline.device.type != device:
print(f"[app_hf] Moving pipeline from {pipeline.device.type} to {device}...", flush=True)
pipeline.to(torch.device(device))
print(f"[app_hf] Pipeline moved to {device}.", flush=True)
return pipeline
def get_rembg_session():
global rembg_session
if rembg_session is None:
print("[app_hf] Initializing rembg u2net session on CPU...", flush=True)
rembg_session = rembg.new_session("birefnet-general")
print("[app_hf] rembg session ready.", flush=True)
return rembg_session
def start_session(req: gr.Request):
get_session_dir(req.session_hash)
def end_session(req: gr.Request):
shutil.rmtree(get_session_dir(req.session_hash), ignore_errors=True)
def preprocess_for_display_and_inference(image: Optional[Image.Image]) -> Tuple[Optional[Image.Image], Optional[Image.Image]]:
if image is None:
return None, None
has_alpha = False
if image.mode == 'RGBA':
alpha = np.array(image)[:, :, 3]
if not np.all(alpha == 255):
has_alpha = True
if has_alpha:
rgba_output = image.convert('RGBA')
else:
input_image = image.convert('RGB')
max_size = max(input_image.size)
scale = min(1, 1024 / max_size)
if scale < 1:
input_image = input_image.resize(
(int(input_image.width * scale), int(input_image.height * scale)),
Image.Resampling.LANCZOS,
)
rgba_output = rembg.remove(input_image, session=get_rembg_session())
output_np = np.array(rgba_output)
alpha = output_np[:, :, 3]
bbox = np.argwhere(alpha > 0.8 * 255)
if len(bbox) == 0:
bbox_crop = (0, 0, rgba_output.width, rgba_output.height)
else:
bbox = np.min(bbox[:, 1]), np.min(bbox[:, 0]), np.max(bbox[:, 1]), np.max(bbox[:, 0])
center = (bbox[0] + bbox[2]) / 2, (bbox[1] + bbox[3]) / 2
size = max(bbox[2] - bbox[0], bbox[3] - bbox[1])
size = int(size * 1.2)
bbox_crop = (
int(center[0] - size // 2),
int(center[1] - size // 2),
int(center[0] + size // 2),
int(center[1] + size // 2),
)
rgba_output = rgba_output.crop(bbox_crop)
rgba_output = rgba_output.resize((518, 518), Image.Resampling.LANCZOS)
display_np = np.array(rgba_output).astype(np.float32) / 255
display_np = display_np[:, :, :3] * display_np[:, :, 3:4]
display_image = Image.fromarray((display_np * 255).astype(np.uint8))
return display_image, rgba_output
def save_processed_rgba(image: Optional[Image.Image], session_id: Optional[str] = None) -> Optional[str]:
if image is None:
return None
file_path = os.path.join(get_session_dir(session_id), 'processed_input_rgba.png')
image.save(file_path)
return file_path
def load_processed_rgba(path: Optional[str]) -> Optional[Image.Image]:
if not path:
return None
file_path = Path(path)
if not file_path.exists():
return None
return Image.open(file_path).convert('RGBA')
def prepare_input_for_generation(image: Optional[Image.Image], req: gr.Request = None) -> Tuple[Optional[str], Optional[Image.Image]]:
print('[app_hf] Preparing input on CPU before GPU stage...', flush=True)
processed_image, processed_rgba = preprocess_for_display_and_inference(image)
processed_rgba_path = save_processed_rgba(processed_rgba, req.session_hash if req else None)
print(f'[app_hf] CPU preprocessing completed. path={processed_rgba_path}', flush=True)
return processed_rgba_path, processed_image
def get_seed(randomize_seed: bool, seed: int) -> int:
return np.random.randint(0, MAX_SEED) if randomize_seed else seed
def on_slat_model_change(slat_model_name: str):
is_control = (slat_model_name == 'slat_flow_control')
return (
gr.update(interactive=is_control),
gr.update(visible=not is_control),
)
def save_generation_state(state: Dict[str, Any], session_id: Optional[str]) -> str:
state_path = os.path.join(get_session_dir(session_id), 'generation_state.pt')
torch.save(state, state_path)
return state_path
def load_generation_state(state_path: str) -> Dict[str, Any]:
return torch.load(state_path, map_location='cpu')
def export_preview_assets(
user_dir: str,
orig_vertices: np.ndarray,
orig_faces: np.ndarray,
joints: np.ndarray,
parents: np.ndarray,
skin_weights: np.ndarray,
vertex_colors: Optional[np.ndarray],
) -> Tuple[str, Optional[str]]:
# Lazy import: see the note at the top of the file about not importing
# anigen in the main process.
from anigen.utils.export_utils import convert_to_glb_from_data, visualize_skeleton_as_mesh
preview_mesh_path = os.path.join(user_dir, 'preview_mesh.glb')
preview_skeleton_path = os.path.join(user_dir, 'preview_skeleton.glb')
preview_mesh = trimesh.Trimesh(vertices=orig_vertices, faces=orig_faces, process=False)
convert_to_glb_from_data(
preview_mesh,
joints,
parents,
skin_weights,
preview_mesh_path,
vertex_colors=vertex_colors,
texture_image=None,
)
skeleton_mesh = visualize_skeleton_as_mesh(joints, parents)
if skeleton_mesh is not None and len(skeleton_mesh.vertices) > 0:
skeleton_mesh.export(preview_skeleton_path)
else:
preview_skeleton_path = None
return preview_mesh_path, preview_skeleton_path
def get_user_dir_from_artifact_path(path: Optional[str]) -> str:
if not path:
raise gr.Error('Missing intermediate artifact path.')
return str(Path(path).resolve().parent)
def update_download_buttons(mesh_path: Optional[str], skel_path: Optional[str]):
return (
gr.update(value=mesh_path, interactive=bool(mesh_path)),
gr.update(value=skel_path, interactive=bool(skel_path)),
)
def disable_download_buttons():
return (
gr.update(value=None, interactive=False),
gr.update(value=None, interactive=False),
)
def mark_generate_queued(ss_sampling_steps: int, slat_sampling_steps: int):
return (
f'CPU preprocessing done. Waiting for ZeroGPU allocation for preview generation '
f'(SS steps: {ss_sampling_steps}, SLat steps: {slat_sampling_steps}).'
)
def mark_extract_queued(texture_size: int, simplify_ratio: float, fill_holes: bool):
return (
f'Waiting for ZeroGPU allocation for final GLB extraction '
f'(texture: {texture_size}, simplify: {simplify_ratio:.2f}, fill_holes: {fill_holes}).'
)
@spaces.GPU(duration=120)
def generate_preview(
processed_input_rgba_path: Optional[str],
seed: int,
ss_model_name: str,
slat_model_name: str,
ss_guidance_strength: float,
ss_sampling_steps: int,
slat_guidance_strength: float,
slat_sampling_steps: int,
joints_density: int,
progress=gr.Progress(track_tqdm=False),
) -> Tuple[str, str, Optional[str], str]:
global current_ss_model_name, current_slat_model_name
print('[app_hf] generate_preview: entered GPU function, requesting pipeline...', flush=True)
try:
from anigen.utils.skin_utils import repair_skeleton_parents
device = get_runtime_device()
print(f'[app_hf] generate_preview started on device={device}', flush=True)
processed_input_rgba = load_processed_rgba(processed_input_rgba_path)
if processed_input_rgba is None:
raise gr.Error('Missing processed input image. Please upload an image and click Generate again.')
print('[app_hf] processed input loaded; initializing pipeline next.', flush=True)
pipe = get_pipeline(device)
if ss_model_name != current_ss_model_name:
progress(0, desc=f'Loading SS model: {ss_model_name}...')
pipe.load_ss_flow_model(f'ckpts/anigen/{ss_model_name}', device=device, use_ema=False)
current_ss_model_name = ss_model_name
if slat_model_name != current_slat_model_name:
progress(0, desc=f'Loading SLAT model: {slat_model_name}...')
pipe.load_slat_flow_model(f'ckpts/anigen/{slat_model_name}', device=device, use_ema=False)
current_slat_model_name = slat_model_name
torch.manual_seed(seed)
np.random.seed(seed)
progress(0.02, desc='Estimating normals...')
processed_image, processed_normal = pipe.preprocess_image(processed_input_rgba)
print('[app_hf] preprocessing on GPU worker finished.', flush=True)
progress(0.08, desc='Encoding image conditions...')
cond_dict_ss, cond_dict_slat_rgb = pipe.get_cond(processed_image, processed_normal)
print('[app_hf] conditioning ready.', flush=True)
def ss_progress_callback(step, total):
frac = (step + 1) / total
progress(0.10 + frac * 0.40, desc=f'SS Sampling: {step + 1}/{total}')
def slat_progress_callback(step, total):
frac = (step + 1) / total
progress(0.50 + frac * 0.40, desc=f'SLat Sampling: {step + 1}/{total}')
coords, coords_skl, _, _ = pipe.sample_sparse_structure(
cond_dict_ss,
strength=ss_guidance_strength,
steps=ss_sampling_steps,
progress_callback=ss_progress_callback,
)
print('[app_hf] sparse structure sampled.', flush=True)
slat, slat_skl = pipe.sample_slat(
cond_dict_slat_rgb,
coords,
coords_skl,
strength=slat_guidance_strength,
steps=slat_sampling_steps,
joint_density=joints_density,
progress_callback=slat_progress_callback,
)
print('[app_hf] slat sampled.', flush=True)
progress(0.92, desc='Decoding preview mesh...')
mesh_result, skeleton_result = pipe.decode_slat(slat, slat_skl)
print('[app_hf] decode finished.', flush=True)
joints = skeleton_result.joints_grouped.detach().cpu().to(torch.float32)
parents = skeleton_result.parents_grouped.detach().cpu().to(torch.int32)
parents_np = repair_skeleton_parents(
joints=joints.numpy(),
parents=parents.numpy(),
verbose=False,
).astype(np.int32)
parents = torch.from_numpy(parents_np)
skin_weights = skeleton_result.skin_pred.detach().cpu().to(torch.float32)
orig_vertices = mesh_result.vertices.detach().cpu().to(torch.float32)
orig_faces = mesh_result.faces.detach().cpu().to(torch.long)
vertex_attrs = None
if getattr(mesh_result, 'vertex_attrs', None) is not None:
vertex_attrs = mesh_result.vertex_attrs.detach().cpu().to(torch.float32)
vertex_colors = None
if vertex_attrs is not None and vertex_attrs.shape[-1] >= 3:
vertex_colors = vertex_attrs[:, :3].numpy()
user_dir = get_user_dir_from_artifact_path(processed_input_rgba_path)
preview_mesh_path, preview_skeleton_path = export_preview_assets(
user_dir=user_dir,
orig_vertices=orig_vertices.numpy(),
orig_faces=orig_faces.numpy(),
joints=joints.numpy(),
parents=parents.numpy(),
skin_weights=skin_weights.numpy(),
vertex_colors=vertex_colors,
)
state = {
'orig_vertices': orig_vertices.contiguous(),
'orig_faces': orig_faces.contiguous(),
'vertex_attrs': vertex_attrs.contiguous() if vertex_attrs is not None else None,
'joints': joints.contiguous(),
'parents': parents.contiguous(),
'skin_weights': skin_weights.contiguous(),
'mesh_res': int(getattr(mesh_result, 'res', 64)),
}
state_path = os.path.join(user_dir, 'generation_state.pt')
torch.save(state, state_path)
print(f'[app_hf] Preview ready. State saved to {state_path}', flush=True)
status = 'Preview ready. Click “Extract GLB” to run simplification and texture baking.'
return state_path, preview_mesh_path, preview_skeleton_path, status
except Exception as exc:
print(f'[app_hf] generate_preview failed: {exc}', flush=True)
print(traceback.format_exc(), flush=True)
raise
finally:
# Don't move pipeline back to CPU: the ZeroGPU worker is reused across
# calls, so keeping the pipeline on GPU lets subsequent calls skip the
# 40+s load. Moving to CPU and back has also been shown to trigger
# silent worker aborts for this pipeline.
if torch.cuda.is_available():
torch.cuda.empty_cache()
@spaces.GPU(duration=120)
def extract_glb(
generation_state_path: Optional[str],
texture_size: int,
simplify_ratio: float,
fill_holes: bool,
progress=gr.Progress(track_tqdm=False),
) -> Tuple[str, Optional[str], str]:
try:
from anigen.representations.mesh.cube2mesh_skeleton import AniGenMeshExtractResult
from anigen.utils.export_utils import convert_to_glb_from_data, visualize_skeleton_as_mesh
from anigen.utils.postprocessing_utils import (
bake_texture,
barycentric_transfer_attributes,
parametrize_mesh,
postprocess_mesh,
)
from anigen.utils.render_utils import render_multiview
from anigen.utils.skin_utils import (
filter_skinning_weights,
repair_skeleton_parents,
smooth_skin_weights_on_mesh,
)
if not generation_state_path or not os.path.exists(generation_state_path):
raise gr.Error('Please click Generate first to create a preview state.')
device = get_runtime_device()
print(f'[app_hf] extract_glb started on device={device}', flush=True)
state = load_generation_state(generation_state_path)
orig_vertices = state['orig_vertices'].numpy()
orig_faces = state['orig_faces'].numpy()
joints = state['joints'].numpy()
parents = repair_skeleton_parents(
joints=joints,
parents=state['parents'].numpy().astype(np.int32),
verbose=False,
).astype(np.int32)
skin_weights = state['skin_weights'].numpy()
vertex_attrs_cpu = state.get('vertex_attrs')
vertex_colors = None
if vertex_attrs_cpu is not None and vertex_attrs_cpu.shape[-1] >= 3:
vertex_colors = vertex_attrs_cpu[:, :3].numpy()
user_dir = get_user_dir_from_artifact_path(generation_state_path)
output_glb_path = os.path.join(user_dir, 'mesh.glb')
skeleton_glb_path = os.path.join(user_dir, 'skeleton.glb')
progress(0.02, desc='Simplifying mesh...')
new_vertices, new_faces = postprocess_mesh(
orig_vertices,
orig_faces,
simplify=(simplify_ratio > 0),
simplify_ratio=simplify_ratio,
fill_holes=fill_holes,
verbose=True,
)
if new_vertices.shape[0] != orig_vertices.shape[0]:
orig_mesh = trimesh.Trimesh(vertices=orig_vertices, faces=orig_faces, process=False)
skin_weights = barycentric_transfer_attributes(orig_mesh, skin_weights, new_vertices)
if vertex_colors is not None:
vertex_colors = barycentric_transfer_attributes(orig_mesh, vertex_colors, new_vertices)
mesh = trimesh.Trimesh(vertices=new_vertices, faces=new_faces, process=False)
# progress(0.30, desc='Filtering skin weights...')
# skin_weights = filter_skinning_weights(mesh, skin_weights, joints, parents)
progress(0.42, desc='Smoothing skin weights...')
skin_weights = smooth_skin_weights_on_mesh(
mesh,
skin_weights,
iterations=100,
alpha=1.0,
)
texture_image = None
if int(texture_size) > 0:
progress(0.55, desc='Parameterizing UVs...')
uv_vertices, uv_faces, uvs, vmapping = parametrize_mesh(new_vertices, new_faces)
skin_weights = skin_weights[vmapping]
if vertex_colors is not None:
vertex_colors = vertex_colors[vmapping]
vertex_attrs = None
if vertex_attrs_cpu is not None:
vertex_attrs = vertex_attrs_cpu.to(device=device, dtype=torch.float32)
mesh_result = AniGenMeshExtractResult(
vertices=torch.as_tensor(orig_vertices, device=device, dtype=torch.float32),
faces=torch.as_tensor(orig_faces, device=device, dtype=torch.long),
vertex_attrs=vertex_attrs,
res=int(state.get('mesh_res', 64)),
)
progress(0.65, desc='Rendering teacher views...')
observations, extrinsics_mv, intrinsics_mv = render_multiview(
mesh_result,
resolution=1024,
nviews=100,
)
masks = [np.any(obs > 0, axis=-1) for obs in observations]
extrinsics_np = [e.detach().cpu().numpy() for e in extrinsics_mv]
intrinsics_np = [i.detach().cpu().numpy() for i in intrinsics_mv]
progress(0.78, desc='Baking texture...')
with torch.enable_grad():
texture_image = bake_texture(
uv_vertices,
uv_faces,
uvs,
observations,
masks,
extrinsics_np,
intrinsics_np,
texture_size=int(texture_size),
mode='opt',
lambda_tv=0.01,
verbose=True,
)
mesh = trimesh.Trimesh(
vertices=uv_vertices,
faces=uv_faces,
visual=trimesh.visual.TextureVisuals(uv=uvs),
process=False,
)
progress(0.94, desc='Exporting GLB...')
convert_to_glb_from_data(
mesh,
joints,
parents,
skin_weights,
output_glb_path,
vertex_colors=vertex_colors,
texture_image=texture_image,
)
skeleton_mesh = visualize_skeleton_as_mesh(joints, parents)
if skeleton_mesh is not None and len(skeleton_mesh.vertices) > 0:
skeleton_mesh.export(skeleton_glb_path)
else:
skeleton_glb_path = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
print('[app_hf] Final GLB extraction completed.', flush=True)
return output_glb_path, skeleton_glb_path, 'Final GLB ready with mesh simplification and texture baking.'
except Exception as exc:
print(f'[app_hf] extract_glb failed: {exc}', flush=True)
print(traceback.format_exc(), flush=True)
raise
with gr.Blocks(delete_cache=(600, 600)) as demo:
gr.Markdown(
"""
## Image to 3D Asset with [AniGen]
* Click **Generate** for a fast preview, then click **Extract GLB** for the full textured glb file export.
* [AniGen GitHub Repository](https://github.com/VAST-AI-Research/AniGen)
* [Tripo: Your 3D Workspace with AI](https://www.tripo3d.ai)
"""
)
gr.HTML("""
<style>
@keyframes gentle-pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.35; }
}
</style>
<div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin-bottom:-8px;">
<span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">&#128161; <b>Tip</b></span>&ensp;
Not satisfied with the geometry or skeleton?
Try switching the SS Model to <code>ss_flow_solo</code> or <code>ss_flow_duet</code> in Generation Settings.
</div>
""")
with gr.Row():
with gr.Column():
processed_input_path_state = gr.State(value=None)
generation_state_path = gr.State(value=None)
image_prompt = gr.Image(label='Image Prompt', format='png', image_mode='RGBA', type='pil', height=300)
with gr.Accordion(label='Generation Settings', open=True):
seed = gr.Slider(0, MAX_SEED, label='Seed', value=42, step=1)
randomize_seed = gr.Checkbox(label='Randomize Seed', value=False)
gr.Markdown('**Model Selection**')
with gr.Row():
ss_model_dropdown = gr.Dropdown(
choices=SS_MODEL_CHOICES,
value=DEFAULT_SS_MODEL,
label='SS Model (Sparse Structure)',
)
slat_model_dropdown = gr.Dropdown(
choices=SLAT_MODEL_CHOICES,
value=DEFAULT_SLAT_MODEL,
label='SLAT Model (Structured Latent)',
)
gr.Markdown('Stage 1: Sparse Structure Generation')
with gr.Row():
ss_guidance_strength = gr.Slider(0.0, 15.0, label='Guidance Strength', value=7.5, step=0.1)
ss_sampling_steps = gr.Slider(1, 50, label='Sampling Steps', value=25, step=1)
gr.Markdown('Stage 2: Structured Latent Generation')
with gr.Row():
slat_guidance_strength = gr.Slider(0.0, 10.0, label='Guidance Strength', value=3.0, step=0.1)
slat_sampling_steps = gr.Slider(1, 50, label='Sampling Steps', value=25, step=1)
gr.Markdown('Skeleton & Skinning Settings')
joints_density = gr.Slider(0, 4, label='Joints Density', value=1, step=1, interactive=False)
density_hint = gr.Markdown(
'*Switch `SLAT Model` to `slat_flow_control` to enable joint density control.*',
visible=True,
)
with gr.Accordion(label='Extraction Settings', open=False):
simplify_ratio = gr.Slider(0.0, 0.99, label='Mesh Simplification Ratio', value=0.95, step=0.01)
fill_holes = gr.Checkbox(label='Fill Holes', value=True)
texture_size = gr.Slider(256, 2048, label='Texture Resolution', value=1024, step=256)
with gr.Row():
generate_btn = gr.Button('Generate')
extract_btn = gr.Button('Extract GLB')
with gr.Column():
mesh_output = gr.Model3D(label="Generated Mesh", height=300, interactive=False)
download_mesh = gr.DownloadButton(label='Download Mesh GLB', interactive=False)
skeleton_output = LitModel3D(label='Skeleton Preview / Final GLB', exposure=5.0, height=300, interactive=False)
download_skeleton = gr.DownloadButton(label='Download Skeleton GLB', interactive=False)
processed_image_output = gr.Image(label='Processed Image', type='pil', height=300)
status_output = gr.Markdown('Upload an image, click **Generate**, then click **Extract GLB**.')
with gr.Row() as single_image_example:
gr.Examples(
examples=[
f'assets/cond_images/{image}'
for image in os.listdir('assets/cond_images')
],
inputs=[image_prompt],
examples_per_page=64,
)
demo.load(start_session)
demo.unload(end_session)
slat_model_dropdown.change(
on_slat_model_change,
inputs=[slat_model_dropdown],
outputs=[joints_density, density_hint],
)
generate_btn.click(
get_seed,
inputs=[randomize_seed, seed],
outputs=[seed],
).then(
prepare_input_for_generation,
inputs=[image_prompt],
outputs=[processed_input_path_state, processed_image_output],
).then(
mark_generate_queued,
inputs=[ss_sampling_steps, slat_sampling_steps],
outputs=[status_output],
).then(
generate_preview,
inputs=[
processed_input_path_state,
seed,
ss_model_dropdown,
slat_model_dropdown,
ss_guidance_strength,
ss_sampling_steps,
slat_guidance_strength,
slat_sampling_steps,
joints_density,
],
outputs=[
generation_state_path,
mesh_output,
skeleton_output,
status_output,
],
).then(
disable_download_buttons,
outputs=[download_mesh, download_skeleton],
)
extract_btn.click(
mark_extract_queued,
inputs=[texture_size, simplify_ratio, fill_holes],
outputs=[status_output],
).then(
extract_glb,
inputs=[generation_state_path, texture_size, simplify_ratio, fill_holes],
outputs=[mesh_output, skeleton_output, status_output],
).then(
update_download_buttons,
inputs=[mesh_output, skeleton_output],
outputs=[download_mesh, download_skeleton],
)
# Pre-download any missing checkpoints at module load so the first request
# doesn't pay the download cost. The pipeline itself is NOT instantiated here:
# AniGen's module tree crashes the ZeroGPU forked worker when it tries to move
# the CPU-preloaded pipeline to cuda. We instead load the pipeline lazily
# inside `generate_preview` (which runs inside the spaces.GPU worker); because
# ZeroGPU reuses the worker process across calls, the 40+s load only happens
# on the very first request per worker.
#
# We import `ensure_ckpts` by loading the file directly, rather than doing
# `from anigen.utils.ckpt_utils import ensure_ckpts`, because the latter runs
# `anigen/__init__.py` which eagerly imports `anigen.models`, `warp`, spconv
# etc. and leaves the main process in a bad CUDA state. See the note at the
# top of this file.
import importlib.util as _iu
_spec = _iu.spec_from_file_location(
'anigen_ckpt_utils_isolated',
os.path.join(os.path.dirname(os.path.abspath(__file__)), 'anigen/utils/ckpt_utils.py'),
)
_mod = _iu.module_from_spec(_spec)
_spec.loader.exec_module(_mod)
_mod.ensure_ckpts()
del _iu, _spec, _mod
if __name__ == '__main__':
demo.launch(server_name='0.0.0.0', share=True)