import os import sys import subprocess # <--- 确保这行在这里! import importlib import site import time # --- 🧪 1. 内存级伪造 diso (必须在任何业务 import 之前) --- def mock_diso(): from types import ModuleType print("🧪 Creating emergency mock for diso...") diso = ModuleType("diso") class FakeDiffDMC: def __init__(self, *args, **kwargs): pass def __call__(self, *args, **kwargs): return None diso.DiffDMC = FakeDiffDMC sys.modules["diso"] = diso sys.modules["diso._C"] = ModuleType("diso._C") sys.modules["diso.diso_native"] = ModuleType("diso.diso_native") print("✅ diso has been mocked successfully!") mock_diso() # --- 🚀 2. 极速环境安装 (已经成功的 scatter/sparse) --- def install_essential_packages(): print("📦 Checking core dependencies...") # 确保基础环境正确 subprocess.run([sys.executable, "-m", "pip", "install", "ninja", "setuptools", "wheel", "-q"]) # 极速安装 PyG 扩展 subprocess.run([ sys.executable, "-m", "pip", "install", "torch-scatter", "torch-sparse", "torch-cluster", "-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html", "--no-cache-dir" ]) # 安装剩下的渲染工具 subprocess.run([ sys.executable, "-m", "pip", "install", "pyrender", "pyopengl==3.1.0", "pyyaml", "trimesh", "accelerate", "-q" ]) importlib.invalidate_caches() site.main() print("🎉 Environment Installation Phase Finished.") install_essential_packages() # ... 之前的 mock_diso 和安装逻辑 ... # 1. 核心路径保护 os.environ["PARTCRAFTER_PROCESSED"] = os.environ.get("PARTCRAFTER_PROCESSED", "outputs") os.makedirs(os.environ["PARTCRAFTER_PROCESSED"], exist_ok=True) # 2. 模型权重下载路径确认 (确保这些目录也存在) os.makedirs("pretrained_weights/PartCrafter", exist_ok=True) os.makedirs("pretrained_weights/RMBG-1.4", exist_ok=True) # ... 继续执行 snapshot_download ... # --- 3. 正式导入业务逻辑 (现在开始这几百行代码就不会报错了) --- import spaces import gradio as gr import numpy as np import torch import uuid import shutil from huggingface_hub import snapshot_download from PIL import Image from accelerate.utils import set_seed # 从这里往下,粘贴你原本所有的业务逻辑代码 (PartCrafterPipeline 等) # ... # --- 🚀 核心修复:强制版本回退以避开编译 --- def pre_install_check(): try: import torch # 如果是 2.9+ 版本,强制降级到有预编译包的 2.4.0 if "2.9" in torch.__version__: print(f"🔄 Current torch {torch.__version__} is too new. Downgrading to 2.4.0 for speed...") subprocess.check_call([sys.executable, "-m", "pip", "install", "ninja", "setuptools", "wheel", "-q"]) subprocess.check_call([ sys.executable, "-m", "pip", "install", "torch==2.4.0+cu121", "torchvision==0.19.0+cu121", "--extra-index-url", "https://download.pytorch.org/whl/cu121" ]) # 刷新路径 importlib.invalidate_caches() os.execv(sys.executable, ['python'] + sys.argv) # 重启进程以加载新版本 except Exception as e: print(f"Pre-install check note: {e}") pre_install_check() import trimesh import glob import importlib, site # Re-discover all .pth/.egg-link files for sitedir in site.getsitepackages(): site.addsitedir(sitedir) importlib.invalidate_caches() # --- 简化的 CUDA 环境配置 --- def setup_cuda_env(): cuda_path = "/usr/local/cuda" if os.path.exists(cuda_path): os.environ["CUDA_HOME"] = cuda_path os.environ["PATH"] = f"{cuda_path}/bin:{os.environ['PATH']}" os.environ["LD_LIBRARY_PATH"] = f"{cuda_path}/lib64:{os.environ.get('LD_LIBRARY_PATH', '')}" print(f"==> Using system CUDA at {cuda_path}") setup_cuda_env() # --- 🚀 针对 PyTorch 2.9.1 的优化源码编译方案 --- # --- 🚀 暴力整合版:攻克 diso 最后的防线 --- def install_heavy_packages(): os.environ['PYOPENGL_PLATFORM'] = 'egl' # 1. PyG 扩展(这部分已经稳了,保持不动) print("📦 Installing PyG extensions...") subprocess.run([ sys.executable, "-m", "pip", "install", "torch-scatter", "torch-sparse", "torch-cluster", "-f", "https://data.pyg.org/whl/torch-2.4.0+cu121.html" ], check=True) # 2. 暴力解决 diso:克隆源码 -> 强行导入 print("🔥 Attempting D-Plan: Manual diso injection...") diso_path = os.path.join(os.getcwd(), "diso_source") if not os.path.exists(diso_path): subprocess.run(["git", "clone", "https://github.com/SarahWeiii/diso.git", diso_path]) # 将 diso 的源码路径直接加入系统搜索路径 # 这样即使没有编译成功 .so 文件,Python 也能找到包结构 if diso_path not in sys.path: sys.path.insert(0, diso_path) # 3. 安装渲染和其他轻量级依赖 print("📦 Installing rendering tools...") subprocess.run([sys.executable, "-m", "pip", "install", "pyrender", "pyopengl==3.1.0", "pyyaml", "-q"], check=True) importlib.invalidate_caches() print("🎉 Environment Installation Phase Finished.") # 执行安装 install_heavy_packages() # --- 🛰️ 关键:diso 导入补丁 --- try: import diso print("✅ diso imported successfully!") except ImportError: # 如果还是报错,尝试将 diso 内部的包直接暴露出来 print("⚠️ diso import failed, applying emergency mock...") diso_src_path = os.path.join(os.getcwd(), "diso_source") sys.path.insert(0, diso_src_path) # 强制让 Python 识别 diso 目录 importlib.invalidate_caches() # ... 后续代码保持不变 ... from src.utils.data_utils import get_colored_mesh_composition, scene_to_parts, load_surfaces from src.utils.render_utils import render_views_around_mesh, render_normal_views_around_mesh, make_grid_for_images_or_videos, export_renderings, explode_mesh from src.pipelines.pipeline_partcrafter import PartCrafterPipeline from src.utils.image_utils import prepare_image from src.models.briarmbg import BriaRMBG # Constants MAX_NUM_PARTS = 16 DEVICE = "cuda" DTYPE = torch.float16 # Download and initialize models partcrafter_weights_dir = "pretrained_weights/PartCrafter" rmbg_weights_dir = "pretrained_weights/RMBG-1.4" snapshot_download(repo_id="wgsxm/PartCrafter", local_dir=partcrafter_weights_dir) snapshot_download(repo_id="briaai/RMBG-1.4", local_dir=rmbg_weights_dir) rmbg_net = BriaRMBG.from_pretrained(rmbg_weights_dir).to(DEVICE) rmbg_net.eval() pipe: PartCrafterPipeline = PartCrafterPipeline.from_pretrained(partcrafter_weights_dir).to(DEVICE, DTYPE) def first_file_from_dir(directory, ext): files = glob.glob(os.path.join(directory, f"*.{ext}")) return sorted(files)[0] if files else None def get_duration( image_path, num_parts, seed, num_tokens, num_inference_steps, guidance_scale, use_flash_decoder, rmbg, session_id, progress, ): duration_seconds = 75 if num_parts > 10: duration_seconds = 120 elif num_parts > 5: duration_seconds = 90 return int(duration_seconds) @spaces.GPU(duration=140) def gen_model_n_video(image_path: str, num_parts: int, progress=gr.Progress(track_tqdm=True),): model_path = run_partcrafter(image_path, num_parts=num_parts, progress=progress) video_path = gen_video(model_path) return model_path, video_path @spaces.GPU() def gen_video(model_path): if model_path is None: gr.Info("You must craft the 3d parts first") return None export_dir = os.path.dirname(model_path) merged = trimesh.load(model_path) preview_path = os.path.join(export_dir, "rendering.gif") num_views = 36 radius = 4 fps = 7 rendered_images = render_views_around_mesh( merged, num_views=num_views, radius=radius, ) export_renderings( rendered_images, preview_path, fps=fps, ) return preview_path @spaces.GPU(duration=get_duration) @torch.no_grad() def run_partcrafter(image_path: str, num_parts: int = 1, seed: int = 0, num_tokens: int = 1024, num_inference_steps: int = 50, guidance_scale: float = 7.0, use_flash_decoder: bool = False, rmbg: bool = True, session_id = None, progress=gr.Progress(track_tqdm=True),): """ Generate structured 3D meshes from a 2D image using the PartCrafter pipeline. This function takes a single 2D image as input and produces a set of part-based 3D meshes, using compositional latent diffusion with attention to structure and part separation. Optionally removes the background using a pretrained background removal model (RMBG), and outputs a merged object mesh. Args: image_path (str): Path to the input image file on disk. num_parts (int, optional): Number of distinct parts to decompose the object into. Defaults to 1. seed (int, optional): Random seed for reproducibility. Defaults to 0. num_tokens (int, optional): Number of tokens used during latent encoding. Higher values yield finer detail. Defaults to 1024. num_inference_steps (int, optional): Number of diffusion inference steps. More steps improve quality but increase runtime. Defaults to 50. guidance_scale (float, optional): Classifier-free guidance scale. Higher values emphasize adherence to conditioning. Defaults to 7.0. use_flash_decoder (bool, optional): Whether to use FlashAttention in the decoder for performance. Defaults to False. rmbg (bool, optional): Whether to apply background removal before processing. Defaults to True. session_id (str, optional): Optional session ID to manage export paths. If not provided, a random UUID is generated. progress (gr.Progress, optional): Gradio progress object for visual feedback. Automatically handled by Gradio. Returns: Tuple[str, str, str, str]: - `merged_path` (str): File path to the merged full object mesh (`object.glb`). Notes: - This function utilizes HuggingFace pretrained weights for both part generation and background removal. - The final output includes merged model parts to visualize object structure. - Generation time depends on the number of parts and inference parameters. """ max_num_expanded_coords = 1e9 if session_id is None: session_id = uuid.uuid4().hex if rmbg: img_pil = prepare_image(image_path, bg_color=np.array([1.0, 1.0, 1.0]), rmbg_net=rmbg_net) else: img_pil = Image.open(image_path) set_seed(seed) start_time = time.time() outputs = pipe( image=[img_pil] * num_parts, attention_kwargs={"num_parts": num_parts}, num_tokens=num_tokens, generator=torch.Generator(device=pipe.device).manual_seed(seed), num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, max_num_expanded_coords=max_num_expanded_coords, use_flash_decoder=use_flash_decoder, ).meshes duration = time.time() - start_time print(f"Generation time: {duration:.2f}s") # Ensure no None outputs for i, mesh in enumerate(outputs): if mesh is None: outputs[i] = trimesh.Trimesh(vertices=[[0,0,0]], faces=[[0,0,0]]) export_dir = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], session_id) # If it already exists, delete it (and all its contents) if os.path.exists(export_dir): shutil.rmtree(export_dir) os.makedirs(export_dir, exist_ok=True) parts = [] for idx, mesh in enumerate(outputs): part = os.path.join(export_dir, f"part_{idx:02}.glb") mesh.export(part) parts.append(part) # Merge and color merged = get_colored_mesh_composition(outputs) split_mesh = explode_mesh(merged) merged_path = os.path.join(export_dir, "object.glb") merged.export(merged_path) return merged_path def cleanup(request: gr.Request): sid = request.session_hash if sid: d1 = os.path.join(os.environ["PARTCRAFTER_PROCESSED"], sid) shutil.rmtree(d1, ignore_errors=True) def start_session(request: gr.Request): return request.session_hash def build_demo(): css = """ #col-container { margin: 0 auto; max-width: 1560px; } """ theme = gr.themes.Ocean() with gr.Blocks(css=css, theme=theme) as demo: session_state = gr.State() demo.load(start_session, outputs=[session_state]) with gr.Column(elem_id="col-container"): gr.HTML( """
""" ) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="filepath", label="Input Image", height=256) num_parts = gr.Slider(1, MAX_NUM_PARTS, value=4, step=1, label="Number of Parts") run_button = gr.Button("Step 1 - 🧩 Craft 3D Parts", variant="primary") video_button = gr.Button("Step 2 - 🎥 Generate Split Preview Gif (Optional)") with gr.Accordion("Advanced Settings", open=False): seed = gr.Number(value=0, label="Random Seed", precision=0) num_tokens = gr.Slider(256, 2048, value=1024, step=64, label="Num Tokens") num_steps = gr.Slider(1, 100, value=50, step=1, label="Inference Steps") guidance = gr.Slider(1.0, 20.0, value=7.0, step=0.1, label="Guidance Scale") flash_decoder = gr.Checkbox(value=False, label="Use Flash Decoder") remove_bg = gr.Checkbox(value=True, label="Remove Background (RMBG)") with gr.Column(scale=2): gr.HTML( """The 3D Preview might take a few seconds to load the 3D model
""" ) with gr.Row(): output_model = gr.Model3D(label="Merged 3D Object", height=512, interactive=False) video_output = gr.Image(label="Split Preview", height=512) with gr.Row(): with gr.Column(): examples = gr.Examples( examples=[ [ "assets/images/np5_b81f29e567ea4db48014f89c9079e403.png", 5, ], [ "assets/images/np7_1c004909dedb4ebe8db69b4d7b077434.png", 7, ], [ "assets/images/np16_dino.png", 16, ], [ "assets/images/np13_39c0fa16ed324b54a605dcdbcd80797c.png", 13, ], ], inputs=[input_image, num_parts], outputs=[output_model, video_output], fn=gen_model_n_video, cache_examples=True ) run_button.click(fn=run_partcrafter, inputs=[input_image, num_parts, seed, num_tokens, num_steps, guidance, flash_decoder, remove_bg, session_state], outputs=[output_model]) video_button.click(fn=gen_video, inputs=[output_model], outputs=[video_output]) return demo if __name__ == "__main__": demo = build_demo() demo.unload(cleanup) demo.queue() demo.launch(mcp_server=True, ssr_mode=False)