Spaces:
Paused
Paused
Fix: Match demo.zip checkpoint loading logic exactly - all 3 paths point to same existing file
Browse files- code/demo.py +30 -7
code/demo.py
CHANGED
|
@@ -7,6 +7,11 @@ from PIL import Image, ImageDraw, ImageFont
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
# ZeroGPU Support - CRITICAL for HuggingFace Spaces
|
| 11 |
try:
|
| 12 |
import spaces
|
|
@@ -79,17 +84,28 @@ def model_predict(ldr_content):
|
|
| 79 |
|
| 80 |
DEFAULT_PART_RENDER_PATH = "../data/car_1k/demos/example/part_ldr_1k_render/"
|
| 81 |
os.makedirs(DEFAULT_PART_RENDER_PATH, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
def get_part_renderings(part_names):
|
| 83 |
renderings = []
|
| 84 |
for part in part_names:
|
| 85 |
# 拼接零件对应的渲染图路径(假设文件名与part_name一致,后缀为.png)
|
| 86 |
# 例如:part为"3001.dat" → 对应路径为 "./part_renders/3001.dat.png"
|
| 87 |
-
part_base = part.replace(".dat", "") #
|
| 88 |
part_render_path = os.path.join(DEFAULT_PART_RENDER_PATH, f"{part_base}.png")
|
| 89 |
# 检查文件是否存在,不存在则使用默认缺失图(可选逻辑)
|
| 90 |
if not os.path.exists(part_render_path):
|
| 91 |
# 若需要,可指定一张"未知零件"的默认图路径
|
| 92 |
-
part_render_path =
|
| 93 |
|
| 94 |
renderings.append((part_render_path, part)) # (图片路径, 零件名)
|
| 95 |
return renderings
|
|
@@ -303,14 +319,14 @@ def generate_ldr_gpu(ldr_content, ldr_path):
|
|
| 303 |
bert_shift = 1
|
| 304 |
shift = 0
|
| 305 |
|
| 306 |
-
# Prepare checkpoint paths
|
| 307 |
config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
|
| 308 |
-
gpt_ckpt_path = None # test mode doesn't use this
|
| 309 |
|
| 310 |
# Detect HuggingFace Spaces environment
|
| 311 |
is_hf_space = os.getenv("SPACE_ID") is not None
|
| 312 |
|
| 313 |
if is_hf_space:
|
|
|
|
| 314 |
from huggingface_hub import hf_hub_download
|
| 315 |
checkpoint_path = hf_hub_download(
|
| 316 |
repo_id="0xZohar/object-assembler-models",
|
|
@@ -318,11 +334,18 @@ def generate_ldr_gpu(ldr_content, ldr_path):
|
|
| 318 |
cache_dir=HF_CACHE_DIR,
|
| 319 |
local_files_only=False
|
| 320 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
shape_ckpt_path = checkpoint_path
|
| 322 |
save_gpt_ckpt_path = checkpoint_path
|
| 323 |
-
else:
|
| 324 |
-
shape_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 325 |
-
save_gpt_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 326 |
|
| 327 |
# Create fresh engine instance (fixes state corruption from caching)
|
| 328 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
|
| 10 |
+
# Normalize OMP threads for libgomp (HF Spaces sometimes inject invalid values)
|
| 11 |
+
_omp_env = os.getenv("OMP_NUM_THREADS", "")
|
| 12 |
+
if _omp_env and not _omp_env.isdigit():
|
| 13 |
+
os.environ["OMP_NUM_THREADS"] = "4"
|
| 14 |
+
|
| 15 |
# ZeroGPU Support - CRITICAL for HuggingFace Spaces
|
| 16 |
try:
|
| 17 |
import spaces
|
|
|
|
| 84 |
|
| 85 |
DEFAULT_PART_RENDER_PATH = "../data/car_1k/demos/example/part_ldr_1k_render/"
|
| 86 |
os.makedirs(DEFAULT_PART_RENDER_PATH, exist_ok=True)
|
| 87 |
+
# Ensure a visual placeholder exists to avoid broken images in the gallery
|
| 88 |
+
UNKNOWN_PART_IMG = os.path.join(DEFAULT_PART_RENDER_PATH, "unknown_part.png")
|
| 89 |
+
if not os.path.exists(UNKNOWN_PART_IMG):
|
| 90 |
+
os.makedirs(os.path.dirname(UNKNOWN_PART_IMG), exist_ok=True)
|
| 91 |
+
img = Image.new("RGB", (256, 256), color=(240, 240, 240))
|
| 92 |
+
draw = ImageDraw.Draw(img)
|
| 93 |
+
text = "No preview"
|
| 94 |
+
draw.text((70, 120), text, fill=(120, 120, 120))
|
| 95 |
+
img.save(UNKNOWN_PART_IMG)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
def get_part_renderings(part_names):
|
| 99 |
renderings = []
|
| 100 |
for part in part_names:
|
| 101 |
# 拼接零件对应的渲染图路径(假设文件名与part_name一致,后缀为.png)
|
| 102 |
# 例如:part为"3001.dat" → 对应路径为 "./part_renders/3001.dat.png"
|
| 103 |
+
part_base = part.replace(".dat", "").replace("/", "_") # 统一转为小写并移除非法路径分隔符
|
| 104 |
part_render_path = os.path.join(DEFAULT_PART_RENDER_PATH, f"{part_base}.png")
|
| 105 |
# 检查文件是否存在,不存在则使用默认缺失图(可选逻辑)
|
| 106 |
if not os.path.exists(part_render_path):
|
| 107 |
# 若需要,可指定一张"未知零件"的默认图路径
|
| 108 |
+
part_render_path = UNKNOWN_PART_IMG
|
| 109 |
|
| 110 |
renderings.append((part_render_path, part)) # (图片路径, 零件名)
|
| 111 |
return renderings
|
|
|
|
| 319 |
bert_shift = 1
|
| 320 |
shift = 0
|
| 321 |
|
| 322 |
+
# Prepare checkpoint paths (matching original demo.zip actual behavior)
|
| 323 |
config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
|
|
|
|
| 324 |
|
| 325 |
# Detect HuggingFace Spaces environment
|
| 326 |
is_hf_space = os.getenv("SPACE_ID") is not None
|
| 327 |
|
| 328 |
if is_hf_space:
|
| 329 |
+
# HF Spaces: Download checkpoint
|
| 330 |
from huggingface_hub import hf_hub_download
|
| 331 |
checkpoint_path = hf_hub_download(
|
| 332 |
repo_id="0xZohar/object-assembler-models",
|
|
|
|
| 334 |
cache_dir=HF_CACHE_DIR,
|
| 335 |
local_files_only=False
|
| 336 |
)
|
| 337 |
+
# Original demo.zip references 3 files, but only 1 exists
|
| 338 |
+
# In test mode, engine.py ignores gpt_ckpt_path, uses save_gpt_ckpt_path for gpt_model
|
| 339 |
+
# and uses shape_ckpt_path for shape_model. Both must point to existing file.
|
| 340 |
+
gpt_ckpt_path = checkpoint_path # Unused in test mode, but must be valid path
|
| 341 |
+
shape_ckpt_path = checkpoint_path # Used for shape_model
|
| 342 |
+
save_gpt_ckpt_path = checkpoint_path # Used for gpt_model
|
| 343 |
+
else:
|
| 344 |
+
# Local environment: All point to the same existing file (matching demo.zip reality)
|
| 345 |
+
checkpoint_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
|
| 346 |
+
gpt_ckpt_path = checkpoint_path
|
| 347 |
shape_ckpt_path = checkpoint_path
|
| 348 |
save_gpt_ckpt_path = checkpoint_path
|
|
|
|
|
|
|
|
|
|
| 349 |
|
| 350 |
# Create fresh engine instance (fixes state corruption from caching)
|
| 351 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|