PartCrafter / app.py
theYiran's picture
Update app.py
79b41c0 verified
import os
import sys
import subprocess # <--- 确保这行在这里!
import importlib
import site
import time
# --- 🧪 1. 内存级伪造 diso (必须在任何业务 import 之前) ---
def mock_diso():
from types import ModuleType
print("🧪 Creating emergency mock for diso...")
diso = ModuleType("diso")
class FakeDiffDMC:
def __init__(self, *args, **kwargs): pass
def __call__(self, *args, **kwargs): return None
diso.DiffDMC = FakeDiffDMC
sys.modules["diso"] = diso
sys.modules["diso._C"] = ModuleType("diso._C")
sys.modules["diso.diso_native"] = ModuleType("diso.diso_native")
print("✅ diso has been mocked successfully!")
mock_diso()
# --- 🚀 2. 极速环境安装 (已经成功的 scatter/sparse) ---
def install_essential_packages():
print("📦 Checking core dependencies...")
# 确保基础环境正确
subprocess.run([sys.executable, "-m", "pip", "install", "ninja", "setuptools", "wheel", "-q"])
# 极速安装 PyG 扩展
subprocess.run([
sys.executable, "-m", "pip", "install",
"torch-scatter", "torch-sparse", "torch-cluster",
"-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html",
"--no-cache-dir"
])
# 安装剩下的渲染工具
subprocess.run([
sys.executable, "-m", "pip", "install",
"pyrender", "pyopengl==3.1.0", "pyyaml", "trimesh", "accelerate", "-q"
])
importlib.invalidate_caches()
site.main()
print("🎉 Environment Installation Phase Finished.")
install_essential_packages()
# ... 之前的 mock_diso 和安装逻辑 ...
# 1. 核心路径保护
os.environ["PARTCRAFTER_PROCESSED"] = os.environ.get("PARTCRAFTER_PROCESSED", "outputs")
os.makedirs(os.environ["PARTCRAFTER_PROCESSED"], exist_ok=True)
# 2. 模型权重下载路径确认 (确保这些目录也存在)
os.makedirs("pretrained_weights/PartCrafter", exist_ok=True)
os.makedirs("pretrained_weights/RMBG-1.4", exist_ok=True)
# ... 继续执行 snapshot_download ...
# --- 3. 正式导入业务逻辑 (现在开始这几百行代码就不会报错了) ---
import spaces
import gradio as gr
import numpy as np
import torch
import uuid
import shutil
from huggingface_hub import snapshot_download
from PIL import Image
from accelerate.utils import set_seed
# 从这里往下,粘贴你原本所有的业务逻辑代码 (PartCrafterPipeline 等)
# ...
# --- 🚀 核心修复:强制版本回退以避开编译 ---
def pre_install_check():
try:
import torch
# 如果是 2.9+ 版本,强制降级到有预编译包的 2.4.0
if "2.9" in torch.__version__:
print(f"🔄 Current torch {torch.__version__} is too new. Downgrading to 2.4.0 for speed...")
subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja", "setuptools", "wheel", "-q"])
subprocess.check_call([
sys.executable, "-m", "pip", "install",
"torch==2.4.0+cu121", "torchvision==0.19.0+cu121",
"--extra-index-url", "https://download.pytorch.org/whl/cu121"
])
# 刷新路径
importlib.invalidate_caches()
os.execv(sys.executable, ['python'] + sys.argv) # 重启进程以加载新版本
except Exception as e:
print(f"Pre-install check note: {e}")
pre_install_check()
import trimesh
import glob
import importlib, site
# Re-discover all .pth/.egg-link files
for sitedir in site.getsitepackages():
site.addsitedir(sitedir)
importlib.invalidate_caches()
# --- 简化的 CUDA 环境配置 ---
def setup_cuda_env():
cuda_path = "/usr/local/cuda"
if os.path.exists(cuda_path):
os.environ["CUDA_HOME"] = cuda_path
os.environ["PATH"] = f"{cuda_path}/bin:{os.environ['PATH']}"
os.environ["LD_LIBRARY_PATH"] = f"{cuda_path}/lib64:{os.environ.get('LD_LIBRARY_PATH', '')}"
print(f"==> Using system CUDA at {cuda_path}")
setup_cuda_env()
# --- 🚀 针对 PyTorch 2.9.1 的优化源码编译方案 ---
# --- 🚀 暴力整合版:攻克 diso 最后的防线 ---
def install_heavy_packages():
os.environ['PYOPENGL_PLATFORM'] = 'egl'
# 1. PyG 扩展(这部分已经稳了,保持不动)
print("📦 Installing PyG extensions...")
subprocess.run([
sys.executable, "-m", "pip", "install",
"torch-scatter", "torch-sparse", "torch-cluster",
"-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html"
], check=True)
# 2. 暴力解决 diso:克隆源码 -> 强行导入
print("🔥 Attempting D-Plan: Manual diso injection...")
diso_path = os.path.join(os.getcwd(), "diso_source")
if not os.path.exists(diso_path):
subprocess.run(["git", "clone", "https://github.com/SarahWeiii/diso.git", diso_path])
# 将 diso 的源码路径直接加入系统搜索路径
# 这样即使没有编译成功 .so 文件,Python 也能找到包结构
if diso_path not in sys.path:
sys.path.insert(0, diso_path)
# 3. 安装渲染和其他轻量级依赖
print("📦 Installing rendering tools...")
subprocess.run([sys.executable, "-m", "pip", "install", "pyrender", "pyopengl==3.1.0", "pyyaml", "-q"], check=True)
importlib.invalidate_caches()
print("🎉 Environment Installation Phase Finished.")
# 执行安装
install_heavy_packages()
# --- 🛰️ 关键:diso 导入补丁 ---
try:
import diso
print("✅ diso imported successfully!")
except ImportError:
# 如果还是报错,尝试将 diso 内部的包直接暴露出来
print("⚠️ diso import failed, applying emergency mock...")
diso_src_path = os.path.join(os.getcwd(), "diso_source")
sys.path.insert(0, diso_src_path)
# 强制让 Python 识别 diso 目录
importlib.invalidate_caches()
# ... 后续代码保持不变 ...
from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces
from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings, explode_mesh
from src.pipelines.pipeline_partcrafter import PartCrafterPipeline
from src.utils.image_utils import prepare_image
from src.models.briarmbg import BriaRMBG
# Constants
MAX_NUM_PARTS = 16
DEVICE = "cuda"
DTYPE = torch.float16
# Download and initialize models
partcrafter_weights_dir = "pretrained_weights/PartCrafter"
rmbg_weights_dir = "pretrained_weights/RMBG-1.4"
snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir)
snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir)
rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(DEVICE)
rmbg_net.eval()
pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(DEVICE, DTYPE)
def first_file_from_dir(directory, ext):
files = glob.glob(os.path.join(directory, f"*.{ext}"))
return sorted(files)[0] if files else None
def get_duration(
image_path,
num_parts,
seed,
num_tokens,
num_inference_steps,
guidance_scale,
use_flash_decoder,
rmbg,
session_id,
progress,
):
duration_seconds = 75
if num_parts > 10:
duration_seconds = 120
elif num_parts > 5:
duration_seconds = 90
return int(duration_seconds)
@spaces.GPU(duration=140)
def gen_model_n_video(image_path: str,
num_parts: int,
progress=gr.Progress(track_tqdm=True),):
model_path = run_partcrafter(image_path, num_parts=num_parts, progress=progress)
video_path = gen_video(model_path)
return model_path, video_path
@spaces.GPU()
def gen_video(model_path):
if model_path is None:
gr.Info("You must craft the 3d parts first")
return None
export_dir = os.path.dirname(model_path)
merged = trimesh.load(model_path)
preview_path = os.path.join(export_dir, "rendering.gif")
num_views = 36
radius = 4
fps = 7
rendered_images = render_views_around_mesh(
merged,
num_views=num_views,
radius=radius,
)
export_renderings(
rendered_images,
preview_path,
fps=fps,
)
return preview_path
@spaces.GPU(duration=get_duration)
@torch.no_grad()
def run_partcrafter(image_path: str,
num_parts: int = 1,
seed: int = 0,
num_tokens: int = 1024,
num_inference_steps: int = 50,
guidance_scale: float = 7.0,
use_flash_decoder: bool = False,
rmbg: bool = True,
session_id = None,
progress=gr.Progress(track_tqdm=True),):
"""
Generate structured 3D meshes from a 2D image using the PartCrafter pipeline.
This function takes a single 2D image as input and produces a set of part-based 3D meshes,
using compositional latent diffusion with attention to structure and part separation.
Optionally removes the background using a pretrained background removal model (RMBG),
and outputs a merged object mesh.
Args:
image_path (str): Path to the input image file on disk.
num_parts (int, optional): Number of distinct parts to decompose the object into. Defaults to 1.
seed (int, optional): Random seed for reproducibility. Defaults to 0.
num_tokens (int, optional): Number of tokens used during latent encoding. Higher values yield finer detail. Defaults to 1024.
num_inference_steps (int, optional): Number of diffusion inference steps. More steps improve quality but increase runtime. Defaults to 50.
guidance_scale (float, optional): Classifier-free guidance scale. Higher values emphasize adherence to conditioning. Defaults to 7.0.
use_flash_decoder (bool, optional): Whether to use FlashAttention in the decoder for performance. Defaults to False.
rmbg (bool, optional): Whether to apply background removal before processing. Defaults to True.
session_id (str, optional): Optional session ID to manage export paths. If not provided, a random UUID is generated.
progress (gr.Progress, optional): Gradio progress object for visual feedback. Automatically handled by Gradio.
Returns:
Tuple[str, str, str, str]:
- `merged_path` (str): File path to the merged full object mesh (`object.glb`).
Notes:
- This function utilizes HuggingFace pretrained weights for both part generation and background removal.
- The final output includes merged model parts to visualize object structure.
- Generation time depends on the number of parts and inference parameters.
"""
max_num_expanded_coords = 1e9
if session_id is None:
session_id = uuid.uuid4().hex
if rmbg:
img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net)
else:
img_pil = Image.open(image_path)
set_seed(seed)
start_time = time.time()
outputs = pipe(
image=[img_pil] * num_parts,
attention_kwargs={"num_parts": num_parts},
num_tokens=num_tokens,
generator=torch.Generator(device=pipe.device).manual_seed(seed),
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
max_num_expanded_coords=max_num_expanded_coords,
use_flash_decoder=use_flash_decoder,
).meshes
duration = time.time() - start_time
print(f"Generation time: {duration:.2f}s")
# Ensure no None outputs
for i, mesh in enumerate(outputs):
if mesh is None:
outputs[i] = trimesh.Trimesh(vertices=[[0,0,0]], faces=[[0,0,0]])
export_dir = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], session_id)
# If it already exists, delete it (and all its contents)
if os.path.exists(export_dir):
shutil.rmtree(export_dir)
os.makedirs(export_dir, exist_ok=True)
parts = []
for idx, mesh in enumerate(outputs):
part = os.path.join(export_dir, f"part_{idx:02}.glb")
mesh.export(part)
parts.append(part)
# Merge and color
merged = get_colored_mesh_composition(outputs)
split_mesh = explode_mesh(merged)
merged_path = os.path.join(export_dir, "object.glb")
merged.export(merged_path)
return merged_path
def cleanup(request: gr.Request):
sid = request.session_hash
if sid:
d1 = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], sid)
shutil.rmtree(d1, ignore_errors=True)
def start_session(request: gr.Request):
return request.session_hash
def build_demo():
css = """
#col-container {
margin: 0 auto;
max-width: 1560px;
}
"""
theme = gr.themes.Ocean()
with gr.Blocks(css=css, theme=theme) as demo:
session_state = gr.State()
demo.load(start_session, outputs=[session_state])
with gr.Column(elem_id="col-container"):
gr.HTML(
"""
<div style="text-align: center;">
<p style="font-size:16px; display: inline; margin: 0;">
<strong>PartCrafter</strong> – Structured 3D Mesh Generation via Compositional Latent Diffusion Transformers
</p>
<a href="https://github.com/wgsxm/PartCrafter" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo">
</a>
</div>
<div style="text-align: center;">
HF Space by :<a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;">
<img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo">
</a>
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(type="filepath", label="Input Image", height=256)
num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts")
run_button = gr.Button("Step 1 - 🧩 Craft 3D Parts", variant="primary")
video_button = gr.Button("Step 2 - 🎥 Generate Split Preview Gif (Optional)")
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Number(value=0, label="Random Seed", precision=0)
num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens")
num_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps")
guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.1, label="Guidance Scale")
flash_decoder = gr.Checkbox(value=False, label="Use Flash Decoder")
remove_bg = gr.Checkbox(value=True, label="Remove Background (RMBG)")
with gr.Column(scale=2):
gr.HTML(
"""
<p style="opacity: 0.6; font-style: italic;">
The 3D Preview might take a few seconds to load the 3D model
</p>
"""
)
with gr.Row():
output_model = gr.Model3D(label="Merged 3D Object", height=512, interactive=False)
video_output = gr.Image(label="Split Preview", height=512)
with gr.Row():
with gr.Column():
examples = gr.Examples(
examples=[
[
"assets/images/np5_b81f29e567ea4db48014f89c9079e403.png",
5,
],
[
"assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png",
7,
],
[
"assets/images/np16_dino.png",
16,
],
[
"assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png",
13,
],
],
inputs=[input_image, num_parts],
outputs=[output_model, video_output],
fn=gen_model_n_video,
cache_examples=True
)
run_button.click(fn=run_partcrafter,
inputs=[input_image, num_parts, seed, num_tokens, num_steps,
guidance, flash_decoder, remove_bg, session_state],
outputs=[output_model])
video_button.click(fn=gen_video,
inputs=[output_model],
outputs=[video_output])
return demo
if __name__ == "__main__":
demo = build_demo()
demo.unload(cleanup)
demo.queue()
demo.launch(mcp_server=True, ssr_mode=False)