0xZohar commited on
Commit
3db6945
·
verified ·
1 Parent(s): dae7e2f

Fix: Match demo.zip checkpoint loading logic exactly - all 3 paths point to same existing file

Browse files
Files changed (1) hide show
  1. code/demo.py +30 -7
code/demo.py CHANGED
@@ -7,6 +7,11 @@ from PIL import Image, ImageDraw, ImageFont
7
  import numpy as np
8
  import torch
9
 
 
 
 
 
 
10
  # ZeroGPU Support - CRITICAL for HuggingFace Spaces
11
  try:
12
  import spaces
@@ -79,17 +84,28 @@ def model_predict(ldr_content):
79
 
80
  DEFAULT_PART_RENDER_PATH = "../data/car_1k/demos/example/part_ldr_1k_render/"
81
  os.makedirs(DEFAULT_PART_RENDER_PATH, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
82
  def get_part_renderings(part_names):
83
  renderings = []
84
  for part in part_names:
85
  # 拼接零件对应的渲染图路径(假设文件名与part_name一致,后缀为.png)
86
  # 例如:part为"3001.dat" → 对应路径为 "./part_renders/3001.dat.png"
87
- part_base = part.replace(".dat", "") # 统一转为小写并移除.dat
88
  part_render_path = os.path.join(DEFAULT_PART_RENDER_PATH, f"{part_base}.png")
89
  # 检查文件是否存在,不存在则使用默认缺失图(可选逻辑)
90
  if not os.path.exists(part_render_path):
91
  # 若需要,可指定一张"未知零件"的默认图路径
92
- part_render_path = os.path.join(DEFAULT_PART_RENDER_PATH, "unknown_part.png")
93
 
94
  renderings.append((part_render_path, part)) # (图片路径, 零件名)
95
  return renderings
@@ -303,14 +319,14 @@ def generate_ldr_gpu(ldr_content, ldr_path):
303
  bert_shift = 1
304
  shift = 0
305
 
306
- # Prepare checkpoint paths
307
  config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
308
- gpt_ckpt_path = None # test mode doesn't use this
309
 
310
  # Detect HuggingFace Spaces environment
311
  is_hf_space = os.getenv("SPACE_ID") is not None
312
 
313
  if is_hf_space:
 
314
  from huggingface_hub import hf_hub_download
315
  checkpoint_path = hf_hub_download(
316
  repo_id="0xZohar/object-assembler-models",
@@ -318,11 +334,18 @@ def generate_ldr_gpu(ldr_content, ldr_path):
318
  cache_dir=HF_CACHE_DIR,
319
  local_files_only=False
320
  )
 
 
 
 
 
 
 
 
 
 
321
  shape_ckpt_path = checkpoint_path
322
  save_gpt_ckpt_path = checkpoint_path
323
- else:
324
- shape_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
325
- save_gpt_ckpt_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
326
 
327
  # Create fresh engine instance (fixes state corruption from caching)
328
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
7
  import numpy as np
8
  import torch
9
 
10
+ # Normalize OMP threads for libgomp (HF Spaces sometimes inject invalid values)
11
+ _omp_env = os.getenv("OMP_NUM_THREADS", "")
12
+ if _omp_env and not _omp_env.isdigit():
13
+ os.environ["OMP_NUM_THREADS"] = "4"
14
+
15
  # ZeroGPU Support - CRITICAL for HuggingFace Spaces
16
  try:
17
  import spaces
 
84
 
85
  DEFAULT_PART_RENDER_PATH = "../data/car_1k/demos/example/part_ldr_1k_render/"
86
  os.makedirs(DEFAULT_PART_RENDER_PATH, exist_ok=True)
87
+ # Ensure a visual placeholder exists to avoid broken images in the gallery
88
+ UNKNOWN_PART_IMG = os.path.join(DEFAULT_PART_RENDER_PATH, "unknown_part.png")
89
+ if not os.path.exists(UNKNOWN_PART_IMG):
90
+ os.makedirs(os.path.dirname(UNKNOWN_PART_IMG), exist_ok=True)
91
+ img = Image.new("RGB", (256, 256), color=(240, 240, 240))
92
+ draw = ImageDraw.Draw(img)
93
+ text = "No preview"
94
+ draw.text((70, 120), text, fill=(120, 120, 120))
95
+ img.save(UNKNOWN_PART_IMG)
96
+
97
+
98
  def get_part_renderings(part_names):
99
  renderings = []
100
  for part in part_names:
101
  # 拼接零件对应的渲染图路径(假设文件名与part_name一致,后缀为.png)
102
  # 例如:part为"3001.dat" → 对应路径为 "./part_renders/3001.dat.png"
103
+ part_base = part.replace(".dat", "").replace("/", "_") # 统一转为小写并移除非法路径分隔符
104
  part_render_path = os.path.join(DEFAULT_PART_RENDER_PATH, f"{part_base}.png")
105
  # 检查文件是否存在,不存在则使用默认缺失图(可选逻辑)
106
  if not os.path.exists(part_render_path):
107
  # 若需要,可指定一张"未知零件"的默认图路径
108
+ part_render_path = UNKNOWN_PART_IMG
109
 
110
  renderings.append((part_render_path, part)) # (图片路径, 零件名)
111
  return renderings
 
319
  bert_shift = 1
320
  shift = 0
321
 
322
+ # Prepare checkpoint paths (matching original demo.zip actual behavior)
323
  config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
 
324
 
325
  # Detect HuggingFace Spaces environment
326
  is_hf_space = os.getenv("SPACE_ID") is not None
327
 
328
  if is_hf_space:
329
+ # HF Spaces: Download checkpoint
330
  from huggingface_hub import hf_hub_download
331
  checkpoint_path = hf_hub_download(
332
  repo_id="0xZohar/object-assembler-models",
 
334
  cache_dir=HF_CACHE_DIR,
335
  local_files_only=False
336
  )
337
+ # Original demo.zip references 3 files, but only 1 exists
338
+ # In test mode, engine.py ignores gpt_ckpt_path, uses save_gpt_ckpt_path for gpt_model
339
+ # and uses shape_ckpt_path for shape_model. Both must point to existing file.
340
+ gpt_ckpt_path = checkpoint_path # Unused in test mode, but must be valid path
341
+ shape_ckpt_path = checkpoint_path # Used for shape_model
342
+ save_gpt_ckpt_path = checkpoint_path # Used for gpt_model
343
+ else:
344
+ # Local environment: All point to the same existing file (matching demo.zip reality)
345
+ checkpoint_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
346
+ gpt_ckpt_path = checkpoint_path
347
  shape_ckpt_path = checkpoint_path
348
  save_gpt_ckpt_path = checkpoint_path
 
 
 
349
 
350
  # Create fresh engine instance (fixes state corruption from caching)
351
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')