0xZohar commited on
Commit
9fa6864
·
verified ·
1 Parent(s): 3db6945

Update code/demo.py: 主程序(支持Model Hub权重加载)

Browse files
Files changed (1) hide show
  1. code/demo.py +29 -14
code/demo.py CHANGED
@@ -319,33 +319,48 @@ def generate_ldr_gpu(ldr_content, ldr_path):
319
  bert_shift = 1
320
  shift = 0
321
 
322
- # Prepare checkpoint paths (matching original demo.zip actual behavior)
323
  config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
324
 
325
  # Detect HuggingFace Spaces environment
326
  is_hf_space = os.getenv("SPACE_ID") is not None
327
 
328
  if is_hf_space:
329
- # HF Spaces: Download checkpoint
330
  from huggingface_hub import hf_hub_download
331
- checkpoint_path = hf_hub_download(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
332
  repo_id="0xZohar/object-assembler-models",
333
  filename="save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors",
334
  cache_dir=HF_CACHE_DIR,
335
  local_files_only=False
336
  )
337
- # Original demo.zip references 3 files, but only 1 exists
338
- # In test mode, engine.py ignores gpt_ckpt_path, uses save_gpt_ckpt_path for gpt_model
339
- # and uses shape_ckpt_path for shape_model. Both must point to existing file.
340
- gpt_ckpt_path = checkpoint_path # Unused in test mode, but must be valid path
341
- shape_ckpt_path = checkpoint_path # Used for shape_model
342
- save_gpt_ckpt_path = checkpoint_path # Used for gpt_model
343
  else:
344
- # Local environment: All point to the same existing file (matching demo.zip reality)
345
- checkpoint_path = 'model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
346
- gpt_ckpt_path = checkpoint_path
347
- shape_ckpt_path = checkpoint_path
348
- save_gpt_ckpt_path = checkpoint_path
349
 
350
  # Create fresh engine instance (fixes state corruption from caching)
351
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
319
  bert_shift = 1
320
  shift = 0
321
 
322
+ # Prepare checkpoint paths (3 separate weight files as per original demo.zip design)
323
  config_path = os.path.join(os.path.dirname(__file__), 'cube3d/configs/open_model_v0.5.yaml')
324
 
325
  # Detect HuggingFace Spaces environment
326
  is_hf_space = os.getenv("SPACE_ID") is not None
327
 
328
  if is_hf_space:
329
+ # HF Spaces: Download all 3 checkpoints from Model Hub
330
  from huggingface_hub import hf_hub_download
331
+ print("📥 Downloading model weights from HuggingFace Hub...")
332
+
333
+ # Download base GPT model (7.17 GB)
334
+ gpt_ckpt_path = hf_hub_download(
335
+ repo_id="0xZohar/object-assembler-models",
336
+ filename="shape_gpt.safetensors",
337
+ cache_dir=HF_CACHE_DIR,
338
+ local_files_only=False
339
+ )
340
+ print(f" ✓ Base GPT model loaded")
341
+
342
+ # Download shape tokenizer (1.09 GB)
343
+ shape_ckpt_path = hf_hub_download(
344
+ repo_id="0xZohar/object-assembler-models",
345
+ filename="shape_tokenizer.safetensors",
346
+ cache_dir=HF_CACHE_DIR,
347
+ local_files_only=False
348
+ )
349
+ print(f" ✓ Shape tokenizer loaded")
350
+
351
+ # Download fine-tuned adapter (1.68 GB)
352
+ save_gpt_ckpt_path = hf_hub_download(
353
  repo_id="0xZohar/object-assembler-models",
354
  filename="save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors",
355
  cache_dir=HF_CACHE_DIR,
356
  local_files_only=False
357
  )
358
+ print(f" ✓ Fine-tuned adapter loaded")
 
 
 
 
 
359
  else:
360
+ # Local environment: Use local paths (matching original demo.zip structure)
361
+ gpt_ckpt_path = 'temp_weights/shape_gpt.safetensors'
362
+ shape_ckpt_path = 'temp_weights/shape_tokenizer.safetensors'
363
+ save_gpt_ckpt_path = '/private/tmp/demo_extracted/demo/code/model_weights/save_shape_cars_whole_p_rot_scratch_4mask_randp.safetensors'
 
364
 
365
  # Create fresh engine instance (fixes state corruption from caching)
366
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')